mirror of
https://github.com/langgenius/dify.git
synced 2026-04-16 03:00:29 -04:00
Merge branch 'main' into jzh
This commit is contained in:
@@ -62,7 +62,7 @@ jobs:
|
||||
- name: Render coverage markdown from structured data
|
||||
id: render
|
||||
run: |
|
||||
comment_body="$(uv run --directory api python api/libs/pyrefly_type_coverage.py \
|
||||
comment_body="$(uv run --directory api python libs/pyrefly_type_coverage.py \
|
||||
--base base_report.json \
|
||||
< pr_report.json)"
|
||||
|
||||
|
||||
@@ -57,6 +57,9 @@ REDIS_SSL_CERTFILE=
|
||||
REDIS_SSL_KEYFILE=
|
||||
# Path to client private key file for SSL authentication
|
||||
REDIS_DB=0
|
||||
# Optional global prefix for Redis keys, topics, streams, and Celery Redis transport artifacts.
|
||||
# Leave empty to preserve current unprefixed behavior.
|
||||
REDIS_KEY_PREFIX=
|
||||
|
||||
# redis Sentinel configuration.
|
||||
REDIS_USE_SENTINEL=false
|
||||
|
||||
@@ -160,6 +160,16 @@ class DatabaseConfig(BaseSettings):
|
||||
default="",
|
||||
)
|
||||
|
||||
DB_SESSION_TIMEZONE_OVERRIDE: str = Field(
|
||||
description=(
|
||||
"PostgreSQL session timezone override injected via startup options."
|
||||
" Default is 'UTC' for out-of-the-box consistency."
|
||||
" Set to empty string to disable app-level timezone injection, for example when using RDS Proxy"
|
||||
" together with a database-side default timezone."
|
||||
),
|
||||
default="UTC",
|
||||
)
|
||||
|
||||
@computed_field # type: ignore[prop-decorator]
|
||||
@property
|
||||
def SQLALCHEMY_DATABASE_URI_SCHEME(self) -> str:
|
||||
@@ -227,12 +237,13 @@ class DatabaseConfig(BaseSettings):
|
||||
connect_args: dict[str, str] = {}
|
||||
# Use the dynamic SQLALCHEMY_DATABASE_URI_SCHEME property
|
||||
if self.SQLALCHEMY_DATABASE_URI_SCHEME.startswith("postgresql"):
|
||||
timezone_opt = "-c timezone=UTC"
|
||||
if options:
|
||||
merged_options = f"{options} {timezone_opt}"
|
||||
else:
|
||||
merged_options = timezone_opt
|
||||
connect_args = {"options": merged_options}
|
||||
merged_options = options.strip()
|
||||
session_timezone_override = self.DB_SESSION_TIMEZONE_OVERRIDE.strip()
|
||||
if session_timezone_override:
|
||||
timezone_opt = f"-c timezone={session_timezone_override}"
|
||||
merged_options = f"{merged_options} {timezone_opt}".strip() if merged_options else timezone_opt
|
||||
if merged_options:
|
||||
connect_args = {"options": merged_options}
|
||||
|
||||
result: SQLAlchemyEngineOptionsDict = {
|
||||
"pool_size": self.SQLALCHEMY_POOL_SIZE,
|
||||
|
||||
5
api/configs/middleware/cache/redis_config.py
vendored
5
api/configs/middleware/cache/redis_config.py
vendored
@@ -32,6 +32,11 @@ class RedisConfig(BaseSettings):
|
||||
default=0,
|
||||
)
|
||||
|
||||
REDIS_KEY_PREFIX: str = Field(
|
||||
description="Optional global prefix for Redis keys, topics, and transport artifacts",
|
||||
default="",
|
||||
)
|
||||
|
||||
REDIS_USE_SSL: bool = Field(
|
||||
description="Enable SSL/TLS for the Redis connection",
|
||||
default=False,
|
||||
|
||||
@@ -6,7 +6,6 @@ from typing import Any, Literal
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from graphon.enums import WorkflowExecutionStatus
|
||||
from graphon.file import helpers as file_helpers
|
||||
from pydantic import AliasChoices, BaseModel, Field, computed_field, field_validator
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
@@ -31,6 +30,7 @@ from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.trigger.constants import TRIGGER_NODE_TYPES
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from libs.helper import build_icon_url
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App, DatasetPermissionEnum, Workflow
|
||||
from models.model import IconType
|
||||
@@ -161,15 +161,6 @@ def _to_timestamp(value: datetime | int | None) -> int | None:
|
||||
return value
|
||||
|
||||
|
||||
def _build_icon_url(icon_type: str | IconType | None, icon: str | None) -> str | None:
|
||||
if icon is None or icon_type is None:
|
||||
return None
|
||||
icon_type_value = icon_type.value if isinstance(icon_type, IconType) else str(icon_type)
|
||||
if icon_type_value.lower() != IconType.IMAGE:
|
||||
return None
|
||||
return file_helpers.get_signed_file_url(icon)
|
||||
|
||||
|
||||
class Tag(ResponseModel):
|
||||
id: str
|
||||
name: str
|
||||
@@ -292,7 +283,7 @@ class Site(ResponseModel):
|
||||
@computed_field(return_type=str | None) # type: ignore
|
||||
@property
|
||||
def icon_url(self) -> str | None:
|
||||
return _build_icon_url(self.icon_type, self.icon)
|
||||
return build_icon_url(self.icon_type, self.icon)
|
||||
|
||||
@field_validator("icon_type", mode="before")
|
||||
@classmethod
|
||||
@@ -342,7 +333,7 @@ class AppPartial(ResponseModel):
|
||||
@computed_field(return_type=str | None) # type: ignore
|
||||
@property
|
||||
def icon_url(self) -> str | None:
|
||||
return _build_icon_url(self.icon_type, self.icon)
|
||||
return build_icon_url(self.icon_type, self.icon)
|
||||
|
||||
@field_validator("created_at", "updated_at", mode="before")
|
||||
@classmethod
|
||||
@@ -390,7 +381,7 @@ class AppDetailWithSite(AppDetail):
|
||||
@computed_field(return_type=str | None) # type: ignore
|
||||
@property
|
||||
def icon_url(self) -> str | None:
|
||||
return _build_icon_url(self.icon_type, self.icon)
|
||||
return build_icon_url(self.icon_type, self.icon)
|
||||
|
||||
|
||||
class AppPagination(ResponseModel):
|
||||
|
||||
@@ -1,44 +1,86 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from extensions.ext_database import db
|
||||
from fields.conversation_variable_fields import (
|
||||
conversation_variable_fields,
|
||||
paginated_conversation_variable_fields,
|
||||
)
|
||||
from fields._value_type_serializer import serialize_value_type
|
||||
from fields.base import ResponseModel
|
||||
from libs.login import login_required
|
||||
from models import ConversationVariable
|
||||
from models.model import AppMode
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class ConversationVariablesQuery(BaseModel):
|
||||
conversation_id: str = Field(..., description="Conversation ID to filter variables")
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
ConversationVariablesQuery.__name__,
|
||||
ConversationVariablesQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
def _to_timestamp(value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
# Register base model first
|
||||
conversation_variable_model = console_ns.model("ConversationVariable", conversation_variable_fields)
|
||||
|
||||
# For nested models, need to replace nested dict with registered model
|
||||
paginated_conversation_variable_fields_copy = paginated_conversation_variable_fields.copy()
|
||||
paginated_conversation_variable_fields_copy["data"] = fields.List(
|
||||
fields.Nested(conversation_variable_model), attribute="data"
|
||||
)
|
||||
paginated_conversation_variable_model = console_ns.model(
|
||||
"PaginatedConversationVariable", paginated_conversation_variable_fields_copy
|
||||
class ConversationVariableResponse(ResponseModel):
|
||||
id: str
|
||||
name: str
|
||||
value_type: str
|
||||
value: str | None = None
|
||||
description: str | None = None
|
||||
created_at: int | None = None
|
||||
updated_at: int | None = None
|
||||
|
||||
@field_validator("value_type", mode="before")
|
||||
@classmethod
|
||||
def _normalize_value_type(cls, value: Any) -> str:
|
||||
exposed_type = getattr(value, "exposed_type", None)
|
||||
if callable(exposed_type):
|
||||
return str(exposed_type().value)
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
try:
|
||||
return serialize_value_type(value)
|
||||
except Exception:
|
||||
return serialize_value_type({"value_type": value})
|
||||
|
||||
@field_validator("value", mode="before")
|
||||
@classmethod
|
||||
def _normalize_value(cls, value: Any | None) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
return str(value)
|
||||
|
||||
@field_validator("created_at", "updated_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
|
||||
|
||||
class PaginatedConversationVariableResponse(ResponseModel):
|
||||
page: int
|
||||
limit: int
|
||||
total: int
|
||||
has_more: bool
|
||||
data: list[ConversationVariableResponse]
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
ConversationVariablesQuery,
|
||||
ConversationVariableResponse,
|
||||
PaginatedConversationVariableResponse,
|
||||
)
|
||||
|
||||
|
||||
@@ -48,12 +90,15 @@ class ConversationVariablesApi(Resource):
|
||||
@console_ns.doc(description="Get conversation variables for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[ConversationVariablesQuery.__name__])
|
||||
@console_ns.response(200, "Conversation variables retrieved successfully", paginated_conversation_variable_model)
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Conversation variables retrieved successfully",
|
||||
console_ns.models[PaginatedConversationVariableResponse.__name__],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.ADVANCED_CHAT)
|
||||
@marshal_with(paginated_conversation_variable_model)
|
||||
def get(self, app_model):
|
||||
args = ConversationVariablesQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
@@ -72,17 +117,22 @@ class ConversationVariablesApi(Resource):
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
rows = session.scalars(stmt).all()
|
||||
|
||||
return {
|
||||
"page": page,
|
||||
"limit": page_size,
|
||||
"total": len(rows),
|
||||
"has_more": False,
|
||||
"data": [
|
||||
{
|
||||
"created_at": row.created_at,
|
||||
"updated_at": row.updated_at,
|
||||
**row.to_variable().model_dump(),
|
||||
}
|
||||
for row in rows
|
||||
],
|
||||
}
|
||||
response = PaginatedConversationVariableResponse.model_validate(
|
||||
{
|
||||
"page": page,
|
||||
"limit": page_size,
|
||||
"total": len(rows),
|
||||
"has_more": False,
|
||||
"data": [
|
||||
ConversationVariableResponse.model_validate(
|
||||
{
|
||||
"created_at": row.created_at,
|
||||
"updated_at": row.updated_at,
|
||||
**row.to_variable().model_dump(),
|
||||
}
|
||||
)
|
||||
for row in rows
|
||||
],
|
||||
}
|
||||
)
|
||||
return response.model_dump(mode="json")
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from flask_restx import Resource
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import exists, func, select
|
||||
@@ -25,10 +26,21 @@ from controllers.console.wraps import (
|
||||
setup_required,
|
||||
)
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.entities.execution_extra_content import ExecutionExtraContentDomainModel
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from extensions.ext_database import db
|
||||
from fields.raws import FilesContainedField
|
||||
from libs.helper import TimestampField, uuid_value
|
||||
from fields.base import ResponseModel
|
||||
from fields.conversation_fields import (
|
||||
AgentThought,
|
||||
ConversationAnnotation,
|
||||
ConversationAnnotationHitHistory,
|
||||
Feedback,
|
||||
JSONValue,
|
||||
MessageFile,
|
||||
format_files_contained,
|
||||
to_timestamp,
|
||||
)
|
||||
from libs.helper import uuid_value
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.enums import FeedbackFromSource, FeedbackRating
|
||||
@@ -98,6 +110,51 @@ class SuggestedQuestionsResponse(BaseModel):
|
||||
data: list[str] = Field(description="Suggested question")
|
||||
|
||||
|
||||
class MessageDetailResponse(ResponseModel):
|
||||
id: str
|
||||
conversation_id: str
|
||||
inputs: dict[str, JSONValue]
|
||||
query: str
|
||||
message: JSONValue | None = None
|
||||
message_tokens: int | None = None
|
||||
answer: str = Field(validation_alias="re_sign_file_url_answer")
|
||||
answer_tokens: int | None = None
|
||||
provider_response_latency: float | None = None
|
||||
from_source: str
|
||||
from_end_user_id: str | None = None
|
||||
from_account_id: str | None = None
|
||||
feedbacks: list[Feedback] = Field(default_factory=list)
|
||||
workflow_run_id: str | None = None
|
||||
annotation: ConversationAnnotation | None = None
|
||||
annotation_hit_history: ConversationAnnotationHitHistory | None = None
|
||||
created_at: int | None = None
|
||||
agent_thoughts: list[AgentThought] = Field(default_factory=list)
|
||||
message_files: list[MessageFile] = Field(default_factory=list)
|
||||
extra_contents: list[ExecutionExtraContentDomainModel] = Field(default_factory=list)
|
||||
metadata: JSONValue | None = Field(default=None, validation_alias="message_metadata_dict")
|
||||
status: str
|
||||
error: str | None = None
|
||||
parent_message_id: str | None = None
|
||||
|
||||
@field_validator("inputs", mode="before")
|
||||
@classmethod
|
||||
def _normalize_inputs(cls, value: JSONValue) -> JSONValue:
|
||||
return format_files_contained(value)
|
||||
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return to_timestamp(value)
|
||||
return value
|
||||
|
||||
|
||||
class MessageInfiniteScrollPaginationResponse(ResponseModel):
|
||||
limit: int
|
||||
has_more: bool
|
||||
data: list[MessageDetailResponse]
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
ChatMessagesQuery,
|
||||
@@ -105,124 +162,8 @@ register_schema_models(
|
||||
FeedbackExportQuery,
|
||||
AnnotationCountResponse,
|
||||
SuggestedQuestionsResponse,
|
||||
)
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
# Register in dependency order: base models first, then dependent models
|
||||
|
||||
# Base models
|
||||
simple_account_model = console_ns.model(
|
||||
"SimpleAccount",
|
||||
{
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"email": fields.String,
|
||||
},
|
||||
)
|
||||
|
||||
message_file_model = console_ns.model(
|
||||
"MessageFile",
|
||||
{
|
||||
"id": fields.String,
|
||||
"filename": fields.String,
|
||||
"type": fields.String,
|
||||
"url": fields.String,
|
||||
"mime_type": fields.String,
|
||||
"size": fields.Integer,
|
||||
"transfer_method": fields.String,
|
||||
"belongs_to": fields.String(default="user"),
|
||||
"upload_file_id": fields.String(default=None),
|
||||
},
|
||||
)
|
||||
|
||||
agent_thought_model = console_ns.model(
|
||||
"AgentThought",
|
||||
{
|
||||
"id": fields.String,
|
||||
"chain_id": fields.String,
|
||||
"message_id": fields.String,
|
||||
"position": fields.Integer,
|
||||
"thought": fields.String,
|
||||
"tool": fields.String,
|
||||
"tool_labels": fields.Raw,
|
||||
"tool_input": fields.String,
|
||||
"created_at": TimestampField,
|
||||
"observation": fields.String,
|
||||
"files": fields.List(fields.String),
|
||||
},
|
||||
)
|
||||
|
||||
# Models that depend on simple_account_model
|
||||
feedback_model = console_ns.model(
|
||||
"Feedback",
|
||||
{
|
||||
"rating": fields.String,
|
||||
"content": fields.String,
|
||||
"from_source": fields.String,
|
||||
"from_end_user_id": fields.String,
|
||||
"from_account": fields.Nested(simple_account_model, allow_null=True),
|
||||
},
|
||||
)
|
||||
|
||||
annotation_model = console_ns.model(
|
||||
"Annotation",
|
||||
{
|
||||
"id": fields.String,
|
||||
"question": fields.String,
|
||||
"content": fields.String,
|
||||
"account": fields.Nested(simple_account_model, allow_null=True),
|
||||
"created_at": TimestampField,
|
||||
},
|
||||
)
|
||||
|
||||
annotation_hit_history_model = console_ns.model(
|
||||
"AnnotationHitHistory",
|
||||
{
|
||||
"annotation_id": fields.String(attribute="id"),
|
||||
"annotation_create_account": fields.Nested(simple_account_model, allow_null=True),
|
||||
"created_at": TimestampField,
|
||||
},
|
||||
)
|
||||
|
||||
# Message detail model that depends on multiple models
|
||||
message_detail_model = console_ns.model(
|
||||
"MessageDetail",
|
||||
{
|
||||
"id": fields.String,
|
||||
"conversation_id": fields.String,
|
||||
"inputs": FilesContainedField,
|
||||
"query": fields.String,
|
||||
"message": fields.Raw,
|
||||
"message_tokens": fields.Integer,
|
||||
"answer": fields.String(attribute="re_sign_file_url_answer"),
|
||||
"answer_tokens": fields.Integer,
|
||||
"provider_response_latency": fields.Float,
|
||||
"from_source": fields.String,
|
||||
"from_end_user_id": fields.String,
|
||||
"from_account_id": fields.String,
|
||||
"feedbacks": fields.List(fields.Nested(feedback_model)),
|
||||
"workflow_run_id": fields.String,
|
||||
"annotation": fields.Nested(annotation_model, allow_null=True),
|
||||
"annotation_hit_history": fields.Nested(annotation_hit_history_model, allow_null=True),
|
||||
"created_at": TimestampField,
|
||||
"agent_thoughts": fields.List(fields.Nested(agent_thought_model)),
|
||||
"message_files": fields.List(fields.Nested(message_file_model)),
|
||||
"extra_contents": fields.List(fields.Raw),
|
||||
"metadata": fields.Raw(attribute="message_metadata_dict"),
|
||||
"status": fields.String,
|
||||
"error": fields.String,
|
||||
"parent_message_id": fields.String,
|
||||
},
|
||||
)
|
||||
|
||||
# Message infinite scroll pagination model
|
||||
message_infinite_scroll_pagination_model = console_ns.model(
|
||||
"MessageInfiniteScrollPagination",
|
||||
{
|
||||
"limit": fields.Integer,
|
||||
"has_more": fields.Boolean,
|
||||
"data": fields.List(fields.Nested(message_detail_model)),
|
||||
},
|
||||
MessageDetailResponse,
|
||||
MessageInfiniteScrollPaginationResponse,
|
||||
)
|
||||
|
||||
|
||||
@@ -232,13 +173,12 @@ class ChatMessageListApi(Resource):
|
||||
@console_ns.doc(description="Get chat messages for a conversation with pagination")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[ChatMessagesQuery.__name__])
|
||||
@console_ns.response(200, "Success", message_infinite_scroll_pagination_model)
|
||||
@console_ns.response(200, "Success", console_ns.models[MessageInfiniteScrollPaginationResponse.__name__])
|
||||
@console_ns.response(404, "Conversation not found")
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@setup_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
@marshal_with(message_infinite_scroll_pagination_model)
|
||||
@edit_permission_required
|
||||
def get(self, app_model):
|
||||
args = ChatMessagesQuery.model_validate(request.args.to_dict())
|
||||
@@ -298,7 +238,10 @@ class ChatMessageListApi(Resource):
|
||||
history_messages = list(reversed(history_messages))
|
||||
attach_message_extra_contents(history_messages)
|
||||
|
||||
return InfiniteScrollPagination(data=history_messages, limit=args.limit, has_more=has_more)
|
||||
return MessageInfiniteScrollPaginationResponse.model_validate(
|
||||
InfiniteScrollPagination(data=history_messages, limit=args.limit, has_more=has_more),
|
||||
from_attributes=True,
|
||||
).model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/feedbacks")
|
||||
@@ -468,13 +411,12 @@ class MessageApi(Resource):
|
||||
@console_ns.doc("get_message")
|
||||
@console_ns.doc(description="Get message details by ID")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "message_id": "Message ID"})
|
||||
@console_ns.response(200, "Message retrieved successfully", message_detail_model)
|
||||
@console_ns.response(200, "Message retrieved successfully", console_ns.models[MessageDetailResponse.__name__])
|
||||
@console_ns.response(404, "Message not found")
|
||||
@get_app_model
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(message_detail_model)
|
||||
def get(self, app_model, message_id: str):
|
||||
message_id = str(message_id)
|
||||
|
||||
@@ -486,4 +428,4 @@ class MessageApi(Resource):
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
||||
attach_message_extra_contents([message])
|
||||
return message
|
||||
return MessageDetailResponse.model_validate(message, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
@@ -1,27 +1,26 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from dateutil.parser import isoparse
|
||||
from flask import request
|
||||
from flask_restx import Resource, marshal_with
|
||||
from flask_restx import Resource
|
||||
from graphon.enums import WorkflowExecutionStatus
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from extensions.ext_database import db
|
||||
from fields.workflow_app_log_fields import (
|
||||
build_workflow_app_log_pagination_model,
|
||||
build_workflow_archived_log_pagination_model,
|
||||
)
|
||||
from fields.base import ResponseModel
|
||||
from fields.end_user_fields import SimpleEndUser
|
||||
from fields.member_fields import SimpleAccount
|
||||
from libs.login import login_required
|
||||
from models import App
|
||||
from models.model import AppMode
|
||||
from services.workflow_app_service import WorkflowAppService
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class WorkflowAppLogQuery(BaseModel):
|
||||
keyword: str | None = Field(default=None, description="Search keyword for filtering logs")
|
||||
@@ -58,13 +57,113 @@ class WorkflowAppLogQuery(BaseModel):
|
||||
raise ValueError("Invalid boolean value for detail")
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
WorkflowAppLogQuery.__name__, WorkflowAppLogQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
class WorkflowRunForLogResponse(ResponseModel):
|
||||
id: str
|
||||
version: str | None = None
|
||||
status: str | None = None
|
||||
triggered_from: str | None = None
|
||||
error: str | None = None
|
||||
elapsed_time: float | None = None
|
||||
total_tokens: int | None = None
|
||||
total_steps: int | None = None
|
||||
created_at: int | None = None
|
||||
finished_at: int | None = None
|
||||
exceptions_count: int | None = None
|
||||
|
||||
# Register model for flask_restx to avoid dict type issues in Swagger
|
||||
workflow_app_log_pagination_model = build_workflow_app_log_pagination_model(console_ns)
|
||||
workflow_archived_log_pagination_model = build_workflow_archived_log_pagination_model(console_ns)
|
||||
@field_validator("status", mode="before")
|
||||
@classmethod
|
||||
def _normalize_status(cls, value: Any) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
return str(getattr(value, "value", value))
|
||||
|
||||
@field_validator("created_at", "finished_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
|
||||
|
||||
class WorkflowRunForArchivedLogResponse(ResponseModel):
|
||||
id: str
|
||||
status: str | None = None
|
||||
triggered_from: str | None = None
|
||||
elapsed_time: float | None = None
|
||||
total_tokens: int | None = None
|
||||
|
||||
@field_validator("status", mode="before")
|
||||
@classmethod
|
||||
def _normalize_status(cls, value: Any) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
return str(getattr(value, "value", value))
|
||||
|
||||
|
||||
class WorkflowAppLogPartialResponse(ResponseModel):
|
||||
id: str
|
||||
workflow_run: WorkflowRunForLogResponse | None = None
|
||||
details: Any = None
|
||||
created_from: str | None = None
|
||||
created_by_role: str | None = None
|
||||
created_by_account: SimpleAccount | None = None
|
||||
created_by_end_user: SimpleEndUser | None = None
|
||||
created_at: int | None = None
|
||||
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
|
||||
|
||||
class WorkflowArchivedLogPartialResponse(ResponseModel):
|
||||
id: str
|
||||
workflow_run: WorkflowRunForArchivedLogResponse | None = None
|
||||
trigger_metadata: Any = None
|
||||
created_by_account: SimpleAccount | None = None
|
||||
created_by_end_user: SimpleEndUser | None = None
|
||||
created_at: int | None = None
|
||||
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
|
||||
|
||||
class WorkflowAppLogPaginationResponse(ResponseModel):
|
||||
page: int
|
||||
limit: int
|
||||
total: int
|
||||
has_more: bool
|
||||
data: list[WorkflowAppLogPartialResponse]
|
||||
|
||||
|
||||
class WorkflowArchivedLogPaginationResponse(ResponseModel):
|
||||
page: int
|
||||
limit: int
|
||||
total: int
|
||||
has_more: bool
|
||||
data: list[WorkflowArchivedLogPartialResponse]
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
WorkflowAppLogQuery,
|
||||
WorkflowRunForLogResponse,
|
||||
WorkflowRunForArchivedLogResponse,
|
||||
WorkflowAppLogPartialResponse,
|
||||
WorkflowArchivedLogPartialResponse,
|
||||
WorkflowAppLogPaginationResponse,
|
||||
WorkflowArchivedLogPaginationResponse,
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow-app-logs")
|
||||
@@ -73,12 +172,15 @@ class WorkflowAppLogApi(Resource):
|
||||
@console_ns.doc(description="Get workflow application execution logs")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowAppLogQuery.__name__])
|
||||
@console_ns.response(200, "Workflow app logs retrieved successfully", workflow_app_log_pagination_model)
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Workflow app logs retrieved successfully",
|
||||
console_ns.models[WorkflowAppLogPaginationResponse.__name__],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||
@marshal_with(workflow_app_log_pagination_model)
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
Get workflow app logs
|
||||
@@ -102,7 +204,9 @@ class WorkflowAppLogApi(Resource):
|
||||
created_by_account=args.created_by_account,
|
||||
)
|
||||
|
||||
return workflow_app_log_pagination
|
||||
return WorkflowAppLogPaginationResponse.model_validate(
|
||||
workflow_app_log_pagination, from_attributes=True
|
||||
).model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow-archived-logs")
|
||||
@@ -111,12 +215,15 @@ class WorkflowArchivedLogApi(Resource):
|
||||
@console_ns.doc(description="Get workflow archived execution logs")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowAppLogQuery.__name__])
|
||||
@console_ns.response(200, "Workflow archived logs retrieved successfully", workflow_archived_log_pagination_model)
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Workflow archived logs retrieved successfully",
|
||||
console_ns.models[WorkflowArchivedLogPaginationResponse.__name__],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||
@marshal_with(workflow_archived_log_pagination_model)
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
Get workflow archived logs
|
||||
@@ -132,4 +239,6 @@ class WorkflowArchivedLogApi(Resource):
|
||||
limit=args.limit,
|
||||
)
|
||||
|
||||
return workflow_app_log_pagination
|
||||
return WorkflowArchivedLogPaginationResponse.model_validate(
|
||||
workflow_app_log_pagination, from_attributes=True
|
||||
).model_dump(mode="json")
|
||||
|
||||
@@ -1,16 +1,17 @@
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, field_validator
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.common.schema import get_or_create_model
|
||||
from controllers.common.schema import register_schema_models
|
||||
from extensions.ext_database import db
|
||||
from fields.workflow_trigger_fields import trigger_fields, triggers_list_fields, webhook_trigger_fields
|
||||
from fields.base import ResponseModel
|
||||
from libs.login import current_user, login_required
|
||||
from models.enums import AppTriggerStatus
|
||||
from models.model import Account, App, AppMode
|
||||
@@ -21,15 +22,6 @@ from ..app.wraps import get_app_model
|
||||
from ..wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
trigger_model = get_or_create_model("WorkflowTrigger", trigger_fields)
|
||||
|
||||
triggers_list_fields_copy = triggers_list_fields.copy()
|
||||
triggers_list_fields_copy["data"] = fields.List(fields.Nested(trigger_model))
|
||||
triggers_list_model = get_or_create_model("WorkflowTriggerList", triggers_list_fields_copy)
|
||||
|
||||
webhook_trigger_model = get_or_create_model("WebhookTrigger", webhook_trigger_fields)
|
||||
|
||||
|
||||
class Parser(BaseModel):
|
||||
@@ -41,10 +33,52 @@ class ParserEnable(BaseModel):
|
||||
enable_trigger: bool
|
||||
|
||||
|
||||
console_ns.schema_model(Parser.__name__, Parser.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
class WorkflowTriggerResponse(ResponseModel):
|
||||
id: str
|
||||
trigger_type: str
|
||||
title: str
|
||||
node_id: str
|
||||
provider_name: str
|
||||
icon: str
|
||||
status: str
|
||||
created_at: datetime | None = None
|
||||
updated_at: datetime | None = None
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserEnable.__name__, ParserEnable.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
@field_validator("id", "trigger_type", "title", "node_id", "provider_name", "icon", "status", mode="before")
|
||||
@classmethod
|
||||
def _normalize_string_fields(cls, value: object) -> str:
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
return str(value)
|
||||
|
||||
|
||||
class WorkflowTriggerListResponse(ResponseModel):
|
||||
data: list[WorkflowTriggerResponse]
|
||||
|
||||
|
||||
class WebhookTriggerResponse(ResponseModel):
|
||||
id: str
|
||||
webhook_id: str
|
||||
webhook_url: str
|
||||
webhook_debug_url: str
|
||||
node_id: str
|
||||
created_at: datetime | None = None
|
||||
|
||||
@field_validator("id", "webhook_id", "webhook_url", "webhook_debug_url", "node_id", mode="before")
|
||||
@classmethod
|
||||
def _normalize_string_fields(cls, value: object) -> str:
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
return str(value)
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
Parser,
|
||||
ParserEnable,
|
||||
WorkflowTriggerResponse,
|
||||
WorkflowTriggerListResponse,
|
||||
WebhookTriggerResponse,
|
||||
)
|
||||
|
||||
|
||||
@@ -57,7 +91,7 @@ class WebhookTriggerApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.WORKFLOW)
|
||||
@marshal_with(webhook_trigger_model)
|
||||
@console_ns.response(200, "Success", console_ns.models[WebhookTriggerResponse.__name__])
|
||||
def get(self, app_model: App):
|
||||
"""Get webhook trigger for a node"""
|
||||
args = Parser.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
@@ -78,7 +112,7 @@ class WebhookTriggerApi(Resource):
|
||||
if not webhook_trigger:
|
||||
raise NotFound("Webhook trigger not found for this node")
|
||||
|
||||
return webhook_trigger
|
||||
return WebhookTriggerResponse.model_validate(webhook_trigger, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/triggers")
|
||||
@@ -89,7 +123,7 @@ class AppTriggersApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.WORKFLOW)
|
||||
@marshal_with(triggers_list_model)
|
||||
@console_ns.response(200, "Success", console_ns.models[WorkflowTriggerListResponse.__name__])
|
||||
def get(self, app_model: App):
|
||||
"""Get app triggers list"""
|
||||
assert isinstance(current_user, Account)
|
||||
@@ -118,7 +152,9 @@ class AppTriggersApi(Resource):
|
||||
else:
|
||||
trigger.icon = "" # type: ignore
|
||||
|
||||
return {"data": triggers}
|
||||
return WorkflowTriggerListResponse.model_validate({"data": triggers}, from_attributes=True).model_dump(
|
||||
mode="json"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/trigger-enable")
|
||||
@@ -129,7 +165,7 @@ class AppTriggerEnableApi(Resource):
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_app_model(mode=AppMode.WORKFLOW)
|
||||
@marshal_with(trigger_model)
|
||||
@console_ns.response(200, "Success", console_ns.models[WorkflowTriggerResponse.__name__])
|
||||
def post(self, app_model: App):
|
||||
"""Update app trigger (enable/disable)"""
|
||||
args = ParserEnable.model_validate(console_ns.payload)
|
||||
@@ -160,4 +196,4 @@ class AppTriggerEnableApi(Resource):
|
||||
else:
|
||||
trigger.icon = "" # type: ignore
|
||||
|
||||
return trigger
|
||||
return WorkflowTriggerResponse.model_validate(trigger, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
@@ -40,7 +42,7 @@ class ActivatePayload(BaseModel):
|
||||
|
||||
class ActivationCheckResponse(BaseModel):
|
||||
is_valid: bool = Field(description="Whether token is valid")
|
||||
data: dict | None = Field(default=None, description="Activation data if valid")
|
||||
data: dict[str, Any] | None = Field(default=None, description="Activation data if valid")
|
||||
|
||||
|
||||
class ActivationResponse(BaseModel):
|
||||
|
||||
@@ -1026,7 +1026,7 @@ class DocumentMetadataApi(DocumentResource):
|
||||
|
||||
if not isinstance(doc_metadata, dict):
|
||||
raise ValueError("doc_metadata must be a dictionary.")
|
||||
metadata_schema: dict = cast(dict, DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type])
|
||||
metadata_schema: dict[str, Any] = cast(dict[str, Any], DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type])
|
||||
|
||||
document.doc_metadata = {}
|
||||
if doc_type == "others":
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
from flask_restx import Resource, fields
|
||||
from __future__ import annotations
|
||||
|
||||
from controllers.common.schema import register_schema_model
|
||||
from fields.hit_testing_fields import (
|
||||
child_chunk_fields,
|
||||
document_fields,
|
||||
files_fields,
|
||||
hit_testing_record_fields,
|
||||
segment_fields,
|
||||
)
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from flask_restx import Resource
|
||||
from pydantic import Field, field_validator
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from fields.base import ResponseModel
|
||||
from libs.login import login_required
|
||||
|
||||
from .. import console_ns
|
||||
@@ -18,39 +18,92 @@ from ..wraps import (
|
||||
setup_required,
|
||||
)
|
||||
|
||||
register_schema_model(console_ns, HitTestingPayload)
|
||||
|
||||
def _to_timestamp(value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
|
||||
|
||||
def _get_or_create_model(model_name: str, field_def):
|
||||
"""Get or create a flask_restx model to avoid dict type issues in Swagger."""
|
||||
existing = console_ns.models.get(model_name)
|
||||
if existing is None:
|
||||
existing = console_ns.model(model_name, field_def)
|
||||
return existing
|
||||
class HitTestingDocument(ResponseModel):
|
||||
id: str | None = None
|
||||
data_source_type: str | None = None
|
||||
name: str | None = None
|
||||
doc_type: str | None = None
|
||||
doc_metadata: Any | None = None
|
||||
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
document_model = _get_or_create_model("HitTestingDocument", document_fields)
|
||||
class HitTestingSegment(ResponseModel):
|
||||
id: str | None = None
|
||||
position: int | None = None
|
||||
document_id: str | None = None
|
||||
content: str | None = None
|
||||
sign_content: str | None = None
|
||||
answer: str | None = None
|
||||
word_count: int | None = None
|
||||
tokens: int | None = None
|
||||
keywords: list[str] = Field(default_factory=list)
|
||||
index_node_id: str | None = None
|
||||
index_node_hash: str | None = None
|
||||
hit_count: int | None = None
|
||||
enabled: bool | None = None
|
||||
disabled_at: int | None = None
|
||||
disabled_by: str | None = None
|
||||
status: str | None = None
|
||||
created_by: str | None = None
|
||||
created_at: int | None = None
|
||||
indexing_at: int | None = None
|
||||
completed_at: int | None = None
|
||||
error: str | None = None
|
||||
stopped_at: int | None = None
|
||||
document: HitTestingDocument | None = None
|
||||
|
||||
segment_fields_copy = segment_fields.copy()
|
||||
segment_fields_copy["document"] = fields.Nested(document_model)
|
||||
segment_model = _get_or_create_model("HitTestingSegment", segment_fields_copy)
|
||||
@field_validator("disabled_at", "created_at", "indexing_at", "completed_at", "stopped_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
|
||||
child_chunk_model = _get_or_create_model("HitTestingChildChunk", child_chunk_fields)
|
||||
files_model = _get_or_create_model("HitTestingFile", files_fields)
|
||||
|
||||
hit_testing_record_fields_copy = hit_testing_record_fields.copy()
|
||||
hit_testing_record_fields_copy["segment"] = fields.Nested(segment_model)
|
||||
hit_testing_record_fields_copy["child_chunks"] = fields.List(fields.Nested(child_chunk_model))
|
||||
hit_testing_record_fields_copy["files"] = fields.List(fields.Nested(files_model))
|
||||
hit_testing_record_model = _get_or_create_model("HitTestingRecord", hit_testing_record_fields_copy)
|
||||
class HitTestingChildChunk(ResponseModel):
|
||||
id: str | None = None
|
||||
content: str | None = None
|
||||
position: int | None = None
|
||||
score: float | None = None
|
||||
|
||||
# Response model for hit testing API
|
||||
hit_testing_response_fields = {
|
||||
"query": fields.String,
|
||||
"records": fields.List(fields.Nested(hit_testing_record_model)),
|
||||
}
|
||||
hit_testing_response_model = _get_or_create_model("HitTestingResponse", hit_testing_response_fields)
|
||||
|
||||
class HitTestingFile(ResponseModel):
|
||||
id: str | None = None
|
||||
name: str | None = None
|
||||
size: int | None = None
|
||||
extension: str | None = None
|
||||
mime_type: str | None = None
|
||||
source_url: str | None = None
|
||||
|
||||
|
||||
class HitTestingRecord(ResponseModel):
|
||||
segment: HitTestingSegment | None = None
|
||||
child_chunks: list[HitTestingChildChunk] = Field(default_factory=list)
|
||||
score: float | None = None
|
||||
tsne_position: Any | None = None
|
||||
files: list[HitTestingFile] = Field(default_factory=list)
|
||||
summary: str | None = None
|
||||
|
||||
|
||||
class HitTestingResponse(ResponseModel):
|
||||
query: str
|
||||
records: list[HitTestingRecord] = Field(default_factory=list)
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
HitTestingPayload,
|
||||
HitTestingDocument,
|
||||
HitTestingSegment,
|
||||
HitTestingChildChunk,
|
||||
HitTestingFile,
|
||||
HitTestingRecord,
|
||||
HitTestingResponse,
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/hit-testing")
|
||||
@@ -59,7 +112,11 @@ class HitTestingApi(Resource, DatasetsHitTestingBase):
|
||||
@console_ns.doc(description="Test dataset knowledge retrieval")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@console_ns.expect(console_ns.models[HitTestingPayload.__name__])
|
||||
@console_ns.response(200, "Hit testing completed successfully", model=hit_testing_response_model)
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Hit testing completed successfully",
|
||||
model=console_ns.models[HitTestingResponse.__name__],
|
||||
)
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
@console_ns.response(400, "Invalid parameters")
|
||||
@setup_required
|
||||
@@ -74,4 +131,4 @@ class HitTestingApi(Resource, DatasetsHitTestingBase):
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
self.hit_testing_args_check(args)
|
||||
|
||||
return self.perform_hit_testing(dataset, args)
|
||||
return HitTestingResponse.model_validate(self.perform_hit_testing(dataset, args)).model_dump(mode="json")
|
||||
|
||||
@@ -1,21 +1,24 @@
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from flask_restx import Resource
|
||||
from graphon.file import helpers as file_helpers
|
||||
from pydantic import BaseModel, Field, computed_field, field_validator
|
||||
from sqlalchemy import and_, select
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
|
||||
|
||||
from controllers.common.schema import get_or_create_model
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
||||
from extensions.ext_database import db
|
||||
from fields.installed_app_fields import app_fields, installed_app_fields, installed_app_list_fields
|
||||
from fields.base import ResponseModel
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App, InstalledApp, RecommendedApp
|
||||
from models.model import IconType
|
||||
from services.account_service import TenantService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.feature_service import FeatureService
|
||||
@@ -36,22 +39,97 @@ class InstalledAppsListQuery(BaseModel):
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
app_model = get_or_create_model("InstalledAppInfo", app_fields)
|
||||
def _build_icon_url(icon_type: str | IconType | None, icon: str | None) -> str | None:
|
||||
if icon is None or icon_type is None:
|
||||
return None
|
||||
icon_type_value = icon_type.value if isinstance(icon_type, IconType) else str(icon_type)
|
||||
if icon_type_value.lower() != IconType.IMAGE:
|
||||
return None
|
||||
return file_helpers.get_signed_file_url(icon)
|
||||
|
||||
installed_app_fields_copy = installed_app_fields.copy()
|
||||
installed_app_fields_copy["app"] = fields.Nested(app_model)
|
||||
installed_app_model = get_or_create_model("InstalledApp", installed_app_fields_copy)
|
||||
|
||||
installed_app_list_fields_copy = installed_app_list_fields.copy()
|
||||
installed_app_list_fields_copy["installed_apps"] = fields.List(fields.Nested(installed_app_model))
|
||||
installed_app_list_model = get_or_create_model("InstalledAppList", installed_app_list_fields_copy)
|
||||
def _safe_primitive(value: Any) -> Any:
|
||||
if value is None or isinstance(value, (str, int, float, bool, datetime)):
|
||||
return value
|
||||
return None
|
||||
|
||||
|
||||
class InstalledAppInfoResponse(ResponseModel):
|
||||
id: str
|
||||
name: str | None = None
|
||||
mode: str | None = None
|
||||
icon_type: str | None = None
|
||||
icon: str | None = None
|
||||
icon_background: str | None = None
|
||||
use_icon_as_answer_icon: bool | None = None
|
||||
|
||||
@field_validator("mode", "icon_type", mode="before")
|
||||
@classmethod
|
||||
def _normalize_enum_like(cls, value: Any) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
return str(getattr(value, "value", value))
|
||||
|
||||
@computed_field(return_type=str | None) # type: ignore[prop-decorator]
|
||||
@property
|
||||
def icon_url(self) -> str | None:
|
||||
return _build_icon_url(self.icon_type, self.icon)
|
||||
|
||||
|
||||
class InstalledAppResponse(ResponseModel):
|
||||
id: str
|
||||
app: InstalledAppInfoResponse
|
||||
app_owner_tenant_id: str
|
||||
is_pinned: bool
|
||||
last_used_at: int | None = None
|
||||
editable: bool
|
||||
uninstallable: bool
|
||||
|
||||
@field_validator("app", mode="before")
|
||||
@classmethod
|
||||
def _normalize_app(cls, value: Any) -> Any:
|
||||
if isinstance(value, dict):
|
||||
return value
|
||||
return {
|
||||
"id": _safe_primitive(getattr(value, "id", "")) or "",
|
||||
"name": _safe_primitive(getattr(value, "name", None)),
|
||||
"mode": _safe_primitive(getattr(value, "mode", None)),
|
||||
"icon_type": _safe_primitive(getattr(value, "icon_type", None)),
|
||||
"icon": _safe_primitive(getattr(value, "icon", None)),
|
||||
"icon_background": _safe_primitive(getattr(value, "icon_background", None)),
|
||||
"use_icon_as_answer_icon": _safe_primitive(getattr(value, "use_icon_as_answer_icon", None)),
|
||||
}
|
||||
|
||||
@field_validator("last_used_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
|
||||
|
||||
class InstalledAppListResponse(ResponseModel):
|
||||
installed_apps: list[InstalledAppResponse]
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
InstalledAppCreatePayload,
|
||||
InstalledAppUpdatePayload,
|
||||
InstalledAppsListQuery,
|
||||
InstalledAppInfoResponse,
|
||||
InstalledAppResponse,
|
||||
InstalledAppListResponse,
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/installed-apps")
|
||||
class InstalledAppsListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(installed_app_list_model)
|
||||
@console_ns.response(200, "Success", console_ns.models[InstalledAppListResponse.__name__])
|
||||
def get(self):
|
||||
query = InstalledAppsListQuery.model_validate(request.args.to_dict())
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
@@ -125,7 +203,9 @@ class InstalledAppsListApi(Resource):
|
||||
)
|
||||
)
|
||||
|
||||
return {"installed_apps": installed_app_list}
|
||||
return InstalledAppListResponse.model_validate(
|
||||
{"installed_apps": installed_app_list}, from_attributes=True
|
||||
).model_dump(mode="json")
|
||||
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
||||
@@ -1,66 +1,83 @@
|
||||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, computed_field, field_validator
|
||||
|
||||
from constants.languages import languages
|
||||
from controllers.common.schema import get_or_create_model
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from libs.helper import AppIconUrlField
|
||||
from fields.base import ResponseModel
|
||||
from libs.helper import build_icon_url
|
||||
from libs.login import current_user, login_required
|
||||
from services.recommended_app_service import RecommendedAppService
|
||||
|
||||
app_fields = {
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"mode": fields.String,
|
||||
"icon": fields.String,
|
||||
"icon_type": fields.String,
|
||||
"icon_url": AppIconUrlField,
|
||||
"icon_background": fields.String,
|
||||
}
|
||||
|
||||
app_model = get_or_create_model("RecommendedAppInfo", app_fields)
|
||||
|
||||
recommended_app_fields = {
|
||||
"app": fields.Nested(app_model, attribute="app"),
|
||||
"app_id": fields.String,
|
||||
"description": fields.String(attribute="description"),
|
||||
"copyright": fields.String,
|
||||
"privacy_policy": fields.String,
|
||||
"custom_disclaimer": fields.String,
|
||||
"category": fields.String,
|
||||
"position": fields.Integer,
|
||||
"is_listed": fields.Boolean,
|
||||
"can_trial": fields.Boolean,
|
||||
}
|
||||
|
||||
recommended_app_model = get_or_create_model("RecommendedApp", recommended_app_fields)
|
||||
|
||||
recommended_app_list_fields = {
|
||||
"recommended_apps": fields.List(fields.Nested(recommended_app_model)),
|
||||
"categories": fields.List(fields.String),
|
||||
}
|
||||
|
||||
recommended_app_list_model = get_or_create_model("RecommendedAppList", recommended_app_list_fields)
|
||||
|
||||
|
||||
class RecommendedAppsQuery(BaseModel):
|
||||
language: str | None = Field(default=None)
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
RecommendedAppsQuery.__name__,
|
||||
RecommendedAppsQuery.model_json_schema(ref_template="#/definitions/{model}"),
|
||||
class RecommendedAppInfoResponse(ResponseModel):
|
||||
id: str
|
||||
name: str | None = None
|
||||
mode: str | None = None
|
||||
icon: str | None = None
|
||||
icon_type: str | None = None
|
||||
icon_background: str | None = None
|
||||
|
||||
@staticmethod
|
||||
def _normalize_enum_like(value: Any) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
return str(getattr(value, "value", value))
|
||||
|
||||
@field_validator("mode", "icon_type", mode="before")
|
||||
@classmethod
|
||||
def _normalize_enum_fields(cls, value: Any) -> str | None:
|
||||
return cls._normalize_enum_like(value)
|
||||
|
||||
@computed_field(return_type=str | None) # type: ignore[prop-decorator]
|
||||
@property
|
||||
def icon_url(self) -> str | None:
|
||||
return build_icon_url(self.icon_type, self.icon)
|
||||
|
||||
|
||||
class RecommendedAppResponse(ResponseModel):
|
||||
app: RecommendedAppInfoResponse | None = None
|
||||
app_id: str
|
||||
description: str | None = None
|
||||
copyright: str | None = None
|
||||
privacy_policy: str | None = None
|
||||
custom_disclaimer: str | None = None
|
||||
category: str | None = None
|
||||
position: int | None = None
|
||||
is_listed: bool | None = None
|
||||
can_trial: bool | None = None
|
||||
|
||||
|
||||
class RecommendedAppListResponse(ResponseModel):
|
||||
recommended_apps: list[RecommendedAppResponse]
|
||||
categories: list[str]
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
RecommendedAppsQuery,
|
||||
RecommendedAppInfoResponse,
|
||||
RecommendedAppResponse,
|
||||
RecommendedAppListResponse,
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/explore/apps")
|
||||
class RecommendedAppListApi(Resource):
|
||||
@console_ns.expect(console_ns.models[RecommendedAppsQuery.__name__])
|
||||
@console_ns.response(200, "Success", console_ns.models[RecommendedAppListResponse.__name__])
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(recommended_app_list_model)
|
||||
def get(self):
|
||||
# language args
|
||||
args = RecommendedAppsQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
@@ -72,7 +89,10 @@ class RecommendedAppListApi(Resource):
|
||||
else:
|
||||
language_prefix = languages[0]
|
||||
|
||||
return RecommendedAppService.get_recommended_apps_and_categories(language_prefix)
|
||||
return RecommendedAppListResponse.model_validate(
|
||||
RecommendedAppService.get_recommended_apps_and_categories(language_prefix),
|
||||
from_attributes=True,
|
||||
).model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/explore/apps/<uuid:app_id>")
|
||||
|
||||
@@ -1,15 +1,18 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, TypeAdapter, field_validator
|
||||
|
||||
from constants import HIDDEN_VALUE
|
||||
from fields.api_based_extension_fields import api_based_extension_fields
|
||||
from fields.base import ResponseModel
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.api_based_extension import APIBasedExtension
|
||||
from services.api_based_extension_service import APIBasedExtensionService
|
||||
from services.code_based_extension_service import CodeBasedExtensionService
|
||||
|
||||
from ..common.schema import register_schema_models
|
||||
from ..common.schema import DEFAULT_REF_TEMPLATE_SWAGGER_2_0, register_schema_models
|
||||
from . import console_ns
|
||||
from .wraps import account_initialization_required, setup_required
|
||||
|
||||
@@ -24,12 +27,52 @@ class APIBasedExtensionPayload(BaseModel):
|
||||
api_key: str = Field(description="API key for authentication")
|
||||
|
||||
|
||||
register_schema_models(console_ns, APIBasedExtensionPayload)
|
||||
class CodeBasedExtensionResponse(ResponseModel):
|
||||
module: str = Field(description="Module name")
|
||||
data: Any = Field(description="Extension data")
|
||||
|
||||
|
||||
api_based_extension_model = console_ns.model("ApiBasedExtensionModel", api_based_extension_fields)
|
||||
def _mask_api_key(api_key: str) -> str:
|
||||
if not api_key:
|
||||
return api_key
|
||||
if len(api_key) <= 8:
|
||||
return api_key[0] + "******" + api_key[-1]
|
||||
return api_key[:3] + "******" + api_key[-3:]
|
||||
|
||||
api_based_extension_list_model = fields.List(fields.Nested(api_based_extension_model))
|
||||
|
||||
def _to_timestamp(value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
|
||||
|
||||
class APIBasedExtensionResponse(ResponseModel):
|
||||
id: str
|
||||
name: str
|
||||
api_endpoint: str
|
||||
api_key: str
|
||||
created_at: int | None = None
|
||||
|
||||
@field_validator("api_key", mode="before")
|
||||
@classmethod
|
||||
def _normalize_api_key(cls, value: str) -> str:
|
||||
return _mask_api_key(value)
|
||||
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
|
||||
|
||||
register_schema_models(console_ns, APIBasedExtensionPayload, CodeBasedExtensionResponse, APIBasedExtensionResponse)
|
||||
console_ns.schema_model(
|
||||
"APIBasedExtensionListResponse",
|
||||
TypeAdapter(list[APIBasedExtensionResponse]).json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
|
||||
def _serialize_api_based_extension(extension: APIBasedExtension) -> dict[str, Any]:
|
||||
return APIBasedExtensionResponse.model_validate(extension, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/code-based-extension")
|
||||
@@ -40,10 +83,7 @@ class CodeBasedExtensionAPI(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Success",
|
||||
console_ns.model(
|
||||
"CodeBasedExtensionResponse",
|
||||
{"module": fields.String(description="Module name"), "data": fields.Raw(description="Extension data")},
|
||||
),
|
||||
console_ns.models[CodeBasedExtensionResponse.__name__],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -51,30 +91,34 @@ class CodeBasedExtensionAPI(Resource):
|
||||
def get(self):
|
||||
query = CodeBasedExtensionQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
return {"module": query.module, "data": CodeBasedExtensionService.get_code_based_extension(query.module)}
|
||||
return CodeBasedExtensionResponse(
|
||||
module=query.module,
|
||||
data=CodeBasedExtensionService.get_code_based_extension(query.module),
|
||||
).model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/api-based-extension")
|
||||
class APIBasedExtensionAPI(Resource):
|
||||
@console_ns.doc("get_api_based_extensions")
|
||||
@console_ns.doc(description="Get all API-based extensions for current tenant")
|
||||
@console_ns.response(200, "Success", api_based_extension_list_model)
|
||||
@console_ns.response(200, "Success", console_ns.models["APIBasedExtensionListResponse"])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(api_based_extension_model)
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
return APIBasedExtensionService.get_all_by_tenant_id(tenant_id)
|
||||
return [
|
||||
_serialize_api_based_extension(extension)
|
||||
for extension in APIBasedExtensionService.get_all_by_tenant_id(tenant_id)
|
||||
]
|
||||
|
||||
@console_ns.doc("create_api_based_extension")
|
||||
@console_ns.doc(description="Create a new API-based extension")
|
||||
@console_ns.expect(console_ns.models[APIBasedExtensionPayload.__name__])
|
||||
@console_ns.response(201, "Extension created successfully", api_based_extension_model)
|
||||
@console_ns.response(201, "Extension created successfully", console_ns.models[APIBasedExtensionResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(api_based_extension_model)
|
||||
def post(self):
|
||||
payload = APIBasedExtensionPayload.model_validate(console_ns.payload or {})
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@@ -86,7 +130,7 @@ class APIBasedExtensionAPI(Resource):
|
||||
api_key=payload.api_key,
|
||||
)
|
||||
|
||||
return APIBasedExtensionService.save(extension_data)
|
||||
return _serialize_api_based_extension(APIBasedExtensionService.save(extension_data))
|
||||
|
||||
|
||||
@console_ns.route("/api-based-extension/<uuid:id>")
|
||||
@@ -94,26 +138,26 @@ class APIBasedExtensionDetailAPI(Resource):
|
||||
@console_ns.doc("get_api_based_extension")
|
||||
@console_ns.doc(description="Get API-based extension by ID")
|
||||
@console_ns.doc(params={"id": "Extension ID"})
|
||||
@console_ns.response(200, "Success", api_based_extension_model)
|
||||
@console_ns.response(200, "Success", console_ns.models[APIBasedExtensionResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(api_based_extension_model)
|
||||
def get(self, id):
|
||||
api_based_extension_id = str(id)
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
return APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id)
|
||||
return _serialize_api_based_extension(
|
||||
APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id)
|
||||
)
|
||||
|
||||
@console_ns.doc("update_api_based_extension")
|
||||
@console_ns.doc(description="Update API-based extension")
|
||||
@console_ns.doc(params={"id": "Extension ID"})
|
||||
@console_ns.expect(console_ns.models[APIBasedExtensionPayload.__name__])
|
||||
@console_ns.response(200, "Extension updated successfully", api_based_extension_model)
|
||||
@console_ns.response(200, "Extension updated successfully", console_ns.models[APIBasedExtensionResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(api_based_extension_model)
|
||||
def post(self, id):
|
||||
api_based_extension_id = str(id)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@@ -128,7 +172,7 @@ class APIBasedExtensionDetailAPI(Resource):
|
||||
if payload.api_key != HIDDEN_VALUE:
|
||||
extension_data_from_db.api_key = payload.api_key
|
||||
|
||||
return APIBasedExtensionService.save(extension_data_from_db)
|
||||
return _serialize_api_based_extension(APIBasedExtensionService.save(extension_data_from_db))
|
||||
|
||||
@console_ns.doc("delete_api_based_extension")
|
||||
@console_ns.doc(description="Delete API-based extension")
|
||||
|
||||
@@ -35,22 +35,24 @@ def plugin_permission_required(
|
||||
return view(*args, **kwargs)
|
||||
|
||||
if install_required:
|
||||
if permission.install_permission == TenantPluginPermission.InstallPermission.NOBODY:
|
||||
raise Forbidden()
|
||||
if permission.install_permission == TenantPluginPermission.InstallPermission.ADMINS:
|
||||
if not user.is_admin_or_owner:
|
||||
match permission.install_permission:
|
||||
case TenantPluginPermission.InstallPermission.NOBODY:
|
||||
raise Forbidden()
|
||||
if permission.install_permission == TenantPluginPermission.InstallPermission.EVERYONE:
|
||||
pass
|
||||
case TenantPluginPermission.InstallPermission.ADMINS:
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
case TenantPluginPermission.InstallPermission.EVERYONE:
|
||||
pass
|
||||
|
||||
if debug_required:
|
||||
if permission.debug_permission == TenantPluginPermission.DebugPermission.NOBODY:
|
||||
raise Forbidden()
|
||||
if permission.debug_permission == TenantPluginPermission.DebugPermission.ADMINS:
|
||||
if not user.is_admin_or_owner:
|
||||
match permission.debug_permission:
|
||||
case TenantPluginPermission.DebugPermission.NOBODY:
|
||||
raise Forbidden()
|
||||
if permission.debug_permission == TenantPluginPermission.DebugPermission.EVERYONE:
|
||||
pass
|
||||
case TenantPluginPermission.DebugPermission.ADMINS:
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
case TenantPluginPermission.DebugPermission.EVERYONE:
|
||||
pass
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Any, Literal
|
||||
|
||||
import pytz
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
from sqlalchemy import select
|
||||
|
||||
@@ -37,9 +37,10 @@ from controllers.console.wraps import (
|
||||
setup_required,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from fields.member_fields import Account as AccountResponse
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import EmailStr, TimestampField, extract_remote_ip, timezone
|
||||
from libs.helper import EmailStr, extract_remote_ip, timezone
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import AccountIntegrate, InvitationCode
|
||||
from models.account import AccountStatus, InvitationCodeStatus
|
||||
@@ -178,17 +179,57 @@ def _serialize_account(account) -> dict[str, Any]:
|
||||
return AccountResponse.model_validate(account, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
|
||||
integrate_fields = {
|
||||
"provider": fields.String,
|
||||
"created_at": TimestampField,
|
||||
"is_bound": fields.Boolean,
|
||||
"link": fields.String,
|
||||
}
|
||||
def _to_timestamp(value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
|
||||
integrate_model = console_ns.model("AccountIntegrate", integrate_fields)
|
||||
integrate_list_model = console_ns.model(
|
||||
"AccountIntegrateList",
|
||||
{"data": fields.List(fields.Nested(integrate_model))},
|
||||
|
||||
class AccountIntegrateResponse(ResponseModel):
|
||||
provider: str
|
||||
created_at: int | None = None
|
||||
is_bound: bool
|
||||
link: str | None = None
|
||||
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
|
||||
|
||||
class AccountIntegrateListResponse(ResponseModel):
|
||||
data: list[AccountIntegrateResponse]
|
||||
|
||||
|
||||
class EducationVerifyResponse(ResponseModel):
|
||||
token: str | None = None
|
||||
|
||||
|
||||
class EducationStatusResponse(ResponseModel):
|
||||
result: bool | None = None
|
||||
is_student: bool | None = None
|
||||
expire_at: int | None = None
|
||||
allow_refresh: bool | None = None
|
||||
|
||||
@field_validator("expire_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_expire_at(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
|
||||
|
||||
class EducationAutocompleteResponse(ResponseModel):
|
||||
data: list[str] = Field(default_factory=list)
|
||||
curr_page: int | None = None
|
||||
has_next: bool | None = None
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
AccountIntegrateResponse,
|
||||
AccountIntegrateListResponse,
|
||||
EducationVerifyResponse,
|
||||
EducationStatusResponse,
|
||||
EducationAutocompleteResponse,
|
||||
)
|
||||
|
||||
|
||||
@@ -359,7 +400,7 @@ class AccountIntegrateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(integrate_list_model)
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountIntegrateListResponse.__name__])
|
||||
def get(self):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
@@ -395,7 +436,9 @@ class AccountIntegrateApi(Resource):
|
||||
}
|
||||
)
|
||||
|
||||
return {"data": integrate_data}
|
||||
return AccountIntegrateListResponse(
|
||||
data=[AccountIntegrateResponse.model_validate(item) for item in integrate_data]
|
||||
).model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/account/delete/verify")
|
||||
@@ -447,31 +490,22 @@ class AccountDeleteUpdateFeedbackApi(Resource):
|
||||
|
||||
@console_ns.route("/account/education/verify")
|
||||
class EducationVerifyApi(Resource):
|
||||
verify_fields = {
|
||||
"token": fields.String,
|
||||
}
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
@cloud_edition_billing_enabled
|
||||
@marshal_with(verify_fields)
|
||||
@console_ns.response(200, "Success", console_ns.models[EducationVerifyResponse.__name__])
|
||||
def get(self):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
return BillingService.EducationIdentity.verify(account.id, account.email)
|
||||
return EducationVerifyResponse.model_validate(
|
||||
BillingService.EducationIdentity.verify(account.id, account.email) or {}
|
||||
).model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/account/education")
|
||||
class EducationApi(Resource):
|
||||
status_fields = {
|
||||
"result": fields.Boolean,
|
||||
"is_student": fields.Boolean,
|
||||
"expire_at": TimestampField,
|
||||
"allow_refresh": fields.Boolean,
|
||||
}
|
||||
|
||||
@console_ns.expect(console_ns.models[EducationActivatePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -491,37 +525,33 @@ class EducationApi(Resource):
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
@cloud_edition_billing_enabled
|
||||
@marshal_with(status_fields)
|
||||
@console_ns.response(200, "Success", console_ns.models[EducationStatusResponse.__name__])
|
||||
def get(self):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
res = BillingService.EducationIdentity.status(account.id)
|
||||
res = BillingService.EducationIdentity.status(account.id) or {}
|
||||
# convert expire_at to UTC timestamp from isoformat
|
||||
if res and "expire_at" in res:
|
||||
res["expire_at"] = datetime.fromisoformat(res["expire_at"]).astimezone(pytz.utc)
|
||||
return res
|
||||
return EducationStatusResponse.model_validate(res).model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/account/education/autocomplete")
|
||||
class EducationAutoCompleteApi(Resource):
|
||||
data_fields = {
|
||||
"data": fields.List(fields.String),
|
||||
"curr_page": fields.Integer,
|
||||
"has_next": fields.Boolean,
|
||||
}
|
||||
|
||||
@console_ns.expect(console_ns.models[EducationAutocompleteQuery.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
@cloud_edition_billing_enabled
|
||||
@marshal_with(data_fields)
|
||||
@console_ns.response(200, "Success", console_ns.models[EducationAutocompleteResponse.__name__])
|
||||
def get(self):
|
||||
payload = request.args.to_dict(flat=True)
|
||||
args = EducationAutocompleteQuery.model_validate(payload)
|
||||
|
||||
return BillingService.EducationIdentity.autocomplete(args.keywords, args.page, args.limit)
|
||||
return EducationAutocompleteResponse.model_validate(
|
||||
BillingService.EducationIdentity.autocomplete(args.keywords, args.page, args.limit) or {}
|
||||
).model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/account/change-email")
|
||||
|
||||
@@ -465,7 +465,7 @@ class ModelProviderModelDisableApi(Resource):
|
||||
class ParserValidate(BaseModel):
|
||||
model: str
|
||||
model_type: ModelType
|
||||
credentials: dict
|
||||
credentials: dict[str, Any]
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from flask_restx import Resource, fields, marshal
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
@@ -26,6 +27,7 @@ from controllers.console.wraps import (
|
||||
)
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from libs.helper import TimestampField
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.account import Tenant, TenantCustomConfigDict, TenantStatus
|
||||
@@ -58,6 +60,37 @@ class WorkspaceInfoPayload(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
class TenantInfoResponse(ResponseModel):
|
||||
id: str
|
||||
name: str | None = None
|
||||
plan: str | None = None
|
||||
status: str | None = None
|
||||
created_at: int | None = None
|
||||
role: str | None = None
|
||||
in_trial: bool | None = None
|
||||
trial_end_reason: str | None = None
|
||||
custom_config: dict | None = None
|
||||
trial_credits: int | None = None
|
||||
trial_credits_used: int | None = None
|
||||
next_credit_reset_date: int | None = None
|
||||
|
||||
@field_validator("plan", "status", "trial_end_reason", mode="before")
|
||||
@classmethod
|
||||
def _normalize_enum_like(cls, value):
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
return str(getattr(value, "value", value))
|
||||
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_created_at(cls, value: datetime | int | None):
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
@@ -66,6 +99,7 @@ reg(WorkspaceListQuery)
|
||||
reg(SwitchWorkspacePayload)
|
||||
reg(WorkspaceCustomConfigPayload)
|
||||
reg(WorkspaceInfoPayload)
|
||||
reg(TenantInfoResponse)
|
||||
|
||||
provider_fields = {
|
||||
"provider_name": fields.String,
|
||||
@@ -180,7 +214,7 @@ class TenantApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(tenant_fields)
|
||||
@console_ns.response(200, "Success", console_ns.models[TenantInfoResponse.__name__])
|
||||
def post(self):
|
||||
if request.path == "/info":
|
||||
logger.warning("Deprecated URL /info was used.")
|
||||
@@ -200,7 +234,13 @@ class TenantApi(Resource):
|
||||
else:
|
||||
raise Unauthorized("workspace is archived")
|
||||
|
||||
return WorkspaceService.get_tenant_info(tenant), 200
|
||||
return (
|
||||
TenantInfoResponse.model_validate(
|
||||
WorkspaceService.get_tenant_info(tenant),
|
||||
from_attributes=True,
|
||||
).model_dump(mode="json"),
|
||||
200,
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/switch")
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from graphon.variables.types import SegmentType
|
||||
from pydantic import BaseModel, Field, TypeAdapter, field_validator
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
@@ -14,14 +16,12 @@ from controllers.service_api.app.error import NotChatAppError
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from fields._value_type_serializer import serialize_value_type
|
||||
from fields.base import ResponseModel
|
||||
from fields.conversation_fields import (
|
||||
ConversationInfiniteScrollPagination,
|
||||
SimpleConversation,
|
||||
)
|
||||
from fields.conversation_variable_fields import (
|
||||
build_conversation_variable_infinite_scroll_pagination_model,
|
||||
build_conversation_variable_model,
|
||||
)
|
||||
from libs.helper import UUIDStrOrEmpty
|
||||
from models.model import App, AppMode, EndUser
|
||||
from services.conversation_service import ConversationService
|
||||
@@ -70,12 +70,70 @@ class ConversationVariableUpdatePayload(BaseModel):
|
||||
value: Any
|
||||
|
||||
|
||||
class ConversationVariableResponse(ResponseModel):
|
||||
id: str
|
||||
name: str
|
||||
value_type: str
|
||||
value: str | None = None
|
||||
description: str | None = None
|
||||
created_at: int | None = None
|
||||
updated_at: int | None = None
|
||||
|
||||
@field_validator("value_type", mode="before")
|
||||
@classmethod
|
||||
def normalize_value_type(cls, value: Any) -> str:
|
||||
exposed_type = getattr(value, "exposed_type", None)
|
||||
if callable(exposed_type):
|
||||
return str(exposed_type().value)
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
return str(SegmentType(value).exposed_type().value)
|
||||
except ValueError:
|
||||
return value
|
||||
try:
|
||||
return serialize_value_type(value)
|
||||
except (AttributeError, TypeError, ValueError):
|
||||
pass
|
||||
|
||||
try:
|
||||
return serialize_value_type({"value_type": value})
|
||||
except (AttributeError, TypeError, ValueError):
|
||||
value_attr = getattr(value, "value", None)
|
||||
if value_attr is not None:
|
||||
return str(value_attr)
|
||||
return str(value)
|
||||
|
||||
@field_validator("value", mode="before")
|
||||
@classmethod
|
||||
def normalize_value(cls, value: Any | None) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
return str(value)
|
||||
|
||||
@field_validator("created_at", "updated_at", mode="before")
|
||||
@classmethod
|
||||
def normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
|
||||
|
||||
class ConversationVariableInfiniteScrollPaginationResponse(ResponseModel):
|
||||
limit: int
|
||||
has_more: bool
|
||||
data: list[ConversationVariableResponse]
|
||||
|
||||
|
||||
register_schema_models(
|
||||
service_api_ns,
|
||||
ConversationListQuery,
|
||||
ConversationRenamePayload,
|
||||
ConversationVariablesQuery,
|
||||
ConversationVariableUpdatePayload,
|
||||
ConversationVariableResponse,
|
||||
ConversationVariableInfiniteScrollPaginationResponse,
|
||||
)
|
||||
|
||||
|
||||
@@ -204,8 +262,12 @@ class ConversationVariablesApi(Resource):
|
||||
404: "Conversation not found",
|
||||
}
|
||||
)
|
||||
@service_api_ns.response(
|
||||
200,
|
||||
"Variables retrieved successfully",
|
||||
service_api_ns.models[ConversationVariableInfiniteScrollPaginationResponse.__name__],
|
||||
)
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
|
||||
@service_api_ns.marshal_with(build_conversation_variable_infinite_scroll_pagination_model(service_api_ns))
|
||||
def get(self, app_model: App, end_user: EndUser, c_id):
|
||||
"""List all variables for a conversation.
|
||||
|
||||
@@ -222,9 +284,12 @@ class ConversationVariablesApi(Resource):
|
||||
last_id = str(query_args.last_id) if query_args.last_id else None
|
||||
|
||||
try:
|
||||
return ConversationService.get_conversational_variable(
|
||||
pagination = ConversationService.get_conversational_variable(
|
||||
app_model, conversation_id, end_user, query_args.limit, last_id, query_args.variable_name
|
||||
)
|
||||
return ConversationVariableInfiniteScrollPaginationResponse.model_validate(
|
||||
pagination, from_attributes=True
|
||||
).model_dump(mode="json")
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
@@ -243,8 +308,12 @@ class ConversationVariableDetailApi(Resource):
|
||||
404: "Conversation or variable not found",
|
||||
}
|
||||
)
|
||||
@service_api_ns.response(
|
||||
200,
|
||||
"Variable updated successfully",
|
||||
service_api_ns.models[ConversationVariableResponse.__name__],
|
||||
)
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
|
||||
@service_api_ns.marshal_with(build_conversation_variable_model(service_api_ns))
|
||||
def put(self, app_model: App, end_user: EndUser, c_id, variable_id):
|
||||
"""Update a conversation variable's value.
|
||||
|
||||
@@ -261,9 +330,10 @@ class ConversationVariableDetailApi(Resource):
|
||||
payload = ConversationVariableUpdatePayload.model_validate(service_api_ns.payload or {})
|
||||
|
||||
try:
|
||||
return ConversationService.update_conversation_variable(
|
||||
variable = ConversationService.update_conversation_variable(
|
||||
app_model, conversation_id, variable_id, end_user, payload.value
|
||||
)
|
||||
return ConversationVariableResponse.model_validate(variable, from_attributes=True).model_dump(mode="json")
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
except services.errors.conversation.ConversationVariableNotExistsError:
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
|
||||
from dateutil.parser import isoparse
|
||||
from flask import request
|
||||
from flask_restx import Namespace, Resource, fields
|
||||
from flask_restx import Resource, fields
|
||||
from graphon.enums import WorkflowExecutionStatus
|
||||
from graphon.graph_engine.manager import GraphEngineManager
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
||||
|
||||
@@ -33,9 +35,10 @@ from core.errors.error import (
|
||||
from core.helper.trace_id_helper import get_external_trace_id
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model
|
||||
from fields.base import ResponseModel
|
||||
from fields.end_user_fields import SimpleEndUser
|
||||
from fields.member_fields import SimpleAccount
|
||||
from libs import helper
|
||||
from libs.helper import OptionalTimestampField, TimestampField
|
||||
from models.model import App, AppMode, EndUser
|
||||
from models.workflow import WorkflowRun
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
@@ -65,38 +68,142 @@ class WorkflowLogQuery(BaseModel):
|
||||
register_schema_models(service_api_ns, WorkflowRunPayload, WorkflowLogQuery)
|
||||
|
||||
|
||||
def _to_timestamp(value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
|
||||
|
||||
def _enum_value(value):
|
||||
return getattr(value, "value", value)
|
||||
|
||||
|
||||
class WorkflowRunStatusField(fields.Raw):
|
||||
def output(self, key, obj: WorkflowRun, **kwargs):
|
||||
return obj.status.value
|
||||
return _enum_value(obj.status)
|
||||
|
||||
|
||||
class WorkflowRunOutputsField(fields.Raw):
|
||||
def output(self, key, obj: WorkflowRun, **kwargs):
|
||||
if obj.status == WorkflowExecutionStatus.PAUSED:
|
||||
status = _enum_value(obj.status)
|
||||
if status == WorkflowExecutionStatus.PAUSED.value:
|
||||
return {}
|
||||
|
||||
outputs = obj.outputs_dict
|
||||
return outputs or {}
|
||||
|
||||
|
||||
workflow_run_fields = {
|
||||
"id": fields.String,
|
||||
"workflow_id": fields.String,
|
||||
"status": WorkflowRunStatusField,
|
||||
"inputs": fields.Raw,
|
||||
"outputs": WorkflowRunOutputsField,
|
||||
"error": fields.String,
|
||||
"total_steps": fields.Integer,
|
||||
"total_tokens": fields.Integer,
|
||||
"created_at": TimestampField,
|
||||
"finished_at": OptionalTimestampField,
|
||||
"elapsed_time": fields.Float,
|
||||
}
|
||||
class WorkflowRunResponse(ResponseModel):
|
||||
id: str
|
||||
workflow_id: str
|
||||
status: str
|
||||
inputs: dict | list | str | int | float | bool | None = None
|
||||
outputs: dict = Field(default_factory=dict)
|
||||
error: str | None = None
|
||||
total_steps: int | None = None
|
||||
total_tokens: int | None = None
|
||||
created_at: int | None = None
|
||||
finished_at: int | None = None
|
||||
elapsed_time: float | int | None = None
|
||||
|
||||
@field_validator("created_at", "finished_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
|
||||
|
||||
def build_workflow_run_model(api_or_ns: Namespace):
|
||||
"""Build the workflow run model for the API or Namespace."""
|
||||
return api_or_ns.model("WorkflowRun", workflow_run_fields)
|
||||
class WorkflowRunForLogResponse(ResponseModel):
|
||||
id: str
|
||||
version: str | None = None
|
||||
status: str | None = None
|
||||
triggered_from: str | None = None
|
||||
error: str | None = None
|
||||
elapsed_time: float | int | None = None
|
||||
total_tokens: int | None = None
|
||||
total_steps: int | None = None
|
||||
created_at: int | None = None
|
||||
finished_at: int | None = None
|
||||
exceptions_count: int | None = None
|
||||
|
||||
@field_validator("status", "triggered_from", mode="before")
|
||||
@classmethod
|
||||
def _normalize_enum(cls, value):
|
||||
return _enum_value(value)
|
||||
|
||||
@field_validator("created_at", "finished_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
|
||||
|
||||
class WorkflowAppLogPartialResponse(ResponseModel):
|
||||
id: str
|
||||
workflow_run: WorkflowRunForLogResponse | None = None
|
||||
details: dict | list | str | int | float | bool | None = None
|
||||
created_from: str | None = None
|
||||
created_by_role: str | None = None
|
||||
created_by_account: SimpleAccount | None = None
|
||||
created_by_end_user: SimpleEndUser | None = None
|
||||
created_at: int | None = None
|
||||
|
||||
@field_validator("created_from", "created_by_role", mode="before")
|
||||
@classmethod
|
||||
def _normalize_enum(cls, value):
|
||||
return _enum_value(value)
|
||||
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
|
||||
|
||||
class WorkflowAppLogPaginationResponse(ResponseModel):
|
||||
page: int
|
||||
limit: int
|
||||
total: int
|
||||
has_more: bool
|
||||
data: list[WorkflowAppLogPartialResponse]
|
||||
|
||||
|
||||
register_schema_models(
|
||||
service_api_ns,
|
||||
WorkflowRunResponse,
|
||||
WorkflowRunForLogResponse,
|
||||
WorkflowAppLogPartialResponse,
|
||||
WorkflowAppLogPaginationResponse,
|
||||
)
|
||||
|
||||
|
||||
def _serialize_workflow_run(workflow_run: WorkflowRun) -> dict:
|
||||
status = _enum_value(workflow_run.status)
|
||||
raw_outputs = workflow_run.outputs_dict
|
||||
if status == WorkflowExecutionStatus.PAUSED.value or raw_outputs is None:
|
||||
outputs: dict = {}
|
||||
elif isinstance(raw_outputs, dict):
|
||||
outputs = raw_outputs
|
||||
elif isinstance(raw_outputs, Mapping):
|
||||
outputs = dict(raw_outputs)
|
||||
else:
|
||||
outputs = {}
|
||||
return WorkflowRunResponse.model_validate(
|
||||
{
|
||||
"id": workflow_run.id,
|
||||
"workflow_id": workflow_run.workflow_id,
|
||||
"status": status,
|
||||
"inputs": workflow_run.inputs,
|
||||
"outputs": outputs,
|
||||
"error": workflow_run.error,
|
||||
"total_steps": workflow_run.total_steps,
|
||||
"total_tokens": workflow_run.total_tokens,
|
||||
"created_at": workflow_run.created_at,
|
||||
"finished_at": workflow_run.finished_at,
|
||||
"elapsed_time": workflow_run.elapsed_time,
|
||||
}
|
||||
).model_dump(mode="json")
|
||||
|
||||
|
||||
def _serialize_workflow_log_pagination(pagination) -> dict:
|
||||
return WorkflowAppLogPaginationResponse.model_validate(pagination, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
|
||||
@service_api_ns.route("/workflows/run/<string:workflow_run_id>")
|
||||
@@ -112,7 +219,11 @@ class WorkflowRunDetailApi(Resource):
|
||||
}
|
||||
)
|
||||
@validate_app_token
|
||||
@service_api_ns.marshal_with(build_workflow_run_model(service_api_ns))
|
||||
@service_api_ns.response(
|
||||
200,
|
||||
"Workflow run details retrieved successfully",
|
||||
service_api_ns.models[WorkflowRunResponse.__name__],
|
||||
)
|
||||
def get(self, app_model: App, workflow_run_id: str):
|
||||
"""Get a workflow task running detail.
|
||||
|
||||
@@ -133,7 +244,7 @@ class WorkflowRunDetailApi(Resource):
|
||||
)
|
||||
if not workflow_run:
|
||||
raise NotFound("Workflow run not found.")
|
||||
return workflow_run
|
||||
return _serialize_workflow_run(workflow_run)
|
||||
|
||||
|
||||
@service_api_ns.route("/workflows/run")
|
||||
@@ -299,7 +410,11 @@ class WorkflowAppLogApi(Resource):
|
||||
}
|
||||
)
|
||||
@validate_app_token
|
||||
@service_api_ns.marshal_with(build_workflow_app_log_pagination_model(service_api_ns))
|
||||
@service_api_ns.response(
|
||||
200,
|
||||
"Logs retrieved successfully",
|
||||
service_api_ns.models[WorkflowAppLogPaginationResponse.__name__],
|
||||
)
|
||||
def get(self, app_model: App):
|
||||
"""Get workflow app logs.
|
||||
|
||||
@@ -327,4 +442,4 @@ class WorkflowAppLogApi(Resource):
|
||||
created_by_account=args.created_by_account,
|
||||
)
|
||||
|
||||
return workflow_app_log_pagination
|
||||
return _serialize_workflow_log_pagination(workflow_app_log_pagination)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
import re
|
||||
from collections.abc import Generator
|
||||
from typing import Union
|
||||
from typing import Any, Union
|
||||
|
||||
from graphon.model_runtime.entities.llm_entities import LLMResultChunk
|
||||
|
||||
@@ -11,7 +11,7 @@ from core.agent.entities import AgentScratchpadUnit
|
||||
class CotAgentOutputParser:
|
||||
@classmethod
|
||||
def handle_react_stream_output(
|
||||
cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict
|
||||
cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict[str, Any]
|
||||
) -> Generator[Union[str, AgentScratchpadUnit.Action], None, None]:
|
||||
def parse_action(action) -> Union[str, AgentScratchpadUnit.Action]:
|
||||
action_name = None
|
||||
|
||||
@@ -84,7 +84,7 @@ class AgentStrategyEntity(BaseModel):
|
||||
identity: AgentStrategyIdentity
|
||||
parameters: list[AgentStrategyParameter] = Field(default_factory=list)
|
||||
description: I18nObject = Field(..., description="The description of the agent strategy")
|
||||
output_schema: dict | None = None
|
||||
output_schema: dict[str, Any] | None = None
|
||||
features: list[AgentFeature] | None = None
|
||||
meta_version: str | None = None
|
||||
# pydantic configs
|
||||
|
||||
@@ -22,8 +22,8 @@ class SensitiveWordAvoidanceConfigManager:
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(
|
||||
cls, tenant_id: str, config: dict, only_structure_validate: bool = False
|
||||
) -> tuple[dict, list[str]]:
|
||||
cls, tenant_id: str, config: dict[str, Any], only_structure_validate: bool = False
|
||||
) -> tuple[dict[str, Any], list[str]]:
|
||||
if not config.get("sensitive_word_avoidance"):
|
||||
config["sensitive_word_avoidance"] = {"enabled": False}
|
||||
|
||||
|
||||
@@ -41,7 +41,7 @@ class ModelConfigManager:
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, tenant_id: str, config: Mapping[str, Any]) -> tuple[dict, list[str]]:
|
||||
def validate_and_set_defaults(cls, tenant_id: str, config: Mapping[str, Any]) -> tuple[dict[str, Any], list[str]]:
|
||||
"""
|
||||
Validate and set defaults for model config
|
||||
|
||||
|
||||
@@ -57,7 +57,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
@classmethod
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, Any, None]:
|
||||
) -> Generator[dict[str, Any] | str, Any, None]:
|
||||
"""
|
||||
Convert stream full response.
|
||||
:param stream_response: stream response
|
||||
@@ -88,7 +88,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
@classmethod
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, Any, None]:
|
||||
) -> Generator[dict[str, Any] | str, Any, None]:
|
||||
"""
|
||||
Convert stream simple response.
|
||||
:param stream_response: stream response
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from collections.abc import Generator
|
||||
from typing import cast
|
||||
from typing import Any, cast
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
@@ -56,7 +56,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
@classmethod
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
) -> Generator[dict[str, Any] | str, None, None]:
|
||||
"""
|
||||
Convert stream full response.
|
||||
:param stream_response: stream response
|
||||
@@ -87,7 +87,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
@classmethod
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
) -> Generator[dict[str, Any] | str, None, None]:
|
||||
"""
|
||||
Convert stream simple response.
|
||||
:param stream_response: stream response
|
||||
|
||||
@@ -24,7 +24,7 @@ class AppGenerateResponseConverter(ABC):
|
||||
return cls.convert_blocking_full_response(response)
|
||||
else:
|
||||
|
||||
def _generate_full_response() -> Generator[dict | str, Any, None]:
|
||||
def _generate_full_response() -> Generator[dict[str, Any] | str, Any, None]:
|
||||
yield from cls.convert_stream_full_response(response)
|
||||
|
||||
return _generate_full_response()
|
||||
@@ -33,7 +33,7 @@ class AppGenerateResponseConverter(ABC):
|
||||
return cls.convert_blocking_simple_response(response)
|
||||
else:
|
||||
|
||||
def _generate_simple_response() -> Generator[dict | str, Any, None]:
|
||||
def _generate_simple_response() -> Generator[dict[str, Any] | str, Any, None]:
|
||||
yield from cls.convert_stream_simple_response(response)
|
||||
|
||||
return _generate_simple_response()
|
||||
@@ -52,14 +52,14 @@ class AppGenerateResponseConverter(ABC):
|
||||
@abstractmethod
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
) -> Generator[dict[str, Any] | str, None, None]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
) -> Generator[dict[str, Any] | str, None, None]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from collections.abc import Generator
|
||||
from typing import cast
|
||||
from typing import Any, cast
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
@@ -56,7 +56,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
@classmethod
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
) -> Generator[dict[str, Any] | str, None, None]:
|
||||
"""
|
||||
Convert stream full response.
|
||||
:param stream_response: stream response
|
||||
@@ -87,7 +87,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
@classmethod
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
) -> Generator[dict[str, Any] | str, None, None]:
|
||||
"""
|
||||
Convert stream simple response.
|
||||
:param stream_response: stream response
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from collections.abc import Generator
|
||||
from typing import cast
|
||||
from typing import Any, cast
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
@@ -55,7 +55,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
@classmethod
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
) -> Generator[dict[str, Any] | str, None, None]:
|
||||
"""
|
||||
Convert stream full response.
|
||||
:param stream_response: stream response
|
||||
@@ -85,7 +85,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
@classmethod
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
) -> Generator[dict[str, Any] | str, None, None]:
|
||||
"""
|
||||
Convert stream simple response.
|
||||
:param stream_response: stream response
|
||||
|
||||
@@ -37,7 +37,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
@classmethod
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
) -> Generator[dict[str, Any] | str, None, None]:
|
||||
"""
|
||||
Convert stream full response.
|
||||
:param stream_response: stream response
|
||||
@@ -66,7 +66,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
@classmethod
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
) -> Generator[dict[str, Any] | str, None, None]:
|
||||
"""
|
||||
Convert stream simple response.
|
||||
:param stream_response: stream response
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from collections.abc import Generator
|
||||
from typing import cast
|
||||
from typing import Any, cast
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
@@ -37,7 +37,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
@classmethod
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
) -> Generator[dict[str, Any] | str, None, None]:
|
||||
"""
|
||||
Convert stream full response.
|
||||
:param stream_response: stream response
|
||||
@@ -66,7 +66,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
@classmethod
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
) -> Generator[dict[str, Any] | str, None, None]:
|
||||
"""
|
||||
Convert stream simple response.
|
||||
:param stream_response: stream response
|
||||
|
||||
@@ -682,15 +682,16 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
|
||||
def _save_workflow_app_log(self, *, session: Session, workflow_run_id: str | None):
|
||||
invoke_from = self._application_generate_entity.invoke_from
|
||||
if invoke_from == InvokeFrom.SERVICE_API:
|
||||
created_from = WorkflowAppLogCreatedFrom.SERVICE_API
|
||||
elif invoke_from == InvokeFrom.EXPLORE:
|
||||
created_from = WorkflowAppLogCreatedFrom.INSTALLED_APP
|
||||
elif invoke_from == InvokeFrom.WEB_APP:
|
||||
created_from = WorkflowAppLogCreatedFrom.WEB_APP
|
||||
else:
|
||||
# not save log for debugging
|
||||
return
|
||||
match invoke_from:
|
||||
case InvokeFrom.SERVICE_API:
|
||||
created_from = WorkflowAppLogCreatedFrom.SERVICE_API
|
||||
case InvokeFrom.EXPLORE:
|
||||
created_from = WorkflowAppLogCreatedFrom.INSTALLED_APP
|
||||
case InvokeFrom.WEB_APP:
|
||||
created_from = WorkflowAppLogCreatedFrom.WEB_APP
|
||||
case InvokeFrom.DEBUGGER | InvokeFrom.TRIGGER | InvokeFrom.PUBLISHED_PIPELINE | InvokeFrom.VALIDATION:
|
||||
# not save log for debugging
|
||||
return
|
||||
|
||||
if not workflow_run_id:
|
||||
return
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import cast
|
||||
from typing import Any, cast
|
||||
|
||||
import httpx
|
||||
|
||||
@@ -14,7 +14,7 @@ class APIBasedExtensionRequestor:
|
||||
self.api_endpoint = api_endpoint
|
||||
self.api_key = api_key
|
||||
|
||||
def request(self, point: APIBasedExtensionPoint, params: dict):
|
||||
def request(self, point: APIBasedExtensionPoint, params: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Request the api.
|
||||
|
||||
@@ -49,4 +49,4 @@ class APIBasedExtensionRequestor:
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"request error, status_code: {response.status_code}, content: {response.text[:100]}")
|
||||
|
||||
return cast(dict, response.json())
|
||||
return cast(dict[str, Any], response.json())
|
||||
|
||||
@@ -21,8 +21,8 @@ class ExtensionModule(StrEnum):
|
||||
class ModuleExtension(BaseModel):
|
||||
extension_class: Any | None = None
|
||||
name: str
|
||||
label: dict | None = None
|
||||
form_schema: list | None = None
|
||||
label: dict[str, Any] | None = None
|
||||
form_schema: list[dict[str, Any]] | None = None
|
||||
builtin: bool = True
|
||||
position: int | None = None
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ class ExternalDataToolFactory:
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def validate_config(cls, name: str, tenant_id: str, config: dict[str, Any]):
|
||||
def validate_config(cls, name: str, tenant_id: str, config: dict[str, Any]) -> None:
|
||||
"""
|
||||
Validate the incoming form config data.
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from flask import Flask
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from pydantic import BaseModel
|
||||
@@ -28,7 +30,7 @@ class FreeHostingQuota(HostingQuota):
|
||||
|
||||
class HostingProvider(BaseModel):
|
||||
enabled: bool = False
|
||||
credentials: dict | None = None
|
||||
credentials: dict[str, Any] | None = None
|
||||
quota_unit: QuotaUnit | None = None
|
||||
quotas: list[HostingQuota] = []
|
||||
|
||||
|
||||
@@ -737,7 +737,7 @@ class IndexingRunner:
|
||||
def _update_document_index_status(
|
||||
document_id: str,
|
||||
after_indexing_status: IndexingStatus,
|
||||
extra_update_params: dict[Any, Any] | None = None,
|
||||
extra_update_params: Mapping[Any, Any] | None = None,
|
||||
):
|
||||
"""
|
||||
Update the document indexing status.
|
||||
@@ -764,7 +764,7 @@ class IndexingRunner:
|
||||
db.session.commit()
|
||||
|
||||
@staticmethod
|
||||
def _update_segments_by_document(dataset_document_id: str, update_params: dict[Any, Any]):
|
||||
def _update_segments_by_document(dataset_document_id: str, update_params: Mapping[Any, Any]):
|
||||
"""
|
||||
Update the document segment by document id.
|
||||
"""
|
||||
|
||||
@@ -2,7 +2,7 @@ import json
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import Sequence
|
||||
from typing import Protocol, TypedDict, cast
|
||||
from typing import Any, Protocol, TypedDict, cast
|
||||
|
||||
import json_repair
|
||||
from graphon.enums import WorkflowNodeExecutionMetadataKey
|
||||
@@ -533,7 +533,7 @@ class LLMGenerator:
|
||||
def __instruction_modify_common(
|
||||
tenant_id: str,
|
||||
model_config: ModelConfig,
|
||||
last_run: dict | None,
|
||||
last_run: dict[str, Any] | None,
|
||||
current: str | None,
|
||||
error_message: str | None,
|
||||
instruction: str,
|
||||
|
||||
@@ -202,7 +202,7 @@ def _handle_native_json_schema(
|
||||
structured_output_schema: Mapping,
|
||||
model_parameters: dict[str, Any],
|
||||
rules: list[ParameterRule],
|
||||
):
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Handle structured output for models with native JSON schema support.
|
||||
|
||||
@@ -224,7 +224,7 @@ def _handle_native_json_schema(
|
||||
return model_parameters
|
||||
|
||||
|
||||
def _set_response_format(model_parameters: dict[str, Any], rules: list[ParameterRule]):
|
||||
def _set_response_format(model_parameters: dict[str, Any], rules: list[ParameterRule]) -> None:
|
||||
"""
|
||||
Set the appropriate response format parameter based on model rules.
|
||||
|
||||
@@ -326,7 +326,7 @@ def _prepare_schema_for_model(provider: str, model_schema: AIModelEntity, schema
|
||||
return {"schema": processed_schema, "name": "llm_response"}
|
||||
|
||||
|
||||
def remove_additional_properties(schema: dict[str, Any]):
|
||||
def remove_additional_properties(schema: dict[str, Any]) -> None:
|
||||
"""
|
||||
Remove additionalProperties fields from JSON schema.
|
||||
Used for models like Gemini that don't support this property.
|
||||
@@ -349,7 +349,7 @@ def remove_additional_properties(schema: dict[str, Any]):
|
||||
remove_additional_properties(item)
|
||||
|
||||
|
||||
def convert_boolean_to_string(schema: dict):
|
||||
def convert_boolean_to_string(schema: dict[str, Any]) -> None:
|
||||
"""
|
||||
Convert boolean type specifications to string in JSON schema.
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import tempfile
|
||||
from binascii import hexlify, unhexlify
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from graphon.model_runtime.entities.llm_entities import (
|
||||
LLMResult,
|
||||
@@ -226,7 +227,7 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
# invoke model
|
||||
response = model_instance.invoke_tts(content_text=payload.content_text, voice=payload.voice)
|
||||
|
||||
def handle() -> Generator[dict, None, None]:
|
||||
def handle() -> Generator[dict[str, Any], None, None]:
|
||||
for chunk in response:
|
||||
yield {"result": hexlify(chunk).decode("utf-8")}
|
||||
|
||||
|
||||
@@ -313,7 +313,7 @@ class SimplePromptTransform(PromptTransform):
|
||||
|
||||
return prompt_message
|
||||
|
||||
def _get_prompt_rule(self, app_mode: AppMode, provider: str, model: str):
|
||||
def _get_prompt_rule(self, app_mode: AppMode, provider: str, model: str) -> dict[str, Any]:
|
||||
"""
|
||||
Get simple prompt rule.
|
||||
:param app_mode: app mode
|
||||
@@ -325,7 +325,7 @@ class SimplePromptTransform(PromptTransform):
|
||||
|
||||
# Check if the prompt file is already loaded
|
||||
if prompt_file_name in prompt_file_contents:
|
||||
return cast(dict, prompt_file_contents[prompt_file_name])
|
||||
return cast(dict[str, Any], prompt_file_contents[prompt_file_name])
|
||||
|
||||
# Get the absolute path of the subdirectory
|
||||
prompt_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "prompt_templates")
|
||||
@@ -338,7 +338,7 @@ class SimplePromptTransform(PromptTransform):
|
||||
# Store the content of the prompt file
|
||||
prompt_file_contents[prompt_file_name] = content
|
||||
|
||||
return cast(dict, content)
|
||||
return cast(dict[str, Any], content)
|
||||
|
||||
def _prompt_file_name(self, app_mode: AppMode, provider: str, model: str) -> str:
|
||||
# baichuan
|
||||
|
||||
@@ -106,7 +106,7 @@ class CacheEmbedding(Embeddings):
|
||||
|
||||
return text_embeddings
|
||||
|
||||
def embed_multimodal_documents(self, multimodel_documents: list[dict]) -> list[list[float]]:
|
||||
def embed_multimodal_documents(self, multimodel_documents: list[dict[str, Any]]) -> list[list[float]]:
|
||||
"""Embed file documents."""
|
||||
# use doc embedding cache or store if not exists
|
||||
multimodel_embeddings: list[Any] = [None for _ in range(len(multimodel_documents))]
|
||||
|
||||
@@ -11,7 +11,7 @@ class Embeddings(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def embed_multimodal_documents(self, multimodel_documents: list[dict]) -> list[list[float]]:
|
||||
def embed_multimodal_documents(self, multimodel_documents: list[dict[str, Any]]) -> list[list[float]]:
|
||||
"""Embed file documents."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -95,9 +95,9 @@ class ExtractProcessor:
|
||||
) -> list[Document]:
|
||||
if extract_setting.datasource_type == DatasourceType.FILE:
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
upload_file = extract_setting.upload_file
|
||||
if not file_path:
|
||||
assert extract_setting.upload_file is not None, "upload_file is required"
|
||||
upload_file: UploadFile = extract_setting.upload_file
|
||||
assert upload_file is not None, "upload_file is required"
|
||||
suffix = Path(upload_file.key).suffix
|
||||
# FIXME mypy: Cannot determine type of 'tempfile._get_candidate_names' better not use it here
|
||||
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore
|
||||
@@ -113,6 +113,7 @@ class ExtractProcessor:
|
||||
if file_extension in {".xlsx", ".xls"}:
|
||||
extractor = ExcelExtractor(file_path)
|
||||
elif file_extension == ".pdf":
|
||||
assert upload_file is not None
|
||||
extractor = PdfExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
|
||||
elif file_extension in {".md", ".markdown", ".mdx"}:
|
||||
extractor = (
|
||||
@@ -123,6 +124,7 @@ class ExtractProcessor:
|
||||
elif file_extension in {".htm", ".html"}:
|
||||
extractor = HtmlExtractor(file_path)
|
||||
elif file_extension == ".docx":
|
||||
assert upload_file is not None
|
||||
extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
|
||||
elif file_extension == ".doc":
|
||||
extractor = UnstructuredWordExtractor(file_path, unstructured_api_url, unstructured_api_key)
|
||||
@@ -149,12 +151,14 @@ class ExtractProcessor:
|
||||
if file_extension in {".xlsx", ".xls"}:
|
||||
extractor = ExcelExtractor(file_path)
|
||||
elif file_extension == ".pdf":
|
||||
assert upload_file is not None
|
||||
extractor = PdfExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
|
||||
elif file_extension in {".md", ".markdown", ".mdx"}:
|
||||
extractor = MarkdownExtractor(file_path, autodetect_encoding=True)
|
||||
elif file_extension in {".htm", ".html"}:
|
||||
extractor = HtmlExtractor(file_path)
|
||||
elif file_extension == ".docx":
|
||||
assert upload_file is not None
|
||||
extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
|
||||
elif file_extension == ".csv":
|
||||
extractor = CSVExtractor(file_path, autodetect_encoding=True)
|
||||
|
||||
@@ -174,21 +174,25 @@ class FirecrawlApp:
|
||||
return f"{self.base_url.rstrip('/')}/{path.lstrip('/')}"
|
||||
|
||||
def _post_request(self, url, data, headers, retries=3, backoff_factor=0.5) -> httpx.Response:
|
||||
response: httpx.Response | None = None
|
||||
for attempt in range(retries):
|
||||
response = httpx.post(url, headers=headers, json=data)
|
||||
if response.status_code == 502:
|
||||
time.sleep(backoff_factor * (2**attempt))
|
||||
else:
|
||||
return response
|
||||
assert response is not None, "retries must be at least 1"
|
||||
return response
|
||||
|
||||
def _get_request(self, url, headers, retries=3, backoff_factor=0.5) -> httpx.Response:
|
||||
response: httpx.Response | None = None
|
||||
for attempt in range(retries):
|
||||
response = httpx.get(url, headers=headers)
|
||||
if response.status_code == 502:
|
||||
time.sleep(backoff_factor * (2**attempt))
|
||||
else:
|
||||
return response
|
||||
assert response is not None, "retries must be at least 1"
|
||||
return response
|
||||
|
||||
def _handle_error(self, response, action):
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import NamedTuple, Union
|
||||
from typing import Any, NamedTuple, Union
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -10,7 +10,7 @@ class ReactAction:
|
||||
|
||||
tool: str
|
||||
"""The name of the Tool to execute."""
|
||||
tool_input: Union[str, dict]
|
||||
tool_input: Union[str, dict[str, Any]]
|
||||
"""The input to pass in to the Tool."""
|
||||
log: str
|
||||
"""Additional information to log about the action."""
|
||||
@@ -19,7 +19,7 @@ class ReactAction:
|
||||
class ReactFinish(NamedTuple):
|
||||
"""The final return value of an ReactFinish."""
|
||||
|
||||
return_values: dict
|
||||
return_values: dict[str, Any]
|
||||
"""Dictionary of return values."""
|
||||
log: str
|
||||
"""Additional information to log about the return value"""
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from collections.abc import Generator, Sequence
|
||||
from typing import Union
|
||||
from typing import Any, Union
|
||||
|
||||
from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
|
||||
@@ -139,7 +139,7 @@ class ReactMultiDatasetRouter:
|
||||
|
||||
def _invoke_llm(
|
||||
self,
|
||||
completion_param: dict,
|
||||
completion_param: dict[str, Any],
|
||||
model_instance: ModelInstance,
|
||||
prompt_messages: list[PromptMessage],
|
||||
stop: list[str],
|
||||
|
||||
@@ -63,7 +63,7 @@ class TextSplitter(BaseDocumentTransformer, ABC):
|
||||
def split_text(self, text: str) -> list[str]:
|
||||
"""Split text into multiple components."""
|
||||
|
||||
def create_documents(self, texts: list[str], metadatas: list[dict] | None = None) -> list[Document]:
|
||||
def create_documents(self, texts: list[str], metadatas: list[dict[str, Any]] | None = None) -> list[Document]:
|
||||
"""Create documents from a list of texts."""
|
||||
_metadatas = metadatas or [{}] * len(texts)
|
||||
documents = []
|
||||
|
||||
@@ -254,7 +254,7 @@ def resolve_dify_schema_refs(
|
||||
return resolver.resolve(schema)
|
||||
|
||||
|
||||
def _remove_metadata_fields(schema: dict) -> dict:
|
||||
def _remove_metadata_fields(schema: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Remove metadata fields from schema that shouldn't be included in resolved output
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -26,6 +27,6 @@ class ApiToolBundle(BaseModel):
|
||||
# icon
|
||||
icon: str | None = None
|
||||
# openapi operation
|
||||
openapi: dict
|
||||
openapi: dict[str, Any]
|
||||
# output schema
|
||||
output_schema: Mapping[str, object] = Field(default_factory=dict)
|
||||
|
||||
@@ -47,7 +47,7 @@ class ToolEngine:
|
||||
@staticmethod
|
||||
def agent_invoke(
|
||||
tool: Tool,
|
||||
tool_parameters: Union[str, dict],
|
||||
tool_parameters: Union[str, dict[str, Any]],
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
message: Message,
|
||||
@@ -85,7 +85,7 @@ class ToolEngine:
|
||||
invocation_meta_dict: dict[str, ToolInvokeMeta] = {}
|
||||
|
||||
def message_callback(
|
||||
invocation_meta_dict: dict[str, Any],
|
||||
invocation_meta_dict: dict[str, ToolInvokeMeta],
|
||||
messages: Generator[ToolInvokeMessage | ToolInvokeMeta, None, None],
|
||||
):
|
||||
for message in messages:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
@@ -19,10 +20,18 @@ class ToolLabelManager:
|
||||
return list(set(tool_labels))
|
||||
|
||||
@classmethod
|
||||
def update_tool_labels(cls, controller: ToolProviderController, labels: list[str]):
|
||||
def update_tool_labels(
|
||||
cls, controller: ToolProviderController, labels: list[str], session: Session | None = None
|
||||
) -> None:
|
||||
"""
|
||||
Update tool labels
|
||||
|
||||
:param controller: tool provider controller
|
||||
:param labels: list of tool labels
|
||||
:param session: database session, if None, a new session will be created
|
||||
:return: None
|
||||
"""
|
||||
|
||||
labels = cls.filter_tool_labels(labels)
|
||||
|
||||
if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
|
||||
@@ -30,26 +39,46 @@ class ToolLabelManager:
|
||||
else:
|
||||
raise ValueError("Unsupported tool type")
|
||||
|
||||
if session is not None:
|
||||
cls._update_tool_labels_logics(session, provider_id, controller, labels)
|
||||
else:
|
||||
with sessionmaker(db.engine).begin() as _session:
|
||||
cls._update_tool_labels_logics(_session, provider_id, controller, labels)
|
||||
|
||||
@classmethod
|
||||
def _update_tool_labels_logics(
|
||||
cls, session: Session, provider_id: str, controller: ToolProviderController, labels: list[str]
|
||||
) -> None:
|
||||
"""
|
||||
Update tool labels logics
|
||||
|
||||
:param session: database session
|
||||
:param provider_id: tool provider ID
|
||||
:param controller: tool provider controller
|
||||
:param labels: list of tool labels
|
||||
:return: None
|
||||
"""
|
||||
|
||||
# delete old labels
|
||||
db.session.execute(delete(ToolLabelBinding).where(ToolLabelBinding.tool_id == provider_id))
|
||||
_ = session.execute(
|
||||
delete(ToolLabelBinding).where(
|
||||
ToolLabelBinding.tool_id == provider_id, ToolLabelBinding.tool_type == controller.provider_type
|
||||
)
|
||||
)
|
||||
|
||||
# insert new labels
|
||||
for label in labels:
|
||||
db.session.add(
|
||||
ToolLabelBinding(
|
||||
tool_id=provider_id,
|
||||
tool_type=controller.provider_type,
|
||||
label_name=label,
|
||||
)
|
||||
)
|
||||
|
||||
db.session.commit()
|
||||
session.add(ToolLabelBinding(tool_id=provider_id, tool_type=controller.provider_type, label_name=label))
|
||||
|
||||
@classmethod
|
||||
def get_tool_labels(cls, controller: ToolProviderController) -> list[str]:
|
||||
"""
|
||||
Get tool labels
|
||||
|
||||
:param controller: tool provider controller
|
||||
:return: list of tool labels (str)
|
||||
"""
|
||||
|
||||
if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
|
||||
provider_id = controller.provider_id
|
||||
elif isinstance(controller, BuiltinToolProviderController):
|
||||
@@ -60,9 +89,11 @@ class ToolLabelManager:
|
||||
ToolLabelBinding.tool_id == provider_id,
|
||||
ToolLabelBinding.tool_type == controller.provider_type,
|
||||
)
|
||||
labels = db.session.scalars(stmt).all()
|
||||
|
||||
return list(labels)
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as _session:
|
||||
labels: list[str] = list(_session.scalars(stmt).all())
|
||||
|
||||
return labels
|
||||
|
||||
@classmethod
|
||||
def get_tools_labels(cls, tool_providers: list[ToolProviderController]) -> dict[str, list[str]]:
|
||||
@@ -78,16 +109,22 @@ class ToolLabelManager:
|
||||
if not tool_providers:
|
||||
return {}
|
||||
|
||||
provider_ids: list[str] = []
|
||||
provider_types: set[str] = set()
|
||||
|
||||
for controller in tool_providers:
|
||||
if not isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
|
||||
raise ValueError("Unsupported tool type")
|
||||
|
||||
provider_ids = []
|
||||
for controller in tool_providers:
|
||||
assert isinstance(controller, ApiToolProviderController | WorkflowToolProviderController)
|
||||
provider_ids.append(controller.provider_id)
|
||||
provider_types.add(controller.provider_type)
|
||||
|
||||
labels = db.session.scalars(select(ToolLabelBinding).where(ToolLabelBinding.tool_id.in_(provider_ids))).all()
|
||||
labels: list[ToolLabelBinding] = []
|
||||
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as _session:
|
||||
stmt = select(ToolLabelBinding).where(
|
||||
ToolLabelBinding.tool_id.in_(provider_ids), ToolLabelBinding.tool_type.in_(list(provider_types))
|
||||
)
|
||||
labels = list(_session.scalars(stmt).all())
|
||||
|
||||
tool_labels: dict[str, list[str]] = {label.tool_id: [] for label in labels}
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import cast
|
||||
from typing import Any, cast
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
@@ -39,7 +39,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
||||
dataset_id: str
|
||||
user_id: str | None = None
|
||||
retrieve_config: DatasetRetrieveConfigEntity
|
||||
inputs: dict
|
||||
inputs: dict[str, Any]
|
||||
|
||||
@classmethod
|
||||
def from_dataset(cls, dataset: Dataset, **kwargs):
|
||||
|
||||
@@ -277,7 +277,7 @@ class WorkflowTool(Tool):
|
||||
session.expunge(app)
|
||||
return app
|
||||
|
||||
def _transform_args(self, tool_parameters: dict[str, Any]) -> tuple[dict[str, Any], list[dict[str, Any]]]:
|
||||
def _transform_args(self, tool_parameters: dict[str, Any]) -> tuple[dict[str, Any], list[dict[str, str | None]]]:
|
||||
"""
|
||||
transform the tool parameters
|
||||
|
||||
@@ -355,7 +355,7 @@ class WorkflowTool(Tool):
|
||||
|
||||
return result, files
|
||||
|
||||
def _update_file_mapping(self, file_dict: dict[str, Any]):
|
||||
def _update_file_mapping(self, file_dict: dict[str, Any]) -> dict[str, Any]:
|
||||
file_id = resolve_file_record_id(file_dict.get("reference") or file_dict.get("related_id"))
|
||||
transfer_method = FileTransferMethod.value_of(file_dict.get("transfer_method"))
|
||||
match transfer_method:
|
||||
|
||||
@@ -329,7 +329,7 @@ class EnterpriseMetricHandler:
|
||||
return
|
||||
|
||||
include_content = exporter.include_content
|
||||
attrs: dict = {
|
||||
attrs: dict[str, Any] = {
|
||||
"dify.message.id": payload.get("message_id"),
|
||||
"dify.tenant_id": envelope.tenant_id,
|
||||
"dify.event.id": envelope.event_id,
|
||||
|
||||
@@ -9,6 +9,7 @@ from typing_extensions import TypedDict
|
||||
|
||||
from configs import dify_config
|
||||
from dify_app import DifyApp
|
||||
from extensions.redis_names import normalize_redis_key_prefix
|
||||
|
||||
|
||||
class _CelerySentinelKwargsDict(TypedDict):
|
||||
@@ -16,9 +17,10 @@ class _CelerySentinelKwargsDict(TypedDict):
|
||||
password: str | None
|
||||
|
||||
|
||||
class CelerySentinelTransportDict(TypedDict):
|
||||
class CelerySentinelTransportDict(TypedDict, total=False):
|
||||
master_name: str | None
|
||||
sentinel_kwargs: _CelerySentinelKwargsDict
|
||||
global_keyprefix: str
|
||||
|
||||
|
||||
class CelerySSLOptionsDict(TypedDict):
|
||||
@@ -61,15 +63,31 @@ def get_celery_ssl_options() -> CelerySSLOptionsDict | None:
|
||||
|
||||
def get_celery_broker_transport_options() -> CelerySentinelTransportDict | dict[str, Any]:
|
||||
"""Get broker transport options (e.g. Redis Sentinel) for Celery connections."""
|
||||
transport_options: CelerySentinelTransportDict | dict[str, Any]
|
||||
if dify_config.CELERY_USE_SENTINEL:
|
||||
return CelerySentinelTransportDict(
|
||||
transport_options = CelerySentinelTransportDict(
|
||||
master_name=dify_config.CELERY_SENTINEL_MASTER_NAME,
|
||||
sentinel_kwargs=_CelerySentinelKwargsDict(
|
||||
socket_timeout=dify_config.CELERY_SENTINEL_SOCKET_TIMEOUT,
|
||||
password=dify_config.CELERY_SENTINEL_PASSWORD,
|
||||
),
|
||||
)
|
||||
return {}
|
||||
else:
|
||||
transport_options = {}
|
||||
|
||||
global_keyprefix = get_celery_redis_global_keyprefix()
|
||||
if global_keyprefix:
|
||||
transport_options["global_keyprefix"] = global_keyprefix
|
||||
|
||||
return transport_options
|
||||
|
||||
|
||||
def get_celery_redis_global_keyprefix() -> str | None:
|
||||
"""Return the Redis transport prefix for Celery when namespace isolation is enabled."""
|
||||
normalized_prefix = normalize_redis_key_prefix(dify_config.REDIS_KEY_PREFIX)
|
||||
if not normalized_prefix:
|
||||
return None
|
||||
return f"{normalized_prefix}:"
|
||||
|
||||
|
||||
def init_app(app: DifyApp) -> Celery:
|
||||
|
||||
@@ -3,7 +3,7 @@ import logging
|
||||
import ssl
|
||||
from collections.abc import Callable
|
||||
from datetime import timedelta
|
||||
from typing import TYPE_CHECKING, Any, Union
|
||||
from typing import Any, Union, cast
|
||||
|
||||
import redis
|
||||
from redis import RedisError
|
||||
@@ -18,17 +18,26 @@ from typing_extensions import TypedDict
|
||||
|
||||
from configs import dify_config
|
||||
from dify_app import DifyApp
|
||||
from extensions.redis_names import (
|
||||
normalize_redis_key_prefix,
|
||||
serialize_redis_name,
|
||||
serialize_redis_name_arg,
|
||||
serialize_redis_name_args,
|
||||
)
|
||||
from libs.broadcast_channel.channel import BroadcastChannel as BroadcastChannelProtocol
|
||||
from libs.broadcast_channel.redis.channel import BroadcastChannel as RedisBroadcastChannel
|
||||
from libs.broadcast_channel.redis.sharded_channel import ShardedRedisBroadcastChannel
|
||||
from libs.broadcast_channel.redis.streams_channel import StreamsBroadcastChannel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from redis.lock import Lock
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_normalize_redis_key_prefix = normalize_redis_key_prefix
|
||||
_serialize_redis_name = serialize_redis_name
|
||||
_serialize_redis_name_arg = serialize_redis_name_arg
|
||||
_serialize_redis_name_args = serialize_redis_name_args
|
||||
|
||||
|
||||
class RedisClientWrapper:
|
||||
"""
|
||||
A wrapper class for the Redis client that addresses the issue where the global
|
||||
@@ -59,68 +68,148 @@ class RedisClientWrapper:
|
||||
if self._client is None:
|
||||
self._client = client
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# Type hints for IDE support and static analysis
|
||||
# These are not executed at runtime but provide type information
|
||||
def get(self, name: str | bytes) -> Any: ...
|
||||
|
||||
def set(
|
||||
self,
|
||||
name: str | bytes,
|
||||
value: Any,
|
||||
ex: int | None = None,
|
||||
px: int | None = None,
|
||||
nx: bool = False,
|
||||
xx: bool = False,
|
||||
keepttl: bool = False,
|
||||
get: bool = False,
|
||||
exat: int | None = None,
|
||||
pxat: int | None = None,
|
||||
) -> Any: ...
|
||||
|
||||
def setex(self, name: str | bytes, time: int | timedelta, value: Any) -> Any: ...
|
||||
def setnx(self, name: str | bytes, value: Any) -> Any: ...
|
||||
def delete(self, *names: str | bytes) -> Any: ...
|
||||
def incr(self, name: str | bytes, amount: int = 1) -> Any: ...
|
||||
def expire(
|
||||
self,
|
||||
name: str | bytes,
|
||||
time: int | timedelta,
|
||||
nx: bool = False,
|
||||
xx: bool = False,
|
||||
gt: bool = False,
|
||||
lt: bool = False,
|
||||
) -> Any: ...
|
||||
def lock(
|
||||
self,
|
||||
name: str,
|
||||
timeout: float | None = None,
|
||||
sleep: float = 0.1,
|
||||
blocking: bool = True,
|
||||
blocking_timeout: float | None = None,
|
||||
thread_local: bool = True,
|
||||
) -> Lock: ...
|
||||
def zadd(
|
||||
self,
|
||||
name: str | bytes,
|
||||
mapping: dict[str | bytes | int | float, float | int | str | bytes],
|
||||
nx: bool = False,
|
||||
xx: bool = False,
|
||||
ch: bool = False,
|
||||
incr: bool = False,
|
||||
gt: bool = False,
|
||||
lt: bool = False,
|
||||
) -> Any: ...
|
||||
def zremrangebyscore(self, name: str | bytes, min: float | str, max: float | str) -> Any: ...
|
||||
def zcard(self, name: str | bytes) -> Any: ...
|
||||
def getdel(self, name: str | bytes) -> Any: ...
|
||||
def pubsub(self) -> PubSub: ...
|
||||
def pipeline(self, transaction: bool = True, shard_hint: str | None = None) -> Any: ...
|
||||
|
||||
def __getattr__(self, item: str) -> Any:
|
||||
def _require_client(self) -> redis.Redis | RedisCluster:
|
||||
if self._client is None:
|
||||
raise RuntimeError("Redis client is not initialized. Call init_app first.")
|
||||
return getattr(self._client, item)
|
||||
return self._client
|
||||
|
||||
def _get_prefix(self) -> str:
|
||||
return dify_config.REDIS_KEY_PREFIX
|
||||
|
||||
def get(self, name: str | bytes) -> Any:
|
||||
return self._require_client().get(_serialize_redis_name_arg(name, self._get_prefix()))
|
||||
|
||||
def set(
|
||||
self,
|
||||
name: str | bytes,
|
||||
value: Any,
|
||||
ex: int | None = None,
|
||||
px: int | None = None,
|
||||
nx: bool = False,
|
||||
xx: bool = False,
|
||||
keepttl: bool = False,
|
||||
get: bool = False,
|
||||
exat: int | None = None,
|
||||
pxat: int | None = None,
|
||||
) -> Any:
|
||||
return self._require_client().set(
|
||||
_serialize_redis_name_arg(name, self._get_prefix()),
|
||||
value,
|
||||
ex=ex,
|
||||
px=px,
|
||||
nx=nx,
|
||||
xx=xx,
|
||||
keepttl=keepttl,
|
||||
get=get,
|
||||
exat=exat,
|
||||
pxat=pxat,
|
||||
)
|
||||
|
||||
def setex(self, name: str | bytes, time: int | timedelta, value: Any) -> Any:
|
||||
return self._require_client().setex(_serialize_redis_name_arg(name, self._get_prefix()), time, value)
|
||||
|
||||
def setnx(self, name: str | bytes, value: Any) -> Any:
|
||||
return self._require_client().setnx(_serialize_redis_name_arg(name, self._get_prefix()), value)
|
||||
|
||||
def delete(self, *names: str | bytes) -> Any:
|
||||
return self._require_client().delete(*_serialize_redis_name_args(names, self._get_prefix()))
|
||||
|
||||
def incr(self, name: str | bytes, amount: int = 1) -> Any:
|
||||
return self._require_client().incr(_serialize_redis_name_arg(name, self._get_prefix()), amount)
|
||||
|
||||
def expire(
|
||||
self,
|
||||
name: str | bytes,
|
||||
time: int | timedelta,
|
||||
nx: bool = False,
|
||||
xx: bool = False,
|
||||
gt: bool = False,
|
||||
lt: bool = False,
|
||||
) -> Any:
|
||||
return self._require_client().expire(
|
||||
_serialize_redis_name_arg(name, self._get_prefix()),
|
||||
time,
|
||||
nx=nx,
|
||||
xx=xx,
|
||||
gt=gt,
|
||||
lt=lt,
|
||||
)
|
||||
|
||||
def exists(self, *names: str | bytes) -> Any:
|
||||
return self._require_client().exists(*_serialize_redis_name_args(names, self._get_prefix()))
|
||||
|
||||
def ttl(self, name: str | bytes) -> Any:
|
||||
return self._require_client().ttl(_serialize_redis_name_arg(name, self._get_prefix()))
|
||||
|
||||
def getdel(self, name: str | bytes) -> Any:
|
||||
return self._require_client().getdel(_serialize_redis_name_arg(name, self._get_prefix()))
|
||||
|
||||
def lock(
|
||||
self,
|
||||
name: str,
|
||||
timeout: float | None = None,
|
||||
sleep: float = 0.1,
|
||||
blocking: bool = True,
|
||||
blocking_timeout: float | None = None,
|
||||
thread_local: bool = True,
|
||||
) -> Any:
|
||||
return self._require_client().lock(
|
||||
_serialize_redis_name(name, self._get_prefix()),
|
||||
timeout=timeout,
|
||||
sleep=sleep,
|
||||
blocking=blocking,
|
||||
blocking_timeout=blocking_timeout,
|
||||
thread_local=thread_local,
|
||||
)
|
||||
|
||||
def hset(self, name: str | bytes, *args: Any, **kwargs: Any) -> Any:
|
||||
return self._require_client().hset(_serialize_redis_name_arg(name, self._get_prefix()), *args, **kwargs)
|
||||
|
||||
def hgetall(self, name: str | bytes) -> Any:
|
||||
return self._require_client().hgetall(_serialize_redis_name_arg(name, self._get_prefix()))
|
||||
|
||||
def hdel(self, name: str | bytes, *keys: str | bytes) -> Any:
|
||||
return self._require_client().hdel(_serialize_redis_name_arg(name, self._get_prefix()), *keys)
|
||||
|
||||
def hlen(self, name: str | bytes) -> Any:
|
||||
return self._require_client().hlen(_serialize_redis_name_arg(name, self._get_prefix()))
|
||||
|
||||
def zadd(
|
||||
self,
|
||||
name: str | bytes,
|
||||
mapping: dict[str | bytes | int | float, float | int | str | bytes],
|
||||
nx: bool = False,
|
||||
xx: bool = False,
|
||||
ch: bool = False,
|
||||
incr: bool = False,
|
||||
gt: bool = False,
|
||||
lt: bool = False,
|
||||
) -> Any:
|
||||
return self._require_client().zadd(
|
||||
_serialize_redis_name_arg(name, self._get_prefix()),
|
||||
cast(Any, mapping),
|
||||
nx=nx,
|
||||
xx=xx,
|
||||
ch=ch,
|
||||
incr=incr,
|
||||
gt=gt,
|
||||
lt=lt,
|
||||
)
|
||||
|
||||
def zremrangebyscore(self, name: str | bytes, min: float | str, max: float | str) -> Any:
|
||||
return self._require_client().zremrangebyscore(_serialize_redis_name_arg(name, self._get_prefix()), min, max)
|
||||
|
||||
def zcard(self, name: str | bytes) -> Any:
|
||||
return self._require_client().zcard(_serialize_redis_name_arg(name, self._get_prefix()))
|
||||
|
||||
def pubsub(self) -> PubSub:
|
||||
return self._require_client().pubsub()
|
||||
|
||||
def pipeline(self, transaction: bool = True, shard_hint: str | None = None) -> Any:
|
||||
return self._require_client().pipeline(transaction=transaction, shard_hint=shard_hint)
|
||||
|
||||
def __getattr__(self, item: str) -> Any:
|
||||
return getattr(self._require_client(), item)
|
||||
|
||||
|
||||
redis_client: RedisClientWrapper = RedisClientWrapper()
|
||||
|
||||
32
api/extensions/redis_names.py
Normal file
32
api/extensions/redis_names.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from configs import dify_config
|
||||
|
||||
|
||||
def normalize_redis_key_prefix(prefix: str | None) -> str:
|
||||
"""Normalize the configured Redis key prefix for consistent runtime use."""
|
||||
if prefix is None:
|
||||
return ""
|
||||
return prefix.strip()
|
||||
|
||||
|
||||
def get_redis_key_prefix() -> str:
|
||||
"""Read and normalize the current Redis key prefix from config."""
|
||||
return normalize_redis_key_prefix(dify_config.REDIS_KEY_PREFIX)
|
||||
|
||||
|
||||
def serialize_redis_name(name: str, prefix: str | None = None) -> str:
|
||||
"""Convert a logical Redis name into the physical name used in Redis."""
|
||||
normalized_prefix = get_redis_key_prefix() if prefix is None else normalize_redis_key_prefix(prefix)
|
||||
if not normalized_prefix:
|
||||
return name
|
||||
return f"{normalized_prefix}:{name}"
|
||||
|
||||
|
||||
def serialize_redis_name_arg(name: str | bytes, prefix: str | None = None) -> str | bytes:
|
||||
"""Prefix string Redis names while preserving bytes inputs unchanged."""
|
||||
if isinstance(name, bytes):
|
||||
return name
|
||||
return serialize_redis_name(name, prefix)
|
||||
|
||||
|
||||
def serialize_redis_name_args(names: tuple[str | bytes, ...], prefix: str | None = None) -> tuple[str | bytes, ...]:
|
||||
return tuple(serialize_redis_name_arg(name, prefix) for name in names)
|
||||
@@ -2,6 +2,7 @@ import logging
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import opendal
|
||||
from dotenv import dotenv_values
|
||||
@@ -19,7 +20,7 @@ def _get_opendal_kwargs(*, scheme: str, env_file_path: str = ".env", prefix: str
|
||||
if key.startswith(config_prefix):
|
||||
kwargs[key[len(config_prefix) :].lower()] = value
|
||||
|
||||
file_env_vars: dict = dotenv_values(env_file_path) or {}
|
||||
file_env_vars: dict[str, Any] = dotenv_values(env_file_path) or {}
|
||||
for key, value in file_env_vars.items():
|
||||
if key.startswith(config_prefix) and key[len(config_prefix) :].lower() not in kwargs and value:
|
||||
kwargs[key[len(config_prefix) :].lower()] = value
|
||||
|
||||
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from extensions.redis_names import serialize_redis_name
|
||||
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
|
||||
from redis import Redis, RedisCluster
|
||||
|
||||
@@ -32,12 +33,13 @@ class Topic:
|
||||
def __init__(self, redis_client: Redis | RedisCluster, topic: str):
|
||||
self._client = redis_client
|
||||
self._topic = topic
|
||||
self._redis_topic = serialize_redis_name(topic)
|
||||
|
||||
def as_producer(self) -> Producer:
|
||||
return self
|
||||
|
||||
def publish(self, payload: bytes) -> None:
|
||||
self._client.publish(self._topic, payload)
|
||||
self._client.publish(self._redis_topic, payload)
|
||||
|
||||
def as_subscriber(self) -> Subscriber:
|
||||
return self
|
||||
@@ -46,7 +48,7 @@ class Topic:
|
||||
return _RedisSubscription(
|
||||
client=self._client,
|
||||
pubsub=self._client.pubsub(),
|
||||
topic=self._topic,
|
||||
topic=self._redis_topic,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from extensions.redis_names import serialize_redis_name
|
||||
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
|
||||
from redis import Redis, RedisCluster
|
||||
|
||||
@@ -30,12 +31,13 @@ class ShardedTopic:
|
||||
def __init__(self, redis_client: Redis | RedisCluster, topic: str):
|
||||
self._client = redis_client
|
||||
self._topic = topic
|
||||
self._redis_topic = serialize_redis_name(topic)
|
||||
|
||||
def as_producer(self) -> Producer:
|
||||
return self
|
||||
|
||||
def publish(self, payload: bytes) -> None:
|
||||
self._client.spublish(self._topic, payload) # type: ignore[attr-defined,union-attr]
|
||||
self._client.spublish(self._redis_topic, payload) # type: ignore[attr-defined,union-attr]
|
||||
|
||||
def as_subscriber(self) -> Subscriber:
|
||||
return self
|
||||
@@ -44,7 +46,7 @@ class ShardedTopic:
|
||||
return _RedisShardedSubscription(
|
||||
client=self._client,
|
||||
pubsub=self._client.pubsub(),
|
||||
topic=self._topic,
|
||||
topic=self._redis_topic,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import threading
|
||||
from collections.abc import Iterator
|
||||
from typing import Self
|
||||
|
||||
from extensions.redis_names import serialize_redis_name
|
||||
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
|
||||
from libs.broadcast_channel.exc import SubscriptionClosedError
|
||||
from redis import Redis, RedisCluster
|
||||
@@ -35,7 +36,7 @@ class StreamsTopic:
|
||||
def __init__(self, redis_client: Redis | RedisCluster, topic: str, *, retention_seconds: int = 600):
|
||||
self._client = redis_client
|
||||
self._topic = topic
|
||||
self._key = f"stream:{topic}"
|
||||
self._key = serialize_redis_name(f"stream:{topic}")
|
||||
self._retention_seconds = retention_seconds
|
||||
self.max_length = 5000
|
||||
|
||||
|
||||
@@ -103,7 +103,10 @@ class DbMigrationAutoRenewLock:
|
||||
timeout=self._ttl_seconds,
|
||||
thread_local=False,
|
||||
)
|
||||
acquired = bool(self._lock.acquire(*args, **kwargs))
|
||||
lock = self._lock
|
||||
if lock is None:
|
||||
raise RuntimeError("Redis lock initialization failed.")
|
||||
acquired = bool(lock.acquire(*args, **kwargs))
|
||||
self._acquired = acquired
|
||||
if acquired:
|
||||
self._start_heartbeat()
|
||||
|
||||
@@ -120,10 +120,22 @@ class AppIconUrlField(fields.Raw):
|
||||
obj = obj["app"]
|
||||
|
||||
if isinstance(obj, App | Site) and obj.icon_type == IconType.IMAGE:
|
||||
return file_helpers.get_signed_file_url(obj.icon)
|
||||
return build_icon_url(obj.icon_type, obj.icon)
|
||||
return None
|
||||
|
||||
|
||||
def build_icon_url(icon_type: Any, icon: str | None) -> str | None:
|
||||
if icon is None or icon_type is None:
|
||||
return None
|
||||
|
||||
from models.model import IconType
|
||||
|
||||
icon_type_value = icon_type.value if isinstance(icon_type, IconType) else str(icon_type)
|
||||
if icon_type_value.lower() != IconType.IMAGE:
|
||||
return None
|
||||
return file_helpers.get_signed_file_url(icon)
|
||||
|
||||
|
||||
class AvatarUrlField(fields.Raw):
|
||||
def output(self, key, obj, **kwargs):
|
||||
if obj is None:
|
||||
|
||||
@@ -1551,7 +1551,7 @@ class PipelineBuiltInTemplate(TypeBase):
|
||||
name: Mapped[str] = mapped_column(sa.String(255), nullable=False)
|
||||
description: Mapped[str] = mapped_column(LongText, nullable=False)
|
||||
chunk_structure: Mapped[str] = mapped_column(sa.String(255), nullable=False)
|
||||
icon: Mapped[dict] = mapped_column(sa.JSON, nullable=False)
|
||||
icon: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False)
|
||||
yaml_content: Mapped[str] = mapped_column(LongText, nullable=False)
|
||||
copyright: Mapped[str] = mapped_column(sa.String(255), nullable=False)
|
||||
privacy_policy: Mapped[str] = mapped_column(sa.String(255), nullable=False)
|
||||
@@ -1584,7 +1584,7 @@ class PipelineCustomizedTemplate(TypeBase):
|
||||
name: Mapped[str] = mapped_column(sa.String(255), nullable=False)
|
||||
description: Mapped[str] = mapped_column(LongText, nullable=False)
|
||||
chunk_structure: Mapped[str] = mapped_column(sa.String(255), nullable=False)
|
||||
icon: Mapped[dict] = mapped_column(sa.JSON, nullable=False)
|
||||
icon: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False)
|
||||
position: Mapped[int] = mapped_column(sa.Integer, nullable=False)
|
||||
yaml_content: Mapped[str] = mapped_column(LongText, nullable=False)
|
||||
install_count: Mapped[int] = mapped_column(sa.Integer, nullable=False)
|
||||
@@ -1657,7 +1657,7 @@ class DocumentPipelineExecutionLog(TypeBase):
|
||||
datasource_type: Mapped[str] = mapped_column(sa.String(255), nullable=False)
|
||||
datasource_info: Mapped[str] = mapped_column(LongText, nullable=False)
|
||||
datasource_node_id: Mapped[str] = mapped_column(sa.String(255), nullable=False)
|
||||
input_data: Mapped[dict] = mapped_column(sa.JSON, nullable=False)
|
||||
input_data: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False)
|
||||
created_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
|
||||
@@ -1007,7 +1007,7 @@ class OAuthProviderApp(TypeBase):
|
||||
app_icon: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
client_id: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
client_secret: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
app_label: Mapped[dict] = mapped_column(sa.JSON, nullable=False, default_factory=dict)
|
||||
app_label: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False, default_factory=dict)
|
||||
redirect_uris: Mapped[list] = mapped_column(sa.JSON, nullable=False, default_factory=list)
|
||||
scope: Mapped[str] = mapped_column(
|
||||
String(255),
|
||||
@@ -2495,7 +2495,7 @@ class TraceAppConfig(TypeBase):
|
||||
)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
tracing_provider: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
tracing_config: Mapped[dict | None] = mapped_column(sa.JSON, nullable=True)
|
||||
tracing_config: Mapped[dict[str, Any] | None] = mapped_column(sa.JSON, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import func
|
||||
@@ -22,7 +23,7 @@ class DatasourceOauthParamConfig(TypeBase):
|
||||
)
|
||||
plugin_id: Mapped[str] = mapped_column(sa.String(255), nullable=False)
|
||||
provider: Mapped[str] = mapped_column(sa.String(255), nullable=False)
|
||||
system_credentials: Mapped[dict] = mapped_column(AdjustedJSON, nullable=False)
|
||||
system_credentials: Mapped[dict[str, Any]] = mapped_column(AdjustedJSON, nullable=False)
|
||||
|
||||
|
||||
class DatasourceProvider(TypeBase):
|
||||
@@ -40,7 +41,7 @@ class DatasourceProvider(TypeBase):
|
||||
provider: Mapped[str] = mapped_column(sa.String(128), nullable=False)
|
||||
plugin_id: Mapped[str] = mapped_column(sa.String(255), nullable=False)
|
||||
auth_type: Mapped[str] = mapped_column(sa.String(255), nullable=False)
|
||||
encrypted_credentials: Mapped[dict] = mapped_column(AdjustedJSON, nullable=False)
|
||||
encrypted_credentials: Mapped[dict[str, Any]] = mapped_column(AdjustedJSON, nullable=False)
|
||||
avatar_url: Mapped[str] = mapped_column(LongText, nullable=True, default="default")
|
||||
is_default: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"), default=False)
|
||||
expires_at: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default="-1", default=-1)
|
||||
@@ -70,7 +71,7 @@ class DatasourceOauthTenantParamConfig(TypeBase):
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
provider: Mapped[str] = mapped_column(sa.String(255), nullable=False)
|
||||
plugin_id: Mapped[str] = mapped_column(sa.String(255), nullable=False)
|
||||
client_params: Mapped[dict] = mapped_column(AdjustedJSON, nullable=False, default_factory=dict)
|
||||
client_params: Mapped[dict[str, Any]] = mapped_column(AdjustedJSON, nullable=False, default_factory=dict)
|
||||
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=False)
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
|
||||
@@ -25,7 +25,7 @@ class DataSourceOauthBinding(TypeBase):
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
access_token: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
provider: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
source_info: Mapped[dict] = mapped_column(AdjustedJSON, nullable=False)
|
||||
source_info: Mapped[dict[str, Any]] = mapped_column(AdjustedJSON, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
)
|
||||
|
||||
@@ -23,7 +23,7 @@ class ElasticSearchJaVector(ElasticSearchVector):
|
||||
self,
|
||||
embeddings: list[list[float]],
|
||||
metadatas: list[dict[Any, Any]] | None = None,
|
||||
index_params: dict | None = None,
|
||||
index_params: dict[str, Any] | None = None,
|
||||
):
|
||||
lock_name = f"vector_indexing_lock_{self._collection_name}"
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from numpy import ndarray
|
||||
@@ -8,5 +9,5 @@ class CollectionORM(DeclarativeBase):
|
||||
__tablename__: str
|
||||
id: Mapped[UUID]
|
||||
text: Mapped[str]
|
||||
meta: Mapped[dict]
|
||||
meta: Mapped[dict[str, Any]]
|
||||
vector: Mapped[ndarray]
|
||||
|
||||
@@ -67,7 +67,7 @@ class PGVectoRS(BaseVector):
|
||||
primary_key=True,
|
||||
)
|
||||
text: Mapped[str]
|
||||
meta: Mapped[dict] = mapped_column(postgresql.JSONB)
|
||||
meta: Mapped[dict[str, Any]] = mapped_column(postgresql.JSONB)
|
||||
vector: Mapped[ndarray] = mapped_column(VECTOR(dim))
|
||||
|
||||
self._table = _Table
|
||||
|
||||
@@ -136,6 +136,9 @@ dify-vdb-weaviate = { workspace = true }
|
||||
[tool.uv]
|
||||
default-groups = ["storage", "tools", "vdb-all"]
|
||||
package = false
|
||||
override-dependencies = [
|
||||
"pyarrow>=18.0.0",
|
||||
]
|
||||
|
||||
[dependency-groups]
|
||||
|
||||
|
||||
@@ -87,7 +87,7 @@ class RetrievalModel(BaseModel):
|
||||
|
||||
class MetaDataConfig(BaseModel):
|
||||
doc_type: str
|
||||
doc_metadata: dict
|
||||
doc_metadata: dict[str, Any]
|
||||
|
||||
|
||||
class KnowledgeConfig(BaseModel):
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from graphon.model_runtime.entities.model_entities import ModelType, ParameterRule
|
||||
|
||||
@@ -168,7 +169,9 @@ class ModelProviderService:
|
||||
model_name=model,
|
||||
)
|
||||
|
||||
def get_provider_credential(self, tenant_id: str, provider: str, credential_id: str | None = None) -> dict | None:
|
||||
def get_provider_credential(
|
||||
self, tenant_id: str, provider: str, credential_id: str | None = None
|
||||
) -> dict[str, Any] | None:
|
||||
"""
|
||||
get provider credentials.
|
||||
|
||||
@@ -180,7 +183,7 @@ class ModelProviderService:
|
||||
provider_configuration = self._get_provider_configuration(tenant_id, provider)
|
||||
return provider_configuration.get_provider_credential(credential_id=credential_id)
|
||||
|
||||
def validate_provider_credentials(self, tenant_id: str, provider: str, credentials: dict):
|
||||
def validate_provider_credentials(self, tenant_id: str, provider: str, credentials: dict[str, Any]):
|
||||
"""
|
||||
validate provider credentials before saving.
|
||||
|
||||
@@ -192,7 +195,7 @@ class ModelProviderService:
|
||||
provider_configuration.validate_provider_credentials(credentials)
|
||||
|
||||
def create_provider_credential(
|
||||
self, tenant_id: str, provider: str, credentials: dict, credential_name: str | None
|
||||
self, tenant_id: str, provider: str, credentials: dict[str, Any], credential_name: str | None
|
||||
) -> None:
|
||||
"""
|
||||
Create and save new provider credentials.
|
||||
@@ -210,7 +213,7 @@ class ModelProviderService:
|
||||
self,
|
||||
tenant_id: str,
|
||||
provider: str,
|
||||
credentials: dict,
|
||||
credentials: dict[str, Any],
|
||||
credential_id: str,
|
||||
credential_name: str | None,
|
||||
) -> None:
|
||||
@@ -254,7 +257,7 @@ class ModelProviderService:
|
||||
|
||||
def get_model_credential(
|
||||
self, tenant_id: str, provider: str, model_type: str, model: str, credential_id: str | None
|
||||
) -> dict | None:
|
||||
) -> dict[str, Any] | None:
|
||||
"""
|
||||
Retrieve model-specific credentials.
|
||||
|
||||
@@ -270,7 +273,9 @@ class ModelProviderService:
|
||||
model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id
|
||||
)
|
||||
|
||||
def validate_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict):
|
||||
def validate_model_credentials(
|
||||
self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict[str, Any]
|
||||
):
|
||||
"""
|
||||
validate model credentials.
|
||||
|
||||
@@ -287,7 +292,13 @@ class ModelProviderService:
|
||||
)
|
||||
|
||||
def create_model_credential(
|
||||
self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict, credential_name: str | None
|
||||
self,
|
||||
tenant_id: str,
|
||||
provider: str,
|
||||
model_type: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
credential_name: str | None,
|
||||
) -> None:
|
||||
"""
|
||||
create and save model credentials.
|
||||
@@ -314,7 +325,7 @@ class ModelProviderService:
|
||||
provider: str,
|
||||
model_type: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
credentials: dict[str, Any],
|
||||
credential_id: str,
|
||||
credential_name: str | None,
|
||||
) -> None:
|
||||
|
||||
@@ -104,7 +104,7 @@ class RagPipelineService:
|
||||
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||
|
||||
@classmethod
|
||||
def get_pipeline_templates(cls, type: str = "built-in", language: str = "en-US") -> dict:
|
||||
def get_pipeline_templates(cls, type: str = "built-in", language: str = "en-US") -> dict[str, Any]:
|
||||
if type == "built-in":
|
||||
mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE
|
||||
retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
|
||||
@@ -120,7 +120,7 @@ class RagPipelineService:
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def get_pipeline_template_detail(cls, template_id: str, type: str = "built-in") -> dict | None:
|
||||
def get_pipeline_template_detail(cls, template_id: str, type: str = "built-in") -> dict[str, Any] | None:
|
||||
"""
|
||||
Get pipeline template detail.
|
||||
|
||||
@@ -131,7 +131,7 @@ class RagPipelineService:
|
||||
if type == "built-in":
|
||||
mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE
|
||||
retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
|
||||
built_in_result: dict | None = retrieval_instance.get_pipeline_template_detail(template_id)
|
||||
built_in_result: dict[str, Any] | None = retrieval_instance.get_pipeline_template_detail(template_id)
|
||||
if built_in_result is None:
|
||||
logger.warning(
|
||||
"pipeline template retrieval returned empty result, template_id: %s, mode: %s",
|
||||
@@ -142,7 +142,7 @@ class RagPipelineService:
|
||||
else:
|
||||
mode = "customized"
|
||||
retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
|
||||
customized_result: dict | None = retrieval_instance.get_pipeline_template_detail(template_id)
|
||||
customized_result: dict[str, Any] | None = retrieval_instance.get_pipeline_template_detail(template_id)
|
||||
return customized_result
|
||||
|
||||
@classmethod
|
||||
@@ -297,7 +297,7 @@ class RagPipelineService:
|
||||
self,
|
||||
*,
|
||||
pipeline: Pipeline,
|
||||
graph: dict,
|
||||
graph: dict[str, Any],
|
||||
unique_hash: str | None,
|
||||
account: Account,
|
||||
environment_variables: Sequence[VariableBase],
|
||||
@@ -467,7 +467,9 @@ class RagPipelineService:
|
||||
|
||||
return default_block_configs
|
||||
|
||||
def get_default_block_config(self, node_type: str, filters: dict | None = None) -> Mapping[str, object] | None:
|
||||
def get_default_block_config(
|
||||
self, node_type: str, filters: dict[str, Any] | None = None
|
||||
) -> Mapping[str, object] | None:
|
||||
"""
|
||||
Get default config of node.
|
||||
:param node_type: node type
|
||||
@@ -500,7 +502,7 @@ class RagPipelineService:
|
||||
return default_config
|
||||
|
||||
def run_draft_workflow_node(
|
||||
self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account
|
||||
self, pipeline: Pipeline, node_id: str, user_inputs: dict[str, Any], account: Account
|
||||
) -> WorkflowNodeExecutionModel | None:
|
||||
"""
|
||||
Run draft workflow node
|
||||
@@ -582,7 +584,7 @@ class RagPipelineService:
|
||||
self,
|
||||
pipeline: Pipeline,
|
||||
node_id: str,
|
||||
user_inputs: dict,
|
||||
user_inputs: dict[str, Any],
|
||||
account: Account,
|
||||
datasource_type: str,
|
||||
is_published: bool,
|
||||
@@ -749,7 +751,7 @@ class RagPipelineService:
|
||||
self,
|
||||
pipeline: Pipeline,
|
||||
node_id: str,
|
||||
user_inputs: dict,
|
||||
user_inputs: dict[str, Any],
|
||||
account: Account,
|
||||
datasource_type: str,
|
||||
is_published: bool,
|
||||
@@ -979,7 +981,7 @@ class RagPipelineService:
|
||||
return workflow_node_execution
|
||||
|
||||
def update_workflow(
|
||||
self, *, session: Session, workflow_id: str, tenant_id: str, account_id: str, data: dict
|
||||
self, *, session: Session, workflow_id: str, tenant_id: str, account_id: str, data: dict[str, Any]
|
||||
) -> Workflow | None:
|
||||
"""
|
||||
Update workflow attributes
|
||||
@@ -1099,7 +1101,9 @@ class RagPipelineService:
|
||||
]
|
||||
return datasource_provider_variables
|
||||
|
||||
def get_rag_pipeline_paginate_workflow_runs(self, pipeline: Pipeline, args: dict) -> InfiniteScrollPagination:
|
||||
def get_rag_pipeline_paginate_workflow_runs(
|
||||
self, pipeline: Pipeline, args: dict[str, Any]
|
||||
) -> InfiniteScrollPagination:
|
||||
"""
|
||||
Get debug workflow run list
|
||||
Only return triggered_from == debugging
|
||||
@@ -1169,7 +1173,7 @@ class RagPipelineService:
|
||||
return list(node_executions)
|
||||
|
||||
@classmethod
|
||||
def publish_customized_pipeline_template(cls, pipeline_id: str, args: dict):
|
||||
def publish_customized_pipeline_template(cls, pipeline_id: str, args: dict[str, Any]):
|
||||
"""
|
||||
Publish customized pipeline template
|
||||
"""
|
||||
@@ -1259,7 +1263,7 @@ class RagPipelineService:
|
||||
)
|
||||
return node_exec
|
||||
|
||||
def set_datasource_variables(self, pipeline: Pipeline, args: dict, current_user: Account):
|
||||
def set_datasource_variables(self, pipeline: Pipeline, args: dict[str, Any], current_user: Account):
|
||||
"""
|
||||
Set datasource variables
|
||||
"""
|
||||
@@ -1346,7 +1350,7 @@ class RagPipelineService:
|
||||
)
|
||||
return workflow_node_execution_db_model
|
||||
|
||||
def get_recommended_plugins(self, type: str) -> dict:
|
||||
def get_recommended_plugins(self, type: str) -> dict[str, Any]:
|
||||
# Query active recommended plugins
|
||||
stmt = select(PipelineRecommendedPlugin).where(PipelineRecommendedPlugin.active == True)
|
||||
if type and type != "all":
|
||||
|
||||
@@ -139,62 +139,82 @@ class WorkflowToolManageService:
|
||||
:param labels: labels
|
||||
:return: the updated tool
|
||||
"""
|
||||
# check if the name is unique
|
||||
existing_workflow_tool_provider = db.session.scalar(
|
||||
select(WorkflowToolProvider)
|
||||
.where(
|
||||
WorkflowToolProvider.tenant_id == tenant_id,
|
||||
WorkflowToolProvider.name == name,
|
||||
WorkflowToolProvider.id != workflow_tool_id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
existing_workflow_tool_provider: WorkflowToolProvider | None = None
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as _session:
|
||||
# query if the name exists for other tools
|
||||
existing_workflow_tool_provider = _session.scalar(
|
||||
select(WorkflowToolProvider)
|
||||
.where(
|
||||
WorkflowToolProvider.tenant_id == tenant_id,
|
||||
WorkflowToolProvider.name == name,
|
||||
WorkflowToolProvider.id != workflow_tool_id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
# if the name exists raise error
|
||||
if existing_workflow_tool_provider is not None:
|
||||
raise ValueError(f"Tool with name {name} already exists")
|
||||
|
||||
workflow_tool_provider: WorkflowToolProvider | None = db.session.scalar(
|
||||
select(WorkflowToolProvider)
|
||||
.where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
|
||||
.limit(1)
|
||||
)
|
||||
# query the workflow tool provider
|
||||
workflow_tool_provider: WorkflowToolProvider | None = None
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as _session:
|
||||
workflow_tool_provider = _session.scalar(
|
||||
select(WorkflowToolProvider)
|
||||
.where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
# if not found raise error
|
||||
if workflow_tool_provider is None:
|
||||
raise ValueError(f"Tool {workflow_tool_id} not found")
|
||||
|
||||
app: App | None = db.session.scalar(
|
||||
select(App).where(App.id == workflow_tool_provider.app_id, App.tenant_id == tenant_id).limit(1)
|
||||
)
|
||||
# query the app
|
||||
app: App | None = None
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as _session:
|
||||
app = _session.scalar(
|
||||
select(App).where(App.id == workflow_tool_provider.app_id, App.tenant_id == tenant_id).limit(1)
|
||||
)
|
||||
|
||||
# if not found raise error
|
||||
if app is None:
|
||||
raise ValueError(f"App {workflow_tool_provider.app_id} not found")
|
||||
|
||||
# query the workflow
|
||||
workflow: Workflow | None = app.workflow
|
||||
|
||||
# if not found raise error
|
||||
if workflow is None:
|
||||
raise ValueError(f"Workflow not found for app {workflow_tool_provider.app_id}")
|
||||
|
||||
# check if workflow configuration is synced
|
||||
WorkflowToolConfigurationUtils.ensure_no_human_input_nodes(workflow.graph_dict)
|
||||
|
||||
workflow_tool_provider.name = name
|
||||
workflow_tool_provider.label = label
|
||||
workflow_tool_provider.icon = json.dumps(icon)
|
||||
workflow_tool_provider.description = description
|
||||
workflow_tool_provider.parameter_configuration = json.dumps([p.model_dump() for p in parameters])
|
||||
workflow_tool_provider.privacy_policy = privacy_policy
|
||||
workflow_tool_provider.version = workflow.version
|
||||
workflow_tool_provider.updated_at = datetime.now()
|
||||
with sessionmaker(db.engine).begin() as _session:
|
||||
_session.add(workflow_tool_provider)
|
||||
|
||||
try:
|
||||
WorkflowToolProviderController.from_db(workflow_tool_provider)
|
||||
except Exception as e:
|
||||
raise ValueError(str(e))
|
||||
# update workflow tool provider
|
||||
workflow_tool_provider.name = name
|
||||
workflow_tool_provider.label = label
|
||||
workflow_tool_provider.icon = json.dumps(icon)
|
||||
workflow_tool_provider.description = description
|
||||
workflow_tool_provider.parameter_configuration = json.dumps([p.model_dump() for p in parameters])
|
||||
workflow_tool_provider.privacy_policy = privacy_policy
|
||||
workflow_tool_provider.version = workflow.version
|
||||
workflow_tool_provider.updated_at = datetime.now()
|
||||
|
||||
db.session.commit()
|
||||
try:
|
||||
WorkflowToolProviderController.from_db(workflow_tool_provider)
|
||||
except Exception as e:
|
||||
raise ValueError(str(e))
|
||||
|
||||
if labels is not None:
|
||||
ToolLabelManager.update_tool_labels(
|
||||
ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), labels
|
||||
)
|
||||
if labels is not None:
|
||||
ToolLabelManager.update_tool_labels(
|
||||
ToolTransformService.workflow_provider_to_controller(workflow_tool_provider),
|
||||
labels,
|
||||
session=_session,
|
||||
)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@@ -241,8 +241,8 @@ class WorkflowService:
|
||||
self,
|
||||
*,
|
||||
app_model: App,
|
||||
graph: dict,
|
||||
features: dict,
|
||||
graph: dict[str, Any],
|
||||
features: dict[str, Any],
|
||||
unique_hash: str | None,
|
||||
account: Account,
|
||||
environment_variables: Sequence[VariableBase],
|
||||
@@ -576,7 +576,7 @@ class WorkflowService:
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to validate default credential for tool provider {provider}: {str(e)}")
|
||||
|
||||
def _validate_load_balancing_credentials(self, workflow: Workflow, node_data: dict, node_id: str) -> None:
|
||||
def _validate_load_balancing_credentials(self, workflow: Workflow, node_data: dict[str, Any], node_id: str) -> None:
|
||||
"""
|
||||
Validate load balancing credentials for a workflow node.
|
||||
|
||||
@@ -1214,7 +1214,7 @@ class WorkflowService:
|
||||
return variable_pool
|
||||
|
||||
def run_free_workflow_node(
|
||||
self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any]
|
||||
self, node_data: dict[str, Any], tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any]
|
||||
) -> WorkflowNodeExecution:
|
||||
"""
|
||||
Run free workflow node
|
||||
@@ -1361,7 +1361,7 @@ class WorkflowService:
|
||||
node_execution.status = WorkflowNodeExecutionStatus.FAILED
|
||||
node_execution.error = error
|
||||
|
||||
def convert_to_workflow(self, app_model: App, account: Account, args: dict) -> App:
|
||||
def convert_to_workflow(self, app_model: App, account: Account, args: dict[str, Any]) -> App:
|
||||
"""
|
||||
Basic mode of chatbot app(expert mode) to workflow
|
||||
Completion App to Workflow App
|
||||
@@ -1421,7 +1421,7 @@ class WorkflowService:
|
||||
if node_type == BuiltinNodeTypes.HUMAN_INPUT:
|
||||
self._validate_human_input_node_data(node_data)
|
||||
|
||||
def validate_features_structure(self, app_model: App, features: dict):
|
||||
def validate_features_structure(self, app_model: App, features: dict[str, Any]):
|
||||
match app_model.mode:
|
||||
case AppMode.ADVANCED_CHAT:
|
||||
return AdvancedChatAppConfigManager.config_validate(
|
||||
@@ -1434,7 +1434,7 @@ class WorkflowService:
|
||||
case _:
|
||||
raise ValueError(f"Invalid app mode: {app_model.mode}")
|
||||
|
||||
def _validate_human_input_node_data(self, node_data: dict) -> None:
|
||||
def _validate_human_input_node_data(self, node_data: dict[str, Any]) -> None:
|
||||
"""
|
||||
Validate HumanInput node data format.
|
||||
|
||||
@@ -1452,7 +1452,7 @@ class WorkflowService:
|
||||
raise ValueError(f"Invalid HumanInput node data: {str(e)}")
|
||||
|
||||
def update_workflow(
|
||||
self, *, session: Session, workflow_id: str, tenant_id: str, account_id: str, data: dict
|
||||
self, *, session: Session, workflow_id: str, tenant_id: str, account_id: str, data: dict[str, Any]
|
||||
) -> Workflow | None:
|
||||
"""
|
||||
Update workflow attributes
|
||||
|
||||
@@ -33,6 +33,7 @@ REDIS_USERNAME=
|
||||
REDIS_PASSWORD=difyai123456
|
||||
REDIS_USE_SSL=false
|
||||
REDIS_DB=0
|
||||
REDIS_KEY_PREFIX=
|
||||
|
||||
# PostgreSQL database configuration
|
||||
DB_USERNAME=postgres
|
||||
|
||||
@@ -432,7 +432,7 @@ class TestWorkflowAppLogEndpoints:
|
||||
monkeypatch.setattr(workflow_app_log_module, "sessionmaker", DummySessionMaker)
|
||||
|
||||
def fake_get_paginate(self, **_kwargs):
|
||||
return {"items": [], "total": 0}
|
||||
return {"page": 1, "limit": 20, "total": 0, "has_more": False, "data": []}
|
||||
|
||||
monkeypatch.setattr(
|
||||
workflow_app_log_module.WorkflowAppService,
|
||||
@@ -443,7 +443,7 @@ class TestWorkflowAppLogEndpoints:
|
||||
with app.test_request_context("/?page=1&limit=20"):
|
||||
result = method(app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
assert result == {"items": [], "total": 0}
|
||||
assert result == {"page": 1, "limit": 20, "total": 0, "has_more": False, "data": []}
|
||||
|
||||
|
||||
class TestWorkflowDraftVariableEndpoints:
|
||||
@@ -608,7 +608,8 @@ class TestWorkflowTriggerEndpoints:
|
||||
with app.test_request_context("/?node_id=node-1"):
|
||||
result = method(app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
assert result is trigger
|
||||
assert isinstance(result, dict)
|
||||
assert {"id", "webhook_id", "webhook_url", "webhook_debug_url", "node_id", "created_at"} <= set(result.keys())
|
||||
|
||||
|
||||
class TestWrapsEndpoints:
|
||||
|
||||
@@ -0,0 +1,73 @@
|
||||
from datetime import datetime
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console.app.conversation import _get_conversation
|
||||
from models.enums import ConversationFromSource
|
||||
from models.model import AppMode, Conversation
|
||||
from tests.test_containers_integration_tests.controllers.console.helpers import (
|
||||
create_console_account_and_tenant,
|
||||
create_console_app,
|
||||
)
|
||||
|
||||
|
||||
def test_get_conversation_mark_read_keeps_updated_at_unchanged(
|
||||
db_session_with_containers: Session,
|
||||
):
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
|
||||
|
||||
original_updated_at = datetime(2026, 2, 8, 0, 0, 0)
|
||||
conversation = Conversation(
|
||||
app_id=app.id,
|
||||
name="read timestamp test",
|
||||
inputs={},
|
||||
status="normal",
|
||||
mode=AppMode.CHAT,
|
||||
from_source=ConversationFromSource.CONSOLE,
|
||||
from_account_id=account.id,
|
||||
updated_at=original_updated_at,
|
||||
)
|
||||
db_session_with_containers.add(conversation)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
read_at = datetime(2026, 2, 9, 0, 0, 0)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.app.conversation.current_account_with_tenant",
|
||||
return_value=(account, tenant.id),
|
||||
autospec=True,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.app.conversation.naive_utc_now",
|
||||
return_value=read_at,
|
||||
autospec=True,
|
||||
),
|
||||
):
|
||||
loaded = _get_conversation(app, conversation.id)
|
||||
|
||||
db_session_with_containers.refresh(conversation)
|
||||
|
||||
assert loaded.id == conversation.id
|
||||
assert conversation.read_at == read_at
|
||||
assert conversation.read_account_id == account.id
|
||||
assert conversation.updated_at == original_updated_at
|
||||
|
||||
|
||||
def test_get_conversation_raises_not_found_for_missing_conversation(
|
||||
db_session_with_containers: Session,
|
||||
):
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
|
||||
|
||||
with patch(
|
||||
"controllers.console.app.conversation.current_account_with_tenant",
|
||||
return_value=(account, tenant.id),
|
||||
autospec=True,
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
_get_conversation(app, "00000000-0000-0000-0000-000000000000")
|
||||
@@ -1,28 +1,48 @@
|
||||
"""Unit tests for controllers.web.site endpoints."""
|
||||
"""Testcontainers integration tests for controllers.web.site endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.web.site import AppSiteApi, AppSiteInfo
|
||||
from models import Tenant, TenantStatus
|
||||
from models.model import App, AppMode, CustomizeTokenStrategy, Site
|
||||
|
||||
|
||||
def _tenant(*, status: str = "normal") -> SimpleNamespace:
|
||||
return SimpleNamespace(
|
||||
id="tenant-1",
|
||||
status=status,
|
||||
plan="basic",
|
||||
custom_config_dict={"remove_webapp_brand": False, "replace_webapp_logo": False},
|
||||
@pytest.fixture
|
||||
def app(flask_app_with_containers) -> Flask:
|
||||
return flask_app_with_containers
|
||||
|
||||
|
||||
def _create_tenant(db_session: Session, *, status: TenantStatus = TenantStatus.NORMAL) -> Tenant:
|
||||
tenant = Tenant(name="test-tenant", status=status)
|
||||
db_session.add(tenant)
|
||||
db_session.commit()
|
||||
return tenant
|
||||
|
||||
|
||||
def _create_app(db_session: Session, tenant_id: str, *, enable_site: bool = True) -> App:
|
||||
app_model = App(
|
||||
tenant_id=tenant_id,
|
||||
mode=AppMode.CHAT,
|
||||
name="test-app",
|
||||
enable_site=enable_site,
|
||||
enable_api=True,
|
||||
)
|
||||
db_session.add(app_model)
|
||||
db_session.commit()
|
||||
return app_model
|
||||
|
||||
|
||||
def _site() -> SimpleNamespace:
|
||||
return SimpleNamespace(
|
||||
def _create_site(db_session: Session, app_id: str) -> Site:
|
||||
site = Site(
|
||||
app_id=app_id,
|
||||
title="Site",
|
||||
icon_type="emoji",
|
||||
icon="robot",
|
||||
@@ -31,77 +51,64 @@ def _site() -> SimpleNamespace:
|
||||
default_language="en",
|
||||
chat_color_theme="light",
|
||||
chat_color_theme_inverted=False,
|
||||
copyright=None,
|
||||
privacy_policy=None,
|
||||
custom_disclaimer=None,
|
||||
customize_token_strategy=CustomizeTokenStrategy.NOT_ALLOW,
|
||||
code=f"code-{app_id[-6:]}",
|
||||
prompt_public=False,
|
||||
show_workflow_steps=True,
|
||||
use_icon_as_answer_icon=False,
|
||||
)
|
||||
db_session.add(site)
|
||||
db_session.commit()
|
||||
return site
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AppSiteApi
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestAppSiteApi:
|
||||
@patch("controllers.web.site.FeatureService.get_features")
|
||||
@patch("controllers.web.site.db")
|
||||
def test_happy_path(self, mock_db: MagicMock, mock_features: MagicMock, app: Flask) -> None:
|
||||
def test_happy_path(self, mock_features, app: Flask, db_session_with_containers: Session) -> None:
|
||||
app.config["RESTX_MASK_HEADER"] = "X-Fields"
|
||||
mock_features.return_value = SimpleNamespace(can_replace_logo=False)
|
||||
site_obj = _site()
|
||||
mock_db.session.scalar.return_value = site_obj
|
||||
tenant = _tenant()
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant, enable_site=True)
|
||||
tenant = _create_tenant(db_session_with_containers)
|
||||
app_model = _create_app(db_session_with_containers, tenant.id)
|
||||
_create_site(db_session_with_containers, app_model.id)
|
||||
end_user = SimpleNamespace(id="eu-1")
|
||||
mock_features.return_value = SimpleNamespace(can_replace_logo=False)
|
||||
|
||||
with app.test_request_context("/site"):
|
||||
result = AppSiteApi().get(app_model, end_user)
|
||||
|
||||
# marshal_with serializes AppSiteInfo to a dict
|
||||
assert result["app_id"] == "app-1"
|
||||
assert result["app_id"] == app_model.id
|
||||
assert result["plan"] == "basic"
|
||||
assert result["enable_site"] is True
|
||||
|
||||
@patch("controllers.web.site.db")
|
||||
def test_missing_site_raises_forbidden(self, mock_db: MagicMock, app: Flask) -> None:
|
||||
def test_missing_site_raises_forbidden(self, app: Flask, db_session_with_containers: Session) -> None:
|
||||
app.config["RESTX_MASK_HEADER"] = "X-Fields"
|
||||
mock_db.session.scalar.return_value = None
|
||||
tenant = _tenant()
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant, enable_site=True)
|
||||
tenant = _create_tenant(db_session_with_containers)
|
||||
app_model = _create_app(db_session_with_containers, tenant.id)
|
||||
end_user = SimpleNamespace(id="eu-1")
|
||||
|
||||
with app.test_request_context("/site"):
|
||||
with pytest.raises(Forbidden):
|
||||
AppSiteApi().get(app_model, end_user)
|
||||
|
||||
@patch("controllers.web.site.db")
|
||||
def test_archived_tenant_raises_forbidden(self, mock_db: MagicMock, app: Flask) -> None:
|
||||
@patch("controllers.web.site.FeatureService.get_features")
|
||||
def test_archived_tenant_raises_forbidden(
|
||||
self, mock_features, app: Flask, db_session_with_containers: Session
|
||||
) -> None:
|
||||
app.config["RESTX_MASK_HEADER"] = "X-Fields"
|
||||
from models.account import TenantStatus
|
||||
|
||||
mock_db.session.scalar.return_value = _site()
|
||||
tenant = SimpleNamespace(
|
||||
id="tenant-1",
|
||||
status=TenantStatus.ARCHIVE,
|
||||
plan="basic",
|
||||
custom_config_dict={},
|
||||
)
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant)
|
||||
tenant = _create_tenant(db_session_with_containers, status=TenantStatus.ARCHIVE)
|
||||
app_model = _create_app(db_session_with_containers, tenant.id)
|
||||
_create_site(db_session_with_containers, app_model.id)
|
||||
end_user = SimpleNamespace(id="eu-1")
|
||||
mock_features.return_value = SimpleNamespace(can_replace_logo=False)
|
||||
|
||||
with app.test_request_context("/site"):
|
||||
with pytest.raises(Forbidden):
|
||||
AppSiteApi().get(app_model, end_user)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AppSiteInfo
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestAppSiteInfo:
|
||||
def test_basic_fields(self) -> None:
|
||||
tenant = _tenant()
|
||||
site_obj = _site()
|
||||
tenant = SimpleNamespace(id="tenant-1", plan="basic", custom_config_dict={})
|
||||
site_obj = SimpleNamespace()
|
||||
info = AppSiteInfo(tenant, SimpleNamespace(id="app-1", enable_site=True), site_obj, "eu-1", False)
|
||||
|
||||
assert info.app_id == "app-1"
|
||||
@@ -118,7 +125,7 @@ class TestAppSiteInfo:
|
||||
plan="pro",
|
||||
custom_config_dict={"remove_webapp_brand": True, "replace_webapp_logo": True},
|
||||
)
|
||||
site_obj = _site()
|
||||
site_obj = SimpleNamespace()
|
||||
info = AppSiteInfo(tenant, SimpleNamespace(id="app-1", enable_site=True), site_obj, "eu-1", True)
|
||||
|
||||
assert info.can_replace_logo is True
|
||||
@@ -0,0 +1,613 @@
|
||||
"""Testcontainers integration tests for DatasetService permission and lifecycle SQL paths."""
|
||||
|
||||
from datetime import datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import (
|
||||
AppDatasetJoin,
|
||||
Dataset,
|
||||
DatasetAutoDisableLog,
|
||||
DatasetCollectionBinding,
|
||||
DatasetPermission,
|
||||
DatasetPermissionEnum,
|
||||
)
|
||||
from models.enums import DataSourceType
|
||||
from services.dataset_service import DatasetCollectionBindingService, DatasetPermissionService, DatasetService
|
||||
from services.errors.account import NoPermissionError
|
||||
|
||||
|
||||
class DatasetPermissionIntegrationFactory:
|
||||
@staticmethod
|
||||
def create_account_with_tenant(
|
||||
db_session_with_containers: Session,
|
||||
role: TenantAccountRole = TenantAccountRole.OWNER,
|
||||
) -> tuple[Account, Tenant]:
|
||||
account = Account(
|
||||
email=f"{uuid4()}@example.com",
|
||||
name=f"user-{uuid4()}",
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
)
|
||||
tenant = Tenant(name=f"tenant-{uuid4()}", status="normal")
|
||||
db_session_with_containers.add_all([account, tenant])
|
||||
db_session_with_containers.flush()
|
||||
|
||||
join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role=role,
|
||||
current=True,
|
||||
)
|
||||
db_session_with_containers.add(join)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
account.role = role
|
||||
account._current_tenant = tenant
|
||||
return account, tenant
|
||||
|
||||
@staticmethod
|
||||
def create_account_in_tenant(
|
||||
db_session_with_containers: Session,
|
||||
tenant: Tenant,
|
||||
role: TenantAccountRole = TenantAccountRole.EDITOR,
|
||||
) -> Account:
|
||||
account = Account(
|
||||
email=f"{uuid4()}@example.com",
|
||||
name=f"user-{uuid4()}",
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
)
|
||||
db_session_with_containers.add(account)
|
||||
db_session_with_containers.flush()
|
||||
|
||||
join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role=role,
|
||||
current=True,
|
||||
)
|
||||
db_session_with_containers.add(join)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
account.role = role
|
||||
account._current_tenant = tenant
|
||||
return account
|
||||
|
||||
@staticmethod
|
||||
def create_dataset(
|
||||
db_session_with_containers: Session,
|
||||
*,
|
||||
tenant_id: str,
|
||||
created_by: str,
|
||||
name: str | None = None,
|
||||
permission: DatasetPermissionEnum = DatasetPermissionEnum.ONLY_ME,
|
||||
indexing_technique: str | None = IndexTechniqueType.HIGH_QUALITY,
|
||||
enable_api: bool = True,
|
||||
) -> Dataset:
|
||||
dataset = Dataset(
|
||||
tenant_id=tenant_id,
|
||||
name=name or f"dataset-{uuid4()}",
|
||||
description="desc",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
indexing_technique=indexing_technique,
|
||||
created_by=created_by,
|
||||
provider="vendor",
|
||||
permission=permission,
|
||||
retrieval_model={"top_k": 2},
|
||||
)
|
||||
dataset.enable_api = enable_api
|
||||
db_session_with_containers.add(dataset)
|
||||
db_session_with_containers.commit()
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
def create_dataset_permission(
|
||||
db_session_with_containers: Session,
|
||||
*,
|
||||
dataset_id: str,
|
||||
tenant_id: str,
|
||||
account_id: str,
|
||||
) -> DatasetPermission:
|
||||
permission = DatasetPermission(
|
||||
dataset_id=dataset_id,
|
||||
tenant_id=tenant_id,
|
||||
account_id=account_id,
|
||||
has_permission=True,
|
||||
)
|
||||
db_session_with_containers.add(permission)
|
||||
db_session_with_containers.commit()
|
||||
return permission
|
||||
|
||||
@staticmethod
|
||||
def create_app_dataset_join(
|
||||
db_session_with_containers: Session,
|
||||
*,
|
||||
dataset_id: str,
|
||||
) -> AppDatasetJoin:
|
||||
join = AppDatasetJoin(
|
||||
app_id=str(uuid4()),
|
||||
dataset_id=dataset_id,
|
||||
)
|
||||
db_session_with_containers.add(join)
|
||||
db_session_with_containers.commit()
|
||||
return join
|
||||
|
||||
@staticmethod
|
||||
def create_collection_binding(
|
||||
db_session_with_containers: Session,
|
||||
*,
|
||||
provider_name: str,
|
||||
model_name: str,
|
||||
collection_type: str = "dataset",
|
||||
) -> DatasetCollectionBinding:
|
||||
binding = DatasetCollectionBinding(
|
||||
provider_name=provider_name,
|
||||
model_name=model_name,
|
||||
collection_name=f"collection_{uuid4().hex}",
|
||||
type=collection_type,
|
||||
)
|
||||
db_session_with_containers.add(binding)
|
||||
db_session_with_containers.commit()
|
||||
return binding
|
||||
|
||||
@staticmethod
|
||||
def create_auto_disable_log(
|
||||
db_session_with_containers: Session,
|
||||
*,
|
||||
tenant_id: str,
|
||||
dataset_id: str,
|
||||
document_id: str,
|
||||
) -> DatasetAutoDisableLog:
|
||||
log = DatasetAutoDisableLog(
|
||||
tenant_id=tenant_id,
|
||||
dataset_id=dataset_id,
|
||||
document_id=document_id,
|
||||
)
|
||||
db_session_with_containers.add(log)
|
||||
db_session_with_containers.commit()
|
||||
return log
|
||||
|
||||
|
||||
class TestDatasetServicePermissionsAndLifecycle:
|
||||
def test_delete_dataset_returns_false_when_dataset_is_missing(self, db_session_with_containers: Session):
|
||||
owner, _tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers)
|
||||
|
||||
result = DatasetService.delete_dataset(str(uuid4()), user=owner)
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_delete_dataset_checks_permission_and_deletes_dataset(self, db_session_with_containers: Session):
|
||||
owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers)
|
||||
dataset = DatasetPermissionIntegrationFactory.create_dataset(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
created_by=owner.id,
|
||||
)
|
||||
|
||||
with patch("services.dataset_service.dataset_was_deleted.send") as send_deleted_signal:
|
||||
result = DatasetService.delete_dataset(dataset.id, user=owner)
|
||||
|
||||
assert result is True
|
||||
assert db_session_with_containers.get(Dataset, dataset.id) is None
|
||||
send_deleted_signal.assert_called_once_with(dataset)
|
||||
|
||||
def test_dataset_use_check_returns_true_when_join_exists(self, db_session_with_containers: Session):
|
||||
owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers)
|
||||
dataset = DatasetPermissionIntegrationFactory.create_dataset(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
created_by=owner.id,
|
||||
)
|
||||
DatasetPermissionIntegrationFactory.create_app_dataset_join(
|
||||
db_session_with_containers,
|
||||
dataset_id=dataset.id,
|
||||
)
|
||||
|
||||
assert DatasetService.dataset_use_check(dataset.id) is True
|
||||
|
||||
def test_dataset_use_check_returns_false_when_join_missing(self, db_session_with_containers: Session):
|
||||
owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers)
|
||||
dataset = DatasetPermissionIntegrationFactory.create_dataset(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
created_by=owner.id,
|
||||
)
|
||||
|
||||
assert DatasetService.dataset_use_check(dataset.id) is False
|
||||
|
||||
def test_check_dataset_permission_rejects_cross_tenant_access(self, db_session_with_containers: Session):
|
||||
owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers)
|
||||
outsider, _other_tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(
|
||||
db_session_with_containers
|
||||
)
|
||||
dataset = DatasetPermissionIntegrationFactory.create_dataset(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
created_by=owner.id,
|
||||
)
|
||||
|
||||
with pytest.raises(NoPermissionError, match="do not have permission"):
|
||||
DatasetService.check_dataset_permission(dataset, outsider)
|
||||
|
||||
def test_check_dataset_permission_rejects_only_me_dataset_for_non_creator(
|
||||
self, db_session_with_containers: Session
|
||||
):
|
||||
owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers)
|
||||
member = DatasetPermissionIntegrationFactory.create_account_in_tenant(db_session_with_containers, tenant)
|
||||
dataset = DatasetPermissionIntegrationFactory.create_dataset(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
created_by=owner.id,
|
||||
permission=DatasetPermissionEnum.ONLY_ME,
|
||||
)
|
||||
|
||||
with pytest.raises(NoPermissionError, match="do not have permission"):
|
||||
DatasetService.check_dataset_permission(dataset, member)
|
||||
|
||||
def test_check_dataset_permission_rejects_partial_team_user_without_binding(
|
||||
self, db_session_with_containers: Session
|
||||
):
|
||||
owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers)
|
||||
member = DatasetPermissionIntegrationFactory.create_account_in_tenant(db_session_with_containers, tenant)
|
||||
dataset = DatasetPermissionIntegrationFactory.create_dataset(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
created_by=owner.id,
|
||||
permission=DatasetPermissionEnum.PARTIAL_TEAM,
|
||||
)
|
||||
|
||||
with pytest.raises(NoPermissionError, match="do not have permission"):
|
||||
DatasetService.check_dataset_permission(dataset, member)
|
||||
|
||||
def test_check_dataset_permission_allows_partial_team_creator(self, db_session_with_containers: Session):
|
||||
creator, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(
|
||||
db_session_with_containers,
|
||||
role=TenantAccountRole.EDITOR,
|
||||
)
|
||||
dataset = DatasetPermissionIntegrationFactory.create_dataset(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
created_by=creator.id,
|
||||
permission=DatasetPermissionEnum.PARTIAL_TEAM,
|
||||
)
|
||||
|
||||
DatasetService.check_dataset_permission(dataset, creator)
|
||||
|
||||
def test_check_dataset_permission_allows_partial_team_member_with_binding(
|
||||
self, db_session_with_containers: Session
|
||||
):
|
||||
owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers)
|
||||
member = DatasetPermissionIntegrationFactory.create_account_in_tenant(db_session_with_containers, tenant)
|
||||
dataset = DatasetPermissionIntegrationFactory.create_dataset(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
created_by=owner.id,
|
||||
permission=DatasetPermissionEnum.PARTIAL_TEAM,
|
||||
)
|
||||
DatasetPermissionIntegrationFactory.create_dataset_permission(
|
||||
db_session_with_containers,
|
||||
dataset_id=dataset.id,
|
||||
tenant_id=tenant.id,
|
||||
account_id=member.id,
|
||||
)
|
||||
|
||||
DatasetService.check_dataset_permission(dataset, member)
|
||||
|
||||
def test_check_dataset_operator_permission_rejects_only_me_for_non_creator(
|
||||
self, db_session_with_containers: Session
|
||||
):
|
||||
owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers)
|
||||
operator = DatasetPermissionIntegrationFactory.create_account_in_tenant(
|
||||
db_session_with_containers,
|
||||
tenant,
|
||||
role=TenantAccountRole.EDITOR,
|
||||
)
|
||||
dataset = DatasetPermissionIntegrationFactory.create_dataset(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
created_by=owner.id,
|
||||
permission=DatasetPermissionEnum.ONLY_ME,
|
||||
)
|
||||
|
||||
with pytest.raises(NoPermissionError, match="do not have permission"):
|
||||
DatasetService.check_dataset_operator_permission(user=operator, dataset=dataset)
|
||||
|
||||
def test_check_dataset_operator_permission_rejects_partial_team_without_binding(
|
||||
self, db_session_with_containers: Session
|
||||
):
|
||||
owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers)
|
||||
operator = DatasetPermissionIntegrationFactory.create_account_in_tenant(
|
||||
db_session_with_containers,
|
||||
tenant,
|
||||
role=TenantAccountRole.EDITOR,
|
||||
)
|
||||
dataset = DatasetPermissionIntegrationFactory.create_dataset(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
created_by=owner.id,
|
||||
permission=DatasetPermissionEnum.PARTIAL_TEAM,
|
||||
)
|
||||
|
||||
with pytest.raises(NoPermissionError, match="do not have permission"):
|
||||
DatasetService.check_dataset_operator_permission(user=operator, dataset=dataset)
|
||||
|
||||
def test_check_dataset_operator_permission_allows_partial_team_with_binding(
|
||||
self, db_session_with_containers: Session
|
||||
):
|
||||
owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers)
|
||||
operator = DatasetPermissionIntegrationFactory.create_account_in_tenant(
|
||||
db_session_with_containers,
|
||||
tenant,
|
||||
role=TenantAccountRole.EDITOR,
|
||||
)
|
||||
dataset = DatasetPermissionIntegrationFactory.create_dataset(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
created_by=owner.id,
|
||||
permission=DatasetPermissionEnum.PARTIAL_TEAM,
|
||||
)
|
||||
DatasetPermissionIntegrationFactory.create_dataset_permission(
|
||||
db_session_with_containers,
|
||||
dataset_id=dataset.id,
|
||||
tenant_id=tenant.id,
|
||||
account_id=operator.id,
|
||||
)
|
||||
|
||||
DatasetService.check_dataset_operator_permission(user=operator, dataset=dataset)
|
||||
|
||||
def test_update_dataset_api_status_raises_not_found_for_missing_dataset(self, flask_app_with_containers):
|
||||
with flask_app_with_containers.app_context():
|
||||
with pytest.raises(NotFound, match="Dataset not found"):
|
||||
DatasetService.update_dataset_api_status(str(uuid4()), True)
|
||||
|
||||
def test_update_dataset_api_status_requires_current_user_id(self, db_session_with_containers: Session):
|
||||
owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers)
|
||||
dataset = DatasetPermissionIntegrationFactory.create_dataset(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
created_by=owner.id,
|
||||
enable_api=False,
|
||||
)
|
||||
|
||||
with patch("services.dataset_service.current_user", SimpleNamespace(id=None)):
|
||||
with pytest.raises(ValueError, match="Current user or current user id not found"):
|
||||
DatasetService.update_dataset_api_status(dataset.id, True)
|
||||
|
||||
def test_update_dataset_api_status_updates_fields_and_commits(self, db_session_with_containers: Session):
|
||||
owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers)
|
||||
dataset = DatasetPermissionIntegrationFactory.create_dataset(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
created_by=owner.id,
|
||||
enable_api=False,
|
||||
)
|
||||
now = datetime(2026, 4, 14, 18, 0, 0)
|
||||
|
||||
with (
|
||||
patch("services.dataset_service.current_user", owner),
|
||||
patch("services.dataset_service.naive_utc_now", return_value=now),
|
||||
):
|
||||
DatasetService.update_dataset_api_status(dataset.id, True)
|
||||
|
||||
db_session_with_containers.refresh(dataset)
|
||||
assert dataset.enable_api is True
|
||||
assert dataset.updated_by == owner.id
|
||||
assert dataset.updated_at == now
|
||||
|
||||
def test_get_dataset_auto_disable_logs_returns_empty_when_billing_is_disabled(
|
||||
self, db_session_with_containers: Session
|
||||
):
|
||||
owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers)
|
||||
features = SimpleNamespace(
|
||||
billing=SimpleNamespace(enabled=False, subscription=SimpleNamespace(plan="professional"))
|
||||
)
|
||||
|
||||
with (
|
||||
patch("services.dataset_service.current_user", owner),
|
||||
patch("services.dataset_service.FeatureService.get_features", return_value=features),
|
||||
):
|
||||
result = DatasetService.get_dataset_auto_disable_logs(str(uuid4()))
|
||||
|
||||
assert result == {"document_ids": [], "count": 0}
|
||||
|
||||
def test_get_dataset_auto_disable_logs_returns_recent_document_ids(self, db_session_with_containers: Session):
|
||||
owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers)
|
||||
dataset = DatasetPermissionIntegrationFactory.create_dataset(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
created_by=owner.id,
|
||||
)
|
||||
DatasetPermissionIntegrationFactory.create_auto_disable_log(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
document_id=str(uuid4()),
|
||||
)
|
||||
DatasetPermissionIntegrationFactory.create_auto_disable_log(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
document_id=str(uuid4()),
|
||||
)
|
||||
features = SimpleNamespace(
|
||||
billing=SimpleNamespace(enabled=True, subscription=SimpleNamespace(plan="professional"))
|
||||
)
|
||||
|
||||
with (
|
||||
patch("services.dataset_service.current_user", owner),
|
||||
patch("services.dataset_service.FeatureService.get_features", return_value=features),
|
||||
):
|
||||
result = DatasetService.get_dataset_auto_disable_logs(dataset.id)
|
||||
|
||||
assert result["count"] == 2
|
||||
assert len(result["document_ids"]) == 2
|
||||
|
||||
|
||||
class TestDatasetCollectionBindingServiceIntegration:
|
||||
def test_get_dataset_collection_binding_returns_existing_binding(self, db_session_with_containers: Session):
|
||||
binding = DatasetPermissionIntegrationFactory.create_collection_binding(
|
||||
db_session_with_containers,
|
||||
provider_name="provider",
|
||||
model_name="model",
|
||||
)
|
||||
|
||||
result = DatasetCollectionBindingService.get_dataset_collection_binding("provider", "model")
|
||||
|
||||
assert result.id == binding.id
|
||||
|
||||
def test_get_dataset_collection_binding_creates_binding_when_missing(self, db_session_with_containers: Session):
|
||||
result = DatasetCollectionBindingService.get_dataset_collection_binding("provider", "missing-model")
|
||||
|
||||
persisted = db_session_with_containers.get(DatasetCollectionBinding, result.id)
|
||||
assert persisted is not None
|
||||
assert persisted.provider_name == "provider"
|
||||
assert persisted.model_name == "missing-model"
|
||||
assert persisted.type == "dataset"
|
||||
assert persisted.collection_name
|
||||
|
||||
def test_get_dataset_collection_binding_by_id_and_type_raises_when_missing(self, flask_app_with_containers):
|
||||
with flask_app_with_containers.app_context():
|
||||
with pytest.raises(ValueError, match="Dataset collection binding not found"):
|
||||
DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(str(uuid4()))
|
||||
|
||||
def test_get_dataset_collection_binding_by_id_and_type_returns_binding(self, db_session_with_containers: Session):
|
||||
binding = DatasetPermissionIntegrationFactory.create_collection_binding(
|
||||
db_session_with_containers,
|
||||
provider_name="provider",
|
||||
model_name="model",
|
||||
)
|
||||
|
||||
result = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(binding.id)
|
||||
|
||||
assert result.id == binding.id
|
||||
|
||||
|
||||
class TestDatasetPermissionServiceIntegration:
|
||||
def test_get_dataset_partial_member_list_returns_scalar_results(self, db_session_with_containers: Session):
|
||||
owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers)
|
||||
member_a = DatasetPermissionIntegrationFactory.create_account_in_tenant(db_session_with_containers, tenant)
|
||||
member_b = DatasetPermissionIntegrationFactory.create_account_in_tenant(db_session_with_containers, tenant)
|
||||
dataset = DatasetPermissionIntegrationFactory.create_dataset(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
created_by=owner.id,
|
||||
permission=DatasetPermissionEnum.PARTIAL_TEAM,
|
||||
)
|
||||
DatasetPermissionIntegrationFactory.create_dataset_permission(
|
||||
db_session_with_containers,
|
||||
dataset_id=dataset.id,
|
||||
tenant_id=tenant.id,
|
||||
account_id=member_a.id,
|
||||
)
|
||||
DatasetPermissionIntegrationFactory.create_dataset_permission(
|
||||
db_session_with_containers,
|
||||
dataset_id=dataset.id,
|
||||
tenant_id=tenant.id,
|
||||
account_id=member_b.id,
|
||||
)
|
||||
|
||||
result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id)
|
||||
|
||||
assert set(result) == {member_a.id, member_b.id}
|
||||
|
||||
def test_update_partial_member_list_replaces_permissions_and_commits(self, db_session_with_containers: Session):
|
||||
owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers)
|
||||
member_a = DatasetPermissionIntegrationFactory.create_account_in_tenant(db_session_with_containers, tenant)
|
||||
member_b = DatasetPermissionIntegrationFactory.create_account_in_tenant(db_session_with_containers, tenant)
|
||||
dataset = DatasetPermissionIntegrationFactory.create_dataset(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
created_by=owner.id,
|
||||
permission=DatasetPermissionEnum.PARTIAL_TEAM,
|
||||
)
|
||||
stale_member = DatasetPermissionIntegrationFactory.create_account_in_tenant(db_session_with_containers, tenant)
|
||||
DatasetPermissionIntegrationFactory.create_dataset_permission(
|
||||
db_session_with_containers,
|
||||
dataset_id=dataset.id,
|
||||
tenant_id=tenant.id,
|
||||
account_id=stale_member.id,
|
||||
)
|
||||
|
||||
DatasetPermissionService.update_partial_member_list(
|
||||
tenant.id,
|
||||
dataset.id,
|
||||
[{"user_id": member_a.id}, {"user_id": member_b.id}],
|
||||
)
|
||||
|
||||
permissions = db_session_with_containers.query(DatasetPermission).filter_by(dataset_id=dataset.id).all()
|
||||
assert {permission.account_id for permission in permissions} == {member_a.id, member_b.id}
|
||||
|
||||
def test_check_permission_requires_dataset_editor(self):
|
||||
user = SimpleNamespace(is_dataset_editor=False, is_dataset_operator=False)
|
||||
dataset = SimpleNamespace(id="dataset-1", permission=DatasetPermissionEnum.ALL_TEAM)
|
||||
|
||||
with pytest.raises(NoPermissionError, match="does not have permission"):
|
||||
DatasetPermissionService.check_permission(user, dataset, DatasetPermissionEnum.ALL_TEAM, [])
|
||||
|
||||
def test_check_permission_prevents_dataset_operator_from_changing_permission_mode(self):
|
||||
user = SimpleNamespace(is_dataset_editor=True, is_dataset_operator=True)
|
||||
dataset = SimpleNamespace(id="dataset-1", permission=DatasetPermissionEnum.ALL_TEAM)
|
||||
|
||||
with pytest.raises(NoPermissionError, match="cannot change the dataset permissions"):
|
||||
DatasetPermissionService.check_permission(user, dataset, DatasetPermissionEnum.ONLY_ME, [])
|
||||
|
||||
def test_check_permission_requires_partial_member_list_for_partial_members_mode(self):
|
||||
user = SimpleNamespace(is_dataset_editor=True, is_dataset_operator=True)
|
||||
dataset = SimpleNamespace(id="dataset-1", permission=DatasetPermissionEnum.PARTIAL_TEAM)
|
||||
|
||||
with pytest.raises(ValueError, match="Partial member list is required"):
|
||||
DatasetPermissionService.check_permission(user, dataset, DatasetPermissionEnum.PARTIAL_TEAM, [])
|
||||
|
||||
def test_check_permission_rejects_dataset_operator_member_list_changes(self):
|
||||
user = SimpleNamespace(is_dataset_editor=True, is_dataset_operator=True)
|
||||
dataset = SimpleNamespace(id="dataset-1", permission=DatasetPermissionEnum.PARTIAL_TEAM)
|
||||
|
||||
with patch.object(DatasetPermissionService, "get_dataset_partial_member_list", return_value=["user-1"]):
|
||||
with pytest.raises(ValueError, match="cannot change the dataset permissions"):
|
||||
DatasetPermissionService.check_permission(
|
||||
user,
|
||||
dataset,
|
||||
DatasetPermissionEnum.PARTIAL_TEAM,
|
||||
[{"user_id": "user-2"}],
|
||||
)
|
||||
|
||||
def test_check_permission_allows_dataset_operator_when_member_list_is_unchanged(self):
|
||||
user = SimpleNamespace(is_dataset_editor=True, is_dataset_operator=True)
|
||||
dataset = SimpleNamespace(id="dataset-1", permission=DatasetPermissionEnum.PARTIAL_TEAM)
|
||||
|
||||
with patch.object(DatasetPermissionService, "get_dataset_partial_member_list", return_value=["user-1"]):
|
||||
DatasetPermissionService.check_permission(
|
||||
user,
|
||||
dataset,
|
||||
DatasetPermissionEnum.PARTIAL_TEAM,
|
||||
[{"user_id": "user-1"}],
|
||||
)
|
||||
|
||||
def test_clear_partial_member_list_deletes_permissions_and_commits(self, db_session_with_containers: Session):
|
||||
owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers)
|
||||
member = DatasetPermissionIntegrationFactory.create_account_in_tenant(db_session_with_containers, tenant)
|
||||
dataset = DatasetPermissionIntegrationFactory.create_dataset(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
created_by=owner.id,
|
||||
permission=DatasetPermissionEnum.PARTIAL_TEAM,
|
||||
)
|
||||
DatasetPermissionIntegrationFactory.create_dataset_permission(
|
||||
db_session_with_containers,
|
||||
dataset_id=dataset.id,
|
||||
tenant_id=tenant.id,
|
||||
account_id=member.id,
|
||||
)
|
||||
|
||||
DatasetPermissionService.clear_partial_member_list(dataset.id)
|
||||
|
||||
remaining = db_session_with_containers.query(DatasetPermission).filter_by(dataset_id=dataset.id).all()
|
||||
assert remaining == []
|
||||
@@ -0,0 +1,387 @@
|
||||
"""Testcontainers integration tests for schedule service SQL-backed behavior."""
|
||||
|
||||
from datetime import datetime
|
||||
from types import SimpleNamespace
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.workflow.nodes.trigger_schedule.entities import ScheduleConfig, SchedulePlanUpdate
|
||||
from core.workflow.nodes.trigger_schedule.exc import ScheduleNotFoundError
|
||||
from events.event_handlers.sync_workflow_schedule_when_app_published import sync_schedule_from_workflow
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.trigger import WorkflowSchedulePlan
|
||||
from services.errors.account import AccountNotFoundError
|
||||
from services.trigger.schedule_service import ScheduleService
|
||||
|
||||
|
||||
class ScheduleServiceIntegrationFactory:
|
||||
@staticmethod
|
||||
def create_account_with_tenant(
|
||||
db_session_with_containers: Session,
|
||||
role: TenantAccountRole = TenantAccountRole.OWNER,
|
||||
) -> tuple[Account, Tenant]:
|
||||
account = Account(
|
||||
email=f"{uuid4()}@example.com",
|
||||
name=f"user-{uuid4()}",
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
)
|
||||
tenant = Tenant(name=f"tenant-{uuid4()}", status="normal")
|
||||
db_session_with_containers.add_all([account, tenant])
|
||||
db_session_with_containers.flush()
|
||||
|
||||
join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role=role,
|
||||
current=True,
|
||||
)
|
||||
db_session_with_containers.add(join)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
account.current_tenant = tenant
|
||||
return account, tenant
|
||||
|
||||
@staticmethod
|
||||
def create_schedule_plan(
|
||||
db_session_with_containers: Session,
|
||||
*,
|
||||
tenant_id: str,
|
||||
app_id: str | None = None,
|
||||
node_id: str = "start",
|
||||
cron_expression: str = "30 10 * * *",
|
||||
timezone: str = "UTC",
|
||||
next_run_at: datetime | None = None,
|
||||
) -> WorkflowSchedulePlan:
|
||||
schedule = WorkflowSchedulePlan(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id or str(uuid4()),
|
||||
node_id=node_id,
|
||||
cron_expression=cron_expression,
|
||||
timezone=timezone,
|
||||
next_run_at=next_run_at,
|
||||
)
|
||||
db_session_with_containers.add(schedule)
|
||||
db_session_with_containers.commit()
|
||||
return schedule
|
||||
|
||||
|
||||
def _cron_workflow(
|
||||
*,
|
||||
node_id: str = "start",
|
||||
cron_expression: str = "30 10 * * *",
|
||||
timezone: str = "UTC",
|
||||
):
|
||||
return SimpleNamespace(
|
||||
graph_dict={
|
||||
"nodes": [
|
||||
{
|
||||
"id": node_id,
|
||||
"data": {
|
||||
"type": "trigger-schedule",
|
||||
"mode": "cron",
|
||||
"cron_expression": cron_expression,
|
||||
"timezone": timezone,
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _no_schedule_workflow():
|
||||
return SimpleNamespace(
|
||||
graph_dict={
|
||||
"nodes": [
|
||||
{
|
||||
"id": "node-1",
|
||||
"data": {"type": "llm"},
|
||||
}
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class TestScheduleServiceIntegration:
|
||||
def test_create_schedule_persists_schedule(self, db_session_with_containers: Session):
|
||||
account, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(db_session_with_containers)
|
||||
expected_next_run = datetime(2026, 1, 1, 10, 30, 0)
|
||||
config = ScheduleConfig(
|
||||
node_id="start",
|
||||
cron_expression="30 10 * * *",
|
||||
timezone="UTC",
|
||||
)
|
||||
|
||||
with pytest.MonkeyPatch.context() as monkeypatch:
|
||||
monkeypatch.setattr(
|
||||
"services.trigger.schedule_service.calculate_next_run_at",
|
||||
lambda *_args, **_kwargs: expected_next_run,
|
||||
)
|
||||
schedule = ScheduleService.create_schedule(
|
||||
session=db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
app_id=str(uuid4()),
|
||||
config=config,
|
||||
)
|
||||
|
||||
persisted = db_session_with_containers.get(WorkflowSchedulePlan, schedule.id)
|
||||
assert persisted is not None
|
||||
assert persisted.tenant_id == tenant.id
|
||||
assert persisted.node_id == "start"
|
||||
assert persisted.cron_expression == "30 10 * * *"
|
||||
assert persisted.timezone == "UTC"
|
||||
assert persisted.next_run_at == expected_next_run
|
||||
|
||||
def test_update_schedule_updates_fields_and_recomputes_next_run(self, db_session_with_containers: Session):
|
||||
_account, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(db_session_with_containers)
|
||||
schedule = ScheduleServiceIntegrationFactory.create_schedule_plan(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
cron_expression="30 10 * * *",
|
||||
timezone="UTC",
|
||||
)
|
||||
expected_next_run = datetime(2026, 1, 2, 12, 0, 0)
|
||||
|
||||
with pytest.MonkeyPatch.context() as monkeypatch:
|
||||
monkeypatch.setattr(
|
||||
"services.trigger.schedule_service.calculate_next_run_at",
|
||||
lambda *_args, **_kwargs: expected_next_run,
|
||||
)
|
||||
updated = ScheduleService.update_schedule(
|
||||
session=db_session_with_containers,
|
||||
schedule_id=schedule.id,
|
||||
updates=SchedulePlanUpdate(
|
||||
cron_expression="0 12 * * *",
|
||||
timezone="America/New_York",
|
||||
),
|
||||
)
|
||||
|
||||
db_session_with_containers.refresh(updated)
|
||||
assert updated.cron_expression == "0 12 * * *"
|
||||
assert updated.timezone == "America/New_York"
|
||||
assert updated.next_run_at == expected_next_run
|
||||
|
||||
def test_update_schedule_updates_only_node_id_without_recomputing_time(self, db_session_with_containers: Session):
|
||||
_account, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(db_session_with_containers)
|
||||
initial_next_run = datetime(2026, 1, 1, 10, 0, 0)
|
||||
schedule = ScheduleServiceIntegrationFactory.create_schedule_plan(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
next_run_at=initial_next_run,
|
||||
)
|
||||
|
||||
with pytest.MonkeyPatch.context() as monkeypatch:
|
||||
calls: list[tuple] = []
|
||||
|
||||
def _track(*args, **kwargs):
|
||||
calls.append((args, kwargs))
|
||||
return datetime(2026, 1, 9, 10, 0, 0)
|
||||
|
||||
monkeypatch.setattr("services.trigger.schedule_service.calculate_next_run_at", _track)
|
||||
updated = ScheduleService.update_schedule(
|
||||
session=db_session_with_containers,
|
||||
schedule_id=schedule.id,
|
||||
updates=SchedulePlanUpdate(node_id="node-new"),
|
||||
)
|
||||
|
||||
db_session_with_containers.refresh(updated)
|
||||
assert updated.node_id == "node-new"
|
||||
assert updated.next_run_at == initial_next_run
|
||||
assert calls == []
|
||||
|
||||
def test_update_schedule_not_found_raises(self, db_session_with_containers: Session):
|
||||
with pytest.raises(ScheduleNotFoundError, match="Schedule not found"):
|
||||
ScheduleService.update_schedule(
|
||||
session=db_session_with_containers,
|
||||
schedule_id=str(uuid4()),
|
||||
updates=SchedulePlanUpdate(node_id="node-new"),
|
||||
)
|
||||
|
||||
def test_delete_schedule_removes_row(self, db_session_with_containers: Session):
|
||||
_account, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(db_session_with_containers)
|
||||
schedule = ScheduleServiceIntegrationFactory.create_schedule_plan(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
)
|
||||
|
||||
ScheduleService.delete_schedule(
|
||||
session=db_session_with_containers,
|
||||
schedule_id=schedule.id,
|
||||
)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
assert db_session_with_containers.get(WorkflowSchedulePlan, schedule.id) is None
|
||||
|
||||
def test_delete_schedule_not_found_raises(self, db_session_with_containers: Session):
|
||||
with pytest.raises(ScheduleNotFoundError, match="Schedule not found"):
|
||||
ScheduleService.delete_schedule(
|
||||
session=db_session_with_containers,
|
||||
schedule_id=str(uuid4()),
|
||||
)
|
||||
|
||||
def test_get_tenant_owner_returns_owner_account(self, db_session_with_containers: Session):
|
||||
owner, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(
|
||||
db_session_with_containers,
|
||||
role=TenantAccountRole.OWNER,
|
||||
)
|
||||
|
||||
result = ScheduleService.get_tenant_owner(
|
||||
session=db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
)
|
||||
|
||||
assert result.id == owner.id
|
||||
|
||||
def test_get_tenant_owner_falls_back_to_admin(self, db_session_with_containers: Session):
|
||||
admin, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(
|
||||
db_session_with_containers,
|
||||
role=TenantAccountRole.ADMIN,
|
||||
)
|
||||
|
||||
result = ScheduleService.get_tenant_owner(
|
||||
session=db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
)
|
||||
|
||||
assert result.id == admin.id
|
||||
|
||||
def test_get_tenant_owner_raises_when_account_record_missing(self, db_session_with_containers: Session):
|
||||
_account, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(db_session_with_containers)
|
||||
db_session_with_containers.execute(delete(TenantAccountJoin))
|
||||
missing_account_id = str(uuid4())
|
||||
join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=missing_account_id,
|
||||
role=TenantAccountRole.OWNER,
|
||||
current=True,
|
||||
)
|
||||
db_session_with_containers.add(join)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
with pytest.raises(AccountNotFoundError, match=missing_account_id):
|
||||
ScheduleService.get_tenant_owner(session=db_session_with_containers, tenant_id=tenant.id)
|
||||
|
||||
def test_get_tenant_owner_raises_when_no_owner_or_admin_found(self, db_session_with_containers: Session):
|
||||
_account, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(db_session_with_containers)
|
||||
db_session_with_containers.execute(delete(TenantAccountJoin))
|
||||
db_session_with_containers.commit()
|
||||
|
||||
with pytest.raises(AccountNotFoundError, match=tenant.id):
|
||||
ScheduleService.get_tenant_owner(session=db_session_with_containers, tenant_id=tenant.id)
|
||||
|
||||
def test_update_next_run_at_updates_persisted_value(self, db_session_with_containers: Session):
|
||||
_account, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(db_session_with_containers)
|
||||
schedule = ScheduleServiceIntegrationFactory.create_schedule_plan(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
)
|
||||
expected_next_run = datetime(2026, 1, 3, 10, 30, 0)
|
||||
|
||||
with pytest.MonkeyPatch.context() as monkeypatch:
|
||||
monkeypatch.setattr(
|
||||
"services.trigger.schedule_service.calculate_next_run_at",
|
||||
lambda *_args, **_kwargs: expected_next_run,
|
||||
)
|
||||
result = ScheduleService.update_next_run_at(
|
||||
session=db_session_with_containers,
|
||||
schedule_id=schedule.id,
|
||||
)
|
||||
|
||||
db_session_with_containers.refresh(schedule)
|
||||
assert result == expected_next_run
|
||||
assert schedule.next_run_at == expected_next_run
|
||||
|
||||
def test_update_next_run_at_raises_when_schedule_not_found(self, db_session_with_containers: Session):
|
||||
with pytest.raises(ScheduleNotFoundError, match="Schedule not found"):
|
||||
ScheduleService.update_next_run_at(
|
||||
session=db_session_with_containers,
|
||||
schedule_id=str(uuid4()),
|
||||
)
|
||||
|
||||
|
||||
class TestSyncScheduleFromWorkflowIntegration:
|
||||
def test_sync_schedule_create_new(self, db_session_with_containers: Session):
|
||||
_account, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(db_session_with_containers)
|
||||
app_id = str(uuid4())
|
||||
expected_next_run = datetime(2026, 1, 4, 10, 30, 0)
|
||||
|
||||
with pytest.MonkeyPatch.context() as monkeypatch:
|
||||
monkeypatch.setattr(
|
||||
"services.trigger.schedule_service.calculate_next_run_at",
|
||||
lambda *_args, **_kwargs: expected_next_run,
|
||||
)
|
||||
result = sync_schedule_from_workflow(
|
||||
tenant_id=tenant.id,
|
||||
app_id=app_id,
|
||||
workflow=_cron_workflow(),
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
persisted = db_session_with_containers.execute(
|
||||
select(WorkflowSchedulePlan).where(WorkflowSchedulePlan.app_id == app_id)
|
||||
).scalar_one()
|
||||
assert persisted.node_id == "start"
|
||||
assert persisted.cron_expression == "30 10 * * *"
|
||||
assert persisted.timezone == "UTC"
|
||||
assert persisted.next_run_at == expected_next_run
|
||||
|
||||
def test_sync_schedule_update_existing(self, db_session_with_containers: Session):
|
||||
_account, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(db_session_with_containers)
|
||||
app_id = str(uuid4())
|
||||
existing = ScheduleServiceIntegrationFactory.create_schedule_plan(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
app_id=app_id,
|
||||
node_id="old-start",
|
||||
cron_expression="30 10 * * *",
|
||||
timezone="UTC",
|
||||
)
|
||||
existing_id = existing.id
|
||||
expected_next_run = datetime(2026, 1, 5, 12, 0, 0)
|
||||
|
||||
with pytest.MonkeyPatch.context() as monkeypatch:
|
||||
monkeypatch.setattr(
|
||||
"services.trigger.schedule_service.calculate_next_run_at",
|
||||
lambda *_args, **_kwargs: expected_next_run,
|
||||
)
|
||||
result = sync_schedule_from_workflow(
|
||||
tenant_id=tenant.id,
|
||||
app_id=app_id,
|
||||
workflow=_cron_workflow(
|
||||
node_id="start",
|
||||
cron_expression="0 12 * * *",
|
||||
timezone="America/New_York",
|
||||
),
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
db_session_with_containers.expire_all()
|
||||
persisted = db_session_with_containers.get(WorkflowSchedulePlan, existing_id)
|
||||
assert persisted is not None
|
||||
assert persisted.node_id == "start"
|
||||
assert persisted.cron_expression == "0 12 * * *"
|
||||
assert persisted.timezone == "America/New_York"
|
||||
assert persisted.next_run_at == expected_next_run
|
||||
|
||||
def test_sync_schedule_remove_when_no_config(self, db_session_with_containers: Session):
|
||||
_account, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(db_session_with_containers)
|
||||
app_id = str(uuid4())
|
||||
existing = ScheduleServiceIntegrationFactory.create_schedule_plan(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
app_id=app_id,
|
||||
)
|
||||
existing_id = existing.id
|
||||
|
||||
result = sync_schedule_from_workflow(
|
||||
tenant_id=tenant.id,
|
||||
app_id=app_id,
|
||||
workflow=_no_schedule_workflow(),
|
||||
)
|
||||
|
||||
assert result is None
|
||||
db_session_with_containers.expire_all()
|
||||
assert db_session_with_containers.get(WorkflowSchedulePlan, existing_id) is None
|
||||
@@ -2,6 +2,7 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
@@ -530,22 +531,18 @@ class TestAddDocumentToIndexTask:
|
||||
redis_client.set(indexing_cache_key, "processing", ex=300)
|
||||
|
||||
# Verify logs exist before processing
|
||||
existing_logs = (
|
||||
db_session_with_containers.query(DatasetAutoDisableLog)
|
||||
.where(DatasetAutoDisableLog.document_id == document.id)
|
||||
.all()
|
||||
)
|
||||
existing_logs = db_session_with_containers.scalars(
|
||||
select(DatasetAutoDisableLog).where(DatasetAutoDisableLog.document_id == document.id)
|
||||
).all()
|
||||
assert len(existing_logs) == 2
|
||||
|
||||
# Act: Execute the task
|
||||
add_document_to_index_task(document.id)
|
||||
|
||||
# Assert: Verify auto disable logs were deleted
|
||||
remaining_logs = (
|
||||
db_session_with_containers.query(DatasetAutoDisableLog)
|
||||
.where(DatasetAutoDisableLog.document_id == document.id)
|
||||
.all()
|
||||
)
|
||||
remaining_logs = db_session_with_containers.scalars(
|
||||
select(DatasetAutoDisableLog).where(DatasetAutoDisableLog.document_id == document.id)
|
||||
).all()
|
||||
assert len(remaining_logs) == 0
|
||||
|
||||
# Verify index processing occurred normally
|
||||
|
||||
@@ -11,6 +11,7 @@ from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
@@ -267,11 +268,13 @@ class TestBatchCleanDocumentTask:
|
||||
db_session_with_containers.commit() # Ensure all changes are committed
|
||||
|
||||
# Check that segment is deleted
|
||||
deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first()
|
||||
deleted_segment = db_session_with_containers.scalar(
|
||||
select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1)
|
||||
)
|
||||
assert deleted_segment is None
|
||||
|
||||
# Check that upload file is deleted
|
||||
deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first()
|
||||
deleted_file = db_session_with_containers.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1))
|
||||
assert deleted_file is None
|
||||
|
||||
def test_batch_clean_document_task_with_image_files(
|
||||
@@ -319,7 +322,9 @@ class TestBatchCleanDocumentTask:
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Check that segment is deleted
|
||||
deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first()
|
||||
deleted_segment = db_session_with_containers.scalar(
|
||||
select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1)
|
||||
)
|
||||
assert deleted_segment is None
|
||||
|
||||
# Verify that the task completed successfully by checking the log output
|
||||
@@ -360,14 +365,14 @@ class TestBatchCleanDocumentTask:
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Check that upload file is deleted
|
||||
deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first()
|
||||
deleted_file = db_session_with_containers.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1))
|
||||
assert deleted_file is None
|
||||
|
||||
# Verify database cleanup
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Check that upload file is deleted
|
||||
deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first()
|
||||
deleted_file = db_session_with_containers.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1))
|
||||
assert deleted_file is None
|
||||
|
||||
def test_batch_clean_document_task_dataset_not_found(
|
||||
@@ -410,7 +415,9 @@ class TestBatchCleanDocumentTask:
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Document should still exist since cleanup failed
|
||||
existing_document = db_session_with_containers.query(Document).filter_by(id=document_id).first()
|
||||
existing_document = db_session_with_containers.scalar(
|
||||
select(Document).where(Document.id == document_id).limit(1)
|
||||
)
|
||||
assert existing_document is not None
|
||||
|
||||
def test_batch_clean_document_task_storage_cleanup_failure(
|
||||
@@ -453,11 +460,13 @@ class TestBatchCleanDocumentTask:
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Check that segment is deleted from database
|
||||
deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first()
|
||||
deleted_segment = db_session_with_containers.scalar(
|
||||
select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1)
|
||||
)
|
||||
assert deleted_segment is None
|
||||
|
||||
# Check that upload file is deleted from database
|
||||
deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first()
|
||||
deleted_file = db_session_with_containers.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1))
|
||||
assert deleted_file is None
|
||||
|
||||
def test_batch_clean_document_task_multiple_documents(
|
||||
@@ -510,12 +519,16 @@ class TestBatchCleanDocumentTask:
|
||||
|
||||
# Check that all segments are deleted
|
||||
for segment_id in segment_ids:
|
||||
deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first()
|
||||
deleted_segment = db_session_with_containers.scalar(
|
||||
select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1)
|
||||
)
|
||||
assert deleted_segment is None
|
||||
|
||||
# Check that all upload files are deleted
|
||||
for file_id in file_ids:
|
||||
deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first()
|
||||
deleted_file = db_session_with_containers.scalar(
|
||||
select(UploadFile).where(UploadFile.id == file_id).limit(1)
|
||||
)
|
||||
assert deleted_file is None
|
||||
|
||||
def test_batch_clean_document_task_different_doc_forms(
|
||||
@@ -564,7 +577,9 @@ class TestBatchCleanDocumentTask:
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Check that segment is deleted
|
||||
deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first()
|
||||
deleted_segment = db_session_with_containers.scalar(
|
||||
select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1)
|
||||
)
|
||||
assert deleted_segment is None
|
||||
|
||||
except Exception as e:
|
||||
@@ -574,7 +589,9 @@ class TestBatchCleanDocumentTask:
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Check if the segment still exists (task may have failed before deletion)
|
||||
existing_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first()
|
||||
existing_segment = db_session_with_containers.scalar(
|
||||
select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1)
|
||||
)
|
||||
if existing_segment is not None:
|
||||
# If segment still exists, the task failed before deletion
|
||||
# This is acceptable in test environments with external service issues
|
||||
@@ -645,12 +662,16 @@ class TestBatchCleanDocumentTask:
|
||||
|
||||
# Check that all segments are deleted
|
||||
for segment_id in segment_ids:
|
||||
deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first()
|
||||
deleted_segment = db_session_with_containers.scalar(
|
||||
select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1)
|
||||
)
|
||||
assert deleted_segment is None
|
||||
|
||||
# Check that all upload files are deleted
|
||||
for file_id in file_ids:
|
||||
deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first()
|
||||
deleted_file = db_session_with_containers.scalar(
|
||||
select(UploadFile).where(UploadFile.id == file_id).limit(1)
|
||||
)
|
||||
assert deleted_file is None
|
||||
|
||||
def test_batch_clean_document_task_integration_with_real_database(
|
||||
@@ -699,8 +720,16 @@ class TestBatchCleanDocumentTask:
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Verify initial state
|
||||
assert db_session_with_containers.query(DocumentSegment).filter_by(document_id=document.id).count() == 3
|
||||
assert db_session_with_containers.query(UploadFile).filter_by(id=upload_file.id).first() is not None
|
||||
assert (
|
||||
db_session_with_containers.scalar(
|
||||
select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id == document.id)
|
||||
)
|
||||
== 3
|
||||
)
|
||||
assert (
|
||||
db_session_with_containers.scalar(select(UploadFile).where(UploadFile.id == upload_file.id).limit(1))
|
||||
is not None
|
||||
)
|
||||
|
||||
# Store original IDs for verification
|
||||
document_id = document.id
|
||||
@@ -720,13 +749,20 @@ class TestBatchCleanDocumentTask:
|
||||
|
||||
# Check that all segments are deleted
|
||||
for segment_id in segment_ids:
|
||||
deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first()
|
||||
deleted_segment = db_session_with_containers.scalar(
|
||||
select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1)
|
||||
)
|
||||
assert deleted_segment is None
|
||||
|
||||
# Check that upload file is deleted
|
||||
deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first()
|
||||
deleted_file = db_session_with_containers.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1))
|
||||
assert deleted_file is None
|
||||
|
||||
# Verify final database state
|
||||
assert db_session_with_containers.query(DocumentSegment).filter_by(document_id=document_id).count() == 0
|
||||
assert db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() is None
|
||||
assert (
|
||||
db_session_with_containers.scalar(
|
||||
select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id == document_id)
|
||||
)
|
||||
== 0
|
||||
)
|
||||
assert db_session_with_containers.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1)) is None
|
||||
|
||||
@@ -17,6 +17,7 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
@@ -37,13 +38,13 @@ class TestBatchCreateSegmentToIndexTask:
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
# Clear all test data
|
||||
db_session_with_containers.query(DocumentSegment).delete()
|
||||
db_session_with_containers.query(Document).delete()
|
||||
db_session_with_containers.query(Dataset).delete()
|
||||
db_session_with_containers.query(UploadFile).delete()
|
||||
db_session_with_containers.query(TenantAccountJoin).delete()
|
||||
db_session_with_containers.query(Tenant).delete()
|
||||
db_session_with_containers.query(Account).delete()
|
||||
db_session_with_containers.execute(delete(DocumentSegment))
|
||||
db_session_with_containers.execute(delete(Document))
|
||||
db_session_with_containers.execute(delete(Dataset))
|
||||
db_session_with_containers.execute(delete(UploadFile))
|
||||
db_session_with_containers.execute(delete(TenantAccountJoin))
|
||||
db_session_with_containers.execute(delete(Tenant))
|
||||
db_session_with_containers.execute(delete(Account))
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Clear Redis cache
|
||||
@@ -292,12 +293,9 @@ class TestBatchCreateSegmentToIndexTask:
|
||||
# Verify results
|
||||
|
||||
# Check that segments were created
|
||||
segments = (
|
||||
db_session_with_containers.query(DocumentSegment)
|
||||
.filter_by(document_id=document.id)
|
||||
.order_by(DocumentSegment.position)
|
||||
.all()
|
||||
)
|
||||
segments = db_session_with_containers.scalars(
|
||||
select(DocumentSegment).where(DocumentSegment.document_id == document.id).order_by(DocumentSegment.position)
|
||||
).all()
|
||||
assert len(segments) == 3
|
||||
|
||||
# Verify segment content and metadata
|
||||
@@ -367,11 +365,11 @@ class TestBatchCreateSegmentToIndexTask:
|
||||
|
||||
# Verify no segments were created (since dataset doesn't exist)
|
||||
|
||||
segments = db_session_with_containers.query(DocumentSegment).all()
|
||||
segments = db_session_with_containers.scalars(select(DocumentSegment)).all()
|
||||
assert len(segments) == 0
|
||||
|
||||
# Verify no documents were modified
|
||||
documents = db_session_with_containers.query(Document).all()
|
||||
documents = db_session_with_containers.scalars(select(Document)).all()
|
||||
assert len(documents) == 0
|
||||
|
||||
def test_batch_create_segment_to_index_task_document_not_found(
|
||||
@@ -415,12 +413,14 @@ class TestBatchCreateSegmentToIndexTask:
|
||||
|
||||
# Verify no segments were created
|
||||
|
||||
segments = db_session_with_containers.query(DocumentSegment).all()
|
||||
segments = db_session_with_containers.scalars(select(DocumentSegment)).all()
|
||||
assert len(segments) == 0
|
||||
|
||||
# Verify dataset remains unchanged (no segments were added to the dataset)
|
||||
db_session_with_containers.refresh(dataset)
|
||||
segments_for_dataset = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
|
||||
segments_for_dataset = db_session_with_containers.scalars(
|
||||
select(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id)
|
||||
).all()
|
||||
assert len(segments_for_dataset) == 0
|
||||
|
||||
def test_batch_create_segment_to_index_task_document_not_available(
|
||||
@@ -516,7 +516,9 @@ class TestBatchCreateSegmentToIndexTask:
|
||||
assert cache_value == b"error"
|
||||
|
||||
# Verify no segments were created
|
||||
segments = db_session_with_containers.query(DocumentSegment).filter_by(document_id=document.id).all()
|
||||
segments = db_session_with_containers.scalars(
|
||||
select(DocumentSegment).where(DocumentSegment.document_id == document.id)
|
||||
).all()
|
||||
assert len(segments) == 0
|
||||
|
||||
def test_batch_create_segment_to_index_task_upload_file_not_found(
|
||||
@@ -560,7 +562,7 @@ class TestBatchCreateSegmentToIndexTask:
|
||||
|
||||
# Verify no segments were created
|
||||
|
||||
segments = db_session_with_containers.query(DocumentSegment).all()
|
||||
segments = db_session_with_containers.scalars(select(DocumentSegment)).all()
|
||||
assert len(segments) == 0
|
||||
|
||||
# Verify document remains unchanged
|
||||
@@ -611,7 +613,7 @@ class TestBatchCreateSegmentToIndexTask:
|
||||
# Verify error handling
|
||||
# Since exception was raised, no segments should be created
|
||||
|
||||
segments = db_session_with_containers.query(DocumentSegment).all()
|
||||
segments = db_session_with_containers.scalars(select(DocumentSegment)).all()
|
||||
assert len(segments) == 0
|
||||
|
||||
# Verify document remains unchanged
|
||||
@@ -682,12 +684,9 @@ class TestBatchCreateSegmentToIndexTask:
|
||||
|
||||
# Verify results
|
||||
# Check that new segments were created with correct positions
|
||||
all_segments = (
|
||||
db_session_with_containers.query(DocumentSegment)
|
||||
.filter_by(document_id=document.id)
|
||||
.order_by(DocumentSegment.position)
|
||||
.all()
|
||||
)
|
||||
all_segments = db_session_with_containers.scalars(
|
||||
select(DocumentSegment).where(DocumentSegment.document_id == document.id).order_by(DocumentSegment.position)
|
||||
).all()
|
||||
assert len(all_segments) == 6 # 3 existing + 3 new
|
||||
|
||||
# Verify position ordering
|
||||
|
||||
@@ -16,6 +16,7 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
@@ -52,18 +53,18 @@ class TestCleanDatasetTask:
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
# Clear all test data using the provided session fixture
|
||||
db_session_with_containers.query(DatasetMetadataBinding).delete()
|
||||
db_session_with_containers.query(DatasetMetadata).delete()
|
||||
db_session_with_containers.query(AppDatasetJoin).delete()
|
||||
db_session_with_containers.query(DatasetQuery).delete()
|
||||
db_session_with_containers.query(DatasetProcessRule).delete()
|
||||
db_session_with_containers.query(DocumentSegment).delete()
|
||||
db_session_with_containers.query(Document).delete()
|
||||
db_session_with_containers.query(Dataset).delete()
|
||||
db_session_with_containers.query(UploadFile).delete()
|
||||
db_session_with_containers.query(TenantAccountJoin).delete()
|
||||
db_session_with_containers.query(Tenant).delete()
|
||||
db_session_with_containers.query(Account).delete()
|
||||
db_session_with_containers.execute(delete(DatasetMetadataBinding))
|
||||
db_session_with_containers.execute(delete(DatasetMetadata))
|
||||
db_session_with_containers.execute(delete(AppDatasetJoin))
|
||||
db_session_with_containers.execute(delete(DatasetQuery))
|
||||
db_session_with_containers.execute(delete(DatasetProcessRule))
|
||||
db_session_with_containers.execute(delete(DocumentSegment))
|
||||
db_session_with_containers.execute(delete(Document))
|
||||
db_session_with_containers.execute(delete(Dataset))
|
||||
db_session_with_containers.execute(delete(UploadFile))
|
||||
db_session_with_containers.execute(delete(TenantAccountJoin))
|
||||
db_session_with_containers.execute(delete(Tenant))
|
||||
db_session_with_containers.execute(delete(Account))
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Clear Redis cache
|
||||
@@ -302,28 +303,40 @@ class TestCleanDatasetTask:
|
||||
|
||||
# Verify results
|
||||
# Check that dataset-related data was cleaned up
|
||||
documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all()
|
||||
documents = db_session_with_containers.scalars(select(Document).where(Document.dataset_id == dataset.id)).all()
|
||||
assert len(documents) == 0
|
||||
|
||||
segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
|
||||
segments = db_session_with_containers.scalars(
|
||||
select(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id)
|
||||
).all()
|
||||
assert len(segments) == 0
|
||||
|
||||
# Check that metadata and bindings were cleaned up
|
||||
metadata = db_session_with_containers.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all()
|
||||
metadata = db_session_with_containers.scalars(
|
||||
select(DatasetMetadata).where(DatasetMetadata.dataset_id == dataset.id)
|
||||
).all()
|
||||
assert len(metadata) == 0
|
||||
|
||||
bindings = db_session_with_containers.query(DatasetMetadataBinding).filter_by(dataset_id=dataset.id).all()
|
||||
bindings = db_session_with_containers.scalars(
|
||||
select(DatasetMetadataBinding).where(DatasetMetadataBinding.dataset_id == dataset.id)
|
||||
).all()
|
||||
assert len(bindings) == 0
|
||||
|
||||
# Check that process rules and queries were cleaned up
|
||||
process_rules = db_session_with_containers.query(DatasetProcessRule).filter_by(dataset_id=dataset.id).all()
|
||||
process_rules = db_session_with_containers.scalars(
|
||||
select(DatasetProcessRule).where(DatasetProcessRule.dataset_id == dataset.id)
|
||||
).all()
|
||||
assert len(process_rules) == 0
|
||||
|
||||
queries = db_session_with_containers.query(DatasetQuery).filter_by(dataset_id=dataset.id).all()
|
||||
queries = db_session_with_containers.scalars(
|
||||
select(DatasetQuery).where(DatasetQuery.dataset_id == dataset.id)
|
||||
).all()
|
||||
assert len(queries) == 0
|
||||
|
||||
# Check that app dataset joins were cleaned up
|
||||
app_joins = db_session_with_containers.query(AppDatasetJoin).filter_by(dataset_id=dataset.id).all()
|
||||
app_joins = db_session_with_containers.scalars(
|
||||
select(AppDatasetJoin).where(AppDatasetJoin.dataset_id == dataset.id)
|
||||
).all()
|
||||
assert len(app_joins) == 0
|
||||
|
||||
# Verify index processor was called
|
||||
@@ -414,24 +427,32 @@ class TestCleanDatasetTask:
|
||||
|
||||
# Verify results
|
||||
# Check that all documents were deleted
|
||||
remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all()
|
||||
remaining_documents = db_session_with_containers.scalars(
|
||||
select(Document).where(Document.dataset_id == dataset.id)
|
||||
).all()
|
||||
assert len(remaining_documents) == 0
|
||||
|
||||
# Check that all segments were deleted
|
||||
remaining_segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
|
||||
remaining_segments = db_session_with_containers.scalars(
|
||||
select(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id)
|
||||
).all()
|
||||
assert len(remaining_segments) == 0
|
||||
|
||||
# Check that all upload files were deleted
|
||||
remaining_files = db_session_with_containers.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).all()
|
||||
remaining_files = db_session_with_containers.scalars(
|
||||
select(UploadFile).where(UploadFile.id.in_(upload_file_ids))
|
||||
).all()
|
||||
assert len(remaining_files) == 0
|
||||
|
||||
# Check that metadata and bindings were cleaned up
|
||||
remaining_metadata = db_session_with_containers.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all()
|
||||
remaining_metadata = db_session_with_containers.scalars(
|
||||
select(DatasetMetadata).where(DatasetMetadata.dataset_id == dataset.id)
|
||||
).all()
|
||||
assert len(remaining_metadata) == 0
|
||||
|
||||
remaining_bindings = (
|
||||
db_session_with_containers.query(DatasetMetadataBinding).filter_by(dataset_id=dataset.id).all()
|
||||
)
|
||||
remaining_bindings = db_session_with_containers.scalars(
|
||||
select(DatasetMetadataBinding).where(DatasetMetadataBinding.dataset_id == dataset.id)
|
||||
).all()
|
||||
assert len(remaining_bindings) == 0
|
||||
|
||||
# Verify index processor was called
|
||||
@@ -485,12 +506,14 @@ class TestCleanDatasetTask:
|
||||
|
||||
# Check that all data was cleaned up
|
||||
|
||||
remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all()
|
||||
remaining_documents = db_session_with_containers.scalars(
|
||||
select(Document).where(Document.dataset_id == dataset.id)
|
||||
).all()
|
||||
assert len(remaining_documents) == 0
|
||||
|
||||
remaining_segments = (
|
||||
db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
|
||||
)
|
||||
remaining_segments = db_session_with_containers.scalars(
|
||||
select(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id)
|
||||
).all()
|
||||
assert len(remaining_segments) == 0
|
||||
|
||||
# Recreate data for next test case
|
||||
@@ -538,11 +561,15 @@ class TestCleanDatasetTask:
|
||||
# Verify results - even with vector cleanup failure, documents and segments should be deleted
|
||||
|
||||
# Check that documents were still deleted despite vector cleanup failure
|
||||
remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all()
|
||||
remaining_documents = db_session_with_containers.scalars(
|
||||
select(Document).where(Document.dataset_id == dataset.id)
|
||||
).all()
|
||||
assert len(remaining_documents) == 0
|
||||
|
||||
# Check that segments were still deleted despite vector cleanup failure
|
||||
remaining_segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
|
||||
remaining_segments = db_session_with_containers.scalars(
|
||||
select(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id)
|
||||
).all()
|
||||
assert len(remaining_segments) == 0
|
||||
|
||||
# Verify that index processor was called and failed
|
||||
@@ -622,18 +649,22 @@ class TestCleanDatasetTask:
|
||||
|
||||
# Verify results
|
||||
# Check that all documents were deleted
|
||||
remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all()
|
||||
remaining_documents = db_session_with_containers.scalars(
|
||||
select(Document).where(Document.dataset_id == dataset.id)
|
||||
).all()
|
||||
assert len(remaining_documents) == 0
|
||||
|
||||
# Check that all segments were deleted
|
||||
remaining_segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
|
||||
remaining_segments = db_session_with_containers.scalars(
|
||||
select(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id)
|
||||
).all()
|
||||
assert len(remaining_segments) == 0
|
||||
|
||||
# Check that all image files were deleted from database
|
||||
image_file_ids = [f.id for f in image_files]
|
||||
remaining_image_files = (
|
||||
db_session_with_containers.query(UploadFile).where(UploadFile.id.in_(image_file_ids)).all()
|
||||
)
|
||||
remaining_image_files = db_session_with_containers.scalars(
|
||||
select(UploadFile).where(UploadFile.id.in_(image_file_ids))
|
||||
).all()
|
||||
assert len(remaining_image_files) == 0
|
||||
|
||||
# Verify that storage.delete was called for each image file
|
||||
@@ -738,24 +769,32 @@ class TestCleanDatasetTask:
|
||||
|
||||
# Verify results
|
||||
# Check that all documents were deleted
|
||||
remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all()
|
||||
remaining_documents = db_session_with_containers.scalars(
|
||||
select(Document).where(Document.dataset_id == dataset.id)
|
||||
).all()
|
||||
assert len(remaining_documents) == 0
|
||||
|
||||
# Check that all segments were deleted
|
||||
remaining_segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
|
||||
remaining_segments = db_session_with_containers.scalars(
|
||||
select(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id)
|
||||
).all()
|
||||
assert len(remaining_segments) == 0
|
||||
|
||||
# Check that all upload files were deleted
|
||||
remaining_files = db_session_with_containers.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).all()
|
||||
remaining_files = db_session_with_containers.scalars(
|
||||
select(UploadFile).where(UploadFile.id.in_(upload_file_ids))
|
||||
).all()
|
||||
assert len(remaining_files) == 0
|
||||
|
||||
# Check that all metadata and bindings were deleted
|
||||
remaining_metadata = db_session_with_containers.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all()
|
||||
remaining_metadata = db_session_with_containers.scalars(
|
||||
select(DatasetMetadata).where(DatasetMetadata.dataset_id == dataset.id)
|
||||
).all()
|
||||
assert len(remaining_metadata) == 0
|
||||
|
||||
remaining_bindings = (
|
||||
db_session_with_containers.query(DatasetMetadataBinding).filter_by(dataset_id=dataset.id).all()
|
||||
)
|
||||
remaining_bindings = db_session_with_containers.scalars(
|
||||
select(DatasetMetadataBinding).where(DatasetMetadataBinding.dataset_id == dataset.id)
|
||||
).all()
|
||||
assert len(remaining_bindings) == 0
|
||||
|
||||
# Verify performance expectations
|
||||
@@ -826,7 +865,9 @@ class TestCleanDatasetTask:
|
||||
# Check that upload file was still deleted from database despite storage failure
|
||||
# Note: When storage operations fail, the upload file may not be deleted
|
||||
# This demonstrates that the cleanup process continues even with storage errors
|
||||
remaining_files = db_session_with_containers.query(UploadFile).filter_by(id=upload_file.id).all()
|
||||
remaining_files = db_session_with_containers.scalars(
|
||||
select(UploadFile).where(UploadFile.id == upload_file.id)
|
||||
).all()
|
||||
# The upload file should still be deleted from the database even if storage cleanup fails
|
||||
# However, this depends on the specific implementation of clean_dataset_task
|
||||
if len(remaining_files) > 0:
|
||||
@@ -976,19 +1017,27 @@ class TestCleanDatasetTask:
|
||||
|
||||
# Verify results
|
||||
# Check that all documents were deleted
|
||||
remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all()
|
||||
remaining_documents = db_session_with_containers.scalars(
|
||||
select(Document).where(Document.dataset_id == dataset.id)
|
||||
).all()
|
||||
assert len(remaining_documents) == 0
|
||||
|
||||
# Check that all segments were deleted
|
||||
remaining_segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
|
||||
remaining_segments = db_session_with_containers.scalars(
|
||||
select(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id)
|
||||
).all()
|
||||
assert len(remaining_segments) == 0
|
||||
|
||||
# Check that all upload files were deleted
|
||||
remaining_files = db_session_with_containers.query(UploadFile).filter_by(id=upload_file_id).all()
|
||||
remaining_files = db_session_with_containers.scalars(
|
||||
select(UploadFile).where(UploadFile.id == upload_file_id)
|
||||
).all()
|
||||
assert len(remaining_files) == 0
|
||||
|
||||
# Check that all metadata was deleted
|
||||
remaining_metadata = db_session_with_containers.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all()
|
||||
remaining_metadata = db_session_with_containers.scalars(
|
||||
select(DatasetMetadata).where(DatasetMetadata.dataset_id == dataset.id)
|
||||
).all()
|
||||
assert len(remaining_metadata) == 0
|
||||
|
||||
# Verify that storage.delete was called
|
||||
|
||||
@@ -11,6 +11,7 @@ from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
from sqlalchemy import func, select
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
@@ -145,11 +146,16 @@ class TestCleanNotionDocumentTask:
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Verify data exists before cleanup
|
||||
assert db_session_with_containers.query(Document).filter(Document.id.in_(document_ids)).count() == 3
|
||||
assert (
|
||||
db_session_with_containers.query(DocumentSegment)
|
||||
.filter(DocumentSegment.document_id.in_(document_ids))
|
||||
.count()
|
||||
db_session_with_containers.scalar(
|
||||
select(func.count()).select_from(Document).where(Document.id.in_(document_ids))
|
||||
)
|
||||
== 3
|
||||
)
|
||||
assert (
|
||||
db_session_with_containers.scalar(
|
||||
select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids))
|
||||
)
|
||||
== 6
|
||||
)
|
||||
|
||||
@@ -158,9 +164,9 @@ class TestCleanNotionDocumentTask:
|
||||
|
||||
# Verify segments are deleted
|
||||
assert (
|
||||
db_session_with_containers.query(DocumentSegment)
|
||||
.filter(DocumentSegment.document_id.in_(document_ids))
|
||||
.count()
|
||||
db_session_with_containers.scalar(
|
||||
select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids))
|
||||
)
|
||||
== 0
|
||||
)
|
||||
|
||||
@@ -323,9 +329,9 @@ class TestCleanNotionDocumentTask:
|
||||
|
||||
# Verify segments are deleted
|
||||
assert (
|
||||
db_session_with_containers.query(DocumentSegment)
|
||||
.filter(DocumentSegment.document_id == document.id)
|
||||
.count()
|
||||
db_session_with_containers.scalar(
|
||||
select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id == document.id)
|
||||
)
|
||||
== 0
|
||||
)
|
||||
|
||||
@@ -411,7 +417,9 @@ class TestCleanNotionDocumentTask:
|
||||
|
||||
# Verify segments are deleted
|
||||
assert (
|
||||
db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).count()
|
||||
db_session_with_containers.scalar(
|
||||
select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id == document.id)
|
||||
)
|
||||
== 0
|
||||
)
|
||||
|
||||
@@ -499,9 +507,16 @@ class TestCleanNotionDocumentTask:
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Verify all data exists before cleanup
|
||||
assert db_session_with_containers.query(Document).filter(Document.dataset_id == dataset.id).count() == 5
|
||||
assert (
|
||||
db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset.id).count()
|
||||
db_session_with_containers.scalar(
|
||||
select(func.count()).select_from(Document).where(Document.dataset_id == dataset.id)
|
||||
)
|
||||
== 5
|
||||
)
|
||||
assert (
|
||||
db_session_with_containers.scalar(
|
||||
select(func.count()).select_from(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id)
|
||||
)
|
||||
== 10
|
||||
)
|
||||
|
||||
@@ -514,19 +529,26 @@ class TestCleanNotionDocumentTask:
|
||||
|
||||
# Verify only specified documents' segments are deleted
|
||||
assert (
|
||||
db_session_with_containers.query(DocumentSegment)
|
||||
.filter(DocumentSegment.document_id.in_(documents_to_clean))
|
||||
.count()
|
||||
db_session_with_containers.scalar(
|
||||
select(func.count())
|
||||
.select_from(DocumentSegment)
|
||||
.where(DocumentSegment.document_id.in_(documents_to_clean))
|
||||
)
|
||||
== 0
|
||||
)
|
||||
|
||||
# Verify remaining documents and segments are intact
|
||||
remaining_docs = [doc.id for doc in documents[3:]]
|
||||
assert db_session_with_containers.query(Document).filter(Document.id.in_(remaining_docs)).count() == 2
|
||||
assert (
|
||||
db_session_with_containers.query(DocumentSegment)
|
||||
.filter(DocumentSegment.document_id.in_(remaining_docs))
|
||||
.count()
|
||||
db_session_with_containers.scalar(
|
||||
select(func.count()).select_from(Document).where(Document.id.in_(remaining_docs))
|
||||
)
|
||||
== 2
|
||||
)
|
||||
assert (
|
||||
db_session_with_containers.scalar(
|
||||
select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id.in_(remaining_docs))
|
||||
)
|
||||
== 4
|
||||
)
|
||||
|
||||
@@ -613,7 +635,9 @@ class TestCleanNotionDocumentTask:
|
||||
|
||||
# Verify all segments exist before cleanup
|
||||
assert (
|
||||
db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).count()
|
||||
db_session_with_containers.scalar(
|
||||
select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id == document.id)
|
||||
)
|
||||
== 4
|
||||
)
|
||||
|
||||
@@ -622,7 +646,9 @@ class TestCleanNotionDocumentTask:
|
||||
|
||||
# Verify all segments are deleted regardless of status
|
||||
assert (
|
||||
db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).count()
|
||||
db_session_with_containers.scalar(
|
||||
select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id == document.id)
|
||||
)
|
||||
== 0
|
||||
)
|
||||
|
||||
@@ -795,11 +821,15 @@ class TestCleanNotionDocumentTask:
|
||||
|
||||
# Verify all data exists before cleanup
|
||||
assert (
|
||||
db_session_with_containers.query(Document).filter(Document.dataset_id == dataset.id).count()
|
||||
db_session_with_containers.scalar(
|
||||
select(func.count()).select_from(Document).where(Document.dataset_id == dataset.id)
|
||||
)
|
||||
== num_documents
|
||||
)
|
||||
assert (
|
||||
db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset.id).count()
|
||||
db_session_with_containers.scalar(
|
||||
select(func.count()).select_from(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id)
|
||||
)
|
||||
== num_documents * num_segments_per_doc
|
||||
)
|
||||
|
||||
@@ -809,7 +839,9 @@ class TestCleanNotionDocumentTask:
|
||||
|
||||
# Verify all segments are deleted
|
||||
assert (
|
||||
db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset.id).count()
|
||||
db_session_with_containers.scalar(
|
||||
select(func.count()).select_from(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id)
|
||||
)
|
||||
== 0
|
||||
)
|
||||
|
||||
@@ -906,8 +938,8 @@ class TestCleanNotionDocumentTask:
|
||||
|
||||
# Verify all data exists before cleanup
|
||||
# Note: There may be documents from previous tests, so we check for at least 3
|
||||
assert db_session_with_containers.query(Document).count() >= 3
|
||||
assert db_session_with_containers.query(DocumentSegment).count() >= 9
|
||||
assert db_session_with_containers.scalar(select(func.count()).select_from(Document)) >= 3
|
||||
assert db_session_with_containers.scalar(select(func.count()).select_from(DocumentSegment)) >= 9
|
||||
|
||||
# Clean up documents from only the first dataset
|
||||
target_dataset = datasets[0]
|
||||
@@ -919,19 +951,26 @@ class TestCleanNotionDocumentTask:
|
||||
|
||||
# Verify only documents' segments from target dataset are deleted
|
||||
assert (
|
||||
db_session_with_containers.query(DocumentSegment)
|
||||
.filter(DocumentSegment.document_id == target_document.id)
|
||||
.count()
|
||||
db_session_with_containers.scalar(
|
||||
select(func.count())
|
||||
.select_from(DocumentSegment)
|
||||
.where(DocumentSegment.document_id == target_document.id)
|
||||
)
|
||||
== 0
|
||||
)
|
||||
|
||||
# Verify documents from other datasets remain intact
|
||||
remaining_docs = [doc.id for doc in all_documents[1:]]
|
||||
assert db_session_with_containers.query(Document).filter(Document.id.in_(remaining_docs)).count() == 2
|
||||
assert (
|
||||
db_session_with_containers.query(DocumentSegment)
|
||||
.filter(DocumentSegment.document_id.in_(remaining_docs))
|
||||
.count()
|
||||
db_session_with_containers.scalar(
|
||||
select(func.count()).select_from(Document).where(Document.id.in_(remaining_docs))
|
||||
)
|
||||
== 2
|
||||
)
|
||||
assert (
|
||||
db_session_with_containers.scalar(
|
||||
select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id.in_(remaining_docs))
|
||||
)
|
||||
== 6
|
||||
)
|
||||
|
||||
@@ -1028,11 +1067,13 @@ class TestCleanNotionDocumentTask:
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Verify all data exists before cleanup
|
||||
assert db_session_with_containers.query(Document).filter(Document.dataset_id == dataset.id).count() == len(
|
||||
document_statuses
|
||||
)
|
||||
assert db_session_with_containers.scalar(
|
||||
select(func.count()).select_from(Document).where(Document.dataset_id == dataset.id)
|
||||
) == len(document_statuses)
|
||||
assert (
|
||||
db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset.id).count()
|
||||
db_session_with_containers.scalar(
|
||||
select(func.count()).select_from(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id)
|
||||
)
|
||||
== len(document_statuses) * 2
|
||||
)
|
||||
|
||||
@@ -1042,7 +1083,9 @@ class TestCleanNotionDocumentTask:
|
||||
|
||||
# Verify all segments are deleted regardless of status
|
||||
assert (
|
||||
db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset.id).count()
|
||||
db_session_with_containers.scalar(
|
||||
select(func.count()).select_from(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id)
|
||||
)
|
||||
== 0
|
||||
)
|
||||
|
||||
@@ -1142,9 +1185,16 @@ class TestCleanNotionDocumentTask:
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Verify data exists before cleanup
|
||||
assert db_session_with_containers.query(Document).filter(Document.id == document.id).count() == 1
|
||||
assert (
|
||||
db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).count()
|
||||
db_session_with_containers.scalar(
|
||||
select(func.count()).select_from(Document).where(Document.id == document.id)
|
||||
)
|
||||
== 1
|
||||
)
|
||||
assert (
|
||||
db_session_with_containers.scalar(
|
||||
select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id == document.id)
|
||||
)
|
||||
== 3
|
||||
)
|
||||
|
||||
@@ -1153,7 +1203,9 @@ class TestCleanNotionDocumentTask:
|
||||
|
||||
# Verify segments are deleted
|
||||
assert (
|
||||
db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).count()
|
||||
db_session_with_containers.scalar(
|
||||
select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id == document.id)
|
||||
)
|
||||
== 0
|
||||
)
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.indexing_runner import DocumentIsPausedError
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
@@ -175,7 +176,7 @@ class TestDatasetIndexingTaskIntegration:
|
||||
|
||||
def _query_document(self, db_session_with_containers, document_id: str) -> Document | None:
|
||||
"""Return the latest persisted document state."""
|
||||
return db_session_with_containers.query(Document).where(Document.id == document_id).first()
|
||||
return db_session_with_containers.scalar(select(Document).where(Document.id == document_id).limit(1))
|
||||
|
||||
def _assert_documents_parsing(self, db_session_with_containers, document_ids: Sequence[str]) -> None:
|
||||
"""Assert all target documents are persisted in parsing status."""
|
||||
|
||||
@@ -9,6 +9,7 @@ The task is responsible for removing document segments from the search index whe
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from faker import Faker
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
@@ -471,9 +472,9 @@ class TestDisableSegmentsFromIndexTask:
|
||||
db_session_with_containers.refresh(segments[1])
|
||||
|
||||
# Check that segments are re-enabled after error
|
||||
updated_segments = (
|
||||
db_session_with_containers.query(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)).all()
|
||||
)
|
||||
updated_segments = db_session_with_containers.scalars(
|
||||
select(DocumentSegment).where(DocumentSegment.id.in_(segment_ids))
|
||||
).all()
|
||||
|
||||
for segment in updated_segments:
|
||||
assert segment.enabled is True
|
||||
|
||||
@@ -12,6 +12,7 @@ from unittest.mock import Mock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import delete, func, select, update
|
||||
|
||||
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
@@ -254,8 +255,8 @@ class TestDocumentIndexingSyncTask:
|
||||
"""Test that task raises error when data_source_info is empty."""
|
||||
# Arrange
|
||||
context = self._create_notion_sync_context(db_session_with_containers, data_source_info=None)
|
||||
db_session_with_containers.query(Document).where(Document.id == context["document"].id).update(
|
||||
{"data_source_info": None}
|
||||
db_session_with_containers.execute(
|
||||
update(Document).where(Document.id == context["document"].id).values(data_source_info=None)
|
||||
)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
@@ -274,8 +275,8 @@ class TestDocumentIndexingSyncTask:
|
||||
|
||||
# Assert
|
||||
db_session_with_containers.expire_all()
|
||||
updated_document = (
|
||||
db_session_with_containers.query(Document).where(Document.id == context["document"].id).first()
|
||||
updated_document = db_session_with_containers.scalar(
|
||||
select(Document).where(Document.id == context["document"].id).limit(1)
|
||||
)
|
||||
assert updated_document is not None
|
||||
assert updated_document.indexing_status == IndexingStatus.ERROR
|
||||
@@ -294,13 +295,13 @@ class TestDocumentIndexingSyncTask:
|
||||
|
||||
# Assert
|
||||
db_session_with_containers.expire_all()
|
||||
updated_document = (
|
||||
db_session_with_containers.query(Document).where(Document.id == context["document"].id).first()
|
||||
updated_document = db_session_with_containers.scalar(
|
||||
select(Document).where(Document.id == context["document"].id).limit(1)
|
||||
)
|
||||
remaining_segments = (
|
||||
db_session_with_containers.query(DocumentSegment)
|
||||
remaining_segments = db_session_with_containers.scalar(
|
||||
select(func.count())
|
||||
.select_from(DocumentSegment)
|
||||
.where(DocumentSegment.document_id == context["document"].id)
|
||||
.count()
|
||||
)
|
||||
assert updated_document is not None
|
||||
assert updated_document.indexing_status == IndexingStatus.COMPLETED
|
||||
@@ -319,13 +320,13 @@ class TestDocumentIndexingSyncTask:
|
||||
|
||||
# Assert
|
||||
db_session_with_containers.expire_all()
|
||||
updated_document = (
|
||||
db_session_with_containers.query(Document).where(Document.id == context["document"].id).first()
|
||||
updated_document = db_session_with_containers.scalar(
|
||||
select(Document).where(Document.id == context["document"].id).limit(1)
|
||||
)
|
||||
remaining_segments = (
|
||||
db_session_with_containers.query(DocumentSegment)
|
||||
remaining_segments = db_session_with_containers.scalar(
|
||||
select(func.count())
|
||||
.select_from(DocumentSegment)
|
||||
.where(DocumentSegment.document_id == context["document"].id)
|
||||
.count()
|
||||
)
|
||||
|
||||
assert updated_document is not None
|
||||
@@ -354,7 +355,7 @@ class TestDocumentIndexingSyncTask:
|
||||
context = self._create_notion_sync_context(db_session_with_containers)
|
||||
|
||||
def _delete_dataset_before_clean() -> str:
|
||||
db_session_with_containers.query(Dataset).where(Dataset.id == context["dataset"].id).delete()
|
||||
db_session_with_containers.execute(delete(Dataset).where(Dataset.id == context["dataset"].id))
|
||||
db_session_with_containers.commit()
|
||||
return "2024-01-02T00:00:00Z"
|
||||
|
||||
@@ -367,8 +368,8 @@ class TestDocumentIndexingSyncTask:
|
||||
|
||||
# Assert
|
||||
db_session_with_containers.expire_all()
|
||||
updated_document = (
|
||||
db_session_with_containers.query(Document).where(Document.id == context["document"].id).first()
|
||||
updated_document = db_session_with_containers.scalar(
|
||||
select(Document).where(Document.id == context["document"].id).limit(1)
|
||||
)
|
||||
assert updated_document is not None
|
||||
assert updated_document.indexing_status == IndexingStatus.PARSING
|
||||
@@ -386,13 +387,13 @@ class TestDocumentIndexingSyncTask:
|
||||
|
||||
# Assert
|
||||
db_session_with_containers.expire_all()
|
||||
updated_document = (
|
||||
db_session_with_containers.query(Document).where(Document.id == context["document"].id).first()
|
||||
updated_document = db_session_with_containers.scalar(
|
||||
select(Document).where(Document.id == context["document"].id).limit(1)
|
||||
)
|
||||
remaining_segments = (
|
||||
db_session_with_containers.query(DocumentSegment)
|
||||
remaining_segments = db_session_with_containers.scalar(
|
||||
select(func.count())
|
||||
.select_from(DocumentSegment)
|
||||
.where(DocumentSegment.document_id == context["document"].id)
|
||||
.count()
|
||||
)
|
||||
assert updated_document is not None
|
||||
assert updated_document.indexing_status == IndexingStatus.PARSING
|
||||
@@ -410,8 +411,8 @@ class TestDocumentIndexingSyncTask:
|
||||
|
||||
# Assert
|
||||
db_session_with_containers.expire_all()
|
||||
updated_document = (
|
||||
db_session_with_containers.query(Document).where(Document.id == context["document"].id).first()
|
||||
updated_document = db_session_with_containers.scalar(
|
||||
select(Document).where(Document.id == context["document"].id).limit(1)
|
||||
)
|
||||
assert updated_document is not None
|
||||
assert updated_document.indexing_status == IndexingStatus.PARSING
|
||||
@@ -428,8 +429,8 @@ class TestDocumentIndexingSyncTask:
|
||||
|
||||
# Assert
|
||||
db_session_with_containers.expire_all()
|
||||
updated_document = (
|
||||
db_session_with_containers.query(Document).where(Document.id == context["document"].id).first()
|
||||
updated_document = db_session_with_containers.scalar(
|
||||
select(Document).where(Document.id == context["document"].id).limit(1)
|
||||
)
|
||||
assert updated_document is not None
|
||||
assert updated_document.indexing_status == IndexingStatus.ERROR
|
||||
|
||||
@@ -2,6 +2,7 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
from sqlalchemy import func, select
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
@@ -123,13 +124,13 @@ class TestDocumentIndexingUpdateTask:
|
||||
db_session_with_containers.expire_all()
|
||||
|
||||
# Assert document status updated before reindex
|
||||
updated = db_session_with_containers.query(Document).where(Document.id == document.id).first()
|
||||
updated = db_session_with_containers.scalar(select(Document).where(Document.id == document.id).limit(1))
|
||||
assert updated.indexing_status == IndexingStatus.PARSING
|
||||
assert updated.processing_started_at is not None
|
||||
|
||||
# Segments should be deleted
|
||||
remaining = (
|
||||
db_session_with_containers.query(DocumentSegment).where(DocumentSegment.document_id == document.id).count()
|
||||
remaining = db_session_with_containers.scalar(
|
||||
select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id == document.id)
|
||||
)
|
||||
assert remaining == 0
|
||||
|
||||
@@ -167,8 +168,8 @@ class TestDocumentIndexingUpdateTask:
|
||||
mock_external_dependencies["runner_instance"].run.assert_called_once()
|
||||
|
||||
# Segments should remain (since clean failed before DB delete)
|
||||
remaining = (
|
||||
db_session_with_containers.query(DocumentSegment).where(DocumentSegment.document_id == document.id).count()
|
||||
remaining = db_session_with_containers.scalar(
|
||||
select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id == document.id)
|
||||
)
|
||||
assert remaining > 0
|
||||
|
||||
|
||||
@@ -145,7 +145,7 @@ def test_inner_api_config_exist(monkeypatch: pytest.MonkeyPatch):
|
||||
|
||||
|
||||
def test_db_extras_options_merging(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Test that DB_EXTRAS options are properly merged with default timezone setting"""
|
||||
"""Test that DB_EXTRAS options are merged with the default timezone startup option."""
|
||||
# Set environment variables
|
||||
monkeypatch.setenv("DB_TYPE", "postgresql")
|
||||
monkeypatch.setenv("DB_USERNAME", "postgres")
|
||||
@@ -158,15 +158,28 @@ def test_db_extras_options_merging(monkeypatch: pytest.MonkeyPatch):
|
||||
# Create config
|
||||
config = DifyConfig()
|
||||
|
||||
# Get engine options
|
||||
engine_options = config.SQLALCHEMY_ENGINE_OPTIONS
|
||||
|
||||
# Verify options contains both search_path and timezone
|
||||
options = engine_options["connect_args"]["options"]
|
||||
options = config.SQLALCHEMY_ENGINE_OPTIONS["connect_args"]["options"]
|
||||
assert "search_path=myschema" in options
|
||||
assert "timezone=UTC" in options
|
||||
|
||||
|
||||
def test_db_session_timezone_override_can_disable_app_level_timezone_injection(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setenv("DB_TYPE", "postgresql")
|
||||
monkeypatch.setenv("DB_USERNAME", "postgres")
|
||||
monkeypatch.setenv("DB_PASSWORD", "postgres")
|
||||
monkeypatch.setenv("DB_HOST", "localhost")
|
||||
monkeypatch.setenv("DB_PORT", "5432")
|
||||
monkeypatch.setenv("DB_DATABASE", "dify")
|
||||
monkeypatch.setenv("DB_EXTRAS", "options=-c search_path=myschema")
|
||||
monkeypatch.setenv("DB_SESSION_TIMEZONE_OVERRIDE", "")
|
||||
|
||||
config = DifyConfig()
|
||||
|
||||
assert config.SQLALCHEMY_ENGINE_OPTIONS["connect_args"] == {
|
||||
"options": "-c search_path=myschema",
|
||||
}
|
||||
|
||||
|
||||
def test_pubsub_redis_url_default(monkeypatch: pytest.MonkeyPatch):
|
||||
os.environ.clear()
|
||||
|
||||
@@ -223,6 +236,41 @@ def test_pubsub_redis_url_required_when_default_unavailable(monkeypatch: pytest.
|
||||
_ = DifyConfig().normalized_pubsub_redis_url
|
||||
|
||||
|
||||
def test_dify_config_exposes_redis_key_prefix_default(monkeypatch: pytest.MonkeyPatch):
|
||||
os.environ.clear()
|
||||
|
||||
monkeypatch.setenv("CONSOLE_API_URL", "https://example.com")
|
||||
monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com")
|
||||
monkeypatch.setenv("DB_TYPE", "postgresql")
|
||||
monkeypatch.setenv("DB_USERNAME", "postgres")
|
||||
monkeypatch.setenv("DB_PASSWORD", "postgres")
|
||||
monkeypatch.setenv("DB_HOST", "localhost")
|
||||
monkeypatch.setenv("DB_PORT", "5432")
|
||||
monkeypatch.setenv("DB_DATABASE", "dify")
|
||||
|
||||
config = DifyConfig(_env_file=None)
|
||||
|
||||
assert config.REDIS_KEY_PREFIX == ""
|
||||
|
||||
|
||||
def test_dify_config_reads_redis_key_prefix_from_env(monkeypatch: pytest.MonkeyPatch):
|
||||
os.environ.clear()
|
||||
|
||||
monkeypatch.setenv("CONSOLE_API_URL", "https://example.com")
|
||||
monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com")
|
||||
monkeypatch.setenv("DB_TYPE", "postgresql")
|
||||
monkeypatch.setenv("DB_USERNAME", "postgres")
|
||||
monkeypatch.setenv("DB_PASSWORD", "postgres")
|
||||
monkeypatch.setenv("DB_HOST", "localhost")
|
||||
monkeypatch.setenv("DB_PORT", "5432")
|
||||
monkeypatch.setenv("DB_DATABASE", "dify")
|
||||
monkeypatch.setenv("REDIS_KEY_PREFIX", "enterprise-a")
|
||||
|
||||
config = DifyConfig(_env_file=None)
|
||||
|
||||
assert config.REDIS_KEY_PREFIX == "enterprise-a"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("broker_url", "expected_host", "expected_port", "expected_username", "expected_password", "expected_db"),
|
||||
[
|
||||
|
||||
@@ -138,12 +138,15 @@ def app_models(app_module):
|
||||
def patch_signed_url(monkeypatch, app_module):
|
||||
"""Ensure icon URL generation uses a deterministic helper for tests."""
|
||||
|
||||
def _fake_signed_url(key: str | None) -> str | None:
|
||||
if not key:
|
||||
def _fake_build_icon_url(_icon_type, key: str | None) -> str | None:
|
||||
if key is None:
|
||||
return None
|
||||
icon_type = str(_icon_type).lower()
|
||||
if icon_type != "image":
|
||||
return None
|
||||
return f"signed:{key}"
|
||||
|
||||
monkeypatch.setattr(app_module.file_helpers, "get_signed_file_url", _fake_signed_url)
|
||||
monkeypatch.setattr(app_module, "build_icon_url", _fake_build_icon_url)
|
||||
|
||||
|
||||
def _ts(hour: int = 12) -> datetime:
|
||||
|
||||
@@ -1,42 +0,0 @@
|
||||
from datetime import datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from controllers.console.app.conversation import _get_conversation
|
||||
|
||||
|
||||
def test_get_conversation_mark_read_keeps_updated_at_unchanged():
|
||||
app_model = SimpleNamespace(id="app-id")
|
||||
account = SimpleNamespace(id="account-id")
|
||||
conversation = MagicMock()
|
||||
conversation.id = "conversation-id"
|
||||
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.app.conversation.current_account_with_tenant",
|
||||
return_value=(account, None),
|
||||
autospec=True,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.app.conversation.naive_utc_now",
|
||||
return_value=datetime(2026, 2, 9, 0, 0, 0),
|
||||
autospec=True,
|
||||
),
|
||||
patch("controllers.console.app.conversation.db.session", autospec=True) as mock_session,
|
||||
):
|
||||
mock_session.scalar.return_value = conversation
|
||||
|
||||
_get_conversation(app_model, "conversation-id")
|
||||
|
||||
statement = mock_session.execute.call_args[0][0]
|
||||
compiled = statement.compile()
|
||||
sql_text = str(compiled).lower()
|
||||
compact_sql_text = sql_text.replace(" ", "")
|
||||
params = compiled.params
|
||||
|
||||
assert "updated_at=current_timestamp" not in compact_sql_text
|
||||
assert "updated_at=conversations.updated_at" in compact_sql_text
|
||||
assert "read_at=:read_at" in compact_sql_text
|
||||
assert "read_account_id=:read_account_id" in compact_sql_text
|
||||
assert params["read_at"] == datetime(2026, 2, 9, 0, 0, 0)
|
||||
assert params["read_account_id"] == "account-id"
|
||||
@@ -0,0 +1,108 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import nullcontext
|
||||
from datetime import UTC, datetime
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from graphon.variables.types import SegmentType
|
||||
from pydantic import ValidationError
|
||||
|
||||
from controllers.console.app import conversation_variables as conversation_variables_module
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
bound_self = getattr(func, "__self__", None)
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
if bound_self is not None:
|
||||
return func.__get__(bound_self, bound_self.__class__)
|
||||
return func
|
||||
|
||||
|
||||
def test_get_conversation_variables_returns_paginated_response(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = conversation_variables_module.ConversationVariablesApi()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
created_at = datetime(2026, 1, 1, tzinfo=UTC)
|
||||
updated_at = datetime(2026, 1, 2, tzinfo=UTC)
|
||||
row = SimpleNamespace(
|
||||
created_at=created_at,
|
||||
updated_at=updated_at,
|
||||
to_variable=lambda: SimpleNamespace(
|
||||
model_dump=lambda: {
|
||||
"id": "var-1",
|
||||
"name": "my_var",
|
||||
"value_type": "string",
|
||||
"value": "value",
|
||||
"description": "desc",
|
||||
}
|
||||
),
|
||||
)
|
||||
session = SimpleNamespace(scalars=lambda _stmt: SimpleNamespace(all=lambda: [row]))
|
||||
monkeypatch.setattr(conversation_variables_module, "db", SimpleNamespace(engine=object()))
|
||||
monkeypatch.setattr(
|
||||
conversation_variables_module,
|
||||
"sessionmaker",
|
||||
lambda *_args, **_kwargs: SimpleNamespace(begin=lambda: nullcontext(session)),
|
||||
)
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/apps/app-1/conversation-variables",
|
||||
method="GET",
|
||||
query_string={"conversation_id": "conv-1"},
|
||||
):
|
||||
response = method(app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
assert response["page"] == 1
|
||||
assert response["limit"] == 100
|
||||
assert response["total"] == 1
|
||||
assert response["has_more"] is False
|
||||
assert response["data"][0]["id"] == "var-1"
|
||||
assert response["data"][0]["created_at"] == int(created_at.timestamp())
|
||||
assert response["data"][0]["updated_at"] == int(updated_at.timestamp())
|
||||
|
||||
|
||||
def test_get_conversation_variables_normalizes_value_type_and_value(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = conversation_variables_module.ConversationVariablesApi()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
row = SimpleNamespace(
|
||||
created_at=None,
|
||||
updated_at=None,
|
||||
to_variable=lambda: SimpleNamespace(
|
||||
model_dump=lambda: {
|
||||
"id": "var-2",
|
||||
"name": "my_var_2",
|
||||
"value_type": SegmentType.INTEGER,
|
||||
"value": 42,
|
||||
"description": None,
|
||||
}
|
||||
),
|
||||
)
|
||||
session = SimpleNamespace(scalars=lambda _stmt: SimpleNamespace(all=lambda: [row]))
|
||||
monkeypatch.setattr(conversation_variables_module, "db", SimpleNamespace(engine=object()))
|
||||
monkeypatch.setattr(
|
||||
conversation_variables_module,
|
||||
"sessionmaker",
|
||||
lambda *_args, **_kwargs: SimpleNamespace(begin=lambda: nullcontext(session)),
|
||||
)
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/apps/app-1/conversation-variables",
|
||||
method="GET",
|
||||
query_string={"conversation_id": "conv-1"},
|
||||
):
|
||||
response = method(app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
assert response["data"][0]["value_type"] == "number"
|
||||
assert response["data"][0]["value"] == "42"
|
||||
|
||||
|
||||
def test_get_conversation_variables_requires_conversation_id(app) -> None:
|
||||
api = conversation_variables_module.ConversationVariablesApi()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
with app.test_request_context("/console/api/apps/app-1/conversation-variables", method="GET"):
|
||||
with pytest.raises(ValidationError):
|
||||
method(app_model=SimpleNamespace(id="app-1"))
|
||||
@@ -1,5 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
import pytest
|
||||
|
||||
from controllers.console.app import message as message_module
|
||||
@@ -120,3 +122,24 @@ def test_suggested_questions_response(app, monkeypatch: pytest.MonkeyPatch) -> N
|
||||
response = message_module.SuggestedQuestionsResponse(data=["What is AI?", "How does ML work?"])
|
||||
assert len(response.data) == 2
|
||||
assert response.data[0] == "What is AI?"
|
||||
|
||||
|
||||
def test_message_detail_response_normalizes_aliases_and_timestamp(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test MessageDetailResponse normalizes alias fields and datetime timestamps."""
|
||||
created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC)
|
||||
response = message_module.MessageDetailResponse.model_validate(
|
||||
{
|
||||
"id": "550e8400-e29b-41d4-a716-446655440000",
|
||||
"conversation_id": "550e8400-e29b-41d4-a716-446655440001",
|
||||
"inputs": {"foo": "bar"},
|
||||
"query": "hello",
|
||||
"re_sign_file_url_answer": "world",
|
||||
"from_source": "user",
|
||||
"status": "normal",
|
||||
"created_at": created_at,
|
||||
"message_metadata_dict": {"token_usage": 3},
|
||||
}
|
||||
)
|
||||
assert response.answer == "world"
|
||||
assert response.metadata == {"token_usage": 3}
|
||||
assert response.created_at == int(created_at.timestamp())
|
||||
|
||||
@@ -0,0 +1,85 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from graphon.enums import WorkflowExecutionStatus
|
||||
|
||||
from controllers.console.app import workflow_app_log as workflow_app_log_module
|
||||
|
||||
|
||||
def test_workflow_app_log_query_parses_bool_and_datetime():
|
||||
query = workflow_app_log_module.WorkflowAppLogQuery.model_validate(
|
||||
{
|
||||
"detail": "true",
|
||||
"created_at__before": "2026-01-02T03:04:05Z",
|
||||
"page": "2",
|
||||
"limit": "10",
|
||||
}
|
||||
)
|
||||
|
||||
assert query.detail is True
|
||||
assert query.created_at__before == datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC)
|
||||
assert query.page == 2
|
||||
assert query.limit == 10
|
||||
|
||||
|
||||
def test_workflow_app_log_pagination_response_normalizes_nested_fields():
|
||||
created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC)
|
||||
response = workflow_app_log_module.WorkflowAppLogPaginationResponse.model_validate(
|
||||
{
|
||||
"page": 1,
|
||||
"limit": 20,
|
||||
"total": 1,
|
||||
"has_more": False,
|
||||
"data": [
|
||||
{
|
||||
"id": "log-1",
|
||||
"workflow_run": {
|
||||
"id": "run-1",
|
||||
"status": WorkflowExecutionStatus.SUCCEEDED,
|
||||
"created_at": created_at,
|
||||
"finished_at": created_at,
|
||||
},
|
||||
"details": {"trigger_metadata": {}},
|
||||
"created_by_account": {"id": "acc-1", "name": "acc", "email": "acc@example.com"},
|
||||
"created_at": created_at,
|
||||
}
|
||||
],
|
||||
}
|
||||
).model_dump(mode="json")
|
||||
|
||||
assert response["data"][0]["workflow_run"]["status"] == "succeeded"
|
||||
assert response["data"][0]["workflow_run"]["created_at"] == int(created_at.timestamp())
|
||||
assert response["data"][0]["created_at"] == int(created_at.timestamp())
|
||||
|
||||
|
||||
def test_workflow_archived_log_pagination_response_normalizes_nested_fields():
|
||||
created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC)
|
||||
response = workflow_app_log_module.WorkflowArchivedLogPaginationResponse.model_validate(
|
||||
{
|
||||
"page": 1,
|
||||
"limit": 20,
|
||||
"total": 1,
|
||||
"has_more": False,
|
||||
"data": [
|
||||
{
|
||||
"id": "archived-1",
|
||||
"workflow_run": {
|
||||
"id": "run-1",
|
||||
"status": WorkflowExecutionStatus.FAILED,
|
||||
},
|
||||
"trigger_metadata": {"type": "trigger-plugin"},
|
||||
"created_by_end_user": {
|
||||
"id": "eu-1",
|
||||
"type": "anonymous",
|
||||
"is_anonymous": True,
|
||||
"session_id": "session-1",
|
||||
},
|
||||
"created_at": created_at,
|
||||
}
|
||||
],
|
||||
}
|
||||
).model_dump(mode="json")
|
||||
|
||||
assert response["data"][0]["workflow_run"]["status"] == "failed"
|
||||
assert response["data"][0]["created_at"] == int(created_at.timestamp())
|
||||
@@ -0,0 +1,54 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from types import SimpleNamespace
|
||||
|
||||
from controllers.console.app import workflow_trigger as workflow_trigger_module
|
||||
|
||||
|
||||
def test_parser_models_validate():
|
||||
parser = workflow_trigger_module.Parser(node_id="node-1")
|
||||
enable_parser = workflow_trigger_module.ParserEnable(
|
||||
trigger_id="550e8400-e29b-41d4-a716-446655440000", enable_trigger=True
|
||||
)
|
||||
|
||||
assert parser.node_id == "node-1"
|
||||
assert enable_parser.enable_trigger is True
|
||||
|
||||
|
||||
def test_workflow_trigger_response_serializes_datetime():
|
||||
created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC)
|
||||
trigger = SimpleNamespace(
|
||||
id="trigger-1",
|
||||
trigger_type="trigger-plugin",
|
||||
title="Trigger",
|
||||
node_id="node-1",
|
||||
provider_name="provider",
|
||||
icon="https://example.com/icon",
|
||||
status="enabled",
|
||||
created_at=created_at,
|
||||
updated_at=created_at,
|
||||
)
|
||||
|
||||
payload = workflow_trigger_module.WorkflowTriggerResponse.model_validate(trigger, from_attributes=True).model_dump(
|
||||
mode="json"
|
||||
)
|
||||
assert payload["id"] == "trigger-1"
|
||||
assert payload["created_at"] == "2026-01-02T03:04:05Z"
|
||||
assert payload["updated_at"] == "2026-01-02T03:04:05Z"
|
||||
|
||||
|
||||
def test_webhook_trigger_response_serializes_datetime():
|
||||
created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC)
|
||||
webhook = {
|
||||
"id": "webhook-1",
|
||||
"webhook_id": "whk-1",
|
||||
"webhook_url": "https://example.com/hook",
|
||||
"webhook_debug_url": "https://example.com/hook/debug",
|
||||
"node_id": "node-1",
|
||||
"created_at": created_at,
|
||||
}
|
||||
|
||||
payload = workflow_trigger_module.WebhookTriggerResponse.model_validate(webhook).model_dump(mode="json")
|
||||
assert payload["webhook_id"] == "whk-1"
|
||||
assert payload["created_at"] == "2026-01-02T03:04:05Z"
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user