Feat/support multimodal embedding (#29115)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Jyong
2025-12-09 14:41:46 +08:00
committed by GitHub
parent 77cf8f6c27
commit 9affc546c6
78 changed files with 3230 additions and 713 deletions

View File

@@ -654,3 +654,9 @@ TENANT_ISOLATED_TASK_CONCURRENCY=1
# Maximum number of segments for dataset segments API (0 for unlimited)
DATASET_MAX_SEGMENTS_PER_REQUEST=0
# Multimodal knowledgebase limit
SINGLE_CHUNK_ATTACHMENT_LIMIT=10
ATTACHMENT_IMAGE_FILE_SIZE_LIMIT=2
ATTACHMENT_IMAGE_DOWNLOAD_TIMEOUT=60
IMAGE_FILE_BATCH_LIMIT=10

View File

@@ -360,6 +360,26 @@ class FileUploadConfig(BaseSettings):
default=10,
)
IMAGE_FILE_BATCH_LIMIT: PositiveInt = Field(
description="Maximum number of files allowed in a image batch upload operation",
default=10,
)
SINGLE_CHUNK_ATTACHMENT_LIMIT: PositiveInt = Field(
description="Maximum number of files allowed in a single chunk attachment",
default=10,
)
ATTACHMENT_IMAGE_FILE_SIZE_LIMIT: NonNegativeInt = Field(
description="Maximum allowed image file size for attachments in megabytes",
default=2,
)
ATTACHMENT_IMAGE_DOWNLOAD_TIMEOUT: NonNegativeInt = Field(
description="Timeout for downloading image attachments in seconds",
default=60,
)
inner_UPLOAD_FILE_EXTENSION_BLACKLIST: str = Field(
description=(
"Comma-separated list of file extensions that are blocked from upload. "

View File

@@ -151,6 +151,7 @@ class DatasetUpdatePayload(BaseModel):
external_knowledge_id: str | None = None
external_knowledge_api_id: str | None = None
icon_info: dict[str, Any] | None = None
is_multimodal: bool | None = False
@field_validator("indexing_technique")
@classmethod
@@ -423,17 +424,16 @@ class DatasetApi(Resource):
payload = DatasetUpdatePayload.model_validate(console_ns.payload or {})
payload_data = payload.model_dump(exclude_unset=True)
current_user, current_tenant_id = current_account_with_tenant()
# check embedding model setting
if (
payload.indexing_technique == "high_quality"
and payload.embedding_model_provider is not None
and payload.embedding_model is not None
):
DatasetService.check_embedding_model_setting(
is_multimodal = DatasetService.check_is_multimodal_model(
dataset.tenant_id, payload.embedding_model_provider, payload.embedding_model
)
payload.is_multimodal = is_multimodal
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
DatasetPermissionService.check_permission(
current_user, dataset, payload.permission, payload.partial_member_list

View File

@@ -424,6 +424,10 @@ class DatasetInitApi(Resource):
model_type=ModelType.TEXT_EMBEDDING,
model=knowledge_config.embedding_model,
)
is_multimodal = DatasetService.check_is_multimodal_model(
current_tenant_id, knowledge_config.embedding_model_provider, knowledge_config.embedding_model
)
knowledge_config.is_multimodal = is_multimodal
except InvokeAuthorizationError:
raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."

View File

@@ -51,6 +51,7 @@ class SegmentCreatePayload(BaseModel):
content: str
answer: str | None = None
keywords: list[str] | None = None
attachment_ids: list[str] | None = None
class SegmentUpdatePayload(BaseModel):
@@ -58,6 +59,7 @@ class SegmentUpdatePayload(BaseModel):
answer: str | None = None
keywords: list[str] | None = None
regenerate_child_chunks: bool = False
attachment_ids: list[str] | None = None
class BatchImportPayload(BaseModel):

View File

@@ -1,7 +1,7 @@
import logging
from typing import Any
from flask_restx import marshal
from flask_restx import marshal, reqparse
from pydantic import BaseModel, Field
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
@@ -33,6 +33,7 @@ class HitTestingPayload(BaseModel):
query: str = Field(max_length=250)
retrieval_model: dict[str, Any] | None = None
external_retrieval_model: dict[str, Any] | None = None
attachment_ids: list[str] | None = None
class DatasetsHitTestingBase:
@@ -54,16 +55,28 @@ class DatasetsHitTestingBase:
def hit_testing_args_check(args: dict[str, Any]):
HitTestingService.hit_testing_args_check(args)
@staticmethod
def parse_args():
parser = (
reqparse.RequestParser()
.add_argument("query", type=str, required=False, location="json")
.add_argument("attachment_ids", type=list, required=False, location="json")
.add_argument("retrieval_model", type=dict, required=False, location="json")
.add_argument("external_retrieval_model", type=dict, required=False, location="json")
)
return parser.parse_args()
@staticmethod
def perform_hit_testing(dataset, args):
assert isinstance(current_user, Account)
try:
response = HitTestingService.retrieve(
dataset=dataset,
query=args["query"],
query=args.get("query"),
account=current_user,
retrieval_model=args["retrieval_model"],
external_retrieval_model=args["external_retrieval_model"],
retrieval_model=args.get("retrieval_model"),
external_retrieval_model=args.get("external_retrieval_model"),
attachment_ids=args.get("attachment_ids"),
limit=10,
)
return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)}

View File

@@ -45,6 +45,9 @@ class FileApi(Resource):
"video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
"audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
"workflow_file_upload_limit": dify_config.WORKFLOW_FILE_UPLOAD_LIMIT,
"image_file_batch_limit": dify_config.IMAGE_FILE_BATCH_LIMIT,
"single_chunk_attachment_limit": dify_config.SINGLE_CHUNK_ATTACHMENT_LIMIT,
"attachment_image_file_size_limit": dify_config.ATTACHMENT_IMAGE_FILE_SIZE_LIMIT,
}, 200
@setup_required

View File

@@ -83,6 +83,7 @@ class AppRunner:
context: str | None = None,
memory: TokenBufferMemory | None = None,
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
context_files: list["File"] | None = None,
) -> tuple[list[PromptMessage], list[str] | None]:
"""
Organize prompt messages
@@ -111,6 +112,7 @@ class AppRunner:
memory=memory,
model_config=model_config,
image_detail_config=image_detail_config,
context_files=context_files,
)
else:
memory_config = MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False))

View File

@@ -11,6 +11,7 @@ from core.app.entities.app_invoke_entities import (
)
from core.app.entities.queue_entities import QueueAnnotationReplyEvent
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.file import File
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
@@ -146,6 +147,7 @@ class ChatAppRunner(AppRunner):
# get context from datasets
context = None
context_files: list[File] = []
if app_config.dataset and app_config.dataset.dataset_ids:
hit_callback = DatasetIndexToolCallbackHandler(
queue_manager,
@@ -156,7 +158,7 @@ class ChatAppRunner(AppRunner):
)
dataset_retrieval = DatasetRetrieval(application_generate_entity)
context = dataset_retrieval.retrieve(
context, retrieved_files = dataset_retrieval.retrieve(
app_id=app_record.id,
user_id=application_generate_entity.user_id,
tenant_id=app_record.tenant_id,
@@ -171,7 +173,11 @@ class ChatAppRunner(AppRunner):
memory=memory,
message_id=message.id,
inputs=inputs,
vision_enabled=application_generate_entity.app_config.app_model_config_dict.get("file_upload", {}).get(
"enabled", False
),
)
context_files = retrieved_files or []
# reorganize all inputs and template to prompt messages
# Include: prompt template, inputs, query(optional), files(optional)
@@ -186,6 +192,7 @@ class ChatAppRunner(AppRunner):
context=context,
memory=memory,
image_detail_config=image_detail_config,
context_files=context_files,
)
# check hosting moderation

View File

@@ -10,6 +10,7 @@ from core.app.entities.app_invoke_entities import (
CompletionAppGenerateEntity,
)
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.file import File
from core.model_manager import ModelInstance
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
from core.moderation.base import ModerationError
@@ -102,6 +103,7 @@ class CompletionAppRunner(AppRunner):
# get context from datasets
context = None
context_files: list[File] = []
if app_config.dataset and app_config.dataset.dataset_ids:
hit_callback = DatasetIndexToolCallbackHandler(
queue_manager,
@@ -116,7 +118,7 @@ class CompletionAppRunner(AppRunner):
query = inputs.get(dataset_config.retrieve_config.query_variable, "")
dataset_retrieval = DatasetRetrieval(application_generate_entity)
context = dataset_retrieval.retrieve(
context, retrieved_files = dataset_retrieval.retrieve(
app_id=app_record.id,
user_id=application_generate_entity.user_id,
tenant_id=app_record.tenant_id,
@@ -130,7 +132,11 @@ class CompletionAppRunner(AppRunner):
hit_callback=hit_callback,
message_id=message.id,
inputs=inputs,
vision_enabled=application_generate_entity.app_config.app_model_config_dict.get("file_upload", {}).get(
"enabled", False
),
)
context_files = retrieved_files or []
# reorganize all inputs and template to prompt messages
# Include: prompt template, inputs, query(optional), files(optional)
@@ -144,6 +150,7 @@ class CompletionAppRunner(AppRunner):
query=query,
context=context,
image_detail_config=image_detail_config,
context_files=context_files,
)
# check hosting moderation

View File

@@ -7,7 +7,7 @@ from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.models.document import Document
from extensions.ext_database import db
from models.dataset import ChildChunk, DatasetQuery, DocumentSegment
@@ -59,7 +59,7 @@ class DatasetIndexToolCallbackHandler:
document_id,
)
continue
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
child_chunk_stmt = select(ChildChunk).where(
ChildChunk.index_node_id == document.metadata["doc_id"],
ChildChunk.dataset_id == dataset_document.dataset_id,

View File

@@ -7,7 +7,7 @@ import time
import uuid
from typing import Any
from flask import current_app
from flask import Flask, current_app
from sqlalchemy import select
from sqlalchemy.orm.exc import ObjectDeletedError
@@ -21,7 +21,7 @@ from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import ChildDocument, Document
@@ -36,6 +36,7 @@ from extensions.ext_redis import redis_client
from extensions.ext_storage import storage
from libs import helper
from libs.datetime_utils import naive_utc_now
from models import Account
from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment
from models.dataset import Document as DatasetDocument
from models.model import UploadFile
@@ -89,8 +90,17 @@ class IndexingRunner:
text_docs = self._extract(index_processor, requeried_document, processing_rule.to_dict())
# transform
current_user = db.session.query(Account).filter_by(id=requeried_document.created_by).first()
if not current_user:
raise ValueError("no current user found")
current_user.set_tenant_id(dataset.tenant_id)
documents = self._transform(
index_processor, dataset, text_docs, requeried_document.doc_language, processing_rule.to_dict()
index_processor,
dataset,
text_docs,
requeried_document.doc_language,
processing_rule.to_dict(),
current_user=current_user,
)
# save segment
self._load_segments(dataset, requeried_document, documents)
@@ -136,7 +146,7 @@ class IndexingRunner:
for document_segment in document_segments:
db.session.delete(document_segment)
if requeried_document.doc_form == IndexType.PARENT_CHILD_INDEX:
if requeried_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
# delete child chunks
db.session.query(ChildChunk).where(ChildChunk.segment_id == document_segment.id).delete()
db.session.commit()
@@ -152,8 +162,17 @@ class IndexingRunner:
text_docs = self._extract(index_processor, requeried_document, processing_rule.to_dict())
# transform
current_user = db.session.query(Account).filter_by(id=requeried_document.created_by).first()
if not current_user:
raise ValueError("no current user found")
current_user.set_tenant_id(dataset.tenant_id)
documents = self._transform(
index_processor, dataset, text_docs, requeried_document.doc_language, processing_rule.to_dict()
index_processor,
dataset,
text_docs,
requeried_document.doc_language,
processing_rule.to_dict(),
current_user=current_user,
)
# save segment
self._load_segments(dataset, requeried_document, documents)
@@ -209,7 +228,7 @@ class IndexingRunner:
"dataset_id": document_segment.dataset_id,
},
)
if requeried_document.doc_form == IndexType.PARENT_CHILD_INDEX:
if requeried_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
child_chunks = document_segment.get_child_chunks()
if child_chunks:
child_documents = []
@@ -302,6 +321,7 @@ class IndexingRunner:
text_docs = index_processor.extract(extract_setting, process_rule_mode=tmp_processing_rule["mode"])
documents = index_processor.transform(
text_docs,
current_user=None,
embedding_model_instance=embedding_model_instance,
process_rule=processing_rule.to_dict(),
tenant_id=tenant_id,
@@ -551,7 +571,10 @@ class IndexingRunner:
indexing_start_at = time.perf_counter()
tokens = 0
create_keyword_thread = None
if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX and dataset.indexing_technique == "economy":
if (
dataset_document.doc_form != IndexStructureType.PARENT_CHILD_INDEX
and dataset.indexing_technique == "economy"
):
# create keyword index
create_keyword_thread = threading.Thread(
target=self._process_keyword_index,
@@ -590,7 +613,7 @@ class IndexingRunner:
for future in futures:
tokens += future.result()
if (
dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX
dataset_document.doc_form != IndexStructureType.PARENT_CHILD_INDEX
and dataset.indexing_technique == "economy"
and create_keyword_thread is not None
):
@@ -635,7 +658,13 @@ class IndexingRunner:
db.session.commit()
def _process_chunk(
self, flask_app, index_processor, chunk_documents, dataset, dataset_document, embedding_model_instance
self,
flask_app: Flask,
index_processor: BaseIndexProcessor,
chunk_documents: list[Document],
dataset: Dataset,
dataset_document: DatasetDocument,
embedding_model_instance: ModelInstance | None,
):
with flask_app.app_context():
# check document is paused
@@ -646,8 +675,15 @@ class IndexingRunner:
page_content_list = [document.page_content for document in chunk_documents]
tokens += sum(embedding_model_instance.get_text_embedding_num_tokens(page_content_list))
multimodal_documents = []
for document in chunk_documents:
if document.attachments and dataset.is_multimodal:
multimodal_documents.extend(document.attachments)
# load index
index_processor.load(dataset, chunk_documents, with_keywords=False)
index_processor.load(
dataset, chunk_documents, multimodal_documents=multimodal_documents, with_keywords=False
)
document_ids = [document.metadata["doc_id"] for document in chunk_documents]
db.session.query(DocumentSegment).where(
@@ -710,6 +746,7 @@ class IndexingRunner:
text_docs: list[Document],
doc_language: str,
process_rule: dict,
current_user: Account | None = None,
) -> list[Document]:
# get embedding model instance
embedding_model_instance = None
@@ -729,6 +766,7 @@ class IndexingRunner:
documents = index_processor.transform(
text_docs,
current_user,
embedding_model_instance=embedding_model_instance,
process_rule=process_rule,
tenant_id=dataset.tenant_id,
@@ -737,14 +775,16 @@ class IndexingRunner:
return documents
def _load_segments(self, dataset, dataset_document, documents):
def _load_segments(self, dataset: Dataset, dataset_document: DatasetDocument, documents: list[Document]):
# save node to document segment
doc_store = DatasetDocumentStore(
dataset=dataset, user_id=dataset_document.created_by, document_id=dataset_document.id
)
# add document segments
doc_store.add_documents(docs=documents, save_child=dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX)
doc_store.add_documents(
docs=documents, save_child=dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX
)
# update document status to indexing
cur_time = naive_utc_now()

View File

@@ -10,9 +10,9 @@ from core.errors.error import ProviderTokenNotInitError
from core.model_runtime.callbacks.base_callback import Callback
from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
from core.model_runtime.entities.rerank_entities import RerankResult
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
from core.model_runtime.entities.text_embedding_entities import EmbeddingResult
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.model_providers.__base.moderation_model import ModerationModel
@@ -200,7 +200,7 @@ class ModelInstance:
def invoke_text_embedding(
self, texts: list[str], user: str | None = None, input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT
) -> TextEmbeddingResult:
) -> EmbeddingResult:
"""
Invoke large language model
@@ -212,7 +212,7 @@ class ModelInstance:
if not isinstance(self.model_type_instance, TextEmbeddingModel):
raise Exception("Model type instance is not TextEmbeddingModel")
return cast(
TextEmbeddingResult,
EmbeddingResult,
self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model,
@@ -223,6 +223,34 @@ class ModelInstance:
),
)
def invoke_multimodal_embedding(
self,
multimodel_documents: list[dict],
user: str | None = None,
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
) -> EmbeddingResult:
"""
Invoke large language model
:param multimodel_documents: multimodel documents to embed
:param user: unique user id
:param input_type: input type
:return: embeddings result
"""
if not isinstance(self.model_type_instance, TextEmbeddingModel):
raise Exception("Model type instance is not TextEmbeddingModel")
return cast(
EmbeddingResult,
self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model,
credentials=self.credentials,
multimodel_documents=multimodel_documents,
user=user,
input_type=input_type,
),
)
def get_text_embedding_num_tokens(self, texts: list[str]) -> list[int]:
"""
Get number of tokens for text embedding
@@ -276,6 +304,40 @@ class ModelInstance:
),
)
def invoke_multimodal_rerank(
self,
query: dict,
docs: list[dict],
score_threshold: float | None = None,
top_n: int | None = None,
user: str | None = None,
) -> RerankResult:
"""
Invoke rerank model
:param query: search query
:param docs: docs for reranking
:param score_threshold: score threshold
:param top_n: top n
:param user: unique user id
:return: rerank result
"""
if not isinstance(self.model_type_instance, RerankModel):
raise Exception("Model type instance is not RerankModel")
return cast(
RerankResult,
self._round_robin_invoke(
function=self.model_type_instance.invoke_multimodal_rerank,
model=self.model,
credentials=self.credentials,
query=query,
docs=docs,
score_threshold=score_threshold,
top_n=top_n,
user=user,
),
)
def invoke_moderation(self, text: str, user: str | None = None) -> bool:
"""
Invoke moderation model
@@ -461,6 +523,32 @@ class ModelManager:
model=default_model_entity.model,
)
def check_model_support_vision(self, tenant_id: str, provider: str, model: str, model_type: ModelType) -> bool:
"""
Check if model supports vision
:param tenant_id: tenant id
:param provider: provider name
:param model: model name
:return: True if model supports vision, False otherwise
"""
model_instance = self.get_model_instance(tenant_id, provider, model_type, model)
model_type_instance = model_instance.model_type_instance
match model_type:
case ModelType.LLM:
model_type_instance = cast(LargeLanguageModel, model_type_instance)
case ModelType.TEXT_EMBEDDING:
model_type_instance = cast(TextEmbeddingModel, model_type_instance)
case ModelType.RERANK:
model_type_instance = cast(RerankModel, model_type_instance)
case _:
raise ValueError(f"Model type {model_type} is not supported")
model_schema = model_type_instance.get_model_schema(model, model_instance.credentials)
if not model_schema:
return False
if model_schema.features and ModelFeature.VISION in model_schema.features:
return True
return False
class LBModelManager:
def __init__(

View File

@@ -19,7 +19,7 @@ class EmbeddingUsage(ModelUsage):
latency: float
class TextEmbeddingResult(BaseModel):
class EmbeddingResult(BaseModel):
"""
Model class for text embedding result.
"""
@@ -27,3 +27,13 @@ class TextEmbeddingResult(BaseModel):
model: str
embeddings: list[list[float]]
usage: EmbeddingUsage
class FileEmbeddingResult(BaseModel):
"""
Model class for file embedding result.
"""
model: str
embeddings: list[list[float]]
usage: EmbeddingUsage

View File

@@ -50,3 +50,43 @@ class RerankModel(AIModel):
)
except Exception as e:
raise self._transform_invoke_error(e)
def invoke_multimodal_rerank(
self,
model: str,
credentials: dict,
query: dict,
docs: list[dict],
score_threshold: float | None = None,
top_n: int | None = None,
user: str | None = None,
) -> RerankResult:
"""
Invoke multimodal rerank model
:param model: model name
:param credentials: model credentials
:param query: search query
:param docs: docs for reranking
:param score_threshold: score threshold
:param top_n: top n
:param user: unique user id
:return: rerank result
"""
try:
from core.plugin.impl.model import PluginModelClient
plugin_model_manager = PluginModelClient()
return plugin_model_manager.invoke_multimodal_rerank(
tenant_id=self.tenant_id,
user_id=user or "unknown",
plugin_id=self.plugin_id,
provider=self.provider_name,
model=model,
credentials=credentials,
query=query,
docs=docs,
score_threshold=score_threshold,
top_n=top_n,
)
except Exception as e:
raise self._transform_invoke_error(e)

View File

