mirror of
https://github.com/langgenius/dify.git
synced 2026-06-01 22:01:06 -04:00
feat: extract model runtime
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
@@ -106,9 +106,9 @@ class DataPostProcessor:
|
||||
) -> ModelInstance | None:
|
||||
if reranking_model:
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
reranking_provider_name = reranking_model["reranking_provider_name"]
|
||||
reranking_model_name = reranking_model["reranking_model_name"]
|
||||
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
|
||||
reranking_provider_name = reranking_model.get("reranking_provider_name")
|
||||
reranking_model_name = reranking_model.get("reranking_model_name")
|
||||
if not reranking_provider_name or not reranking_model_name:
|
||||
return None
|
||||
rerank_model_instance = model_manager.get_model_instance(
|
||||
|
||||
@@ -328,7 +328,7 @@ class RetrievalService:
|
||||
str(dataset.tenant_id), str(RerankMode.RERANKING_MODEL), reranking_model, None, False
|
||||
)
|
||||
if dataset.is_multimodal:
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=dataset.tenant_id)
|
||||
is_support_vision = model_manager.check_model_support_vision(
|
||||
tenant_id=dataset.tenant_id,
|
||||
provider=reranking_model["reranking_provider_name"],
|
||||
|
||||
@@ -303,7 +303,7 @@ class Vector:
|
||||
redis_client.delete(collection_exist_cache_key)
|
||||
|
||||
def _get_embeddings(self) -> Embeddings:
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=self._dataset.tenant_id)
|
||||
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=self._dataset.tenant_id,
|
||||
|
||||
@@ -72,7 +72,7 @@ class DatasetDocumentStore:
|
||||
max_position = 0
|
||||
embedding_model = None
|
||||
if self._dataset.indexing_technique == "high_quality":
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=self._dataset.tenant_id)
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=self._dataset.tenant_id,
|
||||
provider=self._dataset.embedding_model_provider,
|
||||
|
||||
@@ -8,7 +8,7 @@ from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from configs import dify_config
|
||||
from core.entities.embedding_type import EmbeddingInputType
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.rag.embedding.embedding_base import Embeddings
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey
|
||||
from dify_graph.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
@@ -22,8 +22,20 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class CacheEmbedding(Embeddings):
|
||||
def __init__(self, model_instance: ModelInstance, user: str | None = None):
|
||||
self._model_instance = model_instance
|
||||
self._user = user
|
||||
self._model_instance = self._bind_model_instance(model_instance, user)
|
||||
|
||||
@staticmethod
|
||||
def _bind_model_instance(model_instance: ModelInstance, user: str | None) -> ModelInstance:
|
||||
if user is None:
|
||||
return model_instance
|
||||
|
||||
tenant_id = model_instance.provider_model_bundle.configuration.tenant_id
|
||||
return ModelManager.for_tenant(tenant_id=tenant_id, user_id=user).get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
provider=model_instance.provider,
|
||||
model_type=model_instance.model_type_instance.model_type,
|
||||
model=model_instance.model_name,
|
||||
)
|
||||
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Embed search docs in batches of 10."""
|
||||
@@ -65,7 +77,7 @@ class CacheEmbedding(Embeddings):
|
||||
batch_texts = embedding_queue_texts[i : i + max_chunks]
|
||||
|
||||
embedding_result = self._model_instance.invoke_text_embedding(
|
||||
texts=batch_texts, user=self._user, input_type=EmbeddingInputType.DOCUMENT
|
||||
texts=batch_texts, input_type=EmbeddingInputType.DOCUMENT
|
||||
)
|
||||
|
||||
for vector in embedding_result.embeddings:
|
||||
@@ -147,7 +159,6 @@ class CacheEmbedding(Embeddings):
|
||||
|
||||
embedding_result = self._model_instance.invoke_multimodal_embedding(
|
||||
multimodel_documents=batch_multimodel_documents,
|
||||
user=self._user,
|
||||
input_type=EmbeddingInputType.DOCUMENT,
|
||||
)
|
||||
|
||||
@@ -202,7 +213,7 @@ class CacheEmbedding(Embeddings):
|
||||
return [float(x) for x in decoded_embedding]
|
||||
try:
|
||||
embedding_result = self._model_instance.invoke_text_embedding(
|
||||
texts=[text], user=self._user, input_type=EmbeddingInputType.QUERY
|
||||
texts=[text], input_type=EmbeddingInputType.QUERY
|
||||
)
|
||||
|
||||
embedding_results = embedding_result.embeddings[0]
|
||||
@@ -245,7 +256,7 @@ class CacheEmbedding(Embeddings):
|
||||
return [float(x) for x in decoded_embedding]
|
||||
try:
|
||||
embedding_result = self._model_instance.invoke_multimodal_embedding(
|
||||
multimodel_documents=[multimodel_document], user=self._user, input_type=EmbeddingInputType.QUERY
|
||||
multimodel_documents=[multimodel_document], input_type=EmbeddingInputType.QUERY
|
||||
)
|
||||
|
||||
embedding_results = embedding_result.embeddings[0]
|
||||
|
||||
@@ -12,7 +12,7 @@ from core.app.llm import deduct_llm_quota
|
||||
from core.entities.knowledge_entities import PreviewDetail
|
||||
from core.llm_generator.prompts import DEFAULT_GENERATOR_SUMMARY_PROMPT
|
||||
from core.model_manager import ModelInstance
|
||||
from core.provider_manager import ProviderManager
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
|
||||
from core.rag.cleaner.clean_processor import CleanProcessor
|
||||
from core.rag.data_post_processor.data_post_processor import RerankingModelDict
|
||||
from core.rag.datasource.keyword.keyword_factory import Keyword
|
||||
@@ -410,7 +410,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
# If default prompt doesn't have {language} placeholder, use it as-is
|
||||
pass
|
||||
|
||||
provider_manager = ProviderManager()
|
||||
provider_manager = create_plugin_provider_manager(tenant_id=tenant_id)
|
||||
provider_model_bundle = provider_manager.get_provider_model_bundle(
|
||||
tenant_id, model_provider_name, ModelType.LLM
|
||||
)
|
||||
|
||||
@@ -16,6 +16,18 @@ class RerankModelRunner(BaseRerankRunner):
|
||||
def __init__(self, rerank_model_instance: ModelInstance):
|
||||
self.rerank_model_instance = rerank_model_instance
|
||||
|
||||
def _get_invoke_model_instance(self, user: str | None) -> ModelInstance:
|
||||
if user is None:
|
||||
return self.rerank_model_instance
|
||||
|
||||
tenant_id = self.rerank_model_instance.provider_model_bundle.configuration.tenant_id
|
||||
return ModelManager.for_tenant(tenant_id=tenant_id, user_id=user).get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
provider=self.rerank_model_instance.provider,
|
||||
model_type=ModelType.RERANK,
|
||||
model=self.rerank_model_instance.model_name,
|
||||
)
|
||||
|
||||
def run(
|
||||
self,
|
||||
query: str,
|
||||
@@ -34,7 +46,9 @@ class RerankModelRunner(BaseRerankRunner):
|
||||
:param user: unique user id if needed
|
||||
:return:
|
||||
"""
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(
|
||||
tenant_id=self.rerank_model_instance.provider_model_bundle.configuration.tenant_id
|
||||
)
|
||||
is_support_vision = model_manager.check_model_support_vision(
|
||||
tenant_id=self.rerank_model_instance.provider_model_bundle.configuration.tenant_id,
|
||||
provider=self.rerank_model_instance.provider,
|
||||
@@ -102,8 +116,8 @@ class RerankModelRunner(BaseRerankRunner):
|
||||
docs.append(document.page_content)
|
||||
unique_documents.append(document)
|
||||
|
||||
rerank_result = self.rerank_model_instance.invoke_rerank(
|
||||
query=query, docs=docs, score_threshold=score_threshold, top_n=top_n, user=user
|
||||
rerank_result = self._get_invoke_model_instance(user).invoke_rerank(
|
||||
query=query, docs=docs, score_threshold=score_threshold, top_n=top_n
|
||||
)
|
||||
return rerank_result, unique_documents
|
||||
|
||||
@@ -180,8 +194,8 @@ class RerankModelRunner(BaseRerankRunner):
|
||||
"content": file_query,
|
||||
"content_type": DocType.IMAGE,
|
||||
}
|
||||
rerank_result = self.rerank_model_instance.invoke_multimodal_rerank(
|
||||
query=file_query_dict, docs=docs, score_threshold=score_threshold, top_n=top_n, user=user
|
||||
rerank_result = self._get_invoke_model_instance(user).invoke_multimodal_rerank(
|
||||
query=file_query_dict, docs=docs, score_threshold=score_threshold, top_n=top_n
|
||||
)
|
||||
return rerank_result, unique_documents
|
||||
else:
|
||||
|
||||
@@ -163,7 +163,7 @@ class WeightRerankRunner(BaseRerankRunner):
|
||||
"""
|
||||
query_vector_scores = []
|
||||
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
|
||||
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
|
||||
@@ -160,7 +160,7 @@ class DatasetRetrieval:
|
||||
if request.model_provider is None or request.model_name is None or request.query is None:
|
||||
raise ValueError("model_provider, model_name, and query are required for single retrieval mode")
|
||||
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=request.tenant_id)
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=request.tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
@@ -387,7 +387,7 @@ class DatasetRetrieval:
|
||||
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=tenant_id, model_type=ModelType.LLM, provider=model_config.provider, model=model_config.model
|
||||
)
|
||||
@@ -1411,7 +1411,7 @@ class DatasetRetrieval:
|
||||
raise ValueError("metadata_model_config is required")
|
||||
# get metadata model instance
|
||||
# fetch model config
|
||||
model_instance, model_config = self._fetch_model_config(tenant_id, metadata_model_config)
|
||||
model_instance, model_config = self._fetch_model_config(tenant_id, metadata_model_config, user_id=user_id)
|
||||
|
||||
# fetch prompt messages
|
||||
prompt_messages, stop = self._get_prompt_template(
|
||||
@@ -1430,7 +1430,6 @@ class DatasetRetrieval:
|
||||
model_parameters=model_config.parameters,
|
||||
stop=stop,
|
||||
stream=True,
|
||||
user=user_id,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1533,7 +1532,7 @@ class DatasetRetrieval:
|
||||
return filters
|
||||
|
||||
def _fetch_model_config(
|
||||
self, tenant_id: str, model: ModelConfig
|
||||
self, tenant_id: str, model: ModelConfig, user_id: str | None = None
|
||||
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
|
||||
"""
|
||||
Fetch model config
|
||||
@@ -1543,7 +1542,7 @@ class DatasetRetrieval:
|
||||
model_name = model.name
|
||||
provider_name = model.provider
|
||||
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=tenant_id, user_id=user_id)
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=tenant_id, model_type=ModelType.LLM, provider=provider_name, model=model_name
|
||||
)
|
||||
|
||||
@@ -3,13 +3,14 @@ from typing import Union
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.app.llm import deduct_llm_quota
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
|
||||
from core.rag.retrieval.output_parser.react_output import ReactAction
|
||||
from core.rag.retrieval.output_parser.structured_chat import StructuredChatOutputParser
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
|
||||
PREFIX = """Respond to the human as helpfully and accurately as possible. You have access to the following tools:"""
|
||||
|
||||
@@ -150,19 +151,24 @@ class ReactMultiDatasetRouter:
|
||||
:param stop: stop
|
||||
:return:
|
||||
"""
|
||||
invoke_result: Generator[LLMResult, None, None] = model_instance.invoke_llm(
|
||||
bound_model_instance = ModelManager.for_tenant(tenant_id=tenant_id, user_id=user_id).get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
provider=model_instance.provider,
|
||||
model_type=ModelType.LLM,
|
||||
model=model_instance.model_name,
|
||||
)
|
||||
invoke_result: Generator[LLMResult, None, None] = bound_model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=completion_param,
|
||||
stop=stop,
|
||||
stream=True,
|
||||
user=user_id,
|
||||
)
|
||||
|
||||
# handle invoke result
|
||||
text, usage = self._handle_invoke_result(invoke_result=invoke_result)
|
||||
|
||||
# deduct quota
|
||||
deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage)
|
||||
deduct_llm_quota(tenant_id=tenant_id, model_instance=bound_model_instance, usage=usage)
|
||||
|
||||
return text, usage
|
||||
|
||||
|
||||
Reference in New Issue
Block a user