fix: sync 35528 (#35539)

This commit is contained in:
Yunlu Wen
2026-04-24 11:59:33 +08:00
committed by GitHub
parent 38fc2a6574
commit 48e13f65dc
4 changed files with 89 additions and 8 deletions

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
from copy import deepcopy
from typing import Any
from core.app.entities.app_invoke_entities import DifyRunContext, ModelConfigWithCredentialsEntity
@@ -14,8 +15,21 @@ from graphon.nodes.llm.protocols import CredentialsProvider
class DifyCredentialsProvider:
"""Resolves and returns LLM credentials for a given provider and model.
Fetched credentials are stored in :attr:`credentials_cache` and reused for
subsequent ``fetch`` calls for the same ``(provider_name, model_name)``.
Because of that cache, a single instance can return stale credentials after
the tenant or provider configuration changes (e.g. API key rotation).
Do **not** keep one instance for the lifetime of a process or across
unrelated invocations. Create a new provider per request, workflow run, or
other bounded scope where up-to-date credentials matter.
"""
tenant_id: str
provider_manager: ProviderManager
credentials_cache: dict[tuple[str, str], dict[str, Any]]
def __init__(
self,
@@ -30,8 +44,12 @@ class DifyCredentialsProvider:
user_id=run_context.user_id,
)
self.provider_manager = provider_manager
self.credentials_cache = {}
def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]:
if (provider_name, model_name) in self.credentials_cache:
return deepcopy(self.credentials_cache[(provider_name, model_name)])
provider_configurations = self.provider_manager.get_configurations(self.tenant_id)
provider_configuration = provider_configurations.get(provider_name)
if not provider_configuration:
@@ -46,6 +64,7 @@ class DifyCredentialsProvider:
if credentials is None:
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
self.credentials_cache[(provider_name, model_name)] = deepcopy(credentials)
return credentials
@@ -65,7 +84,8 @@ class DifyModelFactory:
provider_manager=create_plugin_provider_manager(
tenant_id=run_context.tenant_id,
user_id=run_context.user_id,
)
),
enable_credentials_cache=True,
)
self.model_manager = model_manager
@@ -84,7 +104,7 @@ def build_dify_model_access(run_context: DifyRunContext) -> tuple[CredentialsPro
tenant_id=run_context.tenant_id,
user_id=run_context.user_id,
)
model_manager = ModelManager(provider_manager=provider_manager)
model_manager = ModelManager(provider_manager=provider_manager, enable_credentials_cache=True)
return (
DifyCredentialsProvider(run_context=run_context, provider_manager=provider_manager),

View File

@@ -1,5 +1,6 @@
import logging
from collections.abc import Callable, Generator, Iterable, Mapping, Sequence
from copy import deepcopy
from typing import IO, Any, Literal, Optional, ParamSpec, TypeVar, Union, cast, overload
from configs import dify_config
@@ -36,11 +37,13 @@ class ModelInstance:
Model instance class.
"""
def __init__(self, provider_model_bundle: ProviderModelBundle, model: str):
def __init__(self, provider_model_bundle: ProviderModelBundle, model: str, credentials: dict | None = None) -> None:
self.provider_model_bundle = provider_model_bundle
self.model_name = model
self.provider = provider_model_bundle.configuration.provider.provider
self.credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model)
if credentials is None:
credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model)
self.credentials = credentials
# Runtime LLM invocation fields.
self.parameters: Mapping[str, Any] = {}
self.stop: Sequence[str] = ()
@@ -434,8 +437,30 @@ class ModelInstance:
class ModelManager:
def __init__(self, provider_manager: ProviderManager):
"""Resolves :class:`ModelInstance` objects for a tenant and provider.
When ``enable_credentials_cache`` is ``True``, resolved credentials for each
``(tenant_id, provider, model_type, model)`` are stored in
``_credentials_cache`` and reused. That can return **stale** credentials after
API keys or provider settings change, so a manager constructed with
``enable_credentials_cache=True`` should not be kept for the lifetime of a
process or shared across unrelated work. Prefer a new manager per request,
workflow run, or similar bounded scope.
The default is ``enable_credentials_cache=False``; in that mode the internal
credential cache is not populated, and each ``get_model_instance`` call
loads credentials from the current provider configuration.
"""
def __init__(
self,
provider_manager: ProviderManager,
*,
enable_credentials_cache: bool = False,
) -> None:
self._provider_manager = provider_manager
self._credentials_cache: dict[tuple[str, str, str, str], Any] = {}
self._enable_credentials_cache = enable_credentials_cache
@classmethod
def for_tenant(cls, tenant_id: str, user_id: str | None = None) -> "ModelManager":
@@ -463,8 +488,19 @@ class ModelManager:
tenant_id=tenant_id, provider=provider, model_type=model_type
)
model_instance = ModelInstance(provider_model_bundle, model)
return model_instance
cred_cache_key = (tenant_id, provider, model_type.value, model)
if cred_cache_key in self._credentials_cache:
return ModelInstance(
provider_model_bundle,
model,
deepcopy(self._credentials_cache[cred_cache_key]),
)
ret = ModelInstance(provider_model_bundle, model)
if self._enable_credentials_cache:
self._credentials_cache[cred_cache_key] = deepcopy(ret.credentials)
return ret
def get_default_provider_model_name(self, tenant_id: str, model_type: ModelType) -> tuple[str | None, str | None]:
"""

View File

@@ -5,6 +5,7 @@ import uuid
from datetime import datetime
from typing import TYPE_CHECKING
from cachetools.func import ttl_cache
from pydantic import BaseModel, ConfigDict, Field, model_validator
from configs import dify_config
@@ -99,6 +100,7 @@ def try_join_default_workspace(account_id: str) -> None:
class EnterpriseService:
@classmethod
@ttl_cache(ttl=5)
def get_info(cls):
return EnterpriseRequest.send_request("GET", "/info")

View File

@@ -5,7 +5,7 @@ import redis
from pytest_mock import MockerFixture
from core.entities.provider_entities import ModelLoadBalancingConfiguration
from core.model_manager import LBModelManager
from core.model_manager import LBModelManager, ModelManager
from extensions.ext_redis import redis_client
from graphon.model_runtime.entities.model_entities import ModelType
@@ -40,6 +40,29 @@ def lb_model_manager():
return lb_model_manager
def test_model_manager_with_cache_enabled_reuses_stored_credentials():
"""With ``enable_credentials_cache=True``, later calls for the same key return cached creds."""
provider_manager = MagicMock()
bundle = MagicMock()
bundle.configuration.provider.provider = "openai"
bundle.configuration.tenant_id = "tenant-1"
bundle.configuration.model_settings = []
bundle.model_type_instance.model_type = ModelType.LLM
get_creds = MagicMock(return_value={"api_key": "first"})
bundle.configuration.get_current_credentials = get_creds
provider_manager.get_provider_model_bundle.return_value = bundle
manager = ModelManager(provider_manager, enable_credentials_cache=True)
first = manager.get_model_instance("tenant-1", "openai", ModelType.LLM, "gpt-4")
assert first.credentials == {"api_key": "first"}
get_creds.assert_called_once()
get_creds.return_value = {"api_key": "second"}
second = manager.get_model_instance("tenant-1", "openai", ModelType.LLM, "gpt-4")
assert second.credentials == {"api_key": "first"}
get_creds.assert_called_once()
def test_lb_model_manager_fetch_next(mocker: MockerFixture, lb_model_manager: LBModelManager):
# initialize redis client
redis_client.initialize(redis.Redis())