@@ -2,7 +2,7 @@ from pydantic import ConfigDict
from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
from core.model_runtime.entities.text_embedding_entities import EmbeddingResult
from core.model_runtime.model_providers.__base.ai_model import AIModel
@@ -20,16 +20,18 @@ class TextEmbeddingModel(AIModel):
self,
model: str,
credentials: dict,
texts: list[str],
texts: list[str] | None = None,
multimodel_documents: list[dict] | None = None,
user: str | None = None,
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
) -> TextEmbeddingResult:
) -> EmbeddingResult:
"""
Invoke text embedding model
:param model: model name
:param credentials: model credentials
:param texts: texts to embed
:param files: files to embed
:param user: unique user id
:param input_type: input type
:return: embeddings result
@@ -38,16 +40,29 @@ class TextEmbeddingModel(AIModel):
try:
plugin_model_manager = PluginModelClient()
return plugin_model_manager.invoke_text_embedding(
tenant_id=self.tenant_id,
user_id=user or "unknown",
plugin_id=self.plugin_id,
provider=self.provider_name,
model=model,
credentials=credentials,
texts=texts,
input_type=input_type,
)
if texts:
return plugin_model_manager.invoke_text_embedding(
tenant_id=self.tenant_id,
user_id=user or "unknown",
plugin_id=self.plugin_id,
provider=self.provider_name,
model=model,
credentials=credentials,
texts=texts,
input_type=input_type,
)
if multimodel_documents:
return plugin_model_manager.invoke_multimodal_embedding(
tenant_id=self.tenant_id,
user_id=user or "unknown",
plugin_id=self.plugin_id,
provider=self.provider_name,
model=model,
credentials=credentials,
documents=multimodel_documents,
input_type=input_type,
)
raise ValueError("No texts or files provided")
except Exception as e:
raise self._transform_invoke_error(e)

View File

@@ -6,7 +6,7 @@ from core.model_runtime.entities.llm_entities import LLMResultChunk
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from core.model_runtime.entities.model_entities import AIModelEntity
from core.model_runtime.entities.rerank_entities import RerankResult
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
from core.model_runtime.entities.text_embedding_entities import EmbeddingResult
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin_daemon import (
PluginBasicBooleanResponse,
@@ -243,14 +243,14 @@ class PluginModelClient(BasePluginClient):
credentials: dict,
texts: list[str],
input_type: str,
) -> TextEmbeddingResult:
) -> EmbeddingResult:
"""
Invoke text embedding
"""
response = self._request_with_plugin_daemon_response_stream(
method="POST",
path=f"plugin/{tenant_id}/dispatch/text_embedding/invoke",
type_=TextEmbeddingResult,
type_=EmbeddingResult,
data=jsonable_encoder(
{
"user_id": user_id,
@@ -275,6 +275,48 @@ class PluginModelClient(BasePluginClient):
raise ValueError("Failed to invoke text embedding")
def invoke_multimodal_embedding(
self,
tenant_id: str,
user_id: str,
plugin_id: str,
provider: str,
model: str,
credentials: dict,
documents: list[dict],
input_type: str,
) -> EmbeddingResult:
"""
Invoke file embedding
"""
response = self._request_with_plugin_daemon_response_stream(
method="POST",
path=f"plugin/{tenant_id}/dispatch/multimodal_embedding/invoke",
type_=EmbeddingResult,
data=jsonable_encoder(
{
"user_id": user_id,
"data": {
"provider": provider,
"model_type": "text-embedding",
"model": model,
"credentials": credentials,
"documents": documents,
"input_type": input_type,
},
}
),
headers={
"X-Plugin-ID": plugin_id,
"Content-Type": "application/json",
},
)
for resp in response:
return resp
raise ValueError("Failed to invoke file embedding")
def get_text_embedding_num_tokens(
self,
tenant_id: str,
@@ -361,6 +403,51 @@ class PluginModelClient(BasePluginClient):
raise ValueError("Failed to invoke rerank")
def invoke_multimodal_rerank(
self,
tenant_id: str,
user_id: str,
plugin_id: str,
provider: str,
model: str,
credentials: dict,
query: dict,
docs: list[dict],
score_threshold: float | None = None,
top_n: int | None = None,
) -> RerankResult:
"""
Invoke multimodal rerank
"""
response = self._request_with_plugin_daemon_response_stream(
method="POST",
path=f"plugin/{tenant_id}/dispatch/multimodal_rerank/invoke",
type_=RerankResult,
data=jsonable_encoder(
{
"user_id": user_id,
"data": {
"provider": provider,
"model_type": "rerank",
"model": model,
"credentials": credentials,
"query": query,
"docs": docs,
"score_threshold": score_threshold,
"top_n": top_n,
},
}
),
headers={
"X-Plugin-ID": plugin_id,
"Content-Type": "application/json",
},
)
for resp in response:
return resp
raise ValueError("Failed to invoke multimodal rerank")
def invoke_tts(
self,
tenant_id: str,

View File

@@ -49,6 +49,7 @@ class SimplePromptTransform(PromptTransform):
memory: TokenBufferMemory | None,
model_config: ModelConfigWithCredentialsEntity,
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
context_files: list["File"] | None = None,
) -> tuple[list[PromptMessage], list[str] | None]:
inputs = {key: str(value) for key, value in inputs.items()}
@@ -64,6 +65,7 @@ class SimplePromptTransform(PromptTransform):
memory=memory,
model_config=model_config,
image_detail_config=image_detail_config,
context_files=context_files,
)
else:
prompt_messages, stops = self._get_completion_model_prompt_messages(
@@ -76,6 +78,7 @@ class SimplePromptTransform(PromptTransform):
memory=memory,
model_config=model_config,
image_detail_config=image_detail_config,
context_files=context_files,
)
return prompt_messages, stops
@@ -187,6 +190,7 @@ class SimplePromptTransform(PromptTransform):
memory: TokenBufferMemory | None,
model_config: ModelConfigWithCredentialsEntity,
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
context_files: list["File"] | None = None,
) -> tuple[list[PromptMessage], list[str] | None]:
prompt_messages: list[PromptMessage] = []
@@ -216,9 +220,9 @@ class SimplePromptTransform(PromptTransform):
)
if query:
prompt_messages.append(self._get_last_user_message(query, files, image_detail_config))
prompt_messages.append(self._get_last_user_message(query, files, image_detail_config, context_files))
else:
prompt_messages.append(self._get_last_user_message(prompt, files, image_detail_config))
prompt_messages.append(self._get_last_user_message(prompt, files, image_detail_config, context_files))
return prompt_messages, None
@@ -233,6 +237,7 @@ class SimplePromptTransform(PromptTransform):
memory: TokenBufferMemory | None,
model_config: ModelConfigWithCredentialsEntity,
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
context_files: list["File"] | None = None,
) -> tuple[list[PromptMessage], list[str] | None]:
# get prompt
prompt, prompt_rules = self._get_prompt_str_and_rules(
@@ -275,20 +280,27 @@ class SimplePromptTransform(PromptTransform):
if stops is not None and len(stops) == 0:
stops = None
return [self._get_last_user_message(prompt, files, image_detail_config)], stops
return [self._get_last_user_message(prompt, files, image_detail_config, context_files)], stops
def _get_last_user_message(
self,
prompt: str,
files: Sequence["File"],
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
context_files: list["File"] | None = None,
) -> UserPromptMessage:
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
if files:
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
for file in files:
prompt_message_contents.append(
file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config)
)
if context_files:
for file in context_files:
prompt_message_contents.append(
file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config)
)
if prompt_message_contents:
prompt_message_contents.append(TextPromptMessageContent(data=prompt))
prompt_message = UserPromptMessage(content=prompt_message_contents)

View File

@@ -2,6 +2,7 @@ from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.rag.data_post_processor.reorder import ReorderRunner
from core.rag.index_processor.constant.query_type import QueryType
from core.rag.models.document import Document
from core.rag.rerank.entity.weight import KeywordSetting, VectorSetting, Weights
from core.rag.rerank.rerank_base import BaseRerankRunner
@@ -30,9 +31,10 @@ class DataPostProcessor:
score_threshold: float | None = None,
top_n: int | None = None,
user: str | None = None,
query_type: QueryType = QueryType.TEXT_QUERY,
) -> list[Document]:
if self.rerank_runner:
documents = self.rerank_runner.run(query, documents, score_threshold, top_n, user)
documents = self.rerank_runner.run(query, documents, score_threshold, top_n, user, query_type)
if self.reorder_runner:
documents = self.reorder_runner.run(documents)

View File

@@ -1,23 +1,30 @@
import concurrent.futures
from concurrent.futures import ThreadPoolExecutor
from typing import Any
from flask import Flask, current_app
from sqlalchemy import select
from sqlalchemy.orm import Session, load_only
from configs import dify_config
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.embedding.retrieval import RetrievalSegments
from core.rag.entities.metadata_entities import MetadataCondition
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.constant.query_type import QueryType
from core.rag.models.document import Document
from core.rag.rerank.rerank_type import RerankMode
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.signature import sign_upload_file
from extensions.ext_database import db
from models.dataset import ChildChunk, Dataset, DocumentSegment
from models.dataset import ChildChunk, Dataset, DocumentSegment, SegmentAttachmentBinding
from models.dataset import Document as DatasetDocument
from models.model import UploadFile
from services.external_knowledge_service import ExternalDatasetService
default_retrieval_model = {
@@ -37,14 +44,15 @@ class RetrievalService:
retrieval_method: RetrievalMethod,
dataset_id: str,
query: str,
top_k: int,
top_k: int = 4,
score_threshold: float | None = 0.0,
reranking_model: dict | None = None,
reranking_mode: str = "reranking_model",
weights: dict | None = None,
document_ids_filter: list[str] | None = None,
attachment_ids: list | None = None,
):
if not query:
if not query and not attachment_ids:
return []
dataset = cls._get_dataset(dataset_id)
if not dataset:
@@ -56,69 +64,52 @@ class RetrievalService:
# Optimize multithreading with thread pools
with ThreadPoolExecutor(max_workers=dify_config.RETRIEVAL_SERVICE_EXECUTORS) as executor: # type: ignore
futures = []
if retrieval_method == RetrievalMethod.KEYWORD_SEARCH:
retrieval_service = RetrievalService()
if query:
futures.append(
executor.submit(
cls.keyword_search,
retrieval_service._retrieve,
flask_app=current_app._get_current_object(), # type: ignore
dataset_id=dataset_id,
query=query,
top_k=top_k,
all_documents=all_documents,
exceptions=exceptions,
document_ids_filter=document_ids_filter,
)
)
if RetrievalMethod.is_support_semantic_search(retrieval_method):
futures.append(
executor.submit(
cls.embedding_search,
flask_app=current_app._get_current_object(), # type: ignore
dataset_id=dataset_id,
retrieval_method=retrieval_method,
dataset=dataset,
query=query,
top_k=top_k,
score_threshold=score_threshold,
reranking_model=reranking_model,
all_documents=all_documents,
retrieval_method=retrieval_method,
exceptions=exceptions,
reranking_mode=reranking_mode,
weights=weights,
document_ids_filter=document_ids_filter,
attachment_id=None,
all_documents=all_documents,
exceptions=exceptions,
)
)
if RetrievalMethod.is_support_fulltext_search(retrieval_method):
futures.append(
executor.submit(
cls.full_text_index_search,
flask_app=current_app._get_current_object(), # type: ignore
dataset_id=dataset_id,
query=query,
top_k=top_k,
score_threshold=score_threshold,
reranking_model=reranking_model,
all_documents=all_documents,
retrieval_method=retrieval_method,
exceptions=exceptions,
document_ids_filter=document_ids_filter,
if attachment_ids:
for attachment_id in attachment_ids:
futures.append(
executor.submit(
retrieval_service._retrieve,
flask_app=current_app._get_current_object(), # type: ignore
retrieval_method=retrieval_method,
dataset=dataset,
query=None,
top_k=top_k,
score_threshold=score_threshold,
reranking_model=reranking_model,
reranking_mode=reranking_mode,
weights=weights,
document_ids_filter=document_ids_filter,
attachment_id=attachment_id,
all_documents=all_documents,
exceptions=exceptions,
)
)
)
concurrent.futures.wait(futures, timeout=30, return_when=concurrent.futures.ALL_COMPLETED)
concurrent.futures.wait(futures, timeout=3600, return_when=concurrent.futures.ALL_COMPLETED)
if exceptions:
raise ValueError(";\n".join(exceptions))
# Deduplicate documents for hybrid search to avoid duplicate chunks
if retrieval_method == RetrievalMethod.HYBRID_SEARCH:
all_documents = cls._deduplicate_documents(all_documents)
data_post_processor = DataPostProcessor(
str(dataset.tenant_id), reranking_mode, reranking_model, weights, False
)
all_documents = data_post_processor.invoke(
query=query,
documents=all_documents,
score_threshold=score_threshold,
top_n=top_k,
)
return all_documents
@classmethod
@@ -223,6 +214,7 @@ class RetrievalService:
retrieval_method: RetrievalMethod,
exceptions: list,
document_ids_filter: list[str] | None = None,
query_type: QueryType = QueryType.TEXT_QUERY,
):
with flask_app.app_context():
try:
@@ -231,14 +223,30 @@ class RetrievalService:
raise ValueError("dataset not found")
vector = Vector(dataset=dataset)
documents = vector.search_by_vector(
query,
search_type="similarity_score_threshold",
top_k=top_k,
score_threshold=score_threshold,
filter={"group_id": [dataset.id]},
document_ids_filter=document_ids_filter,
)
documents = []
if query_type == QueryType.TEXT_QUERY:
documents.extend(
vector.search_by_vector(
query,
search_type="similarity_score_threshold",
top_k=top_k,
score_threshold=score_threshold,
filter={"group_id": [dataset.id]},
document_ids_filter=document_ids_filter,
)
)
if query_type == QueryType.IMAGE_QUERY:
if not dataset.is_multimodal:
return
documents.extend(
vector.search_by_file(
file_id=query,
top_k=top_k,
score_threshold=score_threshold,
filter={"group_id": [dataset.id]},
document_ids_filter=document_ids_filter,
)
)
if documents:
if (
@@ -250,14 +258,37 @@ class RetrievalService:
data_post_processor = DataPostProcessor(
str(dataset.tenant_id), str(RerankMode.RERANKING_MODEL), reranking_model, None, False
)
all_documents.extend(
data_post_processor.invoke(
query=query,
documents=documents,
score_threshold=score_threshold,
top_n=len(documents),
if dataset.is_multimodal:
model_manager = ModelManager()
is_support_vision = model_manager.check_model_support_vision(
tenant_id=dataset.tenant_id,
provider=reranking_model.get("reranking_provider_name") or "",
model=reranking_model.get("reranking_model_name") or "",
model_type=ModelType.RERANK,
)
if is_support_vision:
all_documents.extend(
data_post_processor.invoke(
query=query,
documents=documents,
score_threshold=score_threshold,
top_n=len(documents),
query_type=query_type,
)
)
else:
# not effective, return original documents
all_documents.extend(documents)
else:
all_documents.extend(
data_post_processor.invoke(
query=query,
documents=documents,
score_threshold=score_threshold,
top_n=len(documents),
query_type=query_type,
)
)
)
else:
all_documents.extend(documents)
except Exception as e:
@@ -339,103 +370,159 @@ class RetrievalService:
records = []
include_segment_ids = set()
segment_child_map = {}
# Process documents
for document in documents:
document_id = document.metadata.get("document_id")
if document_id not in dataset_documents:
continue
dataset_document = dataset_documents[document_id]
if not dataset_document:
continue
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
# Handle parent-child documents
child_index_node_id = document.metadata.get("doc_id")
child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id == child_index_node_id)
child_chunk = db.session.scalar(child_chunk_stmt)
if not child_chunk:
segment_file_map = {}
with Session(db.engine) as session:
# Process documents
for document in documents:
segment_id = None
attachment_info = None
child_chunk = None
document_id = document.metadata.get("document_id")
if document_id not in dataset_documents:
continue
segment = (
db.session.query(DocumentSegment)
.where(
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.id == child_chunk.segment_id,
)
.options(
load_only(
DocumentSegment.id,
DocumentSegment.content,
DocumentSegment.answer,
dataset_document = dataset_documents[document_id]
if not dataset_document:
continue
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
# Handle parent-child documents
if document.metadata.get("doc_type") == DocType.IMAGE:
attachment_info_dict = cls.get_segment_attachment_info(
dataset_document.dataset_id,
dataset_document.tenant_id,
document.metadata.get("doc_id") or "",
session,
)
if attachment_info_dict:
attachment_info = attachment_info_dict["attchment_info"]
segment_id = attachment_info_dict["segment_id"]
else:
child_index_node_id = document.metadata.get("doc_id")
child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id == child_index_node_id)
child_chunk = session.scalar(child_chunk_stmt)
if not child_chunk:
continue
segment_id = child_chunk.segment_id
if not segment_id:
continue
segment = (
session.query(DocumentSegment)
.where(
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.id == segment_id,
)
.options(
load_only(
DocumentSegment.id,
DocumentSegment.content,
DocumentSegment.answer,
)
)
.first()
)
.first()
)
if not segment:
continue
if not segment:
continue
if segment.id not in include_segment_ids:
include_segment_ids.add(segment.id)
child_chunk_detail = {
"id": child_chunk.id,
"content": child_chunk.content,
"position": child_chunk.position,
"score": document.metadata.get("score", 0.0),
}
map_detail = {
"max_score": document.metadata.get("score", 0.0),
"child_chunks": [child_chunk_detail],
}
segment_child_map[segment.id] = map_detail
record = {
"segment": segment,
}
records.append(record)
if segment.id not in include_segment_ids:
include_segment_ids.add(segment.id)
if child_chunk:
child_chunk_detail = {
"id": child_chunk.id,
"content": child_chunk.content,
"position": child_chunk.position,
"score": document.metadata.get("score", 0.0),
}
map_detail = {
"max_score": document.metadata.get("score", 0.0),
"child_chunks": [child_chunk_detail],
}
segment_child_map[segment.id] = map_detail
record = {
"segment": segment,
}
if attachment_info:
segment_file_map[segment.id] = [attachment_info]
records.append(record)
else:
if child_chunk:
child_chunk_detail = {
"id": child_chunk.id,
"content": child_chunk.content,
"position": child_chunk.position,
"score": document.metadata.get("score", 0.0),
}
segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail)
segment_child_map[segment.id]["max_score"] = max(
segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0)
)
if attachment_info:
segment_file_map[segment.id].append(attachment_info)
else:
child_chunk_detail = {
"id": child_chunk.id,
"content": child_chunk.content,
"position": child_chunk.position,
"score": document.metadata.get("score", 0.0),
}
segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail)
segment_child_map[segment.id]["max_score"] = max(
segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0)
)
else:
# Handle normal documents
index_node_id = document.metadata.get("doc_id")
if not index_node_id:
continue
document_segment_stmt = select(DocumentSegment).where(
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.index_node_id == index_node_id,
)
segment = db.session.scalar(document_segment_stmt)
# Handle normal documents
segment = None
if document.metadata.get("doc_type") == DocType.IMAGE:
attachment_info_dict = cls.get_segment_attachment_info(
dataset_document.dataset_id,
dataset_document.tenant_id,
document.metadata.get("doc_id") or "",
session,
)
if attachment_info_dict:
attachment_info = attachment_info_dict["attchment_info"]
segment_id = attachment_info_dict["segment_id"]
document_segment_stmt = select(DocumentSegment).where(
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.id == segment_id,
)
segment = db.session.scalar(document_segment_stmt)
if segment:
segment_file_map[segment.id] = [attachment_info]
else:
index_node_id = document.metadata.get("doc_id")
if not index_node_id:
continue
document_segment_stmt = select(DocumentSegment).where(
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.index_node_id == index_node_id,
)
segment = db.session.scalar(document_segment_stmt)
if not segment:
continue
include_segment_ids.add(segment.id)
record = {
"segment": segment,
"score": document.metadata.get("score"), # type: ignore
}
records.append(record)
if not segment:
continue
if segment.id not in include_segment_ids:
include_segment_ids.add(segment.id)
record = {
"segment": segment,
"score": document.metadata.get("score"), # type: ignore
}
if attachment_info:
segment_file_map[segment.id] = [attachment_info]
records.append(record)
else:
if attachment_info:
attachment_infos = segment_file_map.get(segment.id, [])
if attachment_info not in attachment_infos:
attachment_infos.append(attachment_info)
segment_file_map[segment.id] = attachment_infos
# Add child chunks information to records
for record in records:
if record["segment"].id in segment_child_map:
record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks") # type: ignore
record["score"] = segment_child_map[record["segment"].id]["max_score"]
if record["segment"].id in segment_file_map:
record["files"] = segment_file_map[record["segment"].id] # type: ignore[assignment]
result = []
for record in records:
@@ -447,6 +534,11 @@ class RetrievalService:
if not isinstance(child_chunks, list):
child_chunks = None
# Extract files, ensuring it's a list or None
files = record.get("files")
if not isinstance(files, list):
files = None
# Extract score, ensuring it's a float or None
score_value = record.get("score")
score = (
@@ -456,10 +548,149 @@ class RetrievalService:
)
# Create RetrievalSegments object
retrieval_segment = RetrievalSegments(segment=segment, child_chunks=child_chunks, score=score)
retrieval_segment = RetrievalSegments(
segment=segment, child_chunks=child_chunks, score=score, files=files
)
result.append(retrieval_segment)
return result
except Exception as e:
db.session.rollback()
raise e
def _retrieve(
self,
flask_app: Flask,
retrieval_method: RetrievalMethod,
dataset: Dataset,
query: str | None = None,
top_k: int = 4,
score_threshold: float | None = 0.0,
reranking_model: dict | None = None,
reranking_mode: str = "reranking_model",
weights: dict | None = None,
document_ids_filter: list[str] | None = None,
attachment_id: str | None = None,
all_documents: list[Document] = [],
exceptions: list[str] = [],
):
if not query and not attachment_id:
return
with flask_app.app_context():
all_documents_item: list[Document] = []
# Optimize multithreading with thread pools
with ThreadPoolExecutor(max_workers=dify_config.RETRIEVAL_SERVICE_EXECUTORS) as executor: # type: ignore
futures = []
if retrieval_method == RetrievalMethod.KEYWORD_SEARCH and query:
futures.append(
executor.submit(
self.keyword_search,
flask_app=current_app._get_current_object(), # type: ignore
dataset_id=dataset.id,
query=query,
top_k=top_k,
all_documents=all_documents_item,
exceptions=exceptions,
document_ids_filter=document_ids_filter,
)
)
if RetrievalMethod.is_support_semantic_search(retrieval_method):
if query:
futures.append(
executor.submit(
self.embedding_search,
flask_app=current_app._get_current_object(), # type: ignore
dataset_id=dataset.id,
query=query,
top_k=top_k,
score_threshold=score_threshold,
reranking_model=reranking_model,
all_documents=all_documents_item,
retrieval_method=retrieval_method,
exceptions=exceptions,
document_ids_filter=document_ids_filter,
query_type=QueryType.TEXT_QUERY,
)
)
if attachment_id:
futures.append(
executor.submit(
self.embedding_search,
flask_app=current_app._get_current_object(), # type: ignore
dataset_id=dataset.id,
query=attachment_id,
top_k=top_k,
score_threshold=score_threshold,
reranking_model=reranking_model,
all_documents=all_documents_item,
retrieval_method=retrieval_method,
exceptions=exceptions,
document_ids_filter=document_ids_filter,
query_type=QueryType.IMAGE_QUERY,
)
)
if RetrievalMethod.is_support_fulltext_search(retrieval_method) and query:
futures.append(
executor.submit(
self.full_text_index_search,
flask_app=current_app._get_current_object(), # type: ignore
dataset_id=dataset.id,
query=query,
top_k=top_k,
score_threshold=score_threshold,
reranking_model=reranking_model,
all_documents=all_documents_item,
retrieval_method=retrieval_method,
exceptions=exceptions,
document_ids_filter=document_ids_filter,
)
)
concurrent.futures.wait(futures, timeout=300, return_when=concurrent.futures.ALL_COMPLETED)
if exceptions:
raise ValueError(";\n".join(exceptions))
# Deduplicate documents for hybrid search to avoid duplicate chunks
if retrieval_method == RetrievalMethod.HYBRID_SEARCH:
if attachment_id and reranking_mode == RerankMode.WEIGHTED_SCORE:
all_documents.extend(all_documents_item)
all_documents_item = self._deduplicate_documents(all_documents_item)
data_post_processor = DataPostProcessor(
str(dataset.tenant_id), reranking_mode, reranking_model, weights, False
)
query = query or attachment_id
if not query:
return
all_documents_item = data_post_processor.invoke(
query=query,
documents=all_documents_item,
score_threshold=score_threshold,
top_n=top_k,
query_type=QueryType.TEXT_QUERY if query else QueryType.IMAGE_QUERY,
)
all_documents.extend(all_documents_item)
@classmethod
def get_segment_attachment_info(
cls, dataset_id: str, tenant_id: str, attachment_id: str, session: Session
) -> dict[str, Any] | None:
upload_file = session.query(UploadFile).where(UploadFile.id == attachment_id).first()
if upload_file:
attachment_binding = (
session.query(SegmentAttachmentBinding)
.where(SegmentAttachmentBinding.attachment_id == upload_file.id)
.first()
)
if attachment_binding:
attchment_info = {
"id": upload_file.id,
"name": upload_file.name,
"extension": "." + upload_file.extension,
"mime_type": upload_file.mime_type,
"source_url": sign_upload_file(upload_file.id, upload_file.extension),
"size": upload_file.size,
}
return {"attchment_info": attchment_info, "segment_id": attachment_binding.segment_id}
return None

View File

@@ -1,3 +1,4 @@
import base64
import logging
import time
from abc import ABC, abstractmethod
@@ -12,10 +13,13 @@ from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.cached_embedding import CacheEmbedding
from core.rag.embedding.embedding_base import Embeddings
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.models.document import Document
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from extensions.ext_storage import storage
from models.dataset import Dataset, Whitelist
from models.model import UploadFile
logger = logging.getLogger(__name__)
@@ -203,6 +207,47 @@ class Vector:
self._vector_processor.create(texts=batch, embeddings=batch_embeddings, **kwargs)
logger.info("Embedding %s texts took %s s", len(texts), time.time() - start)
def create_multimodal(self, file_documents: list | None = None, **kwargs):
if file_documents:
start = time.time()
logger.info("start embedding %s files %s", len(file_documents), start)
batch_size = 1000
total_batches = len(file_documents) + batch_size - 1
for i in range(0, len(file_documents), batch_size):
batch = file_documents[i : i + batch_size]
batch_start = time.time()
logger.info("Processing batch %s/%s (%s files)", i // batch_size + 1, total_batches, len(batch))
# Batch query all upload files to avoid N+1 queries
attachment_ids = [doc.metadata["doc_id"] for doc in batch]
stmt = select(UploadFile).where(UploadFile.id.in_(attachment_ids))
upload_files = db.session.scalars(stmt).all()
upload_file_map = {str(f.id): f for f in upload_files}
file_base64_list = []
real_batch = []
for document in batch:
attachment_id = document.metadata["doc_id"]
doc_type = document.metadata["doc_type"]
upload_file = upload_file_map.get(attachment_id)
if upload_file:
blob = storage.load_once(upload_file.key)
file_base64_str = base64.b64encode(blob).decode()
file_base64_list.append(
{
"content": file_base64_str,
"content_type": doc_type,
"file_id": attachment_id,
}
)
real_batch.append(document)
batch_embeddings = self._embeddings.embed_multimodal_documents(file_base64_list)
logger.info(
"Embedding batch %s/%s took %s s", i // batch_size + 1, total_batches, time.time() - batch_start
)
self._vector_processor.create(texts=real_batch, embeddings=batch_embeddings, **kwargs)
logger.info("Embedding %s files took %s s", len(file_documents), time.time() - start)
def add_texts(self, documents: list[Document], **kwargs):
if kwargs.get("duplicate_check", False):
documents = self._filter_duplicate_texts(documents)
@@ -223,6 +268,22 @@ class Vector:
query_vector = self._embeddings.embed_query(query)
return self._vector_processor.search_by_vector(query_vector, **kwargs)
def search_by_file(self, file_id: str, **kwargs: Any) -> list[Document]:
upload_file: UploadFile | None = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
if not upload_file:
return []
blob = storage.load_once(upload_file.key)
file_base64_str = base64.b64encode(blob).decode()
multimodal_vector = self._embeddings.embed_multimodal_query(
{
"content": file_base64_str,
"content_type": DocType.IMAGE,
"file_id": file_id,
}
)
return self._vector_processor.search_by_vector(multimodal_vector, **kwargs)
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
return self._vector_processor.search_by_full_text(query, **kwargs)

View File

@@ -5,9 +5,9 @@ from sqlalchemy import func, select
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.rag.models.document import Document
from core.rag.models.document import AttachmentDocument, Document
from extensions.ext_database import db
from models.dataset import ChildChunk, Dataset, DocumentSegment
from models.dataset import ChildChunk, Dataset, DocumentSegment, SegmentAttachmentBinding
class DatasetDocumentStore:
@@ -120,6 +120,9 @@ class DatasetDocumentStore:
db.session.add(segment_document)
db.session.flush()
self.add_multimodel_documents_binding(
segment_id=segment_document.id, multimodel_documents=doc.attachments
)
if save_child:
if doc.children:
for position, child in enumerate(doc.children, start=1):
@@ -144,6 +147,9 @@ class DatasetDocumentStore:
segment_document.index_node_hash = doc.metadata.get("doc_hash")
segment_document.word_count = len(doc.page_content)
segment_document.tokens = tokens
self.add_multimodel_documents_binding(
segment_id=segment_document.id, multimodel_documents=doc.attachments
)
if save_child and doc.children:
# delete the existing child chunks
db.session.query(ChildChunk).where(
@@ -233,3 +239,15 @@ class DatasetDocumentStore:
document_segment = db.session.scalar(stmt)
return document_segment
def add_multimodel_documents_binding(self, segment_id: str, multimodel_documents: list[AttachmentDocument] | None):
if multimodel_documents:
for multimodel_document in multimodel_documents:
binding = SegmentAttachmentBinding(
tenant_id=self._dataset.tenant_id,
dataset_id=self._dataset.id,
document_id=self._document_id,
segment_id=segment_id,
attachment_id=multimodel_document.metadata["doc_id"],
)
db.session.add(binding)

View File

@@ -104,6 +104,88 @@ class CacheEmbedding(Embeddings):
return text_embeddings
def embed_multimodal_documents(self, multimodel_documents: list[dict]) -> list[list[float]]:
"""Embed file documents."""
# use doc embedding cache or store if not exists
multimodel_embeddings: list[Any] = [None for _ in range(len(multimodel_documents))]
embedding_queue_indices = []
for i, multimodel_document in enumerate(multimodel_documents):
file_id = multimodel_document["file_id"]
embedding = (
db.session.query(Embedding)
.filter_by(
model_name=self._model_instance.model, hash=file_id, provider_name=self._model_instance.provider
)
.first()
)
if embedding:
multimodel_embeddings[i] = embedding.get_embedding()
else:
embedding_queue_indices.append(i)
# NOTE: avoid closing the shared scoped session here; downstream code may still have pending work
if embedding_queue_indices:
embedding_queue_multimodel_documents = [multimodel_documents[i] for i in embedding_queue_indices]
embedding_queue_embeddings = []
try:
model_type_instance = cast(TextEmbeddingModel, self._model_instance.model_type_instance)
model_schema = model_type_instance.get_model_schema(
self._model_instance.model, self._model_instance.credentials
)
max_chunks = (
model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS]
if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties
else 1
)
for i in range(0, len(embedding_queue_multimodel_documents), max_chunks):
batch_multimodel_documents = embedding_queue_multimodel_documents[i : i + max_chunks]
embedding_result = self._model_instance.invoke_multimodal_embedding(
multimodel_documents=batch_multimodel_documents,
user=self._user,
input_type=EmbeddingInputType.DOCUMENT,
)
for vector in embedding_result.embeddings:
try:
# FIXME: type ignore for numpy here
normalized_embedding = (vector / np.linalg.norm(vector)).tolist() # type: ignore
# stackoverflow best way: https://stackoverflow.com/questions/20319813/how-to-check-list-containing-nan
if np.isnan(normalized_embedding).any():
# for issue #11827 float values are not json compliant
logger.warning("Normalized embedding is nan: %s", normalized_embedding)
continue
embedding_queue_embeddings.append(normalized_embedding)
except IntegrityError:
db.session.rollback()
except Exception:
logger.exception("Failed transform embedding")
cache_embeddings = []
try:
for i, n_embedding in zip(embedding_queue_indices, embedding_queue_embeddings):
multimodel_embeddings[i] = n_embedding
file_id = multimodel_documents[i]["file_id"]
if file_id not in cache_embeddings:
embedding_cache = Embedding(
model_name=self._model_instance.model,
hash=file_id,
provider_name=self._model_instance.provider,
embedding=pickle.dumps(n_embedding, protocol=pickle.HIGHEST_PROTOCOL),
)
embedding_cache.set_embedding(n_embedding)
db.session.add(embedding_cache)
cache_embeddings.append(file_id)
db.session.commit()
except IntegrityError:
db.session.rollback()
except Exception as ex:
db.session.rollback()
logger.exception("Failed to embed documents")
raise ex
return multimodel_embeddings
def embed_query(self, text: str) -> list[float]:
"""Embed query text."""
# use doc embedding cache or store if not exists
@@ -146,3 +228,46 @@ class CacheEmbedding(Embeddings):
raise ex
return embedding_results # type: ignore
def embed_multimodal_query(self, multimodel_document: dict) -> list[float]:
"""Embed multimodal documents."""
# use doc embedding cache or store if not exists
file_id = multimodel_document["file_id"]
embedding_cache_key = f"{self._model_instance.provider}_{self._model_instance.model}_{file_id}"
embedding = redis_client.get(embedding_cache_key)
if embedding:
redis_client.expire(embedding_cache_key, 600)
decoded_embedding = np.frombuffer(base64.b64decode(embedding), dtype="float")
return [float(x) for x in decoded_embedding]
try:
embedding_result = self._model_instance.invoke_multimodal_embedding(
multimodel_documents=[multimodel_document], user=self._user, input_type=EmbeddingInputType.QUERY
)
embedding_results = embedding_result.embeddings[0]
# FIXME: type ignore for numpy here
embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist() # type: ignore
if np.isnan(embedding_results).any():
raise ValueError("Normalized embedding is nan please try again")
except Exception as ex:
if dify_config.DEBUG:
logger.exception("Failed to embed multimodal document '%s'", multimodel_document["file_id"])
raise ex
try:
# encode embedding to base64
embedding_vector = np.array(embedding_results)
vector_bytes = embedding_vector.tobytes()
# Transform to Base64
encoded_vector = base64.b64encode(vector_bytes)
# Transform to string
encoded_str = encoded_vector.decode("utf-8")
redis_client.setex(embedding_cache_key, 600, encoded_str)
except Exception as ex:
if dify_config.DEBUG:
logger.exception(
"Failed to add embedding to redis for the multimodal document '%s'", multimodel_document["file_id"]
)
raise ex
return embedding_results # type: ignore

View File

@@ -9,11 +9,21 @@ class Embeddings(ABC):
"""Embed search docs."""
raise NotImplementedError
@abstractmethod
def embed_multimodal_documents(self, multimodel_documents: list[dict]) -> list[list[float]]:
"""Embed file documents."""
raise NotImplementedError
@abstractmethod
def embed_query(self, text: str) -> list[float]:
"""Embed query text."""
raise NotImplementedError
@abstractmethod
def embed_multimodal_query(self, multimodel_document: dict) -> list[float]:
"""Embed multimodal query."""
raise NotImplementedError
async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
"""Asynchronous Embed search docs."""
raise NotImplementedError

View File

@@ -19,3 +19,4 @@ class RetrievalSegments(BaseModel):
segment: DocumentSegment
child_chunks: list[RetrievalChildChunk] | None = None
score: float | None = None
files: list[dict[str, str | int]] | None = None

View File

@@ -21,3 +21,4 @@ class RetrievalSourceMetadata(BaseModel):
page: int | None = None
doc_metadata: dict[str, Any] | None = None
title: str | None = None
files: list[dict[str, Any]] | None = None

View File

@@ -0,0 +1,6 @@
from enum import StrEnum
class DocType(StrEnum):
TEXT = "text"
IMAGE = "image"

View File

@@ -1,7 +1,12 @@
from enum import StrEnum
class IndexType(StrEnum):
class IndexStructureType(StrEnum):
PARAGRAPH_INDEX = "text_model"
QA_INDEX = "qa_model"
PARENT_CHILD_INDEX = "hierarchical_model"
class IndexTechniqueType(StrEnum):
ECONOMY = "economy"
HIGH_QUALITY = "high_quality"

View File

@@ -0,0 +1,6 @@
from enum import StrEnum
class QueryType(StrEnum):
TEXT_QUERY = "text_query"
IMAGE_QUERY = "image_query"

View File

@@ -1,20 +1,34 @@
"""Abstract interface for document loader implementations."""
import cgi
import logging
import mimetypes
import os
import re
from abc import ABC, abstractmethod
from collections.abc import Mapping
from typing import TYPE_CHECKING, Any, Optional
from urllib.parse import unquote, urlparse
import httpx
from configs import dify_config
from core.helper import ssrf_proxy
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.models.document import Document
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.models.document import AttachmentDocument, Document
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.rag.splitter.fixed_text_splitter import (
EnhanceRecursiveCharacterTextSplitter,
FixedRecursiveCharacterTextSplitter,
)
from core.rag.splitter.text_splitter import TextSplitter
from extensions.ext_database import db
from extensions.ext_storage import storage
from models import Account, ToolFile
from models.dataset import Dataset, DatasetProcessRule
from models.dataset import Document as DatasetDocument
from models.model import UploadFile
if TYPE_CHECKING:
from core.model_manager import ModelInstance
@@ -28,11 +42,18 @@ class BaseIndexProcessor(ABC):
raise NotImplementedError
@abstractmethod
def transform(self, documents: list[Document], **kwargs) -> list[Document]:
def transform(self, documents: list[Document], current_user: Account | None = None, **kwargs) -> list[Document]:
raise NotImplementedError
@abstractmethod
def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs):
def load(
self,
dataset: Dataset,
documents: list[Document],
multimodal_documents: list[AttachmentDocument] | None = None,
with_keywords: bool = True,
**kwargs,
):
raise NotImplementedError
@abstractmethod
@@ -96,3 +117,178 @@ class BaseIndexProcessor(ABC):
)
return character_splitter # type: ignore
def _get_content_files(self, document: Document, current_user: Account | None = None) -> list[AttachmentDocument]:
"""
Get the content files from the document.
"""
multi_model_documents: list[AttachmentDocument] = []
text = document.page_content
images = self._extract_markdown_images(text)
if not images:
return multi_model_documents
upload_file_id_list = []
for image in images:
# Collect all upload_file_ids including duplicates to preserve occurrence count
# For data before v0.10.0
pattern = r"/files/([a-f0-9\-]+)/image-preview(?:\?.*?)?"
match = re.search(pattern, image)
if match:
upload_file_id = match.group(1)
upload_file_id_list.append(upload_file_id)
continue
# For data after v0.10.0
pattern = r"/files/([a-f0-9\-]+)/file-preview(?:\?.*?)?"
match = re.search(pattern, image)
if match:
upload_file_id = match.group(1)
upload_file_id_list.append(upload_file_id)
continue
# For tools directory - direct file formats (e.g., .png, .jpg, etc.)
# Match URL including any query parameters up to common URL boundaries (space, parenthesis, quotes)
pattern = r"/files/tools/([a-f0-9\-]+)\.([a-zA-Z0-9]+)(?:\?[^\s\)\"\']*)?"
match = re.search(pattern, image)
if match:
if current_user:
tool_file_id = match.group(1)
upload_file_id = self._download_tool_file(tool_file_id, current_user)
if upload_file_id:
upload_file_id_list.append(upload_file_id)
continue
if current_user:
upload_file_id = self._download_image(image.split(" ")[0], current_user)
if upload_file_id:
upload_file_id_list.append(upload_file_id)
if not upload_file_id_list:
return multi_model_documents
# Get unique IDs for database query
unique_upload_file_ids = list(set(upload_file_id_list))
upload_files = db.session.query(UploadFile).where(UploadFile.id.in_(unique_upload_file_ids)).all()
# Create a mapping from ID to UploadFile for quick lookup
upload_file_map = {upload_file.id: upload_file for upload_file in upload_files}
# Create a Document for each occurrence (including duplicates)
for upload_file_id in upload_file_id_list:
upload_file = upload_file_map.get(upload_file_id)
if upload_file:
multi_model_documents.append(
AttachmentDocument(
page_content=upload_file.name,
metadata={
"doc_id": upload_file.id,
"doc_hash": "",
"document_id": document.metadata.get("document_id"),
"dataset_id": document.metadata.get("dataset_id"),
"doc_type": DocType.IMAGE,
},
)
)
return multi_model_documents
def _extract_markdown_images(self, text: str) -> list[str]:
"""
Extract the markdown images from the text.
"""
pattern = r"!\[.*?\]\((.*?)\)"
return re.findall(pattern, text)
def _download_image(self, image_url: str, current_user: Account) -> str | None:
"""
Download the image from the URL.
Image size must not exceed 2MB.
"""
from services.file_service import FileService
MAX_IMAGE_SIZE = dify_config.ATTACHMENT_IMAGE_FILE_SIZE_LIMIT * 1024 * 1024
DOWNLOAD_TIMEOUT = dify_config.ATTACHMENT_IMAGE_DOWNLOAD_TIMEOUT
try:
# Download with timeout
response = ssrf_proxy.get(image_url, timeout=DOWNLOAD_TIMEOUT)
response.raise_for_status()
# Check Content-Length header if available
content_length = response.headers.get("Content-Length")
if content_length and int(content_length) > MAX_IMAGE_SIZE:
logging.warning("Image from %s exceeds 2MB limit (size: %s bytes)", image_url, content_length)
return None
filename = None
content_disposition = response.headers.get("content-disposition")
if content_disposition:
_, params = cgi.parse_header(content_disposition)
if "filename" in params:
filename = params["filename"]
filename = unquote(filename)
if not filename:
parsed_url = urlparse(image_url)
# unquote 处理 URL 中的中文
path = unquote(parsed_url.path)
filename = os.path.basename(path)
if not filename:
filename = "downloaded_image_file"
name, current_ext = os.path.splitext(filename)
content_type = response.headers.get("content-type", "").split(";")[0].strip()
real_ext = mimetypes.guess_extension(content_type)
if not current_ext and real_ext or current_ext in [".php", ".jsp", ".asp", ".html"] and real_ext:
filename = f"{name}{real_ext}"
# Download content with size limit
blob = b""
for chunk in response.iter_bytes(chunk_size=8192):
blob += chunk
if len(blob) > MAX_IMAGE_SIZE:
logging.warning("Image from %s exceeds 2MB limit during download", image_url)
return None
if not blob:
logging.warning("Image from %s is empty", image_url)
return None
upload_file = FileService(db.engine).upload_file(
filename=filename,
content=blob,
mimetype=content_type,
user=current_user,
)
return upload_file.id
except httpx.TimeoutException:
logging.warning("Timeout downloading image from %s after %s seconds", image_url, DOWNLOAD_TIMEOUT)
return None
except httpx.RequestError as e:
logging.warning("Error downloading image from %s: %s", image_url, str(e))
return None
except Exception:
logging.exception("Unexpected error downloading image from %s", image_url)
return None
def _download_tool_file(self, tool_file_id: str, current_user: Account) -> str | None:
"""
Download the tool file from the ID.
"""
from services.file_service import FileService
tool_file = db.session.query(ToolFile).where(ToolFile.id == tool_file_id).first()
if not tool_file:
return None
blob = storage.load_once(tool_file.file_key)
upload_file = FileService(db.engine).upload_file(
filename=tool_file.name,
content=blob,
mimetype=tool_file.mimetype,
user=current_user,
)
return upload_file.id

View File

@@ -1,6 +1,6 @@
"""Abstract interface for document loader implementations."""
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor
from core.rag.index_processor.processor.parent_child_index_processor import ParentChildIndexProcessor
@@ -19,11 +19,11 @@ class IndexProcessorFactory:
if not self._index_type:
raise ValueError("Index type must be specified.")
if self._index_type == IndexType.PARAGRAPH_INDEX:
if self._index_type == IndexStructureType.PARAGRAPH_INDEX:
return ParagraphIndexProcessor()
elif self._index_type == IndexType.QA_INDEX:
elif self._index_type == IndexStructureType.QA_INDEX:
return QAIndexProcessor()
elif self._index_type == IndexType.PARENT_CHILD_INDEX:
elif self._index_type == IndexStructureType.PARENT_CHILD_INDEX:
return ParentChildIndexProcessor()
else:
raise ValueError(f"Index type {self._index_type} is not supported.")

View File

@@ -11,14 +11,17 @@ from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.models.document import Document
from core.rag.models.document import AttachmentDocument, Document, MultimodalGeneralStructureChunk
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.utils.text_processing_utils import remove_leading_symbols
from libs import helper
from models.account import Account
from models.dataset import Dataset, DatasetProcessRule
from models.dataset import Document as DatasetDocument
from services.account_service import AccountService
from services.entities.knowledge_entities.knowledge_entities import Rule
@@ -33,7 +36,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
return text_docs
def transform(self, documents: list[Document], **kwargs) -> list[Document]:
def transform(self, documents: list[Document], current_user: Account | None = None, **kwargs) -> list[Document]:
process_rule = kwargs.get("process_rule")
if not process_rule:
raise ValueError("No process rule found.")
@@ -69,6 +72,11 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
if document_node.metadata is not None:
document_node.metadata["doc_id"] = doc_id
document_node.metadata["doc_hash"] = hash
multimodal_documents = (
self._get_content_files(document_node, current_user) if document_node.metadata else None
)
if multimodal_documents:
document_node.attachments = multimodal_documents
# delete Splitter character
page_content = remove_leading_symbols(document_node.page_content).strip()
if len(page_content) > 0:
@@ -77,10 +85,19 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
all_documents.extend(split_documents)
return all_documents
def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs):
def load(
self,
dataset: Dataset,
documents: list[Document],
multimodal_documents: list[AttachmentDocument] | None = None,
with_keywords: bool = True,
**kwargs,
):
if dataset.indexing_technique == "high_quality":
vector = Vector(dataset)
vector.create(documents)
if multimodal_documents and dataset.is_multimodal:
vector.create_multimodal(multimodal_documents)
with_keywords = False
if with_keywords:
keywords_list = kwargs.get("keywords_list")
@@ -134,8 +151,9 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
return docs
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any):
documents: list[Any] = []
all_multimodal_documents: list[Any] = []
if isinstance(chunks, list):
documents = []
for content in chunks:
metadata = {
"dataset_id": dataset.id,
@@ -144,26 +162,68 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
"doc_hash": helper.generate_text_hash(content),
}
doc = Document(page_content=content, metadata=metadata)
attachments = self._get_content_files(doc)
if attachments:
doc.attachments = attachments
all_multimodal_documents.extend(attachments)
documents.append(doc)
if documents:
# save node to document segment
doc_store = DatasetDocumentStore(dataset=dataset, user_id=document.created_by, document_id=document.id)
# add document segments
doc_store.add_documents(docs=documents, save_child=False)
if dataset.indexing_technique == "high_quality":
vector = Vector(dataset)
vector.create(documents)
elif dataset.indexing_technique == "economy":
keyword = Keyword(dataset)
keyword.add_texts(documents)
else:
raise ValueError("Chunks is not a list")
multimodal_general_structure = MultimodalGeneralStructureChunk.model_validate(chunks)
for general_chunk in multimodal_general_structure.general_chunks:
metadata = {
"dataset_id": dataset.id,
"document_id": document.id,
"doc_id": str(uuid.uuid4()),
"doc_hash": helper.generate_text_hash(general_chunk.content),
}
doc = Document(page_content=general_chunk.content, metadata=metadata)
if general_chunk.files:
attachments = []
for file in general_chunk.files:
file_metadata = {
"doc_id": file.id,
"doc_hash": "",
"document_id": document.id,
"dataset_id": dataset.id,
"doc_type": DocType.IMAGE,
}
file_document = AttachmentDocument(
page_content=file.filename or "image_file", metadata=file_metadata
)
attachments.append(file_document)
all_multimodal_documents.append(file_document)
doc.attachments = attachments
else:
account = AccountService.load_user(document.created_by)
if not account:
raise ValueError("Invalid account")
doc.attachments = self._get_content_files(doc, current_user=account)
if doc.attachments:
all_multimodal_documents.extend(doc.attachments)
documents.append(doc)
if documents:
# save node to document segment
doc_store = DatasetDocumentStore(dataset=dataset, user_id=document.created_by, document_id=document.id)
# add document segments
doc_store.add_documents(docs=documents, save_child=False)
if dataset.indexing_technique == "high_quality":
vector = Vector(dataset)
vector.create(documents)
if all_multimodal_documents:
vector.create_multimodal(all_multimodal_documents)
elif dataset.indexing_technique == "economy":
keyword = Keyword(dataset)
keyword.add_texts(documents)
def format_preview(self, chunks: Any) -> Mapping[str, Any]:
if isinstance(chunks, list):
preview = []
for content in chunks:
preview.append({"content": content})
return {"chunk_structure": IndexType.PARAGRAPH_INDEX, "preview": preview, "total_segments": len(chunks)}
return {
"chunk_structure": IndexStructureType.PARAGRAPH_INDEX,
"preview": preview,
"total_segments": len(chunks),
}
else:
raise ValueError("Chunks is not a list")

View File

@@ -13,14 +13,17 @@ from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.models.document import ChildDocument, Document, ParentChildStructureChunk
from core.rag.models.document import AttachmentDocument, ChildDocument, Document, ParentChildStructureChunk
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db
from libs import helper
from models import Account
from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment
from models.dataset import Document as DatasetDocument
from services.account_service import AccountService
from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule
@@ -35,7 +38,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
return text_docs
def transform(self, documents: list[Document], **kwargs) -> list[Document]:
def transform(self, documents: list[Document], current_user: Account | None = None, **kwargs) -> list[Document]:
process_rule = kwargs.get("process_rule")
if not process_rule:
raise ValueError("No process rule found.")
@@ -77,6 +80,9 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
page_content = page_content
if len(page_content) > 0:
document_node.page_content = page_content
multimodel_documents = self._get_content_files(document_node, current_user)
if multimodel_documents:
document_node.attachments = multimodel_documents
# parse document to child nodes
child_nodes = self._split_child_nodes(
document_node, rules, process_rule.get("mode"), kwargs.get("embedding_model_instance")
@@ -87,6 +93,9 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
elif rules.parent_mode == ParentMode.FULL_DOC:
page_content = "\n".join([document.page_content for document in documents])
document = Document(page_content=page_content, metadata=documents[0].metadata)
multimodel_documents = self._get_content_files(document)
if multimodel_documents:
document.attachments = multimodel_documents
# parse document to child nodes
child_nodes = self._split_child_nodes(
document, rules, process_rule.get("mode"), kwargs.get("embedding_model_instance")
@@ -104,7 +113,14 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
return all_documents
def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs):
def load(
self,
dataset: Dataset,
documents: list[Document],
multimodal_documents: list[AttachmentDocument] | None = None,
with_keywords: bool = True,
**kwargs,
):
if dataset.indexing_technique == "high_quality":
vector = Vector(dataset)
for document in documents:
@@ -114,6 +130,8 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
Document.model_validate(child_document.model_dump()) for child_document in child_documents
]
vector.create(formatted_child_documents)
if multimodal_documents and dataset.is_multimodal:
vector.create_multimodal(multimodal_documents)
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs):
# node_ids is segment's node_ids
@@ -244,6 +262,24 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
}
child_documents.append(ChildDocument(page_content=child, metadata=child_metadata))
doc = Document(page_content=parent_child.parent_content, metadata=metadata, children=child_documents)
if parent_child.files and len(parent_child.files) > 0:
attachments = []
for file in parent_child.files:
file_metadata = {
"doc_id": file.id,
"doc_hash": "",
"document_id": document.id,
"dataset_id": dataset.id,
"doc_type": DocType.IMAGE,
}
file_document = AttachmentDocument(page_content=file.filename or "", metadata=file_metadata)
attachments.append(file_document)
doc.attachments = attachments
else:
account = AccountService.load_user(document.created_by)
if not account:
raise ValueError("Invalid account")
doc.attachments = self._get_content_files(doc, current_user=account)
documents.append(doc)
if documents:
# update document parent mode
@@ -267,12 +303,17 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
doc_store.add_documents(docs=documents, save_child=True)
if dataset.indexing_technique == "high_quality":
all_child_documents = []
all_multimodal_documents = []
for doc in documents:
if doc.children:
all_child_documents.extend(doc.children)
if doc.attachments:
all_multimodal_documents.extend(doc.attachments)
vector = Vector(dataset)
if all_child_documents:
vector = Vector(dataset)
vector.create(all_child_documents)
if all_multimodal_documents:
vector.create_multimodal(all_multimodal_documents)
def format_preview(self, chunks: Any) -> Mapping[str, Any]:
parent_childs = ParentChildStructureChunk.model_validate(chunks)
@@ -280,7 +321,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
for parent_child in parent_childs.parent_child_chunks:
preview.append({"content": parent_child.parent_content, "child_chunks": parent_child.child_contents})
return {
"chunk_structure": IndexType.PARENT_CHILD_INDEX,
"chunk_structure": IndexStructureType.PARENT_CHILD_INDEX,
"parent_mode": parent_childs.parent_mode,
"preview": preview,
"total_segments": len(parent_childs.parent_child_chunks),

View File

@@ -18,12 +18,13 @@ from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.models.document import Document, QAStructureChunk
from core.rag.models.document import AttachmentDocument, Document, QAStructureChunk
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.utils.text_processing_utils import remove_leading_symbols
from libs import helper
from models.account import Account
from models.dataset import Dataset
from models.dataset import Document as DatasetDocument
from services.entities.knowledge_entities.knowledge_entities import Rule
@@ -41,7 +42,7 @@ class QAIndexProcessor(BaseIndexProcessor):
)
return text_docs
def transform(self, documents: list[Document], **kwargs) -> list[Document]:
def transform(self, documents: list[Document], current_user: Account | None = None, **kwargs) -> list[Document]:
preview = kwargs.get("preview")
process_rule = kwargs.get("process_rule")
if not process_rule:
@@ -116,7 +117,7 @@ class QAIndexProcessor(BaseIndexProcessor):
try:
# Skip the first row
df = pd.read_csv(file)
df = pd.read_csv(file) # type: ignore
text_docs = []
for _, row in df.iterrows():
data = Document(page_content=row.iloc[0], metadata={"answer": row.iloc[1]})
@@ -128,10 +129,19 @@ class QAIndexProcessor(BaseIndexProcessor):
raise ValueError(str(e))
return text_docs
def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs):
def load(
self,
dataset: Dataset,
documents: list[Document],
multimodal_documents: list[AttachmentDocument] | None = None,
with_keywords: bool = True,
**kwargs,
):
if dataset.indexing_technique == "high_quality":
vector = Vector(dataset)
vector.create(documents)
if multimodal_documents and dataset.is_multimodal:
vector.create_multimodal(multimodal_documents)
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs):
vector = Vector(dataset)
@@ -197,7 +207,7 @@ class QAIndexProcessor(BaseIndexProcessor):
for qa_chunk in qa_chunks.qa_chunks:
preview.append({"question": qa_chunk.question, "answer": qa_chunk.answer})
return {
"chunk_structure": IndexType.QA_INDEX,
"chunk_structure": IndexStructureType.QA_INDEX,
"qa_preview": preview,
"total_segments": len(qa_chunks.qa_chunks),
}

View File

@@ -4,6 +4,8 @@ from typing import Any
from pydantic import BaseModel, Field
from core.file import File
class ChildDocument(BaseModel):
"""Class for storing a piece of text and associated metadata."""
@@ -15,7 +17,19 @@ class ChildDocument(BaseModel):
"""Arbitrary metadata about the page content (e.g., source, relationships to other
documents, etc.).
"""
metadata: dict = Field(default_factory=dict)
metadata: dict[str, Any] = Field(default_factory=dict)
class AttachmentDocument(BaseModel):
"""Class for storing a piece of text and associated metadata."""
page_content: str
provider: str | None = "dify"
vector: list[float] | None = None
metadata: dict[str, Any] = Field(default_factory=dict)
class Document(BaseModel):
@@ -28,12 +42,31 @@ class Document(BaseModel):
"""Arbitrary metadata about the page content (e.g., source, relationships to other
documents, etc.).
"""
metadata: dict = Field(default_factory=dict)
metadata: dict[str, Any] = Field(default_factory=dict)
provider: str | None = "dify"
children: list[ChildDocument] | None = None
attachments: list[AttachmentDocument] | None = None
class GeneralChunk(BaseModel):
"""
General Chunk.
"""
content: str
files: list[File] | None = None
class MultimodalGeneralStructureChunk(BaseModel):
"""
Multimodal General Structure Chunk.
"""
general_chunks: list[GeneralChunk]
class GeneralStructureChunk(BaseModel):
"""
@@ -50,6 +83,7 @@ class ParentChildChunk(BaseModel):
parent_content: str
child_contents: list[str]
files: list[File] | None = None
class ParentChildStructureChunk(BaseModel):

