feat: mypy for all type check (#10921)

This commit is contained in:
yihong
2024-12-24 18:38:51 +08:00
committed by GitHub
parent c91e8b1737
commit 56e15d09a9
584 changed files with 3975 additions and 2826 deletions

View File

@@ -14,7 +14,7 @@ class UserTool(BaseModel):
label: I18nObject # label
description: I18nObject
parameters: Optional[list[ToolParameter]] = None
labels: list[str] = None
labels: list[str] | None = None
UserToolProviderTypeLiteral = Optional[Literal["builtin", "api", "workflow"]]

View File

@@ -18,7 +18,7 @@ class ApiToolBundle(BaseModel):
# summary
summary: Optional[str] = None
# operation_id
operation_id: str = None
operation_id: str | None = None
# parameters
parameters: Optional[list[ToolParameter]] = None
# author

View File

@@ -244,18 +244,19 @@ class ToolParameter(BaseModel):
"""
# convert options to ToolParameterOption
if options:
options = [
options_tool_parametor = [
ToolParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option)) for option in options
]
return cls(
name=name,
label=I18nObject(en_US="", zh_Hans=""),
human_description=I18nObject(en_US="", zh_Hans=""),
placeholder=None,
type=type,
form=cls.ToolParameterForm.LLM,
llm_description=llm_description,
required=required,
options=options,
options=options_tool_parametor,
)
@@ -331,7 +332,7 @@ class ToolProviderCredentials(BaseModel):
"default": self.default,
"options": self.options,
"help": self.help.to_dict() if self.help else None,
"label": self.label.to_dict(),
"label": self.label.to_dict() if self.label else None,
"url": self.url,
"placeholder": self.placeholder.to_dict() if self.placeholder else None,
}
@@ -374,7 +375,10 @@ class ToolRuntimeVariablePool(BaseModel):
pool[index] = ToolRuntimeImageVariable(**variable)
super().__init__(**data)
def dict(self) -> dict:
def dict(self) -> dict: # type: ignore
"""
FIXME: just ignore the type check for now
"""
return {
"conversation_id": self.conversation_id,
"user_id": self.user_id,

View File

@@ -1,9 +1,14 @@
from typing import Optional
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import (
ApiProviderAuthType,
ToolCredentialsOption,
ToolDescription,
ToolIdentity,
ToolProviderCredentials,
ToolProviderIdentity,
ToolProviderType,
)
from core.tools.provider.tool_provider import ToolProviderController
@@ -64,21 +69,18 @@ class ApiToolProviderController(ToolProviderController):
pass
else:
raise ValueError(f"invalid auth type {auth_type}")
user_name = db_provider.user.name if db_provider.user_id else ""
user_name = db_provider.user.name if db_provider.user_id and db_provider.user is not None else ""
return ApiToolProviderController(
**{
"identity": {
"author": user_name,
"name": db_provider.name,
"label": {"en_US": db_provider.name, "zh_Hans": db_provider.name},
"description": {"en_US": db_provider.description, "zh_Hans": db_provider.description},
"icon": db_provider.icon,
},
"credentials_schema": credentials_schema,
"provider_id": db_provider.id or "",
}
identity=ToolProviderIdentity(
author=user_name,
name=db_provider.name,
label=I18nObject(en_US=db_provider.name, zh_Hans=db_provider.name),
description=I18nObject(en_US=db_provider.description, zh_Hans=db_provider.description),
icon=db_provider.icon,
),
credentials_schema=credentials_schema,
provider_id=db_provider.id or "",
tools=None,
)
@property
@@ -93,24 +95,22 @@ class ApiToolProviderController(ToolProviderController):
:return: the tool
"""
return ApiTool(
**{
"api_bundle": tool_bundle,
"identity": {
"author": tool_bundle.author,
"name": tool_bundle.operation_id,
"label": {"en_US": tool_bundle.operation_id, "zh_Hans": tool_bundle.operation_id},
"icon": self.identity.icon,
"provider": self.provider_id,
},
"description": {
"human": {"en_US": tool_bundle.summary or "", "zh_Hans": tool_bundle.summary or ""},
"llm": tool_bundle.summary or "",
},
"parameters": tool_bundle.parameters or [],
}
api_bundle=tool_bundle,
identity=ToolIdentity(
author=tool_bundle.author,
name=tool_bundle.operation_id or "",
label=I18nObject(en_US=tool_bundle.operation_id, zh_Hans=tool_bundle.operation_id),
icon=self.identity.icon if self.identity else None,
provider=self.provider_id,
),
description=ToolDescription(
human=I18nObject(en_US=tool_bundle.summary or "", zh_Hans=tool_bundle.summary or ""),
llm=tool_bundle.summary or "",
),
parameters=tool_bundle.parameters or [],
)
def load_bundled_tools(self, tools: list[ApiToolBundle]) -> list[ApiTool]:
def load_bundled_tools(self, tools: list[ApiToolBundle]) -> list[Tool]:
"""
load bundled tools
@@ -121,7 +121,7 @@ class ApiToolProviderController(ToolProviderController):
return self.tools
def get_tools(self, user_id: str, tenant_id: str) -> list[ApiTool]:
def get_tools(self, user_id: str = "", tenant_id: str = "") -> Optional[list[Tool]]:
"""
fetch tools from database
@@ -131,6 +131,8 @@ class ApiToolProviderController(ToolProviderController):
"""
if self.tools is not None:
return self.tools
if self.identity is None:
return None
tools: list[Tool] = []
@@ -151,7 +153,7 @@ class ApiToolProviderController(ToolProviderController):
self.tools = tools
return tools
def get_tool(self, tool_name: str) -> ApiTool:
def get_tool(self, tool_name: str) -> Tool:
"""
get tool by name
@@ -161,7 +163,9 @@ class ApiToolProviderController(ToolProviderController):
if self.tools is None:
self.get_tools()
for tool in self.tools:
for tool in self.tools or []:
if tool.identity is None:
continue
if tool.identity.name == tool_name:
return tool

View File

@@ -1,9 +1,10 @@
import logging
from typing import Any
from typing import Any, Optional
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolParameter, ToolParameterOption, ToolProviderType
from core.tools.provider.tool_provider import ToolProviderController
from core.tools.tool.api_tool import ApiTool
from core.tools.tool.tool import Tool
from extensions.ext_database import db
from models.model import App, AppModelConfig
@@ -20,10 +21,10 @@ class AppToolProviderEntity(ToolProviderController):
def _validate_credentials(self, tool_name: str, credentials: dict[str, Any]) -> None:
pass
def validate_parameters(self, tool_name: str, tool_parameters: dict[str, Any]) -> None:
def validate_parameters(self, tool_id: int, tool_name: str, tool_parameters: dict[str, Any]) -> None:
pass
def get_tools(self, user_id: str) -> list[Tool]:
def get_tools(self, user_id: str = "", tenant_id: str = "") -> list[Tool]:
db_tools: list[PublishedAppTool] = (
db.session.query(PublishedAppTool)
.filter(
@@ -38,7 +39,7 @@ class AppToolProviderEntity(ToolProviderController):
tools: list[Tool] = []
for db_tool in db_tools:
tool = {
tool: dict[str, Any] = {
"identity": {
"author": db_tool.author,
"name": db_tool.tool_name,
@@ -52,7 +53,7 @@ class AppToolProviderEntity(ToolProviderController):
"parameters": [],
}
# get app from db
app: App = db_tool.app
app: Optional[App] = db_tool.app
if not app:
logger.error(f"app {db_tool.app_id} not found")
@@ -79,6 +80,7 @@ class AppToolProviderEntity(ToolProviderController):
type=ToolParameter.ToolParameterType.STRING,
required=required,
default=default,
placeholder=I18nObject(en_US="", zh_Hans=""),
)
)
elif form_type == "select":
@@ -92,6 +94,7 @@ class AppToolProviderEntity(ToolProviderController):
type=ToolParameter.ToolParameterType.SELECT,
required=required,
default=default,
placeholder=I18nObject(en_US="", zh_Hans=""),
options=[
ToolParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option))
for option in options
@@ -99,5 +102,5 @@ class AppToolProviderEntity(ToolProviderController):
)
)
tools.append(Tool(**tool))
tools.append(ApiTool(**tool))
return tools

View File

@@ -5,7 +5,7 @@ from core.tools.entities.api_entities import UserToolProvider
class BuiltinToolProviderSort:
_position = {}
_position: dict[str, int] = {}
@classmethod
def sort(cls, providers: list[UserToolProvider]) -> list[UserToolProvider]:

View File

