ruff check preview (#25653)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Asuka Minato
2025-09-16 13:58:12 +09:00
committed by GitHub
parent a0c7713494
commit bdd85b36a4
42 changed files with 224 additions and 342 deletions

View File

@@ -22,7 +22,7 @@ jobs:
# Fix lint errors # Fix lint errors
uv run ruff check --fix . uv run ruff check --fix .
# Format code # Format code
uv run ruff format . uv run ruff format ..
- name: ast-grep - name: ast-grep
run: | run: |

View File

@@ -5,7 +5,7 @@ line-length = 120
quote-style = "double" quote-style = "double"
[lint] [lint]
preview = false preview = true
select = [ select = [
"B", # flake8-bugbear rules "B", # flake8-bugbear rules
"C4", # flake8-comprehensions "C4", # flake8-comprehensions
@@ -65,6 +65,7 @@ ignore = [
"B006", # mutable-argument-default "B006", # mutable-argument-default
"B007", # unused-loop-control-variable "B007", # unused-loop-control-variable
"B026", # star-arg-unpacking-after-keyword-arg "B026", # star-arg-unpacking-after-keyword-arg
"B901", # allow return in yield
"B903", # class-as-data-structure "B903", # class-as-data-structure
"B904", # raise-without-from-inside-except "B904", # raise-without-from-inside-except
"B905", # zip-without-explicit-strict "B905", # zip-without-explicit-strict

View File

@@ -1,6 +1,7 @@
import base64 import base64
import json import json
import logging import logging
import operator
import secrets import secrets
from typing import Any from typing import Any
@@ -953,7 +954,7 @@ def clear_orphaned_file_records(force: bool):
click.echo(click.style("- Deleting orphaned message_files records", fg="white")) click.echo(click.style("- Deleting orphaned message_files records", fg="white"))
query = "DELETE FROM message_files WHERE id IN :ids" query = "DELETE FROM message_files WHERE id IN :ids"
with db.engine.begin() as conn: with db.engine.begin() as conn:
conn.execute(sa.text(query), {"ids": tuple([record["id"] for record in orphaned_message_files])}) conn.execute(sa.text(query), {"ids": tuple(record["id"] for record in orphaned_message_files)})
click.echo( click.echo(
click.style(f"Removed {len(orphaned_message_files)} orphaned message_files records.", fg="green") click.style(f"Removed {len(orphaned_message_files)} orphaned message_files records.", fg="green")
) )
@@ -1307,7 +1308,7 @@ def cleanup_orphaned_draft_variables(
if dry_run: if dry_run:
logger.info("DRY RUN: Would delete the following:") logger.info("DRY RUN: Would delete the following:")
for app_id, count in sorted(stats["orphaned_by_app"].items(), key=lambda x: x[1], reverse=True)[ for app_id, count in sorted(stats["orphaned_by_app"].items(), key=operator.itemgetter(1), reverse=True)[
:10 :10
]: # Show top 10 ]: # Show top 10
logger.info(" App %s: %s variables", app_id, count) logger.info(" App %s: %s variables", app_id, count)

View File

@@ -355,8 +355,8 @@ class AliyunDataTrace(BaseTraceInstance):
GEN_AI_FRAMEWORK: "dify", GEN_AI_FRAMEWORK: "dify",
TOOL_NAME: node_execution.title, TOOL_NAME: node_execution.title,
TOOL_DESCRIPTION: json.dumps(tool_des, ensure_ascii=False), TOOL_DESCRIPTION: json.dumps(tool_des, ensure_ascii=False),
TOOL_PARAMETERS: json.dumps(node_execution.inputs if node_execution.inputs else {}, ensure_ascii=False), TOOL_PARAMETERS: json.dumps(node_execution.inputs or {}, ensure_ascii=False),
INPUT_VALUE: json.dumps(node_execution.inputs if node_execution.inputs else {}, ensure_ascii=False), INPUT_VALUE: json.dumps(node_execution.inputs or {}, ensure_ascii=False),
OUTPUT_VALUE: json.dumps(node_execution.outputs, ensure_ascii=False), OUTPUT_VALUE: json.dumps(node_execution.outputs, ensure_ascii=False),
}, },
status=self.get_workflow_node_status(node_execution), status=self.get_workflow_node_status(node_execution),

View File

@@ -144,13 +144,13 @@ class LangFuseDataTrace(BaseTraceInstance):
if node_type == NodeType.LLM: if node_type == NodeType.LLM:
inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {} inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {}
else: else:
inputs = node_execution.inputs if node_execution.inputs else {} inputs = node_execution.inputs or {}
outputs = node_execution.outputs if node_execution.outputs else {} outputs = node_execution.outputs or {}
created_at = node_execution.created_at or datetime.now() created_at = node_execution.created_at or datetime.now()
elapsed_time = node_execution.elapsed_time elapsed_time = node_execution.elapsed_time
finished_at = created_at + timedelta(seconds=elapsed_time) finished_at = created_at + timedelta(seconds=elapsed_time)
execution_metadata = node_execution.metadata if node_execution.metadata else {} execution_metadata = node_execution.metadata or {}
metadata = {str(k): v for k, v in execution_metadata.items()} metadata = {str(k): v for k, v in execution_metadata.items()}
metadata.update( metadata.update(
{ {
@@ -163,7 +163,7 @@ class LangFuseDataTrace(BaseTraceInstance):
"status": status, "status": status,
} }
) )
process_data = node_execution.process_data if node_execution.process_data else {} process_data = node_execution.process_data or {}
model_provider = process_data.get("model_provider", None) model_provider = process_data.get("model_provider", None)
model_name = process_data.get("model_name", None) model_name = process_data.get("model_name", None)
if model_provider is not None and model_name is not None: if model_provider is not None and model_name is not None:

View File

@@ -167,13 +167,13 @@ class LangSmithDataTrace(BaseTraceInstance):
if node_type == NodeType.LLM: if node_type == NodeType.LLM:
inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {} inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {}
else: else:
inputs = node_execution.inputs if node_execution.inputs else {} inputs = node_execution.inputs or {}
outputs = node_execution.outputs if node_execution.outputs else {} outputs = node_execution.outputs or {}
created_at = node_execution.created_at or datetime.now() created_at = node_execution.created_at or datetime.now()
elapsed_time = node_execution.elapsed_time elapsed_time = node_execution.elapsed_time
finished_at = created_at + timedelta(seconds=elapsed_time) finished_at = created_at + timedelta(seconds=elapsed_time)
execution_metadata = node_execution.metadata if node_execution.metadata else {} execution_metadata = node_execution.metadata or {}
node_total_tokens = execution_metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS) or 0 node_total_tokens = execution_metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS) or 0
metadata = {str(key): value for key, value in execution_metadata.items()} metadata = {str(key): value for key, value in execution_metadata.items()}
metadata.update( metadata.update(
@@ -188,7 +188,7 @@ class LangSmithDataTrace(BaseTraceInstance):
} }
) )
process_data = node_execution.process_data if node_execution.process_data else {} process_data = node_execution.process_data or {}
if process_data and process_data.get("model_mode") == "chat": if process_data and process_data.get("model_mode") == "chat":
run_type = LangSmithRunType.llm run_type = LangSmithRunType.llm

View File

@@ -182,13 +182,13 @@ class OpikDataTrace(BaseTraceInstance):
if node_type == NodeType.LLM: if node_type == NodeType.LLM:
inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {} inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {}
else: else:
inputs = node_execution.inputs if node_execution.inputs else {} inputs = node_execution.inputs or {}
outputs = node_execution.outputs if node_execution.outputs else {} outputs = node_execution.outputs or {}
created_at = node_execution.created_at or datetime.now() created_at = node_execution.created_at or datetime.now()
elapsed_time = node_execution.elapsed_time elapsed_time = node_execution.elapsed_time
finished_at = created_at + timedelta(seconds=elapsed_time) finished_at = created_at + timedelta(seconds=elapsed_time)
execution_metadata = node_execution.metadata if node_execution.metadata else {} execution_metadata = node_execution.metadata or {}
metadata = {str(k): v for k, v in execution_metadata.items()} metadata = {str(k): v for k, v in execution_metadata.items()}
metadata.update( metadata.update(
{ {
@@ -202,7 +202,7 @@ class OpikDataTrace(BaseTraceInstance):
} }
) )
process_data = node_execution.process_data if node_execution.process_data else {} process_data = node_execution.process_data or {}
provider = None provider = None
model = None model = None

View File

@@ -1,3 +1,4 @@
import collections
import json import json
import logging import logging
import os import os
@@ -40,7 +41,7 @@ from tasks.ops_trace_task import process_trace_tasks
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class OpsTraceProviderConfigMap(dict[str, dict[str, Any]]): class OpsTraceProviderConfigMap(collections.UserDict[str, dict[str, Any]]):
def __getitem__(self, provider: str) -> dict[str, Any]: def __getitem__(self, provider: str) -> dict[str, Any]:
match provider: match provider:
case TracingProviderEnum.LANGFUSE: case TracingProviderEnum.LANGFUSE:
@@ -121,7 +122,7 @@ class OpsTraceProviderConfigMap(dict[str, dict[str, Any]]):
raise KeyError(f"Unsupported tracing provider: {provider}") raise KeyError(f"Unsupported tracing provider: {provider}")
provider_config_map: dict[str, dict[str, Any]] = OpsTraceProviderConfigMap() provider_config_map = OpsTraceProviderConfigMap()
class OpsTraceManager: class OpsTraceManager:

View File

