mirror of
https://github.com/langgenius/dify.git
synced 2026-03-27 11:02:18 -04:00
refactor: replace dict with BedrockRetrievalSetting BaseModel in knowledge_service (#34080)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -25,7 +25,7 @@ from libs.login import current_account_with_tenant, login_required
|
||||
from services.dataset_service import DatasetService
|
||||
from services.external_knowledge_service import ExternalDatasetService
|
||||
from services.hit_testing_service import HitTestingService
|
||||
from services.knowledge_service import ExternalDatasetTestService
|
||||
from services.knowledge_service import BedrockRetrievalSetting, ExternalDatasetTestService
|
||||
|
||||
|
||||
def _build_dataset_detail_model():
|
||||
@@ -86,7 +86,7 @@ class ExternalHitTestingPayload(BaseModel):
|
||||
|
||||
|
||||
class BedrockRetrievalPayload(BaseModel):
|
||||
retrieval_setting: dict[str, object]
|
||||
retrieval_setting: "BedrockRetrievalSetting"
|
||||
query: str
|
||||
knowledge_id: str
|
||||
|
||||
|
||||
@@ -1,12 +1,20 @@
|
||||
import boto3
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
|
||||
class BedrockRetrievalSetting(BaseModel):
|
||||
"""Retrieval settings for Amazon Bedrock knowledge base queries."""
|
||||
|
||||
top_k: int | None = Field(default=None, description="Maximum number of results to retrieve")
|
||||
score_threshold: float = Field(default=0.0, description="Minimum relevance score threshold")
|
||||
|
||||
|
||||
class ExternalDatasetTestService:
|
||||
# this service is only for internal testing
|
||||
@staticmethod
|
||||
def knowledge_retrieval(retrieval_setting: dict, query: str, knowledge_id: str):
|
||||
def knowledge_retrieval(retrieval_setting: BedrockRetrievalSetting, query: str, knowledge_id: str):
|
||||
# get bedrock client
|
||||
client = boto3.client(
|
||||
"bedrock-agent-runtime",
|
||||
@@ -20,7 +28,7 @@ class ExternalDatasetTestService:
|
||||
knowledgeBaseId=knowledge_id,
|
||||
retrievalConfiguration={
|
||||
"vectorSearchConfiguration": {
|
||||
"numberOfResults": retrieval_setting.get("top_k"),
|
||||
"numberOfResults": retrieval_setting.top_k,
|
||||
"overrideSearchType": "HYBRID",
|
||||
}
|
||||
},
|
||||
@@ -33,7 +41,7 @@ class ExternalDatasetTestService:
|
||||
retrieval_results = response.get("retrievalResults")
|
||||
for retrieval_result in retrieval_results:
|
||||
# filter out results with score less than threshold
|
||||
if retrieval_result.get("score") < retrieval_setting.get("score_threshold", 0.0):
|
||||
if retrieval_result.get("score") < retrieval_setting.score_threshold:
|
||||
continue
|
||||
result = {
|
||||
"metadata": retrieval_result.get("metadata"),
|
||||
|
||||
@@ -3,7 +3,7 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from services.knowledge_service import ExternalDatasetTestService
|
||||
from services.knowledge_service import BedrockRetrievalSetting, ExternalDatasetTestService
|
||||
|
||||
|
||||
class TestKnowledgeService:
|
||||
@@ -24,7 +24,7 @@ class TestKnowledgeService:
|
||||
mock_client = MagicMock()
|
||||
mock_boto_client.return_value = mock_client
|
||||
|
||||
retrieval_setting = {"top_k": 4, "score_threshold": 0.5}
|
||||
retrieval_setting = BedrockRetrievalSetting(top_k=4, score_threshold=0.5)
|
||||
query = "test query"
|
||||
knowledge_id = "kb-123"
|
||||
|
||||
@@ -87,7 +87,10 @@ class TestKnowledgeService:
|
||||
mock_client.retrieve.return_value = {"ResponseMetadata": {"HTTPStatusCode": 200}, "retrievalResults": []}
|
||||
|
||||
# Act
|
||||
result = cast(dict[str, Any], ExternalDatasetTestService.knowledge_retrieval({"top_k": 1}, "query", "kb"))
|
||||
result = cast(
|
||||
dict[str, Any],
|
||||
ExternalDatasetTestService.knowledge_retrieval(BedrockRetrievalSetting(top_k=1), "query", "kb"),
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result["records"] == []
|
||||
@@ -104,7 +107,10 @@ class TestKnowledgeService:
|
||||
mock_client.retrieve.return_value = {"ResponseMetadata": {"HTTPStatusCode": 500}}
|
||||
|
||||
# Act
|
||||
result = cast(dict[str, Any], ExternalDatasetTestService.knowledge_retrieval({"top_k": 1}, "query", "kb"))
|
||||
result = cast(
|
||||
dict[str, Any],
|
||||
ExternalDatasetTestService.knowledge_retrieval(BedrockRetrievalSetting(top_k=1), "query", "kb"),
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result["records"] == []
|
||||
@@ -114,7 +120,7 @@ class TestKnowledgeService:
|
||||
with patch("services.knowledge_service.boto3.client") as mock_boto:
|
||||
mock_boto.side_effect = Exception("client init failed")
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
ExternalDatasetTestService.knowledge_retrieval({"top_k": 1}, "query", "kb")
|
||||
ExternalDatasetTestService.knowledge_retrieval(BedrockRetrievalSetting(top_k=1), "query", "kb")
|
||||
assert "client init failed" in str(exc_info.value)
|
||||
|
||||
# ===== Edge Cases =====
|
||||
@@ -139,7 +145,10 @@ class TestKnowledgeService:
|
||||
|
||||
# Act
|
||||
# retrieval_setting missing "score_threshold"
|
||||
result = cast(dict[str, Any], ExternalDatasetTestService.knowledge_retrieval({"top_k": 1}, "query", "kb"))
|
||||
result = cast(
|
||||
dict[str, Any],
|
||||
ExternalDatasetTestService.knowledge_retrieval(BedrockRetrievalSetting(top_k=1), "query", "kb"),
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(result["records"]) == 1
|
||||
|
||||
Reference in New Issue
Block a user