mirror of
https://github.com/langgenius/dify.git
synced 2026-06-01 04:00:59 -04:00
feat: extract model runtime
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
@@ -7,6 +7,7 @@ from core.entities.embedding_type import EmbeddingInputType
|
||||
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
|
||||
from core.entities.provider_entities import ModelLoadBalancingConfiguration
|
||||
from core.errors.error import ProviderTokenNotInitError
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
|
||||
from core.provider_manager import ProviderManager
|
||||
from dify_graph.model_runtime.callbacks.base_callback import Callback
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMResult
|
||||
@@ -30,7 +31,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class ModelInstance:
|
||||
"""
|
||||
Model instance class
|
||||
Model instance class.
|
||||
"""
|
||||
|
||||
def __init__(self, provider_model_bundle: ProviderModelBundle, model: str):
|
||||
@@ -110,7 +111,6 @@ class ModelInstance:
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
stream: Literal[True] = True,
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
) -> Generator: ...
|
||||
|
||||
@@ -122,7 +122,6 @@ class ModelInstance:
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
stream: Literal[False] = False,
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
) -> LLMResult: ...
|
||||
|
||||
@@ -134,7 +133,6 @@ class ModelInstance:
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
) -> Union[LLMResult, Generator]: ...
|
||||
|
||||
@@ -145,7 +143,6 @@ class ModelInstance:
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: Sequence[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
@@ -156,7 +153,6 @@ class ModelInstance:
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:param callbacks: callbacks
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
@@ -173,7 +169,6 @@ class ModelInstance:
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
callbacks=callbacks,
|
||||
),
|
||||
)
|
||||
@@ -202,13 +197,12 @@ class ModelInstance:
|
||||
)
|
||||
|
||||
def invoke_text_embedding(
|
||||
self, texts: list[str], user: str | None = None, input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT
|
||||
self, texts: list[str], input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT
|
||||
) -> EmbeddingResult:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param texts: texts to embed
|
||||
:param user: unique user id
|
||||
:param input_type: input type
|
||||
:return: embeddings result
|
||||
"""
|
||||
@@ -221,7 +215,6 @@ class ModelInstance:
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
texts=texts,
|
||||
user=user,
|
||||
input_type=input_type,
|
||||
),
|
||||
)
|
||||
@@ -229,14 +222,12 @@ class ModelInstance:
|
||||
def invoke_multimodal_embedding(
|
||||
self,
|
||||
multimodel_documents: list[dict],
|
||||
user: str | None = None,
|
||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||
) -> EmbeddingResult:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param multimodel_documents: multimodel documents to embed
|
||||
:param user: unique user id
|
||||
:param input_type: input type
|
||||
:return: embeddings result
|
||||
"""
|
||||
@@ -249,7 +240,6 @@ class ModelInstance:
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
multimodel_documents=multimodel_documents,
|
||||
user=user,
|
||||
input_type=input_type,
|
||||
),
|
||||
)
|
||||
@@ -279,7 +269,6 @@ class ModelInstance:
|
||||
docs: list[str],
|
||||
score_threshold: float | None = None,
|
||||
top_n: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> RerankResult:
|
||||
"""
|
||||
Invoke rerank model
|
||||
@@ -288,7 +277,6 @@ class ModelInstance:
|
||||
:param docs: docs for reranking
|
||||
:param score_threshold: score threshold
|
||||
:param top_n: top n
|
||||
:param user: unique user id
|
||||
:return: rerank result
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, RerankModel):
|
||||
@@ -303,7 +291,6 @@ class ModelInstance:
|
||||
docs=docs,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_n,
|
||||
user=user,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -313,7 +300,6 @@ class ModelInstance:
|
||||
docs: list[dict],
|
||||
score_threshold: float | None = None,
|
||||
top_n: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> RerankResult:
|
||||
"""
|
||||
Invoke rerank model
|
||||
@@ -322,7 +308,6 @@ class ModelInstance:
|
||||
:param docs: docs for reranking
|
||||
:param score_threshold: score threshold
|
||||
:param top_n: top n
|
||||
:param user: unique user id
|
||||
:return: rerank result
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, RerankModel):
|
||||
@@ -337,16 +322,14 @@ class ModelInstance:
|
||||
docs=docs,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_n,
|
||||
user=user,
|
||||
),
|
||||
)
|
||||
|
||||
def invoke_moderation(self, text: str, user: str | None = None) -> bool:
|
||||
def invoke_moderation(self, text: str) -> bool:
|
||||
"""
|
||||
Invoke moderation model
|
||||
|
||||
:param text: text to moderate
|
||||
:param user: unique user id
|
||||
:return: false if text is safe, true otherwise
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, ModerationModel):
|
||||
@@ -358,16 +341,14 @@ class ModelInstance:
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
text=text,
|
||||
user=user,
|
||||
),
|
||||
)
|
||||
|
||||
def invoke_speech2text(self, file: IO[bytes], user: str | None = None) -> str:
|
||||
def invoke_speech2text(self, file: IO[bytes]) -> str:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param file: audio file
|
||||
:param user: unique user id
|
||||
:return: text for given audio file
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, Speech2TextModel):
|
||||
@@ -379,18 +360,15 @@ class ModelInstance:
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
file=file,
|
||||
user=user,
|
||||
),
|
||||
)
|
||||
|
||||
def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: str | None = None) -> Iterable[bytes]:
|
||||
def invoke_tts(self, content_text: str, voice: str = "") -> Iterable[bytes]:
|
||||
"""
|
||||
Invoke large language tts model
|
||||
|
||||
:param content_text: text content to be translated
|
||||
:param tenant_id: user tenant id
|
||||
:param voice: model timbre
|
||||
:param user: unique user id
|
||||
:return: text for given audio file
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, TTSModel):
|
||||
@@ -402,8 +380,6 @@ class ModelInstance:
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
content_text=content_text,
|
||||
user=user,
|
||||
tenant_id=tenant_id,
|
||||
voice=voice,
|
||||
),
|
||||
)
|
||||
@@ -477,10 +453,20 @@ class ModelInstance:
|
||||
|
||||
|
||||
class ModelManager:
|
||||
def __init__(self):
|
||||
self._provider_manager = ProviderManager()
|
||||
def __init__(self, provider_manager: ProviderManager):
|
||||
self._provider_manager = provider_manager
|
||||
|
||||
def get_model_instance(self, tenant_id: str, provider: str, model_type: ModelType, model: str) -> ModelInstance:
|
||||
@classmethod
|
||||
def for_tenant(cls, tenant_id: str, user_id: str | None = None) -> "ModelManager":
|
||||
return cls(provider_manager=create_plugin_provider_manager(tenant_id=tenant_id, user_id=user_id))
|
||||
|
||||
def get_model_instance(
|
||||
self,
|
||||
tenant_id: str,
|
||||
provider: str,
|
||||
model_type: ModelType,
|
||||
model: str,
|
||||
) -> ModelInstance:
|
||||
"""
|
||||
Get model instance
|
||||
:param tenant_id: tenant id
|
||||
@@ -496,7 +482,8 @@ class ModelManager:
|
||||
tenant_id=tenant_id, provider=provider, model_type=model_type
|
||||
)
|
||||
|
||||
return ModelInstance(provider_model_bundle, model)
|
||||
model_instance = ModelInstance(provider_model_bundle, model)
|
||||
return model_instance
|
||||
|
||||
def get_default_provider_model_name(self, tenant_id: str, model_type: ModelType) -> tuple[str | None, str | None]:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user