From ff9cf6c7a415ee11bc971f7b9139ca0acb1709dd Mon Sep 17 00:00:00 2001 From: Faiz Khairi Date: Wed, 25 Mar 2026 20:33:24 +0800 Subject: [PATCH] refactor: replace dict with BedrockRetrievalSetting BaseModel in knowledge_service (#34080) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/controllers/console/datasets/external.py | 4 ++-- api/services/knowledge_service.py | 14 ++++++++++--- .../services/test_knowledge_service.py | 21 +++++++++++++------ 3 files changed, 28 insertions(+), 11 deletions(-) diff --git a/api/controllers/console/datasets/external.py b/api/controllers/console/datasets/external.py index 86090bcd10..fc6896f123 100644 --- a/api/controllers/console/datasets/external.py +++ b/api/controllers/console/datasets/external.py @@ -25,7 +25,7 @@ from libs.login import current_account_with_tenant, login_required from services.dataset_service import DatasetService from services.external_knowledge_service import ExternalDatasetService from services.hit_testing_service import HitTestingService -from services.knowledge_service import ExternalDatasetTestService +from services.knowledge_service import BedrockRetrievalSetting, ExternalDatasetTestService def _build_dataset_detail_model(): @@ -86,7 +86,7 @@ class ExternalHitTestingPayload(BaseModel): class BedrockRetrievalPayload(BaseModel): - retrieval_setting: dict[str, object] + retrieval_setting: "BedrockRetrievalSetting" query: str knowledge_id: str diff --git a/api/services/knowledge_service.py b/api/services/knowledge_service.py index 02fe1d19bc..3d6fdb08a3 100644 --- a/api/services/knowledge_service.py +++ b/api/services/knowledge_service.py @@ -1,12 +1,20 @@ import boto3 +from pydantic import BaseModel, Field from configs import dify_config +class BedrockRetrievalSetting(BaseModel): + """Retrieval settings for Amazon Bedrock knowledge base queries.""" + + top_k: int | None = Field(default=None, description="Maximum number of results to retrieve") + score_threshold: float = Field(default=0.0, description="Minimum relevance score threshold") + + class ExternalDatasetTestService: # this service is only for internal testing @staticmethod - def knowledge_retrieval(retrieval_setting: dict, query: str, knowledge_id: str): + def knowledge_retrieval(retrieval_setting: BedrockRetrievalSetting, query: str, knowledge_id: str): # get bedrock client client = boto3.client( "bedrock-agent-runtime", @@ -20,7 +28,7 @@ class ExternalDatasetTestService: knowledgeBaseId=knowledge_id, retrievalConfiguration={ "vectorSearchConfiguration": { - "numberOfResults": retrieval_setting.get("top_k"), + "numberOfResults": retrieval_setting.top_k, "overrideSearchType": "HYBRID", } }, @@ -33,7 +41,7 @@ class ExternalDatasetTestService: retrieval_results = response.get("retrievalResults") for retrieval_result in retrieval_results: # filter out results with score less than threshold - if retrieval_result.get("score") < retrieval_setting.get("score_threshold", 0.0): + if retrieval_result.get("score") < retrieval_setting.score_threshold: continue result = { "metadata": retrieval_result.get("metadata"), diff --git a/api/tests/unit_tests/services/test_knowledge_service.py b/api/tests/unit_tests/services/test_knowledge_service.py index bc0caee071..53c243ad71 100644 --- a/api/tests/unit_tests/services/test_knowledge_service.py +++ b/api/tests/unit_tests/services/test_knowledge_service.py @@ -3,7 +3,7 @@ from unittest.mock import MagicMock, patch import pytest -from services.knowledge_service import ExternalDatasetTestService +from services.knowledge_service import BedrockRetrievalSetting, ExternalDatasetTestService class TestKnowledgeService: @@ -24,7 +24,7 @@ class TestKnowledgeService: mock_client = MagicMock() mock_boto_client.return_value = mock_client - retrieval_setting = {"top_k": 4, "score_threshold": 0.5} + retrieval_setting = BedrockRetrievalSetting(top_k=4, score_threshold=0.5) query = "test query" knowledge_id = "kb-123" @@ -87,7 +87,10 @@ class TestKnowledgeService: mock_client.retrieve.return_value = {"ResponseMetadata": {"HTTPStatusCode": 200}, "retrievalResults": []} # Act - result = cast(dict[str, Any], ExternalDatasetTestService.knowledge_retrieval({"top_k": 1}, "query", "kb")) + result = cast( + dict[str, Any], + ExternalDatasetTestService.knowledge_retrieval(BedrockRetrievalSetting(top_k=1), "query", "kb"), + ) # Assert assert result["records"] == [] @@ -104,7 +107,10 @@ class TestKnowledgeService: mock_client.retrieve.return_value = {"ResponseMetadata": {"HTTPStatusCode": 500}} # Act - result = cast(dict[str, Any], ExternalDatasetTestService.knowledge_retrieval({"top_k": 1}, "query", "kb")) + result = cast( + dict[str, Any], + ExternalDatasetTestService.knowledge_retrieval(BedrockRetrievalSetting(top_k=1), "query", "kb"), + ) # Assert assert result["records"] == [] @@ -114,7 +120,7 @@ class TestKnowledgeService: with patch("services.knowledge_service.boto3.client") as mock_boto: mock_boto.side_effect = Exception("client init failed") with pytest.raises(Exception) as exc_info: - ExternalDatasetTestService.knowledge_retrieval({"top_k": 1}, "query", "kb") + ExternalDatasetTestService.knowledge_retrieval(BedrockRetrievalSetting(top_k=1), "query", "kb") assert "client init failed" in str(exc_info.value) # ===== Edge Cases ===== @@ -139,7 +145,10 @@ class TestKnowledgeService: # Act # retrieval_setting missing "score_threshold" - result = cast(dict[str, Any], ExternalDatasetTestService.knowledge_retrieval({"top_k": 1}, "query", "kb")) + result = cast( + dict[str, Any], + ExternalDatasetTestService.knowledge_retrieval(BedrockRetrievalSetting(top_k=1), "query", "kb"), + ) # Assert assert len(result["records"]) == 1