1
0
mirror of synced 2026-01-05 21:02:13 -05:00

Publish stream status messages in CDK (#24994)

* Publish stream status messages in CDK

* Automated Commit - Formatting Changes

* Convert to StreamDescriptor

* Automated Commit - Formatting Changes

* Bump to latest protocol model

* Automated Commit - Formatting Changes

* Bump protocol version

* Add tests for stream status message creation

* Formatting

* Formatting

* Fix failing test

* Actually emit state message

* Automated Commit - Formatting Changes

* Bump airbyte-protocol

* PR feedback

* Fix parameter input

* Correctly yield status message

* PR feedback

* Formatting

* Fix failing tests

* Automated Commit - Formatting Changes

* Revert accidental change

* Automated Change

* Replace STOPPED with COMPLETE/INCOMPLETE

* Update source-facebook-marketing changelog

* Revert "Update source-facebook-marketing changelog"

This reverts commit 709edb800c.

---------

Co-authored-by: jdpgrailsdev <jdpgrailsdev@users.noreply.github.com>
This commit is contained in:
Jonathan Pearlin
2023-04-26 10:30:36 -04:00
committed by GitHub
parent 3e6c8a37b0
commit 2ebfa459cf
9 changed files with 242 additions and 37 deletions

View File

@@ -1242,7 +1242,7 @@ def _create_page(response_body):
def test_read_manifest_declarative_source(test_name, manifest, pages, expected_records, expected_calls):
_stream_name = "Rates"
with patch.object(HttpStream, "_fetch_next_page", side_effect=pages) as mock_http_stream:
output_data = [message.record.data for message in _run_read(manifest, _stream_name)]
output_data = [message.record.data for message in _run_read(manifest, _stream_name) if message.record]
assert expected_records == output_data
mock_http_stream.assert_has_calls(expected_calls)

View File

@@ -21,6 +21,9 @@ from airbyte_cdk.models import (
AirbyteStateType,
AirbyteStream,
AirbyteStreamState,
AirbyteStreamStatus,
AirbyteStreamStatusTraceMessage,
AirbyteTraceMessage,
ConfiguredAirbyteCatalog,
ConfiguredAirbyteStream,
DestinationSyncMode,
@@ -28,8 +31,10 @@ from airbyte_cdk.models import (
Status,
StreamDescriptor,
SyncMode,
Type,
TraceType,
)
from airbyte_cdk.models import Type
from airbyte_cdk.models import Type as MessageType
from airbyte_cdk.sources import AbstractSource
from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager
from airbyte_cdk.sources.streams import IncrementalMixin, Stream
@@ -250,6 +255,19 @@ def _as_records(stream: str, data: List[Dict[str, Any]]) -> List[AirbyteMessage]
return [_as_record(stream, datum) for datum in data]
def _as_stream_status(stream: str, status: AirbyteStreamStatus) -> AirbyteMessage:
trace_message = AirbyteTraceMessage(
emitted_at=datetime.datetime.now().timestamp() * 1000.0,
type=TraceType.STREAM_STATUS,
stream_status=AirbyteStreamStatusTraceMessage(
stream_descriptor=StreamDescriptor(name=stream),
status=status,
),
)
return AirbyteMessage(type=MessageType.TRACE, trace=trace_message)
def _as_state(state_data: Dict[str, Any], stream_name: str = "", per_stream_state: Dict[str, Any] = None):
if per_stream_state:
return AirbyteMessage(
@@ -277,6 +295,8 @@ def _fix_emitted_at(messages: List[AirbyteMessage]) -> List[AirbyteMessage]:
for msg in messages:
if msg.type == Type.RECORD and msg.record:
msg.record.emitted_at = GLOBAL_EMITTED_AT
if msg.type == Type.TRACE and msg.trace:
msg.trace.emitted_at = GLOBAL_EMITTED_AT
return messages
@@ -296,7 +316,17 @@ def test_valid_full_refresh_read_no_slices(mocker):
]
)
expected = _as_records("s1", stream_output) + _as_records("s2", stream_output)
expected = _fix_emitted_at(
[
_as_stream_status("s1", AirbyteStreamStatus.STARTED),
_as_stream_status("s1", AirbyteStreamStatus.RUNNING),
*_as_records("s1", stream_output),
_as_stream_status("s1", AirbyteStreamStatus.COMPLETE),
_as_stream_status("s2", AirbyteStreamStatus.STARTED),
_as_stream_status("s2", AirbyteStreamStatus.RUNNING),
*_as_records("s2", stream_output),
_as_stream_status("s2", AirbyteStreamStatus.COMPLETE)
])
messages = _fix_emitted_at(list(src.read(logger, {}, catalog)))
assert expected == messages
@@ -326,7 +356,17 @@ def test_valid_full_refresh_read_with_slices(mocker):
]
)
expected = [*_as_records("s1", slices), *_as_records("s2", slices)]
expected = _fix_emitted_at(
[
_as_stream_status("s1", AirbyteStreamStatus.STARTED),
_as_stream_status("s1", AirbyteStreamStatus.RUNNING),
*_as_records("s1", slices),
_as_stream_status("s1", AirbyteStreamStatus.COMPLETE),
_as_stream_status("s2", AirbyteStreamStatus.STARTED),
_as_stream_status("s2", AirbyteStreamStatus.RUNNING),
*_as_records("s2", slices),
_as_stream_status("s2", AirbyteStreamStatus.COMPLETE)
])
messages = _fix_emitted_at(list(src.read(logger, {}, catalog)))
@@ -448,18 +488,24 @@ class TestIncrementalRead:
]
)
expected = [
expected = _fix_emitted_at([
_as_stream_status("s1", AirbyteStreamStatus.STARTED),
_as_stream_status("s1", AirbyteStreamStatus.RUNNING),
_as_record("s1", stream_output[0]),
_as_record("s1", stream_output[1]),
_as_state({"s1": new_state_from_connector}, "s1", new_state_from_connector)
if per_stream_enabled
else _as_state({"s1": new_state_from_connector}),
_as_stream_status("s1", AirbyteStreamStatus.COMPLETE),
_as_stream_status("s2", AirbyteStreamStatus.STARTED),
_as_stream_status("s2", AirbyteStreamStatus.RUNNING),
_as_record("s2", stream_output[0]),
_as_record("s2", stream_output[1]),
_as_state({"s1": new_state_from_connector, "s2": new_state_from_connector}, "s2", new_state_from_connector)
if per_stream_enabled
else _as_state({"s1": new_state_from_connector, "s2": new_state_from_connector}),
]
_as_stream_status("s2", AirbyteStreamStatus.COMPLETE),
])
messages = _fix_emitted_at(list(src.read(logger, {}, catalog, state=input_state)))
assert messages == expected
@@ -521,18 +567,24 @@ class TestIncrementalRead:
]
)
expected = [
expected = _fix_emitted_at([
_as_stream_status("s1", AirbyteStreamStatus.STARTED),
_as_stream_status("s1", AirbyteStreamStatus.RUNNING),
_as_record("s1", stream_output[0]),
_as_state({"s1": state}, "s1", state) if per_stream_enabled else _as_state({"s1": state}),
_as_record("s1", stream_output[1]),
_as_state({"s1": state}, "s1", state) if per_stream_enabled else _as_state({"s1": state}),
_as_state({"s1": state}, "s1", state) if per_stream_enabled else _as_state({"s1": state}),
_as_stream_status("s1", AirbyteStreamStatus.COMPLETE),
_as_stream_status("s2", AirbyteStreamStatus.STARTED),
_as_stream_status("s2", AirbyteStreamStatus.RUNNING),
_as_record("s2", stream_output[0]),
_as_state({"s1": state, "s2": state}, "s2", state) if per_stream_enabled else _as_state({"s1": state, "s2": state}),
_as_record("s2", stream_output[1]),
_as_state({"s1": state, "s2": state}, "s2", state) if per_stream_enabled else _as_state({"s1": state, "s2": state}),
_as_state({"s1": state, "s2": state}, "s2", state) if per_stream_enabled else _as_state({"s1": state, "s2": state}),
]
_as_stream_status("s2", AirbyteStreamStatus.COMPLETE),
])
messages = _fix_emitted_at(list(src.read(logger, {}, catalog, state=input_state)))
assert expected == messages
@@ -582,12 +634,18 @@ class TestIncrementalRead:
]
)
expected = [
expected = _fix_emitted_at([
_as_stream_status("s1", AirbyteStreamStatus.STARTED),
_as_stream_status("s1", AirbyteStreamStatus.RUNNING),
*_as_records("s1", stream_output),
_as_state({"s1": state}, "s1", state) if per_stream_enabled else _as_state({"s1": state}),
_as_stream_status("s1", AirbyteStreamStatus.COMPLETE),
_as_stream_status("s2", AirbyteStreamStatus.STARTED),
_as_stream_status("s2", AirbyteStreamStatus.RUNNING),
*_as_records("s2", stream_output),
_as_state({"s1": state, "s2": state}, "s2", state) if per_stream_enabled else _as_state({"s1": state, "s2": state}),
]
_as_stream_status("s2", AirbyteStreamStatus.COMPLETE),
])
messages = _fix_emitted_at(list(src.read(logger, {}, catalog, state=input_state)))
@@ -658,20 +716,26 @@ class TestIncrementalRead:
]
)
expected = [
expected = _fix_emitted_at([
_as_stream_status("s1", AirbyteStreamStatus.STARTED),
_as_stream_status("s1", AirbyteStreamStatus.RUNNING),
# stream 1 slice 1
*_as_records("s1", stream_output),
_as_state({"s1": state}, "s1", state) if per_stream_enabled else _as_state({"s1": state}),
# stream 1 slice 2
*_as_records("s1", stream_output),
_as_state({"s1": state}, "s1", state) if per_stream_enabled else _as_state({"s1": state}),
_as_stream_status("s1", AirbyteStreamStatus.COMPLETE),
_as_stream_status("s2", AirbyteStreamStatus.STARTED),
_as_stream_status("s2", AirbyteStreamStatus.RUNNING),
# stream 2 slice 1
*_as_records("s2", stream_output),
_as_state({"s1": state, "s2": state}, "s2", state) if per_stream_enabled else _as_state({"s1": state, "s2": state}),
# stream 2 slice 2
*_as_records("s2", stream_output),
_as_state({"s1": state, "s2": state}, "s2", state) if per_stream_enabled else _as_state({"s1": state, "s2": state}),
]
_as_stream_status("s2", AirbyteStreamStatus.COMPLETE),
])
messages = _fix_emitted_at(list(src.read(logger, {}, catalog, state=input_state)))
@@ -753,10 +817,14 @@ class TestIncrementalRead:
]
)
expected = [
expected = _fix_emitted_at([
_as_stream_status("s1", AirbyteStreamStatus.STARTED),
_as_state({"s1": state}, "s1", state) if per_stream_enabled else _as_state({"s1": state}),
_as_stream_status("s1", AirbyteStreamStatus.COMPLETE),
_as_stream_status("s2", AirbyteStreamStatus.STARTED),
_as_state({"s1": state, "s2": state}, "s2", state) if per_stream_enabled else _as_state({"s1": state, "s2": state}),
]
_as_stream_status("s2", AirbyteStreamStatus.COMPLETE),
])
messages = _fix_emitted_at(list(src.read(logger, {}, catalog, state=input_state)))
@@ -837,8 +905,10 @@ class TestIncrementalRead:
]
)
expected = [
expected = _fix_emitted_at([
# stream 1 slice 1
_as_stream_status("s1", AirbyteStreamStatus.STARTED),
_as_stream_status("s1", AirbyteStreamStatus.RUNNING),
_as_record("s1", stream_output[0]),
_as_record("s1", stream_output[1]),
_as_state({"s1": state}, "s1", state) if per_stream_enabled else _as_state({"s1": state}),
@@ -850,7 +920,10 @@ class TestIncrementalRead:
_as_state({"s1": state}, "s1", state) if per_stream_enabled else _as_state({"s1": state}),
_as_record("s1", stream_output[2]),
_as_state({"s1": state}, "s1", state) if per_stream_enabled else _as_state({"s1": state}),
_as_stream_status("s1", AirbyteStreamStatus.COMPLETE),
# stream 2 slice 1
_as_stream_status("s2", AirbyteStreamStatus.STARTED),
_as_stream_status("s2", AirbyteStreamStatus.RUNNING),
_as_record("s2", stream_output[0]),
_as_record("s2", stream_output[1]),
_as_state({"s1": state, "s2": state}, "s2", state) if per_stream_enabled else _as_state({"s1": state, "s2": state}),
@@ -862,7 +935,8 @@ class TestIncrementalRead:
_as_state({"s1": state, "s2": state}, "s2", state) if per_stream_enabled else _as_state({"s1": state, "s2": state}),
_as_record("s2", stream_output[2]),
_as_state({"s1": state, "s2": state}, "s2", state) if per_stream_enabled else _as_state({"s1": state, "s2": state}),
]
_as_stream_status("s2", AirbyteStreamStatus.COMPLETE),
])
messages = _fix_emitted_at(list(src.read(logger, {}, catalog, state=input_state)))
@@ -942,6 +1016,8 @@ class TestIncrementalRead:
expected = _fix_emitted_at(
[
_as_stream_status("s1", AirbyteStreamStatus.STARTED),
_as_stream_status("s1", AirbyteStreamStatus.RUNNING),
# stream 1 slice 1
stream_data_to_airbyte_message("s1", stream_output[0]),
stream_data_to_airbyte_message("s1", stream_output[1]),
@@ -956,7 +1032,10 @@ class TestIncrementalRead:
_as_state({"s1": state}, "s1", state) if per_stream_enabled else _as_state({"s1": state}),
stream_data_to_airbyte_message("s1", stream_output[3]),
_as_state({"s1": state}, "s1", state) if per_stream_enabled else _as_state({"s1": state}),
_as_stream_status("s1", AirbyteStreamStatus.COMPLETE),
# stream 2 slice 1
_as_stream_status("s2", AirbyteStreamStatus.STARTED),
_as_stream_status("s2", AirbyteStreamStatus.RUNNING),
stream_data_to_airbyte_message("s2", stream_output[0]),
stream_data_to_airbyte_message("s2", stream_output[1]),
stream_data_to_airbyte_message("s2", stream_output[2]),
@@ -970,6 +1049,7 @@ class TestIncrementalRead:
_as_state({"s1": state, "s2": state}, "s2", state) if per_stream_enabled else _as_state({"s1": state, "s2": state}),
stream_data_to_airbyte_message("s2", stream_output[3]),
_as_state({"s1": state, "s2": state}, "s2", state) if per_stream_enabled else _as_state({"s1": state, "s2": state}),
_as_stream_status("s2", AirbyteStreamStatus.COMPLETE),
]
)

