1
0
mirror of synced 2025-12-23 21:03:15 -05:00

support semi incremental by adding extractor record filter (#13520)

* support semi incremental by adding extractor record filter

* refactor extractor into a record_selector that supports extraction and filtering of response records
This commit is contained in:
Brian Lai
2022-06-23 00:09:44 -04:00
committed by GitHub
parent c6d83b3239
commit a61224887e
11 changed files with 216 additions and 29 deletions

View File

@@ -1,15 +0,0 @@
#
# Copyright (c) 2022 Airbyte, Inc., all rights reserved.
#
from abc import ABC, abstractmethod
from typing import List
import requests
from airbyte_cdk.sources.declarative.types import Record
class HttpExtractor(ABC):
@abstractmethod
def extract_records(self, response: requests.Response) -> List[Record]:
pass

View File

@@ -0,0 +1,21 @@
#
# Copyright (c) 2022 Airbyte, Inc., all rights reserved.
#
from abc import ABC, abstractmethod
from typing import Any, List, Mapping
import requests
from airbyte_cdk.sources.declarative.types import Record
class HttpSelector(ABC):
@abstractmethod
def select_records(
self,
response: requests.Response,
stream_state: Mapping[str, Any],
stream_slice: Mapping[str, Any] = None,
next_page_token: Mapping[str, Any] = None,
) -> List[Record]:
pass

View File

@@ -6,13 +6,12 @@ from typing import List
import requests
from airbyte_cdk.sources.declarative.decoders.decoder import Decoder
from airbyte_cdk.sources.declarative.extractors.http_extractor import HttpExtractor
from airbyte_cdk.sources.declarative.interpolation.jinja import JinjaInterpolation
from airbyte_cdk.sources.declarative.types import Record
from jello import lib as jello_lib
class JelloExtractor(HttpExtractor):
class JelloExtractor:
default_transform = "."
def __init__(self, transform: str, decoder: Decoder, config, kwargs=None):

View File

@@ -0,0 +1,24 @@
#
# Copyright (c) 2022 Airbyte, Inc., all rights reserved.
#
from typing import Any, List, Mapping
from airbyte_cdk.sources.declarative.interpolation.interpolated_boolean import InterpolatedBoolean
from airbyte_cdk.sources.declarative.types import Record
class RecordFilter:
def __init__(self, config, condition: str = None):
self._config = config
self._filter_interpolator = InterpolatedBoolean(condition)
def filter_records(
self,
records: List[Record],
stream_state: Mapping[str, Any],
stream_slice: Mapping[str, Any] = None,
next_page_token: Mapping[str, Any] = None,
) -> List[Record]:
kwargs = {"stream_state": stream_state, "stream_slice": stream_slice, "next_page_token": next_page_token}
return [record for record in records if self._filter_interpolator.eval(self._config, record=record, **kwargs)]

View File

@@ -0,0 +1,36 @@
#
# Copyright (c) 2022 Airbyte, Inc., all rights reserved.
#
from typing import Any, List, Mapping
import requests
from airbyte_cdk.sources.declarative.extractors.http_selector import HttpSelector
from airbyte_cdk.sources.declarative.extractors.jello import JelloExtractor
from airbyte_cdk.sources.declarative.extractors.record_filter import RecordFilter
from airbyte_cdk.sources.declarative.types import Record
class RecordSelector(HttpSelector):
"""
Responsible for translating an HTTP response into a list of records by extracting records from the response and optionally filtering
records based on a heuristic.
"""
def __init__(self, extractor: JelloExtractor, record_filter: RecordFilter = None):
self._extractor = extractor
self._record_filter = record_filter
def select_records(
self,
response: requests.Response,
stream_state: Mapping[str, Any],
stream_slice: Mapping[str, Any] = None,
next_page_token: Mapping[str, Any] = None,
) -> List[Record]:
all_records = self._extractor.extract_records(response)
if self._record_filter:
return self._record_filter.filter_records(
all_records, stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token
)
return all_records

View File

@@ -15,8 +15,8 @@ class ConditionalPaginator:
A paginator that performs pagination by incrementing a page number and stops based on a provided stop condition.
"""
def __init__(self, stop_condition_template: str, state: DictState, decoder: Decoder, config):
self._stop_condition_template = InterpolatedBoolean(stop_condition_template)
def __init__(self, stop_condition: str, state: DictState, decoder: Decoder, config):
self._stop_condition_interpolator = InterpolatedBoolean(stop_condition)
self._state: DictState = state
self._decoder = decoder
self._config = config
@@ -24,7 +24,7 @@ class ConditionalPaginator:
def next_page_token(self, response: requests.Response, last_records: List[Mapping[str, Any]]) -> Optional[Mapping[str, Any]]:
decoded_response = self._decoder.decode(response)
headers = response.headers
should_stop = self._stop_condition_template.eval(
should_stop = self._stop_condition_interpolator.eval(
self._config, decoded_response=decoded_response, headers=headers, last_records=last_records
)

View File

@@ -6,7 +6,7 @@ from typing import Any, Iterable, List, Mapping, MutableMapping, Optional, Union
import requests
from airbyte_cdk.models import SyncMode
from airbyte_cdk.sources.declarative.extractors.http_extractor import HttpExtractor
from airbyte_cdk.sources.declarative.extractors.http_selector import HttpSelector
from airbyte_cdk.sources.declarative.requesters.paginators.paginator import Paginator
from airbyte_cdk.sources.declarative.requesters.requester import Requester
from airbyte_cdk.sources.declarative.retrievers.retriever import Retriever
@@ -22,7 +22,7 @@ class SimpleRetriever(Retriever, HttpStream):
primary_key,
requester: Requester,
paginator: Paginator,
extractor: HttpExtractor,
record_selector: HttpSelector,
stream_slicer: StreamSlicer,
state: State,
):
@@ -30,7 +30,7 @@ class SimpleRetriever(Retriever, HttpStream):
self._primary_key = primary_key
self._paginator = paginator
self._requester = requester
self._extractor = extractor
self._record_selector = record_selector
super().__init__(self._requester.get_authenticator())
self._iterator: StreamSlicer = stream_slicer
self._state: State = state.deep_copy()
@@ -190,7 +190,9 @@ class SimpleRetriever(Retriever, HttpStream):
next_page_token: Mapping[str, Any] = None,
) -> Iterable[Mapping]:
self._last_response = response
records = self._extractor.extract_records(response)
records = self._record_selector.select_records(
response=response, stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token
)
self._last_records = records
return records

View File

@@ -0,0 +1,48 @@
#
# Copyright (c) 2022 Airbyte, Inc., all rights reserved.
#
import pytest
from airbyte_cdk.sources.declarative.extractors.record_filter import RecordFilter
@pytest.mark.parametrize(
"test_name, filter_template, records, expected_records",
[
(
"test_using_state_filter",
"{{ record['created_at'] > stream_state['created_at'] }}",
[{"id": 1, "created_at": "06-06-21"}, {"id": 2, "created_at": "06-07-21"}, {"id": 3, "created_at": "06-08-21"}],
[{"id": 2, "created_at": "06-07-21"}, {"id": 3, "created_at": "06-08-21"}],
),
(
"test_with_slice_filter",
"{{ record['last_seen'] >= stream_slice['last_seen'] }}",
[{"id": 1, "last_seen": "06-06-21"}, {"id": 2, "last_seen": "06-07-21"}, {"id": 3, "last_seen": "06-10-21"}],
[{"id": 3, "last_seen": "06-10-21"}],
),
(
"test_with_next_page_token_filter",
"{{ record['id'] >= next_page_token['last_seen_id'] }}",
[{"id": 11}, {"id": 12}, {"id": 13}, {"id": 14}, {"id": 15}],
[{"id": 14}, {"id": 15}],
),
(
"test_missing_filter_fields_return_no_results",
"{{ record['id'] >= next_page_token['path_to_nowhere'] }}",
[{"id": 11}, {"id": 12}, {"id": 13}, {"id": 14}, {"id": 15}],
[],
),
],
)
def test_record_filter(test_name, filter_template, records, expected_records):
config = {"response_override": "stop_if_you_see_me"}
stream_state = {"created_at": "06-06-21"}
stream_slice = {"last_seen": "06-10-21"}
next_page_token = {"last_seen_id": 14}
record_filter = RecordFilter(config=config, condition=filter_template)
actual_records = record_filter.filter_records(
records, stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token
)
assert actual_records == expected_records

View File

@@ -0,0 +1,58 @@
#
# Copyright (c) 2022 Airbyte, Inc., all rights reserved.
#
import json
import pytest
import requests
from airbyte_cdk.sources.declarative.decoders.json_decoder import JsonDecoder
from airbyte_cdk.sources.declarative.extractors.jello import JelloExtractor
from airbyte_cdk.sources.declarative.extractors.record_filter import RecordFilter
from airbyte_cdk.sources.declarative.extractors.record_selector import RecordSelector
@pytest.mark.parametrize(
"test_name, transform_template, filter_template, body, expected_records",
[
(
"test_with_extractor_and_filter",
"_.data",
"{{ record['created_at'] > stream_state['created_at'] }}",
{"data": [{"id": 1, "created_at": "06-06-21"}, {"id": 2, "created_at": "06-07-21"}, {"id": 3, "created_at": "06-08-21"}]},
[{"id": 2, "created_at": "06-07-21"}, {"id": 3, "created_at": "06-08-21"}],
),
(
"test_no_record_filter_returns_all_records",
"_.data",
None,
{"data": [{"id": 1, "created_at": "06-06-21"}, {"id": 2, "created_at": "06-07-21"}]},
[{"id": 1, "created_at": "06-06-21"}, {"id": 2, "created_at": "06-07-21"}],
),
],
)
def test_record_filter(test_name, transform_template, filter_template, body, expected_records):
config = {"response_override": "stop_if_you_see_me"}
stream_state = {"created_at": "06-06-21"}
stream_slice = {"last_seen": "06-10-21"}
next_page_token = {"last_seen_id": 14}
response = create_response(body)
decoder = JsonDecoder()
extractor = JelloExtractor(transform=transform_template, decoder=decoder, config=config, kwargs={})
if filter_template is None:
record_filter = None
else:
record_filter = RecordFilter(config=config, condition=filter_template)
record_selector = RecordSelector(extractor=extractor, record_filter=record_filter)
actual_records = record_selector.select_records(
response=response, stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token
)
assert actual_records == expected_records
def create_response(body):
response = requests.Response()
response._content = json.dumps(body).encode("utf-8")
return response

View File

@@ -22,8 +22,8 @@ def test():
next_page_token = {"cursor": "cursor_value"}
paginator.next_page_token.return_value = next_page_token
extractor = MagicMock()
extractor.extract_records.return_value = records
record_selector = MagicMock()
record_selector.select_records.return_value = records
iterator = MagicMock()
stream_slices = [{"date": "2022-01-01"}, {"date": "2022-01-02"}]
@@ -62,7 +62,7 @@ def test():
use_cache = True
requester.use_cache = use_cache
retriever = SimpleRetriever("stream_name", primary_key, requester, paginator, extractor, iterator, state)
retriever = SimpleRetriever("stream_name", primary_key, requester, paginator, record_selector, iterator, state)
# hack because we clone the state...
retriever._state = state

View File

@@ -4,6 +4,8 @@
from airbyte_cdk.sources.declarative.declarative_stream import DeclarativeStream
from airbyte_cdk.sources.declarative.decoders.json_decoder import JsonDecoder
from airbyte_cdk.sources.declarative.extractors.record_filter import RecordFilter
from airbyte_cdk.sources.declarative.extractors.record_selector import RecordSelector
from airbyte_cdk.sources.declarative.parsers.factory import DeclarativeComponentFactory
from airbyte_cdk.sources.declarative.parsers.yaml_parser import YamlParser
from airbyte_cdk.sources.declarative.requesters.request_options.interpolated_request_options_provider import (
@@ -86,6 +88,13 @@ decoder:
extractor:
class_name: airbyte_cdk.sources.declarative.extractors.jello.JelloExtractor
decoder: "*ref(decoder)"
selector:
class_name: airbyte_cdk.sources.declarative.extractors.record_selector.RecordSelector
extractor:
decoder: "*ref(decoder)"
record_filter:
class_name: airbyte_cdk.sources.declarative.extractors.record_filter.RecordFilter
condition: "{{ record['id'] > stream_state['id'] }}"
metadata_paginator:
class_name: "airbyte_cdk.sources.declarative.requesters.paginators.next_page_url_paginator.NextPageUrlPaginator"
next_page_token_template:
@@ -139,6 +148,8 @@ list_stream:
default: "marketing/lists"
paginator:
ref: "*ref(metadata_paginator)"
record_selector:
ref: "*ref(selector)"
check:
class_name: airbyte_cdk.sources.declarative.checks.check_stream.CheckStream
stream_names: ["list_stream"]
@@ -156,8 +167,11 @@ check:
assert type(stream._retriever) == SimpleRetriever
assert stream._retriever._requester._method == HttpMethod.GET
assert stream._retriever._requester._authenticator._tokens == ["verysecrettoken"]
assert type(stream._retriever._extractor._decoder) == JsonDecoder
assert stream._retriever._extractor._transform == ".result[]"
assert type(stream._retriever._record_selector) == RecordSelector
assert type(stream._retriever._record_selector._extractor._decoder) == JsonDecoder
assert stream._retriever._record_selector._extractor._transform == ".result[]"
assert type(stream._retriever._record_selector._record_filter) == RecordFilter
assert stream._retriever._record_selector._record_filter._filter_interpolator._condition == "{{ record['id'] > stream_state['id'] }}"
assert stream._schema_loader._file_path._string == "./source_sendgrid/schemas/lists.json"
checker = factory.create_component(config["check"], input_config)()