View File

@@ -1,5 +1,6 @@
from abc import ABC, abstractmethod
from core.rag.index_processor.constant.query_type import QueryType
from core.rag.models.document import Document
@@ -12,6 +13,7 @@ class BaseRerankRunner(ABC):
score_threshold: float | None = None,
top_n: int | None = None,
user: str | None = None,
query_type: QueryType = QueryType.TEXT_QUERY,
) -> list[Document]:
"""
Run rerank model

View File

@@ -1,6 +1,15 @@
from core.model_manager import ModelInstance
import base64
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.entities.rerank_entities import RerankResult
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.query_type import QueryType
from core.rag.models.document import Document
from core.rag.rerank.rerank_base import BaseRerankRunner
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.model import UploadFile
class RerankModelRunner(BaseRerankRunner):
@@ -14,6 +23,7 @@ class RerankModelRunner(BaseRerankRunner):
score_threshold: float | None = None,
top_n: int | None = None,
user: str | None = None,
query_type: QueryType = QueryType.TEXT_QUERY,
) -> list[Document]:
"""
Run rerank model
@@ -24,6 +34,56 @@ class RerankModelRunner(BaseRerankRunner):
:param user: unique user id if needed
:return:
"""
model_manager = ModelManager()
is_support_vision = model_manager.check_model_support_vision(
tenant_id=self.rerank_model_instance.provider_model_bundle.configuration.tenant_id,
provider=self.rerank_model_instance.provider,
model=self.rerank_model_instance.model,
model_type=ModelType.RERANK,
)
if not is_support_vision:
if query_type == QueryType.TEXT_QUERY:
rerank_result, unique_documents = self.fetch_text_rerank(query, documents, score_threshold, top_n, user)
else:
return documents
else:
rerank_result, unique_documents = self.fetch_multimodal_rerank(
query, documents, score_threshold, top_n, user, query_type
)
rerank_documents = []
for result in rerank_result.docs:
if score_threshold is None or result.score >= score_threshold:
# format document
rerank_document = Document(
page_content=result.text,
metadata=unique_documents[result.index].metadata,
provider=unique_documents[result.index].provider,
)
if rerank_document.metadata is not None:
rerank_document.metadata["score"] = result.score
rerank_documents.append(rerank_document)
rerank_documents.sort(key=lambda x: x.metadata.get("score", 0.0), reverse=True)
return rerank_documents[:top_n] if top_n else rerank_documents
def fetch_text_rerank(
self,
query: str,
documents: list[Document],
score_threshold: float | None = None,
top_n: int | None = None,
user: str | None = None,
) -> tuple[RerankResult, list[Document]]:
"""
Fetch text rerank
:param query: search query
:param documents: documents for reranking
:param score_threshold: score threshold
:param top_n: top n
:param user: unique user id if needed
:return:
"""
docs = []
doc_ids = set()
unique_documents = []
@@ -33,33 +93,99 @@ class RerankModelRunner(BaseRerankRunner):
and document.metadata is not None
and document.metadata["doc_id"] not in doc_ids
):
doc_ids.add(document.metadata["doc_id"])
docs.append(document.page_content)
unique_documents.append(document)
if not document.metadata.get("doc_type") or document.metadata.get("doc_type") == DocType.TEXT:
doc_ids.add(document.metadata["doc_id"])
docs.append(document.page_content)
unique_documents.append(document)
elif document.provider == "external":
if document not in unique_documents:
docs.append(document.page_content)
unique_documents.append(document)
documents = unique_documents
rerank_result = self.rerank_model_instance.invoke_rerank(
query=query, docs=docs, score_threshold=score_threshold, top_n=top_n, user=user
)
return rerank_result, unique_documents
rerank_documents = []
def fetch_multimodal_rerank(
self,
query: str,
documents: list[Document],
score_threshold: float | None = None,
top_n: int | None = None,
user: str | None = None,
query_type: QueryType = QueryType.TEXT_QUERY,
) -> tuple[RerankResult, list[Document]]:
"""
Fetch multimodal rerank
:param query: search query
:param documents: documents for reranking
:param score_threshold: score threshold
:param top_n: top n
:param user: unique user id if needed
:param query_type: query type
:return: rerank result
"""
docs = []
doc_ids = set()
unique_documents = []
for document in documents:
if (
document.provider == "dify"
and document.metadata is not None
and document.metadata["doc_id"] not in doc_ids
):
if document.metadata.get("doc_type") == DocType.IMAGE:
# Query file info within db.session context to ensure thread-safe access
upload_file = (
db.session.query(UploadFile).where(UploadFile.id == document.metadata["doc_id"]).first()
)
if upload_file:
blob = storage.load_once(upload_file.key)
document_file_base64 = base64.b64encode(blob).decode()
document_file_dict = {
"content": document_file_base64,
"content_type": document.metadata["doc_type"],
}
docs.append(document_file_dict)
else:
document_text_dict = {
"content": document.page_content,
"content_type": document.metadata.get("doc_type") or DocType.TEXT,
}
docs.append(document_text_dict)
doc_ids.add(document.metadata["doc_id"])
unique_documents.append(document)
elif document.provider == "external":
if document not in unique_documents:
docs.append(
{
"content": document.page_content,
"content_type": document.metadata.get("doc_type") or DocType.TEXT,
}
)
unique_documents.append(document)
for result in rerank_result.docs:
if score_threshold is None or result.score >= score_threshold:
# format document
rerank_document = Document(
page_content=result.text,
metadata=documents[result.index].metadata,
provider=documents[result.index].provider,
documents = unique_documents
if query_type == QueryType.TEXT_QUERY:
rerank_result, unique_documents = self.fetch_text_rerank(query, documents, score_threshold, top_n, user)
return rerank_result, unique_documents
elif query_type == QueryType.IMAGE_QUERY:
# Query file info within db.session context to ensure thread-safe access
upload_file = db.session.query(UploadFile).where(UploadFile.id == query).first()
if upload_file:
blob = storage.load_once(upload_file.key)
file_query = base64.b64encode(blob).decode()
file_query_dict = {
"content": file_query,
"content_type": DocType.IMAGE,
}
rerank_result = self.rerank_model_instance.invoke_multimodal_rerank(
query=file_query_dict, docs=docs, score_threshold=score_threshold, top_n=top_n, user=user
)
if rerank_document.metadata is not None:
rerank_document.metadata["score"] = result.score
rerank_documents.append(rerank_document)
return rerank_result, unique_documents
else:
raise ValueError(f"Upload file not found for query: {query}")
rerank_documents.sort(key=lambda x: x.metadata.get("score", 0.0), reverse=True)
return rerank_documents[:top_n] if top_n else rerank_documents
else:
raise ValueError(f"Query type {query_type} is not supported")

View File

@@ -7,6 +7,8 @@ from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
from core.rag.embedding.cached_embedding import CacheEmbedding
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.query_type import QueryType
from core.rag.models.document import Document
from core.rag.rerank.entity.weight import VectorSetting, Weights
from core.rag.rerank.rerank_base import BaseRerankRunner
@@ -24,6 +26,7 @@ class WeightRerankRunner(BaseRerankRunner):
score_threshold: float | None = None,
top_n: int | None = None,
user: str | None = None,
query_type: QueryType = QueryType.TEXT_QUERY,
) -> list[Document]:
"""
Run rerank model
@@ -43,8 +46,10 @@ class WeightRerankRunner(BaseRerankRunner):
and document.metadata is not None
and document.metadata["doc_id"] not in doc_ids
):
doc_ids.add(document.metadata["doc_id"])
unique_documents.append(document)
# weight rerank only support text documents
if not document.metadata.get("doc_type") or document.metadata.get("doc_type") == DocType.TEXT:
doc_ids.add(document.metadata["doc_id"])
unique_documents.append(document)
else:
if document not in unique_documents:
unique_documents.append(document)

