80 lines
3.0 KiB
Python
80 lines
3.0 KiB
Python
#
|
|
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
|
|
#
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import multiprocessing as mp
|
|
import traceback
|
|
from multiprocessing import Queue
|
|
from typing import Any, Callable, List, Mapping, cast
|
|
|
|
import dill
|
|
import orjson
|
|
|
|
from airbyte_cdk.models import AirbyteMessage, AirbyteMessageSerializer
|
|
|
|
|
|
def run_in_external_process(fn: Callable, timeout: int, max_timeout: int, logger: logging.Logger, args: List[Any]) -> Mapping[str, Any]:
|
|
"""
|
|
fn passed in must return a tuple of (desired return value, Exception OR None)
|
|
This allows propagating any errors from the process up and raising accordingly
|
|
"""
|
|
result = None
|
|
while result is None:
|
|
q_worker: Queue = mp.Queue()
|
|
proc = mp.Process(
|
|
target=multiprocess_queuer,
|
|
# use dill to pickle the function for Windows-compatibility
|
|
args=(dill.dumps(fn), q_worker, *args),
|
|
)
|
|
proc.start()
|
|
try:
|
|
# this attempts to get return value from function with our specified timeout up to max
|
|
result, potential_error = q_worker.get(timeout=min(timeout, max_timeout))
|
|
except mp.queues.Empty: # type: ignore[attr-defined]
|
|
if timeout >= max_timeout: # if we've got to max_timeout and tried once with that value
|
|
raise TimeoutError(f"Timed out too many times while running {fn.__name__}, max timeout of {max_timeout} seconds reached.")
|
|
logger.info(f"timed out while running {fn.__name__} after {timeout} seconds, retrying...")
|
|
timeout *= 2 # double timeout and try again
|
|
else:
|
|
if potential_error is None:
|
|
return result # type: ignore[no-any-return]
|
|
traceback.print_exception(type(potential_error), potential_error, potential_error.__traceback__)
|
|
raise potential_error
|
|
finally:
|
|
try:
|
|
proc.terminate()
|
|
except Exception as e:
|
|
logger.info(f"'{fn.__name__}' proc unterminated, error: {e}")
|
|
|
|
|
|
def multiprocess_queuer(func: Callable, queue: mp.Queue, *args: Any, **kwargs: Any) -> None:
|
|
"""this is our multiprocesser helper function, lives at top-level to be Windows-compatible"""
|
|
queue.put(dill.loads(func)(*args, **kwargs))
|
|
|
|
|
|
def get_value_or_json_if_empty_string(options: str = None) -> str:
|
|
return options.strip() if options else "{}"
|
|
|
|
|
|
def airbyte_message_to_json(
|
|
message: AirbyteMessage,
|
|
*,
|
|
newline: bool = False,
|
|
) -> str:
|
|
"""Dump the provided AirbyteMessage to a JSON string.
|
|
|
|
Optionally append a newline character to the end of the string.
|
|
"""
|
|
result = orjson.dumps(cast(dict, AirbyteMessageSerializer.dump(message))).decode()
|
|
if newline:
|
|
result += "\n"
|
|
|
|
return result
|
|
|
|
|
|
def airbyte_message_from_json(message_json: str) -> AirbyteMessage:
|
|
"""Create an AirbyteMessage object from the provided JSON string."""
|
|
return AirbyteMessageSerializer.load(orjson.loads(message_json))
|