mirror of
https://github.com/langgenius/dify.git
synced 2026-01-10 02:00:33 -05:00
feat: mypy for all type check (#10921)
This commit is contained in:
@@ -51,6 +51,8 @@ class QdrantConfig(BaseModel):
|
||||
if self.endpoint and self.endpoint.startswith("path:"):
|
||||
path = self.endpoint.replace("path:", "")
|
||||
if not os.path.isabs(path):
|
||||
if not self.root_path:
|
||||
raise ValueError("Root path is not set")
|
||||
path = os.path.join(self.root_path, path)
|
||||
|
||||
return {"path": path}
|
||||
@@ -149,9 +151,12 @@ class QdrantVector(BaseVector):
|
||||
uuids = self._get_uuids(documents)
|
||||
texts = [d.page_content for d in documents]
|
||||
metadatas = [d.metadata for d in documents]
|
||||
|
||||
added_ids = []
|
||||
for batch_ids, points in self._generate_rest_batches(texts, embeddings, metadatas, uuids, 64, self._group_id):
|
||||
# Filter out None values from metadatas list to match expected type
|
||||
filtered_metadatas = [m for m in metadatas if m is not None]
|
||||
for batch_ids, points in self._generate_rest_batches(
|
||||
texts, embeddings, filtered_metadatas, uuids, 64, self._group_id
|
||||
):
|
||||
self._client.upsert(collection_name=self._collection_name, points=points)
|
||||
added_ids.extend(batch_ids)
|
||||
|
||||
@@ -194,7 +199,7 @@ class QdrantVector(BaseVector):
|
||||
batch_metadatas,
|
||||
Field.CONTENT_KEY.value,
|
||||
Field.METADATA_KEY.value,
|
||||
group_id,
|
||||
group_id or "", # Ensure group_id is never None
|
||||
Field.GROUP_KEY.value,
|
||||
),
|
||||
)
|
||||
@@ -337,18 +342,20 @@ class QdrantVector(BaseVector):
|
||||
)
|
||||
docs = []
|
||||
for result in results:
|
||||
if result.payload is None:
|
||||
continue
|
||||
metadata = result.payload.get(Field.METADATA_KEY.value) or {}
|
||||
# duplicate check score threshold
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
if result.score > score_threshold:
|
||||
metadata["score"] = result.score
|
||||
doc = Document(
|
||||
page_content=result.payload.get(Field.CONTENT_KEY.value),
|
||||
page_content=result.payload.get(Field.CONTENT_KEY.value, ""),
|
||||
metadata=metadata,
|
||||
)
|
||||
docs.append(doc)
|
||||
# Sort the documents by score in descending order
|
||||
docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True)
|
||||
docs = sorted(docs, key=lambda x: x.metadata["score"] if x.metadata is not None else 0, reverse=True)
|
||||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
@@ -432,9 +439,9 @@ class QdrantVectorFactory(AbstractVectorFactory):
|
||||
collection_name=collection_name,
|
||||
group_id=dataset.id,
|
||||
config=QdrantConfig(
|
||||
endpoint=dify_config.QDRANT_URL,
|
||||
endpoint=dify_config.QDRANT_URL or "",
|
||||
api_key=dify_config.QDRANT_API_KEY,
|
||||
root_path=current_app.config.root_path,
|
||||
root_path=str(current_app.config.root_path),
|
||||
timeout=dify_config.QDRANT_CLIENT_TIMEOUT,
|
||||
grpc_port=dify_config.QDRANT_GRPC_PORT,
|
||||
prefer_grpc=dify_config.QDRANT_GRPC_ENABLED,
|
||||
|
||||
Reference in New Issue
Block a user