mirror of
https://github.com/langgenius/dify.git
synced 2026-04-19 03:00:42 -04:00
refactor: use sessionmaker in api_tools_manage_service.py (#34892)
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -38,6 +38,17 @@ class ToolCredentialPolicyViolationError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class ApiToolProviderNotFoundError(ValueError):
|
||||
error_code = "api_tool_provider_not_found"
|
||||
provider_name: str
|
||||
tenant_id: str
|
||||
|
||||
def __init__(self, provider_name: str, tenant_id: str):
|
||||
self.provider_name = provider_name
|
||||
self.tenant_id = tenant_id
|
||||
super().__init__(f"api provider {provider_name} does not exist")
|
||||
|
||||
|
||||
class WorkflowToolHumanInputNotSupportedError(BaseHTTPException):
|
||||
error_code = "workflow_tool_human_input_not_supported"
|
||||
description = "Workflow with Human Input nodes cannot be published as a workflow tool."
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Any, TypedDict, cast
|
||||
|
||||
from httpx import get
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
@@ -15,6 +16,7 @@ from core.tools.entities.tool_entities import (
|
||||
ApiProviderAuthType,
|
||||
ApiProviderSchemaType,
|
||||
)
|
||||
from core.tools.errors import ApiToolProviderNotFoundError
|
||||
from core.tools.tool_label_manager import ToolLabelManager
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.encryption import create_tool_provider_encrypter
|
||||
@@ -116,71 +118,85 @@ class ApiToolManageService:
|
||||
privacy_policy: str,
|
||||
custom_disclaimer: str,
|
||||
labels: list[str],
|
||||
):
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
create api tool provider
|
||||
Create a new API tool provider.
|
||||
|
||||
:param user_id: The ID of the user creating the provider.
|
||||
:param tenant_id: The ID of the workspace/tenant.
|
||||
:param provider_name: The name of the API tool provider.
|
||||
:param icon: The icon configuration for the provider.
|
||||
:param credentials: The credentials for the provider.
|
||||
:param schema_type: The type of schema (e.g., OpenAPI).
|
||||
:param schema: The raw schema string.
|
||||
:param privacy_policy: The privacy policy URL or text.
|
||||
:param custom_disclaimer: Custom disclaimer text.
|
||||
:param labels: A list of labels for the provider.
|
||||
:return: A dictionary indicating the result status.
|
||||
"""
|
||||
|
||||
provider_name = provider_name.strip()
|
||||
|
||||
# check if the provider exists
|
||||
provider = db.session.scalar(
|
||||
select(ApiToolProvider)
|
||||
.where(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
ApiToolProvider.name == provider_name,
|
||||
# Create new session with automatic transaction management
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as _session:
|
||||
provider: ApiToolProvider | None = _session.scalar(
|
||||
select(ApiToolProvider)
|
||||
.where(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
ApiToolProvider.name == provider_name,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if provider is not None:
|
||||
raise ValueError(f"provider {provider_name} already exists")
|
||||
if provider is not None:
|
||||
raise ValueError(f"provider {provider_name} already exists")
|
||||
|
||||
# parse openapi to tool bundle
|
||||
extra_info: dict[str, str] = {}
|
||||
# extra info like description will be set here
|
||||
tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info)
|
||||
# parse openapi to tool bundle
|
||||
extra_info: dict[str, str] = {}
|
||||
# extra info like description will be set here
|
||||
tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info)
|
||||
|
||||
if len(tool_bundles) > 100:
|
||||
raise ValueError("the number of apis should be less than 100")
|
||||
if len(tool_bundles) > 100:
|
||||
raise ValueError("the number of apis should be less than 100")
|
||||
|
||||
# create db provider
|
||||
db_provider = ApiToolProvider(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
name=provider_name,
|
||||
icon=json.dumps(icon),
|
||||
schema=schema,
|
||||
description=extra_info.get("description", ""),
|
||||
schema_type_str=schema_type,
|
||||
tools_str=json.dumps(jsonable_encoder(tool_bundles)),
|
||||
credentials_str="{}",
|
||||
privacy_policy=privacy_policy,
|
||||
custom_disclaimer=custom_disclaimer,
|
||||
)
|
||||
# create API tool provider
|
||||
api_tool_provider = ApiToolProvider(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
name=provider_name,
|
||||
icon=json.dumps(icon),
|
||||
schema=schema,
|
||||
description=extra_info.get("description", ""),
|
||||
schema_type_str=schema_type,
|
||||
tools_str=json.dumps(jsonable_encoder(tool_bundles)),
|
||||
credentials_str="{}",
|
||||
privacy_policy=privacy_policy,
|
||||
custom_disclaimer=custom_disclaimer,
|
||||
)
|
||||
|
||||
if "auth_type" not in credentials:
|
||||
raise ValueError("auth_type is required")
|
||||
if "auth_type" not in credentials:
|
||||
raise ValueError("auth_type is required")
|
||||
|
||||
# get auth type, none or api key
|
||||
auth_type = ApiProviderAuthType.value_of(credentials["auth_type"])
|
||||
# get auth type, none or api key
|
||||
auth_type = ApiProviderAuthType.value_of(credentials["auth_type"])
|
||||
|
||||
# create provider entity
|
||||
provider_controller = ApiToolProviderController.from_db(db_provider, auth_type)
|
||||
# load tools into provider entity
|
||||
provider_controller.load_bundled_tools(tool_bundles)
|
||||
# create provider entity
|
||||
provider_controller = ApiToolProviderController.from_db(api_tool_provider, auth_type)
|
||||
# load tools into provider entity
|
||||
provider_controller.load_bundled_tools(tool_bundles)
|
||||
|
||||
# encrypt credentials
|
||||
encrypter, _ = create_tool_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
controller=provider_controller,
|
||||
)
|
||||
db_provider.credentials_str = json.dumps(encrypter.encrypt(credentials))
|
||||
# encrypt credentials
|
||||
encrypter, _ = create_tool_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
controller=provider_controller,
|
||||
)
|
||||
api_tool_provider.credentials_str = json.dumps(encrypter.encrypt(credentials))
|
||||
|
||||
db.session.add(db_provider)
|
||||
db.session.commit()
|
||||
_session.add(api_tool_provider)
|
||||
|
||||
# update labels
|
||||
ToolLabelManager.update_tool_labels(provider_controller, labels)
|
||||
# update labels
|
||||
ToolLabelManager.update_tool_labels(provider_controller, labels, _session)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@@ -212,16 +228,25 @@ class ApiToolManageService:
|
||||
@staticmethod
|
||||
def list_api_tool_provider_tools(user_id: str, tenant_id: str, provider_name: str) -> list[ToolApiEntity]:
|
||||
"""
|
||||
list api tool provider tools
|
||||
List tools provided by a specific API tool provider.
|
||||
|
||||
:param user_id: The ID of the user requesting the list.
|
||||
:param tenant_id: The ID of the workspace/tenant.
|
||||
:param provider_name: The name of the API tool provider.
|
||||
:return: A list of ToolApiEntity objects.
|
||||
"""
|
||||
provider: ApiToolProvider | None = db.session.scalar(
|
||||
select(ApiToolProvider)
|
||||
.where(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
ApiToolProvider.name == provider_name,
|
||||
|
||||
# create new session with automatic transaction management
|
||||
provider: ApiToolProvider | None = None
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as _session:
|
||||
provider = _session.scalar(
|
||||
select(ApiToolProvider)
|
||||
.where(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
ApiToolProvider.name == provider_name,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if provider is None:
|
||||
raise ValueError(f"you have not added provider {provider_name}")
|
||||
@@ -251,103 +276,133 @@ class ApiToolManageService:
|
||||
privacy_policy: str | None,
|
||||
custom_disclaimer: str,
|
||||
labels: list[str],
|
||||
):
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
update api tool provider
|
||||
Update an existing API tool provider.
|
||||
|
||||
:param user_id: The ID of the user updating the provider.
|
||||
:param tenant_id: The ID of the workspace/tenant.
|
||||
:param provider_name: The new name of the API tool provider.
|
||||
:param original_provider: The original name of the API tool provider.
|
||||
:param icon: The icon configuration for the provider.
|
||||
:param credentials: The credentials for the provider.
|
||||
:param _schema_type: The type of schema (e.g., OpenAPI).
|
||||
:param schema: The raw schema string.
|
||||
:param privacy_policy: The privacy policy URL or text.
|
||||
:param custom_disclaimer: Custom disclaimer text.
|
||||
:param labels: A list of labels for the provider.
|
||||
:return: A dictionary indicating the result status.
|
||||
"""
|
||||
|
||||
provider_name = provider_name.strip()
|
||||
|
||||
# check if the provider exists
|
||||
provider = db.session.scalar(
|
||||
select(ApiToolProvider)
|
||||
.where(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
ApiToolProvider.name == original_provider,
|
||||
# create new session with automatic transaction management
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as _session:
|
||||
provider: ApiToolProvider | None = _session.scalar(
|
||||
select(ApiToolProvider)
|
||||
.where(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
ApiToolProvider.name == original_provider,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if provider is None:
|
||||
raise ValueError(f"api provider {provider_name} does not exists")
|
||||
# parse openapi to tool bundle
|
||||
extra_info: dict[str, str] = {}
|
||||
# extra info like description will be set here
|
||||
tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info)
|
||||
if provider is None:
|
||||
raise ApiToolProviderNotFoundError(provider_name=original_provider, tenant_id=tenant_id)
|
||||
|
||||
# update db provider
|
||||
provider.name = provider_name
|
||||
provider.icon = json.dumps(icon)
|
||||
provider.schema = schema
|
||||
provider.description = extra_info.get("description", "")
|
||||
provider.schema_type_str = schema_type
|
||||
provider.tools_str = json.dumps(jsonable_encoder(tool_bundles))
|
||||
provider.privacy_policy = privacy_policy
|
||||
provider.custom_disclaimer = custom_disclaimer
|
||||
# parse openapi to tool bundle
|
||||
extra_info: dict[str, str] = {}
|
||||
# extra info like description will be set here
|
||||
tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info)
|
||||
|
||||
if "auth_type" not in credentials:
|
||||
raise ValueError("auth_type is required")
|
||||
# update db provider
|
||||
provider.name = provider_name
|
||||
provider.icon = json.dumps(icon)
|
||||
provider.schema = schema
|
||||
provider.description = extra_info.get("description", "")
|
||||
provider.schema_type_str = schema_type
|
||||
provider.tools_str = json.dumps(jsonable_encoder(tool_bundles))
|
||||
provider.privacy_policy = privacy_policy
|
||||
provider.custom_disclaimer = custom_disclaimer
|
||||
|
||||
# get auth type, none or api key
|
||||
auth_type = ApiProviderAuthType.value_of(credentials["auth_type"])
|
||||
if "auth_type" not in credentials:
|
||||
raise ValueError("auth_type is required")
|
||||
|
||||
# create provider entity
|
||||
provider_controller = ApiToolProviderController.from_db(provider, auth_type)
|
||||
# load tools into provider entity
|
||||
provider_controller.load_bundled_tools(tool_bundles)
|
||||
# get auth type, none or api key
|
||||
auth_type = ApiProviderAuthType.value_of(credentials["auth_type"])
|
||||
|
||||
# get original credentials if exists
|
||||
encrypter, cache = create_tool_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
controller=provider_controller,
|
||||
)
|
||||
# create provider entity
|
||||
provider_controller = ApiToolProviderController.from_db(provider, auth_type)
|
||||
# load tools into provider entity
|
||||
provider_controller.load_bundled_tools(tool_bundles)
|
||||
|
||||
original_credentials = encrypter.decrypt(provider.credentials)
|
||||
masked_credentials = encrypter.mask_plugin_credentials(original_credentials)
|
||||
# check if the credential has changed, save the original credential
|
||||
for name, value in credentials.items():
|
||||
if name in masked_credentials and value == masked_credentials[name]:
|
||||
credentials[name] = original_credentials[name]
|
||||
# get original credentials if exists
|
||||
encrypter, cache = create_tool_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
controller=provider_controller,
|
||||
)
|
||||
|
||||
credentials = dict(encrypter.encrypt(credentials))
|
||||
provider.credentials_str = json.dumps(credentials)
|
||||
original_credentials = encrypter.decrypt(provider.credentials)
|
||||
masked_credentials = encrypter.mask_plugin_credentials(original_credentials)
|
||||
|
||||
db.session.add(provider)
|
||||
db.session.commit()
|
||||
# check if the credential has changed, save the original credential
|
||||
for name, value in credentials.items():
|
||||
if name in masked_credentials and value == masked_credentials[name]:
|
||||
credentials[name] = original_credentials[name]
|
||||
|
||||
credentials = dict(encrypter.encrypt(credentials))
|
||||
provider.credentials_str = json.dumps(credentials)
|
||||
|
||||
_session.add(provider)
|
||||
|
||||
# update labels
|
||||
ToolLabelManager.update_tool_labels(provider_controller, labels, _session)
|
||||
|
||||
# delete cache
|
||||
cache.delete()
|
||||
|
||||
# update labels
|
||||
ToolLabelManager.update_tool_labels(provider_controller, labels)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@staticmethod
|
||||
def delete_api_tool_provider(user_id: str, tenant_id: str, provider_name: str):
|
||||
"""
|
||||
delete tool provider
|
||||
Delete an API tool provider.
|
||||
|
||||
:param user_id: The ID of the user performing the deletion operation.
|
||||
:param tenant_id: The ID of the workspace/tenant where the provider belongs.
|
||||
:param provider_name: The unique name of the API tool provider to be deleted.
|
||||
:raises ValueError: If the specified provider does not exist in the tenant.
|
||||
:return: A dictionary indicating the result status.
|
||||
"""
|
||||
provider = db.session.scalar(
|
||||
select(ApiToolProvider)
|
||||
.where(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
ApiToolProvider.name == provider_name,
|
||||
|
||||
# create new session with automatic transaction management
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as _session:
|
||||
provider: ApiToolProvider | None = _session.scalar(
|
||||
select(ApiToolProvider)
|
||||
.where(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
ApiToolProvider.name == provider_name,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if provider is None:
|
||||
raise ValueError(f"you have not added provider {provider_name}")
|
||||
if provider is None:
|
||||
raise ValueError(f"you have not added provider {provider_name}")
|
||||
|
||||
db.session.delete(provider)
|
||||
db.session.commit()
|
||||
_session.delete(provider)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@staticmethod
|
||||
def get_api_tool_provider(user_id: str, tenant_id: str, provider: str):
|
||||
def get_api_tool_provider(user_id: str, tenant_id: str, provider: str) -> dict[str, Any]:
|
||||
"""
|
||||
get api tool provider
|
||||
Get API tool provider details.
|
||||
|
||||
:param user_id: The ID of the user requesting the provider.
|
||||
:param tenant_id: The ID of the workspace/tenant.
|
||||
:param provider: The name of the API tool provider.
|
||||
:return: A dictionary containing the provider details.
|
||||
"""
|
||||
return ToolManager.user_get_api_provider(provider=provider, tenant_id=tenant_id)
|
||||
|
||||
@@ -360,10 +415,20 @@ class ApiToolManageService:
|
||||
parameters: dict[str, Any],
|
||||
schema_type: ApiProviderSchemaType,
|
||||
schema: str,
|
||||
):
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
test api tool before adding api tool provider
|
||||
Test an API tool before adding the API tool provider.
|
||||
|
||||
:param tenant_id: The ID of the workspace/tenant.
|
||||
:param provider_name: The name of the API tool provider.
|
||||
:param tool_name: The name of the specific tool to test.
|
||||
:param credentials: The credentials for the provider.
|
||||
:param parameters: The parameters to pass to the tool.
|
||||
:param schema_type: The type of schema (e.g., OpenAPI).
|
||||
:param schema: The raw schema string.
|
||||
:return: A dictionary containing the result or error message.
|
||||
"""
|
||||
|
||||
if schema_type not in [member.value for member in ApiProviderSchemaType]:
|
||||
raise ValueError(f"invalid schema type {schema_type}")
|
||||
|
||||
@@ -377,18 +442,21 @@ class ApiToolManageService:
|
||||
if tool_bundle is None:
|
||||
raise ValueError(f"invalid tool name {tool_name}")
|
||||
|
||||
db_provider = db.session.scalar(
|
||||
select(ApiToolProvider)
|
||||
.where(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
ApiToolProvider.name == provider_name,
|
||||
# create new session with automatic transaction management to get the provider
|
||||
provider: ApiToolProvider | None = None
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as _session:
|
||||
provider = _session.scalar(
|
||||
select(ApiToolProvider)
|
||||
.where(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
ApiToolProvider.name == provider_name,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not db_provider:
|
||||
if provider is None:
|
||||
# create a fake db provider
|
||||
db_provider = ApiToolProvider(
|
||||
provider = ApiToolProvider(
|
||||
tenant_id="",
|
||||
user_id="",
|
||||
name="",
|
||||
@@ -407,12 +475,12 @@ class ApiToolManageService:
|
||||
auth_type = ApiProviderAuthType.value_of(credentials["auth_type"])
|
||||
|
||||
# create provider entity
|
||||
provider_controller = ApiToolProviderController.from_db(db_provider, auth_type)
|
||||
provider_controller = ApiToolProviderController.from_db(provider, auth_type)
|
||||
# load tools into provider entity
|
||||
provider_controller.load_bundled_tools(tool_bundles)
|
||||
|
||||
# decrypt credentials
|
||||
if db_provider.id:
|
||||
if provider.id:
|
||||
encrypter, _ = create_tool_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
controller=provider_controller,
|
||||
@@ -443,14 +511,21 @@ class ApiToolManageService:
|
||||
@staticmethod
|
||||
def list_api_tools(tenant_id: str) -> list[ToolProviderApiEntity]:
|
||||
"""
|
||||
list api tools
|
||||
List all API tools for a specific tenant.
|
||||
|
||||
:param tenant_id: The ID of the workspace/tenant.
|
||||
:return: A list of ToolProviderApiEntity objects.
|
||||
"""
|
||||
# get all api providers
|
||||
db_providers = db.session.scalars(select(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id)).all()
|
||||
# create new session with automatic transaction management
|
||||
providers: list[ApiToolProvider] = []
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as _session:
|
||||
providers = list(
|
||||
_session.scalars(select(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id)).all()
|
||||
)
|
||||
|
||||
result: list[ToolProviderApiEntity] = []
|
||||
|
||||
for provider in db_providers:
|
||||
for provider in providers:
|
||||
# convert provider controller to user provider
|
||||
provider_controller = ToolTransformService.api_provider_to_controller(db_provider=provider)
|
||||
labels = ToolLabelManager.get_tool_labels(provider_controller)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import inspect
|
||||
import json
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
@@ -6,6 +8,8 @@ from pydantic import TypeAdapter, ValidationError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.tools.entities.tool_entities import ApiProviderSchemaType
|
||||
from core.tools.errors import ApiToolProviderNotFoundError
|
||||
from core.tools.tool_label_manager import ToolLabelManager
|
||||
from models import Account, Tenant
|
||||
from models.tools import ApiToolProvider
|
||||
from services.tools.api_tools_manage_service import ApiToolManageService
|
||||
@@ -590,30 +594,204 @@ class TestApiToolManageService:
|
||||
with pytest.raises(ValueError, match="you have not added provider"):
|
||||
ApiToolManageService.delete_api_tool_provider(account.id, tenant.id, "nonexistent")
|
||||
|
||||
def test_update_api_tool_provider_not_found(
|
||||
def test_update_api_tool_provider_success(
|
||||
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""Test update raises ValueError when original provider not found."""
|
||||
fake = Faker()
|
||||
|
||||
# Firmware fix for cache.delete() in update flow
|
||||
mock_encrypter = mock_external_service_dependencies["encrypter"]
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
mock_cache = MagicMock()
|
||||
mock_cache.delete.return_value = None
|
||||
mock_encrypter.return_value = (mock_encrypter, mock_cache)
|
||||
|
||||
# Get fake account and tenant
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="does not exists"):
|
||||
ApiToolManageService.update_api_tool_provider(
|
||||
# original provider name
|
||||
original_name = "original-provider"
|
||||
|
||||
# Create original provider
|
||||
_ = ApiToolManageService.create_api_tool_provider(
|
||||
user_id=account.id,
|
||||
tenant_id=tenant.id,
|
||||
provider_name=original_name,
|
||||
icon={"type": "emoji", "value": "🔧"},
|
||||
credentials={"auth_type": "none"},
|
||||
schema_type=ApiProviderSchemaType.OPENAPI,
|
||||
schema=self._create_test_openapi_schema(),
|
||||
privacy_policy="",
|
||||
custom_disclaimer="",
|
||||
labels=["old-label"],
|
||||
)
|
||||
|
||||
# new provide name and new labels for update
|
||||
new_name = "updated-provider"
|
||||
new_labels = ["new-label-1", "new-label-2"]
|
||||
|
||||
# Reset mock history so assertions focus on update path only
|
||||
mock_external_service_dependencies["encrypter"].reset_mock()
|
||||
mock_external_service_dependencies["provider_controller"].from_db.reset_mock()
|
||||
mock_external_service_dependencies["tool_label_manager"].update_tool_labels.reset_mock()
|
||||
|
||||
# Act: Update the provider with new values
|
||||
result = ApiToolManageService.update_api_tool_provider(
|
||||
user_id=account.id,
|
||||
tenant_id=tenant.id,
|
||||
# new provider name - changed 1
|
||||
provider_name=new_name,
|
||||
original_provider=original_name,
|
||||
# new icon - changed 2
|
||||
icon={"type": "emoji", "value": "🚀"},
|
||||
credentials={"auth_type": "none"},
|
||||
_schema_type=ApiProviderSchemaType.OPENAPI,
|
||||
schema=self._create_test_openapi_schema(),
|
||||
# new privacy policy - changed 3
|
||||
privacy_policy="https://new-policy.com",
|
||||
# new custom disclaimer - changed 4
|
||||
custom_disclaimer="New disclaimer",
|
||||
# new labels - changed 5 (However, we will not verify this, not this layer responsibility.)
|
||||
labels=new_labels,
|
||||
)
|
||||
|
||||
# Assert: Verify the result
|
||||
assert result == {"result": "success"}
|
||||
|
||||
# Get the updated provider from the database
|
||||
updated_provider: ApiToolProvider | None = (
|
||||
db_session_with_containers.query(ApiToolProvider)
|
||||
.filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == new_name)
|
||||
.first()
|
||||
)
|
||||
|
||||
# Verify the provider was updated successfully
|
||||
assert updated_provider is not None
|
||||
|
||||
# Manually refresh to keep object detachment
|
||||
db_session_with_containers.refresh(updated_provider)
|
||||
# Verify all the updated fields
|
||||
# - changed 1
|
||||
assert updated_provider.name == new_name
|
||||
# - changed 2
|
||||
icon_data = json.loads(updated_provider.icon)
|
||||
assert icon_data["type"] == "emoji"
|
||||
assert icon_data["value"] == "🚀"
|
||||
# - changed 3
|
||||
assert updated_provider.privacy_policy == "https://new-policy.com"
|
||||
# - changed 4
|
||||
assert updated_provider.custom_disclaimer == "New disclaimer"
|
||||
|
||||
# Verify old provider name no longer exists after rename
|
||||
original_provider: ApiToolProvider | None = (
|
||||
db_session_with_containers.query(ApiToolProvider)
|
||||
.filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == original_name)
|
||||
.first()
|
||||
)
|
||||
assert original_provider is None
|
||||
|
||||
# Verify update flow calls critical collaborators
|
||||
mock_external_service_dependencies["provider_controller"].from_db.assert_called_once()
|
||||
mock_external_service_dependencies["encrypter"].assert_called_once()
|
||||
mock_cache.delete.assert_called_once()
|
||||
|
||||
# Deeply verify on session propagation of labels update logics:
|
||||
# Since in refactoring, we pass session down to label manager to keep atomicity.
|
||||
# The assertion here is to verify this.
|
||||
sig = inspect.signature(ToolLabelManager.update_tool_labels)
|
||||
args, kwargs = mock_external_service_dependencies["tool_label_manager"].update_tool_labels.call_args
|
||||
bound_args = sig.bind(*args, **kwargs)
|
||||
passed_session = bound_args.arguments.get("session")
|
||||
# Ensure the type: Session
|
||||
assert isinstance(passed_session, Session), f"Expected Session object, got {type(passed_session)}"
|
||||
assert passed_session is not None, (
|
||||
"Atomicity Failure: Session cannot be passed to Label Manager in update_api_tool_provider"
|
||||
)
|
||||
|
||||
def test_update_api_tool_provider_not_found(
|
||||
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test update raises ValueError when original provider not found.
|
||||
|
||||
This test verifies:
|
||||
- Proper error when trying to update a non-existing original provider
|
||||
- No accidental upsert/new provider creation
|
||||
- No external dependency invocation on early failure path
|
||||
"""
|
||||
# Arrange: Create test account and tenant
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Keep an existing provider in DB to ensure unrelated data remains unchanged
|
||||
existing_provider_name = "existing-provider"
|
||||
_ = ApiToolManageService.create_api_tool_provider(
|
||||
user_id=account.id,
|
||||
tenant_id=tenant.id,
|
||||
provider_name=existing_provider_name,
|
||||
icon={"type": "emoji", "value": "🔧"},
|
||||
credentials={"auth_type": "none"},
|
||||
schema_type=ApiProviderSchemaType.OPENAPI,
|
||||
schema=self._create_test_openapi_schema(),
|
||||
privacy_policy="https://existing-policy.com",
|
||||
custom_disclaimer="Existing disclaimer",
|
||||
labels=["existing-label"],
|
||||
)
|
||||
|
||||
# Reset mock history so assertions focus on update failure path only
|
||||
mock_external_service_dependencies["tool_label_manager"].update_tool_labels.reset_mock()
|
||||
mock_external_service_dependencies["encrypter"].reset_mock()
|
||||
mock_external_service_dependencies["provider_controller"].from_db.reset_mock()
|
||||
|
||||
# Act & Assert: Verify update fails with clear error message
|
||||
target_new_name = "new-provider-name"
|
||||
missing_original_name = "missing-original-provider"
|
||||
with pytest.raises(ApiToolProviderNotFoundError) as exc_info:
|
||||
_ = ApiToolManageService.update_api_tool_provider(
|
||||
user_id=account.id,
|
||||
tenant_id=tenant.id,
|
||||
provider_name="new-name",
|
||||
original_provider="nonexistent",
|
||||
icon={},
|
||||
provider_name=target_new_name,
|
||||
original_provider=missing_original_name,
|
||||
icon={"type": "emoji", "value": "🚀"},
|
||||
credentials={"auth_type": "none"},
|
||||
_schema_type=ApiProviderSchemaType.OPENAPI,
|
||||
schema=self._create_test_openapi_schema(),
|
||||
privacy_policy=None,
|
||||
custom_disclaimer="",
|
||||
labels=[],
|
||||
privacy_policy="https://new-policy.com",
|
||||
custom_disclaimer="New disclaimer",
|
||||
labels=["new-label"],
|
||||
)
|
||||
|
||||
error = exc_info.value
|
||||
assert error.provider_name == missing_original_name
|
||||
assert error.tenant_id == tenant.id
|
||||
assert error.error_code == "api_tool_provider_not_found"
|
||||
|
||||
# Assert: Existing provider should remain unchanged
|
||||
existing_provider: ApiToolProvider | None = (
|
||||
db_session_with_containers.query(ApiToolProvider)
|
||||
.filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == existing_provider_name)
|
||||
.first()
|
||||
)
|
||||
assert existing_provider is not None
|
||||
assert existing_provider.name == existing_provider_name
|
||||
|
||||
# Assert: No new provider should be created
|
||||
unexpected_new_provider: ApiToolProvider | None = (
|
||||
db_session_with_containers.query(ApiToolProvider)
|
||||
.filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == target_new_name)
|
||||
.first()
|
||||
)
|
||||
assert unexpected_new_provider is None
|
||||
|
||||
# Assert: Early failure should skip all downstream external interactions
|
||||
mock_external_service_dependencies["tool_label_manager"].update_tool_labels.assert_not_called()
|
||||
mock_external_service_dependencies["encrypter"].assert_not_called()
|
||||
mock_external_service_dependencies["provider_controller"].from_db.assert_not_called()
|
||||
|
||||
def test_update_api_tool_provider_missing_auth_type(
|
||||
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
|
||||
Reference in New Issue
Block a user