1
0
mirror of synced 2025-12-25 02:09:19 -05:00

[ISSUE #6548] make all fields nullable except from pk and cursor field (#36201)

This commit is contained in:
Maxime Carbonneau-Leclerc
2024-03-20 09:48:38 -04:00
committed by GitHub
parent a38fdacdf8
commit 2f34f084e4
7 changed files with 440 additions and 84 deletions

View File

@@ -24,7 +24,7 @@ from airbyte_cdk.sources.utils.slice_logger import SliceLogger
from airbyte_cdk.sources.utils.types import JsonType
from airbyte_cdk.utils import AirbyteTracedException
from airbyte_cdk.utils.datetime_format_inferrer import DatetimeFormatInferrer
from airbyte_cdk.utils.schema_inferrer import SchemaInferrer
from airbyte_cdk.utils.schema_inferrer import SchemaInferrer, SchemaValidationException
from airbyte_protocol.models.airbyte_protocol import (
AirbyteControlMessage,
AirbyteLogMessage,
@@ -45,6 +45,32 @@ class MessageGrouper:
self._max_slices = max_slices
self._max_record_limit = max_record_limit
def _pk_to_nested_and_composite_field(self, field: Optional[Union[str, List[str], List[List[str]]]]) -> List[List[str]]:
if not field:
return [[]]
if isinstance(field, str):
return [[field]]
is_composite_key = isinstance(field[0], str)
if is_composite_key:
return [[i] for i in field] # type: ignore # the type of field is expected to be List[str] here
return field # type: ignore # the type of field is expected to be List[List[str]] here
def _cursor_field_to_nested_and_composite_field(self, field: Union[str, List[str]]) -> List[List[str]]:
if not field:
return [[]]
if isinstance(field, str):
return [[field]]
is_nested_key = isinstance(field[0], str)
if is_nested_key:
return [field] # type: ignore # the type of field is expected to be List[str] here
raise ValueError(f"Unknown type for cursor field `{field}")
def get_message_groups(
self,
source: DeclarativeSource,
@@ -54,7 +80,11 @@ class MessageGrouper:
) -> StreamRead:
if record_limit is not None and not (1 <= record_limit <= self._max_record_limit):
raise ValueError(f"Record limit must be between 1 and {self._max_record_limit}. Got {record_limit}")
schema_inferrer = SchemaInferrer()
stream = source.streams(config)[0] # The connector builder currently only supports reading from a single stream at a time
schema_inferrer = SchemaInferrer(
self._pk_to_nested_and_composite_field(stream.primary_key),
self._cursor_field_to_nested_and_composite_field(stream.cursor_field),
)
datetime_format_inferrer = DatetimeFormatInferrer()
if record_limit is None:
@@ -88,14 +118,20 @@ class MessageGrouper:
else:
raise ValueError(f"Unknown message group type: {type(message_group)}")
try:
configured_stream = configured_catalog.streams[0] # The connector builder currently only supports reading from a single stream at a time
schema = schema_inferrer.get_stream_schema(configured_stream.stream.name)
except SchemaValidationException as exception:
for validation_error in exception.validation_errors:
log_messages.append(LogMessage(validation_error, "ERROR"))
schema = exception.schema
return StreamRead(
logs=log_messages,
slices=slices,
test_read_limit_reached=self._has_reached_limit(slices),
auxiliary_requests=auxiliary_requests,
inferred_schema=schema_inferrer.get_stream_schema(
configured_catalog.streams[0].stream.name
), # The connector builder currently only supports reading from a single stream at a time
inferred_schema=schema,
latest_config_update=self._clean_config(latest_config_update.connectorConfig.config) if latest_config_update else None,
inferred_datetime_formats=datetime_format_inferrer.get_inferred_datetime_formats(),
)

View File

@@ -1,29 +1,62 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
from typing import Any, Dict, List
from typing import List, Union, overload
from airbyte_protocol.models import ConfiguredAirbyteCatalog, SyncMode
from airbyte_protocol.models import ConfiguredAirbyteCatalog, ConfiguredAirbyteStream, SyncMode
class ConfiguredAirbyteStreamBuilder:
def __init__(self) -> None:
self._stream = {
"stream": {
"name": "any name",
"json_schema": {},
"supported_sync_modes": ["full_refresh", "incremental"],
"source_defined_primary_key": [["id"]],
},
"primary_key": [["id"]],
"sync_mode": "full_refresh",
"destination_sync_mode": "overwrite",
}
def with_name(self, name: str) -> "ConfiguredAirbyteStreamBuilder":
self._stream["stream"]["name"] = name # type: ignore # we assume that self._stream["stream"] is a Dict[str, Any]
return self
def with_sync_mode(self, sync_mode: SyncMode) -> "ConfiguredAirbyteStreamBuilder":
self._stream["sync_mode"] = sync_mode.name
return self
def with_primary_key(self, pk: List[List[str]]) -> "ConfiguredAirbyteStreamBuilder":
self._stream["primary_key"] = pk
self._stream["stream"]["source_defined_primary_key"] = pk # type: ignore # we assume that self._stream["stream"] is a Dict[str, Any]
return self
def build(self) -> ConfiguredAirbyteStream:
return ConfiguredAirbyteStream.parse_obj(self._stream)
class CatalogBuilder:
def __init__(self) -> None:
self._streams: List[Dict[str, Any]] = []
self._streams: List[ConfiguredAirbyteStreamBuilder] = []
@overload
def with_stream(self, name: ConfiguredAirbyteStreamBuilder) -> "CatalogBuilder":
...
@overload
def with_stream(self, name: str, sync_mode: SyncMode) -> "CatalogBuilder":
self._streams.append(
{
"stream": {
"name": name,
"json_schema": {},
"supported_sync_modes": ["full_refresh", "incremental"],
"source_defined_primary_key": [["id"]],
},
"primary_key": [["id"]],
"sync_mode": sync_mode.name,
"destination_sync_mode": "overwrite",
}
)
...
def with_stream(self, name: Union[str, ConfiguredAirbyteStreamBuilder], sync_mode: Union[SyncMode, None] = None) -> "CatalogBuilder":
# As we are introducing a fully fledge ConfiguredAirbyteStreamBuilder, we would like to deprecate the previous interface
# with_stream(str, SyncMode)
# to avoid a breaking change, `name` needs to stay in the API but this can be either a name or a builder
name_or_builder = name
builder = name_or_builder if isinstance(name_or_builder, ConfiguredAirbyteStreamBuilder) else ConfiguredAirbyteStreamBuilder().with_name(name_or_builder).with_sync_mode(sync_mode)
self._streams.append(builder)
return self
def build(self) -> ConfiguredAirbyteCatalog:
return ConfiguredAirbyteCatalog.parse_obj({"streams": self._streams})
return ConfiguredAirbyteCatalog(streams=list(map(lambda builder: builder.build(), self._streams)))

View File

@@ -3,18 +3,23 @@
#
from collections import defaultdict
from typing import Any, Dict, Mapping, Optional
from typing import Any, Dict, List, Mapping, Optional
from airbyte_cdk.models import AirbyteRecordMessage
from genson import SchemaBuilder, SchemaNode
from genson.schema.strategies.object import Object
from genson.schema.strategies.scalar import Number
_NULL_TYPE = "null"
class NoRequiredObj(Object):
"""
This class has Object behaviour, but it does not generate "required[]" fields
every time it parses object. So we dont add unnecessary extra field.
every time it parses object. So we don't add unnecessary extra field.
The logic is that even reading all the data from a source, it does not mean that there can be another record added with those fields as
optional. Hence, we make everything nullable.
"""
def to_schema(self) -> Mapping[str, Any]:
@@ -41,6 +46,25 @@ class NoRequiredSchemaBuilder(SchemaBuilder):
InferredSchema = Dict[str, Any]
class SchemaValidationException(Exception):
@classmethod
def merge_exceptions(cls, exceptions: List["SchemaValidationException"]) -> "SchemaValidationException":
# We assume the schema is the same for all SchemaValidationException
return SchemaValidationException(exceptions[0].schema, [x for exception in exceptions for x in exception._validation_errors])
def __init__(self, schema: InferredSchema, validation_errors: List[Exception]):
self._schema = schema
self._validation_errors = validation_errors
@property
def schema(self) -> InferredSchema:
return self._schema
@property
def validation_errors(self) -> List[str]:
return list(map(lambda error: str(error), self._validation_errors))
class SchemaInferrer:
"""
This class is used to infer a JSON schema which fits all the records passed into it
@@ -53,23 +77,15 @@ class SchemaInferrer:
stream_to_builder: Dict[str, SchemaBuilder]
def __init__(self) -> None:
def __init__(self, pk: Optional[List[List[str]]] = None, cursor_field: Optional[List[List[str]]] = None) -> None:
self.stream_to_builder = defaultdict(NoRequiredSchemaBuilder)
self._pk = [] if pk is None else pk
self._cursor_field = [] if cursor_field is None else cursor_field
def accumulate(self, record: AirbyteRecordMessage) -> None:
"""Uses the input record to add to the inferred schemas maintained by this object"""
self.stream_to_builder[record.stream].add_object(record.data)
def get_inferred_schemas(self) -> Dict[str, InferredSchema]:
"""
Returns the JSON schemas for all encountered streams inferred by inspecting all records
passed via the accumulate method
"""
schemas = {}
for stream_name, builder in self.stream_to_builder.items():
schemas[stream_name] = self._clean(builder.to_schema())
return schemas
def _clean(self, node: InferredSchema) -> InferredSchema:
"""
Recursively cleans up a produced schema:
@@ -78,23 +94,119 @@ class SchemaInferrer:
"""
if isinstance(node, dict):
if "anyOf" in node:
if len(node["anyOf"]) == 2 and {"type": "null"} in node["anyOf"]:
real_type = node["anyOf"][1] if node["anyOf"][0]["type"] == "null" else node["anyOf"][0]
if len(node["anyOf"]) == 2 and {"type": _NULL_TYPE} in node["anyOf"]:
real_type = node["anyOf"][1] if node["anyOf"][0]["type"] == _NULL_TYPE else node["anyOf"][0]
node.update(real_type)
node["type"] = [node["type"], "null"]
node["type"] = [node["type"], _NULL_TYPE]
node.pop("anyOf")
if "properties" in node and isinstance(node["properties"], dict):
for key, value in list(node["properties"].items()):
if isinstance(value, dict) and value.get("type", None) == "null":
if isinstance(value, dict) and value.get("type", None) == _NULL_TYPE:
node["properties"].pop(key)
else:
self._clean(value)
if "items" in node:
self._clean(node["items"])
# this check needs to follow the "anyOf" cleaning as it might populate `type`
if isinstance(node["type"], list):
if _NULL_TYPE in node["type"]:
# we want to make sure null is always at the end as it makes schemas more readable
node["type"].remove(_NULL_TYPE)
node["type"].append(_NULL_TYPE)
else:
node["type"] = [node["type"], _NULL_TYPE]
return node
def _add_required_properties(self, node: InferredSchema) -> InferredSchema:
"""
This method takes properties that should be marked as required (self._pk and self._cursor_field) and travel the schema to mark every
node as required.
"""
# Removing nullable for the root as when we call `_clean`, we make everything nullable
node["type"] = "object"
exceptions = []
for field in [x for x in [self._pk, self._cursor_field] if x]:
try:
self._add_fields_as_required(node, field)
except SchemaValidationException as exception:
exceptions.append(exception)
if exceptions:
raise SchemaValidationException.merge_exceptions(exceptions)
return node
def _add_fields_as_required(self, node: InferredSchema, composite_key: List[List[str]]) -> None:
"""
Take a list of nested keys (this list represents a composite key) and travel the schema to mark every node as required.
"""
errors: List[Exception] = []
for path in composite_key:
try:
self._add_field_as_required(node, path)
except ValueError as exception:
errors.append(exception)
if errors:
raise SchemaValidationException(node, errors)
def _add_field_as_required(self, node: InferredSchema, path: List[str], traveled_path: Optional[List[str]] = None) -> None:
"""
Take a nested key and travel the schema to mark every node as required.
"""
self._remove_null_from_type(node)
if self._is_leaf(path):
return
if not traveled_path:
traveled_path = []
if "properties" not in node:
# This validation is only relevant when `traveled_path` is empty
raise ValueError(
f"Path {traveled_path} does not refer to an object but is `{node}` and hence {path} can't be marked as required."
)
next_node = path[0]
if next_node not in node["properties"]:
raise ValueError(f"Path {traveled_path} does not have field `{next_node}` in the schema and hence can't be marked as required.")
if "type" not in node:
# We do not expect this case to happen but we added a specific error message just in case
raise ValueError(
f"Unknown schema error: {traveled_path} is expected to have a type but did not. Schema inferrence is probably broken"
)
if node["type"] not in ["object", ["null", "object"], ["object", "null"]]:
raise ValueError(f"Path {traveled_path} is expected to be an object but was of type `{node['properties'][next_node]['type']}`")
if "required" not in node or not node["required"]:
node["required"] = [next_node]
elif next_node not in node["required"]:
node["required"].append(next_node)
traveled_path.append(next_node)
self._add_field_as_required(node["properties"][next_node], path[1:], traveled_path)
def _is_leaf(self, path: List[str]) -> bool:
return len(path) == 0
def _remove_null_from_type(self, node: InferredSchema) -> None:
if isinstance(node["type"], list):
if "null" in node["type"]:
node["type"].remove("null")
if len(node["type"]) == 1:
node["type"] = node["type"][0]
def get_stream_schema(self, stream_name: str) -> Optional[InferredSchema]:
"""
Returns the inferred JSON schema for the specified stream. Might be `None` if there were no records for the given stream name.
"""
return self._clean(self.stream_to_builder[stream_name].to_schema()) if stream_name in self.stream_to_builder else None
return (
self._add_required_properties(self._clean(self.stream_to_builder[stream_name].to_schema()))
if stream_name in self.stream_to_builder
else None
)

View File

@@ -499,7 +499,19 @@ def test_config_update():
@patch("traceback.TracebackException.from_exception")
def test_read_returns_error_response(mock_from_exception):
class MockDeclarativeStream:
@property
def primary_key(self):
return [[]]
@property
def cursor_field(self):
return []
class MockManifestDeclarativeSource:
def streams(self, config):
return [MockDeclarativeStream()]
def read(self, logger, config, catalog, state):
raise ValueError("error_message")

View File

@@ -21,6 +21,9 @@ from airbyte_cdk.models import (
from airbyte_cdk.models import Type as MessageType
from unit_tests.connector_builder.utils import create_configured_catalog
_NO_PK = [[]]
_NO_CURSOR_FIELD = []
MAX_PAGES_PER_SLICE = 4
MAX_SLICES = 3
@@ -96,7 +99,7 @@ def test_get_grouped_messages(mock_entrypoint_read: Mock) -> None:
response = {"status_code": 200, "headers": {"field": "value"}, "body": {"content": '{"name": "field"}'}}
expected_schema = {
"$schema": "http://json-schema.org/schema#",
"properties": {"name": {"type": "string"}, "date": {"type": "string"}},
"properties": {"name": {"type": ["string", "null"]}, "date": {"type": ["string", "null"]}},
"type": "object",
}
expected_datetime_fields = {"date": "%Y-%m-%d"}
@@ -537,6 +540,7 @@ def test_get_grouped_messages_given_maximum_number_of_pages_then_test_read_limit
def test_read_stream_returns_error_if_stream_does_not_exist() -> None:
mock_source = MagicMock()
mock_source.read.side_effect = ValueError("error")
mock_source.streams.return_value = [make_mock_stream()]
full_config: Mapping[str, Any] = {**CONFIG, **{"__injected_declarative_manifest": MANIFEST}}
@@ -636,12 +640,58 @@ def test_given_no_slices_then_return_empty_slices(mock_entrypoint_read: Mock) ->
assert len(stream_read.slices) == 0
@patch("airbyte_cdk.connector_builder.message_grouper.AirbyteEntrypoint.read")
def test_given_pk_then_ensure_pk_is_pass_to_schema_inferrence(mock_entrypoint_read: Mock) -> None:
mock_source = make_mock_source(mock_entrypoint_read, iter([
request_response_log_message({"request": 1}, {"response": 2}, "http://any_url.com"),
record_message("hashiras", {"id": "Shinobu Kocho", "date": "2023-03-03"}),
record_message("hashiras", {"id": "Muichiro Tokito", "date": "2023-03-04"}),
]))
mock_source.streams.return_value = [Mock()]
mock_source.streams.return_value[0].primary_key = [["id"]]
mock_source.streams.return_value[0].cursor_field = _NO_CURSOR_FIELD
connector_builder_handler = MessageGrouper(MAX_PAGES_PER_SLICE, MAX_SLICES)
stream_read: StreamRead = connector_builder_handler.get_message_groups(
source=mock_source, config=CONFIG, configured_catalog=create_configured_catalog("hashiras")
)
assert stream_read.inferred_schema["required"] == ["id"]
@patch("airbyte_cdk.connector_builder.message_grouper.AirbyteEntrypoint.read")
def test_given_cursor_field_then_ensure_cursor_field_is_pass_to_schema_inferrence(mock_entrypoint_read: Mock) -> None:
mock_source = make_mock_source(mock_entrypoint_read, iter([
request_response_log_message({"request": 1}, {"response": 2}, "http://any_url.com"),
record_message("hashiras", {"id": "Shinobu Kocho", "date": "2023-03-03"}),
record_message("hashiras", {"id": "Muichiro Tokito", "date": "2023-03-04"}),
]))
mock_source.streams.return_value = [Mock()]
mock_source.streams.return_value[0].primary_key = _NO_PK
mock_source.streams.return_value[0].cursor_field = ["date"]
connector_builder_handler = MessageGrouper(MAX_PAGES_PER_SLICE, MAX_SLICES)
stream_read: StreamRead = connector_builder_handler.get_message_groups(
source=mock_source, config=CONFIG, configured_catalog=create_configured_catalog("hashiras")
)
assert stream_read.inferred_schema["required"] == ["date"]
def make_mock_source(mock_entrypoint_read: Mock, return_value: Iterator[AirbyteMessage]) -> MagicMock:
mock_source = MagicMock()
mock_entrypoint_read.return_value = return_value
mock_source.streams.return_value = [make_mock_stream()]
return mock_source
def make_mock_stream():
mock_stream = MagicMock()
mock_stream.primary_key = []
mock_stream.cursor_field = []
return mock_stream
def request_log_message(request: Mapping[str, Any]) -> AirbyteMessage:
return AirbyteMessage(type=MessageType.LOG, log=AirbyteLogMessage(level=Level.INFO, message=f"request:{json.dumps(request)}"))

View File

@@ -6,7 +6,7 @@ from typing import List, Mapping
import pytest
from airbyte_cdk.models.airbyte_protocol import AirbyteRecordMessage
from airbyte_cdk.utils.schema_inferrer import SchemaInferrer
from airbyte_cdk.utils.schema_inferrer import SchemaInferrer, SchemaValidationException
NOW = 1234567
@@ -19,7 +19,7 @@ NOW = 1234567
{"stream": "my_stream", "data": {"field_A": "abc"}},
{"stream": "my_stream", "data": {"field_A": "def"}},
],
{"my_stream": {"field_A": {"type": "string"}}},
{"my_stream": {"field_A": {"type": ["string", "null"]}}},
id="test_basic",
),
pytest.param(
@@ -27,7 +27,7 @@ NOW = 1234567
{"stream": "my_stream", "data": {"field_A": 1.0}},
{"stream": "my_stream", "data": {"field_A": "abc"}},
],
{"my_stream": {"field_A": {"type": ["number", "string"]}}},
{"my_stream": {"field_A": {"type": ["number", "string", "null"]}}},
id="test_deriving_schema_refine",
),
pytest.param(
@@ -38,10 +38,10 @@ NOW = 1234567
{
"my_stream": {
"obj": {
"type": "object",
"type": ["object", "null"],
"properties": {
"data": {"type": "array", "items": {"type": "number"}},
"other_key": {"type": "string"},
"data": {"type": ["array", "null"], "items": {"type": ["number", "null"]}},
"other_key": {"type": ["string", "null"]},
},
}
}
@@ -53,7 +53,7 @@ NOW = 1234567
{"stream": "my_stream", "data": {"field_A": 1}},
{"stream": "my_stream", "data": {"field_A": 2}},
],
{"my_stream": {"field_A": {"type": "number"}}},
{"my_stream": {"field_A": {"type": ["number", "null"]}}},
id="test_integer_number",
),
pytest.param(
@@ -68,7 +68,7 @@ NOW = 1234567
{"stream": "my_stream", "data": {"field_A": None}},
{"stream": "my_stream", "data": {"field_A": "abc"}},
],
{"my_stream": {"field_A": {"type": ["null", "string"]}}},
{"my_stream": {"field_A": {"type": ["string", "null"]}}},
id="test_null_optional",
),
pytest.param(
@@ -76,7 +76,7 @@ NOW = 1234567
{"stream": "my_stream", "data": {"field_A": None}},
{"stream": "my_stream", "data": {"field_A": {"nested": "abc"}}},
],
{"my_stream": {"field_A": {"type": ["object", "null"], "properties": {"nested": {"type": "string"}}}}},
{"my_stream": {"field_A": {"type": ["object", "null"], "properties": {"nested": {"type": ["string", "null"]}}}}},
id="test_any_of",
),
pytest.param(
@@ -84,7 +84,7 @@ NOW = 1234567
{"stream": "my_stream", "data": {"field_A": None}},
{"stream": "my_stream", "data": {"field_A": {"nested": "abc", "nully": None}}},
],
{"my_stream": {"field_A": {"type": ["object", "null"], "properties": {"nested": {"type": "string"}}}}},
{"my_stream": {"field_A": {"type": ["object", "null"], "properties": {"nested": {"type": ["string", "null"]}}}}},
id="test_any_of_with_null",
),
pytest.param(
@@ -97,7 +97,7 @@ NOW = 1234567
"my_stream": {
"field_A": {
"type": ["object", "null"],
"properties": {"nested": {"type": "string"}, "nully": {"type": ["null", "string"]}},
"properties": {"nested": {"type": ["string", "null"]}, "nully": {"type": ["string", "null"]}},
}
}
},
@@ -113,7 +113,7 @@ NOW = 1234567
"my_stream": {
"field_A": {
"type": ["object", "null"],
"properties": {"nested": {"type": "string"}, "nully": {"type": ["null", "string"]}},
"properties": {"nested": {"type": ["string", "null"]}, "nully": {"type": ["string", "null"]}},
}
}
},
@@ -123,7 +123,7 @@ NOW = 1234567
[
{"stream": "my_stream", "data": {"field_A": "abc", "nested": {"field_B": None}}},
],
{"my_stream": {"field_A": {"type": "string"}, "nested": {"type": "object", "properties": {}}}},
{"my_stream": {"field_A": {"type": ["string", "null"]}, "nested": {"type": ["object", "null"], "properties": {}}}},
id="test_nested_null",
),
pytest.param(
@@ -132,8 +132,8 @@ NOW = 1234567
],
{
"my_stream": {
"field_A": {"type": "string"},
"nested": {"type": "array", "items": {"type": "object", "properties": {"field_C": {"type": "string"}}}},
"field_A": {"type": ["string", "null"]},
"nested": {"type": ["array", "null"], "items": {"type": ["object", "null"], "properties": {"field_C": {"type": ["string", "null"]}}}},
}
},
id="test_array_nested_null",
@@ -145,8 +145,8 @@ NOW = 1234567
],
{
"my_stream": {
"field_A": {"type": "string"},
"nested": {"type": ["array", "null"], "items": {"type": "object", "properties": {"field_C": {"type": "string"}}}},
"field_A": {"type": ["string", "null"]},
"nested": {"type": ["array", "null"], "items": {"type": ["object", "null"], "properties": {"field_C": {"type": ["string", "null"]}}}},
}
},
id="test_array_top_level_null",
@@ -156,7 +156,7 @@ NOW = 1234567
{"stream": "my_stream", "data": {"field_A": None}},
{"stream": "my_stream", "data": {"field_A": "abc"}},
],
{"my_stream": {"field_A": {"type": ["null", "string"]}}},
{"my_stream": {"field_A": {"type": ["string", "null"]}}},
id="test_null_string",
),
],
@@ -167,36 +167,127 @@ def test_schema_derivation(input_records: List, expected_schemas: Mapping):
inferrer.accumulate(AirbyteRecordMessage(stream=record["stream"], data=record["data"], emitted_at=NOW))
for stream_name, expected_schema in expected_schemas.items():
assert inferrer.get_inferred_schemas()[stream_name] == {
assert inferrer.get_stream_schema(stream_name) == {
"$schema": "http://json-schema.org/schema#",
"type": "object",
"properties": expected_schema,
}
def test_deriving_schema_multiple_streams():
inferrer = SchemaInferrer()
inferrer.accumulate(AirbyteRecordMessage(stream="my_stream", data={"field_A": 1.0}, emitted_at=NOW))
inferrer.accumulate(AirbyteRecordMessage(stream="my_stream2", data={"field_A": "abc"}, emitted_at=NOW))
inferred_schemas = inferrer.get_inferred_schemas()
assert inferred_schemas["my_stream"] == {
"$schema": "http://json-schema.org/schema#",
"type": "object",
"properties": {"field_A": {"type": "number"}},
}
assert inferred_schemas["my_stream2"] == {
"$schema": "http://json-schema.org/schema#",
"type": "object",
"properties": {"field_A": {"type": "string"}},
}
_STREAM_NAME = "a stream name"
_ANY_VALUE = "any value"
_IS_PK = True
_IS_CURSOR_FIELD = True
def test_get_individual_schema():
inferrer = SchemaInferrer()
inferrer.accumulate(AirbyteRecordMessage(stream="my_stream", data={"field_A": 1.0}, emitted_at=NOW))
assert inferrer.get_stream_schema("my_stream") == {
"$schema": "http://json-schema.org/schema#",
"type": "object",
"properties": {"field_A": {"type": "number"}},
}
assert inferrer.get_stream_schema("another_stream") is None
def _create_inferrer_with_required_field(is_pk: bool, field: List[List[str]]) -> SchemaInferrer:
if is_pk:
return SchemaInferrer(field)
return SchemaInferrer([[]], field)
@pytest.mark.parametrize(
"is_pk",
[
pytest.param(_IS_PK, id="required_field_is_pk"),
pytest.param(_IS_CURSOR_FIELD, id="required_field_is_cursor_field"),
]
)
def test_field_is_on_root(is_pk: bool):
inferrer = _create_inferrer_with_required_field(is_pk, [["property"]])
inferrer.accumulate(AirbyteRecordMessage(stream=_STREAM_NAME, data={"property": _ANY_VALUE}, emitted_at=NOW))
assert inferrer.get_stream_schema(_STREAM_NAME)["required"] == ["property"]
assert inferrer.get_stream_schema(_STREAM_NAME)["properties"]["property"]["type"] == "string"
@pytest.mark.parametrize(
"is_pk",
[
pytest.param(_IS_PK, id="required_field_is_pk"),
pytest.param(_IS_CURSOR_FIELD, id="required_field_is_cursor_field"),
]
)
def test_field_is_nested(is_pk: bool):
inferrer = _create_inferrer_with_required_field(is_pk, [["property", "nested_property"]])
inferrer.accumulate(AirbyteRecordMessage(stream=_STREAM_NAME, data={"property": {"nested_property": _ANY_VALUE}}, emitted_at=NOW))
assert inferrer.get_stream_schema(_STREAM_NAME)["required"] == ["property"]
assert inferrer.get_stream_schema(_STREAM_NAME)["properties"]["property"]["type"] == "object"
assert inferrer.get_stream_schema(_STREAM_NAME)["properties"]["property"]["required"] == ["nested_property"]
@pytest.mark.parametrize(
"is_pk",
[
pytest.param(_IS_PK, id="required_field_is_pk"),
pytest.param(_IS_CURSOR_FIELD, id="required_field_is_cursor_field"),
]
)
def test_field_is_composite(is_pk: bool):
inferrer = _create_inferrer_with_required_field(is_pk, [["property 1"], ["property 2"]])
inferrer.accumulate(AirbyteRecordMessage(stream=_STREAM_NAME, data={"property 1": _ANY_VALUE, "property 2": _ANY_VALUE}, emitted_at=NOW))
assert inferrer.get_stream_schema(_STREAM_NAME)["required"] == ["property 1", "property 2"]
@pytest.mark.parametrize(
"is_pk",
[
pytest.param(_IS_PK, id="required_field_is_pk"),
pytest.param(_IS_CURSOR_FIELD, id="required_field_is_cursor_field"),
]
)
def test_field_is_composite_and_nested(is_pk: bool):
inferrer = _create_inferrer_with_required_field(is_pk, [["property 1", "nested"], ["property 2"]])
inferrer.accumulate(AirbyteRecordMessage(stream=_STREAM_NAME, data={"property 1": {"nested": _ANY_VALUE}, "property 2": _ANY_VALUE}, emitted_at=NOW))
assert inferrer.get_stream_schema(_STREAM_NAME)["required"] == ["property 1", "property 2"]
assert inferrer.get_stream_schema(_STREAM_NAME)["properties"]["property 1"]["type"] == "object"
assert inferrer.get_stream_schema(_STREAM_NAME)["properties"]["property 2"]["type"] == "string"
assert inferrer.get_stream_schema(_STREAM_NAME)["properties"]["property 1"]["required"] == ["nested"]
assert inferrer.get_stream_schema(_STREAM_NAME)["properties"]["property 1"]["properties"]["nested"]["type"] == "string"
def test_given_pk_does_not_exist_when_get_inferred_schemas_then_raise_error():
inferrer = SchemaInferrer([["pk does not exist"]])
inferrer.accumulate(AirbyteRecordMessage(stream=_STREAM_NAME, data={"id": _ANY_VALUE}, emitted_at=NOW))
with pytest.raises(SchemaValidationException) as exception:
inferrer.get_stream_schema(_STREAM_NAME)
assert len(exception.value.validation_errors) == 1
def test_given_pk_path_is_partially_valid_when_get_inferred_schemas_then_validation_error_mentions_where_the_issue_is():
inferrer = SchemaInferrer([["id", "nested pk that does not exist"]])
inferrer.accumulate(AirbyteRecordMessage(stream=_STREAM_NAME, data={"id": _ANY_VALUE}, emitted_at=NOW))
with pytest.raises(SchemaValidationException) as exception:
inferrer.get_stream_schema(_STREAM_NAME)
assert len(exception.value.validation_errors) == 1
assert "Path ['id']" in exception.value.validation_errors[0]
def test_given_composite_pk_but_only_one_path_valid_when_get_inferred_schemas_then_valid_path_is_required():
inferrer = SchemaInferrer([["id 1"], ["id 2"]])
inferrer.accumulate(AirbyteRecordMessage(stream=_STREAM_NAME, data={"id 1": _ANY_VALUE}, emitted_at=NOW))
with pytest.raises(SchemaValidationException) as exception:
inferrer.get_stream_schema(_STREAM_NAME)
assert exception.value.schema["required"] == ["id 1"]
def test_given_composite_pk_but_only_one_path_valid_when_get_inferred_schemas_then_validation_error_mentions_where_the_issue_is():
inferrer = SchemaInferrer([["id 1"], ["id 2"]])
inferrer.accumulate(AirbyteRecordMessage(stream=_STREAM_NAME, data={"id 1": _ANY_VALUE}, emitted_at=NOW))
with pytest.raises(SchemaValidationException) as exception:
inferrer.get_stream_schema(_STREAM_NAME)
assert len(exception.value.validation_errors) == 1
assert "id 2" in exception.value.validation_errors[0]

View File

@@ -1597,6 +1597,28 @@ VALID_CATALOG_TRANSITIONS = [
)
},
),
Transition(
name="Given the same types, the order does not matter",
should_fail=False,
previous={
"test_stream": AirbyteStream.parse_obj(
{
"name": "test_stream",
"json_schema": {"properties": {"user": {"type": "object", "properties": {"username": {"type": ["null", "string"]}}}}},
"supported_sync_modes": ["full_refresh"],
}
)
},
current={
"test_stream": AirbyteStream.parse_obj(
{
"name": "test_stream",
"json_schema": {"properties": {"user": {"type": "object", "properties": {"username": {"type": ["string", "null"]}}}}},
"supported_sync_modes": ["full_refresh"],
}
)
},
),
Transition(
name="Changing 'type' field to list should not fail.",
should_fail=False,