refactor: replace dict with BedrockRetrievalSetting BaseModel in knowledge_service (#34080)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Faiz Khairi
2026-03-25 20:33:24 +08:00
committed by GitHub
parent 56593f20b0
commit ff9cf6c7a4
3 changed files with 28 additions and 11 deletions

View File

@@ -25,7 +25,7 @@ from libs.login import current_account_with_tenant, login_required
from services.dataset_service import DatasetService
from services.external_knowledge_service import ExternalDatasetService
from services.hit_testing_service import HitTestingService
from services.knowledge_service import ExternalDatasetTestService
from services.knowledge_service import BedrockRetrievalSetting, ExternalDatasetTestService
def _build_dataset_detail_model():
@@ -86,7 +86,7 @@ class ExternalHitTestingPayload(BaseModel):
class BedrockRetrievalPayload(BaseModel):
retrieval_setting: dict[str, object]
retrieval_setting: "BedrockRetrievalSetting"
query: str
knowledge_id: str

View File

@@ -1,12 +1,20 @@
import boto3
from pydantic import BaseModel, Field
from configs import dify_config
class BedrockRetrievalSetting(BaseModel):
"""Retrieval settings for Amazon Bedrock knowledge base queries."""
top_k: int | None = Field(default=None, description="Maximum number of results to retrieve")
score_threshold: float = Field(default=0.0, description="Minimum relevance score threshold")
class ExternalDatasetTestService:
# this service is only for internal testing
@staticmethod
def knowledge_retrieval(retrieval_setting: dict, query: str, knowledge_id: str):
def knowledge_retrieval(retrieval_setting: BedrockRetrievalSetting, query: str, knowledge_id: str):
# get bedrock client
client = boto3.client(
"bedrock-agent-runtime",
@@ -20,7 +28,7 @@ class ExternalDatasetTestService:
knowledgeBaseId=knowledge_id,
retrievalConfiguration={
"vectorSearchConfiguration": {
"numberOfResults": retrieval_setting.get("top_k"),
"numberOfResults": retrieval_setting.top_k,
"overrideSearchType": "HYBRID",
}
},
@@ -33,7 +41,7 @@ class ExternalDatasetTestService:
retrieval_results = response.get("retrievalResults")
for retrieval_result in retrieval_results:
# filter out results with score less than threshold
if retrieval_result.get("score") < retrieval_setting.get("score_threshold", 0.0):
if retrieval_result.get("score") < retrieval_setting.score_threshold:
continue
result = {
"metadata": retrieval_result.get("metadata"),

View File

@@ -3,7 +3,7 @@ from unittest.mock import MagicMock, patch
import pytest
from services.knowledge_service import ExternalDatasetTestService
from services.knowledge_service import BedrockRetrievalSetting, ExternalDatasetTestService
class TestKnowledgeService:
@@ -24,7 +24,7 @@ class TestKnowledgeService:
mock_client = MagicMock()
mock_boto_client.return_value = mock_client
retrieval_setting = {"top_k": 4, "score_threshold": 0.5}
retrieval_setting = BedrockRetrievalSetting(top_k=4, score_threshold=0.5)
query = "test query"
knowledge_id = "kb-123"
@@ -87,7 +87,10 @@ class TestKnowledgeService:
mock_client.retrieve.return_value = {"ResponseMetadata": {"HTTPStatusCode": 200}, "retrievalResults": []}
# Act
result = cast(dict[str, Any], ExternalDatasetTestService.knowledge_retrieval({"top_k": 1}, "query", "kb"))
result = cast(
dict[str, Any],
ExternalDatasetTestService.knowledge_retrieval(BedrockRetrievalSetting(top_k=1), "query", "kb"),
)
# Assert
assert result["records"] == []
@@ -104,7 +107,10 @@ class TestKnowledgeService:
mock_client.retrieve.return_value = {"ResponseMetadata": {"HTTPStatusCode": 500}}
# Act
result = cast(dict[str, Any], ExternalDatasetTestService.knowledge_retrieval({"top_k": 1}, "query", "kb"))
result = cast(
dict[str, Any],
ExternalDatasetTestService.knowledge_retrieval(BedrockRetrievalSetting(top_k=1), "query", "kb"),
)
# Assert
assert result["records"] == []
@@ -114,7 +120,7 @@ class TestKnowledgeService:
with patch("services.knowledge_service.boto3.client") as mock_boto:
mock_boto.side_effect = Exception("client init failed")
with pytest.raises(Exception) as exc_info:
ExternalDatasetTestService.knowledge_retrieval({"top_k": 1}, "query", "kb")
ExternalDatasetTestService.knowledge_retrieval(BedrockRetrievalSetting(top_k=1), "query", "kb")
assert "client init failed" in str(exc_info.value)
# ===== Edge Cases =====
@@ -139,7 +145,10 @@ class TestKnowledgeService:
# Act
# retrieval_setting missing "score_threshold"
result = cast(dict[str, Any], ExternalDatasetTestService.knowledge_retrieval({"top_k": 1}, "query", "kb"))
result = cast(
dict[str, Any],
ExternalDatasetTestService.knowledge_retrieval(BedrockRetrievalSetting(top_k=1), "query", "kb"),
)
# Assert
assert len(result["records"]) == 1