mirror of
https://github.com/langgenius/dify.git
synced 2026-05-10 06:00:51 -04:00
chore(api): upgrade graphon to v0.3.0 (#35469)
Signed-off-by: -LAN- <laipz8200@outlook.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: WH-2099 <wh2099@pm.me>
This commit is contained in:
@@ -4,23 +4,32 @@ import hashlib
|
||||
import logging
|
||||
from collections.abc import Generator, Iterable, Sequence
|
||||
from threading import Lock
|
||||
from typing import IO, Any, Union
|
||||
from typing import IO, Any, Literal, cast, overload
|
||||
|
||||
from pydantic import ValidationError
|
||||
from redis import RedisError
|
||||
|
||||
from configs import dify_config
|
||||
from core.llm_generator.output_parser.structured_output import (
|
||||
invoke_llm_with_structured_output as invoke_llm_with_structured_output_helper,
|
||||
)
|
||||
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
||||
from core.plugin.impl.asset import PluginAssetManager
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
from extensions.ext_redis import redis_client
|
||||
from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
from graphon.model_runtime.entities.llm_entities import (
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMResultChunkWithStructuredOutput,
|
||||
LLMResultWithStructuredOutput,
|
||||
)
|
||||
from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||
from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||
from graphon.model_runtime.entities.provider_entities import ProviderEntity
|
||||
from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult
|
||||
from graphon.model_runtime.entities.text_embedding_entities import EmbeddingInputType, EmbeddingResult
|
||||
from graphon.model_runtime.runtime import ModelRuntime
|
||||
from graphon.model_runtime.model_providers.base.large_language_model import normalize_non_stream_runtime_result
|
||||
from graphon.model_runtime.protocols.runtime import ModelRuntime
|
||||
from models.provider_ids import ModelProviderID
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -29,6 +38,68 @@ logger = logging.getLogger(__name__)
|
||||
TENANT_SCOPE_SCHEMA_CACHE_USER_ID = "__DIFY_TS__"
|
||||
|
||||
|
||||
# TODO(-LAN-): Move native structured-output invocation into Graphon's LLM node.
|
||||
# TODO(-LAN-): Remove this Dify-side adapter once Graphon owns structured output end-to-end.
|
||||
class _PluginStructuredOutputModelInstance:
|
||||
"""Bind plugin model identity to the shared structured-output helper.
|
||||
|
||||
The structured-output parser is shared with legacy ``ModelInstance`` flows
|
||||
and only needs an object exposing ``invoke_llm(...)``. ``PluginModelRuntime``
|
||||
intentionally exposes a lower-level API where provider, model, and
|
||||
credentials are passed per call. This adapter supplies the small bound
|
||||
``invoke_llm`` surface the helper needs without constructing a full
|
||||
``ModelInstance`` or reintroducing model-manager dependencies into the
|
||||
plugin runtime path.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
runtime: PluginModelRuntime,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
) -> None:
|
||||
self._runtime = runtime
|
||||
self._provider = provider
|
||||
self._model = model
|
||||
self._credentials = credentials
|
||||
|
||||
def invoke_llm(
|
||||
self,
|
||||
*,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
model_parameters: dict[str, Any] | None = None,
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: Sequence[str] | None = None,
|
||||
stream: bool = True,
|
||||
callbacks: object | None = None,
|
||||
) -> LLMResult | Generator[LLMResultChunk, None, None]:
|
||||
del callbacks
|
||||
if stream:
|
||||
return self._runtime.invoke_llm(
|
||||
provider=self._provider,
|
||||
model=self._model,
|
||||
credentials=self._credentials,
|
||||
model_parameters=model_parameters or {},
|
||||
prompt_messages=prompt_messages,
|
||||
tools=list(tools) if tools else None,
|
||||
stop=stop,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
return self._runtime.invoke_llm(
|
||||
provider=self._provider,
|
||||
model=self._model,
|
||||
credentials=self._credentials,
|
||||
model_parameters=model_parameters or {},
|
||||
prompt_messages=prompt_messages,
|
||||
tools=list(tools) if tools else None,
|
||||
stop=stop,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
|
||||
class PluginModelRuntime(ModelRuntime):
|
||||
"""Plugin-backed runtime adapter bound to tenant context and optional caller scope."""
|
||||
|
||||
@@ -195,6 +266,34 @@ class PluginModelRuntime(ModelRuntime):
|
||||
|
||||
return schema
|
||||
|
||||
@overload
|
||||
def invoke_llm(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
model_parameters: dict[str, Any],
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
tools: list[PromptMessageTool] | None,
|
||||
stop: Sequence[str] | None,
|
||||
stream: Literal[False],
|
||||
) -> LLMResult: ...
|
||||
|
||||
@overload
|
||||
def invoke_llm(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
model_parameters: dict[str, Any],
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
tools: list[PromptMessageTool] | None,
|
||||
stop: Sequence[str] | None,
|
||||
stream: Literal[True],
|
||||
) -> Generator[LLMResultChunk, None, None]: ...
|
||||
|
||||
def invoke_llm(
|
||||
self,
|
||||
*,
|
||||
@@ -206,9 +305,9 @@ class PluginModelRuntime(ModelRuntime):
|
||||
tools: list[PromptMessageTool] | None,
|
||||
stop: Sequence[str] | None,
|
||||
stream: bool,
|
||||
) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]:
|
||||
) -> LLMResult | Generator[LLMResultChunk, None, None]:
|
||||
plugin_id, provider_name = self._split_provider(provider)
|
||||
return self.client.invoke_llm(
|
||||
result = self.client.invoke_llm(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
plugin_id=plugin_id,
|
||||
@@ -221,6 +320,81 @@ class PluginModelRuntime(ModelRuntime):
|
||||
stop=list(stop) if stop else None,
|
||||
stream=stream,
|
||||
)
|
||||
if stream:
|
||||
return result
|
||||
|
||||
return normalize_non_stream_runtime_result(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
result=result,
|
||||
)
|
||||
|
||||
@overload
|
||||
def invoke_llm_with_structured_output(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
json_schema: dict[str, Any],
|
||||
model_parameters: dict[str, Any],
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
stop: Sequence[str] | None,
|
||||
stream: Literal[False],
|
||||
) -> LLMResultWithStructuredOutput: ...
|
||||
|
||||
@overload
|
||||
def invoke_llm_with_structured_output(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
json_schema: dict[str, Any],
|
||||
model_parameters: dict[str, Any],
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
stop: Sequence[str] | None,
|
||||
stream: Literal[True],
|
||||
) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
|
||||
|
||||
def invoke_llm_with_structured_output(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
json_schema: dict[str, Any],
|
||||
model_parameters: dict[str, Any],
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
stop: Sequence[str] | None,
|
||||
stream: bool,
|
||||
) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]:
|
||||
model_schema = self.get_model_schema(
|
||||
provider=provider,
|
||||
model_type=ModelType.LLM,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
)
|
||||
if model_schema is None:
|
||||
raise ValueError(f"Model schema not found for {model}")
|
||||
|
||||
adapter = _PluginStructuredOutputModelInstance(
|
||||
runtime=self,
|
||||
provider=provider,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
)
|
||||
return invoke_llm_with_structured_output_helper(
|
||||
provider=provider,
|
||||
model_schema=model_schema,
|
||||
model_instance=cast(Any, adapter),
|
||||
prompt_messages=prompt_messages,
|
||||
json_schema=json_schema,
|
||||
model_parameters=model_parameters,
|
||||
tools=None,
|
||||
stop=list(stop) if stop else None,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
def get_llm_num_tokens(
|
||||
self,
|
||||
|
||||
@@ -3,13 +3,46 @@ from __future__ import annotations
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from graphon.model_runtime.entities.provider_entities import ProviderEntity
|
||||
from graphon.model_runtime.model_providers.base.ai_model import AIModel
|
||||
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
|
||||
from graphon.model_runtime.model_providers.base.moderation_model import ModerationModel
|
||||
from graphon.model_runtime.model_providers.base.rerank_model import RerankModel
|
||||
from graphon.model_runtime.model_providers.base.speech2text_model import Speech2TextModel
|
||||
from graphon.model_runtime.model_providers.base.text_embedding_model import TextEmbeddingModel
|
||||
from graphon.model_runtime.model_providers.base.tts_model import TTSModel
|
||||
from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from graphon.model_runtime.protocols.runtime import ModelRuntime
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.model_manager import ModelManager
|
||||
from core.plugin.impl.model_runtime import PluginModelRuntime
|
||||
from core.provider_manager import ProviderManager
|
||||
|
||||
_MODEL_CLASS_BY_TYPE: dict[ModelType, type[AIModel]] = {
|
||||
ModelType.LLM: LargeLanguageModel,
|
||||
ModelType.TEXT_EMBEDDING: TextEmbeddingModel,
|
||||
ModelType.RERANK: RerankModel,
|
||||
ModelType.SPEECH2TEXT: Speech2TextModel,
|
||||
ModelType.MODERATION: ModerationModel,
|
||||
ModelType.TTS: TTSModel,
|
||||
}
|
||||
|
||||
|
||||
def create_model_type_instance(
|
||||
*,
|
||||
runtime: ModelRuntime,
|
||||
provider_schema: ProviderEntity,
|
||||
model_type: ModelType,
|
||||
) -> AIModel:
|
||||
"""Build the graphon model wrapper explicitly against the request runtime."""
|
||||
model_class = _MODEL_CLASS_BY_TYPE.get(model_type)
|
||||
if model_class is None:
|
||||
raise ValueError(f"Unsupported model type: {model_type}")
|
||||
|
||||
return model_class(provider_schema=provider_schema, model_runtime=runtime)
|
||||
|
||||
|
||||
class PluginModelAssembly:
|
||||
"""Compose request-scoped model views on top of a single plugin runtime."""
|
||||
@@ -38,9 +71,22 @@ class PluginModelAssembly:
|
||||
@property
|
||||
def model_provider_factory(self) -> ModelProviderFactory:
|
||||
if self._model_provider_factory is None:
|
||||
self._model_provider_factory = ModelProviderFactory(model_runtime=self.model_runtime)
|
||||
self._model_provider_factory = ModelProviderFactory(runtime=self.model_runtime)
|
||||
return self._model_provider_factory
|
||||
|
||||
def create_model_type_instance(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model_type: ModelType,
|
||||
) -> AIModel:
|
||||
provider_schema = self.model_provider_factory.get_provider_schema(provider=provider)
|
||||
return create_model_type_instance(
|
||||
runtime=self.model_runtime,
|
||||
provider_schema=provider_schema,
|
||||
model_type=model_type,
|
||||
)
|
||||
|
||||
@property
|
||||
def provider_manager(self) -> ProviderManager:
|
||||
if self._provider_manager is None:
|
||||
|
||||
Reference in New Issue
Block a user