feat: add end user oauth apis (#29683)

This commit is contained in:
Charles Yao
2025-12-19 15:30:16 +08:00
committed by GitHub
parent d98e70922d
commit 06456afd49
14 changed files with 1970 additions and 27 deletions

View File

@@ -54,6 +54,7 @@ from .app import (
completion,
conversation,
conversation_variables,
enduser_auth,
generator,
mcp_server,
message,
@@ -162,6 +163,7 @@ __all__ = [
"datasource_content_preview",
"email_register",
"endpoint",
"enduser_auth",
"extension",
"external",
"feature",

View File

@@ -0,0 +1,230 @@
"""
End-user authentication API controllers.
Provides API endpoints for managing end-user credentials for tool authentication.
"""
from flask import request
from flask_restx import Resource
from werkzeug.exceptions import BadRequest
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder
from libs.login import current_account_with_tenant, login_required
from models import AppMode
from services.tools.app_auth_requirement_service import AppAuthRequirementService
from services.tools.enduser_auth_service import EndUserAuthService
from services.tools.enduser_oauth_service import EndUserOAuthService
@console_ns.route("/apps/<uuid:app_id>/auth/providers")
class AppAuthProvidersApi(Resource):
"""
Get list of authentication providers required for an app.
Returns providers that require end-user authentication based on app configuration.
"""
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.COMPLETION, AppMode.WORKFLOW])
def get(self, app_model):
"""Get authentication providers required for the app."""
_, tenant_id = current_account_with_tenant()
providers = AppAuthRequirementService.get_required_providers(
tenant_id=tenant_id,
app_id=str(app_model.id),
)
return jsonable_encoder(providers)
@console_ns.route("/apps/<uuid:app_id>/auth/providers/<path:provider_id>/credentials")
class AppAuthProviderCredentialsApi(Resource):
"""
Manage end-user credentials for a specific provider.
Allows listing, creating, and deleting end-user credentials.
"""
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.COMPLETION, AppMode.WORKFLOW])
def get(self, app_model, provider_id: str):
"""List end-user's credentials for this provider."""
user, tenant_id = current_account_with_tenant()
# For console API, use the current account user as end_user_id
# In production, this would be the actual end-user ID from the chat/completion request
end_user_id = str(user.id)
credentials = EndUserAuthService.list_credentials(
tenant_id=tenant_id,
end_user_id=end_user_id,
provider_id=provider_id,
)
return jsonable_encoder(credentials)
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.COMPLETION, AppMode.WORKFLOW])
def post(self, app_model, provider_id: str):
"""Create a new credential (API key only)."""
user, tenant_id = current_account_with_tenant()
end_user_id = str(user.id)
payload = request.get_json()
if not payload:
raise BadRequest("Request body is required")
credential_type = payload.get("credential_type")
credentials = payload.get("credentials")
if not credential_type or not credentials:
raise BadRequest("credential_type and credentials are required")
if credential_type != "api-key":
raise BadRequest(
"Only 'api-key' credential type can be created via this endpoint. "
"Use OAuth flow for OAuth credentials."
)
credential = EndUserAuthService.create_api_key_credential(
tenant_id=tenant_id,
end_user_id=end_user_id,
provider_id=provider_id,
credentials=credentials,
)
return jsonable_encoder(credential)
@console_ns.route("/apps/<uuid:app_id>/auth/providers/<path:provider_id>/credentials/<string:credential_id>")
class AppAuthProviderCredentialApi(Resource):
"""
Manage a specific end-user credential.
Allows getting, updating, or deleting a credential.
"""
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.COMPLETION, AppMode.WORKFLOW])
def delete(self, app_model, provider_id: str, credential_id: str):
"""Delete a credential."""
user, tenant_id = current_account_with_tenant()
end_user_id = str(user.id)
EndUserAuthService.delete_credential(
tenant_id=tenant_id,
end_user_id=end_user_id,
provider_id=provider_id,
credential_id=credential_id,
)
return {"result": "success"}
@console_ns.route("/apps/<uuid:app_id>/auth/oauth/<path:provider_id>/authorization-url")
class AppAuthOAuthAuthorizationUrlApi(Resource):
"""
Get OAuth authorization URL for end-user authentication.
Returns the URL where the user should be redirected to authenticate with the provider.
"""
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.COMPLETION, AppMode.WORKFLOW])
def get(self, app_model, provider_id: str):
"""Get OAuth authorization URL."""
user, tenant_id = current_account_with_tenant()
end_user_id = str(user.id)
result = EndUserOAuthService.get_authorization_url(
end_user_id=end_user_id,
tenant_id=tenant_id,
app_id=str(app_model.id),
provider=provider_id,
)
# Set OAuth context cookie for callback
response = jsonable_encoder({
"authorization_url": result["authorization_url"],
})
# Store context_id in response for frontend to set as cookie
response["context_id"] = result["context_id"]
return response
@console_ns.route("/apps/<uuid:app_id>/auth/oauth/<path:provider_id>/callback")
class AppAuthOAuthCallbackApi(Resource):
"""
Handle OAuth callback for end-user authentication.
This endpoint is called by the OAuth provider after user authorization.
"""
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.COMPLETION, AppMode.WORKFLOW])
def get(self, app_model, provider_id: str):
"""Handle OAuth callback and store credentials."""
# Get OAuth context ID from cookie
context_id = request.cookies.get("oauth_context_id")
if not context_id:
raise BadRequest("Missing OAuth context")
# Get OAuth error if any
error = request.args.get("error")
error_description = request.args.get("error_description")
if error:
raise BadRequest(f"OAuth error: {error} - {error_description}")
# Handle callback and create credential
result = EndUserOAuthService.handle_oauth_callback(
context_id=context_id,
request=request,
)
return jsonable_encoder(result)
@console_ns.route("/apps/<uuid:app_id>/auth/providers/<path:provider_id>/credentials/<string:credential_id>/refresh")
class AppAuthProviderRefreshApi(Resource):
"""
Manually refresh OAuth token for a credential.
This endpoint allows refreshing an expired OAuth token.
"""
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.COMPLETION, AppMode.WORKFLOW])
def post(self, app_model, provider_id: str, credential_id: str):
"""Refresh OAuth token."""
user, tenant_id = current_account_with_tenant()
end_user_id = str(user.id)
result = EndUserOAuthService.refresh_oauth_token(
credential_id=credential_id,
end_user_id=end_user_id,
tenant_id=tenant_id,
)
return jsonable_encoder(result)

View File

@@ -3,7 +3,7 @@ from typing import Any, Union
from pydantic import BaseModel, Field
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
from core.tools.entities.tool_entities import ToolAuthType, ToolInvokeMessage, ToolProviderType
class AgentToolEntity(BaseModel):
@@ -17,6 +17,7 @@ class AgentToolEntity(BaseModel):
tool_parameters: dict[str, Any] = Field(default_factory=dict)
plugin_unique_identifier: str | None = None
credential_id: str | None = None
auth_type: ToolAuthType = ToolAuthType.WORKSPACE
class AgentPromptEntity(BaseModel):

View File

@@ -124,6 +124,7 @@ class ToolProviderCredentialApiEntity(BaseModel):
default=False, description="Whether the credential is the default credential for the provider in the workspace"
)
credentials: Mapping[str, object] = Field(description="The credentials of the provider", default_factory=dict)
expires_at: int = Field(default=-1, description="Unix timestamp when credential expires (-1 for no expiry)")
class ToolProviderCredentialInfoApiEntity(BaseModel):

View File

