mirror of
https://github.com/langgenius/dify.git
synced 2025-12-19 17:27:16 -05:00
ruff check preview (#25653)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
2
.github/workflows/autofix.yml
vendored
2
.github/workflows/autofix.yml
vendored
@@ -22,7 +22,7 @@ jobs:
|
|||||||
# Fix lint errors
|
# Fix lint errors
|
||||||
uv run ruff check --fix .
|
uv run ruff check --fix .
|
||||||
# Format code
|
# Format code
|
||||||
uv run ruff format .
|
uv run ruff format ..
|
||||||
|
|
||||||
- name: ast-grep
|
- name: ast-grep
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ line-length = 120
|
|||||||
quote-style = "double"
|
quote-style = "double"
|
||||||
|
|
||||||
[lint]
|
[lint]
|
||||||
preview = false
|
preview = true
|
||||||
select = [
|
select = [
|
||||||
"B", # flake8-bugbear rules
|
"B", # flake8-bugbear rules
|
||||||
"C4", # flake8-comprehensions
|
"C4", # flake8-comprehensions
|
||||||
@@ -65,6 +65,7 @@ ignore = [
|
|||||||
"B006", # mutable-argument-default
|
"B006", # mutable-argument-default
|
||||||
"B007", # unused-loop-control-variable
|
"B007", # unused-loop-control-variable
|
||||||
"B026", # star-arg-unpacking-after-keyword-arg
|
"B026", # star-arg-unpacking-after-keyword-arg
|
||||||
|
"B901", # allow return in yield
|
||||||
"B903", # class-as-data-structure
|
"B903", # class-as-data-structure
|
||||||
"B904", # raise-without-from-inside-except
|
"B904", # raise-without-from-inside-except
|
||||||
"B905", # zip-without-explicit-strict
|
"B905", # zip-without-explicit-strict
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import base64
|
import base64
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import operator
|
||||||
import secrets
|
import secrets
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@@ -953,7 +954,7 @@ def clear_orphaned_file_records(force: bool):
|
|||||||
click.echo(click.style("- Deleting orphaned message_files records", fg="white"))
|
click.echo(click.style("- Deleting orphaned message_files records", fg="white"))
|
||||||
query = "DELETE FROM message_files WHERE id IN :ids"
|
query = "DELETE FROM message_files WHERE id IN :ids"
|
||||||
with db.engine.begin() as conn:
|
with db.engine.begin() as conn:
|
||||||
conn.execute(sa.text(query), {"ids": tuple([record["id"] for record in orphaned_message_files])})
|
conn.execute(sa.text(query), {"ids": tuple(record["id"] for record in orphaned_message_files)})
|
||||||
click.echo(
|
click.echo(
|
||||||
click.style(f"Removed {len(orphaned_message_files)} orphaned message_files records.", fg="green")
|
click.style(f"Removed {len(orphaned_message_files)} orphaned message_files records.", fg="green")
|
||||||
)
|
)
|
||||||
@@ -1307,7 +1308,7 @@ def cleanup_orphaned_draft_variables(
|
|||||||
|
|
||||||
if dry_run:
|
if dry_run:
|
||||||
logger.info("DRY RUN: Would delete the following:")
|
logger.info("DRY RUN: Would delete the following:")
|
||||||
for app_id, count in sorted(stats["orphaned_by_app"].items(), key=lambda x: x[1], reverse=True)[
|
for app_id, count in sorted(stats["orphaned_by_app"].items(), key=operator.itemgetter(1), reverse=True)[
|
||||||
:10
|
:10
|
||||||
]: # Show top 10
|
]: # Show top 10
|
||||||
logger.info(" App %s: %s variables", app_id, count)
|
logger.info(" App %s: %s variables", app_id, count)
|
||||||
|
|||||||
@@ -355,8 +355,8 @@ class AliyunDataTrace(BaseTraceInstance):
|
|||||||
GEN_AI_FRAMEWORK: "dify",
|
GEN_AI_FRAMEWORK: "dify",
|
||||||
TOOL_NAME: node_execution.title,
|
TOOL_NAME: node_execution.title,
|
||||||
TOOL_DESCRIPTION: json.dumps(tool_des, ensure_ascii=False),
|
TOOL_DESCRIPTION: json.dumps(tool_des, ensure_ascii=False),
|
||||||
TOOL_PARAMETERS: json.dumps(node_execution.inputs if node_execution.inputs else {}, ensure_ascii=False),
|
TOOL_PARAMETERS: json.dumps(node_execution.inputs or {}, ensure_ascii=False),
|
||||||
INPUT_VALUE: json.dumps(node_execution.inputs if node_execution.inputs else {}, ensure_ascii=False),
|
INPUT_VALUE: json.dumps(node_execution.inputs or {}, ensure_ascii=False),
|
||||||
OUTPUT_VALUE: json.dumps(node_execution.outputs, ensure_ascii=False),
|
OUTPUT_VALUE: json.dumps(node_execution.outputs, ensure_ascii=False),
|
||||||
},
|
},
|
||||||
status=self.get_workflow_node_status(node_execution),
|
status=self.get_workflow_node_status(node_execution),
|
||||||
|
|||||||
@@ -144,13 +144,13 @@ class LangFuseDataTrace(BaseTraceInstance):
|
|||||||
if node_type == NodeType.LLM:
|
if node_type == NodeType.LLM:
|
||||||
inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {}
|
inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {}
|
||||||
else:
|
else:
|
||||||
inputs = node_execution.inputs if node_execution.inputs else {}
|
inputs = node_execution.inputs or {}
|
||||||
outputs = node_execution.outputs if node_execution.outputs else {}
|
outputs = node_execution.outputs or {}
|
||||||
created_at = node_execution.created_at or datetime.now()
|
created_at = node_execution.created_at or datetime.now()
|
||||||
elapsed_time = node_execution.elapsed_time
|
elapsed_time = node_execution.elapsed_time
|
||||||
finished_at = created_at + timedelta(seconds=elapsed_time)
|
finished_at = created_at + timedelta(seconds=elapsed_time)
|
||||||
|
|
||||||
execution_metadata = node_execution.metadata if node_execution.metadata else {}
|
execution_metadata = node_execution.metadata or {}
|
||||||
metadata = {str(k): v for k, v in execution_metadata.items()}
|
metadata = {str(k): v for k, v in execution_metadata.items()}
|
||||||
metadata.update(
|
metadata.update(
|
||||||
{
|
{
|
||||||
@@ -163,7 +163,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
|||||||
"status": status,
|
"status": status,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
process_data = node_execution.process_data if node_execution.process_data else {}
|
process_data = node_execution.process_data or {}
|
||||||
model_provider = process_data.get("model_provider", None)
|
model_provider = process_data.get("model_provider", None)
|
||||||
model_name = process_data.get("model_name", None)
|
model_name = process_data.get("model_name", None)
|
||||||
if model_provider is not None and model_name is not None:
|
if model_provider is not None and model_name is not None:
|
||||||
|
|||||||
@@ -167,13 +167,13 @@ class LangSmithDataTrace(BaseTraceInstance):
|
|||||||
if node_type == NodeType.LLM:
|
if node_type == NodeType.LLM:
|
||||||
inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {}
|
inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {}
|
||||||
else:
|
else:
|
||||||
inputs = node_execution.inputs if node_execution.inputs else {}
|
inputs = node_execution.inputs or {}
|
||||||
outputs = node_execution.outputs if node_execution.outputs else {}
|
outputs = node_execution.outputs or {}
|
||||||
created_at = node_execution.created_at or datetime.now()
|
created_at = node_execution.created_at or datetime.now()
|
||||||
elapsed_time = node_execution.elapsed_time
|
elapsed_time = node_execution.elapsed_time
|
||||||
finished_at = created_at + timedelta(seconds=elapsed_time)
|
finished_at = created_at + timedelta(seconds=elapsed_time)
|
||||||
|
|
||||||
execution_metadata = node_execution.metadata if node_execution.metadata else {}
|
execution_metadata = node_execution.metadata or {}
|
||||||
node_total_tokens = execution_metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS) or 0
|
node_total_tokens = execution_metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS) or 0
|
||||||
metadata = {str(key): value for key, value in execution_metadata.items()}
|
metadata = {str(key): value for key, value in execution_metadata.items()}
|
||||||
metadata.update(
|
metadata.update(
|
||||||
@@ -188,7 +188,7 @@ class LangSmithDataTrace(BaseTraceInstance):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
process_data = node_execution.process_data if node_execution.process_data else {}
|
process_data = node_execution.process_data or {}
|
||||||
|
|
||||||
if process_data and process_data.get("model_mode") == "chat":
|
if process_data and process_data.get("model_mode") == "chat":
|
||||||
run_type = LangSmithRunType.llm
|
run_type = LangSmithRunType.llm
|
||||||
|
|||||||
@@ -182,13 +182,13 @@ class OpikDataTrace(BaseTraceInstance):
|
|||||||
if node_type == NodeType.LLM:
|
if node_type == NodeType.LLM:
|
||||||
inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {}
|
inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {}
|
||||||
else:
|
else:
|
||||||
inputs = node_execution.inputs if node_execution.inputs else {}
|
inputs = node_execution.inputs or {}
|
||||||
outputs = node_execution.outputs if node_execution.outputs else {}
|
outputs = node_execution.outputs or {}
|
||||||
created_at = node_execution.created_at or datetime.now()
|
created_at = node_execution.created_at or datetime.now()
|
||||||
elapsed_time = node_execution.elapsed_time
|
elapsed_time = node_execution.elapsed_time
|
||||||
finished_at = created_at + timedelta(seconds=elapsed_time)
|
finished_at = created_at + timedelta(seconds=elapsed_time)
|
||||||
|
|
||||||
execution_metadata = node_execution.metadata if node_execution.metadata else {}
|
execution_metadata = node_execution.metadata or {}
|
||||||
metadata = {str(k): v for k, v in execution_metadata.items()}
|
metadata = {str(k): v for k, v in execution_metadata.items()}
|
||||||
metadata.update(
|
metadata.update(
|
||||||
{
|
{
|
||||||
@@ -202,7 +202,7 @@ class OpikDataTrace(BaseTraceInstance):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
process_data = node_execution.process_data if node_execution.process_data else {}
|
process_data = node_execution.process_data or {}
|
||||||
|
|
||||||
provider = None
|
provider = None
|
||||||
model = None
|
model = None
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import collections
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
@@ -40,7 +41,7 @@ from tasks.ops_trace_task import process_trace_tasks
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class OpsTraceProviderConfigMap(dict[str, dict[str, Any]]):
|
class OpsTraceProviderConfigMap(collections.UserDict[str, dict[str, Any]]):
|
||||||
def __getitem__(self, provider: str) -> dict[str, Any]:
|
def __getitem__(self, provider: str) -> dict[str, Any]:
|
||||||
match provider:
|
match provider:
|
||||||
case TracingProviderEnum.LANGFUSE:
|
case TracingProviderEnum.LANGFUSE:
|
||||||
@@ -121,7 +122,7 @@ class OpsTraceProviderConfigMap(dict[str, dict[str, Any]]):
|
|||||||
raise KeyError(f"Unsupported tracing provider: {provider}")
|
raise KeyError(f"Unsupported tracing provider: {provider}")
|
||||||
|
|
||||||
|
|
||||||
provider_config_map: dict[str, dict[str, Any]] = OpsTraceProviderConfigMap()
|
provider_config_map = OpsTraceProviderConfigMap()
|
||||||
|
|
||||||
|
|
||||||
class OpsTraceManager:
|
class OpsTraceManager:
|
||||||
|
|||||||
@@ -169,13 +169,13 @@ class WeaveDataTrace(BaseTraceInstance):
|
|||||||
if node_type == NodeType.LLM:
|
if node_type == NodeType.LLM:
|
||||||
inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {}
|
inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {}
|
||||||
else:
|
else:
|
||||||
inputs = node_execution.inputs if node_execution.inputs else {}
|
inputs = node_execution.inputs or {}
|
||||||
outputs = node_execution.outputs if node_execution.outputs else {}
|
outputs = node_execution.outputs or {}
|
||||||
created_at = node_execution.created_at or datetime.now()
|
created_at = node_execution.created_at or datetime.now()
|
||||||
elapsed_time = node_execution.elapsed_time
|
elapsed_time = node_execution.elapsed_time
|
||||||
finished_at = created_at + timedelta(seconds=elapsed_time)
|
finished_at = created_at + timedelta(seconds=elapsed_time)
|
||||||
|
|
||||||
execution_metadata = node_execution.metadata if node_execution.metadata else {}
|
execution_metadata = node_execution.metadata or {}
|
||||||
node_total_tokens = execution_metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS) or 0
|
node_total_tokens = execution_metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS) or 0
|
||||||
attributes = {str(k): v for k, v in execution_metadata.items()}
|
attributes = {str(k): v for k, v in execution_metadata.items()}
|
||||||
attributes.update(
|
attributes.update(
|
||||||
@@ -190,7 +190,7 @@ class WeaveDataTrace(BaseTraceInstance):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
process_data = node_execution.process_data if node_execution.process_data else {}
|
process_data = node_execution.process_data or {}
|
||||||
if process_data and process_data.get("model_mode") == "chat":
|
if process_data and process_data.get("model_mode") == "chat":
|
||||||
attributes.update(
|
attributes.update(
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -641,7 +641,7 @@ class ClickzettaVector(BaseVector):
|
|||||||
|
|
||||||
for doc, embedding in zip(batch_docs, batch_embeddings):
|
for doc, embedding in zip(batch_docs, batch_embeddings):
|
||||||
# Optimized: minimal checks for common case, fallback for edge cases
|
# Optimized: minimal checks for common case, fallback for edge cases
|
||||||
metadata = doc.metadata if doc.metadata else {}
|
metadata = doc.metadata or {}
|
||||||
|
|
||||||
if not isinstance(metadata, dict):
|
if not isinstance(metadata, dict):
|
||||||
metadata = {}
|
metadata = {}
|
||||||
|
|||||||
@@ -103,7 +103,7 @@ class MatrixoneVector(BaseVector):
|
|||||||
self.client = self._get_client(len(embeddings[0]), True)
|
self.client = self._get_client(len(embeddings[0]), True)
|
||||||
assert self.client is not None
|
assert self.client is not None
|
||||||
ids = []
|
ids = []
|
||||||
for _, doc in enumerate(documents):
|
for doc in documents:
|
||||||
if doc.metadata is not None:
|
if doc.metadata is not None:
|
||||||
doc_id = doc.metadata.get("doc_id", str(uuid.uuid4()))
|
doc_id = doc.metadata.get("doc_id", str(uuid.uuid4()))
|
||||||
ids.append(doc_id)
|
ids.append(doc_id)
|
||||||
|
|||||||
@@ -104,7 +104,7 @@ class OpenSearchVector(BaseVector):
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
# See https://github.com/langchain-ai/langchainjs/issues/4346#issuecomment-1935123377
|
# See https://github.com/langchain-ai/langchainjs/issues/4346#issuecomment-1935123377
|
||||||
if self._client_config.aws_service not in ["aoss"]:
|
if self._client_config.aws_service != "aoss":
|
||||||
action["_id"] = uuid4().hex
|
action["_id"] = uuid4().hex
|
||||||
actions.append(action)
|
actions.append(action)
|
||||||
|
|
||||||
|
|||||||
@@ -159,7 +159,7 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository):
|
|||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
db_model.status = domain_model.status
|
db_model.status = domain_model.status
|
||||||
db_model.error = domain_model.error_message if domain_model.error_message else None
|
db_model.error = domain_model.error_message or None
|
||||||
db_model.total_tokens = domain_model.total_tokens
|
db_model.total_tokens = domain_model.total_tokens
|
||||||
db_model.total_steps = domain_model.total_steps
|
db_model.total_steps = domain_model.total_steps
|
||||||
db_model.exceptions_count = domain_model.exceptions_count
|
db_model.exceptions_count = domain_model.exceptions_count
|
||||||
|
|||||||
@@ -320,7 +320,7 @@ class AgentNode(BaseNode):
|
|||||||
memory = self._fetch_memory(model_instance)
|
memory = self._fetch_memory(model_instance)
|
||||||
if memory:
|
if memory:
|
||||||
prompt_messages = memory.get_history_prompt_messages(
|
prompt_messages = memory.get_history_prompt_messages(
|
||||||
message_limit=node_data.memory.window.size if node_data.memory.window.size else None
|
message_limit=node_data.memory.window.size or None
|
||||||
)
|
)
|
||||||
history_prompt_messages = [
|
history_prompt_messages = [
|
||||||
prompt_message.model_dump(mode="json") for prompt_message in prompt_messages
|
prompt_message.model_dump(mode="json") for prompt_message in prompt_messages
|
||||||
|
|||||||
@@ -141,9 +141,7 @@ def init_app(app: DifyApp) -> Celery:
|
|||||||
imports.append("schedule.queue_monitor_task")
|
imports.append("schedule.queue_monitor_task")
|
||||||
beat_schedule["datasets-queue-monitor"] = {
|
beat_schedule["datasets-queue-monitor"] = {
|
||||||
"task": "schedule.queue_monitor_task.queue_monitor_task",
|
"task": "schedule.queue_monitor_task.queue_monitor_task",
|
||||||
"schedule": timedelta(
|
"schedule": timedelta(minutes=dify_config.QUEUE_MONITOR_INTERVAL or 30),
|
||||||
minutes=dify_config.QUEUE_MONITOR_INTERVAL if dify_config.QUEUE_MONITOR_INTERVAL else 30
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
if dify_config.ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK and dify_config.MARKETPLACE_ENABLED:
|
if dify_config.ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK and dify_config.MARKETPLACE_ENABLED:
|
||||||
imports.append("schedule.check_upgradable_plugin_task")
|
imports.append("schedule.check_upgradable_plugin_task")
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ Supports complete lifecycle management for knowledge base files.
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import operator
|
||||||
from dataclasses import asdict, dataclass
|
from dataclasses import asdict, dataclass
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import StrEnum, auto
|
from enum import StrEnum, auto
|
||||||
@@ -356,7 +357,7 @@ class FileLifecycleManager:
|
|||||||
# Cleanup old versions for each file
|
# Cleanup old versions for each file
|
||||||
for base_filename, versions in file_versions.items():
|
for base_filename, versions in file_versions.items():
|
||||||
# Sort by version number
|
# Sort by version number
|
||||||
versions.sort(key=lambda x: x[0], reverse=True)
|
versions.sort(key=operator.itemgetter(0), reverse=True)
|
||||||
|
|
||||||
# Keep the newest max_versions versions, delete the rest
|
# Keep the newest max_versions versions, delete the rest
|
||||||
if len(versions) > max_versions:
|
if len(versions) > max_versions:
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import operator
|
||||||
import traceback
|
import traceback
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
@@ -118,7 +119,7 @@ def process_tenant_plugin_autoupgrade_check_task(
|
|||||||
current_version = version
|
current_version = version
|
||||||
latest_version = manifest.latest_version
|
latest_version = manifest.latest_version
|
||||||
|
|
||||||
def fix_only_checker(latest_version, current_version):
|
def fix_only_checker(latest_version: str, current_version: str):
|
||||||
latest_version_tuple = tuple(int(val) for val in latest_version.split("."))
|
latest_version_tuple = tuple(int(val) for val in latest_version.split("."))
|
||||||
current_version_tuple = tuple(int(val) for val in current_version.split("."))
|
current_version_tuple = tuple(int(val) for val in current_version.split("."))
|
||||||
|
|
||||||
@@ -130,8 +131,7 @@ def process_tenant_plugin_autoupgrade_check_task(
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
version_checker = {
|
version_checker = {
|
||||||
TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST: lambda latest_version,
|
TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST: operator.ne,
|
||||||
current_version: latest_version != current_version,
|
|
||||||
TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY: fix_only_checker,
|
TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY: fix_only_checker,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@@ -60,8 +61,7 @@ class TestClickZettaVolumeStorage(unittest.TestCase):
|
|||||||
# Test download
|
# Test download
|
||||||
with tempfile.NamedTemporaryFile() as temp_file:
|
with tempfile.NamedTemporaryFile() as temp_file:
|
||||||
storage.download(test_filename, temp_file.name)
|
storage.download(test_filename, temp_file.name)
|
||||||
with open(temp_file.name, "rb") as f:
|
downloaded_content = Path(temp_file.name).read_bytes()
|
||||||
downloaded_content = f.read()
|
|
||||||
assert downloaded_content == test_content
|
assert downloaded_content == test_content
|
||||||
|
|
||||||
# Test scan
|
# Test scan
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ and realistic testing scenarios with actual PostgreSQL and Redis instances.
|
|||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -276,8 +277,7 @@ class TestBatchCreateSegmentToIndexTask:
|
|||||||
mock_storage = mock_external_service_dependencies["storage"]
|
mock_storage = mock_external_service_dependencies["storage"]
|
||||||
|
|
||||||
def mock_download(key, file_path):
|
def mock_download(key, file_path):
|
||||||
with open(file_path, "w", encoding="utf-8") as f:
|
Path(file_path).write_text(csv_content, encoding="utf-8")
|
||||||
f.write(csv_content)
|
|
||||||
|
|
||||||
mock_storage.download.side_effect = mock_download
|
mock_storage.download.side_effect = mock_download
|
||||||
|
|
||||||
@@ -505,7 +505,7 @@ class TestBatchCreateSegmentToIndexTask:
|
|||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
# Test each unavailable document
|
# Test each unavailable document
|
||||||
for i, document in enumerate(test_cases):
|
for document in test_cases:
|
||||||
job_id = str(uuid.uuid4())
|
job_id = str(uuid.uuid4())
|
||||||
batch_create_segment_to_index_task(
|
batch_create_segment_to_index_task(
|
||||||
job_id=job_id,
|
job_id=job_id,
|
||||||
@@ -601,8 +601,7 @@ class TestBatchCreateSegmentToIndexTask:
|
|||||||
mock_storage = mock_external_service_dependencies["storage"]
|
mock_storage = mock_external_service_dependencies["storage"]
|
||||||
|
|
||||||
def mock_download(key, file_path):
|
def mock_download(key, file_path):
|
||||||
with open(file_path, "w", encoding="utf-8") as f:
|
Path(file_path).write_text(empty_csv_content, encoding="utf-8")
|
||||||
f.write(empty_csv_content)
|
|
||||||
|
|
||||||
mock_storage.download.side_effect = mock_download
|
mock_storage.download.side_effect = mock_download
|
||||||
|
|
||||||
@@ -684,8 +683,7 @@ class TestBatchCreateSegmentToIndexTask:
|
|||||||
mock_storage = mock_external_service_dependencies["storage"]
|
mock_storage = mock_external_service_dependencies["storage"]
|
||||||
|
|
||||||
def mock_download(key, file_path):
|
def mock_download(key, file_path):
|
||||||
with open(file_path, "w", encoding="utf-8") as f:
|
Path(file_path).write_text(csv_content, encoding="utf-8")
|
||||||
f.write(csv_content)
|
|
||||||
|
|
||||||
mock_storage.download.side_effect = mock_download
|
mock_storage.download.side_effect = mock_download
|
||||||
|
|
||||||
|
|||||||
@@ -362,7 +362,7 @@ class TestCleanDatasetTask:
|
|||||||
|
|
||||||
# Create segments for each document
|
# Create segments for each document
|
||||||
segments = []
|
segments = []
|
||||||
for i, document in enumerate(documents):
|
for document in documents:
|
||||||
segment = self._create_test_segment(db_session_with_containers, account, tenant, dataset, document)
|
segment = self._create_test_segment(db_session_with_containers, account, tenant, dataset, document)
|
||||||
segments.append(segment)
|
segments.append(segment)
|
||||||
|
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ class FakeResponse:
|
|||||||
self.status_code = status_code
|
self.status_code = status_code
|
||||||
self.headers = headers or {}
|
self.headers = headers or {}
|
||||||
self.content = content
|
self.content = content
|
||||||
self.text = text if text else content.decode("utf-8", errors="ignore")
|
self.text = text or content.decode("utf-8", errors="ignore")
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------
|
# ---------------------------
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
from pathlib import Path
|
||||||
from unittest.mock import Mock, create_autospec, patch
|
from unittest.mock import Mock, create_autospec, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -146,19 +147,17 @@ class TestMetadataBugCompleteValidation:
|
|||||||
# Console API create
|
# Console API create
|
||||||
console_create_file = "api/controllers/console/datasets/metadata.py"
|
console_create_file = "api/controllers/console/datasets/metadata.py"
|
||||||
if os.path.exists(console_create_file):
|
if os.path.exists(console_create_file):
|
||||||
with open(console_create_file) as f:
|
content = Path(console_create_file).read_text()
|
||||||
content = f.read()
|
# Should contain nullable=False, not nullable=True
|
||||||
# Should contain nullable=False, not nullable=True
|
assert "nullable=True" not in content.split("class DatasetMetadataCreateApi")[1].split("class")[0]
|
||||||
assert "nullable=True" not in content.split("class DatasetMetadataCreateApi")[1].split("class")[0]
|
|
||||||
|
|
||||||
# Service API create
|
# Service API create
|
||||||
service_create_file = "api/controllers/service_api/dataset/metadata.py"
|
service_create_file = "api/controllers/service_api/dataset/metadata.py"
|
||||||
if os.path.exists(service_create_file):
|
if os.path.exists(service_create_file):
|
||||||
with open(service_create_file) as f:
|
content = Path(service_create_file).read_text()
|
||||||
content = f.read()
|
# Should contain nullable=False, not nullable=True
|
||||||
# Should contain nullable=False, not nullable=True
|
create_api_section = content.split("class DatasetMetadataCreateServiceApi")[1].split("class")[0]
|
||||||
create_api_section = content.split("class DatasetMetadataCreateServiceApi")[1].split("class")[0]
|
assert "nullable=True" not in create_api_section
|
||||||
assert "nullable=True" not in create_api_section
|
|
||||||
|
|
||||||
|
|
||||||
class TestMetadataValidationSummary:
|
class TestMetadataValidationSummary:
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import yaml # type: ignore
|
import yaml # type: ignore
|
||||||
from dotenv import dotenv_values
|
from dotenv import dotenv_values
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
BASE_API_AND_DOCKER_CONFIG_SET_DIFF = {
|
BASE_API_AND_DOCKER_CONFIG_SET_DIFF = {
|
||||||
"APP_MAX_EXECUTION_TIME",
|
"APP_MAX_EXECUTION_TIME",
|
||||||
@@ -98,23 +99,15 @@ with open(Path("docker") / Path("docker-compose.yaml")) as f:
|
|||||||
|
|
||||||
def test_yaml_config():
|
def test_yaml_config():
|
||||||
# python set == operator is used to compare two sets
|
# python set == operator is used to compare two sets
|
||||||
DIFF_API_WITH_DOCKER = (
|
DIFF_API_WITH_DOCKER = API_CONFIG_SET - DOCKER_CONFIG_SET - BASE_API_AND_DOCKER_CONFIG_SET_DIFF
|
||||||
API_CONFIG_SET - DOCKER_CONFIG_SET - BASE_API_AND_DOCKER_CONFIG_SET_DIFF
|
|
||||||
)
|
|
||||||
if DIFF_API_WITH_DOCKER:
|
if DIFF_API_WITH_DOCKER:
|
||||||
print(
|
print(f"API and Docker config sets are different with key: {DIFF_API_WITH_DOCKER}")
|
||||||
f"API and Docker config sets are different with key: {DIFF_API_WITH_DOCKER}"
|
|
||||||
)
|
|
||||||
raise Exception("API and Docker config sets are different")
|
raise Exception("API and Docker config sets are different")
|
||||||
DIFF_API_WITH_DOCKER_COMPOSE = (
|
DIFF_API_WITH_DOCKER_COMPOSE = (
|
||||||
API_CONFIG_SET
|
API_CONFIG_SET - DOCKER_COMPOSE_CONFIG_SET - BASE_API_AND_DOCKER_COMPOSE_CONFIG_SET_DIFF
|
||||||
- DOCKER_COMPOSE_CONFIG_SET
|
|
||||||
- BASE_API_AND_DOCKER_COMPOSE_CONFIG_SET_DIFF
|
|
||||||
)
|
)
|
||||||
if DIFF_API_WITH_DOCKER_COMPOSE:
|
if DIFF_API_WITH_DOCKER_COMPOSE:
|
||||||
print(
|
print(f"API and Docker Compose config sets are different with key: {DIFF_API_WITH_DOCKER_COMPOSE}")
|
||||||
f"API and Docker Compose config sets are different with key: {DIFF_API_WITH_DOCKER_COMPOSE}"
|
|
||||||
)
|
|
||||||
raise Exception("API and Docker Compose config sets are different")
|
raise Exception("API and Docker Compose config sets are different")
|
||||||
print("All tests passed!")
|
print("All tests passed!")
|
||||||
|
|
||||||
|
|||||||
@@ -51,9 +51,7 @@ def cleanup() -> None:
|
|||||||
if sys.stdin.isatty():
|
if sys.stdin.isatty():
|
||||||
log.separator()
|
log.separator()
|
||||||
log.warning("This action cannot be undone!")
|
log.warning("This action cannot be undone!")
|
||||||
confirmation = input(
|
confirmation = input("Are you sure you want to remove all config and report files? (yes/no): ")
|
||||||
"Are you sure you want to remove all config and report files? (yes/no): "
|
|
||||||
)
|
|
||||||
|
|
||||||
if confirmation.lower() not in ["yes", "y"]:
|
if confirmation.lower() not in ["yes", "y"]:
|
||||||
log.error("Cleanup cancelled.")
|
log.error("Cleanup cancelled.")
|
||||||
|
|||||||
@@ -3,4 +3,4 @@
|
|||||||
from .config_helper import config_helper
|
from .config_helper import config_helper
|
||||||
from .logger_helper import Logger, ProgressLogger
|
from .logger_helper import Logger, ProgressLogger
|
||||||
|
|
||||||
__all__ = ["config_helper", "Logger", "ProgressLogger"]
|
__all__ = ["Logger", "ProgressLogger", "config_helper"]
|
||||||
|
|||||||
@@ -65,9 +65,9 @@ class ConfigHelper:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with open(config_path, "r") as f:
|
with open(config_path) as f:
|
||||||
return json.load(f)
|
return json.load(f)
|
||||||
except (json.JSONDecodeError, IOError) as e:
|
except (OSError, json.JSONDecodeError) as e:
|
||||||
print(f"❌ Error reading {filename}: {e}")
|
print(f"❌ Error reading {filename}: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -101,7 +101,7 @@ class ConfigHelper:
|
|||||||
with open(config_path, "w") as f:
|
with open(config_path, "w") as f:
|
||||||
json.dump(data, f, indent=2)
|
json.dump(data, f, indent=2)
|
||||||
return True
|
return True
|
||||||
except IOError as e:
|
except OSError as e:
|
||||||
print(f"❌ Error writing {filename}: {e}")
|
print(f"❌ Error writing {filename}: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@@ -133,7 +133,7 @@ class ConfigHelper:
|
|||||||
try:
|
try:
|
||||||
config_path.unlink()
|
config_path.unlink()
|
||||||
return True
|
return True
|
||||||
except IOError as e:
|
except OSError as e:
|
||||||
print(f"❌ Error deleting {filename}: {e}")
|
print(f"❌ Error deleting {filename}: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@@ -148,9 +148,9 @@ class ConfigHelper:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with open(state_path, "r") as f:
|
with open(state_path) as f:
|
||||||
return json.load(f)
|
return json.load(f)
|
||||||
except (json.JSONDecodeError, IOError) as e:
|
except (OSError, json.JSONDecodeError) as e:
|
||||||
print(f"❌ Error reading {self.state_file}: {e}")
|
print(f"❌ Error reading {self.state_file}: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -170,7 +170,7 @@ class ConfigHelper:
|
|||||||
with open(state_path, "w") as f:
|
with open(state_path, "w") as f:
|
||||||
json.dump(data, f, indent=2)
|
json.dump(data, f, indent=2)
|
||||||
return True
|
return True
|
||||||
except IOError as e:
|
except OSError as e:
|
||||||
print(f"❌ Error writing {self.state_file}: {e}")
|
print(f"❌ Error writing {self.state_file}: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|||||||
@@ -159,9 +159,7 @@ class ProgressLogger:
|
|||||||
|
|
||||||
if self.logger.use_colors:
|
if self.logger.use_colors:
|
||||||
progress_bar = self._create_progress_bar()
|
progress_bar = self._create_progress_bar()
|
||||||
print(
|
print(f"\n\033[1m[Step {self.current_step}/{self.total_steps}]\033[0m {progress_bar}")
|
||||||
f"\n\033[1m[Step {self.current_step}/{self.total_steps}]\033[0m {progress_bar}"
|
|
||||||
)
|
|
||||||
self.logger.step(f"{description} (Elapsed: {elapsed:.1f}s)")
|
self.logger.step(f"{description} (Elapsed: {elapsed:.1f}s)")
|
||||||
else:
|
else:
|
||||||
print(f"\n[Step {self.current_step}/{self.total_steps}]")
|
print(f"\n[Step {self.current_step}/{self.total_steps}]")
|
||||||
|
|||||||
@@ -6,8 +6,7 @@ from pathlib import Path
|
|||||||
sys.path.append(str(Path(__file__).parent.parent))
|
sys.path.append(str(Path(__file__).parent.parent))
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from common import config_helper
|
from common import Logger, config_helper
|
||||||
from common import Logger
|
|
||||||
|
|
||||||
|
|
||||||
def configure_openai_plugin() -> None:
|
def configure_openai_plugin() -> None:
|
||||||
@@ -72,29 +71,19 @@ def configure_openai_plugin() -> None:
|
|||||||
|
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
log.success("OpenAI plugin configured successfully!")
|
log.success("OpenAI plugin configured successfully!")
|
||||||
log.key_value(
|
log.key_value("API Base", config_payload["credentials"]["openai_api_base"])
|
||||||
"API Base", config_payload["credentials"]["openai_api_base"]
|
log.key_value("API Key", config_payload["credentials"]["openai_api_key"])
|
||||||
)
|
|
||||||
log.key_value(
|
|
||||||
"API Key", config_payload["credentials"]["openai_api_key"]
|
|
||||||
)
|
|
||||||
|
|
||||||
elif response.status_code == 201:
|
elif response.status_code == 201:
|
||||||
log.success("OpenAI plugin credentials created successfully!")
|
log.success("OpenAI plugin credentials created successfully!")
|
||||||
log.key_value(
|
log.key_value("API Base", config_payload["credentials"]["openai_api_base"])
|
||||||
"API Base", config_payload["credentials"]["openai_api_base"]
|
log.key_value("API Key", config_payload["credentials"]["openai_api_key"])
|
||||||
)
|
|
||||||
log.key_value(
|
|
||||||
"API Key", config_payload["credentials"]["openai_api_key"]
|
|
||||||
)
|
|
||||||
|
|
||||||
elif response.status_code == 401:
|
elif response.status_code == 401:
|
||||||
log.error("Configuration failed: Unauthorized")
|
log.error("Configuration failed: Unauthorized")
|
||||||
log.info("Token may have expired. Please run login_admin.py again")
|
log.info("Token may have expired. Please run login_admin.py again")
|
||||||
else:
|
else:
|
||||||
log.error(
|
log.error(f"Configuration failed with status code: {response.status_code}")
|
||||||
f"Configuration failed with status code: {response.status_code}"
|
|
||||||
)
|
|
||||||
log.debug(f"Response: {response.text}")
|
log.debug(f"Response: {response.text}")
|
||||||
|
|
||||||
except httpx.ConnectError:
|
except httpx.ConnectError:
|
||||||
|
|||||||
@@ -5,10 +5,10 @@ from pathlib import Path
|
|||||||
|
|
||||||
sys.path.append(str(Path(__file__).parent.parent))
|
sys.path.append(str(Path(__file__).parent.parent))
|
||||||
|
|
||||||
import httpx
|
|
||||||
import json
|
import json
|
||||||
from common import config_helper
|
|
||||||
from common import Logger
|
import httpx
|
||||||
|
from common import Logger, config_helper
|
||||||
|
|
||||||
|
|
||||||
def create_api_key() -> None:
|
def create_api_key() -> None:
|
||||||
@@ -90,9 +90,7 @@ def create_api_key() -> None:
|
|||||||
}
|
}
|
||||||
|
|
||||||
if config_helper.write_config("api_key_config", api_key_config):
|
if config_helper.write_config("api_key_config", api_key_config):
|
||||||
log.info(
|
log.info(f"API key saved to: {config_helper.get_config_path('benchmark_state')}")
|
||||||
f"API key saved to: {config_helper.get_config_path('benchmark_state')}"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
log.error("No API token received")
|
log.error("No API token received")
|
||||||
log.debug(f"Response: {json.dumps(response_data, indent=2)}")
|
log.debug(f"Response: {json.dumps(response_data, indent=2)}")
|
||||||
@@ -101,9 +99,7 @@ def create_api_key() -> None:
|
|||||||
log.error("API key creation failed: Unauthorized")
|
log.error("API key creation failed: Unauthorized")
|
||||||
log.info("Token may have expired. Please run login_admin.py again")
|
log.info("Token may have expired. Please run login_admin.py again")
|
||||||
else:
|
else:
|
||||||
log.error(
|
log.error(f"API key creation failed with status code: {response.status_code}")
|
||||||
f"API key creation failed with status code: {response.status_code}"
|
|
||||||
)
|
|
||||||
log.debug(f"Response: {response.text}")
|
log.debug(f"Response: {response.text}")
|
||||||
|
|
||||||
except httpx.ConnectError:
|
except httpx.ConnectError:
|
||||||
|
|||||||
@@ -5,9 +5,10 @@ from pathlib import Path
|
|||||||
|
|
||||||
sys.path.append(str(Path(__file__).parent.parent))
|
sys.path.append(str(Path(__file__).parent.parent))
|
||||||
|
|
||||||
import httpx
|
|
||||||
import json
|
import json
|
||||||
from common import config_helper, Logger
|
|
||||||
|
import httpx
|
||||||
|
from common import Logger, config_helper
|
||||||
|
|
||||||
|
|
||||||
def import_workflow_app() -> None:
|
def import_workflow_app() -> None:
|
||||||
@@ -30,7 +31,7 @@ def import_workflow_app() -> None:
|
|||||||
log.error(f"DSL file not found: {dsl_path}")
|
log.error(f"DSL file not found: {dsl_path}")
|
||||||
return
|
return
|
||||||
|
|
||||||
with open(dsl_path, "r") as f:
|
with open(dsl_path) as f:
|
||||||
yaml_content = f.read()
|
yaml_content = f.read()
|
||||||
|
|
||||||
log.step("Importing workflow app from DSL...")
|
log.step("Importing workflow app from DSL...")
|
||||||
@@ -86,9 +87,7 @@ def import_workflow_app() -> None:
|
|||||||
log.success("Workflow app imported successfully!")
|
log.success("Workflow app imported successfully!")
|
||||||
log.key_value("App ID", app_id)
|
log.key_value("App ID", app_id)
|
||||||
log.key_value("App Mode", response_data.get("app_mode"))
|
log.key_value("App Mode", response_data.get("app_mode"))
|
||||||
log.key_value(
|
log.key_value("DSL Version", response_data.get("imported_dsl_version"))
|
||||||
"DSL Version", response_data.get("imported_dsl_version")
|
|
||||||
)
|
|
||||||
|
|
||||||
# Save app_id to config
|
# Save app_id to config
|
||||||
app_config = {
|
app_config = {
|
||||||
@@ -99,9 +98,7 @@ def import_workflow_app() -> None:
|
|||||||
}
|
}
|
||||||
|
|
||||||
if config_helper.write_config("app_config", app_config):
|
if config_helper.write_config("app_config", app_config):
|
||||||
log.info(
|
log.info(f"App config saved to: {config_helper.get_config_path('benchmark_state')}")
|
||||||
f"App config saved to: {config_helper.get_config_path('benchmark_state')}"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
log.error("Import completed but no app_id received")
|
log.error("Import completed but no app_id received")
|
||||||
log.debug(f"Response: {json.dumps(response_data, indent=2)}")
|
log.debug(f"Response: {json.dumps(response_data, indent=2)}")
|
||||||
|
|||||||
@@ -5,10 +5,10 @@ from pathlib import Path
|
|||||||
|
|
||||||
sys.path.append(str(Path(__file__).parent.parent))
|
sys.path.append(str(Path(__file__).parent.parent))
|
||||||
|
|
||||||
import httpx
|
|
||||||
import time
|
import time
|
||||||
from common import config_helper
|
|
||||||
from common import Logger
|
import httpx
|
||||||
|
from common import Logger, config_helper
|
||||||
|
|
||||||
|
|
||||||
def install_openai_plugin() -> None:
|
def install_openai_plugin() -> None:
|
||||||
@@ -28,9 +28,7 @@ def install_openai_plugin() -> None:
|
|||||||
|
|
||||||
# API endpoint for plugin installation
|
# API endpoint for plugin installation
|
||||||
base_url = "http://localhost:5001"
|
base_url = "http://localhost:5001"
|
||||||
install_endpoint = (
|
install_endpoint = f"{base_url}/console/api/workspaces/current/plugin/install/marketplace"
|
||||||
f"{base_url}/console/api/workspaces/current/plugin/install/marketplace"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Plugin identifier
|
# Plugin identifier
|
||||||
plugin_payload = {
|
plugin_payload = {
|
||||||
@@ -83,9 +81,7 @@ def install_openai_plugin() -> None:
|
|||||||
log.info("Polling for task completion...")
|
log.info("Polling for task completion...")
|
||||||
|
|
||||||
# Poll for task completion
|
# Poll for task completion
|
||||||
task_endpoint = (
|
task_endpoint = f"{base_url}/console/api/workspaces/current/plugin/tasks/{task_id}"
|
||||||
f"{base_url}/console/api/workspaces/current/plugin/tasks/{task_id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
max_attempts = 30 # 30 attempts with 2 second delay = 60 seconds max
|
max_attempts = 30 # 30 attempts with 2 second delay = 60 seconds max
|
||||||
attempt = 0
|
attempt = 0
|
||||||
@@ -131,9 +127,7 @@ def install_openai_plugin() -> None:
|
|||||||
plugins = task_info.get("plugins", [])
|
plugins = task_info.get("plugins", [])
|
||||||
if plugins:
|
if plugins:
|
||||||
for plugin in plugins:
|
for plugin in plugins:
|
||||||
log.list_item(
|
log.list_item(f"{plugin.get('plugin_id')}: {plugin.get('message')}")
|
||||||
f"{plugin.get('plugin_id')}: {plugin.get('message')}"
|
|
||||||
)
|
|
||||||
break
|
break
|
||||||
|
|
||||||
# Continue polling if status is "pending" or other
|
# Continue polling if status is "pending" or other
|
||||||
@@ -149,9 +143,7 @@ def install_openai_plugin() -> None:
|
|||||||
log.warning("Plugin may already be installed")
|
log.warning("Plugin may already be installed")
|
||||||
log.debug(f"Response: {response.text}")
|
log.debug(f"Response: {response.text}")
|
||||||
else:
|
else:
|
||||||
log.error(
|
log.error(f"Installation failed with status code: {response.status_code}")
|
||||||
f"Installation failed with status code: {response.status_code}"
|
|
||||||
)
|
|
||||||
log.debug(f"Response: {response.text}")
|
log.debug(f"Response: {response.text}")
|
||||||
|
|
||||||
except httpx.ConnectError:
|
except httpx.ConnectError:
|
||||||
|
|||||||
@@ -5,10 +5,10 @@ from pathlib import Path
|
|||||||
|
|
||||||
sys.path.append(str(Path(__file__).parent.parent))
|
sys.path.append(str(Path(__file__).parent.parent))
|
||||||
|
|
||||||
import httpx
|
|
||||||
import json
|
import json
|
||||||
from common import config_helper
|
|
||||||
from common import Logger
|
import httpx
|
||||||
|
from common import Logger, config_helper
|
||||||
|
|
||||||
|
|
||||||
def login_admin() -> None:
|
def login_admin() -> None:
|
||||||
@@ -77,16 +77,10 @@ def login_admin() -> None:
|
|||||||
|
|
||||||
# Save token config
|
# Save token config
|
||||||
if config_helper.write_config("token_config", token_config):
|
if config_helper.write_config("token_config", token_config):
|
||||||
log.info(
|
log.info(f"Token saved to: {config_helper.get_config_path('benchmark_state')}")
|
||||||
f"Token saved to: {config_helper.get_config_path('benchmark_state')}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Show truncated token for verification
|
# Show truncated token for verification
|
||||||
token_display = (
|
token_display = f"{access_token[:20]}..." if len(access_token) > 20 else "Token saved"
|
||||||
f"{access_token[:20]}..."
|
|
||||||
if len(access_token) > 20
|
|
||||||
else "Token saved"
|
|
||||||
)
|
|
||||||
log.key_value("Access token", token_display)
|
log.key_value("Access token", token_display)
|
||||||
|
|
||||||
elif response.status_code == 401:
|
elif response.status_code == 401:
|
||||||
|
|||||||
@@ -3,8 +3,10 @@
|
|||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Iterator
|
from collections.abc import Iterator
|
||||||
from flask import Flask, request, jsonify, Response
|
from typing import Any
|
||||||
|
|
||||||
|
from flask import Flask, Response, jsonify, request
|
||||||
|
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -5,10 +5,10 @@ from pathlib import Path
|
|||||||
|
|
||||||
sys.path.append(str(Path(__file__).parent.parent))
|
sys.path.append(str(Path(__file__).parent.parent))
|
||||||
|
|
||||||
import httpx
|
|
||||||
import json
|
import json
|
||||||
from common import config_helper
|
|
||||||
from common import Logger
|
import httpx
|
||||||
|
from common import Logger, config_helper
|
||||||
|
|
||||||
|
|
||||||
def publish_workflow() -> None:
|
def publish_workflow() -> None:
|
||||||
@@ -79,9 +79,7 @@ def publish_workflow() -> None:
|
|||||||
try:
|
try:
|
||||||
response_data = response.json()
|
response_data = response.json()
|
||||||
if response_data:
|
if response_data:
|
||||||
log.debug(
|
log.debug(f"Response: {json.dumps(response_data, indent=2)}")
|
||||||
f"Response: {json.dumps(response_data, indent=2)}"
|
|
||||||
)
|
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
# Response might be empty or non-JSON
|
# Response might be empty or non-JSON
|
||||||
pass
|
pass
|
||||||
@@ -93,9 +91,7 @@ def publish_workflow() -> None:
|
|||||||
log.error("Workflow publish failed: App not found")
|
log.error("Workflow publish failed: App not found")
|
||||||
log.info("Make sure the app was imported successfully")
|
log.info("Make sure the app was imported successfully")
|
||||||
else:
|
else:
|
||||||
log.error(
|
log.error(f"Workflow publish failed with status code: {response.status_code}")
|
||||||
f"Workflow publish failed with status code: {response.status_code}"
|
|
||||||
)
|
|
||||||
log.debug(f"Response: {response.text}")
|
log.debug(f"Response: {response.text}")
|
||||||
|
|
||||||
except httpx.ConnectError:
|
except httpx.ConnectError:
|
||||||
|
|||||||
@@ -5,9 +5,10 @@ from pathlib import Path
|
|||||||
|
|
||||||
sys.path.append(str(Path(__file__).parent.parent))
|
sys.path.append(str(Path(__file__).parent.parent))
|
||||||
|
|
||||||
import httpx
|
|
||||||
import json
|
import json
|
||||||
from common import config_helper, Logger
|
|
||||||
|
import httpx
|
||||||
|
from common import Logger, config_helper
|
||||||
|
|
||||||
|
|
||||||
def run_workflow(question: str = "fake question", streaming: bool = True) -> None:
|
def run_workflow(question: str = "fake question", streaming: bool = True) -> None:
|
||||||
@@ -70,9 +71,7 @@ def run_workflow(question: str = "fake question", streaming: bool = True) -> Non
|
|||||||
event = data.get("event")
|
event = data.get("event")
|
||||||
|
|
||||||
if event == "workflow_started":
|
if event == "workflow_started":
|
||||||
log.progress(
|
log.progress(f"Workflow started: {data.get('data', {}).get('id')}")
|
||||||
f"Workflow started: {data.get('data', {}).get('id')}"
|
|
||||||
)
|
|
||||||
elif event == "node_started":
|
elif event == "node_started":
|
||||||
node_data = data.get("data", {})
|
node_data = data.get("data", {})
|
||||||
log.progress(
|
log.progress(
|
||||||
@@ -116,9 +115,7 @@ def run_workflow(question: str = "fake question", streaming: bool = True) -> Non
|
|||||||
# Some lines might not be JSON
|
# Some lines might not be JSON
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
log.error(
|
log.error(f"Workflow run failed with status code: {response.status_code}")
|
||||||
f"Workflow run failed with status code: {response.status_code}"
|
|
||||||
)
|
|
||||||
log.debug(f"Response: {response.text}")
|
log.debug(f"Response: {response.text}")
|
||||||
else:
|
else:
|
||||||
# Handle blocking response
|
# Handle blocking response
|
||||||
@@ -142,9 +139,7 @@ def run_workflow(question: str = "fake question", streaming: bool = True) -> Non
|
|||||||
log.info("📤 Final Answer:")
|
log.info("📤 Final Answer:")
|
||||||
log.info(outputs.get("answer"), indent=2)
|
log.info(outputs.get("answer"), indent=2)
|
||||||
else:
|
else:
|
||||||
log.error(
|
log.error(f"Workflow run failed with status code: {response.status_code}")
|
||||||
f"Workflow run failed with status code: {response.status_code}"
|
|
||||||
)
|
|
||||||
log.debug(f"Response: {response.text}")
|
log.debug(f"Response: {response.text}")
|
||||||
|
|
||||||
except httpx.ConnectError:
|
except httpx.ConnectError:
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from pathlib import Path
|
|||||||
sys.path.append(str(Path(__file__).parent.parent))
|
sys.path.append(str(Path(__file__).parent.parent))
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from common import config_helper, Logger
|
from common import Logger, config_helper
|
||||||
|
|
||||||
|
|
||||||
def setup_admin_account() -> None:
|
def setup_admin_account() -> None:
|
||||||
@@ -24,9 +24,7 @@ def setup_admin_account() -> None:
|
|||||||
|
|
||||||
# Save credentials to config file
|
# Save credentials to config file
|
||||||
if config_helper.write_config("admin_config", admin_config):
|
if config_helper.write_config("admin_config", admin_config):
|
||||||
log.info(
|
log.info(f"Admin credentials saved to: {config_helper.get_config_path('benchmark_state')}")
|
||||||
f"Admin credentials saved to: {config_helper.get_config_path('benchmark_state')}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# API setup endpoint
|
# API setup endpoint
|
||||||
base_url = "http://localhost:5001"
|
base_url = "http://localhost:5001"
|
||||||
@@ -56,9 +54,7 @@ def setup_admin_account() -> None:
|
|||||||
log.key_value("Username", admin_config["username"])
|
log.key_value("Username", admin_config["username"])
|
||||||
|
|
||||||
elif response.status_code == 400:
|
elif response.status_code == 400:
|
||||||
log.warning(
|
log.warning("Setup may have already been completed or invalid data provided")
|
||||||
"Setup may have already been completed or invalid data provided"
|
|
||||||
)
|
|
||||||
log.debug(f"Response: {response.text}")
|
log.debug(f"Response: {response.text}")
|
||||||
else:
|
else:
|
||||||
log.error(f"Setup failed with status code: {response.status_code}")
|
log.error(f"Setup failed with status code: {response.status_code}")
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
import socket
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import socket
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from common import Logger, ProgressLogger
|
from common import Logger, ProgressLogger
|
||||||
@@ -93,9 +93,7 @@ def main() -> None:
|
|||||||
if retry.lower() in ["yes", "y"]:
|
if retry.lower() in ["yes", "y"]:
|
||||||
return main() # Recursively call main to check again
|
return main() # Recursively call main to check again
|
||||||
else:
|
else:
|
||||||
print(
|
print("❌ Setup cancelled. Please start the required services and try again.")
|
||||||
"❌ Setup cancelled. Please start the required services and try again."
|
|
||||||
)
|
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
log.success("All required services are running!")
|
log.success("All required services are running!")
|
||||||
|
|||||||
@@ -7,29 +7,28 @@ measuring key metrics like connection rate, event throughput, and time to first
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import time
|
import logging
|
||||||
|
import os
|
||||||
import random
|
import random
|
||||||
|
import statistics
|
||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
import os
|
import time
|
||||||
import logging
|
|
||||||
import statistics
|
|
||||||
from pathlib import Path
|
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
from dataclasses import asdict, dataclass
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from dataclasses import dataclass, asdict
|
from pathlib import Path
|
||||||
from locust import HttpUser, task, between, events, constant
|
from typing import Literal, TypeAlias, TypedDict
|
||||||
from typing import TypedDict, Literal, TypeAlias
|
|
||||||
import requests.exceptions
|
import requests.exceptions
|
||||||
|
from locust import HttpUser, between, constant, events, task
|
||||||
|
|
||||||
# Add the stress-test directory to path to import common modules
|
# Add the stress-test directory to path to import common modules
|
||||||
sys.path.insert(0, str(Path(__file__).parent))
|
sys.path.insert(0, str(Path(__file__).parent))
|
||||||
from common.config_helper import ConfigHelper # type: ignore[import-not-found]
|
from common.config_helper import ConfigHelper # type: ignore[import-not-found]
|
||||||
|
|
||||||
# Configure logging
|
# Configure logging
|
||||||
logging.basicConfig(
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||||
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
|
||||||
)
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Configuration from environment
|
# Configuration from environment
|
||||||
@@ -54,6 +53,7 @@ ErrorType: TypeAlias = Literal[
|
|||||||
|
|
||||||
class ErrorCounts(TypedDict):
|
class ErrorCounts(TypedDict):
|
||||||
"""Error count tracking"""
|
"""Error count tracking"""
|
||||||
|
|
||||||
connection_error: int
|
connection_error: int
|
||||||
timeout: int
|
timeout: int
|
||||||
invalid_json: int
|
invalid_json: int
|
||||||
@@ -65,6 +65,7 @@ class ErrorCounts(TypedDict):
|
|||||||
|
|
||||||
class SSEEvent(TypedDict):
|
class SSEEvent(TypedDict):
|
||||||
"""Server-Sent Event structure"""
|
"""Server-Sent Event structure"""
|
||||||
|
|
||||||
data: str
|
data: str
|
||||||
event: str
|
event: str
|
||||||
id: str | None
|
id: str | None
|
||||||
@@ -72,11 +73,13 @@ class SSEEvent(TypedDict):
|
|||||||
|
|
||||||
class WorkflowInputs(TypedDict):
|
class WorkflowInputs(TypedDict):
|
||||||
"""Workflow input structure"""
|
"""Workflow input structure"""
|
||||||
|
|
||||||
question: str
|
question: str
|
||||||
|
|
||||||
|
|
||||||
class WorkflowRequestData(TypedDict):
|
class WorkflowRequestData(TypedDict):
|
||||||
"""Workflow request payload"""
|
"""Workflow request payload"""
|
||||||
|
|
||||||
inputs: WorkflowInputs
|
inputs: WorkflowInputs
|
||||||
response_mode: Literal["streaming"]
|
response_mode: Literal["streaming"]
|
||||||
user: str
|
user: str
|
||||||
@@ -84,6 +87,7 @@ class WorkflowRequestData(TypedDict):
|
|||||||
|
|
||||||
class ParsedEventData(TypedDict, total=False):
|
class ParsedEventData(TypedDict, total=False):
|
||||||
"""Parsed event data from SSE stream"""
|
"""Parsed event data from SSE stream"""
|
||||||
|
|
||||||
event: str
|
event: str
|
||||||
task_id: str
|
task_id: str
|
||||||
workflow_run_id: str
|
workflow_run_id: str
|
||||||
@@ -93,6 +97,7 @@ class ParsedEventData(TypedDict, total=False):
|
|||||||
|
|
||||||
class LocustStats(TypedDict):
|
class LocustStats(TypedDict):
|
||||||
"""Locust statistics structure"""
|
"""Locust statistics structure"""
|
||||||
|
|
||||||
total_requests: int
|
total_requests: int
|
||||||
total_failures: int
|
total_failures: int
|
||||||
avg_response_time: float
|
avg_response_time: float
|
||||||
@@ -102,6 +107,7 @@ class LocustStats(TypedDict):
|
|||||||
|
|
||||||
class ReportData(TypedDict):
|
class ReportData(TypedDict):
|
||||||
"""JSON report structure"""
|
"""JSON report structure"""
|
||||||
|
|
||||||
timestamp: str
|
timestamp: str
|
||||||
duration_seconds: float
|
duration_seconds: float
|
||||||
metrics: dict[str, object] # Metrics as dict for JSON serialization
|
metrics: dict[str, object] # Metrics as dict for JSON serialization
|
||||||
@@ -154,7 +160,7 @@ class MetricsTracker:
|
|||||||
self.total_connections = 0
|
self.total_connections = 0
|
||||||
self.total_events = 0
|
self.total_events = 0
|
||||||
self.start_time = time.time()
|
self.start_time = time.time()
|
||||||
|
|
||||||
# Enhanced metrics with memory limits
|
# Enhanced metrics with memory limits
|
||||||
self.max_samples = 10000 # Prevent unbounded growth
|
self.max_samples = 10000 # Prevent unbounded growth
|
||||||
self.ttfe_samples: deque[float] = deque(maxlen=self.max_samples)
|
self.ttfe_samples: deque[float] = deque(maxlen=self.max_samples)
|
||||||
@@ -233,9 +239,7 @@ class MetricsTracker:
|
|||||||
max_ttfe = max(self.ttfe_samples)
|
max_ttfe = max(self.ttfe_samples)
|
||||||
p50_ttfe = statistics.median(self.ttfe_samples)
|
p50_ttfe = statistics.median(self.ttfe_samples)
|
||||||
if len(self.ttfe_samples) >= 2:
|
if len(self.ttfe_samples) >= 2:
|
||||||
quantiles = statistics.quantiles(
|
quantiles = statistics.quantiles(self.ttfe_samples, n=20, method="inclusive")
|
||||||
self.ttfe_samples, n=20, method="inclusive"
|
|
||||||
)
|
|
||||||
p95_ttfe = quantiles[18] # 19th of 19 quantiles = 95th percentile
|
p95_ttfe = quantiles[18] # 19th of 19 quantiles = 95th percentile
|
||||||
else:
|
else:
|
||||||
p95_ttfe = max_ttfe
|
p95_ttfe = max_ttfe
|
||||||
@@ -255,9 +259,7 @@ class MetricsTracker:
|
|||||||
if durations
|
if durations
|
||||||
else 0
|
else 0
|
||||||
)
|
)
|
||||||
events_per_stream_avg = (
|
events_per_stream_avg = statistics.mean(events_per_stream) if events_per_stream else 0
|
||||||
statistics.mean(events_per_stream) if events_per_stream else 0
|
|
||||||
)
|
|
||||||
|
|
||||||
# Calculate inter-event latency statistics
|
# Calculate inter-event latency statistics
|
||||||
all_inter_event_times = []
|
all_inter_event_times = []
|
||||||
@@ -268,32 +270,20 @@ class MetricsTracker:
|
|||||||
inter_event_latency_avg = statistics.mean(all_inter_event_times)
|
inter_event_latency_avg = statistics.mean(all_inter_event_times)
|
||||||
inter_event_latency_p50 = statistics.median(all_inter_event_times)
|
inter_event_latency_p50 = statistics.median(all_inter_event_times)
|
||||||
inter_event_latency_p95 = (
|
inter_event_latency_p95 = (
|
||||||
statistics.quantiles(
|
statistics.quantiles(all_inter_event_times, n=20, method="inclusive")[18]
|
||||||
all_inter_event_times, n=20, method="inclusive"
|
|
||||||
)[18]
|
|
||||||
if len(all_inter_event_times) >= 2
|
if len(all_inter_event_times) >= 2
|
||||||
else max(all_inter_event_times)
|
else max(all_inter_event_times)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
inter_event_latency_avg = inter_event_latency_p50 = (
|
inter_event_latency_avg = inter_event_latency_p50 = inter_event_latency_p95 = 0
|
||||||
inter_event_latency_p95
|
|
||||||
) = 0
|
|
||||||
else:
|
else:
|
||||||
stream_duration_avg = stream_duration_p50 = stream_duration_p95 = (
|
stream_duration_avg = stream_duration_p50 = stream_duration_p95 = events_per_stream_avg = 0
|
||||||
events_per_stream_avg
|
inter_event_latency_avg = inter_event_latency_p50 = inter_event_latency_p95 = 0
|
||||||
) = 0
|
|
||||||
inter_event_latency_avg = inter_event_latency_p50 = (
|
|
||||||
inter_event_latency_p95
|
|
||||||
) = 0
|
|
||||||
|
|
||||||
# Also calculate overall average rates
|
# Also calculate overall average rates
|
||||||
total_elapsed = current_time - self.start_time
|
total_elapsed = current_time - self.start_time
|
||||||
overall_conn_rate = (
|
overall_conn_rate = self.total_connections / total_elapsed if total_elapsed > 0 else 0
|
||||||
self.total_connections / total_elapsed if total_elapsed > 0 else 0
|
overall_event_rate = self.total_events / total_elapsed if total_elapsed > 0 else 0
|
||||||
)
|
|
||||||
overall_event_rate = (
|
|
||||||
self.total_events / total_elapsed if total_elapsed > 0 else 0
|
|
||||||
)
|
|
||||||
|
|
||||||
return MetricsSnapshot(
|
return MetricsSnapshot(
|
||||||
active_connections=self.active_connections,
|
active_connections=self.active_connections,
|
||||||
@@ -389,7 +379,7 @@ class DifyWorkflowUser(HttpUser):
|
|||||||
|
|
||||||
# Load questions from file or use defaults
|
# Load questions from file or use defaults
|
||||||
if QUESTIONS_FILE and os.path.exists(QUESTIONS_FILE):
|
if QUESTIONS_FILE and os.path.exists(QUESTIONS_FILE):
|
||||||
with open(QUESTIONS_FILE, "r") as f:
|
with open(QUESTIONS_FILE) as f:
|
||||||
self.questions = [line.strip() for line in f if line.strip()]
|
self.questions = [line.strip() for line in f if line.strip()]
|
||||||
else:
|
else:
|
||||||
self.questions = [
|
self.questions = [
|
||||||
@@ -451,18 +441,13 @@ class DifyWorkflowUser(HttpUser):
|
|||||||
try:
|
try:
|
||||||
# Validate response
|
# Validate response
|
||||||
if response.status_code >= 400:
|
if response.status_code >= 400:
|
||||||
error_type: ErrorType = (
|
error_type: ErrorType = "http_4xx" if response.status_code < 500 else "http_5xx"
|
||||||
"http_4xx" if response.status_code < 500 else "http_5xx"
|
|
||||||
)
|
|
||||||
metrics.record_error(error_type)
|
metrics.record_error(error_type)
|
||||||
response.failure(f"HTTP {response.status_code}")
|
response.failure(f"HTTP {response.status_code}")
|
||||||
return
|
return
|
||||||
|
|
||||||
content_type = response.headers.get("Content-Type", "")
|
content_type = response.headers.get("Content-Type", "")
|
||||||
if (
|
if "text/event-stream" not in content_type and "application/json" not in content_type:
|
||||||
"text/event-stream" not in content_type
|
|
||||||
and "application/json" not in content_type
|
|
||||||
):
|
|
||||||
logger.error(f"Expected text/event-stream, got: {content_type}")
|
logger.error(f"Expected text/event-stream, got: {content_type}")
|
||||||
metrics.record_error("invalid_response")
|
metrics.record_error("invalid_response")
|
||||||
response.failure(f"Invalid content type: {content_type}")
|
response.failure(f"Invalid content type: {content_type}")
|
||||||
@@ -473,10 +458,13 @@ class DifyWorkflowUser(HttpUser):
|
|||||||
|
|
||||||
for line in response.iter_lines(decode_unicode=True):
|
for line in response.iter_lines(decode_unicode=True):
|
||||||
# Check if runner is stopping
|
# Check if runner is stopping
|
||||||
if getattr(self.environment.runner, 'state', '') in ('stopping', 'stopped'):
|
if getattr(self.environment.runner, "state", "") in (
|
||||||
|
"stopping",
|
||||||
|
"stopped",
|
||||||
|
):
|
||||||
logger.debug("Runner stopping, breaking streaming loop")
|
logger.debug("Runner stopping, breaking streaming loop")
|
||||||
break
|
break
|
||||||
|
|
||||||
if line is not None:
|
if line is not None:
|
||||||
bytes_received += len(line.encode("utf-8"))
|
bytes_received += len(line.encode("utf-8"))
|
||||||
|
|
||||||
@@ -489,9 +477,7 @@ class DifyWorkflowUser(HttpUser):
|
|||||||
|
|
||||||
# Track inter-event timing
|
# Track inter-event timing
|
||||||
if last_event_time:
|
if last_event_time:
|
||||||
inter_event_times.append(
|
inter_event_times.append((current_time - last_event_time) * 1000)
|
||||||
(current_time - last_event_time) * 1000
|
|
||||||
)
|
|
||||||
last_event_time = current_time
|
last_event_time = current_time
|
||||||
|
|
||||||
if first_event_time is None:
|
if first_event_time is None:
|
||||||
@@ -512,15 +498,11 @@ class DifyWorkflowUser(HttpUser):
|
|||||||
parsed_event: ParsedEventData = json.loads(event_data)
|
parsed_event: ParsedEventData = json.loads(event_data)
|
||||||
# Check for terminal events
|
# Check for terminal events
|
||||||
if parsed_event.get("event") in TERMINAL_EVENTS:
|
if parsed_event.get("event") in TERMINAL_EVENTS:
|
||||||
logger.debug(
|
logger.debug(f"Received terminal event: {parsed_event.get('event')}")
|
||||||
f"Received terminal event: {parsed_event.get('event')}"
|
|
||||||
)
|
|
||||||
request_success = True
|
request_success = True
|
||||||
break
|
break
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
logger.debug(
|
logger.debug(f"JSON decode error: {e} for data: {event_data[:100]}")
|
||||||
f"JSON decode error: {e} for data: {event_data[:100]}"
|
|
||||||
)
|
|
||||||
metrics.record_error("invalid_json")
|
metrics.record_error("invalid_json")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -583,16 +565,18 @@ def on_test_start(environment: object, **kwargs: object) -> None:
|
|||||||
|
|
||||||
# Periodic stats reporting
|
# Periodic stats reporting
|
||||||
def report_stats() -> None:
|
def report_stats() -> None:
|
||||||
if not hasattr(environment, 'runner'):
|
if not hasattr(environment, "runner"):
|
||||||
return
|
return
|
||||||
runner = environment.runner
|
runner = environment.runner
|
||||||
while hasattr(runner, 'state') and runner.state not in ["stopped", "stopping"]:
|
while hasattr(runner, "state") and runner.state not in ["stopped", "stopping"]:
|
||||||
time.sleep(5) # Report every 5 seconds
|
time.sleep(5) # Report every 5 seconds
|
||||||
if hasattr(runner, 'state') and runner.state == "running":
|
if hasattr(runner, "state") and runner.state == "running":
|
||||||
stats = metrics.get_stats()
|
stats = metrics.get_stats()
|
||||||
|
|
||||||
# Only log on master node in distributed mode
|
# Only log on master node in distributed mode
|
||||||
is_master = not getattr(environment.runner, "worker_id", None) if hasattr(environment, 'runner') else True
|
is_master = (
|
||||||
|
not getattr(environment.runner, "worker_id", None) if hasattr(environment, "runner") else True
|
||||||
|
)
|
||||||
if is_master:
|
if is_master:
|
||||||
# Clear previous lines and show updated stats
|
# Clear previous lines and show updated stats
|
||||||
logger.info("\n" + "=" * 80)
|
logger.info("\n" + "=" * 80)
|
||||||
@@ -623,15 +607,15 @@ def on_test_start(environment: object, **kwargs: object) -> None:
|
|||||||
logger.info(
|
logger.info(
|
||||||
f"{'(TTFE in ms)':<25} {stats.ttfe_avg:>15.1f} {stats.ttfe_p50:>10.1f} {stats.ttfe_p95:>10.1f} {stats.ttfe_min:>10.1f} {stats.ttfe_max:>10.1f}"
|
f"{'(TTFE in ms)':<25} {stats.ttfe_avg:>15.1f} {stats.ttfe_p50:>10.1f} {stats.ttfe_p95:>10.1f} {stats.ttfe_min:>10.1f} {stats.ttfe_max:>10.1f}"
|
||||||
)
|
)
|
||||||
logger.info(f"{'Window Samples':<25} {stats.ttfe_samples:>15,d} (last {min(10000, stats.ttfe_total_samples):,d} samples)")
|
logger.info(
|
||||||
|
f"{'Window Samples':<25} {stats.ttfe_samples:>15,d} (last {min(10000, stats.ttfe_total_samples):,d} samples)"
|
||||||
|
)
|
||||||
logger.info(f"{'Total Samples':<25} {stats.ttfe_total_samples:>15,d}")
|
logger.info(f"{'Total Samples':<25} {stats.ttfe_total_samples:>15,d}")
|
||||||
|
|
||||||
# Inter-event latency
|
# Inter-event latency
|
||||||
if stats.inter_event_latency_avg > 0:
|
if stats.inter_event_latency_avg > 0:
|
||||||
logger.info("-" * 80)
|
logger.info("-" * 80)
|
||||||
logger.info(
|
logger.info(f"{'INTER-EVENT LATENCY':<25} {'AVG':>15} {'P50':>10} {'P95':>10}")
|
||||||
f"{'INTER-EVENT LATENCY':<25} {'AVG':>15} {'P50':>10} {'P95':>10}"
|
|
||||||
)
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"{'(ms between events)':<25} {stats.inter_event_latency_avg:>15.1f} {stats.inter_event_latency_p50:>10.1f} {stats.inter_event_latency_p95:>10.1f}"
|
f"{'(ms between events)':<25} {stats.inter_event_latency_avg:>15.1f} {stats.inter_event_latency_p50:>10.1f} {stats.inter_event_latency_p95:>10.1f}"
|
||||||
)
|
)
|
||||||
@@ -647,9 +631,9 @@ def on_test_start(environment: object, **kwargs: object) -> None:
|
|||||||
logger.info("=" * 80)
|
logger.info("=" * 80)
|
||||||
|
|
||||||
# Show Locust stats summary
|
# Show Locust stats summary
|
||||||
if hasattr(environment, 'stats') and hasattr(environment.stats, 'total'):
|
if hasattr(environment, "stats") and hasattr(environment.stats, "total"):
|
||||||
total = environment.stats.total
|
total = environment.stats.total
|
||||||
if hasattr(total, 'num_requests') and total.num_requests > 0:
|
if hasattr(total, "num_requests") and total.num_requests > 0:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"{'LOCUST STATS':<25} {'Requests':>12} {'Fails':>8} {'Avg (ms)':>12} {'Min':>8} {'Max':>8}"
|
f"{'LOCUST STATS':<25} {'Requests':>12} {'Fails':>8} {'Avg (ms)':>12} {'Min':>8} {'Max':>8}"
|
||||||
)
|
)
|
||||||
@@ -687,21 +671,15 @@ def on_test_stop(environment: object, **kwargs: object) -> None:
|
|||||||
logger.info("")
|
logger.info("")
|
||||||
logger.info("EVENTS")
|
logger.info("EVENTS")
|
||||||
logger.info(f" {'Total Events Received:':<30} {stats.total_events:>10,d}")
|
logger.info(f" {'Total Events Received:':<30} {stats.total_events:>10,d}")
|
||||||
logger.info(
|
logger.info(f" {'Average Throughput:':<30} {stats.overall_event_rate:>10.2f} events/s")
|
||||||
f" {'Average Throughput:':<30} {stats.overall_event_rate:>10.2f} events/s"
|
logger.info(f" {'Final Rate (10s window):':<30} {stats.event_rate:>10.2f} events/s")
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
f" {'Final Rate (10s window):':<30} {stats.event_rate:>10.2f} events/s"
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info("")
|
logger.info("")
|
||||||
logger.info("STREAM METRICS")
|
logger.info("STREAM METRICS")
|
||||||
logger.info(f" {'Avg Stream Duration:':<30} {stats.stream_duration_avg:>10.1f} ms")
|
logger.info(f" {'Avg Stream Duration:':<30} {stats.stream_duration_avg:>10.1f} ms")
|
||||||
logger.info(f" {'P50 Stream Duration:':<30} {stats.stream_duration_p50:>10.1f} ms")
|
logger.info(f" {'P50 Stream Duration:':<30} {stats.stream_duration_p50:>10.1f} ms")
|
||||||
logger.info(f" {'P95 Stream Duration:':<30} {stats.stream_duration_p95:>10.1f} ms")
|
logger.info(f" {'P95 Stream Duration:':<30} {stats.stream_duration_p95:>10.1f} ms")
|
||||||
logger.info(
|
logger.info(f" {'Avg Events per Stream:':<30} {stats.events_per_stream_avg:>10.1f}")
|
||||||
f" {'Avg Events per Stream:':<30} {stats.events_per_stream_avg:>10.1f}"
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info("")
|
logger.info("")
|
||||||
logger.info("INTER-EVENT LATENCY")
|
logger.info("INTER-EVENT LATENCY")
|
||||||
@@ -716,7 +694,9 @@ def on_test_stop(environment: object, **kwargs: object) -> None:
|
|||||||
logger.info(f" {'95th Percentile:':<30} {stats.ttfe_p95:>10.1f} ms")
|
logger.info(f" {'95th Percentile:':<30} {stats.ttfe_p95:>10.1f} ms")
|
||||||
logger.info(f" {'Minimum:':<30} {stats.ttfe_min:>10.1f} ms")
|
logger.info(f" {'Minimum:':<30} {stats.ttfe_min:>10.1f} ms")
|
||||||
logger.info(f" {'Maximum:':<30} {stats.ttfe_max:>10.1f} ms")
|
logger.info(f" {'Maximum:':<30} {stats.ttfe_max:>10.1f} ms")
|
||||||
logger.info(f" {'Window Samples:':<30} {stats.ttfe_samples:>10,d} (last {min(10000, stats.ttfe_total_samples):,d})")
|
logger.info(
|
||||||
|
f" {'Window Samples:':<30} {stats.ttfe_samples:>10,d} (last {min(10000, stats.ttfe_total_samples):,d})"
|
||||||
|
)
|
||||||
logger.info(f" {'Total Samples:':<30} {stats.ttfe_total_samples:>10,d}")
|
logger.info(f" {'Total Samples:':<30} {stats.ttfe_total_samples:>10,d}")
|
||||||
|
|
||||||
# Error summary
|
# Error summary
|
||||||
@@ -730,7 +710,7 @@ def on_test_stop(environment: object, **kwargs: object) -> None:
|
|||||||
logger.info("=" * 80 + "\n")
|
logger.info("=" * 80 + "\n")
|
||||||
|
|
||||||
# Export machine-readable report (only on master node)
|
# Export machine-readable report (only on master node)
|
||||||
is_master = not getattr(environment.runner, 'worker_id', None) if hasattr(environment, 'runner') else True
|
is_master = not getattr(environment.runner, "worker_id", None) if hasattr(environment, "runner") else True
|
||||||
if is_master:
|
if is_master:
|
||||||
export_json_report(stats, test_duration, environment)
|
export_json_report(stats, test_duration, environment)
|
||||||
|
|
||||||
@@ -746,9 +726,9 @@ def export_json_report(stats: MetricsSnapshot, duration: float, environment: obj
|
|||||||
|
|
||||||
# Access environment.stats.total attributes safely
|
# Access environment.stats.total attributes safely
|
||||||
locust_stats: LocustStats | None = None
|
locust_stats: LocustStats | None = None
|
||||||
if hasattr(environment, 'stats') and hasattr(environment.stats, 'total'):
|
if hasattr(environment, "stats") and hasattr(environment.stats, "total"):
|
||||||
total = environment.stats.total
|
total = environment.stats.total
|
||||||
if hasattr(total, 'num_requests') and total.num_requests > 0:
|
if hasattr(total, "num_requests") and total.num_requests > 0:
|
||||||
locust_stats = LocustStats(
|
locust_stats = LocustStats(
|
||||||
total_requests=total.num_requests,
|
total_requests=total.num_requests,
|
||||||
total_failures=total.num_failures,
|
total_failures=total.num_failures,
|
||||||
|
|||||||
@@ -1,7 +1,15 @@
|
|||||||
from dify_client.client import (
|
from dify_client.client import (
|
||||||
ChatClient,
|
ChatClient,
|
||||||
CompletionClient,
|
CompletionClient,
|
||||||
WorkflowClient,
|
|
||||||
KnowledgeBaseClient,
|
|
||||||
DifyClient,
|
DifyClient,
|
||||||
|
KnowledgeBaseClient,
|
||||||
|
WorkflowClient,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ChatClient",
|
||||||
|
"CompletionClient",
|
||||||
|
"DifyClient",
|
||||||
|
"KnowledgeBaseClient",
|
||||||
|
"WorkflowClient",
|
||||||
|
]
|
||||||
|
|||||||
@@ -8,16 +8,16 @@ class DifyClient:
|
|||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
self.base_url = base_url
|
self.base_url = base_url
|
||||||
|
|
||||||
def _send_request(self, method: str, endpoint: str, json: dict | None = None, params: dict | None = None, stream: bool = False):
|
def _send_request(
|
||||||
|
self, method: str, endpoint: str, json: dict | None = None, params: dict | None = None, stream: bool = False
|
||||||
|
):
|
||||||
headers = {
|
headers = {
|
||||||
"Authorization": f"Bearer {self.api_key}",
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
}
|
}
|
||||||
|
|
||||||
url = f"{self.base_url}{endpoint}"
|
url = f"{self.base_url}{endpoint}"
|
||||||
response = requests.request(
|
response = requests.request(method, url, json=json, params=params, headers=headers, stream=stream)
|
||||||
method, url, json=json, params=params, headers=headers, stream=stream
|
|
||||||
)
|
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
@@ -25,9 +25,7 @@ class DifyClient:
|
|||||||
headers = {"Authorization": f"Bearer {self.api_key}"}
|
headers = {"Authorization": f"Bearer {self.api_key}"}
|
||||||
|
|
||||||
url = f"{self.base_url}{endpoint}"
|
url = f"{self.base_url}{endpoint}"
|
||||||
response = requests.request(
|
response = requests.request(method, url, data=data, headers=headers, files=files)
|
||||||
method, url, data=data, headers=headers, files=files
|
|
||||||
)
|
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
@@ -41,9 +39,7 @@ class DifyClient:
|
|||||||
|
|
||||||
def file_upload(self, user: str, files: dict):
|
def file_upload(self, user: str, files: dict):
|
||||||
data = {"user": user}
|
data = {"user": user}
|
||||||
return self._send_request_with_files(
|
return self._send_request_with_files("POST", "/files/upload", data=data, files=files)
|
||||||
"POST", "/files/upload", data=data, files=files
|
|
||||||
)
|
|
||||||
|
|
||||||
def text_to_audio(self, text: str, user: str, streaming: bool = False):
|
def text_to_audio(self, text: str, user: str, streaming: bool = False):
|
||||||
data = {"text": text, "user": user, "streaming": streaming}
|
data = {"text": text, "user": user, "streaming": streaming}
|
||||||
@@ -55,7 +51,9 @@ class DifyClient:
|
|||||||
|
|
||||||
|
|
||||||
class CompletionClient(DifyClient):
|
class CompletionClient(DifyClient):
|
||||||
def create_completion_message(self, inputs: dict, response_mode: Literal["blocking", "streaming"], user: str, files: dict | None = None):
|
def create_completion_message(
|
||||||
|
self, inputs: dict, response_mode: Literal["blocking", "streaming"], user: str, files: dict | None = None
|
||||||
|
):
|
||||||
data = {
|
data = {
|
||||||
"inputs": inputs,
|
"inputs": inputs,
|
||||||
"response_mode": response_mode,
|
"response_mode": response_mode,
|
||||||
@@ -99,9 +97,7 @@ class ChatClient(DifyClient):
|
|||||||
|
|
||||||
def get_suggested(self, message_id: str, user: str):
|
def get_suggested(self, message_id: str, user: str):
|
||||||
params = {"user": user}
|
params = {"user": user}
|
||||||
return self._send_request(
|
return self._send_request("GET", f"/messages/{message_id}/suggested", params=params)
|
||||||
"GET", f"/messages/{message_id}/suggested", params=params
|
|
||||||
)
|
|
||||||
|
|
||||||
def stop_message(self, task_id: str, user: str):
|
def stop_message(self, task_id: str, user: str):
|
||||||
data = {"user": user}
|
data = {"user": user}
|
||||||
@@ -112,10 +108,9 @@ class ChatClient(DifyClient):
|
|||||||
user: str,
|
user: str,
|
||||||
last_id: str | None = None,
|
last_id: str | None = None,
|
||||||
limit: int | None = None,
|
limit: int | None = None,
|
||||||
pinned: bool | None = None
|
pinned: bool | None = None,
|
||||||
):
|
):
|
||||||
params = {"user": user, "last_id": last_id,
|
params = {"user": user, "last_id": last_id, "limit": limit, "pinned": pinned}
|
||||||
"limit": limit, "pinned": pinned}
|
|
||||||
return self._send_request("GET", "/conversations", params=params)
|
return self._send_request("GET", "/conversations", params=params)
|
||||||
|
|
||||||
def get_conversation_messages(
|
def get_conversation_messages(
|
||||||
@@ -123,7 +118,7 @@ class ChatClient(DifyClient):
|
|||||||
user: str,
|
user: str,
|
||||||
conversation_id: str | None = None,
|
conversation_id: str | None = None,
|
||||||
first_id: str | None = None,
|
first_id: str | None = None,
|
||||||
limit: int | None = None
|
limit: int | None = None,
|
||||||
):
|
):
|
||||||
params = {"user": user}
|
params = {"user": user}
|
||||||
|
|
||||||
@@ -136,13 +131,9 @@ class ChatClient(DifyClient):
|
|||||||
|
|
||||||
return self._send_request("GET", "/messages", params=params)
|
return self._send_request("GET", "/messages", params=params)
|
||||||
|
|
||||||
def rename_conversation(
|
def rename_conversation(self, conversation_id: str, name: str, auto_generate: bool, user: str):
|
||||||
self, conversation_id: str, name: str, auto_generate: bool, user: str
|
|
||||||
):
|
|
||||||
data = {"name": name, "auto_generate": auto_generate, "user": user}
|
data = {"name": name, "auto_generate": auto_generate, "user": user}
|
||||||
return self._send_request(
|
return self._send_request("POST", f"/conversations/{conversation_id}/name", data)
|
||||||
"POST", f"/conversations/{conversation_id}/name", data
|
|
||||||
)
|
|
||||||
|
|
||||||
def delete_conversation(self, conversation_id: str, user: str):
|
def delete_conversation(self, conversation_id: str, user: str):
|
||||||
data = {"user": user}
|
data = {"user": user}
|
||||||
@@ -155,9 +146,7 @@ class ChatClient(DifyClient):
|
|||||||
|
|
||||||
|
|
||||||
class WorkflowClient(DifyClient):
|
class WorkflowClient(DifyClient):
|
||||||
def run(
|
def run(self, inputs: dict, response_mode: Literal["blocking", "streaming"] = "streaming", user: str = "abc-123"):
|
||||||
self, inputs: dict, response_mode: Literal["blocking", "streaming"] = "streaming", user: str = "abc-123"
|
|
||||||
):
|
|
||||||
data = {"inputs": inputs, "response_mode": response_mode, "user": user}
|
data = {"inputs": inputs, "response_mode": response_mode, "user": user}
|
||||||
return self._send_request("POST", "/workflows/run", data)
|
return self._send_request("POST", "/workflows/run", data)
|
||||||
|
|
||||||
@@ -197,13 +186,9 @@ class KnowledgeBaseClient(DifyClient):
|
|||||||
return self._send_request("POST", "/datasets", {"name": name}, **kwargs)
|
return self._send_request("POST", "/datasets", {"name": name}, **kwargs)
|
||||||
|
|
||||||
def list_datasets(self, page: int = 1, page_size: int = 20, **kwargs):
|
def list_datasets(self, page: int = 1, page_size: int = 20, **kwargs):
|
||||||
return self._send_request(
|
return self._send_request("GET", f"/datasets?page={page}&limit={page_size}", **kwargs)
|
||||||
"GET", f"/datasets?page={page}&limit={page_size}", **kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
def create_document_by_text(
|
def create_document_by_text(self, name, text, extra_params: dict | None = None, **kwargs):
|
||||||
self, name, text, extra_params: dict | None = None, **kwargs
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Create a document by text.
|
Create a document by text.
|
||||||
|
|
||||||
@@ -272,9 +257,7 @@ class KnowledgeBaseClient(DifyClient):
|
|||||||
data = {"name": name, "text": text}
|
data = {"name": name, "text": text}
|
||||||
if extra_params is not None and isinstance(extra_params, dict):
|
if extra_params is not None and isinstance(extra_params, dict):
|
||||||
data.update(extra_params)
|
data.update(extra_params)
|
||||||
url = (
|
url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/update_by_text"
|
||||||
f"/datasets/{self._get_dataset_id()}/documents/{document_id}/update_by_text"
|
|
||||||
)
|
|
||||||
return self._send_request("POST", url, json=data, **kwargs)
|
return self._send_request("POST", url, json=data, **kwargs)
|
||||||
|
|
||||||
def create_document_by_file(
|
def create_document_by_file(
|
||||||
@@ -315,13 +298,9 @@ class KnowledgeBaseClient(DifyClient):
|
|||||||
if original_document_id is not None:
|
if original_document_id is not None:
|
||||||
data["original_document_id"] = original_document_id
|
data["original_document_id"] = original_document_id
|
||||||
url = f"/datasets/{self._get_dataset_id()}/document/create_by_file"
|
url = f"/datasets/{self._get_dataset_id()}/document/create_by_file"
|
||||||
return self._send_request_with_files(
|
return self._send_request_with_files("POST", url, {"data": json.dumps(data)}, files)
|
||||||
"POST", url, {"data": json.dumps(data)}, files
|
|
||||||
)
|
|
||||||
|
|
||||||
def update_document_by_file(
|
def update_document_by_file(self, document_id: str, file_path: str, extra_params: dict | None = None):
|
||||||
self, document_id: str, file_path: str, extra_params: dict | None = None
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Update a document by file.
|
Update a document by file.
|
||||||
|
|
||||||
@@ -351,12 +330,8 @@ class KnowledgeBaseClient(DifyClient):
|
|||||||
data = {}
|
data = {}
|
||||||
if extra_params is not None and isinstance(extra_params, dict):
|
if extra_params is not None and isinstance(extra_params, dict):
|
||||||
data.update(extra_params)
|
data.update(extra_params)
|
||||||
url = (
|
url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/update_by_file"
|
||||||
f"/datasets/{self._get_dataset_id()}/documents/{document_id}/update_by_file"
|
return self._send_request_with_files("POST", url, {"data": json.dumps(data)}, files)
|
||||||
)
|
|
||||||
return self._send_request_with_files(
|
|
||||||
"POST", url, {"data": json.dumps(data)}, files
|
|
||||||
)
|
|
||||||
|
|
||||||
def batch_indexing_status(self, batch_id: str, **kwargs):
|
def batch_indexing_status(self, batch_id: str, **kwargs):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from setuptools import setup
|
from setuptools import setup
|
||||||
|
|
||||||
with open("README.md", "r", encoding="utf-8") as fh:
|
with open("README.md", encoding="utf-8") as fh:
|
||||||
long_description = fh.read()
|
long_description = fh.read()
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
|
|||||||
@@ -18,9 +18,7 @@ FILE_PATH_BASE = os.path.dirname(__file__)
|
|||||||
class TestKnowledgeBaseClient(unittest.TestCase):
|
class TestKnowledgeBaseClient(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.knowledge_base_client = KnowledgeBaseClient(API_KEY, base_url=API_BASE_URL)
|
self.knowledge_base_client = KnowledgeBaseClient(API_KEY, base_url=API_BASE_URL)
|
||||||
self.README_FILE_PATH = os.path.abspath(
|
self.README_FILE_PATH = os.path.abspath(os.path.join(FILE_PATH_BASE, "../README.md"))
|
||||||
os.path.join(FILE_PATH_BASE, "../README.md")
|
|
||||||
)
|
|
||||||
self.dataset_id = None
|
self.dataset_id = None
|
||||||
self.document_id = None
|
self.document_id = None
|
||||||
self.segment_id = None
|
self.segment_id = None
|
||||||
@@ -28,9 +26,7 @@ class TestKnowledgeBaseClient(unittest.TestCase):
|
|||||||
|
|
||||||
def _get_dataset_kb_client(self):
|
def _get_dataset_kb_client(self):
|
||||||
self.assertIsNotNone(self.dataset_id)
|
self.assertIsNotNone(self.dataset_id)
|
||||||
return KnowledgeBaseClient(
|
return KnowledgeBaseClient(API_KEY, base_url=API_BASE_URL, dataset_id=self.dataset_id)
|
||||||
API_KEY, base_url=API_BASE_URL, dataset_id=self.dataset_id
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_001_create_dataset(self):
|
def test_001_create_dataset(self):
|
||||||
response = self.knowledge_base_client.create_dataset(name="test_dataset")
|
response = self.knowledge_base_client.create_dataset(name="test_dataset")
|
||||||
@@ -76,9 +72,7 @@ class TestKnowledgeBaseClient(unittest.TestCase):
|
|||||||
def _test_004_update_document_by_text(self):
|
def _test_004_update_document_by_text(self):
|
||||||
client = self._get_dataset_kb_client()
|
client = self._get_dataset_kb_client()
|
||||||
self.assertIsNotNone(self.document_id)
|
self.assertIsNotNone(self.document_id)
|
||||||
response = client.update_document_by_text(
|
response = client.update_document_by_text(self.document_id, "test_document_updated", "test_text_updated")
|
||||||
self.document_id, "test_document_updated", "test_text_updated"
|
|
||||||
)
|
|
||||||
data = response.json()
|
data = response.json()
|
||||||
self.assertIn("document", data)
|
self.assertIn("document", data)
|
||||||
self.assertIn("batch", data)
|
self.assertIn("batch", data)
|
||||||
@@ -93,9 +87,7 @@ class TestKnowledgeBaseClient(unittest.TestCase):
|
|||||||
def _test_006_update_document_by_file(self):
|
def _test_006_update_document_by_file(self):
|
||||||
client = self._get_dataset_kb_client()
|
client = self._get_dataset_kb_client()
|
||||||
self.assertIsNotNone(self.document_id)
|
self.assertIsNotNone(self.document_id)
|
||||||
response = client.update_document_by_file(
|
response = client.update_document_by_file(self.document_id, self.README_FILE_PATH)
|
||||||
self.document_id, self.README_FILE_PATH
|
|
||||||
)
|
|
||||||
data = response.json()
|
data = response.json()
|
||||||
self.assertIn("document", data)
|
self.assertIn("document", data)
|
||||||
self.assertIn("batch", data)
|
self.assertIn("batch", data)
|
||||||
@@ -125,9 +117,7 @@ class TestKnowledgeBaseClient(unittest.TestCase):
|
|||||||
|
|
||||||
def _test_010_add_segments(self):
|
def _test_010_add_segments(self):
|
||||||
client = self._get_dataset_kb_client()
|
client = self._get_dataset_kb_client()
|
||||||
response = client.add_segments(
|
response = client.add_segments(self.document_id, [{"content": "test text segment 1"}])
|
||||||
self.document_id, [{"content": "test text segment 1"}]
|
|
||||||
)
|
|
||||||
data = response.json()
|
data = response.json()
|
||||||
self.assertIn("data", data)
|
self.assertIn("data", data)
|
||||||
self.assertGreater(len(data["data"]), 0)
|
self.assertGreater(len(data["data"]), 0)
|
||||||
@@ -174,18 +164,12 @@ class TestChatClient(unittest.TestCase):
|
|||||||
self.chat_client = ChatClient(API_KEY)
|
self.chat_client = ChatClient(API_KEY)
|
||||||
|
|
||||||
def test_create_chat_message(self):
|
def test_create_chat_message(self):
|
||||||
response = self.chat_client.create_chat_message(
|
response = self.chat_client.create_chat_message({}, "Hello, World!", "test_user")
|
||||||
{}, "Hello, World!", "test_user"
|
|
||||||
)
|
|
||||||
self.assertIn("answer", response.text)
|
self.assertIn("answer", response.text)
|
||||||
|
|
||||||
def test_create_chat_message_with_vision_model_by_remote_url(self):
|
def test_create_chat_message_with_vision_model_by_remote_url(self):
|
||||||
files = [
|
files = [{"type": "image", "transfer_method": "remote_url", "url": "your_image_url"}]
|
||||||
{"type": "image", "transfer_method": "remote_url", "url": "your_image_url"}
|
response = self.chat_client.create_chat_message({}, "Describe the picture.", "test_user", files=files)
|
||||||
]
|
|
||||||
response = self.chat_client.create_chat_message(
|
|
||||||
{}, "Describe the picture.", "test_user", files=files
|
|
||||||
)
|
|
||||||
self.assertIn("answer", response.text)
|
self.assertIn("answer", response.text)
|
||||||
|
|
||||||
def test_create_chat_message_with_vision_model_by_local_file(self):
|
def test_create_chat_message_with_vision_model_by_local_file(self):
|
||||||
@@ -196,15 +180,11 @@ class TestChatClient(unittest.TestCase):
|
|||||||
"upload_file_id": "your_file_id",
|
"upload_file_id": "your_file_id",
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
response = self.chat_client.create_chat_message(
|
response = self.chat_client.create_chat_message({}, "Describe the picture.", "test_user", files=files)
|
||||||
{}, "Describe the picture.", "test_user", files=files
|
|
||||||
)
|
|
||||||
self.assertIn("answer", response.text)
|
self.assertIn("answer", response.text)
|
||||||
|
|
||||||
def test_get_conversation_messages(self):
|
def test_get_conversation_messages(self):
|
||||||
response = self.chat_client.get_conversation_messages(
|
response = self.chat_client.get_conversation_messages("test_user", "your_conversation_id")
|
||||||
"test_user", "your_conversation_id"
|
|
||||||
)
|
|
||||||
self.assertIn("answer", response.text)
|
self.assertIn("answer", response.text)
|
||||||
|
|
||||||
def test_get_conversations(self):
|
def test_get_conversations(self):
|
||||||
@@ -223,9 +203,7 @@ class TestCompletionClient(unittest.TestCase):
|
|||||||
self.assertIn("answer", response.text)
|
self.assertIn("answer", response.text)
|
||||||
|
|
||||||
def test_create_completion_message_with_vision_model_by_remote_url(self):
|
def test_create_completion_message_with_vision_model_by_remote_url(self):
|
||||||
files = [
|
files = [{"type": "image", "transfer_method": "remote_url", "url": "your_image_url"}]
|
||||||
{"type": "image", "transfer_method": "remote_url", "url": "your_image_url"}
|
|
||||||
]
|
|
||||||
response = self.completion_client.create_completion_message(
|
response = self.completion_client.create_completion_message(
|
||||||
{"query": "Describe the picture."}, "blocking", "test_user", files
|
{"query": "Describe the picture."}, "blocking", "test_user", files
|
||||||
)
|
)
|
||||||
@@ -250,9 +228,7 @@ class TestDifyClient(unittest.TestCase):
|
|||||||
self.dify_client = DifyClient(API_KEY)
|
self.dify_client = DifyClient(API_KEY)
|
||||||
|
|
||||||
def test_message_feedback(self):
|
def test_message_feedback(self):
|
||||||
response = self.dify_client.message_feedback(
|
response = self.dify_client.message_feedback("your_message_id", "like", "test_user")
|
||||||
"your_message_id", "like", "test_user"
|
|
||||||
)
|
|
||||||
self.assertIn("success", response.text)
|
self.assertIn("success", response.text)
|
||||||
|
|
||||||
def test_get_application_parameters(self):
|
def test_get_application_parameters(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user