mirror of
https://github.com/langgenius/dify.git
synced 2026-06-01 22:01:06 -04:00
feat: extract model runtime
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
@@ -22,6 +22,9 @@ class ASRTool(BuiltinTool):
|
||||
app_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
if not self.runtime:
|
||||
raise ValueError("Runtime is required")
|
||||
runtime = self.runtime
|
||||
file = tool_parameters.get("audio_file")
|
||||
if file.type != FileType.AUDIO: # type: ignore
|
||||
yield self.create_text_message("not a valid audio file")
|
||||
@@ -29,20 +32,19 @@ class ASRTool(BuiltinTool):
|
||||
audio_binary = io.BytesIO(download(file)) # type: ignore
|
||||
audio_binary.name = "temp.mp3"
|
||||
provider, model = tool_parameters.get("model").split("#") # type: ignore
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=runtime.tenant_id, user_id=user_id)
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=self.runtime.tenant_id,
|
||||
tenant_id=runtime.tenant_id,
|
||||
provider=provider,
|
||||
model_type=ModelType.SPEECH2TEXT,
|
||||
model=model,
|
||||
)
|
||||
text = model_instance.invoke_speech2text(
|
||||
file=audio_binary,
|
||||
user=user_id,
|
||||
)
|
||||
text = model_instance.invoke_speech2text(file=audio_binary)
|
||||
yield self.create_text_message(text)
|
||||
|
||||
def get_available_models(self) -> list[tuple[str, str]]:
|
||||
if not self.runtime:
|
||||
raise ValueError("Runtime is required")
|
||||
model_provider_service = ModelProviderService()
|
||||
models = model_provider_service.get_models_by_model_type(
|
||||
tenant_id=self.runtime.tenant_id, model_type="speech2text"
|
||||
|
||||
@@ -20,13 +20,14 @@ class TTSTool(BuiltinTool):
|
||||
app_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
provider, model = tool_parameters.get("model").split("#") # type: ignore
|
||||
voice = tool_parameters.get(f"voice#{provider}#{model}")
|
||||
model_manager = ModelManager()
|
||||
if not self.runtime:
|
||||
raise ValueError("Runtime is required")
|
||||
runtime = self.runtime
|
||||
provider, model = tool_parameters.get("model").split("#") # type: ignore
|
||||
voice = tool_parameters.get(f"voice#{provider}#{model}")
|
||||
model_manager = ModelManager.for_tenant(tenant_id=runtime.tenant_id, user_id=user_id)
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=self.runtime.tenant_id or "",
|
||||
tenant_id=runtime.tenant_id or "",
|
||||
provider=provider,
|
||||
model_type=ModelType.TTS,
|
||||
model=model,
|
||||
@@ -39,12 +40,7 @@ class TTSTool(BuiltinTool):
|
||||
raise ValueError("Sorry, no voice available.")
|
||||
else:
|
||||
raise ValueError("Sorry, no voice available.")
|
||||
tts = model_instance.invoke_tts(
|
||||
content_text=tool_parameters.get("text"), # type: ignore
|
||||
user=user_id,
|
||||
tenant_id=self.runtime.tenant_id,
|
||||
voice=voice,
|
||||
)
|
||||
tts = model_instance.invoke_tts(content_text=tool_parameters.get("text"), voice=voice) # type: ignore[arg-type]
|
||||
buffer = io.BytesIO()
|
||||
for chunk in tts:
|
||||
buffer.write(chunk)
|
||||
|
||||
@@ -65,7 +65,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
# do rerank for searched documents
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=self.tenant_id)
|
||||
rerank_model_instance = model_manager.get_model_instance(
|
||||
tenant_id=self.tenant_id,
|
||||
provider=self.reranking_provider_name,
|
||||
|
||||
@@ -37,7 +37,7 @@ class ModelInvocationUtils:
|
||||
"""
|
||||
get max llm context tokens of the model
|
||||
"""
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
@@ -65,7 +65,7 @@ class ModelInvocationUtils:
|
||||
"""
|
||||
|
||||
# get model instance
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
|
||||
model_instance = model_manager.get_default_model_instance(tenant_id=tenant_id, model_type=ModelType.LLM)
|
||||
|
||||
if not model_instance:
|
||||
@@ -92,7 +92,7 @@ class ModelInvocationUtils:
|
||||
"""
|
||||
|
||||
# get model manager
|
||||
model_manager = ModelManager()
|
||||
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
|
||||
# get model instance
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
@@ -136,7 +136,6 @@ class ModelInvocationUtils:
|
||||
tools=[],
|
||||
stop=[],
|
||||
stream=False,
|
||||
user=user_id,
|
||||
callbacks=[],
|
||||
)
|
||||
except InvokeRateLimitError as e:
|
||||
|
||||
Reference in New Issue
Block a user