Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Asuka Minato
2025-10-17 10:10:16 +09:00
committed by GitHub
parent d7f0a31e24
commit 19cc6ea993
5 changed files with 43 additions and 18 deletions

View File

@@ -22,7 +22,7 @@ from core.errors.error import (
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from core.workflow.graph_engine.manager import GraphEngineManager from core.workflow.graph_engine.manager import GraphEngineManager
from libs import helper from libs import helper
from libs.login import current_user from libs.login import current_user as current_user_
from models.model import AppMode, InstalledApp from models.model import AppMode, InstalledApp
from services.app_generate_service import AppGenerateService from services.app_generate_service import AppGenerateService
from services.errors.llm import InvokeRateLimitError from services.errors.llm import InvokeRateLimitError
@@ -31,6 +31,8 @@ from .. import console_ns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
current_user = current_user_._get_current_object() # type: ignore
@console_ns.route("/installed-apps/<uuid:installed_app_id>/workflows/run") @console_ns.route("/installed-apps/<uuid:installed_app_id>/workflows/run")
class InstalledAppWorkflowRunApi(InstalledAppResource): class InstalledAppWorkflowRunApi(InstalledAppResource):

View File

@@ -303,7 +303,12 @@ def edit_permission_required(f: Callable[P, R]):
def decorated_function(*args: P.args, **kwargs: P.kwargs): def decorated_function(*args: P.args, **kwargs: P.kwargs):
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
current_user, _ = current_account_with_tenant() from libs.login import current_user
from models import Account
user = current_user._get_current_object() # type: ignore
if not isinstance(user, Account):
raise Forbidden()
if not current_user.has_edit_permission: if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
return f(*args, **kwargs) return f(*args, **kwargs)

View File

@@ -1,6 +1,6 @@
from collections.abc import Callable from collections.abc import Callable
from functools import wraps from functools import wraps
from typing import Union, cast from typing import Any
from flask import current_app, g, has_request_context, request from flask import current_app, g, has_request_context, request
from flask_login.config import EXEMPT_METHODS # type: ignore from flask_login.config import EXEMPT_METHODS # type: ignore
@@ -10,16 +10,21 @@ from configs import dify_config
from models import Account from models import Account
from models.model import EndUser from models.model import EndUser
#: A proxy for the current user. If no user is logged in, this will be an
#: anonymous user
current_user = cast(Union[Account, EndUser, None], LocalProxy(lambda: _get_user()))
def current_account_with_tenant(): def current_account_with_tenant():
if not isinstance(current_user, Account): """
Resolve the underlying account for the current user proxy and ensure tenant context exists.
Allows tests to supply plain Account mocks without the LocalProxy helper.
"""
user_proxy = current_user
get_current_object = getattr(user_proxy, "_get_current_object", None)
user = get_current_object() if callable(get_current_object) else user_proxy # type: ignore
if not isinstance(user, Account):
raise ValueError("current_user must be an Account instance") raise ValueError("current_user must be an Account instance")
assert current_user.current_tenant_id is not None, "The tenant information should be loaded." assert user.current_tenant_id is not None, "The tenant information should be loaded."
return current_user, current_user.current_tenant_id return user, user.current_tenant_id
from typing import ParamSpec, TypeVar from typing import ParamSpec, TypeVar
@@ -81,3 +86,9 @@ def _get_user() -> EndUser | Account | None:
return g._login_user # type: ignore return g._login_user # type: ignore
return None return None
#: A proxy for the current user. If no user is logged in, this will be an
#: anonymous user
# NOTE: Any here, but use _get_current_object to check the fields
current_user: Any = LocalProxy(lambda: _get_user())

View File

@@ -1479,7 +1479,7 @@ class EndUser(Base, UserMixin):
sa.Index("end_user_tenant_session_id_idx", "tenant_id", "session_id", "type"), sa.Index("end_user_tenant_session_id_idx", "tenant_id", "session_id", "type"),
) )
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
app_id = mapped_column(StringUUID, nullable=True) app_id = mapped_column(StringUUID, nullable=True)
type: Mapped[str] = mapped_column(String(255), nullable=False) type: Mapped[str] = mapped_column(String(255), nullable=False)

View File

