refactor: refactor python sdk (#28118)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
wangxiaolei
2025-11-20 11:10:53 +08:00
committed by GitHub
parent a1b735a4c0
commit 99e9fc751b
14 changed files with 4551 additions and 106 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,228 @@
"""Base client with common functionality for both sync and async clients."""
import json
import time
import logging
from typing import Dict, Callable, Optional
try:
# Python 3.10+
from typing import ParamSpec
except ImportError:
# Python < 3.10
from typing_extensions import ParamSpec
from urllib.parse import urljoin
import httpx
P = ParamSpec("P")
from .exceptions import (
DifyClientError,
APIError,
AuthenticationError,
RateLimitError,
ValidationError,
NetworkError,
TimeoutError,
)
class BaseClientMixin:
"""Mixin class providing common functionality for Dify clients."""
def __init__(
self,
api_key: str,
base_url: str = "https://api.dify.ai/v1",
timeout: float = 60.0,
max_retries: int = 3,
retry_delay: float = 1.0,
enable_logging: bool = False,
):
"""Initialize the base client.
Args:
api_key: Your Dify API key
base_url: Base URL for the Dify API
timeout: Request timeout in seconds
max_retries: Maximum number of retry attempts
retry_delay: Delay between retries in seconds
enable_logging: Enable detailed logging
"""
if not api_key:
raise ValidationError("API key is required")
self.api_key = api_key
self.base_url = base_url.rstrip("/")
self.timeout = timeout
self.max_retries = max_retries
self.retry_delay = retry_delay
self.enable_logging = enable_logging
# Setup logging
self.logger = logging.getLogger(f"dify_client.{self.__class__.__name__.lower()}")
if enable_logging and not self.logger.handlers:
# Create console handler with formatter
handler = logging.StreamHandler()
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
handler.setFormatter(formatter)
self.logger.addHandler(handler)
self.logger.setLevel(logging.INFO)
self.enable_logging = True
else:
self.enable_logging = enable_logging
def _get_headers(self, content_type: str = "application/json") -> Dict[str, str]:
"""Get common request headers."""
return {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": content_type,
"User-Agent": "dify-client-python/0.1.12",
}
def _build_url(self, endpoint: str) -> str:
"""Build full URL from endpoint."""
return urljoin(self.base_url + "/", endpoint.lstrip("/"))
def _handle_response(self, response: httpx.Response) -> httpx.Response:
"""Handle HTTP response and raise appropriate exceptions."""
try:
if response.status_code == 401:
raise AuthenticationError(
"Authentication failed. Check your API key.",
status_code=response.status_code,
response=response.json() if response.content else None,
)
elif response.status_code == 429:
retry_after = response.headers.get("Retry-After")
raise RateLimitError(
"Rate limit exceeded. Please try again later.",
retry_after=int(retry_after) if retry_after else None,
)
elif response.status_code >= 400:
try:
error_data = response.json()
message = error_data.get("message", f"HTTP {response.status_code}")
except:
message = f"HTTP {response.status_code}: {response.text}"
raise APIError(
message,
status_code=response.status_code,
response=response.json() if response.content else None,
)
return response
except json.JSONDecodeError:
raise APIError(
f"Invalid JSON response: {response.text}",
status_code=response.status_code,
)
def _retry_request(
self,
request_func: Callable[P, httpx.Response],
request_context: str | None = None,
*args: P.args,
**kwargs: P.kwargs,
) -> httpx.Response:
"""Retry a request with exponential backoff.
Args:
request_func: Function that performs the HTTP request
request_context: Context description for logging (e.g., "GET /v1/messages")
*args: Positional arguments to pass to request_func
**kwargs: Keyword arguments to pass to request_func
Returns:
httpx.Response: Successful response
Raises:
NetworkError: On network failures after retries
TimeoutError: On timeout failures after retries
APIError: On API errors (4xx/5xx responses)
DifyClientError: On unexpected failures
"""
last_exception = None
for attempt in range(self.max_retries + 1):
try:
response = request_func(*args, **kwargs)
return response # Let caller handle response processing
except (httpx.NetworkError, httpx.TimeoutException) as e:
last_exception = e
context_msg = f" {request_context}" if request_context else ""
if attempt < self.max_retries:
delay = self.retry_delay * (2**attempt) # Exponential backoff
self.logger.warning(
f"Request failed{context_msg} (attempt {attempt + 1}/{self.max_retries + 1}): {e}. "
f"Retrying in {delay:.2f} seconds..."
)
time.sleep(delay)
else:
self.logger.error(f"Request failed{context_msg} after {self.max_retries + 1} attempts: {e}")
# Convert to custom exceptions
if isinstance(e, httpx.TimeoutException):
from .exceptions import TimeoutError
raise TimeoutError(f"Request timed out after {self.max_retries} retries{context_msg}") from e
else:
from .exceptions import NetworkError
raise NetworkError(
f"Network error after {self.max_retries} retries{context_msg}: {str(e)}"
) from e
if last_exception:
raise last_exception
raise DifyClientError("Request failed after retries")
def _validate_params(self, **params) -> None:
"""Validate request parameters."""
for key, value in params.items():
if value is None:
continue
# String validations
if isinstance(value, str):
if not value.strip():
raise ValidationError(f"Parameter '{key}' cannot be empty or whitespace only")
if len(value) > 10000:
raise ValidationError(f"Parameter '{key}' exceeds maximum length of 10000 characters")
# List validations
elif isinstance(value, list):
if len(value) > 1000:
raise ValidationError(f"Parameter '{key}' exceeds maximum size of 1000 items")
# Dictionary validations
elif isinstance(value, dict):
if len(value) > 100:
raise ValidationError(f"Parameter '{key}' exceeds maximum size of 100 items")
# Type-specific validations
if key == "user" and not isinstance(value, str):
raise ValidationError(f"Parameter '{key}' must be a string")
elif key in ["page", "limit", "page_size"] and not isinstance(value, int):
raise ValidationError(f"Parameter '{key}' must be an integer")
elif key == "files" and not isinstance(value, (list, dict)):
raise ValidationError(f"Parameter '{key}' must be a list or dict")
elif key == "rating" and value not in ["like", "dislike"]:
raise ValidationError(f"Parameter '{key}' must be 'like' or 'dislike'")
def _log_request(self, method: str, url: str, **kwargs) -> None:
"""Log request details."""
self.logger.info(f"Making {method} request to {url}")
if kwargs.get("json"):
self.logger.debug(f"Request body: {kwargs['json']}")
if kwargs.get("params"):
self.logger.debug(f"Query params: {kwargs['params']}")
def _log_response(self, response: httpx.Response) -> None:
"""Log response details."""
self.logger.info(f"Received response: {response.status_code} ({len(response.content)} bytes)")

View File

@@ -1,11 +1,20 @@
import json import json
import logging
import os import os
from typing import Literal, Dict, List, Any, IO from typing import Literal, Dict, List, Any, IO, Optional, Union
import httpx import httpx
from .base_client import BaseClientMixin
from .exceptions import (
APIError,
AuthenticationError,
RateLimitError,
ValidationError,
FileUploadError,
)
class DifyClient: class DifyClient(BaseClientMixin):
"""Synchronous Dify API client. """Synchronous Dify API client.
This client uses httpx.Client for efficient connection pooling and resource management. This client uses httpx.Client for efficient connection pooling and resource management.
@@ -21,6 +30,9 @@ class DifyClient:
api_key: str, api_key: str,
base_url: str = "https://api.dify.ai/v1", base_url: str = "https://api.dify.ai/v1",
timeout: float = 60.0, timeout: float = 60.0,
max_retries: int = 3,
retry_delay: float = 1.0,
enable_logging: bool = False,
): ):
"""Initialize the Dify client. """Initialize the Dify client.
@@ -28,9 +40,13 @@ class DifyClient:
api_key: Your Dify API key api_key: Your Dify API key
base_url: Base URL for the Dify API base_url: Base URL for the Dify API
timeout: Request timeout in seconds (default: 60.0) timeout: Request timeout in seconds (default: 60.0)
max_retries: Maximum number of retry attempts (default: 3)
retry_delay: Delay between retries in seconds (default: 1.0)
enable_logging: Whether to enable request logging (default: True)
""" """
self.api_key = api_key # Initialize base client functionality
self.base_url = base_url BaseClientMixin.__init__(self, api_key, base_url, timeout, max_retries, retry_delay, enable_logging)
self._client = httpx.Client( self._client = httpx.Client(
base_url=base_url, base_url=base_url,
timeout=httpx.Timeout(timeout, connect=5.0), timeout=httpx.Timeout(timeout, connect=5.0),
@@ -53,12 +69,12 @@ class DifyClient:
self, self,
method: str, method: str,
endpoint: str, endpoint: str,
json: dict | None = None, json: Dict[str, Any] | None = None,
params: dict | None = None, params: Dict[str, Any] | None = None,
stream: bool = False, stream: bool = False,
**kwargs, **kwargs,
): ):
"""Send an HTTP request to the Dify API. """Send an HTTP request to the Dify API with retry logic.
Args: Args:
method: HTTP method (GET, POST, PUT, PATCH, DELETE) method: HTTP method (GET, POST, PUT, PATCH, DELETE)
@@ -71,23 +87,91 @@ class DifyClient:
Returns: Returns:
httpx.Response object httpx.Response object
""" """
# Validate parameters
if json:
self._validate_params(**json)
if params:
self._validate_params(**params)
headers = { headers = {
"Authorization": f"Bearer {self.api_key}", "Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json", "Content-Type": "application/json",
} }
# httpx.Client automatically prepends base_url def make_request():
response = self._client.request( """Inner function to perform the actual HTTP request."""
method, # Log request if logging is enabled
endpoint, if self.enable_logging:
json=json, self.logger.info(f"Sending {method} request to {endpoint}")
params=params, # Debug logging for detailed information
headers=headers, if self.logger.isEnabledFor(logging.DEBUG):
**kwargs, if json:
) self.logger.debug(f"Request body: {json}")
if params:
self.logger.debug(f"Request params: {params}")
# httpx.Client automatically prepends base_url
response = self._client.request(
method,
endpoint,
json=json,
params=params,
headers=headers,
**kwargs,
)
# Log response if logging is enabled
if self.enable_logging:
self.logger.info(f"Received response: {response.status_code}")
return response
# Use the retry mechanism from base client
request_context = f"{method} {endpoint}"
response = self._retry_request(make_request, request_context)
# Handle error responses (API errors don't retry)
self._handle_error_response(response)
return response return response
def _handle_error_response(self, response, is_upload_request: bool = False) -> None:
"""Handle HTTP error responses and raise appropriate exceptions."""
if response.status_code < 400:
return # Success response
try:
error_data = response.json()
message = error_data.get("message", f"HTTP {response.status_code}")
except (ValueError, KeyError):
message = f"HTTP {response.status_code}"
error_data = None
# Log error response if logging is enabled
if self.enable_logging:
self.logger.error(f"API error: {response.status_code} - {message}")
if response.status_code == 401:
raise AuthenticationError(message, response.status_code, error_data)
elif response.status_code == 429:
retry_after = response.headers.get("Retry-After")
raise RateLimitError(message, retry_after)
elif response.status_code == 422:
raise ValidationError(message, response.status_code, error_data)
elif response.status_code == 400:
# Check if this is a file upload error based on the URL or context
current_url = getattr(response, "url", "") or ""
if is_upload_request or "upload" in str(current_url).lower() or "files" in str(current_url).lower():
raise FileUploadError(message, response.status_code, error_data)
else:
raise APIError(message, response.status_code, error_data)
elif response.status_code >= 500:
# Server errors should raise APIError
raise APIError(message, response.status_code, error_data)
elif response.status_code >= 400:
raise APIError(message, response.status_code, error_data)
def _send_request_with_files(self, method: str, endpoint: str, data: dict, files: dict): def _send_request_with_files(self, method: str, endpoint: str, data: dict, files: dict):
"""Send an HTTP request with file uploads. """Send an HTTP request with file uploads.
@@ -102,6 +186,12 @@ class DifyClient:
""" """
headers = {"Authorization": f"Bearer {self.api_key}"} headers = {"Authorization": f"Bearer {self.api_key}"}
# Log file upload request if logging is enabled
if self.enable_logging:
self.logger.info(f"Sending {method} file upload request to {endpoint}")
self.logger.debug(f"Form data: {data}")
self.logger.debug(f"Files: {files}")
response = self._client.request( response = self._client.request(
method, method,
endpoint, endpoint,
@@ -110,9 +200,17 @@ class DifyClient:
files=files, files=files,
) )
# Log response if logging is enabled
if self.enable_logging:
self.logger.info(f"Received file upload response: {response.status_code}")
# Handle error responses
self._handle_error_response(response, is_upload_request=True)
return response return response
def message_feedback(self, message_id: str, rating: Literal["like", "dislike"], user: str): def message_feedback(self, message_id: str, rating: Literal["like", "dislike"], user: str):
self._validate_params(message_id=message_id, rating=rating, user=user)
data = {"rating": rating, "user": user} data = {"rating": rating, "user": user}
return self._send_request("POST", f"/messages/{message_id}/feedbacks", data) return self._send_request("POST", f"/messages/{message_id}/feedbacks", data)
@@ -144,6 +242,72 @@ class DifyClient:
"""Get file preview by file ID.""" """Get file preview by file ID."""
return self._send_request("GET", f"/files/{file_id}/preview") return self._send_request("GET", f"/files/{file_id}/preview")
# App Configuration APIs
def get_app_site_config(self, app_id: str):
"""Get app site configuration.
Args:
app_id: ID of the app
Returns:
App site configuration
"""
url = f"/apps/{app_id}/site/config"
return self._send_request("GET", url)
def update_app_site_config(self, app_id: str, config_data: Dict[str, Any]):
"""Update app site configuration.
Args:
app_id: ID of the app
config_data: Configuration data to update
Returns:
Updated app site configuration
"""
url = f"/apps/{app_id}/site/config"
return self._send_request("PUT", url, json=config_data)
def get_app_api_tokens(self, app_id: str):
"""Get API tokens for an app.
Args:
app_id: ID of the app
Returns:
List of API tokens
"""
url = f"/apps/{app_id}/api-tokens"
return self._send_request("GET", url)
def create_app_api_token(self, app_id: str, name: str, description: str | None = None):
"""Create a new API token for an app.
Args:
app_id: ID of the app
name: Name for the API token
description: Description for the API token (optional)
Returns:
Created API token information
"""
data = {"name": name, "description": description}
url = f"/apps/{app_id}/api-tokens"
return self._send_request("POST", url, json=data)
def delete_app_api_token(self, app_id: str, token_id: str):
"""Delete an API token.
Args:
app_id: ID of the app
token_id: ID of the token to delete
Returns:
Deletion result
"""
url = f"/apps/{app_id}/api-tokens/{token_id}"
return self._send_request("DELETE", url)
class CompletionClient(DifyClient): class CompletionClient(DifyClient):
def create_completion_message( def create_completion_message(
@@ -151,8 +315,16 @@ class CompletionClient(DifyClient):
inputs: dict, inputs: dict,
response_mode: Literal["blocking", "streaming"], response_mode: Literal["blocking", "streaming"],
user: str, user: str,
files: dict | None = None, files: Dict[str, Any] | None = None,
): ):
# Validate parameters
if not isinstance(inputs, dict):
raise ValidationError("inputs must be a dictionary")
if response_mode not in ["blocking", "streaming"]:
raise ValidationError("response_mode must be 'blocking' or 'streaming'")
self._validate_params(inputs=inputs, response_mode=response_mode, user=user)
data = { data = {
"inputs": inputs, "inputs": inputs,
"response_mode": response_mode, "response_mode": response_mode,
@@ -175,8 +347,18 @@ class ChatClient(DifyClient):
user: str, user: str,
response_mode: Literal["blocking", "streaming"] = "blocking", response_mode: Literal["blocking", "streaming"] = "blocking",
conversation_id: str | None = None, conversation_id: str | None = None,
files: dict | None = None, files: Dict[str, Any] | None = None,
): ):
# Validate parameters
if not isinstance(inputs, dict):
raise ValidationError("inputs must be a dictionary")
if not isinstance(query, str) or not query.strip():
raise ValidationError("query must be a non-empty string")
if response_mode not in ["blocking", "streaming"]:
raise ValidationError("response_mode must be 'blocking' or 'streaming'")
self._validate_params(inputs=inputs, query=query, user=user, response_mode=response_mode)
data = { data = {
"inputs": inputs, "inputs": inputs,
"query": query, "query": query,
@@ -238,7 +420,7 @@ class ChatClient(DifyClient):
data = {"user": user} data = {"user": user}
return self._send_request("DELETE", f"/conversations/{conversation_id}", data) return self._send_request("DELETE", f"/conversations/{conversation_id}", data)
def audio_to_text(self, audio_file: IO[bytes] | tuple, user: str): def audio_to_text(self, audio_file: Union[IO[bytes], tuple], user: str):
data = {"user": user} data = {"user": user}
files = {"file": audio_file} files = {"file": audio_file}
return self._send_request_with_files("POST", "/audio-to-text", data, files) return self._send_request_with_files("POST", "/audio-to-text", data, files)
@@ -313,7 +495,48 @@ class ChatClient(DifyClient):
""" """
data = {"value": value, "user": user} data = {"value": value, "user": user}
url = f"/conversations/{conversation_id}/variables/{variable_id}" url = f"/conversations/{conversation_id}/variables/{variable_id}"
return self._send_request("PATCH", url, json=data) return self._send_request("PUT", url, json=data)
def delete_annotation_with_response(self, annotation_id: str):
"""Delete an annotation with full response handling."""
url = f"/apps/annotations/{annotation_id}"
return self._send_request("DELETE", url)
def list_conversation_variables_with_pagination(
self, conversation_id: str, user: str, page: int = 1, limit: int = 20
):
"""List conversation variables with pagination."""
params = {"page": page, "limit": limit, "user": user}
url = f"/conversations/{conversation_id}/variables"
return self._send_request("GET", url, params=params)
def update_conversation_variable_with_response(self, conversation_id: str, variable_id: str, user: str, value: Any):
"""Update a conversation variable with full response handling."""
data = {"value": value, "user": user}
url = f"/conversations/{conversation_id}/variables/{variable_id}"
return self._send_request("PUT", url, json=data)
# Enhanced Annotation APIs
def get_annotation_reply_job_status(self, action: str, job_id: str):
"""Get status of an annotation reply action job."""
url = f"/apps/annotation-reply/{action}/status/{job_id}"
return self._send_request("GET", url)
def list_annotations_with_pagination(self, page: int = 1, limit: int = 20, keyword: str | None = None):
"""List annotations with pagination."""
params = {"page": page, "limit": limit, "keyword": keyword}
return self._send_request("GET", "/apps/annotations", params=params)
def create_annotation_with_response(self, question: str, answer: str):
"""Create an annotation with full response handling."""
data = {"question": question, "answer": answer}
return self._send_request("POST", "/apps/annotations", json=data)
def update_annotation_with_response(self, annotation_id: str, question: str, answer: str):
"""Update an annotation with full response handling."""
data = {"question": question, "answer": answer}
url = f"/apps/annotations/{annotation_id}"
return self._send_request("PUT", url, json=data)
class WorkflowClient(DifyClient): class WorkflowClient(DifyClient):
@@ -376,6 +599,68 @@ class WorkflowClient(DifyClient):
stream=(response_mode == "streaming"), stream=(response_mode == "streaming"),
) )
# Enhanced Workflow APIs
def get_workflow_draft(self, app_id: str):
"""Get workflow draft configuration.
Args:
app_id: ID of the workflow app
Returns:
Workflow draft configuration
"""
url = f"/apps/{app_id}/workflow/draft"
return self._send_request("GET", url)
def update_workflow_draft(self, app_id: str, workflow_data: Dict[str, Any]):
"""Update workflow draft configuration.
Args:
app_id: ID of the workflow app
workflow_data: Workflow configuration data
Returns:
Updated workflow draft
"""
url = f"/apps/{app_id}/workflow/draft"
return self._send_request("PUT", url, json=workflow_data)
def publish_workflow(self, app_id: str):
"""Publish workflow from draft.
Args:
app_id: ID of the workflow app
Returns:
Published workflow information
"""
url = f"/apps/{app_id}/workflow/publish"
return self._send_request("POST", url)
def get_workflow_run_history(
self,
app_id: str,
page: int = 1,
limit: int = 20,
status: Literal["succeeded", "failed", "stopped"] | None = None,
):
"""Get workflow run history.
Args:
app_id: ID of the workflow app
page: Page number (default: 1)
limit: Number of items per page (default: 20)
status: Filter by status (optional)
Returns:
Paginated workflow run history
"""
params = {"page": page, "limit": limit}
if status:
params["status"] = status
url = f"/apps/{app_id}/workflow/runs"
return self._send_request("GET", url, params=params)
class WorkspaceClient(DifyClient): class WorkspaceClient(DifyClient):
"""Client for workspace-related operations.""" """Client for workspace-related operations."""
@@ -385,6 +670,41 @@ class WorkspaceClient(DifyClient):
url = f"/workspaces/current/models/model-types/{model_type}" url = f"/workspaces/current/models/model-types/{model_type}"
return self._send_request("GET", url) return self._send_request("GET", url)
def get_available_models_by_type(self, model_type: str):
"""Get available models by model type (enhanced version)."""
url = f"/workspaces/current/models/model-types/{model_type}"
return self._send_request("GET", url)
def get_model_providers(self):
"""Get all model providers."""
return self._send_request("GET", "/workspaces/current/model-providers")
def get_model_provider_models(self, provider_name: str):
"""Get models for a specific provider."""
url = f"/workspaces/current/model-providers/{provider_name}/models"
return self._send_request("GET", url)
def validate_model_provider_credentials(self, provider_name: str, credentials: Dict[str, Any]):
"""Validate model provider credentials."""
url = f"/workspaces/current/model-providers/{provider_name}/credentials/validate"
return self._send_request("POST", url, json=credentials)
# File Management APIs
def get_file_info(self, file_id: str):
"""Get information about a specific file."""
url = f"/files/{file_id}/info"
return self._send_request("GET", url)
def get_file_download_url(self, file_id: str):
"""Get download URL for a file."""
url = f"/files/{file_id}/download-url"
return self._send_request("GET", url)
def delete_file(self, file_id: str):
"""Delete a file."""
url = f"/files/{file_id}"
return self._send_request("DELETE", url)
class KnowledgeBaseClient(DifyClient): class KnowledgeBaseClient(DifyClient):
def __init__( def __init__(
@@ -416,7 +736,7 @@ class KnowledgeBaseClient(DifyClient):
def list_datasets(self, page: int = 1, page_size: int = 20, **kwargs): def list_datasets(self, page: int = 1, page_size: int = 20, **kwargs):
return self._send_request("GET", "/datasets", params={"page": page, "limit": page_size}, **kwargs) return self._send_request("GET", "/datasets", params={"page": page, "limit": page_size}, **kwargs)
def create_document_by_text(self, name, text, extra_params: dict | None = None, **kwargs): def create_document_by_text(self, name, text, extra_params: Dict[str, Any] | None = None, **kwargs):
""" """
Create a document by text. Create a document by text.
@@ -458,7 +778,7 @@ class KnowledgeBaseClient(DifyClient):
document_id: str, document_id: str,
name: str, name: str,
text: str, text: str,
extra_params: dict | None = None, extra_params: Dict[str, Any] | None = None,
**kwargs, **kwargs,
): ):
""" """
@@ -497,7 +817,7 @@ class KnowledgeBaseClient(DifyClient):
self, self,
file_path: str, file_path: str,
original_document_id: str | None = None, original_document_id: str | None = None,
extra_params: dict | None = None, extra_params: Dict[str, Any] | None = None,
): ):
""" """
Create a document by file. Create a document by file.
@@ -537,7 +857,12 @@ class KnowledgeBaseClient(DifyClient):
url = f"/datasets/{self._get_dataset_id()}/document/create_by_file" url = f"/datasets/{self._get_dataset_id()}/document/create_by_file"
return self._send_request_with_files("POST", url, {"data": json.dumps(data)}, files) return self._send_request_with_files("POST", url, {"data": json.dumps(data)}, files)
def update_document_by_file(self, document_id: str, file_path: str, extra_params: dict | None = None): def update_document_by_file(
self,
document_id: str,
file_path: str,
extra_params: Dict[str, Any] | None = None,
):
""" """
Update a document by file. Update a document by file.
@@ -893,3 +1218,50 @@ class KnowledgeBaseClient(DifyClient):
url = f"/datasets/{ds_id}/documents/status/{action}" url = f"/datasets/{ds_id}/documents/status/{action}"
data = {"document_ids": document_ids} data = {"document_ids": document_ids}
return self._send_request("PATCH", url, json=data) return self._send_request("PATCH", url, json=data)
# Enhanced Dataset APIs
def create_dataset_from_template(self, template_name: str, name: str, description: str | None = None):
"""Create a dataset from a predefined template.
Args:
template_name: Name of the template to use
name: Name for the new dataset
description: Description for the dataset (optional)
Returns:
Created dataset information
"""
data = {
"template_name": template_name,
"name": name,
"description": description,
}
return self._send_request("POST", "/datasets/from-template", json=data)
def duplicate_dataset(self, dataset_id: str, name: str):
"""Duplicate an existing dataset.
Args:
dataset_id: ID of dataset to duplicate
name: Name for duplicated dataset
Returns:
New dataset information
"""
data = {"name": name}
url = f"/datasets/{dataset_id}/duplicate"
return self._send_request("POST", url, json=data)
def list_conversation_variables_with_pagination(
self, conversation_id: str, user: str, page: int = 1, limit: int = 20
):
"""List conversation variables with pagination."""
params = {"page": page, "limit": limit, "user": user}
url = f"/conversations/{conversation_id}/variables"
return self._send_request("GET", url, params=params)
def update_conversation_variable_with_response(self, conversation_id: str, variable_id: str, user: str, value: Any):
"""Update a conversation variable with full response handling."""
data = {"value": value, "user": user}
url = f"/conversations/{conversation_id}/variables/{variable_id}"
return self._send_request("PUT", url, json=data)

