refactor(api): tighten types for Tenant.custom_config_dict and MCPToolProvider.headers (#34698)

This commit is contained in:
corevibe555
2026-04-08 04:36:42 +03:00
committed by GitHub
parent 2127d5850f
commit c825d5dcf6
22 changed files with 42 additions and 48 deletions

View File

@@ -66,7 +66,7 @@ from core.app.entities.queue_entities import (
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
)
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.rag.entities import RetrievalSourceMetadata
from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id, resolve_workflow_node_class
from core.workflow.system_variables import (
build_bootstrap_variables,

View File

@@ -10,7 +10,7 @@ from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChun
from pydantic import BaseModel, ConfigDict, Field
from core.app.entities.agent_strategy import AgentStrategyInfo
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.rag.entities import RetrievalSourceMetadata
class QueueEvent(StrEnum):

View File

@@ -9,7 +9,7 @@ from graphon.nodes.human_input.entities import FormInput, UserAction
from pydantic import BaseModel, ConfigDict, Field
from core.app.entities.agent_strategy import AgentStrategyInfo
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.rag.entities import RetrievalSourceMetadata
class AnnotationReplyAccount(BaseModel):

View File

@@ -6,7 +6,7 @@ from sqlalchemy import select, update
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.rag.entities import RetrievalSourceMetadata
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.models.document import Document
from extensions.ext_database import db

View File

@@ -9,7 +9,7 @@ from yarl import URL
from configs import dify_config
from core.entities.provider_entities import ProviderConfig
from core.plugin.entities.oauth import OAuthSchema
from core.plugin.entities import OAuthSchema
from core.plugin.entities.parameters import (
PluginParameter,
PluginParameterOption,

View File

@@ -1 +1,8 @@
from core.entities.plugin_credential_type import PluginCredentialType
DEFAULT_PLUGIN_ID = "langgenius"
__all__ = [
"DEFAULT_PLUGIN_ID",
"PluginCredentialType",
]

View File

@@ -0,0 +1,9 @@
import enum
class PluginCredentialType(enum.Enum):
MODEL = 0 # must be 0 for API contract compatibility
TOOL = 1 # must be 1 for API contract compatibility
def to_number(self):
return self.value

View File

@@ -22,6 +22,7 @@ from sqlalchemy import func, select
from sqlalchemy.orm import Session
from constants import HIDDEN_VALUE
from core.entities import PluginCredentialType
from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity
from core.entities.provider_entities import (
CustomConfiguration,
@@ -46,7 +47,6 @@ from models.provider import (
TenantPreferredModelProvider,
)
from models.provider_ids import ModelProviderID
from services.enterprise.plugin_manager_service import PluginCredentialType
logger = logging.getLogger(__name__)

View File

@@ -2,7 +2,7 @@
Credential utility functions for checking credential existence and policy compliance.
"""
from services.enterprise.plugin_manager_service import PluginCredentialType
from core.entities import PluginCredentialType
def is_credential_exists(credential_id: str, credential_type: "PluginCredentialType") -> bool:

View File

@@ -17,6 +17,7 @@ from graphon.model_runtime.model_providers.__base.text_embedding_model import Te
from graphon.model_runtime.model_providers.__base.tts_model import TTSModel
from configs import dify_config
from core.entities import PluginCredentialType
from core.entities.embedding_type import EmbeddingInputType
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
from core.entities.provider_entities import ModelLoadBalancingConfiguration
@@ -25,7 +26,6 @@ from core.plugin.impl.model_runtime_factory import create_plugin_provider_manage
from core.provider_manager import ProviderManager
from extensions.ext_redis import redis_client
from models.provider import ProviderType
from services.enterprise.plugin_manager_service import PluginCredentialType
logger = logging.getLogger(__name__)

View File

@@ -0,0 +1,5 @@
from core.plugin.entities.oauth import OAuthSchema
__all__ = [
"OAuthSchema",
]

View File

@@ -1,5 +1,3 @@
from collections.abc import Sequence
from pydantic import BaseModel, Field
from core.entities.provider_entities import ProviderConfig
@@ -10,12 +8,12 @@ class OAuthSchema(BaseModel):
OAuth schema
"""
client_schema: Sequence[ProviderConfig] = Field(
client_schema: list[ProviderConfig] = Field(
default_factory=list,
description="client schema like client_id, client_secret, etc.",
)
credentials_schema: Sequence[ProviderConfig] = Field(
credentials_schema: list[ProviderConfig] = Field(
default_factory=list,
description="credentials schema like access_token, refresh_token, etc.",
)

View File

@@ -39,9 +39,7 @@ from core.prompt.simple_prompt_transform import ModelMode
from core.rag.data_post_processor.data_post_processor import DataPostProcessor, RerankingModelDict, WeightsDict
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
from core.rag.datasource.retrieval_service import DefaultRetrievalModelDict, RetrievalService
from core.rag.entities import Condition
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.rag.entities.context_entities import DocumentContext
from core.rag.entities import Condition, DocumentContext, RetrievalSourceMetadata
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from core.rag.index_processor.constant.query_type import QueryType

View File

@@ -19,6 +19,7 @@ from pydantic import (
from typing_extensions import TypedDict
from core.entities.provider_entities import ProviderConfig
from core.plugin.entities import OAuthSchema
from core.plugin.entities.parameters import (
MCPServerParameterType,
PluginParameter,
@@ -28,7 +29,7 @@ from core.plugin.entities.parameters import (
cast_parameter_value,
init_frontend_parameter,
)
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.rag.entities import RetrievalSourceMetadata
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.constants import TOOL_SELECTOR_MODEL_IDENTITY
@@ -428,15 +429,6 @@ class ToolEntity(BaseModel):
return value or {}
class OAuthSchema(BaseModel):
client_schema: list[ProviderConfig] = Field(
default_factory=list[ProviderConfig], description="The schema of the OAuth client"
)
credentials_schema: list[ProviderConfig] = Field(
default_factory=list[ProviderConfig], description="The schema of the OAuth credentials"
)
class ToolProviderEntity(BaseModel):
identity: ToolProviderIdentity
plugin_id: str | None = None

View File

@@ -17,6 +17,7 @@ from yarl import URL
import contexts
from configs import dify_config
from core.entities import PluginCredentialType
from core.helper.provider_cache import ToolProviderCredentialsCache
from core.plugin.impl.tool import PluginToolManager
from core.tools.__base.tool_provider import ToolProviderController
@@ -29,7 +30,6 @@ from core.tools.utils.uuid_utils import is_valid_uuid
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
from extensions.ext_database import db
from models.provider_ids import ToolProviderID
from services.enterprise.plugin_manager_service import PluginCredentialType
from services.tools.mcp_tools_manage_service import MCPToolManageService
if TYPE_CHECKING:

View File

@@ -8,7 +8,7 @@ from sqlalchemy import select
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.model_manager import ModelManager
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.rag.entities import RetrievalSourceMetadata
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from core.rag.models.document import Document as RagDocument
from core.rag.rerank.rerank_model import RerankModelRunner

View File

@@ -6,8 +6,7 @@ from sqlalchemy import select
from core.app.app_config.entities import DatasetRetrieveConfigEntity, ModelConfig
from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.rag.entities.context_entities import DocumentContext
from core.rag.entities import DocumentContext, RetrievalSourceMetadata
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from core.rag.models.document import Document as RetrievalDocument
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval

View File

@@ -6,6 +6,7 @@ from typing import Any, Union
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator
from core.entities.provider_entities import ProviderConfig
from core.plugin.entities import OAuthSchema
from core.plugin.entities.parameters import (
PluginParameterAutoGenerate,
PluginParameterOption,
@@ -108,13 +109,6 @@ class EventEntity(BaseModel):
return v or []
class OAuthSchema(BaseModel):
client_schema: list[ProviderConfig] = Field(default_factory=list, description="The schema of the OAuth client")
credentials_schema: list[ProviderConfig] = Field(
default_factory=list, description="The schema of the OAuth credentials"
)
class SubscriptionConstructor(BaseModel):
"""
The subscription constructor of the trigger provider

View File

@@ -1,23 +1,15 @@
import enum
import logging
from pydantic import BaseModel
from configs import dify_config
from core.entities import PluginCredentialType
from services.enterprise.base import EnterprisePluginManagerRequest
from services.errors.base import BaseServiceError
logger = logging.getLogger(__name__)
class PluginCredentialType(enum.Enum):
MODEL = 0 # must be 0 for API contract compatibility
TOOL = 1 # must be 1 for API contract compatibility
def to_number(self):
return self.value
class CheckCredentialPolicyComplianceRequest(BaseModel):
dify_credential_id: str
provider: str

View File

@@ -38,6 +38,7 @@ from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context
from core.app.file_access import DatabaseFileAccessController
from core.entities import PluginCredentialType
from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly, create_plugin_provider_manager
from core.repositories import DifyCoreRepositoryFactory
from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl
@@ -66,7 +67,6 @@ from models.tools import WorkflowToolProvider
from models.workflow import Workflow, WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom, WorkflowType
from repositories.factory import DifyAPIRepositoryFactory
from services.billing_service import BillingService
from services.enterprise.plugin_manager_service import PluginCredentialType
from services.errors.app import (
IsDraftWorkflowError,
TriggerNodeLimitExceededError,

View File

@@ -9,7 +9,7 @@ from flask import Flask, current_app
from core.app.entities.queue_entities import QueueAnnotationReplyEvent, QueueRetrieverResourcesEvent
from core.app.entities.task_entities import MessageStreamResponse, StreamEvent, TaskStateMetadata
from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.rag.entities import RetrievalSourceMetadata
from models.model import AppMode

View File

@@ -4,8 +4,8 @@ from typing import cast
import pytest
from pytest_mock import MockerFixture
from core.entities import PluginCredentialType
from core.helper.credential_utils import check_credential_policy_compliance, is_credential_exists
from services.enterprise.plugin_manager_service import PluginCredentialType
def test_check_credential_policy_compliance_returns_when_feature_disabled(