@@ -4,7 +4,7 @@ from hmac import new as hmac_new
from json import loads as json_loads
from threading import Lock
from time import sleep, time
from typing import Any
from typing import Any, Union
from httpx import get, post
from requests import get as requests_get
@@ -21,23 +21,25 @@ class AIPPTGenerateToolAdapter:
"""
_api_base_url = URL("https://co.aippt.cn/api")
_api_token_cache = {}
_style_cache = {}
_api_token_cache: dict[str, dict[str, Union[str, float]]] = {}
_style_cache: dict[str, dict[str, Union[list[dict[str, Any]], float]]] = {}
_api_token_cache_lock = Lock()
_style_cache_lock = Lock()
_api_token_cache_lock: Lock = Lock()
_style_cache_lock: Lock = Lock()
_task = {}
_task: dict[str, Any] = {}
_task_type_map = {
"auto": 1,
"markdown": 7,
}
_tool: BuiltinTool
_tool: BuiltinTool | None
def __init__(self, tool: BuiltinTool = None):
def __init__(self, tool: BuiltinTool | None = None):
self._tool = tool
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
def _invoke(
self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
Invokes the AIPPT generate tool with the given user ID and tool parameters.
@@ -68,8 +70,8 @@ class AIPPTGenerateToolAdapter:
)
# get suit
color: str = tool_parameters.get("color")
style: str = tool_parameters.get("style")
color: str = tool_parameters.get("color", "")
style: str = tool_parameters.get("style", "")
if color == "__default__":
color_id = ""
@@ -226,7 +228,7 @@ class AIPPTGenerateToolAdapter:
return ""
def _generate_ppt(self, task_id: str, suit_id: int, user_id) -> tuple[str, str]:
def _generate_ppt(self, task_id: str, suit_id: int, user_id: str) -> tuple[str, str]:
"""
Generate a ppt
@@ -362,7 +364,9 @@ class AIPPTGenerateToolAdapter:
).decode("utf-8")
@classmethod
def _get_styles(cls, credentials: dict[str, str], user_id: str) -> tuple[list[dict], list[dict]]:
def _get_styles(
cls, credentials: dict[str, str], user_id: str
) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
"""
Get styles
"""
@@ -415,7 +419,7 @@ class AIPPTGenerateToolAdapter:
return colors, styles
def get_styles(self, user_id: str) -> tuple[list[dict], list[dict]]:
def get_styles(self, user_id: str) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
"""
Get styles
@@ -507,7 +511,9 @@ class AIPPTGenerateTool(BuiltinTool):
def __init__(self, **kwargs: Any):
super().__init__(**kwargs)
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
def _invoke(
self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
return AIPPTGenerateToolAdapter(self)._invoke(user_id, tool_parameters)
def get_runtime_parameters(self) -> list[ToolParameter]:

View File

@@ -1,7 +1,7 @@
import logging
from typing import Any, Optional
import arxiv
import arxiv # type: ignore
from pydantic import BaseModel, Field
from core.tools.entities.tool_entities import ToolInvokeMessage

View File

@@ -11,19 +11,21 @@ from services.model_provider_service import ModelProviderService
class TTSTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> list[ToolInvokeMessage]:
provider, model = tool_parameters.get("model").split("#")
voice = tool_parameters.get(f"voice#{provider}#{model}")
provider, model = tool_parameters.get("model", "").split("#")
voice = tool_parameters.get(f"voice#{provider}#{model}", "")
model_manager = ModelManager()
if not self.runtime:
raise ValueError("Runtime is required")
model_instance = model_manager.get_model_instance(
tenant_id=self.runtime.tenant_id,
tenant_id=self.runtime.tenant_id or "",
provider=provider,
model_type=ModelType.TTS,
model=model,
)
tts = model_instance.invoke_tts(
content_text=tool_parameters.get("text"),
content_text=tool_parameters.get("text", ""),
user=user_id,
tenant_id=self.runtime.tenant_id,
tenant_id=self.runtime.tenant_id or "",
voice=voice,
)
buffer = io.BytesIO()
@@ -41,8 +43,11 @@ class TTSTool(BuiltinTool):
]
def get_available_models(self) -> list[tuple[str, str, list[Any]]]:
if not self.runtime:
raise ValueError("Runtime is required")
model_provider_service = ModelProviderService()
models = model_provider_service.get_models_by_model_type(tenant_id=self.runtime.tenant_id, model_type="tts")
tid: str = self.runtime.tenant_id or ""
models = model_provider_service.get_models_by_model_type(tenant_id=tid, model_type="tts")
items = []
for provider_model in models:
provider = provider_model.provider
@@ -62,6 +67,8 @@ class TTSTool(BuiltinTool):
ToolParameter(
name=f"voice#{provider}#{model}",
label=I18nObject(en_US=f"Voice of {model}({provider})"),
human_description=I18nObject(en_US=f"Select a voice for {model} model"),
placeholder=I18nObject(en_US="Select a voice"),
type=ToolParameter.ToolParameterType.SELECT,
form=ToolParameter.ToolParameterForm.FORM,
options=[
@@ -83,6 +90,7 @@ class TTSTool(BuiltinTool):
type=ToolParameter.ToolParameterType.SELECT,
form=ToolParameter.ToolParameterForm.FORM,
required=True,
placeholder=I18nObject(en_US="Select a model", zh_Hans="选择模型"),
options=options,
),
)

View File

@@ -2,8 +2,8 @@ import json
import logging
from typing import Any, Union
import boto3
from botocore.exceptions import BotoCoreError
import boto3 # type: ignore
from botocore.exceptions import BotoCoreError # type: ignore
from pydantic import BaseModel, Field
from core.tools.entities.tool_entities import ToolInvokeMessage

View File

@@ -1,7 +1,7 @@
import json
from typing import Any, Union
import boto3
import boto3 # type: ignore
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool

View File

@@ -2,7 +2,7 @@ import json
import logging
from typing import Any, Union
import boto3
import boto3 # type: ignore
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool

View File

@@ -2,7 +2,7 @@ import json
import operator
from typing import Any, Union
import boto3
import boto3 # type: ignore
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
@@ -10,8 +10,8 @@ from core.tools.tool.builtin_tool import BuiltinTool
class SageMakerReRankTool(BuiltinTool):
sagemaker_client: Any = None
sagemaker_endpoint: str = None
topk: int = None
sagemaker_endpoint: str | None = None
topk: int | None = None
def _sagemaker_rerank(self, query_input: str, docs: list[str], rerank_endpoint: str):
inputs = [query_input] * len(docs)

View File

@@ -2,7 +2,7 @@ import json
from enum import Enum
from typing import Any, Optional, Union
import boto3
import boto3 # type: ignore
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
@@ -17,7 +17,7 @@ class TTSModelType(Enum):
class SageMakerTTSTool(BuiltinTool):
sagemaker_client: Any = None
sagemaker_endpoint: str = None
sagemaker_endpoint: str | None = None
s3_client: Any = None
comprehend_client: Any = None

View File

@@ -1,6 +1,6 @@
from typing import Any, Union
from zhipuai import ZhipuAI
from zhipuai import ZhipuAI # type: ignore
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool

View File

@@ -1,7 +1,7 @@
from typing import Any, Union
import httpx
from zhipuai import ZhipuAI
from zhipuai import ZhipuAI # type: ignore
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool

View File

@@ -1,7 +1,7 @@
import random
from typing import Any, Union
from zhipuai import ZhipuAI
from zhipuai import ZhipuAI # type: ignore
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool

View File

@@ -7,18 +7,22 @@ from core.tools.utils.feishu_api_utils import FeishuRequest
class SearchRecordsTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
if not self.runtime or not self.runtime.credentials:
raise ValueError("Runtime is not set")
app_id = self.runtime.credentials.get("app_id")
app_secret = self.runtime.credentials.get("app_secret")
if not app_id or not app_secret:
raise ValueError("app_id and app_secret are required")
client = FeishuRequest(app_id, app_secret)
app_token = tool_parameters.get("app_token")
table_id = tool_parameters.get("table_id")
table_name = tool_parameters.get("table_name")
view_id = tool_parameters.get("view_id")
field_names = tool_parameters.get("field_names")
sort = tool_parameters.get("sort")
filters = tool_parameters.get("filter")
page_token = tool_parameters.get("page_token")
app_token = tool_parameters.get("app_token", "")
table_id = tool_parameters.get("table_id", "")
table_name = tool_parameters.get("table_name", "")
view_id = tool_parameters.get("view_id", "")
field_names = tool_parameters.get("field_names", "")
sort = tool_parameters.get("sort", "")
filters = tool_parameters.get("filter", "")
page_token = tool_parameters.get("page_token", "")
automatic_fields = tool_parameters.get("automatic_fields", False)
user_id_type = tool_parameters.get("user_id_type", "open_id")
page_size = tool_parameters.get("page_size", 20)

View File

@@ -7,14 +7,18 @@ from core.tools.utils.feishu_api_utils import FeishuRequest
class UpdateRecordsTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
if not self.runtime or not self.runtime.credentials:
raise ValueError("Runtime is not set")
app_id = self.runtime.credentials.get("app_id")
app_secret = self.runtime.credentials.get("app_secret")
if not app_id or not app_secret:
raise ValueError("app_id and app_secret are required")
client = FeishuRequest(app_id, app_secret)
app_token = tool_parameters.get("app_token")
table_id = tool_parameters.get("table_id")
table_name = tool_parameters.get("table_name")
records = tool_parameters.get("records")
app_token = tool_parameters.get("app_token", "")
table_id = tool_parameters.get("table_id", "")
table_name = tool_parameters.get("table_name", "")
records = tool_parameters.get("records", "")
user_id_type = tool_parameters.get("user_id_type", "open_id")
res = client.update_records(app_token, table_id, table_name, records, user_id_type)

View File

@@ -7,12 +7,16 @@ from core.tools.utils.feishu_api_utils import FeishuRequest
class AddEventAttendeesTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
if not self.runtime or not self.runtime.credentials:
raise ValueError("Runtime is not set")
app_id = self.runtime.credentials.get("app_id")
app_secret = self.runtime.credentials.get("app_secret")
if not app_id or not app_secret:
raise ValueError("app_id and app_secret are required")
client = FeishuRequest(app_id, app_secret)
event_id = tool_parameters.get("event_id")
attendee_phone_or_email = tool_parameters.get("attendee_phone_or_email")
event_id = tool_parameters.get("event_id", "")
attendee_phone_or_email = tool_parameters.get("attendee_phone_or_email", "")
need_notification = tool_parameters.get("need_notification", True)
res = client.add_event_attendees(event_id, attendee_phone_or_email, need_notification)

View File

@@ -7,11 +7,15 @@ from core.tools.utils.feishu_api_utils import FeishuRequest
class DeleteEventTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
if not self.runtime or not self.runtime.credentials:
raise ValueError("Runtime is not set")
app_id = self.runtime.credentials.get("app_id")
app_secret = self.runtime.credentials.get("app_secret")
if not app_id or not app_secret:
raise ValueError("app_id and app_secret are required")
client = FeishuRequest(app_id, app_secret)
event_id = tool_parameters.get("event_id")
event_id = tool_parameters.get("event_id", "")
need_notification = tool_parameters.get("need_notification", True)
res = client.delete_event(event_id, need_notification)

View File

@@ -7,8 +7,12 @@ from core.tools.utils.feishu_api_utils import FeishuRequest
class GetPrimaryCalendarTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
if not self.runtime or not self.runtime.credentials:
raise ValueError("Runtime is not set")
app_id = self.runtime.credentials.get("app_id")
app_secret = self.runtime.credentials.get("app_secret")
if not app_id or not app_secret:
raise ValueError("app_id and app_secret are required")
client = FeishuRequest(app_id, app_secret)
user_id_type = tool_parameters.get("user_id_type", "open_id")

View File

@@ -7,14 +7,18 @@ from core.tools.utils.feishu_api_utils import FeishuRequest
class ListEventsTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
if not self.runtime or not self.runtime.credentials:
raise ValueError("Runtime is not set")
app_id = self.runtime.credentials.get("app_id")
app_secret = self.runtime.credentials.get("app_secret")
if not app_id or not app_secret:
raise ValueError("app_id and app_secret are required")
client = FeishuRequest(app_id, app_secret)
start_time = tool_parameters.get("start_time")
end_time = tool_parameters.get("end_time")
page_token = tool_parameters.get("page_token")
page_size = tool_parameters.get("page_size")
start_time = tool_parameters.get("start_time", "")
end_time = tool_parameters.get("end_time", "")
page_token = tool_parameters.get("page_token", "")
page_size = tool_parameters.get("page_size", 50)
res = client.list_events(start_time, end_time, page_token, page_size)

View File

@@ -7,16 +7,20 @@ from core.tools.utils.feishu_api_utils import FeishuRequest
class UpdateEventTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
if not self.runtime or not self.runtime.credentials:
raise ValueError("Runtime is not set")
app_id = self.runtime.credentials.get("app_id")
app_secret = self.runtime.credentials.get("app_secret")
if not app_id or not app_secret:
raise ValueError("app_id and app_secret are required")
client = FeishuRequest(app_id, app_secret)
event_id = tool_parameters.get("event_id")
summary = tool_parameters.get("summary")
description = tool_parameters.get("description")
event_id = tool_parameters.get("event_id", "")
summary = tool_parameters.get("summary", "")
description = tool_parameters.get("description", "")
need_notification = tool_parameters.get("need_notification", True)
start_time = tool_parameters.get("start_time")
end_time = tool_parameters.get("end_time")
start_time = tool_parameters.get("start_time", "")
end_time = tool_parameters.get("end_time", "")
auto_record = tool_parameters.get("auto_record", False)
res = client.update_event(event_id, summary, description, need_notification, start_time, end_time, auto_record)

View File

@@ -7,13 +7,17 @@ from core.tools.utils.feishu_api_utils import FeishuRequest
class CreateDocumentTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
if not self.runtime or not self.runtime.credentials:
raise ValueError("Runtime is not set")
app_id = self.runtime.credentials.get("app_id")
app_secret = self.runtime.credentials.get("app_secret")
if not app_id or not app_secret:
raise ValueError("app_id and app_secret are required")
client = FeishuRequest(app_id, app_secret)
title = tool_parameters.get("title")
content = tool_parameters.get("content")
folder_token = tool_parameters.get("folder_token")
title = tool_parameters.get("title", "")
content = tool_parameters.get("content", "")
folder_token = tool_parameters.get("folder_token", "")
res = client.create_document(title, content, folder_token)
return self.create_json_message(res)

View File

@@ -7,11 +7,15 @@ from core.tools.utils.feishu_api_utils import FeishuRequest
class ListDocumentBlockTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
if not self.runtime or not self.runtime.credentials:
raise ValueError("Runtime is not set")
app_id = self.runtime.credentials.get("app_id")
app_secret = self.runtime.credentials.get("app_secret")
if not app_id or not app_secret:
raise ValueError("app_id and app_secret are required")
client = FeishuRequest(app_id, app_secret)
document_id = tool_parameters.get("document_id")
document_id = tool_parameters.get("document_id", "")
page_token = tool_parameters.get("page_token", "")
user_id_type = tool_parameters.get("user_id_type", "open_id")
page_size = tool_parameters.get("page_size", 500)

View File

@@ -1,7 +1,7 @@
import json
from typing import Any, Union
from jsonpath_ng import parse
from jsonpath_ng import parse # type: ignore
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool

View File

@@ -1,7 +1,7 @@
import json
from typing import Any, Union
from jsonpath_ng import parse
from jsonpath_ng import parse # type: ignore
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool

View File

@@ -1,7 +1,7 @@
import json
from typing import Any, Union
from jsonpath_ng import parse
from jsonpath_ng import parse # type: ignore
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool

View File

@@ -1,7 +1,7 @@
import json
from typing import Any, Union
from jsonpath_ng import parse
from jsonpath_ng import parse # type: ignore
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool

View File

@@ -1,7 +1,7 @@
import logging
from typing import Any, Union
import numexpr as ne
import numexpr as ne # type: ignore
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool

View File

@@ -1,4 +1,4 @@
from novita_client import (
from novita_client import ( # type: ignore
Txt2ImgV3Embedding,
Txt2ImgV3HiresFix,
Txt2ImgV3LoRA,

View File

@@ -2,7 +2,7 @@ from base64 import b64decode
from copy import deepcopy
from typing import Any, Union
from novita_client import (
from novita_client import ( # type: ignore
NovitaClient,
)

View File

@@ -2,7 +2,7 @@ from base64 import b64decode
from copy import deepcopy
from typing import Any, Union
from novita_client import (
from novita_client import ( # type: ignore
NovitaClient,
)

View File

@@ -13,7 +13,7 @@ from core.tools.tool.builtin_tool import BuiltinTool
with warnings.catch_warnings():
warnings.simplefilter("ignore")
from pydub import AudioSegment
from pydub import AudioSegment # type: ignore
class PodcastAudioGeneratorTool(BuiltinTool):

View File

@@ -2,10 +2,10 @@ import io
import logging
from typing import Any, Union
from qrcode.constants import ERROR_CORRECT_H, ERROR_CORRECT_L, ERROR_CORRECT_M, ERROR_CORRECT_Q
from qrcode.image.base import BaseImage
from qrcode.image.pure import PyPNGImage
from qrcode.main import QRCode
from qrcode.constants import ERROR_CORRECT_H, ERROR_CORRECT_L, ERROR_CORRECT_M, ERROR_CORRECT_Q # type: ignore
from qrcode.image.base import BaseImage # type: ignore
from qrcode.image.pure import PyPNGImage # type: ignore
from qrcode.main import QRCode # type: ignore
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool

View File

@@ -1,7 +1,7 @@
from typing import Any, Union
from urllib.parse import parse_qs, urlparse
from youtube_transcript_api import YouTubeTranscriptApi
from youtube_transcript_api import YouTubeTranscriptApi # type: ignore
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool

View File

@@ -37,7 +37,7 @@ class TwilioAPIWrapper(BaseModel):
def set_validator(cls, values: dict) -> dict:
"""Validate that api key and python package exists in environment."""
try:
from twilio.rest import Client
from twilio.rest import Client # type: ignore
except ImportError:
raise ImportError("Could not import twilio python package. Please install it with `pip install twilio`.")
account_sid = values.get("account_sid")

View File

@@ -1,7 +1,7 @@
from typing import Any
from twilio.base.exceptions import TwilioRestException
from twilio.rest import Client
from twilio.base.exceptions import TwilioRestException # type: ignore
from twilio.rest import Client # type: ignore
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController

View File

@@ -1,6 +1,6 @@
from typing import Any, Union
from vanna.remote import VannaDefault
from vanna.remote import VannaDefault # type: ignore
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.errors import ToolProviderCredentialValidationError
@@ -14,6 +14,9 @@ class VannaTool(BuiltinTool):
"""
invoke tools
"""
# Ensure runtime and credentials
if not self.runtime or not self.runtime.credentials:
raise ToolProviderCredentialValidationError("Tool runtime or credentials are missing")
api_key = self.runtime.credentials.get("api_key", None)
if not api_key:
raise ToolProviderCredentialValidationError("Please input api key")

View File

@@ -1,6 +1,6 @@
from typing import Any, Optional, Union
import wikipedia
import wikipedia # type: ignore
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool

View File

@@ -3,7 +3,7 @@ from typing import Any, Union
import pandas as pd
from requests.exceptions import HTTPError, ReadTimeout
from yfinance import download
from yfinance import download # type: ignore
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool

View File

@@ -1,6 +1,6 @@
from typing import Any, Union
import yfinance
import yfinance # type: ignore
from requests.exceptions import HTTPError, ReadTimeout
from core.tools.entities.tool_entities import ToolInvokeMessage

View File

@@ -1,7 +1,7 @@
from typing import Any, Union
from requests.exceptions import HTTPError, ReadTimeout
from yfinance import Ticker
from yfinance import Ticker # type: ignore
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool

View File

@@ -1,7 +1,7 @@
from datetime import datetime
from typing import Any, Union
from googleapiclient.discovery import build
from googleapiclient.discovery import build # type: ignore
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool

View File

@@ -1,6 +1,6 @@
from abc import abstractmethod
from os import listdir, path
from typing import Any
from typing import Any, Optional
from core.helper.module_import_helper import load_single_subclass_from_source
from core.tools.entities.tool_entities import ToolParameter, ToolProviderCredentials, ToolProviderType
@@ -50,6 +50,8 @@ class BuiltinToolProviderController(ToolProviderController):
"""
if self.tools:
return self.tools
if not self.identity:
return []
provider = self.identity.name
tool_path = path.join(path.dirname(path.realpath(__file__)), "builtin", provider, "tools")
@@ -86,7 +88,7 @@ class BuiltinToolProviderController(ToolProviderController):
return self.credentials_schema.copy()
def get_tools(self) -> list[Tool]:
def get_tools(self, user_id: str = "", tenant_id: str = "") -> Optional[list[Tool]]:
"""
returns a list of tools that the provider can provide
@@ -94,11 +96,14 @@ class BuiltinToolProviderController(ToolProviderController):
"""
return self._get_builtin_tools()
def get_tool(self, tool_name: str) -> Tool:
def get_tool(self, tool_name: str) -> Optional[Tool]:
"""
returns the tool that the provider can provide
"""
return next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None)
tools = self.get_tools()
if tools is None:
raise ValueError("tools not found")
return next((t for t in tools if t.identity and t.identity.name == tool_name), None)
def get_parameters(self, tool_name: str) -> list[ToolParameter]:
"""
@@ -107,10 +112,13 @@ class BuiltinToolProviderController(ToolProviderController):
:param tool_name: the name of the tool, defined in `get_tools`
:return: list of parameters
"""
tool = next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None)
tools = self.get_tools()
if tools is None:
raise ToolNotFoundError(f"tool {tool_name} not found")
tool = next((t for t in tools if t.identity and t.identity.name == tool_name), None)
if tool is None:
raise ToolNotFoundError(f"tool {tool_name} not found")
return tool.parameters
return tool.parameters or []
@property
def need_credentials(self) -> bool:
@@ -144,6 +152,8 @@ class BuiltinToolProviderController(ToolProviderController):
"""
returns the labels of the provider
"""
if self.identity is None:
return []
return self.identity.tags or []
def validate_parameters(self, tool_id: int, tool_name: str, tool_parameters: dict[str, Any]) -> None:
@@ -159,56 +169,56 @@ class BuiltinToolProviderController(ToolProviderController):
for parameter in tool_parameters_schema:
tool_parameters_need_to_validate[parameter.name] = parameter
for parameter in tool_parameters:
if parameter not in tool_parameters_need_to_validate:
raise ToolParameterValidationError(f"parameter {parameter} not found in tool {tool_name}")
for parameter_name in tool_parameters:
if parameter_name not in tool_parameters_need_to_validate:
raise ToolParameterValidationError(f"parameter {parameter_name} not found in tool {tool_name}")
# check type
parameter_schema = tool_parameters_need_to_validate[parameter]
parameter_schema = tool_parameters_need_to_validate[parameter_name]
if parameter_schema.type == ToolParameter.ToolParameterType.STRING:
if not isinstance(tool_parameters[parameter], str):
raise ToolParameterValidationError(f"parameter {parameter} should be string")
if not isinstance(tool_parameters[parameter_name], str):
raise ToolParameterValidationError(f"parameter {parameter_name} should be string")
elif parameter_schema.type == ToolParameter.ToolParameterType.NUMBER:
if not isinstance(tool_parameters[parameter], int | float):
raise ToolParameterValidationError(f"parameter {parameter} should be number")
if not isinstance(tool_parameters[parameter_name], int | float):
raise ToolParameterValidationError(f"parameter {parameter_name} should be number")
if parameter_schema.min is not None and tool_parameters[parameter] < parameter_schema.min:
if parameter_schema.min is not None and tool_parameters[parameter_name] < parameter_schema.min:
raise ToolParameterValidationError(
f"parameter {parameter} should be greater than {parameter_schema.min}"
f"parameter {parameter_name} should be greater than {parameter_schema.min}"
)
if parameter_schema.max is not None and tool_parameters[parameter] > parameter_schema.max:
if parameter_schema.max is not None and tool_parameters[parameter_name] > parameter_schema.max:
raise ToolParameterValidationError(
f"parameter {parameter} should be less than {parameter_schema.max}"
f"parameter {parameter_name} should be less than {parameter_schema.max}"
)
elif parameter_schema.type == ToolParameter.ToolParameterType.BOOLEAN:
if not isinstance(tool_parameters[parameter], bool):
raise ToolParameterValidationError(f"parameter {parameter} should be boolean")
if not isinstance(tool_parameters[parameter_name], bool):
raise ToolParameterValidationError(f"parameter {parameter_name} should be boolean")
elif parameter_schema.type == ToolParameter.ToolParameterType.SELECT:
if not isinstance(tool_parameters[parameter], str):
raise ToolParameterValidationError(f"parameter {parameter} should be string")
if not isinstance(tool_parameters[parameter_name], str):
raise ToolParameterValidationError(f"parameter {parameter_name} should be string")
options = parameter_schema.options
if not isinstance(options, list):
raise ToolParameterValidationError(f"parameter {parameter} options should be list")
raise ToolParameterValidationError(f"parameter {parameter_name} options should be list")
if tool_parameters[parameter] not in [x.value for x in options]:
raise ToolParameterValidationError(f"parameter {parameter} should be one of {options}")
if tool_parameters[parameter_name] not in [x.value for x in options]:
raise ToolParameterValidationError(f"parameter {parameter_name} should be one of {options}")
tool_parameters_need_to_validate.pop(parameter)
tool_parameters_need_to_validate.pop(parameter_name)
for parameter in tool_parameters_need_to_validate:
parameter_schema = tool_parameters_need_to_validate[parameter]
for parameter_name in tool_parameters_need_to_validate:
parameter_schema = tool_parameters_need_to_validate[parameter_name]
if parameter_schema.required:
raise ToolParameterValidationError(f"parameter {parameter} is required")
raise ToolParameterValidationError(f"parameter {parameter_name} is required")
# the parameter is not set currently, set the default value if needed
if parameter_schema.default is not None:
default_value = parameter_schema.type.cast_value(parameter_schema.default)
tool_parameters[parameter] = default_value
tool_parameters[parameter_name] = default_value
def validate_credentials(self, credentials: dict[str, Any]) -> None:
"""

View File

@@ -24,10 +24,12 @@ class ToolProviderController(BaseModel, ABC):
:return: the credentials schema
"""
if self.credentials_schema is None:
return {}
return self.credentials_schema.copy()
@abstractmethod
def get_tools(self) -> list[Tool]:
def get_tools(self, user_id: str = "", tenant_id: str = "") -> Optional[list[Tool]]:
"""
returns a list of tools that the provider can provide
@@ -36,7 +38,7 @@ class ToolProviderController(BaseModel, ABC):
pass
@abstractmethod
def get_tool(self, tool_name: str) -> Tool:
def get_tool(self, tool_name: str) -> Optional[Tool]:
"""
returns a tool that the provider can provide
@@ -51,10 +53,13 @@ class ToolProviderController(BaseModel, ABC):
:param tool_name: the name of the tool, defined in `get_tools`
:return: list of parameters
"""
tool = next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None)
tools = self.get_tools()
if tools is None:
raise ToolNotFoundError(f"tool {tool_name} not found")
tool = next((t for t in tools if t.identity and t.identity.name == tool_name), None)
if tool is None:
raise ToolNotFoundError(f"tool {tool_name} not found")
return tool.parameters
return tool.parameters or []
@property
def provider_type(self) -> ToolProviderType:
@@ -78,55 +83,55 @@ class ToolProviderController(BaseModel, ABC):
for parameter in tool_parameters_schema:
tool_parameters_need_to_validate[parameter.name] = parameter
for parameter in tool_parameters:
if parameter not in tool_parameters_need_to_validate:
raise ToolParameterValidationError(f"parameter {parameter} not found in tool {tool_name}")
for tool_parameter in tool_parameters:
if tool_parameter not in tool_parameters_need_to_validate:
raise ToolParameterValidationError(f"parameter {tool_parameter} not found in tool {tool_name}")
# check type
parameter_schema = tool_parameters_need_to_validate[parameter]
parameter_schema = tool_parameters_need_to_validate[tool_parameter]
if parameter_schema.type == ToolParameter.ToolParameterType.STRING:
if not isinstance(tool_parameters[parameter], str):
raise ToolParameterValidationError(f"parameter {parameter} should be string")
if not isinstance(tool_parameters[tool_parameter], str):
raise ToolParameterValidationError(f"parameter {tool_parameter} should be string")
elif parameter_schema.type == ToolParameter.ToolParameterType.NUMBER:
if not isinstance(tool_parameters[parameter], int | float):
raise ToolParameterValidationError(f"parameter {parameter} should be number")
if not isinstance(tool_parameters[tool_parameter], int | float):
raise ToolParameterValidationError(f"parameter {tool_parameter} should be number")
if parameter_schema.min is not None and tool_parameters[parameter] < parameter_schema.min:
if parameter_schema.min is not None and tool_parameters[tool_parameter] < parameter_schema.min:
raise ToolParameterValidationError(
f"parameter {parameter} should be greater than {parameter_schema.min}"
f"parameter {tool_parameter} should be greater than {parameter_schema.min}"
)
if parameter_schema.max is not None and tool_parameters[parameter] > parameter_schema.max:
if parameter_schema.max is not None and tool_parameters[tool_parameter] > parameter_schema.max:
raise ToolParameterValidationError(
f"parameter {parameter} should be less than {parameter_schema.max}"
f"parameter {tool_parameter} should be less than {parameter_schema.max}"
)
elif parameter_schema.type == ToolParameter.ToolParameterType.BOOLEAN:
if not isinstance(tool_parameters[parameter], bool):
raise ToolParameterValidationError(f"parameter {parameter} should be boolean")
if not isinstance(tool_parameters[tool_parameter], bool):
raise ToolParameterValidationError(f"parameter {tool_parameter} should be boolean")
elif parameter_schema.type == ToolParameter.ToolParameterType.SELECT:
if not isinstance(tool_parameters[parameter], str):
raise ToolParameterValidationError(f"parameter {parameter} should be string")
if not isinstance(tool_parameters[tool_parameter], str):
raise ToolParameterValidationError(f"parameter {tool_parameter} should be string")
options = parameter_schema.options
if not isinstance(options, list):
raise ToolParameterValidationError(f"parameter {parameter} options should be list")
raise ToolParameterValidationError(f"parameter {tool_parameter} options should be list")
if tool_parameters[parameter] not in [x.value for x in options]:
raise ToolParameterValidationError(f"parameter {parameter} should be one of {options}")
if tool_parameters[tool_parameter] not in [x.value for x in options]:
raise ToolParameterValidationError(f"parameter {tool_parameter} should be one of {options}")
tool_parameters_need_to_validate.pop(parameter)
tool_parameters_need_to_validate.pop(tool_parameter)
for parameter in tool_parameters_need_to_validate:
parameter_schema = tool_parameters_need_to_validate[parameter]
for tool_parameter_validate in tool_parameters_need_to_validate:
parameter_schema = tool_parameters_need_to_validate[tool_parameter_validate]
if parameter_schema.required:
raise ToolParameterValidationError(f"parameter {parameter} is required")
raise ToolParameterValidationError(f"parameter {tool_parameter_validate} is required")
# the parameter is not set currently, set the default value if needed
if parameter_schema.default is not None:
tool_parameters[parameter] = parameter_schema.type.cast_value(parameter_schema.default)
tool_parameters[tool_parameter_validate] = parameter_schema.type.cast_value(parameter_schema.default)
def validate_credentials_format(self, credentials: dict[str, Any]) -> None:
"""
@@ -144,6 +149,8 @@ class ToolProviderController(BaseModel, ABC):
for credential_name in credentials:
if credential_name not in credentials_need_to_validate:
if self.identity is None:
raise ValueError("identity is not set")
raise ToolProviderCredentialValidationError(
f"credential {credential_name} not found in provider {self.identity.name}"
)

