mirror of
https://github.com/langgenius/dify.git
synced 2026-05-30 16:00:32 -04:00
feat: extract model runtime
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
@@ -4,8 +4,8 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti
|
||||
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
|
||||
from core.entities.provider_entities import CustomConfiguration, CustomProviderConfiguration, SystemConfiguration
|
||||
from core.model_manager import ModelInstance
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from models.provider import ProviderType
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ def get_mocked_fetch_model_config(
|
||||
mode: str,
|
||||
credentials: dict,
|
||||
):
|
||||
model_provider_factory = ModelProviderFactory(tenant_id="9d2074fc-6f86-45a9-b09d-6ecc63b9056b")
|
||||
model_provider_factory = create_plugin_model_provider_factory(tenant_id="9d2074fc-6f86-45a9-b09d-6ecc63b9056b")
|
||||
model_type_instance = model_provider_factory.get_model_type_instance(provider, ModelType.LLM)
|
||||
provider_model_bundle = ProviderModelBundle(
|
||||
configuration=ProviderConfiguration(
|
||||
|
||||
@@ -9,6 +9,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom
|
||||
from core.helper.ssrf_proxy import ssrf_proxy
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.workflow.node_factory import DifyNodeFactory
|
||||
from core.workflow.node_runtime import DifyFileReferenceFactory
|
||||
from dify_graph.enums import WorkflowNodeExecutionStatus
|
||||
from dify_graph.file.file_manager import file_manager
|
||||
from dify_graph.graph import Graph
|
||||
@@ -81,6 +82,7 @@ def init_http_node(config: dict):
|
||||
http_client=ssrf_proxy,
|
||||
tool_file_manager_factory=ToolFileManager,
|
||||
file_manager=file_manager,
|
||||
file_reference_factory=DifyFileReferenceFactory(init_params.run_context),
|
||||
)
|
||||
|
||||
return node
|
||||
@@ -728,6 +730,7 @@ def test_nested_object_variable_selector(setup_http_mock):
|
||||
http_client=ssrf_proxy,
|
||||
tool_file_manager_factory=ToolFileManager,
|
||||
file_manager=file_manager,
|
||||
file_reference_factory=DifyFileReferenceFactory(init_params.run_context),
|
||||
)
|
||||
|
||||
result = node._run()
|
||||
|
||||
@@ -9,8 +9,10 @@ from core.llm_generator.output_parser.structured_output import _parse_structured
|
||||
from core.model_manager import ModelInstance
|
||||
from dify_graph.enums import WorkflowNodeExecutionStatus
|
||||
from dify_graph.node_events import StreamCompletedEvent
|
||||
from dify_graph.nodes.llm.file_saver import LLMFileSaver
|
||||
from dify_graph.nodes.llm.node import LLMNode
|
||||
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer
|
||||
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||
from dify_graph.nodes.llm.runtime_protocols import PromptMessageSerializerProtocol
|
||||
from dify_graph.nodes.protocols import HttpClientProtocol
|
||||
from dify_graph.runtime import GraphRuntimeState, VariablePool
|
||||
from dify_graph.system_variable import SystemVariable
|
||||
@@ -66,6 +68,11 @@ def init_llm_node(config: dict) -> LLMNode:
|
||||
variable_pool.add(["abc", "output"], "sunny")
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
prompt_message_serializer = MagicMock(spec=PromptMessageSerializerProtocol)
|
||||
prompt_message_serializer.serialize.side_effect = lambda *, model_mode, prompt_messages: [
|
||||
message.model_dump(mode="json") for message in prompt_messages
|
||||
]
|
||||
llm_file_saver = MagicMock(spec=LLMFileSaver)
|
||||
|
||||
node = LLMNode(
|
||||
id=str(uuid.uuid4()),
|
||||
@@ -75,7 +82,8 @@ def init_llm_node(config: dict) -> LLMNode:
|
||||
credentials_provider=MagicMock(spec=CredentialsProvider),
|
||||
model_factory=MagicMock(spec=ModelFactory),
|
||||
model_instance=MagicMock(spec=ModelInstance),
|
||||
template_renderer=MagicMock(spec=TemplateRenderer),
|
||||
llm_file_saver=llm_file_saver,
|
||||
prompt_message_serializer=prompt_message_serializer,
|
||||
http_client=MagicMock(spec=HttpClientProtocol),
|
||||
)
|
||||
|
||||
@@ -159,7 +167,7 @@ def test_execute_llm():
|
||||
return mock_model_instance
|
||||
|
||||
# Mock fetch_prompt_messages to avoid database calls
|
||||
def mock_fetch_prompt_messages_1(*_args, **_kwargs):
|
||||
def mock_fetch_prompt_messages_1(**_kwargs):
|
||||
from dify_graph.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage
|
||||
|
||||
return [
|
||||
|
||||
@@ -5,6 +5,7 @@ from unittest.mock import MagicMock
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom
|
||||
from core.model_manager import ModelInstance
|
||||
from core.workflow.node_runtime import DifyPromptMessageSerializer
|
||||
from dify_graph.enums import WorkflowNodeExecutionStatus
|
||||
from dify_graph.model_runtime.entities import AssistantPromptMessage, UserPromptMessage
|
||||
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||
@@ -77,6 +78,7 @@ def init_parameter_extractor_node(config: dict, memory=None):
|
||||
model_factory=MagicMock(spec=ModelFactory),
|
||||
model_instance=MagicMock(spec=ModelInstance),
|
||||
memory=memory,
|
||||
prompt_message_serializer=DifyPromptMessageSerializer(),
|
||||
)
|
||||
return node
|
||||
|
||||
|
||||
@@ -5,10 +5,10 @@ from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom
|
||||
from core.workflow.node_factory import DifyNodeFactory
|
||||
from dify_graph.enums import WorkflowNodeExecutionStatus
|
||||
from dify_graph.graph import Graph
|
||||
from dify_graph.nodes.template_transform.template_renderer import TemplateRenderError
|
||||
from dify_graph.nodes.template_transform.template_transform_node import TemplateTransformNode
|
||||
from dify_graph.runtime import GraphRuntimeState, VariablePool
|
||||
from dify_graph.system_variable import SystemVariable
|
||||
from dify_graph.template_rendering import TemplateRenderError
|
||||
from tests.workflow_test_utils import build_test_graph_init_params
|
||||
|
||||
|
||||
@@ -90,7 +90,7 @@ def test_execute_template_transform():
|
||||
config=config,
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
template_renderer=_SimpleJinja2Renderer(),
|
||||
jinja2_template_renderer=_SimpleJinja2Renderer(),
|
||||
)
|
||||
|
||||
# execute node
|
||||
|
||||
@@ -5,6 +5,7 @@ from unittest.mock import MagicMock, patch
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom
|
||||
from core.tools.utils.configuration import ToolParameterConfigurationManager
|
||||
from core.workflow.node_factory import DifyNodeFactory
|
||||
from core.workflow.node_runtime import DifyToolNodeRuntime
|
||||
from dify_graph.enums import WorkflowNodeExecutionStatus
|
||||
from dify_graph.graph import Graph
|
||||
from dify_graph.node_events import StreamCompletedEvent
|
||||
@@ -64,6 +65,7 @@ def init_tool_node(config: dict):
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
tool_file_manager_factory=tool_file_manager_factory,
|
||||
runtime=DifyToolNodeRuntime(init_params.run_context),
|
||||
)
|
||||
return node
|
||||
|
||||
|
||||
Reference in New Issue
Block a user