mirror of
https://github.com/langgenius/dify.git
synced 2026-05-26 22:00:53 -04:00
fix: preserve request-bound model runtime scope
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
@@ -5,7 +7,7 @@ from collections import defaultdict
|
||||
from collections.abc import Iterator, Sequence
|
||||
from json import JSONDecodeError
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -28,6 +30,8 @@ from dify_graph.model_runtime.entities.provider_entities import (
|
||||
ProviderEntity,
|
||||
)
|
||||
from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from dify_graph.model_runtime.runtime import ModelRuntime
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.engine import db
|
||||
from models.enums import CredentialSourceType
|
||||
@@ -60,6 +64,10 @@ class ProviderConfiguration(BaseModel):
|
||||
- Load balancing configurations
|
||||
- Model enablement/disablement
|
||||
|
||||
Request flows can bind a pre-scoped runtime via ``bind_model_runtime()`` so
|
||||
nested schema and model lookups reuse the caller scope that was already
|
||||
resolved by the composition layer.
|
||||
|
||||
TODO: lots of logic in a BaseModel entity should be separated, the exceptions should be classified
|
||||
"""
|
||||
|
||||
@@ -73,6 +81,7 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
_bound_model_runtime: ModelRuntime | None = PrivateAttr(default=None)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _(self):
|
||||
@@ -92,6 +101,16 @@ class ProviderConfiguration(BaseModel):
|
||||
self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL)
|
||||
return self
|
||||
|
||||
def bind_model_runtime(self, model_runtime: "ModelRuntime") -> None:
|
||||
"""Attach the already-composed runtime for request-bound call chains."""
|
||||
self._bound_model_runtime = model_runtime
|
||||
|
||||
def _get_model_provider_factory(self) -> ModelProviderFactory:
|
||||
"""Reuse a bound runtime when available, else fall back to tenant scope."""
|
||||
if self._bound_model_runtime is not None:
|
||||
return ModelProviderFactory(model_runtime=self._bound_model_runtime)
|
||||
return create_plugin_model_provider_factory(tenant_id=self.tenant_id)
|
||||
|
||||
def get_current_credentials(self, model_type: ModelType, model: str) -> dict | None:
|
||||
"""
|
||||
Get current credentials.
|
||||
@@ -343,7 +362,7 @@ class ProviderConfiguration(BaseModel):
|
||||
tenant_id=self.tenant_id, token=original_credentials[key]
|
||||
)
|
||||
|
||||
model_provider_factory = create_plugin_model_provider_factory(tenant_id=self.tenant_id)
|
||||
model_provider_factory = self._get_model_provider_factory()
|
||||
validated_credentials = model_provider_factory.provider_credentials_validate(
|
||||
provider=self.provider.provider, credentials=credentials
|
||||
)
|
||||
@@ -902,7 +921,7 @@ class ProviderConfiguration(BaseModel):
|
||||
tenant_id=self.tenant_id, token=original_credentials[key]
|
||||
)
|
||||
|
||||
model_provider_factory = create_plugin_model_provider_factory(tenant_id=self.tenant_id)
|
||||
model_provider_factory = self._get_model_provider_factory()
|
||||
validated_credentials = model_provider_factory.model_credentials_validate(
|
||||
provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
|
||||
)
|
||||
@@ -1388,7 +1407,7 @@ class ProviderConfiguration(BaseModel):
|
||||
:param model_type: model type
|
||||
:return:
|
||||
"""
|
||||
model_provider_factory = create_plugin_model_provider_factory(tenant_id=self.tenant_id)
|
||||
model_provider_factory = self._get_model_provider_factory()
|
||||
|
||||
# Get model instance of LLM
|
||||
return model_provider_factory.get_model_type_instance(provider=self.provider.provider, model_type=model_type)
|
||||
@@ -1397,7 +1416,7 @@ class ProviderConfiguration(BaseModel):
|
||||
"""
|
||||
Get model schema
|
||||
"""
|
||||
model_provider_factory = create_plugin_model_provider_factory(tenant_id=self.tenant_id)
|
||||
model_provider_factory = self._get_model_provider_factory()
|
||||
return model_provider_factory.get_model_schema(
|
||||
provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
|
||||
)
|
||||
@@ -1499,7 +1518,7 @@ class ProviderConfiguration(BaseModel):
|
||||
:param model: model name
|
||||
:return:
|
||||
"""
|
||||
model_provider_factory = create_plugin_model_provider_factory(tenant_id=self.tenant_id)
|
||||
model_provider_factory = self._get_model_provider_factory()
|
||||
provider_schema = model_provider_factory.get_provider_schema(self.provider.provider)
|
||||
|
||||
model_types: list[ModelType] = []
|
||||
|
||||
@@ -25,9 +25,12 @@ from models.provider_ids import ModelProviderID
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# `TS` means tenant scope
|
||||
TENANT_SCOPE_SCHEMA_CACHE_USER_ID = "__DIFY_TS__"
|
||||
|
||||
|
||||
class PluginModelRuntime(ModelRuntime):
|
||||
"""Plugin-backed runtime adapter bound to tenant context and a default user."""
|
||||
"""Plugin-backed runtime adapter bound to tenant context and optional caller scope."""
|
||||
|
||||
tenant_id: str
|
||||
user_id: str | None
|
||||
@@ -479,7 +482,10 @@ class PluginModelRuntime(ModelRuntime):
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
) -> str:
|
||||
cache_key = f"{self.tenant_id}:{provider}:{model_type.value}:{model}"
|
||||
# The plugin daemon distinguishes ``None`` from an explicit empty-string
|
||||
# caller id, so the cache must only collapse ``None`` into tenant scope.
|
||||
cache_user_id = TENANT_SCOPE_SCHEMA_CACHE_USER_ID if self.user_id is None else self.user_id
|
||||
cache_key = f"{self.tenant_id}:{provider}:{model_type.value}:{model}:{cache_user_id}"
|
||||
sorted_credentials = sorted(credentials.items()) if credentials else []
|
||||
if not sorted_credentials:
|
||||
return cache_key
|
||||
|
||||
@@ -65,6 +65,9 @@ class ProviderManager:
|
||||
|
||||
The runtime adapter is injected by the composition layer so this class stays
|
||||
focused on configuration assembly instead of constructing plugin runtimes.
|
||||
Request-bound managers may carry caller identity in that runtime, and the
|
||||
resulting ``ProviderConfiguration`` objects must reuse it for downstream
|
||||
model-type and schema lookups.
|
||||
"""
|
||||
|
||||
def __init__(self, model_runtime: ModelRuntime):
|
||||
@@ -264,6 +267,7 @@ class ProviderManager:
|
||||
custom_configuration=custom_configuration,
|
||||
model_settings=model_settings,
|
||||
)
|
||||
provider_configuration.bind_model_runtime(self._model_runtime)
|
||||
|
||||
provider_configurations[str(provider_id_entity)] = provider_configuration
|
||||
|
||||
|
||||
@@ -161,7 +161,7 @@ class DatasetRetrieval:
|
||||
if request.model_provider is None or request.model_name is None or request.query is None:
|
||||
raise ValueError("model_provider, model_name, and query are required for single retrieval mode")
|
||||
|
||||
model_manager = ModelManager.for_tenant(tenant_id=request.tenant_id)
|
||||
model_manager = ModelManager.for_tenant(tenant_id=request.tenant_id, user_id=request.user_id)
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=request.tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
@@ -384,23 +384,27 @@ class DatasetRetrieval:
|
||||
return None, []
|
||||
retrieve_config = config.retrieve_config
|
||||
|
||||
# check model is support tool calling
|
||||
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
|
||||
model_manager = ModelManager.for_tenant(tenant_id=tenant_id, user_id=user_id)
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=tenant_id, model_type=ModelType.LLM, provider=model_config.provider, model=model_config.model
|
||||
)
|
||||
model_type_instance = cast(LargeLanguageModel, model_instance.model_type_instance)
|
||||
|
||||
# get model schema
|
||||
# Reuse the caller-bound model instance for both schema resolution and
|
||||
# downstream planner/invoke calls so a single request never mixes
|
||||
# tenant-scope and request-bound runtimes.
|
||||
model_schema = model_type_instance.get_model_schema(
|
||||
model=model_config.model, credentials=model_config.credentials
|
||||
model=model_instance.model_name,
|
||||
credentials=model_instance.credentials,
|
||||
)
|
||||
|
||||
if not model_schema:
|
||||
return None, []
|
||||
|
||||
model_config.provider_model_bundle = model_instance.provider_model_bundle
|
||||
model_config.credentials = model_instance.credentials
|
||||
model_config.model_schema = model_schema
|
||||
|
||||
planning_strategy = PlanningStrategy.REACT_ROUTER
|
||||
features = model_schema.features
|
||||
if features:
|
||||
|
||||
@@ -120,6 +120,7 @@ class ReactMultiDatasetRouter:
|
||||
memory_config=None,
|
||||
memory=None,
|
||||
model_config=model_config,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
result_text, usage = self._invoke_llm(
|
||||
completion_param=model_config.parameters,
|
||||
|
||||
@@ -436,12 +436,16 @@ def test_get_model_type_instance_and_schema_delegate_to_factory() -> None:
|
||||
mock_factory.get_model_type_instance.return_value = mock_model_type_instance
|
||||
mock_factory.get_model_schema.return_value = mock_schema
|
||||
|
||||
with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory):
|
||||
with patch(
|
||||
"core.entities.provider_configuration.create_plugin_model_provider_factory",
|
||||
return_value=mock_factory,
|
||||
) as mock_factory_builder:
|
||||
model_type_instance = configuration.get_model_type_instance(ModelType.LLM)
|
||||
model_schema = configuration.get_model_schema(ModelType.LLM, "gpt-4o", {"api_key": "x"})
|
||||
|
||||
assert model_type_instance is mock_model_type_instance
|
||||
assert model_schema is mock_schema
|
||||
assert mock_factory_builder.call_count == 2
|
||||
mock_factory.get_model_type_instance.assert_called_once_with(provider="openai", model_type=ModelType.LLM)
|
||||
mock_factory.get_model_schema.assert_called_once_with(
|
||||
provider="openai",
|
||||
@@ -451,6 +455,31 @@ def test_get_model_type_instance_and_schema_delegate_to_factory() -> None:
|
||||
)
|
||||
|
||||
|
||||
def test_get_model_type_instance_and_schema_reuse_bound_runtime_factory() -> None:
|
||||
configuration = _build_provider_configuration()
|
||||
bound_runtime = Mock()
|
||||
configuration.bind_model_runtime(bound_runtime)
|
||||
|
||||
mock_factory = Mock()
|
||||
mock_model_type_instance = Mock()
|
||||
mock_schema = _build_ai_model("gpt-4o")
|
||||
mock_factory.get_model_type_instance.return_value = mock_model_type_instance
|
||||
mock_factory.get_model_schema.return_value = mock_schema
|
||||
|
||||
with (
|
||||
patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory) as mock_factory_cls,
|
||||
patch("core.entities.provider_configuration.create_plugin_model_provider_factory") as mock_factory_builder,
|
||||
):
|
||||
model_type_instance = configuration.get_model_type_instance(ModelType.LLM)
|
||||
model_schema = configuration.get_model_schema(ModelType.LLM, "gpt-4o", {"api_key": "x"})
|
||||
|
||||
assert model_type_instance is mock_model_type_instance
|
||||
assert model_schema is mock_schema
|
||||
assert mock_factory_cls.call_count == 2
|
||||
mock_factory_cls.assert_called_with(model_runtime=bound_runtime)
|
||||
mock_factory_builder.assert_not_called()
|
||||
|
||||
|
||||
def test_get_provider_model_returns_none_when_model_not_found() -> None:
|
||||
configuration = _build_provider_configuration()
|
||||
fake_model = SimpleNamespace(model="other-model")
|
||||
|
||||
@@ -10,7 +10,7 @@ import pytest
|
||||
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
||||
from core.plugin.impl import model_runtime as model_runtime_module
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
from core.plugin.impl.model_runtime import PluginModelRuntime
|
||||
from core.plugin.impl.model_runtime import PluginModelRuntime, TENANT_SCOPE_SCHEMA_CACHE_USER_ID
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_model_runtime
|
||||
from dify_graph.model_runtime.entities.common_entities import I18nObject
|
||||
from dify_graph.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
|
||||
@@ -416,6 +416,69 @@ def test_get_schema_cache_key_is_stable_across_credential_order() -> None:
|
||||
assert first == second
|
||||
|
||||
|
||||
def test_get_schema_cache_key_separates_distinct_user_scopes() -> None:
|
||||
first_runtime = PluginModelRuntime(tenant_id="tenant", user_id="user-a", client=Mock(spec=PluginModelClient))
|
||||
second_runtime = PluginModelRuntime(tenant_id="tenant", user_id="user-b", client=Mock(spec=PluginModelClient))
|
||||
|
||||
first = first_runtime._get_schema_cache_key(
|
||||
provider="langgenius/openai/openai",
|
||||
model_type=ModelType.LLM,
|
||||
model="gpt-4o-mini",
|
||||
credentials={"a": "1"},
|
||||
)
|
||||
second = second_runtime._get_schema_cache_key(
|
||||
provider="langgenius/openai/openai",
|
||||
model_type=ModelType.LLM,
|
||||
model="gpt-4o-mini",
|
||||
credentials={"a": "1"},
|
||||
)
|
||||
|
||||
assert first != second
|
||||
|
||||
|
||||
def test_get_schema_cache_key_separates_tenant_scope_from_user_scope() -> None:
|
||||
tenant_runtime = PluginModelRuntime(tenant_id="tenant", user_id=None, client=Mock(spec=PluginModelClient))
|
||||
user_runtime = PluginModelRuntime(tenant_id="tenant", user_id="user-a", client=Mock(spec=PluginModelClient))
|
||||
|
||||
tenant_key = tenant_runtime._get_schema_cache_key(
|
||||
provider="langgenius/openai/openai",
|
||||
model_type=ModelType.LLM,
|
||||
model="gpt-4o-mini",
|
||||
credentials={"a": "1"},
|
||||
)
|
||||
user_key = user_runtime._get_schema_cache_key(
|
||||
provider="langgenius/openai/openai",
|
||||
model_type=ModelType.LLM,
|
||||
model="gpt-4o-mini",
|
||||
credentials={"a": "1"},
|
||||
)
|
||||
|
||||
assert tenant_key != user_key
|
||||
assert f":{TENANT_SCOPE_SCHEMA_CACHE_USER_ID}" in tenant_key
|
||||
|
||||
|
||||
def test_get_schema_cache_key_separates_tenant_scope_from_empty_string_user_scope() -> None:
|
||||
tenant_runtime = PluginModelRuntime(tenant_id="tenant", user_id=None, client=Mock(spec=PluginModelClient))
|
||||
empty_user_runtime = PluginModelRuntime(tenant_id="tenant", user_id="", client=Mock(spec=PluginModelClient))
|
||||
|
||||
tenant_key = tenant_runtime._get_schema_cache_key(
|
||||
provider="langgenius/openai/openai",
|
||||
model_type=ModelType.LLM,
|
||||
model="gpt-4o-mini",
|
||||
credentials={},
|
||||
)
|
||||
empty_user_key = empty_user_runtime._get_schema_cache_key(
|
||||
provider="langgenius/openai/openai",
|
||||
model_type=ModelType.LLM,
|
||||
model="gpt-4o-mini",
|
||||
credentials={},
|
||||
)
|
||||
|
||||
assert tenant_key != empty_user_key
|
||||
assert empty_user_key.endswith(":")
|
||||
assert TENANT_SCOPE_SCHEMA_CACHE_USER_ID not in empty_user_key
|
||||
|
||||
|
||||
def test_get_provider_schema_supports_short_alias_and_rejects_invalid_provider() -> None:
|
||||
client = Mock(spec=PluginModelClient)
|
||||
client.fetch_model_providers.return_value = [
|
||||
|
||||
@@ -4246,6 +4246,7 @@ class TestKnowledgeRetrievalCoverage:
|
||||
mock_model_manager.return_value.get_model_instance.return_value = model_instance
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
retrieval.knowledge_retrieval(request)
|
||||
mock_model_manager.assert_called_once_with(tenant_id="tenant-1", user_id="user-1")
|
||||
assert error_cls in type(exc_info.value).__name__
|
||||
|
||||
|
||||
@@ -4298,9 +4299,13 @@ class TestRetrieveCoverage:
|
||||
),
|
||||
)
|
||||
model_config = self._build_model_config()
|
||||
model_config.provider_model_bundle.model_type_instance.get_model_schema.return_value = None
|
||||
model_instance = Mock()
|
||||
model_instance.model_name = "gpt-4"
|
||||
model_instance.credentials = {"api_key": "secret"}
|
||||
model_instance.provider_model_bundle = Mock()
|
||||
model_instance.model_type_instance.get_model_schema.return_value = None
|
||||
with patch("core.rag.retrieval.dataset_retrieval.ModelManager.for_tenant") as mock_model_manager:
|
||||
mock_model_manager.return_value.get_model_instance.return_value = Mock()
|
||||
mock_model_manager.return_value.get_model_instance.return_value = model_instance
|
||||
result = retrieval.retrieve(
|
||||
app_id="app-1",
|
||||
user_id="user-1",
|
||||
@@ -4313,8 +4318,58 @@ class TestRetrieveCoverage:
|
||||
hit_callback=Mock(),
|
||||
message_id="m1",
|
||||
)
|
||||
mock_model_manager.assert_called_once_with(tenant_id="tenant-1", user_id="user-1")
|
||||
assert result == (None, [])
|
||||
|
||||
def test_retrieve_uses_bound_model_instance_schema_and_updates_model_config(
|
||||
self, retrieval: DatasetRetrieval
|
||||
) -> None:
|
||||
config = DatasetEntity(
|
||||
dataset_ids=["d1"],
|
||||
retrieve_config=DatasetRetrieveConfigEntity(
|
||||
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE,
|
||||
metadata_filtering_mode="disabled",
|
||||
),
|
||||
)
|
||||
model_config = self._build_model_config(features=[])
|
||||
model_config.provider_model_bundle.model_type_instance.get_model_schema.return_value = None
|
||||
bound_schema = SimpleNamespace(features=[ModelFeature.TOOL_CALL])
|
||||
bound_bundle = Mock()
|
||||
bound_model_instance = Mock()
|
||||
bound_model_instance.model_name = "gpt-4"
|
||||
bound_model_instance.credentials = {"api_key": "secret"}
|
||||
bound_model_instance.provider_model_bundle = bound_bundle
|
||||
bound_model_instance.model_type_instance.get_model_schema.return_value = bound_schema
|
||||
|
||||
with (
|
||||
patch("core.rag.retrieval.dataset_retrieval.ModelManager.for_tenant") as mock_model_manager,
|
||||
patch.object(retrieval, "_get_available_datasets", return_value=[SimpleNamespace(id="d1")]),
|
||||
patch.object(retrieval, "get_metadata_filter_condition", return_value=(None, None)),
|
||||
patch.object(retrieval, "single_retrieve", return_value=[]) as mock_single_retrieve,
|
||||
):
|
||||
mock_model_manager.return_value.get_model_instance.return_value = bound_model_instance
|
||||
context, files = retrieval.retrieve(
|
||||
app_id="app-1",
|
||||
user_id="user-1",
|
||||
tenant_id="tenant-1",
|
||||
model_config=model_config,
|
||||
config=config,
|
||||
query="python",
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
show_retrieve_source=False,
|
||||
hit_callback=Mock(),
|
||||
message_id="m1",
|
||||
)
|
||||
|
||||
mock_model_manager.assert_called_once_with(tenant_id="tenant-1", user_id="user-1")
|
||||
mock_single_retrieve.assert_called_once()
|
||||
assert mock_single_retrieve.call_args.args[8] == PlanningStrategy.ROUTER
|
||||
assert model_config.provider_model_bundle is bound_bundle
|
||||
assert model_config.credentials == {"api_key": "secret"}
|
||||
assert model_config.model_schema is bound_schema
|
||||
assert context == ""
|
||||
assert files == []
|
||||
|
||||
def test_single_strategy_with_external_documents(self, retrieval: DatasetRetrieval) -> None:
|
||||
retrieve_config = DatasetRetrieveConfigEntity(
|
||||
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE,
|
||||
@@ -4336,7 +4391,12 @@ class TestRetrieveCoverage:
|
||||
patch.object(retrieval, "get_metadata_filter_condition", return_value=(None, None)),
|
||||
patch.object(retrieval, "single_retrieve", return_value=[external_doc]),
|
||||
):
|
||||
mock_model_manager.return_value.get_model_instance.return_value = Mock()
|
||||
bound_model_instance = Mock()
|
||||
bound_model_instance.model_name = "gpt-4"
|
||||
bound_model_instance.credentials = {}
|
||||
bound_model_instance.provider_model_bundle = Mock()
|
||||
bound_model_instance.model_type_instance.get_model_schema.return_value = SimpleNamespace(features=[])
|
||||
mock_model_manager.return_value.get_model_instance.return_value = bound_model_instance
|
||||
context, files = retrieval.retrieve(
|
||||
app_id="app-1",
|
||||
user_id="user-1",
|
||||
@@ -4432,7 +4492,14 @@ class TestRetrieveCoverage:
|
||||
patch("core.rag.retrieval.dataset_retrieval.sign_upload_file", return_value="https://signed"),
|
||||
patch("core.rag.retrieval.dataset_retrieval.db.session.execute") as mock_execute,
|
||||
):
|
||||
mock_model_manager.return_value.get_model_instance.return_value = Mock()
|
||||
bound_model_instance = Mock()
|
||||
bound_model_instance.model_name = "gpt-4"
|
||||
bound_model_instance.credentials = {}
|
||||
bound_model_instance.provider_model_bundle = Mock()
|
||||
bound_model_instance.model_type_instance.get_model_schema.return_value = SimpleNamespace(
|
||||
features=[ModelFeature.TOOL_CALL]
|
||||
)
|
||||
mock_model_manager.return_value.get_model_instance.return_value = bound_model_instance
|
||||
mock_execute.side_effect = [execute_attachments, execute_docs, execute_datasets]
|
||||
context, files = retrieval.retrieve(
|
||||
app_id="app-1",
|
||||
|
||||
@@ -5,6 +5,7 @@ from core.rag.retrieval.output_parser.react_output import ReactAction, ReactFini
|
||||
from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
|
||||
from dify_graph.model_runtime.entities.message_entities import PromptMessageRole
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
|
||||
|
||||
class TestReactMultiDatasetRouter:
|
||||
@@ -87,6 +88,7 @@ class TestReactMultiDatasetRouter:
|
||||
model_config = Mock()
|
||||
model_config.mode = "chat"
|
||||
model_config.parameters = {"temperature": 0.1}
|
||||
model_instance = Mock()
|
||||
usage = LLMUsage.empty_usage()
|
||||
tools = [Mock(name="dataset-1"), Mock(name="dataset-2")]
|
||||
tools[0].name = "dataset-1"
|
||||
@@ -108,13 +110,14 @@ class TestReactMultiDatasetRouter:
|
||||
dataset_id, returned_usage = router._react_invoke(
|
||||
query="python",
|
||||
model_config=model_config,
|
||||
model_instance=Mock(),
|
||||
model_instance=model_instance,
|
||||
tools=tools,
|
||||
user_id="u1",
|
||||
tenant_id="t1",
|
||||
)
|
||||
|
||||
mock_chat_prompt.assert_called_once()
|
||||
assert mock_prompt_transform.return_value.get_prompt.call_args.kwargs["model_instance"] is model_instance
|
||||
assert dataset_id == "dataset-2"
|
||||
assert returned_usage == usage
|
||||
|
||||
@@ -178,7 +181,13 @@ class TestReactMultiDatasetRouter:
|
||||
|
||||
assert text == "part"
|
||||
assert returned_usage == usage
|
||||
mock_manager.assert_any_call(tenant_id="t1", user_id="u1")
|
||||
mock_manager.assert_called_once_with(tenant_id="t1", user_id="u1")
|
||||
mock_manager.return_value.get_model_instance.assert_called_once_with(
|
||||
tenant_id="t1",
|
||||
provider=model_instance.provider,
|
||||
model_type=ModelType.LLM,
|
||||
model=model_instance.model_name,
|
||||
)
|
||||
mock_deduct.assert_called_once()
|
||||
|
||||
def test_handle_invoke_result_with_empty_usage(self) -> None:
|
||||
|
||||
@@ -343,6 +343,35 @@ def test_get_provider_model_bundle_raises_for_unknown_provider(mocker: MockerFix
|
||||
manager.get_provider_model_bundle("tenant-id", "openai", ModelType.LLM)
|
||||
|
||||
|
||||
def test_get_configurations_binds_manager_runtime_to_provider_configuration(
|
||||
mocker: MockerFixture, mock_provider_entity
|
||||
):
|
||||
manager = _build_provider_manager(mocker)
|
||||
provider_configuration = Mock()
|
||||
provider_factory = Mock()
|
||||
provider_factory.get_providers.return_value = [mock_provider_entity]
|
||||
custom_configuration = SimpleNamespace(provider=None, models=[])
|
||||
system_configuration = SimpleNamespace(enabled=False, quota_configurations=[], current_quota_type=None)
|
||||
|
||||
with (
|
||||
patch.object(manager, "_get_all_providers", return_value={"openai": []}),
|
||||
patch.object(manager, "_init_trial_provider_records", return_value={"openai": []}),
|
||||
patch.object(manager, "_get_all_provider_models", return_value={"openai": []}),
|
||||
patch.object(manager, "_get_all_preferred_model_providers", return_value={}),
|
||||
patch.object(manager, "_get_all_provider_model_settings", return_value={}),
|
||||
patch.object(manager, "_get_all_provider_load_balancing_configs", return_value={}),
|
||||
patch.object(manager, "_get_all_provider_model_credentials", return_value={}),
|
||||
patch.object(manager, "_to_custom_configuration", return_value=custom_configuration),
|
||||
patch.object(manager, "_to_system_configuration", return_value=system_configuration),
|
||||
patch.object(manager, "_to_model_settings", return_value=[]),
|
||||
patch("core.provider_manager.ModelProviderFactory", return_value=provider_factory),
|
||||
patch("core.provider_manager.ProviderConfiguration", return_value=provider_configuration),
|
||||
):
|
||||
manager.get_configurations("tenant-id")
|
||||
|
||||
provider_configuration.bind_model_runtime.assert_called_once_with(manager._model_runtime)
|
||||
|
||||
|
||||
def test_get_provider_model_bundle_returns_selected_model_type_instance(mocker: MockerFixture):
|
||||
manager = _build_provider_manager(mocker)
|
||||
provider_configuration = Mock()
|
||||
|
||||
Reference in New Issue
Block a user