View File

@@ -0,0 +1,71 @@
"""Custom exceptions for the Dify client."""
from typing import Optional, Dict, Any
class DifyClientError(Exception):
"""Base exception for all Dify client errors."""
def __init__(self, message: str, status_code: int | None = None, response: Dict[str, Any] | None = None):
super().__init__(message)
self.message = message
self.status_code = status_code
self.response = response
class APIError(DifyClientError):
"""Raised when the API returns an error response."""
def __init__(self, message: str, status_code: int, response: Dict[str, Any] | None = None):
super().__init__(message, status_code, response)
self.status_code = status_code
class AuthenticationError(DifyClientError):
"""Raised when authentication fails."""
pass
class RateLimitError(DifyClientError):
"""Raised when rate limit is exceeded."""
def __init__(self, message: str = "Rate limit exceeded", retry_after: int | None = None):
super().__init__(message)
self.retry_after = retry_after
class ValidationError(DifyClientError):
"""Raised when request validation fails."""
pass
class NetworkError(DifyClientError):
"""Raised when network-related errors occur."""
pass
class TimeoutError(DifyClientError):
"""Raised when request times out."""
pass
class FileUploadError(DifyClientError):
"""Raised when file upload fails."""
pass
class DatasetError(DifyClientError):
"""Raised when dataset operations fail."""
pass
class WorkflowError(DifyClientError):
"""Raised when workflow operations fail."""
pass

View File

