feat: mypy for all type check (#10921)

This commit is contained in:
yihong
2024-12-24 18:38:51 +08:00
committed by GitHub
parent c91e8b1737
commit 56e15d09a9
584 changed files with 3975 additions and 2826 deletions

View File

@@ -62,6 +62,7 @@ class BasedGenerateTaskPipeline:
"""
logger.debug("error: %s", event.error)
e = event.error
err: Exception
if isinstance(e, InvokeAuthorizationError):
err = InvokeAuthorizationError("Incorrect API key provided")
@@ -130,6 +131,7 @@ class BasedGenerateTaskPipeline:
rule=ModerationRule(type=sensitive_word_avoidance.type, config=sensitive_word_avoidance.config),
queue_manager=self._queue_manager,
)
return None
def _handle_output_moderation_when_task_finished(self, completion: str) -> Optional[str]:
"""

View File

@@ -2,6 +2,7 @@ import json
import logging
import time
from collections.abc import Generator
from threading import Thread
from typing import Optional, Union, cast
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
@@ -103,7 +104,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
)
)
self._conversation_name_generate_thread = None
self._conversation_name_generate_thread: Optional[Thread] = None
def process(
self,
@@ -123,7 +124,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
if self._application_generate_entity.app_config.app_mode != AppMode.COMPLETION:
# start generate conversation name thread
self._conversation_name_generate_thread = self._generate_conversation_name(
self._conversation, self._application_generate_entity.query
self._conversation, self._application_generate_entity.query or ""
)
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
@@ -146,7 +147,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
extras = {"usage": jsonable_encoder(self._task_state.llm_result.usage)}
if self._task_state.metadata:
extras["metadata"] = self._task_state.metadata
response: Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse]
if self._conversation.mode == AppMode.COMPLETION.value:
response = CompletionAppBlockingResponse(
task_id=self._application_generate_entity.task_id,
@@ -154,7 +155,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
id=self._message.id,
mode=self._conversation.mode,
message_id=self._message.id,
answer=self._task_state.llm_result.message.content,
answer=cast(str, self._task_state.llm_result.message.content),
created_at=int(self._message.created_at.timestamp()),
**extras,
),
@@ -167,7 +168,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
mode=self._conversation.mode,
conversation_id=self._conversation.id,
message_id=self._message.id,
answer=self._task_state.llm_result.message.content,
answer=cast(str, self._task_state.llm_result.message.content),
created_at=int(self._message.created_at.timestamp()),
**extras,
),
@@ -252,7 +253,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
def _process_stream_response(
self, publisher: AppGeneratorTTSPublisher, trace_manager: Optional[TraceQueueManager] = None
self, publisher: Optional[AppGeneratorTTSPublisher], trace_manager: Optional[TraceQueueManager] = None
) -> Generator[StreamResponse, None, None]:
"""
Process stream response.
@@ -269,13 +270,14 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
break
elif isinstance(event, QueueStopEvent | QueueMessageEndEvent):
if isinstance(event, QueueMessageEndEvent):
self._task_state.llm_result = event.llm_result
if event.llm_result:
self._task_state.llm_result = event.llm_result
else:
self._handle_stop(event)
# handle output moderation
output_moderation_answer = self._handle_output_moderation_when_task_finished(
self._task_state.llm_result.message.content
cast(str, self._task_state.llm_result.message.content)
)
if output_moderation_answer:
self._task_state.llm_result.message.content = output_moderation_answer
@@ -292,7 +294,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
if annotation:
self._task_state.llm_result.message.content = annotation.content
elif isinstance(event, QueueAgentThoughtEvent):
yield self._agent_thought_to_stream_response(event)
agent_thought_response = self._agent_thought_to_stream_response(event)
if agent_thought_response is not None:
yield agent_thought_response
elif isinstance(event, QueueMessageFileEvent):
response = self._message_file_to_stream_response(event)
if response:
@@ -307,16 +311,18 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
self._task_state.llm_result.prompt_messages = chunk.prompt_messages
# handle output moderation chunk
should_direct_answer = self._handle_output_moderation_chunk(delta_text)
should_direct_answer = self._handle_output_moderation_chunk(cast(str, delta_text))
if should_direct_answer:
continue
self._task_state.llm_result.message.content += delta_text
current_content = cast(str, self._task_state.llm_result.message.content)
current_content += cast(str, delta_text)
self._task_state.llm_result.message.content = current_content
if isinstance(event, QueueLLMChunkEvent):
yield self._message_to_stream_response(delta_text, self._message.id)
yield self._message_to_stream_response(cast(str, delta_text), self._message.id)
else:
yield self._agent_message_to_stream_response(delta_text, self._message.id)
yield self._agent_message_to_stream_response(cast(str, delta_text), self._message.id)
elif isinstance(event, QueueMessageReplaceEvent):
yield self._message_replace_to_stream_response(answer=event.text)
elif isinstance(event, QueuePingEvent):
@@ -336,8 +342,14 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
llm_result = self._task_state.llm_result
usage = llm_result.usage
self._message = db.session.query(Message).filter(Message.id == self._message.id).first()
self._conversation = db.session.query(Conversation).filter(Conversation.id == self._conversation.id).first()
message = db.session.query(Message).filter(Message.id == self._message.id).first()
if not message:
raise Exception(f"Message {self._message.id} not found")
self._message = message
conversation = db.session.query(Conversation).filter(Conversation.id == self._conversation.id).first()
if not conversation:
raise Exception(f"Conversation {self._conversation.id} not found")
self._conversation = conversation
self._message.message = PromptMessageUtil.prompt_messages_to_prompt_for_saving(
self._model_config.mode, self._task_state.llm_result.prompt_messages
@@ -346,7 +358,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
self._message.message_unit_price = usage.prompt_unit_price
self._message.message_price_unit = usage.prompt_price_unit
self._message.answer = (
PromptTemplateParser.remove_template_variables(llm_result.message.content.strip())
PromptTemplateParser.remove_template_variables(cast(str, llm_result.message.content).strip())
if llm_result.message.content
else ""
)
@@ -374,6 +386,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
application_generate_entity=self._application_generate_entity,
conversation=self._conversation,
is_first_message=self._application_generate_entity.app_config.app_mode in {AppMode.AGENT_CHAT, AppMode.CHAT}
and hasattr(self._application_generate_entity, "conversation_id")
and self._application_generate_entity.conversation_id is None,
extras=self._application_generate_entity.extras,
)
@@ -420,7 +433,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
extras["metadata"] = self._task_state.metadata
return MessageEndStreamResponse(
task_id=self._application_generate_entity.task_id, id=self._message.id, **extras
task_id=self._application_generate_entity.task_id,
id=self._message.id,
metadata=extras.get("metadata", {}),
)
def _agent_message_to_stream_response(self, answer: str, message_id: str) -> AgentMessageStreamResponse:
@@ -440,7 +455,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
:param event: agent thought event
:return:
"""
agent_thought: MessageAgentThought = (
agent_thought: Optional[MessageAgentThought] = (
db.session.query(MessageAgentThought).filter(MessageAgentThought.id == event.agent_thought_id).first()
)
db.session.refresh(agent_thought)

View File

@@ -128,7 +128,7 @@ class MessageCycleManage:
"""
message_file = db.session.query(MessageFile).filter(MessageFile.id == event.message_file_id).first()
if message_file:
if message_file and message_file.url is not None:
# get tool file id
tool_file_id = message_file.url.split("/")[-1]
# trim extension

View File

@@ -93,7 +93,7 @@ class WorkflowCycleManage:
)
# handle special values
inputs = WorkflowEntry.handle_special_values(inputs)
inputs = dict(WorkflowEntry.handle_special_values(inputs) or {})
# init workflow run
with Session(db.engine, expire_on_commit=False) as session:
@@ -192,7 +192,7 @@ class WorkflowCycleManage:
"""
workflow_run = self._refetch_workflow_run(workflow_run.id)
outputs = WorkflowEntry.handle_special_values(outputs)
outputs = WorkflowEntry.handle_special_values(dict(outputs) if outputs else None)
workflow_run.status = WorkflowRunStatus.PARTIAL_SUCCESSED.value
workflow_run.outputs = json.dumps(outputs or {})
@@ -500,7 +500,7 @@ class WorkflowCycleManage:
id=workflow_run.id,
workflow_id=workflow_run.workflow_id,
sequence_number=workflow_run.sequence_number,
inputs=workflow_run.inputs_dict,
inputs=dict(workflow_run.inputs_dict or {}),
created_at=int(workflow_run.created_at.timestamp()),
),
)
@@ -545,7 +545,7 @@ class WorkflowCycleManage:
workflow_id=workflow_run.workflow_id,
sequence_number=workflow_run.sequence_number,
status=workflow_run.status,
outputs=workflow_run.outputs_dict,
outputs=dict(workflow_run.outputs_dict) if workflow_run.outputs_dict else None,
error=workflow_run.error,
elapsed_time=workflow_run.elapsed_time,
total_tokens=workflow_run.total_tokens,
@@ -553,7 +553,7 @@ class WorkflowCycleManage:
created_by=created_by,
created_at=int(workflow_run.created_at.timestamp()),
finished_at=int(workflow_run.finished_at.timestamp()),
files=self._fetch_files_from_node_outputs(workflow_run.outputs_dict),
files=self._fetch_files_from_node_outputs(dict(workflow_run.outputs_dict)),
exceptions_count=workflow_run.exceptions_count,
),
)
@@ -655,7 +655,7 @@ class WorkflowCycleManage:
event: QueueNodeRetryEvent,
task_id: str,
workflow_node_execution: WorkflowNodeExecution,
) -> Optional[NodeFinishStreamResponse]:
) -> Optional[Union[NodeRetryStreamResponse, NodeFinishStreamResponse]]:
"""
Workflow node finish to stream response.
:param event: queue node succeeded or failed event
@@ -838,7 +838,7 @@ class WorkflowCycleManage:
),
)
def _fetch_files_from_node_outputs(self, outputs_dict: dict) -> Sequence[Mapping[str, Any]]:
def _fetch_files_from_node_outputs(self, outputs_dict: Mapping[str, Any]) -> Sequence[Mapping[str, Any]]:
"""
Fetch files from node outputs
:param outputs_dict: node outputs dict
@@ -851,9 +851,11 @@ class WorkflowCycleManage:
# Remove None
files = [file for file in files if file]
# Flatten list
files = [file for sublist in files for file in sublist]
# Flatten the list of sequences into a single list of mappings
flattened_files = [file for sublist in files if sublist for file in sublist]
return files
# Convert to tuple to match Sequence type
return tuple(flattened_files)
def _fetch_files_from_variable_value(self, value: Union[dict, list]) -> Sequence[Mapping[str, Any]]:
"""
@@ -891,6 +893,8 @@ class WorkflowCycleManage:
elif isinstance(value, File):
return value.to_dict()
return None
def _refetch_workflow_run(self, workflow_run_id: str) -> WorkflowRun:
"""
Refetch workflow run