mirror of
https://github.com/langgenius/dify.git
synced 2025-12-19 17:27:16 -05:00
feat: add end user oauth apis (#29683)
This commit is contained in:
@@ -54,6 +54,7 @@ from .app import (
|
||||
completion,
|
||||
conversation,
|
||||
conversation_variables,
|
||||
enduser_auth,
|
||||
generator,
|
||||
mcp_server,
|
||||
message,
|
||||
@@ -162,6 +163,7 @@ __all__ = [
|
||||
"datasource_content_preview",
|
||||
"email_register",
|
||||
"endpoint",
|
||||
"enduser_auth",
|
||||
"extension",
|
||||
"external",
|
||||
"feature",
|
||||
|
||||
230
api/controllers/console/app/enduser_auth.py
Normal file
230
api/controllers/console/app/enduser_auth.py
Normal file
@@ -0,0 +1,230 @@
|
||||
"""
|
||||
End-user authentication API controllers.
|
||||
|
||||
Provides API endpoints for managing end-user credentials for tool authentication.
|
||||
"""
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import AppMode
|
||||
from services.tools.app_auth_requirement_service import AppAuthRequirementService
|
||||
from services.tools.enduser_auth_service import EndUserAuthService
|
||||
from services.tools.enduser_oauth_service import EndUserOAuthService
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/auth/providers")
|
||||
class AppAuthProvidersApi(Resource):
|
||||
"""
|
||||
Get list of authentication providers required for an app.
|
||||
|
||||
Returns providers that require end-user authentication based on app configuration.
|
||||
"""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.COMPLETION, AppMode.WORKFLOW])
|
||||
def get(self, app_model):
|
||||
"""Get authentication providers required for the app."""
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
providers = AppAuthRequirementService.get_required_providers(
|
||||
tenant_id=tenant_id,
|
||||
app_id=str(app_model.id),
|
||||
)
|
||||
|
||||
return jsonable_encoder(providers)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/auth/providers/<path:provider_id>/credentials")
|
||||
class AppAuthProviderCredentialsApi(Resource):
|
||||
"""
|
||||
Manage end-user credentials for a specific provider.
|
||||
|
||||
Allows listing, creating, and deleting end-user credentials.
|
||||
"""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.COMPLETION, AppMode.WORKFLOW])
|
||||
def get(self, app_model, provider_id: str):
|
||||
"""List end-user's credentials for this provider."""
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
# For console API, use the current account user as end_user_id
|
||||
# In production, this would be the actual end-user ID from the chat/completion request
|
||||
end_user_id = str(user.id)
|
||||
|
||||
credentials = EndUserAuthService.list_credentials(
|
||||
tenant_id=tenant_id,
|
||||
end_user_id=end_user_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
|
||||
return jsonable_encoder(credentials)
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.COMPLETION, AppMode.WORKFLOW])
|
||||
def post(self, app_model, provider_id: str):
|
||||
"""Create a new credential (API key only)."""
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
end_user_id = str(user.id)
|
||||
|
||||
payload = request.get_json()
|
||||
if not payload:
|
||||
raise BadRequest("Request body is required")
|
||||
|
||||
credential_type = payload.get("credential_type")
|
||||
credentials = payload.get("credentials")
|
||||
|
||||
if not credential_type or not credentials:
|
||||
raise BadRequest("credential_type and credentials are required")
|
||||
|
||||
if credential_type != "api-key":
|
||||
raise BadRequest(
|
||||
"Only 'api-key' credential type can be created via this endpoint. "
|
||||
"Use OAuth flow for OAuth credentials."
|
||||
)
|
||||
|
||||
credential = EndUserAuthService.create_api_key_credential(
|
||||
tenant_id=tenant_id,
|
||||
end_user_id=end_user_id,
|
||||
provider_id=provider_id,
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
return jsonable_encoder(credential)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/auth/providers/<path:provider_id>/credentials/<string:credential_id>")
|
||||
class AppAuthProviderCredentialApi(Resource):
|
||||
"""
|
||||
Manage a specific end-user credential.
|
||||
|
||||
Allows getting, updating, or deleting a credential.
|
||||
"""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.COMPLETION, AppMode.WORKFLOW])
|
||||
def delete(self, app_model, provider_id: str, credential_id: str):
|
||||
"""Delete a credential."""
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
end_user_id = str(user.id)
|
||||
|
||||
EndUserAuthService.delete_credential(
|
||||
tenant_id=tenant_id,
|
||||
end_user_id=end_user_id,
|
||||
provider_id=provider_id,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/auth/oauth/<path:provider_id>/authorization-url")
|
||||
class AppAuthOAuthAuthorizationUrlApi(Resource):
|
||||
"""
|
||||
Get OAuth authorization URL for end-user authentication.
|
||||
|
||||
Returns the URL where the user should be redirected to authenticate with the provider.
|
||||
"""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.COMPLETION, AppMode.WORKFLOW])
|
||||
def get(self, app_model, provider_id: str):
|
||||
"""Get OAuth authorization URL."""
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
end_user_id = str(user.id)
|
||||
|
||||
result = EndUserOAuthService.get_authorization_url(
|
||||
end_user_id=end_user_id,
|
||||
tenant_id=tenant_id,
|
||||
app_id=str(app_model.id),
|
||||
provider=provider_id,
|
||||
)
|
||||
|
||||
# Set OAuth context cookie for callback
|
||||
response = jsonable_encoder({
|
||||
"authorization_url": result["authorization_url"],
|
||||
})
|
||||
|
||||
# Store context_id in response for frontend to set as cookie
|
||||
response["context_id"] = result["context_id"]
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/auth/oauth/<path:provider_id>/callback")
|
||||
class AppAuthOAuthCallbackApi(Resource):
|
||||
"""
|
||||
Handle OAuth callback for end-user authentication.
|
||||
|
||||
This endpoint is called by the OAuth provider after user authorization.
|
||||
"""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.COMPLETION, AppMode.WORKFLOW])
|
||||
def get(self, app_model, provider_id: str):
|
||||
"""Handle OAuth callback and store credentials."""
|
||||
# Get OAuth context ID from cookie
|
||||
context_id = request.cookies.get("oauth_context_id")
|
||||
|
||||
if not context_id:
|
||||
raise BadRequest("Missing OAuth context")
|
||||
|
||||
# Get OAuth error if any
|
||||
error = request.args.get("error")
|
||||
error_description = request.args.get("error_description")
|
||||
|
||||
if error:
|
||||
raise BadRequest(f"OAuth error: {error} - {error_description}")
|
||||
|
||||
# Handle callback and create credential
|
||||
result = EndUserOAuthService.handle_oauth_callback(
|
||||
context_id=context_id,
|
||||
request=request,
|
||||
)
|
||||
|
||||
return jsonable_encoder(result)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/auth/providers/<path:provider_id>/credentials/<string:credential_id>/refresh")
|
||||
class AppAuthProviderRefreshApi(Resource):
|
||||
"""
|
||||
Manually refresh OAuth token for a credential.
|
||||
|
||||
This endpoint allows refreshing an expired OAuth token.
|
||||
"""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.COMPLETION, AppMode.WORKFLOW])
|
||||
def post(self, app_model, provider_id: str, credential_id: str):
|
||||
"""Refresh OAuth token."""
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
end_user_id = str(user.id)
|
||||
|
||||
result = EndUserOAuthService.refresh_oauth_token(
|
||||
credential_id=credential_id,
|
||||
end_user_id=end_user_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
return jsonable_encoder(result)
|
||||
@@ -3,7 +3,7 @@ from typing import Any, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
|
||||
from core.tools.entities.tool_entities import ToolAuthType, ToolInvokeMessage, ToolProviderType
|
||||
|
||||
|
||||
class AgentToolEntity(BaseModel):
|
||||
@@ -17,6 +17,7 @@ class AgentToolEntity(BaseModel):
|
||||
tool_parameters: dict[str, Any] = Field(default_factory=dict)
|
||||
plugin_unique_identifier: str | None = None
|
||||
credential_id: str | None = None
|
||||
auth_type: ToolAuthType = ToolAuthType.WORKSPACE
|
||||
|
||||
|
||||
class AgentPromptEntity(BaseModel):
|
||||
|
||||
@@ -124,6 +124,7 @@ class ToolProviderCredentialApiEntity(BaseModel):
|
||||
default=False, description="Whether the credential is the default credential for the provider in the workspace"
|
||||
)
|
||||
credentials: Mapping[str, object] = Field(description="The credentials of the provider", default_factory=dict)
|
||||
expires_at: int = Field(default=-1, description="Unix timestamp when credential expires (-1 for no expiry)")
|
||||
|
||||
|
||||
class ToolProviderCredentialInfoApiEntity(BaseModel):
|
||||
|
||||
@@ -49,6 +49,7 @@ from core.tools.entities.api_entities import ToolProviderApiEntity, ToolProvider
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import (
|
||||
ApiProviderAuthType,
|
||||
ToolAuthType,
|
||||
ToolInvokeFrom,
|
||||
ToolParameter,
|
||||
ToolProviderType,
|
||||
@@ -58,7 +59,7 @@ from core.tools.tool_label_manager import ToolLabelManager
|
||||
from core.tools.utils.configuration import ToolParameterConfigurationManager
|
||||
from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter
|
||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||
from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
|
||||
from models.tools import ApiToolProvider, BuiltinToolProvider, EndUserAuthenticationProvider, WorkflowToolProvider
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -78,6 +79,54 @@ class ToolManager:
|
||||
_builtin_providers_loaded = False
|
||||
_builtin_tools_labels: dict[str, Union[I18nObject, None]] = {}
|
||||
|
||||
@classmethod
|
||||
def _refresh_oauth_credentials(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
provider_id: str,
|
||||
user_id: str,
|
||||
decrypted_credentials: Mapping[str, Any],
|
||||
) -> tuple[dict[str, Any], int]:
|
||||
"""
|
||||
Refresh OAuth credentials for a provider.
|
||||
|
||||
This is a helper method to centralize the OAuth token refresh logic
|
||||
used by both end-user and workspace authentication flows.
|
||||
|
||||
:param tenant_id: the tenant id
|
||||
:param provider_id: the provider id
|
||||
:param user_id: the user id (end_user_id or workspace user_id)
|
||||
:param decrypted_credentials: the current decrypted credentials
|
||||
|
||||
:return: tuple of (refreshed credentials dict, expires_at timestamp)
|
||||
"""
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
|
||||
# Local import to avoid circular dependency at module level
|
||||
# This import is necessary but creates a cycle: tool_manager -> builtin_tools_manage_service -> tool_manager
|
||||
# TODO: Break the circular dependency by refactoring service layer
|
||||
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||
|
||||
# Parse provider ID and build OAuth configuration
|
||||
tool_provider = ToolProviderID(provider_id)
|
||||
provider_name = tool_provider.provider_name
|
||||
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/tool/callback"
|
||||
system_credentials = BuiltinToolManageService.get_oauth_client(tenant_id, provider_id)
|
||||
|
||||
# Refresh the credentials using OAuth handler
|
||||
oauth_handler = OAuthHandler()
|
||||
refreshed_credentials = oauth_handler.refresh_credentials(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
plugin_id=tool_provider.plugin_id,
|
||||
provider=provider_name,
|
||||
redirect_uri=redirect_uri,
|
||||
system_credentials=system_credentials or {},
|
||||
credentials=decrypted_credentials,
|
||||
)
|
||||
|
||||
return refreshed_credentials.credentials, refreshed_credentials.expires_at
|
||||
|
||||
@classmethod
|
||||
def get_hardcoded_provider(cls, provider: str) -> BuiltinToolProviderController:
|
||||
"""
|
||||
@@ -165,6 +214,8 @@ class ToolManager:
|
||||
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
|
||||
tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT,
|
||||
credential_id: str | None = None,
|
||||
auth_type: ToolAuthType = ToolAuthType.WORKSPACE,
|
||||
end_user_id: str | None = None,
|
||||
) -> Union[BuiltinTool, PluginTool, ApiTool, WorkflowTool, MCPTool]:
|
||||
"""
|
||||
get the tool runtime
|
||||
@@ -176,6 +227,8 @@ class ToolManager:
|
||||
:param invoke_from: invoke from
|
||||
:param tool_invoke_from: the tool invoke from
|
||||
:param credential_id: the credential id
|
||||
:param auth_type: the authentication type (workspace or end_user)
|
||||
:param end_user_id: the end user id (required when auth_type is END_USER)
|
||||
|
||||
:return: the tool
|
||||
"""
|
||||
@@ -200,6 +253,75 @@ class ToolManager:
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
# Handle end-user authentication
|
||||
if auth_type == ToolAuthType.END_USER:
|
||||
if not end_user_id:
|
||||
raise ToolProviderNotFoundError("end_user_id is required for END_USER auth_type")
|
||||
|
||||
# Query end-user credentials
|
||||
enduser_provider = (
|
||||
db.session.query(EndUserAuthenticationProvider)
|
||||
.where(
|
||||
EndUserAuthenticationProvider.tenant_id == tenant_id,
|
||||
EndUserAuthenticationProvider.end_user_id == end_user_id,
|
||||
EndUserAuthenticationProvider.provider == provider_id,
|
||||
)
|
||||
.order_by(EndUserAuthenticationProvider.created_at.asc())
|
||||
.first()
|
||||
)
|
||||
|
||||
if enduser_provider is None:
|
||||
raise ToolProviderNotFoundError(
|
||||
f"No end-user credentials found for provider {provider_id}"
|
||||
)
|
||||
|
||||
# Decrypt end-user credentials
|
||||
encrypter, cache = create_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[
|
||||
x.to_basic_provider_config()
|
||||
for x in provider_controller.get_credentials_schema_by_type(enduser_provider.credential_type)
|
||||
],
|
||||
cache=ToolProviderCredentialsCache(
|
||||
tenant_id=tenant_id, provider=provider_id, credential_id=enduser_provider.id
|
||||
),
|
||||
)
|
||||
|
||||
decrypted_credentials: Mapping[str, Any] = encrypter.decrypt(enduser_provider.credentials)
|
||||
|
||||
# Handle OAuth token refresh for end-users if expired
|
||||
if enduser_provider.expires_at != -1 and (enduser_provider.expires_at - 60) < int(time.time()):
|
||||
# Refresh credentials using the centralized helper method
|
||||
refreshed_credentials, expires_at = cls._refresh_oauth_credentials(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=provider_id,
|
||||
user_id=end_user_id,
|
||||
decrypted_credentials=decrypted_credentials,
|
||||
)
|
||||
|
||||
# Update the provider with refreshed credentials
|
||||
enduser_provider.encrypted_credentials = json.dumps(encrypter.encrypt(refreshed_credentials))
|
||||
enduser_provider.expires_at = expires_at
|
||||
db.session.commit()
|
||||
decrypted_credentials = refreshed_credentials
|
||||
cache.delete()
|
||||
|
||||
return cast(
|
||||
BuiltinTool,
|
||||
builtin_tool.fork_tool_runtime(
|
||||
runtime=ToolRuntime(
|
||||
tenant_id=tenant_id,
|
||||
credentials=dict(decrypted_credentials),
|
||||
credential_type=CredentialType.of(enduser_provider.credential_type),
|
||||
runtime_parameters={},
|
||||
invoke_from=invoke_from,
|
||||
tool_invoke_from=tool_invoke_from,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
# Handle workspace authentication (existing logic)
|
||||
builtin_provider = None
|
||||
if isinstance(provider_controller, PluginToolProviderController):
|
||||
provider_id_entity = ToolProviderID(provider_id)
|
||||
@@ -270,34 +392,19 @@ class ToolManager:
|
||||
|
||||
# check if the credentials is expired
|
||||
if builtin_provider.expires_at != -1 and (builtin_provider.expires_at - 60) < int(time.time()):
|
||||
# TODO: circular import
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||
|
||||
# refresh the credentials
|
||||
tool_provider = ToolProviderID(provider_id)
|
||||
provider_name = tool_provider.provider_name
|
||||
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/tool/callback"
|
||||
system_credentials = BuiltinToolManageService.get_oauth_client(tenant_id, provider_id)
|
||||
|
||||
oauth_handler = OAuthHandler()
|
||||
# refresh the credentials
|
||||
refreshed_credentials = oauth_handler.refresh_credentials(
|
||||
# Refresh credentials using the centralized helper method
|
||||
refreshed_credentials, expires_at = cls._refresh_oauth_credentials(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=provider_id,
|
||||
user_id=builtin_provider.user_id,
|
||||
plugin_id=tool_provider.plugin_id,
|
||||
provider=provider_name,
|
||||
redirect_uri=redirect_uri,
|
||||
system_credentials=system_credentials or {},
|
||||
credentials=decrypted_credentials,
|
||||
decrypted_credentials=decrypted_credentials,
|
||||
)
|
||||
# update the credentials
|
||||
builtin_provider.encrypted_credentials = json.dumps(
|
||||
encrypter.encrypt(refreshed_credentials.credentials)
|
||||
)
|
||||
builtin_provider.expires_at = refreshed_credentials.expires_at
|
||||
|
||||
# Update the provider with refreshed credentials
|
||||
builtin_provider.encrypted_credentials = json.dumps(encrypter.encrypt(refreshed_credentials))
|
||||
builtin_provider.expires_at = expires_at
|
||||
db.session.commit()
|
||||
decrypted_credentials = refreshed_credentials.credentials
|
||||
decrypted_credentials = refreshed_credentials
|
||||
cache.delete()
|
||||
|
||||
return cast(
|
||||
@@ -368,6 +475,7 @@ class ToolManager:
|
||||
agent_tool: AgentToolEntity,
|
||||
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
|
||||
variable_pool: Optional["VariablePool"] = None,
|
||||
end_user_id: str | None = None,
|
||||
) -> Tool:
|
||||
"""
|
||||
get the agent tool runtime
|
||||
@@ -380,6 +488,8 @@ class ToolManager:
|
||||
invoke_from=invoke_from,
|
||||
tool_invoke_from=ToolInvokeFrom.AGENT,
|
||||
credential_id=agent_tool.credential_id,
|
||||
auth_type=agent_tool.auth_type,
|
||||
end_user_id=end_user_id,
|
||||
)
|
||||
runtime_parameters = {}
|
||||
parameters = tool_entity.get_merged_runtime_parameters()
|
||||
@@ -410,6 +520,7 @@ class ToolManager:
|
||||
workflow_tool: "ToolEntity",
|
||||
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
|
||||
variable_pool: Optional["VariablePool"] = None,
|
||||
end_user_id: str | None = None,
|
||||
) -> Tool:
|
||||
"""
|
||||
get the workflow tool runtime
|
||||
@@ -423,6 +534,8 @@ class ToolManager:
|
||||
invoke_from=invoke_from,
|
||||
tool_invoke_from=ToolInvokeFrom.WORKFLOW,
|
||||
credential_id=workflow_tool.credential_id,
|
||||
auth_type=workflow_tool.auth_type,
|
||||
end_user_id=end_user_id,
|
||||
)
|
||||
|
||||
parameters = tool_runtime.get_merged_runtime_parameters()
|
||||
|
||||
@@ -66,6 +66,7 @@ class ToolNode(Node[ToolNodeData]):
|
||||
# get tool runtime
|
||||
try:
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from models.enums import UserFrom
|
||||
|
||||
# This is an issue that caused problems before.
|
||||
# Logically, we shouldn't use the node_data.version field for judgment
|
||||
@@ -74,8 +75,20 @@ class ToolNode(Node[ToolNodeData]):
|
||||
variable_pool: VariablePool | None = None
|
||||
if self.node_data.version != "1" or self.node_data.tool_node_version is not None:
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
|
||||
# Determine end_user_id based on user_from
|
||||
end_user_id = None
|
||||
if self.user_from == UserFrom.END_USER:
|
||||
end_user_id = self.user_id
|
||||
|
||||
tool_runtime = ToolManager.get_workflow_tool_runtime(
|
||||
self.tenant_id, self.app_id, self._node_id, self.node_data, self.invoke_from, variable_pool
|
||||
self.tenant_id,
|
||||
self.app_id,
|
||||
self._node_id,
|
||||
self.node_data,
|
||||
self.invoke_from,
|
||||
variable_pool,
|
||||
end_user_id,
|
||||
)
|
||||
except ToolNodeError as e:
|
||||
yield StreamCompletedEvent(
|
||||
|
||||
@@ -0,0 +1,25 @@
|
||||
"""merge_heads
|
||||
|
||||
Revision ID: 3134f4e0620d
|
||||
Revises: d57accd375ae, a7b4e8f2c9d1
|
||||
Create Date: 2025-12-14 14:18:19.393720
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '3134f4e0620d'
|
||||
down_revision = ('d57accd375ae', 'a7b4e8f2c9d1')
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
pass
|
||||
|
||||
|
||||
def downgrade():
|
||||
pass
|
||||
227
api/services/tools/app_auth_requirement_service.py
Normal file
227
api/services/tools/app_auth_requirement_service.py
Normal file
@@ -0,0 +1,227 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from extensions.ext_database import db
|
||||
from models.workflow import Workflow
|
||||
from services.tools.enduser_auth_service import EndUserAuthService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AppAuthRequirementService:
|
||||
"""
|
||||
Service for analyzing authentication requirements in apps.
|
||||
Examines workflow DSL to identify which providers need end-user authentication.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def get_tool_auth_requirements(
|
||||
app_id: str,
|
||||
tenant_id: str,
|
||||
provider_type: str | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Get all authentication requirements for tools in an app.
|
||||
|
||||
:param app_id: The application ID
|
||||
:param tenant_id: The tenant ID
|
||||
:param provider_type: Optional filter by provider type (e.g., "tool")
|
||||
:return: List of provider requirements
|
||||
"""
|
||||
try:
|
||||
# Get latest published workflow for the app
|
||||
with Session(db.engine, autoflush=False) as session:
|
||||
workflow = (
|
||||
session.query(Workflow)
|
||||
.filter_by(app_id=app_id, tenant_id=tenant_id)
|
||||
.order_by(Workflow.created_at.desc())
|
||||
.first()
|
||||
)
|
||||
|
||||
if not workflow:
|
||||
return []
|
||||
|
||||
# Parse workflow graph to find tool nodes
|
||||
graph = workflow.graph_dict
|
||||
if not graph or "nodes" not in graph:
|
||||
return []
|
||||
|
||||
providers = []
|
||||
seen_providers = set()
|
||||
|
||||
# Iterate through workflow nodes
|
||||
for node in graph.get("nodes", []):
|
||||
node_data = node.get("data", {})
|
||||
node_type = node_data.get("type")
|
||||
|
||||
# Check if it's a tool node
|
||||
if node_type == "tool":
|
||||
provider_id = node_data.get("provider_id")
|
||||
provider_name = node_data.get("provider_name")
|
||||
tool_name = node_data.get("tool_name")
|
||||
|
||||
if not provider_id:
|
||||
continue
|
||||
|
||||
# Avoid duplicates
|
||||
if provider_id in seen_providers:
|
||||
continue
|
||||
|
||||
seen_providers.add(provider_id)
|
||||
|
||||
# Get provider controller to check authentication requirements
|
||||
try:
|
||||
provider_controller = ToolManager.get_builtin_provider(provider_id, tenant_id)
|
||||
|
||||
# Check if provider needs credentials
|
||||
if not provider_controller.need_credentials:
|
||||
continue
|
||||
|
||||
# Get supported credential types
|
||||
supported_types = provider_controller.get_supported_credential_types()
|
||||
|
||||
# Determine required credential type (prefer OAuth if supported)
|
||||
required_type = None
|
||||
if supported_types:
|
||||
# Prefer OAuth2, then API key
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
|
||||
if CredentialType.OAUTH2 in supported_types:
|
||||
required_type = "oauth2"
|
||||
elif CredentialType.API_KEY in supported_types:
|
||||
required_type = "api-key"
|
||||
else:
|
||||
required_type = supported_types[0].value
|
||||
|
||||
providers.append(
|
||||
{
|
||||
"provider_id": provider_id,
|
||||
"provider_name": provider_name or provider_id,
|
||||
"supported_credential_types": [ct.value for ct in supported_types],
|
||||
"required_credential_type": required_type,
|
||||
"provider_type": "tool",
|
||||
"feature_context": {
|
||||
"node_ids": [node.get("id")],
|
||||
"tool_names": [tool_name] if tool_name else [],
|
||||
},
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Error getting provider info for %s: %s", provider_id, e)
|
||||
continue
|
||||
|
||||
# Filter by provider_type if specified
|
||||
if provider_type:
|
||||
providers = [p for p in providers if p.get("provider_type") == provider_type]
|
||||
|
||||
return providers
|
||||
except Exception:
|
||||
logger.exception("Error getting auth requirements for app %s", app_id)
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def get_required_providers(
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Get list of providers that require end-user authentication for an app.
|
||||
Simplified version of get_tool_auth_requirements for API use.
|
||||
|
||||
:param tenant_id: The tenant ID
|
||||
:param app_id: The application ID
|
||||
:return: List of provider information dictionaries
|
||||
"""
|
||||
requirements = AppAuthRequirementService.get_tool_auth_requirements(
|
||||
app_id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
# Transform to simpler format for API response
|
||||
return [
|
||||
{
|
||||
"provider_id": req["provider_id"],
|
||||
"provider_name": req["provider_name"],
|
||||
"credential_type": req["required_credential_type"],
|
||||
"is_required": True,
|
||||
"oauth_config": None if req["required_credential_type"] != "oauth2" else {
|
||||
"supported_types": req["supported_credential_types"],
|
||||
},
|
||||
}
|
||||
for req in requirements
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_auth_status(
|
||||
app_id: str,
|
||||
end_user_id: str,
|
||||
tenant_id: str,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Get overall authentication status for an app and end user.
|
||||
Shows which providers are authenticated and which need authentication.
|
||||
|
||||
:param app_id: The application ID
|
||||
:param end_user_id: The end user ID
|
||||
:param tenant_id: The tenant ID
|
||||
:return: Dict with authentication status for all providers
|
||||
"""
|
||||
try:
|
||||
# Get required providers for this app
|
||||
required_providers = AppAuthRequirementService.get_tool_auth_requirements(app_id, tenant_id)
|
||||
|
||||
# Check authentication status for each provider
|
||||
provider_statuses = []
|
||||
for provider_info in required_providers:
|
||||
provider_id = provider_info["provider_id"]
|
||||
|
||||
# Get credentials for this provider
|
||||
credentials = EndUserAuthService.list_credentials(tenant_id, end_user_id, provider_id)
|
||||
|
||||
# Build status
|
||||
provider_status = {
|
||||
"provider_id": provider_id,
|
||||
"provider_name": provider_info["provider_name"],
|
||||
"provider_type": provider_info["provider_type"],
|
||||
"authenticated": len(credentials) > 0,
|
||||
"credentials": [
|
||||
{
|
||||
"credential_id": cred.id,
|
||||
"name": cred.name,
|
||||
"type": cred.credential_type.value,
|
||||
"is_default": cred.is_default,
|
||||
"expires_at": cred.expires_at,
|
||||
}
|
||||
for cred in credentials
|
||||
],
|
||||
}
|
||||
|
||||
provider_statuses.append(provider_status)
|
||||
|
||||
return {"providers": provider_statuses}
|
||||
except Exception:
|
||||
logger.exception("Error getting auth status for app %s", app_id)
|
||||
return {"providers": []}
|
||||
|
||||
@staticmethod
|
||||
def is_provider_authenticated(
|
||||
provider_id: str,
|
||||
end_user_id: str,
|
||||
tenant_id: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a specific provider is authenticated for an end user.
|
||||
|
||||
:param provider_id: The provider identifier
|
||||
:param end_user_id: The end user ID
|
||||
:param tenant_id: The tenant ID
|
||||
:return: True if authenticated, False otherwise
|
||||
"""
|
||||
try:
|
||||
credentials = EndUserAuthService.list_credentials(tenant_id, end_user_id, provider_id)
|
||||
return len(credentials) > 0
|
||||
except Exception:
|
||||
return False
|
||||
558
api/services/tools/enduser_auth_service.py
Normal file
558
api/services/tools/enduser_auth_service.py
Normal file
@@ -0,0 +1,558 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
from sqlalchemy import exists, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from constants import HIDDEN_VALUE, UNKNOWN_VALUE
|
||||
from core.helper.name_generator import generate_incremental_name
|
||||
from core.helper.provider_cache import NoOpProviderCredentialCache
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.tools.entities.api_entities import ToolProviderCredentialApiEntity
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.encryption import create_provider_encrypter
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.tools import EndUserAuthenticationProvider
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EndUserAuthService:
|
||||
"""
|
||||
Service for managing end-user authentication credentials.
|
||||
Follows similar patterns to BuiltinToolManageService but for end users.
|
||||
"""
|
||||
|
||||
__MAX_CREDENTIALS_PER_PROVIDER__ = 100
|
||||
|
||||
@staticmethod
|
||||
def list_credentials(
|
||||
tenant_id: str, end_user_id: str, provider_id: str
|
||||
) -> list[ToolProviderCredentialApiEntity]:
|
||||
"""
|
||||
List all credentials for a specific provider and end user.
|
||||
|
||||
:param end_user_id: The end user ID
|
||||
:param tenant_id: The tenant ID
|
||||
:param provider_id: The provider identifier
|
||||
:return: List of credential entities
|
||||
"""
|
||||
with Session(db.engine, autoflush=False) as session:
|
||||
credentials = (
|
||||
session.query(EndUserAuthenticationProvider)
|
||||
.filter_by(end_user_id=end_user_id, tenant_id=tenant_id, provider=provider_id)
|
||||
.order_by(EndUserAuthenticationProvider.created_at.asc())
|
||||
.all()
|
||||
)
|
||||
|
||||
if not credentials:
|
||||
return []
|
||||
|
||||
# Get provider controller to access credential schema
|
||||
provider_controller = ToolManager.get_builtin_provider(provider_id, tenant_id)
|
||||
|
||||
result: list[ToolProviderCredentialApiEntity] = []
|
||||
for credential in credentials:
|
||||
try:
|
||||
# Create encrypter for masking credentials
|
||||
encrypter, _ = EndUserAuthService._create_encrypter(
|
||||
tenant_id, provider_controller, credential.credential_type
|
||||
)
|
||||
|
||||
# Decrypt and mask credentials
|
||||
decrypted = encrypter.decrypt(credential.credentials)
|
||||
masked_credentials = encrypter.mask_plugin_credentials(decrypted)
|
||||
|
||||
# Convert to API entity
|
||||
credential_entity = ToolTransformService.convert_enduser_provider_to_credential_entity(
|
||||
provider=credential,
|
||||
credentials=dict(masked_credentials),
|
||||
)
|
||||
result.append(credential_entity)
|
||||
except Exception:
|
||||
logger.exception("Error processing credential %s", credential.id)
|
||||
continue
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def get_credential(
|
||||
credential_id: str, end_user_id: str, tenant_id: str, mask_credentials: bool = True
|
||||
) -> ToolProviderCredentialApiEntity | None:
|
||||
"""
|
||||
Get a specific credential by ID.
|
||||
|
||||
:param credential_id: The credential ID
|
||||
:param end_user_id: The end user ID
|
||||
:param tenant_id: The tenant ID
|
||||
:param mask_credentials: Whether to mask secret fields
|
||||
:return: Credential entity or None
|
||||
"""
|
||||
with Session(db.engine, autoflush=False) as session:
|
||||
credential = (
|
||||
session.query(EndUserAuthenticationProvider)
|
||||
.filter_by(id=credential_id, end_user_id=end_user_id, tenant_id=tenant_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not credential:
|
||||
return None
|
||||
|
||||
# Get provider controller
|
||||
provider_controller = ToolManager.get_builtin_provider(credential.provider, tenant_id)
|
||||
|
||||
# Create encrypter
|
||||
encrypter, _ = EndUserAuthService._create_encrypter(
|
||||
tenant_id, provider_controller, credential.credential_type
|
||||
)
|
||||
|
||||
# Decrypt credentials
|
||||
decrypted = encrypter.decrypt(credential.credentials)
|
||||
|
||||
# Mask if requested
|
||||
if mask_credentials:
|
||||
decrypted = encrypter.mask_plugin_credentials(decrypted)
|
||||
|
||||
# Convert to API entity
|
||||
return ToolTransformService.convert_enduser_provider_to_credential_entity(
|
||||
provider=credential,
|
||||
credentials=dict(decrypted),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create_api_key_credential(
|
||||
tenant_id: str,
|
||||
end_user_id: str,
|
||||
provider_id: str,
|
||||
credentials: dict,
|
||||
name: str | None = None,
|
||||
) -> ToolProviderCredentialApiEntity:
|
||||
"""
|
||||
Create a new API key credential for an end user.
|
||||
|
||||
:param tenant_id: The tenant ID
|
||||
:param end_user_id: The end user ID
|
||||
:param provider_id: The provider identifier
|
||||
:param credentials: The credential data
|
||||
:param name: Optional custom name
|
||||
:return: Created credential entity
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
try:
|
||||
lock = f"enduser_credential_create_lock:{end_user_id}_{provider_id}"
|
||||
with redis_client.lock(lock, timeout=20):
|
||||
# Get provider controller
|
||||
provider_controller = ToolManager.get_builtin_provider(provider_id, tenant_id)
|
||||
if not provider_controller.need_credentials:
|
||||
raise ValueError(f"Provider {provider_id} does not need credentials")
|
||||
|
||||
# Check credential count
|
||||
credential_count = (
|
||||
session.query(EndUserAuthenticationProvider)
|
||||
.filter_by(end_user_id=end_user_id, tenant_id=tenant_id, provider=provider_id)
|
||||
.count()
|
||||
)
|
||||
|
||||
if credential_count >= EndUserAuthService.__MAX_CREDENTIALS_PER_PROVIDER__:
|
||||
raise ValueError(
|
||||
f"Maximum number of credentials ({EndUserAuthService.__MAX_CREDENTIALS_PER_PROVIDER__}) "
|
||||
f"reached for provider {provider_id}"
|
||||
)
|
||||
|
||||
# Validate credentials
|
||||
credential_type = CredentialType.API_KEY
|
||||
if CredentialType.of(credential_type).is_validate_allowed():
|
||||
provider_controller.validate_credentials(end_user_id, credentials)
|
||||
|
||||
# Generate name if not provided
|
||||
if name is None or name == "":
|
||||
name = EndUserAuthService._generate_credential_name(
|
||||
session=session,
|
||||
end_user_id=end_user_id,
|
||||
tenant_id=tenant_id,
|
||||
provider=provider_id,
|
||||
credential_type=credential_type,
|
||||
)
|
||||
else:
|
||||
# Validate name length
|
||||
if len(name) > 30:
|
||||
raise ValueError("Credential name must be 30 characters or less")
|
||||
|
||||
# Check if name is already used
|
||||
if session.scalar(
|
||||
select(
|
||||
exists().where(
|
||||
EndUserAuthenticationProvider.end_user_id == end_user_id,
|
||||
EndUserAuthenticationProvider.tenant_id == tenant_id,
|
||||
EndUserAuthenticationProvider.provider == provider_id,
|
||||
EndUserAuthenticationProvider.name == name,
|
||||
)
|
||||
)
|
||||
):
|
||||
raise ValueError(f"The credential name '{name}' is already used")
|
||||
|
||||
# Create encrypter
|
||||
encrypter, _ = EndUserAuthService._create_encrypter(
|
||||
tenant_id, provider_controller, credential_type
|
||||
)
|
||||
|
||||
# Create credential record
|
||||
db_credential = EndUserAuthenticationProvider(
|
||||
tenant_id=tenant_id,
|
||||
end_user_id=end_user_id,
|
||||
provider=provider_id,
|
||||
encrypted_credentials=json.dumps(encrypter.encrypt(credentials)),
|
||||
credential_type=credential_type,
|
||||
name=name,
|
||||
expires_at=-1, # API keys don't expire
|
||||
)
|
||||
|
||||
session.add(db_credential)
|
||||
session.commit()
|
||||
session.refresh(db_credential)
|
||||
|
||||
# Return masked credentials
|
||||
masked_credentials = encrypter.mask_plugin_credentials(credentials)
|
||||
return ToolTransformService.convert_enduser_provider_to_credential_entity(
|
||||
provider=db_credential,
|
||||
credentials=dict(masked_credentials),
|
||||
)
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.exception("Error creating API key credential")
|
||||
raise ValueError(str(e))
|
||||
|
||||
@staticmethod
|
||||
def create_oauth_credential(
|
||||
end_user_id: str,
|
||||
tenant_id: str,
|
||||
provider: str,
|
||||
credentials: dict,
|
||||
expires_at: int = -1,
|
||||
name: str | None = None,
|
||||
) -> EndUserAuthenticationProvider:
|
||||
"""
|
||||
Create a new OAuth credential for an end user.
|
||||
Used internally by OAuth callback handler.
|
||||
|
||||
:param end_user_id: The end user ID
|
||||
:param tenant_id: The tenant ID
|
||||
:param provider: The provider identifier
|
||||
:param credentials: The OAuth credentials (access_token, refresh_token, etc.)
|
||||
:param expires_at: Unix timestamp when token expires (-1 for no expiry)
|
||||
:param name: Optional custom name
|
||||
:return: Created credential record
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
try:
|
||||
lock = f"enduser_credential_create_lock:{end_user_id}_{provider}"
|
||||
with redis_client.lock(lock, timeout=20):
|
||||
# Get provider controller
|
||||
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
||||
|
||||
# Check credential count
|
||||
credential_count = (
|
||||
session.query(EndUserAuthenticationProvider)
|
||||
.filter_by(end_user_id=end_user_id, tenant_id=tenant_id, provider=provider)
|
||||
.count()
|
||||
)
|
||||
|
||||
if credential_count >= EndUserAuthService.__MAX_CREDENTIALS_PER_PROVIDER__:
|
||||
raise ValueError(
|
||||
f"Maximum number of credentials ({EndUserAuthService.__MAX_CREDENTIALS_PER_PROVIDER__}) "
|
||||
f"reached for provider {provider}"
|
||||
)
|
||||
|
||||
# Generate name if not provided
|
||||
credential_type = CredentialType.OAUTH2
|
||||
if name is None or name == "":
|
||||
name = EndUserAuthService._generate_credential_name(
|
||||
session=session,
|
||||
end_user_id=end_user_id,
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
credential_type=credential_type,
|
||||
)
|
||||
|
||||
# Create encrypter
|
||||
encrypter, _ = EndUserAuthService._create_encrypter(
|
||||
tenant_id, provider_controller, credential_type
|
||||
)
|
||||
|
||||
# Create credential record
|
||||
db_credential = EndUserAuthenticationProvider(
|
||||
tenant_id=tenant_id,
|
||||
end_user_id=end_user_id,
|
||||
provider=provider,
|
||||
encrypted_credentials=json.dumps(encrypter.encrypt(credentials)),
|
||||
credential_type=credential_type,
|
||||
name=name,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
|
||||
session.add(db_credential)
|
||||
session.commit()
|
||||
session.refresh(db_credential)
|
||||
|
||||
return db_credential
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.exception("Error creating OAuth credential")
|
||||
raise ValueError(str(e))
|
||||
|
||||
@staticmethod
|
||||
def update_credential(
|
||||
credential_id: str,
|
||||
end_user_id: str,
|
||||
tenant_id: str,
|
||||
credentials: dict | None = None,
|
||||
name: str | None = None,
|
||||
) -> ToolProviderCredentialApiEntity:
|
||||
"""
|
||||
Update an existing credential (API key only).
|
||||
|
||||
:param credential_id: The credential ID
|
||||
:param end_user_id: The end user ID
|
||||
:param tenant_id: The tenant ID
|
||||
:param credentials: Updated credentials (optional)
|
||||
:param name: Updated name (optional)
|
||||
:return: Updated credential entity
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
try:
|
||||
# Get credential
|
||||
db_credential = (
|
||||
session.query(EndUserAuthenticationProvider)
|
||||
.filter_by(id=credential_id, end_user_id=end_user_id, tenant_id=tenant_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not db_credential:
|
||||
raise ValueError(f"Credential {credential_id} not found")
|
||||
|
||||
# Only API key credentials can be updated
|
||||
if not CredentialType.of(db_credential.credential_type).is_editable():
|
||||
raise ValueError("Only API key credentials can be updated via this endpoint")
|
||||
|
||||
# At least one field must be provided
|
||||
if credentials is None and name is None:
|
||||
raise ValueError("At least one field (credentials or name) must be provided")
|
||||
|
||||
# Get provider controller
|
||||
provider_controller = ToolManager.get_builtin_provider(db_credential.provider, tenant_id)
|
||||
|
||||
# Create encrypter
|
||||
encrypter, _ = EndUserAuthService._create_encrypter(
|
||||
tenant_id, provider_controller, db_credential.credential_type
|
||||
)
|
||||
|
||||
# Update credentials if provided
|
||||
if credentials:
|
||||
# Decrypt original credentials
|
||||
original_credentials = encrypter.decrypt(db_credential.credentials)
|
||||
|
||||
# Merge with new credentials, keeping hidden values
|
||||
new_credentials: dict = {
|
||||
key: value if value != HIDDEN_VALUE else original_credentials.get(key, UNKNOWN_VALUE)
|
||||
for key, value in credentials.items()
|
||||
}
|
||||
|
||||
# Validate new credentials
|
||||
if CredentialType.of(db_credential.credential_type).is_validate_allowed():
|
||||
provider_controller.validate_credentials(end_user_id, new_credentials)
|
||||
|
||||
# Encrypt and save
|
||||
db_credential.encrypted_credentials = json.dumps(encrypter.encrypt(new_credentials))
|
||||
|
||||
# Update name if provided
|
||||
if name and name != db_credential.name:
|
||||
# Validate name length
|
||||
if len(name) > 30:
|
||||
raise ValueError("Credential name must be 30 characters or less")
|
||||
|
||||
# Check if name is already used
|
||||
if session.scalar(
|
||||
select(
|
||||
exists().where(
|
||||
EndUserAuthenticationProvider.end_user_id == end_user_id,
|
||||
EndUserAuthenticationProvider.tenant_id == tenant_id,
|
||||
EndUserAuthenticationProvider.provider == db_credential.provider,
|
||||
EndUserAuthenticationProvider.name == name,
|
||||
)
|
||||
)
|
||||
):
|
||||
raise ValueError(f"The credential name '{name}' is already used")
|
||||
|
||||
db_credential.name = name
|
||||
|
||||
session.commit()
|
||||
session.refresh(db_credential)
|
||||
|
||||
# Return masked credentials
|
||||
decrypted = encrypter.decrypt(db_credential.credentials)
|
||||
masked_credentials = encrypter.mask_plugin_credentials(decrypted)
|
||||
|
||||
return ToolTransformService.convert_enduser_provider_to_credential_entity(
|
||||
provider=db_credential,
|
||||
credentials=dict(masked_credentials),
|
||||
)
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.exception("Error updating credential")
|
||||
raise ValueError(str(e))
|
||||
|
||||
@staticmethod
|
||||
def delete_credential(
|
||||
tenant_id: str, end_user_id: str, provider_id: str, credential_id: str
|
||||
) -> bool:
|
||||
"""
|
||||
Delete a credential.
|
||||
|
||||
:param credential_id: The credential ID
|
||||
:param end_user_id: The end user ID
|
||||
:param tenant_id: The tenant ID
|
||||
:return: True if deleted successfully
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
credential = (
|
||||
session.query(EndUserAuthenticationProvider)
|
||||
.filter_by(id=credential_id, end_user_id=end_user_id, tenant_id=tenant_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not credential:
|
||||
raise ValueError(f"Credential {credential_id} not found")
|
||||
|
||||
session.delete(credential)
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def refresh_oauth_token(
|
||||
credential_id: str, end_user_id: str, tenant_id: str, refreshed_credentials: dict, expires_at: int
|
||||
) -> EndUserAuthenticationProvider:
|
||||
"""
|
||||
Update OAuth credentials after token refresh.
|
||||
|
||||
:param credential_id: The credential ID
|
||||
:param end_user_id: The end user ID
|
||||
:param tenant_id: The tenant ID
|
||||
:param refreshed_credentials: New credentials from OAuth refresh
|
||||
:param expires_at: New expiration timestamp
|
||||
:return: Updated credential record
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
try:
|
||||
credential = (
|
||||
session.query(EndUserAuthenticationProvider)
|
||||
.filter_by(id=credential_id, end_user_id=end_user_id, tenant_id=tenant_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not credential:
|
||||
raise ValueError(f"Credential {credential_id} not found")
|
||||
|
||||
if credential.credential_type != CredentialType.OAUTH2:
|
||||
raise ValueError("Only OAuth credentials can be refreshed")
|
||||
|
||||
# Get provider controller
|
||||
provider_controller = ToolManager.get_builtin_provider(credential.provider, tenant_id)
|
||||
|
||||
# Create encrypter
|
||||
encrypter, _ = EndUserAuthService._create_encrypter(
|
||||
tenant_id, provider_controller, credential.credential_type
|
||||
)
|
||||
|
||||
# Encrypt and save new credentials
|
||||
credential.encrypted_credentials = json.dumps(encrypter.encrypt(refreshed_credentials))
|
||||
credential.expires_at = expires_at
|
||||
|
||||
session.commit()
|
||||
session.refresh(credential)
|
||||
|
||||
return credential
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.exception("Error refreshing OAuth token")
|
||||
raise ValueError(str(e))
|
||||
|
||||
@staticmethod
|
||||
def get_default_credential(
|
||||
end_user_id: str, tenant_id: str, provider: str
|
||||
) -> EndUserAuthenticationProvider | None:
|
||||
"""
|
||||
Get the default (oldest) credential for a provider.
|
||||
|
||||
:param end_user_id: The end user ID
|
||||
:param tenant_id: The tenant ID
|
||||
:param provider: The provider identifier
|
||||
:return: Credential record or None
|
||||
"""
|
||||
with Session(db.engine, autoflush=False) as session:
|
||||
return (
|
||||
session.query(EndUserAuthenticationProvider)
|
||||
.filter_by(end_user_id=end_user_id, tenant_id=tenant_id, provider=provider)
|
||||
.order_by(EndUserAuthenticationProvider.created_at.asc())
|
||||
.first()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _generate_credential_name(
|
||||
session: Session,
|
||||
end_user_id: str,
|
||||
tenant_id: str,
|
||||
provider: str,
|
||||
credential_type: CredentialType,
|
||||
) -> str:
|
||||
"""
|
||||
Generate a unique credential name.
|
||||
|
||||
:param session: Database session
|
||||
:param end_user_id: The end user ID
|
||||
:param tenant_id: The tenant ID
|
||||
:param provider: The provider identifier
|
||||
:param credential_type: The credential type
|
||||
:return: Generated name (e.g., "API KEY 1", "AUTH 1")
|
||||
"""
|
||||
existing_credentials = (
|
||||
session.query(EndUserAuthenticationProvider)
|
||||
.filter_by(
|
||||
end_user_id=end_user_id,
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
credential_type=credential_type,
|
||||
)
|
||||
.order_by(EndUserAuthenticationProvider.created_at.desc())
|
||||
.all()
|
||||
)
|
||||
|
||||
return generate_incremental_name(
|
||||
[credential.name for credential in existing_credentials],
|
||||
f"{credential_type.get_name()}",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _create_encrypter(
|
||||
tenant_id: str, provider_controller, credential_type: CredentialType | str
|
||||
) -> tuple:
|
||||
"""
|
||||
Create an encrypter for credential encryption/decryption.
|
||||
|
||||
:param tenant_id: The tenant ID
|
||||
:param provider_controller: The provider controller
|
||||
:param credential_type: The credential type
|
||||
:return: Tuple of (encrypter, cache)
|
||||
"""
|
||||
if isinstance(credential_type, str):
|
||||
credential_type = CredentialType.of(credential_type)
|
||||
|
||||
return create_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[
|
||||
x.to_basic_provider_config()
|
||||
for x in provider_controller.get_credentials_schema_by_type(credential_type)
|
||||
],
|
||||
cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
264
api/services/tools/enduser_oauth_service.py
Normal file
264
api/services/tools/enduser_oauth_service.py
Normal file
@@ -0,0 +1,264 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from werkzeug import Request
|
||||
|
||||
from configs import dify_config
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from models.provider_ids import ToolProviderID
|
||||
from services.plugin.oauth_service import OAuthProxyService
|
||||
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||
from services.tools.enduser_auth_service import EndUserAuthService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EndUserOAuthService:
|
||||
"""
|
||||
Service for managing end-user OAuth authentication flows.
|
||||
Reuses existing OAuthProxyService and OAuthHandler infrastructure.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def get_authorization_url(
|
||||
end_user_id: str,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
provider: str,
|
||||
) -> dict[str, str]:
|
||||
"""
|
||||
Initiate OAuth authorization flow for an end user.
|
||||
|
||||
:param end_user_id: The end user ID
|
||||
:param tenant_id: The tenant ID
|
||||
:param app_id: The application ID
|
||||
:param provider: The provider identifier
|
||||
:return: Dict with authorization_url
|
||||
"""
|
||||
try:
|
||||
# Get OAuth client configuration (reuse workspace-level logic)
|
||||
oauth_client = BuiltinToolManageService.get_oauth_client(tenant_id, provider)
|
||||
if not oauth_client:
|
||||
raise ValueError(f"OAuth client not configured for provider {provider}")
|
||||
|
||||
# Get provider controller
|
||||
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
||||
tool_provider_id = ToolProviderID(provider)
|
||||
|
||||
# Create OAuth context with end-user specific data
|
||||
context_id = OAuthProxyService.create_proxy_context(
|
||||
user_id=end_user_id, # Using end_user_id as user_id
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=tool_provider_id.plugin_id,
|
||||
provider=tool_provider_id.provider_name,
|
||||
extra_data={
|
||||
"app_id": app_id,
|
||||
"provider_type": "tool", # For now, only tools support end-user auth
|
||||
},
|
||||
)
|
||||
|
||||
# Use the same redirect URI as workspace OAuth to reuse the same OAuth client
|
||||
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{tool_provider_id}/tool/callback"
|
||||
|
||||
# Get authorization URL from OAuth handler
|
||||
oauth_handler = OAuthHandler()
|
||||
response = oauth_handler.get_authorization_url(
|
||||
tenant_id=tenant_id,
|
||||
user_id=end_user_id,
|
||||
plugin_id=tool_provider_id.plugin_id,
|
||||
provider=tool_provider_id.provider_name,
|
||||
redirect_uri=redirect_uri,
|
||||
system_credentials=oauth_client,
|
||||
)
|
||||
|
||||
return {
|
||||
"authorization_url": response.authorization_url,
|
||||
"context_id": context_id, # Return for setting cookie
|
||||
}
|
||||
except Exception as e:
|
||||
logger.exception("Error getting authorization URL for end user")
|
||||
raise ValueError(f"Failed to initiate OAuth flow: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
def handle_oauth_callback(
|
||||
context_id: str,
|
||||
request: Request,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Handle OAuth callback and create credential.
|
||||
|
||||
:param context_id: The OAuth context ID from cookie
|
||||
:param request: The callback request with authorization code
|
||||
:return: Dict with credential information
|
||||
"""
|
||||
try:
|
||||
# Validate and retrieve context
|
||||
context = OAuthProxyService.use_proxy_context(context_id)
|
||||
|
||||
# Extract context data
|
||||
end_user_id = context.get("user_id") # user_id is actually end_user_id
|
||||
tenant_id = context.get("tenant_id")
|
||||
app_id = context.get("app_id")
|
||||
plugin_id = context.get("plugin_id")
|
||||
provider = context.get("provider")
|
||||
|
||||
if not all([end_user_id, tenant_id, app_id, plugin_id, provider]):
|
||||
raise ValueError("Invalid OAuth context: missing required fields")
|
||||
|
||||
# Reconstruct full provider ID
|
||||
full_provider = f"{plugin_id}/{provider}" if plugin_id != "langgenius" else provider
|
||||
|
||||
# Get OAuth client configuration
|
||||
oauth_client = BuiltinToolManageService.get_oauth_client(tenant_id, full_provider)
|
||||
if not oauth_client:
|
||||
raise ValueError(f"OAuth client not configured for provider {full_provider}")
|
||||
|
||||
# Use the same redirect URI as workspace OAuth (must match authorization request)
|
||||
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{full_provider}/tool/callback"
|
||||
|
||||
# Exchange authorization code for credentials
|
||||
oauth_handler = OAuthHandler()
|
||||
credentials_response = oauth_handler.get_credentials(
|
||||
tenant_id=tenant_id,
|
||||
user_id=end_user_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider,
|
||||
redirect_uri=redirect_uri,
|
||||
system_credentials=oauth_client,
|
||||
request=request,
|
||||
)
|
||||
|
||||
# Calculate expiration timestamp
|
||||
expires_at = -1
|
||||
if credentials_response.expires_in and credentials_response.expires_in > 0:
|
||||
import time
|
||||
|
||||
expires_at = int(time.time()) + credentials_response.expires_in
|
||||
|
||||
# Create credential in database
|
||||
credential = EndUserAuthService.create_oauth_credential(
|
||||
end_user_id=end_user_id,
|
||||
tenant_id=tenant_id,
|
||||
provider=full_provider,
|
||||
credentials=credentials_response.credentials,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"credential_id": credential.id,
|
||||
"provider": full_provider,
|
||||
"app_id": app_id,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.exception("Error handling OAuth callback for end user")
|
||||
raise ValueError(f"Failed to complete OAuth flow: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
def refresh_oauth_token(
|
||||
credential_id: str,
|
||||
end_user_id: str,
|
||||
tenant_id: str,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Refresh an expired OAuth token.
|
||||
|
||||
:param credential_id: The credential ID
|
||||
:param end_user_id: The end user ID
|
||||
:param tenant_id: The tenant ID
|
||||
:return: Dict with refresh status
|
||||
"""
|
||||
try:
|
||||
# Get existing credential
|
||||
credential = EndUserAuthService.get_credential(
|
||||
credential_id=credential_id,
|
||||
end_user_id=end_user_id,
|
||||
tenant_id=tenant_id,
|
||||
mask_credentials=False, # Need full credentials for refresh
|
||||
)
|
||||
|
||||
if not credential:
|
||||
raise ValueError(f"Credential {credential_id} not found")
|
||||
|
||||
if credential.credential_type != CredentialType.OAUTH2:
|
||||
raise ValueError("Only OAuth credentials can be refreshed")
|
||||
|
||||
# Get OAuth client configuration
|
||||
oauth_client = BuiltinToolManageService.get_oauth_client(tenant_id, credential.provider)
|
||||
if not oauth_client:
|
||||
raise ValueError(f"OAuth client not configured for provider {credential.provider}")
|
||||
|
||||
# Get provider info
|
||||
tool_provider_id = ToolProviderID(credential.provider)
|
||||
# Use the same redirect URI as workspace OAuth to reuse the same OAuth client
|
||||
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{credential.provider}/tool/callback"
|
||||
|
||||
# Refresh credentials via OAuth handler
|
||||
oauth_handler = OAuthHandler()
|
||||
refreshed_response = oauth_handler.refresh_credentials(
|
||||
tenant_id=tenant_id,
|
||||
user_id=end_user_id,
|
||||
plugin_id=tool_provider_id.plugin_id,
|
||||
provider=tool_provider_id.provider_name,
|
||||
redirect_uri=redirect_uri,
|
||||
system_credentials=oauth_client,
|
||||
credentials=credential.credentials,
|
||||
)
|
||||
|
||||
# Calculate new expiration timestamp
|
||||
expires_at = -1
|
||||
if refreshed_response.expires_in and refreshed_response.expires_in > 0:
|
||||
import time
|
||||
|
||||
expires_at = int(time.time()) + refreshed_response.expires_in
|
||||
|
||||
# Update credential in database
|
||||
updated_credential = EndUserAuthService.refresh_oauth_token(
|
||||
credential_id=credential_id,
|
||||
end_user_id=end_user_id,
|
||||
tenant_id=tenant_id,
|
||||
refreshed_credentials=refreshed_response.credentials,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"credential_id": updated_credential.id,
|
||||
"expires_at": expires_at,
|
||||
"refreshed_at": int(updated_credential.updated_at.timestamp()),
|
||||
}
|
||||
except Exception:
|
||||
logger.exception("Error refreshing OAuth token for end user")
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Failed to refresh token",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_oauth_client_info(tenant_id: str, provider: str) -> dict[str, Any]:
|
||||
"""
|
||||
Get OAuth client information for a provider.
|
||||
Used to check if OAuth is available and configured.
|
||||
|
||||
:param tenant_id: The tenant ID
|
||||
:param provider: The provider identifier
|
||||
:return: Dict with OAuth client info
|
||||
"""
|
||||
try:
|
||||
# Check if OAuth client exists (either system or custom)
|
||||
oauth_client = BuiltinToolManageService.get_oauth_client(tenant_id, provider)
|
||||
|
||||
return {
|
||||
"configured": oauth_client is not None,
|
||||
"system_configured": BuiltinToolManageService.is_oauth_system_client_exists(provider),
|
||||
"custom_configured": BuiltinToolManageService.is_oauth_custom_client_enabled(tenant_id, provider),
|
||||
}
|
||||
except Exception:
|
||||
logger.exception("Error getting OAuth client info")
|
||||
return {
|
||||
"configured": False,
|
||||
"system_configured": False,
|
||||
"custom_configured": False,
|
||||
}
|
||||
@@ -426,6 +426,27 @@ class ToolTransformService:
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def convert_enduser_provider_to_credential_entity(
|
||||
provider, credentials: dict
|
||||
) -> ToolProviderCredentialApiEntity:
|
||||
"""
|
||||
Convert EndUserAuthenticationProvider to ToolProviderCredentialApiEntity.
|
||||
|
||||
:param provider: EndUserAuthenticationProvider instance
|
||||
:param credentials: Decrypted/masked credentials dict
|
||||
:return: ToolProviderCredentialApiEntity
|
||||
"""
|
||||
return ToolProviderCredentialApiEntity(
|
||||
id=provider.id,
|
||||
name=provider.name,
|
||||
provider=provider.provider,
|
||||
credential_type=CredentialType.of(provider.credential_type),
|
||||
is_default=False, # End-user credentials don't have default flag (use oldest)
|
||||
credentials=credentials,
|
||||
expires_at=provider.expires_at,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def convert_mcp_schema_to_parameter(schema: dict[str, Any]) -> list["ToolParameter"]:
|
||||
"""
|
||||
|
||||
486
api/tests/unit_tests/core/tools/test_enduser_tool_auth.py
Normal file
486
api/tests/unit_tests/core/tools/test_enduser_tool_auth.py
Normal file
@@ -0,0 +1,486 @@
|
||||
"""
|
||||
Unit tests for end-user tool authentication.
|
||||
|
||||
Tests the integration of end-user authentication with tool runtime resolution.
|
||||
"""
|
||||
|
||||
import time
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.tools.entities.tool_entities import ToolAuthType, ToolInvokeFrom, ToolProviderType
|
||||
from core.tools.errors import ToolProviderNotFoundError
|
||||
from core.tools.tool_manager import ToolManager
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session():
|
||||
"""Mock database session."""
|
||||
with patch("core.tools.tool_manager.db") as mock_db:
|
||||
yield mock_db.session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_provider_controller():
|
||||
"""Mock builtin provider controller."""
|
||||
controller = MagicMock()
|
||||
controller.need_credentials = True
|
||||
controller.get_tool = MagicMock(return_value=MagicMock())
|
||||
controller.get_credentials_schema_by_type = MagicMock(return_value=[])
|
||||
return controller
|
||||
|
||||
|
||||
class TestEndUserToolAuthentication:
|
||||
"""Test suite for end-user tool authentication."""
|
||||
|
||||
def test_end_user_auth_requires_end_user_id(self, mock_db_session, mock_provider_controller):
|
||||
"""
|
||||
Test that END_USER auth_type requires end_user_id parameter.
|
||||
|
||||
When auth_type is END_USER but end_user_id is None, should raise error.
|
||||
"""
|
||||
with patch.object(ToolManager, "get_builtin_provider", return_value=mock_provider_controller):
|
||||
with pytest.raises(ToolProviderNotFoundError, match="end_user_id is required"):
|
||||
ToolManager.get_tool_runtime(
|
||||
provider_type=ToolProviderType.BUILT_IN,
|
||||
provider_id="test_provider",
|
||||
tool_name="test_tool",
|
||||
tenant_id="test_tenant",
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
tool_invoke_from=ToolInvokeFrom.WORKFLOW,
|
||||
auth_type=ToolAuthType.END_USER,
|
||||
end_user_id=None, # Missing!
|
||||
)
|
||||
|
||||
def test_end_user_auth_missing_credentials(self, mock_db_session, mock_provider_controller):
|
||||
"""
|
||||
Test that error is raised when end-user has no credentials for provider.
|
||||
|
||||
When auth_type is END_USER but no credentials exist, should raise error.
|
||||
"""
|
||||
# Mock no credentials found
|
||||
mock_db_session.query.return_value.where.return_value.order_by.return_value.first.return_value = None
|
||||
|
||||
with patch.object(ToolManager, "get_builtin_provider", return_value=mock_provider_controller):
|
||||
with pytest.raises(ToolProviderNotFoundError, match="No end-user credentials found"):
|
||||
ToolManager.get_tool_runtime(
|
||||
provider_type=ToolProviderType.BUILT_IN,
|
||||
provider_id="test_provider",
|
||||
tool_name="test_tool",
|
||||
tenant_id="test_tenant",
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
tool_invoke_from=ToolInvokeFrom.WORKFLOW,
|
||||
auth_type=ToolAuthType.END_USER,
|
||||
end_user_id="end_user_123",
|
||||
)
|
||||
|
||||
def test_end_user_auth_with_credentials(self, mock_db_session, mock_provider_controller):
|
||||
"""
|
||||
Test successful end-user credential resolution.
|
||||
|
||||
When auth_type is END_USER and credentials exist, should return tool runtime.
|
||||
"""
|
||||
# Mock end-user provider
|
||||
mock_enduser_provider = MagicMock()
|
||||
mock_enduser_provider.id = "cred_123"
|
||||
mock_enduser_provider.credential_type = "api-key"
|
||||
mock_enduser_provider.credentials = '{"api_key": "encrypted"}'
|
||||
mock_enduser_provider.expires_at = -1 # No expiry
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.order_by.return_value.first.return_value = (
|
||||
mock_enduser_provider
|
||||
)
|
||||
|
||||
# Mock encrypter
|
||||
mock_encrypter = MagicMock()
|
||||
mock_encrypter.decrypt.return_value = {"api_key": "decrypted_key"}
|
||||
mock_cache = MagicMock()
|
||||
|
||||
with (
|
||||
patch.object(ToolManager, "get_builtin_provider", return_value=mock_provider_controller),
|
||||
patch("core.tools.tool_manager.create_provider_encrypter", return_value=(mock_encrypter, mock_cache)),
|
||||
):
|
||||
tool_runtime = ToolManager.get_tool_runtime(
|
||||
provider_type=ToolProviderType.BUILT_IN,
|
||||
provider_id="test_provider",
|
||||
tool_name="test_tool",
|
||||
tenant_id="test_tenant",
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
tool_invoke_from=ToolInvokeFrom.WORKFLOW,
|
||||
auth_type=ToolAuthType.END_USER,
|
||||
end_user_id="end_user_123",
|
||||
)
|
||||
|
||||
# Verify tool runtime was created
|
||||
assert tool_runtime is not None
|
||||
# Verify encrypter was called with decrypted credentials
|
||||
mock_encrypter.decrypt.assert_called_once()
|
||||
|
||||
def test_workspace_auth_backward_compatibility(self, mock_db_session, mock_provider_controller):
|
||||
"""
|
||||
Test that workspace authentication still works (backward compatibility).
|
||||
|
||||
When auth_type is WORKSPACE (default), should use workspace credentials.
|
||||
"""
|
||||
# Mock workspace provider
|
||||
mock_workspace_provider = MagicMock()
|
||||
mock_workspace_provider.id = "workspace_cred_123"
|
||||
mock_workspace_provider.credential_type = "api-key"
|
||||
mock_workspace_provider.credentials = '{"api_key": "workspace_encrypted"}'
|
||||
mock_workspace_provider.expires_at = -1
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.order_by.return_value.first.return_value = (
|
||||
mock_workspace_provider
|
||||
)
|
||||
|
||||
# Mock encrypter
|
||||
mock_encrypter = MagicMock()
|
||||
mock_encrypter.decrypt.return_value = {"api_key": "workspace_decrypted"}
|
||||
mock_cache = MagicMock()
|
||||
|
||||
with (
|
||||
patch.object(ToolManager, "get_builtin_provider", return_value=mock_provider_controller),
|
||||
patch("core.tools.tool_manager.create_provider_encrypter", return_value=(mock_encrypter, mock_cache)),
|
||||
patch("core.helper.credential_utils.check_credential_policy_compliance"),
|
||||
):
|
||||
tool_runtime = ToolManager.get_tool_runtime(
|
||||
provider_type=ToolProviderType.BUILT_IN,
|
||||
provider_id="test_provider",
|
||||
tool_name="test_tool",
|
||||
tenant_id="test_tenant",
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
tool_invoke_from=ToolInvokeFrom.WORKFLOW,
|
||||
auth_type=ToolAuthType.WORKSPACE, # Workspace auth
|
||||
end_user_id=None, # Not needed for workspace auth
|
||||
)
|
||||
|
||||
# Verify tool runtime was created
|
||||
assert tool_runtime is not None
|
||||
|
||||
def test_workflow_tool_runtime_passes_end_user_id(self, mock_db_session, mock_provider_controller):
|
||||
"""
|
||||
Test that get_workflow_tool_runtime correctly passes end_user_id to get_tool_runtime.
|
||||
"""
|
||||
from core.workflow.nodes.tool.entities import ToolEntity
|
||||
|
||||
# Create a mock ToolEntity with END_USER auth_type
|
||||
workflow_tool = MagicMock(spec=ToolEntity)
|
||||
workflow_tool.provider_type = ToolProviderType.BUILT_IN
|
||||
workflow_tool.provider_id = "test_provider"
|
||||
workflow_tool.tool_name = "test_tool"
|
||||
workflow_tool.credential_id = None
|
||||
workflow_tool.auth_type = ToolAuthType.END_USER
|
||||
workflow_tool.tool_configurations = {}
|
||||
|
||||
# Mock end-user credentials
|
||||
mock_enduser_provider = MagicMock()
|
||||
mock_enduser_provider.id = "cred_123"
|
||||
mock_enduser_provider.credential_type = "api-key"
|
||||
mock_enduser_provider.credentials = '{"api_key": "encrypted"}'
|
||||
mock_enduser_provider.expires_at = -1
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.order_by.return_value.first.return_value = (
|
||||
mock_enduser_provider
|
||||
)
|
||||
|
||||
mock_encrypter = MagicMock()
|
||||
mock_encrypter.decrypt.return_value = {"api_key": "decrypted"}
|
||||
mock_cache = MagicMock()
|
||||
|
||||
with (
|
||||
patch.object(ToolManager, "get_builtin_provider", return_value=mock_provider_controller),
|
||||
patch("core.tools.tool_manager.create_provider_encrypter", return_value=(mock_encrypter, mock_cache)),
|
||||
):
|
||||
tool_runtime = ToolManager.get_workflow_tool_runtime(
|
||||
tenant_id="test_tenant",
|
||||
app_id="test_app",
|
||||
node_id="test_node",
|
||||
workflow_tool=workflow_tool,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
variable_pool=None,
|
||||
end_user_id="end_user_123", # Pass end_user_id
|
||||
)
|
||||
|
||||
# Verify tool runtime was created
|
||||
assert tool_runtime is not None
|
||||
|
||||
|
||||
class TestOAuthTokenRefresh:
|
||||
"""Test suite for OAuth token refresh functionality."""
|
||||
|
||||
def test_enduser_oauth_token_refresh_when_expired(self, mock_db_session, mock_provider_controller):
|
||||
"""
|
||||
Test that end-user OAuth tokens are automatically refreshed when expired.
|
||||
|
||||
When an OAuth token is expired (expires_at < current_time + 60s buffer),
|
||||
the system should automatically refresh it before using.
|
||||
"""
|
||||
# Mock end-user provider with expired OAuth token
|
||||
mock_enduser_provider = MagicMock()
|
||||
mock_enduser_provider.id = "cred_123"
|
||||
mock_enduser_provider.credential_type = "oauth2"
|
||||
mock_enduser_provider.credentials = '{"access_token": "old_token", "refresh_token": "refresh"}'
|
||||
# Set expiry to past (token expired)
|
||||
mock_enduser_provider.expires_at = int(time.time()) - 100
|
||||
mock_enduser_provider.encrypted_credentials = '{"access_token": "old_token", "refresh_token": "refresh"}'
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.order_by.return_value.first.return_value = (
|
||||
mock_enduser_provider
|
||||
)
|
||||
|
||||
# Mock encrypter
|
||||
mock_encrypter = MagicMock()
|
||||
mock_encrypter.decrypt.return_value = {"access_token": "old_token", "refresh_token": "refresh"}
|
||||
mock_encrypter.encrypt.return_value = {"access_token": "new_token", "refresh_token": "refresh"}
|
||||
mock_cache = MagicMock()
|
||||
|
||||
# Mock OAuth refresh response
|
||||
mock_refreshed_credentials = MagicMock()
|
||||
mock_refreshed_credentials.credentials = {"access_token": "new_token", "refresh_token": "refresh"}
|
||||
mock_refreshed_credentials.expires_at = int(time.time()) + 3600 # New expiry 1 hour from now
|
||||
|
||||
with (
|
||||
patch.object(ToolManager, "get_builtin_provider", return_value=mock_provider_controller),
|
||||
patch("core.tools.tool_manager.create_provider_encrypter", return_value=(mock_encrypter, mock_cache)),
|
||||
patch.object(
|
||||
ToolManager,
|
||||
"_refresh_oauth_credentials",
|
||||
return_value=(
|
||||
mock_refreshed_credentials.credentials,
|
||||
mock_refreshed_credentials.expires_at,
|
||||
),
|
||||
) as mock_refresh,
|
||||
):
|
||||
tool_runtime = ToolManager.get_tool_runtime(
|
||||
provider_type=ToolProviderType.BUILT_IN,
|
||||
provider_id="github",
|
||||
tool_name="test_tool",
|
||||
tenant_id="test_tenant",
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
tool_invoke_from=ToolInvokeFrom.WORKFLOW,
|
||||
auth_type=ToolAuthType.END_USER,
|
||||
end_user_id="end_user_123",
|
||||
)
|
||||
|
||||
# Verify refresh was called
|
||||
mock_refresh.assert_called_once_with(
|
||||
tenant_id="test_tenant",
|
||||
provider_id="github",
|
||||
user_id="end_user_123",
|
||||
decrypted_credentials={"access_token": "old_token", "refresh_token": "refresh"},
|
||||
)
|
||||
|
||||
# Verify provider was updated with new credentials
|
||||
assert mock_enduser_provider.expires_at == mock_refreshed_credentials.expires_at
|
||||
mock_db_session.commit.assert_called_once()
|
||||
mock_cache.delete.assert_called_once()
|
||||
|
||||
# Verify tool runtime was created
|
||||
assert tool_runtime is not None
|
||||
|
||||
def test_enduser_oauth_token_not_refreshed_when_valid(self, mock_db_session, mock_provider_controller):
|
||||
"""
|
||||
Test that valid OAuth tokens are NOT refreshed.
|
||||
|
||||
When an OAuth token is still valid (expires_at > current_time + 60s buffer),
|
||||
the system should use it without refreshing.
|
||||
"""
|
||||
# Mock end-user provider with valid OAuth token
|
||||
mock_enduser_provider = MagicMock()
|
||||
mock_enduser_provider.id = "cred_123"
|
||||
mock_enduser_provider.credential_type = "oauth2"
|
||||
mock_enduser_provider.credentials = '{"access_token": "valid_token", "refresh_token": "refresh"}'
|
||||
# Set expiry to future (token still valid with buffer)
|
||||
mock_enduser_provider.expires_at = int(time.time()) + 3600 # Valid for 1 hour
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.order_by.return_value.first.return_value = (
|
||||
mock_enduser_provider
|
||||
)
|
||||
|
||||
# Mock encrypter
|
||||
mock_encrypter = MagicMock()
|
||||
mock_encrypter.decrypt.return_value = {"access_token": "valid_token", "refresh_token": "refresh"}
|
||||
mock_cache = MagicMock()
|
||||
|
||||
with (
|
||||
patch.object(ToolManager, "get_builtin_provider", return_value=mock_provider_controller),
|
||||
patch("core.tools.tool_manager.create_provider_encrypter", return_value=(mock_encrypter, mock_cache)),
|
||||
patch.object(ToolManager, "_refresh_oauth_credentials") as mock_refresh,
|
||||
):
|
||||
tool_runtime = ToolManager.get_tool_runtime(
|
||||
provider_type=ToolProviderType.BUILT_IN,
|
||||
provider_id="github",
|
||||
tool_name="test_tool",
|
||||
tenant_id="test_tenant",
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
tool_invoke_from=ToolInvokeFrom.WORKFLOW,
|
||||
auth_type=ToolAuthType.END_USER,
|
||||
end_user_id="end_user_123",
|
||||
)
|
||||
|
||||
# Verify refresh was NOT called (token still valid)
|
||||
mock_refresh.assert_not_called()
|
||||
|
||||
# Verify tool runtime was created with original credentials
|
||||
assert tool_runtime is not None
|
||||
|
||||
def test_workspace_oauth_token_refresh_when_expired(self, mock_db_session, mock_provider_controller):
|
||||
"""
|
||||
Test that workspace OAuth tokens are automatically refreshed when expired.
|
||||
|
||||
This ensures the refactored _refresh_oauth_credentials helper works
|
||||
for both end-user and workspace authentication flows.
|
||||
"""
|
||||
# Mock workspace provider with expired OAuth token
|
||||
mock_workspace_provider = MagicMock()
|
||||
mock_workspace_provider.id = "workspace_cred_123"
|
||||
mock_workspace_provider.user_id = "workspace_user_456"
|
||||
mock_workspace_provider.credential_type = "oauth2"
|
||||
mock_workspace_provider.credentials = '{"access_token": "old_workspace_token"}'
|
||||
mock_workspace_provider.expires_at = int(time.time()) - 100 # Expired
|
||||
mock_workspace_provider.encrypted_credentials = '{"access_token": "old_workspace_token"}'
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.order_by.return_value.first.return_value = (
|
||||
mock_workspace_provider
|
||||
)
|
||||
|
||||
# Mock encrypter
|
||||
mock_encrypter = MagicMock()
|
||||
mock_encrypter.decrypt.return_value = {"access_token": "old_workspace_token"}
|
||||
mock_encrypter.encrypt.return_value = {"access_token": "new_workspace_token"}
|
||||
mock_cache = MagicMock()
|
||||
|
||||
# Mock OAuth refresh response
|
||||
refreshed_creds = {"access_token": "new_workspace_token"}
|
||||
new_expires_at = int(time.time()) + 3600
|
||||
|
||||
with (
|
||||
patch.object(ToolManager, "get_builtin_provider", return_value=mock_provider_controller),
|
||||
patch("core.tools.tool_manager.create_provider_encrypter", return_value=(mock_encrypter, mock_cache)),
|
||||
patch("core.helper.credential_utils.check_credential_policy_compliance"),
|
||||
patch.object(
|
||||
ToolManager, "_refresh_oauth_credentials", return_value=(refreshed_creds, new_expires_at)
|
||||
) as mock_refresh,
|
||||
):
|
||||
tool_runtime = ToolManager.get_tool_runtime(
|
||||
provider_type=ToolProviderType.BUILT_IN,
|
||||
provider_id="github",
|
||||
tool_name="test_tool",
|
||||
tenant_id="test_tenant",
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
tool_invoke_from=ToolInvokeFrom.WORKFLOW,
|
||||
auth_type=ToolAuthType.WORKSPACE,
|
||||
)
|
||||
|
||||
# Verify refresh was called with workspace user_id
|
||||
mock_refresh.assert_called_once_with(
|
||||
tenant_id="test_tenant",
|
||||
provider_id="github",
|
||||
user_id="workspace_user_456",
|
||||
decrypted_credentials={"access_token": "old_workspace_token"},
|
||||
)
|
||||
|
||||
# Verify provider was updated
|
||||
assert mock_workspace_provider.expires_at == new_expires_at
|
||||
mock_db_session.commit.assert_called_once()
|
||||
mock_cache.delete.assert_called_once()
|
||||
|
||||
# Verify tool runtime was created
|
||||
assert tool_runtime is not None
|
||||
|
||||
def test_oauth_token_no_refresh_for_non_oauth_credentials(self, mock_db_session, mock_provider_controller):
|
||||
"""
|
||||
Test that non-OAuth credentials (API keys) are never refreshed.
|
||||
|
||||
API keys with expires_at = -1 should not trigger refresh logic.
|
||||
"""
|
||||
# Mock end-user provider with API key (no expiry)
|
||||
mock_enduser_provider = MagicMock()
|
||||
mock_enduser_provider.id = "cred_123"
|
||||
mock_enduser_provider.credential_type = "api-key"
|
||||
mock_enduser_provider.credentials = '{"api_key": "sk-1234567890"}'
|
||||
mock_enduser_provider.expires_at = -1 # API keys don't expire
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.order_by.return_value.first.return_value = (
|
||||
mock_enduser_provider
|
||||
)
|
||||
|
||||
# Mock encrypter
|
||||
mock_encrypter = MagicMock()
|
||||
mock_encrypter.decrypt.return_value = {"api_key": "sk-1234567890"}
|
||||
mock_cache = MagicMock()
|
||||
|
||||
with (
|
||||
patch.object(ToolManager, "get_builtin_provider", return_value=mock_provider_controller),
|
||||
patch("core.tools.tool_manager.create_provider_encrypter", return_value=(mock_encrypter, mock_cache)),
|
||||
patch.object(ToolManager, "_refresh_oauth_credentials") as mock_refresh,
|
||||
):
|
||||
tool_runtime = ToolManager.get_tool_runtime(
|
||||
provider_type=ToolProviderType.BUILT_IN,
|
||||
provider_id="openai",
|
||||
tool_name="test_tool",
|
||||
tenant_id="test_tenant",
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
tool_invoke_from=ToolInvokeFrom.WORKFLOW,
|
||||
auth_type=ToolAuthType.END_USER,
|
||||
end_user_id="end_user_123",
|
||||
)
|
||||
|
||||
# Verify refresh was NOT called (API key doesn't need refresh)
|
||||
mock_refresh.assert_not_called()
|
||||
|
||||
# Verify tool runtime was created
|
||||
assert tool_runtime is not None
|
||||
|
||||
def test_refresh_oauth_credentials_helper_method(self):
|
||||
"""
|
||||
Test the _refresh_oauth_credentials helper method directly.
|
||||
|
||||
This tests the centralized OAuth refresh logic that is used by both
|
||||
end-user and workspace authentication flows.
|
||||
"""
|
||||
# Mock dependencies
|
||||
mock_oauth_handler = MagicMock()
|
||||
mock_refreshed = MagicMock()
|
||||
mock_refreshed.credentials = {"access_token": "new_token", "refresh_token": "new_refresh"}
|
||||
mock_refreshed.expires_at = int(time.time()) + 7200
|
||||
mock_oauth_handler.refresh_credentials.return_value = mock_refreshed
|
||||
|
||||
with (
|
||||
# Patch OAuthHandler where it's imported (inside the method)
|
||||
patch("core.plugin.impl.oauth.OAuthHandler", return_value=mock_oauth_handler),
|
||||
patch("core.tools.tool_manager.ToolProviderID") as mock_provider_id,
|
||||
patch(
|
||||
"services.tools.builtin_tools_manage_service.BuiltinToolManageService.get_oauth_client",
|
||||
return_value={"client_id": "test"},
|
||||
),
|
||||
patch("core.tools.tool_manager.dify_config.CONSOLE_API_URL", "http://localhost:5001"),
|
||||
):
|
||||
# Setup provider ID mock
|
||||
mock_provider_id.return_value.provider_name = "github"
|
||||
mock_provider_id.return_value.plugin_id = "builtin"
|
||||
|
||||
# Call the helper method
|
||||
credentials, expires_at = ToolManager._refresh_oauth_credentials(
|
||||
tenant_id="test_tenant",
|
||||
provider_id="langgenius/github/github",
|
||||
user_id="user_123",
|
||||
decrypted_credentials={"access_token": "old_token", "refresh_token": "old_refresh"},
|
||||
)
|
||||
|
||||
# Verify OAuth handler was called correctly
|
||||
mock_oauth_handler.refresh_credentials.assert_called_once_with(
|
||||
tenant_id="test_tenant",
|
||||
user_id="user_123",
|
||||
plugin_id="builtin",
|
||||
provider="github",
|
||||
redirect_uri="http://localhost:5001/console/api/oauth/plugin/langgenius/github/github/tool/callback",
|
||||
system_credentials={"client_id": "test"},
|
||||
credentials={"access_token": "old_token", "refresh_token": "old_refresh"},
|
||||
)
|
||||
|
||||
# Verify returned values
|
||||
assert credentials == {"access_token": "new_token", "refresh_token": "new_refresh"}
|
||||
assert expires_at == mock_refreshed.expires_at
|
||||
@@ -214,6 +214,7 @@ const ToolSelector: FC<Props> = ({
|
||||
onSelect({
|
||||
...value,
|
||||
use_end_user_credentials: enabled,
|
||||
auth_type: enabled ? 'end_user' : 'workspace',
|
||||
} as any)
|
||||
}
|
||||
const handleEndUserCredentialTypeChange = (type: string) => {
|
||||
|
||||
@@ -24,4 +24,5 @@ export type ToolNodeType = CommonNodeType & {
|
||||
provider_icon?: Collection['icon']
|
||||
provider_icon_dark?: Collection['icon_dark']
|
||||
plugin_unique_identifier?: string
|
||||
auth_type?: 'workspace' | 'end_user'
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user