1
0
mirror of synced 2026-01-07 09:05:45 -05:00

Low code CDK: Fix mypy errors (#28386)

* ingore unit tests in mypy check

* Update airbyte-cdk/python/bin/run-mypy-on-modified-files.sh

Co-authored-by: Alexandre Girard <alexandre@airbyte.io>

* work through mypy errors

* fix a bunch of stuff

* fix more type hints

* fix model_to_component_factory types

* format

* ignore list instead of allow list

---------

Co-authored-by: Alexandre Girard <alexandre@airbyte.io>
This commit is contained in:
Joe Reuter
2023-07-19 15:08:35 +02:00
committed by GitHub
parent db16853fd8
commit 78728410f4
3 changed files with 212 additions and 176 deletions

View File

@@ -7,7 +7,7 @@ from __future__ import annotations
import importlib
import inspect
import re
from typing import Any, Callable, List, Literal, Mapping, Optional, Type, Union, get_args, get_origin, get_type_hints
from typing import Any, Callable, List, Mapping, Optional, Type, Union, get_args, get_origin, get_type_hints
from airbyte_cdk.models import Level
from airbyte_cdk.sources.declarative.auth import DeclarativeOauth2Authenticator
@@ -110,7 +110,7 @@ from airbyte_cdk.sources.declarative.types import Config
from airbyte_cdk.sources.message import InMemoryMessageRepository, LogAppenderMessageRepositoryDecorator, MessageRepository
from pydantic import BaseModel
ComponentDefinition: Union[Literal, Mapping, List]
ComponentDefinition = Mapping[str, Any]
DEFAULT_BACKOFF_STRATEGY = ExponentialBackoffStrategy
@@ -119,23 +119,23 @@ DEFAULT_BACKOFF_STRATEGY = ExponentialBackoffStrategy
class ModelToComponentFactory:
def __init__(
self,
limit_pages_fetched_per_slice: int = None,
limit_slices_fetched: int = None,
limit_pages_fetched_per_slice: Optional[int] = None,
limit_slices_fetched: Optional[int] = None,
emit_connector_builder_messages: bool = False,
disable_retries: bool = False,
message_repository: MessageRepository = None,
message_repository: Optional[MessageRepository] = None,
):
self._init_mappings()
self._limit_pages_fetched_per_slice = limit_pages_fetched_per_slice
self._limit_slices_fetched = limit_slices_fetched
self._emit_connector_builder_messages = emit_connector_builder_messages
self._disable_retries = disable_retries
self._message_repository = message_repository or InMemoryMessageRepository(
self._message_repository = message_repository or InMemoryMessageRepository( # type: ignore
self._evaluate_log_level(emit_connector_builder_messages)
)
def _init_mappings(self):
self.PYDANTIC_MODEL_TO_CONSTRUCTOR: [Type[BaseModel], Callable] = {
def _init_mappings(self) -> None:
self.PYDANTIC_MODEL_TO_CONSTRUCTOR: Mapping[Type[BaseModel], Callable[..., Any]] = {
AddedFieldDefinitionModel: self.create_added_field_definition,
AddFieldsModel: self.create_add_fields,
ApiKeyAuthenticatorModel: self.create_api_key_authenticator,
@@ -190,7 +190,9 @@ class ModelToComponentFactory:
# Needed for the case where we need to perform a second parse on the fields of a custom component
self.TYPE_NAME_TO_MODEL = {cls.__name__: cls for cls in self.PYDANTIC_MODEL_TO_CONSTRUCTOR}
def create_component(self, model_type: Type[BaseModel], component_definition: ComponentDefinition, config: Config, **kwargs) -> type:
def create_component(
self, model_type: Type[BaseModel], component_definition: ComponentDefinition, config: Config, **kwargs: Any
) -> Any:
"""
Takes a given Pydantic model type and Mapping representing a component definition and creates a declarative component and
subcomponents which will be used at runtime. This is done by first parsing the mapping into a Pydantic model and then creating
@@ -213,26 +215,28 @@ class ModelToComponentFactory:
return self._create_component_from_model(model=declarative_component_model, config=config, **kwargs)
def _create_component_from_model(self, model: BaseModel, config: Config, **kwargs) -> Any:
def _create_component_from_model(self, model: BaseModel, config: Config, **kwargs: Any) -> Any:
if model.__class__ not in self.PYDANTIC_MODEL_TO_CONSTRUCTOR:
raise ValueError(f"{model.__class__} with attributes {model} is not a valid component type")
component_constructor = self.PYDANTIC_MODEL_TO_CONSTRUCTOR.get(model.__class__)
if not component_constructor:
raise ValueError(f"Could not find constructor for {model.__class__}")
return component_constructor(model=model, config=config, **kwargs)
@staticmethod
def create_added_field_definition(model: AddedFieldDefinitionModel, config: Config, **kwargs) -> AddedFieldDefinition:
interpolated_value = InterpolatedString.create(model.value, parameters=model.parameters)
return AddedFieldDefinition(path=model.path, value=interpolated_value, parameters=model.parameters)
def create_added_field_definition(model: AddedFieldDefinitionModel, config: Config, **kwargs: Any) -> AddedFieldDefinition:
interpolated_value = InterpolatedString.create(model.value, parameters=model.parameters or {})
return AddedFieldDefinition(path=model.path, value=interpolated_value, parameters=model.parameters or {})
def create_add_fields(self, model: AddFieldsModel, config: Config, **kwargs) -> AddFields:
def create_add_fields(self, model: AddFieldsModel, config: Config, **kwargs: Any) -> AddFields:
added_field_definitions = [
self._create_component_from_model(model=added_field_definition_model, config=config)
for added_field_definition_model in model.fields
]
return AddFields(fields=added_field_definitions, parameters=model.parameters)
return AddFields(fields=added_field_definitions, parameters=model.parameters or {})
@staticmethod
def create_api_key_authenticator(model: ApiKeyAuthenticatorModel, config: Config, **kwargs) -> ApiKeyAuthenticator:
def create_api_key_authenticator(model: ApiKeyAuthenticatorModel, config: Config, **kwargs: Any) -> ApiKeyAuthenticator:
if model.inject_into is None and model.header is None:
raise ValueError("Expected either inject_into or header to be set for ApiKeyAuthenticator")
@@ -243,52 +247,56 @@ class ModelToComponentFactory:
RequestOption(
inject_into=RequestOptionType(model.inject_into.inject_into.value),
field_name=model.inject_into.field_name,
parameters=model.parameters,
parameters=model.parameters or {},
)
if model.inject_into
else RequestOption(
inject_into=RequestOptionType.header,
field_name=model.header,
parameters=model.parameters,
field_name=model.header or "",
parameters=model.parameters or {},
)
)
return ApiKeyAuthenticator(api_token=model.api_token, request_option=request_option, config=config, parameters=model.parameters)
return ApiKeyAuthenticator(
api_token=model.api_token or "", request_option=request_option, config=config, parameters=model.parameters or {}
)
@staticmethod
def create_basic_http_authenticator(model: BasicHttpAuthenticatorModel, config: Config, **kwargs) -> BasicHttpAuthenticator:
return BasicHttpAuthenticator(password=model.password, username=model.username, config=config, parameters=model.parameters)
def create_basic_http_authenticator(model: BasicHttpAuthenticatorModel, config: Config, **kwargs: Any) -> BasicHttpAuthenticator:
return BasicHttpAuthenticator(
password=model.password or "", username=model.username, config=config, parameters=model.parameters or {}
)
@staticmethod
def create_bearer_authenticator(model: BearerAuthenticatorModel, config: Config, **kwargs) -> BearerAuthenticator:
def create_bearer_authenticator(model: BearerAuthenticatorModel, config: Config, **kwargs: Any) -> BearerAuthenticator:
return BearerAuthenticator(
api_token=model.api_token,
config=config,
parameters=model.parameters,
parameters=model.parameters or {},
)
@staticmethod
def create_check_stream(model: CheckStreamModel, config: Config, **kwargs):
def create_check_stream(model: CheckStreamModel, config: Config, **kwargs: Any) -> CheckStream:
return CheckStream(stream_names=model.stream_names, parameters={})
def create_composite_error_handler(self, model: CompositeErrorHandlerModel, config: Config, **kwargs) -> CompositeErrorHandler:
def create_composite_error_handler(self, model: CompositeErrorHandlerModel, config: Config, **kwargs: Any) -> CompositeErrorHandler:
error_handlers = [
self._create_component_from_model(model=error_handler_model, config=config) for error_handler_model in model.error_handlers
]
return CompositeErrorHandler(error_handlers=error_handlers, parameters=model.parameters)
return CompositeErrorHandler(error_handlers=error_handlers, parameters=model.parameters or {})
@staticmethod
def create_constant_backoff_strategy(model: ConstantBackoffStrategyModel, config: Config, **kwargs) -> ConstantBackoffStrategy:
def create_constant_backoff_strategy(model: ConstantBackoffStrategyModel, config: Config, **kwargs: Any) -> ConstantBackoffStrategy:
return ConstantBackoffStrategy(
backoff_time_in_seconds=model.backoff_time_in_seconds,
config=config,
parameters=model.parameters,
parameters=model.parameters or {},
)
def create_cursor_pagination(self, model: CursorPaginationModel, config: Config, **kwargs) -> CursorPaginationStrategy:
def create_cursor_pagination(self, model: CursorPaginationModel, config: Config, **kwargs: Any) -> CursorPaginationStrategy:
if model.decoder:
decoder = self._create_component_from_model(model=model.decoder, config=config)
else:
decoder = JsonDecoder(parameters=model.parameters)
decoder = JsonDecoder(parameters=model.parameters or {})
return CursorPaginationStrategy(
cursor_value=model.cursor_value,
@@ -296,10 +304,10 @@ class ModelToComponentFactory:
page_size=model.page_size,
stop_condition=model.stop_condition,
config=config,
parameters=model.parameters,
parameters=model.parameters or {},
)
def create_custom_component(self, model, config: Config, **kwargs) -> type:
def create_custom_component(self, model: Any, config: Config, **kwargs: Any) -> Any:
"""
Generically creates a custom component based on the model type and a class_name reference to the custom Python class being
instantiated. Only the model's additional properties that match the custom class definition are passed to the constructor
@@ -347,14 +355,14 @@ class ModelToComponentFactory:
return custom_component_class(**kwargs)
@staticmethod
def _get_class_from_fully_qualified_class_name(class_name: str) -> type:
def _get_class_from_fully_qualified_class_name(class_name: str) -> Any:
split = class_name.split(".")
module = ".".join(split[:-1])
class_name = split[-1]
return getattr(importlib.import_module(module), class_name)
@staticmethod
def _derive_component_type_from_type_hints(field_type: str) -> Optional[str]:
def _derive_component_type_from_type_hints(field_type: Any) -> Optional[str]:
interface = field_type
while True:
origin = get_origin(interface)
@@ -371,7 +379,7 @@ class ModelToComponentFactory:
return None
@staticmethod
def is_builtin_type(cls) -> bool:
def is_builtin_type(cls: Optional[Type[Any]]) -> bool:
if not cls:
return False
return cls.__module__ == "builtins"
@@ -384,7 +392,7 @@ class ModelToComponentFactory:
else:
return []
def _create_nested_component(self, model, model_field: str, model_value: Any, config: Config) -> Any:
def _create_nested_component(self, model: Any, model_field: str, model_value: Any, config: Config) -> Any:
type_name = model_value.get("type", None)
if not type_name:
# If no type is specified, we can assume this is a dictionary object which can be returned instead of a subcomponent
@@ -420,13 +428,13 @@ class ModelToComponentFactory:
@staticmethod
def _is_component(model_value: Any) -> bool:
return isinstance(model_value, dict) and model_value.get("type")
return isinstance(model_value, dict) and model_value.get("type") is not None
def create_datetime_based_cursor(self, model: DatetimeBasedCursorModel, config: Config, **kwargs) -> DatetimeBasedCursor:
start_datetime = (
def create_datetime_based_cursor(self, model: DatetimeBasedCursorModel, config: Config, **kwargs: Any) -> DatetimeBasedCursor:
start_datetime: Union[str, MinMaxDatetime] = (
model.start_datetime if isinstance(model.start_datetime, str) else self.create_min_max_datetime(model.start_datetime, config)
)
end_datetime = None
end_datetime: Union[str, MinMaxDatetime, None] = None
if model.is_data_feed and model.end_datetime:
raise ValueError("Data feed does not support end_datetime")
if model.end_datetime:
@@ -438,7 +446,7 @@ class ModelToComponentFactory:
RequestOption(
inject_into=RequestOptionType(model.end_time_option.inject_into.value),
field_name=model.end_time_option.field_name,
parameters=model.parameters,
parameters=model.parameters or {},
)
if model.end_time_option
else None
@@ -447,7 +455,7 @@ class ModelToComponentFactory:
RequestOption(
inject_into=RequestOptionType(model.start_time_option.inject_into.value),
field_name=model.start_time_option.field_name,
parameters=model.parameters,
parameters=model.parameters or {},
)
if model.start_time_option
else None
@@ -467,10 +475,10 @@ class ModelToComponentFactory:
partition_field_start=model.partition_field_start,
message_repository=self._message_repository,
config=config,
parameters=model.parameters,
parameters=model.parameters or {},
)
def create_declarative_stream(self, model: DeclarativeStreamModel, config: Config, **kwargs) -> DeclarativeStream:
def create_declarative_stream(self, model: DeclarativeStreamModel, config: Config, **kwargs: Any) -> DeclarativeStream:
# When constructing a declarative stream, we assemble the incremental_sync component and retriever's partition_router field
# components if they exist into a single CartesianProductStreamSlicer. This is then passed back as an argument when constructing the
# Retriever. This is done in the declarative stream not the retriever to support custom retrievers. The custom create methods in
@@ -505,31 +513,31 @@ class ModelToComponentFactory:
schema_loader = DefaultSchemaLoader(config=config, parameters=options)
return DeclarativeStream(
name=model.name,
name=model.name or "",
primary_key=primary_key,
retriever=retriever,
schema_loader=schema_loader,
stream_cursor_field=cursor_field or "",
config=config,
parameters=model.parameters,
parameters=model.parameters or {},
)
def _merge_stream_slicers(self, model: DeclarativeStreamModel, config: Config) -> Optional[StreamSlicer]:
stream_slicer = None
if hasattr(model.retriever, "partition_router") and model.retriever.partition_router:
stream_slicer_model = model.retriever.partition_router
stream_slicer = (
CartesianProductStreamSlicer(
if isinstance(stream_slicer_model, list):
stream_slicer = CartesianProductStreamSlicer(
[self._create_component_from_model(model=slicer, config=config) for slicer in stream_slicer_model], parameters={}
)
if type(stream_slicer_model) == list
else self._create_component_from_model(model=stream_slicer_model, config=config)
)
else:
stream_slicer = self._create_component_from_model(model=stream_slicer_model, config=config)
if model.incremental_sync and stream_slicer:
incremental_sync_model = model.incremental_sync
return PerPartitionCursor(
cursor_factory=CursorFactory(
lambda: self._create_component_from_model(model=model.incremental_sync, config=config),
lambda: self._create_component_from_model(model=incremental_sync_model, config=config),
),
partition_router=stream_slicer,
)
@@ -540,13 +548,13 @@ class ModelToComponentFactory:
else:
return None
def create_default_error_handler(self, model: DefaultErrorHandlerModel, config: Config, **kwargs) -> DefaultErrorHandler:
def create_default_error_handler(self, model: DefaultErrorHandlerModel, config: Config, **kwargs: Any) -> DefaultErrorHandler:
backoff_strategies = []
if model.backoff_strategies:
for backoff_strategy_model in model.backoff_strategies:
backoff_strategies.append(self._create_component_from_model(model=backoff_strategy_model, config=config))
else:
backoff_strategies.append(DEFAULT_BACKOFF_STRATEGY(config=config, parameters=model.parameters))
backoff_strategies.append(DEFAULT_BACKOFF_STRATEGY(config=config, parameters=model.parameters or {}))
response_filters = []
if model.response_filters:
@@ -555,22 +563,25 @@ class ModelToComponentFactory:
else:
response_filters.append(
HttpResponseFilter(
ResponseAction.RETRY, http_codes=HttpResponseFilter.DEFAULT_RETRIABLE_ERRORS, config=config, parameters=model.parameters
ResponseAction.RETRY,
http_codes=HttpResponseFilter.DEFAULT_RETRIABLE_ERRORS,
config=config,
parameters=model.parameters or {},
)
)
response_filters.append(HttpResponseFilter(ResponseAction.IGNORE, config=config, parameters=model.parameters))
response_filters.append(HttpResponseFilter(ResponseAction.IGNORE, config=config, parameters=model.parameters or {}))
return DefaultErrorHandler(
backoff_strategies=backoff_strategies,
max_retries=model.max_retries,
response_filters=response_filters,
config=config,
parameters=model.parameters,
parameters=model.parameters or {},
)
def create_default_paginator(
self, model: DefaultPaginatorModel, config: Config, *, url_base: str, cursor_used_for_stop_condition: Optional[Cursor] = None
) -> DefaultPaginator:
) -> Union[DefaultPaginator, PaginatorTestReadDecorator]:
decoder = self._create_component_from_model(model=model.decoder, config=config) if model.decoder else JsonDecoder(parameters={})
page_size_option = (
self._create_component_from_model(model=model.page_size_option, config=config) if model.page_size_option else None
@@ -591,19 +602,20 @@ class ModelToComponentFactory:
pagination_strategy=pagination_strategy,
url_base=url_base,
config=config,
parameters=model.parameters,
parameters=model.parameters or {},
)
if self._limit_pages_fetched_per_slice:
return PaginatorTestReadDecorator(paginator, self._limit_pages_fetched_per_slice)
return paginator
def create_dpath_extractor(self, model: DpathExtractorModel, config: Config, **kwargs) -> DpathExtractor:
def create_dpath_extractor(self, model: DpathExtractorModel, config: Config, **kwargs: Any) -> DpathExtractor:
decoder = self._create_component_from_model(model.decoder, config=config) if model.decoder else JsonDecoder(parameters={})
return DpathExtractor(decoder=decoder, field_path=model.field_path, config=config, parameters=model.parameters)
model_field_path: List[Union[InterpolatedString, str]] = [x for x in model.field_path]
return DpathExtractor(decoder=decoder, field_path=model_field_path, config=config, parameters=model.parameters or {})
@staticmethod
def create_exponential_backoff_strategy(model: ExponentialBackoffStrategyModel, config: Config) -> ExponentialBackoffStrategy:
return ExponentialBackoffStrategy(factor=model.factor, parameters=model.parameters, config=config)
return ExponentialBackoffStrategy(factor=model.factor or 5, parameters=model.parameters or {}, config=config)
def create_http_requester(self, model: HttpRequesterModel, config: Config, *, name: str) -> HttpRequester:
authenticator = (
@@ -614,7 +626,7 @@ class ModelToComponentFactory:
error_handler = (
self._create_component_from_model(model=model.error_handler, config=config)
if model.error_handler
else DefaultErrorHandler(backoff_strategies=[], response_filters=[], config=config, parameters=model.parameters)
else DefaultErrorHandler(backoff_strategies=[], response_filters=[], config=config, parameters=model.parameters or {})
)
request_options_provider = InterpolatedRequestOptionsProvider(
@@ -623,7 +635,11 @@ class ModelToComponentFactory:
request_headers=model.request_headers,
request_parameters=model.request_parameters,
config=config,
parameters=model.parameters,
parameters=model.parameters or {},
)
model_http_method = (
model.http_method if isinstance(model.http_method, str) else model.http_method.value if model.http_method is not None else "GET"
)
return HttpRequester(
@@ -632,14 +648,14 @@ class ModelToComponentFactory:
path=model.path,
authenticator=authenticator,
error_handler=error_handler,
http_method=model.http_method,
http_method=model_http_method,
request_options_provider=request_options_provider,
config=config,
parameters=model.parameters,
parameters=model.parameters or {},
)
@staticmethod
def create_http_response_filter(model: HttpResponseFilterModel, config: Config, **kwargs) -> HttpResponseFilter:
def create_http_response_filter(model: HttpResponseFilterModel, config: Config, **kwargs: Any) -> HttpResponseFilter:
action = ResponseAction(model.action.value)
http_codes = (
set(model.http_codes) if model.http_codes else set()
@@ -648,32 +664,32 @@ class ModelToComponentFactory:
return HttpResponseFilter(
action=action,
error_message=model.error_message or "",
error_message_contains=model.error_message_contains,
error_message_contains=model.error_message_contains or "",
http_codes=http_codes,
predicate=model.predicate or "",
config=config,
parameters=model.parameters,
parameters=model.parameters or {},
)
@staticmethod
def create_inline_schema_loader(model: InlineSchemaLoaderModel, config: Config, **kwargs) -> InlineSchemaLoader:
return InlineSchemaLoader(schema=model.schema_, parameters={})
def create_inline_schema_loader(model: InlineSchemaLoaderModel, config: Config, **kwargs: Any) -> InlineSchemaLoader:
return InlineSchemaLoader(schema=model.schema_ or {}, parameters={})
@staticmethod
def create_json_decoder(model: JsonDecoderModel, config: Config, **kwargs) -> JsonDecoder:
def create_json_decoder(model: JsonDecoderModel, config: Config, **kwargs: Any) -> JsonDecoder:
return JsonDecoder(parameters={})
@staticmethod
def create_json_file_schema_loader(model: JsonFileSchemaLoaderModel, config: Config, **kwargs) -> JsonFileSchemaLoader:
return JsonFileSchemaLoader(file_path=model.file_path, config=config, parameters=model.parameters)
def create_json_file_schema_loader(model: JsonFileSchemaLoaderModel, config: Config, **kwargs: Any) -> JsonFileSchemaLoader:
return JsonFileSchemaLoader(file_path=model.file_path or "", config=config, parameters=model.parameters or {})
@staticmethod
def create_list_partition_router(model: ListPartitionRouterModel, config: Config, **kwargs) -> ListPartitionRouter:
def create_list_partition_router(model: ListPartitionRouterModel, config: Config, **kwargs: Any) -> ListPartitionRouter:
request_option = (
RequestOption(
inject_into=RequestOptionType(model.request_option.inject_into.value),
field_name=model.request_option.field_name,
parameters=model.parameters,
parameters=model.parameters or {},
)
if model.request_option
else None
@@ -683,72 +699,78 @@ class ModelToComponentFactory:
request_option=request_option,
values=model.values,
config=config,
parameters=model.parameters,
parameters=model.parameters or {},
)
@staticmethod
def create_min_max_datetime(model: MinMaxDatetimeModel, config: Config, **kwargs) -> MinMaxDatetime:
def create_min_max_datetime(model: MinMaxDatetimeModel, config: Config, **kwargs: Any) -> MinMaxDatetime:
return MinMaxDatetime(
datetime=model.datetime,
datetime_format=model.datetime_format,
max_datetime=model.max_datetime,
min_datetime=model.min_datetime,
parameters=model.parameters,
datetime_format=model.datetime_format or "",
max_datetime=model.max_datetime or "",
min_datetime=model.min_datetime or "",
parameters=model.parameters or {},
)
@staticmethod
def create_no_auth(model: NoAuthModel, config: Config, **kwargs) -> NoAuth:
return NoAuth(parameters=model.parameters)
def create_no_auth(model: NoAuthModel, config: Config, **kwargs: Any) -> NoAuth:
return NoAuth(parameters=model.parameters or {})
@staticmethod
def create_no_pagination(model: NoPaginationModel, config: Config, **kwargs) -> NoPagination:
def create_no_pagination(model: NoPaginationModel, config: Config, **kwargs: Any) -> NoPagination:
return NoPagination(parameters={})
def create_oauth_authenticator(self, model: OAuthAuthenticatorModel, config: Config, **kwargs) -> DeclarativeOauth2Authenticator:
def create_oauth_authenticator(self, model: OAuthAuthenticatorModel, config: Config, **kwargs: Any) -> DeclarativeOauth2Authenticator:
if model.refresh_token_updater:
return DeclarativeSingleUseRefreshTokenOauth2Authenticator(
# ignore type error beause fixing it would have a lot of dependencies, revisit later
return DeclarativeSingleUseRefreshTokenOauth2Authenticator( # type: ignore
config,
InterpolatedString.create(model.token_refresh_endpoint, parameters=model.parameters).eval(config),
access_token_name=InterpolatedString.create(model.access_token_name, parameters=model.parameters).eval(config),
InterpolatedString.create(model.token_refresh_endpoint, parameters=model.parameters or {}).eval(config),
access_token_name=InterpolatedString.create(
model.access_token_name or "access_token", parameters=model.parameters or {}
).eval(config),
refresh_token_name=model.refresh_token_updater.refresh_token_name,
expires_in_name=InterpolatedString.create(model.expires_in_name, parameters=model.parameters).eval(config),
client_id=InterpolatedString.create(model.client_id, parameters=model.parameters).eval(config),
client_secret=InterpolatedString.create(model.client_secret, parameters=model.parameters).eval(config),
expires_in_name=InterpolatedString.create(model.expires_in_name or "expires_in", parameters=model.parameters or {}).eval(
config
),
client_id=InterpolatedString.create(model.client_id, parameters=model.parameters or {}).eval(config),
client_secret=InterpolatedString.create(model.client_secret, parameters=model.parameters or {}).eval(config),
access_token_config_path=model.refresh_token_updater.access_token_config_path,
refresh_token_config_path=model.refresh_token_updater.refresh_token_config_path,
token_expiry_date_config_path=model.refresh_token_updater.token_expiry_date_config_path,
grant_type=InterpolatedString.create(model.grant_type, parameters=model.parameters).eval(config),
refresh_request_body=InterpolatedMapping(model.refresh_request_body or {}, parameters=model.parameters).eval(config),
grant_type=InterpolatedString.create(model.grant_type or "refresh_token", parameters=model.parameters or {}).eval(config),
refresh_request_body=InterpolatedMapping(model.refresh_request_body or {}, parameters=model.parameters or {}).eval(config),
scopes=model.scopes,
token_expiry_date_format=model.token_expiry_date_format,
message_repository=self._message_repository,
)
return DeclarativeOauth2Authenticator(
access_token_name=model.access_token_name,
# ignore type error beause fixing it would have a lot of dependencies, revisit later
return DeclarativeOauth2Authenticator( # type: ignore
access_token_name=model.access_token_name or "access_token",
client_id=model.client_id,
client_secret=model.client_secret,
expires_in_name=model.expires_in_name,
grant_type=model.grant_type,
expires_in_name=model.expires_in_name or "expires_in",
grant_type=model.grant_type or "refresh_token",
refresh_request_body=model.refresh_request_body,
refresh_token=model.refresh_token,
scopes=model.scopes,
token_expiry_date=model.token_expiry_date,
token_expiry_date_format=model.token_expiry_date_format,
token_expiry_date_format=model.token_expiry_date_format, # type: ignore
token_refresh_endpoint=model.token_refresh_endpoint,
config=config,
parameters=model.parameters,
parameters=model.parameters or {},
message_repository=self._message_repository,
)
@staticmethod
def create_offset_increment(model: OffsetIncrementModel, config: Config, **kwargs) -> OffsetIncrement:
return OffsetIncrement(page_size=model.page_size, config=config, parameters=model.parameters)
def create_offset_increment(model: OffsetIncrementModel, config: Config, **kwargs: Any) -> OffsetIncrement:
return OffsetIncrement(page_size=model.page_size, config=config, parameters=model.parameters or {})
@staticmethod
def create_page_increment(model: PageIncrementModel, config: Config, **kwargs) -> PageIncrement:
return PageIncrement(page_size=model.page_size, start_from_page=model.start_from_page, parameters=model.parameters)
def create_page_increment(model: PageIncrementModel, config: Config, **kwargs: Any) -> PageIncrement:
return PageIncrement(page_size=model.page_size, start_from_page=model.start_from_page or 0, parameters=model.parameters or {})
def create_parent_stream_config(self, model: ParentStreamConfigModel, config: Config, **kwargs) -> ParentStreamConfig:
def create_parent_stream_config(self, model: ParentStreamConfigModel, config: Config, **kwargs: Any) -> ParentStreamConfig:
declarative_stream = self._create_component_from_model(model.stream, config=config)
request_option = self._create_component_from_model(model.request_option, config=config) if model.request_option else None
return ParentStreamConfig(
@@ -757,51 +779,55 @@ class ModelToComponentFactory:
stream=declarative_stream,
partition_field=model.partition_field,
config=config,
parameters=model.parameters,
parameters=model.parameters or {},
)
@staticmethod
def create_record_filter(model: RecordFilterModel, config: Config, **kwargs) -> RecordFilter:
return RecordFilter(condition=model.condition, config=config, parameters=model.parameters)
def create_record_filter(model: RecordFilterModel, config: Config, **kwargs: Any) -> RecordFilter:
return RecordFilter(condition=model.condition or "", config=config, parameters=model.parameters or {})
@staticmethod
def create_request_path(model: RequestPathModel, config: Config, **kwargs) -> RequestPath:
def create_request_path(model: RequestPathModel, config: Config, **kwargs: Any) -> RequestPath:
return RequestPath(parameters={})
@staticmethod
def create_request_option(model: RequestOptionModel, config: Config, **kwargs) -> RequestOption:
def create_request_option(model: RequestOptionModel, config: Config, **kwargs: Any) -> RequestOption:
inject_into = RequestOptionType(model.inject_into.value)
return RequestOption(field_name=model.field_name, inject_into=inject_into, parameters={})
def create_record_selector(
self, model: RecordSelectorModel, config: Config, *, transformations: List[RecordTransformation], **kwargs
self, model: RecordSelectorModel, config: Config, *, transformations: List[RecordTransformation], **kwargs: Any
) -> RecordSelector:
extractor = self._create_component_from_model(model=model.extractor, config=config)
record_filter = self._create_component_from_model(model.record_filter, config=config) if model.record_filter else None
return RecordSelector(
extractor=extractor, config=config, record_filter=record_filter, transformations=transformations, parameters=model.parameters
extractor=extractor,
config=config,
record_filter=record_filter,
transformations=transformations,
parameters=model.parameters or {},
)
@staticmethod
def create_remove_fields(model: RemoveFieldsModel, config: Config, **kwargs) -> RemoveFields:
def create_remove_fields(model: RemoveFieldsModel, config: Config, **kwargs: Any) -> RemoveFields:
return RemoveFields(field_pointers=model.field_pointers, parameters={})
@staticmethod
def create_session_token_authenticator(
model: SessionTokenAuthenticatorModel, config: Config, *, url_base: str, **kwargs
model: SessionTokenAuthenticatorModel, config: Config, *, url_base: str, **kwargs: Any
) -> SessionTokenAuthenticator:
return SessionTokenAuthenticator(
api_url=url_base,
header=model.header,
login_url=model.login_url,
password=model.password,
session_token=model.session_token,
session_token_response_key=model.session_token_response_key,
username=model.username,
password=model.password or "",
session_token=model.session_token or "",
session_token_response_key=model.session_token_response_key or "",
username=model.username or "",
validate_session_url=model.validate_session_url,
config=config,
parameters=model.parameters,
parameters=model.parameters or {},
)
def create_simple_retriever(
@@ -840,8 +866,8 @@ class ModelToComponentFactory:
stream_slicer=stream_slicer,
cursor=cursor,
config=config,
maximum_number_of_slices=self._limit_slices_fetched,
parameters=model.parameters,
maximum_number_of_slices=self._limit_slices_fetched or 5,
parameters=model.parameters or {},
disable_retries=self._disable_retries,
message_repository=self._message_repository,
)
@@ -854,13 +880,13 @@ class ModelToComponentFactory:
stream_slicer=stream_slicer,
cursor=cursor,
config=config,
parameters=model.parameters,
parameters=model.parameters or {},
disable_retries=self._disable_retries,
message_repository=self._message_repository,
)
@staticmethod
def create_spec(model: SpecModel, config: Config, **kwargs) -> Spec:
def create_spec(model: SpecModel, config: Config, **kwargs: Any) -> Spec:
return Spec(
connection_specification=model.connection_specification,
documentation_url=model.documentation_url,
@@ -868,7 +894,9 @@ class ModelToComponentFactory:
parameters={},
)
def create_substream_partition_router(self, model: SubstreamPartitionRouterModel, config: Config, **kwargs) -> SubstreamPartitionRouter:
def create_substream_partition_router(
self, model: SubstreamPartitionRouterModel, config: Config, **kwargs: Any
) -> SubstreamPartitionRouter:
parent_stream_configs = []
if model.parent_stream_configs:
parent_stream_configs.extend(
@@ -878,9 +906,9 @@ class ModelToComponentFactory:
]
)
return SubstreamPartitionRouter(parent_stream_configs=parent_stream_configs, parameters=model.parameters, config=config)
return SubstreamPartitionRouter(parent_stream_configs=parent_stream_configs, parameters=model.parameters or {}, config=config)
def _create_message_repository_substream_wrapper(self, model, config):
def _create_message_repository_substream_wrapper(self, model: ParentStreamConfigModel, config: Config) -> Any:
substream_factory = ModelToComponentFactory(
limit_pages_fetched_per_slice=self._limit_pages_fetched_per_slice,
limit_slices_fetched=self._limit_slices_fetched,
@@ -895,19 +923,19 @@ class ModelToComponentFactory:
return substream_factory._create_component_from_model(model=model, config=config)
@staticmethod
def create_wait_time_from_header(model: WaitTimeFromHeaderModel, config: Config, **kwargs) -> WaitTimeFromHeaderBackoffStrategy:
return WaitTimeFromHeaderBackoffStrategy(header=model.header, parameters=model.parameters, config=config, regex=model.regex)
def create_wait_time_from_header(model: WaitTimeFromHeaderModel, config: Config, **kwargs: Any) -> WaitTimeFromHeaderBackoffStrategy:
return WaitTimeFromHeaderBackoffStrategy(header=model.header, parameters=model.parameters or {}, config=config, regex=model.regex)
@staticmethod
def create_wait_until_time_from_header(
model: WaitUntilTimeFromHeaderModel, config: Config, **kwargs
model: WaitUntilTimeFromHeaderModel, config: Config, **kwargs: Any
) -> WaitUntilTimeFromHeaderBackoffStrategy:
return WaitUntilTimeFromHeaderBackoffStrategy(
header=model.header, parameters=model.parameters, config=config, min_wait=model.min_wait, regex=model.regex
header=model.header, parameters=model.parameters or {}, config=config, min_wait=model.min_wait, regex=model.regex
)
def get_message_repository(self):
def get_message_repository(self) -> MessageRepository:
return self._message_repository
def _evaluate_log_level(self, emit_connector_builder_messages) -> Level:
def _evaluate_log_level(self, emit_connector_builder_messages: bool) -> Level:
return Level.DEBUG if emit_connector_builder_messages else Level.INFO

View File

@@ -4,7 +4,7 @@
from dataclasses import InitVar, dataclass, field
from itertools import islice
from typing import Any, Iterable, List, Mapping, MutableMapping, Optional, Union
from typing import Any, Callable, Iterable, List, Mapping, MutableMapping, Optional, Union
import requests
from airbyte_cdk.models import AirbyteMessage, Level, SyncMode
@@ -61,25 +61,25 @@ class SimpleRetriever(Retriever, HttpStream):
primary_key: Optional[Union[str, List[str], List[List[str]]]]
_primary_key: str = field(init=False, repr=False, default="")
paginator: Optional[Paginator] = None
stream_slicer: Optional[StreamSlicer] = SinglePartitionRouter(parameters={})
stream_slicer: StreamSlicer = SinglePartitionRouter(parameters={})
cursor: Optional[Cursor] = None
disable_retries: bool = False
message_repository: MessageRepository = NoopMessageRepository()
def __post_init__(self, parameters: Mapping[str, Any]):
self.paginator = self.paginator or NoPagination(parameters=parameters)
def __post_init__(self, parameters: Mapping[str, Any]) -> None:
self._paginator = self.paginator or NoPagination(parameters=parameters)
self._last_response: Optional[requests.Response] = None
self._records_from_last_response: List[Record] = []
HttpStream.__init__(self, self.requester.get_authenticator())
self._last_response = None
self._records_from_last_response = None
self._parameters = parameters
self.name = InterpolatedString(self._name, parameters=parameters)
self._name = InterpolatedString(self._name, parameters=parameters) if isinstance(self._name, str) else self._name
@property
@property # type: ignore
def name(self) -> str:
"""
:return: Stream name
"""
return self._name.eval(self.config)
return str(self._name.eval(self.config)) if isinstance(self._name, InterpolatedString) else self._name
@name.setter
def name(self, value: str) -> None:
@@ -103,8 +103,9 @@ class SimpleRetriever(Retriever, HttpStream):
def max_retries(self) -> Union[int, None]:
if self.disable_retries:
return 0
if hasattr(self.requester.error_handler, "max_retries"):
return self.requester.error_handler.max_retries
# this will be removed once simple_retriever is decoupled from http_stream
if hasattr(self.requester.error_handler, "max_retries"): # type: ignore
return self.requester.error_handler.max_retries # type: ignore
return self._DEFAULT_MAX_RETRY
def should_retry(self, response: requests.Response) -> bool:
@@ -117,7 +118,7 @@ class SimpleRetriever(Retriever, HttpStream):
Unexpected but transient exceptions (connection timeout, DNS resolution failed, etc..) are retried by default.
"""
return self.requester.interpret_response_status(response).action == ResponseAction.RETRY
return bool(self.requester.interpret_response_status(response).action == ResponseAction.RETRY)
def backoff_time(self, response: requests.Response) -> Optional[float]:
"""
@@ -148,11 +149,11 @@ class SimpleRetriever(Retriever, HttpStream):
self,
stream_slice: Optional[StreamSlice],
next_page_token: Optional[Mapping[str, Any]],
requester_method,
paginator_method,
stream_slicer_method,
auth_options_method,
):
requester_method: Callable[..., Mapping[str, Any]],
paginator_method: Callable[..., Mapping[str, Any]],
stream_slicer_method: Callable[..., Mapping[str, Any]],
auth_options_method: Callable[..., Mapping[str, Any]],
) -> MutableMapping[str, Any]:
"""
Get the request_option from the requester and from the paginator
Raise a ValueError if there's a key collision
@@ -197,7 +198,7 @@ class SimpleRetriever(Retriever, HttpStream):
stream_slice,
next_page_token,
self.requester.get_request_headers,
self.paginator.get_request_headers,
self._paginator.get_request_headers,
self.stream_slicer.get_request_headers,
# auth headers are handled separately by passing the authenticator to the HttpStream constructor
lambda: {},
@@ -219,7 +220,7 @@ class SimpleRetriever(Retriever, HttpStream):
stream_slice,
next_page_token,
self.requester.get_request_params,
self.paginator.get_request_params,
self._paginator.get_request_params,
self.stream_slicer.get_request_params,
self.requester.get_authenticator().get_request_params,
)
@@ -229,7 +230,7 @@ class SimpleRetriever(Retriever, HttpStream):
stream_state: StreamState,
stream_slice: Optional[StreamSlice] = None,
next_page_token: Optional[Mapping[str, Any]] = None,
) -> Optional[Union[Mapping, str]]:
) -> Optional[Union[Mapping[str, Any], str]]:
"""
Specifies how to populate the body of the request with a non-JSON payload.
@@ -244,7 +245,7 @@ class SimpleRetriever(Retriever, HttpStream):
stream_state=self.state, stream_slice=stream_slice, next_page_token=next_page_token
)
if isinstance(base_body_data, str):
paginator_body_data = self.paginator.get_request_body_data()
paginator_body_data = self._paginator.get_request_body_data()
if paginator_body_data:
raise ValueError(
f"Cannot combine requester's body data= {base_body_data} with paginator's body_data: {paginator_body_data}"
@@ -254,10 +255,11 @@ class SimpleRetriever(Retriever, HttpStream):
return self._get_request_options(
stream_slice,
next_page_token,
self.requester.get_request_body_data,
self.paginator.get_request_body_data,
self.stream_slicer.get_request_body_data,
self.requester.get_authenticator().get_request_body_data,
# body data can be a string as well, this will be fixed in the rewrite using http requester instead of http stream
self.requester.get_request_body_data, # type: ignore
self._paginator.get_request_body_data, # type: ignore
self.stream_slicer.get_request_body_data, # type: ignore
self.requester.get_authenticator().get_request_body_data, # type: ignore
)
def request_body_json(
@@ -265,7 +267,7 @@ class SimpleRetriever(Retriever, HttpStream):
stream_state: StreamState,
stream_slice: Optional[StreamSlice] = None,
next_page_token: Optional[Mapping[str, Any]] = None,
) -> Optional[Mapping]:
) -> Optional[Mapping[str, Any]]:
"""
Specifies how to populate the body of the request with a JSON payload.
@@ -275,10 +277,11 @@ class SimpleRetriever(Retriever, HttpStream):
return self._get_request_options(
stream_slice,
next_page_token,
self.requester.get_request_body_json,
self.paginator.get_request_body_json,
self.stream_slicer.get_request_body_json,
self.requester.get_authenticator().get_request_body_json,
# body json can be None as well, this will be fixed in the rewrite using http requester instead of http stream
self.requester.get_request_body_json, # type: ignore
self._paginator.get_request_body_json, # type: ignore
self.stream_slicer.get_request_body_json, # type: ignore
self.requester.get_authenticator().get_request_body_json, # type: ignore
)
def request_kwargs(
@@ -311,7 +314,7 @@ class SimpleRetriever(Retriever, HttpStream):
:return:
"""
# Warning: use self.state instead of the stream_state passed as argument!
paginator_path = self.paginator.path()
paginator_path = self._paginator.path()
if paginator_path:
return paginator_path
else:
@@ -361,7 +364,7 @@ class SimpleRetriever(Retriever, HttpStream):
self._records_from_last_response = records
return records
@property
@property # type: ignore
def primary_key(self) -> Optional[Union[str, List[str], List[List[str]]]]:
"""The stream's primary key"""
return self._primary_key
@@ -379,17 +382,19 @@ class SimpleRetriever(Retriever, HttpStream):
:return: The token for the next page from the input response object. Returning None means there are no more pages to read in this response.
"""
return self.paginator.next_page_token(response, self._records_from_last_response)
return self._paginator.next_page_token(response, self._records_from_last_response)
def read_records(
self,
sync_mode: SyncMode,
cursor_field: Optional[List[str]] = None,
stream_slice: Optional[StreamSlice] = None,
stream_state: Optional[StreamState] = None,
) -> Iterable[StreamData]:
# Warning: use self.state instead of the stream_state passed as argument!
stream_slice = stream_slice or {} # None-check
self.paginator.reset()
# Fixing paginator types has a long tail of dependencies
self._paginator.reset() # type: ignore
# Note: Adding the state per partition led to a difficult situation where the state for a partition is not the same as the
# stream_state. This means that if any class downstream wants to access the state, it would need to perform some kind of selection
# based on the partition. To short circuit this, we do the selection here which avoid downstream classes to know about it the
@@ -401,8 +406,8 @@ class SimpleRetriever(Retriever, HttpStream):
# * Why is this abstraction not on the DeclarativeStream level? DeclarativeStream does not have a notion of stream slicers and we
# would like to avoid exposing the stream state outside of the cursor. This case is needed as of 2023-06-14 because of
# interpolation.
if self.cursor and hasattr(self.cursor, "select_state"):
slice_state = self.cursor.select_state(stream_slice)
if self.cursor and hasattr(self.cursor, "select_state"): # type: ignore
slice_state = self.cursor.select_state(stream_slice) # type: ignore
elif self.cursor:
slice_state = self.cursor.get_stream_state()
else:
@@ -441,8 +446,10 @@ class SimpleRetriever(Retriever, HttpStream):
return Record(dict(stream_data), stream_slice)
elif isinstance(stream_data, AirbyteMessage) and stream_data.record:
return Record(stream_data.record.data, stream_slice)
return None
def stream_slices(self) -> Iterable[Optional[Mapping[str, Any]]]:
# stream_slices is defined with arguments on http stream and fixing this has a long tail of dependencies. Will be resolved by the decoupling of http stream and simple retriever
def stream_slices(self) -> Iterable[Optional[Mapping[str, Any]]]: # type: ignore
"""
Specifies the slices for this stream. See the stream slicing section of the docs for more information.
@@ -455,11 +462,11 @@ class SimpleRetriever(Retriever, HttpStream):
return self.stream_slicer.stream_slices()
@property
def state(self) -> MutableMapping[str, Any]:
def state(self) -> Mapping[str, Any]:
return self.cursor.get_stream_state() if self.cursor else {}
@state.setter
def state(self, value: StreamState):
def state(self, value: StreamState) -> None:
"""State setter, accept state serialized by state getter."""
if self.cursor:
self.cursor.set_initial_state(value)
@@ -483,14 +490,15 @@ class SimpleRetrieverTestReadDecorator(SimpleRetriever):
maximum_number_of_slices: int = 5
def __post_init__(self, options: Mapping[str, Any]):
def __post_init__(self, options: Mapping[str, Any]) -> None:
super().__post_init__(options)
if self.maximum_number_of_slices and self.maximum_number_of_slices < 1:
raise ValueError(
f"The maximum number of slices on a test read needs to be strictly positive. Got {self.maximum_number_of_slices}"
)
def stream_slices(self) -> Iterable[Optional[Mapping[str, Any]]]:
# stream_slices is defined with arguments on http stream and fixing this has a long tail of dependencies. Will be resolved by the decoupling of http stream and simple retriever
def stream_slices(self) -> Iterable[Optional[Mapping[str, Any]]]: # type: ignore
return islice(super().stream_slices(), self.maximum_number_of_slices)
def parse_records(