@@ -49,6 +49,7 @@ from core.tools.entities.api_entities import ToolProviderApiEntity, ToolProvider
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import (
ApiProviderAuthType,
ToolAuthType,
ToolInvokeFrom,
ToolParameter,
ToolProviderType,
@@ -58,7 +59,7 @@ from core.tools.tool_label_manager import ToolLabelManager
from core.tools.utils.configuration import ToolParameterConfigurationManager
from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter
from core.tools.workflow_as_tool.tool import WorkflowTool
from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
from models.tools import ApiToolProvider, BuiltinToolProvider, EndUserAuthenticationProvider, WorkflowToolProvider
from services.tools.tools_transform_service import ToolTransformService
if TYPE_CHECKING:
@@ -78,6 +79,54 @@ class ToolManager:
_builtin_providers_loaded = False
_builtin_tools_labels: dict[str, Union[I18nObject, None]] = {}
@classmethod
def _refresh_oauth_credentials(
cls,
tenant_id: str,
provider_id: str,
user_id: str,
decrypted_credentials: Mapping[str, Any],
) -> tuple[dict[str, Any], int]:
"""
Refresh OAuth credentials for a provider.
This is a helper method to centralize the OAuth token refresh logic
used by both end-user and workspace authentication flows.
:param tenant_id: the tenant id
:param provider_id: the provider id
:param user_id: the user id (end_user_id or workspace user_id)
:param decrypted_credentials: the current decrypted credentials
:return: tuple of (refreshed credentials dict, expires_at timestamp)
"""
from core.plugin.impl.oauth import OAuthHandler
# Local import to avoid circular dependency at module level
# This import is necessary but creates a cycle: tool_manager -> builtin_tools_manage_service -> tool_manager
# TODO: Break the circular dependency by refactoring service layer
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
# Parse provider ID and build OAuth configuration
tool_provider = ToolProviderID(provider_id)
provider_name = tool_provider.provider_name
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/tool/callback"
system_credentials = BuiltinToolManageService.get_oauth_client(tenant_id, provider_id)
# Refresh the credentials using OAuth handler
oauth_handler = OAuthHandler()
refreshed_credentials = oauth_handler.refresh_credentials(
tenant_id=tenant_id,
user_id=user_id,
plugin_id=tool_provider.plugin_id,
provider=provider_name,
redirect_uri=redirect_uri,
system_credentials=system_credentials or {},
credentials=decrypted_credentials,
)
return refreshed_credentials.credentials, refreshed_credentials.expires_at
@classmethod
def get_hardcoded_provider(cls, provider: str) -> BuiltinToolProviderController:
"""
@@ -165,6 +214,8 @@ class ToolManager:
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT,
credential_id: str | None = None,
auth_type: ToolAuthType = ToolAuthType.WORKSPACE,
end_user_id: str | None = None,
) -> Union[BuiltinTool, PluginTool, ApiTool, WorkflowTool, MCPTool]:
"""
get the tool runtime
@@ -176,6 +227,8 @@ class ToolManager:
:param invoke_from: invoke from
:param tool_invoke_from: the tool invoke from
:param credential_id: the credential id
:param auth_type: the authentication type (workspace or end_user)
:param end_user_id: the end user id (required when auth_type is END_USER)
:return: the tool
"""
@@ -200,6 +253,75 @@ class ToolManager:
)
),
)
# Handle end-user authentication
if auth_type == ToolAuthType.END_USER:
if not end_user_id:
raise ToolProviderNotFoundError("end_user_id is required for END_USER auth_type")
# Query end-user credentials
enduser_provider = (
db.session.query(EndUserAuthenticationProvider)
.where(
EndUserAuthenticationProvider.tenant_id == tenant_id,
EndUserAuthenticationProvider.end_user_id == end_user_id,
EndUserAuthenticationProvider.provider == provider_id,
)
.order_by(EndUserAuthenticationProvider.created_at.asc())
.first()
)
if enduser_provider is None:
raise ToolProviderNotFoundError(
f"No end-user credentials found for provider {provider_id}"
)
# Decrypt end-user credentials
encrypter, cache = create_provider_encrypter(
tenant_id=tenant_id,
config=[
x.to_basic_provider_config()
for x in provider_controller.get_credentials_schema_by_type(enduser_provider.credential_type)
],
cache=ToolProviderCredentialsCache(
tenant_id=tenant_id, provider=provider_id, credential_id=enduser_provider.id
),
)
decrypted_credentials: Mapping[str, Any] = encrypter.decrypt(enduser_provider.credentials)
# Handle OAuth token refresh for end-users if expired
if enduser_provider.expires_at != -1 and (enduser_provider.expires_at - 60) < int(time.time()):
# Refresh credentials using the centralized helper method
refreshed_credentials, expires_at = cls._refresh_oauth_credentials(
tenant_id=tenant_id,
provider_id=provider_id,
user_id=end_user_id,
decrypted_credentials=decrypted_credentials,
)
# Update the provider with refreshed credentials
enduser_provider.encrypted_credentials = json.dumps(encrypter.encrypt(refreshed_credentials))
enduser_provider.expires_at = expires_at
db.session.commit()
decrypted_credentials = refreshed_credentials
cache.delete()
return cast(
BuiltinTool,
builtin_tool.fork_tool_runtime(
runtime=ToolRuntime(
tenant_id=tenant_id,
credentials=dict(decrypted_credentials),
credential_type=CredentialType.of(enduser_provider.credential_type),
runtime_parameters={},
invoke_from=invoke_from,
tool_invoke_from=tool_invoke_from,
)
),
)
# Handle workspace authentication (existing logic)
builtin_provider = None
if isinstance(provider_controller, PluginToolProviderController):
provider_id_entity = ToolProviderID(provider_id)
@@ -270,34 +392,19 @@ class ToolManager:
# check if the credentials is expired
if builtin_provider.expires_at != -1 and (builtin_provider.expires_at - 60) < int(time.time()):
# TODO: circular import
from core.plugin.impl.oauth import OAuthHandler
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
# refresh the credentials
tool_provider = ToolProviderID(provider_id)
provider_name = tool_provider.provider_name
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/tool/callback"
system_credentials = BuiltinToolManageService.get_oauth_client(tenant_id, provider_id)
oauth_handler = OAuthHandler()
# refresh the credentials
refreshed_credentials = oauth_handler.refresh_credentials(
# Refresh credentials using the centralized helper method
refreshed_credentials, expires_at = cls._refresh_oauth_credentials(
tenant_id=tenant_id,
provider_id=provider_id,
user_id=builtin_provider.user_id,
plugin_id=tool_provider.plugin_id,
provider=provider_name,
redirect_uri=redirect_uri,
system_credentials=system_credentials or {},
credentials=decrypted_credentials,
decrypted_credentials=decrypted_credentials,
)
# update the credentials
builtin_provider.encrypted_credentials = json.dumps(
encrypter.encrypt(refreshed_credentials.credentials)
)
builtin_provider.expires_at = refreshed_credentials.expires_at
# Update the provider with refreshed credentials
builtin_provider.encrypted_credentials = json.dumps(encrypter.encrypt(refreshed_credentials))
builtin_provider.expires_at = expires_at
db.session.commit()
decrypted_credentials = refreshed_credentials.credentials
decrypted_credentials = refreshed_credentials
cache.delete()
return cast(
@@ -368,6 +475,7 @@ class ToolManager:
agent_tool: AgentToolEntity,
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
variable_pool: Optional["VariablePool"] = None,
end_user_id: str | None = None,
) -> Tool:
"""
get the agent tool runtime
@@ -380,6 +488,8 @@ class ToolManager:
invoke_from=invoke_from,
tool_invoke_from=ToolInvokeFrom.AGENT,
credential_id=agent_tool.credential_id,
auth_type=agent_tool.auth_type,
end_user_id=end_user_id,
)
runtime_parameters = {}
parameters = tool_entity.get_merged_runtime_parameters()
@@ -410,6 +520,7 @@ class ToolManager:
workflow_tool: "ToolEntity",
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
variable_pool: Optional["VariablePool"] = None,
end_user_id: str | None = None,
) -> Tool:
"""
get the workflow tool runtime
@@ -423,6 +534,8 @@ class ToolManager:
invoke_from=invoke_from,
tool_invoke_from=ToolInvokeFrom.WORKFLOW,
credential_id=workflow_tool.credential_id,
auth_type=workflow_tool.auth_type,
end_user_id=end_user_id,
)
parameters = tool_runtime.get_merged_runtime_parameters()

View File

@@ -66,6 +66,7 @@ class ToolNode(Node[ToolNodeData]):
# get tool runtime
try:
from core.tools.tool_manager import ToolManager
from models.enums import UserFrom
# This is an issue that caused problems before.
# Logically, we shouldn't use the node_data.version field for judgment
@@ -74,8 +75,20 @@ class ToolNode(Node[ToolNodeData]):
variable_pool: VariablePool | None = None
if self.node_data.version != "1" or self.node_data.tool_node_version is not None:
variable_pool = self.graph_runtime_state.variable_pool
# Determine end_user_id based on user_from
end_user_id = None
if self.user_from == UserFrom.END_USER:
end_user_id = self.user_id
tool_runtime = ToolManager.get_workflow_tool_runtime(
self.tenant_id, self.app_id, self._node_id, self.node_data, self.invoke_from, variable_pool
self.tenant_id,
self.app_id,
self._node_id,
self.node_data,
self.invoke_from,
variable_pool,
end_user_id,
)
except ToolNodeError as e:
yield StreamCompletedEvent(

View File

@@ -0,0 +1,25 @@
"""merge_heads
Revision ID: 3134f4e0620d
Revises: d57accd375ae, a7b4e8f2c9d1
Create Date: 2025-12-14 14:18:19.393720
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '3134f4e0620d'
down_revision = ('d57accd375ae', 'a7b4e8f2c9d1')
branch_labels = None
depends_on = None
def upgrade():
pass
def downgrade():
pass

