chore: apply ty checks on api code with script and ci action (#24653)

This commit is contained in:
Bowen Liang
2025-09-02 16:05:13 +08:00
committed by GitHub
parent c373b734bc
commit 7b379e2a61
48 changed files with 188 additions and 142 deletions

View File

@@ -118,7 +118,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
response_chunk.update(sub_stream_response.to_ignore_detail_dict()) # ty: ignore [unresolved-attribute]
else:
response_chunk.update(sub_stream_response.to_dict())

View File

@@ -89,7 +89,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
response_chunk.update(sub_stream_response.to_ignore_detail_dict()) # ty: ignore [unresolved-attribute]
else:
response_chunk.update(sub_stream_response.to_dict())
yield response_chunk

View File

@@ -96,7 +96,11 @@ class RateLimit:
if isinstance(generator, Mapping):
return generator
else:
return RateLimitGenerator(rate_limit=self, generator=generator, request_id=request_id)
return RateLimitGenerator(
rate_limit=self,
generator=generator, # ty: ignore [invalid-argument-type]
request_id=request_id,
)
class RateLimitGenerator:

View File

@@ -50,7 +50,7 @@ class BasedGenerateTaskPipeline:
if isinstance(e, InvokeAuthorizationError):
err = InvokeAuthorizationError("Incorrect API key provided")
elif isinstance(e, InvokeError | ValueError):
err = e
err = e # ty: ignore [invalid-assignment]
else:
description = getattr(e, "description", None)
err = Exception(description if description is not None else str(e))

View File

@@ -43,9 +43,9 @@ class APIBasedExtensionRequestor:
timeout=self.timeout,
proxies=proxies,
)
except requests.exceptions.Timeout:
except requests.Timeout:
raise ValueError("request timeout")
except requests.exceptions.ConnectionError:
except requests.ConnectionError:
raise ValueError("request connection error")
if response.status_code != 200:

View File