@@ -169,13 +169,13 @@ class WeaveDataTrace(BaseTraceInstance):
if node_type == NodeType.LLM: if node_type == NodeType.LLM:
inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {} inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {}
else: else:
inputs = node_execution.inputs if node_execution.inputs else {} inputs = node_execution.inputs or {}
outputs = node_execution.outputs if node_execution.outputs else {} outputs = node_execution.outputs or {}
created_at = node_execution.created_at or datetime.now() created_at = node_execution.created_at or datetime.now()
elapsed_time = node_execution.elapsed_time elapsed_time = node_execution.elapsed_time
finished_at = created_at + timedelta(seconds=elapsed_time) finished_at = created_at + timedelta(seconds=elapsed_time)
execution_metadata = node_execution.metadata if node_execution.metadata else {} execution_metadata = node_execution.metadata or {}
node_total_tokens = execution_metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS) or 0 node_total_tokens = execution_metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS) or 0
attributes = {str(k): v for k, v in execution_metadata.items()} attributes = {str(k): v for k, v in execution_metadata.items()}
attributes.update( attributes.update(
@@ -190,7 +190,7 @@ class WeaveDataTrace(BaseTraceInstance):
} }
) )
process_data = node_execution.process_data if node_execution.process_data else {} process_data = node_execution.process_data or {}
if process_data and process_data.get("model_mode") == "chat": if process_data and process_data.get("model_mode") == "chat":
attributes.update( attributes.update(
{ {

View File

@@ -641,7 +641,7 @@ class ClickzettaVector(BaseVector):
for doc, embedding in zip(batch_docs, batch_embeddings): for doc, embedding in zip(batch_docs, batch_embeddings):
# Optimized: minimal checks for common case, fallback for edge cases # Optimized: minimal checks for common case, fallback for edge cases
metadata = doc.metadata if doc.metadata else {} metadata = doc.metadata or {}
if not isinstance(metadata, dict): if not isinstance(metadata, dict):
metadata = {} metadata = {}

View File

@@ -103,7 +103,7 @@ class MatrixoneVector(BaseVector):
self.client = self._get_client(len(embeddings[0]), True) self.client = self._get_client(len(embeddings[0]), True)
assert self.client is not None assert self.client is not None
ids = [] ids = []
for _, doc in enumerate(documents): for doc in documents:
if doc.metadata is not None: if doc.metadata is not None:
doc_id = doc.metadata.get("doc_id", str(uuid.uuid4())) doc_id = doc.metadata.get("doc_id", str(uuid.uuid4()))
ids.append(doc_id) ids.append(doc_id)

View File

@@ -104,7 +104,7 @@ class OpenSearchVector(BaseVector):
}, },
} }
# See https://github.com/langchain-ai/langchainjs/issues/4346#issuecomment-1935123377 # See https://github.com/langchain-ai/langchainjs/issues/4346#issuecomment-1935123377
if self._client_config.aws_service not in ["aoss"]: if self._client_config.aws_service != "aoss":
action["_id"] = uuid4().hex action["_id"] = uuid4().hex
actions.append(action) actions.append(action)

View File

@@ -159,7 +159,7 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository):
else None else None
) )
db_model.status = domain_model.status db_model.status = domain_model.status
db_model.error = domain_model.error_message if domain_model.error_message else None db_model.error = domain_model.error_message or None
db_model.total_tokens = domain_model.total_tokens db_model.total_tokens = domain_model.total_tokens
db_model.total_steps = domain_model.total_steps db_model.total_steps = domain_model.total_steps
db_model.exceptions_count = domain_model.exceptions_count db_model.exceptions_count = domain_model.exceptions_count

View File

