diff --git a/api/controllers/console/apikey.py b/api/controllers/console/apikey.py index 772bb9d0f1..b03d9b4a4c 100644 --- a/api/controllers/console/apikey.py +++ b/api/controllers/console/apikey.py @@ -1,12 +1,16 @@ +from datetime import datetime + import flask_restx -from flask_restx import Resource, fields, marshal_with +from flask_restx import Resource from flask_restx._http import HTTPStatus +from pydantic import field_validator from sqlalchemy import delete, func, select from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import Forbidden +from controllers.common.schema import register_schema_models from extensions.ext_database import db -from libs.helper import TimestampField +from fields.base import ResponseModel from libs.login import current_account_with_tenant, login_required from models.dataset import Dataset from models.enums import ApiTokenType @@ -16,21 +20,31 @@ from services.api_token_service import ApiTokenCache from . import console_ns from .wraps import account_initialization_required, edit_permission_required, setup_required -api_key_fields = { - "id": fields.String, - "type": fields.String, - "token": fields.String, - "last_used_at": TimestampField, - "created_at": TimestampField, -} -api_key_item_model = console_ns.model("ApiKeyItem", api_key_fields) +def _to_timestamp(value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value -api_key_list = {"data": fields.List(fields.Nested(api_key_item_model), attribute="items")} -api_key_list_model = console_ns.model( - "ApiKeyList", {"data": fields.List(fields.Nested(api_key_item_model), attribute="items")} -) +class ApiKeyItem(ResponseModel): + id: str + type: str + token: str + last_used_at: int | None = None + created_at: int | None = None + + @field_validator("last_used_at", "created_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class ApiKeyList(ResponseModel): + data: list[ApiKeyItem] + + +register_schema_models(console_ns, ApiKeyItem, ApiKeyList) def _get_resource(resource_id, tenant_id, resource_model): @@ -54,7 +68,6 @@ class BaseApiKeyListResource(Resource): token_prefix: str | None = None max_keys = 10 - @marshal_with(api_key_list_model) def get(self, resource_id): assert self.resource_id_field is not None, "resource_id_field must be set" resource_id = str(resource_id) @@ -66,9 +79,8 @@ class BaseApiKeyListResource(Resource): ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id ) ).all() - return {"items": keys} + return ApiKeyList.model_validate({"data": keys}, from_attributes=True).model_dump(mode="json") - @marshal_with(api_key_item_model) @edit_permission_required def post(self, resource_id): assert self.resource_id_field is not None, "resource_id_field must be set" @@ -100,7 +112,7 @@ class BaseApiKeyListResource(Resource): api_token.type = self.resource_type db.session.add(api_token) db.session.commit() - return api_token, 201 + return ApiKeyItem.model_validate(api_token, from_attributes=True).model_dump(mode="json"), 201 class BaseApiKeyResource(Resource): @@ -147,7 +159,7 @@ class AppApiKeyListResource(BaseApiKeyListResource): @console_ns.doc("get_app_api_keys") @console_ns.doc(description="Get all API keys for an app") @console_ns.doc(params={"resource_id": "App ID"}) - @console_ns.response(200, "Success", api_key_list_model) + @console_ns.response(200, "API keys retrieved successfully", console_ns.models[ApiKeyList.__name__]) def get(self, resource_id): # type: ignore """Get all API keys for an app""" return super().get(resource_id) @@ -155,7 +167,7 @@ class AppApiKeyListResource(BaseApiKeyListResource): @console_ns.doc("create_app_api_key") @console_ns.doc(description="Create a new API key for an app") @console_ns.doc(params={"resource_id": "App ID"}) - @console_ns.response(201, "API key created successfully", api_key_item_model) + @console_ns.response(201, "API key created successfully", console_ns.models[ApiKeyItem.__name__]) @console_ns.response(400, "Maximum keys exceeded") def post(self, resource_id): # type: ignore """Create a new API key for an app""" @@ -187,7 +199,7 @@ class DatasetApiKeyListResource(BaseApiKeyListResource): @console_ns.doc("get_dataset_api_keys") @console_ns.doc(description="Get all API keys for a dataset") @console_ns.doc(params={"resource_id": "Dataset ID"}) - @console_ns.response(200, "Success", api_key_list_model) + @console_ns.response(200, "API keys retrieved successfully", console_ns.models[ApiKeyList.__name__]) def get(self, resource_id): # type: ignore """Get all API keys for a dataset""" return super().get(resource_id) @@ -195,7 +207,7 @@ class DatasetApiKeyListResource(BaseApiKeyListResource): @console_ns.doc("create_dataset_api_key") @console_ns.doc(description="Create a new API key for a dataset") @console_ns.doc(params={"resource_id": "Dataset ID"}) - @console_ns.response(201, "API key created successfully", api_key_item_model) + @console_ns.response(201, "API key created successfully", console_ns.models[ApiKeyItem.__name__]) @console_ns.response(400, "Maximum keys exceeded") def post(self, resource_id): # type: ignore """Create a new API key for a dataset""" diff --git a/api/controllers/console/auth/activate.py b/api/controllers/console/auth/activate.py index f741107b87..e6316bfd62 100644 --- a/api/controllers/console/auth/activate.py +++ b/api/controllers/console/auth/activate.py @@ -1,8 +1,9 @@ from flask import request -from flask_restx import Resource, fields +from flask_restx import Resource from pydantic import BaseModel, Field, field_validator from constants.languages import supported_language +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.error import AlreadyActivateError from extensions.ext_database import db @@ -11,8 +12,6 @@ from libs.helper import EmailStr, timezone from models import AccountStatus from services.account_service import RegisterService -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - class ActivateCheckQuery(BaseModel): workspace_id: str | None = Field(default=None) @@ -39,8 +38,16 @@ class ActivatePayload(BaseModel): return timezone(value) -for model in (ActivateCheckQuery, ActivatePayload): - console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) +class ActivationCheckResponse(BaseModel): + is_valid: bool = Field(description="Whether token is valid") + data: dict | None = Field(default=None, description="Activation data if valid") + + +class ActivationResponse(BaseModel): + result: str = Field(description="Operation result") + + +register_schema_models(console_ns, ActivateCheckQuery, ActivatePayload, ActivationCheckResponse, ActivationResponse) @console_ns.route("/activate/check") @@ -51,13 +58,7 @@ class ActivateCheckApi(Resource): @console_ns.response( 200, "Success", - console_ns.model( - "ActivationCheckResponse", - { - "is_valid": fields.Boolean(description="Whether token is valid"), - "data": fields.Raw(description="Activation data if valid"), - }, - ), + console_ns.models[ActivationCheckResponse.__name__], ) def get(self): args = ActivateCheckQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore @@ -95,12 +96,7 @@ class ActivateApi(Resource): @console_ns.response( 200, "Account activated successfully", - console_ns.model( - "ActivationResponse", - { - "result": fields.String(description="Operation result"), - }, - ), + console_ns.models[ActivationResponse.__name__], ) @console_ns.response(400, "Already activated or invalid token") def post(self): diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index f23c7eb431..b2a905366a 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -11,10 +11,7 @@ import services from configs import dify_config from controllers.common.schema import get_or_create_model, register_schema_models from controllers.console import console_ns -from controllers.console.apikey import ( - api_key_item_model, - api_key_list_model, -) +from controllers.console.apikey import ApiKeyItem, ApiKeyList from controllers.console.app.error import ProviderNotInitializeError from controllers.console.datasets.error import DatasetInUseError, DatasetNameDuplicateError, IndexingEstimateError from controllers.console.wraps import ( @@ -785,23 +782,23 @@ class DatasetApiKeyApi(Resource): @console_ns.doc("get_dataset_api_keys") @console_ns.doc(description="Get dataset API keys") - @console_ns.response(200, "API keys retrieved successfully", api_key_list_model) + @console_ns.response(200, "API keys retrieved successfully", console_ns.models[ApiKeyList.__name__]) @setup_required @login_required @account_initialization_required - @marshal_with(api_key_list_model) def get(self): _, current_tenant_id = current_account_with_tenant() keys = db.session.scalars( select(ApiToken).where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_tenant_id) ).all() - return {"items": keys} + return ApiKeyList.model_validate({"data": keys}, from_attributes=True).model_dump(mode="json") + @console_ns.response(200, "API key created successfully", console_ns.models[ApiKeyItem.__name__]) + @console_ns.response(400, "Maximum keys exceeded") @setup_required @login_required @is_admin_or_owner_required @account_initialization_required - @marshal_with(api_key_item_model) def post(self): _, current_tenant_id = current_account_with_tenant() @@ -828,7 +825,7 @@ class DatasetApiKeyApi(Resource): api_token.type = self.resource_type db.session.add(api_token) db.session.commit() - return api_token, 200 + return ApiKeyItem.model_validate(api_token, from_attributes=True).model_dump(mode="json"), 200 @console_ns.route("/datasets/api-keys/") diff --git a/api/tests/unit_tests/controllers/console/datasets/test_datasets.py b/api/tests/unit_tests/controllers/console/datasets/test_datasets.py index 8555900f4e..94d6c17915 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_datasets.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_datasets.py @@ -1555,7 +1555,17 @@ class TestDatasetApiKeyApi: method = unwrap(api.get) mock_key_1 = MagicMock(spec=ApiToken) + mock_key_1.id = "key-1" + mock_key_1.type = "dataset" + mock_key_1.token = "ds-abc" + mock_key_1.last_used_at = None + mock_key_1.created_at = None mock_key_2 = MagicMock(spec=ApiToken) + mock_key_2.id = "key-2" + mock_key_2.type = "dataset" + mock_key_2.token = "ds-def" + mock_key_2.last_used_at = None + mock_key_2.created_at = None with ( app.test_request_context("/"), @@ -1570,13 +1580,26 @@ class TestDatasetApiKeyApi: ): response = method(api) - assert "items" in response - assert response["items"] == [mock_key_1, mock_key_2] + assert "data" in response + assert len(response["data"]) == 2 + assert response["data"][0]["id"] == "key-1" + assert response["data"][0]["token"] == "ds-abc" + assert response["data"][1]["id"] == "key-2" + assert response["data"][1]["token"] == "ds-def" def test_post_create_api_key_success(self, app): api = DatasetApiKeyApi() method = unwrap(api.post) + mock_token = MagicMock() + mock_token.id = "new-key-id" + mock_token.last_used_at = None + mock_token.created_at = datetime.datetime(2024, 1, 1, 0, 0, 0, tzinfo=datetime.UTC) + + mock_api_token_cls = MagicMock() + mock_api_token_cls.return_value = mock_token + mock_api_token_cls.generate_api_key.return_value = "dataset-abc123" + with ( app.test_request_context("/"), patch( @@ -1588,8 +1611,8 @@ class TestDatasetApiKeyApi: return_value=3, ), patch( - "controllers.console.datasets.datasets.ApiToken.generate_api_key", - return_value="dataset-abc123", + "controllers.console.datasets.datasets.ApiToken", + mock_api_token_cls, ), patch( "controllers.console.datasets.datasets.db.session.add", @@ -1603,9 +1626,11 @@ class TestDatasetApiKeyApi: response, status = method(api) assert status == 200 - assert isinstance(response, ApiToken) - assert response.token == "dataset-abc123" - assert response.type == "dataset" + assert isinstance(response, dict) + assert response["id"] == "new-key-id" + assert response["token"] == "dataset-abc123" + assert response["type"] == "dataset" + assert response["created_at"] is not None def test_post_exceed_max_keys(self, app): api = DatasetApiKeyApi()