refactor(api): migrate console/service_api.dataset.hit_testing to BaseModel (#36533)

This commit is contained in:
chariri
2026-05-27 15:51:42 +09:00
committed by GitHub
parent a8d380bcaf
commit b034449a0c
13 changed files with 732 additions and 441 deletions

View File

@@ -24,12 +24,53 @@ from werkzeug.exceptions import Forbidden, NotFound
import services
from controllers.service_api.dataset.hit_testing import HitTestingApi, HitTestingPayload
from models.account import Account
from models.dataset import Dataset
from services.entities.knowledge_entities.knowledge_entities import RetrievalModel
# ---------------------------------------------------------------------------
# HitTestingPayload Model Tests
# ---------------------------------------------------------------------------
def hit_testing_record() -> dict[str, object]:
return {
"segment": {
"id": "segment-1",
"position": 1,
"document_id": "document-1",
"content": "Chunk text",
"sign_content": "Chunk text",
"answer": None,
"word_count": 2,
"tokens": 3,
"keywords": None,
"index_node_id": None,
"index_node_hash": None,
"hit_count": 0,
"enabled": True,
"disabled_at": None,
"disabled_by": None,
"status": "completed",
"created_by": "account-1",
"created_at": 1_700_000_000,
"indexing_at": None,
"completed_at": None,
"error": None,
"stopped_at": None,
"document": {
"id": "document-1",
"data_source_type": "upload_file",
"name": "guide.md",
"doc_type": None,
"doc_metadata": None,
},
},
"child_chunks": None,
"files": None,
"score": 0.9,
}
class TestHitTestingPayload:
"""Test suite for HitTestingPayload Pydantic model."""
@@ -48,7 +89,7 @@ class TestHitTestingPayload:
}
payload = HitTestingPayload(
query="test query",
retrieval_model=retrieval_model_data,
retrieval_model=RetrievalModel.model_validate(retrieval_model_data),
external_retrieval_model={"provider": "openai"},
attachment_ids=["att_1", "att_2"],
)
@@ -68,6 +109,12 @@ class TestHitTestingPayload:
payload = HitTestingPayload(query="x" * 250)
assert len(payload.query) == 250
def test_payload_ignores_unknown_fields_for_compatibility(self):
"""Top-level fields outside the documented schema remain ignored as before."""
payload = HitTestingPayload.model_validate({"query": "test query", "top_k": 3})
assert payload.model_dump(exclude_none=True) == {"query": "test query"}
# ---------------------------------------------------------------------------
# HitTestingApi Tests
@@ -80,8 +127,11 @@ class TestHitTestingPayload:
class TestHitTestingApiPost:
"""Tests for HitTestingApi.post() via __wrapped__ to skip billing decorator."""
@staticmethod
def _dataset(dataset_id: str, tenant_id: str) -> Dataset:
return Dataset(id=dataset_id, tenant_id=tenant_id, name="Dataset", created_by="account-1")
@patch("controllers.service_api.dataset.hit_testing.service_api_ns")
@patch("controllers.console.datasets.hit_testing_base.marshal")
@patch("controllers.console.datasets.hit_testing_base.HitTestingService")
@patch("controllers.console.datasets.hit_testing_base.DatasetService")
@patch("controllers.console.datasets.hit_testing_base.current_user", new_callable=lambda: Mock(spec=Account))
@@ -90,7 +140,6 @@ class TestHitTestingApiPost:
mock_current_user,
mock_dataset_svc,
mock_hit_svc,
mock_marshal,
mock_ns,
app: Flask,
):
@@ -98,15 +147,13 @@ class TestHitTestingApiPost:
dataset_id = str(uuid.uuid4())
tenant_id = str(uuid.uuid4())
mock_dataset = Mock()
mock_dataset.id = dataset_id
mock_dataset = self._dataset(dataset_id, tenant_id)
mock_dataset_svc.get_dataset.return_value = mock_dataset
mock_dataset_svc.check_dataset_permission.return_value = None
mock_hit_svc.retrieve.return_value = {"query": {"content": "test query"}, "records": []}
mock_hit_svc.hit_testing_args_check.return_value = None
mock_marshal.return_value = []
mock_ns.payload = {"query": "test query"}
@@ -115,11 +162,10 @@ class TestHitTestingApiPost:
# Skip billing decorator via __wrapped__
response = HitTestingApi.post.__wrapped__(api, tenant_id, dataset_id)
assert response["query"] == "test query"
assert response["query"] == {"content": "test query"}
mock_hit_svc.retrieve.assert_called_once()
@patch("controllers.service_api.dataset.hit_testing.service_api_ns")
@patch("controllers.console.datasets.hit_testing_base.marshal")
@patch("controllers.console.datasets.hit_testing_base.HitTestingService")
@patch("controllers.console.datasets.hit_testing_base.DatasetService")
@patch("controllers.console.datasets.hit_testing_base.current_user", new_callable=lambda: Mock(spec=Account))
@@ -128,7 +174,6 @@ class TestHitTestingApiPost:
mock_current_user,
mock_dataset_svc,
mock_hit_svc,
mock_marshal,
mock_ns,
app: Flask,
):
@@ -136,8 +181,7 @@ class TestHitTestingApiPost:
dataset_id = str(uuid.uuid4())
tenant_id = str(uuid.uuid4())
mock_dataset = Mock()
mock_dataset.id = dataset_id
mock_dataset = self._dataset(dataset_id, tenant_id)
mock_dataset_svc.get_dataset.return_value = mock_dataset
mock_dataset_svc.check_dataset_permission.return_value = None
@@ -152,7 +196,6 @@ class TestHitTestingApiPost:
mock_hit_svc.retrieve.return_value = {"query": {"content": "complex query"}, "records": []}
mock_hit_svc.hit_testing_args_check.return_value = None
mock_marshal.return_value = []
mock_ns.payload = {
"query": "complex query",
@@ -164,7 +207,7 @@ class TestHitTestingApiPost:
api = HitTestingApi()
response = HitTestingApi.post.__wrapped__(api, tenant_id, dataset_id)
assert response["query"] == "complex query"
assert response["query"] == {"content": "complex query"}
call_kwargs = mock_hit_svc.retrieve.call_args
# retrieval_model is serialized via model_dump, verify key fields
passed_retrieval_model = call_kwargs.kwargs.get("retrieval_model")
@@ -173,7 +216,6 @@ class TestHitTestingApiPost:
assert passed_retrieval_model["top_k"] == 10
@patch("controllers.service_api.dataset.hit_testing.service_api_ns")
@patch("controllers.console.datasets.hit_testing_base.marshal")
@patch("controllers.console.datasets.hit_testing_base.HitTestingService")
@patch("controllers.console.datasets.hit_testing_base.DatasetService")
@patch("controllers.console.datasets.hit_testing_base.current_user", new_callable=lambda: Mock(spec=Account))
@@ -182,7 +224,6 @@ class TestHitTestingApiPost:
mock_current_user,
mock_dataset_svc,
mock_hit_svc,
mock_marshal,
mock_ns,
app: Flask,
):
@@ -190,14 +231,12 @@ class TestHitTestingApiPost:
dataset_id = str(uuid.uuid4())
tenant_id = str(uuid.uuid4())
mock_dataset = Mock()
mock_dataset.id = dataset_id
mock_dataset = self._dataset(dataset_id, tenant_id)
mock_dataset_svc.get_dataset.return_value = mock_dataset
mock_dataset_svc.check_dataset_permission.return_value = None
mock_hit_svc.retrieve.return_value = {"query": {"content": "filtered query"}, "records": []}
mock_hit_svc.hit_testing_args_check.return_value = None
mock_marshal.return_value = []
metadata_filtering_conditions = {
"logical_operator": "and",
@@ -229,7 +268,6 @@ class TestHitTestingApiPost:
assert passed_retrieval_model["metadata_filtering_conditions"] == metadata_filtering_conditions
@patch("controllers.service_api.dataset.hit_testing.service_api_ns")
@patch("controllers.console.datasets.hit_testing_base.marshal")
@patch("controllers.console.datasets.hit_testing_base.HitTestingService")
@patch("controllers.console.datasets.hit_testing_base.DatasetService")
@patch("controllers.console.datasets.hit_testing_base.current_user", new_callable=lambda: Mock(spec=Account))
@@ -238,30 +276,23 @@ class TestHitTestingApiPost:
mock_current_user,
mock_dataset_svc,
mock_hit_svc,
mock_marshal,
mock_ns,
app: Flask,
):
"""Test service API prepares nullable list fields from marshalled records."""
"""Test service API prepares nullable list fields from retrieval records."""
dataset_id = str(uuid.uuid4())
tenant_id = str(uuid.uuid4())
mock_dataset = Mock()
mock_dataset.id = dataset_id
mock_dataset = self._dataset(dataset_id, tenant_id)
mock_dataset_svc.get_dataset.return_value = mock_dataset
mock_dataset_svc.check_dataset_permission.return_value = None
mock_hit_svc.retrieve.return_value = {"query": {"content": "legacy query"}, "records": ["placeholder"]}
mock_hit_svc.retrieve.return_value = {
"query": {"content": "legacy query"},
"records": [hit_testing_record()],
}
mock_hit_svc.hit_testing_args_check.return_value = None
mock_marshal.return_value = [
{
"segment": {"id": "segment-1", "keywords": None},
"child_chunks": None,
"files": None,
"score": 0.9,
}
]
mock_ns.payload = {"query": "legacy query"}
@@ -269,15 +300,15 @@ class TestHitTestingApiPost:
api = HitTestingApi()
response = HitTestingApi.post.__wrapped__(api, tenant_id, dataset_id)
assert response["query"] == "legacy query"
assert response["records"] == [
{
"segment": {"id": "segment-1", "keywords": []},
"child_chunks": [],
"files": [],
"score": 0.9,
}
]
assert response["query"] == {"content": "legacy query"}
record = response["records"][0]
assert record["segment"]["id"] == "segment-1"
assert record["segment"]["keywords"] == []
assert record["child_chunks"] == []
assert record["files"] == []
assert record["score"] == 0.9
assert record["tsne_position"] is None
assert record["summary"] is None
@patch("controllers.service_api.dataset.hit_testing.service_api_ns")
@patch("controllers.console.datasets.hit_testing_base.DatasetService")
@@ -315,8 +346,7 @@ class TestHitTestingApiPost:
dataset_id = str(uuid.uuid4())
tenant_id = str(uuid.uuid4())
mock_dataset = Mock()
mock_dataset.id = dataset_id
mock_dataset = self._dataset(dataset_id, tenant_id)
mock_dataset_svc.get_dataset.return_value = mock_dataset
mock_dataset_svc.check_dataset_permission.side_effect = services.errors.account.NoPermissionError(