View File

@@ -0,0 +1,227 @@
import logging
from typing import Any
from sqlalchemy.orm import Session
from core.tools.tool_manager import ToolManager
from extensions.ext_database import db
from models.workflow import Workflow
from services.tools.enduser_auth_service import EndUserAuthService
logger = logging.getLogger(__name__)
class AppAuthRequirementService:
"""
Service for analyzing authentication requirements in apps.
Examines workflow DSL to identify which providers need end-user authentication.
"""
@staticmethod
def get_tool_auth_requirements(
app_id: str,
tenant_id: str,
provider_type: str | None = None,
) -> list[dict[str, Any]]:
"""
Get all authentication requirements for tools in an app.
:param app_id: The application ID
:param tenant_id: The tenant ID
:param provider_type: Optional filter by provider type (e.g., "tool")
:return: List of provider requirements
"""
try:
# Get latest published workflow for the app
with Session(db.engine, autoflush=False) as session:
workflow = (
session.query(Workflow)
.filter_by(app_id=app_id, tenant_id=tenant_id)
.order_by(Workflow.created_at.desc())
.first()
)
if not workflow:
return []
# Parse workflow graph to find tool nodes
graph = workflow.graph_dict
if not graph or "nodes" not in graph:
return []
providers = []
seen_providers = set()
# Iterate through workflow nodes
for node in graph.get("nodes", []):
node_data = node.get("data", {})
node_type = node_data.get("type")
# Check if it's a tool node
if node_type == "tool":
provider_id = node_data.get("provider_id")
provider_name = node_data.get("provider_name")
tool_name = node_data.get("tool_name")
if not provider_id:
continue
# Avoid duplicates
if provider_id in seen_providers:
continue
seen_providers.add(provider_id)
# Get provider controller to check authentication requirements
try:
provider_controller = ToolManager.get_builtin_provider(provider_id, tenant_id)
# Check if provider needs credentials
if not provider_controller.need_credentials:
continue
# Get supported credential types
supported_types = provider_controller.get_supported_credential_types()
# Determine required credential type (prefer OAuth if supported)
required_type = None
if supported_types:
# Prefer OAuth2, then API key
from core.plugin.entities.plugin_daemon import CredentialType
if CredentialType.OAUTH2 in supported_types:
required_type = "oauth2"
elif CredentialType.API_KEY in supported_types:
required_type = "api-key"
else:
required_type = supported_types[0].value
providers.append(
{
"provider_id": provider_id,
"provider_name": provider_name or provider_id,
"supported_credential_types": [ct.value for ct in supported_types],
"required_credential_type": required_type,
"provider_type": "tool",
"feature_context": {
"node_ids": [node.get("id")],
"tool_names": [tool_name] if tool_name else [],
},
}
)
except Exception as e:
logger.warning("Error getting provider info for %s: %s", provider_id, e)
continue
# Filter by provider_type if specified
if provider_type:
providers = [p for p in providers if p.get("provider_type") == provider_type]
return providers
except Exception:
logger.exception("Error getting auth requirements for app %s", app_id)
return []
@staticmethod
def get_required_providers(
tenant_id: str,
app_id: str,
) -> list[dict[str, Any]]:
"""
Get list of providers that require end-user authentication for an app.
Simplified version of get_tool_auth_requirements for API use.
:param tenant_id: The tenant ID
:param app_id: The application ID
:return: List of provider information dictionaries
"""
requirements = AppAuthRequirementService.get_tool_auth_requirements(
app_id=app_id,
tenant_id=tenant_id,
)
# Transform to simpler format for API response
return [
{
"provider_id": req["provider_id"],
"provider_name": req["provider_name"],
"credential_type": req["required_credential_type"],
"is_required": True,
"oauth_config": None if req["required_credential_type"] != "oauth2" else {
"supported_types": req["supported_credential_types"],
},
}
for req in requirements
]
@staticmethod
def get_auth_status(
app_id: str,
end_user_id: str,
tenant_id: str,
) -> dict[str, Any]:
"""
Get overall authentication status for an app and end user.
Shows which providers are authenticated and which need authentication.
:param app_id: The application ID
:param end_user_id: The end user ID
:param tenant_id: The tenant ID
:return: Dict with authentication status for all providers
"""
try:
# Get required providers for this app
required_providers = AppAuthRequirementService.get_tool_auth_requirements(app_id, tenant_id)
# Check authentication status for each provider
provider_statuses = []
for provider_info in required_providers:
provider_id = provider_info["provider_id"]
# Get credentials for this provider
credentials = EndUserAuthService.list_credentials(tenant_id, end_user_id, provider_id)
# Build status
provider_status = {
"provider_id": provider_id,
"provider_name": provider_info["provider_name"],
"provider_type": provider_info["provider_type"],
"authenticated": len(credentials) > 0,
"credentials": [
{
"credential_id": cred.id,
"name": cred.name,
"type": cred.credential_type.value,
"is_default": cred.is_default,
"expires_at": cred.expires_at,
}
for cred in credentials
],
}
provider_statuses.append(provider_status)
return {"providers": provider_statuses}
except Exception:
logger.exception("Error getting auth status for app %s", app_id)
return {"providers": []}
@staticmethod
def is_provider_authenticated(
provider_id: str,
end_user_id: str,
tenant_id: str,
) -> bool:
"""
Check if a specific provider is authenticated for an end user.
:param provider_id: The provider identifier
:param end_user_id: The end user ID
:param tenant_id: The tenant ID
:return: True if authenticated, False otherwise
"""
try:
credentials = EndUserAuthService.list_credentials(tenant_id, end_user_id, provider_id)
return len(credentials) > 0
except Exception:
return False

View File

