mirror of
https://github.com/langgenius/dify.git
synced 2025-12-19 17:27:16 -05:00
fix(api): SegmentType.is_valid() raises AssertionError for SegmentType.GROUP (#28249)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -1,9 +1,12 @@
|
||||
from collections.abc import Mapping
|
||||
from enum import StrEnum
|
||||
from typing import Any, Optional
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from core.file.models import File
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
class ArrayValidation(StrEnum):
|
||||
"""Strategy for validating array elements.
|
||||
@@ -155,6 +158,17 @@ class SegmentType(StrEnum):
|
||||
return isinstance(value, File)
|
||||
elif self == SegmentType.NONE:
|
||||
return value is None
|
||||
elif self == SegmentType.GROUP:
|
||||
from .segment_group import SegmentGroup
|
||||
from .segments import Segment
|
||||
|
||||
if isinstance(value, SegmentGroup):
|
||||
return all(isinstance(item, Segment) for item in value.value)
|
||||
|
||||
if isinstance(value, list):
|
||||
return all(isinstance(item, Segment) for item in value)
|
||||
|
||||
return False
|
||||
else:
|
||||
raise AssertionError("this statement should be unreachable.")
|
||||
|
||||
|
||||
@@ -12,6 +12,16 @@ import pytest
|
||||
|
||||
from core.file.enums import FileTransferMethod, FileType
|
||||
from core.file.models import File
|
||||
from core.variables.segment_group import SegmentGroup
|
||||
from core.variables.segments import (
|
||||
ArrayFileSegment,
|
||||
BooleanSegment,
|
||||
FileSegment,
|
||||
IntegerSegment,
|
||||
NoneSegment,
|
||||
ObjectSegment,
|
||||
StringSegment,
|
||||
)
|
||||
from core.variables.types import ArrayValidation, SegmentType
|
||||
|
||||
|
||||
@@ -202,6 +212,45 @@ def get_none_cases() -> list[ValidationTestCase]:
|
||||
]
|
||||
|
||||
|
||||
def get_group_cases() -> list[ValidationTestCase]:
|
||||
"""Get test cases for valid group values."""
|
||||
test_file = create_test_file()
|
||||
segments = [
|
||||
StringSegment(value="hello"),
|
||||
IntegerSegment(value=42),
|
||||
BooleanSegment(value=True),
|
||||
ObjectSegment(value={"key": "value"}),
|
||||
FileSegment(value=test_file),
|
||||
NoneSegment(value=None),
|
||||
]
|
||||
|
||||
return [
|
||||
# valid cases
|
||||
ValidationTestCase(
|
||||
SegmentType.GROUP, SegmentGroup(value=segments), True, "Valid SegmentGroup with mixed segments"
|
||||
),
|
||||
ValidationTestCase(
|
||||
SegmentType.GROUP, [StringSegment(value="test"), IntegerSegment(value=123)], True, "List of Segment objects"
|
||||
),
|
||||
ValidationTestCase(SegmentType.GROUP, SegmentGroup(value=[]), True, "Empty SegmentGroup"),
|
||||
ValidationTestCase(SegmentType.GROUP, [], True, "Empty list"),
|
||||
# invalid cases
|
||||
ValidationTestCase(SegmentType.GROUP, "not a list", False, "String value"),
|
||||
ValidationTestCase(SegmentType.GROUP, 123, False, "Integer value"),
|
||||
ValidationTestCase(SegmentType.GROUP, True, False, "Boolean value"),
|
||||
ValidationTestCase(SegmentType.GROUP, None, False, "None value"),
|
||||
ValidationTestCase(SegmentType.GROUP, {"key": "value"}, False, "Dict value"),
|
||||
ValidationTestCase(SegmentType.GROUP, test_file, False, "File value"),
|
||||
ValidationTestCase(SegmentType.GROUP, ["string", 123, True], False, "List with non-Segment objects"),
|
||||
ValidationTestCase(
|
||||
SegmentType.GROUP,
|
||||
[StringSegment(value="test"), "not a segment"],
|
||||
False,
|
||||
"Mixed list with some non-Segment objects",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def get_array_any_validation_cases() -> list[ArrayValidationTestCase]:
|
||||
"""Get test cases for ARRAY_ANY validation."""
|
||||
return [
|
||||
@@ -477,11 +526,77 @@ class TestSegmentTypeIsValid:
|
||||
def test_none_validation_valid_cases(self, case):
|
||||
assert case.segment_type.is_valid(case.value) == case.expected
|
||||
|
||||
def test_unsupported_segment_type_raises_assertion_error(self):
|
||||
"""Test that unsupported SegmentType values raise AssertionError."""
|
||||
# GROUP is not handled in is_valid method
|
||||
with pytest.raises(AssertionError, match="this statement should be unreachable"):
|
||||
SegmentType.GROUP.is_valid("any value")
|
||||
@pytest.mark.parametrize("case", get_group_cases(), ids=lambda case: case.description)
|
||||
def test_group_validation(self, case):
|
||||
"""Test GROUP type validation with various inputs."""
|
||||
assert case.segment_type.is_valid(case.value) == case.expected
|
||||
|
||||
def test_group_validation_edge_cases(self):
|
||||
"""Test GROUP validation edge cases."""
|
||||
test_file = create_test_file()
|
||||
|
||||
# Test with nested SegmentGroups
|
||||
inner_group = SegmentGroup(value=[StringSegment(value="inner"), IntegerSegment(value=42)])
|
||||
outer_group = SegmentGroup(value=[StringSegment(value="outer"), inner_group])
|
||||
assert SegmentType.GROUP.is_valid(outer_group) is True
|
||||
|
||||
# Test with ArrayFileSegment (which is also a Segment)
|
||||
file_segment = FileSegment(value=test_file)
|
||||
array_file_segment = ArrayFileSegment(value=[test_file, test_file])
|
||||
group_with_arrays = SegmentGroup(value=[file_segment, array_file_segment, StringSegment(value="test")])
|
||||
assert SegmentType.GROUP.is_valid(group_with_arrays) is True
|
||||
|
||||
# Test performance with large number of segments
|
||||
large_segment_list = [StringSegment(value=f"item_{i}") for i in range(1000)]
|
||||
large_group = SegmentGroup(value=large_segment_list)
|
||||
assert SegmentType.GROUP.is_valid(large_group) is True
|
||||
|
||||
def test_no_truly_unsupported_segment_types_exist(self):
|
||||
"""Test that all SegmentType enum values are properly handled in is_valid method.
|
||||
|
||||
This test ensures there are no SegmentType values that would raise AssertionError.
|
||||
If this test fails, it means a new SegmentType was added without proper validation support.
|
||||
"""
|
||||
# Test that ALL segment types are handled and don't raise AssertionError
|
||||
all_segment_types = set(SegmentType)
|
||||
|
||||
for segment_type in all_segment_types:
|
||||
# Create a valid test value for each type
|
||||
test_value: Any = None
|
||||
if segment_type == SegmentType.STRING:
|
||||
test_value = "test"
|
||||
elif segment_type in {SegmentType.NUMBER, SegmentType.INTEGER}:
|
||||
test_value = 42
|
||||
elif segment_type == SegmentType.FLOAT:
|
||||
test_value = 3.14
|
||||
elif segment_type == SegmentType.BOOLEAN:
|
||||
test_value = True
|
||||
elif segment_type == SegmentType.OBJECT:
|
||||
test_value = {"key": "value"}
|
||||
elif segment_type == SegmentType.SECRET:
|
||||
test_value = "secret"
|
||||
elif segment_type == SegmentType.FILE:
|
||||
test_value = create_test_file()
|
||||
elif segment_type == SegmentType.NONE:
|
||||
test_value = None
|
||||
elif segment_type == SegmentType.GROUP:
|
||||
test_value = SegmentGroup(value=[StringSegment(value="test")])
|
||||
elif segment_type.is_array_type():
|
||||
test_value = [] # Empty array is valid for all array types
|
||||
else:
|
||||
# If we get here, there's a segment type we don't know how to test
|
||||
# This should prompt us to add validation logic
|
||||
pytest.fail(f"Unknown segment type {segment_type} needs validation logic and test case")
|
||||
|
||||
# This should NOT raise AssertionError
|
||||
try:
|
||||
result = segment_type.is_valid(test_value)
|
||||
assert isinstance(result, bool), f"is_valid should return boolean for {segment_type}"
|
||||
except AssertionError as e:
|
||||
pytest.fail(
|
||||
f"SegmentType.{segment_type.name}.is_valid() raised AssertionError: {e}. "
|
||||
"This segment type needs to be handled in the is_valid method."
|
||||
)
|
||||
|
||||
|
||||
class TestSegmentTypeArrayValidation:
|
||||
@@ -611,6 +726,7 @@ class TestSegmentTypeValidationIntegration:
|
||||
SegmentType.SECRET,
|
||||
SegmentType.FILE,
|
||||
SegmentType.NONE,
|
||||
SegmentType.GROUP,
|
||||
]
|
||||
|
||||
for segment_type in non_array_types:
|
||||
@@ -630,6 +746,8 @@ class TestSegmentTypeValidationIntegration:
|
||||
valid_value = create_test_file()
|
||||
elif segment_type == SegmentType.NONE:
|
||||
valid_value = None
|
||||
elif segment_type == SegmentType.GROUP:
|
||||
valid_value = SegmentGroup(value=[StringSegment(value="test")])
|
||||
else:
|
||||
continue # Skip unsupported types
|
||||
|
||||
@@ -656,6 +774,7 @@ class TestSegmentTypeValidationIntegration:
|
||||
SegmentType.SECRET,
|
||||
SegmentType.FILE,
|
||||
SegmentType.NONE,
|
||||
SegmentType.GROUP,
|
||||
# Array types
|
||||
SegmentType.ARRAY_ANY,
|
||||
SegmentType.ARRAY_STRING,
|
||||
@@ -667,7 +786,6 @@ class TestSegmentTypeValidationIntegration:
|
||||
|
||||
# Types that are not handled by is_valid (should raise AssertionError)
|
||||
unhandled_types = {
|
||||
SegmentType.GROUP,
|
||||
SegmentType.INTEGER, # Handled by NUMBER validation logic
|
||||
SegmentType.FLOAT, # Handled by NUMBER validation logic
|
||||
}
|
||||
@@ -696,6 +814,8 @@ class TestSegmentTypeValidationIntegration:
|
||||
assert segment_type.is_valid(create_test_file()) is True
|
||||
elif segment_type == SegmentType.NONE:
|
||||
assert segment_type.is_valid(None) is True
|
||||
elif segment_type == SegmentType.GROUP:
|
||||
assert segment_type.is_valid(SegmentGroup(value=[StringSegment(value="test")])) is True
|
||||
|
||||
def test_boolean_vs_integer_type_distinction(self):
|
||||
"""Test the important distinction between boolean and integer types in validation."""
|
||||
|
||||
Reference in New Issue
Block a user