mirror of
https://github.com/langgenius/dify.git
synced 2025-12-19 17:27:16 -05:00
chore: adopt StrEnum and auto() for some string-typed enums (#25129)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
This commit is contained in:
@@ -477,12 +477,12 @@ def convert_to_agent_apps():
|
||||
click.echo(f"Converting app: {app.id}")
|
||||
|
||||
try:
|
||||
app.mode = AppMode.AGENT_CHAT.value
|
||||
app.mode = AppMode.AGENT_CHAT
|
||||
db.session.commit()
|
||||
|
||||
# update conversation mode to agent
|
||||
db.session.query(Conversation).where(Conversation.app_id == app.id).update(
|
||||
{Conversation.mode: AppMode.AGENT_CHAT.value}
|
||||
{Conversation.mode: AppMode.AGENT_CHAT}
|
||||
)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import enum
|
||||
from enum import Enum
|
||||
from typing import Literal, Optional
|
||||
|
||||
from pydantic import Field, PositiveInt
|
||||
@@ -10,7 +10,7 @@ class OpenSearchConfig(BaseSettings):
|
||||
Configuration settings for OpenSearch
|
||||
"""
|
||||
|
||||
class AuthMethod(enum.StrEnum):
|
||||
class AuthMethod(Enum):
|
||||
"""
|
||||
Authentication method for OpenSearch
|
||||
"""
|
||||
|
||||
@@ -7,7 +7,7 @@ default_app_templates: Mapping[AppMode, Mapping] = {
|
||||
# workflow default mode
|
||||
AppMode.WORKFLOW: {
|
||||
"app": {
|
||||
"mode": AppMode.WORKFLOW.value,
|
||||
"mode": AppMode.WORKFLOW,
|
||||
"enable_site": True,
|
||||
"enable_api": True,
|
||||
}
|
||||
@@ -15,7 +15,7 @@ default_app_templates: Mapping[AppMode, Mapping] = {
|
||||
# completion default mode
|
||||
AppMode.COMPLETION: {
|
||||
"app": {
|
||||
"mode": AppMode.COMPLETION.value,
|
||||
"mode": AppMode.COMPLETION,
|
||||
"enable_site": True,
|
||||
"enable_api": True,
|
||||
},
|
||||
@@ -44,7 +44,7 @@ default_app_templates: Mapping[AppMode, Mapping] = {
|
||||
# chat default mode
|
||||
AppMode.CHAT: {
|
||||
"app": {
|
||||
"mode": AppMode.CHAT.value,
|
||||
"mode": AppMode.CHAT,
|
||||
"enable_site": True,
|
||||
"enable_api": True,
|
||||
},
|
||||
@@ -60,7 +60,7 @@ default_app_templates: Mapping[AppMode, Mapping] = {
|
||||
# advanced-chat default mode
|
||||
AppMode.ADVANCED_CHAT: {
|
||||
"app": {
|
||||
"mode": AppMode.ADVANCED_CHAT.value,
|
||||
"mode": AppMode.ADVANCED_CHAT,
|
||||
"enable_site": True,
|
||||
"enable_api": True,
|
||||
},
|
||||
@@ -68,7 +68,7 @@ default_app_templates: Mapping[AppMode, Mapping] = {
|
||||
# agent-chat default mode
|
||||
AppMode.AGENT_CHAT: {
|
||||
"app": {
|
||||
"mode": AppMode.AGENT_CHAT.value,
|
||||
"mode": AppMode.AGENT_CHAT,
|
||||
"enable_site": True,
|
||||
"enable_api": True,
|
||||
},
|
||||
|
||||
@@ -307,7 +307,7 @@ class ChatConversationApi(Resource):
|
||||
.having(func.count(Message.id) >= args["message_count_gte"])
|
||||
)
|
||||
|
||||
if app_model.mode == AppMode.ADVANCED_CHAT.value:
|
||||
if app_model.mode == AppMode.ADVANCED_CHAT:
|
||||
query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER.value)
|
||||
|
||||
match args["sort_by"]:
|
||||
|
||||
@@ -74,7 +74,7 @@ class ModelConfigResource(Resource):
|
||||
)
|
||||
new_app_model_config = new_app_model_config.from_model_config_dict(model_configuration)
|
||||
|
||||
if app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent:
|
||||
if app_model.mode == AppMode.AGENT_CHAT or app_model.is_agent:
|
||||
# get original app model config
|
||||
original_app_model_config = (
|
||||
db.session.query(AppModelConfig).where(AppModelConfig.id == app_model.app_model_config_id).first()
|
||||
|
||||
@@ -20,7 +20,7 @@ class AppParameterApi(InstalledAppResource):
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
|
||||
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
workflow = app_model.workflow
|
||||
if workflow is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
@@ -150,7 +150,7 @@ class MCPAppApi(Resource):
|
||||
def _get_user_input_form(self, app: App) -> list[VariableEntity]:
|
||||
"""Get and convert user input form"""
|
||||
# Get raw user input form based on app mode
|
||||
if app.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
|
||||
if app.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
if not app.workflow:
|
||||
raise MCPRequestError(mcp_types.INVALID_REQUEST, "App is unavailable")
|
||||
raw_user_input_form = app.workflow.user_input_form(to_old_structure=True)
|
||||
|
||||
@@ -29,7 +29,7 @@ class AppParameterApi(Resource):
|
||||
|
||||
Returns the input form parameters and configuration for the application.
|
||||
"""
|
||||
if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
|
||||
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
workflow = app_model.workflow
|
||||
if workflow is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
@@ -38,7 +38,7 @@ class AppParameterApi(WebApiResource):
|
||||
@marshal_with(fields.parameters_fields)
|
||||
def get(self, app_model: App, end_user):
|
||||
"""Retrieve app parameters."""
|
||||
if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
|
||||
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
workflow = app_model.workflow
|
||||
if workflow is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import enum
|
||||
from enum import StrEnum
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator
|
||||
@@ -26,25 +26,25 @@ class AgentStrategyProviderIdentity(ToolProviderIdentity):
|
||||
|
||||
|
||||
class AgentStrategyParameter(PluginParameter):
|
||||
class AgentStrategyParameterType(enum.StrEnum):
|
||||
class AgentStrategyParameterType(StrEnum):
|
||||
"""
|
||||
Keep all the types from PluginParameterType
|
||||
"""
|
||||
|
||||
STRING = CommonParameterType.STRING.value
|
||||
NUMBER = CommonParameterType.NUMBER.value
|
||||
BOOLEAN = CommonParameterType.BOOLEAN.value
|
||||
SELECT = CommonParameterType.SELECT.value
|
||||
SECRET_INPUT = CommonParameterType.SECRET_INPUT.value
|
||||
FILE = CommonParameterType.FILE.value
|
||||
FILES = CommonParameterType.FILES.value
|
||||
APP_SELECTOR = CommonParameterType.APP_SELECTOR.value
|
||||
MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value
|
||||
TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR.value
|
||||
ANY = CommonParameterType.ANY.value
|
||||
STRING = CommonParameterType.STRING
|
||||
NUMBER = CommonParameterType.NUMBER
|
||||
BOOLEAN = CommonParameterType.BOOLEAN
|
||||
SELECT = CommonParameterType.SELECT
|
||||
SECRET_INPUT = CommonParameterType.SECRET_INPUT
|
||||
FILE = CommonParameterType.FILE
|
||||
FILES = CommonParameterType.FILES
|
||||
APP_SELECTOR = CommonParameterType.APP_SELECTOR
|
||||
MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR
|
||||
TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR
|
||||
ANY = CommonParameterType.ANY
|
||||
|
||||
# deprecated, should not use.
|
||||
SYSTEM_FILES = CommonParameterType.SYSTEM_FILES.value
|
||||
SYSTEM_FILES = CommonParameterType.SYSTEM_FILES
|
||||
|
||||
def as_normal_type(self):
|
||||
return as_normal_type(self)
|
||||
@@ -72,7 +72,7 @@ class AgentStrategyIdentity(ToolIdentity):
|
||||
pass
|
||||
|
||||
|
||||
class AgentFeature(enum.StrEnum):
|
||||
class AgentFeature(StrEnum):
|
||||
"""
|
||||
Agent Feature, used to describe the features of the agent strategy.
|
||||
"""
|
||||
|
||||
@@ -70,7 +70,7 @@ class PromptTemplateConfigManager:
|
||||
:param config: app model config args
|
||||
"""
|
||||
if not config.get("prompt_type"):
|
||||
config["prompt_type"] = PromptTemplateEntity.PromptType.SIMPLE.value
|
||||
config["prompt_type"] = PromptTemplateEntity.PromptType.SIMPLE
|
||||
|
||||
prompt_type_vals = [typ.value for typ in PromptTemplateEntity.PromptType]
|
||||
if config["prompt_type"] not in prompt_type_vals:
|
||||
@@ -90,7 +90,7 @@ class PromptTemplateConfigManager:
|
||||
if not isinstance(config["completion_prompt_config"], dict):
|
||||
raise ValueError("completion_prompt_config must be of object type")
|
||||
|
||||
if config["prompt_type"] == PromptTemplateEntity.PromptType.ADVANCED.value:
|
||||
if config["prompt_type"] == PromptTemplateEntity.PromptType.ADVANCED:
|
||||
if not config["chat_prompt_config"] and not config["completion_prompt_config"]:
|
||||
raise ValueError(
|
||||
"chat_prompt_config or completion_prompt_config is required when prompt_type is advanced"
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from collections.abc import Sequence
|
||||
from enum import Enum, StrEnum
|
||||
from enum import StrEnum, auto
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
@@ -61,14 +61,14 @@ class PromptTemplateEntity(BaseModel):
|
||||
Prompt Template Entity.
|
||||
"""
|
||||
|
||||
class PromptType(Enum):
|
||||
class PromptType(StrEnum):
|
||||
"""
|
||||
Prompt Type.
|
||||
'simple', 'advanced'
|
||||
"""
|
||||
|
||||
SIMPLE = "simple"
|
||||
ADVANCED = "advanced"
|
||||
SIMPLE = auto()
|
||||
ADVANCED = auto()
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str):
|
||||
@@ -195,14 +195,14 @@ class DatasetRetrieveConfigEntity(BaseModel):
|
||||
Dataset Retrieve Config Entity.
|
||||
"""
|
||||
|
||||
class RetrieveStrategy(Enum):
|
||||
class RetrieveStrategy(StrEnum):
|
||||
"""
|
||||
Dataset Retrieve Strategy.
|
||||
'single' or 'multiple'
|
||||
"""
|
||||
|
||||
SINGLE = "single"
|
||||
MULTIPLE = "multiple"
|
||||
SINGLE = auto()
|
||||
MULTIPLE = auto()
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str):
|
||||
@@ -293,12 +293,12 @@ class AppConfig(BaseModel):
|
||||
sensitive_word_avoidance: Optional[SensitiveWordAvoidanceEntity] = None
|
||||
|
||||
|
||||
class EasyUIBasedAppModelConfigFrom(Enum):
|
||||
class EasyUIBasedAppModelConfigFrom(StrEnum):
|
||||
"""
|
||||
App Model Config From.
|
||||
"""
|
||||
|
||||
ARGS = "args"
|
||||
ARGS = auto()
|
||||
APP_LATEST_CONFIG = "app-latest-config"
|
||||
CONVERSATION_SPECIFIC_CONFIG = "conversation-specific-config"
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from enum import Enum, StrEnum
|
||||
from enum import StrEnum, auto
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -626,15 +626,15 @@ class QueueStopEvent(AppQueueEvent):
|
||||
QueueStopEvent entity
|
||||
"""
|
||||
|
||||
class StopBy(Enum):
|
||||
class StopBy(StrEnum):
|
||||
"""
|
||||
Stop by enum
|
||||
"""
|
||||
|
||||
USER_MANUAL = "user-manual"
|
||||
ANNOTATION_REPLY = "annotation-reply"
|
||||
OUTPUT_MODERATION = "output-moderation"
|
||||
INPUT_MODERATION = "input-moderation"
|
||||
USER_MANUAL = auto()
|
||||
ANNOTATION_REPLY = auto()
|
||||
OUTPUT_MODERATION = auto()
|
||||
INPUT_MODERATION = auto()
|
||||
|
||||
event: QueueEvent = QueueEvent.STOP
|
||||
stopped_by: StopBy
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from enum import Enum
|
||||
from enum import StrEnum, auto
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
@@ -50,37 +50,37 @@ class WorkflowTaskState(TaskState):
|
||||
answer: str = ""
|
||||
|
||||
|
||||
class StreamEvent(Enum):
|
||||
class StreamEvent(StrEnum):
|
||||
"""
|
||||
Stream event
|
||||
"""
|
||||
|
||||
PING = "ping"
|
||||
ERROR = "error"
|
||||
MESSAGE = "message"
|
||||
MESSAGE_END = "message_end"
|
||||
TTS_MESSAGE = "tts_message"
|
||||
TTS_MESSAGE_END = "tts_message_end"
|
||||
MESSAGE_FILE = "message_file"
|
||||
MESSAGE_REPLACE = "message_replace"
|
||||
AGENT_THOUGHT = "agent_thought"
|
||||
AGENT_MESSAGE = "agent_message"
|
||||
WORKFLOW_STARTED = "workflow_started"
|
||||
WORKFLOW_FINISHED = "workflow_finished"
|
||||
NODE_STARTED = "node_started"
|
||||
NODE_FINISHED = "node_finished"
|
||||
NODE_RETRY = "node_retry"
|
||||
PARALLEL_BRANCH_STARTED = "parallel_branch_started"
|
||||
PARALLEL_BRANCH_FINISHED = "parallel_branch_finished"
|
||||
ITERATION_STARTED = "iteration_started"
|
||||
ITERATION_NEXT = "iteration_next"
|
||||
ITERATION_COMPLETED = "iteration_completed"
|
||||
LOOP_STARTED = "loop_started"
|
||||
LOOP_NEXT = "loop_next"
|
||||
LOOP_COMPLETED = "loop_completed"
|
||||
TEXT_CHUNK = "text_chunk"
|
||||
TEXT_REPLACE = "text_replace"
|
||||
AGENT_LOG = "agent_log"
|
||||
PING = auto()
|
||||
ERROR = auto()
|
||||
MESSAGE = auto()
|
||||
MESSAGE_END = auto()
|
||||
TTS_MESSAGE = auto()
|
||||
TTS_MESSAGE_END = auto()
|
||||
MESSAGE_FILE = auto()
|
||||
MESSAGE_REPLACE = auto()
|
||||
AGENT_THOUGHT = auto()
|
||||
AGENT_MESSAGE = auto()
|
||||
WORKFLOW_STARTED = auto()
|
||||
WORKFLOW_FINISHED = auto()
|
||||
NODE_STARTED = auto()
|
||||
NODE_FINISHED = auto()
|
||||
NODE_RETRY = auto()
|
||||
PARALLEL_BRANCH_STARTED = auto()
|
||||
PARALLEL_BRANCH_FINISHED = auto()
|
||||
ITERATION_STARTED = auto()
|
||||
ITERATION_NEXT = auto()
|
||||
ITERATION_COMPLETED = auto()
|
||||
LOOP_STARTED = auto()
|
||||
LOOP_NEXT = auto()
|
||||
LOOP_COMPLETED = auto()
|
||||
TEXT_CHUNK = auto()
|
||||
TEXT_REPLACE = auto()
|
||||
AGENT_LOG = auto()
|
||||
|
||||
|
||||
class StreamResponse(BaseModel):
|
||||
|
||||
@@ -145,7 +145,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
if self._task_state.metadata:
|
||||
extras["metadata"] = self._task_state.metadata.model_dump()
|
||||
response: Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse]
|
||||
if self._conversation_mode == AppMode.COMPLETION.value:
|
||||
if self._conversation_mode == AppMode.COMPLETION:
|
||||
response = CompletionAppBlockingResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
data=CompletionAppBlockingResponse.Data(
|
||||
|
||||
@@ -92,7 +92,7 @@ class MessageCycleManager:
|
||||
if not conversation:
|
||||
return
|
||||
|
||||
if conversation.mode != AppMode.COMPLETION.value:
|
||||
if conversation.mode != AppMode.COMPLETION:
|
||||
app_model = conversation.app
|
||||
if not app_model:
|
||||
return
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from enum import Enum
|
||||
from enum import StrEnum, auto
|
||||
|
||||
|
||||
class PlanningStrategy(Enum):
|
||||
ROUTER = "router"
|
||||
REACT_ROUTER = "react_router"
|
||||
REACT = "react"
|
||||
FUNCTION_CALL = "function_call"
|
||||
class PlanningStrategy(StrEnum):
|
||||
ROUTER = auto()
|
||||
REACT_ROUTER = auto()
|
||||
REACT = auto()
|
||||
FUNCTION_CALL = auto()
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
from enum import Enum
|
||||
from enum import StrEnum, auto
|
||||
|
||||
|
||||
class EmbeddingInputType(Enum):
|
||||
class EmbeddingInputType(StrEnum):
|
||||
"""
|
||||
Enum for embedding input type.
|
||||
"""
|
||||
|
||||
DOCUMENT = "document"
|
||||
QUERY = "query"
|
||||
DOCUMENT = auto()
|
||||
QUERY = auto()
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from collections.abc import Sequence
|
||||
from enum import Enum
|
||||
from enum import StrEnum, auto
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
@@ -9,16 +9,16 @@ from core.model_runtime.entities.model_entities import ModelType, ProviderModel
|
||||
from core.model_runtime.entities.provider_entities import ProviderEntity
|
||||
|
||||
|
||||
class ModelStatus(Enum):
|
||||
class ModelStatus(StrEnum):
|
||||
"""
|
||||
Enum class for model status.
|
||||
"""
|
||||
|
||||
ACTIVE = "active"
|
||||
ACTIVE = auto()
|
||||
NO_CONFIGURE = "no-configure"
|
||||
QUOTA_EXCEEDED = "quota-exceeded"
|
||||
NO_PERMISSION = "no-permission"
|
||||
DISABLED = "disabled"
|
||||
DISABLED = auto()
|
||||
CREDENTIAL_REMOVED = "credential-removed"
|
||||
|
||||
|
||||
|
||||
@@ -1,20 +1,20 @@
|
||||
from enum import StrEnum
|
||||
from enum import StrEnum, auto
|
||||
|
||||
|
||||
class CommonParameterType(StrEnum):
|
||||
SECRET_INPUT = "secret-input"
|
||||
TEXT_INPUT = "text-input"
|
||||
SELECT = "select"
|
||||
STRING = "string"
|
||||
NUMBER = "number"
|
||||
FILE = "file"
|
||||
FILES = "files"
|
||||
SELECT = auto()
|
||||
STRING = auto()
|
||||
NUMBER = auto()
|
||||
FILE = auto()
|
||||
FILES = auto()
|
||||
SYSTEM_FILES = "system-files"
|
||||
BOOLEAN = "boolean"
|
||||
BOOLEAN = auto()
|
||||
APP_SELECTOR = "app-selector"
|
||||
MODEL_SELECTOR = "model-selector"
|
||||
TOOLS_SELECTOR = "array[tools]"
|
||||
ANY = "any"
|
||||
ANY = auto()
|
||||
|
||||
# Dynamic select parameter
|
||||
# Once you are not sure about the available options until authorization is done
|
||||
@@ -23,29 +23,29 @@ class CommonParameterType(StrEnum):
|
||||
|
||||
# TOOL_SELECTOR = "tool-selector"
|
||||
# MCP object and array type parameters
|
||||
ARRAY = "array"
|
||||
OBJECT = "object"
|
||||
ARRAY = auto()
|
||||
OBJECT = auto()
|
||||
|
||||
|
||||
class AppSelectorScope(StrEnum):
|
||||
ALL = "all"
|
||||
CHAT = "chat"
|
||||
WORKFLOW = "workflow"
|
||||
COMPLETION = "completion"
|
||||
ALL = auto()
|
||||
CHAT = auto()
|
||||
WORKFLOW = auto()
|
||||
COMPLETION = auto()
|
||||
|
||||
|
||||
class ModelSelectorScope(StrEnum):
|
||||
LLM = "llm"
|
||||
LLM = auto()
|
||||
TEXT_EMBEDDING = "text-embedding"
|
||||
RERANK = "rerank"
|
||||
TTS = "tts"
|
||||
SPEECH2TEXT = "speech2text"
|
||||
MODERATION = "moderation"
|
||||
VISION = "vision"
|
||||
RERANK = auto()
|
||||
TTS = auto()
|
||||
SPEECH2TEXT = auto()
|
||||
MODERATION = auto()
|
||||
VISION = auto()
|
||||
|
||||
|
||||
class ToolSelectorScope(StrEnum):
|
||||
ALL = "all"
|
||||
CUSTOM = "custom"
|
||||
BUILTIN = "builtin"
|
||||
WORKFLOW = "workflow"
|
||||
ALL = auto()
|
||||
CUSTOM = auto()
|
||||
BUILTIN = auto()
|
||||
WORKFLOW = auto()
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from enum import Enum
|
||||
from enum import StrEnum, auto
|
||||
from typing import Optional, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
@@ -13,14 +13,14 @@ from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
|
||||
|
||||
class ProviderQuotaType(Enum):
|
||||
PAID = "paid"
|
||||
class ProviderQuotaType(StrEnum):
|
||||
PAID = auto()
|
||||
"""hosted paid quota"""
|
||||
|
||||
FREE = "free"
|
||||
FREE = auto()
|
||||
"""third-party free quota"""
|
||||
|
||||
TRIAL = "trial"
|
||||
TRIAL = auto()
|
||||
"""hosted trial quota"""
|
||||
|
||||
@staticmethod
|
||||
@@ -31,20 +31,20 @@ class ProviderQuotaType(Enum):
|
||||
raise ValueError(f"No matching enum found for value '{value}'")
|
||||
|
||||
|
||||
class QuotaUnit(Enum):
|
||||
TIMES = "times"
|
||||
TOKENS = "tokens"
|
||||
CREDITS = "credits"
|
||||
class QuotaUnit(StrEnum):
|
||||
TIMES = auto()
|
||||
TOKENS = auto()
|
||||
CREDITS = auto()
|
||||
|
||||
|
||||
class SystemConfigurationStatus(Enum):
|
||||
class SystemConfigurationStatus(StrEnum):
|
||||
"""
|
||||
Enum class for system configuration status.
|
||||
"""
|
||||
|
||||
ACTIVE = "active"
|
||||
ACTIVE = auto()
|
||||
QUOTA_EXCEEDED = "quota-exceeded"
|
||||
UNSUPPORTED = "unsupported"
|
||||
UNSUPPORTED = auto()
|
||||
|
||||
|
||||
class RestrictModel(BaseModel):
|
||||
@@ -168,14 +168,14 @@ class BasicProviderConfig(BaseModel):
|
||||
Base model class for common provider settings like credentials
|
||||
"""
|
||||
|
||||
class Type(Enum):
|
||||
SECRET_INPUT = CommonParameterType.SECRET_INPUT.value
|
||||
TEXT_INPUT = CommonParameterType.TEXT_INPUT.value
|
||||
SELECT = CommonParameterType.SELECT.value
|
||||
BOOLEAN = CommonParameterType.BOOLEAN.value
|
||||
APP_SELECTOR = CommonParameterType.APP_SELECTOR.value
|
||||
MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value
|
||||
TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR.value
|
||||
class Type(StrEnum):
|
||||
SECRET_INPUT = CommonParameterType.SECRET_INPUT
|
||||
TEXT_INPUT = CommonParameterType.TEXT_INPUT
|
||||
SELECT = CommonParameterType.SELECT
|
||||
BOOLEAN = CommonParameterType.BOOLEAN
|
||||
APP_SELECTOR = CommonParameterType.APP_SELECTOR
|
||||
MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR
|
||||
TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "ProviderConfig.Type":
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import enum
|
||||
import importlib.util
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from enum import StrEnum, auto
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
@@ -13,9 +13,9 @@ from core.helper.position_helper import sort_to_dict_by_position_map
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ExtensionModule(enum.Enum):
|
||||
MODERATION = "moderation"
|
||||
EXTERNAL_DATA_TOOL = "external_data_tool"
|
||||
class ExtensionModule(StrEnum):
|
||||
MODERATION = auto()
|
||||
EXTERNAL_DATA_TOOL = auto()
|
||||
|
||||
|
||||
class ModuleExtension(BaseModel):
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
import json
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from json import JSONDecodeError
|
||||
from typing import Optional
|
||||
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
|
||||
class ProviderCredentialsCacheType(Enum):
|
||||
class ProviderCredentialsCacheType(StrEnum):
|
||||
PROVIDER = "provider"
|
||||
MODEL = "provider_model"
|
||||
LOAD_BALANCING_MODEL = "load_balancing_provider_model"
|
||||
@@ -14,7 +14,7 @@ class ProviderCredentialsCacheType(Enum):
|
||||
|
||||
class ProviderCredentialsCache:
|
||||
def __init__(self, tenant_id: str, identity_id: str, cache_type: ProviderCredentialsCacheType):
|
||||
self.cache_key = f"{cache_type.value}_credentials:tenant_id:{tenant_id}:id:{identity_id}"
|
||||
self.cache_key = f"{cache_type}_credentials:tenant_id:{tenant_id}:id:{identity_id}"
|
||||
|
||||
def get(self) -> Optional[dict]:
|
||||
"""
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
import json
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from json import JSONDecodeError
|
||||
from typing import Optional
|
||||
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
|
||||
class ToolParameterCacheType(Enum):
|
||||
class ToolParameterCacheType(StrEnum):
|
||||
PARAMETER = "tool_parameter"
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ class ToolParameterCache:
|
||||
self, tenant_id: str, provider: str, tool_name: str, cache_type: ToolParameterCacheType, identity_id: str
|
||||
):
|
||||
self.cache_key = (
|
||||
f"{cache_type.value}_secret:tenant_id:{tenant_id}:provider:{provider}:tool_name:{tool_name}"
|
||||
f"{cache_type}_secret:tenant_id:{tenant_id}:provider:{provider}:tool_name:{tool_name}"
|
||||
f":identity_id:{identity_id}"
|
||||
)
|
||||
|
||||
|
||||
@@ -142,7 +142,7 @@ def handle_call_tool(
|
||||
end_user,
|
||||
args,
|
||||
InvokeFrom.SERVICE_API,
|
||||
streaming=app.mode == AppMode.AGENT_CHAT.value,
|
||||
streaming=app.mode == AppMode.AGENT_CHAT,
|
||||
)
|
||||
|
||||
answer = extract_answer_from_response(app, response)
|
||||
@@ -157,7 +157,7 @@ def build_parameter_schema(
|
||||
"""Build parameter schema for the tool"""
|
||||
parameters, required = convert_input_form_to_parameters(user_input_form, parameters_dict)
|
||||
|
||||
if app_mode in {AppMode.COMPLETION.value, AppMode.WORKFLOW.value}:
|
||||
if app_mode in {AppMode.COMPLETION, AppMode.WORKFLOW}:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": parameters,
|
||||
@@ -175,9 +175,9 @@ def build_parameter_schema(
|
||||
|
||||
def prepare_tool_arguments(app: App, arguments: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Prepare arguments based on app mode"""
|
||||
if app.mode == AppMode.WORKFLOW.value:
|
||||
if app.mode == AppMode.WORKFLOW:
|
||||
return {"inputs": arguments}
|
||||
elif app.mode == AppMode.COMPLETION.value:
|
||||
elif app.mode == AppMode.COMPLETION:
|
||||
return {"query": "", "inputs": arguments}
|
||||
else:
|
||||
# Chat modes - create a copy to avoid modifying original dict
|
||||
@@ -218,13 +218,13 @@ def process_streaming_response(response: RateLimitGenerator) -> str:
|
||||
def process_mapping_response(app: App, response: Mapping) -> str:
|
||||
"""Process mapping response based on app mode"""
|
||||
if app.mode in {
|
||||
AppMode.ADVANCED_CHAT.value,
|
||||
AppMode.COMPLETION.value,
|
||||
AppMode.CHAT.value,
|
||||
AppMode.AGENT_CHAT.value,
|
||||
AppMode.ADVANCED_CHAT,
|
||||
AppMode.COMPLETION,
|
||||
AppMode.CHAT,
|
||||
AppMode.AGENT_CHAT,
|
||||
}:
|
||||
return response.get("answer", "")
|
||||
elif app.mode == AppMode.WORKFLOW.value:
|
||||
elif app.mode == AppMode.WORKFLOW:
|
||||
return json.dumps(response["data"]["outputs"], ensure_ascii=False)
|
||||
else:
|
||||
raise ValueError("Invalid app mode: " + str(app.mode))
|
||||
|
||||
@@ -1,20 +1,20 @@
|
||||
from abc import ABC
|
||||
from collections.abc import Mapping, Sequence
|
||||
from enum import Enum, StrEnum
|
||||
from enum import StrEnum, auto
|
||||
from typing import Annotated, Any, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field, field_serializer, field_validator
|
||||
|
||||
|
||||
class PromptMessageRole(Enum):
|
||||
class PromptMessageRole(StrEnum):
|
||||
"""
|
||||
Enum class for prompt message.
|
||||
"""
|
||||
|
||||
SYSTEM = "system"
|
||||
USER = "user"
|
||||
ASSISTANT = "assistant"
|
||||
TOOL = "tool"
|
||||
SYSTEM = auto()
|
||||
USER = auto()
|
||||
ASSISTANT = auto()
|
||||
TOOL = auto()
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "PromptMessageRole":
|
||||
@@ -54,11 +54,11 @@ class PromptMessageContentType(StrEnum):
|
||||
Enum class for prompt message content type.
|
||||
"""
|
||||
|
||||
TEXT = "text"
|
||||
IMAGE = "image"
|
||||
AUDIO = "audio"
|
||||
VIDEO = "video"
|
||||
DOCUMENT = "document"
|
||||
TEXT = auto()
|
||||
IMAGE = auto()
|
||||
AUDIO = auto()
|
||||
VIDEO = auto()
|
||||
DOCUMENT = auto()
|
||||
|
||||
|
||||
class PromptMessageContent(ABC, BaseModel):
|
||||
@@ -108,8 +108,8 @@ class ImagePromptMessageContent(MultiModalPromptMessageContent):
|
||||
"""
|
||||
|
||||
class DETAIL(StrEnum):
|
||||
LOW = "low"
|
||||
HIGH = "high"
|
||||
LOW = auto()
|
||||
HIGH = auto()
|
||||
|
||||
type: Literal[PromptMessageContentType.IMAGE] = PromptMessageContentType.IMAGE
|
||||
detail: DETAIL = DETAIL.LOW
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from decimal import Decimal
|
||||
from enum import Enum, StrEnum
|
||||
from enum import StrEnum, auto
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, model_validator
|
||||
@@ -7,17 +7,17 @@ from pydantic import BaseModel, ConfigDict, model_validator
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
|
||||
|
||||
class ModelType(Enum):
|
||||
class ModelType(StrEnum):
|
||||
"""
|
||||
Enum class for model type.
|
||||
"""
|
||||
|
||||
LLM = "llm"
|
||||
LLM = auto()
|
||||
TEXT_EMBEDDING = "text-embedding"
|
||||
RERANK = "rerank"
|
||||
SPEECH2TEXT = "speech2text"
|
||||
MODERATION = "moderation"
|
||||
TTS = "tts"
|
||||
RERANK = auto()
|
||||
SPEECH2TEXT = auto()
|
||||
MODERATION = auto()
|
||||
TTS = auto()
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, origin_model_type: str) -> "ModelType":
|
||||
@@ -26,17 +26,17 @@ class ModelType(Enum):
|
||||
|
||||
:return: model type
|
||||
"""
|
||||
if origin_model_type in {"text-generation", cls.LLM.value}:
|
||||
if origin_model_type in {"text-generation", cls.LLM}:
|
||||
return cls.LLM
|
||||
elif origin_model_type in {"embeddings", cls.TEXT_EMBEDDING.value}:
|
||||
elif origin_model_type in {"embeddings", cls.TEXT_EMBEDDING}:
|
||||
return cls.TEXT_EMBEDDING
|
||||
elif origin_model_type in {"reranking", cls.RERANK.value}:
|
||||
elif origin_model_type in {"reranking", cls.RERANK}:
|
||||
return cls.RERANK
|
||||
elif origin_model_type in {"speech2text", cls.SPEECH2TEXT.value}:
|
||||
elif origin_model_type in {"speech2text", cls.SPEECH2TEXT}:
|
||||
return cls.SPEECH2TEXT
|
||||
elif origin_model_type in {"tts", cls.TTS.value}:
|
||||
elif origin_model_type in {"tts", cls.TTS}:
|
||||
return cls.TTS
|
||||
elif origin_model_type == cls.MODERATION.value:
|
||||
elif origin_model_type == cls.MODERATION:
|
||||
return cls.MODERATION
|
||||
else:
|
||||
raise ValueError(f"invalid origin model type {origin_model_type}")
|
||||
@@ -63,7 +63,7 @@ class ModelType(Enum):
|
||||
raise ValueError(f"invalid model type {self}")
|
||||
|
||||
|
||||
class FetchFrom(Enum):
|
||||
class FetchFrom(StrEnum):
|
||||
"""
|
||||
Enum class for fetch from.
|
||||
"""
|
||||
@@ -72,7 +72,7 @@ class FetchFrom(Enum):
|
||||
CUSTOMIZABLE_MODEL = "customizable-model"
|
||||
|
||||
|
||||
class ModelFeature(Enum):
|
||||
class ModelFeature(StrEnum):
|
||||
"""
|
||||
Enum class for llm feature.
|
||||
"""
|
||||
@@ -80,11 +80,11 @@ class ModelFeature(Enum):
|
||||
TOOL_CALL = "tool-call"
|
||||
MULTI_TOOL_CALL = "multi-tool-call"
|
||||
AGENT_THOUGHT = "agent-thought"
|
||||
VISION = "vision"
|
||||
VISION = auto()
|
||||
STREAM_TOOL_CALL = "stream-tool-call"
|
||||
DOCUMENT = "document"
|
||||
VIDEO = "video"
|
||||
AUDIO = "audio"
|
||||
DOCUMENT = auto()
|
||||
VIDEO = auto()
|
||||
AUDIO = auto()
|
||||
STRUCTURED_OUTPUT = "structured-output"
|
||||
|
||||
|
||||
@@ -93,14 +93,14 @@ class DefaultParameterName(StrEnum):
|
||||
Enum class for parameter template variable.
|
||||
"""
|
||||
|
||||
TEMPERATURE = "temperature"
|
||||
TOP_P = "top_p"
|
||||
TOP_K = "top_k"
|
||||
PRESENCE_PENALTY = "presence_penalty"
|
||||
FREQUENCY_PENALTY = "frequency_penalty"
|
||||
MAX_TOKENS = "max_tokens"
|
||||
RESPONSE_FORMAT = "response_format"
|
||||
JSON_SCHEMA = "json_schema"
|
||||
TEMPERATURE = auto()
|
||||
TOP_P = auto()
|
||||
TOP_K = auto()
|
||||
PRESENCE_PENALTY = auto()
|
||||
FREQUENCY_PENALTY = auto()
|
||||
MAX_TOKENS = auto()
|
||||
RESPONSE_FORMAT = auto()
|
||||
JSON_SCHEMA = auto()
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: Any) -> "DefaultParameterName":
|
||||
@@ -116,34 +116,34 @@ class DefaultParameterName(StrEnum):
|
||||
raise ValueError(f"invalid parameter name {value}")
|
||||
|
||||
|
||||
class ParameterType(Enum):
|
||||
class ParameterType(StrEnum):
|
||||
"""
|
||||
Enum class for parameter type.
|
||||
"""
|
||||
|
||||
FLOAT = "float"
|
||||
INT = "int"
|
||||
STRING = "string"
|
||||
BOOLEAN = "boolean"
|
||||
TEXT = "text"
|
||||
FLOAT = auto()
|
||||
INT = auto()
|
||||
STRING = auto()
|
||||
BOOLEAN = auto()
|
||||
TEXT = auto()
|
||||
|
||||
|
||||
class ModelPropertyKey(Enum):
|
||||
class ModelPropertyKey(StrEnum):
|
||||
"""
|
||||
Enum class for model property key.
|
||||
"""
|
||||
|
||||
MODE = "mode"
|
||||
CONTEXT_SIZE = "context_size"
|
||||
MAX_CHUNKS = "max_chunks"
|
||||
FILE_UPLOAD_LIMIT = "file_upload_limit"
|
||||
SUPPORTED_FILE_EXTENSIONS = "supported_file_extensions"
|
||||
MAX_CHARACTERS_PER_CHUNK = "max_characters_per_chunk"
|
||||
DEFAULT_VOICE = "default_voice"
|
||||
VOICES = "voices"
|
||||
WORD_LIMIT = "word_limit"
|
||||
AUDIO_TYPE = "audio_type"
|
||||
MAX_WORKERS = "max_workers"
|
||||
MODE = auto()
|
||||
CONTEXT_SIZE = auto()
|
||||
MAX_CHUNKS = auto()
|
||||
FILE_UPLOAD_LIMIT = auto()
|
||||
SUPPORTED_FILE_EXTENSIONS = auto()
|
||||
MAX_CHARACTERS_PER_CHUNK = auto()
|
||||
DEFAULT_VOICE = auto()
|
||||
VOICES = auto()
|
||||
WORD_LIMIT = auto()
|
||||
AUDIO_TYPE = auto()
|
||||
MAX_WORKERS = auto()
|
||||
|
||||
|
||||
class ProviderModel(BaseModel):
|
||||
@@ -220,13 +220,13 @@ class ModelUsage(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
class PriceType(Enum):
|
||||
class PriceType(StrEnum):
|
||||
"""
|
||||
Enum class for price type.
|
||||
"""
|
||||
|
||||
INPUT = "input"
|
||||
OUTPUT = "output"
|
||||
INPUT = auto()
|
||||
OUTPUT = auto()
|
||||
|
||||
|
||||
class PriceInfo(BaseModel):
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from collections.abc import Sequence
|
||||
from enum import Enum
|
||||
from enum import Enum, StrEnum, auto
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
@@ -17,16 +17,16 @@ class ConfigurateMethod(Enum):
|
||||
CUSTOMIZABLE_MODEL = "customizable-model"
|
||||
|
||||
|
||||
class FormType(Enum):
|
||||
class FormType(StrEnum):
|
||||
"""
|
||||
Enum class for form type.
|
||||
"""
|
||||
|
||||
TEXT_INPUT = "text-input"
|
||||
SECRET_INPUT = "secret-input"
|
||||
SELECT = "select"
|
||||
RADIO = "radio"
|
||||
SWITCH = "switch"
|
||||
SELECT = auto()
|
||||
RADIO = auto()
|
||||
SWITCH = auto()
|
||||
|
||||
|
||||
class FormShowOnObject(BaseModel):
|
||||
|
||||
@@ -47,7 +47,7 @@ class TextEmbeddingModel(AIModel):
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
texts=texts,
|
||||
input_type=input_type.value,
|
||||
input_type=input_type,
|
||||
)
|
||||
except Exception as e:
|
||||
raise self._transform_invoke_error(e)
|
||||
|
||||
@@ -18,7 +18,7 @@ from pydantic_core import Url
|
||||
from pydantic_extra_types.color import Color
|
||||
|
||||
|
||||
def _model_dump(model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any):
|
||||
def _model_dump(model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any) -> Any:
|
||||
return model.model_dump(mode=mode, **kwargs)
|
||||
|
||||
|
||||
@@ -100,7 +100,7 @@ def jsonable_encoder(
|
||||
exclude_none: bool = False,
|
||||
custom_encoder: Optional[dict[Any, Callable[[Any], Any]]] = None,
|
||||
sqlalchemy_safe: bool = True,
|
||||
):
|
||||
) -> Any:
|
||||
custom_encoder = custom_encoder or {}
|
||||
if custom_encoder:
|
||||
if type(obj) in custom_encoder:
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from enum import StrEnum, auto
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -7,9 +7,9 @@ from pydantic import BaseModel, Field
|
||||
from core.extension.extensible import Extensible, ExtensionModule
|
||||
|
||||
|
||||
class ModerationAction(Enum):
|
||||
DIRECT_OUTPUT = "direct_output"
|
||||
OVERRIDDEN = "overridden"
|
||||
class ModerationAction(StrEnum):
|
||||
DIRECT_OUTPUT = auto()
|
||||
OVERRIDDEN = auto()
|
||||
|
||||
|
||||
class ModerationInputsResult(BaseModel):
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
|
||||
# public
|
||||
GEN_AI_SESSION_ID = "gen_ai.session.id"
|
||||
@@ -53,7 +53,7 @@ TOOL_DESCRIPTION = "tool.description"
|
||||
TOOL_PARAMETERS = "tool.parameters"
|
||||
|
||||
|
||||
class GenAISpanKind(Enum):
|
||||
class GenAISpanKind(StrEnum):
|
||||
CHAIN = "CHAIN"
|
||||
RETRIEVER = "RETRIEVER"
|
||||
RERANKER = "RERANKER"
|
||||
|
||||
@@ -27,7 +27,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
||||
app = cls._get_app(app_id, tenant_id)
|
||||
|
||||
"""Retrieve app parameters."""
|
||||
if app.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
|
||||
if app.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
workflow = app.workflow
|
||||
if workflow is None:
|
||||
raise ValueError("unexpected app type")
|
||||
@@ -70,7 +70,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
||||
|
||||
conversation_id = conversation_id or ""
|
||||
|
||||
if app.mode in {AppMode.ADVANCED_CHAT.value, AppMode.AGENT_CHAT.value, AppMode.CHAT.value}:
|
||||
if app.mode in {AppMode.ADVANCED_CHAT, AppMode.AGENT_CHAT, AppMode.CHAT}:
|
||||
if not query:
|
||||
raise ValueError("missing query")
|
||||
|
||||
@@ -96,7 +96,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
||||
"""
|
||||
invoke chat app
|
||||
"""
|
||||
if app.mode == AppMode.ADVANCED_CHAT.value:
|
||||
if app.mode == AppMode.ADVANCED_CHAT:
|
||||
workflow = app.workflow
|
||||
if not workflow:
|
||||
raise ValueError("unexpected app type")
|
||||
@@ -114,7 +114,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
streaming=stream,
|
||||
)
|
||||
elif app.mode == AppMode.AGENT_CHAT.value:
|
||||
elif app.mode == AppMode.AGENT_CHAT:
|
||||
return AgentChatAppGenerator().generate(
|
||||
app_model=app,
|
||||
user=user,
|
||||
@@ -127,7 +127,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
streaming=stream,
|
||||
)
|
||||
elif app.mode == AppMode.CHAT.value:
|
||||
elif app.mode == AppMode.CHAT:
|
||||
return ChatAppGenerator().generate(
|
||||
app_model=app,
|
||||
user=user,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import enum
|
||||
import json
|
||||
from enum import StrEnum, auto
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
@@ -25,44 +25,44 @@ class PluginParameterOption(BaseModel):
|
||||
return value
|
||||
|
||||
|
||||
class PluginParameterType(enum.StrEnum):
|
||||
class PluginParameterType(StrEnum):
|
||||
"""
|
||||
all available parameter types
|
||||
"""
|
||||
|
||||
STRING = CommonParameterType.STRING.value
|
||||
NUMBER = CommonParameterType.NUMBER.value
|
||||
BOOLEAN = CommonParameterType.BOOLEAN.value
|
||||
SELECT = CommonParameterType.SELECT.value
|
||||
SECRET_INPUT = CommonParameterType.SECRET_INPUT.value
|
||||
FILE = CommonParameterType.FILE.value
|
||||
FILES = CommonParameterType.FILES.value
|
||||
APP_SELECTOR = CommonParameterType.APP_SELECTOR.value
|
||||
MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value
|
||||
TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR.value
|
||||
ANY = CommonParameterType.ANY.value
|
||||
DYNAMIC_SELECT = CommonParameterType.DYNAMIC_SELECT.value
|
||||
STRING = CommonParameterType.STRING
|
||||
NUMBER = CommonParameterType.NUMBER
|
||||
BOOLEAN = CommonParameterType.BOOLEAN
|
||||
SELECT = CommonParameterType.SELECT
|
||||
SECRET_INPUT = CommonParameterType.SECRET_INPUT
|
||||
FILE = CommonParameterType.FILE
|
||||
FILES = CommonParameterType.FILES
|
||||
APP_SELECTOR = CommonParameterType.APP_SELECTOR
|
||||
MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR
|
||||
TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR
|
||||
ANY = CommonParameterType.ANY
|
||||
DYNAMIC_SELECT = CommonParameterType.DYNAMIC_SELECT
|
||||
|
||||
# deprecated, should not use.
|
||||
SYSTEM_FILES = CommonParameterType.SYSTEM_FILES.value
|
||||
SYSTEM_FILES = CommonParameterType.SYSTEM_FILES
|
||||
|
||||
# MCP object and array type parameters
|
||||
ARRAY = CommonParameterType.ARRAY.value
|
||||
OBJECT = CommonParameterType.OBJECT.value
|
||||
ARRAY = CommonParameterType.ARRAY
|
||||
OBJECT = CommonParameterType.OBJECT
|
||||
|
||||
|
||||
class MCPServerParameterType(enum.StrEnum):
|
||||
class MCPServerParameterType(StrEnum):
|
||||
"""
|
||||
MCP server got complex parameter types
|
||||
"""
|
||||
|
||||
ARRAY = "array"
|
||||
OBJECT = "object"
|
||||
ARRAY = auto()
|
||||
OBJECT = auto()
|
||||
|
||||
|
||||
class PluginParameterAutoGenerate(BaseModel):
|
||||
class Type(enum.StrEnum):
|
||||
PROMPT_INSTRUCTION = "prompt_instruction"
|
||||
class Type(StrEnum):
|
||||
PROMPT_INSTRUCTION = auto()
|
||||
|
||||
type: Type
|
||||
|
||||
@@ -93,7 +93,7 @@ class PluginParameter(BaseModel):
|
||||
return v
|
||||
|
||||
|
||||
def as_normal_type(typ: enum.StrEnum):
|
||||
def as_normal_type(typ: StrEnum):
|
||||
if typ.value in {
|
||||
PluginParameterType.SECRET_INPUT,
|
||||
PluginParameterType.SELECT,
|
||||
@@ -102,7 +102,7 @@ def as_normal_type(typ: enum.StrEnum):
|
||||
return typ.value
|
||||
|
||||
|
||||
def cast_parameter_value(typ: enum.StrEnum, value: Any, /):
|
||||
def cast_parameter_value(typ: StrEnum, value: Any, /):
|
||||
try:
|
||||
match typ.value:
|
||||
case PluginParameterType.STRING | PluginParameterType.SECRET_INPUT | PluginParameterType.SELECT:
|
||||
@@ -190,7 +190,7 @@ def cast_parameter_value(typ: enum.StrEnum, value: Any, /):
|
||||
raise ValueError(f"The tool parameter value {value} is not in correct type of {as_normal_type(typ)}.")
|
||||
|
||||
|
||||
def init_frontend_parameter(rule: PluginParameter, type: enum.StrEnum, value: Any):
|
||||
def init_frontend_parameter(rule: PluginParameter, type: StrEnum, value: Any):
|
||||
"""
|
||||
init frontend parameter by rule
|
||||
"""
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import datetime
|
||||
import enum
|
||||
import re
|
||||
from collections.abc import Mapping
|
||||
from enum import StrEnum, auto
|
||||
from typing import Any, Optional
|
||||
|
||||
from packaging.version import InvalidVersion, Version
|
||||
@@ -16,11 +16,11 @@ from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolProviderEntity
|
||||
|
||||
|
||||
class PluginInstallationSource(enum.StrEnum):
|
||||
Github = "github"
|
||||
Marketplace = "marketplace"
|
||||
Package = "package"
|
||||
Remote = "remote"
|
||||
class PluginInstallationSource(StrEnum):
|
||||
Github = auto()
|
||||
Marketplace = auto()
|
||||
Package = auto()
|
||||
Remote = auto()
|
||||
|
||||
|
||||
class PluginResourceRequirements(BaseModel):
|
||||
@@ -58,10 +58,10 @@ class PluginResourceRequirements(BaseModel):
|
||||
permission: Optional[Permission] = Field(default=None)
|
||||
|
||||
|
||||
class PluginCategory(enum.StrEnum):
|
||||
Tool = "tool"
|
||||
Model = "model"
|
||||
Extension = "extension"
|
||||
class PluginCategory(StrEnum):
|
||||
Tool = auto()
|
||||
Model = auto()
|
||||
Extension = auto()
|
||||
AgentStrategy = "agent-strategy"
|
||||
|
||||
|
||||
@@ -206,10 +206,10 @@ class ToolProviderID(GenericProviderID):
|
||||
|
||||
|
||||
class PluginDependency(BaseModel):
|
||||
class Type(enum.StrEnum):
|
||||
Github = PluginInstallationSource.Github.value
|
||||
Marketplace = PluginInstallationSource.Marketplace.value
|
||||
Package = PluginInstallationSource.Package.value
|
||||
class Type(StrEnum):
|
||||
Github = PluginInstallationSource.Github
|
||||
Marketplace = PluginInstallationSource.Marketplace
|
||||
Package = PluginInstallationSource.Package
|
||||
|
||||
class Github(BaseModel):
|
||||
repo: str
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import enum
|
||||
import json
|
||||
import os
|
||||
from collections.abc import Mapping, Sequence
|
||||
from enum import StrEnum, auto
|
||||
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||
|
||||
from core.app.app_config.entities import PromptTemplateEntity
|
||||
@@ -25,9 +25,9 @@ if TYPE_CHECKING:
|
||||
from core.file.models import File
|
||||
|
||||
|
||||
class ModelMode(enum.StrEnum):
|
||||
COMPLETION = "completion"
|
||||
CHAT = "chat"
|
||||
class ModelMode(StrEnum):
|
||||
COMPLETION = auto()
|
||||
CHAT = auto()
|
||||
|
||||
|
||||
prompt_file_contents: dict[str, Any] = {}
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
from enum import Enum
|
||||
from enum import StrEnum, auto
|
||||
|
||||
|
||||
class Field(Enum):
|
||||
class Field(StrEnum):
|
||||
CONTENT_KEY = "page_content"
|
||||
METADATA_KEY = "metadata"
|
||||
GROUP_KEY = "group_id"
|
||||
VECTOR = "vector"
|
||||
VECTOR = auto()
|
||||
# Sparse Vector aims to support full text search
|
||||
SPARSE_VECTOR = "sparse_vector"
|
||||
SPARSE_VECTOR = auto()
|
||||
TEXT_KEY = "text"
|
||||
PRIMARY_KEY = "id"
|
||||
DOC_ID = "metadata.doc_id"
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from clickhouse_connect import get_client
|
||||
@@ -27,7 +27,7 @@ class MyScaleConfig(BaseModel):
|
||||
fts_params: str
|
||||
|
||||
|
||||
class SortOrder(Enum):
|
||||
class SortOrder(StrEnum):
|
||||
ASC = "ASC"
|
||||
DESC = "DESC"
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class DatasourceType(Enum):
|
||||
class DatasourceType(StrEnum):
|
||||
FILE = "upload_file"
|
||||
NOTION = "notion_import"
|
||||
WEBSITE = "website_crawl"
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
from enum import Enum, StrEnum
|
||||
from enum import StrEnum, auto
|
||||
|
||||
|
||||
class BuiltInField(StrEnum):
|
||||
document_name = "document_name"
|
||||
uploader = "uploader"
|
||||
upload_date = "upload_date"
|
||||
last_update_date = "last_update_date"
|
||||
source = "source"
|
||||
document_name = auto()
|
||||
uploader = auto()
|
||||
upload_date = auto()
|
||||
last_update_date = auto()
|
||||
source = auto()
|
||||
|
||||
|
||||
class MetadataDataSource(Enum):
|
||||
class MetadataDataSource(StrEnum):
|
||||
upload_file = "file_upload"
|
||||
website_crawl = "website"
|
||||
notion_import = "notion"
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
import base64
|
||||
import contextlib
|
||||
import enum
|
||||
from collections.abc import Mapping
|
||||
from enum import Enum
|
||||
from enum import StrEnum, auto
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_serializer, field_validator, model_validator
|
||||
@@ -22,37 +21,37 @@ from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.constants import TOOL_SELECTOR_MODEL_IDENTITY
|
||||
|
||||
|
||||
class ToolLabelEnum(Enum):
|
||||
SEARCH = "search"
|
||||
IMAGE = "image"
|
||||
VIDEOS = "videos"
|
||||
WEATHER = "weather"
|
||||
FINANCE = "finance"
|
||||
DESIGN = "design"
|
||||
TRAVEL = "travel"
|
||||
SOCIAL = "social"
|
||||
NEWS = "news"
|
||||
MEDICAL = "medical"
|
||||
PRODUCTIVITY = "productivity"
|
||||
EDUCATION = "education"
|
||||
BUSINESS = "business"
|
||||
ENTERTAINMENT = "entertainment"
|
||||
UTILITIES = "utilities"
|
||||
OTHER = "other"
|
||||
class ToolLabelEnum(StrEnum):
|
||||
SEARCH = auto()
|
||||
IMAGE = auto()
|
||||
VIDEOS = auto()
|
||||
WEATHER = auto()
|
||||
FINANCE = auto()
|
||||
DESIGN = auto()
|
||||
TRAVEL = auto()
|
||||
SOCIAL = auto()
|
||||
NEWS = auto()
|
||||
MEDICAL = auto()
|
||||
PRODUCTIVITY = auto()
|
||||
EDUCATION = auto()
|
||||
BUSINESS = auto()
|
||||
ENTERTAINMENT = auto()
|
||||
UTILITIES = auto()
|
||||
OTHER = auto()
|
||||
|
||||
|
||||
class ToolProviderType(enum.StrEnum):
|
||||
class ToolProviderType(StrEnum):
|
||||
"""
|
||||
Enum class for tool provider
|
||||
"""
|
||||
|
||||
PLUGIN = "plugin"
|
||||
PLUGIN = auto()
|
||||
BUILT_IN = "builtin"
|
||||
WORKFLOW = "workflow"
|
||||
API = "api"
|
||||
APP = "app"
|
||||
WORKFLOW = auto()
|
||||
API = auto()
|
||||
APP = auto()
|
||||
DATASET_RETRIEVAL = "dataset-retrieval"
|
||||
MCP = "mcp"
|
||||
MCP = auto()
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "ToolProviderType":
|
||||
@@ -68,15 +67,15 @@ class ToolProviderType(enum.StrEnum):
|
||||
raise ValueError(f"invalid mode value {value}")
|
||||
|
||||
|
||||
class ApiProviderSchemaType(Enum):
|
||||
class ApiProviderSchemaType(StrEnum):
|
||||
"""
|
||||
Enum class for api provider schema type.
|
||||
"""
|
||||
|
||||
OPENAPI = "openapi"
|
||||
SWAGGER = "swagger"
|
||||
OPENAI_PLUGIN = "openai_plugin"
|
||||
OPENAI_ACTIONS = "openai_actions"
|
||||
OPENAPI = auto()
|
||||
SWAGGER = auto()
|
||||
OPENAI_PLUGIN = auto()
|
||||
OPENAI_ACTIONS = auto()
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "ApiProviderSchemaType":
|
||||
@@ -92,14 +91,14 @@ class ApiProviderSchemaType(Enum):
|
||||
raise ValueError(f"invalid mode value {value}")
|
||||
|
||||
|
||||
class ApiProviderAuthType(Enum):
|
||||
class ApiProviderAuthType(StrEnum):
|
||||
"""
|
||||
Enum class for api provider auth type.
|
||||
"""
|
||||
|
||||
NONE = "none"
|
||||
API_KEY_HEADER = "api_key_header"
|
||||
API_KEY_QUERY = "api_key_query"
|
||||
NONE = auto()
|
||||
API_KEY_HEADER = auto()
|
||||
API_KEY_QUERY = auto()
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "ApiProviderAuthType":
|
||||
@@ -176,10 +175,10 @@ class ToolInvokeMessage(BaseModel):
|
||||
return value
|
||||
|
||||
class LogMessage(BaseModel):
|
||||
class LogStatus(Enum):
|
||||
START = "start"
|
||||
ERROR = "error"
|
||||
SUCCESS = "success"
|
||||
class LogStatus(StrEnum):
|
||||
START = auto()
|
||||
ERROR = auto()
|
||||
SUCCESS = auto()
|
||||
|
||||
id: str
|
||||
label: str = Field(..., description="The label of the log")
|
||||
@@ -193,19 +192,19 @@ class ToolInvokeMessage(BaseModel):
|
||||
retriever_resources: list[RetrievalSourceMetadata] = Field(..., description="retriever resources")
|
||||
context: str = Field(..., description="context")
|
||||
|
||||
class MessageType(Enum):
|
||||
TEXT = "text"
|
||||
IMAGE = "image"
|
||||
LINK = "link"
|
||||
BLOB = "blob"
|
||||
JSON = "json"
|
||||
IMAGE_LINK = "image_link"
|
||||
BINARY_LINK = "binary_link"
|
||||
VARIABLE = "variable"
|
||||
FILE = "file"
|
||||
LOG = "log"
|
||||
BLOB_CHUNK = "blob_chunk"
|
||||
RETRIEVER_RESOURCES = "retriever_resources"
|
||||
class MessageType(StrEnum):
|
||||
TEXT = auto()
|
||||
IMAGE = auto()
|
||||
LINK = auto()
|
||||
BLOB = auto()
|
||||
JSON = auto()
|
||||
IMAGE_LINK = auto()
|
||||
BINARY_LINK = auto()
|
||||
VARIABLE = auto()
|
||||
FILE = auto()
|
||||
LOG = auto()
|
||||
BLOB_CHUNK = auto()
|
||||
RETRIEVER_RESOURCES = auto()
|
||||
|
||||
type: MessageType = MessageType.TEXT
|
||||
"""
|
||||
@@ -250,29 +249,29 @@ class ToolParameter(PluginParameter):
|
||||
Overrides type
|
||||
"""
|
||||
|
||||
class ToolParameterType(enum.StrEnum):
|
||||
class ToolParameterType(StrEnum):
|
||||
"""
|
||||
removes TOOLS_SELECTOR from PluginParameterType
|
||||
"""
|
||||
|
||||
STRING = PluginParameterType.STRING.value
|
||||
NUMBER = PluginParameterType.NUMBER.value
|
||||
BOOLEAN = PluginParameterType.BOOLEAN.value
|
||||
SELECT = PluginParameterType.SELECT.value
|
||||
SECRET_INPUT = PluginParameterType.SECRET_INPUT.value
|
||||
FILE = PluginParameterType.FILE.value
|
||||
FILES = PluginParameterType.FILES.value
|
||||
APP_SELECTOR = PluginParameterType.APP_SELECTOR.value
|
||||
MODEL_SELECTOR = PluginParameterType.MODEL_SELECTOR.value
|
||||
ANY = PluginParameterType.ANY.value
|
||||
DYNAMIC_SELECT = PluginParameterType.DYNAMIC_SELECT.value
|
||||
STRING = PluginParameterType.STRING
|
||||
NUMBER = PluginParameterType.NUMBER
|
||||
BOOLEAN = PluginParameterType.BOOLEAN
|
||||
SELECT = PluginParameterType.SELECT
|
||||
SECRET_INPUT = PluginParameterType.SECRET_INPUT
|
||||
FILE = PluginParameterType.FILE
|
||||
FILES = PluginParameterType.FILES
|
||||
APP_SELECTOR = PluginParameterType.APP_SELECTOR
|
||||
MODEL_SELECTOR = PluginParameterType.MODEL_SELECTOR
|
||||
ANY = PluginParameterType.ANY
|
||||
DYNAMIC_SELECT = PluginParameterType.DYNAMIC_SELECT
|
||||
|
||||
# MCP object and array type parameters
|
||||
ARRAY = MCPServerParameterType.ARRAY.value
|
||||
OBJECT = MCPServerParameterType.OBJECT.value
|
||||
ARRAY = MCPServerParameterType.ARRAY
|
||||
OBJECT = MCPServerParameterType.OBJECT
|
||||
|
||||
# deprecated, should not use.
|
||||
SYSTEM_FILES = PluginParameterType.SYSTEM_FILES.value
|
||||
SYSTEM_FILES = PluginParameterType.SYSTEM_FILES
|
||||
|
||||
def as_normal_type(self):
|
||||
return as_normal_type(self)
|
||||
@@ -280,10 +279,10 @@ class ToolParameter(PluginParameter):
|
||||
def cast_value(self, value: Any):
|
||||
return cast_parameter_value(self, value)
|
||||
|
||||
class ToolParameterForm(Enum):
|
||||
SCHEMA = "schema" # should be set while adding tool
|
||||
FORM = "form" # should be set before invoking tool
|
||||
LLM = "llm" # will be set by LLM
|
||||
class ToolParameterForm(StrEnum):
|
||||
SCHEMA = auto() # should be set while adding tool
|
||||
FORM = auto() # should be set before invoking tool
|
||||
LLM = auto() # will be set by LLM
|
||||
|
||||
type: ToolParameterType = Field(..., description="The type of the parameter")
|
||||
human_description: Optional[I18nObject] = Field(default=None, description="The description presented to the user")
|
||||
@@ -446,14 +445,14 @@ class ToolLabel(BaseModel):
|
||||
icon: str = Field(..., description="The icon of the tool")
|
||||
|
||||
|
||||
class ToolInvokeFrom(Enum):
|
||||
class ToolInvokeFrom(StrEnum):
|
||||
"""
|
||||
Enum class for tool invoke
|
||||
"""
|
||||
|
||||
WORKFLOW = "workflow"
|
||||
AGENT = "agent"
|
||||
PLUGIN = "plugin"
|
||||
WORKFLOW = auto()
|
||||
AGENT = auto()
|
||||
PLUGIN = auto()
|
||||
|
||||
|
||||
class ToolSelector(BaseModel):
|
||||
@@ -478,9 +477,9 @@ class ToolSelector(BaseModel):
|
||||
return self.model_dump()
|
||||
|
||||
|
||||
class CredentialType(enum.StrEnum):
|
||||
class CredentialType(StrEnum):
|
||||
API_KEY = "api-key"
|
||||
OAUTH2 = "oauth2"
|
||||
OAUTH2 = auto()
|
||||
|
||||
def get_name(self):
|
||||
if self == CredentialType.API_KEY:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from enum import StrEnum, auto
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -11,12 +11,12 @@ from libs.datetime_utils import naive_utc_now
|
||||
|
||||
|
||||
class RouteNodeState(BaseModel):
|
||||
class Status(Enum):
|
||||
RUNNING = "running"
|
||||
SUCCESS = "success"
|
||||
FAILED = "failed"
|
||||
PAUSED = "paused"
|
||||
EXCEPTION = "exception"
|
||||
class Status(StrEnum):
|
||||
RUNNING = auto()
|
||||
SUCCESS = auto()
|
||||
FAILED = auto()
|
||||
PAUSED = auto()
|
||||
EXCEPTION = auto()
|
||||
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
"""node state id"""
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from enum import Enum, StrEnum
|
||||
from enum import IntEnum, StrEnum, auto
|
||||
from typing import Any, Literal, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -25,9 +25,9 @@ class AgentNodeData(BaseNodeData):
|
||||
agent_parameters: dict[str, AgentInput]
|
||||
|
||||
|
||||
class ParamsAutoGenerated(Enum):
|
||||
CLOSE = 0
|
||||
OPEN = 1
|
||||
class ParamsAutoGenerated(IntEnum):
|
||||
CLOSE = auto()
|
||||
OPEN = auto()
|
||||
|
||||
|
||||
class AgentOldVersionModelFeatures(StrEnum):
|
||||
@@ -38,8 +38,8 @@ class AgentOldVersionModelFeatures(StrEnum):
|
||||
TOOL_CALL = "tool-call"
|
||||
MULTI_TOOL_CALL = "multi-tool-call"
|
||||
AGENT_THOUGHT = "agent-thought"
|
||||
VISION = "vision"
|
||||
VISION = auto()
|
||||
STREAM_TOOL_CALL = "stream-tool-call"
|
||||
DOCUMENT = "document"
|
||||
VIDEO = "video"
|
||||
AUDIO = "audio"
|
||||
DOCUMENT = auto()
|
||||
VIDEO = auto()
|
||||
AUDIO = auto()
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from collections.abc import Sequence
|
||||
from enum import Enum
|
||||
from enum import StrEnum, auto
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -19,9 +19,9 @@ class GenerateRouteChunk(BaseModel):
|
||||
Generate Route Chunk.
|
||||
"""
|
||||
|
||||
class ChunkType(Enum):
|
||||
VAR = "var"
|
||||
TEXT = "text"
|
||||
class ChunkType(StrEnum):
|
||||
VAR = auto()
|
||||
TEXT = auto()
|
||||
|
||||
type: ChunkType = Field(..., description="generate route chunk type")
|
||||
|
||||
|
||||
@@ -259,7 +259,7 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
)
|
||||
all_documents = []
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE.value:
|
||||
if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
|
||||
# fetch model config
|
||||
if node_data.single_retrieval_config is None:
|
||||
raise ValueError("single_retrieval_config is required")
|
||||
@@ -291,7 +291,7 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
metadata_filter_document_ids=metadata_filter_document_ids,
|
||||
metadata_condition=metadata_condition,
|
||||
)
|
||||
elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value:
|
||||
elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
|
||||
if node_data.multiple_retrieval_config is None:
|
||||
raise ValueError("multiple_retrieval_config is required")
|
||||
if node_data.multiple_retrieval_config.reranking_mode == "reranking_model":
|
||||
|
||||
@@ -9,19 +9,19 @@ import json
|
||||
import logging
|
||||
from dataclasses import asdict, dataclass
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from enum import StrEnum, auto
|
||||
from typing import Any, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FileStatus(Enum):
|
||||
class FileStatus(StrEnum):
|
||||
"""File status enumeration"""
|
||||
|
||||
ACTIVE = "active" # Active status
|
||||
ARCHIVED = "archived" # Archived
|
||||
DELETED = "deleted" # Deleted (soft delete)
|
||||
BACKUP = "backup" # Backup file
|
||||
ACTIVE = auto() # Active status
|
||||
ARCHIVED = auto() # Archived
|
||||
DELETED = auto() # Deleted (soft delete)
|
||||
BACKUP = auto() # Backup file
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -5,13 +5,13 @@ According to ClickZetta's permission model, different Volume types have differen
|
||||
"""
|
||||
|
||||
import logging
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VolumePermission(Enum):
|
||||
class VolumePermission(StrEnum):
|
||||
"""Volume permission type enumeration"""
|
||||
|
||||
READ = "SELECT" # Corresponds to ClickZetta's SELECT permission
|
||||
|
||||
@@ -7,7 +7,7 @@ eliminates the need for repetitive language switching logic.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from enum import StrEnum, auto
|
||||
from typing import Any, Optional, Protocol
|
||||
|
||||
from flask import render_template
|
||||
@@ -17,30 +17,30 @@ from extensions.ext_mail import mail
|
||||
from services.feature_service import BrandingModel, FeatureService
|
||||
|
||||
|
||||
class EmailType(Enum):
|
||||
class EmailType(StrEnum):
|
||||
"""Enumeration of supported email types."""
|
||||
|
||||
RESET_PASSWORD = "reset_password"
|
||||
RESET_PASSWORD_WHEN_ACCOUNT_NOT_EXIST = "reset_password_when_account_not_exist"
|
||||
INVITE_MEMBER = "invite_member"
|
||||
EMAIL_CODE_LOGIN = "email_code_login"
|
||||
CHANGE_EMAIL_OLD = "change_email_old"
|
||||
CHANGE_EMAIL_NEW = "change_email_new"
|
||||
CHANGE_EMAIL_COMPLETED = "change_email_completed"
|
||||
OWNER_TRANSFER_CONFIRM = "owner_transfer_confirm"
|
||||
OWNER_TRANSFER_OLD_NOTIFY = "owner_transfer_old_notify"
|
||||
OWNER_TRANSFER_NEW_NOTIFY = "owner_transfer_new_notify"
|
||||
ACCOUNT_DELETION_SUCCESS = "account_deletion_success"
|
||||
ACCOUNT_DELETION_VERIFICATION = "account_deletion_verification"
|
||||
ENTERPRISE_CUSTOM = "enterprise_custom"
|
||||
QUEUE_MONITOR_ALERT = "queue_monitor_alert"
|
||||
DOCUMENT_CLEAN_NOTIFY = "document_clean_notify"
|
||||
EMAIL_REGISTER = "email_register"
|
||||
EMAIL_REGISTER_WHEN_ACCOUNT_EXIST = "email_register_when_account_exist"
|
||||
RESET_PASSWORD_WHEN_ACCOUNT_NOT_EXIST_NO_REGISTER = "reset_password_when_account_not_exist_no_register"
|
||||
RESET_PASSWORD = auto()
|
||||
RESET_PASSWORD_WHEN_ACCOUNT_NOT_EXIST = auto()
|
||||
INVITE_MEMBER = auto()
|
||||
EMAIL_CODE_LOGIN = auto()
|
||||
CHANGE_EMAIL_OLD = auto()
|
||||
CHANGE_EMAIL_NEW = auto()
|
||||
CHANGE_EMAIL_COMPLETED = auto()
|
||||
OWNER_TRANSFER_CONFIRM = auto()
|
||||
OWNER_TRANSFER_OLD_NOTIFY = auto()
|
||||
OWNER_TRANSFER_NEW_NOTIFY = auto()
|
||||
ACCOUNT_DELETION_SUCCESS = auto()
|
||||
ACCOUNT_DELETION_VERIFICATION = auto()
|
||||
ENTERPRISE_CUSTOM = auto()
|
||||
QUEUE_MONITOR_ALERT = auto()
|
||||
DOCUMENT_CLEAN_NOTIFY = auto()
|
||||
EMAIL_REGISTER = auto()
|
||||
EMAIL_REGISTER_WHEN_ACCOUNT_EXIST = auto()
|
||||
RESET_PASSWORD_WHEN_ACCOUNT_NOT_EXIST_NO_REGISTER = auto()
|
||||
|
||||
|
||||
class EmailLanguage(Enum):
|
||||
class EmailLanguage(StrEnum):
|
||||
"""Supported email languages with fallback handling."""
|
||||
|
||||
EN_US = "en-US"
|
||||
|
||||
@@ -68,7 +68,7 @@ class AppIconUrlField(fields.Raw):
|
||||
if isinstance(obj, dict) and "app" in obj:
|
||||
obj = obj["app"]
|
||||
|
||||
if isinstance(obj, App | Site) and obj.icon_type == IconType.IMAGE.value:
|
||||
if isinstance(obj, App | Site) and obj.icon_type == IconType.IMAGE:
|
||||
return file_helpers.get_signed_file_url(obj.icon)
|
||||
return None
|
||||
|
||||
|
||||
@@ -224,35 +224,35 @@ class Dataset(Base):
|
||||
doc_metadata.append(
|
||||
{
|
||||
"id": "built-in",
|
||||
"name": BuiltInField.document_name.value,
|
||||
"name": BuiltInField.document_name,
|
||||
"type": "string",
|
||||
}
|
||||
)
|
||||
doc_metadata.append(
|
||||
{
|
||||
"id": "built-in",
|
||||
"name": BuiltInField.uploader.value,
|
||||
"name": BuiltInField.uploader,
|
||||
"type": "string",
|
||||
}
|
||||
)
|
||||
doc_metadata.append(
|
||||
{
|
||||
"id": "built-in",
|
||||
"name": BuiltInField.upload_date.value,
|
||||
"name": BuiltInField.upload_date,
|
||||
"type": "time",
|
||||
}
|
||||
)
|
||||
doc_metadata.append(
|
||||
{
|
||||
"id": "built-in",
|
||||
"name": BuiltInField.last_update_date.value,
|
||||
"name": BuiltInField.last_update_date,
|
||||
"type": "time",
|
||||
}
|
||||
)
|
||||
doc_metadata.append(
|
||||
{
|
||||
"id": "built-in",
|
||||
"name": BuiltInField.source.value,
|
||||
"name": BuiltInField.source,
|
||||
"type": "string",
|
||||
}
|
||||
)
|
||||
@@ -544,7 +544,7 @@ class Document(Base):
|
||||
"id": "built-in",
|
||||
"name": BuiltInField.source,
|
||||
"type": "string",
|
||||
"value": MetadataDataSource[self.data_source_type].value,
|
||||
"value": MetadataDataSource[self.data_source_type],
|
||||
}
|
||||
)
|
||||
return built_in_fields
|
||||
|
||||
@@ -3,7 +3,7 @@ import re
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from enum import Enum, StrEnum
|
||||
from enum import StrEnum, auto
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional, cast
|
||||
|
||||
from core.plugin.entities.plugin import GenericProviderID
|
||||
@@ -62,9 +62,9 @@ class AppMode(StrEnum):
|
||||
raise ValueError(f"invalid mode value {value}")
|
||||
|
||||
|
||||
class IconType(Enum):
|
||||
IMAGE = "image"
|
||||
EMOJI = "emoji"
|
||||
class IconType(StrEnum):
|
||||
IMAGE = auto()
|
||||
EMOJI = auto()
|
||||
|
||||
|
||||
class App(Base):
|
||||
@@ -149,15 +149,15 @@ class App(Base):
|
||||
if app_model_config.agent_mode_dict.get("enabled", False) and app_model_config.agent_mode_dict.get(
|
||||
"strategy", ""
|
||||
) in {"function_call", "react"}:
|
||||
self.mode = AppMode.AGENT_CHAT.value
|
||||
self.mode = AppMode.AGENT_CHAT
|
||||
db.session.commit()
|
||||
return True
|
||||
return False
|
||||
|
||||
@property
|
||||
def mode_compatible_with_agent(self) -> str:
|
||||
if self.mode == AppMode.CHAT.value and self.is_agent:
|
||||
return AppMode.AGENT_CHAT.value
|
||||
if self.mode == AppMode.CHAT and self.is_agent:
|
||||
return AppMode.AGENT_CHAT
|
||||
|
||||
return str(self.mode)
|
||||
|
||||
@@ -713,7 +713,7 @@ class Conversation(Base):
|
||||
model_config = {}
|
||||
app_model_config: Optional[AppModelConfig] = None
|
||||
|
||||
if self.mode == AppMode.ADVANCED_CHAT.value:
|
||||
if self.mode == AppMode.ADVANCED_CHAT:
|
||||
if self.override_model_configs:
|
||||
override_model_configs = json.loads(self.override_model_configs)
|
||||
model_config = override_model_configs
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from enum import StrEnum, auto
|
||||
from functools import cached_property
|
||||
from typing import Optional
|
||||
|
||||
@@ -12,9 +12,9 @@ from .engine import db
|
||||
from .types import StringUUID
|
||||
|
||||
|
||||
class ProviderType(Enum):
|
||||
CUSTOM = "custom"
|
||||
SYSTEM = "system"
|
||||
class ProviderType(StrEnum):
|
||||
CUSTOM = auto()
|
||||
SYSTEM = auto()
|
||||
|
||||
@staticmethod
|
||||
def value_of(value: str) -> "ProviderType":
|
||||
@@ -24,14 +24,14 @@ class ProviderType(Enum):
|
||||
raise ValueError(f"No matching enum found for value '{value}'")
|
||||
|
||||
|
||||
class ProviderQuotaType(Enum):
|
||||
PAID = "paid"
|
||||
class ProviderQuotaType(StrEnum):
|
||||
PAID = auto()
|
||||
"""hosted paid quota"""
|
||||
|
||||
FREE = "free"
|
||||
FREE = auto()
|
||||
"""third-party free quota"""
|
||||
|
||||
TRIAL = "trial"
|
||||
TRIAL = auto()
|
||||
"""hosted trial quota"""
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -2,7 +2,7 @@ import json
|
||||
import logging
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from enum import Enum, StrEnum
|
||||
from enum import StrEnum, auto
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union, cast
|
||||
from uuid import uuid4
|
||||
|
||||
@@ -41,13 +41,13 @@ from .types import EnumText, StringUUID
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowType(Enum):
|
||||
class WorkflowType(StrEnum):
|
||||
"""
|
||||
Workflow Type Enum
|
||||
"""
|
||||
|
||||
WORKFLOW = "workflow"
|
||||
CHAT = "chat"
|
||||
WORKFLOW = auto()
|
||||
CHAT = auto()
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "WorkflowType":
|
||||
@@ -777,7 +777,7 @@ class WorkflowNodeExecutionModel(Base):
|
||||
return extras
|
||||
|
||||
|
||||
class WorkflowAppLogCreatedFrom(Enum):
|
||||
class WorkflowAppLogCreatedFrom(StrEnum):
|
||||
"""
|
||||
Workflow App Log Created From Enum
|
||||
"""
|
||||
|
||||
@@ -32,14 +32,14 @@ class AdvancedPromptTemplateService:
|
||||
def get_common_prompt(cls, app_mode: str, model_mode: str, has_context: str):
|
||||
context_prompt = copy.deepcopy(CONTEXT)
|
||||
|
||||
if app_mode == AppMode.CHAT.value:
|
||||
if app_mode == AppMode.CHAT:
|
||||
if model_mode == "completion":
|
||||
return cls.get_completion_prompt(
|
||||
copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt
|
||||
)
|
||||
elif model_mode == "chat":
|
||||
return cls.get_chat_prompt(copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt)
|
||||
elif app_mode == AppMode.COMPLETION.value:
|
||||
elif app_mode == AppMode.COMPLETION:
|
||||
if model_mode == "completion":
|
||||
return cls.get_completion_prompt(
|
||||
copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt
|
||||
@@ -73,7 +73,7 @@ class AdvancedPromptTemplateService:
|
||||
def get_baichuan_prompt(cls, app_mode: str, model_mode: str, has_context: str):
|
||||
baichuan_context_prompt = copy.deepcopy(BAICHUAN_CONTEXT)
|
||||
|
||||
if app_mode == AppMode.CHAT.value:
|
||||
if app_mode == AppMode.CHAT:
|
||||
if model_mode == "completion":
|
||||
return cls.get_completion_prompt(
|
||||
copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, baichuan_context_prompt
|
||||
@@ -82,7 +82,7 @@ class AdvancedPromptTemplateService:
|
||||
return cls.get_chat_prompt(
|
||||
copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt
|
||||
)
|
||||
elif app_mode == AppMode.COMPLETION.value:
|
||||
elif app_mode == AppMode.COMPLETION:
|
||||
if model_mode == "completion":
|
||||
return cls.get_completion_prompt(
|
||||
copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG),
|
||||
|
||||
@@ -60,7 +60,7 @@ class AppGenerateService:
|
||||
request_id = RateLimit.gen_request_key()
|
||||
try:
|
||||
request_id = rate_limit.enter(request_id)
|
||||
if app_model.mode == AppMode.COMPLETION.value:
|
||||
if app_model.mode == AppMode.COMPLETION:
|
||||
return rate_limit.generate(
|
||||
CompletionAppGenerator.convert_to_event_stream(
|
||||
CompletionAppGenerator().generate(
|
||||
@@ -69,7 +69,7 @@ class AppGenerateService:
|
||||
),
|
||||
request_id=request_id,
|
||||
)
|
||||
elif app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent:
|
||||
elif app_model.mode == AppMode.AGENT_CHAT or app_model.is_agent:
|
||||
return rate_limit.generate(
|
||||
AgentChatAppGenerator.convert_to_event_stream(
|
||||
AgentChatAppGenerator().generate(
|
||||
@@ -78,7 +78,7 @@ class AppGenerateService:
|
||||
),
|
||||
request_id,
|
||||
)
|
||||
elif app_model.mode == AppMode.CHAT.value:
|
||||
elif app_model.mode == AppMode.CHAT:
|
||||
return rate_limit.generate(
|
||||
ChatAppGenerator.convert_to_event_stream(
|
||||
ChatAppGenerator().generate(
|
||||
@@ -87,7 +87,7 @@ class AppGenerateService:
|
||||
),
|
||||
request_id=request_id,
|
||||
)
|
||||
elif app_model.mode == AppMode.ADVANCED_CHAT.value:
|
||||
elif app_model.mode == AppMode.ADVANCED_CHAT:
|
||||
workflow_id = args.get("workflow_id")
|
||||
workflow = cls._get_workflow(app_model, invoke_from, workflow_id)
|
||||
return rate_limit.generate(
|
||||
@@ -103,7 +103,7 @@ class AppGenerateService:
|
||||
),
|
||||
request_id=request_id,
|
||||
)
|
||||
elif app_model.mode == AppMode.WORKFLOW.value:
|
||||
elif app_model.mode == AppMode.WORKFLOW:
|
||||
workflow_id = args.get("workflow_id")
|
||||
workflow = cls._get_workflow(app_model, invoke_from, workflow_id)
|
||||
return rate_limit.generate(
|
||||
@@ -155,14 +155,14 @@ class AppGenerateService:
|
||||
|
||||
@classmethod
|
||||
def generate_single_iteration(cls, app_model: App, user: Account, node_id: str, args: Any, streaming: bool = True):
|
||||
if app_model.mode == AppMode.ADVANCED_CHAT.value:
|
||||
if app_model.mode == AppMode.ADVANCED_CHAT:
|
||||
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
|
||||
return AdvancedChatAppGenerator.convert_to_event_stream(
|
||||
AdvancedChatAppGenerator().single_iteration_generate(
|
||||
app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
|
||||
)
|
||||
)
|
||||
elif app_model.mode == AppMode.WORKFLOW.value:
|
||||
elif app_model.mode == AppMode.WORKFLOW:
|
||||
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
|
||||
return AdvancedChatAppGenerator.convert_to_event_stream(
|
||||
WorkflowAppGenerator().single_iteration_generate(
|
||||
@@ -174,14 +174,14 @@ class AppGenerateService:
|
||||
|
||||
@classmethod
|
||||
def generate_single_loop(cls, app_model: App, user: Account, node_id: str, args: Any, streaming: bool = True):
|
||||
if app_model.mode == AppMode.ADVANCED_CHAT.value:
|
||||
if app_model.mode == AppMode.ADVANCED_CHAT:
|
||||
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
|
||||
return AdvancedChatAppGenerator.convert_to_event_stream(
|
||||
AdvancedChatAppGenerator().single_loop_generate(
|
||||
app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
|
||||
)
|
||||
)
|
||||
elif app_model.mode == AppMode.WORKFLOW.value:
|
||||
elif app_model.mode == AppMode.WORKFLOW:
|
||||
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
|
||||
return AdvancedChatAppGenerator.convert_to_event_stream(
|
||||
WorkflowAppGenerator().single_loop_generate(
|
||||
|
||||
@@ -40,15 +40,15 @@ class AppService:
|
||||
filters = [App.tenant_id == tenant_id, App.is_universal == False]
|
||||
|
||||
if args["mode"] == "workflow":
|
||||
filters.append(App.mode == AppMode.WORKFLOW.value)
|
||||
filters.append(App.mode == AppMode.WORKFLOW)
|
||||
elif args["mode"] == "completion":
|
||||
filters.append(App.mode == AppMode.COMPLETION.value)
|
||||
filters.append(App.mode == AppMode.COMPLETION)
|
||||
elif args["mode"] == "chat":
|
||||
filters.append(App.mode == AppMode.CHAT.value)
|
||||
filters.append(App.mode == AppMode.CHAT)
|
||||
elif args["mode"] == "advanced-chat":
|
||||
filters.append(App.mode == AppMode.ADVANCED_CHAT.value)
|
||||
filters.append(App.mode == AppMode.ADVANCED_CHAT)
|
||||
elif args["mode"] == "agent-chat":
|
||||
filters.append(App.mode == AppMode.AGENT_CHAT.value)
|
||||
filters.append(App.mode == AppMode.AGENT_CHAT)
|
||||
|
||||
if args.get("is_created_by_me", False):
|
||||
filters.append(App.created_by == user_id)
|
||||
@@ -171,7 +171,7 @@ class AppService:
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
# get original app model config
|
||||
if app.mode == AppMode.AGENT_CHAT.value or app.is_agent:
|
||||
if app.mode == AppMode.AGENT_CHAT or app.is_agent:
|
||||
model_config = app.app_model_config
|
||||
if not model_config:
|
||||
return app
|
||||
|
||||
@@ -31,7 +31,7 @@ logger = logging.getLogger(__name__)
|
||||
class AudioService:
|
||||
@classmethod
|
||||
def transcript_asr(cls, app_model: App, file: FileStorage, end_user: Optional[str] = None):
|
||||
if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
|
||||
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
workflow = app_model.workflow
|
||||
if workflow is None:
|
||||
raise ValueError("Speech to text is not enabled")
|
||||
@@ -88,7 +88,7 @@ class AudioService:
|
||||
def invoke_tts(text_content: str, app_model: App, voice: Optional[str] = None, is_draft: bool = False):
|
||||
with app.app_context():
|
||||
if voice is None:
|
||||
if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
|
||||
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
if is_draft:
|
||||
workflow = WorkflowService().get_draft_workflow(app_model=app_model)
|
||||
else:
|
||||
|
||||
@@ -1004,7 +1004,7 @@ class DocumentService:
|
||||
if dataset.built_in_field_enabled:
|
||||
if document.doc_metadata:
|
||||
doc_metadata = copy.deepcopy(document.doc_metadata)
|
||||
doc_metadata[BuiltInField.document_name.value] = name
|
||||
doc_metadata[BuiltInField.document_name] = name
|
||||
document.doc_metadata = doc_metadata
|
||||
|
||||
document.name = name
|
||||
|
||||
@@ -229,7 +229,7 @@ class MessageService:
|
||||
|
||||
model_manager = ModelManager()
|
||||
|
||||
if app_model.mode == AppMode.ADVANCED_CHAT.value:
|
||||
if app_model.mode == AppMode.ADVANCED_CHAT:
|
||||
workflow_service = WorkflowService()
|
||||
if invoke_from == InvokeFrom.DEBUGGER:
|
||||
workflow = workflow_service.get_draft_workflow(app_model=app_model)
|
||||
|
||||
@@ -131,11 +131,11 @@ class MetadataService:
|
||||
@staticmethod
|
||||
def get_built_in_fields():
|
||||
return [
|
||||
{"name": BuiltInField.document_name.value, "type": "string"},
|
||||
{"name": BuiltInField.uploader.value, "type": "string"},
|
||||
{"name": BuiltInField.upload_date.value, "type": "time"},
|
||||
{"name": BuiltInField.last_update_date.value, "type": "time"},
|
||||
{"name": BuiltInField.source.value, "type": "string"},
|
||||
{"name": BuiltInField.document_name, "type": "string"},
|
||||
{"name": BuiltInField.uploader, "type": "string"},
|
||||
{"name": BuiltInField.upload_date, "type": "time"},
|
||||
{"name": BuiltInField.last_update_date, "type": "time"},
|
||||
{"name": BuiltInField.source, "type": "string"},
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
@@ -153,11 +153,11 @@ class MetadataService:
|
||||
doc_metadata = {}
|
||||
else:
|
||||
doc_metadata = copy.deepcopy(document.doc_metadata)
|
||||
doc_metadata[BuiltInField.document_name.value] = document.name
|
||||
doc_metadata[BuiltInField.uploader.value] = document.uploader
|
||||
doc_metadata[BuiltInField.upload_date.value] = document.upload_date.timestamp()
|
||||
doc_metadata[BuiltInField.last_update_date.value] = document.last_update_date.timestamp()
|
||||
doc_metadata[BuiltInField.source.value] = MetadataDataSource[document.data_source_type].value
|
||||
doc_metadata[BuiltInField.document_name] = document.name
|
||||
doc_metadata[BuiltInField.uploader] = document.uploader
|
||||
doc_metadata[BuiltInField.upload_date] = document.upload_date.timestamp()
|
||||
doc_metadata[BuiltInField.last_update_date] = document.last_update_date.timestamp()
|
||||
doc_metadata[BuiltInField.source] = MetadataDataSource[document.data_source_type]
|
||||
document.doc_metadata = doc_metadata
|
||||
db.session.add(document)
|
||||
dataset.built_in_field_enabled = True
|
||||
@@ -183,11 +183,11 @@ class MetadataService:
|
||||
doc_metadata = {}
|
||||
else:
|
||||
doc_metadata = copy.deepcopy(document.doc_metadata)
|
||||
doc_metadata.pop(BuiltInField.document_name.value, None)
|
||||
doc_metadata.pop(BuiltInField.uploader.value, None)
|
||||
doc_metadata.pop(BuiltInField.upload_date.value, None)
|
||||
doc_metadata.pop(BuiltInField.last_update_date.value, None)
|
||||
doc_metadata.pop(BuiltInField.source.value, None)
|
||||
doc_metadata.pop(BuiltInField.document_name, None)
|
||||
doc_metadata.pop(BuiltInField.uploader, None)
|
||||
doc_metadata.pop(BuiltInField.upload_date, None)
|
||||
doc_metadata.pop(BuiltInField.last_update_date, None)
|
||||
doc_metadata.pop(BuiltInField.source, None)
|
||||
document.doc_metadata = doc_metadata
|
||||
db.session.add(document)
|
||||
document_ids.append(document.id)
|
||||
@@ -211,11 +211,11 @@ class MetadataService:
|
||||
for metadata_value in operation.metadata_list:
|
||||
doc_metadata[metadata_value.name] = metadata_value.value
|
||||
if dataset.built_in_field_enabled:
|
||||
doc_metadata[BuiltInField.document_name.value] = document.name
|
||||
doc_metadata[BuiltInField.uploader.value] = document.uploader
|
||||
doc_metadata[BuiltInField.upload_date.value] = document.upload_date.timestamp()
|
||||
doc_metadata[BuiltInField.last_update_date.value] = document.last_update_date.timestamp()
|
||||
doc_metadata[BuiltInField.source.value] = MetadataDataSource[document.data_source_type].value
|
||||
doc_metadata[BuiltInField.document_name] = document.name
|
||||
doc_metadata[BuiltInField.uploader] = document.uploader
|
||||
doc_metadata[BuiltInField.upload_date] = document.upload_date.timestamp()
|
||||
doc_metadata[BuiltInField.last_update_date] = document.last_update_date.timestamp()
|
||||
doc_metadata[BuiltInField.source] = MetadataDataSource[document.data_source_type]
|
||||
document.doc_metadata = doc_metadata
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
|
||||
@@ -256,7 +256,7 @@ class PluginMigration:
|
||||
return []
|
||||
|
||||
agent_app_model_config_ids = [
|
||||
app.app_model_config_id for app in apps if app.is_agent or app.mode == AppMode.AGENT_CHAT.value
|
||||
app.app_model_config_id for app in apps if app.is_agent or app.mode == AppMode.AGENT_CHAT
|
||||
]
|
||||
|
||||
rs = session.query(AppModelConfig).where(AppModelConfig.id.in_(agent_app_model_config_ids)).all()
|
||||
|
||||
@@ -65,7 +65,7 @@ class WorkflowConverter:
|
||||
new_app = App()
|
||||
new_app.tenant_id = app_model.tenant_id
|
||||
new_app.name = name or app_model.name + "(workflow)"
|
||||
new_app.mode = AppMode.ADVANCED_CHAT.value if app_model.mode == AppMode.CHAT.value else AppMode.WORKFLOW.value
|
||||
new_app.mode = AppMode.ADVANCED_CHAT if app_model.mode == AppMode.CHAT else AppMode.WORKFLOW
|
||||
new_app.icon_type = icon_type or app_model.icon_type
|
||||
new_app.icon = icon or app_model.icon
|
||||
new_app.icon_background = icon_background or app_model.icon_background
|
||||
@@ -203,7 +203,7 @@ class WorkflowConverter:
|
||||
app_mode_enum = AppMode.value_of(app_model.mode)
|
||||
app_config: EasyUIBasedAppConfig
|
||||
if app_mode_enum == AppMode.AGENT_CHAT or app_model.is_agent:
|
||||
app_model.mode = AppMode.AGENT_CHAT.value
|
||||
app_model.mode = AppMode.AGENT_CHAT
|
||||
app_config = AgentChatAppConfigManager.get_app_config(
|
||||
app_model=app_model, app_model_config=app_model_config
|
||||
)
|
||||
@@ -279,7 +279,7 @@ class WorkflowConverter:
|
||||
"app_id": app_model.id,
|
||||
"tool_variable": tool_variable,
|
||||
"inputs": inputs,
|
||||
"query": "{{#sys.query#}}" if app_model.mode == AppMode.CHAT.value else "",
|
||||
"query": "{{#sys.query#}}" if app_model.mode == AppMode.CHAT else "",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -618,7 +618,7 @@ class WorkflowConverter:
|
||||
:param app_model: App instance
|
||||
:return: AppMode
|
||||
"""
|
||||
if app_model.mode == AppMode.COMPLETION.value:
|
||||
if app_model.mode == AppMode.COMPLETION:
|
||||
return AppMode.WORKFLOW
|
||||
else:
|
||||
return AppMode.ADVANCED_CHAT
|
||||
|
||||
@@ -828,7 +828,7 @@ class WorkflowService:
|
||||
# chatbot convert to workflow mode
|
||||
workflow_converter = WorkflowConverter()
|
||||
|
||||
if app_model.mode not in {AppMode.CHAT.value, AppMode.COMPLETION.value}:
|
||||
if app_model.mode not in {AppMode.CHAT, AppMode.COMPLETION}:
|
||||
raise ValueError(f"Current App mode: {app_model.mode} is not supported convert to workflow.")
|
||||
|
||||
# convert to workflow
|
||||
@@ -844,11 +844,11 @@ class WorkflowService:
|
||||
return new_app
|
||||
|
||||
def validate_features_structure(self, app_model: App, features: dict):
|
||||
if app_model.mode == AppMode.ADVANCED_CHAT.value:
|
||||
if app_model.mode == AppMode.ADVANCED_CHAT:
|
||||
return AdvancedChatAppConfigManager.config_validate(
|
||||
tenant_id=app_model.tenant_id, config=features, only_structure_validate=True
|
||||
)
|
||||
elif app_model.mode == AppMode.WORKFLOW.value:
|
||||
elif app_model.mode == AppMode.WORKFLOW:
|
||||
return WorkflowAppConfigManager.config_validate(
|
||||
tenant_id=app_model.tenant_id, config=features, only_structure_validate=True
|
||||
)
|
||||
|
||||
@@ -42,7 +42,7 @@ class TestAdvancedPromptTemplateService:
|
||||
|
||||
# Test data for Baichuan model
|
||||
args = {
|
||||
"app_mode": AppMode.CHAT.value,
|
||||
"app_mode": AppMode.CHAT,
|
||||
"model_mode": "completion",
|
||||
"model_name": "baichuan-13b-chat",
|
||||
"has_context": "true",
|
||||
@@ -77,7 +77,7 @@ class TestAdvancedPromptTemplateService:
|
||||
|
||||
# Test data for common model
|
||||
args = {
|
||||
"app_mode": AppMode.CHAT.value,
|
||||
"app_mode": AppMode.CHAT,
|
||||
"model_mode": "completion",
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"has_context": "true",
|
||||
@@ -116,7 +116,7 @@ class TestAdvancedPromptTemplateService:
|
||||
|
||||
for model_name in test_cases:
|
||||
args = {
|
||||
"app_mode": AppMode.CHAT.value,
|
||||
"app_mode": AppMode.CHAT,
|
||||
"model_mode": "completion",
|
||||
"model_name": model_name,
|
||||
"has_context": "true",
|
||||
@@ -144,7 +144,7 @@ class TestAdvancedPromptTemplateService:
|
||||
fake = Faker()
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT.value, "completion", "true")
|
||||
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT, "completion", "true")
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
@@ -173,7 +173,7 @@ class TestAdvancedPromptTemplateService:
|
||||
fake = Faker()
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT.value, "chat", "true")
|
||||
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT, "chat", "true")
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
@@ -202,7 +202,7 @@ class TestAdvancedPromptTemplateService:
|
||||
fake = Faker()
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.COMPLETION.value, "completion", "true")
|
||||
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.COMPLETION, "completion", "true")
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
@@ -230,7 +230,7 @@ class TestAdvancedPromptTemplateService:
|
||||
fake = Faker()
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.COMPLETION.value, "chat", "true")
|
||||
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.COMPLETION, "chat", "true")
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
@@ -257,7 +257,7 @@ class TestAdvancedPromptTemplateService:
|
||||
fake = Faker()
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT.value, "completion", "false")
|
||||
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT, "completion", "false")
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
@@ -303,7 +303,7 @@ class TestAdvancedPromptTemplateService:
|
||||
fake = Faker()
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT.value, "unsupported_mode", "true")
|
||||
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT, "unsupported_mode", "true")
|
||||
|
||||
# Assert: Verify empty dict is returned
|
||||
assert result == {}
|
||||
@@ -442,7 +442,7 @@ class TestAdvancedPromptTemplateService:
|
||||
fake = Faker()
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT.value, "completion", "true")
|
||||
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT, "completion", "true")
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
@@ -473,7 +473,7 @@ class TestAdvancedPromptTemplateService:
|
||||
fake = Faker()
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT.value, "chat", "true")
|
||||
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT, "chat", "true")
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
@@ -502,7 +502,7 @@ class TestAdvancedPromptTemplateService:
|
||||
fake = Faker()
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.COMPLETION.value, "completion", "true")
|
||||
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.COMPLETION, "completion", "true")
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
@@ -530,7 +530,7 @@ class TestAdvancedPromptTemplateService:
|
||||
fake = Faker()
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.COMPLETION.value, "chat", "true")
|
||||
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.COMPLETION, "chat", "true")
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
@@ -557,7 +557,7 @@ class TestAdvancedPromptTemplateService:
|
||||
fake = Faker()
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT.value, "completion", "false")
|
||||
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT, "completion", "false")
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
@@ -603,7 +603,7 @@ class TestAdvancedPromptTemplateService:
|
||||
fake = Faker()
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT.value, "unsupported_mode", "true")
|
||||
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT, "unsupported_mode", "true")
|
||||
|
||||
# Assert: Verify empty dict is returned
|
||||
assert result == {}
|
||||
@@ -621,7 +621,7 @@ class TestAdvancedPromptTemplateService:
|
||||
fake = Faker()
|
||||
|
||||
# Test all app modes
|
||||
app_modes = [AppMode.CHAT.value, AppMode.COMPLETION.value]
|
||||
app_modes = [AppMode.CHAT, AppMode.COMPLETION]
|
||||
model_modes = ["completion", "chat"]
|
||||
|
||||
for app_mode in app_modes:
|
||||
@@ -653,7 +653,7 @@ class TestAdvancedPromptTemplateService:
|
||||
fake = Faker()
|
||||
|
||||
# Test all app modes
|
||||
app_modes = [AppMode.CHAT.value, AppMode.COMPLETION.value]
|
||||
app_modes = [AppMode.CHAT, AppMode.COMPLETION]
|
||||
model_modes = ["completion", "chat"]
|
||||
|
||||
for app_mode in app_modes:
|
||||
@@ -686,10 +686,10 @@ class TestAdvancedPromptTemplateService:
|
||||
# Test edge cases
|
||||
edge_cases = [
|
||||
{"app_mode": "", "model_mode": "completion", "model_name": "gpt-3.5-turbo", "has_context": "true"},
|
||||
{"app_mode": AppMode.CHAT.value, "model_mode": "", "model_name": "gpt-3.5-turbo", "has_context": "true"},
|
||||
{"app_mode": AppMode.CHAT.value, "model_mode": "completion", "model_name": "", "has_context": "true"},
|
||||
{"app_mode": AppMode.CHAT, "model_mode": "", "model_name": "gpt-3.5-turbo", "has_context": "true"},
|
||||
{"app_mode": AppMode.CHAT, "model_mode": "completion", "model_name": "", "has_context": "true"},
|
||||
{
|
||||
"app_mode": AppMode.CHAT.value,
|
||||
"app_mode": AppMode.CHAT,
|
||||
"model_mode": "completion",
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"has_context": "",
|
||||
@@ -723,7 +723,7 @@ class TestAdvancedPromptTemplateService:
|
||||
|
||||
# Test with context
|
||||
args = {
|
||||
"app_mode": AppMode.CHAT.value,
|
||||
"app_mode": AppMode.CHAT,
|
||||
"model_mode": "completion",
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"has_context": "true",
|
||||
@@ -757,7 +757,7 @@ class TestAdvancedPromptTemplateService:
|
||||
|
||||
# Test with context
|
||||
args = {
|
||||
"app_mode": AppMode.CHAT.value,
|
||||
"app_mode": AppMode.CHAT,
|
||||
"model_mode": "completion",
|
||||
"model_name": "baichuan-13b-chat",
|
||||
"has_context": "true",
|
||||
@@ -786,25 +786,25 @@ class TestAdvancedPromptTemplateService:
|
||||
# Test different scenarios
|
||||
test_scenarios = [
|
||||
{
|
||||
"app_mode": AppMode.CHAT.value,
|
||||
"app_mode": AppMode.CHAT,
|
||||
"model_mode": "completion",
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"has_context": "true",
|
||||
},
|
||||
{
|
||||
"app_mode": AppMode.CHAT.value,
|
||||
"app_mode": AppMode.CHAT,
|
||||
"model_mode": "chat",
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"has_context": "true",
|
||||
},
|
||||
{
|
||||
"app_mode": AppMode.COMPLETION.value,
|
||||
"app_mode": AppMode.COMPLETION,
|
||||
"model_mode": "completion",
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"has_context": "true",
|
||||
},
|
||||
{
|
||||
"app_mode": AppMode.COMPLETION.value,
|
||||
"app_mode": AppMode.COMPLETION,
|
||||
"model_mode": "chat",
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"has_context": "true",
|
||||
@@ -843,25 +843,25 @@ class TestAdvancedPromptTemplateService:
|
||||
# Test different scenarios
|
||||
test_scenarios = [
|
||||
{
|
||||
"app_mode": AppMode.CHAT.value,
|
||||
"app_mode": AppMode.CHAT,
|
||||
"model_mode": "completion",
|
||||
"model_name": "baichuan-13b-chat",
|
||||
"has_context": "true",
|
||||
},
|
||||
{
|
||||
"app_mode": AppMode.CHAT.value,
|
||||
"app_mode": AppMode.CHAT,
|
||||
"model_mode": "chat",
|
||||
"model_name": "baichuan-13b-chat",
|
||||
"has_context": "true",
|
||||
},
|
||||
{
|
||||
"app_mode": AppMode.COMPLETION.value,
|
||||
"app_mode": AppMode.COMPLETION,
|
||||
"model_mode": "completion",
|
||||
"model_name": "baichuan-13b-chat",
|
||||
"has_context": "true",
|
||||
},
|
||||
{
|
||||
"app_mode": AppMode.COMPLETION.value,
|
||||
"app_mode": AppMode.COMPLETION,
|
||||
"model_mode": "chat",
|
||||
"model_name": "baichuan-13b-chat",
|
||||
"has_context": "true",
|
||||
|
||||
@@ -255,7 +255,7 @@ class TestMetadataService:
|
||||
mock_external_service_dependencies["current_user"].id = account.id
|
||||
|
||||
# Try to create metadata with built-in field name
|
||||
built_in_field_name = BuiltInField.document_name.value
|
||||
built_in_field_name = BuiltInField.document_name
|
||||
metadata_args = MetadataArgs(type="string", name=built_in_field_name)
|
||||
|
||||
# Act & Assert: Verify proper error handling
|
||||
@@ -375,7 +375,7 @@ class TestMetadataService:
|
||||
metadata = MetadataService.create_metadata(dataset.id, metadata_args)
|
||||
|
||||
# Try to update with built-in field name
|
||||
built_in_field_name = BuiltInField.document_name.value
|
||||
built_in_field_name = BuiltInField.document_name
|
||||
|
||||
with pytest.raises(ValueError, match="Metadata name already exists in Built-in fields."):
|
||||
MetadataService.update_metadata_name(dataset.id, metadata.id, built_in_field_name)
|
||||
@@ -540,11 +540,11 @@ class TestMetadataService:
|
||||
field_names = [field["name"] for field in result]
|
||||
field_types = [field["type"] for field in result]
|
||||
|
||||
assert BuiltInField.document_name.value in field_names
|
||||
assert BuiltInField.uploader.value in field_names
|
||||
assert BuiltInField.upload_date.value in field_names
|
||||
assert BuiltInField.last_update_date.value in field_names
|
||||
assert BuiltInField.source.value in field_names
|
||||
assert BuiltInField.document_name in field_names
|
||||
assert BuiltInField.uploader in field_names
|
||||
assert BuiltInField.upload_date in field_names
|
||||
assert BuiltInField.last_update_date in field_names
|
||||
assert BuiltInField.source in field_names
|
||||
|
||||
# Verify field types
|
||||
assert "string" in field_types
|
||||
@@ -682,11 +682,11 @@ class TestMetadataService:
|
||||
|
||||
# Set document metadata with built-in fields
|
||||
document.doc_metadata = {
|
||||
BuiltInField.document_name.value: document.name,
|
||||
BuiltInField.uploader.value: "test_uploader",
|
||||
BuiltInField.upload_date.value: 1234567890.0,
|
||||
BuiltInField.last_update_date.value: 1234567890.0,
|
||||
BuiltInField.source.value: "test_source",
|
||||
BuiltInField.document_name: document.name,
|
||||
BuiltInField.uploader: "test_uploader",
|
||||
BuiltInField.upload_date: 1234567890.0,
|
||||
BuiltInField.last_update_date: 1234567890.0,
|
||||
BuiltInField.source: "test_source",
|
||||
}
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
|
||||
@@ -96,7 +96,7 @@ class TestWorkflowService:
|
||||
app.tenant_id = fake.uuid4()
|
||||
app.name = fake.company()
|
||||
app.description = fake.text()
|
||||
app.mode = AppMode.WORKFLOW.value
|
||||
app.mode = AppMode.WORKFLOW
|
||||
app.icon_type = "emoji"
|
||||
app.icon = "🤖"
|
||||
app.icon_background = "#FFEAD5"
|
||||
@@ -883,7 +883,7 @@ class TestWorkflowService:
|
||||
|
||||
# Create chat mode app
|
||||
app = self._create_test_app(db_session_with_containers, fake)
|
||||
app.mode = AppMode.CHAT.value
|
||||
app.mode = AppMode.CHAT
|
||||
|
||||
# Create app model config (required for conversion)
|
||||
from models.model import AppModelConfig
|
||||
@@ -926,7 +926,7 @@ class TestWorkflowService:
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result.mode == AppMode.ADVANCED_CHAT.value # CHAT mode converts to ADVANCED_CHAT, not WORKFLOW
|
||||
assert result.mode == AppMode.ADVANCED_CHAT # CHAT mode converts to ADVANCED_CHAT, not WORKFLOW
|
||||
assert result.name == conversion_args["name"]
|
||||
assert result.icon == conversion_args["icon"]
|
||||
assert result.icon_type == conversion_args["icon_type"]
|
||||
@@ -945,7 +945,7 @@ class TestWorkflowService:
|
||||
|
||||
# Create completion mode app
|
||||
app = self._create_test_app(db_session_with_containers, fake)
|
||||
app.mode = AppMode.COMPLETION.value
|
||||
app.mode = AppMode.COMPLETION
|
||||
|
||||
# Create app model config (required for conversion)
|
||||
from models.model import AppModelConfig
|
||||
@@ -988,7 +988,7 @@ class TestWorkflowService:
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result.mode == AppMode.WORKFLOW.value
|
||||
assert result.mode == AppMode.WORKFLOW
|
||||
assert result.name == conversion_args["name"]
|
||||
assert result.icon == conversion_args["icon"]
|
||||
assert result.icon_type == conversion_args["icon_type"]
|
||||
@@ -1007,7 +1007,7 @@ class TestWorkflowService:
|
||||
|
||||
# Create workflow mode app (already in workflow mode)
|
||||
app = self._create_test_app(db_session_with_containers, fake)
|
||||
app.mode = AppMode.WORKFLOW.value
|
||||
app.mode = AppMode.WORKFLOW
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
@@ -1030,7 +1030,7 @@ class TestWorkflowService:
|
||||
# Arrange
|
||||
fake = Faker()
|
||||
app = self._create_test_app(db_session_with_containers, fake)
|
||||
app.mode = AppMode.ADVANCED_CHAT.value
|
||||
app.mode = AppMode.ADVANCED_CHAT
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
@@ -1061,7 +1061,7 @@ class TestWorkflowService:
|
||||
# Arrange
|
||||
fake = Faker()
|
||||
app = self._create_test_app(db_session_with_containers, fake)
|
||||
app.mode = AppMode.WORKFLOW.value
|
||||
app.mode = AppMode.WORKFLOW
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ class TestHandleMCPRequest:
|
||||
"""Setup test fixtures"""
|
||||
self.app = Mock(spec=App)
|
||||
self.app.name = "test_app"
|
||||
self.app.mode = AppMode.CHAT.value
|
||||
self.app.mode = AppMode.CHAT
|
||||
|
||||
self.mcp_server = Mock(spec=AppMCPServer)
|
||||
self.mcp_server.description = "Test server"
|
||||
@@ -196,7 +196,7 @@ class TestIndividualHandlers:
|
||||
def test_handle_list_tools(self):
|
||||
"""Test list tools handler"""
|
||||
app_name = "test_app"
|
||||
app_mode = AppMode.CHAT.value
|
||||
app_mode = AppMode.CHAT
|
||||
description = "Test server"
|
||||
parameters_dict: dict[str, str] = {}
|
||||
user_input_form: list[VariableEntity] = []
|
||||
@@ -212,7 +212,7 @@ class TestIndividualHandlers:
|
||||
def test_handle_call_tool(self, mock_app_generate):
|
||||
"""Test call tool handler"""
|
||||
app = Mock(spec=App)
|
||||
app.mode = AppMode.CHAT.value
|
||||
app.mode = AppMode.CHAT
|
||||
|
||||
# Create mock request
|
||||
mock_request = Mock()
|
||||
@@ -252,7 +252,7 @@ class TestUtilityFunctions:
|
||||
|
||||
def test_build_parameter_schema_chat_mode(self):
|
||||
"""Test building parameter schema for chat mode"""
|
||||
app_mode = AppMode.CHAT.value
|
||||
app_mode = AppMode.CHAT
|
||||
parameters_dict: dict[str, str] = {"name": "Enter your name"}
|
||||
|
||||
user_input_form = [
|
||||
@@ -275,7 +275,7 @@ class TestUtilityFunctions:
|
||||
|
||||
def test_build_parameter_schema_workflow_mode(self):
|
||||
"""Test building parameter schema for workflow mode"""
|
||||
app_mode = AppMode.WORKFLOW.value
|
||||
app_mode = AppMode.WORKFLOW
|
||||
parameters_dict: dict[str, str] = {"input_text": "Enter text"}
|
||||
|
||||
user_input_form = [
|
||||
@@ -298,7 +298,7 @@ class TestUtilityFunctions:
|
||||
def test_prepare_tool_arguments_chat_mode(self):
|
||||
"""Test preparing tool arguments for chat mode"""
|
||||
app = Mock(spec=App)
|
||||
app.mode = AppMode.CHAT.value
|
||||
app.mode = AppMode.CHAT
|
||||
|
||||
arguments = {"query": "test question", "name": "John"}
|
||||
|
||||
@@ -312,7 +312,7 @@ class TestUtilityFunctions:
|
||||
def test_prepare_tool_arguments_workflow_mode(self):
|
||||
"""Test preparing tool arguments for workflow mode"""
|
||||
app = Mock(spec=App)
|
||||
app.mode = AppMode.WORKFLOW.value
|
||||
app.mode = AppMode.WORKFLOW
|
||||
|
||||
arguments = {"input_text": "test input"}
|
||||
|
||||
@@ -324,7 +324,7 @@ class TestUtilityFunctions:
|
||||
def test_prepare_tool_arguments_completion_mode(self):
|
||||
"""Test preparing tool arguments for completion mode"""
|
||||
app = Mock(spec=App)
|
||||
app.mode = AppMode.COMPLETION.value
|
||||
app.mode = AppMode.COMPLETION
|
||||
|
||||
arguments = {"name": "John"}
|
||||
|
||||
@@ -336,7 +336,7 @@ class TestUtilityFunctions:
|
||||
def test_extract_answer_from_mapping_response_chat(self):
|
||||
"""Test extracting answer from mapping response for chat mode"""
|
||||
app = Mock(spec=App)
|
||||
app.mode = AppMode.CHAT.value
|
||||
app.mode = AppMode.CHAT
|
||||
|
||||
response = {"answer": "test answer", "other": "data"}
|
||||
|
||||
@@ -347,7 +347,7 @@ class TestUtilityFunctions:
|
||||
def test_extract_answer_from_mapping_response_workflow(self):
|
||||
"""Test extracting answer from mapping response for workflow mode"""
|
||||
app = Mock(spec=App)
|
||||
app.mode = AppMode.WORKFLOW.value
|
||||
app.mode = AppMode.WORKFLOW
|
||||
|
||||
response = {"data": {"outputs": {"result": "test result"}}}
|
||||
|
||||
|
||||
@@ -66,7 +66,7 @@ def test__convert_to_http_request_node_for_chatbot(default_variables):
|
||||
app_model = MagicMock()
|
||||
app_model.id = "app_id"
|
||||
app_model.tenant_id = "tenant_id"
|
||||
app_model.mode = AppMode.CHAT.value
|
||||
app_model.mode = AppMode.CHAT
|
||||
|
||||
api_based_extension_id = "api_based_extension_id"
|
||||
mock_api_based_extension = APIBasedExtension(
|
||||
@@ -127,7 +127,7 @@ def test__convert_to_http_request_node_for_workflow_app(default_variables):
|
||||
app_model = MagicMock()
|
||||
app_model.id = "app_id"
|
||||
app_model.tenant_id = "tenant_id"
|
||||
app_model.mode = AppMode.WORKFLOW.value
|
||||
app_model.mode = AppMode.WORKFLOW
|
||||
|
||||
api_based_extension_id = "api_based_extension_id"
|
||||
mock_api_based_extension = APIBasedExtension(
|
||||
|
||||
Reference in New Issue
Block a user