feat: extract model runtime

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN-
2026-03-15 15:34:47 +08:00
parent 3d5a29462e
commit fbb74a4af9
178 changed files with 4343 additions and 2134 deletions

View File

@@ -4,7 +4,7 @@ from dify_graph.model_runtime.entities.message_entities import (
TextPromptMessageContent,
UserPromptMessage,
)
from dify_graph.model_runtime.model_providers.__base.large_language_model import _normalize_non_stream_plugin_result
from dify_graph.model_runtime.model_providers.__base.large_language_model import _normalize_non_stream_runtime_result
def _make_chunk(
@@ -20,7 +20,7 @@ def _make_chunk(
return LLMResultChunk(model=model, delta=delta, system_fingerprint=system_fingerprint)
def test__normalize_non_stream_plugin_result__from_first_chunk_str_content_and_tool_calls():
def test__normalize_non_stream_runtime_result__from_first_chunk_str_content_and_tool_calls():
prompt_messages = [UserPromptMessage(content="hi")]
tool_calls = [
@@ -44,7 +44,7 @@ def test__normalize_non_stream_plugin_result__from_first_chunk_str_content_and_t
usage = LLMUsage.empty_usage().model_copy(update={"prompt_tokens": 1, "total_tokens": 1})
chunk = _make_chunk(content="hello", tool_calls=tool_calls, usage=usage, system_fingerprint="fp-1")
result = _normalize_non_stream_plugin_result(
result = _normalize_non_stream_runtime_result(
model="test-model", prompt_messages=prompt_messages, result=iter([chunk])
)
@@ -62,20 +62,20 @@ def test__normalize_non_stream_plugin_result__from_first_chunk_str_content_and_t
]
def test__normalize_non_stream_plugin_result__from_first_chunk_list_content():
def test__normalize_non_stream_runtime_result__from_first_chunk_list_content():
prompt_messages = [UserPromptMessage(content="hi")]
content_list = [TextPromptMessageContent(data="a"), TextPromptMessageContent(data="b")]
chunk = _make_chunk(content=content_list, usage=LLMUsage.empty_usage())
result = _normalize_non_stream_plugin_result(
result = _normalize_non_stream_runtime_result(
model="test-model", prompt_messages=prompt_messages, result=iter([chunk])
)
assert result.message.content == content_list
def test__normalize_non_stream_plugin_result__passthrough_llm_result():
def test__normalize_non_stream_runtime_result__passthrough_llm_result():
prompt_messages = [UserPromptMessage(content="hi")]
llm_result = LLMResult(
model="test-model",
@@ -85,15 +85,15 @@ def test__normalize_non_stream_plugin_result__passthrough_llm_result():
)
assert (
_normalize_non_stream_plugin_result(model="test-model", prompt_messages=prompt_messages, result=llm_result)
_normalize_non_stream_runtime_result(model="test-model", prompt_messages=prompt_messages, result=llm_result)
== llm_result
)
def test__normalize_non_stream_plugin_result__empty_iterator_defaults():
def test__normalize_non_stream_runtime_result__empty_iterator_defaults():
prompt_messages = [UserPromptMessage(content="hi")]
result = _normalize_non_stream_plugin_result(model="test-model", prompt_messages=prompt_messages, result=iter([]))
result = _normalize_non_stream_runtime_result(model="test-model", prompt_messages=prompt_messages, result=iter([]))
assert result.model == "test-model"
assert result.prompt_messages == prompt_messages
@@ -103,7 +103,7 @@ def test__normalize_non_stream_plugin_result__empty_iterator_defaults():
assert result.system_fingerprint is None
def test__normalize_non_stream_plugin_result__accumulates_all_chunks():
def test__normalize_non_stream_runtime_result__accumulates_all_chunks():
"""All chunks are accumulated from the iterator."""
prompt_messages = [UserPromptMessage(content="hi")]
@@ -116,7 +116,7 @@ def test__normalize_non_stream_plugin_result__accumulates_all_chunks():
finally:
closed.append(True)
result = _normalize_non_stream_plugin_result(
result = _normalize_non_stream_runtime_result(
model="test-model",
prompt_messages=prompt_messages,
result=_chunk_iter(),

View File

@@ -2,9 +2,8 @@ import decimal
from unittest.mock import MagicMock, patch
import pytest
from redis import RedisError
from pydantic import BaseModel
from core.plugin.entities.plugin_daemon import PluginDaemonInnerError, PluginModelProviderEntity
from dify_graph.model_runtime.entities.common_entities import I18nObject
from dify_graph.model_runtime.entities.model_entities import (
AIModelEntity,
@@ -17,6 +16,7 @@ from dify_graph.model_runtime.entities.model_entities import (
PriceConfig,
PriceType,
)
from dify_graph.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity
from dify_graph.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
@@ -28,46 +28,72 @@ from dify_graph.model_runtime.errors.invoke import (
from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel
class _ConcreteAIModel(AIModel):
model_type = ModelType.LLM
class TestAIModel:
@pytest.fixture
def mock_plugin_model_provider(self):
return MagicMock(spec=PluginModelProviderEntity)
@pytest.fixture
def ai_model(self, mock_plugin_model_provider):
return AIModel(
tenant_id="tenant_123",
model_type=ModelType.LLM,
plugin_id="plugin_123",
provider_name="test_provider",
plugin_model_provider=mock_plugin_model_provider,
def provider_schema(self) -> ProviderEntity:
return ProviderEntity(
provider="langgenius/openai/openai",
provider_name="openai",
label=I18nObject(en_US="OpenAI"),
supported_model_types=[ModelType.LLM],
configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL],
)
def test_invoke_error_mapping(self, ai_model):
@pytest.fixture
def model_runtime(self) -> MagicMock:
return MagicMock()
@pytest.fixture
def ai_model(self, provider_schema: ProviderEntity, model_runtime: MagicMock) -> AIModel:
return _ConcreteAIModel(
provider_schema=provider_schema,
model_runtime=model_runtime,
)
def test_init_stores_runtime_state_and_is_not_pydantic_model(
self, ai_model: AIModel, provider_schema: ProviderEntity, model_runtime: MagicMock
) -> None:
assert ai_model.model_type == ModelType.LLM
assert ai_model.provider_schema is provider_schema
assert ai_model.model_runtime is model_runtime
assert ai_model.provider == "langgenius/openai/openai"
assert ai_model.provider_display_name == "OpenAI"
assert ai_model.started_at == 0
assert not isinstance(ai_model, BaseModel)
def test_direct_base_class_requires_subclass_model_type(
self, provider_schema: ProviderEntity, model_runtime: MagicMock
) -> None:
with pytest.raises(TypeError, match="subclasses must define model_type"):
AIModel(provider_schema=provider_schema, model_runtime=model_runtime)
def test_subclass_uses_class_level_model_type(
self, provider_schema: ProviderEntity, model_runtime: MagicMock
) -> None:
model = _ConcreteAIModel(provider_schema=provider_schema, model_runtime=model_runtime)
assert model.model_type == ModelType.LLM
def test_invoke_error_mapping(self, ai_model: AIModel) -> None:
mapping = ai_model._invoke_error_mapping
assert InvokeConnectionError in mapping
assert InvokeServerUnavailableError in mapping
assert InvokeRateLimitError in mapping
assert InvokeAuthorizationError in mapping
assert InvokeBadRequestError in mapping
assert PluginDaemonInnerError in mapping
assert ValueError in mapping
def test_transform_invoke_error(self, ai_model):
# Case: mapped error (InvokeAuthorizationError)
def test_transform_invoke_error(self, ai_model: AIModel) -> None:
err = Exception("Original error")
with patch.object(AIModel, "_invoke_error_mapping", {InvokeAuthorizationError: [Exception]}):
transformed = ai_model._transform_invoke_error(err)
assert isinstance(transformed, InvokeAuthorizationError)
assert "Incorrect model credentials provided" in str(transformed.description)
# Case: mapped error (InvokeError subclass)
with patch.object(AIModel, "_invoke_error_mapping", {InvokeRateLimitError("Rate limit"): [Exception]}):
transformed = ai_model._transform_invoke_error(err)
assert isinstance(transformed, InvokeError)
assert "[test_provider]" in transformed.description
# Case: mapped error (not InvokeError)
class CustomNonInvokeError(Exception):
pass
@@ -75,52 +101,43 @@ class TestAIModel:
transformed = ai_model._transform_invoke_error(err)
assert transformed == err
# Case: unmapped error
unmapped_err = Exception("Unmapped")
transformed = ai_model._transform_invoke_error(unmapped_err)
transformed = ai_model._transform_invoke_error(Exception("Unmapped"))
assert isinstance(transformed, InvokeError)
assert "Error: Unmapped" in transformed.description
assert transformed.description == "[OpenAI] Error: Unmapped"
def test_get_price(self, ai_model):
def test_get_price(self, ai_model: AIModel) -> None:
model_name = "test_model"
credentials = {"key": "value"}
# Mock get_model_schema
mock_schema = MagicMock(spec=AIModelEntity)
mock_schema.pricing = PriceConfig(
input=decimal.Decimal("0.002"),
output=decimal.Decimal("0.004"),
unit=decimal.Decimal(1000), # 1000 tokens per unit
unit=decimal.Decimal(1000),
currency="USD",
)
with patch.object(AIModel, "get_model_schema", return_value=mock_schema):
# Test INPUT
price_info = ai_model.get_price(model_name, credentials, PriceType.INPUT, 2000)
assert price_info.unit_price == decimal.Decimal("0.002")
# Test OUTPUT
price_info = ai_model.get_price(model_name, credentials, PriceType.OUTPUT, 2000)
assert price_info.unit_price == decimal.Decimal("0.004")
# Case: unit_price is None (returns zeroed PriceInfo)
mock_schema.pricing = None
with patch.object(AIModel, "get_model_schema", return_value=mock_schema):
price_info = ai_model.get_price(model_name, credentials, PriceType.INPUT, 1000)
assert price_info.total_amount == decimal.Decimal("0.0")
def test_get_price_no_price_config_error(self, ai_model):
model_name = "test_model"
# We need it to be truthy at line 107 and 112 but falsy at line 127.
def test_get_price_no_price_config_error(self, ai_model: AIModel) -> None:
class ChangingPriceConfig:
def __init__(self):
def __init__(self) -> None:
self.input = decimal.Decimal("0.01")
self.unit = decimal.Decimal(1)
self.currency = "USD"
self.called = 0
def __bool__(self):
def __bool__(self) -> bool:
self.called += 1
return self.called <= 2
@@ -128,11 +145,12 @@ class TestAIModel:
mock_schema.pricing = ChangingPriceConfig()
with patch.object(AIModel, "get_model_schema", return_value=mock_schema):
with pytest.raises(ValueError) as excinfo:
ai_model.get_price(model_name, {}, PriceType.INPUT, 1000)
assert "Price config not found" in str(excinfo.value)
with pytest.raises(ValueError, match="Price config not found"):
ai_model.get_price("test_model", {}, PriceType.INPUT, 1000)
def test_get_model_schema_cache_hit(self, ai_model):
def test_get_model_schema_delegates_to_runtime(
self, ai_model: AIModel, model_runtime: MagicMock, provider_schema: ProviderEntity
) -> None:
model_name = "test_model"
credentials = {"api_key": "abc"}
@@ -144,120 +162,21 @@ class TestAIModel:
model_properties={ModelPropertyKey.CONTEXT_SIZE: 1024},
parameter_rules=[],
)
model_runtime.get_model_schema.return_value = mock_schema
with patch("dify_graph.model_runtime.model_providers.__base.ai_model.redis_client") as mock_redis:
mock_redis.get.return_value = mock_schema.model_dump_json().encode()
schema = ai_model.get_model_schema(model_name, credentials)
schema = ai_model.get_model_schema(model_name, credentials)
assert schema.model == "test_model"
mock_redis.get.assert_called_once()
def test_get_model_schema_cache_miss(self, ai_model):
model_name = "test_model"
credentials = {"api_key": "abc"}
mock_schema = AIModelEntity(
model="test_model",
label=I18nObject(en_US="Test Model"),
assert schema == mock_schema
model_runtime.get_model_schema.assert_called_once_with(
provider=provider_schema.provider,
model_type=ModelType.LLM,
fetch_from=FetchFrom.PREDEFINED_MODEL,
model_properties={ModelPropertyKey.CONTEXT_SIZE: 1024},
parameter_rules=[],
model=model_name,
credentials=credentials,
)
with (
patch("dify_graph.model_runtime.model_providers.__base.ai_model.redis_client") as mock_redis,
patch("core.plugin.impl.model.PluginModelClient") as mock_client,
):
mock_redis.get.return_value = None
mock_manager = mock_client.return_value
mock_manager.get_model_schema.return_value = mock_schema
schema = ai_model.get_model_schema(model_name, credentials)
assert schema == mock_schema
mock_manager.get_model_schema.assert_called_once()
mock_redis.setex.assert_called_once()
def test_get_model_schema_redis_error(self, ai_model):
model_name = "test_model"
with (
patch("dify_graph.model_runtime.model_providers.__base.ai_model.redis_client") as mock_redis,
patch("core.plugin.impl.model.PluginModelClient") as mock_client,
):
mock_redis.get.side_effect = RedisError("Connection refused")
mock_manager = mock_client.return_value
mock_manager.get_model_schema.return_value = None
schema = ai_model.get_model_schema(model_name, {})
assert schema is None
mock_manager.get_model_schema.assert_called_once()
def test_get_model_schema_validation_error(self, ai_model):
model_name = "test_model"
with (
patch("dify_graph.model_runtime.model_providers.__base.ai_model.redis_client") as mock_redis,
patch("core.plugin.impl.model.PluginModelClient") as mock_client,
):
mock_redis.get.return_value = b"invalid json"
mock_manager = mock_client.return_value
mock_manager.get_model_schema.return_value = None
# This should trigger ValidationError at line 166 and go to delete()
schema = ai_model.get_model_schema(model_name, {})
assert schema is None
mock_redis.delete.assert_called()
def test_get_model_schema_redis_delete_error(self, ai_model):
model_name = "test_model"
with (
patch("dify_graph.model_runtime.model_providers.__base.ai_model.redis_client") as mock_redis,
patch("core.plugin.impl.model.PluginModelClient") as mock_client,
):
mock_redis.get.return_value = b'{"invalid": "schema"}'
mock_redis.delete.side_effect = RedisError("Delete failed")
mock_manager = mock_client.return_value
mock_manager.get_model_schema.return_value = None
schema = ai_model.get_model_schema(model_name, {})
assert schema is None
mock_redis.delete.assert_called()
def test_get_model_schema_redis_setex_error(self, ai_model):
model_name = "test_model"
mock_schema = AIModelEntity(
model="test_model",
label=I18nObject(en_US="Test Model"),
model_type=ModelType.LLM,
fetch_from=FetchFrom.PREDEFINED_MODEL,
model_properties={ModelPropertyKey.CONTEXT_SIZE: 1024},
parameter_rules=[],
)
with (
patch("dify_graph.model_runtime.model_providers.__base.ai_model.redis_client") as mock_redis,
patch("core.plugin.impl.model.PluginModelClient") as mock_client,
):
mock_redis.get.return_value = None
mock_redis.setex.side_effect = RuntimeError("Setex failed")
mock_manager = mock_client.return_value
mock_manager.get_model_schema.return_value = mock_schema
schema = ai_model.get_model_schema(model_name, {})
assert schema == mock_schema
mock_redis.setex.assert_called()
def test_get_customizable_model_schema_from_credentials_template_mapping_value_error(self, ai_model):
model_name = "test_model"
def test_get_customizable_model_schema_from_credentials_template_mapping_value_error(
self, ai_model: AIModel
) -> None:
mock_schema = AIModelEntity(
model="test_model",
label=I18nObject(en_US="Test Model"),
@@ -275,12 +194,11 @@ class TestAIModel:
)
with patch.object(AIModel, "get_customizable_model_schema", return_value=mock_schema):
schema = ai_model.get_customizable_model_schema_from_credentials(model_name, {})
schema = ai_model.get_customizable_model_schema_from_credentials("test_model", {})
assert schema is not None
assert schema.parameter_rules[0].use_template == "invalid_template_name"
def test_get_customizable_model_schema_from_credentials(self, ai_model):
model_name = "test_model"
def test_get_customizable_model_schema_from_credentials(self, ai_model: AIModel) -> None:
mock_schema = AIModelEntity(
model="test_model",
label=I18nObject(en_US="Test Model"),
@@ -310,27 +228,27 @@ class TestAIModel:
)
with patch.object(AIModel, "get_customizable_model_schema", return_value=mock_schema):
schema = ai_model.get_customizable_model_schema_from_credentials(model_name, {})
schema = ai_model.get_customizable_model_schema_from_credentials("test_model", {})
assert schema is not None
assert schema.parameter_rules[0].max == 1.0
assert schema.parameter_rules[1].help is not None
assert schema.parameter_rules[1].help.en_US != ""
assert schema.parameter_rules[2].help is not None
assert schema.parameter_rules[2].help.zh_Hans != ""
assert schema.parameter_rules[3].use_template is None
def test_get_customizable_model_schema_from_credentials_none(self, ai_model):
def test_get_customizable_model_schema_from_credentials_none(self, ai_model: AIModel) -> None:
with patch.object(AIModel, "get_customizable_model_schema", return_value=None):
schema = ai_model.get_customizable_model_schema_from_credentials("model", {})
assert schema is None
def test_get_customizable_model_schema_default(self, ai_model):
def test_get_customizable_model_schema_default(self, ai_model: AIModel) -> None:
assert ai_model.get_customizable_model_schema("model", {}) is None
def test_get_default_parameter_rule_variable_map(self, ai_model):
# Valid
res = ai_model._get_default_parameter_rule_variable_map(DefaultParameterName.TEMPERATURE)
assert res["default"] == 0.0
def test_get_default_parameter_rule_variable_map(self, ai_model: AIModel) -> None:
result = ai_model._get_default_parameter_rule_variable_map(DefaultParameterName.TEMPERATURE)
assert result["default"] == 0.0
# Invalid
with pytest.raises(Exception) as excinfo:
with pytest.raises(Exception, match="Invalid model parameter rule name"):
ai_model._get_default_parameter_rule_variable_map("invalid_name")
assert "Invalid model parameter rule name" in str(excinfo.value)

View File

@@ -0,0 +1,170 @@
from decimal import Decimal
from io import BytesIO
from unittest.mock import MagicMock
from dify_graph.model_runtime.entities.common_entities import I18nObject
from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage
from dify_graph.model_runtime.entities.model_entities import ModelType
from dify_graph.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity
from dify_graph.model_runtime.entities.rerank_entities import RerankResult
from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingResult, EmbeddingUsage
from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from dify_graph.model_runtime.model_providers.__base.moderation_model import ModerationModel
from dify_graph.model_runtime.model_providers.__base.rerank_model import RerankModel
from dify_graph.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
from dify_graph.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from dify_graph.model_runtime.model_providers.__base.tts_model import TTSModel
def _provider_schema(model_type: ModelType) -> ProviderEntity:
return ProviderEntity(
provider="langgenius/openai/openai",
provider_name="openai",
label=I18nObject(en_US="OpenAI"),
supported_model_types=[model_type],
configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL],
)
def _embedding_usage() -> EmbeddingUsage:
return EmbeddingUsage(
tokens=1,
total_tokens=1,
unit_price=Decimal(0),
price_unit=Decimal(0),
total_price=Decimal(0),
currency="USD",
latency=0.0,
)
def test_large_language_model_invokes_runtime_without_user_id() -> None:
runtime = MagicMock()
runtime.invoke_llm.return_value = LLMResult(
model="gpt-4o-mini",
prompt_messages=[],
message=AssistantPromptMessage(content="ok"),
usage=LLMUsage.empty_usage(),
)
model = LargeLanguageModel(provider_schema=_provider_schema(ModelType.LLM), model_runtime=runtime)
model.invoke(
model="gpt-4o-mini",
credentials={"api_key": "secret"},
prompt_messages=[UserPromptMessage(content="hi")],
stream=False,
)
assert "user_id" not in runtime.invoke_llm.call_args.kwargs
def test_text_embedding_model_invokes_runtime_without_user_id_for_text_requests() -> None:
runtime = MagicMock()
runtime.invoke_text_embedding.return_value = EmbeddingResult(
model="text-embedding-3-small",
embeddings=[[0.1]],
usage=_embedding_usage(),
)
model = TextEmbeddingModel(provider_schema=_provider_schema(ModelType.TEXT_EMBEDDING), model_runtime=runtime)
model.invoke(
model="text-embedding-3-small",
credentials={"api_key": "secret"},
texts=["hello"],
)
assert "user_id" not in runtime.invoke_text_embedding.call_args.kwargs
def test_text_embedding_model_invokes_runtime_without_user_id_for_multimodal_requests() -> None:
runtime = MagicMock()
runtime.invoke_multimodal_embedding.return_value = EmbeddingResult(
model="text-embedding-3-small",
embeddings=[[0.1]],
usage=_embedding_usage(),
)
model = TextEmbeddingModel(provider_schema=_provider_schema(ModelType.TEXT_EMBEDDING), model_runtime=runtime)
model.invoke(
model="text-embedding-3-small",
credentials={"api_key": "secret"},
multimodel_documents=[{"content": "hello", "content_type": "text"}],
)
assert "user_id" not in runtime.invoke_multimodal_embedding.call_args.kwargs
def test_rerank_model_invokes_runtime_without_user_id_for_text_requests() -> None:
runtime = MagicMock()
runtime.invoke_rerank.return_value = RerankResult(model="rerank", docs=[])
model = RerankModel(provider_schema=_provider_schema(ModelType.RERANK), model_runtime=runtime)
model.invoke(
model="rerank",
credentials={"api_key": "secret"},
query="q",
docs=["d1"],
)
assert "user_id" not in runtime.invoke_rerank.call_args.kwargs
def test_rerank_model_invokes_runtime_without_user_id_for_multimodal_requests() -> None:
runtime = MagicMock()
runtime.invoke_multimodal_rerank.return_value = RerankResult(model="rerank", docs=[])
model = RerankModel(provider_schema=_provider_schema(ModelType.RERANK), model_runtime=runtime)
model.invoke_multimodal_rerank(
model="rerank",
credentials={"api_key": "secret"},
query={"content": "q", "content_type": "text"},
docs=[{"content": "d1", "content_type": "text"}],
)
assert "user_id" not in runtime.invoke_multimodal_rerank.call_args.kwargs
def test_tts_model_invokes_runtime_without_user_id() -> None:
runtime = MagicMock()
runtime.invoke_tts.return_value = [b"chunk"]
model = TTSModel(provider_schema=_provider_schema(ModelType.TTS), model_runtime=runtime)
list(
model.invoke(
model="tts-1",
credentials={"api_key": "secret"},
content_text="hello",
voice="alloy",
)
)
assert "user_id" not in runtime.invoke_tts.call_args.kwargs
def test_speech_to_text_model_invokes_runtime_without_user_id() -> None:
runtime = MagicMock()
runtime.invoke_speech_to_text.return_value = "transcript"
model = Speech2TextModel(provider_schema=_provider_schema(ModelType.SPEECH2TEXT), model_runtime=runtime)
model.invoke(
model="whisper-1",
credentials={"api_key": "secret"},
file=BytesIO(b"audio"),
)
assert "user_id" not in runtime.invoke_speech_to_text.call_args.kwargs
def test_moderation_model_invokes_runtime_without_user_id() -> None:
runtime = MagicMock()
runtime.invoke_moderation.return_value = True
model = ModerationModel(provider_schema=_provider_schema(ModelType.MODERATION), model_runtime=runtime)
model.invoke(
model="omni-moderation-latest",
credentials={"api_key": "secret"},
text="unsafe?",
)
assert "user_id" not in runtime.invoke_moderation.call_args.kwargs