feat: extract model runtime

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN-
2026-03-15 15:34:47 +08:00
parent 3d5a29462e
commit fbb74a4af9
178 changed files with 4343 additions and 2134 deletions

View File

@@ -8,7 +8,7 @@ from sqlalchemy.exc import IntegrityError
from configs import dify_config
from core.entities.embedding_type import EmbeddingInputType
from core.model_manager import ModelInstance
from core.model_manager import ModelInstance, ModelManager
from core.rag.embedding.embedding_base import Embeddings
from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey
from dify_graph.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
@@ -22,8 +22,20 @@ logger = logging.getLogger(__name__)
class CacheEmbedding(Embeddings):
def __init__(self, model_instance: ModelInstance, user: str | None = None):
self._model_instance = model_instance
self._user = user
self._model_instance = self._bind_model_instance(model_instance, user)
@staticmethod
def _bind_model_instance(model_instance: ModelInstance, user: str | None) -> ModelInstance:
if user is None:
return model_instance
tenant_id = model_instance.provider_model_bundle.configuration.tenant_id
return ModelManager.for_tenant(tenant_id=tenant_id, user_id=user).get_model_instance(
tenant_id=tenant_id,
provider=model_instance.provider,
model_type=model_instance.model_type_instance.model_type,
model=model_instance.model_name,
)
def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""Embed search docs in batches of 10."""
@@ -65,7 +77,7 @@ class CacheEmbedding(Embeddings):
batch_texts = embedding_queue_texts[i : i + max_chunks]
embedding_result = self._model_instance.invoke_text_embedding(
texts=batch_texts, user=self._user, input_type=EmbeddingInputType.DOCUMENT
texts=batch_texts, input_type=EmbeddingInputType.DOCUMENT
)
for vector in embedding_result.embeddings:
@@ -147,7 +159,6 @@ class CacheEmbedding(Embeddings):
embedding_result = self._model_instance.invoke_multimodal_embedding(
multimodel_documents=batch_multimodel_documents,
user=self._user,
input_type=EmbeddingInputType.DOCUMENT,
)
@@ -202,7 +213,7 @@ class CacheEmbedding(Embeddings):
return [float(x) for x in decoded_embedding]
try:
embedding_result = self._model_instance.invoke_text_embedding(
texts=[text], user=self._user, input_type=EmbeddingInputType.QUERY
texts=[text], input_type=EmbeddingInputType.QUERY
)
embedding_results = embedding_result.embeddings[0]
@@ -245,7 +256,7 @@ class CacheEmbedding(Embeddings):
return [float(x) for x in decoded_embedding]
try:
embedding_result = self._model_instance.invoke_multimodal_embedding(
multimodel_documents=[multimodel_document], user=self._user, input_type=EmbeddingInputType.QUERY
multimodel_documents=[multimodel_document], input_type=EmbeddingInputType.QUERY
)
embedding_results = embedding_result.embeddings[0]