From 7e8147295bf1807388811136e98309f6d7ec8cc0 Mon Sep 17 00:00:00 2001 From: EvanYao <2869018789@qq.com> Date: Mon, 18 May 2026 15:16:31 +0800 Subject: [PATCH] refactor: convert isinstance chains to match/case (part 5) (#36298) Co-authored-by: Stephen Zhou --- api/core/app/apps/base_app_generator.py | 56 ++++----- api/core/app/workflow/layers/persistence.py | 71 ++++------- .../output_parser/structured_output.py | 36 +++--- api/core/workflow/node_runtime.py | 113 +++++++++--------- api/libs/helper.py | 38 +++--- api/models/utils/file_input_compat.py | 43 +++---- 6 files changed, 175 insertions(+), 182 deletions(-) diff --git a/api/core/app/apps/base_app_generator.py b/api/core/app/apps/base_app_generator.py index 8e8ccf2b90..d7ef5165f0 100644 --- a/api/core/app/apps/base_app_generator.py +++ b/api/core/app/apps/base_app_generator.py @@ -195,22 +195,23 @@ class BaseAppGenerator: ) if variable_entity.type == VariableEntityType.NUMBER: - if isinstance(value, (int, float)): - return value - elif isinstance(value, str): - # handle empty string case - if not value.strip(): - return None - # may raise ValueError if user_input_value is not a valid number - try: - if "." in value: - return float(value) - else: - return int(value) - except ValueError: - raise ValueError(f"{variable_entity.variable} in input form must be a valid number") - else: - raise TypeError(f"expected value type int, float or str, got {type(value)}, value: {value}") + match value: + case int() | float(): + return value + case str(): + # handle empty string case + if not value.strip(): + return None + # may raise ValueError if user_input_value is not a valid number + try: + if "." in value: + return float(value) + else: + return int(value) + except ValueError: + raise ValueError(f"{variable_entity.variable} in input form must be a valid number") + case _: + raise TypeError(f"expected value type int, float or str, got {type(value)}, value: {value}") match variable_entity.type: case VariableEntityType.SELECT: @@ -241,17 +242,18 @@ class BaseAppGenerator: f"{variable_entity.variable} in input form must be less than {variable_entity.max_length} files" ) case VariableEntityType.CHECKBOX: - if isinstance(value, str): - normalized_value = value.strip().lower() - if normalized_value in {"true", "1", "yes", "on"}: - value = True - elif normalized_value in {"false", "0", "no", "off"}: - value = False - elif isinstance(value, (int, float)): - if value == 1: - value = True - elif value == 0: - value = False + match value: + case str(): + normalized_value = value.strip().lower() + if normalized_value in {"true", "1", "yes", "on"}: + value = True + elif normalized_value in {"false", "0", "no", "off"}: + value = False + case int() | float(): + if value == 1: + value = True + elif value == 0: + value = False case VariableEntityType.JSON_OBJECT: if value and not isinstance(value, dict): raise ValueError(f"{variable_entity.variable} in input form must be a dict") diff --git a/api/core/app/workflow/layers/persistence.py b/api/core/app/workflow/layers/persistence.py index 19152cebae..c5dba65232 100644 --- a/api/core/app/workflow/layers/persistence.py +++ b/api/core/app/workflow/layers/persistence.py @@ -105,52 +105,31 @@ class WorkflowPersistenceLayer(GraphEngineLayer): self._node_sequence = 0 def on_event(self, event: GraphEngineEvent) -> None: - if isinstance(event, GraphRunStartedEvent): - self._handle_graph_run_started() - return - - if isinstance(event, GraphRunSucceededEvent): - self._handle_graph_run_succeeded(event) - return - - if isinstance(event, GraphRunPartialSucceededEvent): - self._handle_graph_run_partial_succeeded(event) - return - - if isinstance(event, GraphRunFailedEvent): - self._handle_graph_run_failed(event) - return - - if isinstance(event, GraphRunAbortedEvent): - self._handle_graph_run_aborted(event) - return - - if isinstance(event, GraphRunPausedEvent): - self._handle_graph_run_paused(event) - return - - if isinstance(event, NodeRunRetryEvent): - self._handle_node_retry(event) - return - - if isinstance(event, NodeRunStartedEvent): - self._handle_node_started(event) - return - - if isinstance(event, NodeRunSucceededEvent): - self._handle_node_succeeded(event) - return - - if isinstance(event, NodeRunFailedEvent): - self._handle_node_failed(event) - return - - if isinstance(event, NodeRunExceptionEvent): - self._handle_node_exception(event) - return - - if isinstance(event, NodeRunPauseRequestedEvent): - self._handle_node_pause_requested(event) + match event: + case GraphRunStartedEvent(): + self._handle_graph_run_started() + case GraphRunSucceededEvent(): + self._handle_graph_run_succeeded(event) + case GraphRunPartialSucceededEvent(): + self._handle_graph_run_partial_succeeded(event) + case GraphRunFailedEvent(): + self._handle_graph_run_failed(event) + case GraphRunAbortedEvent(): + self._handle_graph_run_aborted(event) + case GraphRunPausedEvent(): + self._handle_graph_run_paused(event) + case NodeRunRetryEvent(): + self._handle_node_retry(event) + case NodeRunStartedEvent(): + self._handle_node_started(event) + case NodeRunSucceededEvent(): + self._handle_node_succeeded(event) + case NodeRunFailedEvent(): + self._handle_node_failed(event) + case NodeRunExceptionEvent(): + self._handle_node_exception(event) + case NodeRunPauseRequestedEvent(): + self._handle_node_pause_requested(event) def on_graph_end(self, error: Exception | None) -> None: return diff --git a/api/core/llm_generator/output_parser/structured_output.py b/api/core/llm_generator/output_parser/structured_output.py index d2e375626f..6cba4fbdf6 100644 --- a/api/core/llm_generator/output_parser/structured_output.py +++ b/api/core/llm_generator/output_parser/structured_output.py @@ -288,11 +288,13 @@ def _parse_structured_output(result_text: str) -> Mapping[str, Any]: except ValidationError: # if the result_text is not a valid json, try to repair it temp_parsed = json_repair.loads(result_text) - if not isinstance(temp_parsed, dict): - # handle reasoning model like deepseek-r1 got '\n\n\n' prefix - if isinstance(temp_parsed, list): + match temp_parsed: + case dict(): + pass + case list(): + # handle reasoning model like deepseek-r1 got '\n\n\n' prefix temp_parsed = next((item for item in temp_parsed if isinstance(item, dict)), {}) - else: + case _: raise OutputParserError(f"Failed to parse structured output: {result_text}") structured_output = cast(dict, temp_parsed) return structured_output @@ -341,12 +343,13 @@ def remove_additional_properties(schema: dict[str, Any]) -> None: # Process nested structures recursively for value in schema.values(): - if isinstance(value, dict): - remove_additional_properties(value) - elif isinstance(value, list): - for item in value: - if isinstance(item, dict): - remove_additional_properties(item) + match value: + case dict(): + remove_additional_properties(value) + case list(): + for item in value: + if isinstance(item, dict): + remove_additional_properties(item) def convert_boolean_to_string(schema: dict[str, Any]) -> None: @@ -364,9 +367,10 @@ def convert_boolean_to_string(schema: dict[str, Any]) -> None: # Process nested dictionaries and lists recursively for value in schema.values(): - if isinstance(value, dict): - convert_boolean_to_string(value) - elif isinstance(value, list): - for item in value: - if isinstance(item, dict): - convert_boolean_to_string(item) + match value: + case dict(): + convert_boolean_to_string(value) + case list(): + for item in value: + if isinstance(item, dict): + convert_boolean_to_string(item) diff --git a/api/core/workflow/node_runtime.py b/api/core/workflow/node_runtime.py index 7d6ac74791..7ea03a4a33 100644 --- a/api/core/workflow/node_runtime.py +++ b/api/core/workflow/node_runtime.py @@ -608,64 +608,67 @@ class DifyToolNodeRuntime(ToolNodeRuntimeProtocol): from core.tools.entities.tool_entities import ToolInvokeMessage as CoreToolInvokeMessage - if isinstance(message, CoreToolInvokeMessage.TextMessage): - return ToolRuntimeMessage.TextMessage(text=message.text) - if isinstance(message, CoreToolInvokeMessage.JsonMessage): - return ToolRuntimeMessage.JsonMessage( - json_object=message.json_object, - suppress_output=message.suppress_output, - ) - if isinstance(message, CoreToolInvokeMessage.BlobMessage): - return ToolRuntimeMessage.BlobMessage(blob=message.blob) - if isinstance(message, CoreToolInvokeMessage.BlobChunkMessage): - return ToolRuntimeMessage.BlobChunkMessage( - id=message.id, - sequence=message.sequence, - total_length=message.total_length, - blob=message.blob, - end=message.end, - ) - if isinstance(message, CoreToolInvokeMessage.FileMessage): - return ToolRuntimeMessage.FileMessage(file_marker=message.file_marker) - if isinstance(message, CoreToolInvokeMessage.VariableMessage): - return ToolRuntimeMessage.VariableMessage( - variable_name=message.variable_name, - variable_value=message.variable_value, - stream=message.stream, - ) - if isinstance(message, CoreToolInvokeMessage.LogMessage): - return ToolRuntimeMessage.LogMessage( - id=message.id, - label=message.label, - parent_id=message.parent_id, - error=message.error, - status=ToolRuntimeMessage.LogMessage.LogStatus(message.status.value), - data=dict(message.data), - metadata=dict(message.metadata), - ) - if isinstance(message, CoreToolInvokeMessage.RetrieverResourceMessage): - retriever_resources = [ - resource.model_dump() if hasattr(resource, "model_dump") else dict(resource) - for resource in message.retriever_resources - ] - return ToolRuntimeMessage.RetrieverResourceMessage( - retriever_resources=retriever_resources, - context=message.context, - ) - - raise TypeError(f"unsupported tool message payload: {type(message).__name__}") + match message: + case CoreToolInvokeMessage.TextMessage(): + return ToolRuntimeMessage.TextMessage(text=message.text) + case CoreToolInvokeMessage.JsonMessage(): + return ToolRuntimeMessage.JsonMessage( + json_object=message.json_object, + suppress_output=message.suppress_output, + ) + case CoreToolInvokeMessage.BlobMessage(): + return ToolRuntimeMessage.BlobMessage(blob=message.blob) + case CoreToolInvokeMessage.BlobChunkMessage(): + return ToolRuntimeMessage.BlobChunkMessage( + id=message.id, + sequence=message.sequence, + total_length=message.total_length, + blob=message.blob, + end=message.end, + ) + case CoreToolInvokeMessage.FileMessage(): + return ToolRuntimeMessage.FileMessage(file_marker=message.file_marker) + case CoreToolInvokeMessage.VariableMessage(): + return ToolRuntimeMessage.VariableMessage( + variable_name=message.variable_name, + variable_value=message.variable_value, + stream=message.stream, + ) + case CoreToolInvokeMessage.LogMessage(): + return ToolRuntimeMessage.LogMessage( + id=message.id, + label=message.label, + parent_id=message.parent_id, + error=message.error, + status=ToolRuntimeMessage.LogMessage.LogStatus(message.status.value), + data=dict(message.data), + metadata=dict(message.metadata), + ) + case CoreToolInvokeMessage.RetrieverResourceMessage(): + retriever_resources = [ + resource.model_dump() if hasattr(resource, "model_dump") else dict(resource) + for resource in message.retriever_resources + ] + return ToolRuntimeMessage.RetrieverResourceMessage( + retriever_resources=retriever_resources, + context=message.context, + ) + case _: + raise TypeError(f"unsupported tool message payload: {type(message).__name__}") @staticmethod def _map_invocation_exception(exc: Exception, *, provider_name: str) -> ToolNodeError: - if isinstance(exc, ToolNodeError): - return exc - if isinstance(exc, PluginInvokeError): - return ToolRuntimeInvocationError(exc.to_user_friendly_error(plugin_name=provider_name)) - if isinstance(exc, PluginDaemonClientSideError): - return ToolRuntimeInvocationError(f"Failed to invoke tool, error: {exc.description}") - if isinstance(exc, ToolInvokeError): - return ToolRuntimeInvocationError(f"Failed to invoke tool {provider_name}: {exc}") - return ToolRuntimeInvocationError(str(exc)) + match exc: + case ToolNodeError(): + return exc + case PluginInvokeError(): + return ToolRuntimeInvocationError(exc.to_user_friendly_error(plugin_name=provider_name)) + case PluginDaemonClientSideError(): + return ToolRuntimeInvocationError(f"Failed to invoke tool, error: {exc.description}") + case ToolInvokeError(): + return ToolRuntimeInvocationError(f"Failed to invoke tool {provider_name}: {exc}") + case _: + return ToolRuntimeInvocationError(str(exc)) class DifyHumanInputNodeRuntime(HumanInputNodeRuntimeProtocol): diff --git a/api/libs/helper.py b/api/libs/helper.py index 7b9128ab7a..04900f385c 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -98,12 +98,13 @@ def extract_tenant_id(user: "Account | EndUser") -> str | None: from models import Account from models.model import EndUser - if isinstance(user, Account): - return user.current_tenant_id - elif isinstance(user, EndUser): - return user.tenant_id - else: - raise ValueError(f"Invalid user type: {type(user)}. Expected Account or EndUser.") + match user: + case Account(): + return user.current_tenant_id + case EndUser(): + return user.tenant_id + case _: + raise ValueError(f"Invalid user type: {type(user)}. Expected Account or EndUser.") def run(script): @@ -422,18 +423,19 @@ def length_prefixed_response( # | Magic Number 1byte | Reserved 1byte | Header Length 2bytes | Data Length 4bytes | Reserved 6bytes | Data return struct.pack(" FileType: - if isinstance(value, FileType): - return value - if isinstance(value, str): - return FileType.value_of(value) - raise ValueError("file type is required in file mapping") + match value: + case FileType(): + return value + case str(): + return FileType.value_of(value) + case _: + raise ValueError("file type is required in file mapping") mapping = dict(file_mapping) transfer_method_value = mapping.get("transfer_method") - if isinstance(transfer_method_value, FileTransferMethod): - transfer_method = transfer_method_value - elif isinstance(transfer_method_value, str): - transfer_method = FileTransferMethod.value_of(transfer_method_value) - else: - raise ValueError("transfer_method is required in file mapping") + match transfer_method_value: + case FileTransferMethod(): + transfer_method = transfer_method_value + case str(): + transfer_method = FileTransferMethod.value_of(transfer_method_value) + case _: + raise ValueError("transfer_method is required in file mapping") file_id = mapping.get("file_id") if not isinstance(file_id, str) or not file_id: @@ -151,15 +154,15 @@ def rebuild_serialized_graph_files_without_lookup(value: Any) -> Any: so historical JSON blobs remain readable without reintroducing global graph patches or test-local coercion. """ - if isinstance(value, list): - return [rebuild_serialized_graph_files_without_lookup(item) for item in value] - - if isinstance(value, dict): - if maybe_file_object(value): - return build_file_from_mapping_without_lookup(file_mapping=value) - return {key: rebuild_serialized_graph_files_without_lookup(item) for key, item in value.items()} - - return value + match value: + case list(): + return [rebuild_serialized_graph_files_without_lookup(item) for item in value] + case dict(): + if maybe_file_object(value): + return build_file_from_mapping_without_lookup(file_mapping=value) + return {key: rebuild_serialized_graph_files_without_lookup(item) for key, item in value.items()} + case _: + return value def build_file_from_stored_mapping(