View File

@@ -8,6 +8,7 @@ from typing import Any, Union, cast
from flask import Flask, current_app
from sqlalchemy import and_, or_, select
from sqlalchemy.orm import Session
from core.app.app_config.entities import (
DatasetEntity,
@@ -19,6 +20,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCre
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.entities.agent_entities import PlanningStrategy
from core.entities.model_entities import ModelStatus
from core.file import File, FileTransferMethod, FileType
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
@@ -37,7 +39,9 @@ from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.rag.entities.context_entities import DocumentContext
from core.rag.entities.metadata_entities import Condition, MetadataCondition
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from core.rag.index_processor.constant.query_type import QueryType
from core.rag.models.document import Document
from core.rag.rerank.rerank_type import RerankMode
from core.rag.retrieval.retrieval_methods import RetrievalMethod
@@ -52,10 +56,12 @@ from core.rag.retrieval.template_prompts import (
METADATA_FILTER_USER_PROMPT_2,
METADATA_FILTER_USER_PROMPT_3,
)
from core.tools.signature import sign_upload_file
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
from extensions.ext_database import db
from libs.json_in_md_parser import parse_and_check_json_markdown
from models.dataset import ChildChunk, Dataset, DatasetMetadata, DatasetQuery, DocumentSegment
from models import UploadFile
from models.dataset import ChildChunk, Dataset, DatasetMetadata, DatasetQuery, DocumentSegment, SegmentAttachmentBinding
from models.dataset import Document as DatasetDocument
from services.external_knowledge_service import ExternalDatasetService
@@ -99,7 +105,8 @@ class DatasetRetrieval:
message_id: str,
memory: TokenBufferMemory | None = None,
inputs: Mapping[str, Any] | None = None,
) -> str | None:
vision_enabled: bool = False,
) -> tuple[str | None, list[File] | None]:
"""
Retrieve dataset.
:param app_id: app_id
@@ -118,7 +125,7 @@ class DatasetRetrieval:
"""
dataset_ids = config.dataset_ids
if len(dataset_ids) == 0:
return None
return None, []
retrieve_config = config.retrieve_config
# check model is support tool calling
@@ -136,7 +143,7 @@ class DatasetRetrieval:
)
if not model_schema:
return None
return None, []
planning_strategy = PlanningStrategy.REACT_ROUTER
features = model_schema.features
@@ -182,8 +189,8 @@ class DatasetRetrieval:
tenant_id,
user_id,
user_from,
available_datasets,
query,
available_datasets,
model_instance,
model_config,
planning_strategy,
@@ -213,6 +220,7 @@ class DatasetRetrieval:
dify_documents = [item for item in all_documents if item.provider == "dify"]
external_documents = [item for item in all_documents if item.provider == "external"]
document_context_list: list[DocumentContext] = []
context_files: list[File] = []
retrieval_resource_list: list[RetrievalSourceMetadata] = []
# deal with external documents
for item in external_documents:
@@ -248,6 +256,31 @@ class DatasetRetrieval:
score=record.score,
)
)
if vision_enabled:
attachments_with_bindings = db.session.execute(
select(SegmentAttachmentBinding, UploadFile)
.join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id)
.where(
SegmentAttachmentBinding.segment_id == segment.id,
)
).all()
if attachments_with_bindings:
for _, upload_file in attachments_with_bindings:
attchment_info = File(
id=upload_file.id,
filename=upload_file.name,
extension="." + upload_file.extension,
mime_type=upload_file.mime_type,
tenant_id=segment.tenant_id,
type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
remote_url=upload_file.source_url,
related_id=upload_file.id,
size=upload_file.size,
storage_key=upload_file.key,
url=sign_upload_file(upload_file.id, upload_file.extension),
)
context_files.append(attchment_info)
if show_retrieve_source:
for record in records:
segment = record.segment
@@ -288,8 +321,10 @@ class DatasetRetrieval:
hit_callback.return_retriever_resource_info(retrieval_resource_list)
if document_context_list:
document_context_list = sorted(document_context_list, key=lambda x: x.score or 0.0, reverse=True)
return str("\n".join([document_context.content for document_context in document_context_list]))
return ""
return str(
"\n".join([document_context.content for document_context in document_context_list])
), context_files
return "", context_files
def single_retrieve(
self,
@@ -297,8 +332,8 @@ class DatasetRetrieval:
tenant_id: str,
user_id: str,
user_from: str,
available_datasets: list,
query: str,
available_datasets: list,
model_instance: ModelInstance,
model_config: ModelConfigWithCredentialsEntity,
planning_strategy: PlanningStrategy,
@@ -336,7 +371,7 @@ class DatasetRetrieval:
dataset_id, router_usage = function_call_router.invoke(query, tools, model_config, model_instance)
self._record_usage(router_usage)
timer = None
if dataset_id:
# get retrieval model config
dataset_stmt = select(Dataset).where(Dataset.id == dataset_id)
@@ -406,10 +441,19 @@ class DatasetRetrieval:
weights=retrieval_model_config.get("weights", None),
document_ids_filter=document_ids_filter,
)
self._on_query(query, [dataset_id], app_id, user_from, user_id)
self._on_query(query, None, [dataset_id], app_id, user_from, user_id)
if results:
self._on_retrieval_end(results, message_id, timer)
thread = threading.Thread(
target=self._on_retrieval_end,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"documents": results,
"message_id": message_id,
"timer": timer,
},
)
thread.start()
return results
return []
@@ -421,7 +465,7 @@ class DatasetRetrieval:
user_id: str,
user_from: str,
available_datasets: list,
query: str,
query: str | None,
top_k: int,
score_threshold: float,
reranking_mode: str,
@@ -431,10 +475,11 @@ class DatasetRetrieval:
message_id: str | None = None,
metadata_filter_document_ids: dict[str, list[str]] | None = None,
metadata_condition: MetadataCondition | None = None,
attachment_ids: list[str] | None = None,
):
if not available_datasets:
return []
threads = []
all_threads = []
all_documents: list[Document] = []
dataset_ids = [dataset.id for dataset in available_datasets]
index_type_check = all(
@@ -467,131 +512,226 @@ class DatasetRetrieval:
0
].embedding_model_provider
weights["vector_setting"]["embedding_model_name"] = available_datasets[0].embedding_model
for dataset in available_datasets:
index_type = dataset.indexing_technique
document_ids_filter = None
if dataset.provider != "external":
if metadata_condition and not metadata_filter_document_ids:
continue
if metadata_filter_document_ids:
document_ids = metadata_filter_document_ids.get(dataset.id, [])
if document_ids:
document_ids_filter = document_ids
else:
continue
retrieval_thread = threading.Thread(
target=self._retriever,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"dataset_id": dataset.id,
"query": query,
"top_k": top_k,
"all_documents": all_documents,
"document_ids_filter": document_ids_filter,
"metadata_condition": metadata_condition,
},
)
threads.append(retrieval_thread)
retrieval_thread.start()
for thread in threads:
thread.join()
with measure_time() as timer:
if reranking_enable:
# do rerank for searched documents
data_post_processor = DataPostProcessor(tenant_id, reranking_mode, reranking_model, weights, False)
all_documents = data_post_processor.invoke(
query=query, documents=all_documents, score_threshold=score_threshold, top_n=top_k
if query:
query_thread = threading.Thread(
target=self._multiple_retrieve_thread,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"available_datasets": available_datasets,
"metadata_condition": metadata_condition,
"metadata_filter_document_ids": metadata_filter_document_ids,
"all_documents": all_documents,
"tenant_id": tenant_id,
"reranking_enable": reranking_enable,
"reranking_mode": reranking_mode,
"reranking_model": reranking_model,
"weights": weights,
"top_k": top_k,
"score_threshold": score_threshold,
"query": query,
"attachment_id": None,
},
)
else:
if index_type == "economy":
all_documents = self.calculate_keyword_score(query, all_documents, top_k)
elif index_type == "high_quality":
all_documents = self.calculate_vector_score(all_documents, top_k, score_threshold)
else:
all_documents = all_documents[:top_k] if top_k else all_documents
self._on_query(query, dataset_ids, app_id, user_from, user_id)
all_threads.append(query_thread)
query_thread.start()
if attachment_ids:
for attachment_id in attachment_ids:
attachment_thread = threading.Thread(
target=self._multiple_retrieve_thread,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"available_datasets": available_datasets,
"metadata_condition": metadata_condition,
"metadata_filter_document_ids": metadata_filter_document_ids,
"all_documents": all_documents,
"tenant_id": tenant_id,
"reranking_enable": reranking_enable,
"reranking_mode": reranking_mode,
"reranking_model": reranking_model,
"weights": weights,
"top_k": top_k,
"score_threshold": score_threshold,
"query": None,
"attachment_id": attachment_id,
},
)
all_threads.append(attachment_thread)
attachment_thread.start()
for thread in all_threads:
thread.join()
self._on_query(query, attachment_ids, dataset_ids, app_id, user_from, user_id)
if all_documents:
self._on_retrieval_end(all_documents, message_id, timer)
return all_documents
def _on_retrieval_end(self, documents: list[Document], message_id: str | None = None, timer: dict | None = None):
"""Handle retrieval end."""
dify_documents = [document for document in documents if document.provider == "dify"]
for document in dify_documents:
if document.metadata is not None:
dataset_document_stmt = select(DatasetDocument).where(
DatasetDocument.id == document.metadata["document_id"]
)
dataset_document = db.session.scalar(dataset_document_stmt)
if dataset_document:
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_chunk_stmt = select(ChildChunk).where(
ChildChunk.index_node_id == document.metadata["doc_id"],
ChildChunk.dataset_id == dataset_document.dataset_id,
ChildChunk.document_id == dataset_document.id,
)
child_chunk = db.session.scalar(child_chunk_stmt)
if child_chunk:
_ = (
db.session.query(DocumentSegment)
.where(DocumentSegment.id == child_chunk.segment_id)
.update(
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
synchronize_session=False,
)
)
else:
query = db.session.query(DocumentSegment).where(
DocumentSegment.index_node_id == document.metadata["doc_id"]
)
# if 'dataset_id' in document.metadata:
if "dataset_id" in document.metadata:
query = query.where(DocumentSegment.dataset_id == document.metadata["dataset_id"])
# add hit count to document segment
query.update(
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False
)
db.session.commit()
# get tracing instance
trace_manager: TraceQueueManager | None = (
self.application_generate_entity.trace_manager if self.application_generate_entity else None
)
if trace_manager:
trace_manager.add_trace_task(
TraceTask(
TraceTaskName.DATASET_RETRIEVAL_TRACE, message_id=message_id, documents=documents, timer=timer
)
# add thread to call _on_retrieval_end
retrieval_end_thread = threading.Thread(
target=self._on_retrieval_end,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"documents": all_documents,
"message_id": message_id,
"timer": timer,
},
)
retrieval_end_thread.start()
retrieval_resource_list = []
doc_ids_filter = []
for document in all_documents:
if document.provider == "dify":
doc_id = document.metadata.get("doc_id")
if doc_id and doc_id not in doc_ids_filter:
doc_ids_filter.append(doc_id)
retrieval_resource_list.append(document)
elif document.provider == "external":
retrieval_resource_list.append(document)
return retrieval_resource_list
def _on_query(self, query: str, dataset_ids: list[str], app_id: str, user_from: str, user_id: str):
def _on_retrieval_end(
self, flask_app: Flask, documents: list[Document], message_id: str | None = None, timer: dict | None = None
):
"""Handle retrieval end."""
with flask_app.app_context():
dify_documents = [document for document in documents if document.provider == "dify"]
segment_ids = []
segment_index_node_ids = []
with Session(db.engine) as session:
for document in dify_documents:
if document.metadata is not None:
dataset_document_stmt = select(DatasetDocument).where(
DatasetDocument.id == document.metadata["document_id"]
)
dataset_document = session.scalar(dataset_document_stmt)
if dataset_document:
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
segment_id = None
if (
"doc_type" not in document.metadata
or document.metadata.get("doc_type") == DocType.TEXT
):
child_chunk_stmt = select(ChildChunk).where(
ChildChunk.index_node_id == document.metadata["doc_id"],
ChildChunk.dataset_id == dataset_document.dataset_id,
ChildChunk.document_id == dataset_document.id,
)
child_chunk = session.scalar(child_chunk_stmt)
if child_chunk:
segment_id = child_chunk.segment_id
elif (
"doc_type" in document.metadata
and document.metadata.get("doc_type") == DocType.IMAGE
):
attachment_info_dict = RetrievalService.get_segment_attachment_info(
dataset_document.dataset_id,
dataset_document.tenant_id,
document.metadata.get("doc_id") or "",
session,
)
if attachment_info_dict:
segment_id = attachment_info_dict["segment_id"]
if segment_id:
if segment_id not in segment_ids:
segment_ids.append(segment_id)
_ = (
session.query(DocumentSegment)
.where(DocumentSegment.id == segment_id)
.update(
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
synchronize_session=False,
)
)
else:
query = None
if (
"doc_type" not in document.metadata
or document.metadata.get("doc_type") == DocType.TEXT
):
if document.metadata["doc_id"] not in segment_index_node_ids:
segment = (
session.query(DocumentSegment)
.where(DocumentSegment.index_node_id == document.metadata["doc_id"])
.first()
)
if segment:
segment_index_node_ids.append(document.metadata["doc_id"])
segment_ids.append(segment.id)
query = session.query(DocumentSegment).where(
DocumentSegment.id == segment.id
)
elif (
"doc_type" in document.metadata
and document.metadata.get("doc_type") == DocType.IMAGE
):
attachment_info_dict = RetrievalService.get_segment_attachment_info(
dataset_document.dataset_id,
dataset_document.tenant_id,
document.metadata.get("doc_id") or "",
session,
)
if attachment_info_dict:
segment_id = attachment_info_dict["segment_id"]
if segment_id not in segment_ids:
segment_ids.append(segment_id)
query = session.query(DocumentSegment).where(DocumentSegment.id == segment_id)
if query:
# if 'dataset_id' in document.metadata:
if "dataset_id" in document.metadata:
query = query.where(
DocumentSegment.dataset_id == document.metadata["dataset_id"]
)
# add hit count to document segment
query.update(
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
synchronize_session=False,
)
db.session.commit()
# get tracing instance
trace_manager: TraceQueueManager | None = (
self.application_generate_entity.trace_manager if self.application_generate_entity else None
)
if trace_manager:
trace_manager.add_trace_task(
TraceTask(
TraceTaskName.DATASET_RETRIEVAL_TRACE, message_id=message_id, documents=documents, timer=timer
)
)
def _on_query(
self,
query: str | None,
attachment_ids: list[str] | None,
dataset_ids: list[str],
app_id: str,
user_from: str,
user_id: str,
):
"""
Handle query.
"""
if not query:
if not query and not attachment_ids:
return
dataset_queries = []
for dataset_id in dataset_ids:
dataset_query = DatasetQuery(
dataset_id=dataset_id,
content=query,
source="app",
source_app_id=app_id,
created_by_role=user_from,
created_by=user_id,
)
dataset_queries.append(dataset_query)
if dataset_queries:
db.session.add_all(dataset_queries)
contents = []
if query:
contents.append({"content_type": QueryType.TEXT_QUERY, "content": query})
if attachment_ids:
for attachment_id in attachment_ids:
contents.append({"content_type": QueryType.IMAGE_QUERY, "content": attachment_id})
if contents:
dataset_query = DatasetQuery(
dataset_id=dataset_id,
content=json.dumps(contents),
source="app",
source_app_id=app_id,
created_by_role=user_from,
created_by=user_id,
)
dataset_queries.append(dataset_query)
if dataset_queries:
db.session.add_all(dataset_queries)
db.session.commit()
def _retriever(
@@ -603,6 +743,7 @@ class DatasetRetrieval:
all_documents: list,
document_ids_filter: list[str] | None = None,
metadata_condition: MetadataCondition | None = None,
attachment_ids: list[str] | None = None,
):
with flask_app.app_context():
dataset_stmt = select(Dataset).where(Dataset.id == dataset_id)
@@ -611,7 +752,7 @@ class DatasetRetrieval:
if not dataset:
return []
if dataset.provider == "external":
if dataset.provider == "external" and query:
external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
tenant_id=dataset.tenant_id,
dataset_id=dataset_id,
@@ -663,6 +804,7 @@ class DatasetRetrieval:
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
weights=retrieval_model.get("weights", None),
document_ids_filter=document_ids_filter,
attachment_ids=attachment_ids,
)
all_documents.extend(documents)
@@ -1222,3 +1364,86 @@ class DatasetRetrieval:
usage = LLMUsage.empty_usage()
return full_text, usage
def _multiple_retrieve_thread(
self,
flask_app: Flask,
available_datasets: list,
metadata_condition: MetadataCondition | None,
metadata_filter_document_ids: dict[str, list[str]] | None,
all_documents: list[Document],
tenant_id: str,
reranking_enable: bool,
reranking_mode: str,
reranking_model: dict | None,
weights: dict[str, Any] | None,
top_k: int,
score_threshold: float,
query: str | None,
attachment_id: str | None,
):
with flask_app.app_context():
threads = []
all_documents_item: list[Document] = []
index_type = None
for dataset in available_datasets:
index_type = dataset.indexing_technique
document_ids_filter = None
if dataset.provider != "external":
if metadata_condition and not metadata_filter_document_ids:
continue
if metadata_filter_document_ids:
document_ids = metadata_filter_document_ids.get(dataset.id, [])
if document_ids:
document_ids_filter = document_ids
else:
continue
retrieval_thread = threading.Thread(
target=self._retriever,
kwargs={
"flask_app": flask_app,
"dataset_id": dataset.id,
"query": query,
"top_k": top_k,
"all_documents": all_documents_item,
"document_ids_filter": document_ids_filter,
"metadata_condition": metadata_condition,
"attachment_ids": [attachment_id] if attachment_id else None,
},
)
threads.append(retrieval_thread)
retrieval_thread.start()
for thread in threads:
thread.join()
if reranking_enable:
# do rerank for searched documents
data_post_processor = DataPostProcessor(tenant_id, reranking_mode, reranking_model, weights, False)
if query:
all_documents_item = data_post_processor.invoke(
query=query,
documents=all_documents_item,
score_threshold=score_threshold,
top_n=top_k,
query_type=QueryType.TEXT_QUERY,
)
if attachment_id:
all_documents_item = data_post_processor.invoke(
documents=all_documents_item,
score_threshold=score_threshold,
top_n=top_k,
query_type=QueryType.IMAGE_QUERY,
query=attachment_id,
)
else:
if index_type == IndexTechniqueType.ECONOMY:
if not query:
all_documents_item = []
else:
all_documents_item = self.calculate_keyword_score(query, all_documents_item, top_k)
elif index_type == IndexTechniqueType.HIGH_QUALITY:
all_documents_item = self.calculate_vector_score(all_documents_item, top_k, score_threshold)
else:
all_documents_item = all_documents_item[:top_k] if top_k else all_documents_item
if all_documents_item:
all_documents.extend(all_documents_item)

