mirror of
https://github.com/langgenius/dify.git
synced 2026-04-10 12:00:26 -04:00
refactor(api): tighten types for Tenant.custom_config_dict and MCPToolProvider.headers (#34698)
This commit is contained in:
@@ -66,7 +66,7 @@ from core.app.entities.queue_entities import (
|
||||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.rag.entities import RetrievalSourceMetadata
|
||||
from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id, resolve_workflow_node_class
|
||||
from core.workflow.system_variables import (
|
||||
build_bootstrap_variables,
|
||||
|
||||
@@ -10,7 +10,7 @@ from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChun
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from core.app.entities.agent_strategy import AgentStrategyInfo
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.rag.entities import RetrievalSourceMetadata
|
||||
|
||||
|
||||
class QueueEvent(StrEnum):
|
||||
|
||||
@@ -9,7 +9,7 @@ from graphon.nodes.human_input.entities import FormInput, UserAction
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from core.app.entities.agent_strategy import AgentStrategyInfo
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.rag.entities import RetrievalSourceMetadata
|
||||
|
||||
|
||||
class AnnotationReplyAccount(BaseModel):
|
||||
|
||||
@@ -6,7 +6,7 @@ from sqlalchemy import select, update
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.rag.entities import RetrievalSourceMetadata
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
|
||||
@@ -9,7 +9,7 @@ from yarl import URL
|
||||
|
||||
from configs import dify_config
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.plugin.entities.oauth import OAuthSchema
|
||||
from core.plugin.entities import OAuthSchema
|
||||
from core.plugin.entities.parameters import (
|
||||
PluginParameter,
|
||||
PluginParameterOption,
|
||||
|
||||
@@ -1 +1,8 @@
|
||||
from core.entities.plugin_credential_type import PluginCredentialType
|
||||
|
||||
DEFAULT_PLUGIN_ID = "langgenius"
|
||||
|
||||
__all__ = [
|
||||
"DEFAULT_PLUGIN_ID",
|
||||
"PluginCredentialType",
|
||||
]
|
||||
|
||||
9
api/core/entities/plugin_credential_type.py
Normal file
9
api/core/entities/plugin_credential_type.py
Normal file
@@ -0,0 +1,9 @@
|
||||
import enum
|
||||
|
||||
|
||||
class PluginCredentialType(enum.Enum):
|
||||
MODEL = 0 # must be 0 for API contract compatibility
|
||||
TOOL = 1 # must be 1 for API contract compatibility
|
||||
|
||||
def to_number(self):
|
||||
return self.value
|
||||
@@ -22,6 +22,7 @@ from sqlalchemy import func, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from constants import HIDDEN_VALUE
|
||||
from core.entities import PluginCredentialType
|
||||
from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity
|
||||
from core.entities.provider_entities import (
|
||||
CustomConfiguration,
|
||||
@@ -46,7 +47,6 @@ from models.provider import (
|
||||
TenantPreferredModelProvider,
|
||||
)
|
||||
from models.provider_ids import ModelProviderID
|
||||
from services.enterprise.plugin_manager_service import PluginCredentialType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
Credential utility functions for checking credential existence and policy compliance.
|
||||
"""
|
||||
|
||||
from services.enterprise.plugin_manager_service import PluginCredentialType
|
||||
from core.entities import PluginCredentialType
|
||||
|
||||
|
||||
def is_credential_exists(credential_id: str, credential_type: "PluginCredentialType") -> bool:
|
||||
|
||||
@@ -17,6 +17,7 @@ from graphon.model_runtime.model_providers.__base.text_embedding_model import Te
|
||||
from graphon.model_runtime.model_providers.__base.tts_model import TTSModel
|
||||
|
||||
from configs import dify_config
|
||||
from core.entities import PluginCredentialType
|
||||
from core.entities.embedding_type import EmbeddingInputType
|
||||
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
|
||||
from core.entities.provider_entities import ModelLoadBalancingConfiguration
|
||||
@@ -25,7 +26,6 @@ from core.plugin.impl.model_runtime_factory import create_plugin_provider_manage
|
||||
from core.provider_manager import ProviderManager
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.provider import ProviderType
|
||||
from services.enterprise.plugin_manager_service import PluginCredentialType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
5
api/core/plugin/entities/__init__.py
Normal file
5
api/core/plugin/entities/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from core.plugin.entities.oauth import OAuthSchema
|
||||
|
||||
__all__ = [
|
||||
"OAuthSchema",
|
||||
]
|
||||
@@ -1,5 +1,3 @@
|
||||
from collections.abc import Sequence
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
@@ -10,12 +8,12 @@ class OAuthSchema(BaseModel):
|
||||
OAuth schema
|
||||
"""
|
||||
|
||||
client_schema: Sequence[ProviderConfig] = Field(
|
||||
client_schema: list[ProviderConfig] = Field(
|
||||
default_factory=list,
|
||||
description="client schema like client_id, client_secret, etc.",
|
||||
)
|
||||
|
||||
credentials_schema: Sequence[ProviderConfig] = Field(
|
||||
credentials_schema: list[ProviderConfig] = Field(
|
||||
default_factory=list,
|
||||
description="credentials schema like access_token, refresh_token, etc.",
|
||||
)
|
||||
|
||||
@@ -39,9 +39,7 @@ from core.prompt.simple_prompt_transform import ModelMode
|
||||
from core.rag.data_post_processor.data_post_processor import DataPostProcessor, RerankingModelDict, WeightsDict
|
||||
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
|
||||
from core.rag.datasource.retrieval_service import DefaultRetrievalModelDict, RetrievalService
|
||||
from core.rag.entities import Condition
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.rag.entities.context_entities import DocumentContext
|
||||
from core.rag.entities import Condition, DocumentContext, RetrievalSourceMetadata
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
from core.rag.index_processor.constant.query_type import QueryType
|
||||
|
||||
@@ -19,6 +19,7 @@ from pydantic import (
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.plugin.entities import OAuthSchema
|
||||
from core.plugin.entities.parameters import (
|
||||
MCPServerParameterType,
|
||||
PluginParameter,
|
||||
@@ -28,7 +29,7 @@ from core.plugin.entities.parameters import (
|
||||
cast_parameter_value,
|
||||
init_frontend_parameter,
|
||||
)
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.rag.entities import RetrievalSourceMetadata
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.constants import TOOL_SELECTOR_MODEL_IDENTITY
|
||||
|
||||
@@ -428,15 +429,6 @@ class ToolEntity(BaseModel):
|
||||
return value or {}
|
||||
|
||||
|
||||
class OAuthSchema(BaseModel):
|
||||
client_schema: list[ProviderConfig] = Field(
|
||||
default_factory=list[ProviderConfig], description="The schema of the OAuth client"
|
||||
)
|
||||
credentials_schema: list[ProviderConfig] = Field(
|
||||
default_factory=list[ProviderConfig], description="The schema of the OAuth credentials"
|
||||
)
|
||||
|
||||
|
||||
class ToolProviderEntity(BaseModel):
|
||||
identity: ToolProviderIdentity
|
||||
plugin_id: str | None = None
|
||||
|
||||
@@ -17,6 +17,7 @@ from yarl import URL
|
||||
|
||||
import contexts
|
||||
from configs import dify_config
|
||||
from core.entities import PluginCredentialType
|
||||
from core.helper.provider_cache import ToolProviderCredentialsCache
|
||||
from core.plugin.impl.tool import PluginToolManager
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
@@ -29,7 +30,6 @@ from core.tools.utils.uuid_utils import is_valid_uuid
|
||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||
from extensions.ext_database import db
|
||||
from models.provider_ids import ToolProviderID
|
||||
from services.enterprise.plugin_manager_service import PluginCredentialType
|
||||
from services.tools.mcp_tools_manage_service import MCPToolManageService
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@@ -8,7 +8,7 @@ from sqlalchemy import select
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.model_manager import ModelManager
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.rag.entities import RetrievalSourceMetadata
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from core.rag.models.document import Document as RagDocument
|
||||
from core.rag.rerank.rerank_model import RerankModelRunner
|
||||
|
||||
@@ -6,8 +6,7 @@ from sqlalchemy import select
|
||||
from core.app.app_config.entities import DatasetRetrieveConfigEntity, ModelConfig
|
||||
from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.rag.entities.context_entities import DocumentContext
|
||||
from core.rag.entities import DocumentContext, RetrievalSourceMetadata
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from core.rag.models.document import Document as RetrievalDocument
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Any, Union
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator
|
||||
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.plugin.entities import OAuthSchema
|
||||
from core.plugin.entities.parameters import (
|
||||
PluginParameterAutoGenerate,
|
||||
PluginParameterOption,
|
||||
@@ -108,13 +109,6 @@ class EventEntity(BaseModel):
|
||||
return v or []
|
||||
|
||||
|
||||
class OAuthSchema(BaseModel):
|
||||
client_schema: list[ProviderConfig] = Field(default_factory=list, description="The schema of the OAuth client")
|
||||
credentials_schema: list[ProviderConfig] = Field(
|
||||
default_factory=list, description="The schema of the OAuth credentials"
|
||||
)
|
||||
|
||||
|
||||
class SubscriptionConstructor(BaseModel):
|
||||
"""
|
||||
The subscription constructor of the trigger provider
|
||||
|
||||
@@ -1,23 +1,15 @@
|
||||
import enum
|
||||
import logging
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from configs import dify_config
|
||||
from core.entities import PluginCredentialType
|
||||
from services.enterprise.base import EnterprisePluginManagerRequest
|
||||
from services.errors.base import BaseServiceError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PluginCredentialType(enum.Enum):
|
||||
MODEL = 0 # must be 0 for API contract compatibility
|
||||
TOOL = 1 # must be 1 for API contract compatibility
|
||||
|
||||
def to_number(self):
|
||||
return self.value
|
||||
|
||||
|
||||
class CheckCredentialPolicyComplianceRequest(BaseModel):
|
||||
dify_credential_id: str
|
||||
provider: str
|
||||
|
||||
@@ -38,6 +38,7 @@ from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
|
||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context
|
||||
from core.app.file_access import DatabaseFileAccessController
|
||||
from core.entities import PluginCredentialType
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly, create_plugin_provider_manager
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl
|
||||
@@ -66,7 +67,6 @@ from models.tools import WorkflowToolProvider
|
||||
from models.workflow import Workflow, WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom, WorkflowType
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
from services.billing_service import BillingService
|
||||
from services.enterprise.plugin_manager_service import PluginCredentialType
|
||||
from services.errors.app import (
|
||||
IsDraftWorkflowError,
|
||||
TriggerNodeLimitExceededError,
|
||||
|
||||
@@ -9,7 +9,7 @@ from flask import Flask, current_app
|
||||
from core.app.entities.queue_entities import QueueAnnotationReplyEvent, QueueRetrieverResourcesEvent
|
||||
from core.app.entities.task_entities import MessageStreamResponse, StreamEvent, TaskStateMetadata
|
||||
from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.rag.entities import RetrievalSourceMetadata
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
|
||||
@@ -4,8 +4,8 @@ from typing import cast
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from core.entities import PluginCredentialType
|
||||
from core.helper.credential_utils import check_credential_policy_compliance, is_credential_exists
|
||||
from services.enterprise.plugin_manager_service import PluginCredentialType
|
||||
|
||||
|
||||
def test_check_credential_policy_compliance_returns_when_feature_disabled(
|
||||
|
||||
Reference in New Issue
Block a user