refactor: clean messages task

This commit is contained in:
hj24
2025-12-18 16:43:12 +08:00
parent 46c9a59a31
commit 45e2d4627f
8 changed files with 2225 additions and 77 deletions

View File

@@ -1,7 +1,9 @@
import base64
import datetime
import json
import logging
import secrets
import time
from typing import Any
import click
@@ -45,6 +47,7 @@ from services.clear_free_plan_tenant_expired_logs import ClearFreePlanTenantExpi
from services.plugin.data_migration import PluginDataMigration
from services.plugin.plugin_migration import PluginMigration
from services.plugin.plugin_service import PluginService
from services.sandbox_messages_clean_service import SandboxMessagesCleanService
from tasks.remove_app_and_related_data_task import delete_draft_variables_batch
logger = logging.getLogger(__name__)
@@ -1900,3 +1903,76 @@ def migrate_oss(
except Exception as e:
db.session.rollback()
click.echo(click.style(f"Failed to update DB storage_type: {str(e)}", fg="red"))
@click.command("clean-expired-sandbox-messages", help="Clean expired sandbox messages.")
@click.option("--batch-size", default=1000, show_default=True, help="Batch size for selecting messages.")
@click.option(
"--graceful-period",
default=21,
show_default=True,
help="Graceful period in days after subscription expiration.",
)
@click.option(
"--start-from",
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
default=None,
help="Optional lower bound (inclusive) for created_at; must be paired with --end-before.",
)
@click.option(
"--end-before",
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
default=None,
help="Optional upper bound (exclusive) for created_at; must be paired with --start-after.",
)
@click.option("--dry-run", is_flag=True, default=False, help="Show messages logs would be cleaned without deleteing")
def clean_expired_sandbox_messages(
batch_size: int,
graceful_period: int,
start_from: datetime.datetime,
end_before: datetime.datetime,
dry_run: bool,
):
"""
Clean expired messages and related data for sandbox tenants.
"""
if not dify_config.BILLING_ENABLED:
click.echo(click.style("Billing is not enabled. Skipping sandbox messages cleanup.", fg="yellow"))
return
click.echo(click.style("clean_messages: start clean messages.", fg="green"))
start_at = time.perf_counter()
try:
stats = SandboxMessagesCleanService.clean_sandbox_messages_by_time_range(
start_from=start_from,
end_before=end_before,
graceful_period=graceful_period,
batch_size=batch_size,
dry_run=dry_run,
)
end_at = time.perf_counter()
click.echo(
click.style(
f"clean_messages: completed successfully\n"
f" - Latency: {end_at - start_at:.2f}s\n"
f" - Batches processed: {stats['batches']}\n"
f" - Messages found: {stats['total_messages']}\n"
f" - Messages deleted: {stats['total_deleted']}",
fg="green",
)
)
except Exception as e:
end_at = time.perf_counter()
logger.exception("clean_messages failed")
click.echo(
click.style(
f"clean_messages: failed after {end_at - start_at:.2f}s - {str(e)}",
fg="red",
)
)
raise
click.echo(click.style("Sandbox messages cleanup completed.", fg="green"))

View File

@@ -4,6 +4,7 @@ from dify_app import DifyApp
def init_app(app: DifyApp):
from commands import (
add_qdrant_index,
clean_expired_sandbox_messages,
cleanup_orphaned_draft_variables,
clear_free_plan_tenant_expired_logs,
clear_orphaned_file_records,
@@ -54,6 +55,7 @@ def init_app(app: DifyApp):
setup_datasource_oauth_client,
transform_datasource_credentials,
install_rag_pipeline_plugins,
clean_expired_sandbox_messages,
]
for cmd in cmds_to_register:
app.cli.add_command(cmd)

View File

@@ -0,0 +1,33 @@
"""feat: add created_at id index to messages
Revision ID: 649d817a739e
Revises: 03ea244985ce
Create Date: 2025-12-18 16:39:33.090454
"""
from alembic import op
import models as models
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = '649d817a739e'
down_revision = '03ea244985ce'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('messages', schema=None) as batch_op:
batch_op.create_index('message_created_at_id_idx', ['created_at', 'id'], unique=False)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('messages', schema=None) as batch_op:
batch_op.drop_index('message_created_at_id_idx')
# ### end Alembic commands ###

View File

@@ -965,6 +965,7 @@ class Message(Base):
Index("message_workflow_run_id_idx", "conversation_id", "workflow_run_id"),
Index("message_created_at_idx", "created_at"),
Index("message_app_mode_idx", "app_mode"),
Index("message_created_at_id_idx", "created_at", "id"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()))

View File

@@ -1,90 +1,54 @@
import datetime
import logging
import time
import click
from sqlalchemy.exc import SQLAlchemyError
import app
from configs import dify_config
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.model import (
App,
Message,
MessageAgentThought,
MessageAnnotation,
MessageChain,
MessageFeedback,
MessageFile,
)
from models.web import SavedMessage
from services.feature_service import FeatureService
from services.sandbox_messages_clean_service import SandboxMessagesCleanService
logger = logging.getLogger(__name__)
@app.celery.task(queue="dataset")
@app.celery.task(queue="retention")
def clean_messages():
click.echo(click.style("Start clean messages.", fg="green"))
start_at = time.perf_counter()
plan_sandbox_clean_message_day = datetime.datetime.now() - datetime.timedelta(
days=dify_config.PLAN_SANDBOX_CLEAN_MESSAGE_DAY_SETTING
)
while True:
try:
# Main query with join and filter
messages = (
db.session.query(Message)
.where(Message.created_at < plan_sandbox_clean_message_day)
.order_by(Message.created_at.desc())
.limit(100)
.all()
)
"""
Clean expired messages from sandbox plan tenants.
except SQLAlchemyError:
raise
if not messages:
break
for message in messages:
app = db.session.query(App).filter_by(id=message.app_id).first()
if not app:
logger.warning(
"Expected App record to exist, but none was found, app_id=%s, message_id=%s",
message.app_id,
message.id,
)
continue
features_cache_key = f"features:{app.tenant_id}"
plan_cache = redis_client.get(features_cache_key)
if plan_cache is None:
features = FeatureService.get_features(app.tenant_id)
redis_client.setex(features_cache_key, 600, features.billing.subscription.plan)
plan = features.billing.subscription.plan
else:
plan = plan_cache.decode()
if plan == CloudPlan.SANDBOX:
# clean related message
db.session.query(MessageFeedback).where(MessageFeedback.message_id == message.id).delete(
synchronize_session=False
)
db.session.query(MessageAnnotation).where(MessageAnnotation.message_id == message.id).delete(
synchronize_session=False
)
db.session.query(MessageChain).where(MessageChain.message_id == message.id).delete(
synchronize_session=False
)
db.session.query(MessageAgentThought).where(MessageAgentThought.message_id == message.id).delete(
synchronize_session=False
)
db.session.query(MessageFile).where(MessageFile.message_id == message.id).delete(
synchronize_session=False
)
db.session.query(SavedMessage).where(SavedMessage.message_id == message.id).delete(
synchronize_session=False
)
db.session.query(Message).where(Message.id == message.id).delete()
db.session.commit()
end_at = time.perf_counter()
click.echo(click.style(f"Cleaned messages from db success latency: {end_at - start_at}", fg="green"))
This task uses SandboxMessagesCleanService to efficiently clean messages in batches.
"""
if not dify_config.BILLING_ENABLED:
click.echo(click.style("Billing is not enabled. Skipping sandbox messages cleanup.", fg="yellow"))
return
click.echo(click.style("clean_messages: start clean messages.", fg="green"))
start_at = time.perf_counter()
try:
stats = SandboxMessagesCleanService.clean_sandbox_messages_by_days(
days=dify_config.SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS,
graceful_period=dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD,
batch_size=dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE,
)
end_at = time.perf_counter()
click.echo(
click.style(
f"clean_messages: completed successfully\n"
f" - Latency: {end_at - start_at:.2f}s\n"
f" - Batches processed: {stats['batches']}\n"
f" - Messages found: {stats['total_messages']}\n"
f" - Messages deleted: {stats['total_deleted']}",
fg="green",
)
)
except Exception as e:
end_at = time.perf_counter()
logger.exception("clean_messages failed")
click.echo(
click.style(
f"clean_messages: failed after {end_at - start_at:.2f}s - {str(e)}",
fg="red",
)
)
raise

View File

