mirror of
https://github.com/langgenius/dify.git
synced 2025-12-25 01:00:42 -05:00
refactor: implement tenant self queue for rag tasks (#27559)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
@@ -0,0 +1 @@
|
||||
|
||||
@@ -0,0 +1,595 @@
|
||||
"""
|
||||
Integration tests for TenantIsolatedTaskQueue using testcontainers.
|
||||
|
||||
These tests verify the Redis-based task queue functionality with real Redis instances,
|
||||
testing tenant isolation, task serialization, and queue operations in a realistic environment.
|
||||
Includes compatibility tests for migrating from legacy string-only queues.
|
||||
|
||||
All tests use generic naming to avoid coupling to specific business implementations.
|
||||
"""
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
|
||||
from core.rag.pipeline.queue import TaskWrapper, TenantIsolatedTaskQueue
|
||||
from extensions.ext_redis import redis_client
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestTask:
|
||||
"""Test task data structure for testing complex object serialization."""
|
||||
|
||||
task_id: str
|
||||
tenant_id: str
|
||||
data: dict[str, Any]
|
||||
metadata: dict[str, Any]
|
||||
|
||||
|
||||
class TestTenantIsolatedTaskQueueIntegration:
|
||||
"""Integration tests for TenantIsolatedTaskQueue using testcontainers."""
|
||||
|
||||
@pytest.fixture
|
||||
def fake(self):
|
||||
"""Faker instance for generating test data."""
|
||||
return Faker()
|
||||
|
||||
@pytest.fixture
|
||||
def test_tenant_and_account(self, db_session_with_containers, fake):
|
||||
"""Create test tenant and account for testing."""
|
||||
# Create account
|
||||
account = Account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
)
|
||||
db_session_with_containers.add(account)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create tenant
|
||||
tenant = Tenant(
|
||||
name=fake.company(),
|
||||
status="normal",
|
||||
)
|
||||
db_session_with_containers.add(tenant)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create tenant-account join
|
||||
join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role=TenantAccountRole.OWNER,
|
||||
current=True,
|
||||
)
|
||||
db_session_with_containers.add(join)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
return tenant, account
|
||||
|
||||
@pytest.fixture
|
||||
def test_queue(self, test_tenant_and_account):
|
||||
"""Create a generic test queue for testing."""
|
||||
tenant, _ = test_tenant_and_account
|
||||
return TenantIsolatedTaskQueue(tenant.id, "test_queue")
|
||||
|
||||
@pytest.fixture
|
||||
def secondary_queue(self, test_tenant_and_account):
|
||||
"""Create a secondary test queue for testing isolation."""
|
||||
tenant, _ = test_tenant_and_account
|
||||
return TenantIsolatedTaskQueue(tenant.id, "secondary_queue")
|
||||
|
||||
def test_queue_initialization(self, test_tenant_and_account):
|
||||
"""Test queue initialization with correct key generation."""
|
||||
tenant, _ = test_tenant_and_account
|
||||
queue = TenantIsolatedTaskQueue(tenant.id, "test-key")
|
||||
|
||||
assert queue._tenant_id == tenant.id
|
||||
assert queue._unique_key == "test-key"
|
||||
assert queue._queue == f"tenant_self_test-key_task_queue:{tenant.id}"
|
||||
assert queue._task_key == f"tenant_test-key_task:{tenant.id}"
|
||||
|
||||
def test_tenant_isolation(self, test_tenant_and_account, db_session_with_containers, fake):
|
||||
"""Test that different tenants have isolated queues."""
|
||||
tenant1, _ = test_tenant_and_account
|
||||
|
||||
# Create second tenant
|
||||
tenant2 = Tenant(
|
||||
name=fake.company(),
|
||||
status="normal",
|
||||
)
|
||||
db_session_with_containers.add(tenant2)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
queue1 = TenantIsolatedTaskQueue(tenant1.id, "same-key")
|
||||
queue2 = TenantIsolatedTaskQueue(tenant2.id, "same-key")
|
||||
|
||||
assert queue1._queue != queue2._queue
|
||||
assert queue1._task_key != queue2._task_key
|
||||
assert queue1._queue == f"tenant_self_same-key_task_queue:{tenant1.id}"
|
||||
assert queue2._queue == f"tenant_self_same-key_task_queue:{tenant2.id}"
|
||||
|
||||
def test_key_isolation(self, test_tenant_and_account):
|
||||
"""Test that different keys have isolated queues."""
|
||||
tenant, _ = test_tenant_and_account
|
||||
queue1 = TenantIsolatedTaskQueue(tenant.id, "key1")
|
||||
queue2 = TenantIsolatedTaskQueue(tenant.id, "key2")
|
||||
|
||||
assert queue1._queue != queue2._queue
|
||||
assert queue1._task_key != queue2._task_key
|
||||
assert queue1._queue == f"tenant_self_key1_task_queue:{tenant.id}"
|
||||
assert queue2._queue == f"tenant_self_key2_task_queue:{tenant.id}"
|
||||
|
||||
def test_task_key_operations(self, test_queue):
|
||||
"""Test task key operations (get, set, delete)."""
|
||||
# Initially no task key should exist
|
||||
assert test_queue.get_task_key() is None
|
||||
|
||||
# Set task waiting time with default TTL
|
||||
test_queue.set_task_waiting_time()
|
||||
task_key = test_queue.get_task_key()
|
||||
# Redis returns bytes, convert to string for comparison
|
||||
assert task_key in (b"1", "1")
|
||||
|
||||
# Set task waiting time with custom TTL
|
||||
custom_ttl = 30
|
||||
test_queue.set_task_waiting_time(custom_ttl)
|
||||
task_key = test_queue.get_task_key()
|
||||
assert task_key in (b"1", "1")
|
||||
|
||||
# Delete task key
|
||||
test_queue.delete_task_key()
|
||||
assert test_queue.get_task_key() is None
|
||||
|
||||
def test_push_and_pull_string_tasks(self, test_queue):
|
||||
"""Test pushing and pulling string tasks."""
|
||||
tasks = ["task1", "task2", "task3"]
|
||||
|
||||
# Push tasks
|
||||
test_queue.push_tasks(tasks)
|
||||
|
||||
# Pull tasks (FIFO order)
|
||||
pulled_tasks = test_queue.pull_tasks(3)
|
||||
|
||||
# Should get tasks in FIFO order (lpush + rpop = FIFO)
|
||||
assert pulled_tasks == ["task1", "task2", "task3"]
|
||||
|
||||
def test_push_and_pull_multiple_tasks(self, test_queue):
|
||||
"""Test pushing and pulling multiple tasks at once."""
|
||||
tasks = ["task1", "task2", "task3", "task4", "task5"]
|
||||
|
||||
# Push tasks
|
||||
test_queue.push_tasks(tasks)
|
||||
|
||||
# Pull multiple tasks
|
||||
pulled_tasks = test_queue.pull_tasks(3)
|
||||
assert len(pulled_tasks) == 3
|
||||
assert pulled_tasks == ["task1", "task2", "task3"]
|
||||
|
||||
# Pull remaining tasks
|
||||
remaining_tasks = test_queue.pull_tasks(5)
|
||||
assert len(remaining_tasks) == 2
|
||||
assert remaining_tasks == ["task4", "task5"]
|
||||
|
||||
def test_push_and_pull_complex_objects(self, test_queue, fake):
|
||||
"""Test pushing and pulling complex object tasks."""
|
||||
# Create complex task objects as dictionaries (not dataclass instances)
|
||||
tasks = [
|
||||
{
|
||||
"task_id": str(uuid4()),
|
||||
"tenant_id": test_queue._tenant_id,
|
||||
"data": {
|
||||
"file_id": str(uuid4()),
|
||||
"content": fake.text(),
|
||||
"metadata": {"size": fake.random_int(1000, 10000)},
|
||||
},
|
||||
"metadata": {"created_at": fake.iso8601(), "tags": fake.words(3)},
|
||||
},
|
||||
{
|
||||
"task_id": str(uuid4()),
|
||||
"tenant_id": test_queue._tenant_id,
|
||||
"data": {
|
||||
"file_id": str(uuid4()),
|
||||
"content": "测试中文内容",
|
||||
"metadata": {"size": fake.random_int(1000, 10000)},
|
||||
},
|
||||
"metadata": {"created_at": fake.iso8601(), "tags": ["中文", "测试", "emoji🚀"]},
|
||||
},
|
||||
]
|
||||
|
||||
# Push complex tasks
|
||||
test_queue.push_tasks(tasks)
|
||||
|
||||
# Pull tasks
|
||||
pulled_tasks = test_queue.pull_tasks(2)
|
||||
assert len(pulled_tasks) == 2
|
||||
|
||||
# Verify deserialized tasks match original (FIFO order)
|
||||
for i, pulled_task in enumerate(pulled_tasks):
|
||||
original_task = tasks[i] # FIFO order
|
||||
assert isinstance(pulled_task, dict)
|
||||
assert pulled_task["task_id"] == original_task["task_id"]
|
||||
assert pulled_task["tenant_id"] == original_task["tenant_id"]
|
||||
assert pulled_task["data"] == original_task["data"]
|
||||
assert pulled_task["metadata"] == original_task["metadata"]
|
||||
|
||||
def test_mixed_task_types(self, test_queue, fake):
|
||||
"""Test pushing and pulling mixed string and object tasks."""
|
||||
string_task = "simple_string_task"
|
||||
object_task = {
|
||||
"task_id": str(uuid4()),
|
||||
"dataset_id": str(uuid4()),
|
||||
"document_ids": [str(uuid4()) for _ in range(3)],
|
||||
}
|
||||
|
||||
tasks = [string_task, object_task, "another_string"]
|
||||
|
||||
# Push mixed tasks
|
||||
test_queue.push_tasks(tasks)
|
||||
|
||||
# Pull all tasks
|
||||
pulled_tasks = test_queue.pull_tasks(3)
|
||||
assert len(pulled_tasks) == 3
|
||||
|
||||
# Verify types and content
|
||||
assert pulled_tasks[0] == string_task
|
||||
assert isinstance(pulled_tasks[1], dict)
|
||||
assert pulled_tasks[1] == object_task
|
||||
assert pulled_tasks[2] == "another_string"
|
||||
|
||||
def test_empty_queue_operations(self, test_queue):
|
||||
"""Test operations on empty queue."""
|
||||
# Pull from empty queue
|
||||
tasks = test_queue.pull_tasks(5)
|
||||
assert tasks == []
|
||||
|
||||
# Pull zero or negative count
|
||||
assert test_queue.pull_tasks(0) == []
|
||||
assert test_queue.pull_tasks(-1) == []
|
||||
|
||||
def test_task_ttl_expiration(self, test_queue):
|
||||
"""Test task key TTL expiration."""
|
||||
# Set task with short TTL
|
||||
short_ttl = 2
|
||||
test_queue.set_task_waiting_time(short_ttl)
|
||||
|
||||
# Verify task key exists
|
||||
assert test_queue.get_task_key() == b"1" or test_queue.get_task_key() == "1"
|
||||
|
||||
# Wait for TTL to expire
|
||||
time.sleep(short_ttl + 1)
|
||||
|
||||
# Verify task key has expired
|
||||
assert test_queue.get_task_key() is None
|
||||
|
||||
def test_large_task_batch(self, test_queue, fake):
|
||||
"""Test handling large batches of tasks."""
|
||||
# Create large batch of tasks
|
||||
large_batch = []
|
||||
for i in range(100):
|
||||
task = {
|
||||
"task_id": str(uuid4()),
|
||||
"index": i,
|
||||
"data": fake.text(max_nb_chars=100),
|
||||
"metadata": {"batch_id": str(uuid4())},
|
||||
}
|
||||
large_batch.append(task)
|
||||
|
||||
# Push large batch
|
||||
test_queue.push_tasks(large_batch)
|
||||
|
||||
# Pull all tasks
|
||||
pulled_tasks = test_queue.pull_tasks(100)
|
||||
assert len(pulled_tasks) == 100
|
||||
|
||||
# Verify all tasks were retrieved correctly (FIFO order)
|
||||
for i, task in enumerate(pulled_tasks):
|
||||
assert isinstance(task, dict)
|
||||
assert task["index"] == i # FIFO order
|
||||
|
||||
def test_queue_operations_isolation(self, test_tenant_and_account, fake):
|
||||
"""Test concurrent operations on different queues."""
|
||||
tenant, _ = test_tenant_and_account
|
||||
|
||||
# Create multiple queues for the same tenant
|
||||
queue1 = TenantIsolatedTaskQueue(tenant.id, "queue1")
|
||||
queue2 = TenantIsolatedTaskQueue(tenant.id, "queue2")
|
||||
|
||||
# Push tasks to different queues
|
||||
queue1.push_tasks(["task1_queue1", "task2_queue1"])
|
||||
queue2.push_tasks(["task1_queue2", "task2_queue2"])
|
||||
|
||||
# Verify queues are isolated
|
||||
tasks1 = queue1.pull_tasks(2)
|
||||
tasks2 = queue2.pull_tasks(2)
|
||||
|
||||
assert tasks1 == ["task1_queue1", "task2_queue1"]
|
||||
assert tasks2 == ["task1_queue2", "task2_queue2"]
|
||||
assert tasks1 != tasks2
|
||||
|
||||
def test_task_wrapper_serialization_roundtrip(self, test_queue, fake):
|
||||
"""Test TaskWrapper serialization and deserialization roundtrip."""
|
||||
# Create complex nested data
|
||||
complex_data = {
|
||||
"id": str(uuid4()),
|
||||
"nested": {"deep": {"value": "test", "numbers": [1, 2, 3, 4, 5], "unicode": "测试中文", "emoji": "🚀"}},
|
||||
"metadata": {"created_at": fake.iso8601(), "tags": ["tag1", "tag2", "tag3"]},
|
||||
}
|
||||
|
||||
# Create wrapper and serialize
|
||||
wrapper = TaskWrapper(data=complex_data)
|
||||
serialized = wrapper.serialize()
|
||||
|
||||
# Verify serialization
|
||||
assert isinstance(serialized, str)
|
||||
assert "测试中文" in serialized
|
||||
assert "🚀" in serialized
|
||||
|
||||
# Deserialize and verify
|
||||
deserialized_wrapper = TaskWrapper.deserialize(serialized)
|
||||
assert deserialized_wrapper.data == complex_data
|
||||
|
||||
def test_error_handling_invalid_json(self, test_queue):
|
||||
"""Test error handling for invalid JSON in wrapped tasks."""
|
||||
# Manually create invalid JSON task (not a valid TaskWrapper JSON)
|
||||
invalid_json_task = "invalid json data"
|
||||
|
||||
# Push invalid task directly to Redis
|
||||
redis_client.lpush(test_queue._queue, invalid_json_task)
|
||||
|
||||
# Pull task - should fall back to string since it's not valid JSON
|
||||
task = test_queue.pull_tasks(1)
|
||||
assert task[0] == invalid_json_task
|
||||
|
||||
def test_real_world_batch_processing_scenario(self, test_queue, fake):
|
||||
"""Test realistic batch processing scenario."""
|
||||
# Simulate batch processing tasks
|
||||
batch_tasks = []
|
||||
for i in range(3):
|
||||
task = {
|
||||
"file_id": str(uuid4()),
|
||||
"tenant_id": test_queue._tenant_id,
|
||||
"user_id": str(uuid4()),
|
||||
"processing_config": {
|
||||
"model": fake.random_element(["model_a", "model_b", "model_c"]),
|
||||
"temperature": fake.random.uniform(0.1, 1.0),
|
||||
"max_tokens": fake.random_int(1000, 4000),
|
||||
},
|
||||
"metadata": {
|
||||
"source": fake.random_element(["upload", "api", "webhook"]),
|
||||
"priority": fake.random_element(["low", "normal", "high"]),
|
||||
},
|
||||
}
|
||||
batch_tasks.append(task)
|
||||
|
||||
# Push tasks
|
||||
test_queue.push_tasks(batch_tasks)
|
||||
|
||||
# Process tasks in batches
|
||||
batch_size = 2
|
||||
processed_tasks = []
|
||||
|
||||
while True:
|
||||
batch = test_queue.pull_tasks(batch_size)
|
||||
if not batch:
|
||||
break
|
||||
|
||||
processed_tasks.extend(batch)
|
||||
|
||||
# Verify all tasks were processed
|
||||
assert len(processed_tasks) == 3
|
||||
|
||||
# Verify task structure
|
||||
for task in processed_tasks:
|
||||
assert isinstance(task, dict)
|
||||
assert "file_id" in task
|
||||
assert "tenant_id" in task
|
||||
assert "processing_config" in task
|
||||
assert "metadata" in task
|
||||
assert task["tenant_id"] == test_queue._tenant_id
|
||||
|
||||
|
||||
class TestTenantIsolatedTaskQueueCompatibility:
|
||||
"""Compatibility tests for migrating from legacy string-only queues."""
|
||||
|
||||
@pytest.fixture
|
||||
def fake(self):
|
||||
"""Faker instance for generating test data."""
|
||||
return Faker()
|
||||
|
||||
@pytest.fixture
|
||||
def test_tenant_and_account(self, db_session_with_containers, fake):
|
||||
"""Create test tenant and account for testing."""
|
||||
# Create account
|
||||
account = Account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
)
|
||||
db_session_with_containers.add(account)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create tenant
|
||||
tenant = Tenant(
|
||||
name=fake.company(),
|
||||
status="normal",
|
||||
)
|
||||
db_session_with_containers.add(tenant)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create tenant-account join
|
||||
join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role=TenantAccountRole.OWNER,
|
||||
current=True,
|
||||
)
|
||||
db_session_with_containers.add(join)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
return tenant, account
|
||||
|
||||
def test_legacy_string_queue_compatibility(self, test_tenant_and_account, fake):
|
||||
"""
|
||||
Test compatibility with legacy queues containing only string data.
|
||||
|
||||
This simulates the scenario where Redis queues already contain string data
|
||||
from the old architecture, and we need to ensure the new code can read them.
|
||||
"""
|
||||
tenant, _ = test_tenant_and_account
|
||||
queue = TenantIsolatedTaskQueue(tenant.id, "legacy_queue")
|
||||
|
||||
# Simulate legacy string data in Redis queue (using old format)
|
||||
legacy_strings = ["legacy_task_1", "legacy_task_2", "legacy_task_3", "legacy_task_4", "legacy_task_5"]
|
||||
|
||||
# Manually push legacy strings directly to Redis (simulating old system)
|
||||
for legacy_string in legacy_strings:
|
||||
redis_client.lpush(queue._queue, legacy_string)
|
||||
|
||||
# Verify new code can read legacy string data
|
||||
pulled_tasks = queue.pull_tasks(5)
|
||||
assert len(pulled_tasks) == 5
|
||||
|
||||
# Verify all tasks are strings (not wrapped)
|
||||
for task in pulled_tasks:
|
||||
assert isinstance(task, str)
|
||||
assert task.startswith("legacy_task_")
|
||||
|
||||
# Verify order (FIFO from Redis list)
|
||||
expected_order = ["legacy_task_1", "legacy_task_2", "legacy_task_3", "legacy_task_4", "legacy_task_5"]
|
||||
assert pulled_tasks == expected_order
|
||||
|
||||
def test_legacy_queue_migration_scenario(self, test_tenant_and_account, fake):
|
||||
"""
|
||||
Test complete migration scenario from legacy to new system.
|
||||
|
||||
This simulates the real-world scenario where:
|
||||
1. Legacy system has string data in Redis
|
||||
2. New system starts processing the same queue
|
||||
3. Both legacy and new tasks coexist during migration
|
||||
4. New system can handle both formats seamlessly
|
||||
"""
|
||||
tenant, _ = test_tenant_and_account
|
||||
queue = TenantIsolatedTaskQueue(tenant.id, "migration_queue")
|
||||
|
||||
# Phase 1: Legacy system has data
|
||||
legacy_tasks = [f"legacy_resource_{i}" for i in range(1, 6)]
|
||||
redis_client.lpush(queue._queue, *legacy_tasks)
|
||||
|
||||
# Phase 2: New system starts processing legacy data
|
||||
processed_legacy = []
|
||||
while True:
|
||||
tasks = queue.pull_tasks(1)
|
||||
if not tasks:
|
||||
break
|
||||
processed_legacy.extend(tasks)
|
||||
|
||||
# Verify legacy data was processed correctly
|
||||
assert len(processed_legacy) == 5
|
||||
for task in processed_legacy:
|
||||
assert isinstance(task, str)
|
||||
assert task.startswith("legacy_resource_")
|
||||
|
||||
# Phase 3: New system adds new tasks (mixed types)
|
||||
new_string_tasks = ["new_resource_1", "new_resource_2"]
|
||||
new_object_tasks = [
|
||||
{
|
||||
"resource_id": str(uuid4()),
|
||||
"tenant_id": tenant.id,
|
||||
"processing_type": "new_system",
|
||||
"metadata": {"version": "2.0", "features": ["ai", "ml"]},
|
||||
},
|
||||
{
|
||||
"resource_id": str(uuid4()),
|
||||
"tenant_id": tenant.id,
|
||||
"processing_type": "new_system",
|
||||
"metadata": {"version": "2.0", "features": ["ai", "ml"]},
|
||||
},
|
||||
]
|
||||
|
||||
# Push new tasks using new system
|
||||
queue.push_tasks(new_string_tasks)
|
||||
queue.push_tasks(new_object_tasks)
|
||||
|
||||
# Phase 4: Process all new tasks
|
||||
processed_new = []
|
||||
while True:
|
||||
tasks = queue.pull_tasks(1)
|
||||
if not tasks:
|
||||
break
|
||||
processed_new.extend(tasks)
|
||||
|
||||
# Verify new tasks were processed correctly
|
||||
assert len(processed_new) == 4
|
||||
|
||||
string_tasks = [task for task in processed_new if isinstance(task, str)]
|
||||
object_tasks = [task for task in processed_new if isinstance(task, dict)]
|
||||
|
||||
assert len(string_tasks) == 2
|
||||
assert len(object_tasks) == 2
|
||||
|
||||
# Verify string tasks
|
||||
for task in string_tasks:
|
||||
assert task.startswith("new_resource_")
|
||||
|
||||
# Verify object tasks
|
||||
for task in object_tasks:
|
||||
assert isinstance(task, dict)
|
||||
assert "resource_id" in task
|
||||
assert "tenant_id" in task
|
||||
assert task["tenant_id"] == tenant.id
|
||||
assert task["processing_type"] == "new_system"
|
||||
|
||||
def test_legacy_queue_error_recovery(self, test_tenant_and_account, fake):
|
||||
"""
|
||||
Test error recovery when legacy queue contains malformed data.
|
||||
|
||||
This ensures the new system can gracefully handle corrupted or
|
||||
malformed legacy data without crashing.
|
||||
"""
|
||||
tenant, _ = test_tenant_and_account
|
||||
queue = TenantIsolatedTaskQueue(tenant.id, "error_recovery_queue")
|
||||
|
||||
# Create mix of valid and malformed legacy data
|
||||
mixed_legacy_data = [
|
||||
"valid_legacy_task_1",
|
||||
"valid_legacy_task_2",
|
||||
"malformed_data_string", # This should be treated as string
|
||||
"valid_legacy_task_3",
|
||||
"invalid_json_not_taskwrapper_format", # This should fall back to string (not valid TaskWrapper JSON)
|
||||
"valid_legacy_task_4",
|
||||
]
|
||||
|
||||
# Manually push mixed data directly to Redis
|
||||
redis_client.lpush(queue._queue, *mixed_legacy_data)
|
||||
|
||||
# Process all tasks
|
||||
processed_tasks = []
|
||||
while True:
|
||||
tasks = queue.pull_tasks(1)
|
||||
if not tasks:
|
||||
break
|
||||
processed_tasks.extend(tasks)
|
||||
|
||||
# Verify all tasks were processed (no crashes)
|
||||
assert len(processed_tasks) == 6
|
||||
|
||||
# Verify all tasks are strings (malformed data falls back to string)
|
||||
for task in processed_tasks:
|
||||
assert isinstance(task, str)
|
||||
|
||||
# Verify valid tasks are preserved
|
||||
valid_tasks = [task for task in processed_tasks if task.startswith("valid_legacy_task_")]
|
||||
assert len(valid_tasks) == 4
|
||||
|
||||
# Verify malformed data is handled gracefully
|
||||
malformed_tasks = [task for task in processed_tasks if not task.startswith("valid_legacy_task_")]
|
||||
assert len(malformed_tasks) == 2
|
||||
assert "malformed_data_string" in malformed_tasks
|
||||
assert "invalid_json_not_taskwrapper_format" in malformed_tasks
|
||||
@@ -1,17 +1,33 @@
|
||||
from dataclasses import asdict
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
|
||||
from core.entities.document_task import DocumentTask
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_database import db
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset, Document
|
||||
from tasks.document_indexing_task import document_indexing_task
|
||||
from tasks.document_indexing_task import (
|
||||
_document_indexing, # Core function
|
||||
_document_indexing_with_tenant_queue, # Tenant queue wrapper function
|
||||
document_indexing_task, # Deprecated old interface
|
||||
normal_document_indexing_task, # New normal task
|
||||
priority_document_indexing_task, # New priority task
|
||||
)
|
||||
|
||||
|
||||
class TestDocumentIndexingTask:
|
||||
"""Integration tests for document_indexing_task using testcontainers."""
|
||||
class TestDocumentIndexingTasks:
|
||||
"""Integration tests for document indexing tasks using testcontainers.
|
||||
|
||||
This test class covers:
|
||||
- Core _document_indexing function
|
||||
- Deprecated document_indexing_task function
|
||||
- New normal_document_indexing_task function
|
||||
- New priority_document_indexing_task function
|
||||
- Tenant queue wrapper _document_indexing_with_tenant_queue function
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_external_service_dependencies(self):
|
||||
@@ -224,7 +240,7 @@ class TestDocumentIndexingTask:
|
||||
document_ids = [doc.id for doc in documents]
|
||||
|
||||
# Act: Execute the task
|
||||
document_indexing_task(dataset.id, document_ids)
|
||||
_document_indexing(dataset.id, document_ids)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
# Verify indexing runner was called correctly
|
||||
@@ -232,10 +248,11 @@ class TestDocumentIndexingTask:
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||
|
||||
# Verify documents were updated to parsing status
|
||||
for document in documents:
|
||||
db.session.refresh(document)
|
||||
assert document.indexing_status == "parsing"
|
||||
assert document.processing_started_at is not None
|
||||
# Re-query documents from database since _document_indexing uses a different session
|
||||
for doc_id in document_ids:
|
||||
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
|
||||
assert updated_document.indexing_status == "parsing"
|
||||
assert updated_document.processing_started_at is not None
|
||||
|
||||
# Verify the run method was called with correct documents
|
||||
call_args = mock_external_service_dependencies["indexing_runner_instance"].run.call_args
|
||||
@@ -261,7 +278,7 @@ class TestDocumentIndexingTask:
|
||||
document_ids = [fake.uuid4() for _ in range(3)]
|
||||
|
||||
# Act: Execute the task with non-existent dataset
|
||||
document_indexing_task(non_existent_dataset_id, document_ids)
|
||||
_document_indexing(non_existent_dataset_id, document_ids)
|
||||
|
||||
# Assert: Verify no processing occurred
|
||||
mock_external_service_dependencies["indexing_runner"].assert_not_called()
|
||||
@@ -291,17 +308,18 @@ class TestDocumentIndexingTask:
|
||||
all_document_ids = existing_document_ids + non_existent_document_ids
|
||||
|
||||
# Act: Execute the task with mixed document IDs
|
||||
document_indexing_task(dataset.id, all_document_ids)
|
||||
_document_indexing(dataset.id, all_document_ids)
|
||||
|
||||
# Assert: Verify only existing documents were processed
|
||||
mock_external_service_dependencies["indexing_runner"].assert_called_once()
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||
|
||||
# Verify only existing documents were updated
|
||||
for document in documents:
|
||||
db.session.refresh(document)
|
||||
assert document.indexing_status == "parsing"
|
||||
assert document.processing_started_at is not None
|
||||
# Re-query documents from database since _document_indexing uses a different session
|
||||
for doc_id in existing_document_ids:
|
||||
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
|
||||
assert updated_document.indexing_status == "parsing"
|
||||
assert updated_document.processing_started_at is not None
|
||||
|
||||
# Verify the run method was called with only existing documents
|
||||
call_args = mock_external_service_dependencies["indexing_runner_instance"].run.call_args
|
||||
@@ -333,7 +351,7 @@ class TestDocumentIndexingTask:
|
||||
)
|
||||
|
||||
# Act: Execute the task
|
||||
document_indexing_task(dataset.id, document_ids)
|
||||
_document_indexing(dataset.id, document_ids)
|
||||
|
||||
# Assert: Verify exception was handled gracefully
|
||||
# The task should complete without raising exceptions
|
||||
@@ -341,10 +359,11 @@ class TestDocumentIndexingTask:
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||
|
||||
# Verify documents were still updated to parsing status before the exception
|
||||
for document in documents:
|
||||
db.session.refresh(document)
|
||||
assert document.indexing_status == "parsing"
|
||||
assert document.processing_started_at is not None
|
||||
# Re-query documents from database since _document_indexing close the session
|
||||
for doc_id in document_ids:
|
||||
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
|
||||
assert updated_document.indexing_status == "parsing"
|
||||
assert updated_document.processing_started_at is not None
|
||||
|
||||
def test_document_indexing_task_mixed_document_states(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
@@ -407,17 +426,18 @@ class TestDocumentIndexingTask:
|
||||
document_ids = [doc.id for doc in all_documents]
|
||||
|
||||
# Act: Execute the task with mixed document states
|
||||
document_indexing_task(dataset.id, document_ids)
|
||||
_document_indexing(dataset.id, document_ids)
|
||||
|
||||
# Assert: Verify processing
|
||||
mock_external_service_dependencies["indexing_runner"].assert_called_once()
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||
|
||||
# Verify all documents were updated to parsing status
|
||||
for document in all_documents:
|
||||
db.session.refresh(document)
|
||||
assert document.indexing_status == "parsing"
|
||||
assert document.processing_started_at is not None
|
||||
# Re-query documents from database since _document_indexing uses a different session
|
||||
for doc_id in document_ids:
|
||||
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
|
||||
assert updated_document.indexing_status == "parsing"
|
||||
assert updated_document.processing_started_at is not None
|
||||
|
||||
# Verify the run method was called with all documents
|
||||
call_args = mock_external_service_dependencies["indexing_runner_instance"].run.call_args
|
||||
@@ -470,15 +490,16 @@ class TestDocumentIndexingTask:
|
||||
document_ids = [doc.id for doc in all_documents]
|
||||
|
||||
# Act: Execute the task with too many documents for sandbox plan
|
||||
document_indexing_task(dataset.id, document_ids)
|
||||
_document_indexing(dataset.id, document_ids)
|
||||
|
||||
# Assert: Verify error handling
|
||||
for document in all_documents:
|
||||
db.session.refresh(document)
|
||||
assert document.indexing_status == "error"
|
||||
assert document.error is not None
|
||||
assert "batch upload" in document.error
|
||||
assert document.stopped_at is not None
|
||||
# Re-query documents from database since _document_indexing uses a different session
|
||||
for doc_id in document_ids:
|
||||
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
|
||||
assert updated_document.indexing_status == "error"
|
||||
assert updated_document.error is not None
|
||||
assert "batch upload" in updated_document.error
|
||||
assert updated_document.stopped_at is not None
|
||||
|
||||
# Verify no indexing runner was called
|
||||
mock_external_service_dependencies["indexing_runner"].assert_not_called()
|
||||
@@ -503,17 +524,18 @@ class TestDocumentIndexingTask:
|
||||
document_ids = [doc.id for doc in documents]
|
||||
|
||||
# Act: Execute the task with billing disabled
|
||||
document_indexing_task(dataset.id, document_ids)
|
||||
_document_indexing(dataset.id, document_ids)
|
||||
|
||||
# Assert: Verify successful processing
|
||||
mock_external_service_dependencies["indexing_runner"].assert_called_once()
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||
|
||||
# Verify documents were updated to parsing status
|
||||
for document in documents:
|
||||
db.session.refresh(document)
|
||||
assert document.indexing_status == "parsing"
|
||||
assert document.processing_started_at is not None
|
||||
# Re-query documents from database since _document_indexing uses a different session
|
||||
for doc_id in document_ids:
|
||||
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
|
||||
assert updated_document.indexing_status == "parsing"
|
||||
assert updated_document.processing_started_at is not None
|
||||
|
||||
def test_document_indexing_task_document_is_paused_error(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
@@ -541,7 +563,7 @@ class TestDocumentIndexingTask:
|
||||
)
|
||||
|
||||
# Act: Execute the task
|
||||
document_indexing_task(dataset.id, document_ids)
|
||||
_document_indexing(dataset.id, document_ids)
|
||||
|
||||
# Assert: Verify exception was handled gracefully
|
||||
# The task should complete without raising exceptions
|
||||
@@ -549,7 +571,317 @@ class TestDocumentIndexingTask:
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||
|
||||
# Verify documents were still updated to parsing status before the exception
|
||||
for document in documents:
|
||||
db.session.refresh(document)
|
||||
assert document.indexing_status == "parsing"
|
||||
assert document.processing_started_at is not None
|
||||
# Re-query documents from database since _document_indexing uses a different session
|
||||
for doc_id in document_ids:
|
||||
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
|
||||
assert updated_document.indexing_status == "parsing"
|
||||
assert updated_document.processing_started_at is not None
|
||||
|
||||
# ==================== NEW TESTS FOR REFACTORED FUNCTIONS ====================
|
||||
def test_old_document_indexing_task_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test document_indexing_task basic functionality.
|
||||
|
||||
This test verifies:
|
||||
- Task function calls the wrapper correctly
|
||||
- Basic parameter passing works
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
dataset, documents = self._create_test_dataset_and_documents(
|
||||
db_session_with_containers, mock_external_service_dependencies, document_count=1
|
||||
)
|
||||
document_ids = [doc.id for doc in documents]
|
||||
|
||||
# Act: Execute the deprecated task (it only takes 2 parameters)
|
||||
document_indexing_task(dataset.id, document_ids)
|
||||
|
||||
# Assert: Verify processing occurred (core logic is tested in _document_indexing tests)
|
||||
mock_external_service_dependencies["indexing_runner"].assert_called_once()
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||
|
||||
def test_normal_document_indexing_task_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test normal_document_indexing_task basic functionality.
|
||||
|
||||
This test verifies:
|
||||
- Task function calls the wrapper correctly
|
||||
- Basic parameter passing works
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
dataset, documents = self._create_test_dataset_and_documents(
|
||||
db_session_with_containers, mock_external_service_dependencies, document_count=1
|
||||
)
|
||||
document_ids = [doc.id for doc in documents]
|
||||
tenant_id = dataset.tenant_id
|
||||
|
||||
# Act: Execute the new normal task
|
||||
normal_document_indexing_task(tenant_id, dataset.id, document_ids)
|
||||
|
||||
# Assert: Verify processing occurred (core logic is tested in _document_indexing tests)
|
||||
mock_external_service_dependencies["indexing_runner"].assert_called_once()
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||
|
||||
def test_priority_document_indexing_task_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test priority_document_indexing_task basic functionality.
|
||||
|
||||
This test verifies:
|
||||
- Task function calls the wrapper correctly
|
||||
- Basic parameter passing works
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
dataset, documents = self._create_test_dataset_and_documents(
|
||||
db_session_with_containers, mock_external_service_dependencies, document_count=1
|
||||
)
|
||||
document_ids = [doc.id for doc in documents]
|
||||
tenant_id = dataset.tenant_id
|
||||
|
||||
# Act: Execute the new priority task
|
||||
priority_document_indexing_task(tenant_id, dataset.id, document_ids)
|
||||
|
||||
# Assert: Verify processing occurred (core logic is tested in _document_indexing tests)
|
||||
mock_external_service_dependencies["indexing_runner"].assert_called_once()
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||
|
||||
def test_document_indexing_with_tenant_queue_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test _document_indexing_with_tenant_queue function with no waiting tasks.
|
||||
|
||||
This test verifies:
|
||||
- Core indexing logic execution (same as _document_indexing)
|
||||
- Tenant queue cleanup when no waiting tasks
|
||||
- Task function parameter passing
|
||||
- Queue management after processing
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
dataset, documents = self._create_test_dataset_and_documents(
|
||||
db_session_with_containers, mock_external_service_dependencies, document_count=2
|
||||
)
|
||||
document_ids = [doc.id for doc in documents]
|
||||
tenant_id = dataset.tenant_id
|
||||
|
||||
# Mock the task function
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
mock_task_func = MagicMock()
|
||||
|
||||
# Act: Execute the wrapper function
|
||||
_document_indexing_with_tenant_queue(tenant_id, dataset.id, document_ids, mock_task_func)
|
||||
|
||||
# Assert: Verify core processing occurred (same as _document_indexing)
|
||||
mock_external_service_dependencies["indexing_runner"].assert_called_once()
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||
|
||||
# Verify documents were updated (same as _document_indexing)
|
||||
# Re-query documents from database since _document_indexing uses a different session
|
||||
for doc_id in document_ids:
|
||||
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
|
||||
assert updated_document.indexing_status == "parsing"
|
||||
assert updated_document.processing_started_at is not None
|
||||
|
||||
# Verify the run method was called with correct documents
|
||||
call_args = mock_external_service_dependencies["indexing_runner_instance"].run.call_args
|
||||
assert call_args is not None
|
||||
processed_documents = call_args[0][0]
|
||||
assert len(processed_documents) == 2
|
||||
|
||||
# Verify task function was not called (no waiting tasks)
|
||||
mock_task_func.delay.assert_not_called()
|
||||
|
||||
def test_document_indexing_with_tenant_queue_with_waiting_tasks(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test _document_indexing_with_tenant_queue function with waiting tasks in queue using real Redis.
|
||||
|
||||
This test verifies:
|
||||
- Core indexing logic execution
|
||||
- Real Redis-based tenant queue processing of waiting tasks
|
||||
- Task function calls for waiting tasks
|
||||
- Queue management with multiple tasks using actual Redis operations
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
dataset, documents = self._create_test_dataset_and_documents(
|
||||
db_session_with_containers, mock_external_service_dependencies, document_count=1
|
||||
)
|
||||
document_ids = [doc.id for doc in documents]
|
||||
tenant_id = dataset.tenant_id
|
||||
dataset_id = dataset.id
|
||||
|
||||
# Mock the task function
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
mock_task_func = MagicMock()
|
||||
|
||||
# Use real Redis for TenantIsolatedTaskQueue
|
||||
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
|
||||
|
||||
# Create real queue instance
|
||||
queue = TenantIsolatedTaskQueue(tenant_id, "document_indexing")
|
||||
|
||||
# Add waiting tasks to the real Redis queue
|
||||
waiting_tasks = [
|
||||
DocumentTask(tenant_id=tenant_id, dataset_id=dataset.id, document_ids=["waiting-doc-1"]),
|
||||
DocumentTask(tenant_id=tenant_id, dataset_id=dataset.id, document_ids=["waiting-doc-2"]),
|
||||
]
|
||||
# Convert DocumentTask objects to dictionaries for serialization
|
||||
waiting_task_dicts = [asdict(task) for task in waiting_tasks]
|
||||
queue.push_tasks(waiting_task_dicts)
|
||||
|
||||
# Act: Execute the wrapper function
|
||||
_document_indexing_with_tenant_queue(tenant_id, dataset.id, document_ids, mock_task_func)
|
||||
|
||||
# Assert: Verify core processing occurred
|
||||
mock_external_service_dependencies["indexing_runner"].assert_called_once()
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||
|
||||
# Verify task function was called for each waiting task
|
||||
assert mock_task_func.delay.call_count == 1
|
||||
|
||||
# Verify correct parameters for each call
|
||||
calls = mock_task_func.delay.call_args_list
|
||||
assert calls[0][1] == {"tenant_id": tenant_id, "dataset_id": dataset_id, "document_ids": ["waiting-doc-1"]}
|
||||
|
||||
# Verify queue is empty after processing (tasks were pulled)
|
||||
remaining_tasks = queue.pull_tasks(count=10) # Pull more than we added
|
||||
assert len(remaining_tasks) == 1
|
||||
|
||||
def test_document_indexing_with_tenant_queue_error_handling(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test error handling in _document_indexing_with_tenant_queue using real Redis.
|
||||
|
||||
This test verifies:
|
||||
- Exception handling during core processing
|
||||
- Tenant queue cleanup even on errors using real Redis
|
||||
- Proper error logging
|
||||
- Function completes without raising exceptions
|
||||
- Queue management continues despite core processing errors
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
dataset, documents = self._create_test_dataset_and_documents(
|
||||
db_session_with_containers, mock_external_service_dependencies, document_count=1
|
||||
)
|
||||
document_ids = [doc.id for doc in documents]
|
||||
tenant_id = dataset.tenant_id
|
||||
dataset_id = dataset.id
|
||||
|
||||
# Mock IndexingRunner to raise an exception
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.side_effect = Exception("Test error")
|
||||
|
||||
# Mock the task function
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
mock_task_func = MagicMock()
|
||||
|
||||
# Use real Redis for TenantIsolatedTaskQueue
|
||||
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
|
||||
|
||||
# Create real queue instance
|
||||
queue = TenantIsolatedTaskQueue(tenant_id, "document_indexing")
|
||||
|
||||
# Add waiting task to the real Redis queue
|
||||
waiting_task = DocumentTask(tenant_id=tenant_id, dataset_id=dataset.id, document_ids=["waiting-doc-1"])
|
||||
queue.push_tasks([asdict(waiting_task)])
|
||||
|
||||
# Act: Execute the wrapper function
|
||||
_document_indexing_with_tenant_queue(tenant_id, dataset.id, document_ids, mock_task_func)
|
||||
|
||||
# Assert: Verify error was handled gracefully
|
||||
# The function should not raise exceptions
|
||||
mock_external_service_dependencies["indexing_runner"].assert_called_once()
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||
|
||||
# Verify documents were still updated to parsing status before the exception
|
||||
# Re-query documents from database since _document_indexing uses a different session
|
||||
for doc_id in document_ids:
|
||||
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
|
||||
assert updated_document.indexing_status == "parsing"
|
||||
assert updated_document.processing_started_at is not None
|
||||
|
||||
# Verify waiting task was still processed despite core processing error
|
||||
mock_task_func.delay.assert_called_once()
|
||||
|
||||
# Verify correct parameters for the call
|
||||
call = mock_task_func.delay.call_args
|
||||
assert call[1] == {"tenant_id": tenant_id, "dataset_id": dataset_id, "document_ids": ["waiting-doc-1"]}
|
||||
|
||||
# Verify queue is empty after processing (task was pulled)
|
||||
remaining_tasks = queue.pull_tasks(count=10)
|
||||
assert len(remaining_tasks) == 0
|
||||
|
||||
def test_document_indexing_with_tenant_queue_tenant_isolation(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test tenant isolation in _document_indexing_with_tenant_queue using real Redis.
|
||||
|
||||
This test verifies:
|
||||
- Different tenants have isolated queues
|
||||
- Tasks from one tenant don't affect another tenant's queue
|
||||
- Queue operations are properly scoped to tenant
|
||||
"""
|
||||
# Arrange: Create test data for two different tenants
|
||||
dataset1, documents1 = self._create_test_dataset_and_documents(
|
||||
db_session_with_containers, mock_external_service_dependencies, document_count=1
|
||||
)
|
||||
dataset2, documents2 = self._create_test_dataset_and_documents(
|
||||
db_session_with_containers, mock_external_service_dependencies, document_count=1
|
||||
)
|
||||
|
||||
tenant1_id = dataset1.tenant_id
|
||||
tenant2_id = dataset2.tenant_id
|
||||
dataset1_id = dataset1.id
|
||||
dataset2_id = dataset2.id
|
||||
document_ids1 = [doc.id for doc in documents1]
|
||||
document_ids2 = [doc.id for doc in documents2]
|
||||
|
||||
# Mock the task function
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
mock_task_func = MagicMock()
|
||||
|
||||
# Use real Redis for TenantIsolatedTaskQueue
|
||||
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
|
||||
|
||||
# Create queue instances for both tenants
|
||||
queue1 = TenantIsolatedTaskQueue(tenant1_id, "document_indexing")
|
||||
queue2 = TenantIsolatedTaskQueue(tenant2_id, "document_indexing")
|
||||
|
||||
# Add waiting tasks to both queues
|
||||
waiting_task1 = DocumentTask(tenant_id=tenant1_id, dataset_id=dataset1.id, document_ids=["tenant1-doc-1"])
|
||||
waiting_task2 = DocumentTask(tenant_id=tenant2_id, dataset_id=dataset2.id, document_ids=["tenant2-doc-1"])
|
||||
|
||||
queue1.push_tasks([asdict(waiting_task1)])
|
||||
queue2.push_tasks([asdict(waiting_task2)])
|
||||
|
||||
# Act: Execute the wrapper function for tenant1 only
|
||||
_document_indexing_with_tenant_queue(tenant1_id, dataset1.id, document_ids1, mock_task_func)
|
||||
|
||||
# Assert: Verify core processing occurred for tenant1
|
||||
mock_external_service_dependencies["indexing_runner"].assert_called_once()
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||
|
||||
# Verify only tenant1's waiting task was processed
|
||||
mock_task_func.delay.assert_called_once()
|
||||
call = mock_task_func.delay.call_args
|
||||
assert call[1] == {"tenant_id": tenant1_id, "dataset_id": dataset1_id, "document_ids": ["tenant1-doc-1"]}
|
||||
|
||||
# Verify tenant1's queue is empty
|
||||
remaining_tasks1 = queue1.pull_tasks(count=10)
|
||||
assert len(remaining_tasks1) == 0
|
||||
|
||||
# Verify tenant2's queue still has its task (isolation)
|
||||
remaining_tasks2 = queue2.pull_tasks(count=10)
|
||||
assert len(remaining_tasks2) == 1
|
||||
|
||||
# Verify queue keys are different
|
||||
assert queue1._queue != queue2._queue
|
||||
assert queue1._task_key != queue2._task_key
|
||||
|
||||
@@ -0,0 +1,936 @@
|
||||
import json
|
||||
import uuid
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity
|
||||
from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
|
||||
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
|
||||
from extensions.ext_database import db
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Pipeline
|
||||
from models.workflow import Workflow
|
||||
from tasks.rag_pipeline.priority_rag_pipeline_run_task import (
|
||||
priority_rag_pipeline_run_task,
|
||||
run_single_rag_pipeline_task,
|
||||
)
|
||||
from tasks.rag_pipeline.rag_pipeline_run_task import rag_pipeline_run_task
|
||||
|
||||
|
||||
class TestRagPipelineRunTasks:
|
||||
"""Integration tests for RAG pipeline run tasks using testcontainers.
|
||||
|
||||
This test class covers:
|
||||
- priority_rag_pipeline_run_task function
|
||||
- rag_pipeline_run_task function
|
||||
- run_single_rag_pipeline_task function
|
||||
- Real Redis-based TenantIsolatedTaskQueue operations
|
||||
- PipelineGenerator._generate method mocking and parameter validation
|
||||
- File operations and cleanup
|
||||
- Error handling and queue management
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_pipeline_generator(self):
|
||||
"""Mock PipelineGenerator._generate method."""
|
||||
with patch("core.app.apps.pipeline.pipeline_generator.PipelineGenerator._generate") as mock_generate:
|
||||
# Mock the _generate method to return a simple response
|
||||
mock_generate.return_value = {"answer": "Test response", "metadata": {"test": "data"}}
|
||||
yield mock_generate
|
||||
|
||||
@pytest.fixture
|
||||
def mock_file_service(self):
|
||||
"""Mock FileService for file operations."""
|
||||
with (
|
||||
patch("services.file_service.FileService.get_file_content") as mock_get_content,
|
||||
patch("services.file_service.FileService.delete_file") as mock_delete_file,
|
||||
):
|
||||
yield {
|
||||
"get_content": mock_get_content,
|
||||
"delete_file": mock_delete_file,
|
||||
}
|
||||
|
||||
def _create_test_pipeline_and_workflow(self, db_session_with_containers):
|
||||
"""
|
||||
Helper method to create test pipeline and workflow for testing.
|
||||
|
||||
Args:
|
||||
db_session_with_containers: Database session from testcontainers infrastructure
|
||||
|
||||
Returns:
|
||||
tuple: (account, tenant, pipeline, workflow) - Created entities
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create account and tenant
|
||||
account = Account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
)
|
||||
db.session.add(account)
|
||||
db.session.commit()
|
||||
|
||||
tenant = Tenant(
|
||||
name=fake.company(),
|
||||
status="normal",
|
||||
)
|
||||
db.session.add(tenant)
|
||||
db.session.commit()
|
||||
|
||||
# Create tenant-account join
|
||||
join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role=TenantAccountRole.OWNER,
|
||||
current=True,
|
||||
)
|
||||
db.session.add(join)
|
||||
db.session.commit()
|
||||
|
||||
# Create workflow
|
||||
workflow = Workflow(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=tenant.id,
|
||||
app_id=str(uuid.uuid4()),
|
||||
type="workflow",
|
||||
version="draft",
|
||||
graph="{}",
|
||||
features="{}",
|
||||
marked_name=fake.company(),
|
||||
marked_comment=fake.text(max_nb_chars=100),
|
||||
created_by=account.id,
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
rag_pipeline_variables=[],
|
||||
)
|
||||
db.session.add(workflow)
|
||||
db.session.commit()
|
||||
|
||||
# Create pipeline
|
||||
pipeline = Pipeline(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=tenant.id,
|
||||
workflow_id=workflow.id,
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
created_by=account.id,
|
||||
)
|
||||
db.session.add(pipeline)
|
||||
db.session.commit()
|
||||
|
||||
# Refresh entities to ensure they're properly loaded
|
||||
db.session.refresh(account)
|
||||
db.session.refresh(tenant)
|
||||
db.session.refresh(workflow)
|
||||
db.session.refresh(pipeline)
|
||||
|
||||
return account, tenant, pipeline, workflow
|
||||
|
||||
def _create_rag_pipeline_invoke_entities(self, account, tenant, pipeline, workflow, count=2):
|
||||
"""
|
||||
Helper method to create RAG pipeline invoke entities for testing.
|
||||
|
||||
Args:
|
||||
account: Account instance
|
||||
tenant: Tenant instance
|
||||
pipeline: Pipeline instance
|
||||
workflow: Workflow instance
|
||||
count: Number of entities to create
|
||||
|
||||
Returns:
|
||||
list: List of RagPipelineInvokeEntity instances
|
||||
"""
|
||||
fake = Faker()
|
||||
entities = []
|
||||
|
||||
for i in range(count):
|
||||
# Create application generate entity
|
||||
app_config = {
|
||||
"app_id": str(uuid.uuid4()),
|
||||
"app_name": fake.company(),
|
||||
"mode": "workflow",
|
||||
"workflow_id": workflow.id,
|
||||
"tenant_id": tenant.id,
|
||||
"app_mode": "workflow",
|
||||
}
|
||||
|
||||
application_generate_entity = {
|
||||
"task_id": str(uuid.uuid4()),
|
||||
"app_config": app_config,
|
||||
"inputs": {"query": f"Test query {i}"},
|
||||
"files": [],
|
||||
"user_id": account.id,
|
||||
"stream": False,
|
||||
"invoke_from": "published",
|
||||
"workflow_execution_id": str(uuid.uuid4()),
|
||||
"pipeline_config": {
|
||||
"app_id": str(uuid.uuid4()),
|
||||
"app_name": fake.company(),
|
||||
"mode": "workflow",
|
||||
"workflow_id": workflow.id,
|
||||
"tenant_id": tenant.id,
|
||||
"app_mode": "workflow",
|
||||
},
|
||||
"datasource_type": "upload_file",
|
||||
"datasource_info": {},
|
||||
"dataset_id": str(uuid.uuid4()),
|
||||
"batch": "test_batch",
|
||||
}
|
||||
|
||||
entity = RagPipelineInvokeEntity(
|
||||
pipeline_id=pipeline.id,
|
||||
application_generate_entity=application_generate_entity,
|
||||
user_id=account.id,
|
||||
tenant_id=tenant.id,
|
||||
workflow_id=workflow.id,
|
||||
streaming=False,
|
||||
workflow_execution_id=str(uuid.uuid4()),
|
||||
workflow_thread_pool_id=str(uuid.uuid4()),
|
||||
)
|
||||
entities.append(entity)
|
||||
|
||||
return entities
|
||||
|
||||
def _create_file_content_for_entities(self, entities):
|
||||
"""
|
||||
Helper method to create file content for RAG pipeline invoke entities.
|
||||
|
||||
Args:
|
||||
entities: List of RagPipelineInvokeEntity instances
|
||||
|
||||
Returns:
|
||||
str: JSON string containing serialized entities
|
||||
"""
|
||||
entities_data = [entity.model_dump() for entity in entities]
|
||||
return json.dumps(entities_data)
|
||||
|
||||
def test_priority_rag_pipeline_run_task_success(
|
||||
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
|
||||
):
|
||||
"""
|
||||
Test successful priority RAG pipeline run task execution.
|
||||
|
||||
This test verifies:
|
||||
- Task execution with multiple RAG pipeline invoke entities
|
||||
- File content retrieval and parsing
|
||||
- PipelineGenerator._generate method calls with correct parameters
|
||||
- Thread pool execution
|
||||
- File cleanup after execution
|
||||
- Queue management with no waiting tasks
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
account, tenant, pipeline, workflow = self._create_test_pipeline_and_workflow(db_session_with_containers)
|
||||
entities = self._create_rag_pipeline_invoke_entities(account, tenant, pipeline, workflow, count=2)
|
||||
file_content = self._create_file_content_for_entities(entities)
|
||||
|
||||
# Mock file service
|
||||
file_id = str(uuid.uuid4())
|
||||
mock_file_service["get_content"].return_value = file_content
|
||||
|
||||
# Act: Execute the priority task
|
||||
priority_rag_pipeline_run_task(file_id, tenant.id)
|
||||
|
||||
# Assert: Verify expected outcomes
|
||||
# Verify file operations
|
||||
mock_file_service["get_content"].assert_called_once_with(file_id)
|
||||
mock_file_service["delete_file"].assert_called_once_with(file_id)
|
||||
|
||||
# Verify PipelineGenerator._generate was called for each entity
|
||||
assert mock_pipeline_generator.call_count == 2
|
||||
|
||||
# Verify call parameters for each entity
|
||||
calls = mock_pipeline_generator.call_args_list
|
||||
for call in calls:
|
||||
call_kwargs = call[1] # Get keyword arguments
|
||||
assert call_kwargs["pipeline"].id == pipeline.id
|
||||
assert call_kwargs["workflow_id"] == workflow.id
|
||||
assert call_kwargs["user"].id == account.id
|
||||
assert call_kwargs["invoke_from"] == InvokeFrom.PUBLISHED
|
||||
assert call_kwargs["streaming"] == False
|
||||
assert isinstance(call_kwargs["application_generate_entity"], RagPipelineGenerateEntity)
|
||||
|
||||
def test_rag_pipeline_run_task_success(
|
||||
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
|
||||
):
|
||||
"""
|
||||
Test successful regular RAG pipeline run task execution.
|
||||
|
||||
This test verifies:
|
||||
- Task execution with multiple RAG pipeline invoke entities
|
||||
- File content retrieval and parsing
|
||||
- PipelineGenerator._generate method calls with correct parameters
|
||||
- Thread pool execution
|
||||
- File cleanup after execution
|
||||
- Queue management with no waiting tasks
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
account, tenant, pipeline, workflow = self._create_test_pipeline_and_workflow(db_session_with_containers)
|
||||
entities = self._create_rag_pipeline_invoke_entities(account, tenant, pipeline, workflow, count=3)
|
||||
file_content = self._create_file_content_for_entities(entities)
|
||||
|
||||
# Mock file service
|
||||
file_id = str(uuid.uuid4())
|
||||
mock_file_service["get_content"].return_value = file_content
|
||||
|
||||
# Act: Execute the regular task
|
||||
rag_pipeline_run_task(file_id, tenant.id)
|
||||
|
||||
# Assert: Verify expected outcomes
|
||||
# Verify file operations
|
||||
mock_file_service["get_content"].assert_called_once_with(file_id)
|
||||
mock_file_service["delete_file"].assert_called_once_with(file_id)
|
||||
|
||||
# Verify PipelineGenerator._generate was called for each entity
|
||||
assert mock_pipeline_generator.call_count == 3
|
||||
|
||||
# Verify call parameters for each entity
|
||||
calls = mock_pipeline_generator.call_args_list
|
||||
for call in calls:
|
||||
call_kwargs = call[1] # Get keyword arguments
|
||||
assert call_kwargs["pipeline"].id == pipeline.id
|
||||
assert call_kwargs["workflow_id"] == workflow.id
|
||||
assert call_kwargs["user"].id == account.id
|
||||
assert call_kwargs["invoke_from"] == InvokeFrom.PUBLISHED
|
||||
assert call_kwargs["streaming"] == False
|
||||
assert isinstance(call_kwargs["application_generate_entity"], RagPipelineGenerateEntity)
|
||||
|
||||
def test_priority_rag_pipeline_run_task_with_waiting_tasks(
|
||||
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
|
||||
):
|
||||
"""
|
||||
Test priority RAG pipeline run task with waiting tasks in queue using real Redis.
|
||||
|
||||
This test verifies:
|
||||
- Core task execution
|
||||
- Real Redis-based tenant queue processing of waiting tasks
|
||||
- Task function calls for waiting tasks
|
||||
- Queue management with multiple tasks using actual Redis operations
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
account, tenant, pipeline, workflow = self._create_test_pipeline_and_workflow(db_session_with_containers)
|
||||
entities = self._create_rag_pipeline_invoke_entities(account, tenant, pipeline, workflow, count=1)
|
||||
file_content = self._create_file_content_for_entities(entities)
|
||||
|
||||
# Mock file service
|
||||
file_id = str(uuid.uuid4())
|
||||
mock_file_service["get_content"].return_value = file_content
|
||||
|
||||
# Use real Redis for TenantIsolatedTaskQueue
|
||||
queue = TenantIsolatedTaskQueue(tenant.id, "pipeline")
|
||||
|
||||
# Add waiting tasks to the real Redis queue
|
||||
waiting_file_ids = [str(uuid.uuid4()) for _ in range(2)]
|
||||
queue.push_tasks(waiting_file_ids)
|
||||
|
||||
# Mock the task function calls
|
||||
with patch(
|
||||
"tasks.rag_pipeline.priority_rag_pipeline_run_task.priority_rag_pipeline_run_task.delay"
|
||||
) as mock_delay:
|
||||
# Act: Execute the priority task
|
||||
priority_rag_pipeline_run_task(file_id, tenant.id)
|
||||
|
||||
# Assert: Verify core processing occurred
|
||||
mock_file_service["get_content"].assert_called_once_with(file_id)
|
||||
mock_file_service["delete_file"].assert_called_once_with(file_id)
|
||||
assert mock_pipeline_generator.call_count == 1
|
||||
|
||||
# Verify waiting tasks were processed, pull 1 task a time by default
|
||||
assert mock_delay.call_count == 1
|
||||
|
||||
# Verify correct parameters for the call
|
||||
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
|
||||
assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_ids[0]
|
||||
assert call_kwargs.get("tenant_id") == tenant.id
|
||||
|
||||
# Verify queue still has remaining tasks (only 1 was pulled)
|
||||
remaining_tasks = queue.pull_tasks(count=10)
|
||||
assert len(remaining_tasks) == 1 # 2 original - 1 pulled = 1 remaining
|
||||
|
||||
def test_rag_pipeline_run_task_legacy_compatibility(
|
||||
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
|
||||
):
|
||||
"""
|
||||
Test regular RAG pipeline run task with legacy Redis queue format for backward compatibility.
|
||||
|
||||
This test simulates the scenario where:
|
||||
- Old code writes file IDs directly to Redis list using lpush
|
||||
- New worker processes these legacy queue entries
|
||||
- Ensures backward compatibility during deployment transition
|
||||
|
||||
Legacy format: redis_client.lpush(tenant_self_pipeline_task_queue, upload_file.id)
|
||||
New format: TenantIsolatedTaskQueue.push_tasks([file_id])
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
account, tenant, pipeline, workflow = self._create_test_pipeline_and_workflow(db_session_with_containers)
|
||||
entities = self._create_rag_pipeline_invoke_entities(account, tenant, pipeline, workflow, count=1)
|
||||
file_content = self._create_file_content_for_entities(entities)
|
||||
|
||||
# Mock file service
|
||||
file_id = str(uuid.uuid4())
|
||||
mock_file_service["get_content"].return_value = file_content
|
||||
|
||||
# Simulate legacy Redis queue format - direct file IDs in Redis list
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
# Legacy queue key format (old code)
|
||||
legacy_queue_key = f"tenant_self_pipeline_task_queue:{tenant.id}"
|
||||
legacy_task_key = f"tenant_pipeline_task:{tenant.id}"
|
||||
|
||||
# Add legacy format data to Redis (simulating old code behavior)
|
||||
legacy_file_ids = [str(uuid.uuid4()) for _ in range(3)]
|
||||
for file_id_legacy in legacy_file_ids:
|
||||
redis_client.lpush(legacy_queue_key, file_id_legacy)
|
||||
|
||||
# Set the task key to indicate there are waiting tasks (legacy behavior)
|
||||
redis_client.set(legacy_task_key, 1, ex=60 * 60)
|
||||
|
||||
# Mock the task function calls
|
||||
with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay:
|
||||
# Act: Execute the priority task with new code but legacy queue data
|
||||
rag_pipeline_run_task(file_id, tenant.id)
|
||||
|
||||
# Assert: Verify core processing occurred
|
||||
mock_file_service["get_content"].assert_called_once_with(file_id)
|
||||
mock_file_service["delete_file"].assert_called_once_with(file_id)
|
||||
assert mock_pipeline_generator.call_count == 1
|
||||
|
||||
# Verify waiting tasks were processed, pull 1 task a time by default
|
||||
assert mock_delay.call_count == 1
|
||||
|
||||
# Verify correct parameters for the call
|
||||
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
|
||||
assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == legacy_file_ids[0]
|
||||
assert call_kwargs.get("tenant_id") == tenant.id
|
||||
|
||||
# Verify that new code can process legacy queue entries
|
||||
# The new TenantIsolatedTaskQueue should be able to read from the legacy format
|
||||
queue = TenantIsolatedTaskQueue(tenant.id, "pipeline")
|
||||
|
||||
# Verify queue still has remaining tasks (only 1 was pulled)
|
||||
remaining_tasks = queue.pull_tasks(count=10)
|
||||
assert len(remaining_tasks) == 2 # 3 original - 1 pulled = 2 remaining
|
||||
|
||||
# Cleanup: Remove legacy test data
|
||||
redis_client.delete(legacy_queue_key)
|
||||
redis_client.delete(legacy_task_key)
|
||||
|
||||
def test_rag_pipeline_run_task_with_waiting_tasks(
|
||||
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
|
||||
):
|
||||
"""
|
||||
Test regular RAG pipeline run task with waiting tasks in queue using real Redis.
|
||||
|
||||
This test verifies:
|
||||
- Core task execution
|
||||
- Real Redis-based tenant queue processing of waiting tasks
|
||||
- Task function calls for waiting tasks
|
||||
- Queue management with multiple tasks using actual Redis operations
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
account, tenant, pipeline, workflow = self._create_test_pipeline_and_workflow(db_session_with_containers)
|
||||
entities = self._create_rag_pipeline_invoke_entities(account, tenant, pipeline, workflow, count=1)
|
||||
file_content = self._create_file_content_for_entities(entities)
|
||||
|
||||
# Mock file service
|
||||
file_id = str(uuid.uuid4())
|
||||
mock_file_service["get_content"].return_value = file_content
|
||||
|
||||
# Use real Redis for TenantIsolatedTaskQueue
|
||||
queue = TenantIsolatedTaskQueue(tenant.id, "pipeline")
|
||||
|
||||
# Add waiting tasks to the real Redis queue
|
||||
waiting_file_ids = [str(uuid.uuid4()) for _ in range(3)]
|
||||
queue.push_tasks(waiting_file_ids)
|
||||
|
||||
# Mock the task function calls
|
||||
with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay:
|
||||
# Act: Execute the regular task
|
||||
rag_pipeline_run_task(file_id, tenant.id)
|
||||
|
||||
# Assert: Verify core processing occurred
|
||||
mock_file_service["get_content"].assert_called_once_with(file_id)
|
||||
mock_file_service["delete_file"].assert_called_once_with(file_id)
|
||||
assert mock_pipeline_generator.call_count == 1
|
||||
|
||||
# Verify waiting tasks were processed, pull 1 task a time by default
|
||||
assert mock_delay.call_count == 1
|
||||
|
||||
# Verify correct parameters for the call
|
||||
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
|
||||
assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_ids[0]
|
||||
assert call_kwargs.get("tenant_id") == tenant.id
|
||||
|
||||
# Verify queue still has remaining tasks (only 1 was pulled)
|
||||
remaining_tasks = queue.pull_tasks(count=10)
|
||||
assert len(remaining_tasks) == 2 # 3 original - 1 pulled = 2 remaining
|
||||
|
||||
def test_priority_rag_pipeline_run_task_error_handling(
|
||||
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
|
||||
):
|
||||
"""
|
||||
Test error handling in priority RAG pipeline run task using real Redis.
|
||||
|
||||
This test verifies:
|
||||
- Exception handling during core processing
|
||||
- Tenant queue cleanup even on errors using real Redis
|
||||
- Proper error logging
|
||||
- Function completes without raising exceptions
|
||||
- Queue management continues despite core processing errors
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
account, tenant, pipeline, workflow = self._create_test_pipeline_and_workflow(db_session_with_containers)
|
||||
entities = self._create_rag_pipeline_invoke_entities(account, tenant, pipeline, workflow, count=1)
|
||||
file_content = self._create_file_content_for_entities(entities)
|
||||
|
||||
# Mock file service
|
||||
file_id = str(uuid.uuid4())
|
||||
mock_file_service["get_content"].return_value = file_content
|
||||
|
||||
# Mock PipelineGenerator to raise an exception
|
||||
mock_pipeline_generator.side_effect = Exception("Pipeline generation failed")
|
||||
|
||||
# Use real Redis for TenantIsolatedTaskQueue
|
||||
queue = TenantIsolatedTaskQueue(tenant.id, "pipeline")
|
||||
|
||||
# Add waiting task to the real Redis queue
|
||||
waiting_file_id = str(uuid.uuid4())
|
||||
queue.push_tasks([waiting_file_id])
|
||||
|
||||
# Mock the task function calls
|
||||
with patch(
|
||||
"tasks.rag_pipeline.priority_rag_pipeline_run_task.priority_rag_pipeline_run_task.delay"
|
||||
) as mock_delay:
|
||||
# Act: Execute the priority task (should not raise exception)
|
||||
priority_rag_pipeline_run_task(file_id, tenant.id)
|
||||
|
||||
# Assert: Verify error was handled gracefully
|
||||
# The function should not raise exceptions
|
||||
mock_file_service["get_content"].assert_called_once_with(file_id)
|
||||
mock_file_service["delete_file"].assert_called_once_with(file_id)
|
||||
assert mock_pipeline_generator.call_count == 1
|
||||
|
||||
# Verify waiting task was still processed despite core processing error
|
||||
mock_delay.assert_called_once()
|
||||
|
||||
# Verify correct parameters for the call
|
||||
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
|
||||
assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id
|
||||
assert call_kwargs.get("tenant_id") == tenant.id
|
||||
|
||||
# Verify queue is empty after processing (task was pulled)
|
||||
remaining_tasks = queue.pull_tasks(count=10)
|
||||
assert len(remaining_tasks) == 0
|
||||
|
||||
def test_rag_pipeline_run_task_error_handling(
|
||||
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
|
||||
):
|
||||
"""
|
||||
Test error handling in regular RAG pipeline run task using real Redis.
|
||||
|
||||
This test verifies:
|
||||
- Exception handling during core processing
|
||||
- Tenant queue cleanup even on errors using real Redis
|
||||
- Proper error logging
|
||||
- Function completes without raising exceptions
|
||||
- Queue management continues despite core processing errors
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
account, tenant, pipeline, workflow = self._create_test_pipeline_and_workflow(db_session_with_containers)
|
||||
entities = self._create_rag_pipeline_invoke_entities(account, tenant, pipeline, workflow, count=1)
|
||||
file_content = self._create_file_content_for_entities(entities)
|
||||
|
||||
# Mock file service
|
||||
file_id = str(uuid.uuid4())
|
||||
mock_file_service["get_content"].return_value = file_content
|
||||
|
||||
# Mock PipelineGenerator to raise an exception
|
||||
mock_pipeline_generator.side_effect = Exception("Pipeline generation failed")
|
||||
|
||||
# Use real Redis for TenantIsolatedTaskQueue
|
||||
queue = TenantIsolatedTaskQueue(tenant.id, "pipeline")
|
||||
|
||||
# Add waiting task to the real Redis queue
|
||||
waiting_file_id = str(uuid.uuid4())
|
||||
queue.push_tasks([waiting_file_id])
|
||||
|
||||
# Mock the task function calls
|
||||
with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay:
|
||||
# Act: Execute the regular task (should not raise exception)
|
||||
rag_pipeline_run_task(file_id, tenant.id)
|
||||
|
||||
# Assert: Verify error was handled gracefully
|
||||
# The function should not raise exceptions
|
||||
mock_file_service["get_content"].assert_called_once_with(file_id)
|
||||
mock_file_service["delete_file"].assert_called_once_with(file_id)
|
||||
assert mock_pipeline_generator.call_count == 1
|
||||
|
||||
# Verify waiting task was still processed despite core processing error
|
||||
mock_delay.assert_called_once()
|
||||
|
||||
# Verify correct parameters for the call
|
||||
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
|
||||
assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id
|
||||
assert call_kwargs.get("tenant_id") == tenant.id
|
||||
|
||||
# Verify queue is empty after processing (task was pulled)
|
||||
remaining_tasks = queue.pull_tasks(count=10)
|
||||
assert len(remaining_tasks) == 0
|
||||
|
||||
def test_priority_rag_pipeline_run_task_tenant_isolation(
|
||||
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
|
||||
):
|
||||
"""
|
||||
Test tenant isolation in priority RAG pipeline run task using real Redis.
|
||||
|
||||
This test verifies:
|
||||
- Different tenants have isolated queues
|
||||
- Tasks from one tenant don't affect another tenant's queue
|
||||
- Queue operations are properly scoped to tenant
|
||||
"""
|
||||
# Arrange: Create test data for two different tenants
|
||||
account1, tenant1, pipeline1, workflow1 = self._create_test_pipeline_and_workflow(db_session_with_containers)
|
||||
account2, tenant2, pipeline2, workflow2 = self._create_test_pipeline_and_workflow(db_session_with_containers)
|
||||
|
||||
entities1 = self._create_rag_pipeline_invoke_entities(account1, tenant1, pipeline1, workflow1, count=1)
|
||||
entities2 = self._create_rag_pipeline_invoke_entities(account2, tenant2, pipeline2, workflow2, count=1)
|
||||
|
||||
file_content1 = self._create_file_content_for_entities(entities1)
|
||||
file_content2 = self._create_file_content_for_entities(entities2)
|
||||
|
||||
# Mock file service
|
||||
file_id1 = str(uuid.uuid4())
|
||||
file_id2 = str(uuid.uuid4())
|
||||
mock_file_service["get_content"].side_effect = [file_content1, file_content2]
|
||||
|
||||
# Use real Redis for TenantIsolatedTaskQueue
|
||||
queue1 = TenantIsolatedTaskQueue(tenant1.id, "pipeline")
|
||||
queue2 = TenantIsolatedTaskQueue(tenant2.id, "pipeline")
|
||||
|
||||
# Add waiting tasks to both queues
|
||||
waiting_file_id1 = str(uuid.uuid4())
|
||||
waiting_file_id2 = str(uuid.uuid4())
|
||||
|
||||
queue1.push_tasks([waiting_file_id1])
|
||||
queue2.push_tasks([waiting_file_id2])
|
||||
|
||||
# Mock the task function calls
|
||||
with patch(
|
||||
"tasks.rag_pipeline.priority_rag_pipeline_run_task.priority_rag_pipeline_run_task.delay"
|
||||
) as mock_delay:
|
||||
# Act: Execute the priority task for tenant1 only
|
||||
priority_rag_pipeline_run_task(file_id1, tenant1.id)
|
||||
|
||||
# Assert: Verify core processing occurred for tenant1
|
||||
assert mock_file_service["get_content"].call_count == 1
|
||||
assert mock_file_service["delete_file"].call_count == 1
|
||||
assert mock_pipeline_generator.call_count == 1
|
||||
|
||||
# Verify only tenant1's waiting task was processed
|
||||
mock_delay.assert_called_once()
|
||||
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
|
||||
assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id1
|
||||
assert call_kwargs.get("tenant_id") == tenant1.id
|
||||
|
||||
# Verify tenant1's queue is empty
|
||||
remaining_tasks1 = queue1.pull_tasks(count=10)
|
||||
assert len(remaining_tasks1) == 0
|
||||
|
||||
# Verify tenant2's queue still has its task (isolation)
|
||||
remaining_tasks2 = queue2.pull_tasks(count=10)
|
||||
assert len(remaining_tasks2) == 1
|
||||
|
||||
# Verify queue keys are different
|
||||
assert queue1._queue != queue2._queue
|
||||
assert queue1._task_key != queue2._task_key
|
||||
|
||||
def test_rag_pipeline_run_task_tenant_isolation(
|
||||
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
|
||||
):
|
||||
"""
|
||||
Test tenant isolation in regular RAG pipeline run task using real Redis.
|
||||
|
||||
This test verifies:
|
||||
- Different tenants have isolated queues
|
||||
- Tasks from one tenant don't affect another tenant's queue
|
||||
- Queue operations are properly scoped to tenant
|
||||
"""
|
||||
# Arrange: Create test data for two different tenants
|
||||
account1, tenant1, pipeline1, workflow1 = self._create_test_pipeline_and_workflow(db_session_with_containers)
|
||||
account2, tenant2, pipeline2, workflow2 = self._create_test_pipeline_and_workflow(db_session_with_containers)
|
||||
|
||||
entities1 = self._create_rag_pipeline_invoke_entities(account1, tenant1, pipeline1, workflow1, count=1)
|
||||
entities2 = self._create_rag_pipeline_invoke_entities(account2, tenant2, pipeline2, workflow2, count=1)
|
||||
|
||||
file_content1 = self._create_file_content_for_entities(entities1)
|
||||
file_content2 = self._create_file_content_for_entities(entities2)
|
||||
|
||||
# Mock file service
|
||||
file_id1 = str(uuid.uuid4())
|
||||
file_id2 = str(uuid.uuid4())
|
||||
mock_file_service["get_content"].side_effect = [file_content1, file_content2]
|
||||
|
||||
# Use real Redis for TenantIsolatedTaskQueue
|
||||
queue1 = TenantIsolatedTaskQueue(tenant1.id, "pipeline")
|
||||
queue2 = TenantIsolatedTaskQueue(tenant2.id, "pipeline")
|
||||
|
||||
# Add waiting tasks to both queues
|
||||
waiting_file_id1 = str(uuid.uuid4())
|
||||
waiting_file_id2 = str(uuid.uuid4())
|
||||
|
||||
queue1.push_tasks([waiting_file_id1])
|
||||
queue2.push_tasks([waiting_file_id2])
|
||||
|
||||
# Mock the task function calls
|
||||
with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay:
|
||||
# Act: Execute the regular task for tenant1 only
|
||||
rag_pipeline_run_task(file_id1, tenant1.id)
|
||||
|
||||
# Assert: Verify core processing occurred for tenant1
|
||||
assert mock_file_service["get_content"].call_count == 1
|
||||
assert mock_file_service["delete_file"].call_count == 1
|
||||
assert mock_pipeline_generator.call_count == 1
|
||||
|
||||
# Verify only tenant1's waiting task was processed
|
||||
mock_delay.assert_called_once()
|
||||
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
|
||||
assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id1
|
||||
assert call_kwargs.get("tenant_id") == tenant1.id
|
||||
|
||||
# Verify tenant1's queue is empty
|
||||
remaining_tasks1 = queue1.pull_tasks(count=10)
|
||||
assert len(remaining_tasks1) == 0
|
||||
|
||||
# Verify tenant2's queue still has its task (isolation)
|
||||
remaining_tasks2 = queue2.pull_tasks(count=10)
|
||||
assert len(remaining_tasks2) == 1
|
||||
|
||||
# Verify queue keys are different
|
||||
assert queue1._queue != queue2._queue
|
||||
assert queue1._task_key != queue2._task_key
|
||||
|
||||
def test_run_single_rag_pipeline_task_success(
|
||||
self, db_session_with_containers, mock_pipeline_generator, flask_app_with_containers
|
||||
):
|
||||
"""
|
||||
Test successful run_single_rag_pipeline_task execution.
|
||||
|
||||
This test verifies:
|
||||
- Single RAG pipeline task execution within Flask app context
|
||||
- Entity validation and database queries
|
||||
- PipelineGenerator._generate method call with correct parameters
|
||||
- Proper Flask context handling
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
account, tenant, pipeline, workflow = self._create_test_pipeline_and_workflow(db_session_with_containers)
|
||||
entities = self._create_rag_pipeline_invoke_entities(account, tenant, pipeline, workflow, count=1)
|
||||
entity_data = entities[0].model_dump()
|
||||
|
||||
# Act: Execute the single task
|
||||
with flask_app_with_containers.app_context():
|
||||
run_single_rag_pipeline_task(entity_data, flask_app_with_containers)
|
||||
|
||||
# Assert: Verify expected outcomes
|
||||
# Verify PipelineGenerator._generate was called
|
||||
assert mock_pipeline_generator.call_count == 1
|
||||
|
||||
# Verify call parameters
|
||||
call = mock_pipeline_generator.call_args
|
||||
call_kwargs = call[1] # Get keyword arguments
|
||||
assert call_kwargs["pipeline"].id == pipeline.id
|
||||
assert call_kwargs["workflow_id"] == workflow.id
|
||||
assert call_kwargs["user"].id == account.id
|
||||
assert call_kwargs["invoke_from"] == InvokeFrom.PUBLISHED
|
||||
assert call_kwargs["streaming"] == False
|
||||
assert isinstance(call_kwargs["application_generate_entity"], RagPipelineGenerateEntity)
|
||||
|
||||
def test_run_single_rag_pipeline_task_entity_validation_error(
|
||||
self, db_session_with_containers, mock_pipeline_generator, flask_app_with_containers
|
||||
):
|
||||
"""
|
||||
Test run_single_rag_pipeline_task with invalid entity data.
|
||||
|
||||
This test verifies:
|
||||
- Proper error handling for invalid entity data
|
||||
- Exception logging
|
||||
- Function raises ValueError for missing entities
|
||||
"""
|
||||
# Arrange: Create entity data with valid UUIDs but non-existent entities
|
||||
fake = Faker()
|
||||
invalid_entity_data = {
|
||||
"pipeline_id": str(uuid.uuid4()),
|
||||
"application_generate_entity": {
|
||||
"app_config": {
|
||||
"app_id": str(uuid.uuid4()),
|
||||
"app_name": "Test App",
|
||||
"mode": "workflow",
|
||||
"workflow_id": str(uuid.uuid4()),
|
||||
},
|
||||
"inputs": {"query": "Test query"},
|
||||
"query": "Test query",
|
||||
"response_mode": "blocking",
|
||||
"user": str(uuid.uuid4()),
|
||||
"files": [],
|
||||
"conversation_id": str(uuid.uuid4()),
|
||||
},
|
||||
"user_id": str(uuid.uuid4()),
|
||||
"tenant_id": str(uuid.uuid4()),
|
||||
"workflow_id": str(uuid.uuid4()),
|
||||
"streaming": False,
|
||||
"workflow_execution_id": str(uuid.uuid4()),
|
||||
"workflow_thread_pool_id": str(uuid.uuid4()),
|
||||
}
|
||||
|
||||
# Act & Assert: Execute the single task with non-existent entities (should raise ValueError)
|
||||
with flask_app_with_containers.app_context():
|
||||
with pytest.raises(ValueError, match="Account .* not found"):
|
||||
run_single_rag_pipeline_task(invalid_entity_data, flask_app_with_containers)
|
||||
|
||||
# Assert: Pipeline generator should not be called
|
||||
mock_pipeline_generator.assert_not_called()
|
||||
|
||||
def test_run_single_rag_pipeline_task_database_entity_not_found(
|
||||
self, db_session_with_containers, mock_pipeline_generator, flask_app_with_containers
|
||||
):
|
||||
"""
|
||||
Test run_single_rag_pipeline_task with non-existent database entities.
|
||||
|
||||
This test verifies:
|
||||
- Proper error handling for missing database entities
|
||||
- Exception logging
|
||||
- Function raises ValueError for missing entities
|
||||
"""
|
||||
# Arrange: Create test data with non-existent IDs
|
||||
fake = Faker()
|
||||
entity_data = {
|
||||
"pipeline_id": str(uuid.uuid4()),
|
||||
"application_generate_entity": {
|
||||
"app_config": {
|
||||
"app_id": str(uuid.uuid4()),
|
||||
"app_name": "Test App",
|
||||
"mode": "workflow",
|
||||
"workflow_id": str(uuid.uuid4()),
|
||||
},
|
||||
"inputs": {"query": "Test query"},
|
||||
"query": "Test query",
|
||||
"response_mode": "blocking",
|
||||
"user": str(uuid.uuid4()),
|
||||
"files": [],
|
||||
"conversation_id": str(uuid.uuid4()),
|
||||
},
|
||||
"user_id": str(uuid.uuid4()),
|
||||
"tenant_id": str(uuid.uuid4()),
|
||||
"workflow_id": str(uuid.uuid4()),
|
||||
"streaming": False,
|
||||
"workflow_execution_id": str(uuid.uuid4()),
|
||||
"workflow_thread_pool_id": str(uuid.uuid4()),
|
||||
}
|
||||
|
||||
# Act & Assert: Execute the single task with non-existent entities (should raise ValueError)
|
||||
with flask_app_with_containers.app_context():
|
||||
with pytest.raises(ValueError, match="Account .* not found"):
|
||||
run_single_rag_pipeline_task(entity_data, flask_app_with_containers)
|
||||
|
||||
# Assert: Pipeline generator should not be called
|
||||
mock_pipeline_generator.assert_not_called()
|
||||
|
||||
def test_priority_rag_pipeline_run_task_file_not_found(
|
||||
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
|
||||
):
|
||||
"""
|
||||
Test priority RAG pipeline run task with non-existent file.
|
||||
|
||||
This test verifies:
|
||||
- Proper error handling for missing files
|
||||
- Exception logging
|
||||
- Function raises Exception for file errors
|
||||
- Queue management continues despite file errors
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
account, tenant, pipeline, workflow = self._create_test_pipeline_and_workflow(db_session_with_containers)
|
||||
|
||||
# Mock file service to raise exception
|
||||
file_id = str(uuid.uuid4())
|
||||
mock_file_service["get_content"].side_effect = Exception("File not found")
|
||||
|
||||
# Use real Redis for TenantIsolatedTaskQueue
|
||||
queue = TenantIsolatedTaskQueue(tenant.id, "pipeline")
|
||||
|
||||
# Add waiting task to the real Redis queue
|
||||
waiting_file_id = str(uuid.uuid4())
|
||||
queue.push_tasks([waiting_file_id])
|
||||
|
||||
# Mock the task function calls
|
||||
with patch(
|
||||
"tasks.rag_pipeline.priority_rag_pipeline_run_task.priority_rag_pipeline_run_task.delay"
|
||||
) as mock_delay:
|
||||
# Act & Assert: Execute the priority task (should raise Exception)
|
||||
with pytest.raises(Exception, match="File not found"):
|
||||
priority_rag_pipeline_run_task(file_id, tenant.id)
|
||||
|
||||
# Assert: Verify error was handled gracefully
|
||||
mock_file_service["get_content"].assert_called_once_with(file_id)
|
||||
mock_pipeline_generator.assert_not_called()
|
||||
|
||||
# Verify waiting task was still processed despite file error
|
||||
mock_delay.assert_called_once()
|
||||
|
||||
# Verify correct parameters for the call
|
||||
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
|
||||
assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id
|
||||
assert call_kwargs.get("tenant_id") == tenant.id
|
||||
|
||||
# Verify queue is empty after processing (task was pulled)
|
||||
remaining_tasks = queue.pull_tasks(count=10)
|
||||
assert len(remaining_tasks) == 0
|
||||
|
||||
def test_rag_pipeline_run_task_file_not_found(
|
||||
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
|
||||
):
|
||||
"""
|
||||
Test regular RAG pipeline run task with non-existent file.
|
||||
|
||||
This test verifies:
|
||||
- Proper error handling for missing files
|
||||
- Exception logging
|
||||
- Function raises Exception for file errors
|
||||
- Queue management continues despite file errors
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
account, tenant, pipeline, workflow = self._create_test_pipeline_and_workflow(db_session_with_containers)
|
||||
|
||||
# Mock file service to raise exception
|
||||
file_id = str(uuid.uuid4())
|
||||
mock_file_service["get_content"].side_effect = Exception("File not found")
|
||||
|
||||
# Use real Redis for TenantIsolatedTaskQueue
|
||||
queue = TenantIsolatedTaskQueue(tenant.id, "pipeline")
|
||||
|
||||
# Add waiting task to the real Redis queue
|
||||
waiting_file_id = str(uuid.uuid4())
|
||||
queue.push_tasks([waiting_file_id])
|
||||
|
||||
# Mock the task function calls
|
||||
with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay:
|
||||
# Act & Assert: Execute the regular task (should raise Exception)
|
||||
with pytest.raises(Exception, match="File not found"):
|
||||
rag_pipeline_run_task(file_id, tenant.id)
|
||||
|
||||
# Assert: Verify error was handled gracefully
|
||||
mock_file_service["get_content"].assert_called_once_with(file_id)
|
||||
mock_pipeline_generator.assert_not_called()
|
||||
|
||||
# Verify waiting task was still processed despite file error
|
||||
mock_delay.assert_called_once()
|
||||
|
||||
# Verify correct parameters for the call
|
||||
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
|
||||
assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id
|
||||
assert call_kwargs.get("tenant_id") == tenant.id
|
||||
|
||||
# Verify queue is empty after processing (task was pulled)
|
||||
remaining_tasks = queue.pull_tasks(count=10)
|
||||
assert len(remaining_tasks) == 0
|
||||
Reference in New Issue
Block a user