141 lines
5.3 KiB
Python
141 lines
5.3 KiB
Python
#
|
|
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
|
|
#
|
|
|
|
from unittest import mock
|
|
from unittest.mock import MagicMock, call
|
|
|
|
from airbyte_cdk.models import (
|
|
AirbyteLogMessage,
|
|
AirbyteMessage,
|
|
AirbyteRecordMessage,
|
|
AirbyteTraceMessage,
|
|
Level,
|
|
SyncMode,
|
|
TraceType,
|
|
Type,
|
|
)
|
|
from airbyte_cdk.sources.declarative.declarative_stream import DeclarativeStream
|
|
from airbyte_cdk.sources.declarative.transformations import AddFields, RecordTransformation
|
|
from airbyte_cdk.sources.declarative.transformations.add_fields import AddedFieldDefinition
|
|
|
|
|
|
def test_declarative_stream():
|
|
name = "stream"
|
|
primary_key = "pk"
|
|
cursor_field = "created_at"
|
|
|
|
schema_loader = MagicMock()
|
|
json_schema = {"name": {"type": "string"}}
|
|
schema_loader.get_json_schema.return_value = json_schema
|
|
|
|
state = MagicMock()
|
|
records = [
|
|
{"pk": 1234, "field": "value"},
|
|
{"pk": 4567, "field": "different_value"},
|
|
AirbyteMessage(type=Type.LOG, log=AirbyteLogMessage(level=Level.INFO, message="This is a log message")),
|
|
AirbyteMessage(type=Type.TRACE, trace=AirbyteTraceMessage(type=TraceType.ERROR, emitted_at=12345)),
|
|
]
|
|
stream_slices = [
|
|
{"date": "2021-01-01"},
|
|
{"date": "2021-01-02"},
|
|
{"date": "2021-01-03"},
|
|
]
|
|
|
|
retriever = MagicMock()
|
|
retriever.state = state
|
|
retriever.read_records.return_value = records
|
|
retriever.stream_slices.return_value = stream_slices
|
|
|
|
no_op_transform = mock.create_autospec(spec=RecordTransformation)
|
|
no_op_transform.transform = MagicMock(side_effect=lambda record, config, stream_slice, stream_state: record)
|
|
transformations = [no_op_transform]
|
|
|
|
config = {"api_key": "open_sesame"}
|
|
|
|
stream = DeclarativeStream(
|
|
name=name,
|
|
primary_key=primary_key,
|
|
stream_cursor_field="{{ parameters['cursor_field'] }}",
|
|
schema_loader=schema_loader,
|
|
retriever=retriever,
|
|
config=config,
|
|
transformations=transformations,
|
|
parameters={"cursor_field": "created_at"},
|
|
)
|
|
|
|
assert stream.name == name
|
|
assert stream.get_json_schema() == json_schema
|
|
assert stream.state == state
|
|
input_slice = stream_slices[0]
|
|
assert list(stream.read_records(SyncMode.full_refresh, cursor_field, input_slice, state)) == records
|
|
assert stream.primary_key == primary_key
|
|
assert stream.cursor_field == cursor_field
|
|
assert stream.stream_slices(sync_mode=SyncMode.incremental, cursor_field=cursor_field, stream_state=None) == stream_slices
|
|
for transformation in transformations:
|
|
expected_calls = [
|
|
call(record, config=config, stream_slice=input_slice, stream_state=state) for record in records if isinstance(record, dict)
|
|
]
|
|
assert len(transformation.transform.call_args_list) == len(expected_calls)
|
|
transformation.transform.assert_has_calls(expected_calls, any_order=False)
|
|
|
|
|
|
def test_declarative_stream_with_add_fields_transform():
|
|
name = "stream"
|
|
primary_key = "pk"
|
|
cursor_field = "created_at"
|
|
|
|
schema_loader = MagicMock()
|
|
json_schema = {"name": {"type": "string"}}
|
|
schema_loader.get_json_schema.return_value = json_schema
|
|
|
|
state = MagicMock()
|
|
retriever_records = [
|
|
{"pk": 1234, "field": "value"},
|
|
{"pk": 4567, "field": "different_value"},
|
|
AirbyteMessage(type=Type.RECORD, record=AirbyteRecordMessage(data={"pk": 1357, "field": "a_value"}, emitted_at=12344, stream="stream")),
|
|
AirbyteMessage(type=Type.LOG, log=AirbyteLogMessage(level=Level.INFO, message="This is a log message")),
|
|
AirbyteMessage(type=Type.TRACE, trace=AirbyteTraceMessage(type=TraceType.ERROR, emitted_at=12345)),
|
|
]
|
|
|
|
expected_records = [
|
|
{"pk": 1234, "field": "value", "added_key": "added_value"},
|
|
{"pk": 4567, "field": "different_value", "added_key": "added_value"},
|
|
AirbyteMessage(type=Type.RECORD, record=AirbyteRecordMessage(data={"pk": 1357, "field": "a_value", "added_key": "added_value"}, emitted_at=12344, stream="stream")),
|
|
AirbyteMessage(type=Type.LOG, log=AirbyteLogMessage(level=Level.INFO, message="This is a log message")),
|
|
AirbyteMessage(type=Type.TRACE, trace=AirbyteTraceMessage(type=TraceType.ERROR, emitted_at=12345)),
|
|
]
|
|
stream_slices = [
|
|
{"date": "2021-01-01"},
|
|
{"date": "2021-01-02"},
|
|
{"date": "2021-01-03"},
|
|
]
|
|
|
|
retriever = MagicMock()
|
|
retriever.state = state
|
|
retriever.read_records.return_value = retriever_records
|
|
retriever.stream_slices.return_value = stream_slices
|
|
|
|
inputs = [AddedFieldDefinition(path=["added_key"], value="added_value", parameters={})]
|
|
add_fields_transform = AddFields(fields=inputs, parameters={})
|
|
transformations = [add_fields_transform]
|
|
|
|
config = {"api_key": "open_sesame"}
|
|
|
|
stream = DeclarativeStream(
|
|
name=name,
|
|
primary_key=primary_key,
|
|
stream_cursor_field="{{ parameters['cursor_field'] }}",
|
|
schema_loader=schema_loader,
|
|
retriever=retriever,
|
|
config=config,
|
|
transformations=transformations,
|
|
parameters={"cursor_field": "created_at"},
|
|
)
|
|
|
|
assert stream.name == name
|
|
assert stream.get_json_schema() == json_schema
|
|
assert stream.state == state
|
|
input_slice = stream_slices[0]
|
|
assert list(stream.read_records(SyncMode.full_refresh, cursor_field, input_slice, state)) == expected_records
|