@@ -0,0 +1,488 @@
import datetime
import json
import logging
from collections.abc import Sequence
from dataclasses import dataclass
from typing import cast
from sqlalchemy import delete, select
from sqlalchemy.engine import CursorResult
from sqlalchemy.orm import Session
from configs import dify_config
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.model import (
App,
AppAnnotationHitHistory,
DatasetRetrieverResource,
Message,
MessageAgentThought,
MessageAnnotation,
MessageChain,
MessageFeedback,
MessageFile,
)
from models.web import SavedMessage
from services.billing_service import BillingService, SubscriptionPlan
logger = logging.getLogger(__name__)
@dataclass
class SimpleMessage:
"""Lightweight message info containing only essential fields for cleaning."""
id: str
app_id: str
created_at: datetime.datetime
class SandboxMessagesCleanService:
"""
Service for cleaning expired messages from sandbox plan tenants.
"""
# Redis key prefix for tenant plan cache
PLAN_CACHE_KEY_PREFIX = "tenant_plan:"
# Cache TTL: 10 minutes
PLAN_CACHE_TTL = 600
@classmethod
def clean_sandbox_messages_by_time_range(
cls,
start_from: datetime.datetime,
end_before: datetime.datetime,
graceful_period: int = 21,
batch_size: int = 1000,
dry_run: bool = False,
) -> dict[str, int]:
"""
Clean sandbox messages within a specific time range [start_from, end_before).
Args:
start_from: Start time (inclusive) of the range
end_before: End time (exclusive) of the range
batch_size: Number of messages to process per batch
dry_run: Whether to perform a dry run (no actual deletion)
Returns:
Statistics about the cleaning operation
Raises:
ValueError: If start_from >= end_before
"""
if start_from >= end_before:
raise ValueError(f"start_from ({start_from}) must be less than end_before ({end_before})")
if batch_size <= 0:
raise ValueError(f"batch_size ({batch_size}) must be greater than 0")
if graceful_period < 0:
raise ValueError(f"graceful_period ({graceful_period}) must be greater than or equal to 0")
logger.info("clean_messages: start_from=%s, end_before=%s, batch_size=%s", start_from, end_before, batch_size)
return cls._clean_sandbox_messages_by_time_range(
start_from=start_from,
end_before=end_before,
graceful_period=graceful_period,
batch_size=batch_size,
dry_run=dry_run,
)
@classmethod
def clean_sandbox_messages_by_days(
cls,
days: int = 30,
graceful_period: int = 21,
batch_size: int = 1000,
dry_run: bool = False,
) -> dict[str, int]:
"""
Clean sandbox messages older than specified days.
Args:
days: Number of days to look back from now
batch_size: Number of messages to process per batch
dry_run: Whether to perform a dry run (no actual deletion)
Returns:
Statistics about the cleaning operation
"""
if days < 0:
raise ValueError(f"days ({days}) must be greater than or equal to 0")
if batch_size <= 0:
raise ValueError(f"batch_size ({batch_size}) must be greater than 0")
if graceful_period < 0:
raise ValueError(f"graceful_period ({graceful_period}) must be greater than or equal to 0")
end_before = datetime.datetime.now() - datetime.timedelta(days=days)
logger.info("clean_messages: days=%s, end_before=%s, batch_size=%s", days, end_before, batch_size)
return cls._clean_sandbox_messages_by_time_range(
end_before=end_before,
start_from=None,
graceful_period=graceful_period,
batch_size=batch_size,
dry_run=dry_run,
)
@classmethod
def _clean_sandbox_messages_by_time_range(
cls,
end_before: datetime.datetime,
start_from: datetime.datetime | None = None,
graceful_period: int = 21,
batch_size: int = 1000,
dry_run: bool = False,
) -> dict[str, int]:
"""
Internal method to clean sandbox messages within a time range using cursor-based pagination.
Time range is [start_from, end_before) - left-closed, right-open interval.
Steps:
1. Iterate messages using cursor pagination (by created_at, id)
2. Extract app_ids from messages
3. Query tenant_ids from apps
4. Batch fetch subscription plans
5. Delete messages from sandbox tenants
Args:
end_before: End time (exclusive) of the range
start_from: Optional start time (inclusive) of the range
batch_size: Number of messages to process per batch
dry_run: Whether to perform a dry run (no actual deletion)
Returns:
Dict with statistics: batches, total_messages, total_deleted
"""
stats = {
"batches": 0,
"total_messages": 0,
"total_deleted": 0,
}
if not dify_config.BILLING_ENABLED:
logger.info("clean_messages: billing is not enabled, skip cleaning messages")
return stats
tenant_whitelist = cls._get_tenant_whitelist()
logger.info("clean_messages: tenant_whitelist=%s", tenant_whitelist)
# Cursor-based pagination using (created_at, id) to avoid infinite loops
# and ensure proper ordering with time-based filtering
_cursor: tuple[datetime.datetime, str] | None = None
logger.info(
"clean_messages: start cleaning messages (dry_run=%s), start_from=%s, end_before=%s",
dry_run,
start_from,
end_before,
)
while True:
stats["batches"] += 1
# Step 1: Fetch a batch of messages using cursor
with Session(db.engine, expire_on_commit=False) as session:
msg_stmt = (
select(Message.id, Message.app_id, Message.created_at)
.where(Message.created_at < end_before)
.order_by(Message.created_at, Message.id)
.limit(batch_size)
)
if start_from:
msg_stmt = msg_stmt.where(Message.created_at >= start_from)
# Apply cursor condition: (created_at, id) > (last_created_at, last_message_id)
# This translates to:
# created_at > last_created_at OR (created_at = last_created_at AND id > last_message_id)
if _cursor:
# Continuing from previous batch
msg_stmt = msg_stmt.where(
(Message.created_at > _cursor[0])
| ((Message.created_at == _cursor[0]) & (Message.id > _cursor[1]))
)
raw_messages = list(session.execute(msg_stmt).all())
messages = [
SimpleMessage(id=msg_id, app_id=app_id, created_at=msg_created_at)
for msg_id, app_id, msg_created_at in raw_messages
]
if not messages:
logger.info("clean_messages (batch %s): no more messages to process", stats["batches"])
break
# Update cursor to the last message's (created_at, id)
_cursor = (messages[-1].created_at, messages[-1].id)
# Step 2: Extract app_ids from this batch
app_ids = list({msg.app_id for msg in messages})
if not app_ids:
logger.info("clean_messages (batch %s): no app_ids found, skip", stats["batches"])
continue
# Step 3: Query tenant_ids from apps
app_stmt = select(App.id, App.tenant_id).where(App.id.in_(app_ids))
apps = list(session.execute(app_stmt).all())
if not apps:
logger.info("clean_messages (batch %s): no apps found, skip", stats["batches"])
continue
# Step 4: End sesion to call billing API to avoid long-running transaction.
# Build app_id -> tenant_id mapping
app_to_tenant: dict[str, str] = {app.id: app.tenant_id for app in apps}
tenant_ids = list(set(app_to_tenant.values()))
# Batch fetch subscription plans
tenant_plans = cls._batch_fetch_tenant_plans(tenant_ids)
# Step 5: Filter messages from sandbox tenants
sandbox_message_ids = cls._filter_expired_sandbox_messages(
messages=messages,
app_to_tenant=app_to_tenant,
tenant_plans=tenant_plans,
tenant_whitelist=tenant_whitelist,
graceful_period_days=graceful_period,
)
if not sandbox_message_ids:
logger.info("clean_messages (batch %s): no sandbox messages found, skip", stats["batches"])
continue
stats["total_messages"] += len(sandbox_message_ids)
# Step 6: Batch delete messages and their relations
if not dry_run:
with Session(db.engine, expire_on_commit=False) as session:
# Delete related records first
cls._batch_delete_message_relations(session, sandbox_message_ids)
# Delete messages
delete_stmt = delete(Message).where(Message.id.in_(sandbox_message_ids))
delete_result = cast(CursorResult, session.execute(delete_stmt))
messages_deleted = delete_result.rowcount
session.commit()
stats["total_deleted"] += messages_deleted
logger.info(
"clean_messages (batch %s): processed %s messages, deleted %s sandbox messages",
stats["batches"],
len(messages),
messages_deleted,
)
else:
sample_ids = ", ".join(sample_id for sample_id in sandbox_message_ids[:5])
logger.info(
"clean_messages (batch %s, dry_run): would delete %s sandbox messages, sample ids: %s",
stats["batches"],
len(sandbox_message_ids),
sample_ids,
)
logger.info(
"clean_messages completed: total batches: %s, total messages: %s, total deleted: %s",
stats["batches"],
stats["total_messages"],
stats["total_deleted"],
)
return stats
@classmethod
def _filter_expired_sandbox_messages(
cls,
messages: Sequence[SimpleMessage],
app_to_tenant: dict[str, str],
tenant_plans: dict[str, SubscriptionPlan],
tenant_whitelist: Sequence[str],
graceful_period_days: int,
current_timestamp: int | None = None,
) -> list[str]:
"""
Filter messages that should be deleted based on sandbox plan expiration.
A message should be deleted if:
1. It belongs to a sandbox tenant AND
2. Either:
a) The tenant has no previous subscription (expiration_date == -1), OR
b) The subscription expired more than graceful_period_days ago
Args:
messages: List of message objects with id and app_id attributes
app_to_tenant: Mapping from app_id to tenant_id
tenant_plans: Mapping from tenant_id to subscription plan info
graceful_period_days: Grace period in days after expiration
current_timestamp: Current Unix timestamp (defaults to now, injectable for testing)
Returns:
List of message IDs that should be deleted
"""
if current_timestamp is None:
current_timestamp = int(datetime.datetime.now(datetime.UTC).timestamp())
sandbox_message_ids: list[str] = []
graceful_period_seconds = graceful_period_days * 24 * 60 * 60
for msg in messages:
# Get tenant_id for this message's app
tenant_id = app_to_tenant.get(msg.app_id)
if not tenant_id:
continue
# Skip tenant messages in whitelist
if tenant_id in tenant_whitelist:
continue
# Get subscription plan for this tenant
tenant_plan = tenant_plans.get(tenant_id)
if not tenant_plan:
continue
plan = str(tenant_plan["plan"])
expiration_date = int(tenant_plan["expiration_date"])
# Only process sandbox plans
if plan != CloudPlan.SANDBOX:
continue
# Case 1: No previous subscription (-1 means never had a paid subscription)
if expiration_date == -1:
sandbox_message_ids.append(msg.id)
continue
# Case 2: Subscription expired beyond grace period
if current_timestamp - expiration_date > graceful_period_seconds:
sandbox_message_ids.append(msg.id)
return sandbox_message_ids
@classmethod
def _get_tenant_whitelist(cls) -> Sequence[str]:
return BillingService.get_expired_subscription_cleanup_whitelist()
@classmethod
def _batch_fetch_tenant_plans(cls, tenant_ids: Sequence[str]) -> dict[str, SubscriptionPlan]:
"""
Batch fetch tenant plans with Redis caching.
This method uses a two-tier strategy:
1. First, batch fetch from Redis cache using mget
2. For cache misses, fetch from billing API
3. Update Redis cache using pipeline for new entries
Args:
tenant_ids: List of tenant IDs
Returns:
Dict mapping tenant_id to SubscriptionPlan (with "plan" and "expiration_date" keys)
"""
if not tenant_ids:
return {}
tenant_plans: dict[str, SubscriptionPlan] = {}
# Step 1: Batch fetch from Redis cache using mget
redis_keys = [f"{cls.PLAN_CACHE_KEY_PREFIX}{tenant_id}" for tenant_id in tenant_ids]
try:
cached_values = redis_client.mget(redis_keys)
# Map cached values back to tenant_ids
cache_hits: dict[str, SubscriptionPlan] = {}
cache_misses: list[str] = []
for tenant_id, cached_value in zip(tenant_ids, cached_values):
if cached_value:
# Redis returns bytes, decode to string and parse JSON
json_str = cached_value.decode("utf-8") if isinstance(cached_value, bytes) else cached_value
try:
plan_dict = json.loads(json_str)
if isinstance(plan_dict, dict) and "plan" in plan_dict:
cache_hits[tenant_id] = cast(SubscriptionPlan, plan_dict)
tenant_plans[tenant_id] = cast(SubscriptionPlan, plan_dict)
else:
cache_misses.append(tenant_id)
except json.JSONDecodeError:
cache_misses.append(tenant_id)
else:
cache_misses.append(tenant_id)
logger.info(
"clean_messages: fetch_tenant_plans(cache hits=%s, cache misses=%s)",
len(cache_hits),
len(cache_misses),
)
except Exception as e:
logger.warning("clean_messages: fetch_tenant_plans(redis mget failed: %s, falling back to API)", e)
cache_misses = list(tenant_ids)
# Step 2: Fetch missing plans from billing API
if cache_misses:
bulk_plans = BillingService.get_plan_bulk(cache_misses)
if bulk_plans:
plans_to_cache: dict[str, SubscriptionPlan] = {}
for tenant_id, plan_dict in bulk_plans.items():
if isinstance(plan_dict, dict):
tenant_plans[tenant_id] = plan_dict # type: ignore
plans_to_cache[tenant_id] = plan_dict # type: ignore
# Step 3: Batch update Redis cache using pipeline
if plans_to_cache:
try:
pipe = redis_client.pipeline()
for tenant_id, plan_dict in plans_to_cache.items():
redis_key = f"{cls.PLAN_CACHE_KEY_PREFIX}{tenant_id}"
# Serialize dict to JSON string
json_str = json.dumps(plan_dict)
pipe.setex(redis_key, cls.PLAN_CACHE_TTL, json_str)
pipe.execute()
logger.info(
"clean_messages: cached %s new tenant plans to Redis",
len(plans_to_cache),
)
except Exception as e:
logger.warning("clean_messages: Redis pipeline failed: %s", e)
return tenant_plans
@classmethod
def _batch_delete_message_relations(cls, session: Session, message_ids: Sequence[str]) -> None:
"""
Batch delete all related records for given message IDs.
Args:
session: Database session
message_ids: List of message IDs to delete relations for
"""
if not message_ids:
return
# Delete all related records in batch
session.execute(delete(MessageFeedback).where(MessageFeedback.message_id.in_(message_ids)))
session.execute(delete(MessageAnnotation).where(MessageAnnotation.message_id.in_(message_ids)))
session.execute(delete(MessageChain).where(MessageChain.message_id.in_(message_ids)))
session.execute(delete(MessageAgentThought).where(MessageAgentThought.message_id.in_(message_ids)))
session.execute(delete(MessageFile).where(MessageFile.message_id.in_(message_ids)))
session.execute(delete(SavedMessage).where(SavedMessage.message_id.in_(message_ids)))
session.execute(delete(AppAnnotationHitHistory).where(AppAnnotationHitHistory.message_id.in_(message_ids)))
session.execute(delete(DatasetRetrieverResource).where(DatasetRetrieverResource.message_id.in_(message_ids)))

