diff --git a/.github/workflows/autofix.yml b/.github/workflows/autofix.yml index 766f55f642..068ba686fa 100644 --- a/.github/workflows/autofix.yml +++ b/.github/workflows/autofix.yml @@ -22,7 +22,7 @@ jobs: # Fix lint errors uv run ruff check --fix . # Format code - uv run ruff format . + uv run ruff format .. - name: ast-grep run: | diff --git a/api/.ruff.toml b/api/.ruff.toml index 9a15754d9a..67ad3b1449 100644 --- a/api/.ruff.toml +++ b/api/.ruff.toml @@ -5,7 +5,7 @@ line-length = 120 quote-style = "double" [lint] -preview = false +preview = true select = [ "B", # flake8-bugbear rules "C4", # flake8-comprehensions @@ -65,6 +65,7 @@ ignore = [ "B006", # mutable-argument-default "B007", # unused-loop-control-variable "B026", # star-arg-unpacking-after-keyword-arg + "B901", # allow return in yield "B903", # class-as-data-structure "B904", # raise-without-from-inside-except "B905", # zip-without-explicit-strict diff --git a/api/commands.py b/api/commands.py index 0adef498cd..58054a9adf 100644 --- a/api/commands.py +++ b/api/commands.py @@ -1,6 +1,7 @@ import base64 import json import logging +import operator import secrets from typing import Any @@ -953,7 +954,7 @@ def clear_orphaned_file_records(force: bool): click.echo(click.style("- Deleting orphaned message_files records", fg="white")) query = "DELETE FROM message_files WHERE id IN :ids" with db.engine.begin() as conn: - conn.execute(sa.text(query), {"ids": tuple([record["id"] for record in orphaned_message_files])}) + conn.execute(sa.text(query), {"ids": tuple(record["id"] for record in orphaned_message_files)}) click.echo( click.style(f"Removed {len(orphaned_message_files)} orphaned message_files records.", fg="green") ) @@ -1307,7 +1308,7 @@ def cleanup_orphaned_draft_variables( if dry_run: logger.info("DRY RUN: Would delete the following:") - for app_id, count in sorted(stats["orphaned_by_app"].items(), key=lambda x: x[1], reverse=True)[ + for app_id, count in sorted(stats["orphaned_by_app"].items(), key=operator.itemgetter(1), reverse=True)[ :10 ]: # Show top 10 logger.info(" App %s: %s variables", app_id, count) diff --git a/api/core/ops/aliyun_trace/aliyun_trace.py b/api/core/ops/aliyun_trace/aliyun_trace.py index db8684139d..d9519bb078 100644 --- a/api/core/ops/aliyun_trace/aliyun_trace.py +++ b/api/core/ops/aliyun_trace/aliyun_trace.py @@ -355,8 +355,8 @@ class AliyunDataTrace(BaseTraceInstance): GEN_AI_FRAMEWORK: "dify", TOOL_NAME: node_execution.title, TOOL_DESCRIPTION: json.dumps(tool_des, ensure_ascii=False), - TOOL_PARAMETERS: json.dumps(node_execution.inputs if node_execution.inputs else {}, ensure_ascii=False), - INPUT_VALUE: json.dumps(node_execution.inputs if node_execution.inputs else {}, ensure_ascii=False), + TOOL_PARAMETERS: json.dumps(node_execution.inputs or {}, ensure_ascii=False), + INPUT_VALUE: json.dumps(node_execution.inputs or {}, ensure_ascii=False), OUTPUT_VALUE: json.dumps(node_execution.outputs, ensure_ascii=False), }, status=self.get_workflow_node_status(node_execution), diff --git a/api/core/ops/langfuse_trace/langfuse_trace.py b/api/core/ops/langfuse_trace/langfuse_trace.py index f05e59c28e..119dd52a5f 100644 --- a/api/core/ops/langfuse_trace/langfuse_trace.py +++ b/api/core/ops/langfuse_trace/langfuse_trace.py @@ -144,13 +144,13 @@ class LangFuseDataTrace(BaseTraceInstance): if node_type == NodeType.LLM: inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {} else: - inputs = node_execution.inputs if node_execution.inputs else {} - outputs = node_execution.outputs if node_execution.outputs else {} + inputs = node_execution.inputs or {} + outputs = node_execution.outputs or {} created_at = node_execution.created_at or datetime.now() elapsed_time = node_execution.elapsed_time finished_at = created_at + timedelta(seconds=elapsed_time) - execution_metadata = node_execution.metadata if node_execution.metadata else {} + execution_metadata = node_execution.metadata or {} metadata = {str(k): v for k, v in execution_metadata.items()} metadata.update( { @@ -163,7 +163,7 @@ class LangFuseDataTrace(BaseTraceInstance): "status": status, } ) - process_data = node_execution.process_data if node_execution.process_data else {} + process_data = node_execution.process_data or {} model_provider = process_data.get("model_provider", None) model_name = process_data.get("model_name", None) if model_provider is not None and model_name is not None: diff --git a/api/core/ops/langsmith_trace/langsmith_trace.py b/api/core/ops/langsmith_trace/langsmith_trace.py index ff8a346d27..6c24ac0e47 100644 --- a/api/core/ops/langsmith_trace/langsmith_trace.py +++ b/api/core/ops/langsmith_trace/langsmith_trace.py @@ -167,13 +167,13 @@ class LangSmithDataTrace(BaseTraceInstance): if node_type == NodeType.LLM: inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {} else: - inputs = node_execution.inputs if node_execution.inputs else {} - outputs = node_execution.outputs if node_execution.outputs else {} + inputs = node_execution.inputs or {} + outputs = node_execution.outputs or {} created_at = node_execution.created_at or datetime.now() elapsed_time = node_execution.elapsed_time finished_at = created_at + timedelta(seconds=elapsed_time) - execution_metadata = node_execution.metadata if node_execution.metadata else {} + execution_metadata = node_execution.metadata or {} node_total_tokens = execution_metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS) or 0 metadata = {str(key): value for key, value in execution_metadata.items()} metadata.update( @@ -188,7 +188,7 @@ class LangSmithDataTrace(BaseTraceInstance): } ) - process_data = node_execution.process_data if node_execution.process_data else {} + process_data = node_execution.process_data or {} if process_data and process_data.get("model_mode") == "chat": run_type = LangSmithRunType.llm diff --git a/api/core/ops/opik_trace/opik_trace.py b/api/core/ops/opik_trace/opik_trace.py index 20e880e940..98e9cb2dcb 100644 --- a/api/core/ops/opik_trace/opik_trace.py +++ b/api/core/ops/opik_trace/opik_trace.py @@ -182,13 +182,13 @@ class OpikDataTrace(BaseTraceInstance): if node_type == NodeType.LLM: inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {} else: - inputs = node_execution.inputs if node_execution.inputs else {} - outputs = node_execution.outputs if node_execution.outputs else {} + inputs = node_execution.inputs or {} + outputs = node_execution.outputs or {} created_at = node_execution.created_at or datetime.now() elapsed_time = node_execution.elapsed_time finished_at = created_at + timedelta(seconds=elapsed_time) - execution_metadata = node_execution.metadata if node_execution.metadata else {} + execution_metadata = node_execution.metadata or {} metadata = {str(k): v for k, v in execution_metadata.items()} metadata.update( { @@ -202,7 +202,7 @@ class OpikDataTrace(BaseTraceInstance): } ) - process_data = node_execution.process_data if node_execution.process_data else {} + process_data = node_execution.process_data or {} provider = None model = None diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index aeed928581..08d4adb2ff 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -1,3 +1,4 @@ +import collections import json import logging import os @@ -40,7 +41,7 @@ from tasks.ops_trace_task import process_trace_tasks logger = logging.getLogger(__name__) -class OpsTraceProviderConfigMap(dict[str, dict[str, Any]]): +class OpsTraceProviderConfigMap(collections.UserDict[str, dict[str, Any]]): def __getitem__(self, provider: str) -> dict[str, Any]: match provider: case TracingProviderEnum.LANGFUSE: @@ -121,7 +122,7 @@ class OpsTraceProviderConfigMap(dict[str, dict[str, Any]]): raise KeyError(f"Unsupported tracing provider: {provider}") -provider_config_map: dict[str, dict[str, Any]] = OpsTraceProviderConfigMap() +provider_config_map = OpsTraceProviderConfigMap() class OpsTraceManager: diff --git a/api/core/ops/weave_trace/weave_trace.py b/api/core/ops/weave_trace/weave_trace.py index c8532c63be..13a4529311 100644 --- a/api/core/ops/weave_trace/weave_trace.py +++ b/api/core/ops/weave_trace/weave_trace.py @@ -169,13 +169,13 @@ class WeaveDataTrace(BaseTraceInstance): if node_type == NodeType.LLM: inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {} else: - inputs = node_execution.inputs if node_execution.inputs else {} - outputs = node_execution.outputs if node_execution.outputs else {} + inputs = node_execution.inputs or {} + outputs = node_execution.outputs or {} created_at = node_execution.created_at or datetime.now() elapsed_time = node_execution.elapsed_time finished_at = created_at + timedelta(seconds=elapsed_time) - execution_metadata = node_execution.metadata if node_execution.metadata else {} + execution_metadata = node_execution.metadata or {} node_total_tokens = execution_metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS) or 0 attributes = {str(k): v for k, v in execution_metadata.items()} attributes.update( @@ -190,7 +190,7 @@ class WeaveDataTrace(BaseTraceInstance): } ) - process_data = node_execution.process_data if node_execution.process_data else {} + process_data = node_execution.process_data or {} if process_data and process_data.get("model_mode") == "chat": attributes.update( { diff --git a/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py b/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py index 7a4f11a453..e55e5f3101 100644 --- a/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py +++ b/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py @@ -641,7 +641,7 @@ class ClickzettaVector(BaseVector): for doc, embedding in zip(batch_docs, batch_embeddings): # Optimized: minimal checks for common case, fallback for edge cases - metadata = doc.metadata if doc.metadata else {} + metadata = doc.metadata or {} if not isinstance(metadata, dict): metadata = {} diff --git a/api/core/rag/datasource/vdb/matrixone/matrixone_vector.py b/api/core/rag/datasource/vdb/matrixone/matrixone_vector.py index f6e8fa6d0f..6fe396dc1e 100644 --- a/api/core/rag/datasource/vdb/matrixone/matrixone_vector.py +++ b/api/core/rag/datasource/vdb/matrixone/matrixone_vector.py @@ -103,7 +103,7 @@ class MatrixoneVector(BaseVector): self.client = self._get_client(len(embeddings[0]), True) assert self.client is not None ids = [] - for _, doc in enumerate(documents): + for doc in documents: if doc.metadata is not None: doc_id = doc.metadata.get("doc_id", str(uuid.uuid4())) ids.append(doc_id) diff --git a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py index 764248e6fa..3eb1df027e 100644 --- a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py +++ b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py @@ -104,7 +104,7 @@ class OpenSearchVector(BaseVector): }, } # See https://github.com/langchain-ai/langchainjs/issues/4346#issuecomment-1935123377 - if self._client_config.aws_service not in ["aoss"]: + if self._client_config.aws_service != "aoss": action["_id"] = uuid4().hex actions.append(action) diff --git a/api/core/repositories/sqlalchemy_workflow_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_execution_repository.py index ca5e33d83a..7d1069e28f 100644 --- a/api/core/repositories/sqlalchemy_workflow_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_execution_repository.py @@ -159,7 +159,7 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository): else None ) db_model.status = domain_model.status - db_model.error = domain_model.error_message if domain_model.error_message else None + db_model.error = domain_model.error_message or None db_model.total_tokens = domain_model.total_tokens db_model.total_steps = domain_model.total_steps db_model.exceptions_count = domain_model.exceptions_count diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index 70ec2ee2b9..c075aa3e64 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -320,7 +320,7 @@ class AgentNode(BaseNode): memory = self._fetch_memory(model_instance) if memory: prompt_messages = memory.get_history_prompt_messages( - message_limit=node_data.memory.window.size if node_data.memory.window.size else None + message_limit=node_data.memory.window.size or None ) history_prompt_messages = [ prompt_message.model_dump(mode="json") for prompt_message in prompt_messages diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py index 1884f7f9a9..585539e2ce 100644 --- a/api/extensions/ext_celery.py +++ b/api/extensions/ext_celery.py @@ -141,9 +141,7 @@ def init_app(app: DifyApp) -> Celery: imports.append("schedule.queue_monitor_task") beat_schedule["datasets-queue-monitor"] = { "task": "schedule.queue_monitor_task.queue_monitor_task", - "schedule": timedelta( - minutes=dify_config.QUEUE_MONITOR_INTERVAL if dify_config.QUEUE_MONITOR_INTERVAL else 30 - ), + "schedule": timedelta(minutes=dify_config.QUEUE_MONITOR_INTERVAL or 30), } if dify_config.ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK and dify_config.MARKETPLACE_ENABLED: imports.append("schedule.check_upgradable_plugin_task") diff --git a/api/extensions/storage/clickzetta_volume/file_lifecycle.py b/api/extensions/storage/clickzetta_volume/file_lifecycle.py index 5d5d3ee295..6ab02ad8cc 100644 --- a/api/extensions/storage/clickzetta_volume/file_lifecycle.py +++ b/api/extensions/storage/clickzetta_volume/file_lifecycle.py @@ -7,6 +7,7 @@ Supports complete lifecycle management for knowledge base files. import json import logging +import operator from dataclasses import asdict, dataclass from datetime import datetime from enum import StrEnum, auto @@ -356,7 +357,7 @@ class FileLifecycleManager: # Cleanup old versions for each file for base_filename, versions in file_versions.items(): # Sort by version number - versions.sort(key=lambda x: x[0], reverse=True) + versions.sort(key=operator.itemgetter(0), reverse=True) # Keep the newest max_versions versions, delete the rest if len(versions) > max_versions: diff --git a/api/tasks/process_tenant_plugin_autoupgrade_check_task.py b/api/tasks/process_tenant_plugin_autoupgrade_check_task.py index d871b297e0..bae8f1c4db 100644 --- a/api/tasks/process_tenant_plugin_autoupgrade_check_task.py +++ b/api/tasks/process_tenant_plugin_autoupgrade_check_task.py @@ -1,3 +1,4 @@ +import operator import traceback import typing @@ -118,7 +119,7 @@ def process_tenant_plugin_autoupgrade_check_task( current_version = version latest_version = manifest.latest_version - def fix_only_checker(latest_version, current_version): + def fix_only_checker(latest_version: str, current_version: str): latest_version_tuple = tuple(int(val) for val in latest_version.split(".")) current_version_tuple = tuple(int(val) for val in current_version.split(".")) @@ -130,8 +131,7 @@ def process_tenant_plugin_autoupgrade_check_task( return False version_checker = { - TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST: lambda latest_version, - current_version: latest_version != current_version, + TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST: operator.ne, TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY: fix_only_checker, } diff --git a/api/tests/integration_tests/storage/test_clickzetta_volume.py b/api/tests/integration_tests/storage/test_clickzetta_volume.py index 293b469ef3..7e60f60adc 100644 --- a/api/tests/integration_tests/storage/test_clickzetta_volume.py +++ b/api/tests/integration_tests/storage/test_clickzetta_volume.py @@ -3,6 +3,7 @@ import os import tempfile import unittest +from pathlib import Path import pytest @@ -60,8 +61,7 @@ class TestClickZettaVolumeStorage(unittest.TestCase): # Test download with tempfile.NamedTemporaryFile() as temp_file: storage.download(test_filename, temp_file.name) - with open(temp_file.name, "rb") as f: - downloaded_content = f.read() + downloaded_content = Path(temp_file.name).read_bytes() assert downloaded_content == test_content # Test scan diff --git a/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py index 065bcc2cd7..fcae93c669 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py @@ -12,6 +12,7 @@ and realistic testing scenarios with actual PostgreSQL and Redis instances. import uuid from datetime import datetime +from pathlib import Path from unittest.mock import MagicMock, patch import pytest @@ -276,8 +277,7 @@ class TestBatchCreateSegmentToIndexTask: mock_storage = mock_external_service_dependencies["storage"] def mock_download(key, file_path): - with open(file_path, "w", encoding="utf-8") as f: - f.write(csv_content) + Path(file_path).write_text(csv_content, encoding="utf-8") mock_storage.download.side_effect = mock_download @@ -505,7 +505,7 @@ class TestBatchCreateSegmentToIndexTask: db.session.commit() # Test each unavailable document - for i, document in enumerate(test_cases): + for document in test_cases: job_id = str(uuid.uuid4()) batch_create_segment_to_index_task( job_id=job_id, @@ -601,8 +601,7 @@ class TestBatchCreateSegmentToIndexTask: mock_storage = mock_external_service_dependencies["storage"] def mock_download(key, file_path): - with open(file_path, "w", encoding="utf-8") as f: - f.write(empty_csv_content) + Path(file_path).write_text(empty_csv_content, encoding="utf-8") mock_storage.download.side_effect = mock_download @@ -684,8 +683,7 @@ class TestBatchCreateSegmentToIndexTask: mock_storage = mock_external_service_dependencies["storage"] def mock_download(key, file_path): - with open(file_path, "w", encoding="utf-8") as f: - f.write(csv_content) + Path(file_path).write_text(csv_content, encoding="utf-8") mock_storage.download.side_effect = mock_download diff --git a/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py b/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py index 0083011070..e0c2da63b9 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py @@ -362,7 +362,7 @@ class TestCleanDatasetTask: # Create segments for each document segments = [] - for i, document in enumerate(documents): + for document in documents: segment = self._create_test_segment(db_session_with_containers, account, tenant, dataset, document) segments.append(segment) diff --git a/api/tests/unit_tests/core/tools/utils/test_web_reader_tool.py b/api/tests/unit_tests/core/tools/utils/test_web_reader_tool.py index 57ddacd13d..0bf4a3cf91 100644 --- a/api/tests/unit_tests/core/tools/utils/test_web_reader_tool.py +++ b/api/tests/unit_tests/core/tools/utils/test_web_reader_tool.py @@ -15,7 +15,7 @@ class FakeResponse: self.status_code = status_code self.headers = headers or {} self.content = content - self.text = text if text else content.decode("utf-8", errors="ignore") + self.text = text or content.decode("utf-8", errors="ignore") # --------------------------- diff --git a/api/tests/unit_tests/services/test_metadata_bug_complete.py b/api/tests/unit_tests/services/test_metadata_bug_complete.py index ad65175e89..0ff1edc950 100644 --- a/api/tests/unit_tests/services/test_metadata_bug_complete.py +++ b/api/tests/unit_tests/services/test_metadata_bug_complete.py @@ -1,3 +1,4 @@ +from pathlib import Path from unittest.mock import Mock, create_autospec, patch import pytest @@ -146,19 +147,17 @@ class TestMetadataBugCompleteValidation: # Console API create console_create_file = "api/controllers/console/datasets/metadata.py" if os.path.exists(console_create_file): - with open(console_create_file) as f: - content = f.read() - # Should contain nullable=False, not nullable=True - assert "nullable=True" not in content.split("class DatasetMetadataCreateApi")[1].split("class")[0] + content = Path(console_create_file).read_text() + # Should contain nullable=False, not nullable=True + assert "nullable=True" not in content.split("class DatasetMetadataCreateApi")[1].split("class")[0] # Service API create service_create_file = "api/controllers/service_api/dataset/metadata.py" if os.path.exists(service_create_file): - with open(service_create_file) as f: - content = f.read() - # Should contain nullable=False, not nullable=True - create_api_section = content.split("class DatasetMetadataCreateServiceApi")[1].split("class")[0] - assert "nullable=True" not in create_api_section + content = Path(service_create_file).read_text() + # Should contain nullable=False, not nullable=True + create_api_section = content.split("class DatasetMetadataCreateServiceApi")[1].split("class")[0] + assert "nullable=True" not in create_api_section class TestMetadataValidationSummary: diff --git a/dev/pytest/pytest_config_tests.py b/dev/pytest/pytest_config_tests.py index 63d0cbaf3a..1ec95deb09 100644 --- a/dev/pytest/pytest_config_tests.py +++ b/dev/pytest/pytest_config_tests.py @@ -1,6 +1,7 @@ +from pathlib import Path + import yaml # type: ignore from dotenv import dotenv_values -from pathlib import Path BASE_API_AND_DOCKER_CONFIG_SET_DIFF = { "APP_MAX_EXECUTION_TIME", @@ -98,23 +99,15 @@ with open(Path("docker") / Path("docker-compose.yaml")) as f: def test_yaml_config(): # python set == operator is used to compare two sets - DIFF_API_WITH_DOCKER = ( - API_CONFIG_SET - DOCKER_CONFIG_SET - BASE_API_AND_DOCKER_CONFIG_SET_DIFF - ) + DIFF_API_WITH_DOCKER = API_CONFIG_SET - DOCKER_CONFIG_SET - BASE_API_AND_DOCKER_CONFIG_SET_DIFF if DIFF_API_WITH_DOCKER: - print( - f"API and Docker config sets are different with key: {DIFF_API_WITH_DOCKER}" - ) + print(f"API and Docker config sets are different with key: {DIFF_API_WITH_DOCKER}") raise Exception("API and Docker config sets are different") DIFF_API_WITH_DOCKER_COMPOSE = ( - API_CONFIG_SET - - DOCKER_COMPOSE_CONFIG_SET - - BASE_API_AND_DOCKER_COMPOSE_CONFIG_SET_DIFF + API_CONFIG_SET - DOCKER_COMPOSE_CONFIG_SET - BASE_API_AND_DOCKER_COMPOSE_CONFIG_SET_DIFF ) if DIFF_API_WITH_DOCKER_COMPOSE: - print( - f"API and Docker Compose config sets are different with key: {DIFF_API_WITH_DOCKER_COMPOSE}" - ) + print(f"API and Docker Compose config sets are different with key: {DIFF_API_WITH_DOCKER_COMPOSE}") raise Exception("API and Docker Compose config sets are different") print("All tests passed!") diff --git a/scripts/stress-test/cleanup.py b/scripts/stress-test/cleanup.py index 5fb1b96ddc..05b97be7ca 100755 --- a/scripts/stress-test/cleanup.py +++ b/scripts/stress-test/cleanup.py @@ -51,9 +51,7 @@ def cleanup() -> None: if sys.stdin.isatty(): log.separator() log.warning("This action cannot be undone!") - confirmation = input( - "Are you sure you want to remove all config and report files? (yes/no): " - ) + confirmation = input("Are you sure you want to remove all config and report files? (yes/no): ") if confirmation.lower() not in ["yes", "y"]: log.error("Cleanup cancelled.") diff --git a/scripts/stress-test/common/__init__.py b/scripts/stress-test/common/__init__.py index 920c31482e..a38d972ffb 100644 --- a/scripts/stress-test/common/__init__.py +++ b/scripts/stress-test/common/__init__.py @@ -3,4 +3,4 @@ from .config_helper import config_helper from .logger_helper import Logger, ProgressLogger -__all__ = ["config_helper", "Logger", "ProgressLogger"] +__all__ = ["Logger", "ProgressLogger", "config_helper"] diff --git a/scripts/stress-test/common/config_helper.py b/scripts/stress-test/common/config_helper.py index 9a3581fb54..75fcbffa6f 100644 --- a/scripts/stress-test/common/config_helper.py +++ b/scripts/stress-test/common/config_helper.py @@ -65,9 +65,9 @@ class ConfigHelper: return None try: - with open(config_path, "r") as f: + with open(config_path) as f: return json.load(f) - except (json.JSONDecodeError, IOError) as e: + except (OSError, json.JSONDecodeError) as e: print(f"❌ Error reading {filename}: {e}") return None @@ -101,7 +101,7 @@ class ConfigHelper: with open(config_path, "w") as f: json.dump(data, f, indent=2) return True - except IOError as e: + except OSError as e: print(f"❌ Error writing {filename}: {e}") return False @@ -133,7 +133,7 @@ class ConfigHelper: try: config_path.unlink() return True - except IOError as e: + except OSError as e: print(f"❌ Error deleting {filename}: {e}") return False @@ -148,9 +148,9 @@ class ConfigHelper: return None try: - with open(state_path, "r") as f: + with open(state_path) as f: return json.load(f) - except (json.JSONDecodeError, IOError) as e: + except (OSError, json.JSONDecodeError) as e: print(f"❌ Error reading {self.state_file}: {e}") return None @@ -170,7 +170,7 @@ class ConfigHelper: with open(state_path, "w") as f: json.dump(data, f, indent=2) return True - except IOError as e: + except OSError as e: print(f"❌ Error writing {self.state_file}: {e}") return False diff --git a/scripts/stress-test/common/logger_helper.py b/scripts/stress-test/common/logger_helper.py index 03436eed6c..c522685f1d 100644 --- a/scripts/stress-test/common/logger_helper.py +++ b/scripts/stress-test/common/logger_helper.py @@ -159,9 +159,7 @@ class ProgressLogger: if self.logger.use_colors: progress_bar = self._create_progress_bar() - print( - f"\n\033[1m[Step {self.current_step}/{self.total_steps}]\033[0m {progress_bar}" - ) + print(f"\n\033[1m[Step {self.current_step}/{self.total_steps}]\033[0m {progress_bar}") self.logger.step(f"{description} (Elapsed: {elapsed:.1f}s)") else: print(f"\n[Step {self.current_step}/{self.total_steps}]") diff --git a/scripts/stress-test/setup/configure_openai_plugin.py b/scripts/stress-test/setup/configure_openai_plugin.py index 5ec23b7d62..110417fa73 100755 --- a/scripts/stress-test/setup/configure_openai_plugin.py +++ b/scripts/stress-test/setup/configure_openai_plugin.py @@ -6,8 +6,7 @@ from pathlib import Path sys.path.append(str(Path(__file__).parent.parent)) import httpx -from common import config_helper -from common import Logger +from common import Logger, config_helper def configure_openai_plugin() -> None: @@ -72,29 +71,19 @@ def configure_openai_plugin() -> None: if response.status_code == 200: log.success("OpenAI plugin configured successfully!") - log.key_value( - "API Base", config_payload["credentials"]["openai_api_base"] - ) - log.key_value( - "API Key", config_payload["credentials"]["openai_api_key"] - ) + log.key_value("API Base", config_payload["credentials"]["openai_api_base"]) + log.key_value("API Key", config_payload["credentials"]["openai_api_key"]) elif response.status_code == 201: log.success("OpenAI plugin credentials created successfully!") - log.key_value( - "API Base", config_payload["credentials"]["openai_api_base"] - ) - log.key_value( - "API Key", config_payload["credentials"]["openai_api_key"] - ) + log.key_value("API Base", config_payload["credentials"]["openai_api_base"]) + log.key_value("API Key", config_payload["credentials"]["openai_api_key"]) elif response.status_code == 401: log.error("Configuration failed: Unauthorized") log.info("Token may have expired. Please run login_admin.py again") else: - log.error( - f"Configuration failed with status code: {response.status_code}" - ) + log.error(f"Configuration failed with status code: {response.status_code}") log.debug(f"Response: {response.text}") except httpx.ConnectError: diff --git a/scripts/stress-test/setup/create_api_key.py b/scripts/stress-test/setup/create_api_key.py index 08eaed309a..cd04fe57eb 100755 --- a/scripts/stress-test/setup/create_api_key.py +++ b/scripts/stress-test/setup/create_api_key.py @@ -5,10 +5,10 @@ from pathlib import Path sys.path.append(str(Path(__file__).parent.parent)) -import httpx import json -from common import config_helper -from common import Logger + +import httpx +from common import Logger, config_helper def create_api_key() -> None: @@ -90,9 +90,7 @@ def create_api_key() -> None: } if config_helper.write_config("api_key_config", api_key_config): - log.info( - f"API key saved to: {config_helper.get_config_path('benchmark_state')}" - ) + log.info(f"API key saved to: {config_helper.get_config_path('benchmark_state')}") else: log.error("No API token received") log.debug(f"Response: {json.dumps(response_data, indent=2)}") @@ -101,9 +99,7 @@ def create_api_key() -> None: log.error("API key creation failed: Unauthorized") log.info("Token may have expired. Please run login_admin.py again") else: - log.error( - f"API key creation failed with status code: {response.status_code}" - ) + log.error(f"API key creation failed with status code: {response.status_code}") log.debug(f"Response: {response.text}") except httpx.ConnectError: diff --git a/scripts/stress-test/setup/import_workflow_app.py b/scripts/stress-test/setup/import_workflow_app.py index 1d8a9b6009..86d0239e35 100755 --- a/scripts/stress-test/setup/import_workflow_app.py +++ b/scripts/stress-test/setup/import_workflow_app.py @@ -5,9 +5,10 @@ from pathlib import Path sys.path.append(str(Path(__file__).parent.parent)) -import httpx import json -from common import config_helper, Logger + +import httpx +from common import Logger, config_helper def import_workflow_app() -> None: @@ -30,7 +31,7 @@ def import_workflow_app() -> None: log.error(f"DSL file not found: {dsl_path}") return - with open(dsl_path, "r") as f: + with open(dsl_path) as f: yaml_content = f.read() log.step("Importing workflow app from DSL...") @@ -86,9 +87,7 @@ def import_workflow_app() -> None: log.success("Workflow app imported successfully!") log.key_value("App ID", app_id) log.key_value("App Mode", response_data.get("app_mode")) - log.key_value( - "DSL Version", response_data.get("imported_dsl_version") - ) + log.key_value("DSL Version", response_data.get("imported_dsl_version")) # Save app_id to config app_config = { @@ -99,9 +98,7 @@ def import_workflow_app() -> None: } if config_helper.write_config("app_config", app_config): - log.info( - f"App config saved to: {config_helper.get_config_path('benchmark_state')}" - ) + log.info(f"App config saved to: {config_helper.get_config_path('benchmark_state')}") else: log.error("Import completed but no app_id received") log.debug(f"Response: {json.dumps(response_data, indent=2)}") diff --git a/scripts/stress-test/setup/install_openai_plugin.py b/scripts/stress-test/setup/install_openai_plugin.py index d925126dae..055e5661f8 100755 --- a/scripts/stress-test/setup/install_openai_plugin.py +++ b/scripts/stress-test/setup/install_openai_plugin.py @@ -5,10 +5,10 @@ from pathlib import Path sys.path.append(str(Path(__file__).parent.parent)) -import httpx import time -from common import config_helper -from common import Logger + +import httpx +from common import Logger, config_helper def install_openai_plugin() -> None: @@ -28,9 +28,7 @@ def install_openai_plugin() -> None: # API endpoint for plugin installation base_url = "http://localhost:5001" - install_endpoint = ( - f"{base_url}/console/api/workspaces/current/plugin/install/marketplace" - ) + install_endpoint = f"{base_url}/console/api/workspaces/current/plugin/install/marketplace" # Plugin identifier plugin_payload = { @@ -83,9 +81,7 @@ def install_openai_plugin() -> None: log.info("Polling for task completion...") # Poll for task completion - task_endpoint = ( - f"{base_url}/console/api/workspaces/current/plugin/tasks/{task_id}" - ) + task_endpoint = f"{base_url}/console/api/workspaces/current/plugin/tasks/{task_id}" max_attempts = 30 # 30 attempts with 2 second delay = 60 seconds max attempt = 0 @@ -131,9 +127,7 @@ def install_openai_plugin() -> None: plugins = task_info.get("plugins", []) if plugins: for plugin in plugins: - log.list_item( - f"{plugin.get('plugin_id')}: {plugin.get('message')}" - ) + log.list_item(f"{plugin.get('plugin_id')}: {plugin.get('message')}") break # Continue polling if status is "pending" or other @@ -149,9 +143,7 @@ def install_openai_plugin() -> None: log.warning("Plugin may already be installed") log.debug(f"Response: {response.text}") else: - log.error( - f"Installation failed with status code: {response.status_code}" - ) + log.error(f"Installation failed with status code: {response.status_code}") log.debug(f"Response: {response.text}") except httpx.ConnectError: diff --git a/scripts/stress-test/setup/login_admin.py b/scripts/stress-test/setup/login_admin.py index 0e3da687f6..572b8fb650 100755 --- a/scripts/stress-test/setup/login_admin.py +++ b/scripts/stress-test/setup/login_admin.py @@ -5,10 +5,10 @@ from pathlib import Path sys.path.append(str(Path(__file__).parent.parent)) -import httpx import json -from common import config_helper -from common import Logger + +import httpx +from common import Logger, config_helper def login_admin() -> None: @@ -77,16 +77,10 @@ def login_admin() -> None: # Save token config if config_helper.write_config("token_config", token_config): - log.info( - f"Token saved to: {config_helper.get_config_path('benchmark_state')}" - ) + log.info(f"Token saved to: {config_helper.get_config_path('benchmark_state')}") # Show truncated token for verification - token_display = ( - f"{access_token[:20]}..." - if len(access_token) > 20 - else "Token saved" - ) + token_display = f"{access_token[:20]}..." if len(access_token) > 20 else "Token saved" log.key_value("Access token", token_display) elif response.status_code == 401: diff --git a/scripts/stress-test/setup/mock_openai_server.py b/scripts/stress-test/setup/mock_openai_server.py index 9e3e8d590a..7333c66e57 100755 --- a/scripts/stress-test/setup/mock_openai_server.py +++ b/scripts/stress-test/setup/mock_openai_server.py @@ -3,8 +3,10 @@ import json import time import uuid -from typing import Any, Iterator -from flask import Flask, request, jsonify, Response +from collections.abc import Iterator +from typing import Any + +from flask import Flask, Response, jsonify, request app = Flask(__name__) diff --git a/scripts/stress-test/setup/publish_workflow.py b/scripts/stress-test/setup/publish_workflow.py index f23db5049c..b772eccebd 100755 --- a/scripts/stress-test/setup/publish_workflow.py +++ b/scripts/stress-test/setup/publish_workflow.py @@ -5,10 +5,10 @@ from pathlib import Path sys.path.append(str(Path(__file__).parent.parent)) -import httpx import json -from common import config_helper -from common import Logger + +import httpx +from common import Logger, config_helper def publish_workflow() -> None: @@ -79,9 +79,7 @@ def publish_workflow() -> None: try: response_data = response.json() if response_data: - log.debug( - f"Response: {json.dumps(response_data, indent=2)}" - ) + log.debug(f"Response: {json.dumps(response_data, indent=2)}") except json.JSONDecodeError: # Response might be empty or non-JSON pass @@ -93,9 +91,7 @@ def publish_workflow() -> None: log.error("Workflow publish failed: App not found") log.info("Make sure the app was imported successfully") else: - log.error( - f"Workflow publish failed with status code: {response.status_code}" - ) + log.error(f"Workflow publish failed with status code: {response.status_code}") log.debug(f"Response: {response.text}") except httpx.ConnectError: diff --git a/scripts/stress-test/setup/run_workflow.py b/scripts/stress-test/setup/run_workflow.py index 13ab1280fc..6da0ff17be 100755 --- a/scripts/stress-test/setup/run_workflow.py +++ b/scripts/stress-test/setup/run_workflow.py @@ -5,9 +5,10 @@ from pathlib import Path sys.path.append(str(Path(__file__).parent.parent)) -import httpx import json -from common import config_helper, Logger + +import httpx +from common import Logger, config_helper def run_workflow(question: str = "fake question", streaming: bool = True) -> None: @@ -70,9 +71,7 @@ def run_workflow(question: str = "fake question", streaming: bool = True) -> Non event = data.get("event") if event == "workflow_started": - log.progress( - f"Workflow started: {data.get('data', {}).get('id')}" - ) + log.progress(f"Workflow started: {data.get('data', {}).get('id')}") elif event == "node_started": node_data = data.get("data", {}) log.progress( @@ -116,9 +115,7 @@ def run_workflow(question: str = "fake question", streaming: bool = True) -> Non # Some lines might not be JSON pass else: - log.error( - f"Workflow run failed with status code: {response.status_code}" - ) + log.error(f"Workflow run failed with status code: {response.status_code}") log.debug(f"Response: {response.text}") else: # Handle blocking response @@ -142,9 +139,7 @@ def run_workflow(question: str = "fake question", streaming: bool = True) -> Non log.info("📤 Final Answer:") log.info(outputs.get("answer"), indent=2) else: - log.error( - f"Workflow run failed with status code: {response.status_code}" - ) + log.error(f"Workflow run failed with status code: {response.status_code}") log.debug(f"Response: {response.text}") except httpx.ConnectError: diff --git a/scripts/stress-test/setup/setup_admin.py b/scripts/stress-test/setup/setup_admin.py index 3cc83aa773..a5e9161210 100755 --- a/scripts/stress-test/setup/setup_admin.py +++ b/scripts/stress-test/setup/setup_admin.py @@ -6,7 +6,7 @@ from pathlib import Path sys.path.append(str(Path(__file__).parent.parent)) import httpx -from common import config_helper, Logger +from common import Logger, config_helper def setup_admin_account() -> None: @@ -24,9 +24,7 @@ def setup_admin_account() -> None: # Save credentials to config file if config_helper.write_config("admin_config", admin_config): - log.info( - f"Admin credentials saved to: {config_helper.get_config_path('benchmark_state')}" - ) + log.info(f"Admin credentials saved to: {config_helper.get_config_path('benchmark_state')}") # API setup endpoint base_url = "http://localhost:5001" @@ -56,9 +54,7 @@ def setup_admin_account() -> None: log.key_value("Username", admin_config["username"]) elif response.status_code == 400: - log.warning( - "Setup may have already been completed or invalid data provided" - ) + log.warning("Setup may have already been completed or invalid data provided") log.debug(f"Response: {response.text}") else: log.error(f"Setup failed with status code: {response.status_code}") diff --git a/scripts/stress-test/setup_all.py b/scripts/stress-test/setup_all.py index 7a4c5c01b2..ece420f925 100755 --- a/scripts/stress-test/setup_all.py +++ b/scripts/stress-test/setup_all.py @@ -1,9 +1,9 @@ #!/usr/bin/env python3 +import socket import subprocess import sys import time -import socket from pathlib import Path from common import Logger, ProgressLogger @@ -93,9 +93,7 @@ def main() -> None: if retry.lower() in ["yes", "y"]: return main() # Recursively call main to check again else: - print( - "❌ Setup cancelled. Please start the required services and try again." - ) + print("❌ Setup cancelled. Please start the required services and try again.") sys.exit(1) log.success("All required services are running!") diff --git a/scripts/stress-test/sse_benchmark.py b/scripts/stress-test/sse_benchmark.py index e08fe6c33b..99fe2b20f4 100644 --- a/scripts/stress-test/sse_benchmark.py +++ b/scripts/stress-test/sse_benchmark.py @@ -7,29 +7,28 @@ measuring key metrics like connection rate, event throughput, and time to first """ import json -import time +import logging +import os import random +import statistics import sys import threading -import os -import logging -import statistics -from pathlib import Path +import time from collections import deque +from dataclasses import asdict, dataclass from datetime import datetime -from dataclasses import dataclass, asdict -from locust import HttpUser, task, between, events, constant -from typing import TypedDict, Literal, TypeAlias +from pathlib import Path +from typing import Literal, TypeAlias, TypedDict + import requests.exceptions +from locust import HttpUser, between, constant, events, task # Add the stress-test directory to path to import common modules sys.path.insert(0, str(Path(__file__).parent)) from common.config_helper import ConfigHelper # type: ignore[import-not-found] # Configure logging -logging.basicConfig( - level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" -) +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) # Configuration from environment @@ -54,6 +53,7 @@ ErrorType: TypeAlias = Literal[ class ErrorCounts(TypedDict): """Error count tracking""" + connection_error: int timeout: int invalid_json: int @@ -65,6 +65,7 @@ class ErrorCounts(TypedDict): class SSEEvent(TypedDict): """Server-Sent Event structure""" + data: str event: str id: str | None @@ -72,11 +73,13 @@ class SSEEvent(TypedDict): class WorkflowInputs(TypedDict): """Workflow input structure""" + question: str class WorkflowRequestData(TypedDict): """Workflow request payload""" + inputs: WorkflowInputs response_mode: Literal["streaming"] user: str @@ -84,6 +87,7 @@ class WorkflowRequestData(TypedDict): class ParsedEventData(TypedDict, total=False): """Parsed event data from SSE stream""" + event: str task_id: str workflow_run_id: str @@ -93,6 +97,7 @@ class ParsedEventData(TypedDict, total=False): class LocustStats(TypedDict): """Locust statistics structure""" + total_requests: int total_failures: int avg_response_time: float @@ -102,6 +107,7 @@ class LocustStats(TypedDict): class ReportData(TypedDict): """JSON report structure""" + timestamp: str duration_seconds: float metrics: dict[str, object] # Metrics as dict for JSON serialization @@ -154,7 +160,7 @@ class MetricsTracker: self.total_connections = 0 self.total_events = 0 self.start_time = time.time() - + # Enhanced metrics with memory limits self.max_samples = 10000 # Prevent unbounded growth self.ttfe_samples: deque[float] = deque(maxlen=self.max_samples) @@ -233,9 +239,7 @@ class MetricsTracker: max_ttfe = max(self.ttfe_samples) p50_ttfe = statistics.median(self.ttfe_samples) if len(self.ttfe_samples) >= 2: - quantiles = statistics.quantiles( - self.ttfe_samples, n=20, method="inclusive" - ) + quantiles = statistics.quantiles(self.ttfe_samples, n=20, method="inclusive") p95_ttfe = quantiles[18] # 19th of 19 quantiles = 95th percentile else: p95_ttfe = max_ttfe @@ -255,9 +259,7 @@ class MetricsTracker: if durations else 0 ) - events_per_stream_avg = ( - statistics.mean(events_per_stream) if events_per_stream else 0 - ) + events_per_stream_avg = statistics.mean(events_per_stream) if events_per_stream else 0 # Calculate inter-event latency statistics all_inter_event_times = [] @@ -268,32 +270,20 @@ class MetricsTracker: inter_event_latency_avg = statistics.mean(all_inter_event_times) inter_event_latency_p50 = statistics.median(all_inter_event_times) inter_event_latency_p95 = ( - statistics.quantiles( - all_inter_event_times, n=20, method="inclusive" - )[18] + statistics.quantiles(all_inter_event_times, n=20, method="inclusive")[18] if len(all_inter_event_times) >= 2 else max(all_inter_event_times) ) else: - inter_event_latency_avg = inter_event_latency_p50 = ( - inter_event_latency_p95 - ) = 0 + inter_event_latency_avg = inter_event_latency_p50 = inter_event_latency_p95 = 0 else: - stream_duration_avg = stream_duration_p50 = stream_duration_p95 = ( - events_per_stream_avg - ) = 0 - inter_event_latency_avg = inter_event_latency_p50 = ( - inter_event_latency_p95 - ) = 0 + stream_duration_avg = stream_duration_p50 = stream_duration_p95 = events_per_stream_avg = 0 + inter_event_latency_avg = inter_event_latency_p50 = inter_event_latency_p95 = 0 # Also calculate overall average rates total_elapsed = current_time - self.start_time - overall_conn_rate = ( - self.total_connections / total_elapsed if total_elapsed > 0 else 0 - ) - overall_event_rate = ( - self.total_events / total_elapsed if total_elapsed > 0 else 0 - ) + overall_conn_rate = self.total_connections / total_elapsed if total_elapsed > 0 else 0 + overall_event_rate = self.total_events / total_elapsed if total_elapsed > 0 else 0 return MetricsSnapshot( active_connections=self.active_connections, @@ -389,7 +379,7 @@ class DifyWorkflowUser(HttpUser): # Load questions from file or use defaults if QUESTIONS_FILE and os.path.exists(QUESTIONS_FILE): - with open(QUESTIONS_FILE, "r") as f: + with open(QUESTIONS_FILE) as f: self.questions = [line.strip() for line in f if line.strip()] else: self.questions = [ @@ -451,18 +441,13 @@ class DifyWorkflowUser(HttpUser): try: # Validate response if response.status_code >= 400: - error_type: ErrorType = ( - "http_4xx" if response.status_code < 500 else "http_5xx" - ) + error_type: ErrorType = "http_4xx" if response.status_code < 500 else "http_5xx" metrics.record_error(error_type) response.failure(f"HTTP {response.status_code}") return content_type = response.headers.get("Content-Type", "") - if ( - "text/event-stream" not in content_type - and "application/json" not in content_type - ): + if "text/event-stream" not in content_type and "application/json" not in content_type: logger.error(f"Expected text/event-stream, got: {content_type}") metrics.record_error("invalid_response") response.failure(f"Invalid content type: {content_type}") @@ -473,10 +458,13 @@ class DifyWorkflowUser(HttpUser): for line in response.iter_lines(decode_unicode=True): # Check if runner is stopping - if getattr(self.environment.runner, 'state', '') in ('stopping', 'stopped'): + if getattr(self.environment.runner, "state", "") in ( + "stopping", + "stopped", + ): logger.debug("Runner stopping, breaking streaming loop") break - + if line is not None: bytes_received += len(line.encode("utf-8")) @@ -489,9 +477,7 @@ class DifyWorkflowUser(HttpUser): # Track inter-event timing if last_event_time: - inter_event_times.append( - (current_time - last_event_time) * 1000 - ) + inter_event_times.append((current_time - last_event_time) * 1000) last_event_time = current_time if first_event_time is None: @@ -512,15 +498,11 @@ class DifyWorkflowUser(HttpUser): parsed_event: ParsedEventData = json.loads(event_data) # Check for terminal events if parsed_event.get("event") in TERMINAL_EVENTS: - logger.debug( - f"Received terminal event: {parsed_event.get('event')}" - ) + logger.debug(f"Received terminal event: {parsed_event.get('event')}") request_success = True break except json.JSONDecodeError as e: - logger.debug( - f"JSON decode error: {e} for data: {event_data[:100]}" - ) + logger.debug(f"JSON decode error: {e} for data: {event_data[:100]}") metrics.record_error("invalid_json") except Exception as e: @@ -583,16 +565,18 @@ def on_test_start(environment: object, **kwargs: object) -> None: # Periodic stats reporting def report_stats() -> None: - if not hasattr(environment, 'runner'): + if not hasattr(environment, "runner"): return runner = environment.runner - while hasattr(runner, 'state') and runner.state not in ["stopped", "stopping"]: + while hasattr(runner, "state") and runner.state not in ["stopped", "stopping"]: time.sleep(5) # Report every 5 seconds - if hasattr(runner, 'state') and runner.state == "running": + if hasattr(runner, "state") and runner.state == "running": stats = metrics.get_stats() # Only log on master node in distributed mode - is_master = not getattr(environment.runner, "worker_id", None) if hasattr(environment, 'runner') else True + is_master = ( + not getattr(environment.runner, "worker_id", None) if hasattr(environment, "runner") else True + ) if is_master: # Clear previous lines and show updated stats logger.info("\n" + "=" * 80) @@ -623,15 +607,15 @@ def on_test_start(environment: object, **kwargs: object) -> None: logger.info( f"{'(TTFE in ms)':<25} {stats.ttfe_avg:>15.1f} {stats.ttfe_p50:>10.1f} {stats.ttfe_p95:>10.1f} {stats.ttfe_min:>10.1f} {stats.ttfe_max:>10.1f}" ) - logger.info(f"{'Window Samples':<25} {stats.ttfe_samples:>15,d} (last {min(10000, stats.ttfe_total_samples):,d} samples)") + logger.info( + f"{'Window Samples':<25} {stats.ttfe_samples:>15,d} (last {min(10000, stats.ttfe_total_samples):,d} samples)" + ) logger.info(f"{'Total Samples':<25} {stats.ttfe_total_samples:>15,d}") # Inter-event latency if stats.inter_event_latency_avg > 0: logger.info("-" * 80) - logger.info( - f"{'INTER-EVENT LATENCY':<25} {'AVG':>15} {'P50':>10} {'P95':>10}" - ) + logger.info(f"{'INTER-EVENT LATENCY':<25} {'AVG':>15} {'P50':>10} {'P95':>10}") logger.info( f"{'(ms between events)':<25} {stats.inter_event_latency_avg:>15.1f} {stats.inter_event_latency_p50:>10.1f} {stats.inter_event_latency_p95:>10.1f}" ) @@ -647,9 +631,9 @@ def on_test_start(environment: object, **kwargs: object) -> None: logger.info("=" * 80) # Show Locust stats summary - if hasattr(environment, 'stats') and hasattr(environment.stats, 'total'): + if hasattr(environment, "stats") and hasattr(environment.stats, "total"): total = environment.stats.total - if hasattr(total, 'num_requests') and total.num_requests > 0: + if hasattr(total, "num_requests") and total.num_requests > 0: logger.info( f"{'LOCUST STATS':<25} {'Requests':>12} {'Fails':>8} {'Avg (ms)':>12} {'Min':>8} {'Max':>8}" ) @@ -687,21 +671,15 @@ def on_test_stop(environment: object, **kwargs: object) -> None: logger.info("") logger.info("EVENTS") logger.info(f" {'Total Events Received:':<30} {stats.total_events:>10,d}") - logger.info( - f" {'Average Throughput:':<30} {stats.overall_event_rate:>10.2f} events/s" - ) - logger.info( - f" {'Final Rate (10s window):':<30} {stats.event_rate:>10.2f} events/s" - ) + logger.info(f" {'Average Throughput:':<30} {stats.overall_event_rate:>10.2f} events/s") + logger.info(f" {'Final Rate (10s window):':<30} {stats.event_rate:>10.2f} events/s") logger.info("") logger.info("STREAM METRICS") logger.info(f" {'Avg Stream Duration:':<30} {stats.stream_duration_avg:>10.1f} ms") logger.info(f" {'P50 Stream Duration:':<30} {stats.stream_duration_p50:>10.1f} ms") logger.info(f" {'P95 Stream Duration:':<30} {stats.stream_duration_p95:>10.1f} ms") - logger.info( - f" {'Avg Events per Stream:':<30} {stats.events_per_stream_avg:>10.1f}" - ) + logger.info(f" {'Avg Events per Stream:':<30} {stats.events_per_stream_avg:>10.1f}") logger.info("") logger.info("INTER-EVENT LATENCY") @@ -716,7 +694,9 @@ def on_test_stop(environment: object, **kwargs: object) -> None: logger.info(f" {'95th Percentile:':<30} {stats.ttfe_p95:>10.1f} ms") logger.info(f" {'Minimum:':<30} {stats.ttfe_min:>10.1f} ms") logger.info(f" {'Maximum:':<30} {stats.ttfe_max:>10.1f} ms") - logger.info(f" {'Window Samples:':<30} {stats.ttfe_samples:>10,d} (last {min(10000, stats.ttfe_total_samples):,d})") + logger.info( + f" {'Window Samples:':<30} {stats.ttfe_samples:>10,d} (last {min(10000, stats.ttfe_total_samples):,d})" + ) logger.info(f" {'Total Samples:':<30} {stats.ttfe_total_samples:>10,d}") # Error summary @@ -730,7 +710,7 @@ def on_test_stop(environment: object, **kwargs: object) -> None: logger.info("=" * 80 + "\n") # Export machine-readable report (only on master node) - is_master = not getattr(environment.runner, 'worker_id', None) if hasattr(environment, 'runner') else True + is_master = not getattr(environment.runner, "worker_id", None) if hasattr(environment, "runner") else True if is_master: export_json_report(stats, test_duration, environment) @@ -746,9 +726,9 @@ def export_json_report(stats: MetricsSnapshot, duration: float, environment: obj # Access environment.stats.total attributes safely locust_stats: LocustStats | None = None - if hasattr(environment, 'stats') and hasattr(environment.stats, 'total'): + if hasattr(environment, "stats") and hasattr(environment.stats, "total"): total = environment.stats.total - if hasattr(total, 'num_requests') and total.num_requests > 0: + if hasattr(total, "num_requests") and total.num_requests > 0: locust_stats = LocustStats( total_requests=total.num_requests, total_failures=total.num_failures, diff --git a/sdks/python-client/dify_client/__init__.py b/sdks/python-client/dify_client/__init__.py index d00c207afa..e866472f45 100644 --- a/sdks/python-client/dify_client/__init__.py +++ b/sdks/python-client/dify_client/__init__.py @@ -1,7 +1,15 @@ from dify_client.client import ( ChatClient, CompletionClient, - WorkflowClient, - KnowledgeBaseClient, DifyClient, + KnowledgeBaseClient, + WorkflowClient, ) + +__all__ = [ + "ChatClient", + "CompletionClient", + "DifyClient", + "KnowledgeBaseClient", + "WorkflowClient", +] diff --git a/sdks/python-client/dify_client/client.py b/sdks/python-client/dify_client/client.py index f113d4f0a8..791cb98a1b 100644 --- a/sdks/python-client/dify_client/client.py +++ b/sdks/python-client/dify_client/client.py @@ -8,16 +8,16 @@ class DifyClient: self.api_key = api_key self.base_url = base_url - def _send_request(self, method: str, endpoint: str, json: dict | None = None, params: dict | None = None, stream: bool = False): + def _send_request( + self, method: str, endpoint: str, json: dict | None = None, params: dict | None = None, stream: bool = False + ): headers = { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", } url = f"{self.base_url}{endpoint}" - response = requests.request( - method, url, json=json, params=params, headers=headers, stream=stream - ) + response = requests.request(method, url, json=json, params=params, headers=headers, stream=stream) return response @@ -25,9 +25,7 @@ class DifyClient: headers = {"Authorization": f"Bearer {self.api_key}"} url = f"{self.base_url}{endpoint}" - response = requests.request( - method, url, data=data, headers=headers, files=files - ) + response = requests.request(method, url, data=data, headers=headers, files=files) return response @@ -41,9 +39,7 @@ class DifyClient: def file_upload(self, user: str, files: dict): data = {"user": user} - return self._send_request_with_files( - "POST", "/files/upload", data=data, files=files - ) + return self._send_request_with_files("POST", "/files/upload", data=data, files=files) def text_to_audio(self, text: str, user: str, streaming: bool = False): data = {"text": text, "user": user, "streaming": streaming} @@ -55,7 +51,9 @@ class DifyClient: class CompletionClient(DifyClient): - def create_completion_message(self, inputs: dict, response_mode: Literal["blocking", "streaming"], user: str, files: dict | None = None): + def create_completion_message( + self, inputs: dict, response_mode: Literal["blocking", "streaming"], user: str, files: dict | None = None + ): data = { "inputs": inputs, "response_mode": response_mode, @@ -99,9 +97,7 @@ class ChatClient(DifyClient): def get_suggested(self, message_id: str, user: str): params = {"user": user} - return self._send_request( - "GET", f"/messages/{message_id}/suggested", params=params - ) + return self._send_request("GET", f"/messages/{message_id}/suggested", params=params) def stop_message(self, task_id: str, user: str): data = {"user": user} @@ -112,10 +108,9 @@ class ChatClient(DifyClient): user: str, last_id: str | None = None, limit: int | None = None, - pinned: bool | None = None + pinned: bool | None = None, ): - params = {"user": user, "last_id": last_id, - "limit": limit, "pinned": pinned} + params = {"user": user, "last_id": last_id, "limit": limit, "pinned": pinned} return self._send_request("GET", "/conversations", params=params) def get_conversation_messages( @@ -123,7 +118,7 @@ class ChatClient(DifyClient): user: str, conversation_id: str | None = None, first_id: str | None = None, - limit: int | None = None + limit: int | None = None, ): params = {"user": user} @@ -136,13 +131,9 @@ class ChatClient(DifyClient): return self._send_request("GET", "/messages", params=params) - def rename_conversation( - self, conversation_id: str, name: str, auto_generate: bool, user: str - ): + def rename_conversation(self, conversation_id: str, name: str, auto_generate: bool, user: str): data = {"name": name, "auto_generate": auto_generate, "user": user} - return self._send_request( - "POST", f"/conversations/{conversation_id}/name", data - ) + return self._send_request("POST", f"/conversations/{conversation_id}/name", data) def delete_conversation(self, conversation_id: str, user: str): data = {"user": user} @@ -155,9 +146,7 @@ class ChatClient(DifyClient): class WorkflowClient(DifyClient): - def run( - self, inputs: dict, response_mode: Literal["blocking", "streaming"] = "streaming", user: str = "abc-123" - ): + def run(self, inputs: dict, response_mode: Literal["blocking", "streaming"] = "streaming", user: str = "abc-123"): data = {"inputs": inputs, "response_mode": response_mode, "user": user} return self._send_request("POST", "/workflows/run", data) @@ -197,13 +186,9 @@ class KnowledgeBaseClient(DifyClient): return self._send_request("POST", "/datasets", {"name": name}, **kwargs) def list_datasets(self, page: int = 1, page_size: int = 20, **kwargs): - return self._send_request( - "GET", f"/datasets?page={page}&limit={page_size}", **kwargs - ) + return self._send_request("GET", f"/datasets?page={page}&limit={page_size}", **kwargs) - def create_document_by_text( - self, name, text, extra_params: dict | None = None, **kwargs - ): + def create_document_by_text(self, name, text, extra_params: dict | None = None, **kwargs): """ Create a document by text. @@ -272,9 +257,7 @@ class KnowledgeBaseClient(DifyClient): data = {"name": name, "text": text} if extra_params is not None and isinstance(extra_params, dict): data.update(extra_params) - url = ( - f"/datasets/{self._get_dataset_id()}/documents/{document_id}/update_by_text" - ) + url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/update_by_text" return self._send_request("POST", url, json=data, **kwargs) def create_document_by_file( @@ -315,13 +298,9 @@ class KnowledgeBaseClient(DifyClient): if original_document_id is not None: data["original_document_id"] = original_document_id url = f"/datasets/{self._get_dataset_id()}/document/create_by_file" - return self._send_request_with_files( - "POST", url, {"data": json.dumps(data)}, files - ) + return self._send_request_with_files("POST", url, {"data": json.dumps(data)}, files) - def update_document_by_file( - self, document_id: str, file_path: str, extra_params: dict | None = None - ): + def update_document_by_file(self, document_id: str, file_path: str, extra_params: dict | None = None): """ Update a document by file. @@ -351,12 +330,8 @@ class KnowledgeBaseClient(DifyClient): data = {} if extra_params is not None and isinstance(extra_params, dict): data.update(extra_params) - url = ( - f"/datasets/{self._get_dataset_id()}/documents/{document_id}/update_by_file" - ) - return self._send_request_with_files( - "POST", url, {"data": json.dumps(data)}, files - ) + url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/update_by_file" + return self._send_request_with_files("POST", url, {"data": json.dumps(data)}, files) def batch_indexing_status(self, batch_id: str, **kwargs): """ diff --git a/sdks/python-client/setup.py b/sdks/python-client/setup.py index 7340fffb4c..a05f6410fb 100644 --- a/sdks/python-client/setup.py +++ b/sdks/python-client/setup.py @@ -1,6 +1,6 @@ from setuptools import setup -with open("README.md", "r", encoding="utf-8") as fh: +with open("README.md", encoding="utf-8") as fh: long_description = fh.read() setup( diff --git a/sdks/python-client/tests/test_client.py b/sdks/python-client/tests/test_client.py index 52032417c0..fce1b11eba 100644 --- a/sdks/python-client/tests/test_client.py +++ b/sdks/python-client/tests/test_client.py @@ -18,9 +18,7 @@ FILE_PATH_BASE = os.path.dirname(__file__) class TestKnowledgeBaseClient(unittest.TestCase): def setUp(self): self.knowledge_base_client = KnowledgeBaseClient(API_KEY, base_url=API_BASE_URL) - self.README_FILE_PATH = os.path.abspath( - os.path.join(FILE_PATH_BASE, "../README.md") - ) + self.README_FILE_PATH = os.path.abspath(os.path.join(FILE_PATH_BASE, "../README.md")) self.dataset_id = None self.document_id = None self.segment_id = None @@ -28,9 +26,7 @@ class TestKnowledgeBaseClient(unittest.TestCase): def _get_dataset_kb_client(self): self.assertIsNotNone(self.dataset_id) - return KnowledgeBaseClient( - API_KEY, base_url=API_BASE_URL, dataset_id=self.dataset_id - ) + return KnowledgeBaseClient(API_KEY, base_url=API_BASE_URL, dataset_id=self.dataset_id) def test_001_create_dataset(self): response = self.knowledge_base_client.create_dataset(name="test_dataset") @@ -76,9 +72,7 @@ class TestKnowledgeBaseClient(unittest.TestCase): def _test_004_update_document_by_text(self): client = self._get_dataset_kb_client() self.assertIsNotNone(self.document_id) - response = client.update_document_by_text( - self.document_id, "test_document_updated", "test_text_updated" - ) + response = client.update_document_by_text(self.document_id, "test_document_updated", "test_text_updated") data = response.json() self.assertIn("document", data) self.assertIn("batch", data) @@ -93,9 +87,7 @@ class TestKnowledgeBaseClient(unittest.TestCase): def _test_006_update_document_by_file(self): client = self._get_dataset_kb_client() self.assertIsNotNone(self.document_id) - response = client.update_document_by_file( - self.document_id, self.README_FILE_PATH - ) + response = client.update_document_by_file(self.document_id, self.README_FILE_PATH) data = response.json() self.assertIn("document", data) self.assertIn("batch", data) @@ -125,9 +117,7 @@ class TestKnowledgeBaseClient(unittest.TestCase): def _test_010_add_segments(self): client = self._get_dataset_kb_client() - response = client.add_segments( - self.document_id, [{"content": "test text segment 1"}] - ) + response = client.add_segments(self.document_id, [{"content": "test text segment 1"}]) data = response.json() self.assertIn("data", data) self.assertGreater(len(data["data"]), 0) @@ -174,18 +164,12 @@ class TestChatClient(unittest.TestCase): self.chat_client = ChatClient(API_KEY) def test_create_chat_message(self): - response = self.chat_client.create_chat_message( - {}, "Hello, World!", "test_user" - ) + response = self.chat_client.create_chat_message({}, "Hello, World!", "test_user") self.assertIn("answer", response.text) def test_create_chat_message_with_vision_model_by_remote_url(self): - files = [ - {"type": "image", "transfer_method": "remote_url", "url": "your_image_url"} - ] - response = self.chat_client.create_chat_message( - {}, "Describe the picture.", "test_user", files=files - ) + files = [{"type": "image", "transfer_method": "remote_url", "url": "your_image_url"}] + response = self.chat_client.create_chat_message({}, "Describe the picture.", "test_user", files=files) self.assertIn("answer", response.text) def test_create_chat_message_with_vision_model_by_local_file(self): @@ -196,15 +180,11 @@ class TestChatClient(unittest.TestCase): "upload_file_id": "your_file_id", } ] - response = self.chat_client.create_chat_message( - {}, "Describe the picture.", "test_user", files=files - ) + response = self.chat_client.create_chat_message({}, "Describe the picture.", "test_user", files=files) self.assertIn("answer", response.text) def test_get_conversation_messages(self): - response = self.chat_client.get_conversation_messages( - "test_user", "your_conversation_id" - ) + response = self.chat_client.get_conversation_messages("test_user", "your_conversation_id") self.assertIn("answer", response.text) def test_get_conversations(self): @@ -223,9 +203,7 @@ class TestCompletionClient(unittest.TestCase): self.assertIn("answer", response.text) def test_create_completion_message_with_vision_model_by_remote_url(self): - files = [ - {"type": "image", "transfer_method": "remote_url", "url": "your_image_url"} - ] + files = [{"type": "image", "transfer_method": "remote_url", "url": "your_image_url"}] response = self.completion_client.create_completion_message( {"query": "Describe the picture."}, "blocking", "test_user", files ) @@ -250,9 +228,7 @@ class TestDifyClient(unittest.TestCase): self.dify_client = DifyClient(API_KEY) def test_message_feedback(self): - response = self.dify_client.message_feedback( - "your_message_id", "like", "test_user" - ) + response = self.dify_client.message_feedback("your_message_id", "like", "test_user") self.assertIn("success", response.text) def test_get_application_parameters(self):