rm type ignore (#25715)

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
Asuka Minato
2025-10-21 12:26:58 +09:00
committed by GitHub
parent c11cdf7468
commit 32c715c4d0
78 changed files with 229 additions and 204 deletions

View File

@@ -145,7 +145,7 @@ class DatabaseConfig(BaseSettings):
default="postgresql",
)
@computed_field # type: ignore[misc]
@computed_field # type: ignore[prop-decorator]
@property
def SQLALCHEMY_DATABASE_URI(self) -> str:
db_extras = (
@@ -198,7 +198,7 @@ class DatabaseConfig(BaseSettings):
default=os.cpu_count() or 1,
)
@computed_field # type: ignore[misc]
@computed_field # type: ignore[prop-decorator]
@property
def SQLALCHEMY_ENGINE_OPTIONS(self) -> dict[str, Any]:
# Parse DB_EXTRAS for 'options'

View File

@@ -24,7 +24,7 @@ except ImportError:
)
else:
warnings.warn("To use python-magic guess MIMETYPE, you need to install `libmagic`", stacklevel=2)
magic = None # type: ignore
magic = None # type: ignore[assignment]
from pydantic import BaseModel

View File

@@ -211,8 +211,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
user=user,
stream=streaming,
)
# FIXME: Type hinting issue here, ignore it for now, will fix it later
return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) # type: ignore
return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
def _generate_worker(
self,

View File

@@ -89,7 +89,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
response_chunk.update(sub_stream_response.to_ignore_detail_dict()) # ty: ignore [unresolved-attribute]
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
else:
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk

View File

@@ -98,7 +98,7 @@ class RateLimit:
else:
return RateLimitGenerator(
rate_limit=self,
generator=generator, # ty: ignore [invalid-argument-type]
generator=generator,
request_id=request_id,
)

View File

@@ -49,7 +49,7 @@ class BasedGenerateTaskPipeline:
if isinstance(e, InvokeAuthorizationError):
err = InvokeAuthorizationError("Incorrect API key provided")
elif isinstance(e, InvokeError | ValueError):
err = e # ty: ignore [invalid-assignment]
err = e
else:
description = getattr(e, "description", None)
err = Exception(description if description is not None else str(e))

View File

@@ -1868,7 +1868,7 @@ class ProviderConfigurations(BaseModel):
if "/" not in key:
key = str(ModelProviderID(key))
return self.configurations.get(key, default) # type: ignore
return self.configurations.get(key, default)
class ProviderModelBundle(BaseModel):

View File

@@ -20,7 +20,7 @@ def import_module_from_source(*, module_name: str, py_file_path: AnyStr, use_laz
else:
# Refer to: https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly
# FIXME: mypy does not support the type of spec.loader
spec = importlib.util.spec_from_file_location(module_name, py_file_path) # type: ignore
spec = importlib.util.spec_from_file_location(module_name, py_file_path) # type: ignore[assignment]
if not spec or not spec.loader:
raise Exception(f"Failed to load module {module_name} from {py_file_path!r}")
if use_lazy_loader:

View File

@@ -2,7 +2,7 @@ import logging
import os
from datetime import datetime, timedelta
from langfuse import Langfuse # type: ignore
from langfuse import Langfuse
from sqlalchemy.orm import sessionmaker
from core.ops.base_trace_instance import BaseTraceInstance

View File

@@ -180,7 +180,7 @@ class BasePluginClient:
Make a request to the plugin daemon inner API and return the response as a model.
"""
response = self._request(method, path, headers, data, params, files)
return type_(**response.json()) # type: ignore
return type_(**response.json()) # type: ignore[return-value]
def _request_with_plugin_daemon_response(
self,

View File

@@ -74,7 +74,7 @@ class CeleryWorkflowExecutionRepository(WorkflowExecutionRepository):
tenant_id = extract_tenant_id(user)
if not tenant_id:
raise ValueError("User must have a tenant_id or current_tenant_id")
self._tenant_id = tenant_id # type: ignore[assignment] # We've already checked tenant_id is not None
self._tenant_id = tenant_id
# Store app context
self._app_id = app_id

View File

@@ -81,7 +81,7 @@ class CeleryWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
tenant_id = extract_tenant_id(user)
if not tenant_id:
raise ValueError("User must have a tenant_id or current_tenant_id")
self._tenant_id = tenant_id # type: ignore[assignment] # We've already checked tenant_id is not None
self._tenant_id = tenant_id
# Store app context
self._app_id = app_id

View File

@@ -60,7 +60,7 @@ class DifyCoreRepositoryFactory:
try:
repository_class = import_string(class_path)
return repository_class( # type: ignore[no-any-return]
return repository_class(
session_factory=session_factory,
user=user,
app_id=app_id,
@@ -96,7 +96,7 @@ class DifyCoreRepositoryFactory:
try:
repository_class = import_string(class_path)
return repository_class( # type: ignore[no-any-return]
return repository_class(
session_factory=session_factory,
user=user,
app_id=app_id,

View File

@@ -157,7 +157,7 @@ class BuiltinToolProviderController(ToolProviderController):
"""
returns the tool that the provider can provide
"""
return next(filter(lambda x: x.entity.identity.name == tool_name, self.get_tools()), None) # type: ignore
return next(filter(lambda x: x.entity.identity.name == tool_name, self.get_tools()), None)
@property
def need_credentials(self) -> bool:

View File

@@ -43,7 +43,7 @@ class TTSTool(BuiltinTool):
content_text=tool_parameters.get("text"), # type: ignore
user=user_id,
tenant_id=self.runtime.tenant_id,
voice=voice, # type: ignore
voice=voice,
)
buffer = io.BytesIO()
for chunk in tts:

View File

@@ -34,6 +34,7 @@ class LocaltimeToTimestampTool(BuiltinTool):
yield self.create_text_message(f"{timestamp}")
# TODO: this method's type is messy
@staticmethod
def localtime_to_timestamp(localtime: str, time_format: str, local_tz=None) -> int | None:
try:

View File

@@ -48,6 +48,6 @@ class TimezoneConversionTool(BuiltinTool):
datetime_with_tz = input_timezone.localize(local_time)
# timezone convert
converted_datetime = datetime_with_tz.astimezone(output_timezone)
return converted_datetime.strftime(format=time_format) # type: ignore
return converted_datetime.strftime(time_format)
except Exception as e:
raise ToolInvokeError(str(e))

View File

@@ -105,7 +105,7 @@ class MCPToolProviderController(ToolProviderController):
"""
pass
def get_tool(self, tool_name: str) -> MCPTool: # type: ignore
def get_tool(self, tool_name: str) -> MCPTool:
"""
return tool with given name
"""
@@ -128,7 +128,7 @@ class MCPToolProviderController(ToolProviderController):
sse_read_timeout=self.sse_read_timeout,
)
def get_tools(self) -> list[MCPTool]: # type: ignore
def get_tools(self) -> list[MCPTool]:
"""
get all tools
"""

View File

@@ -26,7 +26,7 @@ class ToolLabelManager:
labels = cls.filter_tool_labels(labels)
if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
provider_id = controller.provider_id # ty: ignore [unresolved-attribute]
provider_id = controller.provider_id
else:
raise ValueError("Unsupported tool type")
@@ -51,7 +51,7 @@ class ToolLabelManager:
Get tool labels
"""
if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
provider_id = controller.provider_id # ty: ignore [unresolved-attribute]
provider_id = controller.provider_id
elif isinstance(controller, BuiltinToolProviderController):
return controller.tool_labels
else:
@@ -85,7 +85,7 @@ class ToolLabelManager:
provider_ids = []
for controller in tool_providers:
assert isinstance(controller, ApiToolProviderController | WorkflowToolProviderController)
provider_ids.append(controller.provider_id) # ty: ignore [unresolved-attribute]
provider_ids.append(controller.provider_id)
labels = db.session.scalars(select(ToolLabelBinding).where(ToolLabelBinding.tool_id.in_(provider_ids))).all()

View File

@@ -193,18 +193,18 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
document = db.session.scalar(dataset_document_stmt) # type: ignore
document = db.session.scalar(dataset_document_stmt)
if dataset and document:
source = RetrievalSourceMetadata(
dataset_id=dataset.id,
dataset_name=dataset.name,
document_id=document.id, # type: ignore
document_name=document.name, # type: ignore
data_source_type=document.data_source_type, # type: ignore
document_id=document.id,
document_name=document.name,
data_source_type=document.data_source_type,
segment_id=segment.id,
retriever_from=self.retriever_from,
score=record.score or 0.0,
doc_metadata=document.doc_metadata, # type: ignore
doc_metadata=document.doc_metadata,
)
if self.retriever_from == "dev":

View File

@@ -6,8 +6,8 @@ from typing import Any, cast
from urllib.parse import unquote
import chardet
import cloudscraper # type: ignore
from readabilipy import simple_json_from_html_string # type: ignore
import cloudscraper
from readabilipy import simple_json_from_html_string
from core.helper import ssrf_proxy
from core.rag.extractor import extract_processor
@@ -63,8 +63,8 @@ def get_url(url: str, user_agent: str | None = None) -> str:
response = ssrf_proxy.get(url, headers=headers, follow_redirects=True, timeout=(120, 300))
elif response.status_code == 403:
scraper = cloudscraper.create_scraper()
scraper.perform_request = ssrf_proxy.make_request # type: ignore
response = scraper.get(url, headers=headers, follow_redirects=True, timeout=(120, 300)) # type: ignore
scraper.perform_request = ssrf_proxy.make_request
response = scraper.get(url, headers=headers, timeout=(120, 300))
if response.status_code != 200:
return f"URL returned status code {response.status_code}."

View File

@@ -3,7 +3,7 @@ from functools import lru_cache
from pathlib import Path
from typing import Any
import yaml # type: ignore
import yaml
from yaml import YAMLError
logger = logging.getLogger(__name__)

View File

@@ -99,7 +99,7 @@ class WorkflowToolProviderController(ToolProviderController):
variables = WorkflowToolConfigurationUtils.get_workflow_graph_variables(graph)
def fetch_workflow_variable(variable_name: str) -> VariableEntity | None:
return next(filter(lambda x: x.variable == variable_name, variables), None) # type: ignore
return next(filter(lambda x: x.variable == variable_name, variables), None)
user = db_provider.user

View File

@@ -4,7 +4,7 @@ from .types import SegmentType
class SegmentGroup(Segment):
value_type: SegmentType = SegmentType.GROUP
value: list[Segment] = None # type: ignore
value: list[Segment]
@property
def text(self):

View File

@@ -19,7 +19,7 @@ class Segment(BaseModel):
model_config = ConfigDict(frozen=True)
value_type: SegmentType
value: Any = None
value: Any
@field_validator("value_type")
@classmethod
@@ -74,12 +74,12 @@ class NoneSegment(Segment):
class StringSegment(Segment):
value_type: SegmentType = SegmentType.STRING
value: str = None # type: ignore
value: str
class FloatSegment(Segment):
value_type: SegmentType = SegmentType.FLOAT
value: float = None # type: ignore
value: float
# NOTE(QuantumGhost): seems that the equality for FloatSegment with `NaN` value has some problems.
# The following tests cannot pass.
#
@@ -98,12 +98,12 @@ class FloatSegment(Segment):
class IntegerSegment(Segment):
value_type: SegmentType = SegmentType.INTEGER
value: int = None # type: ignore
value: int
class ObjectSegment(Segment):
value_type: SegmentType = SegmentType.OBJECT
value: Mapping[str, Any] = None # type: ignore
value: Mapping[str, Any]
@property
def text(self) -> str:
@@ -136,7 +136,7 @@ class ArraySegment(Segment):
class FileSegment(Segment):
value_type: SegmentType = SegmentType.FILE
value: File = None # type: ignore
value: File
@property
def markdown(self) -> str:
@@ -153,17 +153,17 @@ class FileSegment(Segment):
class BooleanSegment(Segment):
value_type: SegmentType = SegmentType.BOOLEAN
value: bool = None # type: ignore
value: bool
class ArrayAnySegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_ANY
value: Sequence[Any] = None # type: ignore
value: Sequence[Any]
class ArrayStringSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_STRING
value: Sequence[str] = None # type: ignore
value: Sequence[str]
@property
def text(self) -> str:
@@ -175,17 +175,17 @@ class ArrayStringSegment(ArraySegment):
class ArrayNumberSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_NUMBER
value: Sequence[float | int] = None # type: ignore
value: Sequence[float | int]
class ArrayObjectSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_OBJECT
value: Sequence[Mapping[str, Any]] = None # type: ignore
value: Sequence[Mapping[str, Any]]
class ArrayFileSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_FILE
value: Sequence[File] = None # type: ignore
value: Sequence[File]
@property
def markdown(self) -> str:
@@ -205,7 +205,7 @@ class ArrayFileSegment(ArraySegment):
class ArrayBooleanSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_BOOLEAN
value: Sequence[bool] = None # type: ignore
value: Sequence[bool]
def get_segment_discriminator(v: Any) -> SegmentType | None:

View File

@@ -1,5 +1,6 @@
import json
from abc import ABC
from builtins import type as type_
from collections.abc import Sequence
from enum import StrEnum
from typing import Any, Union
@@ -58,10 +59,9 @@ class DefaultValue(BaseModel):
raise DefaultValueTypeError(f"Invalid JSON format for value: {value}")
@staticmethod
def _validate_array(value: Any, element_type: DefaultValueType) -> bool:
def _validate_array(value: Any, element_type: type_ | tuple[type_, ...]) -> bool:
"""Unified array type validation"""
# FIXME, type ignore here for do not find the reason mypy complain, if find the root cause, please fix it
return isinstance(value, list) and all(isinstance(x, element_type) for x in value) # type: ignore
return isinstance(value, list) and all(isinstance(x, element_type) for x in value)
@staticmethod
def _convert_number(value: str) -> float:

View File

@@ -10,10 +10,10 @@ from typing import Any
import chardet
import docx
import pandas as pd
import pypandoc # type: ignore
import pypdfium2 # type: ignore
import webvtt # type: ignore
import yaml # type: ignore
import pypandoc
import pypdfium2
import webvtt
import yaml
from docx.document import Document
from docx.oxml.table import CT_Tbl
from docx.oxml.text.paragraph import CT_P

View File

@@ -141,7 +141,7 @@ class KnowledgeRetrievalNode(Node):
def version(cls):
return "1"
def _run(self) -> NodeRunResult: # type: ignore
def _run(self) -> NodeRunResult:
# extract variables
variable = self.graph_runtime_state.variable_pool.get(self._node_data.query_variable_selector)
if not isinstance(variable, StringSegment):
@@ -443,7 +443,7 @@ class KnowledgeRetrievalNode(Node):
metadata_condition = MetadataCondition(
logical_operator=node_data.metadata_filtering_conditions.logical_operator
if node_data.metadata_filtering_conditions
else "or", # type: ignore
else "or",
conditions=conditions,
)
elif node_data.metadata_filtering_mode == "manual":
@@ -457,10 +457,10 @@ class KnowledgeRetrievalNode(Node):
expected_value = self.graph_runtime_state.variable_pool.convert_template(
expected_value
).value[0]
if expected_value.value_type in {"number", "integer", "float"}: # type: ignore
expected_value = expected_value.value # type: ignore
elif expected_value.value_type == "string": # type: ignore
expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip() # type: ignore
if expected_value.value_type in {"number", "integer", "float"}:
expected_value = expected_value.value
elif expected_value.value_type == "string":
expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip()
else:
raise ValueError("Invalid expected metadata value type")
conditions.append(
@@ -487,7 +487,7 @@ class KnowledgeRetrievalNode(Node):
if (
node_data.metadata_filtering_conditions
and node_data.metadata_filtering_conditions.logical_operator == "and"
): # type: ignore
):
document_query = document_query.where(and_(*filters))
else:
document_query = document_query.where(or_(*filters))

View File

@@ -260,7 +260,7 @@ class VariablePool(BaseModel):
# This ensures that we can keep the id of the system variables intact.
if self._has(selector):
continue
self.add(selector, value) # type: ignore
self.add(selector, value)
@classmethod
def empty(cls) -> "VariablePool":

View File

@@ -7,7 +7,7 @@ def is_enabled() -> bool:
def init_app(app: DifyApp):
from flask_compress import Compress # type: ignore
from flask_compress import Compress
compress = Compress()
compress.init_app(app)

View File

@@ -1,6 +1,6 @@
import json
import flask_login # type: ignore
import flask_login
from flask import Response, request
from flask_login import user_loaded_from_request, user_logged_in
from werkzeug.exceptions import NotFound, Unauthorized

View File

@@ -2,7 +2,7 @@ from dify_app import DifyApp
def init_app(app: DifyApp):
import flask_migrate # type: ignore
import flask_migrate
from extensions.ext_database import db

View File

@@ -103,7 +103,7 @@ def init_app(app: DifyApp):
def shutdown_tracer():
provider = trace.get_tracer_provider()
if hasattr(provider, "force_flush"):
provider.force_flush() # ty: ignore [call-non-callable]
provider.force_flush()
class ExceptionLoggingHandler(logging.Handler):
"""Custom logging handler that creates spans for logging.exception() calls"""

View File

@@ -6,4 +6,4 @@ def init_app(app: DifyApp):
if dify_config.RESPECT_XFORWARD_HEADERS_ENABLED:
from werkzeug.middleware.proxy_fix import ProxyFix
app.wsgi_app = ProxyFix(app.wsgi_app, x_port=1) # type: ignore
app.wsgi_app = ProxyFix(app.wsgi_app, x_port=1) # type: ignore[method-assign]

View File

@@ -5,7 +5,7 @@ from dify_app import DifyApp
def init_app(app: DifyApp):
if dify_config.SENTRY_DSN:
import sentry_sdk
from langfuse import parse_error # type: ignore
from langfuse import parse_error
from sentry_sdk.integrations.celery import CeleryIntegration
from sentry_sdk.integrations.flask import FlaskIntegration
from werkzeug.exceptions import HTTPException

View File

@@ -1,7 +1,7 @@
import posixpath
from collections.abc import Generator
import oss2 as aliyun_s3 # type: ignore
import oss2 as aliyun_s3
from configs import dify_config
from extensions.storage.base_storage import BaseStorage

View File

@@ -2,9 +2,9 @@ import base64
import hashlib
from collections.abc import Generator
from baidubce.auth.bce_credentials import BceCredentials # type: ignore
from baidubce.bce_client_configuration import BceClientConfiguration # type: ignore
from baidubce.services.bos.bos_client import BosClient # type: ignore
from baidubce.auth.bce_credentials import BceCredentials
from baidubce.bce_client_configuration import BceClientConfiguration
from baidubce.services.bos.bos_client import BosClient
from configs import dify_config
from extensions.storage.base_storage import BaseStorage

View File

@@ -11,7 +11,7 @@ from collections.abc import Generator
from io import BytesIO
from pathlib import Path
import clickzetta # type: ignore[import]
import clickzetta
from pydantic import BaseModel, model_validator
from extensions.storage.base_storage import BaseStorage

View File

@@ -34,7 +34,7 @@ class VolumePermissionManager:
# Support two initialization methods: connection object or configuration dictionary
if isinstance(connection_or_config, dict):
# Create connection from configuration dictionary
import clickzetta # type: ignore[import-untyped]
import clickzetta
config = connection_or_config
self._connection = clickzetta.connect(

View File

@@ -3,7 +3,7 @@ import io
import json
from collections.abc import Generator
from google.cloud import storage as google_cloud_storage # type: ignore
from google.cloud import storage as google_cloud_storage
from configs import dify_config
from extensions.storage.base_storage import BaseStorage

View File

@@ -1,6 +1,6 @@
from collections.abc import Generator
from obs import ObsClient # type: ignore
from obs import ObsClient
from configs import dify_config
from extensions.storage.base_storage import BaseStorage

View File

@@ -1,7 +1,7 @@
from collections.abc import Generator
import boto3 # type: ignore
from botocore.exceptions import ClientError # type: ignore
import boto3
from botocore.exceptions import ClientError
from configs import dify_config
from extensions.storage.base_storage import BaseStorage

View File

@@ -1,6 +1,6 @@
from collections.abc import Generator
from qcloud_cos import CosConfig, CosS3Client # type: ignore
from qcloud_cos import CosConfig, CosS3Client
from configs import dify_config
from extensions.storage.base_storage import BaseStorage

View File

@@ -1,6 +1,6 @@
from collections.abc import Generator
import tos # type: ignore
import tos
from configs import dify_config
from extensions.storage.base_storage import BaseStorage

View File

@@ -146,6 +146,6 @@ class ExternalApi(Api):
kwargs["doc"] = dify_config.SWAGGER_UI_PATH if dify_config.SWAGGER_UI_ENABLED else False
# manual separate call on construction and init_app to ensure configs in kwargs effective
super().__init__(app=None, *args, **kwargs) # type: ignore
super().__init__(app=None, *args, **kwargs)
self.init_app(app, **kwargs)
register_external_error_handlers(self)

View File

@@ -23,7 +23,7 @@ from hashlib import sha1
import Crypto.Hash.SHA1
import Crypto.Util.number
import gmpy2 # type: ignore
import gmpy2
from Crypto import Random
from Crypto.Signature.pss import MGF1
from Crypto.Util.number import bytes_to_long, ceil_div, long_to_bytes
@@ -136,7 +136,7 @@ class PKCS1OAepCipher:
# Step 3a (OS2IP)
em_int = bytes_to_long(em)
# Step 3b (RSAEP)
m_int = gmpy2.powmod(em_int, self._key.e, self._key.n) # ty: ignore [unresolved-attribute]
m_int = gmpy2.powmod(em_int, self._key.e, self._key.n)
# Step 3c (I2OSP)
c = long_to_bytes(m_int, k)
return c
@@ -169,7 +169,7 @@ class PKCS1OAepCipher:
ct_int = bytes_to_long(ciphertext)
# Step 2b (RSADP)
# m_int = self._key._decrypt(ct_int)
m_int = gmpy2.powmod(ct_int, self._key.d, self._key.n) # ty: ignore [unresolved-attribute]
m_int = gmpy2.powmod(ct_int, self._key.d, self._key.n)
# Complete step 2c (I2OSP)
em = long_to_bytes(m_int, k)
# Step 3a
@@ -191,12 +191,12 @@ class PKCS1OAepCipher:
# Step 3g
one_pos = hLen + db[hLen:].find(b"\x01")
lHash1 = db[:hLen]
invalid = bord(y) | int(one_pos < hLen) # type: ignore
invalid = bord(y) | int(one_pos < hLen) # type: ignore[arg-type]
hash_compare = strxor(lHash1, lHash)
for x in hash_compare:
invalid |= bord(x) # type: ignore
invalid |= bord(x) # type: ignore[arg-type]
for x in db[hLen:one_pos]:
invalid |= bord(x) # type: ignore
invalid |= bord(x) # type: ignore[arg-type]
if invalid != 0:
raise ValueError("Incorrect decryption.")
# Step 4

View File

@@ -3,7 +3,7 @@ from functools import wraps
from typing import Any
from flask import current_app, g, has_request_context, request
from flask_login.config import EXEMPT_METHODS # type: ignore
from flask_login.config import EXEMPT_METHODS
from werkzeug.local import LocalProxy
from configs import dify_config
@@ -87,7 +87,7 @@ def _get_user() -> EndUser | Account | None:
if "_login_user" not in g:
current_app.login_manager._load_user() # type: ignore
return g._login_user # type: ignore
return g._login_user
return None

View File

@@ -1,8 +1,8 @@
import logging
import sendgrid # type: ignore
import sendgrid
from python_http_client.exceptions import ForbiddenError, UnauthorizedError
from sendgrid.helpers.mail import Content, Email, Mail, To # type: ignore
from sendgrid.helpers.mail import Content, Email, Mail, To
logger = logging.getLogger(__name__)

View File

@@ -5,7 +5,7 @@ from datetime import datetime
from typing import Any, Optional
import sqlalchemy as sa
from flask_login import UserMixin # type: ignore[import-untyped]
from flask_login import UserMixin
from sqlalchemy import DateTime, String, func, select
from sqlalchemy.orm import Mapped, Session, mapped_column
from typing_extensions import deprecated

View File

@@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Any, Literal, Optional, cast
import sqlalchemy as sa
from flask import request
from flask_login import UserMixin # type: ignore[import-untyped]
from flask_login import UserMixin
from sqlalchemy import Float, Index, PrimaryKeyConstraint, String, exists, func, select, text
from sqlalchemy.orm import Mapped, Session, mapped_column

View File

@@ -16,7 +16,25 @@
"opentelemetry.instrumentation.requests",
"opentelemetry.instrumentation.sqlalchemy",
"opentelemetry.instrumentation.redis",
"opentelemetry.instrumentation.httpx"
"langfuse",
"cloudscraper",
"readabilipy",
"pypandoc",
"pypdfium2",
"webvtt",
"flask_compress",
"oss2",
"baidubce.auth.bce_credentials",
"baidubce.bce_client_configuration",
"baidubce.services.bos.bos_client",
"clickzetta",
"google.cloud",
"obs",
"qcloud_cos",
"tos",
"gmpy2",
"sendgrid",
"sendgrid.helpers.mail"
],
"reportUnknownMemberType": "hint",
"reportUnknownParameterType": "hint",
@@ -28,7 +46,7 @@
"reportUnnecessaryComparison": "hint",
"reportUnnecessaryIsInstance": "hint",
"reportUntypedFunctionDecorator": "hint",
"reportUnnecessaryTypeIgnoreComment": "hint",
"reportAttributeAccessIssue": "hint",
"pythonVersion": "3.11",
"pythonPlatform": "All"

View File

@@ -48,7 +48,7 @@ class DifyAPIRepositoryFactory(DifyCoreRepositoryFactory):
try:
repository_class = import_string(class_path)
return repository_class(session_maker=session_maker) # type: ignore[no-any-return]
return repository_class(session_maker=session_maker)
except (ImportError, Exception) as e:
raise RepositoryImportError(
f"Failed to create DifyAPIWorkflowNodeExecutionRepository from '{class_path}': {e}"
@@ -77,6 +77,6 @@ class DifyAPIRepositoryFactory(DifyCoreRepositoryFactory):
try:
repository_class = import_string(class_path)
return repository_class(session_maker=session_maker) # type: ignore[no-any-return]
return repository_class(session_maker=session_maker)
except (ImportError, Exception) as e:
raise RepositoryImportError(f"Failed to create APIWorkflowRunRepository from '{class_path}': {e}") from e

View File

@@ -7,7 +7,7 @@ from enum import StrEnum
from urllib.parse import urlparse
from uuid import uuid4
import yaml # type: ignore
import yaml
from Crypto.Cipher import AES
from Crypto.Util.Padding import pad, unpad
from packaging import version
@@ -563,7 +563,7 @@ class AppDslService:
else:
cls._append_model_config_export_data(export_data, app_model)
return yaml.dump(export_data, allow_unicode=True) # type: ignore
return yaml.dump(export_data, allow_unicode=True)
@classmethod
def _append_workflow_export_data(

View File

@@ -241,9 +241,9 @@ class DatasetService:
dataset.created_by = account.id
dataset.updated_by = account.id
dataset.tenant_id = tenant_id
dataset.embedding_model_provider = embedding_model.provider if embedding_model else None # type: ignore
dataset.embedding_model = embedding_model.model if embedding_model else None # type: ignore
dataset.retrieval_model = retrieval_model.model_dump() if retrieval_model else None # type: ignore
dataset.embedding_model_provider = embedding_model.provider if embedding_model else None
dataset.embedding_model = embedding_model.model if embedding_model else None
dataset.retrieval_model = retrieval_model.model_dump() if retrieval_model else None
dataset.permission = permission or DatasetPermissionEnum.ONLY_ME
dataset.provider = provider
db.session.add(dataset)
@@ -1416,6 +1416,8 @@ class DocumentService:
# check document limit
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
assert knowledge_config.data_source
assert knowledge_config.data_source.info_list.file_info_list
features = FeatureService.get_features(current_user.current_tenant_id)
@@ -1424,15 +1426,16 @@ class DocumentService:
count = 0
if knowledge_config.data_source:
if knowledge_config.data_source.info_list.data_source_type == "upload_file":
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids
count = len(upload_file_list)
elif knowledge_config.data_source.info_list.data_source_type == "notion_import":
notion_info_list = knowledge_config.data_source.info_list.notion_info_list
for notion_info in notion_info_list: # type: ignore
notion_info_list = knowledge_config.data_source.info_list.notion_info_list or []
for notion_info in notion_info_list:
count = count + len(notion_info.pages)
elif knowledge_config.data_source.info_list.data_source_type == "website_crawl":
website_info = knowledge_config.data_source.info_list.website_info_list
count = len(website_info.urls) # type: ignore
assert website_info
count = len(website_info.urls)
batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
if features.billing.subscription.plan == "sandbox" and count > 1:
@@ -1444,7 +1447,7 @@ class DocumentService:
# if dataset is empty, update dataset data_source_type
if not dataset.data_source_type:
dataset.data_source_type = knowledge_config.data_source.info_list.data_source_type # type: ignore
dataset.data_source_type = knowledge_config.data_source.info_list.data_source_type
if not dataset.indexing_technique:
if knowledge_config.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST:
@@ -1481,7 +1484,7 @@ class DocumentService:
knowledge_config.retrieval_model.model_dump()
if knowledge_config.retrieval_model
else default_retrieval_model
) # type: ignore
)
documents = []
if knowledge_config.original_document_id:
@@ -1523,11 +1526,12 @@ class DocumentService:
db.session.flush()
lock_name = f"add_document_lock_dataset_id_{dataset.id}"
with redis_client.lock(lock_name, timeout=600):
assert dataset_process_rule
position = DocumentService.get_documents_position(dataset.id)
document_ids = []
duplicate_document_ids = []
if knowledge_config.data_source.info_list.data_source_type == "upload_file": # type: ignore
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore
if knowledge_config.data_source.info_list.data_source_type == "upload_file":
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids
for file_id in upload_file_list:
file = (
db.session.query(UploadFile)
@@ -1540,7 +1544,7 @@ class DocumentService:
raise FileNotExistsError()
file_name = file.name
data_source_info = {
data_source_info: dict[str, str | bool] = {
"upload_file_id": file_id,
}
# check duplicate
@@ -1557,7 +1561,7 @@ class DocumentService:
.first()
)
if document:
document.dataset_process_rule_id = dataset_process_rule.id # type: ignore
document.dataset_process_rule_id = dataset_process_rule.id
document.updated_at = naive_utc_now()
document.created_from = created_from
document.doc_form = knowledge_config.doc_form
@@ -1571,8 +1575,8 @@ class DocumentService:
continue
document = DocumentService.build_document(
dataset,
dataset_process_rule.id, # type: ignore
knowledge_config.data_source.info_list.data_source_type, # type: ignore
dataset_process_rule.id,
knowledge_config.data_source.info_list.data_source_type,
knowledge_config.doc_form,
knowledge_config.doc_language,
data_source_info,
@@ -1587,7 +1591,7 @@ class DocumentService:
document_ids.append(document.id)
documents.append(document)
position += 1
elif knowledge_config.data_source.info_list.data_source_type == "notion_import": # type: ignore
elif knowledge_config.data_source.info_list.data_source_type == "notion_import":
notion_info_list = knowledge_config.data_source.info_list.notion_info_list # type: ignore
if not notion_info_list:
raise ValueError("No notion info list found.")
@@ -1616,15 +1620,15 @@ class DocumentService:
"credential_id": notion_info.credential_id,
"notion_workspace_id": workspace_id,
"notion_page_id": page.page_id,
"notion_page_icon": page.page_icon.model_dump() if page.page_icon else None,
"notion_page_icon": page.page_icon.model_dump() if page.page_icon else None, # type: ignore
"type": page.type,
}
# Truncate page name to 255 characters to prevent DB field length errors
truncated_page_name = page.page_name[:255] if page.page_name else "nopagename"
document = DocumentService.build_document(
dataset,
dataset_process_rule.id, # type: ignore
knowledge_config.data_source.info_list.data_source_type, # type: ignore
dataset_process_rule.id,
knowledge_config.data_source.info_list.data_source_type,
knowledge_config.doc_form,
knowledge_config.doc_language,
data_source_info,
@@ -1644,8 +1648,8 @@ class DocumentService:
# delete not selected documents
if len(exist_document) > 0:
clean_notion_document_task.delay(list(exist_document.values()), dataset.id)
elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": # type: ignore
website_info = knowledge_config.data_source.info_list.website_info_list # type: ignore
elif knowledge_config.data_source.info_list.data_source_type == "website_crawl":
website_info = knowledge_config.data_source.info_list.website_info_list
if not website_info:
raise ValueError("No website info list found.")
urls = website_info.urls
@@ -1663,8 +1667,8 @@ class DocumentService:
document_name = url
document = DocumentService.build_document(
dataset,
dataset_process_rule.id, # type: ignore
knowledge_config.data_source.info_list.data_source_type, # type: ignore
dataset_process_rule.id,
knowledge_config.data_source.info_list.data_source_type,
knowledge_config.doc_form,
knowledge_config.doc_language,
data_source_info,
@@ -2071,7 +2075,7 @@ class DocumentService:
# update document data source
if document_data.data_source:
file_name = ""
data_source_info = {}
data_source_info: dict[str, str | bool] = {}
if document_data.data_source.info_list.data_source_type == "upload_file":
if not document_data.data_source.info_list.file_info_list:
raise ValueError("No file info list found.")
@@ -2128,7 +2132,7 @@ class DocumentService:
"url": url,
"provider": website_info.provider,
"job_id": website_info.job_id,
"only_main_content": website_info.only_main_content, # type: ignore
"only_main_content": website_info.only_main_content,
"mode": "crawl",
}
document.data_source_type = document_data.data_source.info_list.data_source_type
@@ -2154,7 +2158,7 @@ class DocumentService:
db.session.query(DocumentSegment).filter_by(document_id=document.id).update(
{DocumentSegment.status: "re_segment"}
) # type: ignore
)
db.session.commit()
# trigger async task
document_indexing_update_task.delay(document.dataset_id, document.id)
@@ -2164,25 +2168,26 @@ class DocumentService:
def save_document_without_dataset_id(tenant_id: str, knowledge_config: KnowledgeConfig, account: Account):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
assert knowledge_config.data_source
features = FeatureService.get_features(current_user.current_tenant_id)
if features.billing.enabled:
count = 0
if knowledge_config.data_source.info_list.data_source_type == "upload_file": # type: ignore
if knowledge_config.data_source.info_list.data_source_type == "upload_file":
upload_file_list = (
knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore
if knowledge_config.data_source.info_list.file_info_list # type: ignore
knowledge_config.data_source.info_list.file_info_list.file_ids
if knowledge_config.data_source.info_list.file_info_list
else []
)
count = len(upload_file_list)
elif knowledge_config.data_source.info_list.data_source_type == "notion_import": # type: ignore
notion_info_list = knowledge_config.data_source.info_list.notion_info_list # type: ignore
elif knowledge_config.data_source.info_list.data_source_type == "notion_import":
notion_info_list = knowledge_config.data_source.info_list.notion_info_list
if notion_info_list:
for notion_info in notion_info_list:
count = count + len(notion_info.pages)
elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": # type: ignore
website_info = knowledge_config.data_source.info_list.website_info_list # type: ignore
elif knowledge_config.data_source.info_list.data_source_type == "website_crawl":
website_info = knowledge_config.data_source.info_list.website_info_list
if website_info:
count = len(website_info.urls)
if features.billing.subscription.plan == "sandbox" and count > 1:
@@ -2196,9 +2201,11 @@ class DocumentService:
dataset_collection_binding_id = None
retrieval_model = None
if knowledge_config.indexing_technique == "high_quality":
assert knowledge_config.embedding_model_provider
assert knowledge_config.embedding_model
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
knowledge_config.embedding_model_provider, # type: ignore
knowledge_config.embedding_model, # type: ignore
knowledge_config.embedding_model_provider,
knowledge_config.embedding_model,
)
dataset_collection_binding_id = dataset_collection_binding.id
if knowledge_config.retrieval_model:
@@ -2215,7 +2222,7 @@ class DocumentService:
dataset = Dataset(
tenant_id=tenant_id,
name="",
data_source_type=knowledge_config.data_source.info_list.data_source_type, # type: ignore
data_source_type=knowledge_config.data_source.info_list.data_source_type,
indexing_technique=knowledge_config.indexing_technique,
created_by=account.id,
embedding_model=knowledge_config.embedding_model,
@@ -2224,7 +2231,7 @@ class DocumentService:
retrieval_model=retrieval_model.model_dump() if retrieval_model else None,
)
db.session.add(dataset) # type: ignore
db.session.add(dataset)
db.session.flush()
documents, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account)

View File

@@ -88,7 +88,7 @@ class HitTestingService:
db.session.add(dataset_query)
db.session.commit()
return cls.compact_retrieve_response(query, all_documents) # type: ignore
return cls.compact_retrieve_response(query, all_documents)
@classmethod
def external_retrieve(

View File

@@ -1,4 +1,4 @@
import boto3 # type: ignore
import boto3
from configs import dify_config

View File

@@ -89,7 +89,7 @@ class MetadataService:
document.doc_metadata = doc_metadata
db.session.add(document)
db.session.commit()
return metadata # type: ignore
return metadata
except Exception:
logger.exception("Update metadata name failed")
finally:

View File

@@ -137,7 +137,7 @@ class ModelProviderService:
:return:
"""
provider_configuration = self._get_provider_configuration(tenant_id, provider)
return provider_configuration.get_provider_credential(credential_id=credential_id) # type: ignore
return provider_configuration.get_provider_credential(credential_id=credential_id)
def validate_provider_credentials(self, tenant_id: str, provider: str, credentials: dict):
"""
@@ -225,7 +225,7 @@ class ModelProviderService:
:return:
"""
provider_configuration = self._get_provider_configuration(tenant_id, provider)
return provider_configuration.get_custom_model_credential( # type: ignore
return provider_configuration.get_custom_model_credential(
model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id
)

View File

@@ -146,7 +146,7 @@ class PluginMigration:
futures.append(
thread_pool.submit(
process_tenant,
current_app._get_current_object(), # type: ignore[attr-defined]
current_app._get_current_object(), # type: ignore
tenant_id,
)
)

View File

@@ -544,8 +544,8 @@ class BuiltinToolManageService:
try:
# handle include, exclude
if is_filtered(
include_set=dify_config.POSITION_TOOL_INCLUDES_SET, # type: ignore
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, # type: ignore
include_set=dify_config.POSITION_TOOL_INCLUDES_SET,
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET,
data=provider_controller,
name_func=lambda x: x.entity.identity.name,
):

View File

@@ -308,7 +308,7 @@ class MCPToolManageService:
provider_controller = MCPToolProviderController.from_db(mcp_provider)
tool_configuration = ProviderConfigEncrypter(
tenant_id=mcp_provider.tenant_id,
config=list(provider_controller.get_credentials_schema()), # ty: ignore [invalid-argument-type]
config=list(provider_controller.get_credentials_schema()),
provider_config_cache=NoOpProviderCredentialCache(),
)
credentials = tool_configuration.encrypt(credentials)

View File

@@ -102,7 +102,7 @@ def batch_create_segment_to_index_task(
for segment, tokens in zip(content, tokens_list):
content = segment["content"]
doc_id = str(uuid.uuid4())
segment_hash = helper.generate_text_hash(content) # type: ignore
segment_hash = helper.generate_text_hash(content)
max_position = (
db.session.query(func.max(DocumentSegment.position))
.where(DocumentSegment.document_id == dataset_document.id)

View File

@@ -5,11 +5,11 @@ from unittest.mock import MagicMock
import pytest
from _pytest.monkeypatch import MonkeyPatch
from pymochow import MochowClient # type: ignore
from pymochow.model.database import Database # type: ignore
from pymochow.model.enum import IndexState, IndexType, MetricType, ReadConsistency, TableState # type: ignore
from pymochow.model.schema import HNSWParams, VectorIndex # type: ignore
from pymochow.model.table import Table # type: ignore
from pymochow import MochowClient
from pymochow.model.database import Database
from pymochow.model.enum import IndexState, IndexType, MetricType, ReadConsistency, TableState
from pymochow.model.schema import HNSWParams, VectorIndex
from pymochow.model.table import Table
class AttrDict(UserDict):

View File

@@ -3,15 +3,15 @@ from typing import Any, Union
import pytest
from _pytest.monkeypatch import MonkeyPatch
from tcvectordb import RPCVectorDBClient # type: ignore
from tcvectordb import RPCVectorDBClient
from tcvectordb.model import enum
from tcvectordb.model.collection import FilterIndexConfig
from tcvectordb.model.document import AnnSearch, Document, Filter, KeywordSearch, Rerank # type: ignore
from tcvectordb.model.enum import ReadConsistency # type: ignore
from tcvectordb.model.index import FilterIndex, HNSWParams, Index, IndexField, VectorIndex # type: ignore
from tcvectordb.model.document import AnnSearch, Document, Filter, KeywordSearch, Rerank
from tcvectordb.model.enum import ReadConsistency
from tcvectordb.model.index import FilterIndex, HNSWParams, Index, IndexField, VectorIndex
from tcvectordb.rpc.model.collection import RPCCollection
from tcvectordb.rpc.model.database import RPCDatabase
from xinference_client.types import Embedding # type: ignore
from xinference_client.types import Embedding
class MockTcvectordbClass:

View File

@@ -4,7 +4,7 @@ from unittest.mock import MagicMock
import pytest
from _pytest.monkeypatch import MonkeyPatch
from volcengine.viking_db import ( # type: ignore
from volcengine.viking_db import (
Collection,
Data,
DistanceType,

View File

@@ -43,7 +43,7 @@ class TestWorkflowResponseConverterFetchFilesFromVariableValue:
"""Test with None input"""
# The method signature expects Union[dict, list, Segment], but implementation handles None
# We'll test the actual behavior by passing an empty dict instead
result = WorkflowResponseConverter._fetch_files_from_variable_value(None) # type: ignore
result = WorkflowResponseConverter._fetch_files_from_variable_value(None)
assert result == []
def test_fetch_files_from_variable_value_with_empty_dict(self):

View File

@@ -235,7 +235,7 @@ class TestIndividualHandlers:
# Type assertion needed due to union type
text_content = result.content[0]
assert hasattr(text_content, "text")
assert text_content.text == "test answer" # type: ignore[attr-defined]
assert text_content.text == "test answer"
def test_handle_call_tool_no_end_user(self):
"""Test call tool handler without end user"""

View File

@@ -212,7 +212,7 @@ class TestValidateResult:
parameters=[
ParameterConfig(
name="status",
type="select", # type: ignore
type="select",
description="Status",
required=True,
options=["active", "inactive"],
@@ -400,7 +400,7 @@ class TestTransformResult:
parameters=[
ParameterConfig(
name="status",
type="select", # type: ignore
type="select",
description="Status",
required=True,
options=["active", "inactive"],
@@ -414,7 +414,7 @@ class TestTransformResult:
parameters=[
ParameterConfig(
name="status",
type="select", # type: ignore
type="select",
description="Status",
required=True,
options=["active", "inactive"],

View File

@@ -248,4 +248,4 @@ def test_constructor_with_extra_key():
# Test that SystemVariable should forbid extra keys
with pytest.raises(ValidationError):
# This should fail because there is an unexpected key.
SystemVariable(invalid_key=1) # type: ignore
SystemVariable(invalid_key=1)

View File

@@ -14,36 +14,36 @@ def _create_api_app():
api = ExternalApi(bp)
@api.route("/bad-request")
class Bad(Resource): # type: ignore
def get(self): # type: ignore
class Bad(Resource):
def get(self):
raise BadRequest("invalid input")
@api.route("/unauth")
class Unauth(Resource): # type: ignore
def get(self): # type: ignore
class Unauth(Resource):
def get(self):
raise Unauthorized("auth required")
@api.route("/value-error")
class ValErr(Resource): # type: ignore
def get(self): # type: ignore
class ValErr(Resource):
def get(self):
raise ValueError("boom")
@api.route("/quota")
class Quota(Resource): # type: ignore
def get(self): # type: ignore
class Quota(Resource):
def get(self):
raise AppInvokeQuotaExceededError("quota exceeded")
@api.route("/general")
class Gen(Resource): # type: ignore
def get(self): # type: ignore
class Gen(Resource):
def get(self):
raise RuntimeError("oops")
# Note: We avoid altering default_mediatype to keep normal error paths
# Special 400 message rewrite
@api.route("/json-empty")
class JsonEmpty(Resource): # type: ignore
def get(self): # type: ignore
class JsonEmpty(Resource):
def get(self):
e = BadRequest()
# Force the specific message the handler rewrites
e.description = "Failed to decode JSON object: Expecting value: line 1 column 1 (char 0)"
@@ -51,11 +51,11 @@ def _create_api_app():
# 400 mapping payload path
@api.route("/param-errors")
class ParamErrors(Resource): # type: ignore
def get(self): # type: ignore
class ParamErrors(Resource):
def get(self):
e = BadRequest()
# Coerce a mapping description to trigger param error shaping
e.description = {"field": "is required"} # type: ignore[assignment]
e.description = {"field": "is required"}
raise e
app.register_blueprint(bp, url_prefix="/api")
@@ -105,7 +105,7 @@ def test_external_api_param_mapping_and_quota_and_exc_info_none():
orig_exc_info = ext.sys.exc_info
try:
ext.sys.exc_info = lambda: (None, None, None) # type: ignore[assignment]
ext.sys.exc_info = lambda: (None, None, None)
app = _create_api_app()
client = app.test_client()

View File

@@ -67,7 +67,7 @@ def test_current_user_not_accessible_across_threads(login_app: Flask, test_user:
# without preserve_flask_contexts
result["user_accessible"] = current_user.is_authenticated
except Exception as e:
result["error"] = str(e) # type: ignore
result["error"] = str(e)
# Run the function in a separate thread
thread = threading.Thread(target=check_user_in_thread)
@@ -110,7 +110,7 @@ def test_current_user_accessible_with_preserve_flask_contexts(login_app: Flask,
else:
result["user_accessible"] = False
except Exception as e:
result["error"] = str(e) # type: ignore
result["error"] = str(e)
# Run the function in a separate thread
thread = threading.Thread(target=check_user_in_thread_with_manager)

View File

@@ -16,4 +16,4 @@ def test_oauth_base_methods_raise_not_implemented():
oauth.get_raw_user_info("token")
with pytest.raises(NotImplementedError):
oauth._transform_user_info({}) # type: ignore[name-defined]
oauth._transform_user_info({})

View File

@@ -3,8 +3,8 @@ from unittest.mock import MagicMock
import pytest
from _pytest.monkeypatch import MonkeyPatch
from qcloud_cos import CosS3Client # type: ignore
from qcloud_cos.streambody import StreamBody # type: ignore
from qcloud_cos import CosS3Client
from qcloud_cos.streambody import StreamBody
from tests.unit_tests.oss.__mock.base import (
get_example_bucket,

View File

@@ -4,8 +4,8 @@ from unittest.mock import MagicMock
import pytest
from _pytest.monkeypatch import MonkeyPatch
from tos import TosClientV2 # type: ignore
from tos.clientv2 import DeleteObjectOutput, GetObjectOutput, HeadObjectOutput, PutObjectOutput # type: ignore
from tos import TosClientV2
from tos.clientv2 import DeleteObjectOutput, GetObjectOutput, HeadObjectOutput, PutObjectOutput
from tests.unit_tests.oss.__mock.base import (
get_example_bucket,

View File

@@ -1,7 +1,7 @@
from unittest.mock import patch
import pytest
from qcloud_cos import CosConfig # type: ignore
from qcloud_cos import CosConfig
from extensions.storage.tencent_cos_storage import TencentCosStorage
from tests.unit_tests.oss.__mock.base import (

View File

@@ -1,7 +1,7 @@
from unittest.mock import patch
import pytest
from tos import TosClientV2 # type: ignore
from tos import TosClientV2
from extensions.storage.volcengine_tos_storage import VolcengineTosStorage
from tests.unit_tests.oss.__mock.base import (

View File

@@ -125,13 +125,13 @@ class TestApiKeyAuthService:
mock_session.commit = Mock()
args_copy = self.mock_args.copy()
original_key = args_copy["credentials"]["config"]["api_key"] # type: ignore
original_key = args_copy["credentials"]["config"]["api_key"]
ApiKeyAuthService.create_provider_auth(self.tenant_id, args_copy)
# Verify original key is replaced with encrypted key
assert args_copy["credentials"]["config"]["api_key"] == encrypted_key # type: ignore
assert args_copy["credentials"]["config"]["api_key"] != original_key # type: ignore
assert args_copy["credentials"]["config"]["api_key"] == encrypted_key
assert args_copy["credentials"]["config"]["api_key"] != original_key
# Verify encryption function is called correctly
mock_encrypter.encrypt_token.assert_called_once_with(self.tenant_id, original_key)
@@ -268,7 +268,7 @@ class TestApiKeyAuthService:
def test_validate_api_key_auth_args_empty_credentials(self):
"""Test API key auth args validation - empty credentials"""
args = self.mock_args.copy()
args["credentials"] = None # type: ignore
args["credentials"] = None
with pytest.raises(ValueError, match="credentials is required"):
ApiKeyAuthService.validate_api_key_auth_args(args)
@@ -284,7 +284,7 @@ class TestApiKeyAuthService:
def test_validate_api_key_auth_args_missing_auth_type(self):
"""Test API key auth args validation - missing auth_type"""
args = self.mock_args.copy()
del args["credentials"]["auth_type"] # type: ignore
del args["credentials"]["auth_type"]
with pytest.raises(ValueError, match="auth_type is required"):
ApiKeyAuthService.validate_api_key_auth_args(args)
@@ -292,7 +292,7 @@ class TestApiKeyAuthService:
def test_validate_api_key_auth_args_empty_auth_type(self):
"""Test API key auth args validation - empty auth_type"""
args = self.mock_args.copy()
args["credentials"]["auth_type"] = "" # type: ignore
args["credentials"]["auth_type"] = ""
with pytest.raises(ValueError, match="auth_type is required"):
ApiKeyAuthService.validate_api_key_auth_args(args)
@@ -380,7 +380,7 @@ class TestApiKeyAuthService:
def test_validate_api_key_auth_args_dict_credentials_with_list_auth_type(self):
"""Test API key auth args validation - dict credentials with list auth_type"""
args = self.mock_args.copy()
args["credentials"]["auth_type"] = ["api_key"] # type: ignore # list instead of string
args["credentials"]["auth_type"] = ["api_key"]
# Current implementation checks if auth_type exists and is truthy, list ["api_key"] is truthy
# So this should not raise exception, this test should pass

View File

@@ -116,10 +116,10 @@ class TestSystemOAuthEncrypter:
encrypter = SystemOAuthEncrypter("test_secret")
with pytest.raises(Exception): # noqa: B017
encrypter.encrypt_oauth_params(None) # type: ignore
encrypter.encrypt_oauth_params(None)
with pytest.raises(Exception): # noqa: B017
encrypter.encrypt_oauth_params("not_a_dict") # type: ignore
encrypter.encrypt_oauth_params("not_a_dict")
def test_decrypt_oauth_params_basic(self):
"""Test basic OAuth parameters decryption"""
@@ -207,12 +207,12 @@ class TestSystemOAuthEncrypter:
encrypter = SystemOAuthEncrypter("test_secret")
with pytest.raises(ValueError) as exc_info:
encrypter.decrypt_oauth_params(123) # type: ignore
encrypter.decrypt_oauth_params(123)
assert "encrypted_data must be a string" in str(exc_info.value)
with pytest.raises(ValueError) as exc_info:
encrypter.decrypt_oauth_params(None) # type: ignore
encrypter.decrypt_oauth_params(None)
assert "encrypted_data must be a string" in str(exc_info.value)
@@ -461,14 +461,14 @@ class TestConvenienceFunctions:
"""Test convenience functions with error conditions"""
# Test encryption with invalid input
with pytest.raises(Exception): # noqa: B017
encrypt_system_oauth_params(None) # type: ignore
encrypt_system_oauth_params(None)
# Test decryption with invalid input
with pytest.raises(ValueError):
decrypt_system_oauth_params("")
with pytest.raises(ValueError):
decrypt_system_oauth_params(None) # type: ignore
decrypt_system_oauth_params(None)
class TestErrorHandling:
@@ -501,7 +501,7 @@ class TestErrorHandling:
# Test non-string error
with pytest.raises(ValueError) as exc_info:
encrypter.decrypt_oauth_params(123) # type: ignore
encrypter.decrypt_oauth_params(123)
assert "encrypted_data must be a string" in str(exc_info.value)
# Test invalid format error