feat: extract model runtime

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN-
2026-03-15 15:34:47 +08:00
parent 3d5a29462e
commit fbb74a4af9
178 changed files with 4343 additions and 2134 deletions

View File

@@ -106,9 +106,9 @@ class DataPostProcessor:
) -> ModelInstance | None:
if reranking_model:
try:
model_manager = ModelManager()
reranking_provider_name = reranking_model["reranking_provider_name"]
reranking_model_name = reranking_model["reranking_model_name"]
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
reranking_provider_name = reranking_model.get("reranking_provider_name")
reranking_model_name = reranking_model.get("reranking_model_name")
if not reranking_provider_name or not reranking_model_name:
return None
rerank_model_instance = model_manager.get_model_instance(

View File

@@ -328,7 +328,7 @@ class RetrievalService:
str(dataset.tenant_id), str(RerankMode.RERANKING_MODEL), reranking_model, None, False
)
if dataset.is_multimodal:
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=dataset.tenant_id)
is_support_vision = model_manager.check_model_support_vision(
tenant_id=dataset.tenant_id,
provider=reranking_model["reranking_provider_name"],

View File

@@ -303,7 +303,7 @@ class Vector:
redis_client.delete(collection_exist_cache_key)
def _get_embeddings(self) -> Embeddings:
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=self._dataset.tenant_id)
embedding_model = model_manager.get_model_instance(
tenant_id=self._dataset.tenant_id,

View File

@@ -72,7 +72,7 @@ class DatasetDocumentStore:
max_position = 0
embedding_model = None
if self._dataset.indexing_technique == "high_quality":
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=self._dataset.tenant_id)
embedding_model = model_manager.get_model_instance(
tenant_id=self._dataset.tenant_id,
provider=self._dataset.embedding_model_provider,

View File

@@ -8,7 +8,7 @@ from sqlalchemy.exc import IntegrityError
from configs import dify_config
from core.entities.embedding_type import EmbeddingInputType
from core.model_manager import ModelInstance
from core.model_manager import ModelInstance, ModelManager
from core.rag.embedding.embedding_base import Embeddings
from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey
from dify_graph.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
@@ -22,8 +22,20 @@ logger = logging.getLogger(__name__)
class CacheEmbedding(Embeddings):
def __init__(self, model_instance: ModelInstance, user: str | None = None):
self._model_instance = model_instance
self._user = user
self._model_instance = self._bind_model_instance(model_instance, user)
@staticmethod
def _bind_model_instance(model_instance: ModelInstance, user: str | None) -> ModelInstance:
if user is None:
return model_instance
tenant_id = model_instance.provider_model_bundle.configuration.tenant_id
return ModelManager.for_tenant(tenant_id=tenant_id, user_id=user).get_model_instance(
tenant_id=tenant_id,
provider=model_instance.provider,
model_type=model_instance.model_type_instance.model_type,
model=model_instance.model_name,
)
def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""Embed search docs in batches of 10."""
@@ -65,7 +77,7 @@ class CacheEmbedding(Embeddings):
batch_texts = embedding_queue_texts[i : i + max_chunks]
embedding_result = self._model_instance.invoke_text_embedding(
texts=batch_texts, user=self._user, input_type=EmbeddingInputType.DOCUMENT
texts=batch_texts, input_type=EmbeddingInputType.DOCUMENT
)
for vector in embedding_result.embeddings:
@@ -147,7 +159,6 @@ class CacheEmbedding(Embeddings):
embedding_result = self._model_instance.invoke_multimodal_embedding(
multimodel_documents=batch_multimodel_documents,
user=self._user,
input_type=EmbeddingInputType.DOCUMENT,
)
@@ -202,7 +213,7 @@ class CacheEmbedding(Embeddings):
return [float(x) for x in decoded_embedding]
try:
embedding_result = self._model_instance.invoke_text_embedding(
texts=[text], user=self._user, input_type=EmbeddingInputType.QUERY
texts=[text], input_type=EmbeddingInputType.QUERY
)
embedding_results = embedding_result.embeddings[0]
@@ -245,7 +256,7 @@ class CacheEmbedding(Embeddings):
return [float(x) for x in decoded_embedding]
try:
embedding_result = self._model_instance.invoke_multimodal_embedding(
multimodel_documents=[multimodel_document], user=self._user, input_type=EmbeddingInputType.QUERY
multimodel_documents=[multimodel_document], input_type=EmbeddingInputType.QUERY
)
embedding_results = embedding_result.embeddings[0]

View File

@@ -12,7 +12,7 @@ from core.app.llm import deduct_llm_quota
from core.entities.knowledge_entities import PreviewDetail
from core.llm_generator.prompts import DEFAULT_GENERATOR_SUMMARY_PROMPT
from core.model_manager import ModelInstance
from core.provider_manager import ProviderManager
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
from core.rag.cleaner.clean_processor import CleanProcessor
from core.rag.data_post_processor.data_post_processor import RerankingModelDict
from core.rag.datasource.keyword.keyword_factory import Keyword
@@ -410,7 +410,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
# If default prompt doesn't have {language} placeholder, use it as-is
pass
provider_manager = ProviderManager()
provider_manager = create_plugin_provider_manager(tenant_id=tenant_id)
provider_model_bundle = provider_manager.get_provider_model_bundle(
tenant_id, model_provider_name, ModelType.LLM
)

View File

@@ -16,6 +16,18 @@ class RerankModelRunner(BaseRerankRunner):
def __init__(self, rerank_model_instance: ModelInstance):
self.rerank_model_instance = rerank_model_instance
def _get_invoke_model_instance(self, user: str | None) -> ModelInstance:
if user is None:
return self.rerank_model_instance
tenant_id = self.rerank_model_instance.provider_model_bundle.configuration.tenant_id
return ModelManager.for_tenant(tenant_id=tenant_id, user_id=user).get_model_instance(
tenant_id=tenant_id,
provider=self.rerank_model_instance.provider,
model_type=ModelType.RERANK,
model=self.rerank_model_instance.model_name,
)
def run(
self,
query: str,
@@ -34,7 +46,9 @@ class RerankModelRunner(BaseRerankRunner):
:param user: unique user id if needed
:return:
"""
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(
tenant_id=self.rerank_model_instance.provider_model_bundle.configuration.tenant_id
)
is_support_vision = model_manager.check_model_support_vision(
tenant_id=self.rerank_model_instance.provider_model_bundle.configuration.tenant_id,
provider=self.rerank_model_instance.provider,
@@ -102,8 +116,8 @@ class RerankModelRunner(BaseRerankRunner):
docs.append(document.page_content)
unique_documents.append(document)
rerank_result = self.rerank_model_instance.invoke_rerank(
query=query, docs=docs, score_threshold=score_threshold, top_n=top_n, user=user
rerank_result = self._get_invoke_model_instance(user).invoke_rerank(
query=query, docs=docs, score_threshold=score_threshold, top_n=top_n
)
return rerank_result, unique_documents
@@ -180,8 +194,8 @@ class RerankModelRunner(BaseRerankRunner):
"content": file_query,
"content_type": DocType.IMAGE,
}
rerank_result = self.rerank_model_instance.invoke_multimodal_rerank(
query=file_query_dict, docs=docs, score_threshold=score_threshold, top_n=top_n, user=user
rerank_result = self._get_invoke_model_instance(user).invoke_multimodal_rerank(
query=file_query_dict, docs=docs, score_threshold=score_threshold, top_n=top_n
)
return rerank_result, unique_documents
else:

View File

@@ -163,7 +163,7 @@ class WeightRerankRunner(BaseRerankRunner):
"""
query_vector_scores = []
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
embedding_model = model_manager.get_model_instance(
tenant_id=tenant_id,

View File

@@ -160,7 +160,7 @@ class DatasetRetrieval:
if request.model_provider is None or request.model_name is None or request.query is None:
raise ValueError("model_provider, model_name, and query are required for single retrieval mode")
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=request.tenant_id)
model_instance = model_manager.get_model_instance(
tenant_id=request.tenant_id,
model_type=ModelType.LLM,
@@ -387,7 +387,7 @@ class DatasetRetrieval:
model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
model_instance = model_manager.get_model_instance(
tenant_id=tenant_id, model_type=ModelType.LLM, provider=model_config.provider, model=model_config.model
)
@@ -1411,7 +1411,7 @@ class DatasetRetrieval:
raise ValueError("metadata_model_config is required")
# get metadata model instance
# fetch model config
model_instance, model_config = self._fetch_model_config(tenant_id, metadata_model_config)
model_instance, model_config = self._fetch_model_config(tenant_id, metadata_model_config, user_id=user_id)
# fetch prompt messages
prompt_messages, stop = self._get_prompt_template(
@@ -1430,7 +1430,6 @@ class DatasetRetrieval:
model_parameters=model_config.parameters,
stop=stop,
stream=True,
user=user_id,
),
)
@@ -1533,7 +1532,7 @@ class DatasetRetrieval:
return filters
def _fetch_model_config(
self, tenant_id: str, model: ModelConfig
self, tenant_id: str, model: ModelConfig, user_id: str | None = None
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
"""
Fetch model config
@@ -1543,7 +1542,7 @@ class DatasetRetrieval:
model_name = model.name
provider_name = model.provider
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=tenant_id, user_id=user_id)
model_instance = model_manager.get_model_instance(
tenant_id=tenant_id, model_type=ModelType.LLM, provider=provider_name, model=model_name
)

View File

@@ -3,13 +3,14 @@ from typing import Union
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.app.llm import deduct_llm_quota
from core.model_manager import ModelInstance
from core.model_manager import ModelInstance, ModelManager
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
from core.rag.retrieval.output_parser.react_output import ReactAction
from core.rag.retrieval.output_parser.structured_chat import StructuredChatOutputParser
from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
from dify_graph.model_runtime.entities.model_entities import ModelType
PREFIX = """Respond to the human as helpfully and accurately as possible. You have access to the following tools:"""
@@ -150,19 +151,24 @@ class ReactMultiDatasetRouter:
:param stop: stop
:return:
"""
invoke_result: Generator[LLMResult, None, None] = model_instance.invoke_llm(
bound_model_instance = ModelManager.for_tenant(tenant_id=tenant_id, user_id=user_id).get_model_instance(
tenant_id=tenant_id,
provider=model_instance.provider,
model_type=ModelType.LLM,
model=model_instance.model_name,
)
invoke_result: Generator[LLMResult, None, None] = bound_model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=completion_param,
stop=stop,
stream=True,
user=user_id,
)
# handle invoke result
text, usage = self._handle_invoke_result(invoke_result=invoke_result)
# deduct quota
deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage)
deduct_llm_quota(tenant_id=tenant_id, model_instance=bound_model_instance, usage=usage)
return text, usage