mirror of
https://github.com/langgenius/dify.git
synced 2026-05-30 16:00:32 -04:00
feat: extract model runtime
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
@@ -30,10 +30,27 @@ from dify_graph.model_runtime.entities.message_entities import (
|
||||
SystemPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from models.account import Tenant
|
||||
|
||||
|
||||
class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
@staticmethod
|
||||
def _get_bound_model_instance(
|
||||
*,
|
||||
tenant_id: str,
|
||||
user_id: str | None,
|
||||
provider: str,
|
||||
model_type: ModelType,
|
||||
model: str,
|
||||
):
|
||||
return ModelManager.for_tenant(tenant_id=tenant_id, user_id=user_id).get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
model_type=model_type,
|
||||
model=model,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def invoke_llm(
|
||||
cls, user_id: str, tenant: Tenant, payload: RequestInvokeLLM
|
||||
@@ -41,8 +58,9 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
"""
|
||||
invoke llm
|
||||
"""
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
model_instance = cls._get_bound_model_instance(
|
||||
tenant_id=tenant.id,
|
||||
user_id=user_id,
|
||||
provider=payload.provider,
|
||||
model_type=payload.model_type,
|
||||
model=payload.model,
|
||||
@@ -55,7 +73,6 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
tools=payload.tools,
|
||||
stop=payload.stop,
|
||||
stream=True if payload.stream is None else payload.stream,
|
||||
user=user_id,
|
||||
)
|
||||
|
||||
if isinstance(response, Generator):
|
||||
@@ -94,8 +111,9 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
"""
|
||||
invoke llm with structured output
|
||||
"""
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
model_instance = cls._get_bound_model_instance(
|
||||
tenant_id=tenant.id,
|
||||
user_id=user_id,
|
||||
provider=payload.provider,
|
||||
model_type=payload.model_type,
|
||||
model=payload.model,
|
||||
@@ -115,7 +133,6 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
tools=payload.tools,
|
||||
stop=payload.stop,
|
||||
stream=True if payload.stream is None else payload.stream,
|
||||
user=user_id,
|
||||
model_parameters=payload.completion_params,
|
||||
)
|
||||
|
||||
@@ -156,18 +173,16 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
"""
|
||||
invoke text embedding
|
||||
"""
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
model_instance = cls._get_bound_model_instance(
|
||||
tenant_id=tenant.id,
|
||||
user_id=user_id,
|
||||
provider=payload.provider,
|
||||
model_type=payload.model_type,
|
||||
model=payload.model,
|
||||
)
|
||||
|
||||
# invoke model
|
||||
response = model_instance.invoke_text_embedding(
|
||||
texts=payload.texts,
|
||||
user=user_id,
|
||||
)
|
||||
response = model_instance.invoke_text_embedding(texts=payload.texts)
|
||||
|
||||
return response
|
||||
|
||||
@@ -176,8 +191,9 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
"""
|
||||
invoke rerank
|
||||
"""
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
model_instance = cls._get_bound_model_instance(
|
||||
tenant_id=tenant.id,
|
||||
user_id=user_id,
|
||||
provider=payload.provider,
|
||||
model_type=payload.model_type,
|
||||
model=payload.model,
|
||||
@@ -189,7 +205,6 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
docs=payload.docs,
|
||||
score_threshold=payload.score_threshold,
|
||||
top_n=payload.top_n,
|
||||
user=user_id,
|
||||
)
|
||||
|
||||
return response
|
||||
@@ -199,20 +214,16 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
"""
|
||||
invoke tts
|
||||
"""
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
model_instance = cls._get_bound_model_instance(
|
||||
tenant_id=tenant.id,
|
||||
user_id=user_id,
|
||||
provider=payload.provider,
|
||||
model_type=payload.model_type,
|
||||
model=payload.model,
|
||||
)
|
||||
|
||||
# invoke model
|
||||
response = model_instance.invoke_tts(
|
||||
content_text=payload.content_text,
|
||||
tenant_id=tenant.id,
|
||||
voice=payload.voice,
|
||||
user=user_id,
|
||||
)
|
||||
response = model_instance.invoke_tts(content_text=payload.content_text, voice=payload.voice)
|
||||
|
||||
def handle() -> Generator[dict, None, None]:
|
||||
for chunk in response:
|
||||
@@ -225,8 +236,9 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
"""
|
||||
invoke speech2text
|
||||
"""
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
model_instance = cls._get_bound_model_instance(
|
||||
tenant_id=tenant.id,
|
||||
user_id=user_id,
|
||||
provider=payload.provider,
|
||||
model_type=payload.model_type,
|
||||
model=payload.model,
|
||||
@@ -238,10 +250,7 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
temp.flush()
|
||||
temp.seek(0)
|
||||
|
||||
response = model_instance.invoke_speech2text(
|
||||
file=temp,
|
||||
user=user_id,
|
||||
)
|
||||
response = model_instance.invoke_speech2text(file=temp)
|
||||
|
||||
return {
|
||||
"result": response,
|
||||
@@ -252,18 +261,16 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
"""
|
||||
invoke moderation
|
||||
"""
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
model_instance = cls._get_bound_model_instance(
|
||||
tenant_id=tenant.id,
|
||||
user_id=user_id,
|
||||
provider=payload.provider,
|
||||
model_type=payload.model_type,
|
||||
model=payload.model,
|
||||
)
|
||||
|
||||
# invoke model
|
||||
response = model_instance.invoke_moderation(
|
||||
text=payload.text,
|
||||
user=user_id,
|
||||
)
|
||||
response = model_instance.invoke_moderation(text=payload.text)
|
||||
|
||||
return {
|
||||
"result": response,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import binascii
|
||||
from collections.abc import Generator, Sequence
|
||||
from typing import IO
|
||||
from typing import IO, Any
|
||||
|
||||
from core.plugin.entities.plugin_daemon import (
|
||||
PluginBasicBooleanResponse,
|
||||
@@ -16,12 +16,19 @@ from core.plugin.impl.base import BasePluginClient
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMResultChunk
|
||||
from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||
from dify_graph.model_runtime.entities.model_entities import AIModelEntity
|
||||
from dify_graph.model_runtime.entities.rerank_entities import RerankResult
|
||||
from dify_graph.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult
|
||||
from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingResult
|
||||
from dify_graph.model_runtime.utils.encoders import jsonable_encoder
|
||||
|
||||
|
||||
class PluginModelClient(BasePluginClient):
|
||||
@staticmethod
|
||||
def _dispatch_payload(*, user_id: str | None, data: dict[str, Any]) -> dict[str, Any]:
|
||||
payload: dict[str, Any] = {"data": data}
|
||||
if user_id is not None:
|
||||
payload["user_id"] = user_id
|
||||
return payload
|
||||
|
||||
def fetch_model_providers(self, tenant_id: str) -> Sequence[PluginModelProviderEntity]:
|
||||
"""
|
||||
Fetch model providers for the given tenant.
|
||||
@@ -37,7 +44,7 @@ class PluginModelClient(BasePluginClient):
|
||||
def get_model_schema(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
user_id: str | None,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model_type: str,
|
||||
@@ -51,15 +58,15 @@ class PluginModelClient(BasePluginClient):
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/model/schema",
|
||||
PluginModelSchemaEntity,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
data=self._dispatch_payload(
|
||||
user_id=user_id,
|
||||
data={
|
||||
"provider": provider,
|
||||
"model_type": model_type,
|
||||
"model": model,
|
||||
"credentials": credentials,
|
||||
},
|
||||
},
|
||||
),
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
@@ -72,7 +79,7 @@ class PluginModelClient(BasePluginClient):
|
||||
return None
|
||||
|
||||
def validate_provider_credentials(
|
||||
self, tenant_id: str, user_id: str, plugin_id: str, provider: str, credentials: dict
|
||||
self, tenant_id: str, user_id: str | None, plugin_id: str, provider: str, credentials: dict
|
||||
) -> bool:
|
||||
"""
|
||||
validate the credentials of the provider
|
||||
@@ -81,13 +88,13 @@ class PluginModelClient(BasePluginClient):
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/model/validate_provider_credentials",
|
||||
PluginBasicBooleanResponse,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
data=self._dispatch_payload(
|
||||
user_id=user_id,
|
||||
data={
|
||||
"provider": provider,
|
||||
"credentials": credentials,
|
||||
},
|
||||
},
|
||||
),
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
@@ -105,7 +112,7 @@ class PluginModelClient(BasePluginClient):
|
||||
def validate_model_credentials(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
user_id: str | None,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model_type: str,
|
||||
@@ -119,15 +126,15 @@ class PluginModelClient(BasePluginClient):
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/model/validate_model_credentials",
|
||||
PluginBasicBooleanResponse,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
data=self._dispatch_payload(
|
||||
user_id=user_id,
|
||||
data={
|
||||
"provider": provider,
|
||||
"model_type": model_type,
|
||||
"model": model,
|
||||
"credentials": credentials,
|
||||
},
|
||||
},
|
||||
),
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
@@ -145,7 +152,7 @@ class PluginModelClient(BasePluginClient):
|
||||
def invoke_llm(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
user_id: str | None,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
@@ -164,9 +171,9 @@ class PluginModelClient(BasePluginClient):
|
||||
path=f"plugin/{tenant_id}/dispatch/llm/invoke",
|
||||
type_=LLMResultChunk,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
self._dispatch_payload(
|
||||
user_id=user_id,
|
||||
data={
|
||||
"provider": provider,
|
||||
"model_type": "llm",
|
||||
"model": model,
|
||||
@@ -177,7 +184,7 @@ class PluginModelClient(BasePluginClient):
|
||||
"stop": stop,
|
||||
"stream": stream,
|
||||
},
|
||||
}
|
||||
)
|
||||
),
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
@@ -193,7 +200,7 @@ class PluginModelClient(BasePluginClient):
|
||||
def get_llm_num_tokens(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
user_id: str | None,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model_type: str,
|
||||
@@ -210,9 +217,9 @@ class PluginModelClient(BasePluginClient):
|
||||
path=f"plugin/{tenant_id}/dispatch/llm/num_tokens",
|
||||
type_=PluginLLMNumTokensResponse,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
self._dispatch_payload(
|
||||
user_id=user_id,
|
||||
data={
|
||||
"provider": provider,
|
||||
"model_type": model_type,
|
||||
"model": model,
|
||||
@@ -220,7 +227,7 @@ class PluginModelClient(BasePluginClient):
|
||||
"prompt_messages": prompt_messages,
|
||||
"tools": tools,
|
||||
},
|
||||
}
|
||||
)
|
||||
),
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
@@ -236,7 +243,7 @@ class PluginModelClient(BasePluginClient):
|
||||
def invoke_text_embedding(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
user_id: str | None,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
@@ -252,9 +259,9 @@ class PluginModelClient(BasePluginClient):
|
||||
path=f"plugin/{tenant_id}/dispatch/text_embedding/invoke",
|
||||
type_=EmbeddingResult,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
self._dispatch_payload(
|
||||
user_id=user_id,
|
||||
data={
|
||||
"provider": provider,
|
||||
"model_type": "text-embedding",
|
||||
"model": model,
|
||||
@@ -262,7 +269,7 @@ class PluginModelClient(BasePluginClient):
|
||||
"texts": texts,
|
||||
"input_type": input_type,
|
||||
},
|
||||
}
|
||||
)
|
||||
),
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
@@ -278,7 +285,7 @@ class PluginModelClient(BasePluginClient):
|
||||
def invoke_multimodal_embedding(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
user_id: str | None,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
@@ -294,9 +301,9 @@ class PluginModelClient(BasePluginClient):
|
||||
path=f"plugin/{tenant_id}/dispatch/multimodal_embedding/invoke",
|
||||
type_=EmbeddingResult,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
self._dispatch_payload(
|
||||
user_id=user_id,
|
||||
data={
|
||||
"provider": provider,
|
||||
"model_type": "text-embedding",
|
||||
"model": model,
|
||||
@@ -304,7 +311,7 @@ class PluginModelClient(BasePluginClient):
|
||||
"documents": documents,
|
||||
"input_type": input_type,
|
||||
},
|
||||
}
|
||||
)
|
||||
),
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
@@ -320,7 +327,7 @@ class PluginModelClient(BasePluginClient):
|
||||
def get_text_embedding_num_tokens(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
user_id: str | None,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
@@ -335,16 +342,16 @@ class PluginModelClient(BasePluginClient):
|
||||
path=f"plugin/{tenant_id}/dispatch/text_embedding/num_tokens",
|
||||
type_=PluginTextEmbeddingNumTokensResponse,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
self._dispatch_payload(
|
||||
user_id=user_id,
|
||||
data={
|
||||
"provider": provider,
|
||||
"model_type": "text-embedding",
|
||||
"model": model,
|
||||
"credentials": credentials,
|
||||
"texts": texts,
|
||||
},
|
||||
}
|
||||
)
|
||||
),
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
@@ -360,7 +367,7 @@ class PluginModelClient(BasePluginClient):
|
||||
def invoke_rerank(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
user_id: str | None,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
@@ -378,9 +385,9 @@ class PluginModelClient(BasePluginClient):
|
||||
path=f"plugin/{tenant_id}/dispatch/rerank/invoke",
|
||||
type_=RerankResult,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
self._dispatch_payload(
|
||||
user_id=user_id,
|
||||
data={
|
||||
"provider": provider,
|
||||
"model_type": "rerank",
|
||||
"model": model,
|
||||
@@ -390,7 +397,7 @@ class PluginModelClient(BasePluginClient):
|
||||
"score_threshold": score_threshold,
|
||||
"top_n": top_n,
|
||||
},
|
||||
}
|
||||
)
|
||||
),
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
@@ -406,13 +413,13 @@ class PluginModelClient(BasePluginClient):
|
||||
def invoke_multimodal_rerank(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
user_id: str | None,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
query: dict,
|
||||
docs: list[dict],
|
||||
query: MultimodalRerankInput,
|
||||
docs: list[MultimodalRerankInput],
|
||||
score_threshold: float | None = None,
|
||||
top_n: int | None = None,
|
||||
) -> RerankResult:
|
||||
@@ -424,9 +431,9 @@ class PluginModelClient(BasePluginClient):
|
||||
path=f"plugin/{tenant_id}/dispatch/multimodal_rerank/invoke",
|
||||
type_=RerankResult,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
self._dispatch_payload(
|
||||
user_id=user_id,
|
||||
data={
|
||||
"provider": provider,
|
||||
"model_type": "rerank",
|
||||
"model": model,
|
||||
@@ -436,7 +443,7 @@ class PluginModelClient(BasePluginClient):
|
||||
"score_threshold": score_threshold,
|
||||
"top_n": top_n,
|
||||
},
|
||||
}
|
||||
)
|
||||
),
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
@@ -451,7 +458,7 @@ class PluginModelClient(BasePluginClient):
|
||||
def invoke_tts(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
user_id: str | None,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
@@ -467,9 +474,9 @@ class PluginModelClient(BasePluginClient):
|
||||
path=f"plugin/{tenant_id}/dispatch/tts/invoke",
|
||||
type_=PluginStringResultResponse,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
self._dispatch_payload(
|
||||
user_id=user_id,
|
||||
data={
|
||||
"provider": provider,
|
||||
"model_type": "tts",
|
||||
"model": model,
|
||||
@@ -478,7 +485,7 @@ class PluginModelClient(BasePluginClient):
|
||||
"content_text": content_text,
|
||||
"voice": voice,
|
||||
},
|
||||
}
|
||||
)
|
||||
),
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
@@ -496,7 +503,7 @@ class PluginModelClient(BasePluginClient):
|
||||
def get_tts_model_voices(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
user_id: str | None,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
@@ -511,16 +518,16 @@ class PluginModelClient(BasePluginClient):
|
||||
path=f"plugin/{tenant_id}/dispatch/tts/model/voices",
|
||||
type_=PluginVoicesResponse,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
self._dispatch_payload(
|
||||
user_id=user_id,
|
||||
data={
|
||||
"provider": provider,
|
||||
"model_type": "tts",
|
||||
"model": model,
|
||||
"credentials": credentials,
|
||||
"language": language,
|
||||
},
|
||||
}
|
||||
)
|
||||
),
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
@@ -540,7 +547,7 @@ class PluginModelClient(BasePluginClient):
|
||||
def invoke_speech_to_text(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
user_id: str | None,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
@@ -555,16 +562,16 @@ class PluginModelClient(BasePluginClient):
|
||||
path=f"plugin/{tenant_id}/dispatch/speech2text/invoke",
|
||||
type_=PluginStringResultResponse,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
self._dispatch_payload(
|
||||
user_id=user_id,
|
||||
data={
|
||||
"provider": provider,
|
||||
"model_type": "speech2text",
|
||||
"model": model,
|
||||
"credentials": credentials,
|
||||
"file": binascii.hexlify(file.read()).decode(),
|
||||
},
|
||||
}
|
||||
)
|
||||
),
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
@@ -580,7 +587,7 @@ class PluginModelClient(BasePluginClient):
|
||||
def invoke_moderation(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
user_id: str | None,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
@@ -595,16 +602,16 @@ class PluginModelClient(BasePluginClient):
|
||||
path=f"plugin/{tenant_id}/dispatch/moderation/invoke",
|
||||
type_=PluginBasicBooleanResponse,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
self._dispatch_payload(
|
||||
user_id=user_id,
|
||||
data={
|
||||
"provider": provider,
|
||||
"model_type": "moderation",
|
||||
"model": model,
|
||||
"credentials": credentials,
|
||||
"text": text,
|
||||
},
|
||||
}
|
||||
)
|
||||
),
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
|
||||
490
api/core/plugin/impl/model_runtime.py
Normal file
490
api/core/plugin/impl/model_runtime.py
Normal file
@@ -0,0 +1,490 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
from collections.abc import Generator, Iterable, Sequence
|
||||
from threading import Lock
|
||||
from typing import IO, Any, Union
|
||||
|
||||
from pydantic import ValidationError
|
||||
from redis import RedisError
|
||||
|
||||
from configs import dify_config
|
||||
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
||||
from core.plugin.impl.asset import PluginAssetManager
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||
from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||
from dify_graph.model_runtime.entities.provider_entities import ProviderEntity
|
||||
from dify_graph.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult
|
||||
from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingInputType, EmbeddingResult
|
||||
from dify_graph.model_runtime.runtime import ModelRuntime
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.provider_ids import ModelProviderID
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PluginModelRuntime(ModelRuntime):
|
||||
"""Plugin-backed runtime adapter bound to tenant context and a default user."""
|
||||
|
||||
tenant_id: str
|
||||
user_id: str | None
|
||||
client: PluginModelClient
|
||||
_provider_entities: tuple[ProviderEntity, ...] | None
|
||||
_provider_entities_lock: Lock
|
||||
|
||||
def __init__(self, tenant_id: str, user_id: str | None, client: PluginModelClient) -> None:
|
||||
if client is None:
|
||||
raise ValueError("client is required.")
|
||||
self.tenant_id = tenant_id
|
||||
self.user_id = user_id
|
||||
self.client = client
|
||||
self._provider_entities = None
|
||||
self._provider_entities_lock = Lock()
|
||||
|
||||
def fetch_model_providers(self) -> Sequence[ProviderEntity]:
|
||||
if self._provider_entities is not None:
|
||||
return self._provider_entities
|
||||
|
||||
with self._provider_entities_lock:
|
||||
if self._provider_entities is None:
|
||||
self._provider_entities = tuple(
|
||||
self._to_provider_entity(provider) for provider in self.client.fetch_model_providers(self.tenant_id)
|
||||
)
|
||||
|
||||
return self._provider_entities
|
||||
|
||||
def get_provider_icon(self, *, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]:
|
||||
provider_schema = self._get_provider_schema(provider)
|
||||
|
||||
if icon_type.lower() == "icon_small":
|
||||
if not provider_schema.icon_small:
|
||||
raise ValueError(f"Provider {provider} does not have small icon.")
|
||||
file_name = (
|
||||
provider_schema.icon_small.zh_Hans if lang.lower() == "zh_hans" else provider_schema.icon_small.en_US
|
||||
)
|
||||
elif icon_type.lower() == "icon_small_dark":
|
||||
if not provider_schema.icon_small_dark:
|
||||
raise ValueError(f"Provider {provider} does not have small dark icon.")
|
||||
file_name = (
|
||||
provider_schema.icon_small_dark.zh_Hans
|
||||
if lang.lower() == "zh_hans"
|
||||
else provider_schema.icon_small_dark.en_US
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported icon type: {icon_type}.")
|
||||
|
||||
if not file_name:
|
||||
raise ValueError(f"Provider {provider} does not have icon.")
|
||||
|
||||
image_mime_types = {
|
||||
"jpg": "image/jpeg",
|
||||
"jpeg": "image/jpeg",
|
||||
"png": "image/png",
|
||||
"gif": "image/gif",
|
||||
"bmp": "image/bmp",
|
||||
"tiff": "image/tiff",
|
||||
"tif": "image/tiff",
|
||||
"webp": "image/webp",
|
||||
"svg": "image/svg+xml",
|
||||
"ico": "image/vnd.microsoft.icon",
|
||||
"heif": "image/heif",
|
||||
"heic": "image/heic",
|
||||
}
|
||||
|
||||
extension = file_name.split(".")[-1]
|
||||
mime_type = image_mime_types.get(extension, "image/png")
|
||||
return PluginAssetManager().fetch_asset(tenant_id=self.tenant_id, id=file_name), mime_type
|
||||
|
||||
def validate_provider_credentials(self, *, provider: str, credentials: dict[str, Any]) -> None:
|
||||
plugin_id, provider_name = self._split_provider(provider)
|
||||
self.client.validate_provider_credentials(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
def validate_model_credentials(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model_type: ModelType,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
) -> None:
|
||||
plugin_id, provider_name = self._split_provider(provider)
|
||||
self.client.validate_model_credentials(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
model_type=model_type.value,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
def get_model_schema(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model_type: ModelType,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
) -> AIModelEntity | None:
|
||||
cache_key = self._get_schema_cache_key(
|
||||
provider=provider,
|
||||
model_type=model_type,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
cached_schema_json = None
|
||||
try:
|
||||
cached_schema_json = redis_client.get(cache_key)
|
||||
except (RedisError, RuntimeError) as exc:
|
||||
logger.warning(
|
||||
"Failed to read plugin model schema cache for model %s: %s",
|
||||
model,
|
||||
str(exc),
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
if cached_schema_json:
|
||||
try:
|
||||
return AIModelEntity.model_validate_json(cached_schema_json)
|
||||
except ValidationError:
|
||||
logger.warning("Failed to validate cached plugin model schema for model %s", model, exc_info=True)
|
||||
try:
|
||||
redis_client.delete(cache_key)
|
||||
except (RedisError, RuntimeError) as exc:
|
||||
logger.warning(
|
||||
"Failed to delete invalid plugin model schema cache for model %s: %s",
|
||||
model,
|
||||
str(exc),
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
plugin_id, provider_name = self._split_provider(provider)
|
||||
schema = self.client.get_model_schema(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
model_type=model_type.value,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
if schema:
|
||||
try:
|
||||
redis_client.setex(cache_key, dify_config.PLUGIN_MODEL_SCHEMA_CACHE_TTL, schema.model_dump_json())
|
||||
except (RedisError, RuntimeError) as exc:
|
||||
logger.warning(
|
||||
"Failed to write plugin model schema cache for model %s: %s",
|
||||
model,
|
||||
str(exc),
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
return schema
|
||||
|
||||
def invoke_llm(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
model_parameters: dict[str, Any],
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
tools: list[PromptMessageTool] | None,
|
||||
stop: Sequence[str] | None,
|
||||
stream: bool,
|
||||
) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]:
|
||||
plugin_id, provider_name = self._split_provider(provider)
|
||||
return self.client.invoke_llm(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
model_parameters=model_parameters,
|
||||
prompt_messages=list(prompt_messages),
|
||||
tools=tools,
|
||||
stop=list(stop) if stop else None,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
def get_llm_num_tokens(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model_type: ModelType,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
tools: Sequence[PromptMessageTool] | None,
|
||||
) -> int:
|
||||
if not dify_config.PLUGIN_BASED_TOKEN_COUNTING_ENABLED:
|
||||
return 0
|
||||
|
||||
plugin_id, provider_name = self._split_provider(provider)
|
||||
return self.client.get_llm_num_tokens(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
model_type=model_type.value,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=list(prompt_messages),
|
||||
tools=list(tools) if tools else None,
|
||||
)
|
||||
|
||||
def invoke_text_embedding(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
texts: list[str],
|
||||
input_type: EmbeddingInputType,
|
||||
) -> EmbeddingResult:
|
||||
plugin_id, provider_name = self._split_provider(provider)
|
||||
return self.client.invoke_text_embedding(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
texts=texts,
|
||||
input_type=input_type,
|
||||
)
|
||||
|
||||
def invoke_multimodal_embedding(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
documents: list[dict[str, Any]],
|
||||
input_type: EmbeddingInputType,
|
||||
) -> EmbeddingResult:
|
||||
plugin_id, provider_name = self._split_provider(provider)
|
||||
return self.client.invoke_multimodal_embedding(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
documents=documents,
|
||||
input_type=input_type,
|
||||
)
|
||||
|
||||
def get_text_embedding_num_tokens(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
texts: list[str],
|
||||
) -> list[int]:
|
||||
plugin_id, provider_name = self._split_provider(provider)
|
||||
return self.client.get_text_embedding_num_tokens(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
texts=texts,
|
||||
)
|
||||
|
||||
def invoke_rerank(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
query: str,
|
||||
docs: list[str],
|
||||
score_threshold: float | None,
|
||||
top_n: int | None,
|
||||
) -> RerankResult:
|
||||
plugin_id, provider_name = self._split_provider(provider)
|
||||
return self.client.invoke_rerank(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
query=query,
|
||||
docs=docs,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_n,
|
||||
)
|
||||
|
||||
def invoke_multimodal_rerank(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
query: MultimodalRerankInput,
|
||||
docs: list[MultimodalRerankInput],
|
||||
score_threshold: float | None,
|
||||
top_n: int | None,
|
||||
) -> RerankResult:
|
||||
plugin_id, provider_name = self._split_provider(provider)
|
||||
return self.client.invoke_multimodal_rerank(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
query=query,
|
||||
docs=docs,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_n,
|
||||
)
|
||||
|
||||
def invoke_tts(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
content_text: str,
|
||||
voice: str,
|
||||
) -> Iterable[bytes]:
|
||||
plugin_id, provider_name = self._split_provider(provider)
|
||||
return self.client.invoke_tts(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
content_text=content_text,
|
||||
voice=voice,
|
||||
)
|
||||
|
||||
def get_tts_model_voices(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
language: str | None,
|
||||
) -> Any:
|
||||
plugin_id, provider_name = self._split_provider(provider)
|
||||
return self.client.get_tts_model_voices(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
language=language,
|
||||
)
|
||||
|
||||
def invoke_speech_to_text(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
file: IO[bytes],
|
||||
) -> str:
|
||||
plugin_id, provider_name = self._split_provider(provider)
|
||||
return self.client.invoke_speech_to_text(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
file=file,
|
||||
)
|
||||
|
||||
def invoke_moderation(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
text: str,
|
||||
) -> bool:
|
||||
plugin_id, provider_name = self._split_provider(provider)
|
||||
return self.client.invoke_moderation(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
text=text,
|
||||
)
|
||||
|
||||
def _get_provider_short_name_alias(self, provider: PluginModelProviderEntity) -> str:
|
||||
"""
|
||||
Expose a bare provider alias only for the canonical provider mapping.
|
||||
|
||||
Multiple plugins can publish the same short provider slug. If every
|
||||
provider entity keeps that slug in ``provider_name``, callers that still
|
||||
resolve by short name become order-dependent. Restrict the alias to the
|
||||
provider selected by ``ModelProviderID`` so legacy short-name lookups
|
||||
remain deterministic while the runtime surface stays canonical.
|
||||
"""
|
||||
try:
|
||||
canonical_provider_id = ModelProviderID(provider.provider)
|
||||
except ValueError:
|
||||
return ""
|
||||
|
||||
if canonical_provider_id.plugin_id != provider.plugin_id:
|
||||
return ""
|
||||
if canonical_provider_id.provider_name != provider.provider:
|
||||
return ""
|
||||
|
||||
return provider.provider
|
||||
|
||||
def _to_provider_entity(self, provider: PluginModelProviderEntity) -> ProviderEntity:
|
||||
declaration = provider.declaration.model_copy(deep=True)
|
||||
declaration.provider = f"{provider.plugin_id}/{provider.provider}"
|
||||
declaration.provider_name = self._get_provider_short_name_alias(provider)
|
||||
return declaration
|
||||
|
||||
def _get_provider_schema(self, provider: str) -> ProviderEntity:
|
||||
providers = self.fetch_model_providers()
|
||||
provider_entity = next((item for item in providers if item.provider == provider), None)
|
||||
if provider_entity is None:
|
||||
provider_entity = next((item for item in providers if provider == item.provider_name), None)
|
||||
if provider_entity is None:
|
||||
raise ValueError(f"Invalid provider: {provider}")
|
||||
return provider_entity
|
||||
|
||||
def _get_schema_cache_key(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model_type: ModelType,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
) -> str:
|
||||
cache_key = f"{self.tenant_id}:{provider}:{model_type.value}:{model}"
|
||||
sorted_credentials = sorted(credentials.items()) if credentials else []
|
||||
return cache_key + ":".join(
|
||||
[hashlib.md5(f"{key}:{value}".encode()).hexdigest() for key, value in sorted_credentials]
|
||||
)
|
||||
|
||||
def _split_provider(self, provider: str) -> tuple[str, str]:
|
||||
provider_id = ModelProviderID(provider)
|
||||
return provider_id.plugin_id, provider_id.provider_name
|
||||
45
api/core/plugin/impl/model_runtime_factory.py
Normal file
45
api/core/plugin/impl/model_runtime_factory.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.model_manager import ModelManager
|
||||
from core.plugin.impl.model_runtime import PluginModelRuntime
|
||||
from core.provider_manager import ProviderManager
|
||||
from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
|
||||
|
||||
def create_plugin_model_runtime(*, tenant_id: str, user_id: str | None = None) -> PluginModelRuntime:
|
||||
"""Create a plugin runtime with its client dependency fully composed."""
|
||||
from core.plugin.impl.model_runtime import PluginModelRuntime
|
||||
|
||||
return PluginModelRuntime(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
client=PluginModelClient(),
|
||||
)
|
||||
|
||||
|
||||
def create_plugin_model_provider_factory(*, tenant_id: str, user_id: str | None = None) -> ModelProviderFactory:
|
||||
"""Create a tenant-bound model provider factory for service flows."""
|
||||
from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
|
||||
return ModelProviderFactory(model_runtime=create_plugin_model_runtime(tenant_id=tenant_id, user_id=user_id))
|
||||
|
||||
|
||||
def create_plugin_provider_manager(*, tenant_id: str, user_id: str | None = None) -> ProviderManager:
|
||||
"""Create a tenant-bound provider manager for service flows."""
|
||||
from core.provider_manager import ProviderManager
|
||||
|
||||
return ProviderManager(model_runtime=create_plugin_model_runtime(tenant_id=tenant_id, user_id=user_id))
|
||||
|
||||
|
||||
def create_plugin_model_manager(*, tenant_id: str, user_id: str | None = None) -> ModelManager:
|
||||
"""Create a tenant-bound model manager for service flows."""
|
||||
from core.model_manager import ModelManager
|
||||
|
||||
return ModelManager(
|
||||
provider_manager=create_plugin_provider_manager(tenant_id=tenant_id, user_id=user_id),
|
||||
)
|
||||
Reference in New Issue
Block a user