View File

@@ -11,6 +11,7 @@ from core.tools.entities.tool_entities import (
ToolProviderType,
)
from core.tools.provider.tool_provider import ToolProviderController
from core.tools.tool.tool import Tool
from core.tools.tool.workflow_tool import WorkflowTool
from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils
from extensions.ext_database import db
@@ -116,6 +117,7 @@ class WorkflowToolProviderController(ToolProviderController):
llm_description=parameter.description,
required=variable.required,
options=options,
placeholder=I18nObject(en_US="", zh_Hans=""),
)
)
elif features.file_upload:
@@ -128,6 +130,7 @@ class WorkflowToolProviderController(ToolProviderController):
llm_description=parameter.description,
required=False,
form=parameter.form,
placeholder=I18nObject(en_US="", zh_Hans=""),
)
)
else:
@@ -157,7 +160,7 @@ class WorkflowToolProviderController(ToolProviderController):
label=db_provider.label,
)
def get_tools(self, user_id: str, tenant_id: str) -> list[WorkflowTool]:
def get_tools(self, user_id: str = "", tenant_id: str = "") -> Optional[list[Tool]]:
"""
fetch tools from database
@@ -168,7 +171,7 @@ class WorkflowToolProviderController(ToolProviderController):
if self.tools is not None:
return self.tools
db_providers: WorkflowToolProvider = (
db_providers: Optional[WorkflowToolProvider] = (
db.session.query(WorkflowToolProvider)
.filter(
WorkflowToolProvider.tenant_id == tenant_id,
@@ -179,12 +182,14 @@ class WorkflowToolProviderController(ToolProviderController):
if not db_providers:
return []
if not db_providers.app:
raise ValueError("app not found")
self.tools = [self._get_db_provider_tool(db_providers, db_providers.app)]
return self.tools
def get_tool(self, tool_name: str) -> Optional[WorkflowTool]:
def get_tool(self, tool_name: str) -> Optional[Tool]:
"""
get tool by name
@@ -195,6 +200,8 @@ class WorkflowToolProviderController(ToolProviderController):
return None
for tool in self.tools:
if tool.identity is None:
continue
if tool.identity.name == tool_name:
return tool

