mirror of
https://github.com/langgenius/dify.git
synced 2026-05-01 10:00:11 -04:00
fix: sync 35528 (#35539)
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
from core.app.entities.app_invoke_entities import DifyRunContext, ModelConfigWithCredentialsEntity
|
||||
@@ -14,8 +15,21 @@ from graphon.nodes.llm.protocols import CredentialsProvider
|
||||
|
||||
|
||||
class DifyCredentialsProvider:
|
||||
"""Resolves and returns LLM credentials for a given provider and model.
|
||||
|
||||
Fetched credentials are stored in :attr:`credentials_cache` and reused for
|
||||
subsequent ``fetch`` calls for the same ``(provider_name, model_name)``.
|
||||
Because of that cache, a single instance can return stale credentials after
|
||||
the tenant or provider configuration changes (e.g. API key rotation).
|
||||
|
||||
Do **not** keep one instance for the lifetime of a process or across
|
||||
unrelated invocations. Create a new provider per request, workflow run, or
|
||||
other bounded scope where up-to-date credentials matter.
|
||||
"""
|
||||
|
||||
tenant_id: str
|
||||
provider_manager: ProviderManager
|
||||
credentials_cache: dict[tuple[str, str], dict[str, Any]]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -30,8 +44,12 @@ class DifyCredentialsProvider:
|
||||
user_id=run_context.user_id,
|
||||
)
|
||||
self.provider_manager = provider_manager
|
||||
self.credentials_cache = {}
|
||||
|
||||
def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]:
|
||||
if (provider_name, model_name) in self.credentials_cache:
|
||||
return deepcopy(self.credentials_cache[(provider_name, model_name)])
|
||||
|
||||
provider_configurations = self.provider_manager.get_configurations(self.tenant_id)
|
||||
provider_configuration = provider_configurations.get(provider_name)
|
||||
if not provider_configuration:
|
||||
@@ -46,6 +64,7 @@ class DifyCredentialsProvider:
|
||||
if credentials is None:
|
||||
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
|
||||
|
||||
self.credentials_cache[(provider_name, model_name)] = deepcopy(credentials)
|
||||
return credentials
|
||||
|
||||
|
||||
@@ -65,7 +84,8 @@ class DifyModelFactory:
|
||||
provider_manager=create_plugin_provider_manager(
|
||||
tenant_id=run_context.tenant_id,
|
||||
user_id=run_context.user_id,
|
||||
)
|
||||
),
|
||||
enable_credentials_cache=True,
|
||||
)
|
||||
self.model_manager = model_manager
|
||||
|
||||
@@ -84,7 +104,7 @@ def build_dify_model_access(run_context: DifyRunContext) -> tuple[CredentialsPro
|
||||
tenant_id=run_context.tenant_id,
|
||||
user_id=run_context.user_id,
|
||||
)
|
||||
model_manager = ModelManager(provider_manager=provider_manager)
|
||||
model_manager = ModelManager(provider_manager=provider_manager, enable_credentials_cache=True)
|
||||
|
||||
return (
|
||||
DifyCredentialsProvider(run_context=run_context, provider_manager=provider_manager),
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import logging
|
||||
from collections.abc import Callable, Generator, Iterable, Mapping, Sequence
|
||||
from copy import deepcopy
|
||||
from typing import IO, Any, Literal, Optional, ParamSpec, TypeVar, Union, cast, overload
|
||||
|
||||
from configs import dify_config
|
||||
@@ -36,11 +37,13 @@ class ModelInstance:
|
||||
Model instance class.
|
||||
"""
|
||||
|
||||
def __init__(self, provider_model_bundle: ProviderModelBundle, model: str):
|
||||
def __init__(self, provider_model_bundle: ProviderModelBundle, model: str, credentials: dict | None = None) -> None:
|
||||
self.provider_model_bundle = provider_model_bundle
|
||||
self.model_name = model
|
||||
self.provider = provider_model_bundle.configuration.provider.provider
|
||||
self.credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model)
|
||||
if credentials is None:
|
||||
credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model)
|
||||
self.credentials = credentials
|
||||
# Runtime LLM invocation fields.
|
||||
self.parameters: Mapping[str, Any] = {}
|
||||
self.stop: Sequence[str] = ()
|
||||
@@ -434,8 +437,30 @@ class ModelInstance:
|
||||
|
||||
|
||||
class ModelManager:
|
||||
def __init__(self, provider_manager: ProviderManager):
|
||||
"""Resolves :class:`ModelInstance` objects for a tenant and provider.
|
||||
|
||||
When ``enable_credentials_cache`` is ``True``, resolved credentials for each
|
||||
``(tenant_id, provider, model_type, model)`` are stored in
|
||||
``_credentials_cache`` and reused. That can return **stale** credentials after
|
||||
API keys or provider settings change, so a manager constructed with
|
||||
``enable_credentials_cache=True`` should not be kept for the lifetime of a
|
||||
process or shared across unrelated work. Prefer a new manager per request,
|
||||
workflow run, or similar bounded scope.
|
||||
|
||||
The default is ``enable_credentials_cache=False``; in that mode the internal
|
||||
credential cache is not populated, and each ``get_model_instance`` call
|
||||
loads credentials from the current provider configuration.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
provider_manager: ProviderManager,
|
||||
*,
|
||||
enable_credentials_cache: bool = False,
|
||||
) -> None:
|
||||
self._provider_manager = provider_manager
|
||||
self._credentials_cache: dict[tuple[str, str, str, str], Any] = {}
|
||||
self._enable_credentials_cache = enable_credentials_cache
|
||||
|
||||
@classmethod
|
||||
def for_tenant(cls, tenant_id: str, user_id: str | None = None) -> "ModelManager":
|
||||
@@ -463,8 +488,19 @@ class ModelManager:
|
||||
tenant_id=tenant_id, provider=provider, model_type=model_type
|
||||
)
|
||||
|
||||
model_instance = ModelInstance(provider_model_bundle, model)
|
||||
return model_instance
|
||||
cred_cache_key = (tenant_id, provider, model_type.value, model)
|
||||
|
||||
if cred_cache_key in self._credentials_cache:
|
||||
return ModelInstance(
|
||||
provider_model_bundle,
|
||||
model,
|
||||
deepcopy(self._credentials_cache[cred_cache_key]),
|
||||
)
|
||||
|
||||
ret = ModelInstance(provider_model_bundle, model)
|
||||
if self._enable_credentials_cache:
|
||||
self._credentials_cache[cred_cache_key] = deepcopy(ret.credentials)
|
||||
return ret
|
||||
|
||||
def get_default_provider_model_name(self, tenant_id: str, model_type: ModelType) -> tuple[str | None, str | None]:
|
||||
"""
|
||||
|
||||
@@ -5,6 +5,7 @@ import uuid
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from cachetools.func import ttl_cache
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
|
||||
from configs import dify_config
|
||||
@@ -99,6 +100,7 @@ def try_join_default_workspace(account_id: str) -> None:
|
||||
|
||||
class EnterpriseService:
|
||||
@classmethod
|
||||
@ttl_cache(ttl=5)
|
||||
def get_info(cls):
|
||||
return EnterpriseRequest.send_request("GET", "/info")
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ import redis
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from core.entities.provider_entities import ModelLoadBalancingConfiguration
|
||||
from core.model_manager import LBModelManager
|
||||
from core.model_manager import LBModelManager, ModelManager
|
||||
from extensions.ext_redis import redis_client
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
|
||||
@@ -40,6 +40,29 @@ def lb_model_manager():
|
||||
return lb_model_manager
|
||||
|
||||
|
||||
def test_model_manager_with_cache_enabled_reuses_stored_credentials():
|
||||
"""With ``enable_credentials_cache=True``, later calls for the same key return cached creds."""
|
||||
provider_manager = MagicMock()
|
||||
bundle = MagicMock()
|
||||
bundle.configuration.provider.provider = "openai"
|
||||
bundle.configuration.tenant_id = "tenant-1"
|
||||
bundle.configuration.model_settings = []
|
||||
bundle.model_type_instance.model_type = ModelType.LLM
|
||||
get_creds = MagicMock(return_value={"api_key": "first"})
|
||||
bundle.configuration.get_current_credentials = get_creds
|
||||
provider_manager.get_provider_model_bundle.return_value = bundle
|
||||
|
||||
manager = ModelManager(provider_manager, enable_credentials_cache=True)
|
||||
first = manager.get_model_instance("tenant-1", "openai", ModelType.LLM, "gpt-4")
|
||||
assert first.credentials == {"api_key": "first"}
|
||||
get_creds.assert_called_once()
|
||||
|
||||
get_creds.return_value = {"api_key": "second"}
|
||||
second = manager.get_model_instance("tenant-1", "openai", ModelType.LLM, "gpt-4")
|
||||
assert second.credentials == {"api_key": "first"}
|
||||
get_creds.assert_called_once()
|
||||
|
||||
|
||||
def test_lb_model_manager_fetch_next(mocker: MockerFixture, lb_model_manager: LBModelManager):
|
||||
# initialize redis client
|
||||
redis_client.initialize(redis.Redis())
|
||||
|
||||
Reference in New Issue
Block a user