1
0
mirror of synced 2025-12-23 21:03:15 -05:00

File-based CDK: make full refresh concurrent (#34411)

This commit is contained in:
Catherine Noll
2024-01-29 19:33:50 -05:00
committed by GitHub
parent d2171e48ea
commit eb31e4d2ba
24 changed files with 1042 additions and 195 deletions

View File

@@ -10,7 +10,7 @@ from airbyte_cdk.sources import AbstractSource
from airbyte_cdk.sources.concurrent_source.concurrent_source import ConcurrentSource
from airbyte_cdk.sources.streams import Stream
from airbyte_cdk.sources.streams.concurrent.abstract_stream import AbstractStream
from airbyte_cdk.sources.streams.concurrent.adapters import StreamFacade
from airbyte_cdk.sources.streams.concurrent.abstract_stream_facade import AbstractStreamFacade
class ConcurrentSourceAdapter(AbstractSource, ABC):
@@ -58,6 +58,6 @@ class ConcurrentSourceAdapter(AbstractSource, ABC):
f"The stream {configured_stream.stream.name} no longer exists in the configuration. "
f"Refresh the schema in replication settings and remove this stream from future sync attempts."
)
if isinstance(stream_instance, StreamFacade):
abstract_streams.append(stream_instance._abstract_stream)
if isinstance(stream_instance, AbstractStreamFacade):
abstract_streams.append(stream_instance.get_underlying_stream())
return abstract_streams

View File

@@ -1,4 +1,7 @@
from .abstract_file_based_availability_strategy import AbstractFileBasedAvailabilityStrategy
from .abstract_file_based_availability_strategy import (
AbstractFileBasedAvailabilityStrategy,
AbstractFileBasedAvailabilityStrategyWrapper,
)
from .default_file_based_availability_strategy import DefaultFileBasedAvailabilityStrategy
__all__ = ["AbstractFileBasedAvailabilityStrategy", "DefaultFileBasedAvailabilityStrategy"]
__all__ = ["AbstractFileBasedAvailabilityStrategy", "AbstractFileBasedAvailabilityStrategyWrapper", "DefaultFileBasedAvailabilityStrategy"]

View File