View File

@@ -32,11 +32,13 @@ class ApiTool(Tool):
:param meta: the meta data of a tool call processing, tenant_id is required
:return: the new tool
"""
if self.api_bundle is None:
raise ValueError("api_bundle is required")
return self.__class__(
identity=self.identity.model_copy() if self.identity else None,
parameters=self.parameters.copy() if self.parameters else None,
description=self.description.model_copy() if self.description else None,
api_bundle=self.api_bundle.model_copy() if self.api_bundle else None,
api_bundle=self.api_bundle.model_copy(),
runtime=Tool.Runtime(**runtime),
)
@@ -61,6 +63,8 @@ class ApiTool(Tool):
def assembling_request(self, parameters: dict[str, Any]) -> dict[str, Any]:
headers = {}
if self.runtime is None:
raise ValueError("runtime is required")
credentials = self.runtime.credentials or {}
if "auth_type" not in credentials:
@@ -88,7 +92,7 @@ class ApiTool(Tool):
headers[api_key_header] = credentials["api_key_value"]
needed_parameters = [parameter for parameter in self.api_bundle.parameters if parameter.required]
needed_parameters = [parameter for parameter in (self.api_bundle.parameters or []) if parameter.required]
for parameter in needed_parameters:
if parameter.required and parameter.name not in parameters:
raise ToolParameterValidationError(f"Missing required parameter {parameter.name}")
@@ -137,7 +141,8 @@ class ApiTool(Tool):
params = {}
path_params = {}
body = {}
# FIXME: body should be a dict[str, Any] but it changed a lot in this function
body: Any = {}
cookies = {}
files = []
@@ -198,7 +203,7 @@ class ApiTool(Tool):
body = body
if method in {"get", "head", "post", "put", "delete", "patch"}:
response = getattr(ssrf_proxy, method)(
response: httpx.Response = getattr(ssrf_proxy, method)(
url,
params=params,
headers=headers,
@@ -288,6 +293,7 @@ class ApiTool(Tool):
"""
invoke http request
"""
response: httpx.Response | str = ""
# assemble request
headers = self.assembling_request(tool_parameters)

View File

@@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, cast
from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage
@@ -32,9 +32,12 @@ class BuiltinTool(Tool):
:return: the model result
"""
# invoke model
if self.runtime is None or self.identity is None:
raise ValueError("runtime and identity are required")
return ModelInvocationUtils.invoke(
user_id=user_id,
tenant_id=self.runtime.tenant_id,
tenant_id=self.runtime.tenant_id or "",
tool_type="builtin",
tool_name=self.identity.name,
prompt_messages=prompt_messages,
@@ -50,8 +53,11 @@ class BuiltinTool(Tool):
:param model_config: the model config
:return: the max tokens
"""
if self.runtime is None:
raise ValueError("runtime is required")
return ModelInvocationUtils.get_max_llm_context_tokens(
tenant_id=self.runtime.tenant_id,
tenant_id=self.runtime.tenant_id or "",
)
def get_prompt_tokens(self, prompt_messages: list[PromptMessage]) -> int:
@@ -61,7 +67,12 @@ class BuiltinTool(Tool):
:param prompt_messages: the prompt messages
:return: the tokens
"""
return ModelInvocationUtils.calculate_tokens(tenant_id=self.runtime.tenant_id, prompt_messages=prompt_messages)
if self.runtime is None:
raise ValueError("runtime is required")
return ModelInvocationUtils.calculate_tokens(
tenant_id=self.runtime.tenant_id or "", prompt_messages=prompt_messages
)
def summary(self, user_id: str, content: str) -> str:
max_tokens = self.get_max_tokens()
@@ -81,7 +92,7 @@ class BuiltinTool(Tool):
stop=[],
)
return summary.message.content
return cast(str, summary.message.content)
lines = content.split("\n")
new_lines = []
@@ -102,16 +113,16 @@ class BuiltinTool(Tool):
# merge lines into messages with max tokens
messages: list[str] = []
for i in new_lines:
for j in new_lines:
if len(messages) == 0:
messages.append(i)
messages.append(j)
else:
if len(messages[-1]) + len(i) < max_tokens * 0.5:
messages[-1] += i
if get_prompt_tokens(messages[-1] + i) > max_tokens * 0.7:
messages.append(i)
if len(messages[-1]) + len(j) < max_tokens * 0.5:
messages[-1] += j
if get_prompt_tokens(messages[-1] + j) > max_tokens * 0.7:
messages.append(j)
else:
messages[-1] += i
messages[-1] += j
summaries = []
for i in range(len(messages)):

View File

