Async Snowflake Destination (#26703)
* snowflake at end of coding retreat week * turn off async snowflake * add comment * Fixed test to point to the ShimMessageConsumer instead of AsyncMessageConsumer * Automated Change --------- Co-authored-by: ryankfu <ryan.fu@airbyte.io> Co-authored-by: ryankfu <ryankfu@users.noreply.github.com>
This commit is contained in:
@@ -4,15 +4,19 @@
|
||||
|
||||
package io.airbyte.integrations.destination_async;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
import com.fasterxml.jackson.annotation.JsonPropertyDescription;
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import com.google.common.base.Preconditions;
|
||||
import io.airbyte.commons.json.Jsons;
|
||||
import io.airbyte.integrations.base.AirbyteMessageConsumer;
|
||||
import io.airbyte.integrations.base.SerializedAirbyteMessageConsumer;
|
||||
import io.airbyte.integrations.destination.buffered_stream_consumer.OnStartFunction;
|
||||
import io.airbyte.integrations.destination_async.buffers.BufferEnqueue;
|
||||
import io.airbyte.integrations.destination_async.buffers.BufferManager;
|
||||
import io.airbyte.integrations.destination_async.partial_messages.PartialAirbyteMessage;
|
||||
import io.airbyte.integrations.destination_async.state.FlushFailure;
|
||||
import io.airbyte.protocol.models.v0.AirbyteMessage;
|
||||
import io.airbyte.protocol.models.v0.AirbyteMessage.Type;
|
||||
import io.airbyte.protocol.models.v0.AirbyteStateMessage.AirbyteStateType;
|
||||
import io.airbyte.protocol.models.v0.ConfiguredAirbyteCatalog;
|
||||
import io.airbyte.protocol.models.v0.StreamDescriptor;
|
||||
import java.util.Optional;
|
||||
@@ -31,12 +35,10 @@ import org.slf4j.LoggerFactory;
|
||||
* {@link FlushWorkers}. See the other linked class for more detail.
|
||||
*/
|
||||
@Slf4j
|
||||
public class AsyncStreamConsumer implements AirbyteMessageConsumer {
|
||||
public class AsyncStreamConsumer implements SerializedAirbyteMessageConsumer {
|
||||
|
||||
private static final Logger LOGGER = LoggerFactory.getLogger(AsyncStreamConsumer.class);
|
||||
|
||||
private static final String NON_STREAM_STATE_IDENTIFIER = "GLOBAL";
|
||||
private final Consumer<AirbyteMessage> outputRecordCollector;
|
||||
private final OnStartFunction onStart;
|
||||
private final OnCloseFunction onClose;
|
||||
private final ConfiguredAirbyteCatalog catalog;
|
||||
@@ -44,6 +46,7 @@ public class AsyncStreamConsumer implements AirbyteMessageConsumer {
|
||||
private final BufferEnqueue bufferEnqueue;
|
||||
private final FlushWorkers flushWorkers;
|
||||
private final Set<StreamDescriptor> streamNames;
|
||||
private final FlushFailure flushFailure;
|
||||
|
||||
private boolean hasStarted;
|
||||
private boolean hasClosed;
|
||||
@@ -54,16 +57,27 @@ public class AsyncStreamConsumer implements AirbyteMessageConsumer {
|
||||
final DestinationFlushFunction flusher,
|
||||
final ConfiguredAirbyteCatalog catalog,
|
||||
final BufferManager bufferManager) {
|
||||
this(outputRecordCollector, onStart, onClose, flusher, catalog, bufferManager, new FlushFailure());
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
public AsyncStreamConsumer(final Consumer<AirbyteMessage> outputRecordCollector,
|
||||
final OnStartFunction onStart,
|
||||
final OnCloseFunction onClose,
|
||||
final DestinationFlushFunction flusher,
|
||||
final ConfiguredAirbyteCatalog catalog,
|
||||
final BufferManager bufferManager,
|
||||
final FlushFailure flushFailure) {
|
||||
hasStarted = false;
|
||||
hasClosed = false;
|
||||
|
||||
this.outputRecordCollector = outputRecordCollector;
|
||||
this.onStart = onStart;
|
||||
this.onClose = onClose;
|
||||
this.catalog = catalog;
|
||||
this.bufferManager = bufferManager;
|
||||
bufferEnqueue = bufferManager.getBufferEnqueue();
|
||||
flushWorkers = new FlushWorkers(this.bufferManager.getBufferDequeue(), flusher);
|
||||
this.flushFailure = flushFailure;
|
||||
flushWorkers = new FlushWorkers(bufferManager.getBufferDequeue(), flusher, outputRecordCollector, flushFailure, bufferManager.getStateManager());
|
||||
streamNames = StreamDescriptorUtils.fromConfiguredCatalog(catalog);
|
||||
}
|
||||
|
||||
@@ -79,15 +93,77 @@ public class AsyncStreamConsumer implements AirbyteMessageConsumer {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void accept(final AirbyteMessage message) throws Exception {
|
||||
public void accept(final String messageString, final Integer sizeInBytes) throws Exception {
|
||||
Preconditions.checkState(hasStarted, "Cannot accept records until consumer has started");
|
||||
propagateFlushWorkerExceptionIfPresent();
|
||||
/*
|
||||
* intentionally putting extractStream outside the buffer manager so that if in the future we want
|
||||
* to try to use a threadpool to partial deserialize to get record type and stream name, we can do
|
||||
* it without touching buffer manager.
|
||||
* to try to use a thread pool to partially deserialize to get record type and stream name, we can
|
||||
* do it without touching buffer manager.
|
||||
*/
|
||||
extractStream(message)
|
||||
.ifPresent(streamDescriptor -> bufferEnqueue.addRecord(streamDescriptor, message));
|
||||
deserializeAirbyteMessage(messageString)
|
||||
.ifPresent(message -> {
|
||||
if (message.getType() == Type.RECORD) {
|
||||
validateRecord(message);
|
||||
}
|
||||
|
||||
bufferEnqueue.addRecord(message, sizeInBytes);
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Deserializes to a {@link PartialAirbyteMessage} which can represent both a Record or a State
|
||||
* Message
|
||||
*
|
||||
* @param messageString the string to deserialize
|
||||
* @return PartialAirbyteMessage if the message is valid, empty otherwise
|
||||
*/
|
||||
private Optional<PartialAirbyteMessage> deserializeAirbyteMessage(final String messageString) {
|
||||
final Optional<PartialAirbyteMessage> messageOptional = Jsons.tryDeserialize(messageString, PartialAirbyteMessage.class)
|
||||
.map(partial -> partial.withSerialized(messageString));
|
||||
if (messageOptional.isPresent()) {
|
||||
return messageOptional;
|
||||
} else {
|
||||
if (isStateMessage(messageString)) {
|
||||
throw new IllegalStateException("Invalid state message: " + messageString);
|
||||
} else {
|
||||
LOGGER.error("Received invalid message: " + messageString);
|
||||
return Optional.empty();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Tests whether the provided JSON string represents a state message.
|
||||
*
|
||||
* @param input a JSON string that represents an {@link AirbyteMessage}.
|
||||
* @return {@code true} if the message is a state message, {@code false} otherwise.
|
||||
*/
|
||||
private static boolean isStateMessage(final String input) {
|
||||
final Optional<AirbyteTypeMessage> deserialized = Jsons.tryDeserialize(input, AirbyteTypeMessage.class);
|
||||
return deserialized.filter(airbyteTypeMessage -> airbyteTypeMessage.getType() == Type.STATE).isPresent();
|
||||
}
|
||||
|
||||
/**
|
||||
* Custom class that can be used to parse a JSON message to determine the type of the represented
|
||||
* {@link AirbyteMessage}.
|
||||
*/
|
||||
private static class AirbyteTypeMessage {
|
||||
|
||||
@JsonProperty("type")
|
||||
@JsonPropertyDescription("Message type")
|
||||
private AirbyteMessage.Type type;
|
||||
|
||||
@JsonProperty("type")
|
||||
public AirbyteMessage.Type getType() {
|
||||
return type;
|
||||
}
|
||||
|
||||
@JsonProperty("type")
|
||||
public void setType(final AirbyteMessage.Type type) {
|
||||
this.type = type;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -100,48 +176,32 @@ public class AsyncStreamConsumer implements AirbyteMessageConsumer {
|
||||
// we need to close the workers before closing the bufferManagers (and underlying buffers)
|
||||
// or we risk in-memory data.
|
||||
flushWorkers.close();
|
||||
|
||||
bufferManager.close();
|
||||
onClose.call();
|
||||
LOGGER.info("{} closed.", AsyncStreamConsumer.class);
|
||||
|
||||
// as this throws an exception, we need to be after all other close functions.
|
||||
propagateFlushWorkerExceptionIfPresent();
|
||||
LOGGER.info("{} closed", AsyncStreamConsumer.class);
|
||||
}
|
||||
|
||||
// todo (cgardens) - handle global state.
|
||||
/**
|
||||
* Extract the stream from the message, if the message is a record or state. Otherwise, we don't
|
||||
* care.
|
||||
*
|
||||
* @param message message to extract stream from
|
||||
* @return stream descriptor if the message is a record or state, otherwise empty. In the case of
|
||||
* global state messages the stream descriptor is hardcoded
|
||||
*/
|
||||
private Optional<StreamDescriptor> extractStream(final AirbyteMessage message) {
|
||||
if (message.getType() == Type.RECORD) {
|
||||
final StreamDescriptor streamDescriptor = new StreamDescriptor()
|
||||
.withNamespace(message.getRecord().getNamespace())
|
||||
.withName(message.getRecord().getStream());
|
||||
|
||||
validateRecord(message, streamDescriptor);
|
||||
|
||||
return Optional.of(streamDescriptor);
|
||||
} else if (message.getType() == Type.STATE) {
|
||||
if (message.getState().getType() == AirbyteStateType.STREAM) {
|
||||
return Optional.of(message.getState().getStream().getStreamDescriptor());
|
||||
} else {
|
||||
return Optional.of(new StreamDescriptor().withNamespace(NON_STREAM_STATE_IDENTIFIER).withNamespace(NON_STREAM_STATE_IDENTIFIER));
|
||||
}
|
||||
} else {
|
||||
return Optional.empty();
|
||||
private void propagateFlushWorkerExceptionIfPresent() throws Exception {
|
||||
if (flushFailure.isFailed()) {
|
||||
throw flushFailure.getException();
|
||||
}
|
||||
}
|
||||
|
||||
private void validateRecord(final AirbyteMessage message, final StreamDescriptor streamDescriptor) {
|
||||
private void validateRecord(final PartialAirbyteMessage message) {
|
||||
final StreamDescriptor streamDescriptor = new StreamDescriptor()
|
||||
.withNamespace(message.getRecord().getNamespace())
|
||||
.withName(message.getRecord().getStream());
|
||||
// if stream is not part of list of streams to sync to then throw invalid stream exception
|
||||
if (!streamNames.contains(streamDescriptor)) {
|
||||
throwUnrecognizedStream(catalog, message);
|
||||
}
|
||||
}
|
||||
|
||||
private static void throwUnrecognizedStream(final ConfiguredAirbyteCatalog catalog, final AirbyteMessage message) {
|
||||
private static void throwUnrecognizedStream(final ConfiguredAirbyteCatalog catalog, final PartialAirbyteMessage message) {
|
||||
throw new IllegalArgumentException(
|
||||
String.format("Message contained record from a stream that was not in the catalog. \ncatalog: %s , \nmessage: %s",
|
||||
Jsons.serialize(catalog), Jsons.serialize(message)));
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
package io.airbyte.integrations.destination_async;
|
||||
|
||||
import io.airbyte.integrations.destination_async.partial_messages.PartialAirbyteMessage;
|
||||
import io.airbyte.protocol.models.v0.AirbyteMessage;
|
||||
import io.airbyte.protocol.models.v0.StreamDescriptor;
|
||||
import java.util.stream.Stream;
|
||||
@@ -35,7 +36,7 @@ public interface DestinationFlushFunction {
|
||||
* {@link #getOptimalBatchSizeBytes()} size
|
||||
* @throws Exception
|
||||
*/
|
||||
void flush(StreamDescriptor decs, Stream<AirbyteMessage> stream) throws Exception;
|
||||
void flush(StreamDescriptor decs, Stream<PartialAirbyteMessage> stream) throws Exception;
|
||||
|
||||
/**
|
||||
* When invoking {@link #flush(StreamDescriptor, Stream)}, best effort attempt to invoke flush with
|
||||
|
||||
@@ -4,21 +4,28 @@
|
||||
|
||||
package io.airbyte.integrations.destination_async;
|
||||
|
||||
import io.airbyte.commons.json.Jsons;
|
||||
import io.airbyte.integrations.destination_async.buffers.BufferDequeue;
|
||||
import io.airbyte.integrations.destination_async.buffers.StreamAwareQueue.MessageWithMeta;
|
||||
import io.airbyte.integrations.destination_async.partial_messages.PartialAirbyteMessage;
|
||||
import io.airbyte.integrations.destination_async.state.FlushFailure;
|
||||
import io.airbyte.integrations.destination_async.state.GlobalAsyncStateManager;
|
||||
import io.airbyte.protocol.models.v0.AirbyteMessage;
|
||||
import io.airbyte.protocol.models.v0.StreamDescriptor;
|
||||
import java.time.Instant;
|
||||
import java.time.temporal.ChronoUnit;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.UUID;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.Executors;
|
||||
import java.util.concurrent.ScheduledExecutorService;
|
||||
import java.util.concurrent.ThreadPoolExecutor;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
import java.util.function.Consumer;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.io.FileUtils;
|
||||
|
||||
/**
|
||||
* Parallel flushing of Destination data.
|
||||
@@ -27,7 +34,7 @@ import org.apache.commons.io.FileUtils;
|
||||
* allows for parallel data flushing.
|
||||
* <p>
|
||||
* Parallelising is important as it 1) minimises Destination backpressure 2) minimises the effect of
|
||||
* IO pauses on Destination performance. The second point is particularly important since majority
|
||||
* IO pauses on Destination performance. The second point is particularly important since a majority
|
||||
* of Destination work is IO bound.
|
||||
* <p>
|
||||
* The {@link #supervisorThread} assigns work to worker threads by looping over
|
||||
@@ -41,91 +48,77 @@ import org.apache.commons.io.FileUtils;
|
||||
@Slf4j
|
||||
public class FlushWorkers implements AutoCloseable {
|
||||
|
||||
public static final long TOTAL_QUEUES_MAX_SIZE_LIMIT_BYTES = (long) (Runtime.getRuntime().maxMemory() * 0.8);
|
||||
private static final long QUEUE_FLUSH_THRESHOLD_BYTES = 10 * 1024 * 1024; // 10MB
|
||||
private static final long MAX_TIME_BETWEEN_REC_MINS = 5L;
|
||||
private static final long SUPERVISOR_INITIAL_DELAY_SECS = 0L;
|
||||
private static final long SUPERVISOR_PERIOD_SECS = 1L;
|
||||
private static final long DEBUG_INITIAL_DELAY_SECS = 0L;
|
||||
private static final long DEBUG_PERIOD_SECS = 10L;
|
||||
private final ScheduledExecutorService supervisorThread = Executors.newScheduledThreadPool(1);
|
||||
private final ExecutorService workerPool = Executors.newFixedThreadPool(5);
|
||||
|
||||
private final ScheduledExecutorService supervisorThread;
|
||||
private final ExecutorService workerPool;
|
||||
private final BufferDequeue bufferDequeue;
|
||||
private final DestinationFlushFunction flusher;
|
||||
private final ScheduledExecutorService debugLoop = Executors.newSingleThreadScheduledExecutor();
|
||||
private final ConcurrentHashMap<StreamDescriptor, AtomicInteger> streamToInProgressWorkers = new ConcurrentHashMap<>();
|
||||
private final Consumer<AirbyteMessage> outputRecordCollector;
|
||||
private final ScheduledExecutorService debugLoop;
|
||||
private final RunningFlushWorkers runningFlushWorkers;
|
||||
private final DetectStreamToFlush detectStreamToFlush;
|
||||
|
||||
public FlushWorkers(final BufferDequeue bufferDequeue, final DestinationFlushFunction flushFunction) {
|
||||
private final FlushFailure flushFailure;
|
||||
|
||||
private final AtomicBoolean isClosing;
|
||||
private final GlobalAsyncStateManager stateManager;
|
||||
|
||||
public FlushWorkers(final BufferDequeue bufferDequeue,
|
||||
final DestinationFlushFunction flushFunction,
|
||||
final Consumer<AirbyteMessage> outputRecordCollector,
|
||||
final FlushFailure flushFailure,
|
||||
final GlobalAsyncStateManager stateManager) {
|
||||
this.bufferDequeue = bufferDequeue;
|
||||
this.outputRecordCollector = outputRecordCollector;
|
||||
this.flushFailure = flushFailure;
|
||||
this.stateManager = stateManager;
|
||||
flusher = flushFunction;
|
||||
debugLoop = Executors.newSingleThreadScheduledExecutor();
|
||||
supervisorThread = Executors.newScheduledThreadPool(1);
|
||||
workerPool = Executors.newFixedThreadPool(5);
|
||||
isClosing = new AtomicBoolean(false);
|
||||
runningFlushWorkers = new RunningFlushWorkers();
|
||||
detectStreamToFlush = new DetectStreamToFlush(bufferDequeue, runningFlushWorkers, isClosing, flusher);
|
||||
}
|
||||
|
||||
public void start() {
|
||||
supervisorThread.scheduleAtFixedRate(this::retrieveWork, SUPERVISOR_INITIAL_DELAY_SECS, SUPERVISOR_PERIOD_SECS,
|
||||
supervisorThread.scheduleAtFixedRate(this::retrieveWork,
|
||||
SUPERVISOR_INITIAL_DELAY_SECS,
|
||||
SUPERVISOR_PERIOD_SECS,
|
||||
TimeUnit.SECONDS);
|
||||
debugLoop.scheduleAtFixedRate(this::printWorkerInfo,
|
||||
DEBUG_INITIAL_DELAY_SECS,
|
||||
DEBUG_PERIOD_SECS,
|
||||
TimeUnit.SECONDS);
|
||||
debugLoop.scheduleAtFixedRate(this::printWorkerInfo, DEBUG_INITIAL_DELAY_SECS, DEBUG_PERIOD_SECS, TimeUnit.SECONDS);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() throws Exception {
|
||||
flushAll();
|
||||
|
||||
supervisorThread.shutdown();
|
||||
final var supervisorShut = supervisorThread.awaitTermination(5L, TimeUnit.MINUTES);
|
||||
log.info("Supervisor shut status: {}", supervisorShut);
|
||||
|
||||
log.info("Starting worker pool shutdown..");
|
||||
workerPool.shutdown();
|
||||
final var workersShut = workerPool.awaitTermination(5L, TimeUnit.MINUTES);
|
||||
log.info("Workers shut status: {}", workersShut);
|
||||
|
||||
debugLoop.shutdownNow();
|
||||
}
|
||||
|
||||
private void retrieveWork() {
|
||||
// todo (cgardens) - i'm not convinced this makes sense. as we get close to the limit, we should
|
||||
// flush more eagerly, but "flush all" is never a particularly useful thing in this world.
|
||||
// if the total size is > n, flush all buffers
|
||||
if (bufferDequeue.getTotalGlobalQueueSizeBytes() > TOTAL_QUEUES_MAX_SIZE_LIMIT_BYTES) {
|
||||
flushAll();
|
||||
return;
|
||||
}
|
||||
try {
|
||||
log.info("Retrieve Work -- Finding queues to flush");
|
||||
final ThreadPoolExecutor threadPoolExecutor = (ThreadPoolExecutor) workerPool;
|
||||
int allocatableThreads = threadPoolExecutor.getMaximumPoolSize() - threadPoolExecutor.getActiveCount();
|
||||
|
||||
final ThreadPoolExecutor threadPoolExecutor = (ThreadPoolExecutor) workerPool;
|
||||
var allocatableThreads = threadPoolExecutor.getMaximumPoolSize() - threadPoolExecutor.getActiveCount();
|
||||
while (allocatableThreads > 0) {
|
||||
final Optional<StreamDescriptor> next = detectStreamToFlush.getNextStreamToFlush();
|
||||
|
||||
// todo (cgardens) - build a score to prioritize which queue to flush next. e.g. if a queue is very
|
||||
// large, flush it first. if a queue has not been flushed in a while, flush it next.
|
||||
// otherwise, if each individual stream has crossed a specific threshold, flush
|
||||
for (final StreamDescriptor stream : bufferDequeue.getBufferedStreams()) {
|
||||
if (allocatableThreads == 0) {
|
||||
break;
|
||||
}
|
||||
|
||||
// while we allow out-of-order processing for speed improvements via multiple workers reading from
|
||||
// the same queue, also avoid scheduling more workers than what is already in progress.
|
||||
final var inProgressSizeByte = (bufferDequeue.getQueueSizeBytes(stream).get() -
|
||||
streamToInProgressWorkers.getOrDefault(stream, new AtomicInteger(0)).get() * QUEUE_FLUSH_THRESHOLD_BYTES);
|
||||
final var exceedSize = inProgressSizeByte >= QUEUE_FLUSH_THRESHOLD_BYTES;
|
||||
final var tooLongSinceLastRecord = bufferDequeue.getTimeOfLastRecord(stream)
|
||||
.map(time -> time.isBefore(Instant.now().minus(MAX_TIME_BETWEEN_REC_MINS, ChronoUnit.MINUTES)))
|
||||
.orElse(false);
|
||||
|
||||
if (exceedSize || tooLongSinceLastRecord) {
|
||||
log.info(
|
||||
"Allocated stream {}, exceedSize:{}, tooLongSinceLastRecord: {}, bytes in queue: {} computed in-progress bytes: {} , threshold bytes: {}",
|
||||
stream.getName(), exceedSize, tooLongSinceLastRecord,
|
||||
FileUtils.byteCountToDisplaySize(bufferDequeue.getQueueSizeBytes(stream).get()),
|
||||
FileUtils.byteCountToDisplaySize(inProgressSizeByte),
|
||||
FileUtils.byteCountToDisplaySize(QUEUE_FLUSH_THRESHOLD_BYTES));
|
||||
allocatableThreads--;
|
||||
if (streamToInProgressWorkers.containsKey(stream)) {
|
||||
streamToInProgressWorkers.get(stream).getAndAdd(1);
|
||||
if (next.isPresent()) {
|
||||
final StreamDescriptor desc = next.get();
|
||||
final UUID flushWorkerId = UUID.randomUUID();
|
||||
runningFlushWorkers.trackFlushWorker(desc, flushWorkerId);
|
||||
allocatableThreads--;
|
||||
flush(desc, flushWorkerId);
|
||||
} else {
|
||||
streamToInProgressWorkers.put(stream, new AtomicInteger(1));
|
||||
break;
|
||||
}
|
||||
flush(stream);
|
||||
}
|
||||
} catch (final Exception e) {
|
||||
log.error("Flush worker error: ", e);
|
||||
flushFailure.propagateException(e);
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -142,28 +135,101 @@ public class FlushWorkers implements AutoCloseable {
|
||||
|
||||
}
|
||||
|
||||
private void flushAll() {
|
||||
log.info("Flushing all buffers..");
|
||||
for (final StreamDescriptor desc : bufferDequeue.getBufferedStreams()) {
|
||||
flush(desc);
|
||||
}
|
||||
}
|
||||
|
||||
private void flush(final StreamDescriptor desc) {
|
||||
private void flush(final StreamDescriptor desc, final UUID flushWorkerId) {
|
||||
workerPool.submit(() -> {
|
||||
log.info("Worker picked up work..");
|
||||
log.info("Flush Worker ({}) -- Worker picked up work.", humanReadableFlushWorkerId(flushWorkerId));
|
||||
try {
|
||||
log.info("Attempting to read from queue {}. Current queue size: {}", desc, bufferDequeue.getQueueSizeInRecords(desc).get());
|
||||
log.info("Flush Worker ({}) -- Attempting to read from queue namespace: {}, stream: {}.",
|
||||
humanReadableFlushWorkerId(flushWorkerId),
|
||||
desc.getNamespace(),
|
||||
desc.getName());
|
||||
|
||||
try (final var batch = bufferDequeue.take(desc, flusher.getOptimalBatchSizeBytes())) {
|
||||
flusher.flush(desc, batch.getData());
|
||||
runningFlushWorkers.registerBatchSize(desc, flushWorkerId, batch.getSizeInBytes());
|
||||
final Map<Long, Long> stateIdToCount = batch.getData()
|
||||
.stream()
|
||||
.map(MessageWithMeta::stateId)
|
||||
.collect(Collectors.groupingBy(
|
||||
stateId -> stateId,
|
||||
Collectors.counting()));
|
||||
log.info("Flush Worker ({}) -- Batch contains: {} records, {} bytes.",
|
||||
humanReadableFlushWorkerId(flushWorkerId),
|
||||
batch.getData().size(),
|
||||
AirbyteFileUtils.byteCountToDisplaySize(batch.getSizeInBytes()));
|
||||
|
||||
flusher.flush(desc, batch.getData().stream().map(MessageWithMeta::message));
|
||||
emitStateMessages(batch.flushStates(stateIdToCount));
|
||||
}
|
||||
|
||||
log.info("Worker finished flushing. Current queue size: {}", bufferDequeue.getQueueSizeInRecords(desc));
|
||||
log.info("Flush Worker ({}) -- Worker finished flushing. Current queue size: {}",
|
||||
humanReadableFlushWorkerId(flushWorkerId),
|
||||
bufferDequeue.getQueueSizeInRecords(desc).orElseThrow());
|
||||
} catch (final Exception e) {
|
||||
log.error(String.format("Flush Worker (%s) -- flush worker error: ", humanReadableFlushWorkerId(flushWorkerId)), e);
|
||||
flushFailure.propagateException(e);
|
||||
throw new RuntimeException(e);
|
||||
} finally {
|
||||
runningFlushWorkers.completeFlushWorker(desc, flushWorkerId);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() throws Exception {
|
||||
log.info("Closing flush workers -- waiting for all buffers to flush");
|
||||
isClosing.set(true);
|
||||
// wait for all buffers to be flushed.
|
||||
while (true) {
|
||||
final Map<StreamDescriptor, Long> streamDescriptorToRemainingRecords = bufferDequeue.getBufferedStreams()
|
||||
.stream()
|
||||
.collect(Collectors.toMap(desc -> desc, desc -> bufferDequeue.getQueueSizeInRecords(desc).orElseThrow()));
|
||||
|
||||
final boolean anyRecordsLeft = streamDescriptorToRemainingRecords
|
||||
.values()
|
||||
.stream()
|
||||
.anyMatch(size -> size > 0);
|
||||
|
||||
if (!anyRecordsLeft) {
|
||||
break;
|
||||
}
|
||||
|
||||
final var workerInfo = new StringBuilder().append("REMAINING_BUFFERS_INFO").append(System.lineSeparator());
|
||||
streamDescriptorToRemainingRecords.entrySet()
|
||||
.stream()
|
||||
.filter(entry -> entry.getValue() > 0)
|
||||
.forEach(entry -> workerInfo.append(String.format(" Namespace: %s Stream: %s -- remaining records: %d",
|
||||
entry.getKey().getNamespace(),
|
||||
entry.getKey().getName(),
|
||||
entry.getValue())));
|
||||
log.info(workerInfo.toString());
|
||||
log.info("Waiting for all streams to flush.");
|
||||
Thread.sleep(1000);
|
||||
}
|
||||
log.info("Closing flush workers -- all buffers flushed");
|
||||
|
||||
// before shutting down the supervisor, flush all state.
|
||||
emitStateMessages(stateManager.flushStates());
|
||||
supervisorThread.shutdown();
|
||||
final var supervisorShut = supervisorThread.awaitTermination(5L, TimeUnit.MINUTES);
|
||||
log.info("Closing flush workers -- Supervisor shutdown status: {}", supervisorShut);
|
||||
|
||||
log.info("Closing flush workers -- Starting worker pool shutdown..");
|
||||
workerPool.shutdown();
|
||||
final var workersShut = workerPool.awaitTermination(5L, TimeUnit.MINUTES);
|
||||
log.info("Closing flush workers -- Workers shutdown status: {}", workersShut);
|
||||
|
||||
debugLoop.shutdownNow();
|
||||
}
|
||||
|
||||
private void emitStateMessages(final List<PartialAirbyteMessage> partials) {
|
||||
partials
|
||||
.stream()
|
||||
.map(partial -> Jsons.deserialize(partial.getSerialized(), AirbyteMessage.class))
|
||||
.forEach(outputRecordCollector);
|
||||
}
|
||||
|
||||
private static String humanReadableFlushWorkerId(final UUID flushWorkerId) {
|
||||
return flushWorkerId.toString().substring(0, 5);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -5,18 +5,20 @@
|
||||
package io.airbyte.integrations.destination_async.buffers;
|
||||
|
||||
import io.airbyte.integrations.destination_async.GlobalMemoryManager;
|
||||
import io.airbyte.protocol.models.v0.AirbyteMessage;
|
||||
import io.airbyte.integrations.destination_async.buffers.MemoryBoundedLinkedBlockingQueue.MemoryItem;
|
||||
import io.airbyte.integrations.destination_async.buffers.StreamAwareQueue.MessageWithMeta;
|
||||
import io.airbyte.integrations.destination_async.state.GlobalAsyncStateManager;
|
||||
import io.airbyte.protocol.models.v0.StreamDescriptor;
|
||||
import java.time.Instant;
|
||||
import java.util.HashSet;
|
||||
import java.util.LinkedList;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.concurrent.ConcurrentMap;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.concurrent.atomic.AtomicLong;
|
||||
import java.util.concurrent.locks.ReentrantLock;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
/**
|
||||
* Represents the minimal interface over the underlying buffer queues required for dequeue
|
||||
@@ -29,22 +31,25 @@ import java.util.stream.Stream;
|
||||
public class BufferDequeue {
|
||||
|
||||
private final GlobalMemoryManager memoryManager;
|
||||
private final ConcurrentMap<StreamDescriptor, MemoryBoundedLinkedBlockingQueue<AirbyteMessage>> buffers;
|
||||
private final ConcurrentMap<StreamDescriptor, StreamAwareQueue> buffers;
|
||||
private final GlobalAsyncStateManager stateManager;
|
||||
private final ConcurrentMap<StreamDescriptor, ReentrantLock> bufferLocks;
|
||||
|
||||
public BufferDequeue(final GlobalMemoryManager memoryManager,
|
||||
final ConcurrentMap<StreamDescriptor, MemoryBoundedLinkedBlockingQueue<AirbyteMessage>> buffers) {
|
||||
final ConcurrentMap<StreamDescriptor, StreamAwareQueue> buffers,
|
||||
final GlobalAsyncStateManager stateManager) {
|
||||
this.memoryManager = memoryManager;
|
||||
this.buffers = buffers;
|
||||
this.stateManager = stateManager;
|
||||
bufferLocks = new ConcurrentHashMap<>();
|
||||
}
|
||||
|
||||
/**
|
||||
* Primary dequeue method. Best-effort read a specified optimal memory size from the queue.
|
||||
* Primary dequeue method. Reads from queue up to optimalBytesToRead OR until the queue is empty.
|
||||
*
|
||||
* @param streamDescriptor specific buffer to take from
|
||||
* @param optimalBytesToRead bytes to read, if possible
|
||||
* @return
|
||||
* @return autocloseable batch object, that frees memory.
|
||||
*/
|
||||
public MemoryAwareMessageBatch take(final StreamDescriptor streamDescriptor, final long optimalBytesToRead) {
|
||||
final var queue = buffers.get(streamDescriptor);
|
||||
@@ -57,33 +62,27 @@ public class BufferDequeue {
|
||||
try {
|
||||
final AtomicLong bytesRead = new AtomicLong();
|
||||
|
||||
final var s = Stream.generate(() -> {
|
||||
try {
|
||||
return queue.poll(20, TimeUnit.MILLISECONDS);
|
||||
} catch (final InterruptedException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}).takeWhile(memoryItem -> {
|
||||
// if no new records after waiting, the stream is done.
|
||||
if (memoryItem == null) {
|
||||
return false;
|
||||
}
|
||||
final List<MessageWithMeta> output = new LinkedList<>();
|
||||
while (queue.size() > 0) {
|
||||
final MemoryItem<MessageWithMeta> memoryItem = queue.peek().orElseThrow();
|
||||
|
||||
// otherwise pull records until we hit the memory limit.
|
||||
final long newSize = memoryItem.size() + bytesRead.get();
|
||||
if (newSize <= optimalBytesToRead) {
|
||||
bytesRead.addAndGet(memoryItem.size());
|
||||
return true;
|
||||
output.add(queue.poll().item());
|
||||
} else {
|
||||
return false;
|
||||
break;
|
||||
}
|
||||
}).map(MemoryBoundedLinkedBlockingQueue.MemoryItem::item)
|
||||
.toList()
|
||||
.stream();
|
||||
}
|
||||
|
||||
queue.addMaxMemory(-bytesRead.get());
|
||||
|
||||
return new MemoryAwareMessageBatch(s, bytesRead.get(), memoryManager);
|
||||
return new MemoryAwareMessageBatch(
|
||||
output,
|
||||
bytesRead.get(),
|
||||
memoryManager,
|
||||
stateManager);
|
||||
} finally {
|
||||
bufferLocks.get(streamDescriptor).unlock();
|
||||
}
|
||||
@@ -91,10 +90,9 @@ public class BufferDequeue {
|
||||
|
||||
/**
|
||||
* The following methods are provide metadata for buffer flushing calculations. Consumers are
|
||||
* expected to call {@link #getBufferedStreams()} to retrieve the currently buffered streams as a
|
||||
* handle to the remaining methods.
|
||||
* expected to call it to retrieve the currently buffered streams as a handle to the remaining
|
||||
* methods.
|
||||
*/
|
||||
|
||||
public Set<StreamDescriptor> getBufferedStreams() {
|
||||
return new HashSet<>(buffers.keySet());
|
||||
}
|
||||
@@ -104,7 +102,7 @@ public class BufferDequeue {
|
||||
}
|
||||
|
||||
public long getTotalGlobalQueueSizeBytes() {
|
||||
return buffers.values().stream().map(MemoryBoundedLinkedBlockingQueue::getCurrentMemoryUsage).mapToLong(Long::longValue).sum();
|
||||
return buffers.values().stream().map(StreamAwareQueue::getCurrentMemoryUsage).mapToLong(Long::longValue).sum();
|
||||
}
|
||||
|
||||
public Optional<Long> getQueueSizeInRecords(final StreamDescriptor streamDescriptor) {
|
||||
@@ -112,14 +110,14 @@ public class BufferDequeue {
|
||||
}
|
||||
|
||||
public Optional<Long> getQueueSizeBytes(final StreamDescriptor streamDescriptor) {
|
||||
return getBuffer(streamDescriptor).map(MemoryBoundedLinkedBlockingQueue::getCurrentMemoryUsage);
|
||||
return getBuffer(streamDescriptor).map(StreamAwareQueue::getCurrentMemoryUsage);
|
||||
}
|
||||
|
||||
public Optional<Instant> getTimeOfLastRecord(final StreamDescriptor streamDescriptor) {
|
||||
return getBuffer(streamDescriptor).flatMap(MemoryBoundedLinkedBlockingQueue::getTimeOfLastMessage);
|
||||
return getBuffer(streamDescriptor).flatMap(StreamAwareQueue::getTimeOfLastMessage);
|
||||
}
|
||||
|
||||
private Optional<MemoryBoundedLinkedBlockingQueue<AirbyteMessage>> getBuffer(final StreamDescriptor streamDescriptor) {
|
||||
private Optional<StreamAwareQueue> getBuffer(final StreamDescriptor streamDescriptor) {
|
||||
if (buffers.containsKey(streamDescriptor)) {
|
||||
return Optional.of(buffers.get(streamDescriptor));
|
||||
}
|
||||
|
||||
@@ -4,10 +4,12 @@
|
||||
|
||||
package io.airbyte.integrations.destination_async.buffers;
|
||||
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import io.airbyte.integrations.destination.buffered_stream_consumer.RecordSizeEstimator;
|
||||
import static java.lang.Thread.sleep;
|
||||
|
||||
import io.airbyte.integrations.destination_async.GlobalMemoryManager;
|
||||
import io.airbyte.protocol.models.v0.AirbyteMessage;
|
||||
import io.airbyte.integrations.destination_async.partial_messages.PartialAirbyteMessage;
|
||||
import io.airbyte.integrations.destination_async.state.GlobalAsyncStateManager;
|
||||
import io.airbyte.protocol.models.v0.AirbyteMessage.Type;
|
||||
import io.airbyte.protocol.models.v0.StreamDescriptor;
|
||||
import java.util.concurrent.ConcurrentMap;
|
||||
|
||||
@@ -17,54 +19,65 @@ import java.util.concurrent.ConcurrentMap;
|
||||
*/
|
||||
public class BufferEnqueue {
|
||||
|
||||
private final long initialQueueSizeBytes;
|
||||
private final RecordSizeEstimator recordSizeEstimator;
|
||||
|
||||
private final GlobalMemoryManager memoryManager;
|
||||
private final ConcurrentMap<StreamDescriptor, MemoryBoundedLinkedBlockingQueue<AirbyteMessage>> buffers;
|
||||
private final ConcurrentMap<StreamDescriptor, StreamAwareQueue> buffers;
|
||||
private final GlobalAsyncStateManager stateManager;
|
||||
|
||||
public BufferEnqueue(final GlobalMemoryManager memoryManager,
|
||||
final ConcurrentMap<StreamDescriptor, MemoryBoundedLinkedBlockingQueue<AirbyteMessage>> buffers) {
|
||||
this(GlobalMemoryManager.BLOCK_SIZE_BYTES, memoryManager, buffers);
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
public BufferEnqueue(final long initialQueueSizeBytes,
|
||||
final GlobalMemoryManager memoryManager,
|
||||
final ConcurrentMap<StreamDescriptor, MemoryBoundedLinkedBlockingQueue<AirbyteMessage>> buffers) {
|
||||
this.initialQueueSizeBytes = initialQueueSizeBytes;
|
||||
final ConcurrentMap<StreamDescriptor, StreamAwareQueue> buffers,
|
||||
final GlobalAsyncStateManager stateManager) {
|
||||
this.memoryManager = memoryManager;
|
||||
this.buffers = buffers;
|
||||
this.recordSizeEstimator = new RecordSizeEstimator();
|
||||
this.stateManager = stateManager;
|
||||
}
|
||||
|
||||
/**
|
||||
* Buffer a record. Contains memory management logic to dynamically adjust queue size based via
|
||||
* {@link GlobalMemoryManager} accounting for incoming records.
|
||||
*
|
||||
* @param streamDescriptor stream to buffer record to
|
||||
* @param message to buffer
|
||||
* @param sizeInBytes
|
||||
*/
|
||||
public void addRecord(final StreamDescriptor streamDescriptor, final AirbyteMessage message) {
|
||||
if (!buffers.containsKey(streamDescriptor)) {
|
||||
buffers.put(streamDescriptor, new MemoryBoundedLinkedBlockingQueue<>(memoryManager.requestMemory()));
|
||||
public void addRecord(final PartialAirbyteMessage message, final Integer sizeInBytes) {
|
||||
if (message.getType() == Type.RECORD) {
|
||||
handleRecord(message, sizeInBytes);
|
||||
} else if (message.getType() == Type.STATE) {
|
||||
stateManager.trackState(message, sizeInBytes);
|
||||
}
|
||||
}
|
||||
|
||||
// todo (cgardens) - handle estimating state message size.
|
||||
final long messageSize = message.getType() == AirbyteMessage.Type.RECORD ? recordSizeEstimator.getEstimatedByteSize(message.getRecord()) : 1024;
|
||||
private void handleRecord(final PartialAirbyteMessage message, final Integer sizeInBytes) {
|
||||
final StreamDescriptor streamDescriptor = extractStateFromRecord(message);
|
||||
if (streamDescriptor != null && !buffers.containsKey(streamDescriptor)) {
|
||||
buffers.put(streamDescriptor, new StreamAwareQueue(memoryManager.requestMemory()));
|
||||
}
|
||||
final long stateId = stateManager.getStateIdAndIncrementCounter(streamDescriptor);
|
||||
|
||||
final var queue = buffers.get(streamDescriptor);
|
||||
var addedToQueue = queue.offer(message, messageSize);
|
||||
var addedToQueue = queue.offer(message, sizeInBytes, stateId);
|
||||
|
||||
// todo (cgardens) - what if the record being added is bigger than the block size?
|
||||
// if failed, try to increase memory and add to queue.
|
||||
int i = 0;
|
||||
while (!addedToQueue) {
|
||||
final var newlyAllocatedMemory = memoryManager.requestMemory();
|
||||
if (newlyAllocatedMemory > 0) {
|
||||
queue.addMaxMemory(newlyAllocatedMemory);
|
||||
}
|
||||
addedToQueue = queue.offer(message, messageSize);
|
||||
addedToQueue = queue.offer(message, sizeInBytes, stateId);
|
||||
i++;
|
||||
if (i > 5) {
|
||||
try {
|
||||
sleep(500);
|
||||
} catch (final InterruptedException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static StreamDescriptor extractStateFromRecord(final PartialAirbyteMessage message) {
|
||||
return new StreamDescriptor()
|
||||
.withNamespace(message.getRecord().getNamespace())
|
||||
.withName(message.getRecord().getStream());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -4,11 +4,11 @@
|
||||
|
||||
package io.airbyte.integrations.destination_async.buffers;
|
||||
|
||||
import static io.airbyte.integrations.destination_async.FlushWorkers.TOTAL_QUEUES_MAX_SIZE_LIMIT_BYTES;
|
||||
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import io.airbyte.integrations.destination_async.AirbyteFileUtils;
|
||||
import io.airbyte.integrations.destination_async.FlushWorkers;
|
||||
import io.airbyte.integrations.destination_async.GlobalMemoryManager;
|
||||
import io.airbyte.protocol.models.v0.AirbyteMessage;
|
||||
import io.airbyte.integrations.destination_async.state.GlobalAsyncStateManager;
|
||||
import io.airbyte.protocol.models.v0.StreamDescriptor;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.concurrent.ConcurrentMap;
|
||||
@@ -17,24 +17,44 @@ import java.util.concurrent.ScheduledExecutorService;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.io.FileUtils;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
@Slf4j
|
||||
public class BufferManager {
|
||||
|
||||
private final ConcurrentMap<StreamDescriptor, MemoryBoundedLinkedBlockingQueue<AirbyteMessage>> buffers;
|
||||
private static final Logger LOGGER = LoggerFactory.getLogger(BufferManager.class);
|
||||
|
||||
public final long maxMemory;
|
||||
private final ConcurrentMap<StreamDescriptor, StreamAwareQueue> buffers;
|
||||
private final BufferEnqueue bufferEnqueue;
|
||||
private final BufferDequeue bufferDequeue;
|
||||
private final GlobalMemoryManager memoryManager;
|
||||
private final ScheduledExecutorService debugLoop = Executors.newSingleThreadScheduledExecutor();
|
||||
|
||||
private final GlobalAsyncStateManager stateManager;
|
||||
private final ScheduledExecutorService debugLoop;
|
||||
|
||||
public BufferManager() {
|
||||
memoryManager = new GlobalMemoryManager(TOTAL_QUEUES_MAX_SIZE_LIMIT_BYTES);
|
||||
this((long) (Runtime.getRuntime().maxMemory() * 0.8));
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
public BufferManager(final long memoryLimit) {
|
||||
maxMemory = memoryLimit;
|
||||
LOGGER.info("Memory available to the JVM {}", FileUtils.byteCountToDisplaySize(maxMemory));
|
||||
memoryManager = new GlobalMemoryManager(maxMemory);
|
||||
this.stateManager = new GlobalAsyncStateManager(memoryManager);
|
||||
buffers = new ConcurrentHashMap<>();
|
||||
bufferEnqueue = new BufferEnqueue(memoryManager, buffers);
|
||||
bufferDequeue = new BufferDequeue(memoryManager, buffers);
|
||||
bufferEnqueue = new BufferEnqueue(memoryManager, buffers, stateManager);
|
||||
bufferDequeue = new BufferDequeue(memoryManager, buffers, stateManager);
|
||||
debugLoop = Executors.newSingleThreadScheduledExecutor();
|
||||
debugLoop.scheduleAtFixedRate(this::printQueueInfo, 0, 10, TimeUnit.SECONDS);
|
||||
}
|
||||
|
||||
public GlobalAsyncStateManager getStateManager() {
|
||||
return stateManager;
|
||||
}
|
||||
|
||||
public BufferEnqueue getBufferEnqueue() {
|
||||
return bufferEnqueue;
|
||||
}
|
||||
@@ -57,16 +77,18 @@ public class BufferManager {
|
||||
final var queueInfo = new StringBuilder().append("QUEUE INFO").append(System.lineSeparator());
|
||||
|
||||
queueInfo
|
||||
.append(String.format(" Global Mem Manager -- max: %s, allocated: %s",
|
||||
FileUtils.byteCountToDisplaySize(memoryManager.getMaxMemoryBytes()),
|
||||
FileUtils.byteCountToDisplaySize(memoryManager.getCurrentMemoryBytes())))
|
||||
.append(String.format(" Global Mem Manager -- max: %s, allocated: %s (%s MB), %% used: %s",
|
||||
AirbyteFileUtils.byteCountToDisplaySize(memoryManager.getMaxMemoryBytes()),
|
||||
AirbyteFileUtils.byteCountToDisplaySize(memoryManager.getCurrentMemoryBytes()),
|
||||
(double) memoryManager.getCurrentMemoryBytes() / 1024 / 1024,
|
||||
(double) memoryManager.getCurrentMemoryBytes() / memoryManager.getMaxMemoryBytes()))
|
||||
.append(System.lineSeparator());
|
||||
|
||||
for (final var entry : buffers.entrySet()) {
|
||||
final var queue = entry.getValue();
|
||||
queueInfo.append(
|
||||
String.format(" Queue name: %s, num records: %d, num bytes: %s",
|
||||
entry.getKey().getName(), queue.size(), FileUtils.byteCountToDisplaySize(queue.getCurrentMemoryUsage())))
|
||||
entry.getKey().getName(), queue.size(), AirbyteFileUtils.byteCountToDisplaySize(queue.getCurrentMemoryUsage())))
|
||||
.append(System.lineSeparator());
|
||||
}
|
||||
log.info(queueInfo.toString());
|
||||
|
||||
@@ -5,8 +5,13 @@
|
||||
package io.airbyte.integrations.destination_async.buffers;
|
||||
|
||||
import io.airbyte.integrations.destination_async.GlobalMemoryManager;
|
||||
import io.airbyte.protocol.models.v0.AirbyteMessage;
|
||||
import java.util.stream.Stream;
|
||||
import io.airbyte.integrations.destination_async.buffers.StreamAwareQueue.MessageWithMeta;
|
||||
import io.airbyte.integrations.destination_async.partial_messages.PartialAirbyteMessage;
|
||||
import io.airbyte.integrations.destination_async.state.GlobalAsyncStateManager;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
/**
|
||||
* POJO abstraction representing one discrete buffer read. This allows ergonomics dequeues by
|
||||
@@ -21,24 +26,47 @@ import java.util.stream.Stream;
|
||||
*/
|
||||
public class MemoryAwareMessageBatch implements AutoCloseable {
|
||||
|
||||
private Stream<AirbyteMessage> batch;
|
||||
private static final Logger LOGGER = LoggerFactory.getLogger(MemoryAwareMessageBatch.class);
|
||||
private final List<MessageWithMeta> batch;
|
||||
|
||||
private final long sizeInBytes;
|
||||
private final GlobalMemoryManager memoryManager;
|
||||
private final GlobalAsyncStateManager stateManager;
|
||||
|
||||
public MemoryAwareMessageBatch(final Stream<AirbyteMessage> batch, final long sizeInBytes, final GlobalMemoryManager memoryManager) {
|
||||
public MemoryAwareMessageBatch(final List<MessageWithMeta> batch,
|
||||
final long sizeInBytes,
|
||||
final GlobalMemoryManager memoryManager,
|
||||
final GlobalAsyncStateManager stateManager) {
|
||||
this.batch = batch;
|
||||
this.sizeInBytes = sizeInBytes;
|
||||
this.memoryManager = memoryManager;
|
||||
this.stateManager = stateManager;
|
||||
}
|
||||
|
||||
public Stream<AirbyteMessage> getData() {
|
||||
public long getSizeInBytes() {
|
||||
return sizeInBytes;
|
||||
}
|
||||
|
||||
public List<MessageWithMeta> getData() {
|
||||
return batch;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() throws Exception {
|
||||
batch = null;
|
||||
memoryManager.free(sizeInBytes);
|
||||
}
|
||||
|
||||
/**
|
||||
* For the batch, marks all the states that have now been flushed. Also returns states that can be
|
||||
* flushed. This method is descriptrive, it assumes that whatever consumes the state messages emits
|
||||
* them, internally it purges the states it returns. message that it can.
|
||||
* <p>
|
||||
*
|
||||
* @return list of states that can be flushed
|
||||
*/
|
||||
public List<PartialAirbyteMessage> flushStates(final Map<Long, Long> stateIdToCount) {
|
||||
stateIdToCount.forEach(stateManager::decrement);
|
||||
return stateManager.flushStates();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -4,12 +4,10 @@
|
||||
|
||||
package io.airbyte.integrations.destination_async.buffers;
|
||||
|
||||
import java.time.Instant;
|
||||
import java.util.Optional;
|
||||
import java.util.concurrent.LinkedBlockingQueue;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.concurrent.atomic.AtomicLong;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
import javax.annotation.Nonnull;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
/**
|
||||
@@ -17,7 +15,7 @@ import lombok.extern.slf4j.Slf4j;
|
||||
* bounded on number of items in the queue, it is bounded by the memory it is allowed to use. The
|
||||
* amount of memory it is allowed to use can be resized after it is instantiated.
|
||||
* <p>
|
||||
* This class intentaionally hides the underlying queue inside of it. For this class to work, it has
|
||||
* This class intentionally hides the underlying queue inside of it. For this class to work, it has
|
||||
* to override each method on a queue that adds or removes records from the queue. The Queue
|
||||
* interface has a lot of methods to override, and we don't want to spend the time overriding a lot
|
||||
* of methods that won't be used. By hiding the queue, we avoid someone accidentally using a queue
|
||||
@@ -41,11 +39,7 @@ class MemoryBoundedLinkedBlockingQueue<E> {
|
||||
}
|
||||
|
||||
public void addMaxMemory(final long maxMemoryUsage) {
|
||||
this.hiddenQueue.maxMemoryUsage.addAndGet(maxMemoryUsage);
|
||||
}
|
||||
|
||||
public Optional<Instant> getTimeOfLastMessage() {
|
||||
return Optional.ofNullable(hiddenQueue.timeOfLastMessage.get());
|
||||
hiddenQueue.maxMemoryUsage.addAndGet(maxMemoryUsage);
|
||||
}
|
||||
|
||||
public int size() {
|
||||
@@ -56,6 +50,10 @@ class MemoryBoundedLinkedBlockingQueue<E> {
|
||||
return hiddenQueue.offer(e, itemSizeInBytes);
|
||||
}
|
||||
|
||||
public MemoryBoundedLinkedBlockingQueue.MemoryItem<E> peek() {
|
||||
return hiddenQueue.peek();
|
||||
}
|
||||
|
||||
public MemoryBoundedLinkedBlockingQueue.MemoryItem<E> take() throws InterruptedException {
|
||||
return hiddenQueue.take();
|
||||
}
|
||||
@@ -78,12 +76,10 @@ class MemoryBoundedLinkedBlockingQueue<E> {
|
||||
|
||||
private final AtomicLong currentMemoryUsage;
|
||||
private final AtomicLong maxMemoryUsage;
|
||||
private final AtomicReference<Instant> timeOfLastMessage;
|
||||
|
||||
public HiddenQueue(final long maxMemoryUsage) {
|
||||
currentMemoryUsage = new AtomicLong(0);
|
||||
this.maxMemoryUsage = new AtomicLong(maxMemoryUsage);
|
||||
timeOfLastMessage = new AtomicReference<>(null);
|
||||
}
|
||||
|
||||
public boolean offer(final E e, final long itemSizeInBytes) {
|
||||
@@ -92,9 +88,6 @@ class MemoryBoundedLinkedBlockingQueue<E> {
|
||||
final boolean success = super.offer(new MemoryItem<>(e, itemSizeInBytes));
|
||||
if (!success) {
|
||||
currentMemoryUsage.addAndGet(-itemSizeInBytes);
|
||||
} else {
|
||||
// it succeeded!
|
||||
timeOfLastMessage.set(Instant.now());
|
||||
}
|
||||
log.debug("offer status: {}", success);
|
||||
return success;
|
||||
@@ -105,6 +98,7 @@ class MemoryBoundedLinkedBlockingQueue<E> {
|
||||
}
|
||||
}
|
||||
|
||||
@Nonnull
|
||||
@Override
|
||||
public MemoryBoundedLinkedBlockingQueue.MemoryItem<E> take() throws InterruptedException {
|
||||
final MemoryItem<E> memoryItem = super.take();
|
||||
|
||||
@@ -0,0 +1,69 @@
|
||||
/*
|
||||
* Copyright (c) 2023 Airbyte, Inc., all rights reserved.
|
||||
*/
|
||||
|
||||
package io.airbyte.integrations.destination_async.buffers;
|
||||
|
||||
import io.airbyte.integrations.destination_async.partial_messages.PartialAirbyteMessage;
|
||||
import java.time.Instant;
|
||||
import java.util.Optional;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
@Slf4j
|
||||
public class StreamAwareQueue {
|
||||
|
||||
private final AtomicReference<Instant> timeOfLastMessage;
|
||||
|
||||
private final MemoryBoundedLinkedBlockingQueue<MessageWithMeta> memoryAwareQueue;
|
||||
|
||||
public StreamAwareQueue(final long maxMemoryUsage) {
|
||||
memoryAwareQueue = new MemoryBoundedLinkedBlockingQueue<>(maxMemoryUsage);
|
||||
timeOfLastMessage = new AtomicReference<>();
|
||||
}
|
||||
|
||||
public long getCurrentMemoryUsage() {
|
||||
return memoryAwareQueue.getCurrentMemoryUsage();
|
||||
}
|
||||
|
||||
public void addMaxMemory(final long maxMemoryUsage) {
|
||||
memoryAwareQueue.addMaxMemory(maxMemoryUsage);
|
||||
}
|
||||
|
||||
public Optional<Instant> getTimeOfLastMessage() {
|
||||
return Optional.ofNullable(timeOfLastMessage.get());
|
||||
}
|
||||
|
||||
public Optional<MemoryBoundedLinkedBlockingQueue.MemoryItem<MessageWithMeta>> peek() {
|
||||
return Optional.ofNullable(memoryAwareQueue.peek());
|
||||
}
|
||||
|
||||
public int size() {
|
||||
return memoryAwareQueue.size();
|
||||
}
|
||||
|
||||
public boolean offer(final PartialAirbyteMessage message, final long messageSizeInBytes, final long stateId) {
|
||||
if (memoryAwareQueue.offer(new MessageWithMeta(message, stateId), messageSizeInBytes)) {
|
||||
timeOfLastMessage.set(Instant.now());
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
public MemoryBoundedLinkedBlockingQueue.MemoryItem<MessageWithMeta> take() throws InterruptedException {
|
||||
return memoryAwareQueue.take();
|
||||
}
|
||||
|
||||
public MemoryBoundedLinkedBlockingQueue.MemoryItem<MessageWithMeta> poll() {
|
||||
return memoryAwareQueue.poll();
|
||||
}
|
||||
|
||||
public MemoryBoundedLinkedBlockingQueue.MemoryItem<MessageWithMeta> poll(final long timeout, final TimeUnit unit) throws InterruptedException {
|
||||
return memoryAwareQueue.poll(timeout, unit);
|
||||
}
|
||||
|
||||
public record MessageWithMeta(PartialAirbyteMessage message, long stateId) {}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,117 @@
|
||||
/*
|
||||
* Copyright (c) 2023 Airbyte, Inc., all rights reserved.
|
||||
*/
|
||||
|
||||
package io.airbyte.integrations.destination_async.partial_messages;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
import com.fasterxml.jackson.annotation.JsonPropertyDescription;
|
||||
import io.airbyte.protocol.models.v0.AirbyteMessage;
|
||||
import java.util.Objects;
|
||||
|
||||
public class PartialAirbyteMessage {
|
||||
|
||||
@JsonProperty("type")
|
||||
@JsonPropertyDescription("Message type")
|
||||
private AirbyteMessage.Type type;
|
||||
|
||||
@JsonProperty("record")
|
||||
private PartialAirbyteRecordMessage record;
|
||||
|
||||
@JsonProperty("state")
|
||||
private PartialAirbyteStateMessage state;
|
||||
|
||||
@JsonProperty("serialized")
|
||||
private String serialized;
|
||||
|
||||
public PartialAirbyteMessage() {}
|
||||
|
||||
@JsonProperty("type")
|
||||
public AirbyteMessage.Type getType() {
|
||||
return type;
|
||||
}
|
||||
|
||||
@JsonProperty("type")
|
||||
public void setType(final AirbyteMessage.Type type) {
|
||||
this.type = type;
|
||||
}
|
||||
|
||||
public PartialAirbyteMessage withType(final AirbyteMessage.Type type) {
|
||||
this.type = type;
|
||||
return this;
|
||||
}
|
||||
|
||||
@JsonProperty("record")
|
||||
public PartialAirbyteRecordMessage getRecord() {
|
||||
return record;
|
||||
}
|
||||
|
||||
@JsonProperty("record")
|
||||
public void setRecord(final PartialAirbyteRecordMessage record) {
|
||||
this.record = record;
|
||||
}
|
||||
|
||||
public PartialAirbyteMessage withRecord(final PartialAirbyteRecordMessage record) {
|
||||
this.record = record;
|
||||
return this;
|
||||
}
|
||||
|
||||
@JsonProperty("state")
|
||||
public PartialAirbyteStateMessage getState() {
|
||||
return state;
|
||||
}
|
||||
|
||||
@JsonProperty("state")
|
||||
public void setState(final PartialAirbyteStateMessage state) {
|
||||
this.state = state;
|
||||
}
|
||||
|
||||
public PartialAirbyteMessage withState(final PartialAirbyteStateMessage state) {
|
||||
this.state = state;
|
||||
return this;
|
||||
}
|
||||
|
||||
@JsonProperty("serialized")
|
||||
public String getSerialized() {
|
||||
return serialized;
|
||||
}
|
||||
|
||||
@JsonProperty("serialized")
|
||||
public void setSerialized(final String serialized) {
|
||||
this.serialized = serialized;
|
||||
}
|
||||
|
||||
public PartialAirbyteMessage withSerialized(final String serialized) {
|
||||
this.serialized = serialized;
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(final Object o) {
|
||||
if (this == o) {
|
||||
return true;
|
||||
}
|
||||
if (o == null || getClass() != o.getClass()) {
|
||||
return false;
|
||||
}
|
||||
final PartialAirbyteMessage that = (PartialAirbyteMessage) o;
|
||||
return type == that.type && Objects.equals(record, that.record) && Objects.equals(state, that.state)
|
||||
&& Objects.equals(serialized, that.serialized);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(type, record, state, serialized);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "PartialAirbyteMessage{" +
|
||||
"type=" + type +
|
||||
", record=" + record +
|
||||
", state=" + state +
|
||||
", serialized='" + serialized + '\'' +
|
||||
'}';
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,74 @@
|
||||
/*
|
||||
* Copyright (c) 2023 Airbyte, Inc., all rights reserved.
|
||||
*/
|
||||
|
||||
package io.airbyte.integrations.destination_async.partial_messages;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
import java.util.Objects;
|
||||
|
||||
public class PartialAirbyteRecordMessage {
|
||||
|
||||
@JsonProperty("namespace")
|
||||
private String namespace;
|
||||
@JsonProperty("stream")
|
||||
private String stream;
|
||||
|
||||
public PartialAirbyteRecordMessage() {}
|
||||
|
||||
@JsonProperty("namespace")
|
||||
public String getNamespace() {
|
||||
return namespace;
|
||||
}
|
||||
|
||||
@JsonProperty("namespace")
|
||||
public void setNamespace(final String namespace) {
|
||||
this.namespace = namespace;
|
||||
}
|
||||
|
||||
public PartialAirbyteRecordMessage withNamespace(final String namespace) {
|
||||
this.namespace = namespace;
|
||||
return this;
|
||||
}
|
||||
|
||||
@JsonProperty("stream")
|
||||
public String getStream() {
|
||||
return stream;
|
||||
}
|
||||
|
||||
@JsonProperty("stream")
|
||||
public void setStream(final String stream) {
|
||||
this.stream = stream;
|
||||
}
|
||||
|
||||
public PartialAirbyteRecordMessage withStream(final String stream) {
|
||||
this.stream = stream;
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(final Object o) {
|
||||
if (this == o) {
|
||||
return true;
|
||||
}
|
||||
if (o == null || getClass() != o.getClass()) {
|
||||
return false;
|
||||
}
|
||||
final PartialAirbyteRecordMessage that = (PartialAirbyteRecordMessage) o;
|
||||
return Objects.equals(namespace, that.namespace) && Objects.equals(stream, that.stream);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(namespace, stream);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "PartialAirbyteRecordMessage{" +
|
||||
"namespace='" + namespace + '\'' +
|
||||
", stream='" + stream + '\'' +
|
||||
'}';
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,76 @@
|
||||
/*
|
||||
* Copyright (c) 2023 Airbyte, Inc., all rights reserved.
|
||||
*/
|
||||
|
||||
package io.airbyte.integrations.destination_async.partial_messages;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
import io.airbyte.protocol.models.v0.AirbyteStateMessage.AirbyteStateType;
|
||||
import java.util.Objects;
|
||||
|
||||
public class PartialAirbyteStateMessage {
|
||||
|
||||
@JsonProperty("type")
|
||||
private AirbyteStateType type;
|
||||
|
||||
@JsonProperty("stream")
|
||||
private PartialAirbyteStreamState stream;
|
||||
|
||||
public PartialAirbyteStateMessage() {}
|
||||
|
||||
@JsonProperty("type")
|
||||
public AirbyteStateType getType() {
|
||||
return type;
|
||||
}
|
||||
|
||||
@JsonProperty("type")
|
||||
public void setType(final AirbyteStateType type) {
|
||||
this.type = type;
|
||||
}
|
||||
|
||||
public PartialAirbyteStateMessage withType(final AirbyteStateType type) {
|
||||
this.type = type;
|
||||
return this;
|
||||
}
|
||||
|
||||
@JsonProperty("stream")
|
||||
public PartialAirbyteStreamState getStream() {
|
||||
return stream;
|
||||
}
|
||||
|
||||
@JsonProperty("stream")
|
||||
public void setStream(final PartialAirbyteStreamState stream) {
|
||||
this.stream = stream;
|
||||
}
|
||||
|
||||
public PartialAirbyteStateMessage withStream(final PartialAirbyteStreamState stream) {
|
||||
this.stream = stream;
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(final Object o) {
|
||||
if (this == o) {
|
||||
return true;
|
||||
}
|
||||
if (o == null || getClass() != o.getClass()) {
|
||||
return false;
|
||||
}
|
||||
final PartialAirbyteStateMessage that = (PartialAirbyteStateMessage) o;
|
||||
return type == that.type && Objects.equals(stream, that.stream);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(type, stream);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "PartialAirbyteStateMessage{" +
|
||||
"type=" + type +
|
||||
", stream=" + stream +
|
||||
'}';
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,59 @@
|
||||
/*
|
||||
* Copyright (c) 2023 Airbyte, Inc., all rights reserved.
|
||||
*/
|
||||
|
||||
package io.airbyte.integrations.destination_async.partial_messages;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
import io.airbyte.protocol.models.v0.StreamDescriptor;
|
||||
import java.util.Objects;
|
||||
|
||||
public class PartialAirbyteStreamState {
|
||||
|
||||
@JsonProperty("stream_descriptor")
|
||||
private StreamDescriptor streamDescriptor;
|
||||
|
||||
public PartialAirbyteStreamState() {
|
||||
streamDescriptor = streamDescriptor;
|
||||
}
|
||||
|
||||
@JsonProperty("stream_descriptor")
|
||||
public StreamDescriptor getStreamDescriptor() {
|
||||
return streamDescriptor;
|
||||
}
|
||||
|
||||
@JsonProperty("stream_descriptor")
|
||||
public void setStreamDescriptor(final StreamDescriptor streamDescriptor) {
|
||||
this.streamDescriptor = streamDescriptor;
|
||||
}
|
||||
|
||||
public PartialAirbyteStreamState withStreamDescriptor(final StreamDescriptor streamDescriptor) {
|
||||
this.streamDescriptor = streamDescriptor;
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(final Object o) {
|
||||
if (this == o) {
|
||||
return true;
|
||||
}
|
||||
if (o == null || getClass() != o.getClass()) {
|
||||
return false;
|
||||
}
|
||||
final PartialAirbyteStreamState that = (PartialAirbyteStreamState) o;
|
||||
return Objects.equals(streamDescriptor, that.streamDescriptor);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(streamDescriptor);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "PartialAirbyteStreamState{" +
|
||||
"streamDescriptor=" + streamDescriptor +
|
||||
'}';
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,29 @@
|
||||
/*
|
||||
* Copyright (c) 2023 Airbyte, Inc., all rights reserved.
|
||||
*/
|
||||
|
||||
package io.airbyte.integrations.destination_async.state;
|
||||
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
|
||||
public class FlushFailure {
|
||||
|
||||
private final AtomicBoolean isFailed = new AtomicBoolean(false);
|
||||
|
||||
private final AtomicReference<Exception> exceptionAtomicReference = new AtomicReference<>();
|
||||
|
||||
public void propagateException(Exception e) {
|
||||
this.isFailed.set(true);
|
||||
this.exceptionAtomicReference.set(e);
|
||||
}
|
||||
|
||||
public boolean isFailed() {
|
||||
return isFailed.get();
|
||||
}
|
||||
|
||||
public Exception getException() {
|
||||
return exceptionAtomicReference.get();
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,328 @@
|
||||
/*
|
||||
* Copyright (c) 2023 Airbyte, Inc., all rights reserved.
|
||||
*/
|
||||
|
||||
package io.airbyte.integrations.destination_async.state;
|
||||
|
||||
import static java.lang.Thread.sleep;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import io.airbyte.integrations.destination_async.GlobalMemoryManager;
|
||||
import io.airbyte.integrations.destination_async.partial_messages.PartialAirbyteMessage;
|
||||
import io.airbyte.integrations.destination_async.partial_messages.PartialAirbyteStreamState;
|
||||
import io.airbyte.protocol.models.v0.AirbyteMessage;
|
||||
import io.airbyte.protocol.models.v0.AirbyteStateMessage;
|
||||
import io.airbyte.protocol.models.v0.StreamDescriptor;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.LinkedList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.UUID;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.concurrent.ConcurrentMap;
|
||||
import java.util.concurrent.atomic.AtomicLong;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.io.FileUtils;
|
||||
import org.apache.commons.lang3.tuple.ImmutablePair;
|
||||
import org.apache.mina.util.ConcurrentHashSet;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
/**
|
||||
* Responsible for managing state within the Destination. The general approach is a ref counter
|
||||
* approach - each state message is associated with a record count. This count represents the number
|
||||
* of preceding records. For a state to be emitted, all preceding records have to be written to the
|
||||
* destination i.e. the counter is 0.
|
||||
* <p>
|
||||
* A per-stream state queue is maintained internally, with each state within the queue having a
|
||||
* counter. This means we *ALLOW* records succeeding an unemitted state to be written. This
|
||||
* decouples record writing from state management at the cost of potentially repeating work if an
|
||||
* upstream state is never written.
|
||||
* <p>
|
||||
* One important detail here is the difference between how PER-STREAM & NON-PER-STREAM is handled.
|
||||
* The PER-STREAM case is simple, and is as described above. The NON-PER-STREAM case is slightly
|
||||
* tricky. Because we don't know the stream type to begin with, we always assume PER_STREAM until
|
||||
* the first state message arrives. If this state message is a GLOBAL state, we alias all existing
|
||||
* state ids to a single global state id via a set of alias ids. From then onwards, we use one id -
|
||||
* {@link #SENTINEL_GLOBAL_DESC} regardless of stream. Read
|
||||
* {@link #convertToGlobalIfNeeded(AirbyteMessage)} for more detail.
|
||||
*/
|
||||
@Slf4j
|
||||
public class GlobalAsyncStateManager {
|
||||
|
||||
private static final Logger LOGGER = LoggerFactory.getLogger(GlobalAsyncStateManager.class);
|
||||
|
||||
private static final StreamDescriptor SENTINEL_GLOBAL_DESC = new StreamDescriptor().withName(UUID.randomUUID().toString());
|
||||
private final GlobalMemoryManager memoryManager;
|
||||
|
||||
/**
|
||||
* Memory that the manager has allocated to it to use. It can ask for more memory as needed.
|
||||
*/
|
||||
private final AtomicLong memoryAllocated;
|
||||
/**
|
||||
* Memory that the manager is currently using.
|
||||
*/
|
||||
private final AtomicLong memoryUsed;
|
||||
|
||||
boolean preState = true;
|
||||
private final ConcurrentMap<Long, AtomicLong> stateIdToCounter = new ConcurrentHashMap<>();
|
||||
private final ConcurrentMap<StreamDescriptor, LinkedList<Long>> streamToStateIdQ = new ConcurrentHashMap<>();
|
||||
|
||||
private final ConcurrentMap<Long, ImmutablePair<PartialAirbyteMessage, Long>> stateIdToState = new ConcurrentHashMap<>();
|
||||
// empty in the STREAM case.
|
||||
|
||||
// Alias-ing only exists in the non-STREAM case where we have to convert existing state ids to one
|
||||
// single global id.
|
||||
// This only happens once.
|
||||
private final Set<Long> aliasIds = new ConcurrentHashSet<>();
|
||||
private long retroactiveGlobalStateId = 0;
|
||||
|
||||
public GlobalAsyncStateManager(final GlobalMemoryManager memoryManager) {
|
||||
this.memoryManager = memoryManager;
|
||||
memoryAllocated = new AtomicLong(memoryManager.requestMemory());
|
||||
memoryUsed = new AtomicLong();
|
||||
}
|
||||
|
||||
// Always assume STREAM to begin, and convert only if needed. Most state is per stream anyway.
|
||||
private AirbyteStateMessage.AirbyteStateType stateType = AirbyteStateMessage.AirbyteStateType.STREAM;
|
||||
|
||||
/**
|
||||
* Main method to process state messages.
|
||||
* <p>
|
||||
* The first incoming state message tells us the type of state we are dealing with. We then convert
|
||||
* internal data structures if needed.
|
||||
* <p>
|
||||
* Because state messages are a watermark, all preceding records need to be flushed before the state
|
||||
* message can be processed.
|
||||
*/
|
||||
public void trackState(final PartialAirbyteMessage message, final long sizeInBytes) {
|
||||
if (preState) {
|
||||
convertToGlobalIfNeeded(message);
|
||||
preState = false;
|
||||
}
|
||||
// stateType should not change after a conversion.
|
||||
Preconditions.checkArgument(stateType == extractStateType(message));
|
||||
|
||||
closeState(message, sizeInBytes);
|
||||
}
|
||||
|
||||
/**
|
||||
* Identical to {@link #getStateId(StreamDescriptor)} except this increments the associated counter
|
||||
* by 1. Intended to be called whenever a record is ingested.
|
||||
*
|
||||
* @param streamDescriptor - stream to get stateId for.
|
||||
* @return state id
|
||||
*/
|
||||
public long getStateIdAndIncrementCounter(final StreamDescriptor streamDescriptor) {
|
||||
return getStateIdAndIncrement(streamDescriptor, 1);
|
||||
}
|
||||
|
||||
/**
|
||||
* Each decrement represent one written record for a state. A zero counter means there are no more
|
||||
* inflight records associated with a state and the state can be flushed.
|
||||
*
|
||||
* @param stateId reference to a state.
|
||||
* @param count to decrement.
|
||||
*/
|
||||
public void decrement(final long stateId, final long count) {
|
||||
log.trace("decrementing state id: {}, count: {}", stateId, count);
|
||||
stateIdToCounter.get(getStateAfterAlias(stateId)).addAndGet(-count);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns state messages with no more inflight records i.e. counter = 0 across all streams.
|
||||
* Intended to be called by {@link io.airbyte.integrations.destination_async.FlushWorkers} after a
|
||||
* worker has finished flushing its record batch.
|
||||
* <p>
|
||||
* The return list of states should be emitted back to the platform.
|
||||
*
|
||||
* @return list of state messages with no more inflight records.
|
||||
*/
|
||||
public List<PartialAirbyteMessage> flushStates() {
|
||||
final List<PartialAirbyteMessage> output = new ArrayList<>();
|
||||
Long bytesFlushed = 0L;
|
||||
for (final Map.Entry<StreamDescriptor, LinkedList<Long>> entry : streamToStateIdQ.entrySet()) {
|
||||
// remove all states with 0 counters.
|
||||
final LinkedList<Long> stateIdQueue = entry.getValue();
|
||||
while (true) {
|
||||
final Long oldestState = stateIdQueue.peek();
|
||||
final boolean emptyQ = oldestState == null;
|
||||
final boolean noCorrespondingStateMsg = stateIdToState.get(oldestState) == null;
|
||||
if (emptyQ || noCorrespondingStateMsg) {
|
||||
break;
|
||||
}
|
||||
|
||||
final boolean noPrevRecs = !stateIdToCounter.containsKey(oldestState);
|
||||
final boolean allRecsEmitted = stateIdToCounter.get(oldestState).get() == 0;
|
||||
if (noPrevRecs || allRecsEmitted) {
|
||||
entry.getValue().poll(); // poll to remove. no need to read as the earlier peek is still valid.
|
||||
output.add(stateIdToState.get(oldestState).getLeft());
|
||||
bytesFlushed += stateIdToState.get(oldestState).getRight();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
freeBytes(bytesFlushed);
|
||||
return output;
|
||||
}
|
||||
|
||||
private Long getStateIdAndIncrement(final StreamDescriptor streamDescriptor, final long increment) {
|
||||
final StreamDescriptor resolvedDescriptor = stateType == AirbyteStateMessage.AirbyteStateType.STREAM ? streamDescriptor : SENTINEL_GLOBAL_DESC;
|
||||
if (!streamToStateIdQ.containsKey(resolvedDescriptor)) {
|
||||
registerNewStreamDescriptor(resolvedDescriptor);
|
||||
}
|
||||
final Long stateId = streamToStateIdQ.get(resolvedDescriptor).peekLast();
|
||||
final var update = stateIdToCounter.get(stateId).addAndGet(increment);
|
||||
log.trace("State id: {}, count: {}", stateId, update);
|
||||
return stateId;
|
||||
}
|
||||
|
||||
/**
|
||||
* Return the internal id of a state message. This is the id that should be used to reference a
|
||||
* state when interacting with all methods in this class.
|
||||
*
|
||||
* @param streamDescriptor - stream to get stateId for.
|
||||
* @return state id
|
||||
*/
|
||||
private long getStateId(final StreamDescriptor streamDescriptor) {
|
||||
return getStateIdAndIncrement(streamDescriptor, 0);
|
||||
}
|
||||
|
||||
/**
|
||||
* Pass this the number of bytes that were flushed. It will track those internally and if the
|
||||
* memoryUsed gets signficantly lower than what is allocated, then it will return it to the memory
|
||||
* manager. We don't always return to the memory manager to avoid needlessly allocating /
|
||||
* de-allocating memory rapidly over a few bytes.
|
||||
*
|
||||
* @param bytesFlushed bytes that were flushed (and should be removed from memory used).
|
||||
*/
|
||||
private void freeBytes(final long bytesFlushed) {
|
||||
LOGGER.debug("Bytes flushed memory to store state message. Allocated: {}, Used: {}, Flushed: {}, % Used: {}",
|
||||
FileUtils.byteCountToDisplaySize(memoryAllocated.get()),
|
||||
FileUtils.byteCountToDisplaySize(memoryUsed.get()),
|
||||
FileUtils.byteCountToDisplaySize(bytesFlushed),
|
||||
(double) memoryUsed.get() / memoryAllocated.get());
|
||||
|
||||
memoryManager.free(bytesFlushed);
|
||||
memoryAllocated.addAndGet(-bytesFlushed);
|
||||
memoryUsed.addAndGet(-bytesFlushed);
|
||||
LOGGER.debug("Returned {} of memory back to the memory manager.", FileUtils.byteCountToDisplaySize(bytesFlushed));
|
||||
}
|
||||
|
||||
private void convertToGlobalIfNeeded(final PartialAirbyteMessage message) {
|
||||
// instead of checking for global or legacy, check for the inverse of stream.
|
||||
stateType = extractStateType(message);
|
||||
if (stateType != AirbyteStateMessage.AirbyteStateType.STREAM) {// alias old stream-level state ids to single global state id
|
||||
// upon conversion, all previous tracking data structures need to be cleared as we move
|
||||
// into the non-STREAM world for correctness.
|
||||
|
||||
aliasIds.addAll(streamToStateIdQ.values().stream().flatMap(Collection::stream).toList());
|
||||
streamToStateIdQ.clear();
|
||||
retroactiveGlobalStateId = StateIdProvider.getNextId();
|
||||
|
||||
streamToStateIdQ.put(SENTINEL_GLOBAL_DESC, new LinkedList<>());
|
||||
streamToStateIdQ.get(SENTINEL_GLOBAL_DESC).add(retroactiveGlobalStateId);
|
||||
|
||||
final long combinedCounter = stateIdToCounter.values()
|
||||
.stream()
|
||||
.mapToLong(AtomicLong::get)
|
||||
.sum();
|
||||
stateIdToCounter.clear();
|
||||
stateIdToCounter.put(retroactiveGlobalStateId, new AtomicLong(combinedCounter));
|
||||
}
|
||||
}
|
||||
|
||||
private AirbyteStateMessage.AirbyteStateType extractStateType(final PartialAirbyteMessage message) {
|
||||
if (message.getState().getType() == null) {
|
||||
// Treated the same as GLOBAL.
|
||||
return AirbyteStateMessage.AirbyteStateType.LEGACY;
|
||||
} else {
|
||||
return message.getState().getType();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* When a state message is received, 'close' the previous state to associate the existing state id
|
||||
* to the newly arrived state message. We also increment the state id in preparation for the next
|
||||
* state message.
|
||||
*/
|
||||
private void closeState(final PartialAirbyteMessage message, final long sizeInBytes) {
|
||||
final StreamDescriptor resolvedDescriptor = extractStream(message).orElse(SENTINEL_GLOBAL_DESC);
|
||||
stateIdToState.put(getStateId(resolvedDescriptor), ImmutablePair.of(message, sizeInBytes));
|
||||
registerNewStateId(resolvedDescriptor);
|
||||
|
||||
allocateMemoryToState(sizeInBytes);
|
||||
}
|
||||
|
||||
/**
|
||||
* Given the size of a state message, tracks how much memory the manager is using and requests
|
||||
* additional memory from the memory manager if needed.
|
||||
*
|
||||
* @param sizeInBytes size of the state message
|
||||
*/
|
||||
@SuppressWarnings("BusyWait")
|
||||
private void allocateMemoryToState(final long sizeInBytes) {
|
||||
if (memoryAllocated.get() < memoryUsed.get() + sizeInBytes) {
|
||||
while (memoryAllocated.get() < memoryUsed.get() + sizeInBytes) {
|
||||
memoryAllocated.addAndGet(memoryManager.requestMemory());
|
||||
try {
|
||||
LOGGER.debug("Insufficient memory to store state message. Allocated: {}, Used: {}, Size of State Msg: {}, Needed: {}",
|
||||
FileUtils.byteCountToDisplaySize(memoryAllocated.get()),
|
||||
FileUtils.byteCountToDisplaySize(memoryUsed.get()),
|
||||
FileUtils.byteCountToDisplaySize(sizeInBytes),
|
||||
FileUtils.byteCountToDisplaySize(sizeInBytes - (memoryAllocated.get() - memoryUsed.get())));
|
||||
sleep(1000);
|
||||
} catch (final InterruptedException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
memoryUsed.addAndGet(sizeInBytes);
|
||||
LOGGER.debug("State Manager memory usage: Allocated: {}, Used: {}, % Used {}",
|
||||
FileUtils.byteCountToDisplaySize(memoryAllocated.get()),
|
||||
FileUtils.byteCountToDisplaySize(memoryUsed.get()),
|
||||
(double) memoryUsed.get() / memoryAllocated.get());
|
||||
}
|
||||
|
||||
private static Optional<StreamDescriptor> extractStream(final PartialAirbyteMessage message) {
|
||||
return Optional.ofNullable(message.getState().getStream()).map(PartialAirbyteStreamState::getStreamDescriptor);
|
||||
}
|
||||
|
||||
private long getStateAfterAlias(final long stateId) {
|
||||
if (aliasIds.contains(stateId)) {
|
||||
return retroactiveGlobalStateId;
|
||||
} else {
|
||||
return stateId;
|
||||
}
|
||||
}
|
||||
|
||||
private void registerNewStreamDescriptor(final StreamDescriptor resolvedDescriptor) {
|
||||
streamToStateIdQ.put(resolvedDescriptor, new LinkedList<>());
|
||||
registerNewStateId(resolvedDescriptor);
|
||||
}
|
||||
|
||||
private void registerNewStateId(final StreamDescriptor resolvedDescriptor) {
|
||||
final long stateId = StateIdProvider.getNextId();
|
||||
streamToStateIdQ.get(resolvedDescriptor).add(stateId);
|
||||
stateIdToCounter.put(stateId, new AtomicLong(0));
|
||||
}
|
||||
|
||||
/**
|
||||
* Simplify internal tracking by providing a global always increasing counter for state ids.
|
||||
*/
|
||||
private static class StateIdProvider {
|
||||
|
||||
private static long pk = 0;
|
||||
|
||||
public static long getNextId() {
|
||||
return pk++;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,312 @@
|
||||
/*
|
||||
* Copyright (c) 2023 Airbyte, Inc., all rights reserved.
|
||||
*/
|
||||
|
||||
package io.airbyte.integrations.destination_async;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
import static org.mockito.ArgumentMatchers.eq;
|
||||
import static org.mockito.Mockito.atLeast;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.times;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
import com.fasterxml.jackson.databind.JsonNode;
|
||||
import com.google.common.collect.ImmutableMap;
|
||||
import com.google.common.collect.Lists;
|
||||
import io.airbyte.commons.json.Jsons;
|
||||
import io.airbyte.integrations.destination.buffered_stream_consumer.OnStartFunction;
|
||||
import io.airbyte.integrations.destination.buffered_stream_consumer.RecordSizeEstimator;
|
||||
import io.airbyte.integrations.destination_async.buffers.BufferManager;
|
||||
import io.airbyte.integrations.destination_async.partial_messages.PartialAirbyteMessage;
|
||||
import io.airbyte.integrations.destination_async.partial_messages.PartialAirbyteRecordMessage;
|
||||
import io.airbyte.integrations.destination_async.state.FlushFailure;
|
||||
import io.airbyte.protocol.models.Field;
|
||||
import io.airbyte.protocol.models.JsonSchemaType;
|
||||
import io.airbyte.protocol.models.v0.AirbyteMessage;
|
||||
import io.airbyte.protocol.models.v0.AirbyteMessage.Type;
|
||||
import io.airbyte.protocol.models.v0.AirbyteRecordMessage;
|
||||
import io.airbyte.protocol.models.v0.AirbyteStateMessage;
|
||||
import io.airbyte.protocol.models.v0.AirbyteStateMessage.AirbyteStateType;
|
||||
import io.airbyte.protocol.models.v0.AirbyteStreamState;
|
||||
import io.airbyte.protocol.models.v0.CatalogHelpers;
|
||||
import io.airbyte.protocol.models.v0.ConfiguredAirbyteCatalog;
|
||||
import io.airbyte.protocol.models.v0.StreamDescriptor;
|
||||
import java.io.IOException;
|
||||
import java.time.Instant;
|
||||
import java.util.Collection;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.Executors;
|
||||
import java.util.concurrent.Future;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.concurrent.TimeoutException;
|
||||
import java.util.concurrent.atomic.AtomicLong;
|
||||
import java.util.function.Consumer;
|
||||
import java.util.stream.Stream;
|
||||
import org.apache.commons.lang.RandomStringUtils;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Nested;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.mockito.ArgumentCaptor;
|
||||
|
||||
class AsyncStreamConsumerTest {
|
||||
|
||||
private static final int RECORD_SIZE_20_BYTES = 20;
|
||||
|
||||
private static final String SCHEMA_NAME = "public";
|
||||
private static final String STREAM_NAME = "id_and_name";
|
||||
private static final String STREAM_NAME2 = STREAM_NAME + 2;
|
||||
private static final StreamDescriptor STREAM1_DESC = new StreamDescriptor()
|
||||
.withNamespace(SCHEMA_NAME)
|
||||
.withName(STREAM_NAME);
|
||||
|
||||
private static final ConfiguredAirbyteCatalog CATALOG = new ConfiguredAirbyteCatalog().withStreams(List.of(
|
||||
CatalogHelpers.createConfiguredAirbyteStream(
|
||||
STREAM_NAME,
|
||||
SCHEMA_NAME,
|
||||
Field.of("id", JsonSchemaType.NUMBER),
|
||||
Field.of("name", JsonSchemaType.STRING)),
|
||||
CatalogHelpers.createConfiguredAirbyteStream(
|
||||
STREAM_NAME2,
|
||||
SCHEMA_NAME,
|
||||
Field.of("id", JsonSchemaType.NUMBER),
|
||||
Field.of("name", JsonSchemaType.STRING))));
|
||||
|
||||
private static final AirbyteMessage STATE_MESSAGE1 = new AirbyteMessage()
|
||||
.withType(Type.STATE)
|
||||
.withState(new AirbyteStateMessage()
|
||||
.withType(AirbyteStateType.STREAM)
|
||||
.withStream(new AirbyteStreamState().withStreamDescriptor(STREAM1_DESC).withStreamState(Jsons.jsonNode(1))));
|
||||
private static final AirbyteMessage STATE_MESSAGE2 = new AirbyteMessage()
|
||||
.withType(Type.STATE)
|
||||
.withState(new AirbyteStateMessage()
|
||||
.withType(AirbyteStateType.STREAM)
|
||||
.withStream(new AirbyteStreamState().withStreamDescriptor(STREAM1_DESC).withStreamState(Jsons.jsonNode(2))));
|
||||
|
||||
private AsyncStreamConsumer consumer;
|
||||
private OnStartFunction onStart;
|
||||
private DestinationFlushFunction flushFunction;
|
||||
private OnCloseFunction onClose;
|
||||
private Consumer<AirbyteMessage> outputRecordCollector;
|
||||
private FlushFailure flushFailure;
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
@BeforeEach
|
||||
void setup() {
|
||||
onStart = mock(OnStartFunction.class);
|
||||
onClose = mock(OnCloseFunction.class);
|
||||
flushFunction = mock(DestinationFlushFunction.class);
|
||||
outputRecordCollector = mock(Consumer.class);
|
||||
flushFailure = mock(FlushFailure.class);
|
||||
consumer = new AsyncStreamConsumer(
|
||||
outputRecordCollector,
|
||||
onStart,
|
||||
onClose,
|
||||
flushFunction,
|
||||
CATALOG,
|
||||
new BufferManager(),
|
||||
flushFailure);
|
||||
|
||||
when(flushFunction.getOptimalBatchSizeBytes()).thenReturn(10_000L);
|
||||
}
|
||||
|
||||
@Test
|
||||
void test1StreamWith1State() throws Exception {
|
||||
final List<PartialAirbyteMessage> expectedRecords = generateRecords(1_000);
|
||||
|
||||
consumer.start();
|
||||
consumeRecords(consumer, expectedRecords);
|
||||
consumer.accept(Jsons.serialize(STATE_MESSAGE1), RECORD_SIZE_20_BYTES);
|
||||
consumer.close();
|
||||
|
||||
verifyStartAndClose();
|
||||
|
||||
verifyRecords(STREAM_NAME, SCHEMA_NAME, expectedRecords);
|
||||
|
||||
verify(outputRecordCollector).accept(STATE_MESSAGE1);
|
||||
}
|
||||
|
||||
@Test
|
||||
void test1StreamWith2State() throws Exception {
|
||||
final List<PartialAirbyteMessage> expectedRecords = generateRecords(1_000);
|
||||
|
||||
consumer.start();
|
||||
consumeRecords(consumer, expectedRecords);
|
||||
consumer.accept(Jsons.serialize(STATE_MESSAGE1), RECORD_SIZE_20_BYTES);
|
||||
consumer.accept(Jsons.serialize(STATE_MESSAGE2), RECORD_SIZE_20_BYTES);
|
||||
consumer.close();
|
||||
|
||||
verifyStartAndClose();
|
||||
|
||||
verifyRecords(STREAM_NAME, SCHEMA_NAME, expectedRecords);
|
||||
|
||||
verify(outputRecordCollector, times(1)).accept(STATE_MESSAGE2);
|
||||
}
|
||||
|
||||
@Test
|
||||
void test1StreamWith0State() throws Exception {
|
||||
final List<PartialAirbyteMessage> expectedRecords = generateRecords(1_000);
|
||||
|
||||
consumer.start();
|
||||
consumeRecords(consumer, expectedRecords);
|
||||
consumer.close();
|
||||
|
||||
verifyStartAndClose();
|
||||
|
||||
verifyRecords(STREAM_NAME, SCHEMA_NAME, expectedRecords);
|
||||
}
|
||||
|
||||
@Test
|
||||
void testShouldBlockWhenQueuesAreFull() throws Exception {
|
||||
consumer.start();
|
||||
|
||||
}
|
||||
|
||||
/*
|
||||
* Tests that the consumer will block when the buffer is full. Achieves this by setting optimal
|
||||
* batch size to 0, so the flush worker never actually pulls anything from the queue.
|
||||
*/
|
||||
@Test
|
||||
void testBackPressure() throws Exception {
|
||||
flushFunction = mock(DestinationFlushFunction.class);
|
||||
flushFailure = mock(FlushFailure.class);
|
||||
consumer = new AsyncStreamConsumer(
|
||||
m -> {},
|
||||
() -> {},
|
||||
() -> {},
|
||||
flushFunction,
|
||||
CATALOG,
|
||||
new BufferManager(1024 * 10),
|
||||
flushFailure);
|
||||
when(flushFunction.getOptimalBatchSizeBytes()).thenReturn(0L);
|
||||
|
||||
final AtomicLong recordCount = new AtomicLong();
|
||||
|
||||
consumer.start();
|
||||
|
||||
final ExecutorService executor = Executors.newSingleThreadExecutor();
|
||||
while (true) {
|
||||
final Future<?> future = executor.submit(() -> {
|
||||
try {
|
||||
consumer.accept(Jsons.serialize(new AirbyteMessage()
|
||||
.withType(Type.RECORD)
|
||||
.withRecord(new AirbyteRecordMessage()
|
||||
.withStream(STREAM_NAME)
|
||||
.withNamespace(SCHEMA_NAME)
|
||||
.withEmittedAt(Instant.now().toEpochMilli())
|
||||
.withData(Jsons.jsonNode(recordCount.getAndIncrement())))),
|
||||
RECORD_SIZE_20_BYTES);
|
||||
} catch (final Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
});
|
||||
|
||||
try {
|
||||
future.get(1, TimeUnit.SECONDS);
|
||||
} catch (final TimeoutException e) {
|
||||
future.cancel(true); // Stop the operation running in thread
|
||||
break;
|
||||
}
|
||||
}
|
||||
executor.shutdownNow();
|
||||
|
||||
assertTrue(recordCount.get() < 1000, String.format("Record count was %s", recordCount.get()));
|
||||
}
|
||||
|
||||
@Nested
|
||||
class ErrorHandling {
|
||||
|
||||
@Test
|
||||
void testErrorOnAccept() throws Exception {
|
||||
when(flushFailure.isFailed()).thenReturn(false).thenReturn(true);
|
||||
when(flushFailure.getException()).thenReturn(new IOException("test exception"));
|
||||
|
||||
final var m = new AirbyteMessage()
|
||||
.withType(Type.RECORD)
|
||||
.withRecord(new AirbyteRecordMessage()
|
||||
.withStream(STREAM_NAME)
|
||||
.withNamespace(SCHEMA_NAME)
|
||||
.withEmittedAt(Instant.now().toEpochMilli())
|
||||
.withData(Jsons.deserialize("")));
|
||||
consumer.start();
|
||||
consumer.accept(Jsons.serialize(m), RECORD_SIZE_20_BYTES);
|
||||
assertThrows(IOException.class, () -> consumer.accept(Jsons.serialize(m), RECORD_SIZE_20_BYTES));
|
||||
}
|
||||
|
||||
@Test
|
||||
void testErrorOnClose() throws Exception {
|
||||
when(flushFailure.isFailed()).thenReturn(true);
|
||||
when(flushFailure.getException()).thenReturn(new IOException("test exception"));
|
||||
|
||||
consumer.start();
|
||||
assertThrows(IOException.class, () -> consumer.close());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
private static void consumeRecords(final AsyncStreamConsumer consumer, final Collection<PartialAirbyteMessage> records) {
|
||||
records.forEach(m -> {
|
||||
try {
|
||||
consumer.accept(m.getSerialized(), RECORD_SIZE_20_BYTES);
|
||||
} catch (final Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// NOTE: Generates records at chunks of 160 bytes
|
||||
@SuppressWarnings("SameParameterValue")
|
||||
private static List<PartialAirbyteMessage> generateRecords(final long targetSizeInBytes) {
|
||||
final List<PartialAirbyteMessage> output = Lists.newArrayList();
|
||||
long bytesCounter = 0;
|
||||
for (int i = 0;; i++) {
|
||||
final JsonNode payload =
|
||||
Jsons.jsonNode(ImmutableMap.of("id", RandomStringUtils.randomAlphabetic(7), "name", "human " + String.format("%8d", i)));
|
||||
final long sizeInBytes = RecordSizeEstimator.getStringByteSize(payload);
|
||||
bytesCounter += sizeInBytes;
|
||||
final PartialAirbyteMessage airbyteMessage = new PartialAirbyteMessage()
|
||||
.withType(Type.RECORD)
|
||||
.withRecord(new PartialAirbyteRecordMessage()
|
||||
.withStream(STREAM_NAME)
|
||||
.withNamespace(SCHEMA_NAME))
|
||||
.withSerialized(Jsons.serialize(new AirbyteMessage()
|
||||
.withType(Type.RECORD)
|
||||
.withRecord(new AirbyteRecordMessage()
|
||||
.withStream(STREAM_NAME)
|
||||
.withNamespace(SCHEMA_NAME)
|
||||
.withData(payload))));
|
||||
if (bytesCounter > targetSizeInBytes) {
|
||||
break;
|
||||
} else {
|
||||
output.add(airbyteMessage);
|
||||
}
|
||||
}
|
||||
return output;
|
||||
}
|
||||
|
||||
private void verifyStartAndClose() throws Exception {
|
||||
verify(onStart).call();
|
||||
verify(onClose).call();
|
||||
}
|
||||
|
||||
@SuppressWarnings({"unchecked", "SameParameterValue"})
|
||||
private void verifyRecords(final String streamName, final String namespace, final Collection<PartialAirbyteMessage> expectedRecords)
|
||||
throws Exception {
|
||||
final ArgumentCaptor<Stream<PartialAirbyteMessage>> argumentCaptor = ArgumentCaptor.forClass(Stream.class);
|
||||
verify(flushFunction, atLeast(1)).flush(
|
||||
eq(new StreamDescriptor().withNamespace(namespace).withName(streamName)),
|
||||
argumentCaptor.capture());
|
||||
|
||||
// captures the output of all the workers, since our records could come out in any of them.
|
||||
final List<PartialAirbyteMessage> actualRecords = argumentCaptor
|
||||
.getAllValues()
|
||||
.stream()
|
||||
// flatten those results into a single list for the simplicity of comparison
|
||||
.flatMap(s -> s)
|
||||
.toList();
|
||||
assertEquals(expectedRecords.stream().toList(), actualRecords);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,73 @@
|
||||
/*
|
||||
* Copyright (c) 2023 Airbyte, Inc., all rights reserved.
|
||||
*/
|
||||
|
||||
package io.airbyte.integrations.destination_async;
|
||||
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
import io.airbyte.integrations.destination_async.buffers.BufferDequeue;
|
||||
import io.airbyte.integrations.destination_async.buffers.MemoryAwareMessageBatch;
|
||||
import io.airbyte.integrations.destination_async.partial_messages.PartialAirbyteMessage;
|
||||
import io.airbyte.integrations.destination_async.state.FlushFailure;
|
||||
import io.airbyte.integrations.destination_async.state.GlobalAsyncStateManager;
|
||||
import io.airbyte.protocol.models.v0.StreamDescriptor;
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
import java.util.stream.Stream;
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
public class FlushWorkersTest {
|
||||
|
||||
@Test
|
||||
void testErrorHandling() throws Exception {
|
||||
final AtomicBoolean hasThrownError = new AtomicBoolean(false);
|
||||
final var desc = new StreamDescriptor().withName("test");
|
||||
final var dequeue = mock(BufferDequeue.class);
|
||||
when(dequeue.getBufferedStreams()).thenReturn(Set.of(desc));
|
||||
when(dequeue.take(desc, 1000)).thenReturn(new MemoryAwareMessageBatch(List.of(), 10, null, null));
|
||||
when(dequeue.getQueueSizeBytes(desc)).thenReturn(Optional.of(10L));
|
||||
when(dequeue.getQueueSizeInRecords(desc)).thenAnswer(ignored -> {
|
||||
if (hasThrownError.get()) {
|
||||
return Optional.of(0L);
|
||||
} else {
|
||||
return Optional.of(1L);
|
||||
}
|
||||
});
|
||||
|
||||
final var flushFailure = new FlushFailure();
|
||||
final var workers = new FlushWorkers(dequeue, new ErrorOnFlush(hasThrownError), m -> {}, flushFailure, mock(GlobalAsyncStateManager.class));
|
||||
workers.start();
|
||||
workers.close();
|
||||
|
||||
Assertions.assertTrue(flushFailure.isFailed());
|
||||
Assertions.assertEquals(IOException.class, flushFailure.getException().getClass());
|
||||
}
|
||||
|
||||
private static class ErrorOnFlush implements DestinationFlushFunction {
|
||||
|
||||
private final AtomicBoolean hasThrownError;
|
||||
|
||||
public ErrorOnFlush(final AtomicBoolean hasThrownError) {
|
||||
this.hasThrownError = hasThrownError;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void flush(final StreamDescriptor desc, final Stream<PartialAirbyteMessage> stream) throws Exception {
|
||||
hasThrownError.set(true);
|
||||
throw new IOException("Error on flush");
|
||||
}
|
||||
|
||||
@Override
|
||||
public long getOptimalBatchSizeBytes() {
|
||||
return 1000;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
@@ -8,9 +8,9 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
import io.airbyte.commons.json.Jsons;
|
||||
import io.airbyte.protocol.models.v0.AirbyteMessage;
|
||||
import io.airbyte.integrations.destination_async.partial_messages.PartialAirbyteMessage;
|
||||
import io.airbyte.integrations.destination_async.partial_messages.PartialAirbyteRecordMessage;
|
||||
import io.airbyte.protocol.models.v0.AirbyteMessage.Type;
|
||||
import io.airbyte.protocol.models.v0.AirbyteRecordMessage;
|
||||
import io.airbyte.protocol.models.v0.StreamDescriptor;
|
||||
import java.time.Instant;
|
||||
import java.time.temporal.ChronoUnit;
|
||||
@@ -19,14 +19,14 @@ import org.junit.jupiter.api.Test;
|
||||
|
||||
public class BufferDequeueTest {
|
||||
|
||||
private static final int RECORD_SIZE_20_BYTES = 20;
|
||||
public static final String RECORD_20_BYTES = "abc";
|
||||
private static final String STREAM_NAME = "stream1";
|
||||
private static final StreamDescriptor STREAM_DESC = new StreamDescriptor().withName(STREAM_NAME);
|
||||
private static final AirbyteMessage RECORD_MSG_20_BYTES = new AirbyteMessage()
|
||||
private static final PartialAirbyteMessage RECORD_MSG_20_BYTES = new PartialAirbyteMessage()
|
||||
.withType(Type.RECORD)
|
||||
.withRecord(new AirbyteRecordMessage()
|
||||
.withStream(STREAM_NAME)
|
||||
.withData(Jsons.jsonNode(RECORD_20_BYTES)));
|
||||
.withRecord(new PartialAirbyteRecordMessage()
|
||||
.withStream(STREAM_NAME));
|
||||
|
||||
@Nested
|
||||
class Take {
|
||||
@@ -37,15 +37,17 @@ public class BufferDequeueTest {
|
||||
final BufferEnqueue enqueue = bufferManager.getBufferEnqueue();
|
||||
final BufferDequeue dequeue = bufferManager.getBufferDequeue();
|
||||
|
||||
enqueue.addRecord(STREAM_DESC, RECORD_MSG_20_BYTES);
|
||||
enqueue.addRecord(STREAM_DESC, RECORD_MSG_20_BYTES);
|
||||
enqueue.addRecord(STREAM_DESC, RECORD_MSG_20_BYTES);
|
||||
enqueue.addRecord(STREAM_DESC, RECORD_MSG_20_BYTES);
|
||||
enqueue.addRecord(RECORD_MSG_20_BYTES, RECORD_SIZE_20_BYTES);
|
||||
enqueue.addRecord(RECORD_MSG_20_BYTES, RECORD_SIZE_20_BYTES);
|
||||
enqueue.addRecord(RECORD_MSG_20_BYTES, RECORD_SIZE_20_BYTES);
|
||||
enqueue.addRecord(RECORD_MSG_20_BYTES, RECORD_SIZE_20_BYTES);
|
||||
|
||||
// total size of records is 80, so we expect 50 to get us 2 records (prefer to under-pull records
|
||||
// than over-pull).
|
||||
try (final MemoryAwareMessageBatch take = dequeue.take(STREAM_DESC, 50)) {
|
||||
assertEquals(2, take.getData().toList().size());
|
||||
assertEquals(2, take.getData().size());
|
||||
// verify it only took the records from the queue that it actually returned.
|
||||
assertEquals(2, dequeue.getQueueSizeInRecords(STREAM_DESC).orElseThrow());
|
||||
} catch (final Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
@@ -57,12 +59,12 @@ public class BufferDequeueTest {
|
||||
final BufferEnqueue enqueue = bufferManager.getBufferEnqueue();
|
||||
final BufferDequeue dequeue = bufferManager.getBufferDequeue();
|
||||
|
||||
enqueue.addRecord(STREAM_DESC, RECORD_MSG_20_BYTES);
|
||||
enqueue.addRecord(STREAM_DESC, RECORD_MSG_20_BYTES);
|
||||
enqueue.addRecord(STREAM_DESC, RECORD_MSG_20_BYTES);
|
||||
enqueue.addRecord(RECORD_MSG_20_BYTES, RECORD_SIZE_20_BYTES);
|
||||
enqueue.addRecord(RECORD_MSG_20_BYTES, RECORD_SIZE_20_BYTES);
|
||||
enqueue.addRecord(RECORD_MSG_20_BYTES, RECORD_SIZE_20_BYTES);
|
||||
|
||||
try (final MemoryAwareMessageBatch take = dequeue.take(STREAM_DESC, 60)) {
|
||||
assertEquals(3, take.getData().toList().size());
|
||||
assertEquals(3, take.getData().size());
|
||||
} catch (final Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
@@ -74,11 +76,11 @@ public class BufferDequeueTest {
|
||||
final BufferEnqueue enqueue = bufferManager.getBufferEnqueue();
|
||||
final BufferDequeue dequeue = bufferManager.getBufferDequeue();
|
||||
|
||||
enqueue.addRecord(STREAM_DESC, RECORD_MSG_20_BYTES);
|
||||
enqueue.addRecord(STREAM_DESC, RECORD_MSG_20_BYTES);
|
||||
enqueue.addRecord(RECORD_MSG_20_BYTES, RECORD_SIZE_20_BYTES);
|
||||
enqueue.addRecord(RECORD_MSG_20_BYTES, RECORD_SIZE_20_BYTES);
|
||||
|
||||
try (final MemoryAwareMessageBatch take = dequeue.take(STREAM_DESC, Long.MAX_VALUE)) {
|
||||
assertEquals(2, take.getData().toList().size());
|
||||
assertEquals(2, take.getData().size());
|
||||
} catch (final Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
@@ -92,11 +94,13 @@ public class BufferDequeueTest {
|
||||
final BufferEnqueue enqueue = bufferManager.getBufferEnqueue();
|
||||
final BufferDequeue dequeue = bufferManager.getBufferDequeue();
|
||||
|
||||
enqueue.addRecord(STREAM_DESC, RECORD_MSG_20_BYTES);
|
||||
enqueue.addRecord(STREAM_DESC, RECORD_MSG_20_BYTES);
|
||||
enqueue.addRecord(RECORD_MSG_20_BYTES, RECORD_SIZE_20_BYTES);
|
||||
enqueue.addRecord(RECORD_MSG_20_BYTES, RECORD_SIZE_20_BYTES);
|
||||
|
||||
final var secondStream = new StreamDescriptor().withName("stream_2");
|
||||
enqueue.addRecord(secondStream, RECORD_MSG_20_BYTES);
|
||||
final PartialAirbyteMessage recordFromSecondStream = Jsons.clone(RECORD_MSG_20_BYTES);
|
||||
recordFromSecondStream.getRecord().withStream(secondStream.getName());
|
||||
enqueue.addRecord(recordFromSecondStream, RECORD_SIZE_20_BYTES);
|
||||
|
||||
assertEquals(60, dequeue.getTotalGlobalQueueSizeBytes());
|
||||
|
||||
|
||||
@@ -5,32 +5,35 @@
|
||||
package io.airbyte.integrations.destination_async.buffers;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.mockito.Mockito.mock;
|
||||
|
||||
import io.airbyte.commons.json.Jsons;
|
||||
import io.airbyte.integrations.destination_async.GlobalMemoryManager;
|
||||
import io.airbyte.integrations.destination_async.partial_messages.PartialAirbyteMessage;
|
||||
import io.airbyte.integrations.destination_async.partial_messages.PartialAirbyteRecordMessage;
|
||||
import io.airbyte.integrations.destination_async.state.GlobalAsyncStateManager;
|
||||
import io.airbyte.protocol.models.v0.AirbyteMessage;
|
||||
import io.airbyte.protocol.models.v0.AirbyteRecordMessage;
|
||||
import io.airbyte.protocol.models.v0.StreamDescriptor;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
public class BufferEnqueueTest {
|
||||
|
||||
private static final int RECORD_SIZE_20_BYTES = 20;
|
||||
|
||||
@Test
|
||||
void testAddRecordShouldAdd() {
|
||||
final var twoMB = 2 * 1024 * 1024;
|
||||
final var streamToBuffer = new ConcurrentHashMap<StreamDescriptor, MemoryBoundedLinkedBlockingQueue<AirbyteMessage>>();
|
||||
final var enqueue = new BufferEnqueue(new GlobalMemoryManager(twoMB), streamToBuffer);
|
||||
final var streamToBuffer = new ConcurrentHashMap<StreamDescriptor, StreamAwareQueue>();
|
||||
final var enqueue = new BufferEnqueue(new GlobalMemoryManager(twoMB), streamToBuffer, mock(GlobalAsyncStateManager.class));
|
||||
|
||||
final var streamName = "stream";
|
||||
final var stream = new StreamDescriptor().withName(streamName);
|
||||
final var record = new AirbyteMessage()
|
||||
final var record = new PartialAirbyteMessage()
|
||||
.withType(AirbyteMessage.Type.RECORD)
|
||||
.withRecord(new AirbyteRecordMessage()
|
||||
.withStream(streamName)
|
||||
.withData(Jsons.jsonNode(BufferDequeueTest.RECORD_20_BYTES)));
|
||||
.withRecord(new PartialAirbyteRecordMessage()
|
||||
.withStream(streamName));
|
||||
|
||||
enqueue.addRecord(stream, record);
|
||||
enqueue.addRecord(record, RECORD_SIZE_20_BYTES);
|
||||
assertEquals(1, streamToBuffer.get(stream).size());
|
||||
assertEquals(20L, streamToBuffer.get(stream).getCurrentMemoryUsage());
|
||||
|
||||
@@ -39,20 +42,19 @@ public class BufferEnqueueTest {
|
||||
@Test
|
||||
public void testAddRecordShouldExpand() {
|
||||
final var oneKb = 1024;
|
||||
final var initialQueueSizeBytes = 20;
|
||||
final var streamToBuffer = new ConcurrentHashMap<StreamDescriptor, MemoryBoundedLinkedBlockingQueue<AirbyteMessage>>();
|
||||
final var enqueue = new BufferEnqueue(initialQueueSizeBytes, new GlobalMemoryManager(oneKb), streamToBuffer);
|
||||
final var streamToBuffer = new ConcurrentHashMap<StreamDescriptor, StreamAwareQueue>();
|
||||
final var enqueue =
|
||||
new BufferEnqueue(new GlobalMemoryManager(oneKb), streamToBuffer, mock(GlobalAsyncStateManager.class));
|
||||
|
||||
final var streamName = "stream";
|
||||
final var stream = new StreamDescriptor().withName(streamName);
|
||||
final var record = new AirbyteMessage()
|
||||
final var record = new PartialAirbyteMessage()
|
||||
.withType(AirbyteMessage.Type.RECORD)
|
||||
.withRecord(new AirbyteRecordMessage()
|
||||
.withStream(streamName)
|
||||
.withData(Jsons.jsonNode(BufferDequeueTest.RECORD_20_BYTES)));
|
||||
.withRecord(new PartialAirbyteRecordMessage()
|
||||
.withStream(streamName));
|
||||
|
||||
enqueue.addRecord(stream, record);
|
||||
enqueue.addRecord(stream, record);
|
||||
enqueue.addRecord(record, RECORD_SIZE_20_BYTES);
|
||||
enqueue.addRecord(record, RECORD_SIZE_20_BYTES);
|
||||
assertEquals(2, streamToBuffer.get(stream).size());
|
||||
assertEquals(40, streamToBuffer.get(stream).getCurrentMemoryUsage());
|
||||
|
||||
|
||||
@@ -26,28 +26,6 @@ public class MemoryBoundedLinkedBlockingQueueTest {
|
||||
assertEquals("abc", item.item());
|
||||
}
|
||||
|
||||
@Test
|
||||
void test() throws InterruptedException {
|
||||
final MemoryBoundedLinkedBlockingQueue<String> queue = new MemoryBoundedLinkedBlockingQueue<>(1024);
|
||||
|
||||
assertEquals(0, queue.getCurrentMemoryUsage());
|
||||
assertNull(queue.getTimeOfLastMessage().orElse(null));
|
||||
|
||||
queue.offer("abc", 6);
|
||||
queue.offer("abc", 6);
|
||||
queue.offer("abc", 6);
|
||||
|
||||
assertEquals(18, queue.getCurrentMemoryUsage());
|
||||
assertNotNull(queue.getTimeOfLastMessage().orElse(null));
|
||||
|
||||
queue.take();
|
||||
queue.take();
|
||||
queue.take();
|
||||
|
||||
assertEquals(0, queue.getCurrentMemoryUsage());
|
||||
assertNotNull(queue.getTimeOfLastMessage().orElse(null));
|
||||
}
|
||||
|
||||
@Test
|
||||
void testBlocksOnFullMemory() throws InterruptedException {
|
||||
final MemoryBoundedLinkedBlockingQueue<String> queue = new MemoryBoundedLinkedBlockingQueue<>(10);
|
||||
|
||||
@@ -0,0 +1,38 @@
|
||||
/*
|
||||
* Copyright (c) 2023 Airbyte, Inc., all rights reserved.
|
||||
*/
|
||||
|
||||
package io.airbyte.integrations.destination_async.buffers;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertNotNull;
|
||||
import static org.junit.jupiter.api.Assertions.assertNull;
|
||||
|
||||
import io.airbyte.integrations.destination_async.partial_messages.PartialAirbyteMessage;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
public class StreamAwareQueueTest {
|
||||
|
||||
@Test
|
||||
void test() throws InterruptedException {
|
||||
final StreamAwareQueue queue = new StreamAwareQueue(1024);
|
||||
|
||||
assertEquals(0, queue.getCurrentMemoryUsage());
|
||||
assertNull(queue.getTimeOfLastMessage().orElse(null));
|
||||
|
||||
queue.offer(new PartialAirbyteMessage(), 6, 1);
|
||||
queue.offer(new PartialAirbyteMessage(), 6, 2);
|
||||
queue.offer(new PartialAirbyteMessage(), 6, 3);
|
||||
|
||||
assertEquals(18, queue.getCurrentMemoryUsage());
|
||||
assertNotNull(queue.getTimeOfLastMessage().orElse(null));
|
||||
|
||||
queue.take();
|
||||
queue.take();
|
||||
queue.take();
|
||||
|
||||
assertEquals(0, queue.getCurrentMemoryUsage());
|
||||
assertNotNull(queue.getTimeOfLastMessage().orElse(null));
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,206 @@
|
||||
/*
|
||||
* Copyright (c) 2023 Airbyte, Inc., all rights reserved.
|
||||
*/
|
||||
|
||||
package io.airbyte.integrations.destination_async.state;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertNotEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||
|
||||
import io.airbyte.integrations.destination_async.GlobalMemoryManager;
|
||||
import io.airbyte.integrations.destination_async.partial_messages.PartialAirbyteMessage;
|
||||
import io.airbyte.integrations.destination_async.partial_messages.PartialAirbyteStateMessage;
|
||||
import io.airbyte.integrations.destination_async.partial_messages.PartialAirbyteStreamState;
|
||||
import io.airbyte.protocol.models.v0.AirbyteMessage.Type;
|
||||
import io.airbyte.protocol.models.v0.AirbyteStateMessage.AirbyteStateType;
|
||||
import io.airbyte.protocol.models.v0.StreamDescriptor;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
import org.junit.jupiter.api.Nested;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
class GlobalAsyncStateManagerTest {
|
||||
|
||||
private static final long TOTAL_QUEUES_MAX_SIZE_LIMIT_BYTES = 100 * 1024 * 1024; // 10MB
|
||||
|
||||
private static final long STATE_MSG_SIZE = 1000;
|
||||
|
||||
private static final String STREAM_NAME = "id_and_name";
|
||||
private static final String STREAM_NAME2 = STREAM_NAME + 2;
|
||||
private static final String STREAM_NAME3 = STREAM_NAME + 3;
|
||||
private static final StreamDescriptor STREAM1_DESC = new StreamDescriptor()
|
||||
.withName(STREAM_NAME);
|
||||
private static final StreamDescriptor STREAM2_DESC = new StreamDescriptor()
|
||||
.withName(STREAM_NAME2);
|
||||
private static final StreamDescriptor STREAM3_DESC = new StreamDescriptor()
|
||||
.withName(STREAM_NAME3);
|
||||
|
||||
private static final PartialAirbyteMessage GLOBAL_STATE_MESSAGE1 = new PartialAirbyteMessage()
|
||||
.withType(Type.STATE)
|
||||
.withState(new PartialAirbyteStateMessage()
|
||||
.withType(AirbyteStateType.GLOBAL));
|
||||
private static final PartialAirbyteMessage GLOBAL_STATE_MESSAGE2 = new PartialAirbyteMessage()
|
||||
.withType(Type.STATE)
|
||||
.withState(new PartialAirbyteStateMessage()
|
||||
.withType(AirbyteStateType.GLOBAL));
|
||||
private static final PartialAirbyteMessage STREAM1_STATE_MESSAGE1 = new PartialAirbyteMessage()
|
||||
.withType(Type.STATE)
|
||||
.withState(new PartialAirbyteStateMessage()
|
||||
.withType(AirbyteStateType.STREAM)
|
||||
.withStream(new PartialAirbyteStreamState().withStreamDescriptor(STREAM1_DESC)));
|
||||
private static final PartialAirbyteMessage STREAM1_STATE_MESSAGE2 = new PartialAirbyteMessage()
|
||||
.withType(Type.STATE)
|
||||
.withState(new PartialAirbyteStateMessage()
|
||||
.withType(AirbyteStateType.STREAM)
|
||||
.withStream(new PartialAirbyteStreamState().withStreamDescriptor(STREAM1_DESC)));
|
||||
|
||||
@Test
|
||||
void testBasic() {
|
||||
final GlobalAsyncStateManager stateManager = new GlobalAsyncStateManager(new GlobalMemoryManager(TOTAL_QUEUES_MAX_SIZE_LIMIT_BYTES));
|
||||
|
||||
final var firstStateId = stateManager.getStateIdAndIncrementCounter(STREAM1_DESC);
|
||||
final var secondStateId = stateManager.getStateIdAndIncrementCounter(STREAM1_DESC);
|
||||
assertEquals(firstStateId, secondStateId);
|
||||
|
||||
stateManager.decrement(firstStateId, 2);
|
||||
// because no state message has been tracked, there is nothing to flush yet.
|
||||
var flushed = stateManager.flushStates();
|
||||
assertEquals(0, flushed.size());
|
||||
|
||||
stateManager.trackState(STREAM1_STATE_MESSAGE1, STATE_MSG_SIZE);
|
||||
flushed = stateManager.flushStates();
|
||||
assertEquals(List.of(STREAM1_STATE_MESSAGE1), flushed);
|
||||
}
|
||||
|
||||
@Nested
|
||||
class GlobalState {
|
||||
|
||||
@Test
|
||||
void testEmptyQueuesGlobalState() {
|
||||
final GlobalAsyncStateManager stateManager = new GlobalAsyncStateManager(new GlobalMemoryManager(TOTAL_QUEUES_MAX_SIZE_LIMIT_BYTES));
|
||||
|
||||
// GLOBAL
|
||||
stateManager.trackState(GLOBAL_STATE_MESSAGE1, STATE_MSG_SIZE);
|
||||
assertEquals(List.of(GLOBAL_STATE_MESSAGE1), stateManager.flushStates());
|
||||
|
||||
assertThrows(IllegalArgumentException.class, () -> stateManager.trackState(STREAM1_STATE_MESSAGE1, STATE_MSG_SIZE));
|
||||
}
|
||||
|
||||
@Test
|
||||
void testConversion() {
|
||||
final GlobalAsyncStateManager stateManager = new GlobalAsyncStateManager(new GlobalMemoryManager(TOTAL_QUEUES_MAX_SIZE_LIMIT_BYTES));
|
||||
|
||||
final var preConvertId0 = simulateIncomingRecords(STREAM1_DESC, 10, stateManager);
|
||||
final var preConvertId1 = simulateIncomingRecords(STREAM2_DESC, 10, stateManager);
|
||||
final var preConvertId2 = simulateIncomingRecords(STREAM3_DESC, 10, stateManager);
|
||||
assertEquals(3, Set.of(preConvertId0, preConvertId1, preConvertId2).size());
|
||||
|
||||
stateManager.trackState(GLOBAL_STATE_MESSAGE1, STATE_MSG_SIZE);
|
||||
|
||||
// Since this is actually a global state, we can only flush after all streams are done.
|
||||
stateManager.decrement(preConvertId0, 10);
|
||||
assertEquals(List.of(), stateManager.flushStates());
|
||||
stateManager.decrement(preConvertId1, 10);
|
||||
assertEquals(List.of(), stateManager.flushStates());
|
||||
stateManager.decrement(preConvertId2, 10);
|
||||
assertEquals(List.of(GLOBAL_STATE_MESSAGE1), stateManager.flushStates());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testCorrectFlushingOneStream() {
|
||||
final GlobalAsyncStateManager stateManager = new GlobalAsyncStateManager(new GlobalMemoryManager(TOTAL_QUEUES_MAX_SIZE_LIMIT_BYTES));
|
||||
|
||||
final var preConvertId0 = simulateIncomingRecords(STREAM1_DESC, 10, stateManager);
|
||||
stateManager.trackState(GLOBAL_STATE_MESSAGE1, STATE_MSG_SIZE);
|
||||
stateManager.decrement(preConvertId0, 10);
|
||||
assertEquals(List.of(GLOBAL_STATE_MESSAGE1), stateManager.flushStates());
|
||||
|
||||
final var afterConvertId1 = simulateIncomingRecords(STREAM1_DESC, 10, stateManager);
|
||||
stateManager.trackState(GLOBAL_STATE_MESSAGE2, STATE_MSG_SIZE);
|
||||
stateManager.decrement(afterConvertId1, 10);
|
||||
assertEquals(List.of(GLOBAL_STATE_MESSAGE2), stateManager.flushStates());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testCorrectFlushingManyStreams() {
|
||||
final GlobalAsyncStateManager stateManager = new GlobalAsyncStateManager(new GlobalMemoryManager(TOTAL_QUEUES_MAX_SIZE_LIMIT_BYTES));
|
||||
|
||||
final var preConvertId0 = simulateIncomingRecords(STREAM1_DESC, 10, stateManager);
|
||||
final var preConvertId1 = simulateIncomingRecords(STREAM2_DESC, 10, stateManager);
|
||||
assertNotEquals(preConvertId0, preConvertId1);
|
||||
stateManager.trackState(GLOBAL_STATE_MESSAGE1, STATE_MSG_SIZE);
|
||||
stateManager.decrement(preConvertId0, 10);
|
||||
stateManager.decrement(preConvertId1, 10);
|
||||
assertEquals(List.of(GLOBAL_STATE_MESSAGE1), stateManager.flushStates());
|
||||
|
||||
final var afterConvertId0 = simulateIncomingRecords(STREAM1_DESC, 10, stateManager);
|
||||
final var afterConvertId1 = simulateIncomingRecords(STREAM2_DESC, 10, stateManager);
|
||||
assertEquals(afterConvertId0, afterConvertId1);
|
||||
stateManager.trackState(GLOBAL_STATE_MESSAGE2, STATE_MSG_SIZE);
|
||||
stateManager.decrement(afterConvertId0, 20);
|
||||
assertEquals(List.of(GLOBAL_STATE_MESSAGE2), stateManager.flushStates());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@Nested
|
||||
class PerStreamState {
|
||||
|
||||
@Test
|
||||
void testEmptyQueues() {
|
||||
final GlobalAsyncStateManager stateManager = new GlobalAsyncStateManager(new GlobalMemoryManager(TOTAL_QUEUES_MAX_SIZE_LIMIT_BYTES));
|
||||
|
||||
// GLOBAL
|
||||
stateManager.trackState(STREAM1_STATE_MESSAGE1, STATE_MSG_SIZE);
|
||||
assertEquals(List.of(STREAM1_STATE_MESSAGE1), stateManager.flushStates());
|
||||
|
||||
assertThrows(IllegalArgumentException.class, () -> stateManager.trackState(GLOBAL_STATE_MESSAGE1, STATE_MSG_SIZE));
|
||||
}
|
||||
|
||||
@Test
|
||||
void testCorrectFlushingOneStream() {
|
||||
final GlobalAsyncStateManager stateManager = new GlobalAsyncStateManager(new GlobalMemoryManager(TOTAL_QUEUES_MAX_SIZE_LIMIT_BYTES));
|
||||
|
||||
var stateId = simulateIncomingRecords(STREAM1_DESC, 3, stateManager);
|
||||
stateManager.trackState(STREAM1_STATE_MESSAGE1, STATE_MSG_SIZE);
|
||||
stateManager.decrement(stateId, 3);
|
||||
assertEquals(List.of(STREAM1_STATE_MESSAGE1), stateManager.flushStates());
|
||||
|
||||
stateId = simulateIncomingRecords(STREAM1_DESC, 10, stateManager);
|
||||
stateManager.trackState(STREAM1_STATE_MESSAGE2, STATE_MSG_SIZE);
|
||||
stateManager.decrement(stateId, 10);
|
||||
assertEquals(List.of(STREAM1_STATE_MESSAGE2), stateManager.flushStates());
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
void testCorrectFlushingManyStream() {
|
||||
final GlobalAsyncStateManager stateManager = new GlobalAsyncStateManager(new GlobalMemoryManager(TOTAL_QUEUES_MAX_SIZE_LIMIT_BYTES));
|
||||
|
||||
final var stream1StateId = simulateIncomingRecords(STREAM1_DESC, 3, stateManager);
|
||||
final var stream2StateId = simulateIncomingRecords(STREAM2_DESC, 7, stateManager);
|
||||
|
||||
stateManager.trackState(STREAM1_STATE_MESSAGE1, STATE_MSG_SIZE);
|
||||
stateManager.decrement(stream1StateId, 3);
|
||||
assertEquals(List.of(STREAM1_STATE_MESSAGE1), stateManager.flushStates());
|
||||
|
||||
stateManager.decrement(stream2StateId, 4);
|
||||
assertEquals(List.of(), stateManager.flushStates());
|
||||
stateManager.trackState(STREAM1_STATE_MESSAGE2, STATE_MSG_SIZE);
|
||||
stateManager.decrement(stream2StateId, 3);
|
||||
// only flush state if counter is 0.
|
||||
assertEquals(List.of(STREAM1_STATE_MESSAGE2), stateManager.flushStates());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
private static long simulateIncomingRecords(final StreamDescriptor desc, final long count, final GlobalAsyncStateManager manager) {
|
||||
var stateId = 0L;
|
||||
for (int i = 0; i < count; i++) {
|
||||
stateId = manager.getStateIdAndIncrementCounter(desc);
|
||||
}
|
||||
return stateId;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -4,13 +4,14 @@
|
||||
|
||||
package io.airbyte.integrations.destination.staging;
|
||||
|
||||
import io.airbyte.commons.json.Jsons;
|
||||
import io.airbyte.db.jdbc.JdbcDatabase;
|
||||
import io.airbyte.integrations.destination.jdbc.WriteConfig;
|
||||
import io.airbyte.integrations.destination.record_buffer.FileBuffer;
|
||||
import io.airbyte.integrations.destination.s3.csv.CsvSerializedBuffer;
|
||||
import io.airbyte.integrations.destination.s3.csv.StagingDatabaseCsvSheetGenerator;
|
||||
import io.airbyte.integrations.destination_async.DestinationFlushFunction;
|
||||
import io.airbyte.protocol.models.Jsons;
|
||||
import io.airbyte.integrations.destination_async.partial_messages.PartialAirbyteMessage;
|
||||
import io.airbyte.protocol.models.v0.AirbyteMessage;
|
||||
import io.airbyte.protocol.models.v0.ConfiguredAirbyteCatalog;
|
||||
import io.airbyte.protocol.models.v0.StreamDescriptor;
|
||||
@@ -41,27 +42,22 @@ class AsyncFlush implements DestinationFlushFunction {
|
||||
this.catalog = catalog;
|
||||
}
|
||||
|
||||
// todo(davin): exceptions are too broad.
|
||||
@Override
|
||||
public void flush(final StreamDescriptor decs, final Stream<AirbyteMessage> stream) throws Exception {
|
||||
// write this to a file - serilizable buffer?
|
||||
// where do we create all the write configs?
|
||||
log.info("Starting staging flush..");
|
||||
CsvSerializedBuffer writer = null;
|
||||
public void flush(final StreamDescriptor decs, final Stream<PartialAirbyteMessage> stream) throws Exception {
|
||||
final CsvSerializedBuffer writer;
|
||||
try {
|
||||
writer = new CsvSerializedBuffer(
|
||||
new FileBuffer(CsvSerializedBuffer.CSV_GZ_SUFFIX),
|
||||
new StagingDatabaseCsvSheetGenerator(),
|
||||
true);
|
||||
|
||||
log.info("Converting to CSV file..");
|
||||
|
||||
// reassign as lambdas require references to be final.
|
||||
final CsvSerializedBuffer finalWriter = writer;
|
||||
stream.forEach(record -> {
|
||||
try {
|
||||
// todo(davin): handle non-record airbyte messages.
|
||||
finalWriter.accept(record.getRecord());
|
||||
// todo (cgardens) - most writers just go ahead and re-serialize the contents of the record message.
|
||||
// we should either just pass the raw string or at least have a way to do that and create a default
|
||||
// impl that maintains backwards compatible behavior.
|
||||
writer.accept(Jsons.deserialize(record.getSerialized(), AirbyteMessage.class).getRecord());
|
||||
} catch (final Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
@@ -70,9 +66,8 @@ class AsyncFlush implements DestinationFlushFunction {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
||||
log.info("Converted to CSV file..");
|
||||
writer.flush();
|
||||
log.info("Flushing buffer for stream {} ({}) to staging", decs.getName(), FileUtils.byteCountToDisplaySize(writer.getByteCount()));
|
||||
log.info("Flushing CSV buffer for stream {} ({}) to staging", decs.getName(), FileUtils.byteCountToDisplaySize(writer.getByteCount()));
|
||||
if (!streamDescToWriteConfig.containsKey(decs)) {
|
||||
throw new IllegalArgumentException(
|
||||
String.format("Message contained record from a stream that was not in the catalog. \ncatalog: %s", Jsons.serialize(catalog)));
|
||||
@@ -85,7 +80,6 @@ class AsyncFlush implements DestinationFlushFunction {
|
||||
stagingOperations.getStagingPath(StagingConsumerFactory.RANDOM_CONNECTION_ID, schemaName, writeConfig.getStreamName(),
|
||||
writeConfig.getWriteDatetime());
|
||||
try {
|
||||
log.info("Starting upload to stage..");
|
||||
final String stagedFile = stagingOperations.uploadRecordsToStage(database, writer, schemaName, stageName, stagingPath);
|
||||
GeneralStagingFunctions.copyIntoTableFromStage(database, stageName, stagingPath, List.of(stagedFile), writeConfig.getOutputTableName(),
|
||||
schemaName,
|
||||
|
||||
@@ -12,6 +12,7 @@ import com.google.common.base.Preconditions;
|
||||
import io.airbyte.commons.exceptions.ConfigErrorException;
|
||||
import io.airbyte.db.jdbc.JdbcDatabase;
|
||||
import io.airbyte.integrations.base.AirbyteMessageConsumer;
|
||||
import io.airbyte.integrations.base.SerializedAirbyteMessageConsumer;
|
||||
import io.airbyte.integrations.destination.NamingConventionTransformer;
|
||||
import io.airbyte.integrations.destination.buffered_stream_consumer.BufferedStreamConsumer;
|
||||
import io.airbyte.integrations.destination.jdbc.WriteConfig;
|
||||
@@ -78,14 +79,14 @@ public class StagingConsumerFactory {
|
||||
stagingOperations::isValidData);
|
||||
}
|
||||
|
||||
public AirbyteMessageConsumer createAsync(final Consumer<AirbyteMessage> outputRecordCollector,
|
||||
final JdbcDatabase database,
|
||||
final StagingOperations stagingOperations,
|
||||
final NamingConventionTransformer namingResolver,
|
||||
final BufferCreateFunction onCreateBuffer,
|
||||
final JsonNode config,
|
||||
final ConfiguredAirbyteCatalog catalog,
|
||||
final boolean purgeStagingData) {
|
||||
public SerializedAirbyteMessageConsumer createAsync(final Consumer<AirbyteMessage> outputRecordCollector,
|
||||
final JdbcDatabase database,
|
||||
final StagingOperations stagingOperations,
|
||||
final NamingConventionTransformer namingResolver,
|
||||
final BufferCreateFunction onCreateBuffer,
|
||||
final JsonNode config,
|
||||
final ConfiguredAirbyteCatalog catalog,
|
||||
final boolean purgeStagingData) {
|
||||
final List<WriteConfig> writeConfigs = createWriteConfigs(namingResolver, config, catalog);
|
||||
final var streamDescToWriteConfig = streamDescToWriteConfig(writeConfigs);
|
||||
final var flusher = new AsyncFlush(streamDescToWriteConfig, stagingOperations, database, catalog);
|
||||
|
||||
Reference in New Issue
Block a user