From 78728410f420aa98fd3944f2399c3c8533ae17ce Mon Sep 17 00:00:00 2001 From: Joe Reuter Date: Wed, 19 Jul 2023 15:08:35 +0200 Subject: [PATCH] Low code CDK: Fix mypy errors (#28386) * ingore unit tests in mypy check * Update airbyte-cdk/python/bin/run-mypy-on-modified-files.sh Co-authored-by: Alexandre Girard * work through mypy errors * fix a bunch of stuff * fix more type hints * fix model_to_component_factory types * format * ignore list instead of allow list --------- Co-authored-by: Alexandre Girard --- .../parsers/model_to_component_factory.py | 294 ++++++++++-------- .../retrievers/simple_retriever.py | 90 +++--- .../retrievers/test_simple_retriever.py | 4 +- 3 files changed, 212 insertions(+), 176 deletions(-) diff --git a/airbyte-cdk/python/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py b/airbyte-cdk/python/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py index 6199f37915a..5f306f78a58 100644 --- a/airbyte-cdk/python/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py +++ b/airbyte-cdk/python/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py @@ -7,7 +7,7 @@ from __future__ import annotations import importlib import inspect import re -from typing import Any, Callable, List, Literal, Mapping, Optional, Type, Union, get_args, get_origin, get_type_hints +from typing import Any, Callable, List, Mapping, Optional, Type, Union, get_args, get_origin, get_type_hints from airbyte_cdk.models import Level from airbyte_cdk.sources.declarative.auth import DeclarativeOauth2Authenticator @@ -110,7 +110,7 @@ from airbyte_cdk.sources.declarative.types import Config from airbyte_cdk.sources.message import InMemoryMessageRepository, LogAppenderMessageRepositoryDecorator, MessageRepository from pydantic import BaseModel -ComponentDefinition: Union[Literal, Mapping, List] +ComponentDefinition = Mapping[str, Any] DEFAULT_BACKOFF_STRATEGY = ExponentialBackoffStrategy @@ -119,23 +119,23 @@ DEFAULT_BACKOFF_STRATEGY = ExponentialBackoffStrategy class ModelToComponentFactory: def __init__( self, - limit_pages_fetched_per_slice: int = None, - limit_slices_fetched: int = None, + limit_pages_fetched_per_slice: Optional[int] = None, + limit_slices_fetched: Optional[int] = None, emit_connector_builder_messages: bool = False, disable_retries: bool = False, - message_repository: MessageRepository = None, + message_repository: Optional[MessageRepository] = None, ): self._init_mappings() self._limit_pages_fetched_per_slice = limit_pages_fetched_per_slice self._limit_slices_fetched = limit_slices_fetched self._emit_connector_builder_messages = emit_connector_builder_messages self._disable_retries = disable_retries - self._message_repository = message_repository or InMemoryMessageRepository( + self._message_repository = message_repository or InMemoryMessageRepository( # type: ignore self._evaluate_log_level(emit_connector_builder_messages) ) - def _init_mappings(self): - self.PYDANTIC_MODEL_TO_CONSTRUCTOR: [Type[BaseModel], Callable] = { + def _init_mappings(self) -> None: + self.PYDANTIC_MODEL_TO_CONSTRUCTOR: Mapping[Type[BaseModel], Callable[..., Any]] = { AddedFieldDefinitionModel: self.create_added_field_definition, AddFieldsModel: self.create_add_fields, ApiKeyAuthenticatorModel: self.create_api_key_authenticator, @@ -190,7 +190,9 @@ class ModelToComponentFactory: # Needed for the case where we need to perform a second parse on the fields of a custom component self.TYPE_NAME_TO_MODEL = {cls.__name__: cls for cls in self.PYDANTIC_MODEL_TO_CONSTRUCTOR} - def create_component(self, model_type: Type[BaseModel], component_definition: ComponentDefinition, config: Config, **kwargs) -> type: + def create_component( + self, model_type: Type[BaseModel], component_definition: ComponentDefinition, config: Config, **kwargs: Any + ) -> Any: """ Takes a given Pydantic model type and Mapping representing a component definition and creates a declarative component and subcomponents which will be used at runtime. This is done by first parsing the mapping into a Pydantic model and then creating @@ -213,26 +215,28 @@ class ModelToComponentFactory: return self._create_component_from_model(model=declarative_component_model, config=config, **kwargs) - def _create_component_from_model(self, model: BaseModel, config: Config, **kwargs) -> Any: + def _create_component_from_model(self, model: BaseModel, config: Config, **kwargs: Any) -> Any: if model.__class__ not in self.PYDANTIC_MODEL_TO_CONSTRUCTOR: raise ValueError(f"{model.__class__} with attributes {model} is not a valid component type") component_constructor = self.PYDANTIC_MODEL_TO_CONSTRUCTOR.get(model.__class__) + if not component_constructor: + raise ValueError(f"Could not find constructor for {model.__class__}") return component_constructor(model=model, config=config, **kwargs) @staticmethod - def create_added_field_definition(model: AddedFieldDefinitionModel, config: Config, **kwargs) -> AddedFieldDefinition: - interpolated_value = InterpolatedString.create(model.value, parameters=model.parameters) - return AddedFieldDefinition(path=model.path, value=interpolated_value, parameters=model.parameters) + def create_added_field_definition(model: AddedFieldDefinitionModel, config: Config, **kwargs: Any) -> AddedFieldDefinition: + interpolated_value = InterpolatedString.create(model.value, parameters=model.parameters or {}) + return AddedFieldDefinition(path=model.path, value=interpolated_value, parameters=model.parameters or {}) - def create_add_fields(self, model: AddFieldsModel, config: Config, **kwargs) -> AddFields: + def create_add_fields(self, model: AddFieldsModel, config: Config, **kwargs: Any) -> AddFields: added_field_definitions = [ self._create_component_from_model(model=added_field_definition_model, config=config) for added_field_definition_model in model.fields ] - return AddFields(fields=added_field_definitions, parameters=model.parameters) + return AddFields(fields=added_field_definitions, parameters=model.parameters or {}) @staticmethod - def create_api_key_authenticator(model: ApiKeyAuthenticatorModel, config: Config, **kwargs) -> ApiKeyAuthenticator: + def create_api_key_authenticator(model: ApiKeyAuthenticatorModel, config: Config, **kwargs: Any) -> ApiKeyAuthenticator: if model.inject_into is None and model.header is None: raise ValueError("Expected either inject_into or header to be set for ApiKeyAuthenticator") @@ -243,52 +247,56 @@ class ModelToComponentFactory: RequestOption( inject_into=RequestOptionType(model.inject_into.inject_into.value), field_name=model.inject_into.field_name, - parameters=model.parameters, + parameters=model.parameters or {}, ) if model.inject_into else RequestOption( inject_into=RequestOptionType.header, - field_name=model.header, - parameters=model.parameters, + field_name=model.header or "", + parameters=model.parameters or {}, ) ) - return ApiKeyAuthenticator(api_token=model.api_token, request_option=request_option, config=config, parameters=model.parameters) + return ApiKeyAuthenticator( + api_token=model.api_token or "", request_option=request_option, config=config, parameters=model.parameters or {} + ) @staticmethod - def create_basic_http_authenticator(model: BasicHttpAuthenticatorModel, config: Config, **kwargs) -> BasicHttpAuthenticator: - return BasicHttpAuthenticator(password=model.password, username=model.username, config=config, parameters=model.parameters) + def create_basic_http_authenticator(model: BasicHttpAuthenticatorModel, config: Config, **kwargs: Any) -> BasicHttpAuthenticator: + return BasicHttpAuthenticator( + password=model.password or "", username=model.username, config=config, parameters=model.parameters or {} + ) @staticmethod - def create_bearer_authenticator(model: BearerAuthenticatorModel, config: Config, **kwargs) -> BearerAuthenticator: + def create_bearer_authenticator(model: BearerAuthenticatorModel, config: Config, **kwargs: Any) -> BearerAuthenticator: return BearerAuthenticator( api_token=model.api_token, config=config, - parameters=model.parameters, + parameters=model.parameters or {}, ) @staticmethod - def create_check_stream(model: CheckStreamModel, config: Config, **kwargs): + def create_check_stream(model: CheckStreamModel, config: Config, **kwargs: Any) -> CheckStream: return CheckStream(stream_names=model.stream_names, parameters={}) - def create_composite_error_handler(self, model: CompositeErrorHandlerModel, config: Config, **kwargs) -> CompositeErrorHandler: + def create_composite_error_handler(self, model: CompositeErrorHandlerModel, config: Config, **kwargs: Any) -> CompositeErrorHandler: error_handlers = [ self._create_component_from_model(model=error_handler_model, config=config) for error_handler_model in model.error_handlers ] - return CompositeErrorHandler(error_handlers=error_handlers, parameters=model.parameters) + return CompositeErrorHandler(error_handlers=error_handlers, parameters=model.parameters or {}) @staticmethod - def create_constant_backoff_strategy(model: ConstantBackoffStrategyModel, config: Config, **kwargs) -> ConstantBackoffStrategy: + def create_constant_backoff_strategy(model: ConstantBackoffStrategyModel, config: Config, **kwargs: Any) -> ConstantBackoffStrategy: return ConstantBackoffStrategy( backoff_time_in_seconds=model.backoff_time_in_seconds, config=config, - parameters=model.parameters, + parameters=model.parameters or {}, ) - def create_cursor_pagination(self, model: CursorPaginationModel, config: Config, **kwargs) -> CursorPaginationStrategy: + def create_cursor_pagination(self, model: CursorPaginationModel, config: Config, **kwargs: Any) -> CursorPaginationStrategy: if model.decoder: decoder = self._create_component_from_model(model=model.decoder, config=config) else: - decoder = JsonDecoder(parameters=model.parameters) + decoder = JsonDecoder(parameters=model.parameters or {}) return CursorPaginationStrategy( cursor_value=model.cursor_value, @@ -296,10 +304,10 @@ class ModelToComponentFactory: page_size=model.page_size, stop_condition=model.stop_condition, config=config, - parameters=model.parameters, + parameters=model.parameters or {}, ) - def create_custom_component(self, model, config: Config, **kwargs) -> type: + def create_custom_component(self, model: Any, config: Config, **kwargs: Any) -> Any: """ Generically creates a custom component based on the model type and a class_name reference to the custom Python class being instantiated. Only the model's additional properties that match the custom class definition are passed to the constructor @@ -347,14 +355,14 @@ class ModelToComponentFactory: return custom_component_class(**kwargs) @staticmethod - def _get_class_from_fully_qualified_class_name(class_name: str) -> type: + def _get_class_from_fully_qualified_class_name(class_name: str) -> Any: split = class_name.split(".") module = ".".join(split[:-1]) class_name = split[-1] return getattr(importlib.import_module(module), class_name) @staticmethod - def _derive_component_type_from_type_hints(field_type: str) -> Optional[str]: + def _derive_component_type_from_type_hints(field_type: Any) -> Optional[str]: interface = field_type while True: origin = get_origin(interface) @@ -371,7 +379,7 @@ class ModelToComponentFactory: return None @staticmethod - def is_builtin_type(cls) -> bool: + def is_builtin_type(cls: Optional[Type[Any]]) -> bool: if not cls: return False return cls.__module__ == "builtins" @@ -384,7 +392,7 @@ class ModelToComponentFactory: else: return [] - def _create_nested_component(self, model, model_field: str, model_value: Any, config: Config) -> Any: + def _create_nested_component(self, model: Any, model_field: str, model_value: Any, config: Config) -> Any: type_name = model_value.get("type", None) if not type_name: # If no type is specified, we can assume this is a dictionary object which can be returned instead of a subcomponent @@ -420,13 +428,13 @@ class ModelToComponentFactory: @staticmethod def _is_component(model_value: Any) -> bool: - return isinstance(model_value, dict) and model_value.get("type") + return isinstance(model_value, dict) and model_value.get("type") is not None - def create_datetime_based_cursor(self, model: DatetimeBasedCursorModel, config: Config, **kwargs) -> DatetimeBasedCursor: - start_datetime = ( + def create_datetime_based_cursor(self, model: DatetimeBasedCursorModel, config: Config, **kwargs: Any) -> DatetimeBasedCursor: + start_datetime: Union[str, MinMaxDatetime] = ( model.start_datetime if isinstance(model.start_datetime, str) else self.create_min_max_datetime(model.start_datetime, config) ) - end_datetime = None + end_datetime: Union[str, MinMaxDatetime, None] = None if model.is_data_feed and model.end_datetime: raise ValueError("Data feed does not support end_datetime") if model.end_datetime: @@ -438,7 +446,7 @@ class ModelToComponentFactory: RequestOption( inject_into=RequestOptionType(model.end_time_option.inject_into.value), field_name=model.end_time_option.field_name, - parameters=model.parameters, + parameters=model.parameters or {}, ) if model.end_time_option else None @@ -447,7 +455,7 @@ class ModelToComponentFactory: RequestOption( inject_into=RequestOptionType(model.start_time_option.inject_into.value), field_name=model.start_time_option.field_name, - parameters=model.parameters, + parameters=model.parameters or {}, ) if model.start_time_option else None @@ -467,10 +475,10 @@ class ModelToComponentFactory: partition_field_start=model.partition_field_start, message_repository=self._message_repository, config=config, - parameters=model.parameters, + parameters=model.parameters or {}, ) - def create_declarative_stream(self, model: DeclarativeStreamModel, config: Config, **kwargs) -> DeclarativeStream: + def create_declarative_stream(self, model: DeclarativeStreamModel, config: Config, **kwargs: Any) -> DeclarativeStream: # When constructing a declarative stream, we assemble the incremental_sync component and retriever's partition_router field # components if they exist into a single CartesianProductStreamSlicer. This is then passed back as an argument when constructing the # Retriever. This is done in the declarative stream not the retriever to support custom retrievers. The custom create methods in @@ -505,31 +513,31 @@ class ModelToComponentFactory: schema_loader = DefaultSchemaLoader(config=config, parameters=options) return DeclarativeStream( - name=model.name, + name=model.name or "", primary_key=primary_key, retriever=retriever, schema_loader=schema_loader, stream_cursor_field=cursor_field or "", config=config, - parameters=model.parameters, + parameters=model.parameters or {}, ) def _merge_stream_slicers(self, model: DeclarativeStreamModel, config: Config) -> Optional[StreamSlicer]: stream_slicer = None if hasattr(model.retriever, "partition_router") and model.retriever.partition_router: stream_slicer_model = model.retriever.partition_router - stream_slicer = ( - CartesianProductStreamSlicer( + if isinstance(stream_slicer_model, list): + stream_slicer = CartesianProductStreamSlicer( [self._create_component_from_model(model=slicer, config=config) for slicer in stream_slicer_model], parameters={} ) - if type(stream_slicer_model) == list - else self._create_component_from_model(model=stream_slicer_model, config=config) - ) + else: + stream_slicer = self._create_component_from_model(model=stream_slicer_model, config=config) if model.incremental_sync and stream_slicer: + incremental_sync_model = model.incremental_sync return PerPartitionCursor( cursor_factory=CursorFactory( - lambda: self._create_component_from_model(model=model.incremental_sync, config=config), + lambda: self._create_component_from_model(model=incremental_sync_model, config=config), ), partition_router=stream_slicer, ) @@ -540,13 +548,13 @@ class ModelToComponentFactory: else: return None - def create_default_error_handler(self, model: DefaultErrorHandlerModel, config: Config, **kwargs) -> DefaultErrorHandler: + def create_default_error_handler(self, model: DefaultErrorHandlerModel, config: Config, **kwargs: Any) -> DefaultErrorHandler: backoff_strategies = [] if model.backoff_strategies: for backoff_strategy_model in model.backoff_strategies: backoff_strategies.append(self._create_component_from_model(model=backoff_strategy_model, config=config)) else: - backoff_strategies.append(DEFAULT_BACKOFF_STRATEGY(config=config, parameters=model.parameters)) + backoff_strategies.append(DEFAULT_BACKOFF_STRATEGY(config=config, parameters=model.parameters or {})) response_filters = [] if model.response_filters: @@ -555,22 +563,25 @@ class ModelToComponentFactory: else: response_filters.append( HttpResponseFilter( - ResponseAction.RETRY, http_codes=HttpResponseFilter.DEFAULT_RETRIABLE_ERRORS, config=config, parameters=model.parameters + ResponseAction.RETRY, + http_codes=HttpResponseFilter.DEFAULT_RETRIABLE_ERRORS, + config=config, + parameters=model.parameters or {}, ) ) - response_filters.append(HttpResponseFilter(ResponseAction.IGNORE, config=config, parameters=model.parameters)) + response_filters.append(HttpResponseFilter(ResponseAction.IGNORE, config=config, parameters=model.parameters or {})) return DefaultErrorHandler( backoff_strategies=backoff_strategies, max_retries=model.max_retries, response_filters=response_filters, config=config, - parameters=model.parameters, + parameters=model.parameters or {}, ) def create_default_paginator( self, model: DefaultPaginatorModel, config: Config, *, url_base: str, cursor_used_for_stop_condition: Optional[Cursor] = None - ) -> DefaultPaginator: + ) -> Union[DefaultPaginator, PaginatorTestReadDecorator]: decoder = self._create_component_from_model(model=model.decoder, config=config) if model.decoder else JsonDecoder(parameters={}) page_size_option = ( self._create_component_from_model(model=model.page_size_option, config=config) if model.page_size_option else None @@ -591,19 +602,20 @@ class ModelToComponentFactory: pagination_strategy=pagination_strategy, url_base=url_base, config=config, - parameters=model.parameters, + parameters=model.parameters or {}, ) if self._limit_pages_fetched_per_slice: return PaginatorTestReadDecorator(paginator, self._limit_pages_fetched_per_slice) return paginator - def create_dpath_extractor(self, model: DpathExtractorModel, config: Config, **kwargs) -> DpathExtractor: + def create_dpath_extractor(self, model: DpathExtractorModel, config: Config, **kwargs: Any) -> DpathExtractor: decoder = self._create_component_from_model(model.decoder, config=config) if model.decoder else JsonDecoder(parameters={}) - return DpathExtractor(decoder=decoder, field_path=model.field_path, config=config, parameters=model.parameters) + model_field_path: List[Union[InterpolatedString, str]] = [x for x in model.field_path] + return DpathExtractor(decoder=decoder, field_path=model_field_path, config=config, parameters=model.parameters or {}) @staticmethod def create_exponential_backoff_strategy(model: ExponentialBackoffStrategyModel, config: Config) -> ExponentialBackoffStrategy: - return ExponentialBackoffStrategy(factor=model.factor, parameters=model.parameters, config=config) + return ExponentialBackoffStrategy(factor=model.factor or 5, parameters=model.parameters or {}, config=config) def create_http_requester(self, model: HttpRequesterModel, config: Config, *, name: str) -> HttpRequester: authenticator = ( @@ -614,7 +626,7 @@ class ModelToComponentFactory: error_handler = ( self._create_component_from_model(model=model.error_handler, config=config) if model.error_handler - else DefaultErrorHandler(backoff_strategies=[], response_filters=[], config=config, parameters=model.parameters) + else DefaultErrorHandler(backoff_strategies=[], response_filters=[], config=config, parameters=model.parameters or {}) ) request_options_provider = InterpolatedRequestOptionsProvider( @@ -623,7 +635,11 @@ class ModelToComponentFactory: request_headers=model.request_headers, request_parameters=model.request_parameters, config=config, - parameters=model.parameters, + parameters=model.parameters or {}, + ) + + model_http_method = ( + model.http_method if isinstance(model.http_method, str) else model.http_method.value if model.http_method is not None else "GET" ) return HttpRequester( @@ -632,14 +648,14 @@ class ModelToComponentFactory: path=model.path, authenticator=authenticator, error_handler=error_handler, - http_method=model.http_method, + http_method=model_http_method, request_options_provider=request_options_provider, config=config, - parameters=model.parameters, + parameters=model.parameters or {}, ) @staticmethod - def create_http_response_filter(model: HttpResponseFilterModel, config: Config, **kwargs) -> HttpResponseFilter: + def create_http_response_filter(model: HttpResponseFilterModel, config: Config, **kwargs: Any) -> HttpResponseFilter: action = ResponseAction(model.action.value) http_codes = ( set(model.http_codes) if model.http_codes else set() @@ -648,32 +664,32 @@ class ModelToComponentFactory: return HttpResponseFilter( action=action, error_message=model.error_message or "", - error_message_contains=model.error_message_contains, + error_message_contains=model.error_message_contains or "", http_codes=http_codes, predicate=model.predicate or "", config=config, - parameters=model.parameters, + parameters=model.parameters or {}, ) @staticmethod - def create_inline_schema_loader(model: InlineSchemaLoaderModel, config: Config, **kwargs) -> InlineSchemaLoader: - return InlineSchemaLoader(schema=model.schema_, parameters={}) + def create_inline_schema_loader(model: InlineSchemaLoaderModel, config: Config, **kwargs: Any) -> InlineSchemaLoader: + return InlineSchemaLoader(schema=model.schema_ or {}, parameters={}) @staticmethod - def create_json_decoder(model: JsonDecoderModel, config: Config, **kwargs) -> JsonDecoder: + def create_json_decoder(model: JsonDecoderModel, config: Config, **kwargs: Any) -> JsonDecoder: return JsonDecoder(parameters={}) @staticmethod - def create_json_file_schema_loader(model: JsonFileSchemaLoaderModel, config: Config, **kwargs) -> JsonFileSchemaLoader: - return JsonFileSchemaLoader(file_path=model.file_path, config=config, parameters=model.parameters) + def create_json_file_schema_loader(model: JsonFileSchemaLoaderModel, config: Config, **kwargs: Any) -> JsonFileSchemaLoader: + return JsonFileSchemaLoader(file_path=model.file_path or "", config=config, parameters=model.parameters or {}) @staticmethod - def create_list_partition_router(model: ListPartitionRouterModel, config: Config, **kwargs) -> ListPartitionRouter: + def create_list_partition_router(model: ListPartitionRouterModel, config: Config, **kwargs: Any) -> ListPartitionRouter: request_option = ( RequestOption( inject_into=RequestOptionType(model.request_option.inject_into.value), field_name=model.request_option.field_name, - parameters=model.parameters, + parameters=model.parameters or {}, ) if model.request_option else None @@ -683,72 +699,78 @@ class ModelToComponentFactory: request_option=request_option, values=model.values, config=config, - parameters=model.parameters, + parameters=model.parameters or {}, ) @staticmethod - def create_min_max_datetime(model: MinMaxDatetimeModel, config: Config, **kwargs) -> MinMaxDatetime: + def create_min_max_datetime(model: MinMaxDatetimeModel, config: Config, **kwargs: Any) -> MinMaxDatetime: return MinMaxDatetime( datetime=model.datetime, - datetime_format=model.datetime_format, - max_datetime=model.max_datetime, - min_datetime=model.min_datetime, - parameters=model.parameters, + datetime_format=model.datetime_format or "", + max_datetime=model.max_datetime or "", + min_datetime=model.min_datetime or "", + parameters=model.parameters or {}, ) @staticmethod - def create_no_auth(model: NoAuthModel, config: Config, **kwargs) -> NoAuth: - return NoAuth(parameters=model.parameters) + def create_no_auth(model: NoAuthModel, config: Config, **kwargs: Any) -> NoAuth: + return NoAuth(parameters=model.parameters or {}) @staticmethod - def create_no_pagination(model: NoPaginationModel, config: Config, **kwargs) -> NoPagination: + def create_no_pagination(model: NoPaginationModel, config: Config, **kwargs: Any) -> NoPagination: return NoPagination(parameters={}) - def create_oauth_authenticator(self, model: OAuthAuthenticatorModel, config: Config, **kwargs) -> DeclarativeOauth2Authenticator: + def create_oauth_authenticator(self, model: OAuthAuthenticatorModel, config: Config, **kwargs: Any) -> DeclarativeOauth2Authenticator: if model.refresh_token_updater: - return DeclarativeSingleUseRefreshTokenOauth2Authenticator( + # ignore type error beause fixing it would have a lot of dependencies, revisit later + return DeclarativeSingleUseRefreshTokenOauth2Authenticator( # type: ignore config, - InterpolatedString.create(model.token_refresh_endpoint, parameters=model.parameters).eval(config), - access_token_name=InterpolatedString.create(model.access_token_name, parameters=model.parameters).eval(config), + InterpolatedString.create(model.token_refresh_endpoint, parameters=model.parameters or {}).eval(config), + access_token_name=InterpolatedString.create( + model.access_token_name or "access_token", parameters=model.parameters or {} + ).eval(config), refresh_token_name=model.refresh_token_updater.refresh_token_name, - expires_in_name=InterpolatedString.create(model.expires_in_name, parameters=model.parameters).eval(config), - client_id=InterpolatedString.create(model.client_id, parameters=model.parameters).eval(config), - client_secret=InterpolatedString.create(model.client_secret, parameters=model.parameters).eval(config), + expires_in_name=InterpolatedString.create(model.expires_in_name or "expires_in", parameters=model.parameters or {}).eval( + config + ), + client_id=InterpolatedString.create(model.client_id, parameters=model.parameters or {}).eval(config), + client_secret=InterpolatedString.create(model.client_secret, parameters=model.parameters or {}).eval(config), access_token_config_path=model.refresh_token_updater.access_token_config_path, refresh_token_config_path=model.refresh_token_updater.refresh_token_config_path, token_expiry_date_config_path=model.refresh_token_updater.token_expiry_date_config_path, - grant_type=InterpolatedString.create(model.grant_type, parameters=model.parameters).eval(config), - refresh_request_body=InterpolatedMapping(model.refresh_request_body or {}, parameters=model.parameters).eval(config), + grant_type=InterpolatedString.create(model.grant_type or "refresh_token", parameters=model.parameters or {}).eval(config), + refresh_request_body=InterpolatedMapping(model.refresh_request_body or {}, parameters=model.parameters or {}).eval(config), scopes=model.scopes, token_expiry_date_format=model.token_expiry_date_format, message_repository=self._message_repository, ) - return DeclarativeOauth2Authenticator( - access_token_name=model.access_token_name, + # ignore type error beause fixing it would have a lot of dependencies, revisit later + return DeclarativeOauth2Authenticator( # type: ignore + access_token_name=model.access_token_name or "access_token", client_id=model.client_id, client_secret=model.client_secret, - expires_in_name=model.expires_in_name, - grant_type=model.grant_type, + expires_in_name=model.expires_in_name or "expires_in", + grant_type=model.grant_type or "refresh_token", refresh_request_body=model.refresh_request_body, refresh_token=model.refresh_token, scopes=model.scopes, token_expiry_date=model.token_expiry_date, - token_expiry_date_format=model.token_expiry_date_format, + token_expiry_date_format=model.token_expiry_date_format, # type: ignore token_refresh_endpoint=model.token_refresh_endpoint, config=config, - parameters=model.parameters, + parameters=model.parameters or {}, message_repository=self._message_repository, ) @staticmethod - def create_offset_increment(model: OffsetIncrementModel, config: Config, **kwargs) -> OffsetIncrement: - return OffsetIncrement(page_size=model.page_size, config=config, parameters=model.parameters) + def create_offset_increment(model: OffsetIncrementModel, config: Config, **kwargs: Any) -> OffsetIncrement: + return OffsetIncrement(page_size=model.page_size, config=config, parameters=model.parameters or {}) @staticmethod - def create_page_increment(model: PageIncrementModel, config: Config, **kwargs) -> PageIncrement: - return PageIncrement(page_size=model.page_size, start_from_page=model.start_from_page, parameters=model.parameters) + def create_page_increment(model: PageIncrementModel, config: Config, **kwargs: Any) -> PageIncrement: + return PageIncrement(page_size=model.page_size, start_from_page=model.start_from_page or 0, parameters=model.parameters or {}) - def create_parent_stream_config(self, model: ParentStreamConfigModel, config: Config, **kwargs) -> ParentStreamConfig: + def create_parent_stream_config(self, model: ParentStreamConfigModel, config: Config, **kwargs: Any) -> ParentStreamConfig: declarative_stream = self._create_component_from_model(model.stream, config=config) request_option = self._create_component_from_model(model.request_option, config=config) if model.request_option else None return ParentStreamConfig( @@ -757,51 +779,55 @@ class ModelToComponentFactory: stream=declarative_stream, partition_field=model.partition_field, config=config, - parameters=model.parameters, + parameters=model.parameters or {}, ) @staticmethod - def create_record_filter(model: RecordFilterModel, config: Config, **kwargs) -> RecordFilter: - return RecordFilter(condition=model.condition, config=config, parameters=model.parameters) + def create_record_filter(model: RecordFilterModel, config: Config, **kwargs: Any) -> RecordFilter: + return RecordFilter(condition=model.condition or "", config=config, parameters=model.parameters or {}) @staticmethod - def create_request_path(model: RequestPathModel, config: Config, **kwargs) -> RequestPath: + def create_request_path(model: RequestPathModel, config: Config, **kwargs: Any) -> RequestPath: return RequestPath(parameters={}) @staticmethod - def create_request_option(model: RequestOptionModel, config: Config, **kwargs) -> RequestOption: + def create_request_option(model: RequestOptionModel, config: Config, **kwargs: Any) -> RequestOption: inject_into = RequestOptionType(model.inject_into.value) return RequestOption(field_name=model.field_name, inject_into=inject_into, parameters={}) def create_record_selector( - self, model: RecordSelectorModel, config: Config, *, transformations: List[RecordTransformation], **kwargs + self, model: RecordSelectorModel, config: Config, *, transformations: List[RecordTransformation], **kwargs: Any ) -> RecordSelector: extractor = self._create_component_from_model(model=model.extractor, config=config) record_filter = self._create_component_from_model(model.record_filter, config=config) if model.record_filter else None return RecordSelector( - extractor=extractor, config=config, record_filter=record_filter, transformations=transformations, parameters=model.parameters + extractor=extractor, + config=config, + record_filter=record_filter, + transformations=transformations, + parameters=model.parameters or {}, ) @staticmethod - def create_remove_fields(model: RemoveFieldsModel, config: Config, **kwargs) -> RemoveFields: + def create_remove_fields(model: RemoveFieldsModel, config: Config, **kwargs: Any) -> RemoveFields: return RemoveFields(field_pointers=model.field_pointers, parameters={}) @staticmethod def create_session_token_authenticator( - model: SessionTokenAuthenticatorModel, config: Config, *, url_base: str, **kwargs + model: SessionTokenAuthenticatorModel, config: Config, *, url_base: str, **kwargs: Any ) -> SessionTokenAuthenticator: return SessionTokenAuthenticator( api_url=url_base, header=model.header, login_url=model.login_url, - password=model.password, - session_token=model.session_token, - session_token_response_key=model.session_token_response_key, - username=model.username, + password=model.password or "", + session_token=model.session_token or "", + session_token_response_key=model.session_token_response_key or "", + username=model.username or "", validate_session_url=model.validate_session_url, config=config, - parameters=model.parameters, + parameters=model.parameters or {}, ) def create_simple_retriever( @@ -840,8 +866,8 @@ class ModelToComponentFactory: stream_slicer=stream_slicer, cursor=cursor, config=config, - maximum_number_of_slices=self._limit_slices_fetched, - parameters=model.parameters, + maximum_number_of_slices=self._limit_slices_fetched or 5, + parameters=model.parameters or {}, disable_retries=self._disable_retries, message_repository=self._message_repository, ) @@ -854,13 +880,13 @@ class ModelToComponentFactory: stream_slicer=stream_slicer, cursor=cursor, config=config, - parameters=model.parameters, + parameters=model.parameters or {}, disable_retries=self._disable_retries, message_repository=self._message_repository, ) @staticmethod - def create_spec(model: SpecModel, config: Config, **kwargs) -> Spec: + def create_spec(model: SpecModel, config: Config, **kwargs: Any) -> Spec: return Spec( connection_specification=model.connection_specification, documentation_url=model.documentation_url, @@ -868,7 +894,9 @@ class ModelToComponentFactory: parameters={}, ) - def create_substream_partition_router(self, model: SubstreamPartitionRouterModel, config: Config, **kwargs) -> SubstreamPartitionRouter: + def create_substream_partition_router( + self, model: SubstreamPartitionRouterModel, config: Config, **kwargs: Any + ) -> SubstreamPartitionRouter: parent_stream_configs = [] if model.parent_stream_configs: parent_stream_configs.extend( @@ -878,9 +906,9 @@ class ModelToComponentFactory: ] ) - return SubstreamPartitionRouter(parent_stream_configs=parent_stream_configs, parameters=model.parameters, config=config) + return SubstreamPartitionRouter(parent_stream_configs=parent_stream_configs, parameters=model.parameters or {}, config=config) - def _create_message_repository_substream_wrapper(self, model, config): + def _create_message_repository_substream_wrapper(self, model: ParentStreamConfigModel, config: Config) -> Any: substream_factory = ModelToComponentFactory( limit_pages_fetched_per_slice=self._limit_pages_fetched_per_slice, limit_slices_fetched=self._limit_slices_fetched, @@ -895,19 +923,19 @@ class ModelToComponentFactory: return substream_factory._create_component_from_model(model=model, config=config) @staticmethod - def create_wait_time_from_header(model: WaitTimeFromHeaderModel, config: Config, **kwargs) -> WaitTimeFromHeaderBackoffStrategy: - return WaitTimeFromHeaderBackoffStrategy(header=model.header, parameters=model.parameters, config=config, regex=model.regex) + def create_wait_time_from_header(model: WaitTimeFromHeaderModel, config: Config, **kwargs: Any) -> WaitTimeFromHeaderBackoffStrategy: + return WaitTimeFromHeaderBackoffStrategy(header=model.header, parameters=model.parameters or {}, config=config, regex=model.regex) @staticmethod def create_wait_until_time_from_header( - model: WaitUntilTimeFromHeaderModel, config: Config, **kwargs + model: WaitUntilTimeFromHeaderModel, config: Config, **kwargs: Any ) -> WaitUntilTimeFromHeaderBackoffStrategy: return WaitUntilTimeFromHeaderBackoffStrategy( - header=model.header, parameters=model.parameters, config=config, min_wait=model.min_wait, regex=model.regex + header=model.header, parameters=model.parameters or {}, config=config, min_wait=model.min_wait, regex=model.regex ) - def get_message_repository(self): + def get_message_repository(self) -> MessageRepository: return self._message_repository - def _evaluate_log_level(self, emit_connector_builder_messages) -> Level: + def _evaluate_log_level(self, emit_connector_builder_messages: bool) -> Level: return Level.DEBUG if emit_connector_builder_messages else Level.INFO diff --git a/airbyte-cdk/python/airbyte_cdk/sources/declarative/retrievers/simple_retriever.py b/airbyte-cdk/python/airbyte_cdk/sources/declarative/retrievers/simple_retriever.py index 55ce658873d..a6fe6c1ffe1 100644 --- a/airbyte-cdk/python/airbyte_cdk/sources/declarative/retrievers/simple_retriever.py +++ b/airbyte-cdk/python/airbyte_cdk/sources/declarative/retrievers/simple_retriever.py @@ -4,7 +4,7 @@ from dataclasses import InitVar, dataclass, field from itertools import islice -from typing import Any, Iterable, List, Mapping, MutableMapping, Optional, Union +from typing import Any, Callable, Iterable, List, Mapping, MutableMapping, Optional, Union import requests from airbyte_cdk.models import AirbyteMessage, Level, SyncMode @@ -61,25 +61,25 @@ class SimpleRetriever(Retriever, HttpStream): primary_key: Optional[Union[str, List[str], List[List[str]]]] _primary_key: str = field(init=False, repr=False, default="") paginator: Optional[Paginator] = None - stream_slicer: Optional[StreamSlicer] = SinglePartitionRouter(parameters={}) + stream_slicer: StreamSlicer = SinglePartitionRouter(parameters={}) cursor: Optional[Cursor] = None disable_retries: bool = False message_repository: MessageRepository = NoopMessageRepository() - def __post_init__(self, parameters: Mapping[str, Any]): - self.paginator = self.paginator or NoPagination(parameters=parameters) + def __post_init__(self, parameters: Mapping[str, Any]) -> None: + self._paginator = self.paginator or NoPagination(parameters=parameters) + self._last_response: Optional[requests.Response] = None + self._records_from_last_response: List[Record] = [] HttpStream.__init__(self, self.requester.get_authenticator()) - self._last_response = None - self._records_from_last_response = None self._parameters = parameters - self.name = InterpolatedString(self._name, parameters=parameters) + self._name = InterpolatedString(self._name, parameters=parameters) if isinstance(self._name, str) else self._name - @property + @property # type: ignore def name(self) -> str: """ :return: Stream name """ - return self._name.eval(self.config) + return str(self._name.eval(self.config)) if isinstance(self._name, InterpolatedString) else self._name @name.setter def name(self, value: str) -> None: @@ -103,8 +103,9 @@ class SimpleRetriever(Retriever, HttpStream): def max_retries(self) -> Union[int, None]: if self.disable_retries: return 0 - if hasattr(self.requester.error_handler, "max_retries"): - return self.requester.error_handler.max_retries + # this will be removed once simple_retriever is decoupled from http_stream + if hasattr(self.requester.error_handler, "max_retries"): # type: ignore + return self.requester.error_handler.max_retries # type: ignore return self._DEFAULT_MAX_RETRY def should_retry(self, response: requests.Response) -> bool: @@ -117,7 +118,7 @@ class SimpleRetriever(Retriever, HttpStream): Unexpected but transient exceptions (connection timeout, DNS resolution failed, etc..) are retried by default. """ - return self.requester.interpret_response_status(response).action == ResponseAction.RETRY + return bool(self.requester.interpret_response_status(response).action == ResponseAction.RETRY) def backoff_time(self, response: requests.Response) -> Optional[float]: """ @@ -148,11 +149,11 @@ class SimpleRetriever(Retriever, HttpStream): self, stream_slice: Optional[StreamSlice], next_page_token: Optional[Mapping[str, Any]], - requester_method, - paginator_method, - stream_slicer_method, - auth_options_method, - ): + requester_method: Callable[..., Mapping[str, Any]], + paginator_method: Callable[..., Mapping[str, Any]], + stream_slicer_method: Callable[..., Mapping[str, Any]], + auth_options_method: Callable[..., Mapping[str, Any]], + ) -> MutableMapping[str, Any]: """ Get the request_option from the requester and from the paginator Raise a ValueError if there's a key collision @@ -197,7 +198,7 @@ class SimpleRetriever(Retriever, HttpStream): stream_slice, next_page_token, self.requester.get_request_headers, - self.paginator.get_request_headers, + self._paginator.get_request_headers, self.stream_slicer.get_request_headers, # auth headers are handled separately by passing the authenticator to the HttpStream constructor lambda: {}, @@ -219,7 +220,7 @@ class SimpleRetriever(Retriever, HttpStream): stream_slice, next_page_token, self.requester.get_request_params, - self.paginator.get_request_params, + self._paginator.get_request_params, self.stream_slicer.get_request_params, self.requester.get_authenticator().get_request_params, ) @@ -229,7 +230,7 @@ class SimpleRetriever(Retriever, HttpStream): stream_state: StreamState, stream_slice: Optional[StreamSlice] = None, next_page_token: Optional[Mapping[str, Any]] = None, - ) -> Optional[Union[Mapping, str]]: + ) -> Optional[Union[Mapping[str, Any], str]]: """ Specifies how to populate the body of the request with a non-JSON payload. @@ -244,7 +245,7 @@ class SimpleRetriever(Retriever, HttpStream): stream_state=self.state, stream_slice=stream_slice, next_page_token=next_page_token ) if isinstance(base_body_data, str): - paginator_body_data = self.paginator.get_request_body_data() + paginator_body_data = self._paginator.get_request_body_data() if paginator_body_data: raise ValueError( f"Cannot combine requester's body data= {base_body_data} with paginator's body_data: {paginator_body_data}" @@ -254,10 +255,11 @@ class SimpleRetriever(Retriever, HttpStream): return self._get_request_options( stream_slice, next_page_token, - self.requester.get_request_body_data, - self.paginator.get_request_body_data, - self.stream_slicer.get_request_body_data, - self.requester.get_authenticator().get_request_body_data, + # body data can be a string as well, this will be fixed in the rewrite using http requester instead of http stream + self.requester.get_request_body_data, # type: ignore + self._paginator.get_request_body_data, # type: ignore + self.stream_slicer.get_request_body_data, # type: ignore + self.requester.get_authenticator().get_request_body_data, # type: ignore ) def request_body_json( @@ -265,7 +267,7 @@ class SimpleRetriever(Retriever, HttpStream): stream_state: StreamState, stream_slice: Optional[StreamSlice] = None, next_page_token: Optional[Mapping[str, Any]] = None, - ) -> Optional[Mapping]: + ) -> Optional[Mapping[str, Any]]: """ Specifies how to populate the body of the request with a JSON payload. @@ -275,10 +277,11 @@ class SimpleRetriever(Retriever, HttpStream): return self._get_request_options( stream_slice, next_page_token, - self.requester.get_request_body_json, - self.paginator.get_request_body_json, - self.stream_slicer.get_request_body_json, - self.requester.get_authenticator().get_request_body_json, + # body json can be None as well, this will be fixed in the rewrite using http requester instead of http stream + self.requester.get_request_body_json, # type: ignore + self._paginator.get_request_body_json, # type: ignore + self.stream_slicer.get_request_body_json, # type: ignore + self.requester.get_authenticator().get_request_body_json, # type: ignore ) def request_kwargs( @@ -311,7 +314,7 @@ class SimpleRetriever(Retriever, HttpStream): :return: """ # Warning: use self.state instead of the stream_state passed as argument! - paginator_path = self.paginator.path() + paginator_path = self._paginator.path() if paginator_path: return paginator_path else: @@ -361,7 +364,7 @@ class SimpleRetriever(Retriever, HttpStream): self._records_from_last_response = records return records - @property + @property # type: ignore def primary_key(self) -> Optional[Union[str, List[str], List[List[str]]]]: """The stream's primary key""" return self._primary_key @@ -379,17 +382,19 @@ class SimpleRetriever(Retriever, HttpStream): :return: The token for the next page from the input response object. Returning None means there are no more pages to read in this response. """ - return self.paginator.next_page_token(response, self._records_from_last_response) + return self._paginator.next_page_token(response, self._records_from_last_response) def read_records( self, sync_mode: SyncMode, cursor_field: Optional[List[str]] = None, stream_slice: Optional[StreamSlice] = None, + stream_state: Optional[StreamState] = None, ) -> Iterable[StreamData]: # Warning: use self.state instead of the stream_state passed as argument! stream_slice = stream_slice or {} # None-check - self.paginator.reset() + # Fixing paginator types has a long tail of dependencies + self._paginator.reset() # type: ignore # Note: Adding the state per partition led to a difficult situation where the state for a partition is not the same as the # stream_state. This means that if any class downstream wants to access the state, it would need to perform some kind of selection # based on the partition. To short circuit this, we do the selection here which avoid downstream classes to know about it the @@ -401,8 +406,8 @@ class SimpleRetriever(Retriever, HttpStream): # * Why is this abstraction not on the DeclarativeStream level? DeclarativeStream does not have a notion of stream slicers and we # would like to avoid exposing the stream state outside of the cursor. This case is needed as of 2023-06-14 because of # interpolation. - if self.cursor and hasattr(self.cursor, "select_state"): - slice_state = self.cursor.select_state(stream_slice) + if self.cursor and hasattr(self.cursor, "select_state"): # type: ignore + slice_state = self.cursor.select_state(stream_slice) # type: ignore elif self.cursor: slice_state = self.cursor.get_stream_state() else: @@ -441,8 +446,10 @@ class SimpleRetriever(Retriever, HttpStream): return Record(dict(stream_data), stream_slice) elif isinstance(stream_data, AirbyteMessage) and stream_data.record: return Record(stream_data.record.data, stream_slice) + return None - def stream_slices(self) -> Iterable[Optional[Mapping[str, Any]]]: + # stream_slices is defined with arguments on http stream and fixing this has a long tail of dependencies. Will be resolved by the decoupling of http stream and simple retriever + def stream_slices(self) -> Iterable[Optional[Mapping[str, Any]]]: # type: ignore """ Specifies the slices for this stream. See the stream slicing section of the docs for more information. @@ -455,11 +462,11 @@ class SimpleRetriever(Retriever, HttpStream): return self.stream_slicer.stream_slices() @property - def state(self) -> MutableMapping[str, Any]: + def state(self) -> Mapping[str, Any]: return self.cursor.get_stream_state() if self.cursor else {} @state.setter - def state(self, value: StreamState): + def state(self, value: StreamState) -> None: """State setter, accept state serialized by state getter.""" if self.cursor: self.cursor.set_initial_state(value) @@ -483,14 +490,15 @@ class SimpleRetrieverTestReadDecorator(SimpleRetriever): maximum_number_of_slices: int = 5 - def __post_init__(self, options: Mapping[str, Any]): + def __post_init__(self, options: Mapping[str, Any]) -> None: super().__post_init__(options) if self.maximum_number_of_slices and self.maximum_number_of_slices < 1: raise ValueError( f"The maximum number of slices on a test read needs to be strictly positive. Got {self.maximum_number_of_slices}" ) - def stream_slices(self) -> Iterable[Optional[Mapping[str, Any]]]: + # stream_slices is defined with arguments on http stream and fixing this has a long tail of dependencies. Will be resolved by the decoupling of http stream and simple retriever + def stream_slices(self) -> Iterable[Optional[Mapping[str, Any]]]: # type: ignore return islice(super().stream_slices(), self.maximum_number_of_slices) def parse_records( diff --git a/airbyte-cdk/python/unit_tests/sources/declarative/retrievers/test_simple_retriever.py b/airbyte-cdk/python/unit_tests/sources/declarative/retrievers/test_simple_retriever.py index f68e4ef63ce..3ecbb50543a 100644 --- a/airbyte-cdk/python/unit_tests/sources/declarative/retrievers/test_simple_retriever.py +++ b/airbyte-cdk/python/unit_tests/sources/declarative/retrievers/test_simple_retriever.py @@ -101,7 +101,7 @@ def test_simple_retriever_full(mock_http_stream): assert retriever.stream_slices() == stream_slices assert retriever._last_response is None - assert retriever._records_from_last_response is None + assert retriever._records_from_last_response == [] assert retriever.parse_response(response, stream_state={}) == records assert retriever._last_response == response assert retriever._records_from_last_response == records @@ -185,7 +185,7 @@ def test_simple_retriever_with_request_response_log_last_records(mock_http_strea ) assert retriever._last_response is None - assert retriever._records_from_last_response is None + assert retriever._records_from_last_response == [] assert retriever.parse_response(response, stream_state={}) == request_response_logs assert retriever._last_response == response assert retriever._records_from_last_response == request_response_logs