88 lines
3.0 KiB
Python
88 lines
3.0 KiB
Python
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.
|
|
|
|
import pytest
|
|
from airbyte_cdk.models import (
|
|
AirbyteControlConnectorConfigMessage,
|
|
AirbyteControlMessage,
|
|
AirbyteMessage,
|
|
AirbyteRecordMessage,
|
|
AirbyteStateBlob,
|
|
AirbyteStateMessage,
|
|
AirbyteStateStats,
|
|
AirbyteStateType,
|
|
AirbyteStreamState,
|
|
OrchestratorType,
|
|
StreamDescriptor,
|
|
Type,
|
|
)
|
|
from airbyte_cdk.sources.connector_state_manager import HashableStreamDescriptor
|
|
from airbyte_cdk.utils.message_utils import get_stream_descriptor
|
|
|
|
|
|
def test_get_record_message_stream_descriptor():
|
|
message = AirbyteMessage(
|
|
type=Type.RECORD,
|
|
record=AirbyteRecordMessage(
|
|
stream="test_stream",
|
|
namespace="test_namespace",
|
|
data={"id": "12345"},
|
|
emitted_at=1,
|
|
),
|
|
)
|
|
expected_descriptor = HashableStreamDescriptor(name="test_stream", namespace="test_namespace")
|
|
assert get_stream_descriptor(message) == expected_descriptor
|
|
|
|
|
|
def test_get_record_message_stream_descriptor_no_namespace():
|
|
message = AirbyteMessage(
|
|
type=Type.RECORD,
|
|
record=AirbyteRecordMessage(stream="test_stream", data={"id": "12345"}, emitted_at=1),
|
|
)
|
|
expected_descriptor = HashableStreamDescriptor(name="test_stream", namespace=None)
|
|
assert get_stream_descriptor(message) == expected_descriptor
|
|
|
|
|
|
def test_get_state_message_stream_descriptor():
|
|
message = AirbyteMessage(
|
|
type=Type.STATE,
|
|
state=AirbyteStateMessage(
|
|
type=AirbyteStateType.STREAM,
|
|
stream=AirbyteStreamState(
|
|
stream_descriptor=StreamDescriptor(name="test_stream", namespace="test_namespace"),
|
|
stream_state=AirbyteStateBlob(updated_at="2024-02-02"),
|
|
),
|
|
sourceStats=AirbyteStateStats(recordCount=27.0),
|
|
),
|
|
)
|
|
expected_descriptor = HashableStreamDescriptor(name="test_stream", namespace="test_namespace")
|
|
assert get_stream_descriptor(message) == expected_descriptor
|
|
|
|
|
|
def test_get_state_message_stream_descriptor_no_namespace():
|
|
message = AirbyteMessage(
|
|
type=Type.STATE,
|
|
state=AirbyteStateMessage(
|
|
type=AirbyteStateType.STREAM,
|
|
stream=AirbyteStreamState(
|
|
stream_descriptor=StreamDescriptor(name="test_stream"),
|
|
stream_state=AirbyteStateBlob(updated_at="2024-02-02"),
|
|
),
|
|
sourceStats=AirbyteStateStats(recordCount=27.0),
|
|
),
|
|
)
|
|
expected_descriptor = HashableStreamDescriptor(name="test_stream", namespace=None)
|
|
assert get_stream_descriptor(message) == expected_descriptor
|
|
|
|
|
|
def test_get_other_message_stream_descriptor_fails():
|
|
message = AirbyteMessage(
|
|
type=Type.CONTROL,
|
|
control=AirbyteControlMessage(
|
|
type=OrchestratorType.CONNECTOR_CONFIG,
|
|
emitted_at=10,
|
|
connectorConfig=AirbyteControlConnectorConfigMessage(config={"any config": "a config value"}),
|
|
),
|
|
)
|
|
with pytest.raises(NotImplementedError):
|
|
get_stream_descriptor(message)
|