1
0
mirror of synced 2025-12-25 02:09:19 -05:00

refactor!(airbyte-cdk): remove availability strategy (#40682)

Signed-off-by: Artem Inzhyyants <artem.inzhyyants@gmail.com>
This commit is contained in:
Artem Inzhyyants
2024-07-17 11:55:51 +02:00
committed by GitHub
parent 536e8e21c8
commit 2bfb2264d9
10 changed files with 18 additions and 287 deletions

View File

@@ -128,11 +128,6 @@ class AbstractSource(Source, ABC):
)
timer.start_event(f"Syncing stream {configured_stream.stream.name}")
stream_is_available, reason = stream_instance.check_availability(logger, self)
if not stream_is_available:
logger.warning(f"Skipped syncing stream '{stream_instance.name}' because it was unavailable. {reason}")
yield stream_status_as_airbyte_message(configured_stream.stream, AirbyteStreamStatus.INCOMPLETE)
continue
logger.info(f"Marking stream {configured_stream.stream.name} as STARTED")
yield stream_status_as_airbyte_message(configured_stream.stream, AirbyteStreamStatus.STARTED)
yield from self._read_stream(

View File

@@ -4,9 +4,9 @@
import concurrent
import logging
from queue import Queue
from typing import Iterable, Iterator, List, Optional, Tuple
from typing import Iterable, Iterator, List
from airbyte_cdk.models import AirbyteMessage, AirbyteStreamStatus
from airbyte_cdk.models import AirbyteMessage
from airbyte_cdk.sources.concurrent_source.concurrent_read_processor import ConcurrentReadProcessor
from airbyte_cdk.sources.concurrent_source.partition_generation_completed_sentinel import PartitionGenerationCompletedSentinel
from airbyte_cdk.sources.concurrent_source.stream_thread_exception import StreamThreadException
@@ -19,7 +19,6 @@ from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partitio
from airbyte_cdk.sources.streams.concurrent.partitions.record import Record
from airbyte_cdk.sources.streams.concurrent.partitions.types import PartitionCompleteSentinel, QueueItem
from airbyte_cdk.sources.utils.slice_logger import DebugSliceLogger, SliceLogger
from airbyte_cdk.utils.stream_status_utils import as_airbyte_message as stream_status_as_airbyte_message
class ConcurrentSource:
@@ -81,15 +80,6 @@ class ConcurrentSource:
streams: List[AbstractStream],
) -> Iterator[AirbyteMessage]:
self._logger.info("Starting syncing")
stream_instances_to_read_from, not_available_streams = self._get_streams_to_read_from(streams)
for stream, message in not_available_streams:
self._logger.warning(f"Skipped syncing stream '{stream.name}' because it was unavailable. {message}")
yield stream_status_as_airbyte_message(stream, AirbyteStreamStatus.INCOMPLETE)
# Return early if there are no streams to read from
if not stream_instances_to_read_from:
return
# We set a maxsize to for the main thread to process record items when the queue size grows. This assumes that there are less
# threads generating partitions that than are max number of workers. If it weren't the case, we could have threads only generating
@@ -97,7 +87,7 @@ class ConcurrentSource:
# information and might even need to be configurable depending on the source
queue: Queue[QueueItem] = Queue(maxsize=10_000)
concurrent_stream_processor = ConcurrentReadProcessor(
stream_instances_to_read_from,
streams,
PartitionEnqueuer(queue, self._threadpool),
self._threadpool,
self._logger,
@@ -155,22 +145,3 @@ class ConcurrentSource:
yield from concurrent_stream_processor.on_record(queue_item)
else:
raise ValueError(f"Unknown queue item type: {type(queue_item)}")
def _get_streams_to_read_from(
self, streams: List[AbstractStream]
) -> Tuple[List[AbstractStream], List[Tuple[AbstractStream, Optional[str]]]]:
"""
Iterate over the configured streams and return a list of streams to read from.
If a stream is not configured, it will be skipped.
If a stream is configured but does not exist in the source and self.raise_exception_on_missing_stream is True, an exception will be raised
If a stream is not available, it will be skipped
"""
stream_instances_to_read_from = []
not_available_streams = []
for stream in streams:
stream_availability = stream.check_availability()
if not stream_availability.is_available():
not_available_streams.append((stream, stream_availability.message()))
continue
stream_instances_to_read_from.append(stream)
return stream_instances_to_read_from, not_available_streams

View File

@@ -18,6 +18,7 @@ from airbyte_cdk.sources.file_based.schema_validation_policies import AbstractSc
from airbyte_cdk.sources.file_based.stream.cursor import AbstractFileBasedCursor
from airbyte_cdk.sources.file_based.types import StreamSlice
from airbyte_cdk.sources.streams import Stream
from deprecated import deprecated
class AbstractFileBasedStream(Stream):
@@ -152,6 +153,7 @@ class AbstractFileBasedStream(Stream):
)
@cached_property
@deprecated(version="3.7.0")
def availability_strategy(self) -> AbstractFileBasedAvailabilityStrategy:
return self._availability_strategy

