1
0
mirror of synced 2026-01-07 09:05:45 -05:00

Run mypy on airbyte-cdk as part of the build pipeline and fix typing issues in the file-based module (#27790)

* Try running only on modified files

* make a change

* return something with the wrong type

* Revert "return something with the wrong type"

This reverts commit 23b828371e.

* fix typing in file-based

* format

* Mypy

* fix

* leave as Mapping

* Revert "leave as Mapping"

This reverts commit 908f063f70.

* Use Dict

* update

* move dict()

* Revert "move dict()"

This reverts commit fa347a8236.

* Revert "Revert "move dict()""

This reverts commit c9237df2e4.

* Revert "Revert "Revert "move dict()"""

This reverts commit 5ac1616414.

* use Mapping

* point to config file

* comment

* strict = False

* remove --

* Revert "comment"

This reverts commit 6000814a82.

* install types

* install types in same command as mypy runs

* non-interactive

* freeze version

* pydantic plugin

* plugins

* update

* ignore missing import

* Revert "ignore missing import"

This reverts commit 1da7930fb7.

* Install pydantic instead

* fix

* this passes locally

* strict = true

* format

* explicitly import models

* Update

* remove old mypy.ini config

* temporarily disable mypy

* format

* any

* format

* fix tests

* format

* Automated Commit - Formatting Changes

* Revert "temporarily disable mypy"

This reverts commit eb8470fa3f.

* implicit reexport

* update test

* fix mypy

* Automated Commit - Formatting Changes

* fix some errors in tests

* more type fixes

* more fixes

* more

* .

* done with tests

* fix last files

* format

* Update gradle

* change source-stripe

* only run mypy on cdk

* remove strict

* Add more rules

* update

* ignore missing imports

* cast to string

* Allow untyped decorator

* reset to master

* move to the cdk

* derp

* move explicit imports around

* Automated Commit - Formatting Changes

* Revert "move explicit imports around"

This reverts commit 56e306b72f.

* move explicit imports around

* Upgrade mypy version

* point to config file

* Update readme

* Ignore errors in the models module

* Automated Commit - Formatting Changes

* move check to gradle build

* Any

* try checking out master too

* Revert "try checking out master too"

This reverts commit 8a8f3e373c.

* fetch master

* install mypy

* try without origin

* fetch from the script

* checkout master

* ls the branches

* remotes/origin/master

* remove some cruft

* comment

* remove pydantic types

* unpin mypy

* fetch from the script

* Update connectors base too

* modify a non-cdk file to confirm it doesn't get checked by mypy

* run mypy after generateComponentManifestClassFiles

* run from the venv

* pass files as arguments

* update

* fix when running without args

* with subdir

* path

* try without /

* ./

* remove filter

* try resetting

* Revert "try resetting"

This reverts commit 3a54c424de.

* exclude autogen file

* do not use the github action

* works locally

* remove extra fetch

* run on connectors base

* try bad  typing

* Revert "try bad  typing"

This reverts commit 33b512a3e4.

* reset stripe

* Revert "reset stripe"

This reverts commit 28f23fc6dd.

* Revert "Revert "reset stripe""

This reverts commit 5bf5dee371.

* missing return type

* do not ignore the autogen file

* remove extra installs

* run from venv

* Only check files modified on current branch

* Revert "Only check files modified on current branch"

This reverts commit b4b728e654.

* use merge-base

* Revert "use merge-base"

This reverts commit 3136670cbf.

* try with updated mypy

* bump

* run other steps after mypy

* reset task ordering

* run mypy though

* looser config

* tests pass

* fix mypy issues

* type: ignore

* optional

* this is always a bool

* ignore

* fix typing issues

* remove ignore

* remove mapping

* Automated Commit - Formatting Changes

* Revert "remove ignore"

This reverts commit 9ffeeb6cb1.

* update config

---------

Co-authored-by: girarda <girarda@users.noreply.github.com>
Co-authored-by: Joe Bell <joseph.bell@airbyte.io>
This commit is contained in:
Alexandre Girard
2023-07-13 16:55:48 -07:00
committed by GitHub
parent 792878b253
commit 97a353d5c5
38 changed files with 355 additions and 214 deletions

View File

@@ -44,7 +44,7 @@ class CsvFormat(BaseModel):
double_quote: bool = Field(
title="Double Quote", default=True, description="Whether two quotes in a quoted CSV value denote a single quote in the data."
)
quoting_behavior: Optional[QuotingBehavior] = Field(
quoting_behavior: QuotingBehavior = Field(
title="Quoting Behavior",
default=QuotingBehavior.QUOTE_SPECIAL_CHARACTERS,
description="The quoting behavior determines when a value in a row should have quote marks added around it. For example, if Quote Non-numeric is specified, while reading, quotes are expected for row values that do not contain numbers. Or for Quote All, every row value will be expecting quotes.",
@@ -54,7 +54,7 @@ class CsvFormat(BaseModel):
# of using pyarrow
@validator("delimiter")
def validate_delimiter(cls, v):
def validate_delimiter(cls, v: str) -> str:
if len(v) != 1:
raise ValueError("delimiter should only be one character")
if v in {"\r", "\n"}:
@@ -62,19 +62,19 @@ class CsvFormat(BaseModel):
return v
@validator("quote_char")
def validate_quote_char(cls, v):
def validate_quote_char(cls, v: str) -> str:
if len(v) != 1:
raise ValueError("quote_char should only be one character")
return v
@validator("escape_char")
def validate_escape_char(cls, v):
def validate_escape_char(cls, v: str) -> str:
if len(v) != 1:
raise ValueError("escape_char should only be one character")
return v
@validator("encoding")
def validate_encoding(cls, v):
def validate_encoding(cls, v: str) -> str:
try:
codecs.lookup(v)
except LookupError:
@@ -116,13 +116,13 @@ class FileBasedStreamConfig(BaseModel):
)
@validator("file_type", pre=True)
def validate_file_type(cls, v):
def validate_file_type(cls, v: str) -> str:
if v not in VALID_FILE_TYPES:
raise ValueError(f"Format filetype {v} is not a supported file type")
return v
@validator("format", pre=True)
def transform_format(cls, v):
def transform_format(cls, v: Mapping[str, str]) -> Any:
if isinstance(v, Mapping):
file_type = v.get("filetype", "")
if file_type:
@@ -132,6 +132,8 @@ class FileBasedStreamConfig(BaseModel):
return v
@validator("input_schema", pre=True)
def transform_input_schema(cls, v):
def transform_input_schema(cls, v: Optional[Union[str, Mapping[str, Any]]]) -> Optional[Mapping[str, Any]]:
if v:
return type_mapping_to_jsonschema(v)
else:
return None

View File

@@ -13,15 +13,14 @@ from airbyte_cdk.sources.file_based.remote_file import RemoteFile
from airbyte_cdk.sources.file_based.schema_helpers import conforms_to_schema
from airbyte_cdk.sources.file_based.stream import AbstractFileBasedStream
from airbyte_cdk.sources.streams.availability_strategy import AvailabilityStrategy
from airbyte_cdk.sources.streams.core import Stream
class DefaultFileBasedAvailabilityStrategy(AvailabilityStrategy):
def __init__(self, stream_reader: AbstractFileBasedStreamReader):
self.stream_reader = stream_reader
def check_availability(
self, stream: AbstractFileBasedStream, logger: logging.Logger, _: Optional[Source]
) -> Tuple[bool, Optional[str]]:
def check_availability(self, stream: Stream, logger: logging.Logger, _: Optional[Source]) -> Tuple[bool, Optional[str]]:
"""
Perform a connection check for the stream.
@@ -38,6 +37,8 @@ class DefaultFileBasedAvailabilityStrategy(AvailabilityStrategy):
- If the user provided a schema in the config, check that a subset of records in
one file conform to the schema via a call to stream.conforms_to_schema(schema).
"""
if not isinstance(stream, AbstractFileBasedStream):
raise ValueError(f"Stream {stream.name} is not a file-based stream.")
try:
files = self._check_list_files(stream)
self._check_parse_record(stream, files[0], logger)
@@ -60,7 +61,7 @@ class DefaultFileBasedAvailabilityStrategy(AvailabilityStrategy):
return files
def _check_parse_record(self, stream: AbstractFileBasedStream, file: RemoteFile, logger: logging.Logger):
def _check_parse_record(self, stream: AbstractFileBasedStream, file: RemoteFile, logger: logging.Logger) -> None:
parser = stream.get_parser(stream.config.file_type)
try:
@@ -75,7 +76,7 @@ class DefaultFileBasedAvailabilityStrategy(AvailabilityStrategy):
schema = stream.catalog_schema or stream.config.input_schema
if schema and stream.validation_policy.validate_schema_before_sync:
if not conforms_to_schema(record, schema):
if not conforms_to_schema(record, schema): # type: ignore
raise CheckAvailabilityError(
FileBasedSourceError.ERROR_VALIDATING_RECORD,
stream=stream.name,

View File

@@ -37,7 +37,7 @@ class FileBasedSourceError(Enum):
class BaseFileBasedSourceError(Exception):
def __init__(self, error: FileBasedSourceError, **kwargs):
def __init__(self, error: FileBasedSourceError, **kwargs): # type: ignore # noqa
super().__init__(
f"{FileBasedSourceError(error).value} Contact Support if you need assistance.\n{' '.join([f'{k}={v}' for k, v in kwargs.items()])}"
)

View File

@@ -5,10 +5,9 @@
import logging
import traceback
from abc import ABC
from typing import Any, Dict, List, Mapping, Optional, Tuple, Type
from typing import Any, List, Mapping, Optional, Tuple, Type
from airbyte_cdk.models import ConfiguredAirbyteCatalog
from airbyte_cdk.models.airbyte_protocol import ConnectorSpecification
from airbyte_cdk.models import ConfiguredAirbyteCatalog, ConnectorSpecification
from airbyte_cdk.sources import AbstractSource
from airbyte_cdk.sources.file_based.config.abstract_file_based_spec import AbstractFileBasedSpec
from airbyte_cdk.sources.file_based.config.file_based_stream_config import FileBasedStreamConfig
@@ -21,6 +20,7 @@ from airbyte_cdk.sources.file_based.file_types.file_type_parser import FileTypeP
from airbyte_cdk.sources.file_based.schema_validation_policies import DEFAULT_SCHEMA_VALIDATION_POLICIES, AbstractSchemaValidationPolicy
from airbyte_cdk.sources.file_based.stream import AbstractFileBasedStream, DefaultFileBasedStream
from airbyte_cdk.sources.file_based.stream.cursor.default_file_based_cursor import DefaultFileBasedCursor
from airbyte_cdk.sources.streams import Stream
from airbyte_cdk.sources.streams.availability_strategy import AvailabilityStrategy
from pydantic.error_wrappers import ValidationError
@@ -35,15 +35,15 @@ class FileBasedSource(AbstractSource, ABC):
availability_strategy: Optional[AvailabilityStrategy],
spec_class: Type[AbstractFileBasedSpec],
discovery_policy: AbstractDiscoveryPolicy = DefaultDiscoveryPolicy(),
parsers: Dict[str, FileTypeParser] = None,
validation_policies: Dict[str, AbstractSchemaValidationPolicy] = DEFAULT_SCHEMA_VALIDATION_POLICIES,
parsers: Mapping[str, FileTypeParser] = default_parsers,
validation_policies: Mapping[str, AbstractSchemaValidationPolicy] = DEFAULT_SCHEMA_VALIDATION_POLICIES,
max_history_size: int = DEFAULT_MAX_HISTORY_SIZE,
):
self.stream_reader = stream_reader
self.availability_strategy = availability_strategy or DefaultFileBasedAvailabilityStrategy(stream_reader)
self.spec_class = spec_class
self.discovery_policy = discovery_policy
self.parsers = parsers or default_parsers
self.parsers = parsers
self.validation_policies = validation_policies
self.stream_schemas = {s.stream.name: s.stream.json_schema for s in catalog.streams} if catalog else {}
self.max_history_size = max_history_size
@@ -70,6 +70,8 @@ class FileBasedSource(AbstractSource, ABC):
errors = []
for stream in streams:
if not isinstance(stream, AbstractFileBasedStream):
raise ValueError(f"Stream {stream} is not a file-based stream.")
try:
(
stream_is_available,
@@ -78,18 +80,18 @@ class FileBasedSource(AbstractSource, ABC):
except Exception:
errors.append(f"Unable to connect to stream {stream} - {''.join(traceback.format_exc())}")
else:
if not stream_is_available:
if not stream_is_available and reason:
errors.append(reason)
return not bool(errors), (errors or None)
def streams(self, config: Mapping[str, Any]) -> List[AbstractFileBasedStream]:
def streams(self, config: Mapping[str, Any]) -> List[Stream]:
"""
Return a list of this source's streams.
"""
try:
parsed_config = self.spec_class(**config)
streams = []
streams: List[Stream] = []
for stream_config in parsed_config.streams:
self._validate_input_schema(stream_config)
streams.append(
@@ -119,13 +121,13 @@ class FileBasedSource(AbstractSource, ABC):
connectionSpecification=self.spec_class.schema(),
)
def _validate_and_get_validation_policy(self, stream_config: FileBasedStreamConfig):
def _validate_and_get_validation_policy(self, stream_config: FileBasedStreamConfig) -> AbstractSchemaValidationPolicy:
if stream_config.validation_policy not in self.validation_policies:
raise ValidationError(
f"`validation_policy` must be one of {list(self.validation_policies.keys())}", model=FileBasedStreamConfig
)
return self.validation_policies[stream_config.validation_policy]
def _validate_input_schema(self, stream_config: FileBasedStreamConfig):
def _validate_input_schema(self, stream_config: FileBasedStreamConfig) -> None:
if stream_config.schemaless and stream_config.input_schema:
raise ValidationError("`input_schema` and `schemaless` options cannot both be set", model=FileBasedStreamConfig)

View File

@@ -1,9 +1,12 @@
from typing import Mapping
from .avro_parser import AvroParser
from .csv_parser import CsvParser
from .file_type_parser import FileTypeParser
from .jsonl_parser import JsonlParser
from .parquet_parser import ParquetParser
default_parsers = {
default_parsers: Mapping[str, FileTypeParser] = {
"avro": AvroParser(),
"csv": CsvParser(),
"jsonl": JsonlParser(),

View File

@@ -19,7 +19,7 @@ class AvroParser(FileTypeParser):
stream_reader: AbstractFileBasedStreamReader,
logger: logging.Logger,
) -> Dict[str, Any]:
...
raise NotImplementedError()
def parse_records(
self,
@@ -28,4 +28,4 @@ class AvroParser(FileTypeParser):
stream_reader: AbstractFileBasedStreamReader,
logger: logging.Logger,
) -> Iterable[Dict[str, Any]]:
...
raise NotImplementedError()

View File

@@ -17,7 +17,7 @@ from airbyte_cdk.sources.file_based.schema_helpers import TYPE_PYTHON_MAPPING
DIALECT_NAME = "_config_dialect"
config_to_quoting: [QuotingBehavior, int] = {
config_to_quoting: Mapping[QuotingBehavior, int] = {
QuotingBehavior.QUOTE_ALL: csv.QUOTE_ALL,
QuotingBehavior.QUOTE_SPECIAL_CHARACTERS: csv.QUOTE_MINIMAL,
QuotingBehavior.QUOTE_NONNUMERIC: csv.QUOTE_NONNUMERIC,
@@ -47,13 +47,13 @@ class CsvParser(FileTypeParser):
with stream_reader.open_file(file) as fp:
# todo: the existing InMemoryFilesSource.open_file() test source doesn't currently require an encoding, but actual
# sources will likely require one. Rather than modify the interface now we can wait until the real use case
reader = csv.DictReader(fp, dialect=dialect_name)
reader = csv.DictReader(fp, dialect=dialect_name) # type: ignore
schema = {field.strip(): {"type": "string"} for field in next(reader)}
csv.unregister_dialect(dialect_name)
return schema
else:
with stream_reader.open_file(file) as fp:
reader = csv.DictReader(fp)
reader = csv.DictReader(fp) # type: ignore
return {field.strip(): {"type": "string"} for field in next(reader)}
def parse_records(
@@ -63,7 +63,7 @@ class CsvParser(FileTypeParser):
stream_reader: AbstractFileBasedStreamReader,
logger: logging.Logger,
) -> Iterable[Dict[str, Any]]:
schema = config.input_schema
schema: Mapping[str, Any] = config.input_schema # type: ignore
config_format = config.format.get(config.file_type) if config.format else None
if config_format:
# Formats are configured individually per-stream so a unique dialect should be registered for each stream.
@@ -80,16 +80,16 @@ class CsvParser(FileTypeParser):
with stream_reader.open_file(file) as fp:
# todo: the existing InMemoryFilesSource.open_file() test source doesn't currently require an encoding, but actual
# sources will likely require one. Rather than modify the interface now we can wait until the real use case
reader = csv.DictReader(fp, dialect=dialect_name)
reader = csv.DictReader(fp, dialect=dialect_name) # type: ignore
yield from self._read_and_cast_types(reader, schema, logger)
else:
with stream_reader.open_file(file) as fp:
reader = csv.DictReader(fp)
reader = csv.DictReader(fp) # type: ignore
yield from self._read_and_cast_types(reader, schema, logger)
@staticmethod
def _read_and_cast_types(
reader: csv.DictReader, schema: Optional[Mapping[str, str]], logger: logging.Logger
reader: csv.DictReader, schema: Optional[Mapping[str, Any]], logger: logging.Logger # type: ignore
) -> Iterable[Dict[str, Any]]:
"""
If the user provided a schema, attempt to cast the record values to the associated type.
@@ -120,9 +120,9 @@ def cast_types(row: Dict[str, str], property_types: Dict[str, Any], logger: logg
for key, value in row.items():
prop_type = property_types.get(key)
cast_value = value
cast_value: Any = value
if prop_type in TYPE_PYTHON_MAPPING:
if prop_type in TYPE_PYTHON_MAPPING and prop_type is not None:
_, python_type = TYPE_PYTHON_MAPPING[prop_type]
if python_type is None:
@@ -169,5 +169,5 @@ def cast_types(row: Dict[str, str], property_types: Dict[str, Any], logger: logg
return result
def _format_warning(key: str, value: str, expected_type: str) -> str:
def _format_warning(key: str, value: str, expected_type: Optional[Any]) -> str:
return f"{key}: value={value},expected_type={expected_type}"

View File

@@ -22,7 +22,7 @@ class JsonlParser(FileTypeParser):
stream_reader: AbstractFileBasedStreamReader,
logger: logging.Logger,
) -> Dict[str, Any]:
...
raise NotImplementedError()
def parse_records(
self,
@@ -31,4 +31,4 @@ class JsonlParser(FileTypeParser):
stream_reader: AbstractFileBasedStreamReader,
logger: logging.Logger,
) -> Iterable[Dict[str, Any]]:
...
raise NotImplementedError()

View File

@@ -21,4 +21,6 @@ class RemoteFile(BaseModel):
extensions = self.uri.split(".")[1:]
if not extensions:
return True
if not self.file_type:
return True
return any(self.file_type.casefold() in e.casefold() for e in extensions)

View File

@@ -6,11 +6,11 @@ import json
from copy import deepcopy
from enum import Enum
from functools import total_ordering
from typing import Any, Dict, List, Literal, Mapping, Optional, Union
from typing import Any, Dict, List, Literal, Mapping, Optional, Tuple, Type, Union
from airbyte_cdk.sources.file_based.exceptions import ConfigValidationError, FileBasedSourceError, SchemaInferenceError
JsonSchemaSupportedType = Union[List, Literal["string"], str]
JsonSchemaSupportedType = Union[List[str], Literal["string"], str]
SchemaType = Dict[str, Dict[str, JsonSchemaSupportedType]]
schemaless_schema = {"type": "object", "properties": {"data": {"type": "object"}}}
@@ -25,14 +25,14 @@ class ComparableType(Enum):
STRING = 4
OBJECT = 5
def __lt__(self, other):
def __lt__(self, other: Any) -> bool:
if self.__class__ is other.__class__:
return self.value < other.value
return self.value < other.value # type: ignore
else:
return NotImplemented
TYPE_PYTHON_MAPPING = {
TYPE_PYTHON_MAPPING: Mapping[str, Tuple[str, Optional[Type[Any]]]] = {
"null": ("null", None),
"array": ("array", list),
"boolean": ("boolean", bool),
@@ -44,7 +44,7 @@ TYPE_PYTHON_MAPPING = {
}
def get_comparable_type(value: Any) -> ComparableType:
def get_comparable_type(value: Any) -> Optional[ComparableType]:
if value == "null":
return ComparableType.NULL
if value == "boolean":
@@ -57,9 +57,11 @@ def get_comparable_type(value: Any) -> ComparableType:
return ComparableType.STRING
if value == "object":
return ComparableType.OBJECT
else:
return None
def get_inferred_type(value: Any) -> ComparableType:
def get_inferred_type(value: Any) -> Optional[ComparableType]:
if value is None:
return ComparableType.NULL
if isinstance(value, bool):
@@ -72,6 +74,8 @@ def get_inferred_type(value: Any) -> ComparableType:
return ComparableType.STRING
if isinstance(value, dict):
return ComparableType.OBJECT
else:
return None
def merge_schemas(schema1: SchemaType, schema2: SchemaType) -> SchemaType:
@@ -91,10 +95,10 @@ def merge_schemas(schema1: SchemaType, schema2: SchemaType) -> SchemaType:
and nothing else.
"""
for k, t in list(schema1.items()) + list(schema2.items()):
if not isinstance(t, dict) or not _is_valid_type(t.get("type")):
if not isinstance(t, dict) or "type" not in t or not _is_valid_type(t["type"]):
raise SchemaInferenceError(FileBasedSourceError.UNRECOGNIZED_TYPE, key=k, type=t)
merged_schema = deepcopy(schema1)
merged_schema: Dict[str, Any] = deepcopy(schema1)
for k2, t2 in schema2.items():
t1 = merged_schema.get(k2)
if t1 is None:
@@ -136,7 +140,7 @@ def _choose_wider_type(key: str, t1: Dict[str, Any], t2: Dict[str, Any]) -> Dict
) # accessing the type_mapping value
def is_equal_or_narrower_type(value: Any, expected_type: str):
def is_equal_or_narrower_type(value: Any, expected_type: str) -> bool:
if isinstance(value, list):
# We do not compare lists directly; the individual items are compared.
# If we hit this condition, it means that the expected type is not
@@ -151,7 +155,7 @@ def is_equal_or_narrower_type(value: Any, expected_type: str):
return ComparableType(inferred_type) <= ComparableType(get_comparable_type(expected_type))
def conforms_to_schema(record: Mapping[str, Any], schema: Mapping[str, str]) -> bool:
def conforms_to_schema(record: Mapping[str, Any], schema: Mapping[str, Any]) -> bool:
"""
Return true iff the record conforms to the supplied schema.
@@ -185,9 +189,12 @@ def conforms_to_schema(record: Mapping[str, Any], schema: Mapping[str, str]) ->
return True
def _parse_json_input(input_schema: Optional[Union[str, Dict[str, str]]]) -> Optional[Mapping[str, str]]:
def _parse_json_input(input_schema: Union[str, Mapping[str, str]]) -> Optional[Mapping[str, str]]:
try:
schema = json.loads(input_schema)
if isinstance(input_schema, str):
schema: Mapping[str, str] = json.loads(input_schema)
else:
schema = input_schema
if not all(isinstance(s, str) for s in schema.values()):
raise ConfigValidationError(
FileBasedSourceError.ERROR_PARSING_USER_PROVIDED_SCHEMA, details="Invalid input schema; nested schemas are not supported."
@@ -199,7 +206,7 @@ def _parse_json_input(input_schema: Optional[Union[str, Dict[str, str]]]) -> Opt
return schema
def type_mapping_to_jsonschema(input_schema: Optional[Union[str, Mapping[str, str]]]) -> Optional[Mapping[str, str]]:
def type_mapping_to_jsonschema(input_schema: Optional[Union[str, Mapping[str, str]]]) -> Optional[Mapping[str, Any]]:
"""
Return the user input schema (type mapping), transformed to JSON Schema format.

View File

@@ -3,7 +3,7 @@
#
from abc import ABC, abstractmethod
from typing import Any, Mapping
from typing import Any, Mapping, Optional
class AbstractSchemaValidationPolicy(ABC):
@@ -11,8 +11,8 @@ class AbstractSchemaValidationPolicy(ABC):
validate_schema_before_sync = False # Whether to verify that records conform to the schema during the stream's availabilty check
@abstractmethod
def record_passes_validation_policy(self, record: Mapping[str, Any], schema: Mapping[str, Any]) -> bool:
def record_passes_validation_policy(self, record: Mapping[str, Any], schema: Optional[Mapping[str, Any]]) -> bool:
"""
Return True if the record passes the user's validation policy.
"""
...
raise NotImplementedError()

View File

@@ -2,7 +2,7 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
from typing import Any, Mapping
from typing import Any, Mapping, Optional
from airbyte_cdk.sources.file_based.exceptions import FileBasedSourceError, StopSyncPerValidationPolicy
from airbyte_cdk.sources.file_based.schema_helpers import conforms_to_schema
@@ -12,23 +12,23 @@ from airbyte_cdk.sources.file_based.schema_validation_policies import AbstractSc
class EmitRecordPolicy(AbstractSchemaValidationPolicy):
name = "emit_record"
def record_passes_validation_policy(self, record: Mapping[str, Any], schema: Mapping[str, Any]) -> bool:
def record_passes_validation_policy(self, record: Mapping[str, Any], schema: Optional[Mapping[str, Any]]) -> bool:
return True
class SkipRecordPolicy(AbstractSchemaValidationPolicy):
name = "skip_record"
def record_passes_validation_policy(self, record: Mapping[str, Any], schema: Mapping[str, Any]) -> bool:
return conforms_to_schema(record, schema)
def record_passes_validation_policy(self, record: Mapping[str, Any], schema: Optional[Mapping[str, Any]]) -> bool:
return schema is not None and conforms_to_schema(record, schema)
class WaitForDiscoverPolicy(AbstractSchemaValidationPolicy):
name = "wait_for_discover"
validate_schema_before_sync = True
def record_passes_validation_policy(self, record: Mapping[str, Any], schema: Mapping[str, Any]) -> bool:
if not conforms_to_schema(record, schema):
def record_passes_validation_policy(self, record: Mapping[str, Any], schema: Optional[Mapping[str, Any]]) -> bool:
if schema is None or not conforms_to_schema(record, schema):
raise StopSyncPerValidationPolicy(FileBasedSourceError.STOP_SYNC_PER_SCHEMA_VALIDATION_POLICY)
return True

View File

@@ -3,7 +3,7 @@
#
from abc import abstractmethod
from functools import cached_property
from functools import cached_property, lru_cache
from typing import Any, Dict, Iterable, List, Mapping, Optional
from airbyte_cdk.models import ConfiguredAirbyteCatalog, SyncMode
@@ -14,7 +14,7 @@ from airbyte_cdk.sources.file_based.file_based_stream_reader import AbstractFile
from airbyte_cdk.sources.file_based.file_types.file_type_parser import FileTypeParser
from airbyte_cdk.sources.file_based.remote_file import RemoteFile
from airbyte_cdk.sources.file_based.schema_validation_policies import AbstractSchemaValidationPolicy
from airbyte_cdk.sources.file_based.types import StreamSlice, StreamState
from airbyte_cdk.sources.file_based.types import StreamSlice
from airbyte_cdk.sources.streams import Stream
from airbyte_cdk.sources.streams.availability_strategy import AvailabilityStrategy
@@ -69,15 +69,17 @@ class AbstractFileBasedStream(Stream):
def read_records(
self,
sync_mode: SyncMode,
cursor_field: List[str] = None,
cursor_field: Optional[List[str]] = None,
stream_slice: Optional[StreamSlice] = None,
stream_state: Optional[StreamState] = None,
stream_state: Optional[Mapping[str, Any]] = None,
) -> Iterable[Mapping[str, Any]]:
"""
Yield all records from all remote files in `list_files_for_this_sync`.
This method acts as an adapter between the generic Stream interface and the file-based's
stream since file-based streams manage their own states.
"""
if stream_slice is None:
raise ValueError("stream_slice must be set")
return self.read_records_from_slice(stream_slice)
@abstractmethod
@@ -88,7 +90,7 @@ class AbstractFileBasedStream(Stream):
...
def stream_slices(
self, *, sync_mode: SyncMode, cursor_field: List[str] = None, stream_state: StreamState = None
self, *, sync_mode: SyncMode, cursor_field: Optional[List[str]] = None, stream_state: Optional[Mapping[str, Any]] = None
) -> Iterable[Optional[Mapping[str, Any]]]:
"""
This method acts as an adapter between the generic Stream interface and the file-based's
@@ -105,6 +107,7 @@ class AbstractFileBasedStream(Stream):
...
@abstractmethod
@lru_cache(maxsize=None)
def get_json_schema(self) -> Mapping[str, Any]:
"""
Return the JSON Schema for a stream.
@@ -133,7 +136,7 @@ class AbstractFileBasedStream(Stream):
)
@cached_property
def availability_strategy(self):
def availability_strategy(self) -> AvailabilityStrategy:
return self._availability_strategy
@property

View File

@@ -4,7 +4,7 @@
import logging
from datetime import datetime, timedelta
from typing import Iterable, Mapping, Optional
from typing import Iterable, MutableMapping, Optional
from airbyte_cdk.sources.file_based.remote_file import RemoteFile
from airbyte_cdk.sources.file_based.stream.cursor.file_based_cursor import FileBasedCursor
@@ -16,7 +16,7 @@ class DefaultFileBasedCursor(FileBasedCursor):
DATE_TIME_FORMAT = "%Y-%m-%dT%H:%M:%S.%fZ"
def __init__(self, max_history_size: int, days_to_sync_if_history_is_full: Optional[int]):
self._file_to_datetime_history: Mapping[str:datetime] = {}
self._file_to_datetime_history: MutableMapping[str, str] = {}
self._max_history_size = max_history_size
self._time_window_if_history_is_full = timedelta(
days=days_to_sync_if_history_is_full or self.DEFAULT_DAYS_TO_SYNC_IF_HISTORY_IS_FULL
@@ -62,7 +62,7 @@ class DefaultFileBasedCursor(FileBasedCursor):
def _should_sync_file(self, file: RemoteFile, logger: logging.Logger) -> bool:
if file.uri in self._file_to_datetime_history:
# If the file's uri is in the history, we should sync the file if it has been modified since it was synced
updated_at_from_history = datetime.strptime(self._file_to_datetime_history.get(file.uri), self.DATE_TIME_FORMAT)
updated_at_from_history = datetime.strptime(self._file_to_datetime_history[file.uri], self.DATE_TIME_FORMAT)
if file.last_modified < updated_at_from_history:
logger.warning(
f"The file {file.uri}'s last modified date is older than the last time it was synced. This is unexpected. Skipping the file."
@@ -71,6 +71,8 @@ class DefaultFileBasedCursor(FileBasedCursor):
return file.last_modified > updated_at_from_history
return file.last_modified > updated_at_from_history
if self._is_history_full():
if self._initial_earliest_file_in_history is None:
return True
if file.last_modified > self._initial_earliest_file_in_history.last_modified:
# If the history is partial and the file's datetime is strictly greater than the earliest file in the history,
# we should sync it

View File

@@ -5,7 +5,7 @@
import logging
from abc import ABC, abstractmethod
from datetime import datetime
from typing import Any, Iterable, Mapping
from typing import Any, Iterable, MutableMapping
from airbyte_cdk.sources.file_based.remote_file import RemoteFile
from airbyte_cdk.sources.file_based.types import StreamState
@@ -32,7 +32,7 @@ class FileBasedCursor(ABC):
"""
@abstractmethod
def get_state(self) -> Mapping[str, Any]:
def get_state(self) -> MutableMapping[str, Any]:
"""
Get the state of the cursor.
"""

View File

@@ -6,9 +6,11 @@ import asyncio
import itertools
import traceback
from functools import cache
from typing import Any, Dict, Iterable, List, Mapping, MutableMapping, Optional, Union
from typing import Any, Dict, Iterable, List, Mapping, MutableMapping, Optional, Set, Union
from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, Level, Type
from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, Level
from airbyte_cdk.models import Type as MessageType
from airbyte_cdk.sources.file_based.config.file_based_stream_config import PrimaryKeyType
from airbyte_cdk.sources.file_based.exceptions import (
FileBasedSourceError,
InvalidSchemaError,
@@ -36,7 +38,7 @@ class DefaultFileBasedStream(AbstractFileBasedStream, IncrementalMixin):
ab_file_name_col = "_ab_source_file_url"
airbyte_columns = [ab_last_mod_col, ab_file_name_col]
def __init__(self, cursor: FileBasedCursor, **kwargs):
def __init__(self, cursor: FileBasedCursor, **kwargs: Any):
super().__init__(**kwargs)
self._cursor = cursor
@@ -45,12 +47,12 @@ class DefaultFileBasedStream(AbstractFileBasedStream, IncrementalMixin):
return self._cursor.get_state()
@state.setter
def state(self, value: MutableMapping[str, Any]):
def state(self, value: MutableMapping[str, Any]) -> None:
"""State setter, accept state serialized by state getter."""
self._cursor.set_initial_state(value)
@property
def primary_key(self) -> Optional[Union[str, List[str], List[List[str]]]]:
def primary_key(self) -> PrimaryKeyType:
return self.config.primary_key
def compute_slices(self) -> Iterable[Optional[Mapping[str, Any]]]:
@@ -93,7 +95,7 @@ class DefaultFileBasedStream(AbstractFileBasedStream, IncrementalMixin):
except StopSyncPerValidationPolicy:
yield AirbyteMessage(
type=Type.LOG,
type=MessageType.LOG,
log=AirbyteLogMessage(
level=Level.WARN,
message=f"Stopping sync in accordance with the configured validation policy. Records in file did not conform to the schema. stream={self.name} file={file.uri} validation_policy={self.config.validation_policy} n_skipped={n_skipped}",
@@ -103,7 +105,7 @@ class DefaultFileBasedStream(AbstractFileBasedStream, IncrementalMixin):
except Exception:
yield AirbyteMessage(
type=Type.LOG,
type=MessageType.LOG,
log=AirbyteLogMessage(
level=Level.ERROR,
message=f"{FileBasedSourceError.ERROR_PARSING_RECORD.value} stream={self.name} file={file.uri} line_no={line_no} n_skipped={n_skipped}",
@@ -115,7 +117,7 @@ class DefaultFileBasedStream(AbstractFileBasedStream, IncrementalMixin):
else:
if n_skipped:
yield AirbyteMessage(
type=Type.LOG,
type=MessageType.LOG,
log=AirbyteLogMessage(
level=Level.WARN,
message=f"Records in file did not pass validation policy. stream={self.name} file={file.uri} n_skipped={n_skipped} validation_policy={self.validation_policy.name}",
@@ -141,12 +143,11 @@ class DefaultFileBasedStream(AbstractFileBasedStream, IncrementalMixin):
except Exception as exc:
raise SchemaInferenceError(FileBasedSourceError.SCHEMA_INFERENCE_ERROR, stream=self.name) from exc
else:
schema["properties"] = {**extra_fields, **schema["properties"]}
return schema
return {"type": "object", "properties": {**extra_fields, **schema["properties"]}}
def _get_raw_json_schema(self) -> JsonSchema:
if self.config.input_schema:
return self.config.input_schema
return self.config.input_schema # type: ignore
elif self.config.schemaless:
return schemaless_schema
else:
@@ -180,7 +181,7 @@ class DefaultFileBasedStream(AbstractFileBasedStream, IncrementalMixin):
The output of this method is cached so we don't need to list the files more than once.
This means we won't pick up changes to the files during a sync.
"""
return list(self._stream_reader.get_matching_files(self.config.globs))
return list(self._stream_reader.get_matching_files(self.config.globs or []))
def infer_schema(self, files: List[RemoteFile]) -> Mapping[str, Any]:
loop = asyncio.get_event_loop()
@@ -193,13 +194,13 @@ class DefaultFileBasedStream(AbstractFileBasedStream, IncrementalMixin):
Each file type has a corresponding `infer_schema` handler.
Dispatch on file type.
"""
base_schema: Dict[str, str] = {}
pending_tasks = set()
base_schema: Dict[str, Any] = {}
pending_tasks: Set[asyncio.tasks.Task[Dict[str, Any]]] = set()
n_started, n_files = 0, len(files)
files = iter(files)
files_iterator = iter(files)
while pending_tasks or n_started < n_files:
while len(pending_tasks) <= self._discovery_policy.n_concurrent_requests and (file := next(files, None)):
while len(pending_tasks) <= self._discovery_policy.n_concurrent_requests and (file := next(files_iterator, None)):
pending_tasks.add(asyncio.create_task(self._infer_file_schema(file)))
n_started += 1
# Return when the first task is completed so that we can enqueue a new task as soon as the
@@ -210,7 +211,7 @@ class DefaultFileBasedStream(AbstractFileBasedStream, IncrementalMixin):
return base_schema
async def _infer_file_schema(self, file: RemoteFile) -> Mapping[str, Any]:
async def _infer_file_schema(self, file: RemoteFile) -> Dict[str, Any]:
try:
return await self.get_parser(self.config.file_type).infer_schema(self.config, file, self._stream_reader, self.logger)
except Exception as exc:

View File

@@ -4,7 +4,7 @@
from __future__ import annotations
from typing import Any, Mapping
from typing import Any, Mapping, MutableMapping
StreamSlice = Mapping[str, Any]
StreamState = Mapping[str, Any]
StreamState = MutableMapping[str, Any]