fix: preserve request-bound model runtime scope

This commit is contained in:
WH-2099
2026-03-21 12:55:03 +08:00
parent f8538ae36a
commit e589ed23a8
10 changed files with 255 additions and 24 deletions

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
import json
import logging
import re
@@ -5,7 +7,7 @@ from collections import defaultdict
from collections.abc import Iterator, Sequence
from json import JSONDecodeError
from pydantic import BaseModel, ConfigDict, Field, model_validator
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator
from sqlalchemy import func, select
from sqlalchemy.orm import Session
@@ -28,6 +30,8 @@ from dify_graph.model_runtime.entities.provider_entities import (
ProviderEntity,
)
from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel
from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from dify_graph.model_runtime.runtime import ModelRuntime
from libs.datetime_utils import naive_utc_now
from models.engine import db
from models.enums import CredentialSourceType
@@ -60,6 +64,10 @@ class ProviderConfiguration(BaseModel):
- Load balancing configurations
- Model enablement/disablement
Request flows can bind a pre-scoped runtime via ``bind_model_runtime()`` so
nested schema and model lookups reuse the caller scope that was already
resolved by the composition layer.
TODO: lots of logic in a BaseModel entity should be separated, the exceptions should be classified
"""
@@ -73,6 +81,7 @@ class ProviderConfiguration(BaseModel):
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
_bound_model_runtime: ModelRuntime | None = PrivateAttr(default=None)
@model_validator(mode="after")
def _(self):
@@ -92,6 +101,16 @@ class ProviderConfiguration(BaseModel):
self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL)
return self
def bind_model_runtime(self, model_runtime: "ModelRuntime") -> None:
"""Attach the already-composed runtime for request-bound call chains."""
self._bound_model_runtime = model_runtime
def _get_model_provider_factory(self) -> ModelProviderFactory:
"""Reuse a bound runtime when available, else fall back to tenant scope."""
if self._bound_model_runtime is not None:
return ModelProviderFactory(model_runtime=self._bound_model_runtime)
return create_plugin_model_provider_factory(tenant_id=self.tenant_id)
def get_current_credentials(self, model_type: ModelType, model: str) -> dict | None:
"""
Get current credentials.
@@ -343,7 +362,7 @@ class ProviderConfiguration(BaseModel):
tenant_id=self.tenant_id, token=original_credentials[key]
)
model_provider_factory = create_plugin_model_provider_factory(tenant_id=self.tenant_id)
model_provider_factory = self._get_model_provider_factory()
validated_credentials = model_provider_factory.provider_credentials_validate(
provider=self.provider.provider, credentials=credentials
)
@@ -902,7 +921,7 @@ class ProviderConfiguration(BaseModel):
tenant_id=self.tenant_id, token=original_credentials[key]
)
model_provider_factory = create_plugin_model_provider_factory(tenant_id=self.tenant_id)
model_provider_factory = self._get_model_provider_factory()
validated_credentials = model_provider_factory.model_credentials_validate(
provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
)
@@ -1388,7 +1407,7 @@ class ProviderConfiguration(BaseModel):
:param model_type: model type
:return:
"""
model_provider_factory = create_plugin_model_provider_factory(tenant_id=self.tenant_id)
model_provider_factory = self._get_model_provider_factory()
# Get model instance of LLM
return model_provider_factory.get_model_type_instance(provider=self.provider.provider, model_type=model_type)
@@ -1397,7 +1416,7 @@ class ProviderConfiguration(BaseModel):
"""
Get model schema
"""
model_provider_factory = create_plugin_model_provider_factory(tenant_id=self.tenant_id)
model_provider_factory = self._get_model_provider_factory()
return model_provider_factory.get_model_schema(
provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
)
@@ -1499,7 +1518,7 @@ class ProviderConfiguration(BaseModel):
:param model: model name
:return:
"""
model_provider_factory = create_plugin_model_provider_factory(tenant_id=self.tenant_id)
model_provider_factory = self._get_model_provider_factory()
provider_schema = model_provider_factory.get_provider_schema(self.provider.provider)
model_types: list[ModelType] = []

View File

