1
0
mirror of synced 2025-12-25 02:09:19 -05:00

Async Snowflake Destination (#26703)

* snowflake at end of coding retreat week

* turn off async snowflake

* add comment

* Fixed test to point to the ShimMessageConsumer instead of AsyncMessageConsumer

* Automated Change

---------

Co-authored-by: ryankfu <ryan.fu@airbyte.io>
Co-authored-by: ryankfu <ryankfu@users.noreply.github.com>
This commit is contained in:
Charles
2023-05-31 16:01:36 -07:00
committed by GitHub
parent 88438bc6f5
commit 4ac62f3b4f
27 changed files with 1863 additions and 302 deletions

View File

@@ -4,15 +4,19 @@
package io.airbyte.integrations.destination_async;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonPropertyDescription;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import io.airbyte.commons.json.Jsons;
import io.airbyte.integrations.base.AirbyteMessageConsumer;
import io.airbyte.integrations.base.SerializedAirbyteMessageConsumer;
import io.airbyte.integrations.destination.buffered_stream_consumer.OnStartFunction;
import io.airbyte.integrations.destination_async.buffers.BufferEnqueue;
import io.airbyte.integrations.destination_async.buffers.BufferManager;
import io.airbyte.integrations.destination_async.partial_messages.PartialAirbyteMessage;
import io.airbyte.integrations.destination_async.state.FlushFailure;
import io.airbyte.protocol.models.v0.AirbyteMessage;
import io.airbyte.protocol.models.v0.AirbyteMessage.Type;
import io.airbyte.protocol.models.v0.AirbyteStateMessage.AirbyteStateType;
import io.airbyte.protocol.models.v0.ConfiguredAirbyteCatalog;
import io.airbyte.protocol.models.v0.StreamDescriptor;
import java.util.Optional;
@@ -31,12 +35,10 @@ import org.slf4j.LoggerFactory;
* {@link FlushWorkers}. See the other linked class for more detail.
*/
@Slf4j
public class AsyncStreamConsumer implements AirbyteMessageConsumer {
public class AsyncStreamConsumer implements SerializedAirbyteMessageConsumer {
private static final Logger LOGGER = LoggerFactory.getLogger(AsyncStreamConsumer.class);
private static final String NON_STREAM_STATE_IDENTIFIER = "GLOBAL";
private final Consumer<AirbyteMessage> outputRecordCollector;
private final OnStartFunction onStart;
private final OnCloseFunction onClose;
private final ConfiguredAirbyteCatalog catalog;
@@ -44,6 +46,7 @@ public class AsyncStreamConsumer implements AirbyteMessageConsumer {
private final BufferEnqueue bufferEnqueue;
private final FlushWorkers flushWorkers;
private final Set<StreamDescriptor> streamNames;
private final FlushFailure flushFailure;
private boolean hasStarted;
private boolean hasClosed;
@@ -54,16 +57,27 @@ public class AsyncStreamConsumer implements AirbyteMessageConsumer {
final DestinationFlushFunction flusher,
final ConfiguredAirbyteCatalog catalog,
final BufferManager bufferManager) {
this(outputRecordCollector, onStart, onClose, flusher, catalog, bufferManager, new FlushFailure());
}
@VisibleForTesting
public AsyncStreamConsumer(final Consumer<AirbyteMessage> outputRecordCollector,
final OnStartFunction onStart,
final OnCloseFunction onClose,
final DestinationFlushFunction flusher,
final ConfiguredAirbyteCatalog catalog,
final BufferManager bufferManager,
final FlushFailure flushFailure) {
hasStarted = false;
hasClosed = false;
this.outputRecordCollector = outputRecordCollector;
this.onStart = onStart;
this.onClose = onClose;
this.catalog = catalog;
this.bufferManager = bufferManager;
bufferEnqueue = bufferManager.getBufferEnqueue();
flushWorkers = new FlushWorkers(this.bufferManager.getBufferDequeue(), flusher);
this.flushFailure = flushFailure;
flushWorkers = new FlushWorkers(bufferManager.getBufferDequeue(), flusher, outputRecordCollector, flushFailure, bufferManager.getStateManager());
streamNames = StreamDescriptorUtils.fromConfiguredCatalog(catalog);
}
@@ -79,15 +93,77 @@ public class AsyncStreamConsumer implements AirbyteMessageConsumer {
}
@Override
public void accept(final AirbyteMessage message) throws Exception {
public void accept(final String messageString, final Integer sizeInBytes) throws Exception {
Preconditions.checkState(hasStarted, "Cannot accept records until consumer has started");
propagateFlushWorkerExceptionIfPresent();
/*
* intentionally putting extractStream outside the buffer manager so that if in the future we want
* to try to use a threadpool to partial deserialize to get record type and stream name, we can do
* it without touching buffer manager.
* to try to use a thread pool to partially deserialize to get record type and stream name, we can
* do it without touching buffer manager.
*/
extractStream(message)
.ifPresent(streamDescriptor -> bufferEnqueue.addRecord(streamDescriptor, message));
deserializeAirbyteMessage(messageString)
.ifPresent(message -> {
if (message.getType() == Type.RECORD) {
validateRecord(message);
}
bufferEnqueue.addRecord(message, sizeInBytes);
});
}
/**
* Deserializes to a {@link PartialAirbyteMessage} which can represent both a Record or a State
* Message
*
* @param messageString the string to deserialize
* @return PartialAirbyteMessage if the message is valid, empty otherwise
*/
private Optional<PartialAirbyteMessage> deserializeAirbyteMessage(final String messageString) {
final Optional<PartialAirbyteMessage> messageOptional = Jsons.tryDeserialize(messageString, PartialAirbyteMessage.class)
.map(partial -> partial.withSerialized(messageString));
if (messageOptional.isPresent()) {
return messageOptional;
} else {
if (isStateMessage(messageString)) {
throw new IllegalStateException("Invalid state message: " + messageString);
} else {
LOGGER.error("Received invalid message: " + messageString);
return Optional.empty();
}
}
}
/**
* Tests whether the provided JSON string represents a state message.
*
* @param input a JSON string that represents an {@link AirbyteMessage}.
* @return {@code true} if the message is a state message, {@code false} otherwise.
*/
private static boolean isStateMessage(final String input) {
final Optional<AirbyteTypeMessage> deserialized = Jsons.tryDeserialize(input, AirbyteTypeMessage.class);
return deserialized.filter(airbyteTypeMessage -> airbyteTypeMessage.getType() == Type.STATE).isPresent();
}
/**
* Custom class that can be used to parse a JSON message to determine the type of the represented
* {@link AirbyteMessage}.
*/
private static class AirbyteTypeMessage {
@JsonProperty("type")
@JsonPropertyDescription("Message type")
private AirbyteMessage.Type type;
@JsonProperty("type")
public AirbyteMessage.Type getType() {
return type;
}
@JsonProperty("type")
public void setType(final AirbyteMessage.Type type) {
this.type = type;
}
}
@Override
@@ -100,48 +176,32 @@ public class AsyncStreamConsumer implements AirbyteMessageConsumer {
// we need to close the workers before closing the bufferManagers (and underlying buffers)
// or we risk in-memory data.
flushWorkers.close();
bufferManager.close();
onClose.call();
LOGGER.info("{} closed.", AsyncStreamConsumer.class);
// as this throws an exception, we need to be after all other close functions.
propagateFlushWorkerExceptionIfPresent();
LOGGER.info("{} closed", AsyncStreamConsumer.class);
}
// todo (cgardens) - handle global state.
/**
* Extract the stream from the message, if the message is a record or state. Otherwise, we don't
* care.
*
* @param message message to extract stream from
* @return stream descriptor if the message is a record or state, otherwise empty. In the case of
* global state messages the stream descriptor is hardcoded
*/
private Optional<StreamDescriptor> extractStream(final AirbyteMessage message) {
if (message.getType() == Type.RECORD) {
final StreamDescriptor streamDescriptor = new StreamDescriptor()
.withNamespace(message.getRecord().getNamespace())
.withName(message.getRecord().getStream());
validateRecord(message, streamDescriptor);
return Optional.of(streamDescriptor);
} else if (message.getType() == Type.STATE) {
if (message.getState().getType() == AirbyteStateType.STREAM) {
return Optional.of(message.getState().getStream().getStreamDescriptor());
} else {
return Optional.of(new StreamDescriptor().withNamespace(NON_STREAM_STATE_IDENTIFIER).withNamespace(NON_STREAM_STATE_IDENTIFIER));
}
} else {
return Optional.empty();
private void propagateFlushWorkerExceptionIfPresent() throws Exception {
if (flushFailure.isFailed()) {
throw flushFailure.getException();
}
}
private void validateRecord(final AirbyteMessage message, final StreamDescriptor streamDescriptor) {
private void validateRecord(final PartialAirbyteMessage message) {
final StreamDescriptor streamDescriptor = new StreamDescriptor()
.withNamespace(message.getRecord().getNamespace())
.withName(message.getRecord().getStream());
// if stream is not part of list of streams to sync to then throw invalid stream exception
if (!streamNames.contains(streamDescriptor)) {
throwUnrecognizedStream(catalog, message);
}
}
private static void throwUnrecognizedStream(final ConfiguredAirbyteCatalog catalog, final AirbyteMessage message) {
private static void throwUnrecognizedStream(final ConfiguredAirbyteCatalog catalog, final PartialAirbyteMessage message) {
throw new IllegalArgumentException(
String.format("Message contained record from a stream that was not in the catalog. \ncatalog: %s , \nmessage: %s",
Jsons.serialize(catalog), Jsons.serialize(message)));

View File

@@ -4,6 +4,7 @@
package io.airbyte.integrations.destination_async;
import io.airbyte.integrations.destination_async.partial_messages.PartialAirbyteMessage;
import io.airbyte.protocol.models.v0.AirbyteMessage;
import io.airbyte.protocol.models.v0.StreamDescriptor;
import java.util.stream.Stream;
@@ -35,7 +36,7 @@ public interface DestinationFlushFunction {
* {@link #getOptimalBatchSizeBytes()} size
* @throws Exception
*/
void flush(StreamDescriptor decs, Stream<AirbyteMessage> stream) throws Exception;
void flush(StreamDescriptor decs, Stream<PartialAirbyteMessage> stream) throws Exception;
/**
* When invoking {@link #flush(StreamDescriptor, Stream)}, best effort attempt to invoke flush with

View File

@@ -4,21 +4,28 @@
package io.airbyte.integrations.destination_async;
import io.airbyte.commons.json.Jsons;
import io.airbyte.integrations.destination_async.buffers.BufferDequeue;
import io.airbyte.integrations.destination_async.buffers.StreamAwareQueue.MessageWithMeta;
import io.airbyte.integrations.destination_async.partial_messages.PartialAirbyteMessage;
import io.airbyte.integrations.destination_async.state.FlushFailure;
import io.airbyte.integrations.destination_async.state.GlobalAsyncStateManager;
import io.airbyte.protocol.models.v0.AirbyteMessage;
import io.airbyte.protocol.models.v0.StreamDescriptor;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.concurrent.ConcurrentHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils;
/**
* Parallel flushing of Destination data.
@@ -27,7 +34,7 @@ import org.apache.commons.io.FileUtils;
* allows for parallel data flushing.
* <p>
* Parallelising is important as it 1) minimises Destination backpressure 2) minimises the effect of
* IO pauses on Destination performance. The second point is particularly important since majority
* IO pauses on Destination performance. The second point is particularly important since a majority
* of Destination work is IO bound.
* <p>
* The {@link #supervisorThread} assigns work to worker threads by looping over
@@ -41,91 +48,77 @@ import org.apache.commons.io.FileUtils;
@Slf4j
public class FlushWorkers implements AutoCloseable {
public static final long TOTAL_QUEUES_MAX_SIZE_LIMIT_BYTES = (long) (Runtime.getRuntime().maxMemory() * 0.8);
private static final long QUEUE_FLUSH_THRESHOLD_BYTES = 10 * 1024 * 1024; // 10MB
private static final long MAX_TIME_BETWEEN_REC_MINS = 5L;
private static final long SUPERVISOR_INITIAL_DELAY_SECS = 0L;
private static final long SUPERVISOR_PERIOD_SECS = 1L;
private static final long DEBUG_INITIAL_DELAY_SECS = 0L;
private static final long DEBUG_PERIOD_SECS = 10L;
private final ScheduledExecutorService supervisorThread = Executors.newScheduledThreadPool(1);
private final ExecutorService workerPool = Executors.newFixedThreadPool(5);
private final ScheduledExecutorService supervisorThread;
private final ExecutorService workerPool;
private final BufferDequeue bufferDequeue;
private final DestinationFlushFunction flusher;
private final ScheduledExecutorService debugLoop = Executors.newSingleThreadScheduledExecutor();
private final ConcurrentHashMap<StreamDescriptor, AtomicInteger> streamToInProgressWorkers = new ConcurrentHashMap<>();
private final Consumer<AirbyteMessage> outputRecordCollector;
private final ScheduledExecutorService debugLoop;
private final RunningFlushWorkers runningFlushWorkers;
private final DetectStreamToFlush detectStreamToFlush;
public FlushWorkers(final BufferDequeue bufferDequeue, final DestinationFlushFunction flushFunction) {
private final FlushFailure flushFailure;
private final AtomicBoolean isClosing;
private final GlobalAsyncStateManager stateManager;
public FlushWorkers(final BufferDequeue bufferDequeue,
final DestinationFlushFunction flushFunction,
final Consumer<AirbyteMessage> outputRecordCollector,
final FlushFailure flushFailure,
final GlobalAsyncStateManager stateManager) {
this.bufferDequeue = bufferDequeue;
this.outputRecordCollector = outputRecordCollector;
this.flushFailure = flushFailure;
this.stateManager = stateManager;
flusher = flushFunction;
debugLoop = Executors.newSingleThreadScheduledExecutor();
supervisorThread = Executors.newScheduledThreadPool(1);
workerPool = Executors.newFixedThreadPool(5);
isClosing = new AtomicBoolean(false);
runningFlushWorkers = new RunningFlushWorkers();
detectStreamToFlush = new DetectStreamToFlush(bufferDequeue, runningFlushWorkers, isClosing, flusher);
}
public void start() {
supervisorThread.scheduleAtFixedRate(this::retrieveWork, SUPERVISOR_INITIAL_DELAY_SECS, SUPERVISOR_PERIOD_SECS,
supervisorThread.scheduleAtFixedRate(this::retrieveWork,
SUPERVISOR_INITIAL_DELAY_SECS,
SUPERVISOR_PERIOD_SECS,
TimeUnit.SECONDS);
debugLoop.scheduleAtFixedRate(this::printWorkerInfo,
DEBUG_INITIAL_DELAY_SECS,
DEBUG_PERIOD_SECS,
TimeUnit.SECONDS);
debugLoop.scheduleAtFixedRate(this::printWorkerInfo, DEBUG_INITIAL_DELAY_SECS, DEBUG_PERIOD_SECS, TimeUnit.SECONDS);
}
@Override
public void close() throws Exception {
flushAll();
supervisorThread.shutdown();
final var supervisorShut = supervisorThread.awaitTermination(5L, TimeUnit.MINUTES);
log.info("Supervisor shut status: {}", supervisorShut);
log.info("Starting worker pool shutdown..");
workerPool.shutdown();
final var workersShut = workerPool.awaitTermination(5L, TimeUnit.MINUTES);
log.info("Workers shut status: {}", workersShut);
debugLoop.shutdownNow();
}
private void retrieveWork() {
// todo (cgardens) - i'm not convinced this makes sense. as we get close to the limit, we should
// flush more eagerly, but "flush all" is never a particularly useful thing in this world.
// if the total size is > n, flush all buffers
if (bufferDequeue.getTotalGlobalQueueSizeBytes() > TOTAL_QUEUES_MAX_SIZE_LIMIT_BYTES) {
flushAll();
return;
}
try {
log.info("Retrieve Work -- Finding queues to flush");
final ThreadPoolExecutor threadPoolExecutor = (ThreadPoolExecutor) workerPool;
int allocatableThreads = threadPoolExecutor.getMaximumPoolSize() - threadPoolExecutor.getActiveCount();
final ThreadPoolExecutor threadPoolExecutor = (ThreadPoolExecutor) workerPool;
var allocatableThreads = threadPoolExecutor.getMaximumPoolSize() - threadPoolExecutor.getActiveCount();
while (allocatableThreads > 0) {
final Optional<StreamDescriptor> next = detectStreamToFlush.getNextStreamToFlush();
// todo (cgardens) - build a score to prioritize which queue to flush next. e.g. if a queue is very
// large, flush it first. if a queue has not been flushed in a while, flush it next.
// otherwise, if each individual stream has crossed a specific threshold, flush
for (final StreamDescriptor stream : bufferDequeue.getBufferedStreams()) {
if (allocatableThreads == 0) {
break;
}
// while we allow out-of-order processing for speed improvements via multiple workers reading from
// the same queue, also avoid scheduling more workers than what is already in progress.
final var inProgressSizeByte = (bufferDequeue.getQueueSizeBytes(stream).get() -
streamToInProgressWorkers.getOrDefault(stream, new AtomicInteger(0)).get() * QUEUE_FLUSH_THRESHOLD_BYTES);
final var exceedSize = inProgressSizeByte >= QUEUE_FLUSH_THRESHOLD_BYTES;
final var tooLongSinceLastRecord = bufferDequeue.getTimeOfLastRecord(stream)
.map(time -> time.isBefore(Instant.now().minus(MAX_TIME_BETWEEN_REC_MINS, ChronoUnit.MINUTES)))
.orElse(false);
if (exceedSize || tooLongSinceLastRecord) {
log.info(
"Allocated stream {}, exceedSize:{}, tooLongSinceLastRecord: {}, bytes in queue: {} computed in-progress bytes: {} , threshold bytes: {}",
stream.getName(), exceedSize, tooLongSinceLastRecord,
FileUtils.byteCountToDisplaySize(bufferDequeue.getQueueSizeBytes(stream).get()),
FileUtils.byteCountToDisplaySize(inProgressSizeByte),
FileUtils.byteCountToDisplaySize(QUEUE_FLUSH_THRESHOLD_BYTES));
allocatableThreads--;
if (streamToInProgressWorkers.containsKey(stream)) {
streamToInProgressWorkers.get(stream).getAndAdd(1);
if (next.isPresent()) {
final StreamDescriptor desc = next.get();
final UUID flushWorkerId = UUID.randomUUID();
runningFlushWorkers.trackFlushWorker(desc, flushWorkerId);
allocatableThreads--;
flush(desc, flushWorkerId);
} else {
streamToInProgressWorkers.put(stream, new AtomicInteger(1));
break;
}
flush(stream);
}
} catch (final Exception e) {
log.error("Flush worker error: ", e);
flushFailure.propagateException(e);
throw new RuntimeException(e);
}
}
@@ -142,28 +135,101 @@ public class FlushWorkers implements AutoCloseable {
}
private void flushAll() {
log.info("Flushing all buffers..");
for (final StreamDescriptor desc : bufferDequeue.getBufferedStreams()) {
flush(desc);
}
}
private void flush(final StreamDescriptor desc) {
private void flush(final StreamDescriptor desc, final UUID flushWorkerId) {
workerPool.submit(() -> {
log.info("Worker picked up work..");
log.info("Flush Worker ({}) -- Worker picked up work.", humanReadableFlushWorkerId(flushWorkerId));
try {
log.info("Attempting to read from queue {}. Current queue size: {}", desc, bufferDequeue.getQueueSizeInRecords(desc).get());
log.info("Flush Worker ({}) -- Attempting to read from queue namespace: {}, stream: {}.",
humanReadableFlushWorkerId(flushWorkerId),
desc.getNamespace(),
desc.getName());
try (final var batch = bufferDequeue.take(desc, flusher.getOptimalBatchSizeBytes())) {
flusher.flush(desc, batch.getData());
runningFlushWorkers.registerBatchSize(desc, flushWorkerId, batch.getSizeInBytes());
final Map<Long, Long> stateIdToCount = batch.getData()
.stream()
.map(MessageWithMeta::stateId)
.collect(Collectors.groupingBy(
stateId -> stateId,
Collectors.counting()));
log.info("Flush Worker ({}) -- Batch contains: {} records, {} bytes.",
humanReadableFlushWorkerId(flushWorkerId),
batch.getData().size(),
AirbyteFileUtils.byteCountToDisplaySize(batch.getSizeInBytes()));
flusher.flush(desc, batch.getData().stream().map(MessageWithMeta::message));
emitStateMessages(batch.flushStates(stateIdToCount));
}
log.info("Worker finished flushing. Current queue size: {}", bufferDequeue.getQueueSizeInRecords(desc));
log.info("Flush Worker ({}) -- Worker finished flushing. Current queue size: {}",
humanReadableFlushWorkerId(flushWorkerId),
bufferDequeue.getQueueSizeInRecords(desc).orElseThrow());
} catch (final Exception e) {
log.error(String.format("Flush Worker (%s) -- flush worker error: ", humanReadableFlushWorkerId(flushWorkerId)), e);
flushFailure.propagateException(e);
throw new RuntimeException(e);
} finally {
runningFlushWorkers.completeFlushWorker(desc, flushWorkerId);
}
});
}
@Override
public void close() throws Exception {
log.info("Closing flush workers -- waiting for all buffers to flush");
isClosing.set(true);
// wait for all buffers to be flushed.
while (true) {
final Map<StreamDescriptor, Long> streamDescriptorToRemainingRecords = bufferDequeue.getBufferedStreams()
.stream()
.collect(Collectors.toMap(desc -> desc, desc -> bufferDequeue.getQueueSizeInRecords(desc).orElseThrow()));
final boolean anyRecordsLeft = streamDescriptorToRemainingRecords
.values()
.stream()
.anyMatch(size -> size > 0);
if (!anyRecordsLeft) {
break;
}
final var workerInfo = new StringBuilder().append("REMAINING_BUFFERS_INFO").append(System.lineSeparator());
streamDescriptorToRemainingRecords.entrySet()
.stream()
.filter(entry -> entry.getValue() > 0)
.forEach(entry -> workerInfo.append(String.format(" Namespace: %s Stream: %s -- remaining records: %d",
entry.getKey().getNamespace(),
entry.getKey().getName(),
entry.getValue())));
log.info(workerInfo.toString());
log.info("Waiting for all streams to flush.");
Thread.sleep(1000);
}
log.info("Closing flush workers -- all buffers flushed");
// before shutting down the supervisor, flush all state.
emitStateMessages(stateManager.flushStates());
supervisorThread.shutdown();
final var supervisorShut = supervisorThread.awaitTermination(5L, TimeUnit.MINUTES);
log.info("Closing flush workers -- Supervisor shutdown status: {}", supervisorShut);
log.info("Closing flush workers -- Starting worker pool shutdown..");
workerPool.shutdown();
final var workersShut = workerPool.awaitTermination(5L, TimeUnit.MINUTES);
log.info("Closing flush workers -- Workers shutdown status: {}", workersShut);
debugLoop.shutdownNow();
}
private void emitStateMessages(final List<PartialAirbyteMessage> partials) {
partials
.stream()
.map(partial -> Jsons.deserialize(partial.getSerialized(), AirbyteMessage.class))
.forEach(outputRecordCollector);
}
private static String humanReadableFlushWorkerId(final UUID flushWorkerId) {
return flushWorkerId.toString().substring(0, 5);
}
}

View File

@@ -5,18 +5,20 @@
package io.airbyte.integrations.destination_async.buffers;
import io.airbyte.integrations.destination_async.GlobalMemoryManager;
import io.airbyte.protocol.models.v0.AirbyteMessage;
import io.airbyte.integrations.destination_async.buffers.MemoryBoundedLinkedBlockingQueue.MemoryItem;
import io.airbyte.integrations.destination_async.buffers.StreamAwareQueue.MessageWithMeta;
import io.airbyte.integrations.destination_async.state.GlobalAsyncStateManager;
import io.airbyte.protocol.models.v0.StreamDescriptor;
import java.time.Instant;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.locks.ReentrantLock;
import java.util.stream.Stream;
/**
* Represents the minimal interface over the underlying buffer queues required for dequeue
@@ -29,22 +31,25 @@ import java.util.stream.Stream;
public class BufferDequeue {
private final GlobalMemoryManager memoryManager;
private final ConcurrentMap<StreamDescriptor, MemoryBoundedLinkedBlockingQueue<AirbyteMessage>> buffers;
private final ConcurrentMap<StreamDescriptor, StreamAwareQueue> buffers;
private final GlobalAsyncStateManager stateManager;
private final ConcurrentMap<StreamDescriptor, ReentrantLock> bufferLocks;
public BufferDequeue(final GlobalMemoryManager memoryManager,
final ConcurrentMap<StreamDescriptor, MemoryBoundedLinkedBlockingQueue<AirbyteMessage>> buffers) {
final ConcurrentMap<StreamDescriptor, StreamAwareQueue> buffers,
final GlobalAsyncStateManager stateManager) {
this.memoryManager = memoryManager;
this.buffers = buffers;
this.stateManager = stateManager;
bufferLocks = new ConcurrentHashMap<>();
}
/**
* Primary dequeue method. Best-effort read a specified optimal memory size from the queue.
* Primary dequeue method. Reads from queue up to optimalBytesToRead OR until the queue is empty.
*
* @param streamDescriptor specific buffer to take from
* @param optimalBytesToRead bytes to read, if possible
* @return
* @return autocloseable batch object, that frees memory.
*/
public MemoryAwareMessageBatch take(final StreamDescriptor streamDescriptor, final long optimalBytesToRead) {
final var queue = buffers.get(streamDescriptor);
@@ -57,33 +62,27 @@ public class BufferDequeue {
try {
final AtomicLong bytesRead = new AtomicLong();
final var s = Stream.generate(() -> {
try {
return queue.poll(20, TimeUnit.MILLISECONDS);
} catch (final InterruptedException e) {
throw new RuntimeException(e);
}
}).takeWhile(memoryItem -> {
// if no new records after waiting, the stream is done.
if (memoryItem == null) {
return false;
}
final List<MessageWithMeta> output = new LinkedList<>();
while (queue.size() > 0) {
final MemoryItem<MessageWithMeta> memoryItem = queue.peek().orElseThrow();
// otherwise pull records until we hit the memory limit.
final long newSize = memoryItem.size() + bytesRead.get();
if (newSize <= optimalBytesToRead) {
bytesRead.addAndGet(memoryItem.size());
return true;
output.add(queue.poll().item());
} else {
return false;
break;
}
}).map(MemoryBoundedLinkedBlockingQueue.MemoryItem::item)
.toList()
.stream();
}
queue.addMaxMemory(-bytesRead.get());
return new MemoryAwareMessageBatch(s, bytesRead.get(), memoryManager);
return new MemoryAwareMessageBatch(
output,
bytesRead.get(),
memoryManager,
stateManager);
} finally {
bufferLocks.get(streamDescriptor).unlock();
}
@@ -91,10 +90,9 @@ public class BufferDequeue {
/**
* The following methods are provide metadata for buffer flushing calculations. Consumers are
* expected to call {@link #getBufferedStreams()} to retrieve the currently buffered streams as a
* handle to the remaining methods.
* expected to call it to retrieve the currently buffered streams as a handle to the remaining
* methods.
*/
public Set<StreamDescriptor> getBufferedStreams() {
return new HashSet<>(buffers.keySet());
}
@@ -104,7 +102,7 @@ public class BufferDequeue {
}
public long getTotalGlobalQueueSizeBytes() {
return buffers.values().stream().map(MemoryBoundedLinkedBlockingQueue::getCurrentMemoryUsage).mapToLong(Long::longValue).sum();
return buffers.values().stream().map(StreamAwareQueue::getCurrentMemoryUsage).mapToLong(Long::longValue).sum();
}
public Optional<Long> getQueueSizeInRecords(final StreamDescriptor streamDescriptor) {
@@ -112,14 +110,14 @@ public class BufferDequeue {
}
public Optional<Long> getQueueSizeBytes(final StreamDescriptor streamDescriptor) {
return getBuffer(streamDescriptor).map(MemoryBoundedLinkedBlockingQueue::getCurrentMemoryUsage);
return getBuffer(streamDescriptor).map(StreamAwareQueue::getCurrentMemoryUsage);
}
public Optional<Instant> getTimeOfLastRecord(final StreamDescriptor streamDescriptor) {
return getBuffer(streamDescriptor).flatMap(MemoryBoundedLinkedBlockingQueue::getTimeOfLastMessage);
return getBuffer(streamDescriptor).flatMap(StreamAwareQueue::getTimeOfLastMessage);
}
private Optional<MemoryBoundedLinkedBlockingQueue<AirbyteMessage>> getBuffer(final StreamDescriptor streamDescriptor) {
private Optional<StreamAwareQueue> getBuffer(final StreamDescriptor streamDescriptor) {
if (buffers.containsKey(streamDescriptor)) {
return Optional.of(buffers.get(streamDescriptor));
}

View File

@@ -4,10 +4,12 @@
package io.airbyte.integrations.destination_async.buffers;
import com.google.common.annotations.VisibleForTesting;
import io.airbyte.integrations.destination.buffered_stream_consumer.RecordSizeEstimator;
import static java.lang.Thread.sleep;
import io.airbyte.integrations.destination_async.GlobalMemoryManager;
import io.airbyte.protocol.models.v0.AirbyteMessage;
import io.airbyte.integrations.destination_async.partial_messages.PartialAirbyteMessage;
import io.airbyte.integrations.destination_async.state.GlobalAsyncStateManager;
import io.airbyte.protocol.models.v0.AirbyteMessage.Type;
import io.airbyte.protocol.models.v0.StreamDescriptor;
import java.util.concurrent.ConcurrentMap;
@@ -17,54 +19,65 @@ import java.util.concurrent.ConcurrentMap;
*/
public class BufferEnqueue {
private final long initialQueueSizeBytes;
private final RecordSizeEstimator recordSizeEstimator;
private final GlobalMemoryManager memoryManager;
private final ConcurrentMap<StreamDescriptor, MemoryBoundedLinkedBlockingQueue<AirbyteMessage>> buffers;
private final ConcurrentMap<StreamDescriptor, StreamAwareQueue> buffers;
private final GlobalAsyncStateManager stateManager;
public BufferEnqueue(final GlobalMemoryManager memoryManager,
final ConcurrentMap<StreamDescriptor, MemoryBoundedLinkedBlockingQueue<AirbyteMessage>> buffers) {
this(GlobalMemoryManager.BLOCK_SIZE_BYTES, memoryManager, buffers);
}
@VisibleForTesting
public BufferEnqueue(final long initialQueueSizeBytes,
final GlobalMemoryManager memoryManager,
final ConcurrentMap<StreamDescriptor, MemoryBoundedLinkedBlockingQueue<AirbyteMessage>> buffers) {
this.initialQueueSizeBytes = initialQueueSizeBytes;
final ConcurrentMap<StreamDescriptor, StreamAwareQueue> buffers,
final GlobalAsyncStateManager stateManager) {
this.memoryManager = memoryManager;
this.buffers = buffers;
this.recordSizeEstimator = new RecordSizeEstimator();
this.stateManager = stateManager;
}
/**
* Buffer a record. Contains memory management logic to dynamically adjust queue size based via
* {@link GlobalMemoryManager} accounting for incoming records.
*
* @param streamDescriptor stream to buffer record to
* @param message to buffer
* @param sizeInBytes
*/
public void addRecord(final StreamDescriptor streamDescriptor, final AirbyteMessage message) {
if (!buffers.containsKey(streamDescriptor)) {
buffers.put(streamDescriptor, new MemoryBoundedLinkedBlockingQueue<>(memoryManager.requestMemory()));
public void addRecord(final PartialAirbyteMessage message, final Integer sizeInBytes) {
if (message.getType() == Type.RECORD) {
handleRecord(message, sizeInBytes);
} else if (message.getType() == Type.STATE) {
stateManager.trackState(message, sizeInBytes);
}
}
// todo (cgardens) - handle estimating state message size.
final long messageSize = message.getType() == AirbyteMessage.Type.RECORD ? recordSizeEstimator.getEstimatedByteSize(message.getRecord()) : 1024;
private void handleRecord(final PartialAirbyteMessage message, final Integer sizeInBytes) {
final StreamDescriptor streamDescriptor = extractStateFromRecord(message);
if (streamDescriptor != null && !buffers.containsKey(streamDescriptor)) {
buffers.put(streamDescriptor, new StreamAwareQueue(memoryManager.requestMemory()));
}
final long stateId = stateManager.getStateIdAndIncrementCounter(streamDescriptor);
final var queue = buffers.get(streamDescriptor);
var addedToQueue = queue.offer(message, messageSize);
var addedToQueue = queue.offer(message, sizeInBytes, stateId);
// todo (cgardens) - what if the record being added is bigger than the block size?
// if failed, try to increase memory and add to queue.
int i = 0;
while (!addedToQueue) {
final var newlyAllocatedMemory = memoryManager.requestMemory();
if (newlyAllocatedMemory > 0) {
queue.addMaxMemory(newlyAllocatedMemory);
}
addedToQueue = queue.offer(message, messageSize);
addedToQueue = queue.offer(message, sizeInBytes, stateId);
i++;
if (i > 5) {
try {
sleep(500);
} catch (final InterruptedException e) {
throw new RuntimeException(e);
}
}
}
}
private static StreamDescriptor extractStateFromRecord(final PartialAirbyteMessage message) {
return new StreamDescriptor()
.withNamespace(message.getRecord().getNamespace())
.withName(message.getRecord().getStream());
}
}

View File

@@ -4,11 +4,11 @@
package io.airbyte.integrations.destination_async.buffers;
import static io.airbyte.integrations.destination_async.FlushWorkers.TOTAL_QUEUES_MAX_SIZE_LIMIT_BYTES;
import com.google.common.annotations.VisibleForTesting;
import io.airbyte.integrations.destination_async.AirbyteFileUtils;
import io.airbyte.integrations.destination_async.FlushWorkers;
import io.airbyte.integrations.destination_async.GlobalMemoryManager;
import io.airbyte.protocol.models.v0.AirbyteMessage;
import io.airbyte.integrations.destination_async.state.GlobalAsyncStateManager;
import io.airbyte.protocol.models.v0.StreamDescriptor;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
@@ -17,24 +17,44 @@ import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@Slf4j
public class BufferManager {
private final ConcurrentMap<StreamDescriptor, MemoryBoundedLinkedBlockingQueue<AirbyteMessage>> buffers;
private static final Logger LOGGER = LoggerFactory.getLogger(BufferManager.class);
public final long maxMemory;
private final ConcurrentMap<StreamDescriptor, StreamAwareQueue> buffers;
private final BufferEnqueue bufferEnqueue;
private final BufferDequeue bufferDequeue;
private final GlobalMemoryManager memoryManager;
private final ScheduledExecutorService debugLoop = Executors.newSingleThreadScheduledExecutor();
private final GlobalAsyncStateManager stateManager;
private final ScheduledExecutorService debugLoop;
public BufferManager() {
memoryManager = new GlobalMemoryManager(TOTAL_QUEUES_MAX_SIZE_LIMIT_BYTES);
this((long) (Runtime.getRuntime().maxMemory() * 0.8));
}
@VisibleForTesting
public BufferManager(final long memoryLimit) {
maxMemory = memoryLimit;
LOGGER.info("Memory available to the JVM {}", FileUtils.byteCountToDisplaySize(maxMemory));
memoryManager = new GlobalMemoryManager(maxMemory);
this.stateManager = new GlobalAsyncStateManager(memoryManager);
buffers = new ConcurrentHashMap<>();
bufferEnqueue = new BufferEnqueue(memoryManager, buffers);
bufferDequeue = new BufferDequeue(memoryManager, buffers);
bufferEnqueue = new BufferEnqueue(memoryManager, buffers, stateManager);
bufferDequeue = new BufferDequeue(memoryManager, buffers, stateManager);
debugLoop = Executors.newSingleThreadScheduledExecutor();
debugLoop.scheduleAtFixedRate(this::printQueueInfo, 0, 10, TimeUnit.SECONDS);
}
public GlobalAsyncStateManager getStateManager() {
return stateManager;
}
public BufferEnqueue getBufferEnqueue() {
return bufferEnqueue;
}
@@ -57,16 +77,18 @@ public class BufferManager {
final var queueInfo = new StringBuilder().append("QUEUE INFO").append(System.lineSeparator());
queueInfo
.append(String.format(" Global Mem Manager -- max: %s, allocated: %s",
FileUtils.byteCountToDisplaySize(memoryManager.getMaxMemoryBytes()),
FileUtils.byteCountToDisplaySize(memoryManager.getCurrentMemoryBytes())))
.append(String.format(" Global Mem Manager -- max: %s, allocated: %s (%s MB), %% used: %s",
AirbyteFileUtils.byteCountToDisplaySize(memoryManager.getMaxMemoryBytes()),
AirbyteFileUtils.byteCountToDisplaySize(memoryManager.getCurrentMemoryBytes()),
(double) memoryManager.getCurrentMemoryBytes() / 1024 / 1024,
(double) memoryManager.getCurrentMemoryBytes() / memoryManager.getMaxMemoryBytes()))
.append(System.lineSeparator());
for (final var entry : buffers.entrySet()) {
final var queue = entry.getValue();
queueInfo.append(
String.format(" Queue name: %s, num records: %d, num bytes: %s",
entry.getKey().getName(), queue.size(), FileUtils.byteCountToDisplaySize(queue.getCurrentMemoryUsage())))
entry.getKey().getName(), queue.size(), AirbyteFileUtils.byteCountToDisplaySize(queue.getCurrentMemoryUsage())))
.append(System.lineSeparator());
}
log.info(queueInfo.toString());

View File

@@ -5,8 +5,13 @@
package io.airbyte.integrations.destination_async.buffers;
import io.airbyte.integrations.destination_async.GlobalMemoryManager;
import io.airbyte.protocol.models.v0.AirbyteMessage;
import java.util.stream.Stream;
import io.airbyte.integrations.destination_async.buffers.StreamAwareQueue.MessageWithMeta;
import io.airbyte.integrations.destination_async.partial_messages.PartialAirbyteMessage;
import io.airbyte.integrations.destination_async.state.GlobalAsyncStateManager;
import java.util.List;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* POJO abstraction representing one discrete buffer read. This allows ergonomics dequeues by
@@ -21,24 +26,47 @@ import java.util.stream.Stream;
*/
public class MemoryAwareMessageBatch implements AutoCloseable {
private Stream<AirbyteMessage> batch;
private static final Logger LOGGER = LoggerFactory.getLogger(MemoryAwareMessageBatch.class);
private final List<MessageWithMeta> batch;
private final long sizeInBytes;
private final GlobalMemoryManager memoryManager;
private final GlobalAsyncStateManager stateManager;
public MemoryAwareMessageBatch(final Stream<AirbyteMessage> batch, final long sizeInBytes, final GlobalMemoryManager memoryManager) {
public MemoryAwareMessageBatch(final List<MessageWithMeta> batch,
final long sizeInBytes,
final GlobalMemoryManager memoryManager,
final GlobalAsyncStateManager stateManager) {
this.batch = batch;
this.sizeInBytes = sizeInBytes;
this.memoryManager = memoryManager;
this.stateManager = stateManager;
}
public Stream<AirbyteMessage> getData() {
public long getSizeInBytes() {
return sizeInBytes;
}
public List<MessageWithMeta> getData() {
return batch;
}
@Override
public void close() throws Exception {
batch = null;
memoryManager.free(sizeInBytes);
}
/**
* For the batch, marks all the states that have now been flushed. Also returns states that can be
* flushed. This method is descriptrive, it assumes that whatever consumes the state messages emits
* them, internally it purges the states it returns. message that it can.
* <p>
*
* @return list of states that can be flushed
*/
public List<PartialAirbyteMessage> flushStates(final Map<Long, Long> stateIdToCount) {
stateIdToCount.forEach(stateManager::decrement);
return stateManager.flushStates();
}
}

View File

@@ -4,12 +4,10 @@
package io.airbyte.integrations.destination_async.buffers;
import java.time.Instant;
import java.util.Optional;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import javax.annotation.Nonnull;
import lombok.extern.slf4j.Slf4j;
/**
@@ -17,7 +15,7 @@ import lombok.extern.slf4j.Slf4j;
* bounded on number of items in the queue, it is bounded by the memory it is allowed to use. The
* amount of memory it is allowed to use can be resized after it is instantiated.
* <p>
* This class intentaionally hides the underlying queue inside of it. For this class to work, it has
* This class intentionally hides the underlying queue inside of it. For this class to work, it has
* to override each method on a queue that adds or removes records from the queue. The Queue
* interface has a lot of methods to override, and we don't want to spend the time overriding a lot
* of methods that won't be used. By hiding the queue, we avoid someone accidentally using a queue
@@ -41,11 +39,7 @@ class MemoryBoundedLinkedBlockingQueue<E> {
}
public void addMaxMemory(final long maxMemoryUsage) {
this.hiddenQueue.maxMemoryUsage.addAndGet(maxMemoryUsage);
}
public Optional<Instant> getTimeOfLastMessage() {
return Optional.ofNullable(hiddenQueue.timeOfLastMessage.get());
hiddenQueue.maxMemoryUsage.addAndGet(maxMemoryUsage);
}
public int size() {
@@ -56,6 +50,10 @@ class MemoryBoundedLinkedBlockingQueue<E> {
return hiddenQueue.offer(e, itemSizeInBytes);
}
public MemoryBoundedLinkedBlockingQueue.MemoryItem<E> peek() {
return hiddenQueue.peek();
}
public MemoryBoundedLinkedBlockingQueue.MemoryItem<E> take() throws InterruptedException {
return hiddenQueue.take();
}
@@ -78,12 +76,10 @@ class MemoryBoundedLinkedBlockingQueue<E> {
private final AtomicLong currentMemoryUsage;
private final AtomicLong maxMemoryUsage;
private final AtomicReference<Instant> timeOfLastMessage;
public HiddenQueue(final long maxMemoryUsage) {
currentMemoryUsage = new AtomicLong(0);
this.maxMemoryUsage = new AtomicLong(maxMemoryUsage);
timeOfLastMessage = new AtomicReference<>(null);
}
public boolean offer(final E e, final long itemSizeInBytes) {
@@ -92,9 +88,6 @@ class MemoryBoundedLinkedBlockingQueue<E> {
final boolean success = super.offer(new MemoryItem<>(e, itemSizeInBytes));
if (!success) {
currentMemoryUsage.addAndGet(-itemSizeInBytes);
} else {
// it succeeded!
timeOfLastMessage.set(Instant.now());
}
log.debug("offer status: {}", success);
return success;
@@ -105,6 +98,7 @@ class MemoryBoundedLinkedBlockingQueue<E> {
}
}
@Nonnull
@Override
public MemoryBoundedLinkedBlockingQueue.MemoryItem<E> take() throws InterruptedException {
final MemoryItem<E> memoryItem = super.take();

View File

@@ -0,0 +1,69 @@
/*
* Copyright (c) 2023 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination_async.buffers;
import io.airbyte.integrations.destination_async.partial_messages.PartialAirbyteMessage;
import java.time.Instant;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import lombok.extern.slf4j.Slf4j;
@Slf4j
public class StreamAwareQueue {
private final AtomicReference<Instant> timeOfLastMessage;
private final MemoryBoundedLinkedBlockingQueue<MessageWithMeta> memoryAwareQueue;
public StreamAwareQueue(final long maxMemoryUsage) {
memoryAwareQueue = new MemoryBoundedLinkedBlockingQueue<>(maxMemoryUsage);
timeOfLastMessage = new AtomicReference<>();
}
public long getCurrentMemoryUsage() {
return memoryAwareQueue.getCurrentMemoryUsage();
}
public void addMaxMemory(final long maxMemoryUsage) {
memoryAwareQueue.addMaxMemory(maxMemoryUsage);
}
public Optional<Instant> getTimeOfLastMessage() {
return Optional.ofNullable(timeOfLastMessage.get());
}
public Optional<MemoryBoundedLinkedBlockingQueue.MemoryItem<MessageWithMeta>> peek() {
return Optional.ofNullable(memoryAwareQueue.peek());
}
public int size() {
return memoryAwareQueue.size();
}
public boolean offer(final PartialAirbyteMessage message, final long messageSizeInBytes, final long stateId) {
if (memoryAwareQueue.offer(new MessageWithMeta(message, stateId), messageSizeInBytes)) {
timeOfLastMessage.set(Instant.now());
return true;
} else {
return false;
}
}
public MemoryBoundedLinkedBlockingQueue.MemoryItem<MessageWithMeta> take() throws InterruptedException {
return memoryAwareQueue.take();
}
public MemoryBoundedLinkedBlockingQueue.MemoryItem<MessageWithMeta> poll() {
return memoryAwareQueue.poll();
}
public MemoryBoundedLinkedBlockingQueue.MemoryItem<MessageWithMeta> poll(final long timeout, final TimeUnit unit) throws InterruptedException {
return memoryAwareQueue.poll(timeout, unit);
}
public record MessageWithMeta(PartialAirbyteMessage message, long stateId) {}
}

View File

@@ -0,0 +1,117 @@
/*
* Copyright (c) 2023 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination_async.partial_messages;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonPropertyDescription;
import io.airbyte.protocol.models.v0.AirbyteMessage;
import java.util.Objects;
public class PartialAirbyteMessage {
@JsonProperty("type")
@JsonPropertyDescription("Message type")
private AirbyteMessage.Type type;
@JsonProperty("record")
private PartialAirbyteRecordMessage record;
@JsonProperty("state")
private PartialAirbyteStateMessage state;
@JsonProperty("serialized")
private String serialized;
public PartialAirbyteMessage() {}
@JsonProperty("type")
public AirbyteMessage.Type getType() {
return type;
}
@JsonProperty("type")
public void setType(final AirbyteMessage.Type type) {
this.type = type;
}
public PartialAirbyteMessage withType(final AirbyteMessage.Type type) {
this.type = type;
return this;
}
@JsonProperty("record")
public PartialAirbyteRecordMessage getRecord() {
return record;
}
@JsonProperty("record")
public void setRecord(final PartialAirbyteRecordMessage record) {
this.record = record;
}
public PartialAirbyteMessage withRecord(final PartialAirbyteRecordMessage record) {
this.record = record;
return this;
}
@JsonProperty("state")
public PartialAirbyteStateMessage getState() {
return state;
}
@JsonProperty("state")
public void setState(final PartialAirbyteStateMessage state) {
this.state = state;
}
public PartialAirbyteMessage withState(final PartialAirbyteStateMessage state) {
this.state = state;
return this;
}
@JsonProperty("serialized")
public String getSerialized() {
return serialized;
}
@JsonProperty("serialized")
public void setSerialized(final String serialized) {
this.serialized = serialized;
}
public PartialAirbyteMessage withSerialized(final String serialized) {
this.serialized = serialized;
return this;
}
@Override
public boolean equals(final Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
final PartialAirbyteMessage that = (PartialAirbyteMessage) o;
return type == that.type && Objects.equals(record, that.record) && Objects.equals(state, that.state)
&& Objects.equals(serialized, that.serialized);
}
@Override
public int hashCode() {
return Objects.hash(type, record, state, serialized);
}
@Override
public String toString() {
return "PartialAirbyteMessage{" +
"type=" + type +
", record=" + record +
", state=" + state +
", serialized='" + serialized + '\'' +
'}';
}
}

View File

@@ -0,0 +1,74 @@
/*
* Copyright (c) 2023 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination_async.partial_messages;
import com.fasterxml.jackson.annotation.JsonProperty;
import java.util.Objects;
public class PartialAirbyteRecordMessage {
@JsonProperty("namespace")
private String namespace;
@JsonProperty("stream")
private String stream;
public PartialAirbyteRecordMessage() {}
@JsonProperty("namespace")
public String getNamespace() {
return namespace;
}
@JsonProperty("namespace")
public void setNamespace(final String namespace) {
this.namespace = namespace;
}
public PartialAirbyteRecordMessage withNamespace(final String namespace) {
this.namespace = namespace;
return this;
}
@JsonProperty("stream")
public String getStream() {
return stream;
}
@JsonProperty("stream")
public void setStream(final String stream) {
this.stream = stream;
}
public PartialAirbyteRecordMessage withStream(final String stream) {
this.stream = stream;
return this;
}
@Override
public boolean equals(final Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
final PartialAirbyteRecordMessage that = (PartialAirbyteRecordMessage) o;
return Objects.equals(namespace, that.namespace) && Objects.equals(stream, that.stream);
}
@Override
public int hashCode() {
return Objects.hash(namespace, stream);
}
@Override
public String toString() {
return "PartialAirbyteRecordMessage{" +
"namespace='" + namespace + '\'' +
", stream='" + stream + '\'' +
'}';
}
}

View File

@@ -0,0 +1,76 @@
/*
* Copyright (c) 2023 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination_async.partial_messages;
import com.fasterxml.jackson.annotation.JsonProperty;
import io.airbyte.protocol.models.v0.AirbyteStateMessage.AirbyteStateType;
import java.util.Objects;
public class PartialAirbyteStateMessage {
@JsonProperty("type")
private AirbyteStateType type;
@JsonProperty("stream")
private PartialAirbyteStreamState stream;
public PartialAirbyteStateMessage() {}
@JsonProperty("type")
public AirbyteStateType getType() {
return type;
}
@JsonProperty("type")
public void setType(final AirbyteStateType type) {
this.type = type;
}
public PartialAirbyteStateMessage withType(final AirbyteStateType type) {
this.type = type;
return this;
}
@JsonProperty("stream")
public PartialAirbyteStreamState getStream() {
return stream;
}
@JsonProperty("stream")
public void setStream(final PartialAirbyteStreamState stream) {
this.stream = stream;
}
public PartialAirbyteStateMessage withStream(final PartialAirbyteStreamState stream) {
this.stream = stream;
return this;
}
@Override
public boolean equals(final Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
final PartialAirbyteStateMessage that = (PartialAirbyteStateMessage) o;
return type == that.type && Objects.equals(stream, that.stream);
}
@Override
public int hashCode() {
return Objects.hash(type, stream);
}
@Override
public String toString() {
return "PartialAirbyteStateMessage{" +
"type=" + type +
", stream=" + stream +
'}';
}
}

View File

@@ -0,0 +1,59 @@
/*
* Copyright (c) 2023 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination_async.partial_messages;
import com.fasterxml.jackson.annotation.JsonProperty;
import io.airbyte.protocol.models.v0.StreamDescriptor;
import java.util.Objects;
public class PartialAirbyteStreamState {
@JsonProperty("stream_descriptor")
private StreamDescriptor streamDescriptor;
public PartialAirbyteStreamState() {
streamDescriptor = streamDescriptor;
}
@JsonProperty("stream_descriptor")
public StreamDescriptor getStreamDescriptor() {
return streamDescriptor;
}
@JsonProperty("stream_descriptor")
public void setStreamDescriptor(final StreamDescriptor streamDescriptor) {
this.streamDescriptor = streamDescriptor;
}
public PartialAirbyteStreamState withStreamDescriptor(final StreamDescriptor streamDescriptor) {
this.streamDescriptor = streamDescriptor;
return this;
}
@Override
public boolean equals(final Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
final PartialAirbyteStreamState that = (PartialAirbyteStreamState) o;
return Objects.equals(streamDescriptor, that.streamDescriptor);
}
@Override
public int hashCode() {
return Objects.hash(streamDescriptor);
}
@Override
public String toString() {
return "PartialAirbyteStreamState{" +
"streamDescriptor=" + streamDescriptor +
'}';
}
}

View File

@@ -0,0 +1,29 @@
/*
* Copyright (c) 2023 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination_async.state;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
public class FlushFailure {
private final AtomicBoolean isFailed = new AtomicBoolean(false);
private final AtomicReference<Exception> exceptionAtomicReference = new AtomicReference<>();
public void propagateException(Exception e) {
this.isFailed.set(true);
this.exceptionAtomicReference.set(e);
}
public boolean isFailed() {
return isFailed.get();
}
public Exception getException() {
return exceptionAtomicReference.get();
}
}

View File

@@ -0,0 +1,328 @@
/*
* Copyright (c) 2023 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination_async.state;
import static java.lang.Thread.sleep;
import com.google.common.base.Preconditions;
import io.airbyte.integrations.destination_async.GlobalMemoryManager;
import io.airbyte.integrations.destination_async.partial_messages.PartialAirbyteMessage;
import io.airbyte.integrations.destination_async.partial_messages.PartialAirbyteStreamState;
import io.airbyte.protocol.models.v0.AirbyteMessage;
import io.airbyte.protocol.models.v0.AirbyteStateMessage;
import io.airbyte.protocol.models.v0.StreamDescriptor;
import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.atomic.AtomicLong;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.mina.util.ConcurrentHashSet;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Responsible for managing state within the Destination. The general approach is a ref counter
* approach - each state message is associated with a record count. This count represents the number
* of preceding records. For a state to be emitted, all preceding records have to be written to the
* destination i.e. the counter is 0.
* <p>
* A per-stream state queue is maintained internally, with each state within the queue having a
* counter. This means we *ALLOW* records succeeding an unemitted state to be written. This
* decouples record writing from state management at the cost of potentially repeating work if an
* upstream state is never written.
* <p>
* One important detail here is the difference between how PER-STREAM & NON-PER-STREAM is handled.
* The PER-STREAM case is simple, and is as described above. The NON-PER-STREAM case is slightly
* tricky. Because we don't know the stream type to begin with, we always assume PER_STREAM until
* the first state message arrives. If this state message is a GLOBAL state, we alias all existing
* state ids to a single global state id via a set of alias ids. From then onwards, we use one id -
* {@link #SENTINEL_GLOBAL_DESC} regardless of stream. Read
* {@link #convertToGlobalIfNeeded(AirbyteMessage)} for more detail.
*/
@Slf4j
public class GlobalAsyncStateManager {
private static final Logger LOGGER = LoggerFactory.getLogger(GlobalAsyncStateManager.class);
private static final StreamDescriptor SENTINEL_GLOBAL_DESC = new StreamDescriptor().withName(UUID.randomUUID().toString());
private final GlobalMemoryManager memoryManager;
/**
* Memory that the manager has allocated to it to use. It can ask for more memory as needed.
*/
private final AtomicLong memoryAllocated;
/**
* Memory that the manager is currently using.
*/
private final AtomicLong memoryUsed;
boolean preState = true;
private final ConcurrentMap<Long, AtomicLong> stateIdToCounter = new ConcurrentHashMap<>();
private final ConcurrentMap<StreamDescriptor, LinkedList<Long>> streamToStateIdQ = new ConcurrentHashMap<>();
private final ConcurrentMap<Long, ImmutablePair<PartialAirbyteMessage, Long>> stateIdToState = new ConcurrentHashMap<>();
// empty in the STREAM case.
// Alias-ing only exists in the non-STREAM case where we have to convert existing state ids to one
// single global id.
// This only happens once.
private final Set<Long> aliasIds = new ConcurrentHashSet<>();
private long retroactiveGlobalStateId = 0;
public GlobalAsyncStateManager(final GlobalMemoryManager memoryManager) {
this.memoryManager = memoryManager;
memoryAllocated = new AtomicLong(memoryManager.requestMemory());
memoryUsed = new AtomicLong();
}
// Always assume STREAM to begin, and convert only if needed. Most state is per stream anyway.
private AirbyteStateMessage.AirbyteStateType stateType = AirbyteStateMessage.AirbyteStateType.STREAM;
/**
* Main method to process state messages.
* <p>
* The first incoming state message tells us the type of state we are dealing with. We then convert
* internal data structures if needed.
* <p>
* Because state messages are a watermark, all preceding records need to be flushed before the state
* message can be processed.
*/
public void trackState(final PartialAirbyteMessage message, final long sizeInBytes) {
if (preState) {
convertToGlobalIfNeeded(message);
preState = false;
}
// stateType should not change after a conversion.
Preconditions.checkArgument(stateType == extractStateType(message));
closeState(message, sizeInBytes);
}
/**
* Identical to {@link #getStateId(StreamDescriptor)} except this increments the associated counter
* by 1. Intended to be called whenever a record is ingested.
*
* @param streamDescriptor - stream to get stateId for.
* @return state id
*/
public long getStateIdAndIncrementCounter(final StreamDescriptor streamDescriptor) {
return getStateIdAndIncrement(streamDescriptor, 1);
}
/**
* Each decrement represent one written record for a state. A zero counter means there are no more
* inflight records associated with a state and the state can be flushed.
*
* @param stateId reference to a state.
* @param count to decrement.
*/
public void decrement(final long stateId, final long count) {
log.trace("decrementing state id: {}, count: {}", stateId, count);
stateIdToCounter.get(getStateAfterAlias(stateId)).addAndGet(-count);
}
/**
* Returns state messages with no more inflight records i.e. counter = 0 across all streams.
* Intended to be called by {@link io.airbyte.integrations.destination_async.FlushWorkers} after a
* worker has finished flushing its record batch.
* <p>
* The return list of states should be emitted back to the platform.
*
* @return list of state messages with no more inflight records.
*/
public List<PartialAirbyteMessage> flushStates() {
final List<PartialAirbyteMessage> output = new ArrayList<>();
Long bytesFlushed = 0L;
for (final Map.Entry<StreamDescriptor, LinkedList<Long>> entry : streamToStateIdQ.entrySet()) {
// remove all states with 0 counters.
final LinkedList<Long> stateIdQueue = entry.getValue();
while (true) {
final Long oldestState = stateIdQueue.peek();
final boolean emptyQ = oldestState == null;
final boolean noCorrespondingStateMsg = stateIdToState.get(oldestState) == null;
if (emptyQ || noCorrespondingStateMsg) {
break;
}
final boolean noPrevRecs = !stateIdToCounter.containsKey(oldestState);
final boolean allRecsEmitted = stateIdToCounter.get(oldestState).get() == 0;
if (noPrevRecs || allRecsEmitted) {
entry.getValue().poll(); // poll to remove. no need to read as the earlier peek is still valid.
output.add(stateIdToState.get(oldestState).getLeft());
bytesFlushed += stateIdToState.get(oldestState).getRight();
} else {
break;
}
}
}
freeBytes(bytesFlushed);
return output;
}
private Long getStateIdAndIncrement(final StreamDescriptor streamDescriptor, final long increment) {
final StreamDescriptor resolvedDescriptor = stateType == AirbyteStateMessage.AirbyteStateType.STREAM ? streamDescriptor : SENTINEL_GLOBAL_DESC;
if (!streamToStateIdQ.containsKey(resolvedDescriptor)) {
registerNewStreamDescriptor(resolvedDescriptor);
}
final Long stateId = streamToStateIdQ.get(resolvedDescriptor).peekLast();
final var update = stateIdToCounter.get(stateId).addAndGet(increment);
log.trace("State id: {}, count: {}", stateId, update);
return stateId;
}
/**
* Return the internal id of a state message. This is the id that should be used to reference a
* state when interacting with all methods in this class.
*
* @param streamDescriptor - stream to get stateId for.
* @return state id
*/
private long getStateId(final StreamDescriptor streamDescriptor) {
return getStateIdAndIncrement(streamDescriptor, 0);
}
/**
* Pass this the number of bytes that were flushed. It will track those internally and if the
* memoryUsed gets signficantly lower than what is allocated, then it will return it to the memory
* manager. We don't always return to the memory manager to avoid needlessly allocating /
* de-allocating memory rapidly over a few bytes.
*
* @param bytesFlushed bytes that were flushed (and should be removed from memory used).
*/
private void freeBytes(final long bytesFlushed) {
LOGGER.debug("Bytes flushed memory to store state message. Allocated: {}, Used: {}, Flushed: {}, % Used: {}",
FileUtils.byteCountToDisplaySize(memoryAllocated.get()),
FileUtils.byteCountToDisplaySize(memoryUsed.get()),
FileUtils.byteCountToDisplaySize(bytesFlushed),
(double) memoryUsed.get() / memoryAllocated.get());
memoryManager.free(bytesFlushed);
memoryAllocated.addAndGet(-bytesFlushed);
memoryUsed.addAndGet(-bytesFlushed);
LOGGER.debug("Returned {} of memory back to the memory manager.", FileUtils.byteCountToDisplaySize(bytesFlushed));
}
private void convertToGlobalIfNeeded(final PartialAirbyteMessage message) {
// instead of checking for global or legacy, check for the inverse of stream.
stateType = extractStateType(message);
if (stateType != AirbyteStateMessage.AirbyteStateType.STREAM) {// alias old stream-level state ids to single global state id
// upon conversion, all previous tracking data structures need to be cleared as we move
// into the non-STREAM world for correctness.
aliasIds.addAll(streamToStateIdQ.values().stream().flatMap(Collection::stream).toList());
streamToStateIdQ.clear();
retroactiveGlobalStateId = StateIdProvider.getNextId();
streamToStateIdQ.put(SENTINEL_GLOBAL_DESC, new LinkedList<>());
streamToStateIdQ.get(SENTINEL_GLOBAL_DESC).add(retroactiveGlobalStateId);
final long combinedCounter = stateIdToCounter.values()
.stream()
.mapToLong(AtomicLong::get)
.sum();
stateIdToCounter.clear();
stateIdToCounter.put(retroactiveGlobalStateId, new AtomicLong(combinedCounter));
}
}
private AirbyteStateMessage.AirbyteStateType extractStateType(final PartialAirbyteMessage message) {
if (message.getState().getType() == null) {
// Treated the same as GLOBAL.
return AirbyteStateMessage.AirbyteStateType.LEGACY;
} else {
return message.getState().getType();
}
}
/**
* When a state message is received, 'close' the previous state to associate the existing state id
* to the newly arrived state message. We also increment the state id in preparation for the next
* state message.
*/
private void closeState(final PartialAirbyteMessage message, final long sizeInBytes) {
final StreamDescriptor resolvedDescriptor = extractStream(message).orElse(SENTINEL_GLOBAL_DESC);
stateIdToState.put(getStateId(resolvedDescriptor), ImmutablePair.of(message, sizeInBytes));
registerNewStateId(resolvedDescriptor);
allocateMemoryToState(sizeInBytes);
}
/**
* Given the size of a state message, tracks how much memory the manager is using and requests
* additional memory from the memory manager if needed.
*
* @param sizeInBytes size of the state message
*/
@SuppressWarnings("BusyWait")
private void allocateMemoryToState(final long sizeInBytes) {
if (memoryAllocated.get() < memoryUsed.get() + sizeInBytes) {
while (memoryAllocated.get() < memoryUsed.get() + sizeInBytes) {
memoryAllocated.addAndGet(memoryManager.requestMemory());
try {
LOGGER.debug("Insufficient memory to store state message. Allocated: {}, Used: {}, Size of State Msg: {}, Needed: {}",
FileUtils.byteCountToDisplaySize(memoryAllocated.get()),
FileUtils.byteCountToDisplaySize(memoryUsed.get()),
FileUtils.byteCountToDisplaySize(sizeInBytes),
FileUtils.byteCountToDisplaySize(sizeInBytes - (memoryAllocated.get() - memoryUsed.get())));
sleep(1000);
} catch (final InterruptedException e) {
throw new RuntimeException(e);
}
}
}
memoryUsed.addAndGet(sizeInBytes);
LOGGER.debug("State Manager memory usage: Allocated: {}, Used: {}, % Used {}",
FileUtils.byteCountToDisplaySize(memoryAllocated.get()),
FileUtils.byteCountToDisplaySize(memoryUsed.get()),
(double) memoryUsed.get() / memoryAllocated.get());
}
private static Optional<StreamDescriptor> extractStream(final PartialAirbyteMessage message) {
return Optional.ofNullable(message.getState().getStream()).map(PartialAirbyteStreamState::getStreamDescriptor);
}
private long getStateAfterAlias(final long stateId) {
if (aliasIds.contains(stateId)) {
return retroactiveGlobalStateId;
} else {
return stateId;
}
}
private void registerNewStreamDescriptor(final StreamDescriptor resolvedDescriptor) {
streamToStateIdQ.put(resolvedDescriptor, new LinkedList<>());
registerNewStateId(resolvedDescriptor);
}
private void registerNewStateId(final StreamDescriptor resolvedDescriptor) {
final long stateId = StateIdProvider.getNextId();
streamToStateIdQ.get(resolvedDescriptor).add(stateId);
stateIdToCounter.put(stateId, new AtomicLong(0));
}
/**
* Simplify internal tracking by providing a global always increasing counter for state ids.
*/
private static class StateIdProvider {
private static long pk = 0;
public static long getNextId() {
return pk++;
}
}
}

View File

@@ -0,0 +1,312 @@
/*
* Copyright (c) 2023 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination_async;
import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.atLeast;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import com.fasterxml.jackson.databind.JsonNode;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import io.airbyte.commons.json.Jsons;
import io.airbyte.integrations.destination.buffered_stream_consumer.OnStartFunction;
import io.airbyte.integrations.destination.buffered_stream_consumer.RecordSizeEstimator;
import io.airbyte.integrations.destination_async.buffers.BufferManager;
import io.airbyte.integrations.destination_async.partial_messages.PartialAirbyteMessage;
import io.airbyte.integrations.destination_async.partial_messages.PartialAirbyteRecordMessage;
import io.airbyte.integrations.destination_async.state.FlushFailure;
import io.airbyte.protocol.models.Field;
import io.airbyte.protocol.models.JsonSchemaType;
import io.airbyte.protocol.models.v0.AirbyteMessage;
import io.airbyte.protocol.models.v0.AirbyteMessage.Type;
import io.airbyte.protocol.models.v0.AirbyteRecordMessage;
import io.airbyte.protocol.models.v0.AirbyteStateMessage;
import io.airbyte.protocol.models.v0.AirbyteStateMessage.AirbyteStateType;
import io.airbyte.protocol.models.v0.AirbyteStreamState;
import io.airbyte.protocol.models.v0.CatalogHelpers;
import io.airbyte.protocol.models.v0.ConfiguredAirbyteCatalog;
import io.airbyte.protocol.models.v0.StreamDescriptor;
import java.io.IOException;
import java.time.Instant;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Consumer;
import java.util.stream.Stream;
import org.apache.commons.lang.RandomStringUtils;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Nested;
import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor;
class AsyncStreamConsumerTest {
private static final int RECORD_SIZE_20_BYTES = 20;
private static final String SCHEMA_NAME = "public";
private static final String STREAM_NAME = "id_and_name";
private static final String STREAM_NAME2 = STREAM_NAME + 2;
private static final StreamDescriptor STREAM1_DESC = new StreamDescriptor()
.withNamespace(SCHEMA_NAME)
.withName(STREAM_NAME);
private static final ConfiguredAirbyteCatalog CATALOG = new ConfiguredAirbyteCatalog().withStreams(List.of(
CatalogHelpers.createConfiguredAirbyteStream(
STREAM_NAME,
SCHEMA_NAME,
Field.of("id", JsonSchemaType.NUMBER),
Field.of("name", JsonSchemaType.STRING)),
CatalogHelpers.createConfiguredAirbyteStream(
STREAM_NAME2,
SCHEMA_NAME,
Field.of("id", JsonSchemaType.NUMBER),
Field.of("name", JsonSchemaType.STRING))));
private static final AirbyteMessage STATE_MESSAGE1 = new AirbyteMessage()
.withType(Type.STATE)
.withState(new AirbyteStateMessage()
.withType(AirbyteStateType.STREAM)
.withStream(new AirbyteStreamState().withStreamDescriptor(STREAM1_DESC).withStreamState(Jsons.jsonNode(1))));
private static final AirbyteMessage STATE_MESSAGE2 = new AirbyteMessage()
.withType(Type.STATE)
.withState(new AirbyteStateMessage()
.withType(AirbyteStateType.STREAM)
.withStream(new AirbyteStreamState().withStreamDescriptor(STREAM1_DESC).withStreamState(Jsons.jsonNode(2))));
private AsyncStreamConsumer consumer;
private OnStartFunction onStart;
private DestinationFlushFunction flushFunction;
private OnCloseFunction onClose;
private Consumer<AirbyteMessage> outputRecordCollector;
private FlushFailure flushFailure;
@SuppressWarnings("unchecked")
@BeforeEach
void setup() {
onStart = mock(OnStartFunction.class);
onClose = mock(OnCloseFunction.class);
flushFunction = mock(DestinationFlushFunction.class);
outputRecordCollector = mock(Consumer.class);
flushFailure = mock(FlushFailure.class);
consumer = new AsyncStreamConsumer(
outputRecordCollector,
onStart,
onClose,
flushFunction,
CATALOG,
new BufferManager(),
flushFailure);
when(flushFunction.getOptimalBatchSizeBytes()).thenReturn(10_000L);
}
@Test
void test1StreamWith1State() throws Exception {
final List<PartialAirbyteMessage> expectedRecords = generateRecords(1_000);
consumer.start();
consumeRecords(consumer, expectedRecords);
consumer.accept(Jsons.serialize(STATE_MESSAGE1), RECORD_SIZE_20_BYTES);
consumer.close();
verifyStartAndClose();
verifyRecords(STREAM_NAME, SCHEMA_NAME, expectedRecords);
verify(outputRecordCollector).accept(STATE_MESSAGE1);
}
@Test
void test1StreamWith2State() throws Exception {
final List<PartialAirbyteMessage> expectedRecords = generateRecords(1_000);
consumer.start();
consumeRecords(consumer, expectedRecords);
consumer.accept(Jsons.serialize(STATE_MESSAGE1), RECORD_SIZE_20_BYTES);
consumer.accept(Jsons.serialize(STATE_MESSAGE2), RECORD_SIZE_20_BYTES);
consumer.close();
verifyStartAndClose();
verifyRecords(STREAM_NAME, SCHEMA_NAME, expectedRecords);
verify(outputRecordCollector, times(1)).accept(STATE_MESSAGE2);
}
@Test
void test1StreamWith0State() throws Exception {
final List<PartialAirbyteMessage> expectedRecords = generateRecords(1_000);
consumer.start();
consumeRecords(consumer, expectedRecords);
consumer.close();
verifyStartAndClose();
verifyRecords(STREAM_NAME, SCHEMA_NAME, expectedRecords);
}
@Test
void testShouldBlockWhenQueuesAreFull() throws Exception {
consumer.start();
}
/*
* Tests that the consumer will block when the buffer is full. Achieves this by setting optimal
* batch size to 0, so the flush worker never actually pulls anything from the queue.
*/
@Test
void testBackPressure() throws Exception {
flushFunction = mock(DestinationFlushFunction.class);
flushFailure = mock(FlushFailure.class);
consumer = new AsyncStreamConsumer(
m -> {},
() -> {},
() -> {},
flushFunction,
CATALOG,
new BufferManager(1024 * 10),
flushFailure);
when(flushFunction.getOptimalBatchSizeBytes()).thenReturn(0L);
final AtomicLong recordCount = new AtomicLong();
consumer.start();
final ExecutorService executor = Executors.newSingleThreadExecutor();
while (true) {
final Future<?> future = executor.submit(() -> {
try {
consumer.accept(Jsons.serialize(new AirbyteMessage()
.withType(Type.RECORD)
.withRecord(new AirbyteRecordMessage()
.withStream(STREAM_NAME)
.withNamespace(SCHEMA_NAME)
.withEmittedAt(Instant.now().toEpochMilli())
.withData(Jsons.jsonNode(recordCount.getAndIncrement())))),
RECORD_SIZE_20_BYTES);
} catch (final Exception e) {
throw new RuntimeException(e);
}
});
try {
future.get(1, TimeUnit.SECONDS);
} catch (final TimeoutException e) {
future.cancel(true); // Stop the operation running in thread
break;
}
}
executor.shutdownNow();
assertTrue(recordCount.get() < 1000, String.format("Record count was %s", recordCount.get()));
}
@Nested
class ErrorHandling {
@Test
void testErrorOnAccept() throws Exception {
when(flushFailure.isFailed()).thenReturn(false).thenReturn(true);
when(flushFailure.getException()).thenReturn(new IOException("test exception"));
final var m = new AirbyteMessage()
.withType(Type.RECORD)
.withRecord(new AirbyteRecordMessage()
.withStream(STREAM_NAME)
.withNamespace(SCHEMA_NAME)
.withEmittedAt(Instant.now().toEpochMilli())
.withData(Jsons.deserialize("")));
consumer.start();
consumer.accept(Jsons.serialize(m), RECORD_SIZE_20_BYTES);
assertThrows(IOException.class, () -> consumer.accept(Jsons.serialize(m), RECORD_SIZE_20_BYTES));
}
@Test
void testErrorOnClose() throws Exception {
when(flushFailure.isFailed()).thenReturn(true);
when(flushFailure.getException()).thenReturn(new IOException("test exception"));
consumer.start();
assertThrows(IOException.class, () -> consumer.close());
}
}
private static void consumeRecords(final AsyncStreamConsumer consumer, final Collection<PartialAirbyteMessage> records) {
records.forEach(m -> {
try {
consumer.accept(m.getSerialized(), RECORD_SIZE_20_BYTES);
} catch (final Exception e) {
throw new RuntimeException(e);
}
});
}
// NOTE: Generates records at chunks of 160 bytes
@SuppressWarnings("SameParameterValue")
private static List<PartialAirbyteMessage> generateRecords(final long targetSizeInBytes) {
final List<PartialAirbyteMessage> output = Lists.newArrayList();
long bytesCounter = 0;
for (int i = 0;; i++) {
final JsonNode payload =
Jsons.jsonNode(ImmutableMap.of("id", RandomStringUtils.randomAlphabetic(7), "name", "human " + String.format("%8d", i)));
final long sizeInBytes = RecordSizeEstimator.getStringByteSize(payload);
bytesCounter += sizeInBytes;
final PartialAirbyteMessage airbyteMessage = new PartialAirbyteMessage()
.withType(Type.RECORD)
.withRecord(new PartialAirbyteRecordMessage()
.withStream(STREAM_NAME)
.withNamespace(SCHEMA_NAME))
.withSerialized(Jsons.serialize(new AirbyteMessage()
.withType(Type.RECORD)
.withRecord(new AirbyteRecordMessage()
.withStream(STREAM_NAME)
.withNamespace(SCHEMA_NAME)
.withData(payload))));
if (bytesCounter > targetSizeInBytes) {
break;
} else {
output.add(airbyteMessage);
}
}
return output;
}
private void verifyStartAndClose() throws Exception {
verify(onStart).call();
verify(onClose).call();
}
@SuppressWarnings({"unchecked", "SameParameterValue"})
private void verifyRecords(final String streamName, final String namespace, final Collection<PartialAirbyteMessage> expectedRecords)
throws Exception {
final ArgumentCaptor<Stream<PartialAirbyteMessage>> argumentCaptor = ArgumentCaptor.forClass(Stream.class);
verify(flushFunction, atLeast(1)).flush(
eq(new StreamDescriptor().withNamespace(namespace).withName(streamName)),
argumentCaptor.capture());
// captures the output of all the workers, since our records could come out in any of them.
final List<PartialAirbyteMessage> actualRecords = argumentCaptor
.getAllValues()
.stream()
// flatten those results into a single list for the simplicity of comparison
.flatMap(s -> s)
.toList();
assertEquals(expectedRecords.stream().toList(), actualRecords);
}
}

View File

@@ -0,0 +1,73 @@
/*
* Copyright (c) 2023 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination_async;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import io.airbyte.integrations.destination_async.buffers.BufferDequeue;
import io.airbyte.integrations.destination_async.buffers.MemoryAwareMessageBatch;
import io.airbyte.integrations.destination_async.partial_messages.PartialAirbyteMessage;
import io.airbyte.integrations.destination_async.state.FlushFailure;
import io.airbyte.integrations.destination_async.state.GlobalAsyncStateManager;
import io.airbyte.protocol.models.v0.StreamDescriptor;
import java.io.IOException;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.stream.Stream;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
public class FlushWorkersTest {
@Test
void testErrorHandling() throws Exception {
final AtomicBoolean hasThrownError = new AtomicBoolean(false);
final var desc = new StreamDescriptor().withName("test");
final var dequeue = mock(BufferDequeue.class);
when(dequeue.getBufferedStreams()).thenReturn(Set.of(desc));
when(dequeue.take(desc, 1000)).thenReturn(new MemoryAwareMessageBatch(List.of(), 10, null, null));
when(dequeue.getQueueSizeBytes(desc)).thenReturn(Optional.of(10L));
when(dequeue.getQueueSizeInRecords(desc)).thenAnswer(ignored -> {
if (hasThrownError.get()) {
return Optional.of(0L);
} else {
return Optional.of(1L);
}
});
final var flushFailure = new FlushFailure();
final var workers = new FlushWorkers(dequeue, new ErrorOnFlush(hasThrownError), m -> {}, flushFailure, mock(GlobalAsyncStateManager.class));
workers.start();
workers.close();
Assertions.assertTrue(flushFailure.isFailed());
Assertions.assertEquals(IOException.class, flushFailure.getException().getClass());
}
private static class ErrorOnFlush implements DestinationFlushFunction {
private final AtomicBoolean hasThrownError;
public ErrorOnFlush(final AtomicBoolean hasThrownError) {
this.hasThrownError = hasThrownError;
}
@Override
public void flush(final StreamDescriptor desc, final Stream<PartialAirbyteMessage> stream) throws Exception {
hasThrownError.set(true);
throw new IOException("Error on flush");
}
@Override
public long getOptimalBatchSizeBytes() {
return 1000;
}
}
}

View File

@@ -8,9 +8,9 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import io.airbyte.commons.json.Jsons;
import io.airbyte.protocol.models.v0.AirbyteMessage;
import io.airbyte.integrations.destination_async.partial_messages.PartialAirbyteMessage;
import io.airbyte.integrations.destination_async.partial_messages.PartialAirbyteRecordMessage;
import io.airbyte.protocol.models.v0.AirbyteMessage.Type;
import io.airbyte.protocol.models.v0.AirbyteRecordMessage;
import io.airbyte.protocol.models.v0.StreamDescriptor;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
@@ -19,14 +19,14 @@ import org.junit.jupiter.api.Test;
public class BufferDequeueTest {
private static final int RECORD_SIZE_20_BYTES = 20;
public static final String RECORD_20_BYTES = "abc";
private static final String STREAM_NAME = "stream1";
private static final StreamDescriptor STREAM_DESC = new StreamDescriptor().withName(STREAM_NAME);
private static final AirbyteMessage RECORD_MSG_20_BYTES = new AirbyteMessage()
private static final PartialAirbyteMessage RECORD_MSG_20_BYTES = new PartialAirbyteMessage()
.withType(Type.RECORD)
.withRecord(new AirbyteRecordMessage()
.withStream(STREAM_NAME)
.withData(Jsons.jsonNode(RECORD_20_BYTES)));
.withRecord(new PartialAirbyteRecordMessage()
.withStream(STREAM_NAME));
@Nested
class Take {
@@ -37,15 +37,17 @@ public class BufferDequeueTest {
final BufferEnqueue enqueue = bufferManager.getBufferEnqueue();
final BufferDequeue dequeue = bufferManager.getBufferDequeue();
enqueue.addRecord(STREAM_DESC, RECORD_MSG_20_BYTES);
enqueue.addRecord(STREAM_DESC, RECORD_MSG_20_BYTES);
enqueue.addRecord(STREAM_DESC, RECORD_MSG_20_BYTES);
enqueue.addRecord(STREAM_DESC, RECORD_MSG_20_BYTES);
enqueue.addRecord(RECORD_MSG_20_BYTES, RECORD_SIZE_20_BYTES);
enqueue.addRecord(RECORD_MSG_20_BYTES, RECORD_SIZE_20_BYTES);
enqueue.addRecord(RECORD_MSG_20_BYTES, RECORD_SIZE_20_BYTES);
enqueue.addRecord(RECORD_MSG_20_BYTES, RECORD_SIZE_20_BYTES);
// total size of records is 80, so we expect 50 to get us 2 records (prefer to under-pull records
// than over-pull).
try (final MemoryAwareMessageBatch take = dequeue.take(STREAM_DESC, 50)) {
assertEquals(2, take.getData().toList().size());
assertEquals(2, take.getData().size());
// verify it only took the records from the queue that it actually returned.
assertEquals(2, dequeue.getQueueSizeInRecords(STREAM_DESC).orElseThrow());
} catch (final Exception e) {
throw new RuntimeException(e);
}
@@ -57,12 +59,12 @@ public class BufferDequeueTest {
final BufferEnqueue enqueue = bufferManager.getBufferEnqueue();
final BufferDequeue dequeue = bufferManager.getBufferDequeue();
enqueue.addRecord(STREAM_DESC, RECORD_MSG_20_BYTES);
enqueue.addRecord(STREAM_DESC, RECORD_MSG_20_BYTES);
enqueue.addRecord(STREAM_DESC, RECORD_MSG_20_BYTES);
enqueue.addRecord(RECORD_MSG_20_BYTES, RECORD_SIZE_20_BYTES);
enqueue.addRecord(RECORD_MSG_20_BYTES, RECORD_SIZE_20_BYTES);
enqueue.addRecord(RECORD_MSG_20_BYTES, RECORD_SIZE_20_BYTES);
try (final MemoryAwareMessageBatch take = dequeue.take(STREAM_DESC, 60)) {
assertEquals(3, take.getData().toList().size());
assertEquals(3, take.getData().size());
} catch (final Exception e) {
throw new RuntimeException(e);
}
@@ -74,11 +76,11 @@ public class BufferDequeueTest {
final BufferEnqueue enqueue = bufferManager.getBufferEnqueue();
final BufferDequeue dequeue = bufferManager.getBufferDequeue();
enqueue.addRecord(STREAM_DESC, RECORD_MSG_20_BYTES);
enqueue.addRecord(STREAM_DESC, RECORD_MSG_20_BYTES);
enqueue.addRecord(RECORD_MSG_20_BYTES, RECORD_SIZE_20_BYTES);
enqueue.addRecord(RECORD_MSG_20_BYTES, RECORD_SIZE_20_BYTES);
try (final MemoryAwareMessageBatch take = dequeue.take(STREAM_DESC, Long.MAX_VALUE)) {
assertEquals(2, take.getData().toList().size());
assertEquals(2, take.getData().size());
} catch (final Exception e) {
throw new RuntimeException(e);
}
@@ -92,11 +94,13 @@ public class BufferDequeueTest {
final BufferEnqueue enqueue = bufferManager.getBufferEnqueue();
final BufferDequeue dequeue = bufferManager.getBufferDequeue();
enqueue.addRecord(STREAM_DESC, RECORD_MSG_20_BYTES);
enqueue.addRecord(STREAM_DESC, RECORD_MSG_20_BYTES);
enqueue.addRecord(RECORD_MSG_20_BYTES, RECORD_SIZE_20_BYTES);
enqueue.addRecord(RECORD_MSG_20_BYTES, RECORD_SIZE_20_BYTES);
final var secondStream = new StreamDescriptor().withName("stream_2");
enqueue.addRecord(secondStream, RECORD_MSG_20_BYTES);
final PartialAirbyteMessage recordFromSecondStream = Jsons.clone(RECORD_MSG_20_BYTES);
recordFromSecondStream.getRecord().withStream(secondStream.getName());
enqueue.addRecord(recordFromSecondStream, RECORD_SIZE_20_BYTES);
assertEquals(60, dequeue.getTotalGlobalQueueSizeBytes());

View File

@@ -5,32 +5,35 @@
package io.airbyte.integrations.destination_async.buffers;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.Mockito.mock;
import io.airbyte.commons.json.Jsons;
import io.airbyte.integrations.destination_async.GlobalMemoryManager;
import io.airbyte.integrations.destination_async.partial_messages.PartialAirbyteMessage;
import io.airbyte.integrations.destination_async.partial_messages.PartialAirbyteRecordMessage;
import io.airbyte.integrations.destination_async.state.GlobalAsyncStateManager;
import io.airbyte.protocol.models.v0.AirbyteMessage;
import io.airbyte.protocol.models.v0.AirbyteRecordMessage;
import io.airbyte.protocol.models.v0.StreamDescriptor;
import java.util.concurrent.ConcurrentHashMap;
import org.junit.jupiter.api.Test;
public class BufferEnqueueTest {
private static final int RECORD_SIZE_20_BYTES = 20;
@Test
void testAddRecordShouldAdd() {
final var twoMB = 2 * 1024 * 1024;
final var streamToBuffer = new ConcurrentHashMap<StreamDescriptor, MemoryBoundedLinkedBlockingQueue<AirbyteMessage>>();
final var enqueue = new BufferEnqueue(new GlobalMemoryManager(twoMB), streamToBuffer);
final var streamToBuffer = new ConcurrentHashMap<StreamDescriptor, StreamAwareQueue>();
final var enqueue = new BufferEnqueue(new GlobalMemoryManager(twoMB), streamToBuffer, mock(GlobalAsyncStateManager.class));
final var streamName = "stream";
final var stream = new StreamDescriptor().withName(streamName);
final var record = new AirbyteMessage()
final var record = new PartialAirbyteMessage()
.withType(AirbyteMessage.Type.RECORD)
.withRecord(new AirbyteRecordMessage()
.withStream(streamName)
.withData(Jsons.jsonNode(BufferDequeueTest.RECORD_20_BYTES)));
.withRecord(new PartialAirbyteRecordMessage()
.withStream(streamName));
enqueue.addRecord(stream, record);
enqueue.addRecord(record, RECORD_SIZE_20_BYTES);
assertEquals(1, streamToBuffer.get(stream).size());
assertEquals(20L, streamToBuffer.get(stream).getCurrentMemoryUsage());
@@ -39,20 +42,19 @@ public class BufferEnqueueTest {
@Test
public void testAddRecordShouldExpand() {
final var oneKb = 1024;
final var initialQueueSizeBytes = 20;
final var streamToBuffer = new ConcurrentHashMap<StreamDescriptor, MemoryBoundedLinkedBlockingQueue<AirbyteMessage>>();
final var enqueue = new BufferEnqueue(initialQueueSizeBytes, new GlobalMemoryManager(oneKb), streamToBuffer);
final var streamToBuffer = new ConcurrentHashMap<StreamDescriptor, StreamAwareQueue>();
final var enqueue =
new BufferEnqueue(new GlobalMemoryManager(oneKb), streamToBuffer, mock(GlobalAsyncStateManager.class));
final var streamName = "stream";
final var stream = new StreamDescriptor().withName(streamName);
final var record = new AirbyteMessage()
final var record = new PartialAirbyteMessage()
.withType(AirbyteMessage.Type.RECORD)
.withRecord(new AirbyteRecordMessage()
.withStream(streamName)
.withData(Jsons.jsonNode(BufferDequeueTest.RECORD_20_BYTES)));
.withRecord(new PartialAirbyteRecordMessage()
.withStream(streamName));
enqueue.addRecord(stream, record);
enqueue.addRecord(stream, record);
enqueue.addRecord(record, RECORD_SIZE_20_BYTES);
enqueue.addRecord(record, RECORD_SIZE_20_BYTES);
assertEquals(2, streamToBuffer.get(stream).size());
assertEquals(40, streamToBuffer.get(stream).getCurrentMemoryUsage());

View File

@@ -26,28 +26,6 @@ public class MemoryBoundedLinkedBlockingQueueTest {
assertEquals("abc", item.item());
}
@Test
void test() throws InterruptedException {
final MemoryBoundedLinkedBlockingQueue<String> queue = new MemoryBoundedLinkedBlockingQueue<>(1024);
assertEquals(0, queue.getCurrentMemoryUsage());
assertNull(queue.getTimeOfLastMessage().orElse(null));
queue.offer("abc", 6);
queue.offer("abc", 6);
queue.offer("abc", 6);
assertEquals(18, queue.getCurrentMemoryUsage());
assertNotNull(queue.getTimeOfLastMessage().orElse(null));
queue.take();
queue.take();
queue.take();
assertEquals(0, queue.getCurrentMemoryUsage());
assertNotNull(queue.getTimeOfLastMessage().orElse(null));
}
@Test
void testBlocksOnFullMemory() throws InterruptedException {
final MemoryBoundedLinkedBlockingQueue<String> queue = new MemoryBoundedLinkedBlockingQueue<>(10);

View File

@@ -0,0 +1,38 @@
/*
* Copyright (c) 2023 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination_async.buffers;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
import io.airbyte.integrations.destination_async.partial_messages.PartialAirbyteMessage;
import org.junit.jupiter.api.Test;
public class StreamAwareQueueTest {
@Test
void test() throws InterruptedException {
final StreamAwareQueue queue = new StreamAwareQueue(1024);
assertEquals(0, queue.getCurrentMemoryUsage());
assertNull(queue.getTimeOfLastMessage().orElse(null));
queue.offer(new PartialAirbyteMessage(), 6, 1);
queue.offer(new PartialAirbyteMessage(), 6, 2);
queue.offer(new PartialAirbyteMessage(), 6, 3);
assertEquals(18, queue.getCurrentMemoryUsage());
assertNotNull(queue.getTimeOfLastMessage().orElse(null));
queue.take();
queue.take();
queue.take();
assertEquals(0, queue.getCurrentMemoryUsage());
assertNotNull(queue.getTimeOfLastMessage().orElse(null));
}
}

View File

@@ -0,0 +1,206 @@
/*
* Copyright (c) 2023 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.integrations.destination_async.state;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import io.airbyte.integrations.destination_async.GlobalMemoryManager;
import io.airbyte.integrations.destination_async.partial_messages.PartialAirbyteMessage;
import io.airbyte.integrations.destination_async.partial_messages.PartialAirbyteStateMessage;
import io.airbyte.integrations.destination_async.partial_messages.PartialAirbyteStreamState;
import io.airbyte.protocol.models.v0.AirbyteMessage.Type;
import io.airbyte.protocol.models.v0.AirbyteStateMessage.AirbyteStateType;
import io.airbyte.protocol.models.v0.StreamDescriptor;
import java.util.List;
import java.util.Set;
import org.junit.jupiter.api.Nested;
import org.junit.jupiter.api.Test;
class GlobalAsyncStateManagerTest {
private static final long TOTAL_QUEUES_MAX_SIZE_LIMIT_BYTES = 100 * 1024 * 1024; // 10MB
private static final long STATE_MSG_SIZE = 1000;
private static final String STREAM_NAME = "id_and_name";
private static final String STREAM_NAME2 = STREAM_NAME + 2;
private static final String STREAM_NAME3 = STREAM_NAME + 3;
private static final StreamDescriptor STREAM1_DESC = new StreamDescriptor()
.withName(STREAM_NAME);
private static final StreamDescriptor STREAM2_DESC = new StreamDescriptor()
.withName(STREAM_NAME2);
private static final StreamDescriptor STREAM3_DESC = new StreamDescriptor()
.withName(STREAM_NAME3);
private static final PartialAirbyteMessage GLOBAL_STATE_MESSAGE1 = new PartialAirbyteMessage()
.withType(Type.STATE)
.withState(new PartialAirbyteStateMessage()
.withType(AirbyteStateType.GLOBAL));
private static final PartialAirbyteMessage GLOBAL_STATE_MESSAGE2 = new PartialAirbyteMessage()
.withType(Type.STATE)
.withState(new PartialAirbyteStateMessage()
.withType(AirbyteStateType.GLOBAL));
private static final PartialAirbyteMessage STREAM1_STATE_MESSAGE1 = new PartialAirbyteMessage()
.withType(Type.STATE)
.withState(new PartialAirbyteStateMessage()
.withType(AirbyteStateType.STREAM)
.withStream(new PartialAirbyteStreamState().withStreamDescriptor(STREAM1_DESC)));
private static final PartialAirbyteMessage STREAM1_STATE_MESSAGE2 = new PartialAirbyteMessage()
.withType(Type.STATE)
.withState(new PartialAirbyteStateMessage()
.withType(AirbyteStateType.STREAM)
.withStream(new PartialAirbyteStreamState().withStreamDescriptor(STREAM1_DESC)));
@Test
void testBasic() {
final GlobalAsyncStateManager stateManager = new GlobalAsyncStateManager(new GlobalMemoryManager(TOTAL_QUEUES_MAX_SIZE_LIMIT_BYTES));
final var firstStateId = stateManager.getStateIdAndIncrementCounter(STREAM1_DESC);
final var secondStateId = stateManager.getStateIdAndIncrementCounter(STREAM1_DESC);
assertEquals(firstStateId, secondStateId);
stateManager.decrement(firstStateId, 2);
// because no state message has been tracked, there is nothing to flush yet.
var flushed = stateManager.flushStates();
assertEquals(0, flushed.size());
stateManager.trackState(STREAM1_STATE_MESSAGE1, STATE_MSG_SIZE);
flushed = stateManager.flushStates();
assertEquals(List.of(STREAM1_STATE_MESSAGE1), flushed);
}
@Nested
class GlobalState {
@Test
void testEmptyQueuesGlobalState() {
final GlobalAsyncStateManager stateManager = new GlobalAsyncStateManager(new GlobalMemoryManager(TOTAL_QUEUES_MAX_SIZE_LIMIT_BYTES));
// GLOBAL
stateManager.trackState(GLOBAL_STATE_MESSAGE1, STATE_MSG_SIZE);
assertEquals(List.of(GLOBAL_STATE_MESSAGE1), stateManager.flushStates());
assertThrows(IllegalArgumentException.class, () -> stateManager.trackState(STREAM1_STATE_MESSAGE1, STATE_MSG_SIZE));
}
@Test
void testConversion() {
final GlobalAsyncStateManager stateManager = new GlobalAsyncStateManager(new GlobalMemoryManager(TOTAL_QUEUES_MAX_SIZE_LIMIT_BYTES));
final var preConvertId0 = simulateIncomingRecords(STREAM1_DESC, 10, stateManager);
final var preConvertId1 = simulateIncomingRecords(STREAM2_DESC, 10, stateManager);
final var preConvertId2 = simulateIncomingRecords(STREAM3_DESC, 10, stateManager);
assertEquals(3, Set.of(preConvertId0, preConvertId1, preConvertId2).size());
stateManager.trackState(GLOBAL_STATE_MESSAGE1, STATE_MSG_SIZE);
// Since this is actually a global state, we can only flush after all streams are done.
stateManager.decrement(preConvertId0, 10);
assertEquals(List.of(), stateManager.flushStates());
stateManager.decrement(preConvertId1, 10);
assertEquals(List.of(), stateManager.flushStates());
stateManager.decrement(preConvertId2, 10);
assertEquals(List.of(GLOBAL_STATE_MESSAGE1), stateManager.flushStates());
}
@Test
void testCorrectFlushingOneStream() {
final GlobalAsyncStateManager stateManager = new GlobalAsyncStateManager(new GlobalMemoryManager(TOTAL_QUEUES_MAX_SIZE_LIMIT_BYTES));
final var preConvertId0 = simulateIncomingRecords(STREAM1_DESC, 10, stateManager);
stateManager.trackState(GLOBAL_STATE_MESSAGE1, STATE_MSG_SIZE);
stateManager.decrement(preConvertId0, 10);
assertEquals(List.of(GLOBAL_STATE_MESSAGE1), stateManager.flushStates());
final var afterConvertId1 = simulateIncomingRecords(STREAM1_DESC, 10, stateManager);
stateManager.trackState(GLOBAL_STATE_MESSAGE2, STATE_MSG_SIZE);
stateManager.decrement(afterConvertId1, 10);
assertEquals(List.of(GLOBAL_STATE_MESSAGE2), stateManager.flushStates());
}
@Test
void testCorrectFlushingManyStreams() {
final GlobalAsyncStateManager stateManager = new GlobalAsyncStateManager(new GlobalMemoryManager(TOTAL_QUEUES_MAX_SIZE_LIMIT_BYTES));
final var preConvertId0 = simulateIncomingRecords(STREAM1_DESC, 10, stateManager);
final var preConvertId1 = simulateIncomingRecords(STREAM2_DESC, 10, stateManager);
assertNotEquals(preConvertId0, preConvertId1);
stateManager.trackState(GLOBAL_STATE_MESSAGE1, STATE_MSG_SIZE);
stateManager.decrement(preConvertId0, 10);
stateManager.decrement(preConvertId1, 10);
assertEquals(List.of(GLOBAL_STATE_MESSAGE1), stateManager.flushStates());
final var afterConvertId0 = simulateIncomingRecords(STREAM1_DESC, 10, stateManager);
final var afterConvertId1 = simulateIncomingRecords(STREAM2_DESC, 10, stateManager);
assertEquals(afterConvertId0, afterConvertId1);
stateManager.trackState(GLOBAL_STATE_MESSAGE2, STATE_MSG_SIZE);
stateManager.decrement(afterConvertId0, 20);
assertEquals(List.of(GLOBAL_STATE_MESSAGE2), stateManager.flushStates());
}
}
@Nested
class PerStreamState {
@Test
void testEmptyQueues() {
final GlobalAsyncStateManager stateManager = new GlobalAsyncStateManager(new GlobalMemoryManager(TOTAL_QUEUES_MAX_SIZE_LIMIT_BYTES));
// GLOBAL
stateManager.trackState(STREAM1_STATE_MESSAGE1, STATE_MSG_SIZE);
assertEquals(List.of(STREAM1_STATE_MESSAGE1), stateManager.flushStates());
assertThrows(IllegalArgumentException.class, () -> stateManager.trackState(GLOBAL_STATE_MESSAGE1, STATE_MSG_SIZE));
}
@Test
void testCorrectFlushingOneStream() {
final GlobalAsyncStateManager stateManager = new GlobalAsyncStateManager(new GlobalMemoryManager(TOTAL_QUEUES_MAX_SIZE_LIMIT_BYTES));
var stateId = simulateIncomingRecords(STREAM1_DESC, 3, stateManager);
stateManager.trackState(STREAM1_STATE_MESSAGE1, STATE_MSG_SIZE);
stateManager.decrement(stateId, 3);
assertEquals(List.of(STREAM1_STATE_MESSAGE1), stateManager.flushStates());
stateId = simulateIncomingRecords(STREAM1_DESC, 10, stateManager);
stateManager.trackState(STREAM1_STATE_MESSAGE2, STATE_MSG_SIZE);
stateManager.decrement(stateId, 10);
assertEquals(List.of(STREAM1_STATE_MESSAGE2), stateManager.flushStates());
}
@Test
void testCorrectFlushingManyStream() {
final GlobalAsyncStateManager stateManager = new GlobalAsyncStateManager(new GlobalMemoryManager(TOTAL_QUEUES_MAX_SIZE_LIMIT_BYTES));
final var stream1StateId = simulateIncomingRecords(STREAM1_DESC, 3, stateManager);
final var stream2StateId = simulateIncomingRecords(STREAM2_DESC, 7, stateManager);
stateManager.trackState(STREAM1_STATE_MESSAGE1, STATE_MSG_SIZE);
stateManager.decrement(stream1StateId, 3);
assertEquals(List.of(STREAM1_STATE_MESSAGE1), stateManager.flushStates());
stateManager.decrement(stream2StateId, 4);
assertEquals(List.of(), stateManager.flushStates());
stateManager.trackState(STREAM1_STATE_MESSAGE2, STATE_MSG_SIZE);
stateManager.decrement(stream2StateId, 3);
// only flush state if counter is 0.
assertEquals(List.of(STREAM1_STATE_MESSAGE2), stateManager.flushStates());
}
}
private static long simulateIncomingRecords(final StreamDescriptor desc, final long count, final GlobalAsyncStateManager manager) {
var stateId = 0L;
for (int i = 0; i < count; i++) {
stateId = manager.getStateIdAndIncrementCounter(desc);
}
return stateId;
}
}

View File

@@ -4,13 +4,14 @@
package io.airbyte.integrations.destination.staging;
import io.airbyte.commons.json.Jsons;
import io.airbyte.db.jdbc.JdbcDatabase;
import io.airbyte.integrations.destination.jdbc.WriteConfig;
import io.airbyte.integrations.destination.record_buffer.FileBuffer;
import io.airbyte.integrations.destination.s3.csv.CsvSerializedBuffer;
import io.airbyte.integrations.destination.s3.csv.StagingDatabaseCsvSheetGenerator;
import io.airbyte.integrations.destination_async.DestinationFlushFunction;
import io.airbyte.protocol.models.Jsons;
import io.airbyte.integrations.destination_async.partial_messages.PartialAirbyteMessage;
import io.airbyte.protocol.models.v0.AirbyteMessage;
import io.airbyte.protocol.models.v0.ConfiguredAirbyteCatalog;
import io.airbyte.protocol.models.v0.StreamDescriptor;
@@ -41,27 +42,22 @@ class AsyncFlush implements DestinationFlushFunction {
this.catalog = catalog;
}
// todo(davin): exceptions are too broad.
@Override
public void flush(final StreamDescriptor decs, final Stream<AirbyteMessage> stream) throws Exception {
// write this to a file - serilizable buffer?
// where do we create all the write configs?
log.info("Starting staging flush..");
CsvSerializedBuffer writer = null;
public void flush(final StreamDescriptor decs, final Stream<PartialAirbyteMessage> stream) throws Exception {
final CsvSerializedBuffer writer;
try {
writer = new CsvSerializedBuffer(
new FileBuffer(CsvSerializedBuffer.CSV_GZ_SUFFIX),
new StagingDatabaseCsvSheetGenerator(),
true);
log.info("Converting to CSV file..");
// reassign as lambdas require references to be final.
final CsvSerializedBuffer finalWriter = writer;
stream.forEach(record -> {
try {
// todo(davin): handle non-record airbyte messages.
finalWriter.accept(record.getRecord());
// todo (cgardens) - most writers just go ahead and re-serialize the contents of the record message.
// we should either just pass the raw string or at least have a way to do that and create a default
// impl that maintains backwards compatible behavior.
writer.accept(Jsons.deserialize(record.getSerialized(), AirbyteMessage.class).getRecord());
} catch (final Exception e) {
throw new RuntimeException(e);
}
@@ -70,9 +66,8 @@ class AsyncFlush implements DestinationFlushFunction {
throw new RuntimeException(e);
}
log.info("Converted to CSV file..");
writer.flush();
log.info("Flushing buffer for stream {} ({}) to staging", decs.getName(), FileUtils.byteCountToDisplaySize(writer.getByteCount()));
log.info("Flushing CSV buffer for stream {} ({}) to staging", decs.getName(), FileUtils.byteCountToDisplaySize(writer.getByteCount()));
if (!streamDescToWriteConfig.containsKey(decs)) {
throw new IllegalArgumentException(
String.format("Message contained record from a stream that was not in the catalog. \ncatalog: %s", Jsons.serialize(catalog)));
@@ -85,7 +80,6 @@ class AsyncFlush implements DestinationFlushFunction {
stagingOperations.getStagingPath(StagingConsumerFactory.RANDOM_CONNECTION_ID, schemaName, writeConfig.getStreamName(),
writeConfig.getWriteDatetime());
try {
log.info("Starting upload to stage..");
final String stagedFile = stagingOperations.uploadRecordsToStage(database, writer, schemaName, stageName, stagingPath);
GeneralStagingFunctions.copyIntoTableFromStage(database, stageName, stagingPath, List.of(stagedFile), writeConfig.getOutputTableName(),
schemaName,

View File

@@ -12,6 +12,7 @@ import com.google.common.base.Preconditions;
import io.airbyte.commons.exceptions.ConfigErrorException;
import io.airbyte.db.jdbc.JdbcDatabase;
import io.airbyte.integrations.base.AirbyteMessageConsumer;
import io.airbyte.integrations.base.SerializedAirbyteMessageConsumer;
import io.airbyte.integrations.destination.NamingConventionTransformer;
import io.airbyte.integrations.destination.buffered_stream_consumer.BufferedStreamConsumer;
import io.airbyte.integrations.destination.jdbc.WriteConfig;
@@ -78,14 +79,14 @@ public class StagingConsumerFactory {
stagingOperations::isValidData);
}
public AirbyteMessageConsumer createAsync(final Consumer<AirbyteMessage> outputRecordCollector,
final JdbcDatabase database,
final StagingOperations stagingOperations,
final NamingConventionTransformer namingResolver,
final BufferCreateFunction onCreateBuffer,
final JsonNode config,
final ConfiguredAirbyteCatalog catalog,
final boolean purgeStagingData) {
public SerializedAirbyteMessageConsumer createAsync(final Consumer<AirbyteMessage> outputRecordCollector,
final JdbcDatabase database,
final StagingOperations stagingOperations,
final NamingConventionTransformer namingResolver,
final BufferCreateFunction onCreateBuffer,
final JsonNode config,
final ConfiguredAirbyteCatalog catalog,
final boolean purgeStagingData) {
final List<WriteConfig> writeConfigs = createWriteConfigs(namingResolver, config, catalog);
final var streamDescToWriteConfig = streamDescToWriteConfig(writeConfigs);
final var flusher = new AsyncFlush(streamDescToWriteConfig, stagingOperations, database, catalog);