mirror of
https://github.com/langgenius/dify.git
synced 2026-05-30 16:00:32 -04:00
feat: extract model runtime
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
@@ -4,7 +4,7 @@ from dify_graph.model_runtime.entities.message_entities import (
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from dify_graph.model_runtime.model_providers.__base.large_language_model import _normalize_non_stream_plugin_result
|
||||
from dify_graph.model_runtime.model_providers.__base.large_language_model import _normalize_non_stream_runtime_result
|
||||
|
||||
|
||||
def _make_chunk(
|
||||
@@ -20,7 +20,7 @@ def _make_chunk(
|
||||
return LLMResultChunk(model=model, delta=delta, system_fingerprint=system_fingerprint)
|
||||
|
||||
|
||||
def test__normalize_non_stream_plugin_result__from_first_chunk_str_content_and_tool_calls():
|
||||
def test__normalize_non_stream_runtime_result__from_first_chunk_str_content_and_tool_calls():
|
||||
prompt_messages = [UserPromptMessage(content="hi")]
|
||||
|
||||
tool_calls = [
|
||||
@@ -44,7 +44,7 @@ def test__normalize_non_stream_plugin_result__from_first_chunk_str_content_and_t
|
||||
usage = LLMUsage.empty_usage().model_copy(update={"prompt_tokens": 1, "total_tokens": 1})
|
||||
chunk = _make_chunk(content="hello", tool_calls=tool_calls, usage=usage, system_fingerprint="fp-1")
|
||||
|
||||
result = _normalize_non_stream_plugin_result(
|
||||
result = _normalize_non_stream_runtime_result(
|
||||
model="test-model", prompt_messages=prompt_messages, result=iter([chunk])
|
||||
)
|
||||
|
||||
@@ -62,20 +62,20 @@ def test__normalize_non_stream_plugin_result__from_first_chunk_str_content_and_t
|
||||
]
|
||||
|
||||
|
||||
def test__normalize_non_stream_plugin_result__from_first_chunk_list_content():
|
||||
def test__normalize_non_stream_runtime_result__from_first_chunk_list_content():
|
||||
prompt_messages = [UserPromptMessage(content="hi")]
|
||||
|
||||
content_list = [TextPromptMessageContent(data="a"), TextPromptMessageContent(data="b")]
|
||||
chunk = _make_chunk(content=content_list, usage=LLMUsage.empty_usage())
|
||||
|
||||
result = _normalize_non_stream_plugin_result(
|
||||
result = _normalize_non_stream_runtime_result(
|
||||
model="test-model", prompt_messages=prompt_messages, result=iter([chunk])
|
||||
)
|
||||
|
||||
assert result.message.content == content_list
|
||||
|
||||
|
||||
def test__normalize_non_stream_plugin_result__passthrough_llm_result():
|
||||
def test__normalize_non_stream_runtime_result__passthrough_llm_result():
|
||||
prompt_messages = [UserPromptMessage(content="hi")]
|
||||
llm_result = LLMResult(
|
||||
model="test-model",
|
||||
@@ -85,15 +85,15 @@ def test__normalize_non_stream_plugin_result__passthrough_llm_result():
|
||||
)
|
||||
|
||||
assert (
|
||||
_normalize_non_stream_plugin_result(model="test-model", prompt_messages=prompt_messages, result=llm_result)
|
||||
_normalize_non_stream_runtime_result(model="test-model", prompt_messages=prompt_messages, result=llm_result)
|
||||
== llm_result
|
||||
)
|
||||
|
||||
|
||||
def test__normalize_non_stream_plugin_result__empty_iterator_defaults():
|
||||
def test__normalize_non_stream_runtime_result__empty_iterator_defaults():
|
||||
prompt_messages = [UserPromptMessage(content="hi")]
|
||||
|
||||
result = _normalize_non_stream_plugin_result(model="test-model", prompt_messages=prompt_messages, result=iter([]))
|
||||
result = _normalize_non_stream_runtime_result(model="test-model", prompt_messages=prompt_messages, result=iter([]))
|
||||
|
||||
assert result.model == "test-model"
|
||||
assert result.prompt_messages == prompt_messages
|
||||
@@ -103,7 +103,7 @@ def test__normalize_non_stream_plugin_result__empty_iterator_defaults():
|
||||
assert result.system_fingerprint is None
|
||||
|
||||
|
||||
def test__normalize_non_stream_plugin_result__accumulates_all_chunks():
|
||||
def test__normalize_non_stream_runtime_result__accumulates_all_chunks():
|
||||
"""All chunks are accumulated from the iterator."""
|
||||
prompt_messages = [UserPromptMessage(content="hi")]
|
||||
|
||||
@@ -116,7 +116,7 @@ def test__normalize_non_stream_plugin_result__accumulates_all_chunks():
|
||||
finally:
|
||||
closed.append(True)
|
||||
|
||||
result = _normalize_non_stream_plugin_result(
|
||||
result = _normalize_non_stream_runtime_result(
|
||||
model="test-model",
|
||||
prompt_messages=prompt_messages,
|
||||
result=_chunk_iter(),
|
||||
|
||||
@@ -2,9 +2,8 @@ import decimal
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from redis import RedisError
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.plugin.entities.plugin_daemon import PluginDaemonInnerError, PluginModelProviderEntity
|
||||
from dify_graph.model_runtime.entities.common_entities import I18nObject
|
||||
from dify_graph.model_runtime.entities.model_entities import (
|
||||
AIModelEntity,
|
||||
@@ -17,6 +16,7 @@ from dify_graph.model_runtime.entities.model_entities import (
|
||||
PriceConfig,
|
||||
PriceType,
|
||||
)
|
||||
from dify_graph.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity
|
||||
from dify_graph.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
@@ -28,46 +28,72 @@ from dify_graph.model_runtime.errors.invoke import (
|
||||
from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
|
||||
|
||||
class _ConcreteAIModel(AIModel):
|
||||
model_type = ModelType.LLM
|
||||
|
||||
|
||||
class TestAIModel:
|
||||
@pytest.fixture
|
||||
def mock_plugin_model_provider(self):
|
||||
return MagicMock(spec=PluginModelProviderEntity)
|
||||
|
||||
@pytest.fixture
|
||||
def ai_model(self, mock_plugin_model_provider):
|
||||
return AIModel(
|
||||
tenant_id="tenant_123",
|
||||
model_type=ModelType.LLM,
|
||||
plugin_id="plugin_123",
|
||||
provider_name="test_provider",
|
||||
plugin_model_provider=mock_plugin_model_provider,
|
||||
def provider_schema(self) -> ProviderEntity:
|
||||
return ProviderEntity(
|
||||
provider="langgenius/openai/openai",
|
||||
provider_name="openai",
|
||||
label=I18nObject(en_US="OpenAI"),
|
||||
supported_model_types=[ModelType.LLM],
|
||||
configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL],
|
||||
)
|
||||
|
||||
def test_invoke_error_mapping(self, ai_model):
|
||||
@pytest.fixture
|
||||
def model_runtime(self) -> MagicMock:
|
||||
return MagicMock()
|
||||
|
||||
@pytest.fixture
|
||||
def ai_model(self, provider_schema: ProviderEntity, model_runtime: MagicMock) -> AIModel:
|
||||
return _ConcreteAIModel(
|
||||
provider_schema=provider_schema,
|
||||
model_runtime=model_runtime,
|
||||
)
|
||||
|
||||
def test_init_stores_runtime_state_and_is_not_pydantic_model(
|
||||
self, ai_model: AIModel, provider_schema: ProviderEntity, model_runtime: MagicMock
|
||||
) -> None:
|
||||
assert ai_model.model_type == ModelType.LLM
|
||||
assert ai_model.provider_schema is provider_schema
|
||||
assert ai_model.model_runtime is model_runtime
|
||||
assert ai_model.provider == "langgenius/openai/openai"
|
||||
assert ai_model.provider_display_name == "OpenAI"
|
||||
assert ai_model.started_at == 0
|
||||
assert not isinstance(ai_model, BaseModel)
|
||||
|
||||
def test_direct_base_class_requires_subclass_model_type(
|
||||
self, provider_schema: ProviderEntity, model_runtime: MagicMock
|
||||
) -> None:
|
||||
with pytest.raises(TypeError, match="subclasses must define model_type"):
|
||||
AIModel(provider_schema=provider_schema, model_runtime=model_runtime)
|
||||
|
||||
def test_subclass_uses_class_level_model_type(
|
||||
self, provider_schema: ProviderEntity, model_runtime: MagicMock
|
||||
) -> None:
|
||||
model = _ConcreteAIModel(provider_schema=provider_schema, model_runtime=model_runtime)
|
||||
assert model.model_type == ModelType.LLM
|
||||
|
||||
def test_invoke_error_mapping(self, ai_model: AIModel) -> None:
|
||||
mapping = ai_model._invoke_error_mapping
|
||||
assert InvokeConnectionError in mapping
|
||||
assert InvokeServerUnavailableError in mapping
|
||||
assert InvokeRateLimitError in mapping
|
||||
assert InvokeAuthorizationError in mapping
|
||||
assert InvokeBadRequestError in mapping
|
||||
assert PluginDaemonInnerError in mapping
|
||||
assert ValueError in mapping
|
||||
|
||||
def test_transform_invoke_error(self, ai_model):
|
||||
# Case: mapped error (InvokeAuthorizationError)
|
||||
def test_transform_invoke_error(self, ai_model: AIModel) -> None:
|
||||
err = Exception("Original error")
|
||||
|
||||
with patch.object(AIModel, "_invoke_error_mapping", {InvokeAuthorizationError: [Exception]}):
|
||||
transformed = ai_model._transform_invoke_error(err)
|
||||
assert isinstance(transformed, InvokeAuthorizationError)
|
||||
assert "Incorrect model credentials provided" in str(transformed.description)
|
||||
|
||||
# Case: mapped error (InvokeError subclass)
|
||||
with patch.object(AIModel, "_invoke_error_mapping", {InvokeRateLimitError("Rate limit"): [Exception]}):
|
||||
transformed = ai_model._transform_invoke_error(err)
|
||||
assert isinstance(transformed, InvokeError)
|
||||
assert "[test_provider]" in transformed.description
|
||||
|
||||
# Case: mapped error (not InvokeError)
|
||||
class CustomNonInvokeError(Exception):
|
||||
pass
|
||||
|
||||
@@ -75,52 +101,43 @@ class TestAIModel:
|
||||
transformed = ai_model._transform_invoke_error(err)
|
||||
assert transformed == err
|
||||
|
||||
# Case: unmapped error
|
||||
unmapped_err = Exception("Unmapped")
|
||||
transformed = ai_model._transform_invoke_error(unmapped_err)
|
||||
transformed = ai_model._transform_invoke_error(Exception("Unmapped"))
|
||||
assert isinstance(transformed, InvokeError)
|
||||
assert "Error: Unmapped" in transformed.description
|
||||
assert transformed.description == "[OpenAI] Error: Unmapped"
|
||||
|
||||
def test_get_price(self, ai_model):
|
||||
def test_get_price(self, ai_model: AIModel) -> None:
|
||||
model_name = "test_model"
|
||||
credentials = {"key": "value"}
|
||||
|
||||
# Mock get_model_schema
|
||||
mock_schema = MagicMock(spec=AIModelEntity)
|
||||
mock_schema.pricing = PriceConfig(
|
||||
input=decimal.Decimal("0.002"),
|
||||
output=decimal.Decimal("0.004"),
|
||||
unit=decimal.Decimal(1000), # 1000 tokens per unit
|
||||
unit=decimal.Decimal(1000),
|
||||
currency="USD",
|
||||
)
|
||||
|
||||
with patch.object(AIModel, "get_model_schema", return_value=mock_schema):
|
||||
# Test INPUT
|
||||
price_info = ai_model.get_price(model_name, credentials, PriceType.INPUT, 2000)
|
||||
assert price_info.unit_price == decimal.Decimal("0.002")
|
||||
|
||||
# Test OUTPUT
|
||||
price_info = ai_model.get_price(model_name, credentials, PriceType.OUTPUT, 2000)
|
||||
assert price_info.unit_price == decimal.Decimal("0.004")
|
||||
|
||||
# Case: unit_price is None (returns zeroed PriceInfo)
|
||||
mock_schema.pricing = None
|
||||
with patch.object(AIModel, "get_model_schema", return_value=mock_schema):
|
||||
price_info = ai_model.get_price(model_name, credentials, PriceType.INPUT, 1000)
|
||||
assert price_info.total_amount == decimal.Decimal("0.0")
|
||||
|
||||
def test_get_price_no_price_config_error(self, ai_model):
|
||||
model_name = "test_model"
|
||||
|
||||
# We need it to be truthy at line 107 and 112 but falsy at line 127.
|
||||
def test_get_price_no_price_config_error(self, ai_model: AIModel) -> None:
|
||||
class ChangingPriceConfig:
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self.input = decimal.Decimal("0.01")
|
||||
self.unit = decimal.Decimal(1)
|
||||
self.currency = "USD"
|
||||
self.called = 0
|
||||
|
||||
def __bool__(self):
|
||||
def __bool__(self) -> bool:
|
||||
self.called += 1
|
||||
return self.called <= 2
|
||||
|
||||
@@ -128,11 +145,12 @@ class TestAIModel:
|
||||
mock_schema.pricing = ChangingPriceConfig()
|
||||
|
||||
with patch.object(AIModel, "get_model_schema", return_value=mock_schema):
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
ai_model.get_price(model_name, {}, PriceType.INPUT, 1000)
|
||||
assert "Price config not found" in str(excinfo.value)
|
||||
with pytest.raises(ValueError, match="Price config not found"):
|
||||
ai_model.get_price("test_model", {}, PriceType.INPUT, 1000)
|
||||
|
||||
def test_get_model_schema_cache_hit(self, ai_model):
|
||||
def test_get_model_schema_delegates_to_runtime(
|
||||
self, ai_model: AIModel, model_runtime: MagicMock, provider_schema: ProviderEntity
|
||||
) -> None:
|
||||
model_name = "test_model"
|
||||
credentials = {"api_key": "abc"}
|
||||
|
||||
@@ -144,120 +162,21 @@ class TestAIModel:
|
||||
model_properties={ModelPropertyKey.CONTEXT_SIZE: 1024},
|
||||
parameter_rules=[],
|
||||
)
|
||||
model_runtime.get_model_schema.return_value = mock_schema
|
||||
|
||||
with patch("dify_graph.model_runtime.model_providers.__base.ai_model.redis_client") as mock_redis:
|
||||
mock_redis.get.return_value = mock_schema.model_dump_json().encode()
|
||||
schema = ai_model.get_model_schema(model_name, credentials)
|
||||
|
||||
schema = ai_model.get_model_schema(model_name, credentials)
|
||||
|
||||
assert schema.model == "test_model"
|
||||
mock_redis.get.assert_called_once()
|
||||
|
||||
def test_get_model_schema_cache_miss(self, ai_model):
|
||||
model_name = "test_model"
|
||||
credentials = {"api_key": "abc"}
|
||||
|
||||
mock_schema = AIModelEntity(
|
||||
model="test_model",
|
||||
label=I18nObject(en_US="Test Model"),
|
||||
assert schema == mock_schema
|
||||
model_runtime.get_model_schema.assert_called_once_with(
|
||||
provider=provider_schema.provider,
|
||||
model_type=ModelType.LLM,
|
||||
fetch_from=FetchFrom.PREDEFINED_MODEL,
|
||||
model_properties={ModelPropertyKey.CONTEXT_SIZE: 1024},
|
||||
parameter_rules=[],
|
||||
model=model_name,
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
with (
|
||||
patch("dify_graph.model_runtime.model_providers.__base.ai_model.redis_client") as mock_redis,
|
||||
patch("core.plugin.impl.model.PluginModelClient") as mock_client,
|
||||
):
|
||||
mock_redis.get.return_value = None
|
||||
mock_manager = mock_client.return_value
|
||||
mock_manager.get_model_schema.return_value = mock_schema
|
||||
|
||||
schema = ai_model.get_model_schema(model_name, credentials)
|
||||
|
||||
assert schema == mock_schema
|
||||
mock_manager.get_model_schema.assert_called_once()
|
||||
mock_redis.setex.assert_called_once()
|
||||
|
||||
def test_get_model_schema_redis_error(self, ai_model):
|
||||
model_name = "test_model"
|
||||
|
||||
with (
|
||||
patch("dify_graph.model_runtime.model_providers.__base.ai_model.redis_client") as mock_redis,
|
||||
patch("core.plugin.impl.model.PluginModelClient") as mock_client,
|
||||
):
|
||||
mock_redis.get.side_effect = RedisError("Connection refused")
|
||||
mock_manager = mock_client.return_value
|
||||
mock_manager.get_model_schema.return_value = None
|
||||
|
||||
schema = ai_model.get_model_schema(model_name, {})
|
||||
|
||||
assert schema is None
|
||||
mock_manager.get_model_schema.assert_called_once()
|
||||
|
||||
def test_get_model_schema_validation_error(self, ai_model):
|
||||
model_name = "test_model"
|
||||
|
||||
with (
|
||||
patch("dify_graph.model_runtime.model_providers.__base.ai_model.redis_client") as mock_redis,
|
||||
patch("core.plugin.impl.model.PluginModelClient") as mock_client,
|
||||
):
|
||||
mock_redis.get.return_value = b"invalid json"
|
||||
mock_manager = mock_client.return_value
|
||||
mock_manager.get_model_schema.return_value = None
|
||||
|
||||
# This should trigger ValidationError at line 166 and go to delete()
|
||||
schema = ai_model.get_model_schema(model_name, {})
|
||||
|
||||
assert schema is None
|
||||
mock_redis.delete.assert_called()
|
||||
|
||||
def test_get_model_schema_redis_delete_error(self, ai_model):
|
||||
model_name = "test_model"
|
||||
|
||||
with (
|
||||
patch("dify_graph.model_runtime.model_providers.__base.ai_model.redis_client") as mock_redis,
|
||||
patch("core.plugin.impl.model.PluginModelClient") as mock_client,
|
||||
):
|
||||
mock_redis.get.return_value = b'{"invalid": "schema"}'
|
||||
mock_redis.delete.side_effect = RedisError("Delete failed")
|
||||
mock_manager = mock_client.return_value
|
||||
mock_manager.get_model_schema.return_value = None
|
||||
|
||||
schema = ai_model.get_model_schema(model_name, {})
|
||||
|
||||
assert schema is None
|
||||
mock_redis.delete.assert_called()
|
||||
|
||||
def test_get_model_schema_redis_setex_error(self, ai_model):
|
||||
model_name = "test_model"
|
||||
mock_schema = AIModelEntity(
|
||||
model="test_model",
|
||||
label=I18nObject(en_US="Test Model"),
|
||||
model_type=ModelType.LLM,
|
||||
fetch_from=FetchFrom.PREDEFINED_MODEL,
|
||||
model_properties={ModelPropertyKey.CONTEXT_SIZE: 1024},
|
||||
parameter_rules=[],
|
||||
)
|
||||
|
||||
with (
|
||||
patch("dify_graph.model_runtime.model_providers.__base.ai_model.redis_client") as mock_redis,
|
||||
patch("core.plugin.impl.model.PluginModelClient") as mock_client,
|
||||
):
|
||||
mock_redis.get.return_value = None
|
||||
mock_redis.setex.side_effect = RuntimeError("Setex failed")
|
||||
mock_manager = mock_client.return_value
|
||||
mock_manager.get_model_schema.return_value = mock_schema
|
||||
|
||||
schema = ai_model.get_model_schema(model_name, {})
|
||||
|
||||
assert schema == mock_schema
|
||||
mock_redis.setex.assert_called()
|
||||
|
||||
def test_get_customizable_model_schema_from_credentials_template_mapping_value_error(self, ai_model):
|
||||
model_name = "test_model"
|
||||
|
||||
def test_get_customizable_model_schema_from_credentials_template_mapping_value_error(
|
||||
self, ai_model: AIModel
|
||||
) -> None:
|
||||
mock_schema = AIModelEntity(
|
||||
model="test_model",
|
||||
label=I18nObject(en_US="Test Model"),
|
||||
@@ -275,12 +194,11 @@ class TestAIModel:
|
||||
)
|
||||
|
||||
with patch.object(AIModel, "get_customizable_model_schema", return_value=mock_schema):
|
||||
schema = ai_model.get_customizable_model_schema_from_credentials(model_name, {})
|
||||
schema = ai_model.get_customizable_model_schema_from_credentials("test_model", {})
|
||||
assert schema is not None
|
||||
assert schema.parameter_rules[0].use_template == "invalid_template_name"
|
||||
|
||||
def test_get_customizable_model_schema_from_credentials(self, ai_model):
|
||||
model_name = "test_model"
|
||||
|
||||
def test_get_customizable_model_schema_from_credentials(self, ai_model: AIModel) -> None:
|
||||
mock_schema = AIModelEntity(
|
||||
model="test_model",
|
||||
label=I18nObject(en_US="Test Model"),
|
||||
@@ -310,27 +228,27 @@ class TestAIModel:
|
||||
)
|
||||
|
||||
with patch.object(AIModel, "get_customizable_model_schema", return_value=mock_schema):
|
||||
schema = ai_model.get_customizable_model_schema_from_credentials(model_name, {})
|
||||
schema = ai_model.get_customizable_model_schema_from_credentials("test_model", {})
|
||||
|
||||
assert schema is not None
|
||||
assert schema.parameter_rules[0].max == 1.0
|
||||
assert schema.parameter_rules[1].help is not None
|
||||
assert schema.parameter_rules[1].help.en_US != ""
|
||||
assert schema.parameter_rules[2].help is not None
|
||||
assert schema.parameter_rules[2].help.zh_Hans != ""
|
||||
assert schema.parameter_rules[3].use_template is None
|
||||
|
||||
def test_get_customizable_model_schema_from_credentials_none(self, ai_model):
|
||||
def test_get_customizable_model_schema_from_credentials_none(self, ai_model: AIModel) -> None:
|
||||
with patch.object(AIModel, "get_customizable_model_schema", return_value=None):
|
||||
schema = ai_model.get_customizable_model_schema_from_credentials("model", {})
|
||||
assert schema is None
|
||||
|
||||
def test_get_customizable_model_schema_default(self, ai_model):
|
||||
def test_get_customizable_model_schema_default(self, ai_model: AIModel) -> None:
|
||||
assert ai_model.get_customizable_model_schema("model", {}) is None
|
||||
|
||||
def test_get_default_parameter_rule_variable_map(self, ai_model):
|
||||
# Valid
|
||||
res = ai_model._get_default_parameter_rule_variable_map(DefaultParameterName.TEMPERATURE)
|
||||
assert res["default"] == 0.0
|
||||
def test_get_default_parameter_rule_variable_map(self, ai_model: AIModel) -> None:
|
||||
result = ai_model._get_default_parameter_rule_variable_map(DefaultParameterName.TEMPERATURE)
|
||||
assert result["default"] == 0.0
|
||||
|
||||
# Invalid
|
||||
with pytest.raises(Exception) as excinfo:
|
||||
with pytest.raises(Exception, match="Invalid model parameter rule name"):
|
||||
ai_model._get_default_parameter_rule_variable_map("invalid_name")
|
||||
assert "Invalid model parameter rule name" in str(excinfo.value)
|
||||
|
||||
@@ -0,0 +1,170 @@
|
||||
from decimal import Decimal
|
||||
from io import BytesIO
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from dify_graph.model_runtime.entities.common_entities import I18nObject
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from dify_graph.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity
|
||||
from dify_graph.model_runtime.entities.rerank_entities import RerankResult
|
||||
from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingResult, EmbeddingUsage
|
||||
from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from dify_graph.model_runtime.model_providers.__base.moderation_model import ModerationModel
|
||||
from dify_graph.model_runtime.model_providers.__base.rerank_model import RerankModel
|
||||
from dify_graph.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
|
||||
from dify_graph.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from dify_graph.model_runtime.model_providers.__base.tts_model import TTSModel
|
||||
|
||||
|
||||
def _provider_schema(model_type: ModelType) -> ProviderEntity:
|
||||
return ProviderEntity(
|
||||
provider="langgenius/openai/openai",
|
||||
provider_name="openai",
|
||||
label=I18nObject(en_US="OpenAI"),
|
||||
supported_model_types=[model_type],
|
||||
configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL],
|
||||
)
|
||||
|
||||
|
||||
def _embedding_usage() -> EmbeddingUsage:
|
||||
return EmbeddingUsage(
|
||||
tokens=1,
|
||||
total_tokens=1,
|
||||
unit_price=Decimal(0),
|
||||
price_unit=Decimal(0),
|
||||
total_price=Decimal(0),
|
||||
currency="USD",
|
||||
latency=0.0,
|
||||
)
|
||||
|
||||
|
||||
def test_large_language_model_invokes_runtime_without_user_id() -> None:
|
||||
runtime = MagicMock()
|
||||
runtime.invoke_llm.return_value = LLMResult(
|
||||
model="gpt-4o-mini",
|
||||
prompt_messages=[],
|
||||
message=AssistantPromptMessage(content="ok"),
|
||||
usage=LLMUsage.empty_usage(),
|
||||
)
|
||||
model = LargeLanguageModel(provider_schema=_provider_schema(ModelType.LLM), model_runtime=runtime)
|
||||
|
||||
model.invoke(
|
||||
model="gpt-4o-mini",
|
||||
credentials={"api_key": "secret"},
|
||||
prompt_messages=[UserPromptMessage(content="hi")],
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert "user_id" not in runtime.invoke_llm.call_args.kwargs
|
||||
|
||||
|
||||
def test_text_embedding_model_invokes_runtime_without_user_id_for_text_requests() -> None:
|
||||
runtime = MagicMock()
|
||||
runtime.invoke_text_embedding.return_value = EmbeddingResult(
|
||||
model="text-embedding-3-small",
|
||||
embeddings=[[0.1]],
|
||||
usage=_embedding_usage(),
|
||||
)
|
||||
model = TextEmbeddingModel(provider_schema=_provider_schema(ModelType.TEXT_EMBEDDING), model_runtime=runtime)
|
||||
|
||||
model.invoke(
|
||||
model="text-embedding-3-small",
|
||||
credentials={"api_key": "secret"},
|
||||
texts=["hello"],
|
||||
)
|
||||
|
||||
assert "user_id" not in runtime.invoke_text_embedding.call_args.kwargs
|
||||
|
||||
|
||||
def test_text_embedding_model_invokes_runtime_without_user_id_for_multimodal_requests() -> None:
|
||||
runtime = MagicMock()
|
||||
runtime.invoke_multimodal_embedding.return_value = EmbeddingResult(
|
||||
model="text-embedding-3-small",
|
||||
embeddings=[[0.1]],
|
||||
usage=_embedding_usage(),
|
||||
)
|
||||
model = TextEmbeddingModel(provider_schema=_provider_schema(ModelType.TEXT_EMBEDDING), model_runtime=runtime)
|
||||
|
||||
model.invoke(
|
||||
model="text-embedding-3-small",
|
||||
credentials={"api_key": "secret"},
|
||||
multimodel_documents=[{"content": "hello", "content_type": "text"}],
|
||||
)
|
||||
|
||||
assert "user_id" not in runtime.invoke_multimodal_embedding.call_args.kwargs
|
||||
|
||||
|
||||
def test_rerank_model_invokes_runtime_without_user_id_for_text_requests() -> None:
|
||||
runtime = MagicMock()
|
||||
runtime.invoke_rerank.return_value = RerankResult(model="rerank", docs=[])
|
||||
model = RerankModel(provider_schema=_provider_schema(ModelType.RERANK), model_runtime=runtime)
|
||||
|
||||
model.invoke(
|
||||
model="rerank",
|
||||
credentials={"api_key": "secret"},
|
||||
query="q",
|
||||
docs=["d1"],
|
||||
)
|
||||
|
||||
assert "user_id" not in runtime.invoke_rerank.call_args.kwargs
|
||||
|
||||
|
||||
def test_rerank_model_invokes_runtime_without_user_id_for_multimodal_requests() -> None:
|
||||
runtime = MagicMock()
|
||||
runtime.invoke_multimodal_rerank.return_value = RerankResult(model="rerank", docs=[])
|
||||
model = RerankModel(provider_schema=_provider_schema(ModelType.RERANK), model_runtime=runtime)
|
||||
|
||||
model.invoke_multimodal_rerank(
|
||||
model="rerank",
|
||||
credentials={"api_key": "secret"},
|
||||
query={"content": "q", "content_type": "text"},
|
||||
docs=[{"content": "d1", "content_type": "text"}],
|
||||
)
|
||||
|
||||
assert "user_id" not in runtime.invoke_multimodal_rerank.call_args.kwargs
|
||||
|
||||
|
||||
def test_tts_model_invokes_runtime_without_user_id() -> None:
|
||||
runtime = MagicMock()
|
||||
runtime.invoke_tts.return_value = [b"chunk"]
|
||||
model = TTSModel(provider_schema=_provider_schema(ModelType.TTS), model_runtime=runtime)
|
||||
|
||||
list(
|
||||
model.invoke(
|
||||
model="tts-1",
|
||||
credentials={"api_key": "secret"},
|
||||
content_text="hello",
|
||||
voice="alloy",
|
||||
)
|
||||
)
|
||||
|
||||
assert "user_id" not in runtime.invoke_tts.call_args.kwargs
|
||||
|
||||
|
||||
def test_speech_to_text_model_invokes_runtime_without_user_id() -> None:
|
||||
runtime = MagicMock()
|
||||
runtime.invoke_speech_to_text.return_value = "transcript"
|
||||
model = Speech2TextModel(provider_schema=_provider_schema(ModelType.SPEECH2TEXT), model_runtime=runtime)
|
||||
|
||||
model.invoke(
|
||||
model="whisper-1",
|
||||
credentials={"api_key": "secret"},
|
||||
file=BytesIO(b"audio"),
|
||||
)
|
||||
|
||||
assert "user_id" not in runtime.invoke_speech_to_text.call_args.kwargs
|
||||
|
||||
|
||||
def test_moderation_model_invokes_runtime_without_user_id() -> None:
|
||||
runtime = MagicMock()
|
||||
runtime.invoke_moderation.return_value = True
|
||||
model = ModerationModel(provider_schema=_provider_schema(ModelType.MODERATION), model_runtime=runtime)
|
||||
|
||||
model.invoke(
|
||||
model="omni-moderation-latest",
|
||||
credentials={"api_key": "secret"},
|
||||
text="unsafe?",
|
||||
)
|
||||
|
||||
assert "user_id" not in runtime.invoke_moderation.call_args.kwargs
|
||||
Reference in New Issue
Block a user