diff --git a/api/core/tools/utils/dataset_retriever/dataset_retriever_base_tool.py b/api/core/tools/utils/dataset_retriever/dataset_retriever_base_tool.py index a4d2de3b1c..567275531e 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_retriever_base_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_retriever_base_tool.py @@ -1,5 +1,5 @@ from abc import abstractmethod -from typing import Any, Optional +from typing import Optional from msal_extensions.persistence import ABC # type: ignore from pydantic import BaseModel, ConfigDict @@ -21,11 +21,7 @@ class DatasetRetrieverBaseTool(BaseModel, ABC): model_config = ConfigDict(arbitrary_types_allowed=True) @abstractmethod - def _run( - self, - *args: Any, - **kwargs: Any, - ) -> Any: + def _run(self, query: str) -> str: """Use the tool. Add run_manager: Optional[CallbackManagerForToolRun] = None diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 5f092dc2f1..be8fa4d226 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -462,7 +462,7 @@ class KnowledgeRetrievalNode(BaseNode): expected_value = self.graph_runtime_state.variable_pool.convert_template( expected_value ).value[0] - if expected_value.value_type == "number": # type: ignore + if expected_value.value_type in {"number", "integer", "float"}: # type: ignore expected_value = expected_value.value # type: ignore elif expected_value.value_type == "string": # type: ignore expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip() # type: ignore diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index 91e7312805..90a0397b67 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -565,7 +565,7 @@ class LLMNode(BaseNode): retriever_resources=original_retriever_resource, context=context_str.strip() ) - def _convert_to_original_retriever_resource(self, context_dict: dict): + def _convert_to_original_retriever_resource(self, context_dict: dict) -> RetrievalSourceMetadata | None: if ( "metadata" in context_dict and "_source" in context_dict["metadata"]