@@ -47,7 +47,7 @@ def get_subclasses_from_module(mod: ModuleType, parent_type: type) -> list[type]
def load_single_subclass_from_source(
*, module_name: str, script_path: AnyStr, parent_type: type, use_lazy_loader: bool = False
*, module_name: str, script_path: str, parent_type: type, use_lazy_loader: bool = False
) -> type:
"""
Load a single subclass from the source

View File

@@ -56,11 +56,8 @@ class LLMGenerator:
prompts = [UserPromptMessage(content=prompt)]
with measure_time() as timer:
response = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=list(prompts), model_parameters={"max_tokens": 500, "temperature": 1}, stream=False
),
response: LLMResult = model_instance.invoke_llm(
prompt_messages=list(prompts), model_parameters={"max_tokens": 500, "temperature": 1}, stream=False
)
answer = cast(str, response.message.content)
cleaned_answer = re.sub(r"^.*(\{.*\}).*$", r"\1", answer, flags=re.DOTALL)
@@ -113,13 +110,10 @@ class LLMGenerator:
prompt_messages = [UserPromptMessage(content=prompt)]
try:
response = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=list(prompt_messages),
model_parameters={"max_tokens": 256, "temperature": 0},
stream=False,
),
response: LLMResult = model_instance.invoke_llm(
prompt_messages=list(prompt_messages),
model_parameters={"max_tokens": 256, "temperature": 0},
stream=False,
)
text_content = response.message.get_text_content()
@@ -162,11 +156,8 @@ class LLMGenerator:
)
try:
response = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
),
response: LLMResult = model_instance.invoke_llm(
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
)
rule_config["prompt"] = cast(str, response.message.content)
@@ -212,11 +203,8 @@ class LLMGenerator:
try:
try:
# the first step to generate the task prompt
prompt_content = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
),
prompt_content: LLMResult = model_instance.invoke_llm(
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
)
except InvokeError as e:
error = str(e)
@@ -248,11 +236,8 @@ class LLMGenerator:
statement_messages = [UserPromptMessage(content=statement_generate_prompt)]
try:
parameter_content = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=list(parameter_messages), model_parameters=model_parameters, stream=False
),
parameter_content: LLMResult = model_instance.invoke_llm(
prompt_messages=list(parameter_messages), model_parameters=model_parameters, stream=False
)
rule_config["variables"] = re.findall(r'"\s*([^"]+)\s*"', cast(str, parameter_content.message.content))
except InvokeError as e:
@@ -260,11 +245,8 @@ class LLMGenerator:
error_step = "generate variables"
try:
statement_content = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=list(statement_messages), model_parameters=model_parameters, stream=False
),
statement_content: LLMResult = model_instance.invoke_llm(
prompt_messages=list(statement_messages), model_parameters=model_parameters, stream=False
)
rule_config["opening_statement"] = cast(str, statement_content.message.content)
except InvokeError as e:
@@ -307,11 +289,8 @@ class LLMGenerator:
prompt_messages = [UserPromptMessage(content=prompt)]
model_parameters = model_config.get("completion_params", {})
try:
response = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
),
response: LLMResult = model_instance.invoke_llm(
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
)
generated_code = cast(str, response.message.content)
@@ -338,13 +317,10 @@ class LLMGenerator:
prompt_messages = [SystemPromptMessage(content=prompt), UserPromptMessage(content=query)]
response = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters={"temperature": 0.01, "max_tokens": 2000},
stream=False,
),
response: LLMResult = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters={"temperature": 0.01, "max_tokens": 2000},
stream=False,
)
answer = cast(str, response.message.content)
@@ -367,11 +343,8 @@ class LLMGenerator:
model_parameters = model_config.get("model_parameters", {})
try:
response = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
),
response: LLMResult = model_instance.invoke_llm(
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
)
raw_content = response.message.content
@@ -555,11 +528,8 @@ class LLMGenerator:
model_parameters = {"temperature": 0.4}
try:
response = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
),
response: LLMResult = model_instance.invoke_llm(
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
)
generated_raw = cast(str, response.message.content)

View File

@@ -72,7 +72,7 @@ class TraceClient:
else:
logger.debug("AliyunTrace API check failed: Unexpected status code: %s", response.status_code)
return False
except requests.exceptions.RequestException as e:
except requests.RequestException as e:
logger.debug("AliyunTrace API check failed: %s", str(e))
raise ValueError(f"AliyunTrace API check failed: {str(e)}")

View File

@@ -64,7 +64,7 @@ class BasePluginClient:
response = requests.request(
method=method, url=str(url), headers=headers, data=data, params=params, stream=stream, files=files
)
except requests.exceptions.ConnectionError:
except requests.ConnectionError:
logger.exception("Request to Plugin Daemon Service failed")
raise PluginDaemonInnerError(code=-500, message="Request to Plugin Daemon Service failed")

View File

@@ -192,8 +192,8 @@ class AnalyticdbVectorOpenAPI:
collection=self._collection_name,
metrics=self.config.metrics,
include_values=True,
vector=None,
content=None,
vector=None, # ty: ignore [invalid-argument-type]
content=None, # ty: ignore [invalid-argument-type]
top_k=1,
filter=f"ref_doc_id='{id}'",
)
@@ -211,7 +211,7 @@ class AnalyticdbVectorOpenAPI:
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
collection_data=None,
collection_data=None, # ty: ignore [invalid-argument-type]
collection_data_filter=f"ref_doc_id IN {ids_str}",
)
self._client.delete_collection_data(request)
@@ -225,7 +225,7 @@ class AnalyticdbVectorOpenAPI:
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
collection_data=None,
collection_data=None, # ty: ignore [invalid-argument-type]
collection_data_filter=f"metadata_ ->> '{key}' = '{value}'",
)
self._client.delete_collection_data(request)
@@ -249,7 +249,7 @@ class AnalyticdbVectorOpenAPI:
include_values=kwargs.pop("include_values", True),
metrics=self.config.metrics,
vector=query_vector,
content=None,
content=None, # ty: ignore [invalid-argument-type]
top_k=kwargs.get("top_k", 4),
filter=where_clause,
)
@@ -285,7 +285,7 @@ class AnalyticdbVectorOpenAPI:
collection=self._collection_name,
include_values=kwargs.pop("include_values", True),
metrics=self.config.metrics,
vector=None,
vector=None, # ty: ignore [invalid-argument-type]
content=query,
top_k=kwargs.get("top_k", 4),
filter=where_clause,

View File

@@ -12,7 +12,7 @@ import clickzetta # type: ignore
from pydantic import BaseModel, model_validator
if TYPE_CHECKING:
from clickzetta import Connection
from clickzetta.connector.v0.connection import Connection # type: ignore
from configs import dify_config
from core.rag.datasource.vdb.field import Field

View File

@@ -306,7 +306,7 @@ class CouchbaseVector(BaseVector):
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 4)
try:
CBrequest = search.SearchRequest.create(search.QueryStringQuery("text:" + query))
CBrequest = search.SearchRequest.create(search.QueryStringQuery("text:" + query)) # ty: ignore [too-many-positional-arguments]
search_iter = self._scope.search(
self._collection_name + "_search", CBrequest, SearchOptions(limit=top_k, fields=["*"])
)

View File

@@ -138,7 +138,7 @@ class ElasticSearchVector(BaseVector):
if not client.ping():
raise ConnectionError("Failed to connect to Elasticsearch")
except requests.exceptions.ConnectionError as e:
except requests.ConnectionError as e:
raise ConnectionError(f"Vector database connection error: {str(e)}")
except Exception as e:
raise ConnectionError(f"Elasticsearch client initialization failed: {str(e)}")

View File

@@ -376,7 +376,12 @@ class MilvusVector(BaseVector):
if config.token:
client = MilvusClient(uri=config.uri, token=config.token, db_name=config.database)
else:
client = MilvusClient(uri=config.uri, user=config.user, password=config.password, db_name=config.database)
client = MilvusClient(
uri=config.uri,
user=config.user or "",
password=config.password or "",
db_name=config.database,
)
return client

View File

@@ -32,9 +32,9 @@ class VikingDBConfig(BaseModel):
scheme: str
connection_timeout: int
socket_timeout: int
index_type: str = IndexType.HNSW
distance: str = DistanceType.L2
quant: str = QuantType.Float
index_type: str = str(IndexType.HNSW)
distance: str = str(DistanceType.L2)
quant: str = str(QuantType.Float)
class VikingDBVector(BaseVector):

View File

@@ -37,22 +37,22 @@ class WeaviateVector(BaseVector):
self._attributes = attributes
def _init_client(self, config: WeaviateConfig) -> weaviate.Client:
auth_config = weaviate.auth.AuthApiKey(api_key=config.api_key)
auth_config = weaviate.AuthApiKey(api_key=config.api_key or "")
weaviate.connect.connection.has_grpc = False
weaviate.connect.connection.has_grpc = False # ty: ignore [unresolved-attribute]
# Fix to minimize the performance impact of the deprecation check in weaviate-client 3.24.0,
# by changing the connection timeout to pypi.org from 1 second to 0.001 seconds.
# TODO: This can be removed once weaviate-client is updated to 3.26.7 or higher,
# which does not contain the deprecation check.
if hasattr(weaviate.connect.connection, "PYPI_TIMEOUT"):
weaviate.connect.connection.PYPI_TIMEOUT = 0.001
if hasattr(weaviate.connect.connection, "PYPI_TIMEOUT"): # ty: ignore [unresolved-attribute]
weaviate.connect.connection.PYPI_TIMEOUT = 0.001 # ty: ignore [unresolved-attribute]
try:
client = weaviate.Client(
url=config.endpoint, auth_client_secret=auth_config, timeout_config=(5, 60), startup_period=None
)
except requests.exceptions.ConnectionError:
except requests.ConnectionError:
raise ConnectionError("Vector database connection error")
client.batch.configure(

View File

@@ -1,4 +1,4 @@
from typing import Union, cast
from typing import Union
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.model_manager import ModelInstance
@@ -28,14 +28,11 @@ class FunctionCallMultiDatasetRouter:
SystemPromptMessage(content="You are a helpful AI assistant."),
UserPromptMessage(content=query),
]
result = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=prompt_messages,
tools=dataset_tools,
stream=False,
model_parameters={"temperature": 0.2, "top_p": 0.3, "max_tokens": 1500},
),
result: LLMResult = model_instance.invoke_llm(
prompt_messages=prompt_messages,
tools=dataset_tools,
stream=False,
model_parameters={"temperature": 0.2, "top_p": 0.3, "max_tokens": 1500},
)
if result.message.tool_calls:
# get retrieval model config

View File

@@ -1,5 +1,5 @@
from collections.abc import Generator, Sequence
from typing import Union, cast
from typing import Union
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.model_manager import ModelInstance
@@ -150,15 +150,12 @@ class ReactMultiDatasetRouter:
:param stop: stop
:return:
"""
invoke_result = cast(
Generator[LLMResult, None, None],
model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=completion_param,
stop=stop,
stream=True,
user=user_id,
),
invoke_result: Generator[LLMResult, None, None] = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=completion_param,
stop=stop,
stream=True,
user=user_id,
)
# handle invoke result

