mirror of
https://github.com/langgenius/dify.git
synced 2026-06-01 04:00:59 -04:00
refactor(api): migrate tenant/user via DI: apikey, extension, data_source_bearer, oauth_server (#36660)
This commit is contained in:
@@ -9,18 +9,25 @@ from sqlalchemy import delete, func, select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.common.schema import register_response_schema_models
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from libs.helper import to_timestamp
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.helper import dump_response, to_timestamp
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from models.dataset import Dataset
|
||||
from models.enums import ApiTokenType
|
||||
from models.model import ApiToken, App
|
||||
from services.api_token_service import ApiTokenCache
|
||||
|
||||
from . import console_ns
|
||||
from .wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from .wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
|
||||
|
||||
class ApiKeyItem(ResponseModel):
|
||||
@@ -40,7 +47,7 @@ class ApiKeyList(ResponseModel):
|
||||
data: list[ApiKeyItem]
|
||||
|
||||
|
||||
register_schema_models(console_ns, ApiKeyItem, ApiKeyList)
|
||||
register_response_schema_models(console_ns, ApiKeyItem, ApiKeyList)
|
||||
|
||||
|
||||
def _get_resource(resource_id, tenant_id, resource_model):
|
||||
@@ -64,10 +71,11 @@ class BaseApiKeyListResource(Resource):
|
||||
token_prefix: str | None = None
|
||||
max_keys = 10
|
||||
|
||||
def get(self, resource_id):
|
||||
def get(self, resource_id: str, current_tenant_id: str) -> dict[str, object]:
|
||||
return dump_response(ApiKeyList, self._get_api_key_list(resource_id, current_tenant_id))
|
||||
|
||||
def _get_api_key_list(self, resource_id: str, current_tenant_id: str) -> ApiKeyList:
|
||||
assert self.resource_id_field is not None, "resource_id_field must be set"
|
||||
resource_id = str(resource_id)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
_get_resource(resource_id, current_tenant_id, self.resource_model)
|
||||
keys = db.session.scalars(
|
||||
@@ -75,13 +83,14 @@ class BaseApiKeyListResource(Resource):
|
||||
ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id
|
||||
)
|
||||
).all()
|
||||
return ApiKeyList.model_validate({"data": keys}, from_attributes=True).model_dump(mode="json")
|
||||
return ApiKeyList.model_validate({"data": keys}, from_attributes=True)
|
||||
|
||||
@edit_permission_required
|
||||
def post(self, resource_id):
|
||||
def post(self, resource_id: str, current_tenant_id: str) -> tuple[dict[str, object], int]:
|
||||
return dump_response(ApiKeyItem, self._create_api_key(resource_id, current_tenant_id)), 201
|
||||
|
||||
def _create_api_key(self, resource_id: str, current_tenant_id: str) -> ApiToken:
|
||||
assert self.resource_id_field is not None, "resource_id_field must be set"
|
||||
resource_id = str(resource_id)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
_get_resource(resource_id, current_tenant_id, self.resource_model)
|
||||
current_key_count: int = (
|
||||
db.session.scalar(
|
||||
@@ -108,7 +117,7 @@ class BaseApiKeyListResource(Resource):
|
||||
api_token.type = self.resource_type
|
||||
db.session.add(api_token)
|
||||
db.session.commit()
|
||||
return ApiKeyItem.model_validate(api_token, from_attributes=True).model_dump(mode="json"), 201
|
||||
return api_token
|
||||
|
||||
|
||||
class BaseApiKeyResource(Resource):
|
||||
@@ -118,9 +127,20 @@ class BaseApiKeyResource(Resource):
|
||||
resource_model: type | None = None
|
||||
resource_id_field: str | None = None
|
||||
|
||||
def delete(self, resource_id: str, api_key_id: str):
|
||||
def delete(
|
||||
self, resource_id: str, api_key_id: str, current_tenant_id: str, current_user: Account
|
||||
) -> tuple[str, int]:
|
||||
self._delete_api_key(resource_id, api_key_id, current_tenant_id, current_user)
|
||||
return "", 204
|
||||
|
||||
def _delete_api_key(
|
||||
self,
|
||||
resource_id: str,
|
||||
api_key_id: str,
|
||||
current_tenant_id: str,
|
||||
current_user: Account,
|
||||
) -> None:
|
||||
assert self.resource_id_field is not None, "resource_id_field must be set"
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
_get_resource(resource_id, current_tenant_id, self.resource_model)
|
||||
|
||||
if not current_user.is_admin_or_owner:
|
||||
@@ -147,8 +167,6 @@ class BaseApiKeyResource(Resource):
|
||||
db.session.execute(delete(ApiToken).where(ApiToken.id == api_key_id))
|
||||
db.session.commit()
|
||||
|
||||
return "", 204
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:resource_id>/api-keys")
|
||||
class AppApiKeyListResource(BaseApiKeyListResource):
|
||||
@@ -156,18 +174,21 @@ class AppApiKeyListResource(BaseApiKeyListResource):
|
||||
@console_ns.doc(description="Get all API keys for an app")
|
||||
@console_ns.doc(params={"resource_id": "App ID"})
|
||||
@console_ns.response(200, "API keys retrieved successfully", console_ns.models[ApiKeyList.__name__])
|
||||
def get(self, resource_id: UUID):
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, resource_id: UUID) -> dict[str, object]:
|
||||
"""Get all API keys for an app"""
|
||||
return super().get(resource_id)
|
||||
return dump_response(ApiKeyList, self._get_api_key_list(str(resource_id), current_tenant_id))
|
||||
|
||||
@console_ns.doc("create_app_api_key")
|
||||
@console_ns.doc(description="Create a new API key for an app")
|
||||
@console_ns.doc(params={"resource_id": "App ID"})
|
||||
@console_ns.response(201, "API key created successfully", console_ns.models[ApiKeyItem.__name__])
|
||||
@console_ns.response(400, "Maximum keys exceeded")
|
||||
def post(self, resource_id: UUID):
|
||||
@with_current_tenant_id
|
||||
@edit_permission_required
|
||||
def post(self, current_tenant_id: str, resource_id: UUID) -> tuple[dict[str, object], int]:
|
||||
"""Create a new API key for an app"""
|
||||
return super().post(resource_id)
|
||||
return dump_response(ApiKeyItem, self._create_api_key(str(resource_id), current_tenant_id)), 201
|
||||
|
||||
resource_type = ApiTokenType.APP
|
||||
resource_model = App
|
||||
@@ -181,9 +202,14 @@ class AppApiKeyResource(BaseApiKeyResource):
|
||||
@console_ns.doc(description="Delete an API key for an app")
|
||||
@console_ns.doc(params={"resource_id": "App ID", "api_key_id": "API key ID"})
|
||||
@console_ns.response(204, "API key deleted successfully")
|
||||
def delete(self, resource_id: UUID, api_key_id: UUID):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def delete(
|
||||
self, current_tenant_id: str, current_user: Account, resource_id: UUID, api_key_id: UUID
|
||||
) -> tuple[str, int]:
|
||||
"""Delete an API key for an app"""
|
||||
return super().delete(str(resource_id), str(api_key_id))
|
||||
self._delete_api_key(str(resource_id), str(api_key_id), current_tenant_id, current_user)
|
||||
return "", 204
|
||||
|
||||
resource_type = ApiTokenType.APP
|
||||
resource_model = App
|
||||
@@ -196,18 +222,21 @@ class DatasetApiKeyListResource(BaseApiKeyListResource):
|
||||
@console_ns.doc(description="Get all API keys for a dataset")
|
||||
@console_ns.doc(params={"resource_id": "Dataset ID"})
|
||||
@console_ns.response(200, "API keys retrieved successfully", console_ns.models[ApiKeyList.__name__])
|
||||
def get(self, resource_id: UUID):
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, resource_id: UUID) -> dict[str, object]:
|
||||
"""Get all API keys for a dataset"""
|
||||
return super().get(resource_id)
|
||||
return dump_response(ApiKeyList, self._get_api_key_list(str(resource_id), current_tenant_id))
|
||||
|
||||
@console_ns.doc("create_dataset_api_key")
|
||||
@console_ns.doc(description="Create a new API key for a dataset")
|
||||
@console_ns.doc(params={"resource_id": "Dataset ID"})
|
||||
@console_ns.response(201, "API key created successfully", console_ns.models[ApiKeyItem.__name__])
|
||||
@console_ns.response(400, "Maximum keys exceeded")
|
||||
def post(self, resource_id: UUID):
|
||||
@with_current_tenant_id
|
||||
@edit_permission_required
|
||||
def post(self, current_tenant_id: str, resource_id: UUID) -> tuple[dict[str, object], int]:
|
||||
"""Create a new API key for a dataset"""
|
||||
return super().post(resource_id)
|
||||
return dump_response(ApiKeyItem, self._create_api_key(str(resource_id), current_tenant_id)), 201
|
||||
|
||||
resource_type = ApiTokenType.DATASET
|
||||
resource_model = Dataset
|
||||
@@ -221,9 +250,14 @@ class DatasetApiKeyResource(BaseApiKeyResource):
|
||||
@console_ns.doc(description="Delete an API key for a dataset")
|
||||
@console_ns.doc(params={"resource_id": "Dataset ID", "api_key_id": "API key ID"})
|
||||
@console_ns.response(204, "API key deleted successfully")
|
||||
def delete(self, resource_id: UUID, api_key_id: UUID):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def delete(
|
||||
self, current_tenant_id: str, current_user: Account, resource_id: UUID, api_key_id: UUID
|
||||
) -> tuple[str, int]:
|
||||
"""Delete an API key for a dataset"""
|
||||
return super().delete(str(resource_id), str(api_key_id))
|
||||
self._delete_api_key(str(resource_id), str(api_key_id), current_tenant_id, current_user)
|
||||
return "", 204
|
||||
|
||||
resource_type = ApiTokenType.DATASET
|
||||
resource_model = Dataset
|
||||
|
||||
@@ -5,12 +5,12 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from fields.base import ResponseModel
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from services.auth.api_key_auth_service import ApiKeyAuthService
|
||||
|
||||
from .. import console_ns
|
||||
from ..auth.error import ApiKeyAuthFailedError
|
||||
from ..wraps import account_initialization_required, is_admin_or_owner_required, setup_required
|
||||
from ..wraps import account_initialization_required, is_admin_or_owner_required, setup_required, with_current_tenant_id
|
||||
|
||||
|
||||
class ApiKeyAuthBindingPayload(BaseModel):
|
||||
@@ -42,8 +42,8 @@ class ApiKeyAuthDataSource(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str):
|
||||
data_source_api_key_bindings = ApiKeyAuthService.get_provider_auth_list(current_tenant_id)
|
||||
if data_source_api_key_bindings:
|
||||
return {
|
||||
@@ -69,9 +69,9 @@ class ApiKeyAuthDataSourceBinding(Resource):
|
||||
@account_initialization_required
|
||||
@is_admin_or_owner_required
|
||||
@console_ns.expect(console_ns.models[ApiKeyAuthBindingPayload.__name__])
|
||||
def post(self):
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str):
|
||||
# The role of the current user in the table must be admin or owner
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
payload = ApiKeyAuthBindingPayload.model_validate(console_ns.payload)
|
||||
data = payload.model_dump()
|
||||
ApiKeyAuthService.validate_api_key_auth_args(data)
|
||||
@@ -89,10 +89,9 @@ class ApiKeyAuthDataSourceBindingDelete(Resource):
|
||||
@account_initialization_required
|
||||
@is_admin_or_owner_required
|
||||
@console_ns.response(204, "Binding deleted successfully")
|
||||
def delete(self, binding_id: UUID):
|
||||
@with_current_tenant_id
|
||||
def delete(self, current_tenant_id: str, binding_id: UUID):
|
||||
# The role of the current user in the table must be admin or owner
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
ApiKeyAuthService.delete_provider_auth(current_tenant_id, str(binding_id))
|
||||
|
||||
return "", 204
|
||||
|
||||
@@ -8,9 +8,9 @@ from flask_restx import Resource
|
||||
from pydantic import BaseModel
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required, with_current_user
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from models.model import OAuthProviderApp
|
||||
from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN, OAuthGrantType, OAuthServerService
|
||||
@@ -133,12 +133,10 @@ class OAuthServerUserAuthorizeApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@with_current_user
|
||||
@oauth_server_client_id_required
|
||||
def post(self, oauth_provider_app: OAuthProviderApp):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
account = current_user
|
||||
user_account_id = account.id
|
||||
|
||||
def post(self, oauth_provider_app: OAuthProviderApp, current_user: Account):
|
||||
user_account_id = current_user.id
|
||||
code = OAuthServerService.sign_oauth_authorization_code(oauth_provider_app.client_id, user_account_id)
|
||||
return jsonable_encoder(
|
||||
{
|
||||
|
||||
@@ -19,6 +19,7 @@ def test_get_api_key_auth_data_source(
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
foreign_account, foreign_tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
binding = DataSourceApiKeyAuthBinding(
|
||||
tenant_id=tenant.id,
|
||||
category="api_key",
|
||||
@@ -26,8 +27,16 @@ def test_get_api_key_auth_data_source(
|
||||
credentials=json.dumps({"auth_type": "api_key", "config": {"api_key": "encrypted"}}),
|
||||
disabled=False,
|
||||
)
|
||||
db_session_with_containers.add(binding)
|
||||
foreign_binding = DataSourceApiKeyAuthBinding(
|
||||
tenant_id=foreign_tenant.id,
|
||||
category="api_key",
|
||||
provider="foreign_provider",
|
||||
credentials=json.dumps({"auth_type": "api_key", "config": {"api_key": "encrypted"}}),
|
||||
disabled=False,
|
||||
)
|
||||
db_session_with_containers.add_all([binding, foreign_binding])
|
||||
db_session_with_containers.commit()
|
||||
authenticate_console_client(test_client_with_containers, foreign_account)
|
||||
|
||||
response = test_client_with_containers.get(
|
||||
"/console/api/api-key-auth/data-source",
|
||||
@@ -60,20 +69,23 @@ def test_create_binding_successful(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, _tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
tenant_id = tenant.id
|
||||
payload = {"category": "api_key", "provider": "custom", "credentials": {"key": "value"}}
|
||||
|
||||
with (
|
||||
patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.validate_api_key_auth_args"),
|
||||
patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.create_provider_auth"),
|
||||
patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.create_provider_auth") as create_auth,
|
||||
):
|
||||
response = test_client_with_containers.post(
|
||||
"/console/api/api-key-auth/data-source/binding",
|
||||
json={"category": "api_key", "provider": "custom", "credentials": {"key": "value"}},
|
||||
json=payload,
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_json() == {"result": "success"}
|
||||
create_auth.assert_called_once_with(tenant_id, payload)
|
||||
|
||||
|
||||
def test_create_binding_failure(
|
||||
@@ -129,3 +141,35 @@ def test_delete_binding_successful(
|
||||
)
|
||||
is None
|
||||
)
|
||||
|
||||
|
||||
def test_delete_binding_scopes_to_authenticated_tenant(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, _tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
foreign_account, foreign_tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
foreign_binding = DataSourceApiKeyAuthBinding(
|
||||
tenant_id=foreign_tenant.id,
|
||||
category="api_key",
|
||||
provider="custom_provider",
|
||||
credentials=json.dumps({"auth_type": "api_key", "config": {"api_key": "encrypted"}}),
|
||||
disabled=False,
|
||||
)
|
||||
db_session_with_containers.add(foreign_binding)
|
||||
db_session_with_containers.commit()
|
||||
foreign_binding_id = foreign_binding.id
|
||||
authenticate_console_client(test_client_with_containers, foreign_account)
|
||||
|
||||
response = test_client_with_containers.delete(
|
||||
f"/console/api/api-key-auth/data-source/{foreign_binding_id}",
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 204
|
||||
assert (
|
||||
db_session_with_containers.scalar(
|
||||
select(DataSourceApiKeyAuthBinding).where(DataSourceApiKeyAuthBinding.id == foreign_binding_id)
|
||||
)
|
||||
is not None
|
||||
)
|
||||
|
||||
@@ -6,11 +6,13 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from flask.testing import FlaskClient
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from constants import HIDDEN_VALUE
|
||||
from libs.rsa import generate_key_pair
|
||||
from models import Tenant
|
||||
from models.api_based_extension import APIBasedExtension
|
||||
from services.api_based_extension_service import APIBasedExtensionService
|
||||
from tests.test_containers_integration_tests.controllers.console.helpers import (
|
||||
authenticate_console_client,
|
||||
create_console_account_and_tenant,
|
||||
@@ -27,13 +29,14 @@ def _masked_api_key(api_key: str) -> str:
|
||||
def api_extension_client(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> tuple[FlaskClient, dict[str, str], Tenant]:
|
||||
) -> tuple[FlaskClient, dict[str, str], str]:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
tenant_id = tenant.id
|
||||
tenant.encrypt_public_key = generate_key_pair(tenant.id)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
headers = authenticate_console_client(test_client_with_containers, account)
|
||||
return test_client_with_containers, headers, tenant
|
||||
return test_client_with_containers, headers, tenant_id
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
@@ -44,9 +47,10 @@ def mock_api_based_extension_ping():
|
||||
|
||||
|
||||
def test_create_response_masks_plaintext_api_key(
|
||||
api_extension_client: tuple[FlaskClient, dict[str, str], Tenant],
|
||||
api_extension_client: tuple[FlaskClient, dict[str, str], str],
|
||||
db_session_with_containers: Session,
|
||||
) -> None:
|
||||
client, headers, _ = api_extension_client
|
||||
client, headers, tenant_id = api_extension_client
|
||||
api_key = "plain-secret-12345"
|
||||
|
||||
response = client.post(
|
||||
@@ -62,10 +66,57 @@ def test_create_response_masks_plaintext_api_key(
|
||||
assert response.status_code == 201
|
||||
assert response.json is not None
|
||||
assert response.json["api_key"] == _masked_api_key(api_key)
|
||||
extension = db_session_with_containers.scalar(
|
||||
select(APIBasedExtension).where(APIBasedExtension.id == response.json["id"]).limit(1)
|
||||
)
|
||||
assert extension is not None
|
||||
assert extension.tenant_id == tenant_id
|
||||
|
||||
|
||||
def test_list_scopes_api_based_extensions_to_authenticated_tenant(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
account_headers = authenticate_console_client(test_client_with_containers, account)
|
||||
tenant.encrypt_public_key = generate_key_pair(tenant.id)
|
||||
_foreign_account, foreign_tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
foreign_tenant_id = foreign_tenant.id
|
||||
foreign_tenant.encrypt_public_key = generate_key_pair(foreign_tenant.id)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
account_create_response = test_client_with_containers.post(
|
||||
"/console/api/api-based-extension",
|
||||
headers=account_headers,
|
||||
json={
|
||||
"name": "Tenant API",
|
||||
"api_endpoint": "https://tenant.example.com/hook",
|
||||
"api_key": "tenant-secret-12345",
|
||||
},
|
||||
)
|
||||
assert account_create_response.status_code == 201
|
||||
|
||||
APIBasedExtensionService.save(
|
||||
APIBasedExtension(
|
||||
tenant_id=foreign_tenant_id,
|
||||
name="Foreign API",
|
||||
api_endpoint="https://foreign.example.com/hook",
|
||||
api_key="foreign-secret-12345",
|
||||
)
|
||||
)
|
||||
|
||||
response = test_client_with_containers.get(
|
||||
"/console/api/api-based-extension",
|
||||
headers=account_headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json is not None
|
||||
assert [item["name"] for item in response.json] == ["Tenant API"]
|
||||
|
||||
|
||||
def test_update_response_masks_new_plaintext_api_key(
|
||||
api_extension_client: tuple[FlaskClient, dict[str, str], Tenant],
|
||||
api_extension_client: tuple[FlaskClient, dict[str, str], str],
|
||||
) -> None:
|
||||
client, headers, _ = api_extension_client
|
||||
new_api_key = "new-secret-67890"
|
||||
@@ -96,7 +147,7 @@ def test_update_response_masks_new_plaintext_api_key(
|
||||
|
||||
|
||||
def test_update_response_masks_existing_plaintext_api_key_when_hidden_value_is_submitted(
|
||||
api_extension_client: tuple[FlaskClient, dict[str, str], Tenant],
|
||||
api_extension_client: tuple[FlaskClient, dict[str, str], str],
|
||||
) -> None:
|
||||
client, headers, _ = api_extension_client
|
||||
existing_api_key = "old-secret-12345"
|
||||
|
||||
@@ -7,9 +7,11 @@ from unittest.mock import MagicMock, patch
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask.testing import FlaskClient
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models import Account
|
||||
from models.account import AccountStatus, TenantAccountRole
|
||||
from models.enums import ApiTokenType
|
||||
from models.model import ApiToken, App, AppMode
|
||||
from tests.test_containers_integration_tests.controllers.console.helpers import (
|
||||
@@ -58,6 +60,24 @@ class TestAppApiKeyListResource:
|
||||
assert data["token"].startswith("app-")
|
||||
assert data["id"] is not None
|
||||
|
||||
def test_create_api_key_persists_authenticated_tenant(
|
||||
self,
|
||||
setup_app: tuple[FlaskClient, dict[str, str], App],
|
||||
db_session_with_containers: Session,
|
||||
) -> None:
|
||||
client, headers, app = setup_app
|
||||
tenant_id = app.tenant_id
|
||||
|
||||
resp = client.post(f"/console/api/apps/{app.id}/api-keys", headers=headers)
|
||||
|
||||
assert resp.status_code == 201
|
||||
assert resp.json is not None
|
||||
api_token = db_session_with_containers.scalar(select(ApiToken).where(ApiToken.id == resp.json["id"]))
|
||||
assert api_token is not None
|
||||
assert api_token.tenant_id == tenant_id
|
||||
assert api_token.app_id == app.id
|
||||
assert api_token.type == ApiTokenType.APP
|
||||
|
||||
def test_get_keys_after_create(self, setup_app: tuple[FlaskClient, dict[str, str], App]) -> None:
|
||||
client, headers, app = setup_app
|
||||
client.post(f"/console/api/apps/{app.id}/api-keys", headers=headers)
|
||||
@@ -93,6 +113,21 @@ class TestAppApiKeyListResource:
|
||||
)
|
||||
assert resp.status_code == 404
|
||||
|
||||
def test_get_foreign_app_keys_not_found(
|
||||
self,
|
||||
setup_app: tuple[FlaskClient, dict[str, str], App],
|
||||
db_session_with_containers: Session,
|
||||
) -> None:
|
||||
client, headers, _ = setup_app
|
||||
foreign_account, foreign_tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
foreign_app = create_console_app(
|
||||
db_session_with_containers, foreign_tenant.id, foreign_account.id, AppMode.CHAT
|
||||
)
|
||||
|
||||
resp = client.get(f"/console/api/apps/{foreign_app.id}/api-keys", headers=headers)
|
||||
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
class TestAppApiKeyResource:
|
||||
"""Tests for DELETE /apps/<resource_id>/api-keys/<api_key_id>."""
|
||||
@@ -139,16 +174,13 @@ class TestAppApiKeyResource:
|
||||
resource.resource_model = MagicMock()
|
||||
resource.resource_id_field = "app_id"
|
||||
|
||||
non_admin = MagicMock()
|
||||
non_admin.is_admin_or_owner = False
|
||||
non_admin = Account(name="Normal User", email="normal@example.com", status=AccountStatus.ACTIVE)
|
||||
non_admin.id = "normal-user"
|
||||
non_admin.role = TenantAccountRole.NORMAL
|
||||
|
||||
with (
|
||||
flask_app_with_containers.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.apikey.current_account_with_tenant",
|
||||
return_value=(non_admin, "tenant-id"),
|
||||
),
|
||||
patch("controllers.console.apikey._get_resource"),
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
BaseApiKeyResource.delete(resource, "rid", "kid")
|
||||
BaseApiKeyResource.delete(resource, "rid", "kid", "tenant-id", non_admin)
|
||||
|
||||
@@ -0,0 +1,85 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import PropertyMock, patch
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.auth.data_source_bearer_auth import (
|
||||
ApiKeyAuthDataSource,
|
||||
ApiKeyAuthDataSourceBinding,
|
||||
ApiKeyAuthDataSourceBindingDelete,
|
||||
)
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
def _payload_patch(payload: dict):
|
||||
return patch.object(
|
||||
type(console_ns),
|
||||
"payload",
|
||||
new_callable=PropertyMock,
|
||||
return_value=payload,
|
||||
)
|
||||
|
||||
|
||||
def test_list_data_source_auth_uses_injected_tenant_id() -> None:
|
||||
api = ApiKeyAuthDataSource()
|
||||
method = _unwrap(api.get)
|
||||
binding = SimpleNamespace(
|
||||
id="binding-1",
|
||||
category="api_key",
|
||||
provider="custom",
|
||||
disabled=False,
|
||||
created_at=datetime(2026, 1, 1, tzinfo=UTC),
|
||||
updated_at=datetime(2026, 1, 2, tzinfo=UTC),
|
||||
)
|
||||
|
||||
with patch(
|
||||
"controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.get_provider_auth_list",
|
||||
return_value=[binding],
|
||||
) as get_provider_auth_list:
|
||||
result = method(api, "tenant-1")
|
||||
|
||||
get_provider_auth_list.assert_called_once_with("tenant-1")
|
||||
assert result["sources"][0]["id"] == "binding-1"
|
||||
assert result["sources"][0]["provider"] == "custom"
|
||||
|
||||
|
||||
def test_create_data_source_auth_binding_uses_injected_tenant_id() -> None:
|
||||
api = ApiKeyAuthDataSourceBinding()
|
||||
method = _unwrap(api.post)
|
||||
payload = {
|
||||
"category": "api_key",
|
||||
"provider": "custom",
|
||||
"credentials": {"auth_type": "api_key", "config": {"api_key": "secret"}},
|
||||
}
|
||||
|
||||
with (
|
||||
_payload_patch(payload),
|
||||
patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.validate_api_key_auth_args"),
|
||||
patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.create_provider_auth") as create_auth,
|
||||
):
|
||||
result, status = method(api, "tenant-1")
|
||||
|
||||
create_auth.assert_called_once_with("tenant-1", payload)
|
||||
assert result == {"result": "success"}
|
||||
assert status == 200
|
||||
|
||||
|
||||
def test_delete_data_source_auth_binding_uses_injected_tenant_id() -> None:
|
||||
api = ApiKeyAuthDataSourceBindingDelete()
|
||||
method = _unwrap(api.delete)
|
||||
|
||||
with patch(
|
||||
"controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.delete_provider_auth"
|
||||
) as delete_provider_auth:
|
||||
result, status = method(api, "tenant-1", "binding-1")
|
||||
|
||||
delete_provider_auth.assert_called_once_with("tenant-1", "binding-1")
|
||||
assert result == ""
|
||||
assert status == 204
|
||||
@@ -0,0 +1,52 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from controllers.console.auth.oauth_server import OAuthServerUserAuthorizeApi
|
||||
from models import Account
|
||||
from models.account import AccountStatus, TenantAccountRole
|
||||
from models.model import OAuthProviderApp
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
def _make_account() -> Account:
|
||||
account = Account(
|
||||
name="Test User",
|
||||
email="test@example.com",
|
||||
status=AccountStatus.ACTIVE,
|
||||
)
|
||||
account.id = "account-1"
|
||||
account.role = TenantAccountRole.OWNER
|
||||
return account
|
||||
|
||||
|
||||
def _make_oauth_provider_app() -> OAuthProviderApp:
|
||||
return OAuthProviderApp(
|
||||
app_icon="icon",
|
||||
client_id="client-1",
|
||||
client_secret="secret",
|
||||
app_label={"en-US": "Test App"},
|
||||
redirect_uris=["https://example.com/callback"],
|
||||
scope="read",
|
||||
)
|
||||
|
||||
|
||||
def test_oauth_authorize_uses_injected_current_user() -> None:
|
||||
api = OAuthServerUserAuthorizeApi()
|
||||
method = _unwrap(api.post)
|
||||
account = _make_account()
|
||||
oauth_provider_app = _make_oauth_provider_app()
|
||||
|
||||
with patch(
|
||||
"controllers.console.auth.oauth_server.OAuthServerService.sign_oauth_authorization_code",
|
||||
return_value="authorization-code",
|
||||
) as sign_oauth_authorization_code:
|
||||
response = method(api, oauth_provider_app, account)
|
||||
|
||||
sign_oauth_authorization_code.assert_called_once_with("client-1", "account-1")
|
||||
assert response == {"code": "authorization-code"}
|
||||
141
api/tests/unit_tests/controllers/console/test_apikey.py
Normal file
141
api/tests/unit_tests/controllers/console/test_apikey.py
Normal file
@@ -0,0 +1,141 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
from collections.abc import Callable
|
||||
from types import SimpleNamespace
|
||||
from typing import cast
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console.apikey import BaseApiKeyListResource, BaseApiKeyResource
|
||||
from models import Account
|
||||
from models.account import AccountStatus, TenantAccountRole
|
||||
from models.enums import ApiTokenType
|
||||
from models.model import ApiToken, App
|
||||
|
||||
|
||||
def _make_list_resource() -> BaseApiKeyListResource:
|
||||
resource = BaseApiKeyListResource()
|
||||
resource.resource_type = ApiTokenType.APP
|
||||
resource.resource_model = App
|
||||
resource.resource_id_field = "app_id"
|
||||
resource.token_prefix = "app-"
|
||||
return resource
|
||||
|
||||
|
||||
def _make_key_resource() -> BaseApiKeyResource:
|
||||
resource = BaseApiKeyResource()
|
||||
resource.resource_type = ApiTokenType.APP
|
||||
resource.resource_model = App
|
||||
resource.resource_id_field = "app_id"
|
||||
return resource
|
||||
|
||||
|
||||
def _make_account(role: TenantAccountRole) -> Account:
|
||||
account = Account(
|
||||
name="Test User",
|
||||
email=f"{role.value}@example.com",
|
||||
status=AccountStatus.ACTIVE,
|
||||
)
|
||||
account.id = f"{role.value}-user"
|
||||
account.role = role
|
||||
return account
|
||||
|
||||
|
||||
def test_list_api_keys_uses_injected_tenant_id() -> None:
|
||||
resource = _make_list_resource()
|
||||
api_key = SimpleNamespace(
|
||||
id="key-1",
|
||||
type=ApiTokenType.APP,
|
||||
token="app-token",
|
||||
last_used_at=None,
|
||||
created_at=None,
|
||||
)
|
||||
|
||||
with (
|
||||
patch("controllers.console.apikey._get_resource") as get_resource,
|
||||
patch("controllers.console.apikey.db") as db_mock,
|
||||
):
|
||||
db_mock.session.scalars.return_value.all.return_value = [api_key]
|
||||
|
||||
result = resource.get("app-1", "tenant-1")
|
||||
|
||||
get_resource.assert_called_once_with("app-1", "tenant-1", App)
|
||||
assert result == {
|
||||
"data": [
|
||||
{
|
||||
"id": "key-1",
|
||||
"type": "app",
|
||||
"token": "app-token",
|
||||
"last_used_at": None,
|
||||
"created_at": None,
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
def test_create_api_key_uses_injected_tenant_id() -> None:
|
||||
resource = _make_list_resource()
|
||||
raw_post = cast(
|
||||
Callable[[BaseApiKeyListResource, str, str], tuple[dict[str, object], int]],
|
||||
inspect.unwrap(BaseApiKeyListResource.post),
|
||||
)
|
||||
|
||||
def add_api_token(api_token: ApiToken) -> None:
|
||||
api_token.id = "key-1"
|
||||
|
||||
with (
|
||||
patch("controllers.console.apikey._get_resource") as get_resource,
|
||||
patch("controllers.console.apikey.db") as db_mock,
|
||||
patch("controllers.console.apikey.ApiToken.generate_api_key", return_value="app-generated-token"),
|
||||
):
|
||||
db_mock.session.scalar.return_value = 0
|
||||
db_mock.session.add.side_effect = add_api_token
|
||||
|
||||
result, status = raw_post(resource, "app-1", "tenant-1")
|
||||
|
||||
get_resource.assert_called_once_with("app-1", "tenant-1", App)
|
||||
assert status == 201
|
||||
assert result["token"] == "app-generated-token"
|
||||
api_token = db_mock.session.add.call_args.args[0]
|
||||
assert api_token.app_id == "app-1"
|
||||
assert api_token.tenant_id == "tenant-1"
|
||||
assert api_token.type == ApiTokenType.APP
|
||||
db_mock.session.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_delete_api_key_rejects_non_admin_account() -> None:
|
||||
resource = _make_key_resource()
|
||||
|
||||
with (
|
||||
patch("controllers.console.apikey._get_resource") as get_resource,
|
||||
patch("controllers.console.apikey.db") as db_mock,
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
resource.delete("app-1", "key-1", "tenant-1", _make_account(TenantAccountRole.NORMAL))
|
||||
|
||||
get_resource.assert_called_once_with("app-1", "tenant-1", App)
|
||||
db_mock.session.scalar.assert_not_called()
|
||||
|
||||
|
||||
def test_delete_api_key_uses_injected_user_and_tenant() -> None:
|
||||
resource = _make_key_resource()
|
||||
api_key = SimpleNamespace(token="app-token", type=ApiTokenType.APP)
|
||||
|
||||
with (
|
||||
patch("controllers.console.apikey._get_resource") as get_resource,
|
||||
patch("controllers.console.apikey.db") as db_mock,
|
||||
patch("controllers.console.apikey.ApiTokenCache.delete") as delete_cache,
|
||||
):
|
||||
db_mock.session.scalar.return_value = api_key
|
||||
|
||||
result, status = resource.delete("app-1", "key-1", "tenant-1", _make_account(TenantAccountRole.OWNER))
|
||||
|
||||
get_resource.assert_called_once_with("app-1", "tenant-1", App)
|
||||
delete_cache.assert_called_once_with("app-token", ApiTokenType.APP)
|
||||
db_mock.session.execute.assert_called_once()
|
||||
db_mock.session.commit.assert_called_once()
|
||||
assert result == ""
|
||||
assert status == 204
|
||||
Reference in New Issue
Block a user