Run mypy on airbyte-cdk as part of the build pipeline and fix typing issues in the file-based module (#27790)
* Try running only on modified files * make a change * return something with the wrong type * Revert "return something with the wrong type" This reverts commit23b828371e. * fix typing in file-based * format * Mypy * fix * leave as Mapping * Revert "leave as Mapping" This reverts commit908f063f70. * Use Dict * update * move dict() * Revert "move dict()" This reverts commitfa347a8236. * Revert "Revert "move dict()"" This reverts commitc9237df2e4. * Revert "Revert "Revert "move dict()""" This reverts commit5ac1616414. * use Mapping * point to config file * comment * strict = False * remove -- * Revert "comment" This reverts commit6000814a82. * install types * install types in same command as mypy runs * non-interactive * freeze version * pydantic plugin * plugins * update * ignore missing import * Revert "ignore missing import" This reverts commit1da7930fb7. * Install pydantic instead * fix * this passes locally * strict = true * format * explicitly import models * Update * remove old mypy.ini config * temporarily disable mypy * format * any * format * fix tests * format * Automated Commit - Formatting Changes * Revert "temporarily disable mypy" This reverts commiteb8470fa3f. * implicit reexport * update test * fix mypy * Automated Commit - Formatting Changes * fix some errors in tests * more type fixes * more fixes * more * . * done with tests * fix last files * format * Update gradle * change source-stripe * only run mypy on cdk * remove strict * Add more rules * update * ignore missing imports * cast to string * Allow untyped decorator * reset to master * move to the cdk * derp * move explicit imports around * Automated Commit - Formatting Changes * Revert "move explicit imports around" This reverts commit56e306b72f. * move explicit imports around * Upgrade mypy version * point to config file * Update readme * Ignore errors in the models module * Automated Commit - Formatting Changes * move check to gradle build * Any * try checking out master too * Revert "try checking out master too" This reverts commit8a8f3e373c. * fetch master * install mypy * try without origin * fetch from the script * checkout master * ls the branches * remotes/origin/master * remove some cruft * comment * remove pydantic types * unpin mypy * fetch from the script * Update connectors base too * modify a non-cdk file to confirm it doesn't get checked by mypy * run mypy after generateComponentManifestClassFiles * run from the venv * pass files as arguments * update * fix when running without args * with subdir * path * try without / * ./ * remove filter * try resetting * Revert "try resetting" This reverts commit3a54c424de. * exclude autogen file * do not use the github action * works locally * remove extra fetch * run on connectors base * try bad typing * Revert "try bad typing" This reverts commit33b512a3e4. * reset stripe * Revert "reset stripe" This reverts commit28f23fc6dd. * Revert "Revert "reset stripe"" This reverts commit5bf5dee371. * missing return type * do not ignore the autogen file * remove extra installs * run from venv * Only check files modified on current branch * Revert "Only check files modified on current branch" This reverts commitb4b728e654. * use merge-base * Revert "use merge-base" This reverts commit3136670cbf. * try with updated mypy * bump * run other steps after mypy * reset task ordering * run mypy though * looser config * tests pass * fix mypy issues * type: ignore * optional * this is always a bool * ignore * fix typing issues * remove ignore * remove mapping * Automated Commit - Formatting Changes * Revert "remove ignore" This reverts commit9ffeeb6cb1. * update config --------- Co-authored-by: girarda <girarda@users.noreply.github.com> Co-authored-by: Joe Bell <joseph.bell@airbyte.io>
This commit is contained in:
@@ -44,7 +44,7 @@ class CsvFormat(BaseModel):
|
||||
double_quote: bool = Field(
|
||||
title="Double Quote", default=True, description="Whether two quotes in a quoted CSV value denote a single quote in the data."
|
||||
)
|
||||
quoting_behavior: Optional[QuotingBehavior] = Field(
|
||||
quoting_behavior: QuotingBehavior = Field(
|
||||
title="Quoting Behavior",
|
||||
default=QuotingBehavior.QUOTE_SPECIAL_CHARACTERS,
|
||||
description="The quoting behavior determines when a value in a row should have quote marks added around it. For example, if Quote Non-numeric is specified, while reading, quotes are expected for row values that do not contain numbers. Or for Quote All, every row value will be expecting quotes.",
|
||||
@@ -54,7 +54,7 @@ class CsvFormat(BaseModel):
|
||||
# of using pyarrow
|
||||
|
||||
@validator("delimiter")
|
||||
def validate_delimiter(cls, v):
|
||||
def validate_delimiter(cls, v: str) -> str:
|
||||
if len(v) != 1:
|
||||
raise ValueError("delimiter should only be one character")
|
||||
if v in {"\r", "\n"}:
|
||||
@@ -62,19 +62,19 @@ class CsvFormat(BaseModel):
|
||||
return v
|
||||
|
||||
@validator("quote_char")
|
||||
def validate_quote_char(cls, v):
|
||||
def validate_quote_char(cls, v: str) -> str:
|
||||
if len(v) != 1:
|
||||
raise ValueError("quote_char should only be one character")
|
||||
return v
|
||||
|
||||
@validator("escape_char")
|
||||
def validate_escape_char(cls, v):
|
||||
def validate_escape_char(cls, v: str) -> str:
|
||||
if len(v) != 1:
|
||||
raise ValueError("escape_char should only be one character")
|
||||
return v
|
||||
|
||||
@validator("encoding")
|
||||
def validate_encoding(cls, v):
|
||||
def validate_encoding(cls, v: str) -> str:
|
||||
try:
|
||||
codecs.lookup(v)
|
||||
except LookupError:
|
||||
@@ -116,13 +116,13 @@ class FileBasedStreamConfig(BaseModel):
|
||||
)
|
||||
|
||||
@validator("file_type", pre=True)
|
||||
def validate_file_type(cls, v):
|
||||
def validate_file_type(cls, v: str) -> str:
|
||||
if v not in VALID_FILE_TYPES:
|
||||
raise ValueError(f"Format filetype {v} is not a supported file type")
|
||||
return v
|
||||
|
||||
@validator("format", pre=True)
|
||||
def transform_format(cls, v):
|
||||
def transform_format(cls, v: Mapping[str, str]) -> Any:
|
||||
if isinstance(v, Mapping):
|
||||
file_type = v.get("filetype", "")
|
||||
if file_type:
|
||||
@@ -132,6 +132,8 @@ class FileBasedStreamConfig(BaseModel):
|
||||
return v
|
||||
|
||||
@validator("input_schema", pre=True)
|
||||
def transform_input_schema(cls, v):
|
||||
def transform_input_schema(cls, v: Optional[Union[str, Mapping[str, Any]]]) -> Optional[Mapping[str, Any]]:
|
||||
if v:
|
||||
return type_mapping_to_jsonschema(v)
|
||||
else:
|
||||
return None
|
||||
|
||||
@@ -13,15 +13,14 @@ from airbyte_cdk.sources.file_based.remote_file import RemoteFile
|
||||
from airbyte_cdk.sources.file_based.schema_helpers import conforms_to_schema
|
||||
from airbyte_cdk.sources.file_based.stream import AbstractFileBasedStream
|
||||
from airbyte_cdk.sources.streams.availability_strategy import AvailabilityStrategy
|
||||
from airbyte_cdk.sources.streams.core import Stream
|
||||
|
||||
|
||||
class DefaultFileBasedAvailabilityStrategy(AvailabilityStrategy):
|
||||
def __init__(self, stream_reader: AbstractFileBasedStreamReader):
|
||||
self.stream_reader = stream_reader
|
||||
|
||||
def check_availability(
|
||||
self, stream: AbstractFileBasedStream, logger: logging.Logger, _: Optional[Source]
|
||||
) -> Tuple[bool, Optional[str]]:
|
||||
def check_availability(self, stream: Stream, logger: logging.Logger, _: Optional[Source]) -> Tuple[bool, Optional[str]]:
|
||||
"""
|
||||
Perform a connection check for the stream.
|
||||
|
||||
@@ -38,6 +37,8 @@ class DefaultFileBasedAvailabilityStrategy(AvailabilityStrategy):
|
||||
- If the user provided a schema in the config, check that a subset of records in
|
||||
one file conform to the schema via a call to stream.conforms_to_schema(schema).
|
||||
"""
|
||||
if not isinstance(stream, AbstractFileBasedStream):
|
||||
raise ValueError(f"Stream {stream.name} is not a file-based stream.")
|
||||
try:
|
||||
files = self._check_list_files(stream)
|
||||
self._check_parse_record(stream, files[0], logger)
|
||||
@@ -60,7 +61,7 @@ class DefaultFileBasedAvailabilityStrategy(AvailabilityStrategy):
|
||||
|
||||
return files
|
||||
|
||||
def _check_parse_record(self, stream: AbstractFileBasedStream, file: RemoteFile, logger: logging.Logger):
|
||||
def _check_parse_record(self, stream: AbstractFileBasedStream, file: RemoteFile, logger: logging.Logger) -> None:
|
||||
parser = stream.get_parser(stream.config.file_type)
|
||||
|
||||
try:
|
||||
@@ -75,7 +76,7 @@ class DefaultFileBasedAvailabilityStrategy(AvailabilityStrategy):
|
||||
|
||||
schema = stream.catalog_schema or stream.config.input_schema
|
||||
if schema and stream.validation_policy.validate_schema_before_sync:
|
||||
if not conforms_to_schema(record, schema):
|
||||
if not conforms_to_schema(record, schema): # type: ignore
|
||||
raise CheckAvailabilityError(
|
||||
FileBasedSourceError.ERROR_VALIDATING_RECORD,
|
||||
stream=stream.name,
|
||||
|
||||
@@ -37,7 +37,7 @@ class FileBasedSourceError(Enum):
|
||||
|
||||
|
||||
class BaseFileBasedSourceError(Exception):
|
||||
def __init__(self, error: FileBasedSourceError, **kwargs):
|
||||
def __init__(self, error: FileBasedSourceError, **kwargs): # type: ignore # noqa
|
||||
super().__init__(
|
||||
f"{FileBasedSourceError(error).value} Contact Support if you need assistance.\n{' '.join([f'{k}={v}' for k, v in kwargs.items()])}"
|
||||
)
|
||||
|
||||
@@ -5,10 +5,9 @@
|
||||
import logging
|
||||
import traceback
|
||||
from abc import ABC
|
||||
from typing import Any, Dict, List, Mapping, Optional, Tuple, Type
|
||||
from typing import Any, List, Mapping, Optional, Tuple, Type
|
||||
|
||||
from airbyte_cdk.models import ConfiguredAirbyteCatalog
|
||||
from airbyte_cdk.models.airbyte_protocol import ConnectorSpecification
|
||||
from airbyte_cdk.models import ConfiguredAirbyteCatalog, ConnectorSpecification
|
||||
from airbyte_cdk.sources import AbstractSource
|
||||
from airbyte_cdk.sources.file_based.config.abstract_file_based_spec import AbstractFileBasedSpec
|
||||
from airbyte_cdk.sources.file_based.config.file_based_stream_config import FileBasedStreamConfig
|
||||
@@ -21,6 +20,7 @@ from airbyte_cdk.sources.file_based.file_types.file_type_parser import FileTypeP
|
||||
from airbyte_cdk.sources.file_based.schema_validation_policies import DEFAULT_SCHEMA_VALIDATION_POLICIES, AbstractSchemaValidationPolicy
|
||||
from airbyte_cdk.sources.file_based.stream import AbstractFileBasedStream, DefaultFileBasedStream
|
||||
from airbyte_cdk.sources.file_based.stream.cursor.default_file_based_cursor import DefaultFileBasedCursor
|
||||
from airbyte_cdk.sources.streams import Stream
|
||||
from airbyte_cdk.sources.streams.availability_strategy import AvailabilityStrategy
|
||||
from pydantic.error_wrappers import ValidationError
|
||||
|
||||
@@ -35,15 +35,15 @@ class FileBasedSource(AbstractSource, ABC):
|
||||
availability_strategy: Optional[AvailabilityStrategy],
|
||||
spec_class: Type[AbstractFileBasedSpec],
|
||||
discovery_policy: AbstractDiscoveryPolicy = DefaultDiscoveryPolicy(),
|
||||
parsers: Dict[str, FileTypeParser] = None,
|
||||
validation_policies: Dict[str, AbstractSchemaValidationPolicy] = DEFAULT_SCHEMA_VALIDATION_POLICIES,
|
||||
parsers: Mapping[str, FileTypeParser] = default_parsers,
|
||||
validation_policies: Mapping[str, AbstractSchemaValidationPolicy] = DEFAULT_SCHEMA_VALIDATION_POLICIES,
|
||||
max_history_size: int = DEFAULT_MAX_HISTORY_SIZE,
|
||||
):
|
||||
self.stream_reader = stream_reader
|
||||
self.availability_strategy = availability_strategy or DefaultFileBasedAvailabilityStrategy(stream_reader)
|
||||
self.spec_class = spec_class
|
||||
self.discovery_policy = discovery_policy
|
||||
self.parsers = parsers or default_parsers
|
||||
self.parsers = parsers
|
||||
self.validation_policies = validation_policies
|
||||
self.stream_schemas = {s.stream.name: s.stream.json_schema for s in catalog.streams} if catalog else {}
|
||||
self.max_history_size = max_history_size
|
||||
@@ -70,6 +70,8 @@ class FileBasedSource(AbstractSource, ABC):
|
||||
|
||||
errors = []
|
||||
for stream in streams:
|
||||
if not isinstance(stream, AbstractFileBasedStream):
|
||||
raise ValueError(f"Stream {stream} is not a file-based stream.")
|
||||
try:
|
||||
(
|
||||
stream_is_available,
|
||||
@@ -78,18 +80,18 @@ class FileBasedSource(AbstractSource, ABC):
|
||||
except Exception:
|
||||
errors.append(f"Unable to connect to stream {stream} - {''.join(traceback.format_exc())}")
|
||||
else:
|
||||
if not stream_is_available:
|
||||
if not stream_is_available and reason:
|
||||
errors.append(reason)
|
||||
|
||||
return not bool(errors), (errors or None)
|
||||
|
||||
def streams(self, config: Mapping[str, Any]) -> List[AbstractFileBasedStream]:
|
||||
def streams(self, config: Mapping[str, Any]) -> List[Stream]:
|
||||
"""
|
||||
Return a list of this source's streams.
|
||||
"""
|
||||
try:
|
||||
parsed_config = self.spec_class(**config)
|
||||
streams = []
|
||||
streams: List[Stream] = []
|
||||
for stream_config in parsed_config.streams:
|
||||
self._validate_input_schema(stream_config)
|
||||
streams.append(
|
||||
@@ -119,13 +121,13 @@ class FileBasedSource(AbstractSource, ABC):
|
||||
connectionSpecification=self.spec_class.schema(),
|
||||
)
|
||||
|
||||
def _validate_and_get_validation_policy(self, stream_config: FileBasedStreamConfig):
|
||||
def _validate_and_get_validation_policy(self, stream_config: FileBasedStreamConfig) -> AbstractSchemaValidationPolicy:
|
||||
if stream_config.validation_policy not in self.validation_policies:
|
||||
raise ValidationError(
|
||||
f"`validation_policy` must be one of {list(self.validation_policies.keys())}", model=FileBasedStreamConfig
|
||||
)
|
||||
return self.validation_policies[stream_config.validation_policy]
|
||||
|
||||
def _validate_input_schema(self, stream_config: FileBasedStreamConfig):
|
||||
def _validate_input_schema(self, stream_config: FileBasedStreamConfig) -> None:
|
||||
if stream_config.schemaless and stream_config.input_schema:
|
||||
raise ValidationError("`input_schema` and `schemaless` options cannot both be set", model=FileBasedStreamConfig)
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
from typing import Mapping
|
||||
|
||||
from .avro_parser import AvroParser
|
||||
from .csv_parser import CsvParser
|
||||
from .file_type_parser import FileTypeParser
|
||||
from .jsonl_parser import JsonlParser
|
||||
from .parquet_parser import ParquetParser
|
||||
|
||||
default_parsers = {
|
||||
default_parsers: Mapping[str, FileTypeParser] = {
|
||||
"avro": AvroParser(),
|
||||
"csv": CsvParser(),
|
||||
"jsonl": JsonlParser(),
|
||||
|
||||
@@ -19,7 +19,7 @@ class AvroParser(FileTypeParser):
|
||||
stream_reader: AbstractFileBasedStreamReader,
|
||||
logger: logging.Logger,
|
||||
) -> Dict[str, Any]:
|
||||
...
|
||||
raise NotImplementedError()
|
||||
|
||||
def parse_records(
|
||||
self,
|
||||
@@ -28,4 +28,4 @@ class AvroParser(FileTypeParser):
|
||||
stream_reader: AbstractFileBasedStreamReader,
|
||||
logger: logging.Logger,
|
||||
) -> Iterable[Dict[str, Any]]:
|
||||
...
|
||||
raise NotImplementedError()
|
||||
|
||||
@@ -17,7 +17,7 @@ from airbyte_cdk.sources.file_based.schema_helpers import TYPE_PYTHON_MAPPING
|
||||
|
||||
DIALECT_NAME = "_config_dialect"
|
||||
|
||||
config_to_quoting: [QuotingBehavior, int] = {
|
||||
config_to_quoting: Mapping[QuotingBehavior, int] = {
|
||||
QuotingBehavior.QUOTE_ALL: csv.QUOTE_ALL,
|
||||
QuotingBehavior.QUOTE_SPECIAL_CHARACTERS: csv.QUOTE_MINIMAL,
|
||||
QuotingBehavior.QUOTE_NONNUMERIC: csv.QUOTE_NONNUMERIC,
|
||||
@@ -47,13 +47,13 @@ class CsvParser(FileTypeParser):
|
||||
with stream_reader.open_file(file) as fp:
|
||||
# todo: the existing InMemoryFilesSource.open_file() test source doesn't currently require an encoding, but actual
|
||||
# sources will likely require one. Rather than modify the interface now we can wait until the real use case
|
||||
reader = csv.DictReader(fp, dialect=dialect_name)
|
||||
reader = csv.DictReader(fp, dialect=dialect_name) # type: ignore
|
||||
schema = {field.strip(): {"type": "string"} for field in next(reader)}
|
||||
csv.unregister_dialect(dialect_name)
|
||||
return schema
|
||||
else:
|
||||
with stream_reader.open_file(file) as fp:
|
||||
reader = csv.DictReader(fp)
|
||||
reader = csv.DictReader(fp) # type: ignore
|
||||
return {field.strip(): {"type": "string"} for field in next(reader)}
|
||||
|
||||
def parse_records(
|
||||
@@ -63,7 +63,7 @@ class CsvParser(FileTypeParser):
|
||||
stream_reader: AbstractFileBasedStreamReader,
|
||||
logger: logging.Logger,
|
||||
) -> Iterable[Dict[str, Any]]:
|
||||
schema = config.input_schema
|
||||
schema: Mapping[str, Any] = config.input_schema # type: ignore
|
||||
config_format = config.format.get(config.file_type) if config.format else None
|
||||
if config_format:
|
||||
# Formats are configured individually per-stream so a unique dialect should be registered for each stream.
|
||||
@@ -80,16 +80,16 @@ class CsvParser(FileTypeParser):
|
||||
with stream_reader.open_file(file) as fp:
|
||||
# todo: the existing InMemoryFilesSource.open_file() test source doesn't currently require an encoding, but actual
|
||||
# sources will likely require one. Rather than modify the interface now we can wait until the real use case
|
||||
reader = csv.DictReader(fp, dialect=dialect_name)
|
||||
reader = csv.DictReader(fp, dialect=dialect_name) # type: ignore
|
||||
yield from self._read_and_cast_types(reader, schema, logger)
|
||||
else:
|
||||
with stream_reader.open_file(file) as fp:
|
||||
reader = csv.DictReader(fp)
|
||||
reader = csv.DictReader(fp) # type: ignore
|
||||
yield from self._read_and_cast_types(reader, schema, logger)
|
||||
|
||||
@staticmethod
|
||||
def _read_and_cast_types(
|
||||
reader: csv.DictReader, schema: Optional[Mapping[str, str]], logger: logging.Logger
|
||||
reader: csv.DictReader, schema: Optional[Mapping[str, Any]], logger: logging.Logger # type: ignore
|
||||
) -> Iterable[Dict[str, Any]]:
|
||||
"""
|
||||
If the user provided a schema, attempt to cast the record values to the associated type.
|
||||
@@ -120,9 +120,9 @@ def cast_types(row: Dict[str, str], property_types: Dict[str, Any], logger: logg
|
||||
|
||||
for key, value in row.items():
|
||||
prop_type = property_types.get(key)
|
||||
cast_value = value
|
||||
cast_value: Any = value
|
||||
|
||||
if prop_type in TYPE_PYTHON_MAPPING:
|
||||
if prop_type in TYPE_PYTHON_MAPPING and prop_type is not None:
|
||||
_, python_type = TYPE_PYTHON_MAPPING[prop_type]
|
||||
|
||||
if python_type is None:
|
||||
@@ -169,5 +169,5 @@ def cast_types(row: Dict[str, str], property_types: Dict[str, Any], logger: logg
|
||||
return result
|
||||
|
||||
|
||||
def _format_warning(key: str, value: str, expected_type: str) -> str:
|
||||
def _format_warning(key: str, value: str, expected_type: Optional[Any]) -> str:
|
||||
return f"{key}: value={value},expected_type={expected_type}"
|
||||
|
||||
@@ -22,7 +22,7 @@ class JsonlParser(FileTypeParser):
|
||||
stream_reader: AbstractFileBasedStreamReader,
|
||||
logger: logging.Logger,
|
||||
) -> Dict[str, Any]:
|
||||
...
|
||||
raise NotImplementedError()
|
||||
|
||||
def parse_records(
|
||||
self,
|
||||
@@ -31,4 +31,4 @@ class JsonlParser(FileTypeParser):
|
||||
stream_reader: AbstractFileBasedStreamReader,
|
||||
logger: logging.Logger,
|
||||
) -> Iterable[Dict[str, Any]]:
|
||||
...
|
||||
raise NotImplementedError()
|
||||
|
||||
@@ -21,4 +21,6 @@ class RemoteFile(BaseModel):
|
||||
extensions = self.uri.split(".")[1:]
|
||||
if not extensions:
|
||||
return True
|
||||
if not self.file_type:
|
||||
return True
|
||||
return any(self.file_type.casefold() in e.casefold() for e in extensions)
|
||||
|
||||
@@ -6,11 +6,11 @@ import json
|
||||
from copy import deepcopy
|
||||
from enum import Enum
|
||||
from functools import total_ordering
|
||||
from typing import Any, Dict, List, Literal, Mapping, Optional, Union
|
||||
from typing import Any, Dict, List, Literal, Mapping, Optional, Tuple, Type, Union
|
||||
|
||||
from airbyte_cdk.sources.file_based.exceptions import ConfigValidationError, FileBasedSourceError, SchemaInferenceError
|
||||
|
||||
JsonSchemaSupportedType = Union[List, Literal["string"], str]
|
||||
JsonSchemaSupportedType = Union[List[str], Literal["string"], str]
|
||||
SchemaType = Dict[str, Dict[str, JsonSchemaSupportedType]]
|
||||
|
||||
schemaless_schema = {"type": "object", "properties": {"data": {"type": "object"}}}
|
||||
@@ -25,14 +25,14 @@ class ComparableType(Enum):
|
||||
STRING = 4
|
||||
OBJECT = 5
|
||||
|
||||
def __lt__(self, other):
|
||||
def __lt__(self, other: Any) -> bool:
|
||||
if self.__class__ is other.__class__:
|
||||
return self.value < other.value
|
||||
return self.value < other.value # type: ignore
|
||||
else:
|
||||
return NotImplemented
|
||||
|
||||
|
||||
TYPE_PYTHON_MAPPING = {
|
||||
TYPE_PYTHON_MAPPING: Mapping[str, Tuple[str, Optional[Type[Any]]]] = {
|
||||
"null": ("null", None),
|
||||
"array": ("array", list),
|
||||
"boolean": ("boolean", bool),
|
||||
@@ -44,7 +44,7 @@ TYPE_PYTHON_MAPPING = {
|
||||
}
|
||||
|
||||
|
||||
def get_comparable_type(value: Any) -> ComparableType:
|
||||
def get_comparable_type(value: Any) -> Optional[ComparableType]:
|
||||
if value == "null":
|
||||
return ComparableType.NULL
|
||||
if value == "boolean":
|
||||
@@ -57,9 +57,11 @@ def get_comparable_type(value: Any) -> ComparableType:
|
||||
return ComparableType.STRING
|
||||
if value == "object":
|
||||
return ComparableType.OBJECT
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def get_inferred_type(value: Any) -> ComparableType:
|
||||
def get_inferred_type(value: Any) -> Optional[ComparableType]:
|
||||
if value is None:
|
||||
return ComparableType.NULL
|
||||
if isinstance(value, bool):
|
||||
@@ -72,6 +74,8 @@ def get_inferred_type(value: Any) -> ComparableType:
|
||||
return ComparableType.STRING
|
||||
if isinstance(value, dict):
|
||||
return ComparableType.OBJECT
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def merge_schemas(schema1: SchemaType, schema2: SchemaType) -> SchemaType:
|
||||
@@ -91,10 +95,10 @@ def merge_schemas(schema1: SchemaType, schema2: SchemaType) -> SchemaType:
|
||||
and nothing else.
|
||||
"""
|
||||
for k, t in list(schema1.items()) + list(schema2.items()):
|
||||
if not isinstance(t, dict) or not _is_valid_type(t.get("type")):
|
||||
if not isinstance(t, dict) or "type" not in t or not _is_valid_type(t["type"]):
|
||||
raise SchemaInferenceError(FileBasedSourceError.UNRECOGNIZED_TYPE, key=k, type=t)
|
||||
|
||||
merged_schema = deepcopy(schema1)
|
||||
merged_schema: Dict[str, Any] = deepcopy(schema1)
|
||||
for k2, t2 in schema2.items():
|
||||
t1 = merged_schema.get(k2)
|
||||
if t1 is None:
|
||||
@@ -136,7 +140,7 @@ def _choose_wider_type(key: str, t1: Dict[str, Any], t2: Dict[str, Any]) -> Dict
|
||||
) # accessing the type_mapping value
|
||||
|
||||
|
||||
def is_equal_or_narrower_type(value: Any, expected_type: str):
|
||||
def is_equal_or_narrower_type(value: Any, expected_type: str) -> bool:
|
||||
if isinstance(value, list):
|
||||
# We do not compare lists directly; the individual items are compared.
|
||||
# If we hit this condition, it means that the expected type is not
|
||||
@@ -151,7 +155,7 @@ def is_equal_or_narrower_type(value: Any, expected_type: str):
|
||||
return ComparableType(inferred_type) <= ComparableType(get_comparable_type(expected_type))
|
||||
|
||||
|
||||
def conforms_to_schema(record: Mapping[str, Any], schema: Mapping[str, str]) -> bool:
|
||||
def conforms_to_schema(record: Mapping[str, Any], schema: Mapping[str, Any]) -> bool:
|
||||
"""
|
||||
Return true iff the record conforms to the supplied schema.
|
||||
|
||||
@@ -185,9 +189,12 @@ def conforms_to_schema(record: Mapping[str, Any], schema: Mapping[str, str]) ->
|
||||
return True
|
||||
|
||||
|
||||
def _parse_json_input(input_schema: Optional[Union[str, Dict[str, str]]]) -> Optional[Mapping[str, str]]:
|
||||
def _parse_json_input(input_schema: Union[str, Mapping[str, str]]) -> Optional[Mapping[str, str]]:
|
||||
try:
|
||||
schema = json.loads(input_schema)
|
||||
if isinstance(input_schema, str):
|
||||
schema: Mapping[str, str] = json.loads(input_schema)
|
||||
else:
|
||||
schema = input_schema
|
||||
if not all(isinstance(s, str) for s in schema.values()):
|
||||
raise ConfigValidationError(
|
||||
FileBasedSourceError.ERROR_PARSING_USER_PROVIDED_SCHEMA, details="Invalid input schema; nested schemas are not supported."
|
||||
@@ -199,7 +206,7 @@ def _parse_json_input(input_schema: Optional[Union[str, Dict[str, str]]]) -> Opt
|
||||
return schema
|
||||
|
||||
|
||||
def type_mapping_to_jsonschema(input_schema: Optional[Union[str, Mapping[str, str]]]) -> Optional[Mapping[str, str]]:
|
||||
def type_mapping_to_jsonschema(input_schema: Optional[Union[str, Mapping[str, str]]]) -> Optional[Mapping[str, Any]]:
|
||||
"""
|
||||
Return the user input schema (type mapping), transformed to JSON Schema format.
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
#
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Mapping
|
||||
from typing import Any, Mapping, Optional
|
||||
|
||||
|
||||
class AbstractSchemaValidationPolicy(ABC):
|
||||
@@ -11,8 +11,8 @@ class AbstractSchemaValidationPolicy(ABC):
|
||||
validate_schema_before_sync = False # Whether to verify that records conform to the schema during the stream's availabilty check
|
||||
|
||||
@abstractmethod
|
||||
def record_passes_validation_policy(self, record: Mapping[str, Any], schema: Mapping[str, Any]) -> bool:
|
||||
def record_passes_validation_policy(self, record: Mapping[str, Any], schema: Optional[Mapping[str, Any]]) -> bool:
|
||||
"""
|
||||
Return True if the record passes the user's validation policy.
|
||||
"""
|
||||
...
|
||||
raise NotImplementedError()
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
|
||||
#
|
||||
|
||||
from typing import Any, Mapping
|
||||
from typing import Any, Mapping, Optional
|
||||
|
||||
from airbyte_cdk.sources.file_based.exceptions import FileBasedSourceError, StopSyncPerValidationPolicy
|
||||
from airbyte_cdk.sources.file_based.schema_helpers import conforms_to_schema
|
||||
@@ -12,23 +12,23 @@ from airbyte_cdk.sources.file_based.schema_validation_policies import AbstractSc
|
||||
class EmitRecordPolicy(AbstractSchemaValidationPolicy):
|
||||
name = "emit_record"
|
||||
|
||||
def record_passes_validation_policy(self, record: Mapping[str, Any], schema: Mapping[str, Any]) -> bool:
|
||||
def record_passes_validation_policy(self, record: Mapping[str, Any], schema: Optional[Mapping[str, Any]]) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
class SkipRecordPolicy(AbstractSchemaValidationPolicy):
|
||||
name = "skip_record"
|
||||
|
||||
def record_passes_validation_policy(self, record: Mapping[str, Any], schema: Mapping[str, Any]) -> bool:
|
||||
return conforms_to_schema(record, schema)
|
||||
def record_passes_validation_policy(self, record: Mapping[str, Any], schema: Optional[Mapping[str, Any]]) -> bool:
|
||||
return schema is not None and conforms_to_schema(record, schema)
|
||||
|
||||
|
||||
class WaitForDiscoverPolicy(AbstractSchemaValidationPolicy):
|
||||
name = "wait_for_discover"
|
||||
validate_schema_before_sync = True
|
||||
|
||||
def record_passes_validation_policy(self, record: Mapping[str, Any], schema: Mapping[str, Any]) -> bool:
|
||||
if not conforms_to_schema(record, schema):
|
||||
def record_passes_validation_policy(self, record: Mapping[str, Any], schema: Optional[Mapping[str, Any]]) -> bool:
|
||||
if schema is None or not conforms_to_schema(record, schema):
|
||||
raise StopSyncPerValidationPolicy(FileBasedSourceError.STOP_SYNC_PER_SCHEMA_VALIDATION_POLICY)
|
||||
return True
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
#
|
||||
|
||||
from abc import abstractmethod
|
||||
from functools import cached_property
|
||||
from functools import cached_property, lru_cache
|
||||
from typing import Any, Dict, Iterable, List, Mapping, Optional
|
||||
|
||||
from airbyte_cdk.models import ConfiguredAirbyteCatalog, SyncMode
|
||||
@@ -14,7 +14,7 @@ from airbyte_cdk.sources.file_based.file_based_stream_reader import AbstractFile
|
||||
from airbyte_cdk.sources.file_based.file_types.file_type_parser import FileTypeParser
|
||||
from airbyte_cdk.sources.file_based.remote_file import RemoteFile
|
||||
from airbyte_cdk.sources.file_based.schema_validation_policies import AbstractSchemaValidationPolicy
|
||||
from airbyte_cdk.sources.file_based.types import StreamSlice, StreamState
|
||||
from airbyte_cdk.sources.file_based.types import StreamSlice
|
||||
from airbyte_cdk.sources.streams import Stream
|
||||
from airbyte_cdk.sources.streams.availability_strategy import AvailabilityStrategy
|
||||
|
||||
@@ -69,15 +69,17 @@ class AbstractFileBasedStream(Stream):
|
||||
def read_records(
|
||||
self,
|
||||
sync_mode: SyncMode,
|
||||
cursor_field: List[str] = None,
|
||||
cursor_field: Optional[List[str]] = None,
|
||||
stream_slice: Optional[StreamSlice] = None,
|
||||
stream_state: Optional[StreamState] = None,
|
||||
stream_state: Optional[Mapping[str, Any]] = None,
|
||||
) -> Iterable[Mapping[str, Any]]:
|
||||
"""
|
||||
Yield all records from all remote files in `list_files_for_this_sync`.
|
||||
This method acts as an adapter between the generic Stream interface and the file-based's
|
||||
stream since file-based streams manage their own states.
|
||||
"""
|
||||
if stream_slice is None:
|
||||
raise ValueError("stream_slice must be set")
|
||||
return self.read_records_from_slice(stream_slice)
|
||||
|
||||
@abstractmethod
|
||||
@@ -88,7 +90,7 @@ class AbstractFileBasedStream(Stream):
|
||||
...
|
||||
|
||||
def stream_slices(
|
||||
self, *, sync_mode: SyncMode, cursor_field: List[str] = None, stream_state: StreamState = None
|
||||
self, *, sync_mode: SyncMode, cursor_field: Optional[List[str]] = None, stream_state: Optional[Mapping[str, Any]] = None
|
||||
) -> Iterable[Optional[Mapping[str, Any]]]:
|
||||
"""
|
||||
This method acts as an adapter between the generic Stream interface and the file-based's
|
||||
@@ -105,6 +107,7 @@ class AbstractFileBasedStream(Stream):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
@lru_cache(maxsize=None)
|
||||
def get_json_schema(self) -> Mapping[str, Any]:
|
||||
"""
|
||||
Return the JSON Schema for a stream.
|
||||
@@ -133,7 +136,7 @@ class AbstractFileBasedStream(Stream):
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def availability_strategy(self):
|
||||
def availability_strategy(self) -> AvailabilityStrategy:
|
||||
return self._availability_strategy
|
||||
|
||||
@property
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Iterable, Mapping, Optional
|
||||
from typing import Iterable, MutableMapping, Optional
|
||||
|
||||
from airbyte_cdk.sources.file_based.remote_file import RemoteFile
|
||||
from airbyte_cdk.sources.file_based.stream.cursor.file_based_cursor import FileBasedCursor
|
||||
@@ -16,7 +16,7 @@ class DefaultFileBasedCursor(FileBasedCursor):
|
||||
DATE_TIME_FORMAT = "%Y-%m-%dT%H:%M:%S.%fZ"
|
||||
|
||||
def __init__(self, max_history_size: int, days_to_sync_if_history_is_full: Optional[int]):
|
||||
self._file_to_datetime_history: Mapping[str:datetime] = {}
|
||||
self._file_to_datetime_history: MutableMapping[str, str] = {}
|
||||
self._max_history_size = max_history_size
|
||||
self._time_window_if_history_is_full = timedelta(
|
||||
days=days_to_sync_if_history_is_full or self.DEFAULT_DAYS_TO_SYNC_IF_HISTORY_IS_FULL
|
||||
@@ -62,7 +62,7 @@ class DefaultFileBasedCursor(FileBasedCursor):
|
||||
def _should_sync_file(self, file: RemoteFile, logger: logging.Logger) -> bool:
|
||||
if file.uri in self._file_to_datetime_history:
|
||||
# If the file's uri is in the history, we should sync the file if it has been modified since it was synced
|
||||
updated_at_from_history = datetime.strptime(self._file_to_datetime_history.get(file.uri), self.DATE_TIME_FORMAT)
|
||||
updated_at_from_history = datetime.strptime(self._file_to_datetime_history[file.uri], self.DATE_TIME_FORMAT)
|
||||
if file.last_modified < updated_at_from_history:
|
||||
logger.warning(
|
||||
f"The file {file.uri}'s last modified date is older than the last time it was synced. This is unexpected. Skipping the file."
|
||||
@@ -71,6 +71,8 @@ class DefaultFileBasedCursor(FileBasedCursor):
|
||||
return file.last_modified > updated_at_from_history
|
||||
return file.last_modified > updated_at_from_history
|
||||
if self._is_history_full():
|
||||
if self._initial_earliest_file_in_history is None:
|
||||
return True
|
||||
if file.last_modified > self._initial_earliest_file_in_history.last_modified:
|
||||
# If the history is partial and the file's datetime is strictly greater than the earliest file in the history,
|
||||
# we should sync it
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from typing import Any, Iterable, Mapping
|
||||
from typing import Any, Iterable, MutableMapping
|
||||
|
||||
from airbyte_cdk.sources.file_based.remote_file import RemoteFile
|
||||
from airbyte_cdk.sources.file_based.types import StreamState
|
||||
@@ -32,7 +32,7 @@ class FileBasedCursor(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_state(self) -> Mapping[str, Any]:
|
||||
def get_state(self) -> MutableMapping[str, Any]:
|
||||
"""
|
||||
Get the state of the cursor.
|
||||
"""
|
||||
|
||||
@@ -6,9 +6,11 @@ import asyncio
|
||||
import itertools
|
||||
import traceback
|
||||
from functools import cache
|
||||
from typing import Any, Dict, Iterable, List, Mapping, MutableMapping, Optional, Union
|
||||
from typing import Any, Dict, Iterable, List, Mapping, MutableMapping, Optional, Set, Union
|
||||
|
||||
from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, Level, Type
|
||||
from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, Level
|
||||
from airbyte_cdk.models import Type as MessageType
|
||||
from airbyte_cdk.sources.file_based.config.file_based_stream_config import PrimaryKeyType
|
||||
from airbyte_cdk.sources.file_based.exceptions import (
|
||||
FileBasedSourceError,
|
||||
InvalidSchemaError,
|
||||
@@ -36,7 +38,7 @@ class DefaultFileBasedStream(AbstractFileBasedStream, IncrementalMixin):
|
||||
ab_file_name_col = "_ab_source_file_url"
|
||||
airbyte_columns = [ab_last_mod_col, ab_file_name_col]
|
||||
|
||||
def __init__(self, cursor: FileBasedCursor, **kwargs):
|
||||
def __init__(self, cursor: FileBasedCursor, **kwargs: Any):
|
||||
super().__init__(**kwargs)
|
||||
self._cursor = cursor
|
||||
|
||||
@@ -45,12 +47,12 @@ class DefaultFileBasedStream(AbstractFileBasedStream, IncrementalMixin):
|
||||
return self._cursor.get_state()
|
||||
|
||||
@state.setter
|
||||
def state(self, value: MutableMapping[str, Any]):
|
||||
def state(self, value: MutableMapping[str, Any]) -> None:
|
||||
"""State setter, accept state serialized by state getter."""
|
||||
self._cursor.set_initial_state(value)
|
||||
|
||||
@property
|
||||
def primary_key(self) -> Optional[Union[str, List[str], List[List[str]]]]:
|
||||
def primary_key(self) -> PrimaryKeyType:
|
||||
return self.config.primary_key
|
||||
|
||||
def compute_slices(self) -> Iterable[Optional[Mapping[str, Any]]]:
|
||||
@@ -93,7 +95,7 @@ class DefaultFileBasedStream(AbstractFileBasedStream, IncrementalMixin):
|
||||
|
||||
except StopSyncPerValidationPolicy:
|
||||
yield AirbyteMessage(
|
||||
type=Type.LOG,
|
||||
type=MessageType.LOG,
|
||||
log=AirbyteLogMessage(
|
||||
level=Level.WARN,
|
||||
message=f"Stopping sync in accordance with the configured validation policy. Records in file did not conform to the schema. stream={self.name} file={file.uri} validation_policy={self.config.validation_policy} n_skipped={n_skipped}",
|
||||
@@ -103,7 +105,7 @@ class DefaultFileBasedStream(AbstractFileBasedStream, IncrementalMixin):
|
||||
|
||||
except Exception:
|
||||
yield AirbyteMessage(
|
||||
type=Type.LOG,
|
||||
type=MessageType.LOG,
|
||||
log=AirbyteLogMessage(
|
||||
level=Level.ERROR,
|
||||
message=f"{FileBasedSourceError.ERROR_PARSING_RECORD.value} stream={self.name} file={file.uri} line_no={line_no} n_skipped={n_skipped}",
|
||||
@@ -115,7 +117,7 @@ class DefaultFileBasedStream(AbstractFileBasedStream, IncrementalMixin):
|
||||
else:
|
||||
if n_skipped:
|
||||
yield AirbyteMessage(
|
||||
type=Type.LOG,
|
||||
type=MessageType.LOG,
|
||||
log=AirbyteLogMessage(
|
||||
level=Level.WARN,
|
||||
message=f"Records in file did not pass validation policy. stream={self.name} file={file.uri} n_skipped={n_skipped} validation_policy={self.validation_policy.name}",
|
||||
@@ -141,12 +143,11 @@ class DefaultFileBasedStream(AbstractFileBasedStream, IncrementalMixin):
|
||||
except Exception as exc:
|
||||
raise SchemaInferenceError(FileBasedSourceError.SCHEMA_INFERENCE_ERROR, stream=self.name) from exc
|
||||
else:
|
||||
schema["properties"] = {**extra_fields, **schema["properties"]}
|
||||
return schema
|
||||
return {"type": "object", "properties": {**extra_fields, **schema["properties"]}}
|
||||
|
||||
def _get_raw_json_schema(self) -> JsonSchema:
|
||||
if self.config.input_schema:
|
||||
return self.config.input_schema
|
||||
return self.config.input_schema # type: ignore
|
||||
elif self.config.schemaless:
|
||||
return schemaless_schema
|
||||
else:
|
||||
@@ -180,7 +181,7 @@ class DefaultFileBasedStream(AbstractFileBasedStream, IncrementalMixin):
|
||||
The output of this method is cached so we don't need to list the files more than once.
|
||||
This means we won't pick up changes to the files during a sync.
|
||||
"""
|
||||
return list(self._stream_reader.get_matching_files(self.config.globs))
|
||||
return list(self._stream_reader.get_matching_files(self.config.globs or []))
|
||||
|
||||
def infer_schema(self, files: List[RemoteFile]) -> Mapping[str, Any]:
|
||||
loop = asyncio.get_event_loop()
|
||||
@@ -193,13 +194,13 @@ class DefaultFileBasedStream(AbstractFileBasedStream, IncrementalMixin):
|
||||
Each file type has a corresponding `infer_schema` handler.
|
||||
Dispatch on file type.
|
||||
"""
|
||||
base_schema: Dict[str, str] = {}
|
||||
pending_tasks = set()
|
||||
base_schema: Dict[str, Any] = {}
|
||||
pending_tasks: Set[asyncio.tasks.Task[Dict[str, Any]]] = set()
|
||||
|
||||
n_started, n_files = 0, len(files)
|
||||
files = iter(files)
|
||||
files_iterator = iter(files)
|
||||
while pending_tasks or n_started < n_files:
|
||||
while len(pending_tasks) <= self._discovery_policy.n_concurrent_requests and (file := next(files, None)):
|
||||
while len(pending_tasks) <= self._discovery_policy.n_concurrent_requests and (file := next(files_iterator, None)):
|
||||
pending_tasks.add(asyncio.create_task(self._infer_file_schema(file)))
|
||||
n_started += 1
|
||||
# Return when the first task is completed so that we can enqueue a new task as soon as the
|
||||
@@ -210,7 +211,7 @@ class DefaultFileBasedStream(AbstractFileBasedStream, IncrementalMixin):
|
||||
|
||||
return base_schema
|
||||
|
||||
async def _infer_file_schema(self, file: RemoteFile) -> Mapping[str, Any]:
|
||||
async def _infer_file_schema(self, file: RemoteFile) -> Dict[str, Any]:
|
||||
try:
|
||||
return await self.get_parser(self.config.file_type).infer_schema(self.config, file, self._stream_reader, self.logger)
|
||||
except Exception as exc:
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Mapping
|
||||
from typing import Any, Mapping, MutableMapping
|
||||
|
||||
StreamSlice = Mapping[str, Any]
|
||||
StreamState = Mapping[str, Any]
|
||||
StreamState = MutableMapping[str, Any]
|
||||
|
||||
Reference in New Issue
Block a user