feat: Human Input Node (#32060)

The frontend and backend implementation for the human input node.

Co-authored-by: twwu <twwu@dify.ai>
Co-authored-by: JzoNg <jzongcode@gmail.com>
Co-authored-by: yyh <92089059+lyzno1@users.noreply.github.com>
Co-authored-by: zhsama <torvalds@linux.do>
This commit is contained in:
QuantumGhost
2026-02-09 14:57:23 +08:00
committed by GitHub
parent 56e3a55023
commit a1fc280102
474 changed files with 32667 additions and 2050 deletions

View File

@@ -1,3 +1,4 @@
import logging
import os
import pathlib
import random
@@ -10,26 +11,34 @@ from flask.testing import FlaskClient
from sqlalchemy.orm import Session
from app_factory import create_app
from configs.app_config import DifyConfig
from extensions.ext_database import db
from models import Account, DifySetup, Tenant, TenantAccountJoin
from services.account_service import AccountService, RegisterService
_DEFUALT_TEST_ENV = ".env"
_DEFAULT_VDB_TEST_ENV = "vdb.env"
_logger = logging.getLogger(__name__)
# Loading the .env file if it exists
def _load_env():
current_file_path = pathlib.Path(__file__).absolute()
# Items later in the list have higher precedence.
files_to_load = [".env", "vdb.env"]
env_file_paths = [
os.getenv("DIFY_TEST_ENV_FILE", str(current_file_path.parent / _DEFUALT_TEST_ENV)),
os.getenv("DIFY_VDB_TEST_ENV_FILE", str(current_file_path.parent / _DEFAULT_VDB_TEST_ENV)),
]
env_file_paths = [current_file_path.parent / i for i in files_to_load]
for path in env_file_paths:
if not path.exists():
continue
for env_path_str in env_file_paths:
if not pathlib.Path(env_path_str).exists():
_logger.warning("specified configuration file %s not exist", env_path_str)
from dotenv import load_dotenv
# Set `override=True` to ensure values from `vdb.env` take priority over values from `.env`
load_dotenv(str(path), override=True)
load_dotenv(str(env_path_str), override=True)
_load_env()
@@ -41,6 +50,12 @@ os.environ.setdefault("OPENDAL_SCHEME", "fs")
_CACHED_APP = create_app()
@pytest.fixture(scope="session")
def dify_config() -> DifyConfig:
config = DifyConfig() # type: ignore
return config
@pytest.fixture
def flask_app() -> Flask:
return _CACHED_APP

View File

@@ -0,0 +1,36 @@
"""
Utilities and helpers for Redis broadcast channel integration tests.
This module provides utility classes and functions for testing
Redis broadcast channel functionality.
"""
from .test_data import (
LARGE_MESSAGES,
SMALL_MESSAGES,
SPECIAL_MESSAGES,
BufferTestConfig,
ConcurrencyTestConfig,
ErrorTestConfig,
)
from .test_helpers import (
ConcurrentPublisher,
SubscriptionMonitor,
assert_message_order,
measure_throughput,
wait_for_condition,
)
__all__ = [
"LARGE_MESSAGES",
"SMALL_MESSAGES",
"SPECIAL_MESSAGES",
"BufferTestConfig",
"ConcurrencyTestConfig",
"ConcurrentPublisher",
"ErrorTestConfig",
"SubscriptionMonitor",
"assert_message_order",
"measure_throughput",
"wait_for_condition",
]

View File

@@ -0,0 +1,315 @@
"""
Test data and configuration classes for Redis broadcast channel integration tests.
This module provides dataclasses and constants for test configurations,
message sets, and test scenarios.
"""
import dataclasses
from typing import Any
from libs.broadcast_channel.channel import Overflow
@dataclasses.dataclass(frozen=True)
class BufferTestConfig:
"""Configuration for buffer management tests."""
buffer_size: int
overflow_strategy: Overflow
message_count: int
expected_behavior: str
description: str
@dataclasses.dataclass(frozen=True)
class ConcurrencyTestConfig:
"""Configuration for concurrency tests."""
publisher_count: int
subscriber_count: int
messages_per_publisher: int
test_duration: float
description: str
@dataclasses.dataclass(frozen=True)
class ErrorTestConfig:
"""Configuration for error handling tests."""
error_type: str
test_input: Any
expected_exception: type[Exception]
description: str
# Test message sets for different scenarios
SMALL_MESSAGES = [
b"msg_1",
b"msg_2",
b"msg_3",
b"msg_4",
b"msg_5",
]
MEDIUM_MESSAGES = [
b"medium_message_1_with_more_content",
b"medium_message_2_with_more_content",
b"medium_message_3_with_more_content",
b"medium_message_4_with_more_content",
b"medium_message_5_with_more_content",
]
LARGE_MESSAGES = [
b"large_message_" + b"x" * 1000,
b"large_message_" + b"y" * 1000,
b"large_message_" + b"z" * 1000,
]
VERY_LARGE_MESSAGES = [
b"very_large_message_" + b"x" * 10000, # ~10KB
b"very_large_message_" + b"y" * 50000, # ~50KB
b"very_large_message_" + b"z" * 100000, # ~100KB
]
SPECIAL_MESSAGES = [
b"", # Empty message
b"\x00\x01\x02", # Binary data with null bytes
"unicode_test_你好".encode(), # Unicode
b"special_chars_!@#$%^&*()_+-=[]{}|;':\",./<>?", # Special characters
b"newlines\n\r\t", # Control characters
]
BINARY_MESSAGES = [
bytes(range(256)), # All possible byte values
b"\xff\xfe\xfd\xfc\xfb\xfa\xf9\xf8", # High byte values
b"\x00\x01\x02\x03\x04\x05\x06\x07", # Low byte values
]
# Buffer test configurations
BUFFER_TEST_CONFIGS = [
BufferTestConfig(
buffer_size=3,
overflow_strategy=Overflow.DROP_OLDEST,
message_count=5,
expected_behavior="drop_oldest",
description="Drop oldest messages when buffer is full",
),
BufferTestConfig(
buffer_size=3,
overflow_strategy=Overflow.DROP_NEWEST,
message_count=5,
expected_behavior="drop_newest",
description="Drop newest messages when buffer is full",
),
BufferTestConfig(
buffer_size=3,
overflow_strategy=Overflow.BLOCK,
message_count=5,
expected_behavior="block",
description="Block when buffer is full",
),
]
# Concurrency test configurations
CONCURRENCY_TEST_CONFIGS = [
ConcurrencyTestConfig(
publisher_count=1,
subscriber_count=1,
messages_per_publisher=10,
test_duration=5.0,
description="Single publisher, single subscriber",
),
ConcurrencyTestConfig(
publisher_count=3,
subscriber_count=1,
messages_per_publisher=10,
test_duration=5.0,
description="Multiple publishers, single subscriber",
),
ConcurrencyTestConfig(
publisher_count=1,
subscriber_count=3,
messages_per_publisher=10,
test_duration=5.0,
description="Single publisher, multiple subscribers",
),
ConcurrencyTestConfig(
publisher_count=3,
subscriber_count=3,
messages_per_publisher=10,
test_duration=5.0,
description="Multiple publishers, multiple subscribers",
),
]
# Error test configurations
ERROR_TEST_CONFIGS = [
ErrorTestConfig(
error_type="invalid_buffer_size",
test_input=0,
expected_exception=ValueError,
description="Zero buffer size should raise ValueError",
),
ErrorTestConfig(
error_type="invalid_buffer_size",
test_input=-1,
expected_exception=ValueError,
description="Negative buffer size should raise ValueError",
),
ErrorTestConfig(
error_type="invalid_buffer_size",
test_input=1.5,
expected_exception=TypeError,
description="Float buffer size should raise TypeError",
),
ErrorTestConfig(
error_type="invalid_buffer_size",
test_input="invalid",
expected_exception=TypeError,
description="String buffer size should raise TypeError",
),
]
# Topic name test cases
TOPIC_NAME_TEST_CASES = [
"simple_topic",
"topic_with_underscores",
"topic-with-dashes",
"topic.with.dots",
"topic_with_numbers_123",
"UPPERCASE_TOPIC",
"mixed_Case_Topic",
"topic_with_symbols_!@#$%",
"very_long_topic_name_" + "x" * 100,
"unicode_topic_你好",
"topic:with:colons",
"topic/with/slashes",
"topic\\with\\backslashes",
]
# Performance test configurations
PERFORMANCE_TEST_CONFIGS = [
{
"name": "small_messages_high_frequency",
"message_size": 50,
"message_count": 1000,
"description": "Many small messages",
},
{
"name": "medium_messages_medium_frequency",
"message_size": 500,
"message_count": 100,
"description": "Medium messages",
},
{
"name": "large_messages_low_frequency",
"message_size": 5000,
"message_count": 10,
"description": "Large messages",
},
]
# Stress test configurations
STRESS_TEST_CONFIGS = [
{
"name": "high_frequency_publishing",
"publisher_count": 5,
"messages_per_publisher": 100,
"subscriber_count": 3,
"description": "High frequency publishing with multiple publishers",
},
{
"name": "many_subscribers",
"publisher_count": 1,
"messages_per_publisher": 50,
"subscriber_count": 10,
"description": "Many subscribers to single publisher",
},
{
"name": "mixed_load",
"publisher_count": 3,
"messages_per_publisher": 100,
"subscriber_count": 5,
"description": "Mixed load with multiple publishers and subscribers",
},
]
# Edge case test data
EDGE_CASE_MESSAGES = [
b"", # Empty message
b"\x00", # Single null byte
b"\xff", # Single max byte value
b"a", # Single ASCII character
"ä".encode(), # Single unicode character (2 bytes)
"𐍈".encode(), # Unicode character outside BMP (4 bytes)
b"\x00" * 1000, # 1000 null bytes
b"\xff" * 1000, # 1000 max byte values
]
# Message validation test data
MESSAGE_VALIDATION_TEST_CASES = [
{
"name": "valid_bytes",
"input": b"valid_message",
"should_pass": True,
"description": "Valid bytes message",
},
{
"name": "empty_bytes",
"input": b"",
"should_pass": True,
"description": "Empty bytes message",
},
{
"name": "binary_data",
"input": bytes(range(256)),
"should_pass": True,
"description": "Binary data with all byte values",
},
{
"name": "large_message",
"input": b"x" * 1000000, # 1MB
"should_pass": True,
"description": "Large message (1MB)",
},
]
# Redis connection test scenarios
REDIS_CONNECTION_TEST_SCENARIOS = [
{
"name": "normal_connection",
"should_fail": False,
"description": "Normal Redis connection",
},
{
"name": "connection_timeout",
"should_fail": True,
"description": "Connection timeout scenario",
},
{
"name": "connection_refused",
"should_fail": True,
"description": "Connection refused scenario",
},
]
# Test constants
DEFAULT_TIMEOUT = 10.0
SHORT_TIMEOUT = 2.0
LONG_TIMEOUT = 30.0
# Message size limits for testing
MAX_SMALL_MESSAGE_SIZE = 100
MAX_MEDIUM_MESSAGE_SIZE = 1000
MAX_LARGE_MESSAGE_SIZE = 10000
# Thread counts for concurrency testing
MIN_THREAD_COUNT = 1
MAX_THREAD_COUNT = 10
DEFAULT_THREAD_COUNT = 3
# Buffer sizes for testing
MIN_BUFFER_SIZE = 1
MAX_BUFFER_SIZE = 1000
DEFAULT_BUFFER_SIZE = 10

View File

@@ -0,0 +1,396 @@
"""
Test helper utilities for Redis broadcast channel integration tests.
This module provides utility classes and functions for testing concurrent
operations, monitoring subscriptions, and measuring performance.
"""
import logging
import threading
import time
from collections.abc import Callable
from typing import Any
_logger = logging.getLogger(__name__)
class ConcurrentPublisher:
"""
Utility class for publishing messages concurrently from multiple threads.
This class manages multiple publisher threads that can publish messages
to the same or different topics concurrently, useful for stress testing
and concurrency validation.
"""
def __init__(self, producer, message_count: int = 10, delay: float = 0.0):
"""
Initialize the concurrent publisher.
Args:
producer: The producer instance to publish with
message_count: Number of messages to publish per thread
delay: Delay between messages in seconds
"""
self.producer = producer
self.message_count = message_count
self.delay = delay
self.threads: list[threading.Thread] = []
self.published_messages: list[list[bytes]] = []
self._lock = threading.Lock()
self._started = False
def start_publishers(self, thread_count: int = 3) -> None:
"""
Start multiple publisher threads.
Args:
thread_count: Number of publisher threads to start
"""
if self._started:
raise RuntimeError("Publishers already started")
self._started = True
def _publisher(thread_id: int) -> None:
messages: list[bytes] = []
for i in range(self.message_count):
message = f"thread_{thread_id}_msg_{i}".encode()
try:
self.producer.publish(message)
messages.append(message)
if self.delay > 0:
time.sleep(self.delay)
except Exception:
_logger.exception("Pubmsg=lisher %s", thread_id)
with self._lock:
self.published_messages.append(messages)
for thread_id in range(thread_count):
thread = threading.Thread(
target=_publisher,
args=(thread_id,),
name=f"publisher-{thread_id}",
daemon=True,
)
thread.start()
self.threads.append(thread)
def wait_for_completion(self, timeout: float = 30.0) -> bool:
"""
Wait for all publisher threads to complete.
Args:
timeout: Maximum time to wait in seconds
Returns:
bool: True if all threads completed successfully
"""
for thread in self.threads:
thread.join(timeout)
if thread.is_alive():
return False
return True
def get_all_messages(self) -> list[bytes]:
"""
Get all messages published by all threads.
Returns:
list[bytes]: Flattened list of all published messages
"""
with self._lock:
all_messages = []
for thread_messages in self.published_messages:
all_messages.extend(thread_messages)
return all_messages
def get_thread_messages(self, thread_id: int) -> list[bytes]:
"""
Get messages published by a specific thread.
Args:
thread_id: ID of the thread
Returns:
list[bytes]: Messages published by the specified thread
"""
with self._lock:
if 0 <= thread_id < len(self.published_messages):
return self.published_messages[thread_id].copy()
return []
class SubscriptionMonitor:
"""
Utility class for monitoring subscription activity in tests.
This class monitors a subscription and tracks message reception,
errors, and completion status for testing purposes.
"""
def __init__(self, subscription, timeout: float = 10.0):
"""
Initialize the subscription monitor.
Args:
subscription: The subscription to monitor
timeout: Default timeout for operations
"""
self.subscription = subscription
self.timeout = timeout
self.messages: list[bytes] = []
self.errors: list[Exception] = []
self.completed = False
self._lock = threading.Lock()
self._condition = threading.Condition(self._lock)
self._monitor_thread: threading.Thread | None = None
self._start_time: float | None = None
def start_monitoring(self) -> None:
"""Start monitoring the subscription in a separate thread."""
if self._monitor_thread is not None:
raise RuntimeError("Monitoring already started")
self._start_time = time.time()
def _monitor():
try:
for message in self.subscription:
with self._lock:
self.messages.append(message)
self._condition.notify_all()
except Exception as e:
with self._lock:
self.errors.append(e)
self._condition.notify_all()
finally:
with self._lock:
self.completed = True
self._condition.notify_all()
self._monitor_thread = threading.Thread(
target=_monitor,
name="subscription-monitor",
daemon=True,
)
self._monitor_thread.start()
def wait_for_messages(self, count: int, timeout: float | None = None) -> bool:
"""
Wait for a specific number of messages.
Args:
count: Number of messages to wait for
timeout: Timeout in seconds (uses default if None)
Returns:
bool: True if expected messages were received
"""
if timeout is None:
timeout = self.timeout
deadline = time.time() + timeout
with self._condition:
while len(self.messages) < count and not self.completed:
remaining = deadline - time.time()
if remaining <= 0:
return False
self._condition.wait(remaining)
return len(self.messages) >= count
def wait_for_completion(self, timeout: float | None = None) -> bool:
"""
Wait for monitoring to complete.
Args:
timeout: Timeout in seconds (uses default if None)
Returns:
bool: True if monitoring completed successfully
"""
if timeout is None:
timeout = self.timeout
deadline = time.time() + timeout
with self._condition:
while not self.completed:
remaining = deadline - time.time()
if remaining <= 0:
return False
self._condition.wait(remaining)
return True
def get_messages(self) -> list[bytes]:
"""
Get all received messages.
Returns:
list[bytes]: Copy of received messages
"""
with self._lock:
return self.messages.copy()
def get_error_count(self) -> int:
"""
Get the number of errors encountered.
Returns:
int: Number of errors
"""
with self._lock:
return len(self.errors)
def get_elapsed_time(self) -> float:
"""
Get the elapsed monitoring time.
Returns:
float: Elapsed time in seconds
"""
if self._start_time is None:
return 0.0
return time.time() - self._start_time
def stop(self) -> None:
"""Stop monitoring and close the subscription."""
if self._monitor_thread is not None:
self.subscription.close()
self._monitor_thread.join(timeout=1.0)
def assert_message_order(received: list[bytes], expected: list[bytes]) -> bool:
"""
Assert that messages were received in the expected order.
Args:
received: List of received messages
expected: List of expected messages in order
Returns:
bool: True if order matches expected
"""
if len(received) != len(expected):
return False
for i, (recv_msg, exp_msg) in enumerate(zip(received, expected)):
if recv_msg != exp_msg:
_logger.error("Message order mismatch at index %s: expected %s, got %s", i, exp_msg, recv_msg)
return False
return True
def measure_throughput(
operation: Callable[[], Any],
duration: float = 1.0,
) -> tuple[float, int]:
"""
Measure the throughput of an operation over a specified duration.
Args:
operation: The operation to measure
duration: Duration to run the operation in seconds
Returns:
tuple[float, int]: (operations per second, total operations)
"""
start_time = time.time()
end_time = start_time + duration
count = 0
while time.time() < end_time:
try:
operation()
count += 1
except Exception:
_logger.exception("Operation failed")
break
elapsed = time.time() - start_time
ops_per_sec = count / elapsed if elapsed > 0 else 0.0
return ops_per_sec, count
def wait_for_condition(
condition: Callable[[], bool],
timeout: float = 10.0,
interval: float = 0.1,
) -> bool:
"""
Wait for a condition to become true.
Args:
condition: Function that returns True when condition is met
timeout: Maximum time to wait in seconds
interval: Check interval in seconds
Returns:
bool: True if condition was met within timeout
"""
deadline = time.time() + timeout
while time.time() < deadline:
if condition():
return True
time.sleep(interval)
return False
def create_stress_test_messages(
count: int,
size: int = 100,
) -> list[bytes]:
"""
Create messages for stress testing.
Args:
count: Number of messages to create
size: Size of each message in bytes
Returns:
list[bytes]: List of test messages
"""
messages = []
for i in range(count):
message = f"stress_test_msg_{i:06d}_".ljust(size, "x").encode()
messages.append(message)
return messages
def validate_message_integrity(
original_messages: list[bytes],
received_messages: list[bytes],
) -> dict[str, Any]:
"""
Validate the integrity of received messages.
Args:
original_messages: Messages that were sent
received_messages: Messages that were received
Returns:
dict[str, Any]: Validation results
"""
original_set = set(original_messages)
received_set = set(received_messages)
missing_messages = original_set - received_set
extra_messages = received_set - original_set
return {
"total_sent": len(original_messages),
"total_received": len(received_messages),
"missing_count": len(missing_messages),
"extra_count": len(extra_messages),
"missing_messages": list(missing_messages),
"extra_messages": list(extra_messages),
"integrity_ok": len(missing_messages) == 0 and len(extra_messages) == 0,
}