mirror of
https://github.com/langgenius/dify.git
synced 2026-06-01 13:00:48 -04:00
feat: extract model runtime
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
@@ -8,7 +8,7 @@ from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from configs import dify_config
|
||||
from core.entities.embedding_type import EmbeddingInputType
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.rag.embedding.embedding_base import Embeddings
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey
|
||||
from dify_graph.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
@@ -22,8 +22,20 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class CacheEmbedding(Embeddings):
|
||||
def __init__(self, model_instance: ModelInstance, user: str | None = None):
|
||||
self._model_instance = model_instance
|
||||
self._user = user
|
||||
self._model_instance = self._bind_model_instance(model_instance, user)
|
||||
|
||||
@staticmethod
|
||||
def _bind_model_instance(model_instance: ModelInstance, user: str | None) -> ModelInstance:
|
||||
if user is None:
|
||||
return model_instance
|
||||
|
||||
tenant_id = model_instance.provider_model_bundle.configuration.tenant_id
|
||||
return ModelManager.for_tenant(tenant_id=tenant_id, user_id=user).get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
provider=model_instance.provider,
|
||||
model_type=model_instance.model_type_instance.model_type,
|
||||
model=model_instance.model_name,
|
||||
)
|
||||
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Embed search docs in batches of 10."""
|
||||
@@ -65,7 +77,7 @@ class CacheEmbedding(Embeddings):
|
||||
batch_texts = embedding_queue_texts[i : i + max_chunks]
|
||||
|
||||
embedding_result = self._model_instance.invoke_text_embedding(
|
||||
texts=batch_texts, user=self._user, input_type=EmbeddingInputType.DOCUMENT
|
||||
texts=batch_texts, input_type=EmbeddingInputType.DOCUMENT
|
||||
)
|
||||
|
||||
for vector in embedding_result.embeddings:
|
||||
@@ -147,7 +159,6 @@ class CacheEmbedding(Embeddings):
|
||||
|
||||
embedding_result = self._model_instance.invoke_multimodal_embedding(
|
||||
multimodel_documents=batch_multimodel_documents,
|
||||
user=self._user,
|
||||
input_type=EmbeddingInputType.DOCUMENT,
|
||||
)
|
||||
|
||||
@@ -202,7 +213,7 @@ class CacheEmbedding(Embeddings):
|
||||
return [float(x) for x in decoded_embedding]
|
||||
try:
|
||||
embedding_result = self._model_instance.invoke_text_embedding(
|
||||
texts=[text], user=self._user, input_type=EmbeddingInputType.QUERY
|
||||
texts=[text], input_type=EmbeddingInputType.QUERY
|
||||
)
|
||||
|
||||
embedding_results = embedding_result.embeddings[0]
|
||||
@@ -245,7 +256,7 @@ class CacheEmbedding(Embeddings):
|
||||
return [float(x) for x in decoded_embedding]
|
||||
try:
|
||||
embedding_result = self._model_instance.invoke_multimodal_embedding(
|
||||
multimodel_documents=[multimodel_document], user=self._user, input_type=EmbeddingInputType.QUERY
|
||||
multimodel_documents=[multimodel_document], input_type=EmbeddingInputType.QUERY
|
||||
)
|
||||
|
||||
embedding_results = embedding_result.embeddings[0]
|
||||
|
||||
Reference in New Issue
Block a user