mirror of
https://github.com/langgenius/dify.git
synced 2026-05-25 19:00:43 -04:00
321 lines
13 KiB
Python
321 lines
13 KiB
Python
from typing import Any
|
|
|
|
from sqlalchemy import func, select
|
|
from sqlalchemy.exc import IntegrityError
|
|
|
|
from libs.datetime_utils import naive_utc_now
|
|
from models.agent import (
|
|
Agent,
|
|
AgentConfigRevision,
|
|
AgentConfigRevisionOperation,
|
|
AgentConfigSnapshot,
|
|
AgentKind,
|
|
AgentScope,
|
|
AgentSource,
|
|
AgentStatus,
|
|
WorkflowAgentNodeBinding,
|
|
)
|
|
from models.workflow import Workflow
|
|
from services.agent.composer_validator import ComposerConfigValidator
|
|
from services.agent.errors import (
|
|
AgentArchivedError,
|
|
AgentNameConflictError,
|
|
AgentNotFoundError,
|
|
AgentVersionNotFoundError,
|
|
)
|
|
from services.entities.agent_entities import RosterAgentCreatePayload, RosterAgentUpdatePayload
|
|
|
|
|
|
class AgentRosterService:
|
|
def __init__(self, session: Any):
|
|
self._session = session
|
|
|
|
@staticmethod
|
|
def serialize_agent(agent: Agent, active_version: AgentConfigSnapshot | None = None) -> dict[str, Any]:
|
|
return {
|
|
"id": agent.id,
|
|
"name": agent.name,
|
|
"description": agent.description,
|
|
"icon_type": agent.icon_type.value if agent.icon_type else None,
|
|
"icon": agent.icon,
|
|
"icon_background": agent.icon_background,
|
|
"agent_kind": agent.agent_kind.value,
|
|
"scope": agent.scope.value,
|
|
"source": agent.source.value,
|
|
"app_id": agent.app_id,
|
|
"workflow_id": agent.workflow_id,
|
|
"workflow_node_id": agent.workflow_node_id,
|
|
"active_config_snapshot_id": agent.active_config_snapshot_id,
|
|
"active_config_snapshot": AgentRosterService.serialize_version(active_version) if active_version else None,
|
|
"status": agent.status.value,
|
|
"created_by": agent.created_by,
|
|
"updated_by": agent.updated_by,
|
|
"archived_by": agent.archived_by,
|
|
"archived_at": agent.archived_at.isoformat() if agent.archived_at else None,
|
|
"created_at": agent.created_at.isoformat() if agent.created_at else None,
|
|
"updated_at": agent.updated_at.isoformat() if agent.updated_at else None,
|
|
}
|
|
|
|
@staticmethod
|
|
def serialize_version(version: AgentConfigSnapshot | None) -> dict[str, Any] | None:
|
|
if version is None:
|
|
return None
|
|
return {
|
|
"id": version.id,
|
|
"agent_id": version.agent_id,
|
|
"version": version.version,
|
|
"summary": version.summary,
|
|
"version_note": version.version_note,
|
|
"created_by": version.created_by,
|
|
"created_at": version.created_at.isoformat() if version.created_at else None,
|
|
}
|
|
|
|
def list_roster_agents(
|
|
self, *, tenant_id: str, page: int = 1, limit: int = 20, keyword: str | None = None
|
|
) -> dict[str, Any]:
|
|
stmt = select(Agent).where(
|
|
Agent.tenant_id == tenant_id,
|
|
Agent.scope == AgentScope.ROSTER,
|
|
Agent.status == AgentStatus.ACTIVE,
|
|
)
|
|
if keyword:
|
|
from libs.helper import escape_like_pattern
|
|
|
|
escaped_keyword = escape_like_pattern(keyword)
|
|
stmt = stmt.where(Agent.name.ilike(f"%{escaped_keyword}%", escape="\\"))
|
|
stmt = stmt.order_by(Agent.updated_at.desc())
|
|
|
|
total = self._session.scalar(select(func.count()).select_from(stmt.subquery())) or 0
|
|
agents = list(self._session.scalars(stmt.offset((page - 1) * limit).limit(limit)).all())
|
|
versions_by_id = self._load_versions_by_id(
|
|
[agent.active_config_snapshot_id for agent in agents if agent.active_config_snapshot_id]
|
|
)
|
|
|
|
data = []
|
|
for agent in agents:
|
|
active_version = (
|
|
versions_by_id.get(agent.active_config_snapshot_id) if agent.active_config_snapshot_id else None
|
|
)
|
|
data.append(self.serialize_agent(agent, active_version))
|
|
|
|
return {
|
|
"data": data,
|
|
"page": page,
|
|
"limit": limit,
|
|
"total": total,
|
|
"has_more": page * limit < total,
|
|
}
|
|
|
|
def list_invite_options(
|
|
self, *, tenant_id: str, page: int = 1, limit: int = 20, keyword: str | None = None, app_id: str | None = None
|
|
) -> dict[str, Any]:
|
|
result = self.list_roster_agents(tenant_id=tenant_id, page=page, limit=limit, keyword=keyword)
|
|
usage_by_agent_id: dict[str, list[str]] = {}
|
|
if app_id:
|
|
draft_workflow = self._session.scalar(
|
|
select(Workflow)
|
|
.where(
|
|
Workflow.tenant_id == tenant_id,
|
|
Workflow.app_id == app_id,
|
|
Workflow.version == Workflow.VERSION_DRAFT,
|
|
)
|
|
.limit(1)
|
|
)
|
|
if draft_workflow:
|
|
agent_ids = [item["id"] for item in result["data"]]
|
|
if agent_ids:
|
|
bindings = self._session.scalars(
|
|
select(WorkflowAgentNodeBinding).where(
|
|
WorkflowAgentNodeBinding.tenant_id == tenant_id,
|
|
WorkflowAgentNodeBinding.workflow_id == draft_workflow.id,
|
|
WorkflowAgentNodeBinding.agent_id.in_(agent_ids),
|
|
)
|
|
).all()
|
|
for binding in bindings:
|
|
if binding.agent_id:
|
|
usage_by_agent_id.setdefault(binding.agent_id, []).append(binding.node_id)
|
|
|
|
for item in result["data"]:
|
|
existing_node_ids = usage_by_agent_id.get(item["id"], [])
|
|
item["is_in_current_workflow"] = bool(existing_node_ids)
|
|
item["in_current_workflow_count"] = len(existing_node_ids)
|
|
item["existing_node_ids"] = existing_node_ids
|
|
return result
|
|
|
|
def create_roster_agent(
|
|
self,
|
|
*,
|
|
tenant_id: str,
|
|
account_id: str,
|
|
payload: RosterAgentCreatePayload,
|
|
source: AgentSource = AgentSource.AGENT_APP,
|
|
) -> Agent:
|
|
ComposerConfigValidator.validate_agent_soul(payload.agent_soul)
|
|
|
|
agent = Agent(
|
|
tenant_id=tenant_id,
|
|
name=payload.name,
|
|
description=payload.description,
|
|
icon_type=payload.icon_type,
|
|
icon=payload.icon,
|
|
icon_background=payload.icon_background,
|
|
agent_kind=AgentKind.DIFY_AGENT,
|
|
scope=AgentScope.ROSTER,
|
|
source=source,
|
|
status=AgentStatus.ACTIVE,
|
|
created_by=account_id,
|
|
updated_by=account_id,
|
|
)
|
|
self._session.add(agent)
|
|
try:
|
|
self._session.flush()
|
|
except IntegrityError as exc:
|
|
self._session.rollback()
|
|
raise AgentNameConflictError() from exc
|
|
|
|
version = AgentConfigSnapshot(
|
|
tenant_id=tenant_id,
|
|
agent_id=agent.id,
|
|
version=1,
|
|
config_snapshot=payload.agent_soul,
|
|
version_note=payload.version_note,
|
|
created_by=account_id,
|
|
)
|
|
self._session.add(version)
|
|
self._session.flush()
|
|
|
|
revision = AgentConfigRevision(
|
|
tenant_id=tenant_id,
|
|
agent_id=agent.id,
|
|
current_snapshot_id=version.id,
|
|
revision=1,
|
|
operation=AgentConfigRevisionOperation.CREATE_VERSION,
|
|
version_note=payload.version_note,
|
|
created_by=account_id,
|
|
)
|
|
self._session.add(revision)
|
|
agent.active_config_snapshot_id = version.id
|
|
|
|
try:
|
|
self._session.commit()
|
|
except IntegrityError as exc:
|
|
self._session.rollback()
|
|
raise AgentNameConflictError() from exc
|
|
return agent
|
|
|
|
def get_roster_agent_detail(self, *, tenant_id: str, agent_id: str) -> dict[str, Any]:
|
|
agent = self._get_agent(tenant_id=tenant_id, agent_id=agent_id, roster_only=True)
|
|
active_version = self._get_version(
|
|
tenant_id=tenant_id, agent_id=agent.id, version_id=agent.active_config_snapshot_id
|
|
)
|
|
return self.serialize_agent(agent, active_version)
|
|
|
|
def update_roster_agent(
|
|
self, *, tenant_id: str, agent_id: str, account_id: str, payload: RosterAgentUpdatePayload
|
|
) -> dict[str, Any]:
|
|
agent = self._get_agent(tenant_id=tenant_id, agent_id=agent_id, roster_only=True)
|
|
if agent.status == AgentStatus.ARCHIVED:
|
|
raise AgentArchivedError()
|
|
|
|
update_data = payload.model_dump(exclude_unset=True)
|
|
for key, value in update_data.items():
|
|
setattr(agent, key, value)
|
|
agent.updated_by = account_id
|
|
|
|
try:
|
|
self._session.commit()
|
|
except IntegrityError as exc:
|
|
self._session.rollback()
|
|
raise AgentNameConflictError() from exc
|
|
return self.get_roster_agent_detail(tenant_id=tenant_id, agent_id=agent_id)
|
|
|
|
def archive_roster_agent(self, *, tenant_id: str, agent_id: str, account_id: str) -> None:
|
|
agent = self._get_agent(tenant_id=tenant_id, agent_id=agent_id, roster_only=True)
|
|
if agent.status == AgentStatus.ARCHIVED:
|
|
return
|
|
agent.status = AgentStatus.ARCHIVED
|
|
agent.archived_by = account_id
|
|
agent.archived_at = naive_utc_now()
|
|
agent.updated_by = account_id
|
|
self._session.commit()
|
|
|
|
def list_agent_versions(self, *, tenant_id: str, agent_id: str) -> list[dict[str, Any]]:
|
|
self._get_agent(tenant_id=tenant_id, agent_id=agent_id, roster_only=True)
|
|
versions = list(
|
|
self._session.scalars(
|
|
select(AgentConfigSnapshot)
|
|
.where(AgentConfigSnapshot.tenant_id == tenant_id, AgentConfigSnapshot.agent_id == agent_id)
|
|
.order_by(AgentConfigSnapshot.version.desc())
|
|
).all()
|
|
)
|
|
return [
|
|
serialized_version
|
|
for version in versions
|
|
if (serialized_version := self.serialize_version(version)) is not None
|
|
]
|
|
|
|
def get_agent_version_detail(self, *, tenant_id: str, agent_id: str, version_id: str) -> dict[str, Any]:
|
|
self._get_agent(tenant_id=tenant_id, agent_id=agent_id, roster_only=True)
|
|
version = self._get_version(tenant_id=tenant_id, agent_id=agent_id, version_id=version_id)
|
|
revisions = list(
|
|
self._session.scalars(
|
|
select(AgentConfigRevision)
|
|
.where(
|
|
AgentConfigRevision.tenant_id == tenant_id,
|
|
AgentConfigRevision.agent_id == agent_id,
|
|
AgentConfigRevision.current_snapshot_id == version_id,
|
|
)
|
|
.order_by(AgentConfigRevision.revision.desc())
|
|
).all()
|
|
)
|
|
result = self.serialize_version(version) or {}
|
|
result["config_snapshot"] = version.config_snapshot_dict
|
|
result["revisions"] = [
|
|
{
|
|
"id": revision.id,
|
|
"previous_snapshot_id": revision.previous_snapshot_id,
|
|
"current_snapshot_id": revision.current_snapshot_id,
|
|
"revision": revision.revision,
|
|
"operation": revision.operation.value,
|
|
"summary": revision.summary,
|
|
"version_note": revision.version_note,
|
|
"created_by": revision.created_by,
|
|
"created_at": revision.created_at.isoformat() if revision.created_at else None,
|
|
}
|
|
for revision in revisions
|
|
]
|
|
return result
|
|
|
|
def _get_agent(self, *, tenant_id: str, agent_id: str, roster_only: bool = False) -> Agent:
|
|
stmt = select(Agent).where(Agent.tenant_id == tenant_id, Agent.id == agent_id)
|
|
if roster_only:
|
|
stmt = stmt.where(Agent.scope == AgentScope.ROSTER)
|
|
agent = self._session.scalar(stmt.limit(1))
|
|
if not agent:
|
|
raise AgentNotFoundError()
|
|
return agent
|
|
|
|
def _get_version(self, *, tenant_id: str, agent_id: str, version_id: str | None) -> AgentConfigSnapshot:
|
|
if not version_id:
|
|
raise AgentVersionNotFoundError()
|
|
version = self._session.scalar(
|
|
select(AgentConfigSnapshot)
|
|
.where(
|
|
AgentConfigSnapshot.tenant_id == tenant_id,
|
|
AgentConfigSnapshot.agent_id == agent_id,
|
|
AgentConfigSnapshot.id == version_id,
|
|
)
|
|
.limit(1)
|
|
)
|
|
if not version:
|
|
raise AgentVersionNotFoundError()
|
|
return version
|
|
|
|
def _load_versions_by_id(self, version_ids: list[str]) -> dict[str, AgentConfigSnapshot]:
|
|
if not version_ids:
|
|
return {}
|
|
versions = self._session.scalars(
|
|
select(AgentConfigSnapshot).where(AgentConfigSnapshot.id.in_(version_ids))
|
|
).all()
|
|
return {version.id: version for version in versions}
|