diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index 10bb8dc042..c8f1a2ab7f 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json import logging import re @@ -5,7 +7,7 @@ from collections import defaultdict from collections.abc import Iterator, Sequence from json import JSONDecodeError -from pydantic import BaseModel, ConfigDict, Field, model_validator +from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator from sqlalchemy import func, select from sqlalchemy.orm import Session @@ -28,6 +30,8 @@ from dify_graph.model_runtime.entities.provider_entities import ( ProviderEntity, ) from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel +from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory +from dify_graph.model_runtime.runtime import ModelRuntime from libs.datetime_utils import naive_utc_now from models.engine import db from models.enums import CredentialSourceType @@ -60,6 +64,10 @@ class ProviderConfiguration(BaseModel): - Load balancing configurations - Model enablement/disablement + Request flows can bind a pre-scoped runtime via ``bind_model_runtime()`` so + nested schema and model lookups reuse the caller scope that was already + resolved by the composition layer. + TODO: lots of logic in a BaseModel entity should be separated, the exceptions should be classified """ @@ -73,6 +81,7 @@ class ProviderConfiguration(BaseModel): # pydantic configs model_config = ConfigDict(protected_namespaces=()) + _bound_model_runtime: ModelRuntime | None = PrivateAttr(default=None) @model_validator(mode="after") def _(self): @@ -92,6 +101,16 @@ class ProviderConfiguration(BaseModel): self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL) return self + def bind_model_runtime(self, model_runtime: "ModelRuntime") -> None: + """Attach the already-composed runtime for request-bound call chains.""" + self._bound_model_runtime = model_runtime + + def _get_model_provider_factory(self) -> ModelProviderFactory: + """Reuse a bound runtime when available, else fall back to tenant scope.""" + if self._bound_model_runtime is not None: + return ModelProviderFactory(model_runtime=self._bound_model_runtime) + return create_plugin_model_provider_factory(tenant_id=self.tenant_id) + def get_current_credentials(self, model_type: ModelType, model: str) -> dict | None: """ Get current credentials. @@ -343,7 +362,7 @@ class ProviderConfiguration(BaseModel): tenant_id=self.tenant_id, token=original_credentials[key] ) - model_provider_factory = create_plugin_model_provider_factory(tenant_id=self.tenant_id) + model_provider_factory = self._get_model_provider_factory() validated_credentials = model_provider_factory.provider_credentials_validate( provider=self.provider.provider, credentials=credentials ) @@ -902,7 +921,7 @@ class ProviderConfiguration(BaseModel): tenant_id=self.tenant_id, token=original_credentials[key] ) - model_provider_factory = create_plugin_model_provider_factory(tenant_id=self.tenant_id) + model_provider_factory = self._get_model_provider_factory() validated_credentials = model_provider_factory.model_credentials_validate( provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials ) @@ -1388,7 +1407,7 @@ class ProviderConfiguration(BaseModel): :param model_type: model type :return: """ - model_provider_factory = create_plugin_model_provider_factory(tenant_id=self.tenant_id) + model_provider_factory = self._get_model_provider_factory() # Get model instance of LLM return model_provider_factory.get_model_type_instance(provider=self.provider.provider, model_type=model_type) @@ -1397,7 +1416,7 @@ class ProviderConfiguration(BaseModel): """ Get model schema """ - model_provider_factory = create_plugin_model_provider_factory(tenant_id=self.tenant_id) + model_provider_factory = self._get_model_provider_factory() return model_provider_factory.get_model_schema( provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials ) @@ -1499,7 +1518,7 @@ class ProviderConfiguration(BaseModel): :param model: model name :return: """ - model_provider_factory = create_plugin_model_provider_factory(tenant_id=self.tenant_id) + model_provider_factory = self._get_model_provider_factory() provider_schema = model_provider_factory.get_provider_schema(self.provider.provider) model_types: list[ModelType] = [] diff --git a/api/core/plugin/impl/model_runtime.py b/api/core/plugin/impl/model_runtime.py index 0050eea307..af2a2de3fd 100644 --- a/api/core/plugin/impl/model_runtime.py +++ b/api/core/plugin/impl/model_runtime.py @@ -25,9 +25,12 @@ from models.provider_ids import ModelProviderID logger = logging.getLogger(__name__) +# `TS` means tenant scope +TENANT_SCOPE_SCHEMA_CACHE_USER_ID = "__DIFY_TS__" + class PluginModelRuntime(ModelRuntime): - """Plugin-backed runtime adapter bound to tenant context and a default user.""" + """Plugin-backed runtime adapter bound to tenant context and optional caller scope.""" tenant_id: str user_id: str | None @@ -479,7 +482,10 @@ class PluginModelRuntime(ModelRuntime): model: str, credentials: dict[str, Any], ) -> str: - cache_key = f"{self.tenant_id}:{provider}:{model_type.value}:{model}" + # The plugin daemon distinguishes ``None`` from an explicit empty-string + # caller id, so the cache must only collapse ``None`` into tenant scope. + cache_user_id = TENANT_SCOPE_SCHEMA_CACHE_USER_ID if self.user_id is None else self.user_id + cache_key = f"{self.tenant_id}:{provider}:{model_type.value}:{model}:{cache_user_id}" sorted_credentials = sorted(credentials.items()) if credentials else [] if not sorted_credentials: return cache_key diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 10cf4ad7e7..35cb70847f 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -65,6 +65,9 @@ class ProviderManager: The runtime adapter is injected by the composition layer so this class stays focused on configuration assembly instead of constructing plugin runtimes. + Request-bound managers may carry caller identity in that runtime, and the + resulting ``ProviderConfiguration`` objects must reuse it for downstream + model-type and schema lookups. """ def __init__(self, model_runtime: ModelRuntime): @@ -264,6 +267,7 @@ class ProviderManager: custom_configuration=custom_configuration, model_settings=model_settings, ) + provider_configuration.bind_model_runtime(self._model_runtime) provider_configurations[str(provider_id_entity)] = provider_configuration diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 2e2ca0e168..4de09d09ba 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -161,7 +161,7 @@ class DatasetRetrieval: if request.model_provider is None or request.model_name is None or request.query is None: raise ValueError("model_provider, model_name, and query are required for single retrieval mode") - model_manager = ModelManager.for_tenant(tenant_id=request.tenant_id) + model_manager = ModelManager.for_tenant(tenant_id=request.tenant_id, user_id=request.user_id) model_instance = model_manager.get_model_instance( tenant_id=request.tenant_id, model_type=ModelType.LLM, @@ -384,23 +384,27 @@ class DatasetRetrieval: return None, [] retrieve_config = config.retrieve_config - # check model is support tool calling - model_type_instance = model_config.provider_model_bundle.model_type_instance - model_type_instance = cast(LargeLanguageModel, model_type_instance) - - model_manager = ModelManager.for_tenant(tenant_id=tenant_id) + model_manager = ModelManager.for_tenant(tenant_id=tenant_id, user_id=user_id) model_instance = model_manager.get_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, provider=model_config.provider, model=model_config.model ) + model_type_instance = cast(LargeLanguageModel, model_instance.model_type_instance) - # get model schema + # Reuse the caller-bound model instance for both schema resolution and + # downstream planner/invoke calls so a single request never mixes + # tenant-scope and request-bound runtimes. model_schema = model_type_instance.get_model_schema( - model=model_config.model, credentials=model_config.credentials + model=model_instance.model_name, + credentials=model_instance.credentials, ) if not model_schema: return None, [] + model_config.provider_model_bundle = model_instance.provider_model_bundle + model_config.credentials = model_instance.credentials + model_config.model_schema = model_schema + planning_strategy = PlanningStrategy.REACT_ROUTER features = model_schema.features if features: diff --git a/api/core/rag/retrieval/router/multi_dataset_react_route.py b/api/core/rag/retrieval/router/multi_dataset_react_route.py index 3dddce449e..6051802d33 100644 --- a/api/core/rag/retrieval/router/multi_dataset_react_route.py +++ b/api/core/rag/retrieval/router/multi_dataset_react_route.py @@ -120,6 +120,7 @@ class ReactMultiDatasetRouter: memory_config=None, memory=None, model_config=model_config, + model_instance=model_instance, ) result_text, usage = self._invoke_llm( completion_param=model_config.parameters, diff --git a/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py b/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py index 37d8853a73..84ecc3d2e8 100644 --- a/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py +++ b/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py @@ -436,12 +436,16 @@ def test_get_model_type_instance_and_schema_delegate_to_factory() -> None: mock_factory.get_model_type_instance.return_value = mock_model_type_instance mock_factory.get_model_schema.return_value = mock_schema - with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory): + with patch( + "core.entities.provider_configuration.create_plugin_model_provider_factory", + return_value=mock_factory, + ) as mock_factory_builder: model_type_instance = configuration.get_model_type_instance(ModelType.LLM) model_schema = configuration.get_model_schema(ModelType.LLM, "gpt-4o", {"api_key": "x"}) assert model_type_instance is mock_model_type_instance assert model_schema is mock_schema + assert mock_factory_builder.call_count == 2 mock_factory.get_model_type_instance.assert_called_once_with(provider="openai", model_type=ModelType.LLM) mock_factory.get_model_schema.assert_called_once_with( provider="openai", @@ -451,6 +455,31 @@ def test_get_model_type_instance_and_schema_delegate_to_factory() -> None: ) +def test_get_model_type_instance_and_schema_reuse_bound_runtime_factory() -> None: + configuration = _build_provider_configuration() + bound_runtime = Mock() + configuration.bind_model_runtime(bound_runtime) + + mock_factory = Mock() + mock_model_type_instance = Mock() + mock_schema = _build_ai_model("gpt-4o") + mock_factory.get_model_type_instance.return_value = mock_model_type_instance + mock_factory.get_model_schema.return_value = mock_schema + + with ( + patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory) as mock_factory_cls, + patch("core.entities.provider_configuration.create_plugin_model_provider_factory") as mock_factory_builder, + ): + model_type_instance = configuration.get_model_type_instance(ModelType.LLM) + model_schema = configuration.get_model_schema(ModelType.LLM, "gpt-4o", {"api_key": "x"}) + + assert model_type_instance is mock_model_type_instance + assert model_schema is mock_schema + assert mock_factory_cls.call_count == 2 + mock_factory_cls.assert_called_with(model_runtime=bound_runtime) + mock_factory_builder.assert_not_called() + + def test_get_provider_model_returns_none_when_model_not_found() -> None: configuration = _build_provider_configuration() fake_model = SimpleNamespace(model="other-model") diff --git a/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py b/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py index ce4dc4f174..4c4d864eb8 100644 --- a/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py +++ b/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py @@ -10,7 +10,7 @@ import pytest from core.plugin.entities.plugin_daemon import PluginModelProviderEntity from core.plugin.impl import model_runtime as model_runtime_module from core.plugin.impl.model import PluginModelClient -from core.plugin.impl.model_runtime import PluginModelRuntime +from core.plugin.impl.model_runtime import PluginModelRuntime, TENANT_SCOPE_SCHEMA_CACHE_USER_ID from core.plugin.impl.model_runtime_factory import create_plugin_model_runtime from dify_graph.model_runtime.entities.common_entities import I18nObject from dify_graph.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType @@ -416,6 +416,69 @@ def test_get_schema_cache_key_is_stable_across_credential_order() -> None: assert first == second +def test_get_schema_cache_key_separates_distinct_user_scopes() -> None: + first_runtime = PluginModelRuntime(tenant_id="tenant", user_id="user-a", client=Mock(spec=PluginModelClient)) + second_runtime = PluginModelRuntime(tenant_id="tenant", user_id="user-b", client=Mock(spec=PluginModelClient)) + + first = first_runtime._get_schema_cache_key( + provider="langgenius/openai/openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={"a": "1"}, + ) + second = second_runtime._get_schema_cache_key( + provider="langgenius/openai/openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={"a": "1"}, + ) + + assert first != second + + +def test_get_schema_cache_key_separates_tenant_scope_from_user_scope() -> None: + tenant_runtime = PluginModelRuntime(tenant_id="tenant", user_id=None, client=Mock(spec=PluginModelClient)) + user_runtime = PluginModelRuntime(tenant_id="tenant", user_id="user-a", client=Mock(spec=PluginModelClient)) + + tenant_key = tenant_runtime._get_schema_cache_key( + provider="langgenius/openai/openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={"a": "1"}, + ) + user_key = user_runtime._get_schema_cache_key( + provider="langgenius/openai/openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={"a": "1"}, + ) + + assert tenant_key != user_key + assert f":{TENANT_SCOPE_SCHEMA_CACHE_USER_ID}" in tenant_key + + +def test_get_schema_cache_key_separates_tenant_scope_from_empty_string_user_scope() -> None: + tenant_runtime = PluginModelRuntime(tenant_id="tenant", user_id=None, client=Mock(spec=PluginModelClient)) + empty_user_runtime = PluginModelRuntime(tenant_id="tenant", user_id="", client=Mock(spec=PluginModelClient)) + + tenant_key = tenant_runtime._get_schema_cache_key( + provider="langgenius/openai/openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={}, + ) + empty_user_key = empty_user_runtime._get_schema_cache_key( + provider="langgenius/openai/openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={}, + ) + + assert tenant_key != empty_user_key + assert empty_user_key.endswith(":") + assert TENANT_SCOPE_SCHEMA_CACHE_USER_ID not in empty_user_key + + def test_get_provider_schema_supports_short_alias_and_rejects_invalid_provider() -> None: client = Mock(spec=PluginModelClient) client.fetch_model_providers.return_value = [ diff --git a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py index ebc96ce553..88c364a8fb 100644 --- a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py +++ b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py @@ -4246,6 +4246,7 @@ class TestKnowledgeRetrievalCoverage: mock_model_manager.return_value.get_model_instance.return_value = model_instance with pytest.raises(Exception) as exc_info: retrieval.knowledge_retrieval(request) + mock_model_manager.assert_called_once_with(tenant_id="tenant-1", user_id="user-1") assert error_cls in type(exc_info.value).__name__ @@ -4298,9 +4299,13 @@ class TestRetrieveCoverage: ), ) model_config = self._build_model_config() - model_config.provider_model_bundle.model_type_instance.get_model_schema.return_value = None + model_instance = Mock() + model_instance.model_name = "gpt-4" + model_instance.credentials = {"api_key": "secret"} + model_instance.provider_model_bundle = Mock() + model_instance.model_type_instance.get_model_schema.return_value = None with patch("core.rag.retrieval.dataset_retrieval.ModelManager.for_tenant") as mock_model_manager: - mock_model_manager.return_value.get_model_instance.return_value = Mock() + mock_model_manager.return_value.get_model_instance.return_value = model_instance result = retrieval.retrieve( app_id="app-1", user_id="user-1", @@ -4313,8 +4318,58 @@ class TestRetrieveCoverage: hit_callback=Mock(), message_id="m1", ) + mock_model_manager.assert_called_once_with(tenant_id="tenant-1", user_id="user-1") assert result == (None, []) + def test_retrieve_uses_bound_model_instance_schema_and_updates_model_config( + self, retrieval: DatasetRetrieval + ) -> None: + config = DatasetEntity( + dataset_ids=["d1"], + retrieve_config=DatasetRetrieveConfigEntity( + retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE, + metadata_filtering_mode="disabled", + ), + ) + model_config = self._build_model_config(features=[]) + model_config.provider_model_bundle.model_type_instance.get_model_schema.return_value = None + bound_schema = SimpleNamespace(features=[ModelFeature.TOOL_CALL]) + bound_bundle = Mock() + bound_model_instance = Mock() + bound_model_instance.model_name = "gpt-4" + bound_model_instance.credentials = {"api_key": "secret"} + bound_model_instance.provider_model_bundle = bound_bundle + bound_model_instance.model_type_instance.get_model_schema.return_value = bound_schema + + with ( + patch("core.rag.retrieval.dataset_retrieval.ModelManager.for_tenant") as mock_model_manager, + patch.object(retrieval, "_get_available_datasets", return_value=[SimpleNamespace(id="d1")]), + patch.object(retrieval, "get_metadata_filter_condition", return_value=(None, None)), + patch.object(retrieval, "single_retrieve", return_value=[]) as mock_single_retrieve, + ): + mock_model_manager.return_value.get_model_instance.return_value = bound_model_instance + context, files = retrieval.retrieve( + app_id="app-1", + user_id="user-1", + tenant_id="tenant-1", + model_config=model_config, + config=config, + query="python", + invoke_from=InvokeFrom.WEB_APP, + show_retrieve_source=False, + hit_callback=Mock(), + message_id="m1", + ) + + mock_model_manager.assert_called_once_with(tenant_id="tenant-1", user_id="user-1") + mock_single_retrieve.assert_called_once() + assert mock_single_retrieve.call_args.args[8] == PlanningStrategy.ROUTER + assert model_config.provider_model_bundle is bound_bundle + assert model_config.credentials == {"api_key": "secret"} + assert model_config.model_schema is bound_schema + assert context == "" + assert files == [] + def test_single_strategy_with_external_documents(self, retrieval: DatasetRetrieval) -> None: retrieve_config = DatasetRetrieveConfigEntity( retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE, @@ -4336,7 +4391,12 @@ class TestRetrieveCoverage: patch.object(retrieval, "get_metadata_filter_condition", return_value=(None, None)), patch.object(retrieval, "single_retrieve", return_value=[external_doc]), ): - mock_model_manager.return_value.get_model_instance.return_value = Mock() + bound_model_instance = Mock() + bound_model_instance.model_name = "gpt-4" + bound_model_instance.credentials = {} + bound_model_instance.provider_model_bundle = Mock() + bound_model_instance.model_type_instance.get_model_schema.return_value = SimpleNamespace(features=[]) + mock_model_manager.return_value.get_model_instance.return_value = bound_model_instance context, files = retrieval.retrieve( app_id="app-1", user_id="user-1", @@ -4432,7 +4492,14 @@ class TestRetrieveCoverage: patch("core.rag.retrieval.dataset_retrieval.sign_upload_file", return_value="https://signed"), patch("core.rag.retrieval.dataset_retrieval.db.session.execute") as mock_execute, ): - mock_model_manager.return_value.get_model_instance.return_value = Mock() + bound_model_instance = Mock() + bound_model_instance.model_name = "gpt-4" + bound_model_instance.credentials = {} + bound_model_instance.provider_model_bundle = Mock() + bound_model_instance.model_type_instance.get_model_schema.return_value = SimpleNamespace( + features=[ModelFeature.TOOL_CALL] + ) + mock_model_manager.return_value.get_model_instance.return_value = bound_model_instance mock_execute.side_effect = [execute_attachments, execute_docs, execute_datasets] context, files = retrieval.retrieve( app_id="app-1", diff --git a/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_react_route.py b/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_react_route.py index de8b31da7a..206540de1b 100644 --- a/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_react_route.py +++ b/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_react_route.py @@ -5,6 +5,7 @@ from core.rag.retrieval.output_parser.react_output import ReactAction, ReactFini from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter from dify_graph.model_runtime.entities.llm_entities import LLMUsage from dify_graph.model_runtime.entities.message_entities import PromptMessageRole +from dify_graph.model_runtime.entities.model_entities import ModelType class TestReactMultiDatasetRouter: @@ -87,6 +88,7 @@ class TestReactMultiDatasetRouter: model_config = Mock() model_config.mode = "chat" model_config.parameters = {"temperature": 0.1} + model_instance = Mock() usage = LLMUsage.empty_usage() tools = [Mock(name="dataset-1"), Mock(name="dataset-2")] tools[0].name = "dataset-1" @@ -108,13 +110,14 @@ class TestReactMultiDatasetRouter: dataset_id, returned_usage = router._react_invoke( query="python", model_config=model_config, - model_instance=Mock(), + model_instance=model_instance, tools=tools, user_id="u1", tenant_id="t1", ) mock_chat_prompt.assert_called_once() + assert mock_prompt_transform.return_value.get_prompt.call_args.kwargs["model_instance"] is model_instance assert dataset_id == "dataset-2" assert returned_usage == usage @@ -178,7 +181,13 @@ class TestReactMultiDatasetRouter: assert text == "part" assert returned_usage == usage - mock_manager.assert_any_call(tenant_id="t1", user_id="u1") + mock_manager.assert_called_once_with(tenant_id="t1", user_id="u1") + mock_manager.return_value.get_model_instance.assert_called_once_with( + tenant_id="t1", + provider=model_instance.provider, + model_type=ModelType.LLM, + model=model_instance.model_name, + ) mock_deduct.assert_called_once() def test_handle_invoke_result_with_empty_usage(self) -> None: diff --git a/api/tests/unit_tests/core/test_provider_manager.py b/api/tests/unit_tests/core/test_provider_manager.py index f8875f0552..6caf7bc004 100644 --- a/api/tests/unit_tests/core/test_provider_manager.py +++ b/api/tests/unit_tests/core/test_provider_manager.py @@ -343,6 +343,35 @@ def test_get_provider_model_bundle_raises_for_unknown_provider(mocker: MockerFix manager.get_provider_model_bundle("tenant-id", "openai", ModelType.LLM) +def test_get_configurations_binds_manager_runtime_to_provider_configuration( + mocker: MockerFixture, mock_provider_entity +): + manager = _build_provider_manager(mocker) + provider_configuration = Mock() + provider_factory = Mock() + provider_factory.get_providers.return_value = [mock_provider_entity] + custom_configuration = SimpleNamespace(provider=None, models=[]) + system_configuration = SimpleNamespace(enabled=False, quota_configurations=[], current_quota_type=None) + + with ( + patch.object(manager, "_get_all_providers", return_value={"openai": []}), + patch.object(manager, "_init_trial_provider_records", return_value={"openai": []}), + patch.object(manager, "_get_all_provider_models", return_value={"openai": []}), + patch.object(manager, "_get_all_preferred_model_providers", return_value={}), + patch.object(manager, "_get_all_provider_model_settings", return_value={}), + patch.object(manager, "_get_all_provider_load_balancing_configs", return_value={}), + patch.object(manager, "_get_all_provider_model_credentials", return_value={}), + patch.object(manager, "_to_custom_configuration", return_value=custom_configuration), + patch.object(manager, "_to_system_configuration", return_value=system_configuration), + patch.object(manager, "_to_model_settings", return_value=[]), + patch("core.provider_manager.ModelProviderFactory", return_value=provider_factory), + patch("core.provider_manager.ProviderConfiguration", return_value=provider_configuration), + ): + manager.get_configurations("tenant-id") + + provider_configuration.bind_model_runtime.assert_called_once_with(manager._model_runtime) + + def test_get_provider_model_bundle_returns_selected_model_type_instance(mocker: MockerFixture): manager = _build_provider_manager(mocker) provider_configuration = Mock()