@@ -320,7 +320,7 @@ class AgentNode(BaseNode):
memory = self._fetch_memory(model_instance) memory = self._fetch_memory(model_instance)
if memory: if memory:
prompt_messages = memory.get_history_prompt_messages( prompt_messages = memory.get_history_prompt_messages(
message_limit=node_data.memory.window.size if node_data.memory.window.size else None message_limit=node_data.memory.window.size or None
) )
history_prompt_messages = [ history_prompt_messages = [
prompt_message.model_dump(mode="json") for prompt_message in prompt_messages prompt_message.model_dump(mode="json") for prompt_message in prompt_messages

View File

@@ -141,9 +141,7 @@ def init_app(app: DifyApp) -> Celery:
imports.append("schedule.queue_monitor_task") imports.append("schedule.queue_monitor_task")
beat_schedule["datasets-queue-monitor"] = { beat_schedule["datasets-queue-monitor"] = {
"task": "schedule.queue_monitor_task.queue_monitor_task", "task": "schedule.queue_monitor_task.queue_monitor_task",
"schedule": timedelta( "schedule": timedelta(minutes=dify_config.QUEUE_MONITOR_INTERVAL or 30),
minutes=dify_config.QUEUE_MONITOR_INTERVAL if dify_config.QUEUE_MONITOR_INTERVAL else 30
),
} }
if dify_config.ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK and dify_config.MARKETPLACE_ENABLED: if dify_config.ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK and dify_config.MARKETPLACE_ENABLED:
imports.append("schedule.check_upgradable_plugin_task") imports.append("schedule.check_upgradable_plugin_task")

View File

@@ -7,6 +7,7 @@ Supports complete lifecycle management for knowledge base files.
import json import json
import logging import logging
import operator
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from datetime import datetime from datetime import datetime
from enum import StrEnum, auto from enum import StrEnum, auto
@@ -356,7 +357,7 @@ class FileLifecycleManager:
# Cleanup old versions for each file # Cleanup old versions for each file
for base_filename, versions in file_versions.items(): for base_filename, versions in file_versions.items():
# Sort by version number # Sort by version number
versions.sort(key=lambda x: x[0], reverse=True) versions.sort(key=operator.itemgetter(0), reverse=True)
# Keep the newest max_versions versions, delete the rest # Keep the newest max_versions versions, delete the rest
if len(versions) > max_versions: if len(versions) > max_versions:

View File

@@ -1,3 +1,4 @@
import operator
import traceback import traceback
import typing import typing
@@ -118,7 +119,7 @@ def process_tenant_plugin_autoupgrade_check_task(
current_version = version current_version = version
latest_version = manifest.latest_version latest_version = manifest.latest_version
def fix_only_checker(latest_version, current_version): def fix_only_checker(latest_version: str, current_version: str):
latest_version_tuple = tuple(int(val) for val in latest_version.split(".")) latest_version_tuple = tuple(int(val) for val in latest_version.split("."))
current_version_tuple = tuple(int(val) for val in current_version.split(".")) current_version_tuple = tuple(int(val) for val in current_version.split("."))
@@ -130,8 +131,7 @@ def process_tenant_plugin_autoupgrade_check_task(
return False return False
version_checker = { version_checker = {
TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST: lambda latest_version, TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST: operator.ne,
current_version: latest_version != current_version,
TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY: fix_only_checker, TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY: fix_only_checker,
} }

View File

@@ -3,6 +3,7 @@
import os import os
import tempfile import tempfile
import unittest import unittest
from pathlib import Path
import pytest import pytest
@@ -60,8 +61,7 @@ class TestClickZettaVolumeStorage(unittest.TestCase):
# Test download # Test download
with tempfile.NamedTemporaryFile() as temp_file: with tempfile.NamedTemporaryFile() as temp_file:
storage.download(test_filename, temp_file.name) storage.download(test_filename, temp_file.name)
with open(temp_file.name, "rb") as f: downloaded_content = Path(temp_file.name).read_bytes()
downloaded_content = f.read()
assert downloaded_content == test_content assert downloaded_content == test_content
# Test scan # Test scan

View File

@@ -12,6 +12,7 @@ and realistic testing scenarios with actual PostgreSQL and Redis instances.
import uuid import uuid
from datetime import datetime from datetime import datetime
from pathlib import Path
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import pytest import pytest
@@ -276,8 +277,7 @@ class TestBatchCreateSegmentToIndexTask:
mock_storage = mock_external_service_dependencies["storage"] mock_storage = mock_external_service_dependencies["storage"]
def mock_download(key, file_path): def mock_download(key, file_path):
with open(file_path, "w", encoding="utf-8") as f: Path(file_path).write_text(csv_content, encoding="utf-8")
f.write(csv_content)
mock_storage.download.side_effect = mock_download mock_storage.download.side_effect = mock_download
@@ -505,7 +505,7 @@ class TestBatchCreateSegmentToIndexTask:
db.session.commit() db.session.commit()
# Test each unavailable document # Test each unavailable document
for i, document in enumerate(test_cases): for document in test_cases:
job_id = str(uuid.uuid4()) job_id = str(uuid.uuid4())
batch_create_segment_to_index_task( batch_create_segment_to_index_task(
job_id=job_id, job_id=job_id,
@@ -601,8 +601,7 @@ class TestBatchCreateSegmentToIndexTask:
mock_storage = mock_external_service_dependencies["storage"] mock_storage = mock_external_service_dependencies["storage"]
def mock_download(key, file_path): def mock_download(key, file_path):
with open(file_path, "w", encoding="utf-8") as f: Path(file_path).write_text(empty_csv_content, encoding="utf-8")
f.write(empty_csv_content)
mock_storage.download.side_effect = mock_download mock_storage.download.side_effect = mock_download
@@ -684,8 +683,7 @@ class TestBatchCreateSegmentToIndexTask:
mock_storage = mock_external_service_dependencies["storage"] mock_storage = mock_external_service_dependencies["storage"]
def mock_download(key, file_path): def mock_download(key, file_path):
with open(file_path, "w", encoding="utf-8") as f: Path(file_path).write_text(csv_content, encoding="utf-8")
f.write(csv_content)
mock_storage.download.side_effect = mock_download mock_storage.download.side_effect = mock_download

View File

@@ -362,7 +362,7 @@ class TestCleanDatasetTask:
# Create segments for each document # Create segments for each document
segments = [] segments = []
for i, document in enumerate(documents): for document in documents:
segment = self._create_test_segment(db_session_with_containers, account, tenant, dataset, document) segment = self._create_test_segment(db_session_with_containers, account, tenant, dataset, document)
segments.append(segment) segments.append(segment)

View File

@@ -15,7 +15,7 @@ class FakeResponse:
self.status_code = status_code self.status_code = status_code
self.headers = headers or {} self.headers = headers or {}
self.content = content self.content = content
self.text = text if text else content.decode("utf-8", errors="ignore") self.text = text or content.decode("utf-8", errors="ignore")
# --------------------------- # ---------------------------

View File

@@ -1,3 +1,4 @@
from pathlib import Path
from unittest.mock import Mock, create_autospec, patch from unittest.mock import Mock, create_autospec, patch
import pytest import pytest
@@ -146,19 +147,17 @@ class TestMetadataBugCompleteValidation:
# Console API create # Console API create
console_create_file = "api/controllers/console/datasets/metadata.py" console_create_file = "api/controllers/console/datasets/metadata.py"
if os.path.exists(console_create_file): if os.path.exists(console_create_file):
with open(console_create_file) as f: content = Path(console_create_file).read_text()
content = f.read() # Should contain nullable=False, not nullable=True
# Should contain nullable=False, not nullable=True assert "nullable=True" not in content.split("class DatasetMetadataCreateApi")[1].split("class")[0]
assert "nullable=True" not in content.split("class DatasetMetadataCreateApi")[1].split("class")[0]
# Service API create # Service API create
service_create_file = "api/controllers/service_api/dataset/metadata.py" service_create_file = "api/controllers/service_api/dataset/metadata.py"
if os.path.exists(service_create_file): if os.path.exists(service_create_file):
with open(service_create_file) as f: content = Path(service_create_file).read_text()
content = f.read() # Should contain nullable=False, not nullable=True
# Should contain nullable=False, not nullable=True create_api_section = content.split("class DatasetMetadataCreateServiceApi")[1].split("class")[0]
create_api_section = content.split("class DatasetMetadataCreateServiceApi")[1].split("class")[0] assert "nullable=True" not in create_api_section
assert "nullable=True" not in create_api_section
class TestMetadataValidationSummary: class TestMetadataValidationSummary:

View File

@@ -1,6 +1,7 @@
from pathlib import Path
import yaml # type: ignore import yaml # type: ignore
from dotenv import dotenv_values from dotenv import dotenv_values
from pathlib import Path
BASE_API_AND_DOCKER_CONFIG_SET_DIFF = { BASE_API_AND_DOCKER_CONFIG_SET_DIFF = {
"APP_MAX_EXECUTION_TIME", "APP_MAX_EXECUTION_TIME",
@@ -98,23 +99,15 @@ with open(Path("docker") / Path("docker-compose.yaml")) as f:
def test_yaml_config(): def test_yaml_config():
# python set == operator is used to compare two sets # python set == operator is used to compare two sets
DIFF_API_WITH_DOCKER = ( DIFF_API_WITH_DOCKER = API_CONFIG_SET - DOCKER_CONFIG_SET - BASE_API_AND_DOCKER_CONFIG_SET_DIFF
API_CONFIG_SET - DOCKER_CONFIG_SET - BASE_API_AND_DOCKER_CONFIG_SET_DIFF
)
if DIFF_API_WITH_DOCKER: if DIFF_API_WITH_DOCKER:
print( print(f"API and Docker config sets are different with key: {DIFF_API_WITH_DOCKER}")
f"API and Docker config sets are different with key: {DIFF_API_WITH_DOCKER}"
)
raise Exception("API and Docker config sets are different") raise Exception("API and Docker config sets are different")
DIFF_API_WITH_DOCKER_COMPOSE = ( DIFF_API_WITH_DOCKER_COMPOSE = (
API_CONFIG_SET API_CONFIG_SET - DOCKER_COMPOSE_CONFIG_SET - BASE_API_AND_DOCKER_COMPOSE_CONFIG_SET_DIFF
- DOCKER_COMPOSE_CONFIG_SET
- BASE_API_AND_DOCKER_COMPOSE_CONFIG_SET_DIFF
) )
if DIFF_API_WITH_DOCKER_COMPOSE: if DIFF_API_WITH_DOCKER_COMPOSE:
print( print(f"API and Docker Compose config sets are different with key: {DIFF_API_WITH_DOCKER_COMPOSE}")
f"API and Docker Compose config sets are different with key: {DIFF_API_WITH_DOCKER_COMPOSE}"
)
raise Exception("API and Docker Compose config sets are different") raise Exception("API and Docker Compose config sets are different")
print("All tests passed!") print("All tests passed!")

View File

@@ -51,9 +51,7 @@ def cleanup() -> None:
if sys.stdin.isatty(): if sys.stdin.isatty():
log.separator() log.separator()
log.warning("This action cannot be undone!") log.warning("This action cannot be undone!")
confirmation = input( confirmation = input("Are you sure you want to remove all config and report files? (yes/no): ")
"Are you sure you want to remove all config and report files? (yes/no): "
)
if confirmation.lower() not in ["yes", "y"]: if confirmation.lower() not in ["yes", "y"]:
log.error("Cleanup cancelled.") log.error("Cleanup cancelled.")

View File

@@ -3,4 +3,4 @@
from .config_helper import config_helper from .config_helper import config_helper
from .logger_helper import Logger, ProgressLogger from .logger_helper import Logger, ProgressLogger
__all__ = ["config_helper", "Logger", "ProgressLogger"] __all__ = ["Logger", "ProgressLogger", "config_helper"]

View File

@@ -65,9 +65,9 @@ class ConfigHelper:
return None return None
try: try:
with open(config_path, "r") as f: with open(config_path) as f:
return json.load(f) return json.load(f)
except (json.JSONDecodeError, IOError) as e: except (OSError, json.JSONDecodeError) as e:
print(f"❌ Error reading {filename}: {e}") print(f"❌ Error reading {filename}: {e}")
return None return None
@@ -101,7 +101,7 @@ class ConfigHelper:
with open(config_path, "w") as f: with open(config_path, "w") as f:
json.dump(data, f, indent=2) json.dump(data, f, indent=2)
return True return True
except IOError as e: except OSError as e:
print(f"❌ Error writing {filename}: {e}") print(f"❌ Error writing {filename}: {e}")
return False return False
@@ -133,7 +133,7 @@ class ConfigHelper:
try: try:
config_path.unlink() config_path.unlink()
return True return True
except IOError as e: except OSError as e:
print(f"❌ Error deleting {filename}: {e}") print(f"❌ Error deleting {filename}: {e}")
return False return False
@@ -148,9 +148,9 @@ class ConfigHelper:
return None return None
try: try:
with open(state_path, "r") as f: with open(state_path) as f:
return json.load(f) return json.load(f)
except (json.JSONDecodeError, IOError) as e: except (OSError, json.JSONDecodeError) as e:
print(f"❌ Error reading {self.state_file}: {e}") print(f"❌ Error reading {self.state_file}: {e}")
return None return None
@@ -170,7 +170,7 @@ class ConfigHelper:
with open(state_path, "w") as f: with open(state_path, "w") as f:
json.dump(data, f, indent=2) json.dump(data, f, indent=2)
return True return True
except IOError as e: except OSError as e:
print(f"❌ Error writing {self.state_file}: {e}") print(f"❌ Error writing {self.state_file}: {e}")
return False return False

View File

@@ -159,9 +159,7 @@ class ProgressLogger:
if self.logger.use_colors: if self.logger.use_colors:
progress_bar = self._create_progress_bar() progress_bar = self._create_progress_bar()
print( print(f"\n\033[1m[Step {self.current_step}/{self.total_steps}]\033[0m {progress_bar}")
f"\n\033[1m[Step {self.current_step}/{self.total_steps}]\033[0m {progress_bar}"
)
self.logger.step(f"{description} (Elapsed: {elapsed:.1f}s)") self.logger.step(f"{description} (Elapsed: {elapsed:.1f}s)")
else: else:
print(f"\n[Step {self.current_step}/{self.total_steps}]") print(f"\n[Step {self.current_step}/{self.total_steps}]")

View File

@@ -6,8 +6,7 @@ from pathlib import Path
sys.path.append(str(Path(__file__).parent.parent)) sys.path.append(str(Path(__file__).parent.parent))
import httpx import httpx
from common import config_helper from common import Logger, config_helper
from common import Logger
def configure_openai_plugin() -> None: def configure_openai_plugin() -> None:
@@ -72,29 +71,19 @@ def configure_openai_plugin() -> None:
if response.status_code == 200: if response.status_code == 200:
log.success("OpenAI plugin configured successfully!") log.success("OpenAI plugin configured successfully!")
log.key_value( log.key_value("API Base", config_payload["credentials"]["openai_api_base"])
"API Base", config_payload["credentials"]["openai_api_base"] log.key_value("API Key", config_payload["credentials"]["openai_api_key"])
)
log.key_value(
"API Key", config_payload["credentials"]["openai_api_key"]
)
elif response.status_code == 201: elif response.status_code == 201:
log.success("OpenAI plugin credentials created successfully!") log.success("OpenAI plugin credentials created successfully!")
log.key_value( log.key_value("API Base", config_payload["credentials"]["openai_api_base"])
"API Base", config_payload["credentials"]["openai_api_base"] log.key_value("API Key", config_payload["credentials"]["openai_api_key"])
)
log.key_value(
"API Key", config_payload["credentials"]["openai_api_key"]
)
elif response.status_code == 401: elif response.status_code == 401:
log.error("Configuration failed: Unauthorized") log.error("Configuration failed: Unauthorized")
log.info("Token may have expired. Please run login_admin.py again") log.info("Token may have expired. Please run login_admin.py again")
else: else:
log.error( log.error(f"Configuration failed with status code: {response.status_code}")
f"Configuration failed with status code: {response.status_code}"
)
log.debug(f"Response: {response.text}") log.debug(f"Response: {response.text}")
except httpx.ConnectError: except httpx.ConnectError:

View File

@@ -5,10 +5,10 @@ from pathlib import Path
sys.path.append(str(Path(__file__).parent.parent)) sys.path.append(str(Path(__file__).parent.parent))
import httpx
import json import json
from common import config_helper
from common import Logger import httpx
from common import Logger, config_helper
def create_api_key() -> None: def create_api_key() -> None:
@@ -90,9 +90,7 @@ def create_api_key() -> None:
} }
if config_helper.write_config("api_key_config", api_key_config): if config_helper.write_config("api_key_config", api_key_config):
log.info( log.info(f"API key saved to: {config_helper.get_config_path('benchmark_state')}")
f"API key saved to: {config_helper.get_config_path('benchmark_state')}"
)
else: else:
log.error("No API token received") log.error("No API token received")
log.debug(f"Response: {json.dumps(response_data, indent=2)}") log.debug(f"Response: {json.dumps(response_data, indent=2)}")
@@ -101,9 +99,7 @@ def create_api_key() -> None:
log.error("API key creation failed: Unauthorized") log.error("API key creation failed: Unauthorized")
log.info("Token may have expired. Please run login_admin.py again") log.info("Token may have expired. Please run login_admin.py again")
else: else:
log.error( log.error(f"API key creation failed with status code: {response.status_code}")
f"API key creation failed with status code: {response.status_code}"
)
log.debug(f"Response: {response.text}") log.debug(f"Response: {response.text}")
except httpx.ConnectError: except httpx.ConnectError:

View File

@@ -5,9 +5,10 @@ from pathlib import Path
sys.path.append(str(Path(__file__).parent.parent)) sys.path.append(str(Path(__file__).parent.parent))
import httpx
import json import json
from common import config_helper, Logger
import httpx
from common import Logger, config_helper
def import_workflow_app() -> None: def import_workflow_app() -> None:
@@ -30,7 +31,7 @@ def import_workflow_app() -> None:
log.error(f"DSL file not found: {dsl_path}") log.error(f"DSL file not found: {dsl_path}")
return return
with open(dsl_path, "r") as f: with open(dsl_path) as f:
yaml_content = f.read() yaml_content = f.read()
log.step("Importing workflow app from DSL...") log.step("Importing workflow app from DSL...")
@@ -86,9 +87,7 @@ def import_workflow_app() -> None:
log.success("Workflow app imported successfully!") log.success("Workflow app imported successfully!")
log.key_value("App ID", app_id) log.key_value("App ID", app_id)
log.key_value("App Mode", response_data.get("app_mode")) log.key_value("App Mode", response_data.get("app_mode"))
log.key_value( log.key_value("DSL Version", response_data.get("imported_dsl_version"))
"DSL Version", response_data.get("imported_dsl_version")
)
# Save app_id to config # Save app_id to config
app_config = { app_config = {
@@ -99,9 +98,7 @@ def import_workflow_app() -> None:
} }
if config_helper.write_config("app_config", app_config): if config_helper.write_config("app_config", app_config):
log.info( log.info(f"App config saved to: {config_helper.get_config_path('benchmark_state')}")
f"App config saved to: {config_helper.get_config_path('benchmark_state')}"
)
else: else:
log.error("Import completed but no app_id received") log.error("Import completed but no app_id received")
log.debug(f"Response: {json.dumps(response_data, indent=2)}") log.debug(f"Response: {json.dumps(response_data, indent=2)}")

View File

@@ -5,10 +5,10 @@ from pathlib import Path
sys.path.append(str(Path(__file__).parent.parent)) sys.path.append(str(Path(__file__).parent.parent))
import httpx
import time import time
from common import config_helper
from common import Logger import httpx
from common import Logger, config_helper
def install_openai_plugin() -> None: def install_openai_plugin() -> None:
@@ -28,9 +28,7 @@ def install_openai_plugin() -> None:
# API endpoint for plugin installation # API endpoint for plugin installation
base_url = "http://localhost:5001" base_url = "http://localhost:5001"
install_endpoint = ( install_endpoint = f"{base_url}/console/api/workspaces/current/plugin/install/marketplace"
f"{base_url}/console/api/workspaces/current/plugin/install/marketplace"
)
# Plugin identifier # Plugin identifier
plugin_payload = { plugin_payload = {
@@ -83,9 +81,7 @@ def install_openai_plugin() -> None:
log.info("Polling for task completion...") log.info("Polling for task completion...")
# Poll for task completion # Poll for task completion
task_endpoint = ( task_endpoint = f"{base_url}/console/api/workspaces/current/plugin/tasks/{task_id}"
f"{base_url}/console/api/workspaces/current/plugin/tasks/{task_id}"
)
max_attempts = 30 # 30 attempts with 2 second delay = 60 seconds max max_attempts = 30 # 30 attempts with 2 second delay = 60 seconds max
attempt = 0 attempt = 0
@@ -131,9 +127,7 @@ def install_openai_plugin() -> None:
plugins = task_info.get("plugins", []) plugins = task_info.get("plugins", [])
if plugins: if plugins:
for plugin in plugins: for plugin in plugins:
log.list_item( log.list_item(f"{plugin.get('plugin_id')}: {plugin.get('message')}")
f"{plugin.get('plugin_id')}: {plugin.get('message')}"
)
break break
# Continue polling if status is "pending" or other # Continue polling if status is "pending" or other
@@ -149,9 +143,7 @@ def install_openai_plugin() -> None:
log.warning("Plugin may already be installed") log.warning("Plugin may already be installed")
log.debug(f"Response: {response.text}") log.debug(f"Response: {response.text}")
else: else:
log.error( log.error(f"Installation failed with status code: {response.status_code}")
f"Installation failed with status code: {response.status_code}"
)
log.debug(f"Response: {response.text}") log.debug(f"Response: {response.text}")
except httpx.ConnectError: except httpx.ConnectError:

View File

@@ -5,10 +5,10 @@ from pathlib import Path
sys.path.append(str(Path(__file__).parent.parent)) sys.path.append(str(Path(__file__).parent.parent))
import httpx
import json import json
from common import config_helper
from common import Logger import httpx
from common import Logger, config_helper
def login_admin() -> None: def login_admin() -> None:
@@ -77,16 +77,10 @@ def login_admin() -> None:
# Save token config # Save token config
if config_helper.write_config("token_config", token_config): if config_helper.write_config("token_config", token_config):
log.info( log.info(f"Token saved to: {config_helper.get_config_path('benchmark_state')}")
f"Token saved to: {config_helper.get_config_path('benchmark_state')}"
)
# Show truncated token for verification # Show truncated token for verification
token_display = ( token_display = f"{access_token[:20]}..." if len(access_token) > 20 else "Token saved"
f"{access_token[:20]}..."
if len(access_token) > 20
else "Token saved"
)
log.key_value("Access token", token_display) log.key_value("Access token", token_display)
elif response.status_code == 401: elif response.status_code == 401:

View File

@@ -3,8 +3,10 @@
import json import json
import time import time
import uuid import uuid
from typing import Any, Iterator from collections.abc import Iterator
from flask import Flask, request, jsonify, Response from typing import Any
from flask import Flask, Response, jsonify, request
app = Flask(__name__) app = Flask(__name__)

View File

@@ -5,10 +5,10 @@ from pathlib import Path
sys.path.append(str(Path(__file__).parent.parent)) sys.path.append(str(Path(__file__).parent.parent))
import httpx
import json import json
from common import config_helper
from common import Logger import httpx
from common import Logger, config_helper
def publish_workflow() -> None: def publish_workflow() -> None:
@@ -79,9 +79,7 @@ def publish_workflow() -> None:
try: try:
response_data = response.json() response_data = response.json()
if response_data: if response_data:
log.debug( log.debug(f"Response: {json.dumps(response_data, indent=2)}")
f"Response: {json.dumps(response_data, indent=2)}"
)
except json.JSONDecodeError: except json.JSONDecodeError:
# Response might be empty or non-JSON # Response might be empty or non-JSON
pass pass
@@ -93,9 +91,7 @@ def publish_workflow() -> None:
log.error("Workflow publish failed: App not found") log.error("Workflow publish failed: App not found")
log.info("Make sure the app was imported successfully") log.info("Make sure the app was imported successfully")
else: else:
log.error( log.error(f"Workflow publish failed with status code: {response.status_code}")
f"Workflow publish failed with status code: {response.status_code}"
)
log.debug(f"Response: {response.text}") log.debug(f"Response: {response.text}")
except httpx.ConnectError: except httpx.ConnectError:

View File

@@ -5,9 +5,10 @@ from pathlib import Path
sys.path.append(str(Path(__file__).parent.parent)) sys.path.append(str(Path(__file__).parent.parent))
import httpx
import json import json
from common import config_helper, Logger
import httpx
from common import Logger, config_helper
def run_workflow(question: str = "fake question", streaming: bool = True) -> None: def run_workflow(question: str = "fake question", streaming: bool = True) -> None:
@@ -70,9 +71,7 @@ def run_workflow(question: str = "fake question", streaming: bool = True) -> Non
event = data.get("event") event = data.get("event")
if event == "workflow_started": if event == "workflow_started":
log.progress( log.progress(f"Workflow started: {data.get('data', {}).get('id')}")
f"Workflow started: {data.get('data', {}).get('id')}"
)
elif event == "node_started": elif event == "node_started":
node_data = data.get("data", {}) node_data = data.get("data", {})
log.progress( log.progress(
@@ -116,9 +115,7 @@ def run_workflow(question: str = "fake question", streaming: bool = True) -> Non
# Some lines might not be JSON # Some lines might not be JSON
pass pass
else: else:
log.error( log.error(f"Workflow run failed with status code: {response.status_code}")
f"Workflow run failed with status code: {response.status_code}"
)
log.debug(f"Response: {response.text}") log.debug(f"Response: {response.text}")
else: else:
# Handle blocking response # Handle blocking response
@@ -142,9 +139,7 @@ def run_workflow(question: str = "fake question", streaming: bool = True) -> Non
log.info("📤 Final Answer:") log.info("📤 Final Answer:")
log.info(outputs.get("answer"), indent=2) log.info(outputs.get("answer"), indent=2)
else: else:
log.error( log.error(f"Workflow run failed with status code: {response.status_code}")
f"Workflow run failed with status code: {response.status_code}"
)
log.debug(f"Response: {response.text}") log.debug(f"Response: {response.text}")
except httpx.ConnectError: except httpx.ConnectError:

View File

@@ -6,7 +6,7 @@ from pathlib import Path
sys.path.append(str(Path(__file__).parent.parent)) sys.path.append(str(Path(__file__).parent.parent))
import httpx import httpx
from common import config_helper, Logger from common import Logger, config_helper
def setup_admin_account() -> None: def setup_admin_account() -> None:
@@ -24,9 +24,7 @@ def setup_admin_account() -> None:
# Save credentials to config file # Save credentials to config file
if config_helper.write_config("admin_config", admin_config): if config_helper.write_config("admin_config", admin_config):
log.info( log.info(f"Admin credentials saved to: {config_helper.get_config_path('benchmark_state')}")
f"Admin credentials saved to: {config_helper.get_config_path('benchmark_state')}"
)
# API setup endpoint # API setup endpoint
base_url = "http://localhost:5001" base_url = "http://localhost:5001"
@@ -56,9 +54,7 @@ def setup_admin_account() -> None:
log.key_value("Username", admin_config["username"]) log.key_value("Username", admin_config["username"])
elif response.status_code == 400: elif response.status_code == 400:
log.warning( log.warning("Setup may have already been completed or invalid data provided")
"Setup may have already been completed or invalid data provided"
)
log.debug(f"Response: {response.text}") log.debug(f"Response: {response.text}")
else: else:
log.error(f"Setup failed with status code: {response.status_code}") log.error(f"Setup failed with status code: {response.status_code}")

View File

@@ -1,9 +1,9 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import socket
import subprocess import subprocess
import sys import sys
import time import time
import socket
from pathlib import Path from pathlib import Path
from common import Logger, ProgressLogger from common import Logger, ProgressLogger
@@ -93,9 +93,7 @@ def main() -> None:
if retry.lower() in ["yes", "y"]: if retry.lower() in ["yes", "y"]:
return main() # Recursively call main to check again return main() # Recursively call main to check again
else: else:
print( print("❌ Setup cancelled. Please start the required services and try again.")
"❌ Setup cancelled. Please start the required services and try again."
)
sys.exit(1) sys.exit(1)
log.success("All required services are running!") log.success("All required services are running!")

View File

@@ -7,29 +7,28 @@ measuring key metrics like connection rate, event throughput, and time to first
""" """
import json import json
import time import logging
import os
import random import random
import statistics
import sys import sys
import threading import threading
import os import time
import logging
import statistics
from pathlib import Path
from collections import deque from collections import deque
from dataclasses import asdict, dataclass
from datetime import datetime from datetime import datetime
from dataclasses import dataclass, asdict from pathlib import Path
from locust import HttpUser, task, between, events, constant from typing import Literal, TypeAlias, TypedDict
from typing import TypedDict, Literal, TypeAlias
import requests.exceptions import requests.exceptions
from locust import HttpUser, between, constant, events, task
# Add the stress-test directory to path to import common modules # Add the stress-test directory to path to import common modules
sys.path.insert(0, str(Path(__file__).parent)) sys.path.insert(0, str(Path(__file__).parent))
from common.config_helper import ConfigHelper # type: ignore[import-not-found] from common.config_helper import ConfigHelper # type: ignore[import-not-found]
# Configure logging # Configure logging
logging.basicConfig( logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Configuration from environment # Configuration from environment
@@ -54,6 +53,7 @@ ErrorType: TypeAlias = Literal[
class ErrorCounts(TypedDict): class ErrorCounts(TypedDict):
"""Error count tracking""" """Error count tracking"""
connection_error: int connection_error: int
timeout: int timeout: int
invalid_json: int invalid_json: int
@@ -65,6 +65,7 @@ class ErrorCounts(TypedDict):
class SSEEvent(TypedDict): class SSEEvent(TypedDict):
"""Server-Sent Event structure""" """Server-Sent Event structure"""
data: str data: str
event: str event: str
id: str | None id: str | None
@@ -72,11 +73,13 @@ class SSEEvent(TypedDict):
class WorkflowInputs(TypedDict): class WorkflowInputs(TypedDict):
"""Workflow input structure""" """Workflow input structure"""
question: str question: str
class WorkflowRequestData(TypedDict): class WorkflowRequestData(TypedDict):
"""Workflow request payload""" """Workflow request payload"""
inputs: WorkflowInputs inputs: WorkflowInputs
response_mode: Literal["streaming"] response_mode: Literal["streaming"]
user: str user: str
@@ -84,6 +87,7 @@ class WorkflowRequestData(TypedDict):
class ParsedEventData(TypedDict, total=False): class ParsedEventData(TypedDict, total=False):
"""Parsed event data from SSE stream""" """Parsed event data from SSE stream"""
event: str event: str
task_id: str task_id: str
workflow_run_id: str workflow_run_id: str
@@ -93,6 +97,7 @@ class ParsedEventData(TypedDict, total=False):
class LocustStats(TypedDict): class LocustStats(TypedDict):
"""Locust statistics structure""" """Locust statistics structure"""
total_requests: int total_requests: int
total_failures: int total_failures: int
avg_response_time: float avg_response_time: float
@@ -102,6 +107,7 @@ class LocustStats(TypedDict):
class ReportData(TypedDict): class ReportData(TypedDict):
"""JSON report structure""" """JSON report structure"""
timestamp: str timestamp: str
duration_seconds: float duration_seconds: float
metrics: dict[str, object] # Metrics as dict for JSON serialization metrics: dict[str, object] # Metrics as dict for JSON serialization
@@ -154,7 +160,7 @@ class MetricsTracker:
self.total_connections = 0 self.total_connections = 0
self.total_events = 0 self.total_events = 0
self.start_time = time.time() self.start_time = time.time()
# Enhanced metrics with memory limits # Enhanced metrics with memory limits
self.max_samples = 10000 # Prevent unbounded growth self.max_samples = 10000 # Prevent unbounded growth
self.ttfe_samples: deque[float] = deque(maxlen=self.max_samples) self.ttfe_samples: deque[float] = deque(maxlen=self.max_samples)
@@ -233,9 +239,7 @@ class MetricsTracker:
max_ttfe = max(self.ttfe_samples) max_ttfe = max(self.ttfe_samples)
p50_ttfe = statistics.median(self.ttfe_samples) p50_ttfe = statistics.median(self.ttfe_samples)
if len(self.ttfe_samples) >= 2: if len(self.ttfe_samples) >= 2:
quantiles = statistics.quantiles( quantiles = statistics.quantiles(self.ttfe_samples, n=20, method="inclusive")
self.ttfe_samples, n=20, method="inclusive"
)
p95_ttfe = quantiles[18] # 19th of 19 quantiles = 95th percentile p95_ttfe = quantiles[18] # 19th of 19 quantiles = 95th percentile
else: else:
p95_ttfe = max_ttfe p95_ttfe = max_ttfe
@@ -255,9 +259,7 @@ class MetricsTracker:
if durations if durations
else 0 else 0
) )
events_per_stream_avg = ( events_per_stream_avg = statistics.mean(events_per_stream) if events_per_stream else 0
statistics.mean(events_per_stream) if events_per_stream else 0
)
# Calculate inter-event latency statistics # Calculate inter-event latency statistics
all_inter_event_times = [] all_inter_event_times = []
@@ -268,32 +270,20 @@ class MetricsTracker:
inter_event_latency_avg = statistics.mean(all_inter_event_times) inter_event_latency_avg = statistics.mean(all_inter_event_times)
inter_event_latency_p50 = statistics.median(all_inter_event_times) inter_event_latency_p50 = statistics.median(all_inter_event_times)
inter_event_latency_p95 = ( inter_event_latency_p95 = (
statistics.quantiles( statistics.quantiles(all_inter_event_times, n=20, method="inclusive")[18]
all_inter_event_times, n=20, method="inclusive"
)[18]
if len(all_inter_event_times) >= 2 if len(all_inter_event_times) >= 2
else max(all_inter_event_times) else max(all_inter_event_times)
) )
else: else:
inter_event_latency_avg = inter_event_latency_p50 = ( inter_event_latency_avg = inter_event_latency_p50 = inter_event_latency_p95 = 0
inter_event_latency_p95
) = 0
else: else:
stream_duration_avg = stream_duration_p50 = stream_duration_p95 = ( stream_duration_avg = stream_duration_p50 = stream_duration_p95 = events_per_stream_avg = 0
events_per_stream_avg inter_event_latency_avg = inter_event_latency_p50 = inter_event_latency_p95 = 0
) = 0
inter_event_latency_avg = inter_event_latency_p50 = (
inter_event_latency_p95
) = 0
# Also calculate overall average rates # Also calculate overall average rates
total_elapsed = current_time - self.start_time total_elapsed = current_time - self.start_time
overall_conn_rate = ( overall_conn_rate = self.total_connections / total_elapsed if total_elapsed > 0 else 0
self.total_connections / total_elapsed if total_elapsed > 0 else 0 overall_event_rate = self.total_events / total_elapsed if total_elapsed > 0 else 0
)
overall_event_rate = (
self.total_events / total_elapsed if total_elapsed > 0 else 0
)
return MetricsSnapshot( return MetricsSnapshot(
active_connections=self.active_connections, active_connections=self.active_connections,
@@ -389,7 +379,7 @@ class DifyWorkflowUser(HttpUser):
# Load questions from file or use defaults # Load questions from file or use defaults
if QUESTIONS_FILE and os.path.exists(QUESTIONS_FILE): if QUESTIONS_FILE and os.path.exists(QUESTIONS_FILE):
with open(QUESTIONS_FILE, "r") as f: with open(QUESTIONS_FILE) as f:
self.questions = [line.strip() for line in f if line.strip()] self.questions = [line.strip() for line in f if line.strip()]
else: else:
self.questions = [ self.questions = [
@@ -451,18 +441,13 @@ class DifyWorkflowUser(HttpUser):
try: try:
# Validate response # Validate response
if response.status_code >= 400: if response.status_code >= 400:
error_type: ErrorType = ( error_type: ErrorType = "http_4xx" if response.status_code < 500 else "http_5xx"
"http_4xx" if response.status_code < 500 else "http_5xx"
)
metrics.record_error(error_type) metrics.record_error(error_type)
response.failure(f"HTTP {response.status_code}") response.failure(f"HTTP {response.status_code}")
return return
content_type = response.headers.get("Content-Type", "") content_type = response.headers.get("Content-Type", "")
if ( if "text/event-stream" not in content_type and "application/json" not in content_type:
"text/event-stream" not in content_type
and "application/json" not in content_type
):
logger.error(f"Expected text/event-stream, got: {content_type}") logger.error(f"Expected text/event-stream, got: {content_type}")
metrics.record_error("invalid_response") metrics.record_error("invalid_response")
response.failure(f"Invalid content type: {content_type}") response.failure(f"Invalid content type: {content_type}")
@@ -473,10 +458,13 @@ class DifyWorkflowUser(HttpUser):
for line in response.iter_lines(decode_unicode=True): for line in response.iter_lines(decode_unicode=True):
# Check if runner is stopping # Check if runner is stopping
if getattr(self.environment.runner, 'state', '') in ('stopping', 'stopped'): if getattr(self.environment.runner, "state", "") in (
"stopping",
"stopped",
):
logger.debug("Runner stopping, breaking streaming loop") logger.debug("Runner stopping, breaking streaming loop")
break break
if line is not None: if line is not None:
bytes_received += len(line.encode("utf-8")) bytes_received += len(line.encode("utf-8"))
@@ -489,9 +477,7 @@ class DifyWorkflowUser(HttpUser):
# Track inter-event timing # Track inter-event timing
if last_event_time: if last_event_time:
inter_event_times.append( inter_event_times.append((current_time - last_event_time) * 1000)
(current_time - last_event_time) * 1000
)
last_event_time = current_time last_event_time = current_time
if first_event_time is None: if first_event_time is None:
@@ -512,15 +498,11 @@ class DifyWorkflowUser(HttpUser):
parsed_event: ParsedEventData = json.loads(event_data) parsed_event: ParsedEventData = json.loads(event_data)
# Check for terminal events # Check for terminal events
if parsed_event.get("event") in TERMINAL_EVENTS: if parsed_event.get("event") in TERMINAL_EVENTS:
logger.debug( logger.debug(f"Received terminal event: {parsed_event.get('event')}")
f"Received terminal event: {parsed_event.get('event')}"
)
request_success = True request_success = True
break break
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
logger.debug( logger.debug(f"JSON decode error: {e} for data: {event_data[:100]}")
f"JSON decode error: {e} for data: {event_data[:100]}"
)
metrics.record_error("invalid_json") metrics.record_error("invalid_json")
except Exception as e: except Exception as e:
@@ -583,16 +565,18 @@ def on_test_start(environment: object, **kwargs: object) -> None:
# Periodic stats reporting # Periodic stats reporting
def report_stats() -> None: def report_stats() -> None:
if not hasattr(environment, 'runner'): if not hasattr(environment, "runner"):
return return
runner = environment.runner runner = environment.runner
while hasattr(runner, 'state') and runner.state not in ["stopped", "stopping"]: while hasattr(runner, "state") and runner.state not in ["stopped", "stopping"]:
time.sleep(5) # Report every 5 seconds time.sleep(5) # Report every 5 seconds
if hasattr(runner, 'state') and runner.state == "running": if hasattr(runner, "state") and runner.state == "running":
stats = metrics.get_stats() stats = metrics.get_stats()
# Only log on master node in distributed mode # Only log on master node in distributed mode
is_master = not getattr(environment.runner, "worker_id", None) if hasattr(environment, 'runner') else True is_master = (
not getattr(environment.runner, "worker_id", None) if hasattr(environment, "runner") else True
)
if is_master: if is_master:
# Clear previous lines and show updated stats # Clear previous lines and show updated stats
logger.info("\n" + "=" * 80) logger.info("\n" + "=" * 80)
@@ -623,15 +607,15 @@ def on_test_start(environment: object, **kwargs: object) -> None:
logger.info( logger.info(
f"{'(TTFE in ms)':<25} {stats.ttfe_avg:>15.1f} {stats.ttfe_p50:>10.1f} {stats.ttfe_p95:>10.1f} {stats.ttfe_min:>10.1f} {stats.ttfe_max:>10.1f}" f"{'(TTFE in ms)':<25} {stats.ttfe_avg:>15.1f} {stats.ttfe_p50:>10.1f} {stats.ttfe_p95:>10.1f} {stats.ttfe_min:>10.1f} {stats.ttfe_max:>10.1f}"
) )
logger.info(f"{'Window Samples':<25} {stats.ttfe_samples:>15,d} (last {min(10000, stats.ttfe_total_samples):,d} samples)") logger.info(
f"{'Window Samples':<25} {stats.ttfe_samples:>15,d} (last {min(10000, stats.ttfe_total_samples):,d} samples)"
)
logger.info(f"{'Total Samples':<25} {stats.ttfe_total_samples:>15,d}") logger.info(f"{'Total Samples':<25} {stats.ttfe_total_samples:>15,d}")
# Inter-event latency # Inter-event latency
if stats.inter_event_latency_avg > 0: if stats.inter_event_latency_avg > 0:
logger.info("-" * 80) logger.info("-" * 80)
logger.info( logger.info(f"{'INTER-EVENT LATENCY':<25} {'AVG':>15} {'P50':>10} {'P95':>10}")
f"{'INTER-EVENT LATENCY':<25} {'AVG':>15} {'P50':>10} {'P95':>10}"
)
logger.info( logger.info(
f"{'(ms between events)':<25} {stats.inter_event_latency_avg:>15.1f} {stats.inter_event_latency_p50:>10.1f} {stats.inter_event_latency_p95:>10.1f}" f"{'(ms between events)':<25} {stats.inter_event_latency_avg:>15.1f} {stats.inter_event_latency_p50:>10.1f} {stats.inter_event_latency_p95:>10.1f}"
) )
@@ -647,9 +631,9 @@ def on_test_start(environment: object, **kwargs: object) -> None:
logger.info("=" * 80) logger.info("=" * 80)
# Show Locust stats summary # Show Locust stats summary
if hasattr(environment, 'stats') and hasattr(environment.stats, 'total'): if hasattr(environment, "stats") and hasattr(environment.stats, "total"):
total = environment.stats.total total = environment.stats.total
if hasattr(total, 'num_requests') and total.num_requests > 0: if hasattr(total, "num_requests") and total.num_requests > 0:
logger.info( logger.info(
f"{'LOCUST STATS':<25} {'Requests':>12} {'Fails':>8} {'Avg (ms)':>12} {'Min':>8} {'Max':>8}" f"{'LOCUST STATS':<25} {'Requests':>12} {'Fails':>8} {'Avg (ms)':>12} {'Min':>8} {'Max':>8}"
) )
@@ -687,21 +671,15 @@ def on_test_stop(environment: object, **kwargs: object) -> None:
logger.info("") logger.info("")
logger.info("EVENTS") logger.info("EVENTS")
logger.info(f" {'Total Events Received:':<30} {stats.total_events:>10,d}") logger.info(f" {'Total Events Received:':<30} {stats.total_events:>10,d}")
logger.info( logger.info(f" {'Average Throughput:':<30} {stats.overall_event_rate:>10.2f} events/s")
f" {'Average Throughput:':<30} {stats.overall_event_rate:>10.2f} events/s" logger.info(f" {'Final Rate (10s window):':<30} {stats.event_rate:>10.2f} events/s")
)
logger.info(
f" {'Final Rate (10s window):':<30} {stats.event_rate:>10.2f} events/s"
)
logger.info("") logger.info("")
logger.info("STREAM METRICS") logger.info("STREAM METRICS")
logger.info(f" {'Avg Stream Duration:':<30} {stats.stream_duration_avg:>10.1f} ms") logger.info(f" {'Avg Stream Duration:':<30} {stats.stream_duration_avg:>10.1f} ms")
logger.info(f" {'P50 Stream Duration:':<30} {stats.stream_duration_p50:>10.1f} ms") logger.info(f" {'P50 Stream Duration:':<30} {stats.stream_duration_p50:>10.1f} ms")
logger.info(f" {'P95 Stream Duration:':<30} {stats.stream_duration_p95:>10.1f} ms") logger.info(f" {'P95 Stream Duration:':<30} {stats.stream_duration_p95:>10.1f} ms")
logger.info( logger.info(f" {'Avg Events per Stream:':<30} {stats.events_per_stream_avg:>10.1f}")
f" {'Avg Events per Stream:':<30} {stats.events_per_stream_avg:>10.1f}"
)
logger.info("") logger.info("")
logger.info("INTER-EVENT LATENCY") logger.info("INTER-EVENT LATENCY")
@@ -716,7 +694,9 @@ def on_test_stop(environment: object, **kwargs: object) -> None:
logger.info(f" {'95th Percentile:':<30} {stats.ttfe_p95:>10.1f} ms") logger.info(f" {'95th Percentile:':<30} {stats.ttfe_p95:>10.1f} ms")
logger.info(f" {'Minimum:':<30} {stats.ttfe_min:>10.1f} ms") logger.info(f" {'Minimum:':<30} {stats.ttfe_min:>10.1f} ms")
logger.info(f" {'Maximum:':<30} {stats.ttfe_max:>10.1f} ms") logger.info(f" {'Maximum:':<30} {stats.ttfe_max:>10.1f} ms")
logger.info(f" {'Window Samples:':<30} {stats.ttfe_samples:>10,d} (last {min(10000, stats.ttfe_total_samples):,d})") logger.info(
f" {'Window Samples:':<30} {stats.ttfe_samples:>10,d} (last {min(10000, stats.ttfe_total_samples):,d})"
)
logger.info(f" {'Total Samples:':<30} {stats.ttfe_total_samples:>10,d}") logger.info(f" {'Total Samples:':<30} {stats.ttfe_total_samples:>10,d}")
# Error summary # Error summary
@@ -730,7 +710,7 @@ def on_test_stop(environment: object, **kwargs: object) -> None:
logger.info("=" * 80 + "\n") logger.info("=" * 80 + "\n")
# Export machine-readable report (only on master node) # Export machine-readable report (only on master node)
is_master = not getattr(environment.runner, 'worker_id', None) if hasattr(environment, 'runner') else True is_master = not getattr(environment.runner, "worker_id", None) if hasattr(environment, "runner") else True
if is_master: if is_master:
export_json_report(stats, test_duration, environment) export_json_report(stats, test_duration, environment)
@@ -746,9 +726,9 @@ def export_json_report(stats: MetricsSnapshot, duration: float, environment: obj
# Access environment.stats.total attributes safely # Access environment.stats.total attributes safely
locust_stats: LocustStats | None = None locust_stats: LocustStats | None = None
if hasattr(environment, 'stats') and hasattr(environment.stats, 'total'): if hasattr(environment, "stats") and hasattr(environment.stats, "total"):
total = environment.stats.total total = environment.stats.total
if hasattr(total, 'num_requests') and total.num_requests > 0: if hasattr(total, "num_requests") and total.num_requests > 0:
locust_stats = LocustStats( locust_stats = LocustStats(
total_requests=total.num_requests, total_requests=total.num_requests,
total_failures=total.num_failures, total_failures=total.num_failures,

View File

@@ -1,7 +1,15 @@
from dify_client.client import ( from dify_client.client import (
ChatClient, ChatClient,
CompletionClient, CompletionClient,
WorkflowClient,
KnowledgeBaseClient,
DifyClient, DifyClient,
KnowledgeBaseClient,
WorkflowClient,
) )
__all__ = [
"ChatClient",
"CompletionClient",
"DifyClient",
"KnowledgeBaseClient",
"WorkflowClient",
]

View File

@@ -8,16 +8,16 @@ class DifyClient:
self.api_key = api_key self.api_key = api_key
self.base_url = base_url self.base_url = base_url
def _send_request(self, method: str, endpoint: str, json: dict | None = None, params: dict | None = None, stream: bool = False): def _send_request(
self, method: str, endpoint: str, json: dict | None = None, params: dict | None = None, stream: bool = False
):
headers = { headers = {
"Authorization": f"Bearer {self.api_key}", "Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json", "Content-Type": "application/json",
} }
url = f"{self.base_url}{endpoint}" url = f"{self.base_url}{endpoint}"
response = requests.request( response = requests.request(method, url, json=json, params=params, headers=headers, stream=stream)
method, url, json=json, params=params, headers=headers, stream=stream
)
return response return response
@@ -25,9 +25,7 @@ class DifyClient:
headers = {"Authorization": f"Bearer {self.api_key}"} headers = {"Authorization": f"Bearer {self.api_key}"}
url = f"{self.base_url}{endpoint}" url = f"{self.base_url}{endpoint}"
response = requests.request( response = requests.request(method, url, data=data, headers=headers, files=files)
method, url, data=data, headers=headers, files=files
)
return response return response
@@ -41,9 +39,7 @@ class DifyClient:
def file_upload(self, user: str, files: dict): def file_upload(self, user: str, files: dict):
data = {"user": user} data = {"user": user}
return self._send_request_with_files( return self._send_request_with_files("POST", "/files/upload", data=data, files=files)
"POST", "/files/upload", data=data, files=files
)
def text_to_audio(self, text: str, user: str, streaming: bool = False): def text_to_audio(self, text: str, user: str, streaming: bool = False):
data = {"text": text, "user": user, "streaming": streaming} data = {"text": text, "user": user, "streaming": streaming}
@@ -55,7 +51,9 @@ class DifyClient:
class CompletionClient(DifyClient): class CompletionClient(DifyClient):
def create_completion_message(self, inputs: dict, response_mode: Literal["blocking", "streaming"], user: str, files: dict | None = None): def create_completion_message(
self, inputs: dict, response_mode: Literal["blocking", "streaming"], user: str, files: dict | None = None
):
data = { data = {
"inputs": inputs, "inputs": inputs,
"response_mode": response_mode, "response_mode": response_mode,
@@ -99,9 +97,7 @@ class ChatClient(DifyClient):
def get_suggested(self, message_id: str, user: str): def get_suggested(self, message_id: str, user: str):
params = {"user": user} params = {"user": user}
return self._send_request( return self._send_request("GET", f"/messages/{message_id}/suggested", params=params)
"GET", f"/messages/{message_id}/suggested", params=params
)
def stop_message(self, task_id: str, user: str): def stop_message(self, task_id: str, user: str):
data = {"user": user} data = {"user": user}
@@ -112,10 +108,9 @@ class ChatClient(DifyClient):
user: str, user: str,
last_id: str | None = None, last_id: str | None = None,
limit: int | None = None, limit: int | None = None,
pinned: bool | None = None pinned: bool | None = None,
): ):
params = {"user": user, "last_id": last_id, params = {"user": user, "last_id": last_id, "limit": limit, "pinned": pinned}
"limit": limit, "pinned": pinned}
return self._send_request("GET", "/conversations", params=params) return self._send_request("GET", "/conversations", params=params)
def get_conversation_messages( def get_conversation_messages(
@@ -123,7 +118,7 @@ class ChatClient(DifyClient):
user: str, user: str,
conversation_id: str | None = None, conversation_id: str | None = None,
first_id: str | None = None, first_id: str | None = None,
limit: int | None = None limit: int | None = None,
): ):
params = {"user": user} params = {"user": user}
@@ -136,13 +131,9 @@ class ChatClient(DifyClient):
return self._send_request("GET", "/messages", params=params) return self._send_request("GET", "/messages", params=params)
def rename_conversation( def rename_conversation(self, conversation_id: str, name: str, auto_generate: bool, user: str):
self, conversation_id: str, name: str, auto_generate: bool, user: str
):
data = {"name": name, "auto_generate": auto_generate, "user": user} data = {"name": name, "auto_generate": auto_generate, "user": user}
return self._send_request( return self._send_request("POST", f"/conversations/{conversation_id}/name", data)
"POST", f"/conversations/{conversation_id}/name", data
)
def delete_conversation(self, conversation_id: str, user: str): def delete_conversation(self, conversation_id: str, user: str):
data = {"user": user} data = {"user": user}
@@ -155,9 +146,7 @@ class ChatClient(DifyClient):
class WorkflowClient(DifyClient): class WorkflowClient(DifyClient):
def run( def run(self, inputs: dict, response_mode: Literal["blocking", "streaming"] = "streaming", user: str = "abc-123"):
self, inputs: dict, response_mode: Literal["blocking", "streaming"] = "streaming", user: str = "abc-123"
):
data = {"inputs": inputs, "response_mode": response_mode, "user": user} data = {"inputs": inputs, "response_mode": response_mode, "user": user}
return self._send_request("POST", "/workflows/run", data) return self._send_request("POST", "/workflows/run", data)
@@ -197,13 +186,9 @@ class KnowledgeBaseClient(DifyClient):
return self._send_request("POST", "/datasets", {"name": name}, **kwargs) return self._send_request("POST", "/datasets", {"name": name}, **kwargs)
def list_datasets(self, page: int = 1, page_size: int = 20, **kwargs): def list_datasets(self, page: int = 1, page_size: int = 20, **kwargs):
return self._send_request( return self._send_request("GET", f"/datasets?page={page}&limit={page_size}", **kwargs)
"GET", f"/datasets?page={page}&limit={page_size}", **kwargs
)
def create_document_by_text( def create_document_by_text(self, name, text, extra_params: dict | None = None, **kwargs):
self, name, text, extra_params: dict | None = None, **kwargs
):
""" """
Create a document by text. Create a document by text.
@@ -272,9 +257,7 @@ class KnowledgeBaseClient(DifyClient):
data = {"name": name, "text": text} data = {"name": name, "text": text}
if extra_params is not None and isinstance(extra_params, dict): if extra_params is not None and isinstance(extra_params, dict):
data.update(extra_params) data.update(extra_params)
url = ( url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/update_by_text"
f"/datasets/{self._get_dataset_id()}/documents/{document_id}/update_by_text"
)
return self._send_request("POST", url, json=data, **kwargs) return self._send_request("POST", url, json=data, **kwargs)
def create_document_by_file( def create_document_by_file(
@@ -315,13 +298,9 @@ class KnowledgeBaseClient(DifyClient):
if original_document_id is not None: if original_document_id is not None:
data["original_document_id"] = original_document_id data["original_document_id"] = original_document_id
url = f"/datasets/{self._get_dataset_id()}/document/create_by_file" url = f"/datasets/{self._get_dataset_id()}/document/create_by_file"
return self._send_request_with_files( return self._send_request_with_files("POST", url, {"data": json.dumps(data)}, files)
"POST", url, {"data": json.dumps(data)}, files
)
def update_document_by_file( def update_document_by_file(self, document_id: str, file_path: str, extra_params: dict | None = None):
self, document_id: str, file_path: str, extra_params: dict | None = None
):
""" """
Update a document by file. Update a document by file.
@@ -351,12 +330,8 @@ class KnowledgeBaseClient(DifyClient):
data = {} data = {}
if extra_params is not None and isinstance(extra_params, dict): if extra_params is not None and isinstance(extra_params, dict):
data.update(extra_params) data.update(extra_params)
url = ( url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/update_by_file"
f"/datasets/{self._get_dataset_id()}/documents/{document_id}/update_by_file" return self._send_request_with_files("POST", url, {"data": json.dumps(data)}, files)
)
return self._send_request_with_files(
"POST", url, {"data": json.dumps(data)}, files
)
def batch_indexing_status(self, batch_id: str, **kwargs): def batch_indexing_status(self, batch_id: str, **kwargs):
""" """

View File

@@ -1,6 +1,6 @@
from setuptools import setup from setuptools import setup
with open("README.md", "r", encoding="utf-8") as fh: with open("README.md", encoding="utf-8") as fh:
long_description = fh.read() long_description = fh.read()
setup( setup(

View File

@@ -18,9 +18,7 @@ FILE_PATH_BASE = os.path.dirname(__file__)
class TestKnowledgeBaseClient(unittest.TestCase): class TestKnowledgeBaseClient(unittest.TestCase):
def setUp(self): def setUp(self):
self.knowledge_base_client = KnowledgeBaseClient(API_KEY, base_url=API_BASE_URL) self.knowledge_base_client = KnowledgeBaseClient(API_KEY, base_url=API_BASE_URL)
self.README_FILE_PATH = os.path.abspath( self.README_FILE_PATH = os.path.abspath(os.path.join(FILE_PATH_BASE, "../README.md"))
os.path.join(FILE_PATH_BASE, "../README.md")
)
self.dataset_id = None self.dataset_id = None
self.document_id = None self.document_id = None
self.segment_id = None self.segment_id = None
@@ -28,9 +26,7 @@ class TestKnowledgeBaseClient(unittest.TestCase):
def _get_dataset_kb_client(self): def _get_dataset_kb_client(self):
self.assertIsNotNone(self.dataset_id) self.assertIsNotNone(self.dataset_id)
return KnowledgeBaseClient( return KnowledgeBaseClient(API_KEY, base_url=API_BASE_URL, dataset_id=self.dataset_id)
API_KEY, base_url=API_BASE_URL, dataset_id=self.dataset_id
)
def test_001_create_dataset(self): def test_001_create_dataset(self):
response = self.knowledge_base_client.create_dataset(name="test_dataset") response = self.knowledge_base_client.create_dataset(name="test_dataset")
@@ -76,9 +72,7 @@ class TestKnowledgeBaseClient(unittest.TestCase):
def _test_004_update_document_by_text(self): def _test_004_update_document_by_text(self):
client = self._get_dataset_kb_client() client = self._get_dataset_kb_client()
self.assertIsNotNone(self.document_id) self.assertIsNotNone(self.document_id)
response = client.update_document_by_text( response = client.update_document_by_text(self.document_id, "test_document_updated", "test_text_updated")
self.document_id, "test_document_updated", "test_text_updated"
)
data = response.json() data = response.json()
self.assertIn("document", data) self.assertIn("document", data)
self.assertIn("batch", data) self.assertIn("batch", data)
@@ -93,9 +87,7 @@ class TestKnowledgeBaseClient(unittest.TestCase):
def _test_006_update_document_by_file(self): def _test_006_update_document_by_file(self):
client = self._get_dataset_kb_client() client = self._get_dataset_kb_client()
self.assertIsNotNone(self.document_id) self.assertIsNotNone(self.document_id)
response = client.update_document_by_file( response = client.update_document_by_file(self.document_id, self.README_FILE_PATH)
self.document_id, self.README_FILE_PATH
)
data = response.json() data = response.json()
self.assertIn("document", data) self.assertIn("document", data)
self.assertIn("batch", data) self.assertIn("batch", data)
@@ -125,9 +117,7 @@ class TestKnowledgeBaseClient(unittest.TestCase):
def _test_010_add_segments(self): def _test_010_add_segments(self):
client = self._get_dataset_kb_client() client = self._get_dataset_kb_client()
response = client.add_segments( response = client.add_segments(self.document_id, [{"content": "test text segment 1"}])
self.document_id, [{"content": "test text segment 1"}]
)
data = response.json() data = response.json()
self.assertIn("data", data) self.assertIn("data", data)
self.assertGreater(len(data["data"]), 0) self.assertGreater(len(data["data"]), 0)
@@ -174,18 +164,12 @@ class TestChatClient(unittest.TestCase):
self.chat_client = ChatClient(API_KEY) self.chat_client = ChatClient(API_KEY)
def test_create_chat_message(self): def test_create_chat_message(self):
response = self.chat_client.create_chat_message( response = self.chat_client.create_chat_message({}, "Hello, World!", "test_user")
{}, "Hello, World!", "test_user"
)
self.assertIn("answer", response.text) self.assertIn("answer", response.text)
def test_create_chat_message_with_vision_model_by_remote_url(self): def test_create_chat_message_with_vision_model_by_remote_url(self):
files = [ files = [{"type": "image", "transfer_method": "remote_url", "url": "your_image_url"}]
{"type": "image", "transfer_method": "remote_url", "url": "your_image_url"} response = self.chat_client.create_chat_message({}, "Describe the picture.", "test_user", files=files)
]
response = self.chat_client.create_chat_message(
{}, "Describe the picture.", "test_user", files=files
)
self.assertIn("answer", response.text) self.assertIn("answer", response.text)
def test_create_chat_message_with_vision_model_by_local_file(self): def test_create_chat_message_with_vision_model_by_local_file(self):
@@ -196,15 +180,11 @@ class TestChatClient(unittest.TestCase):
"upload_file_id": "your_file_id", "upload_file_id": "your_file_id",
} }
] ]
response = self.chat_client.create_chat_message( response = self.chat_client.create_chat_message({}, "Describe the picture.", "test_user", files=files)
{}, "Describe the picture.", "test_user", files=files
)
self.assertIn("answer", response.text) self.assertIn("answer", response.text)
def test_get_conversation_messages(self): def test_get_conversation_messages(self):
response = self.chat_client.get_conversation_messages( response = self.chat_client.get_conversation_messages("test_user", "your_conversation_id")
"test_user", "your_conversation_id"
)
self.assertIn("answer", response.text) self.assertIn("answer", response.text)
def test_get_conversations(self): def test_get_conversations(self):
@@ -223,9 +203,7 @@ class TestCompletionClient(unittest.TestCase):
self.assertIn("answer", response.text) self.assertIn("answer", response.text)
def test_create_completion_message_with_vision_model_by_remote_url(self): def test_create_completion_message_with_vision_model_by_remote_url(self):
files = [ files = [{"type": "image", "transfer_method": "remote_url", "url": "your_image_url"}]
{"type": "image", "transfer_method": "remote_url", "url": "your_image_url"}
]
response = self.completion_client.create_completion_message( response = self.completion_client.create_completion_message(
{"query": "Describe the picture."}, "blocking", "test_user", files {"query": "Describe the picture."}, "blocking", "test_user", files
) )
@@ -250,9 +228,7 @@ class TestDifyClient(unittest.TestCase):
self.dify_client = DifyClient(API_KEY) self.dify_client = DifyClient(API_KEY)
def test_message_feedback(self): def test_message_feedback(self):
response = self.dify_client.message_feedback( response = self.dify_client.message_feedback("your_message_id", "like", "test_user")
"your_message_id", "like", "test_user"
)
self.assertIn("success", response.text) self.assertIn("success", response.text)
def test_get_application_parameters(self): def test_get_application_parameters(self):