View File

@@ -0,0 +1,65 @@
{
"$id": "https://dify.ai/schemas/v1/multimodal_general_structure.json",
"$schema": "http://json-schema.org/draft-07/schema#",
"version": "1.0.0",
"type": "array",
"title": "Multimodal General Structure",
"description": "Schema for multimodal general structure (v1) - array of objects",
"properties": {
"general_chunks": {
"type": "array",
"items": {
"type": "object",
"properties": {
"content": {
"type": "string",
"description": "The content"
},
"files": {
"type": "array",
"items": {
"type": "object",
"properties": {
"name": {
"type": "string",
"description": "file name"
},
"size": {
"type": "number",
"description": "file size"
},
"extension": {
"type": "string",
"description": "file extension"
},
"type": {
"type": "string",
"description": "file type"
},
"mime_type": {
"type": "string",
"description": "file mime type"
},
"transfer_method": {
"type": "string",
"description": "file transfer method"
},
"url": {
"type": "string",
"description": "file url"
},
"related_id": {
"type": "string",
"description": "file related id"
}
},
"description": "List of files"
}
}
},
"required": ["content"]
},
"description": "List of content and files"
}
}
}

View File

@@ -0,0 +1,78 @@
{
"$id": "https://dify.ai/schemas/v1/multimodal_parent_child_structure.json",
"$schema": "http://json-schema.org/draft-07/schema#",
"version": "1.0.0",
"type": "object",
"title": "Multimodal Parent-Child Structure",
"description": "Schema for multimodal parent-child structure (v1)",
"properties": {
"parent_mode": {
"type": "string",
"description": "The mode of parent-child relationship"
},
"parent_child_chunks": {
"type": "array",
"items": {
"type": "object",
"properties": {
"parent_content": {
"type": "string",
"description": "The parent content"
},
"files": {
"type": "array",
"items": {
"type": "object",
"properties": {
"name": {
"type": "string",
"description": "file name"
},
"size": {
"type": "number",
"description": "file size"
},
"extension": {
"type": "string",
"description": "file extension"
},
"type": {
"type": "string",
"description": "file type"
},
"mime_type": {
"type": "string",
"description": "file mime type"
},
"transfer_method": {
"type": "string",
"description": "file transfer method"
},
"url": {
"type": "string",
"description": "file url"
},
"related_id": {
"type": "string",
"description": "file related id"
}
},
"required": ["name", "size", "extension", "type", "mime_type", "transfer_method", "url", "related_id"]
},
"description": "List of files"
},
"child_contents": {
"type": "array",
"items": {
"type": "string"
},
"description": "List of child contents"
}
},
"required": ["parent_content", "child_contents"]
},
"description": "List of parent-child chunk pairs"
}
},
"required": ["parent_mode", "parent_child_chunks"]
}

View File

@@ -25,6 +25,24 @@ def sign_tool_file(tool_file_id: str, extension: str) -> str:
return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
def sign_upload_file(upload_file_id: str, extension: str) -> str:
"""
sign file to get a temporary url for plugin access
"""
# Use internal URL for plugin/tool file access in Docker environments
base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL
file_preview_url = f"{base_url}/files/{upload_file_id}/image-preview"
timestamp = str(int(time.time()))
nonce = os.urandom(16).hex()
data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}"
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
encoded_sign = base64.urlsafe_b64encode(sign).decode()
return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
def verify_tool_file_signature(file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
"""
verify signature

View File

@@ -13,5 +13,5 @@ def remove_leading_symbols(text: str) -> str:
"""
# Match Unicode ranges for punctuation and symbols
# FIXME this pattern is confused quick fix for #11868 maybe refactor it later
pattern = r"^[\u2000-\u206F\u2E00-\u2E7F\u3000-\u303F!\"#$%&'()*+,./:;<=>?@^_`~]+"
pattern = r"^[\u2000-\u206F\u2E00-\u2E7F\u3000-\u303F\"#$%&'()*+,./:;<=>?@^_`~]+"
return re.sub(pattern, "", text)

View File

@@ -3,6 +3,7 @@ from datetime import datetime
from pydantic import Field
from core.file import File
from core.model_runtime.entities.llm_entities import LLMUsage
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities.pause_reason import PauseReason
@@ -14,6 +15,7 @@ from .base import NodeEventBase
class RunRetrieverResourceEvent(NodeEventBase):
retriever_resources: Sequence[RetrievalSourceMetadata] = Field(..., description="retriever resources")
context: str = Field(..., description="context")
context_files: list[File] | None = Field(default=None, description="context files")
class ModelInvokeCompletedEvent(NodeEventBase):

View File

@@ -114,7 +114,8 @@ class KnowledgeRetrievalNodeData(BaseNodeData):
"""
type: str = "knowledge-retrieval"
query_variable_selector: list[str]
query_variable_selector: list[str] | None | str = None
query_attachment_selector: list[str] | None | str = None
dataset_ids: list[str]
retrieval_mode: Literal["single", "multiple"]
multiple_retrieval_config: MultipleRetrievalConfig | None = None

View File

@@ -25,6 +25,8 @@ from core.rag.entities.metadata_entities import Condition, MetadataCondition
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.variables import (
ArrayFileSegment,
FileSegment,
StringSegment,
)
from core.variables.segments import ArrayObjectSegment
@@ -119,20 +121,41 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
return "1"
def _run(self) -> NodeRunResult:
# extract variables
variable = self.graph_runtime_state.variable_pool.get(self.node_data.query_variable_selector)
if not isinstance(variable, StringSegment):
if not self._node_data.query_variable_selector and not self._node_data.query_attachment_selector:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs={},
error="Query variable is not string type.",
)
query = variable.value
variables = {"query": query}
if not query:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error="Query is required."
process_data={},
outputs={},
metadata={},
llm_usage=LLMUsage.empty_usage(),
)
variables: dict[str, Any] = {}
# extract variables
if self._node_data.query_variable_selector:
variable = self.graph_runtime_state.variable_pool.get(self._node_data.query_variable_selector)
if not isinstance(variable, StringSegment):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs={},
error="Query variable is not string type.",
)
query = variable.value
variables["query"] = query
if self._node_data.query_attachment_selector:
variable = self.graph_runtime_state.variable_pool.get(self._node_data.query_attachment_selector)
if not isinstance(variable, ArrayFileSegment) and not isinstance(variable, FileSegment):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs={},
error="Attachments variable is not array file or file type.",
)
if isinstance(variable, ArrayFileSegment):
variables["attachments"] = variable.value
else:
variables["attachments"] = [variable.value]
# TODO(-LAN-): Move this check outside.
# check rate limit
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(self.tenant_id)
@@ -161,7 +184,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
# retrieve knowledge
usage = LLMUsage.empty_usage()
try:
results, usage = self._fetch_dataset_retriever(node_data=self.node_data, query=query)
results, usage = self._fetch_dataset_retriever(node_data=self._node_data, variables=variables)
outputs = {"result": ArrayObjectSegment(value=results)}
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
@@ -198,12 +221,16 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
db.session.close()
def _fetch_dataset_retriever(
self, node_data: KnowledgeRetrievalNodeData, query: str
self, node_data: KnowledgeRetrievalNodeData, variables: dict[str, Any]
) -> tuple[list[dict[str, Any]], LLMUsage]:
usage = LLMUsage.empty_usage()
available_datasets = []
dataset_ids = node_data.dataset_ids
query = variables.get("query")
attachments = variables.get("attachments")
metadata_filter_document_ids = None
metadata_condition = None
metadata_usage = LLMUsage.empty_usage()
# Subquery: Count the number of available documents for each dataset
subquery = (
db.session.query(Document.dataset_id, func.count(Document.id).label("available_document_count"))
@@ -234,13 +261,14 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
if not dataset:
continue
available_datasets.append(dataset)
metadata_filter_document_ids, metadata_condition, metadata_usage = self._get_metadata_filter_condition(
[dataset.id for dataset in available_datasets], query, node_data
)
usage = self._merge_usage(usage, metadata_usage)
if query:
metadata_filter_document_ids, metadata_condition, metadata_usage = self._get_metadata_filter_condition(
[dataset.id for dataset in available_datasets], query, node_data
)
usage = self._merge_usage(usage, metadata_usage)
all_documents = []
dataset_retrieval = DatasetRetrieval()
if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
if str(node_data.retrieval_mode) == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE and query:
# fetch model config
if node_data.single_retrieval_config is None:
raise ValueError("single_retrieval_config is required")
@@ -272,7 +300,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
metadata_filter_document_ids=metadata_filter_document_ids,
metadata_condition=metadata_condition,
)
elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
elif str(node_data.retrieval_mode) == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
if node_data.multiple_retrieval_config is None:
raise ValueError("multiple_retrieval_config is required")
if node_data.multiple_retrieval_config.reranking_mode == "reranking_model":
@@ -319,6 +347,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
reranking_enable=node_data.multiple_retrieval_config.reranking_enable,
metadata_filter_document_ids=metadata_filter_document_ids,
metadata_condition=metadata_condition,
attachment_ids=[attachment.related_id for attachment in attachments] if attachments else None,
)
usage = self._merge_usage(usage, dataset_retrieval.llm_usage)
@@ -327,7 +356,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
retrieval_resource_list = []
# deal with external documents
for item in external_documents:
source = {
source: dict[str, dict[str, str | Any | dict[Any, Any] | None] | Any | str | None] = {
"metadata": {
"_source": "knowledge",
"dataset_id": item.metadata.get("dataset_id"),
@@ -384,6 +413,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
"doc_metadata": document.doc_metadata,
},
"title": document.name,
"files": list(record.files) if record.files else None,
}
if segment.answer:
source["content"] = f"question:{segment.get_sign_content()} \nanswer:{segment.answer}"
@@ -393,13 +423,21 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
if retrieval_resource_list:
retrieval_resource_list = sorted(
retrieval_resource_list,
key=lambda x: x["metadata"]["score"] if x["metadata"].get("score") is not None else 0.0,
key=self._score, # type: ignore[arg-type, return-value]
reverse=True,
)
for position, item in enumerate(retrieval_resource_list, start=1):
item["metadata"]["position"] = position
item["metadata"]["position"] = position # type: ignore[index]
return retrieval_resource_list, usage
def _score(self, item: dict[str, Any]) -> float:
meta = item.get("metadata")
if isinstance(meta, dict):
s = meta.get("score")
if isinstance(s, (int, float)):
return float(s)
return 0.0
def _get_metadata_filter_condition(
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
) -> tuple[dict[str, list[str]] | None, MetadataCondition | None, LLMUsage]:
@@ -659,7 +697,10 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
typed_node_data = KnowledgeRetrievalNodeData.model_validate(node_data)
variable_mapping = {}
variable_mapping[node_id + ".query"] = typed_node_data.query_variable_selector
if typed_node_data.query_variable_selector:
variable_mapping[node_id + ".query"] = typed_node_data.query_variable_selector
if typed_node_data.query_attachment_selector:
variable_mapping[node_id + ".queryAttachment"] = typed_node_data.query_attachment_selector
return variable_mapping
def get_model_config(self, model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:

View File

@@ -7,8 +7,10 @@ import time
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Literal
from sqlalchemy import select
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.file import FileType, file_manager
from core.file import File, FileTransferMethod, FileType, file_manager
from core.helper.code_executor import CodeExecutor, CodeLanguage
from core.llm_generator.output_parser.errors import OutputParserError
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
@@ -44,6 +46,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.tools.signature import sign_upload_file
from core.variables import (
ArrayFileSegment,
ArraySegment,
@@ -72,6 +75,9 @@ from core.workflow.nodes.base.entities import VariableSelector
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
from core.workflow.runtime import VariablePool
from extensions.ext_database import db
from models.dataset import SegmentAttachmentBinding
from models.model import UploadFile
from . import llm_utils
from .entities import (
@@ -179,12 +185,17 @@ class LLMNode(Node[LLMNodeData]):
# fetch context value
generator = self._fetch_context(node_data=self.node_data)
context = None
context_files: list[File] = []
for event in generator:
context = event.context
context_files = event.context_files or []
yield event
if context:
node_inputs["#context#"] = context
if context_files:
node_inputs["#context_files#"] = [file.model_dump() for file in context_files]
# fetch model config
model_instance, model_config = LLMNode._fetch_model_config(
node_data_model=self.node_data.model,
@@ -220,6 +231,7 @@ class LLMNode(Node[LLMNodeData]):
variable_pool=variable_pool,
jinja2_variables=self.node_data.prompt_config.jinja2_variables,
tenant_id=self.tenant_id,
context_files=context_files,
)
# handle invoke result
@@ -654,10 +666,13 @@ class LLMNode(Node[LLMNodeData]):
context_value_variable = self.graph_runtime_state.variable_pool.get(node_data.context.variable_selector)
if context_value_variable:
if isinstance(context_value_variable, StringSegment):
yield RunRetrieverResourceEvent(retriever_resources=[], context=context_value_variable.value)
yield RunRetrieverResourceEvent(
retriever_resources=[], context=context_value_variable.value, context_files=[]
)
elif isinstance(context_value_variable, ArraySegment):
context_str = ""
original_retriever_resource: list[RetrievalSourceMetadata] = []
context_files: list[File] = []
for item in context_value_variable.value:
if isinstance(item, str):
context_str += item + "\n"
@@ -670,9 +685,34 @@ class LLMNode(Node[LLMNodeData]):
retriever_resource = self._convert_to_original_retriever_resource(item)
if retriever_resource:
original_retriever_resource.append(retriever_resource)
attachments_with_bindings = db.session.execute(
select(SegmentAttachmentBinding, UploadFile)
.join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id)
.where(
SegmentAttachmentBinding.segment_id == retriever_resource.segment_id,
)
).all()
if attachments_with_bindings:
for _, upload_file in attachments_with_bindings:
attchment_info = File(
id=upload_file.id,
filename=upload_file.name,
extension="." + upload_file.extension,
mime_type=upload_file.mime_type,
tenant_id=self.tenant_id,
type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
remote_url=upload_file.source_url,
related_id=upload_file.id,
size=upload_file.size,
storage_key=upload_file.key,
url=sign_upload_file(upload_file.id, upload_file.extension),
)
context_files.append(attchment_info)
yield RunRetrieverResourceEvent(
retriever_resources=original_retriever_resource, context=context_str.strip()
retriever_resources=original_retriever_resource,
context=context_str.strip(),
context_files=context_files,
)
def _convert_to_original_retriever_resource(self, context_dict: dict) -> RetrievalSourceMetadata | None:
@@ -700,6 +740,7 @@ class LLMNode(Node[LLMNodeData]):
content=context_dict.get("content"),
page=metadata.get("page"),
doc_metadata=metadata.get("doc_metadata"),
files=context_dict.get("files"),
)
return source
@@ -741,6 +782,7 @@ class LLMNode(Node[LLMNodeData]):
variable_pool: VariablePool,
jinja2_variables: Sequence[VariableSelector],
tenant_id: str,
context_files: list["File"] | None = None,
) -> tuple[Sequence[PromptMessage], Sequence[str] | None]:
prompt_messages: list[PromptMessage] = []
@@ -853,6 +895,23 @@ class LLMNode(Node[LLMNodeData]):
else:
prompt_messages.append(UserPromptMessage(content=file_prompts))
# The context_files
if vision_enabled and context_files:
file_prompts = []
for file in context_files:
file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail)
file_prompts.append(file_prompt)
# If last prompt is a user prompt, add files into its contents,
# otherwise append a new user prompt
if (
len(prompt_messages) > 0
and isinstance(prompt_messages[-1], UserPromptMessage)
and isinstance(prompt_messages[-1].content, list)
):
prompt_messages[-1] = UserPromptMessage(content=file_prompts + prompt_messages[-1].content)
else:
prompt_messages.append(UserPromptMessage(content=file_prompts))
# Remove empty messages and filter unsupported content
filtered_prompt_messages = []
for prompt_message in prompt_messages:

View File

@@ -97,11 +97,27 @@ dataset_detail_fields = {
"total_documents": fields.Integer,
"total_available_documents": fields.Integer,
"enable_api": fields.Boolean,
"is_multimodal": fields.Boolean,
}
file_info_fields = {
"id": fields.String,
"name": fields.String,
"size": fields.Integer,
"extension": fields.String,
"mime_type": fields.String,
"source_url": fields.String,
}
content_fields = {
"content_type": fields.String,
"content": fields.String,
"file_info": fields.Nested(file_info_fields, allow_null=True),
}
dataset_query_detail_fields = {
"id": fields.String,
"content": fields.String,
"queries": fields.Nested(content_fields),
"source": fields.String,
"source_app_id": fields.String,
"created_by_role": fields.String,

View File

@@ -9,6 +9,8 @@ upload_config_fields = {
"video_file_size_limit": fields.Integer,
"audio_file_size_limit": fields.Integer,
"workflow_file_upload_limit": fields.Integer,
"image_file_batch_limit": fields.Integer,
"single_chunk_attachment_limit": fields.Integer,
}

View File

@@ -43,9 +43,19 @@ child_chunk_fields = {
"score": fields.Float,
}
files_fields = {
"id": fields.String,
"name": fields.String,
"size": fields.Integer,
"extension": fields.String,
"mime_type": fields.String,
"source_url": fields.String,
}
hit_testing_record_fields = {
"segment": fields.Nested(segment_fields),
"child_chunks": fields.List(fields.Nested(child_chunk_fields)),
"score": fields.Float,
"tsne_position": fields.Raw,
"files": fields.List(fields.Nested(files_fields)),
}

View File

@@ -13,6 +13,15 @@ child_chunk_fields = {
"updated_at": TimestampField,
}
attachment_fields = {
"id": fields.String,
"name": fields.String,
"size": fields.Integer,
"extension": fields.String,
"mime_type": fields.String,
"source_url": fields.String,
}
segment_fields = {
"id": fields.String,
"position": fields.Integer,
@@ -39,4 +48,5 @@ segment_fields = {
"error": fields.String,
"stopped_at": TimestampField,
"child_chunks": fields.List(fields.Nested(child_chunk_fields)),
"attachments": fields.List(fields.Nested(attachment_fields)),
}

View File

@@ -0,0 +1,57 @@
"""support-multi-modal
Revision ID: d57accd375ae
Revises: 03f8dcbc611e
Create Date: 2025-11-12 15:37:12.363670
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'd57accd375ae'
down_revision = '7bb281b7a422'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('segment_attachment_bindings',
sa.Column('id', models.types.StringUUID(), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
sa.Column('dataset_id', models.types.StringUUID(), nullable=False),
sa.Column('document_id', models.types.StringUUID(), nullable=False),
sa.Column('segment_id', models.types.StringUUID(), nullable=False),
sa.Column('attachment_id', models.types.StringUUID(), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
sa.PrimaryKeyConstraint('id', name='segment_attachment_binding_pkey')
)
with op.batch_alter_table('segment_attachment_bindings', schema=None) as batch_op:
batch_op.create_index(
'segment_attachment_binding_tenant_dataset_document_segment_idx',
['tenant_id', 'dataset_id', 'document_id', 'segment_id'],
unique=False
)
batch_op.create_index('segment_attachment_binding_attachment_idx', ['attachment_id'], unique=False)
with op.batch_alter_table('datasets', schema=None) as batch_op:
batch_op.add_column(sa.Column('is_multimodal', sa.Boolean(), server_default=sa.text('false'), nullable=False))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please
with op.batch_alter_table('datasets', schema=None) as batch_op:
batch_op.drop_column('is_multimodal')
with op.batch_alter_table('segment_attachment_bindings', schema=None) as batch_op:
batch_op.drop_index('segment_attachment_binding_attachment_idx')
batch_op.drop_index('segment_attachment_binding_tenant_dataset_document_segment_idx')
op.drop_table('segment_attachment_bindings')
# ### end Alembic commands ###

View File

@@ -19,7 +19,9 @@ from sqlalchemy.orm import Mapped, Session, mapped_column
from configs import dify_config
from core.rag.index_processor.constant.built_in_field import BuiltInField, MetadataDataSource
from core.rag.index_processor.constant.query_type import QueryType
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.signature import sign_upload_file
from extensions.ext_storage import storage
from libs.uuid_utils import uuidv7
from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule
@@ -76,6 +78,7 @@ class Dataset(Base):
pipeline_id = mapped_column(StringUUID, nullable=True)
chunk_structure = mapped_column(sa.String(255), nullable=True)
enable_api = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
is_multimodal = mapped_column(sa.Boolean, nullable=False, server_default=db.text("false"))
@property
def total_documents(self):
@@ -728,9 +731,7 @@ class DocumentSegment(Base):
created_by = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_by = mapped_column(StringUUID, nullable=True)
updated_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
indexing_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
completed_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
error = mapped_column(LongText, nullable=True)
@@ -866,6 +867,47 @@ class DocumentSegment(Base):
return text
@property
def attachments(self) -> list[dict[str, Any]]:
# Use JOIN to fetch attachments in a single query instead of two separate queries
attachments_with_bindings = db.session.execute(
select(SegmentAttachmentBinding, UploadFile)
.join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id)
.where(
SegmentAttachmentBinding.tenant_id == self.tenant_id,
SegmentAttachmentBinding.dataset_id == self.dataset_id,
SegmentAttachmentBinding.document_id == self.document_id,
SegmentAttachmentBinding.segment_id == self.id,
)
).all()
if not attachments_with_bindings:
return []
attachment_list = []
for _, attachment in attachments_with_bindings:
upload_file_id = attachment.id
nonce = os.urandom(16).hex()
timestamp = str(int(time.time()))
data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}"
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
encoded_sign = base64.urlsafe_b64encode(sign).decode()
params = f"timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
reference_url = dify_config.CONSOLE_API_URL or ""
base_url = f"{reference_url}/files/{upload_file_id}/image-preview"
source_url = f"{base_url}?{params}"
attachment_list.append(
{
"id": attachment.id,
"name": attachment.name,
"size": attachment.size,
"extension": attachment.extension,
"mime_type": attachment.mime_type,
"source_url": source_url,
}
)
return attachment_list
class ChildChunk(Base):
__tablename__ = "child_chunks"
@@ -963,6 +1005,38 @@ class DatasetQuery(TypeBase):
DateTime, nullable=False, server_default=sa.func.current_timestamp(), init=False
)
@property
def queries(self) -> list[dict[str, Any]]:
try:
queries = json.loads(self.content)
if isinstance(queries, list):
for query in queries:
if query["content_type"] == QueryType.IMAGE_QUERY:
file_info = db.session.query(UploadFile).filter_by(id=query["content"]).first()
if file_info:
query["file_info"] = {
"id": file_info.id,
"name": file_info.name,
"size": file_info.size,
"extension": file_info.extension,
"mime_type": file_info.mime_type,
"source_url": sign_upload_file(file_info.id, file_info.extension),
}
else:
query["file_info"] = None
return queries
else:
return [queries]
except JSONDecodeError:
return [
{
"content_type": QueryType.TEXT_QUERY,
"content": self.content,
"file_info": None,
}
]
class DatasetKeywordTable(TypeBase):
__tablename__ = "dataset_keyword_tables"
@@ -1470,3 +1544,25 @@ class PipelineRecommendedPlugin(TypeBase):
onupdate=func.current_timestamp(),
init=False,
)
class SegmentAttachmentBinding(Base):
__tablename__ = "segment_attachment_bindings"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="segment_attachment_binding_pkey"),
sa.Index(
"segment_attachment_binding_tenant_dataset_document_segment_idx",
"tenant_id",
"dataset_id",
"document_id",
"segment_id",
),
sa.Index("segment_attachment_binding_attachment_idx", "attachment_id"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
document_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
segment_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
attachment_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())

View File

@@ -0,0 +1,31 @@
import base64
from sqlalchemy import Engine
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import NotFound
from extensions.ext_storage import storage
from models.model import UploadFile
PREVIEW_WORDS_LIMIT = 3000
class AttachmentService:
_session_maker: sessionmaker
def __init__(self, session_factory: sessionmaker | Engine | None = None):
if isinstance(session_factory, Engine):
self._session_maker = sessionmaker(bind=session_factory)
elif isinstance(session_factory, sessionmaker):
self._session_maker = session_factory
else:
raise AssertionError("must be a sessionmaker or an Engine.")
def get_file_base64(self, file_id: str) -> str:
upload_file = (
self._session_maker(expire_on_commit=False).query(UploadFile).where(UploadFile.id == file_id).first()
)
if not upload_file:
raise NotFound("File not found")
blob = storage.load_once(upload_file.key)
return base64.b64encode(blob).decode()

View File

