diff --git a/api/services/billing_service.py b/api/services/billing_service.py index 54e1c9d285..3d7cb6cc8d 100644 --- a/api/services/billing_service.py +++ b/api/services/billing_service.py @@ -1,8 +1,12 @@ +import logging import os +from collections.abc import Sequence from typing import Literal import httpx +from pydantic import TypeAdapter from tenacity import retry, retry_if_exception_type, stop_before_delay, wait_fixed +from typing_extensions import TypedDict from werkzeug.exceptions import InternalServerError from enums.cloud_plan import CloudPlan @@ -11,6 +15,15 @@ from extensions.ext_redis import redis_client from libs.helper import RateLimiter from models import Account, TenantAccountJoin, TenantAccountRole +logger = logging.getLogger(__name__) + + +class SubscriptionPlan(TypedDict): + """Tenant subscriptionplan information.""" + + plan: str + expiration_date: int + class BillingService: base_url = os.environ.get("BILLING_API_URL", "BILLING_API_URL") @@ -239,3 +252,39 @@ class BillingService: def sync_partner_tenants_bindings(cls, account_id: str, partner_key: str, click_id: str): payload = {"account_id": account_id, "click_id": click_id} return cls._send_request("PUT", f"/partners/{partner_key}/tenants", json=payload) + + @classmethod + def get_plan_bulk(cls, tenant_ids: Sequence[str]) -> dict[str, SubscriptionPlan]: + """ + Bulk fetch billing subscription plan via billing API. + Payload: {"tenant_ids": ["t1", "t2", ...]} (max 200 per request) + Returns: + Mapping of tenant_id -> {plan: str, expiration_date: int} + """ + results: dict[str, SubscriptionPlan] = {} + subscription_adapter = TypeAdapter(SubscriptionPlan) + + chunk_size = 200 + for i in range(0, len(tenant_ids), chunk_size): + chunk = tenant_ids[i : i + chunk_size] + try: + resp = cls._send_request("POST", "/subscription/plan/batch", json={"tenant_ids": chunk}) + data = resp.get("data", {}) + + for tenant_id, plan in data.items(): + subscription_plan = subscription_adapter.validate_python(plan) + results[tenant_id] = subscription_plan + except Exception: + logger.exception("Failed to fetch billing info batch for tenants: %s", chunk) + continue + + return results + + @classmethod + def get_expired_subscription_cleanup_whitelist(cls) -> Sequence[str]: + resp = cls._send_request("GET", "/subscription/cleanup/whitelist") + data = resp.get("data", []) + tenant_whitelist = [] + for item in data: + tenant_whitelist.append(item["tenant_id"]) + return tenant_whitelist diff --git a/api/tests/unit_tests/services/test_billing_service.py b/api/tests/unit_tests/services/test_billing_service.py index 915aee3fa7..f50f744a75 100644 --- a/api/tests/unit_tests/services/test_billing_service.py +++ b/api/tests/unit_tests/services/test_billing_service.py @@ -1156,6 +1156,199 @@ class TestBillingServiceEdgeCases: assert "Only team owner or team admin can perform this action" in str(exc_info.value) +class TestBillingServiceSubscriptionOperations: + """Unit tests for subscription operations in BillingService. + + Tests cover: + - Bulk plan retrieval with chunking + - Expired subscription cleanup whitelist retrieval + """ + + @pytest.fixture + def mock_send_request(self): + """Mock _send_request method.""" + with patch.object(BillingService, "_send_request") as mock: + yield mock + + def test_get_plan_bulk_with_empty_list(self, mock_send_request): + """Test bulk plan retrieval with empty tenant list.""" + # Arrange + tenant_ids = [] + + # Act + result = BillingService.get_plan_bulk(tenant_ids) + + # Assert + assert result == {} + mock_send_request.assert_not_called() + + def test_get_plan_bulk_with_chunking(self, mock_send_request): + """Test bulk plan retrieval with more than 200 tenants (chunking logic).""" + # Arrange - 250 tenants to test chunking (chunk_size = 200) + tenant_ids = [f"tenant-{i}" for i in range(250)] + + # First chunk: tenants 0-199 + first_chunk_response = { + "data": {f"tenant-{i}": {"plan": "sandbox", "expiration_date": 1735689600} for i in range(200)} + } + + # Second chunk: tenants 200-249 + second_chunk_response = { + "data": {f"tenant-{i}": {"plan": "professional", "expiration_date": 1767225600} for i in range(200, 250)} + } + + mock_send_request.side_effect = [first_chunk_response, second_chunk_response] + + # Act + result = BillingService.get_plan_bulk(tenant_ids) + + # Assert + assert len(result) == 250 + assert result["tenant-0"]["plan"] == "sandbox" + assert result["tenant-199"]["plan"] == "sandbox" + assert result["tenant-200"]["plan"] == "professional" + assert result["tenant-249"]["plan"] == "professional" + assert mock_send_request.call_count == 2 + + # Verify first chunk call + first_call = mock_send_request.call_args_list[0] + assert first_call[0][0] == "POST" + assert first_call[0][1] == "/subscription/plan/batch" + assert len(first_call[1]["json"]["tenant_ids"]) == 200 + + # Verify second chunk call + second_call = mock_send_request.call_args_list[1] + assert len(second_call[1]["json"]["tenant_ids"]) == 50 + + def test_get_plan_bulk_with_partial_batch_failure(self, mock_send_request): + """Test bulk plan retrieval when one batch fails but others succeed.""" + # Arrange - 250 tenants, second batch will fail + tenant_ids = [f"tenant-{i}" for i in range(250)] + + # First chunk succeeds + first_chunk_response = { + "data": {f"tenant-{i}": {"plan": "sandbox", "expiration_date": 1735689600} for i in range(200)} + } + + # Second chunk fails - need to create a mock that raises when called + def side_effect_func(*args, **kwargs): + if mock_send_request.call_count == 1: + return first_chunk_response + else: + raise ValueError("API error") + + mock_send_request.side_effect = side_effect_func + + # Act + result = BillingService.get_plan_bulk(tenant_ids) + + # Assert - should only have data from first batch + assert len(result) == 200 + assert result["tenant-0"]["plan"] == "sandbox" + assert result["tenant-199"]["plan"] == "sandbox" + assert "tenant-200" not in result + assert mock_send_request.call_count == 2 + + def test_get_plan_bulk_with_all_batches_failing(self, mock_send_request): + """Test bulk plan retrieval when all batches fail.""" + # Arrange + tenant_ids = [f"tenant-{i}" for i in range(250)] + + # All chunks fail + def side_effect_func(*args, **kwargs): + raise ValueError("API error") + + mock_send_request.side_effect = side_effect_func + + # Act + result = BillingService.get_plan_bulk(tenant_ids) + + # Assert - should return empty dict + assert result == {} + assert mock_send_request.call_count == 2 + + def test_get_plan_bulk_with_exactly_200_tenants(self, mock_send_request): + """Test bulk plan retrieval with exactly 200 tenants (boundary condition).""" + # Arrange + tenant_ids = [f"tenant-{i}" for i in range(200)] + mock_send_request.return_value = { + "data": {f"tenant-{i}": {"plan": "sandbox", "expiration_date": 1735689600} for i in range(200)} + } + + # Act + result = BillingService.get_plan_bulk(tenant_ids) + + # Assert + assert len(result) == 200 + assert mock_send_request.call_count == 1 + + def test_get_plan_bulk_with_empty_data_response(self, mock_send_request): + """Test bulk plan retrieval with empty data in response.""" + # Arrange + tenant_ids = ["tenant-1", "tenant-2"] + mock_send_request.return_value = {"data": {}} + + # Act + result = BillingService.get_plan_bulk(tenant_ids) + + # Assert + assert result == {} + + def test_get_expired_subscription_cleanup_whitelist_success(self, mock_send_request): + """Test successful retrieval of expired subscription cleanup whitelist.""" + # Arrange + api_response = [ + { + "created_at": "2025-10-16T01:56:17", + "tenant_id": "36bd55ec-2ea9-4d75-a9ea-1f26aeb4ffe6", + "contact": "example@dify.ai", + "id": "36bd55ec-2ea9-4d75-a9ea-1f26aeb4ffe5", + "expired_at": "2026-01-01T01:56:17", + "updated_at": "2025-10-16T01:56:17", + }, + { + "created_at": "2025-10-16T02:00:00", + "tenant_id": "tenant-2", + "contact": "test@example.com", + "id": "whitelist-id-2", + "expired_at": "2026-02-01T00:00:00", + "updated_at": "2025-10-16T02:00:00", + }, + { + "created_at": "2025-10-16T03:00:00", + "tenant_id": "tenant-3", + "contact": "another@example.com", + "id": "whitelist-id-3", + "expired_at": "2026-03-01T00:00:00", + "updated_at": "2025-10-16T03:00:00", + }, + ] + mock_send_request.return_value = {"data": api_response} + + # Act + result = BillingService.get_expired_subscription_cleanup_whitelist() + + # Assert - should return only tenant_ids + assert result == ["36bd55ec-2ea9-4d75-a9ea-1f26aeb4ffe6", "tenant-2", "tenant-3"] + assert len(result) == 3 + assert result[0] == "36bd55ec-2ea9-4d75-a9ea-1f26aeb4ffe6" + assert result[1] == "tenant-2" + assert result[2] == "tenant-3" + mock_send_request.assert_called_once_with("GET", "/subscription/cleanup/whitelist") + + def test_get_expired_subscription_cleanup_whitelist_empty_list(self, mock_send_request): + """Test retrieval of empty cleanup whitelist.""" + # Arrange + mock_send_request.return_value = {"data": []} + + # Act + result = BillingService.get_expired_subscription_cleanup_whitelist() + + # Assert + assert result == [] + assert len(result) == 0 + + class TestBillingServiceIntegrationScenarios: """Integration-style tests simulating real-world usage scenarios.