mirror of
https://github.com/langgenius/dify.git
synced 2026-03-28 23:00:25 -04:00
feat: extract model runtime
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
@@ -33,20 +33,6 @@ ignore_imports =
|
||||
# TODO(QuantumGhost): fix the import violation later
|
||||
dify_graph.entities.pause_reason -> dify_graph.nodes.human_input.entities
|
||||
|
||||
[importlinter:contract:workflow-infrastructure-dependencies]
|
||||
name = Workflow Infrastructure Dependencies
|
||||
type = forbidden
|
||||
source_modules =
|
||||
dify_graph
|
||||
forbidden_modules =
|
||||
extensions.ext_database
|
||||
extensions.ext_redis
|
||||
allow_indirect_imports = True
|
||||
ignore_imports =
|
||||
dify_graph.nodes.llm.node -> extensions.ext_database
|
||||
dify_graph.model_runtime.model_providers.__base.ai_model -> extensions.ext_redis
|
||||
dify_graph.model_runtime.model_providers.model_provider_factory -> extensions.ext_redis
|
||||
|
||||
[importlinter:contract:workflow-external-imports]
|
||||
name = Workflow External Imports
|
||||
type = forbidden
|
||||
@@ -89,45 +75,7 @@ forbidden_modules =
|
||||
core.trigger
|
||||
core.variables
|
||||
ignore_imports =
|
||||
dify_graph.nodes.llm.llm_utils -> core.model_manager
|
||||
dify_graph.nodes.llm.protocols -> core.model_manager
|
||||
dify_graph.nodes.llm.llm_utils -> dify_graph.model_runtime.model_providers.__base.large_language_model
|
||||
dify_graph.nodes.llm.node -> core.tools.signature
|
||||
dify_graph.nodes.tool.tool_node -> core.callback_handler.workflow_tool_callback_handler
|
||||
dify_graph.nodes.tool.tool_node -> core.tools.tool_engine
|
||||
dify_graph.nodes.tool.tool_node -> core.tools.tool_manager
|
||||
dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.advanced_prompt_transform
|
||||
dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.simple_prompt_transform
|
||||
dify_graph.nodes.parameter_extractor.parameter_extractor_node -> dify_graph.model_runtime.model_providers.__base.large_language_model
|
||||
dify_graph.nodes.question_classifier.question_classifier_node -> core.prompt.simple_prompt_transform
|
||||
dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.model_manager
|
||||
dify_graph.nodes.question_classifier.question_classifier_node -> core.model_manager
|
||||
dify_graph.nodes.tool.tool_node -> core.tools.utils.message_transformer
|
||||
dify_graph.nodes.llm.node -> core.llm_generator.output_parser.errors
|
||||
dify_graph.nodes.llm.node -> core.llm_generator.output_parser.structured_output
|
||||
dify_graph.nodes.llm.node -> core.model_manager
|
||||
dify_graph.nodes.llm.entities -> core.prompt.entities.advanced_prompt_entities
|
||||
dify_graph.nodes.llm.node -> core.prompt.entities.advanced_prompt_entities
|
||||
dify_graph.nodes.llm.node -> core.prompt.utils.prompt_message_util
|
||||
dify_graph.nodes.parameter_extractor.entities -> core.prompt.entities.advanced_prompt_entities
|
||||
dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.entities.advanced_prompt_entities
|
||||
dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.utils.prompt_message_util
|
||||
dify_graph.nodes.question_classifier.entities -> core.prompt.entities.advanced_prompt_entities
|
||||
dify_graph.nodes.question_classifier.question_classifier_node -> core.prompt.utils.prompt_message_util
|
||||
dify_graph.nodes.llm.node -> models.dataset
|
||||
dify_graph.nodes.llm.file_saver -> core.tools.signature
|
||||
dify_graph.nodes.llm.file_saver -> core.tools.tool_file_manager
|
||||
dify_graph.nodes.tool.tool_node -> core.tools.errors
|
||||
dify_graph.nodes.llm.node -> extensions.ext_database
|
||||
dify_graph.nodes.llm.node -> models.model
|
||||
dify_graph.nodes.tool.tool_node -> services
|
||||
dify_graph.model_runtime.model_providers.__base.ai_model -> configs
|
||||
dify_graph.model_runtime.model_providers.__base.ai_model -> extensions.ext_redis
|
||||
dify_graph.model_runtime.model_providers.__base.large_language_model -> configs
|
||||
dify_graph.model_runtime.model_providers.__base.text_embedding_model -> core.entities.embedding_type
|
||||
dify_graph.model_runtime.model_providers.model_provider_factory -> configs
|
||||
dify_graph.model_runtime.model_providers.model_provider_factory -> extensions.ext_redis
|
||||
dify_graph.model_runtime.model_providers.model_provider_factory -> models.provider_ids
|
||||
|
||||
[importlinter:contract:rsc]
|
||||
name = RSC
|
||||
|
||||
@@ -6,7 +6,6 @@ from contexts.wrapper import RecyclableContextVar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
||||
from core.tools.plugin_tool.provider import PluginToolProviderController
|
||||
from core.trigger.provider import PluginTriggerProviderController
|
||||
|
||||
@@ -20,14 +19,6 @@ plugin_tool_providers: RecyclableContextVar[dict[str, "PluginToolProviderControl
|
||||
|
||||
plugin_tool_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(ContextVar("plugin_tool_providers_lock"))
|
||||
|
||||
plugin_model_providers: RecyclableContextVar[list["PluginModelProviderEntity"] | None] = RecyclableContextVar(
|
||||
ContextVar("plugin_model_providers")
|
||||
)
|
||||
|
||||
plugin_model_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(
|
||||
ContextVar("plugin_model_providers_lock")
|
||||
)
|
||||
|
||||
datasource_plugin_providers: RecyclableContextVar[dict[str, "DatasourcePluginProviderController"]] = (
|
||||
RecyclableContextVar(ContextVar("datasource_plugin_providers"))
|
||||
)
|
||||
|
||||
@@ -25,7 +25,7 @@ from controllers.console.wraps import (
|
||||
)
|
||||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from core.provider_manager import ProviderManager
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.extractor.entity.datasource_type import DatasourceType
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo
|
||||
@@ -331,7 +331,7 @@ class DatasetListApi(Resource):
|
||||
)
|
||||
|
||||
# check embedding setting
|
||||
provider_manager = ProviderManager()
|
||||
provider_manager = create_plugin_provider_manager(tenant_id=current_tenant_id)
|
||||
configurations = provider_manager.get_configurations(tenant_id=current_tenant_id)
|
||||
|
||||
embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
|
||||
@@ -445,7 +445,7 @@ class DatasetApi(Resource):
|
||||
data.update({"partial_member_list": part_users_list})
|
||||
|
||||
# check embedding setting
|
||||
provider_manager = ProviderManager()
|
||||
provider_manager = create_plugin_provider_manager(tenant_id=current_tenant_id)
|
||||
configurations = provider_manager.get_configurations(tenant_id=current_tenant_id)
|
||||
|
||||
embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
|
||||
|
||||
@@ -452,7 +452,7 @@ class DatasetInitApi(Resource):
|
||||
if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None:
|
||||
raise ValueError("embedding model and embedding model provider are required for high quality indexing.")
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id)
|
||||
model_manager.get_model_instance(
|
||||
tenant_id=current_tenant_id,
|
||||
provider=knowledge_config.embedding_model_provider,
|
||||
|
||||
@@ -282,7 +282,7 @@ class DatasetDocumentSegmentApi(Resource):
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
# check embedding model setting
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id)
|
||||
model_manager.get_model_instance(
|
||||
tenant_id=current_tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
@@ -335,7 +335,7 @@ class DatasetDocumentSegmentAddApi(Resource):
|
||||
# check embedding model setting
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id)
|
||||
model_manager.get_model_instance(
|
||||
tenant_id=current_tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
@@ -386,7 +386,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
# check embedding model setting
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id)
|
||||
model_manager.get_model_instance(
|
||||
tenant_id=current_tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
@@ -571,7 +571,7 @@ class ChildChunkAddApi(Resource):
|
||||
# check embedding model setting
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id)
|
||||
model_manager.get_model_instance(
|
||||
tenant_id=current_tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
|
||||
@@ -282,14 +282,18 @@ class ModelProviderModelCredentialApi(Resource):
|
||||
)
|
||||
|
||||
if args.config_from == "predefined-model":
|
||||
available_credentials = model_provider_service.provider_manager.get_provider_available_credentials(
|
||||
tenant_id=tenant_id, provider_name=provider
|
||||
available_credentials = model_provider_service.get_provider_available_credentials(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
)
|
||||
else:
|
||||
# Normalize model_type to the origin value stored in DB (e.g., "text-generation" for LLM)
|
||||
normalized_model_type = args.model_type.to_origin_model_type()
|
||||
available_credentials = model_provider_service.provider_manager.get_provider_model_available_credentials(
|
||||
tenant_id=tenant_id, provider_name=provider, model_type=normalized_model_type, model_name=args.model
|
||||
available_credentials = model_provider_service.get_provider_model_available_credentials(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
model_type=normalized_model_type,
|
||||
model=args.model,
|
||||
)
|
||||
|
||||
return jsonable_encoder(
|
||||
|
||||
@@ -14,7 +14,7 @@ from controllers.service_api.wraps import (
|
||||
DatasetApiResource,
|
||||
cloud_edition_billing_rate_limit_check,
|
||||
)
|
||||
from core.provider_manager import ProviderManager
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from fields.dataset_fields import dataset_detail_fields
|
||||
from fields.tag_fields import DataSetTag
|
||||
@@ -139,10 +139,10 @@ class DatasetListApi(DatasetApiResource):
|
||||
query.page, query.limit, tenant_id, current_user, query.keyword, query.tag_ids, query.include_all
|
||||
)
|
||||
# check embedding setting
|
||||
provider_manager = ProviderManager()
|
||||
assert isinstance(current_user, Account)
|
||||
cid = current_user.current_tenant_id
|
||||
assert cid is not None
|
||||
provider_manager = create_plugin_provider_manager(tenant_id=cid)
|
||||
configurations = provider_manager.get_configurations(tenant_id=cid)
|
||||
|
||||
embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
|
||||
@@ -253,10 +253,10 @@ class DatasetApi(DatasetApiResource):
|
||||
raise Forbidden(str(e))
|
||||
data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
|
||||
# check embedding setting
|
||||
provider_manager = ProviderManager()
|
||||
assert isinstance(current_user, Account)
|
||||
cid = current_user.current_tenant_id
|
||||
assert cid is not None
|
||||
provider_manager = create_plugin_provider_manager(tenant_id=cid)
|
||||
configurations = provider_manager.get_configurations(tenant_id=cid)
|
||||
|
||||
embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
|
||||
|
||||
@@ -105,7 +105,7 @@ class SegmentApi(DatasetApiResource):
|
||||
# check embedding model setting
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id)
|
||||
model_manager.get_model_instance(
|
||||
tenant_id=current_tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
@@ -159,7 +159,7 @@ class SegmentApi(DatasetApiResource):
|
||||
# check embedding model setting
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id)
|
||||
model_manager.get_model_instance(
|
||||
tenant_id=current_tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
@@ -265,7 +265,7 @@ class DatasetSegmentApi(DatasetApiResource):
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
# check embedding model setting
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id)
|
||||
model_manager.get_model_instance(
|
||||
tenant_id=current_tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
@@ -360,7 +360,7 @@ class ChildChunkApi(DatasetApiResource):
|
||||
# check embedding model setting
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id)
|
||||
model_manager.get_model_instance(
|
||||
tenant_id=current_tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
|
||||
@@ -122,7 +122,6 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
tools=[],
|
||||
stop=app_generate_entity.model_conf.stop,
|
||||
stream=True,
|
||||
user=self.user_id,
|
||||
callbacks=[],
|
||||
)
|
||||
|
||||
|
||||
@@ -96,7 +96,6 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
tools=prompt_messages_tools,
|
||||
stop=app_generate_entity.model_conf.stop,
|
||||
stream=self.stream_tool_call,
|
||||
user=self.user_id,
|
||||
callbacks=[],
|
||||
)
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ from core.app.app_config.entities import EasyUIBasedAppConfig
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.entities.model_entities import ModelStatus
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.provider_manager import ProviderManager
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMMode
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||
from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
@@ -21,7 +21,7 @@ class ModelConfigConverter:
|
||||
"""
|
||||
model_config = app_config.model
|
||||
|
||||
provider_manager = ProviderManager()
|
||||
provider_manager = create_plugin_provider_manager(tenant_id=app_config.tenant_id)
|
||||
provider_model_bundle = provider_manager.get_provider_model_bundle(
|
||||
tenant_id=app_config.tenant_id, provider=model_config.provider, model_type=ModelType.LLM
|
||||
)
|
||||
|
||||
@@ -2,9 +2,8 @@ from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.app.app_config.entities import ModelConfigEntity
|
||||
from core.provider_manager import ProviderManager
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory, create_plugin_provider_manager
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||
from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from models.model import AppModelConfigDict
|
||||
from models.provider_ids import ModelProviderID
|
||||
|
||||
@@ -55,7 +54,7 @@ class ModelConfigManager:
|
||||
raise ValueError("model must be of object type")
|
||||
|
||||
# model.provider
|
||||
model_provider_factory = ModelProviderFactory(tenant_id)
|
||||
model_provider_factory = create_plugin_model_provider_factory(tenant_id=tenant_id)
|
||||
provider_entities = model_provider_factory.get_providers()
|
||||
model_provider_names = [provider.provider for provider in provider_entities]
|
||||
if "provider" not in config["model"]:
|
||||
@@ -71,7 +70,7 @@ class ModelConfigManager:
|
||||
if "name" not in config["model"]:
|
||||
raise ValueError("model.name is required")
|
||||
|
||||
provider_manager = ProviderManager()
|
||||
provider_manager = create_plugin_provider_manager(tenant_id=tenant_id)
|
||||
models = provider_manager.get_configurations(tenant_id).get_models(
|
||||
provider=config["model"]["provider"], model_type=ModelType.LLM
|
||||
)
|
||||
|
||||
@@ -7,7 +7,6 @@ from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validat
|
||||
from constants import UUID_NIL
|
||||
from core.app.app_config.entities import EasyUIBasedAppConfig, WorkflowUIBasedAppConfig
|
||||
from core.entities.provider_configuration import ProviderModelBundle
|
||||
from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY
|
||||
from dify_graph.file import File, FileUploadConfig
|
||||
from dify_graph.model_runtime.entities.model_entities import AIModelEntity
|
||||
|
||||
@@ -15,6 +14,9 @@ if TYPE_CHECKING:
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
|
||||
|
||||
DIFY_RUN_CONTEXT_KEY = "_dify"
|
||||
|
||||
|
||||
class UserFrom(StrEnum):
|
||||
ACCOUNT = "account"
|
||||
END_USER = "end-user"
|
||||
|
||||
@@ -2,23 +2,34 @@ from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.app.entities.app_invoke_entities import DifyRunContext, ModelConfigWithCredentialsEntity
|
||||
from core.app.llm.protocols import CredentialsProvider, ModelFactory
|
||||
from core.errors.error import ProviderTokenNotInitError
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
|
||||
from core.provider_manager import ProviderManager
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from dify_graph.nodes.llm.entities import ModelConfig
|
||||
from dify_graph.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError
|
||||
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||
|
||||
|
||||
class DifyCredentialsProvider:
|
||||
tenant_id: str
|
||||
provider_manager: ProviderManager
|
||||
|
||||
def __init__(self, tenant_id: str, provider_manager: ProviderManager | None = None) -> None:
|
||||
self.tenant_id = tenant_id
|
||||
self.provider_manager = provider_manager or ProviderManager()
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
run_context: DifyRunContext,
|
||||
provider_manager: ProviderManager | None = None,
|
||||
) -> None:
|
||||
self.tenant_id = run_context.tenant_id
|
||||
if provider_manager is None:
|
||||
provider_manager = create_plugin_provider_manager(
|
||||
tenant_id=run_context.tenant_id,
|
||||
user_id=run_context.user_id,
|
||||
)
|
||||
self.provider_manager = provider_manager
|
||||
|
||||
def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]:
|
||||
provider_configurations = self.provider_manager.get_configurations(self.tenant_id)
|
||||
@@ -42,9 +53,21 @@ class DifyModelFactory:
|
||||
tenant_id: str
|
||||
model_manager: ModelManager
|
||||
|
||||
def __init__(self, tenant_id: str, model_manager: ModelManager | None = None) -> None:
|
||||
self.tenant_id = tenant_id
|
||||
self.model_manager = model_manager or ModelManager()
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
run_context: DifyRunContext,
|
||||
model_manager: ModelManager | None = None,
|
||||
) -> None:
|
||||
self.tenant_id = run_context.tenant_id
|
||||
if model_manager is None:
|
||||
model_manager = ModelManager(
|
||||
provider_manager=create_plugin_provider_manager(
|
||||
tenant_id=run_context.tenant_id,
|
||||
user_id=run_context.user_id,
|
||||
)
|
||||
)
|
||||
self.model_manager = model_manager
|
||||
|
||||
def init_model_instance(self, provider_name: str, model_name: str) -> ModelInstance:
|
||||
return self.model_manager.get_model_instance(
|
||||
@@ -55,11 +78,35 @@ class DifyModelFactory:
|
||||
)
|
||||
|
||||
|
||||
def build_dify_model_access(tenant_id: str) -> tuple[CredentialsProvider, ModelFactory]:
|
||||
return (
|
||||
DifyCredentialsProvider(tenant_id=tenant_id),
|
||||
DifyModelFactory(tenant_id=tenant_id),
|
||||
def build_dify_model_access(run_context: DifyRunContext) -> tuple[CredentialsProvider, ModelFactory]:
|
||||
"""Create LLM access adapters that share the same tenant-bound manager graph."""
|
||||
provider_manager = create_plugin_provider_manager(
|
||||
tenant_id=run_context.tenant_id,
|
||||
user_id=run_context.user_id,
|
||||
)
|
||||
model_manager = ModelManager(provider_manager=provider_manager)
|
||||
|
||||
return (
|
||||
DifyCredentialsProvider(run_context=run_context, provider_manager=provider_manager),
|
||||
DifyModelFactory(run_context=run_context, model_manager=model_manager),
|
||||
)
|
||||
|
||||
|
||||
def _normalize_completion_params(completion_params: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
|
||||
"""
|
||||
Split node-level completion params into provider parameters and stop sequences.
|
||||
|
||||
Workflow LLM-compatible nodes still consume runtime invocation settings from
|
||||
``ModelInstance.parameters`` and ``ModelInstance.stop``. Keep the
|
||||
``ModelInstance`` view and the returned config entity aligned here so callers
|
||||
do not need to duplicate normalization logic.
|
||||
"""
|
||||
normalized_parameters = dict(completion_params)
|
||||
stop = normalized_parameters.pop("stop", [])
|
||||
if not isinstance(stop, list) or not all(isinstance(item, str) for item in stop):
|
||||
stop = []
|
||||
|
||||
return normalized_parameters, stop
|
||||
|
||||
|
||||
def fetch_model_config(
|
||||
@@ -80,22 +127,18 @@ def fetch_model_config(
|
||||
model_type=ModelType.LLM,
|
||||
)
|
||||
if provider_model is None:
|
||||
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
|
||||
raise ModelNotExistError(f"Model {node_data_model.name} does not exist.")
|
||||
provider_model.raise_for_status()
|
||||
|
||||
completion_params = dict(node_data_model.completion_params)
|
||||
stop = completion_params.pop("stop", [])
|
||||
if not isinstance(stop, list):
|
||||
stop = []
|
||||
|
||||
model_schema = model_instance.model_type_instance.get_model_schema(node_data_model.name, credentials)
|
||||
if not model_schema:
|
||||
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
|
||||
if model_schema is None:
|
||||
raise ModelNotExistError(f"Model {node_data_model.name} schema does not exist.")
|
||||
|
||||
parameters, stop = _normalize_completion_params(node_data_model.completion_params)
|
||||
model_instance.provider = node_data_model.provider
|
||||
model_instance.model_name = node_data_model.name
|
||||
model_instance.credentials = credentials
|
||||
model_instance.parameters = completion_params
|
||||
model_instance.parameters = parameters
|
||||
model_instance.stop = tuple(stop)
|
||||
|
||||
return model_instance, ModelConfigWithCredentialsEntity(
|
||||
@@ -103,8 +146,8 @@ def fetch_model_config(
|
||||
model=node_data_model.name,
|
||||
model_schema=model_schema,
|
||||
mode=node_data_model.mode,
|
||||
provider_model_bundle=provider_model_bundle,
|
||||
credentials=credentials,
|
||||
parameters=completion_params,
|
||||
parameters=parameters,
|
||||
stop=stop,
|
||||
provider_model_bundle=provider_model_bundle,
|
||||
)
|
||||
|
||||
21
api/core/app/llm/protocols.py
Normal file
21
api/core/app/llm/protocols.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Protocol
|
||||
|
||||
from core.model_manager import ModelInstance
|
||||
|
||||
|
||||
class CredentialsProvider(Protocol):
|
||||
"""Workflow-layer port for loading runtime credentials for a provider/model pair."""
|
||||
|
||||
def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]:
|
||||
"""Return credentials for the target provider/model or raise a domain error."""
|
||||
...
|
||||
|
||||
|
||||
class ModelFactory(Protocol):
|
||||
"""Workflow-layer port for creating mutable ModelInstance objects."""
|
||||
|
||||
def init_model_instance(self, provider_name: str, model_name: str) -> ModelInstance:
|
||||
"""Create a model instance that is ready for workflow-side hydration."""
|
||||
...
|
||||
@@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, cast, final
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext
|
||||
from core.app.llm import deduct_llm_quota, ensure_llm_quota_available
|
||||
from core.errors.error import QuotaExceededError
|
||||
from core.model_manager import ModelInstance
|
||||
@@ -75,7 +76,7 @@ class LLMQuotaLayer(GraphEngineLayer):
|
||||
return
|
||||
|
||||
try:
|
||||
dify_ctx = node.require_dify_context()
|
||||
dify_ctx = DifyRunContext.model_validate(node.require_run_context_value(DIFY_RUN_CONTEXT_KEY))
|
||||
deduct_llm_quota(
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
model_instance=model_instance,
|
||||
|
||||
@@ -25,12 +25,10 @@ class AudioTrunk:
|
||||
self.status = status
|
||||
|
||||
|
||||
def _invoice_tts(text_content: str, model_instance: ModelInstance, tenant_id: str, voice: str):
|
||||
def _invoice_tts(text_content: str, model_instance: ModelInstance, voice: str):
|
||||
if not text_content or text_content.isspace():
|
||||
return
|
||||
return model_instance.invoke_tts(
|
||||
content_text=text_content.strip(), user="responding_tts", tenant_id=tenant_id, voice=voice
|
||||
)
|
||||
return model_instance.invoke_tts(content_text=text_content.strip(), voice=voice)
|
||||
|
||||
|
||||
def _process_future(
|
||||
@@ -62,7 +60,7 @@ class AppGeneratorTTSPublisher:
|
||||
self._audio_queue: queue.Queue[AudioTrunk] = queue.Queue()
|
||||
self._msg_queue: queue.Queue[WorkflowQueueMessage | MessageQueueMessage | None] = queue.Queue()
|
||||
self.match = re.compile(r"[。.!?]")
|
||||
self.model_manager = ModelManager()
|
||||
self.model_manager = ModelManager.for_tenant(tenant_id=self.tenant_id, user_id="responding_tts")
|
||||
self.model_instance = self.model_manager.get_default_model_instance(
|
||||
tenant_id=self.tenant_id, model_type=ModelType.TTS
|
||||
)
|
||||
@@ -89,7 +87,7 @@ class AppGeneratorTTSPublisher:
|
||||
if message is None:
|
||||
if self.msg_text and len(self.msg_text.strip()) > 0:
|
||||
futures_result = self.executor.submit(
|
||||
_invoice_tts, self.msg_text, self.model_instance, self.tenant_id, self.voice
|
||||
_invoice_tts, self.msg_text, self.model_instance, self.voice
|
||||
)
|
||||
future_queue.put(futures_result)
|
||||
break
|
||||
@@ -117,9 +115,7 @@ class AppGeneratorTTSPublisher:
|
||||
if len(sentence_arr) >= min(self.max_sentence, 7):
|
||||
self.max_sentence += 1
|
||||
text_content = "".join(sentence_arr)
|
||||
futures_result = self.executor.submit(
|
||||
_invoice_tts, text_content, self.model_instance, self.tenant_id, self.voice
|
||||
)
|
||||
futures_result = self.executor.submit(_invoice_tts, text_content, self.model_instance, self.voice)
|
||||
future_queue.put(futures_result)
|
||||
if isinstance(text_tmp, str):
|
||||
self.msg_text = text_tmp
|
||||
|
||||
@@ -1,10 +1,5 @@
|
||||
from enum import StrEnum, auto
|
||||
"""Compatibility wrapper for the runtime embedding input enum."""
|
||||
|
||||
from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingInputType
|
||||
|
||||
class EmbeddingInputType(StrEnum):
|
||||
"""
|
||||
Enum for embedding input type.
|
||||
"""
|
||||
|
||||
DOCUMENT = auto()
|
||||
QUERY = auto()
|
||||
__all__ = ["EmbeddingInputType"]
|
||||
|
||||
@@ -19,6 +19,7 @@ from core.entities.provider_entities import (
|
||||
)
|
||||
from core.helper import encrypter
|
||||
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory
|
||||
from dify_graph.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
|
||||
from dify_graph.model_runtime.entities.provider_entities import (
|
||||
ConfigurateMethod,
|
||||
@@ -27,7 +28,6 @@ from dify_graph.model_runtime.entities.provider_entities import (
|
||||
ProviderEntity,
|
||||
)
|
||||
from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.engine import db
|
||||
from models.enums import CredentialSourceType
|
||||
@@ -343,7 +343,7 @@ class ProviderConfiguration(BaseModel):
|
||||
tenant_id=self.tenant_id, token=original_credentials[key]
|
||||
)
|
||||
|
||||
model_provider_factory = ModelProviderFactory(self.tenant_id)
|
||||
model_provider_factory = create_plugin_model_provider_factory(tenant_id=self.tenant_id)
|
||||
validated_credentials = model_provider_factory.provider_credentials_validate(
|
||||
provider=self.provider.provider, credentials=credentials
|
||||
)
|
||||
@@ -902,7 +902,7 @@ class ProviderConfiguration(BaseModel):
|
||||
tenant_id=self.tenant_id, token=original_credentials[key]
|
||||
)
|
||||
|
||||
model_provider_factory = ModelProviderFactory(self.tenant_id)
|
||||
model_provider_factory = create_plugin_model_provider_factory(tenant_id=self.tenant_id)
|
||||
validated_credentials = model_provider_factory.model_credentials_validate(
|
||||
provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
|
||||
)
|
||||
@@ -1388,7 +1388,7 @@ class ProviderConfiguration(BaseModel):
|
||||
:param model_type: model type
|
||||
:return:
|
||||
"""
|
||||
model_provider_factory = ModelProviderFactory(self.tenant_id)
|
||||
model_provider_factory = create_plugin_model_provider_factory(tenant_id=self.tenant_id)
|
||||
|
||||
# Get model instance of LLM
|
||||
return model_provider_factory.get_model_type_instance(provider=self.provider.provider, model_type=model_type)
|
||||
@@ -1397,7 +1397,7 @@ class ProviderConfiguration(BaseModel):
|
||||
"""
|
||||
Get model schema
|
||||
"""
|
||||
model_provider_factory = ModelProviderFactory(self.tenant_id)
|
||||
model_provider_factory = create_plugin_model_provider_factory(tenant_id=self.tenant_id)
|
||||
return model_provider_factory.get_model_schema(
|
||||
provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
|
||||
)
|
||||
@@ -1499,7 +1499,7 @@ class ProviderConfiguration(BaseModel):
|
||||
:param model: model name
|
||||
:return:
|
||||
"""
|
||||
model_provider_factory = ModelProviderFactory(self.tenant_id)
|
||||
model_provider_factory = create_plugin_model_provider_factory(tenant_id=self.tenant_id)
|
||||
provider_schema = model_provider_factory.get_provider_schema(self.provider.provider)
|
||||
|
||||
model_types: list[ModelType] = []
|
||||
|
||||
@@ -4,10 +4,10 @@ from typing import cast
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.entities import DEFAULT_PLUGIN_ID
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from dify_graph.model_runtime.errors.invoke import InvokeBadRequestError
|
||||
from dify_graph.model_runtime.model_providers.__base.moderation_model import ModerationModel
|
||||
from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from extensions.ext_hosting_provider import hosting_configuration
|
||||
from models.provider import ProviderType
|
||||
|
||||
@@ -41,7 +41,7 @@ def check_moderation(tenant_id: str, model_config: ModelConfigWithCredentialsEnt
|
||||
text_chunk = secrets.choice(text_chunks)
|
||||
|
||||
try:
|
||||
model_provider_factory = ModelProviderFactory(tenant_id)
|
||||
model_provider_factory = create_plugin_model_provider_factory(tenant_id=tenant_id)
|
||||
|
||||
# Get model instance of LLM
|
||||
model_type_instance = model_provider_factory.get_model_type_instance(
|
||||
|
||||
@@ -50,7 +50,10 @@ logger = logging.getLogger(__name__)
|
||||
class IndexingRunner:
|
||||
def __init__(self):
|
||||
self.storage = storage
|
||||
self.model_manager = ModelManager()
|
||||
|
||||
@staticmethod
|
||||
def _get_model_manager(tenant_id: str) -> ModelManager:
|
||||
return ModelManager.for_tenant(tenant_id=tenant_id)
|
||||
|
||||
def _handle_indexing_error(self, document_id: str, error: Exception) -> None:
|
||||
"""Handle indexing errors by updating document status."""
|
||||
@@ -291,20 +294,20 @@ class IndexingRunner:
|
||||
raise ValueError("Dataset not found.")
|
||||
if dataset.indexing_technique == "high_quality" or indexing_technique == "high_quality":
|
||||
if dataset.embedding_model_provider:
|
||||
embedding_model_instance = self.model_manager.get_model_instance(
|
||||
embedding_model_instance = self._get_model_manager(tenant_id).get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model,
|
||||
)
|
||||
else:
|
||||
embedding_model_instance = self.model_manager.get_default_model_instance(
|
||||
embedding_model_instance = self._get_model_manager(tenant_id).get_default_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
)
|
||||
else:
|
||||
if indexing_technique == "high_quality":
|
||||
embedding_model_instance = self.model_manager.get_default_model_instance(
|
||||
embedding_model_instance = self._get_model_manager(tenant_id).get_default_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
)
|
||||
@@ -574,7 +577,7 @@ class IndexingRunner:
|
||||
|
||||
embedding_model_instance = None
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
embedding_model_instance = self.model_manager.get_model_instance(
|
||||
embedding_model_instance = self._get_model_manager(dataset.tenant_id).get_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
@@ -766,14 +769,14 @@ class IndexingRunner:
|
||||
embedding_model_instance = None
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.embedding_model_provider:
|
||||
embedding_model_instance = self.model_manager.get_model_instance(
|
||||
embedding_model_instance = self._get_model_manager(dataset.tenant_id).get_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model,
|
||||
)
|
||||
else:
|
||||
embedding_model_instance = self.model_manager.get_default_model_instance(
|
||||
embedding_model_instance = self._get_model_manager(dataset.tenant_id).get_default_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
)
|
||||
|
||||
@@ -62,7 +62,7 @@ class LLMGenerator:
|
||||
|
||||
prompt += query + "\n"
|
||||
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
@@ -120,7 +120,7 @@ class LLMGenerator:
|
||||
prompt = prompt_template.format({"histories": histories, "format_instructions": format_instructions})
|
||||
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
@@ -172,7 +172,7 @@ class LLMGenerator:
|
||||
|
||||
prompt_messages = [UserPromptMessage(content=prompt_generate)]
|
||||
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
|
||||
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
@@ -219,7 +219,7 @@ class LLMGenerator:
|
||||
prompt_messages = [UserPromptMessage(content=prompt_generate_prompt)]
|
||||
|
||||
# get model instance
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
@@ -306,7 +306,7 @@ class LLMGenerator:
|
||||
remove_template_variables=False,
|
||||
)
|
||||
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
@@ -337,7 +337,7 @@ class LLMGenerator:
|
||||
def generate_qa_document(cls, tenant_id: str, query, document_language: str):
|
||||
prompt = GENERATOR_QA_PROMPT.format(language=document_language)
|
||||
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
@@ -362,7 +362,7 @@ class LLMGenerator:
|
||||
|
||||
@classmethod
|
||||
def generate_structured_output(cls, tenant_id: str, args: RuleStructuredOutputPayload):
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
@@ -536,7 +536,7 @@ class LLMGenerator:
|
||||
injected_instruction = injected_instruction.replace(CURRENT, current or "null")
|
||||
if ERROR_MESSAGE in injected_instruction:
|
||||
injected_instruction = injected_instruction.replace(ERROR_MESSAGE, error_message or "null")
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
model_instance = ModelManager.for_tenant(tenant_id=tenant_id).get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
provider=model_config.provider,
|
||||
|
||||
@@ -55,7 +55,6 @@ def invoke_llm_with_structured_output(
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
stream: Literal[True],
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
|
||||
@overload
|
||||
@@ -70,7 +69,6 @@ def invoke_llm_with_structured_output(
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
stream: Literal[False],
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
) -> LLMResultWithStructuredOutput: ...
|
||||
@overload
|
||||
@@ -85,7 +83,6 @@ def invoke_llm_with_structured_output(
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
|
||||
def invoke_llm_with_structured_output(
|
||||
@@ -99,7 +96,6 @@ def invoke_llm_with_structured_output(
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]:
|
||||
"""
|
||||
@@ -113,7 +109,6 @@ def invoke_llm_with_structured_output(
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:param callbacks: callbacks
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
@@ -143,7 +138,6 @@ def invoke_llm_with_structured_output(
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ from core.entities.embedding_type import EmbeddingInputType
|
||||
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
|
||||
from core.entities.provider_entities import ModelLoadBalancingConfiguration
|
||||
from core.errors.error import ProviderTokenNotInitError
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
|
||||
from core.provider_manager import ProviderManager
|
||||
from dify_graph.model_runtime.callbacks.base_callback import Callback
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMResult
|
||||
@@ -30,7 +31,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class ModelInstance:
|
||||
"""
|
||||
Model instance class
|
||||
Model instance class.
|
||||
"""
|
||||
|
||||
def __init__(self, provider_model_bundle: ProviderModelBundle, model: str):
|
||||
@@ -110,7 +111,6 @@ class ModelInstance:
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
stream: Literal[True] = True,
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
) -> Generator: ...
|
||||
|
||||
@@ -122,7 +122,6 @@ class ModelInstance:
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
stream: Literal[False] = False,
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
) -> LLMResult: ...
|
||||
|
||||
@@ -134,7 +133,6 @@ class ModelInstance:
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
) -> Union[LLMResult, Generator]: ...
|
||||
|
||||
@@ -145,7 +143,6 @@ class ModelInstance:
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: Sequence[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
@@ -156,7 +153,6 @@ class ModelInstance:
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:param callbacks: callbacks
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
@@ -173,7 +169,6 @@ class ModelInstance:
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
callbacks=callbacks,
|
||||
),
|
||||
)
|
||||
@@ -202,13 +197,12 @@ class ModelInstance:
|
||||
)
|
||||
|
||||
def invoke_text_embedding(
|
||||
self, texts: list[str], user: str | None = None, input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT
|
||||
self, texts: list[str], input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT
|
||||
) -> EmbeddingResult:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param texts: texts to embed
|
||||
:param user: unique user id
|
||||
:param input_type: input type
|
||||
:return: embeddings result
|
||||
"""
|
||||
@@ -221,7 +215,6 @@ class ModelInstance:
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
texts=texts,
|
||||
user=user,
|
||||
input_type=input_type,
|
||||
),
|
||||
)
|
||||
@@ -229,14 +222,12 @@ class ModelInstance:
|
||||
def invoke_multimodal_embedding(
|
||||
self,
|
||||
multimodel_documents: list[dict],
|
||||
user: str | None = None,
|
||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||
) -> EmbeddingResult:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param multimodel_documents: multimodel documents to embed
|
||||
:param user: unique user id
|
||||
:param input_type: input type
|
||||
:return: embeddings result
|
||||
"""
|
||||
@@ -249,7 +240,6 @@ class ModelInstance:
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
multimodel_documents=multimodel_documents,
|
||||
user=user,
|
||||
input_type=input_type,
|
||||
),
|
||||
)
|
||||
@@ -279,7 +269,6 @@ class ModelInstance:
|
||||
docs: list[str],
|
||||
score_threshold: float | None = None,
|
||||
top_n: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> RerankResult:
|
||||
"""
|
||||
Invoke rerank model
|
||||
@@ -288,7 +277,6 @@ class ModelInstance:
|
||||
:param docs: docs for reranking
|
||||
:param score_threshold: score threshold
|
||||
:param top_n: top n
|
||||
:param user: unique user id
|
||||
:return: rerank result
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, RerankModel):
|
||||
@@ -303,7 +291,6 @@ class ModelInstance:
|
||||
docs=docs,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_n,
|
||||
user=user,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -313,7 +300,6 @@ class ModelInstance:
|
||||
docs: list[dict],
|
||||
score_threshold: float | None = None,
|
||||
top_n: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> RerankResult:
|
||||
"""
|
||||
Invoke rerank model
|
||||
@@ -322,7 +308,6 @@ class ModelInstance:
|
||||
:param docs: docs for reranking
|
||||
:param score_threshold: score threshold
|
||||
:param top_n: top n
|
||||
:param user: unique user id
|
||||
:return: rerank result
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, RerankModel):
|
||||
@@ -337,16 +322,14 @@ class ModelInstance:
|
||||
docs=docs,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_n,
|
||||
user=user,
|
||||
),
|
||||
)
|
||||
|
||||
def invoke_moderation(self, text: str, user: str | None = None) -> bool:
|
||||
def invoke_moderation(self, text: str) -> bool:
|
||||
"""
|
||||
Invoke moderation model
|
||||
|
||||
:param text: text to moderate
|
||||
:param user: unique user id
|
||||
:return: false if text is safe, true otherwise
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, ModerationModel):
|
||||
@@ -358,16 +341,14 @@ class ModelInstance:
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
text=text,
|
||||
user=user,
|
||||
),
|
||||
)
|
||||
|
||||
def invoke_speech2text(self, file: IO[bytes], user: str | None = None) -> str:
|
||||
def invoke_speech2text(self, file: IO[bytes]) -> str:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param file: audio file
|
||||
:param user: unique user id
|
||||
:return: text for given audio file
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, Speech2TextModel):
|
||||
@@ -379,18 +360,15 @@ class ModelInstance:
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
file=file,
|
||||
user=user,
|
||||
),
|
||||
)
|
||||
|
||||
def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: str | None = None) -> Iterable[bytes]:
|
||||
def invoke_tts(self, content_text: str, voice: str = "") -> Iterable[bytes]:
|
||||
"""
|
||||
Invoke large language tts model
|
||||
|
||||
:param content_text: text content to be translated
|
||||
:param tenant_id: user tenant id
|
||||
:param voice: model timbre
|
||||
:param user: unique user id
|
||||
:return: text for given audio file
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, TTSModel):
|
||||
@@ -402,8 +380,6 @@ class ModelInstance:
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
content_text=content_text,
|
||||
user=user,
|
||||
tenant_id=tenant_id,
|
||||
voice=voice,
|
||||
),
|
||||
)
|
||||
@@ -477,10 +453,20 @@ class ModelInstance:
|
||||
|
||||
|
||||
class ModelManager:
|
||||
def __init__(self):
|
||||
self._provider_manager = ProviderManager()
|
||||
def __init__(self, provider_manager: ProviderManager):
|
||||
self._provider_manager = provider_manager
|
||||
|
||||
def get_model_instance(self, tenant_id: str, provider: str, model_type: ModelType, model: str) -> ModelInstance:
|
||||
@classmethod
|
||||
def for_tenant(cls, tenant_id: str, user_id: str | None = None) -> "ModelManager":
|
||||
return cls(provider_manager=create_plugin_provider_manager(tenant_id=tenant_id, user_id=user_id))
|
||||
|
||||
def get_model_instance(
|
||||
self,
|
||||
tenant_id: str,
|
||||
provider: str,
|
||||
model_type: ModelType,
|
||||
model: str,
|
||||
) -> ModelInstance:
|
||||
"""
|
||||
Get model instance
|
||||
:param tenant_id: tenant id
|
||||
@@ -496,7 +482,8 @@ class ModelManager:
|
||||
tenant_id=tenant_id, provider=provider, model_type=model_type
|
||||
)
|
||||
|
||||
return ModelInstance(provider_model_bundle, model)
|
||||
model_instance = ModelInstance(provider_model_bundle, model)
|
||||
return model_instance
|
||||
|
||||
def get_default_provider_model_name(self, tenant_id: str, model_type: ModelType) -> tuple[str | None, str | None]:
|
||||
"""
|
||||
|
||||
@@ -50,7 +50,7 @@ class OpenAIModeration(Moderation):
|
||||
|
||||
def _is_violated(self, inputs: dict):
|
||||
text = "\n".join(str(inputs.values()))
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=self.tenant_id)
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=self.tenant_id, provider="openai", model_type=ModelType.MODERATION, model="omni-moderation-latest"
|
||||
)
|
||||
|
||||
@@ -30,10 +30,27 @@ from dify_graph.model_runtime.entities.message_entities import (
|
||||
SystemPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from models.account import Tenant
|
||||
|
||||
|
||||
class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
@staticmethod
|
||||
def _get_bound_model_instance(
|
||||
*,
|
||||
tenant_id: str,
|
||||
user_id: str | None,
|
||||
provider: str,
|
||||
model_type: ModelType,
|
||||
model: str,
|
||||
):
|
||||
return ModelManager.for_tenant(tenant_id=tenant_id, user_id=user_id).get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
model_type=model_type,
|
||||
model=model,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def invoke_llm(
|
||||
cls, user_id: str, tenant: Tenant, payload: RequestInvokeLLM
|
||||
@@ -41,8 +58,9 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
"""
|
||||
invoke llm
|
||||
"""
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
model_instance = cls._get_bound_model_instance(
|
||||
tenant_id=tenant.id,
|
||||
user_id=user_id,
|
||||
provider=payload.provider,
|
||||
model_type=payload.model_type,
|
||||
model=payload.model,
|
||||
@@ -55,7 +73,6 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
tools=payload.tools,
|
||||
stop=payload.stop,
|
||||
stream=True if payload.stream is None else payload.stream,
|
||||
user=user_id,
|
||||
)
|
||||
|
||||
if isinstance(response, Generator):
|
||||
@@ -94,8 +111,9 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
"""
|
||||
invoke llm with structured output
|
||||
"""
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
model_instance = cls._get_bound_model_instance(
|
||||
tenant_id=tenant.id,
|
||||
user_id=user_id,
|
||||
provider=payload.provider,
|
||||
model_type=payload.model_type,
|
||||
model=payload.model,
|
||||
@@ -115,7 +133,6 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
tools=payload.tools,
|
||||
stop=payload.stop,
|
||||
stream=True if payload.stream is None else payload.stream,
|
||||
user=user_id,
|
||||
model_parameters=payload.completion_params,
|
||||
)
|
||||
|
||||
@@ -156,18 +173,16 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
"""
|
||||
invoke text embedding
|
||||
"""
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
model_instance = cls._get_bound_model_instance(
|
||||
tenant_id=tenant.id,
|
||||
user_id=user_id,
|
||||
provider=payload.provider,
|
||||
model_type=payload.model_type,
|
||||
model=payload.model,
|
||||
)
|
||||
|
||||
# invoke model
|
||||
response = model_instance.invoke_text_embedding(
|
||||
texts=payload.texts,
|
||||
user=user_id,
|
||||
)
|
||||
response = model_instance.invoke_text_embedding(texts=payload.texts)
|
||||
|
||||
return response
|
||||
|
||||
@@ -176,8 +191,9 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
"""
|
||||
invoke rerank
|
||||
"""
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
model_instance = cls._get_bound_model_instance(
|
||||
tenant_id=tenant.id,
|
||||
user_id=user_id,
|
||||
provider=payload.provider,
|
||||
model_type=payload.model_type,
|
||||
model=payload.model,
|
||||
@@ -189,7 +205,6 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
docs=payload.docs,
|
||||
score_threshold=payload.score_threshold,
|
||||
top_n=payload.top_n,
|
||||
user=user_id,
|
||||
)
|
||||
|
||||
return response
|
||||
@@ -199,20 +214,16 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
"""
|
||||
invoke tts
|
||||
"""
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
model_instance = cls._get_bound_model_instance(
|
||||
tenant_id=tenant.id,
|
||||
user_id=user_id,
|
||||
provider=payload.provider,
|
||||
model_type=payload.model_type,
|
||||
model=payload.model,
|
||||
)
|
||||
|
||||
# invoke model
|
||||
response = model_instance.invoke_tts(
|
||||
content_text=payload.content_text,
|
||||
tenant_id=tenant.id,
|
||||
voice=payload.voice,
|
||||
user=user_id,
|
||||
)
|
||||
response = model_instance.invoke_tts(content_text=payload.content_text, voice=payload.voice)
|
||||
|
||||
def handle() -> Generator[dict, None, None]:
|
||||
for chunk in response:
|
||||
@@ -225,8 +236,9 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
"""
|
||||
invoke speech2text
|
||||
"""
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
model_instance = cls._get_bound_model_instance(
|
||||
tenant_id=tenant.id,
|
||||
user_id=user_id,
|
||||
provider=payload.provider,
|
||||
model_type=payload.model_type,
|
||||
model=payload.model,
|
||||
@@ -238,10 +250,7 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
temp.flush()
|
||||
temp.seek(0)
|
||||
|
||||
response = model_instance.invoke_speech2text(
|
||||
file=temp,
|
||||
user=user_id,
|
||||
)
|
||||
response = model_instance.invoke_speech2text(file=temp)
|
||||
|
||||
return {
|
||||
"result": response,
|
||||
@@ -252,18 +261,16 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
"""
|
||||
invoke moderation
|
||||
"""
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
model_instance = cls._get_bound_model_instance(
|
||||
tenant_id=tenant.id,
|
||||
user_id=user_id,
|
||||
provider=payload.provider,
|
||||
model_type=payload.model_type,
|
||||
model=payload.model,
|
||||
)
|
||||
|
||||
# invoke model
|
||||
response = model_instance.invoke_moderation(
|
||||
text=payload.text,
|
||||
user=user_id,
|
||||
)
|
||||
response = model_instance.invoke_moderation(text=payload.text)
|
||||
|
||||
return {
|
||||
"result": response,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import binascii
|
||||
from collections.abc import Generator, Sequence
|
||||
from typing import IO
|
||||
from typing import IO, Any
|
||||
|
||||
from core.plugin.entities.plugin_daemon import (
|
||||
PluginBasicBooleanResponse,
|
||||
@@ -16,12 +16,19 @@ from core.plugin.impl.base import BasePluginClient
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMResultChunk
|
||||
from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||
from dify_graph.model_runtime.entities.model_entities import AIModelEntity
|
||||
from dify_graph.model_runtime.entities.rerank_entities import RerankResult
|
||||
from dify_graph.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult
|
||||
from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingResult
|
||||
from dify_graph.model_runtime.utils.encoders import jsonable_encoder
|
||||
|
||||
|
||||
class PluginModelClient(BasePluginClient):
|
||||
@staticmethod
|
||||
def _dispatch_payload(*, user_id: str | None, data: dict[str, Any]) -> dict[str, Any]:
|
||||
payload: dict[str, Any] = {"data": data}
|
||||
if user_id is not None:
|
||||
payload["user_id"] = user_id
|
||||
return payload
|
||||
|
||||
def fetch_model_providers(self, tenant_id: str) -> Sequence[PluginModelProviderEntity]:
|
||||
"""
|
||||
Fetch model providers for the given tenant.
|
||||
@@ -37,7 +44,7 @@ class PluginModelClient(BasePluginClient):
|
||||
def get_model_schema(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
user_id: str | None,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model_type: str,
|
||||
@@ -51,15 +58,15 @@ class PluginModelClient(BasePluginClient):
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/model/schema",
|
||||
PluginModelSchemaEntity,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
data=self._dispatch_payload(
|
||||
user_id=user_id,
|
||||
data={
|
||||
"provider": provider,
|
||||
"model_type": model_type,
|
||||
"model": model,
|
||||
"credentials": credentials,
|
||||
},
|
||||
},
|
||||
),
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
@@ -72,7 +79,7 @@ class PluginModelClient(BasePluginClient):
|
||||
return None
|
||||
|
||||
def validate_provider_credentials(
|
||||
self, tenant_id: str, user_id: str, plugin_id: str, provider: str, credentials: dict
|
||||
self, tenant_id: str, user_id: str | None, plugin_id: str, provider: str, credentials: dict
|
||||
) -> bool:
|
||||
"""
|
||||
validate the credentials of the provider
|
||||
@@ -81,13 +88,13 @@ class PluginModelClient(BasePluginClient):
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/model/validate_provider_credentials",
|
||||
PluginBasicBooleanResponse,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
data=self._dispatch_payload(
|
||||
user_id=user_id,
|
||||
data={
|
||||
"provider": provider,
|
||||
"credentials": credentials,
|
||||
},
|
||||
},
|
||||
),
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
@@ -105,7 +112,7 @@ class PluginModelClient(BasePluginClient):
|
||||
def validate_model_credentials(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
user_id: str | None,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model_type: str,
|
||||
@@ -119,15 +126,15 @@ class PluginModelClient(BasePluginClient):
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/model/validate_model_credentials",
|
||||
PluginBasicBooleanResponse,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
data=self._dispatch_payload(
|
||||
user_id=user_id,
|
||||
data={
|
||||
"provider": provider,
|
||||
"model_type": model_type,
|
||||
"model": model,
|
||||
"credentials": credentials,
|
||||
},
|
||||
},
|
||||
),
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
@@ -145,7 +152,7 @@ class PluginModelClient(BasePluginClient):
|
||||
def invoke_llm(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
user_id: str | None,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
@@ -164,9 +171,9 @@ class PluginModelClient(BasePluginClient):
|
||||
path=f"plugin/{tenant_id}/dispatch/llm/invoke",
|
||||
type_=LLMResultChunk,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
self._dispatch_payload(
|
||||
user_id=user_id,
|
||||
data={
|
||||
"provider": provider,
|
||||
"model_type": "llm",
|
||||
"model": model,
|
||||
@@ -177,7 +184,7 @@ class PluginModelClient(BasePluginClient):
|
||||
"stop": stop,
|
||||
"stream": stream,
|
||||
},
|
||||
}
|
||||
)
|
||||
),
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
@@ -193,7 +200,7 @@ class PluginModelClient(BasePluginClient):
|
||||
def get_llm_num_tokens(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
user_id: str | None,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model_type: str,
|
||||
@@ -210,9 +217,9 @@ class PluginModelClient(BasePluginClient):
|
||||
path=f"plugin/{tenant_id}/dispatch/llm/num_tokens",
|
||||
type_=PluginLLMNumTokensResponse,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
self._dispatch_payload(
|
||||
user_id=user_id,
|
||||
data={
|
||||
"provider": provider,
|
||||
"model_type": model_type,
|
||||
"model": model,
|
||||
@@ -220,7 +227,7 @@ class PluginModelClient(BasePluginClient):
|
||||
"prompt_messages": prompt_messages,
|
||||
"tools": tools,
|
||||
},
|
||||
}
|
||||
)
|
||||
),
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
@@ -236,7 +243,7 @@ class PluginModelClient(BasePluginClient):
|
||||
def invoke_text_embedding(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
user_id: str | None,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
@@ -252,9 +259,9 @@ class PluginModelClient(BasePluginClient):
|
||||
path=f"plugin/{tenant_id}/dispatch/text_embedding/invoke",
|
||||
type_=EmbeddingResult,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
self._dispatch_payload(
|
||||
user_id=user_id,
|
||||
data={
|
||||
"provider": provider,
|
||||
"model_type": "text-embedding",
|
||||
"model": model,
|
||||
@@ -262,7 +269,7 @@ class PluginModelClient(BasePluginClient):
|
||||
"texts": texts,
|
||||
"input_type": input_type,
|
||||
},
|
||||
}
|
||||
)
|
||||
),
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
@@ -278,7 +285,7 @@ class PluginModelClient(BasePluginClient):
|
||||
def invoke_multimodal_embedding(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
user_id: str | None,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
@@ -294,9 +301,9 @@ class PluginModelClient(BasePluginClient):
|
||||
path=f"plugin/{tenant_id}/dispatch/multimodal_embedding/invoke",
|
||||
type_=EmbeddingResult,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
self._dispatch_payload(
|
||||
user_id=user_id,
|
||||
data={
|
||||
"provider": provider,
|
||||
"model_type": "text-embedding",
|
||||
"model": model,
|
||||
@@ -304,7 +311,7 @@ class PluginModelClient(BasePluginClient):
|
||||
"documents": documents,
|
||||
"input_type": input_type,
|
||||
},
|
||||
}
|
||||
)
|
||||
),
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
@@ -320,7 +327,7 @@ class PluginModelClient(BasePluginClient):
|
||||
def get_text_embedding_num_tokens(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
user_id: str | None,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
@@ -335,16 +342,16 @@ class PluginModelClient(BasePluginClient):
|
||||
path=f"plugin/{tenant_id}/dispatch/text_embedding/num_tokens",
|
||||
type_=PluginTextEmbeddingNumTokensResponse,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
self._dispatch_payload(
|
||||
user_id=user_id,
|
||||
data={
|
||||
"provider": provider,
|
||||
"model_type": "text-embedding",
|
||||
"model": model,
|
||||
"credentials": credentials,
|
||||
"texts": texts,
|
||||
},
|
||||
}
|
||||
)
|
||||
),
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
@@ -360,7 +367,7 @@ class PluginModelClient(BasePluginClient):
|
||||
def invoke_rerank(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
user_id: str | None,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
@@ -378,9 +385,9 @@ class PluginModelClient(BasePluginClient):
|
||||
path=f"plugin/{tenant_id}/dispatch/rerank/invoke",
|
||||
type_=RerankResult,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
self._dispatch_payload(
|
||||
user_id=user_id,
|
||||
data={
|
||||
"provider": provider,
|
||||
"model_type": "rerank",
|
||||
"model": model,
|
||||
@@ -390,7 +397,7 @@ class PluginModelClient(BasePluginClient):
|
||||
"score_threshold": score_threshold,
|
||||
"top_n": top_n,
|
||||
},
|
||||
}
|
||||
)
|
||||
),
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
@@ -406,13 +413,13 @@ class PluginModelClient(BasePluginClient):
|
||||
def invoke_multimodal_rerank(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
user_id: str | None,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
query: dict,
|
||||
docs: list[dict],
|
||||
query: MultimodalRerankInput,
|
||||
docs: list[MultimodalRerankInput],
|
||||
score_threshold: float | None = None,
|
||||
top_n: int | None = None,
|
||||
) -> RerankResult:
|
||||
@@ -424,9 +431,9 @@ class PluginModelClient(BasePluginClient):
|
||||
path=f"plugin/{tenant_id}/dispatch/multimodal_rerank/invoke",
|
||||
type_=RerankResult,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
self._dispatch_payload(
|
||||
user_id=user_id,
|
||||
data={
|
||||
"provider": provider,
|
||||
"model_type": "rerank",
|
||||
"model": model,
|
||||
@@ -436,7 +443,7 @@ class PluginModelClient(BasePluginClient):
|
||||
"score_threshold": score_threshold,
|
||||
"top_n": top_n,
|
||||
},
|
||||
}
|
||||
)
|
||||
),
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
@@ -451,7 +458,7 @@ class PluginModelClient(BasePluginClient):
|
||||
def invoke_tts(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
user_id: str | None,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
@@ -467,9 +474,9 @@ class PluginModelClient(BasePluginClient):
|
||||
path=f"plugin/{tenant_id}/dispatch/tts/invoke",
|
||||
type_=PluginStringResultResponse,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
self._dispatch_payload(
|
||||
user_id=user_id,
|
||||
data={
|
||||
"provider": provider,
|
||||
"model_type": "tts",
|
||||
"model": model,
|
||||
@@ -478,7 +485,7 @@ class PluginModelClient(BasePluginClient):
|
||||
"content_text": content_text,
|
||||
"voice": voice,
|
||||
},
|
||||
}
|
||||
)
|
||||
),
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
@@ -496,7 +503,7 @@ class PluginModelClient(BasePluginClient):
|
||||
def get_tts_model_voices(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
user_id: str | None,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
@@ -511,16 +518,16 @@ class PluginModelClient(BasePluginClient):
|
||||
path=f"plugin/{tenant_id}/dispatch/tts/model/voices",
|
||||
type_=PluginVoicesResponse,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
self._dispatch_payload(
|
||||
user_id=user_id,
|
||||
data={
|
||||
"provider": provider,
|
||||
"model_type": "tts",
|
||||
"model": model,
|
||||
"credentials": credentials,
|
||||
"language": language,
|
||||
},
|
||||
}
|
||||
)
|
||||
),
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
@@ -540,7 +547,7 @@ class PluginModelClient(BasePluginClient):
|
||||
def invoke_speech_to_text(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
user_id: str | None,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
@@ -555,16 +562,16 @@ class PluginModelClient(BasePluginClient):
|
||||
path=f"plugin/{tenant_id}/dispatch/speech2text/invoke",
|
||||
type_=PluginStringResultResponse,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
self._dispatch_payload(
|
||||
user_id=user_id,
|
||||
data={
|
||||
"provider": provider,
|
||||
"model_type": "speech2text",
|
||||
"model": model,
|
||||
"credentials": credentials,
|
||||
"file": binascii.hexlify(file.read()).decode(),
|
||||
},
|
||||
}
|
||||
)
|
||||
),
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
@@ -580,7 +587,7 @@ class PluginModelClient(BasePluginClient):
|
||||
def invoke_moderation(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
user_id: str | None,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
@@ -595,16 +602,16 @@ class PluginModelClient(BasePluginClient):
|
||||
path=f"plugin/{tenant_id}/dispatch/moderation/invoke",
|
||||
type_=PluginBasicBooleanResponse,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
self._dispatch_payload(
|
||||
user_id=user_id,
|
||||
data={
|
||||
"provider": provider,
|
||||
"model_type": "moderation",
|
||||
"model": model,
|
||||
"credentials": credentials,
|
||||
"text": text,
|
||||
},
|
||||
}
|
||||
)
|
||||
),
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
|
||||
490
api/core/plugin/impl/model_runtime.py
Normal file
490
api/core/plugin/impl/model_runtime.py
Normal file
@@ -0,0 +1,490 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
from collections.abc import Generator, Iterable, Sequence
|
||||
from threading import Lock
|
||||
from typing import IO, Any, Union
|
||||
|
||||
from pydantic import ValidationError
|
||||
from redis import RedisError
|
||||
|
||||
from configs import dify_config
|
||||
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
||||
from core.plugin.impl.asset import PluginAssetManager
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||
from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||
from dify_graph.model_runtime.entities.provider_entities import ProviderEntity
|
||||
from dify_graph.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult
|
||||
from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingInputType, EmbeddingResult
|
||||
from dify_graph.model_runtime.runtime import ModelRuntime
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.provider_ids import ModelProviderID
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PluginModelRuntime(ModelRuntime):
|
||||
"""Plugin-backed runtime adapter bound to tenant context and a default user."""
|
||||
|
||||
tenant_id: str
|
||||
user_id: str | None
|
||||
client: PluginModelClient
|
||||
_provider_entities: tuple[ProviderEntity, ...] | None
|
||||
_provider_entities_lock: Lock
|
||||
|
||||
def __init__(self, tenant_id: str, user_id: str | None, client: PluginModelClient) -> None:
|
||||
if client is None:
|
||||
raise ValueError("client is required.")
|
||||
self.tenant_id = tenant_id
|
||||
self.user_id = user_id
|
||||
self.client = client
|
||||
self._provider_entities = None
|
||||
self._provider_entities_lock = Lock()
|
||||
|
||||
def fetch_model_providers(self) -> Sequence[ProviderEntity]:
|
||||
if self._provider_entities is not None:
|
||||
return self._provider_entities
|
||||
|
||||
with self._provider_entities_lock:
|
||||
if self._provider_entities is None:
|
||||
self._provider_entities = tuple(
|
||||
self._to_provider_entity(provider) for provider in self.client.fetch_model_providers(self.tenant_id)
|
||||
)
|
||||
|
||||
return self._provider_entities
|
||||
|
||||
def get_provider_icon(self, *, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]:
|
||||
provider_schema = self._get_provider_schema(provider)
|
||||
|
||||
if icon_type.lower() == "icon_small":
|
||||
if not provider_schema.icon_small:
|
||||
raise ValueError(f"Provider {provider} does not have small icon.")
|
||||
file_name = (
|
||||
provider_schema.icon_small.zh_Hans if lang.lower() == "zh_hans" else provider_schema.icon_small.en_US
|
||||
)
|
||||
elif icon_type.lower() == "icon_small_dark":
|
||||
if not provider_schema.icon_small_dark:
|
||||
raise ValueError(f"Provider {provider} does not have small dark icon.")
|
||||
file_name = (
|
||||
provider_schema.icon_small_dark.zh_Hans
|
||||
if lang.lower() == "zh_hans"
|
||||
else provider_schema.icon_small_dark.en_US
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported icon type: {icon_type}.")
|
||||
|
||||
if not file_name:
|
||||
raise ValueError(f"Provider {provider} does not have icon.")
|
||||
|
||||
image_mime_types = {
|
||||
"jpg": "image/jpeg",
|
||||
"jpeg": "image/jpeg",
|
||||
"png": "image/png",
|
||||
"gif": "image/gif",
|
||||
"bmp": "image/bmp",
|
||||
"tiff": "image/tiff",
|
||||
"tif": "image/tiff",
|
||||
"webp": "image/webp",
|
||||
"svg": "image/svg+xml",
|
||||
"ico": "image/vnd.microsoft.icon",
|
||||
"heif": "image/heif",
|
||||
"heic": "image/heic",
|
||||
}
|
||||
|
||||
extension = file_name.split(".")[-1]
|
||||
mime_type = image_mime_types.get(extension, "image/png")
|
||||
return PluginAssetManager().fetch_asset(tenant_id=self.tenant_id, id=file_name), mime_type
|
||||
|
||||
def validate_provider_credentials(self, *, provider: str, credentials: dict[str, Any]) -> None:
|
||||
plugin_id, provider_name = self._split_provider(provider)
|
||||
self.client.validate_provider_credentials(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
def validate_model_credentials(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model_type: ModelType,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
) -> None:
|
||||
plugin_id, provider_name = self._split_provider(provider)
|
||||
self.client.validate_model_credentials(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
model_type=model_type.value,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
def get_model_schema(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model_type: ModelType,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
) -> AIModelEntity | None:
|
||||
cache_key = self._get_schema_cache_key(
|
||||
provider=provider,
|
||||
model_type=model_type,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
cached_schema_json = None
|
||||
try:
|
||||
cached_schema_json = redis_client.get(cache_key)
|
||||
except (RedisError, RuntimeError) as exc:
|
||||
logger.warning(
|
||||
"Failed to read plugin model schema cache for model %s: %s",
|
||||
model,
|
||||
str(exc),
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
if cached_schema_json:
|
||||
try:
|
||||
return AIModelEntity.model_validate_json(cached_schema_json)
|
||||
except ValidationError:
|
||||
logger.warning("Failed to validate cached plugin model schema for model %s", model, exc_info=True)
|
||||
try:
|
||||
redis_client.delete(cache_key)
|
||||
except (RedisError, RuntimeError) as exc:
|
||||
logger.warning(
|
||||
"Failed to delete invalid plugin model schema cache for model %s: %s",
|
||||
model,
|
||||
str(exc),
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
plugin_id, provider_name = self._split_provider(provider)
|
||||
schema = self.client.get_model_schema(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
model_type=model_type.value,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
if schema:
|
||||
try:
|
||||
redis_client.setex(cache_key, dify_config.PLUGIN_MODEL_SCHEMA_CACHE_TTL, schema.model_dump_json())
|
||||
except (RedisError, RuntimeError) as exc:
|
||||
logger.warning(
|
||||
"Failed to write plugin model schema cache for model %s: %s",
|
||||
model,
|
||||
str(exc),
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
return schema
|
||||
|
||||
def invoke_llm(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
model_parameters: dict[str, Any],
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
tools: list[PromptMessageTool] | None,
|
||||
stop: Sequence[str] | None,
|
||||
stream: bool,
|
||||
) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]:
|
||||
plugin_id, provider_name = self._split_provider(provider)
|
||||
return self.client.invoke_llm(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
model_parameters=model_parameters,
|
||||
prompt_messages=list(prompt_messages),
|
||||
tools=tools,
|
||||
stop=list(stop) if stop else None,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
def get_llm_num_tokens(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model_type: ModelType,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
tools: Sequence[PromptMessageTool] | None,
|
||||
) -> int:
|
||||
if not dify_config.PLUGIN_BASED_TOKEN_COUNTING_ENABLED:
|
||||
return 0
|
||||
|
||||
plugin_id, provider_name = self._split_provider(provider)
|
||||
return self.client.get_llm_num_tokens(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
model_type=model_type.value,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=list(prompt_messages),
|
||||
tools=list(tools) if tools else None,
|
||||
)
|
||||
|
||||
def invoke_text_embedding(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
texts: list[str],
|
||||
input_type: EmbeddingInputType,
|
||||
) -> EmbeddingResult:
|
||||
plugin_id, provider_name = self._split_provider(provider)
|
||||
return self.client.invoke_text_embedding(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
texts=texts,
|
||||
input_type=input_type,
|
||||
)
|
||||
|
||||
def invoke_multimodal_embedding(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
documents: list[dict[str, Any]],
|
||||
input_type: EmbeddingInputType,
|
||||
) -> EmbeddingResult:
|
||||
plugin_id, provider_name = self._split_provider(provider)
|
||||
return self.client.invoke_multimodal_embedding(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
documents=documents,
|
||||
input_type=input_type,
|
||||
)
|
||||
|
||||
def get_text_embedding_num_tokens(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
texts: list[str],
|
||||
) -> list[int]:
|
||||
plugin_id, provider_name = self._split_provider(provider)
|
||||
return self.client.get_text_embedding_num_tokens(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
texts=texts,
|
||||
)
|
||||
|
||||
def invoke_rerank(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
query: str,
|
||||
docs: list[str],
|
||||
score_threshold: float | None,
|
||||
top_n: int | None,
|
||||
) -> RerankResult:
|
||||
plugin_id, provider_name = self._split_provider(provider)
|
||||
return self.client.invoke_rerank(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
query=query,
|
||||
docs=docs,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_n,
|
||||
)
|
||||
|
||||
def invoke_multimodal_rerank(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
query: MultimodalRerankInput,
|
||||
docs: list[MultimodalRerankInput],
|
||||
score_threshold: float | None,
|
||||
top_n: int | None,
|
||||
) -> RerankResult:
|
||||
plugin_id, provider_name = self._split_provider(provider)
|
||||
return self.client.invoke_multimodal_rerank(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
query=query,
|
||||
docs=docs,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_n,
|
||||
)
|
||||
|
||||
def invoke_tts(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
content_text: str,
|
||||
voice: str,
|
||||
) -> Iterable[bytes]:
|
||||
plugin_id, provider_name = self._split_provider(provider)
|
||||
return self.client.invoke_tts(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
content_text=content_text,
|
||||
voice=voice,
|
||||
)
|
||||
|
||||
def get_tts_model_voices(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
language: str | None,
|
||||
) -> Any:
|
||||
plugin_id, provider_name = self._split_provider(provider)
|
||||
return self.client.get_tts_model_voices(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
language=language,
|
||||
)
|
||||
|
||||
def invoke_speech_to_text(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
file: IO[bytes],
|
||||
) -> str:
|
||||
plugin_id, provider_name = self._split_provider(provider)
|
||||
return self.client.invoke_speech_to_text(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
file=file,
|
||||
)
|
||||
|
||||
def invoke_moderation(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
text: str,
|
||||
) -> bool:
|
||||
plugin_id, provider_name = self._split_provider(provider)
|
||||
return self.client.invoke_moderation(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
text=text,
|
||||
)
|
||||
|
||||
def _get_provider_short_name_alias(self, provider: PluginModelProviderEntity) -> str:
|
||||
"""
|
||||
Expose a bare provider alias only for the canonical provider mapping.
|
||||
|
||||
Multiple plugins can publish the same short provider slug. If every
|
||||
provider entity keeps that slug in ``provider_name``, callers that still
|
||||
resolve by short name become order-dependent. Restrict the alias to the
|
||||
provider selected by ``ModelProviderID`` so legacy short-name lookups
|
||||
remain deterministic while the runtime surface stays canonical.
|
||||
"""
|
||||
try:
|
||||
canonical_provider_id = ModelProviderID(provider.provider)
|
||||
except ValueError:
|
||||
return ""
|
||||
|
||||
if canonical_provider_id.plugin_id != provider.plugin_id:
|
||||
return ""
|
||||
if canonical_provider_id.provider_name != provider.provider:
|
||||
return ""
|
||||
|
||||
return provider.provider
|
||||
|
||||
def _to_provider_entity(self, provider: PluginModelProviderEntity) -> ProviderEntity:
|
||||
declaration = provider.declaration.model_copy(deep=True)
|
||||
declaration.provider = f"{provider.plugin_id}/{provider.provider}"
|
||||
declaration.provider_name = self._get_provider_short_name_alias(provider)
|
||||
return declaration
|
||||
|
||||
def _get_provider_schema(self, provider: str) -> ProviderEntity:
|
||||
providers = self.fetch_model_providers()
|
||||
provider_entity = next((item for item in providers if item.provider == provider), None)
|
||||
if provider_entity is None:
|
||||
provider_entity = next((item for item in providers if provider == item.provider_name), None)
|
||||
if provider_entity is None:
|
||||
raise ValueError(f"Invalid provider: {provider}")
|
||||
return provider_entity
|
||||
|
||||
def _get_schema_cache_key(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model_type: ModelType,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
) -> str:
|
||||
cache_key = f"{self.tenant_id}:{provider}:{model_type.value}:{model}"
|
||||
sorted_credentials = sorted(credentials.items()) if credentials else []
|
||||
return cache_key + ":".join(
|
||||
[hashlib.md5(f"{key}:{value}".encode()).hexdigest() for key, value in sorted_credentials]
|
||||
)
|
||||
|
||||
def _split_provider(self, provider: str) -> tuple[str, str]:
|
||||
provider_id = ModelProviderID(provider)
|
||||
return provider_id.plugin_id, provider_id.provider_name
|
||||
45
api/core/plugin/impl/model_runtime_factory.py
Normal file
45
api/core/plugin/impl/model_runtime_factory.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.model_manager import ModelManager
|
||||
from core.plugin.impl.model_runtime import PluginModelRuntime
|
||||
from core.provider_manager import ProviderManager
|
||||
from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
|
||||
|
||||
def create_plugin_model_runtime(*, tenant_id: str, user_id: str | None = None) -> PluginModelRuntime:
|
||||
"""Create a plugin runtime with its client dependency fully composed."""
|
||||
from core.plugin.impl.model_runtime import PluginModelRuntime
|
||||
|
||||
return PluginModelRuntime(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
client=PluginModelClient(),
|
||||
)
|
||||
|
||||
|
||||
def create_plugin_model_provider_factory(*, tenant_id: str, user_id: str | None = None) -> ModelProviderFactory:
|
||||
"""Create a tenant-bound model provider factory for service flows."""
|
||||
from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
|
||||
return ModelProviderFactory(model_runtime=create_plugin_model_runtime(tenant_id=tenant_id, user_id=user_id))
|
||||
|
||||
|
||||
def create_plugin_provider_manager(*, tenant_id: str, user_id: str | None = None) -> ProviderManager:
|
||||
"""Create a tenant-bound provider manager for service flows."""
|
||||
from core.provider_manager import ProviderManager
|
||||
|
||||
return ProviderManager(model_runtime=create_plugin_model_runtime(tenant_id=tenant_id, user_id=user_id))
|
||||
|
||||
|
||||
def create_plugin_model_manager(*, tenant_id: str, user_id: str | None = None) -> ModelManager:
|
||||
"""Create a tenant-bound model manager for service flows."""
|
||||
from core.model_manager import ModelManager
|
||||
|
||||
return ModelManager(
|
||||
provider_manager=create_plugin_provider_manager(tenant_id=tenant_id, user_id=user_id),
|
||||
)
|
||||
@@ -1,50 +1,7 @@
|
||||
from typing import Literal
|
||||
from dify_graph.prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from dify_graph.model_runtime.entities.message_entities import PromptMessageRole
|
||||
|
||||
|
||||
class ChatModelMessage(BaseModel):
|
||||
"""
|
||||
Chat Message.
|
||||
"""
|
||||
|
||||
text: str
|
||||
role: PromptMessageRole
|
||||
edition_type: Literal["basic", "jinja2"] | None = None
|
||||
|
||||
|
||||
class CompletionModelPromptTemplate(BaseModel):
|
||||
"""
|
||||
Completion Model Prompt Template.
|
||||
"""
|
||||
|
||||
text: str
|
||||
edition_type: Literal["basic", "jinja2"] | None = None
|
||||
|
||||
|
||||
class MemoryConfig(BaseModel):
|
||||
"""
|
||||
Memory Config.
|
||||
"""
|
||||
|
||||
class RolePrefix(BaseModel):
|
||||
"""
|
||||
Role Prefix.
|
||||
"""
|
||||
|
||||
user: str
|
||||
assistant: str
|
||||
|
||||
class WindowConfig(BaseModel):
|
||||
"""
|
||||
Window Config.
|
||||
"""
|
||||
|
||||
enabled: bool
|
||||
size: int | None = None
|
||||
|
||||
role_prefix: RolePrefix | None = None
|
||||
window: WindowConfig
|
||||
query_prompt_template: str | None = None
|
||||
__all__ = [
|
||||
"ChatModelMessage",
|
||||
"CompletionModelPromptTemplate",
|
||||
"MemoryConfig",
|
||||
]
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
from json import JSONDecodeError
|
||||
from typing import Any, cast
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
@@ -53,15 +55,22 @@ from models.provider import (
|
||||
from models.provider_ids import ModelProviderID
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dify_graph.model_runtime.runtime import ModelRuntime
|
||||
|
||||
|
||||
class ProviderManager:
|
||||
"""
|
||||
ProviderManager is a class that manages the model providers includes Hosting and Customize Model Providers.
|
||||
ProviderManager manages tenant-scoped model provider configuration.
|
||||
|
||||
The runtime adapter is injected by the composition layer so this class stays
|
||||
focused on configuration assembly instead of constructing plugin runtimes.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, model_runtime: ModelRuntime):
|
||||
self.decoding_rsa_key = None
|
||||
self.decoding_cipher_rsa = None
|
||||
self._model_runtime = model_runtime
|
||||
|
||||
def get_configurations(self, tenant_id: str) -> ProviderConfigurations:
|
||||
"""
|
||||
@@ -127,7 +136,7 @@ class ProviderManager:
|
||||
)
|
||||
|
||||
# Get all provider entities
|
||||
model_provider_factory = ModelProviderFactory(tenant_id)
|
||||
model_provider_factory = ModelProviderFactory(model_runtime=self._model_runtime)
|
||||
provider_entities = model_provider_factory.get_providers()
|
||||
|
||||
# Get All preferred provider types of the workspace
|
||||
@@ -321,7 +330,7 @@ class ProviderManager:
|
||||
if not default_model:
|
||||
return None
|
||||
|
||||
model_provider_factory = ModelProviderFactory(tenant_id)
|
||||
model_provider_factory = ModelProviderFactory(model_runtime=self._model_runtime)
|
||||
provider_schema = model_provider_factory.get_provider_schema(provider=default_model.provider_name)
|
||||
|
||||
return DefaultModelEntity(
|
||||
|
||||
@@ -106,9 +106,9 @@ class DataPostProcessor:
|
||||
) -> ModelInstance | None:
|
||||
if reranking_model:
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
reranking_provider_name = reranking_model["reranking_provider_name"]
|
||||
reranking_model_name = reranking_model["reranking_model_name"]
|
||||
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
|
||||
reranking_provider_name = reranking_model.get("reranking_provider_name")
|
||||
reranking_model_name = reranking_model.get("reranking_model_name")
|
||||
if not reranking_provider_name or not reranking_model_name:
|
||||
return None
|
||||
rerank_model_instance = model_manager.get_model_instance(
|
||||
|
||||
@@ -328,7 +328,7 @@ class RetrievalService:
|
||||
str(dataset.tenant_id), str(RerankMode.RERANKING_MODEL), reranking_model, None, False
|
||||
)
|
||||
if dataset.is_multimodal:
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=dataset.tenant_id)
|
||||
is_support_vision = model_manager.check_model_support_vision(
|
||||
tenant_id=dataset.tenant_id,
|
||||
provider=reranking_model["reranking_provider_name"],
|
||||
|
||||
@@ -303,7 +303,7 @@ class Vector:
|
||||
redis_client.delete(collection_exist_cache_key)
|
||||
|
||||
def _get_embeddings(self) -> Embeddings:
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=self._dataset.tenant_id)
|
||||
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=self._dataset.tenant_id,
|
||||
|
||||
@@ -72,7 +72,7 @@ class DatasetDocumentStore:
|
||||
max_position = 0
|
||||
embedding_model = None
|
||||
if self._dataset.indexing_technique == "high_quality":
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=self._dataset.tenant_id)
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=self._dataset.tenant_id,
|
||||
provider=self._dataset.embedding_model_provider,
|
||||
|
||||
@@ -8,7 +8,7 @@ from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from configs import dify_config
|
||||
from core.entities.embedding_type import EmbeddingInputType
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.rag.embedding.embedding_base import Embeddings
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey
|
||||
from dify_graph.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
@@ -22,8 +22,20 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class CacheEmbedding(Embeddings):
|
||||
def __init__(self, model_instance: ModelInstance, user: str | None = None):
|
||||
self._model_instance = model_instance
|
||||
self._user = user
|
||||
self._model_instance = self._bind_model_instance(model_instance, user)
|
||||
|
||||
@staticmethod
|
||||
def _bind_model_instance(model_instance: ModelInstance, user: str | None) -> ModelInstance:
|
||||
if user is None:
|
||||
return model_instance
|
||||
|
||||
tenant_id = model_instance.provider_model_bundle.configuration.tenant_id
|
||||
return ModelManager.for_tenant(tenant_id=tenant_id, user_id=user).get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
provider=model_instance.provider,
|
||||
model_type=model_instance.model_type_instance.model_type,
|
||||
model=model_instance.model_name,
|
||||
)
|
||||
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Embed search docs in batches of 10."""
|
||||
@@ -65,7 +77,7 @@ class CacheEmbedding(Embeddings):
|
||||
batch_texts = embedding_queue_texts[i : i + max_chunks]
|
||||
|
||||
embedding_result = self._model_instance.invoke_text_embedding(
|
||||
texts=batch_texts, user=self._user, input_type=EmbeddingInputType.DOCUMENT
|
||||
texts=batch_texts, input_type=EmbeddingInputType.DOCUMENT
|
||||
)
|
||||
|
||||
for vector in embedding_result.embeddings:
|
||||
@@ -147,7 +159,6 @@ class CacheEmbedding(Embeddings):
|
||||
|
||||
embedding_result = self._model_instance.invoke_multimodal_embedding(
|
||||
multimodel_documents=batch_multimodel_documents,
|
||||
user=self._user,
|
||||
input_type=EmbeddingInputType.DOCUMENT,
|
||||
)
|
||||
|
||||
@@ -202,7 +213,7 @@ class CacheEmbedding(Embeddings):
|
||||
return [float(x) for x in decoded_embedding]
|
||||
try:
|
||||
embedding_result = self._model_instance.invoke_text_embedding(
|
||||
texts=[text], user=self._user, input_type=EmbeddingInputType.QUERY
|
||||
texts=[text], input_type=EmbeddingInputType.QUERY
|
||||
)
|
||||
|
||||
embedding_results = embedding_result.embeddings[0]
|
||||
@@ -245,7 +256,7 @@ class CacheEmbedding(Embeddings):
|
||||
return [float(x) for x in decoded_embedding]
|
||||
try:
|
||||
embedding_result = self._model_instance.invoke_multimodal_embedding(
|
||||
multimodel_documents=[multimodel_document], user=self._user, input_type=EmbeddingInputType.QUERY
|
||||
multimodel_documents=[multimodel_document], input_type=EmbeddingInputType.QUERY
|
||||
)
|
||||
|
||||
embedding_results = embedding_result.embeddings[0]
|
||||
|
||||
@@ -12,7 +12,7 @@ from core.app.llm import deduct_llm_quota
|
||||
from core.entities.knowledge_entities import PreviewDetail
|
||||
from core.llm_generator.prompts import DEFAULT_GENERATOR_SUMMARY_PROMPT
|
||||
from core.model_manager import ModelInstance
|
||||
from core.provider_manager import ProviderManager
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
|
||||
from core.rag.cleaner.clean_processor import CleanProcessor
|
||||
from core.rag.data_post_processor.data_post_processor import RerankingModelDict
|
||||
from core.rag.datasource.keyword.keyword_factory import Keyword
|
||||
@@ -410,7 +410,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
# If default prompt doesn't have {language} placeholder, use it as-is
|
||||
pass
|
||||
|
||||
provider_manager = ProviderManager()
|
||||
provider_manager = create_plugin_provider_manager(tenant_id=tenant_id)
|
||||
provider_model_bundle = provider_manager.get_provider_model_bundle(
|
||||
tenant_id, model_provider_name, ModelType.LLM
|
||||
)
|
||||
|
||||
@@ -16,6 +16,18 @@ class RerankModelRunner(BaseRerankRunner):
|
||||
def __init__(self, rerank_model_instance: ModelInstance):
|
||||
self.rerank_model_instance = rerank_model_instance
|
||||
|
||||
def _get_invoke_model_instance(self, user: str | None) -> ModelInstance:
|
||||
if user is None:
|
||||
return self.rerank_model_instance
|
||||
|
||||
tenant_id = self.rerank_model_instance.provider_model_bundle.configuration.tenant_id
|
||||
return ModelManager.for_tenant(tenant_id=tenant_id, user_id=user).get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
provider=self.rerank_model_instance.provider,
|
||||
model_type=ModelType.RERANK,
|
||||
model=self.rerank_model_instance.model_name,
|
||||
)
|
||||
|
||||
def run(
|
||||
self,
|
||||
query: str,
|
||||
@@ -34,7 +46,9 @@ class RerankModelRunner(BaseRerankRunner):
|
||||
:param user: unique user id if needed
|
||||
:return:
|
||||
"""
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(
|
||||
tenant_id=self.rerank_model_instance.provider_model_bundle.configuration.tenant_id
|
||||
)
|
||||
is_support_vision = model_manager.check_model_support_vision(
|
||||
tenant_id=self.rerank_model_instance.provider_model_bundle.configuration.tenant_id,
|
||||
provider=self.rerank_model_instance.provider,
|
||||
@@ -102,8 +116,8 @@ class RerankModelRunner(BaseRerankRunner):
|
||||
docs.append(document.page_content)
|
||||
unique_documents.append(document)
|
||||
|
||||
rerank_result = self.rerank_model_instance.invoke_rerank(
|
||||
query=query, docs=docs, score_threshold=score_threshold, top_n=top_n, user=user
|
||||
rerank_result = self._get_invoke_model_instance(user).invoke_rerank(
|
||||
query=query, docs=docs, score_threshold=score_threshold, top_n=top_n
|
||||
)
|
||||
return rerank_result, unique_documents
|
||||
|
||||
@@ -180,8 +194,8 @@ class RerankModelRunner(BaseRerankRunner):
|
||||
"content": file_query,
|
||||
"content_type": DocType.IMAGE,
|
||||
}
|
||||
rerank_result = self.rerank_model_instance.invoke_multimodal_rerank(
|
||||
query=file_query_dict, docs=docs, score_threshold=score_threshold, top_n=top_n, user=user
|
||||
rerank_result = self._get_invoke_model_instance(user).invoke_multimodal_rerank(
|
||||
query=file_query_dict, docs=docs, score_threshold=score_threshold, top_n=top_n
|
||||
)
|
||||
return rerank_result, unique_documents
|
||||
else:
|
||||
|
||||
@@ -163,7 +163,7 @@ class WeightRerankRunner(BaseRerankRunner):
|
||||
"""
|
||||
query_vector_scores = []
|
||||
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
|
||||
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
|
||||
@@ -160,7 +160,7 @@ class DatasetRetrieval:
|
||||
if request.model_provider is None or request.model_name is None or request.query is None:
|
||||
raise ValueError("model_provider, model_name, and query are required for single retrieval mode")
|
||||
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=request.tenant_id)
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=request.tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
@@ -387,7 +387,7 @@ class DatasetRetrieval:
|
||||
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=tenant_id, model_type=ModelType.LLM, provider=model_config.provider, model=model_config.model
|
||||
)
|
||||
@@ -1411,7 +1411,7 @@ class DatasetRetrieval:
|
||||
raise ValueError("metadata_model_config is required")
|
||||
# get metadata model instance
|
||||
# fetch model config
|
||||
model_instance, model_config = self._fetch_model_config(tenant_id, metadata_model_config)
|
||||
model_instance, model_config = self._fetch_model_config(tenant_id, metadata_model_config, user_id=user_id)
|
||||
|
||||
# fetch prompt messages
|
||||
prompt_messages, stop = self._get_prompt_template(
|
||||
@@ -1430,7 +1430,6 @@ class DatasetRetrieval:
|
||||
model_parameters=model_config.parameters,
|
||||
stop=stop,
|
||||
stream=True,
|
||||
user=user_id,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1533,7 +1532,7 @@ class DatasetRetrieval:
|
||||
return filters
|
||||
|
||||
def _fetch_model_config(
|
||||
self, tenant_id: str, model: ModelConfig
|
||||
self, tenant_id: str, model: ModelConfig, user_id: str | None = None
|
||||
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
|
||||
"""
|
||||
Fetch model config
|
||||
@@ -1543,7 +1542,7 @@ class DatasetRetrieval:
|
||||
model_name = model.name
|
||||
provider_name = model.provider
|
||||
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=tenant_id, user_id=user_id)
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=tenant_id, model_type=ModelType.LLM, provider=provider_name, model=model_name
|
||||
)
|
||||
|
||||
@@ -3,13 +3,14 @@ from typing import Union
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.app.llm import deduct_llm_quota
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
|
||||
from core.rag.retrieval.output_parser.react_output import ReactAction
|
||||
from core.rag.retrieval.output_parser.structured_chat import StructuredChatOutputParser
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
|
||||
PREFIX = """Respond to the human as helpfully and accurately as possible. You have access to the following tools:"""
|
||||
|
||||
@@ -150,19 +151,24 @@ class ReactMultiDatasetRouter:
|
||||
:param stop: stop
|
||||
:return:
|
||||
"""
|
||||
invoke_result: Generator[LLMResult, None, None] = model_instance.invoke_llm(
|
||||
bound_model_instance = ModelManager.for_tenant(tenant_id=tenant_id, user_id=user_id).get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
provider=model_instance.provider,
|
||||
model_type=ModelType.LLM,
|
||||
model=model_instance.model_name,
|
||||
)
|
||||
invoke_result: Generator[LLMResult, None, None] = bound_model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=completion_param,
|
||||
stop=stop,
|
||||
stream=True,
|
||||
user=user_id,
|
||||
)
|
||||
|
||||
# handle invoke result
|
||||
text, usage = self._handle_invoke_result(invoke_result=invoke_result)
|
||||
|
||||
# deduct quota
|
||||
deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage)
|
||||
deduct_llm_quota(tenant_id=tenant_id, model_instance=bound_model_instance, usage=usage)
|
||||
|
||||
return text, usage
|
||||
|
||||
|
||||
@@ -201,8 +201,10 @@ class HumanInputFormRepositoryImpl:
|
||||
self,
|
||||
*,
|
||||
tenant_id: str,
|
||||
app_id: str | None = None,
|
||||
):
|
||||
self._tenant_id = tenant_id
|
||||
self._app_id = app_id
|
||||
|
||||
def _delivery_method_to_model(
|
||||
self,
|
||||
@@ -340,6 +342,9 @@ class HumanInputFormRepositoryImpl:
|
||||
|
||||
def create_form(self, params: FormCreateParams) -> HumanInputFormEntity:
|
||||
form_config: HumanInputNodeData = params.form_config
|
||||
app_id = params.app_id or self._app_id
|
||||
if not app_id:
|
||||
raise ValueError("app_id is required to create a human input form")
|
||||
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
# Generate unique form ID
|
||||
@@ -359,7 +364,7 @@ class HumanInputFormRepositoryImpl:
|
||||
form_model = HumanInputForm(
|
||||
id=form_id,
|
||||
tenant_id=self._tenant_id,
|
||||
app_id=params.app_id,
|
||||
app_id=app_id,
|
||||
workflow_run_id=params.workflow_execution_id,
|
||||
form_kind=params.form_kind,
|
||||
node_id=params.node_id,
|
||||
|
||||
@@ -22,6 +22,9 @@ class ASRTool(BuiltinTool):
|
||||
app_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
if not self.runtime:
|
||||
raise ValueError("Runtime is required")
|
||||
runtime = self.runtime
|
||||
file = tool_parameters.get("audio_file")
|
||||
if file.type != FileType.AUDIO: # type: ignore
|
||||
yield self.create_text_message("not a valid audio file")
|
||||
@@ -29,20 +32,19 @@ class ASRTool(BuiltinTool):
|
||||
audio_binary = io.BytesIO(download(file)) # type: ignore
|
||||
audio_binary.name = "temp.mp3"
|
||||
provider, model = tool_parameters.get("model").split("#") # type: ignore
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=runtime.tenant_id, user_id=user_id)
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=self.runtime.tenant_id,
|
||||
tenant_id=runtime.tenant_id,
|
||||
provider=provider,
|
||||
model_type=ModelType.SPEECH2TEXT,
|
||||
model=model,
|
||||
)
|
||||
text = model_instance.invoke_speech2text(
|
||||
file=audio_binary,
|
||||
user=user_id,
|
||||
)
|
||||
text = model_instance.invoke_speech2text(file=audio_binary)
|
||||
yield self.create_text_message(text)
|
||||
|
||||
def get_available_models(self) -> list[tuple[str, str]]:
|
||||
if not self.runtime:
|
||||
raise ValueError("Runtime is required")
|
||||
model_provider_service = ModelProviderService()
|
||||
models = model_provider_service.get_models_by_model_type(
|
||||
tenant_id=self.runtime.tenant_id, model_type="speech2text"
|
||||
|
||||
@@ -20,13 +20,14 @@ class TTSTool(BuiltinTool):
|
||||
app_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
provider, model = tool_parameters.get("model").split("#") # type: ignore
|
||||
voice = tool_parameters.get(f"voice#{provider}#{model}")
|
||||
model_manager = ModelManager()
|
||||
if not self.runtime:
|
||||
raise ValueError("Runtime is required")
|
||||
runtime = self.runtime
|
||||
provider, model = tool_parameters.get("model").split("#") # type: ignore
|
||||
voice = tool_parameters.get(f"voice#{provider}#{model}")
|
||||
model_manager = ModelManager.for_tenant(tenant_id=runtime.tenant_id, user_id=user_id)
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=self.runtime.tenant_id or "",
|
||||
tenant_id=runtime.tenant_id or "",
|
||||
provider=provider,
|
||||
model_type=ModelType.TTS,
|
||||
model=model,
|
||||
@@ -39,12 +40,7 @@ class TTSTool(BuiltinTool):
|
||||
raise ValueError("Sorry, no voice available.")
|
||||
else:
|
||||
raise ValueError("Sorry, no voice available.")
|
||||
tts = model_instance.invoke_tts(
|
||||
content_text=tool_parameters.get("text"), # type: ignore
|
||||
user=user_id,
|
||||
tenant_id=self.runtime.tenant_id,
|
||||
voice=voice,
|
||||
)
|
||||
tts = model_instance.invoke_tts(content_text=tool_parameters.get("text"), voice=voice) # type: ignore[arg-type]
|
||||
buffer = io.BytesIO()
|
||||
for chunk in tts:
|
||||
buffer.write(chunk)
|
||||
|
||||
@@ -65,7 +65,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
# do rerank for searched documents
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=self.tenant_id)
|
||||
rerank_model_instance = model_manager.get_model_instance(
|
||||
tenant_id=self.tenant_id,
|
||||
provider=self.reranking_provider_name,
|
||||
|
||||
@@ -37,7 +37,7 @@ class ModelInvocationUtils:
|
||||
"""
|
||||
get max llm context tokens of the model
|
||||
"""
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
@@ -65,7 +65,7 @@ class ModelInvocationUtils:
|
||||
"""
|
||||
|
||||
# get model instance
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
|
||||
model_instance = model_manager.get_default_model_instance(tenant_id=tenant_id, model_type=ModelType.LLM)
|
||||
|
||||
if not model_instance:
|
||||
@@ -92,7 +92,7 @@ class ModelInvocationUtils:
|
||||
"""
|
||||
|
||||
# get model manager
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
|
||||
# get model instance
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
@@ -136,7 +136,6 @@ class ModelInvocationUtils:
|
||||
tools=[],
|
||||
stop=[],
|
||||
stream=False,
|
||||
user=user_id,
|
||||
callbacks=[],
|
||||
)
|
||||
except InvokeRateLimitError as e:
|
||||
|
||||
@@ -9,8 +9,8 @@ from sqlalchemy.orm import Session
|
||||
from typing_extensions import override
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.entities.app_invoke_entities import DifyRunContext
|
||||
from core.app.llm.model_access import build_dify_model_access
|
||||
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext
|
||||
from core.app.llm.model_access import build_dify_model_access, fetch_model_config
|
||||
from core.helper.code_executor.code_executor import (
|
||||
CodeExecutionError,
|
||||
CodeExecutor,
|
||||
@@ -20,8 +20,17 @@ from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.repositories.human_input_repository import HumanInputFormRepositoryImpl
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.trigger.constants import TRIGGER_NODE_TYPES
|
||||
from core.workflow.node_runtime import (
|
||||
DifyFileReferenceFactory,
|
||||
DifyHumanInputNodeRuntime,
|
||||
DifyPreparedLLM,
|
||||
DifyPromptMessageSerializer,
|
||||
DifyRetrieverAttachmentLoader,
|
||||
DifyToolFileManager,
|
||||
DifyToolNodeRuntime,
|
||||
build_dify_llm_file_saver,
|
||||
)
|
||||
from core.workflow.nodes.agent.message_transformer import AgentMessageTransformer
|
||||
from core.workflow.nodes.agent.plugin_strategy_adapter import (
|
||||
PluginAgentStrategyPresentationProvider,
|
||||
@@ -30,11 +39,9 @@ from core.workflow.nodes.agent.plugin_strategy_adapter import (
|
||||
from core.workflow.nodes.agent.runtime_support import AgentRuntimeSupport
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
|
||||
from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY
|
||||
from dify_graph.enums import BuiltinNodeTypes, NodeType, SystemVariableKey
|
||||
from dify_graph.file.file_manager import file_manager
|
||||
from dify_graph.graph.graph import NodeFactory
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from dify_graph.model_runtime.memory import PromptMessageMemory
|
||||
from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from dify_graph.nodes.base.node import Node
|
||||
@@ -44,13 +51,10 @@ from dify_graph.nodes.code.limits import CodeNodeLimits
|
||||
from dify_graph.nodes.document_extractor import UnstructuredApiConfig
|
||||
from dify_graph.nodes.http_request import build_http_request_config
|
||||
from dify_graph.nodes.llm.entities import LLMNodeData
|
||||
from dify_graph.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError
|
||||
from dify_graph.nodes.llm.protocols import TemplateRenderer
|
||||
from dify_graph.nodes.parameter_extractor.entities import ParameterExtractorNodeData
|
||||
from dify_graph.nodes.question_classifier.entities import QuestionClassifierNodeData
|
||||
from dify_graph.nodes.template_transform.template_renderer import (
|
||||
CodeExecutorJinja2TemplateRenderer,
|
||||
)
|
||||
from dify_graph.template_rendering import CodeExecutorJinja2TemplateRenderer
|
||||
from dify_graph.variables.segments import StringSegment
|
||||
from extensions.ext_database import db
|
||||
from models.model import Conversation
|
||||
@@ -264,11 +268,22 @@ class DifyNodeFactory(NodeFactory):
|
||||
max_string_array_length=dify_config.CODE_MAX_STRING_ARRAY_LENGTH,
|
||||
max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH,
|
||||
)
|
||||
self._template_renderer = CodeExecutorJinja2TemplateRenderer(code_executor=self._code_executor)
|
||||
self._jinja2_template_renderer = CodeExecutorJinja2TemplateRenderer(code_executor=self._code_executor)
|
||||
self._llm_template_renderer: TemplateRenderer = DefaultLLMTemplateRenderer()
|
||||
self._template_transform_max_output_length = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH
|
||||
self._http_request_http_client = ssrf_proxy
|
||||
self._http_request_tool_file_manager_factory = ToolFileManager
|
||||
self._bound_tool_file_manager_factory = lambda: DifyToolFileManager(self._dify_context)
|
||||
self._file_reference_factory = DifyFileReferenceFactory(self._dify_context)
|
||||
self._prompt_message_serializer = DifyPromptMessageSerializer()
|
||||
self._retriever_attachment_loader = DifyRetrieverAttachmentLoader(
|
||||
file_reference_factory=self._file_reference_factory,
|
||||
)
|
||||
self._llm_file_saver = build_dify_llm_file_saver(
|
||||
run_context=self._dify_context,
|
||||
http_client=self._http_request_http_client,
|
||||
)
|
||||
self._human_input_runtime = DifyHumanInputNodeRuntime(self._dify_context)
|
||||
self._tool_runtime = DifyToolNodeRuntime(self._dify_context)
|
||||
self._http_request_file_manager = file_manager
|
||||
self._document_extractor_unstructured_api_config = UnstructuredApiConfig(
|
||||
api_url=dify_config.UNSTRUCTURED_API_URL,
|
||||
@@ -284,7 +299,7 @@ class DifyNodeFactory(NodeFactory):
|
||||
ssrf_default_max_retries=dify_config.SSRF_DEFAULT_MAX_RETRIES,
|
||||
)
|
||||
|
||||
self._llm_credentials_provider, self._llm_model_factory = build_dify_model_access(self._dify_context.tenant_id)
|
||||
self._llm_credentials_provider, self._llm_model_factory = build_dify_model_access(self._dify_context)
|
||||
self._agent_strategy_resolver = PluginAgentStrategyResolver()
|
||||
self._agent_strategy_presentation_provider = PluginAgentStrategyPresentationProvider()
|
||||
self._agent_runtime_support = AgentRuntimeSupport()
|
||||
@@ -321,22 +336,32 @@ class DifyNodeFactory(NodeFactory):
|
||||
"code_limits": self._code_limits,
|
||||
},
|
||||
BuiltinNodeTypes.TEMPLATE_TRANSFORM: lambda: {
|
||||
"template_renderer": self._template_renderer,
|
||||
"jinja2_template_renderer": self._jinja2_template_renderer,
|
||||
"max_output_length": self._template_transform_max_output_length,
|
||||
},
|
||||
BuiltinNodeTypes.HTTP_REQUEST: lambda: {
|
||||
"http_request_config": self._http_request_config,
|
||||
"http_client": self._http_request_http_client,
|
||||
"tool_file_manager_factory": self._http_request_tool_file_manager_factory,
|
||||
"tool_file_manager_factory": self._bound_tool_file_manager_factory,
|
||||
"file_manager": self._http_request_file_manager,
|
||||
"file_reference_factory": self._file_reference_factory,
|
||||
},
|
||||
BuiltinNodeTypes.HUMAN_INPUT: lambda: {
|
||||
"form_repository": HumanInputFormRepositoryImpl(tenant_id=self._dify_context.tenant_id),
|
||||
"form_repository": HumanInputFormRepositoryImpl(
|
||||
tenant_id=self._dify_context.tenant_id,
|
||||
app_id=self._dify_context.app_id,
|
||||
),
|
||||
"runtime": self._human_input_runtime,
|
||||
},
|
||||
BuiltinNodeTypes.LLM: lambda: self._build_llm_compatible_node_init_kwargs(
|
||||
node_class=node_class,
|
||||
node_data=node_data,
|
||||
wrap_model_instance=True,
|
||||
include_http_client=True,
|
||||
include_llm_file_saver=True,
|
||||
include_prompt_message_serializer=True,
|
||||
include_retriever_attachment_loader=True,
|
||||
include_jinja2_template_renderer=True,
|
||||
),
|
||||
BuiltinNodeTypes.DOCUMENT_EXTRACTOR: lambda: {
|
||||
"unstructured_api_config": self._document_extractor_unstructured_api_config,
|
||||
@@ -345,15 +370,26 @@ class DifyNodeFactory(NodeFactory):
|
||||
BuiltinNodeTypes.QUESTION_CLASSIFIER: lambda: self._build_llm_compatible_node_init_kwargs(
|
||||
node_class=node_class,
|
||||
node_data=node_data,
|
||||
wrap_model_instance=True,
|
||||
include_http_client=True,
|
||||
include_llm_file_saver=True,
|
||||
include_prompt_message_serializer=True,
|
||||
include_retriever_attachment_loader=False,
|
||||
include_jinja2_template_renderer=False,
|
||||
),
|
||||
BuiltinNodeTypes.PARAMETER_EXTRACTOR: lambda: self._build_llm_compatible_node_init_kwargs(
|
||||
node_class=node_class,
|
||||
node_data=node_data,
|
||||
wrap_model_instance=True,
|
||||
include_http_client=False,
|
||||
include_llm_file_saver=False,
|
||||
include_prompt_message_serializer=True,
|
||||
include_retriever_attachment_loader=False,
|
||||
include_jinja2_template_renderer=False,
|
||||
),
|
||||
BuiltinNodeTypes.TOOL: lambda: {
|
||||
"tool_file_manager_factory": self._http_request_tool_file_manager_factory(),
|
||||
"tool_file_manager_factory": self._bound_tool_file_manager_factory(),
|
||||
"runtime": self._tool_runtime,
|
||||
},
|
||||
BuiltinNodeTypes.AGENT: lambda: {
|
||||
"strategy_resolver": self._agent_strategy_resolver,
|
||||
@@ -387,7 +423,12 @@ class DifyNodeFactory(NodeFactory):
|
||||
*,
|
||||
node_class: type[Node],
|
||||
node_data: BaseNodeData,
|
||||
wrap_model_instance: bool,
|
||||
include_http_client: bool,
|
||||
include_llm_file_saver: bool,
|
||||
include_prompt_message_serializer: bool,
|
||||
include_retriever_attachment_loader: bool,
|
||||
include_jinja2_template_renderer: bool,
|
||||
) -> dict[str, object]:
|
||||
validated_node_data = cast(
|
||||
LLMCompatibleNodeData,
|
||||
@@ -397,49 +438,33 @@ class DifyNodeFactory(NodeFactory):
|
||||
node_init_kwargs: dict[str, object] = {
|
||||
"credentials_provider": self._llm_credentials_provider,
|
||||
"model_factory": self._llm_model_factory,
|
||||
"model_instance": model_instance,
|
||||
"model_instance": DifyPreparedLLM(model_instance) if wrap_model_instance else model_instance,
|
||||
"memory": self._build_memory_for_llm_node(
|
||||
node_data=validated_node_data,
|
||||
model_instance=model_instance,
|
||||
),
|
||||
}
|
||||
if validated_node_data.type in {BuiltinNodeTypes.LLM, BuiltinNodeTypes.QUESTION_CLASSIFIER}:
|
||||
if validated_node_data.type == BuiltinNodeTypes.QUESTION_CLASSIFIER:
|
||||
node_init_kwargs["template_renderer"] = self._llm_template_renderer
|
||||
if include_http_client:
|
||||
node_init_kwargs["http_client"] = self._http_request_http_client
|
||||
if include_llm_file_saver:
|
||||
node_init_kwargs["llm_file_saver"] = self._llm_file_saver
|
||||
if include_prompt_message_serializer:
|
||||
node_init_kwargs["prompt_message_serializer"] = self._prompt_message_serializer
|
||||
if include_retriever_attachment_loader:
|
||||
node_init_kwargs["retriever_attachment_loader"] = self._retriever_attachment_loader
|
||||
if include_jinja2_template_renderer:
|
||||
node_init_kwargs["jinja2_template_renderer"] = self._jinja2_template_renderer
|
||||
return node_init_kwargs
|
||||
|
||||
def _build_model_instance_for_llm_node(self, node_data: LLMCompatibleNodeData) -> ModelInstance:
|
||||
node_data_model = node_data.model
|
||||
if not node_data_model.mode:
|
||||
raise LLMModeRequiredError("LLM mode is required.")
|
||||
|
||||
credentials = self._llm_credentials_provider.fetch(node_data_model.provider, node_data_model.name)
|
||||
model_instance = self._llm_model_factory.init_model_instance(node_data_model.provider, node_data_model.name)
|
||||
provider_model_bundle = model_instance.provider_model_bundle
|
||||
|
||||
provider_model = provider_model_bundle.configuration.get_provider_model(
|
||||
model=node_data_model.name,
|
||||
model_type=ModelType.LLM,
|
||||
model_instance, _ = fetch_model_config(
|
||||
node_data_model=node_data_model,
|
||||
credentials_provider=self._llm_credentials_provider,
|
||||
model_factory=self._llm_model_factory,
|
||||
)
|
||||
if provider_model is None:
|
||||
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
|
||||
provider_model.raise_for_status()
|
||||
|
||||
completion_params = dict(node_data_model.completion_params)
|
||||
stop = completion_params.pop("stop", [])
|
||||
if not isinstance(stop, list):
|
||||
stop = []
|
||||
|
||||
model_schema = model_instance.model_type_instance.get_model_schema(node_data_model.name, credentials)
|
||||
if not model_schema:
|
||||
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
|
||||
|
||||
model_instance.provider = node_data_model.provider
|
||||
model_instance.model_name = node_data_model.name
|
||||
model_instance.credentials = credentials
|
||||
model_instance.parameters = completion_params
|
||||
model_instance.stop = tuple(stop)
|
||||
model_instance.model_type_instance = cast(LargeLanguageModel, model_instance.model_type_instance)
|
||||
return model_instance
|
||||
|
||||
|
||||
532
api/core/workflow/node_runtime.py
Normal file
532
api/core/workflow/node_runtime.py
Normal file
@@ -0,0 +1,532 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext
|
||||
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
||||
from core.llm_generator.output_parser.errors import OutputParserError
|
||||
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
|
||||
from core.model_manager import ModelInstance
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError, PluginInvokeError
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.tools.entities.tool_entities import ToolProviderType as CoreToolProviderType
|
||||
from core.tools.errors import ToolInvokeError
|
||||
from core.tools.signature import sign_upload_file
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
from dify_graph.file import FileTransferMethod, FileType
|
||||
from dify_graph.model_runtime.entities import LLMMode
|
||||
from dify_graph.model_runtime.entities.llm_entities import (
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMResultChunkWithStructuredOutput,
|
||||
LLMResultWithStructuredOutput,
|
||||
LLMUsage,
|
||||
)
|
||||
from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||
from dify_graph.model_runtime.entities.model_entities import AIModelEntity
|
||||
from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from dify_graph.nodes.human_input.entities import DeliveryChannelConfig, apply_debug_email_recipient
|
||||
from dify_graph.nodes.llm.runtime_protocols import (
|
||||
PreparedLLMProtocol,
|
||||
PromptMessageSerializerProtocol,
|
||||
RetrieverAttachmentLoaderProtocol,
|
||||
)
|
||||
from dify_graph.nodes.protocols import FileReferenceFactoryProtocol, HttpClientProtocol, ToolFileManagerProtocol
|
||||
from dify_graph.nodes.runtime import (
|
||||
HumanInputNodeRuntimeProtocol,
|
||||
ToolNodeRuntimeProtocol,
|
||||
)
|
||||
from dify_graph.nodes.tool.exc import ToolNodeError, ToolRuntimeInvocationError, ToolRuntimeResolutionError
|
||||
from dify_graph.nodes.tool_runtime_entities import (
|
||||
ToolRuntimeHandle,
|
||||
ToolRuntimeMessage,
|
||||
ToolRuntimeParameter,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from models.dataset import SegmentAttachmentBinding
|
||||
from models.model import UploadFile
|
||||
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage as CoreToolInvokeMessage
|
||||
from dify_graph.file import File
|
||||
from dify_graph.nodes.llm.file_saver import LLMFileSaver
|
||||
from dify_graph.nodes.tool.entities import ToolNodeData
|
||||
|
||||
|
||||
def resolve_dify_run_context(run_context: Mapping[str, Any] | DifyRunContext) -> DifyRunContext:
|
||||
if isinstance(run_context, DifyRunContext):
|
||||
return run_context
|
||||
|
||||
raw_ctx = run_context.get(DIFY_RUN_CONTEXT_KEY)
|
||||
if raw_ctx is None:
|
||||
raise ValueError(f"run_context missing required key: {DIFY_RUN_CONTEXT_KEY}")
|
||||
if isinstance(raw_ctx, DifyRunContext):
|
||||
return raw_ctx
|
||||
return DifyRunContext.model_validate(raw_ctx)
|
||||
|
||||
|
||||
class DifyFileReferenceFactory(FileReferenceFactoryProtocol):
|
||||
def __init__(self, run_context: Mapping[str, Any] | DifyRunContext) -> None:
|
||||
self._run_context = resolve_dify_run_context(run_context)
|
||||
|
||||
def build_from_mapping(self, *, mapping: Mapping[str, Any]):
|
||||
return file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=self._run_context.tenant_id,
|
||||
)
|
||||
|
||||
|
||||
class DifyPreparedLLM(PreparedLLMProtocol):
|
||||
"""Workflow-layer adapter that hides the full `ModelInstance` API from `dify_graph` nodes."""
|
||||
|
||||
def __init__(self, model_instance: ModelInstance) -> None:
|
||||
self._model_instance = model_instance
|
||||
|
||||
@property
|
||||
def provider(self) -> str:
|
||||
return self._model_instance.provider
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
return self._model_instance.model_name
|
||||
|
||||
@property
|
||||
def parameters(self) -> Mapping[str, Any]:
|
||||
return self._model_instance.parameters
|
||||
|
||||
@property
|
||||
def stop(self) -> Sequence[str] | None:
|
||||
return self._model_instance.stop
|
||||
|
||||
def get_model_schema(self) -> AIModelEntity:
|
||||
model_schema = cast(LargeLanguageModel, self._model_instance.model_type_instance).get_model_schema(
|
||||
self._model_instance.model_name,
|
||||
self._model_instance.credentials,
|
||||
)
|
||||
if model_schema is None:
|
||||
raise ValueError(f"Model schema not found for {self._model_instance.model_name}")
|
||||
return model_schema
|
||||
|
||||
def get_llm_num_tokens(self, prompt_messages: Sequence[PromptMessage]) -> int:
|
||||
return self._model_instance.get_llm_num_tokens(prompt_messages)
|
||||
|
||||
def invoke_llm(
|
||||
self,
|
||||
*,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
model_parameters: Mapping[str, Any],
|
||||
tools: Sequence[PromptMessageTool] | None,
|
||||
stop: Sequence[str] | None,
|
||||
stream: bool,
|
||||
) -> LLMResult | Generator[LLMResultChunk, None, None]:
|
||||
return self._model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages),
|
||||
model_parameters=dict(model_parameters),
|
||||
tools=list(tools or []),
|
||||
stop=list(stop or []),
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
def invoke_llm_with_structured_output(
|
||||
self,
|
||||
*,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
json_schema: Mapping[str, Any],
|
||||
model_parameters: Mapping[str, Any],
|
||||
stop: Sequence[str] | None,
|
||||
stream: bool,
|
||||
) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]:
|
||||
return invoke_llm_with_structured_output(
|
||||
provider=self.provider,
|
||||
model_schema=self.get_model_schema(),
|
||||
model_instance=self._model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
json_schema=json_schema,
|
||||
model_parameters=model_parameters,
|
||||
stop=list(stop or []),
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
def is_structured_output_parse_error(self, error: Exception) -> bool:
|
||||
return isinstance(error, OutputParserError)
|
||||
|
||||
|
||||
class DifyPromptMessageSerializer(PromptMessageSerializerProtocol):
|
||||
def serialize(
|
||||
self,
|
||||
*,
|
||||
model_mode: LLMMode,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
) -> Any:
|
||||
return PromptMessageUtil.prompt_messages_to_prompt_for_saving(
|
||||
model_mode=model_mode,
|
||||
prompt_messages=prompt_messages,
|
||||
)
|
||||
|
||||
|
||||
class DifyRetrieverAttachmentLoader(RetrieverAttachmentLoaderProtocol):
|
||||
"""Resolve retriever attachments through Dify persistence and return graph file references."""
|
||||
|
||||
def __init__(self, *, file_reference_factory: FileReferenceFactoryProtocol) -> None:
|
||||
self._file_reference_factory = file_reference_factory
|
||||
|
||||
def load(self, *, segment_id: str) -> Sequence[File]:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
attachments_with_bindings = session.execute(
|
||||
select(SegmentAttachmentBinding, UploadFile)
|
||||
.join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id)
|
||||
.where(SegmentAttachmentBinding.segment_id == segment_id)
|
||||
).all()
|
||||
|
||||
return [
|
||||
self._file_reference_factory.build_from_mapping(
|
||||
mapping={
|
||||
"id": upload_file.id,
|
||||
"filename": upload_file.name,
|
||||
"extension": "." + upload_file.extension,
|
||||
"mime_type": upload_file.mime_type,
|
||||
"type": FileType.IMAGE,
|
||||
"transfer_method": FileTransferMethod.LOCAL_FILE,
|
||||
"remote_url": upload_file.source_url,
|
||||
"related_id": upload_file.id,
|
||||
"upload_file_id": upload_file.id,
|
||||
"size": upload_file.size,
|
||||
"storage_key": upload_file.key,
|
||||
"url": sign_upload_file(upload_file.id, upload_file.extension),
|
||||
}
|
||||
)
|
||||
for _, upload_file in attachments_with_bindings
|
||||
]
|
||||
|
||||
|
||||
class DifyToolFileManager(ToolFileManagerProtocol):
|
||||
def __init__(self, run_context: Mapping[str, Any] | DifyRunContext) -> None:
|
||||
self._run_context = resolve_dify_run_context(run_context)
|
||||
self._manager = ToolFileManager()
|
||||
|
||||
def create_file_by_raw(
|
||||
self,
|
||||
*,
|
||||
conversation_id: str | None,
|
||||
file_binary: bytes,
|
||||
mimetype: str,
|
||||
filename: str | None = None,
|
||||
) -> Any:
|
||||
return self._manager.create_file_by_raw(
|
||||
user_id=self._run_context.user_id,
|
||||
tenant_id=self._run_context.tenant_id,
|
||||
conversation_id=conversation_id,
|
||||
file_binary=file_binary,
|
||||
mimetype=mimetype,
|
||||
filename=filename,
|
||||
)
|
||||
|
||||
def get_file_generator_by_tool_file_id(self, tool_file_id: str):
|
||||
return self._manager.get_file_generator_by_tool_file_id(tool_file_id)
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class _WorkflowToolRuntimeSpec:
|
||||
provider_type: CoreToolProviderType
|
||||
provider_id: str
|
||||
tool_name: str
|
||||
tool_configurations: dict[str, Any]
|
||||
credential_id: str | None = None
|
||||
|
||||
|
||||
class DifyToolNodeRuntime(ToolNodeRuntimeProtocol):
|
||||
def __init__(self, run_context: Mapping[str, Any] | DifyRunContext) -> None:
|
||||
self._run_context = resolve_dify_run_context(run_context)
|
||||
self._file_reference_factory = DifyFileReferenceFactory(self._run_context)
|
||||
|
||||
@property
|
||||
def file_reference_factory(self) -> FileReferenceFactoryProtocol:
|
||||
return self._file_reference_factory
|
||||
|
||||
def build_file_reference(self, *, mapping: Mapping[str, Any]):
|
||||
return self._file_reference_factory.build_from_mapping(mapping=mapping)
|
||||
|
||||
def get_runtime(
|
||||
self,
|
||||
*,
|
||||
node_id: str,
|
||||
node_data: ToolNodeData,
|
||||
variable_pool,
|
||||
) -> ToolRuntimeHandle:
|
||||
try:
|
||||
tool_runtime = ToolManager.get_workflow_tool_runtime(
|
||||
self._run_context.tenant_id,
|
||||
self._run_context.app_id,
|
||||
node_id,
|
||||
self._build_tool_runtime_spec(node_data),
|
||||
self._run_context.invoke_from,
|
||||
variable_pool,
|
||||
)
|
||||
except ToolNodeError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
raise ToolRuntimeResolutionError(str(exc)) from exc
|
||||
|
||||
return ToolRuntimeHandle(raw=tool_runtime)
|
||||
|
||||
def get_runtime_parameters(
|
||||
self,
|
||||
*,
|
||||
tool_runtime: ToolRuntimeHandle,
|
||||
) -> Sequence[ToolRuntimeParameter]:
|
||||
tool = self._tool_from_handle(tool_runtime)
|
||||
return [
|
||||
ToolRuntimeParameter(name=parameter.name, required=parameter.required)
|
||||
for parameter in (tool.get_merged_runtime_parameters() or [])
|
||||
]
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
*,
|
||||
tool_runtime: ToolRuntimeHandle,
|
||||
tool_parameters: Mapping[str, Any],
|
||||
workflow_call_depth: int,
|
||||
conversation_id: str | None,
|
||||
provider_name: str,
|
||||
) -> Generator[ToolRuntimeMessage, None, None]:
|
||||
tool = self._tool_from_handle(tool_runtime)
|
||||
callback = DifyWorkflowCallbackHandler()
|
||||
|
||||
try:
|
||||
messages = ToolEngine.generic_invoke(
|
||||
tool=tool,
|
||||
tool_parameters=dict(tool_parameters),
|
||||
user_id=self._run_context.user_id,
|
||||
workflow_tool_callback=callback,
|
||||
workflow_call_depth=workflow_call_depth,
|
||||
app_id=self._run_context.app_id,
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
except Exception as exc:
|
||||
raise self._map_invocation_exception(exc, provider_name=provider_name) from exc
|
||||
|
||||
transformed_messages = ToolFileMessageTransformer.transform_tool_invoke_messages(
|
||||
messages=messages,
|
||||
user_id=self._run_context.user_id,
|
||||
tenant_id=self._run_context.tenant_id,
|
||||
conversation_id=None,
|
||||
)
|
||||
|
||||
return self._adapt_messages(transformed_messages, provider_name=provider_name)
|
||||
|
||||
def get_usage(
|
||||
self,
|
||||
*,
|
||||
tool_runtime: ToolRuntimeHandle,
|
||||
) -> LLMUsage:
|
||||
latest = getattr(tool_runtime.raw, "latest_usage", None)
|
||||
if isinstance(latest, LLMUsage):
|
||||
return latest
|
||||
if isinstance(latest, dict):
|
||||
return LLMUsage.model_validate(latest)
|
||||
return LLMUsage.empty_usage()
|
||||
|
||||
def resolve_provider_icons(
|
||||
self,
|
||||
*,
|
||||
provider_name: str,
|
||||
default_icon: str | None = None,
|
||||
) -> tuple[str | None, str | None]:
|
||||
icon = default_icon
|
||||
icon_dark = None
|
||||
|
||||
manager = PluginInstaller()
|
||||
plugins = manager.list_plugins(self._run_context.tenant_id)
|
||||
try:
|
||||
current_plugin = next(plugin for plugin in plugins if f"{plugin.plugin_id}/{plugin.name}" == provider_name)
|
||||
icon = current_plugin.declaration.icon
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
try:
|
||||
builtin_tool = next(
|
||||
provider
|
||||
for provider in BuiltinToolManageService.list_builtin_tools(
|
||||
self._run_context.user_id,
|
||||
self._run_context.tenant_id,
|
||||
)
|
||||
if provider.name == provider_name
|
||||
)
|
||||
icon = builtin_tool.icon
|
||||
icon_dark = builtin_tool.icon_dark
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
return icon, icon_dark
|
||||
|
||||
@staticmethod
|
||||
def _tool_from_handle(tool_runtime: ToolRuntimeHandle) -> Tool:
|
||||
return cast("Tool", tool_runtime.raw)
|
||||
|
||||
@staticmethod
|
||||
def _build_tool_runtime_spec(node_data: ToolNodeData) -> _WorkflowToolRuntimeSpec:
|
||||
return _WorkflowToolRuntimeSpec(
|
||||
provider_type=CoreToolProviderType(node_data.provider_type.value),
|
||||
provider_id=node_data.provider_id,
|
||||
tool_name=node_data.tool_name,
|
||||
tool_configurations=dict(node_data.tool_configurations),
|
||||
credential_id=node_data.credential_id,
|
||||
)
|
||||
|
||||
def _adapt_messages(
|
||||
self,
|
||||
messages: Generator[CoreToolInvokeMessage, None, None],
|
||||
*,
|
||||
provider_name: str,
|
||||
) -> Generator[ToolRuntimeMessage, None, None]:
|
||||
try:
|
||||
for message in messages:
|
||||
yield self._convert_message(message)
|
||||
except Exception as exc:
|
||||
raise self._map_invocation_exception(exc, provider_name=provider_name) from exc
|
||||
|
||||
def _convert_message(self, message: CoreToolInvokeMessage) -> ToolRuntimeMessage:
|
||||
graph_message_type = ToolRuntimeMessage.MessageType(message.type.value)
|
||||
graph_message = self._convert_message_payload(message.message)
|
||||
graph_meta = message.meta.copy() if message.meta is not None else None
|
||||
return ToolRuntimeMessage(type=graph_message_type, message=graph_message, meta=graph_meta)
|
||||
|
||||
def _convert_message_payload(
|
||||
self,
|
||||
message: CoreToolInvokeMessage.TextMessage
|
||||
| CoreToolInvokeMessage.JsonMessage
|
||||
| CoreToolInvokeMessage.BlobChunkMessage
|
||||
| CoreToolInvokeMessage.BlobMessage
|
||||
| CoreToolInvokeMessage.LogMessage
|
||||
| CoreToolInvokeMessage.FileMessage
|
||||
| CoreToolInvokeMessage.VariableMessage
|
||||
| CoreToolInvokeMessage.RetrieverResourceMessage
|
||||
| None,
|
||||
) -> (
|
||||
ToolRuntimeMessage.TextMessage
|
||||
| ToolRuntimeMessage.JsonMessage
|
||||
| ToolRuntimeMessage.BlobChunkMessage
|
||||
| ToolRuntimeMessage.BlobMessage
|
||||
| ToolRuntimeMessage.LogMessage
|
||||
| ToolRuntimeMessage.FileMessage
|
||||
| ToolRuntimeMessage.VariableMessage
|
||||
| ToolRuntimeMessage.RetrieverResourceMessage
|
||||
| None
|
||||
):
|
||||
if message is None:
|
||||
return None
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage as CoreToolInvokeMessage
|
||||
|
||||
if isinstance(message, CoreToolInvokeMessage.TextMessage):
|
||||
return ToolRuntimeMessage.TextMessage(text=message.text)
|
||||
if isinstance(message, CoreToolInvokeMessage.JsonMessage):
|
||||
return ToolRuntimeMessage.JsonMessage(
|
||||
json_object=message.json_object,
|
||||
suppress_output=message.suppress_output,
|
||||
)
|
||||
if isinstance(message, CoreToolInvokeMessage.BlobMessage):
|
||||
return ToolRuntimeMessage.BlobMessage(blob=message.blob)
|
||||
if isinstance(message, CoreToolInvokeMessage.BlobChunkMessage):
|
||||
return ToolRuntimeMessage.BlobChunkMessage(
|
||||
id=message.id,
|
||||
sequence=message.sequence,
|
||||
total_length=message.total_length,
|
||||
blob=message.blob,
|
||||
end=message.end,
|
||||
)
|
||||
if isinstance(message, CoreToolInvokeMessage.FileMessage):
|
||||
return ToolRuntimeMessage.FileMessage(file_marker=message.file_marker)
|
||||
if isinstance(message, CoreToolInvokeMessage.VariableMessage):
|
||||
return ToolRuntimeMessage.VariableMessage(
|
||||
variable_name=message.variable_name,
|
||||
variable_value=message.variable_value,
|
||||
stream=message.stream,
|
||||
)
|
||||
if isinstance(message, CoreToolInvokeMessage.LogMessage):
|
||||
return ToolRuntimeMessage.LogMessage(
|
||||
id=message.id,
|
||||
label=message.label,
|
||||
parent_id=message.parent_id,
|
||||
error=message.error,
|
||||
status=ToolRuntimeMessage.LogMessage.LogStatus(message.status.value),
|
||||
data=dict(message.data),
|
||||
metadata=dict(message.metadata),
|
||||
)
|
||||
if isinstance(message, CoreToolInvokeMessage.RetrieverResourceMessage):
|
||||
retriever_resources = [
|
||||
resource.model_dump() if hasattr(resource, "model_dump") else dict(resource)
|
||||
for resource in message.retriever_resources
|
||||
]
|
||||
return ToolRuntimeMessage.RetrieverResourceMessage(
|
||||
retriever_resources=retriever_resources,
|
||||
context=message.context,
|
||||
)
|
||||
|
||||
raise TypeError(f"unsupported tool message payload: {type(message).__name__}")
|
||||
|
||||
@staticmethod
|
||||
def _map_invocation_exception(exc: Exception, *, provider_name: str) -> ToolNodeError:
|
||||
if isinstance(exc, ToolNodeError):
|
||||
return exc
|
||||
if isinstance(exc, PluginInvokeError):
|
||||
return ToolRuntimeInvocationError(exc.to_user_friendly_error(plugin_name=provider_name))
|
||||
if isinstance(exc, PluginDaemonClientSideError):
|
||||
return ToolRuntimeInvocationError(f"Failed to invoke tool, error: {exc.description}")
|
||||
if isinstance(exc, ToolInvokeError):
|
||||
return ToolRuntimeInvocationError(f"Failed to invoke tool {provider_name}: {exc}")
|
||||
return ToolRuntimeInvocationError(str(exc))
|
||||
|
||||
|
||||
class DifyHumanInputNodeRuntime(HumanInputNodeRuntimeProtocol):
|
||||
def __init__(self, run_context: Mapping[str, Any] | DifyRunContext) -> None:
|
||||
self._run_context = resolve_dify_run_context(run_context)
|
||||
|
||||
def invoke_source(self) -> str:
|
||||
invoke_from = self._run_context.invoke_from
|
||||
if isinstance(invoke_from, str):
|
||||
return invoke_from
|
||||
return str(getattr(invoke_from, "value", invoke_from))
|
||||
|
||||
def apply_delivery_runtime(
|
||||
self,
|
||||
*,
|
||||
methods: Sequence[DeliveryChannelConfig],
|
||||
) -> Sequence[DeliveryChannelConfig]:
|
||||
return [
|
||||
apply_debug_email_recipient(
|
||||
method,
|
||||
enabled=self.invoke_source() == "debugger",
|
||||
user_id=self._run_context.user_id,
|
||||
)
|
||||
for method in methods
|
||||
]
|
||||
|
||||
def console_actor_id(self) -> str | None:
|
||||
return self._run_context.user_id
|
||||
|
||||
|
||||
def build_dify_llm_file_saver(
|
||||
*,
|
||||
run_context: Mapping[str, Any] | DifyRunContext,
|
||||
http_client: HttpClientProtocol,
|
||||
) -> LLMFileSaver:
|
||||
from dify_graph.nodes.llm.file_saver import FileSaverImpl
|
||||
|
||||
return FileSaverImpl(
|
||||
tool_file_manager=DifyToolFileManager(run_context),
|
||||
file_reference_factory=DifyFileReferenceFactory(run_context),
|
||||
http_client=http_client,
|
||||
)
|
||||
@@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from dify_graph.enums import BuiltinNodeTypes, SystemVariableKey, WorkflowNodeExecutionStatus
|
||||
from dify_graph.node_events import NodeEventBase, NodeRunResult, StreamCompletedEvent
|
||||
@@ -59,7 +60,7 @@ class AgentNode(Node[AgentNodeData]):
|
||||
return "1"
|
||||
|
||||
def populate_start_event(self, event) -> None:
|
||||
dify_ctx = self.require_dify_context()
|
||||
dify_ctx = DifyRunContext.model_validate(self.require_run_context_value(DIFY_RUN_CONTEXT_KEY))
|
||||
event.extras["agent_strategy"] = {
|
||||
"name": self.node_data.agent_strategy_name,
|
||||
"icon": self._presentation_provider.get_icon(
|
||||
@@ -71,7 +72,7 @@ class AgentNode(Node[AgentNodeData]):
|
||||
def _run(self) -> Generator[NodeEventBase, None, None]:
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
|
||||
dify_ctx = self.require_dify_context()
|
||||
dify_ctx = DifyRunContext.model_validate(self.require_run_context_value(DIFY_RUN_CONTEXT_KEY))
|
||||
|
||||
try:
|
||||
strategy = self._strategy_resolver.resolve(
|
||||
@@ -97,6 +98,7 @@ class AgentNode(Node[AgentNodeData]):
|
||||
node_data=self.node_data,
|
||||
strategy=strategy,
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
user_id=dify_ctx.user_id,
|
||||
app_id=dify_ctx.app_id,
|
||||
invoke_from=dify_ctx.invoke_from,
|
||||
)
|
||||
@@ -106,6 +108,7 @@ class AgentNode(Node[AgentNodeData]):
|
||||
node_data=self.node_data,
|
||||
strategy=strategy,
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
user_id=dify_ctx.user_id,
|
||||
app_id=dify_ctx.app_id,
|
||||
invoke_from=dify_ctx.invoke_from,
|
||||
for_log=True,
|
||||
|
||||
@@ -14,7 +14,7 @@ from core.agent.plugin_entities import AgentStrategyParameter
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.plugin.entities.request import InvokeCredentials
|
||||
from core.provider_manager import ProviderManager
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
|
||||
from core.tools.entities.tool_entities import ToolIdentity, ToolParameter, ToolProviderType
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from dify_graph.enums import SystemVariableKey
|
||||
@@ -38,6 +38,7 @@ class AgentRuntimeSupport:
|
||||
node_data: AgentNodeData,
|
||||
strategy: ResolvedAgentStrategy,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
app_id: str,
|
||||
invoke_from: Any,
|
||||
for_log: bool = False,
|
||||
@@ -174,7 +175,11 @@ class AgentRuntimeSupport:
|
||||
value = tool_value
|
||||
if parameter.type == AgentStrategyParameter.AgentStrategyParameterType.MODEL_SELECTOR:
|
||||
value = cast(dict[str, Any], value)
|
||||
model_instance, model_schema = self.fetch_model(tenant_id=tenant_id, value=value)
|
||||
model_instance, model_schema = self.fetch_model(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
value=value,
|
||||
)
|
||||
history_prompt_messages = []
|
||||
if node_data.memory:
|
||||
memory = self.fetch_memory(
|
||||
@@ -232,8 +237,14 @@ class AgentRuntimeSupport:
|
||||
|
||||
return TokenBufferMemory(conversation=conversation, model_instance=model_instance)
|
||||
|
||||
def fetch_model(self, *, tenant_id: str, value: dict[str, Any]) -> tuple[ModelInstance, AIModelEntity | None]:
|
||||
provider_manager = ProviderManager()
|
||||
def fetch_model(
|
||||
self,
|
||||
*,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
value: dict[str, Any],
|
||||
) -> tuple[ModelInstance, AIModelEntity | None]:
|
||||
provider_manager = create_plugin_provider_manager(tenant_id=tenant_id, user_id=user_id)
|
||||
provider_model_bundle = provider_manager.get_provider_model_bundle(
|
||||
tenant_id=tenant_id,
|
||||
provider=value.get("provider", ""),
|
||||
@@ -246,7 +257,7 @@ class AgentRuntimeSupport:
|
||||
)
|
||||
provider_name = provider_model_bundle.configuration.provider.provider
|
||||
model_type_instance = provider_model_bundle.model_type_instance
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
model_instance = ModelManager.for_tenant(tenant_id=tenant_id, user_id=user_id).get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider_name,
|
||||
model_type=ModelType(value.get("model_type", "")),
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext
|
||||
from core.datasource.datasource_manager import DatasourceManager
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderType
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
@@ -50,8 +51,7 @@ class DatasourceNode(Node[DatasourceNodeData]):
|
||||
"""
|
||||
Run the datasource node
|
||||
"""
|
||||
|
||||
dify_ctx = self.require_dify_context()
|
||||
dify_ctx = DifyRunContext.model_validate(self.require_run_context_value(DIFY_RUN_CONTEXT_KEY))
|
||||
node_data = self.node_data
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
datasource_type_segment = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE])
|
||||
|
||||
@@ -9,6 +9,7 @@ from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
from core.app.app_config.entities import DatasetRetrieveConfigEntity
|
||||
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext
|
||||
from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from dify_graph.entities import GraphInitParams
|
||||
@@ -160,7 +161,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
def _fetch_dataset_retriever(
|
||||
self, node_data: KnowledgeRetrievalNodeData, variables: dict[str, Any]
|
||||
) -> tuple[list[Source], LLMUsage]:
|
||||
dify_ctx = self.require_dify_context()
|
||||
dify_ctx = DifyRunContext.model_validate(self.require_run_context_value(DIFY_RUN_CONTEXT_KEY))
|
||||
dataset_ids = node_data.dataset_ids
|
||||
query = variables.get("query")
|
||||
attachments = variables.get("attachments")
|
||||
|
||||
@@ -9,9 +9,9 @@ from dify_graph.enums import NodeExecutionType
|
||||
from dify_graph.file import FileTransferMethod
|
||||
from dify_graph.node_events import NodeRunResult
|
||||
from dify_graph.nodes.base.node import Node
|
||||
from dify_graph.nodes.protocols import FileReferenceFactoryProtocol
|
||||
from dify_graph.variables.types import SegmentType
|
||||
from dify_graph.variables.variables import FileVariable
|
||||
from factories import file_factory
|
||||
from factories.variable_factory import build_segment_with_type
|
||||
|
||||
from .entities import ContentType, WebhookData
|
||||
@@ -23,6 +23,13 @@ class TriggerWebhookNode(Node[WebhookData]):
|
||||
node_type = TRIGGER_WEBHOOK_NODE_TYPE
|
||||
execution_type = NodeExecutionType.ROOT
|
||||
|
||||
_file_reference_factory: FileReferenceFactoryProtocol
|
||||
|
||||
def post_init(self) -> None:
|
||||
from core.workflow.node_runtime import DifyFileReferenceFactory
|
||||
|
||||
self._file_reference_factory = DifyFileReferenceFactory(self.graph_init_params.run_context)
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||
return {
|
||||
@@ -70,7 +77,6 @@ class TriggerWebhookNode(Node[WebhookData]):
|
||||
)
|
||||
|
||||
def generate_file_var(self, param_name: str, file: dict):
|
||||
dify_ctx = self.require_dify_context()
|
||||
related_id = file.get("related_id")
|
||||
transfer_method_value = file.get("transfer_method")
|
||||
if transfer_method_value:
|
||||
@@ -84,10 +90,7 @@ class TriggerWebhookNode(Node[WebhookData]):
|
||||
file["datasource_file_id"] = related_id
|
||||
|
||||
try:
|
||||
file_obj = file_factory.build_from_mapping(
|
||||
mapping=file,
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
)
|
||||
file_obj = self._file_reference_factory.build_from_mapping(mapping=file)
|
||||
file_segment = build_segment_with_type(SegmentType.FILE, file_obj)
|
||||
return FileVariable(name=param_name, value=file_segment.value, selector=[self.id, param_name])
|
||||
except ValueError:
|
||||
|
||||
@@ -3,8 +3,6 @@ from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
DIFY_RUN_CONTEXT_KEY = "_dify"
|
||||
|
||||
|
||||
class GraphInitParams(BaseModel):
|
||||
"""GraphInitParams encapsulates the configurations and contextual information
|
||||
|
||||
@@ -14,7 +14,7 @@ from typing import Any
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from dify_graph.enums import WorkflowExecutionStatus, WorkflowType
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from dify_graph.utils.datetime_utils import naive_utc_now
|
||||
|
||||
|
||||
class WorkflowExecution(BaseModel):
|
||||
|
||||
@@ -10,7 +10,6 @@ from pydantic import TypeAdapter
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from dify_graph.enums import ErrorStrategy, NodeExecutionType, NodeState
|
||||
from dify_graph.nodes.base.node import Node
|
||||
from libs.typing import is_str
|
||||
|
||||
from .edge import Edge
|
||||
from .validation import get_graph_validator
|
||||
@@ -102,7 +101,7 @@ class Graph:
|
||||
source = edge_config.get("source")
|
||||
target = edge_config.get("target")
|
||||
|
||||
if not is_str(source) or not is_str(target):
|
||||
if not isinstance(source, str) or not isinstance(target, str):
|
||||
continue
|
||||
|
||||
# Create edge
|
||||
@@ -110,7 +109,7 @@ class Graph:
|
||||
edge_counter += 1
|
||||
|
||||
source_handle = edge_config.get("sourceHandle", "source")
|
||||
if not is_str(source_handle):
|
||||
if not isinstance(source_handle, str):
|
||||
continue
|
||||
|
||||
edge = Edge(
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from dify_graph.entities.pause_reason import PauseReason
|
||||
|
||||
from .base import GraphNodeEventBase
|
||||
@@ -30,7 +30,7 @@ class NodeRunStreamChunkEvent(GraphNodeEventBase):
|
||||
|
||||
|
||||
class NodeRunRetrieverResourceEvent(GraphNodeEventBase):
|
||||
retriever_resources: Sequence[RetrievalSourceMetadata] = Field(..., description="retriever resources")
|
||||
retriever_resources: Sequence[Mapping[str, Any]] = Field(..., description="retriever resources")
|
||||
context: str = Field(..., description="context")
|
||||
|
||||
|
||||
|
||||
@@ -93,10 +93,14 @@ class ModelCredentialSchema(BaseModel):
|
||||
|
||||
class SimpleProviderEntity(BaseModel):
|
||||
"""
|
||||
Simple model class for provider.
|
||||
Simplified provider schema exposed to callers.
|
||||
|
||||
`provider` is the canonical runtime identifier. `provider_name` is an optional
|
||||
compatibility alias for short-name lookups and is empty when no alias exists.
|
||||
"""
|
||||
|
||||
provider: str
|
||||
provider_name: str = ""
|
||||
label: I18nObject
|
||||
icon_small: I18nObject | None = None
|
||||
icon_small_dark: I18nObject | None = None
|
||||
@@ -115,10 +119,15 @@ class ProviderHelpEntity(BaseModel):
|
||||
|
||||
class ProviderEntity(BaseModel):
|
||||
"""
|
||||
Model class for provider.
|
||||
Runtime-native provider schema.
|
||||
|
||||
`provider` is the canonical runtime identifier. `provider_name` is a
|
||||
compatibility alias for callers that still resolve providers by short name and
|
||||
is empty when no alias exists.
|
||||
"""
|
||||
|
||||
provider: str
|
||||
provider_name: str = ""
|
||||
label: I18nObject
|
||||
description: I18nObject | None = None
|
||||
icon_small: I18nObject | None = None
|
||||
@@ -153,6 +162,7 @@ class ProviderEntity(BaseModel):
|
||||
"""
|
||||
return SimpleProviderEntity(
|
||||
provider=self.provider,
|
||||
provider_name=self.provider_name,
|
||||
label=self.label,
|
||||
icon_small=self.icon_small,
|
||||
supported_model_types=self.supported_model_types,
|
||||
|
||||
@@ -1,6 +1,13 @@
|
||||
from typing import TypedDict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class MultimodalRerankInput(TypedDict):
|
||||
content: str
|
||||
content_type: str
|
||||
|
||||
|
||||
class RerankDocument(BaseModel):
|
||||
"""
|
||||
Model class for rerank document.
|
||||
|
||||
@@ -1,10 +1,18 @@
|
||||
from decimal import Decimal
|
||||
from enum import StrEnum, auto
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelUsage
|
||||
|
||||
|
||||
class EmbeddingInputType(StrEnum):
|
||||
"""Embedding request input variants understood by the model runtime."""
|
||||
|
||||
DOCUMENT = auto()
|
||||
QUERY = auto()
|
||||
|
||||
|
||||
class EmbeddingUsage(ModelUsage):
|
||||
"""
|
||||
Model class for embedding usage.
|
||||
|
||||
@@ -1,12 +1,5 @@
|
||||
import decimal
|
||||
import hashlib
|
||||
import logging
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationError
|
||||
from redis import RedisError
|
||||
|
||||
from configs import dify_config
|
||||
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
||||
from dify_graph.model_runtime.entities.common_entities import I18nObject
|
||||
from dify_graph.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE
|
||||
from dify_graph.model_runtime.entities.model_entities import (
|
||||
@@ -17,6 +10,7 @@ from dify_graph.model_runtime.entities.model_entities import (
|
||||
PriceInfo,
|
||||
PriceType,
|
||||
)
|
||||
from dify_graph.model_runtime.entities.provider_entities import ProviderEntity
|
||||
from dify_graph.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
@@ -25,45 +19,61 @@ from dify_graph.model_runtime.errors.invoke import (
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from dify_graph.model_runtime.runtime import ModelRuntime
|
||||
|
||||
|
||||
class AIModel(BaseModel):
|
||||
class AIModel:
|
||||
"""
|
||||
Base class for all models.
|
||||
Runtime-facing base class for all model providers.
|
||||
|
||||
This stays a regular Python class because instances hold live collaborators
|
||||
such as the provider schema and runtime adapter rather than user input that
|
||||
benefits from Pydantic validation. Subclasses must pin ``model_type`` via a
|
||||
class attribute; the base class is not meant to be instantiated directly.
|
||||
"""
|
||||
|
||||
tenant_id: str = Field(description="Tenant ID")
|
||||
model_type: ModelType = Field(description="Model type")
|
||||
plugin_id: str = Field(description="Plugin ID")
|
||||
provider_name: str = Field(description="Provider")
|
||||
plugin_model_provider: PluginModelProviderEntity = Field(description="Plugin model provider")
|
||||
started_at: float = Field(description="Invoke start time", default=0)
|
||||
model_type: ModelType
|
||||
provider_schema: ProviderEntity
|
||||
model_runtime: ModelRuntime
|
||||
started_at: float
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
def __init__(
|
||||
self,
|
||||
provider_schema: ProviderEntity,
|
||||
model_runtime: ModelRuntime,
|
||||
*,
|
||||
started_at: float = 0,
|
||||
) -> None:
|
||||
if getattr(type(self), "model_type", None) is None:
|
||||
raise TypeError("AIModel subclasses must define model_type as a class attribute")
|
||||
|
||||
self.model_type = type(self).model_type
|
||||
self.provider_schema = provider_schema
|
||||
self.model_runtime = model_runtime
|
||||
self.started_at = started_at
|
||||
|
||||
@property
|
||||
def provider(self) -> str:
|
||||
return self.provider_schema.provider
|
||||
|
||||
@property
|
||||
def provider_display_name(self) -> str:
|
||||
return self.provider_schema.label.en_US
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[Exception], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
The value is the error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
Map model invoke error to unified error.
|
||||
|
||||
:return: Invoke error mapping
|
||||
The key is the error type thrown to the caller, and the value contains
|
||||
runtime-facing exception types that should be normalized to it.
|
||||
"""
|
||||
from core.plugin.entities.plugin_daemon import PluginDaemonInnerError
|
||||
|
||||
return {
|
||||
InvokeConnectionError: [InvokeConnectionError],
|
||||
InvokeServerUnavailableError: [InvokeServerUnavailableError],
|
||||
InvokeRateLimitError: [InvokeRateLimitError],
|
||||
InvokeAuthorizationError: [InvokeAuthorizationError],
|
||||
InvokeBadRequestError: [InvokeBadRequestError],
|
||||
PluginDaemonInnerError: [PluginDaemonInnerError],
|
||||
ValueError: [ValueError],
|
||||
}
|
||||
|
||||
@@ -79,15 +89,18 @@ class AIModel(BaseModel):
|
||||
if invoke_error == InvokeAuthorizationError:
|
||||
return InvokeAuthorizationError(
|
||||
description=(
|
||||
f"[{self.provider_name}] Incorrect model credentials provided, please check and try again."
|
||||
f"[{self.provider_display_name}] Incorrect model credentials provided, "
|
||||
"please check and try again."
|
||||
)
|
||||
)
|
||||
elif isinstance(invoke_error, InvokeError):
|
||||
return InvokeError(description=f"[{self.provider_name}] {invoke_error.description}, {str(error)}")
|
||||
return InvokeError(
|
||||
description=f"[{self.provider_display_name}] {invoke_error.description}, {str(error)}"
|
||||
)
|
||||
else:
|
||||
return error
|
||||
|
||||
return InvokeError(description=f"[{self.provider_name}] Error: {str(error)}")
|
||||
return InvokeError(description=f"[{self.provider_display_name}] Error: {str(error)}")
|
||||
|
||||
def get_price(self, model: str, credentials: dict, price_type: PriceType, tokens: int) -> PriceInfo:
|
||||
"""
|
||||
@@ -144,65 +157,13 @@ class AIModel(BaseModel):
|
||||
:param credentials: model credentials
|
||||
:return: model schema
|
||||
"""
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
plugin_model_manager = PluginModelClient()
|
||||
cache_key = f"{self.tenant_id}:{self.plugin_id}:{self.provider_name}:{self.model_type.value}:{model}"
|
||||
sorted_credentials = sorted(credentials.items()) if credentials else []
|
||||
cache_key += ":".join([hashlib.md5(f"{k}:{v}".encode()).hexdigest() for k, v in sorted_credentials])
|
||||
|
||||
cached_schema_json = None
|
||||
try:
|
||||
cached_schema_json = redis_client.get(cache_key)
|
||||
except (RedisError, RuntimeError) as exc:
|
||||
logger.warning(
|
||||
"Failed to read plugin model schema cache for model %s: %s",
|
||||
model,
|
||||
str(exc),
|
||||
exc_info=True,
|
||||
)
|
||||
if cached_schema_json:
|
||||
try:
|
||||
return AIModelEntity.model_validate_json(cached_schema_json)
|
||||
except ValidationError:
|
||||
logger.warning(
|
||||
"Failed to validate cached plugin model schema for model %s",
|
||||
model,
|
||||
exc_info=True,
|
||||
)
|
||||
try:
|
||||
redis_client.delete(cache_key)
|
||||
except (RedisError, RuntimeError) as exc:
|
||||
logger.warning(
|
||||
"Failed to delete invalid plugin model schema cache for model %s: %s",
|
||||
model,
|
||||
str(exc),
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
schema = plugin_model_manager.get_model_schema(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id="unknown",
|
||||
plugin_id=self.plugin_id,
|
||||
provider=self.provider_name,
|
||||
model_type=self.model_type.value,
|
||||
return self.model_runtime.get_model_schema(
|
||||
provider=self.provider,
|
||||
model_type=self.model_type,
|
||||
model=model,
|
||||
credentials=credentials or {},
|
||||
)
|
||||
|
||||
if schema:
|
||||
try:
|
||||
redis_client.setex(cache_key, dify_config.PLUGIN_MODEL_SCHEMA_CACHE_TTL, schema.model_dump_json())
|
||||
except (RedisError, RuntimeError) as exc:
|
||||
logger.warning(
|
||||
"Failed to write plugin model schema cache for model %s: %s",
|
||||
model,
|
||||
str(exc),
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
return schema
|
||||
|
||||
def get_customizable_model_schema_from_credentials(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
"""
|
||||
Get customizable model schema from credentials
|
||||
|
||||
@@ -4,9 +4,6 @@ import uuid
|
||||
from collections.abc import Callable, Generator, Iterator, Sequence
|
||||
from typing import Union
|
||||
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from configs import dify_config
|
||||
from dify_graph.model_runtime.callbacks.base_callback import Callback
|
||||
from dify_graph.model_runtime.callbacks.logging_callback import LoggingCallback
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMUsage
|
||||
@@ -140,11 +137,9 @@ def _build_llm_result_from_chunks(
|
||||
)
|
||||
|
||||
|
||||
def _invoke_llm_via_plugin(
|
||||
def _invoke_llm_via_runtime(
|
||||
*,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
plugin_id: str,
|
||||
llm_model: "LargeLanguageModel",
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
@@ -154,25 +149,19 @@ def _invoke_llm_via_plugin(
|
||||
stop: Sequence[str] | None,
|
||||
stream: bool,
|
||||
) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]:
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
plugin_model_manager = PluginModelClient()
|
||||
return plugin_model_manager.invoke_llm(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
plugin_id=plugin_id,
|
||||
return llm_model.model_runtime.invoke_llm(
|
||||
provider=provider,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
model_parameters=model_parameters,
|
||||
prompt_messages=list(prompt_messages),
|
||||
tools=tools,
|
||||
stop=list(stop) if stop else None,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
|
||||
def _normalize_non_stream_plugin_result(
|
||||
def _normalize_non_stream_runtime_result(
|
||||
model: str,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
result: Union[LLMResult, Iterator[LLMResultChunk]],
|
||||
@@ -208,9 +197,6 @@ class LargeLanguageModel(AIModel):
|
||||
|
||||
model_type: ModelType = ModelType.LLM
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
model: str,
|
||||
@@ -220,7 +206,6 @@ class LargeLanguageModel(AIModel):
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]:
|
||||
"""
|
||||
@@ -233,7 +218,6 @@ class LargeLanguageModel(AIModel):
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:param callbacks: callbacks
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
@@ -245,7 +229,7 @@ class LargeLanguageModel(AIModel):
|
||||
|
||||
callbacks = callbacks or []
|
||||
|
||||
if dify_config.DEBUG:
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
callbacks.append(LoggingCallback())
|
||||
|
||||
# trigger before invoke callbacks
|
||||
@@ -257,18 +241,15 @@ class LargeLanguageModel(AIModel):
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
|
||||
result: Union[LLMResult, Generator[LLMResultChunk, None, None]]
|
||||
|
||||
try:
|
||||
result = _invoke_llm_via_plugin(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user or "unknown",
|
||||
plugin_id=self.plugin_id,
|
||||
provider=self.provider_name,
|
||||
result = _invoke_llm_via_runtime(
|
||||
llm_model=self,
|
||||
provider=self.provider,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
model_parameters=model_parameters,
|
||||
@@ -279,7 +260,7 @@ class LargeLanguageModel(AIModel):
|
||||
)
|
||||
|
||||
if not stream:
|
||||
result = _normalize_non_stream_plugin_result(
|
||||
result = _normalize_non_stream_runtime_result(
|
||||
model=model, prompt_messages=prompt_messages, result=result
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -292,7 +273,6 @@ class LargeLanguageModel(AIModel):
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
|
||||
@@ -309,7 +289,6 @@ class LargeLanguageModel(AIModel):
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
elif isinstance(result, LLMResult):
|
||||
@@ -322,7 +301,6 @@ class LargeLanguageModel(AIModel):
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
# Following https://github.com/langgenius/dify/issues/17799,
|
||||
@@ -435,22 +413,14 @@ class LargeLanguageModel(AIModel):
|
||||
:param tools: tools for tool calling
|
||||
:return:
|
||||
"""
|
||||
if dify_config.PLUGIN_BASED_TOKEN_COUNTING_ENABLED:
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
plugin_model_manager = PluginModelClient()
|
||||
return plugin_model_manager.get_llm_num_tokens(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id="unknown",
|
||||
plugin_id=self.plugin_id,
|
||||
provider=self.provider_name,
|
||||
model_type=self.model_type.value,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
tools=tools,
|
||||
)
|
||||
return 0
|
||||
return self.model_runtime.get_llm_num_tokens(
|
||||
provider=self.provider,
|
||||
model_type=self.model_type,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
def calc_response_usage(
|
||||
self, model: str, credentials: dict, prompt_tokens: int, completion_tokens: int
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
import time
|
||||
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
|
||||
@@ -13,30 +11,20 @@ class ModerationModel(AIModel):
|
||||
|
||||
model_type: ModelType = ModelType.MODERATION
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
def invoke(self, model: str, credentials: dict, text: str, user: str | None = None) -> bool:
|
||||
def invoke(self, model: str, credentials: dict, text: str) -> bool:
|
||||
"""
|
||||
Invoke moderation model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param text: text to moderate
|
||||
:param user: unique user id
|
||||
:return: false if text is safe, true otherwise
|
||||
"""
|
||||
self.started_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
plugin_model_manager = PluginModelClient()
|
||||
return plugin_model_manager.invoke_moderation(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user or "unknown",
|
||||
plugin_id=self.plugin_id,
|
||||
provider=self.provider_name,
|
||||
return self.model_runtime.invoke_moderation(
|
||||
provider=self.provider,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
text=text,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from dify_graph.model_runtime.entities.rerank_entities import RerankResult
|
||||
from dify_graph.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult
|
||||
from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
|
||||
|
||||
@@ -18,7 +18,6 @@ class RerankModel(AIModel):
|
||||
docs: list[str],
|
||||
score_threshold: float | None = None,
|
||||
top_n: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> RerankResult:
|
||||
"""
|
||||
Invoke rerank model
|
||||
@@ -29,18 +28,11 @@ class RerankModel(AIModel):
|
||||
:param docs: docs for reranking
|
||||
:param score_threshold: score threshold
|
||||
:param top_n: top n
|
||||
:param user: unique user id
|
||||
:return: rerank result
|
||||
"""
|
||||
try:
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
plugin_model_manager = PluginModelClient()
|
||||
return plugin_model_manager.invoke_rerank(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user or "unknown",
|
||||
plugin_id=self.plugin_id,
|
||||
provider=self.provider_name,
|
||||
return self.model_runtime.invoke_rerank(
|
||||
provider=self.provider,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
query=query,
|
||||
@@ -55,11 +47,10 @@ class RerankModel(AIModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
query: dict,
|
||||
docs: list[dict],
|
||||
query: MultimodalRerankInput,
|
||||
docs: list[MultimodalRerankInput],
|
||||
score_threshold: float | None = None,
|
||||
top_n: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> RerankResult:
|
||||
"""
|
||||
Invoke multimodal rerank model
|
||||
@@ -69,18 +60,11 @@ class RerankModel(AIModel):
|
||||
:param docs: docs for reranking
|
||||
:param score_threshold: score threshold
|
||||
:param top_n: top n
|
||||
:param user: unique user id
|
||||
:return: rerank result
|
||||
"""
|
||||
try:
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
plugin_model_manager = PluginModelClient()
|
||||
return plugin_model_manager.invoke_multimodal_rerank(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user or "unknown",
|
||||
plugin_id=self.plugin_id,
|
||||
provider=self.provider_name,
|
||||
return self.model_runtime.invoke_multimodal_rerank(
|
||||
provider=self.provider,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
query=query,
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
from typing import IO
|
||||
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
|
||||
@@ -13,28 +11,18 @@ class Speech2TextModel(AIModel):
|
||||
|
||||
model_type: ModelType = ModelType.SPEECH2TEXT
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
def invoke(self, model: str, credentials: dict, file: IO[bytes], user: str | None = None) -> str:
|
||||
def invoke(self, model: str, credentials: dict, file: IO[bytes]) -> str:
|
||||
"""
|
||||
Invoke speech to text model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param file: audio file
|
||||
:param user: unique user id
|
||||
:return: text for given audio file
|
||||
"""
|
||||
try:
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
plugin_model_manager = PluginModelClient()
|
||||
return plugin_model_manager.invoke_speech_to_text(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user or "unknown",
|
||||
plugin_id=self.plugin_id,
|
||||
provider=self.provider_name,
|
||||
return self.model_runtime.invoke_speech_to_text(
|
||||
provider=self.provider,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
file=file,
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from core.entities.embedding_type import EmbeddingInputType
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||
from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingResult
|
||||
from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingInputType, EmbeddingResult
|
||||
from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
|
||||
|
||||
@@ -13,16 +10,12 @@ class TextEmbeddingModel(AIModel):
|
||||
|
||||
model_type: ModelType = ModelType.TEXT_EMBEDDING
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str] | None = None,
|
||||
multimodel_documents: list[dict] | None = None,
|
||||
user: str | None = None,
|
||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||
) -> EmbeddingResult:
|
||||
"""
|
||||
@@ -32,31 +25,21 @@ class TextEmbeddingModel(AIModel):
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:param files: files to embed
|
||||
:param user: unique user id
|
||||
:param input_type: input type
|
||||
:return: embeddings result
|
||||
"""
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
try:
|
||||
plugin_model_manager = PluginModelClient()
|
||||
if texts:
|
||||
return plugin_model_manager.invoke_text_embedding(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user or "unknown",
|
||||
plugin_id=self.plugin_id,
|
||||
provider=self.provider_name,
|
||||
return self.model_runtime.invoke_text_embedding(
|
||||
provider=self.provider,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
texts=texts,
|
||||
input_type=input_type,
|
||||
)
|
||||
if multimodel_documents:
|
||||
return plugin_model_manager.invoke_multimodal_embedding(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user or "unknown",
|
||||
plugin_id=self.plugin_id,
|
||||
provider=self.provider_name,
|
||||
return self.model_runtime.invoke_multimodal_embedding(
|
||||
provider=self.provider,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
documents=multimodel_documents,
|
||||
@@ -75,14 +58,8 @@ class TextEmbeddingModel(AIModel):
|
||||
:param texts: texts to embed
|
||||
:return:
|
||||
"""
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
plugin_model_manager = PluginModelClient()
|
||||
return plugin_model_manager.get_text_embedding_num_tokens(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id="unknown",
|
||||
plugin_id=self.plugin_id,
|
||||
provider=self.provider_name,
|
||||
return self.model_runtime.get_text_embedding_num_tokens(
|
||||
provider=self.provider,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
texts=texts,
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import logging
|
||||
from collections.abc import Iterable
|
||||
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
|
||||
@@ -16,38 +14,25 @@ class TTSModel(AIModel):
|
||||
|
||||
model_type: ModelType = ModelType.TTS
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
model: str,
|
||||
tenant_id: str,
|
||||
credentials: dict,
|
||||
content_text: str,
|
||||
voice: str,
|
||||
user: str | None = None,
|
||||
) -> Iterable[bytes]:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param tenant_id: user tenant id
|
||||
:param credentials: model credentials
|
||||
:param voice: model timbre
|
||||
:param content_text: text content to be translated
|
||||
:param user: unique user id
|
||||
:return: translated audio file
|
||||
"""
|
||||
try:
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
plugin_model_manager = PluginModelClient()
|
||||
return plugin_model_manager.invoke_tts(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user or "unknown",
|
||||
plugin_id=self.plugin_id,
|
||||
provider=self.provider_name,
|
||||
return self.model_runtime.invoke_tts(
|
||||
provider=self.provider,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
content_text=content_text,
|
||||
@@ -65,14 +50,8 @@ class TTSModel(AIModel):
|
||||
:param credentials: The credentials required to access the TTS model.
|
||||
:return: A list of voices supported by the TTS model.
|
||||
"""
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
plugin_model_manager = PluginModelClient()
|
||||
return plugin_model_manager.get_tts_model_voices(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id="unknown",
|
||||
plugin_id=self.plugin_id,
|
||||
provider=self.provider_name,
|
||||
return self.model_runtime.get_tts_model_voices(
|
||||
provider=self.provider,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
language=language,
|
||||
|
||||
@@ -1,16 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from threading import Lock
|
||||
|
||||
from pydantic import ValidationError
|
||||
from redis import RedisError
|
||||
|
||||
import contexts
|
||||
from configs import dify_config
|
||||
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
||||
from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||
from dify_graph.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity
|
||||
from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
@@ -20,120 +11,64 @@ from dify_graph.model_runtime.model_providers.__base.rerank_model import RerankM
|
||||
from dify_graph.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
|
||||
from dify_graph.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from dify_graph.model_runtime.model_providers.__base.tts_model import TTSModel
|
||||
from dify_graph.model_runtime.runtime import ModelRuntime
|
||||
from dify_graph.model_runtime.schema_validators.model_credential_schema_validator import ModelCredentialSchemaValidator
|
||||
from dify_graph.model_runtime.schema_validators.provider_credential_schema_validator import (
|
||||
ProviderCredentialSchemaValidator,
|
||||
)
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.provider_ids import ModelProviderID
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModelProviderFactory:
|
||||
def __init__(self, tenant_id: str):
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
"""Factory for provider schemas and model-type instances backed by a runtime adapter."""
|
||||
|
||||
self.tenant_id = tenant_id
|
||||
self.plugin_model_manager = PluginModelClient()
|
||||
def __init__(self, model_runtime: ModelRuntime):
|
||||
if model_runtime is None:
|
||||
raise ValueError("model_runtime is required.")
|
||||
self.model_runtime = model_runtime
|
||||
|
||||
def get_providers(self) -> Sequence[ProviderEntity]:
|
||||
"""
|
||||
Get all providers
|
||||
:return: list of providers
|
||||
Get all providers.
|
||||
"""
|
||||
# FIXME(-LAN-): Removed position map sorting since providers are fetched from plugin server
|
||||
# The plugin server should return providers in the desired order
|
||||
plugin_providers = self.get_plugin_model_providers()
|
||||
return [provider.declaration for provider in plugin_providers]
|
||||
return list(self.get_model_providers())
|
||||
|
||||
def get_plugin_model_providers(self) -> Sequence[PluginModelProviderEntity]:
|
||||
def get_model_providers(self) -> Sequence[ProviderEntity]:
|
||||
"""
|
||||
Get all plugin model providers
|
||||
:return: list of plugin model providers
|
||||
Get all model providers exposed by the runtime adapter.
|
||||
"""
|
||||
# check if context is set
|
||||
try:
|
||||
contexts.plugin_model_providers.get()
|
||||
except LookupError:
|
||||
contexts.plugin_model_providers.set(None)
|
||||
contexts.plugin_model_providers_lock.set(Lock())
|
||||
|
||||
with contexts.plugin_model_providers_lock.get():
|
||||
plugin_model_providers = contexts.plugin_model_providers.get()
|
||||
if plugin_model_providers is not None:
|
||||
return plugin_model_providers
|
||||
|
||||
plugin_model_providers = []
|
||||
contexts.plugin_model_providers.set(plugin_model_providers)
|
||||
|
||||
# Fetch plugin model providers
|
||||
plugin_providers = self.plugin_model_manager.fetch_model_providers(self.tenant_id)
|
||||
|
||||
for provider in plugin_providers:
|
||||
provider.declaration.provider = provider.plugin_id + "/" + provider.declaration.provider
|
||||
plugin_model_providers.append(provider)
|
||||
|
||||
return plugin_model_providers
|
||||
return self.model_runtime.fetch_model_providers()
|
||||
|
||||
def get_provider_schema(self, provider: str) -> ProviderEntity:
|
||||
"""
|
||||
Get provider schema
|
||||
:param provider: provider name
|
||||
:return: provider schema
|
||||
Get provider schema.
|
||||
"""
|
||||
plugin_model_provider_entity = self.get_plugin_model_provider(provider=provider)
|
||||
return plugin_model_provider_entity.declaration
|
||||
return self.get_model_provider(provider=provider)
|
||||
|
||||
def get_plugin_model_provider(self, provider: str) -> PluginModelProviderEntity:
|
||||
def get_model_provider(self, provider: str) -> ProviderEntity:
|
||||
"""
|
||||
Get plugin model provider
|
||||
:param provider: provider name
|
||||
:return: provider schema
|
||||
Get provider schema.
|
||||
"""
|
||||
if "/" not in provider:
|
||||
provider = str(ModelProviderID(provider))
|
||||
|
||||
# fetch plugin model providers
|
||||
plugin_model_provider_entities = self.get_plugin_model_providers()
|
||||
|
||||
# get the provider
|
||||
plugin_model_provider_entity = next(
|
||||
(p for p in plugin_model_provider_entities if p.declaration.provider == provider),
|
||||
None,
|
||||
)
|
||||
|
||||
if not plugin_model_provider_entity:
|
||||
provider_entity = self._resolve_provider(provider)
|
||||
if provider_entity is None:
|
||||
raise ValueError(f"Invalid provider: {provider}")
|
||||
|
||||
return plugin_model_provider_entity
|
||||
return provider_entity
|
||||
|
||||
def provider_credentials_validate(self, *, provider: str, credentials: dict):
|
||||
"""
|
||||
Validate provider credentials
|
||||
|
||||
:param provider: provider name
|
||||
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
|
||||
:return:
|
||||
Validate provider credentials.
|
||||
"""
|
||||
# fetch plugin model provider
|
||||
plugin_model_provider_entity = self.get_plugin_model_provider(provider=provider)
|
||||
provider_entity = self.get_model_provider(provider=provider)
|
||||
|
||||
# get provider_credential_schema and validate credentials according to the rules
|
||||
provider_credential_schema = plugin_model_provider_entity.declaration.provider_credential_schema
|
||||
provider_credential_schema = provider_entity.provider_credential_schema
|
||||
if not provider_credential_schema:
|
||||
raise ValueError(f"Provider {provider} does not have provider_credential_schema")
|
||||
|
||||
# validate provider credential schema
|
||||
validator = ProviderCredentialSchemaValidator(provider_credential_schema)
|
||||
filtered_credentials = validator.validate_and_filter(credentials)
|
||||
|
||||
# validate the credentials, raise exception if validation failed
|
||||
self.plugin_model_manager.validate_provider_credentials(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id="unknown",
|
||||
plugin_id=plugin_model_provider_entity.plugin_id,
|
||||
provider=plugin_model_provider_entity.provider,
|
||||
self.model_runtime.validate_provider_credentials(
|
||||
provider=provider_entity.provider,
|
||||
credentials=filtered_credentials,
|
||||
)
|
||||
|
||||
@@ -141,33 +76,20 @@ class ModelProviderFactory:
|
||||
|
||||
def model_credentials_validate(self, *, provider: str, model_type: ModelType, model: str, credentials: dict):
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param provider: provider name
|
||||
:param model_type: model type
|
||||
:param model: model name
|
||||
:param credentials: model credentials, credentials form defined in `model_credential_schema`.
|
||||
:return:
|
||||
Validate model credentials.
|
||||
"""
|
||||
# fetch plugin model provider
|
||||
plugin_model_provider_entity = self.get_plugin_model_provider(provider=provider)
|
||||
provider_entity = self.get_model_provider(provider=provider)
|
||||
|
||||
# get model_credential_schema and validate credentials according to the rules
|
||||
model_credential_schema = plugin_model_provider_entity.declaration.model_credential_schema
|
||||
model_credential_schema = provider_entity.model_credential_schema
|
||||
if not model_credential_schema:
|
||||
raise ValueError(f"Provider {provider} does not have model_credential_schema")
|
||||
|
||||
# validate model credential schema
|
||||
validator = ModelCredentialSchemaValidator(model_type, model_credential_schema)
|
||||
filtered_credentials = validator.validate_and_filter(credentials)
|
||||
|
||||
# call validate_credentials method of model type to validate credentials, raise exception if validation failed
|
||||
self.plugin_model_manager.validate_model_credentials(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id="unknown",
|
||||
plugin_id=plugin_model_provider_entity.plugin_id,
|
||||
provider=plugin_model_provider_entity.provider,
|
||||
model_type=model_type.value,
|
||||
self.model_runtime.validate_model_credentials(
|
||||
provider=provider_entity.provider,
|
||||
model_type=model_type,
|
||||
model=model,
|
||||
credentials=filtered_credentials,
|
||||
)
|
||||
@@ -178,65 +100,16 @@ class ModelProviderFactory:
|
||||
self, *, provider: str, model_type: ModelType, model: str, credentials: dict | None
|
||||
) -> AIModelEntity | None:
|
||||
"""
|
||||
Get model schema
|
||||
Get model schema.
|
||||
"""
|
||||
plugin_id, provider_name = self.get_plugin_id_and_provider_name_from_provider(provider)
|
||||
cache_key = f"{self.tenant_id}:{plugin_id}:{provider_name}:{model_type.value}:{model}"
|
||||
sorted_credentials = sorted(credentials.items()) if credentials else []
|
||||
cache_key += ":".join([hashlib.md5(f"{k}:{v}".encode()).hexdigest() for k, v in sorted_credentials])
|
||||
|
||||
cached_schema_json = None
|
||||
try:
|
||||
cached_schema_json = redis_client.get(cache_key)
|
||||
except (RedisError, RuntimeError) as exc:
|
||||
logger.warning(
|
||||
"Failed to read plugin model schema cache for model %s: %s",
|
||||
model,
|
||||
str(exc),
|
||||
exc_info=True,
|
||||
)
|
||||
if cached_schema_json:
|
||||
try:
|
||||
return AIModelEntity.model_validate_json(cached_schema_json)
|
||||
except ValidationError:
|
||||
logger.warning(
|
||||
"Failed to validate cached plugin model schema for model %s",
|
||||
model,
|
||||
exc_info=True,
|
||||
)
|
||||
try:
|
||||
redis_client.delete(cache_key)
|
||||
except (RedisError, RuntimeError) as exc:
|
||||
logger.warning(
|
||||
"Failed to delete invalid plugin model schema cache for model %s: %s",
|
||||
model,
|
||||
str(exc),
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
schema = self.plugin_model_manager.get_model_schema(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id="unknown",
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
model_type=model_type.value,
|
||||
provider_entity = self.get_model_provider(provider)
|
||||
return self.model_runtime.get_model_schema(
|
||||
provider=provider_entity.provider,
|
||||
model_type=model_type,
|
||||
model=model,
|
||||
credentials=credentials or {},
|
||||
)
|
||||
|
||||
if schema:
|
||||
try:
|
||||
redis_client.setex(cache_key, dify_config.PLUGIN_MODEL_SCHEMA_CACHE_TTL, schema.model_dump_json())
|
||||
except (RedisError, RuntimeError) as exc:
|
||||
logger.warning(
|
||||
"Failed to write plugin model schema cache for model %s: %s",
|
||||
model,
|
||||
str(exc),
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
return schema
|
||||
|
||||
def get_models(
|
||||
self,
|
||||
*,
|
||||
@@ -245,143 +118,56 @@ class ModelProviderFactory:
|
||||
provider_configs: list[ProviderConfig] | None = None,
|
||||
) -> list[SimpleProviderEntity]:
|
||||
"""
|
||||
Get all models for given model type
|
||||
|
||||
:param provider: provider name
|
||||
:param model_type: model type
|
||||
:param provider_configs: list of provider configs
|
||||
:return: list of models
|
||||
Get all models for given model type.
|
||||
"""
|
||||
provider_configs = provider_configs or []
|
||||
|
||||
# scan all providers
|
||||
plugin_model_provider_entities = self.get_plugin_model_providers()
|
||||
|
||||
# traverse all model_provider_extensions
|
||||
providers = []
|
||||
for plugin_model_provider_entity in plugin_model_provider_entities:
|
||||
# filter by provider if provider is present
|
||||
if provider and plugin_model_provider_entity.declaration.provider != provider:
|
||||
for provider_entity in self.get_model_providers():
|
||||
if provider and not self._matches_provider(provider_entity, provider):
|
||||
continue
|
||||
|
||||
# get provider schema
|
||||
provider_schema = plugin_model_provider_entity.declaration
|
||||
|
||||
model_types = provider_schema.supported_model_types
|
||||
if model_type:
|
||||
if model_type not in model_types:
|
||||
continue
|
||||
|
||||
model_types = [model_type]
|
||||
|
||||
all_model_type_models = []
|
||||
for model_schema in provider_schema.models:
|
||||
if model_schema.model_type != model_type:
|
||||
continue
|
||||
|
||||
all_model_type_models.append(model_schema)
|
||||
|
||||
simple_provider_schema = provider_schema.to_simple_provider()
|
||||
if model_type:
|
||||
simple_provider_schema.models = all_model_type_models
|
||||
if model_type and model_type not in provider_entity.supported_model_types:
|
||||
continue
|
||||
|
||||
simple_provider_schema = provider_entity.to_simple_provider()
|
||||
if model_type is not None:
|
||||
simple_provider_schema.models = [
|
||||
model_schema for model_schema in provider_entity.models if model_schema.model_type == model_type
|
||||
]
|
||||
providers.append(simple_provider_schema)
|
||||
|
||||
return providers
|
||||
|
||||
def get_model_type_instance(self, provider: str, model_type: ModelType) -> AIModel:
|
||||
"""
|
||||
Get model type instance by provider name and model type
|
||||
:param provider: provider name
|
||||
:param model_type: model type
|
||||
:return: model type instance
|
||||
Get model type instance by provider name and model type.
|
||||
"""
|
||||
plugin_id, provider_name = self.get_plugin_id_and_provider_name_from_provider(provider)
|
||||
init_params = {
|
||||
"tenant_id": self.tenant_id,
|
||||
"plugin_id": plugin_id,
|
||||
"provider_name": provider_name,
|
||||
"plugin_model_provider": self.get_plugin_model_provider(provider),
|
||||
}
|
||||
provider_schema = self.get_model_provider(provider)
|
||||
|
||||
if model_type == ModelType.LLM:
|
||||
return LargeLanguageModel.model_validate(init_params)
|
||||
elif model_type == ModelType.TEXT_EMBEDDING:
|
||||
return TextEmbeddingModel.model_validate(init_params)
|
||||
elif model_type == ModelType.RERANK:
|
||||
return RerankModel.model_validate(init_params)
|
||||
elif model_type == ModelType.SPEECH2TEXT:
|
||||
return Speech2TextModel.model_validate(init_params)
|
||||
elif model_type == ModelType.MODERATION:
|
||||
return ModerationModel.model_validate(init_params)
|
||||
elif model_type == ModelType.TTS:
|
||||
return TTSModel.model_validate(init_params)
|
||||
return LargeLanguageModel(provider_schema=provider_schema, model_runtime=self.model_runtime)
|
||||
if model_type == ModelType.TEXT_EMBEDDING:
|
||||
return TextEmbeddingModel(provider_schema=provider_schema, model_runtime=self.model_runtime)
|
||||
if model_type == ModelType.RERANK:
|
||||
return RerankModel(provider_schema=provider_schema, model_runtime=self.model_runtime)
|
||||
if model_type == ModelType.SPEECH2TEXT:
|
||||
return Speech2TextModel(provider_schema=provider_schema, model_runtime=self.model_runtime)
|
||||
if model_type == ModelType.MODERATION:
|
||||
return ModerationModel(provider_schema=provider_schema, model_runtime=self.model_runtime)
|
||||
if model_type == ModelType.TTS:
|
||||
return TTSModel(provider_schema=provider_schema, model_runtime=self.model_runtime)
|
||||
|
||||
raise ValueError(f"Unsupported model type: {model_type}")
|
||||
|
||||
def get_provider_icon(self, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]:
|
||||
"""
|
||||
Get provider icon
|
||||
:param provider: provider name
|
||||
:param icon_type: icon type (icon_small or icon_small_dark)
|
||||
:param lang: language (zh_Hans or en_US)
|
||||
:return: provider icon
|
||||
Get provider icon.
|
||||
"""
|
||||
# get the provider schema
|
||||
provider_schema = self.get_provider_schema(provider)
|
||||
provider_entity = self.get_model_provider(provider)
|
||||
return self.model_runtime.get_provider_icon(provider=provider_entity.provider, icon_type=icon_type, lang=lang)
|
||||
|
||||
if icon_type.lower() == "icon_small":
|
||||
if not provider_schema.icon_small:
|
||||
raise ValueError(f"Provider {provider} does not have small icon.")
|
||||
def _resolve_provider(self, provider: str) -> ProviderEntity | None:
|
||||
return next((item for item in self.get_model_providers() if self._matches_provider(item, provider)), None)
|
||||
|
||||
if lang.lower() == "zh_hans":
|
||||
file_name = provider_schema.icon_small.zh_Hans
|
||||
else:
|
||||
file_name = provider_schema.icon_small.en_US
|
||||
elif icon_type.lower() == "icon_small_dark":
|
||||
if not provider_schema.icon_small_dark:
|
||||
raise ValueError(f"Provider {provider} does not have small dark icon.")
|
||||
|
||||
if lang.lower() == "zh_hans":
|
||||
file_name = provider_schema.icon_small_dark.zh_Hans
|
||||
else:
|
||||
file_name = provider_schema.icon_small_dark.en_US
|
||||
else:
|
||||
raise ValueError(f"Unsupported icon type: {icon_type}.")
|
||||
|
||||
if not file_name:
|
||||
raise ValueError(f"Provider {provider} does not have icon.")
|
||||
|
||||
image_mime_types = {
|
||||
"jpg": "image/jpeg",
|
||||
"jpeg": "image/jpeg",
|
||||
"png": "image/png",
|
||||
"gif": "image/gif",
|
||||
"bmp": "image/bmp",
|
||||
"tiff": "image/tiff",
|
||||
"tif": "image/tiff",
|
||||
"webp": "image/webp",
|
||||
"svg": "image/svg+xml",
|
||||
"ico": "image/vnd.microsoft.icon",
|
||||
"heif": "image/heif",
|
||||
"heic": "image/heic",
|
||||
}
|
||||
|
||||
extension = file_name.split(".")[-1]
|
||||
mime_type = image_mime_types.get(extension, "image/png")
|
||||
|
||||
# get icon bytes from plugin asset manager
|
||||
from core.plugin.impl.asset import PluginAssetManager
|
||||
|
||||
plugin_asset_manager = PluginAssetManager()
|
||||
return plugin_asset_manager.fetch_asset(tenant_id=self.tenant_id, id=file_name), mime_type
|
||||
|
||||
def get_plugin_id_and_provider_name_from_provider(self, provider: str) -> tuple[str, str]:
|
||||
"""
|
||||
Get plugin id and provider name from provider name
|
||||
:param provider: provider name
|
||||
:return: plugin id and provider name
|
||||
"""
|
||||
|
||||
provider_id = ModelProviderID(provider)
|
||||
return provider_id.plugin_id, provider_id.provider_name
|
||||
@staticmethod
|
||||
def _matches_provider(provider_entity: ProviderEntity, provider: str) -> bool:
|
||||
return provider in (provider_entity.provider, provider_entity.provider_name)
|
||||
|
||||
159
api/dify_graph/model_runtime/runtime.py
Normal file
159
api/dify_graph/model_runtime/runtime.py
Normal file
@@ -0,0 +1,159 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator, Iterable, Sequence
|
||||
from typing import IO, Any, Protocol, Union, runtime_checkable
|
||||
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||
from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||
from dify_graph.model_runtime.entities.provider_entities import ProviderEntity
|
||||
from dify_graph.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult
|
||||
from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingInputType, EmbeddingResult
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class ModelRuntime(Protocol):
|
||||
"""Port for provider discovery, schema lookup, and model execution.
|
||||
|
||||
`provider` is the model runtime's canonical provider identifier. Adapters may
|
||||
derive transport-specific details from it, but those details stay outside
|
||||
this boundary.
|
||||
"""
|
||||
|
||||
def fetch_model_providers(self) -> Sequence[ProviderEntity]: ...
|
||||
|
||||
def get_provider_icon(self, *, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]: ...
|
||||
|
||||
def validate_provider_credentials(self, *, provider: str, credentials: dict[str, Any]) -> None: ...
|
||||
|
||||
def validate_model_credentials(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model_type: ModelType,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
) -> None: ...
|
||||
|
||||
def get_model_schema(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model_type: ModelType,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
) -> AIModelEntity | None: ...
|
||||
|
||||
def invoke_llm(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
model_parameters: dict[str, Any],
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
tools: list[PromptMessageTool] | None,
|
||||
stop: Sequence[str] | None,
|
||||
stream: bool,
|
||||
) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]: ...
|
||||
|
||||
def get_llm_num_tokens(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model_type: ModelType,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
tools: Sequence[PromptMessageTool] | None,
|
||||
) -> int: ...
|
||||
|
||||
def invoke_text_embedding(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
texts: list[str],
|
||||
input_type: EmbeddingInputType,
|
||||
) -> EmbeddingResult: ...
|
||||
|
||||
def invoke_multimodal_embedding(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
documents: list[dict[str, Any]],
|
||||
input_type: EmbeddingInputType,
|
||||
) -> EmbeddingResult: ...
|
||||
|
||||
def get_text_embedding_num_tokens(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
texts: list[str],
|
||||
) -> list[int]: ...
|
||||
|
||||
def invoke_rerank(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
query: str,
|
||||
docs: list[str],
|
||||
score_threshold: float | None,
|
||||
top_n: int | None,
|
||||
) -> RerankResult: ...
|
||||
|
||||
def invoke_multimodal_rerank(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
query: MultimodalRerankInput,
|
||||
docs: list[MultimodalRerankInput],
|
||||
score_threshold: float | None,
|
||||
top_n: int | None,
|
||||
) -> RerankResult: ...
|
||||
|
||||
def invoke_tts(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
content_text: str,
|
||||
voice: str,
|
||||
) -> Iterable[bytes]: ...
|
||||
|
||||
def get_tts_model_voices(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
language: str | None,
|
||||
) -> Any: ...
|
||||
|
||||
def invoke_speech_to_text(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
file: IO[bytes],
|
||||
) -> str: ...
|
||||
|
||||
def invoke_moderation(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
text: str,
|
||||
) -> bool: ...
|
||||
@@ -6,13 +6,12 @@ from abc import abstractmethod
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from functools import singledispatchmethod
|
||||
from types import MappingProxyType
|
||||
from typing import Any, ClassVar, Generic, Protocol, TypeVar, cast, get_args, get_origin
|
||||
from typing import Any, ClassVar, Generic, TypeVar, cast, get_args, get_origin
|
||||
from uuid import uuid4
|
||||
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.entities.base_node_data import BaseNodeData, RetryConfig
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY
|
||||
from dify_graph.enums import (
|
||||
ErrorStrategy,
|
||||
NodeExecutionType,
|
||||
@@ -60,7 +59,7 @@ from dify_graph.node_events import (
|
||||
StreamCompletedEvent,
|
||||
)
|
||||
from dify_graph.runtime import GraphRuntimeState
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from dify_graph.utils.datetime_utils import naive_utc_now
|
||||
|
||||
NodeDataT = TypeVar("NodeDataT", bound=BaseNodeData)
|
||||
_MISSING_RUN_CONTEXT_VALUE = object()
|
||||
@@ -68,23 +67,6 @@ _MISSING_RUN_CONTEXT_VALUE = object()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DifyRunContextProtocol(Protocol):
|
||||
tenant_id: str
|
||||
app_id: str
|
||||
user_id: str
|
||||
user_from: Any
|
||||
invoke_from: Any
|
||||
|
||||
|
||||
class _MappingDifyRunContext:
|
||||
def __init__(self, mapping: Mapping[str, Any]) -> None:
|
||||
self.tenant_id = str(mapping["tenant_id"])
|
||||
self.app_id = str(mapping["app_id"])
|
||||
self.user_id = str(mapping["user_id"])
|
||||
self.user_from = mapping["user_from"]
|
||||
self.invoke_from = mapping["invoke_from"]
|
||||
|
||||
|
||||
class Node(Generic[NodeDataT]):
|
||||
"""BaseNode serves as the foundational class for all node implementations.
|
||||
|
||||
@@ -177,8 +159,9 @@ class Node(Generic[NodeDataT]):
|
||||
# Skip base class itself
|
||||
if cls is Node:
|
||||
return
|
||||
# Only register production node implementations defined under the
|
||||
# canonical workflow namespaces.
|
||||
# Only treat nodes from the base dify_graph package as production
|
||||
# registrations. Higher-layer packages may still register subclasses,
|
||||
# but dify_graph itself should not know their module identities.
|
||||
# This prevents test helper subclasses from polluting the global registry and
|
||||
# accidentally overriding real node types (e.g., a test Answer node).
|
||||
module_name = getattr(cls, "__module__", "")
|
||||
@@ -186,7 +169,7 @@ class Node(Generic[NodeDataT]):
|
||||
node_type = cls.node_type
|
||||
version = cls.version()
|
||||
bucket = Node._registry.setdefault(node_type, {})
|
||||
if module_name.startswith(("dify_graph.nodes.", "core.workflow.nodes.")):
|
||||
if module_name.startswith("dify_graph.nodes."):
|
||||
# Production node definitions take precedence and may override
|
||||
bucket[version] = cls # type: ignore[index]
|
||||
else:
|
||||
@@ -299,25 +282,6 @@ class Node(Generic[NodeDataT]):
|
||||
raise ValueError(f"run_context missing required key: {key}")
|
||||
return value
|
||||
|
||||
def require_dify_context(self) -> DifyRunContextProtocol:
|
||||
raw_ctx = self.require_run_context_value(DIFY_RUN_CONTEXT_KEY)
|
||||
if raw_ctx is None:
|
||||
raise ValueError(f"run_context missing required key: {DIFY_RUN_CONTEXT_KEY}")
|
||||
|
||||
if isinstance(raw_ctx, Mapping):
|
||||
missing_keys = [
|
||||
key for key in ("tenant_id", "app_id", "user_id", "user_from", "invoke_from") if key not in raw_ctx
|
||||
]
|
||||
if missing_keys:
|
||||
raise ValueError(f"dify context missing required keys: {', '.join(missing_keys)}")
|
||||
return _MappingDifyRunContext(raw_ctx)
|
||||
|
||||
for attr in ("tenant_id", "app_id", "user_id", "user_from", "invoke_from"):
|
||||
if not hasattr(raw_ctx, attr):
|
||||
raise TypeError(f"invalid dify context object, missing attribute: {attr}")
|
||||
|
||||
return cast(DifyRunContextProtocol, raw_ctx)
|
||||
|
||||
@property
|
||||
def execution_id(self) -> str:
|
||||
return self._node_execution_id
|
||||
@@ -793,16 +757,11 @@ class Node(Generic[NodeDataT]):
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: RunRetrieverResourceEvent) -> NodeRunRetrieverResourceEvent:
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
|
||||
retriever_resources = [
|
||||
RetrievalSourceMetadata.model_validate(resource) for resource in event.retriever_resources
|
||||
]
|
||||
return NodeRunRetrieverResourceEvent(
|
||||
id=self.execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
retriever_resources=retriever_resources,
|
||||
retriever_resources=event.retriever_resources,
|
||||
context=event.context,
|
||||
node_version=self.version(),
|
||||
)
|
||||
|
||||
@@ -11,9 +11,13 @@ from dify_graph.nodes.base import variable_template_parser
|
||||
from dify_graph.nodes.base.entities import VariableSelector
|
||||
from dify_graph.nodes.base.node import Node
|
||||
from dify_graph.nodes.http_request.executor import Executor
|
||||
from dify_graph.nodes.protocols import FileManagerProtocol, HttpClientProtocol, ToolFileManagerProtocol
|
||||
from dify_graph.nodes.protocols import (
|
||||
FileManagerProtocol,
|
||||
FileReferenceFactoryProtocol,
|
||||
HttpClientProtocol,
|
||||
ToolFileManagerProtocol,
|
||||
)
|
||||
from dify_graph.variables.segments import ArrayFileSegment
|
||||
from factories import file_factory
|
||||
|
||||
from .config import build_http_request_config, resolve_http_request_config
|
||||
from .entities import (
|
||||
@@ -46,6 +50,7 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
|
||||
http_client: HttpClientProtocol,
|
||||
tool_file_manager_factory: Callable[[], ToolFileManagerProtocol],
|
||||
file_manager: FileManagerProtocol,
|
||||
file_reference_factory: FileReferenceFactoryProtocol,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
id=id,
|
||||
@@ -58,6 +63,7 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
|
||||
self._http_client = http_client
|
||||
self._tool_file_manager_factory = tool_file_manager_factory
|
||||
self._file_manager = file_manager
|
||||
self._file_reference_factory = file_reference_factory
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||
@@ -212,7 +218,6 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
|
||||
"""
|
||||
Extract files from response by checking both Content-Type header and URL
|
||||
"""
|
||||
dify_ctx = self.require_dify_context()
|
||||
files: list[File] = []
|
||||
is_file = response.is_file
|
||||
content_type = response.content_type
|
||||
@@ -237,20 +242,16 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
|
||||
tool_file_manager = self._tool_file_manager_factory()
|
||||
|
||||
tool_file = tool_file_manager.create_file_by_raw(
|
||||
user_id=dify_ctx.user_id,
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
conversation_id=None,
|
||||
file_binary=content,
|
||||
mimetype=mime_type,
|
||||
)
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": tool_file.id,
|
||||
"transfer_method": FileTransferMethod.TOOL_FILE,
|
||||
}
|
||||
file = file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
file = self._file_reference_factory.build_from_mapping(
|
||||
mapping={
|
||||
"tool_file_id": tool_file.id,
|
||||
"transfer_method": FileTransferMethod.TOOL_FILE,
|
||||
}
|
||||
)
|
||||
files.append(file)
|
||||
|
||||
|
||||
@@ -15,15 +15,16 @@ from dify_graph.node_events import (
|
||||
from dify_graph.node_events.base import NodeEventBase
|
||||
from dify_graph.node_events.node import StreamCompletedEvent
|
||||
from dify_graph.nodes.base.node import Node
|
||||
from dify_graph.nodes.runtime import HumanInputNodeRuntimeProtocol
|
||||
from dify_graph.repositories.human_input_form_repository import (
|
||||
FormCreateParams,
|
||||
HumanInputFormEntity,
|
||||
HumanInputFormRepository,
|
||||
)
|
||||
from dify_graph.utils.datetime_utils import naive_utc_now
|
||||
from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
from .entities import DeliveryChannelConfig, HumanInputNodeData, apply_debug_email_recipient
|
||||
from .entities import DeliveryChannelConfig, HumanInputNodeData
|
||||
from .enums import DeliveryMethodType, HumanInputFormStatus, PlaceholderType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -68,6 +69,7 @@ class HumanInputNode(Node[HumanInputNodeData]):
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
form_repository: HumanInputFormRepository,
|
||||
runtime: HumanInputNodeRuntimeProtocol | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
id=id,
|
||||
@@ -76,6 +78,9 @@ class HumanInputNode(Node[HumanInputNodeData]):
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
self._form_repository = form_repository
|
||||
if runtime is None:
|
||||
raise ValueError("runtime is required")
|
||||
self._runtime = runtime
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
@@ -171,25 +176,14 @@ class HumanInputNode(Node[HumanInputNodeData]):
|
||||
return self._node_data.is_webapp_enabled()
|
||||
|
||||
def _effective_delivery_methods(self) -> Sequence[DeliveryChannelConfig]:
|
||||
dify_ctx = self.require_dify_context()
|
||||
invoke_from = self._invoke_from_value()
|
||||
enabled_methods = [method for method in self._node_data.delivery_methods if method.enabled]
|
||||
if invoke_from in {_INVOKE_FROM_DEBUGGER, _INVOKE_FROM_EXPLORE}:
|
||||
enabled_methods = [method for method in enabled_methods if method.type != DeliveryMethodType.WEBAPP]
|
||||
return [
|
||||
apply_debug_email_recipient(
|
||||
method,
|
||||
enabled=invoke_from == _INVOKE_FROM_DEBUGGER,
|
||||
user_id=dify_ctx.user_id,
|
||||
)
|
||||
for method in enabled_methods
|
||||
]
|
||||
return self._runtime.apply_delivery_runtime(methods=enabled_methods)
|
||||
|
||||
def _invoke_from_value(self) -> str:
|
||||
invoke_from = self.require_dify_context().invoke_from
|
||||
if isinstance(invoke_from, str):
|
||||
return invoke_from
|
||||
return str(getattr(invoke_from, "value", invoke_from))
|
||||
return self._runtime.invoke_source()
|
||||
|
||||
def _human_input_required_event(self, form_entity: HumanInputFormEntity) -> HumanInputRequired:
|
||||
node_data = self._node_data
|
||||
@@ -224,11 +218,9 @@ class HumanInputNode(Node[HumanInputNodeData]):
|
||||
"""
|
||||
repo = self._form_repository
|
||||
form = repo.get_form(self._workflow_execution_id, self.id)
|
||||
dify_ctx = self.require_dify_context()
|
||||
if form is None:
|
||||
display_in_ui = self._display_in_ui()
|
||||
params = FormCreateParams(
|
||||
app_id=dify_ctx.app_id,
|
||||
workflow_execution_id=self._workflow_execution_id,
|
||||
node_id=self.id,
|
||||
form_config=self._node_data,
|
||||
@@ -238,7 +230,7 @@ class HumanInputNode(Node[HumanInputNodeData]):
|
||||
resolved_default_values=self.resolve_default_values(),
|
||||
console_recipient_required=self._should_require_console_recipient(),
|
||||
console_creator_account_id=(
|
||||
dify_ctx.user_id
|
||||
self._runtime.console_actor_id()
|
||||
if self._invoke_from_value() in {_INVOKE_FROM_DEBUGGER, _INVOKE_FROM_EXPLORE}
|
||||
else None
|
||||
),
|
||||
|
||||
@@ -34,10 +34,10 @@ from dify_graph.nodes.base import LLMUsageTrackingMixin
|
||||
from dify_graph.nodes.base.node import Node
|
||||
from dify_graph.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
|
||||
from dify_graph.runtime import VariablePool
|
||||
from dify_graph.utils.datetime_utils import naive_utc_now
|
||||
from dify_graph.variables import IntegerVariable, NoneSegment
|
||||
from dify_graph.variables.segments import ArrayAnySegment, ArraySegment
|
||||
from dify_graph.variables.variables import Variable
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
from .exc import (
|
||||
InvalidIteratorValueError,
|
||||
|
||||
@@ -3,11 +3,11 @@ from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import BuiltinNodeTypes, NodeType
|
||||
from dify_graph.model_runtime.entities import ImagePromptMessageContent, LLMMode
|
||||
from dify_graph.nodes.base.entities import VariableSelector
|
||||
from dify_graph.prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
|
||||
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
|
||||
@@ -2,10 +2,8 @@ import mimetypes
|
||||
import typing as tp
|
||||
|
||||
from constants.mimetypes import DEFAULT_EXTENSION, DEFAULT_MIME_TYPE
|
||||
from core.tools.signature import sign_tool_file
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from dify_graph.file import File, FileTransferMethod, FileType
|
||||
from dify_graph.nodes.protocols import HttpClientProtocol
|
||||
from dify_graph.nodes.protocols import FileReferenceFactoryProtocol, HttpClientProtocol, ToolFileManagerProtocol
|
||||
|
||||
|
||||
class LLMFileSaver(tp.Protocol):
|
||||
@@ -57,17 +55,20 @@ class LLMFileSaver(tp.Protocol):
|
||||
|
||||
|
||||
class FileSaverImpl(LLMFileSaver):
|
||||
_tenant_id: str
|
||||
_user_id: str
|
||||
_tool_file_manager: ToolFileManagerProtocol
|
||||
_file_reference_factory: FileReferenceFactoryProtocol
|
||||
|
||||
def __init__(self, user_id: str, tenant_id: str, http_client: HttpClientProtocol):
|
||||
self._user_id = user_id
|
||||
self._tenant_id = tenant_id
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
tool_file_manager: ToolFileManagerProtocol,
|
||||
file_reference_factory: FileReferenceFactoryProtocol,
|
||||
http_client: HttpClientProtocol,
|
||||
):
|
||||
self._tool_file_manager = tool_file_manager
|
||||
self._file_reference_factory = file_reference_factory
|
||||
self._http_client = http_client
|
||||
|
||||
def _get_tool_file_manager(self):
|
||||
return ToolFileManager()
|
||||
|
||||
def save_remote_url(self, url: str, file_type: FileType) -> File:
|
||||
http_response = self._http_client.get(url)
|
||||
http_response.raise_for_status()
|
||||
@@ -83,10 +84,7 @@ class FileSaverImpl(LLMFileSaver):
|
||||
file_type: FileType,
|
||||
extension_override: str | None = None,
|
||||
) -> File:
|
||||
tool_file_manager = self._get_tool_file_manager()
|
||||
tool_file = tool_file_manager.create_file_by_raw(
|
||||
user_id=self._user_id,
|
||||
tenant_id=self._tenant_id,
|
||||
tool_file = self._tool_file_manager.create_file_by_raw(
|
||||
# TODO(QuantumGhost): what is conversation id?
|
||||
conversation_id=None,
|
||||
file_binary=data,
|
||||
@@ -94,19 +92,18 @@ class FileSaverImpl(LLMFileSaver):
|
||||
)
|
||||
extension_override = _validate_extension_override(extension_override)
|
||||
extension = _get_extension(mime_type, extension_override)
|
||||
url = sign_tool_file(tool_file.id, extension)
|
||||
|
||||
return File(
|
||||
tenant_id=self._tenant_id,
|
||||
type=file_type,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
filename=tool_file.name,
|
||||
extension=extension,
|
||||
mime_type=mime_type,
|
||||
size=len(data),
|
||||
related_id=tool_file.id,
|
||||
url=url,
|
||||
storage_key=tool_file.file_key,
|
||||
return self._file_reference_factory.build_from_mapping(
|
||||
mapping={
|
||||
"type": file_type,
|
||||
"transfer_method": FileTransferMethod.TOOL_FILE,
|
||||
"filename": tool_file.name,
|
||||
"extension": extension,
|
||||
"mime_type": mime_type,
|
||||
"size": len(data),
|
||||
"tool_file_id": tool_file.id,
|
||||
"related_id": tool_file.id,
|
||||
"storage_key": tool_file.file_key,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, cast
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Protocol, TypeAlias, cast
|
||||
|
||||
from core.model_manager import ModelInstance
|
||||
from dify_graph.file import FileType, file_manager
|
||||
from dify_graph.file.models import File
|
||||
from dify_graph.model_runtime.entities import (
|
||||
@@ -35,15 +34,36 @@ from .exc import (
|
||||
TemplateTypeNotSupportError,
|
||||
)
|
||||
from .protocols import TemplateRenderer
|
||||
from .runtime_protocols import PreparedLLMProtocol
|
||||
|
||||
|
||||
def fetch_model_schema(*, model_instance: ModelInstance) -> AIModelEntity:
|
||||
model_schema = cast(LargeLanguageModel, model_instance.model_type_instance).get_model_schema(
|
||||
model_instance.model_name,
|
||||
dict(model_instance.credentials),
|
||||
)
|
||||
class _LegacyModelInstance(Protocol):
|
||||
model_type_instance: object
|
||||
model_name: str
|
||||
credentials: object
|
||||
parameters: Mapping[str, Any]
|
||||
|
||||
def get_llm_num_tokens(self, prompt_messages: Sequence[PromptMessage]) -> int: ...
|
||||
|
||||
|
||||
PreparedModelInstance: TypeAlias = PreparedLLMProtocol | _LegacyModelInstance
|
||||
|
||||
|
||||
def fetch_model_schema(*, model_instance: PreparedModelInstance) -> AIModelEntity:
|
||||
get_model_schema = getattr(model_instance, "get_model_schema", None)
|
||||
if callable(get_model_schema):
|
||||
model_schema = cast(PreparedLLMProtocol, model_instance).get_model_schema()
|
||||
else:
|
||||
legacy_model_instance = cast(_LegacyModelInstance, model_instance)
|
||||
credentials = legacy_model_instance.credentials
|
||||
if isinstance(credentials, Mapping):
|
||||
credentials = dict(credentials)
|
||||
model_schema = cast(LargeLanguageModel, legacy_model_instance.model_type_instance).get_model_schema(
|
||||
legacy_model_instance.model_name,
|
||||
credentials,
|
||||
)
|
||||
if not model_schema:
|
||||
raise ValueError(f"Model schema not found for {model_instance.model_name}")
|
||||
raise ValueError(f"Model schema not found for {getattr(model_instance, 'model_name', 'unknown model')}")
|
||||
return model_schema
|
||||
|
||||
|
||||
@@ -116,7 +136,7 @@ def fetch_prompt_messages(
|
||||
sys_files: Sequence[File],
|
||||
context: str | None = None,
|
||||
memory: PromptMessageMemory | None = None,
|
||||
model_instance: ModelInstance,
|
||||
model_instance: PreparedModelInstance,
|
||||
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate,
|
||||
stop: Sequence[str] | None = None,
|
||||
memory_config: MemoryConfig | None = None,
|
||||
@@ -391,7 +411,7 @@ def combine_message_content_with_role(
|
||||
raise NotImplementedError(f"Role {role} is not supported")
|
||||
|
||||
|
||||
def calculate_rest_token(*, prompt_messages: list[PromptMessage], model_instance: ModelInstance) -> int:
|
||||
def calculate_rest_token(*, prompt_messages: list[PromptMessage], model_instance: PreparedModelInstance) -> int:
|
||||
rest_tokens = 2000
|
||||
runtime_model_schema = fetch_model_schema(model_instance=model_instance)
|
||||
runtime_model_parameters = model_instance.parameters
|
||||
@@ -421,7 +441,7 @@ def handle_memory_chat_mode(
|
||||
*,
|
||||
memory: PromptMessageMemory | None,
|
||||
memory_config: MemoryConfig | None,
|
||||
model_instance: ModelInstance,
|
||||
model_instance: PreparedModelInstance,
|
||||
) -> Sequence[PromptMessage]:
|
||||
if not memory or not memory_config:
|
||||
return []
|
||||
@@ -436,7 +456,7 @@ def handle_memory_completion_mode(
|
||||
*,
|
||||
memory: PromptMessageMemory | None,
|
||||
memory_config: MemoryConfig | None,
|
||||
model_instance: ModelInstance,
|
||||
model_instance: PreparedModelInstance,
|
||||
) -> str:
|
||||
if not memory or not memory_config:
|
||||
return ""
|
||||
|
||||
@@ -7,16 +7,8 @@ import logging
|
||||
import re
|
||||
import time
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.llm_generator.output_parser.errors import OutputParserError
|
||||
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
|
||||
from core.model_manager import ModelInstance
|
||||
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.tools.signature import sign_upload_file
|
||||
from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
@@ -27,10 +19,11 @@ from dify_graph.enums import (
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from dify_graph.file import File, FileTransferMethod, FileType
|
||||
from dify_graph.file import File, FileType, file_manager
|
||||
from dify_graph.model_runtime.entities import (
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContentType,
|
||||
TextPromptMessageContent,
|
||||
)
|
||||
from dify_graph.model_runtime.entities.llm_entities import (
|
||||
@@ -41,7 +34,14 @@ from dify_graph.model_runtime.entities.llm_entities import (
|
||||
LLMStructuredOutput,
|
||||
LLMUsage,
|
||||
)
|
||||
from dify_graph.model_runtime.entities.message_entities import PromptMessageContentUnionTypes
|
||||
from dify_graph.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessageContentUnionTypes,
|
||||
PromptMessageRole,
|
||||
SystemPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
|
||||
from dify_graph.model_runtime.memory import PromptMessageMemory
|
||||
from dify_graph.model_runtime.utils.encoders import jsonable_encoder
|
||||
from dify_graph.node_events import (
|
||||
@@ -55,19 +55,23 @@ from dify_graph.node_events import (
|
||||
from dify_graph.nodes.base.entities import VariableSelector
|
||||
from dify_graph.nodes.base.node import Node
|
||||
from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer
|
||||
from dify_graph.nodes.llm.runtime_protocols import (
|
||||
PreparedLLMProtocol,
|
||||
PromptMessageSerializerProtocol,
|
||||
RetrieverAttachmentLoaderProtocol,
|
||||
)
|
||||
from dify_graph.nodes.protocols import HttpClientProtocol
|
||||
from dify_graph.prompt_entities import CompletionModelPromptTemplate, MemoryConfig
|
||||
from dify_graph.runtime import VariablePool
|
||||
from dify_graph.template_rendering import Jinja2TemplateRenderer, TemplateRenderError
|
||||
from dify_graph.variables import (
|
||||
ArrayFileSegment,
|
||||
ArraySegment,
|
||||
FileSegment,
|
||||
NoneSegment,
|
||||
ObjectSegment,
|
||||
StringSegment,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import SegmentAttachmentBinding
|
||||
from models.model import UploadFile
|
||||
|
||||
from . import llm_utils
|
||||
from .entities import (
|
||||
@@ -79,9 +83,12 @@ from .exc import (
|
||||
InvalidContextStructureError,
|
||||
InvalidVariableTypeError,
|
||||
LLMNodeError,
|
||||
MemoryRolePrefixRequiredError,
|
||||
NoPromptFoundError,
|
||||
TemplateTypeNotSupportError,
|
||||
VariableNotFoundError,
|
||||
)
|
||||
from .file_saver import FileSaverImpl, LLMFileSaver
|
||||
from .file_saver import LLMFileSaver
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dify_graph.file.models import File
|
||||
@@ -101,11 +108,11 @@ class LLMNode(Node[LLMNodeData]):
|
||||
_file_outputs: list[File]
|
||||
|
||||
_llm_file_saver: LLMFileSaver
|
||||
_credentials_provider: CredentialsProvider
|
||||
_model_factory: ModelFactory
|
||||
_model_instance: ModelInstance
|
||||
_retriever_attachment_loader: RetrieverAttachmentLoaderProtocol | None
|
||||
_prompt_message_serializer: PromptMessageSerializerProtocol
|
||||
_jinja2_template_renderer: Jinja2TemplateRenderer | None
|
||||
_model_instance: PreparedLLMProtocol
|
||||
_memory: PromptMessageMemory | None
|
||||
_template_renderer: TemplateRenderer
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -114,13 +121,15 @@ class LLMNode(Node[LLMNodeData]):
|
||||
graph_init_params: GraphInitParams,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
*,
|
||||
credentials_provider: CredentialsProvider,
|
||||
model_factory: ModelFactory,
|
||||
model_instance: ModelInstance,
|
||||
credentials_provider: object | None = None,
|
||||
model_factory: object | None = None,
|
||||
model_instance: PreparedLLMProtocol,
|
||||
http_client: HttpClientProtocol,
|
||||
template_renderer: TemplateRenderer,
|
||||
memory: PromptMessageMemory | None = None,
|
||||
llm_file_saver: LLMFileSaver | None = None,
|
||||
llm_file_saver: LLMFileSaver,
|
||||
prompt_message_serializer: PromptMessageSerializerProtocol,
|
||||
retriever_attachment_loader: RetrieverAttachmentLoaderProtocol | None = None,
|
||||
jinja2_template_renderer: Jinja2TemplateRenderer | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
id=id,
|
||||
@@ -131,20 +140,14 @@ class LLMNode(Node[LLMNodeData]):
|
||||
# LLM file outputs, used for MultiModal outputs.
|
||||
self._file_outputs = []
|
||||
|
||||
self._credentials_provider = credentials_provider
|
||||
self._model_factory = model_factory
|
||||
_ = credentials_provider, model_factory, http_client
|
||||
self._model_instance = model_instance
|
||||
self._memory = memory
|
||||
self._template_renderer = template_renderer
|
||||
|
||||
if llm_file_saver is None:
|
||||
dify_ctx = self.require_dify_context()
|
||||
llm_file_saver = FileSaverImpl(
|
||||
user_id=dify_ctx.user_id,
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
http_client=http_client,
|
||||
)
|
||||
self._llm_file_saver = llm_file_saver
|
||||
self._prompt_message_serializer = prompt_message_serializer
|
||||
self._retriever_attachment_loader = retriever_attachment_loader
|
||||
self._jinja2_template_renderer = jinja2_template_renderer
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
@@ -230,7 +233,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
variable_pool=variable_pool,
|
||||
jinja2_variables=self.node_data.prompt_config.jinja2_variables,
|
||||
context_files=context_files,
|
||||
template_renderer=self._template_renderer,
|
||||
jinja2_template_renderer=self._jinja2_template_renderer,
|
||||
)
|
||||
|
||||
# handle invoke result
|
||||
@@ -238,7 +241,6 @@ class LLMNode(Node[LLMNodeData]):
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
stop=stop,
|
||||
user_id=self.require_dify_context().user_id,
|
||||
structured_output_enabled=self.node_data.structured_output_enabled,
|
||||
structured_output=self.node_data.structured_output,
|
||||
file_saver=self._llm_file_saver,
|
||||
@@ -281,7 +283,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
|
||||
process_data = {
|
||||
"model_mode": self.node_data.model.mode,
|
||||
"prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving(
|
||||
"prompts": self._prompt_message_serializer.serialize(
|
||||
model_mode=self.node_data.model.mode, prompt_messages=prompt_messages
|
||||
),
|
||||
"usage": jsonable_encoder(usage),
|
||||
@@ -349,10 +351,9 @@ class LLMNode(Node[LLMNodeData]):
|
||||
@staticmethod
|
||||
def invoke_llm(
|
||||
*,
|
||||
model_instance: ModelInstance,
|
||||
model_instance: PreparedLLMProtocol,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
stop: Sequence[str] | None = None,
|
||||
user_id: str,
|
||||
structured_output_enabled: bool,
|
||||
structured_output: Mapping[str, Any] | None = None,
|
||||
file_saver: LLMFileSaver,
|
||||
@@ -363,35 +364,28 @@ class LLMNode(Node[LLMNodeData]):
|
||||
) -> Generator[NodeEventBase | LLMStructuredOutput, None, None]:
|
||||
model_parameters = model_instance.parameters
|
||||
invoke_model_parameters = dict(model_parameters)
|
||||
|
||||
model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
|
||||
|
||||
if structured_output_enabled:
|
||||
output_schema = LLMNode.fetch_structured_output_schema(
|
||||
structured_output=structured_output or {},
|
||||
)
|
||||
request_start_time = time.perf_counter()
|
||||
|
||||
invoke_result = invoke_llm_with_structured_output(
|
||||
provider=model_instance.provider,
|
||||
model_schema=model_schema,
|
||||
model_instance=model_instance,
|
||||
invoke_result = model_instance.invoke_llm_with_structured_output(
|
||||
prompt_messages=prompt_messages,
|
||||
json_schema=output_schema,
|
||||
model_parameters=invoke_model_parameters,
|
||||
stop=list(stop or []),
|
||||
stop=stop,
|
||||
stream=True,
|
||||
user=user_id,
|
||||
)
|
||||
else:
|
||||
request_start_time = time.perf_counter()
|
||||
|
||||
invoke_result = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages),
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=invoke_model_parameters,
|
||||
stop=list(stop or []),
|
||||
tools=None,
|
||||
stop=stop,
|
||||
stream=True,
|
||||
user=user_id,
|
||||
)
|
||||
|
||||
return LLMNode.handle_invoke_result(
|
||||
@@ -400,6 +394,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
file_outputs=file_outputs,
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
model_instance=model_instance,
|
||||
reasoning_format=reasoning_format,
|
||||
request_start_time=request_start_time,
|
||||
)
|
||||
@@ -412,6 +407,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
file_outputs: list[File],
|
||||
node_id: str,
|
||||
node_type: NodeType,
|
||||
model_instance: PreparedLLMProtocol | object,
|
||||
reasoning_format: Literal["separated", "tagged"] = "tagged",
|
||||
request_start_time: float | None = None,
|
||||
) -> Generator[NodeEventBase | LLMStructuredOutput, None, None]:
|
||||
@@ -483,8 +479,14 @@ class LLMNode(Node[LLMNodeData]):
|
||||
usage = result.delta.usage
|
||||
if finish_reason is None and result.delta.finish_reason:
|
||||
finish_reason = result.delta.finish_reason
|
||||
except OutputParserError as e:
|
||||
raise LLMNodeError(f"Failed to parse structured output: {e}")
|
||||
except Exception as e:
|
||||
if hasattr(model_instance, "is_structured_output_parse_error") and cast(
|
||||
PreparedLLMProtocol, model_instance
|
||||
).is_structured_output_parse_error(e):
|
||||
raise LLMNodeError(f"Failed to parse structured output: {e}") from e
|
||||
if type(e).__name__ == "OutputParserError":
|
||||
raise LLMNodeError(f"Failed to parse structured output: {e}") from e
|
||||
raise
|
||||
|
||||
# Extract reasoning content from <think> tags in the main text
|
||||
full_text = full_text_buffer.getvalue()
|
||||
@@ -687,30 +689,8 @@ class LLMNode(Node[LLMNodeData]):
|
||||
segment_id = retriever_resource.get("segment_id")
|
||||
if not segment_id:
|
||||
continue
|
||||
attachments_with_bindings = db.session.execute(
|
||||
select(SegmentAttachmentBinding, UploadFile)
|
||||
.join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id)
|
||||
.where(
|
||||
SegmentAttachmentBinding.segment_id == segment_id,
|
||||
)
|
||||
).all()
|
||||
if attachments_with_bindings:
|
||||
for _, upload_file in attachments_with_bindings:
|
||||
attachment_info = File(
|
||||
id=upload_file.id,
|
||||
filename=upload_file.name,
|
||||
extension="." + upload_file.extension,
|
||||
mime_type=upload_file.mime_type,
|
||||
tenant_id=self.require_dify_context().tenant_id,
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
remote_url=upload_file.source_url,
|
||||
related_id=upload_file.id,
|
||||
size=upload_file.size,
|
||||
storage_key=upload_file.key,
|
||||
url=sign_upload_file(upload_file.id, upload_file.extension),
|
||||
)
|
||||
context_files.append(attachment_info)
|
||||
if self._retriever_attachment_loader is not None:
|
||||
context_files.extend(self._retriever_attachment_loader.load(segment_id=segment_id))
|
||||
yield RunRetrieverResourceEvent(
|
||||
retriever_resources=original_retriever_resource,
|
||||
context=context_str.strip(),
|
||||
@@ -755,7 +735,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
sys_files: Sequence[File],
|
||||
context: str | None = None,
|
||||
memory: PromptMessageMemory | None = None,
|
||||
model_instance: ModelInstance,
|
||||
model_instance: PreparedLLMProtocol,
|
||||
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate,
|
||||
stop: Sequence[str] | None = None,
|
||||
memory_config: MemoryConfig | None = None,
|
||||
@@ -764,24 +744,186 @@ class LLMNode(Node[LLMNodeData]):
|
||||
variable_pool: VariablePool,
|
||||
jinja2_variables: Sequence[VariableSelector],
|
||||
context_files: list[File] | None = None,
|
||||
template_renderer: TemplateRenderer | None = None,
|
||||
jinja2_template_renderer: Jinja2TemplateRenderer | None = None,
|
||||
) -> tuple[Sequence[PromptMessage], Sequence[str] | None]:
|
||||
return llm_utils.fetch_prompt_messages(
|
||||
sys_query=sys_query,
|
||||
sys_files=sys_files,
|
||||
context=context,
|
||||
memory=memory,
|
||||
model_instance=model_instance,
|
||||
prompt_template=prompt_template,
|
||||
stop=stop,
|
||||
memory_config=memory_config,
|
||||
vision_enabled=vision_enabled,
|
||||
vision_detail=vision_detail,
|
||||
variable_pool=variable_pool,
|
||||
jinja2_variables=jinja2_variables,
|
||||
context_files=context_files,
|
||||
template_renderer=template_renderer,
|
||||
)
|
||||
prompt_messages: list[PromptMessage] = []
|
||||
model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
|
||||
|
||||
if isinstance(prompt_template, list):
|
||||
# For chat model
|
||||
prompt_messages.extend(
|
||||
LLMNode.handle_list_messages(
|
||||
messages=prompt_template,
|
||||
context=context,
|
||||
jinja2_variables=jinja2_variables,
|
||||
variable_pool=variable_pool,
|
||||
vision_detail_config=vision_detail,
|
||||
jinja2_template_renderer=jinja2_template_renderer,
|
||||
)
|
||||
)
|
||||
|
||||
# Get memory messages for chat mode
|
||||
memory_messages = _handle_memory_chat_mode(
|
||||
memory=memory,
|
||||
memory_config=memory_config,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
# Extend prompt_messages with memory messages
|
||||
prompt_messages.extend(memory_messages)
|
||||
|
||||
# Add current query to the prompt messages
|
||||
if sys_query:
|
||||
message = LLMNodeChatModelMessage(
|
||||
text=sys_query,
|
||||
role=PromptMessageRole.USER,
|
||||
edition_type="basic",
|
||||
)
|
||||
prompt_messages.extend(
|
||||
LLMNode.handle_list_messages(
|
||||
messages=[message],
|
||||
context="",
|
||||
jinja2_variables=[],
|
||||
variable_pool=variable_pool,
|
||||
vision_detail_config=vision_detail,
|
||||
jinja2_template_renderer=jinja2_template_renderer,
|
||||
)
|
||||
)
|
||||
|
||||
elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate):
|
||||
# For completion model
|
||||
prompt_messages.extend(
|
||||
_handle_completion_template(
|
||||
template=prompt_template,
|
||||
context=context,
|
||||
jinja2_variables=jinja2_variables,
|
||||
variable_pool=variable_pool,
|
||||
jinja2_template_renderer=jinja2_template_renderer,
|
||||
)
|
||||
)
|
||||
|
||||
# Get memory text for completion model
|
||||
memory_text = _handle_memory_completion_mode(
|
||||
memory=memory,
|
||||
memory_config=memory_config,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
# Insert histories into the prompt
|
||||
prompt_content = prompt_messages[0].content
|
||||
# For issue #11247 - Check if prompt content is a string or a list
|
||||
if isinstance(prompt_content, str):
|
||||
prompt_content = str(prompt_content)
|
||||
if "#histories#" in prompt_content:
|
||||
prompt_content = prompt_content.replace("#histories#", memory_text)
|
||||
else:
|
||||
prompt_content = memory_text + "\n" + prompt_content
|
||||
prompt_messages[0].content = prompt_content
|
||||
elif isinstance(prompt_content, list):
|
||||
for content_item in prompt_content:
|
||||
if isinstance(content_item, TextPromptMessageContent):
|
||||
if "#histories#" in content_item.data:
|
||||
content_item.data = content_item.data.replace("#histories#", memory_text)
|
||||
else:
|
||||
content_item.data = memory_text + "\n" + content_item.data
|
||||
else:
|
||||
raise ValueError("Invalid prompt content type")
|
||||
|
||||
# Add current query to the prompt message
|
||||
if sys_query:
|
||||
if isinstance(prompt_content, str):
|
||||
prompt_content = str(prompt_messages[0].content).replace("#sys.query#", sys_query)
|
||||
prompt_messages[0].content = prompt_content
|
||||
elif isinstance(prompt_content, list):
|
||||
for content_item in prompt_content:
|
||||
if isinstance(content_item, TextPromptMessageContent):
|
||||
content_item.data = sys_query + "\n" + content_item.data
|
||||
else:
|
||||
raise ValueError("Invalid prompt content type")
|
||||
else:
|
||||
raise TemplateTypeNotSupportError(type_name=str(type(prompt_template)))
|
||||
|
||||
# The sys_files will be deprecated later
|
||||
if vision_enabled and sys_files:
|
||||
file_prompts = []
|
||||
for file in sys_files:
|
||||
file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail)
|
||||
file_prompts.append(file_prompt)
|
||||
# If last prompt is a user prompt, add files into its contents,
|
||||
# otherwise append a new user prompt
|
||||
if (
|
||||
len(prompt_messages) > 0
|
||||
and isinstance(prompt_messages[-1], UserPromptMessage)
|
||||
and isinstance(prompt_messages[-1].content, list)
|
||||
):
|
||||
prompt_messages[-1] = UserPromptMessage(content=file_prompts + prompt_messages[-1].content)
|
||||
else:
|
||||
prompt_messages.append(UserPromptMessage(content=file_prompts))
|
||||
|
||||
# The context_files
|
||||
if vision_enabled and context_files:
|
||||
file_prompts = []
|
||||
for file in context_files:
|
||||
file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail)
|
||||
file_prompts.append(file_prompt)
|
||||
# If last prompt is a user prompt, add files into its contents,
|
||||
# otherwise append a new user prompt
|
||||
if (
|
||||
len(prompt_messages) > 0
|
||||
and isinstance(prompt_messages[-1], UserPromptMessage)
|
||||
and isinstance(prompt_messages[-1].content, list)
|
||||
):
|
||||
prompt_messages[-1] = UserPromptMessage(content=file_prompts + prompt_messages[-1].content)
|
||||
else:
|
||||
prompt_messages.append(UserPromptMessage(content=file_prompts))
|
||||
|
||||
# Remove empty messages and filter unsupported content
|
||||
filtered_prompt_messages = []
|
||||
for prompt_message in prompt_messages:
|
||||
if isinstance(prompt_message.content, list):
|
||||
prompt_message_content: list[PromptMessageContentUnionTypes] = []
|
||||
for content_item in prompt_message.content:
|
||||
# Skip content if features are not defined
|
||||
if not model_schema.features:
|
||||
if content_item.type != PromptMessageContentType.TEXT:
|
||||
continue
|
||||
prompt_message_content.append(content_item)
|
||||
continue
|
||||
|
||||
# Skip content if corresponding feature is not supported
|
||||
if (
|
||||
(
|
||||
content_item.type == PromptMessageContentType.IMAGE
|
||||
and ModelFeature.VISION not in model_schema.features
|
||||
)
|
||||
or (
|
||||
content_item.type == PromptMessageContentType.DOCUMENT
|
||||
and ModelFeature.DOCUMENT not in model_schema.features
|
||||
)
|
||||
or (
|
||||
content_item.type == PromptMessageContentType.VIDEO
|
||||
and ModelFeature.VIDEO not in model_schema.features
|
||||
)
|
||||
or (
|
||||
content_item.type == PromptMessageContentType.AUDIO
|
||||
and ModelFeature.AUDIO not in model_schema.features
|
||||
)
|
||||
):
|
||||
continue
|
||||
prompt_message_content.append(content_item)
|
||||
if len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT:
|
||||
prompt_message.content = prompt_message_content[0].data
|
||||
else:
|
||||
prompt_message.content = prompt_message_content
|
||||
if prompt_message.is_empty():
|
||||
continue
|
||||
filtered_prompt_messages.append(prompt_message)
|
||||
|
||||
if len(filtered_prompt_messages) == 0:
|
||||
raise NoPromptFoundError(
|
||||
"No prompt found in the LLM configuration. "
|
||||
"Please ensure a prompt is properly configured before proceeding."
|
||||
)
|
||||
|
||||
return filtered_prompt_messages, stop
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
@@ -881,16 +1023,61 @@ class LLMNode(Node[LLMNodeData]):
|
||||
jinja2_variables: Sequence[VariableSelector],
|
||||
variable_pool: VariablePool,
|
||||
vision_detail_config: ImagePromptMessageContent.DETAIL,
|
||||
template_renderer: TemplateRenderer | None = None,
|
||||
jinja2_template_renderer: Jinja2TemplateRenderer | None = None,
|
||||
) -> Sequence[PromptMessage]:
|
||||
return llm_utils.handle_list_messages(
|
||||
messages=messages,
|
||||
context=context,
|
||||
jinja2_variables=jinja2_variables,
|
||||
variable_pool=variable_pool,
|
||||
vision_detail_config=vision_detail_config,
|
||||
template_renderer=template_renderer,
|
||||
)
|
||||
prompt_messages: list[PromptMessage] = []
|
||||
for message in messages:
|
||||
if message.edition_type == "jinja2":
|
||||
result_text = _render_jinja2_message(
|
||||
template=message.jinja2_text or "",
|
||||
jinja2_variables=jinja2_variables,
|
||||
variable_pool=variable_pool,
|
||||
jinja2_template_renderer=jinja2_template_renderer,
|
||||
)
|
||||
prompt_message = _combine_message_content_with_role(
|
||||
contents=[TextPromptMessageContent(data=result_text)], role=message.role
|
||||
)
|
||||
prompt_messages.append(prompt_message)
|
||||
else:
|
||||
# Get segment group from basic message
|
||||
if context:
|
||||
template = message.text.replace("{#context#}", context)
|
||||
else:
|
||||
template = message.text
|
||||
segment_group = variable_pool.convert_template(template)
|
||||
|
||||
# Process segments for images
|
||||
file_contents = []
|
||||
for segment in segment_group.value:
|
||||
if isinstance(segment, ArrayFileSegment):
|
||||
for file in segment.value:
|
||||
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
|
||||
file_content = file_manager.to_prompt_message_content(
|
||||
file, image_detail_config=vision_detail_config
|
||||
)
|
||||
file_contents.append(file_content)
|
||||
elif isinstance(segment, FileSegment):
|
||||
file = segment.value
|
||||
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
|
||||
file_content = file_manager.to_prompt_message_content(
|
||||
file, image_detail_config=vision_detail_config
|
||||
)
|
||||
file_contents.append(file_content)
|
||||
|
||||
# Create message with text from all segments
|
||||
plain_text = segment_group.text
|
||||
if plain_text:
|
||||
prompt_message = _combine_message_content_with_role(
|
||||
contents=[TextPromptMessageContent(data=plain_text)], role=message.role
|
||||
)
|
||||
prompt_messages.append(prompt_message)
|
||||
|
||||
if file_contents:
|
||||
# Create message with image contents
|
||||
prompt_message = _combine_message_content_with_role(contents=file_contents, role=message.role)
|
||||
prompt_messages.append(prompt_message)
|
||||
|
||||
return prompt_messages
|
||||
|
||||
@staticmethod
|
||||
def handle_blocking_result(
|
||||
@@ -1027,5 +1214,153 @@ class LLMNode(Node[LLMNodeData]):
|
||||
return self.node_data.retry_config.retry_enabled
|
||||
|
||||
@property
|
||||
def model_instance(self) -> ModelInstance:
|
||||
def model_instance(self) -> PreparedLLMProtocol:
|
||||
return self._model_instance
|
||||
|
||||
|
||||
def _combine_message_content_with_role(
|
||||
*, contents: str | list[PromptMessageContentUnionTypes] | None = None, role: PromptMessageRole
|
||||
):
|
||||
match role:
|
||||
case PromptMessageRole.USER:
|
||||
return UserPromptMessage(content=contents)
|
||||
case PromptMessageRole.ASSISTANT:
|
||||
return AssistantPromptMessage(content=contents)
|
||||
case PromptMessageRole.SYSTEM:
|
||||
return SystemPromptMessage(content=contents)
|
||||
case _:
|
||||
raise NotImplementedError(f"Role {role} is not supported")
|
||||
|
||||
|
||||
def _render_jinja2_message(
|
||||
*,
|
||||
template: str,
|
||||
jinja2_variables: Sequence[VariableSelector],
|
||||
variable_pool: VariablePool,
|
||||
jinja2_template_renderer: Jinja2TemplateRenderer | None,
|
||||
):
|
||||
if not template:
|
||||
return ""
|
||||
|
||||
jinja2_inputs = {}
|
||||
for jinja2_variable in jinja2_variables:
|
||||
variable = variable_pool.get(jinja2_variable.value_selector)
|
||||
jinja2_inputs[jinja2_variable.variable] = variable.to_object() if variable else ""
|
||||
if jinja2_template_renderer is None:
|
||||
raise TemplateRenderError("LLMNode requires an injected jinja2_template_renderer for jinja2 prompts.")
|
||||
return jinja2_template_renderer.render_template(template, jinja2_inputs)
|
||||
|
||||
|
||||
def _calculate_rest_token(
|
||||
*,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_instance: PreparedLLMProtocol,
|
||||
) -> int:
|
||||
rest_tokens = 2000
|
||||
runtime_model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
|
||||
runtime_model_parameters = model_instance.parameters
|
||||
|
||||
model_context_tokens = runtime_model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
||||
if model_context_tokens:
|
||||
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)
|
||||
|
||||
max_tokens = 0
|
||||
for parameter_rule in runtime_model_schema.parameter_rules:
|
||||
if parameter_rule.name == "max_tokens" or (
|
||||
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
|
||||
):
|
||||
max_tokens = (
|
||||
runtime_model_parameters.get(parameter_rule.name)
|
||||
or runtime_model_parameters.get(str(parameter_rule.use_template))
|
||||
or 0
|
||||
)
|
||||
|
||||
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
|
||||
rest_tokens = max(rest_tokens, 0)
|
||||
|
||||
return rest_tokens
|
||||
|
||||
|
||||
def _handle_memory_chat_mode(
|
||||
*,
|
||||
memory: PromptMessageMemory | None,
|
||||
memory_config: MemoryConfig | None,
|
||||
model_instance: PreparedLLMProtocol,
|
||||
) -> Sequence[PromptMessage]:
|
||||
memory_messages: Sequence[PromptMessage] = []
|
||||
# Get messages from memory for chat model
|
||||
if memory and memory_config:
|
||||
rest_tokens = _calculate_rest_token(
|
||||
prompt_messages=[],
|
||||
model_instance=model_instance,
|
||||
)
|
||||
memory_messages = memory.get_history_prompt_messages(
|
||||
max_token_limit=rest_tokens,
|
||||
message_limit=memory_config.window.size if memory_config.window.enabled else None,
|
||||
)
|
||||
return memory_messages
|
||||
|
||||
|
||||
def _handle_memory_completion_mode(
|
||||
*,
|
||||
memory: PromptMessageMemory | None,
|
||||
memory_config: MemoryConfig | None,
|
||||
model_instance: PreparedLLMProtocol,
|
||||
) -> str:
|
||||
memory_text = ""
|
||||
# Get history text from memory for completion model
|
||||
if memory and memory_config:
|
||||
rest_tokens = _calculate_rest_token(
|
||||
prompt_messages=[],
|
||||
model_instance=model_instance,
|
||||
)
|
||||
if not memory_config.role_prefix:
|
||||
raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.")
|
||||
memory_text = llm_utils.fetch_memory_text(
|
||||
memory=memory,
|
||||
max_token_limit=rest_tokens,
|
||||
message_limit=memory_config.window.size if memory_config.window.enabled else None,
|
||||
human_prefix=memory_config.role_prefix.user,
|
||||
ai_prefix=memory_config.role_prefix.assistant,
|
||||
)
|
||||
return memory_text
|
||||
|
||||
|
||||
def _handle_completion_template(
|
||||
*,
|
||||
template: LLMNodeCompletionModelPromptTemplate,
|
||||
context: str | None,
|
||||
jinja2_variables: Sequence[VariableSelector],
|
||||
variable_pool: VariablePool,
|
||||
jinja2_template_renderer: Jinja2TemplateRenderer | None = None,
|
||||
) -> Sequence[PromptMessage]:
|
||||
"""Handle completion template processing outside of LLMNode class.
|
||||
|
||||
Args:
|
||||
template: The completion model prompt template
|
||||
context: Optional context string
|
||||
jinja2_variables: Variables for jinja2 template rendering
|
||||
variable_pool: Variable pool for template conversion
|
||||
|
||||
Returns:
|
||||
Sequence of prompt messages
|
||||
"""
|
||||
prompt_messages = []
|
||||
if template.edition_type == "jinja2":
|
||||
result_text = _render_jinja2_message(
|
||||
template=template.jinja2_text or "",
|
||||
jinja2_variables=jinja2_variables,
|
||||
variable_pool=variable_pool,
|
||||
jinja2_template_renderer=jinja2_template_renderer,
|
||||
)
|
||||
else:
|
||||
if context:
|
||||
template_text = template.text.replace("{#context#}", context)
|
||||
else:
|
||||
template_text = template.text
|
||||
result_text = variable_pool.convert_template(template_text).text
|
||||
prompt_message = _combine_message_content_with_role(
|
||||
contents=[TextPromptMessageContent(data=result_text)], role=PromptMessageRole.USER
|
||||
)
|
||||
prompt_messages.append(prompt_message)
|
||||
return prompt_messages
|
||||
|
||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Protocol
|
||||
|
||||
from core.model_manager import ModelInstance
|
||||
from dify_graph.nodes.llm.runtime_protocols import PreparedLLMProtocol
|
||||
|
||||
|
||||
class CredentialsProvider(Protocol):
|
||||
@@ -15,10 +15,10 @@ class CredentialsProvider(Protocol):
|
||||
|
||||
|
||||
class ModelFactory(Protocol):
|
||||
"""Port for creating initialized LLM model instances for execution."""
|
||||
"""Port for creating prepared graph-facing LLM runtimes for execution."""
|
||||
|
||||
def init_model_instance(self, provider_name: str, model_name: str) -> ModelInstance:
|
||||
"""Create a model instance that is ready for schema lookup and invocation."""
|
||||
def init_model_instance(self, provider_name: str, model_name: str) -> PreparedLLMProtocol:
|
||||
"""Create a prepared LLM runtime that is ready for graph execution."""
|
||||
...
|
||||
|
||||
|
||||
|
||||
74
api/dify_graph/nodes/llm/runtime_protocols.py
Normal file
74
api/dify_graph/nodes/llm/runtime_protocols.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any, Protocol
|
||||
|
||||
from dify_graph.file import File
|
||||
from dify_graph.model_runtime.entities import LLMMode, PromptMessage
|
||||
from dify_graph.model_runtime.entities.llm_entities import (
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMResultChunkWithStructuredOutput,
|
||||
LLMResultWithStructuredOutput,
|
||||
)
|
||||
from dify_graph.model_runtime.entities.message_entities import PromptMessageTool
|
||||
from dify_graph.model_runtime.entities.model_entities import AIModelEntity
|
||||
|
||||
|
||||
class PreparedLLMProtocol(Protocol):
|
||||
"""A graph-facing LLM runtime with provider-specific setup already applied."""
|
||||
|
||||
@property
|
||||
def provider(self) -> str: ...
|
||||
|
||||
@property
|
||||
def model_name(self) -> str: ...
|
||||
|
||||
@property
|
||||
def parameters(self) -> Mapping[str, Any]: ...
|
||||
|
||||
@property
|
||||
def stop(self) -> Sequence[str] | None: ...
|
||||
|
||||
def get_model_schema(self) -> AIModelEntity: ...
|
||||
|
||||
def get_llm_num_tokens(self, prompt_messages: Sequence[PromptMessage]) -> int: ...
|
||||
|
||||
def invoke_llm(
|
||||
self,
|
||||
*,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
model_parameters: Mapping[str, Any],
|
||||
tools: Sequence[PromptMessageTool] | None,
|
||||
stop: Sequence[str] | None,
|
||||
stream: bool,
|
||||
) -> LLMResult | Generator[LLMResultChunk, None, None]: ...
|
||||
|
||||
def invoke_llm_with_structured_output(
|
||||
self,
|
||||
*,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
json_schema: Mapping[str, Any],
|
||||
model_parameters: Mapping[str, Any],
|
||||
stop: Sequence[str] | None,
|
||||
stream: bool,
|
||||
) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
|
||||
|
||||
def is_structured_output_parse_error(self, error: Exception) -> bool: ...
|
||||
|
||||
|
||||
class PromptMessageSerializerProtocol(Protocol):
|
||||
"""Port for converting compiled prompt messages into persisted process data."""
|
||||
|
||||
def serialize(
|
||||
self,
|
||||
*,
|
||||
model_mode: LLMMode,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
) -> Any: ...
|
||||
|
||||
|
||||
class RetrieverAttachmentLoaderProtocol(Protocol):
|
||||
"""Port for resolving retriever segment attachments into graph file references."""
|
||||
|
||||
def load(self, *, segment_id: str) -> Sequence[File]: ...
|
||||
@@ -31,9 +31,9 @@ from dify_graph.nodes.base import LLMUsageTrackingMixin
|
||||
from dify_graph.nodes.base.node import Node
|
||||
from dify_graph.nodes.loop.entities import LoopCompletedReason, LoopNodeData, LoopVariableData
|
||||
from dify_graph.utils.condition.processor import ConditionProcessor
|
||||
from dify_graph.utils.datetime_utils import naive_utc_now
|
||||
from dify_graph.variables import Segment, SegmentType
|
||||
from factories.variable_factory import TypeMismatchError, build_segment_with_type, segment_to_variable
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dify_graph.graph_engine import GraphEngine
|
||||
|
||||
@@ -7,10 +7,10 @@ from pydantic import (
|
||||
field_validator,
|
||||
)
|
||||
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import BuiltinNodeTypes, NodeType
|
||||
from dify_graph.nodes.llm.entities import ModelConfig, VisionConfig
|
||||
from dify_graph.prompt_entities import MemoryConfig
|
||||
from dify_graph.variables.types import SegmentType
|
||||
|
||||
_OLD_BOOL_TYPE_NAME = "bool"
|
||||
|
||||
@@ -5,11 +5,6 @@ import uuid
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from core.model_manager import ModelInstance
|
||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
|
||||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from dify_graph.enums import (
|
||||
BuiltinNodeTypes,
|
||||
@@ -17,8 +12,8 @@ from dify_graph.enums import (
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from dify_graph.file import File
|
||||
from dify_graph.model_runtime.entities import ImagePromptMessageContent
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
|
||||
from dify_graph.model_runtime.entities import ImagePromptMessageContent, LLMMode
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from dify_graph.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
@@ -27,14 +22,15 @@ from dify_graph.model_runtime.entities.message_entities import (
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey, ModelType
|
||||
from dify_graph.model_runtime.memory import PromptMessageMemory
|
||||
from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from dify_graph.model_runtime.utils.encoders import jsonable_encoder
|
||||
from dify_graph.node_events import NodeRunResult
|
||||
from dify_graph.nodes.base import variable_template_parser
|
||||
from dify_graph.nodes.base.node import Node
|
||||
from dify_graph.nodes.llm import llm_utils
|
||||
from dify_graph.nodes.llm import LLMNode, llm_utils
|
||||
from dify_graph.nodes.llm.entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate
|
||||
from dify_graph.nodes.llm.runtime_protocols import PreparedLLMProtocol, PromptMessageSerializerProtocol
|
||||
from dify_graph.runtime import VariablePool
|
||||
from dify_graph.variables.types import ArrayValidation, SegmentType
|
||||
from factories.variable_factory import build_segment_with_type
|
||||
@@ -66,7 +62,6 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||
from dify_graph.runtime import GraphRuntimeState
|
||||
|
||||
|
||||
@@ -99,9 +94,8 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
|
||||
node_type = BuiltinNodeTypes.PARAMETER_EXTRACTOR
|
||||
|
||||
_model_instance: ModelInstance
|
||||
_credentials_provider: "CredentialsProvider"
|
||||
_model_factory: "ModelFactory"
|
||||
_model_instance: PreparedLLMProtocol
|
||||
_prompt_message_serializer: PromptMessageSerializerProtocol
|
||||
_memory: PromptMessageMemory | None
|
||||
|
||||
def __init__(
|
||||
@@ -111,10 +105,11 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
*,
|
||||
credentials_provider: "CredentialsProvider",
|
||||
model_factory: "ModelFactory",
|
||||
model_instance: ModelInstance,
|
||||
credentials_provider: object | None = None,
|
||||
model_factory: object | None = None,
|
||||
model_instance: PreparedLLMProtocol,
|
||||
memory: PromptMessageMemory | None = None,
|
||||
prompt_message_serializer: PromptMessageSerializerProtocol,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
id=id,
|
||||
@@ -122,9 +117,9 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
self._credentials_provider = credentials_provider
|
||||
self._model_factory = model_factory
|
||||
_ = credentials_provider, model_factory
|
||||
self._model_instance = model_instance
|
||||
self._prompt_message_serializer = prompt_message_serializer
|
||||
self._memory = memory
|
||||
|
||||
@classmethod
|
||||
@@ -164,13 +159,12 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
)
|
||||
|
||||
model_instance = self._model_instance
|
||||
if not isinstance(model_instance.model_type_instance, LargeLanguageModel):
|
||||
raise InvalidModelTypeError("Model is not a Large Language Model")
|
||||
|
||||
try:
|
||||
model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
|
||||
except ValueError as exc:
|
||||
raise ModelSchemaNotFoundError("Model schema not found") from exc
|
||||
if model_schema.model_type != ModelType.LLM:
|
||||
raise InvalidModelTypeError("Model is not a Large Language Model")
|
||||
memory = self._memory
|
||||
|
||||
if (
|
||||
@@ -210,8 +204,9 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
|
||||
process_data = {
|
||||
"model_mode": node_data.model.mode,
|
||||
"prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving(
|
||||
model_mode=node_data.model.mode, prompt_messages=prompt_messages
|
||||
"prompts": self._prompt_message_serializer.serialize(
|
||||
model_mode=node_data.model.mode,
|
||||
prompt_messages=prompt_messages,
|
||||
),
|
||||
"usage": None,
|
||||
"function": {} if not prompt_message_tools else jsonable_encoder(prompt_message_tools[0]),
|
||||
@@ -287,18 +282,20 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
model_instance: ModelInstance,
|
||||
model_instance: PreparedLLMProtocol,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: list[PromptMessageTool],
|
||||
stop: Sequence[str],
|
||||
stop: Sequence[str] | None,
|
||||
) -> tuple[str, LLMUsage, AssistantPromptMessage.ToolCall | None]:
|
||||
invoke_result = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=dict(model_instance.parameters),
|
||||
tools=tools,
|
||||
stop=list(stop),
|
||||
stream=False,
|
||||
user=self.require_dify_context().user_id,
|
||||
invoke_result = cast(
|
||||
LLMResult,
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=dict(model_instance.parameters),
|
||||
tools=tools or None,
|
||||
stop=stop,
|
||||
stream=False,
|
||||
),
|
||||
)
|
||||
|
||||
# handle invoke result
|
||||
@@ -317,7 +314,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
node_data: ParameterExtractorNodeData,
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_instance: ModelInstance,
|
||||
model_instance: PreparedLLMProtocol,
|
||||
memory: PromptMessageMemory | None,
|
||||
files: Sequence[File],
|
||||
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
|
||||
@@ -329,7 +326,6 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
content=query, structure=json.dumps(node_data.get_parameter_json_schema())
|
||||
)
|
||||
|
||||
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
|
||||
rest_token = self._calculate_rest_token(
|
||||
node_data=node_data,
|
||||
query=query,
|
||||
@@ -340,15 +336,11 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
prompt_template = self._get_function_calling_prompt_template(
|
||||
node_data, query, variable_pool, memory, rest_token
|
||||
)
|
||||
prompt_messages = prompt_transform.get_prompt(
|
||||
prompt_template=prompt_template,
|
||||
inputs={},
|
||||
query="",
|
||||
files=files,
|
||||
context="",
|
||||
memory_config=node_data.memory,
|
||||
memory=None,
|
||||
prompt_messages = self._compile_prompt_messages(
|
||||
model_instance=model_instance,
|
||||
prompt_template=prompt_template,
|
||||
files=files,
|
||||
vision_enabled=node_data.vision.enabled,
|
||||
image_detail_config=vision_detail,
|
||||
)
|
||||
|
||||
@@ -405,7 +397,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
data: ParameterExtractorNodeData,
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_instance: ModelInstance,
|
||||
model_instance: PreparedLLMProtocol,
|
||||
memory: PromptMessageMemory | None,
|
||||
files: Sequence[File],
|
||||
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
|
||||
@@ -413,9 +405,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
"""
|
||||
Generate prompt engineering prompt.
|
||||
"""
|
||||
model_mode = ModelMode(data.model.mode)
|
||||
|
||||
if model_mode == ModelMode.COMPLETION:
|
||||
if data.model.mode == LLMMode.COMPLETION:
|
||||
return self._generate_prompt_engineering_completion_prompt(
|
||||
node_data=data,
|
||||
query=query,
|
||||
@@ -425,7 +415,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
files=files,
|
||||
vision_detail=vision_detail,
|
||||
)
|
||||
elif model_mode == ModelMode.CHAT:
|
||||
if data.model.mode == LLMMode.CHAT:
|
||||
return self._generate_prompt_engineering_chat_prompt(
|
||||
node_data=data,
|
||||
query=query,
|
||||
@@ -435,15 +425,14 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
files=files,
|
||||
vision_detail=vision_detail,
|
||||
)
|
||||
else:
|
||||
raise InvalidModelModeError(f"Invalid model mode: {model_mode}")
|
||||
raise InvalidModelModeError(f"Invalid model mode: {data.model.mode}")
|
||||
|
||||
def _generate_prompt_engineering_completion_prompt(
|
||||
self,
|
||||
node_data: ParameterExtractorNodeData,
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_instance: ModelInstance,
|
||||
model_instance: PreparedLLMProtocol,
|
||||
memory: PromptMessageMemory | None,
|
||||
files: Sequence[File],
|
||||
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
|
||||
@@ -451,7 +440,6 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
"""
|
||||
Generate completion prompt.
|
||||
"""
|
||||
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
|
||||
rest_token = self._calculate_rest_token(
|
||||
node_data=node_data,
|
||||
query=query,
|
||||
@@ -462,27 +450,20 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
prompt_template = self._get_prompt_engineering_prompt_template(
|
||||
node_data=node_data, query=query, variable_pool=variable_pool, memory=memory, max_token_limit=rest_token
|
||||
)
|
||||
prompt_messages = prompt_transform.get_prompt(
|
||||
prompt_template=prompt_template,
|
||||
inputs={"structure": json.dumps(node_data.get_parameter_json_schema())},
|
||||
query="",
|
||||
files=files,
|
||||
context="",
|
||||
memory_config=node_data.memory,
|
||||
# AdvancedPromptTransform is still typed against TokenBufferMemory.
|
||||
memory=cast(Any, memory),
|
||||
return self._compile_prompt_messages(
|
||||
model_instance=model_instance,
|
||||
prompt_template=prompt_template,
|
||||
files=files,
|
||||
vision_enabled=node_data.vision.enabled,
|
||||
image_detail_config=vision_detail,
|
||||
)
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _generate_prompt_engineering_chat_prompt(
|
||||
self,
|
||||
node_data: ParameterExtractorNodeData,
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_instance: ModelInstance,
|
||||
model_instance: PreparedLLMProtocol,
|
||||
memory: PromptMessageMemory | None,
|
||||
files: Sequence[File],
|
||||
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
|
||||
@@ -490,7 +471,6 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
"""
|
||||
Generate chat prompt.
|
||||
"""
|
||||
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
|
||||
rest_token = self._calculate_rest_token(
|
||||
node_data=node_data,
|
||||
query=query,
|
||||
@@ -508,15 +488,11 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
max_token_limit=rest_token,
|
||||
)
|
||||
|
||||
prompt_messages = prompt_transform.get_prompt(
|
||||
prompt_template=prompt_template,
|
||||
inputs={},
|
||||
query="",
|
||||
files=files,
|
||||
context="",
|
||||
memory_config=node_data.memory,
|
||||
memory=None,
|
||||
prompt_messages = self._compile_prompt_messages(
|
||||
model_instance=model_instance,
|
||||
prompt_template=prompt_template,
|
||||
files=files,
|
||||
vision_enabled=node_data.vision.enabled,
|
||||
image_detail_config=vision_detail,
|
||||
)
|
||||
|
||||
@@ -717,8 +693,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
variable_pool: VariablePool,
|
||||
memory: PromptMessageMemory | None,
|
||||
max_token_limit: int = 2000,
|
||||
) -> list[ChatModelMessage]:
|
||||
model_mode = ModelMode(node_data.model.mode)
|
||||
) -> list[LLMNodeChatModelMessage]:
|
||||
input_text = query
|
||||
memory_str = ""
|
||||
instruction = variable_pool.convert_template(node_data.instruction or "").text
|
||||
@@ -727,15 +702,14 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
memory_str = llm_utils.fetch_memory_text(
|
||||
memory=memory, max_token_limit=max_token_limit, message_limit=node_data.memory.window.size
|
||||
)
|
||||
if model_mode == ModelMode.CHAT:
|
||||
system_prompt_messages = ChatModelMessage(
|
||||
if node_data.model.mode == LLMMode.CHAT:
|
||||
system_prompt_messages = LLMNodeChatModelMessage(
|
||||
role=PromptMessageRole.SYSTEM,
|
||||
text=FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT.format(histories=memory_str, instruction=instruction),
|
||||
)
|
||||
user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=input_text)
|
||||
user_prompt_message = LLMNodeChatModelMessage(role=PromptMessageRole.USER, text=input_text)
|
||||
return [system_prompt_messages, user_prompt_message]
|
||||
else:
|
||||
raise InvalidModelModeError(f"Model mode {model_mode} not support.")
|
||||
raise InvalidModelModeError(f"Model mode {node_data.model.mode} not support.")
|
||||
|
||||
def _get_prompt_engineering_prompt_template(
|
||||
self,
|
||||
@@ -744,8 +718,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
variable_pool: VariablePool,
|
||||
memory: PromptMessageMemory | None,
|
||||
max_token_limit: int = 2000,
|
||||
):
|
||||
model_mode = ModelMode(node_data.model.mode)
|
||||
) -> list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate:
|
||||
input_text = query
|
||||
memory_str = ""
|
||||
instruction = variable_pool.convert_template(node_data.instruction or "").text
|
||||
@@ -754,64 +727,53 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
memory_str = llm_utils.fetch_memory_text(
|
||||
memory=memory, max_token_limit=max_token_limit, message_limit=node_data.memory.window.size
|
||||
)
|
||||
if model_mode == ModelMode.CHAT:
|
||||
system_prompt_messages = ChatModelMessage(
|
||||
if node_data.model.mode == LLMMode.CHAT:
|
||||
system_prompt_messages = LLMNodeChatModelMessage(
|
||||
role=PromptMessageRole.SYSTEM,
|
||||
text=CHAT_GENERATE_JSON_PROMPT.format(histories=memory_str, instructions=instruction),
|
||||
)
|
||||
user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=input_text)
|
||||
user_prompt_message = LLMNodeChatModelMessage(role=PromptMessageRole.USER, text=input_text)
|
||||
return [system_prompt_messages, user_prompt_message]
|
||||
elif model_mode == ModelMode.COMPLETION:
|
||||
return CompletionModelPromptTemplate(
|
||||
if node_data.model.mode == LLMMode.COMPLETION:
|
||||
return LLMNodeCompletionModelPromptTemplate(
|
||||
text=COMPLETION_GENERATE_JSON_PROMPT.format(
|
||||
histories=memory_str, text=input_text, instruction=instruction
|
||||
)
|
||||
.replace("{γγγ", "")
|
||||
.replace("}γγγ", "")
|
||||
.replace("{ structure }", json.dumps(node_data.get_parameter_json_schema())),
|
||||
)
|
||||
else:
|
||||
raise InvalidModelModeError(f"Model mode {model_mode} not support.")
|
||||
raise InvalidModelModeError(f"Model mode {node_data.model.mode} not support.")
|
||||
|
||||
def _calculate_rest_token(
|
||||
self,
|
||||
node_data: ParameterExtractorNodeData,
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_instance: ModelInstance,
|
||||
model_instance: PreparedLLMProtocol,
|
||||
context: str | None,
|
||||
) -> int:
|
||||
try:
|
||||
model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
|
||||
except ValueError as exc:
|
||||
raise ModelSchemaNotFoundError("Model schema not found") from exc
|
||||
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
|
||||
|
||||
if set(model_schema.features or []) & {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}:
|
||||
prompt_template = self._get_function_calling_prompt_template(node_data, query, variable_pool, None, 2000)
|
||||
else:
|
||||
prompt_template = self._get_prompt_engineering_prompt_template(node_data, query, variable_pool, None, 2000)
|
||||
|
||||
prompt_messages = prompt_transform.get_prompt(
|
||||
prompt_template=prompt_template,
|
||||
inputs={},
|
||||
query="",
|
||||
files=[],
|
||||
context=context,
|
||||
memory_config=node_data.memory,
|
||||
memory=None,
|
||||
prompt_messages = self._compile_prompt_messages(
|
||||
model_instance=model_instance,
|
||||
prompt_template=prompt_template,
|
||||
files=[],
|
||||
vision_enabled=False,
|
||||
context=context,
|
||||
)
|
||||
rest_tokens = 2000
|
||||
|
||||
model_context_tokens = model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
||||
if model_context_tokens:
|
||||
model_type_instance = cast(LargeLanguageModel, model_instance.model_type_instance)
|
||||
curr_message_tokens = (
|
||||
model_type_instance.get_num_tokens(
|
||||
model_instance.model_name, model_instance.credentials, prompt_messages
|
||||
)
|
||||
+ 1000
|
||||
) # add 1000 to ensure tool call messages
|
||||
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages) + 1000
|
||||
|
||||
max_tokens = 0
|
||||
for parameter_rule in model_schema.parameter_rules:
|
||||
@@ -828,8 +790,34 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
|
||||
return rest_tokens
|
||||
|
||||
def _compile_prompt_messages(
|
||||
self,
|
||||
*,
|
||||
model_instance: PreparedLLMProtocol,
|
||||
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate,
|
||||
files: Sequence[File],
|
||||
vision_enabled: bool,
|
||||
context: str | None = "",
|
||||
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
|
||||
) -> list[PromptMessage]:
|
||||
prompt_messages, _ = LLMNode.fetch_prompt_messages(
|
||||
sys_query="",
|
||||
sys_files=files,
|
||||
context=context,
|
||||
memory=None,
|
||||
model_instance=model_instance,
|
||||
prompt_template=prompt_template,
|
||||
stop=model_instance.stop,
|
||||
memory_config=None,
|
||||
vision_enabled=vision_enabled,
|
||||
vision_detail=image_detail_config or ImagePromptMessageContent.DETAIL.HIGH,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
jinja2_variables=[],
|
||||
)
|
||||
return list(prompt_messages)
|
||||
|
||||
@property
|
||||
def model_instance(self) -> ModelInstance:
|
||||
def model_instance(self) -> PreparedLLMProtocol:
|
||||
return self._model_instance
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, Protocol
|
||||
|
||||
import httpx
|
||||
@@ -35,8 +35,6 @@ class ToolFileManagerProtocol(Protocol):
|
||||
def create_file_by_raw(
|
||||
self,
|
||||
*,
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
conversation_id: str | None,
|
||||
file_binary: bytes,
|
||||
mimetype: str,
|
||||
@@ -44,3 +42,7 @@ class ToolFileManagerProtocol(Protocol):
|
||||
) -> Any: ...
|
||||
|
||||
def get_file_generator_by_tool_file_id(self, tool_file_id: str) -> tuple[Generator | None, ToolFile | None]: ...
|
||||
|
||||
|
||||
class FileReferenceFactoryProtocol(Protocol):
|
||||
def build_from_mapping(self, *, mapping: Mapping[str, Any]) -> File: ...
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import BuiltinNodeTypes, NodeType
|
||||
from dify_graph.nodes.llm import ModelConfig, VisionConfig
|
||||
from dify_graph.prompt_entities import MemoryConfig
|
||||
|
||||
|
||||
class ClassConfig(BaseModel):
|
||||
|
||||
@@ -3,9 +3,6 @@ import re
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.model_manager import ModelInstance
|
||||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from dify_graph.enums import (
|
||||
@@ -14,7 +11,7 @@ from dify_graph.enums import (
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from dify_graph.model_runtime.entities import LLMUsage, ModelPropertyKey, PromptMessageRole
|
||||
from dify_graph.model_runtime.entities import LLMMode, LLMUsage, ModelPropertyKey, PromptMessageRole
|
||||
from dify_graph.model_runtime.memory import PromptMessageMemory
|
||||
from dify_graph.model_runtime.utils.encoders import jsonable_encoder
|
||||
from dify_graph.node_events import ModelInvokeCompletedEvent, NodeRunResult
|
||||
@@ -27,10 +24,11 @@ from dify_graph.nodes.llm import (
|
||||
LLMNodeCompletionModelPromptTemplate,
|
||||
llm_utils,
|
||||
)
|
||||
from dify_graph.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver
|
||||
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer
|
||||
from dify_graph.nodes.llm.file_saver import LLMFileSaver
|
||||
from dify_graph.nodes.llm.protocols import TemplateRenderer
|
||||
from dify_graph.nodes.llm.runtime_protocols import PreparedLLMProtocol, PromptMessageSerializerProtocol
|
||||
from dify_graph.nodes.protocols import HttpClientProtocol
|
||||
from libs.json_in_md_parser import parse_and_check_json_markdown
|
||||
from dify_graph.utils.json_in_md_parser import parse_and_check_json_markdown
|
||||
|
||||
from .entities import QuestionClassifierNodeData
|
||||
from .exc import InvalidModelTypeError
|
||||
@@ -49,15 +47,20 @@ if TYPE_CHECKING:
|
||||
from dify_graph.runtime import GraphRuntimeState
|
||||
|
||||
|
||||
class _PassthroughPromptMessageSerializer:
|
||||
def serialize(self, *, model_mode: Any, prompt_messages: Sequence[Any]) -> Any:
|
||||
_ = model_mode
|
||||
return list(prompt_messages)
|
||||
|
||||
|
||||
class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
node_type = BuiltinNodeTypes.QUESTION_CLASSIFIER
|
||||
execution_type = NodeExecutionType.BRANCH
|
||||
|
||||
_file_outputs: list["File"]
|
||||
_llm_file_saver: LLMFileSaver
|
||||
_credentials_provider: "CredentialsProvider"
|
||||
_model_factory: "ModelFactory"
|
||||
_model_instance: ModelInstance
|
||||
_prompt_message_serializer: PromptMessageSerializerProtocol
|
||||
_model_instance: PreparedLLMProtocol
|
||||
_memory: PromptMessageMemory | None
|
||||
_template_renderer: TemplateRenderer
|
||||
|
||||
@@ -68,13 +71,14 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
*,
|
||||
credentials_provider: "CredentialsProvider",
|
||||
model_factory: "ModelFactory",
|
||||
model_instance: ModelInstance,
|
||||
credentials_provider: object | None = None,
|
||||
model_factory: object | None = None,
|
||||
model_instance: PreparedLLMProtocol,
|
||||
http_client: HttpClientProtocol,
|
||||
template_renderer: TemplateRenderer,
|
||||
memory: PromptMessageMemory | None = None,
|
||||
llm_file_saver: LLMFileSaver | None = None,
|
||||
llm_file_saver: LLMFileSaver,
|
||||
prompt_message_serializer: PromptMessageSerializerProtocol | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
id=id,
|
||||
@@ -85,20 +89,13 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
# LLM file outputs, used for MultiModal outputs.
|
||||
self._file_outputs = []
|
||||
|
||||
self._credentials_provider = credentials_provider
|
||||
self._model_factory = model_factory
|
||||
_ = credentials_provider, model_factory, http_client
|
||||
self._model_instance = model_instance
|
||||
self._memory = memory
|
||||
self._template_renderer = template_renderer
|
||||
|
||||
if llm_file_saver is None:
|
||||
dify_ctx = self.require_dify_context()
|
||||
llm_file_saver = FileSaverImpl(
|
||||
user_id=dify_ctx.user_id,
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
http_client=http_client,
|
||||
)
|
||||
self._llm_file_saver = llm_file_saver
|
||||
self._prompt_message_serializer = prompt_message_serializer or _PassthroughPromptMessageSerializer()
|
||||
|
||||
@classmethod
|
||||
def version(cls):
|
||||
@@ -169,7 +166,6 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
stop=stop,
|
||||
user_id=self.require_dify_context().user_id,
|
||||
structured_output_enabled=False,
|
||||
structured_output=None,
|
||||
file_saver=self._llm_file_saver,
|
||||
@@ -205,7 +201,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
category_id = category_id_result
|
||||
process_data = {
|
||||
"model_mode": node_data.model.mode,
|
||||
"prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving(
|
||||
"prompts": self._prompt_message_serializer.serialize(
|
||||
model_mode=node_data.model.mode, prompt_messages=prompt_messages
|
||||
),
|
||||
"usage": jsonable_encoder(usage),
|
||||
@@ -247,7 +243,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
)
|
||||
|
||||
@property
|
||||
def model_instance(self) -> ModelInstance:
|
||||
def model_instance(self) -> PreparedLLMProtocol:
|
||||
return self._model_instance
|
||||
|
||||
@classmethod
|
||||
@@ -285,7 +281,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
self,
|
||||
node_data: QuestionClassifierNodeData,
|
||||
query: str,
|
||||
model_instance: ModelInstance,
|
||||
model_instance: PreparedLLMProtocol,
|
||||
context: str | None,
|
||||
) -> int:
|
||||
model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
|
||||
@@ -334,7 +330,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
memory: PromptMessageMemory | None,
|
||||
max_token_limit: int = 2000,
|
||||
):
|
||||
model_mode = ModelMode(node_data.model.mode)
|
||||
model_mode = LLMMode(node_data.model.mode)
|
||||
classes = node_data.classes
|
||||
categories = []
|
||||
for class_ in classes:
|
||||
@@ -350,7 +346,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
message_limit=node_data.memory.window.size if node_data.memory and node_data.memory.window else None,
|
||||
)
|
||||
prompt_messages: list[LLMNodeChatModelMessage] = []
|
||||
if model_mode == ModelMode.CHAT:
|
||||
if model_mode == LLMMode.CHAT:
|
||||
system_prompt_messages = LLMNodeChatModelMessage(
|
||||
role=PromptMessageRole.SYSTEM, text=QUESTION_CLASSIFIER_SYSTEM_PROMPT.format(histories=memory_str)
|
||||
)
|
||||
@@ -381,7 +377,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
)
|
||||
prompt_messages.append(user_prompt_message_3)
|
||||
return prompt_messages
|
||||
elif model_mode == ModelMode.COMPLETION:
|
||||
elif model_mode == LLMMode.COMPLETION:
|
||||
return LLMNodeCompletionModelPromptTemplate(
|
||||
text=QUESTION_CLASSIFIER_COMPLETION_PROMPT.format(
|
||||
histories=memory_str,
|
||||
|
||||
75
api/dify_graph/nodes/runtime.py
Normal file
75
api/dify_graph/nodes/runtime.py
Normal file
@@ -0,0 +1,75 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Protocol
|
||||
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
|
||||
from dify_graph.nodes.tool_runtime_entities import (
|
||||
ToolRuntimeHandle,
|
||||
ToolRuntimeMessage,
|
||||
ToolRuntimeParameter,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dify_graph.nodes.human_input.entities import DeliveryChannelConfig
|
||||
from dify_graph.nodes.tool.entities import ToolNodeData
|
||||
from dify_graph.runtime import VariablePool
|
||||
|
||||
|
||||
class ToolNodeRuntimeProtocol(Protocol):
|
||||
"""Workflow-layer adapter owned by `core.workflow` and consumed by `dify_graph`.
|
||||
|
||||
The graph package depends only on these DTOs and lets the workflow layer
|
||||
translate between graph-owned abstractions and `core.tools` internals.
|
||||
"""
|
||||
|
||||
def get_runtime(
|
||||
self,
|
||||
*,
|
||||
node_id: str,
|
||||
node_data: ToolNodeData,
|
||||
variable_pool: VariablePool | None,
|
||||
) -> ToolRuntimeHandle: ...
|
||||
|
||||
def get_runtime_parameters(
|
||||
self,
|
||||
*,
|
||||
tool_runtime: ToolRuntimeHandle,
|
||||
) -> Sequence[ToolRuntimeParameter]: ...
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
*,
|
||||
tool_runtime: ToolRuntimeHandle,
|
||||
tool_parameters: Mapping[str, Any],
|
||||
workflow_call_depth: int,
|
||||
conversation_id: str | None,
|
||||
provider_name: str,
|
||||
) -> Generator[ToolRuntimeMessage, None, None]: ...
|
||||
|
||||
def get_usage(
|
||||
self,
|
||||
*,
|
||||
tool_runtime: ToolRuntimeHandle,
|
||||
) -> LLMUsage: ...
|
||||
|
||||
def build_file_reference(self, *, mapping: Mapping[str, Any]) -> Any: ...
|
||||
|
||||
def resolve_provider_icons(
|
||||
self,
|
||||
*,
|
||||
provider_name: str,
|
||||
default_icon: str | None = None,
|
||||
) -> tuple[str | None, str | None]: ...
|
||||
|
||||
|
||||
class HumanInputNodeRuntimeProtocol(Protocol):
|
||||
def invoke_source(self) -> str: ...
|
||||
|
||||
def apply_delivery_runtime(
|
||||
self,
|
||||
*,
|
||||
methods: Sequence[DeliveryChannelConfig],
|
||||
) -> Sequence[DeliveryChannelConfig]: ...
|
||||
|
||||
def console_actor_id(self) -> str | None: ...
|
||||
@@ -4,9 +4,10 @@ from typing import TYPE_CHECKING, Any
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus
|
||||
from dify_graph.node_events import NodeRunResult
|
||||
from dify_graph.nodes.base.entities import VariableSelector
|
||||
from dify_graph.nodes.base.node import Node
|
||||
from dify_graph.nodes.template_transform.entities import TemplateTransformNodeData
|
||||
from dify_graph.nodes.template_transform.template_renderer import (
|
||||
from dify_graph.template_rendering import (
|
||||
Jinja2TemplateRenderer,
|
||||
TemplateRenderError,
|
||||
)
|
||||
@@ -20,7 +21,7 @@ DEFAULT_TEMPLATE_TRANSFORM_MAX_OUTPUT_LENGTH = 400_000
|
||||
|
||||
class TemplateTransformNode(Node[TemplateTransformNodeData]):
|
||||
node_type = BuiltinNodeTypes.TEMPLATE_TRANSFORM
|
||||
_template_renderer: Jinja2TemplateRenderer
|
||||
_jinja2_template_renderer: Jinja2TemplateRenderer
|
||||
_max_output_length: int
|
||||
|
||||
def __init__(
|
||||
@@ -30,7 +31,7 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]):
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
*,
|
||||
template_renderer: Jinja2TemplateRenderer,
|
||||
jinja2_template_renderer: Jinja2TemplateRenderer,
|
||||
max_output_length: int | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
@@ -39,7 +40,7 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]):
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
self._template_renderer = template_renderer
|
||||
self._jinja2_template_renderer = jinja2_template_renderer
|
||||
|
||||
if max_output_length is not None and max_output_length <= 0:
|
||||
raise ValueError("max_output_length must be a positive integer")
|
||||
@@ -70,7 +71,7 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]):
|
||||
variables[variable_name] = value.to_object() if value else None
|
||||
# Run code
|
||||
try:
|
||||
rendered = self._template_renderer.render_template(self.node_data.template, variables)
|
||||
rendered = self._jinja2_template_renderer.render_template(self.node_data.template, variables)
|
||||
except TemplateRenderError as e:
|
||||
return NodeRunResult(inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, error=str(e))
|
||||
|
||||
@@ -87,9 +88,32 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]):
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls, *, graph_config: Mapping[str, Any], node_id: str, node_data: TemplateTransformNodeData
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: TemplateTransformNodeData | Mapping[str, Any],
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
return {
|
||||
node_id + "." + variable_selector.variable: variable_selector.value_selector
|
||||
for variable_selector in node_data.variables
|
||||
}
|
||||
_ = graph_config
|
||||
raw_variables = (
|
||||
node_data.variables if isinstance(node_data, TemplateTransformNodeData) else node_data.get("variables", [])
|
||||
)
|
||||
variable_mapping: dict[str, Sequence[str]] = {}
|
||||
for variable_selector in raw_variables:
|
||||
if isinstance(variable_selector, VariableSelector):
|
||||
variable_mapping[node_id + "." + variable_selector.variable] = variable_selector.value_selector
|
||||
continue
|
||||
|
||||
if not isinstance(variable_selector, Mapping):
|
||||
continue
|
||||
|
||||
variable = variable_selector.get("variable")
|
||||
value_selector = variable_selector.get("value_selector")
|
||||
if (
|
||||
isinstance(variable, str)
|
||||
and isinstance(value_selector, Sequence)
|
||||
and all(isinstance(selector_part, str) for selector_part in value_selector)
|
||||
):
|
||||
variable_mapping[node_id + "." + variable] = list(value_selector)
|
||||
|
||||
return variable_mapping
|
||||
|
||||
@@ -1,13 +1,27 @@
|
||||
from enum import StrEnum, auto
|
||||
from typing import Any, Literal, Union
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
from pydantic_core.core_schema import ValidationInfo
|
||||
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import BuiltinNodeTypes, NodeType
|
||||
|
||||
|
||||
class ToolProviderType(StrEnum):
|
||||
"""
|
||||
Graph-owned enum for persisted tool provider kinds.
|
||||
"""
|
||||
|
||||
PLUGIN = auto()
|
||||
BUILT_IN = "builtin"
|
||||
WORKFLOW = auto()
|
||||
API = auto()
|
||||
APP = auto()
|
||||
DATASET_RETRIEVAL = "dataset-retrieval"
|
||||
MCP = auto()
|
||||
|
||||
|
||||
class ToolEntity(BaseModel):
|
||||
provider_id: str
|
||||
provider_type: ToolProviderType
|
||||
|
||||
@@ -4,6 +4,18 @@ class ToolNodeError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class ToolRuntimeResolutionError(ToolNodeError):
|
||||
"""Raised when the workflow layer cannot construct a tool runtime."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ToolRuntimeInvocationError(ToolNodeError):
|
||||
"""Raised when the workflow layer fails while invoking a tool runtime."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ToolParameterError(ToolNodeError):
|
||||
"""Exception raised for errors in tool parameters."""
|
||||
|
||||
|
||||
@@ -1,12 +1,6 @@
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
||||
from core.tools.errors import ToolInvokeError
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from dify_graph.enums import (
|
||||
BuiltinNodeTypes,
|
||||
@@ -20,10 +14,14 @@ from dify_graph.node_events import NodeEventBase, NodeRunResult, StreamChunkEven
|
||||
from dify_graph.nodes.base.node import Node
|
||||
from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
from dify_graph.nodes.protocols import ToolFileManagerProtocol
|
||||
from dify_graph.nodes.runtime import ToolNodeRuntimeProtocol
|
||||
from dify_graph.nodes.tool_runtime_entities import (
|
||||
ToolRuntimeHandle,
|
||||
ToolRuntimeMessage,
|
||||
ToolRuntimeParameter,
|
||||
)
|
||||
from dify_graph.variables.segments import ArrayAnySegment, ArrayFileSegment
|
||||
from dify_graph.variables.variables import ArrayAnyVariable
|
||||
from factories import file_factory
|
||||
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||
|
||||
from .entities import ToolNodeData
|
||||
from .exc import (
|
||||
@@ -52,6 +50,7 @@ class ToolNode(Node[ToolNodeData]):
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
*,
|
||||
tool_file_manager_factory: ToolFileManagerProtocol,
|
||||
runtime: ToolNodeRuntimeProtocol | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
id=id,
|
||||
@@ -60,6 +59,9 @@ class ToolNode(Node[ToolNodeData]):
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
self._tool_file_manager_factory = tool_file_manager_factory
|
||||
if runtime is None:
|
||||
raise ValueError("runtime is required")
|
||||
self._runtime = runtime
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
@@ -73,10 +75,6 @@ class ToolNode(Node[ToolNodeData]):
|
||||
"""
|
||||
Run the tool node
|
||||
"""
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError, PluginInvokeError
|
||||
|
||||
dify_ctx = self.require_dify_context()
|
||||
|
||||
# fetch tool icon
|
||||
tool_info = {
|
||||
"provider_type": self.node_data.provider_type.value,
|
||||
@@ -86,8 +84,6 @@ class ToolNode(Node[ToolNodeData]):
|
||||
|
||||
# get tool runtime
|
||||
try:
|
||||
from core.tools.tool_manager import ToolManager
|
||||
|
||||
# This is an issue that caused problems before.
|
||||
# Logically, we shouldn't use the node_data.version field for judgment
|
||||
# But for backward compatibility with historical data
|
||||
@@ -95,13 +91,10 @@ class ToolNode(Node[ToolNodeData]):
|
||||
variable_pool: VariablePool | None = None
|
||||
if self.node_data.version != "1" or self.node_data.tool_node_version is not None:
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
tool_runtime = ToolManager.get_workflow_tool_runtime(
|
||||
dify_ctx.tenant_id,
|
||||
dify_ctx.app_id,
|
||||
self._node_id,
|
||||
self.node_data,
|
||||
dify_ctx.invoke_from,
|
||||
variable_pool,
|
||||
tool_runtime = self._runtime.get_runtime(
|
||||
node_id=self._node_id,
|
||||
node_data=self.node_data,
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
except ToolNodeError as e:
|
||||
yield StreamCompletedEvent(
|
||||
@@ -116,7 +109,7 @@ class ToolNode(Node[ToolNodeData]):
|
||||
return
|
||||
|
||||
# get parameters
|
||||
tool_parameters = tool_runtime.get_merged_runtime_parameters() or []
|
||||
tool_parameters = self._runtime.get_runtime_parameters(tool_runtime=tool_runtime)
|
||||
parameters = self._generate_parameters(
|
||||
tool_parameters=tool_parameters,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
@@ -132,14 +125,12 @@ class ToolNode(Node[ToolNodeData]):
|
||||
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
|
||||
|
||||
try:
|
||||
message_stream = ToolEngine.generic_invoke(
|
||||
tool=tool_runtime,
|
||||
message_stream = self._runtime.invoke(
|
||||
tool_runtime=tool_runtime,
|
||||
tool_parameters=parameters,
|
||||
user_id=dify_ctx.user_id,
|
||||
workflow_tool_callback=DifyWorkflowCallbackHandler(),
|
||||
workflow_call_depth=self.workflow_call_depth,
|
||||
app_id=dify_ctx.app_id,
|
||||
conversation_id=conversation_id.text if conversation_id else None,
|
||||
provider_name=self.node_data.provider_name,
|
||||
)
|
||||
except ToolNodeError as e:
|
||||
yield StreamCompletedEvent(
|
||||
@@ -159,38 +150,16 @@ class ToolNode(Node[ToolNodeData]):
|
||||
messages=message_stream,
|
||||
tool_info=tool_info,
|
||||
parameters_for_log=parameters_for_log,
|
||||
user_id=dify_ctx.user_id,
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
node_id=self._node_id,
|
||||
tool_runtime=tool_runtime,
|
||||
)
|
||||
except ToolInvokeError as e:
|
||||
except ToolNodeError as e:
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
|
||||
error=f"Failed to invoke tool {self.node_data.provider_name}: {str(e)}",
|
||||
error_type=type(e).__name__,
|
||||
)
|
||||
)
|
||||
except PluginInvokeError as e:
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
|
||||
error=e.to_user_friendly_error(plugin_name=self.node_data.provider_name),
|
||||
error_type=type(e).__name__,
|
||||
)
|
||||
)
|
||||
except PluginDaemonClientSideError as e:
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
|
||||
error=f"Failed to invoke tool, error: {e.description}",
|
||||
error=str(e),
|
||||
error_type=type(e).__name__,
|
||||
)
|
||||
)
|
||||
@@ -198,7 +167,7 @@ class ToolNode(Node[ToolNodeData]):
|
||||
def _generate_parameters(
|
||||
self,
|
||||
*,
|
||||
tool_parameters: Sequence[ToolParameter],
|
||||
tool_parameters: Sequence[ToolRuntimeParameter],
|
||||
variable_pool: "VariablePool",
|
||||
node_data: ToolNodeData,
|
||||
for_log: bool = False,
|
||||
@@ -207,7 +176,7 @@ class ToolNode(Node[ToolNodeData]):
|
||||
Generate parameters based on the given tool parameters, variable pool, and node data.
|
||||
|
||||
Args:
|
||||
tool_parameters (Sequence[ToolParameter]): The list of tool parameters.
|
||||
tool_parameters (Sequence[ToolRuntimeParameter]): The list of tool parameters.
|
||||
variable_pool (VariablePool): The variable pool containing the variables.
|
||||
node_data (ToolNodeData): The data associated with the tool node.
|
||||
|
||||
@@ -247,40 +216,29 @@ class ToolNode(Node[ToolNodeData]):
|
||||
|
||||
def _transform_message(
|
||||
self,
|
||||
messages: Generator[ToolInvokeMessage, None, None],
|
||||
messages: Generator[ToolRuntimeMessage, None, None],
|
||||
tool_info: Mapping[str, Any],
|
||||
parameters_for_log: dict[str, Any],
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
node_id: str,
|
||||
tool_runtime: Tool,
|
||||
tool_runtime: ToolRuntimeHandle,
|
||||
**_: Any,
|
||||
) -> Generator[NodeEventBase, None, LLMUsage]:
|
||||
"""
|
||||
Convert ToolInvokeMessages into tuple[plain_text, files]
|
||||
Convert graph-owned tool runtime messages into node outputs.
|
||||
"""
|
||||
# transform message and handle file storage
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
|
||||
message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages(
|
||||
messages=messages,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=None,
|
||||
)
|
||||
|
||||
text = ""
|
||||
files: list[File] = []
|
||||
json: list[dict | list] = []
|
||||
|
||||
variables: dict[str, Any] = {}
|
||||
|
||||
for message in message_stream:
|
||||
for message in messages:
|
||||
if message.type in {
|
||||
ToolInvokeMessage.MessageType.IMAGE_LINK,
|
||||
ToolInvokeMessage.MessageType.BINARY_LINK,
|
||||
ToolInvokeMessage.MessageType.IMAGE,
|
||||
ToolRuntimeMessage.MessageType.IMAGE_LINK,
|
||||
ToolRuntimeMessage.MessageType.BINARY_LINK,
|
||||
ToolRuntimeMessage.MessageType.IMAGE,
|
||||
}:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
assert isinstance(message.message, ToolRuntimeMessage.TextMessage)
|
||||
|
||||
url = message.message.text
|
||||
if message.meta:
|
||||
@@ -300,14 +258,11 @@ class ToolNode(Node[ToolNodeData]):
|
||||
"transfer_method": transfer_method,
|
||||
"url": url,
|
||||
}
|
||||
file = file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
file = self._runtime.build_file_reference(mapping=mapping)
|
||||
files.append(file)
|
||||
elif message.type == ToolInvokeMessage.MessageType.BLOB:
|
||||
elif message.type == ToolRuntimeMessage.MessageType.BLOB:
|
||||
# get tool file id
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
assert isinstance(message.message, ToolRuntimeMessage.TextMessage)
|
||||
assert message.meta
|
||||
|
||||
tool_file_id = message.message.text.split("/")[-1].split(".")[0]
|
||||
@@ -320,27 +275,22 @@ class ToolNode(Node[ToolNodeData]):
|
||||
"transfer_method": FileTransferMethod.TOOL_FILE,
|
||||
}
|
||||
|
||||
files.append(
|
||||
file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
)
|
||||
elif message.type == ToolInvokeMessage.MessageType.TEXT:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
files.append(self._runtime.build_file_reference(mapping=mapping))
|
||||
elif message.type == ToolRuntimeMessage.MessageType.TEXT:
|
||||
assert isinstance(message.message, ToolRuntimeMessage.TextMessage)
|
||||
text += message.message.text
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, "text"],
|
||||
chunk=message.message.text,
|
||||
is_final=False,
|
||||
)
|
||||
elif message.type == ToolInvokeMessage.MessageType.JSON:
|
||||
assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
|
||||
elif message.type == ToolRuntimeMessage.MessageType.JSON:
|
||||
assert isinstance(message.message, ToolRuntimeMessage.JsonMessage)
|
||||
# JSON message handling for tool node
|
||||
if message.message.json_object:
|
||||
json.append(message.message.json_object)
|
||||
elif message.type == ToolInvokeMessage.MessageType.LINK:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
elif message.type == ToolRuntimeMessage.MessageType.LINK:
|
||||
assert isinstance(message.message, ToolRuntimeMessage.TextMessage)
|
||||
|
||||
# Check if this LINK message is a file link
|
||||
file_obj = (message.meta or {}).get("file")
|
||||
@@ -356,8 +306,8 @@ class ToolNode(Node[ToolNodeData]):
|
||||
chunk=stream_text,
|
||||
is_final=False,
|
||||
)
|
||||
elif message.type == ToolInvokeMessage.MessageType.VARIABLE:
|
||||
assert isinstance(message.message, ToolInvokeMessage.VariableMessage)
|
||||
elif message.type == ToolRuntimeMessage.MessageType.VARIABLE:
|
||||
assert isinstance(message.message, ToolRuntimeMessage.VariableMessage)
|
||||
variable_name = message.message.variable_name
|
||||
variable_value = message.message.variable_value
|
||||
if message.message.stream:
|
||||
@@ -374,7 +324,7 @@ class ToolNode(Node[ToolNodeData]):
|
||||
)
|
||||
else:
|
||||
variables[variable_name] = variable_value
|
||||
elif message.type == ToolInvokeMessage.MessageType.FILE:
|
||||
elif message.type == ToolRuntimeMessage.MessageType.FILE:
|
||||
assert message.meta is not None
|
||||
assert isinstance(message.meta, dict)
|
||||
# Validate that meta contains a 'file' key
|
||||
@@ -385,38 +335,16 @@ class ToolNode(Node[ToolNodeData]):
|
||||
if not isinstance(message.meta["file"], File):
|
||||
raise ToolNodeError(f"Expected File object but got {type(message.meta['file']).__name__}")
|
||||
files.append(message.meta["file"])
|
||||
elif message.type == ToolInvokeMessage.MessageType.LOG:
|
||||
assert isinstance(message.message, ToolInvokeMessage.LogMessage)
|
||||
elif message.type == ToolRuntimeMessage.MessageType.LOG:
|
||||
assert isinstance(message.message, ToolRuntimeMessage.LogMessage)
|
||||
if message.message.metadata:
|
||||
icon = tool_info.get("icon", "")
|
||||
dict_metadata = dict(message.message.metadata)
|
||||
if dict_metadata.get("provider"):
|
||||
manager = PluginInstaller()
|
||||
plugins = manager.list_plugins(tenant_id)
|
||||
try:
|
||||
current_plugin = next(
|
||||
plugin
|
||||
for plugin in plugins
|
||||
if f"{plugin.plugin_id}/{plugin.name}" == dict_metadata["provider"]
|
||||
)
|
||||
icon = current_plugin.declaration.icon
|
||||
except StopIteration:
|
||||
pass
|
||||
icon_dark = None
|
||||
try:
|
||||
builtin_tool = next(
|
||||
provider
|
||||
for provider in BuiltinToolManageService.list_builtin_tools(
|
||||
user_id,
|
||||
tenant_id,
|
||||
)
|
||||
if provider.name == dict_metadata["provider"]
|
||||
)
|
||||
icon = builtin_tool.icon
|
||||
icon_dark = builtin_tool.icon_dark
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
icon, icon_dark = self._runtime.resolve_provider_icons(
|
||||
provider_name=dict_metadata["provider"],
|
||||
default_icon=icon,
|
||||
)
|
||||
dict_metadata["icon"] = icon
|
||||
dict_metadata["icon_dark"] = icon_dark
|
||||
message.message.metadata = dict_metadata
|
||||
@@ -446,7 +374,7 @@ class ToolNode(Node[ToolNodeData]):
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
usage = self._extract_tool_usage(tool_runtime)
|
||||
usage = self._runtime.get_usage(tool_runtime=tool_runtime)
|
||||
|
||||
metadata: dict[WorkflowNodeExecutionMetadataKey, Any] = {
|
||||
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
|
||||
@@ -468,21 +396,6 @@ class ToolNode(Node[ToolNodeData]):
|
||||
|
||||
return usage
|
||||
|
||||
@staticmethod
|
||||
def _extract_tool_usage(tool_runtime: Tool) -> LLMUsage:
|
||||
# Avoid importing WorkflowTool at module import time; rely on duck typing
|
||||
# Some runtimes expose `latest_usage`; mocks may synthesize arbitrary attributes.
|
||||
latest = getattr(tool_runtime, "latest_usage", None)
|
||||
# Normalize into a concrete LLMUsage. MagicMock returns truthy attribute objects
|
||||
# for any name, so we must type-check here.
|
||||
if isinstance(latest, LLMUsage):
|
||||
return latest
|
||||
if isinstance(latest, dict):
|
||||
# Allow dict payloads from external runtimes
|
||||
return LLMUsage.model_validate(latest)
|
||||
# Fallback to empty usage when attribute is missing or not a valid payload
|
||||
return LLMUsage.empty_usage()
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
|
||||
101
api/dify_graph/nodes/tool_runtime_entities.py
Normal file
101
api/dify_graph/nodes/tool_runtime_entities.py
Normal file
@@ -0,0 +1,101 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import StrEnum, auto
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class _ToolRuntimeModel(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class ToolRuntimeHandle:
|
||||
"""Opaque graph-owned handle for a workflow-layer tool runtime."""
|
||||
|
||||
raw: object
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class ToolRuntimeParameter:
|
||||
"""Graph-owned parameter shape used by tool nodes."""
|
||||
|
||||
name: str
|
||||
required: bool = False
|
||||
|
||||
|
||||
class ToolRuntimeMessage(_ToolRuntimeModel):
|
||||
"""Graph-owned tool invocation message DTO."""
|
||||
|
||||
class TextMessage(_ToolRuntimeModel):
|
||||
text: str
|
||||
|
||||
class JsonMessage(_ToolRuntimeModel):
|
||||
json_object: dict[str, Any] | list[Any]
|
||||
suppress_output: bool = Field(default=False)
|
||||
|
||||
class BlobMessage(_ToolRuntimeModel):
|
||||
blob: bytes
|
||||
|
||||
class BlobChunkMessage(_ToolRuntimeModel):
|
||||
id: str
|
||||
sequence: int
|
||||
total_length: int
|
||||
blob: bytes
|
||||
end: bool
|
||||
|
||||
class FileMessage(_ToolRuntimeModel):
|
||||
file_marker: str = Field(default="file_marker")
|
||||
|
||||
class VariableMessage(_ToolRuntimeModel):
|
||||
variable_name: str
|
||||
variable_value: dict[str, Any] | list[Any] | str | int | float | bool | None
|
||||
stream: bool = Field(default=False)
|
||||
|
||||
class LogMessage(_ToolRuntimeModel):
|
||||
class LogStatus(StrEnum):
|
||||
START = auto()
|
||||
ERROR = auto()
|
||||
SUCCESS = auto()
|
||||
|
||||
id: str
|
||||
label: str
|
||||
parent_id: str | None = None
|
||||
error: str | None = None
|
||||
status: LogStatus
|
||||
data: dict[str, Any]
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
class RetrieverResourceMessage(_ToolRuntimeModel):
|
||||
retriever_resources: list[dict[str, Any]]
|
||||
context: str
|
||||
|
||||
class MessageType(StrEnum):
|
||||
TEXT = auto()
|
||||
IMAGE = auto()
|
||||
LINK = auto()
|
||||
BLOB = auto()
|
||||
JSON = auto()
|
||||
IMAGE_LINK = auto()
|
||||
BINARY_LINK = auto()
|
||||
VARIABLE = auto()
|
||||
FILE = auto()
|
||||
LOG = auto()
|
||||
BLOB_CHUNK = auto()
|
||||
RETRIEVER_RESOURCES = auto()
|
||||
|
||||
type: MessageType = MessageType.TEXT
|
||||
message: (
|
||||
JsonMessage
|
||||
| TextMessage
|
||||
| BlobChunkMessage
|
||||
| BlobMessage
|
||||
| LogMessage
|
||||
| FileMessage
|
||||
| None
|
||||
| VariableMessage
|
||||
| RetrieverResourceMessage
|
||||
)
|
||||
meta: dict[str, Any] | None = None
|
||||
47
api/dify_graph/prompt_entities.py
Normal file
47
api/dify_graph/prompt_entities.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from dify_graph.model_runtime.entities.message_entities import PromptMessageRole
|
||||
|
||||
|
||||
class ChatModelMessage(BaseModel):
|
||||
"""Graph-owned chat prompt template message."""
|
||||
|
||||
text: str
|
||||
role: PromptMessageRole
|
||||
edition_type: Literal["basic", "jinja2"] | None = None
|
||||
|
||||
|
||||
class CompletionModelPromptTemplate(BaseModel):
|
||||
"""Graph-owned completion prompt template."""
|
||||
|
||||
text: str
|
||||
edition_type: Literal["basic", "jinja2"] | None = None
|
||||
|
||||
|
||||
class MemoryConfig(BaseModel):
|
||||
"""Graph-owned memory configuration for prompt assembly."""
|
||||
|
||||
class RolePrefix(BaseModel):
|
||||
"""Role labels used when serializing completion-model histories."""
|
||||
|
||||
user: str
|
||||
assistant: str
|
||||
|
||||
class WindowConfig(BaseModel):
|
||||
"""History windowing controls."""
|
||||
|
||||
enabled: bool
|
||||
size: int | None = None
|
||||
|
||||
role_prefix: RolePrefix | None = None
|
||||
window: WindowConfig
|
||||
query_prompt_template: str | None = None
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ChatModelMessage",
|
||||
"CompletionModelPromptTemplate",
|
||||
"MemoryConfig",
|
||||
]
|
||||
@@ -18,9 +18,6 @@ class FormNotFoundError(HumanInputError):
|
||||
|
||||
@dataclasses.dataclass
|
||||
class FormCreateParams:
|
||||
# app_id is the identifier for the app that the form belongs to.
|
||||
# It is a string with uuid format.
|
||||
app_id: str
|
||||
# None when creating a delivery test form; set for runtime forms.
|
||||
workflow_execution_id: str | None
|
||||
|
||||
@@ -45,6 +42,9 @@ class FormCreateParams:
|
||||
resolved_default_values: Mapping[str, Any]
|
||||
form_kind: HumanInputFormKind = HumanInputFormKind.RUNTIME
|
||||
|
||||
# Optional application identifier. Implementations may bind this at construction time.
|
||||
app_id: str | None = None
|
||||
|
||||
# Force creating a console-only recipient for submission in Console.
|
||||
console_recipient_required: bool = False
|
||||
console_creator_account_id: str | None = None
|
||||
|
||||
@@ -8,19 +8,17 @@ from dify_graph.nodes.code.entities import CodeLanguage
|
||||
|
||||
|
||||
class TemplateRenderError(ValueError):
|
||||
"""Raised when rendering a Jinja2 template fails."""
|
||||
"""Raised when rendering a template fails."""
|
||||
|
||||
|
||||
class Jinja2TemplateRenderer(Protocol):
|
||||
"""Render Jinja2 templates for template transform nodes."""
|
||||
"""Shared contract for rendering Jinja2 templates in graph nodes."""
|
||||
|
||||
def render_template(self, template: str, variables: Mapping[str, Any]) -> str:
|
||||
"""Render a Jinja2 template with provided variables."""
|
||||
raise NotImplementedError
|
||||
def render_template(self, template: str, variables: Mapping[str, Any]) -> str: ...
|
||||
|
||||
|
||||
class CodeExecutorJinja2TemplateRenderer(Jinja2TemplateRenderer):
|
||||
"""Adapter that renders Jinja2 templates via CodeExecutor."""
|
||||
"""Adapter that renders Jinja2 templates via the workflow code executor."""
|
||||
|
||||
_code_executor: WorkflowCodeExecutor
|
||||
|
||||
20
api/dify_graph/utils/datetime_utils.py
Normal file
20
api/dify_graph/utils/datetime_utils.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import datetime
|
||||
from typing import Protocol
|
||||
|
||||
|
||||
class _NowFunction(Protocol):
|
||||
@abc.abstractmethod
|
||||
def __call__(self, tz: datetime.timezone | None) -> datetime.datetime:
|
||||
"""Return the current time for the requested timezone."""
|
||||
...
|
||||
|
||||
|
||||
_now_func: _NowFunction = datetime.datetime.now
|
||||
|
||||
|
||||
def naive_utc_now() -> datetime.datetime:
|
||||
"""Return the current UTC time as a naive datetime."""
|
||||
return _now_func(datetime.UTC).replace(tzinfo=None)
|
||||
58
api/dify_graph/utils/json_in_md_parser.py
Normal file
58
api/dify_graph/utils/json_in_md_parser.py
Normal file
@@ -0,0 +1,58 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
|
||||
class OutputParserError(ValueError):
|
||||
"""Raised when a markdown-wrapped JSON payload cannot be parsed or validated."""
|
||||
|
||||
|
||||
def parse_json_markdown(json_string: str) -> dict | list:
|
||||
"""Extract and parse the first JSON object or array embedded in markdown text."""
|
||||
json_string = json_string.strip()
|
||||
starts = ["```json", "```", "``", "`", "{", "["]
|
||||
ends = ["```", "``", "`", "}", "]"]
|
||||
end_index = -1
|
||||
start_index = 0
|
||||
|
||||
for start_marker in starts:
|
||||
start_index = json_string.find(start_marker)
|
||||
if start_index != -1:
|
||||
if json_string[start_index] not in ("{", "["):
|
||||
start_index += len(start_marker)
|
||||
break
|
||||
|
||||
if start_index != -1:
|
||||
for end_marker in ends:
|
||||
end_index = json_string.rfind(end_marker, start_index)
|
||||
if end_index != -1:
|
||||
if json_string[end_index] in ("}", "]"):
|
||||
end_index += 1
|
||||
break
|
||||
|
||||
if start_index == -1 or end_index == -1 or start_index >= end_index:
|
||||
raise ValueError("could not find json block in the output.")
|
||||
|
||||
extracted_content = json_string[start_index:end_index].strip()
|
||||
return json.loads(extracted_content)
|
||||
|
||||
|
||||
def parse_and_check_json_markdown(text: str, expected_keys: list[str]) -> dict:
|
||||
try:
|
||||
json_obj = parse_json_markdown(text)
|
||||
except json.JSONDecodeError as exc:
|
||||
raise OutputParserError(f"got invalid json object. error: {exc}") from exc
|
||||
|
||||
if isinstance(json_obj, list):
|
||||
if len(json_obj) == 1 and isinstance(json_obj[0], dict):
|
||||
json_obj = json_obj[0]
|
||||
else:
|
||||
raise OutputParserError(f"got invalid return object. obj:{json_obj}")
|
||||
|
||||
for key in expected_keys:
|
||||
if key not in json_obj:
|
||||
raise OutputParserError(
|
||||
f"got invalid return object. expected key `{key}` to be present, but got {json_obj}"
|
||||
)
|
||||
|
||||
return json_obj
|
||||
@@ -92,7 +92,7 @@ class AppService:
|
||||
default_model_config = default_model_config.copy() if default_model_config else None
|
||||
if default_model_config and "model" in default_model_config:
|
||||
# get model provider
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=account.current_tenant_id or "")
|
||||
|
||||
# get default model instance
|
||||
try:
|
||||
|
||||
@@ -61,7 +61,7 @@ class AudioService:
|
||||
message = f"Audio size larger than {FILE_SIZE} mb"
|
||||
raise AudioTooLargeServiceError(message)
|
||||
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=app_model.tenant_id, user_id=end_user)
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
tenant_id=app_model.tenant_id, model_type=ModelType.SPEECH2TEXT
|
||||
)
|
||||
@@ -71,7 +71,7 @@ class AudioService:
|
||||
buffer = io.BytesIO(file_content)
|
||||
buffer.name = "temp.mp3"
|
||||
|
||||
return {"text": model_instance.invoke_speech2text(file=buffer, user=end_user)}
|
||||
return {"text": model_instance.invoke_speech2text(file=buffer)}
|
||||
|
||||
@classmethod
|
||||
def transcript_tts(
|
||||
@@ -109,7 +109,7 @@ class AudioService:
|
||||
|
||||
voice = cast(str | None, text_to_speech_dict.get("voice"))
|
||||
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=app_model.tenant_id, user_id=end_user)
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
tenant_id=app_model.tenant_id, model_type=ModelType.TTS
|
||||
)
|
||||
@@ -123,9 +123,7 @@ class AudioService:
|
||||
else:
|
||||
raise ValueError("Sorry, no voice available.")
|
||||
|
||||
return model_instance.invoke_tts(
|
||||
content_text=text_content.strip(), user=end_user, tenant_id=app_model.tenant_id, voice=voice
|
||||
)
|
||||
return model_instance.invoke_tts(content_text=text_content.strip(), voice=voice)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
@@ -155,7 +153,7 @@ class AudioService:
|
||||
|
||||
@classmethod
|
||||
def transcript_tts_voices(cls, tenant_id: str, language: str):
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
|
||||
model_instance = model_manager.get_default_model_instance(tenant_id=tenant_id, model_type=ModelType.TTS)
|
||||
if model_instance is None:
|
||||
raise ProviderNotSupportTextToSpeechServiceError()
|
||||
|
||||
@@ -228,7 +228,7 @@ class DatasetService:
|
||||
raise DatasetNameDuplicateError(f"Dataset with name {name} already exists.")
|
||||
embedding_model = None
|
||||
if indexing_technique == "high_quality":
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
|
||||
if embedding_model_provider and embedding_model_name:
|
||||
# check if embedding model setting is valid
|
||||
DatasetService.check_embedding_model_setting(tenant_id, embedding_model_provider, embedding_model_name)
|
||||
@@ -350,7 +350,7 @@ class DatasetService:
|
||||
def check_dataset_model_setting(dataset):
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=dataset.tenant_id)
|
||||
model_manager.get_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
@@ -367,7 +367,7 @@ class DatasetService:
|
||||
@staticmethod
|
||||
def check_embedding_model_setting(tenant_id: str, embedding_model_provider: str, embedding_model: str):
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
|
||||
model_manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
provider=embedding_model_provider,
|
||||
@@ -384,7 +384,7 @@ class DatasetService:
|
||||
@staticmethod
|
||||
def check_is_multimodal_model(tenant_id: str, model_provider: str, model: str):
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
provider=model_provider,
|
||||
@@ -405,7 +405,7 @@ class DatasetService:
|
||||
@staticmethod
|
||||
def check_reranking_model_setting(tenant_id: str, reranking_model_provider: str, reranking_model: str):
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
|
||||
model_manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
provider=reranking_model_provider,
|
||||
@@ -742,7 +742,7 @@ class DatasetService:
|
||||
"""
|
||||
# assert isinstance(current_user, Account) and current_user.current_tenant_id is not None
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=current_user.current_tenant_id)
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
@@ -860,7 +860,7 @@ class DatasetService:
|
||||
"""
|
||||
# assert isinstance(current_user, Account) and current_user.current_tenant_id is not None
|
||||
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=current_user.current_tenant_id)
|
||||
try:
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
@@ -954,7 +954,7 @@ class DatasetService:
|
||||
dataset.chunk_structure = knowledge_configuration.chunk_structure
|
||||
dataset.indexing_technique = knowledge_configuration.indexing_technique
|
||||
if knowledge_configuration.indexing_technique == "high_quality":
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=current_user.current_tenant_id)
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id, # ignore type error
|
||||
provider=knowledge_configuration.embedding_model_provider or "",
|
||||
@@ -996,7 +996,7 @@ class DatasetService:
|
||||
action = "add"
|
||||
# get embedding model setting
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=current_user.current_tenant_id)
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider=knowledge_configuration.embedding_model_provider,
|
||||
@@ -1049,7 +1049,7 @@ class DatasetService:
|
||||
or knowledge_configuration.embedding_model != dataset.embedding_model
|
||||
):
|
||||
action = "update"
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=current_user.current_tenant_id)
|
||||
embedding_model = None
|
||||
try:
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
@@ -1908,7 +1908,7 @@ class DocumentService:
|
||||
|
||||
dataset.indexing_technique = knowledge_config.indexing_technique
|
||||
if knowledge_config.indexing_technique == "high_quality":
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=current_user.current_tenant_id)
|
||||
if knowledge_config.embedding_model and knowledge_config.embedding_model_provider:
|
||||
dataset_embedding_model = knowledge_config.embedding_model
|
||||
dataset_embedding_model_provider = knowledge_config.embedding_model_provider
|
||||
@@ -2220,7 +2220,7 @@ class DocumentService:
|
||||
|
||||
# dataset.indexing_technique = knowledge_config.indexing_technique
|
||||
# if knowledge_config.indexing_technique == "high_quality":
|
||||
# model_manager = ModelManager()
|
||||
# model_manager = ModelManager.for_tenant(tenant_id=current_user.current_tenant_id)
|
||||
# if knowledge_config.embedding_model and knowledge_config.embedding_model_provider:
|
||||
# dataset_embedding_model = knowledge_config.embedding_model
|
||||
# dataset_embedding_model_provider = knowledge_config.embedding_model_provider
|
||||
@@ -3125,7 +3125,7 @@ class SegmentService:
|
||||
segment_hash = helper.generate_text_hash(content)
|
||||
tokens = 0
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=current_user.current_tenant_id)
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
@@ -3208,7 +3208,7 @@ class SegmentService:
|
||||
with redis_client.lock(lock_name, timeout=600):
|
||||
embedding_model = None
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=current_user.current_tenant_id)
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
@@ -3346,7 +3346,7 @@ class SegmentService:
|
||||
# get embedding model instance
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
# check embedding model setting
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=dataset.tenant_id)
|
||||
|
||||
if dataset.embedding_model_provider:
|
||||
embedding_model_instance = model_manager.get_model_instance(
|
||||
@@ -3409,7 +3409,7 @@ class SegmentService:
|
||||
segment_hash = helper.generate_text_hash(content)
|
||||
tokens = 0
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=current_user.current_tenant_id)
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
@@ -3450,7 +3450,7 @@ class SegmentService:
|
||||
# get embedding model instance
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
# check embedding model setting
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=dataset.tenant_id)
|
||||
|
||||
if dataset.embedding_model_provider:
|
||||
embedding_model_instance = model_manager.get_model_instance(
|
||||
|
||||
@@ -255,7 +255,7 @@ class MessageService:
|
||||
app_model=app_model, conversation_id=message.conversation_id, user=user
|
||||
)
|
||||
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=app_model.tenant_id)
|
||||
|
||||
if app_model.mode == AppMode.ADVANCED_CHAT:
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user