chore(api): upgrade graphon to v0.3.0 (#35469)

Signed-off-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: WH-2099 <wh2099@pm.me>
This commit is contained in:
-LAN-
2026-05-09 15:30:03 +08:00
committed by GitHub
parent f3eb3ab4dd
commit 19476109da
80 changed files with 2526 additions and 673 deletions

View File

@@ -1,5 +1,15 @@
"""LLM-related application services."""
from .quota import deduct_llm_quota, ensure_llm_quota_available
from .quota import (
deduct_llm_quota,
deduct_llm_quota_for_model,
ensure_llm_quota_available,
ensure_llm_quota_available_for_model,
)
__all__ = ["deduct_llm_quota", "ensure_llm_quota_available"]
__all__ = [
"deduct_llm_quota",
"deduct_llm_quota_for_model",
"ensure_llm_quota_available",
"ensure_llm_quota_available_for_model",
]

View File

@@ -1,4 +1,14 @@
from sqlalchemy import update
"""Tenant-scoped helpers for checking and deducting LLM provider quota.
System-hosted quota accounting is currently defined only for LLM models. Keep
the public helpers LLM-specific so callers do not carry unused model-type
plumbing, and fail loudly if the deprecated ``ModelInstance`` wrappers are used
with a non-LLM model.
"""
import warnings
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from configs import dify_config
@@ -6,44 +16,47 @@ from core.entities.model_entities import ModelStatus
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit
from core.errors.error import QuotaExceededError
from core.model_manager import ModelInstance
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
from extensions.ext_database import db
from graphon.model_runtime.entities.llm_entities import LLMUsage
from graphon.model_runtime.entities.model_entities import ModelType
from libs.datetime_utils import naive_utc_now
from models.provider import Provider, ProviderType
from models.provider_ids import ModelProviderID
def ensure_llm_quota_available(*, model_instance: ModelInstance) -> None:
provider_model_bundle = model_instance.provider_model_bundle
provider_configuration = provider_model_bundle.configuration
def _get_provider_configuration(*, tenant_id: str, provider: str):
"""Resolve the tenant-bound provider configuration for quota decisions."""
provider_manager = create_plugin_provider_manager(tenant_id=tenant_id)
provider_configuration = provider_manager.get_configurations(tenant_id).get(provider)
if provider_configuration is None:
raise ValueError(f"Provider {provider} does not exist.")
return provider_configuration
def ensure_llm_quota_available_for_model(*, tenant_id: str, provider: str, model: str) -> None:
"""Raise when a tenant-bound LLM model is already out of quota."""
provider_configuration = _get_provider_configuration(tenant_id=tenant_id, provider=provider)
if provider_configuration.using_provider_type != ProviderType.SYSTEM:
return
provider_model = provider_configuration.get_provider_model(
model_type=model_instance.model_type_instance.model_type,
model=model_instance.model_name,
model_type=ModelType.LLM,
model=model,
)
if provider_model and provider_model.status == ModelStatus.QUOTA_EXCEEDED:
raise QuotaExceededError(f"Model provider {model_instance.provider} quota exceeded.")
raise QuotaExceededError(f"Model provider {provider} quota exceeded.")
def deduct_llm_quota(*, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None:
provider_model_bundle = model_instance.provider_model_bundle
provider_configuration = provider_model_bundle.configuration
if provider_configuration.using_provider_type != ProviderType.SYSTEM:
return
system_configuration = provider_configuration.system_configuration
def _resolve_llm_used_quota(*, system_configuration, model: str, usage: LLMUsage) -> int | None:
"""Compute the quota impact for an LLM invocation under the current quota mode."""
quota_unit = None
for quota_configuration in system_configuration.quota_configurations:
if quota_configuration.quota_type == system_configuration.current_quota_type:
quota_unit = quota_configuration.quota_unit
if quota_configuration.quota_limit == -1:
return
return None
break
@@ -52,42 +65,136 @@ def deduct_llm_quota(*, tenant_id: str, model_instance: ModelInstance, usage: LL
if quota_unit == QuotaUnit.TOKENS:
used_quota = usage.total_tokens
elif quota_unit == QuotaUnit.CREDITS:
used_quota = dify_config.get_model_credits(model_instance.model_name)
used_quota = dify_config.get_model_credits(model)
else:
used_quota = 1
return used_quota
def _deduct_free_llm_quota(
*,
tenant_id: str,
provider: str,
quota_type: ProviderQuotaType,
used_quota: int,
) -> None:
"""Deduct FREE provider quota, capping at the limit before reporting exhaustion."""
quota_exceeded = False
with sessionmaker(bind=db.engine).begin() as session:
provider_record = session.scalar(
select(Provider)
.where(
Provider.tenant_id == tenant_id,
# TODO: Use provider name with prefix after the data migration.
Provider.provider_name == ModelProviderID(provider).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == quota_type,
)
.with_for_update()
)
if (
provider_record is None
or provider_record.quota_limit is None
or provider_record.quota_used is None
or provider_record.quota_limit <= provider_record.quota_used
):
quota_exceeded = True
else:
available_quota = provider_record.quota_limit - provider_record.quota_used
deducted_quota = min(used_quota, available_quota)
provider_record.quota_used += deducted_quota
provider_record.last_used = naive_utc_now()
quota_exceeded = deducted_quota < used_quota
if quota_exceeded:
raise QuotaExceededError(f"Model provider {provider} quota exceeded.")
def _deduct_used_llm_quota(*, tenant_id: str, provider: str, provider_configuration, used_quota: int | None) -> None:
"""Apply a resolved LLM quota charge against the current provider quota bucket."""
if provider_configuration.using_provider_type != ProviderType.SYSTEM:
return
system_configuration = provider_configuration.system_configuration
if used_quota is not None and system_configuration.current_quota_type is not None:
match system_configuration.current_quota_type:
case ProviderQuotaType.TRIAL:
from services.credit_pool_service import CreditPoolService
CreditPoolService.check_and_deduct_credits(
CreditPoolService.deduct_credits_capped(
tenant_id=tenant_id,
credits_required=used_quota,
)
case ProviderQuotaType.PAID:
from services.credit_pool_service import CreditPoolService
CreditPoolService.check_and_deduct_credits(
CreditPoolService.deduct_credits_capped(
tenant_id=tenant_id,
credits_required=used_quota,
pool_type="paid",
)
case ProviderQuotaType.FREE:
with sessionmaker(bind=db.engine).begin() as session:
stmt = (
update(Provider)
.where(
Provider.tenant_id == tenant_id,
# TODO: Use provider name with prefix after the data migration.
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == system_configuration.current_quota_type,
Provider.quota_limit > Provider.quota_used,
)
.values(
quota_used=Provider.quota_used + used_quota,
last_used=naive_utc_now(),
)
)
session.execute(stmt)
_deduct_free_llm_quota(
tenant_id=tenant_id,
provider=provider,
quota_type=system_configuration.current_quota_type,
used_quota=used_quota,
)
case _:
return
def deduct_llm_quota_for_model(*, tenant_id: str, provider: str, model: str, usage: LLMUsage) -> None:
"""Deduct tenant-bound quota for the resolved LLM model identity."""
provider_configuration = _get_provider_configuration(tenant_id=tenant_id, provider=provider)
used_quota = _resolve_llm_used_quota(
system_configuration=provider_configuration.system_configuration,
model=model,
usage=usage,
)
_deduct_used_llm_quota(
tenant_id=tenant_id,
provider=provider,
provider_configuration=provider_configuration,
used_quota=used_quota,
)
def _require_llm_model_instance(model_instance: ModelInstance) -> None:
"""Reject deprecated wrapper calls that pass a non-LLM model instance."""
if model_instance.model_type_instance.model_type != ModelType.LLM:
raise ValueError("LLM quota helpers only support LLM model instances.")
def ensure_llm_quota_available(*, model_instance: ModelInstance) -> None:
"""Deprecated compatibility wrapper for callers that still pass ModelInstance."""
warnings.warn(
"ensure_llm_quota_available(model_instance=...) is deprecated; "
"use ensure_llm_quota_available_for_model(...) instead.",
DeprecationWarning,
stacklevel=2,
)
_require_llm_model_instance(model_instance)
ensure_llm_quota_available_for_model(
tenant_id=model_instance.provider_model_bundle.configuration.tenant_id,
provider=model_instance.provider,
model=model_instance.model_name,
)
def deduct_llm_quota(*, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None:
"""Deprecated compatibility wrapper for callers that still pass ModelInstance."""
warnings.warn(
"deduct_llm_quota(tenant_id=..., model_instance=..., usage=...) is deprecated; "
"use deduct_llm_quota_for_model(...) instead.",
DeprecationWarning,
stacklevel=2,
)
_require_llm_model_instance(model_instance)
deduct_llm_quota_for_model(
tenant_id=tenant_id,
provider=model_instance.provider,
model=model_instance.model_name,
usage=usage,
)

View File

@@ -1,36 +1,48 @@
"""
LLM quota deduction layer for GraphEngine.
This layer centralizes model-quota deduction outside node implementations.
This layer centralizes model-quota handling outside node implementations.
Graphon LLM-backed nodes expose provider/model identity through public node
configuration and, after execution, through ``node_run_result.inputs``. Resolve
quota billing from that public identity instead of depending on
``ModelInstance`` reconstruction inside the workflow layer. Missing identity on
quota-tracked nodes is treated as a workflow bug and aborts execution so quota
handling is never silently skipped.
"""
import logging
from typing import TYPE_CHECKING, cast, final, override
from typing import final, override
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext
from core.app.llm import deduct_llm_quota, ensure_llm_quota_available
from core.app.llm import deduct_llm_quota_for_model, ensure_llm_quota_available_for_model
from core.errors.error import QuotaExceededError
from core.model_manager import ModelInstance
from graphon.enums import BuiltinNodeTypes
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus
from graphon.graph_engine.entities.commands import AbortCommand, CommandType
from graphon.graph_engine.layers import GraphEngineLayer
from graphon.graph_events import GraphEngineEvent, GraphNodeEventBase, NodeRunSucceededEvent
from graphon.node_events import NodeRunResult
from graphon.nodes.base.node import Node
if TYPE_CHECKING:
from graphon.nodes.llm.node import LLMNode
from graphon.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
from graphon.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
logger = logging.getLogger(__name__)
_QUOTA_NODE_TYPES = frozenset(
[
BuiltinNodeTypes.LLM,
BuiltinNodeTypes.PARAMETER_EXTRACTOR,
BuiltinNodeTypes.QUESTION_CLASSIFIER,
]
)
@final
class LLMQuotaLayer(GraphEngineLayer):
"""Graph layer that applies LLM quota deduction after node execution."""
"""Graph layer that applies tenant-scoped quota checks to LLM-backed nodes."""
def __init__(self) -> None:
tenant_id: str
_abort_sent: bool
def __init__(self, tenant_id: str) -> None:
super().__init__()
self.tenant_id = tenant_id
self._abort_sent = False
@override
@@ -50,33 +62,49 @@ class LLMQuotaLayer(GraphEngineLayer):
if self._abort_sent:
return
model_instance = self._extract_model_instance(node)
if model_instance is None:
if not self._supports_quota(node):
return
model_identity = self._extract_model_identity_from_node(node)
if model_identity is None:
reason = "LLM quota check requires public node model identity before execution."
self._abort_before_node_run(node=node, reason=reason, error_type="LLMQuotaIdentityError")
logger.error("LLM quota handling aborted, node_id=%s, reason=%s", node.id, reason)
return
provider, model_name = model_identity
try:
ensure_llm_quota_available(model_instance=model_instance)
ensure_llm_quota_available_for_model(
tenant_id=self.tenant_id,
provider=provider,
model=model_name,
)
except QuotaExceededError as exc:
self._set_stop_event(node)
self._send_abort_command(reason=str(exc))
self._abort_before_node_run(node=node, reason=str(exc), error_type=QuotaExceededError.__name__)
logger.warning("LLM quota check failed, node_id=%s, error=%s", node.id, exc)
@override
def on_node_run_end(
self, node: Node, error: Exception | None, result_event: GraphNodeEventBase | None = None
) -> None:
if error is not None or not isinstance(result_event, NodeRunSucceededEvent):
if error is not None or not isinstance(result_event, NodeRunSucceededEvent) or not self._supports_quota(node):
return
model_instance = self._extract_model_instance(node)
if model_instance is None:
model_identity = self._extract_model_identity_from_result_event(result_event)
if model_identity is None:
self._abort_for_missing_model_identity(
node=node,
reason="LLM quota deduction requires model identity in the node result event.",
)
return
provider, model_name = model_identity
try:
dify_ctx = DifyRunContext.model_validate(node.require_run_context_value(DIFY_RUN_CONTEXT_KEY))
deduct_llm_quota(
tenant_id=dify_ctx.tenant_id,
model_instance=model_instance,
deduct_llm_quota_for_model(
tenant_id=self.tenant_id,
provider=provider,
model=model_name,
usage=result_event.node_run_result.llm_usage,
)
except QuotaExceededError as exc:
@@ -92,6 +120,27 @@ class LLMQuotaLayer(GraphEngineLayer):
if stop_event is not None:
stop_event.set()
def _abort_before_node_run(self, *, node: Node, reason: str, error_type: str) -> None:
self._set_stop_event(node)
node.node_data.error_strategy = None
node.node_data.retry_config.retry_enabled = False
def quota_aborted_run() -> NodeRunResult:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=reason,
error_type=error_type,
)
# TODO: Push Graphon to expose a public pre-run failure/skip hook, then replace this private _run override.
node._run = quota_aborted_run # type: ignore[method-assign]
self._send_abort_command(reason=reason)
def _abort_for_missing_model_identity(self, *, node: Node, reason: str) -> None:
self._set_stop_event(node)
self._send_abort_command(reason=reason)
logger.error("LLM quota handling aborted, node_id=%s, reason=%s", node.id, reason)
def _send_abort_command(self, *, reason: str) -> None:
if not self.command_channel or self._abort_sent:
return
@@ -108,29 +157,38 @@ class LLMQuotaLayer(GraphEngineLayer):
logger.exception("Failed to send quota abort command")
@staticmethod
def _extract_model_instance(node: Node) -> ModelInstance | None:
try:
match node.node_type:
case BuiltinNodeTypes.LLM:
model_instance = cast("LLMNode", node).model_instance
case BuiltinNodeTypes.PARAMETER_EXTRACTOR:
model_instance = cast("ParameterExtractorNode", node).model_instance
case BuiltinNodeTypes.QUESTION_CLASSIFIER:
model_instance = cast("QuestionClassifierNode", node).model_instance
case _:
return None
except AttributeError:
def _supports_quota(node: Node) -> bool:
return node.node_type in _QUOTA_NODE_TYPES
@staticmethod
def _extract_model_identity_from_result_event(result_event: NodeRunSucceededEvent) -> tuple[str, str] | None:
provider = result_event.node_run_result.inputs.get("model_provider")
model_name = result_event.node_run_result.inputs.get("model_name")
if isinstance(provider, str) and provider and isinstance(model_name, str) and model_name:
return provider, model_name
return None
@staticmethod
def _extract_model_identity_from_node(node: Node) -> tuple[str, str] | None:
node_data = getattr(node, "node_data", None)
if node_data is None:
node_data = getattr(node, "data", None)
model_config = getattr(node_data, "model", None)
if model_config is None:
logger.warning(
"LLMQuotaLayer skipped quota deduction because node does not expose a model instance, node_id=%s",
"LLMQuotaLayer skipped quota handling because node model config is missing, node_id=%s",
node.id,
)
return None
if isinstance(model_instance, ModelInstance):
return model_instance
raw_model_instance = getattr(model_instance, "_model_instance", None)
if isinstance(raw_model_instance, ModelInstance):
return raw_model_instance
provider = getattr(model_config, "provider", None)
model_name = getattr(model_config, "name", None)
if isinstance(provider, str) and provider and isinstance(model_name, str) and model_name:
return provider, model_name
logger.warning(
"LLMQuotaLayer skipped quota handling because node model identity is invalid, node_id=%s",
node.id,
)
return None

View File

@@ -23,7 +23,7 @@ from core.entities.provider_entities import (
)
from core.helper import encrypter
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory
from core.plugin.impl.model_runtime_factory import create_model_type_instance, create_plugin_model_assembly
from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
from graphon.model_runtime.entities.provider_entities import (
ConfigurateMethod,
@@ -33,7 +33,7 @@ from graphon.model_runtime.entities.provider_entities import (
)
from graphon.model_runtime.model_providers.base.ai_model import AIModel
from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from graphon.model_runtime.runtime import ModelRuntime
from graphon.model_runtime.protocols.runtime import ModelRuntime
from libs.datetime_utils import naive_utc_now
from models.engine import db
from models.enums import CredentialSourceType
@@ -106,11 +106,18 @@ class ProviderConfiguration(BaseModel):
"""Attach the already-composed runtime for request-bound call chains."""
self._bound_model_runtime = model_runtime
def _get_runtime_and_provider_factory(self) -> tuple[ModelRuntime, ModelProviderFactory]:
"""Resolve a provider factory that stays aligned with the runtime used by the caller."""
if self._bound_model_runtime is not None:
return self._bound_model_runtime, ModelProviderFactory(runtime=self._bound_model_runtime)
model_assembly = create_plugin_model_assembly(tenant_id=self.tenant_id)
return model_assembly.model_runtime, model_assembly.model_provider_factory
def get_model_provider_factory(self) -> ModelProviderFactory:
"""Return a provider factory that preserves any request-bound runtime."""
if self._bound_model_runtime is not None:
return ModelProviderFactory(model_runtime=self._bound_model_runtime)
return create_plugin_model_provider_factory(tenant_id=self.tenant_id)
_, model_provider_factory = self._get_runtime_and_provider_factory()
return model_provider_factory
def get_current_credentials(self, model_type: ModelType, model: str) -> dict[str, Any] | None:
"""
@@ -1392,10 +1399,13 @@ class ProviderConfiguration(BaseModel):
:param model_type: model type
:return:
"""
model_provider_factory = self.get_model_provider_factory()
# Get model instance of LLM
return model_provider_factory.get_model_type_instance(provider=self.provider.provider, model_type=model_type)
model_runtime, model_provider_factory = self._get_runtime_and_provider_factory()
provider_schema = model_provider_factory.get_provider_schema(provider=self.provider.provider)
return create_model_type_instance(
runtime=model_runtime,
provider_schema=provider_schema,
model_type=model_type,
)
def get_model_schema(
self, model_type: ModelType, model: str, credentials: dict[str, Any] | None

View File

@@ -4,7 +4,7 @@ from typing import cast
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities import DEFAULT_PLUGIN_ID
from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory
from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly
from extensions.ext_hosting_provider import hosting_configuration
from graphon.model_runtime.entities.model_entities import ModelType
from graphon.model_runtime.errors.invoke import InvokeBadRequestError
@@ -41,10 +41,8 @@ def check_moderation(tenant_id: str, model_config: ModelConfigWithCredentialsEnt
text_chunk = secrets.choice(text_chunks)
try:
model_provider_factory = create_plugin_model_provider_factory(tenant_id=tenant_id)
# Get model instance of LLM
model_type_instance = model_provider_factory.get_model_type_instance(
model_assembly = create_plugin_model_assembly(tenant_id=tenant_id)
model_type_instance = model_assembly.create_model_type_instance(
provider=openai_provider_name, model_type=ModelType.MODERATION
)
model_type_instance = cast(ModerationModel, model_type_instance)

View File

@@ -4,23 +4,32 @@ import hashlib
import logging
from collections.abc import Generator, Iterable, Sequence
from threading import Lock
from typing import IO, Any, Union
from typing import IO, Any, Literal, cast, overload
from pydantic import ValidationError
from redis import RedisError
from configs import dify_config
from core.llm_generator.output_parser.structured_output import (
invoke_llm_with_structured_output as invoke_llm_with_structured_output_helper,
)
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
from core.plugin.impl.asset import PluginAssetManager
from core.plugin.impl.model import PluginModelClient
from extensions.ext_redis import redis_client
from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from graphon.model_runtime.entities.llm_entities import (
LLMResult,
LLMResultChunk,
LLMResultChunkWithStructuredOutput,
LLMResultWithStructuredOutput,
)
from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType
from graphon.model_runtime.entities.provider_entities import ProviderEntity
from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult
from graphon.model_runtime.entities.text_embedding_entities import EmbeddingInputType, EmbeddingResult
from graphon.model_runtime.runtime import ModelRuntime
from graphon.model_runtime.model_providers.base.large_language_model import normalize_non_stream_runtime_result
from graphon.model_runtime.protocols.runtime import ModelRuntime
from models.provider_ids import ModelProviderID
logger = logging.getLogger(__name__)
@@ -29,6 +38,68 @@ logger = logging.getLogger(__name__)
TENANT_SCOPE_SCHEMA_CACHE_USER_ID = "__DIFY_TS__"
# TODO(-LAN-): Move native structured-output invocation into Graphon's LLM node.
# TODO(-LAN-): Remove this Dify-side adapter once Graphon owns structured output end-to-end.
class _PluginStructuredOutputModelInstance:
"""Bind plugin model identity to the shared structured-output helper.
The structured-output parser is shared with legacy ``ModelInstance`` flows
and only needs an object exposing ``invoke_llm(...)``. ``PluginModelRuntime``
intentionally exposes a lower-level API where provider, model, and
credentials are passed per call. This adapter supplies the small bound
``invoke_llm`` surface the helper needs without constructing a full
``ModelInstance`` or reintroducing model-manager dependencies into the
plugin runtime path.
"""
def __init__(
self,
*,
runtime: PluginModelRuntime,
provider: str,
model: str,
credentials: dict[str, Any],
) -> None:
self._runtime = runtime
self._provider = provider
self._model = model
self._credentials = credentials
def invoke_llm(
self,
*,
prompt_messages: Sequence[PromptMessage],
model_parameters: dict[str, Any] | None = None,
tools: Sequence[PromptMessageTool] | None = None,
stop: Sequence[str] | None = None,
stream: bool = True,
callbacks: object | None = None,
) -> LLMResult | Generator[LLMResultChunk, None, None]:
del callbacks
if stream:
return self._runtime.invoke_llm(
provider=self._provider,
model=self._model,
credentials=self._credentials,
model_parameters=model_parameters or {},
prompt_messages=prompt_messages,
tools=list(tools) if tools else None,
stop=stop,
stream=True,
)
return self._runtime.invoke_llm(
provider=self._provider,
model=self._model,
credentials=self._credentials,
model_parameters=model_parameters or {},
prompt_messages=prompt_messages,
tools=list(tools) if tools else None,
stop=stop,
stream=False,
)
class PluginModelRuntime(ModelRuntime):
"""Plugin-backed runtime adapter bound to tenant context and optional caller scope."""
@@ -195,6 +266,34 @@ class PluginModelRuntime(ModelRuntime):
return schema
@overload
def invoke_llm(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
model_parameters: dict[str, Any],
prompt_messages: Sequence[PromptMessage],
tools: list[PromptMessageTool] | None,
stop: Sequence[str] | None,
stream: Literal[False],
) -> LLMResult: ...
@overload
def invoke_llm(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
model_parameters: dict[str, Any],
prompt_messages: Sequence[PromptMessage],
tools: list[PromptMessageTool] | None,
stop: Sequence[str] | None,
stream: Literal[True],
) -> Generator[LLMResultChunk, None, None]: ...
def invoke_llm(
self,
*,
@@ -206,9 +305,9 @@ class PluginModelRuntime(ModelRuntime):
tools: list[PromptMessageTool] | None,
stop: Sequence[str] | None,
stream: bool,
) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]:
) -> LLMResult | Generator[LLMResultChunk, None, None]:
plugin_id, provider_name = self._split_provider(provider)
return self.client.invoke_llm(
result = self.client.invoke_llm(
tenant_id=self.tenant_id,
user_id=self.user_id,
plugin_id=plugin_id,
@@ -221,6 +320,81 @@ class PluginModelRuntime(ModelRuntime):
stop=list(stop) if stop else None,
stream=stream,
)
if stream:
return result
return normalize_non_stream_runtime_result(
model=model,
prompt_messages=prompt_messages,
result=result,
)
@overload
def invoke_llm_with_structured_output(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
json_schema: dict[str, Any],
model_parameters: dict[str, Any],
prompt_messages: Sequence[PromptMessage],
stop: Sequence[str] | None,
stream: Literal[False],
) -> LLMResultWithStructuredOutput: ...
@overload
def invoke_llm_with_structured_output(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
json_schema: dict[str, Any],
model_parameters: dict[str, Any],
prompt_messages: Sequence[PromptMessage],
stop: Sequence[str] | None,
stream: Literal[True],
) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
def invoke_llm_with_structured_output(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
json_schema: dict[str, Any],
model_parameters: dict[str, Any],
prompt_messages: Sequence[PromptMessage],
stop: Sequence[str] | None,
stream: bool,
) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]:
model_schema = self.get_model_schema(
provider=provider,
model_type=ModelType.LLM,
model=model,
credentials=credentials,
)
if model_schema is None:
raise ValueError(f"Model schema not found for {model}")
adapter = _PluginStructuredOutputModelInstance(
runtime=self,
provider=provider,
model=model,
credentials=credentials,
)
return invoke_llm_with_structured_output_helper(
provider=provider,
model_schema=model_schema,
model_instance=cast(Any, adapter),
prompt_messages=prompt_messages,
json_schema=json_schema,
model_parameters=model_parameters,
tools=None,
stop=list(stop) if stop else None,
stream=stream,
)
def get_llm_num_tokens(
self,

View File

@@ -3,13 +3,46 @@ from __future__ import annotations
from typing import TYPE_CHECKING
from core.plugin.impl.model import PluginModelClient
from graphon.model_runtime.entities.model_entities import ModelType
from graphon.model_runtime.entities.provider_entities import ProviderEntity
from graphon.model_runtime.model_providers.base.ai_model import AIModel
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from graphon.model_runtime.model_providers.base.moderation_model import ModerationModel
from graphon.model_runtime.model_providers.base.rerank_model import RerankModel
from graphon.model_runtime.model_providers.base.speech2text_model import Speech2TextModel
from graphon.model_runtime.model_providers.base.text_embedding_model import TextEmbeddingModel
from graphon.model_runtime.model_providers.base.tts_model import TTSModel
from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from graphon.model_runtime.protocols.runtime import ModelRuntime
if TYPE_CHECKING:
from core.model_manager import ModelManager
from core.plugin.impl.model_runtime import PluginModelRuntime
from core.provider_manager import ProviderManager
_MODEL_CLASS_BY_TYPE: dict[ModelType, type[AIModel]] = {
ModelType.LLM: LargeLanguageModel,
ModelType.TEXT_EMBEDDING: TextEmbeddingModel,
ModelType.RERANK: RerankModel,
ModelType.SPEECH2TEXT: Speech2TextModel,
ModelType.MODERATION: ModerationModel,
ModelType.TTS: TTSModel,
}
def create_model_type_instance(
*,
runtime: ModelRuntime,
provider_schema: ProviderEntity,
model_type: ModelType,
) -> AIModel:
"""Build the graphon model wrapper explicitly against the request runtime."""
model_class = _MODEL_CLASS_BY_TYPE.get(model_type)
if model_class is None:
raise ValueError(f"Unsupported model type: {model_type}")
return model_class(provider_schema=provider_schema, model_runtime=runtime)
class PluginModelAssembly:
"""Compose request-scoped model views on top of a single plugin runtime."""
@@ -38,9 +71,22 @@ class PluginModelAssembly:
@property
def model_provider_factory(self) -> ModelProviderFactory:
if self._model_provider_factory is None:
self._model_provider_factory = ModelProviderFactory(model_runtime=self.model_runtime)
self._model_provider_factory = ModelProviderFactory(runtime=self.model_runtime)
return self._model_provider_factory
def create_model_type_instance(
self,
*,
provider: str,
model_type: ModelType,
) -> AIModel:
provider_schema = self.model_provider_factory.get_provider_schema(provider=provider)
return create_model_type_instance(
runtime=self.model_runtime,
provider_schema=provider_schema,
model_type=model_type,
)
@property
def provider_manager(self) -> ProviderManager:
if self._provider_manager is None:

View File

@@ -56,7 +56,7 @@ from models.provider_ids import ModelProviderID
from services.feature_service import FeatureService
if TYPE_CHECKING:
from graphon.model_runtime.runtime import ModelRuntime
from graphon.model_runtime.protocols.runtime import ModelRuntime
_credentials_adapter: TypeAdapter[dict[str, Any]] = TypeAdapter(dict[str, Any])
@@ -165,7 +165,7 @@ class ProviderManager:
)
# Get all provider entities
model_provider_factory = ModelProviderFactory(model_runtime=self._model_runtime)
model_provider_factory = ModelProviderFactory(runtime=self._model_runtime)
provider_entities = model_provider_factory.get_providers()
# Get All preferred provider types of the workspace
@@ -362,7 +362,7 @@ class ProviderManager:
if not default_model:
return None
model_provider_factory = ModelProviderFactory(model_runtime=self._model_runtime)
model_provider_factory = ModelProviderFactory(runtime=self._model_runtime)
provider_schema = model_provider_factory.get_provider_schema(provider=default_model.provider_name)
return DefaultModelEntity(

View File

@@ -374,11 +374,6 @@ class DifyNodeFactory(NodeFactory):
# Re-validate using the resolved node class so workflow-local node schemas
# stay explicit and constructors receive the concrete typed payload.
resolved_node_data = self._validate_resolved_node_data(node_class, node_data)
config_for_node_init: BaseNodeData | dict[str, Any]
if isinstance(resolved_node_data, BaseNodeData):
config_for_node_init = resolved_node_data.model_dump(mode="python", by_alias=True)
else:
config_for_node_init = resolved_node_data
node_type = node_data.type
node_init_kwargs_factories: Mapping[NodeType, Callable[[], dict[str, object]]] = {
BuiltinNodeTypes.CODE: lambda: {
@@ -446,9 +441,10 @@ class DifyNodeFactory(NodeFactory):
},
}
node_init_kwargs = node_init_kwargs_factories.get(node_type, lambda: {})()
constructor_node_data = resolved_node_data.model_dump(mode="python", by_alias=True)
return node_class(
node_id=node_id,
config=config_for_node_init,
data=constructor_node_data,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
**node_init_kwargs,

View File

@@ -35,7 +35,7 @@ class AgentNode(Node[AgentNodeData]):
def __init__(
self,
node_id: str,
config: AgentNodeData,
data: AgentNodeData,
*,
graph_init_params: GraphInitParams,
graph_runtime_state: GraphRuntimeState,
@@ -46,7 +46,7 @@ class AgentNode(Node[AgentNodeData]):
) -> None:
super().__init__(
node_id=node_id,
config=config,
data=data,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)

View File

@@ -36,14 +36,14 @@ class DatasourceNode(Node[DatasourceNodeData]):
def __init__(
self,
node_id: str,
config: DatasourceNodeData,
data: DatasourceNodeData,
*,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
) -> None:
super().__init__(
node_id=node_id,
config=config,
data=data,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)

View File

@@ -32,14 +32,14 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
def __init__(
self,
node_id: str,
config: KnowledgeIndexNodeData,
data: KnowledgeIndexNodeData,
*,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
) -> None:
super().__init__(
node_id=node_id,
config=config,
data=data,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)

View File

@@ -71,14 +71,14 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
def __init__(
self,
node_id: str,
config: KnowledgeRetrievalNodeData,
data: KnowledgeRetrievalNodeData,
*,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
) -> None:
super().__init__(
node_id=node_id,
config=config,
data=data,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
from collections import defaultdict
from collections.abc import Mapping, Sequence
from enum import StrEnum
from typing import Any, Protocol, cast
from typing import Any, Protocol
from uuid import uuid4
from graphon.enums import BuiltinNodeTypes
@@ -82,13 +82,10 @@ def build_system_variables(values: Mapping[str, Any] | None = None, /, **kwargs:
normalized = _normalize_system_variable_values(values, **kwargs)
return [
cast(
Variable,
segment_to_variable(
segment=build_segment(value),
selector=system_variable_selector(key),
name=key,
),
segment_to_variable(
segment=build_segment(value),
selector=system_variable_selector(key),
name=key,
)
for key, value in normalized.items()
]
@@ -130,13 +127,10 @@ def build_bootstrap_variables(
for node_id, value in rag_pipeline_variables_map.items():
variables.append(
cast(
Variable,
segment_to_variable(
segment=build_segment(value),
selector=(RAG_PIPELINE_VARIABLE_NODE_ID, node_id),
name=node_id,
),
segment_to_variable(
segment=build_segment(value),
selector=(RAG_PIPELINE_VARIABLE_NODE_ID, node_id),
name=node_id,
)
)

View File

@@ -46,6 +46,11 @@ _file_access_controller = DatabaseFileAccessController()
class _WorkflowChildEngineBuilder:
tenant_id: str
def __init__(self, *, tenant_id: str) -> None:
self.tenant_id = tenant_id
@staticmethod
def _has_node_id(graph_config: Mapping[str, Any], node_id: str) -> bool | None:
"""
@@ -107,7 +112,7 @@ class _WorkflowChildEngineBuilder:
config=config,
child_engine_builder=self,
)
child_engine.layer(LLMQuotaLayer())
child_engine.layer(LLMQuotaLayer(tenant_id=self.tenant_id))
return child_engine
@@ -176,7 +181,7 @@ class WorkflowEntry:
self.command_channel = command_channel
execution_context = capture_current_context()
graph_runtime_state.execution_context = execution_context
self._child_engine_builder = _WorkflowChildEngineBuilder()
self._child_engine_builder = _WorkflowChildEngineBuilder(tenant_id=tenant_id)
self.graph_engine = GraphEngine(
workflow_id=workflow_id,
graph=graph,
@@ -208,7 +213,7 @@ class WorkflowEntry:
max_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS, max_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME
)
self.graph_engine.layer(limits_layer)
self.graph_engine.layer(LLMQuotaLayer())
self.graph_engine.layer(LLMQuotaLayer(tenant_id=tenant_id))
# Add observability layer when OTel is enabled
if dify_config.ENABLE_OTEL or is_instrument_flag_enabled():