View File

@@ -128,6 +128,7 @@ class FileBasedStreamFacade(AbstractStreamFacade[DefaultStream], AbstractFileBas
return self._legacy_stream.supports_incremental
@property
@deprecated(version="3.7.0")
def availability_strategy(self) -> AbstractFileBasedAvailabilityStrategy:
return self._legacy_stream.availability_strategy

View File

@@ -348,6 +348,7 @@ class Stream(ABC):
"""
return False
@deprecated(version="3.7.0")
def check_availability(self, logger: logging.Logger, source: Optional["Source"] = None) -> Tuple[bool, Optional[str]]:
"""
Checks whether this stream is available.
@@ -364,6 +365,7 @@ class Stream(ABC):
return True, None
@property
@deprecated(version="3.7.0")
def availability_strategy(self) -> Optional["AvailabilityStrategy"]:
"""
:return: The AvailabilityStrategy used to check whether this stream is available.

View File

@@ -123,6 +123,7 @@ class HttpStream(Stream, ABC):
return 5
@property
@deprecated(version="3.7.0", reason="This functionality is handled by combination of HttpClient.ErrorHandler and AirbyteStreamStatus")
def availability_strategy(self) -> Optional[AvailabilityStrategy]:
return HttpAvailabilityStrategy()

View File

@@ -18,6 +18,7 @@ from airbyte_cdk.models import (
DestinationSyncMode,
FailureType,
SyncMode,
TraceType,
)
from airbyte_cdk.models import Type as MessageType
from airbyte_cdk.sources.concurrent_source.concurrent_source_adapter import ConcurrentSourceAdapter
@@ -73,9 +74,11 @@ def test_concurrent_source_adapter(as_stream_status, remove_stack_trace):
unavailable_stream = _mock_stream("s3", [{"data": 3}], False)
concurrent_stream.name = "s2"
logger = Mock()
adapter = _MockSource(concurrent_source, {regular_stream: False, concurrent_stream: True, unavailable_stream: False}, logger)
messages = list(adapter.read(logger, {}, _configured_catalog([regular_stream, concurrent_stream, unavailable_stream])))
adapter = _MockSource(concurrent_source, {regular_stream: False, concurrent_stream: True}, logger)
with pytest.raises(AirbyteTracedException):
messages = []
for message in adapter.read(logger, {}, _configured_catalog([regular_stream, concurrent_stream, unavailable_stream])):
messages.append(message)
records = [m for m in messages if m.type == MessageType.RECORD]
@@ -93,7 +96,7 @@ def test_concurrent_source_adapter(as_stream_status, remove_stack_trace):
assert records == expected_records
unavailable_stream_trace_messages = [m for m in messages if m.type == MessageType.TRACE and m.trace.stream_status.status == AirbyteStreamStatus.INCOMPLETE]
unavailable_stream_trace_messages = [m for m in messages if m.type == MessageType.TRACE and m.trace.type == TraceType.STREAM_STATUS and m.trace.stream_status.status == AirbyteStreamStatus.INCOMPLETE]
expected_status = [as_stream_status("s3", AirbyteStreamStatus.INCOMPLETE)]
assert len(unavailable_stream_trace_messages) == 1

View File

@@ -383,12 +383,6 @@ class IncrementalStreamTest(TestCase):
"start_date": start_datetime.strftime("%Y-%m-%dT%H:%M:%SZ")
}
last_record_date_0 = (start_datetime + timedelta(days=4)).strftime("%Y-%m-%dT%H:%M:%SZ")
http_mocker.get(
_create_legacies_request().with_start_date(start_datetime).with_end_date(start_datetime + timedelta(days=7)).build(),
_create_response().with_record(record=_create_record("legacies").with_cursor(last_record_date_0)).with_record(record=_create_record("legacies").with_cursor(last_record_date_0)).with_record(record=_create_record("legacies").with_cursor(last_record_date_0)).build(),
)
last_record_date_1 = (_NOW - timedelta(days=1)).strftime("%Y-%m-%dT%H:%M:%SZ")
http_mocker.get(
_create_legacies_request().with_start_date(_NOW - timedelta(days=1)).with_end_date(_NOW).build(),
@@ -412,14 +406,6 @@ class IncrementalStreamTest(TestCase):
"start_date": start_datetime.strftime("%Y-%m-%dT%H:%M:%SZ")
}
last_record_date_0 = (start_datetime + timedelta(days=4)).strftime("%Y-%m-%dT%H:%M:%SZ")
http_mocker.get(
_create_legacies_request().with_start_date(start_datetime).with_end_date(start_datetime + timedelta(days=7)).build(),
_create_response().with_record(record=_create_record("legacies").with_cursor(last_record_date_0)).with_record(
record=_create_record("legacies").with_cursor(last_record_date_0)).with_record(
record=_create_record("legacies").with_cursor(last_record_date_0)).build(),
)
last_record_date_1 = _NOW.strftime("%Y-%m-%dT%H:%M:%SZ")
incoming_state = AirbyteStateBlob(created_at=last_record_date_1)

View File

@@ -165,12 +165,6 @@ class ResumableFullRefreshStreamTest(TestCase):
state = StateBuilder().with_stream_state("justice_songs", {"page": 100}).build()
# Needed to handle the availability check request to get the first record
http_mocker.get(
_create_justice_songs_request().build(),
_create_response(pagination_has_more=True).with_pagination().with_record(record=_create_record("justice_songs")).with_record(record=_create_record("justice_songs")).build(),
)
http_mocker.get(
_create_justice_songs_request().with_page(100).build(),
_create_response(pagination_has_more=True).with_pagination().with_record(record=_create_record("justice_songs")).with_record(record=_create_record("justice_songs")).with_record(record=_create_record("justice_songs")).build(),

View File

@@ -7,17 +7,15 @@ import logging
import tempfile
from collections import defaultdict
from contextlib import nullcontext as does_not_raise
from typing import Any, Iterable, List, Mapping, MutableMapping, Optional, Tuple, Union
from typing import Any, List, Mapping, MutableMapping, Optional, Tuple, Union
import pytest
import requests
from airbyte_cdk.models import (
AirbyteGlobalState,
AirbyteStateBlob,
AirbyteStateMessage,
AirbyteStateType,
AirbyteStreamState,
AirbyteStreamStatus,
ConfiguredAirbyteCatalog,
StreamDescriptor,
SyncMode,
@@ -25,8 +23,7 @@ from airbyte_cdk.models import (
)
from airbyte_cdk.sources import AbstractSource, Source
from airbyte_cdk.sources.streams.core import Stream
from airbyte_cdk.sources.streams.http.availability_strategy import HttpAvailabilityStrategy
from airbyte_cdk.sources.streams.http.http import HttpStream, HttpSubStream
from airbyte_cdk.sources.streams.http.http import HttpStream
from airbyte_cdk.sources.utils.transform import TransformConfig, TypeTransformer
from pydantic import ValidationError
@@ -505,224 +502,3 @@ def test_source_config_transform_and_no_transform(mocker, abstract_source, catal
records = [r for r in abstract_source.read(logger=logger_mock, config={}, catalog=catalog, state={})]
assert len(records) == 2 + SLICE_DEBUG_LOG_COUNT + TRACE_STATUS_COUNT + STATE_COUNT
assert [r.record.data for r in records if r.type == Type.RECORD] == [{"value": "23"}, {"value": 23}]
def test_read_default_http_availability_strategy_stream_available(catalog, mocker):
mocker.patch.multiple(HttpStream, __abstractmethods__=set())
mocker.patch.multiple(Stream, __abstractmethods__=set())
class MockHttpStream(mocker.MagicMock, HttpStream):
url_base = "http://example.com"
path = "/dummy/path"
get_json_schema = mocker.MagicMock()
@property
def cursor_field(self) -> Union[str, List[str]]:
return ["updated_at"]
def __init__(self, *args, **kvargs):
mocker.MagicMock.__init__(self)
HttpStream.__init__(self, *args, kvargs)
self.read_records = mocker.MagicMock()
@property
def state(self) -> MutableMapping[str, Any]:
return self._state
@state.setter
def state(self, value: MutableMapping[str, Any]) -> None:
self._state = value
def get_backoff_strategy(self):
return None
def get_error_handler(self):
return None
class MockStream(mocker.MagicMock, Stream):
page_size = None
get_json_schema = mocker.MagicMock()
def __init__(self, *args, **kvargs):
mocker.MagicMock.__init__(self)
self.read_records = mocker.MagicMock()
streams = [MockHttpStream(), MockStream()]
http_stream, non_http_stream = streams
assert isinstance(http_stream, HttpStream)
assert not isinstance(non_http_stream, HttpStream)
assert isinstance(http_stream.availability_strategy, HttpAvailabilityStrategy)
assert non_http_stream.availability_strategy is None
# Add an extra record for the default HttpAvailabilityStrategy to pull from
# during the try: next(records) check, since we are mocking the return value
# and not re-creating the generator like we would during actual reading
http_stream.read_records.return_value = iter([{"value": "test"}] + [{}] * 3)
non_http_stream.read_records.return_value = iter([{}] * 3)
source = MockAbstractSource(streams=streams)
logger = logging.getLogger(f"airbyte.{getattr(abstract_source, 'name', '')}")
records = [r for r in source.read(logger=logger, config={}, catalog=catalog, state={})]
# 3 for http stream, 3 for non http stream, 1 for state message for each stream (2x) and 3 for stream status messages for each stream (2x)
assert len(records) == 3 + 3 + 1 + 1 + 3 + 3
assert http_stream.read_records.called
assert non_http_stream.read_records.called
def test_read_default_http_availability_strategy_stream_unavailable(catalog, mocker, caplog, remove_stack_trace, as_stream_status):
mocker.patch.multiple(Stream, __abstractmethods__=set())
class MockHttpStream(HttpStream):
url_base = "https://test_base_url.com"
primary_key = ""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.resp_counter = 1
def next_page_token(self, response: requests.Response) -> Optional[Mapping[str, Any]]:
return None
def path(self, **kwargs) -> str:
return ""
def parse_response(self, response: requests.Response, **kwargs) -> Iterable[Mapping]:
stub_response = {"data": self.resp_counter}
self.resp_counter += 1
yield stub_response
class MockStream(mocker.MagicMock, Stream):
page_size = None
get_json_schema = mocker.MagicMock()
def __init__(self, *args, **kvargs):
mocker.MagicMock.__init__(self)
self.read_records = mocker.MagicMock()
streams = [MockHttpStream(), MockStream()]
http_stream, non_http_stream = streams
assert isinstance(http_stream, HttpStream)
assert not isinstance(non_http_stream, HttpStream)
assert isinstance(http_stream.availability_strategy, HttpAvailabilityStrategy)
assert non_http_stream.availability_strategy is None
# Don't set anything for read_records return value for HttpStream, since
# it should be skipped due to the stream being unavailable
non_http_stream.read_records.return_value = iter([{}] * 3)
# Patch HTTP request to stream endpoint to make it unavailable
req = requests.Response()
req.status_code = 403
mocker.patch.object(requests.Session, "send", return_value=req)
source = MockAbstractSource(streams=streams)
logger = logging.getLogger("test_read_default_http_availability_strategy_stream_unavailable")
with caplog.at_level(logging.WARNING):
records = [r for r in source.read(logger=logger, config={}, catalog=catalog, state={})]
# 0 for http stream, 3 for non http stream, 1 for non http stream state message and 4 status trace messages
assert len(records) == 0 + 3 + 1 + 4
assert non_http_stream.read_records.called
expected_logs = [
f"Skipped syncing stream '{http_stream.name}' because it was unavailable.",
"Forbidden.",
"You don't have permission to access this resource.",
]
for message in expected_logs:
assert message in caplog.text
expected_status = [as_stream_status(f"{http_stream.name}", AirbyteStreamStatus.INCOMPLETE)]
records = [remove_stack_trace(record) for record in records]
assert records[0].trace.stream_status == expected_status[0].trace.stream_status
def test_read_default_http_availability_strategy_parent_stream_unavailable(catalog, mocker, caplog, remove_stack_trace, as_stream_status):
"""Test default availability strategy if error happens during slice extraction (reading of parent stream)"""
mocker.patch.multiple(Stream, __abstractmethods__=set())
class MockHttpParentStream(HttpStream):
url_base = "https://test_base_url.com"
primary_key = ""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.resp_counter = 1
def next_page_token(self, response: requests.Response) -> Optional[Mapping[str, Any]]:
return None
def path(self, **kwargs) -> str:
return ""
def parse_response(self, response: requests.Response, **kwargs) -> Iterable[Mapping]:
stub_response = {"data": self.resp_counter}
self.resp_counter += 1
yield stub_response
def get_json_schema(self) -> Mapping[str, Any]:
return {}
class MockHttpStream(HttpSubStream):
url_base = "https://test_base_url.com"
primary_key = ""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.resp_counter = 1
def next_page_token(self, response: requests.Response) -> Optional[Mapping[str, Any]]:
return None
def path(self, **kwargs) -> str:
return ""
def parse_response(self, response: requests.Response, **kwargs) -> Iterable[Mapping]:
stub_response = {"data": self.resp_counter}
self.resp_counter += 1
yield stub_response
http_stream = MockHttpStream(parent=MockHttpParentStream())
streams = [http_stream]
assert isinstance(http_stream, HttpSubStream)
assert isinstance(http_stream.availability_strategy, HttpAvailabilityStrategy)
# Patch HTTP request to stream endpoint to make it unavailable
req = requests.Response()
req.status_code = 403
mocker.patch.object(requests.Session, "send", return_value=req)
source = MockAbstractSource(streams=streams)
logger = logging.getLogger("test_read_default_http_availability_strategy_parent_stream_unavailable")
configured_catalog = {
"streams": [
{
"stream": {
"name": "mock_http_stream",
"json_schema": {"type": "object", "properties": {"k": "v"}},
"supported_sync_modes": ["full_refresh"],
},
"destination_sync_mode": "overwrite",
"sync_mode": "full_refresh",
}
]
}
catalog = ConfiguredAirbyteCatalog.model_validate(configured_catalog)
with caplog.at_level(logging.WARNING):
records = [r for r in source.read(logger=logger, config={}, catalog=catalog, state={})]
# 0 for http stream, 0 for non http stream and 1 status trace messages
assert len(records) == 1
expected_logs = [
f"Skipped syncing stream '{http_stream.name}' because it was unavailable.",
"Forbidden.",
"You don't have permission to access this resource.",
]
for message in expected_logs:
assert message in caplog.text
expected_status = [as_stream_status(f"{http_stream.name}", AirbyteStreamStatus.INCOMPLETE)]
records = [remove_stack_trace(record) for record in records]
assert records[0].trace.stream_status == expected_status[0].trace.stream_status