View File

@@ -0,0 +1,996 @@
"""
Integration tests for SandboxMessagesCleanService using testcontainers.
This module provides comprehensive integration tests for the sandbox message cleanup service
using TestContainers infrastructure with real PostgreSQL and Redis.
"""
import datetime
import json
import uuid
from decimal import Decimal
from unittest.mock import MagicMock, patch
import pytest
from faker import Faker
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.model import (
App,
AppAnnotationHitHistory,
Conversation,
DatasetRetrieverResource,
Message,
MessageAgentThought,
MessageAnnotation,
MessageChain,
MessageFeedback,
MessageFile,
)
from models.web import SavedMessage
from services.sandbox_messages_clean_service import SandboxMessagesCleanService
class TestSandboxMessagesCleanServiceIntegration:
"""Integration tests for SandboxMessagesCleanService._clean_sandbox_messages_by_time_range."""
@pytest.fixture(autouse=True)
def cleanup_database(self, db_session_with_containers):
"""Clean up database before and after each test to ensure isolation."""
yield
# Clear all test data in correct order (respecting foreign key constraints)
db.session.query(DatasetRetrieverResource).delete()
db.session.query(AppAnnotationHitHistory).delete()
db.session.query(SavedMessage).delete()
db.session.query(MessageFile).delete()
db.session.query(MessageAgentThought).delete()
db.session.query(MessageChain).delete()
db.session.query(MessageAnnotation).delete()
db.session.query(MessageFeedback).delete()
db.session.query(Message).delete()
db.session.query(Conversation).delete()
db.session.query(App).delete()
db.session.query(TenantAccountJoin).delete()
db.session.query(Tenant).delete()
db.session.query(Account).delete()
db.session.commit()
@pytest.fixture(autouse=True)
def cleanup_redis(self):
"""Clean up Redis cache before each test."""
# Clear tenant plan cache
try:
keys = redis_client.keys(f"{SandboxMessagesCleanService.PLAN_CACHE_KEY_PREFIX}*")
if keys:
redis_client.delete(*keys)
except Exception:
pass # Redis might not be available in some test environments
yield
# Clean up after test
try:
keys = redis_client.keys(f"{SandboxMessagesCleanService.PLAN_CACHE_KEY_PREFIX}*")
if keys:
redis_client.delete(*keys)
except Exception:
pass
@pytest.fixture(autouse=True)
def mock_whitelist(self):
"""Mock whitelist to return empty list by default."""
with patch(
"services.sandbox_messages_clean_service.BillingService.get_expired_subscription_cleanup_whitelist"
) as mock:
mock.return_value = []
yield mock
@pytest.fixture(autouse=True)
def mock_billing_enabled(self):
"""Mock BILLING_ENABLED to be True for all tests."""
with patch("services.sandbox_messages_clean_service.dify_config.BILLING_ENABLED", True):
yield
def _create_account_and_tenant(self, plan="sandbox"):
"""Helper to create account and tenant."""
fake = Faker()
account = Account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
status="active",
)
db.session.add(account)
db.session.flush()
tenant = Tenant(
name=fake.company(),
plan=plan,
status="normal",
)
db.session.add(tenant)
db.session.flush()
tenant_account_join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER,
)
db.session.add(tenant_account_join)
db.session.commit()
return account, tenant
def _create_app(self, tenant, account):
"""Helper to create an app."""
fake = Faker()
app = App(
tenant_id=tenant.id,
name=fake.company(),
description="Test app",
mode="chat",
enable_site=True,
enable_api=True,
api_rpm=60,
api_rph=3600,
is_demo=False,
is_public=False,
created_by=account.id,
updated_by=account.id,
)
db.session.add(app)
db.session.commit()
return app
def _create_conversation(self, app):
"""Helper to create a conversation."""
conversation = Conversation(
app_id=app.id,
app_model_config_id=str(uuid.uuid4()),
model_provider="openai",
model_id="gpt-3.5-turbo",
mode="chat",
name="Test conversation",
inputs={},
status="normal",
from_source="api",
from_end_user_id=str(uuid.uuid4()),
)
db.session.add(conversation)
db.session.commit()
return conversation
def _create_message(self, app, conversation, created_at=None, with_relations=True):
"""Helper to create a message with optional related records."""
if created_at is None:
created_at = datetime.datetime.now()
message = Message(
app_id=app.id,
conversation_id=conversation.id,
model_provider="openai",
model_id="gpt-3.5-turbo",
inputs={},
query="Test query",
answer="Test answer",
message=[{"role": "user", "text": "Test message"}],
message_tokens=10,
message_unit_price=Decimal("0.001"),
answer_tokens=20,
answer_unit_price=Decimal("0.002"),
total_price=Decimal("0.003"),
currency="USD",
from_source="api",
from_account_id=conversation.from_end_user_id,
created_at=created_at,
)
db.session.add(message)
db.session.flush()
if with_relations:
self._create_message_relations(message)
db.session.commit()
return message
def _create_message_relations(self, message):
"""Helper to create all message-related records."""
# MessageFeedback
feedback = MessageFeedback(
app_id=message.app_id,
conversation_id=message.conversation_id,
message_id=message.id,
rating="like",
from_source="api",
from_end_user_id=str(uuid.uuid4()),
)
db.session.add(feedback)
# MessageAnnotation
annotation = MessageAnnotation(
app_id=message.app_id,
conversation_id=message.conversation_id,
message_id=message.id,
question="Test question",
content="Test annotation",
account_id=message.from_account_id,
)
db.session.add(annotation)
# MessageChain
chain = MessageChain(
message_id=message.id,
type="system",
input=json.dumps({"test": "input"}),
output=json.dumps({"test": "output"}),
)
db.session.add(chain)
db.session.flush()
# MessageFile
file = MessageFile(
message_id=message.id,
type="image",
transfer_method="local_file",
url="http://example.com/test.jpg",
belongs_to="user",
created_by_role="end_user",
created_by=str(uuid.uuid4()),
)
db.session.add(file)
# SavedMessage
saved = SavedMessage(
app_id=message.app_id,
message_id=message.id,
created_by_role="end_user",
created_by=str(uuid.uuid4()),
)
db.session.add(saved)
db.session.flush()
# AppAnnotationHitHistory
hit = AppAnnotationHitHistory(
app_id=message.app_id,
annotation_id=annotation.id,
message_id=message.id,
source="annotation",
question="Test question",
account_id=message.from_account_id,
annotation_question="Test annotation question",
annotation_content="Test annotation content",
)
db.session.add(hit)
# DatasetRetrieverResource
resource = DatasetRetrieverResource(
message_id=message.id,
position=1,
dataset_id=str(uuid.uuid4()),
dataset_name="Test dataset",
document_id=str(uuid.uuid4()),
document_name="Test document",
data_source_type="upload_file",
segment_id=str(uuid.uuid4()),
score=0.9,
content="Test content",
hit_count=1,
word_count=10,
segment_position=1,
index_node_hash="test_hash",
retriever_from="dataset",
created_by=message.from_account_id,
)
db.session.add(resource)
def test_clean_no_messages_to_delete(self, db_session_with_containers):
"""Test cleaning when there are no messages to delete."""
# Arrange
end_before = datetime.datetime.now() - datetime.timedelta(days=30)
with patch("services.sandbox_messages_clean_service.BillingService.get_plan_bulk") as mock_billing:
mock_billing.return_value = {}
# Act
stats = SandboxMessagesCleanService._clean_sandbox_messages_by_time_range(
end_before=end_before,
graceful_period=21, # Use default graceful period
batch_size=100,
)
# Assert
# Even with no messages, the loop runs once to check
assert stats["batches"] == 1
assert stats["total_messages"] == 0
assert stats["total_deleted"] == 0
def test_clean_mixed_sandbox_and_paid_tenants(self, db_session_with_containers):
"""Test cleaning with mixed sandbox and paid tenants, correctly filtering sandbox messages."""
# Arrange - Create sandbox tenants with expired messages
sandbox_tenants = []
sandbox_message_ids = []
for i in range(2):
account, tenant = self._create_account_and_tenant(plan="sandbox")
sandbox_tenants.append(tenant)
app = self._create_app(tenant, account)
conv = self._create_conversation(app)
# Create 3 expired messages per sandbox tenant
expired_date = datetime.datetime.now() - datetime.timedelta(days=35)
for j in range(3):
msg = self._create_message(app, conv, created_at=expired_date - datetime.timedelta(hours=j))
sandbox_message_ids.append(msg.id)
# Create paid tenants with expired messages (should NOT be deleted)
paid_tenants = []
paid_message_ids = []
for i in range(2):
account, tenant = self._create_account_and_tenant(plan="professional")
paid_tenants.append(tenant)
app = self._create_app(tenant, account)
conv = self._create_conversation(app)
# Create 2 expired messages per paid tenant
expired_date = datetime.datetime.now() - datetime.timedelta(days=35)
for j in range(2):
msg = self._create_message(app, conv, created_at=expired_date - datetime.timedelta(hours=j))
paid_message_ids.append(msg.id)
# Mock billing service - return plan and expiration_date
now_timestamp = int(datetime.datetime.now(datetime.UTC).timestamp())
expired_15_days_ago = now_timestamp - (15 * 24 * 60 * 60) # Beyond 7-day grace period
plan_map = {}
for tenant in sandbox_tenants:
plan_map[tenant.id] = {
"plan": CloudPlan.SANDBOX,
"expiration_date": expired_15_days_ago,
}
for tenant in paid_tenants:
plan_map[tenant.id] = {
"plan": CloudPlan.PROFESSIONAL,
"expiration_date": now_timestamp + (365 * 24 * 60 * 60), # Active for 1 year
}
with patch("services.sandbox_messages_clean_service.BillingService.get_plan_bulk") as mock_billing:
mock_billing.return_value = plan_map
# Act
end_before = datetime.datetime.now() - datetime.timedelta(days=30)
stats = SandboxMessagesCleanService._clean_sandbox_messages_by_time_range(
end_before=end_before,
graceful_period=7,
batch_size=100,
)
# Assert
assert stats["total_messages"] == 6 # 2 sandbox tenants * 3 messages
assert stats["total_deleted"] == 6
# Only sandbox messages should be deleted
assert db.session.query(Message).where(Message.id.in_(sandbox_message_ids)).count() == 0
# Paid messages should remain
assert db.session.query(Message).where(Message.id.in_(paid_message_ids)).count() == 4
# Related records of sandbox messages should be deleted
assert db.session.query(MessageFeedback).where(MessageFeedback.message_id.in_(sandbox_message_ids)).count() == 0
assert (
db.session.query(MessageAnnotation).where(MessageAnnotation.message_id.in_(sandbox_message_ids)).count()
== 0
)
def test_clean_with_cursor_pagination(self, db_session_with_containers):
"""Test cursor pagination works correctly across multiple batches."""
# Arrange - Create sandbox tenant with messages that will span multiple batches
account, tenant = self._create_account_and_tenant(plan="sandbox")
app = self._create_app(tenant, account)
conv = self._create_conversation(app)
# Create 10 expired messages with different timestamps
base_date = datetime.datetime.now() - datetime.timedelta(days=35)
message_ids = []
for i in range(10):
msg = self._create_message(
app,
conv,
created_at=base_date + datetime.timedelta(hours=i),
with_relations=False, # Skip relations for speed
)
message_ids.append(msg.id)
with patch("services.sandbox_messages_clean_service.BillingService.get_plan_bulk") as mock_billing:
mock_billing.return_value = {
tenant.id: {
"plan": CloudPlan.SANDBOX,
"expiration_date": -1, # No previous subscription
}
}
# Act - Use small batch size to trigger multiple batches
end_before = datetime.datetime.now() - datetime.timedelta(days=30)
stats = SandboxMessagesCleanService._clean_sandbox_messages_by_time_range(
end_before=end_before,
graceful_period=21, # Use default graceful period
batch_size=3, # Small batch size to test pagination
)
# 5 batches for 10 messages with batch_size=3, the last batch is empty
assert stats["batches"] == 5
assert stats["total_messages"] == 10
assert stats["total_deleted"] == 10
# All messages should be deleted
assert db.session.query(Message).where(Message.id.in_(message_ids)).count() == 0
def test_clean_with_dry_run(self, db_session_with_containers):
"""Test dry_run mode does not delete messages."""
# Arrange
account, tenant = self._create_account_and_tenant(plan="sandbox")
app = self._create_app(tenant, account)
conv = self._create_conversation(app)
# Create expired messages
expired_date = datetime.datetime.now() - datetime.timedelta(days=35)
message_ids = []
for i in range(3):
msg = self._create_message(app, conv, created_at=expired_date - datetime.timedelta(hours=i))
message_ids.append(msg.id)
with patch("services.sandbox_messages_clean_service.BillingService.get_plan_bulk") as mock_billing:
mock_billing.return_value = {
tenant.id: {
"plan": CloudPlan.SANDBOX,
"expiration_date": -1, # No previous subscription
}
}
# Act
end_before = datetime.datetime.now() - datetime.timedelta(days=30)
stats = SandboxMessagesCleanService._clean_sandbox_messages_by_time_range(
end_before=end_before,
graceful_period=21, # Use default graceful period
batch_size=100,
dry_run=True, # Dry run mode
)
# Assert
assert stats["total_messages"] == 3 # Messages identified
assert stats["total_deleted"] == 0 # But NOT deleted
# All messages should still exist
assert db.session.query(Message).where(Message.id.in_(message_ids)).count() == 3
# Related records should also still exist
assert db.session.query(MessageFeedback).where(MessageFeedback.message_id.in_(message_ids)).count() == 3
def test_clean_with_billing_partial_exception_some_known_plans(self, db_session_with_containers):
"""Test when billing service fails but returns partial data, only delete known sandbox messages."""
# Arrange - Create 3 tenants
tenants_data = []
for i in range(3):
account, tenant = self._create_account_and_tenant(plan="sandbox")
app = self._create_app(tenant, account)
conv = self._create_conversation(app)
expired_date = datetime.datetime.now() - datetime.timedelta(days=35)
msg = self._create_message(app, conv, created_at=expired_date)
tenants_data.append(
{
"tenant": tenant,
"message_id": msg.id,
}
)
# Mock billing service to return partial data with new structure
now_timestamp = int(datetime.datetime.now(datetime.UTC).timestamp())
# Only tenant[0] is confirmed as sandbox, tenant[1] is professional, tenant[2] is missing
partial_plan_map = {
tenants_data[0]["tenant"].id: {
"plan": CloudPlan.SANDBOX,
"expiration_date": -1, # No previous subscription
},
tenants_data[1]["tenant"].id: {
"plan": CloudPlan.PROFESSIONAL,
"expiration_date": now_timestamp + (365 * 24 * 60 * 60), # Active for 1 year
},
# tenants_data[2] is missing from response
}
with patch("services.sandbox_messages_clean_service.BillingService.get_plan_bulk") as mock_billing:
mock_billing.return_value = partial_plan_map
# Act
end_before = datetime.datetime.now() - datetime.timedelta(days=30)
stats = SandboxMessagesCleanService._clean_sandbox_messages_by_time_range(
end_before=end_before,
graceful_period=21, # Use default graceful period
batch_size=100,
)
# Assert - Only tenant[0]'s message should be deleted
assert stats["total_messages"] == 1
assert stats["total_deleted"] == 1
# Check which messages were deleted
assert (
db.session.query(Message).where(Message.id == tenants_data[0]["message_id"]).count() == 0
) # Sandbox tenant's message deleted
assert (
db.session.query(Message).where(Message.id == tenants_data[1]["message_id"]).count() == 1
) # Professional tenant's message preserved
assert (
db.session.query(Message).where(Message.id == tenants_data[2]["message_id"]).count() == 1
) # Unknown tenant's message preserved (safe default)
def test_clean_with_billing_exception_no_data(self, db_session_with_containers):
"""Test when billing service returns empty data, skip deletion for that batch."""
# Arrange
account, tenant = self._create_account_and_tenant(plan="sandbox")
app = self._create_app(tenant, account)
conv = self._create_conversation(app)
expired_date = datetime.datetime.now() - datetime.timedelta(days=35)
msg_id = None
msg = self._create_message(app, conv, created_at=expired_date)
msg_id = msg.id # Store ID before any operations
db.session.commit()
# Mock billing service to return empty data (simulating failure/no data scenario)
with patch("services.sandbox_messages_clean_service.BillingService.get_plan_bulk") as mock_billing:
mock_billing.return_value = {} # Empty response, tenant plan unknown
# Act - Should not raise exception, just skip deletion
end_before = datetime.datetime.now() - datetime.timedelta(days=30)
stats = SandboxMessagesCleanService._clean_sandbox_messages_by_time_range(
end_before=end_before,
graceful_period=21, # Use default graceful period
batch_size=100,
)
# Assert - No messages should be deleted when plan is unknown
assert stats["total_messages"] == 0 # Cannot determine sandbox messages
assert stats["total_deleted"] == 0
# Message should still exist (safe default - don't delete if plan is unknown)
assert db.session.query(Message).where(Message.id == msg_id).count() == 1
def test_redis_cache_for_tenant_plans(self, db_session_with_containers):
"""Test that tenant plans are cached in Redis and reused."""
# Arrange
account, tenant = self._create_account_and_tenant(plan="sandbox")
app = self._create_app(tenant, account)
conv = self._create_conversation(app)
# Create messages in two batches (to test cache reuse)
expired_date = datetime.datetime.now() - datetime.timedelta(days=35)
batch1_msgs = []
for i in range(2):
msg = self._create_message(
app, conv, created_at=expired_date + datetime.timedelta(hours=i), with_relations=False
)
batch1_msgs.append(msg.id)
batch2_msgs = []
for i in range(2):
msg = self._create_message(
app, conv, created_at=expired_date + datetime.timedelta(hours=10 + i), with_relations=False
)
batch2_msgs.append(msg.id)
# Mock billing service with new structure
mock_get_plan_bulk = MagicMock(
return_value={
tenant.id: {
"plan": CloudPlan.SANDBOX,
"expiration_date": -1, # No previous subscription
}
}
)
with patch("services.sandbox_messages_clean_service.BillingService.get_plan_bulk", mock_get_plan_bulk):
# Act - First call
end_before = datetime.datetime.now() - datetime.timedelta(days=30)
stats1 = SandboxMessagesCleanService._clean_sandbox_messages_by_time_range(
end_before=end_before,
graceful_period=21, # Use default graceful period
batch_size=2, # Process 2 messages per batch
)
# Check billing service was called (cache miss)
assert mock_get_plan_bulk.call_count == 1
first_call_count = mock_get_plan_bulk.call_count
# Verify Redis cache was populated
cache_key = f"{SandboxMessagesCleanService.PLAN_CACHE_KEY_PREFIX}{tenant.id}"
cached_plan = redis_client.get(cache_key)
assert cached_plan is not None
cached_plan_data = json.loads(cached_plan.decode("utf-8"))
assert cached_plan_data["plan"] == CloudPlan.SANDBOX
assert cached_plan_data["expiration_date"] == -1
# Act - Second call with same tenant (should use cache)
# Create more messages for the same tenant
batch3_msgs = []
for i in range(2):
msg = self._create_message(
app, conv, created_at=expired_date + datetime.timedelta(hours=20 + i), with_relations=False
)
batch3_msgs.append(msg.id)
stats2 = SandboxMessagesCleanService._clean_sandbox_messages_by_time_range(
end_before=end_before,
graceful_period=21, # Use default graceful period
batch_size=2,
)
# Assert - Billing service should not be called again (cache hit)
# The call count should be the same
assert mock_get_plan_bulk.call_count == first_call_count # Same tenant, should use cache
# Verify all messages were deleted
total_expected = len(batch1_msgs) + len(batch2_msgs) + len(batch3_msgs)
assert stats1["total_deleted"] + stats2["total_deleted"] == total_expected
def test_time_range_filtering(self, db_session_with_containers):
"""Test that messages are correctly filtered by [start_from, end_before) time range."""
# Arrange
account, tenant = self._create_account_and_tenant(plan="sandbox")
app = self._create_app(tenant, account)
conv = self._create_conversation(app)
base_date = datetime.datetime(2024, 1, 15, 12, 0, 0)
# Create messages: before range, in range, after range
msg_before = self._create_message(
app,
conv,
created_at=datetime.datetime(2024, 1, 1, 12, 0, 0), # Before start_from
with_relations=False,
)
msg_before_id = msg_before.id
msg_at_start = self._create_message(
app,
conv,
created_at=datetime.datetime(2024, 1, 10, 12, 0, 0), # At start_from (inclusive)
with_relations=False,
)
msg_at_start_id = msg_at_start.id
msg_in_range = self._create_message(
app,
conv,
created_at=datetime.datetime(2024, 1, 15, 12, 0, 0), # In range
with_relations=False,
)
msg_in_range_id = msg_in_range.id
msg_at_end = self._create_message(
app,
conv,
created_at=datetime.datetime(2024, 1, 20, 12, 0, 0), # At end_before (exclusive)
with_relations=False,
)
msg_at_end_id = msg_at_end.id
msg_after = self._create_message(
app,
conv,
created_at=datetime.datetime(2024, 1, 25, 12, 0, 0), # After end_before
with_relations=False,
)
msg_after_id = msg_after.id
db.session.commit() # Commit all messages
# Mock billing service with new structure
with patch("services.sandbox_messages_clean_service.BillingService.get_plan_bulk") as mock_billing:
mock_billing.return_value = {
tenant.id: {
"plan": CloudPlan.SANDBOX,
"expiration_date": -1, # No previous subscription
}
}
# Act - Clean with specific time range [2024-01-10, 2024-01-20)
stats = SandboxMessagesCleanService._clean_sandbox_messages_by_time_range(
start_from=datetime.datetime(2024, 1, 10, 12, 0, 0),
end_before=datetime.datetime(2024, 1, 20, 12, 0, 0),
graceful_period=21, # Use default graceful period
batch_size=100,
)
# Assert - Only messages in [start_from, end_before) should be deleted
assert stats["total_messages"] == 2 # msg_at_start and msg_in_range
assert stats["total_deleted"] == 2
# Verify specific messages using stored IDs
# Before range, kept
assert db.session.query(Message).where(Message.id == msg_before_id).count() == 1
# At start (inclusive), deleted
assert db.session.query(Message).where(Message.id == msg_at_start_id).count() == 0
# In range, deleted
assert db.session.query(Message).where(Message.id == msg_in_range_id).count() == 0
# At end (exclusive), kept
assert db.session.query(Message).where(Message.id == msg_at_end_id).count() == 1
# After range, kept
assert db.session.query(Message).where(Message.id == msg_after_id).count() == 1
def test_clean_with_graceful_period_scenarios(self, db_session_with_containers):
"""Test cleaning with different graceful period scenarios."""
# Arrange - Create 5 different tenants with different plan and expiration scenarios
now_timestamp = int(datetime.datetime.now(datetime.UTC).timestamp())
graceful_period = 8 # Use 8 days for this test to validate boundary conditions
# Scenario 1: Sandbox plan with expiration within graceful period (5 days ago)
# Should NOT be deleted
account1, tenant1 = self._create_account_and_tenant(plan="sandbox")
app1 = self._create_app(tenant1, account1)
conv1 = self._create_conversation(app1)
expired_date = datetime.datetime.now() - datetime.timedelta(days=35)
msg1 = self._create_message(app1, conv1, created_at=expired_date, with_relations=False)
msg1_id = msg1.id # Save ID before potential deletion
expired_5_days_ago = now_timestamp - (5 * 24 * 60 * 60) # Within grace period
# Scenario 2: Sandbox plan with expiration beyond graceful period (10 days ago)
# Should be deleted
account2, tenant2 = self._create_account_and_tenant(plan="sandbox")
app2 = self._create_app(tenant2, account2)
conv2 = self._create_conversation(app2)
msg2 = self._create_message(app2, conv2, created_at=expired_date, with_relations=False)
msg2_id = msg2.id # Save ID before potential deletion
expired_10_days_ago = now_timestamp - (10 * 24 * 60 * 60) # Beyond grace period
# Scenario 3: Sandbox plan with expiration_date = -1 (no previous subscription)
# Should be deleted
account3, tenant3 = self._create_account_and_tenant(plan="sandbox")
app3 = self._create_app(tenant3, account3)
conv3 = self._create_conversation(app3)
msg3 = self._create_message(app3, conv3, created_at=expired_date, with_relations=False)
msg3_id = msg3.id # Save ID before potential deletion
# Scenario 4: Non-sandbox plan (professional) with no expiration (future date)
# Should NOT be deleted
account4, tenant4 = self._create_account_and_tenant(plan="professional")
app4 = self._create_app(tenant4, account4)
conv4 = self._create_conversation(app4)
msg4 = self._create_message(app4, conv4, created_at=expired_date, with_relations=False)
msg4_id = msg4.id # Save ID before potential deletion
future_expiration = now_timestamp + (365 * 24 * 60 * 60) # Active for 1 year
# Scenario 5: Sandbox plan with expiration exactly at grace period boundary (8 days ago)
# Should NOT be deleted (boundary is exclusive: > graceful_period)
account5, tenant5 = self._create_account_and_tenant(plan="sandbox")
app5 = self._create_app(tenant5, account5)
conv5 = self._create_conversation(app5)
msg5 = self._create_message(app5, conv5, created_at=expired_date, with_relations=False)
msg5_id = msg5.id # Save ID before potential deletion
expired_exactly_8_days_ago = now_timestamp - (8 * 24 * 60 * 60) # Exactly at boundary
db.session.commit()
# Mock billing service with all scenarios
plan_map = {
tenant1.id: {
"plan": CloudPlan.SANDBOX,
"expiration_date": expired_5_days_ago,
},
tenant2.id: {
"plan": CloudPlan.SANDBOX,
"expiration_date": expired_10_days_ago,
},
tenant3.id: {
"plan": CloudPlan.SANDBOX,
"expiration_date": -1,
},
tenant4.id: {
"plan": CloudPlan.PROFESSIONAL,
"expiration_date": future_expiration,
},
tenant5.id: {
"plan": CloudPlan.SANDBOX,
"expiration_date": expired_exactly_8_days_ago,
},
}
with patch("services.sandbox_messages_clean_service.BillingService.get_plan_bulk") as mock_billing:
mock_billing.return_value = plan_map
# Act
end_before = datetime.datetime.now() - datetime.timedelta(days=30)
# Mock datetime.now() to use the same timestamp as test setup
# This ensures deterministic behavior for boundary conditions (scenario 5)
with patch("services.sandbox_messages_clean_service.datetime") as mock_datetime:
mock_datetime.datetime.now.return_value = datetime.datetime.fromtimestamp(
now_timestamp, tz=datetime.UTC
)
mock_datetime.timedelta = datetime.timedelta # Keep original timedelta
stats = SandboxMessagesCleanService._clean_sandbox_messages_by_time_range(
end_before=end_before,
graceful_period=graceful_period,
batch_size=100,
)
# Assert - Only messages from scenario 2 and 3 should be deleted
assert stats["total_messages"] == 2
assert stats["total_deleted"] == 2
# Verify each scenario using saved IDs
assert db.session.query(Message).where(Message.id == msg1_id).count() == 1 # Within grace, kept
assert db.session.query(Message).where(Message.id == msg2_id).count() == 0 # Beyond grace, deleted
assert db.session.query(Message).where(Message.id == msg3_id).count() == 0 # No subscription, deleted
assert db.session.query(Message).where(Message.id == msg4_id).count() == 1 # Professional plan, kept
assert db.session.query(Message).where(Message.id == msg5_id).count() == 1 # At boundary, kept
def test_clean_with_tenant_whitelist(self, db_session_with_containers, mock_whitelist):
"""Test that whitelisted tenants' messages are not deleted even if they are sandbox and expired."""
# Arrange - Create 3 sandbox tenants with expired messages
tenants_data = []
for i in range(3):
account, tenant = self._create_account_and_tenant(plan="sandbox")
app = self._create_app(tenant, account)
conv = self._create_conversation(app)
expired_date = datetime.datetime.now() - datetime.timedelta(days=35)
msg = self._create_message(app, conv, created_at=expired_date, with_relations=False)
tenants_data.append(
{
"tenant": tenant,
"message_id": msg.id,
}
)
# Mock billing service - all tenants are sandbox with no subscription
plan_map = {
tenants_data[0]["tenant"].id: {
"plan": CloudPlan.SANDBOX,
"expiration_date": -1, # No previous subscription
},
tenants_data[1]["tenant"].id: {
"plan": CloudPlan.SANDBOX,
"expiration_date": -1, # No previous subscription
},
tenants_data[2]["tenant"].id: {
"plan": CloudPlan.SANDBOX,
"expiration_date": -1, # No previous subscription
},
}
# Setup whitelist - tenant0 and tenant1 are whitelisted, tenant2 is not
whitelist = [tenants_data[0]["tenant"].id, tenants_data[1]["tenant"].id]
mock_whitelist.return_value = whitelist
with patch("services.sandbox_messages_clean_service.BillingService.get_plan_bulk") as mock_billing:
mock_billing.return_value = plan_map
# Act
end_before = datetime.datetime.now() - datetime.timedelta(days=30)
stats = SandboxMessagesCleanService._clean_sandbox_messages_by_time_range(
end_before=end_before,
graceful_period=21, # Use default graceful period
batch_size=100,
)
# Assert - Only tenant2's message should be deleted (not whitelisted)
assert stats["total_messages"] == 1
assert stats["total_deleted"] == 1
# Verify tenant0's message still exists (whitelisted)
assert db.session.query(Message).where(Message.id == tenants_data[0]["message_id"]).count() == 1
# Verify tenant1's message still exists (whitelisted)
assert db.session.query(Message).where(Message.id == tenants_data[1]["message_id"]).count() == 1
# Verify tenant2's message was deleted (not whitelisted)
assert db.session.query(Message).where(Message.id == tenants_data[2]["message_id"]).count() == 0
def test_clean_with_whitelist_and_grace_period(self, db_session_with_containers, mock_whitelist):
"""Test that whitelist takes precedence over grace period logic."""
# Arrange - Create 2 sandbox tenants
now_timestamp = int(datetime.datetime.now(datetime.UTC).timestamp())
# Tenant1: whitelisted, expired beyond grace period
account1, tenant1 = self._create_account_and_tenant(plan="sandbox")
app1 = self._create_app(tenant1, account1)
conv1 = self._create_conversation(app1)
expired_date = datetime.datetime.now() - datetime.timedelta(days=35)
msg1 = self._create_message(app1, conv1, created_at=expired_date, with_relations=False)
expired_30_days_ago = now_timestamp - (30 * 24 * 60 * 60) # Well beyond 21-day grace
# Tenant2: not whitelisted, within grace period
account2, tenant2 = self._create_account_and_tenant(plan="sandbox")
app2 = self._create_app(tenant2, account2)
conv2 = self._create_conversation(app2)
msg2 = self._create_message(app2, conv2, created_at=expired_date, with_relations=False)
expired_10_days_ago = now_timestamp - (10 * 24 * 60 * 60) # Within 21-day grace
# Mock billing service
plan_map = {
tenant1.id: {
"plan": CloudPlan.SANDBOX,
"expiration_date": expired_30_days_ago, # Beyond grace period
},
tenant2.id: {
"plan": CloudPlan.SANDBOX,
"expiration_date": expired_10_days_ago, # Within grace period
},
}
# Setup whitelist - only tenant1 is whitelisted
whitelist = [tenant1.id]
mock_whitelist.return_value = whitelist
with patch("services.sandbox_messages_clean_service.BillingService.get_plan_bulk") as mock_billing:
mock_billing.return_value = plan_map
# Act
end_before = datetime.datetime.now() - datetime.timedelta(days=30)
stats = SandboxMessagesCleanService._clean_sandbox_messages_by_time_range(
end_before=end_before,
graceful_period=21, # Use default graceful period
batch_size=100,
)
# Assert - No messages should be deleted
# tenant1: whitelisted (would be deleted based on grace period, but protected by whitelist)
# tenant2: within grace period (not eligible for deletion)
assert stats["total_messages"] == 0
assert stats["total_deleted"] == 0
# Verify both messages still exist
assert db.session.query(Message).where(Message.id == msg1.id).count() == 1 # Whitelisted
assert db.session.query(Message).where(Message.id == msg2.id).count() == 1 # Within grace period
def test_clean_with_empty_whitelist(self, db_session_with_containers, mock_whitelist):
"""Test that empty whitelist behaves as no whitelist (all eligible messages are deleted)."""
# Arrange - Create sandbox tenant with expired messages
account, tenant = self._create_account_and_tenant(plan="sandbox")
app = self._create_app(tenant, account)
conv = self._create_conversation(app)
expired_date = datetime.datetime.now() - datetime.timedelta(days=35)
msg_ids = []
for i in range(3):
msg = self._create_message(app, conv, created_at=expired_date - datetime.timedelta(hours=i))
msg_ids.append(msg.id)
# Mock billing service
plan_map = {
tenant.id: {
"plan": CloudPlan.SANDBOX,
"expiration_date": -1, # No previous subscription
}
}
# Setup empty whitelist (default behavior from fixture)
mock_whitelist.return_value = []
with patch("services.sandbox_messages_clean_service.BillingService.get_plan_bulk") as mock_billing:
mock_billing.return_value = plan_map
# Act
end_before = datetime.datetime.now() - datetime.timedelta(days=30)
stats = SandboxMessagesCleanService._clean_sandbox_messages_by_time_range(
end_before=end_before,
graceful_period=21, # Use default graceful period
batch_size=100,
)
# Assert - All messages should be deleted (no whitelist protection)
assert stats["total_messages"] == 3
assert stats["total_deleted"] == 3
# Verify all messages were deleted
assert db.session.query(Message).where(Message.id.in_(msg_ids)).count() == 0