@@ -0,0 +1,558 @@
import json
import logging
from sqlalchemy import exists, select
from sqlalchemy.orm import Session
from constants import HIDDEN_VALUE, UNKNOWN_VALUE
from core.helper.name_generator import generate_incremental_name
from core.helper.provider_cache import NoOpProviderCredentialCache
from core.plugin.entities.plugin_daemon import CredentialType
from core.tools.entities.api_entities import ToolProviderCredentialApiEntity
from core.tools.tool_manager import ToolManager
from core.tools.utils.encryption import create_provider_encrypter
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.tools import EndUserAuthenticationProvider
from services.tools.tools_transform_service import ToolTransformService
logger = logging.getLogger(__name__)
class EndUserAuthService:
"""
Service for managing end-user authentication credentials.
Follows similar patterns to BuiltinToolManageService but for end users.
"""
__MAX_CREDENTIALS_PER_PROVIDER__ = 100
@staticmethod
def list_credentials(
tenant_id: str, end_user_id: str, provider_id: str
) -> list[ToolProviderCredentialApiEntity]:
"""
List all credentials for a specific provider and end user.
:param end_user_id: The end user ID
:param tenant_id: The tenant ID
:param provider_id: The provider identifier
:return: List of credential entities
"""
with Session(db.engine, autoflush=False) as session:
credentials = (
session.query(EndUserAuthenticationProvider)
.filter_by(end_user_id=end_user_id, tenant_id=tenant_id, provider=provider_id)
.order_by(EndUserAuthenticationProvider.created_at.asc())
.all()
)
if not credentials:
return []
# Get provider controller to access credential schema
provider_controller = ToolManager.get_builtin_provider(provider_id, tenant_id)
result: list[ToolProviderCredentialApiEntity] = []
for credential in credentials:
try:
# Create encrypter for masking credentials
encrypter, _ = EndUserAuthService._create_encrypter(
tenant_id, provider_controller, credential.credential_type
)
# Decrypt and mask credentials
decrypted = encrypter.decrypt(credential.credentials)
masked_credentials = encrypter.mask_plugin_credentials(decrypted)
# Convert to API entity
credential_entity = ToolTransformService.convert_enduser_provider_to_credential_entity(
provider=credential,
credentials=dict(masked_credentials),
)
result.append(credential_entity)
except Exception:
logger.exception("Error processing credential %s", credential.id)
continue
return result
@staticmethod
def get_credential(
credential_id: str, end_user_id: str, tenant_id: str, mask_credentials: bool = True
) -> ToolProviderCredentialApiEntity | None:
"""
Get a specific credential by ID.
:param credential_id: The credential ID
:param end_user_id: The end user ID
:param tenant_id: The tenant ID
:param mask_credentials: Whether to mask secret fields
:return: Credential entity or None
"""
with Session(db.engine, autoflush=False) as session:
credential = (
session.query(EndUserAuthenticationProvider)
.filter_by(id=credential_id, end_user_id=end_user_id, tenant_id=tenant_id)
.first()
)
if not credential:
return None
# Get provider controller
provider_controller = ToolManager.get_builtin_provider(credential.provider, tenant_id)
# Create encrypter
encrypter, _ = EndUserAuthService._create_encrypter(
tenant_id, provider_controller, credential.credential_type
)
# Decrypt credentials
decrypted = encrypter.decrypt(credential.credentials)
# Mask if requested
if mask_credentials:
decrypted = encrypter.mask_plugin_credentials(decrypted)
# Convert to API entity
return ToolTransformService.convert_enduser_provider_to_credential_entity(
provider=credential,
credentials=dict(decrypted),
)
@staticmethod
def create_api_key_credential(
tenant_id: str,
end_user_id: str,
provider_id: str,
credentials: dict,
name: str | None = None,
) -> ToolProviderCredentialApiEntity:
"""
Create a new API key credential for an end user.
:param tenant_id: The tenant ID
:param end_user_id: The end user ID
:param provider_id: The provider identifier
:param credentials: The credential data
:param name: Optional custom name
:return: Created credential entity
"""
with Session(db.engine) as session:
try:
lock = f"enduser_credential_create_lock:{end_user_id}_{provider_id}"
with redis_client.lock(lock, timeout=20):
# Get provider controller
provider_controller = ToolManager.get_builtin_provider(provider_id, tenant_id)
if not provider_controller.need_credentials:
raise ValueError(f"Provider {provider_id} does not need credentials")
# Check credential count
credential_count = (
session.query(EndUserAuthenticationProvider)
.filter_by(end_user_id=end_user_id, tenant_id=tenant_id, provider=provider_id)
.count()
)
if credential_count >= EndUserAuthService.__MAX_CREDENTIALS_PER_PROVIDER__:
raise ValueError(
f"Maximum number of credentials ({EndUserAuthService.__MAX_CREDENTIALS_PER_PROVIDER__}) "
f"reached for provider {provider_id}"
)
# Validate credentials
credential_type = CredentialType.API_KEY
if CredentialType.of(credential_type).is_validate_allowed():
provider_controller.validate_credentials(end_user_id, credentials)
# Generate name if not provided
if name is None or name == "":
name = EndUserAuthService._generate_credential_name(
session=session,
end_user_id=end_user_id,
tenant_id=tenant_id,
provider=provider_id,
credential_type=credential_type,
)
else:
# Validate name length
if len(name) > 30:
raise ValueError("Credential name must be 30 characters or less")
# Check if name is already used
if session.scalar(
select(
exists().where(
EndUserAuthenticationProvider.end_user_id == end_user_id,
EndUserAuthenticationProvider.tenant_id == tenant_id,
EndUserAuthenticationProvider.provider == provider_id,
EndUserAuthenticationProvider.name == name,
)
)
):
raise ValueError(f"The credential name '{name}' is already used")
# Create encrypter
encrypter, _ = EndUserAuthService._create_encrypter(
tenant_id, provider_controller, credential_type
)
# Create credential record
db_credential = EndUserAuthenticationProvider(
tenant_id=tenant_id,
end_user_id=end_user_id,
provider=provider_id,
encrypted_credentials=json.dumps(encrypter.encrypt(credentials)),
credential_type=credential_type,
name=name,
expires_at=-1, # API keys don't expire
)
session.add(db_credential)
session.commit()
session.refresh(db_credential)
# Return masked credentials
masked_credentials = encrypter.mask_plugin_credentials(credentials)
return ToolTransformService.convert_enduser_provider_to_credential_entity(
provider=db_credential,
credentials=dict(masked_credentials),
)
except Exception as e:
session.rollback()
logger.exception("Error creating API key credential")
raise ValueError(str(e))
@staticmethod
def create_oauth_credential(
end_user_id: str,
tenant_id: str,
provider: str,
credentials: dict,
expires_at: int = -1,
name: str | None = None,
) -> EndUserAuthenticationProvider:
"""
Create a new OAuth credential for an end user.
Used internally by OAuth callback handler.
:param end_user_id: The end user ID
:param tenant_id: The tenant ID
:param provider: The provider identifier
:param credentials: The OAuth credentials (access_token, refresh_token, etc.)
:param expires_at: Unix timestamp when token expires (-1 for no expiry)
:param name: Optional custom name
:return: Created credential record
"""
with Session(db.engine) as session:
try:
lock = f"enduser_credential_create_lock:{end_user_id}_{provider}"
with redis_client.lock(lock, timeout=20):
# Get provider controller
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
# Check credential count
credential_count = (
session.query(EndUserAuthenticationProvider)
.filter_by(end_user_id=end_user_id, tenant_id=tenant_id, provider=provider)
.count()
)
if credential_count >= EndUserAuthService.__MAX_CREDENTIALS_PER_PROVIDER__:
raise ValueError(
f"Maximum number of credentials ({EndUserAuthService.__MAX_CREDENTIALS_PER_PROVIDER__}) "
f"reached for provider {provider}"
)
# Generate name if not provided
credential_type = CredentialType.OAUTH2
if name is None or name == "":
name = EndUserAuthService._generate_credential_name(
session=session,
end_user_id=end_user_id,
tenant_id=tenant_id,
provider=provider,
credential_type=credential_type,
)
# Create encrypter
encrypter, _ = EndUserAuthService._create_encrypter(
tenant_id, provider_controller, credential_type
)
# Create credential record
db_credential = EndUserAuthenticationProvider(
tenant_id=tenant_id,
end_user_id=end_user_id,
provider=provider,
encrypted_credentials=json.dumps(encrypter.encrypt(credentials)),
credential_type=credential_type,
name=name,
expires_at=expires_at,
)
session.add(db_credential)
session.commit()
session.refresh(db_credential)
return db_credential
except Exception as e:
session.rollback()
logger.exception("Error creating OAuth credential")
raise ValueError(str(e))
@staticmethod
def update_credential(
credential_id: str,
end_user_id: str,
tenant_id: str,
credentials: dict | None = None,
name: str | None = None,
) -> ToolProviderCredentialApiEntity:
"""
Update an existing credential (API key only).
:param credential_id: The credential ID
:param end_user_id: The end user ID
:param tenant_id: The tenant ID
:param credentials: Updated credentials (optional)
:param name: Updated name (optional)
:return: Updated credential entity
"""
with Session(db.engine) as session:
try:
# Get credential
db_credential = (
session.query(EndUserAuthenticationProvider)
.filter_by(id=credential_id, end_user_id=end_user_id, tenant_id=tenant_id)
.first()
)
if not db_credential:
raise ValueError(f"Credential {credential_id} not found")
# Only API key credentials can be updated
if not CredentialType.of(db_credential.credential_type).is_editable():
raise ValueError("Only API key credentials can be updated via this endpoint")
# At least one field must be provided
if credentials is None and name is None:
raise ValueError("At least one field (credentials or name) must be provided")
# Get provider controller
provider_controller = ToolManager.get_builtin_provider(db_credential.provider, tenant_id)
# Create encrypter
encrypter, _ = EndUserAuthService._create_encrypter(
tenant_id, provider_controller, db_credential.credential_type
)
# Update credentials if provided
if credentials:
# Decrypt original credentials
original_credentials = encrypter.decrypt(db_credential.credentials)
# Merge with new credentials, keeping hidden values
new_credentials: dict = {
key: value if value != HIDDEN_VALUE else original_credentials.get(key, UNKNOWN_VALUE)
for key, value in credentials.items()
}
# Validate new credentials
if CredentialType.of(db_credential.credential_type).is_validate_allowed():
provider_controller.validate_credentials(end_user_id, new_credentials)
# Encrypt and save
db_credential.encrypted_credentials = json.dumps(encrypter.encrypt(new_credentials))
# Update name if provided
if name and name != db_credential.name:
# Validate name length
if len(name) > 30:
raise ValueError("Credential name must be 30 characters or less")
# Check if name is already used
if session.scalar(
select(
exists().where(
EndUserAuthenticationProvider.end_user_id == end_user_id,
EndUserAuthenticationProvider.tenant_id == tenant_id,
EndUserAuthenticationProvider.provider == db_credential.provider,
EndUserAuthenticationProvider.name == name,
)
)
):
raise ValueError(f"The credential name '{name}' is already used")
db_credential.name = name
session.commit()
session.refresh(db_credential)
# Return masked credentials
decrypted = encrypter.decrypt(db_credential.credentials)
masked_credentials = encrypter.mask_plugin_credentials(decrypted)
return ToolTransformService.convert_enduser_provider_to_credential_entity(
provider=db_credential,
credentials=dict(masked_credentials),
)
except Exception as e:
session.rollback()
logger.exception("Error updating credential")
raise ValueError(str(e))
@staticmethod
def delete_credential(
tenant_id: str, end_user_id: str, provider_id: str, credential_id: str
) -> bool:
"""
Delete a credential.
:param credential_id: The credential ID
:param end_user_id: The end user ID
:param tenant_id: The tenant ID
:return: True if deleted successfully
"""
with Session(db.engine) as session:
credential = (
session.query(EndUserAuthenticationProvider)
.filter_by(id=credential_id, end_user_id=end_user_id, tenant_id=tenant_id)
.first()
)
if not credential:
raise ValueError(f"Credential {credential_id} not found")
session.delete(credential)
session.commit()
return True
@staticmethod
def refresh_oauth_token(
credential_id: str, end_user_id: str, tenant_id: str, refreshed_credentials: dict, expires_at: int
) -> EndUserAuthenticationProvider:
"""
Update OAuth credentials after token refresh.
:param credential_id: The credential ID
:param end_user_id: The end user ID
:param tenant_id: The tenant ID
:param refreshed_credentials: New credentials from OAuth refresh
:param expires_at: New expiration timestamp
:return: Updated credential record
"""
with Session(db.engine) as session:
try:
credential = (
session.query(EndUserAuthenticationProvider)
.filter_by(id=credential_id, end_user_id=end_user_id, tenant_id=tenant_id)
.first()
)
if not credential:
raise ValueError(f"Credential {credential_id} not found")
if credential.credential_type != CredentialType.OAUTH2:
raise ValueError("Only OAuth credentials can be refreshed")
# Get provider controller
provider_controller = ToolManager.get_builtin_provider(credential.provider, tenant_id)
# Create encrypter
encrypter, _ = EndUserAuthService._create_encrypter(
tenant_id, provider_controller, credential.credential_type
)
# Encrypt and save new credentials
credential.encrypted_credentials = json.dumps(encrypter.encrypt(refreshed_credentials))
credential.expires_at = expires_at
session.commit()
session.refresh(credential)
return credential
except Exception as e:
session.rollback()
logger.exception("Error refreshing OAuth token")
raise ValueError(str(e))
@staticmethod
def get_default_credential(
end_user_id: str, tenant_id: str, provider: str
) -> EndUserAuthenticationProvider | None:
"""
Get the default (oldest) credential for a provider.
:param end_user_id: The end user ID
:param tenant_id: The tenant ID
:param provider: The provider identifier
:return: Credential record or None
"""
with Session(db.engine, autoflush=False) as session:
return (
session.query(EndUserAuthenticationProvider)
.filter_by(end_user_id=end_user_id, tenant_id=tenant_id, provider=provider)
.order_by(EndUserAuthenticationProvider.created_at.asc())
.first()
)
@staticmethod
def _generate_credential_name(
session: Session,
end_user_id: str,
tenant_id: str,
provider: str,
credential_type: CredentialType,
) -> str:
"""
Generate a unique credential name.
:param session: Database session
:param end_user_id: The end user ID
:param tenant_id: The tenant ID
:param provider: The provider identifier
:param credential_type: The credential type
:return: Generated name (e.g., "API KEY 1", "AUTH 1")
"""
existing_credentials = (
session.query(EndUserAuthenticationProvider)
.filter_by(
end_user_id=end_user_id,
tenant_id=tenant_id,
provider=provider,
credential_type=credential_type,
)
.order_by(EndUserAuthenticationProvider.created_at.desc())
.all()
)
return generate_incremental_name(
[credential.name for credential in existing_credentials],
f"{credential_type.get_name()}",
)
@staticmethod
def _create_encrypter(
tenant_id: str, provider_controller, credential_type: CredentialType | str
) -> tuple:
"""
Create an encrypter for credential encryption/decryption.
:param tenant_id: The tenant ID
:param provider_controller: The provider controller
:param credential_type: The credential type
:return: Tuple of (encrypter, cache)
"""
if isinstance(credential_type, str):
credential_type = CredentialType.of(credential_type)
return create_provider_encrypter(
tenant_id=tenant_id,
config=[
x.to_basic_provider_config()
for x in provider_controller.get_credentials_schema_by_type(credential_type)
],
cache=NoOpProviderCredentialCache(),
)