@@ -0,0 +1,396 @@
"""Response models for the Dify client with proper type hints."""
from typing import Optional, List, Dict, Any, Literal, Union
from dataclasses import dataclass, field
from datetime import datetime
@dataclass
class BaseResponse:
"""Base response model."""
success: bool = True
message: str | None = None
@dataclass
class ErrorResponse(BaseResponse):
"""Error response model."""
error_code: str | None = None
details: Dict[str, Any] | None = None
success: bool = False
@dataclass
class FileInfo:
"""File information model."""
id: str
name: str
size: int
mime_type: str
url: str | None = None
created_at: datetime | None = None
@dataclass
class MessageResponse(BaseResponse):
"""Message response model."""
id: str = ""
answer: str = ""
conversation_id: str | None = None
created_at: int | None = None
metadata: Dict[str, Any] | None = None
files: List[Dict[str, Any]] | None = None
@dataclass
class ConversationResponse(BaseResponse):
"""Conversation response model."""
id: str = ""
name: str = ""
inputs: Dict[str, Any] | None = None
status: str | None = None
created_at: int | None = None
updated_at: int | None = None
@dataclass
class DatasetResponse(BaseResponse):
"""Dataset response model."""
id: str = ""
name: str = ""
description: str | None = None
permission: str | None = None
indexing_technique: str | None = None
embedding_model: str | None = None
embedding_model_provider: str | None = None
retrieval_model: Dict[str, Any] | None = None
document_count: int | None = None
word_count: int | None = None
app_count: int | None = None
created_at: int | None = None
updated_at: int | None = None
@dataclass
class DocumentResponse(BaseResponse):
"""Document response model."""
id: str = ""
name: str = ""
data_source_type: str | None = None
data_source_info: Dict[str, Any] | None = None
dataset_process_rule_id: str | None = None
batch: str | None = None
position: int | None = None
enabled: bool | None = None
disabled_at: float | None = None
disabled_by: str | None = None
archived: bool | None = None
archived_reason: str | None = None
archived_at: float | None = None
archived_by: str | None = None
word_count: int | None = None
hit_count: int | None = None
doc_form: str | None = None
doc_metadata: Dict[str, Any] | None = None
created_at: float | None = None
updated_at: float | None = None
indexing_status: str | None = None
completed_at: float | None = None
paused_at: float | None = None
error: str | None = None
stopped_at: float | None = None
@dataclass
class DocumentSegmentResponse(BaseResponse):
"""Document segment response model."""
id: str = ""
position: int | None = None
document_id: str | None = None
content: str | None = None
answer: str | None = None
word_count: int | None = None
tokens: int | None = None
keywords: List[str] | None = None
index_node_id: str | None = None
index_node_hash: str | None = None
hit_count: int | None = None
enabled: bool | None = None
disabled_at: float | None = None
disabled_by: str | None = None
status: str | None = None
created_by: str | None = None
created_at: float | None = None
indexing_at: float | None = None
completed_at: float | None = None
error: str | None = None
stopped_at: float | None = None
@dataclass
class WorkflowRunResponse(BaseResponse):
"""Workflow run response model."""
id: str = ""
workflow_id: str | None = None
status: Literal["running", "succeeded", "failed", "stopped"] | None = None
inputs: Dict[str, Any] | None = None
outputs: Dict[str, Any] | None = None
error: str | None = None
elapsed_time: float | None = None
total_tokens: int | None = None
total_steps: int | None = None
created_at: float | None = None
finished_at: float | None = None
@dataclass
class ApplicationParametersResponse(BaseResponse):
"""Application parameters response model."""
opening_statement: str | None = None
suggested_questions: List[str] | None = None
speech_to_text: Dict[str, Any] | None = None
text_to_speech: Dict[str, Any] | None = None
retriever_resource: Dict[str, Any] | None = None
sensitive_word_avoidance: Dict[str, Any] | None = None
file_upload: Dict[str, Any] | None = None
system_parameters: Dict[str, Any] | None = None
user_input_form: List[Dict[str, Any]] | None = None
@dataclass
class AnnotationResponse(BaseResponse):
"""Annotation response model."""
id: str = ""
question: str = ""
answer: str = ""
content: str | None = None
created_at: float | None = None
updated_at: float | None = None
created_by: str | None = None
updated_by: str | None = None
hit_count: int | None = None
@dataclass
class PaginatedResponse(BaseResponse):
"""Paginated response model."""
data: List[Any] = field(default_factory=list)
has_more: bool = False
limit: int = 0
total: int = 0
page: int | None = None
@dataclass
class ConversationVariableResponse(BaseResponse):
"""Conversation variable response model."""
conversation_id: str = ""
variables: List[Dict[str, Any]] = field(default_factory=list)
@dataclass
class FileUploadResponse(BaseResponse):
"""File upload response model."""
id: str = ""
name: str = ""
size: int = 0
mime_type: str = ""
url: str | None = None
created_at: float | None = None
@dataclass
class AudioResponse(BaseResponse):
"""Audio generation/response model."""
audio: str | None = None # Base64 encoded audio data or URL
audio_url: str | None = None
duration: float | None = None
sample_rate: int | None = None
@dataclass
class SuggestedQuestionsResponse(BaseResponse):
"""Suggested questions response model."""
message_id: str = ""
questions: List[str] = field(default_factory=list)
@dataclass
class AppInfoResponse(BaseResponse):
"""App info response model."""
id: str = ""
name: str = ""
description: str | None = None
icon: str | None = None
icon_background: str | None = None
mode: str | None = None
tags: List[str] | None = None
enable_site: bool | None = None
enable_api: bool | None = None
api_token: str | None = None
@dataclass
class WorkspaceModelsResponse(BaseResponse):
"""Workspace models response model."""
models: List[Dict[str, Any]] = field(default_factory=list)
@dataclass
class HitTestingResponse(BaseResponse):
"""Hit testing response model."""
query: str = ""
records: List[Dict[str, Any]] = field(default_factory=list)
@dataclass
class DatasetTagsResponse(BaseResponse):
"""Dataset tags response model."""
tags: List[Dict[str, Any]] = field(default_factory=list)
@dataclass
class WorkflowLogsResponse(BaseResponse):
"""Workflow logs response model."""
logs: List[Dict[str, Any]] = field(default_factory=list)
total: int = 0
page: int = 0
limit: int = 0
has_more: bool = False
@dataclass
class ModelProviderResponse(BaseResponse):
"""Model provider response model."""
provider_name: str = ""
provider_type: str = ""
models: List[Dict[str, Any]] = field(default_factory=list)
is_enabled: bool = False
credentials: Dict[str, Any] | None = None
@dataclass
class FileInfoResponse(BaseResponse):
"""File info response model."""
id: str = ""
name: str = ""
size: int = 0
mime_type: str = ""
url: str | None = None
created_at: int | None = None
metadata: Dict[str, Any] | None = None
@dataclass
class WorkflowDraftResponse(BaseResponse):
"""Workflow draft response model."""
id: str = ""
app_id: str = ""
draft_data: Dict[str, Any] = field(default_factory=dict)
version: int = 0
created_at: int | None = None
updated_at: int | None = None
@dataclass
class ApiTokenResponse(BaseResponse):
"""API token response model."""
id: str = ""
name: str = ""
token: str = ""
description: str | None = None
created_at: int | None = None
last_used_at: int | None = None
is_active: bool = True
@dataclass
class JobStatusResponse(BaseResponse):
"""Job status response model."""
job_id: str = ""
job_status: str = ""
error_msg: str | None = None
progress: float | None = None
created_at: int | None = None
updated_at: int | None = None
@dataclass
class DatasetQueryResponse(BaseResponse):
"""Dataset query response model."""
query: str = ""
records: List[Dict[str, Any]] = field(default_factory=list)
total: int = 0
search_time: float | None = None
retrieval_model: Dict[str, Any] | None = None
@dataclass
class DatasetTemplateResponse(BaseResponse):
"""Dataset template response model."""
template_name: str = ""
display_name: str = ""
description: str = ""
category: str = ""
icon: str | None = None
config_schema: Dict[str, Any] = field(default_factory=dict)
# Type aliases for common response types
ResponseType = Union[
BaseResponse,
ErrorResponse,
MessageResponse,
ConversationResponse,
DatasetResponse,
DocumentResponse,
DocumentSegmentResponse,
WorkflowRunResponse,
ApplicationParametersResponse,
AnnotationResponse,
PaginatedResponse,
ConversationVariableResponse,
FileUploadResponse,
AudioResponse,
SuggestedQuestionsResponse,
AppInfoResponse,
WorkspaceModelsResponse,
HitTestingResponse,
DatasetTagsResponse,
WorkflowLogsResponse,
ModelProviderResponse,
FileInfoResponse,
WorkflowDraftResponse,
ApiTokenResponse,
JobStatusResponse,
DatasetQueryResponse,
DatasetTemplateResponse,
]

View File

@@ -0,0 +1,264 @@
"""
Advanced usage examples for the Dify Python SDK.
This example demonstrates:
- Error handling and retries
- Logging configuration
- Context managers
- Async usage
- File uploads
- Dataset management
"""
import asyncio
import logging
from pathlib import Path
from dify_client import (
ChatClient,
CompletionClient,
AsyncChatClient,
KnowledgeBaseClient,
DifyClient,
)
from dify_client.exceptions import (
APIError,
RateLimitError,
AuthenticationError,
DifyClientError,
)
def setup_logging():
"""Setup logging for the SDK."""
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
def example_chat_with_error_handling():
"""Example of chat with comprehensive error handling."""
api_key = "your-api-key-here"
try:
with ChatClient(api_key, enable_logging=True) as client:
# Simple chat message
response = client.create_chat_message(
inputs={}, query="Hello, how are you?", user="user-123", response_mode="blocking"
)
result = response.json()
print(f"Response: {result.get('answer')}")
except AuthenticationError as e:
print(f"Authentication failed: {e}")
print("Please check your API key")
except RateLimitError as e:
print(f"Rate limit exceeded: {e}")
if e.retry_after:
print(f"Retry after {e.retry_after} seconds")
except APIError as e:
print(f"API error: {e.message}")
print(f"Status code: {e.status_code}")
except DifyClientError as e:
print(f"Dify client error: {e}")
except Exception as e:
print(f"Unexpected error: {e}")
def example_completion_with_files():
"""Example of completion with file upload."""
api_key = "your-api-key-here"
with CompletionClient(api_key) as client:
# Upload an image file first
file_path = "path/to/your/image.jpg"
try:
with open(file_path, "rb") as f:
files = {"file": (Path(file_path).name, f, "image/jpeg")}
upload_response = client.file_upload("user-123", files)
upload_response.raise_for_status()
file_id = upload_response.json().get("id")
print(f"File uploaded with ID: {file_id}")
# Use the uploaded file in completion
files_list = [{"type": "image", "transfer_method": "local_file", "upload_file_id": file_id}]
completion_response = client.create_completion_message(
inputs={"query": "Describe this image"}, response_mode="blocking", user="user-123", files=files_list
)
result = completion_response.json()
print(f"Completion result: {result.get('answer')}")
except FileNotFoundError:
print(f"File not found: {file_path}")
except Exception as e:
print(f"Error during file upload/completion: {e}")
def example_dataset_management():
"""Example of dataset management operations."""
api_key = "your-api-key-here"
with KnowledgeBaseClient(api_key) as kb_client:
try:
# Create a new dataset
create_response = kb_client.create_dataset(name="My Test Dataset")
create_response.raise_for_status()
dataset_id = create_response.json().get("id")
print(f"Created dataset with ID: {dataset_id}")
# Create a client with the dataset ID
dataset_client = KnowledgeBaseClient(api_key, dataset_id=dataset_id)
# Add a document by text
doc_response = dataset_client.create_document_by_text(
name="Test Document", text="This is a test document for the knowledge base."
)
doc_response.raise_for_status()
document_id = doc_response.json().get("document", {}).get("id")
print(f"Created document with ID: {document_id}")
# List documents
list_response = dataset_client.list_documents()
list_response.raise_for_status()
documents = list_response.json().get("data", [])
print(f"Dataset contains {len(documents)} documents")
# Update dataset configuration
update_response = dataset_client.update_dataset(
name="Updated Dataset Name", description="Updated description", indexing_technique="high_quality"
)
update_response.raise_for_status()
print("Dataset updated successfully")
except Exception as e:
print(f"Dataset management error: {e}")
async def example_async_chat():
"""Example of async chat usage."""
api_key = "your-api-key-here"
try:
async with AsyncChatClient(api_key) as client:
# Create chat message
response = await client.create_chat_message(
inputs={}, query="What's the weather like?", user="user-456", response_mode="blocking"
)
result = response.json()
print(f"Async response: {result.get('answer')}")
# Get conversations
conversations = await client.get_conversations("user-456")
conversations.raise_for_status()
conv_data = conversations.json()
print(f"Found {len(conv_data.get('data', []))} conversations")
except Exception as e:
print(f"Async chat error: {e}")
def example_streaming_response():
"""Example of handling streaming responses."""
api_key = "your-api-key-here"
with ChatClient(api_key) as client:
try:
response = client.create_chat_message(
inputs={}, query="Tell me a story", user="user-789", response_mode="streaming"
)
print("Streaming response:")
for line in response.iter_lines(decode_unicode=True):
if line.startswith("data:"):
data = line[5:].strip()
if data:
import json
try:
chunk = json.loads(data)
answer = chunk.get("answer", "")
if answer:
print(answer, end="", flush=True)
except json.JSONDecodeError:
continue
print() # New line after streaming
except Exception as e:
print(f"Streaming error: {e}")
def example_application_info():
"""Example of getting application information."""
api_key = "your-api-key-here"
with DifyClient(api_key) as client:
try:
# Get app info
info_response = client.get_app_info()
info_response.raise_for_status()
app_info = info_response.json()
print(f"App name: {app_info.get('name')}")
print(f"App mode: {app_info.get('mode')}")
print(f"App tags: {app_info.get('tags', [])}")
# Get app parameters
params_response = client.get_application_parameters("user-123")
params_response.raise_for_status()
params = params_response.json()
print(f"Opening statement: {params.get('opening_statement')}")
print(f"Suggested questions: {params.get('suggested_questions', [])}")
except Exception as e:
print(f"App info error: {e}")
def main():
"""Run all examples."""
setup_logging()
print("=== Dify Python SDK Advanced Usage Examples ===\n")
print("1. Chat with Error Handling:")
example_chat_with_error_handling()
print()
print("2. Completion with Files:")
example_completion_with_files()
print()
print("3. Dataset Management:")
example_dataset_management()
print()
print("4. Async Chat:")
asyncio.run(example_async_chat())
print()
print("5. Streaming Response:")
example_streaming_response()
print()
print("6. Application Info:")
example_application_info()
print()
print("All examples completed!")
if __name__ == "__main__":
main()

View File