View File

@@ -0,0 +1,588 @@
"""
Unit tests for SandboxMessagesCleanService.
This module tests parameter validation, method invocation, and error handling
without database dependencies (using mocks).
"""
import datetime
from unittest.mock import patch
import pytest
from enums.cloud_plan import CloudPlan
from services.sandbox_messages_clean_service import SandboxMessagesCleanService
class MockMessage:
"""Mock message object for testing."""
def __init__(self, id: str, app_id: str, created_at: datetime.datetime | None = None):
self.id = id
self.app_id = app_id
self.created_at = created_at or datetime.datetime.now()
class TestFilterExpiredSandboxMessages:
"""Unit tests for _filter_expired_sandbox_messages method."""
def test_filter_missing_tenant_mapping(self):
"""Test that messages with missing app-to-tenant mapping are excluded."""
# Arrange
messages = [
MockMessage("msg1", "app1"),
MockMessage("msg2", "app2"),
]
app_to_tenant = {} # No mapping
tenant_plans = {"tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": -1}}
# Act
result = SandboxMessagesCleanService._filter_expired_sandbox_messages(
messages=messages,
app_to_tenant=app_to_tenant,
tenant_plans=tenant_plans,
tenant_whitelist=[],
graceful_period_days=8,
current_timestamp=1000000,
)
# Assert
assert result == []
def test_filter_missing_tenant_plan(self):
"""Test that messages with missing tenant plan are excluded."""
# Arrange
messages = [
MockMessage("msg1", "app1"),
MockMessage("msg2", "app2"),
]
app_to_tenant = {
"app1": "tenant1",
"app2": "tenant2",
}
tenant_plans = {} # No plans
# Act
result = SandboxMessagesCleanService._filter_expired_sandbox_messages(
messages=messages,
app_to_tenant=app_to_tenant,
tenant_plans=tenant_plans,
tenant_whitelist=[],
graceful_period_days=8,
current_timestamp=1000000,
)
# Assert
assert result == []
def test_filter_no_previous_subscription(self):
"""Test that messages with no previous subscription (expiration_date=-1) are deleted."""
# Arrange
messages = [
MockMessage("msg1", "app1"),
MockMessage("msg2", "app2"),
MockMessage("msg3", "app3"),
]
app_to_tenant = {
"app1": "tenant1",
"app2": "tenant2",
"app3": "tenant3",
}
tenant_plans = {
"tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": -1},
"tenant2": {"plan": CloudPlan.SANDBOX, "expiration_date": -1},
"tenant3": {"plan": CloudPlan.SANDBOX, "expiration_date": -1},
}
# Act
result = SandboxMessagesCleanService._filter_expired_sandbox_messages(
messages=messages,
app_to_tenant=app_to_tenant,
tenant_plans=tenant_plans,
tenant_whitelist=[],
graceful_period_days=8,
current_timestamp=1000000,
)
# Assert - all messages should be deleted
assert set(result) == {"msg1", "msg2", "msg3"}
def test_filter_all_within_grace_period(self):
"""Test that no messages are deleted when all are within grace period."""
# Arrange
now = 1000000
# All expired recently (within 8 day grace period)
expired_1_day_ago = now - (1 * 24 * 60 * 60)
expired_3_days_ago = now - (3 * 24 * 60 * 60)
expired_7_days_ago = now - (7 * 24 * 60 * 60)
messages = [
MockMessage("msg1", "app1"),
MockMessage("msg2", "app2"),
MockMessage("msg3", "app3"),
]
app_to_tenant = {
"app1": "tenant1",
"app2": "tenant2",
"app3": "tenant3",
}
tenant_plans = {
"tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": expired_1_day_ago},
"tenant2": {"plan": CloudPlan.SANDBOX, "expiration_date": expired_3_days_ago},
"tenant3": {"plan": CloudPlan.SANDBOX, "expiration_date": expired_7_days_ago},
}
# Act
result = SandboxMessagesCleanService._filter_expired_sandbox_messages(
messages=messages,
app_to_tenant=app_to_tenant,
tenant_plans=tenant_plans,
tenant_whitelist=[],
graceful_period_days=8,
current_timestamp=now,
)
# Assert - no messages should be deleted
assert result == []
def test_filter_partial_expired_beyond_grace_period(self):
"""Test filtering when some messages expired beyond grace period."""
# Arrange
now = 1000000
graceful_period = 8
# Different expiration scenarios
expired_5_days_ago = now - (5 * 24 * 60 * 60) # Within grace - keep
expired_10_days_ago = now - (10 * 24 * 60 * 60) # Beyond grace - delete
expired_30_days_ago = now - (30 * 24 * 60 * 60) # Beyond grace - delete
expired_exactly_8_days_ago = now - (8 * 24 * 60 * 60) # Exactly at boundary - keep
expired_9_days_ago = now - (9 * 24 * 60 * 60) # Just beyond - delete
messages = [
MockMessage("msg1", "app1"), # Within grace
MockMessage("msg2", "app2"), # Beyond grace
MockMessage("msg3", "app3"), # Beyond grace
MockMessage("msg4", "app4"), # No subscription - delete
MockMessage("msg5", "app5"), # Exactly at boundary
MockMessage("msg6", "app6"), # Just beyond grace
]
app_to_tenant = {
"app1": "tenant1",
"app2": "tenant2",
"app3": "tenant3",
"app4": "tenant4",
"app5": "tenant5",
"app6": "tenant6",
}
tenant_plans = {
"tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": expired_5_days_ago},
"tenant2": {"plan": CloudPlan.SANDBOX, "expiration_date": expired_10_days_ago},
"tenant3": {"plan": CloudPlan.SANDBOX, "expiration_date": expired_30_days_ago},
"tenant4": {"plan": CloudPlan.SANDBOX, "expiration_date": -1},
"tenant5": {"plan": CloudPlan.SANDBOX, "expiration_date": expired_exactly_8_days_ago},
"tenant6": {"plan": CloudPlan.SANDBOX, "expiration_date": expired_9_days_ago},
}
# Act
result = SandboxMessagesCleanService._filter_expired_sandbox_messages(
messages=messages,
app_to_tenant=app_to_tenant,
tenant_plans=tenant_plans,
tenant_whitelist=[],
graceful_period_days=graceful_period,
current_timestamp=now,
)
# Assert - msg2, msg3, msg4, msg6 should be deleted
# msg1 and msg5 are within/at grace period boundary
assert set(result) == {"msg2", "msg3", "msg4", "msg6"}
def test_filter_complex_mixed_scenario(self):
"""Test complex scenario with mixed plans, expirations, and missing mappings."""
# Arrange
now = 1000000
sandbox_expired_old = now - (15 * 24 * 60 * 60) # 15 days ago - beyond grace
sandbox_expired_recent = now - (3 * 24 * 60 * 60) # 3 days ago - within grace
future_expiration = now + (30 * 24 * 60 * 60) # 30 days in future - active paid plan
messages = [
MockMessage("msg1", "app1"), # Sandbox, no subscription - delete
MockMessage("msg2", "app2"), # Sandbox, expired old - delete
MockMessage("msg3", "app3"), # Sandbox, within grace - keep
MockMessage("msg4", "app4"), # Team plan, active - keep
MockMessage("msg5", "app5"), # No tenant mapping - keep
MockMessage("msg6", "app6"), # No plan info - keep
MockMessage("msg7", "app7"), # Sandbox, expired old - delete
]
app_to_tenant = {
"app1": "tenant1",
"app2": "tenant2",
"app3": "tenant3",
"app4": "tenant4",
"app6": "tenant6", # Has mapping but no plan
"app7": "tenant7",
# app5 has no mapping
}
tenant_plans = {
"tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": -1},
"tenant2": {"plan": CloudPlan.SANDBOX, "expiration_date": sandbox_expired_old},
"tenant3": {"plan": CloudPlan.SANDBOX, "expiration_date": sandbox_expired_recent},
"tenant4": {"plan": CloudPlan.TEAM, "expiration_date": future_expiration},
"tenant7": {"plan": CloudPlan.SANDBOX, "expiration_date": sandbox_expired_old},
# tenant6 has no plan
}
# Act
result = SandboxMessagesCleanService._filter_expired_sandbox_messages(
messages=messages,
app_to_tenant=app_to_tenant,
tenant_plans=tenant_plans,
tenant_whitelist=[],
graceful_period_days=8,
current_timestamp=now,
)
# Assert - only sandbox expired beyond grace period and no subscription
assert set(result) == {"msg1", "msg2", "msg7"}
def test_filter_empty_inputs(self):
"""Test filtering with empty inputs returns empty list."""
# Arrange - empty messages
result1 = SandboxMessagesCleanService._filter_expired_sandbox_messages(
messages=[],
app_to_tenant={"app1": "tenant1"},
tenant_plans={"tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": -1}},
tenant_whitelist=[],
graceful_period_days=8,
current_timestamp=1000000,
)
# Assert
assert result1 == []
def test_filter_uses_default_timestamp(self):
"""Test that method uses current time when timestamp not provided."""
# Arrange
messages = [MockMessage("msg1", "app1")]
app_to_tenant = {"app1": "tenant1"}
tenant_plans = {"tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": -1}}
# Act - don't provide current_timestamp
result = SandboxMessagesCleanService._filter_expired_sandbox_messages(
messages=messages,
app_to_tenant=app_to_tenant,
tenant_plans=tenant_plans,
tenant_whitelist=[],
graceful_period_days=8,
# current_timestamp not provided - should use datetime.now()
)
# Assert - should still work and return msg1 (no subscription)
assert result == ["msg1"]
def test_filter_with_whitelist(self):
"""Test that messages from whitelisted tenants are excluded from deletion."""
# Arrange
messages = [
MockMessage("msg1", "app1"), # Whitelisted tenant - should be kept
MockMessage("msg2", "app2"), # Not whitelisted - should be deleted
MockMessage("msg3", "app3"), # Whitelisted tenant - should be kept
MockMessage("msg4", "app4"), # Not whitelisted - should be deleted
]
app_to_tenant = {
"app1": "tenant1",
"app2": "tenant2",
"app3": "tenant3",
"app4": "tenant4",
}
tenant_plans = {
"tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": -1},
"tenant2": {"plan": CloudPlan.SANDBOX, "expiration_date": -1},
"tenant3": {"plan": CloudPlan.SANDBOX, "expiration_date": -1},
"tenant4": {"plan": CloudPlan.SANDBOX, "expiration_date": -1},
}
tenant_whitelist = ["tenant1", "tenant3"] # Whitelist tenant1 and tenant3
# Act
result = SandboxMessagesCleanService._filter_expired_sandbox_messages(
messages=messages,
app_to_tenant=app_to_tenant,
tenant_plans=tenant_plans,
tenant_whitelist=tenant_whitelist,
graceful_period_days=8,
current_timestamp=1000000,
)
# Assert - only msg2 and msg4 should be deleted (not whitelisted)
assert set(result) == {"msg2", "msg4"}
def test_filter_with_whitelist_and_grace_period(self):
"""Test whitelist takes precedence over grace period logic."""
# Arrange
now = 1000000
expired_long_ago = now - (30 * 24 * 60 * 60) # Expired 30 days ago
messages = [
MockMessage("msg1", "app1"), # Whitelisted, expired long ago - should be kept
MockMessage("msg2", "app2"), # Not whitelisted, expired long ago - should be deleted
]
app_to_tenant = {
"app1": "tenant1",
"app2": "tenant2",
}
tenant_plans = {
"tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": expired_long_ago},
"tenant2": {"plan": CloudPlan.SANDBOX, "expiration_date": expired_long_ago},
}
tenant_whitelist = ["tenant1"] # Only tenant1 is whitelisted
# Act
result = SandboxMessagesCleanService._filter_expired_sandbox_messages(
messages=messages,
app_to_tenant=app_to_tenant,
tenant_plans=tenant_plans,
tenant_whitelist=tenant_whitelist,
graceful_period_days=8,
current_timestamp=now,
)
# Assert - only msg2 should be deleted
assert result == ["msg2"]
def test_filter_whitelist_with_non_sandbox_plans(self):
"""Test that whitelist only affects sandbox plan messages."""
# Arrange
now = 1000000
future_expiration = now + (30 * 24 * 60 * 60)
messages = [
MockMessage("msg1", "app1"), # Sandbox, whitelisted - kept
MockMessage("msg2", "app2"), # Team plan, whitelisted - kept (not sandbox)
MockMessage("msg3", "app3"), # Sandbox, not whitelisted - deleted
]
app_to_tenant = {
"app1": "tenant1",
"app2": "tenant2",
"app3": "tenant3",
}
tenant_plans = {
"tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": -1},
"tenant2": {"plan": CloudPlan.TEAM, "expiration_date": future_expiration},
"tenant3": {"plan": CloudPlan.SANDBOX, "expiration_date": -1},
}
tenant_whitelist = ["tenant1", "tenant2"]
# Act
result = SandboxMessagesCleanService._filter_expired_sandbox_messages(
messages=messages,
app_to_tenant=app_to_tenant,
tenant_plans=tenant_plans,
tenant_whitelist=tenant_whitelist,
graceful_period_days=8,
current_timestamp=now,
)
# Assert - only msg3 should be deleted (sandbox, not whitelisted)
assert result == ["msg3"]
class TestCleanSandboxMessagesByTimeRange:
"""Unit tests for clean_sandbox_messages_by_time_range method."""
@patch.object(SandboxMessagesCleanService, "_clean_sandbox_messages_by_time_range")
def test_valid_time_range_and_args(self, mock_clean):
"""Test with valid time range and other parameters."""
# Arrange
start_from = datetime.datetime(2024, 1, 1, 0, 0, 0)
end_before = datetime.datetime(2024, 12, 31, 23, 59, 59)
batch_size = 500
dry_run = True
mock_clean.return_value = {
"batches": 5,
"total_messages": 100,
"total_deleted": 100,
}
# Act
SandboxMessagesCleanService.clean_sandbox_messages_by_time_range(
start_from=start_from,
end_before=end_before,
batch_size=batch_size,
dry_run=dry_run,
)
# Assert, expected no exception raised
mock_clean.assert_called_once_with(
start_from=start_from,
end_before=end_before,
graceful_period=21,
batch_size=batch_size,
dry_run=dry_run,
)
@patch.object(SandboxMessagesCleanService, "_clean_sandbox_messages_by_time_range")
def test_with_default_args(self, mock_clean):
"""Test with default args."""
# Arrange
start_from = datetime.datetime(2024, 1, 1)
end_before = datetime.datetime(2024, 2, 1)
mock_clean.return_value = {
"batches": 2,
"total_messages": 50,
"total_deleted": 0,
}
# Act
SandboxMessagesCleanService.clean_sandbox_messages_by_time_range(
start_from=start_from,
end_before=end_before,
)
# Assert
mock_clean.assert_called_once_with(
start_from=start_from,
end_before=end_before,
graceful_period=21,
batch_size=1000,
dry_run=False,
)
def test_invalid_time_range(self):
"""Test invalid time range raises ValueError."""
# Arrange
same_time = datetime.datetime(2024, 1, 1, 12, 0, 0)
# Act & Assert start equals end
with pytest.raises(ValueError, match="start_from .* must be less than end_before"):
SandboxMessagesCleanService.clean_sandbox_messages_by_time_range(
start_from=same_time,
end_before=same_time,
)
# Arrange
start_from = datetime.datetime(2024, 12, 31)
end_before = datetime.datetime(2024, 1, 1)
# Act & Assert start after end
with pytest.raises(ValueError, match="start_from .* must be less than end_before"):
SandboxMessagesCleanService.clean_sandbox_messages_by_time_range(
start_from=start_from,
end_before=end_before,
)
def test_invalid_batch_size(self):
"""Test invalid batch_size raises ValueError."""
# Arrange
start_from = datetime.datetime(2024, 1, 1)
end_before = datetime.datetime(2024, 2, 1)
# Act & Assert batch_size = 0
with pytest.raises(ValueError, match="batch_size .* must be greater than 0"):
SandboxMessagesCleanService.clean_sandbox_messages_by_time_range(
start_from=start_from,
end_before=end_before,
batch_size=0,
)
# Act & Assert batch_size < 0
with pytest.raises(ValueError, match="batch_size .* must be greater than 0"):
SandboxMessagesCleanService.clean_sandbox_messages_by_time_range(
start_from=start_from,
end_before=end_before,
batch_size=-100,
)
class TestCleanSandboxMessagesByDays:
"""Unit tests for clean_sandbox_messages_by_days method."""
@patch.object(SandboxMessagesCleanService, "_clean_sandbox_messages_by_time_range")
def test_default_days(self, mock_clean):
"""Test with default 30 days."""
# Arrange
mock_clean.return_value = {"batches": 3, "total_messages": 75, "total_deleted": 75}
# Act
with patch("services.sandbox_messages_clean_service.datetime") as mock_datetime:
fixed_now = datetime.datetime(2024, 6, 15, 10, 30, 0)
mock_datetime.datetime.now.return_value = fixed_now
mock_datetime.timedelta = datetime.timedelta # Keep original timedelta
SandboxMessagesCleanService.clean_sandbox_messages_by_days()
# Assert
expected_end_before = fixed_now - datetime.timedelta(days=30) # default days=30
mock_clean.assert_called_once_with(
end_before=expected_end_before,
start_from=None,
graceful_period=21,
batch_size=1000,
dry_run=False,
)
@patch.object(SandboxMessagesCleanService, "_clean_sandbox_messages_by_time_range")
def test_custom_days(self, mock_clean):
"""Test with custom number of days."""
# Arrange
custom_days = 90
mock_clean.return_value = {"batches": 10, "total_messages": 500, "total_deleted": 500}
# Act
with patch("services.sandbox_messages_clean_service.datetime") as mock_datetime:
fixed_now = datetime.datetime(2024, 6, 15, 10, 30, 0)
mock_datetime.datetime.now.return_value = fixed_now
mock_datetime.timedelta = datetime.timedelta # Keep original timedelta
result = SandboxMessagesCleanService.clean_sandbox_messages_by_days(days=custom_days)
# Assert
expected_end_before = fixed_now - datetime.timedelta(days=custom_days)
mock_clean.assert_called_once_with(
end_before=expected_end_before,
start_from=None,
graceful_period=21,
batch_size=1000,
dry_run=False,
)
@patch.object(SandboxMessagesCleanService, "_clean_sandbox_messages_by_time_range")
def test_zero_days(self, mock_clean):
"""Test with days=0 (clean all messages before now)."""
# Arrange
mock_clean.return_value = {"batches": 0, "total_messages": 0, "total_deleted": 0}
# Act
with patch("services.sandbox_messages_clean_service.datetime") as mock_datetime:
fixed_now = datetime.datetime(2024, 6, 15, 14, 0, 0)
mock_datetime.datetime.now.return_value = fixed_now
mock_datetime.timedelta = datetime.timedelta # Keep original timedelta
SandboxMessagesCleanService.clean_sandbox_messages_by_days(days=0)
# Assert
expected_end_before = fixed_now - datetime.timedelta(days=0) # same as fixed_now
mock_clean.assert_called_once_with(
end_before=expected_end_before,
start_from=None,
graceful_period=21,
batch_size=1000,
dry_run=False,
)
def test_invalid_batch_size(self):
"""Test invalid batch_size raises ValueError."""
# Act & Assert batch_size = 0
with pytest.raises(ValueError, match="batch_size .* must be greater than 0"):
SandboxMessagesCleanService.clean_sandbox_messages_by_days(
days=30,
batch_size=0,
)
# Act & Assert batch_size < 0
with pytest.raises(ValueError, match="batch_size .* must be greater than 0"):
SandboxMessagesCleanService.clean_sandbox_messages_by_days(
days=30,
batch_size=-500,
)