feat: mypy for all type check (#10921)

This commit is contained in:
yihong
2024-12-24 18:38:51 +08:00
committed by GitHub
parent c91e8b1737
commit 56e15d09a9
584 changed files with 3975 additions and 2826 deletions

View File

@@ -51,6 +51,8 @@ class QdrantConfig(BaseModel):
if self.endpoint and self.endpoint.startswith("path:"):
path = self.endpoint.replace("path:", "")
if not os.path.isabs(path):
if not self.root_path:
raise ValueError("Root path is not set")
path = os.path.join(self.root_path, path)
return {"path": path}
@@ -149,9 +151,12 @@ class QdrantVector(BaseVector):
uuids = self._get_uuids(documents)
texts = [d.page_content for d in documents]
metadatas = [d.metadata for d in documents]
added_ids = []
for batch_ids, points in self._generate_rest_batches(texts, embeddings, metadatas, uuids, 64, self._group_id):
# Filter out None values from metadatas list to match expected type
filtered_metadatas = [m for m in metadatas if m is not None]
for batch_ids, points in self._generate_rest_batches(
texts, embeddings, filtered_metadatas, uuids, 64, self._group_id
):
self._client.upsert(collection_name=self._collection_name, points=points)
added_ids.extend(batch_ids)
@@ -194,7 +199,7 @@ class QdrantVector(BaseVector):
batch_metadatas,
Field.CONTENT_KEY.value,
Field.METADATA_KEY.value,
group_id,
group_id or "", # Ensure group_id is never None
Field.GROUP_KEY.value,
),
)
@@ -337,18 +342,20 @@ class QdrantVector(BaseVector):
)
docs = []
for result in results:
if result.payload is None:
continue
metadata = result.payload.get(Field.METADATA_KEY.value) or {}
# duplicate check score threshold
score_threshold = float(kwargs.get("score_threshold") or 0.0)
if result.score > score_threshold:
metadata["score"] = result.score
doc = Document(
page_content=result.payload.get(Field.CONTENT_KEY.value),
page_content=result.payload.get(Field.CONTENT_KEY.value, ""),
metadata=metadata,
)
docs.append(doc)
# Sort the documents by score in descending order
docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True)
docs = sorted(docs, key=lambda x: x.metadata["score"] if x.metadata is not None else 0, reverse=True)
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
@@ -432,9 +439,9 @@ class QdrantVectorFactory(AbstractVectorFactory):
collection_name=collection_name,
group_id=dataset.id,
config=QdrantConfig(
endpoint=dify_config.QDRANT_URL,
endpoint=dify_config.QDRANT_URL or "",
api_key=dify_config.QDRANT_API_KEY,
root_path=current_app.config.root_path,
root_path=str(current_app.config.root_path),
timeout=dify_config.QDRANT_CLIENT_TIMEOUT,
grpc_port=dify_config.QDRANT_GRPC_PORT,
prefer_grpc=dify_config.QDRANT_GRPC_ENABLED,