@@ -5,7 +5,7 @@ description = "A package for interacting with the Dify Service-API"
readme = "README.md" readme = "README.md"
requires-python = ">=3.10" requires-python = ">=3.10"
dependencies = [ dependencies = [
"httpx>=0.27.0", "httpx[http2]>=0.27.0",
"aiofiles>=23.0.0", "aiofiles>=23.0.0",
] ]
authors = [ authors = [

View File

@@ -1,6 +1,7 @@
import os import os
import time import time
import unittest import unittest
from unittest.mock import Mock, patch, mock_open
from dify_client.client import ( from dify_client.client import (
ChatClient, ChatClient,
@@ -17,38 +18,46 @@ FILE_PATH_BASE = os.path.dirname(__file__)
class TestKnowledgeBaseClient(unittest.TestCase): class TestKnowledgeBaseClient(unittest.TestCase):
def setUp(self): def setUp(self):
self.knowledge_base_client = KnowledgeBaseClient(API_KEY, base_url=API_BASE_URL) self.api_key = "test-api-key"
self.base_url = "https://api.dify.ai/v1"
self.knowledge_base_client = KnowledgeBaseClient(self.api_key, base_url=self.base_url)
self.README_FILE_PATH = os.path.abspath(os.path.join(FILE_PATH_BASE, "../README.md")) self.README_FILE_PATH = os.path.abspath(os.path.join(FILE_PATH_BASE, "../README.md"))
self.dataset_id = None self.dataset_id = "test-dataset-id"
self.document_id = None self.document_id = "test-document-id"
self.segment_id = None self.segment_id = "test-segment-id"
self.batch_id = None self.batch_id = "test-batch-id"
def _get_dataset_kb_client(self): def _get_dataset_kb_client(self):
self.assertIsNotNone(self.dataset_id) return KnowledgeBaseClient(self.api_key, base_url=self.base_url, dataset_id=self.dataset_id)
return KnowledgeBaseClient(API_KEY, base_url=API_BASE_URL, dataset_id=self.dataset_id)
@patch("dify_client.client.httpx.Client")
def test_001_create_dataset(self, mock_httpx_client):
# Mock the HTTP response
mock_response = Mock()
mock_response.json.return_value = {"id": self.dataset_id, "name": "test_dataset"}
mock_response.status_code = 200
mock_client_instance = Mock()
mock_client_instance.request.return_value = mock_response
mock_httpx_client.return_value = mock_client_instance
# Re-create client with mocked httpx
self.knowledge_base_client = KnowledgeBaseClient(self.api_key, base_url=self.base_url)
def test_001_create_dataset(self):
response = self.knowledge_base_client.create_dataset(name="test_dataset") response = self.knowledge_base_client.create_dataset(name="test_dataset")
data = response.json() data = response.json()
self.assertIn("id", data) self.assertIn("id", data)
self.dataset_id = data["id"]
self.assertEqual("test_dataset", data["name"]) self.assertEqual("test_dataset", data["name"])
# the following tests require to be executed in order because they use # the following tests require to be executed in order because they use
# the dataset/document/segment ids from the previous test # the dataset/document/segment ids from the previous test
self._test_002_list_datasets() self._test_002_list_datasets()
self._test_003_create_document_by_text() self._test_003_create_document_by_text()
time.sleep(1)
self._test_004_update_document_by_text() self._test_004_update_document_by_text()
# self._test_005_batch_indexing_status()
time.sleep(1)
self._test_006_update_document_by_file() self._test_006_update_document_by_file()
time.sleep(1)
self._test_007_list_documents() self._test_007_list_documents()
self._test_008_delete_document() self._test_008_delete_document()
self._test_009_create_document_by_file() self._test_009_create_document_by_file()
time.sleep(1)
self._test_010_add_segments() self._test_010_add_segments()
self._test_011_query_segments() self._test_011_query_segments()
self._test_012_update_document_segment() self._test_012_update_document_segment()
@@ -56,6 +65,12 @@ class TestKnowledgeBaseClient(unittest.TestCase):
self._test_014_delete_dataset() self._test_014_delete_dataset()
def _test_002_list_datasets(self): def _test_002_list_datasets(self):
# Mock the response - using the already mocked client from test_001_create_dataset
mock_response = Mock()
mock_response.json.return_value = {"data": [], "total": 0}
mock_response.status_code = 200
self.knowledge_base_client._client.request.return_value = mock_response
response = self.knowledge_base_client.list_datasets() response = self.knowledge_base_client.list_datasets()
data = response.json() data = response.json()
self.assertIn("data", data) self.assertIn("data", data)
@@ -63,45 +78,62 @@ class TestKnowledgeBaseClient(unittest.TestCase):
def _test_003_create_document_by_text(self): def _test_003_create_document_by_text(self):
client = self._get_dataset_kb_client() client = self._get_dataset_kb_client()
# Mock the response
mock_response = Mock()
mock_response.json.return_value = {"document": {"id": self.document_id}, "batch": self.batch_id}
mock_response.status_code = 200
client._client.request.return_value = mock_response
response = client.create_document_by_text("test_document", "test_text") response = client.create_document_by_text("test_document", "test_text")
data = response.json() data = response.json()
self.assertIn("document", data) self.assertIn("document", data)
self.document_id = data["document"]["id"]
self.batch_id = data["batch"]
def _test_004_update_document_by_text(self): def _test_004_update_document_by_text(self):
client = self._get_dataset_kb_client() client = self._get_dataset_kb_client()
self.assertIsNotNone(self.document_id) # Mock the response
mock_response = Mock()
mock_response.json.return_value = {"document": {"id": self.document_id}, "batch": self.batch_id}
mock_response.status_code = 200
client._client.request.return_value = mock_response
response = client.update_document_by_text(self.document_id, "test_document_updated", "test_text_updated") response = client.update_document_by_text(self.document_id, "test_document_updated", "test_text_updated")
data = response.json() data = response.json()
self.assertIn("document", data) self.assertIn("document", data)
self.assertIn("batch", data) self.assertIn("batch", data)
self.batch_id = data["batch"]
def _test_005_batch_indexing_status(self):
client = self._get_dataset_kb_client()
response = client.batch_indexing_status(self.batch_id)
response.json()
self.assertEqual(response.status_code, 200)
def _test_006_update_document_by_file(self): def _test_006_update_document_by_file(self):
client = self._get_dataset_kb_client() client = self._get_dataset_kb_client()
self.assertIsNotNone(self.document_id) # Mock the response
mock_response = Mock()
mock_response.json.return_value = {"document": {"id": self.document_id}, "batch": self.batch_id}
mock_response.status_code = 200
client._client.request.return_value = mock_response
response = client.update_document_by_file(self.document_id, self.README_FILE_PATH) response = client.update_document_by_file(self.document_id, self.README_FILE_PATH)
data = response.json() data = response.json()
self.assertIn("document", data) self.assertIn("document", data)
self.assertIn("batch", data) self.assertIn("batch", data)
self.batch_id = data["batch"]
def _test_007_list_documents(self): def _test_007_list_documents(self):
client = self._get_dataset_kb_client() client = self._get_dataset_kb_client()
# Mock the response
mock_response = Mock()
mock_response.json.return_value = {"data": []}
mock_response.status_code = 200
client._client.request.return_value = mock_response
response = client.list_documents() response = client.list_documents()
data = response.json() data = response.json()
self.assertIn("data", data) self.assertIn("data", data)
def _test_008_delete_document(self): def _test_008_delete_document(self):
client = self._get_dataset_kb_client() client = self._get_dataset_kb_client()
self.assertIsNotNone(self.document_id) # Mock the response
mock_response = Mock()
mock_response.json.return_value = {"result": "success"}
mock_response.status_code = 200
client._client.request.return_value = mock_response
response = client.delete_document(self.document_id) response = client.delete_document(self.document_id)
data = response.json() data = response.json()
self.assertIn("result", data) self.assertIn("result", data)
@@ -109,23 +141,37 @@ class TestKnowledgeBaseClient(unittest.TestCase):
def _test_009_create_document_by_file(self): def _test_009_create_document_by_file(self):
client = self._get_dataset_kb_client() client = self._get_dataset_kb_client()
# Mock the response
mock_response = Mock()
mock_response.json.return_value = {"document": {"id": self.document_id}, "batch": self.batch_id}
mock_response.status_code = 200
client._client.request.return_value = mock_response
response = client.create_document_by_file(self.README_FILE_PATH) response = client.create_document_by_file(self.README_FILE_PATH)
data = response.json() data = response.json()
self.assertIn("document", data) self.assertIn("document", data)
self.document_id = data["document"]["id"]
self.batch_id = data["batch"]
def _test_010_add_segments(self): def _test_010_add_segments(self):
client = self._get_dataset_kb_client() client = self._get_dataset_kb_client()
# Mock the response
mock_response = Mock()
mock_response.json.return_value = {"data": [{"id": self.segment_id, "content": "test text segment 1"}]}
mock_response.status_code = 200
client._client.request.return_value = mock_response
response = client.add_segments(self.document_id, [{"content": "test text segment 1"}]) response = client.add_segments(self.document_id, [{"content": "test text segment 1"}])
data = response.json() data = response.json()
self.assertIn("data", data) self.assertIn("data", data)
self.assertGreater(len(data["data"]), 0) self.assertGreater(len(data["data"]), 0)
segment = data["data"][0]
self.segment_id = segment["id"]
def _test_011_query_segments(self): def _test_011_query_segments(self):
client = self._get_dataset_kb_client() client = self._get_dataset_kb_client()
# Mock the response
mock_response = Mock()
mock_response.json.return_value = {"data": [{"id": self.segment_id, "content": "test text segment 1"}]}
mock_response.status_code = 200
client._client.request.return_value = mock_response
response = client.query_segments(self.document_id) response = client.query_segments(self.document_id)
data = response.json() data = response.json()
self.assertIn("data", data) self.assertIn("data", data)
@@ -133,7 +179,12 @@ class TestKnowledgeBaseClient(unittest.TestCase):
def _test_012_update_document_segment(self): def _test_012_update_document_segment(self):
client = self._get_dataset_kb_client() client = self._get_dataset_kb_client()
self.assertIsNotNone(self.segment_id) # Mock the response
mock_response = Mock()
mock_response.json.return_value = {"data": {"id": self.segment_id, "content": "test text segment 1 updated"}}
mock_response.status_code = 200
client._client.request.return_value = mock_response
response = client.update_document_segment( response = client.update_document_segment(
self.document_id, self.document_id,
self.segment_id, self.segment_id,
@@ -141,13 +192,16 @@ class TestKnowledgeBaseClient(unittest.TestCase):
) )
data = response.json() data = response.json()
self.assertIn("data", data) self.assertIn("data", data)
self.assertGreater(len(data["data"]), 0) self.assertEqual("test text segment 1 updated", data["data"]["content"])
segment = data["data"]
self.assertEqual("test text segment 1 updated", segment["content"])
def _test_013_delete_document_segment(self): def _test_013_delete_document_segment(self):
client = self._get_dataset_kb_client() client = self._get_dataset_kb_client()
self.assertIsNotNone(self.segment_id) # Mock the response
mock_response = Mock()
mock_response.json.return_value = {"result": "success"}
mock_response.status_code = 200
client._client.request.return_value = mock_response
response = client.delete_document_segment(self.document_id, self.segment_id) response = client.delete_document_segment(self.document_id, self.segment_id)
data = response.json() data = response.json()
self.assertIn("result", data) self.assertIn("result", data)
@@ -155,94 +209,279 @@ class TestKnowledgeBaseClient(unittest.TestCase):
def _test_014_delete_dataset(self): def _test_014_delete_dataset(self):
client = self._get_dataset_kb_client() client = self._get_dataset_kb_client()
# Mock the response
mock_response = Mock()
mock_response.status_code = 204
client._client.request.return_value = mock_response
response = client.delete_dataset() response = client.delete_dataset()
self.assertEqual(204, response.status_code) self.assertEqual(204, response.status_code)
class TestChatClient(unittest.TestCase): class TestChatClient(unittest.TestCase):
def setUp(self): @patch("dify_client.client.httpx.Client")
self.chat_client = ChatClient(API_KEY) def setUp(self, mock_httpx_client):
self.api_key = "test-api-key"
self.chat_client = ChatClient(self.api_key)
def test_create_chat_message(self): # Set up default mock response for the client
response = self.chat_client.create_chat_message({}, "Hello, World!", "test_user") mock_response = Mock()
mock_response.text = '{"answer": "Hello! This is a test response."}'
mock_response.json.return_value = {"answer": "Hello! This is a test response."}
mock_response.status_code = 200
mock_client_instance = Mock()
mock_client_instance.request.return_value = mock_response
mock_httpx_client.return_value = mock_client_instance
@patch("dify_client.client.httpx.Client")
def test_create_chat_message(self, mock_httpx_client):
# Mock the HTTP response
mock_response = Mock()
mock_response.text = '{"answer": "Hello! This is a test response."}'
mock_response.json.return_value = {"answer": "Hello! This is a test response."}
mock_response.status_code = 200
mock_client_instance = Mock()
mock_client_instance.request.return_value = mock_response
mock_httpx_client.return_value = mock_client_instance
# Create client with mocked httpx
chat_client = ChatClient(self.api_key)
response = chat_client.create_chat_message({}, "Hello, World!", "test_user")
self.assertIn("answer", response.text) self.assertIn("answer", response.text)
def test_create_chat_message_with_vision_model_by_remote_url(self): @patch("dify_client.client.httpx.Client")
files = [{"type": "image", "transfer_method": "remote_url", "url": "your_image_url"}] def test_create_chat_message_with_vision_model_by_remote_url(self, mock_httpx_client):
response = self.chat_client.create_chat_message({}, "Describe the picture.", "test_user", files=files) # Mock the HTTP response
mock_response = Mock()
mock_response.text = '{"answer": "I can see this is a test image description."}'
mock_response.json.return_value = {"answer": "I can see this is a test image description."}
mock_response.status_code = 200
mock_client_instance = Mock()
mock_client_instance.request.return_value = mock_response
mock_httpx_client.return_value = mock_client_instance
# Create client with mocked httpx
chat_client = ChatClient(self.api_key)
files = [{"type": "image", "transfer_method": "remote_url", "url": "https://example.com/test-image.jpg"}]
response = chat_client.create_chat_message({}, "Describe the picture.", "test_user", files=files)
self.assertIn("answer", response.text) self.assertIn("answer", response.text)
def test_create_chat_message_with_vision_model_by_local_file(self): @patch("dify_client.client.httpx.Client")
def test_create_chat_message_with_vision_model_by_local_file(self, mock_httpx_client):
# Mock the HTTP response
mock_response = Mock()
mock_response.text = '{"answer": "I can see this is a test uploaded image."}'
mock_response.json.return_value = {"answer": "I can see this is a test uploaded image."}
mock_response.status_code = 200
mock_client_instance = Mock()
mock_client_instance.request.return_value = mock_response
mock_httpx_client.return_value = mock_client_instance
# Create client with mocked httpx
chat_client = ChatClient(self.api_key)
files = [ files = [
{ {
"type": "image", "type": "image",
"transfer_method": "local_file", "transfer_method": "local_file",
"upload_file_id": "your_file_id", "upload_file_id": "test-file-id",
} }
] ]
response = self.chat_client.create_chat_message({}, "Describe the picture.", "test_user", files=files) response = chat_client.create_chat_message({}, "Describe the picture.", "test_user", files=files)
self.assertIn("answer", response.text) self.assertIn("answer", response.text)
def test_get_conversation_messages(self): @patch("dify_client.client.httpx.Client")
response = self.chat_client.get_conversation_messages("test_user", "your_conversation_id") def test_get_conversation_messages(self, mock_httpx_client):
# Mock the HTTP response
mock_response = Mock()
mock_response.text = '{"answer": "Here are the conversation messages."}'
mock_response.json.return_value = {"answer": "Here are the conversation messages."}
mock_response.status_code = 200
mock_client_instance = Mock()
mock_client_instance.request.return_value = mock_response
mock_httpx_client.return_value = mock_client_instance
# Create client with mocked httpx
chat_client = ChatClient(self.api_key)
response = chat_client.get_conversation_messages("test_user", "test-conversation-id")
self.assertIn("answer", response.text) self.assertIn("answer", response.text)
def test_get_conversations(self): @patch("dify_client.client.httpx.Client")
response = self.chat_client.get_conversations("test_user") def test_get_conversations(self, mock_httpx_client):
# Mock the HTTP response
mock_response = Mock()
mock_response.text = '{"data": [{"id": "conv1", "name": "Test Conversation"}]}'
mock_response.json.return_value = {"data": [{"id": "conv1", "name": "Test Conversation"}]}
mock_response.status_code = 200
mock_client_instance = Mock()
mock_client_instance.request.return_value = mock_response
mock_httpx_client.return_value = mock_client_instance
# Create client with mocked httpx
chat_client = ChatClient(self.api_key)
response = chat_client.get_conversations("test_user")
self.assertIn("data", response.text) self.assertIn("data", response.text)
class TestCompletionClient(unittest.TestCase): class TestCompletionClient(unittest.TestCase):
def setUp(self): @patch("dify_client.client.httpx.Client")
self.completion_client = CompletionClient(API_KEY) def setUp(self, mock_httpx_client):
self.api_key = "test-api-key"
self.completion_client = CompletionClient(self.api_key)
def test_create_completion_message(self): # Set up default mock response for the client
response = self.completion_client.create_completion_message( mock_response = Mock()
mock_response.text = '{"answer": "This is a test completion response."}'
mock_response.json.return_value = {"answer": "This is a test completion response."}
mock_response.status_code = 200
mock_client_instance = Mock()
mock_client_instance.request.return_value = mock_response
mock_httpx_client.return_value = mock_client_instance
@patch("dify_client.client.httpx.Client")
def test_create_completion_message(self, mock_httpx_client):
# Mock the HTTP response
mock_response = Mock()
mock_response.text = '{"answer": "The weather today is sunny with a temperature of 75°F."}'
mock_response.json.return_value = {"answer": "The weather today is sunny with a temperature of 75°F."}
mock_response.status_code = 200
mock_client_instance = Mock()
mock_client_instance.request.return_value = mock_response
mock_httpx_client.return_value = mock_client_instance
# Create client with mocked httpx
completion_client = CompletionClient(self.api_key)
response = completion_client.create_completion_message(
{"query": "What's the weather like today?"}, "blocking", "test_user" {"query": "What's the weather like today?"}, "blocking", "test_user"
) )
self.assertIn("answer", response.text) self.assertIn("answer", response.text)
def test_create_completion_message_with_vision_model_by_remote_url(self): @patch("dify_client.client.httpx.Client")
files = [{"type": "image", "transfer_method": "remote_url", "url": "your_image_url"}] def test_create_completion_message_with_vision_model_by_remote_url(self, mock_httpx_client):
response = self.completion_client.create_completion_message( # Mock the HTTP response
mock_response = Mock()
mock_response.text = '{"answer": "This is a test image description from completion API."}'
mock_response.json.return_value = {"answer": "This is a test image description from completion API."}
mock_response.status_code = 200
mock_client_instance = Mock()
mock_client_instance.request.return_value = mock_response
mock_httpx_client.return_value = mock_client_instance
# Create client with mocked httpx
completion_client = CompletionClient(self.api_key)
files = [{"type": "image", "transfer_method": "remote_url", "url": "https://example.com/test-image.jpg"}]
response = completion_client.create_completion_message(
{"query": "Describe the picture."}, "blocking", "test_user", files {"query": "Describe the picture."}, "blocking", "test_user", files
) )
self.assertIn("answer", response.text) self.assertIn("answer", response.text)
def test_create_completion_message_with_vision_model_by_local_file(self): @patch("dify_client.client.httpx.Client")
def test_create_completion_message_with_vision_model_by_local_file(self, mock_httpx_client):
# Mock the HTTP response
mock_response = Mock()
mock_response.text = '{"answer": "This is a test uploaded image description from completion API."}'
mock_response.json.return_value = {"answer": "This is a test uploaded image description from completion API."}
mock_response.status_code = 200
mock_client_instance = Mock()
mock_client_instance.request.return_value = mock_response
mock_httpx_client.return_value = mock_client_instance
# Create client with mocked httpx
completion_client = CompletionClient(self.api_key)
files = [ files = [
{ {
"type": "image", "type": "image",
"transfer_method": "local_file", "transfer_method": "local_file",
"upload_file_id": "your_file_id", "upload_file_id": "test-file-id",
} }
] ]
response = self.completion_client.create_completion_message( response = completion_client.create_completion_message(
{"query": "Describe the picture."}, "blocking", "test_user", files {"query": "Describe the picture."}, "blocking", "test_user", files
) )
self.assertIn("answer", response.text) self.assertIn("answer", response.text)
class TestDifyClient(unittest.TestCase): class TestDifyClient(unittest.TestCase):
def setUp(self): @patch("dify_client.client.httpx.Client")
self.dify_client = DifyClient(API_KEY) def setUp(self, mock_httpx_client):
self.api_key = "test-api-key"
self.dify_client = DifyClient(self.api_key)
def test_message_feedback(self): # Set up default mock response for the client
response = self.dify_client.message_feedback("your_message_id", "like", "test_user") mock_response = Mock()
mock_response.text = '{"result": "success"}'
mock_response.json.return_value = {"result": "success"}
mock_response.status_code = 200
mock_client_instance = Mock()
mock_client_instance.request.return_value = mock_response
mock_httpx_client.return_value = mock_client_instance
@patch("dify_client.client.httpx.Client")
def test_message_feedback(self, mock_httpx_client):
# Mock the HTTP response
mock_response = Mock()
mock_response.text = '{"success": true}'
mock_response.json.return_value = {"success": True}
mock_response.status_code = 200
mock_client_instance = Mock()
mock_client_instance.request.return_value = mock_response
mock_httpx_client.return_value = mock_client_instance
# Create client with mocked httpx
dify_client = DifyClient(self.api_key)
response = dify_client.message_feedback("test-message-id", "like", "test_user")
self.assertIn("success", response.text) self.assertIn("success", response.text)
def test_get_application_parameters(self): @patch("dify_client.client.httpx.Client")
response = self.dify_client.get_application_parameters("test_user") def test_get_application_parameters(self, mock_httpx_client):
# Mock the HTTP response
mock_response = Mock()
mock_response.text = '{"user_input_form": [{"field": "text", "label": "Input"}]}'
mock_response.json.return_value = {"user_input_form": [{"field": "text", "label": "Input"}]}
mock_response.status_code = 200
mock_client_instance = Mock()
mock_client_instance.request.return_value = mock_response
mock_httpx_client.return_value = mock_client_instance
# Create client with mocked httpx
dify_client = DifyClient(self.api_key)
response = dify_client.get_application_parameters("test_user")
self.assertIn("user_input_form", response.text) self.assertIn("user_input_form", response.text)
def test_file_upload(self): @patch("dify_client.client.httpx.Client")
file_path = "your_image_file_path" @patch("builtins.open", new_callable=mock_open, read_data=b"fake image data")
def test_file_upload(self, mock_file_open, mock_httpx_client):
# Mock the HTTP response
mock_response = Mock()
mock_response.text = '{"name": "panda.jpeg", "id": "test-file-id"}'
mock_response.json.return_value = {"name": "panda.jpeg", "id": "test-file-id"}
mock_response.status_code = 200
mock_client_instance = Mock()
mock_client_instance.request.return_value = mock_response
mock_httpx_client.return_value = mock_client_instance
# Create client with mocked httpx
dify_client = DifyClient(self.api_key)
file_path = "/path/to/test/panda.jpeg"
file_name = "panda.jpeg" file_name = "panda.jpeg"
mime_type = "image/jpeg" mime_type = "image/jpeg"
with open(file_path, "rb") as file: with open(file_path, "rb") as file:
files = {"file": (file_name, file, mime_type)} files = {"file": (file_name, file, mime_type)}
response = self.dify_client.file_upload("test_user", files) response = dify_client.file_upload("test_user", files)
self.assertIn("name", response.text) self.assertIn("name", response.text)

View File

@@ -0,0 +1,79 @@
"""Tests for custom exceptions."""
import unittest
from dify_client.exceptions import (
DifyClientError,
APIError,
AuthenticationError,
RateLimitError,
ValidationError,
NetworkError,
TimeoutError,
FileUploadError,
DatasetError,
WorkflowError,
)
class TestExceptions(unittest.TestCase):
"""Test custom exception classes."""
def test_base_exception(self):
"""Test base DifyClientError."""
error = DifyClientError("Test message", 500, {"error": "details"})
self.assertEqual(str(error), "Test message")
self.assertEqual(error.status_code, 500)
self.assertEqual(error.response, {"error": "details"})
def test_api_error(self):
"""Test APIError."""
error = APIError("API failed", 400)
self.assertEqual(error.status_code, 400)
self.assertEqual(error.message, "API failed")
def test_authentication_error(self):
"""Test AuthenticationError."""
error = AuthenticationError("Invalid API key")
self.assertEqual(str(error), "Invalid API key")
def test_rate_limit_error(self):
"""Test RateLimitError."""
error = RateLimitError("Rate limited", retry_after=60)
self.assertEqual(error.retry_after, 60)
error_default = RateLimitError()
self.assertEqual(error_default.retry_after, None)
def test_validation_error(self):
"""Test ValidationError."""
error = ValidationError("Invalid parameter")
self.assertEqual(str(error), "Invalid parameter")
def test_network_error(self):
"""Test NetworkError."""
error = NetworkError("Connection failed")
self.assertEqual(str(error), "Connection failed")
def test_timeout_error(self):
"""Test TimeoutError."""
error = TimeoutError("Request timed out")
self.assertEqual(str(error), "Request timed out")
def test_file_upload_error(self):
"""Test FileUploadError."""
error = FileUploadError("Upload failed")
self.assertEqual(str(error), "Upload failed")
def test_dataset_error(self):
"""Test DatasetError."""
error = DatasetError("Dataset operation failed")
self.assertEqual(str(error), "Dataset operation failed")
def test_workflow_error(self):
"""Test WorkflowError."""
error = WorkflowError("Workflow failed")
self.assertEqual(str(error), "Workflow failed")
if __name__ == "__main__":
unittest.main()

View File

@@ -152,6 +152,7 @@ class TestHttpxMigrationMocked(unittest.TestCase):
"""Test that json parameter is passed correctly.""" """Test that json parameter is passed correctly."""
mock_response = Mock() mock_response = Mock()
mock_response.json.return_value = {"result": "success"} mock_response.json.return_value = {"result": "success"}
mock_response.status_code = 200 # Add status_code attribute
mock_client_instance = Mock() mock_client_instance = Mock()
mock_client_instance.request.return_value = mock_response mock_client_instance.request.return_value = mock_response
@@ -173,6 +174,7 @@ class TestHttpxMigrationMocked(unittest.TestCase):
"""Test that params parameter is passed correctly.""" """Test that params parameter is passed correctly."""
mock_response = Mock() mock_response = Mock()
mock_response.json.return_value = {"result": "success"} mock_response.json.return_value = {"result": "success"}
mock_response.status_code = 200 # Add status_code attribute
mock_client_instance = Mock() mock_client_instance = Mock()
mock_client_instance.request.return_value = mock_response mock_client_instance.request.return_value = mock_response

View File

@@ -0,0 +1,539 @@
"""Integration tests with proper mocking."""
import unittest
from unittest.mock import Mock, patch, MagicMock
import json
import httpx
from dify_client import (
DifyClient,
ChatClient,
CompletionClient,
WorkflowClient,
KnowledgeBaseClient,
WorkspaceClient,
)
from dify_client.exceptions import (
APIError,
AuthenticationError,
RateLimitError,
ValidationError,
)
class TestDifyClientIntegration(unittest.TestCase):
"""Integration tests for DifyClient with mocked HTTP responses."""
def setUp(self):
self.api_key = "test_api_key"
self.base_url = "https://api.dify.ai/v1"
self.client = DifyClient(api_key=self.api_key, base_url=self.base_url, enable_logging=False)
@patch("httpx.Client.request")
def test_get_app_info_integration(self, mock_request):
"""Test get_app_info integration."""
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
"id": "app_123",
"name": "Test App",
"description": "A test application",
"mode": "chat",
}
mock_request.return_value = mock_response
response = self.client.get_app_info()
data = response.json()
self.assertEqual(response.status_code, 200)
self.assertEqual(data["id"], "app_123")
self.assertEqual(data["name"], "Test App")
mock_request.assert_called_once_with(
"GET",
"/info",
json=None,
params=None,
headers={
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
},
)
@patch("httpx.Client.request")
def test_get_application_parameters_integration(self, mock_request):
"""Test get_application_parameters integration."""
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
"opening_statement": "Hello! How can I help you?",
"suggested_questions": ["What is AI?", "How does this work?"],
"speech_to_text": {"enabled": True},
"text_to_speech": {"enabled": False},
}
mock_request.return_value = mock_response
response = self.client.get_application_parameters("user_123")
data = response.json()
self.assertEqual(response.status_code, 200)
self.assertEqual(data["opening_statement"], "Hello! How can I help you?")
self.assertEqual(len(data["suggested_questions"]), 2)
mock_request.assert_called_once_with(
"GET",
"/parameters",
json=None,
params={"user": "user_123"},
headers={
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
},
)
@patch("httpx.Client.request")
def test_file_upload_integration(self, mock_request):
"""Test file_upload integration."""
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
"id": "file_123",
"name": "test.txt",
"size": 1024,
"mime_type": "text/plain",
}
mock_request.return_value = mock_response
files = {"file": ("test.txt", "test content", "text/plain")}
response = self.client.file_upload("user_123", files)
data = response.json()
self.assertEqual(response.status_code, 200)
self.assertEqual(data["id"], "file_123")
self.assertEqual(data["name"], "test.txt")
@patch("httpx.Client.request")
def test_message_feedback_integration(self, mock_request):
"""Test message_feedback integration."""
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {"success": True}
mock_request.return_value = mock_response
response = self.client.message_feedback("msg_123", "like", "user_123")
data = response.json()
self.assertEqual(response.status_code, 200)
self.assertTrue(data["success"])
mock_request.assert_called_once_with(
"POST",
"/messages/msg_123/feedbacks",
json={"rating": "like", "user": "user_123"},
params=None,
headers={
"Authorization": "Bearer test_api_key",
"Content-Type": "application/json",
},
)
class TestChatClientIntegration(unittest.TestCase):
"""Integration tests for ChatClient."""
def setUp(self):
self.client = ChatClient("test_api_key", enable_logging=False)
@patch("httpx.Client.request")
def test_create_chat_message_blocking(self, mock_request):
"""Test create_chat_message with blocking response."""
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
"id": "msg_123",
"answer": "Hello! How can I help you today?",
"conversation_id": "conv_123",
"created_at": 1234567890,
}
mock_request.return_value = mock_response
response = self.client.create_chat_message(
inputs={"query": "Hello"},
query="Hello, AI!",
user="user_123",
response_mode="blocking",
)
data = response.json()
self.assertEqual(response.status_code, 200)
self.assertEqual(data["answer"], "Hello! How can I help you today?")
self.assertEqual(data["conversation_id"], "conv_123")
@patch("httpx.Client.request")
def test_create_chat_message_streaming(self, mock_request):
"""Test create_chat_message with streaming response."""
mock_response = Mock()
mock_response.status_code = 200
mock_response.iter_lines.return_value = [
b'data: {"answer": "Hello"}',
b'data: {"answer": " world"}',
b'data: {"answer": "!"}',
]
mock_request.return_value = mock_response
response = self.client.create_chat_message(inputs={}, query="Hello", user="user_123", response_mode="streaming")
self.assertEqual(response.status_code, 200)
lines = list(response.iter_lines())
self.assertEqual(len(lines), 3)
@patch("httpx.Client.request")
def test_get_conversations_integration(self, mock_request):
"""Test get_conversations integration."""
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
"data": [
{"id": "conv_1", "name": "Conversation 1"},
{"id": "conv_2", "name": "Conversation 2"},
],
"has_more": False,
"limit": 20,
}
mock_request.return_value = mock_response
response = self.client.get_conversations("user_123", limit=20)
data = response.json()
self.assertEqual(response.status_code, 200)
self.assertEqual(len(data["data"]), 2)
self.assertEqual(data["data"][0]["name"], "Conversation 1")
@patch("httpx.Client.request")
def test_get_conversation_messages_integration(self, mock_request):
"""Test get_conversation_messages integration."""
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
"data": [
{"id": "msg_1", "role": "user", "content": "Hello"},
{"id": "msg_2", "role": "assistant", "content": "Hi there!"},
]
}
mock_request.return_value = mock_response
response = self.client.get_conversation_messages("user_123", conversation_id="conv_123")
data = response.json()
self.assertEqual(response.status_code, 200)
self.assertEqual(len(data["data"]), 2)
self.assertEqual(data["data"][0]["role"], "user")
class TestCompletionClientIntegration(unittest.TestCase):
"""Integration tests for CompletionClient."""
def setUp(self):
self.client = CompletionClient("test_api_key", enable_logging=False)
@patch("httpx.Client.request")
def test_create_completion_message_blocking(self, mock_request):
"""Test create_completion_message with blocking response."""
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
"id": "comp_123",
"answer": "This is a completion response.",
"created_at": 1234567890,
}
mock_request.return_value = mock_response
response = self.client.create_completion_message(
inputs={"prompt": "Complete this sentence"},
response_mode="blocking",
user="user_123",
)
data = response.json()
self.assertEqual(response.status_code, 200)
self.assertEqual(data["answer"], "This is a completion response.")
@patch("httpx.Client.request")
def test_create_completion_message_with_files(self, mock_request):
"""Test create_completion_message with files."""
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
"id": "comp_124",
"answer": "I can see the image shows...",
"files": [{"id": "file_1", "type": "image"}],
}
mock_request.return_value = mock_response
files = {
"file": {
"type": "image",
"transfer_method": "remote_url",
"url": "https://example.com/image.jpg",
}
}
response = self.client.create_completion_message(
inputs={"prompt": "Describe this image"},
response_mode="blocking",
user="user_123",
files=files,
)
data = response.json()
self.assertEqual(response.status_code, 200)
self.assertIn("image", data["answer"])
self.assertEqual(len(data["files"]), 1)
class TestWorkflowClientIntegration(unittest.TestCase):
"""Integration tests for WorkflowClient."""
def setUp(self):
self.client = WorkflowClient("test_api_key", enable_logging=False)
@patch("httpx.Client.request")
def test_run_workflow_blocking(self, mock_request):
"""Test run workflow with blocking response."""
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
"id": "run_123",
"workflow_id": "workflow_123",
"status": "succeeded",
"inputs": {"query": "Test input"},
"outputs": {"result": "Test output"},
"elapsed_time": 2.5,
}
mock_request.return_value = mock_response
response = self.client.run(inputs={"query": "Test input"}, response_mode="blocking", user="user_123")
data = response.json()
self.assertEqual(response.status_code, 200)
self.assertEqual(data["status"], "succeeded")
self.assertEqual(data["outputs"]["result"], "Test output")
@patch("httpx.Client.request")
def test_get_workflow_logs(self, mock_request):
"""Test get_workflow_logs integration."""
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
"logs": [
{"id": "log_1", "status": "succeeded", "created_at": 1234567890},
{"id": "log_2", "status": "failed", "created_at": 1234567891},
],
"total": 2,
"page": 1,
"limit": 20,
}
mock_request.return_value = mock_response
response = self.client.get_workflow_logs(page=1, limit=20)
data = response.json()
self.assertEqual(response.status_code, 200)
self.assertEqual(len(data["logs"]), 2)
self.assertEqual(data["logs"][0]["status"], "succeeded")
class TestKnowledgeBaseClientIntegration(unittest.TestCase):
"""Integration tests for KnowledgeBaseClient."""
def setUp(self):
self.client = KnowledgeBaseClient("test_api_key")
@patch("httpx.Client.request")
def test_create_dataset(self, mock_request):
"""Test create_dataset integration."""
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
"id": "dataset_123",
"name": "Test Dataset",
"description": "A test dataset",
"created_at": 1234567890,
}
mock_request.return_value = mock_response
response = self.client.create_dataset(name="Test Dataset")
data = response.json()
self.assertEqual(response.status_code, 200)
self.assertEqual(data["name"], "Test Dataset")
self.assertEqual(data["id"], "dataset_123")
@patch("httpx.Client.request")
def test_list_datasets(self, mock_request):
"""Test list_datasets integration."""
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
"data": [
{"id": "dataset_1", "name": "Dataset 1"},
{"id": "dataset_2", "name": "Dataset 2"},
],
"has_more": False,
"limit": 20,
}
mock_request.return_value = mock_response
response = self.client.list_datasets(page=1, page_size=20)
data = response.json()
self.assertEqual(response.status_code, 200)
self.assertEqual(len(data["data"]), 2)
@patch("httpx.Client.request")
def test_create_document_by_text(self, mock_request):
"""Test create_document_by_text integration."""
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
"document": {
"id": "doc_123",
"name": "Test Document",
"word_count": 100,
"status": "indexing",
}
}
mock_request.return_value = mock_response
# Mock dataset_id
self.client.dataset_id = "dataset_123"
response = self.client.create_document_by_text(name="Test Document", text="This is test document content.")
data = response.json()
self.assertEqual(response.status_code, 200)
self.assertEqual(data["document"]["name"], "Test Document")
self.assertEqual(data["document"]["word_count"], 100)
class TestWorkspaceClientIntegration(unittest.TestCase):
"""Integration tests for WorkspaceClient."""
def setUp(self):
self.client = WorkspaceClient("test_api_key", enable_logging=False)
@patch("httpx.Client.request")
def test_get_available_models(self, mock_request):
"""Test get_available_models integration."""
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
"models": [
{"id": "gpt-4", "name": "GPT-4", "provider": "openai"},
{"id": "claude-3", "name": "Claude 3", "provider": "anthropic"},
]
}
mock_request.return_value = mock_response
response = self.client.get_available_models("llm")
data = response.json()
self.assertEqual(response.status_code, 200)
self.assertEqual(len(data["models"]), 2)
self.assertEqual(data["models"][0]["id"], "gpt-4")
class TestErrorScenariosIntegration(unittest.TestCase):
"""Integration tests for error scenarios."""
def setUp(self):
self.client = DifyClient("test_api_key", enable_logging=False)
@patch("httpx.Client.request")
def test_authentication_error_integration(self, mock_request):
"""Test authentication error in integration."""
mock_response = Mock()
mock_response.status_code = 401
mock_response.json.return_value = {"message": "Invalid API key"}
mock_request.return_value = mock_response
with self.assertRaises(AuthenticationError) as context:
self.client.get_app_info()
self.assertEqual(str(context.exception), "Invalid API key")
self.assertEqual(context.exception.status_code, 401)
@patch("httpx.Client.request")
def test_rate_limit_error_integration(self, mock_request):
"""Test rate limit error in integration."""
mock_response = Mock()
mock_response.status_code = 429
mock_response.json.return_value = {"message": "Rate limit exceeded"}
mock_response.headers = {"Retry-After": "60"}
mock_request.return_value = mock_response
with self.assertRaises(RateLimitError) as context:
self.client.get_app_info()
self.assertEqual(str(context.exception), "Rate limit exceeded")
self.assertEqual(context.exception.retry_after, "60")
@patch("httpx.Client.request")
def test_server_error_with_retry_integration(self, mock_request):
"""Test server error with retry in integration."""
# API errors don't retry by design - only network/timeout errors retry
mock_response_500 = Mock()
mock_response_500.status_code = 500
mock_response_500.json.return_value = {"message": "Internal server error"}
mock_request.return_value = mock_response_500
with patch("time.sleep"): # Skip actual sleep
with self.assertRaises(APIError) as context:
self.client.get_app_info()
self.assertEqual(str(context.exception), "Internal server error")
self.assertEqual(mock_request.call_count, 1)
@patch("httpx.Client.request")
def test_validation_error_integration(self, mock_request):
"""Test validation error in integration."""
mock_response = Mock()
mock_response.status_code = 422
mock_response.json.return_value = {
"message": "Validation failed",
"details": {"field": "query", "error": "required"},
}
mock_request.return_value = mock_response
with self.assertRaises(ValidationError) as context:
self.client.get_app_info()
self.assertEqual(str(context.exception), "Validation failed")
self.assertEqual(context.exception.status_code, 422)
class TestContextManagerIntegration(unittest.TestCase):
"""Integration tests for context manager usage."""
@patch("httpx.Client.close")
@patch("httpx.Client.request")
def test_context_manager_usage(self, mock_request, mock_close):
"""Test context manager properly closes connections."""
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {"id": "app_123", "name": "Test App"}
mock_request.return_value = mock_response
with DifyClient("test_api_key") as client:
response = client.get_app_info()
self.assertEqual(response.status_code, 200)
# Verify close was called
mock_close.assert_called_once()
@patch("httpx.Client.close")
def test_manual_close(self, mock_close):
"""Test manual close method."""
client = DifyClient("test_api_key")
client.close()
mock_close.assert_called_once()
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,640 @@
"""Unit tests for response models."""
import unittest
import json
from datetime import datetime
from dify_client.models import (
BaseResponse,
ErrorResponse,
FileInfo,
MessageResponse,
ConversationResponse,
DatasetResponse,
DocumentResponse,
DocumentSegmentResponse,
WorkflowRunResponse,
ApplicationParametersResponse,
AnnotationResponse,
PaginatedResponse,
ConversationVariableResponse,
FileUploadResponse,
AudioResponse,
SuggestedQuestionsResponse,
AppInfoResponse,
WorkspaceModelsResponse,
HitTestingResponse,
DatasetTagsResponse,
WorkflowLogsResponse,
ModelProviderResponse,
FileInfoResponse,
WorkflowDraftResponse,
ApiTokenResponse,
JobStatusResponse,
DatasetQueryResponse,
DatasetTemplateResponse,
)
class TestResponseModels(unittest.TestCase):
"""Test cases for response model classes."""
def test_base_response(self):
"""Test BaseResponse model."""
response = BaseResponse(success=True, message="Operation successful")
self.assertTrue(response.success)
self.assertEqual(response.message, "Operation successful")
def test_base_response_defaults(self):
"""Test BaseResponse with default values."""
response = BaseResponse(success=True)
self.assertTrue(response.success)
self.assertIsNone(response.message)
def test_error_response(self):
"""Test ErrorResponse model."""
response = ErrorResponse(
success=False,
message="Error occurred",
error_code="VALIDATION_ERROR",
details={"field": "invalid_value"},
)
self.assertFalse(response.success)
self.assertEqual(response.message, "Error occurred")
self.assertEqual(response.error_code, "VALIDATION_ERROR")
self.assertEqual(response.details["field"], "invalid_value")
def test_file_info(self):
"""Test FileInfo model."""
now = datetime.now()
file_info = FileInfo(
id="file_123",
name="test.txt",
size=1024,
mime_type="text/plain",
url="https://example.com/file.txt",
created_at=now,
)
self.assertEqual(file_info.id, "file_123")
self.assertEqual(file_info.name, "test.txt")
self.assertEqual(file_info.size, 1024)
self.assertEqual(file_info.mime_type, "text/plain")
self.assertEqual(file_info.url, "https://example.com/file.txt")
self.assertEqual(file_info.created_at, now)
def test_message_response(self):
"""Test MessageResponse model."""
response = MessageResponse(
success=True,
id="msg_123",
answer="Hello, world!",
conversation_id="conv_123",
created_at=1234567890,
metadata={"model": "gpt-4"},
files=[{"id": "file_1", "type": "image"}],
)
self.assertTrue(response.success)
self.assertEqual(response.id, "msg_123")
self.assertEqual(response.answer, "Hello, world!")
self.assertEqual(response.conversation_id, "conv_123")
self.assertEqual(response.created_at, 1234567890)
self.assertEqual(response.metadata["model"], "gpt-4")
self.assertEqual(response.files[0]["id"], "file_1")
def test_conversation_response(self):
"""Test ConversationResponse model."""
response = ConversationResponse(
success=True,
id="conv_123",
name="Test Conversation",
inputs={"query": "Hello"},
status="active",
created_at=1234567890,
updated_at=1234567891,
)
self.assertTrue(response.success)
self.assertEqual(response.id, "conv_123")
self.assertEqual(response.name, "Test Conversation")
self.assertEqual(response.inputs["query"], "Hello")
self.assertEqual(response.status, "active")
self.assertEqual(response.created_at, 1234567890)
self.assertEqual(response.updated_at, 1234567891)
def test_dataset_response(self):
"""Test DatasetResponse model."""
response = DatasetResponse(
success=True,
id="dataset_123",
name="Test Dataset",
description="A test dataset",
permission="read",
indexing_technique="high_quality",
embedding_model="text-embedding-ada-002",
embedding_model_provider="openai",
retrieval_model={"search_type": "semantic"},
document_count=10,
word_count=5000,
app_count=2,
created_at=1234567890,
updated_at=1234567891,
)
self.assertTrue(response.success)
self.assertEqual(response.id, "dataset_123")
self.assertEqual(response.name, "Test Dataset")
self.assertEqual(response.description, "A test dataset")
self.assertEqual(response.permission, "read")
self.assertEqual(response.indexing_technique, "high_quality")
self.assertEqual(response.embedding_model, "text-embedding-ada-002")
self.assertEqual(response.embedding_model_provider, "openai")
self.assertEqual(response.retrieval_model["search_type"], "semantic")
self.assertEqual(response.document_count, 10)
self.assertEqual(response.word_count, 5000)
self.assertEqual(response.app_count, 2)
def test_document_response(self):
"""Test DocumentResponse model."""
response = DocumentResponse(
success=True,
id="doc_123",
name="test_document.txt",
data_source_type="upload_file",
position=1,
enabled=True,
word_count=1000,
hit_count=5,
doc_form="text_model",
created_at=1234567890.0,
indexing_status="completed",
completed_at=1234567891.0,
)
self.assertTrue(response.success)
self.assertEqual(response.id, "doc_123")
self.assertEqual(response.name, "test_document.txt")
self.assertEqual(response.data_source_type, "upload_file")
self.assertEqual(response.position, 1)
self.assertTrue(response.enabled)
self.assertEqual(response.word_count, 1000)
self.assertEqual(response.hit_count, 5)
self.assertEqual(response.doc_form, "text_model")
self.assertEqual(response.created_at, 1234567890.0)
self.assertEqual(response.indexing_status, "completed")
self.assertEqual(response.completed_at, 1234567891.0)
def test_document_segment_response(self):
"""Test DocumentSegmentResponse model."""
response = DocumentSegmentResponse(
success=True,
id="seg_123",
position=1,
document_id="doc_123",
content="This is a test segment.",
answer="Test answer",
word_count=5,
tokens=10,
keywords=["test", "segment"],
hit_count=2,
enabled=True,
status="completed",
created_at=1234567890.0,
completed_at=1234567891.0,
)
self.assertTrue(response.success)
self.assertEqual(response.id, "seg_123")
self.assertEqual(response.position, 1)
self.assertEqual(response.document_id, "doc_123")
self.assertEqual(response.content, "This is a test segment.")
self.assertEqual(response.answer, "Test answer")
self.assertEqual(response.word_count, 5)
self.assertEqual(response.tokens, 10)
self.assertEqual(response.keywords, ["test", "segment"])
self.assertEqual(response.hit_count, 2)
self.assertTrue(response.enabled)
self.assertEqual(response.status, "completed")
self.assertEqual(response.created_at, 1234567890.0)
self.assertEqual(response.completed_at, 1234567891.0)
def test_workflow_run_response(self):
"""Test WorkflowRunResponse model."""
response = WorkflowRunResponse(
success=True,
id="run_123",
workflow_id="workflow_123",
status="succeeded",
inputs={"query": "test"},
outputs={"answer": "result"},
elapsed_time=5.5,
total_tokens=100,
total_steps=3,
created_at=1234567890.0,
finished_at=1234567895.5,
)
self.assertTrue(response.success)
self.assertEqual(response.id, "run_123")
self.assertEqual(response.workflow_id, "workflow_123")
self.assertEqual(response.status, "succeeded")
self.assertEqual(response.inputs["query"], "test")
self.assertEqual(response.outputs["answer"], "result")
self.assertEqual(response.elapsed_time, 5.5)
self.assertEqual(response.total_tokens, 100)
self.assertEqual(response.total_steps, 3)
self.assertEqual(response.created_at, 1234567890.0)
self.assertEqual(response.finished_at, 1234567895.5)
def test_application_parameters_response(self):
"""Test ApplicationParametersResponse model."""
response = ApplicationParametersResponse(
success=True,
opening_statement="Hello! How can I help you?",
suggested_questions=["What is AI?", "How does this work?"],
speech_to_text={"enabled": True},
text_to_speech={"enabled": False, "voice": "alloy"},
retriever_resource={"enabled": True},
sensitive_word_avoidance={"enabled": False},
file_upload={"enabled": True, "file_size_limit": 10485760},
system_parameters={"max_tokens": 1000},
user_input_form=[{"type": "text", "label": "Query"}],
)
self.assertTrue(response.success)
self.assertEqual(response.opening_statement, "Hello! How can I help you?")
self.assertEqual(response.suggested_questions, ["What is AI?", "How does this work?"])
self.assertTrue(response.speech_to_text["enabled"])
self.assertFalse(response.text_to_speech["enabled"])
self.assertEqual(response.text_to_speech["voice"], "alloy")
self.assertTrue(response.retriever_resource["enabled"])
self.assertFalse(response.sensitive_word_avoidance["enabled"])
self.assertTrue(response.file_upload["enabled"])
self.assertEqual(response.file_upload["file_size_limit"], 10485760)
self.assertEqual(response.system_parameters["max_tokens"], 1000)
self.assertEqual(response.user_input_form[0]["type"], "text")
def test_annotation_response(self):
"""Test AnnotationResponse model."""
response = AnnotationResponse(
success=True,
id="annotation_123",
question="What is the capital of France?",
answer="Paris",
content="Additional context",
created_at=1234567890.0,
updated_at=1234567891.0,
created_by="user_123",
updated_by="user_123",
hit_count=5,
)
self.assertTrue(response.success)
self.assertEqual(response.id, "annotation_123")
self.assertEqual(response.question, "What is the capital of France?")
self.assertEqual(response.answer, "Paris")
self.assertEqual(response.content, "Additional context")
self.assertEqual(response.created_at, 1234567890.0)
self.assertEqual(response.updated_at, 1234567891.0)
self.assertEqual(response.created_by, "user_123")
self.assertEqual(response.updated_by, "user_123")
self.assertEqual(response.hit_count, 5)
def test_paginated_response(self):
"""Test PaginatedResponse model."""
response = PaginatedResponse(
success=True,
data=[{"id": 1}, {"id": 2}, {"id": 3}],
has_more=True,
limit=10,
total=100,
page=1,
)
self.assertTrue(response.success)
self.assertEqual(len(response.data), 3)
self.assertEqual(response.data[0]["id"], 1)
self.assertTrue(response.has_more)
self.assertEqual(response.limit, 10)
self.assertEqual(response.total, 100)
self.assertEqual(response.page, 1)
def test_conversation_variable_response(self):
"""Test ConversationVariableResponse model."""
response = ConversationVariableResponse(
success=True,
conversation_id="conv_123",
variables=[
{"id": "var_1", "name": "user_name", "value": "John"},
{"id": "var_2", "name": "preferences", "value": {"theme": "dark"}},
],
)
self.assertTrue(response.success)
self.assertEqual(response.conversation_id, "conv_123")
self.assertEqual(len(response.variables), 2)
self.assertEqual(response.variables[0]["name"], "user_name")
self.assertEqual(response.variables[0]["value"], "John")
self.assertEqual(response.variables[1]["name"], "preferences")
self.assertEqual(response.variables[1]["value"]["theme"], "dark")
def test_file_upload_response(self):
"""Test FileUploadResponse model."""
response = FileUploadResponse(
success=True,
id="file_123",
name="test.txt",
size=1024,
mime_type="text/plain",
url="https://example.com/files/test.txt",
created_at=1234567890.0,
)
self.assertTrue(response.success)
self.assertEqual(response.id, "file_123")
self.assertEqual(response.name, "test.txt")
self.assertEqual(response.size, 1024)
self.assertEqual(response.mime_type, "text/plain")
self.assertEqual(response.url, "https://example.com/files/test.txt")
self.assertEqual(response.created_at, 1234567890.0)
def test_audio_response(self):
"""Test AudioResponse model."""
response = AudioResponse(
success=True,
audio="base64_encoded_audio_data",
audio_url="https://example.com/audio.mp3",
duration=10.5,
sample_rate=44100,
)
self.assertTrue(response.success)
self.assertEqual(response.audio, "base64_encoded_audio_data")
self.assertEqual(response.audio_url, "https://example.com/audio.mp3")
self.assertEqual(response.duration, 10.5)
self.assertEqual(response.sample_rate, 44100)
def test_suggested_questions_response(self):
"""Test SuggestedQuestionsResponse model."""
response = SuggestedQuestionsResponse(
success=True,
message_id="msg_123",
questions=[
"What is machine learning?",
"How does AI work?",
"Can you explain neural networks?",
],
)
self.assertTrue(response.success)
self.assertEqual(response.message_id, "msg_123")
self.assertEqual(len(response.questions), 3)
self.assertEqual(response.questions[0], "What is machine learning?")
def test_app_info_response(self):
"""Test AppInfoResponse model."""
response = AppInfoResponse(
success=True,
id="app_123",
name="Test App",
description="A test application",
icon="🤖",
icon_background="#FF6B6B",
mode="chat",
tags=["AI", "Chat", "Test"],
enable_site=True,
enable_api=True,
api_token="app_token_123",
)
self.assertTrue(response.success)
self.assertEqual(response.id, "app_123")
self.assertEqual(response.name, "Test App")
self.assertEqual(response.description, "A test application")
self.assertEqual(response.icon, "🤖")
self.assertEqual(response.icon_background, "#FF6B6B")
self.assertEqual(response.mode, "chat")
self.assertEqual(response.tags, ["AI", "Chat", "Test"])
self.assertTrue(response.enable_site)
self.assertTrue(response.enable_api)
self.assertEqual(response.api_token, "app_token_123")
def test_workspace_models_response(self):
"""Test WorkspaceModelsResponse model."""
response = WorkspaceModelsResponse(
success=True,
models=[
{"id": "gpt-4", "name": "GPT-4", "provider": "openai"},
{"id": "claude-3", "name": "Claude 3", "provider": "anthropic"},
],
)
self.assertTrue(response.success)
self.assertEqual(len(response.models), 2)
self.assertEqual(response.models[0]["id"], "gpt-4")
self.assertEqual(response.models[0]["name"], "GPT-4")
self.assertEqual(response.models[0]["provider"], "openai")
def test_hit_testing_response(self):
"""Test HitTestingResponse model."""
response = HitTestingResponse(
success=True,
query="What is machine learning?",
records=[
{"content": "Machine learning is a subset of AI...", "score": 0.95},
{"content": "ML algorithms learn from data...", "score": 0.87},
],
)
self.assertTrue(response.success)
self.assertEqual(response.query, "What is machine learning?")
self.assertEqual(len(response.records), 2)
self.assertEqual(response.records[0]["score"], 0.95)
def test_dataset_tags_response(self):
"""Test DatasetTagsResponse model."""
response = DatasetTagsResponse(
success=True,
tags=[
{"id": "tag_1", "name": "Technology", "color": "#FF0000"},
{"id": "tag_2", "name": "Science", "color": "#00FF00"},
],
)
self.assertTrue(response.success)
self.assertEqual(len(response.tags), 2)
self.assertEqual(response.tags[0]["name"], "Technology")
self.assertEqual(response.tags[0]["color"], "#FF0000")
def test_workflow_logs_response(self):
"""Test WorkflowLogsResponse model."""
response = WorkflowLogsResponse(
success=True,
logs=[
{"id": "log_1", "status": "succeeded", "created_at": 1234567890},
{"id": "log_2", "status": "failed", "created_at": 1234567891},
],
total=50,
page=1,
limit=10,
has_more=True,
)
self.assertTrue(response.success)
self.assertEqual(len(response.logs), 2)
self.assertEqual(response.logs[0]["status"], "succeeded")
self.assertEqual(response.total, 50)
self.assertEqual(response.page, 1)
self.assertEqual(response.limit, 10)
self.assertTrue(response.has_more)
def test_model_serialization(self):
"""Test that models can be serialized to JSON."""
response = MessageResponse(
success=True,
id="msg_123",
answer="Hello, world!",
conversation_id="conv_123",
)
# Convert to dict and then to JSON
response_dict = {
"success": response.success,
"id": response.id,
"answer": response.answer,
"conversation_id": response.conversation_id,
}
json_str = json.dumps(response_dict)
parsed = json.loads(json_str)
self.assertTrue(parsed["success"])
self.assertEqual(parsed["id"], "msg_123")
self.assertEqual(parsed["answer"], "Hello, world!")
self.assertEqual(parsed["conversation_id"], "conv_123")
# Tests for new response models
def test_model_provider_response(self):
"""Test ModelProviderResponse model."""
response = ModelProviderResponse(
success=True,
provider_name="openai",
provider_type="llm",
models=[
{"id": "gpt-4", "name": "GPT-4", "max_tokens": 8192},
{"id": "gpt-3.5-turbo", "name": "GPT-3.5 Turbo", "max_tokens": 4096},
],
is_enabled=True,
credentials={"api_key": "sk-..."},
)
self.assertTrue(response.success)
self.assertEqual(response.provider_name, "openai")
self.assertEqual(response.provider_type, "llm")
self.assertEqual(len(response.models), 2)
self.assertEqual(response.models[0]["id"], "gpt-4")
self.assertTrue(response.is_enabled)
self.assertEqual(response.credentials["api_key"], "sk-...")
def test_file_info_response(self):
"""Test FileInfoResponse model."""
response = FileInfoResponse(
success=True,
id="file_123",
name="document.pdf",
size=2048576,
mime_type="application/pdf",
url="https://example.com/files/document.pdf",
created_at=1234567890,
metadata={"pages": 10, "author": "John Doe"},
)
self.assertTrue(response.success)
self.assertEqual(response.id, "file_123")
self.assertEqual(response.name, "document.pdf")
self.assertEqual(response.size, 2048576)
self.assertEqual(response.mime_type, "application/pdf")
self.assertEqual(response.url, "https://example.com/files/document.pdf")
self.assertEqual(response.created_at, 1234567890)
self.assertEqual(response.metadata["pages"], 10)
def test_workflow_draft_response(self):
"""Test WorkflowDraftResponse model."""
response = WorkflowDraftResponse(
success=True,
id="draft_123",
app_id="app_456",
draft_data={"nodes": [], "edges": [], "config": {"name": "Test Workflow"}},
version=1,
created_at=1234567890,
updated_at=1234567891,
)
self.assertTrue(response.success)
self.assertEqual(response.id, "draft_123")
self.assertEqual(response.app_id, "app_456")
self.assertEqual(response.draft_data["config"]["name"], "Test Workflow")
self.assertEqual(response.version, 1)
self.assertEqual(response.created_at, 1234567890)
self.assertEqual(response.updated_at, 1234567891)
def test_api_token_response(self):
"""Test ApiTokenResponse model."""
response = ApiTokenResponse(
success=True,
id="token_123",
name="Production Token",
token="app-xxxxxxxxxxxx",
description="Token for production environment",
created_at=1234567890,
last_used_at=1234567891,
is_active=True,
)
self.assertTrue(response.success)
self.assertEqual(response.id, "token_123")
self.assertEqual(response.name, "Production Token")
self.assertEqual(response.token, "app-xxxxxxxxxxxx")
self.assertEqual(response.description, "Token for production environment")
self.assertEqual(response.created_at, 1234567890)
self.assertEqual(response.last_used_at, 1234567891)
self.assertTrue(response.is_active)
def test_job_status_response(self):
"""Test JobStatusResponse model."""
response = JobStatusResponse(
success=True,
job_id="job_123",
job_status="running",
error_msg=None,
progress=0.75,
created_at=1234567890,
updated_at=1234567891,
)
self.assertTrue(response.success)
self.assertEqual(response.job_id, "job_123")
self.assertEqual(response.job_status, "running")
self.assertIsNone(response.error_msg)
self.assertEqual(response.progress, 0.75)
self.assertEqual(response.created_at, 1234567890)
self.assertEqual(response.updated_at, 1234567891)
def test_dataset_query_response(self):
"""Test DatasetQueryResponse model."""
response = DatasetQueryResponse(
success=True,
query="What is machine learning?",
records=[
{"content": "Machine learning is...", "score": 0.95},
{"content": "ML algorithms...", "score": 0.87},
],
total=2,
search_time=0.123,
retrieval_model={"method": "semantic_search", "top_k": 3},
)
self.assertTrue(response.success)
self.assertEqual(response.query, "What is machine learning?")
self.assertEqual(len(response.records), 2)
self.assertEqual(response.total, 2)
self.assertEqual(response.search_time, 0.123)
self.assertEqual(response.retrieval_model["method"], "semantic_search")
def test_dataset_template_response(self):
"""Test DatasetTemplateResponse model."""
response = DatasetTemplateResponse(
success=True,
template_name="customer_support",
display_name="Customer Support",
description="Template for customer support knowledge base",
category="support",
icon="🎧",
config_schema={"fields": [{"name": "category", "type": "string"}]},
)
self.assertTrue(response.success)
self.assertEqual(response.template_name, "customer_support")
self.assertEqual(response.display_name, "Customer Support")
self.assertEqual(response.description, "Template for customer support knowledge base")
self.assertEqual(response.category, "support")
self.assertEqual(response.icon, "🎧")
self.assertEqual(response.config_schema["fields"][0]["name"], "category")
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,313 @@
"""Unit tests for retry mechanism and error handling."""
import unittest
from unittest.mock import Mock, patch, MagicMock
import httpx
from dify_client.client import DifyClient
from dify_client.exceptions import (
APIError,
AuthenticationError,
RateLimitError,
ValidationError,
NetworkError,
TimeoutError,
FileUploadError,
)
class TestRetryMechanism(unittest.TestCase):
"""Test cases for retry mechanism."""
def setUp(self):
self.api_key = "test_api_key"
self.base_url = "https://api.dify.ai/v1"
self.client = DifyClient(
api_key=self.api_key,
base_url=self.base_url,
max_retries=3,
retry_delay=0.1, # Short delay for tests
enable_logging=False,
)
@patch("httpx.Client.request")
def test_successful_request_no_retry(self, mock_request):
"""Test that successful requests don't trigger retries."""
mock_response = Mock()
mock_response.status_code = 200
mock_response.content = b'{"success": true}'
mock_request.return_value = mock_response
response = self.client._send_request("GET", "/test")
self.assertEqual(response, mock_response)
self.assertEqual(mock_request.call_count, 1)
@patch("httpx.Client.request")
@patch("time.sleep")
def test_retry_on_network_error(self, mock_sleep, mock_request):
"""Test retry on network errors."""
# First two calls raise network error, third succeeds
mock_request.side_effect = [
httpx.NetworkError("Connection failed"),
httpx.NetworkError("Connection failed"),
Mock(status_code=200, content=b'{"success": true}'),
]
mock_response = Mock()
mock_response.status_code = 200
mock_response.content = b'{"success": true}'
response = self.client._send_request("GET", "/test")
self.assertEqual(response.status_code, 200)
self.assertEqual(mock_request.call_count, 3)
self.assertEqual(mock_sleep.call_count, 2)
@patch("httpx.Client.request")
@patch("time.sleep")
def test_retry_on_timeout_error(self, mock_sleep, mock_request):
"""Test retry on timeout errors."""
mock_request.side_effect = [
httpx.TimeoutException("Request timed out"),
httpx.TimeoutException("Request timed out"),
Mock(status_code=200, content=b'{"success": true}'),
]
response = self.client._send_request("GET", "/test")
self.assertEqual(response.status_code, 200)
self.assertEqual(mock_request.call_count, 3)
self.assertEqual(mock_sleep.call_count, 2)
@patch("httpx.Client.request")
@patch("time.sleep")
def test_max_retries_exceeded(self, mock_sleep, mock_request):
"""Test behavior when max retries are exceeded."""
mock_request.side_effect = httpx.NetworkError("Persistent network error")
with self.assertRaises(NetworkError):
self.client._send_request("GET", "/test")
self.assertEqual(mock_request.call_count, 4) # 1 initial + 3 retries
self.assertEqual(mock_sleep.call_count, 3)
@patch("httpx.Client.request")
def test_no_retry_on_client_error(self, mock_request):
"""Test that client errors (4xx) don't trigger retries."""
mock_response = Mock()
mock_response.status_code = 401
mock_response.json.return_value = {"message": "Unauthorized"}
mock_request.return_value = mock_response
with self.assertRaises(AuthenticationError):
self.client._send_request("GET", "/test")
self.assertEqual(mock_request.call_count, 1)
@patch("httpx.Client.request")
def test_retry_on_server_error(self, mock_request):
"""Test that server errors (5xx) don't retry - they raise APIError immediately."""
mock_response_500 = Mock()
mock_response_500.status_code = 500
mock_response_500.json.return_value = {"message": "Internal server error"}
mock_request.return_value = mock_response_500
with self.assertRaises(APIError) as context:
self.client._send_request("GET", "/test")
self.assertEqual(str(context.exception), "Internal server error")
self.assertEqual(context.exception.status_code, 500)
# Should not retry server errors
self.assertEqual(mock_request.call_count, 1)
@patch("httpx.Client.request")
def test_exponential_backoff(self, mock_request):
"""Test exponential backoff timing."""
mock_request.side_effect = [
httpx.NetworkError("Connection failed"),
httpx.NetworkError("Connection failed"),
httpx.NetworkError("Connection failed"),
httpx.NetworkError("Connection failed"), # All attempts fail
]
with patch("time.sleep") as mock_sleep:
with self.assertRaises(NetworkError):
self.client._send_request("GET", "/test")
# Check exponential backoff: 0.1, 0.2, 0.4
expected_calls = [0.1, 0.2, 0.4]
actual_calls = [call[0][0] for call in mock_sleep.call_args_list]
self.assertEqual(actual_calls, expected_calls)
class TestErrorHandling(unittest.TestCase):
"""Test cases for error handling."""
def setUp(self):
self.client = DifyClient(api_key="test_api_key", enable_logging=False)
@patch("httpx.Client.request")
def test_authentication_error(self, mock_request):
"""Test AuthenticationError handling."""
mock_response = Mock()
mock_response.status_code = 401
mock_response.json.return_value = {"message": "Invalid API key"}
mock_request.return_value = mock_response
with self.assertRaises(AuthenticationError) as context:
self.client._send_request("GET", "/test")
self.assertEqual(str(context.exception), "Invalid API key")
self.assertEqual(context.exception.status_code, 401)
@patch("httpx.Client.request")
def test_rate_limit_error(self, mock_request):
"""Test RateLimitError handling."""
mock_response = Mock()
mock_response.status_code = 429
mock_response.json.return_value = {"message": "Rate limit exceeded"}
mock_response.headers = {"Retry-After": "60"}
mock_request.return_value = mock_response
with self.assertRaises(RateLimitError) as context:
self.client._send_request("GET", "/test")
self.assertEqual(str(context.exception), "Rate limit exceeded")
self.assertEqual(context.exception.retry_after, "60")
@patch("httpx.Client.request")
def test_validation_error(self, mock_request):
"""Test ValidationError handling."""
mock_response = Mock()
mock_response.status_code = 422
mock_response.json.return_value = {"message": "Invalid parameters"}
mock_request.return_value = mock_response
with self.assertRaises(ValidationError) as context:
self.client._send_request("GET", "/test")
self.assertEqual(str(context.exception), "Invalid parameters")
self.assertEqual(context.exception.status_code, 422)
@patch("httpx.Client.request")
def test_api_error(self, mock_request):
"""Test general APIError handling."""
mock_response = Mock()
mock_response.status_code = 500
mock_response.json.return_value = {"message": "Internal server error"}
mock_request.return_value = mock_response
with self.assertRaises(APIError) as context:
self.client._send_request("GET", "/test")
self.assertEqual(str(context.exception), "Internal server error")
self.assertEqual(context.exception.status_code, 500)
@patch("httpx.Client.request")
def test_error_response_without_json(self, mock_request):
"""Test error handling when response doesn't contain valid JSON."""
mock_response = Mock()
mock_response.status_code = 500
mock_response.content = b"Internal Server Error"
mock_response.json.side_effect = ValueError("No JSON object could be decoded")
mock_request.return_value = mock_response
with self.assertRaises(APIError) as context:
self.client._send_request("GET", "/test")
self.assertEqual(str(context.exception), "HTTP 500")
@patch("httpx.Client.request")
def test_file_upload_error(self, mock_request):
"""Test FileUploadError handling."""
mock_response = Mock()
mock_response.status_code = 400
mock_response.json.return_value = {"message": "File upload failed"}
mock_request.return_value = mock_response
with self.assertRaises(FileUploadError) as context:
self.client._send_request_with_files("POST", "/upload", {}, {})
self.assertEqual(str(context.exception), "File upload failed")
self.assertEqual(context.exception.status_code, 400)
class TestParameterValidation(unittest.TestCase):
"""Test cases for parameter validation."""
def setUp(self):
self.client = DifyClient(api_key="test_api_key", enable_logging=False)
def test_empty_string_validation(self):
"""Test validation of empty strings."""
with self.assertRaises(ValidationError):
self.client._validate_params(empty_string="")
def test_whitespace_only_string_validation(self):
"""Test validation of whitespace-only strings."""
with self.assertRaises(ValidationError):
self.client._validate_params(whitespace_string=" ")
def test_long_string_validation(self):
"""Test validation of overly long strings."""
long_string = "a" * 10001 # Exceeds 10000 character limit
with self.assertRaises(ValidationError):
self.client._validate_params(long_string=long_string)
def test_large_list_validation(self):
"""Test validation of overly large lists."""
large_list = list(range(1001)) # Exceeds 1000 item limit
with self.assertRaises(ValidationError):
self.client._validate_params(large_list=large_list)
def test_large_dict_validation(self):
"""Test validation of overly large dictionaries."""
large_dict = {f"key_{i}": i for i in range(101)} # Exceeds 100 item limit
with self.assertRaises(ValidationError):
self.client._validate_params(large_dict=large_dict)
def test_valid_parameters_pass(self):
"""Test that valid parameters pass validation."""
# Should not raise any exception
self.client._validate_params(
valid_string="Hello, World!",
valid_list=[1, 2, 3],
valid_dict={"key": "value"},
none_value=None,
)
def test_message_feedback_validation(self):
"""Test validation in message_feedback method."""
with self.assertRaises(ValidationError):
self.client.message_feedback("msg_id", "invalid_rating", "user")
def test_completion_message_validation(self):
"""Test validation in create_completion_message method."""
from dify_client.client import CompletionClient
client = CompletionClient("test_api_key")
with self.assertRaises(ValidationError):
client.create_completion_message(
inputs="not_a_dict", # Should be a dict
response_mode="invalid_mode", # Should be 'blocking' or 'streaming'
user="test_user",
)
def test_chat_message_validation(self):
"""Test validation in create_chat_message method."""
from dify_client.client import ChatClient
client = ChatClient("test_api_key")
with self.assertRaises(ValidationError):
client.create_chat_message(
inputs="not_a_dict", # Should be a dict
query="", # Should not be empty
user="test_user",
response_mode="invalid_mode", # Should be 'blocking' or 'streaming'
)
if __name__ == "__main__":
unittest.main()

