mirror of
https://github.com/langgenius/dify.git
synced 2026-02-12 22:01:20 -05:00
feat: mypy for all type check (#10921)
This commit is contained in:
@@ -14,7 +14,7 @@ class UserTool(BaseModel):
|
||||
label: I18nObject # label
|
||||
description: I18nObject
|
||||
parameters: Optional[list[ToolParameter]] = None
|
||||
labels: list[str] = None
|
||||
labels: list[str] | None = None
|
||||
|
||||
|
||||
UserToolProviderTypeLiteral = Optional[Literal["builtin", "api", "workflow"]]
|
||||
|
||||
@@ -18,7 +18,7 @@ class ApiToolBundle(BaseModel):
|
||||
# summary
|
||||
summary: Optional[str] = None
|
||||
# operation_id
|
||||
operation_id: str = None
|
||||
operation_id: str | None = None
|
||||
# parameters
|
||||
parameters: Optional[list[ToolParameter]] = None
|
||||
# author
|
||||
|
||||
@@ -244,18 +244,19 @@ class ToolParameter(BaseModel):
|
||||
"""
|
||||
# convert options to ToolParameterOption
|
||||
if options:
|
||||
options = [
|
||||
options_tool_parametor = [
|
||||
ToolParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option)) for option in options
|
||||
]
|
||||
return cls(
|
||||
name=name,
|
||||
label=I18nObject(en_US="", zh_Hans=""),
|
||||
human_description=I18nObject(en_US="", zh_Hans=""),
|
||||
placeholder=None,
|
||||
type=type,
|
||||
form=cls.ToolParameterForm.LLM,
|
||||
llm_description=llm_description,
|
||||
required=required,
|
||||
options=options,
|
||||
options=options_tool_parametor,
|
||||
)
|
||||
|
||||
|
||||
@@ -331,7 +332,7 @@ class ToolProviderCredentials(BaseModel):
|
||||
"default": self.default,
|
||||
"options": self.options,
|
||||
"help": self.help.to_dict() if self.help else None,
|
||||
"label": self.label.to_dict(),
|
||||
"label": self.label.to_dict() if self.label else None,
|
||||
"url": self.url,
|
||||
"placeholder": self.placeholder.to_dict() if self.placeholder else None,
|
||||
}
|
||||
@@ -374,7 +375,10 @@ class ToolRuntimeVariablePool(BaseModel):
|
||||
pool[index] = ToolRuntimeImageVariable(**variable)
|
||||
super().__init__(**data)
|
||||
|
||||
def dict(self) -> dict:
|
||||
def dict(self) -> dict: # type: ignore
|
||||
"""
|
||||
FIXME: just ignore the type check for now
|
||||
"""
|
||||
return {
|
||||
"conversation_id": self.conversation_id,
|
||||
"user_id": self.user_id,
|
||||
|
||||
@@ -1,9 +1,14 @@
|
||||
from typing import Optional
|
||||
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||
from core.tools.entities.tool_entities import (
|
||||
ApiProviderAuthType,
|
||||
ToolCredentialsOption,
|
||||
ToolDescription,
|
||||
ToolIdentity,
|
||||
ToolProviderCredentials,
|
||||
ToolProviderIdentity,
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.provider.tool_provider import ToolProviderController
|
||||
@@ -64,21 +69,18 @@ class ApiToolProviderController(ToolProviderController):
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f"invalid auth type {auth_type}")
|
||||
|
||||
user_name = db_provider.user.name if db_provider.user_id else ""
|
||||
|
||||
user_name = db_provider.user.name if db_provider.user_id and db_provider.user is not None else ""
|
||||
return ApiToolProviderController(
|
||||
**{
|
||||
"identity": {
|
||||
"author": user_name,
|
||||
"name": db_provider.name,
|
||||
"label": {"en_US": db_provider.name, "zh_Hans": db_provider.name},
|
||||
"description": {"en_US": db_provider.description, "zh_Hans": db_provider.description},
|
||||
"icon": db_provider.icon,
|
||||
},
|
||||
"credentials_schema": credentials_schema,
|
||||
"provider_id": db_provider.id or "",
|
||||
}
|
||||
identity=ToolProviderIdentity(
|
||||
author=user_name,
|
||||
name=db_provider.name,
|
||||
label=I18nObject(en_US=db_provider.name, zh_Hans=db_provider.name),
|
||||
description=I18nObject(en_US=db_provider.description, zh_Hans=db_provider.description),
|
||||
icon=db_provider.icon,
|
||||
),
|
||||
credentials_schema=credentials_schema,
|
||||
provider_id=db_provider.id or "",
|
||||
tools=None,
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -93,24 +95,22 @@ class ApiToolProviderController(ToolProviderController):
|
||||
:return: the tool
|
||||
"""
|
||||
return ApiTool(
|
||||
**{
|
||||
"api_bundle": tool_bundle,
|
||||
"identity": {
|
||||
"author": tool_bundle.author,
|
||||
"name": tool_bundle.operation_id,
|
||||
"label": {"en_US": tool_bundle.operation_id, "zh_Hans": tool_bundle.operation_id},
|
||||
"icon": self.identity.icon,
|
||||
"provider": self.provider_id,
|
||||
},
|
||||
"description": {
|
||||
"human": {"en_US": tool_bundle.summary or "", "zh_Hans": tool_bundle.summary or ""},
|
||||
"llm": tool_bundle.summary or "",
|
||||
},
|
||||
"parameters": tool_bundle.parameters or [],
|
||||
}
|
||||
api_bundle=tool_bundle,
|
||||
identity=ToolIdentity(
|
||||
author=tool_bundle.author,
|
||||
name=tool_bundle.operation_id or "",
|
||||
label=I18nObject(en_US=tool_bundle.operation_id, zh_Hans=tool_bundle.operation_id),
|
||||
icon=self.identity.icon if self.identity else None,
|
||||
provider=self.provider_id,
|
||||
),
|
||||
description=ToolDescription(
|
||||
human=I18nObject(en_US=tool_bundle.summary or "", zh_Hans=tool_bundle.summary or ""),
|
||||
llm=tool_bundle.summary or "",
|
||||
),
|
||||
parameters=tool_bundle.parameters or [],
|
||||
)
|
||||
|
||||
def load_bundled_tools(self, tools: list[ApiToolBundle]) -> list[ApiTool]:
|
||||
def load_bundled_tools(self, tools: list[ApiToolBundle]) -> list[Tool]:
|
||||
"""
|
||||
load bundled tools
|
||||
|
||||
@@ -121,7 +121,7 @@ class ApiToolProviderController(ToolProviderController):
|
||||
|
||||
return self.tools
|
||||
|
||||
def get_tools(self, user_id: str, tenant_id: str) -> list[ApiTool]:
|
||||
def get_tools(self, user_id: str = "", tenant_id: str = "") -> Optional[list[Tool]]:
|
||||
"""
|
||||
fetch tools from database
|
||||
|
||||
@@ -131,6 +131,8 @@ class ApiToolProviderController(ToolProviderController):
|
||||
"""
|
||||
if self.tools is not None:
|
||||
return self.tools
|
||||
if self.identity is None:
|
||||
return None
|
||||
|
||||
tools: list[Tool] = []
|
||||
|
||||
@@ -151,7 +153,7 @@ class ApiToolProviderController(ToolProviderController):
|
||||
self.tools = tools
|
||||
return tools
|
||||
|
||||
def get_tool(self, tool_name: str) -> ApiTool:
|
||||
def get_tool(self, tool_name: str) -> Tool:
|
||||
"""
|
||||
get tool by name
|
||||
|
||||
@@ -161,7 +163,9 @@ class ApiToolProviderController(ToolProviderController):
|
||||
if self.tools is None:
|
||||
self.get_tools()
|
||||
|
||||
for tool in self.tools:
|
||||
for tool in self.tools or []:
|
||||
if tool.identity is None:
|
||||
continue
|
||||
if tool.identity.name == tool_name:
|
||||
return tool
|
||||
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolParameter, ToolParameterOption, ToolProviderType
|
||||
from core.tools.provider.tool_provider import ToolProviderController
|
||||
from core.tools.tool.api_tool import ApiTool
|
||||
from core.tools.tool.tool import Tool
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, AppModelConfig
|
||||
@@ -20,10 +21,10 @@ class AppToolProviderEntity(ToolProviderController):
|
||||
def _validate_credentials(self, tool_name: str, credentials: dict[str, Any]) -> None:
|
||||
pass
|
||||
|
||||
def validate_parameters(self, tool_name: str, tool_parameters: dict[str, Any]) -> None:
|
||||
def validate_parameters(self, tool_id: int, tool_name: str, tool_parameters: dict[str, Any]) -> None:
|
||||
pass
|
||||
|
||||
def get_tools(self, user_id: str) -> list[Tool]:
|
||||
def get_tools(self, user_id: str = "", tenant_id: str = "") -> list[Tool]:
|
||||
db_tools: list[PublishedAppTool] = (
|
||||
db.session.query(PublishedAppTool)
|
||||
.filter(
|
||||
@@ -38,7 +39,7 @@ class AppToolProviderEntity(ToolProviderController):
|
||||
tools: list[Tool] = []
|
||||
|
||||
for db_tool in db_tools:
|
||||
tool = {
|
||||
tool: dict[str, Any] = {
|
||||
"identity": {
|
||||
"author": db_tool.author,
|
||||
"name": db_tool.tool_name,
|
||||
@@ -52,7 +53,7 @@ class AppToolProviderEntity(ToolProviderController):
|
||||
"parameters": [],
|
||||
}
|
||||
# get app from db
|
||||
app: App = db_tool.app
|
||||
app: Optional[App] = db_tool.app
|
||||
|
||||
if not app:
|
||||
logger.error(f"app {db_tool.app_id} not found")
|
||||
@@ -79,6 +80,7 @@ class AppToolProviderEntity(ToolProviderController):
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=required,
|
||||
default=default,
|
||||
placeholder=I18nObject(en_US="", zh_Hans=""),
|
||||
)
|
||||
)
|
||||
elif form_type == "select":
|
||||
@@ -92,6 +94,7 @@ class AppToolProviderEntity(ToolProviderController):
|
||||
type=ToolParameter.ToolParameterType.SELECT,
|
||||
required=required,
|
||||
default=default,
|
||||
placeholder=I18nObject(en_US="", zh_Hans=""),
|
||||
options=[
|
||||
ToolParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option))
|
||||
for option in options
|
||||
@@ -99,5 +102,5 @@ class AppToolProviderEntity(ToolProviderController):
|
||||
)
|
||||
)
|
||||
|
||||
tools.append(Tool(**tool))
|
||||
tools.append(ApiTool(**tool))
|
||||
return tools
|
||||
|
||||
@@ -5,7 +5,7 @@ from core.tools.entities.api_entities import UserToolProvider
|
||||
|
||||
|
||||
class BuiltinToolProviderSort:
|
||||
_position = {}
|
||||
_position: dict[str, int] = {}
|
||||
|
||||
@classmethod
|
||||
def sort(cls, providers: list[UserToolProvider]) -> list[UserToolProvider]:
|
||||
|
||||
@@ -4,7 +4,7 @@ from hmac import new as hmac_new
|
||||
from json import loads as json_loads
|
||||
from threading import Lock
|
||||
from time import sleep, time
|
||||
from typing import Any
|
||||
from typing import Any, Union
|
||||
|
||||
from httpx import get, post
|
||||
from requests import get as requests_get
|
||||
@@ -21,23 +21,25 @@ class AIPPTGenerateToolAdapter:
|
||||
"""
|
||||
|
||||
_api_base_url = URL("https://co.aippt.cn/api")
|
||||
_api_token_cache = {}
|
||||
_style_cache = {}
|
||||
_api_token_cache: dict[str, dict[str, Union[str, float]]] = {}
|
||||
_style_cache: dict[str, dict[str, Union[list[dict[str, Any]], float]]] = {}
|
||||
|
||||
_api_token_cache_lock = Lock()
|
||||
_style_cache_lock = Lock()
|
||||
_api_token_cache_lock: Lock = Lock()
|
||||
_style_cache_lock: Lock = Lock()
|
||||
|
||||
_task = {}
|
||||
_task: dict[str, Any] = {}
|
||||
_task_type_map = {
|
||||
"auto": 1,
|
||||
"markdown": 7,
|
||||
}
|
||||
_tool: BuiltinTool
|
||||
_tool: BuiltinTool | None
|
||||
|
||||
def __init__(self, tool: BuiltinTool = None):
|
||||
def __init__(self, tool: BuiltinTool | None = None):
|
||||
self._tool = tool
|
||||
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
Invokes the AIPPT generate tool with the given user ID and tool parameters.
|
||||
|
||||
@@ -68,8 +70,8 @@ class AIPPTGenerateToolAdapter:
|
||||
)
|
||||
|
||||
# get suit
|
||||
color: str = tool_parameters.get("color")
|
||||
style: str = tool_parameters.get("style")
|
||||
color: str = tool_parameters.get("color", "")
|
||||
style: str = tool_parameters.get("style", "")
|
||||
|
||||
if color == "__default__":
|
||||
color_id = ""
|
||||
@@ -226,7 +228,7 @@ class AIPPTGenerateToolAdapter:
|
||||
|
||||
return ""
|
||||
|
||||
def _generate_ppt(self, task_id: str, suit_id: int, user_id) -> tuple[str, str]:
|
||||
def _generate_ppt(self, task_id: str, suit_id: int, user_id: str) -> tuple[str, str]:
|
||||
"""
|
||||
Generate a ppt
|
||||
|
||||
@@ -362,7 +364,9 @@ class AIPPTGenerateToolAdapter:
|
||||
).decode("utf-8")
|
||||
|
||||
@classmethod
|
||||
def _get_styles(cls, credentials: dict[str, str], user_id: str) -> tuple[list[dict], list[dict]]:
|
||||
def _get_styles(
|
||||
cls, credentials: dict[str, str], user_id: str
|
||||
) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
|
||||
"""
|
||||
Get styles
|
||||
"""
|
||||
@@ -415,7 +419,7 @@ class AIPPTGenerateToolAdapter:
|
||||
|
||||
return colors, styles
|
||||
|
||||
def get_styles(self, user_id: str) -> tuple[list[dict], list[dict]]:
|
||||
def get_styles(self, user_id: str) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
|
||||
"""
|
||||
Get styles
|
||||
|
||||
@@ -507,7 +511,9 @@ class AIPPTGenerateTool(BuiltinTool):
|
||||
def __init__(self, **kwargs: Any):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
return AIPPTGenerateToolAdapter(self)._invoke(user_id, tool_parameters)
|
||||
|
||||
def get_runtime_parameters(self) -> list[ToolParameter]:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
import arxiv
|
||||
import arxiv # type: ignore
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
|
||||
@@ -11,19 +11,21 @@ from services.model_provider_service import ModelProviderService
|
||||
|
||||
class TTSTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> list[ToolInvokeMessage]:
|
||||
provider, model = tool_parameters.get("model").split("#")
|
||||
voice = tool_parameters.get(f"voice#{provider}#{model}")
|
||||
provider, model = tool_parameters.get("model", "").split("#")
|
||||
voice = tool_parameters.get(f"voice#{provider}#{model}", "")
|
||||
model_manager = ModelManager()
|
||||
if not self.runtime:
|
||||
raise ValueError("Runtime is required")
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=self.runtime.tenant_id,
|
||||
tenant_id=self.runtime.tenant_id or "",
|
||||
provider=provider,
|
||||
model_type=ModelType.TTS,
|
||||
model=model,
|
||||
)
|
||||
tts = model_instance.invoke_tts(
|
||||
content_text=tool_parameters.get("text"),
|
||||
content_text=tool_parameters.get("text", ""),
|
||||
user=user_id,
|
||||
tenant_id=self.runtime.tenant_id,
|
||||
tenant_id=self.runtime.tenant_id or "",
|
||||
voice=voice,
|
||||
)
|
||||
buffer = io.BytesIO()
|
||||
@@ -41,8 +43,11 @@ class TTSTool(BuiltinTool):
|
||||
]
|
||||
|
||||
def get_available_models(self) -> list[tuple[str, str, list[Any]]]:
|
||||
if not self.runtime:
|
||||
raise ValueError("Runtime is required")
|
||||
model_provider_service = ModelProviderService()
|
||||
models = model_provider_service.get_models_by_model_type(tenant_id=self.runtime.tenant_id, model_type="tts")
|
||||
tid: str = self.runtime.tenant_id or ""
|
||||
models = model_provider_service.get_models_by_model_type(tenant_id=tid, model_type="tts")
|
||||
items = []
|
||||
for provider_model in models:
|
||||
provider = provider_model.provider
|
||||
@@ -62,6 +67,8 @@ class TTSTool(BuiltinTool):
|
||||
ToolParameter(
|
||||
name=f"voice#{provider}#{model}",
|
||||
label=I18nObject(en_US=f"Voice of {model}({provider})"),
|
||||
human_description=I18nObject(en_US=f"Select a voice for {model} model"),
|
||||
placeholder=I18nObject(en_US="Select a voice"),
|
||||
type=ToolParameter.ToolParameterType.SELECT,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
options=[
|
||||
@@ -83,6 +90,7 @@ class TTSTool(BuiltinTool):
|
||||
type=ToolParameter.ToolParameterType.SELECT,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
required=True,
|
||||
placeholder=I18nObject(en_US="Select a model", zh_Hans="选择模型"),
|
||||
options=options,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -2,8 +2,8 @@ import json
|
||||
import logging
|
||||
from typing import Any, Union
|
||||
|
||||
import boto3
|
||||
from botocore.exceptions import BotoCoreError
|
||||
import boto3 # type: ignore
|
||||
from botocore.exceptions import BotoCoreError # type: ignore
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
from typing import Any, Union
|
||||
|
||||
import boto3
|
||||
import boto3 # type: ignore
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
@@ -2,7 +2,7 @@ import json
|
||||
import logging
|
||||
from typing import Any, Union
|
||||
|
||||
import boto3
|
||||
import boto3 # type: ignore
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
@@ -2,7 +2,7 @@ import json
|
||||
import operator
|
||||
from typing import Any, Union
|
||||
|
||||
import boto3
|
||||
import boto3 # type: ignore
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
@@ -10,8 +10,8 @@ from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
class SageMakerReRankTool(BuiltinTool):
|
||||
sagemaker_client: Any = None
|
||||
sagemaker_endpoint: str = None
|
||||
topk: int = None
|
||||
sagemaker_endpoint: str | None = None
|
||||
topk: int | None = None
|
||||
|
||||
def _sagemaker_rerank(self, query_input: str, docs: list[str], rerank_endpoint: str):
|
||||
inputs = [query_input] * len(docs)
|
||||
|
||||
@@ -2,7 +2,7 @@ import json
|
||||
from enum import Enum
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import boto3
|
||||
import boto3 # type: ignore
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
@@ -17,7 +17,7 @@ class TTSModelType(Enum):
|
||||
|
||||
class SageMakerTTSTool(BuiltinTool):
|
||||
sagemaker_client: Any = None
|
||||
sagemaker_endpoint: str = None
|
||||
sagemaker_endpoint: str | None = None
|
||||
s3_client: Any = None
|
||||
comprehend_client: Any = None
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Any, Union
|
||||
|
||||
from zhipuai import ZhipuAI
|
||||
from zhipuai import ZhipuAI # type: ignore
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Any, Union
|
||||
|
||||
import httpx
|
||||
from zhipuai import ZhipuAI
|
||||
from zhipuai import ZhipuAI # type: ignore
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import random
|
||||
from typing import Any, Union
|
||||
|
||||
from zhipuai import ZhipuAI
|
||||
from zhipuai import ZhipuAI # type: ignore
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
@@ -7,18 +7,22 @@ from core.tools.utils.feishu_api_utils import FeishuRequest
|
||||
|
||||
class SearchRecordsTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
|
||||
if not self.runtime or not self.runtime.credentials:
|
||||
raise ValueError("Runtime is not set")
|
||||
app_id = self.runtime.credentials.get("app_id")
|
||||
app_secret = self.runtime.credentials.get("app_secret")
|
||||
if not app_id or not app_secret:
|
||||
raise ValueError("app_id and app_secret are required")
|
||||
client = FeishuRequest(app_id, app_secret)
|
||||
|
||||
app_token = tool_parameters.get("app_token")
|
||||
table_id = tool_parameters.get("table_id")
|
||||
table_name = tool_parameters.get("table_name")
|
||||
view_id = tool_parameters.get("view_id")
|
||||
field_names = tool_parameters.get("field_names")
|
||||
sort = tool_parameters.get("sort")
|
||||
filters = tool_parameters.get("filter")
|
||||
page_token = tool_parameters.get("page_token")
|
||||
app_token = tool_parameters.get("app_token", "")
|
||||
table_id = tool_parameters.get("table_id", "")
|
||||
table_name = tool_parameters.get("table_name", "")
|
||||
view_id = tool_parameters.get("view_id", "")
|
||||
field_names = tool_parameters.get("field_names", "")
|
||||
sort = tool_parameters.get("sort", "")
|
||||
filters = tool_parameters.get("filter", "")
|
||||
page_token = tool_parameters.get("page_token", "")
|
||||
automatic_fields = tool_parameters.get("automatic_fields", False)
|
||||
user_id_type = tool_parameters.get("user_id_type", "open_id")
|
||||
page_size = tool_parameters.get("page_size", 20)
|
||||
|
||||
@@ -7,14 +7,18 @@ from core.tools.utils.feishu_api_utils import FeishuRequest
|
||||
|
||||
class UpdateRecordsTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
|
||||
if not self.runtime or not self.runtime.credentials:
|
||||
raise ValueError("Runtime is not set")
|
||||
app_id = self.runtime.credentials.get("app_id")
|
||||
app_secret = self.runtime.credentials.get("app_secret")
|
||||
if not app_id or not app_secret:
|
||||
raise ValueError("app_id and app_secret are required")
|
||||
client = FeishuRequest(app_id, app_secret)
|
||||
|
||||
app_token = tool_parameters.get("app_token")
|
||||
table_id = tool_parameters.get("table_id")
|
||||
table_name = tool_parameters.get("table_name")
|
||||
records = tool_parameters.get("records")
|
||||
app_token = tool_parameters.get("app_token", "")
|
||||
table_id = tool_parameters.get("table_id", "")
|
||||
table_name = tool_parameters.get("table_name", "")
|
||||
records = tool_parameters.get("records", "")
|
||||
user_id_type = tool_parameters.get("user_id_type", "open_id")
|
||||
|
||||
res = client.update_records(app_token, table_id, table_name, records, user_id_type)
|
||||
|
||||
@@ -7,12 +7,16 @@ from core.tools.utils.feishu_api_utils import FeishuRequest
|
||||
|
||||
class AddEventAttendeesTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
|
||||
if not self.runtime or not self.runtime.credentials:
|
||||
raise ValueError("Runtime is not set")
|
||||
app_id = self.runtime.credentials.get("app_id")
|
||||
app_secret = self.runtime.credentials.get("app_secret")
|
||||
if not app_id or not app_secret:
|
||||
raise ValueError("app_id and app_secret are required")
|
||||
client = FeishuRequest(app_id, app_secret)
|
||||
|
||||
event_id = tool_parameters.get("event_id")
|
||||
attendee_phone_or_email = tool_parameters.get("attendee_phone_or_email")
|
||||
event_id = tool_parameters.get("event_id", "")
|
||||
attendee_phone_or_email = tool_parameters.get("attendee_phone_or_email", "")
|
||||
need_notification = tool_parameters.get("need_notification", True)
|
||||
|
||||
res = client.add_event_attendees(event_id, attendee_phone_or_email, need_notification)
|
||||
|
||||
@@ -7,11 +7,15 @@ from core.tools.utils.feishu_api_utils import FeishuRequest
|
||||
|
||||
class DeleteEventTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
|
||||
if not self.runtime or not self.runtime.credentials:
|
||||
raise ValueError("Runtime is not set")
|
||||
app_id = self.runtime.credentials.get("app_id")
|
||||
app_secret = self.runtime.credentials.get("app_secret")
|
||||
if not app_id or not app_secret:
|
||||
raise ValueError("app_id and app_secret are required")
|
||||
client = FeishuRequest(app_id, app_secret)
|
||||
|
||||
event_id = tool_parameters.get("event_id")
|
||||
event_id = tool_parameters.get("event_id", "")
|
||||
need_notification = tool_parameters.get("need_notification", True)
|
||||
|
||||
res = client.delete_event(event_id, need_notification)
|
||||
|
||||
@@ -7,8 +7,12 @@ from core.tools.utils.feishu_api_utils import FeishuRequest
|
||||
|
||||
class GetPrimaryCalendarTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
|
||||
if not self.runtime or not self.runtime.credentials:
|
||||
raise ValueError("Runtime is not set")
|
||||
app_id = self.runtime.credentials.get("app_id")
|
||||
app_secret = self.runtime.credentials.get("app_secret")
|
||||
if not app_id or not app_secret:
|
||||
raise ValueError("app_id and app_secret are required")
|
||||
client = FeishuRequest(app_id, app_secret)
|
||||
|
||||
user_id_type = tool_parameters.get("user_id_type", "open_id")
|
||||
|
||||
@@ -7,14 +7,18 @@ from core.tools.utils.feishu_api_utils import FeishuRequest
|
||||
|
||||
class ListEventsTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
|
||||
if not self.runtime or not self.runtime.credentials:
|
||||
raise ValueError("Runtime is not set")
|
||||
app_id = self.runtime.credentials.get("app_id")
|
||||
app_secret = self.runtime.credentials.get("app_secret")
|
||||
if not app_id or not app_secret:
|
||||
raise ValueError("app_id and app_secret are required")
|
||||
client = FeishuRequest(app_id, app_secret)
|
||||
|
||||
start_time = tool_parameters.get("start_time")
|
||||
end_time = tool_parameters.get("end_time")
|
||||
page_token = tool_parameters.get("page_token")
|
||||
page_size = tool_parameters.get("page_size")
|
||||
start_time = tool_parameters.get("start_time", "")
|
||||
end_time = tool_parameters.get("end_time", "")
|
||||
page_token = tool_parameters.get("page_token", "")
|
||||
page_size = tool_parameters.get("page_size", 50)
|
||||
|
||||
res = client.list_events(start_time, end_time, page_token, page_size)
|
||||
|
||||
|
||||
@@ -7,16 +7,20 @@ from core.tools.utils.feishu_api_utils import FeishuRequest
|
||||
|
||||
class UpdateEventTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
|
||||
if not self.runtime or not self.runtime.credentials:
|
||||
raise ValueError("Runtime is not set")
|
||||
app_id = self.runtime.credentials.get("app_id")
|
||||
app_secret = self.runtime.credentials.get("app_secret")
|
||||
if not app_id or not app_secret:
|
||||
raise ValueError("app_id and app_secret are required")
|
||||
client = FeishuRequest(app_id, app_secret)
|
||||
|
||||
event_id = tool_parameters.get("event_id")
|
||||
summary = tool_parameters.get("summary")
|
||||
description = tool_parameters.get("description")
|
||||
event_id = tool_parameters.get("event_id", "")
|
||||
summary = tool_parameters.get("summary", "")
|
||||
description = tool_parameters.get("description", "")
|
||||
need_notification = tool_parameters.get("need_notification", True)
|
||||
start_time = tool_parameters.get("start_time")
|
||||
end_time = tool_parameters.get("end_time")
|
||||
start_time = tool_parameters.get("start_time", "")
|
||||
end_time = tool_parameters.get("end_time", "")
|
||||
auto_record = tool_parameters.get("auto_record", False)
|
||||
|
||||
res = client.update_event(event_id, summary, description, need_notification, start_time, end_time, auto_record)
|
||||
|
||||
@@ -7,13 +7,17 @@ from core.tools.utils.feishu_api_utils import FeishuRequest
|
||||
|
||||
class CreateDocumentTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
|
||||
if not self.runtime or not self.runtime.credentials:
|
||||
raise ValueError("Runtime is not set")
|
||||
app_id = self.runtime.credentials.get("app_id")
|
||||
app_secret = self.runtime.credentials.get("app_secret")
|
||||
if not app_id or not app_secret:
|
||||
raise ValueError("app_id and app_secret are required")
|
||||
client = FeishuRequest(app_id, app_secret)
|
||||
|
||||
title = tool_parameters.get("title")
|
||||
content = tool_parameters.get("content")
|
||||
folder_token = tool_parameters.get("folder_token")
|
||||
title = tool_parameters.get("title", "")
|
||||
content = tool_parameters.get("content", "")
|
||||
folder_token = tool_parameters.get("folder_token", "")
|
||||
|
||||
res = client.create_document(title, content, folder_token)
|
||||
return self.create_json_message(res)
|
||||
|
||||
@@ -7,11 +7,15 @@ from core.tools.utils.feishu_api_utils import FeishuRequest
|
||||
|
||||
class ListDocumentBlockTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
|
||||
if not self.runtime or not self.runtime.credentials:
|
||||
raise ValueError("Runtime is not set")
|
||||
app_id = self.runtime.credentials.get("app_id")
|
||||
app_secret = self.runtime.credentials.get("app_secret")
|
||||
if not app_id or not app_secret:
|
||||
raise ValueError("app_id and app_secret are required")
|
||||
client = FeishuRequest(app_id, app_secret)
|
||||
|
||||
document_id = tool_parameters.get("document_id")
|
||||
document_id = tool_parameters.get("document_id", "")
|
||||
page_token = tool_parameters.get("page_token", "")
|
||||
user_id_type = tool_parameters.get("user_id_type", "open_id")
|
||||
page_size = tool_parameters.get("page_size", 500)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
from typing import Any, Union
|
||||
|
||||
from jsonpath_ng import parse
|
||||
from jsonpath_ng import parse # type: ignore
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
from typing import Any, Union
|
||||
|
||||
from jsonpath_ng import parse
|
||||
from jsonpath_ng import parse # type: ignore
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
from typing import Any, Union
|
||||
|
||||
from jsonpath_ng import parse
|
||||
from jsonpath_ng import parse # type: ignore
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
from typing import Any, Union
|
||||
|
||||
from jsonpath_ng import parse
|
||||
from jsonpath_ng import parse # type: ignore
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
from typing import Any, Union
|
||||
|
||||
import numexpr as ne
|
||||
import numexpr as ne # type: ignore
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from novita_client import (
|
||||
from novita_client import ( # type: ignore
|
||||
Txt2ImgV3Embedding,
|
||||
Txt2ImgV3HiresFix,
|
||||
Txt2ImgV3LoRA,
|
||||
|
||||
@@ -2,7 +2,7 @@ from base64 import b64decode
|
||||
from copy import deepcopy
|
||||
from typing import Any, Union
|
||||
|
||||
from novita_client import (
|
||||
from novita_client import ( # type: ignore
|
||||
NovitaClient,
|
||||
)
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ from base64 import b64decode
|
||||
from copy import deepcopy
|
||||
from typing import Any, Union
|
||||
|
||||
from novita_client import (
|
||||
from novita_client import ( # type: ignore
|
||||
NovitaClient,
|
||||
)
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
from pydub import AudioSegment
|
||||
from pydub import AudioSegment # type: ignore
|
||||
|
||||
|
||||
class PodcastAudioGeneratorTool(BuiltinTool):
|
||||
|
||||
@@ -2,10 +2,10 @@ import io
|
||||
import logging
|
||||
from typing import Any, Union
|
||||
|
||||
from qrcode.constants import ERROR_CORRECT_H, ERROR_CORRECT_L, ERROR_CORRECT_M, ERROR_CORRECT_Q
|
||||
from qrcode.image.base import BaseImage
|
||||
from qrcode.image.pure import PyPNGImage
|
||||
from qrcode.main import QRCode
|
||||
from qrcode.constants import ERROR_CORRECT_H, ERROR_CORRECT_L, ERROR_CORRECT_M, ERROR_CORRECT_Q # type: ignore
|
||||
from qrcode.image.base import BaseImage # type: ignore
|
||||
from qrcode.image.pure import PyPNGImage # type: ignore
|
||||
from qrcode.main import QRCode # type: ignore
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Any, Union
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
from youtube_transcript_api import YouTubeTranscriptApi
|
||||
from youtube_transcript_api import YouTubeTranscriptApi # type: ignore
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
@@ -37,7 +37,7 @@ class TwilioAPIWrapper(BaseModel):
|
||||
def set_validator(cls, values: dict) -> dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
try:
|
||||
from twilio.rest import Client
|
||||
from twilio.rest import Client # type: ignore
|
||||
except ImportError:
|
||||
raise ImportError("Could not import twilio python package. Please install it with `pip install twilio`.")
|
||||
account_sid = values.get("account_sid")
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Any
|
||||
|
||||
from twilio.base.exceptions import TwilioRestException
|
||||
from twilio.rest import Client
|
||||
from twilio.base.exceptions import TwilioRestException # type: ignore
|
||||
from twilio.rest import Client # type: ignore
|
||||
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Any, Union
|
||||
|
||||
from vanna.remote import VannaDefault
|
||||
from vanna.remote import VannaDefault # type: ignore
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
@@ -14,6 +14,9 @@ class VannaTool(BuiltinTool):
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
# Ensure runtime and credentials
|
||||
if not self.runtime or not self.runtime.credentials:
|
||||
raise ToolProviderCredentialValidationError("Tool runtime or credentials are missing")
|
||||
api_key = self.runtime.credentials.get("api_key", None)
|
||||
if not api_key:
|
||||
raise ToolProviderCredentialValidationError("Please input api key")
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import wikipedia
|
||||
import wikipedia # type: ignore
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Any, Union
|
||||
|
||||
import pandas as pd
|
||||
from requests.exceptions import HTTPError, ReadTimeout
|
||||
from yfinance import download
|
||||
from yfinance import download # type: ignore
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Any, Union
|
||||
|
||||
import yfinance
|
||||
import yfinance # type: ignore
|
||||
from requests.exceptions import HTTPError, ReadTimeout
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Any, Union
|
||||
|
||||
from requests.exceptions import HTTPError, ReadTimeout
|
||||
from yfinance import Ticker
|
||||
from yfinance import Ticker # type: ignore
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, Union
|
||||
|
||||
from googleapiclient.discovery import build
|
||||
from googleapiclient.discovery import build # type: ignore
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from abc import abstractmethod
|
||||
from os import listdir, path
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.helper.module_import_helper import load_single_subclass_from_source
|
||||
from core.tools.entities.tool_entities import ToolParameter, ToolProviderCredentials, ToolProviderType
|
||||
@@ -50,6 +50,8 @@ class BuiltinToolProviderController(ToolProviderController):
|
||||
"""
|
||||
if self.tools:
|
||||
return self.tools
|
||||
if not self.identity:
|
||||
return []
|
||||
|
||||
provider = self.identity.name
|
||||
tool_path = path.join(path.dirname(path.realpath(__file__)), "builtin", provider, "tools")
|
||||
@@ -86,7 +88,7 @@ class BuiltinToolProviderController(ToolProviderController):
|
||||
|
||||
return self.credentials_schema.copy()
|
||||
|
||||
def get_tools(self) -> list[Tool]:
|
||||
def get_tools(self, user_id: str = "", tenant_id: str = "") -> Optional[list[Tool]]:
|
||||
"""
|
||||
returns a list of tools that the provider can provide
|
||||
|
||||
@@ -94,11 +96,14 @@ class BuiltinToolProviderController(ToolProviderController):
|
||||
"""
|
||||
return self._get_builtin_tools()
|
||||
|
||||
def get_tool(self, tool_name: str) -> Tool:
|
||||
def get_tool(self, tool_name: str) -> Optional[Tool]:
|
||||
"""
|
||||
returns the tool that the provider can provide
|
||||
"""
|
||||
return next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None)
|
||||
tools = self.get_tools()
|
||||
if tools is None:
|
||||
raise ValueError("tools not found")
|
||||
return next((t for t in tools if t.identity and t.identity.name == tool_name), None)
|
||||
|
||||
def get_parameters(self, tool_name: str) -> list[ToolParameter]:
|
||||
"""
|
||||
@@ -107,10 +112,13 @@ class BuiltinToolProviderController(ToolProviderController):
|
||||
:param tool_name: the name of the tool, defined in `get_tools`
|
||||
:return: list of parameters
|
||||
"""
|
||||
tool = next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None)
|
||||
tools = self.get_tools()
|
||||
if tools is None:
|
||||
raise ToolNotFoundError(f"tool {tool_name} not found")
|
||||
tool = next((t for t in tools if t.identity and t.identity.name == tool_name), None)
|
||||
if tool is None:
|
||||
raise ToolNotFoundError(f"tool {tool_name} not found")
|
||||
return tool.parameters
|
||||
return tool.parameters or []
|
||||
|
||||
@property
|
||||
def need_credentials(self) -> bool:
|
||||
@@ -144,6 +152,8 @@ class BuiltinToolProviderController(ToolProviderController):
|
||||
"""
|
||||
returns the labels of the provider
|
||||
"""
|
||||
if self.identity is None:
|
||||
return []
|
||||
return self.identity.tags or []
|
||||
|
||||
def validate_parameters(self, tool_id: int, tool_name: str, tool_parameters: dict[str, Any]) -> None:
|
||||
@@ -159,56 +169,56 @@ class BuiltinToolProviderController(ToolProviderController):
|
||||
for parameter in tool_parameters_schema:
|
||||
tool_parameters_need_to_validate[parameter.name] = parameter
|
||||
|
||||
for parameter in tool_parameters:
|
||||
if parameter not in tool_parameters_need_to_validate:
|
||||
raise ToolParameterValidationError(f"parameter {parameter} not found in tool {tool_name}")
|
||||
for parameter_name in tool_parameters:
|
||||
if parameter_name not in tool_parameters_need_to_validate:
|
||||
raise ToolParameterValidationError(f"parameter {parameter_name} not found in tool {tool_name}")
|
||||
|
||||
# check type
|
||||
parameter_schema = tool_parameters_need_to_validate[parameter]
|
||||
parameter_schema = tool_parameters_need_to_validate[parameter_name]
|
||||
if parameter_schema.type == ToolParameter.ToolParameterType.STRING:
|
||||
if not isinstance(tool_parameters[parameter], str):
|
||||
raise ToolParameterValidationError(f"parameter {parameter} should be string")
|
||||
if not isinstance(tool_parameters[parameter_name], str):
|
||||
raise ToolParameterValidationError(f"parameter {parameter_name} should be string")
|
||||
|
||||
elif parameter_schema.type == ToolParameter.ToolParameterType.NUMBER:
|
||||
if not isinstance(tool_parameters[parameter], int | float):
|
||||
raise ToolParameterValidationError(f"parameter {parameter} should be number")
|
||||
if not isinstance(tool_parameters[parameter_name], int | float):
|
||||
raise ToolParameterValidationError(f"parameter {parameter_name} should be number")
|
||||
|
||||
if parameter_schema.min is not None and tool_parameters[parameter] < parameter_schema.min:
|
||||
if parameter_schema.min is not None and tool_parameters[parameter_name] < parameter_schema.min:
|
||||
raise ToolParameterValidationError(
|
||||
f"parameter {parameter} should be greater than {parameter_schema.min}"
|
||||
f"parameter {parameter_name} should be greater than {parameter_schema.min}"
|
||||
)
|
||||
|
||||
if parameter_schema.max is not None and tool_parameters[parameter] > parameter_schema.max:
|
||||
if parameter_schema.max is not None and tool_parameters[parameter_name] > parameter_schema.max:
|
||||
raise ToolParameterValidationError(
|
||||
f"parameter {parameter} should be less than {parameter_schema.max}"
|
||||
f"parameter {parameter_name} should be less than {parameter_schema.max}"
|
||||
)
|
||||
|
||||
elif parameter_schema.type == ToolParameter.ToolParameterType.BOOLEAN:
|
||||
if not isinstance(tool_parameters[parameter], bool):
|
||||
raise ToolParameterValidationError(f"parameter {parameter} should be boolean")
|
||||
if not isinstance(tool_parameters[parameter_name], bool):
|
||||
raise ToolParameterValidationError(f"parameter {parameter_name} should be boolean")
|
||||
|
||||
elif parameter_schema.type == ToolParameter.ToolParameterType.SELECT:
|
||||
if not isinstance(tool_parameters[parameter], str):
|
||||
raise ToolParameterValidationError(f"parameter {parameter} should be string")
|
||||
if not isinstance(tool_parameters[parameter_name], str):
|
||||
raise ToolParameterValidationError(f"parameter {parameter_name} should be string")
|
||||
|
||||
options = parameter_schema.options
|
||||
if not isinstance(options, list):
|
||||
raise ToolParameterValidationError(f"parameter {parameter} options should be list")
|
||||
raise ToolParameterValidationError(f"parameter {parameter_name} options should be list")
|
||||
|
||||
if tool_parameters[parameter] not in [x.value for x in options]:
|
||||
raise ToolParameterValidationError(f"parameter {parameter} should be one of {options}")
|
||||
if tool_parameters[parameter_name] not in [x.value for x in options]:
|
||||
raise ToolParameterValidationError(f"parameter {parameter_name} should be one of {options}")
|
||||
|
||||
tool_parameters_need_to_validate.pop(parameter)
|
||||
tool_parameters_need_to_validate.pop(parameter_name)
|
||||
|
||||
for parameter in tool_parameters_need_to_validate:
|
||||
parameter_schema = tool_parameters_need_to_validate[parameter]
|
||||
for parameter_name in tool_parameters_need_to_validate:
|
||||
parameter_schema = tool_parameters_need_to_validate[parameter_name]
|
||||
if parameter_schema.required:
|
||||
raise ToolParameterValidationError(f"parameter {parameter} is required")
|
||||
raise ToolParameterValidationError(f"parameter {parameter_name} is required")
|
||||
|
||||
# the parameter is not set currently, set the default value if needed
|
||||
if parameter_schema.default is not None:
|
||||
default_value = parameter_schema.type.cast_value(parameter_schema.default)
|
||||
tool_parameters[parameter] = default_value
|
||||
tool_parameters[parameter_name] = default_value
|
||||
|
||||
def validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
"""
|
||||
|
||||
@@ -24,10 +24,12 @@ class ToolProviderController(BaseModel, ABC):
|
||||
|
||||
:return: the credentials schema
|
||||
"""
|
||||
if self.credentials_schema is None:
|
||||
return {}
|
||||
return self.credentials_schema.copy()
|
||||
|
||||
@abstractmethod
|
||||
def get_tools(self) -> list[Tool]:
|
||||
def get_tools(self, user_id: str = "", tenant_id: str = "") -> Optional[list[Tool]]:
|
||||
"""
|
||||
returns a list of tools that the provider can provide
|
||||
|
||||
@@ -36,7 +38,7 @@ class ToolProviderController(BaseModel, ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_tool(self, tool_name: str) -> Tool:
|
||||
def get_tool(self, tool_name: str) -> Optional[Tool]:
|
||||
"""
|
||||
returns a tool that the provider can provide
|
||||
|
||||
@@ -51,10 +53,13 @@ class ToolProviderController(BaseModel, ABC):
|
||||
:param tool_name: the name of the tool, defined in `get_tools`
|
||||
:return: list of parameters
|
||||
"""
|
||||
tool = next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None)
|
||||
tools = self.get_tools()
|
||||
if tools is None:
|
||||
raise ToolNotFoundError(f"tool {tool_name} not found")
|
||||
tool = next((t for t in tools if t.identity and t.identity.name == tool_name), None)
|
||||
if tool is None:
|
||||
raise ToolNotFoundError(f"tool {tool_name} not found")
|
||||
return tool.parameters
|
||||
return tool.parameters or []
|
||||
|
||||
@property
|
||||
def provider_type(self) -> ToolProviderType:
|
||||
@@ -78,55 +83,55 @@ class ToolProviderController(BaseModel, ABC):
|
||||
for parameter in tool_parameters_schema:
|
||||
tool_parameters_need_to_validate[parameter.name] = parameter
|
||||
|
||||
for parameter in tool_parameters:
|
||||
if parameter not in tool_parameters_need_to_validate:
|
||||
raise ToolParameterValidationError(f"parameter {parameter} not found in tool {tool_name}")
|
||||
for tool_parameter in tool_parameters:
|
||||
if tool_parameter not in tool_parameters_need_to_validate:
|
||||
raise ToolParameterValidationError(f"parameter {tool_parameter} not found in tool {tool_name}")
|
||||
|
||||
# check type
|
||||
parameter_schema = tool_parameters_need_to_validate[parameter]
|
||||
parameter_schema = tool_parameters_need_to_validate[tool_parameter]
|
||||
if parameter_schema.type == ToolParameter.ToolParameterType.STRING:
|
||||
if not isinstance(tool_parameters[parameter], str):
|
||||
raise ToolParameterValidationError(f"parameter {parameter} should be string")
|
||||
if not isinstance(tool_parameters[tool_parameter], str):
|
||||
raise ToolParameterValidationError(f"parameter {tool_parameter} should be string")
|
||||
|
||||
elif parameter_schema.type == ToolParameter.ToolParameterType.NUMBER:
|
||||
if not isinstance(tool_parameters[parameter], int | float):
|
||||
raise ToolParameterValidationError(f"parameter {parameter} should be number")
|
||||
if not isinstance(tool_parameters[tool_parameter], int | float):
|
||||
raise ToolParameterValidationError(f"parameter {tool_parameter} should be number")
|
||||
|
||||
if parameter_schema.min is not None and tool_parameters[parameter] < parameter_schema.min:
|
||||
if parameter_schema.min is not None and tool_parameters[tool_parameter] < parameter_schema.min:
|
||||
raise ToolParameterValidationError(
|
||||
f"parameter {parameter} should be greater than {parameter_schema.min}"
|
||||
f"parameter {tool_parameter} should be greater than {parameter_schema.min}"
|
||||
)
|
||||
|
||||
if parameter_schema.max is not None and tool_parameters[parameter] > parameter_schema.max:
|
||||
if parameter_schema.max is not None and tool_parameters[tool_parameter] > parameter_schema.max:
|
||||
raise ToolParameterValidationError(
|
||||
f"parameter {parameter} should be less than {parameter_schema.max}"
|
||||
f"parameter {tool_parameter} should be less than {parameter_schema.max}"
|
||||
)
|
||||
|
||||
elif parameter_schema.type == ToolParameter.ToolParameterType.BOOLEAN:
|
||||
if not isinstance(tool_parameters[parameter], bool):
|
||||
raise ToolParameterValidationError(f"parameter {parameter} should be boolean")
|
||||
if not isinstance(tool_parameters[tool_parameter], bool):
|
||||
raise ToolParameterValidationError(f"parameter {tool_parameter} should be boolean")
|
||||
|
||||
elif parameter_schema.type == ToolParameter.ToolParameterType.SELECT:
|
||||
if not isinstance(tool_parameters[parameter], str):
|
||||
raise ToolParameterValidationError(f"parameter {parameter} should be string")
|
||||
if not isinstance(tool_parameters[tool_parameter], str):
|
||||
raise ToolParameterValidationError(f"parameter {tool_parameter} should be string")
|
||||
|
||||
options = parameter_schema.options
|
||||
if not isinstance(options, list):
|
||||
raise ToolParameterValidationError(f"parameter {parameter} options should be list")
|
||||
raise ToolParameterValidationError(f"parameter {tool_parameter} options should be list")
|
||||
|
||||
if tool_parameters[parameter] not in [x.value for x in options]:
|
||||
raise ToolParameterValidationError(f"parameter {parameter} should be one of {options}")
|
||||
if tool_parameters[tool_parameter] not in [x.value for x in options]:
|
||||
raise ToolParameterValidationError(f"parameter {tool_parameter} should be one of {options}")
|
||||
|
||||
tool_parameters_need_to_validate.pop(parameter)
|
||||
tool_parameters_need_to_validate.pop(tool_parameter)
|
||||
|
||||
for parameter in tool_parameters_need_to_validate:
|
||||
parameter_schema = tool_parameters_need_to_validate[parameter]
|
||||
for tool_parameter_validate in tool_parameters_need_to_validate:
|
||||
parameter_schema = tool_parameters_need_to_validate[tool_parameter_validate]
|
||||
if parameter_schema.required:
|
||||
raise ToolParameterValidationError(f"parameter {parameter} is required")
|
||||
raise ToolParameterValidationError(f"parameter {tool_parameter_validate} is required")
|
||||
|
||||
# the parameter is not set currently, set the default value if needed
|
||||
if parameter_schema.default is not None:
|
||||
tool_parameters[parameter] = parameter_schema.type.cast_value(parameter_schema.default)
|
||||
tool_parameters[tool_parameter_validate] = parameter_schema.type.cast_value(parameter_schema.default)
|
||||
|
||||
def validate_credentials_format(self, credentials: dict[str, Any]) -> None:
|
||||
"""
|
||||
@@ -144,6 +149,8 @@ class ToolProviderController(BaseModel, ABC):
|
||||
|
||||
for credential_name in credentials:
|
||||
if credential_name not in credentials_need_to_validate:
|
||||
if self.identity is None:
|
||||
raise ValueError("identity is not set")
|
||||
raise ToolProviderCredentialValidationError(
|
||||
f"credential {credential_name} not found in provider {self.identity.name}"
|
||||
)
|
||||
|
||||
@@ -11,6 +11,7 @@ from core.tools.entities.tool_entities import (
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.provider.tool_provider import ToolProviderController
|
||||
from core.tools.tool.tool import Tool
|
||||
from core.tools.tool.workflow_tool import WorkflowTool
|
||||
from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils
|
||||
from extensions.ext_database import db
|
||||
@@ -116,6 +117,7 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
llm_description=parameter.description,
|
||||
required=variable.required,
|
||||
options=options,
|
||||
placeholder=I18nObject(en_US="", zh_Hans=""),
|
||||
)
|
||||
)
|
||||
elif features.file_upload:
|
||||
@@ -128,6 +130,7 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
llm_description=parameter.description,
|
||||
required=False,
|
||||
form=parameter.form,
|
||||
placeholder=I18nObject(en_US="", zh_Hans=""),
|
||||
)
|
||||
)
|
||||
else:
|
||||
@@ -157,7 +160,7 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
label=db_provider.label,
|
||||
)
|
||||
|
||||
def get_tools(self, user_id: str, tenant_id: str) -> list[WorkflowTool]:
|
||||
def get_tools(self, user_id: str = "", tenant_id: str = "") -> Optional[list[Tool]]:
|
||||
"""
|
||||
fetch tools from database
|
||||
|
||||
@@ -168,7 +171,7 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
if self.tools is not None:
|
||||
return self.tools
|
||||
|
||||
db_providers: WorkflowToolProvider = (
|
||||
db_providers: Optional[WorkflowToolProvider] = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
.filter(
|
||||
WorkflowToolProvider.tenant_id == tenant_id,
|
||||
@@ -179,12 +182,14 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
|
||||
if not db_providers:
|
||||
return []
|
||||
if not db_providers.app:
|
||||
raise ValueError("app not found")
|
||||
|
||||
self.tools = [self._get_db_provider_tool(db_providers, db_providers.app)]
|
||||
|
||||
return self.tools
|
||||
|
||||
def get_tool(self, tool_name: str) -> Optional[WorkflowTool]:
|
||||
def get_tool(self, tool_name: str) -> Optional[Tool]:
|
||||
"""
|
||||
get tool by name
|
||||
|
||||
@@ -195,6 +200,8 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
return None
|
||||
|
||||
for tool in self.tools:
|
||||
if tool.identity is None:
|
||||
continue
|
||||
if tool.identity.name == tool_name:
|
||||
return tool
|
||||
|
||||
|
||||
@@ -32,11 +32,13 @@ class ApiTool(Tool):
|
||||
:param meta: the meta data of a tool call processing, tenant_id is required
|
||||
:return: the new tool
|
||||
"""
|
||||
if self.api_bundle is None:
|
||||
raise ValueError("api_bundle is required")
|
||||
return self.__class__(
|
||||
identity=self.identity.model_copy() if self.identity else None,
|
||||
parameters=self.parameters.copy() if self.parameters else None,
|
||||
description=self.description.model_copy() if self.description else None,
|
||||
api_bundle=self.api_bundle.model_copy() if self.api_bundle else None,
|
||||
api_bundle=self.api_bundle.model_copy(),
|
||||
runtime=Tool.Runtime(**runtime),
|
||||
)
|
||||
|
||||
@@ -61,6 +63,8 @@ class ApiTool(Tool):
|
||||
|
||||
def assembling_request(self, parameters: dict[str, Any]) -> dict[str, Any]:
|
||||
headers = {}
|
||||
if self.runtime is None:
|
||||
raise ValueError("runtime is required")
|
||||
credentials = self.runtime.credentials or {}
|
||||
|
||||
if "auth_type" not in credentials:
|
||||
@@ -88,7 +92,7 @@ class ApiTool(Tool):
|
||||
|
||||
headers[api_key_header] = credentials["api_key_value"]
|
||||
|
||||
needed_parameters = [parameter for parameter in self.api_bundle.parameters if parameter.required]
|
||||
needed_parameters = [parameter for parameter in (self.api_bundle.parameters or []) if parameter.required]
|
||||
for parameter in needed_parameters:
|
||||
if parameter.required and parameter.name not in parameters:
|
||||
raise ToolParameterValidationError(f"Missing required parameter {parameter.name}")
|
||||
@@ -137,7 +141,8 @@ class ApiTool(Tool):
|
||||
|
||||
params = {}
|
||||
path_params = {}
|
||||
body = {}
|
||||
# FIXME: body should be a dict[str, Any] but it changed a lot in this function
|
||||
body: Any = {}
|
||||
cookies = {}
|
||||
files = []
|
||||
|
||||
@@ -198,7 +203,7 @@ class ApiTool(Tool):
|
||||
body = body
|
||||
|
||||
if method in {"get", "head", "post", "put", "delete", "patch"}:
|
||||
response = getattr(ssrf_proxy, method)(
|
||||
response: httpx.Response = getattr(ssrf_proxy, method)(
|
||||
url,
|
||||
params=params,
|
||||
headers=headers,
|
||||
@@ -288,6 +293,7 @@ class ApiTool(Tool):
|
||||
"""
|
||||
invoke http request
|
||||
"""
|
||||
response: httpx.Response | str = ""
|
||||
# assemble request
|
||||
headers = self.assembling_request(tool_parameters)
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Optional
|
||||
from typing import Optional, cast
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage
|
||||
@@ -32,9 +32,12 @@ class BuiltinTool(Tool):
|
||||
:return: the model result
|
||||
"""
|
||||
# invoke model
|
||||
if self.runtime is None or self.identity is None:
|
||||
raise ValueError("runtime and identity are required")
|
||||
|
||||
return ModelInvocationUtils.invoke(
|
||||
user_id=user_id,
|
||||
tenant_id=self.runtime.tenant_id,
|
||||
tenant_id=self.runtime.tenant_id or "",
|
||||
tool_type="builtin",
|
||||
tool_name=self.identity.name,
|
||||
prompt_messages=prompt_messages,
|
||||
@@ -50,8 +53,11 @@ class BuiltinTool(Tool):
|
||||
:param model_config: the model config
|
||||
:return: the max tokens
|
||||
"""
|
||||
if self.runtime is None:
|
||||
raise ValueError("runtime is required")
|
||||
|
||||
return ModelInvocationUtils.get_max_llm_context_tokens(
|
||||
tenant_id=self.runtime.tenant_id,
|
||||
tenant_id=self.runtime.tenant_id or "",
|
||||
)
|
||||
|
||||
def get_prompt_tokens(self, prompt_messages: list[PromptMessage]) -> int:
|
||||
@@ -61,7 +67,12 @@ class BuiltinTool(Tool):
|
||||
:param prompt_messages: the prompt messages
|
||||
:return: the tokens
|
||||
"""
|
||||
return ModelInvocationUtils.calculate_tokens(tenant_id=self.runtime.tenant_id, prompt_messages=prompt_messages)
|
||||
if self.runtime is None:
|
||||
raise ValueError("runtime is required")
|
||||
|
||||
return ModelInvocationUtils.calculate_tokens(
|
||||
tenant_id=self.runtime.tenant_id or "", prompt_messages=prompt_messages
|
||||
)
|
||||
|
||||
def summary(self, user_id: str, content: str) -> str:
|
||||
max_tokens = self.get_max_tokens()
|
||||
@@ -81,7 +92,7 @@ class BuiltinTool(Tool):
|
||||
stop=[],
|
||||
)
|
||||
|
||||
return summary.message.content
|
||||
return cast(str, summary.message.content)
|
||||
|
||||
lines = content.split("\n")
|
||||
new_lines = []
|
||||
@@ -102,16 +113,16 @@ class BuiltinTool(Tool):
|
||||
|
||||
# merge lines into messages with max tokens
|
||||
messages: list[str] = []
|
||||
for i in new_lines:
|
||||
for j in new_lines:
|
||||
if len(messages) == 0:
|
||||
messages.append(i)
|
||||
messages.append(j)
|
||||
else:
|
||||
if len(messages[-1]) + len(i) < max_tokens * 0.5:
|
||||
messages[-1] += i
|
||||
if get_prompt_tokens(messages[-1] + i) > max_tokens * 0.7:
|
||||
messages.append(i)
|
||||
if len(messages[-1]) + len(j) < max_tokens * 0.5:
|
||||
messages[-1] += j
|
||||
if get_prompt_tokens(messages[-1] + j) > max_tokens * 0.7:
|
||||
messages.append(j)
|
||||
else:
|
||||
messages[-1] += i
|
||||
messages[-1] += j
|
||||
|
||||
summaries = []
|
||||
for i in range(len(messages)):
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import threading
|
||||
from typing import Any
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -7,13 +8,14 @@ from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCa
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.models.document import Document as RagDocument
|
||||
from core.rag.rerank.rerank_model import RerankModelRunner
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
|
||||
default_retrieval_model = {
|
||||
default_retrieval_model: dict[str, Any] = {
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
"reranking_enable": False,
|
||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
@@ -44,12 +46,12 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
|
||||
|
||||
def _run(self, query: str) -> str:
|
||||
threads = []
|
||||
all_documents = []
|
||||
all_documents: list[RagDocument] = []
|
||||
for dataset_id in self.dataset_ids:
|
||||
retrieval_thread = threading.Thread(
|
||||
target=self._retriever,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(),
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"dataset_id": dataset_id,
|
||||
"query": query,
|
||||
"all_documents": all_documents,
|
||||
@@ -77,11 +79,11 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
|
||||
|
||||
document_score_list = {}
|
||||
for item in all_documents:
|
||||
if item.metadata.get("score"):
|
||||
if item.metadata and item.metadata.get("score"):
|
||||
document_score_list[item.metadata["doc_id"]] = item.metadata["score"]
|
||||
|
||||
document_context_list = []
|
||||
index_node_ids = [document.metadata["doc_id"] for document in all_documents]
|
||||
index_node_ids = [document.metadata["doc_id"] for document in all_documents if document.metadata]
|
||||
segments = DocumentSegment.query.filter(
|
||||
DocumentSegment.dataset_id.in_(self.dataset_ids),
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
@@ -139,6 +141,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
|
||||
hit_callback.return_retriever_resource_info(context_list)
|
||||
|
||||
return str("\n".join(document_context_list))
|
||||
return ""
|
||||
|
||||
def _retriever(
|
||||
self,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from abc import abstractmethod
|
||||
from typing import Any, Optional
|
||||
|
||||
from msal_extensions.persistence import ABC
|
||||
from msal_extensions.persistence import ABC # type: ignore
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
@@ -69,25 +71,27 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
||||
metadata=external_document.get("metadata"),
|
||||
provider="external",
|
||||
)
|
||||
document.metadata["score"] = external_document.get("score")
|
||||
document.metadata["title"] = external_document.get("title")
|
||||
document.metadata["dataset_id"] = dataset.id
|
||||
document.metadata["dataset_name"] = dataset.name
|
||||
results.append(document)
|
||||
if document.metadata is not None:
|
||||
document.metadata["score"] = external_document.get("score")
|
||||
document.metadata["title"] = external_document.get("title")
|
||||
document.metadata["dataset_id"] = dataset.id
|
||||
document.metadata["dataset_name"] = dataset.name
|
||||
results.append(document)
|
||||
# deal with external documents
|
||||
context_list = []
|
||||
for position, item in enumerate(results, start=1):
|
||||
source = {
|
||||
"position": position,
|
||||
"dataset_id": item.metadata.get("dataset_id"),
|
||||
"dataset_name": item.metadata.get("dataset_name"),
|
||||
"document_name": item.metadata.get("title"),
|
||||
"data_source_type": "external",
|
||||
"retriever_from": self.retriever_from,
|
||||
"score": item.metadata.get("score"),
|
||||
"title": item.metadata.get("title"),
|
||||
"content": item.page_content,
|
||||
}
|
||||
if item.metadata is not None:
|
||||
source = {
|
||||
"position": position,
|
||||
"dataset_id": item.metadata.get("dataset_id"),
|
||||
"dataset_name": item.metadata.get("dataset_name"),
|
||||
"document_name": item.metadata.get("title"),
|
||||
"data_source_type": "external",
|
||||
"retriever_from": self.retriever_from,
|
||||
"score": item.metadata.get("score"),
|
||||
"title": item.metadata.get("title"),
|
||||
"content": item.page_content,
|
||||
}
|
||||
context_list.append(source)
|
||||
for hit_callback in self.hit_callbacks:
|
||||
hit_callback.return_retriever_resource_info(context_list)
|
||||
@@ -95,7 +99,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
||||
return str("\n".join([item.page_content for item in results]))
|
||||
else:
|
||||
# get retrieval model , if the model is not setting , using default
|
||||
retrieval_model = dataset.retrieval_model or default_retrieval_model
|
||||
retrieval_model: dict[str, Any] = dataset.retrieval_model or default_retrieval_model
|
||||
if dataset.indexing_technique == "economy":
|
||||
# use keyword table query
|
||||
documents = RetrievalService.retrieve(
|
||||
@@ -113,11 +117,11 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
||||
score_threshold=retrieval_model.get("score_threshold", 0.0)
|
||||
if retrieval_model["score_threshold_enabled"]
|
||||
else 0.0,
|
||||
reranking_model=retrieval_model.get("reranking_model", None)
|
||||
reranking_model=retrieval_model.get("reranking_model")
|
||||
if retrieval_model["reranking_enable"]
|
||||
else None,
|
||||
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
|
||||
weights=retrieval_model.get("weights", None),
|
||||
weights=retrieval_model.get("weights"),
|
||||
)
|
||||
else:
|
||||
documents = []
|
||||
@@ -127,7 +131,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
||||
document_score_list = {}
|
||||
if dataset.indexing_technique != "economy":
|
||||
for item in documents:
|
||||
if item.metadata.get("score"):
|
||||
if item.metadata is not None and item.metadata.get("score"):
|
||||
document_score_list[item.metadata["doc_id"]] = item.metadata["score"]
|
||||
document_context_list = []
|
||||
index_node_ids = [document.metadata["doc_id"] for document in documents]
|
||||
@@ -155,20 +159,21 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
||||
context_list = []
|
||||
resource_number = 1
|
||||
for segment in sorted_segments:
|
||||
context = {}
|
||||
document = Document.query.filter(
|
||||
document_segment = Document.query.filter(
|
||||
Document.id == segment.document_id,
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
).first()
|
||||
if dataset and document:
|
||||
if not document_segment:
|
||||
continue
|
||||
if dataset and document_segment:
|
||||
source = {
|
||||
"position": resource_number,
|
||||
"dataset_id": dataset.id,
|
||||
"dataset_name": dataset.name,
|
||||
"document_id": document.id,
|
||||
"document_name": document.name,
|
||||
"data_source_type": document.data_source_type,
|
||||
"document_id": document_segment.id,
|
||||
"document_name": document_segment.name,
|
||||
"data_source_type": document_segment.data_source_type,
|
||||
"segment_id": segment.id,
|
||||
"retriever_from": self.retriever_from,
|
||||
"score": document_score_list.get(segment.index_node_id, None),
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.app.app_config.entities import DatasetRetrieveConfigEntity
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
@@ -23,7 +23,7 @@ class DatasetRetrieverTool(Tool):
|
||||
def get_dataset_tools(
|
||||
tenant_id: str,
|
||||
dataset_ids: list[str],
|
||||
retrieve_config: DatasetRetrieveConfigEntity,
|
||||
retrieve_config: Optional[DatasetRetrieveConfigEntity],
|
||||
return_resource: bool,
|
||||
invoke_from: InvokeFrom,
|
||||
hit_callback: DatasetIndexToolCallbackHandler,
|
||||
@@ -51,6 +51,8 @@ class DatasetRetrieverTool(Tool):
|
||||
invoke_from=invoke_from,
|
||||
hit_callback=hit_callback,
|
||||
)
|
||||
if retrieval_tools is None:
|
||||
return []
|
||||
# restore retrieve strategy
|
||||
retrieve_config.retrieve_strategy = original_retriever_mode
|
||||
|
||||
@@ -83,6 +85,7 @@ class DatasetRetrieverTool(Tool):
|
||||
llm_description="Query for the dataset to be used to retrieve the dataset.",
|
||||
required=True,
|
||||
default="",
|
||||
placeholder=I18nObject(en_US="", zh_Hans=""),
|
||||
),
|
||||
]
|
||||
|
||||
@@ -102,7 +105,9 @@ class DatasetRetrieverTool(Tool):
|
||||
|
||||
return self.create_text_message(text=result)
|
||||
|
||||
def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any]) -> None:
|
||||
def validate_credentials(
|
||||
self, credentials: dict[str, Any], parameters: dict[str, Any], format_only: bool = False
|
||||
) -> str | None:
|
||||
"""
|
||||
validate the credentials for dataset retriever tool
|
||||
"""
|
||||
|
||||
@@ -91,7 +91,7 @@ class Tool(BaseModel, ABC):
|
||||
:return: the tool provider type
|
||||
"""
|
||||
|
||||
def load_variables(self, variables: ToolRuntimeVariablePool):
|
||||
def load_variables(self, variables: ToolRuntimeVariablePool | None) -> None:
|
||||
"""
|
||||
load variables from database
|
||||
|
||||
@@ -105,6 +105,8 @@ class Tool(BaseModel, ABC):
|
||||
"""
|
||||
if not self.variables:
|
||||
return
|
||||
if self.identity is None:
|
||||
return
|
||||
|
||||
self.variables.set_file(self.identity.name, variable_name, image_key)
|
||||
|
||||
@@ -114,6 +116,8 @@ class Tool(BaseModel, ABC):
|
||||
"""
|
||||
if not self.variables:
|
||||
return
|
||||
if self.identity is None:
|
||||
return
|
||||
|
||||
self.variables.set_text(self.identity.name, variable_name, text)
|
||||
|
||||
@@ -200,7 +204,11 @@ class Tool(BaseModel, ABC):
|
||||
def invoke(self, user_id: str, tool_parameters: Mapping[str, Any]) -> list[ToolInvokeMessage]:
|
||||
# update tool_parameters
|
||||
# TODO: Fix type error.
|
||||
if self.runtime is None:
|
||||
return []
|
||||
if self.runtime.runtime_parameters:
|
||||
# Convert Mapping to dict before updating
|
||||
tool_parameters = dict(tool_parameters)
|
||||
tool_parameters.update(self.runtime.runtime_parameters)
|
||||
|
||||
# try parse tool parameters into the correct type
|
||||
@@ -221,7 +229,7 @@ class Tool(BaseModel, ABC):
|
||||
Transform tool parameters type
|
||||
"""
|
||||
# Temp fix for the issue that the tool parameters will be converted to empty while validating the credentials
|
||||
result = deepcopy(tool_parameters)
|
||||
result: dict[str, Any] = deepcopy(dict(tool_parameters))
|
||||
for parameter in self.parameters or []:
|
||||
if parameter.name in tool_parameters:
|
||||
result[parameter.name] = parameter.type.cast_value(tool_parameters[parameter.name])
|
||||
@@ -234,12 +242,15 @@ class Tool(BaseModel, ABC):
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
pass
|
||||
|
||||
def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any]) -> None:
|
||||
def validate_credentials(
|
||||
self, credentials: dict[str, Any], parameters: dict[str, Any], format_only: bool = False
|
||||
) -> str | None:
|
||||
"""
|
||||
validate the credentials
|
||||
|
||||
:param credentials: the credentials
|
||||
:param parameters: the parameters
|
||||
:param format_only: only return the formatted
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@@ -68,20 +68,20 @@ class WorkflowTool(Tool):
|
||||
if data.get("error"):
|
||||
raise Exception(data.get("error"))
|
||||
|
||||
result = []
|
||||
r = []
|
||||
|
||||
outputs = data.get("outputs")
|
||||
if outputs == None:
|
||||
outputs = {}
|
||||
else:
|
||||
outputs, files = self._extract_files(outputs)
|
||||
for file in files:
|
||||
result.append(self.create_file_message(file))
|
||||
outputs, extracted_files = self._extract_files(outputs)
|
||||
for f in extracted_files:
|
||||
r.append(self.create_file_message(f))
|
||||
|
||||
result.append(self.create_text_message(json.dumps(outputs, ensure_ascii=False)))
|
||||
result.append(self.create_json_message(outputs))
|
||||
r.append(self.create_text_message(json.dumps(outputs, ensure_ascii=False)))
|
||||
r.append(self.create_json_message(outputs))
|
||||
|
||||
return result
|
||||
return r
|
||||
|
||||
def _get_user(self, user_id: str) -> Union[EndUser, Account]:
|
||||
"""
|
||||
|
||||
@@ -3,7 +3,7 @@ from collections.abc import Mapping
|
||||
from copy import deepcopy
|
||||
from datetime import UTC, datetime
|
||||
from mimetypes import guess_type
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
from yarl import URL
|
||||
|
||||
@@ -46,7 +46,7 @@ class ToolEngine:
|
||||
invoke_from: InvokeFrom,
|
||||
agent_tool_callback: DifyAgentCallbackHandler,
|
||||
trace_manager: Optional[TraceQueueManager] = None,
|
||||
) -> tuple[str, list[tuple[MessageFile, bool]], ToolInvokeMeta]:
|
||||
) -> tuple[str, list[tuple[MessageFile, str]], ToolInvokeMeta]:
|
||||
"""
|
||||
Agent invokes the tool with the given arguments.
|
||||
"""
|
||||
@@ -69,6 +69,8 @@ class ToolEngine:
|
||||
raise ValueError(f"tool_parameters should be a dict, but got a string: {tool_parameters}")
|
||||
|
||||
# invoke the tool
|
||||
if tool.identity is None:
|
||||
raise ValueError("tool identity is not set")
|
||||
try:
|
||||
# hit the callback handler
|
||||
agent_tool_callback.on_tool_start(tool_name=tool.identity.name, tool_inputs=tool_parameters)
|
||||
@@ -163,6 +165,8 @@ class ToolEngine:
|
||||
"""
|
||||
Invoke the tool with the given arguments.
|
||||
"""
|
||||
if tool.identity is None:
|
||||
raise ValueError("tool identity is not set")
|
||||
started_at = datetime.now(UTC)
|
||||
meta = ToolInvokeMeta(
|
||||
time_cost=0.0,
|
||||
@@ -171,7 +175,7 @@ class ToolEngine:
|
||||
"tool_name": tool.identity.name,
|
||||
"tool_provider": tool.identity.provider,
|
||||
"tool_provider_type": tool.tool_provider_type().value,
|
||||
"tool_parameters": deepcopy(tool.runtime.runtime_parameters),
|
||||
"tool_parameters": deepcopy(tool.runtime.runtime_parameters) if tool.runtime else {},
|
||||
"tool_icon": tool.identity.icon,
|
||||
},
|
||||
)
|
||||
@@ -194,9 +198,9 @@ class ToolEngine:
|
||||
result = ""
|
||||
for response in tool_response:
|
||||
if response.type == ToolInvokeMessage.MessageType.TEXT:
|
||||
result += response.message
|
||||
result += str(response.message) if response.message is not None else ""
|
||||
elif response.type == ToolInvokeMessage.MessageType.LINK:
|
||||
result += f"result link: {response.message}. please tell user to check it."
|
||||
result += f"result link: {response.message!r}. please tell user to check it."
|
||||
elif response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}:
|
||||
result += (
|
||||
"image has been created and sent to user already, you do not need to create it,"
|
||||
@@ -205,7 +209,7 @@ class ToolEngine:
|
||||
elif response.type == ToolInvokeMessage.MessageType.JSON:
|
||||
result += f"tool response: {json.dumps(response.message, ensure_ascii=False)}."
|
||||
else:
|
||||
result += f"tool response: {response.message}."
|
||||
result += f"tool response: {response.message!r}."
|
||||
|
||||
return result
|
||||
|
||||
@@ -223,7 +227,7 @@ class ToolEngine:
|
||||
mimetype = response.meta.get("mime_type")
|
||||
else:
|
||||
try:
|
||||
url = URL(response.message)
|
||||
url = URL(cast(str, response.message))
|
||||
extension = url.suffix
|
||||
guess_type_result, _ = guess_type(f"a{extension}")
|
||||
if guess_type_result:
|
||||
@@ -237,7 +241,7 @@ class ToolEngine:
|
||||
result.append(
|
||||
ToolInvokeMessageBinary(
|
||||
mimetype=response.meta.get("mime_type", "image/jpeg"),
|
||||
url=response.message,
|
||||
url=cast(str, response.message),
|
||||
save_as=response.save_as,
|
||||
)
|
||||
)
|
||||
@@ -245,7 +249,7 @@ class ToolEngine:
|
||||
result.append(
|
||||
ToolInvokeMessageBinary(
|
||||
mimetype=response.meta.get("mime_type", "octet/stream"),
|
||||
url=response.message,
|
||||
url=cast(str, response.message),
|
||||
save_as=response.save_as,
|
||||
)
|
||||
)
|
||||
@@ -257,7 +261,7 @@ class ToolEngine:
|
||||
mimetype=response.meta.get("mime_type", "octet/stream")
|
||||
if response.meta
|
||||
else "octet/stream",
|
||||
url=response.message,
|
||||
url=cast(str, response.message),
|
||||
save_as=response.save_as,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -84,13 +84,17 @@ class ToolLabelManager:
|
||||
if not isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
|
||||
raise ValueError("Unsupported tool type")
|
||||
|
||||
provider_ids = [controller.provider_id for controller in tool_providers]
|
||||
provider_ids = [
|
||||
controller.provider_id
|
||||
for controller in tool_providers
|
||||
if isinstance(controller, (ApiToolProviderController, WorkflowToolProviderController))
|
||||
]
|
||||
|
||||
labels: list[ToolLabelBinding] = (
|
||||
db.session.query(ToolLabelBinding).filter(ToolLabelBinding.tool_id.in_(provider_ids)).all()
|
||||
)
|
||||
|
||||
tool_labels = {label.tool_id: [] for label in labels}
|
||||
tool_labels: dict[str, list[str]] = {label.tool_id: [] for label in labels}
|
||||
|
||||
for label in labels:
|
||||
tool_labels[label.tool_id].append(label.label_name)
|
||||
|
||||
@@ -4,7 +4,7 @@ import mimetypes
|
||||
from collections.abc import Generator
|
||||
from os import listdir, path
|
||||
from threading import Lock, Thread
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.agent.entities import AgentToolEntity
|
||||
@@ -15,15 +15,18 @@ from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.entities.api_entities import UserToolProvider, UserToolProviderTypeLiteral
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ApiProviderAuthType, ToolInvokeFrom, ToolParameter
|
||||
from core.tools.errors import ToolProviderNotFoundError
|
||||
from core.tools.errors import ToolNotFoundError, ToolProviderNotFoundError
|
||||
from core.tools.provider.api_tool_provider import ApiToolProviderController
|
||||
from core.tools.provider.builtin._positions import BuiltinToolProviderSort
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
from core.tools.provider.tool_provider import ToolProviderController
|
||||
from core.tools.provider.workflow_tool_provider import WorkflowToolProviderController
|
||||
from core.tools.tool.api_tool import ApiTool
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
from core.tools.tool.tool import Tool
|
||||
from core.tools.tool_label_manager import ToolLabelManager
|
||||
from core.tools.utils.configuration import ToolConfigurationManager, ToolParameterConfigurationManager
|
||||
from core.workflow.nodes.tool.entities import ToolEntity
|
||||
from extensions.ext_database import db
|
||||
from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
@@ -33,9 +36,9 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class ToolManager:
|
||||
_builtin_provider_lock = Lock()
|
||||
_builtin_providers = {}
|
||||
_builtin_providers: dict[str, BuiltinToolProviderController] = {}
|
||||
_builtin_providers_loaded = False
|
||||
_builtin_tools_labels = {}
|
||||
_builtin_tools_labels: dict[str, Union[I18nObject, None]] = {}
|
||||
|
||||
@classmethod
|
||||
def get_builtin_provider(cls, provider: str) -> BuiltinToolProviderController:
|
||||
@@ -55,7 +58,7 @@ class ToolManager:
|
||||
return cls._builtin_providers[provider]
|
||||
|
||||
@classmethod
|
||||
def get_builtin_tool(cls, provider: str, tool_name: str) -> BuiltinTool:
|
||||
def get_builtin_tool(cls, provider: str, tool_name: str) -> Union[BuiltinTool, Tool]:
|
||||
"""
|
||||
get the builtin tool
|
||||
|
||||
@@ -66,13 +69,15 @@ class ToolManager:
|
||||
"""
|
||||
provider_controller = cls.get_builtin_provider(provider)
|
||||
tool = provider_controller.get_tool(tool_name)
|
||||
if tool is None:
|
||||
raise ToolNotFoundError(f"tool {tool_name} not found")
|
||||
|
||||
return tool
|
||||
|
||||
@classmethod
|
||||
def get_tool(
|
||||
cls, provider_type: str, provider_id: str, tool_name: str, tenant_id: Optional[str] = None
|
||||
) -> Union[BuiltinTool, ApiTool]:
|
||||
) -> Union[BuiltinTool, ApiTool, Tool]:
|
||||
"""
|
||||
get the tool
|
||||
|
||||
@@ -103,7 +108,7 @@ class ToolManager:
|
||||
tenant_id: str,
|
||||
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
|
||||
tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT,
|
||||
) -> Union[BuiltinTool, ApiTool]:
|
||||
) -> Union[BuiltinTool, ApiTool, Tool]:
|
||||
"""
|
||||
get the tool runtime
|
||||
|
||||
@@ -113,6 +118,7 @@ class ToolManager:
|
||||
|
||||
:return: the tool
|
||||
"""
|
||||
controller: Union[BuiltinToolProviderController, ApiToolProviderController, WorkflowToolProviderController]
|
||||
if provider_type == "builtin":
|
||||
builtin_tool = cls.get_builtin_tool(provider_id, tool_name)
|
||||
|
||||
@@ -129,7 +135,7 @@ class ToolManager:
|
||||
)
|
||||
|
||||
# get credentials
|
||||
builtin_provider: BuiltinToolProvider = (
|
||||
builtin_provider: Optional[BuiltinToolProvider] = (
|
||||
db.session.query(BuiltinToolProvider)
|
||||
.filter(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
@@ -177,7 +183,7 @@ class ToolManager:
|
||||
}
|
||||
)
|
||||
elif provider_type == "workflow":
|
||||
workflow_provider = (
|
||||
workflow_provider: Optional[WorkflowToolProvider] = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
|
||||
.first()
|
||||
@@ -187,8 +193,13 @@ class ToolManager:
|
||||
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
|
||||
|
||||
controller = ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider)
|
||||
controller_tools: Optional[list[Tool]] = controller.get_tools(
|
||||
user_id="", tenant_id=workflow_provider.tenant_id
|
||||
)
|
||||
if controller_tools is None or len(controller_tools) == 0:
|
||||
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
|
||||
|
||||
return controller.get_tools(user_id=None, tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime(
|
||||
return controller_tools[0].fork_tool_runtime(
|
||||
runtime={
|
||||
"tenant_id": tenant_id,
|
||||
"credentials": {},
|
||||
@@ -215,7 +226,7 @@ class ToolManager:
|
||||
|
||||
if parameter_rule.type == ToolParameter.ToolParameterType.SELECT:
|
||||
# check if tool_parameter_config in options
|
||||
options = [x.value for x in parameter_rule.options]
|
||||
options = [x.value for x in parameter_rule.options or []]
|
||||
if parameter_value is not None and parameter_value not in options:
|
||||
raise ValueError(
|
||||
f"tool parameter {parameter_rule.name} value {parameter_value} not in options {options}"
|
||||
@@ -267,6 +278,8 @@ class ToolManager:
|
||||
identity_id=f"AGENT.{app_id}",
|
||||
)
|
||||
runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
|
||||
if tool_entity.runtime is None or tool_entity.runtime.runtime_parameters is None:
|
||||
raise ValueError("runtime not found or runtime parameters not found")
|
||||
|
||||
tool_entity.runtime.runtime_parameters.update(runtime_parameters)
|
||||
return tool_entity
|
||||
@@ -312,6 +325,9 @@ class ToolManager:
|
||||
if runtime_parameters:
|
||||
runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
|
||||
|
||||
if tool_entity.runtime is None or tool_entity.runtime.runtime_parameters is None:
|
||||
raise ValueError("runtime not found or runtime parameters not found")
|
||||
|
||||
tool_entity.runtime.runtime_parameters.update(runtime_parameters)
|
||||
return tool_entity
|
||||
|
||||
@@ -326,6 +342,8 @@ class ToolManager:
|
||||
"""
|
||||
# get provider
|
||||
provider_controller = cls.get_builtin_provider(provider)
|
||||
if provider_controller.identity is None:
|
||||
raise ToolProviderNotFoundError(f"builtin provider {provider} not found")
|
||||
|
||||
absolute_path = path.join(
|
||||
path.dirname(path.realpath(__file__)),
|
||||
@@ -381,11 +399,15 @@ class ToolManager:
|
||||
),
|
||||
parent_type=BuiltinToolProviderController,
|
||||
)
|
||||
provider: BuiltinToolProviderController = provider_class()
|
||||
cls._builtin_providers[provider.identity.name] = provider
|
||||
for tool in provider.get_tools():
|
||||
provider_controller: BuiltinToolProviderController = provider_class()
|
||||
if provider_controller.identity is None:
|
||||
continue
|
||||
cls._builtin_providers[provider_controller.identity.name] = provider_controller
|
||||
for tool in provider_controller.get_tools() or []:
|
||||
if tool.identity is None:
|
||||
continue
|
||||
cls._builtin_tools_labels[tool.identity.name] = tool.identity.label
|
||||
yield provider
|
||||
yield provider_controller
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"load builtin provider {provider}")
|
||||
@@ -449,9 +471,11 @@ class ToolManager:
|
||||
# append builtin providers
|
||||
for provider in builtin_providers:
|
||||
# handle include, exclude
|
||||
if provider.identity is None:
|
||||
continue
|
||||
if is_filtered(
|
||||
include_set=dify_config.POSITION_TOOL_INCLUDES_SET,
|
||||
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET,
|
||||
include_set=cast(set[str], dify_config.POSITION_TOOL_INCLUDES_SET),
|
||||
exclude_set=cast(set[str], dify_config.POSITION_TOOL_EXCLUDES_SET),
|
||||
data=provider,
|
||||
name_func=lambda x: x.identity.name,
|
||||
):
|
||||
@@ -472,7 +496,7 @@ class ToolManager:
|
||||
db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all()
|
||||
)
|
||||
|
||||
api_provider_controllers = [
|
||||
api_provider_controllers: list[dict[str, Any]] = [
|
||||
{"provider": provider, "controller": ToolTransformService.api_provider_to_controller(provider)}
|
||||
for provider in db_api_providers
|
||||
]
|
||||
@@ -495,7 +519,7 @@ class ToolManager:
|
||||
db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all()
|
||||
)
|
||||
|
||||
workflow_provider_controllers = []
|
||||
workflow_provider_controllers: list[WorkflowToolProviderController] = []
|
||||
for provider in workflow_providers:
|
||||
try:
|
||||
workflow_provider_controllers.append(
|
||||
@@ -505,7 +529,9 @@ class ToolManager:
|
||||
# app has been deleted
|
||||
pass
|
||||
|
||||
labels = ToolLabelManager.get_tools_labels(workflow_provider_controllers)
|
||||
labels = ToolLabelManager.get_tools_labels(
|
||||
[cast(ToolProviderController, controller) for controller in workflow_provider_controllers]
|
||||
)
|
||||
|
||||
for provider_controller in workflow_provider_controllers:
|
||||
user_provider = ToolTransformService.workflow_provider_to_user_provider(
|
||||
@@ -527,7 +553,7 @@ class ToolManager:
|
||||
|
||||
:return: the provider controller, the credentials
|
||||
"""
|
||||
provider: ApiToolProvider = (
|
||||
provider: Optional[ApiToolProvider] = (
|
||||
db.session.query(ApiToolProvider)
|
||||
.filter(
|
||||
ApiToolProvider.id == provider_id,
|
||||
@@ -556,7 +582,7 @@ class ToolManager:
|
||||
get tool provider
|
||||
"""
|
||||
provider_name = provider
|
||||
provider: ApiToolProvider = (
|
||||
provider_tool: Optional[ApiToolProvider] = (
|
||||
db.session.query(ApiToolProvider)
|
||||
.filter(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
@@ -565,17 +591,18 @@ class ToolManager:
|
||||
.first()
|
||||
)
|
||||
|
||||
if provider is None:
|
||||
if provider_tool is None:
|
||||
raise ValueError(f"you have not added provider {provider_name}")
|
||||
|
||||
try:
|
||||
credentials = json.loads(provider.credentials_str) or {}
|
||||
credentials = json.loads(provider_tool.credentials_str) or {}
|
||||
except:
|
||||
credentials = {}
|
||||
|
||||
# package tool provider controller
|
||||
controller = ApiToolProviderController.from_db(
|
||||
provider, ApiProviderAuthType.API_KEY if credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE
|
||||
provider_tool,
|
||||
ApiProviderAuthType.API_KEY if credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE,
|
||||
)
|
||||
# init tool configuration
|
||||
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller)
|
||||
@@ -584,25 +611,28 @@ class ToolManager:
|
||||
masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)
|
||||
|
||||
try:
|
||||
icon = json.loads(provider.icon)
|
||||
icon = json.loads(provider_tool.icon)
|
||||
except:
|
||||
icon = {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
|
||||
# add tool labels
|
||||
labels = ToolLabelManager.get_tool_labels(controller)
|
||||
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"schema_type": provider.schema_type,
|
||||
"schema": provider.schema,
|
||||
"tools": provider.tools,
|
||||
"icon": icon,
|
||||
"description": provider.description,
|
||||
"credentials": masked_credentials,
|
||||
"privacy_policy": provider.privacy_policy,
|
||||
"custom_disclaimer": provider.custom_disclaimer,
|
||||
"labels": labels,
|
||||
}
|
||||
return cast(
|
||||
dict,
|
||||
jsonable_encoder(
|
||||
{
|
||||
"schema_type": provider_tool.schema_type,
|
||||
"schema": provider_tool.schema,
|
||||
"tools": provider_tool.tools,
|
||||
"icon": icon,
|
||||
"description": provider_tool.description,
|
||||
"credentials": masked_credentials,
|
||||
"privacy_policy": provider_tool.privacy_policy,
|
||||
"custom_disclaimer": provider_tool.custom_disclaimer,
|
||||
"labels": labels,
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -617,6 +647,7 @@ class ToolManager:
|
||||
"""
|
||||
provider_type = provider_type
|
||||
provider_id = provider_id
|
||||
provider: Optional[Union[BuiltinToolProvider, ApiToolProvider, WorkflowToolProvider]] = None
|
||||
if provider_type == "builtin":
|
||||
return (
|
||||
dify_config.CONSOLE_API_URL
|
||||
@@ -626,16 +657,21 @@ class ToolManager:
|
||||
)
|
||||
elif provider_type == "api":
|
||||
try:
|
||||
provider: ApiToolProvider = (
|
||||
provider = (
|
||||
db.session.query(ApiToolProvider)
|
||||
.filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.id == provider_id)
|
||||
.first()
|
||||
)
|
||||
return json.loads(provider.icon)
|
||||
if provider is None:
|
||||
raise ToolProviderNotFoundError(f"api provider {provider_id} not found")
|
||||
icon = json.loads(provider.icon)
|
||||
if isinstance(icon, (str, dict)):
|
||||
return icon
|
||||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
except:
|
||||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
elif provider_type == "workflow":
|
||||
provider: WorkflowToolProvider = (
|
||||
provider = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
|
||||
.first()
|
||||
@@ -643,7 +679,13 @@ class ToolManager:
|
||||
if provider is None:
|
||||
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
|
||||
|
||||
return json.loads(provider.icon)
|
||||
try:
|
||||
icon = json.loads(provider.icon)
|
||||
if isinstance(icon, (str, dict)):
|
||||
return icon
|
||||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
except:
|
||||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
else:
|
||||
raise ValueError(f"provider type {provider_type} not found")
|
||||
|
||||
|
||||
@@ -72,9 +72,13 @@ class ToolConfigurationManager(BaseModel):
|
||||
|
||||
return a deep copy of credentials with decrypted values
|
||||
"""
|
||||
identity_id = ""
|
||||
if self.provider_controller.identity:
|
||||
identity_id = f"{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}"
|
||||
|
||||
cache = ToolProviderCredentialsCache(
|
||||
tenant_id=self.tenant_id,
|
||||
identity_id=f"{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}",
|
||||
identity_id=identity_id,
|
||||
cache_type=ToolProviderCredentialsCacheType.PROVIDER,
|
||||
)
|
||||
cached_credentials = cache.get()
|
||||
@@ -95,9 +99,13 @@ class ToolConfigurationManager(BaseModel):
|
||||
return credentials
|
||||
|
||||
def delete_tool_credentials_cache(self):
|
||||
identity_id = ""
|
||||
if self.provider_controller.identity:
|
||||
identity_id = f"{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}"
|
||||
|
||||
cache = ToolProviderCredentialsCache(
|
||||
tenant_id=self.tenant_id,
|
||||
identity_id=f"{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}",
|
||||
identity_id=identity_id,
|
||||
cache_type=ToolProviderCredentialsCacheType.PROVIDER,
|
||||
)
|
||||
cache.delete()
|
||||
@@ -199,6 +207,9 @@ class ToolParameterConfigurationManager(BaseModel):
|
||||
|
||||
return a deep copy of parameters with decrypted values
|
||||
"""
|
||||
if self.tool_runtime is None or self.tool_runtime.identity is None:
|
||||
raise ValueError("tool_runtime is required")
|
||||
|
||||
cache = ToolParameterCache(
|
||||
tenant_id=self.tenant_id,
|
||||
provider=f"{self.provider_type}.{self.provider_name}",
|
||||
@@ -232,6 +243,9 @@ class ToolParameterConfigurationManager(BaseModel):
|
||||
return parameters
|
||||
|
||||
def delete_tool_parameters_cache(self):
|
||||
if self.tool_runtime is None or self.tool_runtime.identity is None:
|
||||
raise ValueError("tool_runtime is required")
|
||||
|
||||
cache = ToolParameterCache(
|
||||
tenant_id=self.tenant_id,
|
||||
provider=f"{self.provider_type}.{self.provider_name}",
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import json
|
||||
from typing import Optional
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
import httpx
|
||||
|
||||
@@ -101,7 +101,7 @@ class FeishuRequest:
|
||||
"""
|
||||
url = f"{self.API_BASE_URL}/access_token/get_tenant_access_token"
|
||||
payload = {"app_id": app_id, "app_secret": app_secret}
|
||||
res = self._send_request(url, require_token=False, payload=payload)
|
||||
res: dict = self._send_request(url, require_token=False, payload=payload)
|
||||
return res
|
||||
|
||||
def create_document(self, title: str, content: str, folder_token: str) -> dict:
|
||||
@@ -126,15 +126,16 @@ class FeishuRequest:
|
||||
"content": content,
|
||||
"folder_token": folder_token,
|
||||
}
|
||||
res = self._send_request(url, payload=payload)
|
||||
res: dict = self._send_request(url, payload=payload)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def write_document(self, document_id: str, content: str, position: str = "end") -> dict:
|
||||
url = f"{self.API_BASE_URL}/document/write_document"
|
||||
payload = {"document_id": document_id, "content": content, "position": position}
|
||||
res = self._send_request(url, payload=payload)
|
||||
res: dict = self._send_request(url, payload=payload)
|
||||
return res
|
||||
|
||||
def get_document_content(self, document_id: str, mode: str = "markdown", lang: str = "0") -> str:
|
||||
@@ -155,9 +156,9 @@ class FeishuRequest:
|
||||
"lang": lang,
|
||||
}
|
||||
url = f"{self.API_BASE_URL}/document/get_document_content"
|
||||
res = self._send_request(url, method="GET", params=params)
|
||||
res: dict = self._send_request(url, method="GET", params=params)
|
||||
if "data" in res:
|
||||
return res.get("data").get("content")
|
||||
return cast(str, res.get("data", {}).get("content"))
|
||||
return ""
|
||||
|
||||
def list_document_blocks(
|
||||
@@ -173,9 +174,10 @@ class FeishuRequest:
|
||||
"page_token": page_token,
|
||||
}
|
||||
url = f"{self.API_BASE_URL}/document/list_document_blocks"
|
||||
res = self._send_request(url, method="GET", params=params)
|
||||
res: dict = self._send_request(url, method="GET", params=params)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def send_bot_message(self, receive_id_type: str, receive_id: str, msg_type: str, content: str) -> dict:
|
||||
@@ -191,9 +193,10 @@ class FeishuRequest:
|
||||
"msg_type": msg_type,
|
||||
"content": content.strip('"').replace(r"\"", '"').replace(r"\\", "\\"),
|
||||
}
|
||||
res = self._send_request(url, params=params, payload=payload)
|
||||
res: dict = self._send_request(url, params=params, payload=payload)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def send_webhook_message(self, webhook: str, msg_type: str, content: str) -> dict:
|
||||
@@ -203,7 +206,7 @@ class FeishuRequest:
|
||||
"msg_type": msg_type,
|
||||
"content": content.strip('"').replace(r"\"", '"').replace(r"\\", "\\"),
|
||||
}
|
||||
res = self._send_request(url, require_token=False, payload=payload)
|
||||
res: dict = self._send_request(url, require_token=False, payload=payload)
|
||||
return res
|
||||
|
||||
def get_chat_messages(
|
||||
@@ -227,9 +230,10 @@ class FeishuRequest:
|
||||
"page_token": page_token,
|
||||
"page_size": page_size,
|
||||
}
|
||||
res = self._send_request(url, method="GET", params=params)
|
||||
res: dict = self._send_request(url, method="GET", params=params)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def get_thread_messages(
|
||||
@@ -245,9 +249,10 @@ class FeishuRequest:
|
||||
"page_token": page_token,
|
||||
"page_size": page_size,
|
||||
}
|
||||
res = self._send_request(url, method="GET", params=params)
|
||||
res: dict = self._send_request(url, method="GET", params=params)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def create_task(self, summary: str, start_time: str, end_time: str, completed_time: str, description: str) -> dict:
|
||||
@@ -260,9 +265,10 @@ class FeishuRequest:
|
||||
"completed_at": completed_time,
|
||||
"description": description,
|
||||
}
|
||||
res = self._send_request(url, payload=payload)
|
||||
res: dict = self._send_request(url, payload=payload)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def update_task(
|
||||
@@ -278,9 +284,10 @@ class FeishuRequest:
|
||||
"completed_time": completed_time,
|
||||
"description": description,
|
||||
}
|
||||
res = self._send_request(url, method="PATCH", payload=payload)
|
||||
res: dict = self._send_request(url, method="PATCH", payload=payload)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def delete_task(self, task_guid: str) -> dict:
|
||||
@@ -289,7 +296,7 @@ class FeishuRequest:
|
||||
payload = {
|
||||
"task_guid": task_guid,
|
||||
}
|
||||
res = self._send_request(url, method="DELETE", payload=payload)
|
||||
res: dict = self._send_request(url, method="DELETE", payload=payload)
|
||||
return res
|
||||
|
||||
def add_members(self, task_guid: str, member_phone_or_email: str, member_role: str) -> dict:
|
||||
@@ -300,7 +307,7 @@ class FeishuRequest:
|
||||
"member_phone_or_email": member_phone_or_email,
|
||||
"member_role": member_role,
|
||||
}
|
||||
res = self._send_request(url, payload=payload)
|
||||
res: dict = self._send_request(url, payload=payload)
|
||||
return res
|
||||
|
||||
def get_wiki_nodes(self, space_id: str, parent_node_token: str, page_token: str, page_size: int = 20) -> dict:
|
||||
@@ -312,9 +319,10 @@ class FeishuRequest:
|
||||
"page_token": page_token,
|
||||
"page_size": page_size,
|
||||
}
|
||||
res = self._send_request(url, payload=payload)
|
||||
res: dict = self._send_request(url, payload=payload)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def get_primary_calendar(self, user_id_type: str = "open_id") -> dict:
|
||||
@@ -322,9 +330,10 @@ class FeishuRequest:
|
||||
params = {
|
||||
"user_id_type": user_id_type,
|
||||
}
|
||||
res = self._send_request(url, method="GET", params=params)
|
||||
res: dict = self._send_request(url, method="GET", params=params)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def create_event(
|
||||
@@ -347,9 +356,10 @@ class FeishuRequest:
|
||||
"auto_record": auto_record,
|
||||
"attendee_ability": attendee_ability,
|
||||
}
|
||||
res = self._send_request(url, payload=payload)
|
||||
res: dict = self._send_request(url, payload=payload)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def update_event(
|
||||
@@ -363,7 +373,7 @@ class FeishuRequest:
|
||||
auto_record: bool,
|
||||
) -> dict:
|
||||
url = f"{self.API_BASE_URL}/calendar/update_event/{event_id}"
|
||||
payload = {}
|
||||
payload: dict[str, Any] = {}
|
||||
if summary:
|
||||
payload["summary"] = summary
|
||||
if description:
|
||||
@@ -376,7 +386,7 @@ class FeishuRequest:
|
||||
payload["need_notification"] = need_notification
|
||||
if auto_record:
|
||||
payload["auto_record"] = auto_record
|
||||
res = self._send_request(url, method="PATCH", payload=payload)
|
||||
res: dict = self._send_request(url, method="PATCH", payload=payload)
|
||||
return res
|
||||
|
||||
def delete_event(self, event_id: str, need_notification: bool = True) -> dict:
|
||||
@@ -384,7 +394,7 @@ class FeishuRequest:
|
||||
params = {
|
||||
"need_notification": need_notification,
|
||||
}
|
||||
res = self._send_request(url, method="DELETE", params=params)
|
||||
res: dict = self._send_request(url, method="DELETE", params=params)
|
||||
return res
|
||||
|
||||
def list_events(self, start_time: str, end_time: str, page_token: str, page_size: int = 50) -> dict:
|
||||
@@ -395,9 +405,10 @@ class FeishuRequest:
|
||||
"page_token": page_token,
|
||||
"page_size": page_size,
|
||||
}
|
||||
res = self._send_request(url, method="GET", params=params)
|
||||
res: dict = self._send_request(url, method="GET", params=params)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def search_events(
|
||||
@@ -418,9 +429,10 @@ class FeishuRequest:
|
||||
"user_id_type": user_id_type,
|
||||
"page_size": page_size,
|
||||
}
|
||||
res = self._send_request(url, payload=payload)
|
||||
res: dict = self._send_request(url, payload=payload)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def add_event_attendees(self, event_id: str, attendee_phone_or_email: str, need_notification: bool = True) -> dict:
|
||||
@@ -431,9 +443,10 @@ class FeishuRequest:
|
||||
"attendee_phone_or_email": attendee_phone_or_email,
|
||||
"need_notification": need_notification,
|
||||
}
|
||||
res = self._send_request(url, payload=payload)
|
||||
res: dict = self._send_request(url, payload=payload)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def create_spreadsheet(
|
||||
@@ -447,9 +460,10 @@ class FeishuRequest:
|
||||
"title": title,
|
||||
"folder_token": folder_token,
|
||||
}
|
||||
res = self._send_request(url, payload=payload)
|
||||
res: dict = self._send_request(url, payload=payload)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def get_spreadsheet(
|
||||
@@ -463,9 +477,10 @@ class FeishuRequest:
|
||||
"spreadsheet_token": spreadsheet_token,
|
||||
"user_id_type": user_id_type,
|
||||
}
|
||||
res = self._send_request(url, method="GET", params=params)
|
||||
res: dict = self._send_request(url, method="GET", params=params)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def list_spreadsheet_sheets(
|
||||
@@ -477,9 +492,10 @@ class FeishuRequest:
|
||||
params = {
|
||||
"spreadsheet_token": spreadsheet_token,
|
||||
}
|
||||
res = self._send_request(url, method="GET", params=params)
|
||||
res: dict = self._send_request(url, method="GET", params=params)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def add_rows(
|
||||
@@ -499,9 +515,10 @@ class FeishuRequest:
|
||||
"length": length,
|
||||
"values": values,
|
||||
}
|
||||
res = self._send_request(url, payload=payload)
|
||||
res: dict = self._send_request(url, payload=payload)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def add_cols(
|
||||
@@ -521,9 +538,10 @@ class FeishuRequest:
|
||||
"length": length,
|
||||
"values": values,
|
||||
}
|
||||
res = self._send_request(url, payload=payload)
|
||||
res: dict = self._send_request(url, payload=payload)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def read_rows(
|
||||
@@ -545,9 +563,10 @@ class FeishuRequest:
|
||||
"num_rows": num_rows,
|
||||
"user_id_type": user_id_type,
|
||||
}
|
||||
res = self._send_request(url, method="GET", params=params)
|
||||
res: dict = self._send_request(url, method="GET", params=params)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def read_cols(
|
||||
@@ -569,9 +588,10 @@ class FeishuRequest:
|
||||
"num_cols": num_cols,
|
||||
"user_id_type": user_id_type,
|
||||
}
|
||||
res = self._send_request(url, method="GET", params=params)
|
||||
res: dict = self._send_request(url, method="GET", params=params)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def read_table(
|
||||
@@ -593,9 +613,10 @@ class FeishuRequest:
|
||||
"query": query,
|
||||
"user_id_type": user_id_type,
|
||||
}
|
||||
res = self._send_request(url, method="GET", params=params)
|
||||
res: dict = self._send_request(url, method="GET", params=params)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def create_base(
|
||||
@@ -609,9 +630,10 @@ class FeishuRequest:
|
||||
"name": name,
|
||||
"folder_token": folder_token,
|
||||
}
|
||||
res = self._send_request(url, payload=payload)
|
||||
res: dict = self._send_request(url, payload=payload)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def add_records(
|
||||
@@ -633,9 +655,10 @@ class FeishuRequest:
|
||||
payload = {
|
||||
"records": convert_add_records(records),
|
||||
}
|
||||
res = self._send_request(url, params=params, payload=payload)
|
||||
res: dict = self._send_request(url, params=params, payload=payload)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def update_records(
|
||||
@@ -657,9 +680,10 @@ class FeishuRequest:
|
||||
payload = {
|
||||
"records": convert_update_records(records),
|
||||
}
|
||||
res = self._send_request(url, params=params, payload=payload)
|
||||
res: dict = self._send_request(url, params=params, payload=payload)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def delete_records(
|
||||
@@ -686,9 +710,10 @@ class FeishuRequest:
|
||||
payload = {
|
||||
"records": record_id_list,
|
||||
}
|
||||
res = self._send_request(url, params=params, payload=payload)
|
||||
res: dict = self._send_request(url, params=params, payload=payload)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def search_record(
|
||||
@@ -740,7 +765,7 @@ class FeishuRequest:
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError("The input string is not valid JSON")
|
||||
|
||||
payload = {}
|
||||
payload: dict[str, Any] = {}
|
||||
|
||||
if view_id:
|
||||
payload["view_id"] = view_id
|
||||
@@ -752,10 +777,11 @@ class FeishuRequest:
|
||||
payload["filter"] = filter_dict
|
||||
if automatic_fields:
|
||||
payload["automatic_fields"] = automatic_fields
|
||||
res = self._send_request(url, params=params, payload=payload)
|
||||
res: dict = self._send_request(url, params=params, payload=payload)
|
||||
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def get_base_info(
|
||||
@@ -767,9 +793,10 @@ class FeishuRequest:
|
||||
params = {
|
||||
"app_token": app_token,
|
||||
}
|
||||
res = self._send_request(url, method="GET", params=params)
|
||||
res: dict = self._send_request(url, method="GET", params=params)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def create_table(
|
||||
@@ -797,9 +824,10 @@ class FeishuRequest:
|
||||
}
|
||||
if default_view_name:
|
||||
payload["default_view_name"] = default_view_name
|
||||
res = self._send_request(url, params=params, payload=payload)
|
||||
res: dict = self._send_request(url, params=params, payload=payload)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def delete_tables(
|
||||
@@ -834,9 +862,10 @@ class FeishuRequest:
|
||||
"table_names": table_name_list,
|
||||
}
|
||||
|
||||
res = self._send_request(url, params=params, payload=payload)
|
||||
res: dict = self._send_request(url, params=params, payload=payload)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def list_tables(
|
||||
@@ -852,9 +881,10 @@ class FeishuRequest:
|
||||
"page_token": page_token,
|
||||
"page_size": page_size,
|
||||
}
|
||||
res = self._send_request(url, method="GET", params=params)
|
||||
res: dict = self._send_request(url, method="GET", params=params)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def read_records(
|
||||
@@ -882,7 +912,8 @@ class FeishuRequest:
|
||||
"record_ids": record_id_list,
|
||||
"user_id_type": user_id_type,
|
||||
}
|
||||
res = self._send_request(url, method="GET", params=params, payload=payload)
|
||||
res: dict = self._send_request(url, method="GET", params=params, payload=payload)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import json
|
||||
from typing import Optional
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
import httpx
|
||||
|
||||
@@ -62,12 +62,10 @@ class LarkRequest:
|
||||
def tenant_access_token(self) -> str:
|
||||
feishu_tenant_access_token = f"tools:{self.app_id}:feishu_tenant_access_token"
|
||||
if redis_client.exists(feishu_tenant_access_token):
|
||||
return redis_client.get(feishu_tenant_access_token).decode()
|
||||
res = self.get_tenant_access_token(self.app_id, self.app_secret)
|
||||
return str(redis_client.get(feishu_tenant_access_token).decode())
|
||||
res: dict[str, str] = self.get_tenant_access_token(self.app_id, self.app_secret)
|
||||
redis_client.setex(feishu_tenant_access_token, res.get("expire"), res.get("tenant_access_token"))
|
||||
if "tenant_access_token" in res:
|
||||
return res.get("tenant_access_token")
|
||||
return ""
|
||||
return res.get("tenant_access_token", "")
|
||||
|
||||
def _send_request(
|
||||
self,
|
||||
@@ -91,7 +89,7 @@ class LarkRequest:
|
||||
def get_tenant_access_token(self, app_id: str, app_secret: str) -> dict:
|
||||
url = f"{self.API_BASE_URL}/access_token/get_tenant_access_token"
|
||||
payload = {"app_id": app_id, "app_secret": app_secret}
|
||||
res = self._send_request(url, require_token=False, payload=payload)
|
||||
res: dict = self._send_request(url, require_token=False, payload=payload)
|
||||
return res
|
||||
|
||||
def create_document(self, title: str, content: str, folder_token: str) -> dict:
|
||||
@@ -101,15 +99,16 @@ class LarkRequest:
|
||||
"content": content,
|
||||
"folder_token": folder_token,
|
||||
}
|
||||
res = self._send_request(url, payload=payload)
|
||||
res: dict = self._send_request(url, payload=payload)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def write_document(self, document_id: str, content: str, position: str = "end") -> dict:
|
||||
url = f"{self.API_BASE_URL}/document/write_document"
|
||||
payload = {"document_id": document_id, "content": content, "position": position}
|
||||
res = self._send_request(url, payload=payload)
|
||||
res: dict = self._send_request(url, payload=payload)
|
||||
return res
|
||||
|
||||
def get_document_content(self, document_id: str, mode: str = "markdown", lang: str = "0") -> str | dict:
|
||||
@@ -119,9 +118,9 @@ class LarkRequest:
|
||||
"lang": lang,
|
||||
}
|
||||
url = f"{self.API_BASE_URL}/document/get_document_content"
|
||||
res = self._send_request(url, method="GET", params=params)
|
||||
res: dict = self._send_request(url, method="GET", params=params)
|
||||
if "data" in res:
|
||||
return res.get("data").get("content")
|
||||
return cast(dict, res.get("data", {}).get("content"))
|
||||
return ""
|
||||
|
||||
def list_document_blocks(
|
||||
@@ -134,9 +133,10 @@ class LarkRequest:
|
||||
"page_token": page_token,
|
||||
}
|
||||
url = f"{self.API_BASE_URL}/document/list_document_blocks"
|
||||
res = self._send_request(url, method="GET", params=params)
|
||||
res: dict = self._send_request(url, method="GET", params=params)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def send_bot_message(self, receive_id_type: str, receive_id: str, msg_type: str, content: str) -> dict:
|
||||
@@ -149,9 +149,10 @@ class LarkRequest:
|
||||
"msg_type": msg_type,
|
||||
"content": content.strip('"').replace(r"\"", '"').replace(r"\\", "\\"),
|
||||
}
|
||||
res = self._send_request(url, params=params, payload=payload)
|
||||
res: dict = self._send_request(url, params=params, payload=payload)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def send_webhook_message(self, webhook: str, msg_type: str, content: str) -> dict:
|
||||
@@ -161,7 +162,7 @@ class LarkRequest:
|
||||
"msg_type": msg_type,
|
||||
"content": content.strip('"').replace(r"\"", '"').replace(r"\\", "\\"),
|
||||
}
|
||||
res = self._send_request(url, require_token=False, payload=payload)
|
||||
res: dict = self._send_request(url, require_token=False, payload=payload)
|
||||
return res
|
||||
|
||||
def get_chat_messages(
|
||||
@@ -182,9 +183,10 @@ class LarkRequest:
|
||||
"page_token": page_token,
|
||||
"page_size": page_size,
|
||||
}
|
||||
res = self._send_request(url, method="GET", params=params)
|
||||
res: dict = self._send_request(url, method="GET", params=params)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def get_thread_messages(
|
||||
@@ -197,9 +199,10 @@ class LarkRequest:
|
||||
"page_token": page_token,
|
||||
"page_size": page_size,
|
||||
}
|
||||
res = self._send_request(url, method="GET", params=params)
|
||||
res: dict = self._send_request(url, method="GET", params=params)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def create_task(self, summary: str, start_time: str, end_time: str, completed_time: str, description: str) -> dict:
|
||||
@@ -211,9 +214,10 @@ class LarkRequest:
|
||||
"completed_at": completed_time,
|
||||
"description": description,
|
||||
}
|
||||
res = self._send_request(url, payload=payload)
|
||||
res: dict = self._send_request(url, payload=payload)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def update_task(
|
||||
@@ -228,9 +232,10 @@ class LarkRequest:
|
||||
"completed_time": completed_time,
|
||||
"description": description,
|
||||
}
|
||||
res = self._send_request(url, method="PATCH", payload=payload)
|
||||
res: dict = self._send_request(url, method="PATCH", payload=payload)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def delete_task(self, task_guid: str) -> dict:
|
||||
@@ -238,9 +243,10 @@ class LarkRequest:
|
||||
payload = {
|
||||
"task_guid": task_guid,
|
||||
}
|
||||
res = self._send_request(url, method="DELETE", payload=payload)
|
||||
res: dict = self._send_request(url, method="DELETE", payload=payload)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def add_members(self, task_guid: str, member_phone_or_email: str, member_role: str) -> dict:
|
||||
@@ -250,9 +256,10 @@ class LarkRequest:
|
||||
"member_phone_or_email": member_phone_or_email,
|
||||
"member_role": member_role,
|
||||
}
|
||||
res = self._send_request(url, payload=payload)
|
||||
res: dict = self._send_request(url, payload=payload)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def get_wiki_nodes(self, space_id: str, parent_node_token: str, page_token: str, page_size: int = 20) -> dict:
|
||||
@@ -263,9 +270,10 @@ class LarkRequest:
|
||||
"page_token": page_token,
|
||||
"page_size": page_size,
|
||||
}
|
||||
res = self._send_request(url, payload=payload)
|
||||
res: dict = self._send_request(url, payload=payload)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def get_primary_calendar(self, user_id_type: str = "open_id") -> dict:
|
||||
@@ -273,9 +281,10 @@ class LarkRequest:
|
||||
params = {
|
||||
"user_id_type": user_id_type,
|
||||
}
|
||||
res = self._send_request(url, method="GET", params=params)
|
||||
res: dict = self._send_request(url, method="GET", params=params)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def create_event(
|
||||
@@ -298,9 +307,10 @@ class LarkRequest:
|
||||
"auto_record": auto_record,
|
||||
"attendee_ability": attendee_ability,
|
||||
}
|
||||
res = self._send_request(url, payload=payload)
|
||||
res: dict = self._send_request(url, payload=payload)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def update_event(
|
||||
@@ -314,7 +324,7 @@ class LarkRequest:
|
||||
auto_record: bool,
|
||||
) -> dict:
|
||||
url = f"{self.API_BASE_URL}/calendar/update_event/{event_id}"
|
||||
payload = {}
|
||||
payload: dict[str, Any] = {}
|
||||
if summary:
|
||||
payload["summary"] = summary
|
||||
if description:
|
||||
@@ -327,7 +337,7 @@ class LarkRequest:
|
||||
payload["need_notification"] = need_notification
|
||||
if auto_record:
|
||||
payload["auto_record"] = auto_record
|
||||
res = self._send_request(url, method="PATCH", payload=payload)
|
||||
res: dict = self._send_request(url, method="PATCH", payload=payload)
|
||||
return res
|
||||
|
||||
def delete_event(self, event_id: str, need_notification: bool = True) -> dict:
|
||||
@@ -335,7 +345,7 @@ class LarkRequest:
|
||||
params = {
|
||||
"need_notification": need_notification,
|
||||
}
|
||||
res = self._send_request(url, method="DELETE", params=params)
|
||||
res: dict = self._send_request(url, method="DELETE", params=params)
|
||||
return res
|
||||
|
||||
def list_events(self, start_time: str, end_time: str, page_token: str, page_size: int = 50) -> dict:
|
||||
@@ -346,9 +356,10 @@ class LarkRequest:
|
||||
"page_token": page_token,
|
||||
"page_size": page_size,
|
||||
}
|
||||
res = self._send_request(url, method="GET", params=params)
|
||||
res: dict = self._send_request(url, method="GET", params=params)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def search_events(
|
||||
@@ -369,9 +380,10 @@ class LarkRequest:
|
||||
"user_id_type": user_id_type,
|
||||
"page_size": page_size,
|
||||
}
|
||||
res = self._send_request(url, payload=payload)
|
||||
res: dict = self._send_request(url, payload=payload)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def add_event_attendees(self, event_id: str, attendee_phone_or_email: str, need_notification: bool = True) -> dict:
|
||||
@@ -381,9 +393,10 @@ class LarkRequest:
|
||||
"attendee_phone_or_email": attendee_phone_or_email,
|
||||
"need_notification": need_notification,
|
||||
}
|
||||
res = self._send_request(url, payload=payload)
|
||||
res: dict = self._send_request(url, payload=payload)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def create_spreadsheet(
|
||||
@@ -396,9 +409,10 @@ class LarkRequest:
|
||||
"title": title,
|
||||
"folder_token": folder_token,
|
||||
}
|
||||
res = self._send_request(url, payload=payload)
|
||||
res: dict = self._send_request(url, payload=payload)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def get_spreadsheet(
|
||||
@@ -411,9 +425,10 @@ class LarkRequest:
|
||||
"spreadsheet_token": spreadsheet_token,
|
||||
"user_id_type": user_id_type,
|
||||
}
|
||||
res = self._send_request(url, method="GET", params=params)
|
||||
res: dict = self._send_request(url, method="GET", params=params)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def list_spreadsheet_sheets(
|
||||
@@ -424,9 +439,10 @@ class LarkRequest:
|
||||
params = {
|
||||
"spreadsheet_token": spreadsheet_token,
|
||||
}
|
||||
res = self._send_request(url, method="GET", params=params)
|
||||
res: dict = self._send_request(url, method="GET", params=params)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def add_rows(
|
||||
@@ -445,9 +461,10 @@ class LarkRequest:
|
||||
"length": length,
|
||||
"values": values,
|
||||
}
|
||||
res = self._send_request(url, payload=payload)
|
||||
res: dict = self._send_request(url, payload=payload)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def add_cols(
|
||||
@@ -466,9 +483,10 @@ class LarkRequest:
|
||||
"length": length,
|
||||
"values": values,
|
||||
}
|
||||
res = self._send_request(url, payload=payload)
|
||||
res: dict = self._send_request(url, payload=payload)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def read_rows(
|
||||
@@ -489,9 +507,10 @@ class LarkRequest:
|
||||
"num_rows": num_rows,
|
||||
"user_id_type": user_id_type,
|
||||
}
|
||||
res = self._send_request(url, method="GET", params=params)
|
||||
res: dict = self._send_request(url, method="GET", params=params)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def read_cols(
|
||||
@@ -512,9 +531,10 @@ class LarkRequest:
|
||||
"num_cols": num_cols,
|
||||
"user_id_type": user_id_type,
|
||||
}
|
||||
res = self._send_request(url, method="GET", params=params)
|
||||
res: dict = self._send_request(url, method="GET", params=params)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def read_table(
|
||||
@@ -535,9 +555,10 @@ class LarkRequest:
|
||||
"query": query,
|
||||
"user_id_type": user_id_type,
|
||||
}
|
||||
res = self._send_request(url, method="GET", params=params)
|
||||
res: dict = self._send_request(url, method="GET", params=params)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def create_base(
|
||||
@@ -550,9 +571,10 @@ class LarkRequest:
|
||||
"name": name,
|
||||
"folder_token": folder_token,
|
||||
}
|
||||
res = self._send_request(url, payload=payload)
|
||||
res: dict = self._send_request(url, payload=payload)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def add_records(
|
||||
@@ -573,9 +595,10 @@ class LarkRequest:
|
||||
payload = {
|
||||
"records": self.convert_add_records(records),
|
||||
}
|
||||
res = self._send_request(url, params=params, payload=payload)
|
||||
res: dict = self._send_request(url, params=params, payload=payload)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def update_records(
|
||||
@@ -596,9 +619,10 @@ class LarkRequest:
|
||||
payload = {
|
||||
"records": self.convert_update_records(records),
|
||||
}
|
||||
res = self._send_request(url, params=params, payload=payload)
|
||||
res: dict = self._send_request(url, params=params, payload=payload)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def delete_records(
|
||||
@@ -624,9 +648,10 @@ class LarkRequest:
|
||||
payload = {
|
||||
"records": record_id_list,
|
||||
}
|
||||
res = self._send_request(url, params=params, payload=payload)
|
||||
res: dict = self._send_request(url, params=params, payload=payload)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def search_record(
|
||||
@@ -678,7 +703,7 @@ class LarkRequest:
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError("The input string is not valid JSON")
|
||||
|
||||
payload = {}
|
||||
payload: dict[str, Any] = {}
|
||||
|
||||
if view_id:
|
||||
payload["view_id"] = view_id
|
||||
@@ -690,9 +715,10 @@ class LarkRequest:
|
||||
payload["filter"] = filter_dict
|
||||
if automatic_fields:
|
||||
payload["automatic_fields"] = automatic_fields
|
||||
res = self._send_request(url, params=params, payload=payload)
|
||||
res: dict = self._send_request(url, params=params, payload=payload)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def get_base_info(
|
||||
@@ -703,9 +729,10 @@ class LarkRequest:
|
||||
params = {
|
||||
"app_token": app_token,
|
||||
}
|
||||
res = self._send_request(url, method="GET", params=params)
|
||||
res: dict = self._send_request(url, method="GET", params=params)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def create_table(
|
||||
@@ -732,9 +759,10 @@ class LarkRequest:
|
||||
}
|
||||
if default_view_name:
|
||||
payload["default_view_name"] = default_view_name
|
||||
res = self._send_request(url, params=params, payload=payload)
|
||||
res: dict = self._send_request(url, params=params, payload=payload)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def delete_tables(
|
||||
@@ -767,9 +795,10 @@ class LarkRequest:
|
||||
"table_ids": table_id_list,
|
||||
"table_names": table_name_list,
|
||||
}
|
||||
res = self._send_request(url, params=params, payload=payload)
|
||||
res: dict = self._send_request(url, params=params, payload=payload)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def list_tables(
|
||||
@@ -784,9 +813,10 @@ class LarkRequest:
|
||||
"page_token": page_token,
|
||||
"page_size": page_size,
|
||||
}
|
||||
res = self._send_request(url, method="GET", params=params)
|
||||
res: dict = self._send_request(url, method="GET", params=params)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
def read_records(
|
||||
@@ -814,7 +844,8 @@ class LarkRequest:
|
||||
"record_ids": record_id_list,
|
||||
"user_id_type": user_id_type,
|
||||
}
|
||||
res = self._send_request(url, method="POST", params=params, payload=payload)
|
||||
res: dict = self._send_request(url, method="POST", params=params, payload=payload)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
data: dict = res.get("data", {})
|
||||
return data
|
||||
return res
|
||||
|
||||
@@ -90,12 +90,12 @@ class ToolFileMessageTransformer:
|
||||
)
|
||||
elif message.type == ToolInvokeMessage.MessageType.FILE:
|
||||
assert message.meta is not None
|
||||
file = message.meta.get("file")
|
||||
if isinstance(file, File):
|
||||
if file.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
assert file.related_id is not None
|
||||
url = cls.get_tool_file_url(tool_file_id=file.related_id, extension=file.extension)
|
||||
if file.type == FileType.IMAGE:
|
||||
file_mata = message.meta.get("file")
|
||||
if isinstance(file_mata, File):
|
||||
if file_mata.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
assert file_mata.related_id is not None
|
||||
url = cls.get_tool_file_url(tool_file_id=file_mata.related_id, extension=file_mata.extension)
|
||||
if file_mata.type == FileType.IMAGE:
|
||||
result.append(
|
||||
ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
|
||||
|
||||
@@ -5,7 +5,7 @@ Therefore, a model manager is needed to list/invoke/validate models.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import cast
|
||||
from typing import Optional, cast
|
||||
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.llm_entities import LLMResult
|
||||
@@ -51,7 +51,7 @@ class ModelInvocationUtils:
|
||||
if not schema:
|
||||
raise InvokeModelError("No model schema found")
|
||||
|
||||
max_tokens = schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE, None)
|
||||
max_tokens: Optional[int] = schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE, None)
|
||||
if max_tokens is None:
|
||||
return 2048
|
||||
|
||||
@@ -133,14 +133,17 @@ class ModelInvocationUtils:
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
response: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=[],
|
||||
stop=[],
|
||||
stream=False,
|
||||
user=user_id,
|
||||
callbacks=[],
|
||||
response: LLMResult = cast(
|
||||
LLMResult,
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=[],
|
||||
stop=[],
|
||||
stream=False,
|
||||
user=user_id,
|
||||
callbacks=[],
|
||||
),
|
||||
)
|
||||
except InvokeRateLimitError as e:
|
||||
raise InvokeModelError(f"Invoke rate limit error: {e}")
|
||||
|
||||
@@ -6,7 +6,7 @@ from json.decoder import JSONDecodeError
|
||||
from typing import Optional
|
||||
|
||||
from requests import get
|
||||
from yaml import YAMLError, safe_load
|
||||
from yaml import YAMLError, safe_load # type: ignore
|
||||
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||
@@ -64,6 +64,9 @@ class ApiBasedToolSchemaParser:
|
||||
default=parameter["schema"]["default"]
|
||||
if "schema" in parameter and "default" in parameter["schema"]
|
||||
else None,
|
||||
placeholder=I18nObject(
|
||||
en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "")
|
||||
),
|
||||
)
|
||||
|
||||
# check if there is a type
|
||||
@@ -108,6 +111,9 @@ class ApiBasedToolSchemaParser:
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
llm_description=property.get("description", ""),
|
||||
default=property.get("default", None),
|
||||
placeholder=I18nObject(
|
||||
en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "")
|
||||
),
|
||||
)
|
||||
|
||||
# check if there is a type
|
||||
@@ -158,9 +164,9 @@ class ApiBasedToolSchemaParser:
|
||||
return bundles
|
||||
|
||||
@staticmethod
|
||||
def _get_tool_parameter_type(parameter: dict) -> ToolParameter.ToolParameterType:
|
||||
def _get_tool_parameter_type(parameter: dict) -> Optional[ToolParameter.ToolParameterType]:
|
||||
parameter = parameter or {}
|
||||
typ = None
|
||||
typ: Optional[str] = None
|
||||
if parameter.get("format") == "binary":
|
||||
return ToolParameter.ToolParameterType.FILE
|
||||
|
||||
@@ -175,6 +181,8 @@ class ApiBasedToolSchemaParser:
|
||||
return ToolParameter.ToolParameterType.BOOLEAN
|
||||
elif typ == "string":
|
||||
return ToolParameter.ToolParameterType.STRING
|
||||
else:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def parse_openapi_yaml_to_tool_bundle(
|
||||
@@ -236,7 +244,8 @@ class ApiBasedToolSchemaParser:
|
||||
if ("summary" not in operation or len(operation["summary"]) == 0) and (
|
||||
"description" not in operation or len(operation["description"]) == 0
|
||||
):
|
||||
warning["missing_summary"] = f"No summary or description found in operation {method} {path}."
|
||||
if warning is not None:
|
||||
warning["missing_summary"] = f"No summary or description found in operation {method} {path}."
|
||||
|
||||
openapi["paths"][path][method] = {
|
||||
"operationId": operation["operationId"],
|
||||
|
||||
@@ -9,13 +9,13 @@ import tempfile
|
||||
import unicodedata
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import Any, Literal, Optional, cast
|
||||
from urllib.parse import unquote
|
||||
|
||||
import chardet
|
||||
import cloudscraper
|
||||
from bs4 import BeautifulSoup, CData, Comment, NavigableString
|
||||
from regex import regex
|
||||
import cloudscraper # type: ignore
|
||||
from bs4 import BeautifulSoup, CData, Comment, NavigableString # type: ignore
|
||||
from regex import regex # type: ignore
|
||||
|
||||
from core.helper import ssrf_proxy
|
||||
from core.rag.extractor import extract_processor
|
||||
@@ -68,7 +68,7 @@ def get_url(url: str, user_agent: Optional[str] = None) -> str:
|
||||
return "Unsupported content-type [{}] of URL.".format(main_content_type)
|
||||
|
||||
if main_content_type in extract_processor.SUPPORT_URL_CONTENT_TYPES:
|
||||
return ExtractProcessor.load_from_url(url, return_text=True)
|
||||
return cast(str, ExtractProcessor.load_from_url(url, return_text=True))
|
||||
|
||||
response = ssrf_proxy.get(url, headers=headers, follow_redirects=True, timeout=(120, 300))
|
||||
elif response.status_code == 403:
|
||||
@@ -125,7 +125,7 @@ def extract_using_readabilipy(html):
|
||||
os.unlink(article_json_path)
|
||||
os.unlink(html_path)
|
||||
|
||||
article_json = {
|
||||
article_json: dict[str, Any] = {
|
||||
"title": None,
|
||||
"byline": None,
|
||||
"date": None,
|
||||
@@ -300,7 +300,7 @@ def strip_control_characters(text):
|
||||
|
||||
def normalize_unicode(text):
|
||||
"""Normalize unicode such that things that are visually equivalent map to the same unicode string where possible."""
|
||||
normal_form = "NFKC"
|
||||
normal_form: Literal["NFC", "NFD", "NFKC", "NFKD"] = "NFKC"
|
||||
text = unicodedata.normalize(normal_form, text)
|
||||
return text
|
||||
|
||||
@@ -332,6 +332,7 @@ def add_content_digest(element):
|
||||
|
||||
|
||||
def content_digest(element):
|
||||
digest: Any
|
||||
if is_text(element):
|
||||
# Hash
|
||||
trimmed_string = element.string.strip()
|
||||
|
||||
@@ -7,7 +7,7 @@ from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
|
||||
|
||||
class WorkflowToolConfigurationUtils:
|
||||
@classmethod
|
||||
def check_parameter_configurations(cls, configurations: Mapping[str, Any]):
|
||||
def check_parameter_configurations(cls, configurations: list[Mapping[str, Any]]):
|
||||
for configuration in configurations:
|
||||
WorkflowToolParameterConfiguration.model_validate(configuration)
|
||||
|
||||
@@ -27,7 +27,7 @@ class WorkflowToolConfigurationUtils:
|
||||
@classmethod
|
||||
def check_is_synced(
|
||||
cls, variables: list[VariableEntity], tool_configurations: list[WorkflowToolParameterConfiguration]
|
||||
) -> None:
|
||||
) -> bool:
|
||||
"""
|
||||
check is synced
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
import yaml # type: ignore
|
||||
from yaml import YAMLError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
Reference in New Issue
Block a user