From 1c7cf44af47c9ee77572e5fec9ecbfbdb067d5be Mon Sep 17 00:00:00 2001 From: dataCenter430 <161712630+dataCenter430@users.noreply.github.com> Date: Wed, 8 Apr 2026 18:11:47 -0700 Subject: [PATCH] refactor: convert SegmentType controllers if/elif to match/case (#30001) (#34784) --- .../console/app/workflow_draft_variable.py | 39 ++++++++++--------- .../rag_pipeline_draft_variable.py | 39 ++++++++++--------- .../workflow/nodes/trigger_webhook/node.py | 29 +++++++------- 3 files changed, 57 insertions(+), 50 deletions(-) diff --git a/api/controllers/console/app/workflow_draft_variable.py b/api/controllers/console/app/workflow_draft_variable.py index f6d076320c..9771d6f1e5 100644 --- a/api/controllers/console/app/workflow_draft_variable.py +++ b/api/controllers/console/app/workflow_draft_variable.py @@ -384,24 +384,27 @@ class VariableApi(Resource): new_value = None if raw_value is not None: - if variable.value_type == SegmentType.FILE: - if not isinstance(raw_value, dict): - raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}") - raw_value = build_from_mapping( - mapping=raw_value, - tenant_id=app_model.tenant_id, - access_controller=_file_access_controller, - ) - elif variable.value_type == SegmentType.ARRAY_FILE: - if not isinstance(raw_value, list): - raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}") - if len(raw_value) > 0 and not isinstance(raw_value[0], dict): - raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}") - raw_value = build_from_mappings( - mappings=raw_value, - tenant_id=app_model.tenant_id, - access_controller=_file_access_controller, - ) + match variable.value_type: + case SegmentType.FILE: + if not isinstance(raw_value, dict): + raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}") + raw_value = build_from_mapping( + mapping=raw_value, + tenant_id=app_model.tenant_id, + access_controller=_file_access_controller, + ) + case SegmentType.ARRAY_FILE: + if not isinstance(raw_value, list): + raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}") + if len(raw_value) > 0 and not isinstance(raw_value[0], dict): + raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}") + raw_value = build_from_mappings( + mappings=raw_value, + tenant_id=app_model.tenant_id, + access_controller=_file_access_controller, + ) + case _: + pass new_value = build_segment_with_type(variable.value_type, raw_value) draft_var_srv.update_variable(variable, name=new_name, value=new_value) db.session.commit() diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py index 93feec0019..3549f9542d 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py @@ -223,24 +223,27 @@ class RagPipelineVariableApi(Resource): new_value = None if raw_value is not None: - if variable.value_type == SegmentType.FILE: - if not isinstance(raw_value, dict): - raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}") - raw_value = build_from_mapping( - mapping=raw_value, - tenant_id=pipeline.tenant_id, - access_controller=_file_access_controller, - ) - elif variable.value_type == SegmentType.ARRAY_FILE: - if not isinstance(raw_value, list): - raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}") - if len(raw_value) > 0 and not isinstance(raw_value[0], dict): - raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}") - raw_value = build_from_mappings( - mappings=raw_value, - tenant_id=pipeline.tenant_id, - access_controller=_file_access_controller, - ) + match variable.value_type: + case SegmentType.FILE: + if not isinstance(raw_value, dict): + raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}") + raw_value = build_from_mapping( + mapping=raw_value, + tenant_id=pipeline.tenant_id, + access_controller=_file_access_controller, + ) + case SegmentType.ARRAY_FILE: + if not isinstance(raw_value, list): + raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}") + if len(raw_value) > 0 and not isinstance(raw_value[0], dict): + raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}") + raw_value = build_from_mappings( + mappings=raw_value, + tenant_id=pipeline.tenant_id, + access_controller=_file_access_controller, + ) + case _: + pass new_value = build_segment_with_type(variable.value_type, raw_value) draft_var_srv.update_variable(variable, name=new_name, value=new_value) db.session.commit() diff --git a/api/core/workflow/nodes/trigger_webhook/node.py b/api/core/workflow/nodes/trigger_webhook/node.py index ebaac93934..6a0d633627 100644 --- a/api/core/workflow/nodes/trigger_webhook/node.py +++ b/api/core/workflow/nodes/trigger_webhook/node.py @@ -155,24 +155,25 @@ class TriggerWebhookNode(Node[WebhookData]): outputs[param_name] = raw_data continue - if param_type == SegmentType.FILE: - # Get File object (already processed by webhook controller) - files = webhook_data.get("files", {}) - if files and isinstance(files, dict): - file = files.get(param_name) - if file and isinstance(file, dict): - file_var = self.generate_file_var(param_name, file) - if file_var: - outputs[param_name] = file_var + match param_type: + case SegmentType.FILE: + # Get File object (already processed by webhook controller) + files = webhook_data.get("files", {}) + if files and isinstance(files, dict): + file = files.get(param_name) + if file and isinstance(file, dict): + file_var = self.generate_file_var(param_name, file) + if file_var: + outputs[param_name] = file_var + else: + outputs[param_name] = files else: outputs[param_name] = files else: outputs[param_name] = files - else: - outputs[param_name] = files - else: - # Get regular body parameter - outputs[param_name] = webhook_data.get("body", {}).get(param_name) + case _: + # Get regular body parameter + outputs[param_name] = webhook_data.get("body", {}).get(param_name) # Include raw webhook data for debugging/advanced use outputs["_webhook_raw"] = webhook_data