import json import httpx import yaml from flask import request from flask_restx import Resource from pydantic import BaseModel from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, setup_required from core.plugin.impl.exc import PluginPermissionDeniedError from extensions.ext_database import db from libs.login import current_account_with_tenant, login_required from models.model import App from models.workflow import Workflow from services.app_dsl_service import AppDslService class DSLPredictRequest(BaseModel): app_id: str current_node_id: str @console_ns.route("/workspaces/current/dsl/predict") class DSLPredictApi(Resource): @setup_required @login_required @account_initialization_required def post(self): user, _ = current_account_with_tenant() if not user.is_admin_or_owner: raise Forbidden() args = DSLPredictRequest.model_validate(request.get_json()) app_id: str = args.app_id current_node_id: str = args.current_node_id with Session(db.engine) as session: app = session.query(App).filter_by(id=app_id).first() workflow = session.query(Workflow).filter_by(app_id=app_id, version=Workflow.VERSION_DRAFT).first() if not app: raise ValueError("App not found") if not workflow: raise ValueError("Workflow not found") try: i = 0 for node_id, _ in workflow.walk_nodes(): if node_id == current_node_id: break i += 1 dsl = yaml.safe_load(AppDslService.export_dsl(app_model=app)) response = httpx.post( "http://spark-832c:8000/predict", json={"graph_data": dsl, "source_node_index": i}, ) return { "nodes": json.loads(response.json()), } except PluginPermissionDeniedError as e: raise ValueError(e.description) from e