refactor(api): migrate console/service_api.dataset.hit_testing to BaseModel (#36533)

This commit is contained in:
chariri
2026-05-27 15:51:42 +09:00
committed by GitHub
parent a8d380bcaf
commit b034449a0c
13 changed files with 732 additions and 441 deletions

View File

@@ -1,5 +1,5 @@
import uuid
from unittest.mock import MagicMock, PropertyMock, patch
from unittest.mock import PropertyMock, patch
import pytest
from flask import Flask
@@ -8,7 +8,7 @@ from werkzeug.exceptions import NotFound
from controllers.console import console_ns
from controllers.console.datasets.hit_testing import HitTestingApi
from controllers.console.datasets.hit_testing_base import HitTestingPayload
from models.dataset import Dataset
def unwrap(func):
@@ -32,7 +32,48 @@ def dataset_id():
@pytest.fixture
def dataset():
return MagicMock(id="dataset-1")
return Dataset(id="dataset-1", tenant_id="tenant-1", name="Dataset", created_by="account-1")
def hit_testing_record() -> dict[str, object]:
return {
"segment": {
"id": "segment-1",
"position": 1,
"document_id": "document-1",
"content": "Chunk text",
"sign_content": "Chunk text",
"answer": None,
"word_count": 2,
"tokens": 3,
"keywords": [],
"index_node_id": None,
"index_node_hash": None,
"hit_count": 0,
"enabled": True,
"disabled_at": None,
"disabled_by": None,
"status": "completed",
"created_by": "account-1",
"created_at": 1_700_000_000,
"indexing_at": None,
"completed_at": None,
"error": None,
"stopped_at": None,
"document": {
"id": "document-1",
"data_source_type": "upload_file",
"name": "guide.md",
"doc_type": None,
"doc_metadata": None,
},
},
"child_chunks": [],
"score": None,
"tsne_position": None,
"files": [],
"summary": None,
}
@pytest.fixture(autouse=True)
@@ -63,7 +104,6 @@ class TestHitTestingApi:
payload = {
"query": "what is vector search",
"top_k": 3,
}
with (
@@ -74,11 +114,6 @@ class TestHitTestingApi:
new_callable=PropertyMock,
return_value=payload,
),
patch.object(
HitTestingPayload,
"model_validate",
return_value=MagicMock(model_dump=lambda **_: payload),
),
patch.object(
HitTestingApi,
"get_and_validate_dataset",
@@ -91,7 +126,7 @@ class TestHitTestingApi:
patch.object(
HitTestingApi,
"perform_hit_testing",
return_value={"query": "what is vector search", "records": []},
return_value={"query": {"content": "what is vector search"}, "records": []},
),
):
result = method(api, dataset_id)
@@ -107,16 +142,7 @@ class TestHitTestingApi:
payload = {
"query": "what is vector search",
}
records = [
{
"segment": None,
"child_chunks": [],
"score": None,
"tsne_position": None,
"files": [],
"summary": None,
}
]
records = [hit_testing_record()]
with (
app.test_request_context("/"),
@@ -126,11 +152,6 @@ class TestHitTestingApi:
new_callable=PropertyMock,
return_value=payload,
),
patch.object(
HitTestingPayload,
"model_validate",
return_value=MagicMock(model_dump=lambda **_: payload),
),
patch.object(
HitTestingApi,
"get_and_validate_dataset",
@@ -143,13 +164,16 @@ class TestHitTestingApi:
patch.object(
HitTestingApi,
"perform_hit_testing",
return_value={"query": payload["query"], "records": records},
return_value={"query": {"content": payload["query"]}, "records": records},
),
):
result = method(api, dataset_id)
assert result["query"] == payload["query"]
assert result["records"] == records
assert result["query"] == {"content": payload["query"]}
assert result["records"][0]["segment"]["keywords"] == []
assert result["records"][0]["child_chunks"] == []
assert result["records"][0]["files"] == []
assert result["records"][0]["score"] is None
def test_hit_testing_dataset_not_found(self, app: Flask, dataset_id):
api = HitTestingApi()
@@ -192,11 +216,6 @@ class TestHitTestingApi:
new_callable=PropertyMock,
return_value=payload,
),
patch.object(
HitTestingPayload,
"model_validate",
return_value=MagicMock(model_dump=lambda **_: payload),
),
patch.object(
HitTestingApi,
"get_and_validate_dataset",

View File

@@ -22,6 +22,7 @@ from core.errors.error import (
)
from graphon.model_runtime.errors.invoke import InvokeError
from models.account import Account
from models.dataset import Dataset
from services.dataset_service import DatasetService
from services.hit_testing_service import HitTestingService
@@ -43,7 +44,45 @@ def patch_current_user(mocker, account):
@pytest.fixture
def dataset():
return MagicMock(id="dataset-1")
return Dataset(id="dataset-1", tenant_id="tenant-1", name="Dataset", created_by="account-1")
def hit_testing_record() -> dict[str, object]:
return {
"segment": {
"id": "segment-1",
"position": 1,
"document_id": "document-1",
"content": "Chunk text",
"answer": None,
"word_count": 2,
"tokens": 3,
"keywords": None,
"index_node_id": None,
"index_node_hash": None,
"hit_count": 0,
"enabled": True,
"disabled_at": None,
"disabled_by": None,
"status": "completed",
"created_by": "account-1",
"created_at": 1_700_000_000,
"indexing_at": None,
"completed_at": None,
"error": None,
"stopped_at": None,
"document": {
"id": "document-1",
"data_source_type": "upload_file",
"name": "guide.md",
"doc_type": None,
"doc_metadata": None,
},
},
"child_chunks": None,
"files": None,
"score": 0.8,
}
class TestGetAndValidateDataset:
@@ -116,6 +155,13 @@ class TestParseArgs:
with pytest.raises(ValueError):
DatasetsHitTestingBase.parse_args(payload)
def test_parse_args_ignores_unknown_fields_for_compatibility(self):
payload = {"query": "hello", "top_k": 3}
result = DatasetsHitTestingBase.parse_args(payload)
assert result == {"query": "hello"}
class TestPerformHitTesting:
def test_success(self, dataset):
@@ -131,48 +177,42 @@ class TestPerformHitTesting:
):
result = DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"})
assert result["query"] == "hello"
assert result["query"] == {"content": "hello"}
assert result["records"] == []
def test_success_prepares_nullable_list_fields(self, dataset):
response = {
"query": {"content": "hello"},
"records": [
{
"segment": {"id": "segment-1", "keywords": None},
"child_chunks": None,
"files": None,
"score": 0.8,
}
],
"records": [hit_testing_record()],
}
with patch.object(
HitTestingService,
"retrieve",
return_value=response,
):
result = DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"})
assert result["query"] == {"content": "hello"}
record = result["records"][0]
assert record["segment"]["keywords"] == []
assert record["segment"]["sign_content"] is None
assert record["child_chunks"] == []
assert record["files"] == []
assert record["score"] == 0.8
assert record["tsne_position"] is None
assert record["summary"] is None
def test_invalid_query_response_raises_value_error(self, dataset):
with (
patch.object(
HitTestingService,
"retrieve",
return_value=response,
),
patch(
"controllers.console.datasets.hit_testing_base.marshal",
return_value=response["records"],
return_value={"query": "hello", "records": []},
),
pytest.raises(ValueError, match="Invalid hit testing query response"),
):
result = DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"})
assert result["query"] == "hello"
assert result["records"] == [
{
"segment": {"id": "segment-1", "keywords": []},
"child_chunks": [],
"files": [],
"score": 0.8,
}
]
def test_invalid_query_response_raises_value_error(self):
with pytest.raises(ValueError, match="Invalid hit testing query response"):
DatasetsHitTestingBase._extract_hit_testing_query("hello")
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"})
def test_invalid_records_response_raises_value_error(self):
with pytest.raises(ValueError, match="Invalid hit testing records response"):