Vector DB CDK: Refactor to improve readability (#33255)
Co-authored-by: flash1293 <flash1293@users.noreply.github.com>
This commit is contained in:
@@ -4,6 +4,7 @@
|
||||
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Union, cast
|
||||
|
||||
from airbyte_cdk.destinations.vector_db_based.config import (
|
||||
@@ -15,8 +16,8 @@ from airbyte_cdk.destinations.vector_db_based.config import (
|
||||
OpenAIEmbeddingConfigModel,
|
||||
ProcessingConfigModel,
|
||||
)
|
||||
from airbyte_cdk.destinations.vector_db_based.document_processor import Chunk
|
||||
from airbyte_cdk.destinations.vector_db_based.utils import create_chunks, format_exception
|
||||
from airbyte_cdk.models import AirbyteRecordMessage
|
||||
from airbyte_cdk.utils.traced_exception import AirbyteTracedException, FailureType
|
||||
from langchain.embeddings.cohere import CohereEmbeddings
|
||||
from langchain.embeddings.fake import FakeEmbeddings
|
||||
@@ -24,6 +25,12 @@ from langchain.embeddings.localai import LocalAIEmbeddings
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
|
||||
|
||||
@dataclass
|
||||
class Document:
|
||||
page_content: str
|
||||
record: AirbyteRecordMessage
|
||||
|
||||
|
||||
class Embedder(ABC):
|
||||
"""
|
||||
Embedder is an abstract class that defines the interface for embedding text.
|
||||
@@ -41,7 +48,7 @@ class Embedder(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def embed_chunks(self, chunks: List[Chunk]) -> List[Optional[List[float]]]:
|
||||
def embed_documents(self, documents: List[Document]) -> List[Optional[List[float]]]:
|
||||
"""
|
||||
Embed the text of each chunk and return the resulting embedding vectors.
|
||||
If a chunk cannot be embedded or is configured to not be embedded, return None for that chunk.
|
||||
@@ -72,7 +79,7 @@ class BaseOpenAIEmbedder(Embedder):
|
||||
return format_exception(e)
|
||||
return None
|
||||
|
||||
def embed_chunks(self, chunks: List[Chunk]) -> List[Optional[List[float]]]:
|
||||
def embed_documents(self, documents: List[Document]) -> List[Optional[List[float]]]:
|
||||
"""
|
||||
Embed the text of each chunk and return the resulting embedding vectors.
|
||||
|
||||
@@ -80,9 +87,9 @@ class BaseOpenAIEmbedder(Embedder):
|
||||
It's still possible to run into the rate limit between each embed call because the available token budget hasn't recovered between the calls,
|
||||
but the built-in retry mechanism of the OpenAI client handles that.
|
||||
"""
|
||||
# Each chunk can hold at most self.chunk_size tokens, so tokens-per-minute by maximum tokens per chunk is the number of chunks that can be embedded at once without exhausting the limit in a single request
|
||||
# Each chunk can hold at most self.chunk_size tokens, so tokens-per-minute by maximum tokens per chunk is the number of documents that can be embedded at once without exhausting the limit in a single request
|
||||
embedding_batch_size = OPEN_AI_TOKEN_LIMIT // self.chunk_size
|
||||
batches = create_chunks(chunks, batch_size=embedding_batch_size)
|
||||
batches = create_chunks(documents, batch_size=embedding_batch_size)
|
||||
embeddings: List[Optional[List[float]]] = []
|
||||
for batch in batches:
|
||||
embeddings.extend(self.embeddings.embed_documents([chunk.page_content for chunk in batch]))
|
||||
@@ -121,8 +128,8 @@ class CohereEmbedder(Embedder):
|
||||
return format_exception(e)
|
||||
return None
|
||||
|
||||
def embed_chunks(self, chunks: List[Chunk]) -> List[Optional[List[float]]]:
|
||||
return cast(List[Optional[List[float]]], self.embeddings.embed_documents([chunk.page_content or "" for chunk in chunks]))
|
||||
def embed_documents(self, documents: List[Document]) -> List[Optional[List[float]]]:
|
||||
return cast(List[Optional[List[float]]], self.embeddings.embed_documents([document.page_content for document in documents]))
|
||||
|
||||
@property
|
||||
def embedding_dimensions(self) -> int:
|
||||
@@ -142,8 +149,8 @@ class FakeEmbedder(Embedder):
|
||||
return format_exception(e)
|
||||
return None
|
||||
|
||||
def embed_chunks(self, chunks: List[Chunk]) -> List[Optional[List[float]]]:
|
||||
return cast(List[Optional[List[float]]], self.embeddings.embed_documents([chunk.page_content or "" for chunk in chunks]))
|
||||
def embed_documents(self, documents: List[Document]) -> List[Optional[List[float]]]:
|
||||
return cast(List[Optional[List[float]]], self.embeddings.embed_documents([document.page_content for document in documents]))
|
||||
|
||||
@property
|
||||
def embedding_dimensions(self) -> int:
|
||||
@@ -173,8 +180,8 @@ class OpenAICompatibleEmbedder(Embedder):
|
||||
return format_exception(e)
|
||||
return None
|
||||
|
||||
def embed_chunks(self, chunks: List[Chunk]) -> List[Optional[List[float]]]:
|
||||
return cast(List[Optional[List[float]]], self.embeddings.embed_documents([chunk.page_content or "" for chunk in chunks]))
|
||||
def embed_documents(self, documents: List[Document]) -> List[Optional[List[float]]]:
|
||||
return cast(List[Optional[List[float]]], self.embeddings.embed_documents([document.page_content for document in documents]))
|
||||
|
||||
@property
|
||||
def embedding_dimensions(self) -> int:
|
||||
@@ -190,32 +197,32 @@ class FromFieldEmbedder(Embedder):
|
||||
def check(self) -> Optional[str]:
|
||||
return None
|
||||
|
||||
def embed_chunks(self, chunks: List[Chunk]) -> List[Optional[List[float]]]:
|
||||
def embed_documents(self, documents: List[Document]) -> List[Optional[List[float]]]:
|
||||
"""
|
||||
From each chunk, pull the embedding from the field specified in the config.
|
||||
Check that the field exists, is a list of numbers and is the correct size. If not, raise an AirbyteTracedException explaining the problem.
|
||||
"""
|
||||
embeddings: List[Optional[List[float]]] = []
|
||||
for chunk in chunks:
|
||||
data = chunk.record.data
|
||||
for document in documents:
|
||||
data = document.record.data
|
||||
if self.config.field_name not in data:
|
||||
raise AirbyteTracedException(
|
||||
internal_message="Embedding vector field not found",
|
||||
failure_type=FailureType.config_error,
|
||||
message=f"Record {str(data)[:250]}... in stream {chunk.record.stream} does not contain embedding vector field {self.config.field_name}. Please check your embedding configuration, the embedding vector field has to be set correctly on every record.",
|
||||
message=f"Record {str(data)[:250]}... in stream {document.record.stream} does not contain embedding vector field {self.config.field_name}. Please check your embedding configuration, the embedding vector field has to be set correctly on every record.",
|
||||
)
|
||||
field = data[self.config.field_name]
|
||||
if not isinstance(field, list) or not all(isinstance(x, (int, float)) for x in field):
|
||||
raise AirbyteTracedException(
|
||||
internal_message="Embedding vector field not a list of numbers",
|
||||
failure_type=FailureType.config_error,
|
||||
message=f"Record {str(data)[:250]}... in stream {chunk.record.stream} does contain embedding vector field {self.config.field_name}, but it is not a list of numbers. Please check your embedding configuration, the embedding vector field has to be a list of numbers of length {self.config.dimensions} on every record.",
|
||||
message=f"Record {str(data)[:250]}... in stream {document.record.stream} does contain embedding vector field {self.config.field_name}, but it is not a list of numbers. Please check your embedding configuration, the embedding vector field has to be a list of numbers of length {self.config.dimensions} on every record.",
|
||||
)
|
||||
if len(field) != self.config.dimensions:
|
||||
raise AirbyteTracedException(
|
||||
internal_message="Embedding vector field has wrong length",
|
||||
failure_type=FailureType.config_error,
|
||||
message=f"Record {str(data)[:250]}... in stream {chunk.record.stream} does contain embedding vector field {self.config.field_name}, but it has length {len(field)} instead of the configured {self.config.dimensions}. Please check your embedding configuration, the embedding vector field has to be a list of numbers of length {self.config.dimensions} on every record.",
|
||||
message=f"Record {str(data)[:250]}... in stream {document.record.stream} does contain embedding vector field {self.config.field_name}, but it has length {len(field)} instead of the configured {self.config.dimensions}. Please check your embedding configuration, the embedding vector field has to be a list of numbers of length {self.config.dimensions} on every record.",
|
||||
)
|
||||
embeddings.append(field)
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ from typing import Dict, Iterable, List, Tuple
|
||||
|
||||
from airbyte_cdk.destinations.vector_db_based.config import ProcessingConfigModel
|
||||
from airbyte_cdk.destinations.vector_db_based.document_processor import Chunk, DocumentProcessor
|
||||
from airbyte_cdk.destinations.vector_db_based.embedder import Embedder
|
||||
from airbyte_cdk.destinations.vector_db_based.embedder import Document, Embedder
|
||||
from airbyte_cdk.destinations.vector_db_based.indexer import Indexer
|
||||
from airbyte_cdk.models import AirbyteMessage, ConfiguredAirbyteCatalog, Type
|
||||
|
||||
@@ -16,14 +16,14 @@ from airbyte_cdk.models import AirbyteMessage, ConfiguredAirbyteCatalog, Type
|
||||
class Writer:
|
||||
"""
|
||||
The Writer class is orchestrating the document processor, the embedder and the indexer:
|
||||
* Incoming records are passed through the document processor to generate documents
|
||||
* One the configured batch size is reached, the documents are passed to the embedder to generate embeddings
|
||||
* The embedder embeds the documents
|
||||
* The indexer deletes old documents by the associated record id before indexing the new ones
|
||||
* Incoming records are passed through the document processor to generate chunks
|
||||
* One the configured batch size is reached, the chunks are passed to the embedder to generate embeddings
|
||||
* The embedder embeds the chunks
|
||||
* The indexer deletes old chunks by the associated record id before indexing the new ones
|
||||
|
||||
The destination connector is responsible to create a writer instance and pass the input messages iterable to the write method.
|
||||
The batch size can be configured by the destination connector to give the freedom of either letting the user configure it or hardcoding it to a sensible value depending on the destination.
|
||||
The omit_raw_text parameter can be used to omit the raw text from the documents. This can be useful if the raw text is very large and not needed for the destination.
|
||||
The omit_raw_text parameter can be used to omit the raw text from the chunks. This can be useful if the raw text is very large and not needed for the destination.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -37,21 +37,29 @@ class Writer:
|
||||
self._init_batch()
|
||||
|
||||
def _init_batch(self) -> None:
|
||||
self.documents: Dict[Tuple[str, str], List[Chunk]] = defaultdict(list)
|
||||
self.chunks: Dict[Tuple[str, str], List[Chunk]] = defaultdict(list)
|
||||
self.ids_to_delete: Dict[Tuple[str, str], List[str]] = defaultdict(list)
|
||||
self.number_of_documents = 0
|
||||
self.number_of_chunks = 0
|
||||
|
||||
def _convert_to_document(self, chunk: Chunk) -> Document:
|
||||
"""
|
||||
Convert a chunk to a document for the embedder.
|
||||
"""
|
||||
if chunk.page_content is None:
|
||||
raise ValueError("Cannot embed a chunk without page content")
|
||||
return Document(page_content=chunk.page_content, record=chunk.record)
|
||||
|
||||
def _process_batch(self) -> None:
|
||||
for (namespace, stream), ids in self.ids_to_delete.items():
|
||||
self.indexer.delete(ids, namespace, stream)
|
||||
|
||||
for (namespace, stream), documents in self.documents.items():
|
||||
embeddings = self.embedder.embed_chunks(documents)
|
||||
for i, document in enumerate(documents):
|
||||
for (namespace, stream), chunks in self.chunks.items():
|
||||
embeddings = self.embedder.embed_documents([self._convert_to_document(chunk) for chunk in chunks])
|
||||
for i, document in enumerate(chunks):
|
||||
document.embedding = embeddings[i]
|
||||
if self.omit_raw_text:
|
||||
document.page_content = None
|
||||
self.indexer.index(documents, namespace, stream)
|
||||
self.indexer.index(chunks, namespace, stream)
|
||||
|
||||
self._init_batch()
|
||||
|
||||
@@ -65,12 +73,12 @@ class Writer:
|
||||
self._process_batch()
|
||||
yield message
|
||||
elif message.type == Type.RECORD:
|
||||
record_documents, record_id_to_delete = self.processor.process(message.record)
|
||||
self.documents[(message.record.namespace, message.record.stream)].extend(record_documents)
|
||||
record_chunks, record_id_to_delete = self.processor.process(message.record)
|
||||
self.chunks[(message.record.namespace, message.record.stream)].extend(record_chunks)
|
||||
if record_id_to_delete is not None:
|
||||
self.ids_to_delete[(message.record.namespace, message.record.stream)].append(record_id_to_delete)
|
||||
self.number_of_documents += len(record_documents)
|
||||
if self.number_of_documents >= self.batch_size:
|
||||
self.number_of_chunks += len(record_chunks)
|
||||
if self.number_of_chunks >= self.batch_size:
|
||||
self._process_batch()
|
||||
|
||||
self._process_batch()
|
||||
|
||||
@@ -13,12 +13,12 @@ from airbyte_cdk.destinations.vector_db_based.config import (
|
||||
OpenAICompatibleEmbeddingConfigModel,
|
||||
OpenAIEmbeddingConfigModel,
|
||||
)
|
||||
from airbyte_cdk.destinations.vector_db_based.document_processor import Chunk
|
||||
from airbyte_cdk.destinations.vector_db_based.embedder import (
|
||||
COHERE_VECTOR_SIZE,
|
||||
OPEN_AI_VECTOR_SIZE,
|
||||
AzureOpenAIEmbedder,
|
||||
CohereEmbedder,
|
||||
Document,
|
||||
FakeEmbedder,
|
||||
FromFieldEmbedder,
|
||||
OpenAICompatibleEmbedder,
|
||||
@@ -82,10 +82,10 @@ def test_embedder(embedder_class, args, dimensions):
|
||||
mock_embedding_instance.embed_documents.return_value = [[0] * dimensions] * 2
|
||||
|
||||
chunks = [
|
||||
Chunk(page_content="a", metadata={}, record=AirbyteRecordMessage(stream="mystream", data={}, emitted_at=0)),
|
||||
Chunk(page_content="b", metadata={}, record=AirbyteRecordMessage(stream="mystream", data={}, emitted_at=0)),
|
||||
Document(page_content="a", record=AirbyteRecordMessage(stream="mystream", data={}, emitted_at=0)),
|
||||
Document(page_content="b", record=AirbyteRecordMessage(stream="mystream", data={}, emitted_at=0)),
|
||||
]
|
||||
assert embedder.embed_chunks(chunks) == mock_embedding_instance.embed_documents.return_value
|
||||
assert embedder.embed_documents(chunks) == mock_embedding_instance.embed_documents.return_value
|
||||
mock_embedding_instance.embed_documents.assert_called_with(["a", "b"])
|
||||
|
||||
|
||||
@@ -102,12 +102,12 @@ def test_embedder(embedder_class, args, dimensions):
|
||||
)
|
||||
def test_from_field_embedder(field_name, dimensions, metadata, expected_embedding, expected_error):
|
||||
embedder = FromFieldEmbedder(FromFieldEmbeddingConfigModel(mode="from_field", dimensions=dimensions, field_name=field_name))
|
||||
chunks = [Chunk(page_content="a", metadata=metadata, record=AirbyteRecordMessage(stream="mystream", data=metadata, emitted_at=0))]
|
||||
chunks = [Document(page_content="a", record=AirbyteRecordMessage(stream="mystream", data=metadata, emitted_at=0))]
|
||||
if expected_error:
|
||||
with pytest.raises(AirbyteTracedException):
|
||||
embedder.embed_chunks(chunks)
|
||||
embedder.embed_documents(chunks)
|
||||
else:
|
||||
assert embedder.embed_chunks(chunks) == [expected_embedding]
|
||||
assert embedder.embed_documents(chunks) == [expected_embedding]
|
||||
|
||||
|
||||
def test_openai_chunking():
|
||||
@@ -119,7 +119,7 @@ def test_openai_chunking():
|
||||
mock_embedding_instance.embed_documents.side_effect = lambda texts: [[0] * OPEN_AI_VECTOR_SIZE] * len(texts)
|
||||
|
||||
chunks = [
|
||||
Chunk(page_content="a", metadata={}, record=AirbyteRecordMessage(stream="mystream", data={}, emitted_at=0)) for _ in range(1005)
|
||||
Document(page_content="a", record=AirbyteRecordMessage(stream="mystream", data={}, emitted_at=0)) for _ in range(1005)
|
||||
]
|
||||
assert embedder.embed_chunks(chunks) == [[0] * OPEN_AI_VECTOR_SIZE] * 1005
|
||||
assert embedder.embed_documents(chunks) == [[0] * OPEN_AI_VECTOR_SIZE] * 1005
|
||||
mock_embedding_instance.embed_documents.assert_has_calls([call(["a"] * 1000), call(["a"] * 5)])
|
||||
|
||||
@@ -48,8 +48,8 @@ def generate_stream(name: str = "example_stream", namespace: Optional[str] = Non
|
||||
|
||||
def generate_mock_embedder():
|
||||
mock_embedder = MagicMock()
|
||||
mock_embedder.embed_chunks.return_value = [[0] * 1536] * (BATCH_SIZE + 5 + 5)
|
||||
mock_embedder.embed_chunks.side_effect = lambda chunks: [[0] * 1536] * len(chunks)
|
||||
mock_embedder.embed_documents.return_value = [[0] * 1536] * (BATCH_SIZE + 5 + 5)
|
||||
mock_embedder.embed_documents.side_effect = lambda chunks: [[0] * 1536] * len(chunks)
|
||||
|
||||
return mock_embedder
|
||||
|
||||
@@ -88,7 +88,7 @@ def test_write(omit_raw_text: bool):
|
||||
# 1 batches due to max batch size reached and 1 batch due to state message
|
||||
assert mock_indexer.index.call_count == 2
|
||||
assert mock_indexer.delete.call_count == 2
|
||||
assert mock_embedder.embed_chunks.call_count == 2
|
||||
assert mock_embedder.embed_documents.call_count == 2
|
||||
|
||||
if omit_raw_text:
|
||||
for call_args in mock_indexer.index.call_args_list:
|
||||
@@ -110,7 +110,7 @@ def test_write(omit_raw_text: bool):
|
||||
# 1 batch due to end of message stream
|
||||
assert mock_indexer.index.call_count == 3
|
||||
assert mock_indexer.delete.call_count == 3
|
||||
assert mock_embedder.embed_chunks.call_count == 3
|
||||
assert mock_embedder.embed_documents.call_count == 3
|
||||
|
||||
mock_indexer.post_sync.assert_called()
|
||||
|
||||
@@ -169,4 +169,4 @@ def test_write_stream_namespace_split():
|
||||
call(ANY, None, "example_stream2"),
|
||||
]
|
||||
)
|
||||
assert mock_embedder.embed_chunks.call_count == 4
|
||||
assert mock_embedder.embed_documents.call_count == 4
|
||||
|
||||
Reference in New Issue
Block a user