refactor(api): migrate tenant/user via DI: apikey, extension, data_source_bearer, oauth_server (#36660)

This commit is contained in:
chariri
2026-05-26 17:22:35 +09:00
committed by GitHub
parent 59e99ee1ae
commit eed8d659d1
9 changed files with 500 additions and 64 deletions

View File

@@ -9,18 +9,25 @@ from sqlalchemy import delete, func, select
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import Forbidden
from controllers.common.schema import register_schema_models
from controllers.common.schema import register_response_schema_models
from extensions.ext_database import db
from fields.base import ResponseModel
from libs.helper import to_timestamp
from libs.login import current_account_with_tenant, login_required
from libs.helper import dump_response, to_timestamp
from libs.login import login_required
from models import Account
from models.dataset import Dataset
from models.enums import ApiTokenType
from models.model import ApiToken, App
from services.api_token_service import ApiTokenCache
from . import console_ns
from .wraps import account_initialization_required, edit_permission_required, setup_required
from .wraps import (
account_initialization_required,
edit_permission_required,
setup_required,
with_current_tenant_id,
with_current_user,
)
class ApiKeyItem(ResponseModel):
@@ -40,7 +47,7 @@ class ApiKeyList(ResponseModel):
data: list[ApiKeyItem]
register_schema_models(console_ns, ApiKeyItem, ApiKeyList)
register_response_schema_models(console_ns, ApiKeyItem, ApiKeyList)
def _get_resource(resource_id, tenant_id, resource_model):
@@ -64,10 +71,11 @@ class BaseApiKeyListResource(Resource):
token_prefix: str | None = None
max_keys = 10
def get(self, resource_id):
def get(self, resource_id: str, current_tenant_id: str) -> dict[str, object]:
return dump_response(ApiKeyList, self._get_api_key_list(resource_id, current_tenant_id))
def _get_api_key_list(self, resource_id: str, current_tenant_id: str) -> ApiKeyList:
assert self.resource_id_field is not None, "resource_id_field must be set"
resource_id = str(resource_id)
_, current_tenant_id = current_account_with_tenant()
_get_resource(resource_id, current_tenant_id, self.resource_model)
keys = db.session.scalars(
@@ -75,13 +83,14 @@ class BaseApiKeyListResource(Resource):
ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id
)
).all()
return ApiKeyList.model_validate({"data": keys}, from_attributes=True).model_dump(mode="json")
return ApiKeyList.model_validate({"data": keys}, from_attributes=True)
@edit_permission_required
def post(self, resource_id):
def post(self, resource_id: str, current_tenant_id: str) -> tuple[dict[str, object], int]:
return dump_response(ApiKeyItem, self._create_api_key(resource_id, current_tenant_id)), 201
def _create_api_key(self, resource_id: str, current_tenant_id: str) -> ApiToken:
assert self.resource_id_field is not None, "resource_id_field must be set"
resource_id = str(resource_id)
_, current_tenant_id = current_account_with_tenant()
_get_resource(resource_id, current_tenant_id, self.resource_model)
current_key_count: int = (
db.session.scalar(
@@ -108,7 +117,7 @@ class BaseApiKeyListResource(Resource):
api_token.type = self.resource_type
db.session.add(api_token)
db.session.commit()
return ApiKeyItem.model_validate(api_token, from_attributes=True).model_dump(mode="json"), 201
return api_token
class BaseApiKeyResource(Resource):
@@ -118,9 +127,20 @@ class BaseApiKeyResource(Resource):
resource_model: type | None = None
resource_id_field: str | None = None
def delete(self, resource_id: str, api_key_id: str):
def delete(
self, resource_id: str, api_key_id: str, current_tenant_id: str, current_user: Account
) -> tuple[str, int]:
self._delete_api_key(resource_id, api_key_id, current_tenant_id, current_user)
return "", 204
def _delete_api_key(
self,
resource_id: str,
api_key_id: str,
current_tenant_id: str,
current_user: Account,
) -> None:
assert self.resource_id_field is not None, "resource_id_field must be set"
current_user, current_tenant_id = current_account_with_tenant()
_get_resource(resource_id, current_tenant_id, self.resource_model)
if not current_user.is_admin_or_owner:
@@ -147,8 +167,6 @@ class BaseApiKeyResource(Resource):
db.session.execute(delete(ApiToken).where(ApiToken.id == api_key_id))
db.session.commit()
return "", 204
@console_ns.route("/apps/<uuid:resource_id>/api-keys")
class AppApiKeyListResource(BaseApiKeyListResource):
@@ -156,18 +174,21 @@ class AppApiKeyListResource(BaseApiKeyListResource):
@console_ns.doc(description="Get all API keys for an app")
@console_ns.doc(params={"resource_id": "App ID"})
@console_ns.response(200, "API keys retrieved successfully", console_ns.models[ApiKeyList.__name__])
def get(self, resource_id: UUID):
@with_current_tenant_id
def get(self, current_tenant_id: str, resource_id: UUID) -> dict[str, object]:
"""Get all API keys for an app"""
return super().get(resource_id)
return dump_response(ApiKeyList, self._get_api_key_list(str(resource_id), current_tenant_id))
@console_ns.doc("create_app_api_key")
@console_ns.doc(description="Create a new API key for an app")
@console_ns.doc(params={"resource_id": "App ID"})
@console_ns.response(201, "API key created successfully", console_ns.models[ApiKeyItem.__name__])
@console_ns.response(400, "Maximum keys exceeded")
def post(self, resource_id: UUID):
@with_current_tenant_id
@edit_permission_required
def post(self, current_tenant_id: str, resource_id: UUID) -> tuple[dict[str, object], int]:
"""Create a new API key for an app"""
return super().post(resource_id)
return dump_response(ApiKeyItem, self._create_api_key(str(resource_id), current_tenant_id)), 201
resource_type = ApiTokenType.APP
resource_model = App
@@ -181,9 +202,14 @@ class AppApiKeyResource(BaseApiKeyResource):
@console_ns.doc(description="Delete an API key for an app")
@console_ns.doc(params={"resource_id": "App ID", "api_key_id": "API key ID"})
@console_ns.response(204, "API key deleted successfully")
def delete(self, resource_id: UUID, api_key_id: UUID):
@with_current_user
@with_current_tenant_id
def delete(
self, current_tenant_id: str, current_user: Account, resource_id: UUID, api_key_id: UUID
) -> tuple[str, int]:
"""Delete an API key for an app"""
return super().delete(str(resource_id), str(api_key_id))
self._delete_api_key(str(resource_id), str(api_key_id), current_tenant_id, current_user)
return "", 204
resource_type = ApiTokenType.APP
resource_model = App
@@ -196,18 +222,21 @@ class DatasetApiKeyListResource(BaseApiKeyListResource):
@console_ns.doc(description="Get all API keys for a dataset")
@console_ns.doc(params={"resource_id": "Dataset ID"})
@console_ns.response(200, "API keys retrieved successfully", console_ns.models[ApiKeyList.__name__])
def get(self, resource_id: UUID):
@with_current_tenant_id
def get(self, current_tenant_id: str, resource_id: UUID) -> dict[str, object]:
"""Get all API keys for a dataset"""
return super().get(resource_id)
return dump_response(ApiKeyList, self._get_api_key_list(str(resource_id), current_tenant_id))
@console_ns.doc("create_dataset_api_key")
@console_ns.doc(description="Create a new API key for a dataset")
@console_ns.doc(params={"resource_id": "Dataset ID"})
@console_ns.response(201, "API key created successfully", console_ns.models[ApiKeyItem.__name__])
@console_ns.response(400, "Maximum keys exceeded")
def post(self, resource_id: UUID):
@with_current_tenant_id
@edit_permission_required
def post(self, current_tenant_id: str, resource_id: UUID) -> tuple[dict[str, object], int]:
"""Create a new API key for a dataset"""
return super().post(resource_id)
return dump_response(ApiKeyItem, self._create_api_key(str(resource_id), current_tenant_id)), 201
resource_type = ApiTokenType.DATASET
resource_model = Dataset
@@ -221,9 +250,14 @@ class DatasetApiKeyResource(BaseApiKeyResource):
@console_ns.doc(description="Delete an API key for a dataset")
@console_ns.doc(params={"resource_id": "Dataset ID", "api_key_id": "API key ID"})
@console_ns.response(204, "API key deleted successfully")
def delete(self, resource_id: UUID, api_key_id: UUID):
@with_current_user
@with_current_tenant_id
def delete(
self, current_tenant_id: str, current_user: Account, resource_id: UUID, api_key_id: UUID
) -> tuple[str, int]:
"""Delete an API key for a dataset"""
return super().delete(str(resource_id), str(api_key_id))
self._delete_api_key(str(resource_id), str(api_key_id), current_tenant_id, current_user)
return "", 204
resource_type = ApiTokenType.DATASET
resource_model = Dataset

View File

@@ -5,12 +5,12 @@ from pydantic import BaseModel, Field
from controllers.common.schema import register_response_schema_models, register_schema_models
from fields.base import ResponseModel
from libs.login import current_account_with_tenant, login_required
from libs.login import login_required
from services.auth.api_key_auth_service import ApiKeyAuthService
from .. import console_ns
from ..auth.error import ApiKeyAuthFailedError
from ..wraps import account_initialization_required, is_admin_or_owner_required, setup_required
from ..wraps import account_initialization_required, is_admin_or_owner_required, setup_required, with_current_tenant_id
class ApiKeyAuthBindingPayload(BaseModel):
@@ -42,8 +42,8 @@ class ApiKeyAuthDataSource(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
_, current_tenant_id = current_account_with_tenant()
@with_current_tenant_id
def get(self, current_tenant_id: str):
data_source_api_key_bindings = ApiKeyAuthService.get_provider_auth_list(current_tenant_id)
if data_source_api_key_bindings:
return {
@@ -69,9 +69,9 @@ class ApiKeyAuthDataSourceBinding(Resource):
@account_initialization_required
@is_admin_or_owner_required
@console_ns.expect(console_ns.models[ApiKeyAuthBindingPayload.__name__])
def post(self):
@with_current_tenant_id
def post(self, current_tenant_id: str):
# The role of the current user in the table must be admin or owner
_, current_tenant_id = current_account_with_tenant()
payload = ApiKeyAuthBindingPayload.model_validate(console_ns.payload)
data = payload.model_dump()
ApiKeyAuthService.validate_api_key_auth_args(data)
@@ -89,10 +89,9 @@ class ApiKeyAuthDataSourceBindingDelete(Resource):
@account_initialization_required
@is_admin_or_owner_required
@console_ns.response(204, "Binding deleted successfully")
def delete(self, binding_id: UUID):
@with_current_tenant_id
def delete(self, current_tenant_id: str, binding_id: UUID):
# The role of the current user in the table must be admin or owner
_, current_tenant_id = current_account_with_tenant()
ApiKeyAuthService.delete_provider_auth(current_tenant_id, str(binding_id))
return "", 204

View File

@@ -8,9 +8,9 @@ from flask_restx import Resource
from pydantic import BaseModel
from werkzeug.exceptions import BadRequest, NotFound
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.console.wraps import account_initialization_required, setup_required, with_current_user
from graphon.model_runtime.utils.encoders import jsonable_encoder
from libs.login import current_account_with_tenant, login_required
from libs.login import login_required
from models import Account
from models.model import OAuthProviderApp
from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN, OAuthGrantType, OAuthServerService
@@ -133,12 +133,10 @@ class OAuthServerUserAuthorizeApi(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_user
@oauth_server_client_id_required
def post(self, oauth_provider_app: OAuthProviderApp):
current_user, _ = current_account_with_tenant()
account = current_user
user_account_id = account.id
def post(self, oauth_provider_app: OAuthProviderApp, current_user: Account):
user_account_id = current_user.id
code = OAuthServerService.sign_oauth_authorization_code(oauth_provider_app.client_id, user_account_id)
return jsonable_encoder(
{

View File

@@ -19,6 +19,7 @@ def test_get_api_key_auth_data_source(
test_client_with_containers: FlaskClient,
) -> None:
account, tenant = create_console_account_and_tenant(db_session_with_containers)
foreign_account, foreign_tenant = create_console_account_and_tenant(db_session_with_containers)
binding = DataSourceApiKeyAuthBinding(
tenant_id=tenant.id,
category="api_key",
@@ -26,8 +27,16 @@ def test_get_api_key_auth_data_source(
credentials=json.dumps({"auth_type": "api_key", "config": {"api_key": "encrypted"}}),
disabled=False,
)
db_session_with_containers.add(binding)
foreign_binding = DataSourceApiKeyAuthBinding(
tenant_id=foreign_tenant.id,
category="api_key",
provider="foreign_provider",
credentials=json.dumps({"auth_type": "api_key", "config": {"api_key": "encrypted"}}),
disabled=False,
)
db_session_with_containers.add_all([binding, foreign_binding])
db_session_with_containers.commit()
authenticate_console_client(test_client_with_containers, foreign_account)
response = test_client_with_containers.get(
"/console/api/api-key-auth/data-source",
@@ -60,20 +69,23 @@ def test_create_binding_successful(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, _tenant = create_console_account_and_tenant(db_session_with_containers)
account, tenant = create_console_account_and_tenant(db_session_with_containers)
tenant_id = tenant.id
payload = {"category": "api_key", "provider": "custom", "credentials": {"key": "value"}}
with (
patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.validate_api_key_auth_args"),
patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.create_provider_auth"),
patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.create_provider_auth") as create_auth,
):
response = test_client_with_containers.post(
"/console/api/api-key-auth/data-source/binding",
json={"category": "api_key", "provider": "custom", "credentials": {"key": "value"}},
json=payload,
headers=authenticate_console_client(test_client_with_containers, account),
)
assert response.status_code == 200
assert response.get_json() == {"result": "success"}
create_auth.assert_called_once_with(tenant_id, payload)
def test_create_binding_failure(
@@ -129,3 +141,35 @@ def test_delete_binding_successful(
)
is None
)
def test_delete_binding_scopes_to_authenticated_tenant(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, _tenant = create_console_account_and_tenant(db_session_with_containers)
foreign_account, foreign_tenant = create_console_account_and_tenant(db_session_with_containers)
foreign_binding = DataSourceApiKeyAuthBinding(
tenant_id=foreign_tenant.id,
category="api_key",
provider="custom_provider",
credentials=json.dumps({"auth_type": "api_key", "config": {"api_key": "encrypted"}}),
disabled=False,
)
db_session_with_containers.add(foreign_binding)
db_session_with_containers.commit()
foreign_binding_id = foreign_binding.id
authenticate_console_client(test_client_with_containers, foreign_account)
response = test_client_with_containers.delete(
f"/console/api/api-key-auth/data-source/{foreign_binding_id}",
headers=authenticate_console_client(test_client_with_containers, account),
)
assert response.status_code == 204
assert (
db_session_with_containers.scalar(
select(DataSourceApiKeyAuthBinding).where(DataSourceApiKeyAuthBinding.id == foreign_binding_id)
)
is not None
)

View File

@@ -6,11 +6,13 @@ from unittest.mock import patch
import pytest
from flask.testing import FlaskClient
from sqlalchemy import select
from sqlalchemy.orm import Session
from constants import HIDDEN_VALUE
from libs.rsa import generate_key_pair
from models import Tenant
from models.api_based_extension import APIBasedExtension
from services.api_based_extension_service import APIBasedExtensionService
from tests.test_containers_integration_tests.controllers.console.helpers import (
authenticate_console_client,
create_console_account_and_tenant,
@@ -27,13 +29,14 @@ def _masked_api_key(api_key: str) -> str:
def api_extension_client(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> tuple[FlaskClient, dict[str, str], Tenant]:
) -> tuple[FlaskClient, dict[str, str], str]:
account, tenant = create_console_account_and_tenant(db_session_with_containers)
tenant_id = tenant.id
tenant.encrypt_public_key = generate_key_pair(tenant.id)
db_session_with_containers.commit()
headers = authenticate_console_client(test_client_with_containers, account)
return test_client_with_containers, headers, tenant
return test_client_with_containers, headers, tenant_id
@pytest.fixture(autouse=True)
@@ -44,9 +47,10 @@ def mock_api_based_extension_ping():
def test_create_response_masks_plaintext_api_key(
api_extension_client: tuple[FlaskClient, dict[str, str], Tenant],
api_extension_client: tuple[FlaskClient, dict[str, str], str],
db_session_with_containers: Session,
) -> None:
client, headers, _ = api_extension_client
client, headers, tenant_id = api_extension_client
api_key = "plain-secret-12345"
response = client.post(
@@ -62,10 +66,57 @@ def test_create_response_masks_plaintext_api_key(
assert response.status_code == 201
assert response.json is not None
assert response.json["api_key"] == _masked_api_key(api_key)
extension = db_session_with_containers.scalar(
select(APIBasedExtension).where(APIBasedExtension.id == response.json["id"]).limit(1)
)
assert extension is not None
assert extension.tenant_id == tenant_id
def test_list_scopes_api_based_extensions_to_authenticated_tenant(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, tenant = create_console_account_and_tenant(db_session_with_containers)
account_headers = authenticate_console_client(test_client_with_containers, account)
tenant.encrypt_public_key = generate_key_pair(tenant.id)
_foreign_account, foreign_tenant = create_console_account_and_tenant(db_session_with_containers)
foreign_tenant_id = foreign_tenant.id
foreign_tenant.encrypt_public_key = generate_key_pair(foreign_tenant.id)
db_session_with_containers.commit()
account_create_response = test_client_with_containers.post(
"/console/api/api-based-extension",
headers=account_headers,
json={
"name": "Tenant API",
"api_endpoint": "https://tenant.example.com/hook",
"api_key": "tenant-secret-12345",
},
)
assert account_create_response.status_code == 201
APIBasedExtensionService.save(
APIBasedExtension(
tenant_id=foreign_tenant_id,
name="Foreign API",
api_endpoint="https://foreign.example.com/hook",
api_key="foreign-secret-12345",
)
)
response = test_client_with_containers.get(
"/console/api/api-based-extension",
headers=account_headers,
)
assert response.status_code == 200
assert response.json is not None
assert [item["name"] for item in response.json] == ["Tenant API"]
def test_update_response_masks_new_plaintext_api_key(
api_extension_client: tuple[FlaskClient, dict[str, str], Tenant],
api_extension_client: tuple[FlaskClient, dict[str, str], str],
) -> None:
client, headers, _ = api_extension_client
new_api_key = "new-secret-67890"
@@ -96,7 +147,7 @@ def test_update_response_masks_new_plaintext_api_key(
def test_update_response_masks_existing_plaintext_api_key_when_hidden_value_is_submitted(
api_extension_client: tuple[FlaskClient, dict[str, str], Tenant],
api_extension_client: tuple[FlaskClient, dict[str, str], str],
) -> None:
client, headers, _ = api_extension_client
existing_api_key = "old-secret-12345"

View File

@@ -7,9 +7,11 @@ from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from flask.testing import FlaskClient
from sqlalchemy import delete
from sqlalchemy import delete, select
from sqlalchemy.orm import Session
from models import Account
from models.account import AccountStatus, TenantAccountRole
from models.enums import ApiTokenType
from models.model import ApiToken, App, AppMode
from tests.test_containers_integration_tests.controllers.console.helpers import (
@@ -58,6 +60,24 @@ class TestAppApiKeyListResource:
assert data["token"].startswith("app-")
assert data["id"] is not None
def test_create_api_key_persists_authenticated_tenant(
self,
setup_app: tuple[FlaskClient, dict[str, str], App],
db_session_with_containers: Session,
) -> None:
client, headers, app = setup_app
tenant_id = app.tenant_id
resp = client.post(f"/console/api/apps/{app.id}/api-keys", headers=headers)
assert resp.status_code == 201
assert resp.json is not None
api_token = db_session_with_containers.scalar(select(ApiToken).where(ApiToken.id == resp.json["id"]))
assert api_token is not None
assert api_token.tenant_id == tenant_id
assert api_token.app_id == app.id
assert api_token.type == ApiTokenType.APP
def test_get_keys_after_create(self, setup_app: tuple[FlaskClient, dict[str, str], App]) -> None:
client, headers, app = setup_app
client.post(f"/console/api/apps/{app.id}/api-keys", headers=headers)
@@ -93,6 +113,21 @@ class TestAppApiKeyListResource:
)
assert resp.status_code == 404
def test_get_foreign_app_keys_not_found(
self,
setup_app: tuple[FlaskClient, dict[str, str], App],
db_session_with_containers: Session,
) -> None:
client, headers, _ = setup_app
foreign_account, foreign_tenant = create_console_account_and_tenant(db_session_with_containers)
foreign_app = create_console_app(
db_session_with_containers, foreign_tenant.id, foreign_account.id, AppMode.CHAT
)
resp = client.get(f"/console/api/apps/{foreign_app.id}/api-keys", headers=headers)
assert resp.status_code == 404
class TestAppApiKeyResource:
"""Tests for DELETE /apps/<resource_id>/api-keys/<api_key_id>."""
@@ -139,16 +174,13 @@ class TestAppApiKeyResource:
resource.resource_model = MagicMock()
resource.resource_id_field = "app_id"
non_admin = MagicMock()
non_admin.is_admin_or_owner = False
non_admin = Account(name="Normal User", email="normal@example.com", status=AccountStatus.ACTIVE)
non_admin.id = "normal-user"
non_admin.role = TenantAccountRole.NORMAL
with (
flask_app_with_containers.test_request_context("/"),
patch(
"controllers.console.apikey.current_account_with_tenant",
return_value=(non_admin, "tenant-id"),
),
patch("controllers.console.apikey._get_resource"),
):
with pytest.raises(Forbidden):
BaseApiKeyResource.delete(resource, "rid", "kid")
BaseApiKeyResource.delete(resource, "rid", "kid", "tenant-id", non_admin)

View File

@@ -0,0 +1,85 @@
from __future__ import annotations
from datetime import UTC, datetime
from types import SimpleNamespace
from unittest.mock import PropertyMock, patch
from controllers.console import console_ns
from controllers.console.auth.data_source_bearer_auth import (
ApiKeyAuthDataSource,
ApiKeyAuthDataSourceBinding,
ApiKeyAuthDataSourceBindingDelete,
)
def _unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
def _payload_patch(payload: dict):
return patch.object(
type(console_ns),
"payload",
new_callable=PropertyMock,
return_value=payload,
)
def test_list_data_source_auth_uses_injected_tenant_id() -> None:
api = ApiKeyAuthDataSource()
method = _unwrap(api.get)
binding = SimpleNamespace(
id="binding-1",
category="api_key",
provider="custom",
disabled=False,
created_at=datetime(2026, 1, 1, tzinfo=UTC),
updated_at=datetime(2026, 1, 2, tzinfo=UTC),
)
with patch(
"controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.get_provider_auth_list",
return_value=[binding],
) as get_provider_auth_list:
result = method(api, "tenant-1")
get_provider_auth_list.assert_called_once_with("tenant-1")
assert result["sources"][0]["id"] == "binding-1"
assert result["sources"][0]["provider"] == "custom"
def test_create_data_source_auth_binding_uses_injected_tenant_id() -> None:
api = ApiKeyAuthDataSourceBinding()
method = _unwrap(api.post)
payload = {
"category": "api_key",
"provider": "custom",
"credentials": {"auth_type": "api_key", "config": {"api_key": "secret"}},
}
with (
_payload_patch(payload),
patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.validate_api_key_auth_args"),
patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.create_provider_auth") as create_auth,
):
result, status = method(api, "tenant-1")
create_auth.assert_called_once_with("tenant-1", payload)
assert result == {"result": "success"}
assert status == 200
def test_delete_data_source_auth_binding_uses_injected_tenant_id() -> None:
api = ApiKeyAuthDataSourceBindingDelete()
method = _unwrap(api.delete)
with patch(
"controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.delete_provider_auth"
) as delete_provider_auth:
result, status = method(api, "tenant-1", "binding-1")
delete_provider_auth.assert_called_once_with("tenant-1", "binding-1")
assert result == ""
assert status == 204

View File

@@ -0,0 +1,52 @@
from __future__ import annotations
from unittest.mock import patch
from controllers.console.auth.oauth_server import OAuthServerUserAuthorizeApi
from models import Account
from models.account import AccountStatus, TenantAccountRole
from models.model import OAuthProviderApp
def _unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
def _make_account() -> Account:
account = Account(
name="Test User",
email="test@example.com",
status=AccountStatus.ACTIVE,
)
account.id = "account-1"
account.role = TenantAccountRole.OWNER
return account
def _make_oauth_provider_app() -> OAuthProviderApp:
return OAuthProviderApp(
app_icon="icon",
client_id="client-1",
client_secret="secret",
app_label={"en-US": "Test App"},
redirect_uris=["https://example.com/callback"],
scope="read",
)
def test_oauth_authorize_uses_injected_current_user() -> None:
api = OAuthServerUserAuthorizeApi()
method = _unwrap(api.post)
account = _make_account()
oauth_provider_app = _make_oauth_provider_app()
with patch(
"controllers.console.auth.oauth_server.OAuthServerService.sign_oauth_authorization_code",
return_value="authorization-code",
) as sign_oauth_authorization_code:
response = method(api, oauth_provider_app, account)
sign_oauth_authorization_code.assert_called_once_with("client-1", "account-1")
assert response == {"code": "authorization-code"}

View File

@@ -0,0 +1,141 @@
from __future__ import annotations
import inspect
from collections.abc import Callable
from types import SimpleNamespace
from typing import cast
from unittest.mock import patch
import pytest
from werkzeug.exceptions import Forbidden
from controllers.console.apikey import BaseApiKeyListResource, BaseApiKeyResource
from models import Account
from models.account import AccountStatus, TenantAccountRole
from models.enums import ApiTokenType
from models.model import ApiToken, App
def _make_list_resource() -> BaseApiKeyListResource:
resource = BaseApiKeyListResource()
resource.resource_type = ApiTokenType.APP
resource.resource_model = App
resource.resource_id_field = "app_id"
resource.token_prefix = "app-"
return resource
def _make_key_resource() -> BaseApiKeyResource:
resource = BaseApiKeyResource()
resource.resource_type = ApiTokenType.APP
resource.resource_model = App
resource.resource_id_field = "app_id"
return resource
def _make_account(role: TenantAccountRole) -> Account:
account = Account(
name="Test User",
email=f"{role.value}@example.com",
status=AccountStatus.ACTIVE,
)
account.id = f"{role.value}-user"
account.role = role
return account
def test_list_api_keys_uses_injected_tenant_id() -> None:
resource = _make_list_resource()
api_key = SimpleNamespace(
id="key-1",
type=ApiTokenType.APP,
token="app-token",
last_used_at=None,
created_at=None,
)
with (
patch("controllers.console.apikey._get_resource") as get_resource,
patch("controllers.console.apikey.db") as db_mock,
):
db_mock.session.scalars.return_value.all.return_value = [api_key]
result = resource.get("app-1", "tenant-1")
get_resource.assert_called_once_with("app-1", "tenant-1", App)
assert result == {
"data": [
{
"id": "key-1",
"type": "app",
"token": "app-token",
"last_used_at": None,
"created_at": None,
}
]
}
def test_create_api_key_uses_injected_tenant_id() -> None:
resource = _make_list_resource()
raw_post = cast(
Callable[[BaseApiKeyListResource, str, str], tuple[dict[str, object], int]],
inspect.unwrap(BaseApiKeyListResource.post),
)
def add_api_token(api_token: ApiToken) -> None:
api_token.id = "key-1"
with (
patch("controllers.console.apikey._get_resource") as get_resource,
patch("controllers.console.apikey.db") as db_mock,
patch("controllers.console.apikey.ApiToken.generate_api_key", return_value="app-generated-token"),
):
db_mock.session.scalar.return_value = 0
db_mock.session.add.side_effect = add_api_token
result, status = raw_post(resource, "app-1", "tenant-1")
get_resource.assert_called_once_with("app-1", "tenant-1", App)
assert status == 201
assert result["token"] == "app-generated-token"
api_token = db_mock.session.add.call_args.args[0]
assert api_token.app_id == "app-1"
assert api_token.tenant_id == "tenant-1"
assert api_token.type == ApiTokenType.APP
db_mock.session.commit.assert_called_once()
def test_delete_api_key_rejects_non_admin_account() -> None:
resource = _make_key_resource()
with (
patch("controllers.console.apikey._get_resource") as get_resource,
patch("controllers.console.apikey.db") as db_mock,
):
with pytest.raises(Forbidden):
resource.delete("app-1", "key-1", "tenant-1", _make_account(TenantAccountRole.NORMAL))
get_resource.assert_called_once_with("app-1", "tenant-1", App)
db_mock.session.scalar.assert_not_called()
def test_delete_api_key_uses_injected_user_and_tenant() -> None:
resource = _make_key_resource()
api_key = SimpleNamespace(token="app-token", type=ApiTokenType.APP)
with (
patch("controllers.console.apikey._get_resource") as get_resource,
patch("controllers.console.apikey.db") as db_mock,
patch("controllers.console.apikey.ApiTokenCache.delete") as delete_cache,
):
db_mock.session.scalar.return_value = api_key
result, status = resource.delete("app-1", "key-1", "tenant-1", _make_account(TenantAccountRole.OWNER))
get_resource.assert_called_once_with("app-1", "tenant-1", App)
delete_cache.assert_called_once_with("app-token", ApiTokenType.APP)
db_mock.session.execute.assert_called_once()
db_mock.session.commit.assert_called_once()
assert result == ""
assert status == 204