# # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # import os from argparse import Namespace from collections import defaultdict from copy import deepcopy from typing import Any, List, Mapping, MutableMapping, Union from unittest import mock from unittest.mock import MagicMock, patch import pytest import requests from airbyte_cdk import AirbyteEntrypoint from airbyte_cdk import entrypoint as entrypoint_module from airbyte_cdk.models import ( AirbyteCatalog, AirbyteConnectionStatus, AirbyteControlConnectorConfigMessage, AirbyteControlMessage, AirbyteMessage, AirbyteRecordMessage, AirbyteStateBlob, AirbyteStateMessage, AirbyteStateType, AirbyteStream, AirbyteStreamState, AirbyteStreamStatus, AirbyteStreamStatusTraceMessage, AirbyteTraceMessage, ConnectorSpecification, OrchestratorType, Status, StreamDescriptor, SyncMode, TraceType, Type, ) from airbyte_cdk.models.airbyte_protocol import AirbyteStateStats from airbyte_cdk.sources import Source from airbyte_cdk.sources.connector_state_manager import HashableStreamDescriptor from airbyte_cdk.utils import AirbyteTracedException class MockSource(Source): def read(self, **kwargs): pass def discover(self, **kwargs): pass def check(self, **kwargs): pass @property def message_repository(self): pass def _as_arglist(cmd: str, named_args: Mapping[str, Any]) -> List[str]: out = [cmd] for k, v in named_args.items(): out.append(f"--{k}") if v: out.append(v) return out @pytest.fixture def spec_mock(mocker): expected = ConnectorSpecification(connectionSpecification={}) mock = MagicMock(return_value=expected) mocker.patch.object(MockSource, "spec", mock) return mock MESSAGE_FROM_REPOSITORY = AirbyteMessage( type=Type.CONTROL, control=AirbyteControlMessage( type=OrchestratorType.CONNECTOR_CONFIG, emitted_at=10, connectorConfig=AirbyteControlConnectorConfigMessage(config={"any config": "a config value"}), ), ) @pytest.fixture def entrypoint(mocker) -> AirbyteEntrypoint: message_repository = MagicMock() message_repository.consume_queue.side_effect = [[message for message in [MESSAGE_FROM_REPOSITORY]], [], []] mocker.patch.object(MockSource, "message_repository", new_callable=mocker.PropertyMock, return_value=message_repository) return AirbyteEntrypoint(MockSource()) def test_airbyte_entrypoint_init(mocker): mocker.patch.object(entrypoint_module, "init_uncaught_exception_handler") AirbyteEntrypoint(MockSource()) entrypoint_module.init_uncaught_exception_handler.assert_called_once_with(entrypoint_module.logger) @pytest.mark.parametrize( ["cmd", "args", "expected_args"], [ ("spec", {"debug": ""}, {"command": "spec", "debug": True}), ("check", {"config": "config_path"}, {"command": "check", "config": "config_path", "debug": False}), ("discover", {"config": "config_path", "debug": ""}, {"command": "discover", "config": "config_path", "debug": True}), ( "read", {"config": "config_path", "catalog": "catalog_path", "state": "None"}, {"command": "read", "config": "config_path", "catalog": "catalog_path", "state": "None", "debug": False}, ), ( "read", {"config": "config_path", "catalog": "catalog_path", "state": "state_path", "debug": ""}, {"command": "read", "config": "config_path", "catalog": "catalog_path", "state": "state_path", "debug": True}, ), ], ) def test_parse_valid_args(cmd: str, args: Mapping[str, Any], expected_args, entrypoint: AirbyteEntrypoint): arglist = _as_arglist(cmd, args) parsed_args = entrypoint.parse_args(arglist) assert vars(parsed_args) == expected_args @pytest.mark.parametrize( ["cmd", "args"], [ ("check", {"config": "config_path"}), ("discover", {"config": "config_path"}), ("read", {"config": "config_path", "catalog": "catalog_path"}), ], ) def test_parse_missing_required_args(cmd: str, args: MutableMapping[str, Any], entrypoint: AirbyteEntrypoint): required_args = {"check": ["config"], "discover": ["config"], "read": ["config", "catalog"]} for required_arg in required_args[cmd]: argcopy = deepcopy(args) del argcopy[required_arg] with pytest.raises(BaseException): entrypoint.parse_args(_as_arglist(cmd, argcopy)) def _wrap_message(submessage: Union[AirbyteConnectionStatus, ConnectorSpecification, AirbyteRecordMessage, AirbyteCatalog]) -> str: if isinstance(submessage, AirbyteConnectionStatus): message = AirbyteMessage(type=Type.CONNECTION_STATUS, connectionStatus=submessage) elif isinstance(submessage, ConnectorSpecification): message = AirbyteMessage(type=Type.SPEC, spec=submessage) elif isinstance(submessage, AirbyteCatalog): message = AirbyteMessage(type=Type.CATALOG, catalog=submessage) elif isinstance(submessage, AirbyteRecordMessage): message = AirbyteMessage(type=Type.RECORD, record=submessage) else: raise Exception(f"Unknown message type: {submessage}") return message.json(exclude_unset=True) def test_run_spec(entrypoint: AirbyteEntrypoint, mocker): parsed_args = Namespace(command="spec") expected = ConnectorSpecification(connectionSpecification={"hi": "hi"}) mocker.patch.object(MockSource, "spec", return_value=expected) messages = list(entrypoint.run(parsed_args)) assert [MESSAGE_FROM_REPOSITORY.json(exclude_unset=True), _wrap_message(expected)] == messages @pytest.fixture def config_mock(mocker, request): config = request.param if hasattr(request, "param") else {"username": "fake"} mocker.patch.object(MockSource, "read_config", return_value=config) mocker.patch.object(MockSource, "configure", return_value=config) return config @pytest.mark.parametrize( "config_mock, schema, config_valid", [ ({"username": "fake"}, {"type": "object", "properties": {"name": {"type": "string"}}, "additionalProperties": False}, False), ({"username": "fake"}, {"type": "object", "properties": {"username": {"type": "string"}}, "additionalProperties": False}, True), ({"username": "fake"}, {"type": "object", "properties": {"user": {"type": "string"}}}, True), ({"username": "fake"}, {"type": "object", "properties": {"user": {"type": "string", "airbyte_secret": True}}}, True), ( {"username": "fake", "_limit": 22}, {"type": "object", "properties": {"username": {"type": "string"}}, "additionalProperties": False}, True, ), ], indirect=["config_mock"], ) def test_config_validate(entrypoint: AirbyteEntrypoint, mocker, config_mock, schema, config_valid): parsed_args = Namespace(command="check", config="config_path") check_value = AirbyteConnectionStatus(status=Status.SUCCEEDED) mocker.patch.object(MockSource, "check", return_value=check_value) mocker.patch.object(MockSource, "spec", return_value=ConnectorSpecification(connectionSpecification=schema)) messages = list(entrypoint.run(parsed_args)) if config_valid: assert [MESSAGE_FROM_REPOSITORY.json(exclude_unset=True), _wrap_message(check_value)] == messages else: assert len(messages) == 2 assert messages[0] == MESSAGE_FROM_REPOSITORY.json(exclude_unset=True) connection_status_message = AirbyteMessage.parse_raw(messages[1]) assert connection_status_message.type == Type.CONNECTION_STATUS assert connection_status_message.connectionStatus.status == Status.FAILED assert connection_status_message.connectionStatus.message.startswith("Config validation error:") def test_run_check(entrypoint: AirbyteEntrypoint, mocker, spec_mock, config_mock): parsed_args = Namespace(command="check", config="config_path") check_value = AirbyteConnectionStatus(status=Status.SUCCEEDED) mocker.patch.object(MockSource, "check", return_value=check_value) messages = list(entrypoint.run(parsed_args)) assert [MESSAGE_FROM_REPOSITORY.json(exclude_unset=True), _wrap_message(check_value)] == messages assert spec_mock.called def test_run_check_with_exception(entrypoint: AirbyteEntrypoint, mocker, spec_mock, config_mock): parsed_args = Namespace(command="check", config="config_path") mocker.patch.object(MockSource, "check", side_effect=ValueError("Any error")) with pytest.raises(ValueError): messages = list(entrypoint.run(parsed_args)) assert [MESSAGE_FROM_REPOSITORY.json(exclude_unset=True)] == messages def test_run_discover(entrypoint: AirbyteEntrypoint, mocker, spec_mock, config_mock): parsed_args = Namespace(command="discover", config="config_path") expected = AirbyteCatalog(streams=[AirbyteStream(name="stream", json_schema={"k": "v"}, supported_sync_modes=[SyncMode.full_refresh])]) mocker.patch.object(MockSource, "discover", return_value=expected) messages = list(entrypoint.run(parsed_args)) assert [MESSAGE_FROM_REPOSITORY.json(exclude_unset=True), _wrap_message(expected)] == messages assert spec_mock.called def test_run_discover_with_exception(entrypoint: AirbyteEntrypoint, mocker, spec_mock, config_mock): parsed_args = Namespace(command="discover", config="config_path") mocker.patch.object(MockSource, "discover", side_effect=ValueError("Any error")) with pytest.raises(ValueError): messages = list(entrypoint.run(parsed_args)) assert [MESSAGE_FROM_REPOSITORY.json(exclude_unset=True)] == messages def test_run_read(entrypoint: AirbyteEntrypoint, mocker, spec_mock, config_mock): parsed_args = Namespace(command="read", config="config_path", state="statepath", catalog="catalogpath") expected = AirbyteRecordMessage(stream="stream", data={"data": "stuff"}, emitted_at=1) mocker.patch.object(MockSource, "read_state", return_value={}) mocker.patch.object(MockSource, "read_catalog", return_value={}) mocker.patch.object(MockSource, "read", return_value=[AirbyteMessage(record=expected, type=Type.RECORD)]) messages = list(entrypoint.run(parsed_args)) assert [MESSAGE_FROM_REPOSITORY.json(exclude_unset=True), _wrap_message(expected)] == messages assert spec_mock.called def test_given_message_emitted_during_config_when_read_then_emit_message_before_next_steps( entrypoint: AirbyteEntrypoint, mocker, spec_mock, config_mock ): parsed_args = Namespace(command="read", config="config_path", state="statepath", catalog="catalogpath") mocker.patch.object(MockSource, "read_catalog", side_effect=ValueError) messages = entrypoint.run(parsed_args) assert next(messages) == MESSAGE_FROM_REPOSITORY.json(exclude_unset=True) with pytest.raises(ValueError): next(messages) def test_run_read_with_exception(entrypoint: AirbyteEntrypoint, mocker, spec_mock, config_mock): parsed_args = Namespace(command="read", config="config_path", state="statepath", catalog="catalogpath") mocker.patch.object(MockSource, "read_state", return_value={}) mocker.patch.object(MockSource, "read_catalog", return_value={}) mocker.patch.object(MockSource, "read", side_effect=ValueError("Any error")) with pytest.raises(ValueError): messages = list(entrypoint.run(parsed_args)) assert [MESSAGE_FROM_REPOSITORY.json(exclude_unset=True)] == messages def test_invalid_command(entrypoint: AirbyteEntrypoint, config_mock): with pytest.raises(Exception): list(entrypoint.run(Namespace(command="invalid", config="conf"))) @pytest.mark.parametrize( "deployment_mode, url, expected_error", [ pytest.param("CLOUD", "https://airbyte.com", None, id="test_cloud_public_endpoint_is_successful"), pytest.param("CLOUD", "https://192.168.27.30", AirbyteTracedException, id="test_cloud_private_ip_address_is_rejected"), pytest.param("CLOUD", "https://localhost:8080/api/v1/cast", AirbyteTracedException, id="test_cloud_private_endpoint_is_rejected"), pytest.param("CLOUD", "http://past.lives.net/api/v1/inyun", ValueError, id="test_cloud_unsecured_endpoint_is_rejected"), pytest.param("CLOUD", "https://not:very/cash:443.money", ValueError, id="test_cloud_invalid_url_format"), pytest.param("CLOUD", "https://192.168.27.30 ", ValueError, id="test_cloud_incorrect_ip_format_is_rejected"), pytest.param("cloud", "https://192.168.27.30", AirbyteTracedException, id="test_case_insensitive_cloud_environment_variable"), pytest.param("OSS", "https://airbyte.com", None, id="test_oss_public_endpoint_is_successful"), pytest.param("OSS", "https://192.168.27.30", None, id="test_oss_private_endpoint_is_successful"), pytest.param("OSS", "https://localhost:8080/api/v1/cast", None, id="test_oss_private_endpoint_is_successful"), pytest.param("OSS", "http://past.lives.net/api/v1/inyun", None, id="test_oss_unsecured_endpoint_is_successful"), ], ) @patch.object(requests.Session, "send", lambda self, request, **kwargs: requests.Response()) def test_filter_internal_requests(deployment_mode, url, expected_error): with mock.patch.dict(os.environ, {"DEPLOYMENT_MODE": deployment_mode}, clear=False): AirbyteEntrypoint(source=MockSource()) session = requests.Session() prepared_request = requests.PreparedRequest() prepared_request.method = "GET" prepared_request.headers = {"header": "value"} prepared_request.url = url if expected_error: with pytest.raises(expected_error): session.send(request=prepared_request) else: actual_response = session.send(request=prepared_request) assert isinstance(actual_response, requests.Response) @pytest.mark.parametrize( "incoming_message, stream_message_count, expected_message, expected_records_by_stream", [ pytest.param( AirbyteMessage(type=Type.RECORD, record=AirbyteRecordMessage(stream="customers", data={"id": "12345"}, emitted_at=1)), {HashableStreamDescriptor(name="customers"): 100.0}, AirbyteMessage(type=Type.RECORD, record=AirbyteRecordMessage(stream="customers", data={"id": "12345"}, emitted_at=1)), {HashableStreamDescriptor(name="customers"): 101.0}, id="test_handle_record_message", ), pytest.param( AirbyteMessage(type=Type.STATE, state=AirbyteStateMessage(type=AirbyteStateType.STREAM, stream=AirbyteStreamState( stream_descriptor=StreamDescriptor(name="customers"), stream_state=AirbyteStateBlob(updated_at="2024-02-02")))), {HashableStreamDescriptor(name="customers"): 100.0}, AirbyteMessage(type=Type.STATE, state=AirbyteStateMessage(type=AirbyteStateType.STREAM, stream=AirbyteStreamState( stream_descriptor=StreamDescriptor(name="customers"), stream_state=AirbyteStateBlob(updated_at="2024-02-02")), sourceStats=AirbyteStateStats(recordCount=100.0))), {HashableStreamDescriptor(name="customers"): 0.0}, id="test_handle_state_message", ), pytest.param( AirbyteMessage(type=Type.RECORD, record=AirbyteRecordMessage(stream="customers", data={"id": "12345"}, emitted_at=1)), defaultdict(float), AirbyteMessage(type=Type.RECORD, record=AirbyteRecordMessage(stream="customers", data={"id": "12345"}, emitted_at=1)), {HashableStreamDescriptor(name="customers"): 1.0}, id="test_handle_first_record_message", ), pytest.param( AirbyteMessage(type=Type.TRACE, trace=AirbyteTraceMessage(type=TraceType.STREAM_STATUS, stream_status=AirbyteStreamStatusTraceMessage( stream_descriptor=StreamDescriptor(name="customers"), status=AirbyteStreamStatus.COMPLETE), emitted_at=1)), {HashableStreamDescriptor(name="customers"): 5.0}, AirbyteMessage(type=Type.TRACE, trace=AirbyteTraceMessage(type=TraceType.STREAM_STATUS, stream_status=AirbyteStreamStatusTraceMessage( stream_descriptor=StreamDescriptor(name="customers"), status=AirbyteStreamStatus.COMPLETE), emitted_at=1)), {HashableStreamDescriptor(name="customers"): 5.0}, id="test_handle_other_message_type", ), pytest.param( AirbyteMessage(type=Type.RECORD, record=AirbyteRecordMessage(stream="others", data={"id": "12345"}, emitted_at=1)), {HashableStreamDescriptor(name="customers"): 100.0, HashableStreamDescriptor(name="others"): 27.0}, AirbyteMessage(type=Type.RECORD, record=AirbyteRecordMessage(stream="others", data={"id": "12345"}, emitted_at=1)), {HashableStreamDescriptor(name="customers"): 100.0, HashableStreamDescriptor(name="others"): 28.0}, id="test_handle_record_message_for_other_stream", ), pytest.param( AirbyteMessage(type=Type.STATE, state=AirbyteStateMessage(type=AirbyteStateType.STREAM, stream=AirbyteStreamState( stream_descriptor=StreamDescriptor(name="others"), stream_state=AirbyteStateBlob(updated_at="2024-02-02")))), {HashableStreamDescriptor(name="customers"): 100.0, HashableStreamDescriptor(name="others"): 27.0}, AirbyteMessage(type=Type.STATE, state=AirbyteStateMessage(type=AirbyteStateType.STREAM, stream=AirbyteStreamState( stream_descriptor=StreamDescriptor(name="others"), stream_state=AirbyteStateBlob(updated_at="2024-02-02")), sourceStats=AirbyteStateStats(recordCount=27.0))), {HashableStreamDescriptor(name="customers"): 100.0, HashableStreamDescriptor(name="others"): 0.0}, id="test_handle_state_message_for_other_stream", ), pytest.param( AirbyteMessage(type=Type.RECORD, record=AirbyteRecordMessage(stream="customers", namespace="public", data={"id": "12345"}, emitted_at=1)), {HashableStreamDescriptor(name="customers", namespace="public"): 100.0}, AirbyteMessage(type=Type.RECORD, record=AirbyteRecordMessage(stream="customers", namespace="public", data={"id": "12345"}, emitted_at=1)), {HashableStreamDescriptor(name="customers", namespace="public"): 101.0}, id="test_handle_record_message_with_descriptor", ), pytest.param( AirbyteMessage(type=Type.STATE, state=AirbyteStateMessage(type=AirbyteStateType.STREAM, stream=AirbyteStreamState( stream_descriptor=StreamDescriptor(name="customers", namespace="public"), stream_state=AirbyteStateBlob(updated_at="2024-02-02")))), {HashableStreamDescriptor(name="customers", namespace="public"): 100.0}, AirbyteMessage(type=Type.STATE, state=AirbyteStateMessage(type=AirbyteStateType.STREAM, stream=AirbyteStreamState( stream_descriptor=StreamDescriptor(name="customers", namespace="public"), stream_state=AirbyteStateBlob(updated_at="2024-02-02")), sourceStats=AirbyteStateStats(recordCount=100.0))), {HashableStreamDescriptor(name="customers", namespace="public"): 0.0}, id="test_handle_state_message_with_descriptor", ), pytest.param( AirbyteMessage(type=Type.STATE, state=AirbyteStateMessage(type=AirbyteStateType.STREAM, stream=AirbyteStreamState( stream_descriptor=StreamDescriptor(name="others", namespace="public"), stream_state=AirbyteStateBlob(updated_at="2024-02-02")))), {HashableStreamDescriptor(name="customers", namespace="public"): 100.0}, AirbyteMessage(type=Type.STATE, state=AirbyteStateMessage(type=AirbyteStateType.STREAM, stream=AirbyteStreamState( stream_descriptor=StreamDescriptor(name="others", namespace="public"), stream_state=AirbyteStateBlob(updated_at="2024-02-02")), sourceStats=AirbyteStateStats(recordCount=0.0))), {HashableStreamDescriptor(name="customers", namespace="public"): 100.0, HashableStreamDescriptor(name="others", namespace="public"): 0.0}, id="test_handle_state_message_no_records", ), ] ) def test_handle_record_counts(incoming_message, stream_message_count, expected_message, expected_records_by_stream): entrypoint = AirbyteEntrypoint(source=MockSource()) actual_message = entrypoint.handle_record_counts(message=incoming_message, stream_message_count=stream_message_count) assert actual_message == expected_message for stream_descriptor, message_count in stream_message_count.items(): assert isinstance(message_count, float) # Python assertions against different number types won't fail if the value is equivalent assert message_count == expected_records_by_stream[stream_descriptor] if actual_message.type == Type.STATE: assert isinstance(actual_message.state.sourceStats.recordCount, float), "recordCount value should be expressed as a float"