Files
dify/api/controllers/console/datasets/hit_testing_base.py
Asuka Minato 05fe92a541 refactor: port reqparse to BaseModel (#28993)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-12-08 15:31:19 +09:00

90 lines
3.2 KiB
Python

import logging
from typing import Any
from flask_restx import marshal
from pydantic import BaseModel, Field
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services
from controllers.console.app.error import (
CompletionRequestError,
ProviderModelCurrentlyNotSupportError,
ProviderNotInitializeError,
ProviderQuotaExceededError,
)
from controllers.console.datasets.error import DatasetNotInitializedError
from core.errors.error import (
LLMBadRequestError,
ModelCurrentlyNotSupportError,
ProviderTokenNotInitError,
QuotaExceededError,
)
from core.model_runtime.errors.invoke import InvokeError
from fields.hit_testing_fields import hit_testing_record_fields
from libs.login import current_user
from models.account import Account
from services.dataset_service import DatasetService
from services.hit_testing_service import HitTestingService
logger = logging.getLogger(__name__)
class HitTestingPayload(BaseModel):
query: str = Field(max_length=250)
retrieval_model: dict[str, Any] | None = None
external_retrieval_model: dict[str, Any] | None = None
class DatasetsHitTestingBase:
@staticmethod
def get_and_validate_dataset(dataset_id: str):
assert isinstance(current_user, Account)
dataset = DatasetService.get_dataset(dataset_id)
if dataset is None:
raise NotFound("Dataset not found.")
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
return dataset
@staticmethod
def hit_testing_args_check(args: dict[str, Any]):
HitTestingService.hit_testing_args_check(args)
@staticmethod
def perform_hit_testing(dataset, args):
assert isinstance(current_user, Account)
try:
response = HitTestingService.retrieve(
dataset=dataset,
query=args["query"],
account=current_user,
retrieval_model=args["retrieval_model"],
external_retrieval_model=args["external_retrieval_model"],
limit=10,
)
return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)}
except services.errors.index.IndexNotInitializedError:
raise DatasetNotInitializedError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except LLMBadRequestError:
raise ProviderNotInitializeError(
"No Embedding Model or Reranking Model available. Please configure a valid provider "
"in the Settings -> Model Provider."
)
except InvokeError as e:
raise CompletionRequestError(e.description)
except ValueError as e:
raise ValueError(str(e))
except Exception as e:
logger.exception("Hit testing failed.")
raise InternalServerError(str(e))