1
0
mirror of synced 2026-01-15 06:06:30 -05:00
Files
airbyte/airbyte-cdk/python/unit_tests/test/test_entrypoint_wrapper.py
Ella Rohm-Ensing fc12432305 airbyte-cdk: only update airbyte-protocol-models to pydantic v2 (#39524)
## What

Migrating Pydantic V2 for Protocol Messages to speed up emitting records. This gives us 2.5x boost over V1. 

Close https://github.com/airbytehq/airbyte-internal-issues/issues/8333

## How
- Switch to using protocol models generated for pydantic_v2, in a new (temporary) package, `airbyte-protocol-models-pdv2` .
- Update pydantic dependency of the CDK accordingly to v2.
- For minimal impact, still use the compatibility code `pydantic.v1` in all of our pydantic code from airbyte-cdk that does not interact with the protocol models.

## Review guide
1. Checkout the code and clear your CDK virtual env (either `rm -rf .venv && python -m venv .venv` or `poetry env list; poetry env remove <env>`. This is necessary to fully clean out the `airbyte_protocol` library, for some reason. Then: `poetry lock --no-update && poetry install --all-extras`. This should install the CDK with new models. 
2. Run unit tests on the CDK
3. Take your favorite connector and point it's `pyproject.toml` on local CDK (see example in `source-s3`) and try running it's tests and it's regression tests.

## User Impact

> [!warning]
> This is a major CDK change due to the pydantic dependency change - if connectors use pydantic 1.10, they will break and will need to do similar `from pydantic.v1` updates to get running again. Therefore, we should release this as a major CDK version bump.

## Can this PR be safely reverted and rolled back?
- [x] YES 💚
- [ ] NO 

Even if sources migrate to this version, state format should not change, so a revert should be possible.

## Follow up work - Ella to move into issues

<details>

### Source-s3 - turn this into an issue
- [ ] Update source s3 CDK version and any required code changes
- [ ] Fix source-s3 unit tests
- [ ] Run source-s3 regression tests
- [ ] Merge and release source-s3 by June 21st

### Docs
- [ ] Update documentation on how to build with CDK 

### CDK pieces
- [ ] Update file-based CDK format validation to use Pydantic V2
  - This is doable, and requires a breaking change to change `OneOfOptionConfig`. There are a few unhandled test cases that present issues we're unsure of how to handle so far.
- [ ] Update low-code component generators to use Pydantic V2
  - This is doable, there are a few issues around custom component generation that are unhandled.

### Further CDK performance work - create issues for these
- [ ] Research if we can replace prints with buffered output (write to byte buffer and then flush to stdout)
- [ ] Replace `json` with `orjson`
...

</details>
2024-06-21 01:53:44 +02:00

333 lines
15 KiB
Python

# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
import json
import logging
import os
from typing import Any, Iterator, List, Mapping, Optional
from unittest import TestCase
from unittest.mock import Mock, patch
from airbyte_cdk.sources.abstract_source import AbstractSource
from airbyte_cdk.test.entrypoint_wrapper import EntrypointOutput, discover, read
from airbyte_cdk.test.state_builder import StateBuilder
from airbyte_protocol.models import (
AirbyteAnalyticsTraceMessage,
AirbyteCatalog,
AirbyteErrorTraceMessage,
AirbyteLogMessage,
AirbyteMessage,
AirbyteRecordMessage,
AirbyteStateBlob,
AirbyteStateMessage,
AirbyteStreamState,
AirbyteStreamStatus,
AirbyteStreamStatusTraceMessage,
AirbyteTraceMessage,
ConfiguredAirbyteCatalog,
Level,
StreamDescriptor,
TraceType,
Type,
)
def _a_state_message(stream_name: str, stream_state: Mapping[str, Any]) -> AirbyteMessage:
return AirbyteMessage(type=Type.STATE, state=AirbyteStateMessage(stream=AirbyteStreamState(stream_descriptor=StreamDescriptor(name=stream_name), stream_state=AirbyteStateBlob(**stream_state))))
def _a_status_message(stream_name: str, status: AirbyteStreamStatus) -> AirbyteMessage:
return AirbyteMessage(
type=Type.TRACE,
trace=AirbyteTraceMessage(
type=TraceType.STREAM_STATUS,
emitted_at=0,
stream_status=AirbyteStreamStatusTraceMessage(
stream_descriptor=StreamDescriptor(name=stream_name),
status=status,
),
),
)
_A_CATALOG_MESSAGE = AirbyteMessage(
type=Type.CATALOG,
catalog=AirbyteCatalog(streams=[]),
)
_A_RECORD = AirbyteMessage(
type=Type.RECORD, record=AirbyteRecordMessage(stream="stream", data={"record key": "record value"}, emitted_at=0)
)
_A_STATE_MESSAGE = _a_state_message("stream_name", {"state key": "state value for _A_STATE_MESSAGE"})
_A_LOG = AirbyteMessage(type=Type.LOG, log=AirbyteLogMessage(level=Level.INFO, message="This is an Airbyte log message"))
_AN_ERROR_MESSAGE = AirbyteMessage(
type=Type.TRACE,
trace=AirbyteTraceMessage(
type=TraceType.ERROR,
emitted_at=0,
error=AirbyteErrorTraceMessage(message="AirbyteErrorTraceMessage message"),
),
)
_AN_ANALYTIC_MESSAGE = AirbyteMessage(
type=Type.TRACE,
trace=AirbyteTraceMessage(
type=TraceType.ANALYTICS,
emitted_at=0,
analytics=AirbyteAnalyticsTraceMessage(type="an analytic type", value="an analytic value"),
),
)
_A_STREAM_NAME = "a stream name"
_A_CONFIG = {"config_key": "config_value"}
_A_CATALOG = ConfiguredAirbyteCatalog.parse_obj(
{
"streams": [
{
"stream": {
"name": "a_stream_name",
"json_schema": {},
"supported_sync_modes": ["full_refresh"],
},
"sync_mode": "full_refresh",
"destination_sync_mode": "append",
}
]
}
)
_A_STATE = StateBuilder().with_stream_state(_A_STREAM_NAME, {"state_key": "state_value"}).build()
_A_LOG_MESSAGE = "a log message"
def _to_entrypoint_output(messages: List[AirbyteMessage]) -> Iterator[str]:
return (message.json(exclude_unset=True) for message in messages)
def _a_mocked_source() -> AbstractSource:
source = Mock(spec=AbstractSource)
source.message_repository = None
return source
def _validate_tmp_json_file(expected, file_path) -> None:
with open(file_path) as file:
assert json.load(file) == expected
def _validate_tmp_catalog(expected, file_path) -> None:
assert ConfiguredAirbyteCatalog.parse_file(file_path) == expected
def _create_tmp_file_validation(entrypoint, expected_config, expected_catalog: Optional[Any] = None, expected_state: Optional[Any] = None):
def _validate_tmp_files(self):
_validate_tmp_json_file(expected_config, entrypoint.parse_args.call_args.args[0][2])
if expected_catalog:
_validate_tmp_catalog(expected_catalog, entrypoint.parse_args.call_args.args[0][4])
if expected_state:
_validate_tmp_json_file(expected_state, entrypoint.parse_args.call_args.args[0][6])
return entrypoint.return_value.run.return_value
return _validate_tmp_files
class EntrypointWrapperDiscoverTest(TestCase):
def setUp(self) -> None:
self._a_source = _a_mocked_source()
@staticmethod
def test_init_validation_error():
invalid_message = '{"type": "INVALID_TYPE"}'
entrypoint_output = EntrypointOutput([invalid_message])
messages = entrypoint_output._messages
assert len(messages) == 1
assert messages[0].type == Type.LOG
assert messages[0].log.level == Level.INFO
assert messages[0].log.message == invalid_message
@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
def test_when_discover_then_ensure_parameters(self, entrypoint):
entrypoint.return_value.run.side_effect = _create_tmp_file_validation(entrypoint, _A_CONFIG)
discover(self._a_source, _A_CONFIG)
entrypoint.assert_called_once_with(self._a_source)
entrypoint.return_value.run.assert_called_once_with(entrypoint.parse_args.return_value)
assert entrypoint.parse_args.call_count == 1
assert entrypoint.parse_args.call_args.args[0][0] == "discover"
assert entrypoint.parse_args.call_args.args[0][1] == "--config"
@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
def test_when_discover_then_ensure_files_are_temporary(self, entrypoint):
discover(self._a_source, _A_CONFIG)
assert not os.path.exists(entrypoint.parse_args.call_args.args[0][2])
@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
def test_given_logging_during_discover_when_discover_then_output_has_logs(self, entrypoint):
def _do_some_logging(self):
logging.getLogger("any logger").info(_A_LOG_MESSAGE)
return entrypoint.return_value.run.return_value
entrypoint.return_value.run.side_effect = _do_some_logging
output = discover(self._a_source, _A_CONFIG)
assert len(output.logs) == 1
assert output.logs[0].log.message == _A_LOG_MESSAGE
@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
def test_given_record_when_discover_then_output_has_record(self, entrypoint):
entrypoint.return_value.run.return_value = _to_entrypoint_output([_A_CATALOG_MESSAGE])
output = discover(self._a_source, _A_CONFIG)
assert output.catalog == _A_CATALOG_MESSAGE
@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
def test_given_log_when_discover_then_output_has_log(self, entrypoint):
entrypoint.return_value.run.return_value = _to_entrypoint_output([_A_LOG])
output = discover(self._a_source, _A_CONFIG)
assert output.logs == [_A_LOG]
@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
def test_given_trace_message_when_discover_then_output_has_trace_messages(self, entrypoint):
entrypoint.return_value.run.return_value = _to_entrypoint_output([_AN_ANALYTIC_MESSAGE])
output = discover(self._a_source, _A_CONFIG)
assert output.analytics_messages == [_AN_ANALYTIC_MESSAGE]
@patch("airbyte_cdk.test.entrypoint_wrapper.print", create=True)
@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
def test_given_unexpected_exception_when_discover_then_print(self, entrypoint, print_mock):
entrypoint.return_value.run.side_effect = ValueError("This error should be printed")
discover(self._a_source, _A_CONFIG)
assert print_mock.call_count > 0
@patch("airbyte_cdk.test.entrypoint_wrapper.print", create=True)
@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
def test_given_expected_exception_when_discover_then_do_not_print(self, entrypoint, print_mock):
entrypoint.return_value.run.side_effect = ValueError("This error should not be printed")
discover(self._a_source, _A_CONFIG, expecting_exception=True)
assert print_mock.call_count == 0
@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
def test_given_uncaught_exception_when_read_then_output_has_error(self, entrypoint):
entrypoint.return_value.run.side_effect = ValueError("An error")
output = discover(self._a_source, _A_CONFIG)
assert output.errors
class EntrypointWrapperReadTest(TestCase):
def setUp(self) -> None:
self._a_source = _a_mocked_source()
@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
def test_when_read_then_ensure_parameters(self, entrypoint):
entrypoint.return_value.run.side_effect = _create_tmp_file_validation(entrypoint, _A_CONFIG, _A_CATALOG, _A_STATE)
read(self._a_source, _A_CONFIG, _A_CATALOG, _A_STATE)
entrypoint.assert_called_once_with(self._a_source)
entrypoint.return_value.run.assert_called_once_with(entrypoint.parse_args.return_value)
assert entrypoint.parse_args.call_count == 1
assert entrypoint.parse_args.call_args.args[0][0] == "read"
assert entrypoint.parse_args.call_args.args[0][1] == "--config"
assert entrypoint.parse_args.call_args.args[0][3] == "--catalog"
assert entrypoint.parse_args.call_args.args[0][5] == "--state"
@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
def test_when_read_then_ensure_files_are_temporary(self, entrypoint):
read(self._a_source, _A_CONFIG, _A_CATALOG, _A_STATE)
assert not os.path.exists(entrypoint.parse_args.call_args.args[0][2])
assert not os.path.exists(entrypoint.parse_args.call_args.args[0][4])
assert not os.path.exists(entrypoint.parse_args.call_args.args[0][6])
@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
def test_given_logging_during_run_when_read_then_output_has_logs(self, entrypoint):
def _do_some_logging(self):
logging.getLogger("any logger").info(_A_LOG_MESSAGE)
return entrypoint.return_value.run.return_value
entrypoint.return_value.run.side_effect = _do_some_logging
output = read(self._a_source, _A_CONFIG, _A_CATALOG, _A_STATE)
assert len(output.logs) == 1
assert output.logs[0].log.message == _A_LOG_MESSAGE
@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
def test_given_record_when_read_then_output_has_record(self, entrypoint):
entrypoint.return_value.run.return_value = _to_entrypoint_output([_A_RECORD])
output = read(self._a_source, _A_CONFIG, _A_CATALOG, _A_STATE)
assert output.records == [_A_RECORD]
@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
def test_given_state_message_when_read_then_output_has_state_message(self, entrypoint):
entrypoint.return_value.run.return_value = _to_entrypoint_output([_A_STATE_MESSAGE])
output = read(self._a_source, _A_CONFIG, _A_CATALOG, _A_STATE)
assert output.state_messages == [_A_STATE_MESSAGE]
@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
def test_given_state_message_and_records_when_read_then_output_has_records_and_state_message(self, entrypoint):
entrypoint.return_value.run.return_value = _to_entrypoint_output([_A_RECORD, _A_STATE_MESSAGE])
output = read(self._a_source, _A_CONFIG, _A_CATALOG, _A_STATE)
assert output.records_and_state_messages == [_A_RECORD, _A_STATE_MESSAGE]
@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
def test_given_many_state_messages_and_records_when_read_then_output_has_records_and_state_message(self, entrypoint):
state_value = {"state_key": "last state value"}
last_emitted_state = AirbyteStreamState(stream_descriptor=StreamDescriptor(name="stream_name"), stream_state=AirbyteStateBlob(**state_value))
entrypoint.return_value.run.return_value = _to_entrypoint_output([_A_STATE_MESSAGE, _a_state_message("stream_name", state_value)])
output = read(self._a_source, _A_CONFIG, _A_CATALOG, _A_STATE)
assert output.most_recent_state == last_emitted_state
@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
def test_given_log_when_read_then_output_has_log(self, entrypoint):
entrypoint.return_value.run.return_value = _to_entrypoint_output([_A_LOG])
output = read(self._a_source, _A_CONFIG, _A_CATALOG, _A_STATE)
assert output.logs == [_A_LOG]
@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
def test_given_trace_message_when_read_then_output_has_trace_messages(self, entrypoint):
entrypoint.return_value.run.return_value = _to_entrypoint_output([_AN_ANALYTIC_MESSAGE])
output = read(self._a_source, _A_CONFIG, _A_CATALOG, _A_STATE)
assert output.analytics_messages == [_AN_ANALYTIC_MESSAGE]
@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
def test_given_stream_statuses_when_read_then_return_statuses(self, entrypoint):
status_messages = [
_a_status_message(_A_STREAM_NAME, AirbyteStreamStatus.STARTED),
_a_status_message(_A_STREAM_NAME, AirbyteStreamStatus.COMPLETE),
]
entrypoint.return_value.run.return_value = _to_entrypoint_output(status_messages)
output = read(self._a_source, _A_CONFIG, _A_CATALOG, _A_STATE)
assert output.get_stream_statuses(_A_STREAM_NAME) == [AirbyteStreamStatus.STARTED, AirbyteStreamStatus.COMPLETE]
@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
def test_given_stream_statuses_for_many_streams_when_read_then_filter_other_streams(self, entrypoint):
status_messages = [
_a_status_message(_A_STREAM_NAME, AirbyteStreamStatus.STARTED),
_a_status_message("another stream name", AirbyteStreamStatus.INCOMPLETE),
_a_status_message(_A_STREAM_NAME, AirbyteStreamStatus.COMPLETE),
]
entrypoint.return_value.run.return_value = _to_entrypoint_output(status_messages)
output = read(self._a_source, _A_CONFIG, _A_CATALOG, _A_STATE)
assert len(output.get_stream_statuses(_A_STREAM_NAME)) == 2
@patch("airbyte_cdk.test.entrypoint_wrapper.print", create=True)
@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
def test_given_unexpected_exception_when_read_then_print(self, entrypoint, print_mock):
entrypoint.return_value.run.side_effect = ValueError("This error should be printed")
read(self._a_source, _A_CONFIG, _A_CATALOG, _A_STATE)
assert print_mock.call_count > 0
@patch("airbyte_cdk.test.entrypoint_wrapper.print", create=True)
@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
def test_given_expected_exception_when_read_then_do_not_print(self, entrypoint, print_mock):
entrypoint.return_value.run.side_effect = ValueError("This error should not be printed")
read(self._a_source, _A_CONFIG, _A_CATALOG, _A_STATE, expecting_exception=True)
assert print_mock.call_count == 0
@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
def test_given_uncaught_exception_when_read_then_output_has_error(self, entrypoint):
entrypoint.return_value.run.side_effect = ValueError("An error")
output = read(self._a_source, _A_CONFIG, _A_CATALOG, _A_STATE)
assert output.errors