File-based CDK: make full refresh concurrent (#34411)
This commit is contained in:
@@ -10,7 +10,7 @@ from airbyte_cdk.sources import AbstractSource
|
||||
from airbyte_cdk.sources.concurrent_source.concurrent_source import ConcurrentSource
|
||||
from airbyte_cdk.sources.streams import Stream
|
||||
from airbyte_cdk.sources.streams.concurrent.abstract_stream import AbstractStream
|
||||
from airbyte_cdk.sources.streams.concurrent.adapters import StreamFacade
|
||||
from airbyte_cdk.sources.streams.concurrent.abstract_stream_facade import AbstractStreamFacade
|
||||
|
||||
|
||||
class ConcurrentSourceAdapter(AbstractSource, ABC):
|
||||
@@ -58,6 +58,6 @@ class ConcurrentSourceAdapter(AbstractSource, ABC):
|
||||
f"The stream {configured_stream.stream.name} no longer exists in the configuration. "
|
||||
f"Refresh the schema in replication settings and remove this stream from future sync attempts."
|
||||
)
|
||||
if isinstance(stream_instance, StreamFacade):
|
||||
abstract_streams.append(stream_instance._abstract_stream)
|
||||
if isinstance(stream_instance, AbstractStreamFacade):
|
||||
abstract_streams.append(stream_instance.get_underlying_stream())
|
||||
return abstract_streams
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
from .abstract_file_based_availability_strategy import AbstractFileBasedAvailabilityStrategy
|
||||
from .abstract_file_based_availability_strategy import (
|
||||
AbstractFileBasedAvailabilityStrategy,
|
||||
AbstractFileBasedAvailabilityStrategyWrapper,
|
||||
)
|
||||
from .default_file_based_availability_strategy import DefaultFileBasedAvailabilityStrategy
|
||||
|
||||
__all__ = ["AbstractFileBasedAvailabilityStrategy", "DefaultFileBasedAvailabilityStrategy"]
|
||||
__all__ = ["AbstractFileBasedAvailabilityStrategy", "AbstractFileBasedAvailabilityStrategyWrapper", "DefaultFileBasedAvailabilityStrategy"]
|
||||
|
||||
@@ -8,6 +8,12 @@ from typing import TYPE_CHECKING, Optional, Tuple
|
||||
|
||||
from airbyte_cdk.sources import Source
|
||||
from airbyte_cdk.sources.streams.availability_strategy import AvailabilityStrategy
|
||||
from airbyte_cdk.sources.streams.concurrent.availability_strategy import (
|
||||
AbstractAvailabilityStrategy,
|
||||
StreamAvailability,
|
||||
StreamAvailable,
|
||||
StreamUnavailable,
|
||||
)
|
||||
from airbyte_cdk.sources.streams.core import Stream
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -35,3 +41,17 @@ class AbstractFileBasedAvailabilityStrategy(AvailabilityStrategy):
|
||||
Returns (True, None) if successful, otherwise (False, <error message>).
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class AbstractFileBasedAvailabilityStrategyWrapper(AbstractAvailabilityStrategy):
|
||||
def __init__(self, stream: "AbstractFileBasedStream"):
|
||||
self.stream = stream
|
||||
|
||||
def check_availability(self, logger: logging.Logger) -> StreamAvailability:
|
||||
is_available, reason = self.stream.availability_strategy.check_availability(self.stream, logger, None)
|
||||
if is_available:
|
||||
return StreamAvailable()
|
||||
return StreamUnavailable(reason or "")
|
||||
|
||||
def check_availability_and_parsability(self, logger: logging.Logger) -> Tuple[bool, Optional[str]]:
|
||||
return self.stream.availability_strategy.check_availability_and_parsability(self.stream, logger, None)
|
||||
|
||||
@@ -8,8 +8,18 @@ from abc import ABC
|
||||
from collections import Counter
|
||||
from typing import Any, Iterator, List, Mapping, MutableMapping, Optional, Tuple, Type, Union
|
||||
|
||||
from airbyte_cdk.models import AirbyteMessage, AirbyteStateMessage, ConfiguredAirbyteCatalog, ConnectorSpecification
|
||||
from airbyte_cdk.sources import AbstractSource
|
||||
from airbyte_cdk.logger import AirbyteLogFormatter, init_logger
|
||||
from airbyte_cdk.models import (
|
||||
AirbyteMessage,
|
||||
AirbyteStateMessage,
|
||||
ConfiguredAirbyteCatalog,
|
||||
ConnectorSpecification,
|
||||
FailureType,
|
||||
Level,
|
||||
SyncMode,
|
||||
)
|
||||
from airbyte_cdk.sources.concurrent_source.concurrent_source import ConcurrentSource
|
||||
from airbyte_cdk.sources.concurrent_source.concurrent_source_adapter import ConcurrentSourceAdapter
|
||||
from airbyte_cdk.sources.file_based.availability_strategy import AbstractFileBasedAvailabilityStrategy, DefaultFileBasedAvailabilityStrategy
|
||||
from airbyte_cdk.sources.file_based.config.abstract_file_based_spec import AbstractFileBasedSpec
|
||||
from airbyte_cdk.sources.file_based.config.file_based_stream_config import FileBasedStreamConfig, ValidationPolicy
|
||||
@@ -20,19 +30,33 @@ from airbyte_cdk.sources.file_based.file_types import default_parsers
|
||||
from airbyte_cdk.sources.file_based.file_types.file_type_parser import FileTypeParser
|
||||
from airbyte_cdk.sources.file_based.schema_validation_policies import DEFAULT_SCHEMA_VALIDATION_POLICIES, AbstractSchemaValidationPolicy
|
||||
from airbyte_cdk.sources.file_based.stream import AbstractFileBasedStream, DefaultFileBasedStream
|
||||
from airbyte_cdk.sources.file_based.stream.concurrent.adapters import FileBasedStreamFacade
|
||||
from airbyte_cdk.sources.file_based.stream.concurrent.cursor import FileBasedNoopCursor
|
||||
from airbyte_cdk.sources.file_based.stream.cursor import AbstractFileBasedCursor
|
||||
from airbyte_cdk.sources.file_based.stream.cursor.default_file_based_cursor import DefaultFileBasedCursor
|
||||
from airbyte_cdk.sources.message.repository import InMemoryMessageRepository, MessageRepository
|
||||
from airbyte_cdk.sources.source import TState
|
||||
from airbyte_cdk.sources.streams import Stream
|
||||
from airbyte_cdk.utils.analytics_message import create_analytics_message
|
||||
from airbyte_cdk.utils.traced_exception import AirbyteTracedException
|
||||
from pydantic.error_wrappers import ValidationError
|
||||
|
||||
DEFAULT_CONCURRENCY = 100
|
||||
MAX_CONCURRENCY = 100
|
||||
INITIAL_N_PARTITIONS = MAX_CONCURRENCY // 2
|
||||
|
||||
|
||||
class FileBasedSource(ConcurrentSourceAdapter, ABC):
|
||||
# We make each source override the concurrency level to give control over when they are upgraded.
|
||||
_concurrency_level = None
|
||||
|
||||
class FileBasedSource(AbstractSource, ABC):
|
||||
def __init__(
|
||||
self,
|
||||
stream_reader: AbstractFileBasedStreamReader,
|
||||
spec_class: Type[AbstractFileBasedSpec],
|
||||
catalog_path: Optional[str] = None,
|
||||
catalog: Optional[ConfiguredAirbyteCatalog],
|
||||
config: Optional[Mapping[str, Any]],
|
||||
state: Optional[TState],
|
||||
availability_strategy: Optional[AbstractFileBasedAvailabilityStrategy] = None,
|
||||
discovery_policy: AbstractDiscoveryPolicy = DefaultDiscoveryPolicy(),
|
||||
parsers: Mapping[Type[Any], FileTypeParser] = default_parsers,
|
||||
@@ -41,15 +65,29 @@ class FileBasedSource(AbstractSource, ABC):
|
||||
):
|
||||
self.stream_reader = stream_reader
|
||||
self.spec_class = spec_class
|
||||
self.config = config
|
||||
self.catalog = catalog
|
||||
self.state = state
|
||||
self.availability_strategy = availability_strategy or DefaultFileBasedAvailabilityStrategy(stream_reader)
|
||||
self.discovery_policy = discovery_policy
|
||||
self.parsers = parsers
|
||||
self.validation_policies = validation_policies
|
||||
catalog = self.read_catalog(catalog_path) if catalog_path else None
|
||||
self.stream_schemas = {s.stream.name: s.stream.json_schema for s in catalog.streams} if catalog else {}
|
||||
self.cursor_cls = cursor_cls
|
||||
self.logger = logging.getLogger(f"airbyte.{self.name}")
|
||||
self.logger = init_logger(f"airbyte.{self.name}")
|
||||
self.errors_collector: FileBasedErrorsCollector = FileBasedErrorsCollector()
|
||||
self._message_repository: Optional[MessageRepository] = None
|
||||
concurrent_source = ConcurrentSource.create(
|
||||
MAX_CONCURRENCY, INITIAL_N_PARTITIONS, self.logger, self._slice_logger, self.message_repository
|
||||
)
|
||||
self._state = None
|
||||
super().__init__(concurrent_source)
|
||||
|
||||
@property
|
||||
def message_repository(self) -> MessageRepository:
|
||||
if self._message_repository is None:
|
||||
self._message_repository = InMemoryMessageRepository(Level(AirbyteLogFormatter.level_mapping[self.logger.level]))
|
||||
return self._message_repository
|
||||
|
||||
def check_connection(self, logger: logging.Logger, config: Mapping[str, Any]) -> Tuple[bool, Optional[Any]]:
|
||||
"""
|
||||
@@ -61,7 +99,15 @@ class FileBasedSource(AbstractSource, ABC):
|
||||
|
||||
Otherwise, the "error" object should describe what went wrong.
|
||||
"""
|
||||
streams = self.streams(config)
|
||||
try:
|
||||
streams = self.streams(config)
|
||||
except Exception as config_exception:
|
||||
raise AirbyteTracedException(
|
||||
internal_message="Please check the logged errors for more information.",
|
||||
message=FileBasedSourceError.CONFIG_VALIDATION_ERROR.value,
|
||||
exception=AirbyteTracedException(exception=config_exception),
|
||||
failure_type=FailureType.config_error,
|
||||
)
|
||||
if len(streams) == 0:
|
||||
return (
|
||||
False,
|
||||
@@ -80,7 +126,7 @@ class FileBasedSource(AbstractSource, ABC):
|
||||
reason,
|
||||
) = stream.availability_strategy.check_availability_and_parsability(stream, logger, self)
|
||||
except Exception:
|
||||
errors.append(f"Unable to connect to stream {stream} - {''.join(traceback.format_exc())}")
|
||||
errors.append(f"Unable to connect to stream {stream.name} - {''.join(traceback.format_exc())}")
|
||||
else:
|
||||
if not stream_is_available and reason:
|
||||
errors.append(reason)
|
||||
@@ -91,10 +137,26 @@ class FileBasedSource(AbstractSource, ABC):
|
||||
"""
|
||||
Return a list of this source's streams.
|
||||
"""
|
||||
file_based_streams = self._get_file_based_streams(config)
|
||||
|
||||
configured_streams: List[Stream] = []
|
||||
|
||||
for stream in file_based_streams:
|
||||
sync_mode = self._get_sync_mode_from_catalog(stream)
|
||||
if sync_mode == SyncMode.full_refresh and hasattr(self, "_concurrency_level") and self._concurrency_level is not None:
|
||||
configured_streams.append(
|
||||
FileBasedStreamFacade.create_from_stream(stream, self, self.logger, None, FileBasedNoopCursor(stream.config))
|
||||
)
|
||||
else:
|
||||
configured_streams.append(stream)
|
||||
|
||||
return configured_streams
|
||||
|
||||
def _get_file_based_streams(self, config: Mapping[str, Any]) -> List[AbstractFileBasedStream]:
|
||||
try:
|
||||
parsed_config = self._get_parsed_config(config)
|
||||
self.stream_reader.config = parsed_config
|
||||
streams: List[Stream] = []
|
||||
streams: List[AbstractFileBasedStream] = []
|
||||
for stream_config in parsed_config.streams:
|
||||
self._validate_input_schema(stream_config)
|
||||
streams.append(
|
||||
@@ -115,6 +177,14 @@ class FileBasedSource(AbstractSource, ABC):
|
||||
except ValidationError as exc:
|
||||
raise ConfigValidationError(FileBasedSourceError.CONFIG_VALIDATION_ERROR) from exc
|
||||
|
||||
def _get_sync_mode_from_catalog(self, stream: Stream) -> Optional[SyncMode]:
|
||||
if self.catalog:
|
||||
for catalog_stream in self.catalog.streams:
|
||||
if stream.name == catalog_stream.stream.name:
|
||||
return catalog_stream.sync_mode
|
||||
raise RuntimeError(f"No sync mode was found for {stream.name}.")
|
||||
return None
|
||||
|
||||
def read(
|
||||
self,
|
||||
logger: logging.Logger,
|
||||
|
||||
@@ -10,6 +10,7 @@ from collections import defaultdict
|
||||
from functools import partial
|
||||
from io import IOBase
|
||||
from typing import Any, Callable, Dict, Generator, Iterable, List, Mapping, Optional, Set, Tuple
|
||||
from uuid import uuid4
|
||||
|
||||
from airbyte_cdk.models import FailureType
|
||||
from airbyte_cdk.sources.file_based.config.csv_format import CsvFormat, CsvHeaderAutogenerated, CsvHeaderUserProvided, InferenceType
|
||||
@@ -38,8 +39,10 @@ class _CsvReader:
|
||||
|
||||
# Formats are configured individually per-stream so a unique dialect should be registered for each stream.
|
||||
# We don't unregister the dialect because we are lazily parsing each csv file to generate records
|
||||
# This will potentially be a problem if we ever process multiple streams concurrently
|
||||
dialect_name = config.name + DIALECT_NAME
|
||||
# Give each stream's dialect a unique name; otherwise, when we are doing a concurrent sync we can end up
|
||||
# with a race condition where a thread attempts to use a dialect before a separate thread has finished
|
||||
# registering it.
|
||||
dialect_name = f"{config.name}_{str(uuid4())}_{DIALECT_NAME}"
|
||||
csv.register_dialect(
|
||||
dialect_name,
|
||||
delimiter=config_format.delimiter,
|
||||
|
||||
@@ -0,0 +1,322 @@
|
||||
#
|
||||
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.
|
||||
#
|
||||
|
||||
import copy
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from typing import Any, Iterable, List, Mapping, MutableMapping, Optional, Union
|
||||
|
||||
from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, Level, SyncMode, Type
|
||||
from airbyte_cdk.sources import AbstractSource
|
||||
from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager
|
||||
from airbyte_cdk.sources.file_based.availability_strategy import (
|
||||
AbstractFileBasedAvailabilityStrategy,
|
||||
AbstractFileBasedAvailabilityStrategyWrapper,
|
||||
)
|
||||
from airbyte_cdk.sources.file_based.config.file_based_stream_config import PrimaryKeyType
|
||||
from airbyte_cdk.sources.file_based.file_types.file_type_parser import FileTypeParser
|
||||
from airbyte_cdk.sources.file_based.remote_file import RemoteFile
|
||||
from airbyte_cdk.sources.file_based.stream import AbstractFileBasedStream
|
||||
from airbyte_cdk.sources.file_based.stream.concurrent.cursor import FileBasedNoopCursor
|
||||
from airbyte_cdk.sources.file_based.types import StreamSlice
|
||||
from airbyte_cdk.sources.message import MessageRepository
|
||||
from airbyte_cdk.sources.streams.concurrent.abstract_stream_facade import AbstractStreamFacade
|
||||
from airbyte_cdk.sources.streams.concurrent.default_stream import DefaultStream
|
||||
from airbyte_cdk.sources.streams.concurrent.exceptions import ExceptionWithDisplayMessage
|
||||
from airbyte_cdk.sources.streams.concurrent.helpers import get_cursor_field_from_stream, get_primary_key_from_stream
|
||||
from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition
|
||||
from airbyte_cdk.sources.streams.concurrent.partitions.partition_generator import PartitionGenerator
|
||||
from airbyte_cdk.sources.streams.concurrent.partitions.record import Record
|
||||
from airbyte_cdk.sources.streams.core import StreamData
|
||||
from airbyte_cdk.sources.utils.schema_helpers import InternalConfig
|
||||
from airbyte_cdk.sources.utils.slice_logger import SliceLogger
|
||||
from deprecated.classic import deprecated
|
||||
|
||||
"""
|
||||
This module contains adapters to help enabling concurrency on File-based Stream objects without needing to migrate to AbstractStream
|
||||
"""
|
||||
|
||||
|
||||
@deprecated("This class is experimental. Use at your own risk.")
|
||||
class FileBasedStreamFacade(AbstractStreamFacade[DefaultStream], AbstractFileBasedStream):
|
||||
@classmethod
|
||||
def create_from_stream(
|
||||
cls,
|
||||
stream: AbstractFileBasedStream,
|
||||
source: AbstractSource,
|
||||
logger: logging.Logger,
|
||||
state: Optional[MutableMapping[str, Any]],
|
||||
cursor: FileBasedNoopCursor,
|
||||
) -> "FileBasedStreamFacade":
|
||||
"""
|
||||
Create a ConcurrentStream from a FileBasedStream object.
|
||||
"""
|
||||
pk = get_primary_key_from_stream(stream.primary_key)
|
||||
cursor_field = get_cursor_field_from_stream(stream)
|
||||
|
||||
if not source.message_repository:
|
||||
raise ValueError(
|
||||
"A message repository is required to emit non-record messages. Please set the message repository on the source."
|
||||
)
|
||||
|
||||
message_repository = source.message_repository
|
||||
return FileBasedStreamFacade(
|
||||
DefaultStream( # type: ignore
|
||||
partition_generator=FileBasedStreamPartitionGenerator(
|
||||
stream,
|
||||
message_repository,
|
||||
SyncMode.full_refresh if isinstance(cursor, FileBasedNoopCursor) else SyncMode.incremental,
|
||||
[cursor_field] if cursor_field is not None else None,
|
||||
state,
|
||||
cursor,
|
||||
),
|
||||
name=stream.name,
|
||||
json_schema=stream.get_json_schema(),
|
||||
availability_strategy=AbstractFileBasedAvailabilityStrategyWrapper(stream),
|
||||
primary_key=pk,
|
||||
cursor_field=cursor_field,
|
||||
logger=logger,
|
||||
namespace=stream.namespace,
|
||||
),
|
||||
stream,
|
||||
cursor,
|
||||
logger=logger,
|
||||
slice_logger=source._slice_logger,
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stream: DefaultStream,
|
||||
legacy_stream: AbstractFileBasedStream,
|
||||
cursor: FileBasedNoopCursor,
|
||||
slice_logger: SliceLogger,
|
||||
logger: logging.Logger,
|
||||
):
|
||||
"""
|
||||
:param stream: The underlying AbstractStream
|
||||
"""
|
||||
# super().__init__(stream, legacy_stream, cursor, slice_logger, logger)
|
||||
self._abstract_stream = stream
|
||||
self._legacy_stream = legacy_stream
|
||||
self._cursor = cursor
|
||||
self._slice_logger = slice_logger
|
||||
self._logger = logger
|
||||
self.catalog_schema = legacy_stream.catalog_schema
|
||||
self.config = legacy_stream.config
|
||||
self.validation_policy = legacy_stream.validation_policy
|
||||
|
||||
@property
|
||||
def cursor_field(self) -> Union[str, List[str]]:
|
||||
if self._abstract_stream.cursor_field is None:
|
||||
return []
|
||||
else:
|
||||
return self._abstract_stream.cursor_field
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._abstract_stream.name
|
||||
|
||||
@property
|
||||
def supports_incremental(self) -> bool:
|
||||
return self._legacy_stream.supports_incremental
|
||||
|
||||
@property
|
||||
def availability_strategy(self) -> AbstractFileBasedAvailabilityStrategy:
|
||||
return self._legacy_stream.availability_strategy
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_json_schema(self) -> Mapping[str, Any]:
|
||||
return self._abstract_stream.get_json_schema()
|
||||
|
||||
@property
|
||||
def primary_key(self) -> PrimaryKeyType:
|
||||
return self._legacy_stream.config.primary_key or self.get_parser().get_parser_defined_primary_key(self._legacy_stream.config)
|
||||
|
||||
def get_parser(self) -> FileTypeParser:
|
||||
return self._legacy_stream.get_parser()
|
||||
|
||||
def get_files(self) -> Iterable[RemoteFile]:
|
||||
return self._legacy_stream.get_files()
|
||||
|
||||
def read_records_from_slice(self, stream_slice: StreamSlice) -> Iterable[Mapping[str, Any]]:
|
||||
yield from self._legacy_stream.read_records_from_slice(stream_slice)
|
||||
|
||||
def compute_slices(self) -> Iterable[Optional[StreamSlice]]:
|
||||
return self._legacy_stream.compute_slices()
|
||||
|
||||
def infer_schema(self, files: List[RemoteFile]) -> Mapping[str, Any]:
|
||||
return self._legacy_stream.infer_schema(files)
|
||||
|
||||
def get_underlying_stream(self) -> DefaultStream:
|
||||
return self._abstract_stream
|
||||
|
||||
def read_full_refresh(
|
||||
self,
|
||||
cursor_field: Optional[List[str]],
|
||||
logger: logging.Logger,
|
||||
slice_logger: SliceLogger,
|
||||
) -> Iterable[StreamData]:
|
||||
"""
|
||||
Read full refresh. Delegate to the underlying AbstractStream, ignoring all the parameters
|
||||
:param cursor_field: (ignored)
|
||||
:param logger: (ignored)
|
||||
:param slice_logger: (ignored)
|
||||
:return: Iterable of StreamData
|
||||
"""
|
||||
yield from self._read_records()
|
||||
|
||||
def read_incremental(
|
||||
self,
|
||||
cursor_field: Optional[List[str]],
|
||||
logger: logging.Logger,
|
||||
slice_logger: SliceLogger,
|
||||
stream_state: MutableMapping[str, Any],
|
||||
state_manager: ConnectorStateManager,
|
||||
per_stream_state_enabled: bool,
|
||||
internal_config: InternalConfig,
|
||||
) -> Iterable[StreamData]:
|
||||
yield from self._read_records()
|
||||
|
||||
def read_records(
|
||||
self,
|
||||
sync_mode: SyncMode,
|
||||
cursor_field: Optional[List[str]] = None,
|
||||
stream_slice: Optional[Mapping[str, Any]] = None,
|
||||
stream_state: Optional[Mapping[str, Any]] = None,
|
||||
) -> Iterable[StreamData]:
|
||||
try:
|
||||
yield from self._read_records()
|
||||
except Exception as exc:
|
||||
if hasattr(self._cursor, "state"):
|
||||
state = str(self._cursor.state)
|
||||
else:
|
||||
# This shouldn't happen if the ConcurrentCursor was used
|
||||
state = "unknown; no state attribute was available on the cursor"
|
||||
yield AirbyteMessage(
|
||||
type=Type.LOG, log=AirbyteLogMessage(level=Level.ERROR, message=f"Cursor State at time of exception: {state}")
|
||||
)
|
||||
raise exc
|
||||
|
||||
def _read_records(self) -> Iterable[StreamData]:
|
||||
for partition in self._abstract_stream.generate_partitions():
|
||||
if self._slice_logger.should_log_slice_message(self._logger):
|
||||
yield self._slice_logger.create_slice_log_message(partition.to_slice())
|
||||
for record in partition.read():
|
||||
yield record.data
|
||||
|
||||
|
||||
class FileBasedStreamPartition(Partition):
|
||||
def __init__(
|
||||
self,
|
||||
stream: AbstractFileBasedStream,
|
||||
_slice: Optional[Mapping[str, Any]],
|
||||
message_repository: MessageRepository,
|
||||
sync_mode: SyncMode,
|
||||
cursor_field: Optional[List[str]],
|
||||
state: Optional[MutableMapping[str, Any]],
|
||||
cursor: FileBasedNoopCursor,
|
||||
):
|
||||
self._stream = stream
|
||||
self._slice = _slice
|
||||
self._message_repository = message_repository
|
||||
self._sync_mode = sync_mode
|
||||
self._cursor_field = cursor_field
|
||||
self._state = state
|
||||
self._cursor = cursor
|
||||
self._is_closed = False
|
||||
|
||||
def read(self) -> Iterable[Record]:
|
||||
try:
|
||||
for record_data in self._stream.read_records(
|
||||
cursor_field=self._cursor_field,
|
||||
sync_mode=SyncMode.full_refresh,
|
||||
stream_slice=copy.deepcopy(self._slice),
|
||||
stream_state=self._state,
|
||||
):
|
||||
if isinstance(record_data, Mapping):
|
||||
data_to_return = dict(record_data)
|
||||
self._stream.transformer.transform(data_to_return, self._stream.get_json_schema())
|
||||
yield Record(data_to_return, self.stream_name())
|
||||
else:
|
||||
self._message_repository.emit_message(record_data)
|
||||
except Exception as e:
|
||||
display_message = self._stream.get_error_display_message(e)
|
||||
if display_message:
|
||||
raise ExceptionWithDisplayMessage(display_message) from e
|
||||
else:
|
||||
raise e
|
||||
|
||||
def to_slice(self) -> Optional[Mapping[str, Any]]:
|
||||
if self._slice is None:
|
||||
return None
|
||||
assert (
|
||||
len(self._slice["files"]) == 1
|
||||
), f"Expected 1 file per partition but got {len(self._slice['files'])} for stream {self.stream_name()}"
|
||||
file = self._slice["files"][0]
|
||||
return {"files": [file]}
|
||||
|
||||
def close(self) -> None:
|
||||
self._cursor.close_partition(self)
|
||||
self._is_closed = True
|
||||
|
||||
def is_closed(self) -> bool:
|
||||
return self._is_closed
|
||||
|
||||
def __hash__(self) -> int:
|
||||
if self._slice:
|
||||
# Convert the slice to a string so that it can be hashed
|
||||
if len(self._slice["files"]) != 1:
|
||||
raise ValueError(
|
||||
f"Slices for file-based streams should be of length 1, but got {len(self._slice['files'])}. This is unexpected. Please contact Support."
|
||||
)
|
||||
else:
|
||||
s = f"{self._slice['files'][0].last_modified.strftime('%Y-%m-%dT%H:%M:%S.%fZ')}_{self._slice['files'][0].uri}"
|
||||
return hash((self._stream.name, s))
|
||||
else:
|
||||
return hash(self._stream.name)
|
||||
|
||||
def stream_name(self) -> str:
|
||||
return self._stream.name
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"FileBasedStreamPartition({self._stream.name}, {self._slice})"
|
||||
|
||||
|
||||
class FileBasedStreamPartitionGenerator(PartitionGenerator):
|
||||
def __init__(
|
||||
self,
|
||||
stream: AbstractFileBasedStream,
|
||||
message_repository: MessageRepository,
|
||||
sync_mode: SyncMode,
|
||||
cursor_field: Optional[List[str]],
|
||||
state: Optional[MutableMapping[str, Any]],
|
||||
cursor: FileBasedNoopCursor,
|
||||
):
|
||||
self._stream = stream
|
||||
self._message_repository = message_repository
|
||||
self._sync_mode = sync_mode
|
||||
self._cursor_field = cursor_field
|
||||
self._state = state
|
||||
self._cursor = cursor
|
||||
|
||||
def generate(self) -> Iterable[FileBasedStreamPartition]:
|
||||
pending_partitions = []
|
||||
for _slice in self._stream.stream_slices(sync_mode=self._sync_mode, cursor_field=self._cursor_field, stream_state=self._state):
|
||||
if _slice is not None:
|
||||
pending_partitions.extend(
|
||||
[
|
||||
FileBasedStreamPartition(
|
||||
self._stream,
|
||||
{"files": [copy.deepcopy(f)]},
|
||||
self._message_repository,
|
||||
self._sync_mode,
|
||||
self._cursor_field,
|
||||
self._state,
|
||||
self._cursor,
|
||||
)
|
||||
for f in _slice.get("files", [])
|
||||
]
|
||||
)
|
||||
self._cursor.set_pending_partitions(pending_partitions)
|
||||
yield from pending_partitions
|
||||
@@ -0,0 +1,87 @@
|
||||
#
|
||||
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
|
||||
#
|
||||
import logging
|
||||
from abc import abstractmethod
|
||||
from datetime import datetime
|
||||
from typing import Any, Iterable, MutableMapping
|
||||
|
||||
from airbyte_cdk.sources.file_based.config.file_based_stream_config import FileBasedStreamConfig
|
||||
from airbyte_cdk.sources.file_based.remote_file import RemoteFile
|
||||
from airbyte_cdk.sources.file_based.stream.cursor import AbstractFileBasedCursor
|
||||
from airbyte_cdk.sources.file_based.types import StreamState
|
||||
from airbyte_cdk.sources.streams.concurrent.cursor import Cursor
|
||||
from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition
|
||||
from airbyte_cdk.sources.streams.concurrent.partitions.record import Record
|
||||
|
||||
|
||||
class AbstractFileBasedConcurrentCursor(Cursor, AbstractFileBasedCursor):
|
||||
@property
|
||||
@abstractmethod
|
||||
def state(self) -> MutableMapping[str, Any]:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def add_file(self, file: RemoteFile) -> None:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def set_initial_state(self, value: StreamState) -> None:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_state(self) -> MutableMapping[str, Any]:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_start_time(self) -> datetime:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_files_to_sync(self, all_files: Iterable[RemoteFile], logger: logging.Logger) -> Iterable[RemoteFile]:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def observe(self, record: Record) -> None:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def close_partition(self, partition: Partition) -> None:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def set_pending_partitions(self, partitions: Iterable[Partition]) -> None:
|
||||
...
|
||||
|
||||
|
||||
class FileBasedNoopCursor(AbstractFileBasedConcurrentCursor):
|
||||
def __init__(self, stream_config: FileBasedStreamConfig, **kwargs: Any):
|
||||
pass
|
||||
|
||||
@property
|
||||
def state(self) -> MutableMapping[str, Any]:
|
||||
return {}
|
||||
|
||||
def add_file(self, file: RemoteFile) -> None:
|
||||
return None
|
||||
|
||||
def set_initial_state(self, value: StreamState) -> None:
|
||||
return None
|
||||
|
||||
def get_state(self) -> MutableMapping[str, Any]:
|
||||
return {}
|
||||
|
||||
def get_start_time(self) -> datetime:
|
||||
return datetime.min
|
||||
|
||||
def get_files_to_sync(self, all_files: Iterable[RemoteFile], logger: logging.Logger) -> Iterable[RemoteFile]:
|
||||
return []
|
||||
|
||||
def observe(self, record: Record) -> None:
|
||||
return None
|
||||
|
||||
def close_partition(self, partition: Partition) -> None:
|
||||
return None
|
||||
|
||||
def set_pending_partitions(self, partitions: Iterable[Partition]) -> None:
|
||||
return None
|
||||
@@ -164,9 +164,13 @@ class DefaultFileBasedStream(AbstractFileBasedStream, IncrementalMixin):
|
||||
try:
|
||||
schema = self._get_raw_json_schema()
|
||||
except (InvalidSchemaError, NoFilesMatchingError) as config_exception:
|
||||
self.logger.exception(FileBasedSourceError.SCHEMA_INFERENCE_ERROR.value, exc_info=config_exception)
|
||||
raise AirbyteTracedException(
|
||||
message=FileBasedSourceError.SCHEMA_INFERENCE_ERROR.value, exception=config_exception, failure_type=FailureType.config_error
|
||||
) from config_exception
|
||||
internal_message="Please check the logged errors for more information.",
|
||||
message=FileBasedSourceError.SCHEMA_INFERENCE_ERROR.value,
|
||||
exception=AirbyteTracedException(exception=config_exception),
|
||||
failure_type=FailureType.config_error,
|
||||
)
|
||||
except Exception as exc:
|
||||
raise SchemaInferenceError(FileBasedSourceError.SCHEMA_INFERENCE_ERROR, stream=self.name) from exc
|
||||
else:
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Generic, Optional, TypeVar
|
||||
|
||||
from airbyte_cdk.sources.streams.concurrent.exceptions import ExceptionWithDisplayMessage
|
||||
|
||||
StreamType = TypeVar("StreamType")
|
||||
|
||||
|
||||
class AbstractStreamFacade(Generic[StreamType], ABC):
|
||||
@abstractmethod
|
||||
def get_underlying_stream(self) -> StreamType:
|
||||
"""
|
||||
Return the underlying stream facade object.
|
||||
"""
|
||||
...
|
||||
|
||||
@property
|
||||
def source_defined_cursor(self) -> bool:
|
||||
# Streams must be aware of their cursor at instantiation time
|
||||
return True
|
||||
|
||||
def get_error_display_message(self, exception: BaseException) -> Optional[str]:
|
||||
"""
|
||||
Retrieves the user-friendly display message that corresponds to an exception.
|
||||
This will be called when encountering an exception while reading records from the stream, and used to build the AirbyteTraceMessage.
|
||||
|
||||
A display message will be returned if the exception is an instance of ExceptionWithDisplayMessage.
|
||||
|
||||
:param exception: The exception that was raised
|
||||
:return: A user-friendly message that indicates the cause of the error
|
||||
"""
|
||||
if isinstance(exception, ExceptionWithDisplayMessage):
|
||||
return exception.display_message
|
||||
else:
|
||||
return None
|
||||
@@ -14,7 +14,7 @@ from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager
|
||||
from airbyte_cdk.sources.message import MessageRepository
|
||||
from airbyte_cdk.sources.streams import Stream
|
||||
from airbyte_cdk.sources.streams.availability_strategy import AvailabilityStrategy
|
||||
from airbyte_cdk.sources.streams.concurrent.abstract_stream import AbstractStream
|
||||
from airbyte_cdk.sources.streams.concurrent.abstract_stream_facade import AbstractStreamFacade
|
||||
from airbyte_cdk.sources.streams.concurrent.availability_strategy import (
|
||||
AbstractAvailabilityStrategy,
|
||||
StreamAvailability,
|
||||
@@ -24,6 +24,7 @@ from airbyte_cdk.sources.streams.concurrent.availability_strategy import (
|
||||
from airbyte_cdk.sources.streams.concurrent.cursor import Cursor, NoopCursor
|
||||
from airbyte_cdk.sources.streams.concurrent.default_stream import DefaultStream
|
||||
from airbyte_cdk.sources.streams.concurrent.exceptions import ExceptionWithDisplayMessage
|
||||
from airbyte_cdk.sources.streams.concurrent.helpers import get_cursor_field_from_stream, get_primary_key_from_stream
|
||||
from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition
|
||||
from airbyte_cdk.sources.streams.concurrent.partitions.partition_generator import PartitionGenerator
|
||||
from airbyte_cdk.sources.streams.concurrent.partitions.record import Record
|
||||
@@ -38,7 +39,7 @@ This module contains adapters to help enabling concurrency on Stream objects wit
|
||||
|
||||
|
||||
@deprecated("This class is experimental. Use at your own risk.")
|
||||
class StreamFacade(Stream):
|
||||
class StreamFacade(AbstractStreamFacade[DefaultStream], Stream):
|
||||
"""
|
||||
The StreamFacade is a Stream that wraps an AbstractStream and exposes it as a Stream.
|
||||
|
||||
@@ -62,8 +63,8 @@ class StreamFacade(Stream):
|
||||
:param max_workers: The maximum number of worker thread to use
|
||||
:return:
|
||||
"""
|
||||
pk = cls._get_primary_key_from_stream(stream.primary_key)
|
||||
cursor_field = cls._get_cursor_field_from_stream(stream)
|
||||
pk = get_primary_key_from_stream(stream.primary_key)
|
||||
cursor_field = get_cursor_field_from_stream(stream)
|
||||
|
||||
if not source.message_repository:
|
||||
raise ValueError(
|
||||
@@ -104,33 +105,7 @@ class StreamFacade(Stream):
|
||||
if "state" in dir(self._legacy_stream):
|
||||
self._legacy_stream.state = value # type: ignore # validating `state` is attribute of stream using `if` above
|
||||
|
||||
@classmethod
|
||||
def _get_primary_key_from_stream(cls, stream_primary_key: Optional[Union[str, List[str], List[List[str]]]]) -> List[str]:
|
||||
if stream_primary_key is None:
|
||||
return []
|
||||
elif isinstance(stream_primary_key, str):
|
||||
return [stream_primary_key]
|
||||
elif isinstance(stream_primary_key, list):
|
||||
if len(stream_primary_key) > 0 and all(isinstance(k, str) for k in stream_primary_key):
|
||||
return stream_primary_key # type: ignore # We verified all items in the list are strings
|
||||
else:
|
||||
raise ValueError(f"Nested primary keys are not supported. Found {stream_primary_key}")
|
||||
else:
|
||||
raise ValueError(f"Invalid type for primary key: {stream_primary_key}")
|
||||
|
||||
@classmethod
|
||||
def _get_cursor_field_from_stream(cls, stream: Stream) -> Optional[str]:
|
||||
if isinstance(stream.cursor_field, list):
|
||||
if len(stream.cursor_field) > 1:
|
||||
raise ValueError(f"Nested cursor fields are not supported. Got {stream.cursor_field} for {stream.name}")
|
||||
elif len(stream.cursor_field) == 0:
|
||||
return None
|
||||
else:
|
||||
return stream.cursor_field[0]
|
||||
else:
|
||||
return stream.cursor_field
|
||||
|
||||
def __init__(self, stream: AbstractStream, legacy_stream: Stream, cursor: Cursor, slice_logger: SliceLogger, logger: logging.Logger):
|
||||
def __init__(self, stream: DefaultStream, legacy_stream: Stream, cursor: Cursor, slice_logger: SliceLogger, logger: logging.Logger):
|
||||
"""
|
||||
:param stream: The underlying AbstractStream
|
||||
"""
|
||||
@@ -178,7 +153,7 @@ class StreamFacade(Stream):
|
||||
yield from self._read_records()
|
||||
except Exception as exc:
|
||||
if hasattr(self._cursor, "state"):
|
||||
state = self._cursor.state
|
||||
state = str(self._cursor.state)
|
||||
else:
|
||||
# This shouldn't happen if the ConcurrentCursor was used
|
||||
state = "unknown; no state attribute was available on the cursor"
|
||||
@@ -210,11 +185,6 @@ class StreamFacade(Stream):
|
||||
else:
|
||||
return self._abstract_stream.cursor_field
|
||||
|
||||
@property
|
||||
def source_defined_cursor(self) -> bool:
|
||||
# Streams must be aware of their cursor at instantiation time
|
||||
return True
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_json_schema(self) -> Mapping[str, Any]:
|
||||
return self._abstract_stream.get_json_schema()
|
||||
@@ -233,27 +203,15 @@ class StreamFacade(Stream):
|
||||
availability = self._abstract_stream.check_availability()
|
||||
return availability.is_available(), availability.message()
|
||||
|
||||
def get_error_display_message(self, exception: BaseException) -> Optional[str]:
|
||||
"""
|
||||
Retrieves the user-friendly display message that corresponds to an exception.
|
||||
This will be called when encountering an exception while reading records from the stream, and used to build the AirbyteTraceMessage.
|
||||
|
||||
A display message will be returned if the exception is an instance of ExceptionWithDisplayMessage.
|
||||
|
||||
:param exception: The exception that was raised
|
||||
:return: A user-friendly message that indicates the cause of the error
|
||||
"""
|
||||
if isinstance(exception, ExceptionWithDisplayMessage):
|
||||
return exception.display_message
|
||||
else:
|
||||
return None
|
||||
|
||||
def as_airbyte_stream(self) -> AirbyteStream:
|
||||
return self._abstract_stream.as_airbyte_stream()
|
||||
|
||||
def log_stream_sync_configuration(self) -> None:
|
||||
self._abstract_stream.log_stream_sync_configuration()
|
||||
|
||||
def get_underlying_stream(self) -> DefaultStream:
|
||||
return self._abstract_stream
|
||||
|
||||
|
||||
class StreamPartition(Partition):
|
||||
"""
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
|
||||
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from airbyte_cdk.sources.streams import Stream
|
||||
|
||||
|
||||
def get_primary_key_from_stream(stream_primary_key: Optional[Union[str, List[str], List[List[str]]]]) -> List[str]:
|
||||
if stream_primary_key is None:
|
||||
return []
|
||||
elif isinstance(stream_primary_key, str):
|
||||
return [stream_primary_key]
|
||||
elif isinstance(stream_primary_key, list):
|
||||
if len(stream_primary_key) > 0 and all(isinstance(k, str) for k in stream_primary_key):
|
||||
return stream_primary_key # type: ignore # We verified all items in the list are strings
|
||||
else:
|
||||
raise ValueError(f"Nested primary keys are not supported. Found {stream_primary_key}")
|
||||
else:
|
||||
raise ValueError(f"Invalid type for primary key: {stream_primary_key}")
|
||||
|
||||
|
||||
def get_cursor_field_from_stream(stream: Stream) -> Optional[str]:
|
||||
if isinstance(stream.cursor_field, list):
|
||||
if len(stream.cursor_field) > 1:
|
||||
raise ValueError(f"Nested cursor fields are not supported. Got {stream.cursor_field} for {stream.name}")
|
||||
elif len(stream.cursor_field) == 0:
|
||||
return None
|
||||
else:
|
||||
return stream.cursor_field[0]
|
||||
else:
|
||||
return stream.cursor_field
|
||||
@@ -447,11 +447,13 @@ class CsvReaderTest(unittest.TestCase):
|
||||
.build()
|
||||
)
|
||||
|
||||
dialects_before = set(csv.list_dialects())
|
||||
data_generator = self._read_data()
|
||||
next(data_generator)
|
||||
assert f"{self._CONFIG_NAME}_config_dialect" in csv.list_dialects()
|
||||
[new_dialect] = set(csv.list_dialects()) - dialects_before
|
||||
assert self._CONFIG_NAME in new_dialect
|
||||
data_generator.close()
|
||||
assert f"{self._CONFIG_NAME}_config_dialect" not in csv.list_dialects()
|
||||
assert new_dialect not in csv.list_dialects()
|
||||
|
||||
def test_given_too_many_values_for_columns_when_read_data_then_raise_exception_and_unregister_dialect(self) -> None:
|
||||
self._stream_reader.open_file.return_value = (
|
||||
@@ -466,13 +468,15 @@ class CsvReaderTest(unittest.TestCase):
|
||||
.build()
|
||||
)
|
||||
|
||||
dialects_before = set(csv.list_dialects())
|
||||
data_generator = self._read_data()
|
||||
next(data_generator)
|
||||
assert f"{self._CONFIG_NAME}_config_dialect" in csv.list_dialects()
|
||||
[new_dialect] = set(csv.list_dialects()) - dialects_before
|
||||
assert self._CONFIG_NAME in new_dialect
|
||||
|
||||
with pytest.raises(RecordParseError):
|
||||
next(data_generator)
|
||||
assert f"{self._CONFIG_NAME}_config_dialect" not in csv.list_dialects()
|
||||
assert new_dialect not in csv.list_dialects()
|
||||
|
||||
def test_given_too_few_values_for_columns_when_read_data_then_raise_exception_and_unregister_dialect(self) -> None:
|
||||
self._stream_reader.open_file.return_value = (
|
||||
@@ -487,13 +491,15 @@ class CsvReaderTest(unittest.TestCase):
|
||||
.build()
|
||||
)
|
||||
|
||||
dialects_before = set(csv.list_dialects())
|
||||
data_generator = self._read_data()
|
||||
next(data_generator)
|
||||
assert f"{self._CONFIG_NAME}_config_dialect" in csv.list_dialects()
|
||||
[new_dialect] = set(csv.list_dialects()) - dialects_before
|
||||
assert self._CONFIG_NAME in new_dialect
|
||||
|
||||
with pytest.raises(RecordParseError):
|
||||
next(data_generator)
|
||||
assert f"{self._CONFIG_NAME}_config_dialect" not in csv.list_dialects()
|
||||
assert new_dialect not in csv.list_dialects()
|
||||
|
||||
def _read_data(self) -> Generator[Dict[str, str], None, None]:
|
||||
data_generator = self._csv_reader.read_data(
|
||||
|
||||
@@ -26,11 +26,14 @@ from airbyte_cdk.sources.file_based.file_types.file_type_parser import FileTypeP
|
||||
from airbyte_cdk.sources.file_based.remote_file import RemoteFile
|
||||
from airbyte_cdk.sources.file_based.schema_validation_policies import DEFAULT_SCHEMA_VALIDATION_POLICIES, AbstractSchemaValidationPolicy
|
||||
from airbyte_cdk.sources.file_based.stream.cursor import AbstractFileBasedCursor, DefaultFileBasedCursor
|
||||
from airbyte_cdk.sources.source import TState
|
||||
from avro import datafile
|
||||
from pydantic import AnyUrl
|
||||
|
||||
|
||||
class InMemoryFilesSource(FileBasedSource):
|
||||
_concurrency_level = 10
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
files: Mapping[str, Any],
|
||||
@@ -41,6 +44,8 @@ class InMemoryFilesSource(FileBasedSource):
|
||||
parsers: Mapping[str, FileTypeParser],
|
||||
stream_reader: Optional[AbstractFileBasedStreamReader],
|
||||
catalog: Optional[Mapping[str, Any]],
|
||||
config: Optional[Mapping[str, Any]],
|
||||
state: Optional[TState],
|
||||
file_write_options: Mapping[str, Any],
|
||||
cursor_cls: Optional[AbstractFileBasedCursor],
|
||||
):
|
||||
@@ -48,6 +53,9 @@ class InMemoryFilesSource(FileBasedSource):
|
||||
self.files = files
|
||||
self.file_type = file_type
|
||||
self.catalog = catalog
|
||||
self.configured_catalog = ConfiguredAirbyteCatalog(streams=self.catalog["streams"]) if self.catalog else None
|
||||
self.config = config
|
||||
self.state = state
|
||||
|
||||
# Source setup
|
||||
stream_reader = stream_reader or InMemoryFilesStreamReader(files=files, file_type=file_type, file_write_options=file_write_options)
|
||||
@@ -55,7 +63,9 @@ class InMemoryFilesSource(FileBasedSource):
|
||||
super().__init__(
|
||||
stream_reader,
|
||||
spec_class=InMemorySpec,
|
||||
catalog_path="fake_path" if catalog else None,
|
||||
catalog=self.configured_catalog,
|
||||
config=self.config,
|
||||
state=self.state,
|
||||
availability_strategy=availability_strategy,
|
||||
discovery_policy=discovery_policy or DefaultDiscoveryPolicy(),
|
||||
parsers=parsers,
|
||||
@@ -64,7 +74,7 @@ class InMemoryFilesSource(FileBasedSource):
|
||||
)
|
||||
|
||||
def read_catalog(self, catalog_path: str) -> ConfiguredAirbyteCatalog:
|
||||
return ConfiguredAirbyteCatalog(streams=self.catalog["streams"]) if self.catalog else None
|
||||
return self.configured_catalog
|
||||
|
||||
|
||||
class InMemoryFilesStreamReader(AbstractFileBasedStreamReader):
|
||||
|
||||
@@ -847,7 +847,7 @@ invalid_csv_scenario: TestScenario[InMemoryFilesSource] = (
|
||||
"read": [
|
||||
{
|
||||
"level": "ERROR",
|
||||
"message": f"{FileBasedSourceError.ERROR_PARSING_RECORD.value} stream=stream1 file=a.csv line_no=1 n_skipped=0",
|
||||
"message": f"{FileBasedSourceError.INVALID_SCHEMA_ERROR.value} stream=stream1 file=a.csv line_no=1 n_skipped=0",
|
||||
},
|
||||
]
|
||||
}
|
||||
@@ -1471,28 +1471,7 @@ empty_schema_inference_scenario: TestScenario[InMemoryFilesSource] = (
|
||||
}
|
||||
)
|
||||
.set_expected_discover_error(AirbyteTracedException, FileBasedSourceError.SCHEMA_INFERENCE_ERROR.value)
|
||||
.set_expected_records(
|
||||
[
|
||||
{
|
||||
"data": {
|
||||
"col1": "val11",
|
||||
"col2": "val12",
|
||||
"_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z",
|
||||
"_ab_source_file_url": "a.csv",
|
||||
},
|
||||
"stream": "stream1",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"col1": "val21",
|
||||
"col2": "val22",
|
||||
"_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z",
|
||||
"_ab_source_file_url": "a.csv",
|
||||
},
|
||||
"stream": "stream1",
|
||||
},
|
||||
]
|
||||
)
|
||||
.set_expected_records([])
|
||||
).build()
|
||||
|
||||
schemaless_csv_scenario: TestScenario[InMemoryFilesSource] = (
|
||||
@@ -1766,7 +1745,7 @@ schemaless_with_user_input_schema_fails_connection_check_scenario: TestScenario[
|
||||
}
|
||||
)
|
||||
.set_expected_check_status("FAILED")
|
||||
.set_expected_check_error(ConfigValidationError, FileBasedSourceError.CONFIG_VALIDATION_ERROR.value)
|
||||
.set_expected_check_error(AirbyteTracedException, FileBasedSourceError.CONFIG_VALIDATION_ERROR.value)
|
||||
.set_expected_discover_error(ConfigValidationError, FileBasedSourceError.CONFIG_VALIDATION_ERROR.value)
|
||||
.set_expected_read_error(ConfigValidationError, FileBasedSourceError.CONFIG_VALIDATION_ERROR.value)
|
||||
).build()
|
||||
@@ -1854,7 +1833,7 @@ schemaless_with_user_input_schema_fails_connection_check_multi_stream_scenario:
|
||||
}
|
||||
)
|
||||
.set_expected_check_status("FAILED")
|
||||
.set_expected_check_error(ConfigValidationError, FileBasedSourceError.CONFIG_VALIDATION_ERROR.value)
|
||||
.set_expected_check_error(AirbyteTracedException, FileBasedSourceError.CONFIG_VALIDATION_ERROR.value)
|
||||
.set_expected_discover_error(ConfigValidationError, FileBasedSourceError.CONFIG_VALIDATION_ERROR.value)
|
||||
.set_expected_read_error(ConfigValidationError, FileBasedSourceError.CONFIG_VALIDATION_ERROR.value)
|
||||
).build()
|
||||
@@ -3029,6 +3008,7 @@ earlier_csv_scenario: TestScenario[InMemoryFilesSource] = (
|
||||
.set_file_type("csv")
|
||||
)
|
||||
.set_expected_check_status("FAILED")
|
||||
.set_expected_check_error(AirbyteTracedException, FileBasedSourceError.EMPTY_STREAM.value)
|
||||
.set_expected_catalog(
|
||||
{
|
||||
"streams": [
|
||||
|
||||
@@ -14,6 +14,7 @@ from airbyte_cdk.sources.file_based.file_based_stream_reader import AbstractFile
|
||||
from airbyte_cdk.sources.file_based.file_types.file_type_parser import FileTypeParser
|
||||
from airbyte_cdk.sources.file_based.schema_validation_policies import AbstractSchemaValidationPolicy
|
||||
from airbyte_cdk.sources.file_based.stream.cursor import AbstractFileBasedCursor
|
||||
from airbyte_cdk.sources.source import TState
|
||||
from unit_tests.sources.file_based.in_memory_files_source import InMemoryFilesSource
|
||||
from unit_tests.sources.file_based.scenarios.scenario_builder import SourceBuilder
|
||||
|
||||
@@ -29,8 +30,10 @@ class FileBasedSourceBuilder(SourceBuilder[InMemoryFilesSource]):
|
||||
self._stream_reader: Optional[AbstractFileBasedStreamReader] = None
|
||||
self._file_write_options: Mapping[str, Any] = {}
|
||||
self._cursor_cls: Optional[Type[AbstractFileBasedCursor]] = None
|
||||
self._config: Optional[Mapping[str, Any]] = None
|
||||
self._state: Optional[TState] = None
|
||||
|
||||
def build(self, configured_catalog: Optional[Mapping[str, Any]]) -> InMemoryFilesSource:
|
||||
def build(self, configured_catalog: Optional[Mapping[str, Any]], config: Optional[Mapping[str, Any]], state: Optional[TState]) -> InMemoryFilesSource:
|
||||
if self._file_type is None:
|
||||
raise ValueError("file_type is not set")
|
||||
return InMemoryFilesSource(
|
||||
@@ -42,6 +45,8 @@ class FileBasedSourceBuilder(SourceBuilder[InMemoryFilesSource]):
|
||||
self._parsers,
|
||||
self._stream_reader,
|
||||
configured_catalog,
|
||||
config,
|
||||
state,
|
||||
self._file_write_options,
|
||||
self._cursor_cls,
|
||||
)
|
||||
|
||||
@@ -485,14 +485,10 @@ invalid_jsonl_scenario = (
|
||||
}
|
||||
)
|
||||
.set_expected_records(
|
||||
[
|
||||
{
|
||||
"data": {"col1": "val1", "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", "_ab_source_file_url": "a.jsonl"},
|
||||
"stream": "stream1",
|
||||
},
|
||||
]
|
||||
[]
|
||||
)
|
||||
.set_expected_discover_error(AirbyteTracedException, FileBasedSourceError.SCHEMA_INFERENCE_ERROR.value)
|
||||
.set_expected_read_error(AirbyteTracedException, "Please check the logged errors for more information.")
|
||||
.set_expected_logs(
|
||||
{
|
||||
"read": [
|
||||
|
||||
@@ -729,6 +729,7 @@ parquet_with_invalid_config_scenario = (
|
||||
.set_expected_records([])
|
||||
.set_expected_logs({"read": [{"level": "ERROR", "message": "Error parsing record"}]})
|
||||
.set_expected_discover_error(AirbyteTracedException, "Error inferring schema from files")
|
||||
.set_expected_read_error(AirbyteTracedException, "Please check the logged errors for more information.")
|
||||
.set_expected_catalog(
|
||||
{
|
||||
"streams": [
|
||||
|
||||
@@ -8,6 +8,7 @@ from typing import Any, Generic, List, Mapping, Optional, Set, Tuple, Type, Type
|
||||
|
||||
from airbyte_cdk.models import AirbyteAnalyticsTraceMessage, SyncMode
|
||||
from airbyte_cdk.sources import AbstractSource
|
||||
from airbyte_cdk.sources.source import TState
|
||||
from airbyte_protocol.models import ConfiguredAirbyteCatalog
|
||||
|
||||
|
||||
@@ -26,7 +27,7 @@ class SourceBuilder(ABC, Generic[SourceType]):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def build(self, configured_catalog: Optional[Mapping[str, Any]]) -> SourceType:
|
||||
def build(self, configured_catalog: Optional[Mapping[str, Any]], config: Optional[Mapping[str, Any]], state: Optional[TState]) -> SourceType:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@@ -80,11 +81,11 @@ class TestScenario(Generic[SourceType]):
|
||||
return self.catalog.dict() # type: ignore # dict() is not typed
|
||||
|
||||
catalog: Mapping[str, Any] = {"streams": []}
|
||||
for stream in self.source.streams(self.config):
|
||||
for stream in catalog["streams"]:
|
||||
catalog["streams"].append(
|
||||
{
|
||||
"stream": {
|
||||
"name": stream.name,
|
||||
"name": stream["name"],
|
||||
"json_schema": {},
|
||||
"supported_sync_modes": [sync_mode.value],
|
||||
},
|
||||
@@ -152,7 +153,7 @@ class TestScenarioBuilder(Generic[SourceType]):
|
||||
self._expected_logs = expected_logs
|
||||
return self
|
||||
|
||||
def set_expected_records(self, expected_records: List[Mapping[str, Any]]) -> "TestScenarioBuilder[SourceType]":
|
||||
def set_expected_records(self, expected_records: Optional[List[Mapping[str, Any]]]) -> "TestScenarioBuilder[SourceType]":
|
||||
self._expected_records = expected_records
|
||||
return self
|
||||
|
||||
@@ -191,7 +192,9 @@ class TestScenarioBuilder(Generic[SourceType]):
|
||||
if self.source_builder is None:
|
||||
raise ValueError("source_builder is not set")
|
||||
source = self.source_builder.build(
|
||||
self._configured_catalog(SyncMode.incremental if self._incremental_scenario_config else SyncMode.full_refresh)
|
||||
self._configured_catalog(SyncMode.incremental if self._incremental_scenario_config else SyncMode.full_refresh),
|
||||
self._config,
|
||||
self._incremental_scenario_config.input_state if self._incremental_scenario_config else None,
|
||||
)
|
||||
return TestScenario(
|
||||
self._name,
|
||||
|
||||
@@ -661,29 +661,7 @@ wait_for_rediscovery_scenario_single_stream = (
|
||||
]
|
||||
}
|
||||
)
|
||||
.set_expected_records(
|
||||
[
|
||||
{
|
||||
"data": {
|
||||
"col1": "val_a_11",
|
||||
"col2": 1,
|
||||
"_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z",
|
||||
"_ab_source_file_url": "a.csv",
|
||||
},
|
||||
"stream": "stream1",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"col1": "val_a_12",
|
||||
"col2": 2,
|
||||
"_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z",
|
||||
"_ab_source_file_url": "a.csv",
|
||||
},
|
||||
"stream": "stream1",
|
||||
},
|
||||
# No records past that because the first record for the second file did not conform to the schema
|
||||
]
|
||||
)
|
||||
.set_expected_records(None) # When syncing streams concurrently we don't know how many records will be emitted before the sync stops
|
||||
.set_expected_logs(
|
||||
{
|
||||
"read": [
|
||||
@@ -722,56 +700,7 @@ wait_for_rediscovery_scenario_multi_stream = (
|
||||
]
|
||||
}
|
||||
)
|
||||
.set_expected_records(
|
||||
[
|
||||
{
|
||||
"data": {
|
||||
"col1": "val_aa1_11",
|
||||
"col2": 1,
|
||||
"_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z",
|
||||
"_ab_source_file_url": "a/a1.csv",
|
||||
},
|
||||
"stream": "stream1",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"col1": "val_aa1_12",
|
||||
"col2": 2,
|
||||
"_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z",
|
||||
"_ab_source_file_url": "a/a1.csv",
|
||||
},
|
||||
"stream": "stream1",
|
||||
},
|
||||
# {"data": {"col1": "val_aa2_11", "col2": "this is text that will trigger validation policy", "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", "_ab_source_file_url": "a/a2.csv"}, "stream": "stream1"},
|
||||
# {"data": {"col1": "val_aa2_12", "col2": 2, "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", "_ab_source_file_url": "a/a2.csv"}, "stream": "stream1"},
|
||||
# {"data": {"col1": "val_aa3_11", "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", "_ab_source_file_url": "a/a3.csv"}, "stream": "stream1"},
|
||||
# {"data": {"col1": "val_aa3_12", None: "val_aa3_22", "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", "_ab_source_file_url": "a/a3.csv"}, "stream": "stream1"},
|
||||
# {"data": {"col1": "val_aa3_13", "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", "_ab_source_file_url": "a/a3.csv"}, "stream": "stream1"},
|
||||
# {"data": {"col1": "val_aa4_11", "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", "_ab_source_file_url": "a/a4.csv"}, "stream": "stream1"},
|
||||
{
|
||||
"data": {
|
||||
"col1": "val_bb1_11",
|
||||
"col2": 1,
|
||||
"_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z",
|
||||
"_ab_source_file_url": "b/b1.csv",
|
||||
},
|
||||
"stream": "stream2",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"col1": "val_bb1_12",
|
||||
"col2": 2,
|
||||
"_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z",
|
||||
"_ab_source_file_url": "b/b1.csv",
|
||||
},
|
||||
"stream": "stream2",
|
||||
},
|
||||
# {"data": {"col1": "val_bb2_11", "col2": "this is text that will trigger validation policy", "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", "_ab_source_file_url": "b/b2.csv"}, "stream": "stream2"},
|
||||
# {"data": {"col1": "val_bb2_12", "col2": 2, "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", "_ab_source_file_url": "b/b2.csv"}, "stream": "stream2"},
|
||||
# {"data": {"col1": "val_bb3_11", "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", "_ab_source_file_url": "b/b3.csv"}, "stream": "stream2"},
|
||||
# {"data": {"col1": "val_bb3_12", "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z", "_ab_source_file_url": "b/b3.csv"}, "stream": "stream2"},
|
||||
]
|
||||
)
|
||||
.set_expected_records(None) # When syncing streams concurrently we don't know how many records will be emitted before the sync stops
|
||||
.set_expected_logs(
|
||||
{
|
||||
"read": [
|
||||
|
||||
@@ -0,0 +1,365 @@
|
||||
#
|
||||
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
|
||||
#
|
||||
import logging
|
||||
import unittest
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock, Mock
|
||||
|
||||
import pytest
|
||||
from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, AirbyteStream, Level, SyncMode
|
||||
from airbyte_cdk.models import Type as MessageType
|
||||
from airbyte_cdk.sources.file_based.availability_strategy import DefaultFileBasedAvailabilityStrategy
|
||||
from airbyte_cdk.sources.file_based.config.csv_format import CsvFormat
|
||||
from airbyte_cdk.sources.file_based.config.file_based_stream_config import FileBasedStreamConfig
|
||||
from airbyte_cdk.sources.file_based.discovery_policy import DefaultDiscoveryPolicy
|
||||
from airbyte_cdk.sources.file_based.exceptions import FileBasedErrorsCollector
|
||||
from airbyte_cdk.sources.file_based.file_types import default_parsers
|
||||
from airbyte_cdk.sources.file_based.remote_file import RemoteFile
|
||||
from airbyte_cdk.sources.file_based.schema_validation_policies import EmitRecordPolicy
|
||||
from airbyte_cdk.sources.file_based.stream import DefaultFileBasedStream
|
||||
from airbyte_cdk.sources.file_based.stream.concurrent.adapters import (
|
||||
FileBasedStreamFacade,
|
||||
FileBasedStreamPartition,
|
||||
FileBasedStreamPartitionGenerator,
|
||||
)
|
||||
from airbyte_cdk.sources.file_based.stream.concurrent.cursor import FileBasedNoopCursor
|
||||
from airbyte_cdk.sources.message import InMemoryMessageRepository
|
||||
from airbyte_cdk.sources.streams.concurrent.cursor import Cursor
|
||||
from airbyte_cdk.sources.streams.concurrent.exceptions import ExceptionWithDisplayMessage
|
||||
from airbyte_cdk.sources.streams.concurrent.partitions.record import Record
|
||||
from airbyte_cdk.sources.utils.slice_logger import SliceLogger
|
||||
from airbyte_cdk.sources.utils.transform import TransformConfig, TypeTransformer
|
||||
from freezegun import freeze_time
|
||||
|
||||
_ANY_SYNC_MODE = SyncMode.full_refresh
|
||||
_ANY_STATE = {"state_key": "state_value"}
|
||||
_ANY_CURSOR_FIELD = ["a", "cursor", "key"]
|
||||
_STREAM_NAME = "stream"
|
||||
_ANY_CURSOR = Mock(spec=FileBasedNoopCursor)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sync_mode",
|
||||
[
|
||||
pytest.param(SyncMode.full_refresh, id="test_full_refresh"),
|
||||
pytest.param(SyncMode.incremental, id="test_incremental"),
|
||||
],
|
||||
)
|
||||
def test_file_based_stream_partition_generator(sync_mode):
|
||||
stream = Mock()
|
||||
message_repository = Mock()
|
||||
stream_slices = [{"files": [RemoteFile(uri="1", last_modified=datetime.now())]},
|
||||
{"files": [RemoteFile(uri="2", last_modified=datetime.now())]}]
|
||||
stream.stream_slices.return_value = stream_slices
|
||||
|
||||
partition_generator = FileBasedStreamPartitionGenerator(stream, message_repository, _ANY_SYNC_MODE, _ANY_CURSOR_FIELD, _ANY_STATE, _ANY_CURSOR)
|
||||
|
||||
partitions = list(partition_generator.generate())
|
||||
slices = [partition.to_slice() for partition in partitions]
|
||||
assert slices == stream_slices
|
||||
stream.stream_slices.assert_called_once_with(sync_mode=_ANY_SYNC_MODE, cursor_field=_ANY_CURSOR_FIELD, stream_state=_ANY_STATE)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"transformer, expected_records",
|
||||
[
|
||||
pytest.param(
|
||||
TypeTransformer(TransformConfig.NoTransform),
|
||||
[Record({"data": "1"}, _STREAM_NAME), Record({"data": "2"}, _STREAM_NAME)],
|
||||
id="test_no_transform",
|
||||
),
|
||||
pytest.param(
|
||||
TypeTransformer(TransformConfig.DefaultSchemaNormalization),
|
||||
[Record({"data": 1}, _STREAM_NAME), Record({"data": 2}, _STREAM_NAME)],
|
||||
id="test_default_transform",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_file_based_stream_partition(transformer, expected_records):
|
||||
stream = Mock()
|
||||
stream.name = _STREAM_NAME
|
||||
stream.get_json_schema.return_value = {"type": "object", "properties": {"data": {"type": ["integer"]}}}
|
||||
stream.transformer = transformer
|
||||
message_repository = InMemoryMessageRepository()
|
||||
_slice = None
|
||||
sync_mode = SyncMode.full_refresh
|
||||
cursor_field = None
|
||||
state = None
|
||||
partition = FileBasedStreamPartition(stream, _slice, message_repository, sync_mode, cursor_field, state, _ANY_CURSOR)
|
||||
|
||||
a_log_message = AirbyteMessage(
|
||||
type=MessageType.LOG,
|
||||
log=AirbyteLogMessage(
|
||||
level=Level.INFO,
|
||||
message='slice:{"partition": 1}',
|
||||
),
|
||||
)
|
||||
|
||||
stream_data = [a_log_message, {"data": "1"}, {"data": "2"}]
|
||||
stream.read_records.return_value = stream_data
|
||||
|
||||
records = list(partition.read())
|
||||
messages = list(message_repository.consume_queue())
|
||||
|
||||
assert records == expected_records
|
||||
assert messages == [a_log_message]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"exception_type, expected_display_message",
|
||||
[
|
||||
pytest.param(Exception, None, id="test_exception_no_display_message"),
|
||||
pytest.param(ExceptionWithDisplayMessage, "display_message", id="test_exception_no_display_message"),
|
||||
],
|
||||
)
|
||||
def test_file_based_stream_partition_raising_exception(exception_type, expected_display_message):
|
||||
stream = Mock()
|
||||
stream.get_error_display_message.return_value = expected_display_message
|
||||
|
||||
message_repository = InMemoryMessageRepository()
|
||||
_slice = None
|
||||
|
||||
partition = FileBasedStreamPartition(stream, _slice, message_repository, _ANY_SYNC_MODE, _ANY_CURSOR_FIELD, _ANY_STATE, _ANY_CURSOR)
|
||||
|
||||
stream.read_records.side_effect = Exception()
|
||||
|
||||
with pytest.raises(exception_type) as e:
|
||||
list(partition.read())
|
||||
if isinstance(e, ExceptionWithDisplayMessage):
|
||||
assert e.display_message == "display message"
|
||||
|
||||
|
||||
@freeze_time("2023-06-09T00:00:00Z")
|
||||
@pytest.mark.parametrize(
|
||||
"_slice, expected_hash",
|
||||
[
|
||||
pytest.param({"files": [RemoteFile(uri="1", last_modified=datetime.strptime("2023-06-09T00:00:00Z", "%Y-%m-%dT%H:%M:%SZ"))]}, hash(("stream", "2023-06-09T00:00:00.000000Z_1")), id="test_hash_with_slice"),
|
||||
pytest.param(None, hash("stream"), id="test_hash_no_slice"),
|
||||
],
|
||||
)
|
||||
def test_file_based_stream_partition_hash(_slice, expected_hash):
|
||||
stream = Mock()
|
||||
stream.name = "stream"
|
||||
partition = FileBasedStreamPartition(stream, _slice, Mock(), _ANY_SYNC_MODE, _ANY_CURSOR_FIELD, _ANY_STATE, _ANY_CURSOR)
|
||||
|
||||
_hash = partition.__hash__()
|
||||
assert _hash == expected_hash
|
||||
|
||||
|
||||
class StreamFacadeTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self._abstract_stream = Mock()
|
||||
self._abstract_stream.name = "stream"
|
||||
self._abstract_stream.as_airbyte_stream.return_value = AirbyteStream(
|
||||
name="stream",
|
||||
json_schema={"type": "object"},
|
||||
supported_sync_modes=[SyncMode.full_refresh],
|
||||
)
|
||||
self._legacy_stream = DefaultFileBasedStream(
|
||||
cursor=FileBasedNoopCursor(MagicMock()),
|
||||
config=FileBasedStreamConfig(name="stream", format=CsvFormat()),
|
||||
catalog_schema={},
|
||||
stream_reader=MagicMock(),
|
||||
availability_strategy=DefaultFileBasedAvailabilityStrategy(MagicMock()),
|
||||
discovery_policy=DefaultDiscoveryPolicy(),
|
||||
parsers=default_parsers,
|
||||
validation_policy=EmitRecordPolicy(),
|
||||
errors_collector=FileBasedErrorsCollector(),
|
||||
)
|
||||
self._cursor = Mock(spec=Cursor)
|
||||
self._logger = Mock()
|
||||
self._slice_logger = Mock()
|
||||
self._slice_logger.should_log_slice_message.return_value = False
|
||||
self._facade = FileBasedStreamFacade(self._abstract_stream, self._legacy_stream, self._cursor, self._slice_logger, self._logger)
|
||||
self._source = Mock()
|
||||
|
||||
self._stream = Mock()
|
||||
self._stream.primary_key = "id"
|
||||
|
||||
def test_name_is_delegated_to_wrapped_stream(self):
|
||||
assert self._facade.name == self._abstract_stream.name
|
||||
|
||||
def test_cursor_field_is_a_string(self):
|
||||
self._abstract_stream.cursor_field = "cursor_field"
|
||||
assert self._facade.cursor_field == "cursor_field"
|
||||
|
||||
def test_source_defined_cursor_is_true(self):
|
||||
assert self._facade.source_defined_cursor
|
||||
|
||||
def test_json_schema_is_delegated_to_wrapped_stream(self):
|
||||
json_schema = {"type": "object"}
|
||||
self._abstract_stream.get_json_schema.return_value = json_schema
|
||||
assert self._facade.get_json_schema() == json_schema
|
||||
self._abstract_stream.get_json_schema.assert_called_once_with()
|
||||
|
||||
def test_given_cursor_is_noop_when_supports_incremental_then_return_legacy_stream_response(self):
|
||||
assert (
|
||||
FileBasedStreamFacade(
|
||||
self._abstract_stream, self._legacy_stream, _ANY_CURSOR, Mock(spec=SliceLogger), Mock(spec=logging.Logger)
|
||||
).supports_incremental
|
||||
== self._legacy_stream.supports_incremental
|
||||
)
|
||||
|
||||
def test_given_cursor_is_not_noop_when_supports_incremental_then_return_true(self):
|
||||
assert FileBasedStreamFacade(
|
||||
self._abstract_stream, self._legacy_stream, Mock(spec=Cursor), Mock(spec=SliceLogger), Mock(spec=logging.Logger)
|
||||
).supports_incremental
|
||||
|
||||
def test_full_refresh(self):
|
||||
expected_stream_data = [{"data": 1}, {"data": 2}]
|
||||
records = [Record(data, "stream") for data in expected_stream_data]
|
||||
|
||||
partition = Mock()
|
||||
partition.read.return_value = records
|
||||
self._abstract_stream.generate_partitions.return_value = [partition]
|
||||
|
||||
actual_stream_data = list(self._facade.read_records(SyncMode.full_refresh, None, {}, None))
|
||||
|
||||
assert actual_stream_data == expected_stream_data
|
||||
|
||||
def test_read_records_full_refresh(self):
|
||||
expected_stream_data = [{"data": 1}, {"data": 2}]
|
||||
records = [Record(data, "stream") for data in expected_stream_data]
|
||||
partition = Mock()
|
||||
partition.read.return_value = records
|
||||
self._abstract_stream.generate_partitions.return_value = [partition]
|
||||
|
||||
actual_stream_data = list(self._facade.read_full_refresh(None, None, None))
|
||||
|
||||
assert actual_stream_data == expected_stream_data
|
||||
|
||||
def test_read_records_incremental(self):
|
||||
expected_stream_data = [{"data": 1}, {"data": 2}]
|
||||
records = [Record(data, "stream") for data in expected_stream_data]
|
||||
partition = Mock()
|
||||
partition.read.return_value = records
|
||||
self._abstract_stream.generate_partitions.return_value = [partition]
|
||||
|
||||
actual_stream_data = list(self._facade.read_incremental(None, None, None, None, None, None, None))
|
||||
|
||||
assert actual_stream_data == expected_stream_data
|
||||
|
||||
def test_create_from_stream_stream(self):
|
||||
stream = Mock()
|
||||
stream.name = "stream"
|
||||
stream.primary_key = "id"
|
||||
stream.cursor_field = "cursor"
|
||||
|
||||
facade = FileBasedStreamFacade.create_from_stream(stream, self._source, self._logger, _ANY_STATE, self._cursor)
|
||||
|
||||
assert facade.name == "stream"
|
||||
assert facade.cursor_field == "cursor"
|
||||
assert facade._abstract_stream._primary_key == ["id"]
|
||||
|
||||
def test_create_from_stream_stream_with_none_primary_key(self):
|
||||
stream = Mock()
|
||||
stream.name = "stream"
|
||||
stream.primary_key = None
|
||||
stream.cursor_field = []
|
||||
|
||||
facade = FileBasedStreamFacade.create_from_stream(stream, self._source, self._logger, _ANY_STATE, self._cursor)
|
||||
assert facade._abstract_stream._primary_key == []
|
||||
|
||||
def test_create_from_stream_with_composite_primary_key(self):
|
||||
stream = Mock()
|
||||
stream.name = "stream"
|
||||
stream.primary_key = ["id", "name"]
|
||||
stream.cursor_field = []
|
||||
|
||||
facade = FileBasedStreamFacade.create_from_stream(stream, self._source, self._logger, _ANY_STATE, self._cursor)
|
||||
assert facade._abstract_stream._primary_key == ["id", "name"]
|
||||
|
||||
def test_create_from_stream_with_empty_list_cursor(self):
|
||||
stream = Mock()
|
||||
stream.primary_key = "id"
|
||||
stream.cursor_field = []
|
||||
|
||||
facade = FileBasedStreamFacade.create_from_stream(stream, self._source, self._logger, _ANY_STATE, self._cursor)
|
||||
|
||||
assert facade.cursor_field == []
|
||||
|
||||
def test_create_from_stream_raises_exception_if_primary_key_is_nested(self):
|
||||
stream = Mock()
|
||||
stream.name = "stream"
|
||||
stream.primary_key = [["field", "id"]]
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
FileBasedStreamFacade.create_from_stream(stream, self._source, self._logger, _ANY_STATE, self._cursor)
|
||||
|
||||
def test_create_from_stream_raises_exception_if_primary_key_has_invalid_type(self):
|
||||
stream = Mock()
|
||||
stream.name = "stream"
|
||||
stream.primary_key = 123
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
FileBasedStreamFacade.create_from_stream(stream, self._source, self._logger, _ANY_STATE, self._cursor)
|
||||
|
||||
def test_create_from_stream_raises_exception_if_cursor_field_is_nested(self):
|
||||
stream = Mock()
|
||||
stream.name = "stream"
|
||||
stream.primary_key = "id"
|
||||
stream.cursor_field = ["field", "cursor"]
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
FileBasedStreamFacade.create_from_stream(stream, self._source, self._logger, _ANY_STATE, self._cursor)
|
||||
|
||||
def test_create_from_stream_with_cursor_field_as_list(self):
|
||||
stream = Mock()
|
||||
stream.name = "stream"
|
||||
stream.primary_key = "id"
|
||||
stream.cursor_field = ["cursor"]
|
||||
|
||||
facade = FileBasedStreamFacade.create_from_stream(stream, self._source, self._logger, _ANY_STATE, self._cursor)
|
||||
assert facade.cursor_field == "cursor"
|
||||
|
||||
def test_create_from_stream_none_message_repository(self):
|
||||
self._stream.name = "stream"
|
||||
self._stream.primary_key = "id"
|
||||
self._stream.cursor_field = "cursor"
|
||||
self._source.message_repository = None
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
FileBasedStreamFacade.create_from_stream(self._stream, self._source, self._logger, {}, self._cursor)
|
||||
|
||||
def test_get_error_display_message_no_display_message(self):
|
||||
self._stream.get_error_display_message.return_value = "display_message"
|
||||
|
||||
facade = FileBasedStreamFacade.create_from_stream(self._stream, self._source, self._logger, _ANY_STATE, self._cursor)
|
||||
|
||||
expected_display_message = None
|
||||
e = Exception()
|
||||
|
||||
display_message = facade.get_error_display_message(e)
|
||||
|
||||
assert expected_display_message == display_message
|
||||
|
||||
def test_get_error_display_message_with_display_message(self):
|
||||
self._stream.get_error_display_message.return_value = "display_message"
|
||||
|
||||
facade = FileBasedStreamFacade.create_from_stream(self._stream, self._source, self._logger, _ANY_STATE, self._cursor)
|
||||
|
||||
expected_display_message = "display_message"
|
||||
e = ExceptionWithDisplayMessage("display_message")
|
||||
|
||||
display_message = facade.get_error_display_message(e)
|
||||
|
||||
assert expected_display_message == display_message
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"exception, expected_display_message",
|
||||
[
|
||||
pytest.param(Exception("message"), None, id="test_no_display_message"),
|
||||
pytest.param(ExceptionWithDisplayMessage("message"), "message", id="test_no_display_message"),
|
||||
],
|
||||
)
|
||||
def test_get_error_display_message(exception, expected_display_message):
|
||||
stream = Mock()
|
||||
legacy_stream = Mock()
|
||||
cursor = Mock(spec=Cursor)
|
||||
facade = FileBasedStreamFacade(stream, legacy_stream, cursor, Mock().Mock(), Mock())
|
||||
|
||||
display_message = facade.get_error_display_message(exception)
|
||||
|
||||
assert display_message == expected_display_message
|
||||
@@ -73,8 +73,21 @@ def _verify_read_output(output: EntrypointOutput, scenario: TestScenario[Abstrac
|
||||
records, log_messages = output.records_and_state_messages, output.logs
|
||||
logs = [message.log for message in log_messages if message.log.level.value in scenario.log_levels]
|
||||
expected_records = scenario.expected_records
|
||||
|
||||
if expected_records is None:
|
||||
return
|
||||
|
||||
assert len(records) == len(expected_records)
|
||||
for actual, expected in zip(records, expected_records):
|
||||
|
||||
sorted_expected_records = sorted(
|
||||
filter(lambda e: "data" in e, expected_records),
|
||||
key=lambda x: ",".join(f"{k}={v}" for k, v in sorted(x["data"].items(), key=lambda x: x[0]) if k != "emitted_at"),
|
||||
)
|
||||
sorted_records = sorted(
|
||||
filter(lambda r: r.record, records),
|
||||
key=lambda x: ",".join(f"{k}={v}" for k, v in sorted(x.record.data.items(), key=lambda x: x[0]) if k != "emitted_at"),
|
||||
)
|
||||
for actual, expected in zip(sorted_records, sorted_expected_records):
|
||||
if actual.record:
|
||||
assert len(actual.record.data) == len(expected["data"])
|
||||
for key, value in actual.record.data.items():
|
||||
@@ -83,8 +96,11 @@ def _verify_read_output(output: EntrypointOutput, scenario: TestScenario[Abstrac
|
||||
else:
|
||||
assert value == expected["data"][key]
|
||||
assert actual.record.stream == expected["stream"]
|
||||
elif actual.state:
|
||||
assert actual.state.data == expected
|
||||
|
||||
expected_states = filter(lambda e: "data" not in e, expected_records)
|
||||
states = filter(lambda r: r.state, records)
|
||||
for actual, expected in zip(states, expected_states): # states should be emitted in sorted order
|
||||
assert actual.state.data == expected
|
||||
|
||||
if scenario.expected_logs:
|
||||
read_logs = scenario.expected_logs.get("read")
|
||||
@@ -129,7 +145,7 @@ def verify_check(capsys: CaptureFixture[str], tmp_path: PosixPath, scenario: Tes
|
||||
output = check(capsys, tmp_path, scenario)
|
||||
if expected_msg:
|
||||
# expected_msg is a string. what's the expected value field?
|
||||
assert expected_msg.value in output["message"] # type: ignore
|
||||
assert expected_msg in output["message"] # type: ignore
|
||||
assert output["status"] == scenario.expected_check_status
|
||||
|
||||
else:
|
||||
|
||||
@@ -11,6 +11,7 @@ from airbyte_cdk.sources.concurrent_source.concurrent_source_adapter import Conc
|
||||
from airbyte_cdk.sources.concurrent_source.thread_pool_manager import ThreadPoolManager
|
||||
from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager
|
||||
from airbyte_cdk.sources.message import InMemoryMessageRepository, MessageRepository
|
||||
from airbyte_cdk.sources.source import TState
|
||||
from airbyte_cdk.sources.streams import Stream
|
||||
from airbyte_cdk.sources.streams.concurrent.adapters import StreamFacade
|
||||
from airbyte_cdk.sources.streams.concurrent.cursor import ConcurrentCursor, CursorField, NoopCursor
|
||||
@@ -123,6 +124,6 @@ class StreamFacadeSourceBuilder(SourceBuilder[StreamFacadeSource]):
|
||||
self._input_state = state
|
||||
return self
|
||||
|
||||
def build(self, configured_catalog: Optional[Mapping[str, Any]]) -> StreamFacadeSource:
|
||||
def build(self, configured_catalog: Optional[Mapping[str, Any]], config: Optional[Mapping[str, Any]], state: Optional[TState]) -> StreamFacadeSource:
|
||||
threadpool = concurrent.futures.ThreadPoolExecutor(max_workers=self._max_workers, thread_name_prefix="workerpool")
|
||||
return StreamFacadeSource(self._streams, threadpool, self._cursor_field, self._cursor_boundaries, self._input_state)
|
||||
return StreamFacadeSource(self._streams, threadpool, self._cursor_field, self._cursor_boundaries, state)
|
||||
|
||||
@@ -119,7 +119,7 @@ class ConcurrentSourceBuilder(SourceBuilder[ConcurrentCdkSource]):
|
||||
self._streams: List[DefaultStream] = []
|
||||
self._message_repository = None
|
||||
|
||||
def build(self, configured_catalog: Optional[Mapping[str, Any]]) -> ConcurrentCdkSource:
|
||||
def build(self, configured_catalog: Optional[Mapping[str, Any]], _, __) -> ConcurrentCdkSource:
|
||||
return ConcurrentCdkSource(self._streams, self._message_repository, 1, 1)
|
||||
|
||||
def set_streams(self, streams: List[DefaultStream]) -> "ConcurrentSourceBuilder":
|
||||
|
||||
Reference in New Issue
Block a user