@@ -8,6 +8,12 @@ from typing import TYPE_CHECKING, Optional, Tuple
from airbyte_cdk.sources import Source
from airbyte_cdk.sources.streams.availability_strategy import AvailabilityStrategy
from airbyte_cdk.sources.streams.concurrent.availability_strategy import (
AbstractAvailabilityStrategy,
StreamAvailability,
StreamAvailable,
StreamUnavailable,
)
from airbyte_cdk.sources.streams.core import Stream
if TYPE_CHECKING:
@@ -35,3 +41,17 @@ class AbstractFileBasedAvailabilityStrategy(AvailabilityStrategy):
Returns (True, None) if successful, otherwise (False, <error message>).
"""
...
class AbstractFileBasedAvailabilityStrategyWrapper(AbstractAvailabilityStrategy):
def __init__(self, stream: "AbstractFileBasedStream"):
self.stream = stream
def check_availability(self, logger: logging.Logger) -> StreamAvailability:
is_available, reason = self.stream.availability_strategy.check_availability(self.stream, logger, None)
if is_available:
return StreamAvailable()
return StreamUnavailable(reason or "")
def check_availability_and_parsability(self, logger: logging.Logger) -> Tuple[bool, Optional[str]]:
return self.stream.availability_strategy.check_availability_and_parsability(self.stream, logger, None)

View File

@@ -8,8 +8,18 @@ from abc import ABC
from collections import Counter
from typing import Any, Iterator, List, Mapping, MutableMapping, Optional, Tuple, Type, Union
from airbyte_cdk.models import AirbyteMessage, AirbyteStateMessage, ConfiguredAirbyteCatalog, ConnectorSpecification
from airbyte_cdk.sources import AbstractSource
from airbyte_cdk.logger import AirbyteLogFormatter, init_logger
from airbyte_cdk.models import (
AirbyteMessage,
AirbyteStateMessage,
ConfiguredAirbyteCatalog,
ConnectorSpecification,
FailureType,
Level,
SyncMode,
)
from airbyte_cdk.sources.concurrent_source.concurrent_source import ConcurrentSource
from airbyte_cdk.sources.concurrent_source.concurrent_source_adapter import ConcurrentSourceAdapter
from airbyte_cdk.sources.file_based.availability_strategy import AbstractFileBasedAvailabilityStrategy, DefaultFileBasedAvailabilityStrategy
from airbyte_cdk.sources.file_based.config.abstract_file_based_spec import AbstractFileBasedSpec
from airbyte_cdk.sources.file_based.config.file_based_stream_config import FileBasedStreamConfig, ValidationPolicy
@@ -20,19 +30,33 @@ from airbyte_cdk.sources.file_based.file_types import default_parsers
from airbyte_cdk.sources.file_based.file_types.file_type_parser import FileTypeParser
from airbyte_cdk.sources.file_based.schema_validation_policies import DEFAULT_SCHEMA_VALIDATION_POLICIES, AbstractSchemaValidationPolicy
from airbyte_cdk.sources.file_based.stream import AbstractFileBasedStream, DefaultFileBasedStream
from airbyte_cdk.sources.file_based.stream.concurrent.adapters import FileBasedStreamFacade
from airbyte_cdk.sources.file_based.stream.concurrent.cursor import FileBasedNoopCursor
from airbyte_cdk.sources.file_based.stream.cursor import AbstractFileBasedCursor
from airbyte_cdk.sources.file_based.stream.cursor.default_file_based_cursor import DefaultFileBasedCursor
from airbyte_cdk.sources.message.repository import InMemoryMessageRepository, MessageRepository
from airbyte_cdk.sources.source import TState
from airbyte_cdk.sources.streams import Stream
from airbyte_cdk.utils.analytics_message import create_analytics_message
from airbyte_cdk.utils.traced_exception import AirbyteTracedException
from pydantic.error_wrappers import ValidationError
DEFAULT_CONCURRENCY = 100
MAX_CONCURRENCY = 100
INITIAL_N_PARTITIONS = MAX_CONCURRENCY // 2
class FileBasedSource(ConcurrentSourceAdapter, ABC):
# We make each source override the concurrency level to give control over when they are upgraded.
_concurrency_level = None
class FileBasedSource(AbstractSource, ABC):
def __init__(
self,
stream_reader: AbstractFileBasedStreamReader,
spec_class: Type[AbstractFileBasedSpec],
catalog_path: Optional[str] = None,
catalog: Optional[ConfiguredAirbyteCatalog],
config: Optional[Mapping[str, Any]],
state: Optional[TState],
availability_strategy: Optional[AbstractFileBasedAvailabilityStrategy] = None,
discovery_policy: AbstractDiscoveryPolicy = DefaultDiscoveryPolicy(),
parsers: Mapping[Type[Any], FileTypeParser] = default_parsers,
@@ -41,15 +65,29 @@ class FileBasedSource(AbstractSource, ABC):
):
self.stream_reader = stream_reader
self.spec_class = spec_class
self.config = config
self.catalog = catalog
self.state = state
self.availability_strategy = availability_strategy or DefaultFileBasedAvailabilityStrategy(stream_reader)
self.discovery_policy = discovery_policy
self.parsers = parsers
self.validation_policies = validation_policies
catalog = self.read_catalog(catalog_path) if catalog_path else None
self.stream_schemas = {s.stream.name: s.stream.json_schema for s in catalog.streams} if catalog else {}
self.cursor_cls = cursor_cls
self.logger = logging.getLogger(f"airbyte.{self.name}")
self.logger = init_logger(f"airbyte.{self.name}")
self.errors_collector: FileBasedErrorsCollector = FileBasedErrorsCollector()
self._message_repository: Optional[MessageRepository] = None
concurrent_source = ConcurrentSource.create(
MAX_CONCURRENCY, INITIAL_N_PARTITIONS, self.logger, self._slice_logger, self.message_repository
)
self._state = None
super().__init__(concurrent_source)
@property
def message_repository(self) -> MessageRepository:
if self._message_repository is None:
self._message_repository = InMemoryMessageRepository(Level(AirbyteLogFormatter.level_mapping[self.logger.level]))
return self._message_repository
def check_connection(self, logger: logging.Logger, config: Mapping[str, Any]) -> Tuple[bool, Optional[Any]]:
"""
@@ -61,7 +99,15 @@ class FileBasedSource(AbstractSource, ABC):
Otherwise, the "error" object should describe what went wrong.
"""
streams = self.streams(config)
try:
streams = self.streams(config)
except Exception as config_exception:
raise AirbyteTracedException(
internal_message="Please check the logged errors for more information.",
message=FileBasedSourceError.CONFIG_VALIDATION_ERROR.value,
exception=AirbyteTracedException(exception=config_exception),
failure_type=FailureType.config_error,
)
if len(streams) == 0:
return (
False,
@@ -80,7 +126,7 @@ class FileBasedSource(AbstractSource, ABC):
reason,
) = stream.availability_strategy.check_availability_and_parsability(stream, logger, self)
except Exception:
errors.append(f"Unable to connect to stream {stream} - {''.join(traceback.format_exc())}")
errors.append(f"Unable to connect to stream {stream.name} - {''.join(traceback.format_exc())}")
else:
if not stream_is_available and reason:
errors.append(reason)
@@ -91,10 +137,26 @@ class FileBasedSource(AbstractSource, ABC):
"""
Return a list of this source's streams.
"""
file_based_streams = self._get_file_based_streams(config)
configured_streams: List[Stream] = []
for stream in file_based_streams:
sync_mode = self._get_sync_mode_from_catalog(stream)
if sync_mode == SyncMode.full_refresh and hasattr(self, "_concurrency_level") and self._concurrency_level is not None:
configured_streams.append(
FileBasedStreamFacade.create_from_stream(stream, self, self.logger, None, FileBasedNoopCursor(stream.config))
)
else:
configured_streams.append(stream)
return configured_streams
def _get_file_based_streams(self, config: Mapping[str, Any]) -> List[AbstractFileBasedStream]:
try:
parsed_config = self._get_parsed_config(config)
self.stream_reader.config = parsed_config
streams: List[Stream] = []
streams: List[AbstractFileBasedStream] = []
for stream_config in parsed_config.streams:
self._validate_input_schema(stream_config)
streams.append(
@@ -115,6 +177,14 @@ class FileBasedSource(AbstractSource, ABC):
except ValidationError as exc:
raise ConfigValidationError(FileBasedSourceError.CONFIG_VALIDATION_ERROR) from exc
def _get_sync_mode_from_catalog(self, stream: Stream) -> Optional[SyncMode]:
if self.catalog:
for catalog_stream in self.catalog.streams:
if stream.name == catalog_stream.stream.name:
return catalog_stream.sync_mode
raise RuntimeError(f"No sync mode was found for {stream.name}.")
return None
def read(
self,
logger: logging.Logger,

View File

@@ -10,6 +10,7 @@ from collections import defaultdict
from functools import partial
from io import IOBase
from typing import Any, Callable, Dict, Generator, Iterable, List, Mapping, Optional, Set, Tuple
from uuid import uuid4
from airbyte_cdk.models import FailureType
from airbyte_cdk.sources.file_based.config.csv_format import CsvFormat, CsvHeaderAutogenerated, CsvHeaderUserProvided, InferenceType
@@ -38,8 +39,10 @@ class _CsvReader:
# Formats are configured individually per-stream so a unique dialect should be registered for each stream.
# We don't unregister the dialect because we are lazily parsing each csv file to generate records
# This will potentially be a problem if we ever process multiple streams concurrently
dialect_name = config.name + DIALECT_NAME
# Give each stream's dialect a unique name; otherwise, when we are doing a concurrent sync we can end up
# with a race condition where a thread attempts to use a dialect before a separate thread has finished
# registering it.
dialect_name = f"{config.name}_{str(uuid4())}_{DIALECT_NAME}"
csv.register_dialect(
dialect_name,
delimiter=config_format.delimiter,

View File

@@ -0,0 +1,322 @@
#
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.
#
import copy
import logging
from functools import lru_cache
from typing import Any, Iterable, List, Mapping, MutableMapping, Optional, Union
from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, Level, SyncMode, Type
from airbyte_cdk.sources import AbstractSource
from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager
from airbyte_cdk.sources.file_based.availability_strategy import (
AbstractFileBasedAvailabilityStrategy,
AbstractFileBasedAvailabilityStrategyWrapper,
)
from airbyte_cdk.sources.file_based.config.file_based_stream_config import PrimaryKeyType
from airbyte_cdk.sources.file_based.file_types.file_type_parser import FileTypeParser
from airbyte_cdk.sources.file_based.remote_file import RemoteFile
from airbyte_cdk.sources.file_based.stream import AbstractFileBasedStream
from airbyte_cdk.sources.file_based.stream.concurrent.cursor import FileBasedNoopCursor
from airbyte_cdk.sources.file_based.types import StreamSlice
from airbyte_cdk.sources.message import MessageRepository
from airbyte_cdk.sources.streams.concurrent.abstract_stream_facade import AbstractStreamFacade
from airbyte_cdk.sources.streams.concurrent.default_stream import DefaultStream
from airbyte_cdk.sources.streams.concurrent.exceptions import ExceptionWithDisplayMessage
from airbyte_cdk.sources.streams.concurrent.helpers import get_cursor_field_from_stream, get_primary_key_from_stream
from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition
from airbyte_cdk.sources.streams.concurrent.partitions.partition_generator import PartitionGenerator
from airbyte_cdk.sources.streams.concurrent.partitions.record import Record
from airbyte_cdk.sources.streams.core import StreamData
from airbyte_cdk.sources.utils.schema_helpers import InternalConfig
from airbyte_cdk.sources.utils.slice_logger import SliceLogger
from deprecated.classic import deprecated
"""
This module contains adapters to help enabling concurrency on File-based Stream objects without needing to migrate to AbstractStream
"""
@deprecated("This class is experimental. Use at your own risk.")
class FileBasedStreamFacade(AbstractStreamFacade[DefaultStream], AbstractFileBasedStream):
@classmethod
def create_from_stream(
cls,
stream: AbstractFileBasedStream,
source: AbstractSource,
logger: logging.Logger,
state: Optional[MutableMapping[str, Any]],
cursor: FileBasedNoopCursor,
) -> "FileBasedStreamFacade":
"""
Create a ConcurrentStream from a FileBasedStream object.
"""
pk = get_primary_key_from_stream(stream.primary_key)
cursor_field = get_cursor_field_from_stream(stream)
if not source.message_repository:
raise ValueError(
"A message repository is required to emit non-record messages. Please set the message repository on the source."
)
message_repository = source.message_repository
return FileBasedStreamFacade(
DefaultStream( # type: ignore
partition_generator=FileBasedStreamPartitionGenerator(
stream,
message_repository,
SyncMode.full_refresh if isinstance(cursor, FileBasedNoopCursor) else SyncMode.incremental,
[cursor_field] if cursor_field is not None else None,
state,
cursor,
),
name=stream.name,
json_schema=stream.get_json_schema(),
availability_strategy=AbstractFileBasedAvailabilityStrategyWrapper(stream),
primary_key=pk,
cursor_field=cursor_field,
logger=logger,
namespace=stream.namespace,
),
stream,
cursor,
logger=logger,
slice_logger=source._slice_logger,
)
def __init__(
self,
stream: DefaultStream,
legacy_stream: AbstractFileBasedStream,
cursor: FileBasedNoopCursor,
slice_logger: SliceLogger,
logger: logging.Logger,
):
"""
:param stream: The underlying AbstractStream
"""
# super().__init__(stream, legacy_stream, cursor, slice_logger, logger)
self._abstract_stream = stream
self._legacy_stream = legacy_stream
self._cursor = cursor
self._slice_logger = slice_logger
self._logger = logger
self.catalog_schema = legacy_stream.catalog_schema
self.config = legacy_stream.config
self.validation_policy = legacy_stream.validation_policy
@property
def cursor_field(self) -> Union[str, List[str]]:
if self._abstract_stream.cursor_field is None:
return []
else:
return self._abstract_stream.cursor_field
@property
def name(self) -> str:
return self._abstract_stream.name
@property
def supports_incremental(self) -> bool:
return self._legacy_stream.supports_incremental
@property
def availability_strategy(self) -> AbstractFileBasedAvailabilityStrategy:
return self._legacy_stream.availability_strategy
@lru_cache(maxsize=None)
def get_json_schema(self) -> Mapping[str, Any]:
return self._abstract_stream.get_json_schema()
@property
def primary_key(self) -> PrimaryKeyType:
return self._legacy_stream.config.primary_key or self.get_parser().get_parser_defined_primary_key(self._legacy_stream.config)
def get_parser(self) -> FileTypeParser:
return self._legacy_stream.get_parser()
def get_files(self) -> Iterable[RemoteFile]:
return self._legacy_stream.get_files()
def read_records_from_slice(self, stream_slice: StreamSlice) -> Iterable[Mapping[str, Any]]:
yield from self._legacy_stream.read_records_from_slice(stream_slice)
def compute_slices(self) -> Iterable[Optional[StreamSlice]]:
return self._legacy_stream.compute_slices()
def infer_schema(self, files: List[RemoteFile]) -> Mapping[str, Any]:
return self._legacy_stream.infer_schema(files)
def get_underlying_stream(self) -> DefaultStream:
return self._abstract_stream
def read_full_refresh(
self,
cursor_field: Optional[List[str]],
logger: logging.Logger,
slice_logger: SliceLogger,
) -> Iterable[StreamData]:
"""
Read full refresh. Delegate to the underlying AbstractStream, ignoring all the parameters
:param cursor_field: (ignored)
:param logger: (ignored)
:param slice_logger: (ignored)
:return: Iterable of StreamData
"""
yield from self._read_records()
def read_incremental(
self,
cursor_field: Optional[List[str]],
logger: logging.Logger,
slice_logger: SliceLogger,
stream_state: MutableMapping[str, Any],
state_manager: ConnectorStateManager,
per_stream_state_enabled: bool,
internal_config: InternalConfig,
) -> Iterable[StreamData]:
yield from self._read_records()
def read_records(
self,
sync_mode: SyncMode,
cursor_field: Optional[List[str]] = None,
stream_slice: Optional[Mapping[str, Any]] = None,
stream_state: Optional[Mapping[str, Any]] = None,
) -> Iterable[StreamData]:
try:
yield from self._read_records()
except Exception as exc:
if hasattr(self._cursor, "state"):
state = str(self._cursor.state)
else:
# This shouldn't happen if the ConcurrentCursor was used
state = "unknown; no state attribute was available on the cursor"
yield AirbyteMessage(
type=Type.LOG, log=AirbyteLogMessage(level=Level.ERROR, message=f"Cursor State at time of exception: {state}")
)
raise exc
def _read_records(self) -> Iterable[StreamData]:
for partition in self._abstract_stream.generate_partitions():
if self._slice_logger.should_log_slice_message(self._logger):
yield self._slice_logger.create_slice_log_message(partition.to_slice())
for record in partition.read():
yield record.data
class FileBasedStreamPartition(Partition):
def __init__(
self,
stream: AbstractFileBasedStream,
_slice: Optional[Mapping[str, Any]],
message_repository: MessageRepository,
sync_mode: SyncMode,
cursor_field: Optional[List[str]],
state: Optional[MutableMapping[str, Any]],
cursor: FileBasedNoopCursor,
):
self._stream = stream
self._slice = _slice
self._message_repository = message_repository
self._sync_mode = sync_mode
self._cursor_field = cursor_field
self._state = state
self._cursor = cursor
self._is_closed = False
def read(self) -> Iterable[Record]:
try:
for record_data in self._stream.read_records(
cursor_field=self._cursor_field,
sync_mode=SyncMode.full_refresh,
stream_slice=copy.deepcopy(self._slice),
stream_state=self._state,
):
if isinstance(record_data, Mapping):
data_to_return = dict(record_data)
self._stream.transformer.transform(data_to_return, self._stream.get_json_schema())
yield Record(data_to_return, self.stream_name())
else:
self._message_repository.emit_message(record_data)
except Exception as e:
display_message = self._stream.get_error_display_message(e)
if display_message:
raise ExceptionWithDisplayMessage(display_message) from e
else:
raise e
def to_slice(self) -> Optional[Mapping[str, Any]]:
if self._slice is None:
return None
assert (
len(self._slice["files"]) == 1
), f"Expected 1 file per partition but got {len(self._slice['files'])} for stream {self.stream_name()}"
file = self._slice["files"][0]
return {"files": [file]}
def close(self) -> None:
self._cursor.close_partition(self)
self._is_closed = True
def is_closed(self) -> bool:
return self._is_closed
def __hash__(self) -> int:
if self._slice:
# Convert the slice to a string so that it can be hashed
if len(self._slice["files"]) != 1:
raise ValueError(
f"Slices for file-based streams should be of length 1, but got {len(self._slice['files'])}. This is unexpected. Please contact Support."
)
else:
s = f"{self._slice['files'][0].last_modified.strftime('%Y-%m-%dT%H:%M:%S.%fZ')}_{self._slice['files'][0].uri}"
return hash((self._stream.name, s))
else:
return hash(self._stream.name)
def stream_name(self) -> str:
return self._stream.name
def __repr__(self) -> str:
return f"FileBasedStreamPartition({self._stream.name}, {self._slice})"
class FileBasedStreamPartitionGenerator(PartitionGenerator):
def __init__(
self,
stream: AbstractFileBasedStream,
message_repository: MessageRepository,
sync_mode: SyncMode,
cursor_field: Optional[List[str]],
state: Optional[MutableMapping[str, Any]],
cursor: FileBasedNoopCursor,
):
self._stream = stream
self._message_repository = message_repository
self._sync_mode = sync_mode
self._cursor_field = cursor_field
self._state = state
self._cursor = cursor
def generate(self) -> Iterable[FileBasedStreamPartition]:
pending_partitions = []
for _slice in self._stream.stream_slices(sync_mode=self._sync_mode, cursor_field=self._cursor_field, stream_state=self._state):
if _slice is not None:
pending_partitions.extend(
[
FileBasedStreamPartition(
self._stream,
{"files": [copy.deepcopy(f)]},
self._message_repository,
self._sync_mode,
self._cursor_field,
self._state,
self._cursor,
)
for f in _slice.get("files", [])
]
)
self._cursor.set_pending_partitions(pending_partitions)
yield from pending_partitions

View File

@@ -0,0 +1,87 @@
#
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
import logging
from abc import abstractmethod
from datetime import datetime
from typing import Any, Iterable, MutableMapping
from airbyte_cdk.sources.file_based.config.file_based_stream_config import FileBasedStreamConfig
from airbyte_cdk.sources.file_based.remote_file import RemoteFile
from airbyte_cdk.sources.file_based.stream.cursor import AbstractFileBasedCursor
from airbyte_cdk.sources.file_based.types import StreamState
from airbyte_cdk.sources.streams.concurrent.cursor import Cursor
from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition
from airbyte_cdk.sources.streams.concurrent.partitions.record import Record
class AbstractFileBasedConcurrentCursor(Cursor, AbstractFileBasedCursor):
@property
@abstractmethod
def state(self) -> MutableMapping[str, Any]:
...
@abstractmethod
def add_file(self, file: RemoteFile) -> None:
...
@abstractmethod
def set_initial_state(self, value: StreamState) -> None:
...
@abstractmethod
def get_state(self) -> MutableMapping[str, Any]:
...
@abstractmethod
def get_start_time(self) -> datetime:
...
@abstractmethod
def get_files_to_sync(self, all_files: Iterable[RemoteFile], logger: logging.Logger) -> Iterable[RemoteFile]:
...
@abstractmethod
def observe(self, record: Record) -> None:
...
@abstractmethod
def close_partition(self, partition: Partition) -> None:
...
@abstractmethod
def set_pending_partitions(self, partitions: Iterable[Partition]) -> None:
...
class FileBasedNoopCursor(AbstractFileBasedConcurrentCursor):
def __init__(self, stream_config: FileBasedStreamConfig, **kwargs: Any):
pass
@property
def state(self) -> MutableMapping[str, Any]:
return {}
def add_file(self, file: RemoteFile) -> None:
return None
def set_initial_state(self, value: StreamState) -> None:
return None
def get_state(self) -> MutableMapping[str, Any]:
return {}
def get_start_time(self) -> datetime:
return datetime.min
def get_files_to_sync(self, all_files: Iterable[RemoteFile], logger: logging.Logger) -> Iterable[RemoteFile]:
return []
def observe(self, record: Record) -> None:
return None
def close_partition(self, partition: Partition) -> None:
return None
def set_pending_partitions(self, partitions: Iterable[Partition]) -> None:
return None

View File

@@ -164,9 +164,13 @@ class DefaultFileBasedStream(AbstractFileBasedStream, IncrementalMixin):
try:
schema = self._get_raw_json_schema()
except (InvalidSchemaError, NoFilesMatchingError) as config_exception:
self.logger.exception(FileBasedSourceError.SCHEMA_INFERENCE_ERROR.value, exc_info=config_exception)
raise AirbyteTracedException(
message=FileBasedSourceError.SCHEMA_INFERENCE_ERROR.value, exception=config_exception, failure_type=FailureType.config_error
) from config_exception
internal_message="Please check the logged errors for more information.",
message=FileBasedSourceError.SCHEMA_INFERENCE_ERROR.value,
exception=AirbyteTracedException(exception=config_exception),
failure_type=FailureType.config_error,
)
except Exception as exc:
raise SchemaInferenceError(FileBasedSourceError.SCHEMA_INFERENCE_ERROR, stream=self.name) from exc
else:

View File

@@ -0,0 +1,37 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
from abc import ABC, abstractmethod
from typing import Generic, Optional, TypeVar
from airbyte_cdk.sources.streams.concurrent.exceptions import ExceptionWithDisplayMessage
StreamType = TypeVar("StreamType")
class AbstractStreamFacade(Generic[StreamType], ABC):
@abstractmethod
def get_underlying_stream(self) -> StreamType:
"""
Return the underlying stream facade object.
"""
...
@property
def source_defined_cursor(self) -> bool:
# Streams must be aware of their cursor at instantiation time
return True
def get_error_display_message(self, exception: BaseException) -> Optional[str]:
"""
Retrieves the user-friendly display message that corresponds to an exception.
This will be called when encountering an exception while reading records from the stream, and used to build the AirbyteTraceMessage.
A display message will be returned if the exception is an instance of ExceptionWithDisplayMessage.
:param exception: The exception that was raised
:return: A user-friendly message that indicates the cause of the error
"""
if isinstance(exception, ExceptionWithDisplayMessage):
return exception.display_message
else:
return None

View File

@@ -14,7 +14,7 @@ from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager
from airbyte_cdk.sources.message import MessageRepository
from airbyte_cdk.sources.streams import Stream
from airbyte_cdk.sources.streams.availability_strategy import AvailabilityStrategy
from airbyte_cdk.sources.streams.concurrent.abstract_stream import AbstractStream
from airbyte_cdk.sources.streams.concurrent.abstract_stream_facade import AbstractStreamFacade
from airbyte_cdk.sources.streams.concurrent.availability_strategy import (
AbstractAvailabilityStrategy,
StreamAvailability,
@@ -24,6 +24,7 @@ from airbyte_cdk.sources.streams.concurrent.availability_strategy import (
from airbyte_cdk.sources.streams.concurrent.cursor import Cursor, NoopCursor
from airbyte_cdk.sources.streams.concurrent.default_stream import DefaultStream
from airbyte_cdk.sources.streams.concurrent.exceptions import ExceptionWithDisplayMessage
from airbyte_cdk.sources.streams.concurrent.helpers import get_cursor_field_from_stream, get_primary_key_from_stream
from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition
from airbyte_cdk.sources.streams.concurrent.partitions.partition_generator import PartitionGenerator
from airbyte_cdk.sources.streams.concurrent.partitions.record import Record
@@ -38,7 +39,7 @@ This module contains adapters to help enabling concurrency on Stream objects wit
@deprecated("This class is experimental. Use at your own risk.")
class StreamFacade(Stream):
class StreamFacade(AbstractStreamFacade[DefaultStream], Stream):
"""
The StreamFacade is a Stream that wraps an AbstractStream and exposes it as a Stream.
@@ -62,8 +63,8 @@ class StreamFacade(Stream):
:param max_workers: The maximum number of worker thread to use
:return:
"""
pk = cls._get_primary_key_from_stream(stream.primary_key)
cursor_field = cls._get_cursor_field_from_stream(stream)
pk = get_primary_key_from_stream(stream.primary_key)
cursor_field = get_cursor_field_from_stream(stream)
if not source.message_repository:
raise ValueError(
@@ -104,33 +105,7 @@ class StreamFacade(Stream):
if "state" in dir(self._legacy_stream):
self._legacy_stream.state = value # type: ignore # validating `state` is attribute of stream using `if` above
@classmethod
def _get_primary_key_from_stream(cls, stream_primary_key: Optional[Union[str, List[str], List[List[str]]]]) -> List[str]:
if stream_primary_key is None:
return []
elif isinstance(stream_primary_key, str):
return [stream_primary_key]
elif isinstance(stream_primary_key, list):
if len(stream_primary_key) > 0 and all(isinstance(k, str) for k in stream_primary_key):
return stream_primary_key # type: ignore # We verified all items in the list are strings
else:
raise ValueError(f"Nested primary keys are not supported. Found {stream_primary_key}")
else:
raise ValueError(f"Invalid type for primary key: {stream_primary_key}")
@classmethod
def _get_cursor_field_from_stream(cls, stream: Stream) -> Optional[str]:
if isinstance(stream.cursor_field, list):
if len(stream.cursor_field) > 1:
raise ValueError(f"Nested cursor fields are not supported. Got {stream.cursor_field} for {stream.name}")
elif len(stream.cursor_field) == 0:
return None
else:
return stream.cursor_field[0]
else:
return stream.cursor_field
def __init__(self, stream: AbstractStream, legacy_stream: Stream, cursor: Cursor, slice_logger: SliceLogger, logger: logging.Logger):
def __init__(self, stream: DefaultStream, legacy_stream: Stream, cursor: Cursor, slice_logger: SliceLogger, logger: logging.Logger):
"""
:param stream: The underlying AbstractStream
"""
@@ -178,7 +153,7 @@ class StreamFacade(Stream):
yield from self._read_records()
except Exception as exc:
if hasattr(self._cursor, "state"):
state = self._cursor.state
state = str(self._cursor.state)
else:
# This shouldn't happen if the ConcurrentCursor was used
state = "unknown; no state attribute was available on the cursor"
@@ -210,11 +185,6 @@ class StreamFacade(Stream):
else:
return self._abstract_stream.cursor_field
@property
def source_defined_cursor(self) -> bool:
# Streams must be aware of their cursor at instantiation time
return True
@lru_cache(maxsize=None)
def get_json_schema(self) -> Mapping[str, Any]:
return self._abstract_stream.get_json_schema()
@@ -233,27 +203,15 @@ class StreamFacade(Stream):
availability = self._abstract_stream.check_availability()
return availability.is_available(), availability.message()
def get_error_display_message(self, exception: BaseException) -> Optional[str]:
"""
Retrieves the user-friendly display message that corresponds to an exception.
This will be called when encountering an exception while reading records from the stream, and used to build the AirbyteTraceMessage.
A display message will be returned if the exception is an instance of ExceptionWithDisplayMessage.
:param exception: The exception that was raised
:return: A user-friendly message that indicates the cause of the error
"""
if isinstance(exception, ExceptionWithDisplayMessage):
return exception.display_message
else:
return None
def as_airbyte_stream(self) -> AirbyteStream:
return self._abstract_stream.as_airbyte_stream()
def log_stream_sync_configuration(self) -> None:
self._abstract_stream.log_stream_sync_configuration()
def get_underlying_stream(self) -> DefaultStream:
return self._abstract_stream
class StreamPartition(Partition):
"""