@@ -1,4 +1,5 @@
import threading
from typing import Any
from flask import Flask, current_app
from pydantic import BaseModel, Field
@@ -7,13 +8,14 @@ from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCa
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.models.document import Document as RagDocument
from core.rag.rerank.rerank_model import RerankModelRunner
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment
default_retrieval_model = {
default_retrieval_model: dict[str, Any] = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
@@ -44,12 +46,12 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
def _run(self, query: str) -> str:
threads = []
all_documents = []
all_documents: list[RagDocument] = []
for dataset_id in self.dataset_ids:
retrieval_thread = threading.Thread(
target=self._retriever,
kwargs={
"flask_app": current_app._get_current_object(),
"flask_app": current_app._get_current_object(), # type: ignore
"dataset_id": dataset_id,
"query": query,
"all_documents": all_documents,
@@ -77,11 +79,11 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
document_score_list = {}
for item in all_documents:
if item.metadata.get("score"):
if item.metadata and item.metadata.get("score"):
document_score_list[item.metadata["doc_id"]] = item.metadata["score"]
document_context_list = []
index_node_ids = [document.metadata["doc_id"] for document in all_documents]
index_node_ids = [document.metadata["doc_id"] for document in all_documents if document.metadata]
segments = DocumentSegment.query.filter(
DocumentSegment.dataset_id.in_(self.dataset_ids),
DocumentSegment.completed_at.isnot(None),
@@ -139,6 +141,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
hit_callback.return_retriever_resource_info(context_list)
return str("\n".join(document_context_list))
return ""
def _retriever(
self,

View File

@@ -1,7 +1,7 @@
from abc import abstractmethod
from typing import Any, Optional
from msal_extensions.persistence import ABC
from msal_extensions.persistence import ABC # type: ignore
from pydantic import BaseModel, ConfigDict
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler

View File

@@ -1,3 +1,5 @@
from typing import Any
from pydantic import BaseModel, Field
from core.rag.datasource.retrieval_service import RetrievalService
@@ -69,25 +71,27 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
metadata=external_document.get("metadata"),
provider="external",
)
document.metadata["score"] = external_document.get("score")
document.metadata["title"] = external_document.get("title")
document.metadata["dataset_id"] = dataset.id
document.metadata["dataset_name"] = dataset.name
results.append(document)
if document.metadata is not None:
document.metadata["score"] = external_document.get("score")
document.metadata["title"] = external_document.get("title")
document.metadata["dataset_id"] = dataset.id
document.metadata["dataset_name"] = dataset.name
results.append(document)
# deal with external documents
context_list = []
for position, item in enumerate(results, start=1):
source = {
"position": position,
"dataset_id": item.metadata.get("dataset_id"),
"dataset_name": item.metadata.get("dataset_name"),
"document_name": item.metadata.get("title"),
"data_source_type": "external",
"retriever_from": self.retriever_from,
"score": item.metadata.get("score"),
"title": item.metadata.get("title"),
"content": item.page_content,
}
if item.metadata is not None:
source = {
"position": position,
"dataset_id": item.metadata.get("dataset_id"),
"dataset_name": item.metadata.get("dataset_name"),
"document_name": item.metadata.get("title"),
"data_source_type": "external",
"retriever_from": self.retriever_from,
"score": item.metadata.get("score"),
"title": item.metadata.get("title"),
"content": item.page_content,
}
context_list.append(source)
for hit_callback in self.hit_callbacks:
hit_callback.return_retriever_resource_info(context_list)
@@ -95,7 +99,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
return str("\n".join([item.page_content for item in results]))
else:
# get retrieval model , if the model is not setting , using default
retrieval_model = dataset.retrieval_model or default_retrieval_model
retrieval_model: dict[str, Any] = dataset.retrieval_model or default_retrieval_model
if dataset.indexing_technique == "economy":
# use keyword table query
documents = RetrievalService.retrieve(
@@ -113,11 +117,11 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
score_threshold=retrieval_model.get("score_threshold", 0.0)
if retrieval_model["score_threshold_enabled"]
else 0.0,
reranking_model=retrieval_model.get("reranking_model", None)
reranking_model=retrieval_model.get("reranking_model")
if retrieval_model["reranking_enable"]
else None,
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
weights=retrieval_model.get("weights", None),
weights=retrieval_model.get("weights"),
)
else:
documents = []
@@ -127,7 +131,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
document_score_list = {}
if dataset.indexing_technique != "economy":
for item in documents:
if item.metadata.get("score"):
if item.metadata is not None and item.metadata.get("score"):
document_score_list[item.metadata["doc_id"]] = item.metadata["score"]
document_context_list = []
index_node_ids = [document.metadata["doc_id"] for document in documents]
@@ -155,20 +159,21 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
context_list = []
resource_number = 1
for segment in sorted_segments:
context = {}
document = Document.query.filter(
document_segment = Document.query.filter(
Document.id == segment.document_id,
Document.enabled == True,
Document.archived == False,
).first()
if dataset and document:
if not document_segment:
continue
if dataset and document_segment:
source = {
"position": resource_number,
"dataset_id": dataset.id,
"dataset_name": dataset.name,
"document_id": document.id,
"document_name": document.name,
"data_source_type": document.data_source_type,
"document_id": document_segment.id,
"document_name": document_segment.name,
"data_source_type": document_segment.data_source_type,
"segment_id": segment.id,
"retriever_from": self.retriever_from,
"score": document_score_list.get(segment.index_node_id, None),

View File

@@ -1,4 +1,4 @@
from typing import Any
from typing import Any, Optional
from core.app.app_config.entities import DatasetRetrieveConfigEntity
from core.app.entities.app_invoke_entities import InvokeFrom
@@ -23,7 +23,7 @@ class DatasetRetrieverTool(Tool):
def get_dataset_tools(
tenant_id: str,
dataset_ids: list[str],
retrieve_config: DatasetRetrieveConfigEntity,
retrieve_config: Optional[DatasetRetrieveConfigEntity],
return_resource: bool,
invoke_from: InvokeFrom,
hit_callback: DatasetIndexToolCallbackHandler,
@@ -51,6 +51,8 @@ class DatasetRetrieverTool(Tool):
invoke_from=invoke_from,
hit_callback=hit_callback,
)
if retrieval_tools is None:
return []
# restore retrieve strategy
retrieve_config.retrieve_strategy = original_retriever_mode
@@ -83,6 +85,7 @@ class DatasetRetrieverTool(Tool):
llm_description="Query for the dataset to be used to retrieve the dataset.",
required=True,
default="",
placeholder=I18nObject(en_US="", zh_Hans=""),
),
]
@@ -102,7 +105,9 @@ class DatasetRetrieverTool(Tool):
return self.create_text_message(text=result)
def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any]) -> None:
def validate_credentials(
self, credentials: dict[str, Any], parameters: dict[str, Any], format_only: bool = False
) -> str | None:
"""
validate the credentials for dataset retriever tool
"""

View File

@@ -91,7 +91,7 @@ class Tool(BaseModel, ABC):
:return: the tool provider type
"""
def load_variables(self, variables: ToolRuntimeVariablePool):
def load_variables(self, variables: ToolRuntimeVariablePool | None) -> None:
"""
load variables from database
@@ -105,6 +105,8 @@ class Tool(BaseModel, ABC):
"""
if not self.variables:
return
if self.identity is None:
return
self.variables.set_file(self.identity.name, variable_name, image_key)
@@ -114,6 +116,8 @@ class Tool(BaseModel, ABC):
"""
if not self.variables:
return
if self.identity is None:
return
self.variables.set_text(self.identity.name, variable_name, text)
@@ -200,7 +204,11 @@ class Tool(BaseModel, ABC):
def invoke(self, user_id: str, tool_parameters: Mapping[str, Any]) -> list[ToolInvokeMessage]:
# update tool_parameters
# TODO: Fix type error.
if self.runtime is None:
return []
if self.runtime.runtime_parameters:
# Convert Mapping to dict before updating
tool_parameters = dict(tool_parameters)
tool_parameters.update(self.runtime.runtime_parameters)
# try parse tool parameters into the correct type
@@ -221,7 +229,7 @@ class Tool(BaseModel, ABC):
Transform tool parameters type
"""
# Temp fix for the issue that the tool parameters will be converted to empty while validating the credentials
result = deepcopy(tool_parameters)
result: dict[str, Any] = deepcopy(dict(tool_parameters))
for parameter in self.parameters or []:
if parameter.name in tool_parameters:
result[parameter.name] = parameter.type.cast_value(tool_parameters[parameter.name])
@@ -234,12 +242,15 @@ class Tool(BaseModel, ABC):
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
pass
def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any]) -> None:
def validate_credentials(
self, credentials: dict[str, Any], parameters: dict[str, Any], format_only: bool = False
) -> str | None:
"""
validate the credentials
:param credentials: the credentials
:param parameters: the parameters
:param format_only: only return the formatted
"""
pass

View File

@@ -68,20 +68,20 @@ class WorkflowTool(Tool):
if data.get("error"):
raise Exception(data.get("error"))
result = []
r = []
outputs = data.get("outputs")
if outputs == None:
outputs = {}
else:
outputs, files = self._extract_files(outputs)
for file in files:
result.append(self.create_file_message(file))
outputs, extracted_files = self._extract_files(outputs)
for f in extracted_files:
r.append(self.create_file_message(f))
result.append(self.create_text_message(json.dumps(outputs, ensure_ascii=False)))
result.append(self.create_json_message(outputs))
r.append(self.create_text_message(json.dumps(outputs, ensure_ascii=False)))
r.append(self.create_json_message(outputs))
return result
return r
def _get_user(self, user_id: str) -> Union[EndUser, Account]:
"""

View File

@@ -3,7 +3,7 @@ from collections.abc import Mapping
from copy import deepcopy
from datetime import UTC, datetime
from mimetypes import guess_type
from typing import Any, Optional, Union
from typing import Any, Optional, Union, cast
from yarl import URL
@@ -46,7 +46,7 @@ class ToolEngine:
invoke_from: InvokeFrom,
agent_tool_callback: DifyAgentCallbackHandler,
trace_manager: Optional[TraceQueueManager] = None,
) -> tuple[str, list[tuple[MessageFile, bool]], ToolInvokeMeta]:
) -> tuple[str, list[tuple[MessageFile, str]], ToolInvokeMeta]:
"""
Agent invokes the tool with the given arguments.
"""
@@ -69,6 +69,8 @@ class ToolEngine:
raise ValueError(f"tool_parameters should be a dict, but got a string: {tool_parameters}")
# invoke the tool
if tool.identity is None:
raise ValueError("tool identity is not set")
try:
# hit the callback handler
agent_tool_callback.on_tool_start(tool_name=tool.identity.name, tool_inputs=tool_parameters)
@@ -163,6 +165,8 @@ class ToolEngine:
"""
Invoke the tool with the given arguments.
"""
if tool.identity is None:
raise ValueError("tool identity is not set")
started_at = datetime.now(UTC)
meta = ToolInvokeMeta(
time_cost=0.0,
@@ -171,7 +175,7 @@ class ToolEngine:
"tool_name": tool.identity.name,
"tool_provider": tool.identity.provider,
"tool_provider_type": tool.tool_provider_type().value,
"tool_parameters": deepcopy(tool.runtime.runtime_parameters),
"tool_parameters": deepcopy(tool.runtime.runtime_parameters) if tool.runtime else {},
"tool_icon": tool.identity.icon,
},
)
@@ -194,9 +198,9 @@ class ToolEngine:
result = ""
for response in tool_response:
if response.type == ToolInvokeMessage.MessageType.TEXT:
result += response.message
result += str(response.message) if response.message is not None else ""
elif response.type == ToolInvokeMessage.MessageType.LINK:
result += f"result link: {response.message}. please tell user to check it."
result += f"result link: {response.message!r}. please tell user to check it."
elif response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}:
result += (
"image has been created and sent to user already, you do not need to create it,"
@@ -205,7 +209,7 @@ class ToolEngine:
elif response.type == ToolInvokeMessage.MessageType.JSON:
result += f"tool response: {json.dumps(response.message, ensure_ascii=False)}."
else:
result += f"tool response: {response.message}."
result += f"tool response: {response.message!r}."
return result
@@ -223,7 +227,7 @@ class ToolEngine:
mimetype = response.meta.get("mime_type")
else:
try:
url = URL(response.message)
url = URL(cast(str, response.message))
extension = url.suffix
guess_type_result, _ = guess_type(f"a{extension}")
if guess_type_result:
@@ -237,7 +241,7 @@ class ToolEngine:
result.append(
ToolInvokeMessageBinary(
mimetype=response.meta.get("mime_type", "image/jpeg"),
url=response.message,
url=cast(str, response.message),
save_as=response.save_as,
)
)
@@ -245,7 +249,7 @@ class ToolEngine:
result.append(
ToolInvokeMessageBinary(
mimetype=response.meta.get("mime_type", "octet/stream"),
url=response.message,
url=cast(str, response.message),
save_as=response.save_as,
)
)
@@ -257,7 +261,7 @@ class ToolEngine:
mimetype=response.meta.get("mime_type", "octet/stream")
if response.meta
else "octet/stream",
url=response.message,
url=cast(str, response.message),
save_as=response.save_as,
)
)

View File

@@ -84,13 +84,17 @@ class ToolLabelManager:
if not isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
raise ValueError("Unsupported tool type")
provider_ids = [controller.provider_id for controller in tool_providers]
provider_ids = [
controller.provider_id
for controller in tool_providers
if isinstance(controller, (ApiToolProviderController, WorkflowToolProviderController))
]
labels: list[ToolLabelBinding] = (
db.session.query(ToolLabelBinding).filter(ToolLabelBinding.tool_id.in_(provider_ids)).all()
)
tool_labels = {label.tool_id: [] for label in labels}
tool_labels: dict[str, list[str]] = {label.tool_id: [] for label in labels}
for label in labels:
tool_labels[label.tool_id].append(label.label_name)

View File