View File

@@ -74,7 +74,7 @@ class BuiltinToolProviderController(ToolProviderController):
tool = load_yaml_file(path.join(tool_path, tool_file), ignore_error=False)
# get tool class, import the module
assistant_tool_class: type[BuiltinTool] = load_single_subclass_from_source(
assistant_tool_class: type = load_single_subclass_from_source(
module_name=f"core.tools.builtin_tool.providers.{provider}.tools.{tool_name}",
script_path=path.join(
path.dirname(path.realpath(__file__)),

View File

@@ -26,7 +26,7 @@ class ToolLabelManager:
labels = cls.filter_tool_labels(labels)
if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
provider_id = controller.provider_id
provider_id = controller.provider_id # ty: ignore [unresolved-attribute]
else:
raise ValueError("Unsupported tool type")
@@ -51,7 +51,7 @@ class ToolLabelManager:
Get tool labels
"""
if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
provider_id = controller.provider_id
provider_id = controller.provider_id # ty: ignore [unresolved-attribute]
elif isinstance(controller, BuiltinToolProviderController):
return controller.tool_labels
else:
@@ -85,7 +85,7 @@ class ToolLabelManager:
provider_ids = []
for controller in tool_providers:
assert isinstance(controller, ApiToolProviderController | WorkflowToolProviderController)
provider_ids.append(controller.provider_id)
provider_ids.append(controller.provider_id) # ty: ignore [unresolved-attribute]
labels: list[ToolLabelBinding] = (
db.session.query(ToolLabelBinding).where(ToolLabelBinding.tool_id.in_(provider_ids)).all()

View File

@@ -1,7 +1,6 @@
from abc import abstractmethod
from abc import ABC, abstractmethod
from typing import Optional
from msal_extensions.persistence import ABC # type: ignore
from pydantic import BaseModel, ConfigDict
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler

View File

@@ -52,12 +52,12 @@ class AnswerStreamProcessor(StreamProcessor):
yield event
elif isinstance(event, NodeRunSucceededEvent | NodeRunExceptionEvent):
yield event
if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids:
if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids: # ty: ignore [unresolved-attribute]
# update self.route_position after all stream event finished
for answer_node_id in self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]:
for answer_node_id in self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]: # ty: ignore [unresolved-attribute]
self.route_position[answer_node_id] += 1
del self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]
del self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id] # ty: ignore [unresolved-attribute]
self._remove_unreachable_nodes(event)

View File

@@ -83,7 +83,7 @@ class IfElseNode(BaseNode):
else:
# TODO: Update database then remove this
# Fallback to old structure if cases are not defined
input_conditions, group_result, final_result = _should_not_use_old_function(
input_conditions, group_result, final_result = _should_not_use_old_function( # ty: ignore [deprecated]
condition_processor=condition_processor,
variable_pool=self.graph_runtime_state.variable_pool,
conditions=self._node_data.conditions or [],

View File

@@ -441,8 +441,8 @@ class IterationNode(BaseNode):
iteration_run_id = parallel_mode_run_id if parallel_mode_run_id is not None else f"{current_index}"
next_index = int(current_index) + 1
for event in rst:
if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id:
event.in_iteration_id = self.node_id
if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id: # ty: ignore [unresolved-attribute]
event.in_iteration_id = self.node_id # ty: ignore [unresolved-attribute]
if (
isinstance(event, BaseNodeEvent)

View File

@@ -299,8 +299,8 @@ class LoopNode(BaseNode):
check_break_result = False
for event in rst:
if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_loop_id:
event.in_loop_id = self.node_id
if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_loop_id: # ty: ignore [unresolved-attribute]
event.in_loop_id = self.node_id # ty: ignore [unresolved-attribute]
if (
isinstance(event, BaseNodeEvent)