diff --git a/api/controllers/console/apikey.py b/api/controllers/console/apikey.py index 133c57d34d..57470dc977 100644 --- a/api/controllers/console/apikey.py +++ b/api/controllers/console/apikey.py @@ -9,18 +9,25 @@ from sqlalchemy import delete, func, select from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import Forbidden -from controllers.common.schema import register_schema_models +from controllers.common.schema import register_response_schema_models from extensions.ext_database import db from fields.base import ResponseModel -from libs.helper import to_timestamp -from libs.login import current_account_with_tenant, login_required +from libs.helper import dump_response, to_timestamp +from libs.login import login_required +from models import Account from models.dataset import Dataset from models.enums import ApiTokenType from models.model import ApiToken, App from services.api_token_service import ApiTokenCache from . import console_ns -from .wraps import account_initialization_required, edit_permission_required, setup_required +from .wraps import ( + account_initialization_required, + edit_permission_required, + setup_required, + with_current_tenant_id, + with_current_user, +) class ApiKeyItem(ResponseModel): @@ -40,7 +47,7 @@ class ApiKeyList(ResponseModel): data: list[ApiKeyItem] -register_schema_models(console_ns, ApiKeyItem, ApiKeyList) +register_response_schema_models(console_ns, ApiKeyItem, ApiKeyList) def _get_resource(resource_id, tenant_id, resource_model): @@ -64,10 +71,11 @@ class BaseApiKeyListResource(Resource): token_prefix: str | None = None max_keys = 10 - def get(self, resource_id): + def get(self, resource_id: str, current_tenant_id: str) -> dict[str, object]: + return dump_response(ApiKeyList, self._get_api_key_list(resource_id, current_tenant_id)) + + def _get_api_key_list(self, resource_id: str, current_tenant_id: str) -> ApiKeyList: assert self.resource_id_field is not None, "resource_id_field must be set" - resource_id = str(resource_id) - _, current_tenant_id = current_account_with_tenant() _get_resource(resource_id, current_tenant_id, self.resource_model) keys = db.session.scalars( @@ -75,13 +83,14 @@ class BaseApiKeyListResource(Resource): ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id ) ).all() - return ApiKeyList.model_validate({"data": keys}, from_attributes=True).model_dump(mode="json") + return ApiKeyList.model_validate({"data": keys}, from_attributes=True) @edit_permission_required - def post(self, resource_id): + def post(self, resource_id: str, current_tenant_id: str) -> tuple[dict[str, object], int]: + return dump_response(ApiKeyItem, self._create_api_key(resource_id, current_tenant_id)), 201 + + def _create_api_key(self, resource_id: str, current_tenant_id: str) -> ApiToken: assert self.resource_id_field is not None, "resource_id_field must be set" - resource_id = str(resource_id) - _, current_tenant_id = current_account_with_tenant() _get_resource(resource_id, current_tenant_id, self.resource_model) current_key_count: int = ( db.session.scalar( @@ -108,7 +117,7 @@ class BaseApiKeyListResource(Resource): api_token.type = self.resource_type db.session.add(api_token) db.session.commit() - return ApiKeyItem.model_validate(api_token, from_attributes=True).model_dump(mode="json"), 201 + return api_token class BaseApiKeyResource(Resource): @@ -118,9 +127,20 @@ class BaseApiKeyResource(Resource): resource_model: type | None = None resource_id_field: str | None = None - def delete(self, resource_id: str, api_key_id: str): + def delete( + self, resource_id: str, api_key_id: str, current_tenant_id: str, current_user: Account + ) -> tuple[str, int]: + self._delete_api_key(resource_id, api_key_id, current_tenant_id, current_user) + return "", 204 + + def _delete_api_key( + self, + resource_id: str, + api_key_id: str, + current_tenant_id: str, + current_user: Account, + ) -> None: assert self.resource_id_field is not None, "resource_id_field must be set" - current_user, current_tenant_id = current_account_with_tenant() _get_resource(resource_id, current_tenant_id, self.resource_model) if not current_user.is_admin_or_owner: @@ -147,8 +167,6 @@ class BaseApiKeyResource(Resource): db.session.execute(delete(ApiToken).where(ApiToken.id == api_key_id)) db.session.commit() - return "", 204 - @console_ns.route("/apps//api-keys") class AppApiKeyListResource(BaseApiKeyListResource): @@ -156,18 +174,21 @@ class AppApiKeyListResource(BaseApiKeyListResource): @console_ns.doc(description="Get all API keys for an app") @console_ns.doc(params={"resource_id": "App ID"}) @console_ns.response(200, "API keys retrieved successfully", console_ns.models[ApiKeyList.__name__]) - def get(self, resource_id: UUID): + @with_current_tenant_id + def get(self, current_tenant_id: str, resource_id: UUID) -> dict[str, object]: """Get all API keys for an app""" - return super().get(resource_id) + return dump_response(ApiKeyList, self._get_api_key_list(str(resource_id), current_tenant_id)) @console_ns.doc("create_app_api_key") @console_ns.doc(description="Create a new API key for an app") @console_ns.doc(params={"resource_id": "App ID"}) @console_ns.response(201, "API key created successfully", console_ns.models[ApiKeyItem.__name__]) @console_ns.response(400, "Maximum keys exceeded") - def post(self, resource_id: UUID): + @with_current_tenant_id + @edit_permission_required + def post(self, current_tenant_id: str, resource_id: UUID) -> tuple[dict[str, object], int]: """Create a new API key for an app""" - return super().post(resource_id) + return dump_response(ApiKeyItem, self._create_api_key(str(resource_id), current_tenant_id)), 201 resource_type = ApiTokenType.APP resource_model = App @@ -181,9 +202,14 @@ class AppApiKeyResource(BaseApiKeyResource): @console_ns.doc(description="Delete an API key for an app") @console_ns.doc(params={"resource_id": "App ID", "api_key_id": "API key ID"}) @console_ns.response(204, "API key deleted successfully") - def delete(self, resource_id: UUID, api_key_id: UUID): + @with_current_user + @with_current_tenant_id + def delete( + self, current_tenant_id: str, current_user: Account, resource_id: UUID, api_key_id: UUID + ) -> tuple[str, int]: """Delete an API key for an app""" - return super().delete(str(resource_id), str(api_key_id)) + self._delete_api_key(str(resource_id), str(api_key_id), current_tenant_id, current_user) + return "", 204 resource_type = ApiTokenType.APP resource_model = App @@ -196,18 +222,21 @@ class DatasetApiKeyListResource(BaseApiKeyListResource): @console_ns.doc(description="Get all API keys for a dataset") @console_ns.doc(params={"resource_id": "Dataset ID"}) @console_ns.response(200, "API keys retrieved successfully", console_ns.models[ApiKeyList.__name__]) - def get(self, resource_id: UUID): + @with_current_tenant_id + def get(self, current_tenant_id: str, resource_id: UUID) -> dict[str, object]: """Get all API keys for a dataset""" - return super().get(resource_id) + return dump_response(ApiKeyList, self._get_api_key_list(str(resource_id), current_tenant_id)) @console_ns.doc("create_dataset_api_key") @console_ns.doc(description="Create a new API key for a dataset") @console_ns.doc(params={"resource_id": "Dataset ID"}) @console_ns.response(201, "API key created successfully", console_ns.models[ApiKeyItem.__name__]) @console_ns.response(400, "Maximum keys exceeded") - def post(self, resource_id: UUID): + @with_current_tenant_id + @edit_permission_required + def post(self, current_tenant_id: str, resource_id: UUID) -> tuple[dict[str, object], int]: """Create a new API key for a dataset""" - return super().post(resource_id) + return dump_response(ApiKeyItem, self._create_api_key(str(resource_id), current_tenant_id)), 201 resource_type = ApiTokenType.DATASET resource_model = Dataset @@ -221,9 +250,14 @@ class DatasetApiKeyResource(BaseApiKeyResource): @console_ns.doc(description="Delete an API key for a dataset") @console_ns.doc(params={"resource_id": "Dataset ID", "api_key_id": "API key ID"}) @console_ns.response(204, "API key deleted successfully") - def delete(self, resource_id: UUID, api_key_id: UUID): + @with_current_user + @with_current_tenant_id + def delete( + self, current_tenant_id: str, current_user: Account, resource_id: UUID, api_key_id: UUID + ) -> tuple[str, int]: """Delete an API key for a dataset""" - return super().delete(str(resource_id), str(api_key_id)) + self._delete_api_key(str(resource_id), str(api_key_id), current_tenant_id, current_user) + return "", 204 resource_type = ApiTokenType.DATASET resource_model = Dataset diff --git a/api/controllers/console/auth/data_source_bearer_auth.py b/api/controllers/console/auth/data_source_bearer_auth.py index 7544c4dbdc..33f6fb14ae 100644 --- a/api/controllers/console/auth/data_source_bearer_auth.py +++ b/api/controllers/console/auth/data_source_bearer_auth.py @@ -5,12 +5,12 @@ from pydantic import BaseModel, Field from controllers.common.schema import register_response_schema_models, register_schema_models from fields.base import ResponseModel -from libs.login import current_account_with_tenant, login_required +from libs.login import login_required from services.auth.api_key_auth_service import ApiKeyAuthService from .. import console_ns from ..auth.error import ApiKeyAuthFailedError -from ..wraps import account_initialization_required, is_admin_or_owner_required, setup_required +from ..wraps import account_initialization_required, is_admin_or_owner_required, setup_required, with_current_tenant_id class ApiKeyAuthBindingPayload(BaseModel): @@ -42,8 +42,8 @@ class ApiKeyAuthDataSource(Resource): @setup_required @login_required @account_initialization_required - def get(self): - _, current_tenant_id = current_account_with_tenant() + @with_current_tenant_id + def get(self, current_tenant_id: str): data_source_api_key_bindings = ApiKeyAuthService.get_provider_auth_list(current_tenant_id) if data_source_api_key_bindings: return { @@ -69,9 +69,9 @@ class ApiKeyAuthDataSourceBinding(Resource): @account_initialization_required @is_admin_or_owner_required @console_ns.expect(console_ns.models[ApiKeyAuthBindingPayload.__name__]) - def post(self): + @with_current_tenant_id + def post(self, current_tenant_id: str): # The role of the current user in the table must be admin or owner - _, current_tenant_id = current_account_with_tenant() payload = ApiKeyAuthBindingPayload.model_validate(console_ns.payload) data = payload.model_dump() ApiKeyAuthService.validate_api_key_auth_args(data) @@ -89,10 +89,9 @@ class ApiKeyAuthDataSourceBindingDelete(Resource): @account_initialization_required @is_admin_or_owner_required @console_ns.response(204, "Binding deleted successfully") - def delete(self, binding_id: UUID): + @with_current_tenant_id + def delete(self, current_tenant_id: str, binding_id: UUID): # The role of the current user in the table must be admin or owner - _, current_tenant_id = current_account_with_tenant() - ApiKeyAuthService.delete_provider_auth(current_tenant_id, str(binding_id)) return "", 204 diff --git a/api/controllers/console/auth/oauth_server.py b/api/controllers/console/auth/oauth_server.py index 727428c8e7..7e48558977 100644 --- a/api/controllers/console/auth/oauth_server.py +++ b/api/controllers/console/auth/oauth_server.py @@ -8,9 +8,9 @@ from flask_restx import Resource from pydantic import BaseModel from werkzeug.exceptions import BadRequest, NotFound -from controllers.console.wraps import account_initialization_required, setup_required +from controllers.console.wraps import account_initialization_required, setup_required, with_current_user from graphon.model_runtime.utils.encoders import jsonable_encoder -from libs.login import current_account_with_tenant, login_required +from libs.login import login_required from models import Account from models.model import OAuthProviderApp from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN, OAuthGrantType, OAuthServerService @@ -133,12 +133,10 @@ class OAuthServerUserAuthorizeApi(Resource): @setup_required @login_required @account_initialization_required + @with_current_user @oauth_server_client_id_required - def post(self, oauth_provider_app: OAuthProviderApp): - current_user, _ = current_account_with_tenant() - account = current_user - user_account_id = account.id - + def post(self, oauth_provider_app: OAuthProviderApp, current_user: Account): + user_account_id = current_user.id code = OAuthServerService.sign_oauth_authorization_code(oauth_provider_app.client_id, user_account_id) return jsonable_encoder( { diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_data_source_bearer_auth.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_data_source_bearer_auth.py index 00309c25d6..5eb9f71e69 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/auth/test_data_source_bearer_auth.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_data_source_bearer_auth.py @@ -19,6 +19,7 @@ def test_get_api_key_auth_data_source( test_client_with_containers: FlaskClient, ) -> None: account, tenant = create_console_account_and_tenant(db_session_with_containers) + foreign_account, foreign_tenant = create_console_account_and_tenant(db_session_with_containers) binding = DataSourceApiKeyAuthBinding( tenant_id=tenant.id, category="api_key", @@ -26,8 +27,16 @@ def test_get_api_key_auth_data_source( credentials=json.dumps({"auth_type": "api_key", "config": {"api_key": "encrypted"}}), disabled=False, ) - db_session_with_containers.add(binding) + foreign_binding = DataSourceApiKeyAuthBinding( + tenant_id=foreign_tenant.id, + category="api_key", + provider="foreign_provider", + credentials=json.dumps({"auth_type": "api_key", "config": {"api_key": "encrypted"}}), + disabled=False, + ) + db_session_with_containers.add_all([binding, foreign_binding]) db_session_with_containers.commit() + authenticate_console_client(test_client_with_containers, foreign_account) response = test_client_with_containers.get( "/console/api/api-key-auth/data-source", @@ -60,20 +69,23 @@ def test_create_binding_successful( db_session_with_containers: Session, test_client_with_containers: FlaskClient, ) -> None: - account, _tenant = create_console_account_and_tenant(db_session_with_containers) + account, tenant = create_console_account_and_tenant(db_session_with_containers) + tenant_id = tenant.id + payload = {"category": "api_key", "provider": "custom", "credentials": {"key": "value"}} with ( patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.validate_api_key_auth_args"), - patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.create_provider_auth"), + patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.create_provider_auth") as create_auth, ): response = test_client_with_containers.post( "/console/api/api-key-auth/data-source/binding", - json={"category": "api_key", "provider": "custom", "credentials": {"key": "value"}}, + json=payload, headers=authenticate_console_client(test_client_with_containers, account), ) assert response.status_code == 200 assert response.get_json() == {"result": "success"} + create_auth.assert_called_once_with(tenant_id, payload) def test_create_binding_failure( @@ -129,3 +141,35 @@ def test_delete_binding_successful( ) is None ) + + +def test_delete_binding_scopes_to_authenticated_tenant( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, _tenant = create_console_account_and_tenant(db_session_with_containers) + foreign_account, foreign_tenant = create_console_account_and_tenant(db_session_with_containers) + foreign_binding = DataSourceApiKeyAuthBinding( + tenant_id=foreign_tenant.id, + category="api_key", + provider="custom_provider", + credentials=json.dumps({"auth_type": "api_key", "config": {"api_key": "encrypted"}}), + disabled=False, + ) + db_session_with_containers.add(foreign_binding) + db_session_with_containers.commit() + foreign_binding_id = foreign_binding.id + authenticate_console_client(test_client_with_containers, foreign_account) + + response = test_client_with_containers.delete( + f"/console/api/api-key-auth/data-source/{foreign_binding_id}", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 204 + assert ( + db_session_with_containers.scalar( + select(DataSourceApiKeyAuthBinding).where(DataSourceApiKeyAuthBinding.id == foreign_binding_id) + ) + is not None + ) diff --git a/api/tests/test_containers_integration_tests/controllers/console/test_api_based_extension.py b/api/tests/test_containers_integration_tests/controllers/console/test_api_based_extension.py index e7852b8fe1..058f4e5fa3 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/test_api_based_extension.py +++ b/api/tests/test_containers_integration_tests/controllers/console/test_api_based_extension.py @@ -6,11 +6,13 @@ from unittest.mock import patch import pytest from flask.testing import FlaskClient +from sqlalchemy import select from sqlalchemy.orm import Session from constants import HIDDEN_VALUE from libs.rsa import generate_key_pair -from models import Tenant +from models.api_based_extension import APIBasedExtension +from services.api_based_extension_service import APIBasedExtensionService from tests.test_containers_integration_tests.controllers.console.helpers import ( authenticate_console_client, create_console_account_and_tenant, @@ -27,13 +29,14 @@ def _masked_api_key(api_key: str) -> str: def api_extension_client( db_session_with_containers: Session, test_client_with_containers: FlaskClient, -) -> tuple[FlaskClient, dict[str, str], Tenant]: +) -> tuple[FlaskClient, dict[str, str], str]: account, tenant = create_console_account_and_tenant(db_session_with_containers) + tenant_id = tenant.id tenant.encrypt_public_key = generate_key_pair(tenant.id) db_session_with_containers.commit() headers = authenticate_console_client(test_client_with_containers, account) - return test_client_with_containers, headers, tenant + return test_client_with_containers, headers, tenant_id @pytest.fixture(autouse=True) @@ -44,9 +47,10 @@ def mock_api_based_extension_ping(): def test_create_response_masks_plaintext_api_key( - api_extension_client: tuple[FlaskClient, dict[str, str], Tenant], + api_extension_client: tuple[FlaskClient, dict[str, str], str], + db_session_with_containers: Session, ) -> None: - client, headers, _ = api_extension_client + client, headers, tenant_id = api_extension_client api_key = "plain-secret-12345" response = client.post( @@ -62,10 +66,57 @@ def test_create_response_masks_plaintext_api_key( assert response.status_code == 201 assert response.json is not None assert response.json["api_key"] == _masked_api_key(api_key) + extension = db_session_with_containers.scalar( + select(APIBasedExtension).where(APIBasedExtension.id == response.json["id"]).limit(1) + ) + assert extension is not None + assert extension.tenant_id == tenant_id + + +def test_list_scopes_api_based_extensions_to_authenticated_tenant( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + account_headers = authenticate_console_client(test_client_with_containers, account) + tenant.encrypt_public_key = generate_key_pair(tenant.id) + _foreign_account, foreign_tenant = create_console_account_and_tenant(db_session_with_containers) + foreign_tenant_id = foreign_tenant.id + foreign_tenant.encrypt_public_key = generate_key_pair(foreign_tenant.id) + db_session_with_containers.commit() + + account_create_response = test_client_with_containers.post( + "/console/api/api-based-extension", + headers=account_headers, + json={ + "name": "Tenant API", + "api_endpoint": "https://tenant.example.com/hook", + "api_key": "tenant-secret-12345", + }, + ) + assert account_create_response.status_code == 201 + + APIBasedExtensionService.save( + APIBasedExtension( + tenant_id=foreign_tenant_id, + name="Foreign API", + api_endpoint="https://foreign.example.com/hook", + api_key="foreign-secret-12345", + ) + ) + + response = test_client_with_containers.get( + "/console/api/api-based-extension", + headers=account_headers, + ) + + assert response.status_code == 200 + assert response.json is not None + assert [item["name"] for item in response.json] == ["Tenant API"] def test_update_response_masks_new_plaintext_api_key( - api_extension_client: tuple[FlaskClient, dict[str, str], Tenant], + api_extension_client: tuple[FlaskClient, dict[str, str], str], ) -> None: client, headers, _ = api_extension_client new_api_key = "new-secret-67890" @@ -96,7 +147,7 @@ def test_update_response_masks_new_plaintext_api_key( def test_update_response_masks_existing_plaintext_api_key_when_hidden_value_is_submitted( - api_extension_client: tuple[FlaskClient, dict[str, str], Tenant], + api_extension_client: tuple[FlaskClient, dict[str, str], str], ) -> None: client, headers, _ = api_extension_client existing_api_key = "old-secret-12345" diff --git a/api/tests/test_containers_integration_tests/controllers/console/test_apikey.py b/api/tests/test_containers_integration_tests/controllers/console/test_apikey.py index b55eaa1d58..8559501c89 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/test_apikey.py +++ b/api/tests/test_containers_integration_tests/controllers/console/test_apikey.py @@ -7,9 +7,11 @@ from unittest.mock import MagicMock, patch import pytest from flask import Flask from flask.testing import FlaskClient -from sqlalchemy import delete +from sqlalchemy import delete, select from sqlalchemy.orm import Session +from models import Account +from models.account import AccountStatus, TenantAccountRole from models.enums import ApiTokenType from models.model import ApiToken, App, AppMode from tests.test_containers_integration_tests.controllers.console.helpers import ( @@ -58,6 +60,24 @@ class TestAppApiKeyListResource: assert data["token"].startswith("app-") assert data["id"] is not None + def test_create_api_key_persists_authenticated_tenant( + self, + setup_app: tuple[FlaskClient, dict[str, str], App], + db_session_with_containers: Session, + ) -> None: + client, headers, app = setup_app + tenant_id = app.tenant_id + + resp = client.post(f"/console/api/apps/{app.id}/api-keys", headers=headers) + + assert resp.status_code == 201 + assert resp.json is not None + api_token = db_session_with_containers.scalar(select(ApiToken).where(ApiToken.id == resp.json["id"])) + assert api_token is not None + assert api_token.tenant_id == tenant_id + assert api_token.app_id == app.id + assert api_token.type == ApiTokenType.APP + def test_get_keys_after_create(self, setup_app: tuple[FlaskClient, dict[str, str], App]) -> None: client, headers, app = setup_app client.post(f"/console/api/apps/{app.id}/api-keys", headers=headers) @@ -93,6 +113,21 @@ class TestAppApiKeyListResource: ) assert resp.status_code == 404 + def test_get_foreign_app_keys_not_found( + self, + setup_app: tuple[FlaskClient, dict[str, str], App], + db_session_with_containers: Session, + ) -> None: + client, headers, _ = setup_app + foreign_account, foreign_tenant = create_console_account_and_tenant(db_session_with_containers) + foreign_app = create_console_app( + db_session_with_containers, foreign_tenant.id, foreign_account.id, AppMode.CHAT + ) + + resp = client.get(f"/console/api/apps/{foreign_app.id}/api-keys", headers=headers) + + assert resp.status_code == 404 + class TestAppApiKeyResource: """Tests for DELETE /apps//api-keys/.""" @@ -139,16 +174,13 @@ class TestAppApiKeyResource: resource.resource_model = MagicMock() resource.resource_id_field = "app_id" - non_admin = MagicMock() - non_admin.is_admin_or_owner = False + non_admin = Account(name="Normal User", email="normal@example.com", status=AccountStatus.ACTIVE) + non_admin.id = "normal-user" + non_admin.role = TenantAccountRole.NORMAL with ( flask_app_with_containers.test_request_context("/"), - patch( - "controllers.console.apikey.current_account_with_tenant", - return_value=(non_admin, "tenant-id"), - ), patch("controllers.console.apikey._get_resource"), ): with pytest.raises(Forbidden): - BaseApiKeyResource.delete(resource, "rid", "kid") + BaseApiKeyResource.delete(resource, "rid", "kid", "tenant-id", non_admin) diff --git a/api/tests/unit_tests/controllers/console/auth/test_data_source_bearer_auth.py b/api/tests/unit_tests/controllers/console/auth/test_data_source_bearer_auth.py new file mode 100644 index 0000000000..51bbc33079 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/auth/test_data_source_bearer_auth.py @@ -0,0 +1,85 @@ +from __future__ import annotations + +from datetime import UTC, datetime +from types import SimpleNamespace +from unittest.mock import PropertyMock, patch + +from controllers.console import console_ns +from controllers.console.auth.data_source_bearer_auth import ( + ApiKeyAuthDataSource, + ApiKeyAuthDataSourceBinding, + ApiKeyAuthDataSourceBindingDelete, +) + + +def _unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +def _payload_patch(payload: dict): + return patch.object( + type(console_ns), + "payload", + new_callable=PropertyMock, + return_value=payload, + ) + + +def test_list_data_source_auth_uses_injected_tenant_id() -> None: + api = ApiKeyAuthDataSource() + method = _unwrap(api.get) + binding = SimpleNamespace( + id="binding-1", + category="api_key", + provider="custom", + disabled=False, + created_at=datetime(2026, 1, 1, tzinfo=UTC), + updated_at=datetime(2026, 1, 2, tzinfo=UTC), + ) + + with patch( + "controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.get_provider_auth_list", + return_value=[binding], + ) as get_provider_auth_list: + result = method(api, "tenant-1") + + get_provider_auth_list.assert_called_once_with("tenant-1") + assert result["sources"][0]["id"] == "binding-1" + assert result["sources"][0]["provider"] == "custom" + + +def test_create_data_source_auth_binding_uses_injected_tenant_id() -> None: + api = ApiKeyAuthDataSourceBinding() + method = _unwrap(api.post) + payload = { + "category": "api_key", + "provider": "custom", + "credentials": {"auth_type": "api_key", "config": {"api_key": "secret"}}, + } + + with ( + _payload_patch(payload), + patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.validate_api_key_auth_args"), + patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.create_provider_auth") as create_auth, + ): + result, status = method(api, "tenant-1") + + create_auth.assert_called_once_with("tenant-1", payload) + assert result == {"result": "success"} + assert status == 200 + + +def test_delete_data_source_auth_binding_uses_injected_tenant_id() -> None: + api = ApiKeyAuthDataSourceBindingDelete() + method = _unwrap(api.delete) + + with patch( + "controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.delete_provider_auth" + ) as delete_provider_auth: + result, status = method(api, "tenant-1", "binding-1") + + delete_provider_auth.assert_called_once_with("tenant-1", "binding-1") + assert result == "" + assert status == 204 diff --git a/api/tests/unit_tests/controllers/console/auth/test_oauth_server.py b/api/tests/unit_tests/controllers/console/auth/test_oauth_server.py new file mode 100644 index 0000000000..1508d7b50e --- /dev/null +++ b/api/tests/unit_tests/controllers/console/auth/test_oauth_server.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +from unittest.mock import patch + +from controllers.console.auth.oauth_server import OAuthServerUserAuthorizeApi +from models import Account +from models.account import AccountStatus, TenantAccountRole +from models.model import OAuthProviderApp + + +def _unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +def _make_account() -> Account: + account = Account( + name="Test User", + email="test@example.com", + status=AccountStatus.ACTIVE, + ) + account.id = "account-1" + account.role = TenantAccountRole.OWNER + return account + + +def _make_oauth_provider_app() -> OAuthProviderApp: + return OAuthProviderApp( + app_icon="icon", + client_id="client-1", + client_secret="secret", + app_label={"en-US": "Test App"}, + redirect_uris=["https://example.com/callback"], + scope="read", + ) + + +def test_oauth_authorize_uses_injected_current_user() -> None: + api = OAuthServerUserAuthorizeApi() + method = _unwrap(api.post) + account = _make_account() + oauth_provider_app = _make_oauth_provider_app() + + with patch( + "controllers.console.auth.oauth_server.OAuthServerService.sign_oauth_authorization_code", + return_value="authorization-code", + ) as sign_oauth_authorization_code: + response = method(api, oauth_provider_app, account) + + sign_oauth_authorization_code.assert_called_once_with("client-1", "account-1") + assert response == {"code": "authorization-code"} diff --git a/api/tests/unit_tests/controllers/console/test_apikey.py b/api/tests/unit_tests/controllers/console/test_apikey.py new file mode 100644 index 0000000000..1517ff5ed8 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/test_apikey.py @@ -0,0 +1,141 @@ +from __future__ import annotations + +import inspect +from collections.abc import Callable +from types import SimpleNamespace +from typing import cast +from unittest.mock import patch + +import pytest +from werkzeug.exceptions import Forbidden + +from controllers.console.apikey import BaseApiKeyListResource, BaseApiKeyResource +from models import Account +from models.account import AccountStatus, TenantAccountRole +from models.enums import ApiTokenType +from models.model import ApiToken, App + + +def _make_list_resource() -> BaseApiKeyListResource: + resource = BaseApiKeyListResource() + resource.resource_type = ApiTokenType.APP + resource.resource_model = App + resource.resource_id_field = "app_id" + resource.token_prefix = "app-" + return resource + + +def _make_key_resource() -> BaseApiKeyResource: + resource = BaseApiKeyResource() + resource.resource_type = ApiTokenType.APP + resource.resource_model = App + resource.resource_id_field = "app_id" + return resource + + +def _make_account(role: TenantAccountRole) -> Account: + account = Account( + name="Test User", + email=f"{role.value}@example.com", + status=AccountStatus.ACTIVE, + ) + account.id = f"{role.value}-user" + account.role = role + return account + + +def test_list_api_keys_uses_injected_tenant_id() -> None: + resource = _make_list_resource() + api_key = SimpleNamespace( + id="key-1", + type=ApiTokenType.APP, + token="app-token", + last_used_at=None, + created_at=None, + ) + + with ( + patch("controllers.console.apikey._get_resource") as get_resource, + patch("controllers.console.apikey.db") as db_mock, + ): + db_mock.session.scalars.return_value.all.return_value = [api_key] + + result = resource.get("app-1", "tenant-1") + + get_resource.assert_called_once_with("app-1", "tenant-1", App) + assert result == { + "data": [ + { + "id": "key-1", + "type": "app", + "token": "app-token", + "last_used_at": None, + "created_at": None, + } + ] + } + + +def test_create_api_key_uses_injected_tenant_id() -> None: + resource = _make_list_resource() + raw_post = cast( + Callable[[BaseApiKeyListResource, str, str], tuple[dict[str, object], int]], + inspect.unwrap(BaseApiKeyListResource.post), + ) + + def add_api_token(api_token: ApiToken) -> None: + api_token.id = "key-1" + + with ( + patch("controllers.console.apikey._get_resource") as get_resource, + patch("controllers.console.apikey.db") as db_mock, + patch("controllers.console.apikey.ApiToken.generate_api_key", return_value="app-generated-token"), + ): + db_mock.session.scalar.return_value = 0 + db_mock.session.add.side_effect = add_api_token + + result, status = raw_post(resource, "app-1", "tenant-1") + + get_resource.assert_called_once_with("app-1", "tenant-1", App) + assert status == 201 + assert result["token"] == "app-generated-token" + api_token = db_mock.session.add.call_args.args[0] + assert api_token.app_id == "app-1" + assert api_token.tenant_id == "tenant-1" + assert api_token.type == ApiTokenType.APP + db_mock.session.commit.assert_called_once() + + +def test_delete_api_key_rejects_non_admin_account() -> None: + resource = _make_key_resource() + + with ( + patch("controllers.console.apikey._get_resource") as get_resource, + patch("controllers.console.apikey.db") as db_mock, + ): + with pytest.raises(Forbidden): + resource.delete("app-1", "key-1", "tenant-1", _make_account(TenantAccountRole.NORMAL)) + + get_resource.assert_called_once_with("app-1", "tenant-1", App) + db_mock.session.scalar.assert_not_called() + + +def test_delete_api_key_uses_injected_user_and_tenant() -> None: + resource = _make_key_resource() + api_key = SimpleNamespace(token="app-token", type=ApiTokenType.APP) + + with ( + patch("controllers.console.apikey._get_resource") as get_resource, + patch("controllers.console.apikey.db") as db_mock, + patch("controllers.console.apikey.ApiTokenCache.delete") as delete_cache, + ): + db_mock.session.scalar.return_value = api_key + + result, status = resource.delete("app-1", "key-1", "tenant-1", _make_account(TenantAccountRole.OWNER)) + + get_resource.assert_called_once_with("app-1", "tenant-1", App) + delete_cache.assert_called_once_with("app-token", ApiTokenType.APP) + db_mock.session.execute.assert_called_once() + db_mock.session.commit.assert_called_once() + assert result == "" + assert status == 204