mirror of
https://github.com/langgenius/dify.git
synced 2025-12-19 17:27:16 -05:00
fix(workflow_tool): loss current user
This commit is contained in:
@@ -5,6 +5,7 @@ from typing import Any
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
@@ -18,7 +19,8 @@ from core.tools.errors import ToolInvokeError
|
||||
from extensions.ext_database import db
|
||||
from factories.file_factory import build_from_mapping
|
||||
from libs.login import current_user
|
||||
from models.model import App
|
||||
from models.account import Account
|
||||
from models.model import App, EndUser
|
||||
from models.workflow import Workflow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -79,11 +81,13 @@ class WorkflowTool(Tool):
|
||||
generator = WorkflowAppGenerator()
|
||||
assert self.runtime is not None
|
||||
assert self.runtime.invoke_from is not None
|
||||
assert current_user is not None
|
||||
user = self._resolve_user(user_id)
|
||||
if user is None:
|
||||
raise ToolInvokeError("workflow tool invoke missing user context")
|
||||
result = generator.generate(
|
||||
app_model=app,
|
||||
workflow=workflow,
|
||||
user=current_user,
|
||||
user=user,
|
||||
args={"inputs": tool_parameters, "files": files},
|
||||
invoke_from=self.runtime.invoke_from,
|
||||
streaming=False,
|
||||
@@ -227,3 +231,26 @@ class WorkflowTool(Tool):
|
||||
elif transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
file_dict["upload_file_id"] = file_dict.get("related_id")
|
||||
return file_dict
|
||||
|
||||
def _resolve_user(self, user_id: str) -> Account | EndUser | None:
|
||||
runtime = self.runtime
|
||||
try:
|
||||
user_candidate = current_user
|
||||
except RuntimeError:
|
||||
user_candidate = None
|
||||
|
||||
if user_candidate is not None and getattr(user_candidate, "is_authenticated", False):
|
||||
return user_candidate
|
||||
|
||||
if not user_id or runtime is None:
|
||||
return None
|
||||
|
||||
invoke_from = runtime.invoke_from
|
||||
if invoke_from in {InvokeFrom.SERVICE_API, InvokeFrom.WEB_APP, InvokeFrom.PUBLISHED}:
|
||||
end_user = (
|
||||
db.session.query(EndUser).where(EndUser.id == user_id, EndUser.tenant_id == runtime.tenant_id).first()
|
||||
)
|
||||
if end_user:
|
||||
return end_user
|
||||
|
||||
return db.session.query(Account).where(Account.id == user_id).first()
|
||||
|
||||
@@ -40,9 +40,64 @@ def test_workflow_tool_should_raise_tool_invoke_error_when_result_has_error_fiel
|
||||
lambda *args, **kwargs: {"data": {"error": "oops"}},
|
||||
)
|
||||
monkeypatch.setattr("libs.login.current_user", lambda *args, **kwargs: None)
|
||||
monkeypatch.setattr(
|
||||
WorkflowTool,
|
||||
"_resolve_user",
|
||||
lambda self, _user_id: type("DummyUser", (), {"id": _user_id, "is_authenticated": True})(),
|
||||
raising=False,
|
||||
)
|
||||
|
||||
with pytest.raises(ToolInvokeError) as exc_info:
|
||||
# WorkflowTool always returns a generator, so we need to iterate to
|
||||
# actually `run` the tool.
|
||||
list(tool.invoke("test_user", {}))
|
||||
assert exc_info.value.args == ("oops",)
|
||||
|
||||
|
||||
def test_workflow_tool_falls_back_to_user_resolver_when_no_current_user(monkeypatch: pytest.MonkeyPatch):
|
||||
entity = ToolEntity(
|
||||
identity=ToolIdentity(author="tester", name="work", label=I18nObject(en_US="work"), provider="prv"),
|
||||
parameters=[],
|
||||
description=None,
|
||||
has_runtime_parameters=False,
|
||||
)
|
||||
runtime = ToolRuntime(tenant_id="tenant-id", invoke_from=InvokeFrom.SERVICE_API)
|
||||
tool = WorkflowTool(
|
||||
workflow_app_id="app-id",
|
||||
workflow_as_tool_id="tool-id",
|
||||
version="1",
|
||||
workflow_entities={},
|
||||
workflow_call_depth=0,
|
||||
entity=entity,
|
||||
runtime=runtime,
|
||||
)
|
||||
|
||||
# keep tool internals simple for the test
|
||||
monkeypatch.setattr(tool, "_get_app", lambda *_args, **_kwargs: object())
|
||||
monkeypatch.setattr(tool, "_get_workflow", lambda *_args, **_kwargs: object())
|
||||
monkeypatch.setattr(tool, "_transform_args", lambda tool_parameters, **_: (tool_parameters, []))
|
||||
|
||||
captured: dict[str, str] = {}
|
||||
|
||||
class DummyUser:
|
||||
id = "dummy-user"
|
||||
is_authenticated = True
|
||||
|
||||
dummy_user = DummyUser()
|
||||
|
||||
def fake_resolver(self, user_id: str):
|
||||
captured["user_id"] = user_id
|
||||
return dummy_user
|
||||
|
||||
def fake_generate(self, *, user, **_kwargs):
|
||||
assert user is dummy_user
|
||||
return {"data": {"outputs": {}}}
|
||||
|
||||
monkeypatch.setattr("core.app.apps.workflow.app_generator.WorkflowAppGenerator.generate", fake_generate)
|
||||
monkeypatch.setattr("core.tools.workflow_as_tool.tool.current_user", None)
|
||||
monkeypatch.setattr(WorkflowTool, "_resolve_user", fake_resolver, raising=False)
|
||||
|
||||
result = list(tool.invoke("user-123", {}))
|
||||
|
||||
assert captured["user_id"] == "user-123"
|
||||
assert len(result) == 2 # text + json outputs
|
||||
|
||||
Reference in New Issue
Block a user