support semi incremental by adding extractor record filter (#13520)
* support semi incremental by adding extractor record filter * refactor extractor into a record_selector that supports extraction and filtering of response records
This commit is contained in:
@@ -1,15 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2022 Airbyte, Inc., all rights reserved.
|
||||
#
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
|
||||
import requests
|
||||
from airbyte_cdk.sources.declarative.types import Record
|
||||
|
||||
|
||||
class HttpExtractor(ABC):
|
||||
@abstractmethod
|
||||
def extract_records(self, response: requests.Response) -> List[Record]:
|
||||
pass
|
||||
@@ -0,0 +1,21 @@
|
||||
#
|
||||
# Copyright (c) 2022 Airbyte, Inc., all rights reserved.
|
||||
#
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, List, Mapping
|
||||
|
||||
import requests
|
||||
from airbyte_cdk.sources.declarative.types import Record
|
||||
|
||||
|
||||
class HttpSelector(ABC):
|
||||
@abstractmethod
|
||||
def select_records(
|
||||
self,
|
||||
response: requests.Response,
|
||||
stream_state: Mapping[str, Any],
|
||||
stream_slice: Mapping[str, Any] = None,
|
||||
next_page_token: Mapping[str, Any] = None,
|
||||
) -> List[Record]:
|
||||
pass
|
||||
@@ -6,13 +6,12 @@ from typing import List
|
||||
|
||||
import requests
|
||||
from airbyte_cdk.sources.declarative.decoders.decoder import Decoder
|
||||
from airbyte_cdk.sources.declarative.extractors.http_extractor import HttpExtractor
|
||||
from airbyte_cdk.sources.declarative.interpolation.jinja import JinjaInterpolation
|
||||
from airbyte_cdk.sources.declarative.types import Record
|
||||
from jello import lib as jello_lib
|
||||
|
||||
|
||||
class JelloExtractor(HttpExtractor):
|
||||
class JelloExtractor:
|
||||
default_transform = "."
|
||||
|
||||
def __init__(self, transform: str, decoder: Decoder, config, kwargs=None):
|
||||
|
||||
@@ -0,0 +1,24 @@
|
||||
#
|
||||
# Copyright (c) 2022 Airbyte, Inc., all rights reserved.
|
||||
#
|
||||
|
||||
from typing import Any, List, Mapping
|
||||
|
||||
from airbyte_cdk.sources.declarative.interpolation.interpolated_boolean import InterpolatedBoolean
|
||||
from airbyte_cdk.sources.declarative.types import Record
|
||||
|
||||
|
||||
class RecordFilter:
|
||||
def __init__(self, config, condition: str = None):
|
||||
self._config = config
|
||||
self._filter_interpolator = InterpolatedBoolean(condition)
|
||||
|
||||
def filter_records(
|
||||
self,
|
||||
records: List[Record],
|
||||
stream_state: Mapping[str, Any],
|
||||
stream_slice: Mapping[str, Any] = None,
|
||||
next_page_token: Mapping[str, Any] = None,
|
||||
) -> List[Record]:
|
||||
kwargs = {"stream_state": stream_state, "stream_slice": stream_slice, "next_page_token": next_page_token}
|
||||
return [record for record in records if self._filter_interpolator.eval(self._config, record=record, **kwargs)]
|
||||
@@ -0,0 +1,36 @@
|
||||
#
|
||||
# Copyright (c) 2022 Airbyte, Inc., all rights reserved.
|
||||
#
|
||||
|
||||
from typing import Any, List, Mapping
|
||||
|
||||
import requests
|
||||
from airbyte_cdk.sources.declarative.extractors.http_selector import HttpSelector
|
||||
from airbyte_cdk.sources.declarative.extractors.jello import JelloExtractor
|
||||
from airbyte_cdk.sources.declarative.extractors.record_filter import RecordFilter
|
||||
from airbyte_cdk.sources.declarative.types import Record
|
||||
|
||||
|
||||
class RecordSelector(HttpSelector):
|
||||
"""
|
||||
Responsible for translating an HTTP response into a list of records by extracting records from the response and optionally filtering
|
||||
records based on a heuristic.
|
||||
"""
|
||||
|
||||
def __init__(self, extractor: JelloExtractor, record_filter: RecordFilter = None):
|
||||
self._extractor = extractor
|
||||
self._record_filter = record_filter
|
||||
|
||||
def select_records(
|
||||
self,
|
||||
response: requests.Response,
|
||||
stream_state: Mapping[str, Any],
|
||||
stream_slice: Mapping[str, Any] = None,
|
||||
next_page_token: Mapping[str, Any] = None,
|
||||
) -> List[Record]:
|
||||
all_records = self._extractor.extract_records(response)
|
||||
if self._record_filter:
|
||||
return self._record_filter.filter_records(
|
||||
all_records, stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token
|
||||
)
|
||||
return all_records
|
||||
@@ -15,8 +15,8 @@ class ConditionalPaginator:
|
||||
A paginator that performs pagination by incrementing a page number and stops based on a provided stop condition.
|
||||
"""
|
||||
|
||||
def __init__(self, stop_condition_template: str, state: DictState, decoder: Decoder, config):
|
||||
self._stop_condition_template = InterpolatedBoolean(stop_condition_template)
|
||||
def __init__(self, stop_condition: str, state: DictState, decoder: Decoder, config):
|
||||
self._stop_condition_interpolator = InterpolatedBoolean(stop_condition)
|
||||
self._state: DictState = state
|
||||
self._decoder = decoder
|
||||
self._config = config
|
||||
@@ -24,7 +24,7 @@ class ConditionalPaginator:
|
||||
def next_page_token(self, response: requests.Response, last_records: List[Mapping[str, Any]]) -> Optional[Mapping[str, Any]]:
|
||||
decoded_response = self._decoder.decode(response)
|
||||
headers = response.headers
|
||||
should_stop = self._stop_condition_template.eval(
|
||||
should_stop = self._stop_condition_interpolator.eval(
|
||||
self._config, decoded_response=decoded_response, headers=headers, last_records=last_records
|
||||
)
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import Any, Iterable, List, Mapping, MutableMapping, Optional, Union
|
||||
|
||||
import requests
|
||||
from airbyte_cdk.models import SyncMode
|
||||
from airbyte_cdk.sources.declarative.extractors.http_extractor import HttpExtractor
|
||||
from airbyte_cdk.sources.declarative.extractors.http_selector import HttpSelector
|
||||
from airbyte_cdk.sources.declarative.requesters.paginators.paginator import Paginator
|
||||
from airbyte_cdk.sources.declarative.requesters.requester import Requester
|
||||
from airbyte_cdk.sources.declarative.retrievers.retriever import Retriever
|
||||
@@ -22,7 +22,7 @@ class SimpleRetriever(Retriever, HttpStream):
|
||||
primary_key,
|
||||
requester: Requester,
|
||||
paginator: Paginator,
|
||||
extractor: HttpExtractor,
|
||||
record_selector: HttpSelector,
|
||||
stream_slicer: StreamSlicer,
|
||||
state: State,
|
||||
):
|
||||
@@ -30,7 +30,7 @@ class SimpleRetriever(Retriever, HttpStream):
|
||||
self._primary_key = primary_key
|
||||
self._paginator = paginator
|
||||
self._requester = requester
|
||||
self._extractor = extractor
|
||||
self._record_selector = record_selector
|
||||
super().__init__(self._requester.get_authenticator())
|
||||
self._iterator: StreamSlicer = stream_slicer
|
||||
self._state: State = state.deep_copy()
|
||||
@@ -190,7 +190,9 @@ class SimpleRetriever(Retriever, HttpStream):
|
||||
next_page_token: Mapping[str, Any] = None,
|
||||
) -> Iterable[Mapping]:
|
||||
self._last_response = response
|
||||
records = self._extractor.extract_records(response)
|
||||
records = self._record_selector.select_records(
|
||||
response=response, stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token
|
||||
)
|
||||
self._last_records = records
|
||||
return records
|
||||
|
||||
|
||||
@@ -0,0 +1,48 @@
|
||||
#
|
||||
# Copyright (c) 2022 Airbyte, Inc., all rights reserved.
|
||||
#
|
||||
|
||||
import pytest
|
||||
from airbyte_cdk.sources.declarative.extractors.record_filter import RecordFilter
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_name, filter_template, records, expected_records",
|
||||
[
|
||||
(
|
||||
"test_using_state_filter",
|
||||
"{{ record['created_at'] > stream_state['created_at'] }}",
|
||||
[{"id": 1, "created_at": "06-06-21"}, {"id": 2, "created_at": "06-07-21"}, {"id": 3, "created_at": "06-08-21"}],
|
||||
[{"id": 2, "created_at": "06-07-21"}, {"id": 3, "created_at": "06-08-21"}],
|
||||
),
|
||||
(
|
||||
"test_with_slice_filter",
|
||||
"{{ record['last_seen'] >= stream_slice['last_seen'] }}",
|
||||
[{"id": 1, "last_seen": "06-06-21"}, {"id": 2, "last_seen": "06-07-21"}, {"id": 3, "last_seen": "06-10-21"}],
|
||||
[{"id": 3, "last_seen": "06-10-21"}],
|
||||
),
|
||||
(
|
||||
"test_with_next_page_token_filter",
|
||||
"{{ record['id'] >= next_page_token['last_seen_id'] }}",
|
||||
[{"id": 11}, {"id": 12}, {"id": 13}, {"id": 14}, {"id": 15}],
|
||||
[{"id": 14}, {"id": 15}],
|
||||
),
|
||||
(
|
||||
"test_missing_filter_fields_return_no_results",
|
||||
"{{ record['id'] >= next_page_token['path_to_nowhere'] }}",
|
||||
[{"id": 11}, {"id": 12}, {"id": 13}, {"id": 14}, {"id": 15}],
|
||||
[],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_record_filter(test_name, filter_template, records, expected_records):
|
||||
config = {"response_override": "stop_if_you_see_me"}
|
||||
stream_state = {"created_at": "06-06-21"}
|
||||
stream_slice = {"last_seen": "06-10-21"}
|
||||
next_page_token = {"last_seen_id": 14}
|
||||
record_filter = RecordFilter(config=config, condition=filter_template)
|
||||
|
||||
actual_records = record_filter.filter_records(
|
||||
records, stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token
|
||||
)
|
||||
assert actual_records == expected_records
|
||||
@@ -0,0 +1,58 @@
|
||||
#
|
||||
# Copyright (c) 2022 Airbyte, Inc., all rights reserved.
|
||||
#
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from airbyte_cdk.sources.declarative.decoders.json_decoder import JsonDecoder
|
||||
from airbyte_cdk.sources.declarative.extractors.jello import JelloExtractor
|
||||
from airbyte_cdk.sources.declarative.extractors.record_filter import RecordFilter
|
||||
from airbyte_cdk.sources.declarative.extractors.record_selector import RecordSelector
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_name, transform_template, filter_template, body, expected_records",
|
||||
[
|
||||
(
|
||||
"test_with_extractor_and_filter",
|
||||
"_.data",
|
||||
"{{ record['created_at'] > stream_state['created_at'] }}",
|
||||
{"data": [{"id": 1, "created_at": "06-06-21"}, {"id": 2, "created_at": "06-07-21"}, {"id": 3, "created_at": "06-08-21"}]},
|
||||
[{"id": 2, "created_at": "06-07-21"}, {"id": 3, "created_at": "06-08-21"}],
|
||||
),
|
||||
(
|
||||
"test_no_record_filter_returns_all_records",
|
||||
"_.data",
|
||||
None,
|
||||
{"data": [{"id": 1, "created_at": "06-06-21"}, {"id": 2, "created_at": "06-07-21"}]},
|
||||
[{"id": 1, "created_at": "06-06-21"}, {"id": 2, "created_at": "06-07-21"}],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_record_filter(test_name, transform_template, filter_template, body, expected_records):
|
||||
config = {"response_override": "stop_if_you_see_me"}
|
||||
stream_state = {"created_at": "06-06-21"}
|
||||
stream_slice = {"last_seen": "06-10-21"}
|
||||
next_page_token = {"last_seen_id": 14}
|
||||
|
||||
response = create_response(body)
|
||||
decoder = JsonDecoder()
|
||||
extractor = JelloExtractor(transform=transform_template, decoder=decoder, config=config, kwargs={})
|
||||
if filter_template is None:
|
||||
record_filter = None
|
||||
else:
|
||||
record_filter = RecordFilter(config=config, condition=filter_template)
|
||||
record_selector = RecordSelector(extractor=extractor, record_filter=record_filter)
|
||||
|
||||
actual_records = record_selector.select_records(
|
||||
response=response, stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token
|
||||
)
|
||||
assert actual_records == expected_records
|
||||
|
||||
|
||||
def create_response(body):
|
||||
response = requests.Response()
|
||||
response._content = json.dumps(body).encode("utf-8")
|
||||
return response
|
||||
@@ -22,8 +22,8 @@ def test():
|
||||
next_page_token = {"cursor": "cursor_value"}
|
||||
paginator.next_page_token.return_value = next_page_token
|
||||
|
||||
extractor = MagicMock()
|
||||
extractor.extract_records.return_value = records
|
||||
record_selector = MagicMock()
|
||||
record_selector.select_records.return_value = records
|
||||
|
||||
iterator = MagicMock()
|
||||
stream_slices = [{"date": "2022-01-01"}, {"date": "2022-01-02"}]
|
||||
@@ -62,7 +62,7 @@ def test():
|
||||
use_cache = True
|
||||
requester.use_cache = use_cache
|
||||
|
||||
retriever = SimpleRetriever("stream_name", primary_key, requester, paginator, extractor, iterator, state)
|
||||
retriever = SimpleRetriever("stream_name", primary_key, requester, paginator, record_selector, iterator, state)
|
||||
|
||||
# hack because we clone the state...
|
||||
retriever._state = state
|
||||
|
||||
@@ -4,6 +4,8 @@
|
||||
|
||||
from airbyte_cdk.sources.declarative.declarative_stream import DeclarativeStream
|
||||
from airbyte_cdk.sources.declarative.decoders.json_decoder import JsonDecoder
|
||||
from airbyte_cdk.sources.declarative.extractors.record_filter import RecordFilter
|
||||
from airbyte_cdk.sources.declarative.extractors.record_selector import RecordSelector
|
||||
from airbyte_cdk.sources.declarative.parsers.factory import DeclarativeComponentFactory
|
||||
from airbyte_cdk.sources.declarative.parsers.yaml_parser import YamlParser
|
||||
from airbyte_cdk.sources.declarative.requesters.request_options.interpolated_request_options_provider import (
|
||||
@@ -86,6 +88,13 @@ decoder:
|
||||
extractor:
|
||||
class_name: airbyte_cdk.sources.declarative.extractors.jello.JelloExtractor
|
||||
decoder: "*ref(decoder)"
|
||||
selector:
|
||||
class_name: airbyte_cdk.sources.declarative.extractors.record_selector.RecordSelector
|
||||
extractor:
|
||||
decoder: "*ref(decoder)"
|
||||
record_filter:
|
||||
class_name: airbyte_cdk.sources.declarative.extractors.record_filter.RecordFilter
|
||||
condition: "{{ record['id'] > stream_state['id'] }}"
|
||||
metadata_paginator:
|
||||
class_name: "airbyte_cdk.sources.declarative.requesters.paginators.next_page_url_paginator.NextPageUrlPaginator"
|
||||
next_page_token_template:
|
||||
@@ -139,6 +148,8 @@ list_stream:
|
||||
default: "marketing/lists"
|
||||
paginator:
|
||||
ref: "*ref(metadata_paginator)"
|
||||
record_selector:
|
||||
ref: "*ref(selector)"
|
||||
check:
|
||||
class_name: airbyte_cdk.sources.declarative.checks.check_stream.CheckStream
|
||||
stream_names: ["list_stream"]
|
||||
@@ -156,8 +167,11 @@ check:
|
||||
assert type(stream._retriever) == SimpleRetriever
|
||||
assert stream._retriever._requester._method == HttpMethod.GET
|
||||
assert stream._retriever._requester._authenticator._tokens == ["verysecrettoken"]
|
||||
assert type(stream._retriever._extractor._decoder) == JsonDecoder
|
||||
assert stream._retriever._extractor._transform == ".result[]"
|
||||
assert type(stream._retriever._record_selector) == RecordSelector
|
||||
assert type(stream._retriever._record_selector._extractor._decoder) == JsonDecoder
|
||||
assert stream._retriever._record_selector._extractor._transform == ".result[]"
|
||||
assert type(stream._retriever._record_selector._record_filter) == RecordFilter
|
||||
assert stream._retriever._record_selector._record_filter._filter_interpolator._condition == "{{ record['id'] > stream_state['id'] }}"
|
||||
assert stream._schema_loader._file_path._string == "./source_sendgrid/schemas/lists.json"
|
||||
|
||||
checker = factory.create_component(config["check"], input_config)()
|
||||
|
||||
Reference in New Issue
Block a user