From fbb74a4af9c2bfb40c36546baf17660dfd6e6bfd Mon Sep 17 00:00:00 2001 From: -LAN- Date: Sun, 15 Mar 2026 15:34:47 +0800 Subject: [PATCH] feat: extract model runtime Signed-off-by: -LAN- --- api/.importlinter | 52 -- api/contexts/__init__.py | 9 - api/controllers/console/datasets/datasets.py | 6 +- .../console/datasets/datasets_document.py | 2 +- .../console/datasets/datasets_segments.py | 8 +- api/controllers/console/workspace/models.py | 12 +- .../service_api/dataset/dataset.py | 6 +- .../service_api/dataset/segment.py | 8 +- api/core/agent/cot_agent_runner.py | 1 - api/core/agent/fc_agent_runner.py | 1 - .../model_config/converter.py | 4 +- .../easy_ui_based_app/model_config/manager.py | 7 +- api/core/app/entities/app_invoke_entities.py | 4 +- api/core/app/llm/model_access.py | 89 ++- api/core/app/llm/protocols.py | 21 + api/core/app/workflow/layers/llm_quota.py | 3 +- .../base/tts/app_generator_tts_publisher.py | 14 +- api/core/entities/embedding_type.py | 11 +- api/core/entities/provider_configuration.py | 12 +- api/core/helper/moderation.py | 4 +- api/core/indexing_runner.py | 17 +- api/core/llm_generator/llm_generator.py | 16 +- .../output_parser/structured_output.py | 6 - api/core/model_manager.py | 55 +- .../openai_moderation/openai_moderation.py | 2 +- api/core/plugin/backwards_invocation/model.py | 63 +- api/core/plugin/impl/model.py | 155 ++--- api/core/plugin/impl/model_runtime.py | 490 ++++++++++++++++ api/core/plugin/impl/model_runtime_factory.py | 45 ++ .../entities/advanced_prompt_entities.py | 55 +- api/core/provider_manager.py | 19 +- .../data_post_processor.py | 6 +- api/core/rag/datasource/retrieval_service.py | 2 +- api/core/rag/datasource/vdb/vector_factory.py | 2 +- api/core/rag/docstore/dataset_docstore.py | 2 +- api/core/rag/embedding/cached_embedding.py | 25 +- .../processor/paragraph_index_processor.py | 4 +- api/core/rag/rerank/rerank_model.py | 24 +- api/core/rag/rerank/weight_rerank.py | 2 +- api/core/rag/retrieval/dataset_retrieval.py | 11 +- .../router/multi_dataset_react_route.py | 14 +- .../repositories/human_input_repository.py | 7 +- .../builtin_tool/providers/audio/tools/asr.py | 14 +- .../builtin_tool/providers/audio/tools/tts.py | 16 +- .../dataset_multi_retriever_tool.py | 2 +- .../tools/utils/model_invocation_utils.py | 7 +- api/core/workflow/node_factory.py | 117 ++-- api/core/workflow/node_runtime.py | 532 +++++++++++++++++ api/core/workflow/nodes/agent/agent_node.py | 7 +- .../workflow/nodes/agent/runtime_support.py | 21 +- .../nodes/datasource/datasource_node.py | 4 +- .../knowledge_retrieval_node.py | 3 +- .../workflow/nodes/trigger_webhook/node.py | 15 +- api/dify_graph/entities/graph_init_params.py | 2 - api/dify_graph/entities/workflow_execution.py | 2 +- api/dify_graph/graph/graph.py | 5 +- api/dify_graph/graph_events/node.py | 6 +- .../entities/provider_entities.py | 14 +- .../model_runtime/entities/rerank_entities.py | 7 + .../entities/text_embedding_entities.py | 8 + .../model_providers/__base/ai_model.py | 133 ++--- .../__base/large_language_model.py | 66 +-- .../__base/moderation_model.py | 18 +- .../model_providers/__base/rerank_model.py | 30 +- .../__base/speech2text_model.py | 18 +- .../__base/text_embedding_model.py | 37 +- .../model_providers/__base/tts_model.py | 29 +- .../model_providers/model_provider_factory.py | 344 +++-------- api/dify_graph/model_runtime/runtime.py | 159 +++++ api/dify_graph/nodes/base/node.py | 55 +- api/dify_graph/nodes/http_request/node.py | 25 +- .../nodes/human_input/human_input_node.py | 28 +- .../nodes/iteration/iteration_node.py | 2 +- api/dify_graph/nodes/llm/entities.py | 2 +- api/dify_graph/nodes/llm/file_saver.py | 53 +- api/dify_graph/nodes/llm/llm_utils.py | 46 +- api/dify_graph/nodes/llm/node.py | 547 ++++++++++++++---- api/dify_graph/nodes/llm/protocols.py | 8 +- api/dify_graph/nodes/llm/runtime_protocols.py | 74 +++ api/dify_graph/nodes/loop/loop_node.py | 2 +- .../nodes/parameter_extractor/entities.py | 2 +- .../parameter_extractor_node.py | 204 +++---- api/dify_graph/nodes/protocols.py | 8 +- .../nodes/question_classifier/entities.py | 2 +- .../question_classifier_node.py | 56 +- api/dify_graph/nodes/runtime.py | 75 +++ .../template_transform_node.py | 44 +- api/dify_graph/nodes/tool/entities.py | 16 +- api/dify_graph/nodes/tool/exc.py | 12 + api/dify_graph/nodes/tool/tool_node.py | 189 ++---- api/dify_graph/nodes/tool_runtime_entities.py | 101 ++++ api/dify_graph/prompt_entities.py | 47 ++ .../human_input_form_repository.py | 6 +- ...late_renderer.py => template_rendering.py} | 10 +- api/dify_graph/utils/datetime_utils.py | 20 + api/dify_graph/utils/json_in_md_parser.py | 58 ++ api/services/app_service.py | 2 +- api/services/audio_service.py | 12 +- api/services/dataset_service.py | 34 +- api/services/message_service.py | 2 +- api/services/model_load_balancing_service.py | 21 +- api/services/model_provider_service.py | 43 +- api/services/summary_index_service.py | 2 +- api/services/vector_service.py | 2 +- api/services/workflow_service.py | 11 +- .../batch_create_segment_to_index_task.py | 2 +- .../test_datasource_node_integration.py | 2 +- .../workflow/nodes/__mock/model.py | 4 +- .../workflow/nodes/test_http.py | 3 + .../workflow/nodes/test_llm.py | 14 +- .../nodes/test_parameter_extractor.py | 2 + .../workflow/nodes/test_template_transform.py | 4 +- .../workflow/nodes/test_tool.py | 2 + .../test_human_input_resume_node_execution.py | 2 + .../services/test_agent_service.py | 2 +- .../services/test_app_dsl_service.py | 2 +- .../services/test_app_service.py | 2 +- .../services/test_dataset_service.py | 6 +- .../test_dataset_service_update_dataset.py | 6 +- .../services/test_message_service.py | 2 +- .../services/test_saved_message_service.py | 2 +- .../services/test_web_conversation_service.py | 2 +- .../services/test_workflow_app_service.py | 2 +- .../services/test_workflow_run_service.py | 2 +- .../test_workflow_tools_manage_service.py | 2 +- ...test_batch_create_segment_to_index_task.py | 5 +- .../core/llm_generator/test_llm_generator.py | 6 +- .../test_model_provider_factory.py | 60 ++ .../moderation/test_content_moderation.py | 30 +- .../core/plugin/test_model_runtime_adapter.py | 229 ++++++++ .../rag/docstore/test_dataset_docstore.py | 2 +- .../core/rag/indexing/test_indexing_runner.py | 6 +- .../core/rag/rerank/test_reranker.py | 22 +- .../rag/retrieval/test_dataset_retrieval.py | 10 +- .../unit_tests/core/test_provider_manager.py | 21 +- .../core/tools/test_builtin_tools_extra.py | 23 +- .../utils/test_model_invocation_utils.py | 8 +- .../graph_engine/layers/test_llm_quota.py | 21 +- .../graph_engine/test_command_system.py | 4 +- .../test_human_input_pause_multi_branch.py | 2 + .../test_human_input_pause_single_branch.py | 2 + .../test_mock_iteration_simple.py | 2 +- .../workflow/graph_engine/test_mock_nodes.py | 24 +- .../test_mock_nodes_template_code.py | 2 +- .../workflow/graph_engine/test_mock_simple.py | 2 +- .../test_parallel_human_input_join_resume.py | 3 + ...rallel_human_input_pause_missing_finish.py | 3 + .../test_pause_deferred_ready_nodes.py | 2 + .../graph_engine/test_pause_resume_state.py | 2 + .../graph_engine/test_table_runner.py | 4 +- .../nodes/datasource/test_datasource_node.py | 2 +- .../http_request/test_http_request_node.py | 2 + .../nodes/human_input/test_entities.py | 7 +- .../test_human_input_form_filled_event.py | 7 +- .../workflow/nodes/list_operator/node_spec.py | 2 +- .../core/workflow/nodes/llm/test_node.py | 174 ++++-- .../template_transform_node_spec.py | 28 +- .../core/workflow/nodes/test_base_node.py | 5 +- .../core/workflow/nodes/test_if_else.py | 3 +- .../core/workflow/nodes/test_list_operator.py | 3 +- .../workflow/nodes/tool/test_tool_node.py | 79 ++- .../nodes/tool/test_tool_node_runtime.py | 126 ++++ .../v1/test_variable_assigner_v1.py | 3 +- .../v2/test_variable_assigner_v2.py | 3 +- .../webhook/test_webhook_file_conversion.py | 4 +- .../nodes/webhook/test_webhook_node.py | 4 +- .../core/workflow/test_node_factory.py | 236 ++++---- ...large_language_model_non_stream_parsing.py | 22 +- .../model_providers/__base/test_ai_model.py | 252 +++----- .../__base/test_runtime_user_forwarding.py | 170 ++++++ .../services/dataset_service_update_delete.py | 2 +- .../services/document_service_validation.py | 6 +- .../unit_tests/services/segment_service.py | 2 +- .../unit_tests/services/test_app_service.py | 8 +- .../unit_tests/services/test_audio_service.py | 24 +- .../services/test_message_service.py | 4 +- ...est_model_provider_service_sanitization.py | 2 +- .../unit_tests/services/vector_service.py | 4 +- 178 files changed, 4343 insertions(+), 2134 deletions(-) create mode 100644 api/core/app/llm/protocols.py create mode 100644 api/core/plugin/impl/model_runtime.py create mode 100644 api/core/plugin/impl/model_runtime_factory.py create mode 100644 api/core/workflow/node_runtime.py create mode 100644 api/dify_graph/model_runtime/runtime.py create mode 100644 api/dify_graph/nodes/llm/runtime_protocols.py create mode 100644 api/dify_graph/nodes/runtime.py create mode 100644 api/dify_graph/nodes/tool_runtime_entities.py create mode 100644 api/dify_graph/prompt_entities.py rename api/dify_graph/{nodes/template_transform/template_renderer.py => template_rendering.py} (79%) create mode 100644 api/dify_graph/utils/datetime_utils.py create mode 100644 api/dify_graph/utils/json_in_md_parser.py create mode 100644 api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py create mode 100644 api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py create mode 100644 api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node_runtime.py create mode 100644 api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_runtime_user_forwarding.py diff --git a/api/.importlinter b/api/.importlinter index a836d09088..fe1d132535 100644 --- a/api/.importlinter +++ b/api/.importlinter @@ -33,20 +33,6 @@ ignore_imports = # TODO(QuantumGhost): fix the import violation later dify_graph.entities.pause_reason -> dify_graph.nodes.human_input.entities -[importlinter:contract:workflow-infrastructure-dependencies] -name = Workflow Infrastructure Dependencies -type = forbidden -source_modules = - dify_graph -forbidden_modules = - extensions.ext_database - extensions.ext_redis -allow_indirect_imports = True -ignore_imports = - dify_graph.nodes.llm.node -> extensions.ext_database - dify_graph.model_runtime.model_providers.__base.ai_model -> extensions.ext_redis - dify_graph.model_runtime.model_providers.model_provider_factory -> extensions.ext_redis - [importlinter:contract:workflow-external-imports] name = Workflow External Imports type = forbidden @@ -89,45 +75,7 @@ forbidden_modules = core.trigger core.variables ignore_imports = - dify_graph.nodes.llm.llm_utils -> core.model_manager - dify_graph.nodes.llm.protocols -> core.model_manager dify_graph.nodes.llm.llm_utils -> dify_graph.model_runtime.model_providers.__base.large_language_model - dify_graph.nodes.llm.node -> core.tools.signature - dify_graph.nodes.tool.tool_node -> core.callback_handler.workflow_tool_callback_handler - dify_graph.nodes.tool.tool_node -> core.tools.tool_engine - dify_graph.nodes.tool.tool_node -> core.tools.tool_manager - dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.advanced_prompt_transform - dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.simple_prompt_transform - dify_graph.nodes.parameter_extractor.parameter_extractor_node -> dify_graph.model_runtime.model_providers.__base.large_language_model - dify_graph.nodes.question_classifier.question_classifier_node -> core.prompt.simple_prompt_transform - dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.model_manager - dify_graph.nodes.question_classifier.question_classifier_node -> core.model_manager - dify_graph.nodes.tool.tool_node -> core.tools.utils.message_transformer - dify_graph.nodes.llm.node -> core.llm_generator.output_parser.errors - dify_graph.nodes.llm.node -> core.llm_generator.output_parser.structured_output - dify_graph.nodes.llm.node -> core.model_manager - dify_graph.nodes.llm.entities -> core.prompt.entities.advanced_prompt_entities - dify_graph.nodes.llm.node -> core.prompt.entities.advanced_prompt_entities - dify_graph.nodes.llm.node -> core.prompt.utils.prompt_message_util - dify_graph.nodes.parameter_extractor.entities -> core.prompt.entities.advanced_prompt_entities - dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.entities.advanced_prompt_entities - dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.utils.prompt_message_util - dify_graph.nodes.question_classifier.entities -> core.prompt.entities.advanced_prompt_entities - dify_graph.nodes.question_classifier.question_classifier_node -> core.prompt.utils.prompt_message_util - dify_graph.nodes.llm.node -> models.dataset - dify_graph.nodes.llm.file_saver -> core.tools.signature - dify_graph.nodes.llm.file_saver -> core.tools.tool_file_manager - dify_graph.nodes.tool.tool_node -> core.tools.errors - dify_graph.nodes.llm.node -> extensions.ext_database - dify_graph.nodes.llm.node -> models.model - dify_graph.nodes.tool.tool_node -> services - dify_graph.model_runtime.model_providers.__base.ai_model -> configs - dify_graph.model_runtime.model_providers.__base.ai_model -> extensions.ext_redis - dify_graph.model_runtime.model_providers.__base.large_language_model -> configs - dify_graph.model_runtime.model_providers.__base.text_embedding_model -> core.entities.embedding_type - dify_graph.model_runtime.model_providers.model_provider_factory -> configs - dify_graph.model_runtime.model_providers.model_provider_factory -> extensions.ext_redis - dify_graph.model_runtime.model_providers.model_provider_factory -> models.provider_ids [importlinter:contract:rsc] name = RSC diff --git a/api/contexts/__init__.py b/api/contexts/__init__.py index c52dcf8a57..764f9f8ee2 100644 --- a/api/contexts/__init__.py +++ b/api/contexts/__init__.py @@ -6,7 +6,6 @@ from contexts.wrapper import RecyclableContextVar if TYPE_CHECKING: from core.datasource.__base.datasource_provider import DatasourcePluginProviderController - from core.plugin.entities.plugin_daemon import PluginModelProviderEntity from core.tools.plugin_tool.provider import PluginToolProviderController from core.trigger.provider import PluginTriggerProviderController @@ -20,14 +19,6 @@ plugin_tool_providers: RecyclableContextVar[dict[str, "PluginToolProviderControl plugin_tool_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(ContextVar("plugin_tool_providers_lock")) -plugin_model_providers: RecyclableContextVar[list["PluginModelProviderEntity"] | None] = RecyclableContextVar( - ContextVar("plugin_model_providers") -) - -plugin_model_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar( - ContextVar("plugin_model_providers_lock") -) - datasource_plugin_providers: RecyclableContextVar[dict[str, "DatasourcePluginProviderController"]] = ( RecyclableContextVar(ContextVar("datasource_plugin_providers")) ) diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 725a8380cd..a5c6c01958 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -25,7 +25,7 @@ from controllers.console.wraps import ( ) from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.indexing_runner import IndexingRunner -from core.provider_manager import ProviderManager +from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager from core.rag.datasource.vdb.vector_type import VectorType from core.rag.extractor.entity.datasource_type import DatasourceType from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo @@ -331,7 +331,7 @@ class DatasetListApi(Resource): ) # check embedding setting - provider_manager = ProviderManager() + provider_manager = create_plugin_provider_manager(tenant_id=current_tenant_id) configurations = provider_manager.get_configurations(tenant_id=current_tenant_id) embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True) @@ -445,7 +445,7 @@ class DatasetApi(Resource): data.update({"partial_member_list": part_users_list}) # check embedding setting - provider_manager = ProviderManager() + provider_manager = create_plugin_provider_manager(tenant_id=current_tenant_id) configurations = provider_manager.get_configurations(tenant_id=current_tenant_id) embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True) diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index bc90c4ffbd..dd63ab5994 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -452,7 +452,7 @@ class DatasetInitApi(Resource): if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None: raise ValueError("embedding model and embedding model provider are required for high quality indexing.") try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id) model_manager.get_model_instance( tenant_id=current_tenant_id, provider=knowledge_config.embedding_model_provider, diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index 3fd0f3b712..2b37fdf3a9 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -282,7 +282,7 @@ class DatasetDocumentSegmentApi(Resource): if dataset.indexing_technique == "high_quality": # check embedding model setting try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id) model_manager.get_model_instance( tenant_id=current_tenant_id, provider=dataset.embedding_model_provider, @@ -335,7 +335,7 @@ class DatasetDocumentSegmentAddApi(Resource): # check embedding model setting if dataset.indexing_technique == "high_quality": try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id) model_manager.get_model_instance( tenant_id=current_tenant_id, provider=dataset.embedding_model_provider, @@ -386,7 +386,7 @@ class DatasetDocumentSegmentUpdateApi(Resource): if dataset.indexing_technique == "high_quality": # check embedding model setting try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id) model_manager.get_model_instance( tenant_id=current_tenant_id, provider=dataset.embedding_model_provider, @@ -571,7 +571,7 @@ class ChildChunkAddApi(Resource): # check embedding model setting if dataset.indexing_technique == "high_quality": try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id) model_manager.get_model_instance( tenant_id=current_tenant_id, provider=dataset.embedding_model_provider, diff --git a/api/controllers/console/workspace/models.py b/api/controllers/console/workspace/models.py index d7eceb656c..f9e20eddda 100644 --- a/api/controllers/console/workspace/models.py +++ b/api/controllers/console/workspace/models.py @@ -282,14 +282,18 @@ class ModelProviderModelCredentialApi(Resource): ) if args.config_from == "predefined-model": - available_credentials = model_provider_service.provider_manager.get_provider_available_credentials( - tenant_id=tenant_id, provider_name=provider + available_credentials = model_provider_service.get_provider_available_credentials( + tenant_id=tenant_id, + provider=provider, ) else: # Normalize model_type to the origin value stored in DB (e.g., "text-generation" for LLM) normalized_model_type = args.model_type.to_origin_model_type() - available_credentials = model_provider_service.provider_manager.get_provider_model_available_credentials( - tenant_id=tenant_id, provider_name=provider, model_type=normalized_model_type, model_name=args.model + available_credentials = model_provider_service.get_provider_model_available_credentials( + tenant_id=tenant_id, + provider=provider, + model_type=normalized_model_type, + model=args.model, ) return jsonable_encoder( diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index 83d07087ab..b606117278 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -14,7 +14,7 @@ from controllers.service_api.wraps import ( DatasetApiResource, cloud_edition_billing_rate_limit_check, ) -from core.provider_manager import ProviderManager +from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager from dify_graph.model_runtime.entities.model_entities import ModelType from fields.dataset_fields import dataset_detail_fields from fields.tag_fields import DataSetTag @@ -139,10 +139,10 @@ class DatasetListApi(DatasetApiResource): query.page, query.limit, tenant_id, current_user, query.keyword, query.tag_ids, query.include_all ) # check embedding setting - provider_manager = ProviderManager() assert isinstance(current_user, Account) cid = current_user.current_tenant_id assert cid is not None + provider_manager = create_plugin_provider_manager(tenant_id=cid) configurations = provider_manager.get_configurations(tenant_id=cid) embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True) @@ -253,10 +253,10 @@ class DatasetApi(DatasetApiResource): raise Forbidden(str(e)) data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields)) # check embedding setting - provider_manager = ProviderManager() assert isinstance(current_user, Account) cid = current_user.current_tenant_id assert cid is not None + provider_manager = create_plugin_provider_manager(tenant_id=cid) configurations = provider_manager.get_configurations(tenant_id=cid) embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True) diff --git a/api/controllers/service_api/dataset/segment.py b/api/controllers/service_api/dataset/segment.py index 2e3b7fd85e..534513a3c4 100644 --- a/api/controllers/service_api/dataset/segment.py +++ b/api/controllers/service_api/dataset/segment.py @@ -105,7 +105,7 @@ class SegmentApi(DatasetApiResource): # check embedding model setting if dataset.indexing_technique == "high_quality": try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id) model_manager.get_model_instance( tenant_id=current_tenant_id, provider=dataset.embedding_model_provider, @@ -159,7 +159,7 @@ class SegmentApi(DatasetApiResource): # check embedding model setting if dataset.indexing_technique == "high_quality": try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id) model_manager.get_model_instance( tenant_id=current_tenant_id, provider=dataset.embedding_model_provider, @@ -265,7 +265,7 @@ class DatasetSegmentApi(DatasetApiResource): if dataset.indexing_technique == "high_quality": # check embedding model setting try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id) model_manager.get_model_instance( tenant_id=current_tenant_id, provider=dataset.embedding_model_provider, @@ -360,7 +360,7 @@ class ChildChunkApi(DatasetApiResource): # check embedding model setting if dataset.indexing_technique == "high_quality": try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id) model_manager.get_model_instance( tenant_id=current_tenant_id, provider=dataset.embedding_model_provider, diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py index 9271ed10bd..7abac06dde 100644 --- a/api/core/agent/cot_agent_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -122,7 +122,6 @@ class CotAgentRunner(BaseAgentRunner, ABC): tools=[], stop=app_generate_entity.model_conf.stop, stream=True, - user=self.user_id, callbacks=[], ) diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py index 5e13a13b21..d80f3579b9 100644 --- a/api/core/agent/fc_agent_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -96,7 +96,6 @@ class FunctionCallAgentRunner(BaseAgentRunner): tools=prompt_messages_tools, stop=app_generate_entity.model_conf.stop, stream=self.stream_tool_call, - user=self.user_id, callbacks=[], ) diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/converter.py b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py index 558b6e69a0..3d3cccd45d 100644 --- a/api/core/app/app_config/easy_ui_based_app/model_config/converter.py +++ b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py @@ -4,7 +4,7 @@ from core.app.app_config.entities import EasyUIBasedAppConfig from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.entities.model_entities import ModelStatus from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from core.provider_manager import ProviderManager +from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager from dify_graph.model_runtime.entities.llm_entities import LLMMode from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel @@ -21,7 +21,7 @@ class ModelConfigConverter: """ model_config = app_config.model - provider_manager = ProviderManager() + provider_manager = create_plugin_provider_manager(tenant_id=app_config.tenant_id) provider_model_bundle = provider_manager.get_provider_model_bundle( tenant_id=app_config.tenant_id, provider=model_config.provider, model_type=ModelType.LLM ) diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/manager.py b/api/core/app/app_config/easy_ui_based_app/model_config/manager.py index 0929f52e33..1890ab5a42 100644 --- a/api/core/app/app_config/easy_ui_based_app/model_config/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/model_config/manager.py @@ -2,9 +2,8 @@ from collections.abc import Mapping from typing import Any from core.app.app_config.entities import ModelConfigEntity -from core.provider_manager import ProviderManager +from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory, create_plugin_provider_manager from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType -from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from models.model import AppModelConfigDict from models.provider_ids import ModelProviderID @@ -55,7 +54,7 @@ class ModelConfigManager: raise ValueError("model must be of object type") # model.provider - model_provider_factory = ModelProviderFactory(tenant_id) + model_provider_factory = create_plugin_model_provider_factory(tenant_id=tenant_id) provider_entities = model_provider_factory.get_providers() model_provider_names = [provider.provider for provider in provider_entities] if "provider" not in config["model"]: @@ -71,7 +70,7 @@ class ModelConfigManager: if "name" not in config["model"]: raise ValueError("model.name is required") - provider_manager = ProviderManager() + provider_manager = create_plugin_provider_manager(tenant_id=tenant_id) models = provider_manager.get_configurations(tenant_id).get_models( provider=config["model"]["provider"], model_type=ModelType.LLM ) diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index ecbb1cf2f3..f949c2409d 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -7,7 +7,6 @@ from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validat from constants import UUID_NIL from core.app.app_config.entities import EasyUIBasedAppConfig, WorkflowUIBasedAppConfig from core.entities.provider_configuration import ProviderModelBundle -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY from dify_graph.file import File, FileUploadConfig from dify_graph.model_runtime.entities.model_entities import AIModelEntity @@ -15,6 +14,9 @@ if TYPE_CHECKING: from core.ops.ops_trace_manager import TraceQueueManager +DIFY_RUN_CONTEXT_KEY = "_dify" + + class UserFrom(StrEnum): ACCOUNT = "account" END_USER = "end-user" diff --git a/api/core/app/llm/model_access.py b/api/core/app/llm/model_access.py index a63ff39fa5..630dd98a68 100644 --- a/api/core/app/llm/model_access.py +++ b/api/core/app/llm/model_access.py @@ -2,23 +2,34 @@ from __future__ import annotations from typing import Any -from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity +from core.app.entities.app_invoke_entities import DifyRunContext, ModelConfigWithCredentialsEntity +from core.app.llm.protocols import CredentialsProvider, ModelFactory from core.errors.error import ProviderTokenNotInitError from core.model_manager import ModelInstance, ModelManager +from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager from core.provider_manager import ProviderManager from dify_graph.model_runtime.entities.model_entities import ModelType from dify_graph.nodes.llm.entities import ModelConfig from dify_graph.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError -from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory class DifyCredentialsProvider: tenant_id: str provider_manager: ProviderManager - def __init__(self, tenant_id: str, provider_manager: ProviderManager | None = None) -> None: - self.tenant_id = tenant_id - self.provider_manager = provider_manager or ProviderManager() + def __init__( + self, + *, + run_context: DifyRunContext, + provider_manager: ProviderManager | None = None, + ) -> None: + self.tenant_id = run_context.tenant_id + if provider_manager is None: + provider_manager = create_plugin_provider_manager( + tenant_id=run_context.tenant_id, + user_id=run_context.user_id, + ) + self.provider_manager = provider_manager def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]: provider_configurations = self.provider_manager.get_configurations(self.tenant_id) @@ -42,9 +53,21 @@ class DifyModelFactory: tenant_id: str model_manager: ModelManager - def __init__(self, tenant_id: str, model_manager: ModelManager | None = None) -> None: - self.tenant_id = tenant_id - self.model_manager = model_manager or ModelManager() + def __init__( + self, + *, + run_context: DifyRunContext, + model_manager: ModelManager | None = None, + ) -> None: + self.tenant_id = run_context.tenant_id + if model_manager is None: + model_manager = ModelManager( + provider_manager=create_plugin_provider_manager( + tenant_id=run_context.tenant_id, + user_id=run_context.user_id, + ) + ) + self.model_manager = model_manager def init_model_instance(self, provider_name: str, model_name: str) -> ModelInstance: return self.model_manager.get_model_instance( @@ -55,11 +78,35 @@ class DifyModelFactory: ) -def build_dify_model_access(tenant_id: str) -> tuple[CredentialsProvider, ModelFactory]: - return ( - DifyCredentialsProvider(tenant_id=tenant_id), - DifyModelFactory(tenant_id=tenant_id), +def build_dify_model_access(run_context: DifyRunContext) -> tuple[CredentialsProvider, ModelFactory]: + """Create LLM access adapters that share the same tenant-bound manager graph.""" + provider_manager = create_plugin_provider_manager( + tenant_id=run_context.tenant_id, + user_id=run_context.user_id, ) + model_manager = ModelManager(provider_manager=provider_manager) + + return ( + DifyCredentialsProvider(run_context=run_context, provider_manager=provider_manager), + DifyModelFactory(run_context=run_context, model_manager=model_manager), + ) + + +def _normalize_completion_params(completion_params: dict[str, Any]) -> tuple[dict[str, Any], list[str]]: + """ + Split node-level completion params into provider parameters and stop sequences. + + Workflow LLM-compatible nodes still consume runtime invocation settings from + ``ModelInstance.parameters`` and ``ModelInstance.stop``. Keep the + ``ModelInstance`` view and the returned config entity aligned here so callers + do not need to duplicate normalization logic. + """ + normalized_parameters = dict(completion_params) + stop = normalized_parameters.pop("stop", []) + if not isinstance(stop, list) or not all(isinstance(item, str) for item in stop): + stop = [] + + return normalized_parameters, stop def fetch_model_config( @@ -80,22 +127,18 @@ def fetch_model_config( model_type=ModelType.LLM, ) if provider_model is None: - raise ModelNotExistError(f"Model {node_data_model.name} not exist.") + raise ModelNotExistError(f"Model {node_data_model.name} does not exist.") provider_model.raise_for_status() - completion_params = dict(node_data_model.completion_params) - stop = completion_params.pop("stop", []) - if not isinstance(stop, list): - stop = [] - model_schema = model_instance.model_type_instance.get_model_schema(node_data_model.name, credentials) - if not model_schema: - raise ModelNotExistError(f"Model {node_data_model.name} not exist.") + if model_schema is None: + raise ModelNotExistError(f"Model {node_data_model.name} schema does not exist.") + parameters, stop = _normalize_completion_params(node_data_model.completion_params) model_instance.provider = node_data_model.provider model_instance.model_name = node_data_model.name model_instance.credentials = credentials - model_instance.parameters = completion_params + model_instance.parameters = parameters model_instance.stop = tuple(stop) return model_instance, ModelConfigWithCredentialsEntity( @@ -103,8 +146,8 @@ def fetch_model_config( model=node_data_model.name, model_schema=model_schema, mode=node_data_model.mode, - provider_model_bundle=provider_model_bundle, credentials=credentials, - parameters=completion_params, + parameters=parameters, stop=stop, + provider_model_bundle=provider_model_bundle, ) diff --git a/api/core/app/llm/protocols.py b/api/core/app/llm/protocols.py new file mode 100644 index 0000000000..6381e7ac17 --- /dev/null +++ b/api/core/app/llm/protocols.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +from typing import Any, Protocol + +from core.model_manager import ModelInstance + + +class CredentialsProvider(Protocol): + """Workflow-layer port for loading runtime credentials for a provider/model pair.""" + + def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]: + """Return credentials for the target provider/model or raise a domain error.""" + ... + + +class ModelFactory(Protocol): + """Workflow-layer port for creating mutable ModelInstance objects.""" + + def init_model_instance(self, provider_name: str, model_name: str) -> ModelInstance: + """Create a model instance that is ready for workflow-side hydration.""" + ... diff --git a/api/core/app/workflow/layers/llm_quota.py b/api/core/app/workflow/layers/llm_quota.py index faf1516c40..706caf120e 100644 --- a/api/core/app/workflow/layers/llm_quota.py +++ b/api/core/app/workflow/layers/llm_quota.py @@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, cast, final from typing_extensions import override +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext from core.app.llm import deduct_llm_quota, ensure_llm_quota_available from core.errors.error import QuotaExceededError from core.model_manager import ModelInstance @@ -75,7 +76,7 @@ class LLMQuotaLayer(GraphEngineLayer): return try: - dify_ctx = node.require_dify_context() + dify_ctx = DifyRunContext.model_validate(node.require_run_context_value(DIFY_RUN_CONTEXT_KEY)) deduct_llm_quota( tenant_id=dify_ctx.tenant_id, model_instance=model_instance, diff --git a/api/core/base/tts/app_generator_tts_publisher.py b/api/core/base/tts/app_generator_tts_publisher.py index beda515666..c0ea1d6f65 100644 --- a/api/core/base/tts/app_generator_tts_publisher.py +++ b/api/core/base/tts/app_generator_tts_publisher.py @@ -25,12 +25,10 @@ class AudioTrunk: self.status = status -def _invoice_tts(text_content: str, model_instance: ModelInstance, tenant_id: str, voice: str): +def _invoice_tts(text_content: str, model_instance: ModelInstance, voice: str): if not text_content or text_content.isspace(): return - return model_instance.invoke_tts( - content_text=text_content.strip(), user="responding_tts", tenant_id=tenant_id, voice=voice - ) + return model_instance.invoke_tts(content_text=text_content.strip(), voice=voice) def _process_future( @@ -62,7 +60,7 @@ class AppGeneratorTTSPublisher: self._audio_queue: queue.Queue[AudioTrunk] = queue.Queue() self._msg_queue: queue.Queue[WorkflowQueueMessage | MessageQueueMessage | None] = queue.Queue() self.match = re.compile(r"[。.!?]") - self.model_manager = ModelManager() + self.model_manager = ModelManager.for_tenant(tenant_id=self.tenant_id, user_id="responding_tts") self.model_instance = self.model_manager.get_default_model_instance( tenant_id=self.tenant_id, model_type=ModelType.TTS ) @@ -89,7 +87,7 @@ class AppGeneratorTTSPublisher: if message is None: if self.msg_text and len(self.msg_text.strip()) > 0: futures_result = self.executor.submit( - _invoice_tts, self.msg_text, self.model_instance, self.tenant_id, self.voice + _invoice_tts, self.msg_text, self.model_instance, self.voice ) future_queue.put(futures_result) break @@ -117,9 +115,7 @@ class AppGeneratorTTSPublisher: if len(sentence_arr) >= min(self.max_sentence, 7): self.max_sentence += 1 text_content = "".join(sentence_arr) - futures_result = self.executor.submit( - _invoice_tts, text_content, self.model_instance, self.tenant_id, self.voice - ) + futures_result = self.executor.submit(_invoice_tts, text_content, self.model_instance, self.voice) future_queue.put(futures_result) if isinstance(text_tmp, str): self.msg_text = text_tmp diff --git a/api/core/entities/embedding_type.py b/api/core/entities/embedding_type.py index 89b48fd2ef..f8589a89ec 100644 --- a/api/core/entities/embedding_type.py +++ b/api/core/entities/embedding_type.py @@ -1,10 +1,5 @@ -from enum import StrEnum, auto +"""Compatibility wrapper for the runtime embedding input enum.""" +from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingInputType -class EmbeddingInputType(StrEnum): - """ - Enum for embedding input type. - """ - - DOCUMENT = auto() - QUERY = auto() +__all__ = ["EmbeddingInputType"] diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index a9f2300ba2..10bb8dc042 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -19,6 +19,7 @@ from core.entities.provider_entities import ( ) from core.helper import encrypter from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType +from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory from dify_graph.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType from dify_graph.model_runtime.entities.provider_entities import ( ConfigurateMethod, @@ -27,7 +28,6 @@ from dify_graph.model_runtime.entities.provider_entities import ( ProviderEntity, ) from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel -from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from libs.datetime_utils import naive_utc_now from models.engine import db from models.enums import CredentialSourceType @@ -343,7 +343,7 @@ class ProviderConfiguration(BaseModel): tenant_id=self.tenant_id, token=original_credentials[key] ) - model_provider_factory = ModelProviderFactory(self.tenant_id) + model_provider_factory = create_plugin_model_provider_factory(tenant_id=self.tenant_id) validated_credentials = model_provider_factory.provider_credentials_validate( provider=self.provider.provider, credentials=credentials ) @@ -902,7 +902,7 @@ class ProviderConfiguration(BaseModel): tenant_id=self.tenant_id, token=original_credentials[key] ) - model_provider_factory = ModelProviderFactory(self.tenant_id) + model_provider_factory = create_plugin_model_provider_factory(tenant_id=self.tenant_id) validated_credentials = model_provider_factory.model_credentials_validate( provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials ) @@ -1388,7 +1388,7 @@ class ProviderConfiguration(BaseModel): :param model_type: model type :return: """ - model_provider_factory = ModelProviderFactory(self.tenant_id) + model_provider_factory = create_plugin_model_provider_factory(tenant_id=self.tenant_id) # Get model instance of LLM return model_provider_factory.get_model_type_instance(provider=self.provider.provider, model_type=model_type) @@ -1397,7 +1397,7 @@ class ProviderConfiguration(BaseModel): """ Get model schema """ - model_provider_factory = ModelProviderFactory(self.tenant_id) + model_provider_factory = create_plugin_model_provider_factory(tenant_id=self.tenant_id) return model_provider_factory.get_model_schema( provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials ) @@ -1499,7 +1499,7 @@ class ProviderConfiguration(BaseModel): :param model: model name :return: """ - model_provider_factory = ModelProviderFactory(self.tenant_id) + model_provider_factory = create_plugin_model_provider_factory(tenant_id=self.tenant_id) provider_schema = model_provider_factory.get_provider_schema(self.provider.provider) model_types: list[ModelType] = [] diff --git a/api/core/helper/moderation.py b/api/core/helper/moderation.py index 873f6a4093..eb6967be33 100644 --- a/api/core/helper/moderation.py +++ b/api/core/helper/moderation.py @@ -4,10 +4,10 @@ from typing import cast from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.entities import DEFAULT_PLUGIN_ID +from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory from dify_graph.model_runtime.entities.model_entities import ModelType from dify_graph.model_runtime.errors.invoke import InvokeBadRequestError from dify_graph.model_runtime.model_providers.__base.moderation_model import ModerationModel -from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from extensions.ext_hosting_provider import hosting_configuration from models.provider import ProviderType @@ -41,7 +41,7 @@ def check_moderation(tenant_id: str, model_config: ModelConfigWithCredentialsEnt text_chunk = secrets.choice(text_chunks) try: - model_provider_factory = ModelProviderFactory(tenant_id) + model_provider_factory = create_plugin_model_provider_factory(tenant_id=tenant_id) # Get model instance of LLM model_type_instance = model_provider_factory.get_model_type_instance( diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 52776ee626..d0cfe2c07e 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -50,7 +50,10 @@ logger = logging.getLogger(__name__) class IndexingRunner: def __init__(self): self.storage = storage - self.model_manager = ModelManager() + + @staticmethod + def _get_model_manager(tenant_id: str) -> ModelManager: + return ModelManager.for_tenant(tenant_id=tenant_id) def _handle_indexing_error(self, document_id: str, error: Exception) -> None: """Handle indexing errors by updating document status.""" @@ -291,20 +294,20 @@ class IndexingRunner: raise ValueError("Dataset not found.") if dataset.indexing_technique == "high_quality" or indexing_technique == "high_quality": if dataset.embedding_model_provider: - embedding_model_instance = self.model_manager.get_model_instance( + embedding_model_instance = self._get_model_manager(tenant_id).get_model_instance( tenant_id=tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, model=dataset.embedding_model, ) else: - embedding_model_instance = self.model_manager.get_default_model_instance( + embedding_model_instance = self._get_model_manager(tenant_id).get_default_model_instance( tenant_id=tenant_id, model_type=ModelType.TEXT_EMBEDDING, ) else: if indexing_technique == "high_quality": - embedding_model_instance = self.model_manager.get_default_model_instance( + embedding_model_instance = self._get_model_manager(tenant_id).get_default_model_instance( tenant_id=tenant_id, model_type=ModelType.TEXT_EMBEDDING, ) @@ -574,7 +577,7 @@ class IndexingRunner: embedding_model_instance = None if dataset.indexing_technique == "high_quality": - embedding_model_instance = self.model_manager.get_model_instance( + embedding_model_instance = self._get_model_manager(dataset.tenant_id).get_model_instance( tenant_id=dataset.tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, @@ -766,14 +769,14 @@ class IndexingRunner: embedding_model_instance = None if dataset.indexing_technique == "high_quality": if dataset.embedding_model_provider: - embedding_model_instance = self.model_manager.get_model_instance( + embedding_model_instance = self._get_model_manager(dataset.tenant_id).get_model_instance( tenant_id=dataset.tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, model=dataset.embedding_model, ) else: - embedding_model_instance = self.model_manager.get_default_model_instance( + embedding_model_instance = self._get_model_manager(dataset.tenant_id).get_default_model_instance( tenant_id=dataset.tenant_id, model_type=ModelType.TEXT_EMBEDDING, ) diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index c8848336d9..eb38fb8125 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -62,7 +62,7 @@ class LLMGenerator: prompt += query + "\n" - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) model_instance = model_manager.get_default_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, @@ -120,7 +120,7 @@ class LLMGenerator: prompt = prompt_template.format({"histories": histories, "format_instructions": format_instructions}) try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) model_instance = model_manager.get_default_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, @@ -172,7 +172,7 @@ class LLMGenerator: prompt_messages = [UserPromptMessage(content=prompt_generate)] - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) model_instance = model_manager.get_model_instance( tenant_id=tenant_id, @@ -219,7 +219,7 @@ class LLMGenerator: prompt_messages = [UserPromptMessage(content=prompt_generate_prompt)] # get model instance - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) model_instance = model_manager.get_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, @@ -306,7 +306,7 @@ class LLMGenerator: remove_template_variables=False, ) - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) model_instance = model_manager.get_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, @@ -337,7 +337,7 @@ class LLMGenerator: def generate_qa_document(cls, tenant_id: str, query, document_language: str): prompt = GENERATOR_QA_PROMPT.format(language=document_language) - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) model_instance = model_manager.get_default_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, @@ -362,7 +362,7 @@ class LLMGenerator: @classmethod def generate_structured_output(cls, tenant_id: str, args: RuleStructuredOutputPayload): - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) model_instance = model_manager.get_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, @@ -536,7 +536,7 @@ class LLMGenerator: injected_instruction = injected_instruction.replace(CURRENT, current or "null") if ERROR_MESSAGE in injected_instruction: injected_instruction = injected_instruction.replace(ERROR_MESSAGE, error_message or "null") - model_instance = ModelManager().get_model_instance( + model_instance = ModelManager.for_tenant(tenant_id=tenant_id).get_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, provider=model_config.provider, diff --git a/api/core/llm_generator/output_parser/structured_output.py b/api/core/llm_generator/output_parser/structured_output.py index 77ea1713ea..8945021d47 100644 --- a/api/core/llm_generator/output_parser/structured_output.py +++ b/api/core/llm_generator/output_parser/structured_output.py @@ -55,7 +55,6 @@ def invoke_llm_with_structured_output( tools: Sequence[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: Literal[True], - user: str | None = None, callbacks: list[Callback] | None = None, ) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ... @overload @@ -70,7 +69,6 @@ def invoke_llm_with_structured_output( tools: Sequence[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: Literal[False], - user: str | None = None, callbacks: list[Callback] | None = None, ) -> LLMResultWithStructuredOutput: ... @overload @@ -85,7 +83,6 @@ def invoke_llm_with_structured_output( tools: Sequence[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: bool = True, - user: str | None = None, callbacks: list[Callback] | None = None, ) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]: ... def invoke_llm_with_structured_output( @@ -99,7 +96,6 @@ def invoke_llm_with_structured_output( tools: Sequence[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: bool = True, - user: str | None = None, callbacks: list[Callback] | None = None, ) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]: """ @@ -113,7 +109,6 @@ def invoke_llm_with_structured_output( :param tools: tools for tool calling :param stop: stop words :param stream: is stream response - :param user: unique user id :param callbacks: callbacks :return: full response or stream response chunk generator result """ @@ -143,7 +138,6 @@ def invoke_llm_with_structured_output( tools=tools, stop=stop, stream=stream, - user=user, callbacks=callbacks, ) diff --git a/api/core/model_manager.py b/api/core/model_manager.py index 0f710a8fcf..06b35dffcd 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -7,6 +7,7 @@ from core.entities.embedding_type import EmbeddingInputType from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle from core.entities.provider_entities import ModelLoadBalancingConfiguration from core.errors.error import ProviderTokenNotInitError +from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager from core.provider_manager import ProviderManager from dify_graph.model_runtime.callbacks.base_callback import Callback from dify_graph.model_runtime.entities.llm_entities import LLMResult @@ -30,7 +31,7 @@ logger = logging.getLogger(__name__) class ModelInstance: """ - Model instance class + Model instance class. """ def __init__(self, provider_model_bundle: ProviderModelBundle, model: str): @@ -110,7 +111,6 @@ class ModelInstance: tools: Sequence[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: Literal[True] = True, - user: str | None = None, callbacks: list[Callback] | None = None, ) -> Generator: ... @@ -122,7 +122,6 @@ class ModelInstance: tools: Sequence[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: Literal[False] = False, - user: str | None = None, callbacks: list[Callback] | None = None, ) -> LLMResult: ... @@ -134,7 +133,6 @@ class ModelInstance: tools: Sequence[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: bool = True, - user: str | None = None, callbacks: list[Callback] | None = None, ) -> Union[LLMResult, Generator]: ... @@ -145,7 +143,6 @@ class ModelInstance: tools: Sequence[PromptMessageTool] | None = None, stop: Sequence[str] | None = None, stream: bool = True, - user: str | None = None, callbacks: list[Callback] | None = None, ) -> Union[LLMResult, Generator]: """ @@ -156,7 +153,6 @@ class ModelInstance: :param tools: tools for tool calling :param stop: stop words :param stream: is stream response - :param user: unique user id :param callbacks: callbacks :return: full response or stream response chunk generator result """ @@ -173,7 +169,6 @@ class ModelInstance: tools=tools, stop=stop, stream=stream, - user=user, callbacks=callbacks, ), ) @@ -202,13 +197,12 @@ class ModelInstance: ) def invoke_text_embedding( - self, texts: list[str], user: str | None = None, input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT + self, texts: list[str], input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT ) -> EmbeddingResult: """ Invoke large language model :param texts: texts to embed - :param user: unique user id :param input_type: input type :return: embeddings result """ @@ -221,7 +215,6 @@ class ModelInstance: model=self.model_name, credentials=self.credentials, texts=texts, - user=user, input_type=input_type, ), ) @@ -229,14 +222,12 @@ class ModelInstance: def invoke_multimodal_embedding( self, multimodel_documents: list[dict], - user: str | None = None, input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT, ) -> EmbeddingResult: """ Invoke large language model :param multimodel_documents: multimodel documents to embed - :param user: unique user id :param input_type: input type :return: embeddings result """ @@ -249,7 +240,6 @@ class ModelInstance: model=self.model_name, credentials=self.credentials, multimodel_documents=multimodel_documents, - user=user, input_type=input_type, ), ) @@ -279,7 +269,6 @@ class ModelInstance: docs: list[str], score_threshold: float | None = None, top_n: int | None = None, - user: str | None = None, ) -> RerankResult: """ Invoke rerank model @@ -288,7 +277,6 @@ class ModelInstance: :param docs: docs for reranking :param score_threshold: score threshold :param top_n: top n - :param user: unique user id :return: rerank result """ if not isinstance(self.model_type_instance, RerankModel): @@ -303,7 +291,6 @@ class ModelInstance: docs=docs, score_threshold=score_threshold, top_n=top_n, - user=user, ), ) @@ -313,7 +300,6 @@ class ModelInstance: docs: list[dict], score_threshold: float | None = None, top_n: int | None = None, - user: str | None = None, ) -> RerankResult: """ Invoke rerank model @@ -322,7 +308,6 @@ class ModelInstance: :param docs: docs for reranking :param score_threshold: score threshold :param top_n: top n - :param user: unique user id :return: rerank result """ if not isinstance(self.model_type_instance, RerankModel): @@ -337,16 +322,14 @@ class ModelInstance: docs=docs, score_threshold=score_threshold, top_n=top_n, - user=user, ), ) - def invoke_moderation(self, text: str, user: str | None = None) -> bool: + def invoke_moderation(self, text: str) -> bool: """ Invoke moderation model :param text: text to moderate - :param user: unique user id :return: false if text is safe, true otherwise """ if not isinstance(self.model_type_instance, ModerationModel): @@ -358,16 +341,14 @@ class ModelInstance: model=self.model_name, credentials=self.credentials, text=text, - user=user, ), ) - def invoke_speech2text(self, file: IO[bytes], user: str | None = None) -> str: + def invoke_speech2text(self, file: IO[bytes]) -> str: """ Invoke large language model :param file: audio file - :param user: unique user id :return: text for given audio file """ if not isinstance(self.model_type_instance, Speech2TextModel): @@ -379,18 +360,15 @@ class ModelInstance: model=self.model_name, credentials=self.credentials, file=file, - user=user, ), ) - def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: str | None = None) -> Iterable[bytes]: + def invoke_tts(self, content_text: str, voice: str = "") -> Iterable[bytes]: """ Invoke large language tts model :param content_text: text content to be translated - :param tenant_id: user tenant id :param voice: model timbre - :param user: unique user id :return: text for given audio file """ if not isinstance(self.model_type_instance, TTSModel): @@ -402,8 +380,6 @@ class ModelInstance: model=self.model_name, credentials=self.credentials, content_text=content_text, - user=user, - tenant_id=tenant_id, voice=voice, ), ) @@ -477,10 +453,20 @@ class ModelInstance: class ModelManager: - def __init__(self): - self._provider_manager = ProviderManager() + def __init__(self, provider_manager: ProviderManager): + self._provider_manager = provider_manager - def get_model_instance(self, tenant_id: str, provider: str, model_type: ModelType, model: str) -> ModelInstance: + @classmethod + def for_tenant(cls, tenant_id: str, user_id: str | None = None) -> "ModelManager": + return cls(provider_manager=create_plugin_provider_manager(tenant_id=tenant_id, user_id=user_id)) + + def get_model_instance( + self, + tenant_id: str, + provider: str, + model_type: ModelType, + model: str, + ) -> ModelInstance: """ Get model instance :param tenant_id: tenant id @@ -496,7 +482,8 @@ class ModelManager: tenant_id=tenant_id, provider=provider, model_type=model_type ) - return ModelInstance(provider_model_bundle, model) + model_instance = ModelInstance(provider_model_bundle, model) + return model_instance def get_default_provider_model_name(self, tenant_id: str, model_type: ModelType) -> tuple[str | None, str | None]: """ diff --git a/api/core/moderation/openai_moderation/openai_moderation.py b/api/core/moderation/openai_moderation/openai_moderation.py index 06676f5cf4..78b0cd91c3 100644 --- a/api/core/moderation/openai_moderation/openai_moderation.py +++ b/api/core/moderation/openai_moderation/openai_moderation.py @@ -50,7 +50,7 @@ class OpenAIModeration(Moderation): def _is_violated(self, inputs: dict): text = "\n".join(str(inputs.values())) - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=self.tenant_id) model_instance = model_manager.get_model_instance( tenant_id=self.tenant_id, provider="openai", model_type=ModelType.MODERATION, model="omni-moderation-latest" ) diff --git a/api/core/plugin/backwards_invocation/model.py b/api/core/plugin/backwards_invocation/model.py index 11c9191bac..e929097d2f 100644 --- a/api/core/plugin/backwards_invocation/model.py +++ b/api/core/plugin/backwards_invocation/model.py @@ -30,10 +30,27 @@ from dify_graph.model_runtime.entities.message_entities import ( SystemPromptMessage, UserPromptMessage, ) +from dify_graph.model_runtime.entities.model_entities import ModelType from models.account import Tenant class PluginModelBackwardsInvocation(BaseBackwardsInvocation): + @staticmethod + def _get_bound_model_instance( + *, + tenant_id: str, + user_id: str | None, + provider: str, + model_type: ModelType, + model: str, + ): + return ModelManager.for_tenant(tenant_id=tenant_id, user_id=user_id).get_model_instance( + tenant_id=tenant_id, + provider=provider, + model_type=model_type, + model=model, + ) + @classmethod def invoke_llm( cls, user_id: str, tenant: Tenant, payload: RequestInvokeLLM @@ -41,8 +58,9 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): """ invoke llm """ - model_instance = ModelManager().get_model_instance( + model_instance = cls._get_bound_model_instance( tenant_id=tenant.id, + user_id=user_id, provider=payload.provider, model_type=payload.model_type, model=payload.model, @@ -55,7 +73,6 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): tools=payload.tools, stop=payload.stop, stream=True if payload.stream is None else payload.stream, - user=user_id, ) if isinstance(response, Generator): @@ -94,8 +111,9 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): """ invoke llm with structured output """ - model_instance = ModelManager().get_model_instance( + model_instance = cls._get_bound_model_instance( tenant_id=tenant.id, + user_id=user_id, provider=payload.provider, model_type=payload.model_type, model=payload.model, @@ -115,7 +133,6 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): tools=payload.tools, stop=payload.stop, stream=True if payload.stream is None else payload.stream, - user=user_id, model_parameters=payload.completion_params, ) @@ -156,18 +173,16 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): """ invoke text embedding """ - model_instance = ModelManager().get_model_instance( + model_instance = cls._get_bound_model_instance( tenant_id=tenant.id, + user_id=user_id, provider=payload.provider, model_type=payload.model_type, model=payload.model, ) # invoke model - response = model_instance.invoke_text_embedding( - texts=payload.texts, - user=user_id, - ) + response = model_instance.invoke_text_embedding(texts=payload.texts) return response @@ -176,8 +191,9 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): """ invoke rerank """ - model_instance = ModelManager().get_model_instance( + model_instance = cls._get_bound_model_instance( tenant_id=tenant.id, + user_id=user_id, provider=payload.provider, model_type=payload.model_type, model=payload.model, @@ -189,7 +205,6 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): docs=payload.docs, score_threshold=payload.score_threshold, top_n=payload.top_n, - user=user_id, ) return response @@ -199,20 +214,16 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): """ invoke tts """ - model_instance = ModelManager().get_model_instance( + model_instance = cls._get_bound_model_instance( tenant_id=tenant.id, + user_id=user_id, provider=payload.provider, model_type=payload.model_type, model=payload.model, ) # invoke model - response = model_instance.invoke_tts( - content_text=payload.content_text, - tenant_id=tenant.id, - voice=payload.voice, - user=user_id, - ) + response = model_instance.invoke_tts(content_text=payload.content_text, voice=payload.voice) def handle() -> Generator[dict, None, None]: for chunk in response: @@ -225,8 +236,9 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): """ invoke speech2text """ - model_instance = ModelManager().get_model_instance( + model_instance = cls._get_bound_model_instance( tenant_id=tenant.id, + user_id=user_id, provider=payload.provider, model_type=payload.model_type, model=payload.model, @@ -238,10 +250,7 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): temp.flush() temp.seek(0) - response = model_instance.invoke_speech2text( - file=temp, - user=user_id, - ) + response = model_instance.invoke_speech2text(file=temp) return { "result": response, @@ -252,18 +261,16 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): """ invoke moderation """ - model_instance = ModelManager().get_model_instance( + model_instance = cls._get_bound_model_instance( tenant_id=tenant.id, + user_id=user_id, provider=payload.provider, model_type=payload.model_type, model=payload.model, ) # invoke model - response = model_instance.invoke_moderation( - text=payload.text, - user=user_id, - ) + response = model_instance.invoke_moderation(text=payload.text) return { "result": response, diff --git a/api/core/plugin/impl/model.py b/api/core/plugin/impl/model.py index 49ee5d79cb..30c11bb740 100644 --- a/api/core/plugin/impl/model.py +++ b/api/core/plugin/impl/model.py @@ -1,6 +1,6 @@ import binascii from collections.abc import Generator, Sequence -from typing import IO +from typing import IO, Any from core.plugin.entities.plugin_daemon import ( PluginBasicBooleanResponse, @@ -16,12 +16,19 @@ from core.plugin.impl.base import BasePluginClient from dify_graph.model_runtime.entities.llm_entities import LLMResultChunk from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool from dify_graph.model_runtime.entities.model_entities import AIModelEntity -from dify_graph.model_runtime.entities.rerank_entities import RerankResult +from dify_graph.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingResult from dify_graph.model_runtime.utils.encoders import jsonable_encoder class PluginModelClient(BasePluginClient): + @staticmethod + def _dispatch_payload(*, user_id: str | None, data: dict[str, Any]) -> dict[str, Any]: + payload: dict[str, Any] = {"data": data} + if user_id is not None: + payload["user_id"] = user_id + return payload + def fetch_model_providers(self, tenant_id: str) -> Sequence[PluginModelProviderEntity]: """ Fetch model providers for the given tenant. @@ -37,7 +44,7 @@ class PluginModelClient(BasePluginClient): def get_model_schema( self, tenant_id: str, - user_id: str, + user_id: str | None, plugin_id: str, provider: str, model_type: str, @@ -51,15 +58,15 @@ class PluginModelClient(BasePluginClient): "POST", f"plugin/{tenant_id}/dispatch/model/schema", PluginModelSchemaEntity, - data={ - "user_id": user_id, - "data": { + data=self._dispatch_payload( + user_id=user_id, + data={ "provider": provider, "model_type": model_type, "model": model, "credentials": credentials, }, - }, + ), headers={ "X-Plugin-ID": plugin_id, "Content-Type": "application/json", @@ -72,7 +79,7 @@ class PluginModelClient(BasePluginClient): return None def validate_provider_credentials( - self, tenant_id: str, user_id: str, plugin_id: str, provider: str, credentials: dict + self, tenant_id: str, user_id: str | None, plugin_id: str, provider: str, credentials: dict ) -> bool: """ validate the credentials of the provider @@ -81,13 +88,13 @@ class PluginModelClient(BasePluginClient): "POST", f"plugin/{tenant_id}/dispatch/model/validate_provider_credentials", PluginBasicBooleanResponse, - data={ - "user_id": user_id, - "data": { + data=self._dispatch_payload( + user_id=user_id, + data={ "provider": provider, "credentials": credentials, }, - }, + ), headers={ "X-Plugin-ID": plugin_id, "Content-Type": "application/json", @@ -105,7 +112,7 @@ class PluginModelClient(BasePluginClient): def validate_model_credentials( self, tenant_id: str, - user_id: str, + user_id: str | None, plugin_id: str, provider: str, model_type: str, @@ -119,15 +126,15 @@ class PluginModelClient(BasePluginClient): "POST", f"plugin/{tenant_id}/dispatch/model/validate_model_credentials", PluginBasicBooleanResponse, - data={ - "user_id": user_id, - "data": { + data=self._dispatch_payload( + user_id=user_id, + data={ "provider": provider, "model_type": model_type, "model": model, "credentials": credentials, }, - }, + ), headers={ "X-Plugin-ID": plugin_id, "Content-Type": "application/json", @@ -145,7 +152,7 @@ class PluginModelClient(BasePluginClient): def invoke_llm( self, tenant_id: str, - user_id: str, + user_id: str | None, plugin_id: str, provider: str, model: str, @@ -164,9 +171,9 @@ class PluginModelClient(BasePluginClient): path=f"plugin/{tenant_id}/dispatch/llm/invoke", type_=LLMResultChunk, data=jsonable_encoder( - { - "user_id": user_id, - "data": { + self._dispatch_payload( + user_id=user_id, + data={ "provider": provider, "model_type": "llm", "model": model, @@ -177,7 +184,7 @@ class PluginModelClient(BasePluginClient): "stop": stop, "stream": stream, }, - } + ) ), headers={ "X-Plugin-ID": plugin_id, @@ -193,7 +200,7 @@ class PluginModelClient(BasePluginClient): def get_llm_num_tokens( self, tenant_id: str, - user_id: str, + user_id: str | None, plugin_id: str, provider: str, model_type: str, @@ -210,9 +217,9 @@ class PluginModelClient(BasePluginClient): path=f"plugin/{tenant_id}/dispatch/llm/num_tokens", type_=PluginLLMNumTokensResponse, data=jsonable_encoder( - { - "user_id": user_id, - "data": { + self._dispatch_payload( + user_id=user_id, + data={ "provider": provider, "model_type": model_type, "model": model, @@ -220,7 +227,7 @@ class PluginModelClient(BasePluginClient): "prompt_messages": prompt_messages, "tools": tools, }, - } + ) ), headers={ "X-Plugin-ID": plugin_id, @@ -236,7 +243,7 @@ class PluginModelClient(BasePluginClient): def invoke_text_embedding( self, tenant_id: str, - user_id: str, + user_id: str | None, plugin_id: str, provider: str, model: str, @@ -252,9 +259,9 @@ class PluginModelClient(BasePluginClient): path=f"plugin/{tenant_id}/dispatch/text_embedding/invoke", type_=EmbeddingResult, data=jsonable_encoder( - { - "user_id": user_id, - "data": { + self._dispatch_payload( + user_id=user_id, + data={ "provider": provider, "model_type": "text-embedding", "model": model, @@ -262,7 +269,7 @@ class PluginModelClient(BasePluginClient): "texts": texts, "input_type": input_type, }, - } + ) ), headers={ "X-Plugin-ID": plugin_id, @@ -278,7 +285,7 @@ class PluginModelClient(BasePluginClient): def invoke_multimodal_embedding( self, tenant_id: str, - user_id: str, + user_id: str | None, plugin_id: str, provider: str, model: str, @@ -294,9 +301,9 @@ class PluginModelClient(BasePluginClient): path=f"plugin/{tenant_id}/dispatch/multimodal_embedding/invoke", type_=EmbeddingResult, data=jsonable_encoder( - { - "user_id": user_id, - "data": { + self._dispatch_payload( + user_id=user_id, + data={ "provider": provider, "model_type": "text-embedding", "model": model, @@ -304,7 +311,7 @@ class PluginModelClient(BasePluginClient): "documents": documents, "input_type": input_type, }, - } + ) ), headers={ "X-Plugin-ID": plugin_id, @@ -320,7 +327,7 @@ class PluginModelClient(BasePluginClient): def get_text_embedding_num_tokens( self, tenant_id: str, - user_id: str, + user_id: str | None, plugin_id: str, provider: str, model: str, @@ -335,16 +342,16 @@ class PluginModelClient(BasePluginClient): path=f"plugin/{tenant_id}/dispatch/text_embedding/num_tokens", type_=PluginTextEmbeddingNumTokensResponse, data=jsonable_encoder( - { - "user_id": user_id, - "data": { + self._dispatch_payload( + user_id=user_id, + data={ "provider": provider, "model_type": "text-embedding", "model": model, "credentials": credentials, "texts": texts, }, - } + ) ), headers={ "X-Plugin-ID": plugin_id, @@ -360,7 +367,7 @@ class PluginModelClient(BasePluginClient): def invoke_rerank( self, tenant_id: str, - user_id: str, + user_id: str | None, plugin_id: str, provider: str, model: str, @@ -378,9 +385,9 @@ class PluginModelClient(BasePluginClient): path=f"plugin/{tenant_id}/dispatch/rerank/invoke", type_=RerankResult, data=jsonable_encoder( - { - "user_id": user_id, - "data": { + self._dispatch_payload( + user_id=user_id, + data={ "provider": provider, "model_type": "rerank", "model": model, @@ -390,7 +397,7 @@ class PluginModelClient(BasePluginClient): "score_threshold": score_threshold, "top_n": top_n, }, - } + ) ), headers={ "X-Plugin-ID": plugin_id, @@ -406,13 +413,13 @@ class PluginModelClient(BasePluginClient): def invoke_multimodal_rerank( self, tenant_id: str, - user_id: str, + user_id: str | None, plugin_id: str, provider: str, model: str, credentials: dict, - query: dict, - docs: list[dict], + query: MultimodalRerankInput, + docs: list[MultimodalRerankInput], score_threshold: float | None = None, top_n: int | None = None, ) -> RerankResult: @@ -424,9 +431,9 @@ class PluginModelClient(BasePluginClient): path=f"plugin/{tenant_id}/dispatch/multimodal_rerank/invoke", type_=RerankResult, data=jsonable_encoder( - { - "user_id": user_id, - "data": { + self._dispatch_payload( + user_id=user_id, + data={ "provider": provider, "model_type": "rerank", "model": model, @@ -436,7 +443,7 @@ class PluginModelClient(BasePluginClient): "score_threshold": score_threshold, "top_n": top_n, }, - } + ) ), headers={ "X-Plugin-ID": plugin_id, @@ -451,7 +458,7 @@ class PluginModelClient(BasePluginClient): def invoke_tts( self, tenant_id: str, - user_id: str, + user_id: str | None, plugin_id: str, provider: str, model: str, @@ -467,9 +474,9 @@ class PluginModelClient(BasePluginClient): path=f"plugin/{tenant_id}/dispatch/tts/invoke", type_=PluginStringResultResponse, data=jsonable_encoder( - { - "user_id": user_id, - "data": { + self._dispatch_payload( + user_id=user_id, + data={ "provider": provider, "model_type": "tts", "model": model, @@ -478,7 +485,7 @@ class PluginModelClient(BasePluginClient): "content_text": content_text, "voice": voice, }, - } + ) ), headers={ "X-Plugin-ID": plugin_id, @@ -496,7 +503,7 @@ class PluginModelClient(BasePluginClient): def get_tts_model_voices( self, tenant_id: str, - user_id: str, + user_id: str | None, plugin_id: str, provider: str, model: str, @@ -511,16 +518,16 @@ class PluginModelClient(BasePluginClient): path=f"plugin/{tenant_id}/dispatch/tts/model/voices", type_=PluginVoicesResponse, data=jsonable_encoder( - { - "user_id": user_id, - "data": { + self._dispatch_payload( + user_id=user_id, + data={ "provider": provider, "model_type": "tts", "model": model, "credentials": credentials, "language": language, }, - } + ) ), headers={ "X-Plugin-ID": plugin_id, @@ -540,7 +547,7 @@ class PluginModelClient(BasePluginClient): def invoke_speech_to_text( self, tenant_id: str, - user_id: str, + user_id: str | None, plugin_id: str, provider: str, model: str, @@ -555,16 +562,16 @@ class PluginModelClient(BasePluginClient): path=f"plugin/{tenant_id}/dispatch/speech2text/invoke", type_=PluginStringResultResponse, data=jsonable_encoder( - { - "user_id": user_id, - "data": { + self._dispatch_payload( + user_id=user_id, + data={ "provider": provider, "model_type": "speech2text", "model": model, "credentials": credentials, "file": binascii.hexlify(file.read()).decode(), }, - } + ) ), headers={ "X-Plugin-ID": plugin_id, @@ -580,7 +587,7 @@ class PluginModelClient(BasePluginClient): def invoke_moderation( self, tenant_id: str, - user_id: str, + user_id: str | None, plugin_id: str, provider: str, model: str, @@ -595,16 +602,16 @@ class PluginModelClient(BasePluginClient): path=f"plugin/{tenant_id}/dispatch/moderation/invoke", type_=PluginBasicBooleanResponse, data=jsonable_encoder( - { - "user_id": user_id, - "data": { + self._dispatch_payload( + user_id=user_id, + data={ "provider": provider, "model_type": "moderation", "model": model, "credentials": credentials, "text": text, }, - } + ) ), headers={ "X-Plugin-ID": plugin_id, diff --git a/api/core/plugin/impl/model_runtime.py b/api/core/plugin/impl/model_runtime.py new file mode 100644 index 0000000000..507d5fa4f5 --- /dev/null +++ b/api/core/plugin/impl/model_runtime.py @@ -0,0 +1,490 @@ +from __future__ import annotations + +import hashlib +import logging +from collections.abc import Generator, Iterable, Sequence +from threading import Lock +from typing import IO, Any, Union + +from pydantic import ValidationError +from redis import RedisError + +from configs import dify_config +from core.plugin.entities.plugin_daemon import PluginModelProviderEntity +from core.plugin.impl.asset import PluginAssetManager +from core.plugin.impl.model import PluginModelClient +from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk +from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool +from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelType +from dify_graph.model_runtime.entities.provider_entities import ProviderEntity +from dify_graph.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult +from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingInputType, EmbeddingResult +from dify_graph.model_runtime.runtime import ModelRuntime +from extensions.ext_redis import redis_client +from models.provider_ids import ModelProviderID + +logger = logging.getLogger(__name__) + + +class PluginModelRuntime(ModelRuntime): + """Plugin-backed runtime adapter bound to tenant context and a default user.""" + + tenant_id: str + user_id: str | None + client: PluginModelClient + _provider_entities: tuple[ProviderEntity, ...] | None + _provider_entities_lock: Lock + + def __init__(self, tenant_id: str, user_id: str | None, client: PluginModelClient) -> None: + if client is None: + raise ValueError("client is required.") + self.tenant_id = tenant_id + self.user_id = user_id + self.client = client + self._provider_entities = None + self._provider_entities_lock = Lock() + + def fetch_model_providers(self) -> Sequence[ProviderEntity]: + if self._provider_entities is not None: + return self._provider_entities + + with self._provider_entities_lock: + if self._provider_entities is None: + self._provider_entities = tuple( + self._to_provider_entity(provider) for provider in self.client.fetch_model_providers(self.tenant_id) + ) + + return self._provider_entities + + def get_provider_icon(self, *, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]: + provider_schema = self._get_provider_schema(provider) + + if icon_type.lower() == "icon_small": + if not provider_schema.icon_small: + raise ValueError(f"Provider {provider} does not have small icon.") + file_name = ( + provider_schema.icon_small.zh_Hans if lang.lower() == "zh_hans" else provider_schema.icon_small.en_US + ) + elif icon_type.lower() == "icon_small_dark": + if not provider_schema.icon_small_dark: + raise ValueError(f"Provider {provider} does not have small dark icon.") + file_name = ( + provider_schema.icon_small_dark.zh_Hans + if lang.lower() == "zh_hans" + else provider_schema.icon_small_dark.en_US + ) + else: + raise ValueError(f"Unsupported icon type: {icon_type}.") + + if not file_name: + raise ValueError(f"Provider {provider} does not have icon.") + + image_mime_types = { + "jpg": "image/jpeg", + "jpeg": "image/jpeg", + "png": "image/png", + "gif": "image/gif", + "bmp": "image/bmp", + "tiff": "image/tiff", + "tif": "image/tiff", + "webp": "image/webp", + "svg": "image/svg+xml", + "ico": "image/vnd.microsoft.icon", + "heif": "image/heif", + "heic": "image/heic", + } + + extension = file_name.split(".")[-1] + mime_type = image_mime_types.get(extension, "image/png") + return PluginAssetManager().fetch_asset(tenant_id=self.tenant_id, id=file_name), mime_type + + def validate_provider_credentials(self, *, provider: str, credentials: dict[str, Any]) -> None: + plugin_id, provider_name = self._split_provider(provider) + self.client.validate_provider_credentials( + tenant_id=self.tenant_id, + user_id=self.user_id, + plugin_id=plugin_id, + provider=provider_name, + credentials=credentials, + ) + + def validate_model_credentials( + self, + *, + provider: str, + model_type: ModelType, + model: str, + credentials: dict[str, Any], + ) -> None: + plugin_id, provider_name = self._split_provider(provider) + self.client.validate_model_credentials( + tenant_id=self.tenant_id, + user_id=self.user_id, + plugin_id=plugin_id, + provider=provider_name, + model_type=model_type.value, + model=model, + credentials=credentials, + ) + + def get_model_schema( + self, + *, + provider: str, + model_type: ModelType, + model: str, + credentials: dict[str, Any], + ) -> AIModelEntity | None: + cache_key = self._get_schema_cache_key( + provider=provider, + model_type=model_type, + model=model, + credentials=credentials, + ) + + cached_schema_json = None + try: + cached_schema_json = redis_client.get(cache_key) + except (RedisError, RuntimeError) as exc: + logger.warning( + "Failed to read plugin model schema cache for model %s: %s", + model, + str(exc), + exc_info=True, + ) + + if cached_schema_json: + try: + return AIModelEntity.model_validate_json(cached_schema_json) + except ValidationError: + logger.warning("Failed to validate cached plugin model schema for model %s", model, exc_info=True) + try: + redis_client.delete(cache_key) + except (RedisError, RuntimeError) as exc: + logger.warning( + "Failed to delete invalid plugin model schema cache for model %s: %s", + model, + str(exc), + exc_info=True, + ) + + plugin_id, provider_name = self._split_provider(provider) + schema = self.client.get_model_schema( + tenant_id=self.tenant_id, + user_id=self.user_id, + plugin_id=plugin_id, + provider=provider_name, + model_type=model_type.value, + model=model, + credentials=credentials, + ) + + if schema: + try: + redis_client.setex(cache_key, dify_config.PLUGIN_MODEL_SCHEMA_CACHE_TTL, schema.model_dump_json()) + except (RedisError, RuntimeError) as exc: + logger.warning( + "Failed to write plugin model schema cache for model %s: %s", + model, + str(exc), + exc_info=True, + ) + + return schema + + def invoke_llm( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + model_parameters: dict[str, Any], + prompt_messages: Sequence[PromptMessage], + tools: list[PromptMessageTool] | None, + stop: Sequence[str] | None, + stream: bool, + ) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]: + plugin_id, provider_name = self._split_provider(provider) + return self.client.invoke_llm( + tenant_id=self.tenant_id, + user_id=self.user_id, + plugin_id=plugin_id, + provider=provider_name, + model=model, + credentials=credentials, + model_parameters=model_parameters, + prompt_messages=list(prompt_messages), + tools=tools, + stop=list(stop) if stop else None, + stream=stream, + ) + + def get_llm_num_tokens( + self, + *, + provider: str, + model_type: ModelType, + model: str, + credentials: dict[str, Any], + prompt_messages: Sequence[PromptMessage], + tools: Sequence[PromptMessageTool] | None, + ) -> int: + if not dify_config.PLUGIN_BASED_TOKEN_COUNTING_ENABLED: + return 0 + + plugin_id, provider_name = self._split_provider(provider) + return self.client.get_llm_num_tokens( + tenant_id=self.tenant_id, + user_id=self.user_id, + plugin_id=plugin_id, + provider=provider_name, + model_type=model_type.value, + model=model, + credentials=credentials, + prompt_messages=list(prompt_messages), + tools=list(tools) if tools else None, + ) + + def invoke_text_embedding( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + texts: list[str], + input_type: EmbeddingInputType, + ) -> EmbeddingResult: + plugin_id, provider_name = self._split_provider(provider) + return self.client.invoke_text_embedding( + tenant_id=self.tenant_id, + user_id=self.user_id, + plugin_id=plugin_id, + provider=provider_name, + model=model, + credentials=credentials, + texts=texts, + input_type=input_type, + ) + + def invoke_multimodal_embedding( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + documents: list[dict[str, Any]], + input_type: EmbeddingInputType, + ) -> EmbeddingResult: + plugin_id, provider_name = self._split_provider(provider) + return self.client.invoke_multimodal_embedding( + tenant_id=self.tenant_id, + user_id=self.user_id, + plugin_id=plugin_id, + provider=provider_name, + model=model, + credentials=credentials, + documents=documents, + input_type=input_type, + ) + + def get_text_embedding_num_tokens( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + texts: list[str], + ) -> list[int]: + plugin_id, provider_name = self._split_provider(provider) + return self.client.get_text_embedding_num_tokens( + tenant_id=self.tenant_id, + user_id=self.user_id, + plugin_id=plugin_id, + provider=provider_name, + model=model, + credentials=credentials, + texts=texts, + ) + + def invoke_rerank( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + query: str, + docs: list[str], + score_threshold: float | None, + top_n: int | None, + ) -> RerankResult: + plugin_id, provider_name = self._split_provider(provider) + return self.client.invoke_rerank( + tenant_id=self.tenant_id, + user_id=self.user_id, + plugin_id=plugin_id, + provider=provider_name, + model=model, + credentials=credentials, + query=query, + docs=docs, + score_threshold=score_threshold, + top_n=top_n, + ) + + def invoke_multimodal_rerank( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + query: MultimodalRerankInput, + docs: list[MultimodalRerankInput], + score_threshold: float | None, + top_n: int | None, + ) -> RerankResult: + plugin_id, provider_name = self._split_provider(provider) + return self.client.invoke_multimodal_rerank( + tenant_id=self.tenant_id, + user_id=self.user_id, + plugin_id=plugin_id, + provider=provider_name, + model=model, + credentials=credentials, + query=query, + docs=docs, + score_threshold=score_threshold, + top_n=top_n, + ) + + def invoke_tts( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + content_text: str, + voice: str, + ) -> Iterable[bytes]: + plugin_id, provider_name = self._split_provider(provider) + return self.client.invoke_tts( + tenant_id=self.tenant_id, + user_id=self.user_id, + plugin_id=plugin_id, + provider=provider_name, + model=model, + credentials=credentials, + content_text=content_text, + voice=voice, + ) + + def get_tts_model_voices( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + language: str | None, + ) -> Any: + plugin_id, provider_name = self._split_provider(provider) + return self.client.get_tts_model_voices( + tenant_id=self.tenant_id, + user_id=self.user_id, + plugin_id=plugin_id, + provider=provider_name, + model=model, + credentials=credentials, + language=language, + ) + + def invoke_speech_to_text( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + file: IO[bytes], + ) -> str: + plugin_id, provider_name = self._split_provider(provider) + return self.client.invoke_speech_to_text( + tenant_id=self.tenant_id, + user_id=self.user_id, + plugin_id=plugin_id, + provider=provider_name, + model=model, + credentials=credentials, + file=file, + ) + + def invoke_moderation( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + text: str, + ) -> bool: + plugin_id, provider_name = self._split_provider(provider) + return self.client.invoke_moderation( + tenant_id=self.tenant_id, + user_id=self.user_id, + plugin_id=plugin_id, + provider=provider_name, + model=model, + credentials=credentials, + text=text, + ) + + def _get_provider_short_name_alias(self, provider: PluginModelProviderEntity) -> str: + """ + Expose a bare provider alias only for the canonical provider mapping. + + Multiple plugins can publish the same short provider slug. If every + provider entity keeps that slug in ``provider_name``, callers that still + resolve by short name become order-dependent. Restrict the alias to the + provider selected by ``ModelProviderID`` so legacy short-name lookups + remain deterministic while the runtime surface stays canonical. + """ + try: + canonical_provider_id = ModelProviderID(provider.provider) + except ValueError: + return "" + + if canonical_provider_id.plugin_id != provider.plugin_id: + return "" + if canonical_provider_id.provider_name != provider.provider: + return "" + + return provider.provider + + def _to_provider_entity(self, provider: PluginModelProviderEntity) -> ProviderEntity: + declaration = provider.declaration.model_copy(deep=True) + declaration.provider = f"{provider.plugin_id}/{provider.provider}" + declaration.provider_name = self._get_provider_short_name_alias(provider) + return declaration + + def _get_provider_schema(self, provider: str) -> ProviderEntity: + providers = self.fetch_model_providers() + provider_entity = next((item for item in providers if item.provider == provider), None) + if provider_entity is None: + provider_entity = next((item for item in providers if provider == item.provider_name), None) + if provider_entity is None: + raise ValueError(f"Invalid provider: {provider}") + return provider_entity + + def _get_schema_cache_key( + self, + *, + provider: str, + model_type: ModelType, + model: str, + credentials: dict[str, Any], + ) -> str: + cache_key = f"{self.tenant_id}:{provider}:{model_type.value}:{model}" + sorted_credentials = sorted(credentials.items()) if credentials else [] + return cache_key + ":".join( + [hashlib.md5(f"{key}:{value}".encode()).hexdigest() for key, value in sorted_credentials] + ) + + def _split_provider(self, provider: str) -> tuple[str, str]: + provider_id = ModelProviderID(provider) + return provider_id.plugin_id, provider_id.provider_name diff --git a/api/core/plugin/impl/model_runtime_factory.py b/api/core/plugin/impl/model_runtime_factory.py new file mode 100644 index 0000000000..a5d1c81d5d --- /dev/null +++ b/api/core/plugin/impl/model_runtime_factory.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from core.plugin.impl.model import PluginModelClient + +if TYPE_CHECKING: + from core.model_manager import ModelManager + from core.plugin.impl.model_runtime import PluginModelRuntime + from core.provider_manager import ProviderManager + from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory + + +def create_plugin_model_runtime(*, tenant_id: str, user_id: str | None = None) -> PluginModelRuntime: + """Create a plugin runtime with its client dependency fully composed.""" + from core.plugin.impl.model_runtime import PluginModelRuntime + + return PluginModelRuntime( + tenant_id=tenant_id, + user_id=user_id, + client=PluginModelClient(), + ) + + +def create_plugin_model_provider_factory(*, tenant_id: str, user_id: str | None = None) -> ModelProviderFactory: + """Create a tenant-bound model provider factory for service flows.""" + from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory + + return ModelProviderFactory(model_runtime=create_plugin_model_runtime(tenant_id=tenant_id, user_id=user_id)) + + +def create_plugin_provider_manager(*, tenant_id: str, user_id: str | None = None) -> ProviderManager: + """Create a tenant-bound provider manager for service flows.""" + from core.provider_manager import ProviderManager + + return ProviderManager(model_runtime=create_plugin_model_runtime(tenant_id=tenant_id, user_id=user_id)) + + +def create_plugin_model_manager(*, tenant_id: str, user_id: str | None = None) -> ModelManager: + """Create a tenant-bound model manager for service flows.""" + from core.model_manager import ModelManager + + return ModelManager( + provider_manager=create_plugin_provider_manager(tenant_id=tenant_id, user_id=user_id), + ) diff --git a/api/core/prompt/entities/advanced_prompt_entities.py b/api/core/prompt/entities/advanced_prompt_entities.py index 667f5ef099..a6aed8eeb5 100644 --- a/api/core/prompt/entities/advanced_prompt_entities.py +++ b/api/core/prompt/entities/advanced_prompt_entities.py @@ -1,50 +1,7 @@ -from typing import Literal +from dify_graph.prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig -from pydantic import BaseModel - -from dify_graph.model_runtime.entities.message_entities import PromptMessageRole - - -class ChatModelMessage(BaseModel): - """ - Chat Message. - """ - - text: str - role: PromptMessageRole - edition_type: Literal["basic", "jinja2"] | None = None - - -class CompletionModelPromptTemplate(BaseModel): - """ - Completion Model Prompt Template. - """ - - text: str - edition_type: Literal["basic", "jinja2"] | None = None - - -class MemoryConfig(BaseModel): - """ - Memory Config. - """ - - class RolePrefix(BaseModel): - """ - Role Prefix. - """ - - user: str - assistant: str - - class WindowConfig(BaseModel): - """ - Window Config. - """ - - enabled: bool - size: int | None = None - - role_prefix: RolePrefix | None = None - window: WindowConfig - query_prompt_template: str | None = None +__all__ = [ + "ChatModelMessage", + "CompletionModelPromptTemplate", + "MemoryConfig", +] diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 3c3fbd6dd2..d0da840db5 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -1,9 +1,11 @@ +from __future__ import annotations + import contextlib import json from collections import defaultdict from collections.abc import Sequence from json import JSONDecodeError -from typing import Any, cast +from typing import TYPE_CHECKING, Any, cast from sqlalchemy import select from sqlalchemy.exc import IntegrityError @@ -53,15 +55,22 @@ from models.provider import ( from models.provider_ids import ModelProviderID from services.feature_service import FeatureService +if TYPE_CHECKING: + from dify_graph.model_runtime.runtime import ModelRuntime + class ProviderManager: """ - ProviderManager is a class that manages the model providers includes Hosting and Customize Model Providers. + ProviderManager manages tenant-scoped model provider configuration. + + The runtime adapter is injected by the composition layer so this class stays + focused on configuration assembly instead of constructing plugin runtimes. """ - def __init__(self): + def __init__(self, model_runtime: ModelRuntime): self.decoding_rsa_key = None self.decoding_cipher_rsa = None + self._model_runtime = model_runtime def get_configurations(self, tenant_id: str) -> ProviderConfigurations: """ @@ -127,7 +136,7 @@ class ProviderManager: ) # Get all provider entities - model_provider_factory = ModelProviderFactory(tenant_id) + model_provider_factory = ModelProviderFactory(model_runtime=self._model_runtime) provider_entities = model_provider_factory.get_providers() # Get All preferred provider types of the workspace @@ -321,7 +330,7 @@ class ProviderManager: if not default_model: return None - model_provider_factory = ModelProviderFactory(tenant_id) + model_provider_factory = ModelProviderFactory(model_runtime=self._model_runtime) provider_schema = model_provider_factory.get_provider_schema(provider=default_model.provider_name) return DefaultModelEntity( diff --git a/api/core/rag/data_post_processor/data_post_processor.py b/api/core/rag/data_post_processor/data_post_processor.py index 33eb5f963a..66e46df6cb 100644 --- a/api/core/rag/data_post_processor/data_post_processor.py +++ b/api/core/rag/data_post_processor/data_post_processor.py @@ -106,9 +106,9 @@ class DataPostProcessor: ) -> ModelInstance | None: if reranking_model: try: - model_manager = ModelManager() - reranking_provider_name = reranking_model["reranking_provider_name"] - reranking_model_name = reranking_model["reranking_model_name"] + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) + reranking_provider_name = reranking_model.get("reranking_provider_name") + reranking_model_name = reranking_model.get("reranking_model_name") if not reranking_provider_name or not reranking_model_name: return None rerank_model_instance = model_manager.get_model_instance( diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index 713319ab9d..2c449e9163 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -328,7 +328,7 @@ class RetrievalService: str(dataset.tenant_id), str(RerankMode.RERANKING_MODEL), reranking_model, None, False ) if dataset.is_multimodal: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=dataset.tenant_id) is_support_vision = model_manager.check_model_support_vision( tenant_id=dataset.tenant_id, provider=reranking_model["reranking_provider_name"], diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index cd12cd3fae..300a2e14ed 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -303,7 +303,7 @@ class Vector: redis_client.delete(collection_exist_cache_key) def _get_embeddings(self) -> Embeddings: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=self._dataset.tenant_id) embedding_model = model_manager.get_model_instance( tenant_id=self._dataset.tenant_id, diff --git a/api/core/rag/docstore/dataset_docstore.py b/api/core/rag/docstore/dataset_docstore.py index 16a5588024..3851a27d14 100644 --- a/api/core/rag/docstore/dataset_docstore.py +++ b/api/core/rag/docstore/dataset_docstore.py @@ -72,7 +72,7 @@ class DatasetDocumentStore: max_position = 0 embedding_model = None if self._dataset.indexing_technique == "high_quality": - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=self._dataset.tenant_id) embedding_model = model_manager.get_model_instance( tenant_id=self._dataset.tenant_id, provider=self._dataset.embedding_model_provider, diff --git a/api/core/rag/embedding/cached_embedding.py b/api/core/rag/embedding/cached_embedding.py index 6d1b65a055..3c706d800b 100644 --- a/api/core/rag/embedding/cached_embedding.py +++ b/api/core/rag/embedding/cached_embedding.py @@ -8,7 +8,7 @@ from sqlalchemy.exc import IntegrityError from configs import dify_config from core.entities.embedding_type import EmbeddingInputType -from core.model_manager import ModelInstance +from core.model_manager import ModelInstance, ModelManager from core.rag.embedding.embedding_base import Embeddings from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey from dify_graph.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel @@ -22,8 +22,20 @@ logger = logging.getLogger(__name__) class CacheEmbedding(Embeddings): def __init__(self, model_instance: ModelInstance, user: str | None = None): - self._model_instance = model_instance - self._user = user + self._model_instance = self._bind_model_instance(model_instance, user) + + @staticmethod + def _bind_model_instance(model_instance: ModelInstance, user: str | None) -> ModelInstance: + if user is None: + return model_instance + + tenant_id = model_instance.provider_model_bundle.configuration.tenant_id + return ModelManager.for_tenant(tenant_id=tenant_id, user_id=user).get_model_instance( + tenant_id=tenant_id, + provider=model_instance.provider, + model_type=model_instance.model_type_instance.model_type, + model=model_instance.model_name, + ) def embed_documents(self, texts: list[str]) -> list[list[float]]: """Embed search docs in batches of 10.""" @@ -65,7 +77,7 @@ class CacheEmbedding(Embeddings): batch_texts = embedding_queue_texts[i : i + max_chunks] embedding_result = self._model_instance.invoke_text_embedding( - texts=batch_texts, user=self._user, input_type=EmbeddingInputType.DOCUMENT + texts=batch_texts, input_type=EmbeddingInputType.DOCUMENT ) for vector in embedding_result.embeddings: @@ -147,7 +159,6 @@ class CacheEmbedding(Embeddings): embedding_result = self._model_instance.invoke_multimodal_embedding( multimodel_documents=batch_multimodel_documents, - user=self._user, input_type=EmbeddingInputType.DOCUMENT, ) @@ -202,7 +213,7 @@ class CacheEmbedding(Embeddings): return [float(x) for x in decoded_embedding] try: embedding_result = self._model_instance.invoke_text_embedding( - texts=[text], user=self._user, input_type=EmbeddingInputType.QUERY + texts=[text], input_type=EmbeddingInputType.QUERY ) embedding_results = embedding_result.embeddings[0] @@ -245,7 +256,7 @@ class CacheEmbedding(Embeddings): return [float(x) for x in decoded_embedding] try: embedding_result = self._model_instance.invoke_multimodal_embedding( - multimodel_documents=[multimodel_document], user=self._user, input_type=EmbeddingInputType.QUERY + multimodel_documents=[multimodel_document], input_type=EmbeddingInputType.QUERY ) embedding_results = embedding_result.embeddings[0] diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index 80163b1707..80b0e20784 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -12,7 +12,7 @@ from core.app.llm import deduct_llm_quota from core.entities.knowledge_entities import PreviewDetail from core.llm_generator.prompts import DEFAULT_GENERATOR_SUMMARY_PROMPT from core.model_manager import ModelInstance -from core.provider_manager import ProviderManager +from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager from core.rag.cleaner.clean_processor import CleanProcessor from core.rag.data_post_processor.data_post_processor import RerankingModelDict from core.rag.datasource.keyword.keyword_factory import Keyword @@ -410,7 +410,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): # If default prompt doesn't have {language} placeholder, use it as-is pass - provider_manager = ProviderManager() + provider_manager = create_plugin_provider_manager(tenant_id=tenant_id) provider_model_bundle = provider_manager.get_provider_model_bundle( tenant_id, model_provider_name, ModelType.LLM ) diff --git a/api/core/rag/rerank/rerank_model.py b/api/core/rag/rerank/rerank_model.py index fcb14ffc52..44983588bd 100644 --- a/api/core/rag/rerank/rerank_model.py +++ b/api/core/rag/rerank/rerank_model.py @@ -16,6 +16,18 @@ class RerankModelRunner(BaseRerankRunner): def __init__(self, rerank_model_instance: ModelInstance): self.rerank_model_instance = rerank_model_instance + def _get_invoke_model_instance(self, user: str | None) -> ModelInstance: + if user is None: + return self.rerank_model_instance + + tenant_id = self.rerank_model_instance.provider_model_bundle.configuration.tenant_id + return ModelManager.for_tenant(tenant_id=tenant_id, user_id=user).get_model_instance( + tenant_id=tenant_id, + provider=self.rerank_model_instance.provider, + model_type=ModelType.RERANK, + model=self.rerank_model_instance.model_name, + ) + def run( self, query: str, @@ -34,7 +46,9 @@ class RerankModelRunner(BaseRerankRunner): :param user: unique user id if needed :return: """ - model_manager = ModelManager() + model_manager = ModelManager.for_tenant( + tenant_id=self.rerank_model_instance.provider_model_bundle.configuration.tenant_id + ) is_support_vision = model_manager.check_model_support_vision( tenant_id=self.rerank_model_instance.provider_model_bundle.configuration.tenant_id, provider=self.rerank_model_instance.provider, @@ -102,8 +116,8 @@ class RerankModelRunner(BaseRerankRunner): docs.append(document.page_content) unique_documents.append(document) - rerank_result = self.rerank_model_instance.invoke_rerank( - query=query, docs=docs, score_threshold=score_threshold, top_n=top_n, user=user + rerank_result = self._get_invoke_model_instance(user).invoke_rerank( + query=query, docs=docs, score_threshold=score_threshold, top_n=top_n ) return rerank_result, unique_documents @@ -180,8 +194,8 @@ class RerankModelRunner(BaseRerankRunner): "content": file_query, "content_type": DocType.IMAGE, } - rerank_result = self.rerank_model_instance.invoke_multimodal_rerank( - query=file_query_dict, docs=docs, score_threshold=score_threshold, top_n=top_n, user=user + rerank_result = self._get_invoke_model_instance(user).invoke_multimodal_rerank( + query=file_query_dict, docs=docs, score_threshold=score_threshold, top_n=top_n ) return rerank_result, unique_documents else: diff --git a/api/core/rag/rerank/weight_rerank.py b/api/core/rag/rerank/weight_rerank.py index 7edd05d2d1..c270bace6e 100644 --- a/api/core/rag/rerank/weight_rerank.py +++ b/api/core/rag/rerank/weight_rerank.py @@ -163,7 +163,7 @@ class WeightRerankRunner(BaseRerankRunner): """ query_vector_scores = [] - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) embedding_model = model_manager.get_model_instance( tenant_id=tenant_id, diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 78a97f79a5..3b808b36dc 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -160,7 +160,7 @@ class DatasetRetrieval: if request.model_provider is None or request.model_name is None or request.query is None: raise ValueError("model_provider, model_name, and query are required for single retrieval mode") - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=request.tenant_id) model_instance = model_manager.get_model_instance( tenant_id=request.tenant_id, model_type=ModelType.LLM, @@ -387,7 +387,7 @@ class DatasetRetrieval: model_type_instance = model_config.provider_model_bundle.model_type_instance model_type_instance = cast(LargeLanguageModel, model_type_instance) - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) model_instance = model_manager.get_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, provider=model_config.provider, model=model_config.model ) @@ -1411,7 +1411,7 @@ class DatasetRetrieval: raise ValueError("metadata_model_config is required") # get metadata model instance # fetch model config - model_instance, model_config = self._fetch_model_config(tenant_id, metadata_model_config) + model_instance, model_config = self._fetch_model_config(tenant_id, metadata_model_config, user_id=user_id) # fetch prompt messages prompt_messages, stop = self._get_prompt_template( @@ -1430,7 +1430,6 @@ class DatasetRetrieval: model_parameters=model_config.parameters, stop=stop, stream=True, - user=user_id, ), ) @@ -1533,7 +1532,7 @@ class DatasetRetrieval: return filters def _fetch_model_config( - self, tenant_id: str, model: ModelConfig + self, tenant_id: str, model: ModelConfig, user_id: str | None = None ) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: """ Fetch model config @@ -1543,7 +1542,7 @@ class DatasetRetrieval: model_name = model.name provider_name = model.provider - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id, user_id=user_id) model_instance = model_manager.get_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, provider=provider_name, model=model_name ) diff --git a/api/core/rag/retrieval/router/multi_dataset_react_route.py b/api/core/rag/retrieval/router/multi_dataset_react_route.py index ea110fa0a7..3dddce449e 100644 --- a/api/core/rag/retrieval/router/multi_dataset_react_route.py +++ b/api/core/rag/retrieval/router/multi_dataset_react_route.py @@ -3,13 +3,14 @@ from typing import Union from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.app.llm import deduct_llm_quota -from core.model_manager import ModelInstance +from core.model_manager import ModelInstance, ModelManager from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate from core.rag.retrieval.output_parser.react_output import ReactAction from core.rag.retrieval.output_parser.structured_chat import StructuredChatOutputParser from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool +from dify_graph.model_runtime.entities.model_entities import ModelType PREFIX = """Respond to the human as helpfully and accurately as possible. You have access to the following tools:""" @@ -150,19 +151,24 @@ class ReactMultiDatasetRouter: :param stop: stop :return: """ - invoke_result: Generator[LLMResult, None, None] = model_instance.invoke_llm( + bound_model_instance = ModelManager.for_tenant(tenant_id=tenant_id, user_id=user_id).get_model_instance( + tenant_id=tenant_id, + provider=model_instance.provider, + model_type=ModelType.LLM, + model=model_instance.model_name, + ) + invoke_result: Generator[LLMResult, None, None] = bound_model_instance.invoke_llm( prompt_messages=prompt_messages, model_parameters=completion_param, stop=stop, stream=True, - user=user_id, ) # handle invoke result text, usage = self._handle_invoke_result(invoke_result=invoke_result) # deduct quota - deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage) + deduct_llm_quota(tenant_id=tenant_id, model_instance=bound_model_instance, usage=usage) return text, usage diff --git a/api/core/repositories/human_input_repository.py b/api/core/repositories/human_input_repository.py index 6607a87032..8966e4edca 100644 --- a/api/core/repositories/human_input_repository.py +++ b/api/core/repositories/human_input_repository.py @@ -201,8 +201,10 @@ class HumanInputFormRepositoryImpl: self, *, tenant_id: str, + app_id: str | None = None, ): self._tenant_id = tenant_id + self._app_id = app_id def _delivery_method_to_model( self, @@ -340,6 +342,9 @@ class HumanInputFormRepositoryImpl: def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: form_config: HumanInputNodeData = params.form_config + app_id = params.app_id or self._app_id + if not app_id: + raise ValueError("app_id is required to create a human input form") with session_factory.create_session() as session, session.begin(): # Generate unique form ID @@ -359,7 +364,7 @@ class HumanInputFormRepositoryImpl: form_model = HumanInputForm( id=form_id, tenant_id=self._tenant_id, - app_id=params.app_id, + app_id=app_id, workflow_run_id=params.workflow_execution_id, form_kind=params.form_kind, node_id=params.node_id, diff --git a/api/core/tools/builtin_tool/providers/audio/tools/asr.py b/api/core/tools/builtin_tool/providers/audio/tools/asr.py index dacc49c746..0860df2225 100644 --- a/api/core/tools/builtin_tool/providers/audio/tools/asr.py +++ b/api/core/tools/builtin_tool/providers/audio/tools/asr.py @@ -22,6 +22,9 @@ class ASRTool(BuiltinTool): app_id: str | None = None, message_id: str | None = None, ) -> Generator[ToolInvokeMessage, None, None]: + if not self.runtime: + raise ValueError("Runtime is required") + runtime = self.runtime file = tool_parameters.get("audio_file") if file.type != FileType.AUDIO: # type: ignore yield self.create_text_message("not a valid audio file") @@ -29,20 +32,19 @@ class ASRTool(BuiltinTool): audio_binary = io.BytesIO(download(file)) # type: ignore audio_binary.name = "temp.mp3" provider, model = tool_parameters.get("model").split("#") # type: ignore - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=runtime.tenant_id, user_id=user_id) model_instance = model_manager.get_model_instance( - tenant_id=self.runtime.tenant_id, + tenant_id=runtime.tenant_id, provider=provider, model_type=ModelType.SPEECH2TEXT, model=model, ) - text = model_instance.invoke_speech2text( - file=audio_binary, - user=user_id, - ) + text = model_instance.invoke_speech2text(file=audio_binary) yield self.create_text_message(text) def get_available_models(self) -> list[tuple[str, str]]: + if not self.runtime: + raise ValueError("Runtime is required") model_provider_service = ModelProviderService() models = model_provider_service.get_models_by_model_type( tenant_id=self.runtime.tenant_id, model_type="speech2text" diff --git a/api/core/tools/builtin_tool/providers/audio/tools/tts.py b/api/core/tools/builtin_tool/providers/audio/tools/tts.py index 7818bff0ab..0b32b74e95 100644 --- a/api/core/tools/builtin_tool/providers/audio/tools/tts.py +++ b/api/core/tools/builtin_tool/providers/audio/tools/tts.py @@ -20,13 +20,14 @@ class TTSTool(BuiltinTool): app_id: str | None = None, message_id: str | None = None, ) -> Generator[ToolInvokeMessage, None, None]: - provider, model = tool_parameters.get("model").split("#") # type: ignore - voice = tool_parameters.get(f"voice#{provider}#{model}") - model_manager = ModelManager() if not self.runtime: raise ValueError("Runtime is required") + runtime = self.runtime + provider, model = tool_parameters.get("model").split("#") # type: ignore + voice = tool_parameters.get(f"voice#{provider}#{model}") + model_manager = ModelManager.for_tenant(tenant_id=runtime.tenant_id, user_id=user_id) model_instance = model_manager.get_model_instance( - tenant_id=self.runtime.tenant_id or "", + tenant_id=runtime.tenant_id or "", provider=provider, model_type=ModelType.TTS, model=model, @@ -39,12 +40,7 @@ class TTSTool(BuiltinTool): raise ValueError("Sorry, no voice available.") else: raise ValueError("Sorry, no voice available.") - tts = model_instance.invoke_tts( - content_text=tool_parameters.get("text"), # type: ignore - user=user_id, - tenant_id=self.runtime.tenant_id, - voice=voice, - ) + tts = model_instance.invoke_tts(content_text=tool_parameters.get("text"), voice=voice) # type: ignore[arg-type] buffer = io.BytesIO() for chunk in tts: buffer.write(chunk) diff --git a/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py index c2b520fa99..e8b69c35f5 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py @@ -65,7 +65,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): for thread in threads: thread.join() # do rerank for searched documents - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=self.tenant_id) rerank_model_instance = model_manager.get_model_instance( tenant_id=self.tenant_id, provider=self.reranking_provider_name, diff --git a/api/core/tools/utils/model_invocation_utils.py b/api/core/tools/utils/model_invocation_utils.py index 8f958563bd..80d7a2c226 100644 --- a/api/core/tools/utils/model_invocation_utils.py +++ b/api/core/tools/utils/model_invocation_utils.py @@ -37,7 +37,7 @@ class ModelInvocationUtils: """ get max llm context tokens of the model """ - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) model_instance = model_manager.get_default_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, @@ -65,7 +65,7 @@ class ModelInvocationUtils: """ # get model instance - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) model_instance = model_manager.get_default_model_instance(tenant_id=tenant_id, model_type=ModelType.LLM) if not model_instance: @@ -92,7 +92,7 @@ class ModelInvocationUtils: """ # get model manager - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) # get model instance model_instance = model_manager.get_default_model_instance( tenant_id=tenant_id, @@ -136,7 +136,6 @@ class ModelInvocationUtils: tools=[], stop=[], stream=False, - user=user_id, callbacks=[], ) except InvokeRateLimitError as e: diff --git a/api/core/workflow/node_factory.py b/api/core/workflow/node_factory.py index ab34263a79..2afe9aebd6 100644 --- a/api/core/workflow/node_factory.py +++ b/api/core/workflow/node_factory.py @@ -9,8 +9,8 @@ from sqlalchemy.orm import Session from typing_extensions import override from configs import dify_config -from core.app.entities.app_invoke_entities import DifyRunContext -from core.app.llm.model_access import build_dify_model_access +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext +from core.app.llm.model_access import build_dify_model_access, fetch_model_config from core.helper.code_executor.code_executor import ( CodeExecutionError, CodeExecutor, @@ -20,8 +20,17 @@ from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.prompt.entities.advanced_prompt_entities import MemoryConfig from core.repositories.human_input_repository import HumanInputFormRepositoryImpl -from core.tools.tool_file_manager import ToolFileManager from core.trigger.constants import TRIGGER_NODE_TYPES +from core.workflow.node_runtime import ( + DifyFileReferenceFactory, + DifyHumanInputNodeRuntime, + DifyPreparedLLM, + DifyPromptMessageSerializer, + DifyRetrieverAttachmentLoader, + DifyToolFileManager, + DifyToolNodeRuntime, + build_dify_llm_file_saver, +) from core.workflow.nodes.agent.message_transformer import AgentMessageTransformer from core.workflow.nodes.agent.plugin_strategy_adapter import ( PluginAgentStrategyPresentationProvider, @@ -30,11 +39,9 @@ from core.workflow.nodes.agent.plugin_strategy_adapter import ( from core.workflow.nodes.agent.runtime_support import AgentRuntimeSupport from dify_graph.entities.base_node_data import BaseNodeData from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY from dify_graph.enums import BuiltinNodeTypes, NodeType, SystemVariableKey from dify_graph.file.file_manager import file_manager from dify_graph.graph.graph import NodeFactory -from dify_graph.model_runtime.entities.model_entities import ModelType from dify_graph.model_runtime.memory import PromptMessageMemory from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from dify_graph.nodes.base.node import Node @@ -44,13 +51,10 @@ from dify_graph.nodes.code.limits import CodeNodeLimits from dify_graph.nodes.document_extractor import UnstructuredApiConfig from dify_graph.nodes.http_request import build_http_request_config from dify_graph.nodes.llm.entities import LLMNodeData -from dify_graph.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError from dify_graph.nodes.llm.protocols import TemplateRenderer from dify_graph.nodes.parameter_extractor.entities import ParameterExtractorNodeData from dify_graph.nodes.question_classifier.entities import QuestionClassifierNodeData -from dify_graph.nodes.template_transform.template_renderer import ( - CodeExecutorJinja2TemplateRenderer, -) +from dify_graph.template_rendering import CodeExecutorJinja2TemplateRenderer from dify_graph.variables.segments import StringSegment from extensions.ext_database import db from models.model import Conversation @@ -264,11 +268,22 @@ class DifyNodeFactory(NodeFactory): max_string_array_length=dify_config.CODE_MAX_STRING_ARRAY_LENGTH, max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH, ) - self._template_renderer = CodeExecutorJinja2TemplateRenderer(code_executor=self._code_executor) + self._jinja2_template_renderer = CodeExecutorJinja2TemplateRenderer(code_executor=self._code_executor) self._llm_template_renderer: TemplateRenderer = DefaultLLMTemplateRenderer() self._template_transform_max_output_length = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH self._http_request_http_client = ssrf_proxy - self._http_request_tool_file_manager_factory = ToolFileManager + self._bound_tool_file_manager_factory = lambda: DifyToolFileManager(self._dify_context) + self._file_reference_factory = DifyFileReferenceFactory(self._dify_context) + self._prompt_message_serializer = DifyPromptMessageSerializer() + self._retriever_attachment_loader = DifyRetrieverAttachmentLoader( + file_reference_factory=self._file_reference_factory, + ) + self._llm_file_saver = build_dify_llm_file_saver( + run_context=self._dify_context, + http_client=self._http_request_http_client, + ) + self._human_input_runtime = DifyHumanInputNodeRuntime(self._dify_context) + self._tool_runtime = DifyToolNodeRuntime(self._dify_context) self._http_request_file_manager = file_manager self._document_extractor_unstructured_api_config = UnstructuredApiConfig( api_url=dify_config.UNSTRUCTURED_API_URL, @@ -284,7 +299,7 @@ class DifyNodeFactory(NodeFactory): ssrf_default_max_retries=dify_config.SSRF_DEFAULT_MAX_RETRIES, ) - self._llm_credentials_provider, self._llm_model_factory = build_dify_model_access(self._dify_context.tenant_id) + self._llm_credentials_provider, self._llm_model_factory = build_dify_model_access(self._dify_context) self._agent_strategy_resolver = PluginAgentStrategyResolver() self._agent_strategy_presentation_provider = PluginAgentStrategyPresentationProvider() self._agent_runtime_support = AgentRuntimeSupport() @@ -321,22 +336,32 @@ class DifyNodeFactory(NodeFactory): "code_limits": self._code_limits, }, BuiltinNodeTypes.TEMPLATE_TRANSFORM: lambda: { - "template_renderer": self._template_renderer, + "jinja2_template_renderer": self._jinja2_template_renderer, "max_output_length": self._template_transform_max_output_length, }, BuiltinNodeTypes.HTTP_REQUEST: lambda: { "http_request_config": self._http_request_config, "http_client": self._http_request_http_client, - "tool_file_manager_factory": self._http_request_tool_file_manager_factory, + "tool_file_manager_factory": self._bound_tool_file_manager_factory, "file_manager": self._http_request_file_manager, + "file_reference_factory": self._file_reference_factory, }, BuiltinNodeTypes.HUMAN_INPUT: lambda: { - "form_repository": HumanInputFormRepositoryImpl(tenant_id=self._dify_context.tenant_id), + "form_repository": HumanInputFormRepositoryImpl( + tenant_id=self._dify_context.tenant_id, + app_id=self._dify_context.app_id, + ), + "runtime": self._human_input_runtime, }, BuiltinNodeTypes.LLM: lambda: self._build_llm_compatible_node_init_kwargs( node_class=node_class, node_data=node_data, + wrap_model_instance=True, include_http_client=True, + include_llm_file_saver=True, + include_prompt_message_serializer=True, + include_retriever_attachment_loader=True, + include_jinja2_template_renderer=True, ), BuiltinNodeTypes.DOCUMENT_EXTRACTOR: lambda: { "unstructured_api_config": self._document_extractor_unstructured_api_config, @@ -345,15 +370,26 @@ class DifyNodeFactory(NodeFactory): BuiltinNodeTypes.QUESTION_CLASSIFIER: lambda: self._build_llm_compatible_node_init_kwargs( node_class=node_class, node_data=node_data, + wrap_model_instance=True, include_http_client=True, + include_llm_file_saver=True, + include_prompt_message_serializer=True, + include_retriever_attachment_loader=False, + include_jinja2_template_renderer=False, ), BuiltinNodeTypes.PARAMETER_EXTRACTOR: lambda: self._build_llm_compatible_node_init_kwargs( node_class=node_class, node_data=node_data, + wrap_model_instance=True, include_http_client=False, + include_llm_file_saver=False, + include_prompt_message_serializer=True, + include_retriever_attachment_loader=False, + include_jinja2_template_renderer=False, ), BuiltinNodeTypes.TOOL: lambda: { - "tool_file_manager_factory": self._http_request_tool_file_manager_factory(), + "tool_file_manager_factory": self._bound_tool_file_manager_factory(), + "runtime": self._tool_runtime, }, BuiltinNodeTypes.AGENT: lambda: { "strategy_resolver": self._agent_strategy_resolver, @@ -387,7 +423,12 @@ class DifyNodeFactory(NodeFactory): *, node_class: type[Node], node_data: BaseNodeData, + wrap_model_instance: bool, include_http_client: bool, + include_llm_file_saver: bool, + include_prompt_message_serializer: bool, + include_retriever_attachment_loader: bool, + include_jinja2_template_renderer: bool, ) -> dict[str, object]: validated_node_data = cast( LLMCompatibleNodeData, @@ -397,49 +438,33 @@ class DifyNodeFactory(NodeFactory): node_init_kwargs: dict[str, object] = { "credentials_provider": self._llm_credentials_provider, "model_factory": self._llm_model_factory, - "model_instance": model_instance, + "model_instance": DifyPreparedLLM(model_instance) if wrap_model_instance else model_instance, "memory": self._build_memory_for_llm_node( node_data=validated_node_data, model_instance=model_instance, ), } - if validated_node_data.type in {BuiltinNodeTypes.LLM, BuiltinNodeTypes.QUESTION_CLASSIFIER}: + if validated_node_data.type == BuiltinNodeTypes.QUESTION_CLASSIFIER: node_init_kwargs["template_renderer"] = self._llm_template_renderer if include_http_client: node_init_kwargs["http_client"] = self._http_request_http_client + if include_llm_file_saver: + node_init_kwargs["llm_file_saver"] = self._llm_file_saver + if include_prompt_message_serializer: + node_init_kwargs["prompt_message_serializer"] = self._prompt_message_serializer + if include_retriever_attachment_loader: + node_init_kwargs["retriever_attachment_loader"] = self._retriever_attachment_loader + if include_jinja2_template_renderer: + node_init_kwargs["jinja2_template_renderer"] = self._jinja2_template_renderer return node_init_kwargs def _build_model_instance_for_llm_node(self, node_data: LLMCompatibleNodeData) -> ModelInstance: node_data_model = node_data.model - if not node_data_model.mode: - raise LLMModeRequiredError("LLM mode is required.") - - credentials = self._llm_credentials_provider.fetch(node_data_model.provider, node_data_model.name) - model_instance = self._llm_model_factory.init_model_instance(node_data_model.provider, node_data_model.name) - provider_model_bundle = model_instance.provider_model_bundle - - provider_model = provider_model_bundle.configuration.get_provider_model( - model=node_data_model.name, - model_type=ModelType.LLM, + model_instance, _ = fetch_model_config( + node_data_model=node_data_model, + credentials_provider=self._llm_credentials_provider, + model_factory=self._llm_model_factory, ) - if provider_model is None: - raise ModelNotExistError(f"Model {node_data_model.name} not exist.") - provider_model.raise_for_status() - - completion_params = dict(node_data_model.completion_params) - stop = completion_params.pop("stop", []) - if not isinstance(stop, list): - stop = [] - - model_schema = model_instance.model_type_instance.get_model_schema(node_data_model.name, credentials) - if not model_schema: - raise ModelNotExistError(f"Model {node_data_model.name} not exist.") - - model_instance.provider = node_data_model.provider - model_instance.model_name = node_data_model.name - model_instance.credentials = credentials - model_instance.parameters = completion_params - model_instance.stop = tuple(stop) model_instance.model_type_instance = cast(LargeLanguageModel, model_instance.model_type_instance) return model_instance diff --git a/api/core/workflow/node_runtime.py b/api/core/workflow/node_runtime.py new file mode 100644 index 0000000000..107e55993a --- /dev/null +++ b/api/core/workflow/node_runtime.py @@ -0,0 +1,532 @@ +from __future__ import annotations + +from collections.abc import Generator, Mapping, Sequence +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, cast + +from sqlalchemy import select +from sqlalchemy.orm import Session + +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext +from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler +from core.llm_generator.output_parser.errors import OutputParserError +from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output +from core.model_manager import ModelInstance +from core.plugin.impl.exc import PluginDaemonClientSideError, PluginInvokeError +from core.plugin.impl.plugin import PluginInstaller +from core.prompt.utils.prompt_message_util import PromptMessageUtil +from core.tools.entities.tool_entities import ToolProviderType as CoreToolProviderType +from core.tools.errors import ToolInvokeError +from core.tools.signature import sign_upload_file +from core.tools.tool_engine import ToolEngine +from core.tools.tool_file_manager import ToolFileManager +from core.tools.tool_manager import ToolManager +from core.tools.utils.message_transformer import ToolFileMessageTransformer +from dify_graph.file import FileTransferMethod, FileType +from dify_graph.model_runtime.entities import LLMMode +from dify_graph.model_runtime.entities.llm_entities import ( + LLMResult, + LLMResultChunk, + LLMResultChunkWithStructuredOutput, + LLMResultWithStructuredOutput, + LLMUsage, +) +from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool +from dify_graph.model_runtime.entities.model_entities import AIModelEntity +from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from dify_graph.nodes.human_input.entities import DeliveryChannelConfig, apply_debug_email_recipient +from dify_graph.nodes.llm.runtime_protocols import ( + PreparedLLMProtocol, + PromptMessageSerializerProtocol, + RetrieverAttachmentLoaderProtocol, +) +from dify_graph.nodes.protocols import FileReferenceFactoryProtocol, HttpClientProtocol, ToolFileManagerProtocol +from dify_graph.nodes.runtime import ( + HumanInputNodeRuntimeProtocol, + ToolNodeRuntimeProtocol, +) +from dify_graph.nodes.tool.exc import ToolNodeError, ToolRuntimeInvocationError, ToolRuntimeResolutionError +from dify_graph.nodes.tool_runtime_entities import ( + ToolRuntimeHandle, + ToolRuntimeMessage, + ToolRuntimeParameter, +) +from extensions.ext_database import db +from factories import file_factory +from models.dataset import SegmentAttachmentBinding +from models.model import UploadFile +from services.tools.builtin_tools_manage_service import BuiltinToolManageService + +if TYPE_CHECKING: + from core.tools.__base.tool import Tool + from core.tools.entities.tool_entities import ToolInvokeMessage as CoreToolInvokeMessage + from dify_graph.file import File + from dify_graph.nodes.llm.file_saver import LLMFileSaver + from dify_graph.nodes.tool.entities import ToolNodeData + + +def resolve_dify_run_context(run_context: Mapping[str, Any] | DifyRunContext) -> DifyRunContext: + if isinstance(run_context, DifyRunContext): + return run_context + + raw_ctx = run_context.get(DIFY_RUN_CONTEXT_KEY) + if raw_ctx is None: + raise ValueError(f"run_context missing required key: {DIFY_RUN_CONTEXT_KEY}") + if isinstance(raw_ctx, DifyRunContext): + return raw_ctx + return DifyRunContext.model_validate(raw_ctx) + + +class DifyFileReferenceFactory(FileReferenceFactoryProtocol): + def __init__(self, run_context: Mapping[str, Any] | DifyRunContext) -> None: + self._run_context = resolve_dify_run_context(run_context) + + def build_from_mapping(self, *, mapping: Mapping[str, Any]): + return file_factory.build_from_mapping( + mapping=mapping, + tenant_id=self._run_context.tenant_id, + ) + + +class DifyPreparedLLM(PreparedLLMProtocol): + """Workflow-layer adapter that hides the full `ModelInstance` API from `dify_graph` nodes.""" + + def __init__(self, model_instance: ModelInstance) -> None: + self._model_instance = model_instance + + @property + def provider(self) -> str: + return self._model_instance.provider + + @property + def model_name(self) -> str: + return self._model_instance.model_name + + @property + def parameters(self) -> Mapping[str, Any]: + return self._model_instance.parameters + + @property + def stop(self) -> Sequence[str] | None: + return self._model_instance.stop + + def get_model_schema(self) -> AIModelEntity: + model_schema = cast(LargeLanguageModel, self._model_instance.model_type_instance).get_model_schema( + self._model_instance.model_name, + self._model_instance.credentials, + ) + if model_schema is None: + raise ValueError(f"Model schema not found for {self._model_instance.model_name}") + return model_schema + + def get_llm_num_tokens(self, prompt_messages: Sequence[PromptMessage]) -> int: + return self._model_instance.get_llm_num_tokens(prompt_messages) + + def invoke_llm( + self, + *, + prompt_messages: Sequence[PromptMessage], + model_parameters: Mapping[str, Any], + tools: Sequence[PromptMessageTool] | None, + stop: Sequence[str] | None, + stream: bool, + ) -> LLMResult | Generator[LLMResultChunk, None, None]: + return self._model_instance.invoke_llm( + prompt_messages=list(prompt_messages), + model_parameters=dict(model_parameters), + tools=list(tools or []), + stop=list(stop or []), + stream=stream, + ) + + def invoke_llm_with_structured_output( + self, + *, + prompt_messages: Sequence[PromptMessage], + json_schema: Mapping[str, Any], + model_parameters: Mapping[str, Any], + stop: Sequence[str] | None, + stream: bool, + ) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]: + return invoke_llm_with_structured_output( + provider=self.provider, + model_schema=self.get_model_schema(), + model_instance=self._model_instance, + prompt_messages=prompt_messages, + json_schema=json_schema, + model_parameters=model_parameters, + stop=list(stop or []), + stream=stream, + ) + + def is_structured_output_parse_error(self, error: Exception) -> bool: + return isinstance(error, OutputParserError) + + +class DifyPromptMessageSerializer(PromptMessageSerializerProtocol): + def serialize( + self, + *, + model_mode: LLMMode, + prompt_messages: Sequence[PromptMessage], + ) -> Any: + return PromptMessageUtil.prompt_messages_to_prompt_for_saving( + model_mode=model_mode, + prompt_messages=prompt_messages, + ) + + +class DifyRetrieverAttachmentLoader(RetrieverAttachmentLoaderProtocol): + """Resolve retriever attachments through Dify persistence and return graph file references.""" + + def __init__(self, *, file_reference_factory: FileReferenceFactoryProtocol) -> None: + self._file_reference_factory = file_reference_factory + + def load(self, *, segment_id: str) -> Sequence[File]: + with Session(db.engine, expire_on_commit=False) as session: + attachments_with_bindings = session.execute( + select(SegmentAttachmentBinding, UploadFile) + .join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id) + .where(SegmentAttachmentBinding.segment_id == segment_id) + ).all() + + return [ + self._file_reference_factory.build_from_mapping( + mapping={ + "id": upload_file.id, + "filename": upload_file.name, + "extension": "." + upload_file.extension, + "mime_type": upload_file.mime_type, + "type": FileType.IMAGE, + "transfer_method": FileTransferMethod.LOCAL_FILE, + "remote_url": upload_file.source_url, + "related_id": upload_file.id, + "upload_file_id": upload_file.id, + "size": upload_file.size, + "storage_key": upload_file.key, + "url": sign_upload_file(upload_file.id, upload_file.extension), + } + ) + for _, upload_file in attachments_with_bindings + ] + + +class DifyToolFileManager(ToolFileManagerProtocol): + def __init__(self, run_context: Mapping[str, Any] | DifyRunContext) -> None: + self._run_context = resolve_dify_run_context(run_context) + self._manager = ToolFileManager() + + def create_file_by_raw( + self, + *, + conversation_id: str | None, + file_binary: bytes, + mimetype: str, + filename: str | None = None, + ) -> Any: + return self._manager.create_file_by_raw( + user_id=self._run_context.user_id, + tenant_id=self._run_context.tenant_id, + conversation_id=conversation_id, + file_binary=file_binary, + mimetype=mimetype, + filename=filename, + ) + + def get_file_generator_by_tool_file_id(self, tool_file_id: str): + return self._manager.get_file_generator_by_tool_file_id(tool_file_id) + + +@dataclass(frozen=True, slots=True) +class _WorkflowToolRuntimeSpec: + provider_type: CoreToolProviderType + provider_id: str + tool_name: str + tool_configurations: dict[str, Any] + credential_id: str | None = None + + +class DifyToolNodeRuntime(ToolNodeRuntimeProtocol): + def __init__(self, run_context: Mapping[str, Any] | DifyRunContext) -> None: + self._run_context = resolve_dify_run_context(run_context) + self._file_reference_factory = DifyFileReferenceFactory(self._run_context) + + @property + def file_reference_factory(self) -> FileReferenceFactoryProtocol: + return self._file_reference_factory + + def build_file_reference(self, *, mapping: Mapping[str, Any]): + return self._file_reference_factory.build_from_mapping(mapping=mapping) + + def get_runtime( + self, + *, + node_id: str, + node_data: ToolNodeData, + variable_pool, + ) -> ToolRuntimeHandle: + try: + tool_runtime = ToolManager.get_workflow_tool_runtime( + self._run_context.tenant_id, + self._run_context.app_id, + node_id, + self._build_tool_runtime_spec(node_data), + self._run_context.invoke_from, + variable_pool, + ) + except ToolNodeError: + raise + except Exception as exc: + raise ToolRuntimeResolutionError(str(exc)) from exc + + return ToolRuntimeHandle(raw=tool_runtime) + + def get_runtime_parameters( + self, + *, + tool_runtime: ToolRuntimeHandle, + ) -> Sequence[ToolRuntimeParameter]: + tool = self._tool_from_handle(tool_runtime) + return [ + ToolRuntimeParameter(name=parameter.name, required=parameter.required) + for parameter in (tool.get_merged_runtime_parameters() or []) + ] + + def invoke( + self, + *, + tool_runtime: ToolRuntimeHandle, + tool_parameters: Mapping[str, Any], + workflow_call_depth: int, + conversation_id: str | None, + provider_name: str, + ) -> Generator[ToolRuntimeMessage, None, None]: + tool = self._tool_from_handle(tool_runtime) + callback = DifyWorkflowCallbackHandler() + + try: + messages = ToolEngine.generic_invoke( + tool=tool, + tool_parameters=dict(tool_parameters), + user_id=self._run_context.user_id, + workflow_tool_callback=callback, + workflow_call_depth=workflow_call_depth, + app_id=self._run_context.app_id, + conversation_id=conversation_id, + ) + except Exception as exc: + raise self._map_invocation_exception(exc, provider_name=provider_name) from exc + + transformed_messages = ToolFileMessageTransformer.transform_tool_invoke_messages( + messages=messages, + user_id=self._run_context.user_id, + tenant_id=self._run_context.tenant_id, + conversation_id=None, + ) + + return self._adapt_messages(transformed_messages, provider_name=provider_name) + + def get_usage( + self, + *, + tool_runtime: ToolRuntimeHandle, + ) -> LLMUsage: + latest = getattr(tool_runtime.raw, "latest_usage", None) + if isinstance(latest, LLMUsage): + return latest + if isinstance(latest, dict): + return LLMUsage.model_validate(latest) + return LLMUsage.empty_usage() + + def resolve_provider_icons( + self, + *, + provider_name: str, + default_icon: str | None = None, + ) -> tuple[str | None, str | None]: + icon = default_icon + icon_dark = None + + manager = PluginInstaller() + plugins = manager.list_plugins(self._run_context.tenant_id) + try: + current_plugin = next(plugin for plugin in plugins if f"{plugin.plugin_id}/{plugin.name}" == provider_name) + icon = current_plugin.declaration.icon + except StopIteration: + pass + + try: + builtin_tool = next( + provider + for provider in BuiltinToolManageService.list_builtin_tools( + self._run_context.user_id, + self._run_context.tenant_id, + ) + if provider.name == provider_name + ) + icon = builtin_tool.icon + icon_dark = builtin_tool.icon_dark + except StopIteration: + pass + + return icon, icon_dark + + @staticmethod + def _tool_from_handle(tool_runtime: ToolRuntimeHandle) -> Tool: + return cast("Tool", tool_runtime.raw) + + @staticmethod + def _build_tool_runtime_spec(node_data: ToolNodeData) -> _WorkflowToolRuntimeSpec: + return _WorkflowToolRuntimeSpec( + provider_type=CoreToolProviderType(node_data.provider_type.value), + provider_id=node_data.provider_id, + tool_name=node_data.tool_name, + tool_configurations=dict(node_data.tool_configurations), + credential_id=node_data.credential_id, + ) + + def _adapt_messages( + self, + messages: Generator[CoreToolInvokeMessage, None, None], + *, + provider_name: str, + ) -> Generator[ToolRuntimeMessage, None, None]: + try: + for message in messages: + yield self._convert_message(message) + except Exception as exc: + raise self._map_invocation_exception(exc, provider_name=provider_name) from exc + + def _convert_message(self, message: CoreToolInvokeMessage) -> ToolRuntimeMessage: + graph_message_type = ToolRuntimeMessage.MessageType(message.type.value) + graph_message = self._convert_message_payload(message.message) + graph_meta = message.meta.copy() if message.meta is not None else None + return ToolRuntimeMessage(type=graph_message_type, message=graph_message, meta=graph_meta) + + def _convert_message_payload( + self, + message: CoreToolInvokeMessage.TextMessage + | CoreToolInvokeMessage.JsonMessage + | CoreToolInvokeMessage.BlobChunkMessage + | CoreToolInvokeMessage.BlobMessage + | CoreToolInvokeMessage.LogMessage + | CoreToolInvokeMessage.FileMessage + | CoreToolInvokeMessage.VariableMessage + | CoreToolInvokeMessage.RetrieverResourceMessage + | None, + ) -> ( + ToolRuntimeMessage.TextMessage + | ToolRuntimeMessage.JsonMessage + | ToolRuntimeMessage.BlobChunkMessage + | ToolRuntimeMessage.BlobMessage + | ToolRuntimeMessage.LogMessage + | ToolRuntimeMessage.FileMessage + | ToolRuntimeMessage.VariableMessage + | ToolRuntimeMessage.RetrieverResourceMessage + | None + ): + if message is None: + return None + + from core.tools.entities.tool_entities import ToolInvokeMessage as CoreToolInvokeMessage + + if isinstance(message, CoreToolInvokeMessage.TextMessage): + return ToolRuntimeMessage.TextMessage(text=message.text) + if isinstance(message, CoreToolInvokeMessage.JsonMessage): + return ToolRuntimeMessage.JsonMessage( + json_object=message.json_object, + suppress_output=message.suppress_output, + ) + if isinstance(message, CoreToolInvokeMessage.BlobMessage): + return ToolRuntimeMessage.BlobMessage(blob=message.blob) + if isinstance(message, CoreToolInvokeMessage.BlobChunkMessage): + return ToolRuntimeMessage.BlobChunkMessage( + id=message.id, + sequence=message.sequence, + total_length=message.total_length, + blob=message.blob, + end=message.end, + ) + if isinstance(message, CoreToolInvokeMessage.FileMessage): + return ToolRuntimeMessage.FileMessage(file_marker=message.file_marker) + if isinstance(message, CoreToolInvokeMessage.VariableMessage): + return ToolRuntimeMessage.VariableMessage( + variable_name=message.variable_name, + variable_value=message.variable_value, + stream=message.stream, + ) + if isinstance(message, CoreToolInvokeMessage.LogMessage): + return ToolRuntimeMessage.LogMessage( + id=message.id, + label=message.label, + parent_id=message.parent_id, + error=message.error, + status=ToolRuntimeMessage.LogMessage.LogStatus(message.status.value), + data=dict(message.data), + metadata=dict(message.metadata), + ) + if isinstance(message, CoreToolInvokeMessage.RetrieverResourceMessage): + retriever_resources = [ + resource.model_dump() if hasattr(resource, "model_dump") else dict(resource) + for resource in message.retriever_resources + ] + return ToolRuntimeMessage.RetrieverResourceMessage( + retriever_resources=retriever_resources, + context=message.context, + ) + + raise TypeError(f"unsupported tool message payload: {type(message).__name__}") + + @staticmethod + def _map_invocation_exception(exc: Exception, *, provider_name: str) -> ToolNodeError: + if isinstance(exc, ToolNodeError): + return exc + if isinstance(exc, PluginInvokeError): + return ToolRuntimeInvocationError(exc.to_user_friendly_error(plugin_name=provider_name)) + if isinstance(exc, PluginDaemonClientSideError): + return ToolRuntimeInvocationError(f"Failed to invoke tool, error: {exc.description}") + if isinstance(exc, ToolInvokeError): + return ToolRuntimeInvocationError(f"Failed to invoke tool {provider_name}: {exc}") + return ToolRuntimeInvocationError(str(exc)) + + +class DifyHumanInputNodeRuntime(HumanInputNodeRuntimeProtocol): + def __init__(self, run_context: Mapping[str, Any] | DifyRunContext) -> None: + self._run_context = resolve_dify_run_context(run_context) + + def invoke_source(self) -> str: + invoke_from = self._run_context.invoke_from + if isinstance(invoke_from, str): + return invoke_from + return str(getattr(invoke_from, "value", invoke_from)) + + def apply_delivery_runtime( + self, + *, + methods: Sequence[DeliveryChannelConfig], + ) -> Sequence[DeliveryChannelConfig]: + return [ + apply_debug_email_recipient( + method, + enabled=self.invoke_source() == "debugger", + user_id=self._run_context.user_id, + ) + for method in methods + ] + + def console_actor_id(self) -> str | None: + return self._run_context.user_id + + +def build_dify_llm_file_saver( + *, + run_context: Mapping[str, Any] | DifyRunContext, + http_client: HttpClientProtocol, +) -> LLMFileSaver: + from dify_graph.nodes.llm.file_saver import FileSaverImpl + + return FileSaverImpl( + tool_file_manager=DifyToolFileManager(run_context), + file_reference_factory=DifyFileReferenceFactory(run_context), + http_client=http_client, + ) diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index 5699ccf404..202f6d9921 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -3,6 +3,7 @@ from __future__ import annotations from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext from dify_graph.entities.graph_config import NodeConfigDict from dify_graph.enums import BuiltinNodeTypes, SystemVariableKey, WorkflowNodeExecutionStatus from dify_graph.node_events import NodeEventBase, NodeRunResult, StreamCompletedEvent @@ -59,7 +60,7 @@ class AgentNode(Node[AgentNodeData]): return "1" def populate_start_event(self, event) -> None: - dify_ctx = self.require_dify_context() + dify_ctx = DifyRunContext.model_validate(self.require_run_context_value(DIFY_RUN_CONTEXT_KEY)) event.extras["agent_strategy"] = { "name": self.node_data.agent_strategy_name, "icon": self._presentation_provider.get_icon( @@ -71,7 +72,7 @@ class AgentNode(Node[AgentNodeData]): def _run(self) -> Generator[NodeEventBase, None, None]: from core.plugin.impl.exc import PluginDaemonClientSideError - dify_ctx = self.require_dify_context() + dify_ctx = DifyRunContext.model_validate(self.require_run_context_value(DIFY_RUN_CONTEXT_KEY)) try: strategy = self._strategy_resolver.resolve( @@ -97,6 +98,7 @@ class AgentNode(Node[AgentNodeData]): node_data=self.node_data, strategy=strategy, tenant_id=dify_ctx.tenant_id, + user_id=dify_ctx.user_id, app_id=dify_ctx.app_id, invoke_from=dify_ctx.invoke_from, ) @@ -106,6 +108,7 @@ class AgentNode(Node[AgentNodeData]): node_data=self.node_data, strategy=strategy, tenant_id=dify_ctx.tenant_id, + user_id=dify_ctx.user_id, app_id=dify_ctx.app_id, invoke_from=dify_ctx.invoke_from, for_log=True, diff --git a/api/core/workflow/nodes/agent/runtime_support.py b/api/core/workflow/nodes/agent/runtime_support.py index 2ff7c964b9..3c8f9bb2cd 100644 --- a/api/core/workflow/nodes/agent/runtime_support.py +++ b/api/core/workflow/nodes/agent/runtime_support.py @@ -14,7 +14,7 @@ from core.agent.plugin_entities import AgentStrategyParameter from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance, ModelManager from core.plugin.entities.request import InvokeCredentials -from core.provider_manager import ProviderManager +from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager from core.tools.entities.tool_entities import ToolIdentity, ToolParameter, ToolProviderType from core.tools.tool_manager import ToolManager from dify_graph.enums import SystemVariableKey @@ -38,6 +38,7 @@ class AgentRuntimeSupport: node_data: AgentNodeData, strategy: ResolvedAgentStrategy, tenant_id: str, + user_id: str, app_id: str, invoke_from: Any, for_log: bool = False, @@ -174,7 +175,11 @@ class AgentRuntimeSupport: value = tool_value if parameter.type == AgentStrategyParameter.AgentStrategyParameterType.MODEL_SELECTOR: value = cast(dict[str, Any], value) - model_instance, model_schema = self.fetch_model(tenant_id=tenant_id, value=value) + model_instance, model_schema = self.fetch_model( + tenant_id=tenant_id, + user_id=user_id, + value=value, + ) history_prompt_messages = [] if node_data.memory: memory = self.fetch_memory( @@ -232,8 +237,14 @@ class AgentRuntimeSupport: return TokenBufferMemory(conversation=conversation, model_instance=model_instance) - def fetch_model(self, *, tenant_id: str, value: dict[str, Any]) -> tuple[ModelInstance, AIModelEntity | None]: - provider_manager = ProviderManager() + def fetch_model( + self, + *, + tenant_id: str, + user_id: str, + value: dict[str, Any], + ) -> tuple[ModelInstance, AIModelEntity | None]: + provider_manager = create_plugin_provider_manager(tenant_id=tenant_id, user_id=user_id) provider_model_bundle = provider_manager.get_provider_model_bundle( tenant_id=tenant_id, provider=value.get("provider", ""), @@ -246,7 +257,7 @@ class AgentRuntimeSupport: ) provider_name = provider_model_bundle.configuration.provider.provider model_type_instance = provider_model_bundle.model_type_instance - model_instance = ModelManager().get_model_instance( + model_instance = ModelManager.for_tenant(tenant_id=tenant_id, user_id=user_id).get_model_instance( tenant_id=tenant_id, provider=provider_name, model_type=ModelType(value.get("model_type", "")), diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index 44f4a23a5a..9a80dd9af6 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -1,6 +1,7 @@ from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext from core.datasource.datasource_manager import DatasourceManager from core.datasource.entities.datasource_entities import DatasourceProviderType from core.plugin.impl.exc import PluginDaemonClientSideError @@ -50,8 +51,7 @@ class DatasourceNode(Node[DatasourceNodeData]): """ Run the datasource node """ - - dify_ctx = self.require_dify_context() + dify_ctx = DifyRunContext.model_validate(self.require_run_context_value(DIFY_RUN_CONTEXT_KEY)) node_data = self.node_data variable_pool = self.graph_runtime_state.variable_pool datasource_type_segment = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE]) diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 80f59140be..871bd2e968 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -9,6 +9,7 @@ from collections.abc import Mapping, Sequence from typing import TYPE_CHECKING, Any, Literal from core.app.app_config.entities import DatasetRetrieveConfigEntity +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from dify_graph.entities import GraphInitParams @@ -160,7 +161,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD def _fetch_dataset_retriever( self, node_data: KnowledgeRetrievalNodeData, variables: dict[str, Any] ) -> tuple[list[Source], LLMUsage]: - dify_ctx = self.require_dify_context() + dify_ctx = DifyRunContext.model_validate(self.require_run_context_value(DIFY_RUN_CONTEXT_KEY)) dataset_ids = node_data.dataset_ids query = variables.get("query") attachments = variables.get("attachments") diff --git a/api/core/workflow/nodes/trigger_webhook/node.py b/api/core/workflow/nodes/trigger_webhook/node.py index 317844cbda..c4a98345c6 100644 --- a/api/core/workflow/nodes/trigger_webhook/node.py +++ b/api/core/workflow/nodes/trigger_webhook/node.py @@ -9,9 +9,9 @@ from dify_graph.enums import NodeExecutionType from dify_graph.file import FileTransferMethod from dify_graph.node_events import NodeRunResult from dify_graph.nodes.base.node import Node +from dify_graph.nodes.protocols import FileReferenceFactoryProtocol from dify_graph.variables.types import SegmentType from dify_graph.variables.variables import FileVariable -from factories import file_factory from factories.variable_factory import build_segment_with_type from .entities import ContentType, WebhookData @@ -23,6 +23,13 @@ class TriggerWebhookNode(Node[WebhookData]): node_type = TRIGGER_WEBHOOK_NODE_TYPE execution_type = NodeExecutionType.ROOT + _file_reference_factory: FileReferenceFactoryProtocol + + def post_init(self) -> None: + from core.workflow.node_runtime import DifyFileReferenceFactory + + self._file_reference_factory = DifyFileReferenceFactory(self.graph_init_params.run_context) + @classmethod def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: return { @@ -70,7 +77,6 @@ class TriggerWebhookNode(Node[WebhookData]): ) def generate_file_var(self, param_name: str, file: dict): - dify_ctx = self.require_dify_context() related_id = file.get("related_id") transfer_method_value = file.get("transfer_method") if transfer_method_value: @@ -84,10 +90,7 @@ class TriggerWebhookNode(Node[WebhookData]): file["datasource_file_id"] = related_id try: - file_obj = file_factory.build_from_mapping( - mapping=file, - tenant_id=dify_ctx.tenant_id, - ) + file_obj = self._file_reference_factory.build_from_mapping(mapping=file) file_segment = build_segment_with_type(SegmentType.FILE, file_obj) return FileVariable(name=param_name, value=file_segment.value, selector=[self.id, param_name]) except ValueError: diff --git a/api/dify_graph/entities/graph_init_params.py b/api/dify_graph/entities/graph_init_params.py index f785d58a52..160be9b7b2 100644 --- a/api/dify_graph/entities/graph_init_params.py +++ b/api/dify_graph/entities/graph_init_params.py @@ -3,8 +3,6 @@ from typing import Any from pydantic import BaseModel, Field -DIFY_RUN_CONTEXT_KEY = "_dify" - class GraphInitParams(BaseModel): """GraphInitParams encapsulates the configurations and contextual information diff --git a/api/dify_graph/entities/workflow_execution.py b/api/dify_graph/entities/workflow_execution.py index 459ac46415..293a66c56e 100644 --- a/api/dify_graph/entities/workflow_execution.py +++ b/api/dify_graph/entities/workflow_execution.py @@ -14,7 +14,7 @@ from typing import Any from pydantic import BaseModel, Field from dify_graph.enums import WorkflowExecutionStatus, WorkflowType -from libs.datetime_utils import naive_utc_now +from dify_graph.utils.datetime_utils import naive_utc_now class WorkflowExecution(BaseModel): diff --git a/api/dify_graph/graph/graph.py b/api/dify_graph/graph/graph.py index 85117583e0..add69ad884 100644 --- a/api/dify_graph/graph/graph.py +++ b/api/dify_graph/graph/graph.py @@ -10,7 +10,6 @@ from pydantic import TypeAdapter from dify_graph.entities.graph_config import NodeConfigDict from dify_graph.enums import ErrorStrategy, NodeExecutionType, NodeState from dify_graph.nodes.base.node import Node -from libs.typing import is_str from .edge import Edge from .validation import get_graph_validator @@ -102,7 +101,7 @@ class Graph: source = edge_config.get("source") target = edge_config.get("target") - if not is_str(source) or not is_str(target): + if not isinstance(source, str) or not isinstance(target, str): continue # Create edge @@ -110,7 +109,7 @@ class Graph: edge_counter += 1 source_handle = edge_config.get("sourceHandle", "source") - if not is_str(source_handle): + if not isinstance(source_handle, str): continue edge = Edge( diff --git a/api/dify_graph/graph_events/node.py b/api/dify_graph/graph_events/node.py index df19d6c03b..f9a026b5e2 100644 --- a/api/dify_graph/graph_events/node.py +++ b/api/dify_graph/graph_events/node.py @@ -1,9 +1,9 @@ -from collections.abc import Sequence +from collections.abc import Mapping, Sequence from datetime import datetime +from typing import Any from pydantic import Field -from core.rag.entities.citation_metadata import RetrievalSourceMetadata from dify_graph.entities.pause_reason import PauseReason from .base import GraphNodeEventBase @@ -30,7 +30,7 @@ class NodeRunStreamChunkEvent(GraphNodeEventBase): class NodeRunRetrieverResourceEvent(GraphNodeEventBase): - retriever_resources: Sequence[RetrievalSourceMetadata] = Field(..., description="retriever resources") + retriever_resources: Sequence[Mapping[str, Any]] = Field(..., description="retriever resources") context: str = Field(..., description="context") diff --git a/api/dify_graph/model_runtime/entities/provider_entities.py b/api/dify_graph/model_runtime/entities/provider_entities.py index 97a99ea7ce..55e1775826 100644 --- a/api/dify_graph/model_runtime/entities/provider_entities.py +++ b/api/dify_graph/model_runtime/entities/provider_entities.py @@ -93,10 +93,14 @@ class ModelCredentialSchema(BaseModel): class SimpleProviderEntity(BaseModel): """ - Simple model class for provider. + Simplified provider schema exposed to callers. + + `provider` is the canonical runtime identifier. `provider_name` is an optional + compatibility alias for short-name lookups and is empty when no alias exists. """ provider: str + provider_name: str = "" label: I18nObject icon_small: I18nObject | None = None icon_small_dark: I18nObject | None = None @@ -115,10 +119,15 @@ class ProviderHelpEntity(BaseModel): class ProviderEntity(BaseModel): """ - Model class for provider. + Runtime-native provider schema. + + `provider` is the canonical runtime identifier. `provider_name` is a + compatibility alias for callers that still resolve providers by short name and + is empty when no alias exists. """ provider: str + provider_name: str = "" label: I18nObject description: I18nObject | None = None icon_small: I18nObject | None = None @@ -153,6 +162,7 @@ class ProviderEntity(BaseModel): """ return SimpleProviderEntity( provider=self.provider, + provider_name=self.provider_name, label=self.label, icon_small=self.icon_small, supported_model_types=self.supported_model_types, diff --git a/api/dify_graph/model_runtime/entities/rerank_entities.py b/api/dify_graph/model_runtime/entities/rerank_entities.py index 99709e1bcd..8a0bb5fac2 100644 --- a/api/dify_graph/model_runtime/entities/rerank_entities.py +++ b/api/dify_graph/model_runtime/entities/rerank_entities.py @@ -1,6 +1,13 @@ +from typing import TypedDict + from pydantic import BaseModel +class MultimodalRerankInput(TypedDict): + content: str + content_type: str + + class RerankDocument(BaseModel): """ Model class for rerank document. diff --git a/api/dify_graph/model_runtime/entities/text_embedding_entities.py b/api/dify_graph/model_runtime/entities/text_embedding_entities.py index a0210c169d..7954b410d0 100644 --- a/api/dify_graph/model_runtime/entities/text_embedding_entities.py +++ b/api/dify_graph/model_runtime/entities/text_embedding_entities.py @@ -1,10 +1,18 @@ from decimal import Decimal +from enum import StrEnum, auto from pydantic import BaseModel from dify_graph.model_runtime.entities.model_entities import ModelUsage +class EmbeddingInputType(StrEnum): + """Embedding request input variants understood by the model runtime.""" + + DOCUMENT = auto() + QUERY = auto() + + class EmbeddingUsage(ModelUsage): """ Model class for embedding usage. diff --git a/api/dify_graph/model_runtime/model_providers/__base/ai_model.py b/api/dify_graph/model_runtime/model_providers/__base/ai_model.py index ac7ae9925b..cfce0b0602 100644 --- a/api/dify_graph/model_runtime/model_providers/__base/ai_model.py +++ b/api/dify_graph/model_runtime/model_providers/__base/ai_model.py @@ -1,12 +1,5 @@ import decimal -import hashlib -import logging -from pydantic import BaseModel, ConfigDict, Field, ValidationError -from redis import RedisError - -from configs import dify_config -from core.plugin.entities.plugin_daemon import PluginModelProviderEntity from dify_graph.model_runtime.entities.common_entities import I18nObject from dify_graph.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE from dify_graph.model_runtime.entities.model_entities import ( @@ -17,6 +10,7 @@ from dify_graph.model_runtime.entities.model_entities import ( PriceInfo, PriceType, ) +from dify_graph.model_runtime.entities.provider_entities import ProviderEntity from dify_graph.model_runtime.errors.invoke import ( InvokeAuthorizationError, InvokeBadRequestError, @@ -25,45 +19,61 @@ from dify_graph.model_runtime.errors.invoke import ( InvokeRateLimitError, InvokeServerUnavailableError, ) -from extensions.ext_redis import redis_client - -logger = logging.getLogger(__name__) +from dify_graph.model_runtime.runtime import ModelRuntime -class AIModel(BaseModel): +class AIModel: """ - Base class for all models. + Runtime-facing base class for all model providers. + + This stays a regular Python class because instances hold live collaborators + such as the provider schema and runtime adapter rather than user input that + benefits from Pydantic validation. Subclasses must pin ``model_type`` via a + class attribute; the base class is not meant to be instantiated directly. """ - tenant_id: str = Field(description="Tenant ID") - model_type: ModelType = Field(description="Model type") - plugin_id: str = Field(description="Plugin ID") - provider_name: str = Field(description="Provider") - plugin_model_provider: PluginModelProviderEntity = Field(description="Plugin model provider") - started_at: float = Field(description="Invoke start time", default=0) + model_type: ModelType + provider_schema: ProviderEntity + model_runtime: ModelRuntime + started_at: float - # pydantic configs - model_config = ConfigDict(protected_namespaces=()) + def __init__( + self, + provider_schema: ProviderEntity, + model_runtime: ModelRuntime, + *, + started_at: float = 0, + ) -> None: + if getattr(type(self), "model_type", None) is None: + raise TypeError("AIModel subclasses must define model_type as a class attribute") + + self.model_type = type(self).model_type + self.provider_schema = provider_schema + self.model_runtime = model_runtime + self.started_at = started_at + + @property + def provider(self) -> str: + return self.provider_schema.provider + + @property + def provider_display_name(self) -> str: + return self.provider_schema.label.en_US @property def _invoke_error_mapping(self) -> dict[type[Exception], list[type[Exception]]]: """ - Map model invoke error to unified error - The key is the error type thrown to the caller - The value is the error type thrown by the model, - which needs to be converted into a unified error type for the caller. + Map model invoke error to unified error. - :return: Invoke error mapping + The key is the error type thrown to the caller, and the value contains + runtime-facing exception types that should be normalized to it. """ - from core.plugin.entities.plugin_daemon import PluginDaemonInnerError - return { InvokeConnectionError: [InvokeConnectionError], InvokeServerUnavailableError: [InvokeServerUnavailableError], InvokeRateLimitError: [InvokeRateLimitError], InvokeAuthorizationError: [InvokeAuthorizationError], InvokeBadRequestError: [InvokeBadRequestError], - PluginDaemonInnerError: [PluginDaemonInnerError], ValueError: [ValueError], } @@ -79,15 +89,18 @@ class AIModel(BaseModel): if invoke_error == InvokeAuthorizationError: return InvokeAuthorizationError( description=( - f"[{self.provider_name}] Incorrect model credentials provided, please check and try again." + f"[{self.provider_display_name}] Incorrect model credentials provided, " + "please check and try again." ) ) elif isinstance(invoke_error, InvokeError): - return InvokeError(description=f"[{self.provider_name}] {invoke_error.description}, {str(error)}") + return InvokeError( + description=f"[{self.provider_display_name}] {invoke_error.description}, {str(error)}" + ) else: return error - return InvokeError(description=f"[{self.provider_name}] Error: {str(error)}") + return InvokeError(description=f"[{self.provider_display_name}] Error: {str(error)}") def get_price(self, model: str, credentials: dict, price_type: PriceType, tokens: int) -> PriceInfo: """ @@ -144,65 +157,13 @@ class AIModel(BaseModel): :param credentials: model credentials :return: model schema """ - from core.plugin.impl.model import PluginModelClient - - plugin_model_manager = PluginModelClient() - cache_key = f"{self.tenant_id}:{self.plugin_id}:{self.provider_name}:{self.model_type.value}:{model}" - sorted_credentials = sorted(credentials.items()) if credentials else [] - cache_key += ":".join([hashlib.md5(f"{k}:{v}".encode()).hexdigest() for k, v in sorted_credentials]) - - cached_schema_json = None - try: - cached_schema_json = redis_client.get(cache_key) - except (RedisError, RuntimeError) as exc: - logger.warning( - "Failed to read plugin model schema cache for model %s: %s", - model, - str(exc), - exc_info=True, - ) - if cached_schema_json: - try: - return AIModelEntity.model_validate_json(cached_schema_json) - except ValidationError: - logger.warning( - "Failed to validate cached plugin model schema for model %s", - model, - exc_info=True, - ) - try: - redis_client.delete(cache_key) - except (RedisError, RuntimeError) as exc: - logger.warning( - "Failed to delete invalid plugin model schema cache for model %s: %s", - model, - str(exc), - exc_info=True, - ) - - schema = plugin_model_manager.get_model_schema( - tenant_id=self.tenant_id, - user_id="unknown", - plugin_id=self.plugin_id, - provider=self.provider_name, - model_type=self.model_type.value, + return self.model_runtime.get_model_schema( + provider=self.provider, + model_type=self.model_type, model=model, credentials=credentials or {}, ) - if schema: - try: - redis_client.setex(cache_key, dify_config.PLUGIN_MODEL_SCHEMA_CACHE_TTL, schema.model_dump_json()) - except (RedisError, RuntimeError) as exc: - logger.warning( - "Failed to write plugin model schema cache for model %s: %s", - model, - str(exc), - exc_info=True, - ) - - return schema - def get_customizable_model_schema_from_credentials(self, model: str, credentials: dict) -> AIModelEntity | None: """ Get customizable model schema from credentials diff --git a/api/dify_graph/model_runtime/model_providers/__base/large_language_model.py b/api/dify_graph/model_runtime/model_providers/__base/large_language_model.py index bf864ca227..5caac13c15 100644 --- a/api/dify_graph/model_runtime/model_providers/__base/large_language_model.py +++ b/api/dify_graph/model_runtime/model_providers/__base/large_language_model.py @@ -4,9 +4,6 @@ import uuid from collections.abc import Callable, Generator, Iterator, Sequence from typing import Union -from pydantic import ConfigDict - -from configs import dify_config from dify_graph.model_runtime.callbacks.base_callback import Callback from dify_graph.model_runtime.callbacks.logging_callback import LoggingCallback from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMUsage @@ -140,11 +137,9 @@ def _build_llm_result_from_chunks( ) -def _invoke_llm_via_plugin( +def _invoke_llm_via_runtime( *, - tenant_id: str, - user_id: str, - plugin_id: str, + llm_model: "LargeLanguageModel", provider: str, model: str, credentials: dict, @@ -154,25 +149,19 @@ def _invoke_llm_via_plugin( stop: Sequence[str] | None, stream: bool, ) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]: - from core.plugin.impl.model import PluginModelClient - - plugin_model_manager = PluginModelClient() - return plugin_model_manager.invoke_llm( - tenant_id=tenant_id, - user_id=user_id, - plugin_id=plugin_id, + return llm_model.model_runtime.invoke_llm( provider=provider, model=model, credentials=credentials, model_parameters=model_parameters, prompt_messages=list(prompt_messages), tools=tools, - stop=list(stop) if stop else None, + stop=stop, stream=stream, ) -def _normalize_non_stream_plugin_result( +def _normalize_non_stream_runtime_result( model: str, prompt_messages: Sequence[PromptMessage], result: Union[LLMResult, Iterator[LLMResultChunk]], @@ -208,9 +197,6 @@ class LargeLanguageModel(AIModel): model_type: ModelType = ModelType.LLM - # pydantic configs - model_config = ConfigDict(protected_namespaces=()) - def invoke( self, model: str, @@ -220,7 +206,6 @@ class LargeLanguageModel(AIModel): tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: bool = True, - user: str | None = None, callbacks: list[Callback] | None = None, ) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]: """ @@ -233,7 +218,6 @@ class LargeLanguageModel(AIModel): :param tools: tools for tool calling :param stop: stop words :param stream: is stream response - :param user: unique user id :param callbacks: callbacks :return: full response or stream response chunk generator result """ @@ -245,7 +229,7 @@ class LargeLanguageModel(AIModel): callbacks = callbacks or [] - if dify_config.DEBUG: + if logger.isEnabledFor(logging.DEBUG): callbacks.append(LoggingCallback()) # trigger before invoke callbacks @@ -257,18 +241,15 @@ class LargeLanguageModel(AIModel): tools=tools, stop=stop, stream=stream, - user=user, callbacks=callbacks, ) result: Union[LLMResult, Generator[LLMResultChunk, None, None]] try: - result = _invoke_llm_via_plugin( - tenant_id=self.tenant_id, - user_id=user or "unknown", - plugin_id=self.plugin_id, - provider=self.provider_name, + result = _invoke_llm_via_runtime( + llm_model=self, + provider=self.provider, model=model, credentials=credentials, model_parameters=model_parameters, @@ -279,7 +260,7 @@ class LargeLanguageModel(AIModel): ) if not stream: - result = _normalize_non_stream_plugin_result( + result = _normalize_non_stream_runtime_result( model=model, prompt_messages=prompt_messages, result=result ) except Exception as e: @@ -292,7 +273,6 @@ class LargeLanguageModel(AIModel): tools=tools, stop=stop, stream=stream, - user=user, callbacks=callbacks, ) @@ -309,7 +289,6 @@ class LargeLanguageModel(AIModel): tools=tools, stop=stop, stream=stream, - user=user, callbacks=callbacks, ) elif isinstance(result, LLMResult): @@ -322,7 +301,6 @@ class LargeLanguageModel(AIModel): tools=tools, stop=stop, stream=stream, - user=user, callbacks=callbacks, ) # Following https://github.com/langgenius/dify/issues/17799, @@ -435,22 +413,14 @@ class LargeLanguageModel(AIModel): :param tools: tools for tool calling :return: """ - if dify_config.PLUGIN_BASED_TOKEN_COUNTING_ENABLED: - from core.plugin.impl.model import PluginModelClient - - plugin_model_manager = PluginModelClient() - return plugin_model_manager.get_llm_num_tokens( - tenant_id=self.tenant_id, - user_id="unknown", - plugin_id=self.plugin_id, - provider=self.provider_name, - model_type=self.model_type.value, - model=model, - credentials=credentials, - prompt_messages=prompt_messages, - tools=tools, - ) - return 0 + return self.model_runtime.get_llm_num_tokens( + provider=self.provider, + model_type=self.model_type, + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + tools=tools, + ) def calc_response_usage( self, model: str, credentials: dict, prompt_tokens: int, completion_tokens: int diff --git a/api/dify_graph/model_runtime/model_providers/__base/moderation_model.py b/api/dify_graph/model_runtime/model_providers/__base/moderation_model.py index 5fa3d1634b..6983971770 100644 --- a/api/dify_graph/model_runtime/model_providers/__base/moderation_model.py +++ b/api/dify_graph/model_runtime/model_providers/__base/moderation_model.py @@ -1,7 +1,5 @@ import time -from pydantic import ConfigDict - from dify_graph.model_runtime.entities.model_entities import ModelType from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel @@ -13,30 +11,20 @@ class ModerationModel(AIModel): model_type: ModelType = ModelType.MODERATION - # pydantic configs - model_config = ConfigDict(protected_namespaces=()) - - def invoke(self, model: str, credentials: dict, text: str, user: str | None = None) -> bool: + def invoke(self, model: str, credentials: dict, text: str) -> bool: """ Invoke moderation model :param model: model name :param credentials: model credentials :param text: text to moderate - :param user: unique user id :return: false if text is safe, true otherwise """ self.started_at = time.perf_counter() try: - from core.plugin.impl.model import PluginModelClient - - plugin_model_manager = PluginModelClient() - return plugin_model_manager.invoke_moderation( - tenant_id=self.tenant_id, - user_id=user or "unknown", - plugin_id=self.plugin_id, - provider=self.provider_name, + return self.model_runtime.invoke_moderation( + provider=self.provider, model=model, credentials=credentials, text=text, diff --git a/api/dify_graph/model_runtime/model_providers/__base/rerank_model.py b/api/dify_graph/model_runtime/model_providers/__base/rerank_model.py index 5da2b84b95..f448690677 100644 --- a/api/dify_graph/model_runtime/model_providers/__base/rerank_model.py +++ b/api/dify_graph/model_runtime/model_providers/__base/rerank_model.py @@ -1,5 +1,5 @@ from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.entities.rerank_entities import RerankResult +from dify_graph.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel @@ -18,7 +18,6 @@ class RerankModel(AIModel): docs: list[str], score_threshold: float | None = None, top_n: int | None = None, - user: str | None = None, ) -> RerankResult: """ Invoke rerank model @@ -29,18 +28,11 @@ class RerankModel(AIModel): :param docs: docs for reranking :param score_threshold: score threshold :param top_n: top n - :param user: unique user id :return: rerank result """ try: - from core.plugin.impl.model import PluginModelClient - - plugin_model_manager = PluginModelClient() - return plugin_model_manager.invoke_rerank( - tenant_id=self.tenant_id, - user_id=user or "unknown", - plugin_id=self.plugin_id, - provider=self.provider_name, + return self.model_runtime.invoke_rerank( + provider=self.provider, model=model, credentials=credentials, query=query, @@ -55,11 +47,10 @@ class RerankModel(AIModel): self, model: str, credentials: dict, - query: dict, - docs: list[dict], + query: MultimodalRerankInput, + docs: list[MultimodalRerankInput], score_threshold: float | None = None, top_n: int | None = None, - user: str | None = None, ) -> RerankResult: """ Invoke multimodal rerank model @@ -69,18 +60,11 @@ class RerankModel(AIModel): :param docs: docs for reranking :param score_threshold: score threshold :param top_n: top n - :param user: unique user id :return: rerank result """ try: - from core.plugin.impl.model import PluginModelClient - - plugin_model_manager = PluginModelClient() - return plugin_model_manager.invoke_multimodal_rerank( - tenant_id=self.tenant_id, - user_id=user or "unknown", - plugin_id=self.plugin_id, - provider=self.provider_name, + return self.model_runtime.invoke_multimodal_rerank( + provider=self.provider, model=model, credentials=credentials, query=query, diff --git a/api/dify_graph/model_runtime/model_providers/__base/speech2text_model.py b/api/dify_graph/model_runtime/model_providers/__base/speech2text_model.py index e69069a85d..0cfa5208e9 100644 --- a/api/dify_graph/model_runtime/model_providers/__base/speech2text_model.py +++ b/api/dify_graph/model_runtime/model_providers/__base/speech2text_model.py @@ -1,7 +1,5 @@ from typing import IO -from pydantic import ConfigDict - from dify_graph.model_runtime.entities.model_entities import ModelType from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel @@ -13,28 +11,18 @@ class Speech2TextModel(AIModel): model_type: ModelType = ModelType.SPEECH2TEXT - # pydantic configs - model_config = ConfigDict(protected_namespaces=()) - - def invoke(self, model: str, credentials: dict, file: IO[bytes], user: str | None = None) -> str: + def invoke(self, model: str, credentials: dict, file: IO[bytes]) -> str: """ Invoke speech to text model :param model: model name :param credentials: model credentials :param file: audio file - :param user: unique user id :return: text for given audio file """ try: - from core.plugin.impl.model import PluginModelClient - - plugin_model_manager = PluginModelClient() - return plugin_model_manager.invoke_speech_to_text( - tenant_id=self.tenant_id, - user_id=user or "unknown", - plugin_id=self.plugin_id, - provider=self.provider_name, + return self.model_runtime.invoke_speech_to_text( + provider=self.provider, model=model, credentials=credentials, file=file, diff --git a/api/dify_graph/model_runtime/model_providers/__base/text_embedding_model.py b/api/dify_graph/model_runtime/model_providers/__base/text_embedding_model.py index 3438da2ada..d03ff5f6c2 100644 --- a/api/dify_graph/model_runtime/model_providers/__base/text_embedding_model.py +++ b/api/dify_graph/model_runtime/model_providers/__base/text_embedding_model.py @@ -1,8 +1,5 @@ -from pydantic import ConfigDict - -from core.entities.embedding_type import EmbeddingInputType from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType -from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingResult +from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingInputType, EmbeddingResult from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel @@ -13,16 +10,12 @@ class TextEmbeddingModel(AIModel): model_type: ModelType = ModelType.TEXT_EMBEDDING - # pydantic configs - model_config = ConfigDict(protected_namespaces=()) - def invoke( self, model: str, credentials: dict, texts: list[str] | None = None, multimodel_documents: list[dict] | None = None, - user: str | None = None, input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT, ) -> EmbeddingResult: """ @@ -32,31 +25,21 @@ class TextEmbeddingModel(AIModel): :param credentials: model credentials :param texts: texts to embed :param files: files to embed - :param user: unique user id :param input_type: input type :return: embeddings result """ - from core.plugin.impl.model import PluginModelClient - try: - plugin_model_manager = PluginModelClient() if texts: - return plugin_model_manager.invoke_text_embedding( - tenant_id=self.tenant_id, - user_id=user or "unknown", - plugin_id=self.plugin_id, - provider=self.provider_name, + return self.model_runtime.invoke_text_embedding( + provider=self.provider, model=model, credentials=credentials, texts=texts, input_type=input_type, ) if multimodel_documents: - return plugin_model_manager.invoke_multimodal_embedding( - tenant_id=self.tenant_id, - user_id=user or "unknown", - plugin_id=self.plugin_id, - provider=self.provider_name, + return self.model_runtime.invoke_multimodal_embedding( + provider=self.provider, model=model, credentials=credentials, documents=multimodel_documents, @@ -75,14 +58,8 @@ class TextEmbeddingModel(AIModel): :param texts: texts to embed :return: """ - from core.plugin.impl.model import PluginModelClient - - plugin_model_manager = PluginModelClient() - return plugin_model_manager.get_text_embedding_num_tokens( - tenant_id=self.tenant_id, - user_id="unknown", - plugin_id=self.plugin_id, - provider=self.provider_name, + return self.model_runtime.get_text_embedding_num_tokens( + provider=self.provider, model=model, credentials=credentials, texts=texts, diff --git a/api/dify_graph/model_runtime/model_providers/__base/tts_model.py b/api/dify_graph/model_runtime/model_providers/__base/tts_model.py index 0656529f22..23f5640b37 100644 --- a/api/dify_graph/model_runtime/model_providers/__base/tts_model.py +++ b/api/dify_graph/model_runtime/model_providers/__base/tts_model.py @@ -1,8 +1,6 @@ import logging from collections.abc import Iterable -from pydantic import ConfigDict - from dify_graph.model_runtime.entities.model_entities import ModelType from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel @@ -16,38 +14,25 @@ class TTSModel(AIModel): model_type: ModelType = ModelType.TTS - # pydantic configs - model_config = ConfigDict(protected_namespaces=()) - def invoke( self, model: str, - tenant_id: str, credentials: dict, content_text: str, voice: str, - user: str | None = None, ) -> Iterable[bytes]: """ Invoke large language model :param model: model name - :param tenant_id: user tenant id :param credentials: model credentials :param voice: model timbre :param content_text: text content to be translated - :param user: unique user id :return: translated audio file """ try: - from core.plugin.impl.model import PluginModelClient - - plugin_model_manager = PluginModelClient() - return plugin_model_manager.invoke_tts( - tenant_id=self.tenant_id, - user_id=user or "unknown", - plugin_id=self.plugin_id, - provider=self.provider_name, + return self.model_runtime.invoke_tts( + provider=self.provider, model=model, credentials=credentials, content_text=content_text, @@ -65,14 +50,8 @@ class TTSModel(AIModel): :param credentials: The credentials required to access the TTS model. :return: A list of voices supported by the TTS model. """ - from core.plugin.impl.model import PluginModelClient - - plugin_model_manager = PluginModelClient() - return plugin_model_manager.get_tts_model_voices( - tenant_id=self.tenant_id, - user_id="unknown", - plugin_id=self.plugin_id, - provider=self.provider_name, + return self.model_runtime.get_tts_model_voices( + provider=self.provider, model=model, credentials=credentials, language=language, diff --git a/api/dify_graph/model_runtime/model_providers/model_provider_factory.py b/api/dify_graph/model_runtime/model_providers/model_provider_factory.py index de0677a348..ae5c0ec7c4 100644 --- a/api/dify_graph/model_runtime/model_providers/model_provider_factory.py +++ b/api/dify_graph/model_runtime/model_providers/model_provider_factory.py @@ -1,16 +1,7 @@ from __future__ import annotations -import hashlib -import logging from collections.abc import Sequence -from threading import Lock -from pydantic import ValidationError -from redis import RedisError - -import contexts -from configs import dify_config -from core.plugin.entities.plugin_daemon import PluginModelProviderEntity from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelType from dify_graph.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel @@ -20,120 +11,64 @@ from dify_graph.model_runtime.model_providers.__base.rerank_model import RerankM from dify_graph.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel from dify_graph.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from dify_graph.model_runtime.model_providers.__base.tts_model import TTSModel +from dify_graph.model_runtime.runtime import ModelRuntime from dify_graph.model_runtime.schema_validators.model_credential_schema_validator import ModelCredentialSchemaValidator from dify_graph.model_runtime.schema_validators.provider_credential_schema_validator import ( ProviderCredentialSchemaValidator, ) -from extensions.ext_redis import redis_client -from models.provider_ids import ModelProviderID - -logger = logging.getLogger(__name__) class ModelProviderFactory: - def __init__(self, tenant_id: str): - from core.plugin.impl.model import PluginModelClient + """Factory for provider schemas and model-type instances backed by a runtime adapter.""" - self.tenant_id = tenant_id - self.plugin_model_manager = PluginModelClient() + def __init__(self, model_runtime: ModelRuntime): + if model_runtime is None: + raise ValueError("model_runtime is required.") + self.model_runtime = model_runtime def get_providers(self) -> Sequence[ProviderEntity]: """ - Get all providers - :return: list of providers + Get all providers. """ - # FIXME(-LAN-): Removed position map sorting since providers are fetched from plugin server - # The plugin server should return providers in the desired order - plugin_providers = self.get_plugin_model_providers() - return [provider.declaration for provider in plugin_providers] + return list(self.get_model_providers()) - def get_plugin_model_providers(self) -> Sequence[PluginModelProviderEntity]: + def get_model_providers(self) -> Sequence[ProviderEntity]: """ - Get all plugin model providers - :return: list of plugin model providers + Get all model providers exposed by the runtime adapter. """ - # check if context is set - try: - contexts.plugin_model_providers.get() - except LookupError: - contexts.plugin_model_providers.set(None) - contexts.plugin_model_providers_lock.set(Lock()) - - with contexts.plugin_model_providers_lock.get(): - plugin_model_providers = contexts.plugin_model_providers.get() - if plugin_model_providers is not None: - return plugin_model_providers - - plugin_model_providers = [] - contexts.plugin_model_providers.set(plugin_model_providers) - - # Fetch plugin model providers - plugin_providers = self.plugin_model_manager.fetch_model_providers(self.tenant_id) - - for provider in plugin_providers: - provider.declaration.provider = provider.plugin_id + "/" + provider.declaration.provider - plugin_model_providers.append(provider) - - return plugin_model_providers + return self.model_runtime.fetch_model_providers() def get_provider_schema(self, provider: str) -> ProviderEntity: """ - Get provider schema - :param provider: provider name - :return: provider schema + Get provider schema. """ - plugin_model_provider_entity = self.get_plugin_model_provider(provider=provider) - return plugin_model_provider_entity.declaration + return self.get_model_provider(provider=provider) - def get_plugin_model_provider(self, provider: str) -> PluginModelProviderEntity: + def get_model_provider(self, provider: str) -> ProviderEntity: """ - Get plugin model provider - :param provider: provider name - :return: provider schema + Get provider schema. """ - if "/" not in provider: - provider = str(ModelProviderID(provider)) - - # fetch plugin model providers - plugin_model_provider_entities = self.get_plugin_model_providers() - - # get the provider - plugin_model_provider_entity = next( - (p for p in plugin_model_provider_entities if p.declaration.provider == provider), - None, - ) - - if not plugin_model_provider_entity: + provider_entity = self._resolve_provider(provider) + if provider_entity is None: raise ValueError(f"Invalid provider: {provider}") - return plugin_model_provider_entity + return provider_entity def provider_credentials_validate(self, *, provider: str, credentials: dict): """ - Validate provider credentials - - :param provider: provider name - :param credentials: provider credentials, credentials form defined in `provider_credential_schema`. - :return: + Validate provider credentials. """ - # fetch plugin model provider - plugin_model_provider_entity = self.get_plugin_model_provider(provider=provider) + provider_entity = self.get_model_provider(provider=provider) - # get provider_credential_schema and validate credentials according to the rules - provider_credential_schema = plugin_model_provider_entity.declaration.provider_credential_schema + provider_credential_schema = provider_entity.provider_credential_schema if not provider_credential_schema: raise ValueError(f"Provider {provider} does not have provider_credential_schema") - # validate provider credential schema validator = ProviderCredentialSchemaValidator(provider_credential_schema) filtered_credentials = validator.validate_and_filter(credentials) - # validate the credentials, raise exception if validation failed - self.plugin_model_manager.validate_provider_credentials( - tenant_id=self.tenant_id, - user_id="unknown", - plugin_id=plugin_model_provider_entity.plugin_id, - provider=plugin_model_provider_entity.provider, + self.model_runtime.validate_provider_credentials( + provider=provider_entity.provider, credentials=filtered_credentials, ) @@ -141,33 +76,20 @@ class ModelProviderFactory: def model_credentials_validate(self, *, provider: str, model_type: ModelType, model: str, credentials: dict): """ - Validate model credentials - - :param provider: provider name - :param model_type: model type - :param model: model name - :param credentials: model credentials, credentials form defined in `model_credential_schema`. - :return: + Validate model credentials. """ - # fetch plugin model provider - plugin_model_provider_entity = self.get_plugin_model_provider(provider=provider) + provider_entity = self.get_model_provider(provider=provider) - # get model_credential_schema and validate credentials according to the rules - model_credential_schema = plugin_model_provider_entity.declaration.model_credential_schema + model_credential_schema = provider_entity.model_credential_schema if not model_credential_schema: raise ValueError(f"Provider {provider} does not have model_credential_schema") - # validate model credential schema validator = ModelCredentialSchemaValidator(model_type, model_credential_schema) filtered_credentials = validator.validate_and_filter(credentials) - # call validate_credentials method of model type to validate credentials, raise exception if validation failed - self.plugin_model_manager.validate_model_credentials( - tenant_id=self.tenant_id, - user_id="unknown", - plugin_id=plugin_model_provider_entity.plugin_id, - provider=plugin_model_provider_entity.provider, - model_type=model_type.value, + self.model_runtime.validate_model_credentials( + provider=provider_entity.provider, + model_type=model_type, model=model, credentials=filtered_credentials, ) @@ -178,65 +100,16 @@ class ModelProviderFactory: self, *, provider: str, model_type: ModelType, model: str, credentials: dict | None ) -> AIModelEntity | None: """ - Get model schema + Get model schema. """ - plugin_id, provider_name = self.get_plugin_id_and_provider_name_from_provider(provider) - cache_key = f"{self.tenant_id}:{plugin_id}:{provider_name}:{model_type.value}:{model}" - sorted_credentials = sorted(credentials.items()) if credentials else [] - cache_key += ":".join([hashlib.md5(f"{k}:{v}".encode()).hexdigest() for k, v in sorted_credentials]) - - cached_schema_json = None - try: - cached_schema_json = redis_client.get(cache_key) - except (RedisError, RuntimeError) as exc: - logger.warning( - "Failed to read plugin model schema cache for model %s: %s", - model, - str(exc), - exc_info=True, - ) - if cached_schema_json: - try: - return AIModelEntity.model_validate_json(cached_schema_json) - except ValidationError: - logger.warning( - "Failed to validate cached plugin model schema for model %s", - model, - exc_info=True, - ) - try: - redis_client.delete(cache_key) - except (RedisError, RuntimeError) as exc: - logger.warning( - "Failed to delete invalid plugin model schema cache for model %s: %s", - model, - str(exc), - exc_info=True, - ) - - schema = self.plugin_model_manager.get_model_schema( - tenant_id=self.tenant_id, - user_id="unknown", - plugin_id=plugin_id, - provider=provider_name, - model_type=model_type.value, + provider_entity = self.get_model_provider(provider) + return self.model_runtime.get_model_schema( + provider=provider_entity.provider, + model_type=model_type, model=model, credentials=credentials or {}, ) - if schema: - try: - redis_client.setex(cache_key, dify_config.PLUGIN_MODEL_SCHEMA_CACHE_TTL, schema.model_dump_json()) - except (RedisError, RuntimeError) as exc: - logger.warning( - "Failed to write plugin model schema cache for model %s: %s", - model, - str(exc), - exc_info=True, - ) - - return schema - def get_models( self, *, @@ -245,143 +118,56 @@ class ModelProviderFactory: provider_configs: list[ProviderConfig] | None = None, ) -> list[SimpleProviderEntity]: """ - Get all models for given model type - - :param provider: provider name - :param model_type: model type - :param provider_configs: list of provider configs - :return: list of models + Get all models for given model type. """ - provider_configs = provider_configs or [] - - # scan all providers - plugin_model_provider_entities = self.get_plugin_model_providers() - - # traverse all model_provider_extensions providers = [] - for plugin_model_provider_entity in plugin_model_provider_entities: - # filter by provider if provider is present - if provider and plugin_model_provider_entity.declaration.provider != provider: + for provider_entity in self.get_model_providers(): + if provider and not self._matches_provider(provider_entity, provider): continue - # get provider schema - provider_schema = plugin_model_provider_entity.declaration - - model_types = provider_schema.supported_model_types - if model_type: - if model_type not in model_types: - continue - - model_types = [model_type] - - all_model_type_models = [] - for model_schema in provider_schema.models: - if model_schema.model_type != model_type: - continue - - all_model_type_models.append(model_schema) - - simple_provider_schema = provider_schema.to_simple_provider() - if model_type: - simple_provider_schema.models = all_model_type_models + if model_type and model_type not in provider_entity.supported_model_types: + continue + simple_provider_schema = provider_entity.to_simple_provider() + if model_type is not None: + simple_provider_schema.models = [ + model_schema for model_schema in provider_entity.models if model_schema.model_type == model_type + ] providers.append(simple_provider_schema) return providers def get_model_type_instance(self, provider: str, model_type: ModelType) -> AIModel: """ - Get model type instance by provider name and model type - :param provider: provider name - :param model_type: model type - :return: model type instance + Get model type instance by provider name and model type. """ - plugin_id, provider_name = self.get_plugin_id_and_provider_name_from_provider(provider) - init_params = { - "tenant_id": self.tenant_id, - "plugin_id": plugin_id, - "provider_name": provider_name, - "plugin_model_provider": self.get_plugin_model_provider(provider), - } + provider_schema = self.get_model_provider(provider) if model_type == ModelType.LLM: - return LargeLanguageModel.model_validate(init_params) - elif model_type == ModelType.TEXT_EMBEDDING: - return TextEmbeddingModel.model_validate(init_params) - elif model_type == ModelType.RERANK: - return RerankModel.model_validate(init_params) - elif model_type == ModelType.SPEECH2TEXT: - return Speech2TextModel.model_validate(init_params) - elif model_type == ModelType.MODERATION: - return ModerationModel.model_validate(init_params) - elif model_type == ModelType.TTS: - return TTSModel.model_validate(init_params) + return LargeLanguageModel(provider_schema=provider_schema, model_runtime=self.model_runtime) + if model_type == ModelType.TEXT_EMBEDDING: + return TextEmbeddingModel(provider_schema=provider_schema, model_runtime=self.model_runtime) + if model_type == ModelType.RERANK: + return RerankModel(provider_schema=provider_schema, model_runtime=self.model_runtime) + if model_type == ModelType.SPEECH2TEXT: + return Speech2TextModel(provider_schema=provider_schema, model_runtime=self.model_runtime) + if model_type == ModelType.MODERATION: + return ModerationModel(provider_schema=provider_schema, model_runtime=self.model_runtime) + if model_type == ModelType.TTS: + return TTSModel(provider_schema=provider_schema, model_runtime=self.model_runtime) raise ValueError(f"Unsupported model type: {model_type}") def get_provider_icon(self, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]: """ - Get provider icon - :param provider: provider name - :param icon_type: icon type (icon_small or icon_small_dark) - :param lang: language (zh_Hans or en_US) - :return: provider icon + Get provider icon. """ - # get the provider schema - provider_schema = self.get_provider_schema(provider) + provider_entity = self.get_model_provider(provider) + return self.model_runtime.get_provider_icon(provider=provider_entity.provider, icon_type=icon_type, lang=lang) - if icon_type.lower() == "icon_small": - if not provider_schema.icon_small: - raise ValueError(f"Provider {provider} does not have small icon.") + def _resolve_provider(self, provider: str) -> ProviderEntity | None: + return next((item for item in self.get_model_providers() if self._matches_provider(item, provider)), None) - if lang.lower() == "zh_hans": - file_name = provider_schema.icon_small.zh_Hans - else: - file_name = provider_schema.icon_small.en_US - elif icon_type.lower() == "icon_small_dark": - if not provider_schema.icon_small_dark: - raise ValueError(f"Provider {provider} does not have small dark icon.") - - if lang.lower() == "zh_hans": - file_name = provider_schema.icon_small_dark.zh_Hans - else: - file_name = provider_schema.icon_small_dark.en_US - else: - raise ValueError(f"Unsupported icon type: {icon_type}.") - - if not file_name: - raise ValueError(f"Provider {provider} does not have icon.") - - image_mime_types = { - "jpg": "image/jpeg", - "jpeg": "image/jpeg", - "png": "image/png", - "gif": "image/gif", - "bmp": "image/bmp", - "tiff": "image/tiff", - "tif": "image/tiff", - "webp": "image/webp", - "svg": "image/svg+xml", - "ico": "image/vnd.microsoft.icon", - "heif": "image/heif", - "heic": "image/heic", - } - - extension = file_name.split(".")[-1] - mime_type = image_mime_types.get(extension, "image/png") - - # get icon bytes from plugin asset manager - from core.plugin.impl.asset import PluginAssetManager - - plugin_asset_manager = PluginAssetManager() - return plugin_asset_manager.fetch_asset(tenant_id=self.tenant_id, id=file_name), mime_type - - def get_plugin_id_and_provider_name_from_provider(self, provider: str) -> tuple[str, str]: - """ - Get plugin id and provider name from provider name - :param provider: provider name - :return: plugin id and provider name - """ - - provider_id = ModelProviderID(provider) - return provider_id.plugin_id, provider_id.provider_name + @staticmethod + def _matches_provider(provider_entity: ProviderEntity, provider: str) -> bool: + return provider in (provider_entity.provider, provider_entity.provider_name) diff --git a/api/dify_graph/model_runtime/runtime.py b/api/dify_graph/model_runtime/runtime.py new file mode 100644 index 0000000000..577bb6f387 --- /dev/null +++ b/api/dify_graph/model_runtime/runtime.py @@ -0,0 +1,159 @@ +from __future__ import annotations + +from collections.abc import Generator, Iterable, Sequence +from typing import IO, Any, Protocol, Union, runtime_checkable + +from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk +from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool +from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelType +from dify_graph.model_runtime.entities.provider_entities import ProviderEntity +from dify_graph.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult +from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingInputType, EmbeddingResult + + +@runtime_checkable +class ModelRuntime(Protocol): + """Port for provider discovery, schema lookup, and model execution. + + `provider` is the model runtime's canonical provider identifier. Adapters may + derive transport-specific details from it, but those details stay outside + this boundary. + """ + + def fetch_model_providers(self) -> Sequence[ProviderEntity]: ... + + def get_provider_icon(self, *, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]: ... + + def validate_provider_credentials(self, *, provider: str, credentials: dict[str, Any]) -> None: ... + + def validate_model_credentials( + self, + *, + provider: str, + model_type: ModelType, + model: str, + credentials: dict[str, Any], + ) -> None: ... + + def get_model_schema( + self, + *, + provider: str, + model_type: ModelType, + model: str, + credentials: dict[str, Any], + ) -> AIModelEntity | None: ... + + def invoke_llm( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + model_parameters: dict[str, Any], + prompt_messages: Sequence[PromptMessage], + tools: list[PromptMessageTool] | None, + stop: Sequence[str] | None, + stream: bool, + ) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]: ... + + def get_llm_num_tokens( + self, + *, + provider: str, + model_type: ModelType, + model: str, + credentials: dict[str, Any], + prompt_messages: Sequence[PromptMessage], + tools: Sequence[PromptMessageTool] | None, + ) -> int: ... + + def invoke_text_embedding( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + texts: list[str], + input_type: EmbeddingInputType, + ) -> EmbeddingResult: ... + + def invoke_multimodal_embedding( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + documents: list[dict[str, Any]], + input_type: EmbeddingInputType, + ) -> EmbeddingResult: ... + + def get_text_embedding_num_tokens( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + texts: list[str], + ) -> list[int]: ... + + def invoke_rerank( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + query: str, + docs: list[str], + score_threshold: float | None, + top_n: int | None, + ) -> RerankResult: ... + + def invoke_multimodal_rerank( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + query: MultimodalRerankInput, + docs: list[MultimodalRerankInput], + score_threshold: float | None, + top_n: int | None, + ) -> RerankResult: ... + + def invoke_tts( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + content_text: str, + voice: str, + ) -> Iterable[bytes]: ... + + def get_tts_model_voices( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + language: str | None, + ) -> Any: ... + + def invoke_speech_to_text( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + file: IO[bytes], + ) -> str: ... + + def invoke_moderation( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + text: str, + ) -> bool: ... diff --git a/api/dify_graph/nodes/base/node.py b/api/dify_graph/nodes/base/node.py index 56b46a5894..7ecea71356 100644 --- a/api/dify_graph/nodes/base/node.py +++ b/api/dify_graph/nodes/base/node.py @@ -6,13 +6,12 @@ from abc import abstractmethod from collections.abc import Generator, Mapping, Sequence from functools import singledispatchmethod from types import MappingProxyType -from typing import Any, ClassVar, Generic, Protocol, TypeVar, cast, get_args, get_origin +from typing import Any, ClassVar, Generic, TypeVar, cast, get_args, get_origin from uuid import uuid4 from dify_graph.entities import GraphInitParams from dify_graph.entities.base_node_data import BaseNodeData, RetryConfig from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY from dify_graph.enums import ( ErrorStrategy, NodeExecutionType, @@ -60,7 +59,7 @@ from dify_graph.node_events import ( StreamCompletedEvent, ) from dify_graph.runtime import GraphRuntimeState -from libs.datetime_utils import naive_utc_now +from dify_graph.utils.datetime_utils import naive_utc_now NodeDataT = TypeVar("NodeDataT", bound=BaseNodeData) _MISSING_RUN_CONTEXT_VALUE = object() @@ -68,23 +67,6 @@ _MISSING_RUN_CONTEXT_VALUE = object() logger = logging.getLogger(__name__) -class DifyRunContextProtocol(Protocol): - tenant_id: str - app_id: str - user_id: str - user_from: Any - invoke_from: Any - - -class _MappingDifyRunContext: - def __init__(self, mapping: Mapping[str, Any]) -> None: - self.tenant_id = str(mapping["tenant_id"]) - self.app_id = str(mapping["app_id"]) - self.user_id = str(mapping["user_id"]) - self.user_from = mapping["user_from"] - self.invoke_from = mapping["invoke_from"] - - class Node(Generic[NodeDataT]): """BaseNode serves as the foundational class for all node implementations. @@ -177,8 +159,9 @@ class Node(Generic[NodeDataT]): # Skip base class itself if cls is Node: return - # Only register production node implementations defined under the - # canonical workflow namespaces. + # Only treat nodes from the base dify_graph package as production + # registrations. Higher-layer packages may still register subclasses, + # but dify_graph itself should not know their module identities. # This prevents test helper subclasses from polluting the global registry and # accidentally overriding real node types (e.g., a test Answer node). module_name = getattr(cls, "__module__", "") @@ -186,7 +169,7 @@ class Node(Generic[NodeDataT]): node_type = cls.node_type version = cls.version() bucket = Node._registry.setdefault(node_type, {}) - if module_name.startswith(("dify_graph.nodes.", "core.workflow.nodes.")): + if module_name.startswith("dify_graph.nodes."): # Production node definitions take precedence and may override bucket[version] = cls # type: ignore[index] else: @@ -299,25 +282,6 @@ class Node(Generic[NodeDataT]): raise ValueError(f"run_context missing required key: {key}") return value - def require_dify_context(self) -> DifyRunContextProtocol: - raw_ctx = self.require_run_context_value(DIFY_RUN_CONTEXT_KEY) - if raw_ctx is None: - raise ValueError(f"run_context missing required key: {DIFY_RUN_CONTEXT_KEY}") - - if isinstance(raw_ctx, Mapping): - missing_keys = [ - key for key in ("tenant_id", "app_id", "user_id", "user_from", "invoke_from") if key not in raw_ctx - ] - if missing_keys: - raise ValueError(f"dify context missing required keys: {', '.join(missing_keys)}") - return _MappingDifyRunContext(raw_ctx) - - for attr in ("tenant_id", "app_id", "user_id", "user_from", "invoke_from"): - if not hasattr(raw_ctx, attr): - raise TypeError(f"invalid dify context object, missing attribute: {attr}") - - return cast(DifyRunContextProtocol, raw_ctx) - @property def execution_id(self) -> str: return self._node_execution_id @@ -793,16 +757,11 @@ class Node(Generic[NodeDataT]): @_dispatch.register def _(self, event: RunRetrieverResourceEvent) -> NodeRunRetrieverResourceEvent: - from core.rag.entities.citation_metadata import RetrievalSourceMetadata - - retriever_resources = [ - RetrievalSourceMetadata.model_validate(resource) for resource in event.retriever_resources - ] return NodeRunRetrieverResourceEvent( id=self.execution_id, node_id=self._node_id, node_type=self.node_type, - retriever_resources=retriever_resources, + retriever_resources=event.retriever_resources, context=event.context, node_version=self.version(), ) diff --git a/api/dify_graph/nodes/http_request/node.py b/api/dify_graph/nodes/http_request/node.py index 3e5253d809..d1cfbb0b44 100644 --- a/api/dify_graph/nodes/http_request/node.py +++ b/api/dify_graph/nodes/http_request/node.py @@ -11,9 +11,13 @@ from dify_graph.nodes.base import variable_template_parser from dify_graph.nodes.base.entities import VariableSelector from dify_graph.nodes.base.node import Node from dify_graph.nodes.http_request.executor import Executor -from dify_graph.nodes.protocols import FileManagerProtocol, HttpClientProtocol, ToolFileManagerProtocol +from dify_graph.nodes.protocols import ( + FileManagerProtocol, + FileReferenceFactoryProtocol, + HttpClientProtocol, + ToolFileManagerProtocol, +) from dify_graph.variables.segments import ArrayFileSegment -from factories import file_factory from .config import build_http_request_config, resolve_http_request_config from .entities import ( @@ -46,6 +50,7 @@ class HttpRequestNode(Node[HttpRequestNodeData]): http_client: HttpClientProtocol, tool_file_manager_factory: Callable[[], ToolFileManagerProtocol], file_manager: FileManagerProtocol, + file_reference_factory: FileReferenceFactoryProtocol, ) -> None: super().__init__( id=id, @@ -58,6 +63,7 @@ class HttpRequestNode(Node[HttpRequestNodeData]): self._http_client = http_client self._tool_file_manager_factory = tool_file_manager_factory self._file_manager = file_manager + self._file_reference_factory = file_reference_factory @classmethod def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: @@ -212,7 +218,6 @@ class HttpRequestNode(Node[HttpRequestNodeData]): """ Extract files from response by checking both Content-Type header and URL """ - dify_ctx = self.require_dify_context() files: list[File] = [] is_file = response.is_file content_type = response.content_type @@ -237,20 +242,16 @@ class HttpRequestNode(Node[HttpRequestNodeData]): tool_file_manager = self._tool_file_manager_factory() tool_file = tool_file_manager.create_file_by_raw( - user_id=dify_ctx.user_id, - tenant_id=dify_ctx.tenant_id, conversation_id=None, file_binary=content, mimetype=mime_type, ) - mapping = { - "tool_file_id": tool_file.id, - "transfer_method": FileTransferMethod.TOOL_FILE, - } - file = file_factory.build_from_mapping( - mapping=mapping, - tenant_id=dify_ctx.tenant_id, + file = self._file_reference_factory.build_from_mapping( + mapping={ + "tool_file_id": tool_file.id, + "transfer_method": FileTransferMethod.TOOL_FILE, + } ) files.append(file) diff --git a/api/dify_graph/nodes/human_input/human_input_node.py b/api/dify_graph/nodes/human_input/human_input_node.py index 794e33d92e..575f918237 100644 --- a/api/dify_graph/nodes/human_input/human_input_node.py +++ b/api/dify_graph/nodes/human_input/human_input_node.py @@ -15,15 +15,16 @@ from dify_graph.node_events import ( from dify_graph.node_events.base import NodeEventBase from dify_graph.node_events.node import StreamCompletedEvent from dify_graph.nodes.base.node import Node +from dify_graph.nodes.runtime import HumanInputNodeRuntimeProtocol from dify_graph.repositories.human_input_form_repository import ( FormCreateParams, HumanInputFormEntity, HumanInputFormRepository, ) +from dify_graph.utils.datetime_utils import naive_utc_now from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter -from libs.datetime_utils import naive_utc_now -from .entities import DeliveryChannelConfig, HumanInputNodeData, apply_debug_email_recipient +from .entities import DeliveryChannelConfig, HumanInputNodeData from .enums import DeliveryMethodType, HumanInputFormStatus, PlaceholderType if TYPE_CHECKING: @@ -68,6 +69,7 @@ class HumanInputNode(Node[HumanInputNodeData]): graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", form_repository: HumanInputFormRepository, + runtime: HumanInputNodeRuntimeProtocol | None = None, ) -> None: super().__init__( id=id, @@ -76,6 +78,9 @@ class HumanInputNode(Node[HumanInputNodeData]): graph_runtime_state=graph_runtime_state, ) self._form_repository = form_repository + if runtime is None: + raise ValueError("runtime is required") + self._runtime = runtime @classmethod def version(cls) -> str: @@ -171,25 +176,14 @@ class HumanInputNode(Node[HumanInputNodeData]): return self._node_data.is_webapp_enabled() def _effective_delivery_methods(self) -> Sequence[DeliveryChannelConfig]: - dify_ctx = self.require_dify_context() invoke_from = self._invoke_from_value() enabled_methods = [method for method in self._node_data.delivery_methods if method.enabled] if invoke_from in {_INVOKE_FROM_DEBUGGER, _INVOKE_FROM_EXPLORE}: enabled_methods = [method for method in enabled_methods if method.type != DeliveryMethodType.WEBAPP] - return [ - apply_debug_email_recipient( - method, - enabled=invoke_from == _INVOKE_FROM_DEBUGGER, - user_id=dify_ctx.user_id, - ) - for method in enabled_methods - ] + return self._runtime.apply_delivery_runtime(methods=enabled_methods) def _invoke_from_value(self) -> str: - invoke_from = self.require_dify_context().invoke_from - if isinstance(invoke_from, str): - return invoke_from - return str(getattr(invoke_from, "value", invoke_from)) + return self._runtime.invoke_source() def _human_input_required_event(self, form_entity: HumanInputFormEntity) -> HumanInputRequired: node_data = self._node_data @@ -224,11 +218,9 @@ class HumanInputNode(Node[HumanInputNodeData]): """ repo = self._form_repository form = repo.get_form(self._workflow_execution_id, self.id) - dify_ctx = self.require_dify_context() if form is None: display_in_ui = self._display_in_ui() params = FormCreateParams( - app_id=dify_ctx.app_id, workflow_execution_id=self._workflow_execution_id, node_id=self.id, form_config=self._node_data, @@ -238,7 +230,7 @@ class HumanInputNode(Node[HumanInputNodeData]): resolved_default_values=self.resolve_default_values(), console_recipient_required=self._should_require_console_recipient(), console_creator_account_id=( - dify_ctx.user_id + self._runtime.console_actor_id() if self._invoke_from_value() in {_INVOKE_FROM_DEBUGGER, _INVOKE_FROM_EXPLORE} else None ), diff --git a/api/dify_graph/nodes/iteration/iteration_node.py b/api/dify_graph/nodes/iteration/iteration_node.py index 033ec8672f..769ef11b62 100644 --- a/api/dify_graph/nodes/iteration/iteration_node.py +++ b/api/dify_graph/nodes/iteration/iteration_node.py @@ -34,10 +34,10 @@ from dify_graph.nodes.base import LLMUsageTrackingMixin from dify_graph.nodes.base.node import Node from dify_graph.nodes.iteration.entities import ErrorHandleMode, IterationNodeData from dify_graph.runtime import VariablePool +from dify_graph.utils.datetime_utils import naive_utc_now from dify_graph.variables import IntegerVariable, NoneSegment from dify_graph.variables.segments import ArrayAnySegment, ArraySegment from dify_graph.variables.variables import Variable -from libs.datetime_utils import naive_utc_now from .exc import ( InvalidIteratorValueError, diff --git a/api/dify_graph/nodes/llm/entities.py b/api/dify_graph/nodes/llm/entities.py index 6ca01a21da..24b6035c44 100644 --- a/api/dify_graph/nodes/llm/entities.py +++ b/api/dify_graph/nodes/llm/entities.py @@ -3,11 +3,11 @@ from typing import Any, Literal from pydantic import BaseModel, Field, field_validator -from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig from dify_graph.entities.base_node_data import BaseNodeData from dify_graph.enums import BuiltinNodeTypes, NodeType from dify_graph.model_runtime.entities import ImagePromptMessageContent, LLMMode from dify_graph.nodes.base.entities import VariableSelector +from dify_graph.prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig class ModelConfig(BaseModel): diff --git a/api/dify_graph/nodes/llm/file_saver.py b/api/dify_graph/nodes/llm/file_saver.py index 50e52a3b6f..7cf4db6ea8 100644 --- a/api/dify_graph/nodes/llm/file_saver.py +++ b/api/dify_graph/nodes/llm/file_saver.py @@ -2,10 +2,8 @@ import mimetypes import typing as tp from constants.mimetypes import DEFAULT_EXTENSION, DEFAULT_MIME_TYPE -from core.tools.signature import sign_tool_file -from core.tools.tool_file_manager import ToolFileManager from dify_graph.file import File, FileTransferMethod, FileType -from dify_graph.nodes.protocols import HttpClientProtocol +from dify_graph.nodes.protocols import FileReferenceFactoryProtocol, HttpClientProtocol, ToolFileManagerProtocol class LLMFileSaver(tp.Protocol): @@ -57,17 +55,20 @@ class LLMFileSaver(tp.Protocol): class FileSaverImpl(LLMFileSaver): - _tenant_id: str - _user_id: str + _tool_file_manager: ToolFileManagerProtocol + _file_reference_factory: FileReferenceFactoryProtocol - def __init__(self, user_id: str, tenant_id: str, http_client: HttpClientProtocol): - self._user_id = user_id - self._tenant_id = tenant_id + def __init__( + self, + *, + tool_file_manager: ToolFileManagerProtocol, + file_reference_factory: FileReferenceFactoryProtocol, + http_client: HttpClientProtocol, + ): + self._tool_file_manager = tool_file_manager + self._file_reference_factory = file_reference_factory self._http_client = http_client - def _get_tool_file_manager(self): - return ToolFileManager() - def save_remote_url(self, url: str, file_type: FileType) -> File: http_response = self._http_client.get(url) http_response.raise_for_status() @@ -83,10 +84,7 @@ class FileSaverImpl(LLMFileSaver): file_type: FileType, extension_override: str | None = None, ) -> File: - tool_file_manager = self._get_tool_file_manager() - tool_file = tool_file_manager.create_file_by_raw( - user_id=self._user_id, - tenant_id=self._tenant_id, + tool_file = self._tool_file_manager.create_file_by_raw( # TODO(QuantumGhost): what is conversation id? conversation_id=None, file_binary=data, @@ -94,19 +92,18 @@ class FileSaverImpl(LLMFileSaver): ) extension_override = _validate_extension_override(extension_override) extension = _get_extension(mime_type, extension_override) - url = sign_tool_file(tool_file.id, extension) - - return File( - tenant_id=self._tenant_id, - type=file_type, - transfer_method=FileTransferMethod.TOOL_FILE, - filename=tool_file.name, - extension=extension, - mime_type=mime_type, - size=len(data), - related_id=tool_file.id, - url=url, - storage_key=tool_file.file_key, + return self._file_reference_factory.build_from_mapping( + mapping={ + "type": file_type, + "transfer_method": FileTransferMethod.TOOL_FILE, + "filename": tool_file.name, + "extension": extension, + "mime_type": mime_type, + "size": len(data), + "tool_file_id": tool_file.id, + "related_id": tool_file.id, + "storage_key": tool_file.file_key, + } ) diff --git a/api/dify_graph/nodes/llm/llm_utils.py b/api/dify_graph/nodes/llm/llm_utils.py index 2be391a424..9fb32e0939 100644 --- a/api/dify_graph/nodes/llm/llm_utils.py +++ b/api/dify_graph/nodes/llm/llm_utils.py @@ -1,9 +1,8 @@ from __future__ import annotations -from collections.abc import Sequence -from typing import Any, cast +from collections.abc import Mapping, Sequence +from typing import Any, Protocol, TypeAlias, cast -from core.model_manager import ModelInstance from dify_graph.file import FileType, file_manager from dify_graph.file.models import File from dify_graph.model_runtime.entities import ( @@ -35,15 +34,36 @@ from .exc import ( TemplateTypeNotSupportError, ) from .protocols import TemplateRenderer +from .runtime_protocols import PreparedLLMProtocol -def fetch_model_schema(*, model_instance: ModelInstance) -> AIModelEntity: - model_schema = cast(LargeLanguageModel, model_instance.model_type_instance).get_model_schema( - model_instance.model_name, - dict(model_instance.credentials), - ) +class _LegacyModelInstance(Protocol): + model_type_instance: object + model_name: str + credentials: object + parameters: Mapping[str, Any] + + def get_llm_num_tokens(self, prompt_messages: Sequence[PromptMessage]) -> int: ... + + +PreparedModelInstance: TypeAlias = PreparedLLMProtocol | _LegacyModelInstance + + +def fetch_model_schema(*, model_instance: PreparedModelInstance) -> AIModelEntity: + get_model_schema = getattr(model_instance, "get_model_schema", None) + if callable(get_model_schema): + model_schema = cast(PreparedLLMProtocol, model_instance).get_model_schema() + else: + legacy_model_instance = cast(_LegacyModelInstance, model_instance) + credentials = legacy_model_instance.credentials + if isinstance(credentials, Mapping): + credentials = dict(credentials) + model_schema = cast(LargeLanguageModel, legacy_model_instance.model_type_instance).get_model_schema( + legacy_model_instance.model_name, + credentials, + ) if not model_schema: - raise ValueError(f"Model schema not found for {model_instance.model_name}") + raise ValueError(f"Model schema not found for {getattr(model_instance, 'model_name', 'unknown model')}") return model_schema @@ -116,7 +136,7 @@ def fetch_prompt_messages( sys_files: Sequence[File], context: str | None = None, memory: PromptMessageMemory | None = None, - model_instance: ModelInstance, + model_instance: PreparedModelInstance, prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, stop: Sequence[str] | None = None, memory_config: MemoryConfig | None = None, @@ -391,7 +411,7 @@ def combine_message_content_with_role( raise NotImplementedError(f"Role {role} is not supported") -def calculate_rest_token(*, prompt_messages: list[PromptMessage], model_instance: ModelInstance) -> int: +def calculate_rest_token(*, prompt_messages: list[PromptMessage], model_instance: PreparedModelInstance) -> int: rest_tokens = 2000 runtime_model_schema = fetch_model_schema(model_instance=model_instance) runtime_model_parameters = model_instance.parameters @@ -421,7 +441,7 @@ def handle_memory_chat_mode( *, memory: PromptMessageMemory | None, memory_config: MemoryConfig | None, - model_instance: ModelInstance, + model_instance: PreparedModelInstance, ) -> Sequence[PromptMessage]: if not memory or not memory_config: return [] @@ -436,7 +456,7 @@ def handle_memory_completion_mode( *, memory: PromptMessageMemory | None, memory_config: MemoryConfig | None, - model_instance: ModelInstance, + model_instance: PreparedModelInstance, ) -> str: if not memory or not memory_config: return "" diff --git a/api/dify_graph/nodes/llm/node.py b/api/dify_graph/nodes/llm/node.py index 5ed90ed7e3..0099a68110 100644 --- a/api/dify_graph/nodes/llm/node.py +++ b/api/dify_graph/nodes/llm/node.py @@ -7,16 +7,8 @@ import logging import re import time from collections.abc import Generator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, Literal +from typing import TYPE_CHECKING, Any, Literal, cast -from sqlalchemy import select - -from core.llm_generator.output_parser.errors import OutputParserError -from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output -from core.model_manager import ModelInstance -from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig -from core.prompt.utils.prompt_message_util import PromptMessageUtil -from core.tools.signature import sign_upload_file from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID from dify_graph.entities import GraphInitParams from dify_graph.entities.graph_config import NodeConfigDict @@ -27,10 +19,11 @@ from dify_graph.enums import ( WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, ) -from dify_graph.file import File, FileTransferMethod, FileType +from dify_graph.file import File, FileType, file_manager from dify_graph.model_runtime.entities import ( ImagePromptMessageContent, PromptMessage, + PromptMessageContentType, TextPromptMessageContent, ) from dify_graph.model_runtime.entities.llm_entities import ( @@ -41,7 +34,14 @@ from dify_graph.model_runtime.entities.llm_entities import ( LLMStructuredOutput, LLMUsage, ) -from dify_graph.model_runtime.entities.message_entities import PromptMessageContentUnionTypes +from dify_graph.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessageContentUnionTypes, + PromptMessageRole, + SystemPromptMessage, + UserPromptMessage, +) +from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey from dify_graph.model_runtime.memory import PromptMessageMemory from dify_graph.model_runtime.utils.encoders import jsonable_encoder from dify_graph.node_events import ( @@ -55,19 +55,23 @@ from dify_graph.node_events import ( from dify_graph.nodes.base.entities import VariableSelector from dify_graph.nodes.base.node import Node from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser -from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer +from dify_graph.nodes.llm.runtime_protocols import ( + PreparedLLMProtocol, + PromptMessageSerializerProtocol, + RetrieverAttachmentLoaderProtocol, +) from dify_graph.nodes.protocols import HttpClientProtocol +from dify_graph.prompt_entities import CompletionModelPromptTemplate, MemoryConfig from dify_graph.runtime import VariablePool +from dify_graph.template_rendering import Jinja2TemplateRenderer, TemplateRenderError from dify_graph.variables import ( ArrayFileSegment, ArraySegment, + FileSegment, NoneSegment, ObjectSegment, StringSegment, ) -from extensions.ext_database import db -from models.dataset import SegmentAttachmentBinding -from models.model import UploadFile from . import llm_utils from .entities import ( @@ -79,9 +83,12 @@ from .exc import ( InvalidContextStructureError, InvalidVariableTypeError, LLMNodeError, + MemoryRolePrefixRequiredError, + NoPromptFoundError, + TemplateTypeNotSupportError, VariableNotFoundError, ) -from .file_saver import FileSaverImpl, LLMFileSaver +from .file_saver import LLMFileSaver if TYPE_CHECKING: from dify_graph.file.models import File @@ -101,11 +108,11 @@ class LLMNode(Node[LLMNodeData]): _file_outputs: list[File] _llm_file_saver: LLMFileSaver - _credentials_provider: CredentialsProvider - _model_factory: ModelFactory - _model_instance: ModelInstance + _retriever_attachment_loader: RetrieverAttachmentLoaderProtocol | None + _prompt_message_serializer: PromptMessageSerializerProtocol + _jinja2_template_renderer: Jinja2TemplateRenderer | None + _model_instance: PreparedLLMProtocol _memory: PromptMessageMemory | None - _template_renderer: TemplateRenderer def __init__( self, @@ -114,13 +121,15 @@ class LLMNode(Node[LLMNodeData]): graph_init_params: GraphInitParams, graph_runtime_state: GraphRuntimeState, *, - credentials_provider: CredentialsProvider, - model_factory: ModelFactory, - model_instance: ModelInstance, + credentials_provider: object | None = None, + model_factory: object | None = None, + model_instance: PreparedLLMProtocol, http_client: HttpClientProtocol, - template_renderer: TemplateRenderer, memory: PromptMessageMemory | None = None, - llm_file_saver: LLMFileSaver | None = None, + llm_file_saver: LLMFileSaver, + prompt_message_serializer: PromptMessageSerializerProtocol, + retriever_attachment_loader: RetrieverAttachmentLoaderProtocol | None = None, + jinja2_template_renderer: Jinja2TemplateRenderer | None = None, ): super().__init__( id=id, @@ -131,20 +140,14 @@ class LLMNode(Node[LLMNodeData]): # LLM file outputs, used for MultiModal outputs. self._file_outputs = [] - self._credentials_provider = credentials_provider - self._model_factory = model_factory + _ = credentials_provider, model_factory, http_client self._model_instance = model_instance self._memory = memory - self._template_renderer = template_renderer - if llm_file_saver is None: - dify_ctx = self.require_dify_context() - llm_file_saver = FileSaverImpl( - user_id=dify_ctx.user_id, - tenant_id=dify_ctx.tenant_id, - http_client=http_client, - ) self._llm_file_saver = llm_file_saver + self._prompt_message_serializer = prompt_message_serializer + self._retriever_attachment_loader = retriever_attachment_loader + self._jinja2_template_renderer = jinja2_template_renderer @classmethod def version(cls) -> str: @@ -230,7 +233,7 @@ class LLMNode(Node[LLMNodeData]): variable_pool=variable_pool, jinja2_variables=self.node_data.prompt_config.jinja2_variables, context_files=context_files, - template_renderer=self._template_renderer, + jinja2_template_renderer=self._jinja2_template_renderer, ) # handle invoke result @@ -238,7 +241,6 @@ class LLMNode(Node[LLMNodeData]): model_instance=model_instance, prompt_messages=prompt_messages, stop=stop, - user_id=self.require_dify_context().user_id, structured_output_enabled=self.node_data.structured_output_enabled, structured_output=self.node_data.structured_output, file_saver=self._llm_file_saver, @@ -281,7 +283,7 @@ class LLMNode(Node[LLMNodeData]): process_data = { "model_mode": self.node_data.model.mode, - "prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving( + "prompts": self._prompt_message_serializer.serialize( model_mode=self.node_data.model.mode, prompt_messages=prompt_messages ), "usage": jsonable_encoder(usage), @@ -349,10 +351,9 @@ class LLMNode(Node[LLMNodeData]): @staticmethod def invoke_llm( *, - model_instance: ModelInstance, + model_instance: PreparedLLMProtocol, prompt_messages: Sequence[PromptMessage], stop: Sequence[str] | None = None, - user_id: str, structured_output_enabled: bool, structured_output: Mapping[str, Any] | None = None, file_saver: LLMFileSaver, @@ -363,35 +364,28 @@ class LLMNode(Node[LLMNodeData]): ) -> Generator[NodeEventBase | LLMStructuredOutput, None, None]: model_parameters = model_instance.parameters invoke_model_parameters = dict(model_parameters) - - model_schema = llm_utils.fetch_model_schema(model_instance=model_instance) - if structured_output_enabled: output_schema = LLMNode.fetch_structured_output_schema( structured_output=structured_output or {}, ) request_start_time = time.perf_counter() - invoke_result = invoke_llm_with_structured_output( - provider=model_instance.provider, - model_schema=model_schema, - model_instance=model_instance, + invoke_result = model_instance.invoke_llm_with_structured_output( prompt_messages=prompt_messages, json_schema=output_schema, model_parameters=invoke_model_parameters, - stop=list(stop or []), + stop=stop, stream=True, - user=user_id, ) else: request_start_time = time.perf_counter() invoke_result = model_instance.invoke_llm( - prompt_messages=list(prompt_messages), + prompt_messages=prompt_messages, model_parameters=invoke_model_parameters, - stop=list(stop or []), + tools=None, + stop=stop, stream=True, - user=user_id, ) return LLMNode.handle_invoke_result( @@ -400,6 +394,7 @@ class LLMNode(Node[LLMNodeData]): file_outputs=file_outputs, node_id=node_id, node_type=node_type, + model_instance=model_instance, reasoning_format=reasoning_format, request_start_time=request_start_time, ) @@ -412,6 +407,7 @@ class LLMNode(Node[LLMNodeData]): file_outputs: list[File], node_id: str, node_type: NodeType, + model_instance: PreparedLLMProtocol | object, reasoning_format: Literal["separated", "tagged"] = "tagged", request_start_time: float | None = None, ) -> Generator[NodeEventBase | LLMStructuredOutput, None, None]: @@ -483,8 +479,14 @@ class LLMNode(Node[LLMNodeData]): usage = result.delta.usage if finish_reason is None and result.delta.finish_reason: finish_reason = result.delta.finish_reason - except OutputParserError as e: - raise LLMNodeError(f"Failed to parse structured output: {e}") + except Exception as e: + if hasattr(model_instance, "is_structured_output_parse_error") and cast( + PreparedLLMProtocol, model_instance + ).is_structured_output_parse_error(e): + raise LLMNodeError(f"Failed to parse structured output: {e}") from e + if type(e).__name__ == "OutputParserError": + raise LLMNodeError(f"Failed to parse structured output: {e}") from e + raise # Extract reasoning content from tags in the main text full_text = full_text_buffer.getvalue() @@ -687,30 +689,8 @@ class LLMNode(Node[LLMNodeData]): segment_id = retriever_resource.get("segment_id") if not segment_id: continue - attachments_with_bindings = db.session.execute( - select(SegmentAttachmentBinding, UploadFile) - .join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id) - .where( - SegmentAttachmentBinding.segment_id == segment_id, - ) - ).all() - if attachments_with_bindings: - for _, upload_file in attachments_with_bindings: - attachment_info = File( - id=upload_file.id, - filename=upload_file.name, - extension="." + upload_file.extension, - mime_type=upload_file.mime_type, - tenant_id=self.require_dify_context().tenant_id, - type=FileType.IMAGE, - transfer_method=FileTransferMethod.LOCAL_FILE, - remote_url=upload_file.source_url, - related_id=upload_file.id, - size=upload_file.size, - storage_key=upload_file.key, - url=sign_upload_file(upload_file.id, upload_file.extension), - ) - context_files.append(attachment_info) + if self._retriever_attachment_loader is not None: + context_files.extend(self._retriever_attachment_loader.load(segment_id=segment_id)) yield RunRetrieverResourceEvent( retriever_resources=original_retriever_resource, context=context_str.strip(), @@ -755,7 +735,7 @@ class LLMNode(Node[LLMNodeData]): sys_files: Sequence[File], context: str | None = None, memory: PromptMessageMemory | None = None, - model_instance: ModelInstance, + model_instance: PreparedLLMProtocol, prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, stop: Sequence[str] | None = None, memory_config: MemoryConfig | None = None, @@ -764,24 +744,186 @@ class LLMNode(Node[LLMNodeData]): variable_pool: VariablePool, jinja2_variables: Sequence[VariableSelector], context_files: list[File] | None = None, - template_renderer: TemplateRenderer | None = None, + jinja2_template_renderer: Jinja2TemplateRenderer | None = None, ) -> tuple[Sequence[PromptMessage], Sequence[str] | None]: - return llm_utils.fetch_prompt_messages( - sys_query=sys_query, - sys_files=sys_files, - context=context, - memory=memory, - model_instance=model_instance, - prompt_template=prompt_template, - stop=stop, - memory_config=memory_config, - vision_enabled=vision_enabled, - vision_detail=vision_detail, - variable_pool=variable_pool, - jinja2_variables=jinja2_variables, - context_files=context_files, - template_renderer=template_renderer, - ) + prompt_messages: list[PromptMessage] = [] + model_schema = llm_utils.fetch_model_schema(model_instance=model_instance) + + if isinstance(prompt_template, list): + # For chat model + prompt_messages.extend( + LLMNode.handle_list_messages( + messages=prompt_template, + context=context, + jinja2_variables=jinja2_variables, + variable_pool=variable_pool, + vision_detail_config=vision_detail, + jinja2_template_renderer=jinja2_template_renderer, + ) + ) + + # Get memory messages for chat mode + memory_messages = _handle_memory_chat_mode( + memory=memory, + memory_config=memory_config, + model_instance=model_instance, + ) + # Extend prompt_messages with memory messages + prompt_messages.extend(memory_messages) + + # Add current query to the prompt messages + if sys_query: + message = LLMNodeChatModelMessage( + text=sys_query, + role=PromptMessageRole.USER, + edition_type="basic", + ) + prompt_messages.extend( + LLMNode.handle_list_messages( + messages=[message], + context="", + jinja2_variables=[], + variable_pool=variable_pool, + vision_detail_config=vision_detail, + jinja2_template_renderer=jinja2_template_renderer, + ) + ) + + elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate): + # For completion model + prompt_messages.extend( + _handle_completion_template( + template=prompt_template, + context=context, + jinja2_variables=jinja2_variables, + variable_pool=variable_pool, + jinja2_template_renderer=jinja2_template_renderer, + ) + ) + + # Get memory text for completion model + memory_text = _handle_memory_completion_mode( + memory=memory, + memory_config=memory_config, + model_instance=model_instance, + ) + # Insert histories into the prompt + prompt_content = prompt_messages[0].content + # For issue #11247 - Check if prompt content is a string or a list + if isinstance(prompt_content, str): + prompt_content = str(prompt_content) + if "#histories#" in prompt_content: + prompt_content = prompt_content.replace("#histories#", memory_text) + else: + prompt_content = memory_text + "\n" + prompt_content + prompt_messages[0].content = prompt_content + elif isinstance(prompt_content, list): + for content_item in prompt_content: + if isinstance(content_item, TextPromptMessageContent): + if "#histories#" in content_item.data: + content_item.data = content_item.data.replace("#histories#", memory_text) + else: + content_item.data = memory_text + "\n" + content_item.data + else: + raise ValueError("Invalid prompt content type") + + # Add current query to the prompt message + if sys_query: + if isinstance(prompt_content, str): + prompt_content = str(prompt_messages[0].content).replace("#sys.query#", sys_query) + prompt_messages[0].content = prompt_content + elif isinstance(prompt_content, list): + for content_item in prompt_content: + if isinstance(content_item, TextPromptMessageContent): + content_item.data = sys_query + "\n" + content_item.data + else: + raise ValueError("Invalid prompt content type") + else: + raise TemplateTypeNotSupportError(type_name=str(type(prompt_template))) + + # The sys_files will be deprecated later + if vision_enabled and sys_files: + file_prompts = [] + for file in sys_files: + file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail) + file_prompts.append(file_prompt) + # If last prompt is a user prompt, add files into its contents, + # otherwise append a new user prompt + if ( + len(prompt_messages) > 0 + and isinstance(prompt_messages[-1], UserPromptMessage) + and isinstance(prompt_messages[-1].content, list) + ): + prompt_messages[-1] = UserPromptMessage(content=file_prompts + prompt_messages[-1].content) + else: + prompt_messages.append(UserPromptMessage(content=file_prompts)) + + # The context_files + if vision_enabled and context_files: + file_prompts = [] + for file in context_files: + file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail) + file_prompts.append(file_prompt) + # If last prompt is a user prompt, add files into its contents, + # otherwise append a new user prompt + if ( + len(prompt_messages) > 0 + and isinstance(prompt_messages[-1], UserPromptMessage) + and isinstance(prompt_messages[-1].content, list) + ): + prompt_messages[-1] = UserPromptMessage(content=file_prompts + prompt_messages[-1].content) + else: + prompt_messages.append(UserPromptMessage(content=file_prompts)) + + # Remove empty messages and filter unsupported content + filtered_prompt_messages = [] + for prompt_message in prompt_messages: + if isinstance(prompt_message.content, list): + prompt_message_content: list[PromptMessageContentUnionTypes] = [] + for content_item in prompt_message.content: + # Skip content if features are not defined + if not model_schema.features: + if content_item.type != PromptMessageContentType.TEXT: + continue + prompt_message_content.append(content_item) + continue + + # Skip content if corresponding feature is not supported + if ( + ( + content_item.type == PromptMessageContentType.IMAGE + and ModelFeature.VISION not in model_schema.features + ) + or ( + content_item.type == PromptMessageContentType.DOCUMENT + and ModelFeature.DOCUMENT not in model_schema.features + ) + or ( + content_item.type == PromptMessageContentType.VIDEO + and ModelFeature.VIDEO not in model_schema.features + ) + or ( + content_item.type == PromptMessageContentType.AUDIO + and ModelFeature.AUDIO not in model_schema.features + ) + ): + continue + prompt_message_content.append(content_item) + if len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT: + prompt_message.content = prompt_message_content[0].data + else: + prompt_message.content = prompt_message_content + if prompt_message.is_empty(): + continue + filtered_prompt_messages.append(prompt_message) + + if len(filtered_prompt_messages) == 0: + raise NoPromptFoundError( + "No prompt found in the LLM configuration. " + "Please ensure a prompt is properly configured before proceeding." + ) + + return filtered_prompt_messages, stop @classmethod def _extract_variable_selector_to_variable_mapping( @@ -881,16 +1023,61 @@ class LLMNode(Node[LLMNodeData]): jinja2_variables: Sequence[VariableSelector], variable_pool: VariablePool, vision_detail_config: ImagePromptMessageContent.DETAIL, - template_renderer: TemplateRenderer | None = None, + jinja2_template_renderer: Jinja2TemplateRenderer | None = None, ) -> Sequence[PromptMessage]: - return llm_utils.handle_list_messages( - messages=messages, - context=context, - jinja2_variables=jinja2_variables, - variable_pool=variable_pool, - vision_detail_config=vision_detail_config, - template_renderer=template_renderer, - ) + prompt_messages: list[PromptMessage] = [] + for message in messages: + if message.edition_type == "jinja2": + result_text = _render_jinja2_message( + template=message.jinja2_text or "", + jinja2_variables=jinja2_variables, + variable_pool=variable_pool, + jinja2_template_renderer=jinja2_template_renderer, + ) + prompt_message = _combine_message_content_with_role( + contents=[TextPromptMessageContent(data=result_text)], role=message.role + ) + prompt_messages.append(prompt_message) + else: + # Get segment group from basic message + if context: + template = message.text.replace("{#context#}", context) + else: + template = message.text + segment_group = variable_pool.convert_template(template) + + # Process segments for images + file_contents = [] + for segment in segment_group.value: + if isinstance(segment, ArrayFileSegment): + for file in segment.value: + if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}: + file_content = file_manager.to_prompt_message_content( + file, image_detail_config=vision_detail_config + ) + file_contents.append(file_content) + elif isinstance(segment, FileSegment): + file = segment.value + if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}: + file_content = file_manager.to_prompt_message_content( + file, image_detail_config=vision_detail_config + ) + file_contents.append(file_content) + + # Create message with text from all segments + plain_text = segment_group.text + if plain_text: + prompt_message = _combine_message_content_with_role( + contents=[TextPromptMessageContent(data=plain_text)], role=message.role + ) + prompt_messages.append(prompt_message) + + if file_contents: + # Create message with image contents + prompt_message = _combine_message_content_with_role(contents=file_contents, role=message.role) + prompt_messages.append(prompt_message) + + return prompt_messages @staticmethod def handle_blocking_result( @@ -1027,5 +1214,153 @@ class LLMNode(Node[LLMNodeData]): return self.node_data.retry_config.retry_enabled @property - def model_instance(self) -> ModelInstance: + def model_instance(self) -> PreparedLLMProtocol: return self._model_instance + + +def _combine_message_content_with_role( + *, contents: str | list[PromptMessageContentUnionTypes] | None = None, role: PromptMessageRole +): + match role: + case PromptMessageRole.USER: + return UserPromptMessage(content=contents) + case PromptMessageRole.ASSISTANT: + return AssistantPromptMessage(content=contents) + case PromptMessageRole.SYSTEM: + return SystemPromptMessage(content=contents) + case _: + raise NotImplementedError(f"Role {role} is not supported") + + +def _render_jinja2_message( + *, + template: str, + jinja2_variables: Sequence[VariableSelector], + variable_pool: VariablePool, + jinja2_template_renderer: Jinja2TemplateRenderer | None, +): + if not template: + return "" + + jinja2_inputs = {} + for jinja2_variable in jinja2_variables: + variable = variable_pool.get(jinja2_variable.value_selector) + jinja2_inputs[jinja2_variable.variable] = variable.to_object() if variable else "" + if jinja2_template_renderer is None: + raise TemplateRenderError("LLMNode requires an injected jinja2_template_renderer for jinja2 prompts.") + return jinja2_template_renderer.render_template(template, jinja2_inputs) + + +def _calculate_rest_token( + *, + prompt_messages: list[PromptMessage], + model_instance: PreparedLLMProtocol, +) -> int: + rest_tokens = 2000 + runtime_model_schema = llm_utils.fetch_model_schema(model_instance=model_instance) + runtime_model_parameters = model_instance.parameters + + model_context_tokens = runtime_model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) + if model_context_tokens: + curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages) + + max_tokens = 0 + for parameter_rule in runtime_model_schema.parameter_rules: + if parameter_rule.name == "max_tokens" or ( + parameter_rule.use_template and parameter_rule.use_template == "max_tokens" + ): + max_tokens = ( + runtime_model_parameters.get(parameter_rule.name) + or runtime_model_parameters.get(str(parameter_rule.use_template)) + or 0 + ) + + rest_tokens = model_context_tokens - max_tokens - curr_message_tokens + rest_tokens = max(rest_tokens, 0) + + return rest_tokens + + +def _handle_memory_chat_mode( + *, + memory: PromptMessageMemory | None, + memory_config: MemoryConfig | None, + model_instance: PreparedLLMProtocol, +) -> Sequence[PromptMessage]: + memory_messages: Sequence[PromptMessage] = [] + # Get messages from memory for chat model + if memory and memory_config: + rest_tokens = _calculate_rest_token( + prompt_messages=[], + model_instance=model_instance, + ) + memory_messages = memory.get_history_prompt_messages( + max_token_limit=rest_tokens, + message_limit=memory_config.window.size if memory_config.window.enabled else None, + ) + return memory_messages + + +def _handle_memory_completion_mode( + *, + memory: PromptMessageMemory | None, + memory_config: MemoryConfig | None, + model_instance: PreparedLLMProtocol, +) -> str: + memory_text = "" + # Get history text from memory for completion model + if memory and memory_config: + rest_tokens = _calculate_rest_token( + prompt_messages=[], + model_instance=model_instance, + ) + if not memory_config.role_prefix: + raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.") + memory_text = llm_utils.fetch_memory_text( + memory=memory, + max_token_limit=rest_tokens, + message_limit=memory_config.window.size if memory_config.window.enabled else None, + human_prefix=memory_config.role_prefix.user, + ai_prefix=memory_config.role_prefix.assistant, + ) + return memory_text + + +def _handle_completion_template( + *, + template: LLMNodeCompletionModelPromptTemplate, + context: str | None, + jinja2_variables: Sequence[VariableSelector], + variable_pool: VariablePool, + jinja2_template_renderer: Jinja2TemplateRenderer | None = None, +) -> Sequence[PromptMessage]: + """Handle completion template processing outside of LLMNode class. + + Args: + template: The completion model prompt template + context: Optional context string + jinja2_variables: Variables for jinja2 template rendering + variable_pool: Variable pool for template conversion + + Returns: + Sequence of prompt messages + """ + prompt_messages = [] + if template.edition_type == "jinja2": + result_text = _render_jinja2_message( + template=template.jinja2_text or "", + jinja2_variables=jinja2_variables, + variable_pool=variable_pool, + jinja2_template_renderer=jinja2_template_renderer, + ) + else: + if context: + template_text = template.text.replace("{#context#}", context) + else: + template_text = template.text + result_text = variable_pool.convert_template(template_text).text + prompt_message = _combine_message_content_with_role( + contents=[TextPromptMessageContent(data=result_text)], role=PromptMessageRole.USER + ) + prompt_messages.append(prompt_message) + return prompt_messages diff --git a/api/dify_graph/nodes/llm/protocols.py b/api/dify_graph/nodes/llm/protocols.py index 9e95d341c9..e4e5963812 100644 --- a/api/dify_graph/nodes/llm/protocols.py +++ b/api/dify_graph/nodes/llm/protocols.py @@ -3,7 +3,7 @@ from __future__ import annotations from collections.abc import Mapping from typing import Any, Protocol -from core.model_manager import ModelInstance +from dify_graph.nodes.llm.runtime_protocols import PreparedLLMProtocol class CredentialsProvider(Protocol): @@ -15,10 +15,10 @@ class CredentialsProvider(Protocol): class ModelFactory(Protocol): - """Port for creating initialized LLM model instances for execution.""" + """Port for creating prepared graph-facing LLM runtimes for execution.""" - def init_model_instance(self, provider_name: str, model_name: str) -> ModelInstance: - """Create a model instance that is ready for schema lookup and invocation.""" + def init_model_instance(self, provider_name: str, model_name: str) -> PreparedLLMProtocol: + """Create a prepared LLM runtime that is ready for graph execution.""" ... diff --git a/api/dify_graph/nodes/llm/runtime_protocols.py b/api/dify_graph/nodes/llm/runtime_protocols.py new file mode 100644 index 0000000000..64c1d8a925 --- /dev/null +++ b/api/dify_graph/nodes/llm/runtime_protocols.py @@ -0,0 +1,74 @@ +from __future__ import annotations + +from collections.abc import Generator, Mapping, Sequence +from typing import Any, Protocol + +from dify_graph.file import File +from dify_graph.model_runtime.entities import LLMMode, PromptMessage +from dify_graph.model_runtime.entities.llm_entities import ( + LLMResult, + LLMResultChunk, + LLMResultChunkWithStructuredOutput, + LLMResultWithStructuredOutput, +) +from dify_graph.model_runtime.entities.message_entities import PromptMessageTool +from dify_graph.model_runtime.entities.model_entities import AIModelEntity + + +class PreparedLLMProtocol(Protocol): + """A graph-facing LLM runtime with provider-specific setup already applied.""" + + @property + def provider(self) -> str: ... + + @property + def model_name(self) -> str: ... + + @property + def parameters(self) -> Mapping[str, Any]: ... + + @property + def stop(self) -> Sequence[str] | None: ... + + def get_model_schema(self) -> AIModelEntity: ... + + def get_llm_num_tokens(self, prompt_messages: Sequence[PromptMessage]) -> int: ... + + def invoke_llm( + self, + *, + prompt_messages: Sequence[PromptMessage], + model_parameters: Mapping[str, Any], + tools: Sequence[PromptMessageTool] | None, + stop: Sequence[str] | None, + stream: bool, + ) -> LLMResult | Generator[LLMResultChunk, None, None]: ... + + def invoke_llm_with_structured_output( + self, + *, + prompt_messages: Sequence[PromptMessage], + json_schema: Mapping[str, Any], + model_parameters: Mapping[str, Any], + stop: Sequence[str] | None, + stream: bool, + ) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]: ... + + def is_structured_output_parse_error(self, error: Exception) -> bool: ... + + +class PromptMessageSerializerProtocol(Protocol): + """Port for converting compiled prompt messages into persisted process data.""" + + def serialize( + self, + *, + model_mode: LLMMode, + prompt_messages: Sequence[PromptMessage], + ) -> Any: ... + + +class RetrieverAttachmentLoaderProtocol(Protocol): + """Port for resolving retriever segment attachments into graph file references.""" + + def load(self, *, segment_id: str) -> Sequence[File]: ... diff --git a/api/dify_graph/nodes/loop/loop_node.py b/api/dify_graph/nodes/loop/loop_node.py index 3c546ffa23..5319864eef 100644 --- a/api/dify_graph/nodes/loop/loop_node.py +++ b/api/dify_graph/nodes/loop/loop_node.py @@ -31,9 +31,9 @@ from dify_graph.nodes.base import LLMUsageTrackingMixin from dify_graph.nodes.base.node import Node from dify_graph.nodes.loop.entities import LoopCompletedReason, LoopNodeData, LoopVariableData from dify_graph.utils.condition.processor import ConditionProcessor +from dify_graph.utils.datetime_utils import naive_utc_now from dify_graph.variables import Segment, SegmentType from factories.variable_factory import TypeMismatchError, build_segment_with_type, segment_to_variable -from libs.datetime_utils import naive_utc_now if TYPE_CHECKING: from dify_graph.graph_engine import GraphEngine diff --git a/api/dify_graph/nodes/parameter_extractor/entities.py b/api/dify_graph/nodes/parameter_extractor/entities.py index 2fb042c16c..bc6c85c8bc 100644 --- a/api/dify_graph/nodes/parameter_extractor/entities.py +++ b/api/dify_graph/nodes/parameter_extractor/entities.py @@ -7,10 +7,10 @@ from pydantic import ( field_validator, ) -from core.prompt.entities.advanced_prompt_entities import MemoryConfig from dify_graph.entities.base_node_data import BaseNodeData from dify_graph.enums import BuiltinNodeTypes, NodeType from dify_graph.nodes.llm.entities import ModelConfig, VisionConfig +from dify_graph.prompt_entities import MemoryConfig from dify_graph.variables.types import SegmentType _OLD_BOOL_TYPE_NAME = "bool" diff --git a/api/dify_graph/nodes/parameter_extractor/parameter_extractor_node.py b/api/dify_graph/nodes/parameter_extractor/parameter_extractor_node.py index 3913a27697..c2ee7c04b8 100644 --- a/api/dify_graph/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/dify_graph/nodes/parameter_extractor/parameter_extractor_node.py @@ -5,11 +5,6 @@ import uuid from collections.abc import Mapping, Sequence from typing import TYPE_CHECKING, Any, cast -from core.model_manager import ModelInstance -from core.prompt.advanced_prompt_transform import AdvancedPromptTransform -from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate -from core.prompt.simple_prompt_transform import ModelMode -from core.prompt.utils.prompt_message_util import PromptMessageUtil from dify_graph.entities.graph_config import NodeConfigDict from dify_graph.enums import ( BuiltinNodeTypes, @@ -17,8 +12,8 @@ from dify_graph.enums import ( WorkflowNodeExecutionStatus, ) from dify_graph.file import File -from dify_graph.model_runtime.entities import ImagePromptMessageContent -from dify_graph.model_runtime.entities.llm_entities import LLMUsage +from dify_graph.model_runtime.entities import ImagePromptMessageContent, LLMMode +from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage from dify_graph.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, @@ -27,14 +22,15 @@ from dify_graph.model_runtime.entities.message_entities import ( ToolPromptMessage, UserPromptMessage, ) -from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey +from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey, ModelType from dify_graph.model_runtime.memory import PromptMessageMemory -from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from dify_graph.model_runtime.utils.encoders import jsonable_encoder from dify_graph.node_events import NodeRunResult from dify_graph.nodes.base import variable_template_parser from dify_graph.nodes.base.node import Node -from dify_graph.nodes.llm import llm_utils +from dify_graph.nodes.llm import LLMNode, llm_utils +from dify_graph.nodes.llm.entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate +from dify_graph.nodes.llm.runtime_protocols import PreparedLLMProtocol, PromptMessageSerializerProtocol from dify_graph.runtime import VariablePool from dify_graph.variables.types import ArrayValidation, SegmentType from factories.variable_factory import build_segment_with_type @@ -66,7 +62,6 @@ logger = logging.getLogger(__name__) if TYPE_CHECKING: from dify_graph.entities import GraphInitParams - from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory from dify_graph.runtime import GraphRuntimeState @@ -99,9 +94,8 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): node_type = BuiltinNodeTypes.PARAMETER_EXTRACTOR - _model_instance: ModelInstance - _credentials_provider: "CredentialsProvider" - _model_factory: "ModelFactory" + _model_instance: PreparedLLMProtocol + _prompt_message_serializer: PromptMessageSerializerProtocol _memory: PromptMessageMemory | None def __init__( @@ -111,10 +105,11 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", *, - credentials_provider: "CredentialsProvider", - model_factory: "ModelFactory", - model_instance: ModelInstance, + credentials_provider: object | None = None, + model_factory: object | None = None, + model_instance: PreparedLLMProtocol, memory: PromptMessageMemory | None = None, + prompt_message_serializer: PromptMessageSerializerProtocol, ) -> None: super().__init__( id=id, @@ -122,9 +117,9 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, ) - self._credentials_provider = credentials_provider - self._model_factory = model_factory + _ = credentials_provider, model_factory self._model_instance = model_instance + self._prompt_message_serializer = prompt_message_serializer self._memory = memory @classmethod @@ -164,13 +159,12 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): ) model_instance = self._model_instance - if not isinstance(model_instance.model_type_instance, LargeLanguageModel): - raise InvalidModelTypeError("Model is not a Large Language Model") - try: model_schema = llm_utils.fetch_model_schema(model_instance=model_instance) except ValueError as exc: raise ModelSchemaNotFoundError("Model schema not found") from exc + if model_schema.model_type != ModelType.LLM: + raise InvalidModelTypeError("Model is not a Large Language Model") memory = self._memory if ( @@ -210,8 +204,9 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): process_data = { "model_mode": node_data.model.mode, - "prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving( - model_mode=node_data.model.mode, prompt_messages=prompt_messages + "prompts": self._prompt_message_serializer.serialize( + model_mode=node_data.model.mode, + prompt_messages=prompt_messages, ), "usage": None, "function": {} if not prompt_message_tools else jsonable_encoder(prompt_message_tools[0]), @@ -287,18 +282,20 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): def _invoke( self, - model_instance: ModelInstance, + model_instance: PreparedLLMProtocol, prompt_messages: list[PromptMessage], tools: list[PromptMessageTool], - stop: Sequence[str], + stop: Sequence[str] | None, ) -> tuple[str, LLMUsage, AssistantPromptMessage.ToolCall | None]: - invoke_result = model_instance.invoke_llm( - prompt_messages=prompt_messages, - model_parameters=dict(model_instance.parameters), - tools=tools, - stop=list(stop), - stream=False, - user=self.require_dify_context().user_id, + invoke_result = cast( + LLMResult, + model_instance.invoke_llm( + prompt_messages=prompt_messages, + model_parameters=dict(model_instance.parameters), + tools=tools or None, + stop=stop, + stream=False, + ), ) # handle invoke result @@ -317,7 +314,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): node_data: ParameterExtractorNodeData, query: str, variable_pool: VariablePool, - model_instance: ModelInstance, + model_instance: PreparedLLMProtocol, memory: PromptMessageMemory | None, files: Sequence[File], vision_detail: ImagePromptMessageContent.DETAIL | None = None, @@ -329,7 +326,6 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): content=query, structure=json.dumps(node_data.get_parameter_json_schema()) ) - prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) rest_token = self._calculate_rest_token( node_data=node_data, query=query, @@ -340,15 +336,11 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): prompt_template = self._get_function_calling_prompt_template( node_data, query, variable_pool, memory, rest_token ) - prompt_messages = prompt_transform.get_prompt( - prompt_template=prompt_template, - inputs={}, - query="", - files=files, - context="", - memory_config=node_data.memory, - memory=None, + prompt_messages = self._compile_prompt_messages( model_instance=model_instance, + prompt_template=prompt_template, + files=files, + vision_enabled=node_data.vision.enabled, image_detail_config=vision_detail, ) @@ -405,7 +397,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): data: ParameterExtractorNodeData, query: str, variable_pool: VariablePool, - model_instance: ModelInstance, + model_instance: PreparedLLMProtocol, memory: PromptMessageMemory | None, files: Sequence[File], vision_detail: ImagePromptMessageContent.DETAIL | None = None, @@ -413,9 +405,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): """ Generate prompt engineering prompt. """ - model_mode = ModelMode(data.model.mode) - - if model_mode == ModelMode.COMPLETION: + if data.model.mode == LLMMode.COMPLETION: return self._generate_prompt_engineering_completion_prompt( node_data=data, query=query, @@ -425,7 +415,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): files=files, vision_detail=vision_detail, ) - elif model_mode == ModelMode.CHAT: + if data.model.mode == LLMMode.CHAT: return self._generate_prompt_engineering_chat_prompt( node_data=data, query=query, @@ -435,15 +425,14 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): files=files, vision_detail=vision_detail, ) - else: - raise InvalidModelModeError(f"Invalid model mode: {model_mode}") + raise InvalidModelModeError(f"Invalid model mode: {data.model.mode}") def _generate_prompt_engineering_completion_prompt( self, node_data: ParameterExtractorNodeData, query: str, variable_pool: VariablePool, - model_instance: ModelInstance, + model_instance: PreparedLLMProtocol, memory: PromptMessageMemory | None, files: Sequence[File], vision_detail: ImagePromptMessageContent.DETAIL | None = None, @@ -451,7 +440,6 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): """ Generate completion prompt. """ - prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) rest_token = self._calculate_rest_token( node_data=node_data, query=query, @@ -462,27 +450,20 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): prompt_template = self._get_prompt_engineering_prompt_template( node_data=node_data, query=query, variable_pool=variable_pool, memory=memory, max_token_limit=rest_token ) - prompt_messages = prompt_transform.get_prompt( - prompt_template=prompt_template, - inputs={"structure": json.dumps(node_data.get_parameter_json_schema())}, - query="", - files=files, - context="", - memory_config=node_data.memory, - # AdvancedPromptTransform is still typed against TokenBufferMemory. - memory=cast(Any, memory), + return self._compile_prompt_messages( model_instance=model_instance, + prompt_template=prompt_template, + files=files, + vision_enabled=node_data.vision.enabled, image_detail_config=vision_detail, ) - return prompt_messages - def _generate_prompt_engineering_chat_prompt( self, node_data: ParameterExtractorNodeData, query: str, variable_pool: VariablePool, - model_instance: ModelInstance, + model_instance: PreparedLLMProtocol, memory: PromptMessageMemory | None, files: Sequence[File], vision_detail: ImagePromptMessageContent.DETAIL | None = None, @@ -490,7 +471,6 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): """ Generate chat prompt. """ - prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) rest_token = self._calculate_rest_token( node_data=node_data, query=query, @@ -508,15 +488,11 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): max_token_limit=rest_token, ) - prompt_messages = prompt_transform.get_prompt( - prompt_template=prompt_template, - inputs={}, - query="", - files=files, - context="", - memory_config=node_data.memory, - memory=None, + prompt_messages = self._compile_prompt_messages( model_instance=model_instance, + prompt_template=prompt_template, + files=files, + vision_enabled=node_data.vision.enabled, image_detail_config=vision_detail, ) @@ -717,8 +693,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): variable_pool: VariablePool, memory: PromptMessageMemory | None, max_token_limit: int = 2000, - ) -> list[ChatModelMessage]: - model_mode = ModelMode(node_data.model.mode) + ) -> list[LLMNodeChatModelMessage]: input_text = query memory_str = "" instruction = variable_pool.convert_template(node_data.instruction or "").text @@ -727,15 +702,14 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): memory_str = llm_utils.fetch_memory_text( memory=memory, max_token_limit=max_token_limit, message_limit=node_data.memory.window.size ) - if model_mode == ModelMode.CHAT: - system_prompt_messages = ChatModelMessage( + if node_data.model.mode == LLMMode.CHAT: + system_prompt_messages = LLMNodeChatModelMessage( role=PromptMessageRole.SYSTEM, text=FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT.format(histories=memory_str, instruction=instruction), ) - user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=input_text) + user_prompt_message = LLMNodeChatModelMessage(role=PromptMessageRole.USER, text=input_text) return [system_prompt_messages, user_prompt_message] - else: - raise InvalidModelModeError(f"Model mode {model_mode} not support.") + raise InvalidModelModeError(f"Model mode {node_data.model.mode} not support.") def _get_prompt_engineering_prompt_template( self, @@ -744,8 +718,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): variable_pool: VariablePool, memory: PromptMessageMemory | None, max_token_limit: int = 2000, - ): - model_mode = ModelMode(node_data.model.mode) + ) -> list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate: input_text = query memory_str = "" instruction = variable_pool.convert_template(node_data.instruction or "").text @@ -754,64 +727,53 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): memory_str = llm_utils.fetch_memory_text( memory=memory, max_token_limit=max_token_limit, message_limit=node_data.memory.window.size ) - if model_mode == ModelMode.CHAT: - system_prompt_messages = ChatModelMessage( + if node_data.model.mode == LLMMode.CHAT: + system_prompt_messages = LLMNodeChatModelMessage( role=PromptMessageRole.SYSTEM, text=CHAT_GENERATE_JSON_PROMPT.format(histories=memory_str, instructions=instruction), ) - user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=input_text) + user_prompt_message = LLMNodeChatModelMessage(role=PromptMessageRole.USER, text=input_text) return [system_prompt_messages, user_prompt_message] - elif model_mode == ModelMode.COMPLETION: - return CompletionModelPromptTemplate( + if node_data.model.mode == LLMMode.COMPLETION: + return LLMNodeCompletionModelPromptTemplate( text=COMPLETION_GENERATE_JSON_PROMPT.format( histories=memory_str, text=input_text, instruction=instruction ) .replace("{γγγ", "") .replace("}γγγ", "") + .replace("{ structure }", json.dumps(node_data.get_parameter_json_schema())), ) - else: - raise InvalidModelModeError(f"Model mode {model_mode} not support.") + raise InvalidModelModeError(f"Model mode {node_data.model.mode} not support.") def _calculate_rest_token( self, node_data: ParameterExtractorNodeData, query: str, variable_pool: VariablePool, - model_instance: ModelInstance, + model_instance: PreparedLLMProtocol, context: str | None, ) -> int: try: model_schema = llm_utils.fetch_model_schema(model_instance=model_instance) except ValueError as exc: raise ModelSchemaNotFoundError("Model schema not found") from exc - prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) if set(model_schema.features or []) & {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}: prompt_template = self._get_function_calling_prompt_template(node_data, query, variable_pool, None, 2000) else: prompt_template = self._get_prompt_engineering_prompt_template(node_data, query, variable_pool, None, 2000) - prompt_messages = prompt_transform.get_prompt( - prompt_template=prompt_template, - inputs={}, - query="", - files=[], - context=context, - memory_config=node_data.memory, - memory=None, + prompt_messages = self._compile_prompt_messages( model_instance=model_instance, + prompt_template=prompt_template, + files=[], + vision_enabled=False, + context=context, ) rest_tokens = 2000 - model_context_tokens = model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) if model_context_tokens: - model_type_instance = cast(LargeLanguageModel, model_instance.model_type_instance) - curr_message_tokens = ( - model_type_instance.get_num_tokens( - model_instance.model_name, model_instance.credentials, prompt_messages - ) - + 1000 - ) # add 1000 to ensure tool call messages + curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages) + 1000 max_tokens = 0 for parameter_rule in model_schema.parameter_rules: @@ -828,8 +790,34 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): return rest_tokens + def _compile_prompt_messages( + self, + *, + model_instance: PreparedLLMProtocol, + prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, + files: Sequence[File], + vision_enabled: bool, + context: str | None = "", + image_detail_config: ImagePromptMessageContent.DETAIL | None = None, + ) -> list[PromptMessage]: + prompt_messages, _ = LLMNode.fetch_prompt_messages( + sys_query="", + sys_files=files, + context=context, + memory=None, + model_instance=model_instance, + prompt_template=prompt_template, + stop=model_instance.stop, + memory_config=None, + vision_enabled=vision_enabled, + vision_detail=image_detail_config or ImagePromptMessageContent.DETAIL.HIGH, + variable_pool=self.graph_runtime_state.variable_pool, + jinja2_variables=[], + ) + return list(prompt_messages) + @property - def model_instance(self) -> ModelInstance: + def model_instance(self) -> PreparedLLMProtocol: return self._model_instance @classmethod diff --git a/api/dify_graph/nodes/protocols.py b/api/dify_graph/nodes/protocols.py index 62d3bcdca1..16d1609742 100644 --- a/api/dify_graph/nodes/protocols.py +++ b/api/dify_graph/nodes/protocols.py @@ -1,4 +1,4 @@ -from collections.abc import Generator +from collections.abc import Generator, Mapping from typing import Any, Protocol import httpx @@ -35,8 +35,6 @@ class ToolFileManagerProtocol(Protocol): def create_file_by_raw( self, *, - user_id: str, - tenant_id: str, conversation_id: str | None, file_binary: bytes, mimetype: str, @@ -44,3 +42,7 @@ class ToolFileManagerProtocol(Protocol): ) -> Any: ... def get_file_generator_by_tool_file_id(self, tool_file_id: str) -> tuple[Generator | None, ToolFile | None]: ... + + +class FileReferenceFactoryProtocol(Protocol): + def build_from_mapping(self, *, mapping: Mapping[str, Any]) -> File: ... diff --git a/api/dify_graph/nodes/question_classifier/entities.py b/api/dify_graph/nodes/question_classifier/entities.py index 0c1601d439..04c8e76ef9 100644 --- a/api/dify_graph/nodes/question_classifier/entities.py +++ b/api/dify_graph/nodes/question_classifier/entities.py @@ -1,9 +1,9 @@ from pydantic import BaseModel, Field -from core.prompt.entities.advanced_prompt_entities import MemoryConfig from dify_graph.entities.base_node_data import BaseNodeData from dify_graph.enums import BuiltinNodeTypes, NodeType from dify_graph.nodes.llm import ModelConfig, VisionConfig +from dify_graph.prompt_entities import MemoryConfig class ClassConfig(BaseModel): diff --git a/api/dify_graph/nodes/question_classifier/question_classifier_node.py b/api/dify_graph/nodes/question_classifier/question_classifier_node.py index 59d0a2a4d8..fbc21d5548 100644 --- a/api/dify_graph/nodes/question_classifier/question_classifier_node.py +++ b/api/dify_graph/nodes/question_classifier/question_classifier_node.py @@ -3,9 +3,6 @@ import re from collections.abc import Mapping, Sequence from typing import TYPE_CHECKING, Any -from core.model_manager import ModelInstance -from core.prompt.simple_prompt_transform import ModelMode -from core.prompt.utils.prompt_message_util import PromptMessageUtil from dify_graph.entities import GraphInitParams from dify_graph.entities.graph_config import NodeConfigDict from dify_graph.enums import ( @@ -14,7 +11,7 @@ from dify_graph.enums import ( WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, ) -from dify_graph.model_runtime.entities import LLMUsage, ModelPropertyKey, PromptMessageRole +from dify_graph.model_runtime.entities import LLMMode, LLMUsage, ModelPropertyKey, PromptMessageRole from dify_graph.model_runtime.memory import PromptMessageMemory from dify_graph.model_runtime.utils.encoders import jsonable_encoder from dify_graph.node_events import ModelInvokeCompletedEvent, NodeRunResult @@ -27,10 +24,11 @@ from dify_graph.nodes.llm import ( LLMNodeCompletionModelPromptTemplate, llm_utils, ) -from dify_graph.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver -from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer +from dify_graph.nodes.llm.file_saver import LLMFileSaver +from dify_graph.nodes.llm.protocols import TemplateRenderer +from dify_graph.nodes.llm.runtime_protocols import PreparedLLMProtocol, PromptMessageSerializerProtocol from dify_graph.nodes.protocols import HttpClientProtocol -from libs.json_in_md_parser import parse_and_check_json_markdown +from dify_graph.utils.json_in_md_parser import parse_and_check_json_markdown from .entities import QuestionClassifierNodeData from .exc import InvalidModelTypeError @@ -49,15 +47,20 @@ if TYPE_CHECKING: from dify_graph.runtime import GraphRuntimeState +class _PassthroughPromptMessageSerializer: + def serialize(self, *, model_mode: Any, prompt_messages: Sequence[Any]) -> Any: + _ = model_mode + return list(prompt_messages) + + class QuestionClassifierNode(Node[QuestionClassifierNodeData]): node_type = BuiltinNodeTypes.QUESTION_CLASSIFIER execution_type = NodeExecutionType.BRANCH _file_outputs: list["File"] _llm_file_saver: LLMFileSaver - _credentials_provider: "CredentialsProvider" - _model_factory: "ModelFactory" - _model_instance: ModelInstance + _prompt_message_serializer: PromptMessageSerializerProtocol + _model_instance: PreparedLLMProtocol _memory: PromptMessageMemory | None _template_renderer: TemplateRenderer @@ -68,13 +71,14 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", *, - credentials_provider: "CredentialsProvider", - model_factory: "ModelFactory", - model_instance: ModelInstance, + credentials_provider: object | None = None, + model_factory: object | None = None, + model_instance: PreparedLLMProtocol, http_client: HttpClientProtocol, template_renderer: TemplateRenderer, memory: PromptMessageMemory | None = None, - llm_file_saver: LLMFileSaver | None = None, + llm_file_saver: LLMFileSaver, + prompt_message_serializer: PromptMessageSerializerProtocol | None = None, ): super().__init__( id=id, @@ -85,20 +89,13 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): # LLM file outputs, used for MultiModal outputs. self._file_outputs = [] - self._credentials_provider = credentials_provider - self._model_factory = model_factory + _ = credentials_provider, model_factory, http_client self._model_instance = model_instance self._memory = memory self._template_renderer = template_renderer - if llm_file_saver is None: - dify_ctx = self.require_dify_context() - llm_file_saver = FileSaverImpl( - user_id=dify_ctx.user_id, - tenant_id=dify_ctx.tenant_id, - http_client=http_client, - ) self._llm_file_saver = llm_file_saver + self._prompt_message_serializer = prompt_message_serializer or _PassthroughPromptMessageSerializer() @classmethod def version(cls): @@ -169,7 +166,6 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): model_instance=model_instance, prompt_messages=prompt_messages, stop=stop, - user_id=self.require_dify_context().user_id, structured_output_enabled=False, structured_output=None, file_saver=self._llm_file_saver, @@ -205,7 +201,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): category_id = category_id_result process_data = { "model_mode": node_data.model.mode, - "prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving( + "prompts": self._prompt_message_serializer.serialize( model_mode=node_data.model.mode, prompt_messages=prompt_messages ), "usage": jsonable_encoder(usage), @@ -247,7 +243,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): ) @property - def model_instance(self) -> ModelInstance: + def model_instance(self) -> PreparedLLMProtocol: return self._model_instance @classmethod @@ -285,7 +281,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): self, node_data: QuestionClassifierNodeData, query: str, - model_instance: ModelInstance, + model_instance: PreparedLLMProtocol, context: str | None, ) -> int: model_schema = llm_utils.fetch_model_schema(model_instance=model_instance) @@ -334,7 +330,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): memory: PromptMessageMemory | None, max_token_limit: int = 2000, ): - model_mode = ModelMode(node_data.model.mode) + model_mode = LLMMode(node_data.model.mode) classes = node_data.classes categories = [] for class_ in classes: @@ -350,7 +346,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): message_limit=node_data.memory.window.size if node_data.memory and node_data.memory.window else None, ) prompt_messages: list[LLMNodeChatModelMessage] = [] - if model_mode == ModelMode.CHAT: + if model_mode == LLMMode.CHAT: system_prompt_messages = LLMNodeChatModelMessage( role=PromptMessageRole.SYSTEM, text=QUESTION_CLASSIFIER_SYSTEM_PROMPT.format(histories=memory_str) ) @@ -381,7 +377,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): ) prompt_messages.append(user_prompt_message_3) return prompt_messages - elif model_mode == ModelMode.COMPLETION: + elif model_mode == LLMMode.COMPLETION: return LLMNodeCompletionModelPromptTemplate( text=QUESTION_CLASSIFIER_COMPLETION_PROMPT.format( histories=memory_str, diff --git a/api/dify_graph/nodes/runtime.py b/api/dify_graph/nodes/runtime.py new file mode 100644 index 0000000000..01f3398807 --- /dev/null +++ b/api/dify_graph/nodes/runtime.py @@ -0,0 +1,75 @@ +from __future__ import annotations + +from collections.abc import Generator, Mapping, Sequence +from typing import TYPE_CHECKING, Any, Protocol + +from dify_graph.model_runtime.entities.llm_entities import LLMUsage +from dify_graph.nodes.tool_runtime_entities import ( + ToolRuntimeHandle, + ToolRuntimeMessage, + ToolRuntimeParameter, +) + +if TYPE_CHECKING: + from dify_graph.nodes.human_input.entities import DeliveryChannelConfig + from dify_graph.nodes.tool.entities import ToolNodeData + from dify_graph.runtime import VariablePool + + +class ToolNodeRuntimeProtocol(Protocol): + """Workflow-layer adapter owned by `core.workflow` and consumed by `dify_graph`. + + The graph package depends only on these DTOs and lets the workflow layer + translate between graph-owned abstractions and `core.tools` internals. + """ + + def get_runtime( + self, + *, + node_id: str, + node_data: ToolNodeData, + variable_pool: VariablePool | None, + ) -> ToolRuntimeHandle: ... + + def get_runtime_parameters( + self, + *, + tool_runtime: ToolRuntimeHandle, + ) -> Sequence[ToolRuntimeParameter]: ... + + def invoke( + self, + *, + tool_runtime: ToolRuntimeHandle, + tool_parameters: Mapping[str, Any], + workflow_call_depth: int, + conversation_id: str | None, + provider_name: str, + ) -> Generator[ToolRuntimeMessage, None, None]: ... + + def get_usage( + self, + *, + tool_runtime: ToolRuntimeHandle, + ) -> LLMUsage: ... + + def build_file_reference(self, *, mapping: Mapping[str, Any]) -> Any: ... + + def resolve_provider_icons( + self, + *, + provider_name: str, + default_icon: str | None = None, + ) -> tuple[str | None, str | None]: ... + + +class HumanInputNodeRuntimeProtocol(Protocol): + def invoke_source(self) -> str: ... + + def apply_delivery_runtime( + self, + *, + methods: Sequence[DeliveryChannelConfig], + ) -> Sequence[DeliveryChannelConfig]: ... + + def console_actor_id(self) -> str | None: ... diff --git a/api/dify_graph/nodes/template_transform/template_transform_node.py b/api/dify_graph/nodes/template_transform/template_transform_node.py index dc6fce2b0a..2f90f80112 100644 --- a/api/dify_graph/nodes/template_transform/template_transform_node.py +++ b/api/dify_graph/nodes/template_transform/template_transform_node.py @@ -4,9 +4,10 @@ from typing import TYPE_CHECKING, Any from dify_graph.entities.graph_config import NodeConfigDict from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus from dify_graph.node_events import NodeRunResult +from dify_graph.nodes.base.entities import VariableSelector from dify_graph.nodes.base.node import Node from dify_graph.nodes.template_transform.entities import TemplateTransformNodeData -from dify_graph.nodes.template_transform.template_renderer import ( +from dify_graph.template_rendering import ( Jinja2TemplateRenderer, TemplateRenderError, ) @@ -20,7 +21,7 @@ DEFAULT_TEMPLATE_TRANSFORM_MAX_OUTPUT_LENGTH = 400_000 class TemplateTransformNode(Node[TemplateTransformNodeData]): node_type = BuiltinNodeTypes.TEMPLATE_TRANSFORM - _template_renderer: Jinja2TemplateRenderer + _jinja2_template_renderer: Jinja2TemplateRenderer _max_output_length: int def __init__( @@ -30,7 +31,7 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]): graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", *, - template_renderer: Jinja2TemplateRenderer, + jinja2_template_renderer: Jinja2TemplateRenderer, max_output_length: int | None = None, ) -> None: super().__init__( @@ -39,7 +40,7 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]): graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, ) - self._template_renderer = template_renderer + self._jinja2_template_renderer = jinja2_template_renderer if max_output_length is not None and max_output_length <= 0: raise ValueError("max_output_length must be a positive integer") @@ -70,7 +71,7 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]): variables[variable_name] = value.to_object() if value else None # Run code try: - rendered = self._template_renderer.render_template(self.node_data.template, variables) + rendered = self._jinja2_template_renderer.render_template(self.node_data.template, variables) except TemplateRenderError as e: return NodeRunResult(inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, error=str(e)) @@ -87,9 +88,32 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]): @classmethod def _extract_variable_selector_to_variable_mapping( - cls, *, graph_config: Mapping[str, Any], node_id: str, node_data: TemplateTransformNodeData + cls, + *, + graph_config: Mapping[str, Any], + node_id: str, + node_data: TemplateTransformNodeData | Mapping[str, Any], ) -> Mapping[str, Sequence[str]]: - return { - node_id + "." + variable_selector.variable: variable_selector.value_selector - for variable_selector in node_data.variables - } + _ = graph_config + raw_variables = ( + node_data.variables if isinstance(node_data, TemplateTransformNodeData) else node_data.get("variables", []) + ) + variable_mapping: dict[str, Sequence[str]] = {} + for variable_selector in raw_variables: + if isinstance(variable_selector, VariableSelector): + variable_mapping[node_id + "." + variable_selector.variable] = variable_selector.value_selector + continue + + if not isinstance(variable_selector, Mapping): + continue + + variable = variable_selector.get("variable") + value_selector = variable_selector.get("value_selector") + if ( + isinstance(variable, str) + and isinstance(value_selector, Sequence) + and all(isinstance(selector_part, str) for selector_part in value_selector) + ): + variable_mapping[node_id + "." + variable] = list(value_selector) + + return variable_mapping diff --git a/api/dify_graph/nodes/tool/entities.py b/api/dify_graph/nodes/tool/entities.py index b041ee66fd..fc322fd912 100644 --- a/api/dify_graph/nodes/tool/entities.py +++ b/api/dify_graph/nodes/tool/entities.py @@ -1,13 +1,27 @@ +from enum import StrEnum, auto from typing import Any, Literal, Union from pydantic import BaseModel, field_validator from pydantic_core.core_schema import ValidationInfo -from core.tools.entities.tool_entities import ToolProviderType from dify_graph.entities.base_node_data import BaseNodeData from dify_graph.enums import BuiltinNodeTypes, NodeType +class ToolProviderType(StrEnum): + """ + Graph-owned enum for persisted tool provider kinds. + """ + + PLUGIN = auto() + BUILT_IN = "builtin" + WORKFLOW = auto() + API = auto() + APP = auto() + DATASET_RETRIEVAL = "dataset-retrieval" + MCP = auto() + + class ToolEntity(BaseModel): provider_id: str provider_type: ToolProviderType diff --git a/api/dify_graph/nodes/tool/exc.py b/api/dify_graph/nodes/tool/exc.py index 7212e8bfc0..1a309e1084 100644 --- a/api/dify_graph/nodes/tool/exc.py +++ b/api/dify_graph/nodes/tool/exc.py @@ -4,6 +4,18 @@ class ToolNodeError(ValueError): pass +class ToolRuntimeResolutionError(ToolNodeError): + """Raised when the workflow layer cannot construct a tool runtime.""" + + pass + + +class ToolRuntimeInvocationError(ToolNodeError): + """Raised when the workflow layer fails while invoking a tool runtime.""" + + pass + + class ToolParameterError(ToolNodeError): """Exception raised for errors in tool parameters.""" diff --git a/api/dify_graph/nodes/tool/tool_node.py b/api/dify_graph/nodes/tool/tool_node.py index 598f0da92e..dc6aedac19 100644 --- a/api/dify_graph/nodes/tool/tool_node.py +++ b/api/dify_graph/nodes/tool/tool_node.py @@ -1,12 +1,6 @@ from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any -from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler -from core.tools.__base.tool import Tool -from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter -from core.tools.errors import ToolInvokeError -from core.tools.tool_engine import ToolEngine -from core.tools.utils.message_transformer import ToolFileMessageTransformer from dify_graph.entities.graph_config import NodeConfigDict from dify_graph.enums import ( BuiltinNodeTypes, @@ -20,10 +14,14 @@ from dify_graph.node_events import NodeEventBase, NodeRunResult, StreamChunkEven from dify_graph.nodes.base.node import Node from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser from dify_graph.nodes.protocols import ToolFileManagerProtocol +from dify_graph.nodes.runtime import ToolNodeRuntimeProtocol +from dify_graph.nodes.tool_runtime_entities import ( + ToolRuntimeHandle, + ToolRuntimeMessage, + ToolRuntimeParameter, +) from dify_graph.variables.segments import ArrayAnySegment, ArrayFileSegment from dify_graph.variables.variables import ArrayAnyVariable -from factories import file_factory -from services.tools.builtin_tools_manage_service import BuiltinToolManageService from .entities import ToolNodeData from .exc import ( @@ -52,6 +50,7 @@ class ToolNode(Node[ToolNodeData]): graph_runtime_state: "GraphRuntimeState", *, tool_file_manager_factory: ToolFileManagerProtocol, + runtime: ToolNodeRuntimeProtocol | None = None, ): super().__init__( id=id, @@ -60,6 +59,9 @@ class ToolNode(Node[ToolNodeData]): graph_runtime_state=graph_runtime_state, ) self._tool_file_manager_factory = tool_file_manager_factory + if runtime is None: + raise ValueError("runtime is required") + self._runtime = runtime @classmethod def version(cls) -> str: @@ -73,10 +75,6 @@ class ToolNode(Node[ToolNodeData]): """ Run the tool node """ - from core.plugin.impl.exc import PluginDaemonClientSideError, PluginInvokeError - - dify_ctx = self.require_dify_context() - # fetch tool icon tool_info = { "provider_type": self.node_data.provider_type.value, @@ -86,8 +84,6 @@ class ToolNode(Node[ToolNodeData]): # get tool runtime try: - from core.tools.tool_manager import ToolManager - # This is an issue that caused problems before. # Logically, we shouldn't use the node_data.version field for judgment # But for backward compatibility with historical data @@ -95,13 +91,10 @@ class ToolNode(Node[ToolNodeData]): variable_pool: VariablePool | None = None if self.node_data.version != "1" or self.node_data.tool_node_version is not None: variable_pool = self.graph_runtime_state.variable_pool - tool_runtime = ToolManager.get_workflow_tool_runtime( - dify_ctx.tenant_id, - dify_ctx.app_id, - self._node_id, - self.node_data, - dify_ctx.invoke_from, - variable_pool, + tool_runtime = self._runtime.get_runtime( + node_id=self._node_id, + node_data=self.node_data, + variable_pool=variable_pool, ) except ToolNodeError as e: yield StreamCompletedEvent( @@ -116,7 +109,7 @@ class ToolNode(Node[ToolNodeData]): return # get parameters - tool_parameters = tool_runtime.get_merged_runtime_parameters() or [] + tool_parameters = self._runtime.get_runtime_parameters(tool_runtime=tool_runtime) parameters = self._generate_parameters( tool_parameters=tool_parameters, variable_pool=self.graph_runtime_state.variable_pool, @@ -132,14 +125,12 @@ class ToolNode(Node[ToolNodeData]): conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID]) try: - message_stream = ToolEngine.generic_invoke( - tool=tool_runtime, + message_stream = self._runtime.invoke( + tool_runtime=tool_runtime, tool_parameters=parameters, - user_id=dify_ctx.user_id, - workflow_tool_callback=DifyWorkflowCallbackHandler(), workflow_call_depth=self.workflow_call_depth, - app_id=dify_ctx.app_id, conversation_id=conversation_id.text if conversation_id else None, + provider_name=self.node_data.provider_name, ) except ToolNodeError as e: yield StreamCompletedEvent( @@ -159,38 +150,16 @@ class ToolNode(Node[ToolNodeData]): messages=message_stream, tool_info=tool_info, parameters_for_log=parameters_for_log, - user_id=dify_ctx.user_id, - tenant_id=dify_ctx.tenant_id, node_id=self._node_id, tool_runtime=tool_runtime, ) - except ToolInvokeError as e: + except ToolNodeError as e: yield StreamCompletedEvent( node_run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=parameters_for_log, metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info}, - error=f"Failed to invoke tool {self.node_data.provider_name}: {str(e)}", - error_type=type(e).__name__, - ) - ) - except PluginInvokeError as e: - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs=parameters_for_log, - metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info}, - error=e.to_user_friendly_error(plugin_name=self.node_data.provider_name), - error_type=type(e).__name__, - ) - ) - except PluginDaemonClientSideError as e: - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs=parameters_for_log, - metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info}, - error=f"Failed to invoke tool, error: {e.description}", + error=str(e), error_type=type(e).__name__, ) ) @@ -198,7 +167,7 @@ class ToolNode(Node[ToolNodeData]): def _generate_parameters( self, *, - tool_parameters: Sequence[ToolParameter], + tool_parameters: Sequence[ToolRuntimeParameter], variable_pool: "VariablePool", node_data: ToolNodeData, for_log: bool = False, @@ -207,7 +176,7 @@ class ToolNode(Node[ToolNodeData]): Generate parameters based on the given tool parameters, variable pool, and node data. Args: - tool_parameters (Sequence[ToolParameter]): The list of tool parameters. + tool_parameters (Sequence[ToolRuntimeParameter]): The list of tool parameters. variable_pool (VariablePool): The variable pool containing the variables. node_data (ToolNodeData): The data associated with the tool node. @@ -247,40 +216,29 @@ class ToolNode(Node[ToolNodeData]): def _transform_message( self, - messages: Generator[ToolInvokeMessage, None, None], + messages: Generator[ToolRuntimeMessage, None, None], tool_info: Mapping[str, Any], parameters_for_log: dict[str, Any], - user_id: str, - tenant_id: str, node_id: str, - tool_runtime: Tool, + tool_runtime: ToolRuntimeHandle, + **_: Any, ) -> Generator[NodeEventBase, None, LLMUsage]: """ - Convert ToolInvokeMessages into tuple[plain_text, files] + Convert graph-owned tool runtime messages into node outputs. """ - # transform message and handle file storage - from core.plugin.impl.plugin import PluginInstaller - - message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages( - messages=messages, - user_id=user_id, - tenant_id=tenant_id, - conversation_id=None, - ) - text = "" files: list[File] = [] json: list[dict | list] = [] variables: dict[str, Any] = {} - for message in message_stream: + for message in messages: if message.type in { - ToolInvokeMessage.MessageType.IMAGE_LINK, - ToolInvokeMessage.MessageType.BINARY_LINK, - ToolInvokeMessage.MessageType.IMAGE, + ToolRuntimeMessage.MessageType.IMAGE_LINK, + ToolRuntimeMessage.MessageType.BINARY_LINK, + ToolRuntimeMessage.MessageType.IMAGE, }: - assert isinstance(message.message, ToolInvokeMessage.TextMessage) + assert isinstance(message.message, ToolRuntimeMessage.TextMessage) url = message.message.text if message.meta: @@ -300,14 +258,11 @@ class ToolNode(Node[ToolNodeData]): "transfer_method": transfer_method, "url": url, } - file = file_factory.build_from_mapping( - mapping=mapping, - tenant_id=tenant_id, - ) + file = self._runtime.build_file_reference(mapping=mapping) files.append(file) - elif message.type == ToolInvokeMessage.MessageType.BLOB: + elif message.type == ToolRuntimeMessage.MessageType.BLOB: # get tool file id - assert isinstance(message.message, ToolInvokeMessage.TextMessage) + assert isinstance(message.message, ToolRuntimeMessage.TextMessage) assert message.meta tool_file_id = message.message.text.split("/")[-1].split(".")[0] @@ -320,27 +275,22 @@ class ToolNode(Node[ToolNodeData]): "transfer_method": FileTransferMethod.TOOL_FILE, } - files.append( - file_factory.build_from_mapping( - mapping=mapping, - tenant_id=tenant_id, - ) - ) - elif message.type == ToolInvokeMessage.MessageType.TEXT: - assert isinstance(message.message, ToolInvokeMessage.TextMessage) + files.append(self._runtime.build_file_reference(mapping=mapping)) + elif message.type == ToolRuntimeMessage.MessageType.TEXT: + assert isinstance(message.message, ToolRuntimeMessage.TextMessage) text += message.message.text yield StreamChunkEvent( selector=[node_id, "text"], chunk=message.message.text, is_final=False, ) - elif message.type == ToolInvokeMessage.MessageType.JSON: - assert isinstance(message.message, ToolInvokeMessage.JsonMessage) + elif message.type == ToolRuntimeMessage.MessageType.JSON: + assert isinstance(message.message, ToolRuntimeMessage.JsonMessage) # JSON message handling for tool node if message.message.json_object: json.append(message.message.json_object) - elif message.type == ToolInvokeMessage.MessageType.LINK: - assert isinstance(message.message, ToolInvokeMessage.TextMessage) + elif message.type == ToolRuntimeMessage.MessageType.LINK: + assert isinstance(message.message, ToolRuntimeMessage.TextMessage) # Check if this LINK message is a file link file_obj = (message.meta or {}).get("file") @@ -356,8 +306,8 @@ class ToolNode(Node[ToolNodeData]): chunk=stream_text, is_final=False, ) - elif message.type == ToolInvokeMessage.MessageType.VARIABLE: - assert isinstance(message.message, ToolInvokeMessage.VariableMessage) + elif message.type == ToolRuntimeMessage.MessageType.VARIABLE: + assert isinstance(message.message, ToolRuntimeMessage.VariableMessage) variable_name = message.message.variable_name variable_value = message.message.variable_value if message.message.stream: @@ -374,7 +324,7 @@ class ToolNode(Node[ToolNodeData]): ) else: variables[variable_name] = variable_value - elif message.type == ToolInvokeMessage.MessageType.FILE: + elif message.type == ToolRuntimeMessage.MessageType.FILE: assert message.meta is not None assert isinstance(message.meta, dict) # Validate that meta contains a 'file' key @@ -385,38 +335,16 @@ class ToolNode(Node[ToolNodeData]): if not isinstance(message.meta["file"], File): raise ToolNodeError(f"Expected File object but got {type(message.meta['file']).__name__}") files.append(message.meta["file"]) - elif message.type == ToolInvokeMessage.MessageType.LOG: - assert isinstance(message.message, ToolInvokeMessage.LogMessage) + elif message.type == ToolRuntimeMessage.MessageType.LOG: + assert isinstance(message.message, ToolRuntimeMessage.LogMessage) if message.message.metadata: icon = tool_info.get("icon", "") dict_metadata = dict(message.message.metadata) if dict_metadata.get("provider"): - manager = PluginInstaller() - plugins = manager.list_plugins(tenant_id) - try: - current_plugin = next( - plugin - for plugin in plugins - if f"{plugin.plugin_id}/{plugin.name}" == dict_metadata["provider"] - ) - icon = current_plugin.declaration.icon - except StopIteration: - pass - icon_dark = None - try: - builtin_tool = next( - provider - for provider in BuiltinToolManageService.list_builtin_tools( - user_id, - tenant_id, - ) - if provider.name == dict_metadata["provider"] - ) - icon = builtin_tool.icon - icon_dark = builtin_tool.icon_dark - except StopIteration: - pass - + icon, icon_dark = self._runtime.resolve_provider_icons( + provider_name=dict_metadata["provider"], + default_icon=icon, + ) dict_metadata["icon"] = icon dict_metadata["icon_dark"] = icon_dark message.message.metadata = dict_metadata @@ -446,7 +374,7 @@ class ToolNode(Node[ToolNodeData]): is_final=True, ) - usage = self._extract_tool_usage(tool_runtime) + usage = self._runtime.get_usage(tool_runtime=tool_runtime) metadata: dict[WorkflowNodeExecutionMetadataKey, Any] = { WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info, @@ -468,21 +396,6 @@ class ToolNode(Node[ToolNodeData]): return usage - @staticmethod - def _extract_tool_usage(tool_runtime: Tool) -> LLMUsage: - # Avoid importing WorkflowTool at module import time; rely on duck typing - # Some runtimes expose `latest_usage`; mocks may synthesize arbitrary attributes. - latest = getattr(tool_runtime, "latest_usage", None) - # Normalize into a concrete LLMUsage. MagicMock returns truthy attribute objects - # for any name, so we must type-check here. - if isinstance(latest, LLMUsage): - return latest - if isinstance(latest, dict): - # Allow dict payloads from external runtimes - return LLMUsage.model_validate(latest) - # Fallback to empty usage when attribute is missing or not a valid payload - return LLMUsage.empty_usage() - @classmethod def _extract_variable_selector_to_variable_mapping( cls, diff --git a/api/dify_graph/nodes/tool_runtime_entities.py b/api/dify_graph/nodes/tool_runtime_entities.py new file mode 100644 index 0000000000..d1a069fc5d --- /dev/null +++ b/api/dify_graph/nodes/tool_runtime_entities.py @@ -0,0 +1,101 @@ +from __future__ import annotations + +from dataclasses import dataclass +from enum import StrEnum, auto +from typing import Any + +from pydantic import BaseModel, ConfigDict, Field + + +class _ToolRuntimeModel(BaseModel): + model_config = ConfigDict(extra="forbid") + + +@dataclass(frozen=True, slots=True) +class ToolRuntimeHandle: + """Opaque graph-owned handle for a workflow-layer tool runtime.""" + + raw: object + + +@dataclass(frozen=True, slots=True) +class ToolRuntimeParameter: + """Graph-owned parameter shape used by tool nodes.""" + + name: str + required: bool = False + + +class ToolRuntimeMessage(_ToolRuntimeModel): + """Graph-owned tool invocation message DTO.""" + + class TextMessage(_ToolRuntimeModel): + text: str + + class JsonMessage(_ToolRuntimeModel): + json_object: dict[str, Any] | list[Any] + suppress_output: bool = Field(default=False) + + class BlobMessage(_ToolRuntimeModel): + blob: bytes + + class BlobChunkMessage(_ToolRuntimeModel): + id: str + sequence: int + total_length: int + blob: bytes + end: bool + + class FileMessage(_ToolRuntimeModel): + file_marker: str = Field(default="file_marker") + + class VariableMessage(_ToolRuntimeModel): + variable_name: str + variable_value: dict[str, Any] | list[Any] | str | int | float | bool | None + stream: bool = Field(default=False) + + class LogMessage(_ToolRuntimeModel): + class LogStatus(StrEnum): + START = auto() + ERROR = auto() + SUCCESS = auto() + + id: str + label: str + parent_id: str | None = None + error: str | None = None + status: LogStatus + data: dict[str, Any] + metadata: dict[str, Any] = Field(default_factory=dict) + + class RetrieverResourceMessage(_ToolRuntimeModel): + retriever_resources: list[dict[str, Any]] + context: str + + class MessageType(StrEnum): + TEXT = auto() + IMAGE = auto() + LINK = auto() + BLOB = auto() + JSON = auto() + IMAGE_LINK = auto() + BINARY_LINK = auto() + VARIABLE = auto() + FILE = auto() + LOG = auto() + BLOB_CHUNK = auto() + RETRIEVER_RESOURCES = auto() + + type: MessageType = MessageType.TEXT + message: ( + JsonMessage + | TextMessage + | BlobChunkMessage + | BlobMessage + | LogMessage + | FileMessage + | None + | VariableMessage + | RetrieverResourceMessage + ) + meta: dict[str, Any] | None = None diff --git a/api/dify_graph/prompt_entities.py b/api/dify_graph/prompt_entities.py new file mode 100644 index 0000000000..877b530613 --- /dev/null +++ b/api/dify_graph/prompt_entities.py @@ -0,0 +1,47 @@ +from typing import Literal + +from pydantic import BaseModel + +from dify_graph.model_runtime.entities.message_entities import PromptMessageRole + + +class ChatModelMessage(BaseModel): + """Graph-owned chat prompt template message.""" + + text: str + role: PromptMessageRole + edition_type: Literal["basic", "jinja2"] | None = None + + +class CompletionModelPromptTemplate(BaseModel): + """Graph-owned completion prompt template.""" + + text: str + edition_type: Literal["basic", "jinja2"] | None = None + + +class MemoryConfig(BaseModel): + """Graph-owned memory configuration for prompt assembly.""" + + class RolePrefix(BaseModel): + """Role labels used when serializing completion-model histories.""" + + user: str + assistant: str + + class WindowConfig(BaseModel): + """History windowing controls.""" + + enabled: bool + size: int | None = None + + role_prefix: RolePrefix | None = None + window: WindowConfig + query_prompt_template: str | None = None + + +__all__ = [ + "ChatModelMessage", + "CompletionModelPromptTemplate", + "MemoryConfig", +] diff --git a/api/dify_graph/repositories/human_input_form_repository.py b/api/dify_graph/repositories/human_input_form_repository.py index 88966831cb..e43185e9d7 100644 --- a/api/dify_graph/repositories/human_input_form_repository.py +++ b/api/dify_graph/repositories/human_input_form_repository.py @@ -18,9 +18,6 @@ class FormNotFoundError(HumanInputError): @dataclasses.dataclass class FormCreateParams: - # app_id is the identifier for the app that the form belongs to. - # It is a string with uuid format. - app_id: str # None when creating a delivery test form; set for runtime forms. workflow_execution_id: str | None @@ -45,6 +42,9 @@ class FormCreateParams: resolved_default_values: Mapping[str, Any] form_kind: HumanInputFormKind = HumanInputFormKind.RUNTIME + # Optional application identifier. Implementations may bind this at construction time. + app_id: str | None = None + # Force creating a console-only recipient for submission in Console. console_recipient_required: bool = False console_creator_account_id: str | None = None diff --git a/api/dify_graph/nodes/template_transform/template_renderer.py b/api/dify_graph/template_rendering.py similarity index 79% rename from api/dify_graph/nodes/template_transform/template_renderer.py rename to api/dify_graph/template_rendering.py index 9b679d4497..6fee97dbfc 100644 --- a/api/dify_graph/nodes/template_transform/template_renderer.py +++ b/api/dify_graph/template_rendering.py @@ -8,19 +8,17 @@ from dify_graph.nodes.code.entities import CodeLanguage class TemplateRenderError(ValueError): - """Raised when rendering a Jinja2 template fails.""" + """Raised when rendering a template fails.""" class Jinja2TemplateRenderer(Protocol): - """Render Jinja2 templates for template transform nodes.""" + """Shared contract for rendering Jinja2 templates in graph nodes.""" - def render_template(self, template: str, variables: Mapping[str, Any]) -> str: - """Render a Jinja2 template with provided variables.""" - raise NotImplementedError + def render_template(self, template: str, variables: Mapping[str, Any]) -> str: ... class CodeExecutorJinja2TemplateRenderer(Jinja2TemplateRenderer): - """Adapter that renders Jinja2 templates via CodeExecutor.""" + """Adapter that renders Jinja2 templates via the workflow code executor.""" _code_executor: WorkflowCodeExecutor diff --git a/api/dify_graph/utils/datetime_utils.py b/api/dify_graph/utils/datetime_utils.py new file mode 100644 index 0000000000..71af44ca92 --- /dev/null +++ b/api/dify_graph/utils/datetime_utils.py @@ -0,0 +1,20 @@ +from __future__ import annotations + +import abc +import datetime +from typing import Protocol + + +class _NowFunction(Protocol): + @abc.abstractmethod + def __call__(self, tz: datetime.timezone | None) -> datetime.datetime: + """Return the current time for the requested timezone.""" + ... + + +_now_func: _NowFunction = datetime.datetime.now + + +def naive_utc_now() -> datetime.datetime: + """Return the current UTC time as a naive datetime.""" + return _now_func(datetime.UTC).replace(tzinfo=None) diff --git a/api/dify_graph/utils/json_in_md_parser.py b/api/dify_graph/utils/json_in_md_parser.py new file mode 100644 index 0000000000..4416b4774b --- /dev/null +++ b/api/dify_graph/utils/json_in_md_parser.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +import json + + +class OutputParserError(ValueError): + """Raised when a markdown-wrapped JSON payload cannot be parsed or validated.""" + + +def parse_json_markdown(json_string: str) -> dict | list: + """Extract and parse the first JSON object or array embedded in markdown text.""" + json_string = json_string.strip() + starts = ["```json", "```", "``", "`", "{", "["] + ends = ["```", "``", "`", "}", "]"] + end_index = -1 + start_index = 0 + + for start_marker in starts: + start_index = json_string.find(start_marker) + if start_index != -1: + if json_string[start_index] not in ("{", "["): + start_index += len(start_marker) + break + + if start_index != -1: + for end_marker in ends: + end_index = json_string.rfind(end_marker, start_index) + if end_index != -1: + if json_string[end_index] in ("}", "]"): + end_index += 1 + break + + if start_index == -1 or end_index == -1 or start_index >= end_index: + raise ValueError("could not find json block in the output.") + + extracted_content = json_string[start_index:end_index].strip() + return json.loads(extracted_content) + + +def parse_and_check_json_markdown(text: str, expected_keys: list[str]) -> dict: + try: + json_obj = parse_json_markdown(text) + except json.JSONDecodeError as exc: + raise OutputParserError(f"got invalid json object. error: {exc}") from exc + + if isinstance(json_obj, list): + if len(json_obj) == 1 and isinstance(json_obj[0], dict): + json_obj = json_obj[0] + else: + raise OutputParserError(f"got invalid return object. obj:{json_obj}") + + for key in expected_keys: + if key not in json_obj: + raise OutputParserError( + f"got invalid return object. expected key `{key}` to be present, but got {json_obj}" + ) + + return json_obj diff --git a/api/services/app_service.py b/api/services/app_service.py index c5d1479a20..dd38942180 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -92,7 +92,7 @@ class AppService: default_model_config = default_model_config.copy() if default_model_config else None if default_model_config and "model" in default_model_config: # get model provider - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=account.current_tenant_id or "") # get default model instance try: diff --git a/api/services/audio_service.py b/api/services/audio_service.py index 1794ea9947..a4bd96b399 100644 --- a/api/services/audio_service.py +++ b/api/services/audio_service.py @@ -61,7 +61,7 @@ class AudioService: message = f"Audio size larger than {FILE_SIZE} mb" raise AudioTooLargeServiceError(message) - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=app_model.tenant_id, user_id=end_user) model_instance = model_manager.get_default_model_instance( tenant_id=app_model.tenant_id, model_type=ModelType.SPEECH2TEXT ) @@ -71,7 +71,7 @@ class AudioService: buffer = io.BytesIO(file_content) buffer.name = "temp.mp3" - return {"text": model_instance.invoke_speech2text(file=buffer, user=end_user)} + return {"text": model_instance.invoke_speech2text(file=buffer)} @classmethod def transcript_tts( @@ -109,7 +109,7 @@ class AudioService: voice = cast(str | None, text_to_speech_dict.get("voice")) - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=app_model.tenant_id, user_id=end_user) model_instance = model_manager.get_default_model_instance( tenant_id=app_model.tenant_id, model_type=ModelType.TTS ) @@ -123,9 +123,7 @@ class AudioService: else: raise ValueError("Sorry, no voice available.") - return model_instance.invoke_tts( - content_text=text_content.strip(), user=end_user, tenant_id=app_model.tenant_id, voice=voice - ) + return model_instance.invoke_tts(content_text=text_content.strip(), voice=voice) except Exception as e: raise e @@ -155,7 +153,7 @@ class AudioService: @classmethod def transcript_tts_voices(cls, tenant_id: str, language: str): - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) model_instance = model_manager.get_default_model_instance(tenant_id=tenant_id, model_type=ModelType.TTS) if model_instance is None: raise ProviderNotSupportTextToSpeechServiceError() diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index cdab90a3dc..7171aff481 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -228,7 +228,7 @@ class DatasetService: raise DatasetNameDuplicateError(f"Dataset with name {name} already exists.") embedding_model = None if indexing_technique == "high_quality": - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) if embedding_model_provider and embedding_model_name: # check if embedding model setting is valid DatasetService.check_embedding_model_setting(tenant_id, embedding_model_provider, embedding_model_name) @@ -350,7 +350,7 @@ class DatasetService: def check_dataset_model_setting(dataset): if dataset.indexing_technique == "high_quality": try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=dataset.tenant_id) model_manager.get_model_instance( tenant_id=dataset.tenant_id, provider=dataset.embedding_model_provider, @@ -367,7 +367,7 @@ class DatasetService: @staticmethod def check_embedding_model_setting(tenant_id: str, embedding_model_provider: str, embedding_model: str): try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) model_manager.get_model_instance( tenant_id=tenant_id, provider=embedding_model_provider, @@ -384,7 +384,7 @@ class DatasetService: @staticmethod def check_is_multimodal_model(tenant_id: str, model_provider: str, model: str): try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) model_instance = model_manager.get_model_instance( tenant_id=tenant_id, provider=model_provider, @@ -405,7 +405,7 @@ class DatasetService: @staticmethod def check_reranking_model_setting(tenant_id: str, reranking_model_provider: str, reranking_model: str): try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) model_manager.get_model_instance( tenant_id=tenant_id, provider=reranking_model_provider, @@ -742,7 +742,7 @@ class DatasetService: """ # assert isinstance(current_user, Account) and current_user.current_tenant_id is not None try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_user.current_tenant_id) assert isinstance(current_user, Account) assert current_user.current_tenant_id is not None embedding_model = model_manager.get_model_instance( @@ -860,7 +860,7 @@ class DatasetService: """ # assert isinstance(current_user, Account) and current_user.current_tenant_id is not None - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_user.current_tenant_id) try: assert isinstance(current_user, Account) assert current_user.current_tenant_id is not None @@ -954,7 +954,7 @@ class DatasetService: dataset.chunk_structure = knowledge_configuration.chunk_structure dataset.indexing_technique = knowledge_configuration.indexing_technique if knowledge_configuration.indexing_technique == "high_quality": - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_user.current_tenant_id) embedding_model = model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, # ignore type error provider=knowledge_configuration.embedding_model_provider or "", @@ -996,7 +996,7 @@ class DatasetService: action = "add" # get embedding model setting try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_user.current_tenant_id) embedding_model = model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, provider=knowledge_configuration.embedding_model_provider, @@ -1049,7 +1049,7 @@ class DatasetService: or knowledge_configuration.embedding_model != dataset.embedding_model ): action = "update" - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_user.current_tenant_id) embedding_model = None try: embedding_model = model_manager.get_model_instance( @@ -1908,7 +1908,7 @@ class DocumentService: dataset.indexing_technique = knowledge_config.indexing_technique if knowledge_config.indexing_technique == "high_quality": - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_user.current_tenant_id) if knowledge_config.embedding_model and knowledge_config.embedding_model_provider: dataset_embedding_model = knowledge_config.embedding_model dataset_embedding_model_provider = knowledge_config.embedding_model_provider @@ -2220,7 +2220,7 @@ class DocumentService: # dataset.indexing_technique = knowledge_config.indexing_technique # if knowledge_config.indexing_technique == "high_quality": - # model_manager = ModelManager() + # model_manager = ModelManager.for_tenant(tenant_id=current_user.current_tenant_id) # if knowledge_config.embedding_model and knowledge_config.embedding_model_provider: # dataset_embedding_model = knowledge_config.embedding_model # dataset_embedding_model_provider = knowledge_config.embedding_model_provider @@ -3125,7 +3125,7 @@ class SegmentService: segment_hash = helper.generate_text_hash(content) tokens = 0 if dataset.indexing_technique == "high_quality": - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_user.current_tenant_id) embedding_model = model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, provider=dataset.embedding_model_provider, @@ -3208,7 +3208,7 @@ class SegmentService: with redis_client.lock(lock_name, timeout=600): embedding_model = None if dataset.indexing_technique == "high_quality": - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_user.current_tenant_id) embedding_model = model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, provider=dataset.embedding_model_provider, @@ -3346,7 +3346,7 @@ class SegmentService: # get embedding model instance if dataset.indexing_technique == "high_quality": # check embedding model setting - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=dataset.tenant_id) if dataset.embedding_model_provider: embedding_model_instance = model_manager.get_model_instance( @@ -3409,7 +3409,7 @@ class SegmentService: segment_hash = helper.generate_text_hash(content) tokens = 0 if dataset.indexing_technique == "high_quality": - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_user.current_tenant_id) embedding_model = model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, provider=dataset.embedding_model_provider, @@ -3450,7 +3450,7 @@ class SegmentService: # get embedding model instance if dataset.indexing_technique == "high_quality": # check embedding model setting - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=dataset.tenant_id) if dataset.embedding_model_provider: embedding_model_instance = model_manager.get_model_instance( diff --git a/api/services/message_service.py b/api/services/message_service.py index fc87802f51..035a4b99a6 100644 --- a/api/services/message_service.py +++ b/api/services/message_service.py @@ -255,7 +255,7 @@ class MessageService: app_model=app_model, conversation_id=message.conversation_id, user=user ) - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=app_model.tenant_id) if app_model.mode == AppMode.ADVANCED_CHAT: workflow_service = WorkflowService() diff --git a/api/services/model_load_balancing_service.py b/api/services/model_load_balancing_service.py index bf3b6db3ed..2e927bc693 100644 --- a/api/services/model_load_balancing_service.py +++ b/api/services/model_load_balancing_service.py @@ -10,13 +10,13 @@ from core.entities.provider_configuration import ProviderConfiguration from core.helper import encrypter from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType from core.model_manager import LBModelManager +from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory, create_plugin_provider_manager from core.provider_manager import ProviderManager from dify_graph.model_runtime.entities.model_entities import ModelType from dify_graph.model_runtime.entities.provider_entities import ( ModelCredentialSchema, ProviderCredentialSchema, ) -from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from models.enums import CredentialSourceType @@ -26,8 +26,9 @@ logger = logging.getLogger(__name__) class ModelLoadBalancingService: - def __init__(self): - self.provider_manager = ProviderManager() + @staticmethod + def _get_provider_manager(tenant_id: str) -> ProviderManager: + return create_plugin_provider_manager(tenant_id=tenant_id) def enable_model_load_balancing(self, tenant_id: str, provider: str, model: str, model_type: str): """ @@ -40,7 +41,7 @@ class ModelLoadBalancingService: :return: """ # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) + provider_configurations = self._get_provider_manager(tenant_id).get_configurations(tenant_id) # Get provider configuration provider_configuration = provider_configurations.get(provider) @@ -61,7 +62,7 @@ class ModelLoadBalancingService: :return: """ # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) + provider_configurations = self._get_provider_manager(tenant_id).get_configurations(tenant_id) # Get provider configuration provider_configuration = provider_configurations.get(provider) @@ -83,7 +84,7 @@ class ModelLoadBalancingService: :return: """ # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) + provider_configurations = self._get_provider_manager(tenant_id).get_configurations(tenant_id) # Get provider configuration provider_configuration = provider_configurations.get(provider) @@ -223,7 +224,7 @@ class ModelLoadBalancingService: :return: """ # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) + provider_configurations = self._get_provider_manager(tenant_id).get_configurations(tenant_id) # Get provider configuration provider_configuration = provider_configurations.get(provider) @@ -310,7 +311,7 @@ class ModelLoadBalancingService: :return: """ # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) + provider_configurations = self._get_provider_manager(tenant_id).get_configurations(tenant_id) # Get provider configuration provider_configuration = provider_configurations.get(provider) @@ -496,7 +497,7 @@ class ModelLoadBalancingService: :return: """ # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) + provider_configurations = self._get_provider_manager(tenant_id).get_configurations(tenant_id) # Get provider configuration provider_configuration = provider_configurations.get(provider) @@ -581,7 +582,7 @@ class ModelLoadBalancingService: credentials[key] = encrypter.decrypt_token(tenant_id, original_credentials[key]) if validate: - model_provider_factory = ModelProviderFactory(tenant_id) + model_provider_factory = create_plugin_model_provider_factory(tenant_id=tenant_id) if isinstance(credential_schemas, ModelCredentialSchema): credentials = model_provider_factory.model_credentials_validate( provider=provider_configuration.provider.provider, diff --git a/api/services/model_provider_service.py b/api/services/model_provider_service.py index 0ddd6b9b1a..fae3c66e2a 100644 --- a/api/services/model_provider_service.py +++ b/api/services/model_provider_service.py @@ -1,9 +1,9 @@ import logging from core.entities.model_entities import ModelWithProviderEntity, ProviderModelWithStatusEntity +from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory, create_plugin_provider_manager from core.provider_manager import ProviderManager from dify_graph.model_runtime.entities.model_entities import ModelType, ParameterRule -from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from models.provider import ProviderType from services.entities.model_provider_entities import ( CustomConfigurationResponse, @@ -25,8 +25,9 @@ class ModelProviderService: Model Provider Service """ - def __init__(self): - self.provider_manager = ProviderManager() + @staticmethod + def _get_provider_manager(tenant_id: str) -> ProviderManager: + return create_plugin_provider_manager(tenant_id=tenant_id) def _get_provider_configuration(self, tenant_id: str, provider: str): """ @@ -43,7 +44,7 @@ class ModelProviderService: ProviderNotFoundError: If provider doesn't exist """ # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) + provider_configurations = self._get_provider_manager(tenant_id).get_configurations(tenant_id) provider_configuration = provider_configurations.get(provider) if not provider_configuration: @@ -60,7 +61,7 @@ class ModelProviderService: :return: """ # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) + provider_configurations = self._get_provider_manager(tenant_id).get_configurations(tenant_id) provider_responses = [] for provider_configuration in provider_configurations.values(): @@ -138,7 +139,7 @@ class ModelProviderService: :return: """ # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) + provider_configurations = self._get_provider_manager(tenant_id).get_configurations(tenant_id) # Get provider available models return [ @@ -146,6 +147,26 @@ class ModelProviderService: for model in provider_configurations.get_models(provider=provider) ] + def get_provider_available_credentials(self, tenant_id: str, provider: str): + return self._get_provider_manager(tenant_id).get_provider_available_credentials( + tenant_id=tenant_id, + provider_name=provider, + ) + + def get_provider_model_available_credentials( + self, + tenant_id: str, + provider: str, + model_type: str, + model: str, + ): + return self._get_provider_manager(tenant_id).get_provider_model_available_credentials( + tenant_id=tenant_id, + provider_name=provider, + model_type=model_type, + model_name=model, + ) + def get_provider_credential(self, tenant_id: str, provider: str, credential_id: str | None = None) -> dict | None: """ get provider credentials. @@ -391,7 +412,7 @@ class ModelProviderService: :return: """ # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) + provider_configurations = self._get_provider_manager(tenant_id).get_configurations(tenant_id) # Get provider available models models = provider_configurations.get_models(model_type=ModelType.value_of(model_type), only_active=True) @@ -476,7 +497,9 @@ class ModelProviderService: model_type_enum = ModelType.value_of(model_type) try: - result = self.provider_manager.get_default_model(tenant_id=tenant_id, model_type=model_type_enum) + result = self._get_provider_manager(tenant_id).get_default_model( + tenant_id=tenant_id, model_type=model_type_enum + ) return ( DefaultModelResponse( model=result.model, @@ -507,7 +530,7 @@ class ModelProviderService: :return: """ model_type_enum = ModelType.value_of(model_type) - self.provider_manager.update_default_model_record( + self._get_provider_manager(tenant_id).update_default_model_record( tenant_id=tenant_id, model_type=model_type_enum, provider=provider, model=model ) @@ -523,7 +546,7 @@ class ModelProviderService: :param lang: language (zh_Hans or en_US) :return: """ - model_provider_factory = ModelProviderFactory(tenant_id) + model_provider_factory = create_plugin_model_provider_factory(tenant_id=tenant_id) byte_data, mime_type = model_provider_factory.get_provider_icon(provider, icon_type, lang) return byte_data, mime_type diff --git a/api/services/summary_index_service.py b/api/services/summary_index_service.py index 943dfc972b..ef9970e8f2 100644 --- a/api/services/summary_index_service.py +++ b/api/services/summary_index_service.py @@ -191,7 +191,7 @@ class SummaryIndexService: # Calculate embedding tokens for summary (for logging and statistics) embedding_tokens = 0 try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=dataset.tenant_id) embedding_model = model_manager.get_model_instance( tenant_id=dataset.tenant_id, provider=dataset.embedding_model_provider, diff --git a/api/services/vector_service.py b/api/services/vector_service.py index b66fdd7a20..8687e17bcc 100644 --- a/api/services/vector_service.py +++ b/api/services/vector_service.py @@ -47,7 +47,7 @@ class VectorService: # get embedding model instance if dataset.indexing_technique == "high_quality": # check embedding model setting - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=dataset.tenant_id) if dataset.embedding_model_provider: embedding_model_instance = model_manager.get_model_instance( diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 66976058c0..72dc26a29a 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -12,10 +12,12 @@ from configs import dify_config from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context +from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager from core.repositories import DifyCoreRepositoryFactory from core.repositories.human_input_repository import HumanInputFormRepositoryImpl from core.trigger.constants import is_trigger_node_type from core.workflow.node_factory import LATEST_VERSION, get_node_type_classes_mapping, is_start_node_type +from core.workflow.node_runtime import DifyHumanInputNodeRuntime from core.workflow.workflow_entry import WorkflowEntry from dify_graph.entities import GraphInitParams, WorkflowNodeExecution from dify_graph.entities.graph_config import NodeConfigDict @@ -487,11 +489,10 @@ class WorkflowService: """ try: from core.model_manager import ModelManager - from core.provider_manager import ProviderManager from dify_graph.model_runtime.entities.model_entities import ModelType # Get model instance to validate provider+model combination - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) model_manager.get_model_instance( tenant_id=tenant_id, provider=provider, model_type=ModelType.LLM, model=model_name ) @@ -501,7 +502,7 @@ class WorkflowService: # If it fails, an exception will be raised # Additionally, check the model status to ensure it's ACTIVE - provider_manager = ProviderManager() + provider_manager = create_plugin_provider_manager(tenant_id=tenant_id) provider_configurations = provider_manager.get_configurations(tenant_id) models = provider_configurations.get_models(provider=provider, model_type=ModelType.LLM) @@ -607,11 +608,10 @@ class WorkflowService: :return: True if load balancing is enabled, False otherwise """ try: - from core.provider_manager import ProviderManager from dify_graph.model_runtime.entities.model_entities import ModelType # Get provider configurations - provider_manager = ProviderManager() + provider_manager = create_plugin_provider_manager(tenant_id=tenant_id) provider_configurations = provider_manager.get_configurations(tenant_id) provider_configuration = provider_configurations.get(provider) @@ -1139,6 +1139,7 @@ class WorkflowService: graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, form_repository=HumanInputFormRepositoryImpl(tenant_id=workflow.tenant_id), + runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), ) return node diff --git a/api/tasks/batch_create_segment_to_index_task.py b/api/tasks/batch_create_segment_to_index_task.py index 49dee00919..4ad3cd39e9 100644 --- a/api/tasks/batch_create_segment_to_index_task.py +++ b/api/tasks/batch_create_segment_to_index_task.py @@ -120,7 +120,7 @@ def batch_create_segment_to_index_task( document_segments = [] embedding_model = None if dataset_config["indexing_technique"] == "high_quality": - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=dataset_config["tenant_id"]) embedding_model = model_manager.get_model_instance( tenant_id=dataset_config["tenant_id"], provider=dataset_config["embedding_model_provider"], diff --git a/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py b/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py index 3e79792b5b..d20a5a2525 100644 --- a/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py +++ b/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py @@ -1,5 +1,5 @@ +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY from core.workflow.nodes.datasource.datasource_node import DatasourceNode -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus from dify_graph.node_events import NodeRunResult, StreamCompletedEvent diff --git a/api/tests/integration_tests/workflow/nodes/__mock/model.py b/api/tests/integration_tests/workflow/nodes/__mock/model.py index 5b0f86fed1..a66fb0a6b1 100644 --- a/api/tests/integration_tests/workflow/nodes/__mock/model.py +++ b/api/tests/integration_tests/workflow/nodes/__mock/model.py @@ -4,8 +4,8 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle from core.entities.provider_entities import CustomConfiguration, CustomProviderConfiguration, SystemConfiguration from core.model_manager import ModelInstance +from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from models.provider import ProviderType @@ -15,7 +15,7 @@ def get_mocked_fetch_model_config( mode: str, credentials: dict, ): - model_provider_factory = ModelProviderFactory(tenant_id="9d2074fc-6f86-45a9-b09d-6ecc63b9056b") + model_provider_factory = create_plugin_model_provider_factory(tenant_id="9d2074fc-6f86-45a9-b09d-6ecc63b9056b") model_type_instance = model_provider_factory.get_model_type_instance(provider, ModelType.LLM) provider_model_bundle = ProviderModelBundle( configuration=ProviderConfiguration( diff --git a/api/tests/integration_tests/workflow/nodes/test_http.py b/api/tests/integration_tests/workflow/nodes/test_http.py index f885f69e55..f71c263914 100644 --- a/api/tests/integration_tests/workflow/nodes/test_http.py +++ b/api/tests/integration_tests/workflow/nodes/test_http.py @@ -9,6 +9,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.helper.ssrf_proxy import ssrf_proxy from core.tools.tool_file_manager import ToolFileManager from core.workflow.node_factory import DifyNodeFactory +from core.workflow.node_runtime import DifyFileReferenceFactory from dify_graph.enums import WorkflowNodeExecutionStatus from dify_graph.file.file_manager import file_manager from dify_graph.graph import Graph @@ -81,6 +82,7 @@ def init_http_node(config: dict): http_client=ssrf_proxy, tool_file_manager_factory=ToolFileManager, file_manager=file_manager, + file_reference_factory=DifyFileReferenceFactory(init_params.run_context), ) return node @@ -728,6 +730,7 @@ def test_nested_object_variable_selector(setup_http_mock): http_client=ssrf_proxy, tool_file_manager_factory=ToolFileManager, file_manager=file_manager, + file_reference_factory=DifyFileReferenceFactory(init_params.run_context), ) result = node._run() diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py index d628348f1e..dedccb0fd3 100644 --- a/api/tests/integration_tests/workflow/nodes/test_llm.py +++ b/api/tests/integration_tests/workflow/nodes/test_llm.py @@ -9,8 +9,10 @@ from core.llm_generator.output_parser.structured_output import _parse_structured from core.model_manager import ModelInstance from dify_graph.enums import WorkflowNodeExecutionStatus from dify_graph.node_events import StreamCompletedEvent +from dify_graph.nodes.llm.file_saver import LLMFileSaver from dify_graph.nodes.llm.node import LLMNode -from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer +from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory +from dify_graph.nodes.llm.runtime_protocols import PromptMessageSerializerProtocol from dify_graph.nodes.protocols import HttpClientProtocol from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable @@ -66,6 +68,11 @@ def init_llm_node(config: dict) -> LLMNode: variable_pool.add(["abc", "output"], "sunny") graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) + prompt_message_serializer = MagicMock(spec=PromptMessageSerializerProtocol) + prompt_message_serializer.serialize.side_effect = lambda *, model_mode, prompt_messages: [ + message.model_dump(mode="json") for message in prompt_messages + ] + llm_file_saver = MagicMock(spec=LLMFileSaver) node = LLMNode( id=str(uuid.uuid4()), @@ -75,7 +82,8 @@ def init_llm_node(config: dict) -> LLMNode: credentials_provider=MagicMock(spec=CredentialsProvider), model_factory=MagicMock(spec=ModelFactory), model_instance=MagicMock(spec=ModelInstance), - template_renderer=MagicMock(spec=TemplateRenderer), + llm_file_saver=llm_file_saver, + prompt_message_serializer=prompt_message_serializer, http_client=MagicMock(spec=HttpClientProtocol), ) @@ -159,7 +167,7 @@ def test_execute_llm(): return mock_model_instance # Mock fetch_prompt_messages to avoid database calls - def mock_fetch_prompt_messages_1(*_args, **_kwargs): + def mock_fetch_prompt_messages_1(**_kwargs): from dify_graph.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage return [ diff --git a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py index 62d9af0196..4479a9f881 100644 --- a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py +++ b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py @@ -5,6 +5,7 @@ from unittest.mock import MagicMock from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.model_manager import ModelInstance +from core.workflow.node_runtime import DifyPromptMessageSerializer from dify_graph.enums import WorkflowNodeExecutionStatus from dify_graph.model_runtime.entities import AssistantPromptMessage, UserPromptMessage from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory @@ -77,6 +78,7 @@ def init_parameter_extractor_node(config: dict, memory=None): model_factory=MagicMock(spec=ModelFactory), model_instance=MagicMock(spec=ModelInstance), memory=memory, + prompt_message_serializer=DifyPromptMessageSerializer(), ) return node diff --git a/api/tests/integration_tests/workflow/nodes/test_template_transform.py b/api/tests/integration_tests/workflow/nodes/test_template_transform.py index 7bb4f905c3..8324a96c0c 100644 --- a/api/tests/integration_tests/workflow/nodes/test_template_transform.py +++ b/api/tests/integration_tests/workflow/nodes/test_template_transform.py @@ -5,10 +5,10 @@ from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.workflow.node_factory import DifyNodeFactory from dify_graph.enums import WorkflowNodeExecutionStatus from dify_graph.graph import Graph -from dify_graph.nodes.template_transform.template_renderer import TemplateRenderError from dify_graph.nodes.template_transform.template_transform_node import TemplateTransformNode from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable +from dify_graph.template_rendering import TemplateRenderError from tests.workflow_test_utils import build_test_graph_init_params @@ -90,7 +90,7 @@ def test_execute_template_transform(): config=config, graph_init_params=init_params, graph_runtime_state=graph_runtime_state, - template_renderer=_SimpleJinja2Renderer(), + jinja2_template_renderer=_SimpleJinja2Renderer(), ) # execute node diff --git a/api/tests/integration_tests/workflow/nodes/test_tool.py b/api/tests/integration_tests/workflow/nodes/test_tool.py index a6717ada31..8b7251842a 100644 --- a/api/tests/integration_tests/workflow/nodes/test_tool.py +++ b/api/tests/integration_tests/workflow/nodes/test_tool.py @@ -5,6 +5,7 @@ from unittest.mock import MagicMock, patch from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.tools.utils.configuration import ToolParameterConfigurationManager from core.workflow.node_factory import DifyNodeFactory +from core.workflow.node_runtime import DifyToolNodeRuntime from dify_graph.enums import WorkflowNodeExecutionStatus from dify_graph.graph import Graph from dify_graph.node_events import StreamCompletedEvent @@ -64,6 +65,7 @@ def init_tool_node(config: dict): graph_init_params=init_params, graph_runtime_state=graph_runtime_state, tool_file_manager_factory=tool_file_manager_factory, + runtime=DifyToolNodeRuntime(init_params.run_context), ) return node diff --git a/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py b/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py index 9733735df3..a7ab3e2869 100644 --- a/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py +++ b/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py @@ -12,6 +12,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerat from core.app.workflow.layers import PersistenceWorkflowInfo, WorkflowPersistenceLayer from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository +from core.workflow.node_runtime import DifyHumanInputNodeRuntime from dify_graph.enums import WorkflowType from dify_graph.graph import Graph from dify_graph.graph_engine.command_channels.in_memory_channel import InMemoryChannel @@ -120,6 +121,7 @@ def _build_graph( graph_init_params=params, graph_runtime_state=runtime_state, form_repository=form_repository, + runtime=DifyHumanInputNodeRuntime(params.run_context), ) end_data = EndNodeData( diff --git a/api/tests/test_containers_integration_tests/services/test_agent_service.py b/api/tests/test_containers_integration_tests/services/test_agent_service.py index ee34b65831..f262e7727a 100644 --- a/api/tests/test_containers_integration_tests/services/test_agent_service.py +++ b/api/tests/test_containers_integration_tests/services/test_agent_service.py @@ -28,7 +28,7 @@ class TestAgentService: patch("services.agent_service.current_user", create_autospec(Account, instance=True)) as mock_current_user, patch("services.app_service.FeatureService", autospec=True) as mock_feature_service, patch("services.app_service.EnterpriseService", autospec=True) as mock_enterprise_service, - patch("services.app_service.ModelManager", autospec=True) as mock_model_manager, + patch("services.app_service.ModelManager.for_tenant", autospec=True) as mock_model_manager, patch("services.account_service.FeatureService", autospec=True) as mock_account_feature_service, ): # Setup default mock returns for agent service diff --git a/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py b/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py index 8a362e1f5e..33955d5d84 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py @@ -26,7 +26,7 @@ class TestAppDslService: patch("services.app_dsl_service.redis_client") as mock_redis_client, patch("services.app_dsl_service.app_was_created") as mock_app_was_created, patch("services.app_dsl_service.app_model_config_was_updated") as mock_app_model_config_was_updated, - patch("services.app_service.ModelManager") as mock_model_manager, + patch("services.app_service.ModelManager.for_tenant") as mock_model_manager, patch("services.app_service.FeatureService") as mock_feature_service, patch("services.app_service.EnterpriseService") as mock_enterprise_service, ): diff --git a/api/tests/test_containers_integration_tests/services/test_app_service.py b/api/tests/test_containers_integration_tests/services/test_app_service.py index d79f80c009..b6f469d7b3 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_service.py @@ -23,7 +23,7 @@ class TestAppService: with ( patch("services.app_service.FeatureService") as mock_feature_service, patch("services.app_service.EnterpriseService") as mock_enterprise_service, - patch("services.app_service.ModelManager") as mock_model_manager, + patch("services.app_service.ModelManager.for_tenant") as mock_model_manager, patch("services.account_service.FeatureService") as mock_account_feature_service, ): # Setup default mock returns for app service diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service.py b/api/tests/test_containers_integration_tests/services/test_dataset_service.py index ac3d9f9604..cef65de541 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service.py @@ -173,7 +173,7 @@ class TestDatasetServiceCreateDataset: embedding_model = DatasetServiceIntegrationDataFactory.create_embedding_model() # Act - with patch("services.dataset_service.ModelManager") as mock_model_manager: + with patch("services.dataset_service.ModelManager.for_tenant") as mock_model_manager: mock_model_manager.return_value.get_default_model_instance.return_value = embedding_model result = DatasetService.create_empty_dataset( @@ -263,7 +263,7 @@ class TestDatasetServiceCreateDataset: # Act with ( - patch("services.dataset_service.ModelManager") as mock_model_manager, + patch("services.dataset_service.ModelManager.for_tenant") as mock_model_manager, patch("services.dataset_service.DatasetService.check_reranking_model_setting") as mock_check_reranking, ): mock_model_manager.return_value.get_default_model_instance.return_value = embedding_model @@ -296,7 +296,7 @@ class TestDatasetServiceCreateDataset: # Act with ( - patch("services.dataset_service.ModelManager") as mock_model_manager, + patch("services.dataset_service.ModelManager.for_tenant") as mock_model_manager, patch("services.dataset_service.DatasetService.check_embedding_model_setting") as mock_check_embedding, ): mock_model_manager.return_value.get_model_instance.return_value = embedding_model diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py index fd81948247..7a307b0416 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py @@ -362,7 +362,7 @@ class TestDatasetServiceUpdateDataset: with ( patch("services.dataset_service.current_user", user), - patch("services.dataset_service.ModelManager") as mock_model_manager, + patch("services.dataset_service.ModelManager.for_tenant") as mock_model_manager, patch( "services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding" ) as mock_get_binding, @@ -457,7 +457,7 @@ class TestDatasetServiceUpdateDataset: with ( patch("services.dataset_service.current_user", user), - patch("services.dataset_service.ModelManager") as mock_model_manager, + patch("services.dataset_service.ModelManager.for_tenant") as mock_model_manager, patch( "services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding" ) as mock_get_binding, @@ -543,7 +543,7 @@ class TestDatasetServiceUpdateDataset: with ( patch("services.dataset_service.current_user", user), - patch("services.dataset_service.ModelManager") as mock_model_manager, + patch("services.dataset_service.ModelManager.for_tenant") as mock_model_manager, ): mock_model_manager.return_value.get_model_instance.side_effect = Exception("No Embedding Model available") diff --git a/api/tests/test_containers_integration_tests/services/test_message_service.py b/api/tests/test_containers_integration_tests/services/test_message_service.py index af666a0375..81585bda8f 100644 --- a/api/tests/test_containers_integration_tests/services/test_message_service.py +++ b/api/tests/test_containers_integration_tests/services/test_message_service.py @@ -25,7 +25,7 @@ class TestMessageService: """Mock setup for external service dependencies.""" with ( patch("services.account_service.FeatureService") as mock_account_feature_service, - patch("services.message_service.ModelManager") as mock_model_manager, + patch("services.message_service.ModelManager.for_tenant") as mock_model_manager, patch("services.message_service.WorkflowService") as mock_workflow_service, patch("services.message_service.AdvancedChatAppConfigManager") as mock_app_config_manager, patch("services.message_service.LLMGenerator") as mock_llm_generator, diff --git a/api/tests/test_containers_integration_tests/services/test_saved_message_service.py b/api/tests/test_containers_integration_tests/services/test_saved_message_service.py index dd743d46c2..1071ca7dca 100644 --- a/api/tests/test_containers_integration_tests/services/test_saved_message_service.py +++ b/api/tests/test_containers_integration_tests/services/test_saved_message_service.py @@ -19,7 +19,7 @@ class TestSavedMessageService: """Mock setup for external service dependencies.""" with ( patch("services.account_service.FeatureService") as mock_account_feature_service, - patch("services.app_service.ModelManager") as mock_model_manager, + patch("services.app_service.ModelManager.for_tenant") as mock_model_manager, patch("services.saved_message_service.MessageService") as mock_message_service, ): # Setup default mock returns diff --git a/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py b/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py index 425611744b..d71ae67b58 100644 --- a/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py +++ b/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py @@ -24,7 +24,7 @@ class TestWebConversationService: with ( patch("services.app_service.FeatureService") as mock_feature_service, patch("services.app_service.EnterpriseService") as mock_enterprise_service, - patch("services.app_service.ModelManager") as mock_model_manager, + patch("services.app_service.ModelManager.for_tenant") as mock_model_manager, patch("services.account_service.FeatureService") as mock_account_feature_service, ): # Setup default mock returns for app service diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py index 8ab8df2a5a..f2eccdbf9c 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py @@ -27,7 +27,7 @@ class TestWorkflowAppService: with ( patch("services.app_service.FeatureService") as mock_feature_service, patch("services.app_service.EnterpriseService") as mock_enterprise_service, - patch("services.app_service.ModelManager") as mock_model_manager, + patch("services.app_service.ModelManager.for_tenant") as mock_model_manager, patch("services.account_service.FeatureService") as mock_account_feature_service, ): # Setup default mock returns for app service diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py index e080d6ef6b..56047368b0 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py @@ -27,7 +27,7 @@ class TestWorkflowRunService: with ( patch("services.app_service.FeatureService") as mock_feature_service, patch("services.app_service.EnterpriseService") as mock_enterprise_service, - patch("services.app_service.ModelManager") as mock_model_manager, + patch("services.app_service.ModelManager.for_tenant") as mock_model_manager, patch("services.account_service.FeatureService") as mock_account_feature_service, ): # Setup default mock returns for app service diff --git a/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py b/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py index 34906a4e54..cc3b856f04 100644 --- a/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py +++ b/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py @@ -25,7 +25,7 @@ class TestWorkflowToolManageService: with ( patch("services.app_service.FeatureService") as mock_feature_service, patch("services.app_service.EnterpriseService") as mock_enterprise_service, - patch("services.app_service.ModelManager") as mock_model_manager, + patch("services.app_service.ModelManager.for_tenant") as mock_model_manager, patch("services.account_service.FeatureService") as mock_account_feature_service, patch( "services.tools.workflow_tools_manage_service.WorkflowToolProviderController" diff --git a/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py index 202ccb0098..664a0e3a29 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py @@ -53,7 +53,10 @@ class TestBatchCreateSegmentToIndexTask: """Mock setup for external service dependencies.""" with ( patch("tasks.batch_create_segment_to_index_task.storage", autospec=True) as mock_storage, - patch("tasks.batch_create_segment_to_index_task.ModelManager", autospec=True) as mock_model_manager, + patch( + "tasks.batch_create_segment_to_index_task.ModelManager.for_tenant", + autospec=True, + ) as mock_model_manager, patch("tasks.batch_create_segment_to_index_task.VectorService", autospec=True) as mock_vector_service, ): # Setup default mock returns diff --git a/api/tests/unit_tests/core/llm_generator/test_llm_generator.py b/api/tests/unit_tests/core/llm_generator/test_llm_generator.py index 5b7640696f..33a12528c0 100644 --- a/api/tests/unit_tests/core/llm_generator/test_llm_generator.py +++ b/api/tests/unit_tests/core/llm_generator/test_llm_generator.py @@ -13,7 +13,7 @@ from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError, Inv class TestLLMGenerator: @pytest.fixture def mock_model_instance(self): - with patch("core.llm_generator.llm_generator.ModelManager") as mock_manager: + with patch("core.llm_generator.llm_generator.ModelManager.for_tenant") as mock_manager: instance = MagicMock() mock_manager.return_value.get_default_model_instance.return_value = instance mock_manager.return_value.get_model_instance.return_value = instance @@ -98,7 +98,7 @@ class TestLLMGenerator: assert questions[0] == "Question 1?" def test_generate_suggested_questions_after_answer_auth_error(self, mock_model_instance): - with patch("core.llm_generator.llm_generator.ModelManager") as mock_manager: + with patch("core.llm_generator.llm_generator.ModelManager.for_tenant") as mock_manager: mock_manager.return_value.get_default_model_instance.side_effect = InvokeAuthorizationError("Auth failed") questions = LLMGenerator.generate_suggested_questions_after_answer("tenant_id", "histories") assert questions == [] @@ -528,7 +528,7 @@ class TestLLMGenerator: assert "An unexpected error occurred" in result["error"] def test_instruction_modify_common_other_node_type(self, mock_model_instance, model_config_entity): - with patch("core.llm_generator.llm_generator.ModelManager") as mock_manager: + with patch("core.llm_generator.llm_generator.ModelManager.for_tenant") as mock_manager: instance = MagicMock() mock_manager.return_value.get_model_instance.return_value = instance mock_response = MagicMock() diff --git a/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py b/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py new file mode 100644 index 0000000000..ad9445f9af --- /dev/null +++ b/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py @@ -0,0 +1,60 @@ +import pytest + +from dify_graph.model_runtime.entities.common_entities import I18nObject +from dify_graph.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity +from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory + + +class _FakeModelRuntime: + def __init__(self, providers: list[ProviderEntity]) -> None: + self._providers = providers + + def fetch_model_providers(self) -> list[ProviderEntity]: + return self._providers + + +def test_model_provider_factory_resolves_runtime_provider_name() -> None: + provider = ProviderEntity( + provider="langgenius/openai/openai", + provider_name="openai", + label=I18nObject(en_US="OpenAI"), + supported_model_types=[ModelType.LLM], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + ) + factory = ModelProviderFactory(model_runtime=_FakeModelRuntime([provider])) + + provider_schema = factory.get_model_provider("openai") + + assert provider_schema.provider == "langgenius/openai/openai" + assert provider_schema.provider_name == "openai" + + +def test_model_provider_factory_resolves_canonical_short_name_independent_of_provider_order() -> None: + providers = [ + ProviderEntity( + provider="acme/openai/openai", + provider_name="", + label=I18nObject(en_US="Acme OpenAI"), + supported_model_types=[ModelType.LLM], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + ), + ProviderEntity( + provider="langgenius/openai/openai", + provider_name="openai", + label=I18nObject(en_US="OpenAI"), + supported_model_types=[ModelType.LLM], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + ), + ] + factory = ModelProviderFactory(model_runtime=_FakeModelRuntime(providers)) + + provider_schema = factory.get_model_provider("openai") + + assert provider_schema.provider == "langgenius/openai/openai" + assert provider_schema.provider_name == "openai" + + +def test_model_provider_factory_requires_runtime() -> None: + with pytest.raises(ValueError, match="model_runtime is required"): + ModelProviderFactory(model_runtime=None) # type: ignore[arg-type] diff --git a/api/tests/unit_tests/core/moderation/test_content_moderation.py b/api/tests/unit_tests/core/moderation/test_content_moderation.py index e61cde22e7..3a97ad5c5d 100644 --- a/api/tests/unit_tests/core/moderation/test_content_moderation.py +++ b/api/tests/unit_tests/core/moderation/test_content_moderation.py @@ -324,7 +324,7 @@ class TestOpenAIModeration: with pytest.raises(ValueError, match="At least one of inputs_config or outputs_config must be enabled"): OpenAIModeration.validate_config("test-tenant", config) - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_moderation_for_inputs_no_violation(self, mock_model_manager: Mock, openai_moderation: OpenAIModeration): """Test input moderation when OpenAI API returns no violations.""" # Mock the model manager and instance @@ -341,7 +341,7 @@ class TestOpenAIModeration: assert result.action == ModerationAction.DIRECT_OUTPUT assert result.preset_response == "Content flagged by OpenAI moderation." - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_moderation_for_inputs_with_violation(self, mock_model_manager: Mock, openai_moderation: OpenAIModeration): """Test input moderation when OpenAI API detects violations.""" # Mock the model manager to return violation @@ -358,7 +358,7 @@ class TestOpenAIModeration: assert result.action == ModerationAction.DIRECT_OUTPUT assert result.preset_response == "Content flagged by OpenAI moderation." - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_moderation_for_inputs_query_included(self, mock_model_manager: Mock, openai_moderation: OpenAIModeration): """Test that query is included in moderation check with special key.""" mock_instance = MagicMock() @@ -385,7 +385,7 @@ class TestOpenAIModeration: assert "u" in moderated_text assert "e" in moderated_text - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_moderation_for_inputs_disabled(self, mock_model_manager: Mock): """Test input moderation when inputs_config is disabled.""" config = { @@ -400,7 +400,7 @@ class TestOpenAIModeration: # Should not call the API when disabled mock_model_manager.assert_not_called() - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_moderation_for_outputs_no_violation(self, mock_model_manager: Mock, openai_moderation: OpenAIModeration): """Test output moderation when OpenAI API returns no violations.""" mock_instance = MagicMock() @@ -414,7 +414,7 @@ class TestOpenAIModeration: assert result.action == ModerationAction.DIRECT_OUTPUT assert result.preset_response == "Response blocked by moderation." - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_moderation_for_outputs_with_violation(self, mock_model_manager: Mock, openai_moderation: OpenAIModeration): """Test output moderation when OpenAI API detects violations.""" mock_instance = MagicMock() @@ -427,7 +427,7 @@ class TestOpenAIModeration: assert result.flagged is True assert result.action == ModerationAction.DIRECT_OUTPUT - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_moderation_for_outputs_disabled(self, mock_model_manager: Mock): """Test output moderation when outputs_config is disabled.""" config = { @@ -441,7 +441,7 @@ class TestOpenAIModeration: assert result.flagged is False mock_model_manager.assert_not_called() - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_model_manager_called_with_correct_params( self, mock_model_manager: Mock, openai_moderation: OpenAIModeration ): @@ -629,7 +629,7 @@ class TestPresetManagement: assert result.flagged is True assert result.preset_response == "Custom output blocked message" - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_openai_preset_response_in_inputs(self, mock_model_manager: Mock): """Test preset response is properly returned for OpenAI input violations.""" mock_instance = MagicMock() @@ -650,7 +650,7 @@ class TestPresetManagement: assert result.flagged is True assert result.preset_response == "OpenAI input blocked" - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_openai_preset_response_in_outputs(self, mock_model_manager: Mock): """Test preset response is properly returned for OpenAI output violations.""" mock_instance = MagicMock() @@ -989,7 +989,7 @@ class TestOpenAIModerationAdvanced: - Performance considerations """ - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_openai_api_timeout_handling(self, mock_model_manager: Mock): """ Test graceful handling of OpenAI API timeouts. @@ -1012,7 +1012,7 @@ class TestOpenAIModerationAdvanced: with pytest.raises(TimeoutError): moderation.moderation_for_inputs({"text": "test"}, "") - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_openai_api_rate_limit_handling(self, mock_model_manager: Mock): """ Test handling of OpenAI API rate limit errors. @@ -1035,7 +1035,7 @@ class TestOpenAIModerationAdvanced: with pytest.raises(Exception, match="Rate limit exceeded"): moderation.moderation_for_inputs({"text": "test"}, "") - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_openai_with_multiple_input_fields(self, mock_model_manager: Mock): """ Test OpenAI moderation with multiple input fields. @@ -1079,7 +1079,7 @@ class TestOpenAIModerationAdvanced: assert "u" in moderated_text assert "e" in moderated_text - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_openai_empty_text_handling(self, mock_model_manager: Mock): """ Test OpenAI moderation with empty text inputs. @@ -1103,7 +1103,7 @@ class TestOpenAIModerationAdvanced: assert result.flagged is False mock_instance.invoke_moderation.assert_called_once() - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_openai_model_instance_fetched_on_each_call(self, mock_model_manager: Mock): """ Test that ModelManager fetches a fresh model instance on each call. diff --git a/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py b/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py new file mode 100644 index 0000000000..eb738a2063 --- /dev/null +++ b/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py @@ -0,0 +1,229 @@ +"""Unit tests for the plugin-backed model runtime adapter.""" + +import datetime +import uuid +from unittest.mock import Mock, sentinel + +import pytest + +from core.plugin.entities.plugin_daemon import PluginModelProviderEntity +from core.plugin.impl.model import PluginModelClient +from core.plugin.impl.model_runtime import PluginModelRuntime +from core.plugin.impl.model_runtime_factory import create_plugin_model_runtime +from dify_graph.model_runtime.entities.common_entities import I18nObject +from dify_graph.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity + + +class TestPluginModelRuntime: + """Validate the adapter keeps plugin-specific routing out of the runtime port.""" + + def test_fetch_model_providers_returns_runtime_entities(self) -> None: + client = Mock(spec=PluginModelClient) + client.fetch_model_providers.return_value = [ + PluginModelProviderEntity( + id=uuid.uuid4().hex, + created_at=datetime.datetime.now(), + updated_at=datetime.datetime.now(), + provider="openai", + tenant_id="tenant", + plugin_unique_identifier="langgenius/openai/openai", + plugin_id="langgenius/openai", + declaration=ProviderEntity( + provider="openai", + label=I18nObject(en_US="OpenAI"), + supported_model_types=[], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + ), + ) + ] + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + + providers = runtime.fetch_model_providers() + + assert len(providers) == 1 + assert providers[0].provider == "langgenius/openai/openai" + assert providers[0].provider_name == "openai" + assert providers[0].label.en_US == "OpenAI" + client.fetch_model_providers.assert_called_once_with("tenant") + + def test_fetch_model_providers_only_exposes_short_name_for_canonical_provider(self) -> None: + client = Mock(spec=PluginModelClient) + client.fetch_model_providers.return_value = [ + PluginModelProviderEntity( + id=uuid.uuid4().hex, + created_at=datetime.datetime.now(), + updated_at=datetime.datetime.now(), + provider="openai", + tenant_id="tenant", + plugin_unique_identifier="acme/openai/openai", + plugin_id="acme/openai", + declaration=ProviderEntity( + provider="openai", + label=I18nObject(en_US="Acme OpenAI"), + supported_model_types=[], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + ), + ), + PluginModelProviderEntity( + id=uuid.uuid4().hex, + created_at=datetime.datetime.now(), + updated_at=datetime.datetime.now(), + provider="openai", + tenant_id="tenant", + plugin_unique_identifier="langgenius/openai/openai", + plugin_id="langgenius/openai", + declaration=ProviderEntity( + provider="openai", + label=I18nObject(en_US="OpenAI"), + supported_model_types=[], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + ), + ), + ] + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + + providers = runtime.fetch_model_providers() + + provider_aliases = {provider.provider: provider.provider_name for provider in providers} + assert provider_aliases["acme/openai/openai"] == "" + assert provider_aliases["langgenius/openai/openai"] == "openai" + + def test_fetch_model_providers_keeps_google_alias_on_canonical_gemini_provider(self) -> None: + client = Mock(spec=PluginModelClient) + client.fetch_model_providers.return_value = [ + PluginModelProviderEntity( + id=uuid.uuid4().hex, + created_at=datetime.datetime.now(), + updated_at=datetime.datetime.now(), + provider="google", + tenant_id="tenant", + plugin_unique_identifier="langgenius/gemini/google", + plugin_id="langgenius/gemini", + declaration=ProviderEntity( + provider="google", + label=I18nObject(en_US="Google"), + supported_model_types=[], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + ), + ) + ] + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + + providers = runtime.fetch_model_providers() + + assert providers[0].provider == "langgenius/gemini/google" + assert providers[0].provider_name == "google" + + def test_validate_provider_credentials_resolves_plugin_fields(self) -> None: + client = Mock(spec=PluginModelClient) + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + + runtime.validate_provider_credentials( + provider="langgenius/openai/openai", + credentials={"api_key": "secret"}, + ) + + client.validate_provider_credentials.assert_called_once_with( + tenant_id="tenant", + user_id="user", + plugin_id="langgenius/openai", + provider="openai", + credentials={"api_key": "secret"}, + ) + + def test_invoke_llm_resolves_plugin_fields(self) -> None: + client = Mock(spec=PluginModelClient) + client.invoke_llm.return_value = sentinel.result + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + + result = runtime.invoke_llm( + provider="langgenius/openai/openai", + model="gpt-4o-mini", + credentials={"api_key": "secret"}, + model_parameters={"temperature": 0.3}, + prompt_messages=[], + tools=None, + stop=None, + stream=False, + ) + + assert result is sentinel.result + client.invoke_llm.assert_called_once_with( + tenant_id="tenant", + user_id="user", + plugin_id="langgenius/openai", + provider="openai", + model="gpt-4o-mini", + credentials={"api_key": "secret"}, + model_parameters={"temperature": 0.3}, + prompt_messages=[], + tools=None, + stop=None, + stream=False, + ) + + def test_invoke_llm_rejects_per_call_user_override(self) -> None: + client = Mock(spec=PluginModelClient) + client.invoke_llm.return_value = sentinel.result + runtime = PluginModelRuntime(tenant_id="tenant", user_id="bound-user", client=client) + + with pytest.raises(TypeError, match="unexpected keyword argument 'user_id'"): + runtime.invoke_llm( # type: ignore[call-arg] + provider="langgenius/openai/openai", + model="gpt-4o-mini", + credentials={"api_key": "secret"}, + model_parameters={"temperature": 0.3}, + prompt_messages=[], + tools=None, + stop=None, + stream=False, + user_id="request-user", + ) + + client.invoke_llm.assert_not_called() + + def test_invoke_tts_uses_bound_runtime_user_when_runtime_is_unbound(self) -> None: + client = Mock(spec=PluginModelClient) + client.invoke_tts.return_value = iter([b"chunk"]) + runtime = PluginModelRuntime(tenant_id="tenant", user_id=None, client=client) + + result = runtime.invoke_tts( + provider="langgenius/openai/openai", + model="tts-1", + credentials={"api_key": "secret"}, + content_text="hello", + voice="alloy", + ) + + assert list(result) == [b"chunk"] + client.invoke_tts.assert_called_once_with( + tenant_id="tenant", + user_id=None, + plugin_id="langgenius/openai", + provider="openai", + model="tts-1", + credentials={"api_key": "secret"}, + content_text="hello", + voice="alloy", + ) + + def test_fetch_model_providers_uses_bound_runtime_cache(self) -> None: + client = Mock(spec=PluginModelClient) + client.fetch_model_providers.return_value = [] + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + + runtime.fetch_model_providers() + runtime.fetch_model_providers() + + client.fetch_model_providers.assert_called_once_with("tenant") + + +def test_create_plugin_model_runtime_without_user_context() -> None: + runtime = create_plugin_model_runtime(tenant_id="tenant") + + assert runtime.user_id is None + + +def test_plugin_model_runtime_requires_client() -> None: + with pytest.raises(ValueError, match="client is required"): + PluginModelRuntime(tenant_id="tenant", user_id="user", client=None) # type: ignore[arg-type] diff --git a/api/tests/unit_tests/core/rag/docstore/test_dataset_docstore.py b/api/tests/unit_tests/core/rag/docstore/test_dataset_docstore.py index 13285cdad0..3ba0628fe2 100644 --- a/api/tests/unit_tests/core/rag/docstore/test_dataset_docstore.py +++ b/api/tests/unit_tests/core/rag/docstore/test_dataset_docstore.py @@ -163,7 +163,7 @@ class TestDatasetDocumentStoreAddDocuments: with ( patch("core.rag.docstore.dataset_docstore.db") as mock_db, - patch("core.rag.docstore.dataset_docstore.ModelManager") as mock_manager_class, + patch("core.rag.docstore.dataset_docstore.ModelManager.for_tenant") as mock_manager_class, ): mock_session = MagicMock() mock_db.session = mock_session diff --git a/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py b/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py index b011ade884..949c77ff71 100644 --- a/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py +++ b/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py @@ -445,7 +445,7 @@ class TestIndexingRunnerTransform: """Mock all external dependencies for transform tests.""" with ( patch("core.indexing_runner.db") as mock_db, - patch("core.indexing_runner.ModelManager") as mock_model_manager, + patch("core.indexing_runner.ModelManager.for_tenant") as mock_model_manager, ): yield { "db": mock_db, @@ -586,7 +586,7 @@ class TestIndexingRunnerLoad: """Mock all external dependencies for load tests.""" with ( patch("core.indexing_runner.db") as mock_db, - patch("core.indexing_runner.ModelManager") as mock_model_manager, + patch("core.indexing_runner.ModelManager.for_tenant") as mock_model_manager, patch("core.indexing_runner.current_app") as mock_app, patch("core.indexing_runner.threading.Thread") as mock_thread, patch("core.indexing_runner.concurrent.futures.ThreadPoolExecutor") as mock_executor, @@ -754,7 +754,7 @@ class TestIndexingRunnerRun: with ( patch("core.indexing_runner.db") as mock_db, patch("core.indexing_runner.IndexProcessorFactory") as mock_factory, - patch("core.indexing_runner.ModelManager") as mock_model_manager, + patch("core.indexing_runner.ModelManager.for_tenant") as mock_model_manager, patch("core.indexing_runner.storage") as mock_storage, patch("core.indexing_runner.threading.Thread") as mock_thread, ): diff --git a/api/tests/unit_tests/core/rag/rerank/test_reranker.py b/api/tests/unit_tests/core/rag/rerank/test_reranker.py index b150d677f1..6d47a90643 100644 --- a/api/tests/unit_tests/core/rag/rerank/test_reranker.py +++ b/api/tests/unit_tests/core/rag/rerank/test_reranker.py @@ -57,7 +57,7 @@ class TestRerankModelRunner: @pytest.fixture(autouse=True) def mock_model_manager(self): """Auto-use fixture to patch ModelManager for all tests in this class.""" - with patch("core.rag.rerank.rerank_model.ModelManager", autospec=True) as mock_mm: + with patch("core.rag.rerank.rerank_model.ModelManager.for_tenant", autospec=True) as mock_mm: mock_mm.return_value.check_model_support_vision.return_value = False yield mock_mm @@ -424,7 +424,7 @@ class TestRerankModelRunnerMultimodal: Document(page_content="doc", metadata={"doc_id": "doc1"}, provider="dify"), ] - with patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm: + with patch("core.rag.rerank.rerank_model.ModelManager.for_tenant") as mock_mm: mock_mm.return_value.check_model_support_vision.return_value = False result = rerank_runner.run(query="image-file-id", documents=documents, query_type=QueryType.IMAGE_QUERY) @@ -441,7 +441,7 @@ class TestRerankModelRunnerMultimodal: ) with ( - patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm, + patch("core.rag.rerank.rerank_model.ModelManager.for_tenant") as mock_mm, patch.object( rerank_runner, "fetch_multimodal_rerank", @@ -595,7 +595,7 @@ class TestWeightRerankRunner: @pytest.fixture def mock_model_manager(self): """Mock ModelManager for embedding model.""" - with patch("core.rag.rerank.weight_rerank.ModelManager", autospec=True) as mock_manager: + with patch("core.rag.rerank.weight_rerank.ModelManager.for_tenant", autospec=True) as mock_manager: yield mock_manager @pytest.fixture @@ -1145,7 +1145,7 @@ class TestRerankIntegration: @pytest.fixture(autouse=True) def mock_model_manager(self): """Auto-use fixture to patch ModelManager for all tests in this class.""" - with patch("core.rag.rerank.rerank_model.ModelManager", autospec=True) as mock_mm: + with patch("core.rag.rerank.rerank_model.ModelManager.for_tenant", autospec=True) as mock_mm: mock_mm.return_value.check_model_support_vision.return_value = False yield mock_mm @@ -1257,7 +1257,7 @@ class TestRerankEdgeCases: @pytest.fixture(autouse=True) def mock_model_manager(self): """Auto-use fixture to patch ModelManager for all tests in this class.""" - with patch("core.rag.rerank.rerank_model.ModelManager", autospec=True) as mock_mm: + with patch("core.rag.rerank.rerank_model.ModelManager.for_tenant", autospec=True) as mock_mm: mock_mm.return_value.check_model_support_vision.return_value = False yield mock_mm @@ -1527,7 +1527,7 @@ class TestRerankEdgeCases: # Mock dependencies with ( patch("core.rag.rerank.weight_rerank.JiebaKeywordTableHandler", autospec=True) as mock_jieba, - patch("core.rag.rerank.weight_rerank.ModelManager", autospec=True) as mock_manager, + patch("core.rag.rerank.weight_rerank.ModelManager.for_tenant", autospec=True) as mock_manager, patch("core.rag.rerank.weight_rerank.CacheEmbedding", autospec=True) as mock_cache, ): mock_handler = MagicMock() @@ -1598,7 +1598,7 @@ class TestRerankPerformance: @pytest.fixture(autouse=True) def mock_model_manager(self): """Auto-use fixture to patch ModelManager for all tests in this class.""" - with patch("core.rag.rerank.rerank_model.ModelManager", autospec=True) as mock_mm: + with patch("core.rag.rerank.rerank_model.ModelManager.for_tenant", autospec=True) as mock_mm: mock_mm.return_value.check_model_support_vision.return_value = False yield mock_mm @@ -1673,7 +1673,7 @@ class TestRerankPerformance: with ( patch("core.rag.rerank.weight_rerank.JiebaKeywordTableHandler", autospec=True) as mock_jieba, - patch("core.rag.rerank.weight_rerank.ModelManager", autospec=True) as mock_manager, + patch("core.rag.rerank.weight_rerank.ModelManager.for_tenant", autospec=True) as mock_manager, patch("core.rag.rerank.weight_rerank.CacheEmbedding", autospec=True) as mock_cache, ): mock_handler = MagicMock() @@ -1715,7 +1715,7 @@ class TestRerankErrorHandling: @pytest.fixture(autouse=True) def mock_model_manager(self): """Auto-use fixture to patch ModelManager for all tests in this class.""" - with patch("core.rag.rerank.rerank_model.ModelManager", autospec=True) as mock_mm: + with patch("core.rag.rerank.rerank_model.ModelManager.for_tenant", autospec=True) as mock_mm: mock_mm.return_value.check_model_support_vision.return_value = False yield mock_mm @@ -1824,7 +1824,7 @@ class TestRerankErrorHandling: with ( patch("core.rag.rerank.weight_rerank.JiebaKeywordTableHandler", autospec=True) as mock_jieba, - patch("core.rag.rerank.weight_rerank.ModelManager", autospec=True) as mock_manager, + patch("core.rag.rerank.weight_rerank.ModelManager.for_tenant", autospec=True) as mock_manager, patch("core.rag.rerank.weight_rerank.CacheEmbedding", autospec=True) as mock_cache, ): mock_handler = MagicMock() diff --git a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py index 665e98bd9c..8356e3e715 100644 --- a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py +++ b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py @@ -3836,7 +3836,7 @@ class TestDatasetRetrievalAdditionalHelpers: model_instance.model_type_instance.get_model_schema.return_value = Mock() with ( - patch("core.rag.retrieval.dataset_retrieval.ModelManager") as mock_manager, + patch("core.rag.retrieval.dataset_retrieval.ModelManager.for_tenant") as mock_manager, patch("core.rag.retrieval.dataset_retrieval.ModelConfigWithCredentialsEntity") as mock_cfg_entity, ): mock_manager.return_value.get_model_instance.return_value = model_instance @@ -4222,7 +4222,7 @@ class TestKnowledgeRetrievalCoverage: with ( patch.object(retrieval, "_check_knowledge_rate_limit"), patch.object(retrieval, "_get_available_datasets", return_value=[SimpleNamespace(id="dataset-1")]), - patch("core.rag.retrieval.dataset_retrieval.ModelManager") as mock_model_manager, + patch("core.rag.retrieval.dataset_retrieval.ModelManager.for_tenant") as mock_model_manager, ): mock_model_manager.return_value.get_model_instance.return_value = model_instance with pytest.raises(Exception) as exc_info: @@ -4280,7 +4280,7 @@ class TestRetrieveCoverage: ) model_config = self._build_model_config() model_config.provider_model_bundle.model_type_instance.get_model_schema.return_value = None - with patch("core.rag.retrieval.dataset_retrieval.ModelManager") as mock_model_manager: + with patch("core.rag.retrieval.dataset_retrieval.ModelManager.for_tenant") as mock_model_manager: mock_model_manager.return_value.get_model_instance.return_value = Mock() result = retrieval.retrieve( app_id="app-1", @@ -4312,7 +4312,7 @@ class TestRetrieveCoverage: extra={"title": "External", "dataset_name": "External DS"}, ) with ( - patch("core.rag.retrieval.dataset_retrieval.ModelManager") as mock_model_manager, + patch("core.rag.retrieval.dataset_retrieval.ModelManager.for_tenant") as mock_model_manager, patch.object(retrieval, "_get_available_datasets", return_value=[SimpleNamespace(id="d1")]), patch.object(retrieval, "get_metadata_filter_condition", return_value=(None, None)), patch.object(retrieval, "single_retrieve", return_value=[external_doc]), @@ -4402,7 +4402,7 @@ class TestRetrieveCoverage: hit_callback = Mock() with ( - patch("core.rag.retrieval.dataset_retrieval.ModelManager") as mock_model_manager, + patch("core.rag.retrieval.dataset_retrieval.ModelManager.for_tenant") as mock_model_manager, patch.object(retrieval, "_get_available_datasets", return_value=[SimpleNamespace(id="d1")]), patch.object(retrieval, "get_metadata_filter_condition", return_value=(None, None)), patch.object(retrieval, "multiple_retrieve", return_value=[external_doc, dify_doc]), diff --git a/api/tests/unit_tests/core/test_provider_manager.py b/api/tests/unit_tests/core/test_provider_manager.py index 69567c54eb..375da91527 100644 --- a/api/tests/unit_tests/core/test_provider_manager.py +++ b/api/tests/unit_tests/core/test_provider_manager.py @@ -1,6 +1,7 @@ from unittest.mock import Mock, PropertyMock, patch import pytest +from pytest_mock import MockerFixture from core.entities.provider_entities import ModelSettings from core.provider_manager import ProviderManager @@ -9,6 +10,10 @@ from dify_graph.model_runtime.entities.model_entities import ModelType from models.provider import LoadBalancingModelConfig, ProviderModelSetting +def _build_provider_manager(mocker: MockerFixture) -> ProviderManager: + return ProviderManager(model_runtime=mocker.Mock()) + + @pytest.fixture def mock_provider_entity(): mock_entity = Mock() @@ -28,7 +33,7 @@ def mock_provider_entity(): return mock_entity -def test__to_model_settings(mock_provider_entity): +def test__to_model_settings(mocker: MockerFixture, mock_provider_entity): # Mocking the inputs ps = ProviderModelSetting( tenant_id="tenant_id", @@ -69,7 +74,7 @@ def test__to_model_settings(mock_provider_entity): "core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"}, ): - provider_manager = ProviderManager() + provider_manager = _build_provider_manager(mocker) # Running the method result = provider_manager._to_model_settings( @@ -89,7 +94,7 @@ def test__to_model_settings(mock_provider_entity): assert result[0].load_balancing_configs[1].name == "first" -def test__to_model_settings_only_one_lb(mock_provider_entity): +def test__to_model_settings_only_one_lb(mocker: MockerFixture, mock_provider_entity): # Mocking the inputs ps = ProviderModelSetting( @@ -119,7 +124,7 @@ def test__to_model_settings_only_one_lb(mock_provider_entity): "core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"}, ): - provider_manager = ProviderManager() + provider_manager = _build_provider_manager(mocker) # Running the method result = provider_manager._to_model_settings( @@ -137,7 +142,7 @@ def test__to_model_settings_only_one_lb(mock_provider_entity): assert len(result[0].load_balancing_configs) == 0 -def test__to_model_settings_lb_disabled(mock_provider_entity): +def test__to_model_settings_lb_disabled(mocker: MockerFixture, mock_provider_entity): # Mocking the inputs ps = ProviderModelSetting( tenant_id="tenant_id", @@ -176,7 +181,7 @@ def test__to_model_settings_lb_disabled(mock_provider_entity): "core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"}, ): - provider_manager = ProviderManager() + provider_manager = _build_provider_manager(mocker) # Running the method result = provider_manager._to_model_settings( @@ -194,7 +199,7 @@ def test__to_model_settings_lb_disabled(mock_provider_entity): assert len(result[0].load_balancing_configs) == 0 -def test_get_default_model_uses_first_available_active_model(): +def test_get_default_model_uses_first_available_active_model(mocker: MockerFixture): mock_session = Mock() mock_session.scalar.return_value = None @@ -204,7 +209,7 @@ def test_get_default_model_uses_first_available_active_model(): Mock(model="gpt-4", provider=Mock(provider="openai")), ] - manager = ProviderManager() + manager = _build_provider_manager(mocker) with ( patch("core.provider_manager.db.session", mock_session), patch.object(manager, "get_configurations", return_value=provider_configurations), diff --git a/api/tests/unit_tests/core/tools/test_builtin_tools_extra.py b/api/tests/unit_tests/core/tools/test_builtin_tools_extra.py index 62cfb6ce5b..3012e64ebb 100644 --- a/api/tests/unit_tests/core/tools/test_builtin_tools_extra.py +++ b/api/tests/unit_tests/core/tools/test_builtin_tools_extra.py @@ -186,13 +186,19 @@ def test_asr_invalid_file(): def test_asr_valid_file_invocation(monkeypatch): asr = _build_builtin_tool(ASRTool) - model_instance = type("M", (), {"invoke_speech2text": lambda self, file, user: "transcript"})() + model_instance = type("M", (), {"invoke_speech2text": lambda self, file: "transcript"})() model_manager = type("Mgr", (), {"get_model_instance": lambda *a, **k: model_instance})() monkeypatch.setattr("core.tools.builtin_tool.providers.audio.tools.asr.download", lambda file: b"audio-bytes") - monkeypatch.setattr("core.tools.builtin_tool.providers.audio.tools.asr.ModelManager", lambda: model_manager) + captured_manager_kwargs = {} + + monkeypatch.setattr( + "core.tools.builtin_tool.providers.audio.tools.asr.ModelManager.for_tenant", + lambda **kwargs: captured_manager_kwargs.update(kwargs) or model_manager, + ) audio_file = SimpleNamespace(type=FileType.AUDIO) ok = list(asr.invoke(user_id="u", tool_parameters={"audio_file": audio_file, "model": "p#m"}))[0].message.text assert ok == "transcript" + assert captured_manager_kwargs == {"tenant_id": "tenant-1", "user_id": "u"} def test_asr_available_models_and_runtime_parameters(monkeypatch): @@ -208,6 +214,7 @@ def test_asr_available_models_and_runtime_parameters(monkeypatch): def test_tts_invoke_returns_messages(monkeypatch): tts = _build_builtin_tool(TTSTool) + captured_manager_kwargs = {} voices_model_instance = type( "TTSM", (), @@ -217,11 +224,15 @@ def test_tts_invoke_returns_messages(monkeypatch): }, )() monkeypatch.setattr( - "core.tools.builtin_tool.providers.audio.tools.tts.ModelManager", - lambda: type("M", (), {"get_model_instance": lambda *a, **k: voices_model_instance})(), + "core.tools.builtin_tool.providers.audio.tools.tts.ModelManager.for_tenant", + lambda **kwargs: ( + captured_manager_kwargs.update(kwargs) + or type("M", (), {"get_model_instance": lambda *a, **k: voices_model_instance})() + ), ) messages = list(tts.invoke(user_id="u", tool_parameters={"model": "p#m", "text": "hello"})) assert [m.type for m in messages] == [ToolInvokeMessage.MessageType.TEXT, ToolInvokeMessage.MessageType.BLOB] + assert captured_manager_kwargs == {"tenant_id": "tenant-1", "user_id": "u"} def test_tts_get_available_models_requires_runtime(): @@ -254,8 +265,8 @@ def test_tts_tool_raises_when_voice_unavailable(monkeypatch, voices): }, )() monkeypatch.setattr( - "core.tools.builtin_tool.providers.audio.tools.tts.ModelManager", - lambda: type("Manager", (), {"get_model_instance": lambda *args, **kwargs: model_without_voice})(), + "core.tools.builtin_tool.providers.audio.tools.tts.ModelManager.for_tenant", + lambda **_: type("Manager", (), {"get_model_instance": lambda *args, **kwargs: model_without_voice})(), ) with pytest.raises(ValueError, match="no voice available"): list(tts.invoke(user_id="u", tool_parameters={"model": "p#m", "text": "hello"})) diff --git a/api/tests/unit_tests/core/tools/utils/test_model_invocation_utils.py b/api/tests/unit_tests/core/tools/utils/test_model_invocation_utils.py index 2acae889b2..f36993a3b9 100644 --- a/api/tests/unit_tests/core/tools/utils/test_model_invocation_utils.py +++ b/api/tests/unit_tests/core/tools/utils/test_model_invocation_utils.py @@ -60,7 +60,7 @@ def test_get_max_llm_context_tokens_branches(model_instance, expected, error_mat manager = Mock() manager.get_default_model_instance.return_value = model_instance - with patch("core.tools.utils.model_invocation_utils.ModelManager", return_value=manager): + with patch("core.tools.utils.model_invocation_utils.ModelManager.for_tenant", return_value=manager): if error_match: with pytest.raises(InvokeModelError, match=error_match): ModelInvocationUtils.get_max_llm_context_tokens("tenant") @@ -71,7 +71,7 @@ def test_get_max_llm_context_tokens_branches(model_instance, expected, error_mat def test_calculate_tokens_handles_missing_model(): manager = Mock() manager.get_default_model_instance.return_value = None - with patch("core.tools.utils.model_invocation_utils.ModelManager", return_value=manager): + with patch("core.tools.utils.model_invocation_utils.ModelManager.for_tenant", return_value=manager): with pytest.raises(InvokeModelError, match="Model not found"): ModelInvocationUtils.calculate_tokens("tenant", []) @@ -98,7 +98,7 @@ def test_invoke_success_and_error_mappings(): db_mock = SimpleNamespace(session=Mock()) - with patch("core.tools.utils.model_invocation_utils.ModelManager", return_value=manager): + with patch("core.tools.utils.model_invocation_utils.ModelManager.for_tenant", return_value=manager): with patch("core.tools.utils.model_invocation_utils.ToolModelInvoke", _ToolModelInvoke): with patch("core.tools.utils.model_invocation_utils.db", db_mock): response = ModelInvocationUtils.invoke( @@ -145,7 +145,7 @@ def test_invoke_error_mappings(exc, expected): db_mock = SimpleNamespace(session=Mock()) - with patch("core.tools.utils.model_invocation_utils.ModelManager", return_value=manager): + with patch("core.tools.utils.model_invocation_utils.ModelManager.for_tenant", return_value=manager): with patch("core.tools.utils.model_invocation_utils.ToolModelInvoke", _ToolModelInvoke): with patch("core.tools.utils.model_invocation_utils.db", db_mock): with pytest.raises(InvokeModelError, match=expected): diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py index 2a36f712fd..db9b85640a 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py @@ -2,6 +2,7 @@ import threading from datetime import datetime from unittest.mock import MagicMock, patch +from core.app.entities.app_invoke_entities import DifyRunContext, InvokeFrom, UserFrom from core.app.workflow.layers.llm_quota import LLMQuotaLayer from core.errors.error import QuotaExceededError from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus @@ -11,6 +12,16 @@ from dify_graph.model_runtime.entities.llm_entities import LLMUsage from dify_graph.node_events import NodeRunResult +def _build_dify_context() -> DifyRunContext: + return DifyRunContext( + tenant_id="tenant-id", + app_id="app-id", + user_id="user-id", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + ) + + def _build_succeeded_event() -> NodeRunSucceededEvent: return NodeRunSucceededEvent( id="execution-id", @@ -32,7 +43,7 @@ def test_deduct_quota_called_for_successful_llm_node() -> None: node.execution_id = "execution-id" node.node_type = BuiltinNodeTypes.LLM node.tenant_id = "tenant-id" - node.require_dify_context.return_value.tenant_id = "tenant-id" + node.require_run_context_value.return_value = _build_dify_context() node.model_instance = object() result_event = _build_succeeded_event() @@ -53,7 +64,7 @@ def test_deduct_quota_called_for_question_classifier_node() -> None: node.execution_id = "execution-id" node.node_type = BuiltinNodeTypes.QUESTION_CLASSIFIER node.tenant_id = "tenant-id" - node.require_dify_context.return_value.tenant_id = "tenant-id" + node.require_run_context_value.return_value = _build_dify_context() node.model_instance = object() result_event = _build_succeeded_event() @@ -74,7 +85,7 @@ def test_non_llm_node_is_ignored() -> None: node.execution_id = "execution-id" node.node_type = BuiltinNodeTypes.START node.tenant_id = "tenant-id" - node.require_dify_context.return_value.tenant_id = "tenant-id" + node.require_run_context_value.return_value = _build_dify_context() node._model_instance = object() result_event = _build_succeeded_event() @@ -91,7 +102,7 @@ def test_quota_error_is_handled_in_layer() -> None: node.execution_id = "execution-id" node.node_type = BuiltinNodeTypes.LLM node.tenant_id = "tenant-id" - node.require_dify_context.return_value.tenant_id = "tenant-id" + node.require_run_context_value.return_value = _build_dify_context() node.model_instance = object() result_event = _build_succeeded_event() @@ -113,7 +124,7 @@ def test_quota_deduction_exceeded_aborts_workflow_immediately() -> None: node.execution_id = "execution-id" node.node_type = BuiltinNodeTypes.LLM node.tenant_id = "tenant-id" - node.require_dify_context.return_value.tenant_id = "tenant-id" + node.require_run_context_value.return_value = _build_dify_context() node.model_instance = object() node.graph_runtime_state = MagicMock() node.graph_runtime_state.stop_event = stop_event diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py b/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py index 765c4deba3..23cd5d7f8f 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py @@ -3,8 +3,8 @@ import time from unittest.mock import MagicMock -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY, GraphInitParams +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom +from dify_graph.entities.graph_init_params import GraphInitParams from dify_graph.entities.pause_reason import SchedulingPause from dify_graph.graph import Graph from dify_graph.graph_engine import GraphEngine, GraphEngineConfig diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py index 538f53c603..99b8d8b34d 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py @@ -4,6 +4,7 @@ from collections.abc import Iterable from unittest import mock from unittest.mock import MagicMock +from core.workflow.node_runtime import DifyHumanInputNodeRuntime from dify_graph.graph import Graph from dify_graph.graph_events import ( GraphRunPausedEvent, @@ -125,6 +126,7 @@ def _build_branching_graph( graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, form_repository=form_repository, + runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), ) llm_primary = _create_llm_node("llm_primary", "Primary LLM", "Primary stream output") diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py index 36bba6deb6..6b1c6ad936 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py @@ -3,6 +3,7 @@ import time from unittest import mock from unittest.mock import MagicMock +from core.workflow.node_runtime import DifyHumanInputNodeRuntime from dify_graph.graph import Graph from dify_graph.graph_events import ( GraphRunPausedEvent, @@ -121,6 +122,7 @@ def _build_llm_human_llm_graph( graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, form_repository=form_repository, + runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), ) llm_second = _create_llm_node("llm_resume", "Follow-up LLM", "Follow-up prompt") diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py index 3e4247f33f..4fa210bee1 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py @@ -2,7 +2,7 @@ Simple test to verify MockNodeFactory works with iteration nodes. """ -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY from dify_graph.enums import BuiltinNodeTypes from tests.unit_tests.core.workflow.graph_engine.test_mock_config import MockConfigBuilder from tests.unit_tests.core.workflow.graph_engine.test_mock_factory import MockNodeFactory diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py index 454263bef9..b2740e91ba 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py @@ -11,6 +11,7 @@ from typing import TYPE_CHECKING, Any, Optional from unittest.mock import MagicMock from core.model_manager import ModelInstance +from core.workflow.node_runtime import DifyToolNodeRuntime from core.workflow.nodes.agent import AgentNode from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode from dify_graph.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus @@ -20,16 +21,15 @@ from dify_graph.nodes.code import CodeNode from dify_graph.nodes.document_extractor import DocumentExtractorNode from dify_graph.nodes.http_request import HttpRequestNode from dify_graph.nodes.llm import LLMNode -from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer +from dify_graph.nodes.llm.file_saver import LLMFileSaver +from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory +from dify_graph.nodes.llm.runtime_protocols import PromptMessageSerializerProtocol from dify_graph.nodes.parameter_extractor import ParameterExtractorNode -from dify_graph.nodes.protocols import HttpClientProtocol, ToolFileManagerProtocol +from dify_graph.nodes.protocols import FileReferenceFactoryProtocol, HttpClientProtocol, ToolFileManagerProtocol from dify_graph.nodes.question_classifier import QuestionClassifierNode from dify_graph.nodes.template_transform import TemplateTransformNode -from dify_graph.nodes.template_transform.template_renderer import ( - Jinja2TemplateRenderer, - TemplateRenderError, -) from dify_graph.nodes.tool import ToolNode +from dify_graph.template_rendering import Jinja2TemplateRenderer, TemplateRenderError if TYPE_CHECKING: from dify_graph.entities import GraphInitParams @@ -66,20 +66,26 @@ class MockNodeMixin: kwargs.setdefault("credentials_provider", MagicMock(spec=CredentialsProvider)) kwargs.setdefault("model_factory", MagicMock(spec=ModelFactory)) kwargs.setdefault("model_instance", MagicMock(spec=ModelInstance)) + kwargs.setdefault("prompt_message_serializer", MagicMock(spec=PromptMessageSerializerProtocol)) # LLM-like nodes now require an http_client; provide a mock by default for tests. kwargs.setdefault("http_client", MagicMock(spec=HttpClientProtocol)) - if isinstance(self, (LLMNode, QuestionClassifierNode)): - kwargs.setdefault("template_renderer", MagicMock(spec=TemplateRenderer)) + + if isinstance(self, (LLMNode, QuestionClassifierNode)): + kwargs.setdefault("llm_file_saver", MagicMock(spec=LLMFileSaver)) + + if isinstance(self, HttpRequestNode): + kwargs.setdefault("file_reference_factory", MagicMock(spec=FileReferenceFactoryProtocol)) # Ensure TemplateTransformNode receives a renderer now required by constructor if isinstance(self, TemplateTransformNode): - kwargs.setdefault("template_renderer", _TestJinja2Renderer()) + kwargs.setdefault("jinja2_template_renderer", _TestJinja2Renderer()) # Provide default tool_file_manager_factory for ToolNode subclasses from dify_graph.nodes.tool import ToolNode as _ToolNode # local import to avoid cycles if isinstance(self, _ToolNode): kwargs.setdefault("tool_file_manager_factory", MagicMock(spec=ToolFileManagerProtocol)) + kwargs.setdefault("runtime", DifyToolNodeRuntime(graph_init_params.run_context)) if isinstance(self, AgentNode): presentation_provider = MagicMock() diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py index a8398e8f79..482dbd4a91 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py @@ -6,7 +6,7 @@ to ensure they work correctly with the TableTestRunner. """ from configs import dify_config -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus from dify_graph.nodes.code.limits import CodeNodeLimits from tests.unit_tests.core.workflow.graph_engine.test_mock_config import MockConfig, MockConfigBuilder, NodeMockConfig diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_simple.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_simple.py index 5b35b3310a..d93c0a59b1 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_simple.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_simple.py @@ -4,7 +4,7 @@ Simple test to validate the auto-mock system without external dependencies. import sys -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY from dify_graph.enums import BuiltinNodeTypes from tests.unit_tests.core.workflow.graph_engine.test_mock_config import MockConfig, MockConfigBuilder, NodeMockConfig from tests.unit_tests.core.workflow.graph_engine.test_mock_factory import MockNodeFactory diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py index e681b39cc7..6453785c20 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py @@ -4,6 +4,7 @@ from dataclasses import dataclass from datetime import datetime, timedelta from typing import Any, Protocol +from core.workflow.node_runtime import DifyHumanInputNodeRuntime from dify_graph.entities.workflow_start_reason import WorkflowStartReason from dify_graph.graph import Graph from dify_graph.graph_engine.command_channels.in_memory_channel import InMemoryChannel @@ -159,6 +160,7 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor graph_init_params=graph_init_params, graph_runtime_state=runtime_state, form_repository=repo, + runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), ) human_b_config = {"id": "human_b", "data": human_data.model_dump()} @@ -168,6 +170,7 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor graph_init_params=graph_init_params, graph_runtime_state=runtime_state, form_repository=repo, + runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), ) end_data = EndNodeData( diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_pause_missing_finish.py b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_pause_missing_finish.py index 60167c0441..46ad6775d4 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_pause_missing_finish.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_pause_missing_finish.py @@ -4,6 +4,7 @@ from dataclasses import dataclass from datetime import datetime, timedelta from typing import Any +from core.workflow.node_runtime import DifyHumanInputNodeRuntime from dify_graph.entities.workflow_start_reason import WorkflowStartReason from dify_graph.graph import Graph from dify_graph.graph_engine.command_channels.in_memory_channel import InMemoryChannel @@ -162,6 +163,7 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor graph_init_params=graph_init_params, graph_runtime_state=runtime_state, form_repository=repo, + runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), ) human_b_config = {"id": "human_b", "data": human_data.model_dump()} @@ -171,6 +173,7 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor graph_init_params=graph_init_params, graph_runtime_state=runtime_state, form_repository=repo, + runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), delay_seconds=0.2, ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_pause_deferred_ready_nodes.py b/api/tests/unit_tests/core/workflow/graph_engine/test_pause_deferred_ready_nodes.py index 7328ce443f..e2f92663b1 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_pause_deferred_ready_nodes.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_pause_deferred_ready_nodes.py @@ -4,6 +4,7 @@ from dataclasses import dataclass from datetime import datetime, timedelta from typing import Any +from core.workflow.node_runtime import DifyHumanInputNodeRuntime from dify_graph.entities.workflow_start_reason import WorkflowStartReason from dify_graph.graph import Graph from dify_graph.graph_engine.command_channels.in_memory_channel import InMemoryChannel @@ -201,6 +202,7 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor graph_init_params=graph_init_params, graph_runtime_state=runtime_state, form_repository=repo, + runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), ) end_human_data = EndNodeData(title="End Human", outputs=[], desc=None) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_pause_resume_state.py b/api/tests/unit_tests/core/workflow/graph_engine/test_pause_resume_state.py index 15a7de3c52..ddce34c663 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_pause_resume_state.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_pause_resume_state.py @@ -3,6 +3,7 @@ import time from typing import Any from unittest.mock import MagicMock +from core.workflow.node_runtime import DifyHumanInputNodeRuntime from dify_graph.entities.workflow_start_reason import WorkflowStartReason from dify_graph.graph import Graph from dify_graph.graph_engine.command_channels.in_memory_channel import InMemoryChannel @@ -112,6 +113,7 @@ def _build_human_input_graph( graph_init_params=params, graph_runtime_state=runtime_state, form_repository=form_repository, + runtime=DifyHumanInputNodeRuntime(params.run_context), ) end_data = EndNodeData( diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py b/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py index ab8fb346b8..07190e2654 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py @@ -19,10 +19,10 @@ from functools import lru_cache from pathlib import Path from typing import Any, cast -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom from core.tools.utils.yaml_utils import _load_yaml_file from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY, GraphInitParams +from dify_graph.entities.graph_init_params import GraphInitParams from dify_graph.graph import Graph from dify_graph.graph_engine import GraphEngine, GraphEngineConfig from dify_graph.graph_engine.command_channels import InMemoryChannel diff --git a/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py b/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py index 859115ceb3..ab536bbf4b 100644 --- a/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py @@ -1,5 +1,5 @@ +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY from core.workflow.nodes.datasource.datasource_node import DatasourceNode -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus from dify_graph.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py index 5e34bf1d94..b05c932071 100644 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py @@ -7,6 +7,7 @@ import pytest from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.helper.ssrf_proxy import ssrf_proxy from core.tools.tool_file_manager import ToolFileManager +from core.workflow.node_runtime import DifyFileReferenceFactory from dify_graph.enums import WorkflowNodeExecutionStatus from dify_graph.file.file_manager import file_manager from dify_graph.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, HttpRequestNode, HttpRequestNodeConfig @@ -121,6 +122,7 @@ def _build_http_node( http_client=ssrf_proxy, tool_file_manager_factory=ToolFileManager, file_manager=file_manager, + file_reference_factory=DifyFileReferenceFactory(graph_init_params.run_context), ) diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py index 55aa62a1c0..a5ec36e1b4 100644 --- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py +++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py @@ -8,8 +8,9 @@ from unittest.mock import MagicMock import pytest from pydantic import ValidationError +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY +from core.workflow.node_runtime import DifyHumanInputNodeRuntime from dify_graph.entities import GraphInitParams -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY from dify_graph.node_events import PauseRequestedEvent from dify_graph.node_events.node import StreamCompletedEvent from dify_graph.nodes.human_input.entities import ( @@ -364,6 +365,7 @@ class TestHumanInputNodeVariableResolution: graph_init_params=graph_init_params, graph_runtime_state=runtime_state, form_repository=mock_repo, + runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), ) run_result = node._run() @@ -427,6 +429,7 @@ class TestHumanInputNodeVariableResolution: graph_init_params=graph_init_params, graph_runtime_state=runtime_state, form_repository=mock_repo, + runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), ) run_result = node._run() @@ -500,6 +503,7 @@ class TestHumanInputNodeVariableResolution: graph_init_params=graph_init_params, graph_runtime_state=runtime_state, form_repository=mock_repo, + runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), ) run_result = node._run() @@ -597,6 +601,7 @@ class TestHumanInputNodeRenderedContent: graph_init_params=graph_init_params, graph_runtime_state=runtime_state, form_repository=form_repository, + runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), ) pause_gen = node._run() diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py index b0ed47158d..ddb9328bf1 100644 --- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py +++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py @@ -1,8 +1,9 @@ import datetime from types import SimpleNamespace -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY, GraphInitParams +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom +from core.workflow.node_runtime import DifyHumanInputNodeRuntime +from dify_graph.entities.graph_init_params import GraphInitParams from dify_graph.enums import BuiltinNodeTypes from dify_graph.graph_events import ( NodeRunHumanInputFormFilledEvent, @@ -85,6 +86,7 @@ def _build_node(form_content: str = "Please enter your name:\n\n{{#$output.name# graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, form_repository=repo, + runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), ) @@ -149,6 +151,7 @@ def _build_timeout_node() -> HumanInputNode: graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, form_repository=repo, + runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), ) diff --git a/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py b/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py index d71e0921c1..b944c6e785 100644 --- a/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py @@ -2,8 +2,8 @@ from unittest.mock import MagicMock import pytest +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY from dify_graph.entities import GraphInitParams -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus from dify_graph.nodes.list_operator.node import ListOperatorNode from dify_graph.runtime import GraphRuntimeState diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index fc96088af1..2810239e37 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -5,11 +5,17 @@ from unittest import mock import pytest -from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity, UserFrom -from core.app.llm.model_access import DifyCredentialsProvider, DifyModelFactory, fetch_model_config +from core.app.entities.app_invoke_entities import DifyRunContext, InvokeFrom, ModelConfigWithCredentialsEntity, UserFrom +from core.app.llm.model_access import ( + DifyCredentialsProvider, + DifyModelFactory, + build_dify_model_access, + fetch_model_config, +) from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle from core.entities.provider_entities import CustomConfiguration, SystemConfiguration from core.model_manager import ModelInstance +from core.plugin.impl.model_runtime_factory import create_plugin_model_runtime from core.prompt.entities.advanced_prompt_entities import MemoryConfig from dify_graph.entities import GraphInitParams from dify_graph.file import File, FileTransferMethod, FileType @@ -34,8 +40,9 @@ from dify_graph.nodes.llm.entities import ( VisionConfigOptions, ) from dify_graph.nodes.llm.file_saver import LLMFileSaver -from dify_graph.nodes.llm.node import LLMNode -from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer +from dify_graph.nodes.llm.node import LLMNode, _handle_memory_completion_mode +from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory +from dify_graph.nodes.llm.runtime_protocols import PromptMessageSerializerProtocol from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable from dify_graph.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment @@ -107,7 +114,7 @@ def llm_node( mock_file_saver = mock.MagicMock(spec=LLMFileSaver) mock_credentials_provider = mock.MagicMock(spec=CredentialsProvider) mock_model_factory = mock.MagicMock(spec=ModelFactory) - mock_template_renderer = mock.MagicMock(spec=TemplateRenderer) + mock_prompt_message_serializer = mock.MagicMock(spec=PromptMessageSerializerProtocol) node_config = { "id": "1", "data": llm_node_data.model_dump(), @@ -122,7 +129,7 @@ def llm_node( model_factory=mock_model_factory, model_instance=mock.MagicMock(spec=ModelInstance), llm_file_saver=mock_file_saver, - template_renderer=mock_template_renderer, + prompt_message_serializer=mock_prompt_message_serializer, http_client=http_client, ) return node @@ -132,28 +139,31 @@ def llm_node( def model_config(monkeypatch): from tests.integration_tests.model_runtime.__mock.plugin_model import MockModelClass - def mock_plugin_model_providers(_self): - providers = MockModelClass().fetch_model_providers("test") - for provider in providers: - provider.declaration.provider = f"{provider.plugin_id}/{provider.declaration.provider}" + def mock_model_providers(_self): + providers = [] + for provider in MockModelClass().fetch_model_providers("test"): + provider_schema = provider.declaration.model_copy(deep=True) + provider_schema.provider = f"{provider.plugin_id}/{provider.provider}" + provider_schema.provider_name = provider.provider + providers.append(provider_schema) return providers monkeypatch.setattr( ModelProviderFactory, - "get_plugin_model_providers", - mock_plugin_model_providers, + "get_model_providers", + mock_model_providers, ) # Create actual provider and model type instances - model_provider_factory = ModelProviderFactory(tenant_id="test") - provider_instance = model_provider_factory.get_plugin_model_provider("openai") + model_provider_factory = ModelProviderFactory(model_runtime=create_plugin_model_runtime(tenant_id="test")) + provider_instance = model_provider_factory.get_model_provider("openai") model_type_instance = model_provider_factory.get_model_type_instance("openai", ModelType.LLM) # Create a ProviderModelBundle provider_model_bundle = ProviderModelBundle( configuration=ProviderConfiguration( tenant_id="1", - provider=provider_instance.declaration, + provider=provider_instance, preferred_provider_type=ProviderType.CUSTOM, using_provider_type=ProviderType.CUSTOM, system_configuration=SystemConfiguration(enabled=False), @@ -181,13 +191,18 @@ def model_config(monkeypatch): ) -def test_fetch_model_config_uses_ports(model_config: ModelConfigWithCredentialsEntity): +def test_fetch_model_config_hydrates_model_instance_runtime_settings(model_config: ModelConfigWithCredentialsEntity): mock_credentials_provider = mock.MagicMock(spec=CredentialsProvider) mock_model_factory = mock.MagicMock(spec=ModelFactory) provider_model_bundle = model_config.provider_model_bundle model_type_instance = provider_model_bundle.model_type_instance provider_model = mock.MagicMock() + completion_params = { + "temperature": 0.7, + "max_tokens": 256, + "stop": ["Observation:", "Human:"], + } model_instance = mock.MagicMock( model_type_instance=model_type_instance, @@ -208,12 +223,36 @@ def test_fetch_model_config_uses_ports(model_config: ModelConfigWithCredentialsE model_type_instance.__class__, "get_model_schema", return_value=model_config.model_schema, autospec=True ), ): - fetch_model_config( - node_data_model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}), + hydrated_model_instance, model_config_with_credentials = fetch_model_config( + node_data_model=ModelConfig( + provider="openai", + name="gpt-3.5-turbo", + mode="chat", + completion_params=completion_params, + ), credentials_provider=mock_credentials_provider, model_factory=mock_model_factory, ) + assert hydrated_model_instance is model_instance + assert hydrated_model_instance.provider == "openai" + assert hydrated_model_instance.model_name == "gpt-3.5-turbo" + assert hydrated_model_instance.credentials == {"api_key": "test"} + assert hydrated_model_instance.parameters == { + "temperature": 0.7, + "max_tokens": 256, + } + assert hydrated_model_instance.stop == ("Observation:", "Human:") + assert model_config_with_credentials.parameters == { + "temperature": 0.7, + "max_tokens": 256, + } + assert model_config_with_credentials.stop == ["Observation:", "Human:"] + assert completion_params == { + "temperature": 0.7, + "max_tokens": 256, + "stop": ["Observation:", "Human:"], + } mock_credentials_provider.fetch.assert_called_once_with("openai", "gpt-3.5-turbo") mock_model_factory.init_model_instance.assert_called_once_with("openai", "gpt-3.5-turbo") provider_model.raise_for_status.assert_called_once() @@ -230,12 +269,20 @@ def test_dify_model_access_adapters_call_managers(): mock_provider_configuration.get_provider_model.return_value = mock_provider_model mock_provider_configuration.get_current_credentials.return_value = {"api_key": "test"} - credentials_provider = DifyCredentialsProvider( + run_context = DifyRunContext( tenant_id="tenant", + app_id="app", + user_id="user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + ) + + credentials_provider = DifyCredentialsProvider( + run_context=run_context, provider_manager=mock_provider_manager, ) model_factory = DifyModelFactory( - tenant_id="tenant", + run_context=run_context, model_manager=mock_model_manager, ) @@ -255,12 +302,13 @@ def test_dify_model_access_adapters_call_managers(): model="gpt-3.5-turbo", ) mock_provider_model.raise_for_status.assert_called_once() - mock_model_manager.get_model_instance.assert_called_once_with( - tenant_id="tenant", - provider="openai", - model_type=ModelType.LLM, - model="gpt-3.5-turbo", - ) + mock_model_manager.get_model_instance.assert_called_once() + assert mock_model_manager.get_model_instance.call_args.kwargs == { + "tenant_id": "tenant", + "provider": "openai", + "model_type": ModelType.LLM, + "model": "gpt-3.5-turbo", + } def test_fetch_files_with_file_segment(): @@ -592,33 +640,6 @@ def test_handle_list_messages_basic(llm_node): assert result[0].content == [TextPromptMessageContent(data="Hello, world")] -def test_handle_list_messages_jinja2_uses_template_renderer(llm_node): - llm_node._template_renderer.render_jinja2.return_value = "Hello, world" - messages = [ - LLMNodeChatModelMessage( - text="", - jinja2_text="Hello, {{ name }}", - role=PromptMessageRole.USER, - edition_type="jinja2", - ) - ] - - result = llm_node.handle_list_messages( - messages=messages, - context=None, - jinja2_variables=[], - variable_pool=llm_node.graph_runtime_state.variable_pool, - vision_detail_config=ImagePromptMessageContent.DETAIL.HIGH, - template_renderer=llm_node._template_renderer, - ) - - assert result == [UserPromptMessage(content=[TextPromptMessageContent(data="Hello, world")])] - llm_node._template_renderer.render_jinja2.assert_called_once_with( - template="Hello, {{ name }}", - inputs={}, - ) - - def test_handle_memory_completion_mode_uses_prompt_message_interface(): memory = mock.MagicMock(spec=MockTokenBufferMemory) memory.get_history_prompt_messages.return_value = [ @@ -642,8 +663,8 @@ def test_handle_memory_completion_mode_uses_prompt_message_interface(): window=MemoryConfig.WindowConfig(enabled=True, size=3), ) - with mock.patch("dify_graph.nodes.llm.llm_utils.calculate_rest_token", return_value=2000) as mock_rest_token: - memory_text = llm_utils.handle_memory_completion_mode( + with mock.patch("dify_graph.nodes.llm.node._calculate_rest_token", return_value=2000) as mock_rest_token: + memory_text = _handle_memory_completion_mode( memory=memory, memory_config=memory_config, model_instance=model_instance, @@ -659,7 +680,7 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat mock_file_saver: LLMFileSaver = mock.MagicMock(spec=LLMFileSaver) mock_credentials_provider = mock.MagicMock(spec=CredentialsProvider) mock_model_factory = mock.MagicMock(spec=ModelFactory) - mock_template_renderer = mock.MagicMock(spec=TemplateRenderer) + mock_prompt_message_serializer = mock.MagicMock(spec=PromptMessageSerializerProtocol) node_config = { "id": "1", "data": llm_node_data.model_dump(), @@ -674,7 +695,7 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat model_factory=mock_model_factory, model_instance=mock.MagicMock(spec=ModelInstance), llm_file_saver=mock_file_saver, - template_renderer=mock_template_renderer, + prompt_message_serializer=mock_prompt_message_serializer, http_client=http_client, ) return node, mock_file_saver @@ -906,3 +927,42 @@ class TestReasoningFormat: assert clean_text == text_with_think assert reasoning_content == "" + + +def test_dify_model_access_adapters_skip_runtime_build_when_managers_are_injected(): + run_context = DifyRunContext( + tenant_id="tenant", + app_id="app", + user_id="user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + ) + + with mock.patch("core.app.llm.model_access.create_plugin_provider_manager") as mock_provider_manager_factory: + DifyCredentialsProvider(run_context=run_context, provider_manager=mock.MagicMock()) + DifyModelFactory(run_context=run_context, model_manager=mock.MagicMock()) + + mock_provider_manager_factory.assert_not_called() + + +def test_build_dify_model_access_binds_run_context_user_id_once(): + run_context = DifyRunContext( + tenant_id="tenant", + app_id="app", + user_id="user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + ) + + with mock.patch("core.app.llm.model_access.create_plugin_provider_manager") as mock_provider_manager: + build_dify_model_access(run_context) + + mock_provider_manager.assert_called_once_with(tenant_id="tenant", user_id="user") + + +def test_dify_model_access_requires_run_context_argument(): + with pytest.raises(TypeError): + DifyCredentialsProvider() + + with pytest.raises(TypeError): + DifyModelFactory() diff --git a/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py b/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py index 332a8761f9..abe5b31d75 100644 --- a/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py @@ -5,9 +5,9 @@ import pytest from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from dify_graph.enums import BuiltinNodeTypes, ErrorStrategy, WorkflowNodeExecutionStatus from dify_graph.graph import Graph -from dify_graph.nodes.template_transform.template_renderer import TemplateRenderError from dify_graph.nodes.template_transform.template_transform_node import TemplateTransformNode from dify_graph.runtime import GraphRuntimeState +from dify_graph.template_rendering import TemplateRenderError from tests.workflow_test_utils import build_test_graph_init_params @@ -62,7 +62,7 @@ class TestTemplateTransformNode: config={"id": "test_node", "data": basic_node_data}, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, - template_renderer=mock_renderer, + jinja2_template_renderer=mock_renderer, ) assert node.node_type == BuiltinNodeTypes.TEMPLATE_TRANSFORM @@ -78,7 +78,7 @@ class TestTemplateTransformNode: config={"id": "test_node", "data": basic_node_data}, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, - template_renderer=mock_renderer, + jinja2_template_renderer=mock_renderer, ) assert node._get_title() == "Template Transform" @@ -91,7 +91,7 @@ class TestTemplateTransformNode: config={"id": "test_node", "data": basic_node_data}, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, - template_renderer=mock_renderer, + jinja2_template_renderer=mock_renderer, ) assert node._get_description() == "Transform data using template" @@ -111,7 +111,7 @@ class TestTemplateTransformNode: config={"id": "test_node", "data": node_data}, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, - template_renderer=mock_renderer, + jinja2_template_renderer=mock_renderer, ) assert node._get_error_strategy() == ErrorStrategy.FAIL_BRANCH @@ -153,7 +153,7 @@ class TestTemplateTransformNode: config={"id": "test_node", "data": basic_node_data}, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, - template_renderer=mock_renderer, + jinja2_template_renderer=mock_renderer, ) result = node._run() @@ -181,7 +181,7 @@ class TestTemplateTransformNode: config={"id": "test_node", "data": node_data}, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, - template_renderer=mock_renderer, + jinja2_template_renderer=mock_renderer, ) result = node._run() @@ -201,7 +201,7 @@ class TestTemplateTransformNode: config={"id": "test_node", "data": basic_node_data}, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, - template_renderer=mock_renderer, + jinja2_template_renderer=mock_renderer, ) result = node._run() @@ -221,7 +221,7 @@ class TestTemplateTransformNode: config={"id": "test_node", "data": basic_node_data}, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, - template_renderer=mock_renderer, + jinja2_template_renderer=mock_renderer, max_output_length=10, ) @@ -263,7 +263,7 @@ class TestTemplateTransformNode: config={"id": "test_node", "data": node_data}, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, - template_renderer=mock_renderer, + jinja2_template_renderer=mock_renderer, ) result = node._run() @@ -307,7 +307,7 @@ class TestTemplateTransformNode: config={"id": "test_node", "data": node_data}, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, - template_renderer=mock_renderer, + jinja2_template_renderer=mock_renderer, ) result = node._run() @@ -346,7 +346,7 @@ class TestTemplateTransformNode: config={"id": "test_node", "data": node_data}, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, - template_renderer=mock_renderer, + jinja2_template_renderer=mock_renderer, ) result = node._run() @@ -375,7 +375,7 @@ class TestTemplateTransformNode: config={"id": "test_node", "data": node_data}, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, - template_renderer=mock_renderer, + jinja2_template_renderer=mock_renderer, ) result = node._run() @@ -405,7 +405,7 @@ class TestTemplateTransformNode: config={"id": "test_node", "data": node_data}, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, - template_renderer=mock_renderer, + jinja2_template_renderer=mock_renderer, ) result = node._run() diff --git a/api/tests/unit_tests/core/workflow/nodes/test_base_node.py b/api/tests/unit_tests/core/workflow/nodes/test_base_node.py index 2b0205fb7b..6cb7e18ca5 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_base_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_base_node.py @@ -3,6 +3,7 @@ from collections.abc import Mapping import pytest from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.workflow.node_runtime import resolve_dify_run_context from dify_graph.entities import GraphInitParams from dify_graph.entities.base_node_data import BaseNodeData from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter @@ -67,7 +68,7 @@ def test_node_hydrates_data_during_initialization(): assert node.node_data.foo == "bar" assert node.title == "Sample" - dify_ctx = node.require_dify_context() + dify_ctx = resolve_dify_run_context(node.run_context) assert dify_ctx.user_from == "account" assert dify_ctx.invoke_from == "debugger" @@ -91,7 +92,7 @@ def test_node_accepts_invoke_from_enum(): graph_runtime_state=runtime_state, ) - dify_ctx = node.require_dify_context() + dify_ctx = resolve_dify_run_context(node.run_context) assert dify_ctx.user_from == UserFrom.ACCOUNT assert dify_ctx.invoke_from == InvokeFrom.DEBUGGER assert node.get_run_context_value("missing") is None diff --git a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py index c746a945fe..f584fcda3e 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py @@ -4,9 +4,8 @@ from unittest.mock import MagicMock, Mock import pytest -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom from core.workflow.node_factory import DifyNodeFactory -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY from dify_graph.enums import WorkflowNodeExecutionStatus from dify_graph.file import File, FileTransferMethod, FileType from dify_graph.graph import Graph diff --git a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py index 6ca72b64b2..2eee22a678 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py @@ -2,8 +2,7 @@ from unittest.mock import MagicMock import pytest -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom from dify_graph.enums import WorkflowNodeExecutionStatus from dify_graph.file import File, FileTransferMethod, FileType from dify_graph.nodes.list_operator.entities import ( diff --git a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py index 3cbd96dfef..344ded7180 100644 --- a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py @@ -4,15 +4,14 @@ import sys import types from collections.abc import Generator from typing import TYPE_CHECKING, Any -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock import pytest -from core.tools.entities.tool_entities import ToolInvokeMessage -from core.tools.utils.message_transformer import ToolFileMessageTransformer from dify_graph.file import File, FileTransferMethod, FileType from dify_graph.model_runtime.entities.llm_entities import LLMUsage from dify_graph.node_events import StreamChunkEvent, StreamCompletedEvent +from dify_graph.nodes.tool_runtime_entities import ToolRuntimeHandle, ToolRuntimeMessage from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable from dify_graph.variables.segments import ArrayFileSegment @@ -22,6 +21,39 @@ if TYPE_CHECKING: # pragma: no cover - imported for type checking only from dify_graph.nodes.tool.tool_node import ToolNode +class _StubToolRuntime: + def get_runtime(self, *, node_id: str, node_data: Any, variable_pool: Any) -> ToolRuntimeHandle: + raise NotImplementedError + + def get_runtime_parameters(self, *, tool_runtime: ToolRuntimeHandle) -> list[Any]: + return [] + + def invoke( + self, + *, + tool_runtime: ToolRuntimeHandle, + tool_parameters: dict[str, Any], + workflow_call_depth: int, + conversation_id: str | None, + provider_name: str, + ) -> Generator[ToolRuntimeMessage, None, None]: + yield from () + + def get_usage(self, *, tool_runtime: ToolRuntimeHandle) -> LLMUsage: + return LLMUsage.empty_usage() + + def build_file_reference(self, *, mapping: dict[str, Any]) -> Any: + return mapping + + def resolve_provider_icons( + self, + *, + provider_name: str, + default_icon: str | None = None, + ) -> tuple[str | None, str | None]: + return default_icon, None + + @pytest.fixture def tool_node(monkeypatch) -> ToolNode: module_name = "core.ops.ops_trace_manager" @@ -73,6 +105,7 @@ def tool_node(monkeypatch) -> ToolNode: # Provide a stub ToolFileManager to satisfy the updated ToolNode constructor tool_file_manager_factory = MagicMock(spec=ToolFileManagerProtocol) + runtime = _StubToolRuntime() node = ToolNode( id="node-instance", @@ -80,6 +113,7 @@ def tool_node(monkeypatch) -> ToolNode: graph_init_params=init_params, graph_runtime_state=graph_runtime_state, tool_file_manager_factory=tool_file_manager_factory, + runtime=runtime, ) return node @@ -93,24 +127,15 @@ def _collect_events(generator: Generator) -> tuple[list[Any], LLMUsage]: return events, stop.value -def _run_transform(tool_node: ToolNode, message: ToolInvokeMessage) -> tuple[list[Any], LLMUsage]: - def _identity_transform(messages, *_args, **_kwargs): - return messages - - tool_runtime = MagicMock() - with patch.object( - ToolFileMessageTransformer, "transform_tool_invoke_messages", side_effect=_identity_transform, autospec=True - ): - generator = tool_node._transform_message( - messages=iter([message]), - tool_info={"provider_type": "builtin", "provider_id": "provider"}, - parameters_for_log={}, - user_id="user-id", - tenant_id="tenant-id", - node_id=tool_node._node_id, - tool_runtime=tool_runtime, - ) - return _collect_events(generator) +def _run_transform(tool_node: ToolNode, message: ToolRuntimeMessage) -> tuple[list[Any], LLMUsage]: + generator = tool_node._transform_message( + messages=iter([message]), + tool_info={"provider_type": "builtin", "provider_id": "provider"}, + parameters_for_log={}, + node_id=tool_node._node_id, + tool_runtime=ToolRuntimeHandle(raw=object()), + ) + return _collect_events(generator) def test_link_messages_with_file_populate_files_output(tool_node: ToolNode): @@ -125,9 +150,9 @@ def test_link_messages_with_file_populate_files_output(tool_node: ToolNode): size=123, storage_key="file-key", ) - message = ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.LINK, - message=ToolInvokeMessage.TextMessage(text="/files/tools/file-id.pdf"), + message = ToolRuntimeMessage( + type=ToolRuntimeMessage.MessageType.LINK, + message=ToolRuntimeMessage.TextMessage(text="/files/tools/file-id.pdf"), meta={"file": file_obj}, ) @@ -150,9 +175,9 @@ def test_link_messages_with_file_populate_files_output(tool_node: ToolNode): def test_plain_link_messages_remain_links(tool_node: ToolNode): - message = ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.LINK, - message=ToolInvokeMessage.TextMessage(text="https://dify.ai"), + message = ToolRuntimeMessage( + type=ToolRuntimeMessage.MessageType.LINK, + message=ToolRuntimeMessage.TextMessage(text="https://dify.ai"), meta=None, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node_runtime.py b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node_runtime.py new file mode 100644 index 0000000000..b822821f1b --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node_runtime.py @@ -0,0 +1,126 @@ +from __future__ import annotations + +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler +from core.plugin.impl.exc import PluginInvokeError +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.entities.tool_entities import ToolProviderType as CoreToolProviderType +from core.tools.tool_engine import ToolEngine +from core.tools.tool_manager import ToolManager +from core.tools.utils.message_transformer import ToolFileMessageTransformer +from core.workflow.node_runtime import DifyToolNodeRuntime +from dify_graph.model_runtime.entities.llm_entities import LLMUsage +from dify_graph.nodes.tool.entities import ToolNodeData, ToolProviderType +from dify_graph.nodes.tool.exc import ToolRuntimeInvocationError +from dify_graph.nodes.tool_runtime_entities import ToolRuntimeHandle, ToolRuntimeMessage +from tests.workflow_test_utils import build_test_graph_init_params + + +@pytest.fixture +def runtime(monkeypatch) -> DifyToolNodeRuntime: + module_name = "core.ops.ops_trace_manager" + if module_name not in sys.modules: + ops_stub = types.ModuleType(module_name) + ops_stub.TraceQueueManager = object # pragma: no cover - stub attribute + ops_stub.TraceTask = object # pragma: no cover - stub attribute + monkeypatch.setitem(sys.modules, module_name, ops_stub) + + init_params = build_test_graph_init_params( + workflow_id="workflow-id", + graph_config={"nodes": [], "edges": []}, + tenant_id="tenant-id", + app_id="app-id", + user_id="user-id", + user_from="account", + invoke_from="debugger", + call_depth=0, + ) + return DifyToolNodeRuntime(init_params.run_context) + + +def test_invoke_creates_callback_and_converts_messages(runtime: DifyToolNodeRuntime) -> None: + core_message = ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.LINK, + message=ToolInvokeMessage.TextMessage(text="https://dify.ai"), + meta=None, + ) + + with ( + patch.object(ToolEngine, "generic_invoke", return_value=iter([core_message])) as generic_invoke_mock, + patch.object( + ToolFileMessageTransformer, + "transform_tool_invoke_messages", + side_effect=lambda *, messages, **_: messages, + ), + ): + messages = list( + runtime.invoke( + tool_runtime=ToolRuntimeHandle(raw=MagicMock()), + tool_parameters={}, + workflow_call_depth=0, + conversation_id=None, + provider_name="provider", + ) + ) + + assert len(messages) == 1 + graph_message = messages[0] + assert graph_message.type == ToolRuntimeMessage.MessageType.LINK + assert isinstance(graph_message.message, ToolRuntimeMessage.TextMessage) + assert graph_message.message.text == "https://dify.ai" + + callback = generic_invoke_mock.call_args.kwargs["workflow_tool_callback"] + assert isinstance(callback, DifyWorkflowCallbackHandler) + + +def test_invoke_maps_plugin_errors_to_graph_errors(runtime: DifyToolNodeRuntime) -> None: + invoke_error = PluginInvokeError('{"error_type":"RateLimit","message":"too many"}') + + with patch.object(ToolEngine, "generic_invoke", side_effect=invoke_error): + with pytest.raises(ToolRuntimeInvocationError, match="An error occurred in the provider"): + runtime.invoke( + tool_runtime=ToolRuntimeHandle(raw=MagicMock()), + tool_parameters={}, + workflow_call_depth=0, + conversation_id=None, + provider_name="provider", + ) + + +def test_get_usage_normalizes_dict_payload(runtime: DifyToolNodeRuntime) -> None: + usage_payload = LLMUsage.empty_usage().model_dump() + usage_payload["total_tokens"] = 42 + + usage = runtime.get_usage( + tool_runtime=ToolRuntimeHandle(raw=SimpleNamespace(latest_usage=usage_payload)), + ) + + assert usage.total_tokens == 42 + + +def test_get_runtime_converts_graph_provider_type_for_tool_manager(runtime: DifyToolNodeRuntime) -> None: + node_data = ToolNodeData.model_validate( + { + "type": "tool", + "title": "Tool", + "provider_id": "provider", + "provider_type": ToolProviderType.BUILT_IN, + "provider_name": "provider", + "tool_name": "lookup", + "tool_label": "Lookup", + "tool_configurations": {}, + "tool_parameters": {}, + } + ) + + with patch.object(ToolManager, "get_workflow_tool_runtime", return_value=MagicMock()) as runtime_mock: + runtime.get_runtime(node_id="node-id", node_data=node_data, variable_pool=None) + + workflow_tool = runtime_mock.call_args.args[3] + assert workflow_tool.provider_type == CoreToolProviderType.BUILT_IN diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py index e69c05dc0b..6a0548fa1c 100644 --- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py +++ b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py @@ -2,10 +2,9 @@ import time import uuid from uuid import uuid4 -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom from core.workflow.node_factory import DifyNodeFactory from dify_graph.entities import GraphInitParams -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY from dify_graph.graph import Graph from dify_graph.graph_events.node import NodeRunSucceededEvent from dify_graph.nodes.variable_assigner.common import helpers as common_helpers diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py index 6874f3fef1..dd338651b9 100644 --- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py +++ b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py @@ -2,10 +2,9 @@ import time import uuid from uuid import uuid4 -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom from core.workflow.node_factory import DifyNodeFactory from dify_graph.entities import GraphInitParams -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY from dify_graph.graph import Graph from dify_graph.nodes.variable_assigner.v2 import VariableAssignerNode from dify_graph.nodes.variable_assigner.v2.enums import InputType, Operation diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py index 78dd7ce0f3..b8d243a66d 100644 --- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py @@ -8,7 +8,7 @@ when passing files to downstream LLM nodes. from unittest.mock import Mock, patch -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom from core.workflow.nodes.trigger_webhook.entities import ( ContentType, Method, @@ -16,7 +16,7 @@ from core.workflow.nodes.trigger_webhook.entities import ( WebhookData, ) from core.workflow.nodes.trigger_webhook.node import TriggerWebhookNode -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY, GraphInitParams +from dify_graph.entities.graph_init_params import GraphInitParams from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus from dify_graph.runtime.graph_runtime_state import GraphRuntimeState from dify_graph.runtime.variable_pool import VariablePool diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py index 139f65d6c3..15cc07fbeb 100644 --- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py @@ -2,7 +2,7 @@ from unittest.mock import patch import pytest -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE from core.workflow.nodes.trigger_webhook.entities import ( ContentType, @@ -12,7 +12,7 @@ from core.workflow.nodes.trigger_webhook.entities import ( WebhookParameter, ) from core.workflow.nodes.trigger_webhook.node import TriggerWebhookNode -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY, GraphInitParams +from dify_graph.entities.graph_init_params import GraphInitParams from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus from dify_graph.file import File, FileTransferMethod, FileType from dify_graph.runtime.graph_runtime_state import GraphRuntimeState diff --git a/api/tests/unit_tests/core/workflow/test_node_factory.py b/api/tests/unit_tests/core/workflow/test_node_factory.py index 367e3958ad..e427c4b90a 100644 --- a/api/tests/unit_tests/core/workflow/test_node_factory.py +++ b/api/tests/unit_tests/core/workflow/test_node_factory.py @@ -3,11 +3,10 @@ from unittest.mock import MagicMock, patch, sentinel import pytest -from core.app.entities.app_invoke_entities import DifyRunContext, InvokeFrom, UserFrom +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext, InvokeFrom, UserFrom from core.workflow import node_factory from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY from dify_graph.enums import BuiltinNodeTypes, NodeType, SystemVariableKey from dify_graph.nodes.code.entities import CodeLanguage from dify_graph.variables.segments import StringSegment @@ -140,40 +139,22 @@ class TestDefaultWorkflowCodeExecutor: assert executor.is_execution_error(RuntimeError("boom")) is False -class TestDefaultLLMTemplateRenderer: - def test_render_jinja2_delegates_to_code_executor(self, monkeypatch): - renderer = node_factory.DefaultLLMTemplateRenderer() - execute_workflow_code_template = MagicMock(return_value={"result": "hello world"}) - monkeypatch.setattr( - node_factory.CodeExecutor, - "execute_workflow_code_template", - execute_workflow_code_template, - ) - - result = renderer.render_jinja2( - template="Hello {{ name }}", - inputs={"name": "world"}, - ) - - assert result == "hello world" - execute_workflow_code_template.assert_called_once_with( - language=CodeLanguage.JINJA2, - code="Hello {{ name }}", - inputs={"name": "world"}, - ) - - class TestDifyNodeFactoryInit: def test_init_builds_default_dependencies(self): graph_init_params = SimpleNamespace(run_context={"context": "value"}) graph_runtime_state = sentinel.graph_runtime_state - dify_context = SimpleNamespace(tenant_id="tenant-id") - template_renderer = sentinel.template_renderer + dify_context = SimpleNamespace(tenant_id="tenant-id", app_id="app-id", user_id="user-id") + jinja2_template_renderer = sentinel.jinja2_template_renderer unstructured_api_config = sentinel.unstructured_api_config http_request_config = sentinel.http_request_config + file_reference_factory = sentinel.file_reference_factory + prompt_message_serializer = sentinel.prompt_message_serializer + retriever_attachment_loader = sentinel.retriever_attachment_loader + llm_file_saver = sentinel.llm_file_saver credentials_provider = sentinel.credentials_provider model_factory = sentinel.model_factory - llm_template_renderer = sentinel.llm_template_renderer + human_input_runtime = sentinel.human_input_runtime + tool_runtime = sentinel.tool_runtime with ( patch.object( @@ -184,7 +165,7 @@ class TestDifyNodeFactoryInit: patch.object( node_factory, "CodeExecutorJinja2TemplateRenderer", - return_value=template_renderer, + return_value=jinja2_template_renderer, ) as renderer_factory, patch.object( node_factory, @@ -198,9 +179,34 @@ class TestDifyNodeFactoryInit: ), patch.object( node_factory, - "DefaultLLMTemplateRenderer", - return_value=llm_template_renderer, - ) as llm_renderer_factory, + "DifyFileReferenceFactory", + return_value=file_reference_factory, + ), + patch.object( + node_factory, + "DifyPromptMessageSerializer", + return_value=prompt_message_serializer, + ), + patch.object( + node_factory, + "DifyRetrieverAttachmentLoader", + return_value=retriever_attachment_loader, + ), + patch.object( + node_factory, + "build_dify_llm_file_saver", + return_value=llm_file_saver, + ), + patch.object( + node_factory, + "DifyHumanInputNodeRuntime", + return_value=human_input_runtime, + ), + patch.object( + node_factory, + "DifyToolNodeRuntime", + return_value=tool_runtime, + ), patch.object( node_factory, "build_dify_model_access", @@ -213,18 +219,21 @@ class TestDifyNodeFactoryInit: ) resolve_dify_context.assert_called_once_with(graph_init_params.run_context) - build_dify_model_access.assert_called_once_with("tenant-id") + build_dify_model_access.assert_called_once_with(dify_context) renderer_factory.assert_called_once() - llm_renderer_factory.assert_called_once() assert renderer_factory.call_args.kwargs["code_executor"] is factory._code_executor assert factory.graph_init_params is graph_init_params assert factory.graph_runtime_state is graph_runtime_state assert factory._dify_context is dify_context - assert factory._template_renderer is template_renderer - - assert factory._llm_template_renderer is llm_template_renderer + assert factory._jinja2_template_renderer is jinja2_template_renderer assert factory._document_extractor_unstructured_api_config is unstructured_api_config assert factory._http_request_config is http_request_config + assert factory._file_reference_factory is file_reference_factory + assert factory._prompt_message_serializer is prompt_message_serializer + assert factory._retriever_attachment_loader is retriever_attachment_loader + assert factory._llm_file_saver is llm_file_saver + assert factory._human_input_runtime is human_input_runtime + assert factory._tool_runtime is tool_runtime assert factory._llm_credentials_provider is credentials_provider assert factory._llm_model_factory is model_factory @@ -273,11 +282,16 @@ class TestDifyNodeFactoryCreateNode: factory._dify_context = SimpleNamespace(tenant_id="tenant-id", app_id="app-id") factory._code_executor = sentinel.code_executor factory._code_limits = sentinel.code_limits - factory._template_renderer = sentinel.template_renderer - factory._llm_template_renderer = sentinel.llm_template_renderer + factory._jinja2_template_renderer = sentinel.jinja2_template_renderer factory._template_transform_max_output_length = 2048 factory._http_request_http_client = sentinel.http_client - factory._http_request_tool_file_manager_factory = sentinel.tool_file_manager_factory + factory._bound_tool_file_manager_factory = sentinel.tool_file_manager_factory + factory._file_reference_factory = sentinel.file_reference_factory + factory._prompt_message_serializer = sentinel.prompt_message_serializer + factory._retriever_attachment_loader = sentinel.retriever_attachment_loader + factory._llm_file_saver = sentinel.llm_file_saver + factory._human_input_runtime = sentinel.human_input_runtime + factory._tool_runtime = sentinel.tool_runtime factory._http_request_file_manager = sentinel.file_manager factory._document_extractor_unstructured_api_config = sentinel.unstructured_api_config factory._http_request_config = sentinel.http_request_config @@ -394,16 +408,18 @@ class TestDifyNodeFactoryCreateNode: assert kwargs["code_executor"] is sentinel.code_executor assert kwargs["code_limits"] is sentinel.code_limits elif constructor_name == "TemplateTransformNode": - assert kwargs["template_renderer"] is sentinel.template_renderer + assert kwargs["jinja2_template_renderer"] is sentinel.jinja2_template_renderer assert kwargs["max_output_length"] == 2048 elif constructor_name == "HttpRequestNode": assert kwargs["http_request_config"] is sentinel.http_request_config assert kwargs["http_client"] is sentinel.http_client assert kwargs["tool_file_manager_factory"] is sentinel.tool_file_manager_factory assert kwargs["file_manager"] is sentinel.file_manager + assert kwargs["file_reference_factory"] is sentinel.file_reference_factory elif constructor_name == "HumanInputNode": assert kwargs["form_repository"] is form_repository - form_repository_impl.assert_called_once_with(tenant_id="tenant-id") + assert kwargs["runtime"] is sentinel.human_input_runtime + form_repository_impl.assert_called_once_with(tenant_id="tenant-id", app_id="app-id") elif constructor_name == "DocumentExtractorNode": assert kwargs["unstructured_api_config"] is sentinel.unstructured_api_config assert kwargs["http_client"] is sentinel.http_client @@ -416,7 +432,10 @@ class TestDifyNodeFactoryCreateNode: "LLMNode", { "http_client": sentinel.http_client, - "template_renderer": sentinel.llm_template_renderer, + "llm_file_saver": sentinel.llm_file_saver, + "prompt_message_serializer": sentinel.prompt_message_serializer, + "retriever_attachment_loader": sentinel.retriever_attachment_loader, + "jinja2_template_renderer": sentinel.jinja2_template_renderer, }, ), ( @@ -424,10 +443,17 @@ class TestDifyNodeFactoryCreateNode: "QuestionClassifierNode", { "http_client": sentinel.http_client, - "template_renderer": sentinel.llm_template_renderer, + "llm_file_saver": sentinel.llm_file_saver, + "prompt_message_serializer": sentinel.prompt_message_serializer, + }, + ), + ( + BuiltinNodeTypes.PARAMETER_EXTRACTOR, + "ParameterExtractorNode", + { + "prompt_message_serializer": sentinel.prompt_message_serializer, }, ), - (BuiltinNodeTypes.PARAMETER_EXTRACTOR, "ParameterExtractorNode", {}), ], ) def test_creates_model_backed_nodes( @@ -464,7 +490,12 @@ class TestDifyNodeFactoryCreateNode: assert helper_kwargs["node_class"] is constructor assert isinstance(helper_kwargs["node_data"], BaseNodeData) assert helper_kwargs["node_data"].type == node_type + assert helper_kwargs["wrap_model_instance"] is True assert helper_kwargs["include_http_client"] is (node_type != BuiltinNodeTypes.PARAMETER_EXTRACTOR) + assert helper_kwargs["include_llm_file_saver"] is (node_type != BuiltinNodeTypes.PARAMETER_EXTRACTOR) + assert helper_kwargs["include_prompt_message_serializer"] is True + assert helper_kwargs["include_retriever_attachment_loader"] is (node_type == BuiltinNodeTypes.LLM) + assert helper_kwargs["include_jinja2_template_renderer"] is (node_type == BuiltinNodeTypes.LLM) constructor_kwargs = constructor.call_args.kwargs assert constructor_kwargs["id"] == "node-id" @@ -483,96 +514,45 @@ class TestDifyNodeFactoryModelInstance: @pytest.fixture def factory(self): factory = object.__new__(node_factory.DifyNodeFactory) - factory._llm_credentials_provider = MagicMock() - factory._llm_model_factory = MagicMock() + factory._llm_credentials_provider = sentinel.credentials_provider + factory._llm_model_factory = sentinel.model_factory return factory - @pytest.fixture - def llm_model_setup(self, factory): - def _configure( - *, - completion_params=None, - has_provider_model=True, - model_schema=sentinel.model_schema, - ): - credentials = {"api_key": "secret"} - node_data_model = SimpleNamespace( - provider="provider", - name="model", - mode="chat", - completion_params=completion_params or {}, - ) - node_data = SimpleNamespace(model=node_data_model) - provider_model = MagicMock() if has_provider_model else None - provider_model_bundle = SimpleNamespace( - configuration=SimpleNamespace(get_provider_model=MagicMock(return_value=provider_model)) - ) - model_type_instance = MagicMock() - model_type_instance.get_model_schema.return_value = model_schema - model_instance = SimpleNamespace( - provider_model_bundle=provider_model_bundle, - model_type_instance=model_type_instance, - provider=None, - model_name=None, - credentials=None, - parameters=None, - stop=None, - ) - factory._llm_credentials_provider.fetch.return_value = credentials - factory._llm_model_factory.init_model_instance.return_value = model_instance - return SimpleNamespace( - node_data=node_data, - credentials=credentials, - provider_model=provider_model, - model_type_instance=model_type_instance, - model_instance=model_instance, - ) - - return _configure - - def test_requires_llm_mode(self, factory): - node_data = SimpleNamespace( - model=SimpleNamespace( - provider="provider", - name="model", - mode="", - completion_params={}, - ) + def test_delegates_to_fetch_model_config(self, monkeypatch, factory): + node_data_model = SimpleNamespace( + provider="provider", + name="model", + mode="chat", + completion_params={"temperature": 0.3, "stop": ["Human:"]}, ) - - with pytest.raises(node_factory.LLMModeRequiredError, match="LLM mode is required"): - factory._build_model_instance_for_llm_node(node_data) - - def test_raises_when_provider_model_is_missing(self, factory, llm_model_setup): - setup = llm_model_setup(has_provider_model=False) - - with pytest.raises(node_factory.ModelNotExistError, match="Model model not exist"): - factory._build_model_instance_for_llm_node(setup.node_data) - - def test_raises_when_model_schema_is_missing(self, factory, llm_model_setup): - setup = llm_model_setup(model_schema=None) - - with pytest.raises(node_factory.ModelNotExistError, match="Model model not exist"): - factory._build_model_instance_for_llm_node(setup.node_data) - - setup.provider_model.raise_for_status.assert_called_once() - - def test_builds_model_instance_and_normalizes_stop_tokens(self, factory, llm_model_setup): - setup = llm_model_setup( - completion_params={"temperature": 0.3, "stop": "not-a-list"}, - model_schema={"schema": "value"}, + node_data = SimpleNamespace(model=node_data_model) + model_type_instance = MagicMock() + model_instance = SimpleNamespace( + model_type_instance=model_type_instance, + parameters={"temperature": 0.3}, + stop=("Human:",), ) + fetch_model_config = MagicMock(return_value=(model_instance, sentinel.model_config)) + monkeypatch.setattr(node_factory, "fetch_model_config", fetch_model_config) - result = factory._build_model_instance_for_llm_node(setup.node_data) + result = factory._build_model_instance_for_llm_node(node_data) - assert result is setup.model_instance - assert result.provider == "provider" - assert result.model_name == "model" - assert result.credentials == setup.credentials + assert result is model_instance assert result.parameters == {"temperature": 0.3} - assert result.stop == () - assert result.model_type_instance is setup.model_type_instance - setup.provider_model.raise_for_status.assert_called_once() + assert result.stop == ("Human:",) + assert result.model_type_instance is model_type_instance + fetch_model_config.assert_called_once_with( + node_data_model=node_data_model, + credentials_provider=sentinel.credentials_provider, + model_factory=sentinel.model_factory, + ) + + def test_propagates_fetch_model_config_errors(self, monkeypatch, factory): + fetch_model_config = MagicMock(side_effect=ValueError("broken model config")) + monkeypatch.setattr(node_factory, "fetch_model_config", fetch_model_config) + + with pytest.raises(ValueError, match="broken model config"): + factory._build_model_instance_for_llm_node(SimpleNamespace(model=sentinel.node_data_model)) class TestDifyNodeFactoryMemory: diff --git a/api/tests/unit_tests/dify_graph/model_runtime/__base/test_large_language_model_non_stream_parsing.py b/api/tests/unit_tests/dify_graph/model_runtime/__base/test_large_language_model_non_stream_parsing.py index 8dcfd10ec6..1f19d2c2d2 100644 --- a/api/tests/unit_tests/dify_graph/model_runtime/__base/test_large_language_model_non_stream_parsing.py +++ b/api/tests/unit_tests/dify_graph/model_runtime/__base/test_large_language_model_non_stream_parsing.py @@ -4,7 +4,7 @@ from dify_graph.model_runtime.entities.message_entities import ( TextPromptMessageContent, UserPromptMessage, ) -from dify_graph.model_runtime.model_providers.__base.large_language_model import _normalize_non_stream_plugin_result +from dify_graph.model_runtime.model_providers.__base.large_language_model import _normalize_non_stream_runtime_result def _make_chunk( @@ -20,7 +20,7 @@ def _make_chunk( return LLMResultChunk(model=model, delta=delta, system_fingerprint=system_fingerprint) -def test__normalize_non_stream_plugin_result__from_first_chunk_str_content_and_tool_calls(): +def test__normalize_non_stream_runtime_result__from_first_chunk_str_content_and_tool_calls(): prompt_messages = [UserPromptMessage(content="hi")] tool_calls = [ @@ -44,7 +44,7 @@ def test__normalize_non_stream_plugin_result__from_first_chunk_str_content_and_t usage = LLMUsage.empty_usage().model_copy(update={"prompt_tokens": 1, "total_tokens": 1}) chunk = _make_chunk(content="hello", tool_calls=tool_calls, usage=usage, system_fingerprint="fp-1") - result = _normalize_non_stream_plugin_result( + result = _normalize_non_stream_runtime_result( model="test-model", prompt_messages=prompt_messages, result=iter([chunk]) ) @@ -62,20 +62,20 @@ def test__normalize_non_stream_plugin_result__from_first_chunk_str_content_and_t ] -def test__normalize_non_stream_plugin_result__from_first_chunk_list_content(): +def test__normalize_non_stream_runtime_result__from_first_chunk_list_content(): prompt_messages = [UserPromptMessage(content="hi")] content_list = [TextPromptMessageContent(data="a"), TextPromptMessageContent(data="b")] chunk = _make_chunk(content=content_list, usage=LLMUsage.empty_usage()) - result = _normalize_non_stream_plugin_result( + result = _normalize_non_stream_runtime_result( model="test-model", prompt_messages=prompt_messages, result=iter([chunk]) ) assert result.message.content == content_list -def test__normalize_non_stream_plugin_result__passthrough_llm_result(): +def test__normalize_non_stream_runtime_result__passthrough_llm_result(): prompt_messages = [UserPromptMessage(content="hi")] llm_result = LLMResult( model="test-model", @@ -85,15 +85,15 @@ def test__normalize_non_stream_plugin_result__passthrough_llm_result(): ) assert ( - _normalize_non_stream_plugin_result(model="test-model", prompt_messages=prompt_messages, result=llm_result) + _normalize_non_stream_runtime_result(model="test-model", prompt_messages=prompt_messages, result=llm_result) == llm_result ) -def test__normalize_non_stream_plugin_result__empty_iterator_defaults(): +def test__normalize_non_stream_runtime_result__empty_iterator_defaults(): prompt_messages = [UserPromptMessage(content="hi")] - result = _normalize_non_stream_plugin_result(model="test-model", prompt_messages=prompt_messages, result=iter([])) + result = _normalize_non_stream_runtime_result(model="test-model", prompt_messages=prompt_messages, result=iter([])) assert result.model == "test-model" assert result.prompt_messages == prompt_messages @@ -103,7 +103,7 @@ def test__normalize_non_stream_plugin_result__empty_iterator_defaults(): assert result.system_fingerprint is None -def test__normalize_non_stream_plugin_result__accumulates_all_chunks(): +def test__normalize_non_stream_runtime_result__accumulates_all_chunks(): """All chunks are accumulated from the iterator.""" prompt_messages = [UserPromptMessage(content="hi")] @@ -116,7 +116,7 @@ def test__normalize_non_stream_plugin_result__accumulates_all_chunks(): finally: closed.append(True) - result = _normalize_non_stream_plugin_result( + result = _normalize_non_stream_runtime_result( model="test-model", prompt_messages=prompt_messages, result=_chunk_iter(), diff --git a/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_ai_model.py b/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_ai_model.py index 382dce876e..e70c0fe87b 100644 --- a/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_ai_model.py +++ b/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_ai_model.py @@ -2,9 +2,8 @@ import decimal from unittest.mock import MagicMock, patch import pytest -from redis import RedisError +from pydantic import BaseModel -from core.plugin.entities.plugin_daemon import PluginDaemonInnerError, PluginModelProviderEntity from dify_graph.model_runtime.entities.common_entities import I18nObject from dify_graph.model_runtime.entities.model_entities import ( AIModelEntity, @@ -17,6 +16,7 @@ from dify_graph.model_runtime.entities.model_entities import ( PriceConfig, PriceType, ) +from dify_graph.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity from dify_graph.model_runtime.errors.invoke import ( InvokeAuthorizationError, InvokeBadRequestError, @@ -28,46 +28,72 @@ from dify_graph.model_runtime.errors.invoke import ( from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel +class _ConcreteAIModel(AIModel): + model_type = ModelType.LLM + + class TestAIModel: @pytest.fixture - def mock_plugin_model_provider(self): - return MagicMock(spec=PluginModelProviderEntity) - - @pytest.fixture - def ai_model(self, mock_plugin_model_provider): - return AIModel( - tenant_id="tenant_123", - model_type=ModelType.LLM, - plugin_id="plugin_123", - provider_name="test_provider", - plugin_model_provider=mock_plugin_model_provider, + def provider_schema(self) -> ProviderEntity: + return ProviderEntity( + provider="langgenius/openai/openai", + provider_name="openai", + label=I18nObject(en_US="OpenAI"), + supported_model_types=[ModelType.LLM], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], ) - def test_invoke_error_mapping(self, ai_model): + @pytest.fixture + def model_runtime(self) -> MagicMock: + return MagicMock() + + @pytest.fixture + def ai_model(self, provider_schema: ProviderEntity, model_runtime: MagicMock) -> AIModel: + return _ConcreteAIModel( + provider_schema=provider_schema, + model_runtime=model_runtime, + ) + + def test_init_stores_runtime_state_and_is_not_pydantic_model( + self, ai_model: AIModel, provider_schema: ProviderEntity, model_runtime: MagicMock + ) -> None: + assert ai_model.model_type == ModelType.LLM + assert ai_model.provider_schema is provider_schema + assert ai_model.model_runtime is model_runtime + assert ai_model.provider == "langgenius/openai/openai" + assert ai_model.provider_display_name == "OpenAI" + assert ai_model.started_at == 0 + assert not isinstance(ai_model, BaseModel) + + def test_direct_base_class_requires_subclass_model_type( + self, provider_schema: ProviderEntity, model_runtime: MagicMock + ) -> None: + with pytest.raises(TypeError, match="subclasses must define model_type"): + AIModel(provider_schema=provider_schema, model_runtime=model_runtime) + + def test_subclass_uses_class_level_model_type( + self, provider_schema: ProviderEntity, model_runtime: MagicMock + ) -> None: + model = _ConcreteAIModel(provider_schema=provider_schema, model_runtime=model_runtime) + assert model.model_type == ModelType.LLM + + def test_invoke_error_mapping(self, ai_model: AIModel) -> None: mapping = ai_model._invoke_error_mapping assert InvokeConnectionError in mapping assert InvokeServerUnavailableError in mapping assert InvokeRateLimitError in mapping assert InvokeAuthorizationError in mapping assert InvokeBadRequestError in mapping - assert PluginDaemonInnerError in mapping assert ValueError in mapping - def test_transform_invoke_error(self, ai_model): - # Case: mapped error (InvokeAuthorizationError) + def test_transform_invoke_error(self, ai_model: AIModel) -> None: err = Exception("Original error") + with patch.object(AIModel, "_invoke_error_mapping", {InvokeAuthorizationError: [Exception]}): transformed = ai_model._transform_invoke_error(err) assert isinstance(transformed, InvokeAuthorizationError) assert "Incorrect model credentials provided" in str(transformed.description) - # Case: mapped error (InvokeError subclass) - with patch.object(AIModel, "_invoke_error_mapping", {InvokeRateLimitError("Rate limit"): [Exception]}): - transformed = ai_model._transform_invoke_error(err) - assert isinstance(transformed, InvokeError) - assert "[test_provider]" in transformed.description - - # Case: mapped error (not InvokeError) class CustomNonInvokeError(Exception): pass @@ -75,52 +101,43 @@ class TestAIModel: transformed = ai_model._transform_invoke_error(err) assert transformed == err - # Case: unmapped error - unmapped_err = Exception("Unmapped") - transformed = ai_model._transform_invoke_error(unmapped_err) + transformed = ai_model._transform_invoke_error(Exception("Unmapped")) assert isinstance(transformed, InvokeError) - assert "Error: Unmapped" in transformed.description + assert transformed.description == "[OpenAI] Error: Unmapped" - def test_get_price(self, ai_model): + def test_get_price(self, ai_model: AIModel) -> None: model_name = "test_model" credentials = {"key": "value"} - # Mock get_model_schema mock_schema = MagicMock(spec=AIModelEntity) mock_schema.pricing = PriceConfig( input=decimal.Decimal("0.002"), output=decimal.Decimal("0.004"), - unit=decimal.Decimal(1000), # 1000 tokens per unit + unit=decimal.Decimal(1000), currency="USD", ) with patch.object(AIModel, "get_model_schema", return_value=mock_schema): - # Test INPUT price_info = ai_model.get_price(model_name, credentials, PriceType.INPUT, 2000) assert price_info.unit_price == decimal.Decimal("0.002") - # Test OUTPUT price_info = ai_model.get_price(model_name, credentials, PriceType.OUTPUT, 2000) assert price_info.unit_price == decimal.Decimal("0.004") - # Case: unit_price is None (returns zeroed PriceInfo) mock_schema.pricing = None with patch.object(AIModel, "get_model_schema", return_value=mock_schema): price_info = ai_model.get_price(model_name, credentials, PriceType.INPUT, 1000) assert price_info.total_amount == decimal.Decimal("0.0") - def test_get_price_no_price_config_error(self, ai_model): - model_name = "test_model" - - # We need it to be truthy at line 107 and 112 but falsy at line 127. + def test_get_price_no_price_config_error(self, ai_model: AIModel) -> None: class ChangingPriceConfig: - def __init__(self): + def __init__(self) -> None: self.input = decimal.Decimal("0.01") self.unit = decimal.Decimal(1) self.currency = "USD" self.called = 0 - def __bool__(self): + def __bool__(self) -> bool: self.called += 1 return self.called <= 2 @@ -128,11 +145,12 @@ class TestAIModel: mock_schema.pricing = ChangingPriceConfig() with patch.object(AIModel, "get_model_schema", return_value=mock_schema): - with pytest.raises(ValueError) as excinfo: - ai_model.get_price(model_name, {}, PriceType.INPUT, 1000) - assert "Price config not found" in str(excinfo.value) + with pytest.raises(ValueError, match="Price config not found"): + ai_model.get_price("test_model", {}, PriceType.INPUT, 1000) - def test_get_model_schema_cache_hit(self, ai_model): + def test_get_model_schema_delegates_to_runtime( + self, ai_model: AIModel, model_runtime: MagicMock, provider_schema: ProviderEntity + ) -> None: model_name = "test_model" credentials = {"api_key": "abc"} @@ -144,120 +162,21 @@ class TestAIModel: model_properties={ModelPropertyKey.CONTEXT_SIZE: 1024}, parameter_rules=[], ) + model_runtime.get_model_schema.return_value = mock_schema - with patch("dify_graph.model_runtime.model_providers.__base.ai_model.redis_client") as mock_redis: - mock_redis.get.return_value = mock_schema.model_dump_json().encode() + schema = ai_model.get_model_schema(model_name, credentials) - schema = ai_model.get_model_schema(model_name, credentials) - - assert schema.model == "test_model" - mock_redis.get.assert_called_once() - - def test_get_model_schema_cache_miss(self, ai_model): - model_name = "test_model" - credentials = {"api_key": "abc"} - - mock_schema = AIModelEntity( - model="test_model", - label=I18nObject(en_US="Test Model"), + assert schema == mock_schema + model_runtime.get_model_schema.assert_called_once_with( + provider=provider_schema.provider, model_type=ModelType.LLM, - fetch_from=FetchFrom.PREDEFINED_MODEL, - model_properties={ModelPropertyKey.CONTEXT_SIZE: 1024}, - parameter_rules=[], + model=model_name, + credentials=credentials, ) - with ( - patch("dify_graph.model_runtime.model_providers.__base.ai_model.redis_client") as mock_redis, - patch("core.plugin.impl.model.PluginModelClient") as mock_client, - ): - mock_redis.get.return_value = None - mock_manager = mock_client.return_value - mock_manager.get_model_schema.return_value = mock_schema - - schema = ai_model.get_model_schema(model_name, credentials) - - assert schema == mock_schema - mock_manager.get_model_schema.assert_called_once() - mock_redis.setex.assert_called_once() - - def test_get_model_schema_redis_error(self, ai_model): - model_name = "test_model" - - with ( - patch("dify_graph.model_runtime.model_providers.__base.ai_model.redis_client") as mock_redis, - patch("core.plugin.impl.model.PluginModelClient") as mock_client, - ): - mock_redis.get.side_effect = RedisError("Connection refused") - mock_manager = mock_client.return_value - mock_manager.get_model_schema.return_value = None - - schema = ai_model.get_model_schema(model_name, {}) - - assert schema is None - mock_manager.get_model_schema.assert_called_once() - - def test_get_model_schema_validation_error(self, ai_model): - model_name = "test_model" - - with ( - patch("dify_graph.model_runtime.model_providers.__base.ai_model.redis_client") as mock_redis, - patch("core.plugin.impl.model.PluginModelClient") as mock_client, - ): - mock_redis.get.return_value = b"invalid json" - mock_manager = mock_client.return_value - mock_manager.get_model_schema.return_value = None - - # This should trigger ValidationError at line 166 and go to delete() - schema = ai_model.get_model_schema(model_name, {}) - - assert schema is None - mock_redis.delete.assert_called() - - def test_get_model_schema_redis_delete_error(self, ai_model): - model_name = "test_model" - - with ( - patch("dify_graph.model_runtime.model_providers.__base.ai_model.redis_client") as mock_redis, - patch("core.plugin.impl.model.PluginModelClient") as mock_client, - ): - mock_redis.get.return_value = b'{"invalid": "schema"}' - mock_redis.delete.side_effect = RedisError("Delete failed") - mock_manager = mock_client.return_value - mock_manager.get_model_schema.return_value = None - - schema = ai_model.get_model_schema(model_name, {}) - - assert schema is None - mock_redis.delete.assert_called() - - def test_get_model_schema_redis_setex_error(self, ai_model): - model_name = "test_model" - mock_schema = AIModelEntity( - model="test_model", - label=I18nObject(en_US="Test Model"), - model_type=ModelType.LLM, - fetch_from=FetchFrom.PREDEFINED_MODEL, - model_properties={ModelPropertyKey.CONTEXT_SIZE: 1024}, - parameter_rules=[], - ) - - with ( - patch("dify_graph.model_runtime.model_providers.__base.ai_model.redis_client") as mock_redis, - patch("core.plugin.impl.model.PluginModelClient") as mock_client, - ): - mock_redis.get.return_value = None - mock_redis.setex.side_effect = RuntimeError("Setex failed") - mock_manager = mock_client.return_value - mock_manager.get_model_schema.return_value = mock_schema - - schema = ai_model.get_model_schema(model_name, {}) - - assert schema == mock_schema - mock_redis.setex.assert_called() - - def test_get_customizable_model_schema_from_credentials_template_mapping_value_error(self, ai_model): - model_name = "test_model" - + def test_get_customizable_model_schema_from_credentials_template_mapping_value_error( + self, ai_model: AIModel + ) -> None: mock_schema = AIModelEntity( model="test_model", label=I18nObject(en_US="Test Model"), @@ -275,12 +194,11 @@ class TestAIModel: ) with patch.object(AIModel, "get_customizable_model_schema", return_value=mock_schema): - schema = ai_model.get_customizable_model_schema_from_credentials(model_name, {}) + schema = ai_model.get_customizable_model_schema_from_credentials("test_model", {}) + assert schema is not None assert schema.parameter_rules[0].use_template == "invalid_template_name" - def test_get_customizable_model_schema_from_credentials(self, ai_model): - model_name = "test_model" - + def test_get_customizable_model_schema_from_credentials(self, ai_model: AIModel) -> None: mock_schema = AIModelEntity( model="test_model", label=I18nObject(en_US="Test Model"), @@ -310,27 +228,27 @@ class TestAIModel: ) with patch.object(AIModel, "get_customizable_model_schema", return_value=mock_schema): - schema = ai_model.get_customizable_model_schema_from_credentials(model_name, {}) + schema = ai_model.get_customizable_model_schema_from_credentials("test_model", {}) + assert schema is not None assert schema.parameter_rules[0].max == 1.0 + assert schema.parameter_rules[1].help is not None assert schema.parameter_rules[1].help.en_US != "" + assert schema.parameter_rules[2].help is not None assert schema.parameter_rules[2].help.zh_Hans != "" assert schema.parameter_rules[3].use_template is None - def test_get_customizable_model_schema_from_credentials_none(self, ai_model): + def test_get_customizable_model_schema_from_credentials_none(self, ai_model: AIModel) -> None: with patch.object(AIModel, "get_customizable_model_schema", return_value=None): schema = ai_model.get_customizable_model_schema_from_credentials("model", {}) assert schema is None - def test_get_customizable_model_schema_default(self, ai_model): + def test_get_customizable_model_schema_default(self, ai_model: AIModel) -> None: assert ai_model.get_customizable_model_schema("model", {}) is None - def test_get_default_parameter_rule_variable_map(self, ai_model): - # Valid - res = ai_model._get_default_parameter_rule_variable_map(DefaultParameterName.TEMPERATURE) - assert res["default"] == 0.0 + def test_get_default_parameter_rule_variable_map(self, ai_model: AIModel) -> None: + result = ai_model._get_default_parameter_rule_variable_map(DefaultParameterName.TEMPERATURE) + assert result["default"] == 0.0 - # Invalid - with pytest.raises(Exception) as excinfo: + with pytest.raises(Exception, match="Invalid model parameter rule name"): ai_model._get_default_parameter_rule_variable_map("invalid_name") - assert "Invalid model parameter rule name" in str(excinfo.value) diff --git a/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_runtime_user_forwarding.py b/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_runtime_user_forwarding.py new file mode 100644 index 0000000000..f189634943 --- /dev/null +++ b/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_runtime_user_forwarding.py @@ -0,0 +1,170 @@ +from decimal import Decimal +from io import BytesIO +from unittest.mock import MagicMock + +from dify_graph.model_runtime.entities.common_entities import I18nObject +from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage +from dify_graph.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity +from dify_graph.model_runtime.entities.rerank_entities import RerankResult +from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingResult, EmbeddingUsage +from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from dify_graph.model_runtime.model_providers.__base.moderation_model import ModerationModel +from dify_graph.model_runtime.model_providers.__base.rerank_model import RerankModel +from dify_graph.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel +from dify_graph.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel +from dify_graph.model_runtime.model_providers.__base.tts_model import TTSModel + + +def _provider_schema(model_type: ModelType) -> ProviderEntity: + return ProviderEntity( + provider="langgenius/openai/openai", + provider_name="openai", + label=I18nObject(en_US="OpenAI"), + supported_model_types=[model_type], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + ) + + +def _embedding_usage() -> EmbeddingUsage: + return EmbeddingUsage( + tokens=1, + total_tokens=1, + unit_price=Decimal(0), + price_unit=Decimal(0), + total_price=Decimal(0), + currency="USD", + latency=0.0, + ) + + +def test_large_language_model_invokes_runtime_without_user_id() -> None: + runtime = MagicMock() + runtime.invoke_llm.return_value = LLMResult( + model="gpt-4o-mini", + prompt_messages=[], + message=AssistantPromptMessage(content="ok"), + usage=LLMUsage.empty_usage(), + ) + model = LargeLanguageModel(provider_schema=_provider_schema(ModelType.LLM), model_runtime=runtime) + + model.invoke( + model="gpt-4o-mini", + credentials={"api_key": "secret"}, + prompt_messages=[UserPromptMessage(content="hi")], + stream=False, + ) + + assert "user_id" not in runtime.invoke_llm.call_args.kwargs + + +def test_text_embedding_model_invokes_runtime_without_user_id_for_text_requests() -> None: + runtime = MagicMock() + runtime.invoke_text_embedding.return_value = EmbeddingResult( + model="text-embedding-3-small", + embeddings=[[0.1]], + usage=_embedding_usage(), + ) + model = TextEmbeddingModel(provider_schema=_provider_schema(ModelType.TEXT_EMBEDDING), model_runtime=runtime) + + model.invoke( + model="text-embedding-3-small", + credentials={"api_key": "secret"}, + texts=["hello"], + ) + + assert "user_id" not in runtime.invoke_text_embedding.call_args.kwargs + + +def test_text_embedding_model_invokes_runtime_without_user_id_for_multimodal_requests() -> None: + runtime = MagicMock() + runtime.invoke_multimodal_embedding.return_value = EmbeddingResult( + model="text-embedding-3-small", + embeddings=[[0.1]], + usage=_embedding_usage(), + ) + model = TextEmbeddingModel(provider_schema=_provider_schema(ModelType.TEXT_EMBEDDING), model_runtime=runtime) + + model.invoke( + model="text-embedding-3-small", + credentials={"api_key": "secret"}, + multimodel_documents=[{"content": "hello", "content_type": "text"}], + ) + + assert "user_id" not in runtime.invoke_multimodal_embedding.call_args.kwargs + + +def test_rerank_model_invokes_runtime_without_user_id_for_text_requests() -> None: + runtime = MagicMock() + runtime.invoke_rerank.return_value = RerankResult(model="rerank", docs=[]) + model = RerankModel(provider_schema=_provider_schema(ModelType.RERANK), model_runtime=runtime) + + model.invoke( + model="rerank", + credentials={"api_key": "secret"}, + query="q", + docs=["d1"], + ) + + assert "user_id" not in runtime.invoke_rerank.call_args.kwargs + + +def test_rerank_model_invokes_runtime_without_user_id_for_multimodal_requests() -> None: + runtime = MagicMock() + runtime.invoke_multimodal_rerank.return_value = RerankResult(model="rerank", docs=[]) + model = RerankModel(provider_schema=_provider_schema(ModelType.RERANK), model_runtime=runtime) + + model.invoke_multimodal_rerank( + model="rerank", + credentials={"api_key": "secret"}, + query={"content": "q", "content_type": "text"}, + docs=[{"content": "d1", "content_type": "text"}], + ) + + assert "user_id" not in runtime.invoke_multimodal_rerank.call_args.kwargs + + +def test_tts_model_invokes_runtime_without_user_id() -> None: + runtime = MagicMock() + runtime.invoke_tts.return_value = [b"chunk"] + model = TTSModel(provider_schema=_provider_schema(ModelType.TTS), model_runtime=runtime) + + list( + model.invoke( + model="tts-1", + credentials={"api_key": "secret"}, + content_text="hello", + voice="alloy", + ) + ) + + assert "user_id" not in runtime.invoke_tts.call_args.kwargs + + +def test_speech_to_text_model_invokes_runtime_without_user_id() -> None: + runtime = MagicMock() + runtime.invoke_speech_to_text.return_value = "transcript" + model = Speech2TextModel(provider_schema=_provider_schema(ModelType.SPEECH2TEXT), model_runtime=runtime) + + model.invoke( + model="whisper-1", + credentials={"api_key": "secret"}, + file=BytesIO(b"audio"), + ) + + assert "user_id" not in runtime.invoke_speech_to_text.call_args.kwargs + + +def test_moderation_model_invokes_runtime_without_user_id() -> None: + runtime = MagicMock() + runtime.invoke_moderation.return_value = True + model = ModerationModel(provider_schema=_provider_schema(ModelType.MODERATION), model_runtime=runtime) + + model.invoke( + model="omni-moderation-latest", + credentials={"api_key": "secret"}, + text="unsafe?", + ) + + assert "user_id" not in runtime.invoke_moderation.call_args.kwargs diff --git a/api/tests/unit_tests/services/dataset_service_update_delete.py b/api/tests/unit_tests/services/dataset_service_update_delete.py index c805dd98e2..00af35aba8 100644 --- a/api/tests/unit_tests/services/dataset_service_update_delete.py +++ b/api/tests/unit_tests/services/dataset_service_update_delete.py @@ -591,7 +591,7 @@ class TestDatasetServiceUpdateRagPipelineDatasetSettings: patch( "services.dataset_service.current_user", create_autospec(Account, instance=True) ) as mock_current_user, - patch("services.dataset_service.ModelManager") as mock_model_manager, + patch("services.dataset_service.ModelManager.for_tenant") as mock_model_manager, patch( "services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding" ) as mock_get_binding, diff --git a/api/tests/unit_tests/services/document_service_validation.py b/api/tests/unit_tests/services/document_service_validation.py index 6829691507..195259680c 100644 --- a/api/tests/unit_tests/services/document_service_validation.py +++ b/api/tests/unit_tests/services/document_service_validation.py @@ -430,7 +430,7 @@ class TestDatasetServiceCheckDatasetModelSetting: Provides a mocked ModelManager that can be used to verify model instance retrieval and error handling. """ - with patch("services.dataset_service.ModelManager") as mock_manager: + with patch("services.dataset_service.ModelManager.for_tenant") as mock_manager: yield mock_manager def test_check_dataset_model_setting_high_quality_success(self, mock_model_manager): @@ -579,7 +579,7 @@ class TestDatasetServiceCheckEmbeddingModelSetting: Provides a mocked ModelManager that can be used to verify model instance retrieval and error handling. """ - with patch("services.dataset_service.ModelManager") as mock_manager: + with patch("services.dataset_service.ModelManager.for_tenant") as mock_manager: yield mock_manager def test_check_embedding_model_setting_success(self, mock_model_manager): @@ -701,7 +701,7 @@ class TestDatasetServiceCheckRerankingModelSetting: Provides a mocked ModelManager that can be used to verify model instance retrieval and error handling. """ - with patch("services.dataset_service.ModelManager") as mock_manager: + with patch("services.dataset_service.ModelManager.for_tenant") as mock_manager: yield mock_manager def test_check_reranking_model_setting_success(self, mock_model_manager): diff --git a/api/tests/unit_tests/services/segment_service.py b/api/tests/unit_tests/services/segment_service.py index affbc8d0b5..b8c7bfda99 100644 --- a/api/tests/unit_tests/services/segment_service.py +++ b/api/tests/unit_tests/services/segment_service.py @@ -265,7 +265,7 @@ class TestSegmentServiceCreateSegment: patch( "services.dataset_service.VectorService.create_segments_vector", autospec=True ) as mock_vector_service, - patch("services.dataset_service.ModelManager", autospec=True) as mock_model_manager_class, + patch("services.dataset_service.ModelManager.for_tenant", autospec=True) as mock_model_manager_class, patch("services.dataset_service.helper.generate_text_hash", autospec=True) as mock_hash, patch("services.dataset_service.naive_utc_now", autospec=True) as mock_now, ): diff --git a/api/tests/unit_tests/services/test_app_service.py b/api/tests/unit_tests/services/test_app_service.py index bff8dc92c6..ac93a0d18f 100644 --- a/api/tests/unit_tests/services/test_app_service.py +++ b/api/tests/unit_tests/services/test_app_service.py @@ -136,7 +136,7 @@ class TestAppServiceCreate: patch("services.app_service.default_app_templates", app_template), patch("services.app_service.App", return_value=app_instance), patch("services.app_service.AppModelConfig", return_value=app_model_config), - patch("services.app_service.ModelManager") as mock_model_manager, + patch("services.app_service.ModelManager.for_tenant") as mock_model_manager, patch("services.app_service.db") as mock_db, patch("services.app_service.app_was_created") as mock_event, patch("services.app_service.FeatureService.get_system_features") as mock_features, @@ -182,7 +182,7 @@ class TestAppServiceCreate: with ( patch("services.app_service.default_app_templates", app_template), patch("services.app_service.App", return_value=app_instance), - patch("services.app_service.ModelManager") as mock_model_manager, + patch("services.app_service.ModelManager.for_tenant") as mock_model_manager, patch("services.app_service.db") as mock_db, ): manager = mock_model_manager.return_value @@ -209,7 +209,7 @@ class TestAppServiceCreate: patch("services.app_service.default_app_templates", app_template), patch("services.app_service.App", return_value=app_instance), patch("services.app_service.AppModelConfig", return_value=app_model_config), - patch("services.app_service.ModelManager") as mock_model_manager, + patch("services.app_service.ModelManager.for_tenant") as mock_model_manager, patch("services.app_service.db") as mock_db, patch("services.app_service.app_was_created") as mock_event, patch("services.app_service.FeatureService.get_system_features") as mock_features, @@ -247,7 +247,7 @@ class TestAppServiceCreate: patch("services.app_service.default_app_templates", app_template), patch("services.app_service.App", return_value=app_instance), patch("services.app_service.AppModelConfig", return_value=app_model_config), - patch("services.app_service.ModelManager") as mock_model_manager, + patch("services.app_service.ModelManager.for_tenant") as mock_model_manager, patch("services.app_service.db"), patch("services.app_service.app_was_created"), patch( diff --git a/api/tests/unit_tests/services/test_audio_service.py b/api/tests/unit_tests/services/test_audio_service.py index 5d67469105..e60bd41a62 100644 --- a/api/tests/unit_tests/services/test_audio_service.py +++ b/api/tests/unit_tests/services/test_audio_service.py @@ -214,7 +214,7 @@ def factory(): class TestAudioServiceASR: """Test speech-to-text (ASR) operations.""" - @patch("services.audio_service.ModelManager", autospec=True) + @patch("services.audio_service.ModelManager.for_tenant", autospec=True) def test_transcript_asr_success_chat_mode(self, mock_model_manager_class, factory): """Test successful ASR transcription in CHAT mode.""" # Arrange @@ -240,7 +240,7 @@ class TestAudioServiceASR: call_args = mock_model_instance.invoke_speech2text.call_args assert call_args.kwargs["user"] == "user-123" - @patch("services.audio_service.ModelManager", autospec=True) + @patch("services.audio_service.ModelManager.for_tenant", autospec=True) def test_transcript_asr_success_advanced_chat_mode(self, mock_model_manager_class, factory): """Test successful ASR transcription in ADVANCED_CHAT mode.""" # Arrange @@ -347,7 +347,7 @@ class TestAudioServiceASR: with pytest.raises(AudioTooLargeServiceError, match="Audio size larger than 30 mb"): AudioService.transcript_asr(app_model=app, file=file) - @patch("services.audio_service.ModelManager", autospec=True) + @patch("services.audio_service.ModelManager.for_tenant", autospec=True) def test_transcript_asr_raises_error_when_no_model_instance(self, mock_model_manager_class, factory): """Test that ASR raises error when no model instance is available.""" # Arrange @@ -370,7 +370,7 @@ class TestAudioServiceASR: class TestAudioServiceTTS: """Test text-to-speech (TTS) operations.""" - @patch("services.audio_service.ModelManager", autospec=True) + @patch("services.audio_service.ModelManager.for_tenant", autospec=True) def test_transcript_tts_with_text_success(self, mock_model_manager_class, factory): """Test successful TTS with text input.""" # Arrange @@ -406,7 +406,7 @@ class TestAudioServiceTTS: ) @patch("services.audio_service.db.session", autospec=True) - @patch("services.audio_service.ModelManager", autospec=True) + @patch("services.audio_service.ModelManager.for_tenant", autospec=True) def test_transcript_tts_with_message_id_success(self, mock_model_manager_class, mock_db_session, factory): """Test successful TTS with message ID.""" # Arrange @@ -445,7 +445,7 @@ class TestAudioServiceTTS: assert result == b"audio from message" mock_model_instance.invoke_tts.assert_called_once() - @patch("services.audio_service.ModelManager", autospec=True) + @patch("services.audio_service.ModelManager.for_tenant", autospec=True) def test_transcript_tts_with_default_voice(self, mock_model_manager_class, factory): """Test TTS uses default voice when none specified.""" # Arrange @@ -475,7 +475,7 @@ class TestAudioServiceTTS: call_args = mock_model_instance.invoke_tts.call_args assert call_args.kwargs["voice"] == "default-voice" - @patch("services.audio_service.ModelManager", autospec=True) + @patch("services.audio_service.ModelManager.for_tenant", autospec=True) def test_transcript_tts_gets_first_available_voice_when_none_configured(self, mock_model_manager_class, factory): """Test TTS gets first available voice when none is configured.""" # Arrange @@ -506,7 +506,7 @@ class TestAudioServiceTTS: assert call_args.kwargs["voice"] == "auto-voice" @patch("services.audio_service.WorkflowService", autospec=True) - @patch("services.audio_service.ModelManager", autospec=True) + @patch("services.audio_service.ModelManager.for_tenant", autospec=True) def test_transcript_tts_workflow_mode_with_draft( self, mock_model_manager_class, mock_workflow_service_class, factory ): @@ -611,7 +611,7 @@ class TestAudioServiceTTS: # Assert assert result is None - @patch("services.audio_service.ModelManager", autospec=True) + @patch("services.audio_service.ModelManager.for_tenant", autospec=True) def test_transcript_tts_raises_error_when_no_voices_available(self, mock_model_manager_class, factory): """Test that TTS raises error when no voices are available.""" # Arrange @@ -637,7 +637,7 @@ class TestAudioServiceTTS: class TestAudioServiceTTSVoices: """Test TTS voice listing operations.""" - @patch("services.audio_service.ModelManager", autospec=True) + @patch("services.audio_service.ModelManager.for_tenant", autospec=True) def test_transcript_tts_voices_success(self, mock_model_manager_class, factory): """Test successful retrieval of TTS voices.""" # Arrange @@ -662,7 +662,7 @@ class TestAudioServiceTTSVoices: assert result == expected_voices mock_model_instance.get_tts_voices.assert_called_once_with(language) - @patch("services.audio_service.ModelManager", autospec=True) + @patch("services.audio_service.ModelManager.for_tenant", autospec=True) def test_transcript_tts_voices_raises_error_when_no_model_instance(self, mock_model_manager_class, factory): """Test that TTS voices raises error when no model instance is available.""" # Arrange @@ -677,7 +677,7 @@ class TestAudioServiceTTSVoices: with pytest.raises(ProviderNotSupportTextToSpeechServiceError): AudioService.transcript_tts_voices(tenant_id=tenant_id, language=language) - @patch("services.audio_service.ModelManager", autospec=True) + @patch("services.audio_service.ModelManager.for_tenant", autospec=True) def test_transcript_tts_voices_propagates_exceptions(self, mock_model_manager_class, factory): """Test that TTS voices propagates exceptions from model instance.""" # Arrange diff --git a/api/tests/unit_tests/services/test_message_service.py b/api/tests/unit_tests/services/test_message_service.py index e7740ef93a..101b9bff24 100644 --- a/api/tests/unit_tests/services/test_message_service.py +++ b/api/tests/unit_tests/services/test_message_service.py @@ -933,7 +933,7 @@ class TestMessageServiceSuggestedQuestions: ) # Test 28: get_suggested_questions_after_answer - Advanced Chat success - @patch("services.message_service.ModelManager") + @patch("services.message_service.ModelManager.for_tenant") @patch("services.message_service.WorkflowService") @patch("services.message_service.AdvancedChatAppConfigManager") @patch("services.message_service.TokenBufferMemory") @@ -983,7 +983,7 @@ class TestMessageServiceSuggestedQuestions: # Test 29: get_suggested_questions_after_answer - Chat app success (no override) @patch("services.message_service.db") - @patch("services.message_service.ModelManager") + @patch("services.message_service.ModelManager.for_tenant") @patch("services.message_service.TokenBufferMemory") @patch("services.message_service.LLMGenerator") @patch("services.message_service.TraceQueueManager") diff --git a/api/tests/unit_tests/services/test_model_provider_service_sanitization.py b/api/tests/unit_tests/services/test_model_provider_service_sanitization.py index 6a6b63f003..60567e7a74 100644 --- a/api/tests/unit_tests/services/test_model_provider_service_sanitization.py +++ b/api/tests/unit_tests/services/test_model_provider_service_sanitization.py @@ -71,7 +71,7 @@ def service_with_fake_configurations(): return _FakeConfigurations(fake_provider_configuration) svc = ModelProviderService() - svc.provider_manager = _FakeProviderManager() + svc._get_provider_manager = lambda tenant_id: _FakeProviderManager() return svc diff --git a/api/tests/unit_tests/services/vector_service.py b/api/tests/unit_tests/services/vector_service.py index c99275c6b2..960c9aa360 100644 --- a/api/tests/unit_tests/services/vector_service.py +++ b/api/tests/unit_tests/services/vector_service.py @@ -521,7 +521,7 @@ class TestVectorService: assert call_args[1]["keywords_list"] == keywords_list @patch("services.vector_service.VectorService.generate_child_chunks") - @patch("services.vector_service.ModelManager") + @patch("services.vector_service.ModelManager.for_tenant") @patch("services.vector_service.db") def test_create_segments_vector_parent_child_indexing( self, mock_db, mock_model_manager, mock_generate_child_chunks @@ -1754,7 +1754,7 @@ class TestVector: # ======================================================================== @patch("core.rag.datasource.vdb.vector_factory.CacheEmbedding") - @patch("core.rag.datasource.vdb.vector_factory.ModelManager") + @patch("core.rag.datasource.vdb.vector_factory.ModelManager.for_tenant") @patch("core.rag.datasource.vdb.vector_factory.Vector._init_vector") def test_vector_get_embeddings(self, mock_init_vector, mock_model_manager, mock_cache_embedding): """