From 8d99326fb3b334ac7abff538037004a042a7ff12 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Sat, 23 May 2026 17:12:09 +0800 Subject: [PATCH] feat(plugin): cache plugin model providers by tenant (#36449) Co-authored-by: WH-2099 --- api/.env.example | 1 + api/commands/plugin.py | 2 +- api/configs/feature/__init__.py | 5 + api/controllers/console/workspace/plugin.py | 2 +- api/core/plugin/impl/model_runtime.py | 64 +-- api/core/plugin/impl/model_runtime_factory.py | 2 + .../plugin/plugin_service.py | 203 +++++++-- api/core/trigger/provider.py | 2 +- api/models/model.py | 2 +- api/services/datasource_provider_service.py | 2 +- api/services/plugin/plugin_migration.py | 28 +- .../rag_pipeline_transform_service.py | 2 +- .../tools/builtin_tools_manage_service.py | 2 +- api/services/tools/tools_transform_service.py | 2 +- .../trigger/trigger_provider_service.py | 2 +- api/services/workflow_app_service.py | 2 +- ...ss_tenant_plugin_autoupgrade_check_task.py | 2 +- .../services/plugin/test_plugin_service.py | 96 ++--- .../tools/test_tools_transform_service.py | 4 +- .../plugin/impl/test_model_runtime_factory.py | 9 + .../core/plugin/test_model_runtime_adapter.py | 207 +++++++-- .../unit_tests/models/test_app_models.py | 50 +++ .../unit_tests/services/plugin/conftest.py | 4 +- .../services/plugin/test_plugin_migration.py | 28 ++ .../services/plugin/test_plugin_service.py | 392 +++++++++++++++++- docker/envs/core-services/shared.env.example | 1 + 26 files changed, 928 insertions(+), 188 deletions(-) rename api/{services => core}/plugin/plugin_service.py (73%) diff --git a/api/.env.example b/api/.env.example index 833d83797d..f645ba7bf0 100644 --- a/api/.env.example +++ b/api/.env.example @@ -657,6 +657,7 @@ PLUGIN_REMOTE_INSTALL_PORT=5003 PLUGIN_REMOTE_INSTALL_HOST=localhost PLUGIN_MAX_PACKAGE_SIZE=15728640 PLUGIN_MODEL_SCHEMA_CACHE_TTL=3600 +PLUGIN_MODEL_PROVIDERS_CACHE_TTL=86400 INNER_API_KEY_FOR_PLUGIN=QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1 # Marketplace configuration diff --git a/api/commands/plugin.py b/api/commands/plugin.py index 8ad2321b07..e1b3cf0fa1 100644 --- a/api/commands/plugin.py +++ b/api/commands/plugin.py @@ -11,6 +11,7 @@ from configs import dify_config from core.helper import encrypter from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.impl.plugin import PluginInstaller +from core.plugin.plugin_service import PluginService from core.tools.utils.system_encryption import encrypt_system_params from extensions.ext_database import db from models import Tenant @@ -20,7 +21,6 @@ from models.source import DataSourceApiKeyAuthBinding, DataSourceOauthBinding from models.tools import ToolOAuthSystemClient from services.plugin.data_migration import PluginDataMigration from services.plugin.plugin_migration import PluginMigration -from services.plugin.plugin_service import PluginService logger = logging.getLogger(__name__) diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index a752d9d103..d2e6d144ac 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -265,6 +265,11 @@ class PluginConfig(BaseSettings): default=60 * 60, ) + PLUGIN_MODEL_PROVIDERS_CACHE_TTL: PositiveInt = Field( + description="TTL in seconds for caching tenant plugin model providers in Redis", + default=60 * 60 * 24, + ) + PLUGIN_MAX_FILE_SIZE: PositiveInt = Field( description="Maximum allowed size (bytes) for plugin-generated files", default=50 * 1024 * 1024, diff --git a/api/controllers/console/workspace/plugin.py b/api/controllers/console/workspace/plugin.py index 9e5e766f45..c41bf99563 100644 --- a/api/controllers/console/workspace/plugin.py +++ b/api/controllers/console/workspace/plugin.py @@ -15,6 +15,7 @@ from controllers.console import console_ns from controllers.console.workspace import plugin_permission_required from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required from core.plugin.impl.exc import PluginDaemonClientSideError +from core.plugin.plugin_service import PluginService from fields.base import ResponseModel from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_account_with_tenant, login_required @@ -22,7 +23,6 @@ from models.account import TenantPluginAutoUpgradeStrategy, TenantPluginPermissi from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService from services.plugin.plugin_parameter_service import PluginParameterService from services.plugin.plugin_permission_service import PluginPermissionService -from services.plugin.plugin_service import PluginService class ParserList(BaseModel): diff --git a/api/core/plugin/impl/model_runtime.py b/api/core/plugin/impl/model_runtime.py index d555f4d965..3d5ba94f2b 100644 --- a/api/core/plugin/impl/model_runtime.py +++ b/api/core/plugin/impl/model_runtime.py @@ -3,7 +3,6 @@ from __future__ import annotations import hashlib import logging from collections.abc import Generator, Iterable, Sequence -from threading import Lock from typing import IO, Any, Literal, cast, overload, override from pydantic import ValidationError @@ -13,9 +12,9 @@ from configs import dify_config from core.llm_generator.output_parser.structured_output import ( invoke_llm_with_structured_output as invoke_llm_with_structured_output_helper, ) -from core.plugin.entities.plugin_daemon import PluginModelProviderEntity from core.plugin.impl.asset import PluginAssetManager from core.plugin.impl.model import PluginModelClient +from core.plugin.plugin_service import PluginService from extensions.ext_redis import redis_client from graphon.model_runtime.entities.llm_entities import ( LLMResult, @@ -101,35 +100,36 @@ class _PluginStructuredOutputModelInstance: class PluginModelRuntime(ModelRuntime): - """Plugin-backed runtime adapter bound to tenant context and optional caller scope.""" + """Plugin-backed runtime adapter bound to tenant context and optional caller scope. + + Provider discovery goes through ``PluginService`` so the plugin lifecycle + methods and provider reads share one tenant-scoped cache owner. + """ tenant_id: str user_id: str | None client: PluginModelClient - _provider_entities: tuple[ProviderEntity, ...] | None - _provider_entities_lock: Lock + _plugin_service: type[PluginService] - def __init__(self, tenant_id: str, user_id: str | None, client: PluginModelClient) -> None: + def __init__( + self, + tenant_id: str, + user_id: str | None, + client: PluginModelClient, + plugin_service: type[PluginService], + ) -> None: if client is None: raise ValueError("client is required.") + if plugin_service is None: + raise ValueError("plugin_service is required.") self.tenant_id = tenant_id self.user_id = user_id self.client = client - self._provider_entities = None - self._provider_entities_lock = Lock() + self._plugin_service = plugin_service @override def fetch_model_providers(self) -> Sequence[ProviderEntity]: - if self._provider_entities is not None: - return self._provider_entities - - with self._provider_entities_lock: - if self._provider_entities is None: - self._provider_entities = tuple( - self._to_provider_entity(provider) for provider in self.client.fetch_model_providers(self.tenant_id) - ) - - return self._provider_entities + return self._plugin_service.fetch_plugin_model_providers(tenant_id=self.tenant_id, client=self.client) @override def get_provider_icon(self, *, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]: @@ -628,34 +628,6 @@ class PluginModelRuntime(ModelRuntime): text=text, ) - def _get_provider_short_name_alias(self, provider: PluginModelProviderEntity) -> str: - """ - Expose a bare provider alias only for the canonical provider mapping. - - Multiple plugins can publish the same short provider slug. If every - provider entity keeps that slug in ``provider_name``, callers that still - resolve by short name become order-dependent. Restrict the alias to the - provider selected by ``ModelProviderID`` so legacy short-name lookups - remain deterministic while the runtime surface stays canonical. - """ - try: - canonical_provider_id = ModelProviderID(provider.provider) - except ValueError: - return "" - - if canonical_provider_id.plugin_id != provider.plugin_id: - return "" - if canonical_provider_id.provider_name != provider.provider: - return "" - - return provider.provider - - def _to_provider_entity(self, provider: PluginModelProviderEntity) -> ProviderEntity: - declaration = provider.declaration.model_copy(deep=True) - declaration.provider = f"{provider.plugin_id}/{provider.provider}" - declaration.provider_name = self._get_provider_short_name_alias(provider) - return declaration - def _get_provider_schema(self, provider: str) -> ProviderEntity: providers = self.fetch_model_providers() provider_entity = next((item for item in providers if item.provider == provider), None) diff --git a/api/core/plugin/impl/model_runtime_factory.py b/api/core/plugin/impl/model_runtime_factory.py index fbe307ea60..32a4c7b89a 100644 --- a/api/core/plugin/impl/model_runtime_factory.py +++ b/api/core/plugin/impl/model_runtime_factory.py @@ -3,6 +3,7 @@ from __future__ import annotations from typing import TYPE_CHECKING from core.plugin.impl.model import PluginModelClient +from core.plugin.plugin_service import PluginService from graphon.model_runtime.entities.model_entities import ModelType from graphon.model_runtime.entities.provider_entities import ProviderEntity from graphon.model_runtime.model_providers.base.ai_model import AIModel @@ -117,6 +118,7 @@ def create_plugin_model_runtime(*, tenant_id: str, user_id: str | None = None) - tenant_id=tenant_id, user_id=user_id, client=PluginModelClient(), + plugin_service=PluginService, ) diff --git a/api/services/plugin/plugin_service.py b/api/core/plugin/plugin_service.py similarity index 73% rename from api/services/plugin/plugin_service.py rename to api/core/plugin/plugin_service.py index 72271c55d8..a88ddb5f3d 100644 --- a/api/services/plugin/plugin_service.py +++ b/api/core/plugin/plugin_service.py @@ -1,8 +1,17 @@ +"""Core plugin service and tenant-scoped plugin metadata cache ownership. + +This module owns plugin daemon management calls that are shared by API services +and core runtimes. Plugin model provider discovery is cached here, alongside +plugin install, uninstall, and upgrade invalidation, so all cache mutations for +plugin-owned provider metadata stay tenant-scoped and in one place. +""" + import logging from collections.abc import Mapping, Sequence from mimetypes import guess_type -from pydantic import BaseModel +from pydantic import BaseModel, TypeAdapter, ValidationError +from redis import RedisError from sqlalchemy import delete, select, update from sqlalchemy.orm import Session from yarl import URL @@ -22,16 +31,20 @@ from core.plugin.entities.plugin import ( from core.plugin.entities.plugin_daemon import ( PluginDecodeResponse, PluginInstallTask, + PluginInstallTaskStatus, PluginListResponse, + PluginModelProviderEntity, PluginVerification, ) from core.plugin.impl.asset import PluginAssetManager from core.plugin.impl.debugging import PluginDebuggingClient +from core.plugin.impl.model import PluginModelClient from core.plugin.impl.plugin import PluginInstaller from extensions.ext_database import db from extensions.ext_redis import redis_client +from graphon.model_runtime.entities.provider_entities import ProviderEntity from models.provider import Provider, ProviderCredential, TenantPreferredModelProvider -from models.provider_ids import GenericProviderID +from models.provider_ids import GenericProviderID, ModelProviderID from services.enterprise.plugin_manager_service import ( PluginManagerService, PreUninstallPluginRequest, @@ -40,6 +53,7 @@ from services.errors.plugin import PluginInstallationForbiddenError from services.feature_service import FeatureService, PluginInstallationScope logger = logging.getLogger(__name__) +_provider_entities_adapter: TypeAdapter[list[ProviderEntity]] = TypeAdapter(list[ProviderEntity]) class PluginService: @@ -53,6 +67,102 @@ class PluginService: REDIS_KEY_PREFIX = "plugin_service:latest_plugin:" REDIS_TTL = 60 * 5 # 5 minutes + PLUGIN_MODEL_PROVIDERS_REDIS_KEY_PREFIX = "plugin_model_providers:tenant_id:" + PLUGIN_INSTALL_TASK_TERMINAL_STATUSES = (PluginInstallTaskStatus.Success, PluginInstallTaskStatus.Failed) + + @classmethod + def _get_plugin_model_providers_cache_key(cls, tenant_id: str) -> str: + return f"{cls.PLUGIN_MODEL_PROVIDERS_REDIS_KEY_PREFIX}{tenant_id}" + + @staticmethod + def _get_provider_short_name_alias(provider: PluginModelProviderEntity) -> str: + """ + Expose a bare provider alias only for the canonical provider mapping. + + Multiple plugins can publish the same short provider slug. If every + provider entity keeps that slug in ``provider_name``, callers that still + resolve by short name become order-dependent. Restrict the alias to the + provider selected by ``ModelProviderID`` so legacy short-name lookups + remain deterministic while the runtime surface stays canonical. + """ + try: + canonical_provider_id = ModelProviderID(provider.provider) + except ValueError: + return "" + + if canonical_provider_id.plugin_id != provider.plugin_id: + return "" + if canonical_provider_id.provider_name != provider.provider: + return "" + + return provider.provider + + @classmethod + def _to_provider_entity(cls, provider: PluginModelProviderEntity) -> ProviderEntity: + declaration = provider.declaration.model_copy(deep=True) + declaration.provider = f"{provider.plugin_id}/{provider.provider}" + declaration.provider_name = cls._get_provider_short_name_alias(provider) + return declaration + + @classmethod + def _load_cached_plugin_model_providers(cls, tenant_id: str) -> tuple[ProviderEntity, ...] | None: + cache_key = cls._get_plugin_model_providers_cache_key(tenant_id) + try: + cached_providers = redis_client.get(cache_key) + except (RedisError, RuntimeError): + logger.warning("Failed to read cached plugin model providers for tenant %s.", tenant_id, exc_info=True) + return None + + if not cached_providers: + return None + + try: + return tuple(_provider_entities_adapter.validate_json(cached_providers)) + except (TypeError, ValueError, ValidationError): + logger.warning( + "Invalid cached plugin model providers for tenant %s; deleting cache.", tenant_id, exc_info=True + ) + cls.invalidate_plugin_model_providers_cache(tenant_id) + return None + + @classmethod + def _store_cached_plugin_model_providers(cls, tenant_id: str, providers: Sequence[ProviderEntity]) -> None: + cache_key = cls._get_plugin_model_providers_cache_key(tenant_id) + try: + payload = _provider_entities_adapter.dump_json(list(providers)).decode("utf-8") + redis_client.setex(cache_key, dify_config.PLUGIN_MODEL_PROVIDERS_CACHE_TTL, payload) + except (RedisError, RuntimeError): + logger.warning("Failed to cache plugin model providers for tenant %s.", tenant_id, exc_info=True) + + @classmethod + def invalidate_plugin_model_providers_cache(cls, tenant_id: str) -> None: + """Delete the tenant-scoped plugin model provider list cache.""" + try: + redis_client.delete(cls._get_plugin_model_providers_cache_key(tenant_id)) + except (RedisError, RuntimeError): + logger.warning("Failed to invalidate plugin model providers cache for tenant %s.", tenant_id, exc_info=True) + + @classmethod + def fetch_plugin_model_providers( + cls, *, tenant_id: str, client: PluginModelClient | None = None + ) -> Sequence[ProviderEntity]: + """ + Fetch plugin model providers through the tenant-scoped plugin cache. + + Plugin daemon provider discovery and plugin lifecycle cache invalidation + are intentionally owned by this service so tenant isolation and cache + expiry are handled in one place. + """ + cached_providers = cls._load_cached_plugin_model_providers(tenant_id) + if cached_providers is not None: + return cached_providers + + model_client = client or PluginModelClient() + providers = tuple( + cls._to_provider_entity(provider) for provider in model_client.fetch_model_providers(tenant_id) + ) + cls._store_cached_plugin_model_providers(tenant_id, providers) + return providers @staticmethod def fetch_latest_plugin_version(plugin_ids: Sequence[str]) -> Mapping[str, LatestPluginCache | None]: @@ -248,12 +358,18 @@ class PluginService: Fetch plugin installation tasks """ manager = PluginInstaller() - return manager.fetch_plugin_installation_tasks(tenant_id, page, page_size) + tasks = manager.fetch_plugin_installation_tasks(tenant_id, page, page_size) + if any(task.status in PluginService.PLUGIN_INSTALL_TASK_TERMINAL_STATUSES for task in tasks): + PluginService.invalidate_plugin_model_providers_cache(tenant_id) + return tasks @staticmethod def fetch_install_task(tenant_id: str, task_id: str) -> PluginInstallTask: manager = PluginInstaller() - return manager.fetch_plugin_installation_task(tenant_id, task_id) + task = manager.fetch_plugin_installation_task(tenant_id, task_id) + if task.status in PluginService.PLUGIN_INSTALL_TASK_TERMINAL_STATUSES: + PluginService.invalidate_plugin_model_providers_cache(tenant_id) + return task @staticmethod def delete_install_task(tenant_id: str, task_id: str) -> bool: @@ -315,7 +431,7 @@ class PluginService: # check if the plugin is available to install PluginService._check_plugin_installation_scope(response.verification) - return manager.upgrade_plugin( + result = manager.upgrade_plugin( tenant_id, original_plugin_unique_identifier, new_plugin_unique_identifier, @@ -324,6 +440,8 @@ class PluginService: "plugin_unique_identifier": new_plugin_unique_identifier, }, ) + PluginService.invalidate_plugin_model_providers_cache(tenant_id) + return result @staticmethod def upgrade_plugin_with_github( @@ -339,7 +457,7 @@ class PluginService: """ PluginService._check_marketplace_only_permission() manager = PluginInstaller() - return manager.upgrade_plugin( + result = manager.upgrade_plugin( tenant_id, original_plugin_unique_identifier, new_plugin_unique_identifier, @@ -350,6 +468,8 @@ class PluginService: "package": package, }, ) + PluginService.invalidate_plugin_model_providers_cache(tenant_id) + return result @staticmethod def upload_pkg(tenant_id: str, pkg: bytes, verify_signature: bool = False) -> PluginDecodeResponse: @@ -415,12 +535,14 @@ class PluginService: resp = manager.decode_plugin_from_identifier(tenant_id, plugin_unique_identifier) PluginService._check_plugin_installation_scope(resp.verification) - return manager.install_from_identifiers( + result = manager.install_from_identifiers( tenant_id, plugin_unique_identifiers, PluginInstallationSource.Package, [{}], ) + PluginService.invalidate_plugin_model_providers_cache(tenant_id) + return result @staticmethod def install_from_github(tenant_id: str, plugin_unique_identifier: str, repo: str, version: str, package: str): @@ -434,7 +556,7 @@ class PluginService: plugin_decode_response = manager.decode_plugin_from_identifier(tenant_id, plugin_unique_identifier) PluginService._check_plugin_installation_scope(plugin_decode_response.verification) - return manager.install_from_identifiers( + result = manager.install_from_identifiers( tenant_id, [plugin_unique_identifier], PluginInstallationSource.Github, @@ -446,6 +568,8 @@ class PluginService: } ], ) + PluginService.invalidate_plugin_model_providers_cache(tenant_id) + return result @staticmethod def fetch_marketplace_pkg(tenant_id: str, plugin_unique_identifier: str) -> PluginDeclaration: @@ -513,12 +637,14 @@ class PluginService: actual_plugin_unique_identifiers.append(response.unique_identifier) metas.append({"plugin_unique_identifier": response.unique_identifier}) - return manager.install_from_identifiers( + result = manager.install_from_identifiers( tenant_id, actual_plugin_unique_identifiers, PluginInstallationSource.Marketplace, metas, ) + PluginService.invalidate_plugin_model_providers_cache(tenant_id) + return result @staticmethod def uninstall(tenant_id: str, plugin_installation_id: str) -> bool: @@ -529,7 +655,10 @@ class PluginService: plugin = next((p for p in plugins if p.installation_id == plugin_installation_id), None) if not plugin: - return manager.uninstall(tenant_id, plugin_installation_id) + result = manager.uninstall(tenant_id, plugin_installation_id) + if result: + PluginService.invalidate_plugin_model_providers_cache(tenant_id) + return result if dify_config.ENTERPRISE_ENABLED: PluginManagerService.try_pre_uninstall_plugin( @@ -559,37 +688,39 @@ class PluginService: if not credential_ids: logger.info("No credentials found for plugin: %s", plugin_id) - return manager.uninstall(tenant_id, plugin_installation_id) + else: + provider_ids = session.scalars( + select(Provider.id).where( + Provider.tenant_id == tenant_id, + Provider.provider_name.like(f"{plugin_id}/%"), + Provider.credential_id.in_(credential_ids), + ) + ).all() - provider_ids = session.scalars( - select(Provider.id).where( - Provider.tenant_id == tenant_id, - Provider.provider_name.like(f"{plugin_id}/%"), - Provider.credential_id.in_(credential_ids), + session.execute(update(Provider).where(Provider.id.in_(provider_ids)).values(credential_id=None)) + + for provider_id in provider_ids: + ProviderCredentialsCache( + tenant_id=tenant_id, + identity_id=provider_id, + cache_type=ProviderCredentialsCacheType.PROVIDER, + ).delete() + + session.execute( + delete(ProviderCredential).where( + ProviderCredential.id.in_(credential_ids), + ) ) - ).all() - session.execute(update(Provider).where(Provider.id.in_(provider_ids)).values(credential_id=None)) - - for provider_id in provider_ids: - ProviderCredentialsCache( - tenant_id=tenant_id, - identity_id=provider_id, - cache_type=ProviderCredentialsCacheType.PROVIDER, - ).delete() - - session.execute( - delete(ProviderCredential).where( - ProviderCredential.id.in_(credential_ids), + logger.info( + "Completed deleting credentials and cleaning provider associations for plugin: %s", + plugin_id, ) - ) - logger.info( - "Completed deleting credentials and cleaning provider associations for plugin: %s", - plugin_id, - ) - - return manager.uninstall(tenant_id, plugin_installation_id) + result = manager.uninstall(tenant_id, plugin_installation_id) + if result: + PluginService.invalidate_plugin_model_providers_cache(tenant_id) + return result @staticmethod def check_tools_existence(tenant_id: str, provider_ids: Sequence[GenericProviderID]) -> Sequence[bool]: diff --git a/api/core/trigger/provider.py b/api/core/trigger/provider.py index 10fa31fdfa..893526b5e0 100644 --- a/api/core/trigger/provider.py +++ b/api/core/trigger/provider.py @@ -16,6 +16,7 @@ from core.plugin.entities.request import ( TriggerSubscriptionResponse, ) from core.plugin.impl.trigger import PluginTriggerClient +from core.plugin.plugin_service import PluginService from core.trigger.entities.api_entities import EventApiEntity, TriggerProviderApiEntity from core.trigger.entities.entities import ( EventEntity, @@ -30,7 +31,6 @@ from core.trigger.entities.entities import ( ) from core.trigger.errors import TriggerProviderCredentialValidationError from models.provider_ids import TriggerProviderID -from services.plugin.plugin_service import PluginService logger = logging.getLogger(__name__) diff --git a/api/models/model.py b/api/models/model.py index f7f90465cf..3647fbf6f7 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -492,8 +492,8 @@ class App(Base): @property def deleted_tools(self) -> list[DeletedToolInfo]: + from core.plugin.plugin_service import PluginService from core.tools.tool_manager import ToolManager, ToolProviderType - from services.plugin.plugin_service import PluginService # get agent mode tools app_model_config = self.app_model_config diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py index 2245adb681..e3c6f122ab 100644 --- a/api/services/datasource_provider_service.py +++ b/api/services/datasource_provider_service.py @@ -14,13 +14,13 @@ from core.helper.provider_cache import NoOpProviderCredentialCache from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.impl.datasource import PluginDatasourceManager from core.plugin.impl.oauth import OAuthHandler +from core.plugin.plugin_service import PluginService from core.tools.utils.encryption import ProviderConfigCache, ProviderConfigEncrypter, create_provider_encrypter from extensions.ext_database import db from extensions.ext_redis import redis_client from graphon.model_runtime.entities.provider_entities import FormType from models.oauth import DatasourceOauthParamConfig, DatasourceOauthTenantParamConfig, DatasourceProvider from models.provider_ids import DatasourceProviderID -from services.plugin.plugin_service import PluginService logger = logging.getLogger(__name__) diff --git a/api/services/plugin/plugin_migration.py b/api/services/plugin/plugin_migration.py index 8fa3c3d4ef..88b9eeefa1 100644 --- a/api/services/plugin/plugin_migration.py +++ b/api/services/plugin/plugin_migration.py @@ -22,6 +22,7 @@ from core.helper import marketplace from core.plugin.entities.plugin import PluginInstallationSource from core.plugin.entities.plugin_daemon import PluginInstallTaskStatus from core.plugin.impl.plugin import PluginInstaller +from core.plugin.plugin_service import PluginService from core.tools.entities.tool_entities import ToolProviderType from extensions.ext_database import db from models.account import Tenant @@ -29,7 +30,6 @@ from models.model import App, AppMode, AppModelConfig from models.provider_ids import ModelProviderID, ToolProviderID from models.tools import BuiltinToolProvider from models.workflow import Workflow -from services.plugin.plugin_service import PluginService logger = logging.getLogger(__name__) @@ -389,17 +389,19 @@ class PluginMigration: for plugin_id in batch_plugin_ids if plugin_id not in installed_plugins_ids and plugin_id in plugins["plugins"] ] - manager.install_from_identifiers( - tenant_id, - batch_plugin_identifiers, - PluginInstallationSource.Marketplace, - metas=[ - { - "plugin_unique_identifier": identifier, - } - for identifier in batch_plugin_identifiers - ], - ) + if batch_plugin_identifiers: + manager.install_from_identifiers( + tenant_id, + batch_plugin_identifiers, + PluginInstallationSource.Marketplace, + metas=[ + { + "plugin_unique_identifier": identifier, + } + for identifier in batch_plugin_identifiers + ], + ) + PluginService.invalidate_plugin_model_providers_cache(tenant_id) with open(extracted_plugins) as f: """ @@ -595,6 +597,7 @@ class PluginMigration: for identifier in batch_plugin_identifiers ], ) + PluginService.invalidate_plugin_model_providers_cache(tenant_id) except Exception: # add to failed failed.extend(batch_plugin_identifiers) @@ -609,6 +612,7 @@ class PluginMigration: while not done: status = manager.fetch_plugin_installation_task(tenant_id, task_id) if status.status in [PluginInstallTaskStatus.Failed, PluginInstallTaskStatus.Success]: + PluginService.invalidate_plugin_model_providers_cache(tenant_id) for plugin in status.plugins: if plugin.status == PluginInstallTaskStatus.Success: success.append(reverse_map[plugin.plugin_unique_identifier]) diff --git a/api/services/rag_pipeline/rag_pipeline_transform_service.py b/api/services/rag_pipeline/rag_pipeline_transform_service.py index f95519fc9e..ca755d0b91 100644 --- a/api/services/rag_pipeline/rag_pipeline_transform_service.py +++ b/api/services/rag_pipeline/rag_pipeline_transform_service.py @@ -12,6 +12,7 @@ from sqlalchemy import select from configs import dify_config from constants import DOCUMENT_EXTENSIONS from core.plugin.impl.plugin import PluginInstaller +from core.plugin.plugin_service import PluginService from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.retrieval.retrieval_methods import RetrievalMethod from extensions.ext_database import db @@ -22,7 +23,6 @@ from models.model import UploadFile from models.workflow import Workflow, WorkflowType from services.entities.knowledge_entities.rag_pipeline_entities import KnowledgeConfiguration, RetrievalSetting from services.plugin.plugin_migration import PluginMigration -from services.plugin.plugin_service import PluginService logger = logging.getLogger(__name__) diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index 20de1f4058..6b7092e318 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -13,6 +13,7 @@ from core.helper.name_generator import generate_incremental_name from core.helper.position_helper import is_filtered from core.helper.provider_cache import NoOpProviderCredentialCache, ToolProviderCredentialsCache from core.plugin.entities.plugin_daemon import CredentialType +from core.plugin.plugin_service import PluginService from core.tools.builtin_tool.provider import BuiltinToolProviderController from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort from core.tools.entities.api_entities import ( @@ -31,7 +32,6 @@ from extensions.ext_database import db from extensions.ext_redis import redis_client from models.provider_ids import ToolProviderID from models.tools import BuiltinToolProvider, ToolOAuthSystemClient, ToolOAuthTenantClient -from services.plugin.plugin_service import PluginService from services.tools.tools_transform_service import ToolTransformService logger = logging.getLogger(__name__) diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index effbbaad01..c25f120917 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -9,6 +9,7 @@ from configs import dify_config from core.helper.provider_cache import ToolProviderCredentialsCache from core.mcp.types import Tool as MCPTool from core.plugin.entities.plugin_daemon import CredentialType, PluginDatasourceProviderEntity +from core.plugin.plugin_service import PluginService from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime from core.tools.builtin_tool.provider import BuiltinToolProviderController @@ -27,7 +28,6 @@ from core.tools.utils.encryption import create_provider_encrypter, create_tool_p from core.tools.workflow_as_tool.provider import WorkflowToolProviderController from core.tools.workflow_as_tool.tool import WorkflowTool from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider -from services.plugin.plugin_service import PluginService logger = logging.getLogger(__name__) diff --git a/api/services/trigger/trigger_provider_service.py b/api/services/trigger/trigger_provider_service.py index b8a76e4945..855f88fb90 100644 --- a/api/services/trigger/trigger_provider_service.py +++ b/api/services/trigger/trigger_provider_service.py @@ -14,6 +14,7 @@ from core.helper.provider_cache import NoOpProviderCredentialCache from core.helper.provider_encryption import ProviderConfigEncrypter, create_provider_encrypter from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.impl.oauth import OAuthHandler +from core.plugin.plugin_service import PluginService from core.tools.utils.system_encryption import decrypt_system_params from core.trigger.entities.api_entities import ( TriggerProviderApiEntity, @@ -37,7 +38,6 @@ from models.trigger import ( TriggerSubscription, WorkflowPluginTrigger, ) -from services.plugin.plugin_service import PluginService logger = logging.getLogger(__name__) diff --git a/api/services/workflow_app_service.py b/api/services/workflow_app_service.py index 59e02ec9b9..8ff9d651a4 100644 --- a/api/services/workflow_app_service.py +++ b/api/services/workflow_app_service.py @@ -6,11 +6,11 @@ from typing import Any, TypedDict from sqlalchemy import and_, func, or_, select from sqlalchemy.orm import Session +from core.plugin.plugin_service import PluginService from graphon.enums import WorkflowExecutionStatus from models import Account, App, EndUser, TenantAccountJoin, WorkflowAppLog, WorkflowArchiveLog, WorkflowRun from models.enums import AppTriggerType, CreatorUserRole from models.trigger import WorkflowTriggerLog -from services.plugin.plugin_service import PluginService from services.workflow.entities import TriggerMetadata diff --git a/api/tasks/process_tenant_plugin_autoupgrade_check_task.py b/api/tasks/process_tenant_plugin_autoupgrade_check_task.py index 48d1774ce3..ab54b9e72e 100644 --- a/api/tasks/process_tenant_plugin_autoupgrade_check_task.py +++ b/api/tasks/process_tenant_plugin_autoupgrade_check_task.py @@ -9,9 +9,9 @@ from celery import shared_task from core.plugin.entities.marketplace import MarketplacePluginSnapshot from core.plugin.entities.plugin import PluginInstallationSource from core.plugin.impl.plugin import PluginInstaller +from core.plugin.plugin_service import PluginService from extensions.ext_redis import redis_client from models.account import TenantPluginAutoUpgradeStrategy -from services.plugin.plugin_service import PluginService logger = logging.getLogger(__name__) diff --git a/api/tests/test_containers_integration_tests/services/plugin/test_plugin_service.py b/api/tests/test_containers_integration_tests/services/plugin/test_plugin_service.py index 0fa4d9043b..66ff24f374 100644 --- a/api/tests/test_containers_integration_tests/services/plugin/test_plugin_service.py +++ b/api/tests/test_containers_integration_tests/services/plugin/test_plugin_service.py @@ -1,4 +1,4 @@ -"""Tests for services.plugin.plugin_service.PluginService. +"""Tests for core.plugin.plugin_service.PluginService. Covers: version caching with Redis, install permission/scope gates, icon URL construction, asset retrieval with MIME guessing, plugin @@ -17,11 +17,11 @@ from sqlalchemy.orm import Session from core.plugin.entities.plugin import PluginInstallationSource from core.plugin.entities.plugin_daemon import PluginVerification +from core.plugin.plugin_service import PluginService from models import ProviderType from models.provider import Provider, ProviderCredential, TenantPreferredModelProvider from services.errors.plugin import PluginInstallationForbiddenError from services.feature_service import PluginInstallationScope -from services.plugin.plugin_service import PluginService def _make_features( @@ -35,8 +35,8 @@ def _make_features( class TestFetchLatestPluginVersion: - @patch("services.plugin.plugin_service.marketplace") - @patch("services.plugin.plugin_service.redis_client") + @patch("core.plugin.plugin_service.marketplace") + @patch("core.plugin.plugin_service.redis_client") def test_returns_cached_version(self, mock_redis, mock_marketplace): cached_json = PluginService.LatestPluginCache( plugin_id="p1", @@ -53,8 +53,8 @@ class TestFetchLatestPluginVersion: assert result["p1"].version == "1.0.0" mock_marketplace.batch_fetch_plugin_manifests.assert_not_called() - @patch("services.plugin.plugin_service.marketplace") - @patch("services.plugin.plugin_service.redis_client") + @patch("core.plugin.plugin_service.marketplace") + @patch("core.plugin.plugin_service.redis_client") def test_fetches_from_marketplace_on_cache_miss(self, mock_redis, mock_marketplace): mock_redis.get.return_value = None manifest = MagicMock() @@ -71,8 +71,8 @@ class TestFetchLatestPluginVersion: assert result["p1"].version == "2.0.0" mock_redis.setex.assert_called_once() - @patch("services.plugin.plugin_service.marketplace") - @patch("services.plugin.plugin_service.redis_client") + @patch("core.plugin.plugin_service.marketplace") + @patch("core.plugin.plugin_service.redis_client") def test_returns_none_for_unknown_plugin(self, mock_redis, mock_marketplace): mock_redis.get.return_value = None mock_marketplace.batch_fetch_plugin_manifests.return_value = [] @@ -81,8 +81,8 @@ class TestFetchLatestPluginVersion: assert result["unknown"] is None - @patch("services.plugin.plugin_service.marketplace") - @patch("services.plugin.plugin_service.redis_client") + @patch("core.plugin.plugin_service.marketplace") + @patch("core.plugin.plugin_service.redis_client") def test_handles_marketplace_exception_gracefully(self, mock_redis, mock_marketplace): mock_redis.get.return_value = None mock_marketplace.batch_fetch_plugin_manifests.side_effect = RuntimeError("network error") @@ -93,14 +93,14 @@ class TestFetchLatestPluginVersion: class TestCheckMarketplaceOnlyPermission: - @patch("services.plugin.plugin_service.FeatureService") + @patch("core.plugin.plugin_service.FeatureService") def test_raises_when_restricted(self, mock_fs): mock_fs.get_system_features.return_value = _make_features(restrict_to_marketplace=True) with pytest.raises(PluginInstallationForbiddenError): PluginService._check_marketplace_only_permission() - @patch("services.plugin.plugin_service.FeatureService") + @patch("core.plugin.plugin_service.FeatureService") def test_passes_when_not_restricted(self, mock_fs): mock_fs.get_system_features.return_value = _make_features(restrict_to_marketplace=False) @@ -108,7 +108,7 @@ class TestCheckMarketplaceOnlyPermission: class TestCheckPluginInstallationScope: - @patch("services.plugin.plugin_service.FeatureService") + @patch("core.plugin.plugin_service.FeatureService") def test_official_only_allows_langgenius(self, mock_fs): mock_fs.get_system_features.return_value = _make_features(scope=PluginInstallationScope.OFFICIAL_ONLY) verification = MagicMock() @@ -116,14 +116,14 @@ class TestCheckPluginInstallationScope: PluginService._check_plugin_installation_scope(verification) # should not raise - @patch("services.plugin.plugin_service.FeatureService") + @patch("core.plugin.plugin_service.FeatureService") def test_official_only_rejects_third_party(self, mock_fs): mock_fs.get_system_features.return_value = _make_features(scope=PluginInstallationScope.OFFICIAL_ONLY) with pytest.raises(PluginInstallationForbiddenError): PluginService._check_plugin_installation_scope(None) - @patch("services.plugin.plugin_service.FeatureService") + @patch("core.plugin.plugin_service.FeatureService") def test_official_and_partners_allows_partner(self, mock_fs): mock_fs.get_system_features.return_value = _make_features( scope=PluginInstallationScope.OFFICIAL_AND_SPECIFIC_PARTNERS @@ -133,7 +133,7 @@ class TestCheckPluginInstallationScope: PluginService._check_plugin_installation_scope(verification) # should not raise - @patch("services.plugin.plugin_service.FeatureService") + @patch("core.plugin.plugin_service.FeatureService") def test_official_and_partners_rejects_none(self, mock_fs): mock_fs.get_system_features.return_value = _make_features( scope=PluginInstallationScope.OFFICIAL_AND_SPECIFIC_PARTNERS @@ -142,7 +142,7 @@ class TestCheckPluginInstallationScope: with pytest.raises(PluginInstallationForbiddenError): PluginService._check_plugin_installation_scope(None) - @patch("services.plugin.plugin_service.FeatureService") + @patch("core.plugin.plugin_service.FeatureService") def test_none_scope_always_raises(self, mock_fs): mock_fs.get_system_features.return_value = _make_features(scope=PluginInstallationScope.NONE) verification = MagicMock() @@ -151,7 +151,7 @@ class TestCheckPluginInstallationScope: with pytest.raises(PluginInstallationForbiddenError): PluginService._check_plugin_installation_scope(verification) - @patch("services.plugin.plugin_service.FeatureService") + @patch("core.plugin.plugin_service.FeatureService") def test_all_scope_passes_any(self, mock_fs): mock_fs.get_system_features.return_value = _make_features(scope=PluginInstallationScope.ALL) @@ -159,7 +159,7 @@ class TestCheckPluginInstallationScope: class TestGetPluginIconUrl: - @patch("services.plugin.plugin_service.dify_config") + @patch("core.plugin.plugin_service.dify_config") def test_constructs_url_with_params(self, mock_config): mock_config.CONSOLE_API_URL = "https://console.example.com" @@ -171,7 +171,7 @@ class TestGetPluginIconUrl: class TestGetAsset: - @patch("services.plugin.plugin_service.PluginAssetManager") + @patch("core.plugin.plugin_service.PluginAssetManager") def test_returns_bytes_and_guessed_mime(self, mock_asset_cls): mock_asset_cls.return_value.fetch_asset.return_value = b"" @@ -180,7 +180,7 @@ class TestGetAsset: assert data == b"" assert "svg" in mime - @patch("services.plugin.plugin_service.PluginAssetManager") + @patch("core.plugin.plugin_service.PluginAssetManager") def test_fallback_to_octet_stream_for_unknown(self, mock_asset_cls): mock_asset_cls.return_value.fetch_asset.return_value = b"\x00" @@ -190,13 +190,13 @@ class TestGetAsset: class TestIsPluginVerified: - @patch("services.plugin.plugin_service.PluginInstaller") + @patch("core.plugin.plugin_service.PluginInstaller") def test_returns_true_when_verified(self, mock_installer_cls): mock_installer_cls.return_value.fetch_plugin_manifest.return_value.verified = True assert PluginService.is_plugin_verified("t1", "uid-1") is True - @patch("services.plugin.plugin_service.PluginInstaller") + @patch("core.plugin.plugin_service.PluginInstaller") def test_returns_false_on_exception(self, mock_installer_cls): mock_installer_cls.return_value.fetch_plugin_manifest.side_effect = RuntimeError("not found") @@ -204,24 +204,24 @@ class TestIsPluginVerified: class TestUpgradePluginWithMarketplace: - @patch("services.plugin.plugin_service.dify_config") + @patch("core.plugin.plugin_service.dify_config") def test_raises_when_marketplace_disabled(self, mock_config): mock_config.MARKETPLACE_ENABLED = False with pytest.raises(ValueError, match="marketplace is not enabled"): PluginService.upgrade_plugin_with_marketplace("t1", "old-uid", "new-uid") - @patch("services.plugin.plugin_service.dify_config") + @patch("core.plugin.plugin_service.dify_config") def test_raises_when_same_identifier(self, mock_config): mock_config.MARKETPLACE_ENABLED = True with pytest.raises(ValueError, match="same plugin"): PluginService.upgrade_plugin_with_marketplace("t1", "same-uid", "same-uid") - @patch("services.plugin.plugin_service.marketplace") - @patch("services.plugin.plugin_service.FeatureService") - @patch("services.plugin.plugin_service.PluginInstaller") - @patch("services.plugin.plugin_service.dify_config") + @patch("core.plugin.plugin_service.marketplace") + @patch("core.plugin.plugin_service.FeatureService") + @patch("core.plugin.plugin_service.PluginInstaller") + @patch("core.plugin.plugin_service.dify_config") def test_skips_download_when_already_installed(self, mock_config, mock_installer_cls, mock_fs, mock_marketplace): mock_config.MARKETPLACE_ENABLED = True mock_fs.get_system_features.return_value = _make_features() @@ -234,10 +234,10 @@ class TestUpgradePluginWithMarketplace: mock_marketplace.record_install_plugin_event.assert_called_once_with("new-uid") installer.upgrade_plugin.assert_called_once() - @patch("services.plugin.plugin_service.download_plugin_pkg") - @patch("services.plugin.plugin_service.FeatureService") - @patch("services.plugin.plugin_service.PluginInstaller") - @patch("services.plugin.plugin_service.dify_config") + @patch("core.plugin.plugin_service.download_plugin_pkg") + @patch("core.plugin.plugin_service.FeatureService") + @patch("core.plugin.plugin_service.PluginInstaller") + @patch("core.plugin.plugin_service.dify_config") def test_downloads_when_not_installed(self, mock_config, mock_installer_cls, mock_fs, mock_download): mock_config.MARKETPLACE_ENABLED = True mock_fs.get_system_features.return_value = _make_features() @@ -256,8 +256,8 @@ class TestUpgradePluginWithMarketplace: class TestUpgradePluginWithGithub: - @patch("services.plugin.plugin_service.FeatureService") - @patch("services.plugin.plugin_service.PluginInstaller") + @patch("core.plugin.plugin_service.FeatureService") + @patch("core.plugin.plugin_service.PluginInstaller") def test_checks_marketplace_permission_and_delegates(self, mock_installer_cls, mock_fs): mock_fs.get_system_features.return_value = _make_features() installer = mock_installer_cls.return_value @@ -271,8 +271,8 @@ class TestUpgradePluginWithGithub: class TestUploadPkg: - @patch("services.plugin.plugin_service.FeatureService") - @patch("services.plugin.plugin_service.PluginInstaller") + @patch("core.plugin.plugin_service.FeatureService") + @patch("core.plugin.plugin_service.PluginInstaller") def test_runs_permission_and_scope_checks(self, mock_installer_cls, mock_fs): mock_fs.get_system_features.return_value = _make_features() upload_resp = MagicMock() @@ -285,17 +285,17 @@ class TestUploadPkg: class TestInstallFromMarketplacePkg: - @patch("services.plugin.plugin_service.dify_config") + @patch("core.plugin.plugin_service.dify_config") def test_raises_when_marketplace_disabled(self, mock_config): mock_config.MARKETPLACE_ENABLED = False with pytest.raises(ValueError, match="marketplace is not enabled"): PluginService.install_from_marketplace_pkg("t1", ["uid-1"]) - @patch("services.plugin.plugin_service.download_plugin_pkg") - @patch("services.plugin.plugin_service.FeatureService") - @patch("services.plugin.plugin_service.PluginInstaller") - @patch("services.plugin.plugin_service.dify_config") + @patch("core.plugin.plugin_service.download_plugin_pkg") + @patch("core.plugin.plugin_service.FeatureService") + @patch("core.plugin.plugin_service.PluginInstaller") + @patch("core.plugin.plugin_service.dify_config") def test_downloads_when_not_cached(self, mock_config, mock_installer_cls, mock_fs, mock_download): mock_config.MARKETPLACE_ENABLED = True mock_fs.get_system_features.return_value = _make_features() @@ -315,9 +315,9 @@ class TestInstallFromMarketplacePkg: call_args = installer.install_from_identifiers.call_args[0] assert call_args[1] == ["resolved-uid"] - @patch("services.plugin.plugin_service.FeatureService") - @patch("services.plugin.plugin_service.PluginInstaller") - @patch("services.plugin.plugin_service.dify_config") + @patch("core.plugin.plugin_service.FeatureService") + @patch("core.plugin.plugin_service.PluginInstaller") + @patch("core.plugin.plugin_service.dify_config") def test_uses_cached_when_already_downloaded(self, mock_config, mock_installer_cls, mock_fs): mock_config.MARKETPLACE_ENABLED = True mock_fs.get_system_features.return_value = _make_features() @@ -336,7 +336,7 @@ class TestInstallFromMarketplacePkg: class TestUninstall: - @patch("services.plugin.plugin_service.PluginInstaller") + @patch("core.plugin.plugin_service.PluginInstaller") def test_direct_uninstall_when_plugin_not_found(self, mock_installer_cls): installer = mock_installer_cls.return_value installer.list_plugins.return_value = [] @@ -347,7 +347,7 @@ class TestUninstall: assert result is True installer.uninstall.assert_called_once_with("t1", "install-1") - @patch("services.plugin.plugin_service.PluginInstaller") + @patch("core.plugin.plugin_service.PluginInstaller") def test_cleans_credentials_when_plugin_found( self, mock_installer_cls, flask_app_with_containers: Flask, db_session_with_containers: Session ): @@ -389,7 +389,7 @@ class TestUninstall: installer.list_plugins.return_value = [plugin] installer.uninstall.return_value = True - with patch("services.plugin.plugin_service.dify_config") as mock_config: + with patch("core.plugin.plugin_service.dify_config") as mock_config: mock_config.ENTERPRISE_ENABLED = False result = PluginService.uninstall(tenant_id, "install-1") diff --git a/api/tests/test_containers_integration_tests/services/tools/test_tools_transform_service.py b/api/tests/test_containers_integration_tests/services/tools/test_tools_transform_service.py index 2dc50cc720..a174f5d69f 100644 --- a/api/tests/test_containers_integration_tests/services/tools/test_tools_transform_service.py +++ b/api/tests/test_containers_integration_tests/services/tools/test_tools_transform_service.py @@ -6,6 +6,7 @@ import pytest from faker import Faker from sqlalchemy.orm import Session +from core.plugin.plugin_service import PluginService from core.tools.__base.tool import Tool from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity from core.tools.entities.common_entities import I18nObject @@ -20,7 +21,6 @@ from core.tools.entities.tool_entities import ( ToolProviderType, ) from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider -from services.plugin.plugin_service import PluginService from services.tools.tools_transform_service import ToolTransformService @@ -31,7 +31,7 @@ class TestToolTransformService: def mock_external_service_dependencies(self): """Mock setup for external service dependencies.""" with patch("services.tools.tools_transform_service.dify_config") as mock_dify_config: - with patch("services.plugin.plugin_service.dify_config", new=mock_dify_config): + with patch("core.plugin.plugin_service.dify_config", new=mock_dify_config): # Setup default mock returns mock_dify_config.CONSOLE_API_URL = "https://console.example.com" diff --git a/api/tests/unit_tests/core/plugin/impl/test_model_runtime_factory.py b/api/tests/unit_tests/core/plugin/impl/test_model_runtime_factory.py index 52da674f06..c2dfe3be30 100644 --- a/api/tests/unit_tests/core/plugin/impl/test_model_runtime_factory.py +++ b/api/tests/unit_tests/core/plugin/impl/test_model_runtime_factory.py @@ -1,6 +1,7 @@ from unittest.mock import Mock, patch from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly +from core.plugin.plugin_service import PluginService def test_plugin_model_assembly_reuses_single_runtime_across_views(): @@ -34,3 +35,11 @@ def test_plugin_model_assembly_reuses_single_runtime_across_views(): mock_provider_factory_cls.assert_called_once_with(runtime=runtime) mock_provider_manager_cls.assert_called_once_with(model_runtime=runtime) mock_model_manager_cls.assert_called_once_with(provider_manager=provider_manager) + + +def test_create_plugin_model_runtime_injects_plugin_service(): + from core.plugin.impl.model_runtime_factory import create_plugin_model_runtime + + runtime = create_plugin_model_runtime(tenant_id="tenant-1", user_id="user-1") + + assert runtime._plugin_service is PluginService diff --git a/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py b/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py index b1ecaa4ead..f9abc7d02a 100644 --- a/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py +++ b/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py @@ -12,6 +12,7 @@ from core.plugin.impl import model_runtime as model_runtime_module from core.plugin.impl.model import PluginModelClient from core.plugin.impl.model_runtime import TENANT_SCOPE_SCHEMA_CACHE_USER_ID, PluginModelRuntime from core.plugin.impl.model_runtime_factory import create_plugin_model_runtime +from core.plugin.plugin_service import PluginService from graphon.model_runtime.entities.common_entities import I18nObject from graphon.model_runtime.entities.llm_entities import LLMResultChunk, LLMResultChunkDelta, LLMUsage from graphon.model_runtime.entities.message_entities import AssistantPromptMessage @@ -19,6 +20,22 @@ from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFr from graphon.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity +class _FakeRedis: + def __init__(self) -> None: + self._values: dict[str, str] = {} + self.setex_calls: list[tuple[str, int, str]] = [] + + def get(self, key: str) -> str | None: + return self._values.get(key) + + def setex(self, key: str, ttl: int, value: str) -> None: + self._values[key] = value + self.setex_calls.append((key, ttl, value)) + + def delete(self, key: str) -> None: + self._values.pop(key, None) + + def _build_model_schema() -> AIModelEntity: return AIModelEntity( model="gpt-4o-mini", @@ -29,6 +46,24 @@ def _build_model_schema() -> AIModelEntity: ) +def _build_plugin_model_provider(*, tenant_id: str, provider: str = "openai") -> PluginModelProviderEntity: + return PluginModelProviderEntity( + id=uuid.uuid4().hex, + created_at=datetime.datetime.now(), + updated_at=datetime.datetime.now(), + provider=provider, + tenant_id=tenant_id, + plugin_unique_identifier=f"langgenius/{provider}/{provider}", + plugin_id=f"langgenius/{provider}", + declaration=ProviderEntity( + provider=provider, + label=I18nObject(en_US=provider.title()), + supported_model_types=[], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + ), + ) + + class TestPluginModelRuntime: """Validate the adapter keeps plugin-specific routing out of the runtime port.""" @@ -51,7 +86,7 @@ class TestPluginModelRuntime: ), ) ] - runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client, plugin_service=PluginService) providers = runtime.fetch_model_providers() @@ -95,7 +130,7 @@ class TestPluginModelRuntime: ), ), ] - runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client, plugin_service=PluginService) providers = runtime.fetch_model_providers() @@ -122,7 +157,7 @@ class TestPluginModelRuntime: ), ) ] - runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client, plugin_service=PluginService) providers = runtime.fetch_model_providers() @@ -131,7 +166,7 @@ class TestPluginModelRuntime: def test_validate_provider_credentials_resolves_plugin_fields(self) -> None: client = Mock(spec=PluginModelClient) - runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client, plugin_service=PluginService) runtime.validate_provider_credentials( provider="langgenius/openai/openai", @@ -173,7 +208,7 @@ class TestPluginModelRuntime: ), ] ) - runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client, plugin_service=PluginService) result = runtime.invoke_llm( provider="langgenius/openai/openai", @@ -209,7 +244,7 @@ class TestPluginModelRuntime: client = Mock(spec=PluginModelClient) stream_result = iter([]) client.invoke_llm.return_value = stream_result - runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client, plugin_service=PluginService) result = runtime.invoke_llm( provider="langgenius/openai/openai", @@ -240,7 +275,9 @@ class TestPluginModelRuntime: def test_invoke_llm_rejects_per_call_user_override(self) -> None: client = Mock(spec=PluginModelClient) client.invoke_llm.return_value = sentinel.result - runtime = PluginModelRuntime(tenant_id="tenant", user_id="bound-user", client=client) + runtime = PluginModelRuntime( + tenant_id="tenant", user_id="bound-user", client=client, plugin_service=PluginService + ) with pytest.raises(TypeError, match="unexpected keyword argument 'user_id'"): runtime.invoke_llm( # type: ignore[call-arg] @@ -260,7 +297,7 @@ class TestPluginModelRuntime: def test_invoke_tts_uses_bound_runtime_user_when_runtime_is_unbound(self) -> None: client = Mock(spec=PluginModelClient) client.invoke_tts.return_value = iter([b"chunk"]) - runtime = PluginModelRuntime(tenant_id="tenant", user_id=None, client=client) + runtime = PluginModelRuntime(tenant_id="tenant", user_id=None, client=client, plugin_service=PluginService) result = runtime.invoke_tts( provider="langgenius/openai/openai", @@ -282,15 +319,107 @@ class TestPluginModelRuntime: voice="alloy", ) - def test_fetch_model_providers_uses_bound_runtime_cache(self) -> None: + def test_fetch_model_providers_does_not_keep_bound_runtime_cache(self, monkeypatch: pytest.MonkeyPatch) -> None: client = Mock(spec=PluginModelClient) client.fetch_model_providers.return_value = [] - runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + from core.plugin import plugin_service as plugin_service_module + + monkeypatch.setattr( + plugin_service_module, + "redis_client", + SimpleNamespace( + get=Mock(return_value=None), + delete=Mock(), + setex=Mock(), + ), + ) + monkeypatch.setattr(plugin_service_module.dify_config, "PLUGIN_MODEL_PROVIDERS_CACHE_TTL", 300) + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client, plugin_service=PluginService) runtime.fetch_model_providers() runtime.fetch_model_providers() - client.fetch_model_providers.assert_called_once_with("tenant") + assert client.fetch_model_providers.call_count == 2 + + def test_fetch_model_providers_uses_tenant_ttl_cache_across_runtime_instances( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + redis = _FakeRedis() + from core.plugin import plugin_service as plugin_service_module + + monkeypatch.setattr(plugin_service_module, "redis_client", redis) + monkeypatch.setattr(plugin_service_module.dify_config, "PLUGIN_MODEL_PROVIDERS_CACHE_TTL", 300) + first_client = Mock(spec=PluginModelClient) + first_client.fetch_model_providers.return_value = [_build_plugin_model_provider(tenant_id="tenant")] + second_client = Mock(spec=PluginModelClient) + first_runtime = PluginModelRuntime( + tenant_id="tenant", user_id="user-a", client=first_client, plugin_service=PluginService + ) + second_runtime = PluginModelRuntime( + tenant_id="tenant", user_id="user-b", client=second_client, plugin_service=PluginService + ) + + first_providers = first_runtime.fetch_model_providers() + second_providers = second_runtime.fetch_model_providers() + + assert [provider.provider for provider in first_providers] == ["langgenius/openai/openai"] + assert [provider.provider for provider in second_providers] == ["langgenius/openai/openai"] + first_client.fetch_model_providers.assert_called_once_with("tenant") + second_client.fetch_model_providers.assert_not_called() + assert redis.setex_calls[0][1] == 300 + + def test_fetch_model_providers_cache_is_tenant_isolated(self, monkeypatch: pytest.MonkeyPatch) -> None: + redis = _FakeRedis() + from core.plugin import plugin_service as plugin_service_module + + monkeypatch.setattr(plugin_service_module, "redis_client", redis) + monkeypatch.setattr(plugin_service_module.dify_config, "PLUGIN_MODEL_PROVIDERS_CACHE_TTL", 300) + first_client = Mock(spec=PluginModelClient) + first_client.fetch_model_providers.return_value = [_build_plugin_model_provider(tenant_id="tenant-a")] + second_client = Mock(spec=PluginModelClient) + second_client.fetch_model_providers.return_value = [_build_plugin_model_provider(tenant_id="tenant-b")] + first_runtime = PluginModelRuntime( + tenant_id="tenant-a", user_id="user", client=first_client, plugin_service=PluginService + ) + second_runtime = PluginModelRuntime( + tenant_id="tenant-b", user_id="user", client=second_client, plugin_service=PluginService + ) + + first_providers = first_runtime.fetch_model_providers() + second_providers = second_runtime.fetch_model_providers() + + assert [provider.provider for provider in first_providers] == ["langgenius/openai/openai"] + assert [provider.provider for provider in second_providers] == ["langgenius/openai/openai"] + first_client.fetch_model_providers.assert_called_once_with("tenant-a") + second_client.fetch_model_providers.assert_called_once_with("tenant-b") + assert len(redis.setex_calls) == 2 + + def test_fetch_model_providers_delegates_cache_to_injected_plugin_service(self) -> None: + client = Mock(spec=PluginModelClient) + service_result = [ + ProviderEntity( + provider="langgenius/openai/openai", + label=I18nObject(en_US="OpenAI"), + supported_model_types=[], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + ) + ] + fetch_plugin_model_providers = Mock(return_value=service_result) + + class TestPluginService(PluginService): + pass + + TestPluginService.fetch_plugin_model_providers = fetch_plugin_model_providers + + runtime = PluginModelRuntime( + tenant_id="tenant", user_id="user", client=client, plugin_service=TestPluginService + ) + + result = runtime.fetch_model_providers() + + assert result is service_result + fetch_plugin_model_providers.assert_called_once_with(tenant_id="tenant", client=client) + client.fetch_model_providers.assert_not_called() def test_create_plugin_model_runtime_without_user_context() -> None: @@ -301,7 +430,17 @@ def test_create_plugin_model_runtime_without_user_context() -> None: def test_plugin_model_runtime_requires_client() -> None: with pytest.raises(ValueError, match="client is required"): - PluginModelRuntime(tenant_id="tenant", user_id="user", client=None) # type: ignore[arg-type] + PluginModelRuntime(tenant_id="tenant", user_id="user", client=None, plugin_service=PluginService) # type: ignore[arg-type] + + +def test_plugin_model_runtime_requires_plugin_service() -> None: + with pytest.raises(ValueError, match="plugin_service is required"): + PluginModelRuntime( + tenant_id="tenant", + user_id="user", + client=Mock(spec=PluginModelClient), + plugin_service=None, # type: ignore[arg-type] + ) def test_get_model_schema_uses_cached_schema_without_hitting_client(monkeypatch: pytest.MonkeyPatch) -> None: @@ -317,7 +456,7 @@ def test_get_model_schema_uses_cached_schema_without_hitting_client(monkeypatch: ), ) - runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client, plugin_service=PluginService) result = runtime.get_model_schema( provider="langgenius/openai/openai", model_type=ModelType.LLM, @@ -395,7 +534,7 @@ def test_structured_output_adapter_invokes_bound_runtime_non_streaming() -> None def test_invoke_llm_with_structured_output_delegates_with_bound_adapter() -> None: client = Mock(spec=PluginModelClient) - runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client, plugin_service=PluginService) schema = _build_model_schema() runtime.get_model_schema = Mock(return_value=schema) # type: ignore[method-assign] @@ -436,7 +575,7 @@ def test_invoke_llm_with_structured_output_delegates_with_bound_adapter() -> Non def test_invoke_llm_with_structured_output_raises_when_model_schema_is_missing() -> None: client = Mock(spec=PluginModelClient) - runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client, plugin_service=PluginService) runtime.get_model_schema = Mock(return_value=None) # type: ignore[method-assign] with pytest.raises(ValueError, match="Model schema not found for gpt-4o-mini"): @@ -468,7 +607,7 @@ def test_get_model_schema_deletes_invalid_cache_and_refetches(monkeypatch: pytes ) monkeypatch.setattr(model_runtime_module.dify_config, "PLUGIN_MODEL_SCHEMA_CACHE_TTL", 300) client.get_model_schema.return_value = schema - runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client, plugin_service=PluginService) result = runtime.get_model_schema( provider="langgenius/openai/openai", @@ -494,7 +633,7 @@ def test_get_model_schema_deletes_invalid_cache_and_refetches(monkeypatch: pytes def test_get_llm_num_tokens_returns_zero_when_plugin_counting_is_disabled(monkeypatch: pytest.MonkeyPatch) -> None: client = Mock(spec=PluginModelClient) monkeypatch.setattr(model_runtime_module.dify_config, "PLUGIN_BASED_TOKEN_COUNTING_ENABLED", False) - runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client, plugin_service=PluginService) assert ( runtime.get_llm_num_tokens( @@ -533,7 +672,7 @@ def test_get_provider_icon_reads_requested_variant_and_detects_svg_mime(monkeypa ] fetch_asset = Mock(return_value=b"") monkeypatch.setattr(model_runtime_module.PluginAssetManager, "fetch_asset", fetch_asset) - runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client, plugin_service=PluginService) icon_bytes, mime_type = runtime.get_provider_icon( provider="langgenius/openai/openai", @@ -565,7 +704,7 @@ def test_get_provider_icon_rejects_unsupported_types_and_missing_variants() -> N ), ) ] - runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client, plugin_service=PluginService) with pytest.raises(ValueError, match="does not have small dark icon"): runtime.get_provider_icon( @@ -583,7 +722,9 @@ def test_get_provider_icon_rejects_unsupported_types_and_missing_variants() -> N def test_get_schema_cache_key_is_stable_across_credential_order() -> None: - runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=Mock(spec=PluginModelClient)) + runtime = PluginModelRuntime( + tenant_id="tenant", user_id="user", client=Mock(spec=PluginModelClient), plugin_service=PluginService + ) first = runtime._get_schema_cache_key( provider="langgenius/openai/openai", @@ -602,8 +743,12 @@ def test_get_schema_cache_key_is_stable_across_credential_order() -> None: def test_get_schema_cache_key_separates_distinct_user_scopes() -> None: - first_runtime = PluginModelRuntime(tenant_id="tenant", user_id="user-a", client=Mock(spec=PluginModelClient)) - second_runtime = PluginModelRuntime(tenant_id="tenant", user_id="user-b", client=Mock(spec=PluginModelClient)) + first_runtime = PluginModelRuntime( + tenant_id="tenant", user_id="user-a", client=Mock(spec=PluginModelClient), plugin_service=PluginService + ) + second_runtime = PluginModelRuntime( + tenant_id="tenant", user_id="user-b", client=Mock(spec=PluginModelClient), plugin_service=PluginService + ) first = first_runtime._get_schema_cache_key( provider="langgenius/openai/openai", @@ -622,8 +767,12 @@ def test_get_schema_cache_key_separates_distinct_user_scopes() -> None: def test_get_schema_cache_key_separates_tenant_scope_from_user_scope() -> None: - tenant_runtime = PluginModelRuntime(tenant_id="tenant", user_id=None, client=Mock(spec=PluginModelClient)) - user_runtime = PluginModelRuntime(tenant_id="tenant", user_id="user-a", client=Mock(spec=PluginModelClient)) + tenant_runtime = PluginModelRuntime( + tenant_id="tenant", user_id=None, client=Mock(spec=PluginModelClient), plugin_service=PluginService + ) + user_runtime = PluginModelRuntime( + tenant_id="tenant", user_id="user-a", client=Mock(spec=PluginModelClient), plugin_service=PluginService + ) tenant_key = tenant_runtime._get_schema_cache_key( provider="langgenius/openai/openai", @@ -643,8 +792,12 @@ def test_get_schema_cache_key_separates_tenant_scope_from_user_scope() -> None: def test_get_schema_cache_key_separates_tenant_scope_from_empty_string_user_scope() -> None: - tenant_runtime = PluginModelRuntime(tenant_id="tenant", user_id=None, client=Mock(spec=PluginModelClient)) - empty_user_runtime = PluginModelRuntime(tenant_id="tenant", user_id="", client=Mock(spec=PluginModelClient)) + tenant_runtime = PluginModelRuntime( + tenant_id="tenant", user_id=None, client=Mock(spec=PluginModelClient), plugin_service=PluginService + ) + empty_user_runtime = PluginModelRuntime( + tenant_id="tenant", user_id="", client=Mock(spec=PluginModelClient), plugin_service=PluginService + ) tenant_key = tenant_runtime._get_schema_cache_key( provider="langgenius/openai/openai", @@ -683,7 +836,7 @@ def test_get_provider_schema_supports_short_alias_and_rejects_invalid_provider() ), ) ] - runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client, plugin_service=PluginService) assert runtime._get_provider_schema("openai").provider == "langgenius/openai/openai" diff --git a/api/tests/unit_tests/models/test_app_models.py b/api/tests/unit_tests/models/test_app_models.py index e3b8269e15..8850137eb9 100644 --- a/api/tests/unit_tests/models/test_app_models.py +++ b/api/tests/unit_tests/models/test_app_models.py @@ -11,6 +11,7 @@ This test suite covers: import json from datetime import UTC, datetime from decimal import Decimal +from types import SimpleNamespace from unittest.mock import MagicMock, patch from uuid import uuid4 @@ -197,6 +198,55 @@ class TestAppModelValidation: # Assert assert result == AppMode.CHAT + def test_deleted_tools_checks_plugin_builtin_providers_through_core_plugin_service(self): + """Plugin-backed built-in tools are checked through core PluginService.""" + # Arrange + app = App( + tenant_id="tenant-1", + name="Test App", + mode=AppMode.CHAT, + enable_site=True, + enable_api=False, + created_by=str(uuid4()), + ) + app_model_config = AppModelConfig( + app_id=str(uuid4()), + agent_mode=json.dumps( + { + "enabled": True, + "strategy": "function_call", + "tools": [ + { + "provider_type": "builtin", + "provider_id": "langgenius/openai/openai", + "tool_name": "chat", + "tool_parameters": {}, + } + ], + "prompt": None, + } + ), + ) + session_context = MagicMock() + session_context.__enter__.return_value = MagicMock() + session_factory = SimpleNamespace(begin=MagicMock(return_value=session_context)) + + # Act + with ( + patch.object(App, "app_model_config", new_callable=lambda: property(lambda self: app_model_config)), + patch("models.model.db", SimpleNamespace(engine=object())), + patch("models.model.sessionmaker", return_value=session_factory), + patch("core.tools.tool_manager.ToolManager.get_hardcoded_provider", side_effect=Exception), + patch("core.plugin.plugin_service.PluginService.check_tools_existence", return_value=[False]) as exists, + ): + result = app.deleted_tools + + # Assert + assert result == [{"type": "builtin", "tool_name": "chat", "provider_id": "langgenius/openai/openai"}] + exists.assert_called_once() + assert exists.call_args.args[0] == "tenant-1" + assert [str(provider_id) for provider_id in exists.call_args.args[1]] == ["langgenius/openai/openai"] + class TestAppModelConfig: """Test suite for AppModelConfig model.""" diff --git a/api/tests/unit_tests/services/plugin/conftest.py b/api/tests/unit_tests/services/plugin/conftest.py index 9dc4fa0390..cb30058494 100644 --- a/api/tests/unit_tests/services/plugin/conftest.py +++ b/api/tests/unit_tests/services/plugin/conftest.py @@ -24,7 +24,7 @@ def make_features( def mock_installer(monkeypatch: pytest.MonkeyPatch): """Patch PluginInstaller at the service import site.""" mock = MagicMock() - monkeypatch.setattr("services.plugin.plugin_service.PluginInstaller", lambda: mock) + monkeypatch.setattr("core.plugin.plugin_service.PluginInstaller", lambda: mock) return mock @@ -34,6 +34,6 @@ def mock_features(): from unittest.mock import patch features = make_features() - with patch("services.plugin.plugin_service.FeatureService") as mock_fs: + with patch("core.plugin.plugin_service.FeatureService") as mock_fs: mock_fs.get_system_features.return_value = features yield features diff --git a/api/tests/unit_tests/services/plugin/test_plugin_migration.py b/api/tests/unit_tests/services/plugin/test_plugin_migration.py index 12b6ea23a1..8f730d4ed3 100644 --- a/api/tests/unit_tests/services/plugin/test_plugin_migration.py +++ b/api/tests/unit_tests/services/plugin/test_plugin_migration.py @@ -61,6 +61,7 @@ class TestHandlePluginInstanceInstall: patch(f"{MIGRATION_MODULE}.dify_config") as mock_cfg, patch(f"{MIGRATION_MODULE}.marketplace") as mock_marketplace, patch(f"{MIGRATION_MODULE}.PluginInstaller") as mock_installer_cls, + patch(f"{MIGRATION_MODULE}.PluginService.invalidate_plugin_model_providers_cache") as invalidate_cache, ): mock_cfg.MARKETPLACE_ENABLED = True mock_marketplace.download_plugin_pkg.return_value = b"pkg_data" @@ -73,4 +74,31 @@ class TestHandlePluginInstanceInstall: ) mock_marketplace.download_plugin_pkg.assert_called_once() + invalidate_cache.assert_called_once_with("tenant1") assert "success" in result or "failed" in result + + def test_install_plugins_invalidates_cache_after_direct_tenant_install(self, tmp_path) -> None: + extracted_plugins = tmp_path / "plugins.jsonl" + output_file = tmp_path / "output.json" + extracted_plugins.write_text('{"tenant_id":"tenant1","plugins":["langgenius/openai"]}\n') + + with ( + patch( + f"{MIGRATION_MODULE}.PluginMigration.extract_unique_plugins", + return_value={ + "plugins": {"langgenius/openai": "langgenius/openai:1.0.0@abc"}, + "plugin_not_exist": [], + }, + ), + patch(f"{MIGRATION_MODULE}.PluginMigration.handle_plugin_instance_install", return_value={}), + patch(f"{MIGRATION_MODULE}.PluginInstaller") as mock_installer_cls, + patch(f"{MIGRATION_MODULE}.PluginService.invalidate_plugin_model_providers_cache") as invalidate_cache, + ): + mock_installer = MagicMock() + mock_installer.list_plugins.return_value = [] + mock_installer_cls.return_value = mock_installer + + PluginMigration.install_plugins(str(extracted_plugins), str(output_file), workers=1) + + mock_installer.install_from_identifiers.assert_called_once() + invalidate_cache.assert_called_once_with("tenant1") diff --git a/api/tests/unit_tests/services/plugin/test_plugin_service.py b/api/tests/unit_tests/services/plugin/test_plugin_service.py index 05bb3b65c0..42edd5582b 100644 --- a/api/tests/unit_tests/services/plugin/test_plugin_service.py +++ b/api/tests/unit_tests/services/plugin/test_plugin_service.py @@ -1,6 +1,71 @@ -from unittest.mock import MagicMock, patch +import datetime +import uuid +from types import SimpleNamespace +from unittest.mock import MagicMock, Mock, patch -MODULE = "services.plugin.plugin_service" +from pydantic import TypeAdapter +from redis import RedisError + +from core.plugin.entities.plugin_daemon import PluginInstallTask, PluginInstallTaskStatus, PluginModelProviderEntity +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity + +MODULE = "core.plugin.plugin_service" + + +class _FakeSession: + def __init__(self) -> None: + self.execute = Mock() + self.scalars = Mock(return_value=SimpleNamespace(all=Mock(return_value=[]))) + + def __enter__(self) -> "_FakeSession": + return self + + def __exit__(self, exc_type, exc, traceback) -> None: + return None + + def begin(self) -> "_FakeSession": + return self + + +def _build_provider_entity(provider: str = "openai") -> ProviderEntity: + return ProviderEntity( + provider=f"langgenius/{provider}/{provider}", + label=I18nObject(en_US=provider.title()), + supported_model_types=[], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + ) + + +def _build_plugin_model_provider(*, tenant_id: str = "tenant-1", provider: str = "openai") -> PluginModelProviderEntity: + return PluginModelProviderEntity( + id=uuid.uuid4().hex, + created_at=datetime.datetime.now(), + updated_at=datetime.datetime.now(), + provider=provider, + tenant_id=tenant_id, + plugin_unique_identifier=f"langgenius/{provider}/{provider}", + plugin_id=f"langgenius/{provider}", + declaration=ProviderEntity( + provider=provider, + label=I18nObject(en_US=provider.title()), + supported_model_types=[], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + ), + ) + + +def _build_install_task(*, task_id: str = "task-1", status: PluginInstallTaskStatus) -> PluginInstallTask: + now = datetime.datetime.now() + return PluginInstallTask( + id=task_id, + created_at=now, + updated_at=now, + status=status, + total_plugins=1, + completed_plugins=1 if status != PluginInstallTaskStatus.Pending else 0, + plugins=[], + ) class TestFetchLatestPluginVersion: @@ -14,7 +79,7 @@ class TestFetchLatestPluginVersion: mock_cfg.MARKETPLACE_ENABLED = False mock_redis.get.return_value = None # all cache misses - from services.plugin.plugin_service import PluginService + from core.plugin.plugin_service import PluginService result = PluginService.fetch_latest_plugin_version(["langgenius/openai", "langgenius/anthropic"]) @@ -40,7 +105,7 @@ class TestFetchLatestPluginVersion: mock_redis.get.return_value = None mock_marketplace.batch_fetch_plugin_manifests.return_value = [manifest] - from services.plugin.plugin_service import PluginService + from core.plugin.plugin_service import PluginService result = PluginService.fetch_latest_plugin_version(["langgenius/openai"]) @@ -48,3 +113,322 @@ class TestFetchLatestPluginVersion: mock_marketplace.batch_fetch_plugin_manifests.assert_called_once() assert result["langgenius/openai"] is not None assert result["langgenius/openai"].version == "1.0.0" + + +class TestPluginModelProviderCache: + def test_fetch_plugin_model_providers_returns_cached_provider_without_calling_daemon(self) -> None: + """A valid tenant cache entry is reused across runtime calls without plugin daemon access.""" + cached_provider = _build_provider_entity() + cached_payload = TypeAdapter(list[ProviderEntity]).dump_json([cached_provider]).decode("utf-8") + + with patch(f"{MODULE}.redis_client") as redis_client: + redis_client.get.return_value = cached_payload + + from core.plugin.plugin_service import PluginService + + client = Mock() + result = PluginService.fetch_plugin_model_providers(tenant_id="tenant-1", client=client) + + assert [provider.provider for provider in result] == ["langgenius/openai/openai"] + client.fetch_model_providers.assert_not_called() + redis_client.setex.assert_not_called() + + def test_fetch_plugin_model_providers_deletes_invalid_cache_and_refetches(self) -> None: + """Invalid cache payloads are tenant-scoped invalidated before falling back to the daemon.""" + with ( + patch(f"{MODULE}.redis_client") as redis_client, + patch(f"{MODULE}.dify_config") as mock_config, + ): + redis_client.get.return_value = "not-json" + mock_config.PLUGIN_MODEL_PROVIDERS_CACHE_TTL = 86400 + client = Mock() + client.fetch_model_providers.return_value = [_build_plugin_model_provider()] + + from core.plugin.plugin_service import PluginService + + result = PluginService.fetch_plugin_model_providers(tenant_id="tenant-1", client=client) + + cache_key = "plugin_model_providers:tenant_id:tenant-1" + redis_client.delete.assert_called_once_with(cache_key) + redis_client.setex.assert_called_once() + assert redis_client.setex.call_args.args[0] == cache_key + assert redis_client.setex.call_args.args[1] == 86400 + assert [provider.provider for provider in result] == ["langgenius/openai/openai"] + + def test_fetch_plugin_model_providers_refetches_when_cache_read_fails(self) -> None: + """Redis read failures do not block provider discovery for the tenant.""" + with patch(f"{MODULE}.redis_client") as redis_client: + redis_client.get.side_effect = RedisError("redis unavailable") + client = Mock() + client.fetch_model_providers.return_value = [_build_plugin_model_provider()] + + from core.plugin.plugin_service import PluginService + + result = PluginService.fetch_plugin_model_providers(tenant_id="tenant-1", client=client) + + client.fetch_model_providers.assert_called_once_with("tenant-1") + assert [provider.provider for provider in result] == ["langgenius/openai/openai"] + + def test_fetch_plugin_model_providers_returns_fresh_result_when_cache_write_fails(self) -> None: + """Redis write failures are non-fatal after fresh provider data has been fetched.""" + with patch(f"{MODULE}.redis_client") as redis_client: + redis_client.get.return_value = None + redis_client.setex.side_effect = RedisError("redis unavailable") + client = Mock() + client.fetch_model_providers.return_value = [_build_plugin_model_provider()] + + from core.plugin.plugin_service import PluginService + + result = PluginService.fetch_plugin_model_providers(tenant_id="tenant-1", client=client) + + client.fetch_model_providers.assert_called_once_with("tenant-1") + assert [provider.provider for provider in result] == ["langgenius/openai/openai"] + + def test_fetch_plugin_model_providers_creates_default_client_on_cache_miss(self) -> None: + """The service owns plugin daemon access when no runtime-provided client is injected.""" + with ( + patch(f"{MODULE}.redis_client") as redis_client, + patch(f"{MODULE}.PluginModelClient") as client_cls, + ): + redis_client.get.return_value = None + client = client_cls.return_value + client.fetch_model_providers.return_value = [_build_plugin_model_provider()] + + from core.plugin.plugin_service import PluginService + + result = PluginService.fetch_plugin_model_providers(tenant_id="tenant-1") + + client_cls.assert_called_once_with() + client.fetch_model_providers.assert_called_once_with("tenant-1") + assert [provider.provider for provider in result] == ["langgenius/openai/openai"] + + def test_invalidate_plugin_model_providers_cache_uses_tenant_cache_key(self) -> None: + with patch(f"{MODULE}.redis_client") as redis_client: + from core.plugin.plugin_service import PluginService + + PluginService.invalidate_plugin_model_providers_cache("tenant-1") + + redis_client.delete.assert_called_once_with("plugin_model_providers:tenant_id:tenant-1") + + def test_invalidate_plugin_model_providers_cache_ignores_redis_delete_failure(self) -> None: + with patch(f"{MODULE}.redis_client") as redis_client: + redis_client.delete.side_effect = RedisError("redis unavailable") + + from core.plugin.plugin_service import PluginService + + PluginService.invalidate_plugin_model_providers_cache("tenant-1") + + redis_client.delete.assert_called_once_with("plugin_model_providers:tenant_id:tenant-1") + + +class TestPluginModelProviderCacheInvalidation: + def test_fetch_install_task_invalidates_model_provider_cache_when_finished(self) -> None: + """Finished plugin install tasks invalidate tenant provider cache.""" + task = _build_install_task(status=PluginInstallTaskStatus.Success) + + with ( + patch(f"{MODULE}.PluginInstaller") as installer_cls, + patch(f"{MODULE}.PluginService.invalidate_plugin_model_providers_cache") as invalidate_cache, + ): + installer_cls.return_value.fetch_plugin_installation_task.return_value = task + + from core.plugin.plugin_service import PluginService + + result = PluginService.fetch_install_task("tenant-1", "task-1") + + assert result is task + invalidate_cache.assert_called_once_with("tenant-1") + + def test_fetch_install_tasks_invalidates_model_provider_cache_for_finished_tasks(self) -> None: + """Finished tasks from task list polling also invalidate tenant provider cache.""" + task = _build_install_task(status=PluginInstallTaskStatus.Success) + + with ( + patch(f"{MODULE}.PluginInstaller") as installer_cls, + patch(f"{MODULE}.PluginService.invalidate_plugin_model_providers_cache") as invalidate_cache, + ): + installer_cls.return_value.fetch_plugin_installation_tasks.return_value = [task] + + from core.plugin.plugin_service import PluginService + + result = PluginService.fetch_install_tasks("tenant-1", 1, 256) + + assert result == [task] + invalidate_cache.assert_called_once_with("tenant-1") + + def test_fetch_install_tasks_ignores_running_model_provider_cache_tasks(self) -> None: + """Running plugin install tasks do not invalidate provider cache until they reach a terminal state.""" + task = _build_install_task(status=PluginInstallTaskStatus.Running) + + with ( + patch(f"{MODULE}.PluginInstaller") as installer_cls, + patch(f"{MODULE}.PluginService.invalidate_plugin_model_providers_cache") as invalidate_cache, + ): + installer_cls.return_value.fetch_plugin_installation_tasks.return_value = [task] + + from core.plugin.plugin_service import PluginService + + result = PluginService.fetch_install_tasks("tenant-1", 1, 256) + + assert result == [task] + invalidate_cache.assert_not_called() + + def test_upgrade_plugin_with_marketplace_invalidates_model_provider_cache_for_tenant(self) -> None: + """Marketplace upgrades invalidate only the mutated tenant provider cache.""" + with ( + patch(f"{MODULE}.dify_config") as mock_config, + patch(f"{MODULE}.FeatureService") as feature_service, + patch(f"{MODULE}.PluginInstaller") as installer_cls, + patch(f"{MODULE}.marketplace") as marketplace, + patch(f"{MODULE}.PluginService.invalidate_plugin_model_providers_cache") as invalidate_cache, + ): + mock_config.MARKETPLACE_ENABLED = True + feature_service.get_system_features.return_value = SimpleNamespace( + plugin_installation_permission=SimpleNamespace(restrict_to_marketplace_only=False) + ) + installer = installer_cls.return_value + installer.fetch_plugin_manifest.return_value = MagicMock() + installer.upgrade_plugin.return_value = "task-id" + + from core.plugin.plugin_service import PluginService + + result = PluginService.upgrade_plugin_with_marketplace("tenant-1", "old-uid", "new-uid") + + assert result == "task-id" + marketplace.record_install_plugin_event.assert_called_once_with("new-uid") + invalidate_cache.assert_called_once_with("tenant-1") + + def test_install_from_local_pkg_invalidates_model_provider_cache_for_tenant(self) -> None: + """Starting a plugin install invalidates only the mutated tenant provider cache.""" + with ( + patch(f"{MODULE}.PluginService._check_marketplace_only_permission"), + patch(f"{MODULE}.PluginService._check_plugin_installation_scope"), + patch(f"{MODULE}.PluginInstaller") as installer_cls, + patch(f"{MODULE}.PluginService.invalidate_plugin_model_providers_cache") as invalidate_cache, + ): + installer = installer_cls.return_value + decode_response = MagicMock() + decode_response.verification = None + installer.decode_plugin_from_identifier.return_value = decode_response + installer.install_from_identifiers.return_value = "task-id" + + from core.plugin.plugin_service import PluginService + + result = PluginService.install_from_local_pkg("tenant-1", ["langgenius/openai:1.0.0"]) + + assert result == "task-id" + invalidate_cache.assert_called_once_with("tenant-1") + + def test_upgrade_plugin_with_github_invalidates_model_provider_cache_for_tenant(self) -> None: + """Starting a plugin upgrade invalidates only the mutated tenant provider cache.""" + with ( + patch(f"{MODULE}.PluginService._check_marketplace_only_permission"), + patch(f"{MODULE}.PluginInstaller") as installer_cls, + patch(f"{MODULE}.PluginService.invalidate_plugin_model_providers_cache") as invalidate_cache, + ): + installer = installer_cls.return_value + installer.upgrade_plugin.return_value = "task-id" + + from core.plugin.plugin_service import PluginService + + result = PluginService.upgrade_plugin_with_github( + "tenant-1", "old-uid", "new-uid", "langgenius/openai", "1.0.0", "openai.difypkg" + ) + + assert result == "task-id" + invalidate_cache.assert_called_once_with("tenant-1") + + def test_install_from_github_invalidates_model_provider_cache_for_tenant(self) -> None: + """GitHub installs invalidate only the mutated tenant provider cache.""" + with ( + patch(f"{MODULE}.PluginService._check_marketplace_only_permission"), + patch(f"{MODULE}.PluginService._check_plugin_installation_scope"), + patch(f"{MODULE}.PluginInstaller") as installer_cls, + patch(f"{MODULE}.PluginService.invalidate_plugin_model_providers_cache") as invalidate_cache, + ): + installer = installer_cls.return_value + decode_response = MagicMock() + decode_response.verification = None + installer.decode_plugin_from_identifier.return_value = decode_response + installer.install_from_identifiers.return_value = "task-id" + + from core.plugin.plugin_service import PluginService + + result = PluginService.install_from_github( + "tenant-1", "langgenius/openai:1.0.0", "langgenius/openai", "1.0.0", "openai.difypkg" + ) + + assert result == "task-id" + invalidate_cache.assert_called_once_with("tenant-1") + + def test_install_from_marketplace_pkg_invalidates_model_provider_cache_for_tenant(self) -> None: + """Marketplace package installs invalidate only the mutated tenant provider cache.""" + with ( + patch(f"{MODULE}.dify_config") as mock_config, + patch(f"{MODULE}.FeatureService") as feature_service, + patch(f"{MODULE}.PluginService._check_plugin_installation_scope"), + patch(f"{MODULE}.PluginInstaller") as installer_cls, + patch(f"{MODULE}.PluginService.invalidate_plugin_model_providers_cache") as invalidate_cache, + ): + mock_config.MARKETPLACE_ENABLED = True + feature_service.get_system_features.return_value = SimpleNamespace( + plugin_installation_permission=SimpleNamespace(restrict_to_marketplace_only=False) + ) + installer = installer_cls.return_value + installer.fetch_plugin_manifest.return_value = MagicMock() + decode_response = MagicMock() + decode_response.verification = None + installer.decode_plugin_from_identifier.return_value = decode_response + installer.install_from_identifiers.return_value = "task-id" + + from core.plugin.plugin_service import PluginService + + result = PluginService.install_from_marketplace_pkg("tenant-1", ["langgenius/openai:1.0.0"]) + + assert result == "task-id" + invalidate_cache.assert_called_once_with("tenant-1") + + def test_uninstall_invalidates_model_provider_cache_for_tenant(self) -> None: + """Successful uninstall invalidates only the mutated tenant provider cache.""" + with ( + patch(f"{MODULE}.PluginInstaller") as installer_cls, + patch(f"{MODULE}.PluginService.invalidate_plugin_model_providers_cache") as invalidate_cache, + ): + installer = installer_cls.return_value + installer.list_plugins.return_value = [] + installer.uninstall.return_value = True + + from core.plugin.plugin_service import PluginService + + result = PluginService.uninstall("tenant-1", "installation-1") + + assert result is True + invalidate_cache.assert_called_once_with("tenant-1") + + def test_uninstall_existing_plugin_invalidates_cache_after_credential_cleanup(self) -> None: + """Successful uninstall with plugin metadata also invalidates the mutated tenant provider cache.""" + plugin = SimpleNamespace( + installation_id="installation-1", + plugin_id="langgenius/openai", + plugin_unique_identifier="langgenius/openai:1.0.0", + ) + session = _FakeSession() + with ( + patch(f"{MODULE}.db", SimpleNamespace(engine=object())), + patch(f"{MODULE}.dify_config") as mock_config, + patch(f"{MODULE}.PluginInstaller") as installer_cls, + patch(f"{MODULE}.Session", return_value=session), + patch(f"{MODULE}.PluginService.invalidate_plugin_model_providers_cache") as invalidate_cache, + ): + mock_config.ENTERPRISE_ENABLED = False + installer = installer_cls.return_value + installer.list_plugins.return_value = [plugin] + installer.uninstall.return_value = True + + from core.plugin.plugin_service import PluginService + + result = PluginService.uninstall("tenant-1", "installation-1") + + assert result is True + installer.uninstall.assert_called_once_with("tenant-1", "installation-1") + invalidate_cache.assert_called_once_with("tenant-1") diff --git a/docker/envs/core-services/shared.env.example b/docker/envs/core-services/shared.env.example index e25d8daacf..df4c471d11 100644 --- a/docker/envs/core-services/shared.env.example +++ b/docker/envs/core-services/shared.env.example @@ -60,6 +60,7 @@ SSRF_PROXY_HTTPS_URL=http://ssrf_proxy:3128 PGDATA=/var/lib/postgresql/data/pgdata PLUGIN_MAX_PACKAGE_SIZE=52428800 PLUGIN_MODEL_SCHEMA_CACHE_TTL=3600 +PLUGIN_MODEL_PROVIDERS_CACHE_TTL=86400 ENDPOINT_URL_TEMPLATE=http://localhost/e/{hook_id} LOG_LEVEL=INFO LOG_OUTPUT_FORMAT=text