1
0
mirror of synced 2025-12-19 18:14:56 -05:00
Files

347 lines
11 KiB
Python

#
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.
#
from tempfile import NamedTemporaryFile
from unittest.mock import patch, sentinel
import pandas as pd
import pytest
from pandas import read_csv, read_excel, testing
from paramiko import SSHException
from source_file.client import Client, URLFile
from source_file.utils import backoff_handler
from urllib3.exceptions import ProtocolError
from airbyte_cdk.utils import AirbyteTracedException
@pytest.fixture
def wrong_format_client():
return Client(
dataset_name="test_dataset",
url="scp://test_dataset",
provider={"provider": {"storage": "HTTPS", "reader_impl": "gcsfs", "user_agent": False}},
format="wrong",
)
@pytest.fixture
def csv_format_client():
return Client(
dataset_name="test_dataset",
url="scp://test_dataset",
provider={"provider": {"storage": "HTTPS", "reader_impl": "gcsfs", "user_agent": False}},
format="csv",
)
@pytest.mark.parametrize(
"storage, expected_scheme, url",
[
("GCS", "gs://", "http://localhost"),
("S3", "s3://", "http://localhost"),
("AZBLOB", "azure://", "http://localhost"),
("HTTPS", "https://", "http://localhost"),
("SSH", "scp://", "http://localhost"),
("SCP", "scp://", "http://localhost"),
("SFTP", "sftp://", "http://localhost"),
("WEBHDFS", "webhdfs://", "http://localhost"),
("LOCAL", "file://", "http://localhost"),
("WRONG", "", ""),
],
)
def test_storage_scheme(storage, expected_scheme, url):
urlfile = URLFile(provider={"storage": storage}, url=url)
assert urlfile.storage_scheme == expected_scheme
def test_load_dataframes(client, wrong_format_client, absolute_path, test_files):
f = f"{absolute_path}/{test_files}/test.csv"
read_file = next(client.load_dataframes(fp=f))
expected = read_csv(f)
assert read_file.equals(expected)
with pytest.raises(AirbyteTracedException):
next(wrong_format_client.load_dataframes(fp=f))
with pytest.raises(StopIteration):
next(client.load_dataframes(fp=f, skip_data=True))
def test_raises_configuration_error_with_incorrect_file_type(csv_format_client, absolute_path, test_files):
f = f"{absolute_path}/{test_files}/archive_with_test_xlsx.zip"
with pytest.raises(AirbyteTracedException):
next(csv_format_client.load_dataframes(fp=f))
def test_load_dataframes_xlsb(config, absolute_path, test_files):
config["format"] = "excel_binary"
client = Client(**config)
f = f"{absolute_path}/{test_files}/test.xlsb"
read_file = next(client.load_dataframes(fp=f))
expected = read_excel(f, engine="pyxlsb")
assert read_file.equals(expected)
@pytest.mark.parametrize("file_name, should_raise_error", [("test.xlsx", False), ("test_one_line.xlsx", True)])
def test_load_dataframes_xlsx(config, absolute_path, test_files, file_name, should_raise_error):
config["format"] = "excel"
client = Client(**config)
f = f"{absolute_path}/{test_files}/{file_name}"
if should_raise_error:
with pytest.raises(AirbyteTracedException):
next(client.load_dataframes(fp=f))
else:
read_file = next(client.load_dataframes(fp=f))
expected = read_excel(f, engine="openpyxl")
assert read_file.equals(expected)
@pytest.mark.parametrize("file_format, file_path", [("json", "formats/json/demo.json"), ("jsonl", "formats/jsonl/jsonl_nested.jsonl")])
def test_load_nested_json(client, config, absolute_path, test_files, file_format, file_path):
if file_format == "jsonl":
config["format"] = file_format
client = Client(**config)
f = f"{absolute_path}/{test_files}/{file_path}"
with open(f, mode="rb") as file:
assert client.load_nested_json(fp=file)
@pytest.mark.parametrize(
"current_type, dtype, expected",
[
("string", "string", "string"),
("", object, "string"),
("", "int64", "number"),
("", "float64", "number"),
("boolean", "bool", "boolean"),
("number", "int64", "number"),
("number", "float64", "number"),
("number", "datetime64[ns]", "date-time"),
],
)
def test_dtype_to_json_type(client, current_type, dtype, expected):
assert client.dtype_to_json_type(current_type, dtype) == expected
def test_cache_stream(client, absolute_path, test_files):
f = f"{absolute_path}/{test_files}/test.csv"
with open(f, mode="rb") as file:
assert client._cache_stream(file)
def test_unzip_stream(client, absolute_path, test_files):
f = f"{absolute_path}/{test_files}/test.csv.zip"
with open(f, mode="rb") as file:
assert client._unzip(file)
def test_unzip_canonical_ext(absolute_path, test_files):
config = {
"dataset_name": "BBB",
"format": "csv",
"url": f"{absolute_path}/{test_files}/test.csv.zip",
"provider": {"storage": "local"},
"reader_options": {"encoding": "utf-8"},
}
client = Client(**config)
with patch.object(client, "_unzip", wraps=client._unzip) as monkey:
next(client.read())
monkey.assert_called()
def test_unzip_capitalized_ext(absolute_path, test_files):
config = {
"dataset_name": "CCC",
"format": "csv",
"url": f"{absolute_path}/{test_files}/test.csv.Zip",
"provider": {"storage": "local"},
"reader_options": {"encoding": "utf-8"},
}
client = Client(**config)
with patch.object(client, "_unzip", wraps=client._unzip) as monkey:
next(client.read())
monkey.assert_called()
def test_unzip_all_caps_ext(absolute_path, test_files):
config = {
"dataset_name": "CCC",
"format": "csv",
"url": f"{absolute_path}/{test_files}/test.csv.ZIP",
"provider": {"storage": "local"},
"reader_options": {"encoding": "utf-8"},
}
client = Client(**config)
with patch.object(client, "_unzip", wraps=client._unzip) as monkey:
next(client.read())
monkey.assert_called()
def test_open_aws_url():
url = "s3://my_bucket/my_key"
provider = {"storage": "S3"}
with pytest.raises(OSError):
assert URLFile(url=url, provider=provider)._open_aws_url()
provider.update({"aws_access_key_id": "aws_access_key_id", "aws_secret_access_key": "aws_secret_access_key"})
with pytest.raises(OSError):
assert URLFile(url=url, provider=provider)._open_aws_url()
def test_open_azblob_url():
provider = {"storage": "AZBLOB"}
with pytest.raises(ValueError):
assert URLFile(url="", provider=provider)._open_azblob_url()
provider.update({"storage_account": "storage_account", "sas_token": "sas_token", "shared_key": "shared_key"})
with pytest.raises(ValueError):
assert URLFile(url="", provider=provider)._open_azblob_url()
def test_open_gcs_url():
provider = {"storage": "GCS"}
with pytest.raises(IndexError):
assert URLFile(url="", provider=provider)._open_gcs_url()
provider.update({"service_account_json": '{"service_account_json": "service_account_json"}'})
with pytest.raises(ValueError):
assert URLFile(url="", provider=provider)._open_gcs_url()
provider.update({"service_account_json": '{service_account_json": "service_account_json"}'})
with pytest.raises(AirbyteTracedException):
assert URLFile(url="", provider=provider)._open_gcs_url()
def test_read(test_read_config):
client = Client(**test_read_config)
client.sleep_on_retry_sec = 0 # just for test
with patch.object(client, "load_dataframes", side_effect=ConnectionResetError) as mock_method:
try:
return client.read(["date", "key"])
except ConnectionResetError:
print("Exception has been raised correctly!")
mock_method.assert_called()
def test_read_network_issues(test_read_config):
test_read_config.update(format="excel")
client = Client(**test_read_config)
client.sleep_on_retry_sec = 0 # just for test
with patch.object(client, "_cache_stream", side_effect=ProtocolError), pytest.raises(AirbyteTracedException):
next(client.read(["date", "key"]))
def test_urlfile_open_backoff_sftp(monkeypatch, mocker):
call_count = 0
result = sentinel.result
def patched_open(self):
nonlocal call_count
call_count += 1
if call_count < 7:
raise SSHException("Error reading SSH protocol banner[Errno 104] Connection reset by peer")
return result
sleep_mock = mocker.patch("time.sleep")
monkeypatch.setattr(URLFile, "_open", patched_open)
provider = {"storage": "SFTP", "user": "user", "password": "password", "host": "sftp.domain.com", "port": 22}
reader = URLFile(url="/DISTDA.CSV", provider=provider, binary=False)
with pytest.raises(SSHException):
reader.open()
assert reader._file is None
assert call_count == 5
reader.open()
assert reader._file is result
assert call_count == 7
assert sleep_mock.call_count == 5
def test_backoff_handler(caplog):
details = {"tries": 1, "wait": 1}
backoff_handler(details)
expected = [("airbyte", 20, "Caught retryable error after 1 tries. Waiting 1 seconds then retrying...")]
assert caplog.record_tuples == expected
def generate_excel_file(data):
"""
Helper function to generate an Excel file with the given data.
"""
tmp_file = NamedTemporaryFile(suffix=".xlsx", delete=False)
df = pd.DataFrame(data)
df.to_excel(tmp_file.name, index=False, header=False)
tmp_file.seek(0)
return tmp_file
def test_excel_reader_option_names(config):
"""
Test the 'names' option for the Excel reader.
"""
config["format"] = "excel"
config["reader_options"] = {"names": ["CustomCol1", "CustomCol2"]}
client = Client(**config)
data = [["Value1", "Value2"], ["Value3", "Value4"]]
expected_data = [
{"CustomCol1": "Value1", "CustomCol2": "Value2"},
{"CustomCol1": "Value3", "CustomCol2": "Value4"},
]
with generate_excel_file(data) as tmp:
read_file = next(client.load_dataframes(fp=tmp.name))
assert isinstance(read_file, pd.DataFrame)
assert read_file.to_dict(orient="records") == expected_data
def test_excel_reader_option_skiprows(config):
"""
Test the 'skiprows' option for the Excel reader.
"""
config["format"] = "excel"
config["reader_options"] = {"skiprows": 1}
client = Client(**config)
data = [["SkipRow1", "SkipRow2"], ["Header1", "Header2"], ["Value1", "Value2"]]
expected_data = [
{"Header1": "Value1", "Header2": "Value2"},
]
with generate_excel_file(data) as tmp:
read_file = next(client.load_dataframes(fp=tmp.name))
assert isinstance(read_file, pd.DataFrame)
assert read_file.to_dict(orient="records") == expected_data
def test_excel_reader_option_header(config):
"""
Test the 'header' option for the Excel reader.
"""
config["format"] = "excel"
config["reader_options"] = {"header": 1}
client = Client(**config)
data = [["UnusedRow1", "UnusedRow2"], ["Header1", "Header2"], ["Value1", "Value2"]]
expected_data = [
{"Header1": "Value1", "Header2": "Value2"},
]
with generate_excel_file(data) as tmp:
read_file = next(client.load_dataframes(fp=tmp.name))
assert isinstance(read_file, pd.DataFrame)
assert read_file.to_dict(orient="records") == expected_data