1
0
mirror of synced 2025-12-25 02:09:19 -05:00

CDK: Fix typing errors (#9037)

* fix typing, drop AirbyteLogger

* format

* bump the version

* use logger instead of fixture logger

Co-authored-by: Eugene Kulak <kulak.eugene@gmail.com>
Co-authored-by: auganbay <auganenu@gmail.com>
This commit is contained in:
Eugene Kulak
2022-01-14 06:29:34 +02:00
committed by GitHub
parent 41f89d1ab2
commit f83eca58ea
26 changed files with 126 additions and 144 deletions

View File

@@ -1,5 +1,8 @@
# Changelog
## 0.1.47
Fix typing errors.
## 0.1.45
Integrate Sentry for performance and errors tracking.

View File

@@ -4,12 +4,12 @@
import json
import logging
import os
import pkgutil
from abc import ABC, abstractmethod
from typing import Any, Mapping, Optional
from airbyte_cdk.logger import AirbyteLogger
from airbyte_cdk.models import AirbyteConnectionStatus, ConnectorSpecification
@@ -48,7 +48,7 @@ class Connector(ABC):
with open(config_path, "w") as fh:
fh.write(json.dumps(config))
def spec(self, logger: AirbyteLogger) -> ConnectorSpecification:
def spec(self, logger: logging.Logger) -> ConnectorSpecification:
"""
Returns the spec for this integration. The spec is a JSON-Schema object describing the required configurations (e.g: username and password)
required to run this integration.
@@ -59,7 +59,7 @@ class Connector(ABC):
return ConnectorSpecification.parse_obj(json.loads(raw_spec))
@abstractmethod
def check(self, logger: AirbyteLogger, config: Mapping[str, Any]) -> AirbyteConnectionStatus:
def check(self, logger: logging.Logger, config: Mapping[str, Any]) -> AirbyteConnectionStatus:
"""
Tests if the input configuration can be used to successfully connect to the integration e.g: if a provided Stripe API token can be used to connect
to the Stripe API.

View File

@@ -4,19 +4,20 @@
import argparse
import io
import logging
import sys
from abc import ABC, abstractmethod
from typing import Any, Iterable, List, Mapping
from airbyte_cdk import AirbyteLogger
from airbyte_cdk.connector import Connector
from airbyte_cdk.models import AirbyteMessage, ConfiguredAirbyteCatalog, Type
from airbyte_cdk.sources.utils.schema_helpers import check_config_against_spec_or_exit
from pydantic import ValidationError
logger = logging.getLogger("airbyte")
class Destination(Connector, ABC):
logger = AirbyteLogger()
VALID_CMDS = {"spec", "check", "write"}
@abstractmethod
@@ -26,7 +27,7 @@ class Destination(Connector, ABC):
"""Implement to define how the connector writes data to the destination"""
def _run_check(self, config: Mapping[str, Any]) -> AirbyteMessage:
check_result = self.check(self.logger, config)
check_result = self.check(logger, config)
return AirbyteMessage(type=Type.CONNECTION_STATUS, connectionStatus=check_result)
def _parse_input_stream(self, input_stream: io.TextIOWrapper) -> Iterable[AirbyteMessage]:
@@ -35,16 +36,16 @@ class Destination(Connector, ABC):
try:
yield AirbyteMessage.parse_raw(line)
except ValidationError:
self.logger.info(f"ignoring input which can't be deserialized as Airbyte Message: {line}")
logger.info(f"ignoring input which can't be deserialized as Airbyte Message: {line}")
def _run_write(
self, config: Mapping[str, Any], configured_catalog_path: str, input_stream: io.TextIOWrapper
) -> Iterable[AirbyteMessage]:
catalog = ConfiguredAirbyteCatalog.parse_file(configured_catalog_path)
input_messages = self._parse_input_stream(input_stream)
self.logger.info("Begin writing to the destination...")
logger.info("Begin writing to the destination...")
yield from self.write(config=config, configured_catalog=catalog, input_messages=input_messages)
self.logger.info("Writing complete.")
logger.info("Writing complete.")
def parse_args(self, args: List[str]) -> argparse.Namespace:
"""
@@ -86,7 +87,7 @@ class Destination(Connector, ABC):
if cmd not in self.VALID_CMDS:
raise Exception(f"Unrecognized command: {cmd}")
spec = self.spec(self.logger)
spec = self.spec(logger)
if cmd == "spec":
yield AirbyteMessage(type=Type.SPEC, spec=spec)
return

View File

@@ -6,7 +6,7 @@ import logging
import logging.config
import sys
import traceback
from typing import List
from typing import List, Tuple
from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage
@@ -49,7 +49,6 @@ def init_unhandled_exception_output_filtering(logger: logging.Logger) -> None:
def init_logger(name: str = None):
"""Initial set up of logger"""
logging.setLoggerClass(AirbyteNativeLogger)
logging.addLevelName(TRACE_LEVEL_NUM, "TRACE")
logger = logging.getLogger(name)
logger.setLevel(TRACE_LEVEL_NUM)
@@ -61,7 +60,7 @@ def init_logger(name: str = None):
class AirbyteLogFormatter(logging.Formatter):
"""Output log records using AirbyteMessage"""
_secrets = []
_secrets: List[str] = []
@classmethod
def update_secrets(cls, secrets: List[str]):
@@ -88,46 +87,22 @@ class AirbyteLogFormatter(logging.Formatter):
return log_message.json(exclude_unset=True)
class AirbyteNativeLogger(logging.Logger):
"""Using native logger with implementing all AirbyteLogger features"""
def log_by_prefix(msg: str, default_level: str) -> Tuple[int, str]:
"""Custom method, which takes log level from first word of message"""
valid_log_types = ["FATAL", "ERROR", "WARN", "INFO", "DEBUG", "TRACE"]
split_line = msg.split()
first_word = next(iter(split_line), None)
if first_word in valid_log_types:
log_level = logging.getLevelName(first_word)
rendered_message = " ".join(split_line[1:])
else:
log_level = logging.getLevelName(default_level)
rendered_message = msg
def __init__(self, name):
super().__init__(name)
self.valid_log_types = ["FATAL", "ERROR", "WARN", "INFO", "DEBUG", "TRACE"]
def log_by_prefix(self, msg, default_level):
"""Custom method, which takes log level from first word of message"""
split_line = msg.split()
first_word = next(iter(split_line), None)
if first_word in self.valid_log_types:
log_level = logging.getLevelName(first_word)
rendered_message = " ".join(split_line[1:])
else:
default_level = default_level if default_level in self.valid_log_types else "INFO"
log_level = logging.getLevelName(default_level)
rendered_message = msg
self.log(log_level, rendered_message)
def trace(self, msg, *args, **kwargs):
self._log(TRACE_LEVEL_NUM, msg, args, **kwargs)
return log_level, rendered_message
class AirbyteLogger:
def __init__(self):
self.valid_log_types = ["FATAL", "ERROR", "WARN", "INFO", "DEBUG", "TRACE"]
def log_by_prefix(self, message, default_level):
"""Custom method, which takes log level from first word of message"""
split_line = message.split()
first_word = next(iter(split_line), None)
if first_word in self.valid_log_types:
log_level = first_word
rendered_message = " ".join(split_line[1:])
else:
log_level = default_level
rendered_message = message
self.log(log_level, rendered_message)
def log(self, level, message):
log_record = AirbyteLogMessage(level=level, message=message)
log_message = AirbyteMessage(type="LOG", log=log_record)

View File

@@ -4,19 +4,18 @@
import copy
import logging
from abc import ABC, abstractmethod
from datetime import datetime
from functools import lru_cache
from typing import Any, Dict, Iterator, List, Mapping, MutableMapping, Optional, Tuple
from airbyte_cdk.logger import AirbyteLogger
from airbyte_cdk.models import (
AirbyteCatalog,
AirbyteConnectionStatus,
AirbyteMessage,
AirbyteRecordMessage,
AirbyteStateMessage,
AirbyteStream,
ConfiguredAirbyteCatalog,
ConfiguredAirbyteStream,
Status,
@@ -38,8 +37,9 @@ class AbstractSource(Source, ABC):
"""
@abstractmethod
def check_connection(self, logger: AirbyteLogger, config: Mapping[str, Any]) -> Tuple[bool, Optional[Any]]:
def check_connection(self, logger: logging.Logger, config: Mapping[str, Any]) -> Tuple[bool, Optional[Any]]:
"""
:param logger: source logger
:param config: The user-provided configuration as specified by the source's spec.
This usually contains information required to check connection e.g. tokens, secrets and keys etc.
:return: A tuple of (boolean, error). If boolean is true, then the connection check is successful
@@ -57,19 +57,19 @@ class AbstractSource(Source, ABC):
"""
# Stream name to instance map for applying output object transformation
_stream_to_instance_map: Dict[str, AirbyteStream] = {}
_stream_to_instance_map: Dict[str, Stream] = {}
@property
def name(self) -> str:
"""Source name"""
return self.__class__.__name__
def discover(self, logger: AirbyteLogger, config: Mapping[str, Any]) -> AirbyteCatalog:
def discover(self, logger: logging.Logger, config: Mapping[str, Any]) -> AirbyteCatalog:
"""Implements the Discover operation from the Airbyte Specification. See https://docs.airbyte.io/architecture/airbyte-specification."""
streams = [stream.as_airbyte_stream() for stream in self.streams(config=config)]
return AirbyteCatalog(streams=streams)
def check(self, logger: AirbyteLogger, config: Mapping[str, Any]) -> AirbyteConnectionStatus:
def check(self, logger: logging.Logger, config: Mapping[str, Any]) -> AirbyteConnectionStatus:
"""Implements the Check Connection operation from the Airbyte Specification. See https://docs.airbyte.io/architecture/airbyte-specification."""
try:
check_succeeded, error = self.check_connection(logger, config)
@@ -81,7 +81,7 @@ class AbstractSource(Source, ABC):
return AirbyteConnectionStatus(status=Status.SUCCEEDED)
def read(
self, logger: AirbyteLogger, config: Mapping[str, Any], catalog: ConfiguredAirbyteCatalog, state: MutableMapping[str, Any] = None
self, logger: logging.Logger, config: Mapping[str, Any], catalog: ConfiguredAirbyteCatalog, state: MutableMapping[str, Any] = None
) -> Iterator[AirbyteMessage]:
"""Implements the Read operation from the Airbyte Specification. See https://docs.airbyte.io/architecture/airbyte-specification."""
connector_state = copy.deepcopy(state or {})
@@ -118,7 +118,7 @@ class AbstractSource(Source, ABC):
def _read_stream(
self,
logger: AirbyteLogger,
logger: logging.Logger,
stream_instance: Stream,
configured_stream: ConfiguredAirbyteStream,
connector_state: MutableMapping[str, Any],
@@ -160,7 +160,7 @@ class AbstractSource(Source, ABC):
def _read_incremental(
self,
logger: AirbyteLogger,
logger: logging.Logger,
stream_instance: Stream,
configured_stream: ConfiguredAirbyteStream,
connector_state: MutableMapping[str, Any],
@@ -222,7 +222,7 @@ class AbstractSource(Source, ABC):
return AirbyteMessage(type=MessageType.STATE, state=AirbyteStateMessage(data=connector_state))
@lru_cache(maxsize=None)
def _get_stream_transformer_and_schema(self, stream_name: str) -> Tuple[TypeTransformer, dict]:
def _get_stream_transformer_and_schema(self, stream_name: str) -> Tuple[TypeTransformer, Mapping[str, Any]]:
"""
Lookup stream's transform object and jsonschema based on stream name.
This function would be called a lot so using caching to save on costly
@@ -230,7 +230,7 @@ class AbstractSource(Source, ABC):
:param stream_name name of stream from catalog.
:return tuple with stream transformer object and discover json schema.
"""
stream_instance = self._stream_to_instance_map.get(stream_name)
stream_instance = self._stream_to_instance_map[stream_name]
return stream_instance.transformer, stream_instance.get_json_schema()
def _as_airbyte_record(self, stream_name: str, data: Mapping[str, Any]):
@@ -240,6 +240,6 @@ class AbstractSource(Source, ABC):
# need it to normalize values against json schema. By default no action
# taken unless configured. See
# docs/connector-development/cdk-python/schemas.md for details.
transformer.transform(data, schema)
transformer.transform(data, schema) # type: ignore
message = AirbyteRecordMessage(stream=stream_name, data=data, emitted_at=now_millis)
return AirbyteMessage(type=MessageType.RECORD, record=message)

View File

@@ -17,9 +17,9 @@ class BaseConfig(BaseModel):
"""
@classmethod
def schema(cls, **kwargs) -> Dict[str, Any]:
def schema(cls, *args, **kwargs) -> Dict[str, Any]:
"""We're overriding the schema classmethod to enable some post-processing"""
schema = super().schema(**kwargs)
schema = super().schema(*args, **kwargs)
rename_key(schema, old_key="anyOf", new_key="oneOf") # UI supports only oneOf
expand_refs(schema)
schema.pop("description", None) # description added from the docstring

View File

@@ -4,10 +4,10 @@
import copy
import logging
from datetime import datetime
from typing import Any, Iterable, Mapping, MutableMapping, Type
from airbyte_cdk.logger import AirbyteLogger
from airbyte_cdk.models import (
AirbyteCatalog,
AirbyteConnectionStatus,
@@ -39,13 +39,13 @@ class BaseSource(Source):
"""Construct client"""
return self.client_class(**config)
def discover(self, logger: AirbyteLogger, config: Mapping[str, Any]) -> AirbyteCatalog:
def discover(self, logger: logging.Logger, config: Mapping[str, Any]) -> AirbyteCatalog:
"""Discover streams"""
client = self._get_client(config)
return AirbyteCatalog(streams=[stream for stream in client.streams])
def check(self, logger: AirbyteLogger, config: Mapping[str, Any]) -> AirbyteConnectionStatus:
def check(self, logger: logging.Logger, config: Mapping[str, Any]) -> AirbyteConnectionStatus:
"""Check connection"""
client = self._get_client(config)
alive, error = client.health_check()
@@ -55,7 +55,7 @@ class BaseSource(Source):
return AirbyteConnectionStatus(status=Status.SUCCEEDED)
def read(
self, logger: AirbyteLogger, config: Mapping[str, Any], catalog: ConfiguredAirbyteCatalog, state: MutableMapping[str, Any] = None
self, logger: logging.Logger, config: Mapping[str, Any], catalog: ConfiguredAirbyteCatalog, state: MutableMapping[str, Any] = None
) -> Iterable[AirbyteMessage]:
state = state or {}
client = self._get_client(config)
@@ -73,7 +73,7 @@ class BaseSource(Source):
logger.info(f"Finished syncing {self.name}")
def _read_stream(
self, logger: AirbyteLogger, client: BaseClient, configured_stream: ConfiguredAirbyteStream, state: MutableMapping[str, Any]
self, logger: logging.Logger, client: BaseClient, configured_stream: ConfiguredAirbyteStream, state: MutableMapping[str, Any]
):
stream_name = configured_stream.stream.name
use_incremental = configured_stream.sync_mode == SyncMode.incremental and client.stream_has_state(stream_name)

View File

@@ -12,6 +12,7 @@ from datetime import datetime
from io import TextIOWrapper
from typing import Any, DefaultDict, Dict, Iterator, List, Mapping, Optional, Tuple
from airbyte_cdk.logger import log_by_prefix
from airbyte_cdk.models import (
AirbyteCatalog,
AirbyteMessage,
@@ -138,7 +139,7 @@ class SingerHelper:
shell_command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True
)
for line in completed_process.stderr.splitlines():
logger.log_by_prefix(line, "ERROR")
logger.log(*log_by_prefix(line, "ERROR"))
return json.loads(completed_process.stdout)
@@ -169,9 +170,9 @@ class SingerHelper:
if message_data is not None:
yield message_data
else:
logger.log_by_prefix(line, "INFO")
logger.log(*log_by_prefix(line, "INFO"))
else:
logger.log_by_prefix(line, "ERROR")
logger.log(*log_by_prefix(line, "ERROR"))
@staticmethod
def _read_lines(process: subprocess.Popen) -> Iterator[Tuple[str, TextIOWrapper]]:

View File

@@ -4,12 +4,12 @@
import json
import logging
from abc import ABC, abstractmethod
from collections import defaultdict
from typing import Any, Dict, Iterable, Mapping, MutableMapping
from airbyte_cdk.connector import Connector
from airbyte_cdk.logger import AirbyteLogger
from airbyte_cdk.models import AirbyteCatalog, AirbyteMessage, ConfiguredAirbyteCatalog
@@ -29,14 +29,14 @@ class Source(Connector, ABC):
@abstractmethod
def read(
self, logger: AirbyteLogger, config: Mapping[str, Any], catalog: ConfiguredAirbyteCatalog, state: MutableMapping[str, Any] = None
self, logger: logging.Logger, config: Mapping[str, Any], catalog: ConfiguredAirbyteCatalog, state: MutableMapping[str, Any] = None
) -> Iterable[AirbyteMessage]:
"""
Returns a generator of the AirbyteMessages generated by reading the source with the given configuration, catalog, and state.
"""
@abstractmethod
def discover(self, logger: AirbyteLogger, config: Mapping[str, Any]) -> AirbyteCatalog:
def discover(self, logger: logging.Logger, config: Mapping[str, Any]) -> AirbyteCatalog:
"""
Returns an AirbyteCatalog representing the available streams and fields in this integration. For example, given valid credentials to a
Postgres database, returns an Airbyte catalog where each postgres table is a stream, and each table column is a field.

View File

@@ -110,11 +110,13 @@ class Stream(ABC):
"""
def stream_slices(
self, sync_mode: SyncMode, cursor_field: List[str] = None, stream_state: Mapping[str, Any] = None
self, *, sync_mode: SyncMode, cursor_field: List[str] = None, stream_state: Mapping[str, Any] = None
) -> Iterable[Optional[Mapping[str, Any]]]:
"""
Override to define the slices for this stream. See the stream slicing section of the docs for more information.
:param sync_mode:
:param cursor_field:
:param stream_state:
:return:
"""

View File

@@ -33,13 +33,13 @@ class HttpStream(Stream, ABC):
"""
source_defined_cursor = True # Most HTTP streams use a source defined cursor (i.e: the user can't configure it like on a SQL table)
page_size = None # Use this variable to define page size for API http requests with pagination support
page_size: Optional[int] = None # Use this variable to define page size for API http requests with pagination support
# TODO: remove legacy HttpAuthenticator authenticator references
def __init__(self, authenticator: Union[AuthBase, HttpAuthenticator] = None):
self._session = requests.Session()
self._authenticator = NoAuth()
self._authenticator: HttpAuthenticator = NoAuth()
if isinstance(authenticator, AuthBase):
self._session.auth = authenticator
elif authenticator:
@@ -107,7 +107,7 @@ class HttpStream(Stream, ABC):
return 5
@property
def retry_factor(self) -> int:
def retry_factor(self) -> float:
"""
Override if needed. Specifies factor for backoff policy.
"""
@@ -130,6 +130,7 @@ class HttpStream(Stream, ABC):
@abstractmethod
def path(
self,
*,
stream_state: Mapping[str, Any] = None,
stream_slice: Mapping[str, Any] = None,
next_page_token: Mapping[str, Any] = None,
@@ -206,6 +207,7 @@ class HttpStream(Stream, ABC):
def parse_response(
self,
response: requests.Response,
*,
stream_state: Mapping[str, Any],
stream_slice: Mapping[str, Any] = None,
next_page_token: Mapping[str, Any] = None,
@@ -214,6 +216,9 @@ class HttpStream(Stream, ABC):
Parses the raw response object into a list of records.
By default, this returns an iterable containing the input. Override to parse differently.
:param response:
:param stream_state:
:param stream_slice:
:param next_page_token:
:return: An iterable containing the parsed response
"""
@@ -236,6 +241,7 @@ class HttpStream(Stream, ABC):
This method is called only if should_backoff() returns True for the input request.
:param response:
:return how long to backoff in seconds. The return value may be a floating point number for subsecond precision. Returning None defers backoff
to the default backoff behavior (e.g using an exponential algorithm).
"""
@@ -310,11 +316,11 @@ class HttpStream(Stream, ABC):
max_tries: The maximum number of attempts to make before giving
up ...The default value of None means there is no limit to
the number of tries.
This implies that if max_tries is excplicitly set to None there is no
This implies that if max_tries is explicitly set to None there is no
limit to retry attempts, otherwise it is limited number of tries. But
this is not true for current version of backoff packages (1.8.0). Setting
max_tries to 0 or negative number would result in endless retry atempts.
Add this condition to avoid an endless loop if it hasnt been set
max_tries to 0 or negative number would result in endless retry attempts.
Add this condition to avoid an endless loop if it hasn't been set
explicitly (i.e. max_retries is not None).
"""
if max_tries is not None:

View File

@@ -5,6 +5,7 @@
import sys
import time
from typing import Optional
import backoff
from airbyte_cdk.logger import AirbyteLogger
@@ -18,7 +19,7 @@ TRANSIENT_EXCEPTIONS = (DefaultBackoffException, exceptions.ConnectTimeout, exce
logger = AirbyteLogger()
def default_backoff_handler(max_tries: int, factor: int, **kwargs):
def default_backoff_handler(max_tries: Optional[int], factor: float, **kwargs):
def log_retry_attempt(details):
_, exc, _ = sys.exc_info()
if exc.response:
@@ -46,7 +47,7 @@ def default_backoff_handler(max_tries: int, factor: int, **kwargs):
)
def user_defined_backoff_handler(max_tries: int, **kwargs):
def user_defined_backoff_handler(max_tries: Optional[int], **kwargs):
def sleep_on_ratelimit(details):
_, exc, _ = sys.exc_info()
if isinstance(exc, UserDefinedBackoffException):

View File

@@ -23,7 +23,7 @@ class Oauth2Authenticator(AuthBase):
client_secret: str,
refresh_token: str,
scopes: List[str] = None,
token_expiry_date: pendulum.datetime = None,
token_expiry_date: pendulum.DateTime = None,
access_token_name: str = "access_token",
expires_in_name: str = "expires_in",
):

View File

@@ -1 +1,5 @@
#
# Copyright (c) 2021 Airbyte, Inc., all rights reserved.
#
# Initialize Utils Package

View File

@@ -77,8 +77,8 @@ class BaseSchemaModel(BaseModel):
prop["oneOf"] = [{"type": "null"}, {"$ref": ref}]
@classmethod
def schema(cls, **kwargs) -> Dict[str, Any]:
def schema(cls, *args, **kwargs) -> Dict[str, Any]:
"""We're overriding the schema classmethod to enable some post-processing"""
schema = super().schema(**kwargs)
schema = super().schema(*args, **kwargs)
expand_refs(schema)
return schema

View File

@@ -4,7 +4,7 @@
from distutils.util import strtobool
from enum import Flag, auto
from typing import Any, Callable, Dict
from typing import Any, Callable, Dict, Mapping, Optional
from airbyte_cdk.logger import AirbyteLogger
from jsonschema import Draft7Validator, validators
@@ -36,7 +36,7 @@ class TypeTransformer:
Class for transforming object before output.
"""
_custom_normalizer: Callable[[Any, Dict[str, Any]], Any] = None
_custom_normalizer: Optional[Callable[[Any, Dict[str, Any]], Any]] = None
def __init__(self, config: TransformConfig):
"""
@@ -90,7 +90,7 @@ class TypeTransformer:
:param subschema part of the jsonschema containing field type/format data.
:return transformed field value.
"""
target_type = subschema.get("type")
target_type = subschema.get("type", [])
if original_item is None and "null" in target_type:
return None
if isinstance(target_type, list):
@@ -160,11 +160,11 @@ class TypeTransformer:
return normalizator
def transform(self, record: Dict[str, Any], schema: Dict[str, Any]):
def transform(self, record: Dict[str, Any], schema: Mapping[str, Any]):
"""
Normalize and validate according to config.
:param record record instance for normalization/transformation. All modification are done by modifing existent object.
:schema object's jsonschema for normalization.
:param record: record instance for normalization/transformation. All modification are done by modifying existent object.
:param schema: object's jsonschema for normalization.
"""
if TransformConfig.NoTransform in self._config:
return

View File

@@ -6,6 +6,7 @@ import datetime
import time
from contextlib import contextmanager
from dataclasses import dataclass, field
from typing import Optional
from airbyte_cdk.logger import AirbyteLogger
@@ -60,7 +61,7 @@ class EventTimer:
class Event:
name: str
start: float = field(default_factory=time.perf_counter_ns)
end: float = field(default=None)
end: Optional[float] = field(default=None)
@property
def duration(self) -> float:

View File

@@ -3,7 +3,7 @@
#
from functools import reduce
from typing import Any, List, Mapping, Optional
from typing import Any, Iterable, List, Mapping, Optional, Tuple
def all_key_pairs_dot_notation(dict_obj: Mapping) -> Mapping[str, Any]:
@@ -12,7 +12,7 @@ def all_key_pairs_dot_notation(dict_obj: Mapping) -> Mapping[str, Any]:
keys are prefixed with the list of keys passed in as prefix.
"""
def _all_key_pairs_dot_notation(_dict_obj: Mapping, prefix: List[str] = []) -> Mapping[str, Any]:
def _all_key_pairs_dot_notation(_dict_obj: Mapping, prefix: List[str] = []) -> Iterable[Tuple[str, Any]]:
for key, value in _dict_obj.items():
if isinstance(value, dict):
prefix.append(str(key))

View File

@@ -15,7 +15,7 @@ README = (HERE / "README.md").read_text()
setup(
name="airbyte-cdk",
version="0.1.46",
version="0.1.47",
description="A framework for writing Airbyte Connectors.",
long_description=README,
long_description_content_type="text/markdown",

View File

@@ -4,14 +4,14 @@
import copy
import logging
from unittest.mock import patch
from airbyte_cdk.logger import AirbyteLogger
from airbyte_cdk.models.airbyte_protocol import SyncMode
from airbyte_cdk.sources.singer import SingerHelper, SyncModeInfo
from airbyte_cdk.sources.singer.source import BaseSingerSource, ConfigContainer
LOGGER = AirbyteLogger()
logger = logging.getLogger("airbyte")
class TetsBaseSinger(BaseSingerSource):
@@ -57,7 +57,7 @@ basic_singer_catalog = {"streams": [USER_STREAM, ROLES_STREAM]}
@patch.object(SingerHelper, "_read_singer_catalog", return_value=basic_singer_catalog)
def test_singer_discover_single_pk(mock_read_catalog):
airbyte_catalog = TetsBaseSinger().discover(LOGGER, ConfigContainer({}, ""))
airbyte_catalog = TetsBaseSinger().discover(logger, ConfigContainer({}, ""))
_user_stream = airbyte_catalog.streams[0]
_roles_stream = airbyte_catalog.streams[1]
assert _user_stream.source_defined_primary_key == [["id"]]
@@ -69,7 +69,7 @@ def test_singer_discover_with_composite_pk():
singer_catalog_composite_pk = copy.deepcopy(basic_singer_catalog)
singer_catalog_composite_pk["streams"][0]["key_properties"] = ["id", "name"]
with patch.object(SingerHelper, "_read_singer_catalog", return_value=singer_catalog_composite_pk):
airbyte_catalog = TetsBaseSinger().discover(LOGGER, ConfigContainer({}, ""))
airbyte_catalog = TetsBaseSinger().discover(logger, ConfigContainer({}, ""))
_user_stream = airbyte_catalog.streams[0]
_roles_stream = airbyte_catalog.streams[1]
@@ -81,7 +81,7 @@ def test_singer_discover_with_composite_pk():
@patch.object(BaseSingerSource, "get_primary_key_overrides", return_value={"users": ["updated_at"]})
@patch.object(SingerHelper, "_read_singer_catalog", return_value=basic_singer_catalog)
def test_singer_discover_pk_overrides(mock_pk_override, mock_read_catalog):
airbyte_catalog = TetsBaseSinger().discover(LOGGER, ConfigContainer({}, ""))
airbyte_catalog = TetsBaseSinger().discover(logger, ConfigContainer({}, ""))
_user_stream = airbyte_catalog.streams[0]
_roles_stream = airbyte_catalog.streams[1]
assert _user_stream.source_defined_primary_key == [["updated_at"]]
@@ -91,7 +91,7 @@ def test_singer_discover_pk_overrides(mock_pk_override, mock_read_catalog):
@patch.object(SingerHelper, "_read_singer_catalog", return_value=basic_singer_catalog)
def test_singer_discover_metadata(mock_read_catalog):
airbyte_catalog = TetsBaseSinger().discover(LOGGER, ConfigContainer({}, ""))
airbyte_catalog = TetsBaseSinger().discover(logger, ConfigContainer({}, ""))
_user_stream = airbyte_catalog.streams[0]
_roles_stream = airbyte_catalog.streams[1]
@@ -105,7 +105,7 @@ def test_singer_discover_metadata(mock_read_catalog):
def test_singer_discover_sync_mode_overrides(mock_read_catalog):
sync_mode_override = SyncModeInfo(supported_sync_modes=[SyncMode.full_refresh, SyncMode.incremental], default_cursor_field=["name"])
with patch.object(BaseSingerSource, "get_sync_mode_overrides", return_value={"roles": sync_mode_override}):
airbyte_catalog = TetsBaseSinger().discover(LOGGER, ConfigContainer({}, ""))
airbyte_catalog = TetsBaseSinger().discover(logger, ConfigContainer({}, ""))
_roles_stream = airbyte_catalog.streams[1]
assert _roles_stream.supported_sync_modes == sync_mode_override.supported_sync_modes

View File

@@ -29,18 +29,10 @@ class StubBasicReadHttpStream(HttpStream):
def next_page_token(self, response: requests.Response) -> Optional[Mapping[str, Any]]:
return None
def path(
self, stream_state: Mapping[str, Any] = None, stream_slice: Mapping[str, Any] = None, next_page_token: Mapping[str, Any] = None
) -> str:
def path(self, **kwargs) -> str:
return ""
def parse_response(
self,
response: requests.Response,
stream_state: Mapping[str, Any],
stream_slice: Mapping[str, Any] = None,
next_page_token: Mapping[str, Any] = None,
) -> Iterable[Mapping]:
def parse_response(self, response: requests.Response, **kwargs) -> Iterable[Mapping]:
stubResp = {"data": self.resp_counter}
self.resp_counter += 1
yield stubResp
@@ -364,10 +356,10 @@ class CacheHttpSubStream(HttpSubStream):
def __init__(self, parent):
super().__init__(parent=parent)
def parse_response(self, **kwargs) -> Iterable[Mapping]:
def parse_response(self, response: requests.Response, **kwargs) -> Iterable[Mapping]:
return []
def next_page_token(self, **kwargs) -> Optional[Mapping[str, Any]]:
def next_page_token(self, response: requests.Response) -> Optional[Mapping[str, Any]]:
return None
def path(self, **kwargs) -> str:
@@ -406,14 +398,14 @@ class CacheHttpStreamWithSlices(CacheHttpStream):
paths = ["", "search"]
def path(self, stream_slice: Mapping[str, Any] = None, **kwargs) -> str:
return f'{stream_slice.get("path")}'
return f'{stream_slice["path"]}' if stream_slice else ""
def stream_slices(self, **kwargs) -> Iterable[Optional[Mapping[str, Any]]]:
for path in self.paths:
yield {"path": path}
def parse_response(self, response: requests.Response, **kwargs) -> Iterable[Mapping]:
yield response
yield {"value": len(response.text)}
@patch("airbyte_cdk.sources.streams.core.logging", MagicMock())

View File

@@ -2,12 +2,11 @@
# Copyright (c) 2021 Airbyte, Inc., all rights reserved.
#
import logging
from collections import defaultdict
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple, Union
import pytest
from airbyte_cdk.logger import AirbyteLogger
from airbyte_cdk.models import (
AirbyteCatalog,
AirbyteConnectionStatus,
@@ -25,13 +24,15 @@ from airbyte_cdk.models import (
from airbyte_cdk.sources import AbstractSource
from airbyte_cdk.sources.streams import Stream
logger = logging.getLogger("airbyte")
class MockSource(AbstractSource):
def __init__(self, check_lambda: Callable[[], Tuple[bool, Optional[Any]]] = None, streams: List[Stream] = None):
self._streams = streams
self.check_lambda = check_lambda
def check_connection(self, logger: AirbyteLogger, config: Mapping[str, Any]) -> Tuple[bool, Optional[Any]]:
def check_connection(self, logger: logging.Logger, config: Mapping[str, Any]) -> Tuple[bool, Optional[Any]]:
if self.check_lambda:
return self.check_lambda()
return (False, "Missing callable.")
@@ -42,11 +43,6 @@ class MockSource(AbstractSource):
return self._streams
@pytest.fixture
def logger() -> AirbyteLogger:
return AirbyteLogger()
def test_successful_check():
"""Tests that if a source returns TRUE for the connection check the appropriate connectionStatus success message is returned"""
expected = AirbyteConnectionStatus(status=Status.SUCCEEDED)
@@ -112,7 +108,7 @@ def test_discover(mocker):
assert expected == src.discover(logger, {})
def test_read_nonexistent_stream_raises_exception(mocker, logger):
def test_read_nonexistent_stream_raises_exception(mocker):
"""Tests that attempting to sync a stream which the source does not return from the `streams` method raises an exception"""
s1 = MockStream(name="s1")
s2 = MockStream(name="this_stream_doesnt_exist_in_the_source")
@@ -149,7 +145,7 @@ def _fix_emitted_at(messages: List[AirbyteMessage]) -> List[AirbyteMessage]:
return messages
def test_valid_full_refresh_read_no_slices(logger, mocker):
def test_valid_full_refresh_read_no_slices(mocker):
"""Tests that running a full refresh sync on streams which don't specify slices produces the expected AirbyteMessages"""
stream_output = [{"k1": "v1"}, {"k2": "v2"}]
s1 = MockStream([({"sync_mode": SyncMode.full_refresh}, stream_output)], name="s1")
@@ -168,7 +164,7 @@ def test_valid_full_refresh_read_no_slices(logger, mocker):
assert expected == messages
def test_valid_full_refresh_read_with_slices(mocker, logger):
def test_valid_full_refresh_read_with_slices(mocker):
"""Tests that running a full refresh sync on streams which use slices produces the expected AirbyteMessages"""
slices = [{"1": "1"}, {"2": "2"}]
# When attempting to sync a slice, just output that slice as a record
@@ -194,7 +190,7 @@ def _state(state_data: Dict[str, Any]):
return AirbyteMessage(type=Type.STATE, state=AirbyteStateMessage(data=state_data))
def test_valid_incremental_read_with_checkpoint_interval(mocker, logger):
def test_valid_incremental_read_with_checkpoint_interval(mocker):
"""Tests that an incremental read which doesn't specify a checkpoint interval outputs a STATE message after reading N records within a stream"""
stream_output = [{"k1": "v1"}, {"k2": "v2"}]
s1 = MockStream([({"sync_mode": SyncMode.incremental, "stream_state": {}}, stream_output)], name="s1")
@@ -226,7 +222,7 @@ def test_valid_incremental_read_with_checkpoint_interval(mocker, logger):
assert expected == messages
def test_valid_incremental_read_with_no_interval(mocker, logger):
def test_valid_incremental_read_with_no_interval(mocker):
"""Tests that an incremental read which doesn't specify a checkpoint interval outputs a STATE message only after fully reading the stream and does
not output any STATE messages during syncing the stream."""
stream_output = [{"k1": "v1"}, {"k2": "v2"}]
@@ -252,7 +248,7 @@ def test_valid_incremental_read_with_no_interval(mocker, logger):
assert expected == messages
def test_valid_incremental_read_with_slices(mocker, logger):
def test_valid_incremental_read_with_slices(mocker):
"""Tests that an incremental read which uses slices outputs each record in the slice followed by a STATE message, for each slice"""
slices = [{"1": "1"}, {"2": "2"}]
stream_output = [{"k1": "v1"}, {"k2": "v2"}, {"k3": "v3"}]
@@ -291,7 +287,7 @@ def test_valid_incremental_read_with_slices(mocker, logger):
assert expected == messages
def test_valid_incremental_read_with_slices_and_interval(mocker, logger):
def test_valid_incremental_read_with_slices_and_interval(mocker):
"""
Tests that an incremental read which uses slices and a checkpoint interval:
1. outputs all records

View File

@@ -4,12 +4,12 @@
import json
import logging
import tempfile
from typing import Any, Mapping, MutableMapping
from unittest.mock import MagicMock
import pytest
from airbyte_cdk.logger import AirbyteLogger
from airbyte_cdk.models import ConfiguredAirbyteCatalog, SyncMode, Type
from airbyte_cdk.sources import AbstractSource, Source
from airbyte_cdk.sources.streams.core import Stream
@@ -19,14 +19,14 @@ from airbyte_cdk.sources.utils.transform import TransformConfig, TypeTransformer
class MockSource(Source):
def read(
self, logger: AirbyteLogger, config: Mapping[str, Any], catalog: ConfiguredAirbyteCatalog, state: MutableMapping[str, Any] = None
self, logger: logging.Logger, config: Mapping[str, Any], catalog: ConfiguredAirbyteCatalog, state: MutableMapping[str, Any] = None
):
pass
def check(self, logger: AirbyteLogger, config: Mapping[str, Any]):
def check(self, logger: logging.Logger, config: Mapping[str, Any]):
pass
def discover(self, logger: AirbyteLogger, config: Mapping[str, Any]):
def discover(self, logger: logging.Logger, config: Mapping[str, Any]):
pass

View File

@@ -4,13 +4,13 @@
import json
import logging
import tempfile
from pathlib import Path
from typing import Any, Mapping
import pytest
from airbyte_cdk import AirbyteSpec, Connector
from airbyte_cdk.logger import AirbyteLogger
from airbyte_cdk.models import AirbyteConnectionStatus
@@ -39,7 +39,7 @@ class TestAirbyteSpec:
class MockConnector(Connector):
def check(self, logger: AirbyteLogger, config: Mapping[str, Any]) -> AirbyteConnectionStatus:
def check(self, logger: logging.Logger, config: Mapping[str, Any]) -> AirbyteConnectionStatus:
pass
@@ -68,6 +68,6 @@ def test_read_config(nonempty_file, integration: Connector, mock_config):
def test_write_config(integration, mock_config):
config_path = Path(tempfile.gettempdir()) / "config.json"
integration.write_config(mock_config, config_path)
integration.write_config(mock_config, str(config_path))
with open(config_path, "r") as actual:
assert mock_config == json.loads(actual.read())

View File

@@ -54,7 +54,7 @@ def test_level_transform(logger, caplog):
def test_trace(logger, caplog):
logger.trace("Test trace 1")
logger.log(logging.getLevelName("TRACE"), "Test trace 1")
record = caplog.records[0]
assert record.levelname == "TRACE"
assert record.message == "Test trace 1"

View File

@@ -8,7 +8,7 @@ from argparse import Namespace
from typing import Any, Iterable, Mapping, MutableMapping
import pytest
from airbyte_cdk import AirbyteEntrypoint, AirbyteLogger
from airbyte_cdk import AirbyteEntrypoint
from airbyte_cdk.logger import AirbyteLogFormatter
from airbyte_cdk.models import AirbyteMessage, AirbyteRecordMessage, ConfiguredAirbyteCatalog, ConnectorSpecification, Type
from airbyte_cdk.sources import Source
@@ -29,7 +29,7 @@ ANOTHER_NOT_SECRET_VALUE = "I am not a secret"
class MockSource(Source):
def read(
self,
logger: AirbyteLogger,
logger: logging.Logger,
config: Mapping[str, Any],
catalog: ConfiguredAirbyteCatalog,
state: MutableMapping[str, Any] = None,
@@ -152,7 +152,7 @@ def test_airbyte_secrets_are_masked_on_uncaught_exceptions(mocker, caplog):
class BrokenSource(MockSource):
def read(
self,
logger: AirbyteLogger,
logger: logging.Logger,
config: Mapping[str, Any],
catalog: ConfiguredAirbyteCatalog,
state: MutableMapping[str, Any] = None,
@@ -198,7 +198,7 @@ def test_non_airbyte_secrets_are_not_masked_on_uncaught_exceptions(mocker, caplo
class BrokenSource(MockSource):
def read(
self,
logger: AirbyteLogger,
logger: logging.Logger,
config: Mapping[str, Any],
catalog: ConfiguredAirbyteCatalog,
state: MutableMapping[str, Any] = None,