1
0
mirror of synced 2025-12-25 02:09:19 -05:00

Vector DB CDK: Refactor to improve readability (#33255)

Co-authored-by: flash1293 <flash1293@users.noreply.github.com>
This commit is contained in:
Joe Reuter
2023-12-13 12:23:39 +01:00
committed by GitHub
parent c1e428f35c
commit 55d5345bff
4 changed files with 62 additions and 47 deletions

View File

@@ -4,6 +4,7 @@
import os
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import List, Optional, Union, cast
from airbyte_cdk.destinations.vector_db_based.config import (
@@ -15,8 +16,8 @@ from airbyte_cdk.destinations.vector_db_based.config import (
OpenAIEmbeddingConfigModel,
ProcessingConfigModel,
)
from airbyte_cdk.destinations.vector_db_based.document_processor import Chunk
from airbyte_cdk.destinations.vector_db_based.utils import create_chunks, format_exception
from airbyte_cdk.models import AirbyteRecordMessage
from airbyte_cdk.utils.traced_exception import AirbyteTracedException, FailureType
from langchain.embeddings.cohere import CohereEmbeddings
from langchain.embeddings.fake import FakeEmbeddings
@@ -24,6 +25,12 @@ from langchain.embeddings.localai import LocalAIEmbeddings
from langchain.embeddings.openai import OpenAIEmbeddings
@dataclass
class Document:
page_content: str
record: AirbyteRecordMessage
class Embedder(ABC):
"""
Embedder is an abstract class that defines the interface for embedding text.
@@ -41,7 +48,7 @@ class Embedder(ABC):
pass
@abstractmethod
def embed_chunks(self, chunks: List[Chunk]) -> List[Optional[List[float]]]:
def embed_documents(self, documents: List[Document]) -> List[Optional[List[float]]]:
"""
Embed the text of each chunk and return the resulting embedding vectors.
If a chunk cannot be embedded or is configured to not be embedded, return None for that chunk.
@@ -72,7 +79,7 @@ class BaseOpenAIEmbedder(Embedder):
return format_exception(e)
return None
def embed_chunks(self, chunks: List[Chunk]) -> List[Optional[List[float]]]:
def embed_documents(self, documents: List[Document]) -> List[Optional[List[float]]]:
"""
Embed the text of each chunk and return the resulting embedding vectors.
@@ -80,9 +87,9 @@ class BaseOpenAIEmbedder(Embedder):
It's still possible to run into the rate limit between each embed call because the available token budget hasn't recovered between the calls,
but the built-in retry mechanism of the OpenAI client handles that.
"""
# Each chunk can hold at most self.chunk_size tokens, so tokens-per-minute by maximum tokens per chunk is the number of chunks that can be embedded at once without exhausting the limit in a single request
# Each chunk can hold at most self.chunk_size tokens, so tokens-per-minute by maximum tokens per chunk is the number of documents that can be embedded at once without exhausting the limit in a single request
embedding_batch_size = OPEN_AI_TOKEN_LIMIT // self.chunk_size
batches = create_chunks(chunks, batch_size=embedding_batch_size)
batches = create_chunks(documents, batch_size=embedding_batch_size)
embeddings: List[Optional[List[float]]] = []
for batch in batches:
embeddings.extend(self.embeddings.embed_documents([chunk.page_content for chunk in batch]))
@@ -121,8 +128,8 @@ class CohereEmbedder(Embedder):
return format_exception(e)
return None
def embed_chunks(self, chunks: List[Chunk]) -> List[Optional[List[float]]]:
return cast(List[Optional[List[float]]], self.embeddings.embed_documents([chunk.page_content or "" for chunk in chunks]))
def embed_documents(self, documents: List[Document]) -> List[Optional[List[float]]]:
return cast(List[Optional[List[float]]], self.embeddings.embed_documents([document.page_content for document in documents]))
@property
def embedding_dimensions(self) -> int:
@@ -142,8 +149,8 @@ class FakeEmbedder(Embedder):
return format_exception(e)
return None
def embed_chunks(self, chunks: List[Chunk]) -> List[Optional[List[float]]]:
return cast(List[Optional[List[float]]], self.embeddings.embed_documents([chunk.page_content or "" for chunk in chunks]))
def embed_documents(self, documents: List[Document]) -> List[Optional[List[float]]]:
return cast(List[Optional[List[float]]], self.embeddings.embed_documents([document.page_content for document in documents]))
@property
def embedding_dimensions(self) -> int:
@@ -173,8 +180,8 @@ class OpenAICompatibleEmbedder(Embedder):
return format_exception(e)
return None
def embed_chunks(self, chunks: List[Chunk]) -> List[Optional[List[float]]]:
return cast(List[Optional[List[float]]], self.embeddings.embed_documents([chunk.page_content or "" for chunk in chunks]))
def embed_documents(self, documents: List[Document]) -> List[Optional[List[float]]]:
return cast(List[Optional[List[float]]], self.embeddings.embed_documents([document.page_content for document in documents]))
@property
def embedding_dimensions(self) -> int:
@@ -190,32 +197,32 @@ class FromFieldEmbedder(Embedder):
def check(self) -> Optional[str]:
return None
def embed_chunks(self, chunks: List[Chunk]) -> List[Optional[List[float]]]:
def embed_documents(self, documents: List[Document]) -> List[Optional[List[float]]]:
"""
From each chunk, pull the embedding from the field specified in the config.
Check that the field exists, is a list of numbers and is the correct size. If not, raise an AirbyteTracedException explaining the problem.
"""
embeddings: List[Optional[List[float]]] = []
for chunk in chunks:
data = chunk.record.data
for document in documents:
data = document.record.data
if self.config.field_name not in data:
raise AirbyteTracedException(
internal_message="Embedding vector field not found",
failure_type=FailureType.config_error,
message=f"Record {str(data)[:250]}... in stream {chunk.record.stream} does not contain embedding vector field {self.config.field_name}. Please check your embedding configuration, the embedding vector field has to be set correctly on every record.",
message=f"Record {str(data)[:250]}... in stream {document.record.stream} does not contain embedding vector field {self.config.field_name}. Please check your embedding configuration, the embedding vector field has to be set correctly on every record.",
)
field = data[self.config.field_name]
if not isinstance(field, list) or not all(isinstance(x, (int, float)) for x in field):
raise AirbyteTracedException(
internal_message="Embedding vector field not a list of numbers",
failure_type=FailureType.config_error,
message=f"Record {str(data)[:250]}... in stream {chunk.record.stream} does contain embedding vector field {self.config.field_name}, but it is not a list of numbers. Please check your embedding configuration, the embedding vector field has to be a list of numbers of length {self.config.dimensions} on every record.",
message=f"Record {str(data)[:250]}... in stream {document.record.stream} does contain embedding vector field {self.config.field_name}, but it is not a list of numbers. Please check your embedding configuration, the embedding vector field has to be a list of numbers of length {self.config.dimensions} on every record.",
)
if len(field) != self.config.dimensions:
raise AirbyteTracedException(
internal_message="Embedding vector field has wrong length",
failure_type=FailureType.config_error,
message=f"Record {str(data)[:250]}... in stream {chunk.record.stream} does contain embedding vector field {self.config.field_name}, but it has length {len(field)} instead of the configured {self.config.dimensions}. Please check your embedding configuration, the embedding vector field has to be a list of numbers of length {self.config.dimensions} on every record.",
message=f"Record {str(data)[:250]}... in stream {document.record.stream} does contain embedding vector field {self.config.field_name}, but it has length {len(field)} instead of the configured {self.config.dimensions}. Please check your embedding configuration, the embedding vector field has to be a list of numbers of length {self.config.dimensions} on every record.",
)
embeddings.append(field)

View File

@@ -8,7 +8,7 @@ from typing import Dict, Iterable, List, Tuple
from airbyte_cdk.destinations.vector_db_based.config import ProcessingConfigModel
from airbyte_cdk.destinations.vector_db_based.document_processor import Chunk, DocumentProcessor
from airbyte_cdk.destinations.vector_db_based.embedder import Embedder
from airbyte_cdk.destinations.vector_db_based.embedder import Document, Embedder
from airbyte_cdk.destinations.vector_db_based.indexer import Indexer
from airbyte_cdk.models import AirbyteMessage, ConfiguredAirbyteCatalog, Type
@@ -16,14 +16,14 @@ from airbyte_cdk.models import AirbyteMessage, ConfiguredAirbyteCatalog, Type
class Writer:
"""
The Writer class is orchestrating the document processor, the embedder and the indexer:
* Incoming records are passed through the document processor to generate documents
* One the configured batch size is reached, the documents are passed to the embedder to generate embeddings
* The embedder embeds the documents
* The indexer deletes old documents by the associated record id before indexing the new ones
* Incoming records are passed through the document processor to generate chunks
* One the configured batch size is reached, the chunks are passed to the embedder to generate embeddings
* The embedder embeds the chunks
* The indexer deletes old chunks by the associated record id before indexing the new ones
The destination connector is responsible to create a writer instance and pass the input messages iterable to the write method.
The batch size can be configured by the destination connector to give the freedom of either letting the user configure it or hardcoding it to a sensible value depending on the destination.
The omit_raw_text parameter can be used to omit the raw text from the documents. This can be useful if the raw text is very large and not needed for the destination.
The omit_raw_text parameter can be used to omit the raw text from the chunks. This can be useful if the raw text is very large and not needed for the destination.
"""
def __init__(
@@ -37,21 +37,29 @@ class Writer:
self._init_batch()
def _init_batch(self) -> None:
self.documents: Dict[Tuple[str, str], List[Chunk]] = defaultdict(list)
self.chunks: Dict[Tuple[str, str], List[Chunk]] = defaultdict(list)
self.ids_to_delete: Dict[Tuple[str, str], List[str]] = defaultdict(list)
self.number_of_documents = 0
self.number_of_chunks = 0
def _convert_to_document(self, chunk: Chunk) -> Document:
"""
Convert a chunk to a document for the embedder.
"""
if chunk.page_content is None:
raise ValueError("Cannot embed a chunk without page content")
return Document(page_content=chunk.page_content, record=chunk.record)
def _process_batch(self) -> None:
for (namespace, stream), ids in self.ids_to_delete.items():
self.indexer.delete(ids, namespace, stream)
for (namespace, stream), documents in self.documents.items():
embeddings = self.embedder.embed_chunks(documents)
for i, document in enumerate(documents):
for (namespace, stream), chunks in self.chunks.items():
embeddings = self.embedder.embed_documents([self._convert_to_document(chunk) for chunk in chunks])
for i, document in enumerate(chunks):
document.embedding = embeddings[i]
if self.omit_raw_text:
document.page_content = None
self.indexer.index(documents, namespace, stream)
self.indexer.index(chunks, namespace, stream)
self._init_batch()
@@ -65,12 +73,12 @@ class Writer:
self._process_batch()
yield message
elif message.type == Type.RECORD:
record_documents, record_id_to_delete = self.processor.process(message.record)
self.documents[(message.record.namespace, message.record.stream)].extend(record_documents)
record_chunks, record_id_to_delete = self.processor.process(message.record)
self.chunks[(message.record.namespace, message.record.stream)].extend(record_chunks)
if record_id_to_delete is not None:
self.ids_to_delete[(message.record.namespace, message.record.stream)].append(record_id_to_delete)
self.number_of_documents += len(record_documents)
if self.number_of_documents >= self.batch_size:
self.number_of_chunks += len(record_chunks)
if self.number_of_chunks >= self.batch_size:
self._process_batch()
self._process_batch()

View File

@@ -13,12 +13,12 @@ from airbyte_cdk.destinations.vector_db_based.config import (
OpenAICompatibleEmbeddingConfigModel,
OpenAIEmbeddingConfigModel,
)
from airbyte_cdk.destinations.vector_db_based.document_processor import Chunk
from airbyte_cdk.destinations.vector_db_based.embedder import (
COHERE_VECTOR_SIZE,
OPEN_AI_VECTOR_SIZE,
AzureOpenAIEmbedder,
CohereEmbedder,
Document,
FakeEmbedder,
FromFieldEmbedder,
OpenAICompatibleEmbedder,
@@ -82,10 +82,10 @@ def test_embedder(embedder_class, args, dimensions):
mock_embedding_instance.embed_documents.return_value = [[0] * dimensions] * 2
chunks = [
Chunk(page_content="a", metadata={}, record=AirbyteRecordMessage(stream="mystream", data={}, emitted_at=0)),
Chunk(page_content="b", metadata={}, record=AirbyteRecordMessage(stream="mystream", data={}, emitted_at=0)),
Document(page_content="a", record=AirbyteRecordMessage(stream="mystream", data={}, emitted_at=0)),
Document(page_content="b", record=AirbyteRecordMessage(stream="mystream", data={}, emitted_at=0)),
]
assert embedder.embed_chunks(chunks) == mock_embedding_instance.embed_documents.return_value
assert embedder.embed_documents(chunks) == mock_embedding_instance.embed_documents.return_value
mock_embedding_instance.embed_documents.assert_called_with(["a", "b"])
@@ -102,12 +102,12 @@ def test_embedder(embedder_class, args, dimensions):
)
def test_from_field_embedder(field_name, dimensions, metadata, expected_embedding, expected_error):
embedder = FromFieldEmbedder(FromFieldEmbeddingConfigModel(mode="from_field", dimensions=dimensions, field_name=field_name))
chunks = [Chunk(page_content="a", metadata=metadata, record=AirbyteRecordMessage(stream="mystream", data=metadata, emitted_at=0))]
chunks = [Document(page_content="a", record=AirbyteRecordMessage(stream="mystream", data=metadata, emitted_at=0))]
if expected_error:
with pytest.raises(AirbyteTracedException):
embedder.embed_chunks(chunks)
embedder.embed_documents(chunks)
else:
assert embedder.embed_chunks(chunks) == [expected_embedding]
assert embedder.embed_documents(chunks) == [expected_embedding]
def test_openai_chunking():
@@ -119,7 +119,7 @@ def test_openai_chunking():
mock_embedding_instance.embed_documents.side_effect = lambda texts: [[0] * OPEN_AI_VECTOR_SIZE] * len(texts)
chunks = [
Chunk(page_content="a", metadata={}, record=AirbyteRecordMessage(stream="mystream", data={}, emitted_at=0)) for _ in range(1005)
Document(page_content="a", record=AirbyteRecordMessage(stream="mystream", data={}, emitted_at=0)) for _ in range(1005)
]
assert embedder.embed_chunks(chunks) == [[0] * OPEN_AI_VECTOR_SIZE] * 1005
assert embedder.embed_documents(chunks) == [[0] * OPEN_AI_VECTOR_SIZE] * 1005
mock_embedding_instance.embed_documents.assert_has_calls([call(["a"] * 1000), call(["a"] * 5)])

View File

@@ -48,8 +48,8 @@ def generate_stream(name: str = "example_stream", namespace: Optional[str] = Non
def generate_mock_embedder():
mock_embedder = MagicMock()
mock_embedder.embed_chunks.return_value = [[0] * 1536] * (BATCH_SIZE + 5 + 5)
mock_embedder.embed_chunks.side_effect = lambda chunks: [[0] * 1536] * len(chunks)
mock_embedder.embed_documents.return_value = [[0] * 1536] * (BATCH_SIZE + 5 + 5)
mock_embedder.embed_documents.side_effect = lambda chunks: [[0] * 1536] * len(chunks)
return mock_embedder
@@ -88,7 +88,7 @@ def test_write(omit_raw_text: bool):
# 1 batches due to max batch size reached and 1 batch due to state message
assert mock_indexer.index.call_count == 2
assert mock_indexer.delete.call_count == 2
assert mock_embedder.embed_chunks.call_count == 2
assert mock_embedder.embed_documents.call_count == 2
if omit_raw_text:
for call_args in mock_indexer.index.call_args_list:
@@ -110,7 +110,7 @@ def test_write(omit_raw_text: bool):
# 1 batch due to end of message stream
assert mock_indexer.index.call_count == 3
assert mock_indexer.delete.call_count == 3
assert mock_embedder.embed_chunks.call_count == 3
assert mock_embedder.embed_documents.call_count == 3
mock_indexer.post_sync.assert_called()
@@ -169,4 +169,4 @@ def test_write_stream_namespace_split():
call(ANY, None, "example_stream2"),
]
)
assert mock_embedder.embed_chunks.call_count == 4
assert mock_embedder.embed_documents.call_count == 4