1
0
mirror of synced 2025-12-23 21:03:15 -05:00

[ISSUE #20771] limiting the number of requests performed to the backe… (#21525)

* [ISSUE #20771] limiting the number of requests performed to the backend without flag

* [ISSUE #20771] code reviewing my own code

* [ISSUE #20771] adding ABC to paginator

* [ISSUE #20771] format code

* [ISSUE #20771] adding slices to connector builder read request (#21605)

* [ISSUE #20771] adding slices to connector builder read request

* [ISSUE #20771] formatting

* [ISSUE #20771] set flag when limit requests reached (#21619)

* [ISSUE #20771] set flag when limit requests reached

* [ISSUE #20771] assert proper value on test read objects __init__

* [ISSUE #20771] code review and fix edge case

* [ISSUE #20771] fix flake8 error

* [ISSUE #20771] code review

* 🤖 Bump minor version of Airbyte CDK

* to run the CI
This commit is contained in:
Maxime Carbonneau-Leclerc
2023-01-24 10:19:19 -05:00
committed by GitHub
parent 3efc6431ea
commit ca8cdc40aa
25 changed files with 1171 additions and 637 deletions

View File

@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.22.0
current_version = 0.23.0
commit = False
[bumpversion:file:setup.py]

View File

@@ -1,5 +1,8 @@
# Changelog
## 0.23.0
Limiting the number of HTTP requests during a test read
## 0.22.0
Surface the resolved manifest in the CDK

View File

@@ -2,6 +2,7 @@
# Copyright (c) 2022 Airbyte, Inc., all rights reserved.
#
import json
import logging
from abc import ABC, abstractmethod
from typing import Any, Dict, Iterator, List, Mapping, MutableMapping, Optional, Tuple, Union
@@ -9,10 +10,12 @@ from typing import Any, Dict, Iterator, List, Mapping, MutableMapping, Optional,
from airbyte_cdk.models import (
AirbyteCatalog,
AirbyteConnectionStatus,
AirbyteLogMessage,
AirbyteMessage,
AirbyteStateMessage,
ConfiguredAirbyteCatalog,
ConfiguredAirbyteStream,
Level,
Status,
SyncMode,
)
@@ -34,6 +37,8 @@ class AbstractSource(Source, ABC):
in this class to create an Airbyte Specification compliant Source.
"""
SLICE_LOG_PREFIX = "slice:"
@abstractmethod
def check_connection(self, logger: logging.Logger, config: Mapping[str, Any]) -> Tuple[bool, Optional[Any]]:
"""
@@ -236,7 +241,10 @@ class AbstractSource(Source, ABC):
has_slices = False
for _slice in slices:
has_slices = True
logger.debug("Processing stream slice", extra={"slice": _slice})
if logger.isEnabledFor(logging.DEBUG):
yield AirbyteMessage(
type=MessageType.LOG, log=AirbyteLogMessage(level=Level.INFO, message=f"{self.SLICE_LOG_PREFIX}{json.dumps(_slice)}")
)
records = stream_instance.read_records(
sync_mode=SyncMode.incremental,
stream_slice=_slice,
@@ -285,7 +293,10 @@ class AbstractSource(Source, ABC):
)
total_records_counter = 0
for _slice in slices:
logger.debug("Processing stream slice", extra={"slice": _slice})
if logger.isEnabledFor(logging.DEBUG):
yield AirbyteMessage(
type=MessageType.LOG, log=AirbyteLogMessage(level=Level.INFO, message=f"{self.SLICE_LOG_PREFIX}{json.dumps(_slice)}")
)
record_data_or_messages = stream_instance.read_records(
stream_slice=_slice,
sync_mode=SyncMode.full_refresh,

View File

@@ -50,7 +50,13 @@ class ManifestDeclarativeSource(DeclarativeSource):
VALID_TOP_LEVEL_FIELDS = {"check", "definitions", "schemas", "spec", "streams", "type", "version"}
def __init__(self, source_config: ConnectionDefinition, debug: bool = False, construct_using_pydantic_models: bool = False):
def __init__(
self,
source_config: ConnectionDefinition,
debug: bool = False,
component_factory: ModelToComponentFactory = None,
construct_using_pydantic_models: bool = False,
):
"""
:param source_config(Mapping[str, Any]): The manifest of low-code components that describe the source connector
:param debug(bool): True if debug mode is enabled
@@ -71,7 +77,12 @@ class ManifestDeclarativeSource(DeclarativeSource):
self._legacy_source_config = resolved_source_config
self._debug = debug
self._legacy_factory = DeclarativeComponentFactory() # Legacy factory used to instantiate declarative components from the manifest
self._constructor = ModelToComponentFactory() # New factory which converts the manifest to Pydantic models to construct components
if component_factory:
self._constructor = component_factory
else:
self._constructor = (
ModelToComponentFactory()
) # New factory which converts the manifest to Pydantic models to construct components
self._validate_source()

View File

@@ -2,9 +2,9 @@
# Copyright (c) 2022 Airbyte, Inc., all rights reserved.
#
from airbyte_cdk.sources.declarative.requesters.paginators.default_paginator import DefaultPaginator
from airbyte_cdk.sources.declarative.requesters.paginators.default_paginator import DefaultPaginator, PaginatorTestReadDecorator
from airbyte_cdk.sources.declarative.requesters.paginators.no_pagination import NoPagination
from airbyte_cdk.sources.declarative.requesters.paginators.paginator import Paginator
from airbyte_cdk.sources.declarative.requesters.paginators.strategies.pagination_strategy import PaginationStrategy
__all__ = ["DefaultPaginator", "NoPagination", "PaginationStrategy", "Paginator"]
__all__ = ["DefaultPaginator", "NoPagination", "PaginationStrategy", "Paginator", "PaginatorTestReadDecorator"]

View File

@@ -160,3 +160,69 @@ class DefaultPaginator(Paginator, JsonSchemaMixin):
if option_type != RequestOptionType.path:
options[self.page_size_option.field_name] = self.pagination_strategy.get_page_size()
return options
class PaginatorTestReadDecorator(Paginator):
"""
In some cases, we want to limit the number of requests that are made to the backend source. This class allows for limiting the number of
pages that are queried throughout a read command.
"""
_PAGE_COUNT_BEFORE_FIRST_NEXT_CALL = 1
def __init__(self, decorated, maximum_number_of_pages: int = 5):
if maximum_number_of_pages and maximum_number_of_pages < 1:
raise ValueError(f"The maximum number of pages on a test read needs to be strictly positive. Got {maximum_number_of_pages}")
self._maximum_number_of_pages = maximum_number_of_pages
self._decorated = decorated
self._page_count = self._PAGE_COUNT_BEFORE_FIRST_NEXT_CALL
def next_page_token(self, response: requests.Response, last_records: List[Mapping[str, Any]]) -> Optional[Mapping[str, Any]]:
if self._page_count >= self._maximum_number_of_pages:
return None
self._page_count += 1
return self._decorated.next_page_token(response, last_records)
def path(self):
return self._decorated.path()
def get_request_params(
self,
*,
stream_state: Optional[StreamState] = None,
stream_slice: Optional[StreamSlice] = None,
next_page_token: Optional[Mapping[str, Any]] = None,
) -> Mapping[str, Any]:
return self._decorated.get_request_params(stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token)
def get_request_headers(
self,
*,
stream_state: Optional[StreamState] = None,
stream_slice: Optional[StreamSlice] = None,
next_page_token: Optional[Mapping[str, Any]] = None,
) -> Mapping[str, str]:
return self._decorated.get_request_headers(stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token)
def get_request_body_data(
self,
*,
stream_state: Optional[StreamState] = None,
stream_slice: Optional[StreamSlice] = None,
next_page_token: Optional[Mapping[str, Any]] = None,
) -> Mapping[str, Any]:
return self._decorated.get_request_body_data(stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token)
def get_request_body_json(
self,
*,
stream_state: Optional[StreamState] = None,
stream_slice: Optional[StreamSlice] = None,
next_page_token: Optional[Mapping[str, Any]] = None,
) -> Mapping[str, Any]:
return self._decorated.get_request_body_json(stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token)
def reset(self):
self._decorated.reset()
self._page_count = self._PAGE_COUNT_BEFORE_FIRST_NEXT_CALL

View File

@@ -2,7 +2,7 @@
# Copyright (c) 2022 Airbyte, Inc., all rights reserved.
#
from abc import abstractmethod
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, List, Mapping, Optional
@@ -12,7 +12,7 @@ from dataclasses_jsonschema import JsonSchemaMixin
@dataclass
class Paginator(RequestOptionsProvider, JsonSchemaMixin):
class Paginator(ABC, RequestOptionsProvider, JsonSchemaMixin):
"""
Defines the token to use to fetch the next page of records from the API.

View File

@@ -3,6 +3,6 @@
#
from airbyte_cdk.sources.declarative.retrievers.retriever import Retriever
from airbyte_cdk.sources.declarative.retrievers.simple_retriever import SimpleRetriever
from airbyte_cdk.sources.declarative.retrievers.simple_retriever import SimpleRetriever, SimpleRetrieverTestReadDecorator
__all__ = ["Retriever", "SimpleRetriever"]
__all__ = ["Retriever", "SimpleRetriever", "SimpleRetrieverTestReadDecorator"]

View File

@@ -5,6 +5,7 @@
import json
import logging
from dataclasses import InitVar, dataclass, field
from itertools import islice
from json import JSONDecodeError
from typing import Any, Iterable, List, Mapping, MutableMapping, Optional, Union
@@ -416,6 +417,28 @@ class SimpleRetriever(Retriever, HttpStream, JsonSchemaMixin):
yield from self.parse_response(response, stream_slice=stream_slice, stream_state=stream_state)
@dataclass
class SimpleRetrieverTestReadDecorator(SimpleRetriever):
"""
In some cases, we want to limit the number of requests that are made to the backend source. This class allows for limiting the number of
slices that are queried throughout a read command.
"""
maximum_number_of_slices: int = 5
def __post_init__(self, options: Mapping[str, Any]):
super().__post_init__(options)
if self.maximum_number_of_slices and self.maximum_number_of_slices < 1:
raise ValueError(
f"The maximum number of slices on a test read needs to be strictly positive. Got {self.maximum_number_of_slices}"
)
def stream_slices(
self, *, sync_mode: SyncMode, cursor_field: List[str] = None, stream_state: Optional[StreamState] = None
) -> Iterable[Optional[Mapping[str, Any]]]:
return islice(super().stream_slices(sync_mode=sync_mode, stream_state=stream_state), self.maximum_number_of_slices)
def _prepared_request_to_airbyte_message(request: requests.PreparedRequest) -> AirbyteMessage:
# FIXME: this should return some sort of trace message
request_dict = {

View File

@@ -15,7 +15,7 @@ README = (HERE / "README.md").read_text()
setup(
name="airbyte-cdk",
version="0.22.0",
version="0.23.0",
description="A framework for writing Airbyte Connectors.",
long_description=README,
long_description_content_type="text/markdown",

View File

@@ -9,7 +9,12 @@ import pytest
import requests
from airbyte_cdk.sources.declarative.decoders.json_decoder import JsonDecoder
from airbyte_cdk.sources.declarative.interpolation.interpolated_boolean import InterpolatedBoolean
from airbyte_cdk.sources.declarative.requesters.paginators.default_paginator import DefaultPaginator, RequestOption, RequestOptionType
from airbyte_cdk.sources.declarative.requesters.paginators.default_paginator import (
DefaultPaginator,
PaginatorTestReadDecorator,
RequestOption,
RequestOptionType,
)
from airbyte_cdk.sources.declarative.requesters.paginators.strategies.cursor_pagination_strategy import CursorPaginationStrategy
@@ -202,3 +207,25 @@ def test_reset():
strategy = MagicMock()
DefaultPaginator(strategy, config, url_base, options={}, page_size_option=page_size_request_option, page_token_option=page_token_request_option).reset()
assert strategy.reset.called
def test_limit_page_fetched():
maximum_number_of_pages = 5
number_of_next_performed = maximum_number_of_pages - 1
paginator = PaginatorTestReadDecorator(
DefaultPaginator(
page_size_option=MagicMock(),
page_token_option=MagicMock(),
pagination_strategy=MagicMock(),
config=MagicMock(),
url_base=MagicMock(),
options={},
),
maximum_number_of_pages
)
for _ in range(number_of_next_performed):
last_token = paginator.next_page_token(MagicMock(), MagicMock())
assert last_token
assert not paginator.next_page_token(MagicMock(), MagicMock())

View File

@@ -15,6 +15,7 @@ from airbyte_cdk.sources.declarative.requesters.request_option import RequestOpt
from airbyte_cdk.sources.declarative.requesters.requester import HttpMethod
from airbyte_cdk.sources.declarative.retrievers.simple_retriever import (
SimpleRetriever,
SimpleRetrieverTestReadDecorator,
_prepared_request_to_airbyte_message,
_response_to_airbyte_message,
)
@@ -629,3 +630,28 @@ def test_response_to_airbyte_message(test_name, response_body, response_headers,
actual_airbyte_message = _response_to_airbyte_message(response)
assert expected_airbyte_message == actual_airbyte_message
def test_limit_stream_slices():
maximum_number_of_slices = 4
stream_slicer = MagicMock()
stream_slicer.stream_slices.return_value = _generate_slices(maximum_number_of_slices * 2)
retriever = SimpleRetrieverTestReadDecorator(
name="stream_name",
primary_key=primary_key,
requester=MagicMock(),
paginator=MagicMock(),
record_selector=MagicMock(),
stream_slicer=stream_slicer,
maximum_number_of_slices=maximum_number_of_slices,
options={},
config={},
)
truncated_slices = list(retriever.stream_slices(sync_mode=SyncMode.incremental, stream_state=None))
assert truncated_slices == _generate_slices(maximum_number_of_slices)
def _generate_slices(number_of_slices):
return [{"date": f"2022-01-0{day + 1}"} for day in range(number_of_slices)]

View File

@@ -6,6 +6,7 @@ import json
import logging
import os
import sys
from unittest.mock import patch
import pytest
import yaml
@@ -541,6 +542,94 @@ class TestManifestDeclarativeSource:
with pytest.raises(ValidationError):
ManifestDeclarativeSource(source_config=manifest, construct_using_pydantic_models=construct_using_pydantic_models)
@patch("airbyte_cdk.sources.declarative.declarative_source.DeclarativeSource.read")
def test_given_debug_when_read_then_set_log_level(self, declarative_source_read):
any_valid_manifest = {
"version": "version",
"definitions": {
"schema_loader": {"name": "{{ options.stream_name }}", "file_path": "./source_sendgrid/schemas/{{ options.name }}.yaml"},
"retriever": {
"paginator": {
"type": "DefaultPaginator",
"page_size": 10,
"page_size_option": {"inject_into": "request_parameter", "field_name": "page_size"},
"page_token_option": {"inject_into": "path"},
"pagination_strategy": {"type": "CursorPagination", "cursor_value": "{{ response._metadata.next }}"},
},
"requester": {
"path": "/v3/marketing/lists",
"authenticator": {"type": "BearerAuthenticator", "api_token": "{{ config.apikey }}"},
"request_parameters": {"page_size": 10},
},
"record_selector": {"extractor": {"field_pointer": ["result"]}},
},
},
"streams": [
{
"type": "DeclarativeStream",
"$options": {"name": "lists", "primary_key": "id", "url_base": "https://api.sendgrid.com"},
"schema_loader": {
"name": "{{ options.stream_name }}",
"file_path": "./source_sendgrid/schemas/{{ options.name }}.yaml",
},
"retriever": {
"paginator": {
"type": "DefaultPaginator",
"page_size": 10,
"page_size_option": {"inject_into": "request_parameter", "field_name": "page_size"},
"page_token_option": {"inject_into": "path"},
"pagination_strategy": {
"type": "CursorPagination",
"cursor_value": "{{ response._metadata.next }}",
"page_size": 10,
},
},
"requester": {
"path": "/v3/marketing/lists",
"authenticator": {"type": "BearerAuthenticator", "api_token": "{{ config.apikey }}"},
"request_parameters": {"page_size": 10},
},
"record_selector": {"extractor": {"field_pointer": ["result"]}},
},
},
{
"type": "DeclarativeStream",
"$options": {"name": "stream_with_custom_requester", "primary_key": "id", "url_base": "https://api.sendgrid.com"},
"schema_loader": {
"name": "{{ options.stream_name }}",
"file_path": "./source_sendgrid/schemas/{{ options.name }}.yaml",
},
"retriever": {
"paginator": {
"type": "DefaultPaginator",
"page_size": 10,
"page_size_option": {"inject_into": "request_parameter", "field_name": "page_size"},
"page_token_option": {"inject_into": "path"},
"pagination_strategy": {
"type": "CursorPagination",
"cursor_value": "{{ response._metadata.next }}",
"page_size": 10,
},
},
"requester": {
"type": "CustomRequester",
"class_name": "unit_tests.sources.declarative.external_component.SampleCustomComponent",
"path": "/v3/marketing/lists",
"custom_request_parameters": {"page_size": 10},
},
"record_selector": {"extractor": {"field_pointer": ["result"]}},
},
},
],
"check": {"type": "CheckStream", "stream_names": ["lists"]},
}
source = ManifestDeclarativeSource(source_config=any_valid_manifest, debug=True, construct_using_pydantic_models=True)
debug_logger = logging.getLogger("logger.debug")
list(source.read(debug_logger, {}, {}, {}))
assert debug_logger.isEnabledFor(logging.DEBUG)
def test_generate_schema():
schema_str = ManifestDeclarativeSource.generate_schema()

View File

@@ -331,6 +331,57 @@ def test_valid_full_refresh_read_with_slices(mocker):
assert expected == messages
def test_read_full_refresh_with_slices_sends_slice_messages(mocker):
"""Given the logger is debug and a full refresh, AirbyteMessages are sent for slices"""
debug_logger = logging.getLogger("airbyte.debug")
debug_logger.setLevel(logging.DEBUG)
slices = [{"1": "1"}, {"2": "2"}]
stream = MockStream(
[({"sync_mode": SyncMode.full_refresh, "stream_slice": s}, [s]) for s in slices],
name="s1",
)
mocker.patch.object(MockStream, "get_json_schema", return_value={})
mocker.patch.object(MockStream, "stream_slices", return_value=slices)
src = MockSource(streams=[stream])
catalog = ConfiguredAirbyteCatalog(
streams=[
_configured_stream(stream, SyncMode.full_refresh),
]
)
messages = src.read(debug_logger, {}, catalog)
assert 2 == len(list(filter(lambda message: message.log and message.log.message.startswith("slice:"), messages)))
def test_read_incremental_with_slices_sends_slice_messages(mocker):
"""Given the logger is debug and a incremental, AirbyteMessages are sent for slices"""
debug_logger = logging.getLogger("airbyte.debug")
debug_logger.setLevel(logging.DEBUG)
slices = [{"1": "1"}, {"2": "2"}]
stream = MockStream(
[({"sync_mode": SyncMode.incremental, "stream_slice": s, 'stream_state': {}}, [s]) for s in slices],
name="s1",
)
MockStream.supports_incremental = mocker.PropertyMock(return_value=True)
mocker.patch.object(MockStream, "get_json_schema", return_value={})
mocker.patch.object(MockStream, "stream_slices", return_value=slices)
src = MockSource(streams=[stream])
catalog = ConfiguredAirbyteCatalog(
streams=[
_configured_stream(stream, SyncMode.incremental),
]
)
messages = src.read(debug_logger, {}, catalog)
assert 2 == len(list(filter(lambda message: message.log and message.log.message.startswith("slice:"), messages)))
class TestIncrementalRead:
@pytest.mark.parametrize(
"use_legacy",

View File

@@ -401,6 +401,7 @@ def test_internal_config_limit(mocker, abstract_source, catalog):
logger_mock.level = logging.DEBUG
del catalog.streams[1]
STREAM_LIMIT = 2
SLICE_DEBUG_LOG_COUNT = 1
FULL_RECORDS_NUMBER = 3
streams = abstract_source.streams(None)
http_stream = streams[0]
@@ -409,7 +410,7 @@ def test_internal_config_limit(mocker, abstract_source, catalog):
catalog.streams[0].sync_mode = SyncMode.full_refresh
records = [r for r in abstract_source.read(logger=logger_mock, config=internal_config, catalog=catalog, state={})]
assert len(records) == STREAM_LIMIT
assert len(records) == STREAM_LIMIT + SLICE_DEBUG_LOG_COUNT
logger_info_args = [call[0][0] for call in logger_mock.info.call_args_list]
# Check if log line matches number of limit
read_log_record = [_l for _l in logger_info_args if _l.startswith("Read")]
@@ -418,13 +419,13 @@ def test_internal_config_limit(mocker, abstract_source, catalog):
# No limit, check if state record produced for incremental stream
catalog.streams[0].sync_mode = SyncMode.incremental
records = [r for r in abstract_source.read(logger=logger_mock, config={}, catalog=catalog, state={})]
assert len(records) == FULL_RECORDS_NUMBER + 1
assert len(records) == FULL_RECORDS_NUMBER + SLICE_DEBUG_LOG_COUNT + 1
assert records[-1].type == Type.STATE
# Set limit and check if state is produced when limit is set for incremental stream
logger_mock.reset_mock()
records = [r for r in abstract_source.read(logger=logger_mock, config=internal_config, catalog=catalog, state={})]
assert len(records) == STREAM_LIMIT + 1
assert len(records) == STREAM_LIMIT + SLICE_DEBUG_LOG_COUNT + 1
assert records[-1].type == Type.STATE
logger_info_args = [call[0][0] for call in logger_mock.info.call_args_list]
read_log_record = [_l for _l in logger_info_args if _l.startswith("Read")]
@@ -435,6 +436,7 @@ SCHEMA = {"type": "object", "properties": {"value": {"type": "string"}}}
def test_source_config_no_transform(mocker, abstract_source, catalog):
SLICE_DEBUG_LOG_COUNT = 1
logger_mock = mocker.MagicMock()
logger_mock.level = logging.DEBUG
streams = abstract_source.streams(None)
@@ -442,8 +444,8 @@ def test_source_config_no_transform(mocker, abstract_source, catalog):
http_stream.get_json_schema.return_value = non_http_stream.get_json_schema.return_value = SCHEMA
http_stream.read_records.return_value, non_http_stream.read_records.return_value = [[{"value": 23}] * 5] * 2
records = [r for r in abstract_source.read(logger=logger_mock, config={}, catalog=catalog, state={})]
assert len(records) == 2 * 5
assert [r.record.data for r in records] == [{"value": 23}] * 2 * 5
assert len(records) == 2 * (5 + SLICE_DEBUG_LOG_COUNT)
assert [r.record.data for r in records if r.type == Type.RECORD] == [{"value": 23}] * 2 * 5
assert http_stream.get_json_schema.call_count == 5
assert non_http_stream.get_json_schema.call_count == 5
@@ -451,6 +453,7 @@ def test_source_config_no_transform(mocker, abstract_source, catalog):
def test_source_config_transform(mocker, abstract_source, catalog):
logger_mock = mocker.MagicMock()
logger_mock.level = logging.DEBUG
SLICE_DEBUG_LOG_COUNT = 2
streams = abstract_source.streams(None)
http_stream, non_http_stream = streams
http_stream.transformer = TypeTransformer(TransformConfig.DefaultSchemaNormalization)
@@ -458,21 +461,22 @@ def test_source_config_transform(mocker, abstract_source, catalog):
http_stream.get_json_schema.return_value = non_http_stream.get_json_schema.return_value = SCHEMA
http_stream.read_records.return_value, non_http_stream.read_records.return_value = [{"value": 23}], [{"value": 23}]
records = [r for r in abstract_source.read(logger=logger_mock, config={}, catalog=catalog, state={})]
assert len(records) == 2
assert [r.record.data for r in records] == [{"value": "23"}] * 2
assert len(records) == 2 + SLICE_DEBUG_LOG_COUNT
assert [r.record.data for r in records if r.type == Type.RECORD] == [{"value": "23"}] * 2
def test_source_config_transform_and_no_transform(mocker, abstract_source, catalog):
logger_mock = mocker.MagicMock()
logger_mock.level = logging.DEBUG
SLICE_DEBUG_LOG_COUNT = 2
streams = abstract_source.streams(None)
http_stream, non_http_stream = streams
http_stream.transformer = TypeTransformer(TransformConfig.DefaultSchemaNormalization)
http_stream.get_json_schema.return_value = non_http_stream.get_json_schema.return_value = SCHEMA
http_stream.read_records.return_value, non_http_stream.read_records.return_value = [{"value": 23}], [{"value": 23}]
records = [r for r in abstract_source.read(logger=logger_mock, config={}, catalog=catalog, state={})]
assert len(records) == 2
assert [r.record.data for r in records] == [{"value": "23"}, {"value": 23}]
assert len(records) == 2 + SLICE_DEBUG_LOG_COUNT
assert [r.record.data for r in records if r.type == Type.RECORD] == [{"value": "23"}, {"value": 23}]
def test_read_default_http_availability_strategy_stream_available(catalog, mocker):

View File

@@ -4,10 +4,14 @@
from connector_builder.generated.apis.default_api_interface import initialize_router
from connector_builder.impl.default_api import DefaultApiImpl
from connector_builder.impl.low_code_cdk_adapter import LowCodeSourceAdapter
from connector_builder.impl.low_code_cdk_adapter import LowCodeSourceAdapterFactory
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
_MAXIMUM_NUMBER_OF_PAGES_PER_SLICE = 5
_MAXIMUM_NUMBER_OF_SLICES = 5
_ADAPTER_FACTORY = LowCodeSourceAdapterFactory(_MAXIMUM_NUMBER_OF_PAGES_PER_SLICE, _MAXIMUM_NUMBER_OF_SLICES)
app = FastAPI(
title="Connector Builder Server API",
description="Connector Builder Server API ",
@@ -22,4 +26,4 @@ app.add_middleware(
allow_headers=["*"],
)
app.include_router(initialize_router(DefaultApiImpl(LowCodeSourceAdapter)))
app.include_router(initialize_router(DefaultApiImpl(_ADAPTER_FACTORY, _MAXIMUM_NUMBER_OF_PAGES_PER_SLICE, _MAXIMUM_NUMBER_OF_SLICES)))

View File

@@ -19,11 +19,13 @@ class StreamRead(BaseModel):
logs: The logs of this StreamRead.
slices: The slices of this StreamRead.
test_read_limit_reached: The test_read_limit_reached of this StreamRead.
inferred_schema: The inferred_schema of this StreamRead [Optional].
"""
logs: List[object]
slices: List[StreamReadSlices]
test_read_limit_reached: bool
inferred_schema: Optional[Dict[str, Any]] = None
StreamRead.update_forward_refs()

View File

@@ -32,3 +32,11 @@ class CdkAdapter(ABC):
:param config: The user-provided configuration as specified by the source's spec.
:return: An iterator over `AirbyteMessage` objects.
"""
class CdkAdapterFactory(ABC):
@abstractmethod
def create(self, manifest: Dict[str, Any]) -> CdkAdapter:
"""Return an implementation of CdkAdapter"""
pass

View File

@@ -6,7 +6,7 @@ import json
import logging
import traceback
from json import JSONDecodeError
from typing import Any, Callable, Dict, Iterable, Iterator, Optional, Union
from typing import Any, Dict, Iterable, Iterator, Optional, Union
from urllib.parse import parse_qs, urljoin, urlparse
from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, Type
@@ -24,17 +24,21 @@ from connector_builder.generated.models.stream_read_slices import StreamReadSlic
from connector_builder.generated.models.streams_list_read import StreamsListRead
from connector_builder.generated.models.streams_list_read_streams import StreamsListReadStreams
from connector_builder.generated.models.streams_list_request_body import StreamsListRequestBody
from connector_builder.impl.adapter import CdkAdapter
from connector_builder.impl.adapter import CdkAdapter, CdkAdapterFactory
from fastapi import Body, HTTPException
from jsonschema import ValidationError
class DefaultApiImpl(DefaultApi):
logger = logging.getLogger("airbyte.connector-builder")
def __init__(self, adapter_cls: Callable[[Dict[str, Any]], CdkAdapter], max_record_limit: int = 1000):
self.adapter_cls = adapter_cls
def __init__(self, adapter_factory: CdkAdapterFactory, max_pages_per_slice, max_slices, max_record_limit: int = 1000):
self.adapter_factory = adapter_factory
self._max_pages_per_slice = max_pages_per_slice
self._max_slices = max_slices
self.max_record_limit = max_record_limit
super().__init__()
async def get_manifest_template(self) -> str:
@@ -128,7 +132,7 @@ spec:
else:
record_limit = min(stream_read_request_body.record_limit, self.max_record_limit)
single_slice = StreamReadSlices(pages=[])
slices = []
log_messages = []
try:
for message_group in self._get_message_groups(
@@ -139,7 +143,7 @@ spec:
if isinstance(message_group, AirbyteLogMessage):
log_messages.append({"message": message_group.message})
else:
single_slice.pages.append(message_group)
slices.append(message_group)
except Exception as error:
# TODO: We're temporarily using FastAPI's default exception model. Ideally we should use exceptions defined in the OpenAPI spec
self.logger.error(f"Could not perform read with with error: {error.args[0]} - {self._get_stacktrace_as_string(error)}")
@@ -149,9 +153,21 @@ spec:
)
return StreamRead(
logs=log_messages, slices=[single_slice], inferred_schema=schema_inferrer.get_stream_schema(stream_read_request_body.stream)
logs=log_messages,
slices=slices,
test_read_limit_reached=self._has_reached_limit(slices),
inferred_schema=schema_inferrer.get_stream_schema(stream_read_request_body.stream)
)
def _has_reached_limit(self, slices):
if len(slices) >= self._max_slices:
return True
for slice in slices:
if len(slice.pages) >= self._max_pages_per_slice:
return True
return False
async def resolve_manifest(
self, resolve_manifest_request_body: ResolveManifestRequestBody = Body(None, description="")
) -> ResolveManifest:
@@ -191,33 +207,56 @@ spec:
Note: The exception is that normal log messages can be received at any time which are not incorporated into grouping
"""
first_page = True
current_records = []
records_count = 0
at_least_one_page_in_group = False
current_page_records = []
current_slice_pages = []
current_page_request: Optional[HttpRequest] = None
current_page_response: Optional[HttpResponse] = None
while len(current_records) < limit and (message := next(messages, None)):
if first_page and message.type == Type.LOG and message.log.message.startswith("request:"):
first_page = False
request = self._create_request_from_log_message(message.log)
current_page_request = request
while records_count < limit and (message := next(messages, None)):
if self._need_to_close_page(at_least_one_page_in_group, message):
self._close_page(current_page_request, current_page_response, current_slice_pages, current_page_records)
current_page_request = None
current_page_response = None
if at_least_one_page_in_group and message.type == Type.LOG and message.log.message.startswith("slice:"):
yield StreamReadSlices(pages=current_slice_pages)
current_slice_pages = []
at_least_one_page_in_group = False
elif message.type == Type.LOG and message.log.message.startswith("request:"):
if not current_page_request or not current_page_response:
raise ValueError("Every message grouping should have at least one request and response")
yield StreamReadPages(request=current_page_request, response=current_page_response, records=current_records)
if not at_least_one_page_in_group:
at_least_one_page_in_group = True
current_page_request = self._create_request_from_log_message(message.log)
current_records = []
elif message.type == Type.LOG and message.log.message.startswith("response:"):
current_page_response = self._create_response_from_log_message(message.log)
elif message.type == Type.LOG:
yield message.log
elif message.type == Type.RECORD:
current_records.append(message.record.data)
current_page_records.append(message.record.data)
records_count += 1
schema_inferrer.accumulate(message.record)
else:
if not current_page_request or not current_page_response:
raise ValueError("Every message grouping should have at least one request and response")
yield StreamReadPages(request=current_page_request, response=current_page_response, records=current_records)
self._close_page(current_page_request, current_page_response, current_slice_pages, current_page_records)
yield StreamReadSlices(pages=current_slice_pages)
@staticmethod
def _need_to_close_page(at_least_one_page_in_group, message):
return (
at_least_one_page_in_group
and message.type == Type.LOG
and (message.log.message.startswith("request:") or message.log.message.startswith("slice:"))
)
@staticmethod
def _close_page(current_page_request, current_page_response, current_slice_pages, current_page_records):
if not current_page_request or not current_page_response:
raise ValueError("Every message grouping should have at least one request and response")
current_slice_pages.append(
StreamReadPages(request=current_page_request, response=current_page_response, records=current_page_records)
)
current_page_records.clear()
def _create_request_from_log_message(self, log_message: AirbyteLogMessage) -> Optional[HttpRequest]:
# TODO: As a temporary stopgap, the CDK emits request data as a log message string. Ideally this should come in the
@@ -255,7 +294,7 @@ spec:
def _create_low_code_adapter(self, manifest: Dict[str, Any]) -> CdkAdapter:
try:
return self.adapter_cls(manifest=manifest)
return self.adapter_factory.create(manifest)
except ValidationError as error:
# TODO: We're temporarily using FastAPI's default exception model. Ideally we should use exceptions defined in the OpenAPI spec
self.logger.error(f"Invalid connector manifest with error: {error.message} - {DefaultApiImpl._get_stacktrace_as_string(error)}")

View File

@@ -7,15 +7,20 @@ from typing import Any, Dict, Iterator, List
from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, ConfiguredAirbyteCatalog, Level
from airbyte_cdk.models import Type as MessageType
from airbyte_cdk.sources.declarative.declarative_stream import DeclarativeStream
from airbyte_cdk.sources.declarative.parsers.model_to_component_factory import ModelToComponentFactory
from airbyte_cdk.sources.declarative.yaml_declarative_source import ManifestDeclarativeSource
from airbyte_cdk.sources.streams.http import HttpStream
from connector_builder.impl.adapter import CdkAdapter
from connector_builder.impl.adapter import CdkAdapter, CdkAdapterFactory
class LowCodeSourceAdapter(CdkAdapter):
def __init__(self, manifest: Dict[str, Any]):
def __init__(self, manifest: Dict[str, Any], limit_page_fetched_per_slice, limit_slices_fetched):
# Request and response messages are only emitted for a sources that have debug turned on
self._source = ManifestDeclarativeSource(manifest, debug=True)
self._source = ManifestDeclarativeSource(
manifest,
debug=True,
component_factory=ModelToComponentFactory(limit_page_fetched_per_slice, limit_slices_fetched)
)
def get_http_streams(self, config: Dict[str, Any]) -> List[HttpStream]:
http_streams = []
@@ -58,3 +63,13 @@ class LowCodeSourceAdapter(CdkAdapter):
except Exception as e:
yield AirbyteMessage(type=MessageType.LOG, log=AirbyteLogMessage(level=Level.ERROR, message=str(e)))
return
class LowCodeSourceAdapterFactory(CdkAdapterFactory):
def __init__(self, max_pages_per_slice, max_slices):
self._max_pages_per_slice = max_pages_per_slice
self._max_slices = max_slices
def create(self, manifest: Dict[str, Any]) -> CdkAdapter:
return LowCodeSourceAdapter(manifest, self._max_pages_per_slice, self._max_slices)

View File

@@ -41,7 +41,7 @@ setup(
},
packages=find_packages(exclude=("unit_tests", "integration_tests", "docs")),
package_data={},
install_requires=["airbyte-cdk==0.22", "fastapi", "uvicorn"],
install_requires=["airbyte-cdk==0.23", "fastapi", "uvicorn"],
python_requires=">=3.9.11",
extras_require={
"tests": [

View File

@@ -98,6 +98,7 @@ components:
required:
- logs
- slices
- test_read_limit_reached
properties:
logs:
type: array
@@ -144,6 +145,9 @@ components:
type: object
description: The STATE AirbyteMessage emitted at the end of this slice. This can be omitted if a stream slicer is not configured.
# $ref: "#/components/schemas/AirbyteProtocol/definitions/AirbyteStateMessage"
test_read_limit_reached:
type: boolean
description: Whether the maximum number of request per slice or the maximum number of slices queried has been reached
inferred_schema:
type: object
description: The narrowest JSON Schema against which every AirbyteRecord in the slices can validate successfully. This is inferred from reading every record in the output slices.

View File

@@ -20,10 +20,13 @@ from connector_builder.generated.models.streams_list_read import StreamsListRead
from connector_builder.generated.models.streams_list_read_streams import StreamsListReadStreams
from connector_builder.generated.models.streams_list_request_body import StreamsListRequestBody
from connector_builder.impl.default_api import DefaultApiImpl
from connector_builder.impl.low_code_cdk_adapter import LowCodeSourceAdapter
from connector_builder.impl.low_code_cdk_adapter import LowCodeSourceAdapterFactory
from fastapi import HTTPException
from pydantic.error_wrappers import ValidationError
MAX_PAGES_PER_SLICE = 4
MAX_SLICES = 3
MANIFEST = {
"version": "0.1.0",
"type": "DeclarativeSource",
@@ -95,13 +98,17 @@ def record_message(stream: str, data: dict) -> AirbyteMessage:
return AirbyteMessage(type=Type.RECORD, record=AirbyteRecordMessage(stream=stream, data=data, emitted_at=1234))
def slice_message() -> AirbyteMessage:
return AirbyteMessage(type=Type.LOG, log=AirbyteLogMessage(level=Level.INFO, message='slice:{"key": "value"}'))
def test_list_streams():
expected_streams = [
StreamsListReadStreams(name="hashiras", url="https://demonslayers.com/api/v1/hashiras"),
StreamsListReadStreams(name="breathing-techniques", url="https://demonslayers.com/api/v1/breathing_techniques"),
]
api = DefaultApiImpl(LowCodeSourceAdapter)
api = DefaultApiImpl(LowCodeSourceAdapterFactory(MAX_PAGES_PER_SLICE, MAX_SLICES), MAX_PAGES_PER_SLICE, MAX_SLICES)
streams_list_request_body = StreamsListRequestBody(manifest=MANIFEST, config=CONFIG)
loop = asyncio.get_event_loop()
actual_streams = loop.run_until_complete(api.list_streams(streams_list_request_body))
@@ -135,7 +142,7 @@ def test_list_streams_with_interpolated_urls():
expected_streams = StreamsListRead(streams=[StreamsListReadStreams(name="demons", url="https://upper-six.muzan.com/api/v1/demons")])
api = DefaultApiImpl(LowCodeSourceAdapter)
api = DefaultApiImpl(LowCodeSourceAdapterFactory(MAX_PAGES_PER_SLICE, MAX_SLICES), MAX_PAGES_PER_SLICE, MAX_SLICES)
streams_list_request_body = StreamsListRequestBody(manifest=manifest, config=CONFIG)
loop = asyncio.get_event_loop()
actual_streams = loop.run_until_complete(api.list_streams(streams_list_request_body))
@@ -169,7 +176,7 @@ def test_list_streams_with_unresolved_interpolation():
# The interpolated string {{ config['not_in_config'] }} doesn't resolve to anything so it ends up blank during interpolation
expected_streams = StreamsListRead(streams=[StreamsListReadStreams(name="demons", url="https://.muzan.com/api/v1/demons")])
api = DefaultApiImpl(LowCodeSourceAdapter)
api = DefaultApiImpl(LowCodeSourceAdapterFactory(MAX_PAGES_PER_SLICE, MAX_SLICES), MAX_PAGES_PER_SLICE, MAX_SLICES)
streams_list_request_body = StreamsListRequestBody(manifest=manifest, config=CONFIG)
loop = asyncio.get_event_loop()
@@ -212,7 +219,7 @@ def test_read_stream():
),
]
mock_source_adapter_cls = make_mock_adapter_cls(
mock_source_adapter_cls = make_mock_adapter_factory(
iter(
[
request_log_message(request),
@@ -226,7 +233,7 @@ def test_read_stream():
)
)
api = DefaultApiImpl(mock_source_adapter_cls)
api = DefaultApiImpl(mock_source_adapter_cls, MAX_PAGES_PER_SLICE, MAX_SLICES)
loop = asyncio.get_event_loop()
actual_response: StreamRead = loop.run_until_complete(
@@ -277,7 +284,7 @@ def test_read_stream_with_logs():
{"message": "log message after the response"},
]
mock_source_adapter_cls = make_mock_adapter_cls(
mock_source_adapter_cls = make_mock_adapter_factory(
iter(
[
AirbyteMessage(type=Type.LOG, log=AirbyteLogMessage(level=Level.INFO, message="log message before the request")),
@@ -291,7 +298,7 @@ def test_read_stream_with_logs():
)
)
api = DefaultApiImpl(mock_source_adapter_cls)
api = DefaultApiImpl(mock_source_adapter_cls, MAX_PAGES_PER_SLICE, MAX_SLICES)
loop = asyncio.get_event_loop()
actual_response: StreamRead = loop.run_until_complete(
@@ -320,7 +327,7 @@ def test_read_stream_record_limit(request_record_limit, max_record_limit):
"body": {"custom": "field"},
}
response = {"status_code": 200, "headers": {"field": "value"}, "body": '{"name": "field"}'}
mock_source_adapter_cls = make_mock_adapter_cls(
mock_source_adapter_cls = make_mock_adapter_factory(
iter(
[
request_log_message(request),
@@ -337,7 +344,7 @@ def test_read_stream_record_limit(request_record_limit, max_record_limit):
n_records = 2
record_limit = min(request_record_limit, max_record_limit)
api = DefaultApiImpl(mock_source_adapter_cls, max_record_limit=max_record_limit)
api = DefaultApiImpl(mock_source_adapter_cls, MAX_PAGES_PER_SLICE, MAX_SLICES, max_record_limit=max_record_limit)
loop = asyncio.get_event_loop()
actual_response: StreamRead = loop.run_until_complete(
api.read_stream(StreamReadRequestBody(manifest=MANIFEST, config=CONFIG, stream="hashiras", record_limit=request_record_limit))
@@ -363,7 +370,7 @@ def test_read_stream_default_record_limit(max_record_limit):
"body": {"custom": "field"},
}
response = {"status_code": 200, "headers": {"field": "value"}, "body": '{"name": "field"}'}
mock_source_adapter_cls = make_mock_adapter_cls(
mock_source_adapter_cls = make_mock_adapter_factory(
iter(
[
request_log_message(request),
@@ -379,7 +386,7 @@ def test_read_stream_default_record_limit(max_record_limit):
)
n_records = 2
api = DefaultApiImpl(mock_source_adapter_cls, max_record_limit=max_record_limit)
api = DefaultApiImpl(mock_source_adapter_cls, MAX_PAGES_PER_SLICE, MAX_SLICES, max_record_limit=max_record_limit)
loop = asyncio.get_event_loop()
actual_response: StreamRead = loop.run_until_complete(
api.read_stream(StreamReadRequestBody(manifest=MANIFEST, config=CONFIG, stream="hashiras"))
@@ -398,7 +405,7 @@ def test_read_stream_limit_0():
"body": {"custom": "field"},
}
response = {"status_code": 200, "headers": {"field": "value"}, "body": '{"name": "field"}'}
mock_source_adapter_cls = make_mock_adapter_cls(
mock_source_adapter_cls = make_mock_adapter_factory(
iter(
[
request_log_message(request),
@@ -412,7 +419,7 @@ def test_read_stream_limit_0():
]
)
)
api = DefaultApiImpl(mock_source_adapter_cls)
api = DefaultApiImpl(mock_source_adapter_cls, MAX_PAGES_PER_SLICE, MAX_SLICES)
loop = asyncio.get_event_loop()
with pytest.raises(ValidationError):
@@ -453,7 +460,7 @@ def test_read_stream_no_records():
),
]
mock_source_adapter_cls = make_mock_adapter_cls(
mock_source_adapter_cls = make_mock_adapter_factory(
iter(
[
request_log_message(request),
@@ -464,7 +471,7 @@ def test_read_stream_no_records():
)
)
api = DefaultApiImpl(mock_source_adapter_cls)
api = DefaultApiImpl(mock_source_adapter_cls, MAX_PAGES_PER_SLICE, MAX_SLICES)
loop = asyncio.get_event_loop()
actual_response: StreamRead = loop.run_until_complete(
@@ -501,7 +508,7 @@ def test_invalid_manifest():
expected_status_code = 400
api = DefaultApiImpl(LowCodeSourceAdapter)
api = DefaultApiImpl(LowCodeSourceAdapterFactory(MAX_PAGES_PER_SLICE, MAX_SLICES), MAX_PAGES_PER_SLICE, MAX_SLICES)
loop = asyncio.get_event_loop()
with pytest.raises(HTTPException) as actual_exception:
loop.run_until_complete(api.read_stream(StreamReadRequestBody(manifest=invalid_manifest, config={}, stream="hashiras")))
@@ -512,7 +519,7 @@ def test_invalid_manifest():
def test_read_stream_invalid_group_format():
response = {"status_code": 200, "headers": {"field": "value"}, "body": '{"name": "field"}'}
mock_source_adapter_cls = make_mock_adapter_cls(
mock_source_adapter_cls = make_mock_adapter_factory(
iter(
[
response_log_message(response),
@@ -522,7 +529,7 @@ def test_read_stream_invalid_group_format():
)
)
api = DefaultApiImpl(mock_source_adapter_cls)
api = DefaultApiImpl(mock_source_adapter_cls, MAX_PAGES_PER_SLICE, MAX_SLICES)
loop = asyncio.get_event_loop()
with pytest.raises(HTTPException) as actual_exception:
@@ -534,7 +541,7 @@ def test_read_stream_invalid_group_format():
def test_read_stream_returns_error_if_stream_does_not_exist():
expected_status_code = 400
api = DefaultApiImpl(LowCodeSourceAdapter)
api = DefaultApiImpl(LowCodeSourceAdapterFactory(MAX_PAGES_PER_SLICE, MAX_SLICES), MAX_PAGES_PER_SLICE, MAX_SLICES)
loop = asyncio.get_event_loop()
with pytest.raises(HTTPException) as actual_exception:
loop.run_until_complete(api.read_stream(StreamReadRequestBody(manifest=MANIFEST, config={}, stream="not_in_manifest")))
@@ -584,7 +591,7 @@ def test_read_stream_returns_error_if_stream_does_not_exist():
)
def test_create_request_from_log_message(log_message, expected_request):
airbyte_log_message = AirbyteLogMessage(level=Level.INFO, message=log_message)
api = DefaultApiImpl(LowCodeSourceAdapter)
api = DefaultApiImpl(LowCodeSourceAdapterFactory(MAX_PAGES_PER_SLICE, MAX_SLICES), MAX_PAGES_PER_SLICE, MAX_SLICES)
actual_request = api._create_request_from_log_message(airbyte_log_message)
assert actual_request == expected_request
@@ -619,12 +626,101 @@ def test_create_response_from_log_message(log_message, expected_response):
response_message = f"response:{json.dumps(log_message)}"
airbyte_log_message = AirbyteLogMessage(level=Level.INFO, message=response_message)
api = DefaultApiImpl(LowCodeSourceAdapter)
api = DefaultApiImpl(LowCodeSourceAdapterFactory(MAX_PAGES_PER_SLICE, MAX_SLICES), MAX_PAGES_PER_SLICE, MAX_SLICES)
actual_response = api._create_response_from_log_message(airbyte_log_message)
assert actual_response == expected_response
def test_read_stream_with_many_slices():
request = {}
response = {"status_code": 200}
mock_source_adapter_cls = make_mock_adapter_factory(
iter(
[
slice_message(),
request_log_message(request),
response_log_message(response),
record_message("hashiras", {"name": "Muichiro Tokito"}),
slice_message(),
request_log_message(request),
response_log_message(response),
record_message("hashiras", {"name": "Shinobu Kocho"}),
record_message("hashiras", {"name": "Mitsuri Kanroji"}),
request_log_message(request),
response_log_message(response),
record_message("hashiras", {"name": "Obanai Iguro"}),
request_log_message(request),
response_log_message(response),
]
)
)
api = DefaultApiImpl(mock_source_adapter_cls, MAX_PAGES_PER_SLICE, MAX_SLICES)
loop = asyncio.get_event_loop()
stream_read: StreamRead = loop.run_until_complete(
api.read_stream(StreamReadRequestBody(manifest=MANIFEST, config=CONFIG, stream="hashiras"))
)
assert not stream_read.test_read_limit_reached
assert len(stream_read.slices) == 2
assert len(stream_read.slices[0].pages) == 1
assert len(stream_read.slices[0].pages[0].records) == 1
assert len(stream_read.slices[1].pages) == 3
assert len(stream_read.slices[1].pages[0].records) == 2
assert len(stream_read.slices[1].pages[1].records) == 1
assert len(stream_read.slices[1].pages[2].records) == 0
def test_read_stream_given_maximum_number_of_slices_then_test_read_limit_reached():
maximum_number_of_slices = 5
request = {}
response = {"status_code": 200}
mock_source_adapter_cls = make_mock_adapter_factory(
iter(
[
slice_message(),
request_log_message(request),
response_log_message(response)
] * maximum_number_of_slices
)
)
api = DefaultApiImpl(mock_source_adapter_cls, MAX_PAGES_PER_SLICE, MAX_SLICES)
loop = asyncio.get_event_loop()
stream_read: StreamRead = loop.run_until_complete(
api.read_stream(StreamReadRequestBody(manifest=MANIFEST, config=CONFIG, stream="hashiras"))
)
assert stream_read.test_read_limit_reached
def test_read_stream_given_maximum_number_of_pages_then_test_read_limit_reached():
maximum_number_of_pages_per_slice = 5
request = {}
response = {"status_code": 200}
mock_source_adapter_cls = make_mock_adapter_factory(
iter(
[slice_message()] + [request_log_message(request), response_log_message(response)] * maximum_number_of_pages_per_slice
)
)
api = DefaultApiImpl(mock_source_adapter_cls, MAX_PAGES_PER_SLICE, MAX_SLICES)
loop = asyncio.get_event_loop()
stream_read: StreamRead = loop.run_until_complete(
api.read_stream(StreamReadRequestBody(manifest=MANIFEST, config=CONFIG, stream="hashiras"))
)
assert stream_read.test_read_limit_reached
def test_resolve_manifest():
_stream_name = "stream_with_custom_requester"
_stream_primary_key = "id"
@@ -775,7 +871,7 @@ def test_resolve_manifest():
"check": {"type": "CheckStream", "stream_names": ["lists"]},
}
api = DefaultApiImpl(LowCodeSourceAdapter)
api = DefaultApiImpl(LowCodeSourceAdapterFactory(MAX_PAGES_PER_SLICE, MAX_SLICES), MAX_PAGES_PER_SLICE, MAX_SLICES)
loop = asyncio.get_event_loop()
actual_response: ResolveManifest = loop.run_until_complete(api.resolve_manifest(ResolveManifestRequestBody(manifest=manifest)))
@@ -794,7 +890,7 @@ def test_resolve_manifest_unresolvable_references():
"check": {"type": "CheckStream", "stream_names": ["lists"]},
}
api = DefaultApiImpl(LowCodeSourceAdapter)
api = DefaultApiImpl(LowCodeSourceAdapterFactory(MAX_PAGES_PER_SLICE, MAX_SLICES), MAX_PAGES_PER_SLICE, MAX_SLICES)
loop = asyncio.get_event_loop()
with pytest.raises(HTTPException) as actual_exception:
loop.run_until_complete(api.resolve_manifest(ResolveManifestRequestBody(manifest=invalid_manifest)))
@@ -807,7 +903,7 @@ def test_resolve_manifest_invalid():
expected_status_code = 400
invalid_manifest = {"version": "version"}
api = DefaultApiImpl(LowCodeSourceAdapter)
api = DefaultApiImpl(LowCodeSourceAdapterFactory(MAX_PAGES_PER_SLICE, MAX_SLICES), MAX_PAGES_PER_SLICE, MAX_SLICES)
loop = asyncio.get_event_loop()
with pytest.raises(HTTPException) as actual_exception:
loop.run_until_complete(api.resolve_manifest(ResolveManifestRequestBody(manifest=invalid_manifest)))
@@ -816,9 +912,9 @@ def test_resolve_manifest_invalid():
assert actual_exception.value.status_code == expected_status_code
def make_mock_adapter_cls(return_value: Iterator) -> MagicMock:
mock_source_adapter_cls = MagicMock()
def make_mock_adapter_factory(return_value: Iterator) -> MagicMock:
mock_source_adapter_factory = MagicMock()
mock_source_adapter = MagicMock()
mock_source_adapter.read_stream.return_value = return_value
mock_source_adapter_cls.return_value = mock_source_adapter
return mock_source_adapter_cls
mock_source_adapter_factory.create.return_value = mock_source_adapter
return mock_source_adapter_factory

View File

@@ -10,6 +10,8 @@ import pytest
import requests
from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, AirbyteRecordMessage, Level, Type
from airbyte_cdk.sources.declarative.declarative_stream import DeclarativeStream
from airbyte_cdk.sources.declarative.requesters.paginators import PaginatorTestReadDecorator
from airbyte_cdk.sources.declarative.retrievers.simple_retriever import SimpleRetrieverTestReadDecorator
from airbyte_cdk.sources.declarative.parsers.custom_exceptions import UndefinedReferenceException
from airbyte_cdk.sources.streams.http import HttpStream
@@ -176,10 +178,37 @@ MANIFEST_WITH_REFERENCES = {
}
}
MANIFEST_WITH_PAGINATOR = {
"version": "0.1.0",
"type" : "DeclarativeSource",
"definitions": {
},
"streams": [
{
"type" : "DeclarativeStream",
"retriever": {
"type" : "SimpleRetriever",
"record_selector": {"extractor": {"field_pointer": ["items"], "type": "DpathExtractor"}, "type": "RecordSelector"},
"paginator": {
"type": "DefaultPaginator",
"pagination_strategy": {
"type": "OffsetIncrement",
"page_size": 10
},
"url_base": "https://demonslayers.com/api/v1/"
},
"requester": {"url_base": "https://demonslayers.com/api/v1/", "http_method": "GET", "type": "HttpRequester"},
},
"$options": {"name": "hashiras", "path": "/hashiras"},
},
],
"check": {"stream_names": ["hashiras"], "type": "CheckStream"},
}
def test_get_http_streams():
expected_urls = {"https://demonslayers.com/api/v1/breathing_techniques", "https://demonslayers.com/api/v1/hashiras"}
adapter = LowCodeSourceAdapter(MANIFEST)
adapter = LowCodeSourceAdapter(MANIFEST, MAXIMUM_NUMBER_OF_PAGES_PER_SLICE, MAXIMUM_NUMBER_OF_SLICES)
actual_streams = adapter.get_http_streams(config={})
actual_urls = {http_stream.url_base + http_stream.path() for http_stream in actual_streams}
@@ -187,10 +216,13 @@ def test_get_http_streams():
assert actual_urls == expected_urls
MAXIMUM_NUMBER_OF_PAGES_PER_SLICE = 5
MAXIMUM_NUMBER_OF_SLICES = 5
def test_get_http_manifest_with_references():
expected_urls = {"https://demonslayers.com/api/v1/ranks"}
adapter = LowCodeSourceAdapter(MANIFEST_WITH_REFERENCES)
adapter = LowCodeSourceAdapter(MANIFEST_WITH_REFERENCES, MAXIMUM_NUMBER_OF_PAGES_PER_SLICE, MAXIMUM_NUMBER_OF_SLICES)
actual_streams = adapter.get_http_streams(config={})
actual_urls = {http_stream.url_base + http_stream.path() for http_stream in actual_streams}
@@ -204,7 +236,7 @@ def test_get_http_streams_non_declarative_streams():
mock_source = MagicMock()
mock_source.streams.return_value = [non_declarative_stream]
adapter = LowCodeSourceAdapter(MANIFEST)
adapter = LowCodeSourceAdapter(MANIFEST, MAXIMUM_NUMBER_OF_PAGES_PER_SLICE, MAXIMUM_NUMBER_OF_SLICES)
adapter._source = mock_source
with pytest.raises(TypeError):
adapter.get_http_streams(config={})
@@ -217,7 +249,7 @@ def test_get_http_streams_non_http_stream():
mock_source = MagicMock()
mock_source.streams.return_value = [declarative_stream_non_http_retriever]
adapter = LowCodeSourceAdapter(MANIFEST)
adapter = LowCodeSourceAdapter(MANIFEST, MAXIMUM_NUMBER_OF_PAGES_PER_SLICE, MAXIMUM_NUMBER_OF_SLICES)
adapter._source = mock_source
with pytest.raises(TypeError):
adapter.get_http_streams(config={})
@@ -247,7 +279,7 @@ def test_read_streams():
mock_source = MagicMock()
mock_source.read.return_value = expected_messages
adapter = LowCodeSourceAdapter(MANIFEST)
adapter = LowCodeSourceAdapter(MANIFEST, MAXIMUM_NUMBER_OF_PAGES_PER_SLICE, MAXIMUM_NUMBER_OF_SLICES)
adapter._source = mock_source
actual_messages = list(adapter.read_stream("hashiras", {}))
@@ -272,7 +304,7 @@ def test_read_streams_with_error():
mock_source.read.side_effect = return_value
adapter = LowCodeSourceAdapter(MANIFEST)
adapter = LowCodeSourceAdapter(MANIFEST, MAXIMUM_NUMBER_OF_PAGES_PER_SLICE, MAXIMUM_NUMBER_OF_SLICES)
adapter._source = mock_source
actual_messages = list(adapter.read_stream("hashiras", {}))
@@ -309,4 +341,14 @@ def test_read_streams_invalid_reference():
}
with pytest.raises(UndefinedReferenceException):
LowCodeSourceAdapter(invalid_reference_manifest)
LowCodeSourceAdapter(invalid_reference_manifest, MAXIMUM_NUMBER_OF_PAGES_PER_SLICE, MAXIMUM_NUMBER_OF_SLICES)
def test_stream_use_read_test_retriever_and_paginator():
adapter = LowCodeSourceAdapter(MANIFEST_WITH_PAGINATOR, MAXIMUM_NUMBER_OF_PAGES_PER_SLICE, MAXIMUM_NUMBER_OF_SLICES)
streams = adapter.get_http_streams(config={})
assert streams
for stream in streams:
assert isinstance(stream, SimpleRetrieverTestReadDecorator)
assert isinstance(stream.paginator, PaginatorTestReadDecorator)