mirror of
https://github.com/langgenius/dify.git
synced 2026-05-10 15:01:36 -04:00
chore(api): upgrade graphon to v0.3.0 (#35469)
Signed-off-by: -LAN- <laipz8200@outlook.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: WH-2099 <wh2099@pm.me>
This commit is contained in:
@@ -1,5 +1,15 @@
|
||||
"""LLM-related application services."""
|
||||
|
||||
from .quota import deduct_llm_quota, ensure_llm_quota_available
|
||||
from .quota import (
|
||||
deduct_llm_quota,
|
||||
deduct_llm_quota_for_model,
|
||||
ensure_llm_quota_available,
|
||||
ensure_llm_quota_available_for_model,
|
||||
)
|
||||
|
||||
__all__ = ["deduct_llm_quota", "ensure_llm_quota_available"]
|
||||
__all__ = [
|
||||
"deduct_llm_quota",
|
||||
"deduct_llm_quota_for_model",
|
||||
"ensure_llm_quota_available",
|
||||
"ensure_llm_quota_available_for_model",
|
||||
]
|
||||
|
||||
@@ -1,4 +1,14 @@
|
||||
from sqlalchemy import update
|
||||
"""Tenant-scoped helpers for checking and deducting LLM provider quota.
|
||||
|
||||
System-hosted quota accounting is currently defined only for LLM models. Keep
|
||||
the public helpers LLM-specific so callers do not carry unused model-type
|
||||
plumbing, and fail loudly if the deprecated ``ModelInstance`` wrappers are used
|
||||
with a non-LLM model.
|
||||
"""
|
||||
|
||||
import warnings
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
@@ -6,44 +16,47 @@ from core.entities.model_entities import ModelStatus
|
||||
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit
|
||||
from core.errors.error import QuotaExceededError
|
||||
from core.model_manager import ModelInstance
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
|
||||
from extensions.ext_database import db
|
||||
from graphon.model_runtime.entities.llm_entities import LLMUsage
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.provider import Provider, ProviderType
|
||||
from models.provider_ids import ModelProviderID
|
||||
|
||||
|
||||
def ensure_llm_quota_available(*, model_instance: ModelInstance) -> None:
|
||||
provider_model_bundle = model_instance.provider_model_bundle
|
||||
provider_configuration = provider_model_bundle.configuration
|
||||
def _get_provider_configuration(*, tenant_id: str, provider: str):
|
||||
"""Resolve the tenant-bound provider configuration for quota decisions."""
|
||||
provider_manager = create_plugin_provider_manager(tenant_id=tenant_id)
|
||||
provider_configuration = provider_manager.get_configurations(tenant_id).get(provider)
|
||||
if provider_configuration is None:
|
||||
raise ValueError(f"Provider {provider} does not exist.")
|
||||
return provider_configuration
|
||||
|
||||
|
||||
def ensure_llm_quota_available_for_model(*, tenant_id: str, provider: str, model: str) -> None:
|
||||
"""Raise when a tenant-bound LLM model is already out of quota."""
|
||||
provider_configuration = _get_provider_configuration(tenant_id=tenant_id, provider=provider)
|
||||
if provider_configuration.using_provider_type != ProviderType.SYSTEM:
|
||||
return
|
||||
|
||||
provider_model = provider_configuration.get_provider_model(
|
||||
model_type=model_instance.model_type_instance.model_type,
|
||||
model=model_instance.model_name,
|
||||
model_type=ModelType.LLM,
|
||||
model=model,
|
||||
)
|
||||
if provider_model and provider_model.status == ModelStatus.QUOTA_EXCEEDED:
|
||||
raise QuotaExceededError(f"Model provider {model_instance.provider} quota exceeded.")
|
||||
raise QuotaExceededError(f"Model provider {provider} quota exceeded.")
|
||||
|
||||
|
||||
def deduct_llm_quota(*, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None:
|
||||
provider_model_bundle = model_instance.provider_model_bundle
|
||||
provider_configuration = provider_model_bundle.configuration
|
||||
|
||||
if provider_configuration.using_provider_type != ProviderType.SYSTEM:
|
||||
return
|
||||
|
||||
system_configuration = provider_configuration.system_configuration
|
||||
|
||||
def _resolve_llm_used_quota(*, system_configuration, model: str, usage: LLMUsage) -> int | None:
|
||||
"""Compute the quota impact for an LLM invocation under the current quota mode."""
|
||||
quota_unit = None
|
||||
for quota_configuration in system_configuration.quota_configurations:
|
||||
if quota_configuration.quota_type == system_configuration.current_quota_type:
|
||||
quota_unit = quota_configuration.quota_unit
|
||||
|
||||
if quota_configuration.quota_limit == -1:
|
||||
return
|
||||
return None
|
||||
|
||||
break
|
||||
|
||||
@@ -52,42 +65,136 @@ def deduct_llm_quota(*, tenant_id: str, model_instance: ModelInstance, usage: LL
|
||||
if quota_unit == QuotaUnit.TOKENS:
|
||||
used_quota = usage.total_tokens
|
||||
elif quota_unit == QuotaUnit.CREDITS:
|
||||
used_quota = dify_config.get_model_credits(model_instance.model_name)
|
||||
used_quota = dify_config.get_model_credits(model)
|
||||
else:
|
||||
used_quota = 1
|
||||
|
||||
return used_quota
|
||||
|
||||
|
||||
def _deduct_free_llm_quota(
|
||||
*,
|
||||
tenant_id: str,
|
||||
provider: str,
|
||||
quota_type: ProviderQuotaType,
|
||||
used_quota: int,
|
||||
) -> None:
|
||||
"""Deduct FREE provider quota, capping at the limit before reporting exhaustion."""
|
||||
quota_exceeded = False
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
provider_record = session.scalar(
|
||||
select(Provider)
|
||||
.where(
|
||||
Provider.tenant_id == tenant_id,
|
||||
# TODO: Use provider name with prefix after the data migration.
|
||||
Provider.provider_name == ModelProviderID(provider).provider_name,
|
||||
Provider.provider_type == ProviderType.SYSTEM.value,
|
||||
Provider.quota_type == quota_type,
|
||||
)
|
||||
.with_for_update()
|
||||
)
|
||||
if (
|
||||
provider_record is None
|
||||
or provider_record.quota_limit is None
|
||||
or provider_record.quota_used is None
|
||||
or provider_record.quota_limit <= provider_record.quota_used
|
||||
):
|
||||
quota_exceeded = True
|
||||
else:
|
||||
available_quota = provider_record.quota_limit - provider_record.quota_used
|
||||
deducted_quota = min(used_quota, available_quota)
|
||||
provider_record.quota_used += deducted_quota
|
||||
provider_record.last_used = naive_utc_now()
|
||||
quota_exceeded = deducted_quota < used_quota
|
||||
|
||||
if quota_exceeded:
|
||||
raise QuotaExceededError(f"Model provider {provider} quota exceeded.")
|
||||
|
||||
|
||||
def _deduct_used_llm_quota(*, tenant_id: str, provider: str, provider_configuration, used_quota: int | None) -> None:
|
||||
"""Apply a resolved LLM quota charge against the current provider quota bucket."""
|
||||
if provider_configuration.using_provider_type != ProviderType.SYSTEM:
|
||||
return
|
||||
|
||||
system_configuration = provider_configuration.system_configuration
|
||||
if used_quota is not None and system_configuration.current_quota_type is not None:
|
||||
match system_configuration.current_quota_type:
|
||||
case ProviderQuotaType.TRIAL:
|
||||
from services.credit_pool_service import CreditPoolService
|
||||
|
||||
CreditPoolService.check_and_deduct_credits(
|
||||
CreditPoolService.deduct_credits_capped(
|
||||
tenant_id=tenant_id,
|
||||
credits_required=used_quota,
|
||||
)
|
||||
case ProviderQuotaType.PAID:
|
||||
from services.credit_pool_service import CreditPoolService
|
||||
|
||||
CreditPoolService.check_and_deduct_credits(
|
||||
CreditPoolService.deduct_credits_capped(
|
||||
tenant_id=tenant_id,
|
||||
credits_required=used_quota,
|
||||
pool_type="paid",
|
||||
)
|
||||
case ProviderQuotaType.FREE:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
stmt = (
|
||||
update(Provider)
|
||||
.where(
|
||||
Provider.tenant_id == tenant_id,
|
||||
# TODO: Use provider name with prefix after the data migration.
|
||||
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
|
||||
Provider.provider_type == ProviderType.SYSTEM.value,
|
||||
Provider.quota_type == system_configuration.current_quota_type,
|
||||
Provider.quota_limit > Provider.quota_used,
|
||||
)
|
||||
.values(
|
||||
quota_used=Provider.quota_used + used_quota,
|
||||
last_used=naive_utc_now(),
|
||||
)
|
||||
)
|
||||
session.execute(stmt)
|
||||
_deduct_free_llm_quota(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
quota_type=system_configuration.current_quota_type,
|
||||
used_quota=used_quota,
|
||||
)
|
||||
case _:
|
||||
return
|
||||
|
||||
|
||||
def deduct_llm_quota_for_model(*, tenant_id: str, provider: str, model: str, usage: LLMUsage) -> None:
|
||||
"""Deduct tenant-bound quota for the resolved LLM model identity."""
|
||||
provider_configuration = _get_provider_configuration(tenant_id=tenant_id, provider=provider)
|
||||
used_quota = _resolve_llm_used_quota(
|
||||
system_configuration=provider_configuration.system_configuration,
|
||||
model=model,
|
||||
usage=usage,
|
||||
)
|
||||
_deduct_used_llm_quota(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
provider_configuration=provider_configuration,
|
||||
used_quota=used_quota,
|
||||
)
|
||||
|
||||
|
||||
def _require_llm_model_instance(model_instance: ModelInstance) -> None:
|
||||
"""Reject deprecated wrapper calls that pass a non-LLM model instance."""
|
||||
if model_instance.model_type_instance.model_type != ModelType.LLM:
|
||||
raise ValueError("LLM quota helpers only support LLM model instances.")
|
||||
|
||||
|
||||
def ensure_llm_quota_available(*, model_instance: ModelInstance) -> None:
|
||||
"""Deprecated compatibility wrapper for callers that still pass ModelInstance."""
|
||||
warnings.warn(
|
||||
"ensure_llm_quota_available(model_instance=...) is deprecated; "
|
||||
"use ensure_llm_quota_available_for_model(...) instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
_require_llm_model_instance(model_instance)
|
||||
ensure_llm_quota_available_for_model(
|
||||
tenant_id=model_instance.provider_model_bundle.configuration.tenant_id,
|
||||
provider=model_instance.provider,
|
||||
model=model_instance.model_name,
|
||||
)
|
||||
|
||||
|
||||
def deduct_llm_quota(*, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None:
|
||||
"""Deprecated compatibility wrapper for callers that still pass ModelInstance."""
|
||||
warnings.warn(
|
||||
"deduct_llm_quota(tenant_id=..., model_instance=..., usage=...) is deprecated; "
|
||||
"use deduct_llm_quota_for_model(...) instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
_require_llm_model_instance(model_instance)
|
||||
deduct_llm_quota_for_model(
|
||||
tenant_id=tenant_id,
|
||||
provider=model_instance.provider,
|
||||
model=model_instance.model_name,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
@@ -1,36 +1,48 @@
|
||||
"""
|
||||
LLM quota deduction layer for GraphEngine.
|
||||
|
||||
This layer centralizes model-quota deduction outside node implementations.
|
||||
This layer centralizes model-quota handling outside node implementations.
|
||||
|
||||
Graphon LLM-backed nodes expose provider/model identity through public node
|
||||
configuration and, after execution, through ``node_run_result.inputs``. Resolve
|
||||
quota billing from that public identity instead of depending on
|
||||
``ModelInstance`` reconstruction inside the workflow layer. Missing identity on
|
||||
quota-tracked nodes is treated as a workflow bug and aborts execution so quota
|
||||
handling is never silently skipped.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, cast, final, override
|
||||
from typing import final, override
|
||||
|
||||
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext
|
||||
from core.app.llm import deduct_llm_quota, ensure_llm_quota_available
|
||||
from core.app.llm import deduct_llm_quota_for_model, ensure_llm_quota_available_for_model
|
||||
from core.errors.error import QuotaExceededError
|
||||
from core.model_manager import ModelInstance
|
||||
from graphon.enums import BuiltinNodeTypes
|
||||
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus
|
||||
from graphon.graph_engine.entities.commands import AbortCommand, CommandType
|
||||
from graphon.graph_engine.layers import GraphEngineLayer
|
||||
from graphon.graph_events import GraphEngineEvent, GraphNodeEventBase, NodeRunSucceededEvent
|
||||
from graphon.node_events import NodeRunResult
|
||||
from graphon.nodes.base.node import Node
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from graphon.nodes.llm.node import LLMNode
|
||||
from graphon.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
|
||||
from graphon.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_QUOTA_NODE_TYPES = frozenset(
|
||||
[
|
||||
BuiltinNodeTypes.LLM,
|
||||
BuiltinNodeTypes.PARAMETER_EXTRACTOR,
|
||||
BuiltinNodeTypes.QUESTION_CLASSIFIER,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@final
|
||||
class LLMQuotaLayer(GraphEngineLayer):
|
||||
"""Graph layer that applies LLM quota deduction after node execution."""
|
||||
"""Graph layer that applies tenant-scoped quota checks to LLM-backed nodes."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
tenant_id: str
|
||||
_abort_sent: bool
|
||||
|
||||
def __init__(self, tenant_id: str) -> None:
|
||||
super().__init__()
|
||||
self.tenant_id = tenant_id
|
||||
self._abort_sent = False
|
||||
|
||||
@override
|
||||
@@ -50,33 +62,49 @@ class LLMQuotaLayer(GraphEngineLayer):
|
||||
if self._abort_sent:
|
||||
return
|
||||
|
||||
model_instance = self._extract_model_instance(node)
|
||||
if model_instance is None:
|
||||
if not self._supports_quota(node):
|
||||
return
|
||||
|
||||
model_identity = self._extract_model_identity_from_node(node)
|
||||
if model_identity is None:
|
||||
reason = "LLM quota check requires public node model identity before execution."
|
||||
self._abort_before_node_run(node=node, reason=reason, error_type="LLMQuotaIdentityError")
|
||||
logger.error("LLM quota handling aborted, node_id=%s, reason=%s", node.id, reason)
|
||||
return
|
||||
|
||||
provider, model_name = model_identity
|
||||
try:
|
||||
ensure_llm_quota_available(model_instance=model_instance)
|
||||
ensure_llm_quota_available_for_model(
|
||||
tenant_id=self.tenant_id,
|
||||
provider=provider,
|
||||
model=model_name,
|
||||
)
|
||||
except QuotaExceededError as exc:
|
||||
self._set_stop_event(node)
|
||||
self._send_abort_command(reason=str(exc))
|
||||
self._abort_before_node_run(node=node, reason=str(exc), error_type=QuotaExceededError.__name__)
|
||||
logger.warning("LLM quota check failed, node_id=%s, error=%s", node.id, exc)
|
||||
|
||||
@override
|
||||
def on_node_run_end(
|
||||
self, node: Node, error: Exception | None, result_event: GraphNodeEventBase | None = None
|
||||
) -> None:
|
||||
if error is not None or not isinstance(result_event, NodeRunSucceededEvent):
|
||||
if error is not None or not isinstance(result_event, NodeRunSucceededEvent) or not self._supports_quota(node):
|
||||
return
|
||||
|
||||
model_instance = self._extract_model_instance(node)
|
||||
if model_instance is None:
|
||||
model_identity = self._extract_model_identity_from_result_event(result_event)
|
||||
if model_identity is None:
|
||||
self._abort_for_missing_model_identity(
|
||||
node=node,
|
||||
reason="LLM quota deduction requires model identity in the node result event.",
|
||||
)
|
||||
return
|
||||
|
||||
provider, model_name = model_identity
|
||||
|
||||
try:
|
||||
dify_ctx = DifyRunContext.model_validate(node.require_run_context_value(DIFY_RUN_CONTEXT_KEY))
|
||||
deduct_llm_quota(
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
model_instance=model_instance,
|
||||
deduct_llm_quota_for_model(
|
||||
tenant_id=self.tenant_id,
|
||||
provider=provider,
|
||||
model=model_name,
|
||||
usage=result_event.node_run_result.llm_usage,
|
||||
)
|
||||
except QuotaExceededError as exc:
|
||||
@@ -92,6 +120,27 @@ class LLMQuotaLayer(GraphEngineLayer):
|
||||
if stop_event is not None:
|
||||
stop_event.set()
|
||||
|
||||
def _abort_before_node_run(self, *, node: Node, reason: str, error_type: str) -> None:
|
||||
self._set_stop_event(node)
|
||||
node.node_data.error_strategy = None
|
||||
node.node_data.retry_config.retry_enabled = False
|
||||
|
||||
def quota_aborted_run() -> NodeRunResult:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=reason,
|
||||
error_type=error_type,
|
||||
)
|
||||
|
||||
# TODO: Push Graphon to expose a public pre-run failure/skip hook, then replace this private _run override.
|
||||
node._run = quota_aborted_run # type: ignore[method-assign]
|
||||
self._send_abort_command(reason=reason)
|
||||
|
||||
def _abort_for_missing_model_identity(self, *, node: Node, reason: str) -> None:
|
||||
self._set_stop_event(node)
|
||||
self._send_abort_command(reason=reason)
|
||||
logger.error("LLM quota handling aborted, node_id=%s, reason=%s", node.id, reason)
|
||||
|
||||
def _send_abort_command(self, *, reason: str) -> None:
|
||||
if not self.command_channel or self._abort_sent:
|
||||
return
|
||||
@@ -108,29 +157,38 @@ class LLMQuotaLayer(GraphEngineLayer):
|
||||
logger.exception("Failed to send quota abort command")
|
||||
|
||||
@staticmethod
|
||||
def _extract_model_instance(node: Node) -> ModelInstance | None:
|
||||
try:
|
||||
match node.node_type:
|
||||
case BuiltinNodeTypes.LLM:
|
||||
model_instance = cast("LLMNode", node).model_instance
|
||||
case BuiltinNodeTypes.PARAMETER_EXTRACTOR:
|
||||
model_instance = cast("ParameterExtractorNode", node).model_instance
|
||||
case BuiltinNodeTypes.QUESTION_CLASSIFIER:
|
||||
model_instance = cast("QuestionClassifierNode", node).model_instance
|
||||
case _:
|
||||
return None
|
||||
except AttributeError:
|
||||
def _supports_quota(node: Node) -> bool:
|
||||
return node.node_type in _QUOTA_NODE_TYPES
|
||||
|
||||
@staticmethod
|
||||
def _extract_model_identity_from_result_event(result_event: NodeRunSucceededEvent) -> tuple[str, str] | None:
|
||||
provider = result_event.node_run_result.inputs.get("model_provider")
|
||||
model_name = result_event.node_run_result.inputs.get("model_name")
|
||||
if isinstance(provider, str) and provider and isinstance(model_name, str) and model_name:
|
||||
return provider, model_name
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _extract_model_identity_from_node(node: Node) -> tuple[str, str] | None:
|
||||
node_data = getattr(node, "node_data", None)
|
||||
if node_data is None:
|
||||
node_data = getattr(node, "data", None)
|
||||
|
||||
model_config = getattr(node_data, "model", None)
|
||||
if model_config is None:
|
||||
logger.warning(
|
||||
"LLMQuotaLayer skipped quota deduction because node does not expose a model instance, node_id=%s",
|
||||
"LLMQuotaLayer skipped quota handling because node model config is missing, node_id=%s",
|
||||
node.id,
|
||||
)
|
||||
return None
|
||||
|
||||
if isinstance(model_instance, ModelInstance):
|
||||
return model_instance
|
||||
|
||||
raw_model_instance = getattr(model_instance, "_model_instance", None)
|
||||
if isinstance(raw_model_instance, ModelInstance):
|
||||
return raw_model_instance
|
||||
provider = getattr(model_config, "provider", None)
|
||||
model_name = getattr(model_config, "name", None)
|
||||
if isinstance(provider, str) and provider and isinstance(model_name, str) and model_name:
|
||||
return provider, model_name
|
||||
|
||||
logger.warning(
|
||||
"LLMQuotaLayer skipped quota handling because node model identity is invalid, node_id=%s",
|
||||
node.id,
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -23,7 +23,7 @@ from core.entities.provider_entities import (
|
||||
)
|
||||
from core.helper import encrypter
|
||||
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory
|
||||
from core.plugin.impl.model_runtime_factory import create_model_type_instance, create_plugin_model_assembly
|
||||
from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
|
||||
from graphon.model_runtime.entities.provider_entities import (
|
||||
ConfigurateMethod,
|
||||
@@ -33,7 +33,7 @@ from graphon.model_runtime.entities.provider_entities import (
|
||||
)
|
||||
from graphon.model_runtime.model_providers.base.ai_model import AIModel
|
||||
from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from graphon.model_runtime.runtime import ModelRuntime
|
||||
from graphon.model_runtime.protocols.runtime import ModelRuntime
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.engine import db
|
||||
from models.enums import CredentialSourceType
|
||||
@@ -106,11 +106,18 @@ class ProviderConfiguration(BaseModel):
|
||||
"""Attach the already-composed runtime for request-bound call chains."""
|
||||
self._bound_model_runtime = model_runtime
|
||||
|
||||
def _get_runtime_and_provider_factory(self) -> tuple[ModelRuntime, ModelProviderFactory]:
|
||||
"""Resolve a provider factory that stays aligned with the runtime used by the caller."""
|
||||
if self._bound_model_runtime is not None:
|
||||
return self._bound_model_runtime, ModelProviderFactory(runtime=self._bound_model_runtime)
|
||||
|
||||
model_assembly = create_plugin_model_assembly(tenant_id=self.tenant_id)
|
||||
return model_assembly.model_runtime, model_assembly.model_provider_factory
|
||||
|
||||
def get_model_provider_factory(self) -> ModelProviderFactory:
|
||||
"""Return a provider factory that preserves any request-bound runtime."""
|
||||
if self._bound_model_runtime is not None:
|
||||
return ModelProviderFactory(model_runtime=self._bound_model_runtime)
|
||||
return create_plugin_model_provider_factory(tenant_id=self.tenant_id)
|
||||
_, model_provider_factory = self._get_runtime_and_provider_factory()
|
||||
return model_provider_factory
|
||||
|
||||
def get_current_credentials(self, model_type: ModelType, model: str) -> dict[str, Any] | None:
|
||||
"""
|
||||
@@ -1392,10 +1399,13 @@ class ProviderConfiguration(BaseModel):
|
||||
:param model_type: model type
|
||||
:return:
|
||||
"""
|
||||
model_provider_factory = self.get_model_provider_factory()
|
||||
|
||||
# Get model instance of LLM
|
||||
return model_provider_factory.get_model_type_instance(provider=self.provider.provider, model_type=model_type)
|
||||
model_runtime, model_provider_factory = self._get_runtime_and_provider_factory()
|
||||
provider_schema = model_provider_factory.get_provider_schema(provider=self.provider.provider)
|
||||
return create_model_type_instance(
|
||||
runtime=model_runtime,
|
||||
provider_schema=provider_schema,
|
||||
model_type=model_type,
|
||||
)
|
||||
|
||||
def get_model_schema(
|
||||
self, model_type: ModelType, model: str, credentials: dict[str, Any] | None
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import cast
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.entities import DEFAULT_PLUGIN_ID
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly
|
||||
from extensions.ext_hosting_provider import hosting_configuration
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from graphon.model_runtime.errors.invoke import InvokeBadRequestError
|
||||
@@ -41,10 +41,8 @@ def check_moderation(tenant_id: str, model_config: ModelConfigWithCredentialsEnt
|
||||
text_chunk = secrets.choice(text_chunks)
|
||||
|
||||
try:
|
||||
model_provider_factory = create_plugin_model_provider_factory(tenant_id=tenant_id)
|
||||
|
||||
# Get model instance of LLM
|
||||
model_type_instance = model_provider_factory.get_model_type_instance(
|
||||
model_assembly = create_plugin_model_assembly(tenant_id=tenant_id)
|
||||
model_type_instance = model_assembly.create_model_type_instance(
|
||||
provider=openai_provider_name, model_type=ModelType.MODERATION
|
||||
)
|
||||
model_type_instance = cast(ModerationModel, model_type_instance)
|
||||
|
||||
@@ -4,23 +4,32 @@ import hashlib
|
||||
import logging
|
||||
from collections.abc import Generator, Iterable, Sequence
|
||||
from threading import Lock
|
||||
from typing import IO, Any, Union
|
||||
from typing import IO, Any, Literal, cast, overload
|
||||
|
||||
from pydantic import ValidationError
|
||||
from redis import RedisError
|
||||
|
||||
from configs import dify_config
|
||||
from core.llm_generator.output_parser.structured_output import (
|
||||
invoke_llm_with_structured_output as invoke_llm_with_structured_output_helper,
|
||||
)
|
||||
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
||||
from core.plugin.impl.asset import PluginAssetManager
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
from extensions.ext_redis import redis_client
|
||||
from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
from graphon.model_runtime.entities.llm_entities import (
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMResultChunkWithStructuredOutput,
|
||||
LLMResultWithStructuredOutput,
|
||||
)
|
||||
from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||
from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||
from graphon.model_runtime.entities.provider_entities import ProviderEntity
|
||||
from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult
|
||||
from graphon.model_runtime.entities.text_embedding_entities import EmbeddingInputType, EmbeddingResult
|
||||
from graphon.model_runtime.runtime import ModelRuntime
|
||||
from graphon.model_runtime.model_providers.base.large_language_model import normalize_non_stream_runtime_result
|
||||
from graphon.model_runtime.protocols.runtime import ModelRuntime
|
||||
from models.provider_ids import ModelProviderID
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -29,6 +38,68 @@ logger = logging.getLogger(__name__)
|
||||
TENANT_SCOPE_SCHEMA_CACHE_USER_ID = "__DIFY_TS__"
|
||||
|
||||
|
||||
# TODO(-LAN-): Move native structured-output invocation into Graphon's LLM node.
|
||||
# TODO(-LAN-): Remove this Dify-side adapter once Graphon owns structured output end-to-end.
|
||||
class _PluginStructuredOutputModelInstance:
|
||||
"""Bind plugin model identity to the shared structured-output helper.
|
||||
|
||||
The structured-output parser is shared with legacy ``ModelInstance`` flows
|
||||
and only needs an object exposing ``invoke_llm(...)``. ``PluginModelRuntime``
|
||||
intentionally exposes a lower-level API where provider, model, and
|
||||
credentials are passed per call. This adapter supplies the small bound
|
||||
``invoke_llm`` surface the helper needs without constructing a full
|
||||
``ModelInstance`` or reintroducing model-manager dependencies into the
|
||||
plugin runtime path.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
runtime: PluginModelRuntime,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
) -> None:
|
||||
self._runtime = runtime
|
||||
self._provider = provider
|
||||
self._model = model
|
||||
self._credentials = credentials
|
||||
|
||||
def invoke_llm(
|
||||
self,
|
||||
*,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
model_parameters: dict[str, Any] | None = None,
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: Sequence[str] | None = None,
|
||||
stream: bool = True,
|
||||
callbacks: object | None = None,
|
||||
) -> LLMResult | Generator[LLMResultChunk, None, None]:
|
||||
del callbacks
|
||||
if stream:
|
||||
return self._runtime.invoke_llm(
|
||||
provider=self._provider,
|
||||
model=self._model,
|
||||
credentials=self._credentials,
|
||||
model_parameters=model_parameters or {},
|
||||
prompt_messages=prompt_messages,
|
||||
tools=list(tools) if tools else None,
|
||||
stop=stop,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
return self._runtime.invoke_llm(
|
||||
provider=self._provider,
|
||||
model=self._model,
|
||||
credentials=self._credentials,
|
||||
model_parameters=model_parameters or {},
|
||||
prompt_messages=prompt_messages,
|
||||
tools=list(tools) if tools else None,
|
||||
stop=stop,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
|
||||
class PluginModelRuntime(ModelRuntime):
|
||||
"""Plugin-backed runtime adapter bound to tenant context and optional caller scope."""
|
||||
|
||||
@@ -195,6 +266,34 @@ class PluginModelRuntime(ModelRuntime):
|
||||
|
||||
return schema
|
||||
|
||||
@overload
|
||||
def invoke_llm(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
model_parameters: dict[str, Any],
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
tools: list[PromptMessageTool] | None,
|
||||
stop: Sequence[str] | None,
|
||||
stream: Literal[False],
|
||||
) -> LLMResult: ...
|
||||
|
||||
@overload
|
||||
def invoke_llm(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
model_parameters: dict[str, Any],
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
tools: list[PromptMessageTool] | None,
|
||||
stop: Sequence[str] | None,
|
||||
stream: Literal[True],
|
||||
) -> Generator[LLMResultChunk, None, None]: ...
|
||||
|
||||
def invoke_llm(
|
||||
self,
|
||||
*,
|
||||
@@ -206,9 +305,9 @@ class PluginModelRuntime(ModelRuntime):
|
||||
tools: list[PromptMessageTool] | None,
|
||||
stop: Sequence[str] | None,
|
||||
stream: bool,
|
||||
) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]:
|
||||
) -> LLMResult | Generator[LLMResultChunk, None, None]:
|
||||
plugin_id, provider_name = self._split_provider(provider)
|
||||
return self.client.invoke_llm(
|
||||
result = self.client.invoke_llm(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
plugin_id=plugin_id,
|
||||
@@ -221,6 +320,81 @@ class PluginModelRuntime(ModelRuntime):
|
||||
stop=list(stop) if stop else None,
|
||||
stream=stream,
|
||||
)
|
||||
if stream:
|
||||
return result
|
||||
|
||||
return normalize_non_stream_runtime_result(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
result=result,
|
||||
)
|
||||
|
||||
@overload
|
||||
def invoke_llm_with_structured_output(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
json_schema: dict[str, Any],
|
||||
model_parameters: dict[str, Any],
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
stop: Sequence[str] | None,
|
||||
stream: Literal[False],
|
||||
) -> LLMResultWithStructuredOutput: ...
|
||||
|
||||
@overload
|
||||
def invoke_llm_with_structured_output(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
json_schema: dict[str, Any],
|
||||
model_parameters: dict[str, Any],
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
stop: Sequence[str] | None,
|
||||
stream: Literal[True],
|
||||
) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
|
||||
|
||||
def invoke_llm_with_structured_output(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
json_schema: dict[str, Any],
|
||||
model_parameters: dict[str, Any],
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
stop: Sequence[str] | None,
|
||||
stream: bool,
|
||||
) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]:
|
||||
model_schema = self.get_model_schema(
|
||||
provider=provider,
|
||||
model_type=ModelType.LLM,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
)
|
||||
if model_schema is None:
|
||||
raise ValueError(f"Model schema not found for {model}")
|
||||
|
||||
adapter = _PluginStructuredOutputModelInstance(
|
||||
runtime=self,
|
||||
provider=provider,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
)
|
||||
return invoke_llm_with_structured_output_helper(
|
||||
provider=provider,
|
||||
model_schema=model_schema,
|
||||
model_instance=cast(Any, adapter),
|
||||
prompt_messages=prompt_messages,
|
||||
json_schema=json_schema,
|
||||
model_parameters=model_parameters,
|
||||
tools=None,
|
||||
stop=list(stop) if stop else None,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
def get_llm_num_tokens(
|
||||
self,
|
||||
|
||||
@@ -3,13 +3,46 @@ from __future__ import annotations
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from graphon.model_runtime.entities.provider_entities import ProviderEntity
|
||||
from graphon.model_runtime.model_providers.base.ai_model import AIModel
|
||||
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
|
||||
from graphon.model_runtime.model_providers.base.moderation_model import ModerationModel
|
||||
from graphon.model_runtime.model_providers.base.rerank_model import RerankModel
|
||||
from graphon.model_runtime.model_providers.base.speech2text_model import Speech2TextModel
|
||||
from graphon.model_runtime.model_providers.base.text_embedding_model import TextEmbeddingModel
|
||||
from graphon.model_runtime.model_providers.base.tts_model import TTSModel
|
||||
from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from graphon.model_runtime.protocols.runtime import ModelRuntime
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.model_manager import ModelManager
|
||||
from core.plugin.impl.model_runtime import PluginModelRuntime
|
||||
from core.provider_manager import ProviderManager
|
||||
|
||||
_MODEL_CLASS_BY_TYPE: dict[ModelType, type[AIModel]] = {
|
||||
ModelType.LLM: LargeLanguageModel,
|
||||
ModelType.TEXT_EMBEDDING: TextEmbeddingModel,
|
||||
ModelType.RERANK: RerankModel,
|
||||
ModelType.SPEECH2TEXT: Speech2TextModel,
|
||||
ModelType.MODERATION: ModerationModel,
|
||||
ModelType.TTS: TTSModel,
|
||||
}
|
||||
|
||||
|
||||
def create_model_type_instance(
|
||||
*,
|
||||
runtime: ModelRuntime,
|
||||
provider_schema: ProviderEntity,
|
||||
model_type: ModelType,
|
||||
) -> AIModel:
|
||||
"""Build the graphon model wrapper explicitly against the request runtime."""
|
||||
model_class = _MODEL_CLASS_BY_TYPE.get(model_type)
|
||||
if model_class is None:
|
||||
raise ValueError(f"Unsupported model type: {model_type}")
|
||||
|
||||
return model_class(provider_schema=provider_schema, model_runtime=runtime)
|
||||
|
||||
|
||||
class PluginModelAssembly:
|
||||
"""Compose request-scoped model views on top of a single plugin runtime."""
|
||||
@@ -38,9 +71,22 @@ class PluginModelAssembly:
|
||||
@property
|
||||
def model_provider_factory(self) -> ModelProviderFactory:
|
||||
if self._model_provider_factory is None:
|
||||
self._model_provider_factory = ModelProviderFactory(model_runtime=self.model_runtime)
|
||||
self._model_provider_factory = ModelProviderFactory(runtime=self.model_runtime)
|
||||
return self._model_provider_factory
|
||||
|
||||
def create_model_type_instance(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model_type: ModelType,
|
||||
) -> AIModel:
|
||||
provider_schema = self.model_provider_factory.get_provider_schema(provider=provider)
|
||||
return create_model_type_instance(
|
||||
runtime=self.model_runtime,
|
||||
provider_schema=provider_schema,
|
||||
model_type=model_type,
|
||||
)
|
||||
|
||||
@property
|
||||
def provider_manager(self) -> ProviderManager:
|
||||
if self._provider_manager is None:
|
||||
|
||||
@@ -56,7 +56,7 @@ from models.provider_ids import ModelProviderID
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from graphon.model_runtime.runtime import ModelRuntime
|
||||
from graphon.model_runtime.protocols.runtime import ModelRuntime
|
||||
|
||||
_credentials_adapter: TypeAdapter[dict[str, Any]] = TypeAdapter(dict[str, Any])
|
||||
|
||||
@@ -165,7 +165,7 @@ class ProviderManager:
|
||||
)
|
||||
|
||||
# Get all provider entities
|
||||
model_provider_factory = ModelProviderFactory(model_runtime=self._model_runtime)
|
||||
model_provider_factory = ModelProviderFactory(runtime=self._model_runtime)
|
||||
provider_entities = model_provider_factory.get_providers()
|
||||
|
||||
# Get All preferred provider types of the workspace
|
||||
@@ -362,7 +362,7 @@ class ProviderManager:
|
||||
if not default_model:
|
||||
return None
|
||||
|
||||
model_provider_factory = ModelProviderFactory(model_runtime=self._model_runtime)
|
||||
model_provider_factory = ModelProviderFactory(runtime=self._model_runtime)
|
||||
provider_schema = model_provider_factory.get_provider_schema(provider=default_model.provider_name)
|
||||
|
||||
return DefaultModelEntity(
|
||||
|
||||
@@ -374,11 +374,6 @@ class DifyNodeFactory(NodeFactory):
|
||||
# Re-validate using the resolved node class so workflow-local node schemas
|
||||
# stay explicit and constructors receive the concrete typed payload.
|
||||
resolved_node_data = self._validate_resolved_node_data(node_class, node_data)
|
||||
config_for_node_init: BaseNodeData | dict[str, Any]
|
||||
if isinstance(resolved_node_data, BaseNodeData):
|
||||
config_for_node_init = resolved_node_data.model_dump(mode="python", by_alias=True)
|
||||
else:
|
||||
config_for_node_init = resolved_node_data
|
||||
node_type = node_data.type
|
||||
node_init_kwargs_factories: Mapping[NodeType, Callable[[], dict[str, object]]] = {
|
||||
BuiltinNodeTypes.CODE: lambda: {
|
||||
@@ -446,9 +441,10 @@ class DifyNodeFactory(NodeFactory):
|
||||
},
|
||||
}
|
||||
node_init_kwargs = node_init_kwargs_factories.get(node_type, lambda: {})()
|
||||
constructor_node_data = resolved_node_data.model_dump(mode="python", by_alias=True)
|
||||
return node_class(
|
||||
node_id=node_id,
|
||||
config=config_for_node_init,
|
||||
data=constructor_node_data,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
**node_init_kwargs,
|
||||
|
||||
@@ -35,7 +35,7 @@ class AgentNode(Node[AgentNodeData]):
|
||||
def __init__(
|
||||
self,
|
||||
node_id: str,
|
||||
config: AgentNodeData,
|
||||
data: AgentNodeData,
|
||||
*,
|
||||
graph_init_params: GraphInitParams,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
@@ -46,7 +46,7 @@ class AgentNode(Node[AgentNodeData]):
|
||||
) -> None:
|
||||
super().__init__(
|
||||
node_id=node_id,
|
||||
config=config,
|
||||
data=data,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
@@ -36,14 +36,14 @@ class DatasourceNode(Node[DatasourceNodeData]):
|
||||
def __init__(
|
||||
self,
|
||||
node_id: str,
|
||||
config: DatasourceNodeData,
|
||||
data: DatasourceNodeData,
|
||||
*,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
) -> None:
|
||||
super().__init__(
|
||||
node_id=node_id,
|
||||
config=config,
|
||||
data=data,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
@@ -32,14 +32,14 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
|
||||
def __init__(
|
||||
self,
|
||||
node_id: str,
|
||||
config: KnowledgeIndexNodeData,
|
||||
data: KnowledgeIndexNodeData,
|
||||
*,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
) -> None:
|
||||
super().__init__(
|
||||
node_id=node_id,
|
||||
config=config,
|
||||
data=data,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
@@ -71,14 +71,14 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
def __init__(
|
||||
self,
|
||||
node_id: str,
|
||||
config: KnowledgeRetrievalNodeData,
|
||||
data: KnowledgeRetrievalNodeData,
|
||||
*,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
) -> None:
|
||||
super().__init__(
|
||||
node_id=node_id,
|
||||
config=config,
|
||||
data=data,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
from collections import defaultdict
|
||||
from collections.abc import Mapping, Sequence
|
||||
from enum import StrEnum
|
||||
from typing import Any, Protocol, cast
|
||||
from typing import Any, Protocol
|
||||
from uuid import uuid4
|
||||
|
||||
from graphon.enums import BuiltinNodeTypes
|
||||
@@ -82,13 +82,10 @@ def build_system_variables(values: Mapping[str, Any] | None = None, /, **kwargs:
|
||||
normalized = _normalize_system_variable_values(values, **kwargs)
|
||||
|
||||
return [
|
||||
cast(
|
||||
Variable,
|
||||
segment_to_variable(
|
||||
segment=build_segment(value),
|
||||
selector=system_variable_selector(key),
|
||||
name=key,
|
||||
),
|
||||
segment_to_variable(
|
||||
segment=build_segment(value),
|
||||
selector=system_variable_selector(key),
|
||||
name=key,
|
||||
)
|
||||
for key, value in normalized.items()
|
||||
]
|
||||
@@ -130,13 +127,10 @@ def build_bootstrap_variables(
|
||||
|
||||
for node_id, value in rag_pipeline_variables_map.items():
|
||||
variables.append(
|
||||
cast(
|
||||
Variable,
|
||||
segment_to_variable(
|
||||
segment=build_segment(value),
|
||||
selector=(RAG_PIPELINE_VARIABLE_NODE_ID, node_id),
|
||||
name=node_id,
|
||||
),
|
||||
segment_to_variable(
|
||||
segment=build_segment(value),
|
||||
selector=(RAG_PIPELINE_VARIABLE_NODE_ID, node_id),
|
||||
name=node_id,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -46,6 +46,11 @@ _file_access_controller = DatabaseFileAccessController()
|
||||
|
||||
|
||||
class _WorkflowChildEngineBuilder:
|
||||
tenant_id: str
|
||||
|
||||
def __init__(self, *, tenant_id: str) -> None:
|
||||
self.tenant_id = tenant_id
|
||||
|
||||
@staticmethod
|
||||
def _has_node_id(graph_config: Mapping[str, Any], node_id: str) -> bool | None:
|
||||
"""
|
||||
@@ -107,7 +112,7 @@ class _WorkflowChildEngineBuilder:
|
||||
config=config,
|
||||
child_engine_builder=self,
|
||||
)
|
||||
child_engine.layer(LLMQuotaLayer())
|
||||
child_engine.layer(LLMQuotaLayer(tenant_id=self.tenant_id))
|
||||
return child_engine
|
||||
|
||||
|
||||
@@ -176,7 +181,7 @@ class WorkflowEntry:
|
||||
self.command_channel = command_channel
|
||||
execution_context = capture_current_context()
|
||||
graph_runtime_state.execution_context = execution_context
|
||||
self._child_engine_builder = _WorkflowChildEngineBuilder()
|
||||
self._child_engine_builder = _WorkflowChildEngineBuilder(tenant_id=tenant_id)
|
||||
self.graph_engine = GraphEngine(
|
||||
workflow_id=workflow_id,
|
||||
graph=graph,
|
||||
@@ -208,7 +213,7 @@ class WorkflowEntry:
|
||||
max_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS, max_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME
|
||||
)
|
||||
self.graph_engine.layer(limits_layer)
|
||||
self.graph_engine.layer(LLMQuotaLayer())
|
||||
self.graph_engine.layer(LLMQuotaLayer(tenant_id=tenant_id))
|
||||
|
||||
# Add observability layer when OTel is enabled
|
||||
if dify_config.ENABLE_OTEL or is_instrument_flag_enabled():
|
||||
|
||||
Reference in New Issue
Block a user