Files
dify/api/services/agent/roster_service.py
zyssyz123 d9e90d0fa0 feat: add new agent (#36284)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-05-19 10:43:23 +00:00

321 lines
13 KiB
Python

from typing import Any
from sqlalchemy import func, select
from sqlalchemy.exc import IntegrityError
from libs.datetime_utils import naive_utc_now
from models.agent import (
Agent,
AgentConfigRevision,
AgentConfigRevisionOperation,
AgentConfigSnapshot,
AgentKind,
AgentScope,
AgentSource,
AgentStatus,
WorkflowAgentNodeBinding,
)
from models.workflow import Workflow
from services.agent.composer_validator import ComposerConfigValidator
from services.agent.errors import (
AgentArchivedError,
AgentNameConflictError,
AgentNotFoundError,
AgentVersionNotFoundError,
)
from services.entities.agent_entities import RosterAgentCreatePayload, RosterAgentUpdatePayload
class AgentRosterService:
def __init__(self, session: Any):
self._session = session
@staticmethod
def serialize_agent(agent: Agent, active_version: AgentConfigSnapshot | None = None) -> dict[str, Any]:
return {
"id": agent.id,
"name": agent.name,
"description": agent.description,
"icon_type": agent.icon_type.value if agent.icon_type else None,
"icon": agent.icon,
"icon_background": agent.icon_background,
"agent_kind": agent.agent_kind.value,
"scope": agent.scope.value,
"source": agent.source.value,
"app_id": agent.app_id,
"workflow_id": agent.workflow_id,
"workflow_node_id": agent.workflow_node_id,
"active_config_snapshot_id": agent.active_config_snapshot_id,
"active_config_snapshot": AgentRosterService.serialize_version(active_version) if active_version else None,
"status": agent.status.value,
"created_by": agent.created_by,
"updated_by": agent.updated_by,
"archived_by": agent.archived_by,
"archived_at": agent.archived_at.isoformat() if agent.archived_at else None,
"created_at": agent.created_at.isoformat() if agent.created_at else None,
"updated_at": agent.updated_at.isoformat() if agent.updated_at else None,
}
@staticmethod
def serialize_version(version: AgentConfigSnapshot | None) -> dict[str, Any] | None:
if version is None:
return None
return {
"id": version.id,
"agent_id": version.agent_id,
"version": version.version,
"summary": version.summary,
"version_note": version.version_note,
"created_by": version.created_by,
"created_at": version.created_at.isoformat() if version.created_at else None,
}
def list_roster_agents(
self, *, tenant_id: str, page: int = 1, limit: int = 20, keyword: str | None = None
) -> dict[str, Any]:
stmt = select(Agent).where(
Agent.tenant_id == tenant_id,
Agent.scope == AgentScope.ROSTER,
Agent.status == AgentStatus.ACTIVE,
)
if keyword:
from libs.helper import escape_like_pattern
escaped_keyword = escape_like_pattern(keyword)
stmt = stmt.where(Agent.name.ilike(f"%{escaped_keyword}%", escape="\\"))
stmt = stmt.order_by(Agent.updated_at.desc())
total = self._session.scalar(select(func.count()).select_from(stmt.subquery())) or 0
agents = list(self._session.scalars(stmt.offset((page - 1) * limit).limit(limit)).all())
versions_by_id = self._load_versions_by_id(
[agent.active_config_snapshot_id for agent in agents if agent.active_config_snapshot_id]
)
data = []
for agent in agents:
active_version = (
versions_by_id.get(agent.active_config_snapshot_id) if agent.active_config_snapshot_id else None
)
data.append(self.serialize_agent(agent, active_version))
return {
"data": data,
"page": page,
"limit": limit,
"total": total,
"has_more": page * limit < total,
}
def list_invite_options(
self, *, tenant_id: str, page: int = 1, limit: int = 20, keyword: str | None = None, app_id: str | None = None
) -> dict[str, Any]:
result = self.list_roster_agents(tenant_id=tenant_id, page=page, limit=limit, keyword=keyword)
usage_by_agent_id: dict[str, list[str]] = {}
if app_id:
draft_workflow = self._session.scalar(
select(Workflow)
.where(
Workflow.tenant_id == tenant_id,
Workflow.app_id == app_id,
Workflow.version == Workflow.VERSION_DRAFT,
)
.limit(1)
)
if draft_workflow:
agent_ids = [item["id"] for item in result["data"]]
if agent_ids:
bindings = self._session.scalars(
select(WorkflowAgentNodeBinding).where(
WorkflowAgentNodeBinding.tenant_id == tenant_id,
WorkflowAgentNodeBinding.workflow_id == draft_workflow.id,
WorkflowAgentNodeBinding.agent_id.in_(agent_ids),
)
).all()
for binding in bindings:
if binding.agent_id:
usage_by_agent_id.setdefault(binding.agent_id, []).append(binding.node_id)
for item in result["data"]:
existing_node_ids = usage_by_agent_id.get(item["id"], [])
item["is_in_current_workflow"] = bool(existing_node_ids)
item["in_current_workflow_count"] = len(existing_node_ids)
item["existing_node_ids"] = existing_node_ids
return result
def create_roster_agent(
self,
*,
tenant_id: str,
account_id: str,
payload: RosterAgentCreatePayload,
source: AgentSource = AgentSource.AGENT_APP,
) -> Agent:
ComposerConfigValidator.validate_agent_soul(payload.agent_soul)
agent = Agent(
tenant_id=tenant_id,
name=payload.name,
description=payload.description,
icon_type=payload.icon_type,
icon=payload.icon,
icon_background=payload.icon_background,
agent_kind=AgentKind.DIFY_AGENT,
scope=AgentScope.ROSTER,
source=source,
status=AgentStatus.ACTIVE,
created_by=account_id,
updated_by=account_id,
)
self._session.add(agent)
try:
self._session.flush()
except IntegrityError as exc:
self._session.rollback()
raise AgentNameConflictError() from exc
version = AgentConfigSnapshot(
tenant_id=tenant_id,
agent_id=agent.id,
version=1,
config_snapshot=payload.agent_soul,
version_note=payload.version_note,
created_by=account_id,
)
self._session.add(version)
self._session.flush()
revision = AgentConfigRevision(
tenant_id=tenant_id,
agent_id=agent.id,
current_snapshot_id=version.id,
revision=1,
operation=AgentConfigRevisionOperation.CREATE_VERSION,
version_note=payload.version_note,
created_by=account_id,
)
self._session.add(revision)
agent.active_config_snapshot_id = version.id
try:
self._session.commit()
except IntegrityError as exc:
self._session.rollback()
raise AgentNameConflictError() from exc
return agent
def get_roster_agent_detail(self, *, tenant_id: str, agent_id: str) -> dict[str, Any]:
agent = self._get_agent(tenant_id=tenant_id, agent_id=agent_id, roster_only=True)
active_version = self._get_version(
tenant_id=tenant_id, agent_id=agent.id, version_id=agent.active_config_snapshot_id
)
return self.serialize_agent(agent, active_version)
def update_roster_agent(
self, *, tenant_id: str, agent_id: str, account_id: str, payload: RosterAgentUpdatePayload
) -> dict[str, Any]:
agent = self._get_agent(tenant_id=tenant_id, agent_id=agent_id, roster_only=True)
if agent.status == AgentStatus.ARCHIVED:
raise AgentArchivedError()
update_data = payload.model_dump(exclude_unset=True)
for key, value in update_data.items():
setattr(agent, key, value)
agent.updated_by = account_id
try:
self._session.commit()
except IntegrityError as exc:
self._session.rollback()
raise AgentNameConflictError() from exc
return self.get_roster_agent_detail(tenant_id=tenant_id, agent_id=agent_id)
def archive_roster_agent(self, *, tenant_id: str, agent_id: str, account_id: str) -> None:
agent = self._get_agent(tenant_id=tenant_id, agent_id=agent_id, roster_only=True)
if agent.status == AgentStatus.ARCHIVED:
return
agent.status = AgentStatus.ARCHIVED
agent.archived_by = account_id
agent.archived_at = naive_utc_now()
agent.updated_by = account_id
self._session.commit()
def list_agent_versions(self, *, tenant_id: str, agent_id: str) -> list[dict[str, Any]]:
self._get_agent(tenant_id=tenant_id, agent_id=agent_id, roster_only=True)
versions = list(
self._session.scalars(
select(AgentConfigSnapshot)
.where(AgentConfigSnapshot.tenant_id == tenant_id, AgentConfigSnapshot.agent_id == agent_id)
.order_by(AgentConfigSnapshot.version.desc())
).all()
)
return [
serialized_version
for version in versions
if (serialized_version := self.serialize_version(version)) is not None
]
def get_agent_version_detail(self, *, tenant_id: str, agent_id: str, version_id: str) -> dict[str, Any]:
self._get_agent(tenant_id=tenant_id, agent_id=agent_id, roster_only=True)
version = self._get_version(tenant_id=tenant_id, agent_id=agent_id, version_id=version_id)
revisions = list(
self._session.scalars(
select(AgentConfigRevision)
.where(
AgentConfigRevision.tenant_id == tenant_id,
AgentConfigRevision.agent_id == agent_id,
AgentConfigRevision.current_snapshot_id == version_id,
)
.order_by(AgentConfigRevision.revision.desc())
).all()
)
result = self.serialize_version(version) or {}
result["config_snapshot"] = version.config_snapshot_dict
result["revisions"] = [
{
"id": revision.id,
"previous_snapshot_id": revision.previous_snapshot_id,
"current_snapshot_id": revision.current_snapshot_id,
"revision": revision.revision,
"operation": revision.operation.value,
"summary": revision.summary,
"version_note": revision.version_note,
"created_by": revision.created_by,
"created_at": revision.created_at.isoformat() if revision.created_at else None,
}
for revision in revisions
]
return result
def _get_agent(self, *, tenant_id: str, agent_id: str, roster_only: bool = False) -> Agent:
stmt = select(Agent).where(Agent.tenant_id == tenant_id, Agent.id == agent_id)
if roster_only:
stmt = stmt.where(Agent.scope == AgentScope.ROSTER)
agent = self._session.scalar(stmt.limit(1))
if not agent:
raise AgentNotFoundError()
return agent
def _get_version(self, *, tenant_id: str, agent_id: str, version_id: str | None) -> AgentConfigSnapshot:
if not version_id:
raise AgentVersionNotFoundError()
version = self._session.scalar(
select(AgentConfigSnapshot)
.where(
AgentConfigSnapshot.tenant_id == tenant_id,
AgentConfigSnapshot.agent_id == agent_id,
AgentConfigSnapshot.id == version_id,
)
.limit(1)
)
if not version:
raise AgentVersionNotFoundError()
return version
def _load_versions_by_id(self, version_ids: list[str]) -> dict[str, AgentConfigSnapshot]:
if not version_ids:
return {}
versions = self._session.scalars(
select(AgentConfigSnapshot).where(AgentConfigSnapshot.id.in_(version_ids))
).all()
return {version.id: version for version in versions}