From 2afa39cdcbb31e6cdaae6012386a9a6167141576 Mon Sep 17 00:00:00 2001 From: FFXN <31929997+FFXN@users.noreply.github.com> Date: Wed, 13 May 2026 15:31:38 +0800 Subject: [PATCH] fix: knowledge hit-testing render failed. (#36106) Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- .../console/datasets/hit_testing_base.py | 19 ++-- api/services/hit_testing_service.py | 58 +++++++++++- .../console/datasets/test_hit_testing_base.py | 16 +++- .../service_api/dataset/test_hit_testing.py | 10 +-- .../test_hit_testing_service_dump_records.py | 88 +++++++++++++++++++ 5 files changed, 172 insertions(+), 19 deletions(-) create mode 100644 api/tests/unit_tests/services/test_hit_testing_service_dump_records.py diff --git a/api/controllers/console/datasets/hit_testing_base.py b/api/controllers/console/datasets/hit_testing_base.py index 71ab1513ed..bb725a5f6c 100644 --- a/api/controllers/console/datasets/hit_testing_base.py +++ b/api/controllers/console/datasets/hit_testing_base.py @@ -39,11 +39,8 @@ class HitTestingPayload(BaseModel): class DatasetsHitTestingBase: @staticmethod - def _normalize_hit_testing_query(query: Any) -> str: - """Return the user-visible query string from legacy and current response shapes.""" - if isinstance(query, str): - return query - + def _extract_hit_testing_query(query: Any) -> str: + """Return the query string from the service response shape.""" if isinstance(query, dict): content = query.get("content") if isinstance(content, str): @@ -52,15 +49,15 @@ class DatasetsHitTestingBase: raise ValueError("Invalid hit testing query response") @staticmethod - def _normalize_hit_testing_records(records: Any) -> list[dict[str, Any]]: - """Coerce nullable collection fields into lists before response validation.""" + def _prepare_hit_testing_records(records: Any) -> list[dict[str, Any]]: + """Ensure collection fields match the API schema before response validation.""" if not isinstance(records, list): - return [] + raise ValueError("Invalid hit testing records response") normalized_records: list[dict[str, Any]] = [] for record in records: if not isinstance(record, dict): - continue + raise ValueError("Invalid hit testing record response") normalized_record = dict(record) segment = normalized_record.get("segment") @@ -118,8 +115,8 @@ class DatasetsHitTestingBase: limit=10, ) return { - "query": DatasetsHitTestingBase._normalize_hit_testing_query(response.get("query")), - "records": DatasetsHitTestingBase._normalize_hit_testing_records( + "query": DatasetsHitTestingBase._extract_hit_testing_query(response.get("query")), + "records": DatasetsHitTestingBase._prepare_hit_testing_records( marshal(response.get("records", []), hit_testing_record_fields) ), } diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index 2e5987dd28..42c531ae48 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -3,6 +3,8 @@ import logging import time from typing import Any, TypedDict, cast +from sqlalchemy import select + from core.app.app_config.entities import ModelConfig from core.rag.datasource.retrieval_service import DefaultRetrievalModelDict, RetrievalService from core.rag.index_processor.constant.query_type import QueryType @@ -13,6 +15,7 @@ from extensions.ext_database import db from graphon.model_runtime.entities import LLMMode from models import Account from models.dataset import Dataset, DatasetQuery +from models.dataset import Document as DatasetDocument from models.enums import CreatorUserRole, DatasetQuerySource logger = logging.getLogger(__name__) @@ -41,6 +44,59 @@ class HitTestingRetrievalModelDict(DefaultRetrievalModelDict, total=False): class HitTestingService: + @staticmethod + def _dump_dataset_document(document: DatasetDocument) -> dict[str, Any]: + return { + "id": document.id, + "data_source_type": document.data_source_type, + "name": document.name, + "doc_type": document.doc_type, + "doc_metadata": document.doc_metadata, + } + + @classmethod + def _dump_retrieval_records(cls, records: list[Any]) -> list[dict[str, Any]]: + dumped_records = [record.model_dump() for record in records] + document_ids = { + segment.get("document_id") + for record in dumped_records + if isinstance(record, dict) + for segment in [record.get("segment")] + if isinstance(segment, dict) and segment.get("document_id") + } + if not document_ids: + return dumped_records + + documents = { + document.id: cls._dump_dataset_document(document) + for document in db.session.scalars( + select(DatasetDocument).where(DatasetDocument.id.in_(document_ids)) + ).all() + } + + records_with_documents: list[dict[str, Any]] = [] + missing_document_ids: set[str] = set() + for record in dumped_records: + segment = record.get("segment") + if not isinstance(segment, dict): + records_with_documents.append(record) + continue + + document_id = segment.get("document_id") + if document_id in documents: + segment["document"] = documents[document_id] + records_with_documents.append(record) + elif document_id: + missing_document_ids.add(document_id) + + if missing_document_ids: + logger.warning( + "Skipping hit-testing records with missing documents, document_ids=%s", + sorted(missing_document_ids), + ) + + return records_with_documents + @classmethod def retrieve( cls, @@ -174,7 +230,7 @@ class HitTestingService: "query": { "content": query, }, - "records": [record.model_dump() for record in records], + "records": cls._dump_retrieval_records(records), } @classmethod diff --git a/api/tests/unit_tests/controllers/console/datasets/test_hit_testing_base.py b/api/tests/unit_tests/controllers/console/datasets/test_hit_testing_base.py index d29b34beb2..77e9cfeb5b 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_hit_testing_base.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_hit_testing_base.py @@ -120,7 +120,7 @@ class TestParseArgs: class TestPerformHitTesting: def test_success(self, dataset): response = { - "query": "hello", + "query": {"content": "hello"}, "records": [], } @@ -134,7 +134,7 @@ class TestPerformHitTesting: assert result["query"] == "hello" assert result["records"] == [] - def test_success_normalizes_legacy_query_and_nullable_list_fields(self, dataset): + def test_success_prepares_nullable_list_fields(self, dataset): response = { "query": {"content": "hello"}, "records": [ @@ -170,6 +170,18 @@ class TestPerformHitTesting: } ] + def test_invalid_query_response_raises_value_error(self): + with pytest.raises(ValueError, match="Invalid hit testing query response"): + DatasetsHitTestingBase._extract_hit_testing_query("hello") + + def test_invalid_records_response_raises_value_error(self): + with pytest.raises(ValueError, match="Invalid hit testing records response"): + DatasetsHitTestingBase._prepare_hit_testing_records({"records": []}) + + def test_invalid_record_response_raises_value_error(self): + with pytest.raises(ValueError, match="Invalid hit testing record response"): + DatasetsHitTestingBase._prepare_hit_testing_records(["record"]) + def test_index_not_initialized(self, dataset): with patch.object( HitTestingService, diff --git a/api/tests/unit_tests/controllers/service_api/dataset/test_hit_testing.py b/api/tests/unit_tests/controllers/service_api/dataset/test_hit_testing.py index a26cdf6563..c02416e5fe 100644 --- a/api/tests/unit_tests/controllers/service_api/dataset/test_hit_testing.py +++ b/api/tests/unit_tests/controllers/service_api/dataset/test_hit_testing.py @@ -103,7 +103,7 @@ class TestHitTestingApiPost: mock_dataset_svc.get_dataset.return_value = mock_dataset mock_dataset_svc.check_dataset_permission.return_value = None - mock_hit_svc.retrieve.return_value = {"query": "test query", "records": []} + mock_hit_svc.retrieve.return_value = {"query": {"content": "test query"}, "records": []} mock_hit_svc.hit_testing_args_check.return_value = None mock_marshal.return_value = [] @@ -149,7 +149,7 @@ class TestHitTestingApiPost: "score_threshold": 0.8, } - mock_hit_svc.retrieve.return_value = {"query": "complex query", "records": []} + mock_hit_svc.retrieve.return_value = {"query": {"content": "complex query"}, "records": []} mock_hit_svc.hit_testing_args_check.return_value = None mock_marshal.return_value = [] @@ -194,7 +194,7 @@ class TestHitTestingApiPost: mock_dataset_svc.get_dataset.return_value = mock_dataset mock_dataset_svc.check_dataset_permission.return_value = None - mock_hit_svc.retrieve.return_value = {"query": "filtered query", "records": []} + mock_hit_svc.retrieve.return_value = {"query": {"content": "filtered query"}, "records": []} mock_hit_svc.hit_testing_args_check.return_value = None mock_marshal.return_value = [] @@ -232,7 +232,7 @@ class TestHitTestingApiPost: @patch("controllers.console.datasets.hit_testing_base.HitTestingService") @patch("controllers.console.datasets.hit_testing_base.DatasetService") @patch("controllers.console.datasets.hit_testing_base.current_user", new_callable=lambda: Mock(spec=Account)) - def test_post_normalizes_legacy_query_and_nullable_list_fields( + def test_post_prepares_nullable_list_fields( self, mock_current_user, mock_dataset_svc, @@ -241,7 +241,7 @@ class TestHitTestingApiPost: mock_ns, app, ): - """Test service API normalizes legacy query shape and nullable list fields.""" + """Test service API prepares nullable list fields from marshalled records.""" dataset_id = str(uuid.uuid4()) tenant_id = str(uuid.uuid4()) diff --git a/api/tests/unit_tests/services/test_hit_testing_service_dump_records.py b/api/tests/unit_tests/services/test_hit_testing_service_dump_records.py new file mode 100644 index 0000000000..f933c5e440 --- /dev/null +++ b/api/tests/unit_tests/services/test_hit_testing_service_dump_records.py @@ -0,0 +1,88 @@ +from unittest.mock import Mock, patch + +from services.hit_testing_service import HitTestingService + + +def _retrieval_record(payload: dict): + record = Mock() + record.model_dump.return_value = payload + return record + + +def _dataset_document( + document_id: str = "document-1", + name: str = "guide.md", + data_source_type: str = "upload_file", + doc_type: str | None = None, + doc_metadata: dict | None = None, +): + document = Mock() + document.id = document_id + document.name = name + document.data_source_type = data_source_type + document.doc_type = doc_type + document.doc_metadata = doc_metadata + return document + + +class TestHitTestingServiceDumpRecords: + def test_dump_dataset_document_returns_frontend_required_fields(self): + document = _dataset_document(doc_metadata={"source": "manual"}) + + assert HitTestingService._dump_dataset_document(document) == { + "id": "document-1", + "data_source_type": "upload_file", + "name": "guide.md", + "doc_type": None, + "doc_metadata": {"source": "manual"}, + } + + def test_dump_retrieval_records_returns_dumped_records_without_document_ids(self): + record = _retrieval_record({"segment": None, "score": 0.95}) + + assert HitTestingService._dump_retrieval_records([record]) == [{"segment": None, "score": 0.95}] + + def test_dump_retrieval_records_injects_documents_and_keeps_non_segment_records(self): + record_without_segment = _retrieval_record({"segment": None, "score": 0.95}) + record_with_document = _retrieval_record( + { + "segment": { + "id": "segment-1", + "document_id": "document-1", + }, + "score": 0.9, + } + ) + scalars_result = Mock() + scalars_result.all.return_value = [_dataset_document()] + + with patch("services.hit_testing_service.db.session.scalars", return_value=scalars_result): + result = HitTestingService._dump_retrieval_records([record_without_segment, record_with_document]) + + assert result[0] == {"segment": None, "score": 0.95} + assert result[1]["segment"]["document"] == { + "id": "document-1", + "data_source_type": "upload_file", + "name": "guide.md", + "doc_type": None, + "doc_metadata": None, + } + + def test_dump_retrieval_records_skips_records_with_missing_documents(self, caplog): + record = _retrieval_record( + { + "segment": { + "id": "segment-1", + "document_id": "missing-document", + }, + "score": 0.95, + } + ) + scalars_result = Mock() + scalars_result.all.return_value = [] + + with patch("services.hit_testing_service.db.session.scalars", return_value=scalars_result): + result = HitTestingService._dump_retrieval_records([record]) + + assert result == [] + assert "Skipping hit-testing records with missing documents" in caplog.text