@@ -7,7 +7,7 @@ import time
import uuid
from collections import Counter
from collections.abc import Sequence
from typing import Any, Literal
from typing import Any, Literal, cast
import sqlalchemy as sa
from redis.exceptions import LockNotOwnedError
@@ -19,9 +19,10 @@ from configs import dify_config
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.helper.name_generator import generate_incremental_name
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.rag.index_processor.constant.built_in_field import BuiltInField
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from enums.cloud_plan import CloudPlan
from events.dataset_event import dataset_was_deleted
@@ -46,6 +47,7 @@ from models.dataset import (
DocumentSegment,
ExternalKnowledgeBindings,
Pipeline,
SegmentAttachmentBinding,
)
from models.model import UploadFile
from models.provider_ids import ModelProviderID
@@ -363,6 +365,27 @@ class DatasetService:
except ProviderTokenNotInitError as ex:
raise ValueError(ex.description)
@staticmethod
def check_is_multimodal_model(tenant_id: str, model_provider: str, model: str):
try:
model_manager = ModelManager()
model_instance = model_manager.get_model_instance(
tenant_id=tenant_id,
provider=model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=model,
)
text_embedding_model = cast(TextEmbeddingModel, model_instance.model_type_instance)
model_schema = text_embedding_model.get_model_schema(model_instance.model, model_instance.credentials)
if not model_schema:
raise ValueError("Model schema not found")
if model_schema.features and ModelFeature.VISION in model_schema.features:
return True
else:
return False
except LLMBadRequestError:
raise ValueError("No Model available. Please configure a valid provider in the Settings -> Model Provider.")
@staticmethod
def check_reranking_model_setting(tenant_id: str, reranking_model_provider: str, reranking_model: str):
try:
@@ -402,13 +425,13 @@ class DatasetService:
if not dataset:
raise ValueError("Dataset not found")
# check if dataset name is exists
if DatasetService._has_dataset_same_name(
tenant_id=dataset.tenant_id,
dataset_id=dataset_id,
name=data.get("name", dataset.name),
):
raise ValueError("Dataset name already exists")
if data.get("name") and data.get("name") != dataset.name:
if DatasetService._has_dataset_same_name(
tenant_id=dataset.tenant_id,
dataset_id=dataset_id,
name=data.get("name", dataset.name),
):
raise ValueError("Dataset name already exists")
# Verify user has permission to update this dataset
DatasetService.check_dataset_permission(dataset, user)
@@ -844,6 +867,12 @@ class DatasetService:
model_type=ModelType.TEXT_EMBEDDING,
model=knowledge_configuration.embedding_model or "",
)
is_multimodal = DatasetService.check_is_multimodal_model(
current_user.current_tenant_id,
knowledge_configuration.embedding_model_provider,
knowledge_configuration.embedding_model,
)
dataset.is_multimodal = is_multimodal
dataset.embedding_model = embedding_model.model
dataset.embedding_model_provider = embedding_model.provider
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
@@ -880,6 +909,12 @@ class DatasetService:
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_model.provider, embedding_model.model
)
is_multimodal = DatasetService.check_is_multimodal_model(
current_user.current_tenant_id,
knowledge_configuration.embedding_model_provider,
knowledge_configuration.embedding_model,
)
dataset.is_multimodal = is_multimodal
dataset.collection_binding_id = dataset_collection_binding.id
dataset.indexing_technique = knowledge_configuration.indexing_technique
except LLMBadRequestError:
@@ -937,6 +972,12 @@ class DatasetService:
)
)
dataset.collection_binding_id = dataset_collection_binding.id
is_multimodal = DatasetService.check_is_multimodal_model(
current_user.current_tenant_id,
knowledge_configuration.embedding_model_provider,
knowledge_configuration.embedding_model,
)
dataset.is_multimodal = is_multimodal
except LLMBadRequestError:
raise ValueError(
"No Embedding Model available. Please configure a valid provider "
@@ -2305,6 +2346,7 @@ class DocumentService:
embedding_model_provider=knowledge_config.embedding_model_provider,
collection_binding_id=dataset_collection_binding_id,
retrieval_model=retrieval_model.model_dump() if retrieval_model else None,
is_multimodal=knowledge_config.is_multimodal,
)
db.session.add(dataset)
@@ -2685,6 +2727,13 @@ class SegmentService:
if "content" not in args or not args["content"] or not args["content"].strip():
raise ValueError("Content is empty")
if args.get("attachment_ids"):
if not isinstance(args["attachment_ids"], list):
raise ValueError("Attachment IDs is invalid")
single_chunk_attachment_limit = dify_config.SINGLE_CHUNK_ATTACHMENT_LIMIT
if len(args["attachment_ids"]) > single_chunk_attachment_limit:
raise ValueError(f"Exceeded maximum attachment limit of {single_chunk_attachment_limit}")
@classmethod
def create_segment(cls, args: dict, document: Document, dataset: Dataset):
assert isinstance(current_user, Account)
@@ -2731,11 +2780,23 @@ class SegmentService:
segment_document.word_count += len(args["answer"])
segment_document.answer = args["answer"]
db.session.add(segment_document)
# update document word count
assert document.word_count is not None
document.word_count += segment_document.word_count
db.session.add(document)
db.session.add(segment_document)
# update document word count
assert document.word_count is not None
document.word_count += segment_document.word_count
db.session.add(document)
db.session.commit()
if args["attachment_ids"]:
for attachment_id in args["attachment_ids"]:
binding = SegmentAttachmentBinding(
tenant_id=current_user.current_tenant_id,
dataset_id=document.dataset_id,
document_id=document.id,
segment_id=segment_document.id,
attachment_id=attachment_id,
)
db.session.add(binding)
db.session.commit()
# save vector index
@@ -2899,7 +2960,7 @@ class SegmentService:
document.word_count = max(0, document.word_count + word_count_change)
db.session.add(document)
# update segment index task
if document.doc_form == IndexType.PARENT_CHILD_INDEX and args.regenerate_child_chunks:
if document.doc_form == IndexStructureType.PARENT_CHILD_INDEX and args.regenerate_child_chunks:
# regenerate child chunks
# get embedding model instance
if dataset.indexing_technique == "high_quality":
@@ -2926,12 +2987,11 @@ class SegmentService:
.where(DatasetProcessRule.id == document.dataset_process_rule_id)
.first()
)
if not processing_rule:
raise ValueError("No processing rule found.")
VectorService.generate_child_chunks(
segment, document, dataset, embedding_model_instance, processing_rule, True
)
elif document.doc_form in (IndexType.PARAGRAPH_INDEX, IndexType.QA_INDEX):
if processing_rule:
VectorService.generate_child_chunks(
segment, document, dataset, embedding_model_instance, processing_rule, True
)
elif document.doc_form in (IndexStructureType.PARAGRAPH_INDEX, IndexStructureType.QA_INDEX):
if args.enabled or keyword_changed:
# update segment vector index
VectorService.update_segment_vector(args.keywords, segment, dataset)
@@ -2976,7 +3036,7 @@ class SegmentService:
db.session.add(document)
db.session.add(segment)
db.session.commit()
if document.doc_form == IndexType.PARENT_CHILD_INDEX and args.regenerate_child_chunks:
if document.doc_form == IndexStructureType.PARENT_CHILD_INDEX and args.regenerate_child_chunks:
# get embedding model instance
if dataset.indexing_technique == "high_quality":
# check embedding model setting
@@ -3002,15 +3062,15 @@ class SegmentService:
.where(DatasetProcessRule.id == document.dataset_process_rule_id)
.first()
)
if not processing_rule:
raise ValueError("No processing rule found.")
VectorService.generate_child_chunks(
segment, document, dataset, embedding_model_instance, processing_rule, True
)
elif document.doc_form in (IndexType.PARAGRAPH_INDEX, IndexType.QA_INDEX):
if processing_rule:
VectorService.generate_child_chunks(
segment, document, dataset, embedding_model_instance, processing_rule, True
)
elif document.doc_form in (IndexStructureType.PARAGRAPH_INDEX, IndexStructureType.QA_INDEX):
# update segment vector index
VectorService.update_segment_vector(args.keywords, segment, dataset)
# update multimodel vector index
VectorService.update_multimodel_vector(segment, args.attachment_ids or [], dataset)
except Exception as e:
logger.exception("update segment index failed")
segment.enabled = False
@@ -3048,7 +3108,9 @@ class SegmentService:
)
child_node_ids = [chunk[0] for chunk in child_chunks if chunk[0]]
delete_segment_from_index_task.delay([segment.index_node_id], dataset.id, document.id, child_node_ids)
delete_segment_from_index_task.delay(
[segment.index_node_id], dataset.id, document.id, [segment.id], child_node_ids
)
db.session.delete(segment)
# update document word count
@@ -3097,7 +3159,9 @@ class SegmentService:
# Start async cleanup with both parent and child node IDs
if index_node_ids or child_node_ids:
delete_segment_from_index_task.delay(index_node_ids, dataset.id, document.id, child_node_ids)
delete_segment_from_index_task.delay(
index_node_ids, dataset.id, document.id, segment_db_ids, child_node_ids
)
if document.word_count is None:
document.word_count = 0

View File

@@ -124,6 +124,14 @@ class KnowledgeConfig(BaseModel):
embedding_model: str | None = None
embedding_model_provider: str | None = None
name: str | None = None
is_multimodal: bool = False
class SegmentCreateArgs(BaseModel):
content: str | None = None
answer: str | None = None
keywords: list[str] | None = None
attachment_ids: list[str] | None = None
class SegmentUpdateArgs(BaseModel):
@@ -132,6 +140,7 @@ class SegmentUpdateArgs(BaseModel):
keywords: list[str] | None = None
regenerate_child_chunks: bool = False
enabled: bool | None = None
attachment_ids: list[str] | None = None
class ChildChunkUpdateArgs(BaseModel):

View File

@@ -1,3 +1,4 @@
import base64
import hashlib
import os
import uuid
@@ -123,6 +124,15 @@ class FileService:
return file_size <= file_size_limit
def get_file_base64(self, file_id: str) -> str:
upload_file = (
self._session_maker(expire_on_commit=False).query(UploadFile).where(UploadFile.id == file_id).first()
)
if not upload_file:
raise NotFound("File not found")
blob = storage.load_once(upload_file.key)
return base64.b64encode(blob).decode()
def upload_text(self, text: str, text_name: str, user_id: str, tenant_id: str) -> UploadFile:
if len(text_name) > 200:
text_name = text_name[:200]

View File

@@ -1,3 +1,4 @@
import json
import logging
import time
from typing import Any
@@ -5,6 +6,7 @@ from typing import Any
from core.app.app_config.entities import ModelConfig
from core.model_runtime.entities import LLMMode
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.index_processor.constant.query_type import QueryType
from core.rag.models.document import Document
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.rag.retrieval.retrieval_methods import RetrievalMethod
@@ -32,6 +34,7 @@ class HitTestingService:
account: Account,
retrieval_model: Any, # FIXME drop this any
external_retrieval_model: dict,
attachment_ids: list | None = None,
limit: int = 10,
):
start = time.perf_counter()
@@ -41,7 +44,7 @@ class HitTestingService:
retrieval_model = dataset.retrieval_model or default_retrieval_model
document_ids_filter = None
metadata_filtering_conditions = retrieval_model.get("metadata_filtering_conditions", {})
if metadata_filtering_conditions:
if metadata_filtering_conditions and query:
dataset_retrieval = DatasetRetrieval()
from core.app.app_config.entities import MetadataFilteringCondition
@@ -66,6 +69,7 @@ class HitTestingService:
retrieval_method=RetrievalMethod(retrieval_model.get("search_method", RetrievalMethod.SEMANTIC_SEARCH)),
dataset_id=dataset.id,
query=query,
attachment_ids=attachment_ids,
top_k=retrieval_model.get("top_k", 4),
score_threshold=retrieval_model.get("score_threshold", 0.0)
if retrieval_model["score_threshold_enabled"]
@@ -80,17 +84,24 @@ class HitTestingService:
end = time.perf_counter()
logger.debug("Hit testing retrieve in %s seconds", end - start)
dataset_query = DatasetQuery(
dataset_id=dataset.id,
content=query,
source="hit_testing",
source_app_id=None,
created_by_role="account",
created_by=account.id,
)
db.session.add(dataset_query)
dataset_queries = []
if query:
content = {"content_type": QueryType.TEXT_QUERY, "content": query}
dataset_queries.append(content)
if attachment_ids:
for attachment_id in attachment_ids:
content = {"content_type": QueryType.IMAGE_QUERY, "content": attachment_id}
dataset_queries.append(content)
if dataset_queries:
dataset_query = DatasetQuery(
dataset_id=dataset.id,
content=json.dumps(dataset_queries),
source="hit_testing",
source_app_id=None,
created_by_role="account",
created_by=account.id,
)
db.session.add(dataset_query)
db.session.commit()
return cls.compact_retrieve_response(query, all_documents)
@@ -168,9 +179,14 @@ class HitTestingService:
@classmethod
def hit_testing_args_check(cls, args):
query = args["query"]
attachment_ids = args["attachment_ids"]
if not query or len(query) > 250:
raise ValueError("Query is required and cannot exceed 250 characters")
if not attachment_ids and not query:
raise ValueError("Query or attachment_ids is required")
if query and len(query) > 250:
raise ValueError("Query cannot exceed 250 characters")
if attachment_ids and not isinstance(attachment_ids, list):
raise ValueError("Attachment_ids must be a list")
@staticmethod
def escape_query_for_search(query: str) -> str:

View File

@@ -4,11 +4,14 @@ from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import Document
from core.rag.models.document import AttachmentDocument, Document
from extensions.ext_database import db
from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment
from models import UploadFile
from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment, SegmentAttachmentBinding
from models.dataset import Document as DatasetDocument
from services.entities.knowledge_entities.knowledge_entities import ParentMode
@@ -21,9 +24,10 @@ class VectorService:
cls, keywords_list: list[list[str]] | None, segments: list[DocumentSegment], dataset: Dataset, doc_form: str
):
documents: list[Document] = []
multimodal_documents: list[AttachmentDocument] = []
for segment in segments:
if doc_form == IndexType.PARENT_CHILD_INDEX:
if doc_form == IndexStructureType.PARENT_CHILD_INDEX:
dataset_document = db.session.query(DatasetDocument).filter_by(id=segment.document_id).first()
if not dataset_document:
logger.warning(
@@ -70,12 +74,29 @@ class VectorService:
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
"doc_type": DocType.TEXT,
},
)
documents.append(rag_document)
if dataset.is_multimodal:
for attachment in segment.attachments:
multimodal_document: AttachmentDocument = AttachmentDocument(
page_content=attachment["name"],
metadata={
"doc_id": attachment["id"],
"doc_hash": "",
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
"doc_type": DocType.IMAGE,
},
)
multimodal_documents.append(multimodal_document)
index_processor: BaseIndexProcessor = IndexProcessorFactory(doc_form).init_index_processor()
if len(documents) > 0:
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
index_processor.load(dataset, documents, with_keywords=True, keywords_list=keywords_list)
index_processor.load(dataset, documents, None, with_keywords=True, keywords_list=keywords_list)
if len(multimodal_documents) > 0:
index_processor.load(dataset, [], multimodal_documents, with_keywords=False)
@classmethod
def update_segment_vector(cls, keywords: list[str] | None, segment: DocumentSegment, dataset: Dataset):
@@ -130,6 +151,7 @@ class VectorService:
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
"doc_type": DocType.TEXT,
},
)
# use full doc mode to generate segment's child chunk
@@ -226,3 +248,92 @@ class VectorService:
def delete_child_chunk_vector(cls, child_chunk: ChildChunk, dataset: Dataset):
vector = Vector(dataset=dataset)
vector.delete_by_ids([child_chunk.index_node_id])
@classmethod
def update_multimodel_vector(cls, segment: DocumentSegment, attachment_ids: list[str], dataset: Dataset):
if dataset.indexing_technique != "high_quality":
return
attachments = segment.attachments
old_attachment_ids = [attachment["id"] for attachment in attachments] if attachments else []
# Check if there's any actual change needed
if set(attachment_ids) == set(old_attachment_ids):
return
try:
vector = Vector(dataset=dataset)
if dataset.is_multimodal:
# Delete old vectors if they exist
if old_attachment_ids:
vector.delete_by_ids(old_attachment_ids)
# Delete existing segment attachment bindings in one operation
db.session.query(SegmentAttachmentBinding).where(SegmentAttachmentBinding.segment_id == segment.id).delete(
synchronize_session=False
)
if not attachment_ids:
db.session.commit()
return
# Bulk fetch upload files - only fetch needed fields
upload_file_list = db.session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).all()
if not upload_file_list:
db.session.commit()
return
# Create a mapping for quick lookup
upload_file_map = {upload_file.id: upload_file for upload_file in upload_file_list}
# Prepare batch operations
bindings = []
documents = []
# Create common metadata base to avoid repetition
base_metadata = {
"doc_hash": "",
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
"doc_type": DocType.IMAGE,
}
# Process attachments in the order specified by attachment_ids
for attachment_id in attachment_ids:
upload_file = upload_file_map.get(attachment_id)
if not upload_file:
logger.warning("Upload file not found for attachment_id: %s", attachment_id)
continue
# Create segment attachment binding
bindings.append(
SegmentAttachmentBinding(
tenant_id=segment.tenant_id,
dataset_id=segment.dataset_id,
document_id=segment.document_id,
segment_id=segment.id,
attachment_id=upload_file.id,
)
)
# Create document for vector indexing
documents.append(
Document(page_content=upload_file.name, metadata={**base_metadata, "doc_id": upload_file.id})
)
# Bulk insert all bindings at once
if bindings:
db.session.add_all(bindings)
# Add documents to vector store if any
if documents and dataset.is_multimodal:
vector.add_texts(documents, duplicate_check=True)
# Single commit for all operations
db.session.commit()
except Exception:
logger.exception("Failed to update multimodal vector for segment %s", segment.id)
db.session.rollback()
raise

View File

@@ -4,9 +4,10 @@ import time
import click
from celery import shared_task
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import ChildDocument, Document
from core.rag.models.document import AttachmentDocument, ChildDocument, Document
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
@@ -55,6 +56,7 @@ def add_document_to_index_task(dataset_document_id: str):
)
documents = []
multimodal_documents = []
for segment in segments:
document = Document(
page_content=segment.content,
@@ -65,7 +67,7 @@ def add_document_to_index_task(dataset_document_id: str):
"dataset_id": segment.dataset_id,
},
)
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
child_chunks = segment.get_child_chunks()
if child_chunks:
child_documents = []
@@ -81,11 +83,25 @@ def add_document_to_index_task(dataset_document_id: str):
)
child_documents.append(child_document)
document.children = child_documents
if dataset.is_multimodal:
for attachment in segment.attachments:
multimodal_documents.append(
AttachmentDocument(
page_content=attachment["name"],
metadata={
"doc_id": attachment["id"],
"doc_hash": "",
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
"doc_type": DocType.IMAGE,
},
)
)
documents.append(document)
index_type = dataset.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
index_processor.load(dataset, documents)
index_processor.load(dataset, documents, multimodal_documents=multimodal_documents)
# delete auto disable log
db.session.query(DatasetAutoDisableLog).where(DatasetAutoDisableLog.document_id == dataset_document.id).delete()

View File

@@ -18,6 +18,7 @@ from models.dataset import (
DatasetQuery,
Document,
DocumentSegment,
SegmentAttachmentBinding,
)
from models.model import UploadFile
@@ -58,14 +59,20 @@ def clean_dataset_task(
)
documents = db.session.scalars(select(Document).where(Document.dataset_id == dataset_id)).all()
segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.dataset_id == dataset_id)).all()
# Use JOIN to fetch attachments with bindings in a single query
attachments_with_bindings = db.session.execute(
select(SegmentAttachmentBinding, UploadFile)
.join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id)
.where(SegmentAttachmentBinding.tenant_id == tenant_id, SegmentAttachmentBinding.dataset_id == dataset_id)
).all()
# Enhanced validation: Check if doc_form is None, empty string, or contains only whitespace
# This ensures all invalid doc_form values are properly handled
if doc_form is None or (isinstance(doc_form, str) and not doc_form.strip()):
# Use default paragraph index type for empty/invalid datasets to enable vector database cleanup
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.constant.index_type import IndexStructureType
doc_form = IndexType.PARAGRAPH_INDEX
doc_form = IndexStructureType.PARAGRAPH_INDEX
logger.info(
click.style(f"Invalid doc_form detected, using default index type for cleanup: {doc_form}", fg="yellow")
)
@@ -90,6 +97,7 @@ def clean_dataset_task(
for document in documents:
db.session.delete(document)
# delete document file
for segment in segments:
image_upload_file_ids = get_image_upload_file_ids(segment.content)
@@ -107,6 +115,19 @@ def clean_dataset_task(
)
db.session.delete(image_file)
db.session.delete(segment)
# delete segment attachments
if attachments_with_bindings:
for binding, attachment_file in attachments_with_bindings:
try:
storage.delete(attachment_file.key)
except Exception:
logger.exception(
"Delete attachment_file failed when storage deleted, \
attachment_file_id: %s",
binding.attachment_id,
)
db.session.delete(attachment_file)
db.session.delete(binding)
db.session.query(DatasetProcessRule).where(DatasetProcessRule.dataset_id == dataset_id).delete()
db.session.query(DatasetQuery).where(DatasetQuery.dataset_id == dataset_id).delete()

View File

@@ -9,7 +9,7 @@ from core.rag.index_processor.index_processor_factory import IndexProcessorFacto
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.dataset import Dataset, DatasetMetadataBinding, DocumentSegment
from models.dataset import Dataset, DatasetMetadataBinding, DocumentSegment, SegmentAttachmentBinding
from models.model import UploadFile
logger = logging.getLogger(__name__)
@@ -36,6 +36,16 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
raise Exception("Document has no dataset")
segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
# Use JOIN to fetch attachments with bindings in a single query
attachments_with_bindings = db.session.execute(
select(SegmentAttachmentBinding, UploadFile)
.join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id)
.where(
SegmentAttachmentBinding.tenant_id == dataset.tenant_id,
SegmentAttachmentBinding.dataset_id == dataset_id,
SegmentAttachmentBinding.document_id == document_id,
)
).all()
# check segment is exist
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
@@ -69,6 +79,19 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
logger.exception("Delete file failed when document deleted, file_id: %s", file_id)
db.session.delete(file)
db.session.commit()
# delete segment attachments
if attachments_with_bindings:
for binding, attachment_file in attachments_with_bindings:
try:
storage.delete(attachment_file.key)
except Exception:
logger.exception(
"Delete attachment_file failed when storage deleted, \
attachment_file_id: %s",
binding.attachment_id,
)
db.session.delete(attachment_file)
db.session.delete(binding)
# delete dataset metadata binding
db.session.query(DatasetMetadataBinding).where(

View File

@@ -4,9 +4,10 @@ import time
import click
from celery import shared_task # type: ignore
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import ChildDocument, Document
from core.rag.models.document import AttachmentDocument, ChildDocument, Document
from extensions.ext_database import db
from models.dataset import Dataset, DocumentSegment
from models.dataset import Document as DatasetDocument
@@ -28,7 +29,7 @@ def deal_dataset_index_update_task(dataset_id: str, action: str):
if not dataset:
raise Exception("Dataset not found")
index_type = dataset.doc_form or IndexType.PARAGRAPH_INDEX
index_type = dataset.doc_form or IndexStructureType.PARAGRAPH_INDEX
index_processor = IndexProcessorFactory(index_type).init_index_processor()
if action == "upgrade":
dataset_documents = (
@@ -119,6 +120,7 @@ def deal_dataset_index_update_task(dataset_id: str, action: str):
)
if segments:
documents = []
multimodal_documents = []
for segment in segments:
document = Document(
page_content=segment.content,
@@ -129,7 +131,7 @@ def deal_dataset_index_update_task(dataset_id: str, action: str):
"dataset_id": segment.dataset_id,
},
)
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
child_chunks = segment.get_child_chunks()
if child_chunks:
child_documents = []
@@ -145,9 +147,25 @@ def deal_dataset_index_update_task(dataset_id: str, action: str):
)
child_documents.append(child_document)
document.children = child_documents
if dataset.is_multimodal:
for attachment in segment.attachments:
multimodal_documents.append(
AttachmentDocument(
page_content=attachment["name"],
metadata={
"doc_id": attachment["id"],
"doc_hash": "",
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
"doc_type": DocType.IMAGE,
},
)
)
documents.append(document)
# save vector index
index_processor.load(dataset, documents, with_keywords=False)
index_processor.load(
dataset, documents, multimodal_documents=multimodal_documents, with_keywords=False
)
db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "completed"}, synchronize_session=False
)

View File

@@ -1,14 +1,14 @@
import logging
import time
from typing import Literal
import click
from celery import shared_task
from sqlalchemy import select
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import ChildDocument, Document
from core.rag.models.document import AttachmentDocument, ChildDocument, Document
from extensions.ext_database import db
from models.dataset import Dataset, DocumentSegment
from models.dataset import Document as DatasetDocument
@@ -17,7 +17,7 @@ logger = logging.getLogger(__name__)
@shared_task(queue="dataset")
def deal_dataset_vector_index_task(dataset_id: str, action: Literal["remove", "add", "update"]):
def deal_dataset_vector_index_task(dataset_id: str, action: str):
"""
Async deal dataset from index
:param dataset_id: dataset_id
@@ -32,7 +32,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: Literal["remove", "a
if not dataset:
raise Exception("Dataset not found")
index_type = dataset.doc_form or IndexType.PARAGRAPH_INDEX
index_type = dataset.doc_form or IndexStructureType.PARAGRAPH_INDEX
index_processor = IndexProcessorFactory(index_type).init_index_processor()
if action == "remove":
index_processor.clean(dataset, None, with_keywords=False)
@@ -119,6 +119,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: Literal["remove", "a
)
if segments:
documents = []
multimodal_documents = []
for segment in segments:
document = Document(
page_content=segment.content,
@@ -129,7 +130,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: Literal["remove", "a
"dataset_id": segment.dataset_id,
},
)
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
child_chunks = segment.get_child_chunks()
if child_chunks:
child_documents = []
@@ -145,9 +146,25 @@ def deal_dataset_vector_index_task(dataset_id: str, action: Literal["remove", "a
)
child_documents.append(child_document)
document.children = child_documents
if dataset.is_multimodal:
for attachment in segment.attachments:
multimodal_documents.append(
AttachmentDocument(
page_content=attachment["name"],
metadata={
"doc_id": attachment["id"],
"doc_hash": "",
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
"doc_type": DocType.IMAGE,
},
)
)
documents.append(document)
# save vector index
index_processor.load(dataset, documents, with_keywords=False)
index_processor.load(
dataset, documents, multimodal_documents=multimodal_documents, with_keywords=False
)
db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "completed"}, synchronize_session=False
)

View File

@@ -6,14 +6,15 @@ from celery import shared_task
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
from models.dataset import Dataset, Document
from models.dataset import Dataset, Document, SegmentAttachmentBinding
from models.model import UploadFile
logger = logging.getLogger(__name__)
@shared_task(queue="dataset")
def delete_segment_from_index_task(
index_node_ids: list, dataset_id: str, document_id: str, child_node_ids: list | None = None
index_node_ids: list, dataset_id: str, document_id: str, segment_ids: list, child_node_ids: list | None = None
):
"""
Async Remove segment from index
@@ -49,6 +50,21 @@ def delete_segment_from_index_task(
delete_child_chunks=True,
precomputed_child_node_ids=child_node_ids,
)
if dataset.is_multimodal:
# delete segment attachment binding
segment_attachment_bindings = (
db.session.query(SegmentAttachmentBinding)
.where(SegmentAttachmentBinding.segment_id.in_(segment_ids))
.all()
)
if segment_attachment_bindings:
attachment_ids = [binding.attachment_id for binding in segment_attachment_bindings]
index_processor.clean(dataset=dataset, node_ids=attachment_ids, with_keywords=False)
for binding in segment_attachment_bindings:
db.session.delete(binding)
# delete upload file
db.session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).delete(synchronize_session=False)
db.session.commit()
end_at = time.perf_counter()
logger.info(click.style(f"Segment deleted from index latency: {end_at - start_at}", fg="green"))