View File

@@ -0,0 +1,264 @@
import logging
from typing import Any
from werkzeug import Request
from configs import dify_config
from core.plugin.entities.plugin_daemon import CredentialType
from core.plugin.impl.oauth import OAuthHandler
from core.tools.tool_manager import ToolManager
from models.provider_ids import ToolProviderID
from services.plugin.oauth_service import OAuthProxyService
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
from services.tools.enduser_auth_service import EndUserAuthService
logger = logging.getLogger(__name__)
class EndUserOAuthService:
"""
Service for managing end-user OAuth authentication flows.
Reuses existing OAuthProxyService and OAuthHandler infrastructure.
"""
@staticmethod
def get_authorization_url(
end_user_id: str,
tenant_id: str,
app_id: str,
provider: str,
) -> dict[str, str]:
"""
Initiate OAuth authorization flow for an end user.
:param end_user_id: The end user ID
:param tenant_id: The tenant ID
:param app_id: The application ID
:param provider: The provider identifier
:return: Dict with authorization_url
"""
try:
# Get OAuth client configuration (reuse workspace-level logic)
oauth_client = BuiltinToolManageService.get_oauth_client(tenant_id, provider)
if not oauth_client:
raise ValueError(f"OAuth client not configured for provider {provider}")
# Get provider controller
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
tool_provider_id = ToolProviderID(provider)
# Create OAuth context with end-user specific data
context_id = OAuthProxyService.create_proxy_context(
user_id=end_user_id, # Using end_user_id as user_id
tenant_id=tenant_id,
plugin_id=tool_provider_id.plugin_id,
provider=tool_provider_id.provider_name,
extra_data={
"app_id": app_id,
"provider_type": "tool", # For now, only tools support end-user auth
},
)
# Use the same redirect URI as workspace OAuth to reuse the same OAuth client
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{tool_provider_id}/tool/callback"
# Get authorization URL from OAuth handler
oauth_handler = OAuthHandler()
response = oauth_handler.get_authorization_url(
tenant_id=tenant_id,
user_id=end_user_id,
plugin_id=tool_provider_id.plugin_id,
provider=tool_provider_id.provider_name,
redirect_uri=redirect_uri,
system_credentials=oauth_client,
)
return {
"authorization_url": response.authorization_url,
"context_id": context_id, # Return for setting cookie
}
except Exception as e:
logger.exception("Error getting authorization URL for end user")
raise ValueError(f"Failed to initiate OAuth flow: {str(e)}")
@staticmethod
def handle_oauth_callback(
context_id: str,
request: Request,
) -> dict[str, Any]:
"""
Handle OAuth callback and create credential.
:param context_id: The OAuth context ID from cookie
:param request: The callback request with authorization code
:return: Dict with credential information
"""
try:
# Validate and retrieve context
context = OAuthProxyService.use_proxy_context(context_id)
# Extract context data
end_user_id = context.get("user_id") # user_id is actually end_user_id
tenant_id = context.get("tenant_id")
app_id = context.get("app_id")
plugin_id = context.get("plugin_id")
provider = context.get("provider")
if not all([end_user_id, tenant_id, app_id, plugin_id, provider]):
raise ValueError("Invalid OAuth context: missing required fields")
# Reconstruct full provider ID
full_provider = f"{plugin_id}/{provider}" if plugin_id != "langgenius" else provider
# Get OAuth client configuration
oauth_client = BuiltinToolManageService.get_oauth_client(tenant_id, full_provider)
if not oauth_client:
raise ValueError(f"OAuth client not configured for provider {full_provider}")
# Use the same redirect URI as workspace OAuth (must match authorization request)
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{full_provider}/tool/callback"
# Exchange authorization code for credentials
oauth_handler = OAuthHandler()
credentials_response = oauth_handler.get_credentials(
tenant_id=tenant_id,
user_id=end_user_id,
plugin_id=plugin_id,
provider=provider,
redirect_uri=redirect_uri,
system_credentials=oauth_client,
request=request,
)
# Calculate expiration timestamp
expires_at = -1
if credentials_response.expires_in and credentials_response.expires_in > 0:
import time
expires_at = int(time.time()) + credentials_response.expires_in
# Create credential in database
credential = EndUserAuthService.create_oauth_credential(
end_user_id=end_user_id,
tenant_id=tenant_id,
provider=full_provider,
credentials=credentials_response.credentials,
expires_at=expires_at,
)
return {
"success": True,
"credential_id": credential.id,
"provider": full_provider,
"app_id": app_id,
}
except Exception as e:
logger.exception("Error handling OAuth callback for end user")
raise ValueError(f"Failed to complete OAuth flow: {str(e)}")
@staticmethod
def refresh_oauth_token(
credential_id: str,
end_user_id: str,
tenant_id: str,
) -> dict[str, Any]:
"""
Refresh an expired OAuth token.
:param credential_id: The credential ID
:param end_user_id: The end user ID
:param tenant_id: The tenant ID
:return: Dict with refresh status
"""
try:
# Get existing credential
credential = EndUserAuthService.get_credential(
credential_id=credential_id,
end_user_id=end_user_id,
tenant_id=tenant_id,
mask_credentials=False, # Need full credentials for refresh
)
if not credential:
raise ValueError(f"Credential {credential_id} not found")
if credential.credential_type != CredentialType.OAUTH2:
raise ValueError("Only OAuth credentials can be refreshed")
# Get OAuth client configuration
oauth_client = BuiltinToolManageService.get_oauth_client(tenant_id, credential.provider)
if not oauth_client:
raise ValueError(f"OAuth client not configured for provider {credential.provider}")
# Get provider info
tool_provider_id = ToolProviderID(credential.provider)
# Use the same redirect URI as workspace OAuth to reuse the same OAuth client
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{credential.provider}/tool/callback"
# Refresh credentials via OAuth handler
oauth_handler = OAuthHandler()
refreshed_response = oauth_handler.refresh_credentials(
tenant_id=tenant_id,
user_id=end_user_id,
plugin_id=tool_provider_id.plugin_id,
provider=tool_provider_id.provider_name,
redirect_uri=redirect_uri,
system_credentials=oauth_client,
credentials=credential.credentials,
)
# Calculate new expiration timestamp
expires_at = -1
if refreshed_response.expires_in and refreshed_response.expires_in > 0:
import time
expires_at = int(time.time()) + refreshed_response.expires_in
# Update credential in database
updated_credential = EndUserAuthService.refresh_oauth_token(
credential_id=credential_id,
end_user_id=end_user_id,
tenant_id=tenant_id,
refreshed_credentials=refreshed_response.credentials,
expires_at=expires_at,
)
return {
"success": True,
"credential_id": updated_credential.id,
"expires_at": expires_at,
"refreshed_at": int(updated_credential.updated_at.timestamp()),
}
except Exception:
logger.exception("Error refreshing OAuth token for end user")
return {
"success": False,
"error": "Failed to refresh token",
}
@staticmethod
def get_oauth_client_info(tenant_id: str, provider: str) -> dict[str, Any]:
"""
Get OAuth client information for a provider.
Used to check if OAuth is available and configured.
:param tenant_id: The tenant ID
:param provider: The provider identifier
:return: Dict with OAuth client info
"""
try:
# Check if OAuth client exists (either system or custom)
oauth_client = BuiltinToolManageService.get_oauth_client(tenant_id, provider)
return {
"configured": oauth_client is not None,
"system_configured": BuiltinToolManageService.is_oauth_system_client_exists(provider),
"custom_configured": BuiltinToolManageService.is_oauth_custom_client_enabled(tenant_id, provider),
}
except Exception:
logger.exception("Error getting OAuth client info")
return {
"configured": False,
"system_configured": False,
"custom_configured": False,
}

