mirror of
https://github.com/langgenius/dify.git
synced 2026-05-30 07:00:37 -04:00
test: expand provider manager coverage
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, Mock, PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
@@ -8,12 +9,20 @@ from core.provider_manager import ProviderManager
|
||||
from dify_graph.model_runtime.entities.common_entities import I18nObject
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from models.provider import LoadBalancingModelConfig, ProviderModelSetting, TenantDefaultModel
|
||||
from models.provider_ids import ModelProviderID
|
||||
|
||||
|
||||
def _build_provider_manager(mocker: MockerFixture) -> ProviderManager:
|
||||
return ProviderManager(model_runtime=mocker.Mock())
|
||||
|
||||
|
||||
def _build_session_context(session: Mock) -> MagicMock:
|
||||
session_cm = MagicMock()
|
||||
session_cm.__enter__.return_value = session
|
||||
session_cm.__exit__.return_value = False
|
||||
return session_cm
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_provider_entity():
|
||||
mock_entity = Mock()
|
||||
@@ -235,6 +244,85 @@ def test_get_default_model_uses_first_available_active_model(mocker: MockerFixtu
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_get_default_model_returns_none_when_no_default_or_active_models(mocker: MockerFixture):
|
||||
mock_session = Mock()
|
||||
mock_session.scalar.return_value = None
|
||||
provider_configurations = Mock()
|
||||
provider_configurations.get_models.return_value = []
|
||||
manager = _build_provider_manager(mocker)
|
||||
|
||||
with (
|
||||
patch("core.provider_manager.db.session", mock_session),
|
||||
patch.object(manager, "get_configurations", return_value=provider_configurations),
|
||||
patch("core.provider_manager.ModelProviderFactory") as mock_factory_cls,
|
||||
):
|
||||
result = manager.get_default_model("tenant-id", ModelType.LLM)
|
||||
|
||||
assert result is None
|
||||
provider_configurations.get_models.assert_called_once_with(model_type=ModelType.LLM, only_active=True)
|
||||
mock_factory_cls.assert_not_called()
|
||||
mock_session.add.assert_not_called()
|
||||
mock_session.commit.assert_not_called()
|
||||
|
||||
|
||||
def test_get_default_model_uses_injected_runtime_for_existing_default_record(mocker: MockerFixture):
|
||||
existing_default_model = TenantDefaultModel(
|
||||
tenant_id="tenant-id",
|
||||
provider_name="openai",
|
||||
model_name="gpt-4",
|
||||
model_type=ModelType.LLM.to_origin_model_type(),
|
||||
)
|
||||
mock_session = Mock()
|
||||
mock_session.scalar.return_value = existing_default_model
|
||||
manager = _build_provider_manager(mocker)
|
||||
|
||||
with (
|
||||
patch("core.provider_manager.db.session", mock_session),
|
||||
patch("core.provider_manager.ModelProviderFactory") as mock_factory_cls,
|
||||
):
|
||||
mock_factory_cls.return_value.get_provider_schema.return_value = Mock(
|
||||
provider="openai",
|
||||
label=I18nObject(en_US="OpenAI", zh_Hans="OpenAI"),
|
||||
icon_small=I18nObject(en_US="icon_small.png", zh_Hans="icon_small.png"),
|
||||
supported_model_types=[ModelType.LLM],
|
||||
)
|
||||
|
||||
result = manager.get_default_model("tenant-id", ModelType.LLM)
|
||||
|
||||
mock_factory_cls.assert_called_once_with(model_runtime=manager._model_runtime)
|
||||
assert result is not None
|
||||
assert result.model == "gpt-4"
|
||||
assert result.provider.provider == "openai"
|
||||
|
||||
|
||||
def test_get_configurations_uses_injected_runtime_and_adds_provider_aliases(mocker: MockerFixture):
|
||||
manager = _build_provider_manager(mocker)
|
||||
provider_records = {"openai": [SimpleNamespace(provider_name="openai")]}
|
||||
provider_model_records = {"openai": [SimpleNamespace(provider_name="openai")]}
|
||||
preferred_provider_records = {"openai": SimpleNamespace(preferred_provider_type="system")}
|
||||
|
||||
with (
|
||||
patch.object(manager, "_get_all_providers", return_value=provider_records),
|
||||
patch.object(manager, "_init_trial_provider_records", return_value=provider_records),
|
||||
patch.object(manager, "_get_all_provider_models", return_value=provider_model_records),
|
||||
patch.object(manager, "_get_all_preferred_model_providers", return_value=preferred_provider_records),
|
||||
patch.object(manager, "_get_all_provider_model_settings", return_value={}),
|
||||
patch.object(manager, "_get_all_provider_load_balancing_configs", return_value={}),
|
||||
patch.object(manager, "_get_all_provider_model_credentials", return_value={}),
|
||||
patch("core.provider_manager.ModelProviderFactory") as mock_factory_cls,
|
||||
):
|
||||
mock_factory_cls.return_value.get_providers.return_value = []
|
||||
|
||||
result = manager.get_configurations("tenant-id")
|
||||
|
||||
expected_alias = str(ModelProviderID("openai"))
|
||||
mock_factory_cls.assert_called_once_with(model_runtime=manager._model_runtime)
|
||||
assert result.tenant_id == "tenant-id"
|
||||
assert expected_alias in provider_records
|
||||
assert expected_alias in provider_model_records
|
||||
assert expected_alias in preferred_provider_records
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("provider_name", "expected_provider_names"),
|
||||
[
|
||||
@@ -276,6 +364,32 @@ def test_get_provider_model_bundle_returns_selected_model_type_instance(mocker:
|
||||
assert result is expected_bundle
|
||||
|
||||
|
||||
def test_get_first_provider_first_model_returns_none_when_no_models(mocker: MockerFixture):
|
||||
manager = _build_provider_manager(mocker)
|
||||
provider_configurations = Mock()
|
||||
provider_configurations.get_models.return_value = []
|
||||
|
||||
with patch.object(manager, "get_configurations", return_value=provider_configurations):
|
||||
result = manager.get_first_provider_first_model("tenant-id", ModelType.LLM)
|
||||
|
||||
assert result == (None, None)
|
||||
provider_configurations.get_models.assert_called_once_with(model_type=ModelType.LLM, only_active=False)
|
||||
|
||||
|
||||
def test_get_first_provider_first_model_returns_first_model_and_provider(mocker: MockerFixture):
|
||||
manager = _build_provider_manager(mocker)
|
||||
provider_configurations = Mock()
|
||||
provider_configurations.get_models.return_value = [
|
||||
Mock(model="gpt-4", provider=Mock(provider="openai")),
|
||||
Mock(model="gpt-4o", provider=Mock(provider="openai")),
|
||||
]
|
||||
|
||||
with patch.object(manager, "get_configurations", return_value=provider_configurations):
|
||||
result = manager.get_first_provider_first_model("tenant-id", ModelType.LLM)
|
||||
|
||||
assert result == ("openai", "gpt-4")
|
||||
|
||||
|
||||
def test_update_default_model_record_raises_for_unknown_provider(mocker: MockerFixture):
|
||||
manager = _build_provider_manager(mocker)
|
||||
|
||||
@@ -346,3 +460,98 @@ def test_update_default_model_record_creates_record_with_origin_model_type(mocke
|
||||
assert created_default_model.model_name == "gpt-4"
|
||||
assert created_default_model.model_type == ModelType.LLM.to_origin_model_type()
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_get_all_providers_normalizes_provider_names_with_model_provider_id() -> None:
|
||||
session = Mock()
|
||||
openai_provider = SimpleNamespace(provider_name="openai")
|
||||
gemini_provider = SimpleNamespace(provider_name="langgenius/gemini/google")
|
||||
session.scalars.return_value = [openai_provider, gemini_provider]
|
||||
|
||||
with (
|
||||
patch("core.provider_manager.db", SimpleNamespace(engine=object())),
|
||||
patch("core.provider_manager.Session", return_value=_build_session_context(session)),
|
||||
):
|
||||
result = ProviderManager._get_all_providers("tenant-id")
|
||||
|
||||
assert list(result[str(ModelProviderID("openai"))]) == [openai_provider]
|
||||
assert list(result[str(ModelProviderID("langgenius/gemini/google"))]) == [gemini_provider]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"method_name",
|
||||
[
|
||||
"_get_all_provider_models",
|
||||
"_get_all_provider_model_settings",
|
||||
"_get_all_provider_model_credentials",
|
||||
],
|
||||
)
|
||||
def test_provider_grouping_helpers_group_records_by_provider_name(method_name: str) -> None:
|
||||
session = Mock()
|
||||
openai_primary = SimpleNamespace(provider_name="openai")
|
||||
openai_secondary = SimpleNamespace(provider_name="openai")
|
||||
anthropic_record = SimpleNamespace(provider_name="anthropic")
|
||||
session.scalars.return_value = [openai_primary, openai_secondary, anthropic_record]
|
||||
|
||||
with (
|
||||
patch("core.provider_manager.db", SimpleNamespace(engine=object())),
|
||||
patch("core.provider_manager.Session", return_value=_build_session_context(session)),
|
||||
):
|
||||
result = getattr(ProviderManager, method_name)("tenant-id")
|
||||
|
||||
assert list(result["openai"]) == [openai_primary, openai_secondary]
|
||||
assert list(result["anthropic"]) == [anthropic_record]
|
||||
|
||||
|
||||
def test_get_all_preferred_model_providers_returns_mapping_by_provider_name() -> None:
|
||||
session = Mock()
|
||||
openai_preference = SimpleNamespace(provider_name="openai")
|
||||
anthropic_preference = SimpleNamespace(provider_name="anthropic")
|
||||
session.scalars.return_value = [openai_preference, anthropic_preference]
|
||||
|
||||
with (
|
||||
patch("core.provider_manager.db", SimpleNamespace(engine=object())),
|
||||
patch("core.provider_manager.Session", return_value=_build_session_context(session)),
|
||||
):
|
||||
result = ProviderManager._get_all_preferred_model_providers("tenant-id")
|
||||
|
||||
assert result == {
|
||||
"openai": openai_preference,
|
||||
"anthropic": anthropic_preference,
|
||||
}
|
||||
|
||||
|
||||
def test_get_all_provider_load_balancing_configs_returns_empty_when_cached_flag_is_disabled() -> None:
|
||||
with (
|
||||
patch("core.provider_manager.redis_client.get", return_value=b"False"),
|
||||
patch("core.provider_manager.FeatureService.get_features") as mock_get_features,
|
||||
patch("core.provider_manager.Session") as mock_session_cls,
|
||||
):
|
||||
result = ProviderManager._get_all_provider_load_balancing_configs("tenant-id")
|
||||
|
||||
assert result == {}
|
||||
mock_get_features.assert_not_called()
|
||||
mock_session_cls.assert_not_called()
|
||||
|
||||
|
||||
def test_get_all_provider_load_balancing_configs_populates_cache_and_groups_configs() -> None:
|
||||
session = Mock()
|
||||
openai_config = SimpleNamespace(provider_name="openai")
|
||||
anthropic_config = SimpleNamespace(provider_name="anthropic")
|
||||
session.scalars.return_value = [openai_config, anthropic_config]
|
||||
|
||||
with (
|
||||
patch("core.provider_manager.db", SimpleNamespace(engine=object())),
|
||||
patch("core.provider_manager.redis_client.get", return_value=None),
|
||||
patch("core.provider_manager.redis_client.setex") as mock_setex,
|
||||
patch(
|
||||
"core.provider_manager.FeatureService.get_features",
|
||||
return_value=SimpleNamespace(model_load_balancing_enabled=True),
|
||||
),
|
||||
patch("core.provider_manager.Session", return_value=_build_session_context(session)),
|
||||
):
|
||||
result = ProviderManager._get_all_provider_load_balancing_configs("tenant-id")
|
||||
|
||||
mock_setex.assert_called_once_with("tenant:tenant-id:model_load_balancing_enabled", 120, "True")
|
||||
assert list(result["openai"]) == [openai_config]
|
||||
assert list(result["anthropic"]) == [anthropic_config]
|
||||
|
||||
Reference in New Issue
Block a user