View File

@@ -8,7 +8,7 @@ from sqlalchemy import select
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import Dataset, DocumentSegment
from models.dataset import Dataset, DocumentSegment, SegmentAttachmentBinding
from models.dataset import Document as DatasetDocument
logger = logging.getLogger(__name__)
@@ -59,6 +59,16 @@ def disable_segments_from_index_task(segment_ids: list, dataset_id: str, documen
try:
index_node_ids = [segment.index_node_id for segment in segments]
if dataset.is_multimodal:
segment_ids = [segment.id for segment in segments]
segment_attachment_bindings = (
db.session.query(SegmentAttachmentBinding)
.where(SegmentAttachmentBinding.segment_id.in_(segment_ids))
.all()
)
if segment_attachment_bindings:
attachment_ids = [binding.attachment_id for binding in segment_attachment_bindings]
index_node_ids.extend(attachment_ids)
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False)
end_at = time.perf_counter()

View File

@@ -4,9 +4,10 @@ import time
import click
from celery import shared_task
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import ChildDocument, Document
from core.rag.models.document import AttachmentDocument, ChildDocument, Document
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
@@ -67,7 +68,7 @@ def enable_segment_to_index_task(segment_id: str):
return
index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor()
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
child_chunks = segment.get_child_chunks()
if child_chunks:
child_documents = []
@@ -83,8 +84,24 @@ def enable_segment_to_index_task(segment_id: str):
)
child_documents.append(child_document)
document.children = child_documents
multimodel_documents = []
if dataset.is_multimodal:
for attachment in segment.attachments:
multimodel_documents.append(
AttachmentDocument(
page_content=attachment["name"],
metadata={
"doc_id": attachment["id"],
"doc_hash": "",
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
"doc_type": DocType.IMAGE,
},
)
)
# save vector index
index_processor.load(dataset, [document])
index_processor.load(dataset, [document], multimodal_documents=multimodel_documents)
end_at = time.perf_counter()
logger.info(click.style(f"Segment enabled to index: {segment.id} latency: {end_at - start_at}", fg="green"))

View File

@@ -5,9 +5,10 @@ import click
from celery import shared_task
from sqlalchemy import select
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import ChildDocument, Document
from core.rag.models.document import AttachmentDocument, ChildDocument, Document
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
@@ -60,6 +61,7 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i
try:
documents = []
multimodal_documents = []
for segment in segments:
document = Document(
page_content=segment.content,
@@ -71,7 +73,7 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i
},
)
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
child_chunks = segment.get_child_chunks()
if child_chunks:
child_documents = []
@@ -87,9 +89,24 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i
)
child_documents.append(child_document)
document.children = child_documents
if dataset.is_multimodal:
for attachment in segment.attachments:
multimodal_documents.append(
AttachmentDocument(
page_content=attachment["name"],
metadata={
"doc_id": attachment["id"],
"doc_hash": "",
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
"doc_type": DocType.IMAGE,
},
)
)
documents.append(document)
# save vector index
index_processor.load(dataset, documents)
index_processor.load(dataset, documents, multimodal_documents=multimodal_documents)
end_at = time.perf_counter()
logger.info(click.style(f"Segments enabled to index latency: {end_at - start_at}", fg="green"))

View File

@@ -3,7 +3,7 @@ from unittest.mock import MagicMock, patch
import pytest
from faker import Faker
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.constant.index_type import IndexStructureType
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
@@ -95,7 +95,7 @@ class TestAddDocumentToIndexTask:
created_by=account.id,
indexing_status="completed",
enabled=True,
doc_form=IndexType.PARAGRAPH_INDEX,
doc_form=IndexStructureType.PARAGRAPH_INDEX,
)
db.session.add(document)
db.session.commit()
@@ -172,7 +172,9 @@ class TestAddDocumentToIndexTask:
# Assert: Verify the expected outcomes
# Verify index processor was called correctly
mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(IndexType.PARAGRAPH_INDEX)
mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(
IndexStructureType.PARAGRAPH_INDEX
)
mock_external_service_dependencies["index_processor"].load.assert_called_once()
# Verify database state changes
@@ -204,7 +206,7 @@ class TestAddDocumentToIndexTask:
)
# Update document to use different index type
document.doc_form = IndexType.QA_INDEX
document.doc_form = IndexStructureType.QA_INDEX
db.session.commit()
# Refresh dataset to ensure doc_form property reflects the updated document
@@ -221,7 +223,9 @@ class TestAddDocumentToIndexTask:
add_document_to_index_task(document.id)
# Assert: Verify different index type handling
mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(IndexType.QA_INDEX)
mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(
IndexStructureType.QA_INDEX
)
mock_external_service_dependencies["index_processor"].load.assert_called_once()
# Verify the load method was called with correct parameters
@@ -360,7 +364,7 @@ class TestAddDocumentToIndexTask:
)
# Update document to use parent-child index type
document.doc_form = IndexType.PARENT_CHILD_INDEX
document.doc_form = IndexStructureType.PARENT_CHILD_INDEX
db.session.commit()
# Refresh dataset to ensure doc_form property reflects the updated document
@@ -391,7 +395,7 @@ class TestAddDocumentToIndexTask:
# Assert: Verify parent-child index processing
mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(
IndexType.PARENT_CHILD_INDEX
IndexStructureType.PARENT_CHILD_INDEX
)
mock_external_service_dependencies["index_processor"].load.assert_called_once()
@@ -465,8 +469,10 @@ class TestAddDocumentToIndexTask:
# Act: Execute the task
add_document_to_index_task(document.id)
# Assert: Verify index processing occurred with all completed segments
mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(IndexType.PARAGRAPH_INDEX)
# Assert: Verify index processing occurred but with empty documents list
mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(
IndexStructureType.PARAGRAPH_INDEX
)
mock_external_service_dependencies["index_processor"].load.assert_called_once()
# Verify the load method was called with all completed segments
@@ -532,7 +538,9 @@ class TestAddDocumentToIndexTask:
assert len(remaining_logs) == 0
# Verify index processing occurred normally
mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(IndexType.PARAGRAPH_INDEX)
mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(
IndexStructureType.PARAGRAPH_INDEX
)
mock_external_service_dependencies["index_processor"].load.assert_called_once()
# Verify segments were enabled
@@ -699,7 +707,9 @@ class TestAddDocumentToIndexTask:
add_document_to_index_task(document.id)
# Assert: Verify only eligible segments were processed
mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(IndexType.PARAGRAPH_INDEX)
mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(
IndexStructureType.PARAGRAPH_INDEX
)
mock_external_service_dependencies["index_processor"].load.assert_called_once()
# Verify the load method was called with correct parameters

View File

@@ -12,7 +12,7 @@ from unittest.mock import MagicMock, patch
from faker import Faker
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.constant.index_type import IndexStructureType
from models import Account, Dataset, Document, DocumentSegment, Tenant
from tasks.delete_segment_from_index_task import delete_segment_from_index_task
@@ -164,7 +164,7 @@ class TestDeleteSegmentFromIndexTask:
document.updated_at = fake.date_time_this_year()
document.doc_type = kwargs.get("doc_type", "text")
document.doc_metadata = kwargs.get("doc_metadata", {})
document.doc_form = kwargs.get("doc_form", IndexType.PARAGRAPH_INDEX)
document.doc_form = kwargs.get("doc_form", IndexStructureType.PARAGRAPH_INDEX)
document.doc_language = kwargs.get("doc_language", "en")
db_session_with_containers.add(document)
@@ -244,8 +244,11 @@ class TestDeleteSegmentFromIndexTask:
mock_processor = MagicMock()
mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor
# Extract segment IDs for the task
segment_ids = [segment.id for segment in segments]
# Execute the task
result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id)
result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id, segment_ids)
# Verify the task completed successfully
assert result is None # Task should return None on success
@@ -279,7 +282,7 @@ class TestDeleteSegmentFromIndexTask:
index_node_ids = [f"node_{fake.uuid4()}" for _ in range(3)]
# Execute the task with non-existent dataset
result = delete_segment_from_index_task(index_node_ids, non_existent_dataset_id, non_existent_document_id)
result = delete_segment_from_index_task(index_node_ids, non_existent_dataset_id, non_existent_document_id, [])
# Verify the task completed without exceptions
assert result is None # Task should return None when dataset not found
@@ -305,7 +308,7 @@ class TestDeleteSegmentFromIndexTask:
index_node_ids = [f"node_{fake.uuid4()}" for _ in range(3)]
# Execute the task with non-existent document
result = delete_segment_from_index_task(index_node_ids, dataset.id, non_existent_document_id)
result = delete_segment_from_index_task(index_node_ids, dataset.id, non_existent_document_id, [])
# Verify the task completed without exceptions
assert result is None # Task should return None when document not found
@@ -330,9 +333,10 @@ class TestDeleteSegmentFromIndexTask:
segments = self._create_test_document_segments(db_session_with_containers, document, account, 3, fake)
index_node_ids = [segment.index_node_id for segment in segments]
segment_ids = [segment.id for segment in segments]
# Execute the task with disabled document
result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id)
result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id, segment_ids)
# Verify the task completed without exceptions
assert result is None # Task should return None when document is disabled
@@ -357,9 +361,10 @@ class TestDeleteSegmentFromIndexTask:
segments = self._create_test_document_segments(db_session_with_containers, document, account, 3, fake)
index_node_ids = [segment.index_node_id for segment in segments]
segment_ids = [segment.id for segment in segments]
# Execute the task with archived document
result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id)
result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id, segment_ids)
# Verify the task completed without exceptions
assert result is None # Task should return None when document is archived
@@ -386,9 +391,10 @@ class TestDeleteSegmentFromIndexTask:
segments = self._create_test_document_segments(db_session_with_containers, document, account, 3, fake)
index_node_ids = [segment.index_node_id for segment in segments]
segment_ids = [segment.id for segment in segments]
# Execute the task with incomplete indexing
result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id)
result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id, segment_ids)
# Verify the task completed without exceptions
assert result is None # Task should return None when indexing is not completed
@@ -409,7 +415,11 @@ class TestDeleteSegmentFromIndexTask:
fake = Faker()
# Test different document forms
document_forms = [IndexType.PARAGRAPH_INDEX, IndexType.QA_INDEX, IndexType.PARENT_CHILD_INDEX]
document_forms = [
IndexStructureType.PARAGRAPH_INDEX,
IndexStructureType.QA_INDEX,
IndexStructureType.PARENT_CHILD_INDEX,
]
for doc_form in document_forms:
# Create test data for each document form
@@ -420,13 +430,14 @@ class TestDeleteSegmentFromIndexTask:
segments = self._create_test_document_segments(db_session_with_containers, document, account, 2, fake)
index_node_ids = [segment.index_node_id for segment in segments]
segment_ids = [segment.id for segment in segments]
# Mock the index processor
mock_processor = MagicMock()
mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor
# Execute the task
result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id)
result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id, segment_ids)
# Verify the task completed successfully
assert result is None
@@ -469,6 +480,7 @@ class TestDeleteSegmentFromIndexTask:
segments = self._create_test_document_segments(db_session_with_containers, document, account, 3, fake)
index_node_ids = [segment.index_node_id for segment in segments]
segment_ids = [segment.id for segment in segments]
# Mock the index processor to raise an exception
mock_processor = MagicMock()
@@ -476,7 +488,7 @@ class TestDeleteSegmentFromIndexTask:
mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor
# Execute the task - should not raise exception
result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id)
result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id, segment_ids)
# Verify the task completed without raising exceptions
assert result is None # Task should return None even when exceptions occur
@@ -518,7 +530,7 @@ class TestDeleteSegmentFromIndexTask:
mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor
# Execute the task
result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id)
result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id, [])
# Verify the task completed successfully
assert result is None
@@ -555,13 +567,14 @@ class TestDeleteSegmentFromIndexTask:
# Create large number of segments
segments = self._create_test_document_segments(db_session_with_containers, document, account, 50, fake)
index_node_ids = [segment.index_node_id for segment in segments]
segment_ids = [segment.id for segment in segments]
# Mock the index processor
mock_processor = MagicMock()
mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor
# Execute the task
result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id)
result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id, segment_ids)
# Verify the task completed successfully
assert result is None

View File

@@ -3,7 +3,7 @@ from unittest.mock import MagicMock, patch
import pytest
from faker import Faker
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.constant.index_type import IndexStructureType
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
@@ -95,7 +95,7 @@ class TestEnableSegmentsToIndexTask:
created_by=account.id,
indexing_status="completed",
enabled=True,
doc_form=IndexType.PARAGRAPH_INDEX,
doc_form=IndexStructureType.PARAGRAPH_INDEX,
)
db.session.add(document)
db.session.commit()
@@ -166,7 +166,7 @@ class TestEnableSegmentsToIndexTask:
)
# Update document to use different index type
document.doc_form = IndexType.QA_INDEX
document.doc_form = IndexStructureType.QA_INDEX
db.session.commit()
# Refresh dataset to ensure doc_form property reflects the updated document
@@ -185,7 +185,9 @@ class TestEnableSegmentsToIndexTask:
enable_segments_to_index_task(segment_ids, dataset.id, document.id)
# Assert: Verify different index type handling
mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(IndexType.QA_INDEX)
mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(
IndexStructureType.QA_INDEX
)
mock_external_service_dependencies["index_processor"].load.assert_called_once()
# Verify the load method was called with correct parameters
@@ -328,7 +330,9 @@ class TestEnableSegmentsToIndexTask:
enable_segments_to_index_task(non_existent_segment_ids, dataset.id, document.id)
# Assert: Verify index processor was created but load was not called
mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(IndexType.PARAGRAPH_INDEX)
mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(
IndexStructureType.PARAGRAPH_INDEX
)
mock_external_service_dependencies["index_processor"].load.assert_not_called()
def test_enable_segments_to_index_with_parent_child_structure(
@@ -350,7 +354,7 @@ class TestEnableSegmentsToIndexTask:
)
# Update document to use parent-child index type
document.doc_form = IndexType.PARENT_CHILD_INDEX
document.doc_form = IndexStructureType.PARENT_CHILD_INDEX
db.session.commit()
# Refresh dataset to ensure doc_form property reflects the updated document
@@ -383,7 +387,7 @@ class TestEnableSegmentsToIndexTask:
# Assert: Verify parent-child index processing
mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(
IndexType.PARENT_CHILD_INDEX
IndexStructureType.PARENT_CHILD_INDEX
)
mock_external_service_dependencies["index_processor"].load.assert_called_once()

View File

@@ -53,7 +53,7 @@ from sqlalchemy.exc import IntegrityError
from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.model_entities import ModelPropertyKey
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.entities.text_embedding_entities import EmbeddingResult, EmbeddingUsage
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeConnectionError,
@@ -99,10 +99,10 @@ class TestCacheEmbeddingDocuments:
@pytest.fixture
def sample_embedding_result(self):
"""Create a sample TextEmbeddingResult for testing.
"""Create a sample EmbeddingResult for testing.
Returns:
TextEmbeddingResult: Mock embedding result with proper structure
EmbeddingResult: Mock embedding result with proper structure
"""
# Create normalized embedding vectors (dimension 1536 for ada-002)
embedding_vector = np.random.randn(1536)
@@ -118,7 +118,7 @@ class TestCacheEmbeddingDocuments:
latency=0.5,
)
return TextEmbeddingResult(
return EmbeddingResult(
model="text-embedding-ada-002",
embeddings=[normalized_vector],
usage=usage,
@@ -197,7 +197,7 @@ class TestCacheEmbeddingDocuments:
latency=0.8,
)
embedding_result = TextEmbeddingResult(
embedding_result = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=embeddings,
usage=usage,
@@ -296,7 +296,7 @@ class TestCacheEmbeddingDocuments:
latency=0.6,
)
embedding_result = TextEmbeddingResult(
embedding_result = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=new_embeddings,
usage=usage,
@@ -386,7 +386,7 @@ class TestCacheEmbeddingDocuments:
latency=0.5,
)
return TextEmbeddingResult(
return EmbeddingResult(
model="text-embedding-ada-002",
embeddings=embeddings,
usage=usage,
@@ -449,7 +449,7 @@ class TestCacheEmbeddingDocuments:
latency=0.5,
)
embedding_result = TextEmbeddingResult(
embedding_result = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=[valid_vector.tolist(), nan_vector],
usage=usage,
@@ -629,7 +629,7 @@ class TestCacheEmbeddingQuery:
latency=0.3,
)
embedding_result = TextEmbeddingResult(
embedding_result = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=[normalized],
usage=usage,
@@ -728,7 +728,7 @@ class TestCacheEmbeddingQuery:
latency=0.3,
)
embedding_result = TextEmbeddingResult(
embedding_result = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=[nan_vector],
usage=usage,
@@ -793,7 +793,7 @@ class TestCacheEmbeddingQuery:
latency=0.3,
)
embedding_result = TextEmbeddingResult(
embedding_result = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=[normalized],
usage=usage,
@@ -873,13 +873,13 @@ class TestEmbeddingModelSwitching:
latency=0.3,
)
result_ada = TextEmbeddingResult(
result_ada = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=[normalized_ada],
usage=usage,
)
result_3_small = TextEmbeddingResult(
result_3_small = EmbeddingResult(
model="text-embedding-3-small",
embeddings=[normalized_3_small],
usage=usage,
@@ -953,13 +953,13 @@ class TestEmbeddingModelSwitching:
latency=0.4,
)
result_openai = TextEmbeddingResult(
result_openai = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=[normalized_openai],
usage=usage_openai,
)
result_cohere = TextEmbeddingResult(
result_cohere = EmbeddingResult(
model="embed-english-v3.0",
embeddings=[normalized_cohere],
usage=usage_cohere,
@@ -1042,7 +1042,7 @@ class TestEmbeddingDimensionValidation:
latency=0.7,
)
embedding_result = TextEmbeddingResult(
embedding_result = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=embeddings,
usage=usage,
@@ -1095,7 +1095,7 @@ class TestEmbeddingDimensionValidation:
latency=0.5,
)
embedding_result = TextEmbeddingResult(
embedding_result = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=embeddings,
usage=usage,
@@ -1148,7 +1148,7 @@ class TestEmbeddingDimensionValidation:
latency=0.3,
)
result_ada = TextEmbeddingResult(
result_ada = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=[normalized_ada],
usage=usage_ada,
@@ -1181,7 +1181,7 @@ class TestEmbeddingDimensionValidation:
latency=0.4,
)
result_cohere = TextEmbeddingResult(
result_cohere = EmbeddingResult(
model="embed-english-v3.0",
embeddings=[normalized_cohere],
usage=usage_cohere,
@@ -1279,7 +1279,7 @@ class TestEmbeddingEdgeCases:
latency=0.1,
)
embedding_result = TextEmbeddingResult(
embedding_result = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=[normalized],
usage=usage,
@@ -1322,7 +1322,7 @@ class TestEmbeddingEdgeCases:
latency=1.5,
)
embedding_result = TextEmbeddingResult(
embedding_result = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=[normalized],
usage=usage,
@@ -1370,7 +1370,7 @@ class TestEmbeddingEdgeCases:
latency=0.5,
)
embedding_result = TextEmbeddingResult(
embedding_result = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=embeddings,
usage=usage,
@@ -1422,7 +1422,7 @@ class TestEmbeddingEdgeCases:
latency=0.2,
)
embedding_result = TextEmbeddingResult(
embedding_result = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=embeddings,
usage=usage,
@@ -1478,7 +1478,7 @@ class TestEmbeddingEdgeCases:
)
# Model returns embeddings for all texts
embedding_result = TextEmbeddingResult(
embedding_result = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=embeddings,
usage=usage,
@@ -1546,7 +1546,7 @@ class TestEmbeddingEdgeCases:
latency=0.8,
)
embedding_result = TextEmbeddingResult(
embedding_result = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=embeddings,
usage=usage,
@@ -1603,7 +1603,7 @@ class TestEmbeddingEdgeCases:
latency=0.3,
)
embedding_result = TextEmbeddingResult(
embedding_result = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=[normalized],
usage=usage,
@@ -1657,7 +1657,7 @@ class TestEmbeddingEdgeCases:
latency=0.5,
)
embedding_result = TextEmbeddingResult(
embedding_result = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=embeddings,
usage=usage,
@@ -1757,7 +1757,7 @@ class TestEmbeddingCachePerformance:
latency=0.3,
)
embedding_result = TextEmbeddingResult(
embedding_result = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=[normalized],
usage=usage,
@@ -1826,7 +1826,7 @@ class TestEmbeddingCachePerformance:
latency=0.5,
)
return TextEmbeddingResult(
return EmbeddingResult(
model="text-embedding-ada-002",
embeddings=embeddings,
usage=usage,
@@ -1888,7 +1888,7 @@ class TestEmbeddingCachePerformance:
latency=0.3,
)
embedding_result = TextEmbeddingResult(
embedding_result = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=[normalized],
usage=usage,

View File

@@ -62,7 +62,7 @@ from core.indexing_runner import (
IndexingRunner,
)
from core.model_runtime.entities.model_entities import ModelType
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.models.document import ChildDocument, Document
from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset, DatasetProcessRule
@@ -112,7 +112,7 @@ def create_mock_dataset_document(
document_id: str | None = None,
dataset_id: str | None = None,
tenant_id: str | None = None,
doc_form: str = IndexType.PARAGRAPH_INDEX,
doc_form: str = IndexStructureType.PARAGRAPH_INDEX,
data_source_type: str = "upload_file",
doc_language: str = "English",
) -> Mock:
@@ -133,8 +133,8 @@ def create_mock_dataset_document(
Mock: A configured mock DatasetDocument object with all required attributes.
Example:
>>> doc = create_mock_dataset_document(doc_form=IndexType.QA_INDEX)
>>> assert doc.doc_form == IndexType.QA_INDEX
>>> doc = create_mock_dataset_document(doc_form=IndexStructureType.QA_INDEX)
>>> assert doc.doc_form == IndexStructureType.QA_INDEX
"""
doc = Mock(spec=DatasetDocument)
doc.id = document_id or str(uuid.uuid4())
@@ -276,7 +276,7 @@ class TestIndexingRunnerExtract:
doc.id = str(uuid.uuid4())
doc.dataset_id = str(uuid.uuid4())
doc.tenant_id = str(uuid.uuid4())
doc.doc_form = IndexType.PARAGRAPH_INDEX
doc.doc_form = IndexStructureType.PARAGRAPH_INDEX
doc.data_source_type = "upload_file"
doc.data_source_info_dict = {"upload_file_id": str(uuid.uuid4())}
return doc
@@ -616,7 +616,7 @@ class TestIndexingRunnerLoad:
doc = Mock(spec=DatasetDocument)
doc.id = str(uuid.uuid4())
doc.dataset_id = str(uuid.uuid4())
doc.doc_form = IndexType.PARAGRAPH_INDEX
doc.doc_form = IndexStructureType.PARAGRAPH_INDEX
return doc
@pytest.fixture
@@ -700,7 +700,7 @@ class TestIndexingRunnerLoad:
"""Test loading with parent-child index structure."""
# Arrange
runner = IndexingRunner()
sample_dataset_document.doc_form = IndexType.PARENT_CHILD_INDEX
sample_dataset_document.doc_form = IndexStructureType.PARENT_CHILD_INDEX
sample_dataset.indexing_technique = "high_quality"
# Add child documents
@@ -775,7 +775,7 @@ class TestIndexingRunnerRun:
doc.id = str(uuid.uuid4())
doc.dataset_id = str(uuid.uuid4())
doc.tenant_id = str(uuid.uuid4())
doc.doc_form = IndexType.PARAGRAPH_INDEX
doc.doc_form = IndexStructureType.PARAGRAPH_INDEX
doc.doc_language = "English"
doc.data_source_type = "upload_file"
doc.data_source_info_dict = {"upload_file_id": str(uuid.uuid4())}
@@ -802,6 +802,21 @@ class TestIndexingRunnerRun:
mock_process_rule.to_dict.return_value = {"mode": "automatic", "rules": {}}
mock_dependencies["db"].session.scalar.return_value = mock_process_rule
# Mock current_user (Account) for _transform
mock_current_user = MagicMock()
mock_current_user.set_tenant_id = MagicMock()
# Setup db.session.query to return different results based on the model
def mock_query_side_effect(model):
mock_query_result = MagicMock()
if model.__name__ == "Dataset":
mock_query_result.filter_by.return_value.first.return_value = mock_dataset
elif model.__name__ == "Account":
mock_query_result.filter_by.return_value.first.return_value = mock_current_user
return mock_query_result
mock_dependencies["db"].session.query.side_effect = mock_query_side_effect
# Mock processor
mock_processor = MagicMock()
mock_dependencies["factory"].return_value.init_index_processor.return_value = mock_processor
@@ -1268,7 +1283,7 @@ class TestIndexingRunnerLoadSegments:
doc.id = str(uuid.uuid4())
doc.dataset_id = str(uuid.uuid4())
doc.created_by = str(uuid.uuid4())
doc.doc_form = IndexType.PARAGRAPH_INDEX
doc.doc_form = IndexStructureType.PARAGRAPH_INDEX
return doc
@pytest.fixture
@@ -1316,7 +1331,7 @@ class TestIndexingRunnerLoadSegments:
"""Test loading segments for parent-child index."""
# Arrange
runner = IndexingRunner()
sample_dataset_document.doc_form = IndexType.PARENT_CHILD_INDEX
sample_dataset_document.doc_form = IndexStructureType.PARENT_CHILD_INDEX
# Add child documents
for doc in sample_documents:
@@ -1413,7 +1428,7 @@ class TestIndexingRunnerEstimate:
tenant_id=tenant_id,
extract_settings=extract_settings,
tmp_processing_rule={"mode": "automatic", "rules": {}},
doc_form=IndexType.PARAGRAPH_INDEX,
doc_form=IndexStructureType.PARAGRAPH_INDEX,
)