View File

@@ -426,6 +426,27 @@ class ToolTransformService:
credentials=credentials,
)
@staticmethod
def convert_enduser_provider_to_credential_entity(
provider, credentials: dict
) -> ToolProviderCredentialApiEntity:
"""
Convert EndUserAuthenticationProvider to ToolProviderCredentialApiEntity.
:param provider: EndUserAuthenticationProvider instance
:param credentials: Decrypted/masked credentials dict
:return: ToolProviderCredentialApiEntity
"""
return ToolProviderCredentialApiEntity(
id=provider.id,
name=provider.name,
provider=provider.provider,
credential_type=CredentialType.of(provider.credential_type),
is_default=False, # End-user credentials don't have default flag (use oldest)
credentials=credentials,
expires_at=provider.expires_at,
)
@staticmethod
def convert_mcp_schema_to_parameter(schema: dict[str, Any]) -> list["ToolParameter"]:
"""

View File

@@ -0,0 +1,486 @@
"""
Unit tests for end-user tool authentication.
Tests the integration of end-user authentication with tool runtime resolution.
"""
import time
from unittest.mock import MagicMock, patch
import pytest
from core.app.entities.app_invoke_entities import InvokeFrom
from core.tools.entities.tool_entities import ToolAuthType, ToolInvokeFrom, ToolProviderType
from core.tools.errors import ToolProviderNotFoundError
from core.tools.tool_manager import ToolManager
@pytest.fixture
def mock_db_session():
"""Mock database session."""
with patch("core.tools.tool_manager.db") as mock_db:
yield mock_db.session
@pytest.fixture
def mock_provider_controller():
"""Mock builtin provider controller."""
controller = MagicMock()
controller.need_credentials = True
controller.get_tool = MagicMock(return_value=MagicMock())
controller.get_credentials_schema_by_type = MagicMock(return_value=[])
return controller
class TestEndUserToolAuthentication:
"""Test suite for end-user tool authentication."""
def test_end_user_auth_requires_end_user_id(self, mock_db_session, mock_provider_controller):
"""
Test that END_USER auth_type requires end_user_id parameter.
When auth_type is END_USER but end_user_id is None, should raise error.
"""
with patch.object(ToolManager, "get_builtin_provider", return_value=mock_provider_controller):
with pytest.raises(ToolProviderNotFoundError, match="end_user_id is required"):
ToolManager.get_tool_runtime(
provider_type=ToolProviderType.BUILT_IN,
provider_id="test_provider",
tool_name="test_tool",
tenant_id="test_tenant",
invoke_from=InvokeFrom.SERVICE_API,
tool_invoke_from=ToolInvokeFrom.WORKFLOW,
auth_type=ToolAuthType.END_USER,
end_user_id=None, # Missing!
)
def test_end_user_auth_missing_credentials(self, mock_db_session, mock_provider_controller):
"""
Test that error is raised when end-user has no credentials for provider.
When auth_type is END_USER but no credentials exist, should raise error.
"""
# Mock no credentials found
mock_db_session.query.return_value.where.return_value.order_by.return_value.first.return_value = None
with patch.object(ToolManager, "get_builtin_provider", return_value=mock_provider_controller):
with pytest.raises(ToolProviderNotFoundError, match="No end-user credentials found"):
ToolManager.get_tool_runtime(
provider_type=ToolProviderType.BUILT_IN,
provider_id="test_provider",
tool_name="test_tool",
tenant_id="test_tenant",
invoke_from=InvokeFrom.SERVICE_API,
tool_invoke_from=ToolInvokeFrom.WORKFLOW,
auth_type=ToolAuthType.END_USER,
end_user_id="end_user_123",
)
def test_end_user_auth_with_credentials(self, mock_db_session, mock_provider_controller):
"""
Test successful end-user credential resolution.
When auth_type is END_USER and credentials exist, should return tool runtime.
"""
# Mock end-user provider
mock_enduser_provider = MagicMock()
mock_enduser_provider.id = "cred_123"
mock_enduser_provider.credential_type = "api-key"
mock_enduser_provider.credentials = '{"api_key": "encrypted"}'
mock_enduser_provider.expires_at = -1 # No expiry
mock_db_session.query.return_value.where.return_value.order_by.return_value.first.return_value = (
mock_enduser_provider
)
# Mock encrypter
mock_encrypter = MagicMock()
mock_encrypter.decrypt.return_value = {"api_key": "decrypted_key"}
mock_cache = MagicMock()
with (
patch.object(ToolManager, "get_builtin_provider", return_value=mock_provider_controller),
patch("core.tools.tool_manager.create_provider_encrypter", return_value=(mock_encrypter, mock_cache)),
):
tool_runtime = ToolManager.get_tool_runtime(
provider_type=ToolProviderType.BUILT_IN,
provider_id="test_provider",
tool_name="test_tool",
tenant_id="test_tenant",
invoke_from=InvokeFrom.SERVICE_API,
tool_invoke_from=ToolInvokeFrom.WORKFLOW,
auth_type=ToolAuthType.END_USER,
end_user_id="end_user_123",
)
# Verify tool runtime was created
assert tool_runtime is not None
# Verify encrypter was called with decrypted credentials
mock_encrypter.decrypt.assert_called_once()
def test_workspace_auth_backward_compatibility(self, mock_db_session, mock_provider_controller):
"""
Test that workspace authentication still works (backward compatibility).
When auth_type is WORKSPACE (default), should use workspace credentials.
"""
# Mock workspace provider
mock_workspace_provider = MagicMock()
mock_workspace_provider.id = "workspace_cred_123"
mock_workspace_provider.credential_type = "api-key"
mock_workspace_provider.credentials = '{"api_key": "workspace_encrypted"}'
mock_workspace_provider.expires_at = -1
mock_db_session.query.return_value.where.return_value.order_by.return_value.first.return_value = (
mock_workspace_provider
)
# Mock encrypter
mock_encrypter = MagicMock()
mock_encrypter.decrypt.return_value = {"api_key": "workspace_decrypted"}
mock_cache = MagicMock()
with (
patch.object(ToolManager, "get_builtin_provider", return_value=mock_provider_controller),
patch("core.tools.tool_manager.create_provider_encrypter", return_value=(mock_encrypter, mock_cache)),
patch("core.helper.credential_utils.check_credential_policy_compliance"),
):
tool_runtime = ToolManager.get_tool_runtime(
provider_type=ToolProviderType.BUILT_IN,
provider_id="test_provider",
tool_name="test_tool",
tenant_id="test_tenant",
invoke_from=InvokeFrom.SERVICE_API,
tool_invoke_from=ToolInvokeFrom.WORKFLOW,
auth_type=ToolAuthType.WORKSPACE, # Workspace auth
end_user_id=None, # Not needed for workspace auth
)
# Verify tool runtime was created
assert tool_runtime is not None
def test_workflow_tool_runtime_passes_end_user_id(self, mock_db_session, mock_provider_controller):
"""
Test that get_workflow_tool_runtime correctly passes end_user_id to get_tool_runtime.
"""
from core.workflow.nodes.tool.entities import ToolEntity
# Create a mock ToolEntity with END_USER auth_type
workflow_tool = MagicMock(spec=ToolEntity)
workflow_tool.provider_type = ToolProviderType.BUILT_IN
workflow_tool.provider_id = "test_provider"
workflow_tool.tool_name = "test_tool"
workflow_tool.credential_id = None
workflow_tool.auth_type = ToolAuthType.END_USER
workflow_tool.tool_configurations = {}
# Mock end-user credentials
mock_enduser_provider = MagicMock()
mock_enduser_provider.id = "cred_123"
mock_enduser_provider.credential_type = "api-key"
mock_enduser_provider.credentials = '{"api_key": "encrypted"}'
mock_enduser_provider.expires_at = -1
mock_db_session.query.return_value.where.return_value.order_by.return_value.first.return_value = (
mock_enduser_provider
)
mock_encrypter = MagicMock()
mock_encrypter.decrypt.return_value = {"api_key": "decrypted"}
mock_cache = MagicMock()
with (
patch.object(ToolManager, "get_builtin_provider", return_value=mock_provider_controller),
patch("core.tools.tool_manager.create_provider_encrypter", return_value=(mock_encrypter, mock_cache)),
):
tool_runtime = ToolManager.get_workflow_tool_runtime(
tenant_id="test_tenant",
app_id="test_app",
node_id="test_node",
workflow_tool=workflow_tool,
invoke_from=InvokeFrom.SERVICE_API,
variable_pool=None,
end_user_id="end_user_123", # Pass end_user_id
)
# Verify tool runtime was created
assert tool_runtime is not None
class TestOAuthTokenRefresh:
"""Test suite for OAuth token refresh functionality."""
def test_enduser_oauth_token_refresh_when_expired(self, mock_db_session, mock_provider_controller):
"""
Test that end-user OAuth tokens are automatically refreshed when expired.
When an OAuth token is expired (expires_at < current_time + 60s buffer),
the system should automatically refresh it before using.
"""
# Mock end-user provider with expired OAuth token
mock_enduser_provider = MagicMock()
mock_enduser_provider.id = "cred_123"
mock_enduser_provider.credential_type = "oauth2"
mock_enduser_provider.credentials = '{"access_token": "old_token", "refresh_token": "refresh"}'
# Set expiry to past (token expired)
mock_enduser_provider.expires_at = int(time.time()) - 100
mock_enduser_provider.encrypted_credentials = '{"access_token": "old_token", "refresh_token": "refresh"}'
mock_db_session.query.return_value.where.return_value.order_by.return_value.first.return_value = (
mock_enduser_provider
)
# Mock encrypter
mock_encrypter = MagicMock()
mock_encrypter.decrypt.return_value = {"access_token": "old_token", "refresh_token": "refresh"}
mock_encrypter.encrypt.return_value = {"access_token": "new_token", "refresh_token": "refresh"}
mock_cache = MagicMock()
# Mock OAuth refresh response
mock_refreshed_credentials = MagicMock()
mock_refreshed_credentials.credentials = {"access_token": "new_token", "refresh_token": "refresh"}
mock_refreshed_credentials.expires_at = int(time.time()) + 3600 # New expiry 1 hour from now
with (
patch.object(ToolManager, "get_builtin_provider", return_value=mock_provider_controller),
patch("core.tools.tool_manager.create_provider_encrypter", return_value=(mock_encrypter, mock_cache)),
patch.object(
ToolManager,
"_refresh_oauth_credentials",
return_value=(
mock_refreshed_credentials.credentials,
mock_refreshed_credentials.expires_at,
),
) as mock_refresh,
):
tool_runtime = ToolManager.get_tool_runtime(
provider_type=ToolProviderType.BUILT_IN,
provider_id="github",
tool_name="test_tool",
tenant_id="test_tenant",
invoke_from=InvokeFrom.SERVICE_API,
tool_invoke_from=ToolInvokeFrom.WORKFLOW,
auth_type=ToolAuthType.END_USER,
end_user_id="end_user_123",
)
# Verify refresh was called
mock_refresh.assert_called_once_with(
tenant_id="test_tenant",
provider_id="github",
user_id="end_user_123",
decrypted_credentials={"access_token": "old_token", "refresh_token": "refresh"},
)
# Verify provider was updated with new credentials
assert mock_enduser_provider.expires_at == mock_refreshed_credentials.expires_at
mock_db_session.commit.assert_called_once()
mock_cache.delete.assert_called_once()
# Verify tool runtime was created
assert tool_runtime is not None
def test_enduser_oauth_token_not_refreshed_when_valid(self, mock_db_session, mock_provider_controller):
"""
Test that valid OAuth tokens are NOT refreshed.
When an OAuth token is still valid (expires_at > current_time + 60s buffer),
the system should use it without refreshing.
"""
# Mock end-user provider with valid OAuth token
mock_enduser_provider = MagicMock()
mock_enduser_provider.id = "cred_123"
mock_enduser_provider.credential_type = "oauth2"
mock_enduser_provider.credentials = '{"access_token": "valid_token", "refresh_token": "refresh"}'
# Set expiry to future (token still valid with buffer)
mock_enduser_provider.expires_at = int(time.time()) + 3600 # Valid for 1 hour
mock_db_session.query.return_value.where.return_value.order_by.return_value.first.return_value = (
mock_enduser_provider
)
# Mock encrypter
mock_encrypter = MagicMock()
mock_encrypter.decrypt.return_value = {"access_token": "valid_token", "refresh_token": "refresh"}
mock_cache = MagicMock()
with (
patch.object(ToolManager, "get_builtin_provider", return_value=mock_provider_controller),
patch("core.tools.tool_manager.create_provider_encrypter", return_value=(mock_encrypter, mock_cache)),
patch.object(ToolManager, "_refresh_oauth_credentials") as mock_refresh,
):
tool_runtime = ToolManager.get_tool_runtime(
provider_type=ToolProviderType.BUILT_IN,
provider_id="github",
tool_name="test_tool",
tenant_id="test_tenant",
invoke_from=InvokeFrom.SERVICE_API,
tool_invoke_from=ToolInvokeFrom.WORKFLOW,
auth_type=ToolAuthType.END_USER,
end_user_id="end_user_123",
)
# Verify refresh was NOT called (token still valid)
mock_refresh.assert_not_called()
# Verify tool runtime was created with original credentials
assert tool_runtime is not None
def test_workspace_oauth_token_refresh_when_expired(self, mock_db_session, mock_provider_controller):
"""
Test that workspace OAuth tokens are automatically refreshed when expired.
This ensures the refactored _refresh_oauth_credentials helper works
for both end-user and workspace authentication flows.
"""
# Mock workspace provider with expired OAuth token
mock_workspace_provider = MagicMock()
mock_workspace_provider.id = "workspace_cred_123"
mock_workspace_provider.user_id = "workspace_user_456"
mock_workspace_provider.credential_type = "oauth2"
mock_workspace_provider.credentials = '{"access_token": "old_workspace_token"}'
mock_workspace_provider.expires_at = int(time.time()) - 100 # Expired
mock_workspace_provider.encrypted_credentials = '{"access_token": "old_workspace_token"}'
mock_db_session.query.return_value.where.return_value.order_by.return_value.first.return_value = (
mock_workspace_provider
)
# Mock encrypter
mock_encrypter = MagicMock()
mock_encrypter.decrypt.return_value = {"access_token": "old_workspace_token"}
mock_encrypter.encrypt.return_value = {"access_token": "new_workspace_token"}
mock_cache = MagicMock()
# Mock OAuth refresh response
refreshed_creds = {"access_token": "new_workspace_token"}
new_expires_at = int(time.time()) + 3600
with (
patch.object(ToolManager, "get_builtin_provider", return_value=mock_provider_controller),
patch("core.tools.tool_manager.create_provider_encrypter", return_value=(mock_encrypter, mock_cache)),
patch("core.helper.credential_utils.check_credential_policy_compliance"),
patch.object(
ToolManager, "_refresh_oauth_credentials", return_value=(refreshed_creds, new_expires_at)
) as mock_refresh,
):
tool_runtime = ToolManager.get_tool_runtime(
provider_type=ToolProviderType.BUILT_IN,
provider_id="github",
tool_name="test_tool",
tenant_id="test_tenant",
invoke_from=InvokeFrom.SERVICE_API,
tool_invoke_from=ToolInvokeFrom.WORKFLOW,
auth_type=ToolAuthType.WORKSPACE,
)
# Verify refresh was called with workspace user_id
mock_refresh.assert_called_once_with(
tenant_id="test_tenant",
provider_id="github",
user_id="workspace_user_456",
decrypted_credentials={"access_token": "old_workspace_token"},
)
# Verify provider was updated
assert mock_workspace_provider.expires_at == new_expires_at
mock_db_session.commit.assert_called_once()
mock_cache.delete.assert_called_once()
# Verify tool runtime was created
assert tool_runtime is not None
def test_oauth_token_no_refresh_for_non_oauth_credentials(self, mock_db_session, mock_provider_controller):
"""
Test that non-OAuth credentials (API keys) are never refreshed.
API keys with expires_at = -1 should not trigger refresh logic.
"""
# Mock end-user provider with API key (no expiry)
mock_enduser_provider = MagicMock()
mock_enduser_provider.id = "cred_123"
mock_enduser_provider.credential_type = "api-key"
mock_enduser_provider.credentials = '{"api_key": "sk-1234567890"}'
mock_enduser_provider.expires_at = -1 # API keys don't expire
mock_db_session.query.return_value.where.return_value.order_by.return_value.first.return_value = (
mock_enduser_provider
)
# Mock encrypter
mock_encrypter = MagicMock()
mock_encrypter.decrypt.return_value = {"api_key": "sk-1234567890"}
mock_cache = MagicMock()
with (
patch.object(ToolManager, "get_builtin_provider", return_value=mock_provider_controller),
patch("core.tools.tool_manager.create_provider_encrypter", return_value=(mock_encrypter, mock_cache)),
patch.object(ToolManager, "_refresh_oauth_credentials") as mock_refresh,
):
tool_runtime = ToolManager.get_tool_runtime(
provider_type=ToolProviderType.BUILT_IN,
provider_id="openai",
tool_name="test_tool",
tenant_id="test_tenant",
invoke_from=InvokeFrom.SERVICE_API,
tool_invoke_from=ToolInvokeFrom.WORKFLOW,
auth_type=ToolAuthType.END_USER,
end_user_id="end_user_123",
)
# Verify refresh was NOT called (API key doesn't need refresh)
mock_refresh.assert_not_called()
# Verify tool runtime was created
assert tool_runtime is not None
def test_refresh_oauth_credentials_helper_method(self):
"""
Test the _refresh_oauth_credentials helper method directly.
This tests the centralized OAuth refresh logic that is used by both
end-user and workspace authentication flows.
"""
# Mock dependencies
mock_oauth_handler = MagicMock()
mock_refreshed = MagicMock()
mock_refreshed.credentials = {"access_token": "new_token", "refresh_token": "new_refresh"}
mock_refreshed.expires_at = int(time.time()) + 7200
mock_oauth_handler.refresh_credentials.return_value = mock_refreshed
with (
# Patch OAuthHandler where it's imported (inside the method)
patch("core.plugin.impl.oauth.OAuthHandler", return_value=mock_oauth_handler),
patch("core.tools.tool_manager.ToolProviderID") as mock_provider_id,
patch(
"services.tools.builtin_tools_manage_service.BuiltinToolManageService.get_oauth_client",
return_value={"client_id": "test"},
),
patch("core.tools.tool_manager.dify_config.CONSOLE_API_URL", "http://localhost:5001"),
):
# Setup provider ID mock
mock_provider_id.return_value.provider_name = "github"
mock_provider_id.return_value.plugin_id = "builtin"
# Call the helper method
credentials, expires_at = ToolManager._refresh_oauth_credentials(
tenant_id="test_tenant",
provider_id="langgenius/github/github",
user_id="user_123",
decrypted_credentials={"access_token": "old_token", "refresh_token": "old_refresh"},
)
# Verify OAuth handler was called correctly
mock_oauth_handler.refresh_credentials.assert_called_once_with(
tenant_id="test_tenant",
user_id="user_123",
plugin_id="builtin",
provider="github",
redirect_uri="http://localhost:5001/console/api/oauth/plugin/langgenius/github/github/tool/callback",
system_credentials={"client_id": "test"},
credentials={"access_token": "old_token", "refresh_token": "old_refresh"},
)
# Verify returned values
assert credentials == {"access_token": "new_token", "refresh_token": "new_refresh"}
assert expires_at == mock_refreshed.expires_at

View File

@@ -214,6 +214,7 @@ const ToolSelector: FC<Props> = ({
onSelect({
...value,
use_end_user_credentials: enabled,
auth_type: enabled ? 'end_user' : 'workspace',
} as any)
}
const handleEndUserCredentialTypeChange = (type: string) => {

View File

@@ -24,4 +24,5 @@ export type ToolNodeType = CommonNodeType & {
provider_icon?: Collection['icon']
provider_icon_dark?: Collection['icon_dark']
plugin_unique_identifier?: string
auth_type?: 'workspace' | 'end_user'
}