mirror of
https://github.com/langgenius/dify.git
synced 2026-05-16 16:00:26 -04:00
fix: knowledge hit-testing render failed. (#36106)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -39,11 +39,8 @@ class HitTestingPayload(BaseModel):
|
||||
|
||||
class DatasetsHitTestingBase:
|
||||
@staticmethod
|
||||
def _normalize_hit_testing_query(query: Any) -> str:
|
||||
"""Return the user-visible query string from legacy and current response shapes."""
|
||||
if isinstance(query, str):
|
||||
return query
|
||||
|
||||
def _extract_hit_testing_query(query: Any) -> str:
|
||||
"""Return the query string from the service response shape."""
|
||||
if isinstance(query, dict):
|
||||
content = query.get("content")
|
||||
if isinstance(content, str):
|
||||
@@ -52,15 +49,15 @@ class DatasetsHitTestingBase:
|
||||
raise ValueError("Invalid hit testing query response")
|
||||
|
||||
@staticmethod
|
||||
def _normalize_hit_testing_records(records: Any) -> list[dict[str, Any]]:
|
||||
"""Coerce nullable collection fields into lists before response validation."""
|
||||
def _prepare_hit_testing_records(records: Any) -> list[dict[str, Any]]:
|
||||
"""Ensure collection fields match the API schema before response validation."""
|
||||
if not isinstance(records, list):
|
||||
return []
|
||||
raise ValueError("Invalid hit testing records response")
|
||||
|
||||
normalized_records: list[dict[str, Any]] = []
|
||||
for record in records:
|
||||
if not isinstance(record, dict):
|
||||
continue
|
||||
raise ValueError("Invalid hit testing record response")
|
||||
|
||||
normalized_record = dict(record)
|
||||
segment = normalized_record.get("segment")
|
||||
@@ -118,8 +115,8 @@ class DatasetsHitTestingBase:
|
||||
limit=10,
|
||||
)
|
||||
return {
|
||||
"query": DatasetsHitTestingBase._normalize_hit_testing_query(response.get("query")),
|
||||
"records": DatasetsHitTestingBase._normalize_hit_testing_records(
|
||||
"query": DatasetsHitTestingBase._extract_hit_testing_query(response.get("query")),
|
||||
"records": DatasetsHitTestingBase._prepare_hit_testing_records(
|
||||
marshal(response.get("records", []), hit_testing_record_fields)
|
||||
),
|
||||
}
|
||||
|
||||
@@ -3,6 +3,8 @@ import logging
|
||||
import time
|
||||
from typing import Any, TypedDict, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.app.app_config.entities import ModelConfig
|
||||
from core.rag.datasource.retrieval_service import DefaultRetrievalModelDict, RetrievalService
|
||||
from core.rag.index_processor.constant.query_type import QueryType
|
||||
@@ -13,6 +15,7 @@ from extensions.ext_database import db
|
||||
from graphon.model_runtime.entities import LLMMode
|
||||
from models import Account
|
||||
from models.dataset import Dataset, DatasetQuery
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from models.enums import CreatorUserRole, DatasetQuerySource
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -41,6 +44,59 @@ class HitTestingRetrievalModelDict(DefaultRetrievalModelDict, total=False):
|
||||
|
||||
|
||||
class HitTestingService:
|
||||
@staticmethod
|
||||
def _dump_dataset_document(document: DatasetDocument) -> dict[str, Any]:
|
||||
return {
|
||||
"id": document.id,
|
||||
"data_source_type": document.data_source_type,
|
||||
"name": document.name,
|
||||
"doc_type": document.doc_type,
|
||||
"doc_metadata": document.doc_metadata,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _dump_retrieval_records(cls, records: list[Any]) -> list[dict[str, Any]]:
|
||||
dumped_records = [record.model_dump() for record in records]
|
||||
document_ids = {
|
||||
segment.get("document_id")
|
||||
for record in dumped_records
|
||||
if isinstance(record, dict)
|
||||
for segment in [record.get("segment")]
|
||||
if isinstance(segment, dict) and segment.get("document_id")
|
||||
}
|
||||
if not document_ids:
|
||||
return dumped_records
|
||||
|
||||
documents = {
|
||||
document.id: cls._dump_dataset_document(document)
|
||||
for document in db.session.scalars(
|
||||
select(DatasetDocument).where(DatasetDocument.id.in_(document_ids))
|
||||
).all()
|
||||
}
|
||||
|
||||
records_with_documents: list[dict[str, Any]] = []
|
||||
missing_document_ids: set[str] = set()
|
||||
for record in dumped_records:
|
||||
segment = record.get("segment")
|
||||
if not isinstance(segment, dict):
|
||||
records_with_documents.append(record)
|
||||
continue
|
||||
|
||||
document_id = segment.get("document_id")
|
||||
if document_id in documents:
|
||||
segment["document"] = documents[document_id]
|
||||
records_with_documents.append(record)
|
||||
elif document_id:
|
||||
missing_document_ids.add(document_id)
|
||||
|
||||
if missing_document_ids:
|
||||
logger.warning(
|
||||
"Skipping hit-testing records with missing documents, document_ids=%s",
|
||||
sorted(missing_document_ids),
|
||||
)
|
||||
|
||||
return records_with_documents
|
||||
|
||||
@classmethod
|
||||
def retrieve(
|
||||
cls,
|
||||
@@ -174,7 +230,7 @@ class HitTestingService:
|
||||
"query": {
|
||||
"content": query,
|
||||
},
|
||||
"records": [record.model_dump() for record in records],
|
||||
"records": cls._dump_retrieval_records(records),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -120,7 +120,7 @@ class TestParseArgs:
|
||||
class TestPerformHitTesting:
|
||||
def test_success(self, dataset):
|
||||
response = {
|
||||
"query": "hello",
|
||||
"query": {"content": "hello"},
|
||||
"records": [],
|
||||
}
|
||||
|
||||
@@ -134,7 +134,7 @@ class TestPerformHitTesting:
|
||||
assert result["query"] == "hello"
|
||||
assert result["records"] == []
|
||||
|
||||
def test_success_normalizes_legacy_query_and_nullable_list_fields(self, dataset):
|
||||
def test_success_prepares_nullable_list_fields(self, dataset):
|
||||
response = {
|
||||
"query": {"content": "hello"},
|
||||
"records": [
|
||||
@@ -170,6 +170,18 @@ class TestPerformHitTesting:
|
||||
}
|
||||
]
|
||||
|
||||
def test_invalid_query_response_raises_value_error(self):
|
||||
with pytest.raises(ValueError, match="Invalid hit testing query response"):
|
||||
DatasetsHitTestingBase._extract_hit_testing_query("hello")
|
||||
|
||||
def test_invalid_records_response_raises_value_error(self):
|
||||
with pytest.raises(ValueError, match="Invalid hit testing records response"):
|
||||
DatasetsHitTestingBase._prepare_hit_testing_records({"records": []})
|
||||
|
||||
def test_invalid_record_response_raises_value_error(self):
|
||||
with pytest.raises(ValueError, match="Invalid hit testing record response"):
|
||||
DatasetsHitTestingBase._prepare_hit_testing_records(["record"])
|
||||
|
||||
def test_index_not_initialized(self, dataset):
|
||||
with patch.object(
|
||||
HitTestingService,
|
||||
|
||||
@@ -103,7 +103,7 @@ class TestHitTestingApiPost:
|
||||
mock_dataset_svc.get_dataset.return_value = mock_dataset
|
||||
mock_dataset_svc.check_dataset_permission.return_value = None
|
||||
|
||||
mock_hit_svc.retrieve.return_value = {"query": "test query", "records": []}
|
||||
mock_hit_svc.retrieve.return_value = {"query": {"content": "test query"}, "records": []}
|
||||
mock_hit_svc.hit_testing_args_check.return_value = None
|
||||
mock_marshal.return_value = []
|
||||
|
||||
@@ -149,7 +149,7 @@ class TestHitTestingApiPost:
|
||||
"score_threshold": 0.8,
|
||||
}
|
||||
|
||||
mock_hit_svc.retrieve.return_value = {"query": "complex query", "records": []}
|
||||
mock_hit_svc.retrieve.return_value = {"query": {"content": "complex query"}, "records": []}
|
||||
mock_hit_svc.hit_testing_args_check.return_value = None
|
||||
mock_marshal.return_value = []
|
||||
|
||||
@@ -194,7 +194,7 @@ class TestHitTestingApiPost:
|
||||
|
||||
mock_dataset_svc.get_dataset.return_value = mock_dataset
|
||||
mock_dataset_svc.check_dataset_permission.return_value = None
|
||||
mock_hit_svc.retrieve.return_value = {"query": "filtered query", "records": []}
|
||||
mock_hit_svc.retrieve.return_value = {"query": {"content": "filtered query"}, "records": []}
|
||||
mock_hit_svc.hit_testing_args_check.return_value = None
|
||||
mock_marshal.return_value = []
|
||||
|
||||
@@ -232,7 +232,7 @@ class TestHitTestingApiPost:
|
||||
@patch("controllers.console.datasets.hit_testing_base.HitTestingService")
|
||||
@patch("controllers.console.datasets.hit_testing_base.DatasetService")
|
||||
@patch("controllers.console.datasets.hit_testing_base.current_user", new_callable=lambda: Mock(spec=Account))
|
||||
def test_post_normalizes_legacy_query_and_nullable_list_fields(
|
||||
def test_post_prepares_nullable_list_fields(
|
||||
self,
|
||||
mock_current_user,
|
||||
mock_dataset_svc,
|
||||
@@ -241,7 +241,7 @@ class TestHitTestingApiPost:
|
||||
mock_ns,
|
||||
app,
|
||||
):
|
||||
"""Test service API normalizes legacy query shape and nullable list fields."""
|
||||
"""Test service API prepares nullable list fields from marshalled records."""
|
||||
dataset_id = str(uuid.uuid4())
|
||||
tenant_id = str(uuid.uuid4())
|
||||
|
||||
|
||||
@@ -0,0 +1,88 @@
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from services.hit_testing_service import HitTestingService
|
||||
|
||||
|
||||
def _retrieval_record(payload: dict):
|
||||
record = Mock()
|
||||
record.model_dump.return_value = payload
|
||||
return record
|
||||
|
||||
|
||||
def _dataset_document(
|
||||
document_id: str = "document-1",
|
||||
name: str = "guide.md",
|
||||
data_source_type: str = "upload_file",
|
||||
doc_type: str | None = None,
|
||||
doc_metadata: dict | None = None,
|
||||
):
|
||||
document = Mock()
|
||||
document.id = document_id
|
||||
document.name = name
|
||||
document.data_source_type = data_source_type
|
||||
document.doc_type = doc_type
|
||||
document.doc_metadata = doc_metadata
|
||||
return document
|
||||
|
||||
|
||||
class TestHitTestingServiceDumpRecords:
|
||||
def test_dump_dataset_document_returns_frontend_required_fields(self):
|
||||
document = _dataset_document(doc_metadata={"source": "manual"})
|
||||
|
||||
assert HitTestingService._dump_dataset_document(document) == {
|
||||
"id": "document-1",
|
||||
"data_source_type": "upload_file",
|
||||
"name": "guide.md",
|
||||
"doc_type": None,
|
||||
"doc_metadata": {"source": "manual"},
|
||||
}
|
||||
|
||||
def test_dump_retrieval_records_returns_dumped_records_without_document_ids(self):
|
||||
record = _retrieval_record({"segment": None, "score": 0.95})
|
||||
|
||||
assert HitTestingService._dump_retrieval_records([record]) == [{"segment": None, "score": 0.95}]
|
||||
|
||||
def test_dump_retrieval_records_injects_documents_and_keeps_non_segment_records(self):
|
||||
record_without_segment = _retrieval_record({"segment": None, "score": 0.95})
|
||||
record_with_document = _retrieval_record(
|
||||
{
|
||||
"segment": {
|
||||
"id": "segment-1",
|
||||
"document_id": "document-1",
|
||||
},
|
||||
"score": 0.9,
|
||||
}
|
||||
)
|
||||
scalars_result = Mock()
|
||||
scalars_result.all.return_value = [_dataset_document()]
|
||||
|
||||
with patch("services.hit_testing_service.db.session.scalars", return_value=scalars_result):
|
||||
result = HitTestingService._dump_retrieval_records([record_without_segment, record_with_document])
|
||||
|
||||
assert result[0] == {"segment": None, "score": 0.95}
|
||||
assert result[1]["segment"]["document"] == {
|
||||
"id": "document-1",
|
||||
"data_source_type": "upload_file",
|
||||
"name": "guide.md",
|
||||
"doc_type": None,
|
||||
"doc_metadata": None,
|
||||
}
|
||||
|
||||
def test_dump_retrieval_records_skips_records_with_missing_documents(self, caplog):
|
||||
record = _retrieval_record(
|
||||
{
|
||||
"segment": {
|
||||
"id": "segment-1",
|
||||
"document_id": "missing-document",
|
||||
},
|
||||
"score": 0.95,
|
||||
}
|
||||
)
|
||||
scalars_result = Mock()
|
||||
scalars_result.all.return_value = []
|
||||
|
||||
with patch("services.hit_testing_service.db.session.scalars", return_value=scalars_result):
|
||||
result = HitTestingService._dump_retrieval_records([record])
|
||||
|
||||
assert result == []
|
||||
assert "Skipping hit-testing records with missing documents" in caplog.text
|
||||
Reference in New Issue
Block a user