mirror of
https://github.com/langgenius/dify.git
synced 2025-12-19 17:27:16 -05:00
fix(api): SegmentType.is_valid() raises AssertionError for SegmentType.GROUP (#28249)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -1,9 +1,12 @@
|
|||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
from typing import Any, Optional
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
|
|
||||||
from core.file.models import File
|
from core.file.models import File
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ArrayValidation(StrEnum):
|
class ArrayValidation(StrEnum):
|
||||||
"""Strategy for validating array elements.
|
"""Strategy for validating array elements.
|
||||||
@@ -155,6 +158,17 @@ class SegmentType(StrEnum):
|
|||||||
return isinstance(value, File)
|
return isinstance(value, File)
|
||||||
elif self == SegmentType.NONE:
|
elif self == SegmentType.NONE:
|
||||||
return value is None
|
return value is None
|
||||||
|
elif self == SegmentType.GROUP:
|
||||||
|
from .segment_group import SegmentGroup
|
||||||
|
from .segments import Segment
|
||||||
|
|
||||||
|
if isinstance(value, SegmentGroup):
|
||||||
|
return all(isinstance(item, Segment) for item in value.value)
|
||||||
|
|
||||||
|
if isinstance(value, list):
|
||||||
|
return all(isinstance(item, Segment) for item in value)
|
||||||
|
|
||||||
|
return False
|
||||||
else:
|
else:
|
||||||
raise AssertionError("this statement should be unreachable.")
|
raise AssertionError("this statement should be unreachable.")
|
||||||
|
|
||||||
|
|||||||
@@ -12,6 +12,16 @@ import pytest
|
|||||||
|
|
||||||
from core.file.enums import FileTransferMethod, FileType
|
from core.file.enums import FileTransferMethod, FileType
|
||||||
from core.file.models import File
|
from core.file.models import File
|
||||||
|
from core.variables.segment_group import SegmentGroup
|
||||||
|
from core.variables.segments import (
|
||||||
|
ArrayFileSegment,
|
||||||
|
BooleanSegment,
|
||||||
|
FileSegment,
|
||||||
|
IntegerSegment,
|
||||||
|
NoneSegment,
|
||||||
|
ObjectSegment,
|
||||||
|
StringSegment,
|
||||||
|
)
|
||||||
from core.variables.types import ArrayValidation, SegmentType
|
from core.variables.types import ArrayValidation, SegmentType
|
||||||
|
|
||||||
|
|
||||||
@@ -202,6 +212,45 @@ def get_none_cases() -> list[ValidationTestCase]:
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def get_group_cases() -> list[ValidationTestCase]:
|
||||||
|
"""Get test cases for valid group values."""
|
||||||
|
test_file = create_test_file()
|
||||||
|
segments = [
|
||||||
|
StringSegment(value="hello"),
|
||||||
|
IntegerSegment(value=42),
|
||||||
|
BooleanSegment(value=True),
|
||||||
|
ObjectSegment(value={"key": "value"}),
|
||||||
|
FileSegment(value=test_file),
|
||||||
|
NoneSegment(value=None),
|
||||||
|
]
|
||||||
|
|
||||||
|
return [
|
||||||
|
# valid cases
|
||||||
|
ValidationTestCase(
|
||||||
|
SegmentType.GROUP, SegmentGroup(value=segments), True, "Valid SegmentGroup with mixed segments"
|
||||||
|
),
|
||||||
|
ValidationTestCase(
|
||||||
|
SegmentType.GROUP, [StringSegment(value="test"), IntegerSegment(value=123)], True, "List of Segment objects"
|
||||||
|
),
|
||||||
|
ValidationTestCase(SegmentType.GROUP, SegmentGroup(value=[]), True, "Empty SegmentGroup"),
|
||||||
|
ValidationTestCase(SegmentType.GROUP, [], True, "Empty list"),
|
||||||
|
# invalid cases
|
||||||
|
ValidationTestCase(SegmentType.GROUP, "not a list", False, "String value"),
|
||||||
|
ValidationTestCase(SegmentType.GROUP, 123, False, "Integer value"),
|
||||||
|
ValidationTestCase(SegmentType.GROUP, True, False, "Boolean value"),
|
||||||
|
ValidationTestCase(SegmentType.GROUP, None, False, "None value"),
|
||||||
|
ValidationTestCase(SegmentType.GROUP, {"key": "value"}, False, "Dict value"),
|
||||||
|
ValidationTestCase(SegmentType.GROUP, test_file, False, "File value"),
|
||||||
|
ValidationTestCase(SegmentType.GROUP, ["string", 123, True], False, "List with non-Segment objects"),
|
||||||
|
ValidationTestCase(
|
||||||
|
SegmentType.GROUP,
|
||||||
|
[StringSegment(value="test"), "not a segment"],
|
||||||
|
False,
|
||||||
|
"Mixed list with some non-Segment objects",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def get_array_any_validation_cases() -> list[ArrayValidationTestCase]:
|
def get_array_any_validation_cases() -> list[ArrayValidationTestCase]:
|
||||||
"""Get test cases for ARRAY_ANY validation."""
|
"""Get test cases for ARRAY_ANY validation."""
|
||||||
return [
|
return [
|
||||||
@@ -477,11 +526,77 @@ class TestSegmentTypeIsValid:
|
|||||||
def test_none_validation_valid_cases(self, case):
|
def test_none_validation_valid_cases(self, case):
|
||||||
assert case.segment_type.is_valid(case.value) == case.expected
|
assert case.segment_type.is_valid(case.value) == case.expected
|
||||||
|
|
||||||
def test_unsupported_segment_type_raises_assertion_error(self):
|
@pytest.mark.parametrize("case", get_group_cases(), ids=lambda case: case.description)
|
||||||
"""Test that unsupported SegmentType values raise AssertionError."""
|
def test_group_validation(self, case):
|
||||||
# GROUP is not handled in is_valid method
|
"""Test GROUP type validation with various inputs."""
|
||||||
with pytest.raises(AssertionError, match="this statement should be unreachable"):
|
assert case.segment_type.is_valid(case.value) == case.expected
|
||||||
SegmentType.GROUP.is_valid("any value")
|
|
||||||
|
def test_group_validation_edge_cases(self):
|
||||||
|
"""Test GROUP validation edge cases."""
|
||||||
|
test_file = create_test_file()
|
||||||
|
|
||||||
|
# Test with nested SegmentGroups
|
||||||
|
inner_group = SegmentGroup(value=[StringSegment(value="inner"), IntegerSegment(value=42)])
|
||||||
|
outer_group = SegmentGroup(value=[StringSegment(value="outer"), inner_group])
|
||||||
|
assert SegmentType.GROUP.is_valid(outer_group) is True
|
||||||
|
|
||||||
|
# Test with ArrayFileSegment (which is also a Segment)
|
||||||
|
file_segment = FileSegment(value=test_file)
|
||||||
|
array_file_segment = ArrayFileSegment(value=[test_file, test_file])
|
||||||
|
group_with_arrays = SegmentGroup(value=[file_segment, array_file_segment, StringSegment(value="test")])
|
||||||
|
assert SegmentType.GROUP.is_valid(group_with_arrays) is True
|
||||||
|
|
||||||
|
# Test performance with large number of segments
|
||||||
|
large_segment_list = [StringSegment(value=f"item_{i}") for i in range(1000)]
|
||||||
|
large_group = SegmentGroup(value=large_segment_list)
|
||||||
|
assert SegmentType.GROUP.is_valid(large_group) is True
|
||||||
|
|
||||||
|
def test_no_truly_unsupported_segment_types_exist(self):
|
||||||
|
"""Test that all SegmentType enum values are properly handled in is_valid method.
|
||||||
|
|
||||||
|
This test ensures there are no SegmentType values that would raise AssertionError.
|
||||||
|
If this test fails, it means a new SegmentType was added without proper validation support.
|
||||||
|
"""
|
||||||
|
# Test that ALL segment types are handled and don't raise AssertionError
|
||||||
|
all_segment_types = set(SegmentType)
|
||||||
|
|
||||||
|
for segment_type in all_segment_types:
|
||||||
|
# Create a valid test value for each type
|
||||||
|
test_value: Any = None
|
||||||
|
if segment_type == SegmentType.STRING:
|
||||||
|
test_value = "test"
|
||||||
|
elif segment_type in {SegmentType.NUMBER, SegmentType.INTEGER}:
|
||||||
|
test_value = 42
|
||||||
|
elif segment_type == SegmentType.FLOAT:
|
||||||
|
test_value = 3.14
|
||||||
|
elif segment_type == SegmentType.BOOLEAN:
|
||||||
|
test_value = True
|
||||||
|
elif segment_type == SegmentType.OBJECT:
|
||||||
|
test_value = {"key": "value"}
|
||||||
|
elif segment_type == SegmentType.SECRET:
|
||||||
|
test_value = "secret"
|
||||||
|
elif segment_type == SegmentType.FILE:
|
||||||
|
test_value = create_test_file()
|
||||||
|
elif segment_type == SegmentType.NONE:
|
||||||
|
test_value = None
|
||||||
|
elif segment_type == SegmentType.GROUP:
|
||||||
|
test_value = SegmentGroup(value=[StringSegment(value="test")])
|
||||||
|
elif segment_type.is_array_type():
|
||||||
|
test_value = [] # Empty array is valid for all array types
|
||||||
|
else:
|
||||||
|
# If we get here, there's a segment type we don't know how to test
|
||||||
|
# This should prompt us to add validation logic
|
||||||
|
pytest.fail(f"Unknown segment type {segment_type} needs validation logic and test case")
|
||||||
|
|
||||||
|
# This should NOT raise AssertionError
|
||||||
|
try:
|
||||||
|
result = segment_type.is_valid(test_value)
|
||||||
|
assert isinstance(result, bool), f"is_valid should return boolean for {segment_type}"
|
||||||
|
except AssertionError as e:
|
||||||
|
pytest.fail(
|
||||||
|
f"SegmentType.{segment_type.name}.is_valid() raised AssertionError: {e}. "
|
||||||
|
"This segment type needs to be handled in the is_valid method."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestSegmentTypeArrayValidation:
|
class TestSegmentTypeArrayValidation:
|
||||||
@@ -611,6 +726,7 @@ class TestSegmentTypeValidationIntegration:
|
|||||||
SegmentType.SECRET,
|
SegmentType.SECRET,
|
||||||
SegmentType.FILE,
|
SegmentType.FILE,
|
||||||
SegmentType.NONE,
|
SegmentType.NONE,
|
||||||
|
SegmentType.GROUP,
|
||||||
]
|
]
|
||||||
|
|
||||||
for segment_type in non_array_types:
|
for segment_type in non_array_types:
|
||||||
@@ -630,6 +746,8 @@ class TestSegmentTypeValidationIntegration:
|
|||||||
valid_value = create_test_file()
|
valid_value = create_test_file()
|
||||||
elif segment_type == SegmentType.NONE:
|
elif segment_type == SegmentType.NONE:
|
||||||
valid_value = None
|
valid_value = None
|
||||||
|
elif segment_type == SegmentType.GROUP:
|
||||||
|
valid_value = SegmentGroup(value=[StringSegment(value="test")])
|
||||||
else:
|
else:
|
||||||
continue # Skip unsupported types
|
continue # Skip unsupported types
|
||||||
|
|
||||||
@@ -656,6 +774,7 @@ class TestSegmentTypeValidationIntegration:
|
|||||||
SegmentType.SECRET,
|
SegmentType.SECRET,
|
||||||
SegmentType.FILE,
|
SegmentType.FILE,
|
||||||
SegmentType.NONE,
|
SegmentType.NONE,
|
||||||
|
SegmentType.GROUP,
|
||||||
# Array types
|
# Array types
|
||||||
SegmentType.ARRAY_ANY,
|
SegmentType.ARRAY_ANY,
|
||||||
SegmentType.ARRAY_STRING,
|
SegmentType.ARRAY_STRING,
|
||||||
@@ -667,7 +786,6 @@ class TestSegmentTypeValidationIntegration:
|
|||||||
|
|
||||||
# Types that are not handled by is_valid (should raise AssertionError)
|
# Types that are not handled by is_valid (should raise AssertionError)
|
||||||
unhandled_types = {
|
unhandled_types = {
|
||||||
SegmentType.GROUP,
|
|
||||||
SegmentType.INTEGER, # Handled by NUMBER validation logic
|
SegmentType.INTEGER, # Handled by NUMBER validation logic
|
||||||
SegmentType.FLOAT, # Handled by NUMBER validation logic
|
SegmentType.FLOAT, # Handled by NUMBER validation logic
|
||||||
}
|
}
|
||||||
@@ -696,6 +814,8 @@ class TestSegmentTypeValidationIntegration:
|
|||||||
assert segment_type.is_valid(create_test_file()) is True
|
assert segment_type.is_valid(create_test_file()) is True
|
||||||
elif segment_type == SegmentType.NONE:
|
elif segment_type == SegmentType.NONE:
|
||||||
assert segment_type.is_valid(None) is True
|
assert segment_type.is_valid(None) is True
|
||||||
|
elif segment_type == SegmentType.GROUP:
|
||||||
|
assert segment_type.is_valid(SegmentGroup(value=[StringSegment(value="test")])) is True
|
||||||
|
|
||||||
def test_boolean_vs_integer_type_distinction(self):
|
def test_boolean_vs_integer_type_distinction(self):
|
||||||
"""Test the important distinction between boolean and integer types in validation."""
|
"""Test the important distinction between boolean and integer types in validation."""
|
||||||
|
|||||||
Reference in New Issue
Block a user