1
0
mirror of synced 2026-01-07 18:06:03 -05:00

[26989] Add request filter for cloud and integration test fixtures for e2e sync testing (#27534)

* add the request filters and integration test fixtures

* pr feedback and some tweaks to the testing framework

* optimize the cache for more hits

* formatting

* remove cache
This commit is contained in:
Brian Lai
2023-06-22 12:14:07 -04:00
committed by GitHub
parent d96f6cbcdb
commit 02e4bd07f7
6 changed files with 503 additions and 3 deletions

View File

@@ -2,14 +2,17 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
import argparse
import importlib
import ipaddress
import logging
import os.path
import socket
import sys
import tempfile
from functools import wraps
from typing import Any, Iterable, List, Mapping
from urllib.parse import urlparse
from airbyte_cdk.connector import TConfig
from airbyte_cdk.exception_handler import init_uncaught_exception_handler
@@ -21,13 +24,24 @@ from airbyte_cdk.sources.source import TCatalog, TState
from airbyte_cdk.sources.utils.schema_helpers import check_config_against_spec_or_exit, split_config
from airbyte_cdk.utils.airbyte_secrets_utils import get_secrets, update_secrets
from airbyte_cdk.utils.traced_exception import AirbyteTracedException
from requests import Session
logger = init_logger("airbyte")
VALID_URL_SCHEMES = ["https"]
CLOUD_DEPLOYMENT_MODE = "cloud"
class AirbyteEntrypoint(object):
def __init__(self, source: Source):
init_uncaught_exception_handler(logger)
# DEPLOYMENT_MODE is read when instantiating the entrypoint because it is the common path shared by syncs and connector
# builder test requests
deployment_mode = os.environ.get("DEPLOYMENT_MODE", "")
if deployment_mode.casefold() == CLOUD_DEPLOYMENT_MODE:
_init_internal_request_filter()
self.source = source
self.logger = logging.getLogger(f"airbyte.{getattr(source, 'name', '')}")
@@ -172,6 +186,57 @@ def launch(source: Source, args: List[str]):
print(message)
def _init_internal_request_filter():
"""
Wraps the Python requests library to prevent sending requests to internal URL endpoints.
"""
wrapped_fn = Session.send
@wraps(wrapped_fn)
def filtered_send(self, request, **kwargs):
parsed_url = urlparse(request.url)
if parsed_url.scheme not in VALID_URL_SCHEMES:
raise ValueError(
"Invalid Protocol Scheme: The endpoint that data is being requested from is using an invalid or insecure "
+ f"protocol {parsed_url.scheme}. Valid protocol schemes: {','.join(VALID_URL_SCHEMES)}"
)
if not parsed_url.hostname:
raise ValueError("Invalid URL specified: The endpoint that data is being requested from is not a valid URL")
try:
is_private = _is_private_url(parsed_url.hostname, parsed_url.port)
if is_private:
raise ValueError(
"Invalid URL endpoint: The endpoint that data is being requested from belongs to a private network. Source "
+ "connectors only support requesting data from public API endpoints."
)
except socket.gaierror:
# This is a special case where the developer specifies an IP address string that is not formatted correctly like trailing
# whitespace which will fail the socket IP lookup. This only happens when using IP addresses and not text hostnames.
raise ValueError(f"Invalid hostname or IP address '{parsed_url.hostname}' specified.")
return wrapped_fn(self, request, **kwargs)
Session.send = filtered_send
def _is_private_url(hostname: str, port: int) -> bool:
"""
Helper method that checks if any of the IP addresses associated with a hostname belong to a private network.
"""
address_info_entries = socket.getaddrinfo(hostname, port)
for entry in address_info_entries:
# getaddrinfo() returns entries in the form of a 5-tuple where the IP is stored as the sockaddr. For IPv4 this
# is a 2-tuple and for IPv6 it is a 4-tuple, but the address is always the first value of the tuple at 0.
# See https://docs.python.org/3/library/socket.html#socket.getaddrinfo for more details.
ip_address = entry[4][0]
if ipaddress.ip_address(ip_address).is_private:
return True
return False
def main():
impl_module = os.environ.get("AIRBYTE_IMPL_MODULE", Source.__module__)
impl_class = os.environ.get("AIRBYTE_IMPL_PATH", Source.__name__)