mirror of
https://github.com/langgenius/dify.git
synced 2026-05-30 07:00:37 -04:00
feat: extract model runtime
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
@@ -30,10 +30,27 @@ from dify_graph.model_runtime.entities.message_entities import (
|
||||
SystemPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from models.account import Tenant
|
||||
|
||||
|
||||
class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
@staticmethod
|
||||
def _get_bound_model_instance(
|
||||
*,
|
||||
tenant_id: str,
|
||||
user_id: str | None,
|
||||
provider: str,
|
||||
model_type: ModelType,
|
||||
model: str,
|
||||
):
|
||||
return ModelManager.for_tenant(tenant_id=tenant_id, user_id=user_id).get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
model_type=model_type,
|
||||
model=model,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def invoke_llm(
|
||||
cls, user_id: str, tenant: Tenant, payload: RequestInvokeLLM
|
||||
@@ -41,8 +58,9 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
"""
|
||||
invoke llm
|
||||
"""
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
model_instance = cls._get_bound_model_instance(
|
||||
tenant_id=tenant.id,
|
||||
user_id=user_id,
|
||||
provider=payload.provider,
|
||||
model_type=payload.model_type,
|
||||
model=payload.model,
|
||||
@@ -55,7 +73,6 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
tools=payload.tools,
|
||||
stop=payload.stop,
|
||||
stream=True if payload.stream is None else payload.stream,
|
||||
user=user_id,
|
||||
)
|
||||
|
||||
if isinstance(response, Generator):
|
||||
@@ -94,8 +111,9 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
"""
|
||||
invoke llm with structured output
|
||||
"""
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
model_instance = cls._get_bound_model_instance(
|
||||
tenant_id=tenant.id,
|
||||
user_id=user_id,
|
||||
provider=payload.provider,
|
||||
model_type=payload.model_type,
|
||||
model=payload.model,
|
||||
@@ -115,7 +133,6 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
tools=payload.tools,
|
||||
stop=payload.stop,
|
||||
stream=True if payload.stream is None else payload.stream,
|
||||
user=user_id,
|
||||
model_parameters=payload.completion_params,
|
||||
)
|
||||
|
||||
@@ -156,18 +173,16 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
"""
|
||||
invoke text embedding
|
||||
"""
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
model_instance = cls._get_bound_model_instance(
|
||||
tenant_id=tenant.id,
|
||||
user_id=user_id,
|
||||
provider=payload.provider,
|
||||
model_type=payload.model_type,
|
||||
model=payload.model,
|
||||
)
|
||||
|
||||
# invoke model
|
||||
response = model_instance.invoke_text_embedding(
|
||||
texts=payload.texts,
|
||||
user=user_id,
|
||||
)
|
||||
response = model_instance.invoke_text_embedding(texts=payload.texts)
|
||||
|
||||
return response
|
||||
|
||||
@@ -176,8 +191,9 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
"""
|
||||
invoke rerank
|
||||
"""
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
model_instance = cls._get_bound_model_instance(
|
||||
tenant_id=tenant.id,
|
||||
user_id=user_id,
|
||||
provider=payload.provider,
|
||||
model_type=payload.model_type,
|
||||
model=payload.model,
|
||||
@@ -189,7 +205,6 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
docs=payload.docs,
|
||||
score_threshold=payload.score_threshold,
|
||||
top_n=payload.top_n,
|
||||
user=user_id,
|
||||
)
|
||||
|
||||
return response
|
||||
@@ -199,20 +214,16 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
"""
|
||||
invoke tts
|
||||
"""
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
model_instance = cls._get_bound_model_instance(
|
||||
tenant_id=tenant.id,
|
||||
user_id=user_id,
|
||||
provider=payload.provider,
|
||||
model_type=payload.model_type,
|
||||
model=payload.model,
|
||||
)
|
||||
|
||||
# invoke model
|
||||
response = model_instance.invoke_tts(
|
||||
content_text=payload.content_text,
|
||||
tenant_id=tenant.id,
|
||||
voice=payload.voice,
|
||||
user=user_id,
|
||||
)
|
||||
response = model_instance.invoke_tts(content_text=payload.content_text, voice=payload.voice)
|
||||
|
||||
def handle() -> Generator[dict, None, None]:
|
||||
for chunk in response:
|
||||
@@ -225,8 +236,9 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
"""
|
||||
invoke speech2text
|
||||
"""
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
model_instance = cls._get_bound_model_instance(
|
||||
tenant_id=tenant.id,
|
||||
user_id=user_id,
|
||||
provider=payload.provider,
|
||||
model_type=payload.model_type,
|
||||
model=payload.model,
|
||||
@@ -238,10 +250,7 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
temp.flush()
|
||||
temp.seek(0)
|
||||
|
||||
response = model_instance.invoke_speech2text(
|
||||
file=temp,
|
||||
user=user_id,
|
||||
)
|
||||
response = model_instance.invoke_speech2text(file=temp)
|
||||
|
||||
return {
|
||||
"result": response,
|
||||
@@ -252,18 +261,16 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
"""
|
||||
invoke moderation
|
||||
"""
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
model_instance = cls._get_bound_model_instance(
|
||||
tenant_id=tenant.id,
|
||||
user_id=user_id,
|
||||
provider=payload.provider,
|
||||
model_type=payload.model_type,
|
||||
model=payload.model,
|
||||
)
|
||||
|
||||
# invoke model
|
||||
response = model_instance.invoke_moderation(
|
||||
text=payload.text,
|
||||
user=user_id,
|
||||
)
|
||||
response = model_instance.invoke_moderation(text=payload.text)
|
||||
|
||||
return {
|
||||
"result": response,
|
||||
|
||||
Reference in New Issue
Block a user