View File

@@ -365,8 +365,8 @@ def test_internal_config(abstract_source, catalog):
# Test with empty config
logger = logging.getLogger(f"airbyte.{getattr(abstract_source, 'name', '')}")
records = [r for r in abstract_source.read(logger=logger, config={}, catalog=catalog, state={})]
# 3 for http stream and 3 for non http stream
assert len(records) == 3 + 3
# 3 for http stream, 3 for non http stream and 3 for stream status messages for each stream (2x)
assert len(records) == 3 + 3 + 3 + 3
assert http_stream.read_records.called
assert non_http_stream.read_records.called
# Make sure page_size havent been set
@@ -375,21 +375,21 @@ def test_internal_config(abstract_source, catalog):
# Test with records limit set to 1
internal_config = {"some_config": 100, "_limit": 1}
records = [r for r in abstract_source.read(logger=logger, config=internal_config, catalog=catalog, state={})]
# 1 from http stream + 1 from non http stream
assert len(records) == 1 + 1
# 1 from http stream + 1 from non http stream and 3 for stream status messages for each stream (2x)
assert len(records) == 1 + 1 + 3 + 3
assert "_limit" not in abstract_source.streams_config
assert "some_config" in abstract_source.streams_config
# Test with records limit set to number that exceeds expceted records
internal_config = {"some_config": 100, "_limit": 20}
records = [r for r in abstract_source.read(logger=logger, config=internal_config, catalog=catalog, state={})]
assert len(records) == 3 + 3
assert len(records) == 3 + 3 + 3 + 3
# Check if page_size paramter is set to http instance only
internal_config = {"some_config": 100, "_page_size": 2}
records = [r for r in abstract_source.read(logger=logger, config=internal_config, catalog=catalog, state={})]
assert "_page_size" not in abstract_source.streams_config
assert "some_config" in abstract_source.streams_config
assert len(records) == 3 + 3
assert len(records) == 3 + 3 + 3 + 3
assert http_stream.page_size == 2
# Make sure page_size havent been set for non http streams
assert not non_http_stream.page_size
@@ -402,6 +402,7 @@ def test_internal_config_limit(mocker, abstract_source, catalog):
STREAM_LIMIT = 2
SLICE_DEBUG_LOG_COUNT = 1
FULL_RECORDS_NUMBER = 3
TRACE_STATUS_COUNT = 3
streams = abstract_source.streams(None)
http_stream = streams[0]
http_stream.read_records.return_value = [{}] * FULL_RECORDS_NUMBER
@@ -409,7 +410,7 @@ def test_internal_config_limit(mocker, abstract_source, catalog):
catalog.streams[0].sync_mode = SyncMode.full_refresh
records = [r for r in abstract_source.read(logger=logger_mock, config=internal_config, catalog=catalog, state={})]
assert len(records) == STREAM_LIMIT + SLICE_DEBUG_LOG_COUNT
assert len(records) == STREAM_LIMIT + SLICE_DEBUG_LOG_COUNT + TRACE_STATUS_COUNT
logger_info_args = [call[0][0] for call in logger_mock.info.call_args_list]
# Check if log line matches number of limit
read_log_record = [_l for _l in logger_info_args if _l.startswith("Read")]
@@ -418,14 +419,16 @@ def test_internal_config_limit(mocker, abstract_source, catalog):
# No limit, check if state record produced for incremental stream
catalog.streams[0].sync_mode = SyncMode.incremental
records = [r for r in abstract_source.read(logger=logger_mock, config={}, catalog=catalog, state={})]
assert len(records) == FULL_RECORDS_NUMBER + SLICE_DEBUG_LOG_COUNT + 1
assert records[-1].type == Type.STATE
assert len(records) == FULL_RECORDS_NUMBER + SLICE_DEBUG_LOG_COUNT + TRACE_STATUS_COUNT + 1
assert records[-2].type == Type.STATE
assert records[-1].type == Type.TRACE
# Set limit and check if state is produced when limit is set for incremental stream
logger_mock.reset_mock()
records = [r for r in abstract_source.read(logger=logger_mock, config=internal_config, catalog=catalog, state={})]
assert len(records) == STREAM_LIMIT + SLICE_DEBUG_LOG_COUNT + 1
assert records[-1].type == Type.STATE
assert len(records) == STREAM_LIMIT + SLICE_DEBUG_LOG_COUNT + TRACE_STATUS_COUNT + 1
assert records[-2].type == Type.STATE
assert records[-1].type == Type.TRACE
logger_info_args = [call[0][0] for call in logger_mock.info.call_args_list]
read_log_record = [_l for _l in logger_info_args if _l.startswith("Read")]
assert read_log_record[0].startswith(f"Read {STREAM_LIMIT} ")
@@ -436,6 +439,7 @@ SCHEMA = {"type": "object", "properties": {"value": {"type": "string"}}}
def test_source_config_no_transform(mocker, abstract_source, catalog):
SLICE_DEBUG_LOG_COUNT = 1
TRACE_STATUS_COUNT = 3
logger_mock = mocker.MagicMock()
logger_mock.level = logging.DEBUG
streams = abstract_source.streams(None)
@@ -443,7 +447,7 @@ def test_source_config_no_transform(mocker, abstract_source, catalog):
http_stream.get_json_schema.return_value = non_http_stream.get_json_schema.return_value = SCHEMA
http_stream.read_records.return_value, non_http_stream.read_records.return_value = [[{"value": 23}] * 5] * 2
records = [r for r in abstract_source.read(logger=logger_mock, config={}, catalog=catalog, state={})]
assert len(records) == 2 * (5 + SLICE_DEBUG_LOG_COUNT)
assert len(records) == 2 * (5 + SLICE_DEBUG_LOG_COUNT + TRACE_STATUS_COUNT)
assert [r.record.data for r in records if r.type == Type.RECORD] == [{"value": 23}] * 2 * 5
assert http_stream.get_json_schema.call_count == 5
assert non_http_stream.get_json_schema.call_count == 5
@@ -453,6 +457,7 @@ def test_source_config_transform(mocker, abstract_source, catalog):
logger_mock = mocker.MagicMock()
logger_mock.level = logging.DEBUG
SLICE_DEBUG_LOG_COUNT = 2
TRACE_STATUS_COUNT = 6
streams = abstract_source.streams(None)
http_stream, non_http_stream = streams
http_stream.transformer = TypeTransformer(TransformConfig.DefaultSchemaNormalization)
@@ -460,7 +465,7 @@ def test_source_config_transform(mocker, abstract_source, catalog):
http_stream.get_json_schema.return_value = non_http_stream.get_json_schema.return_value = SCHEMA
http_stream.read_records.return_value, non_http_stream.read_records.return_value = [{"value": 23}], [{"value": 23}]
records = [r for r in abstract_source.read(logger=logger_mock, config={}, catalog=catalog, state={})]
assert len(records) == 2 + SLICE_DEBUG_LOG_COUNT
assert len(records) == 2 + SLICE_DEBUG_LOG_COUNT + TRACE_STATUS_COUNT
assert [r.record.data for r in records if r.type == Type.RECORD] == [{"value": "23"}] * 2
@@ -468,13 +473,14 @@ def test_source_config_transform_and_no_transform(mocker, abstract_source, catal
logger_mock = mocker.MagicMock()
logger_mock.level = logging.DEBUG
SLICE_DEBUG_LOG_COUNT = 2
TRACE_STATUS_COUNT = 6
streams = abstract_source.streams(None)
http_stream, non_http_stream = streams
http_stream.transformer = TypeTransformer(TransformConfig.DefaultSchemaNormalization)
http_stream.get_json_schema.return_value = non_http_stream.get_json_schema.return_value = SCHEMA
http_stream.read_records.return_value, non_http_stream.read_records.return_value = [{"value": 23}], [{"value": 23}]
records = [r for r in abstract_source.read(logger=logger_mock, config={}, catalog=catalog, state={})]
assert len(records) == 2 + SLICE_DEBUG_LOG_COUNT
assert len(records) == 2 + SLICE_DEBUG_LOG_COUNT + TRACE_STATUS_COUNT
assert [r.record.data for r in records if r.type == Type.RECORD] == [{"value": "23"}, {"value": 23}]
@@ -520,8 +526,8 @@ def test_read_default_http_availability_strategy_stream_available(catalog, mocke
source = MockAbstractSource(streams=streams)
logger = logging.getLogger(f"airbyte.{getattr(abstract_source, 'name', '')}")
records = [r for r in source.read(logger=logger, config={}, catalog=catalog, state={})]
# 3 for http stream and 3 for non http stream
assert len(records) == 3 + 3
# 3 for http stream, 3 for non http stream and 3 for stream status messages for each stream (2x)
assert len(records) == 3 + 3 + 3 + 3
assert http_stream.read_records.called
assert non_http_stream.read_records.called
@@ -578,8 +584,8 @@ def test_read_default_http_availability_strategy_stream_unavailable(catalog, moc
with caplog.at_level(logging.WARNING):
records = [r for r in source.read(logger=logger, config={}, catalog=catalog, state={})]
# 0 for http stream and 3 for non http stream
assert len(records) == 0 + 3
# 0 for http stream, 3 for non http stream and 3 status trace meessages
assert len(records) == 0 + 3 + 3
assert non_http_stream.read_records.called
expected_logs = [
f"Skipped syncing stream '{http_stream.name}' because it was unavailable.",