feat: extract model runtime

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN-
2026-03-15 15:34:47 +08:00
parent 3d5a29462e
commit fbb74a4af9
178 changed files with 4343 additions and 2134 deletions

View File

@@ -22,6 +22,9 @@ class ASRTool(BuiltinTool):
app_id: str | None = None,
message_id: str | None = None,
) -> Generator[ToolInvokeMessage, None, None]:
if not self.runtime:
raise ValueError("Runtime is required")
runtime = self.runtime
file = tool_parameters.get("audio_file")
if file.type != FileType.AUDIO: # type: ignore
yield self.create_text_message("not a valid audio file")
@@ -29,20 +32,19 @@ class ASRTool(BuiltinTool):
audio_binary = io.BytesIO(download(file)) # type: ignore
audio_binary.name = "temp.mp3"
provider, model = tool_parameters.get("model").split("#") # type: ignore
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=runtime.tenant_id, user_id=user_id)
model_instance = model_manager.get_model_instance(
tenant_id=self.runtime.tenant_id,
tenant_id=runtime.tenant_id,
provider=provider,
model_type=ModelType.SPEECH2TEXT,
model=model,
)
text = model_instance.invoke_speech2text(
file=audio_binary,
user=user_id,
)
text = model_instance.invoke_speech2text(file=audio_binary)
yield self.create_text_message(text)
def get_available_models(self) -> list[tuple[str, str]]:
if not self.runtime:
raise ValueError("Runtime is required")
model_provider_service = ModelProviderService()
models = model_provider_service.get_models_by_model_type(
tenant_id=self.runtime.tenant_id, model_type="speech2text"

View File

@@ -20,13 +20,14 @@ class TTSTool(BuiltinTool):
app_id: str | None = None,
message_id: str | None = None,
) -> Generator[ToolInvokeMessage, None, None]:
provider, model = tool_parameters.get("model").split("#") # type: ignore
voice = tool_parameters.get(f"voice#{provider}#{model}")
model_manager = ModelManager()
if not self.runtime:
raise ValueError("Runtime is required")
runtime = self.runtime
provider, model = tool_parameters.get("model").split("#") # type: ignore
voice = tool_parameters.get(f"voice#{provider}#{model}")
model_manager = ModelManager.for_tenant(tenant_id=runtime.tenant_id, user_id=user_id)
model_instance = model_manager.get_model_instance(
tenant_id=self.runtime.tenant_id or "",
tenant_id=runtime.tenant_id or "",
provider=provider,
model_type=ModelType.TTS,
model=model,
@@ -39,12 +40,7 @@ class TTSTool(BuiltinTool):
raise ValueError("Sorry, no voice available.")
else:
raise ValueError("Sorry, no voice available.")
tts = model_instance.invoke_tts(
content_text=tool_parameters.get("text"), # type: ignore
user=user_id,
tenant_id=self.runtime.tenant_id,
voice=voice,
)
tts = model_instance.invoke_tts(content_text=tool_parameters.get("text"), voice=voice) # type: ignore[arg-type]
buffer = io.BytesIO()
for chunk in tts:
buffer.write(chunk)

View File

@@ -65,7 +65,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
for thread in threads:
thread.join()
# do rerank for searched documents
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=self.tenant_id)
rerank_model_instance = model_manager.get_model_instance(
tenant_id=self.tenant_id,
provider=self.reranking_provider_name,

View File

@@ -37,7 +37,7 @@ class ModelInvocationUtils:
"""
get max llm context tokens of the model
"""
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
model_instance = model_manager.get_default_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
@@ -65,7 +65,7 @@ class ModelInvocationUtils:
"""
# get model instance
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
model_instance = model_manager.get_default_model_instance(tenant_id=tenant_id, model_type=ModelType.LLM)
if not model_instance:
@@ -92,7 +92,7 @@ class ModelInvocationUtils:
"""
# get model manager
model_manager = ModelManager()
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
# get model instance
model_instance = model_manager.get_default_model_instance(
tenant_id=tenant_id,
@@ -136,7 +136,6 @@ class ModelInvocationUtils:
tools=[],
stop=[],
stream=False,
user=user_id,
callbacks=[],
)
except InvokeRateLimitError as e: