## What Migrating Pydantic V2 for Protocol Messages to speed up emitting records. This gives us 2.5x boost over V1. Close https://github.com/airbytehq/airbyte-internal-issues/issues/8333 ## How - Switch to using protocol models generated for pydantic_v2, in a new (temporary) package, `airbyte-protocol-models-pdv2` . - Update pydantic dependency of the CDK accordingly to v2. - For minimal impact, still use the compatibility code `pydantic.v1` in all of our pydantic code from airbyte-cdk that does not interact with the protocol models. ## Review guide 1. Checkout the code and clear your CDK virtual env (either `rm -rf .venv && python -m venv .venv` or `poetry env list; poetry env remove <env>`. This is necessary to fully clean out the `airbyte_protocol` library, for some reason. Then: `poetry lock --no-update && poetry install --all-extras`. This should install the CDK with new models. 2. Run unit tests on the CDK 3. Take your favorite connector and point it's `pyproject.toml` on local CDK (see example in `source-s3`) and try running it's tests and it's regression tests. ## User Impact > [!warning] > This is a major CDK change due to the pydantic dependency change - if connectors use pydantic 1.10, they will break and will need to do similar `from pydantic.v1` updates to get running again. Therefore, we should release this as a major CDK version bump. ## Can this PR be safely reverted and rolled back? - [x] YES 💚 - [ ] NO ❌ Even if sources migrate to this version, state format should not change, so a revert should be possible. ## Follow up work - Ella to move into issues <details> ### Source-s3 - turn this into an issue - [ ] Update source s3 CDK version and any required code changes - [ ] Fix source-s3 unit tests - [ ] Run source-s3 regression tests - [ ] Merge and release source-s3 by June 21st ### Docs - [ ] Update documentation on how to build with CDK ### CDK pieces - [ ] Update file-based CDK format validation to use Pydantic V2 - This is doable, and requires a breaking change to change `OneOfOptionConfig`. There are a few unhandled test cases that present issues we're unsure of how to handle so far. - [ ] Update low-code component generators to use Pydantic V2 - This is doable, there are a few issues around custom component generation that are unhandled. ### Further CDK performance work - create issues for these - [ ] Research if we can replace prints with buffered output (write to byte buffer and then flush to stdout) - [ ] Replace `json` with `orjson` ... </details>
207 lines
8.3 KiB
Python
207 lines
8.3 KiB
Python
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
|
|
|
|
"""
|
|
The AirbyteEntrypoint is important because it is a service layer that orchestrate how we execute commands from the
|
|
[common interface](https://docs.airbyte.com/understanding-airbyte/airbyte-protocol#common-interface) through the source Python
|
|
implementation. There is some logic about which message we send to the platform and when which is relevant for integration testing. Other
|
|
than that, there are integrations point that are annoying to integrate with using Python code:
|
|
* Sources communicate with the platform using stdout. The implication is that the source could just print every message instead of
|
|
returning things to source.<method> or to using the message repository. WARNING: As part of integration testing, we will not support
|
|
messages that are simply printed. The reason is that capturing stdout relies on overriding sys.stdout (see
|
|
https://docs.python.org/3/library/contextlib.html#contextlib.redirect_stdout) which clashes with how pytest captures logs and brings
|
|
considerations for multithreaded applications. If code you work with uses `print` statements, please migrate to
|
|
source.message_repository to emit those messages
|
|
* The entrypoint interface relies on file being written on the file system
|
|
"""
|
|
|
|
import json
|
|
import logging
|
|
import tempfile
|
|
import traceback
|
|
from io import StringIO
|
|
from pathlib import Path
|
|
from typing import Any, List, Mapping, Optional, Union
|
|
|
|
from airbyte_cdk.entrypoint import AirbyteEntrypoint
|
|
from airbyte_cdk.exception_handler import assemble_uncaught_exception
|
|
from airbyte_cdk.logger import AirbyteLogFormatter
|
|
from airbyte_cdk.sources import Source
|
|
from airbyte_protocol.models import (
|
|
AirbyteLogMessage,
|
|
AirbyteMessage,
|
|
AirbyteStateMessage,
|
|
AirbyteStreamStatus,
|
|
ConfiguredAirbyteCatalog,
|
|
Level,
|
|
TraceType,
|
|
Type,
|
|
)
|
|
from pydantic import ValidationError as V2ValidationError
|
|
|
|
|
|
class EntrypointOutput:
|
|
def __init__(self, messages: List[str], uncaught_exception: Optional[BaseException] = None):
|
|
try:
|
|
self._messages = [self._parse_message(message) for message in messages]
|
|
except V2ValidationError as exception:
|
|
raise ValueError("All messages are expected to be AirbyteMessage") from exception
|
|
|
|
if uncaught_exception:
|
|
self._messages.append(assemble_uncaught_exception(type(uncaught_exception), uncaught_exception).as_airbyte_message())
|
|
|
|
@staticmethod
|
|
def _parse_message(message: str) -> AirbyteMessage:
|
|
try:
|
|
return AirbyteMessage.parse_obj(json.loads(message))
|
|
except (json.JSONDecodeError, V2ValidationError):
|
|
# The platform assumes that logs that are not of AirbyteMessage format are log messages
|
|
return AirbyteMessage(type=Type.LOG, log=AirbyteLogMessage(level=Level.INFO, message=message))
|
|
|
|
@property
|
|
def records_and_state_messages(self) -> List[AirbyteMessage]:
|
|
return self._get_message_by_types([Type.RECORD, Type.STATE])
|
|
|
|
@property
|
|
def records(self) -> List[AirbyteMessage]:
|
|
return self._get_message_by_types([Type.RECORD])
|
|
|
|
@property
|
|
def state_messages(self) -> List[AirbyteMessage]:
|
|
return self._get_message_by_types([Type.STATE])
|
|
|
|
@property
|
|
def most_recent_state(self) -> Any:
|
|
state_messages = self._get_message_by_types([Type.STATE])
|
|
if not state_messages:
|
|
raise ValueError("Can't provide most recent state as there are no state messages")
|
|
return state_messages[-1].state.stream
|
|
|
|
@property
|
|
def logs(self) -> List[AirbyteMessage]:
|
|
return self._get_message_by_types([Type.LOG])
|
|
|
|
@property
|
|
def trace_messages(self) -> List[AirbyteMessage]:
|
|
return self._get_message_by_types([Type.TRACE])
|
|
|
|
@property
|
|
def analytics_messages(self) -> List[AirbyteMessage]:
|
|
return self._get_trace_message_by_trace_type(TraceType.ANALYTICS)
|
|
|
|
@property
|
|
def errors(self) -> List[AirbyteMessage]:
|
|
return self._get_trace_message_by_trace_type(TraceType.ERROR)
|
|
|
|
@property
|
|
def catalog(self) -> AirbyteMessage:
|
|
catalog = self._get_message_by_types([Type.CATALOG])
|
|
if len(catalog) != 1:
|
|
raise ValueError(f"Expected exactly one catalog but got {len(catalog)}")
|
|
return catalog[0]
|
|
|
|
def get_stream_statuses(self, stream_name: str) -> List[AirbyteStreamStatus]:
|
|
status_messages = map(
|
|
lambda message: message.trace.stream_status.status,
|
|
filter(
|
|
lambda message: message.trace.stream_status.stream_descriptor.name == stream_name,
|
|
self._get_trace_message_by_trace_type(TraceType.STREAM_STATUS),
|
|
),
|
|
)
|
|
return list(status_messages)
|
|
|
|
def _get_message_by_types(self, message_types: List[Type]) -> List[AirbyteMessage]:
|
|
return [message for message in self._messages if message.type in message_types]
|
|
|
|
def _get_trace_message_by_trace_type(self, trace_type: TraceType) -> List[AirbyteMessage]:
|
|
return [message for message in self._get_message_by_types([Type.TRACE]) if message.trace.type == trace_type]
|
|
|
|
|
|
def _run_command(source: Source, args: List[str], expecting_exception: bool = False) -> EntrypointOutput:
|
|
log_capture_buffer = StringIO()
|
|
stream_handler = logging.StreamHandler(log_capture_buffer)
|
|
stream_handler.setLevel(logging.INFO)
|
|
stream_handler.setFormatter(AirbyteLogFormatter())
|
|
parent_logger = logging.getLogger("")
|
|
parent_logger.addHandler(stream_handler)
|
|
|
|
parsed_args = AirbyteEntrypoint.parse_args(args)
|
|
|
|
source_entrypoint = AirbyteEntrypoint(source)
|
|
messages = []
|
|
uncaught_exception = None
|
|
try:
|
|
for message in source_entrypoint.run(parsed_args):
|
|
messages.append(message)
|
|
except Exception as exception:
|
|
if not expecting_exception:
|
|
print("Printing unexpected error from entrypoint_wrapper")
|
|
print("".join(traceback.format_exception(None, exception, exception.__traceback__)))
|
|
uncaught_exception = exception
|
|
|
|
captured_logs = log_capture_buffer.getvalue().split("\n")[:-1]
|
|
|
|
parent_logger.removeHandler(stream_handler)
|
|
|
|
return EntrypointOutput(messages + captured_logs, uncaught_exception)
|
|
|
|
|
|
def discover(
|
|
source: Source,
|
|
config: Mapping[str, Any],
|
|
expecting_exception: bool = False,
|
|
) -> EntrypointOutput:
|
|
"""
|
|
config must be json serializable
|
|
:param expecting_exception: By default if there is an uncaught exception, the exception will be printed out. If this is expected, please
|
|
provide expecting_exception=True so that the test output logs are cleaner
|
|
"""
|
|
|
|
with tempfile.TemporaryDirectory() as tmp_directory:
|
|
tmp_directory_path = Path(tmp_directory)
|
|
config_file = make_file(tmp_directory_path / "config.json", config)
|
|
|
|
return _run_command(source, ["discover", "--config", config_file, "--debug"], expecting_exception)
|
|
|
|
|
|
def read(
|
|
source: Source,
|
|
config: Mapping[str, Any],
|
|
catalog: ConfiguredAirbyteCatalog,
|
|
state: Optional[List[AirbyteStateMessage]] = None,
|
|
expecting_exception: bool = False,
|
|
) -> EntrypointOutput:
|
|
"""
|
|
config and state must be json serializable
|
|
|
|
:param expecting_exception: By default if there is an uncaught exception, the exception will be printed out. If this is expected, please
|
|
provide expecting_exception=True so that the test output logs are cleaner
|
|
"""
|
|
with tempfile.TemporaryDirectory() as tmp_directory:
|
|
tmp_directory_path = Path(tmp_directory)
|
|
config_file = make_file(tmp_directory_path / "config.json", config)
|
|
catalog_file = make_file(tmp_directory_path / "catalog.json", catalog.json())
|
|
args = [
|
|
"read",
|
|
"--config",
|
|
config_file,
|
|
"--catalog",
|
|
catalog_file,
|
|
]
|
|
if state is not None:
|
|
args.extend(
|
|
[
|
|
"--state",
|
|
make_file(tmp_directory_path / "state.json", f"[{','.join([stream_state.json() for stream_state in state])}]"),
|
|
]
|
|
)
|
|
|
|
return _run_command(source, args, expecting_exception)
|
|
|
|
|
|
def make_file(path: Path, file_contents: Optional[Union[str, Mapping[str, Any], List[Mapping[str, Any]]]]) -> str:
|
|
if isinstance(file_contents, str):
|
|
path.write_text(file_contents)
|
|
else:
|
|
path.write_text(json.dumps(file_contents))
|
|
return str(path)
|