mirror of
https://github.com/langgenius/dify.git
synced 2026-03-06 06:35:24 -05:00
feat: Human Input Node (#32060)
The frontend and backend implementation for the human input node. Co-authored-by: twwu <twwu@dify.ai> Co-authored-by: JzoNg <jzongcode@gmail.com> Co-authored-by: yyh <92089059+lyzno1@users.noreply.github.com> Co-authored-by: zhsama <torvalds@linux.do>
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import logging
|
||||
import os
|
||||
import pathlib
|
||||
import random
|
||||
@@ -10,26 +11,34 @@ from flask.testing import FlaskClient
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app_factory import create_app
|
||||
from configs.app_config import DifyConfig
|
||||
from extensions.ext_database import db
|
||||
from models import Account, DifySetup, Tenant, TenantAccountJoin
|
||||
from services.account_service import AccountService, RegisterService
|
||||
|
||||
_DEFUALT_TEST_ENV = ".env"
|
||||
_DEFAULT_VDB_TEST_ENV = "vdb.env"
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Loading the .env file if it exists
|
||||
def _load_env():
|
||||
current_file_path = pathlib.Path(__file__).absolute()
|
||||
# Items later in the list have higher precedence.
|
||||
files_to_load = [".env", "vdb.env"]
|
||||
env_file_paths = [
|
||||
os.getenv("DIFY_TEST_ENV_FILE", str(current_file_path.parent / _DEFUALT_TEST_ENV)),
|
||||
os.getenv("DIFY_VDB_TEST_ENV_FILE", str(current_file_path.parent / _DEFAULT_VDB_TEST_ENV)),
|
||||
]
|
||||
|
||||
env_file_paths = [current_file_path.parent / i for i in files_to_load]
|
||||
for path in env_file_paths:
|
||||
if not path.exists():
|
||||
continue
|
||||
for env_path_str in env_file_paths:
|
||||
if not pathlib.Path(env_path_str).exists():
|
||||
_logger.warning("specified configuration file %s not exist", env_path_str)
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Set `override=True` to ensure values from `vdb.env` take priority over values from `.env`
|
||||
load_dotenv(str(path), override=True)
|
||||
load_dotenv(str(env_path_str), override=True)
|
||||
|
||||
|
||||
_load_env()
|
||||
@@ -41,6 +50,12 @@ os.environ.setdefault("OPENDAL_SCHEME", "fs")
|
||||
_CACHED_APP = create_app()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def dify_config() -> DifyConfig:
|
||||
config = DifyConfig() # type: ignore
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def flask_app() -> Flask:
|
||||
return _CACHED_APP
|
||||
|
||||
@@ -0,0 +1,36 @@
|
||||
"""
|
||||
Utilities and helpers for Redis broadcast channel integration tests.
|
||||
|
||||
This module provides utility classes and functions for testing
|
||||
Redis broadcast channel functionality.
|
||||
"""
|
||||
|
||||
from .test_data import (
|
||||
LARGE_MESSAGES,
|
||||
SMALL_MESSAGES,
|
||||
SPECIAL_MESSAGES,
|
||||
BufferTestConfig,
|
||||
ConcurrencyTestConfig,
|
||||
ErrorTestConfig,
|
||||
)
|
||||
from .test_helpers import (
|
||||
ConcurrentPublisher,
|
||||
SubscriptionMonitor,
|
||||
assert_message_order,
|
||||
measure_throughput,
|
||||
wait_for_condition,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"LARGE_MESSAGES",
|
||||
"SMALL_MESSAGES",
|
||||
"SPECIAL_MESSAGES",
|
||||
"BufferTestConfig",
|
||||
"ConcurrencyTestConfig",
|
||||
"ConcurrentPublisher",
|
||||
"ErrorTestConfig",
|
||||
"SubscriptionMonitor",
|
||||
"assert_message_order",
|
||||
"measure_throughput",
|
||||
"wait_for_condition",
|
||||
]
|
||||
@@ -0,0 +1,315 @@
|
||||
"""
|
||||
Test data and configuration classes for Redis broadcast channel integration tests.
|
||||
|
||||
This module provides dataclasses and constants for test configurations,
|
||||
message sets, and test scenarios.
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
from typing import Any
|
||||
|
||||
from libs.broadcast_channel.channel import Overflow
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class BufferTestConfig:
|
||||
"""Configuration for buffer management tests."""
|
||||
|
||||
buffer_size: int
|
||||
overflow_strategy: Overflow
|
||||
message_count: int
|
||||
expected_behavior: str
|
||||
description: str
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class ConcurrencyTestConfig:
|
||||
"""Configuration for concurrency tests."""
|
||||
|
||||
publisher_count: int
|
||||
subscriber_count: int
|
||||
messages_per_publisher: int
|
||||
test_duration: float
|
||||
description: str
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class ErrorTestConfig:
|
||||
"""Configuration for error handling tests."""
|
||||
|
||||
error_type: str
|
||||
test_input: Any
|
||||
expected_exception: type[Exception]
|
||||
description: str
|
||||
|
||||
|
||||
# Test message sets for different scenarios
|
||||
SMALL_MESSAGES = [
|
||||
b"msg_1",
|
||||
b"msg_2",
|
||||
b"msg_3",
|
||||
b"msg_4",
|
||||
b"msg_5",
|
||||
]
|
||||
|
||||
MEDIUM_MESSAGES = [
|
||||
b"medium_message_1_with_more_content",
|
||||
b"medium_message_2_with_more_content",
|
||||
b"medium_message_3_with_more_content",
|
||||
b"medium_message_4_with_more_content",
|
||||
b"medium_message_5_with_more_content",
|
||||
]
|
||||
|
||||
LARGE_MESSAGES = [
|
||||
b"large_message_" + b"x" * 1000,
|
||||
b"large_message_" + b"y" * 1000,
|
||||
b"large_message_" + b"z" * 1000,
|
||||
]
|
||||
|
||||
VERY_LARGE_MESSAGES = [
|
||||
b"very_large_message_" + b"x" * 10000, # ~10KB
|
||||
b"very_large_message_" + b"y" * 50000, # ~50KB
|
||||
b"very_large_message_" + b"z" * 100000, # ~100KB
|
||||
]
|
||||
|
||||
SPECIAL_MESSAGES = [
|
||||
b"", # Empty message
|
||||
b"\x00\x01\x02", # Binary data with null bytes
|
||||
"unicode_test_你好".encode(), # Unicode
|
||||
b"special_chars_!@#$%^&*()_+-=[]{}|;':\",./<>?", # Special characters
|
||||
b"newlines\n\r\t", # Control characters
|
||||
]
|
||||
|
||||
BINARY_MESSAGES = [
|
||||
bytes(range(256)), # All possible byte values
|
||||
b"\xff\xfe\xfd\xfc\xfb\xfa\xf9\xf8", # High byte values
|
||||
b"\x00\x01\x02\x03\x04\x05\x06\x07", # Low byte values
|
||||
]
|
||||
|
||||
# Buffer test configurations
|
||||
BUFFER_TEST_CONFIGS = [
|
||||
BufferTestConfig(
|
||||
buffer_size=3,
|
||||
overflow_strategy=Overflow.DROP_OLDEST,
|
||||
message_count=5,
|
||||
expected_behavior="drop_oldest",
|
||||
description="Drop oldest messages when buffer is full",
|
||||
),
|
||||
BufferTestConfig(
|
||||
buffer_size=3,
|
||||
overflow_strategy=Overflow.DROP_NEWEST,
|
||||
message_count=5,
|
||||
expected_behavior="drop_newest",
|
||||
description="Drop newest messages when buffer is full",
|
||||
),
|
||||
BufferTestConfig(
|
||||
buffer_size=3,
|
||||
overflow_strategy=Overflow.BLOCK,
|
||||
message_count=5,
|
||||
expected_behavior="block",
|
||||
description="Block when buffer is full",
|
||||
),
|
||||
]
|
||||
|
||||
# Concurrency test configurations
|
||||
CONCURRENCY_TEST_CONFIGS = [
|
||||
ConcurrencyTestConfig(
|
||||
publisher_count=1,
|
||||
subscriber_count=1,
|
||||
messages_per_publisher=10,
|
||||
test_duration=5.0,
|
||||
description="Single publisher, single subscriber",
|
||||
),
|
||||
ConcurrencyTestConfig(
|
||||
publisher_count=3,
|
||||
subscriber_count=1,
|
||||
messages_per_publisher=10,
|
||||
test_duration=5.0,
|
||||
description="Multiple publishers, single subscriber",
|
||||
),
|
||||
ConcurrencyTestConfig(
|
||||
publisher_count=1,
|
||||
subscriber_count=3,
|
||||
messages_per_publisher=10,
|
||||
test_duration=5.0,
|
||||
description="Single publisher, multiple subscribers",
|
||||
),
|
||||
ConcurrencyTestConfig(
|
||||
publisher_count=3,
|
||||
subscriber_count=3,
|
||||
messages_per_publisher=10,
|
||||
test_duration=5.0,
|
||||
description="Multiple publishers, multiple subscribers",
|
||||
),
|
||||
]
|
||||
|
||||
# Error test configurations
|
||||
ERROR_TEST_CONFIGS = [
|
||||
ErrorTestConfig(
|
||||
error_type="invalid_buffer_size",
|
||||
test_input=0,
|
||||
expected_exception=ValueError,
|
||||
description="Zero buffer size should raise ValueError",
|
||||
),
|
||||
ErrorTestConfig(
|
||||
error_type="invalid_buffer_size",
|
||||
test_input=-1,
|
||||
expected_exception=ValueError,
|
||||
description="Negative buffer size should raise ValueError",
|
||||
),
|
||||
ErrorTestConfig(
|
||||
error_type="invalid_buffer_size",
|
||||
test_input=1.5,
|
||||
expected_exception=TypeError,
|
||||
description="Float buffer size should raise TypeError",
|
||||
),
|
||||
ErrorTestConfig(
|
||||
error_type="invalid_buffer_size",
|
||||
test_input="invalid",
|
||||
expected_exception=TypeError,
|
||||
description="String buffer size should raise TypeError",
|
||||
),
|
||||
]
|
||||
|
||||
# Topic name test cases
|
||||
TOPIC_NAME_TEST_CASES = [
|
||||
"simple_topic",
|
||||
"topic_with_underscores",
|
||||
"topic-with-dashes",
|
||||
"topic.with.dots",
|
||||
"topic_with_numbers_123",
|
||||
"UPPERCASE_TOPIC",
|
||||
"mixed_Case_Topic",
|
||||
"topic_with_symbols_!@#$%",
|
||||
"very_long_topic_name_" + "x" * 100,
|
||||
"unicode_topic_你好",
|
||||
"topic:with:colons",
|
||||
"topic/with/slashes",
|
||||
"topic\\with\\backslashes",
|
||||
]
|
||||
|
||||
# Performance test configurations
|
||||
PERFORMANCE_TEST_CONFIGS = [
|
||||
{
|
||||
"name": "small_messages_high_frequency",
|
||||
"message_size": 50,
|
||||
"message_count": 1000,
|
||||
"description": "Many small messages",
|
||||
},
|
||||
{
|
||||
"name": "medium_messages_medium_frequency",
|
||||
"message_size": 500,
|
||||
"message_count": 100,
|
||||
"description": "Medium messages",
|
||||
},
|
||||
{
|
||||
"name": "large_messages_low_frequency",
|
||||
"message_size": 5000,
|
||||
"message_count": 10,
|
||||
"description": "Large messages",
|
||||
},
|
||||
]
|
||||
|
||||
# Stress test configurations
|
||||
STRESS_TEST_CONFIGS = [
|
||||
{
|
||||
"name": "high_frequency_publishing",
|
||||
"publisher_count": 5,
|
||||
"messages_per_publisher": 100,
|
||||
"subscriber_count": 3,
|
||||
"description": "High frequency publishing with multiple publishers",
|
||||
},
|
||||
{
|
||||
"name": "many_subscribers",
|
||||
"publisher_count": 1,
|
||||
"messages_per_publisher": 50,
|
||||
"subscriber_count": 10,
|
||||
"description": "Many subscribers to single publisher",
|
||||
},
|
||||
{
|
||||
"name": "mixed_load",
|
||||
"publisher_count": 3,
|
||||
"messages_per_publisher": 100,
|
||||
"subscriber_count": 5,
|
||||
"description": "Mixed load with multiple publishers and subscribers",
|
||||
},
|
||||
]
|
||||
|
||||
# Edge case test data
|
||||
EDGE_CASE_MESSAGES = [
|
||||
b"", # Empty message
|
||||
b"\x00", # Single null byte
|
||||
b"\xff", # Single max byte value
|
||||
b"a", # Single ASCII character
|
||||
"ä".encode(), # Single unicode character (2 bytes)
|
||||
"𐍈".encode(), # Unicode character outside BMP (4 bytes)
|
||||
b"\x00" * 1000, # 1000 null bytes
|
||||
b"\xff" * 1000, # 1000 max byte values
|
||||
]
|
||||
|
||||
# Message validation test data
|
||||
MESSAGE_VALIDATION_TEST_CASES = [
|
||||
{
|
||||
"name": "valid_bytes",
|
||||
"input": b"valid_message",
|
||||
"should_pass": True,
|
||||
"description": "Valid bytes message",
|
||||
},
|
||||
{
|
||||
"name": "empty_bytes",
|
||||
"input": b"",
|
||||
"should_pass": True,
|
||||
"description": "Empty bytes message",
|
||||
},
|
||||
{
|
||||
"name": "binary_data",
|
||||
"input": bytes(range(256)),
|
||||
"should_pass": True,
|
||||
"description": "Binary data with all byte values",
|
||||
},
|
||||
{
|
||||
"name": "large_message",
|
||||
"input": b"x" * 1000000, # 1MB
|
||||
"should_pass": True,
|
||||
"description": "Large message (1MB)",
|
||||
},
|
||||
]
|
||||
|
||||
# Redis connection test scenarios
|
||||
REDIS_CONNECTION_TEST_SCENARIOS = [
|
||||
{
|
||||
"name": "normal_connection",
|
||||
"should_fail": False,
|
||||
"description": "Normal Redis connection",
|
||||
},
|
||||
{
|
||||
"name": "connection_timeout",
|
||||
"should_fail": True,
|
||||
"description": "Connection timeout scenario",
|
||||
},
|
||||
{
|
||||
"name": "connection_refused",
|
||||
"should_fail": True,
|
||||
"description": "Connection refused scenario",
|
||||
},
|
||||
]
|
||||
|
||||
# Test constants
|
||||
DEFAULT_TIMEOUT = 10.0
|
||||
SHORT_TIMEOUT = 2.0
|
||||
LONG_TIMEOUT = 30.0
|
||||
|
||||
# Message size limits for testing
|
||||
MAX_SMALL_MESSAGE_SIZE = 100
|
||||
MAX_MEDIUM_MESSAGE_SIZE = 1000
|
||||
MAX_LARGE_MESSAGE_SIZE = 10000
|
||||
|
||||
# Thread counts for concurrency testing
|
||||
MIN_THREAD_COUNT = 1
|
||||
MAX_THREAD_COUNT = 10
|
||||
DEFAULT_THREAD_COUNT = 3
|
||||
|
||||
# Buffer sizes for testing
|
||||
MIN_BUFFER_SIZE = 1
|
||||
MAX_BUFFER_SIZE = 1000
|
||||
DEFAULT_BUFFER_SIZE = 10
|
||||
@@ -0,0 +1,396 @@
|
||||
"""
|
||||
Test helper utilities for Redis broadcast channel integration tests.
|
||||
|
||||
This module provides utility classes and functions for testing concurrent
|
||||
operations, monitoring subscriptions, and measuring performance.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ConcurrentPublisher:
|
||||
"""
|
||||
Utility class for publishing messages concurrently from multiple threads.
|
||||
|
||||
This class manages multiple publisher threads that can publish messages
|
||||
to the same or different topics concurrently, useful for stress testing
|
||||
and concurrency validation.
|
||||
"""
|
||||
|
||||
def __init__(self, producer, message_count: int = 10, delay: float = 0.0):
|
||||
"""
|
||||
Initialize the concurrent publisher.
|
||||
|
||||
Args:
|
||||
producer: The producer instance to publish with
|
||||
message_count: Number of messages to publish per thread
|
||||
delay: Delay between messages in seconds
|
||||
"""
|
||||
self.producer = producer
|
||||
self.message_count = message_count
|
||||
self.delay = delay
|
||||
self.threads: list[threading.Thread] = []
|
||||
self.published_messages: list[list[bytes]] = []
|
||||
self._lock = threading.Lock()
|
||||
self._started = False
|
||||
|
||||
def start_publishers(self, thread_count: int = 3) -> None:
|
||||
"""
|
||||
Start multiple publisher threads.
|
||||
|
||||
Args:
|
||||
thread_count: Number of publisher threads to start
|
||||
"""
|
||||
if self._started:
|
||||
raise RuntimeError("Publishers already started")
|
||||
|
||||
self._started = True
|
||||
|
||||
def _publisher(thread_id: int) -> None:
|
||||
messages: list[bytes] = []
|
||||
for i in range(self.message_count):
|
||||
message = f"thread_{thread_id}_msg_{i}".encode()
|
||||
try:
|
||||
self.producer.publish(message)
|
||||
messages.append(message)
|
||||
if self.delay > 0:
|
||||
time.sleep(self.delay)
|
||||
except Exception:
|
||||
_logger.exception("Pubmsg=lisher %s", thread_id)
|
||||
|
||||
with self._lock:
|
||||
self.published_messages.append(messages)
|
||||
|
||||
for thread_id in range(thread_count):
|
||||
thread = threading.Thread(
|
||||
target=_publisher,
|
||||
args=(thread_id,),
|
||||
name=f"publisher-{thread_id}",
|
||||
daemon=True,
|
||||
)
|
||||
thread.start()
|
||||
self.threads.append(thread)
|
||||
|
||||
def wait_for_completion(self, timeout: float = 30.0) -> bool:
|
||||
"""
|
||||
Wait for all publisher threads to complete.
|
||||
|
||||
Args:
|
||||
timeout: Maximum time to wait in seconds
|
||||
|
||||
Returns:
|
||||
bool: True if all threads completed successfully
|
||||
"""
|
||||
for thread in self.threads:
|
||||
thread.join(timeout)
|
||||
if thread.is_alive():
|
||||
return False
|
||||
return True
|
||||
|
||||
def get_all_messages(self) -> list[bytes]:
|
||||
"""
|
||||
Get all messages published by all threads.
|
||||
|
||||
Returns:
|
||||
list[bytes]: Flattened list of all published messages
|
||||
"""
|
||||
with self._lock:
|
||||
all_messages = []
|
||||
for thread_messages in self.published_messages:
|
||||
all_messages.extend(thread_messages)
|
||||
return all_messages
|
||||
|
||||
def get_thread_messages(self, thread_id: int) -> list[bytes]:
|
||||
"""
|
||||
Get messages published by a specific thread.
|
||||
|
||||
Args:
|
||||
thread_id: ID of the thread
|
||||
|
||||
Returns:
|
||||
list[bytes]: Messages published by the specified thread
|
||||
"""
|
||||
with self._lock:
|
||||
if 0 <= thread_id < len(self.published_messages):
|
||||
return self.published_messages[thread_id].copy()
|
||||
return []
|
||||
|
||||
|
||||
class SubscriptionMonitor:
|
||||
"""
|
||||
Utility class for monitoring subscription activity in tests.
|
||||
|
||||
This class monitors a subscription and tracks message reception,
|
||||
errors, and completion status for testing purposes.
|
||||
"""
|
||||
|
||||
def __init__(self, subscription, timeout: float = 10.0):
|
||||
"""
|
||||
Initialize the subscription monitor.
|
||||
|
||||
Args:
|
||||
subscription: The subscription to monitor
|
||||
timeout: Default timeout for operations
|
||||
"""
|
||||
self.subscription = subscription
|
||||
self.timeout = timeout
|
||||
self.messages: list[bytes] = []
|
||||
self.errors: list[Exception] = []
|
||||
self.completed = False
|
||||
self._lock = threading.Lock()
|
||||
self._condition = threading.Condition(self._lock)
|
||||
self._monitor_thread: threading.Thread | None = None
|
||||
self._start_time: float | None = None
|
||||
|
||||
def start_monitoring(self) -> None:
|
||||
"""Start monitoring the subscription in a separate thread."""
|
||||
if self._monitor_thread is not None:
|
||||
raise RuntimeError("Monitoring already started")
|
||||
|
||||
self._start_time = time.time()
|
||||
|
||||
def _monitor():
|
||||
try:
|
||||
for message in self.subscription:
|
||||
with self._lock:
|
||||
self.messages.append(message)
|
||||
self._condition.notify_all()
|
||||
except Exception as e:
|
||||
with self._lock:
|
||||
self.errors.append(e)
|
||||
self._condition.notify_all()
|
||||
finally:
|
||||
with self._lock:
|
||||
self.completed = True
|
||||
self._condition.notify_all()
|
||||
|
||||
self._monitor_thread = threading.Thread(
|
||||
target=_monitor,
|
||||
name="subscription-monitor",
|
||||
daemon=True,
|
||||
)
|
||||
self._monitor_thread.start()
|
||||
|
||||
def wait_for_messages(self, count: int, timeout: float | None = None) -> bool:
|
||||
"""
|
||||
Wait for a specific number of messages.
|
||||
|
||||
Args:
|
||||
count: Number of messages to wait for
|
||||
timeout: Timeout in seconds (uses default if None)
|
||||
|
||||
Returns:
|
||||
bool: True if expected messages were received
|
||||
"""
|
||||
if timeout is None:
|
||||
timeout = self.timeout
|
||||
|
||||
deadline = time.time() + timeout
|
||||
|
||||
with self._condition:
|
||||
while len(self.messages) < count and not self.completed:
|
||||
remaining = deadline - time.time()
|
||||
if remaining <= 0:
|
||||
return False
|
||||
self._condition.wait(remaining)
|
||||
|
||||
return len(self.messages) >= count
|
||||
|
||||
def wait_for_completion(self, timeout: float | None = None) -> bool:
|
||||
"""
|
||||
Wait for monitoring to complete.
|
||||
|
||||
Args:
|
||||
timeout: Timeout in seconds (uses default if None)
|
||||
|
||||
Returns:
|
||||
bool: True if monitoring completed successfully
|
||||
"""
|
||||
if timeout is None:
|
||||
timeout = self.timeout
|
||||
|
||||
deadline = time.time() + timeout
|
||||
|
||||
with self._condition:
|
||||
while not self.completed:
|
||||
remaining = deadline - time.time()
|
||||
if remaining <= 0:
|
||||
return False
|
||||
self._condition.wait(remaining)
|
||||
|
||||
return True
|
||||
|
||||
def get_messages(self) -> list[bytes]:
|
||||
"""
|
||||
Get all received messages.
|
||||
|
||||
Returns:
|
||||
list[bytes]: Copy of received messages
|
||||
"""
|
||||
with self._lock:
|
||||
return self.messages.copy()
|
||||
|
||||
def get_error_count(self) -> int:
|
||||
"""
|
||||
Get the number of errors encountered.
|
||||
|
||||
Returns:
|
||||
int: Number of errors
|
||||
"""
|
||||
with self._lock:
|
||||
return len(self.errors)
|
||||
|
||||
def get_elapsed_time(self) -> float:
|
||||
"""
|
||||
Get the elapsed monitoring time.
|
||||
|
||||
Returns:
|
||||
float: Elapsed time in seconds
|
||||
"""
|
||||
if self._start_time is None:
|
||||
return 0.0
|
||||
return time.time() - self._start_time
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop monitoring and close the subscription."""
|
||||
if self._monitor_thread is not None:
|
||||
self.subscription.close()
|
||||
self._monitor_thread.join(timeout=1.0)
|
||||
|
||||
|
||||
def assert_message_order(received: list[bytes], expected: list[bytes]) -> bool:
|
||||
"""
|
||||
Assert that messages were received in the expected order.
|
||||
|
||||
Args:
|
||||
received: List of received messages
|
||||
expected: List of expected messages in order
|
||||
|
||||
Returns:
|
||||
bool: True if order matches expected
|
||||
"""
|
||||
if len(received) != len(expected):
|
||||
return False
|
||||
|
||||
for i, (recv_msg, exp_msg) in enumerate(zip(received, expected)):
|
||||
if recv_msg != exp_msg:
|
||||
_logger.error("Message order mismatch at index %s: expected %s, got %s", i, exp_msg, recv_msg)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def measure_throughput(
|
||||
operation: Callable[[], Any],
|
||||
duration: float = 1.0,
|
||||
) -> tuple[float, int]:
|
||||
"""
|
||||
Measure the throughput of an operation over a specified duration.
|
||||
|
||||
Args:
|
||||
operation: The operation to measure
|
||||
duration: Duration to run the operation in seconds
|
||||
|
||||
Returns:
|
||||
tuple[float, int]: (operations per second, total operations)
|
||||
"""
|
||||
start_time = time.time()
|
||||
end_time = start_time + duration
|
||||
count = 0
|
||||
|
||||
while time.time() < end_time:
|
||||
try:
|
||||
operation()
|
||||
count += 1
|
||||
except Exception:
|
||||
_logger.exception("Operation failed")
|
||||
break
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
ops_per_sec = count / elapsed if elapsed > 0 else 0.0
|
||||
|
||||
return ops_per_sec, count
|
||||
|
||||
|
||||
def wait_for_condition(
|
||||
condition: Callable[[], bool],
|
||||
timeout: float = 10.0,
|
||||
interval: float = 0.1,
|
||||
) -> bool:
|
||||
"""
|
||||
Wait for a condition to become true.
|
||||
|
||||
Args:
|
||||
condition: Function that returns True when condition is met
|
||||
timeout: Maximum time to wait in seconds
|
||||
interval: Check interval in seconds
|
||||
|
||||
Returns:
|
||||
bool: True if condition was met within timeout
|
||||
"""
|
||||
deadline = time.time() + timeout
|
||||
|
||||
while time.time() < deadline:
|
||||
if condition():
|
||||
return True
|
||||
time.sleep(interval)
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def create_stress_test_messages(
|
||||
count: int,
|
||||
size: int = 100,
|
||||
) -> list[bytes]:
|
||||
"""
|
||||
Create messages for stress testing.
|
||||
|
||||
Args:
|
||||
count: Number of messages to create
|
||||
size: Size of each message in bytes
|
||||
|
||||
Returns:
|
||||
list[bytes]: List of test messages
|
||||
"""
|
||||
messages = []
|
||||
for i in range(count):
|
||||
message = f"stress_test_msg_{i:06d}_".ljust(size, "x").encode()
|
||||
messages.append(message)
|
||||
return messages
|
||||
|
||||
|
||||
def validate_message_integrity(
|
||||
original_messages: list[bytes],
|
||||
received_messages: list[bytes],
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Validate the integrity of received messages.
|
||||
|
||||
Args:
|
||||
original_messages: Messages that were sent
|
||||
received_messages: Messages that were received
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: Validation results
|
||||
"""
|
||||
original_set = set(original_messages)
|
||||
received_set = set(received_messages)
|
||||
|
||||
missing_messages = original_set - received_set
|
||||
extra_messages = received_set - original_set
|
||||
|
||||
return {
|
||||
"total_sent": len(original_messages),
|
||||
"total_received": len(received_messages),
|
||||
"missing_count": len(missing_messages),
|
||||
"extra_count": len(extra_messages),
|
||||
"missing_messages": list(missing_messages),
|
||||
"extra_messages": list(extra_messages),
|
||||
"integrity_ok": len(missing_messages) == 0 and len(extra_messages) == 0,
|
||||
}
|
||||
Reference in New Issue
Block a user