View File

@@ -26,6 +26,18 @@ from core.rag.rerank.rerank_type import RerankMode
from core.rag.rerank.weight_rerank import WeightRerankRunner
def create_mock_model_instance():
"""Create a properly configured mock ModelInstance for reranking tests."""
mock_instance = Mock(spec=ModelInstance)
# Setup provider_model_bundle chain for check_model_support_vision
mock_instance.provider_model_bundle = Mock()
mock_instance.provider_model_bundle.configuration = Mock()
mock_instance.provider_model_bundle.configuration.tenant_id = "test-tenant-id"
mock_instance.provider = "test-provider"
mock_instance.model = "test-model"
return mock_instance
class TestRerankModelRunner:
"""Unit tests for RerankModelRunner.
@@ -37,10 +49,23 @@ class TestRerankModelRunner:
- Metadata preservation and score injection
"""
@pytest.fixture(autouse=True)
def mock_model_manager(self):
"""Auto-use fixture to patch ModelManager for all tests in this class."""
with patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm:
mock_mm.return_value.check_model_support_vision.return_value = False
yield mock_mm
@pytest.fixture
def mock_model_instance(self):
"""Create a mock ModelInstance for reranking."""
mock_instance = Mock(spec=ModelInstance)
# Setup provider_model_bundle chain for check_model_support_vision
mock_instance.provider_model_bundle = Mock()
mock_instance.provider_model_bundle.configuration = Mock()
mock_instance.provider_model_bundle.configuration.tenant_id = "test-tenant-id"
mock_instance.provider = "test-provider"
mock_instance.model = "test-model"
return mock_instance
@pytest.fixture
@@ -803,7 +828,7 @@ class TestRerankRunnerFactory:
- Parameters are forwarded to runner constructor
"""
# Arrange: Mock model instance
mock_model_instance = Mock(spec=ModelInstance)
mock_model_instance = create_mock_model_instance()
# Act: Create runner via factory
runner = RerankRunnerFactory.create_rerank_runner(
@@ -865,7 +890,7 @@ class TestRerankRunnerFactory:
- String values are properly matched
"""
# Arrange: Mock model instance
mock_model_instance = Mock(spec=ModelInstance)
mock_model_instance = create_mock_model_instance()
# Act: Create runner using enum value
runner = RerankRunnerFactory.create_rerank_runner(
@@ -886,6 +911,13 @@ class TestRerankIntegration:
- Real-world usage scenarios
"""
@pytest.fixture(autouse=True)
def mock_model_manager(self):
"""Auto-use fixture to patch ModelManager for all tests in this class."""
with patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm:
mock_mm.return_value.check_model_support_vision.return_value = False
yield mock_mm
def test_model_reranking_full_workflow(self):
"""Test complete model-based reranking workflow.
@@ -895,7 +927,7 @@ class TestRerankIntegration:
- Top results are returned correctly
"""
# Arrange: Create mock model and documents
mock_model_instance = Mock(spec=ModelInstance)
mock_model_instance = create_mock_model_instance()
mock_rerank_result = RerankResult(
model="bge-reranker-base",
docs=[
@@ -951,7 +983,7 @@ class TestRerankIntegration:
- Normalization is consistent
"""
# Arrange: Create mock model with various scores
mock_model_instance = Mock(spec=ModelInstance)
mock_model_instance = create_mock_model_instance()
mock_rerank_result = RerankResult(
model="bge-reranker-base",
docs=[
@@ -991,6 +1023,13 @@ class TestRerankEdgeCases:
- Concurrent reranking scenarios
"""
@pytest.fixture(autouse=True)
def mock_model_manager(self):
"""Auto-use fixture to patch ModelManager for all tests in this class."""
with patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm:
mock_mm.return_value.check_model_support_vision.return_value = False
yield mock_mm
def test_rerank_with_empty_metadata(self):
"""Test reranking when documents have empty metadata.
@@ -1000,7 +1039,7 @@ class TestRerankEdgeCases:
- Empty metadata documents are processed correctly
"""
# Arrange: Create documents with empty metadata
mock_model_instance = Mock(spec=ModelInstance)
mock_model_instance = create_mock_model_instance()
mock_rerank_result = RerankResult(
model="bge-reranker-base",
docs=[
@@ -1046,7 +1085,7 @@ class TestRerankEdgeCases:
- Score comparison logic works at boundary
"""
# Arrange: Create mock with various scores including negatives
mock_model_instance = Mock(spec=ModelInstance)
mock_model_instance = create_mock_model_instance()
mock_rerank_result = RerankResult(
model="bge-reranker-base",
docs=[
@@ -1082,7 +1121,7 @@ class TestRerankEdgeCases:
- No overflow or precision issues
"""
# Arrange: All documents with perfect scores
mock_model_instance = Mock(spec=ModelInstance)
mock_model_instance = create_mock_model_instance()
mock_rerank_result = RerankResult(
model="bge-reranker-base",
docs=[
@@ -1117,7 +1156,7 @@ class TestRerankEdgeCases:
- Content encoding is preserved
"""
# Arrange: Documents with special characters
mock_model_instance = Mock(spec=ModelInstance)
mock_model_instance = create_mock_model_instance()
mock_rerank_result = RerankResult(
model="bge-reranker-base",
docs=[
@@ -1159,7 +1198,7 @@ class TestRerankEdgeCases:
- Content is not truncated unexpectedly
"""
# Arrange: Documents with very long content
mock_model_instance = Mock(spec=ModelInstance)
mock_model_instance = create_mock_model_instance()
long_content = "This is a very long document. " * 1000 # ~30,000 characters
mock_rerank_result = RerankResult(
@@ -1196,7 +1235,7 @@ class TestRerankEdgeCases:
- All documents are processed correctly
"""
# Arrange: Create 100 documents
mock_model_instance = Mock(spec=ModelInstance)
mock_model_instance = create_mock_model_instance()
num_docs = 100
# Create rerank results for all documents
@@ -1287,7 +1326,7 @@ class TestRerankEdgeCases:
- Documents can still be ranked
"""
# Arrange: Empty query
mock_model_instance = Mock(spec=ModelInstance)
mock_model_instance = create_mock_model_instance()
mock_rerank_result = RerankResult(
model="bge-reranker-base",
docs=[
@@ -1325,6 +1364,13 @@ class TestRerankPerformance:
- Score calculation optimization
"""
@pytest.fixture(autouse=True)
def mock_model_manager(self):
"""Auto-use fixture to patch ModelManager for all tests in this class."""
with patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm:
mock_mm.return_value.check_model_support_vision.return_value = False
yield mock_mm
def test_rerank_batch_processing(self):
"""Test that documents are processed in a single batch.
@@ -1334,7 +1380,7 @@ class TestRerankPerformance:
- Efficient batch processing
"""
# Arrange: Multiple documents
mock_model_instance = Mock(spec=ModelInstance)
mock_model_instance = create_mock_model_instance()
mock_rerank_result = RerankResult(
model="bge-reranker-base",
docs=[RerankDocument(index=i, text=f"Doc {i}", score=0.9 - i * 0.1) for i in range(5)],
@@ -1435,6 +1481,13 @@ class TestRerankErrorHandling:
- Error propagation
"""
@pytest.fixture(autouse=True)
def mock_model_manager(self):
"""Auto-use fixture to patch ModelManager for all tests in this class."""
with patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm:
mock_mm.return_value.check_model_support_vision.return_value = False
yield mock_mm
def test_rerank_model_invocation_error(self):
"""Test handling of model invocation errors.
@@ -1444,7 +1497,7 @@ class TestRerankErrorHandling:
- Error context is preserved
"""
# Arrange: Mock model that raises exception
mock_model_instance = Mock(spec=ModelInstance)
mock_model_instance = create_mock_model_instance()
mock_model_instance.invoke_rerank.side_effect = RuntimeError("Model invocation failed")
documents = [
@@ -1470,7 +1523,7 @@ class TestRerankErrorHandling:
- Invalid results don't corrupt output
"""
# Arrange: Rerank result with invalid index
mock_model_instance = Mock(spec=ModelInstance)
mock_model_instance = create_mock_model_instance()
mock_rerank_result = RerankResult(
model="bge-reranker-base",
docs=[

View File

@@ -425,15 +425,15 @@ class TestRetrievalService:
# ==================== Vector Search Tests ====================
@patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search")
@patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve")
@patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
def test_vector_search_basic(self, mock_get_dataset, mock_embedding_search, mock_dataset, sample_documents):
def test_vector_search_basic(self, mock_get_dataset, mock_retrieve, mock_dataset, sample_documents):
"""
Test basic vector/semantic search functionality.
This test validates the core vector search flow:
1. Dataset is retrieved from database
2. embedding_search is called via ThreadPoolExecutor
2. _retrieve is called via ThreadPoolExecutor
3. Documents are added to shared all_documents list
4. Results are returned to caller
@@ -447,28 +447,28 @@ class TestRetrievalService:
# Set up the mock dataset that will be "retrieved" from database
mock_get_dataset.return_value = mock_dataset
# Create a side effect function that simulates embedding_search behavior
# In the real implementation, embedding_search:
# 1. Gets the dataset
# 2. Creates a Vector instance
# 3. Calls search_by_vector with embeddings
# 4. Extends all_documents with results
def side_effect_embedding_search(
# Create a side effect function that simulates _retrieve behavior
# _retrieve modifies the all_documents list in place
def side_effect_retrieve(
flask_app,
dataset_id,
query,
top_k,
score_threshold,
reranking_model,
all_documents,
retrieval_method,
exceptions,
dataset,
query=None,
top_k=4,
score_threshold=None,
reranking_model=None,
reranking_mode="reranking_model",
weights=None,
document_ids_filter=None,
attachment_id=None,
all_documents=None,
exceptions=None,
):
"""Simulate embedding_search adding documents to the shared list."""
all_documents.extend(sample_documents)
"""Simulate _retrieve adding documents to the shared list."""
if all_documents is not None:
all_documents.extend(sample_documents)
mock_embedding_search.side_effect = side_effect_embedding_search
mock_retrieve.side_effect = side_effect_retrieve
# Define test parameters
query = "What is Python?" # Natural language query
@@ -481,7 +481,7 @@ class TestRetrievalService:
# 1. Check if query is empty (early return if so)
# 2. Get the dataset using _get_dataset
# 3. Create ThreadPoolExecutor
# 4. Submit embedding_search task
# 4. Submit _retrieve task
# 5. Wait for completion
# 6. Return all_documents list
results = RetrievalService.retrieve(
@@ -502,15 +502,13 @@ class TestRetrievalService:
# Verify documents maintain their scores (highest score first in sample_documents)
assert results[0].metadata["score"] == 0.95, "First document should have highest score from sample_documents"
# Verify embedding_search was called exactly once
# Verify _retrieve was called exactly once
# This confirms the search method was invoked by ThreadPoolExecutor
mock_embedding_search.assert_called_once()
mock_retrieve.assert_called_once()
@patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search")
@patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve")
@patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
def test_vector_search_with_document_filter(
self, mock_get_dataset, mock_embedding_search, mock_dataset, sample_documents
):
def test_vector_search_with_document_filter(self, mock_get_dataset, mock_retrieve, mock_dataset, sample_documents):
"""
Test vector search with document ID filtering.
@@ -522,21 +520,25 @@ class TestRetrievalService:
mock_get_dataset.return_value = mock_dataset
filtered_docs = [sample_documents[0]]
def side_effect_embedding_search(
def side_effect_retrieve(
flask_app,
dataset_id,
query,
top_k,
score_threshold,
reranking_model,
all_documents,
retrieval_method,
exceptions,
dataset,
query=None,
top_k=4,
score_threshold=None,
reranking_model=None,
reranking_mode="reranking_model",
weights=None,
document_ids_filter=None,
attachment_id=None,
all_documents=None,
exceptions=None,
):
all_documents.extend(filtered_docs)
if all_documents is not None:
all_documents.extend(filtered_docs)
mock_embedding_search.side_effect = side_effect_embedding_search
mock_retrieve.side_effect = side_effect_retrieve
document_ids_filter = [sample_documents[0].metadata["document_id"]]
# Act
@@ -552,12 +554,12 @@ class TestRetrievalService:
assert len(results) == 1
assert results[0].metadata["doc_id"] == "doc1"
# Verify document_ids_filter was passed
call_kwargs = mock_embedding_search.call_args.kwargs
call_kwargs = mock_retrieve.call_args.kwargs
assert call_kwargs["document_ids_filter"] == document_ids_filter
@patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search")
@patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve")
@patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
def test_vector_search_empty_results(self, mock_get_dataset, mock_embedding_search, mock_dataset):
def test_vector_search_empty_results(self, mock_get_dataset, mock_retrieve, mock_dataset):
"""
Test vector search when no results match the query.
@@ -567,8 +569,8 @@ class TestRetrievalService:
"""
# Arrange
mock_get_dataset.return_value = mock_dataset
# embedding_search doesn't add anything to all_documents
mock_embedding_search.side_effect = lambda *args, **kwargs: None
# _retrieve doesn't add anything to all_documents
mock_retrieve.side_effect = lambda *args, **kwargs: None
# Act
results = RetrievalService.retrieve(
@@ -583,9 +585,9 @@ class TestRetrievalService:
# ==================== Keyword Search Tests ====================
@patch("core.rag.datasource.retrieval_service.RetrievalService.keyword_search")
@patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve")
@patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
def test_keyword_search_basic(self, mock_get_dataset, mock_keyword_search, mock_dataset, sample_documents):
def test_keyword_search_basic(self, mock_get_dataset, mock_retrieve, mock_dataset, sample_documents):
"""
Test basic keyword search functionality.
@@ -597,12 +599,25 @@ class TestRetrievalService:
# Arrange
mock_get_dataset.return_value = mock_dataset
def side_effect_keyword_search(
flask_app, dataset_id, query, top_k, all_documents, exceptions, document_ids_filter=None
def side_effect_retrieve(
flask_app,
retrieval_method,
dataset,
query=None,
top_k=4,
score_threshold=None,
reranking_model=None,
reranking_mode="reranking_model",
weights=None,
document_ids_filter=None,
attachment_id=None,
all_documents=None,
exceptions=None,
):
all_documents.extend(sample_documents)
if all_documents is not None:
all_documents.extend(sample_documents)
mock_keyword_search.side_effect = side_effect_keyword_search
mock_retrieve.side_effect = side_effect_retrieve
query = "Python programming"
top_k = 3
@@ -618,7 +633,7 @@ class TestRetrievalService:
# Assert
assert len(results) == 3
assert all(isinstance(doc, Document) for doc in results)
mock_keyword_search.assert_called_once()
mock_retrieve.assert_called_once()
@patch("core.rag.datasource.retrieval_service.RetrievalService.keyword_search")
@patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
@@ -1147,11 +1162,9 @@ class TestRetrievalService:
# ==================== Metadata Filtering Tests ====================
@patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search")
@patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve")
@patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
def test_vector_search_with_metadata_filter(
self, mock_get_dataset, mock_embedding_search, mock_dataset, sample_documents
):
def test_vector_search_with_metadata_filter(self, mock_get_dataset, mock_retrieve, mock_dataset, sample_documents):
"""
Test vector search with metadata-based document filtering.
@@ -1166,21 +1179,25 @@ class TestRetrievalService:
filtered_doc = sample_documents[0]
filtered_doc.metadata["category"] = "programming"
def side_effect_embedding(
def side_effect_retrieve(
flask_app,
dataset_id,
query,
top_k,
score_threshold,
reranking_model,
all_documents,
retrieval_method,
exceptions,
dataset,
query=None,
top_k=4,
score_threshold=None,
reranking_model=None,
reranking_mode="reranking_model",
weights=None,
document_ids_filter=None,
attachment_id=None,
all_documents=None,
exceptions=None,
):
all_documents.append(filtered_doc)
if all_documents is not None:
all_documents.append(filtered_doc)
mock_embedding_search.side_effect = side_effect_embedding
mock_retrieve.side_effect = side_effect_retrieve
# Act
results = RetrievalService.retrieve(
@@ -1243,9 +1260,9 @@ class TestRetrievalService:
# Assert
assert results == []
@patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search")
@patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve")
@patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
def test_retrieve_with_exception_handling(self, mock_get_dataset, mock_embedding_search, mock_dataset):
def test_retrieve_with_exception_handling(self, mock_get_dataset, mock_retrieve, mock_dataset):
"""
Test that exceptions during retrieval are properly handled.
@@ -1256,22 +1273,26 @@ class TestRetrievalService:
# Arrange
mock_get_dataset.return_value = mock_dataset
# Make embedding_search add an exception to the exceptions list
# Make _retrieve add an exception to the exceptions list
def side_effect_with_exception(
flask_app,
dataset_id,
query,
top_k,
score_threshold,
reranking_model,
all_documents,
retrieval_method,
exceptions,
dataset,
query=None,
top_k=4,
score_threshold=None,
reranking_model=None,
reranking_mode="reranking_model",
weights=None,
document_ids_filter=None,
attachment_id=None,
all_documents=None,
exceptions=None,
):
exceptions.append("Search failed")
if exceptions is not None:
exceptions.append("Search failed")
mock_embedding_search.side_effect = side_effect_with_exception
mock_retrieve.side_effect = side_effect_with_exception
# Act & Assert
with pytest.raises(ValueError) as exc_info:
@@ -1286,9 +1307,9 @@ class TestRetrievalService:
# ==================== Score Threshold Tests ====================
@patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search")
@patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve")
@patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
def test_vector_search_with_score_threshold(self, mock_get_dataset, mock_embedding_search, mock_dataset):
def test_vector_search_with_score_threshold(self, mock_get_dataset, mock_retrieve, mock_dataset):
"""
Test vector search with score threshold filtering.
@@ -1306,21 +1327,25 @@ class TestRetrievalService:
provider="dify",
)
def side_effect_embedding(
def side_effect_retrieve(
flask_app,
dataset_id,
query,
top_k,
score_threshold,
reranking_model,
all_documents,
retrieval_method,
exceptions,
dataset,
query=None,
top_k=4,
score_threshold=None,
reranking_model=None,
reranking_mode="reranking_model",
weights=None,
document_ids_filter=None,
attachment_id=None,
all_documents=None,
exceptions=None,
):
all_documents.append(high_score_doc)
if all_documents is not None:
all_documents.append(high_score_doc)
mock_embedding_search.side_effect = side_effect_embedding
mock_retrieve.side_effect = side_effect_retrieve
score_threshold = 0.8
@@ -1339,9 +1364,9 @@ class TestRetrievalService:
# ==================== Top-K Limiting Tests ====================
@patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search")
@patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve")
@patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
def test_retrieve_respects_top_k_limit(self, mock_get_dataset, mock_embedding_search, mock_dataset):
def test_retrieve_respects_top_k_limit(self, mock_get_dataset, mock_retrieve, mock_dataset):
"""
Test that retrieval respects top_k parameter.
@@ -1362,22 +1387,26 @@ class TestRetrievalService:
for i in range(10)
]
def side_effect_embedding(
def side_effect_retrieve(
flask_app,
dataset_id,
query,
top_k,
score_threshold,
reranking_model,
all_documents,
retrieval_method,
exceptions,
dataset,
query=None,
top_k=4,
score_threshold=None,
reranking_model=None,
reranking_mode="reranking_model",
weights=None,
document_ids_filter=None,
attachment_id=None,
all_documents=None,
exceptions=None,
):
# Return only top_k documents
all_documents.extend(many_docs[:top_k])
if all_documents is not None:
all_documents.extend(many_docs[:top_k])
mock_embedding_search.side_effect = side_effect_embedding
mock_retrieve.side_effect = side_effect_retrieve
top_k = 3
@@ -1390,9 +1419,9 @@ class TestRetrievalService:
)
# Assert
# Verify top_k was passed to embedding_search
assert mock_embedding_search.called
call_kwargs = mock_embedding_search.call_args.kwargs
# Verify _retrieve was called
assert mock_retrieve.called
call_kwargs = mock_retrieve.call_args.kwargs
assert call_kwargs["top_k"] == top_k
# Verify we got the right number of results
assert len(results) == top_k
@@ -1421,11 +1450,9 @@ class TestRetrievalService:
# ==================== Reranking Tests ====================
@patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search")
@patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve")
@patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
def test_semantic_search_with_reranking(
self, mock_get_dataset, mock_embedding_search, mock_dataset, sample_documents
):
def test_semantic_search_with_reranking(self, mock_get_dataset, mock_retrieve, mock_dataset, sample_documents):
"""
Test semantic search with reranking model.
@@ -1439,22 +1466,26 @@ class TestRetrievalService:
# Simulate reranking changing order
reranked_docs = list(reversed(sample_documents))
def side_effect_embedding(
def side_effect_retrieve(
flask_app,
dataset_id,
query,
top_k,
score_threshold,
reranking_model,
all_documents,
retrieval_method,
exceptions,
dataset,
query=None,
top_k=4,
score_threshold=None,
reranking_model=None,
reranking_mode="reranking_model",
weights=None,
document_ids_filter=None,
attachment_id=None,
all_documents=None,
exceptions=None,
):
# embedding_search handles reranking internally
all_documents.extend(reranked_docs)
# _retrieve handles reranking internally
if all_documents is not None:
all_documents.extend(reranked_docs)
mock_embedding_search.side_effect = side_effect_embedding
mock_retrieve.side_effect = side_effect_retrieve
reranking_model = {
"reranking_provider_name": "cohere",
@@ -1473,7 +1504,7 @@ class TestRetrievalService:
# Assert
# For semantic search with reranking, reranking_model should be passed
assert len(results) == 3
call_kwargs = mock_embedding_search.call_args.kwargs
call_kwargs = mock_retrieve.call_args.kwargs
assert call_kwargs["reranking_model"] == reranking_model

View File

@@ -8,7 +8,9 @@ from core.tools.utils.text_processing_utils import remove_leading_symbols
[
("...Hello, World!", "Hello, World!"),
("。测试中文标点", "测试中文标点"),
("!@#Test symbols", "Test symbols"),
# Note: ! is not in the removal pattern, only @# are removed, leaving "!Test symbols"
# The pattern intentionally excludes ! as per #11868 fix
("@#Test symbols", "Test symbols"),
("Hello, World!", "Hello, World!"),
("", ""),
(" ", " "),

View File

@@ -808,6 +808,19 @@ UPLOAD_FILE_BATCH_LIMIT=5
# Recommended: exe,bat,cmd,com,scr,vbs,ps1,msi,dll
UPLOAD_FILE_EXTENSION_BLACKLIST=
# Maximum number of files allowed in a single chunk attachment, default 10.
SINGLE_CHUNK_ATTACHMENT_LIMIT=10
# Maximum number of files allowed in a image batch upload operation
IMAGE_FILE_BATCH_LIMIT=10
# Maximum allowed image file size for attachments in megabytes, default 2.
ATTACHMENT_IMAGE_FILE_SIZE_LIMIT=2
# Timeout for downloading image attachments in seconds, default 60.
ATTACHMENT_IMAGE_DOWNLOAD_TIMEOUT=60
# ETL type, support: `dify`, `Unstructured`
# `dify` Dify's proprietary file extraction scheme
# `Unstructured` Unstructured.io file extraction scheme
@@ -1415,4 +1428,4 @@ WORKFLOW_SCHEDULE_POLLER_BATCH_SIZE=100
WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK=0
# Tenant isolated task queue configuration
TENANT_ISOLATED_TASK_CONCURRENCY=1
TENANT_ISOLATED_TASK_CONCURRENCY=1

View File

@@ -364,6 +364,10 @@ x-shared-env: &shared-api-worker-env
UPLOAD_FILE_SIZE_LIMIT: ${UPLOAD_FILE_SIZE_LIMIT:-15}
UPLOAD_FILE_BATCH_LIMIT: ${UPLOAD_FILE_BATCH_LIMIT:-5}
UPLOAD_FILE_EXTENSION_BLACKLIST: ${UPLOAD_FILE_EXTENSION_BLACKLIST:-}
SINGLE_CHUNK_ATTACHMENT_LIMIT: ${SINGLE_CHUNK_ATTACHMENT_LIMIT:-10}
IMAGE_FILE_BATCH_LIMIT: ${IMAGE_FILE_BATCH_LIMIT:-10}
ATTACHMENT_IMAGE_FILE_SIZE_LIMIT: ${ATTACHMENT_IMAGE_FILE_SIZE_LIMIT:-2}
ATTACHMENT_IMAGE_DOWNLOAD_TIMEOUT: ${ATTACHMENT_IMAGE_DOWNLOAD_TIMEOUT:-60}
ETL_TYPE: ${ETL_TYPE:-dify}
UNSTRUCTURED_API_URL: ${UNSTRUCTURED_API_URL:-}
UNSTRUCTURED_API_KEY: ${UNSTRUCTURED_API_KEY:-}