feat: extract model runtime

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN-
2026-03-15 15:34:47 +08:00
parent 3d5a29462e
commit fbb74a4af9
178 changed files with 4343 additions and 2134 deletions

View File

@@ -30,10 +30,27 @@ from dify_graph.model_runtime.entities.message_entities import (
SystemPromptMessage,
UserPromptMessage,
)
from dify_graph.model_runtime.entities.model_entities import ModelType
from models.account import Tenant
class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
@staticmethod
def _get_bound_model_instance(
*,
tenant_id: str,
user_id: str | None,
provider: str,
model_type: ModelType,
model: str,
):
return ModelManager.for_tenant(tenant_id=tenant_id, user_id=user_id).get_model_instance(
tenant_id=tenant_id,
provider=provider,
model_type=model_type,
model=model,
)
@classmethod
def invoke_llm(
cls, user_id: str, tenant: Tenant, payload: RequestInvokeLLM
@@ -41,8 +58,9 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
"""
invoke llm
"""
model_instance = ModelManager().get_model_instance(
model_instance = cls._get_bound_model_instance(
tenant_id=tenant.id,
user_id=user_id,
provider=payload.provider,
model_type=payload.model_type,
model=payload.model,
@@ -55,7 +73,6 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
tools=payload.tools,
stop=payload.stop,
stream=True if payload.stream is None else payload.stream,
user=user_id,
)
if isinstance(response, Generator):
@@ -94,8 +111,9 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
"""
invoke llm with structured output
"""
model_instance = ModelManager().get_model_instance(
model_instance = cls._get_bound_model_instance(
tenant_id=tenant.id,
user_id=user_id,
provider=payload.provider,
model_type=payload.model_type,
model=payload.model,
@@ -115,7 +133,6 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
tools=payload.tools,
stop=payload.stop,
stream=True if payload.stream is None else payload.stream,
user=user_id,
model_parameters=payload.completion_params,
)
@@ -156,18 +173,16 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
"""
invoke text embedding
"""
model_instance = ModelManager().get_model_instance(
model_instance = cls._get_bound_model_instance(
tenant_id=tenant.id,
user_id=user_id,
provider=payload.provider,
model_type=payload.model_type,
model=payload.model,
)
# invoke model
response = model_instance.invoke_text_embedding(
texts=payload.texts,
user=user_id,
)
response = model_instance.invoke_text_embedding(texts=payload.texts)
return response
@@ -176,8 +191,9 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
"""
invoke rerank
"""
model_instance = ModelManager().get_model_instance(
model_instance = cls._get_bound_model_instance(
tenant_id=tenant.id,
user_id=user_id,
provider=payload.provider,
model_type=payload.model_type,
model=payload.model,
@@ -189,7 +205,6 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
docs=payload.docs,
score_threshold=payload.score_threshold,
top_n=payload.top_n,
user=user_id,
)
return response
@@ -199,20 +214,16 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
"""
invoke tts
"""
model_instance = ModelManager().get_model_instance(
model_instance = cls._get_bound_model_instance(
tenant_id=tenant.id,
user_id=user_id,
provider=payload.provider,
model_type=payload.model_type,
model=payload.model,
)
# invoke model
response = model_instance.invoke_tts(
content_text=payload.content_text,
tenant_id=tenant.id,
voice=payload.voice,
user=user_id,
)
response = model_instance.invoke_tts(content_text=payload.content_text, voice=payload.voice)
def handle() -> Generator[dict, None, None]:
for chunk in response:
@@ -225,8 +236,9 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
"""
invoke speech2text
"""
model_instance = ModelManager().get_model_instance(
model_instance = cls._get_bound_model_instance(
tenant_id=tenant.id,
user_id=user_id,
provider=payload.provider,
model_type=payload.model_type,
model=payload.model,
@@ -238,10 +250,7 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
temp.flush()
temp.seek(0)
response = model_instance.invoke_speech2text(
file=temp,
user=user_id,
)
response = model_instance.invoke_speech2text(file=temp)
return {
"result": response,
@@ -252,18 +261,16 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
"""
invoke moderation
"""
model_instance = ModelManager().get_model_instance(
model_instance = cls._get_bound_model_instance(
tenant_id=tenant.id,
user_id=user_id,
provider=payload.provider,
model_type=payload.model_type,
model=payload.model,
)
# invoke model
response = model_instance.invoke_moderation(
text=payload.text,
user=user_id,
)
response = model_instance.invoke_moderation(text=payload.text)
return {
"result": response,