@@ -25,9 +25,12 @@ from models.provider_ids import ModelProviderID
logger = logging.getLogger(__name__)
# `TS` means tenant scope
TENANT_SCOPE_SCHEMA_CACHE_USER_ID = "__DIFY_TS__"
class PluginModelRuntime(ModelRuntime):
"""Plugin-backed runtime adapter bound to tenant context and a default user."""
"""Plugin-backed runtime adapter bound to tenant context and optional caller scope."""
tenant_id: str
user_id: str | None
@@ -479,7 +482,10 @@ class PluginModelRuntime(ModelRuntime):
model: str,
credentials: dict[str, Any],
) -> str:
cache_key = f"{self.tenant_id}:{provider}:{model_type.value}:{model}"
# The plugin daemon distinguishes ``None`` from an explicit empty-string
# caller id, so the cache must only collapse ``None`` into tenant scope.
cache_user_id = TENANT_SCOPE_SCHEMA_CACHE_USER_ID if self.user_id is None else self.user_id
cache_key = f"{self.tenant_id}:{provider}:{model_type.value}:{model}:{cache_user_id}"
sorted_credentials = sorted(credentials.items()) if credentials else []
if not sorted_credentials:
return cache_key

View File

@@ -65,6 +65,9 @@ class ProviderManager:
The runtime adapter is injected by the composition layer so this class stays
focused on configuration assembly instead of constructing plugin runtimes.
Request-bound managers may carry caller identity in that runtime, and the
resulting ``ProviderConfiguration`` objects must reuse it for downstream
model-type and schema lookups.
"""
def __init__(self, model_runtime: ModelRuntime):
@@ -264,6 +267,7 @@ class ProviderManager:
custom_configuration=custom_configuration,
model_settings=model_settings,
)
provider_configuration.bind_model_runtime(self._model_runtime)
provider_configurations[str(provider_id_entity)] = provider_configuration

View File

@@ -161,7 +161,7 @@ class DatasetRetrieval:
if request.model_provider is None or request.model_name is None or request.query is None:
raise ValueError("model_provider, model_name, and query are required for single retrieval mode")
model_manager = ModelManager.for_tenant(tenant_id=request.tenant_id)
model_manager = ModelManager.for_tenant(tenant_id=request.tenant_id, user_id=request.user_id)
model_instance = model_manager.get_model_instance(
tenant_id=request.tenant_id,
model_type=ModelType.LLM,
@@ -384,23 +384,27 @@ class DatasetRetrieval:
return None, []
retrieve_config = config.retrieve_config
# check model is support tool calling
model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
model_manager = ModelManager.for_tenant(tenant_id=tenant_id, user_id=user_id)
model_instance = model_manager.get_model_instance(
tenant_id=tenant_id, model_type=ModelType.LLM, provider=model_config.provider, model=model_config.model
)
model_type_instance = cast(LargeLanguageModel, model_instance.model_type_instance)
# get model schema
# Reuse the caller-bound model instance for both schema resolution and
# downstream planner/invoke calls so a single request never mixes
# tenant-scope and request-bound runtimes.
model_schema = model_type_instance.get_model_schema(
model=model_config.model, credentials=model_config.credentials
model=model_instance.model_name,
credentials=model_instance.credentials,
)
if not model_schema:
return None, []
model_config.provider_model_bundle = model_instance.provider_model_bundle
model_config.credentials = model_instance.credentials
model_config.model_schema = model_schema
planning_strategy = PlanningStrategy.REACT_ROUTER
features = model_schema.features
if features:

View File

@@ -120,6 +120,7 @@ class ReactMultiDatasetRouter:
memory_config=None,
memory=None,
model_config=model_config,
model_instance=model_instance,
)
result_text, usage = self._invoke_llm(
completion_param=model_config.parameters,

View File

@@ -436,12 +436,16 @@ def test_get_model_type_instance_and_schema_delegate_to_factory() -> None:
mock_factory.get_model_type_instance.return_value = mock_model_type_instance
mock_factory.get_model_schema.return_value = mock_schema
with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory):
with patch(
"core.entities.provider_configuration.create_plugin_model_provider_factory",
return_value=mock_factory,
) as mock_factory_builder:
model_type_instance = configuration.get_model_type_instance(ModelType.LLM)
model_schema = configuration.get_model_schema(ModelType.LLM, "gpt-4o", {"api_key": "x"})
assert model_type_instance is mock_model_type_instance
assert model_schema is mock_schema
assert mock_factory_builder.call_count == 2
mock_factory.get_model_type_instance.assert_called_once_with(provider="openai", model_type=ModelType.LLM)
mock_factory.get_model_schema.assert_called_once_with(
provider="openai",
@@ -451,6 +455,31 @@ def test_get_model_type_instance_and_schema_delegate_to_factory() -> None:
)
def test_get_model_type_instance_and_schema_reuse_bound_runtime_factory() -> None:
configuration = _build_provider_configuration()
bound_runtime = Mock()
configuration.bind_model_runtime(bound_runtime)
mock_factory = Mock()
mock_model_type_instance = Mock()
mock_schema = _build_ai_model("gpt-4o")
mock_factory.get_model_type_instance.return_value = mock_model_type_instance
mock_factory.get_model_schema.return_value = mock_schema
with (
patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory) as mock_factory_cls,
patch("core.entities.provider_configuration.create_plugin_model_provider_factory") as mock_factory_builder,
):
model_type_instance = configuration.get_model_type_instance(ModelType.LLM)
model_schema = configuration.get_model_schema(ModelType.LLM, "gpt-4o", {"api_key": "x"})
assert model_type_instance is mock_model_type_instance
assert model_schema is mock_schema
assert mock_factory_cls.call_count == 2
mock_factory_cls.assert_called_with(model_runtime=bound_runtime)
mock_factory_builder.assert_not_called()
def test_get_provider_model_returns_none_when_model_not_found() -> None:
configuration = _build_provider_configuration()
fake_model = SimpleNamespace(model="other-model")

View File

@@ -10,7 +10,7 @@ import pytest
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
from core.plugin.impl import model_runtime as model_runtime_module
from core.plugin.impl.model import PluginModelClient
from core.plugin.impl.model_runtime import PluginModelRuntime
from core.plugin.impl.model_runtime import PluginModelRuntime, TENANT_SCOPE_SCHEMA_CACHE_USER_ID
from core.plugin.impl.model_runtime_factory import create_plugin_model_runtime
from dify_graph.model_runtime.entities.common_entities import I18nObject
from dify_graph.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
@@ -416,6 +416,69 @@ def test_get_schema_cache_key_is_stable_across_credential_order() -> None:
assert first == second
def test_get_schema_cache_key_separates_distinct_user_scopes() -> None:
first_runtime = PluginModelRuntime(tenant_id="tenant", user_id="user-a", client=Mock(spec=PluginModelClient))
second_runtime = PluginModelRuntime(tenant_id="tenant", user_id="user-b", client=Mock(spec=PluginModelClient))
first = first_runtime._get_schema_cache_key(
provider="langgenius/openai/openai",
model_type=ModelType.LLM,
model="gpt-4o-mini",
credentials={"a": "1"},
)
second = second_runtime._get_schema_cache_key(
provider="langgenius/openai/openai",
model_type=ModelType.LLM,
model="gpt-4o-mini",
credentials={"a": "1"},
)
assert first != second
def test_get_schema_cache_key_separates_tenant_scope_from_user_scope() -> None:
tenant_runtime = PluginModelRuntime(tenant_id="tenant", user_id=None, client=Mock(spec=PluginModelClient))
user_runtime = PluginModelRuntime(tenant_id="tenant", user_id="user-a", client=Mock(spec=PluginModelClient))
tenant_key = tenant_runtime._get_schema_cache_key(
provider="langgenius/openai/openai",
model_type=ModelType.LLM,
model="gpt-4o-mini",
credentials={"a": "1"},
)
user_key = user_runtime._get_schema_cache_key(
provider="langgenius/openai/openai",
model_type=ModelType.LLM,
model="gpt-4o-mini",
credentials={"a": "1"},
)
assert tenant_key != user_key
assert f":{TENANT_SCOPE_SCHEMA_CACHE_USER_ID}" in tenant_key
def test_get_schema_cache_key_separates_tenant_scope_from_empty_string_user_scope() -> None:
tenant_runtime = PluginModelRuntime(tenant_id="tenant", user_id=None, client=Mock(spec=PluginModelClient))
empty_user_runtime = PluginModelRuntime(tenant_id="tenant", user_id="", client=Mock(spec=PluginModelClient))
tenant_key = tenant_runtime._get_schema_cache_key(
provider="langgenius/openai/openai",
model_type=ModelType.LLM,
model="gpt-4o-mini",
credentials={},
)
empty_user_key = empty_user_runtime._get_schema_cache_key(
provider="langgenius/openai/openai",
model_type=ModelType.LLM,
model="gpt-4o-mini",
credentials={},
)
assert tenant_key != empty_user_key
assert empty_user_key.endswith(":")
assert TENANT_SCOPE_SCHEMA_CACHE_USER_ID not in empty_user_key
def test_get_provider_schema_supports_short_alias_and_rejects_invalid_provider() -> None:
client = Mock(spec=PluginModelClient)
client.fetch_model_providers.return_value = [

View File

@@ -4246,6 +4246,7 @@ class TestKnowledgeRetrievalCoverage:
mock_model_manager.return_value.get_model_instance.return_value = model_instance
with pytest.raises(Exception) as exc_info:
retrieval.knowledge_retrieval(request)
mock_model_manager.assert_called_once_with(tenant_id="tenant-1", user_id="user-1")
assert error_cls in type(exc_info.value).__name__
@@ -4298,9 +4299,13 @@ class TestRetrieveCoverage:
),
)
model_config = self._build_model_config()
model_config.provider_model_bundle.model_type_instance.get_model_schema.return_value = None
model_instance = Mock()
model_instance.model_name = "gpt-4"
model_instance.credentials = {"api_key": "secret"}
model_instance.provider_model_bundle = Mock()
model_instance.model_type_instance.get_model_schema.return_value = None
with patch("core.rag.retrieval.dataset_retrieval.ModelManager.for_tenant") as mock_model_manager:
mock_model_manager.return_value.get_model_instance.return_value = Mock()
mock_model_manager.return_value.get_model_instance.return_value = model_instance
result = retrieval.retrieve(
app_id="app-1",
user_id="user-1",
@@ -4313,8 +4318,58 @@ class TestRetrieveCoverage:
hit_callback=Mock(),
message_id="m1",
)
mock_model_manager.assert_called_once_with(tenant_id="tenant-1", user_id="user-1")
assert result == (None, [])
def test_retrieve_uses_bound_model_instance_schema_and_updates_model_config(
self, retrieval: DatasetRetrieval
) -> None:
config = DatasetEntity(
dataset_ids=["d1"],
retrieve_config=DatasetRetrieveConfigEntity(
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE,
metadata_filtering_mode="disabled",
),
)
model_config = self._build_model_config(features=[])
model_config.provider_model_bundle.model_type_instance.get_model_schema.return_value = None
bound_schema = SimpleNamespace(features=[ModelFeature.TOOL_CALL])
bound_bundle = Mock()
bound_model_instance = Mock()
bound_model_instance.model_name = "gpt-4"
bound_model_instance.credentials = {"api_key": "secret"}
bound_model_instance.provider_model_bundle = bound_bundle
bound_model_instance.model_type_instance.get_model_schema.return_value = bound_schema
with (
patch("core.rag.retrieval.dataset_retrieval.ModelManager.for_tenant") as mock_model_manager,
patch.object(retrieval, "_get_available_datasets", return_value=[SimpleNamespace(id="d1")]),
patch.object(retrieval, "get_metadata_filter_condition", return_value=(None, None)),
patch.object(retrieval, "single_retrieve", return_value=[]) as mock_single_retrieve,
):
mock_model_manager.return_value.get_model_instance.return_value = bound_model_instance
context, files = retrieval.retrieve(
app_id="app-1",
user_id="user-1",
tenant_id="tenant-1",
model_config=model_config,
config=config,
query="python",
invoke_from=InvokeFrom.WEB_APP,
show_retrieve_source=False,
hit_callback=Mock(),
message_id="m1",
)
mock_model_manager.assert_called_once_with(tenant_id="tenant-1", user_id="user-1")
mock_single_retrieve.assert_called_once()
assert mock_single_retrieve.call_args.args[8] == PlanningStrategy.ROUTER
assert model_config.provider_model_bundle is bound_bundle
assert model_config.credentials == {"api_key": "secret"}
assert model_config.model_schema is bound_schema
assert context == ""
assert files == []
def test_single_strategy_with_external_documents(self, retrieval: DatasetRetrieval) -> None:
retrieve_config = DatasetRetrieveConfigEntity(
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE,
@@ -4336,7 +4391,12 @@ class TestRetrieveCoverage:
patch.object(retrieval, "get_metadata_filter_condition", return_value=(None, None)),
patch.object(retrieval, "single_retrieve", return_value=[external_doc]),
):
mock_model_manager.return_value.get_model_instance.return_value = Mock()
bound_model_instance = Mock()
bound_model_instance.model_name = "gpt-4"
bound_model_instance.credentials = {}
bound_model_instance.provider_model_bundle = Mock()
bound_model_instance.model_type_instance.get_model_schema.return_value = SimpleNamespace(features=[])
mock_model_manager.return_value.get_model_instance.return_value = bound_model_instance
context, files = retrieval.retrieve(
app_id="app-1",
user_id="user-1",
@@ -4432,7 +4492,14 @@ class TestRetrieveCoverage:
patch("core.rag.retrieval.dataset_retrieval.sign_upload_file", return_value="https://signed"),
patch("core.rag.retrieval.dataset_retrieval.db.session.execute") as mock_execute,
):
mock_model_manager.return_value.get_model_instance.return_value = Mock()
bound_model_instance = Mock()
bound_model_instance.model_name = "gpt-4"
bound_model_instance.credentials = {}
bound_model_instance.provider_model_bundle = Mock()
bound_model_instance.model_type_instance.get_model_schema.return_value = SimpleNamespace(
features=[ModelFeature.TOOL_CALL]
)
mock_model_manager.return_value.get_model_instance.return_value = bound_model_instance
mock_execute.side_effect = [execute_attachments, execute_docs, execute_datasets]
context, files = retrieval.retrieve(
app_id="app-1",

View File

@@ -5,6 +5,7 @@ from core.rag.retrieval.output_parser.react_output import ReactAction, ReactFini
from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
from dify_graph.model_runtime.entities.message_entities import PromptMessageRole
from dify_graph.model_runtime.entities.model_entities import ModelType
class TestReactMultiDatasetRouter:
@@ -87,6 +88,7 @@ class TestReactMultiDatasetRouter:
model_config = Mock()
model_config.mode = "chat"
model_config.parameters = {"temperature": 0.1}
model_instance = Mock()
usage = LLMUsage.empty_usage()
tools = [Mock(name="dataset-1"), Mock(name="dataset-2")]
tools[0].name = "dataset-1"
@@ -108,13 +110,14 @@ class TestReactMultiDatasetRouter:
dataset_id, returned_usage = router._react_invoke(
query="python",
model_config=model_config,
model_instance=Mock(),
model_instance=model_instance,
tools=tools,
user_id="u1",
tenant_id="t1",
)
mock_chat_prompt.assert_called_once()
assert mock_prompt_transform.return_value.get_prompt.call_args.kwargs["model_instance"] is model_instance
assert dataset_id == "dataset-2"
assert returned_usage == usage
@@ -178,7 +181,13 @@ class TestReactMultiDatasetRouter:
assert text == "part"
assert returned_usage == usage
mock_manager.assert_any_call(tenant_id="t1", user_id="u1")
mock_manager.assert_called_once_with(tenant_id="t1", user_id="u1")
mock_manager.return_value.get_model_instance.assert_called_once_with(
tenant_id="t1",
provider=model_instance.provider,
model_type=ModelType.LLM,
model=model_instance.model_name,
)
mock_deduct.assert_called_once()
def test_handle_invoke_result_with_empty_usage(self) -> None:

View File

@@ -343,6 +343,35 @@ def test_get_provider_model_bundle_raises_for_unknown_provider(mocker: MockerFix
manager.get_provider_model_bundle("tenant-id", "openai", ModelType.LLM)
def test_get_configurations_binds_manager_runtime_to_provider_configuration(
mocker: MockerFixture, mock_provider_entity
):
manager = _build_provider_manager(mocker)
provider_configuration = Mock()
provider_factory = Mock()
provider_factory.get_providers.return_value = [mock_provider_entity]
custom_configuration = SimpleNamespace(provider=None, models=[])
system_configuration = SimpleNamespace(enabled=False, quota_configurations=[], current_quota_type=None)
with (
patch.object(manager, "_get_all_providers", return_value={"openai": []}),
patch.object(manager, "_init_trial_provider_records", return_value={"openai": []}),
patch.object(manager, "_get_all_provider_models", return_value={"openai": []}),
patch.object(manager, "_get_all_preferred_model_providers", return_value={}),
patch.object(manager, "_get_all_provider_model_settings", return_value={}),
patch.object(manager, "_get_all_provider_load_balancing_configs", return_value={}),
patch.object(manager, "_get_all_provider_model_credentials", return_value={}),
patch.object(manager, "_to_custom_configuration", return_value=custom_configuration),
patch.object(manager, "_to_system_configuration", return_value=system_configuration),
patch.object(manager, "_to_model_settings", return_value=[]),
patch("core.provider_manager.ModelProviderFactory", return_value=provider_factory),
patch("core.provider_manager.ProviderConfiguration", return_value=provider_configuration),
):
manager.get_configurations("tenant-id")
provider_configuration.bind_model_runtime.assert_called_once_with(manager._model_runtime)
def test_get_provider_model_bundle_returns_selected_model_type_instance(mocker: MockerFixture):
manager = _build_provider_manager(mocker)
provider_configuration = Mock()