View File

@@ -59,7 +59,7 @@ version = "0.1.12"
source = { editable = "." } source = { editable = "." }
dependencies = [ dependencies = [
{ name = "aiofiles" }, { name = "aiofiles" },
{ name = "httpx" }, { name = "httpx", extra = ["http2"] },
] ]
[package.optional-dependencies] [package.optional-dependencies]
@@ -71,7 +71,7 @@ dev = [
[package.metadata] [package.metadata]
requires-dist = [ requires-dist = [
{ name = "aiofiles", specifier = ">=23.0.0" }, { name = "aiofiles", specifier = ">=23.0.0" },
{ name = "httpx", specifier = ">=0.27.0" }, { name = "httpx", extras = ["http2"], specifier = ">=0.27.0" },
{ name = "pytest", marker = "extra == 'dev'", specifier = ">=7.0.0" }, { name = "pytest", marker = "extra == 'dev'", specifier = ">=7.0.0" },
{ name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=0.21.0" }, { name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=0.21.0" },
] ]
@@ -98,6 +98,28 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/04/4b/29cac41a4d98d144bf5f6d33995617b185d14b22401f75ca86f384e87ff1/h11-0.16.0-py3-none-any.whl", hash = "sha256:63cf8bbe7522de3bf65932fda1d9c2772064ffb3dae62d55932da54b31cb6c86", size = 37515, upload-time = "2025-04-24T03:35:24.344Z" }, { url = "https://files.pythonhosted.org/packages/04/4b/29cac41a4d98d144bf5f6d33995617b185d14b22401f75ca86f384e87ff1/h11-0.16.0-py3-none-any.whl", hash = "sha256:63cf8bbe7522de3bf65932fda1d9c2772064ffb3dae62d55932da54b31cb6c86", size = 37515, upload-time = "2025-04-24T03:35:24.344Z" },
] ]
[[package]]
name = "h2"
version = "4.3.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "hpack" },
{ name = "hyperframe" },
]
sdist = { url = "https://files.pythonhosted.org/packages/1d/17/afa56379f94ad0fe8defd37d6eb3f89a25404ffc71d4d848893d270325fc/h2-4.3.0.tar.gz", hash = "sha256:6c59efe4323fa18b47a632221a1888bd7fde6249819beda254aeca909f221bf1", size = 2152026, upload-time = "2025-08-23T18:12:19.778Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/69/b2/119f6e6dcbd96f9069ce9a2665e0146588dc9f88f29549711853645e736a/h2-4.3.0-py3-none-any.whl", hash = "sha256:c438f029a25f7945c69e0ccf0fb951dc3f73a5f6412981daee861431b70e2bdd", size = 61779, upload-time = "2025-08-23T18:12:17.779Z" },
]
[[package]]
name = "hpack"
version = "4.1.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/2c/48/71de9ed269fdae9c8057e5a4c0aa7402e8bb16f2c6e90b3aa53327b113f8/hpack-4.1.0.tar.gz", hash = "sha256:ec5eca154f7056aa06f196a557655c5b009b382873ac8d1e66e79e87535f1dca", size = 51276, upload-time = "2025-01-22T21:44:58.347Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/07/c6/80c95b1b2b94682a72cbdbfb85b81ae2daffa4291fbfa1b1464502ede10d/hpack-4.1.0-py3-none-any.whl", hash = "sha256:157ac792668d995c657d93111f46b4535ed114f0c9c8d672271bbec7eae1b496", size = 34357, upload-time = "2025-01-22T21:44:56.92Z" },
]
[[package]] [[package]]
name = "httpcore" name = "httpcore"
version = "1.0.9" version = "1.0.9"
@@ -126,6 +148,20 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/2a/39/e50c7c3a983047577ee07d2a9e53faf5a69493943ec3f6a384bdc792deb2/httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad", size = 73517, upload-time = "2024-12-06T15:37:21.509Z" }, { url = "https://files.pythonhosted.org/packages/2a/39/e50c7c3a983047577ee07d2a9e53faf5a69493943ec3f6a384bdc792deb2/httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad", size = 73517, upload-time = "2024-12-06T15:37:21.509Z" },
] ]
[package.optional-dependencies]
http2 = [
{ name = "h2" },
]
[[package]]
name = "hyperframe"
version = "6.1.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/02/e7/94f8232d4a74cc99514c13a9f995811485a6903d48e5d952771ef6322e30/hyperframe-6.1.0.tar.gz", hash = "sha256:f630908a00854a7adeabd6382b43923a4c4cd4b821fcb527e6ab9e15382a3b08", size = 26566, upload-time = "2025-01-22T21:41:49.302Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/48/30/47d0bf6072f7252e6521f3447ccfa40b421b6824517f82854703d0f5a98b/hyperframe-6.1.0-py3-none-any.whl", hash = "sha256:b03380493a519fce58ea5af42e4a42317bf9bd425596f7a0835ffce80f1a42e5", size = 13007, upload-time = "2025-01-22T21:41:47.295Z" },
]
[[package]] [[package]]
name = "idna" name = "idna"
version = "3.10" version = "3.10"