View File

@@ -0,0 +1,31 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
from typing import List, Optional, Union
from airbyte_cdk.sources.streams import Stream
def get_primary_key_from_stream(stream_primary_key: Optional[Union[str, List[str], List[List[str]]]]) -> List[str]:
if stream_primary_key is None:
return []
elif isinstance(stream_primary_key, str):
return [stream_primary_key]
elif isinstance(stream_primary_key, list):
if len(stream_primary_key) > 0 and all(isinstance(k, str) for k in stream_primary_key):
return stream_primary_key # type: ignore # We verified all items in the list are strings
else:
raise ValueError(f"Nested primary keys are not supported. Found {stream_primary_key}")
else:
raise ValueError(f"Invalid type for primary key: {stream_primary_key}")
def get_cursor_field_from_stream(stream: Stream) -> Optional[str]:
if isinstance(stream.cursor_field, list):
if len(stream.cursor_field) > 1:
raise ValueError(f"Nested cursor fields are not supported. Got {stream.cursor_field} for {stream.name}")
elif len(stream.cursor_field) == 0:
return None
else:
return stream.cursor_field[0]
else:
return stream.cursor_field

View File

@@ -447,11 +447,13 @@ class CsvReaderTest(unittest.TestCase):
.build()
)
dialects_before = set(csv.list_dialects())
data_generator = self._read_data()
next(data_generator)
assert f"{self._CONFIG_NAME}_config_dialect" in csv.list_dialects()
[new_dialect] = set(csv.list_dialects()) - dialects_before
assert self._CONFIG_NAME in new_dialect
data_generator.close()
assert f"{self._CONFIG_NAME}_config_dialect" not in csv.list_dialects()
assert new_dialect not in csv.list_dialects()
def test_given_too_many_values_for_columns_when_read_data_then_raise_exception_and_unregister_dialect(self) -> None:
self._stream_reader.open_file.return_value = (
@@ -466,13 +468,15 @@ class CsvReaderTest(unittest.TestCase):
.build()
)
dialects_before = set(csv.list_dialects())
data_generator = self._read_data()
next(data_generator)
assert f"{self._CONFIG_NAME}_config_dialect" in csv.list_dialects()
[new_dialect] = set(csv.list_dialects()) - dialects_before
assert self._CONFIG_NAME in new_dialect
with pytest.raises(RecordParseError):
next(data_generator)
assert f"{self._CONFIG_NAME}_config_dialect" not in csv.list_dialects()
assert new_dialect not in csv.list_dialects()
def test_given_too_few_values_for_columns_when_read_data_then_raise_exception_and_unregister_dialect(self) -> None:
self._stream_reader.open_file.return_value = (
@@ -487,13 +491,15 @@ class CsvReaderTest(unittest.TestCase):
.build()
)
dialects_before = set(csv.list_dialects())
data_generator = self._read_data()
next(data_generator)
assert f"{self._CONFIG_NAME}_config_dialect" in csv.list_dialects()
[new_dialect] = set(csv.list_dialects()) - dialects_before
assert self._CONFIG_NAME in new_dialect
with pytest.raises(RecordParseError):
next(data_generator)
assert f"{self._CONFIG_NAME}_config_dialect" not in csv.list_dialects()
assert new_dialect not in csv.list_dialects()
def _read_data(self) -> Generator[Dict[str, str], None, None]:
data_generator = self._csv_reader.read_data(

View File

@@ -26,11 +26,14 @@ from airbyte_cdk.sources.file_based.file_types.file_type_parser import FileTypeP
from airbyte_cdk.sources.file_based.remote_file import RemoteFile
from airbyte_cdk.sources.file_based.schema_validation_policies import DEFAULT_SCHEMA_VALIDATION_POLICIES, AbstractSchemaValidationPolicy
from airbyte_cdk.sources.file_based.stream.cursor import AbstractFileBasedCursor, DefaultFileBasedCursor
from airbyte_cdk.sources.source import TState
from avro import datafile
from pydantic import AnyUrl
class InMemoryFilesSource(FileBasedSource):
_concurrency_level = 10
def __init__(
self,
files: Mapping[str, Any],
@@ -41,6 +44,8 @@ class InMemoryFilesSource(FileBasedSource):
parsers: Mapping[str, FileTypeParser],
stream_reader: Optional[AbstractFileBasedStreamReader],
catalog: Optional[Mapping[str, Any]],
config: Optional[Mapping[str, Any]],
state: Optional[TState],
file_write_options: Mapping[str, Any],
cursor_cls: Optional[AbstractFileBasedCursor],
):
@@ -48,6 +53,9 @@ class InMemoryFilesSource(FileBasedSource):
self.files = files
self.file_type = file_type
self.catalog = catalog
self.configured_catalog = ConfiguredAirbyteCatalog(streams=self.catalog["streams"]) if self.catalog else None
self.config = config
self.state = state
# Source setup
stream_reader = stream_reader or InMemoryFilesStreamReader(files=files, file_type=file_type, file_write_options=file_write_options)
@@ -55,7 +63,9 @@ class InMemoryFilesSource(FileBasedSource):
super().__init__(
stream_reader,
spec_class=InMemorySpec,
catalog_path="fake_path" if catalog else None,
catalog=self.configured_catalog,
config=self.config,
state=self.state,
availability_strategy=availability_strategy,
discovery_policy=discovery_policy or DefaultDiscoveryPolicy(),
parsers=parsers,
@@ -64,7 +74,7 @@ class InMemoryFilesSource(FileBasedSource):
)
def read_catalog(self, catalog_path: str) -> ConfiguredAirbyteCatalog:
return ConfiguredAirbyteCatalog(streams=self.catalog["streams"]) if self.catalog else None
return self.configured_catalog
class InMemoryFilesStreamReader(AbstractFileBasedStreamReader):

View File

@@ -847,7 +847,7 @@ invalid_csv_scenario: TestScenario[InMemoryFilesSource] = (
"read": [
{
"level": "ERROR",
"message": f"{FileBasedSourceError.ERROR_PARSING_RECORD.value} stream=stream1 file=a.csv line_no=1 n_skipped=0",
"message": f"{FileBasedSourceError.INVALID_SCHEMA_ERROR.value} stream=stream1 file=a.csv line_no=1 n_skipped=0",
},
]
}
@@ -1471,28 +1471,7 @@ empty_schema_inference_scenario: TestScenario[InMemoryFilesSource] = (
}
)
.set_expected_discover_error(AirbyteTracedException, FileBasedSourceError.SCHEMA_INFERENCE_ERROR.value)
.set_expected_records(
[
{
"data": {
"col1": "val11",
"col2": "val12",
"_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z",
"_ab_source_file_url": "a.csv",
},
"stream": "stream1",
},
{
"data": {
"col1": "val21",
"col2": "val22",
"_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z",
"_ab_source_file_url": "a.csv",
},
"stream": "stream1",
},
]
)
.set_expected_records([])
).build()
schemaless_csv_scenario: TestScenario[InMemoryFilesSource] = (
@@ -1766,7 +1745,7 @@ schemaless_with_user_input_schema_fails_connection_check_scenario: TestScenario[
}
)
.set_expected_check_status("FAILED")
.set_expected_check_error(ConfigValidationError, FileBasedSourceError.CONFIG_VALIDATION_ERROR.value)
.set_expected_check_error(AirbyteTracedException, FileBasedSourceError.CONFIG_VALIDATION_ERROR.value)
.set_expected_discover_error(ConfigValidationError, FileBasedSourceError.CONFIG_VALIDATION_ERROR.value)
.set_expected_read_error(ConfigValidationError, FileBasedSourceError.CONFIG_VALIDATION_ERROR.value)
).build()
@@ -1854,7 +1833,7 @@ schemaless_with_user_input_schema_fails_connection_check_multi_stream_scenario:
}
)
.set_expected_check_status("FAILED")
.set_expected_check_error(ConfigValidationError, FileBasedSourceError.CONFIG_VALIDATION_ERROR.value)
.set_expected_check_error(AirbyteTracedException, FileBasedSourceError.CONFIG_VALIDATION_ERROR.value)
.set_expected_discover_error(ConfigValidationError, FileBasedSourceError.CONFIG_VALIDATION_ERROR.value)
.set_expected_read_error(ConfigValidationError, FileBasedSourceError.CONFIG_VALIDATION_ERROR.value)
).build()
@@ -3029,6 +3008,7 @@ earlier_csv_scenario: TestScenario[InMemoryFilesSource] = (
.set_file_type("csv")
)
.set_expected_check_status("FAILED")
.set_expected_check_error(AirbyteTracedException, FileBasedSourceError.EMPTY_STREAM.value)
.set_expected_catalog(
{
"streams": [

View File

@@ -14,6 +14,7 @@ from airbyte_cdk.sources.file_based.file_based_stream_reader import AbstractFile
from airbyte_cdk.sources.file_based.file_types.file_type_parser import FileTypeParser
from airbyte_cdk.sources.file_based.schema_validation_policies import AbstractSchemaValidationPolicy
from airbyte_cdk.sources.file_based.stream.cursor import AbstractFileBasedCursor
from airbyte_cdk.sources.source import TState
from unit_tests.sources.file_based.in_memory_files_source import InMemoryFilesSource
from unit_tests.sources.file_based.scenarios.scenario_builder import SourceBuilder
@@ -29,8 +30,10 @@ class FileBasedSourceBuilder(SourceBuilder[InMemoryFilesSource]):
self._stream_reader: Optional[AbstractFileBasedStreamReader] = None
self._file_write_options: Mapping[str, Any] = {}
self._cursor_cls: Optional[Type[AbstractFileBasedCursor]] = None
self._config: Optional[Mapping[str, Any]] = None
self._state: Optional[TState] = None
def build(self, configured_catalog: Optional[Mapping[str, Any]]) -> InMemoryFilesSource:
def build(self, configured_catalog: Optional[Mapping[str, Any]], config: Optional[Mapping[str, Any]], state: Optional[TState]) -> InMemoryFilesSource:
if self._file_type is None:
raise ValueError("file_type is not set")
return InMemoryFilesSource(
@@ -42,6 +45,8 @@ class FileBasedSourceBuilder(SourceBuilder[InMemoryFilesSource]):
self._parsers,
self._stream_reader,
configured_catalog,
config,
state,
self._file_write_options,
self._cursor_cls,
)

View File

@@ -485,14 +485,10 @@ invalid_jsonl_scenario = (
}
)
.set_expected_records(
[
{
"data": {"col1": "val1", "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", "_ab_source_file_url": "a.jsonl"},
"stream": "stream1",
},
]
[]
)
.set_expected_discover_error(AirbyteTracedException, FileBasedSourceError.SCHEMA_INFERENCE_ERROR.value)
.set_expected_read_error(AirbyteTracedException, "Please check the logged errors for more information.")
.set_expected_logs(
{
"read": [

View File

@@ -729,6 +729,7 @@ parquet_with_invalid_config_scenario = (
.set_expected_records([])
.set_expected_logs({"read": [{"level": "ERROR", "message": "Error parsing record"}]})
.set_expected_discover_error(AirbyteTracedException, "Error inferring schema from files")
.set_expected_read_error(AirbyteTracedException, "Please check the logged errors for more information.")
.set_expected_catalog(
{
"streams": [

View File

@@ -8,6 +8,7 @@ from typing import Any, Generic, List, Mapping, Optional, Set, Tuple, Type, Type
from airbyte_cdk.models import AirbyteAnalyticsTraceMessage, SyncMode
from airbyte_cdk.sources import AbstractSource
from airbyte_cdk.sources.source import TState
from airbyte_protocol.models import ConfiguredAirbyteCatalog
@@ -26,7 +27,7 @@ class SourceBuilder(ABC, Generic[SourceType]):
"""
@abstractmethod
def build(self, configured_catalog: Optional[Mapping[str, Any]]) -> SourceType:
def build(self, configured_catalog: Optional[Mapping[str, Any]], config: Optional[Mapping[str, Any]], state: Optional[TState]) -> SourceType:
raise NotImplementedError()
@@ -80,11 +81,11 @@ class TestScenario(Generic[SourceType]):
return self.catalog.dict() # type: ignore # dict() is not typed
catalog: Mapping[str, Any] = {"streams": []}
for stream in self.source.streams(self.config):
for stream in catalog["streams"]:
catalog["streams"].append(
{
"stream": {
"name": stream.name,
"name": stream["name"],
"json_schema": {},
"supported_sync_modes": [sync_mode.value],
},
@@ -152,7 +153,7 @@ class TestScenarioBuilder(Generic[SourceType]):
self._expected_logs = expected_logs
return self
def set_expected_records(self, expected_records: List[Mapping[str, Any]]) -> "TestScenarioBuilder[SourceType]":
def set_expected_records(self, expected_records: Optional[List[Mapping[str, Any]]]) -> "TestScenarioBuilder[SourceType]":
self._expected_records = expected_records
return self
@@ -191,7 +192,9 @@ class TestScenarioBuilder(Generic[SourceType]):
if self.source_builder is None:
raise ValueError("source_builder is not set")
source = self.source_builder.build(
self._configured_catalog(SyncMode.incremental if self._incremental_scenario_config else SyncMode.full_refresh)
self._configured_catalog(SyncMode.incremental if self._incremental_scenario_config else SyncMode.full_refresh),
self._config,
self._incremental_scenario_config.input_state if self._incremental_scenario_config else None,
)
return TestScenario(
self._name,

View File

@@ -661,29 +661,7 @@ wait_for_rediscovery_scenario_single_stream = (
]
}
)
.set_expected_records(
[
{
"data": {
"col1": "val_a_11",
"col2": 1,
"_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z",
"_ab_source_file_url": "a.csv",
},
"stream": "stream1",
},
{
"data": {
"col1": "val_a_12",
"col2": 2,
"_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z",
"_ab_source_file_url": "a.csv",
},
"stream": "stream1",
},
# No records past that because the first record for the second file did not conform to the schema
]
)
.set_expected_records(None) # When syncing streams concurrently we don't know how many records will be emitted before the sync stops
.set_expected_logs(
{
"read": [
@@ -722,56 +700,7 @@ wait_for_rediscovery_scenario_multi_stream = (
]
}
)
.set_expected_records(
[
{
"data": {
"col1": "val_aa1_11",
"col2": 1,
"_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z",
"_ab_source_file_url": "a/a1.csv",
},
"stream": "stream1",
},
{
"data": {
"col1": "val_aa1_12",
"col2": 2,
"_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z",
"_ab_source_file_url": "a/a1.csv",
},
"stream": "stream1",
},
# {"data": {"col1": "val_aa2_11", "col2": "this is text that will trigger validation policy", "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", "_ab_source_file_url": "a/a2.csv"}, "stream": "stream1"},
# {"data": {"col1": "val_aa2_12", "col2": 2, "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", "_ab_source_file_url": "a/a2.csv"}, "stream": "stream1"},
# {"data": {"col1": "val_aa3_11", "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", "_ab_source_file_url": "a/a3.csv"}, "stream": "stream1"},
# {"data": {"col1": "val_aa3_12", None: "val_aa3_22", "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", "_ab_source_file_url": "a/a3.csv"}, "stream": "stream1"},
# {"data": {"col1": "val_aa3_13", "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", "_ab_source_file_url": "a/a3.csv"}, "stream": "stream1"},
# {"data": {"col1": "val_aa4_11", "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", "_ab_source_file_url": "a/a4.csv"}, "stream": "stream1"},
{
"data": {
"col1": "val_bb1_11",
"col2": 1,
"_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z",
"_ab_source_file_url": "b/b1.csv",
},
"stream": "stream2",
},
{
"data": {
"col1": "val_bb1_12",
"col2": 2,
"_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z",
"_ab_source_file_url": "b/b1.csv",
},
"stream": "stream2",
},
# {"data": {"col1": "val_bb2_11", "col2": "this is text that will trigger validation policy", "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", "_ab_source_file_url": "b/b2.csv"}, "stream": "stream2"},
# {"data": {"col1": "val_bb2_12", "col2": 2, "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", "_ab_source_file_url": "b/b2.csv"}, "stream": "stream2"},
# {"data": {"col1": "val_bb3_11", "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", "_ab_source_file_url": "b/b3.csv"}, "stream": "stream2"},
# {"data": {"col1": "val_bb3_12", "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", "_ab_source_file_url": "b/b3.csv"}, "stream": "stream2"},
]
)
.set_expected_records(None) # When syncing streams concurrently we don't know how many records will be emitted before the sync stops
.set_expected_logs(
{
"read": [

View File

@@ -0,0 +1,365 @@
#
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
import logging
import unittest
from datetime import datetime
from unittest.mock import MagicMock, Mock
import pytest
from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, AirbyteStream, Level, SyncMode
from airbyte_cdk.models import Type as MessageType
from airbyte_cdk.sources.file_based.availability_strategy import DefaultFileBasedAvailabilityStrategy
from airbyte_cdk.sources.file_based.config.csv_format import CsvFormat
from airbyte_cdk.sources.file_based.config.file_based_stream_config import FileBasedStreamConfig
from airbyte_cdk.sources.file_based.discovery_policy import DefaultDiscoveryPolicy
from airbyte_cdk.sources.file_based.exceptions import FileBasedErrorsCollector
from airbyte_cdk.sources.file_based.file_types import default_parsers
from airbyte_cdk.sources.file_based.remote_file import RemoteFile
from airbyte_cdk.sources.file_based.schema_validation_policies import EmitRecordPolicy
from airbyte_cdk.sources.file_based.stream import DefaultFileBasedStream
from airbyte_cdk.sources.file_based.stream.concurrent.adapters import (
FileBasedStreamFacade,
FileBasedStreamPartition,
FileBasedStreamPartitionGenerator,
)
from airbyte_cdk.sources.file_based.stream.concurrent.cursor import FileBasedNoopCursor
from airbyte_cdk.sources.message import InMemoryMessageRepository
from airbyte_cdk.sources.streams.concurrent.cursor import Cursor
from airbyte_cdk.sources.streams.concurrent.exceptions import ExceptionWithDisplayMessage
from airbyte_cdk.sources.streams.concurrent.partitions.record import Record
from airbyte_cdk.sources.utils.slice_logger import SliceLogger
from airbyte_cdk.sources.utils.transform import TransformConfig, TypeTransformer
from freezegun import freeze_time
_ANY_SYNC_MODE = SyncMode.full_refresh
_ANY_STATE = {"state_key": "state_value"}
_ANY_CURSOR_FIELD = ["a", "cursor", "key"]
_STREAM_NAME = "stream"
_ANY_CURSOR = Mock(spec=FileBasedNoopCursor)
@pytest.mark.parametrize(
"sync_mode",
[
pytest.param(SyncMode.full_refresh, id="test_full_refresh"),
pytest.param(SyncMode.incremental, id="test_incremental"),
],
)
def test_file_based_stream_partition_generator(sync_mode):
stream = Mock()
message_repository = Mock()
stream_slices = [{"files": [RemoteFile(uri="1", last_modified=datetime.now())]},
{"files": [RemoteFile(uri="2", last_modified=datetime.now())]}]
stream.stream_slices.return_value = stream_slices
partition_generator = FileBasedStreamPartitionGenerator(stream, message_repository, _ANY_SYNC_MODE, _ANY_CURSOR_FIELD, _ANY_STATE, _ANY_CURSOR)
partitions = list(partition_generator.generate())
slices = [partition.to_slice() for partition in partitions]
assert slices == stream_slices
stream.stream_slices.assert_called_once_with(sync_mode=_ANY_SYNC_MODE, cursor_field=_ANY_CURSOR_FIELD, stream_state=_ANY_STATE)
@pytest.mark.parametrize(
"transformer, expected_records",
[
pytest.param(
TypeTransformer(TransformConfig.NoTransform),
[Record({"data": "1"}, _STREAM_NAME), Record({"data": "2"}, _STREAM_NAME)],
id="test_no_transform",
),
pytest.param(
TypeTransformer(TransformConfig.DefaultSchemaNormalization),
[Record({"data": 1}, _STREAM_NAME), Record({"data": 2}, _STREAM_NAME)],
id="test_default_transform",
),
],
)
def test_file_based_stream_partition(transformer, expected_records):
stream = Mock()
stream.name = _STREAM_NAME
stream.get_json_schema.return_value = {"type": "object", "properties": {"data": {"type": ["integer"]}}}
stream.transformer = transformer
message_repository = InMemoryMessageRepository()
_slice = None
sync_mode = SyncMode.full_refresh
cursor_field = None
state = None
partition = FileBasedStreamPartition(stream, _slice, message_repository, sync_mode, cursor_field, state, _ANY_CURSOR)
a_log_message = AirbyteMessage(
type=MessageType.LOG,
log=AirbyteLogMessage(
level=Level.INFO,
message='slice:{"partition": 1}',
),
)
stream_data = [a_log_message, {"data": "1"}, {"data": "2"}]
stream.read_records.return_value = stream_data
records = list(partition.read())
messages = list(message_repository.consume_queue())
assert records == expected_records
assert messages == [a_log_message]
@pytest.mark.parametrize(
"exception_type, expected_display_message",
[
pytest.param(Exception, None, id="test_exception_no_display_message"),
pytest.param(ExceptionWithDisplayMessage, "display_message", id="test_exception_no_display_message"),
],
)
def test_file_based_stream_partition_raising_exception(exception_type, expected_display_message):
stream = Mock()
stream.get_error_display_message.return_value = expected_display_message
message_repository = InMemoryMessageRepository()
_slice = None
partition = FileBasedStreamPartition(stream, _slice, message_repository, _ANY_SYNC_MODE, _ANY_CURSOR_FIELD, _ANY_STATE, _ANY_CURSOR)
stream.read_records.side_effect = Exception()
with pytest.raises(exception_type) as e:
list(partition.read())
if isinstance(e, ExceptionWithDisplayMessage):
assert e.display_message == "display message"
@freeze_time("2023-06-09T00:00:00Z")
@pytest.mark.parametrize(
"_slice, expected_hash",
[
pytest.param({"files": [RemoteFile(uri="1", last_modified=datetime.strptime("2023-06-09T00:00:00Z", "%Y-%m-%dT%H:%M:%SZ"))]}, hash(("stream", "2023-06-09T00:00:00.000000Z_1")), id="test_hash_with_slice"),
pytest.param(None, hash("stream"), id="test_hash_no_slice"),
],
)
def test_file_based_stream_partition_hash(_slice, expected_hash):
stream = Mock()
stream.name = "stream"
partition = FileBasedStreamPartition(stream, _slice, Mock(), _ANY_SYNC_MODE, _ANY_CURSOR_FIELD, _ANY_STATE, _ANY_CURSOR)
_hash = partition.__hash__()
assert _hash == expected_hash
class StreamFacadeTest(unittest.TestCase):
def setUp(self):
self._abstract_stream = Mock()
self._abstract_stream.name = "stream"
self._abstract_stream.as_airbyte_stream.return_value = AirbyteStream(
name="stream",
json_schema={"type": "object"},
supported_sync_modes=[SyncMode.full_refresh],
)
self._legacy_stream = DefaultFileBasedStream(
cursor=FileBasedNoopCursor(MagicMock()),
config=FileBasedStreamConfig(name="stream", format=CsvFormat()),
catalog_schema={},
stream_reader=MagicMock(),
availability_strategy=DefaultFileBasedAvailabilityStrategy(MagicMock()),
discovery_policy=DefaultDiscoveryPolicy(),
parsers=default_parsers,
validation_policy=EmitRecordPolicy(),
errors_collector=FileBasedErrorsCollector(),
)
self._cursor = Mock(spec=Cursor)
self._logger = Mock()
self._slice_logger = Mock()
self._slice_logger.should_log_slice_message.return_value = False
self._facade = FileBasedStreamFacade(self._abstract_stream, self._legacy_stream, self._cursor, self._slice_logger, self._logger)
self._source = Mock()
self._stream = Mock()
self._stream.primary_key = "id"
def test_name_is_delegated_to_wrapped_stream(self):
assert self._facade.name == self._abstract_stream.name
def test_cursor_field_is_a_string(self):
self._abstract_stream.cursor_field = "cursor_field"
assert self._facade.cursor_field == "cursor_field"
def test_source_defined_cursor_is_true(self):
assert self._facade.source_defined_cursor
def test_json_schema_is_delegated_to_wrapped_stream(self):
json_schema = {"type": "object"}
self._abstract_stream.get_json_schema.return_value = json_schema
assert self._facade.get_json_schema() == json_schema
self._abstract_stream.get_json_schema.assert_called_once_with()
def test_given_cursor_is_noop_when_supports_incremental_then_return_legacy_stream_response(self):
assert (
FileBasedStreamFacade(
self._abstract_stream, self._legacy_stream, _ANY_CURSOR, Mock(spec=SliceLogger), Mock(spec=logging.Logger)
).supports_incremental
== self._legacy_stream.supports_incremental
)
def test_given_cursor_is_not_noop_when_supports_incremental_then_return_true(self):
assert FileBasedStreamFacade(
self._abstract_stream, self._legacy_stream, Mock(spec=Cursor), Mock(spec=SliceLogger), Mock(spec=logging.Logger)
).supports_incremental
def test_full_refresh(self):
expected_stream_data = [{"data": 1}, {"data": 2}]
records = [Record(data, "stream") for data in expected_stream_data]
partition = Mock()
partition.read.return_value = records
self._abstract_stream.generate_partitions.return_value = [partition]
actual_stream_data = list(self._facade.read_records(SyncMode.full_refresh, None, {}, None))
assert actual_stream_data == expected_stream_data
def test_read_records_full_refresh(self):
expected_stream_data = [{"data": 1}, {"data": 2}]
records = [Record(data, "stream") for data in expected_stream_data]
partition = Mock()
partition.read.return_value = records
self._abstract_stream.generate_partitions.return_value = [partition]
actual_stream_data = list(self._facade.read_full_refresh(None, None, None))
assert actual_stream_data == expected_stream_data
def test_read_records_incremental(self):
expected_stream_data = [{"data": 1}, {"data": 2}]
records = [Record(data, "stream") for data in expected_stream_data]
partition = Mock()
partition.read.return_value = records
self._abstract_stream.generate_partitions.return_value = [partition]
actual_stream_data = list(self._facade.read_incremental(None, None, None, None, None, None, None))
assert actual_stream_data == expected_stream_data
def test_create_from_stream_stream(self):
stream = Mock()
stream.name = "stream"
stream.primary_key = "id"
stream.cursor_field = "cursor"
facade = FileBasedStreamFacade.create_from_stream(stream, self._source, self._logger, _ANY_STATE, self._cursor)
assert facade.name == "stream"
assert facade.cursor_field == "cursor"
assert facade._abstract_stream._primary_key == ["id"]
def test_create_from_stream_stream_with_none_primary_key(self):
stream = Mock()
stream.name = "stream"
stream.primary_key = None
stream.cursor_field = []
facade = FileBasedStreamFacade.create_from_stream(stream, self._source, self._logger, _ANY_STATE, self._cursor)
assert facade._abstract_stream._primary_key == []
def test_create_from_stream_with_composite_primary_key(self):
stream = Mock()
stream.name = "stream"
stream.primary_key = ["id", "name"]
stream.cursor_field = []
facade = FileBasedStreamFacade.create_from_stream(stream, self._source, self._logger, _ANY_STATE, self._cursor)
assert facade._abstract_stream._primary_key == ["id", "name"]
def test_create_from_stream_with_empty_list_cursor(self):
stream = Mock()
stream.primary_key = "id"
stream.cursor_field = []
facade = FileBasedStreamFacade.create_from_stream(stream, self._source, self._logger, _ANY_STATE, self._cursor)
assert facade.cursor_field == []
def test_create_from_stream_raises_exception_if_primary_key_is_nested(self):
stream = Mock()
stream.name = "stream"
stream.primary_key = [["field", "id"]]
with self.assertRaises(ValueError):
FileBasedStreamFacade.create_from_stream(stream, self._source, self._logger, _ANY_STATE, self._cursor)
def test_create_from_stream_raises_exception_if_primary_key_has_invalid_type(self):
stream = Mock()
stream.name = "stream"
stream.primary_key = 123
with self.assertRaises(ValueError):
FileBasedStreamFacade.create_from_stream(stream, self._source, self._logger, _ANY_STATE, self._cursor)
def test_create_from_stream_raises_exception_if_cursor_field_is_nested(self):
stream = Mock()
stream.name = "stream"
stream.primary_key = "id"
stream.cursor_field = ["field", "cursor"]
with self.assertRaises(ValueError):
FileBasedStreamFacade.create_from_stream(stream, self._source, self._logger, _ANY_STATE, self._cursor)
def test_create_from_stream_with_cursor_field_as_list(self):
stream = Mock()
stream.name = "stream"
stream.primary_key = "id"
stream.cursor_field = ["cursor"]
facade = FileBasedStreamFacade.create_from_stream(stream, self._source, self._logger, _ANY_STATE, self._cursor)
assert facade.cursor_field == "cursor"
def test_create_from_stream_none_message_repository(self):
self._stream.name = "stream"
self._stream.primary_key = "id"
self._stream.cursor_field = "cursor"
self._source.message_repository = None
with self.assertRaises(ValueError):
FileBasedStreamFacade.create_from_stream(self._stream, self._source, self._logger, {}, self._cursor)
def test_get_error_display_message_no_display_message(self):
self._stream.get_error_display_message.return_value = "display_message"
facade = FileBasedStreamFacade.create_from_stream(self._stream, self._source, self._logger, _ANY_STATE, self._cursor)
expected_display_message = None
e = Exception()
display_message = facade.get_error_display_message(e)
assert expected_display_message == display_message
def test_get_error_display_message_with_display_message(self):
self._stream.get_error_display_message.return_value = "display_message"
facade = FileBasedStreamFacade.create_from_stream(self._stream, self._source, self._logger, _ANY_STATE, self._cursor)
expected_display_message = "display_message"
e = ExceptionWithDisplayMessage("display_message")
display_message = facade.get_error_display_message(e)
assert expected_display_message == display_message
@pytest.mark.parametrize(
"exception, expected_display_message",
[
pytest.param(Exception("message"), None, id="test_no_display_message"),
pytest.param(ExceptionWithDisplayMessage("message"), "message", id="test_no_display_message"),
],
)
def test_get_error_display_message(exception, expected_display_message):
stream = Mock()
legacy_stream = Mock()
cursor = Mock(spec=Cursor)
facade = FileBasedStreamFacade(stream, legacy_stream, cursor, Mock().Mock(), Mock())
display_message = facade.get_error_display_message(exception)
assert display_message == expected_display_message

View File

@@ -73,8 +73,21 @@ def _verify_read_output(output: EntrypointOutput, scenario: TestScenario[Abstrac
records, log_messages = output.records_and_state_messages, output.logs
logs = [message.log for message in log_messages if message.log.level.value in scenario.log_levels]
expected_records = scenario.expected_records
if expected_records is None:
return
assert len(records) == len(expected_records)
for actual, expected in zip(records, expected_records):
sorted_expected_records = sorted(
filter(lambda e: "data" in e, expected_records),
key=lambda x: ",".join(f"{k}={v}" for k, v in sorted(x["data"].items(), key=lambda x: x[0]) if k != "emitted_at"),
)
sorted_records = sorted(
filter(lambda r: r.record, records),
key=lambda x: ",".join(f"{k}={v}" for k, v in sorted(x.record.data.items(), key=lambda x: x[0]) if k != "emitted_at"),
)
for actual, expected in zip(sorted_records, sorted_expected_records):
if actual.record:
assert len(actual.record.data) == len(expected["data"])
for key, value in actual.record.data.items():
@@ -83,8 +96,11 @@ def _verify_read_output(output: EntrypointOutput, scenario: TestScenario[Abstrac
else:
assert value == expected["data"][key]
assert actual.record.stream == expected["stream"]
elif actual.state:
assert actual.state.data == expected
expected_states = filter(lambda e: "data" not in e, expected_records)
states = filter(lambda r: r.state, records)
for actual, expected in zip(states, expected_states): # states should be emitted in sorted order
assert actual.state.data == expected
if scenario.expected_logs:
read_logs = scenario.expected_logs.get("read")
@@ -129,7 +145,7 @@ def verify_check(capsys: CaptureFixture[str], tmp_path: PosixPath, scenario: Tes
output = check(capsys, tmp_path, scenario)
if expected_msg:
# expected_msg is a string. what's the expected value field?
assert expected_msg.value in output["message"] # type: ignore
assert expected_msg in output["message"] # type: ignore
assert output["status"] == scenario.expected_check_status
else:

View File

@@ -11,6 +11,7 @@ from airbyte_cdk.sources.concurrent_source.concurrent_source_adapter import Conc
from airbyte_cdk.sources.concurrent_source.thread_pool_manager import ThreadPoolManager
from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager
from airbyte_cdk.sources.message import InMemoryMessageRepository, MessageRepository
from airbyte_cdk.sources.source import TState
from airbyte_cdk.sources.streams import Stream
from airbyte_cdk.sources.streams.concurrent.adapters import StreamFacade
from airbyte_cdk.sources.streams.concurrent.cursor import ConcurrentCursor, CursorField, NoopCursor
@@ -123,6 +124,6 @@ class StreamFacadeSourceBuilder(SourceBuilder[StreamFacadeSource]):
self._input_state = state
return self
def build(self, configured_catalog: Optional[Mapping[str, Any]]) -> StreamFacadeSource:
def build(self, configured_catalog: Optional[Mapping[str, Any]], config: Optional[Mapping[str, Any]], state: Optional[TState]) -> StreamFacadeSource:
threadpool = concurrent.futures.ThreadPoolExecutor(max_workers=self._max_workers, thread_name_prefix="workerpool")
return StreamFacadeSource(self._streams, threadpool, self._cursor_field, self._cursor_boundaries, self._input_state)
return StreamFacadeSource(self._streams, threadpool, self._cursor_field, self._cursor_boundaries, state)

View File

@@ -119,7 +119,7 @@ class ConcurrentSourceBuilder(SourceBuilder[ConcurrentCdkSource]):
self._streams: List[DefaultStream] = []
self._message_repository = None
def build(self, configured_catalog: Optional[Mapping[str, Any]]) -> ConcurrentCdkSource:
def build(self, configured_catalog: Optional[Mapping[str, Any]], _, __) -> ConcurrentCdkSource:
return ConcurrentCdkSource(self._streams, self._message_repository, 1, 1)
def set_streams(self, streams: List[DefaultStream]) -> "ConcurrentSourceBuilder":