feat: extract model runtime

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN-
2026-03-15 15:34:47 +08:00
parent 3d5a29462e
commit fbb74a4af9
178 changed files with 4343 additions and 2134 deletions

View File

@@ -7,6 +7,7 @@ from core.entities.embedding_type import EmbeddingInputType
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
from core.entities.provider_entities import ModelLoadBalancingConfiguration
from core.errors.error import ProviderTokenNotInitError
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
from core.provider_manager import ProviderManager
from dify_graph.model_runtime.callbacks.base_callback import Callback
from dify_graph.model_runtime.entities.llm_entities import LLMResult
@@ -30,7 +31,7 @@ logger = logging.getLogger(__name__)
class ModelInstance:
"""
Model instance class
Model instance class.
"""
def __init__(self, provider_model_bundle: ProviderModelBundle, model: str):
@@ -110,7 +111,6 @@ class ModelInstance:
tools: Sequence[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: Literal[True] = True,
user: str | None = None,
callbacks: list[Callback] | None = None,
) -> Generator: ...
@@ -122,7 +122,6 @@ class ModelInstance:
tools: Sequence[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: Literal[False] = False,
user: str | None = None,
callbacks: list[Callback] | None = None,
) -> LLMResult: ...
@@ -134,7 +133,6 @@ class ModelInstance:
tools: Sequence[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: bool = True,
user: str | None = None,
callbacks: list[Callback] | None = None,
) -> Union[LLMResult, Generator]: ...
@@ -145,7 +143,6 @@ class ModelInstance:
tools: Sequence[PromptMessageTool] | None = None,
stop: Sequence[str] | None = None,
stream: bool = True,
user: str | None = None,
callbacks: list[Callback] | None = None,
) -> Union[LLMResult, Generator]:
"""
@@ -156,7 +153,6 @@ class ModelInstance:
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:param callbacks: callbacks
:return: full response or stream response chunk generator result
"""
@@ -173,7 +169,6 @@ class ModelInstance:
tools=tools,
stop=stop,
stream=stream,
user=user,
callbacks=callbacks,
),
)
@@ -202,13 +197,12 @@ class ModelInstance:
)
def invoke_text_embedding(
self, texts: list[str], user: str | None = None, input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT
self, texts: list[str], input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT
) -> EmbeddingResult:
"""
Invoke large language model
:param texts: texts to embed
:param user: unique user id
:param input_type: input type
:return: embeddings result
"""
@@ -221,7 +215,6 @@ class ModelInstance:
model=self.model_name,
credentials=self.credentials,
texts=texts,
user=user,
input_type=input_type,
),
)
@@ -229,14 +222,12 @@ class ModelInstance:
def invoke_multimodal_embedding(
self,
multimodel_documents: list[dict],
user: str | None = None,
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
) -> EmbeddingResult:
"""
Invoke large language model
:param multimodel_documents: multimodel documents to embed
:param user: unique user id
:param input_type: input type
:return: embeddings result
"""
@@ -249,7 +240,6 @@ class ModelInstance:
model=self.model_name,
credentials=self.credentials,
multimodel_documents=multimodel_documents,
user=user,
input_type=input_type,
),
)
@@ -279,7 +269,6 @@ class ModelInstance:
docs: list[str],
score_threshold: float | None = None,
top_n: int | None = None,
user: str | None = None,
) -> RerankResult:
"""
Invoke rerank model
@@ -288,7 +277,6 @@ class ModelInstance:
:param docs: docs for reranking
:param score_threshold: score threshold
:param top_n: top n
:param user: unique user id
:return: rerank result
"""
if not isinstance(self.model_type_instance, RerankModel):
@@ -303,7 +291,6 @@ class ModelInstance:
docs=docs,
score_threshold=score_threshold,
top_n=top_n,
user=user,
),
)
@@ -313,7 +300,6 @@ class ModelInstance:
docs: list[dict],
score_threshold: float | None = None,
top_n: int | None = None,
user: str | None = None,
) -> RerankResult:
"""
Invoke rerank model
@@ -322,7 +308,6 @@ class ModelInstance:
:param docs: docs for reranking
:param score_threshold: score threshold
:param top_n: top n
:param user: unique user id
:return: rerank result
"""
if not isinstance(self.model_type_instance, RerankModel):
@@ -337,16 +322,14 @@ class ModelInstance:
docs=docs,
score_threshold=score_threshold,
top_n=top_n,
user=user,
),
)
def invoke_moderation(self, text: str, user: str | None = None) -> bool:
def invoke_moderation(self, text: str) -> bool:
"""
Invoke moderation model
:param text: text to moderate
:param user: unique user id
:return: false if text is safe, true otherwise
"""
if not isinstance(self.model_type_instance, ModerationModel):
@@ -358,16 +341,14 @@ class ModelInstance:
model=self.model_name,
credentials=self.credentials,
text=text,
user=user,
),
)
def invoke_speech2text(self, file: IO[bytes], user: str | None = None) -> str:
def invoke_speech2text(self, file: IO[bytes]) -> str:
"""
Invoke large language model
:param file: audio file
:param user: unique user id
:return: text for given audio file
"""
if not isinstance(self.model_type_instance, Speech2TextModel):
@@ -379,18 +360,15 @@ class ModelInstance:
model=self.model_name,
credentials=self.credentials,
file=file,
user=user,
),
)
def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: str | None = None) -> Iterable[bytes]:
def invoke_tts(self, content_text: str, voice: str = "") -> Iterable[bytes]:
"""
Invoke large language tts model
:param content_text: text content to be translated
:param tenant_id: user tenant id
:param voice: model timbre
:param user: unique user id
:return: text for given audio file
"""
if not isinstance(self.model_type_instance, TTSModel):
@@ -402,8 +380,6 @@ class ModelInstance:
model=self.model_name,
credentials=self.credentials,
content_text=content_text,
user=user,
tenant_id=tenant_id,
voice=voice,
),
)
@@ -477,10 +453,20 @@ class ModelInstance:
class ModelManager:
def __init__(self):
self._provider_manager = ProviderManager()
def __init__(self, provider_manager: ProviderManager):
self._provider_manager = provider_manager
def get_model_instance(self, tenant_id: str, provider: str, model_type: ModelType, model: str) -> ModelInstance:
@classmethod
def for_tenant(cls, tenant_id: str, user_id: str | None = None) -> "ModelManager":
return cls(provider_manager=create_plugin_provider_manager(tenant_id=tenant_id, user_id=user_id))
def get_model_instance(
self,
tenant_id: str,
provider: str,
model_type: ModelType,
model: str,
) -> ModelInstance:
"""
Get model instance
:param tenant_id: tenant id
@@ -496,7 +482,8 @@ class ModelManager:
tenant_id=tenant_id, provider=provider, model_type=model_type
)
return ModelInstance(provider_model_bundle, model)
model_instance = ModelInstance(provider_model_bundle, model)
return model_instance
def get_default_provider_model_name(self, tenant_id: str, model_type: ModelType) -> tuple[str | None, str | None]:
"""