@@ -17,7 +17,6 @@ from core.tools.entities.tool_entities import CredentialType
from core.tools.utils.encryption import ProviderConfigCache, ProviderConfigEncrypter, create_provider_encrypter from core.tools.utils.encryption import ProviderConfigCache, ProviderConfigEncrypter, create_provider_encrypter
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from libs.login import current_account_with_tenant
from models.oauth import DatasourceOauthParamConfig, DatasourceOauthTenantParamConfig, DatasourceProvider from models.oauth import DatasourceOauthParamConfig, DatasourceOauthTenantParamConfig, DatasourceProvider
from models.provider_ids import DatasourceProviderID from models.provider_ids import DatasourceProviderID
from services.plugin.plugin_service import PluginService from services.plugin.plugin_service import PluginService
@@ -25,6 +24,16 @@ from services.plugin.plugin_service import PluginService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def get_current_user():
from libs.login import current_user
from models.account import Account
from models.model import EndUser
if not isinstance(current_user._get_current_object(), (Account, EndUser)): # type: ignore
raise TypeError(f"current_user must be Account or EndUser, got {type(current_user).__name__}")
return current_user
class DatasourceProviderService: class DatasourceProviderService:
""" """
Model Provider Service Model Provider Service
@@ -93,8 +102,6 @@ class DatasourceProviderService:
""" """
get credential by id get credential by id
""" """
current_user, _ = current_account_with_tenant()
with Session(db.engine) as session: with Session(db.engine) as session:
if credential_id: if credential_id:
datasource_provider = ( datasource_provider = (
@@ -111,6 +118,7 @@ class DatasourceProviderService:
return {} return {}
# refresh the credentials # refresh the credentials
if datasource_provider.expires_at != -1 and (datasource_provider.expires_at - 60) < int(time.time()): if datasource_provider.expires_at != -1 and (datasource_provider.expires_at - 60) < int(time.time()):
current_user = get_current_user()
decrypted_credentials = self.decrypt_datasource_provider_credentials( decrypted_credentials = self.decrypt_datasource_provider_credentials(
tenant_id=tenant_id, tenant_id=tenant_id,
datasource_provider=datasource_provider, datasource_provider=datasource_provider,
@@ -159,8 +167,6 @@ class DatasourceProviderService:
""" """
get all datasource credentials by provider get all datasource credentials by provider
""" """
current_user, _ = current_account_with_tenant()
with Session(db.engine) as session: with Session(db.engine) as session:
datasource_providers = ( datasource_providers = (
session.query(DatasourceProvider) session.query(DatasourceProvider)
@@ -170,6 +176,7 @@ class DatasourceProviderService:
) )
if not datasource_providers: if not datasource_providers:
return [] return []
current_user = get_current_user()
# refresh the credentials # refresh the credentials
real_credentials_list = [] real_credentials_list = []
for datasource_provider in datasource_providers: for datasource_provider in datasource_providers:
@@ -608,7 +615,6 @@ class DatasourceProviderService:
""" """
provider_name = provider_id.provider_name provider_name = provider_id.provider_name
plugin_id = provider_id.plugin_id plugin_id = provider_id.plugin_id
current_user, _ = current_account_with_tenant()
with Session(db.engine) as session: with Session(db.engine) as session:
lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{CredentialType.API_KEY}" lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{CredentialType.API_KEY}"
@@ -630,6 +636,7 @@ class DatasourceProviderService:
raise ValueError("Authorization name is already exists") raise ValueError("Authorization name is already exists")
try: try:
current_user = get_current_user()
self.provider_manager.validate_provider_credentials( self.provider_manager.validate_provider_credentials(
tenant_id=tenant_id, tenant_id=tenant_id,
user_id=current_user.id, user_id=current_user.id,
@@ -907,7 +914,6 @@ class DatasourceProviderService:
""" """
update datasource credentials. update datasource credentials.
""" """
current_user, _ = current_account_with_tenant()
with Session(db.engine) as session: with Session(db.engine) as session:
datasource_provider = ( datasource_provider = (
@@ -944,6 +950,7 @@ class DatasourceProviderService:
for key, value in credentials.items() for key, value in credentials.items()
} }
try: try:
current_user = get_current_user()
self.provider_manager.validate_provider_credentials( self.provider_manager.validate_provider_credentials(
tenant_id=tenant_id, tenant_id=tenant_id,
user_id=current_user.id, user_id=current_user.id,