mirror of
https://github.com/langgenius/dify.git
synced 2025-12-23 11:13:07 -05:00
fix 27003 (#27005)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -22,7 +22,7 @@ from core.errors.error import (
|
|||||||
from core.model_runtime.errors.invoke import InvokeError
|
from core.model_runtime.errors.invoke import InvokeError
|
||||||
from core.workflow.graph_engine.manager import GraphEngineManager
|
from core.workflow.graph_engine.manager import GraphEngineManager
|
||||||
from libs import helper
|
from libs import helper
|
||||||
from libs.login import current_user
|
from libs.login import current_user as current_user_
|
||||||
from models.model import AppMode, InstalledApp
|
from models.model import AppMode, InstalledApp
|
||||||
from services.app_generate_service import AppGenerateService
|
from services.app_generate_service import AppGenerateService
|
||||||
from services.errors.llm import InvokeRateLimitError
|
from services.errors.llm import InvokeRateLimitError
|
||||||
@@ -31,6 +31,8 @@ from .. import console_ns
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
current_user = current_user_._get_current_object() # type: ignore
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/installed-apps/<uuid:installed_app_id>/workflows/run")
|
@console_ns.route("/installed-apps/<uuid:installed_app_id>/workflows/run")
|
||||||
class InstalledAppWorkflowRunApi(InstalledAppResource):
|
class InstalledAppWorkflowRunApi(InstalledAppResource):
|
||||||
|
|||||||
@@ -303,7 +303,12 @@ def edit_permission_required(f: Callable[P, R]):
|
|||||||
def decorated_function(*args: P.args, **kwargs: P.kwargs):
|
def decorated_function(*args: P.args, **kwargs: P.kwargs):
|
||||||
from werkzeug.exceptions import Forbidden
|
from werkzeug.exceptions import Forbidden
|
||||||
|
|
||||||
current_user, _ = current_account_with_tenant()
|
from libs.login import current_user
|
||||||
|
from models import Account
|
||||||
|
|
||||||
|
user = current_user._get_current_object() # type: ignore
|
||||||
|
if not isinstance(user, Account):
|
||||||
|
raise Forbidden()
|
||||||
if not current_user.has_edit_permission:
|
if not current_user.has_edit_permission:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
return f(*args, **kwargs)
|
return f(*args, **kwargs)
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Union, cast
|
from typing import Any
|
||||||
|
|
||||||
from flask import current_app, g, has_request_context, request
|
from flask import current_app, g, has_request_context, request
|
||||||
from flask_login.config import EXEMPT_METHODS # type: ignore
|
from flask_login.config import EXEMPT_METHODS # type: ignore
|
||||||
@@ -10,16 +10,21 @@ from configs import dify_config
|
|||||||
from models import Account
|
from models import Account
|
||||||
from models.model import EndUser
|
from models.model import EndUser
|
||||||
|
|
||||||
#: A proxy for the current user. If no user is logged in, this will be an
|
|
||||||
#: anonymous user
|
|
||||||
current_user = cast(Union[Account, EndUser, None], LocalProxy(lambda: _get_user()))
|
|
||||||
|
|
||||||
|
|
||||||
def current_account_with_tenant():
|
def current_account_with_tenant():
|
||||||
if not isinstance(current_user, Account):
|
"""
|
||||||
|
Resolve the underlying account for the current user proxy and ensure tenant context exists.
|
||||||
|
Allows tests to supply plain Account mocks without the LocalProxy helper.
|
||||||
|
"""
|
||||||
|
user_proxy = current_user
|
||||||
|
|
||||||
|
get_current_object = getattr(user_proxy, "_get_current_object", None)
|
||||||
|
user = get_current_object() if callable(get_current_object) else user_proxy # type: ignore
|
||||||
|
|
||||||
|
if not isinstance(user, Account):
|
||||||
raise ValueError("current_user must be an Account instance")
|
raise ValueError("current_user must be an Account instance")
|
||||||
assert current_user.current_tenant_id is not None, "The tenant information should be loaded."
|
assert user.current_tenant_id is not None, "The tenant information should be loaded."
|
||||||
return current_user, current_user.current_tenant_id
|
return user, user.current_tenant_id
|
||||||
|
|
||||||
|
|
||||||
from typing import ParamSpec, TypeVar
|
from typing import ParamSpec, TypeVar
|
||||||
@@ -81,3 +86,9 @@ def _get_user() -> EndUser | Account | None:
|
|||||||
return g._login_user # type: ignore
|
return g._login_user # type: ignore
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
#: A proxy for the current user. If no user is logged in, this will be an
|
||||||
|
#: anonymous user
|
||||||
|
# NOTE: Any here, but use _get_current_object to check the fields
|
||||||
|
current_user: Any = LocalProxy(lambda: _get_user())
|
||||||
|
|||||||
@@ -1479,7 +1479,7 @@ class EndUser(Base, UserMixin):
|
|||||||
sa.Index("end_user_tenant_session_id_idx", "tenant_id", "session_id", "type"),
|
sa.Index("end_user_tenant_session_id_idx", "tenant_id", "session_id", "type"),
|
||||||
)
|
)
|
||||||
|
|
||||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||||
app_id = mapped_column(StringUUID, nullable=True)
|
app_id = mapped_column(StringUUID, nullable=True)
|
||||||
type: Mapped[str] = mapped_column(String(255), nullable=False)
|
type: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
|
|||||||
@@ -17,7 +17,6 @@ from core.tools.entities.tool_entities import CredentialType
|
|||||||
from core.tools.utils.encryption import ProviderConfigCache, ProviderConfigEncrypter, create_provider_encrypter
|
from core.tools.utils.encryption import ProviderConfigCache, ProviderConfigEncrypter, create_provider_encrypter
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from extensions.ext_redis import redis_client
|
from extensions.ext_redis import redis_client
|
||||||
from libs.login import current_account_with_tenant
|
|
||||||
from models.oauth import DatasourceOauthParamConfig, DatasourceOauthTenantParamConfig, DatasourceProvider
|
from models.oauth import DatasourceOauthParamConfig, DatasourceOauthTenantParamConfig, DatasourceProvider
|
||||||
from models.provider_ids import DatasourceProviderID
|
from models.provider_ids import DatasourceProviderID
|
||||||
from services.plugin.plugin_service import PluginService
|
from services.plugin.plugin_service import PluginService
|
||||||
@@ -25,6 +24,16 @@ from services.plugin.plugin_service import PluginService
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def get_current_user():
|
||||||
|
from libs.login import current_user
|
||||||
|
from models.account import Account
|
||||||
|
from models.model import EndUser
|
||||||
|
|
||||||
|
if not isinstance(current_user._get_current_object(), (Account, EndUser)): # type: ignore
|
||||||
|
raise TypeError(f"current_user must be Account or EndUser, got {type(current_user).__name__}")
|
||||||
|
return current_user
|
||||||
|
|
||||||
|
|
||||||
class DatasourceProviderService:
|
class DatasourceProviderService:
|
||||||
"""
|
"""
|
||||||
Model Provider Service
|
Model Provider Service
|
||||||
@@ -93,8 +102,6 @@ class DatasourceProviderService:
|
|||||||
"""
|
"""
|
||||||
get credential by id
|
get credential by id
|
||||||
"""
|
"""
|
||||||
current_user, _ = current_account_with_tenant()
|
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine) as session:
|
||||||
if credential_id:
|
if credential_id:
|
||||||
datasource_provider = (
|
datasource_provider = (
|
||||||
@@ -111,6 +118,7 @@ class DatasourceProviderService:
|
|||||||
return {}
|
return {}
|
||||||
# refresh the credentials
|
# refresh the credentials
|
||||||
if datasource_provider.expires_at != -1 and (datasource_provider.expires_at - 60) < int(time.time()):
|
if datasource_provider.expires_at != -1 and (datasource_provider.expires_at - 60) < int(time.time()):
|
||||||
|
current_user = get_current_user()
|
||||||
decrypted_credentials = self.decrypt_datasource_provider_credentials(
|
decrypted_credentials = self.decrypt_datasource_provider_credentials(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
datasource_provider=datasource_provider,
|
datasource_provider=datasource_provider,
|
||||||
@@ -159,8 +167,6 @@ class DatasourceProviderService:
|
|||||||
"""
|
"""
|
||||||
get all datasource credentials by provider
|
get all datasource credentials by provider
|
||||||
"""
|
"""
|
||||||
current_user, _ = current_account_with_tenant()
|
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine) as session:
|
||||||
datasource_providers = (
|
datasource_providers = (
|
||||||
session.query(DatasourceProvider)
|
session.query(DatasourceProvider)
|
||||||
@@ -170,6 +176,7 @@ class DatasourceProviderService:
|
|||||||
)
|
)
|
||||||
if not datasource_providers:
|
if not datasource_providers:
|
||||||
return []
|
return []
|
||||||
|
current_user = get_current_user()
|
||||||
# refresh the credentials
|
# refresh the credentials
|
||||||
real_credentials_list = []
|
real_credentials_list = []
|
||||||
for datasource_provider in datasource_providers:
|
for datasource_provider in datasource_providers:
|
||||||
@@ -608,7 +615,6 @@ class DatasourceProviderService:
|
|||||||
"""
|
"""
|
||||||
provider_name = provider_id.provider_name
|
provider_name = provider_id.provider_name
|
||||||
plugin_id = provider_id.plugin_id
|
plugin_id = provider_id.plugin_id
|
||||||
current_user, _ = current_account_with_tenant()
|
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine) as session:
|
||||||
lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{CredentialType.API_KEY}"
|
lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{CredentialType.API_KEY}"
|
||||||
@@ -630,6 +636,7 @@ class DatasourceProviderService:
|
|||||||
raise ValueError("Authorization name is already exists")
|
raise ValueError("Authorization name is already exists")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
current_user = get_current_user()
|
||||||
self.provider_manager.validate_provider_credentials(
|
self.provider_manager.validate_provider_credentials(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
@@ -907,7 +914,6 @@ class DatasourceProviderService:
|
|||||||
"""
|
"""
|
||||||
update datasource credentials.
|
update datasource credentials.
|
||||||
"""
|
"""
|
||||||
current_user, _ = current_account_with_tenant()
|
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine) as session:
|
||||||
datasource_provider = (
|
datasource_provider = (
|
||||||
@@ -944,6 +950,7 @@ class DatasourceProviderService:
|
|||||||
for key, value in credentials.items()
|
for key, value in credentials.items()
|
||||||
}
|
}
|
||||||
try:
|
try:
|
||||||
|
current_user = get_current_user()
|
||||||
self.provider_manager.validate_provider_credentials(
|
self.provider_manager.validate_provider_credentials(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
|
|||||||
Reference in New Issue
Block a user