mirror of
https://github.com/langgenius/dify.git
synced 2026-03-01 19:22:10 -05:00
feat: mypy for all type check (#10921)
This commit is contained in:
@@ -62,6 +62,7 @@ class BasedGenerateTaskPipeline:
|
||||
"""
|
||||
logger.debug("error: %s", event.error)
|
||||
e = event.error
|
||||
err: Exception
|
||||
|
||||
if isinstance(e, InvokeAuthorizationError):
|
||||
err = InvokeAuthorizationError("Incorrect API key provided")
|
||||
@@ -130,6 +131,7 @@ class BasedGenerateTaskPipeline:
|
||||
rule=ModerationRule(type=sensitive_word_avoidance.type, config=sensitive_word_avoidance.config),
|
||||
queue_manager=self._queue_manager,
|
||||
)
|
||||
return None
|
||||
|
||||
def _handle_output_moderation_when_task_finished(self, completion: str) -> Optional[str]:
|
||||
"""
|
||||
|
||||
@@ -2,6 +2,7 @@ import json
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from threading import Thread
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
|
||||
@@ -103,7 +104,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
)
|
||||
)
|
||||
|
||||
self._conversation_name_generate_thread = None
|
||||
self._conversation_name_generate_thread: Optional[Thread] = None
|
||||
|
||||
def process(
|
||||
self,
|
||||
@@ -123,7 +124,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
if self._application_generate_entity.app_config.app_mode != AppMode.COMPLETION:
|
||||
# start generate conversation name thread
|
||||
self._conversation_name_generate_thread = self._generate_conversation_name(
|
||||
self._conversation, self._application_generate_entity.query
|
||||
self._conversation, self._application_generate_entity.query or ""
|
||||
)
|
||||
|
||||
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
|
||||
@@ -146,7 +147,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
extras = {"usage": jsonable_encoder(self._task_state.llm_result.usage)}
|
||||
if self._task_state.metadata:
|
||||
extras["metadata"] = self._task_state.metadata
|
||||
|
||||
response: Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse]
|
||||
if self._conversation.mode == AppMode.COMPLETION.value:
|
||||
response = CompletionAppBlockingResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
@@ -154,7 +155,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
id=self._message.id,
|
||||
mode=self._conversation.mode,
|
||||
message_id=self._message.id,
|
||||
answer=self._task_state.llm_result.message.content,
|
||||
answer=cast(str, self._task_state.llm_result.message.content),
|
||||
created_at=int(self._message.created_at.timestamp()),
|
||||
**extras,
|
||||
),
|
||||
@@ -167,7 +168,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
mode=self._conversation.mode,
|
||||
conversation_id=self._conversation.id,
|
||||
message_id=self._message.id,
|
||||
answer=self._task_state.llm_result.message.content,
|
||||
answer=cast(str, self._task_state.llm_result.message.content),
|
||||
created_at=int(self._message.created_at.timestamp()),
|
||||
**extras,
|
||||
),
|
||||
@@ -252,7 +253,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
|
||||
|
||||
def _process_stream_response(
|
||||
self, publisher: AppGeneratorTTSPublisher, trace_manager: Optional[TraceQueueManager] = None
|
||||
self, publisher: Optional[AppGeneratorTTSPublisher], trace_manager: Optional[TraceQueueManager] = None
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""
|
||||
Process stream response.
|
||||
@@ -269,13 +270,14 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
break
|
||||
elif isinstance(event, QueueStopEvent | QueueMessageEndEvent):
|
||||
if isinstance(event, QueueMessageEndEvent):
|
||||
self._task_state.llm_result = event.llm_result
|
||||
if event.llm_result:
|
||||
self._task_state.llm_result = event.llm_result
|
||||
else:
|
||||
self._handle_stop(event)
|
||||
|
||||
# handle output moderation
|
||||
output_moderation_answer = self._handle_output_moderation_when_task_finished(
|
||||
self._task_state.llm_result.message.content
|
||||
cast(str, self._task_state.llm_result.message.content)
|
||||
)
|
||||
if output_moderation_answer:
|
||||
self._task_state.llm_result.message.content = output_moderation_answer
|
||||
@@ -292,7 +294,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
if annotation:
|
||||
self._task_state.llm_result.message.content = annotation.content
|
||||
elif isinstance(event, QueueAgentThoughtEvent):
|
||||
yield self._agent_thought_to_stream_response(event)
|
||||
agent_thought_response = self._agent_thought_to_stream_response(event)
|
||||
if agent_thought_response is not None:
|
||||
yield agent_thought_response
|
||||
elif isinstance(event, QueueMessageFileEvent):
|
||||
response = self._message_file_to_stream_response(event)
|
||||
if response:
|
||||
@@ -307,16 +311,18 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
self._task_state.llm_result.prompt_messages = chunk.prompt_messages
|
||||
|
||||
# handle output moderation chunk
|
||||
should_direct_answer = self._handle_output_moderation_chunk(delta_text)
|
||||
should_direct_answer = self._handle_output_moderation_chunk(cast(str, delta_text))
|
||||
if should_direct_answer:
|
||||
continue
|
||||
|
||||
self._task_state.llm_result.message.content += delta_text
|
||||
current_content = cast(str, self._task_state.llm_result.message.content)
|
||||
current_content += cast(str, delta_text)
|
||||
self._task_state.llm_result.message.content = current_content
|
||||
|
||||
if isinstance(event, QueueLLMChunkEvent):
|
||||
yield self._message_to_stream_response(delta_text, self._message.id)
|
||||
yield self._message_to_stream_response(cast(str, delta_text), self._message.id)
|
||||
else:
|
||||
yield self._agent_message_to_stream_response(delta_text, self._message.id)
|
||||
yield self._agent_message_to_stream_response(cast(str, delta_text), self._message.id)
|
||||
elif isinstance(event, QueueMessageReplaceEvent):
|
||||
yield self._message_replace_to_stream_response(answer=event.text)
|
||||
elif isinstance(event, QueuePingEvent):
|
||||
@@ -336,8 +342,14 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
llm_result = self._task_state.llm_result
|
||||
usage = llm_result.usage
|
||||
|
||||
self._message = db.session.query(Message).filter(Message.id == self._message.id).first()
|
||||
self._conversation = db.session.query(Conversation).filter(Conversation.id == self._conversation.id).first()
|
||||
message = db.session.query(Message).filter(Message.id == self._message.id).first()
|
||||
if not message:
|
||||
raise Exception(f"Message {self._message.id} not found")
|
||||
self._message = message
|
||||
conversation = db.session.query(Conversation).filter(Conversation.id == self._conversation.id).first()
|
||||
if not conversation:
|
||||
raise Exception(f"Conversation {self._conversation.id} not found")
|
||||
self._conversation = conversation
|
||||
|
||||
self._message.message = PromptMessageUtil.prompt_messages_to_prompt_for_saving(
|
||||
self._model_config.mode, self._task_state.llm_result.prompt_messages
|
||||
@@ -346,7 +358,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
self._message.message_unit_price = usage.prompt_unit_price
|
||||
self._message.message_price_unit = usage.prompt_price_unit
|
||||
self._message.answer = (
|
||||
PromptTemplateParser.remove_template_variables(llm_result.message.content.strip())
|
||||
PromptTemplateParser.remove_template_variables(cast(str, llm_result.message.content).strip())
|
||||
if llm_result.message.content
|
||||
else ""
|
||||
)
|
||||
@@ -374,6 +386,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
application_generate_entity=self._application_generate_entity,
|
||||
conversation=self._conversation,
|
||||
is_first_message=self._application_generate_entity.app_config.app_mode in {AppMode.AGENT_CHAT, AppMode.CHAT}
|
||||
and hasattr(self._application_generate_entity, "conversation_id")
|
||||
and self._application_generate_entity.conversation_id is None,
|
||||
extras=self._application_generate_entity.extras,
|
||||
)
|
||||
@@ -420,7 +433,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
extras["metadata"] = self._task_state.metadata
|
||||
|
||||
return MessageEndStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id, id=self._message.id, **extras
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
id=self._message.id,
|
||||
metadata=extras.get("metadata", {}),
|
||||
)
|
||||
|
||||
def _agent_message_to_stream_response(self, answer: str, message_id: str) -> AgentMessageStreamResponse:
|
||||
@@ -440,7 +455,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
:param event: agent thought event
|
||||
:return:
|
||||
"""
|
||||
agent_thought: MessageAgentThought = (
|
||||
agent_thought: Optional[MessageAgentThought] = (
|
||||
db.session.query(MessageAgentThought).filter(MessageAgentThought.id == event.agent_thought_id).first()
|
||||
)
|
||||
db.session.refresh(agent_thought)
|
||||
|
||||
@@ -128,7 +128,7 @@ class MessageCycleManage:
|
||||
"""
|
||||
message_file = db.session.query(MessageFile).filter(MessageFile.id == event.message_file_id).first()
|
||||
|
||||
if message_file:
|
||||
if message_file and message_file.url is not None:
|
||||
# get tool file id
|
||||
tool_file_id = message_file.url.split("/")[-1]
|
||||
# trim extension
|
||||
|
||||
@@ -93,7 +93,7 @@ class WorkflowCycleManage:
|
||||
)
|
||||
|
||||
# handle special values
|
||||
inputs = WorkflowEntry.handle_special_values(inputs)
|
||||
inputs = dict(WorkflowEntry.handle_special_values(inputs) or {})
|
||||
|
||||
# init workflow run
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
@@ -192,7 +192,7 @@ class WorkflowCycleManage:
|
||||
"""
|
||||
workflow_run = self._refetch_workflow_run(workflow_run.id)
|
||||
|
||||
outputs = WorkflowEntry.handle_special_values(outputs)
|
||||
outputs = WorkflowEntry.handle_special_values(dict(outputs) if outputs else None)
|
||||
|
||||
workflow_run.status = WorkflowRunStatus.PARTIAL_SUCCESSED.value
|
||||
workflow_run.outputs = json.dumps(outputs or {})
|
||||
@@ -500,7 +500,7 @@ class WorkflowCycleManage:
|
||||
id=workflow_run.id,
|
||||
workflow_id=workflow_run.workflow_id,
|
||||
sequence_number=workflow_run.sequence_number,
|
||||
inputs=workflow_run.inputs_dict,
|
||||
inputs=dict(workflow_run.inputs_dict or {}),
|
||||
created_at=int(workflow_run.created_at.timestamp()),
|
||||
),
|
||||
)
|
||||
@@ -545,7 +545,7 @@ class WorkflowCycleManage:
|
||||
workflow_id=workflow_run.workflow_id,
|
||||
sequence_number=workflow_run.sequence_number,
|
||||
status=workflow_run.status,
|
||||
outputs=workflow_run.outputs_dict,
|
||||
outputs=dict(workflow_run.outputs_dict) if workflow_run.outputs_dict else None,
|
||||
error=workflow_run.error,
|
||||
elapsed_time=workflow_run.elapsed_time,
|
||||
total_tokens=workflow_run.total_tokens,
|
||||
@@ -553,7 +553,7 @@ class WorkflowCycleManage:
|
||||
created_by=created_by,
|
||||
created_at=int(workflow_run.created_at.timestamp()),
|
||||
finished_at=int(workflow_run.finished_at.timestamp()),
|
||||
files=self._fetch_files_from_node_outputs(workflow_run.outputs_dict),
|
||||
files=self._fetch_files_from_node_outputs(dict(workflow_run.outputs_dict)),
|
||||
exceptions_count=workflow_run.exceptions_count,
|
||||
),
|
||||
)
|
||||
@@ -655,7 +655,7 @@ class WorkflowCycleManage:
|
||||
event: QueueNodeRetryEvent,
|
||||
task_id: str,
|
||||
workflow_node_execution: WorkflowNodeExecution,
|
||||
) -> Optional[NodeFinishStreamResponse]:
|
||||
) -> Optional[Union[NodeRetryStreamResponse, NodeFinishStreamResponse]]:
|
||||
"""
|
||||
Workflow node finish to stream response.
|
||||
:param event: queue node succeeded or failed event
|
||||
@@ -838,7 +838,7 @@ class WorkflowCycleManage:
|
||||
),
|
||||
)
|
||||
|
||||
def _fetch_files_from_node_outputs(self, outputs_dict: dict) -> Sequence[Mapping[str, Any]]:
|
||||
def _fetch_files_from_node_outputs(self, outputs_dict: Mapping[str, Any]) -> Sequence[Mapping[str, Any]]:
|
||||
"""
|
||||
Fetch files from node outputs
|
||||
:param outputs_dict: node outputs dict
|
||||
@@ -851,9 +851,11 @@ class WorkflowCycleManage:
|
||||
# Remove None
|
||||
files = [file for file in files if file]
|
||||
# Flatten list
|
||||
files = [file for sublist in files for file in sublist]
|
||||
# Flatten the list of sequences into a single list of mappings
|
||||
flattened_files = [file for sublist in files if sublist for file in sublist]
|
||||
|
||||
return files
|
||||
# Convert to tuple to match Sequence type
|
||||
return tuple(flattened_files)
|
||||
|
||||
def _fetch_files_from_variable_value(self, value: Union[dict, list]) -> Sequence[Mapping[str, Any]]:
|
||||
"""
|
||||
@@ -891,6 +893,8 @@ class WorkflowCycleManage:
|
||||
elif isinstance(value, File):
|
||||
return value.to_dict()
|
||||
|
||||
return None
|
||||
|
||||
def _refetch_workflow_run(self, workflow_run_id: str) -> WorkflowRun:
|
||||
"""
|
||||
Refetch workflow run
|
||||
|
||||
Reference in New Issue
Block a user