@@ -4,7 +4,7 @@ import mimetypes
from collections.abc import Generator
from os import listdir, path
from threading import Lock, Thread
from typing import Any, Optional, Union
from typing import Any, Optional, Union, cast
from configs import dify_config
from core.agent.entities import AgentToolEntity
@@ -15,15 +15,18 @@ from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.entities.api_entities import UserToolProvider, UserToolProviderTypeLiteral
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ApiProviderAuthType, ToolInvokeFrom, ToolParameter
from core.tools.errors import ToolProviderNotFoundError
from core.tools.errors import ToolNotFoundError, ToolProviderNotFoundError
from core.tools.provider.api_tool_provider import ApiToolProviderController
from core.tools.provider.builtin._positions import BuiltinToolProviderSort
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
from core.tools.provider.tool_provider import ToolProviderController
from core.tools.provider.workflow_tool_provider import WorkflowToolProviderController
from core.tools.tool.api_tool import ApiTool
from core.tools.tool.builtin_tool import BuiltinTool
from core.tools.tool.tool import Tool
from core.tools.tool_label_manager import ToolLabelManager
from core.tools.utils.configuration import ToolConfigurationManager, ToolParameterConfigurationManager
from core.workflow.nodes.tool.entities import ToolEntity
from extensions.ext_database import db
from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
from services.tools.tools_transform_service import ToolTransformService
@@ -33,9 +36,9 @@ logger = logging.getLogger(__name__)
class ToolManager:
_builtin_provider_lock = Lock()
_builtin_providers = {}
_builtin_providers: dict[str, BuiltinToolProviderController] = {}
_builtin_providers_loaded = False
_builtin_tools_labels = {}
_builtin_tools_labels: dict[str, Union[I18nObject, None]] = {}
@classmethod
def get_builtin_provider(cls, provider: str) -> BuiltinToolProviderController:
@@ -55,7 +58,7 @@ class ToolManager:
return cls._builtin_providers[provider]
@classmethod
def get_builtin_tool(cls, provider: str, tool_name: str) -> BuiltinTool:
def get_builtin_tool(cls, provider: str, tool_name: str) -> Union[BuiltinTool, Tool]:
"""
get the builtin tool
@@ -66,13 +69,15 @@ class ToolManager:
"""
provider_controller = cls.get_builtin_provider(provider)
tool = provider_controller.get_tool(tool_name)
if tool is None:
raise ToolNotFoundError(f"tool {tool_name} not found")
return tool
@classmethod
def get_tool(
cls, provider_type: str, provider_id: str, tool_name: str, tenant_id: Optional[str] = None
) -> Union[BuiltinTool, ApiTool]:
) -> Union[BuiltinTool, ApiTool, Tool]:
"""
get the tool
@@ -103,7 +108,7 @@ class ToolManager:
tenant_id: str,
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT,
) -> Union[BuiltinTool, ApiTool]:
) -> Union[BuiltinTool, ApiTool, Tool]:
"""
get the tool runtime
@@ -113,6 +118,7 @@ class ToolManager:
:return: the tool
"""
controller: Union[BuiltinToolProviderController, ApiToolProviderController, WorkflowToolProviderController]
if provider_type == "builtin":
builtin_tool = cls.get_builtin_tool(provider_id, tool_name)
@@ -129,7 +135,7 @@ class ToolManager:
)
# get credentials
builtin_provider: BuiltinToolProvider = (
builtin_provider: Optional[BuiltinToolProvider] = (
db.session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
@@ -177,7 +183,7 @@ class ToolManager:
}
)
elif provider_type == "workflow":
workflow_provider = (
workflow_provider: Optional[WorkflowToolProvider] = (
db.session.query(WorkflowToolProvider)
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
.first()
@@ -187,8 +193,13 @@ class ToolManager:
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
controller = ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider)
controller_tools: Optional[list[Tool]] = controller.get_tools(
user_id="", tenant_id=workflow_provider.tenant_id
)
if controller_tools is None or len(controller_tools) == 0:
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
return controller.get_tools(user_id=None, tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime(
return controller_tools[0].fork_tool_runtime(
runtime={
"tenant_id": tenant_id,
"credentials": {},
@@ -215,7 +226,7 @@ class ToolManager:
if parameter_rule.type == ToolParameter.ToolParameterType.SELECT:
# check if tool_parameter_config in options
options = [x.value for x in parameter_rule.options]
options = [x.value for x in parameter_rule.options or []]
if parameter_value is not None and parameter_value not in options:
raise ValueError(
f"tool parameter {parameter_rule.name} value {parameter_value} not in options {options}"
@@ -267,6 +278,8 @@ class ToolManager:
identity_id=f"AGENT.{app_id}",
)
runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
if tool_entity.runtime is None or tool_entity.runtime.runtime_parameters is None:
raise ValueError("runtime not found or runtime parameters not found")
tool_entity.runtime.runtime_parameters.update(runtime_parameters)
return tool_entity
@@ -312,6 +325,9 @@ class ToolManager:
if runtime_parameters:
runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
if tool_entity.runtime is None or tool_entity.runtime.runtime_parameters is None:
raise ValueError("runtime not found or runtime parameters not found")
tool_entity.runtime.runtime_parameters.update(runtime_parameters)
return tool_entity
@@ -326,6 +342,8 @@ class ToolManager:
"""
# get provider
provider_controller = cls.get_builtin_provider(provider)
if provider_controller.identity is None:
raise ToolProviderNotFoundError(f"builtin provider {provider} not found")
absolute_path = path.join(
path.dirname(path.realpath(__file__)),
@@ -381,11 +399,15 @@ class ToolManager:
),
parent_type=BuiltinToolProviderController,
)
provider: BuiltinToolProviderController = provider_class()
cls._builtin_providers[provider.identity.name] = provider
for tool in provider.get_tools():
provider_controller: BuiltinToolProviderController = provider_class()
if provider_controller.identity is None:
continue
cls._builtin_providers[provider_controller.identity.name] = provider_controller
for tool in provider_controller.get_tools() or []:
if tool.identity is None:
continue
cls._builtin_tools_labels[tool.identity.name] = tool.identity.label
yield provider
yield provider_controller
except Exception as e:
logger.exception(f"load builtin provider {provider}")
@@ -449,9 +471,11 @@ class ToolManager:
# append builtin providers
for provider in builtin_providers:
# handle include, exclude
if provider.identity is None:
continue
if is_filtered(
include_set=dify_config.POSITION_TOOL_INCLUDES_SET,
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET,
include_set=cast(set[str], dify_config.POSITION_TOOL_INCLUDES_SET),
exclude_set=cast(set[str], dify_config.POSITION_TOOL_EXCLUDES_SET),
data=provider,
name_func=lambda x: x.identity.name,
):
@@ -472,7 +496,7 @@ class ToolManager:
db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all()
)
api_provider_controllers = [
api_provider_controllers: list[dict[str, Any]] = [
{"provider": provider, "controller": ToolTransformService.api_provider_to_controller(provider)}
for provider in db_api_providers
]
@@ -495,7 +519,7 @@ class ToolManager:
db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all()
)
workflow_provider_controllers = []
workflow_provider_controllers: list[WorkflowToolProviderController] = []
for provider in workflow_providers:
try:
workflow_provider_controllers.append(
@@ -505,7 +529,9 @@ class ToolManager:
# app has been deleted
pass
labels = ToolLabelManager.get_tools_labels(workflow_provider_controllers)
labels = ToolLabelManager.get_tools_labels(
[cast(ToolProviderController, controller) for controller in workflow_provider_controllers]
)
for provider_controller in workflow_provider_controllers:
user_provider = ToolTransformService.workflow_provider_to_user_provider(
@@ -527,7 +553,7 @@ class ToolManager:
:return: the provider controller, the credentials
"""
provider: ApiToolProvider = (
provider: Optional[ApiToolProvider] = (
db.session.query(ApiToolProvider)
.filter(
ApiToolProvider.id == provider_id,
@@ -556,7 +582,7 @@ class ToolManager:
get tool provider
"""
provider_name = provider
provider: ApiToolProvider = (
provider_tool: Optional[ApiToolProvider] = (
db.session.query(ApiToolProvider)
.filter(
ApiToolProvider.tenant_id == tenant_id,
@@ -565,17 +591,18 @@ class ToolManager:
.first()
)
if provider is None:
if provider_tool is None:
raise ValueError(f"you have not added provider {provider_name}")
try:
credentials = json.loads(provider.credentials_str) or {}
credentials = json.loads(provider_tool.credentials_str) or {}
except:
credentials = {}
# package tool provider controller
controller = ApiToolProviderController.from_db(
provider, ApiProviderAuthType.API_KEY if credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE
provider_tool,
ApiProviderAuthType.API_KEY if credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE,
)
# init tool configuration
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller)
@@ -584,25 +611,28 @@ class ToolManager:
masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)
try:
icon = json.loads(provider.icon)
icon = json.loads(provider_tool.icon)
except:
icon = {"background": "#252525", "content": "\ud83d\ude01"}
# add tool labels
labels = ToolLabelManager.get_tool_labels(controller)
return jsonable_encoder(
{
"schema_type": provider.schema_type,
"schema": provider.schema,
"tools": provider.tools,
"icon": icon,
"description": provider.description,
"credentials": masked_credentials,
"privacy_policy": provider.privacy_policy,
"custom_disclaimer": provider.custom_disclaimer,
"labels": labels,
}
return cast(
dict,
jsonable_encoder(
{
"schema_type": provider_tool.schema_type,
"schema": provider_tool.schema,
"tools": provider_tool.tools,
"icon": icon,
"description": provider_tool.description,
"credentials": masked_credentials,
"privacy_policy": provider_tool.privacy_policy,
"custom_disclaimer": provider_tool.custom_disclaimer,
"labels": labels,
}
),
)
@classmethod
@@ -617,6 +647,7 @@ class ToolManager:
"""
provider_type = provider_type
provider_id = provider_id
provider: Optional[Union[BuiltinToolProvider, ApiToolProvider, WorkflowToolProvider]] = None
if provider_type == "builtin":
return (
dify_config.CONSOLE_API_URL
@@ -626,16 +657,21 @@ class ToolManager:
)
elif provider_type == "api":
try:
provider: ApiToolProvider = (
provider = (
db.session.query(ApiToolProvider)
.filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.id == provider_id)
.first()
)
return json.loads(provider.icon)
if provider is None:
raise ToolProviderNotFoundError(f"api provider {provider_id} not found")
icon = json.loads(provider.icon)
if isinstance(icon, (str, dict)):
return icon
return {"background": "#252525", "content": "\ud83d\ude01"}
except:
return {"background": "#252525", "content": "\ud83d\ude01"}
elif provider_type == "workflow":
provider: WorkflowToolProvider = (
provider = (
db.session.query(WorkflowToolProvider)
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
.first()
@@ -643,7 +679,13 @@ class ToolManager:
if provider is None:
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
return json.loads(provider.icon)
try:
icon = json.loads(provider.icon)
if isinstance(icon, (str, dict)):
return icon
return {"background": "#252525", "content": "\ud83d\ude01"}
except:
return {"background": "#252525", "content": "\ud83d\ude01"}
else:
raise ValueError(f"provider type {provider_type} not found")

View File

@@ -72,9 +72,13 @@ class ToolConfigurationManager(BaseModel):
return a deep copy of credentials with decrypted values
"""
identity_id = ""
if self.provider_controller.identity:
identity_id = f"{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}"
cache = ToolProviderCredentialsCache(
tenant_id=self.tenant_id,
identity_id=f"{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}",
identity_id=identity_id,
cache_type=ToolProviderCredentialsCacheType.PROVIDER,
)
cached_credentials = cache.get()
@@ -95,9 +99,13 @@ class ToolConfigurationManager(BaseModel):
return credentials
def delete_tool_credentials_cache(self):
identity_id = ""
if self.provider_controller.identity:
identity_id = f"{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}"
cache = ToolProviderCredentialsCache(
tenant_id=self.tenant_id,
identity_id=f"{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}",
identity_id=identity_id,
cache_type=ToolProviderCredentialsCacheType.PROVIDER,
)
cache.delete()
@@ -199,6 +207,9 @@ class ToolParameterConfigurationManager(BaseModel):
return a deep copy of parameters with decrypted values
"""
if self.tool_runtime is None or self.tool_runtime.identity is None:
raise ValueError("tool_runtime is required")
cache = ToolParameterCache(
tenant_id=self.tenant_id,
provider=f"{self.provider_type}.{self.provider_name}",
@@ -232,6 +243,9 @@ class ToolParameterConfigurationManager(BaseModel):
return parameters
def delete_tool_parameters_cache(self):
if self.tool_runtime is None or self.tool_runtime.identity is None:
raise ValueError("tool_runtime is required")
cache = ToolParameterCache(
tenant_id=self.tenant_id,
provider=f"{self.provider_type}.{self.provider_name}",

View File

@@ -1,5 +1,5 @@
import json
from typing import Optional
from typing import Any, Optional, cast
import httpx
@@ -101,7 +101,7 @@ class FeishuRequest:
"""
url = f"{self.API_BASE_URL}/access_token/get_tenant_access_token"
payload = {"app_id": app_id, "app_secret": app_secret}
res = self._send_request(url, require_token=False, payload=payload)
res: dict = self._send_request(url, require_token=False, payload=payload)
return res
def create_document(self, title: str, content: str, folder_token: str) -> dict:
@@ -126,15 +126,16 @@ class FeishuRequest:
"content": content,
"folder_token": folder_token,
}
res = self._send_request(url, payload=payload)
res: dict = self._send_request(url, payload=payload)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def write_document(self, document_id: str, content: str, position: str = "end") -> dict:
url = f"{self.API_BASE_URL}/document/write_document"
payload = {"document_id": document_id, "content": content, "position": position}
res = self._send_request(url, payload=payload)
res: dict = self._send_request(url, payload=payload)
return res
def get_document_content(self, document_id: str, mode: str = "markdown", lang: str = "0") -> str:
@@ -155,9 +156,9 @@ class FeishuRequest:
"lang": lang,
}
url = f"{self.API_BASE_URL}/document/get_document_content"
res = self._send_request(url, method="GET", params=params)
res: dict = self._send_request(url, method="GET", params=params)
if "data" in res:
return res.get("data").get("content")
return cast(str, res.get("data", {}).get("content"))
return ""
def list_document_blocks(
@@ -173,9 +174,10 @@ class FeishuRequest:
"page_token": page_token,
}
url = f"{self.API_BASE_URL}/document/list_document_blocks"
res = self._send_request(url, method="GET", params=params)
res: dict = self._send_request(url, method="GET", params=params)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def send_bot_message(self, receive_id_type: str, receive_id: str, msg_type: str, content: str) -> dict:
@@ -191,9 +193,10 @@ class FeishuRequest:
"msg_type": msg_type,
"content": content.strip('"').replace(r"\"", '"').replace(r"\\", "\\"),
}
res = self._send_request(url, params=params, payload=payload)
res: dict = self._send_request(url, params=params, payload=payload)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def send_webhook_message(self, webhook: str, msg_type: str, content: str) -> dict:
@@ -203,7 +206,7 @@ class FeishuRequest:
"msg_type": msg_type,
"content": content.strip('"').replace(r"\"", '"').replace(r"\\", "\\"),
}
res = self._send_request(url, require_token=False, payload=payload)
res: dict = self._send_request(url, require_token=False, payload=payload)
return res
def get_chat_messages(
@@ -227,9 +230,10 @@ class FeishuRequest:
"page_token": page_token,
"page_size": page_size,
}
res = self._send_request(url, method="GET", params=params)
res: dict = self._send_request(url, method="GET", params=params)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def get_thread_messages(
@@ -245,9 +249,10 @@ class FeishuRequest:
"page_token": page_token,
"page_size": page_size,
}
res = self._send_request(url, method="GET", params=params)
res: dict = self._send_request(url, method="GET", params=params)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def create_task(self, summary: str, start_time: str, end_time: str, completed_time: str, description: str) -> dict:
@@ -260,9 +265,10 @@ class FeishuRequest:
"completed_at": completed_time,
"description": description,
}
res = self._send_request(url, payload=payload)
res: dict = self._send_request(url, payload=payload)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def update_task(
@@ -278,9 +284,10 @@ class FeishuRequest:
"completed_time": completed_time,
"description": description,
}
res = self._send_request(url, method="PATCH", payload=payload)
res: dict = self._send_request(url, method="PATCH", payload=payload)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def delete_task(self, task_guid: str) -> dict:
@@ -289,7 +296,7 @@ class FeishuRequest:
payload = {
"task_guid": task_guid,
}
res = self._send_request(url, method="DELETE", payload=payload)
res: dict = self._send_request(url, method="DELETE", payload=payload)
return res
def add_members(self, task_guid: str, member_phone_or_email: str, member_role: str) -> dict:
@@ -300,7 +307,7 @@ class FeishuRequest:
"member_phone_or_email": member_phone_or_email,
"member_role": member_role,
}
res = self._send_request(url, payload=payload)
res: dict = self._send_request(url, payload=payload)
return res
def get_wiki_nodes(self, space_id: str, parent_node_token: str, page_token: str, page_size: int = 20) -> dict:
@@ -312,9 +319,10 @@ class FeishuRequest:
"page_token": page_token,
"page_size": page_size,
}
res = self._send_request(url, payload=payload)
res: dict = self._send_request(url, payload=payload)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def get_primary_calendar(self, user_id_type: str = "open_id") -> dict:
@@ -322,9 +330,10 @@ class FeishuRequest:
params = {
"user_id_type": user_id_type,
}
res = self._send_request(url, method="GET", params=params)
res: dict = self._send_request(url, method="GET", params=params)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def create_event(
@@ -347,9 +356,10 @@ class FeishuRequest:
"auto_record": auto_record,
"attendee_ability": attendee_ability,
}
res = self._send_request(url, payload=payload)
res: dict = self._send_request(url, payload=payload)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def update_event(
@@ -363,7 +373,7 @@ class FeishuRequest:
auto_record: bool,
) -> dict:
url = f"{self.API_BASE_URL}/calendar/update_event/{event_id}"
payload = {}
payload: dict[str, Any] = {}
if summary:
payload["summary"] = summary
if description:
@@ -376,7 +386,7 @@ class FeishuRequest:
payload["need_notification"] = need_notification
if auto_record:
payload["auto_record"] = auto_record
res = self._send_request(url, method="PATCH", payload=payload)
res: dict = self._send_request(url, method="PATCH", payload=payload)
return res
def delete_event(self, event_id: str, need_notification: bool = True) -> dict:
@@ -384,7 +394,7 @@ class FeishuRequest:
params = {
"need_notification": need_notification,
}
res = self._send_request(url, method="DELETE", params=params)
res: dict = self._send_request(url, method="DELETE", params=params)
return res
def list_events(self, start_time: str, end_time: str, page_token: str, page_size: int = 50) -> dict:
@@ -395,9 +405,10 @@ class FeishuRequest:
"page_token": page_token,
"page_size": page_size,
}
res = self._send_request(url, method="GET", params=params)
res: dict = self._send_request(url, method="GET", params=params)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def search_events(
@@ -418,9 +429,10 @@ class FeishuRequest:
"user_id_type": user_id_type,
"page_size": page_size,
}
res = self._send_request(url, payload=payload)
res: dict = self._send_request(url, payload=payload)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def add_event_attendees(self, event_id: str, attendee_phone_or_email: str, need_notification: bool = True) -> dict:
@@ -431,9 +443,10 @@ class FeishuRequest:
"attendee_phone_or_email": attendee_phone_or_email,
"need_notification": need_notification,
}
res = self._send_request(url, payload=payload)
res: dict = self._send_request(url, payload=payload)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def create_spreadsheet(
@@ -447,9 +460,10 @@ class FeishuRequest:
"title": title,
"folder_token": folder_token,
}
res = self._send_request(url, payload=payload)
res: dict = self._send_request(url, payload=payload)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def get_spreadsheet(
@@ -463,9 +477,10 @@ class FeishuRequest:
"spreadsheet_token": spreadsheet_token,
"user_id_type": user_id_type,
}
res = self._send_request(url, method="GET", params=params)
res: dict = self._send_request(url, method="GET", params=params)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def list_spreadsheet_sheets(
@@ -477,9 +492,10 @@ class FeishuRequest:
params = {
"spreadsheet_token": spreadsheet_token,
}
res = self._send_request(url, method="GET", params=params)
res: dict = self._send_request(url, method="GET", params=params)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def add_rows(
@@ -499,9 +515,10 @@ class FeishuRequest:
"length": length,
"values": values,
}
res = self._send_request(url, payload=payload)
res: dict = self._send_request(url, payload=payload)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def add_cols(
@@ -521,9 +538,10 @@ class FeishuRequest:
"length": length,
"values": values,
}
res = self._send_request(url, payload=payload)
res: dict = self._send_request(url, payload=payload)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def read_rows(
@@ -545,9 +563,10 @@ class FeishuRequest:
"num_rows": num_rows,
"user_id_type": user_id_type,
}
res = self._send_request(url, method="GET", params=params)
res: dict = self._send_request(url, method="GET", params=params)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def read_cols(
@@ -569,9 +588,10 @@ class FeishuRequest:
"num_cols": num_cols,
"user_id_type": user_id_type,
}
res = self._send_request(url, method="GET", params=params)
res: dict = self._send_request(url, method="GET", params=params)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def read_table(
@@ -593,9 +613,10 @@ class FeishuRequest:
"query": query,
"user_id_type": user_id_type,
}
res = self._send_request(url, method="GET", params=params)
res: dict = self._send_request(url, method="GET", params=params)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def create_base(
@@ -609,9 +630,10 @@ class FeishuRequest:
"name": name,
"folder_token": folder_token,
}
res = self._send_request(url, payload=payload)
res: dict = self._send_request(url, payload=payload)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def add_records(
@@ -633,9 +655,10 @@ class FeishuRequest:
payload = {
"records": convert_add_records(records),
}
res = self._send_request(url, params=params, payload=payload)
res: dict = self._send_request(url, params=params, payload=payload)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def update_records(
@@ -657,9 +680,10 @@ class FeishuRequest:
payload = {
"records": convert_update_records(records),
}
res = self._send_request(url, params=params, payload=payload)
res: dict = self._send_request(url, params=params, payload=payload)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def delete_records(
@@ -686,9 +710,10 @@ class FeishuRequest:
payload = {
"records": record_id_list,
}
res = self._send_request(url, params=params, payload=payload)
res: dict = self._send_request(url, params=params, payload=payload)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def search_record(
@@ -740,7 +765,7 @@ class FeishuRequest:
except json.JSONDecodeError:
raise ValueError("The input string is not valid JSON")
payload = {}
payload: dict[str, Any] = {}
if view_id:
payload["view_id"] = view_id
@@ -752,10 +777,11 @@ class FeishuRequest:
payload["filter"] = filter_dict
if automatic_fields:
payload["automatic_fields"] = automatic_fields
res = self._send_request(url, params=params, payload=payload)
res: dict = self._send_request(url, params=params, payload=payload)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def get_base_info(
@@ -767,9 +793,10 @@ class FeishuRequest:
params = {
"app_token": app_token,
}
res = self._send_request(url, method="GET", params=params)
res: dict = self._send_request(url, method="GET", params=params)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def create_table(
@@ -797,9 +824,10 @@ class FeishuRequest:
}
if default_view_name:
payload["default_view_name"] = default_view_name
res = self._send_request(url, params=params, payload=payload)
res: dict = self._send_request(url, params=params, payload=payload)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def delete_tables(
@@ -834,9 +862,10 @@ class FeishuRequest:
"table_names": table_name_list,
}
res = self._send_request(url, params=params, payload=payload)
res: dict = self._send_request(url, params=params, payload=payload)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def list_tables(
@@ -852,9 +881,10 @@ class FeishuRequest:
"page_token": page_token,
"page_size": page_size,
}
res = self._send_request(url, method="GET", params=params)
res: dict = self._send_request(url, method="GET", params=params)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def read_records(
@@ -882,7 +912,8 @@ class FeishuRequest:
"record_ids": record_id_list,
"user_id_type": user_id_type,
}
res = self._send_request(url, method="GET", params=params, payload=payload)
res: dict = self._send_request(url, method="GET", params=params, payload=payload)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res

View File

@@ -1,5 +1,5 @@
import json
from typing import Optional
from typing import Any, Optional, cast
import httpx
@@ -62,12 +62,10 @@ class LarkRequest:
def tenant_access_token(self) -> str:
feishu_tenant_access_token = f"tools:{self.app_id}:feishu_tenant_access_token"
if redis_client.exists(feishu_tenant_access_token):
return redis_client.get(feishu_tenant_access_token).decode()
res = self.get_tenant_access_token(self.app_id, self.app_secret)
return str(redis_client.get(feishu_tenant_access_token).decode())
res: dict[str, str] = self.get_tenant_access_token(self.app_id, self.app_secret)
redis_client.setex(feishu_tenant_access_token, res.get("expire"), res.get("tenant_access_token"))
if "tenant_access_token" in res:
return res.get("tenant_access_token")
return ""
return res.get("tenant_access_token", "")
def _send_request(
self,
@@ -91,7 +89,7 @@ class LarkRequest:
def get_tenant_access_token(self, app_id: str, app_secret: str) -> dict:
url = f"{self.API_BASE_URL}/access_token/get_tenant_access_token"
payload = {"app_id": app_id, "app_secret": app_secret}
res = self._send_request(url, require_token=False, payload=payload)
res: dict = self._send_request(url, require_token=False, payload=payload)
return res
def create_document(self, title: str, content: str, folder_token: str) -> dict:
@@ -101,15 +99,16 @@ class LarkRequest:
"content": content,
"folder_token": folder_token,
}
res = self._send_request(url, payload=payload)
res: dict = self._send_request(url, payload=payload)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def write_document(self, document_id: str, content: str, position: str = "end") -> dict:
url = f"{self.API_BASE_URL}/document/write_document"
payload = {"document_id": document_id, "content": content, "position": position}
res = self._send_request(url, payload=payload)
res: dict = self._send_request(url, payload=payload)
return res
def get_document_content(self, document_id: str, mode: str = "markdown", lang: str = "0") -> str | dict:
@@ -119,9 +118,9 @@ class LarkRequest:
"lang": lang,
}
url = f"{self.API_BASE_URL}/document/get_document_content"
res = self._send_request(url, method="GET", params=params)
res: dict = self._send_request(url, method="GET", params=params)
if "data" in res:
return res.get("data").get("content")
return cast(dict, res.get("data", {}).get("content"))
return ""
def list_document_blocks(
@@ -134,9 +133,10 @@ class LarkRequest:
"page_token": page_token,
}
url = f"{self.API_BASE_URL}/document/list_document_blocks"
res = self._send_request(url, method="GET", params=params)
res: dict = self._send_request(url, method="GET", params=params)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def send_bot_message(self, receive_id_type: str, receive_id: str, msg_type: str, content: str) -> dict:
@@ -149,9 +149,10 @@ class LarkRequest:
"msg_type": msg_type,
"content": content.strip('"').replace(r"\"", '"').replace(r"\\", "\\"),
}
res = self._send_request(url, params=params, payload=payload)
res: dict = self._send_request(url, params=params, payload=payload)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def send_webhook_message(self, webhook: str, msg_type: str, content: str) -> dict:
@@ -161,7 +162,7 @@ class LarkRequest:
"msg_type": msg_type,
"content": content.strip('"').replace(r"\"", '"').replace(r"\\", "\\"),
}
res = self._send_request(url, require_token=False, payload=payload)
res: dict = self._send_request(url, require_token=False, payload=payload)
return res
def get_chat_messages(
@@ -182,9 +183,10 @@ class LarkRequest:
"page_token": page_token,
"page_size": page_size,
}
res = self._send_request(url, method="GET", params=params)
res: dict = self._send_request(url, method="GET", params=params)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def get_thread_messages(
@@ -197,9 +199,10 @@ class LarkRequest:
"page_token": page_token,
"page_size": page_size,
}
res = self._send_request(url, method="GET", params=params)
res: dict = self._send_request(url, method="GET", params=params)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def create_task(self, summary: str, start_time: str, end_time: str, completed_time: str, description: str) -> dict:
@@ -211,9 +214,10 @@ class LarkRequest:
"completed_at": completed_time,
"description": description,
}
res = self._send_request(url, payload=payload)
res: dict = self._send_request(url, payload=payload)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def update_task(
@@ -228,9 +232,10 @@ class LarkRequest:
"completed_time": completed_time,
"description": description,
}
res = self._send_request(url, method="PATCH", payload=payload)
res: dict = self._send_request(url, method="PATCH", payload=payload)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def delete_task(self, task_guid: str) -> dict:
@@ -238,9 +243,10 @@ class LarkRequest:
payload = {
"task_guid": task_guid,
}
res = self._send_request(url, method="DELETE", payload=payload)
res: dict = self._send_request(url, method="DELETE", payload=payload)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def add_members(self, task_guid: str, member_phone_or_email: str, member_role: str) -> dict:
@@ -250,9 +256,10 @@ class LarkRequest:
"member_phone_or_email": member_phone_or_email,
"member_role": member_role,
}
res = self._send_request(url, payload=payload)
res: dict = self._send_request(url, payload=payload)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def get_wiki_nodes(self, space_id: str, parent_node_token: str, page_token: str, page_size: int = 20) -> dict:
@@ -263,9 +270,10 @@ class LarkRequest:
"page_token": page_token,
"page_size": page_size,
}
res = self._send_request(url, payload=payload)
res: dict = self._send_request(url, payload=payload)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def get_primary_calendar(self, user_id_type: str = "open_id") -> dict:
@@ -273,9 +281,10 @@ class LarkRequest:
params = {
"user_id_type": user_id_type,
}
res = self._send_request(url, method="GET", params=params)
res: dict = self._send_request(url, method="GET", params=params)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def create_event(
@@ -298,9 +307,10 @@ class LarkRequest:
"auto_record": auto_record,
"attendee_ability": attendee_ability,
}
res = self._send_request(url, payload=payload)
res: dict = self._send_request(url, payload=payload)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def update_event(
@@ -314,7 +324,7 @@ class LarkRequest:
auto_record: bool,
) -> dict:
url = f"{self.API_BASE_URL}/calendar/update_event/{event_id}"
payload = {}
payload: dict[str, Any] = {}
if summary:
payload["summary"] = summary
if description:
@@ -327,7 +337,7 @@ class LarkRequest:
payload["need_notification"] = need_notification
if auto_record:
payload["auto_record"] = auto_record
res = self._send_request(url, method="PATCH", payload=payload)
res: dict = self._send_request(url, method="PATCH", payload=payload)
return res
def delete_event(self, event_id: str, need_notification: bool = True) -> dict:
@@ -335,7 +345,7 @@ class LarkRequest:
params = {
"need_notification": need_notification,
}
res = self._send_request(url, method="DELETE", params=params)
res: dict = self._send_request(url, method="DELETE", params=params)
return res
def list_events(self, start_time: str, end_time: str, page_token: str, page_size: int = 50) -> dict:
@@ -346,9 +356,10 @@ class LarkRequest:
"page_token": page_token,
"page_size": page_size,
}
res = self._send_request(url, method="GET", params=params)
res: dict = self._send_request(url, method="GET", params=params)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def search_events(
@@ -369,9 +380,10 @@ class LarkRequest:
"user_id_type": user_id_type,
"page_size": page_size,
}
res = self._send_request(url, payload=payload)
res: dict = self._send_request(url, payload=payload)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def add_event_attendees(self, event_id: str, attendee_phone_or_email: str, need_notification: bool = True) -> dict:
@@ -381,9 +393,10 @@ class LarkRequest:
"attendee_phone_or_email": attendee_phone_or_email,
"need_notification": need_notification,
}
res = self._send_request(url, payload=payload)
res: dict = self._send_request(url, payload=payload)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def create_spreadsheet(
@@ -396,9 +409,10 @@ class LarkRequest:
"title": title,
"folder_token": folder_token,
}
res = self._send_request(url, payload=payload)
res: dict = self._send_request(url, payload=payload)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def get_spreadsheet(
@@ -411,9 +425,10 @@ class LarkRequest:
"spreadsheet_token": spreadsheet_token,
"user_id_type": user_id_type,
}
res = self._send_request(url, method="GET", params=params)
res: dict = self._send_request(url, method="GET", params=params)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def list_spreadsheet_sheets(
@@ -424,9 +439,10 @@ class LarkRequest:
params = {
"spreadsheet_token": spreadsheet_token,
}
res = self._send_request(url, method="GET", params=params)
res: dict = self._send_request(url, method="GET", params=params)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def add_rows(
@@ -445,9 +461,10 @@ class LarkRequest:
"length": length,
"values": values,
}
res = self._send_request(url, payload=payload)
res: dict = self._send_request(url, payload=payload)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def add_cols(
@@ -466,9 +483,10 @@ class LarkRequest:
"length": length,
"values": values,
}
res = self._send_request(url, payload=payload)
res: dict = self._send_request(url, payload=payload)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def read_rows(
@@ -489,9 +507,10 @@ class LarkRequest:
"num_rows": num_rows,
"user_id_type": user_id_type,
}
res = self._send_request(url, method="GET", params=params)
res: dict = self._send_request(url, method="GET", params=params)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def read_cols(
@@ -512,9 +531,10 @@ class LarkRequest:
"num_cols": num_cols,
"user_id_type": user_id_type,
}
res = self._send_request(url, method="GET", params=params)
res: dict = self._send_request(url, method="GET", params=params)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def read_table(
@@ -535,9 +555,10 @@ class LarkRequest:
"query": query,
"user_id_type": user_id_type,
}
res = self._send_request(url, method="GET", params=params)
res: dict = self._send_request(url, method="GET", params=params)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def create_base(
@@ -550,9 +571,10 @@ class LarkRequest:
"name": name,
"folder_token": folder_token,
}
res = self._send_request(url, payload=payload)
res: dict = self._send_request(url, payload=payload)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def add_records(
@@ -573,9 +595,10 @@ class LarkRequest:
payload = {
"records": self.convert_add_records(records),
}
res = self._send_request(url, params=params, payload=payload)
res: dict = self._send_request(url, params=params, payload=payload)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def update_records(
@@ -596,9 +619,10 @@ class LarkRequest:
payload = {
"records": self.convert_update_records(records),
}
res = self._send_request(url, params=params, payload=payload)
res: dict = self._send_request(url, params=params, payload=payload)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def delete_records(
@@ -624,9 +648,10 @@ class LarkRequest:
payload = {
"records": record_id_list,
}
res = self._send_request(url, params=params, payload=payload)
res: dict = self._send_request(url, params=params, payload=payload)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def search_record(
@@ -678,7 +703,7 @@ class LarkRequest:
except json.JSONDecodeError:
raise ValueError("The input string is not valid JSON")
payload = {}
payload: dict[str, Any] = {}
if view_id:
payload["view_id"] = view_id
@@ -690,9 +715,10 @@ class LarkRequest:
payload["filter"] = filter_dict
if automatic_fields:
payload["automatic_fields"] = automatic_fields
res = self._send_request(url, params=params, payload=payload)
res: dict = self._send_request(url, params=params, payload=payload)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def get_base_info(
@@ -703,9 +729,10 @@ class LarkRequest:
params = {
"app_token": app_token,
}
res = self._send_request(url, method="GET", params=params)
res: dict = self._send_request(url, method="GET", params=params)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def create_table(
@@ -732,9 +759,10 @@ class LarkRequest:
}
if default_view_name:
payload["default_view_name"] = default_view_name
res = self._send_request(url, params=params, payload=payload)
res: dict = self._send_request(url, params=params, payload=payload)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def delete_tables(
@@ -767,9 +795,10 @@ class LarkRequest:
"table_ids": table_id_list,
"table_names": table_name_list,
}
res = self._send_request(url, params=params, payload=payload)
res: dict = self._send_request(url, params=params, payload=payload)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def list_tables(
@@ -784,9 +813,10 @@ class LarkRequest:
"page_token": page_token,
"page_size": page_size,
}
res = self._send_request(url, method="GET", params=params)
res: dict = self._send_request(url, method="GET", params=params)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def read_records(
@@ -814,7 +844,8 @@ class LarkRequest:
"record_ids": record_id_list,
"user_id_type": user_id_type,
}
res = self._send_request(url, method="POST", params=params, payload=payload)
res: dict = self._send_request(url, method="POST", params=params, payload=payload)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res

View File

@@ -90,12 +90,12 @@ class ToolFileMessageTransformer:
)
elif message.type == ToolInvokeMessage.MessageType.FILE:
assert message.meta is not None
file = message.meta.get("file")
if isinstance(file, File):
if file.transfer_method == FileTransferMethod.TOOL_FILE:
assert file.related_id is not None
url = cls.get_tool_file_url(tool_file_id=file.related_id, extension=file.extension)
if file.type == FileType.IMAGE:
file_mata = message.meta.get("file")
if isinstance(file_mata, File):
if file_mata.transfer_method == FileTransferMethod.TOOL_FILE:
assert file_mata.related_id is not None
url = cls.get_tool_file_url(tool_file_id=file_mata.related_id, extension=file_mata.extension)
if file_mata.type == FileType.IMAGE:
result.append(
ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.IMAGE_LINK,

View File

@@ -5,7 +5,7 @@ Therefore, a model manager is needed to list/invoke/validate models.
"""
import json
from typing import cast
from typing import Optional, cast
from core.model_manager import ModelManager
from core.model_runtime.entities.llm_entities import LLMResult
@@ -51,7 +51,7 @@ class ModelInvocationUtils:
if not schema:
raise InvokeModelError("No model schema found")
max_tokens = schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE, None)
max_tokens: Optional[int] = schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE, None)
if max_tokens is None:
return 2048
@@ -133,14 +133,17 @@ class ModelInvocationUtils:
db.session.commit()
try:
response: LLMResult = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=[],
stop=[],
stream=False,
user=user_id,
callbacks=[],
response: LLMResult = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=[],
stop=[],
stream=False,
user=user_id,
callbacks=[],
),
)
except InvokeRateLimitError as e:
raise InvokeModelError(f"Invoke rate limit error: {e}")

View File

@@ -6,7 +6,7 @@ from json.decoder import JSONDecodeError
from typing import Optional
from requests import get
from yaml import YAMLError, safe_load
from yaml import YAMLError, safe_load # type: ignore
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiToolBundle
@@ -64,6 +64,9 @@ class ApiBasedToolSchemaParser:
default=parameter["schema"]["default"]
if "schema" in parameter and "default" in parameter["schema"]
else None,
placeholder=I18nObject(
en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "")
),
)
# check if there is a type
@@ -108,6 +111,9 @@ class ApiBasedToolSchemaParser:
form=ToolParameter.ToolParameterForm.LLM,
llm_description=property.get("description", ""),
default=property.get("default", None),
placeholder=I18nObject(
en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "")
),
)
# check if there is a type
@@ -158,9 +164,9 @@ class ApiBasedToolSchemaParser:
return bundles
@staticmethod
def _get_tool_parameter_type(parameter: dict) -> ToolParameter.ToolParameterType:
def _get_tool_parameter_type(parameter: dict) -> Optional[ToolParameter.ToolParameterType]:
parameter = parameter or {}
typ = None
typ: Optional[str] = None
if parameter.get("format") == "binary":
return ToolParameter.ToolParameterType.FILE
@@ -175,6 +181,8 @@ class ApiBasedToolSchemaParser:
return ToolParameter.ToolParameterType.BOOLEAN
elif typ == "string":
return ToolParameter.ToolParameterType.STRING
else:
return None
@staticmethod
def parse_openapi_yaml_to_tool_bundle(
@@ -236,7 +244,8 @@ class ApiBasedToolSchemaParser:
if ("summary" not in operation or len(operation["summary"]) == 0) and (
"description" not in operation or len(operation["description"]) == 0
):
warning["missing_summary"] = f"No summary or description found in operation {method} {path}."
if warning is not None:
warning["missing_summary"] = f"No summary or description found in operation {method} {path}."
openapi["paths"][path][method] = {
"operationId": operation["operationId"],

View File

@@ -9,13 +9,13 @@ import tempfile
import unicodedata
from contextlib import contextmanager
from pathlib import Path
from typing import Optional
from typing import Any, Literal, Optional, cast
from urllib.parse import unquote
import chardet
import cloudscraper
from bs4 import BeautifulSoup, CData, Comment, NavigableString
from regex import regex
import cloudscraper # type: ignore
from bs4 import BeautifulSoup, CData, Comment, NavigableString # type: ignore
from regex import regex # type: ignore
from core.helper import ssrf_proxy
from core.rag.extractor import extract_processor
@@ -68,7 +68,7 @@ def get_url(url: str, user_agent: Optional[str] = None) -> str:
return "Unsupported content-type [{}] of URL.".format(main_content_type)
if main_content_type in extract_processor.SUPPORT_URL_CONTENT_TYPES:
return ExtractProcessor.load_from_url(url, return_text=True)
return cast(str, ExtractProcessor.load_from_url(url, return_text=True))
response = ssrf_proxy.get(url, headers=headers, follow_redirects=True, timeout=(120, 300))
elif response.status_code == 403:
@@ -125,7 +125,7 @@ def extract_using_readabilipy(html):
os.unlink(article_json_path)
os.unlink(html_path)
article_json = {
article_json: dict[str, Any] = {
"title": None,
"byline": None,
"date": None,
@@ -300,7 +300,7 @@ def strip_control_characters(text):
def normalize_unicode(text):
"""Normalize unicode such that things that are visually equivalent map to the same unicode string where possible."""
normal_form = "NFKC"
normal_form: Literal["NFC", "NFD", "NFKC", "NFKD"] = "NFKC"
text = unicodedata.normalize(normal_form, text)
return text
@@ -332,6 +332,7 @@ def add_content_digest(element):
def content_digest(element):
digest: Any
if is_text(element):
# Hash
trimmed_string = element.string.strip()

View File

@@ -7,7 +7,7 @@ from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
class WorkflowToolConfigurationUtils:
@classmethod
def check_parameter_configurations(cls, configurations: Mapping[str, Any]):
def check_parameter_configurations(cls, configurations: list[Mapping[str, Any]]):
for configuration in configurations:
WorkflowToolParameterConfiguration.model_validate(configuration)
@@ -27,7 +27,7 @@ class WorkflowToolConfigurationUtils:
@classmethod
def check_is_synced(
cls, variables: list[VariableEntity], tool_configurations: list[WorkflowToolParameterConfiguration]
) -> None:
) -> bool:
"""
check is synced

View File

@@ -2,7 +2,7 @@ import logging
from pathlib import Path
from typing import Any
import yaml
import yaml # type: ignore
from yaml import YAMLError
logger = logging.getLogger(__name__)