1
0
mirror of synced 2025-12-30 21:02:43 -05:00
Files
airbyte/airbyte-cdk/python/airbyte_cdk/connector_builder/message_grouper.py
Alexandre Girard df01616951 [Issue #23497] Deduplicate query parameters for declarative connectors (#28550)
* remove duplicate param

* remove duplicate params

* fix some of the typing issues

* fix typing issues

* newline

* format

* Enable by default

* Add missing file

* refactor and remove flag

* none check

* move line of code

* fix typing in rate_limiting

* comment

* use typedef

* else branch

* format

* gate the feature

* rename test

* fix the test

* only dedupe if the values are the same

* Add some tests

* convert values to strings

* Document the change

* implement in requester too
2023-07-25 14:22:25 -07:00

306 lines
15 KiB
Python

#
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
import json
import logging
from copy import deepcopy
from json import JSONDecodeError
from typing import Any, Dict, Iterable, Iterator, List, Mapping, Optional, Union
from urllib.parse import parse_qs, urlparse
from airbyte_cdk.connector_builder.models import (
AuxiliaryRequest,
HttpRequest,
HttpResponse,
LogMessage,
StreamRead,
StreamReadPages,
StreamReadSlices,
)
from airbyte_cdk.entrypoint import AirbyteEntrypoint
from airbyte_cdk.sources import AbstractSource
from airbyte_cdk.sources.declarative.declarative_source import DeclarativeSource
from airbyte_cdk.sources.utils.types import JsonType
from airbyte_cdk.utils import AirbyteTracedException
from airbyte_cdk.utils.datetime_format_inferrer import DatetimeFormatInferrer
from airbyte_cdk.utils.schema_inferrer import SchemaInferrer
from airbyte_protocol.models.airbyte_protocol import (
AirbyteControlMessage,
AirbyteLogMessage,
AirbyteMessage,
AirbyteTraceMessage,
ConfiguredAirbyteCatalog,
Level,
OrchestratorType,
TraceType,
)
from airbyte_protocol.models.airbyte_protocol import Type as MessageType
class MessageGrouper:
logger = logging.getLogger("airbyte.connector-builder")
def __init__(self, max_pages_per_slice: int, max_slices: int, max_record_limit: int = 1000):
self._max_pages_per_slice = max_pages_per_slice
self._max_slices = max_slices
self._max_record_limit = max_record_limit
def get_message_groups(
self,
source: DeclarativeSource,
config: Mapping[str, Any],
configured_catalog: ConfiguredAirbyteCatalog,
record_limit: Optional[int] = None,
) -> StreamRead:
if record_limit is not None and not (1 <= record_limit <= 1000):
raise ValueError(f"Record limit must be between 1 and 1000. Got {record_limit}")
schema_inferrer = SchemaInferrer()
datetime_format_inferrer = DatetimeFormatInferrer()
if record_limit is None:
record_limit = self._max_record_limit
else:
record_limit = min(record_limit, self._max_record_limit)
slices = []
log_messages = []
latest_config_update: AirbyteControlMessage = None
auxiliary_requests = []
for message_group in self._get_message_groups(
self._read_stream(source, config, configured_catalog),
schema_inferrer,
datetime_format_inferrer,
record_limit,
):
if isinstance(message_group, AirbyteLogMessage):
log_messages.append(LogMessage(**{"message": message_group.message, "level": message_group.level.value}))
elif isinstance(message_group, AirbyteTraceMessage):
if message_group.type == TraceType.ERROR:
error_message = f"{message_group.error.message} - {message_group.error.stack_trace}"
log_messages.append(LogMessage(**{"message": error_message, "level": "ERROR"}))
elif isinstance(message_group, AirbyteControlMessage):
if not latest_config_update or latest_config_update.emitted_at <= message_group.emitted_at:
latest_config_update = message_group
elif isinstance(message_group, AuxiliaryRequest):
auxiliary_requests.append(message_group)
elif isinstance(message_group, StreamReadSlices):
slices.append(message_group)
else:
raise ValueError(f"Unknown message group type: {type(message_group)}")
return StreamRead(
logs=log_messages,
slices=slices,
test_read_limit_reached=self._has_reached_limit(slices),
auxiliary_requests=auxiliary_requests,
inferred_schema=schema_inferrer.get_stream_schema(
configured_catalog.streams[0].stream.name
), # The connector builder currently only supports reading from a single stream at a time
latest_config_update=self._clean_config(latest_config_update.connectorConfig.config) if latest_config_update else None,
inferred_datetime_formats=datetime_format_inferrer.get_inferred_datetime_formats(),
)
def _get_message_groups(
self, messages: Iterator[AirbyteMessage], schema_inferrer: SchemaInferrer, datetime_format_inferrer: DatetimeFormatInferrer, limit: int
) -> Iterable[Union[StreamReadPages, AirbyteControlMessage, AirbyteLogMessage, AirbyteTraceMessage, AuxiliaryRequest]]:
"""
Message groups are partitioned according to when request log messages are received. Subsequent response log messages
and record messages belong to the prior request log message and when we encounter another request, append the latest
message group, until <limit> records have been read.
Messages received from the CDK read operation will always arrive in the following order:
{type: LOG, log: {message: "request: ..."}}
{type: LOG, log: {message: "response: ..."}}
... 0 or more record messages
{type: RECORD, record: {data: ...}}
{type: RECORD, record: {data: ...}}
Repeats for each request/response made
Note: The exception is that normal log messages can be received at any time which are not incorporated into grouping
"""
records_count = 0
at_least_one_page_in_group = False
current_page_records: List[Mapping[str, Any]] = []
current_slice_descriptor: Optional[Dict[str, Any]] = None
current_slice_pages: List[StreamReadPages] = []
current_page_request: Optional[HttpRequest] = None
current_page_response: Optional[HttpResponse] = None
had_error = False
while records_count < limit and (message := next(messages, None)):
json_object = self._parse_json(message.log) if message.type == MessageType.LOG else None
if json_object is not None and not isinstance(json_object, dict):
raise ValueError(f"Expected log message to be a dict, got {json_object} of type {type(json_object)}")
json_message: Optional[Dict[str, JsonType]] = json_object
if self._need_to_close_page(at_least_one_page_in_group, message, json_message):
self._close_page(current_page_request, current_page_response, current_slice_pages, current_page_records, True)
current_page_request = None
current_page_response = None
if at_least_one_page_in_group and message.type == MessageType.LOG and message.log.message.startswith(AbstractSource.SLICE_LOG_PREFIX):
yield StreamReadSlices(pages=current_slice_pages, slice_descriptor=current_slice_descriptor)
current_slice_descriptor = self._parse_slice_description(message.log.message)
current_slice_pages = []
at_least_one_page_in_group = False
elif message.type == MessageType.LOG and message.log.message.startswith(AbstractSource.SLICE_LOG_PREFIX):
# parsing the first slice
current_slice_descriptor = self._parse_slice_description(message.log.message)
elif message.type == MessageType.LOG:
if json_message is not None and self._is_http_log(json_message):
if self._is_auxiliary_http_request(json_message):
airbyte_cdk = json_message.get("airbyte_cdk", {})
if not isinstance(airbyte_cdk, dict):
raise ValueError(f"Expected airbyte_cdk to be a dict, got {airbyte_cdk} of type {type(airbyte_cdk)}")
stream = airbyte_cdk.get("stream", {})
if not isinstance(stream, dict):
raise ValueError(f"Expected stream to be a dict, got {stream} of type {type(stream)}")
title_prefix = (
"Parent stream: " if stream.get("is_substream", False) else ""
)
http = json_message.get("http", {})
if not isinstance(http, dict):
raise ValueError(f"Expected http to be a dict, got {http} of type {type(http)}")
yield AuxiliaryRequest(
title=title_prefix + str(http.get("title", None)),
description=str(http.get("description", None)),
request=self._create_request_from_log_message(json_message),
response=self._create_response_from_log_message(json_message),
)
else:
at_least_one_page_in_group = True
current_page_request = self._create_request_from_log_message(json_message)
current_page_response = self._create_response_from_log_message(json_message)
else:
if message.log.level == Level.ERROR:
had_error = True
yield message.log
elif message.type == MessageType.TRACE:
if message.trace.type == TraceType.ERROR:
had_error = True
yield message.trace
elif message.type == MessageType.RECORD:
current_page_records.append(message.record.data)
records_count += 1
schema_inferrer.accumulate(message.record)
datetime_format_inferrer.accumulate(message.record)
elif message.type == MessageType.CONTROL and message.control.type == OrchestratorType.CONNECTOR_CONFIG:
yield message.control
else:
self._close_page(current_page_request, current_page_response, current_slice_pages, current_page_records, validate_page_complete=not had_error)
yield StreamReadSlices(pages=current_slice_pages, slice_descriptor=current_slice_descriptor)
@staticmethod
def _need_to_close_page(at_least_one_page_in_group: bool, message: AirbyteMessage, json_message: Optional[Dict[str, Any]]) -> bool:
return (
at_least_one_page_in_group
and message.type == MessageType.LOG
and (MessageGrouper._is_page_http_request(json_message) or message.log.message.startswith("slice:"))
)
@staticmethod
def _is_page_http_request(json_message: Optional[Dict[str, Any]]) -> bool:
if not json_message:
return False
else:
return MessageGrouper._is_http_log(json_message) and not MessageGrouper._is_auxiliary_http_request(json_message)
@staticmethod
def _is_http_log(message: Dict[str, JsonType]) -> bool:
return bool(message.get("http", False))
@staticmethod
def _is_auxiliary_http_request(message: Optional[Dict[str, Any]]) -> bool:
"""
A auxiliary request is a request that is performed and will not directly lead to record for the specific stream it is being queried.
A couple of examples are:
* OAuth authentication
* Substream slice generation
"""
if not message:
return False
is_http = MessageGrouper._is_http_log(message)
return is_http and message.get("http", {}).get("is_auxiliary", False)
@staticmethod
def _close_page(current_page_request: Optional[HttpRequest], current_page_response: Optional[HttpResponse], current_slice_pages: List[StreamReadPages], current_page_records: List[Mapping[str, Any]], validate_page_complete: bool) -> None:
"""
Close a page when parsing message groups
@param validate_page_complete: in some cases, we expect the CDK to not return a response. As of today, this will only happen before
an uncaught exception and therefore, the assumption is that `validate_page_complete=True` only on the last page that is being closed
"""
if validate_page_complete and (not current_page_request or not current_page_response):
raise ValueError("Every message grouping should have at least one request and response")
current_slice_pages.append(
StreamReadPages(request=current_page_request, response=current_page_response, records=deepcopy(current_page_records)) # type: ignore
)
current_page_records.clear()
def _read_stream(self, source: DeclarativeSource, config: Mapping[str, Any], configured_catalog: ConfiguredAirbyteCatalog) -> Iterator[AirbyteMessage]:
# the generator can raise an exception
# iterate over the generated messages. if next raise an exception, catch it and yield it as an AirbyteLogMessage
try:
yield from AirbyteEntrypoint(source).read(source.spec(self.logger), config, configured_catalog, {})
except Exception as e:
error_message = f"{e.args[0] if len(e.args) > 0 else str(e)}"
yield AirbyteTracedException.from_exception(e, message=error_message).as_airbyte_message()
@staticmethod
def _parse_json(log_message: AirbyteLogMessage) -> JsonType:
# TODO: As a temporary stopgap, the CDK emits request/response data as a log message string. Ideally this should come in the
# form of a custom message object defined in the Airbyte protocol, but this unblocks us in the immediate while the
# protocol change is worked on.
try:
json_object: JsonType = json.loads(log_message.message)
return json_object
except JSONDecodeError:
return None
@staticmethod
def _create_request_from_log_message(json_http_message: Dict[str, Any]) -> HttpRequest:
url = urlparse(json_http_message.get("url", {}).get("full", ""))
full_path = f"{url.scheme}://{url.hostname}{url.path}" if url else ""
request = json_http_message.get("http", {}).get("request", {})
parameters = parse_qs(url.query) or None
return HttpRequest(
url=full_path,
http_method=request.get("method", ""),
headers=request.get("headers"),
parameters=parameters,
body=request.get("body", {}).get("content", ""),
)
@staticmethod
def _create_response_from_log_message(json_http_message: Dict[str, Any]) -> HttpResponse:
response = json_http_message.get("http", {}).get("response", {})
body = response.get("body", {}).get("content", "")
return HttpResponse(status=response.get("status_code"), body=body, headers=response.get("headers"))
def _has_reached_limit(self, slices: List[StreamReadSlices]) -> bool:
if len(slices) >= self._max_slices:
return True
record_count = 0
for _slice in slices:
if len(_slice.pages) >= self._max_pages_per_slice:
return True
for page in _slice.pages:
record_count += len(page.records)
if record_count >= self._max_record_limit:
return True
return False
def _parse_slice_description(self, log_message: str) -> Dict[str, Any]:
return json.loads(log_message.replace(AbstractSource.SLICE_LOG_PREFIX, "", 1)) # type: ignore
@staticmethod
def _clean_config(config: Dict[str, Any]) -> Dict[str, Any]:
cleaned_config = deepcopy(config)
for key in config.keys():
if key.startswith("__"):
del cleaned_config[key]
return cleaned_config