mirror of
https://github.com/langgenius/dify.git
synced 2025-12-19 17:27:16 -05:00
rm type ignore (#25715)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
@@ -145,7 +145,7 @@ class DatabaseConfig(BaseSettings):
|
||||
default="postgresql",
|
||||
)
|
||||
|
||||
@computed_field # type: ignore[misc]
|
||||
@computed_field # type: ignore[prop-decorator]
|
||||
@property
|
||||
def SQLALCHEMY_DATABASE_URI(self) -> str:
|
||||
db_extras = (
|
||||
@@ -198,7 +198,7 @@ class DatabaseConfig(BaseSettings):
|
||||
default=os.cpu_count() or 1,
|
||||
)
|
||||
|
||||
@computed_field # type: ignore[misc]
|
||||
@computed_field # type: ignore[prop-decorator]
|
||||
@property
|
||||
def SQLALCHEMY_ENGINE_OPTIONS(self) -> dict[str, Any]:
|
||||
# Parse DB_EXTRAS for 'options'
|
||||
|
||||
@@ -24,7 +24,7 @@ except ImportError:
|
||||
)
|
||||
else:
|
||||
warnings.warn("To use python-magic guess MIMETYPE, you need to install `libmagic`", stacklevel=2)
|
||||
magic = None # type: ignore
|
||||
magic = None # type: ignore[assignment]
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
@@ -211,8 +211,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
user=user,
|
||||
stream=streaming,
|
||||
)
|
||||
# FIXME: Type hinting issue here, ignore it for now, will fix it later
|
||||
return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) # type: ignore
|
||||
return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
|
||||
|
||||
def _generate_worker(
|
||||
self,
|
||||
|
||||
@@ -89,7 +89,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
|
||||
response_chunk.update(sub_stream_response.to_ignore_detail_dict()) # ty: ignore [unresolved-attribute]
|
||||
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
yield response_chunk
|
||||
|
||||
@@ -98,7 +98,7 @@ class RateLimit:
|
||||
else:
|
||||
return RateLimitGenerator(
|
||||
rate_limit=self,
|
||||
generator=generator, # ty: ignore [invalid-argument-type]
|
||||
generator=generator,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
|
||||
@@ -49,7 +49,7 @@ class BasedGenerateTaskPipeline:
|
||||
if isinstance(e, InvokeAuthorizationError):
|
||||
err = InvokeAuthorizationError("Incorrect API key provided")
|
||||
elif isinstance(e, InvokeError | ValueError):
|
||||
err = e # ty: ignore [invalid-assignment]
|
||||
err = e
|
||||
else:
|
||||
description = getattr(e, "description", None)
|
||||
err = Exception(description if description is not None else str(e))
|
||||
|
||||
@@ -1868,7 +1868,7 @@ class ProviderConfigurations(BaseModel):
|
||||
if "/" not in key:
|
||||
key = str(ModelProviderID(key))
|
||||
|
||||
return self.configurations.get(key, default) # type: ignore
|
||||
return self.configurations.get(key, default)
|
||||
|
||||
|
||||
class ProviderModelBundle(BaseModel):
|
||||
|
||||
@@ -20,7 +20,7 @@ def import_module_from_source(*, module_name: str, py_file_path: AnyStr, use_laz
|
||||
else:
|
||||
# Refer to: https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly
|
||||
# FIXME: mypy does not support the type of spec.loader
|
||||
spec = importlib.util.spec_from_file_location(module_name, py_file_path) # type: ignore
|
||||
spec = importlib.util.spec_from_file_location(module_name, py_file_path) # type: ignore[assignment]
|
||||
if not spec or not spec.loader:
|
||||
raise Exception(f"Failed to load module {module_name} from {py_file_path!r}")
|
||||
if use_lazy_loader:
|
||||
|
||||
@@ -2,7 +2,7 @@ import logging
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from langfuse import Langfuse # type: ignore
|
||||
from langfuse import Langfuse
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
|
||||
@@ -180,7 +180,7 @@ class BasePluginClient:
|
||||
Make a request to the plugin daemon inner API and return the response as a model.
|
||||
"""
|
||||
response = self._request(method, path, headers, data, params, files)
|
||||
return type_(**response.json()) # type: ignore
|
||||
return type_(**response.json()) # type: ignore[return-value]
|
||||
|
||||
def _request_with_plugin_daemon_response(
|
||||
self,
|
||||
|
||||
@@ -74,7 +74,7 @@ class CeleryWorkflowExecutionRepository(WorkflowExecutionRepository):
|
||||
tenant_id = extract_tenant_id(user)
|
||||
if not tenant_id:
|
||||
raise ValueError("User must have a tenant_id or current_tenant_id")
|
||||
self._tenant_id = tenant_id # type: ignore[assignment] # We've already checked tenant_id is not None
|
||||
self._tenant_id = tenant_id
|
||||
|
||||
# Store app context
|
||||
self._app_id = app_id
|
||||
|
||||
@@ -81,7 +81,7 @@ class CeleryWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
|
||||
tenant_id = extract_tenant_id(user)
|
||||
if not tenant_id:
|
||||
raise ValueError("User must have a tenant_id or current_tenant_id")
|
||||
self._tenant_id = tenant_id # type: ignore[assignment] # We've already checked tenant_id is not None
|
||||
self._tenant_id = tenant_id
|
||||
|
||||
# Store app context
|
||||
self._app_id = app_id
|
||||
|
||||
@@ -60,7 +60,7 @@ class DifyCoreRepositoryFactory:
|
||||
|
||||
try:
|
||||
repository_class = import_string(class_path)
|
||||
return repository_class( # type: ignore[no-any-return]
|
||||
return repository_class(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=app_id,
|
||||
@@ -96,7 +96,7 @@ class DifyCoreRepositoryFactory:
|
||||
|
||||
try:
|
||||
repository_class = import_string(class_path)
|
||||
return repository_class( # type: ignore[no-any-return]
|
||||
return repository_class(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=app_id,
|
||||
|
||||
@@ -157,7 +157,7 @@ class BuiltinToolProviderController(ToolProviderController):
|
||||
"""
|
||||
returns the tool that the provider can provide
|
||||
"""
|
||||
return next(filter(lambda x: x.entity.identity.name == tool_name, self.get_tools()), None) # type: ignore
|
||||
return next(filter(lambda x: x.entity.identity.name == tool_name, self.get_tools()), None)
|
||||
|
||||
@property
|
||||
def need_credentials(self) -> bool:
|
||||
|
||||
@@ -43,7 +43,7 @@ class TTSTool(BuiltinTool):
|
||||
content_text=tool_parameters.get("text"), # type: ignore
|
||||
user=user_id,
|
||||
tenant_id=self.runtime.tenant_id,
|
||||
voice=voice, # type: ignore
|
||||
voice=voice,
|
||||
)
|
||||
buffer = io.BytesIO()
|
||||
for chunk in tts:
|
||||
|
||||
@@ -34,6 +34,7 @@ class LocaltimeToTimestampTool(BuiltinTool):
|
||||
|
||||
yield self.create_text_message(f"{timestamp}")
|
||||
|
||||
# TODO: this method's type is messy
|
||||
@staticmethod
|
||||
def localtime_to_timestamp(localtime: str, time_format: str, local_tz=None) -> int | None:
|
||||
try:
|
||||
|
||||
@@ -48,6 +48,6 @@ class TimezoneConversionTool(BuiltinTool):
|
||||
datetime_with_tz = input_timezone.localize(local_time)
|
||||
# timezone convert
|
||||
converted_datetime = datetime_with_tz.astimezone(output_timezone)
|
||||
return converted_datetime.strftime(format=time_format) # type: ignore
|
||||
return converted_datetime.strftime(time_format)
|
||||
except Exception as e:
|
||||
raise ToolInvokeError(str(e))
|
||||
|
||||
@@ -105,7 +105,7 @@ class MCPToolProviderController(ToolProviderController):
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_tool(self, tool_name: str) -> MCPTool: # type: ignore
|
||||
def get_tool(self, tool_name: str) -> MCPTool:
|
||||
"""
|
||||
return tool with given name
|
||||
"""
|
||||
@@ -128,7 +128,7 @@ class MCPToolProviderController(ToolProviderController):
|
||||
sse_read_timeout=self.sse_read_timeout,
|
||||
)
|
||||
|
||||
def get_tools(self) -> list[MCPTool]: # type: ignore
|
||||
def get_tools(self) -> list[MCPTool]:
|
||||
"""
|
||||
get all tools
|
||||
"""
|
||||
|
||||
@@ -26,7 +26,7 @@ class ToolLabelManager:
|
||||
labels = cls.filter_tool_labels(labels)
|
||||
|
||||
if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
|
||||
provider_id = controller.provider_id # ty: ignore [unresolved-attribute]
|
||||
provider_id = controller.provider_id
|
||||
else:
|
||||
raise ValueError("Unsupported tool type")
|
||||
|
||||
@@ -51,7 +51,7 @@ class ToolLabelManager:
|
||||
Get tool labels
|
||||
"""
|
||||
if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
|
||||
provider_id = controller.provider_id # ty: ignore [unresolved-attribute]
|
||||
provider_id = controller.provider_id
|
||||
elif isinstance(controller, BuiltinToolProviderController):
|
||||
return controller.tool_labels
|
||||
else:
|
||||
@@ -85,7 +85,7 @@ class ToolLabelManager:
|
||||
provider_ids = []
|
||||
for controller in tool_providers:
|
||||
assert isinstance(controller, ApiToolProviderController | WorkflowToolProviderController)
|
||||
provider_ids.append(controller.provider_id) # ty: ignore [unresolved-attribute]
|
||||
provider_ids.append(controller.provider_id)
|
||||
|
||||
labels = db.session.scalars(select(ToolLabelBinding).where(ToolLabelBinding.tool_id.in_(provider_ids))).all()
|
||||
|
||||
|
||||
@@ -193,18 +193,18 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
)
|
||||
document = db.session.scalar(dataset_document_stmt) # type: ignore
|
||||
document = db.session.scalar(dataset_document_stmt)
|
||||
if dataset and document:
|
||||
source = RetrievalSourceMetadata(
|
||||
dataset_id=dataset.id,
|
||||
dataset_name=dataset.name,
|
||||
document_id=document.id, # type: ignore
|
||||
document_name=document.name, # type: ignore
|
||||
data_source_type=document.data_source_type, # type: ignore
|
||||
document_id=document.id,
|
||||
document_name=document.name,
|
||||
data_source_type=document.data_source_type,
|
||||
segment_id=segment.id,
|
||||
retriever_from=self.retriever_from,
|
||||
score=record.score or 0.0,
|
||||
doc_metadata=document.doc_metadata, # type: ignore
|
||||
doc_metadata=document.doc_metadata,
|
||||
)
|
||||
|
||||
if self.retriever_from == "dev":
|
||||
|
||||
@@ -6,8 +6,8 @@ from typing import Any, cast
|
||||
from urllib.parse import unquote
|
||||
|
||||
import chardet
|
||||
import cloudscraper # type: ignore
|
||||
from readabilipy import simple_json_from_html_string # type: ignore
|
||||
import cloudscraper
|
||||
from readabilipy import simple_json_from_html_string
|
||||
|
||||
from core.helper import ssrf_proxy
|
||||
from core.rag.extractor import extract_processor
|
||||
@@ -63,8 +63,8 @@ def get_url(url: str, user_agent: str | None = None) -> str:
|
||||
response = ssrf_proxy.get(url, headers=headers, follow_redirects=True, timeout=(120, 300))
|
||||
elif response.status_code == 403:
|
||||
scraper = cloudscraper.create_scraper()
|
||||
scraper.perform_request = ssrf_proxy.make_request # type: ignore
|
||||
response = scraper.get(url, headers=headers, follow_redirects=True, timeout=(120, 300)) # type: ignore
|
||||
scraper.perform_request = ssrf_proxy.make_request
|
||||
response = scraper.get(url, headers=headers, timeout=(120, 300))
|
||||
|
||||
if response.status_code != 200:
|
||||
return f"URL returned status code {response.status_code}."
|
||||
|
||||
@@ -3,7 +3,7 @@ from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import yaml # type: ignore
|
||||
import yaml
|
||||
from yaml import YAMLError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -99,7 +99,7 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
variables = WorkflowToolConfigurationUtils.get_workflow_graph_variables(graph)
|
||||
|
||||
def fetch_workflow_variable(variable_name: str) -> VariableEntity | None:
|
||||
return next(filter(lambda x: x.variable == variable_name, variables), None) # type: ignore
|
||||
return next(filter(lambda x: x.variable == variable_name, variables), None)
|
||||
|
||||
user = db_provider.user
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ from .types import SegmentType
|
||||
|
||||
class SegmentGroup(Segment):
|
||||
value_type: SegmentType = SegmentType.GROUP
|
||||
value: list[Segment] = None # type: ignore
|
||||
value: list[Segment]
|
||||
|
||||
@property
|
||||
def text(self):
|
||||
|
||||
@@ -19,7 +19,7 @@ class Segment(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
value_type: SegmentType
|
||||
value: Any = None
|
||||
value: Any
|
||||
|
||||
@field_validator("value_type")
|
||||
@classmethod
|
||||
@@ -74,12 +74,12 @@ class NoneSegment(Segment):
|
||||
|
||||
class StringSegment(Segment):
|
||||
value_type: SegmentType = SegmentType.STRING
|
||||
value: str = None # type: ignore
|
||||
value: str
|
||||
|
||||
|
||||
class FloatSegment(Segment):
|
||||
value_type: SegmentType = SegmentType.FLOAT
|
||||
value: float = None # type: ignore
|
||||
value: float
|
||||
# NOTE(QuantumGhost): seems that the equality for FloatSegment with `NaN` value has some problems.
|
||||
# The following tests cannot pass.
|
||||
#
|
||||
@@ -98,12 +98,12 @@ class FloatSegment(Segment):
|
||||
|
||||
class IntegerSegment(Segment):
|
||||
value_type: SegmentType = SegmentType.INTEGER
|
||||
value: int = None # type: ignore
|
||||
value: int
|
||||
|
||||
|
||||
class ObjectSegment(Segment):
|
||||
value_type: SegmentType = SegmentType.OBJECT
|
||||
value: Mapping[str, Any] = None # type: ignore
|
||||
value: Mapping[str, Any]
|
||||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
@@ -136,7 +136,7 @@ class ArraySegment(Segment):
|
||||
|
||||
class FileSegment(Segment):
|
||||
value_type: SegmentType = SegmentType.FILE
|
||||
value: File = None # type: ignore
|
||||
value: File
|
||||
|
||||
@property
|
||||
def markdown(self) -> str:
|
||||
@@ -153,17 +153,17 @@ class FileSegment(Segment):
|
||||
|
||||
class BooleanSegment(Segment):
|
||||
value_type: SegmentType = SegmentType.BOOLEAN
|
||||
value: bool = None # type: ignore
|
||||
value: bool
|
||||
|
||||
|
||||
class ArrayAnySegment(ArraySegment):
|
||||
value_type: SegmentType = SegmentType.ARRAY_ANY
|
||||
value: Sequence[Any] = None # type: ignore
|
||||
value: Sequence[Any]
|
||||
|
||||
|
||||
class ArrayStringSegment(ArraySegment):
|
||||
value_type: SegmentType = SegmentType.ARRAY_STRING
|
||||
value: Sequence[str] = None # type: ignore
|
||||
value: Sequence[str]
|
||||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
@@ -175,17 +175,17 @@ class ArrayStringSegment(ArraySegment):
|
||||
|
||||
class ArrayNumberSegment(ArraySegment):
|
||||
value_type: SegmentType = SegmentType.ARRAY_NUMBER
|
||||
value: Sequence[float | int] = None # type: ignore
|
||||
value: Sequence[float | int]
|
||||
|
||||
|
||||
class ArrayObjectSegment(ArraySegment):
|
||||
value_type: SegmentType = SegmentType.ARRAY_OBJECT
|
||||
value: Sequence[Mapping[str, Any]] = None # type: ignore
|
||||
value: Sequence[Mapping[str, Any]]
|
||||
|
||||
|
||||
class ArrayFileSegment(ArraySegment):
|
||||
value_type: SegmentType = SegmentType.ARRAY_FILE
|
||||
value: Sequence[File] = None # type: ignore
|
||||
value: Sequence[File]
|
||||
|
||||
@property
|
||||
def markdown(self) -> str:
|
||||
@@ -205,7 +205,7 @@ class ArrayFileSegment(ArraySegment):
|
||||
|
||||
class ArrayBooleanSegment(ArraySegment):
|
||||
value_type: SegmentType = SegmentType.ARRAY_BOOLEAN
|
||||
value: Sequence[bool] = None # type: ignore
|
||||
value: Sequence[bool]
|
||||
|
||||
|
||||
def get_segment_discriminator(v: Any) -> SegmentType | None:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import json
|
||||
from abc import ABC
|
||||
from builtins import type as type_
|
||||
from collections.abc import Sequence
|
||||
from enum import StrEnum
|
||||
from typing import Any, Union
|
||||
@@ -58,10 +59,9 @@ class DefaultValue(BaseModel):
|
||||
raise DefaultValueTypeError(f"Invalid JSON format for value: {value}")
|
||||
|
||||
@staticmethod
|
||||
def _validate_array(value: Any, element_type: DefaultValueType) -> bool:
|
||||
def _validate_array(value: Any, element_type: type_ | tuple[type_, ...]) -> bool:
|
||||
"""Unified array type validation"""
|
||||
# FIXME, type ignore here for do not find the reason mypy complain, if find the root cause, please fix it
|
||||
return isinstance(value, list) and all(isinstance(x, element_type) for x in value) # type: ignore
|
||||
return isinstance(value, list) and all(isinstance(x, element_type) for x in value)
|
||||
|
||||
@staticmethod
|
||||
def _convert_number(value: str) -> float:
|
||||
|
||||
@@ -10,10 +10,10 @@ from typing import Any
|
||||
import chardet
|
||||
import docx
|
||||
import pandas as pd
|
||||
import pypandoc # type: ignore
|
||||
import pypdfium2 # type: ignore
|
||||
import webvtt # type: ignore
|
||||
import yaml # type: ignore
|
||||
import pypandoc
|
||||
import pypdfium2
|
||||
import webvtt
|
||||
import yaml
|
||||
from docx.document import Document
|
||||
from docx.oxml.table import CT_Tbl
|
||||
from docx.oxml.text.paragraph import CT_P
|
||||
|
||||
@@ -141,7 +141,7 @@ class KnowledgeRetrievalNode(Node):
|
||||
def version(cls):
|
||||
return "1"
|
||||
|
||||
def _run(self) -> NodeRunResult: # type: ignore
|
||||
def _run(self) -> NodeRunResult:
|
||||
# extract variables
|
||||
variable = self.graph_runtime_state.variable_pool.get(self._node_data.query_variable_selector)
|
||||
if not isinstance(variable, StringSegment):
|
||||
@@ -443,7 +443,7 @@ class KnowledgeRetrievalNode(Node):
|
||||
metadata_condition = MetadataCondition(
|
||||
logical_operator=node_data.metadata_filtering_conditions.logical_operator
|
||||
if node_data.metadata_filtering_conditions
|
||||
else "or", # type: ignore
|
||||
else "or",
|
||||
conditions=conditions,
|
||||
)
|
||||
elif node_data.metadata_filtering_mode == "manual":
|
||||
@@ -457,10 +457,10 @@ class KnowledgeRetrievalNode(Node):
|
||||
expected_value = self.graph_runtime_state.variable_pool.convert_template(
|
||||
expected_value
|
||||
).value[0]
|
||||
if expected_value.value_type in {"number", "integer", "float"}: # type: ignore
|
||||
expected_value = expected_value.value # type: ignore
|
||||
elif expected_value.value_type == "string": # type: ignore
|
||||
expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip() # type: ignore
|
||||
if expected_value.value_type in {"number", "integer", "float"}:
|
||||
expected_value = expected_value.value
|
||||
elif expected_value.value_type == "string":
|
||||
expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip()
|
||||
else:
|
||||
raise ValueError("Invalid expected metadata value type")
|
||||
conditions.append(
|
||||
@@ -487,7 +487,7 @@ class KnowledgeRetrievalNode(Node):
|
||||
if (
|
||||
node_data.metadata_filtering_conditions
|
||||
and node_data.metadata_filtering_conditions.logical_operator == "and"
|
||||
): # type: ignore
|
||||
):
|
||||
document_query = document_query.where(and_(*filters))
|
||||
else:
|
||||
document_query = document_query.where(or_(*filters))
|
||||
|
||||
@@ -260,7 +260,7 @@ class VariablePool(BaseModel):
|
||||
# This ensures that we can keep the id of the system variables intact.
|
||||
if self._has(selector):
|
||||
continue
|
||||
self.add(selector, value) # type: ignore
|
||||
self.add(selector, value)
|
||||
|
||||
@classmethod
|
||||
def empty(cls) -> "VariablePool":
|
||||
|
||||
@@ -7,7 +7,7 @@ def is_enabled() -> bool:
|
||||
|
||||
|
||||
def init_app(app: DifyApp):
|
||||
from flask_compress import Compress # type: ignore
|
||||
from flask_compress import Compress
|
||||
|
||||
compress = Compress()
|
||||
compress.init_app(app)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import json
|
||||
|
||||
import flask_login # type: ignore
|
||||
import flask_login
|
||||
from flask import Response, request
|
||||
from flask_login import user_loaded_from_request, user_logged_in
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
|
||||
@@ -2,7 +2,7 @@ from dify_app import DifyApp
|
||||
|
||||
|
||||
def init_app(app: DifyApp):
|
||||
import flask_migrate # type: ignore
|
||||
import flask_migrate
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
|
||||
@@ -103,7 +103,7 @@ def init_app(app: DifyApp):
|
||||
def shutdown_tracer():
|
||||
provider = trace.get_tracer_provider()
|
||||
if hasattr(provider, "force_flush"):
|
||||
provider.force_flush() # ty: ignore [call-non-callable]
|
||||
provider.force_flush()
|
||||
|
||||
class ExceptionLoggingHandler(logging.Handler):
|
||||
"""Custom logging handler that creates spans for logging.exception() calls"""
|
||||
|
||||
@@ -6,4 +6,4 @@ def init_app(app: DifyApp):
|
||||
if dify_config.RESPECT_XFORWARD_HEADERS_ENABLED:
|
||||
from werkzeug.middleware.proxy_fix import ProxyFix
|
||||
|
||||
app.wsgi_app = ProxyFix(app.wsgi_app, x_port=1) # type: ignore
|
||||
app.wsgi_app = ProxyFix(app.wsgi_app, x_port=1) # type: ignore[method-assign]
|
||||
|
||||
@@ -5,7 +5,7 @@ from dify_app import DifyApp
|
||||
def init_app(app: DifyApp):
|
||||
if dify_config.SENTRY_DSN:
|
||||
import sentry_sdk
|
||||
from langfuse import parse_error # type: ignore
|
||||
from langfuse import parse_error
|
||||
from sentry_sdk.integrations.celery import CeleryIntegration
|
||||
from sentry_sdk.integrations.flask import FlaskIntegration
|
||||
from werkzeug.exceptions import HTTPException
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import posixpath
|
||||
from collections.abc import Generator
|
||||
|
||||
import oss2 as aliyun_s3 # type: ignore
|
||||
import oss2 as aliyun_s3
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.storage.base_storage import BaseStorage
|
||||
|
||||
@@ -2,9 +2,9 @@ import base64
|
||||
import hashlib
|
||||
from collections.abc import Generator
|
||||
|
||||
from baidubce.auth.bce_credentials import BceCredentials # type: ignore
|
||||
from baidubce.bce_client_configuration import BceClientConfiguration # type: ignore
|
||||
from baidubce.services.bos.bos_client import BosClient # type: ignore
|
||||
from baidubce.auth.bce_credentials import BceCredentials
|
||||
from baidubce.bce_client_configuration import BceClientConfiguration
|
||||
from baidubce.services.bos.bos_client import BosClient
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.storage.base_storage import BaseStorage
|
||||
|
||||
@@ -11,7 +11,7 @@ from collections.abc import Generator
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
|
||||
import clickzetta # type: ignore[import]
|
||||
import clickzetta
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from extensions.storage.base_storage import BaseStorage
|
||||
|
||||
@@ -34,7 +34,7 @@ class VolumePermissionManager:
|
||||
# Support two initialization methods: connection object or configuration dictionary
|
||||
if isinstance(connection_or_config, dict):
|
||||
# Create connection from configuration dictionary
|
||||
import clickzetta # type: ignore[import-untyped]
|
||||
import clickzetta
|
||||
|
||||
config = connection_or_config
|
||||
self._connection = clickzetta.connect(
|
||||
|
||||
@@ -3,7 +3,7 @@ import io
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
|
||||
from google.cloud import storage as google_cloud_storage # type: ignore
|
||||
from google.cloud import storage as google_cloud_storage
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.storage.base_storage import BaseStorage
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
from obs import ObsClient # type: ignore
|
||||
from obs import ObsClient
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.storage.base_storage import BaseStorage
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
import boto3 # type: ignore
|
||||
from botocore.exceptions import ClientError # type: ignore
|
||||
import boto3
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.storage.base_storage import BaseStorage
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
from qcloud_cos import CosConfig, CosS3Client # type: ignore
|
||||
from qcloud_cos import CosConfig, CosS3Client
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.storage.base_storage import BaseStorage
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
import tos # type: ignore
|
||||
import tos
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.storage.base_storage import BaseStorage
|
||||
|
||||
@@ -146,6 +146,6 @@ class ExternalApi(Api):
|
||||
kwargs["doc"] = dify_config.SWAGGER_UI_PATH if dify_config.SWAGGER_UI_ENABLED else False
|
||||
|
||||
# manual separate call on construction and init_app to ensure configs in kwargs effective
|
||||
super().__init__(app=None, *args, **kwargs) # type: ignore
|
||||
super().__init__(app=None, *args, **kwargs)
|
||||
self.init_app(app, **kwargs)
|
||||
register_external_error_handlers(self)
|
||||
|
||||
@@ -23,7 +23,7 @@ from hashlib import sha1
|
||||
|
||||
import Crypto.Hash.SHA1
|
||||
import Crypto.Util.number
|
||||
import gmpy2 # type: ignore
|
||||
import gmpy2
|
||||
from Crypto import Random
|
||||
from Crypto.Signature.pss import MGF1
|
||||
from Crypto.Util.number import bytes_to_long, ceil_div, long_to_bytes
|
||||
@@ -136,7 +136,7 @@ class PKCS1OAepCipher:
|
||||
# Step 3a (OS2IP)
|
||||
em_int = bytes_to_long(em)
|
||||
# Step 3b (RSAEP)
|
||||
m_int = gmpy2.powmod(em_int, self._key.e, self._key.n) # ty: ignore [unresolved-attribute]
|
||||
m_int = gmpy2.powmod(em_int, self._key.e, self._key.n)
|
||||
# Step 3c (I2OSP)
|
||||
c = long_to_bytes(m_int, k)
|
||||
return c
|
||||
@@ -169,7 +169,7 @@ class PKCS1OAepCipher:
|
||||
ct_int = bytes_to_long(ciphertext)
|
||||
# Step 2b (RSADP)
|
||||
# m_int = self._key._decrypt(ct_int)
|
||||
m_int = gmpy2.powmod(ct_int, self._key.d, self._key.n) # ty: ignore [unresolved-attribute]
|
||||
m_int = gmpy2.powmod(ct_int, self._key.d, self._key.n)
|
||||
# Complete step 2c (I2OSP)
|
||||
em = long_to_bytes(m_int, k)
|
||||
# Step 3a
|
||||
@@ -191,12 +191,12 @@ class PKCS1OAepCipher:
|
||||
# Step 3g
|
||||
one_pos = hLen + db[hLen:].find(b"\x01")
|
||||
lHash1 = db[:hLen]
|
||||
invalid = bord(y) | int(one_pos < hLen) # type: ignore
|
||||
invalid = bord(y) | int(one_pos < hLen) # type: ignore[arg-type]
|
||||
hash_compare = strxor(lHash1, lHash)
|
||||
for x in hash_compare:
|
||||
invalid |= bord(x) # type: ignore
|
||||
invalid |= bord(x) # type: ignore[arg-type]
|
||||
for x in db[hLen:one_pos]:
|
||||
invalid |= bord(x) # type: ignore
|
||||
invalid |= bord(x) # type: ignore[arg-type]
|
||||
if invalid != 0:
|
||||
raise ValueError("Incorrect decryption.")
|
||||
# Step 4
|
||||
|
||||
@@ -3,7 +3,7 @@ from functools import wraps
|
||||
from typing import Any
|
||||
|
||||
from flask import current_app, g, has_request_context, request
|
||||
from flask_login.config import EXEMPT_METHODS # type: ignore
|
||||
from flask_login.config import EXEMPT_METHODS
|
||||
from werkzeug.local import LocalProxy
|
||||
|
||||
from configs import dify_config
|
||||
@@ -87,7 +87,7 @@ def _get_user() -> EndUser | Account | None:
|
||||
if "_login_user" not in g:
|
||||
current_app.login_manager._load_user() # type: ignore
|
||||
|
||||
return g._login_user # type: ignore
|
||||
return g._login_user
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import logging
|
||||
|
||||
import sendgrid # type: ignore
|
||||
import sendgrid
|
||||
from python_http_client.exceptions import ForbiddenError, UnauthorizedError
|
||||
from sendgrid.helpers.mail import Content, Email, Mail, To # type: ignore
|
||||
from sendgrid.helpers.mail import Content, Email, Mail, To
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
|
||||
import sqlalchemy as sa
|
||||
from flask_login import UserMixin # type: ignore[import-untyped]
|
||||
from flask_login import UserMixin
|
||||
from sqlalchemy import DateTime, String, func, select
|
||||
from sqlalchemy.orm import Mapped, Session, mapped_column
|
||||
from typing_extensions import deprecated
|
||||
|
||||
@@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Any, Literal, Optional, cast
|
||||
|
||||
import sqlalchemy as sa
|
||||
from flask import request
|
||||
from flask_login import UserMixin # type: ignore[import-untyped]
|
||||
from flask_login import UserMixin
|
||||
from sqlalchemy import Float, Index, PrimaryKeyConstraint, String, exists, func, select, text
|
||||
from sqlalchemy.orm import Mapped, Session, mapped_column
|
||||
|
||||
|
||||
@@ -16,7 +16,25 @@
|
||||
"opentelemetry.instrumentation.requests",
|
||||
"opentelemetry.instrumentation.sqlalchemy",
|
||||
"opentelemetry.instrumentation.redis",
|
||||
"opentelemetry.instrumentation.httpx"
|
||||
"langfuse",
|
||||
"cloudscraper",
|
||||
"readabilipy",
|
||||
"pypandoc",
|
||||
"pypdfium2",
|
||||
"webvtt",
|
||||
"flask_compress",
|
||||
"oss2",
|
||||
"baidubce.auth.bce_credentials",
|
||||
"baidubce.bce_client_configuration",
|
||||
"baidubce.services.bos.bos_client",
|
||||
"clickzetta",
|
||||
"google.cloud",
|
||||
"obs",
|
||||
"qcloud_cos",
|
||||
"tos",
|
||||
"gmpy2",
|
||||
"sendgrid",
|
||||
"sendgrid.helpers.mail"
|
||||
],
|
||||
"reportUnknownMemberType": "hint",
|
||||
"reportUnknownParameterType": "hint",
|
||||
@@ -28,7 +46,7 @@
|
||||
"reportUnnecessaryComparison": "hint",
|
||||
"reportUnnecessaryIsInstance": "hint",
|
||||
"reportUntypedFunctionDecorator": "hint",
|
||||
|
||||
"reportUnnecessaryTypeIgnoreComment": "hint",
|
||||
"reportAttributeAccessIssue": "hint",
|
||||
"pythonVersion": "3.11",
|
||||
"pythonPlatform": "All"
|
||||
|
||||
@@ -48,7 +48,7 @@ class DifyAPIRepositoryFactory(DifyCoreRepositoryFactory):
|
||||
|
||||
try:
|
||||
repository_class = import_string(class_path)
|
||||
return repository_class(session_maker=session_maker) # type: ignore[no-any-return]
|
||||
return repository_class(session_maker=session_maker)
|
||||
except (ImportError, Exception) as e:
|
||||
raise RepositoryImportError(
|
||||
f"Failed to create DifyAPIWorkflowNodeExecutionRepository from '{class_path}': {e}"
|
||||
@@ -77,6 +77,6 @@ class DifyAPIRepositoryFactory(DifyCoreRepositoryFactory):
|
||||
|
||||
try:
|
||||
repository_class = import_string(class_path)
|
||||
return repository_class(session_maker=session_maker) # type: ignore[no-any-return]
|
||||
return repository_class(session_maker=session_maker)
|
||||
except (ImportError, Exception) as e:
|
||||
raise RepositoryImportError(f"Failed to create APIWorkflowRunRepository from '{class_path}': {e}") from e
|
||||
|
||||
@@ -7,7 +7,7 @@ from enum import StrEnum
|
||||
from urllib.parse import urlparse
|
||||
from uuid import uuid4
|
||||
|
||||
import yaml # type: ignore
|
||||
import yaml
|
||||
from Crypto.Cipher import AES
|
||||
from Crypto.Util.Padding import pad, unpad
|
||||
from packaging import version
|
||||
@@ -563,7 +563,7 @@ class AppDslService:
|
||||
else:
|
||||
cls._append_model_config_export_data(export_data, app_model)
|
||||
|
||||
return yaml.dump(export_data, allow_unicode=True) # type: ignore
|
||||
return yaml.dump(export_data, allow_unicode=True)
|
||||
|
||||
@classmethod
|
||||
def _append_workflow_export_data(
|
||||
|
||||
@@ -241,9 +241,9 @@ class DatasetService:
|
||||
dataset.created_by = account.id
|
||||
dataset.updated_by = account.id
|
||||
dataset.tenant_id = tenant_id
|
||||
dataset.embedding_model_provider = embedding_model.provider if embedding_model else None # type: ignore
|
||||
dataset.embedding_model = embedding_model.model if embedding_model else None # type: ignore
|
||||
dataset.retrieval_model = retrieval_model.model_dump() if retrieval_model else None # type: ignore
|
||||
dataset.embedding_model_provider = embedding_model.provider if embedding_model else None
|
||||
dataset.embedding_model = embedding_model.model if embedding_model else None
|
||||
dataset.retrieval_model = retrieval_model.model_dump() if retrieval_model else None
|
||||
dataset.permission = permission or DatasetPermissionEnum.ONLY_ME
|
||||
dataset.provider = provider
|
||||
db.session.add(dataset)
|
||||
@@ -1416,6 +1416,8 @@ class DocumentService:
|
||||
# check document limit
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
assert knowledge_config.data_source
|
||||
assert knowledge_config.data_source.info_list.file_info_list
|
||||
|
||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||
|
||||
@@ -1424,15 +1426,16 @@ class DocumentService:
|
||||
count = 0
|
||||
if knowledge_config.data_source:
|
||||
if knowledge_config.data_source.info_list.data_source_type == "upload_file":
|
||||
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore
|
||||
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids
|
||||
count = len(upload_file_list)
|
||||
elif knowledge_config.data_source.info_list.data_source_type == "notion_import":
|
||||
notion_info_list = knowledge_config.data_source.info_list.notion_info_list
|
||||
for notion_info in notion_info_list: # type: ignore
|
||||
notion_info_list = knowledge_config.data_source.info_list.notion_info_list or []
|
||||
for notion_info in notion_info_list:
|
||||
count = count + len(notion_info.pages)
|
||||
elif knowledge_config.data_source.info_list.data_source_type == "website_crawl":
|
||||
website_info = knowledge_config.data_source.info_list.website_info_list
|
||||
count = len(website_info.urls) # type: ignore
|
||||
assert website_info
|
||||
count = len(website_info.urls)
|
||||
batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
|
||||
|
||||
if features.billing.subscription.plan == "sandbox" and count > 1:
|
||||
@@ -1444,7 +1447,7 @@ class DocumentService:
|
||||
|
||||
# if dataset is empty, update dataset data_source_type
|
||||
if not dataset.data_source_type:
|
||||
dataset.data_source_type = knowledge_config.data_source.info_list.data_source_type # type: ignore
|
||||
dataset.data_source_type = knowledge_config.data_source.info_list.data_source_type
|
||||
|
||||
if not dataset.indexing_technique:
|
||||
if knowledge_config.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST:
|
||||
@@ -1481,7 +1484,7 @@ class DocumentService:
|
||||
knowledge_config.retrieval_model.model_dump()
|
||||
if knowledge_config.retrieval_model
|
||||
else default_retrieval_model
|
||||
) # type: ignore
|
||||
)
|
||||
|
||||
documents = []
|
||||
if knowledge_config.original_document_id:
|
||||
@@ -1523,11 +1526,12 @@ class DocumentService:
|
||||
db.session.flush()
|
||||
lock_name = f"add_document_lock_dataset_id_{dataset.id}"
|
||||
with redis_client.lock(lock_name, timeout=600):
|
||||
assert dataset_process_rule
|
||||
position = DocumentService.get_documents_position(dataset.id)
|
||||
document_ids = []
|
||||
duplicate_document_ids = []
|
||||
if knowledge_config.data_source.info_list.data_source_type == "upload_file": # type: ignore
|
||||
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore
|
||||
if knowledge_config.data_source.info_list.data_source_type == "upload_file":
|
||||
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids
|
||||
for file_id in upload_file_list:
|
||||
file = (
|
||||
db.session.query(UploadFile)
|
||||
@@ -1540,7 +1544,7 @@ class DocumentService:
|
||||
raise FileNotExistsError()
|
||||
|
||||
file_name = file.name
|
||||
data_source_info = {
|
||||
data_source_info: dict[str, str | bool] = {
|
||||
"upload_file_id": file_id,
|
||||
}
|
||||
# check duplicate
|
||||
@@ -1557,7 +1561,7 @@ class DocumentService:
|
||||
.first()
|
||||
)
|
||||
if document:
|
||||
document.dataset_process_rule_id = dataset_process_rule.id # type: ignore
|
||||
document.dataset_process_rule_id = dataset_process_rule.id
|
||||
document.updated_at = naive_utc_now()
|
||||
document.created_from = created_from
|
||||
document.doc_form = knowledge_config.doc_form
|
||||
@@ -1571,8 +1575,8 @@ class DocumentService:
|
||||
continue
|
||||
document = DocumentService.build_document(
|
||||
dataset,
|
||||
dataset_process_rule.id, # type: ignore
|
||||
knowledge_config.data_source.info_list.data_source_type, # type: ignore
|
||||
dataset_process_rule.id,
|
||||
knowledge_config.data_source.info_list.data_source_type,
|
||||
knowledge_config.doc_form,
|
||||
knowledge_config.doc_language,
|
||||
data_source_info,
|
||||
@@ -1587,7 +1591,7 @@ class DocumentService:
|
||||
document_ids.append(document.id)
|
||||
documents.append(document)
|
||||
position += 1
|
||||
elif knowledge_config.data_source.info_list.data_source_type == "notion_import": # type: ignore
|
||||
elif knowledge_config.data_source.info_list.data_source_type == "notion_import":
|
||||
notion_info_list = knowledge_config.data_source.info_list.notion_info_list # type: ignore
|
||||
if not notion_info_list:
|
||||
raise ValueError("No notion info list found.")
|
||||
@@ -1616,15 +1620,15 @@ class DocumentService:
|
||||
"credential_id": notion_info.credential_id,
|
||||
"notion_workspace_id": workspace_id,
|
||||
"notion_page_id": page.page_id,
|
||||
"notion_page_icon": page.page_icon.model_dump() if page.page_icon else None,
|
||||
"notion_page_icon": page.page_icon.model_dump() if page.page_icon else None, # type: ignore
|
||||
"type": page.type,
|
||||
}
|
||||
# Truncate page name to 255 characters to prevent DB field length errors
|
||||
truncated_page_name = page.page_name[:255] if page.page_name else "nopagename"
|
||||
document = DocumentService.build_document(
|
||||
dataset,
|
||||
dataset_process_rule.id, # type: ignore
|
||||
knowledge_config.data_source.info_list.data_source_type, # type: ignore
|
||||
dataset_process_rule.id,
|
||||
knowledge_config.data_source.info_list.data_source_type,
|
||||
knowledge_config.doc_form,
|
||||
knowledge_config.doc_language,
|
||||
data_source_info,
|
||||
@@ -1644,8 +1648,8 @@ class DocumentService:
|
||||
# delete not selected documents
|
||||
if len(exist_document) > 0:
|
||||
clean_notion_document_task.delay(list(exist_document.values()), dataset.id)
|
||||
elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": # type: ignore
|
||||
website_info = knowledge_config.data_source.info_list.website_info_list # type: ignore
|
||||
elif knowledge_config.data_source.info_list.data_source_type == "website_crawl":
|
||||
website_info = knowledge_config.data_source.info_list.website_info_list
|
||||
if not website_info:
|
||||
raise ValueError("No website info list found.")
|
||||
urls = website_info.urls
|
||||
@@ -1663,8 +1667,8 @@ class DocumentService:
|
||||
document_name = url
|
||||
document = DocumentService.build_document(
|
||||
dataset,
|
||||
dataset_process_rule.id, # type: ignore
|
||||
knowledge_config.data_source.info_list.data_source_type, # type: ignore
|
||||
dataset_process_rule.id,
|
||||
knowledge_config.data_source.info_list.data_source_type,
|
||||
knowledge_config.doc_form,
|
||||
knowledge_config.doc_language,
|
||||
data_source_info,
|
||||
@@ -2071,7 +2075,7 @@ class DocumentService:
|
||||
# update document data source
|
||||
if document_data.data_source:
|
||||
file_name = ""
|
||||
data_source_info = {}
|
||||
data_source_info: dict[str, str | bool] = {}
|
||||
if document_data.data_source.info_list.data_source_type == "upload_file":
|
||||
if not document_data.data_source.info_list.file_info_list:
|
||||
raise ValueError("No file info list found.")
|
||||
@@ -2128,7 +2132,7 @@ class DocumentService:
|
||||
"url": url,
|
||||
"provider": website_info.provider,
|
||||
"job_id": website_info.job_id,
|
||||
"only_main_content": website_info.only_main_content, # type: ignore
|
||||
"only_main_content": website_info.only_main_content,
|
||||
"mode": "crawl",
|
||||
}
|
||||
document.data_source_type = document_data.data_source.info_list.data_source_type
|
||||
@@ -2154,7 +2158,7 @@ class DocumentService:
|
||||
|
||||
db.session.query(DocumentSegment).filter_by(document_id=document.id).update(
|
||||
{DocumentSegment.status: "re_segment"}
|
||||
) # type: ignore
|
||||
)
|
||||
db.session.commit()
|
||||
# trigger async task
|
||||
document_indexing_update_task.delay(document.dataset_id, document.id)
|
||||
@@ -2164,25 +2168,26 @@ class DocumentService:
|
||||
def save_document_without_dataset_id(tenant_id: str, knowledge_config: KnowledgeConfig, account: Account):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
assert knowledge_config.data_source
|
||||
|
||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||
|
||||
if features.billing.enabled:
|
||||
count = 0
|
||||
if knowledge_config.data_source.info_list.data_source_type == "upload_file": # type: ignore
|
||||
if knowledge_config.data_source.info_list.data_source_type == "upload_file":
|
||||
upload_file_list = (
|
||||
knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore
|
||||
if knowledge_config.data_source.info_list.file_info_list # type: ignore
|
||||
knowledge_config.data_source.info_list.file_info_list.file_ids
|
||||
if knowledge_config.data_source.info_list.file_info_list
|
||||
else []
|
||||
)
|
||||
count = len(upload_file_list)
|
||||
elif knowledge_config.data_source.info_list.data_source_type == "notion_import": # type: ignore
|
||||
notion_info_list = knowledge_config.data_source.info_list.notion_info_list # type: ignore
|
||||
elif knowledge_config.data_source.info_list.data_source_type == "notion_import":
|
||||
notion_info_list = knowledge_config.data_source.info_list.notion_info_list
|
||||
if notion_info_list:
|
||||
for notion_info in notion_info_list:
|
||||
count = count + len(notion_info.pages)
|
||||
elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": # type: ignore
|
||||
website_info = knowledge_config.data_source.info_list.website_info_list # type: ignore
|
||||
elif knowledge_config.data_source.info_list.data_source_type == "website_crawl":
|
||||
website_info = knowledge_config.data_source.info_list.website_info_list
|
||||
if website_info:
|
||||
count = len(website_info.urls)
|
||||
if features.billing.subscription.plan == "sandbox" and count > 1:
|
||||
@@ -2196,9 +2201,11 @@ class DocumentService:
|
||||
dataset_collection_binding_id = None
|
||||
retrieval_model = None
|
||||
if knowledge_config.indexing_technique == "high_quality":
|
||||
assert knowledge_config.embedding_model_provider
|
||||
assert knowledge_config.embedding_model
|
||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
knowledge_config.embedding_model_provider, # type: ignore
|
||||
knowledge_config.embedding_model, # type: ignore
|
||||
knowledge_config.embedding_model_provider,
|
||||
knowledge_config.embedding_model,
|
||||
)
|
||||
dataset_collection_binding_id = dataset_collection_binding.id
|
||||
if knowledge_config.retrieval_model:
|
||||
@@ -2215,7 +2222,7 @@ class DocumentService:
|
||||
dataset = Dataset(
|
||||
tenant_id=tenant_id,
|
||||
name="",
|
||||
data_source_type=knowledge_config.data_source.info_list.data_source_type, # type: ignore
|
||||
data_source_type=knowledge_config.data_source.info_list.data_source_type,
|
||||
indexing_technique=knowledge_config.indexing_technique,
|
||||
created_by=account.id,
|
||||
embedding_model=knowledge_config.embedding_model,
|
||||
@@ -2224,7 +2231,7 @@ class DocumentService:
|
||||
retrieval_model=retrieval_model.model_dump() if retrieval_model else None,
|
||||
)
|
||||
|
||||
db.session.add(dataset) # type: ignore
|
||||
db.session.add(dataset)
|
||||
db.session.flush()
|
||||
|
||||
documents, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account)
|
||||
|
||||
@@ -88,7 +88,7 @@ class HitTestingService:
|
||||
db.session.add(dataset_query)
|
||||
db.session.commit()
|
||||
|
||||
return cls.compact_retrieve_response(query, all_documents) # type: ignore
|
||||
return cls.compact_retrieve_response(query, all_documents)
|
||||
|
||||
@classmethod
|
||||
def external_retrieve(
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import boto3 # type: ignore
|
||||
import boto3
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
|
||||
@@ -89,7 +89,7 @@ class MetadataService:
|
||||
document.doc_metadata = doc_metadata
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
return metadata # type: ignore
|
||||
return metadata
|
||||
except Exception:
|
||||
logger.exception("Update metadata name failed")
|
||||
finally:
|
||||
|
||||
@@ -137,7 +137,7 @@ class ModelProviderService:
|
||||
:return:
|
||||
"""
|
||||
provider_configuration = self._get_provider_configuration(tenant_id, provider)
|
||||
return provider_configuration.get_provider_credential(credential_id=credential_id) # type: ignore
|
||||
return provider_configuration.get_provider_credential(credential_id=credential_id)
|
||||
|
||||
def validate_provider_credentials(self, tenant_id: str, provider: str, credentials: dict):
|
||||
"""
|
||||
@@ -225,7 +225,7 @@ class ModelProviderService:
|
||||
:return:
|
||||
"""
|
||||
provider_configuration = self._get_provider_configuration(tenant_id, provider)
|
||||
return provider_configuration.get_custom_model_credential( # type: ignore
|
||||
return provider_configuration.get_custom_model_credential(
|
||||
model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id
|
||||
)
|
||||
|
||||
|
||||
@@ -146,7 +146,7 @@ class PluginMigration:
|
||||
futures.append(
|
||||
thread_pool.submit(
|
||||
process_tenant,
|
||||
current_app._get_current_object(), # type: ignore[attr-defined]
|
||||
current_app._get_current_object(), # type: ignore
|
||||
tenant_id,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -544,8 +544,8 @@ class BuiltinToolManageService:
|
||||
try:
|
||||
# handle include, exclude
|
||||
if is_filtered(
|
||||
include_set=dify_config.POSITION_TOOL_INCLUDES_SET, # type: ignore
|
||||
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, # type: ignore
|
||||
include_set=dify_config.POSITION_TOOL_INCLUDES_SET,
|
||||
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET,
|
||||
data=provider_controller,
|
||||
name_func=lambda x: x.entity.identity.name,
|
||||
):
|
||||
|
||||
@@ -308,7 +308,7 @@ class MCPToolManageService:
|
||||
provider_controller = MCPToolProviderController.from_db(mcp_provider)
|
||||
tool_configuration = ProviderConfigEncrypter(
|
||||
tenant_id=mcp_provider.tenant_id,
|
||||
config=list(provider_controller.get_credentials_schema()), # ty: ignore [invalid-argument-type]
|
||||
config=list(provider_controller.get_credentials_schema()),
|
||||
provider_config_cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
credentials = tool_configuration.encrypt(credentials)
|
||||
|
||||
@@ -102,7 +102,7 @@ def batch_create_segment_to_index_task(
|
||||
for segment, tokens in zip(content, tokens_list):
|
||||
content = segment["content"]
|
||||
doc_id = str(uuid.uuid4())
|
||||
segment_hash = helper.generate_text_hash(content) # type: ignore
|
||||
segment_hash = helper.generate_text_hash(content)
|
||||
max_position = (
|
||||
db.session.query(func.max(DocumentSegment.position))
|
||||
.where(DocumentSegment.document_id == dataset_document.id)
|
||||
|
||||
@@ -5,11 +5,11 @@ from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from pymochow import MochowClient # type: ignore
|
||||
from pymochow.model.database import Database # type: ignore
|
||||
from pymochow.model.enum import IndexState, IndexType, MetricType, ReadConsistency, TableState # type: ignore
|
||||
from pymochow.model.schema import HNSWParams, VectorIndex # type: ignore
|
||||
from pymochow.model.table import Table # type: ignore
|
||||
from pymochow import MochowClient
|
||||
from pymochow.model.database import Database
|
||||
from pymochow.model.enum import IndexState, IndexType, MetricType, ReadConsistency, TableState
|
||||
from pymochow.model.schema import HNSWParams, VectorIndex
|
||||
from pymochow.model.table import Table
|
||||
|
||||
|
||||
class AttrDict(UserDict):
|
||||
|
||||
@@ -3,15 +3,15 @@ from typing import Any, Union
|
||||
|
||||
import pytest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from tcvectordb import RPCVectorDBClient # type: ignore
|
||||
from tcvectordb import RPCVectorDBClient
|
||||
from tcvectordb.model import enum
|
||||
from tcvectordb.model.collection import FilterIndexConfig
|
||||
from tcvectordb.model.document import AnnSearch, Document, Filter, KeywordSearch, Rerank # type: ignore
|
||||
from tcvectordb.model.enum import ReadConsistency # type: ignore
|
||||
from tcvectordb.model.index import FilterIndex, HNSWParams, Index, IndexField, VectorIndex # type: ignore
|
||||
from tcvectordb.model.document import AnnSearch, Document, Filter, KeywordSearch, Rerank
|
||||
from tcvectordb.model.enum import ReadConsistency
|
||||
from tcvectordb.model.index import FilterIndex, HNSWParams, Index, IndexField, VectorIndex
|
||||
from tcvectordb.rpc.model.collection import RPCCollection
|
||||
from tcvectordb.rpc.model.database import RPCDatabase
|
||||
from xinference_client.types import Embedding # type: ignore
|
||||
from xinference_client.types import Embedding
|
||||
|
||||
|
||||
class MockTcvectordbClass:
|
||||
|
||||
@@ -4,7 +4,7 @@ from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from volcengine.viking_db import ( # type: ignore
|
||||
from volcengine.viking_db import (
|
||||
Collection,
|
||||
Data,
|
||||
DistanceType,
|
||||
|
||||
@@ -43,7 +43,7 @@ class TestWorkflowResponseConverterFetchFilesFromVariableValue:
|
||||
"""Test with None input"""
|
||||
# The method signature expects Union[dict, list, Segment], but implementation handles None
|
||||
# We'll test the actual behavior by passing an empty dict instead
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value(None) # type: ignore
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value(None)
|
||||
assert result == []
|
||||
|
||||
def test_fetch_files_from_variable_value_with_empty_dict(self):
|
||||
|
||||
@@ -235,7 +235,7 @@ class TestIndividualHandlers:
|
||||
# Type assertion needed due to union type
|
||||
text_content = result.content[0]
|
||||
assert hasattr(text_content, "text")
|
||||
assert text_content.text == "test answer" # type: ignore[attr-defined]
|
||||
assert text_content.text == "test answer"
|
||||
|
||||
def test_handle_call_tool_no_end_user(self):
|
||||
"""Test call tool handler without end user"""
|
||||
|
||||
@@ -212,7 +212,7 @@ class TestValidateResult:
|
||||
parameters=[
|
||||
ParameterConfig(
|
||||
name="status",
|
||||
type="select", # type: ignore
|
||||
type="select",
|
||||
description="Status",
|
||||
required=True,
|
||||
options=["active", "inactive"],
|
||||
@@ -400,7 +400,7 @@ class TestTransformResult:
|
||||
parameters=[
|
||||
ParameterConfig(
|
||||
name="status",
|
||||
type="select", # type: ignore
|
||||
type="select",
|
||||
description="Status",
|
||||
required=True,
|
||||
options=["active", "inactive"],
|
||||
@@ -414,7 +414,7 @@ class TestTransformResult:
|
||||
parameters=[
|
||||
ParameterConfig(
|
||||
name="status",
|
||||
type="select", # type: ignore
|
||||
type="select",
|
||||
description="Status",
|
||||
required=True,
|
||||
options=["active", "inactive"],
|
||||
|
||||
@@ -248,4 +248,4 @@ def test_constructor_with_extra_key():
|
||||
# Test that SystemVariable should forbid extra keys
|
||||
with pytest.raises(ValidationError):
|
||||
# This should fail because there is an unexpected key.
|
||||
SystemVariable(invalid_key=1) # type: ignore
|
||||
SystemVariable(invalid_key=1)
|
||||
|
||||
@@ -14,36 +14,36 @@ def _create_api_app():
|
||||
api = ExternalApi(bp)
|
||||
|
||||
@api.route("/bad-request")
|
||||
class Bad(Resource): # type: ignore
|
||||
def get(self): # type: ignore
|
||||
class Bad(Resource):
|
||||
def get(self):
|
||||
raise BadRequest("invalid input")
|
||||
|
||||
@api.route("/unauth")
|
||||
class Unauth(Resource): # type: ignore
|
||||
def get(self): # type: ignore
|
||||
class Unauth(Resource):
|
||||
def get(self):
|
||||
raise Unauthorized("auth required")
|
||||
|
||||
@api.route("/value-error")
|
||||
class ValErr(Resource): # type: ignore
|
||||
def get(self): # type: ignore
|
||||
class ValErr(Resource):
|
||||
def get(self):
|
||||
raise ValueError("boom")
|
||||
|
||||
@api.route("/quota")
|
||||
class Quota(Resource): # type: ignore
|
||||
def get(self): # type: ignore
|
||||
class Quota(Resource):
|
||||
def get(self):
|
||||
raise AppInvokeQuotaExceededError("quota exceeded")
|
||||
|
||||
@api.route("/general")
|
||||
class Gen(Resource): # type: ignore
|
||||
def get(self): # type: ignore
|
||||
class Gen(Resource):
|
||||
def get(self):
|
||||
raise RuntimeError("oops")
|
||||
|
||||
# Note: We avoid altering default_mediatype to keep normal error paths
|
||||
|
||||
# Special 400 message rewrite
|
||||
@api.route("/json-empty")
|
||||
class JsonEmpty(Resource): # type: ignore
|
||||
def get(self): # type: ignore
|
||||
class JsonEmpty(Resource):
|
||||
def get(self):
|
||||
e = BadRequest()
|
||||
# Force the specific message the handler rewrites
|
||||
e.description = "Failed to decode JSON object: Expecting value: line 1 column 1 (char 0)"
|
||||
@@ -51,11 +51,11 @@ def _create_api_app():
|
||||
|
||||
# 400 mapping payload path
|
||||
@api.route("/param-errors")
|
||||
class ParamErrors(Resource): # type: ignore
|
||||
def get(self): # type: ignore
|
||||
class ParamErrors(Resource):
|
||||
def get(self):
|
||||
e = BadRequest()
|
||||
# Coerce a mapping description to trigger param error shaping
|
||||
e.description = {"field": "is required"} # type: ignore[assignment]
|
||||
e.description = {"field": "is required"}
|
||||
raise e
|
||||
|
||||
app.register_blueprint(bp, url_prefix="/api")
|
||||
@@ -105,7 +105,7 @@ def test_external_api_param_mapping_and_quota_and_exc_info_none():
|
||||
|
||||
orig_exc_info = ext.sys.exc_info
|
||||
try:
|
||||
ext.sys.exc_info = lambda: (None, None, None) # type: ignore[assignment]
|
||||
ext.sys.exc_info = lambda: (None, None, None)
|
||||
|
||||
app = _create_api_app()
|
||||
client = app.test_client()
|
||||
|
||||
@@ -67,7 +67,7 @@ def test_current_user_not_accessible_across_threads(login_app: Flask, test_user:
|
||||
# without preserve_flask_contexts
|
||||
result["user_accessible"] = current_user.is_authenticated
|
||||
except Exception as e:
|
||||
result["error"] = str(e) # type: ignore
|
||||
result["error"] = str(e)
|
||||
|
||||
# Run the function in a separate thread
|
||||
thread = threading.Thread(target=check_user_in_thread)
|
||||
@@ -110,7 +110,7 @@ def test_current_user_accessible_with_preserve_flask_contexts(login_app: Flask,
|
||||
else:
|
||||
result["user_accessible"] = False
|
||||
except Exception as e:
|
||||
result["error"] = str(e) # type: ignore
|
||||
result["error"] = str(e)
|
||||
|
||||
# Run the function in a separate thread
|
||||
thread = threading.Thread(target=check_user_in_thread_with_manager)
|
||||
|
||||
@@ -16,4 +16,4 @@ def test_oauth_base_methods_raise_not_implemented():
|
||||
oauth.get_raw_user_info("token")
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
oauth._transform_user_info({}) # type: ignore[name-defined]
|
||||
oauth._transform_user_info({})
|
||||
|
||||
@@ -3,8 +3,8 @@ from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from qcloud_cos import CosS3Client # type: ignore
|
||||
from qcloud_cos.streambody import StreamBody # type: ignore
|
||||
from qcloud_cos import CosS3Client
|
||||
from qcloud_cos.streambody import StreamBody
|
||||
|
||||
from tests.unit_tests.oss.__mock.base import (
|
||||
get_example_bucket,
|
||||
|
||||
@@ -4,8 +4,8 @@ from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from tos import TosClientV2 # type: ignore
|
||||
from tos.clientv2 import DeleteObjectOutput, GetObjectOutput, HeadObjectOutput, PutObjectOutput # type: ignore
|
||||
from tos import TosClientV2
|
||||
from tos.clientv2 import DeleteObjectOutput, GetObjectOutput, HeadObjectOutput, PutObjectOutput
|
||||
|
||||
from tests.unit_tests.oss.__mock.base import (
|
||||
get_example_bucket,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from qcloud_cos import CosConfig # type: ignore
|
||||
from qcloud_cos import CosConfig
|
||||
|
||||
from extensions.storage.tencent_cos_storage import TencentCosStorage
|
||||
from tests.unit_tests.oss.__mock.base import (
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from tos import TosClientV2 # type: ignore
|
||||
from tos import TosClientV2
|
||||
|
||||
from extensions.storage.volcengine_tos_storage import VolcengineTosStorage
|
||||
from tests.unit_tests.oss.__mock.base import (
|
||||
|
||||
@@ -125,13 +125,13 @@ class TestApiKeyAuthService:
|
||||
mock_session.commit = Mock()
|
||||
|
||||
args_copy = self.mock_args.copy()
|
||||
original_key = args_copy["credentials"]["config"]["api_key"] # type: ignore
|
||||
original_key = args_copy["credentials"]["config"]["api_key"]
|
||||
|
||||
ApiKeyAuthService.create_provider_auth(self.tenant_id, args_copy)
|
||||
|
||||
# Verify original key is replaced with encrypted key
|
||||
assert args_copy["credentials"]["config"]["api_key"] == encrypted_key # type: ignore
|
||||
assert args_copy["credentials"]["config"]["api_key"] != original_key # type: ignore
|
||||
assert args_copy["credentials"]["config"]["api_key"] == encrypted_key
|
||||
assert args_copy["credentials"]["config"]["api_key"] != original_key
|
||||
|
||||
# Verify encryption function is called correctly
|
||||
mock_encrypter.encrypt_token.assert_called_once_with(self.tenant_id, original_key)
|
||||
@@ -268,7 +268,7 @@ class TestApiKeyAuthService:
|
||||
def test_validate_api_key_auth_args_empty_credentials(self):
|
||||
"""Test API key auth args validation - empty credentials"""
|
||||
args = self.mock_args.copy()
|
||||
args["credentials"] = None # type: ignore
|
||||
args["credentials"] = None
|
||||
|
||||
with pytest.raises(ValueError, match="credentials is required"):
|
||||
ApiKeyAuthService.validate_api_key_auth_args(args)
|
||||
@@ -284,7 +284,7 @@ class TestApiKeyAuthService:
|
||||
def test_validate_api_key_auth_args_missing_auth_type(self):
|
||||
"""Test API key auth args validation - missing auth_type"""
|
||||
args = self.mock_args.copy()
|
||||
del args["credentials"]["auth_type"] # type: ignore
|
||||
del args["credentials"]["auth_type"]
|
||||
|
||||
with pytest.raises(ValueError, match="auth_type is required"):
|
||||
ApiKeyAuthService.validate_api_key_auth_args(args)
|
||||
@@ -292,7 +292,7 @@ class TestApiKeyAuthService:
|
||||
def test_validate_api_key_auth_args_empty_auth_type(self):
|
||||
"""Test API key auth args validation - empty auth_type"""
|
||||
args = self.mock_args.copy()
|
||||
args["credentials"]["auth_type"] = "" # type: ignore
|
||||
args["credentials"]["auth_type"] = ""
|
||||
|
||||
with pytest.raises(ValueError, match="auth_type is required"):
|
||||
ApiKeyAuthService.validate_api_key_auth_args(args)
|
||||
@@ -380,7 +380,7 @@ class TestApiKeyAuthService:
|
||||
def test_validate_api_key_auth_args_dict_credentials_with_list_auth_type(self):
|
||||
"""Test API key auth args validation - dict credentials with list auth_type"""
|
||||
args = self.mock_args.copy()
|
||||
args["credentials"]["auth_type"] = ["api_key"] # type: ignore # list instead of string
|
||||
args["credentials"]["auth_type"] = ["api_key"]
|
||||
|
||||
# Current implementation checks if auth_type exists and is truthy, list ["api_key"] is truthy
|
||||
# So this should not raise exception, this test should pass
|
||||
|
||||
@@ -116,10 +116,10 @@ class TestSystemOAuthEncrypter:
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
|
||||
with pytest.raises(Exception): # noqa: B017
|
||||
encrypter.encrypt_oauth_params(None) # type: ignore
|
||||
encrypter.encrypt_oauth_params(None)
|
||||
|
||||
with pytest.raises(Exception): # noqa: B017
|
||||
encrypter.encrypt_oauth_params("not_a_dict") # type: ignore
|
||||
encrypter.encrypt_oauth_params("not_a_dict")
|
||||
|
||||
def test_decrypt_oauth_params_basic(self):
|
||||
"""Test basic OAuth parameters decryption"""
|
||||
@@ -207,12 +207,12 @@ class TestSystemOAuthEncrypter:
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
encrypter.decrypt_oauth_params(123) # type: ignore
|
||||
encrypter.decrypt_oauth_params(123)
|
||||
|
||||
assert "encrypted_data must be a string" in str(exc_info.value)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
encrypter.decrypt_oauth_params(None) # type: ignore
|
||||
encrypter.decrypt_oauth_params(None)
|
||||
|
||||
assert "encrypted_data must be a string" in str(exc_info.value)
|
||||
|
||||
@@ -461,14 +461,14 @@ class TestConvenienceFunctions:
|
||||
"""Test convenience functions with error conditions"""
|
||||
# Test encryption with invalid input
|
||||
with pytest.raises(Exception): # noqa: B017
|
||||
encrypt_system_oauth_params(None) # type: ignore
|
||||
encrypt_system_oauth_params(None)
|
||||
|
||||
# Test decryption with invalid input
|
||||
with pytest.raises(ValueError):
|
||||
decrypt_system_oauth_params("")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
decrypt_system_oauth_params(None) # type: ignore
|
||||
decrypt_system_oauth_params(None)
|
||||
|
||||
|
||||
class TestErrorHandling:
|
||||
@@ -501,7 +501,7 @@ class TestErrorHandling:
|
||||
|
||||
# Test non-string error
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
encrypter.decrypt_oauth_params(123) # type: ignore
|
||||
encrypter.decrypt_oauth_params(123)
|
||||
assert "encrypted_data must be a string" in str(exc_info.value)
|
||||
|
||||
# Test invalid format error
|
||||
|
||||
Reference in New Issue
Block a user