Socket support for Dataflow (#65606)
This commit is contained in:
@@ -6,7 +6,7 @@ import javax.xml.xpath.XPathFactory
|
||||
import org.w3c.dom.Document
|
||||
|
||||
allprojects {
|
||||
version = "0.1.20"
|
||||
version = "0.1.21"
|
||||
apply plugin: 'java-library'
|
||||
apply plugin: 'maven-publish'
|
||||
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
**Load CDK**
|
||||
|
||||
## Version 0.1.21
|
||||
|
||||
* **Changed:** Adds basic socket support.
|
||||
|
||||
## Version 0.1.20
|
||||
|
||||
* **Changed:** Fix hard failure edge case in stream initialization for in dataflow cdk lifecycle.
|
||||
|
||||
@@ -7,7 +7,7 @@ package io.airbyte.cdk.load.dataflow
|
||||
import io.airbyte.cdk.load.command.DestinationCatalog
|
||||
import io.airbyte.cdk.load.dataflow.config.MemoryAndParallelismConfig
|
||||
import io.airbyte.cdk.load.dataflow.finalization.StreamCompletionTracker
|
||||
import io.airbyte.cdk.load.dataflow.pipeline.DataFlowPipeline
|
||||
import io.airbyte.cdk.load.dataflow.pipeline.PipelineRunner
|
||||
import io.airbyte.cdk.load.write.DestinationWriter
|
||||
import io.airbyte.cdk.load.write.StreamLoader
|
||||
import io.github.oshai.kotlinlogging.KotlinLogging
|
||||
@@ -23,7 +23,7 @@ import kotlinx.coroutines.runBlocking
|
||||
class DestinationLifecycle(
|
||||
private val destinationInitializer: DestinationWriter,
|
||||
private val destinationCatalog: DestinationCatalog,
|
||||
private val pipeline: DataFlowPipeline,
|
||||
private val pipeline: PipelineRunner,
|
||||
private val completionTracker: StreamCompletionTracker,
|
||||
private val memoryAndParallelismConfig: MemoryAndParallelismConfig,
|
||||
) {
|
||||
|
||||
@@ -10,15 +10,13 @@ import io.airbyte.cdk.load.dataflow.config.MemoryAndParallelismConfig
|
||||
import io.airbyte.cdk.load.dataflow.state.PartitionHistogram
|
||||
import io.airbyte.cdk.load.dataflow.transform.RecordDTO
|
||||
import io.github.oshai.kotlinlogging.KotlinLogging
|
||||
import jakarta.inject.Singleton
|
||||
import java.util.concurrent.ConcurrentHashMap
|
||||
|
||||
typealias StoreKey = DestinationStream.Descriptor
|
||||
|
||||
@Singleton
|
||||
class AggregateStore(
|
||||
private val aggFactory: AggregateFactory,
|
||||
private val memoryAndParallelismConfig: MemoryAndParallelismConfig,
|
||||
memoryAndParallelismConfig: MemoryAndParallelismConfig,
|
||||
) {
|
||||
private val log = KotlinLogging.logger {}
|
||||
|
||||
@@ -102,3 +100,11 @@ data class AggregateEntry(
|
||||
return stalenessTrigger.isComplete(ts)
|
||||
}
|
||||
}
|
||||
|
||||
/* For testing purposes so we can mock. */
|
||||
class AggregateStoreFactory(
|
||||
private val aggFactory: AggregateFactory,
|
||||
private val memoryAndParallelismConfig: MemoryAndParallelismConfig,
|
||||
) {
|
||||
fun make() = AggregateStore(aggFactory, memoryAndParallelismConfig)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,126 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
|
||||
*/
|
||||
|
||||
package io.airbyte.cdk.load.dataflow.config
|
||||
|
||||
import io.airbyte.cdk.load.dataflow.aggregate.AggregateFactory
|
||||
import io.airbyte.cdk.load.dataflow.aggregate.AggregateStoreFactory
|
||||
import io.airbyte.cdk.load.dataflow.finalization.StreamCompletionTracker
|
||||
import io.airbyte.cdk.load.dataflow.input.DataFlowPipelineInputFlow
|
||||
import io.airbyte.cdk.load.dataflow.input.DestinationMessageInputFlow
|
||||
import io.airbyte.cdk.load.dataflow.pipeline.DataFlowPipeline
|
||||
import io.airbyte.cdk.load.dataflow.pipeline.DataFlowStage
|
||||
import io.airbyte.cdk.load.dataflow.pipeline.PipelineCompletionHandler
|
||||
import io.airbyte.cdk.load.dataflow.stages.AggregateStage
|
||||
import io.airbyte.cdk.load.dataflow.state.StateHistogramStore
|
||||
import io.airbyte.cdk.load.dataflow.state.StateKeyClient
|
||||
import io.airbyte.cdk.load.dataflow.state.StateStore
|
||||
import io.airbyte.cdk.load.file.ClientSocket
|
||||
import io.airbyte.cdk.load.message.ProtocolMessageDeserializer
|
||||
import io.micronaut.context.annotation.Factory
|
||||
import io.micronaut.context.annotation.Requires
|
||||
import io.micronaut.context.annotation.Value
|
||||
import jakarta.inject.Named
|
||||
import jakarta.inject.Singleton
|
||||
import java.io.InputStream
|
||||
|
||||
/**
|
||||
* Conditionally creates input streams / sockets based on channel medium, then wires up a pipeline
|
||||
* to each input with separate aggregate stores but shared state stores.
|
||||
*/
|
||||
@Factory
|
||||
class InputBeanFactory {
|
||||
@Requires(property = "airbyte.destination.core.data-channel.medium", value = "SOCKET")
|
||||
@Singleton
|
||||
fun sockets(
|
||||
@Value("\${airbyte.destination.core.data-channel.socket-paths}") socketPaths: List<String>,
|
||||
@Value("\${airbyte.destination.core.data-channel.socket-buffer-size-bytes}")
|
||||
bufferSizeBytes: Int,
|
||||
@Value("\${airbyte.destination.core.data-channel.socket-connection-timeout-ms}")
|
||||
socketConnectionTimeoutMs: Long,
|
||||
): List<ClientSocket> =
|
||||
socketPaths.map {
|
||||
ClientSocket(
|
||||
socketPath = it,
|
||||
bufferSizeBytes = bufferSizeBytes,
|
||||
connectTimeoutMs = socketConnectionTimeoutMs,
|
||||
)
|
||||
}
|
||||
|
||||
@Requires(property = "airbyte.destination.core.data-channel.medium", value = "SOCKET")
|
||||
@Named("inputStreams")
|
||||
@Singleton
|
||||
fun socketStreams(
|
||||
sockets: List<ClientSocket>,
|
||||
): List<InputStream> = sockets.map(ClientSocket::openInputStream)
|
||||
|
||||
@Requires(property = "airbyte.destination.core.data-channel.medium", value = "STDIO")
|
||||
@Named("inputStreams")
|
||||
@Singleton
|
||||
fun stdInStreams(): List<InputStream> = listOf(System.`in`)
|
||||
|
||||
@Singleton
|
||||
fun messageFlows(
|
||||
@Named("inputStreams") inputStreams: List<InputStream>,
|
||||
deserializer: ProtocolMessageDeserializer,
|
||||
): List<DestinationMessageInputFlow> =
|
||||
inputStreams.map {
|
||||
DestinationMessageInputFlow(
|
||||
inputStream = it,
|
||||
deserializer = deserializer,
|
||||
)
|
||||
}
|
||||
|
||||
@Singleton
|
||||
fun inputFlows(
|
||||
messageFlows: List<DestinationMessageInputFlow>,
|
||||
stateStore: StateStore,
|
||||
stateKeyClient: StateKeyClient,
|
||||
completionTracker: StreamCompletionTracker,
|
||||
): List<DataFlowPipelineInputFlow> =
|
||||
messageFlows.map {
|
||||
DataFlowPipelineInputFlow(
|
||||
inputFlow = it,
|
||||
stateStore = stateStore,
|
||||
stateKeyClient = stateKeyClient,
|
||||
completionTracker = completionTracker,
|
||||
)
|
||||
}
|
||||
|
||||
@Singleton
|
||||
fun aggregateStoreFactory(
|
||||
aggFactory: AggregateFactory,
|
||||
memoryAndParallelismConfig: MemoryAndParallelismConfig,
|
||||
) =
|
||||
AggregateStoreFactory(
|
||||
aggFactory,
|
||||
memoryAndParallelismConfig,
|
||||
)
|
||||
|
||||
@Singleton
|
||||
fun pipes(
|
||||
inputFlows: List<DataFlowPipelineInputFlow>,
|
||||
@Named("parse") parse: DataFlowStage,
|
||||
@Named("flush") flush: DataFlowStage,
|
||||
@Named("state") state: DataFlowStage,
|
||||
aggregateStoreFactory: AggregateStoreFactory,
|
||||
stateHistogramStore: StateHistogramStore,
|
||||
memoryAndParallelismConfig: MemoryAndParallelismConfig,
|
||||
): List<DataFlowPipeline> =
|
||||
inputFlows.map {
|
||||
val aggStore = aggregateStoreFactory.make()
|
||||
val aggregate = AggregateStage(aggStore)
|
||||
val completionHandler = PipelineCompletionHandler(aggStore, stateHistogramStore)
|
||||
|
||||
DataFlowPipeline(
|
||||
input = it,
|
||||
parse = parse,
|
||||
aggregate = aggregate,
|
||||
flush = flush,
|
||||
state = state,
|
||||
completionHandler = completionHandler,
|
||||
memoryAndParallelismConfig = memoryAndParallelismConfig,
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -14,8 +14,8 @@ import kotlin.time.Duration.Companion.minutes
|
||||
* progress.
|
||||
* - maxEstBytesPerAgg configures the estimated size of each aggregate.
|
||||
* - The max memory consumption is (maxEstBytesPerAgg * maxConcurrentAggregates) +
|
||||
* (maxEstBytesPerAgg * 2). Example with default values: (70,000,000 * 5) + (70,000,000 * 2) =
|
||||
* 350,000,000 + 140,000,000 = 490,000,000 bytes (approx 0.49 GB).
|
||||
* (maxEstBytesPerAgg * maxBufferedAggregates). Example with default values: (50,000,000 * 5) +
|
||||
* (50,000,000 * 3) = 300,000,000 + 150,000,000 = 450,000,000 bytes (approx 0.45 GB).
|
||||
* - stalenessDeadlinePerAggMs is how long we will wait to flush an aggregate if it is not
|
||||
* fulfilling the requirement of entry count or max memory.
|
||||
* - maxRecordsPerAgg configures the max number of records in an aggregate.
|
||||
@@ -23,10 +23,10 @@ import kotlin.time.Duration.Companion.minutes
|
||||
*/
|
||||
data class MemoryAndParallelismConfig(
|
||||
val maxOpenAggregates: Int = 5,
|
||||
val maxBufferedAggregates: Int = 5,
|
||||
val maxBufferedAggregates: Int = 3,
|
||||
val stalenessDeadlinePerAgg: Duration = 5.minutes,
|
||||
val maxRecordsPerAgg: Long = 100_000L,
|
||||
val maxEstBytesPerAgg: Long = 70_000_000L,
|
||||
val maxEstBytesPerAgg: Long = 50_000_000L,
|
||||
val maxConcurrentLifecycleOperations: Int = 10,
|
||||
) {
|
||||
init {
|
||||
|
||||
@@ -12,7 +12,7 @@ import io.airbyte.cdk.load.message.CheckpointMessage
|
||||
import io.airbyte.cdk.load.message.DestinationMessage
|
||||
import io.airbyte.cdk.load.message.DestinationRecord
|
||||
import io.airbyte.cdk.load.message.DestinationRecordStreamComplete
|
||||
import jakarta.inject.Singleton
|
||||
import io.github.oshai.kotlinlogging.KotlinLogging
|
||||
import kotlinx.coroutines.flow.Flow
|
||||
import kotlinx.coroutines.flow.FlowCollector
|
||||
|
||||
@@ -22,13 +22,14 @@ import kotlinx.coroutines.flow.FlowCollector
|
||||
* Adds state ids to the input, handling the serial case where we infer the state id from a global
|
||||
* counter.
|
||||
*/
|
||||
@Singleton
|
||||
class DataFlowPipelineInputFlow(
|
||||
private val inputFlow: Flow<DestinationMessage>,
|
||||
private val stateStore: StateStore,
|
||||
private val stateKeyClient: StateKeyClient,
|
||||
private val completionTracker: StreamCompletionTracker,
|
||||
) : Flow<DataFlowStageIO> {
|
||||
val log = KotlinLogging.logger {}
|
||||
|
||||
override suspend fun collect(
|
||||
collector: FlowCollector<DataFlowStageIO>,
|
||||
) {
|
||||
@@ -48,5 +49,7 @@ class DataFlowPipelineInputFlow(
|
||||
else -> Unit
|
||||
}
|
||||
}
|
||||
|
||||
log.info { "Finished routing input." }
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,16 +7,15 @@ package io.airbyte.cdk.load.dataflow.input
|
||||
import io.airbyte.cdk.load.message.DestinationMessage
|
||||
import io.airbyte.cdk.load.message.ProtocolMessageDeserializer
|
||||
import io.github.oshai.kotlinlogging.KotlinLogging
|
||||
import jakarta.inject.Named
|
||||
import jakarta.inject.Singleton
|
||||
import java.io.InputStream
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.flow.Flow
|
||||
import kotlinx.coroutines.flow.FlowCollector
|
||||
import kotlinx.coroutines.withContext
|
||||
|
||||
/** Takes bytes and emits DestinationMessages */
|
||||
@Singleton
|
||||
class DestinationMessageInputFlow(
|
||||
@Named("inputStream") private val inputStream: InputStream,
|
||||
private val inputStream: InputStream,
|
||||
private val deserializer: ProtocolMessageDeserializer,
|
||||
) : Flow<DestinationMessage> {
|
||||
val log = KotlinLogging.logger {}
|
||||
@@ -26,21 +25,29 @@ class DestinationMessageInputFlow(
|
||||
) {
|
||||
var msgCount = 0L
|
||||
var estBytes = 0L
|
||||
inputStream
|
||||
.bufferedReader()
|
||||
.lineSequence()
|
||||
.filter { it.isNotEmpty() }
|
||||
.forEach { line ->
|
||||
val message = deserializer.deserialize(line)
|
||||
|
||||
collector.emit(message)
|
||||
val reader = inputStream.bufferedReader()
|
||||
try {
|
||||
reader
|
||||
.lineSequence()
|
||||
.filter { it.isNotEmpty() }
|
||||
.forEach { line ->
|
||||
val message = deserializer.deserialize(line)
|
||||
|
||||
estBytes += line.length
|
||||
if (++msgCount % 100_000 == 0L) {
|
||||
log.info { "Processed $msgCount messages (${estBytes/1024/1024}Mb)" }
|
||||
collector.emit(message)
|
||||
|
||||
estBytes += line.length
|
||||
if (++msgCount % 100_000 == 0L) {
|
||||
log.info { "Processed $msgCount messages (${estBytes/1024/1024}Mb)" }
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
withContext(Dispatchers.IO) {
|
||||
reader.close()
|
||||
log.info { "Input stream reader closed." }
|
||||
}
|
||||
}
|
||||
|
||||
log.info { "Finished processing input" }
|
||||
log.info { "Finished reading input." }
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,31 +6,25 @@ package io.airbyte.cdk.load.dataflow.pipeline
|
||||
|
||||
import io.airbyte.cdk.load.dataflow.config.MemoryAndParallelismConfig
|
||||
import io.airbyte.cdk.load.dataflow.stages.AggregateStage
|
||||
import jakarta.inject.Named
|
||||
import jakarta.inject.Singleton
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.flow.Flow
|
||||
import kotlinx.coroutines.flow.buffer
|
||||
import kotlinx.coroutines.flow.flowOn
|
||||
import kotlinx.coroutines.flow.map
|
||||
import kotlinx.coroutines.flow.onCompletion
|
||||
import kotlinx.coroutines.flow.onStart
|
||||
import kotlinx.coroutines.flow.transform
|
||||
|
||||
@Singleton
|
||||
class DataFlowPipeline(
|
||||
private val input: Flow<DataFlowStageIO>,
|
||||
@Named("parse") private val parse: DataFlowStage,
|
||||
@Named("aggregate") private val aggregate: AggregateStage,
|
||||
@Named("flush") private val flush: DataFlowStage,
|
||||
@Named("state") private val state: DataFlowStage,
|
||||
private val startHandler: PipelineStartHandler,
|
||||
private val parse: DataFlowStage,
|
||||
private val aggregate: AggregateStage,
|
||||
private val flush: DataFlowStage,
|
||||
private val state: DataFlowStage,
|
||||
private val completionHandler: PipelineCompletionHandler,
|
||||
private val memoryAndParallelismConfig: MemoryAndParallelismConfig,
|
||||
) {
|
||||
suspend fun run() {
|
||||
input
|
||||
.onStart { startHandler.run() }
|
||||
.map(parse::apply)
|
||||
.transform { aggregate.apply(it, this) }
|
||||
.buffer(capacity = memoryAndParallelismConfig.maxBufferedAggregates)
|
||||
|
||||
@@ -6,34 +6,25 @@ package io.airbyte.cdk.load.dataflow.pipeline
|
||||
|
||||
import io.airbyte.cdk.load.dataflow.aggregate.AggregateStore
|
||||
import io.airbyte.cdk.load.dataflow.state.StateHistogramStore
|
||||
import io.airbyte.cdk.load.dataflow.state.StateReconciler
|
||||
import io.github.oshai.kotlinlogging.KotlinLogging
|
||||
import jakarta.inject.Singleton
|
||||
import kotlinx.coroutines.async
|
||||
import kotlinx.coroutines.awaitAll
|
||||
import kotlinx.coroutines.coroutineScope
|
||||
|
||||
@Singleton
|
||||
class PipelineCompletionHandler(
|
||||
private val aggStore: AggregateStore,
|
||||
private val stateHistogramStore: StateHistogramStore,
|
||||
private val reconciler: StateReconciler,
|
||||
) {
|
||||
private val log = KotlinLogging.logger {}
|
||||
|
||||
suspend fun apply(
|
||||
cause: Throwable?,
|
||||
) = coroutineScope {
|
||||
// shutdown the reconciler regardless of success or failure, so we don't hang
|
||||
reconciler.disable()
|
||||
|
||||
if (cause != null) {
|
||||
log.error { "Destination Pipeline Completed — Exceptionally" }
|
||||
throw cause
|
||||
}
|
||||
|
||||
log.info { "Destination Pipeline Completed — Successfully" }
|
||||
|
||||
val remainingAggregates = aggStore.getAll()
|
||||
|
||||
log.info { "Flushing ${remainingAggregates.size} final aggregates..." }
|
||||
@@ -46,6 +37,6 @@ class PipelineCompletionHandler(
|
||||
}
|
||||
.awaitAll()
|
||||
|
||||
reconciler.flushCompleteStates()
|
||||
log.info { "Final aggregates flushed." }
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
|
||||
*/
|
||||
|
||||
package io.airbyte.cdk.load.dataflow.pipeline
|
||||
|
||||
import io.airbyte.cdk.load.dataflow.state.StateReconciler
|
||||
import io.github.oshai.kotlinlogging.KotlinLogging
|
||||
import jakarta.inject.Singleton
|
||||
import kotlinx.coroutines.CoroutineScope
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.coroutineScope
|
||||
import kotlinx.coroutines.joinAll
|
||||
import kotlinx.coroutines.launch
|
||||
|
||||
@Singleton
|
||||
class PipelineRunner(
|
||||
private val reconciler: StateReconciler,
|
||||
val pipelines: List<DataFlowPipeline>,
|
||||
) {
|
||||
private val log = KotlinLogging.logger {}
|
||||
|
||||
suspend fun run() = coroutineScope {
|
||||
log.info { "Destination Pipeline Starting..." }
|
||||
log.info { "Running with ${pipelines.size} input streams..." }
|
||||
|
||||
reconciler.run(CoroutineScope(Dispatchers.IO))
|
||||
|
||||
try {
|
||||
pipelines.map { p -> launch { p.run() } }.joinAll()
|
||||
log.info { "Individual pipelines complete..." }
|
||||
} finally {
|
||||
// shutdown the reconciler regardless of success or failure, so we don't hang
|
||||
log.info { "Disabling reconciler..." }
|
||||
reconciler.disable()
|
||||
}
|
||||
|
||||
log.info { "Flushing final states..." }
|
||||
reconciler.flushCompleteStates()
|
||||
|
||||
log.info { "Destination Pipeline Completed — Successfully" }
|
||||
}
|
||||
}
|
||||
@@ -1,24 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
|
||||
*/
|
||||
|
||||
package io.airbyte.cdk.load.dataflow.pipeline
|
||||
|
||||
import io.airbyte.cdk.load.dataflow.state.StateReconciler
|
||||
import io.github.oshai.kotlinlogging.KotlinLogging
|
||||
import jakarta.inject.Singleton
|
||||
import kotlinx.coroutines.CoroutineScope
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
|
||||
@Singleton
|
||||
class PipelineStartHandler(
|
||||
private val reconciler: StateReconciler,
|
||||
) {
|
||||
private val log = KotlinLogging.logger {}
|
||||
|
||||
fun run() {
|
||||
log.info { "Destination Pipeline Starting..." }
|
||||
|
||||
reconciler.run(CoroutineScope(Dispatchers.IO))
|
||||
}
|
||||
}
|
||||
@@ -6,12 +6,8 @@ package io.airbyte.cdk.load.dataflow.stages
|
||||
|
||||
import io.airbyte.cdk.load.dataflow.aggregate.AggregateStore
|
||||
import io.airbyte.cdk.load.dataflow.pipeline.DataFlowStageIO
|
||||
import jakarta.inject.Named
|
||||
import jakarta.inject.Singleton
|
||||
import kotlinx.coroutines.flow.FlowCollector
|
||||
|
||||
@Named("aggregate")
|
||||
@Singleton
|
||||
class AggregateStage(
|
||||
val store: AggregateStore,
|
||||
) {
|
||||
|
||||
@@ -57,4 +57,53 @@ class ClientSocket(
|
||||
}
|
||||
log.info { "Reading from socket $socketPath complete" }
|
||||
}
|
||||
|
||||
fun openInputStream(): InputStream {
|
||||
log.info { "Connecting client socket at $socketPath" }
|
||||
val socketFile = File(socketPath)
|
||||
var totalWaitMs = 0L
|
||||
|
||||
while (!socketFile.exists()) {
|
||||
log.info { "Waiting for socket file to be created: $socketPath" }
|
||||
Thread.sleep(connectWaitDelayMs)
|
||||
totalWaitMs += connectWaitDelayMs
|
||||
if (totalWaitMs > connectTimeoutMs) {
|
||||
throw IllegalStateException(
|
||||
"Socket file $socketPath not created after $connectTimeoutMs ms"
|
||||
)
|
||||
}
|
||||
}
|
||||
log.info { "Socket file $socketPath created" }
|
||||
|
||||
val address = UnixDomainSocketAddress.of(socketFile.toPath())
|
||||
val openedSocket = SocketChannel.open(StandardProtocolFamily.UNIX)
|
||||
|
||||
log.info { "Socket file $socketPath opened" }
|
||||
|
||||
if (!openedSocket.connect(address)) {
|
||||
throw IllegalStateException("Failed to connect to socket $socketPath")
|
||||
}
|
||||
|
||||
// HACK: The dockerized destination tests uses this exact message
|
||||
// as a signal that it's safe to create the TCP connection to the
|
||||
// socat sidecar that feeds data into the socket. Removing it
|
||||
// will break tests. TODO: Anything else.
|
||||
log.info { "Socket file $socketPath connected for reading" }
|
||||
|
||||
val inputStream = Channels.newInputStream(openedSocket).buffered(bufferSizeBytes)
|
||||
|
||||
return SocketInputStream(openedSocket, inputStream)
|
||||
}
|
||||
}
|
||||
|
||||
class SocketInputStream(
|
||||
private val socketChannel: SocketChannel,
|
||||
private val inputStream: InputStream,
|
||||
) : InputStream() {
|
||||
override fun read(): Int = inputStream.read()
|
||||
|
||||
override fun close() {
|
||||
inputStream.close()
|
||||
socketChannel.close()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@ import io.airbyte.cdk.load.command.DestinationCatalog
|
||||
import io.airbyte.cdk.load.command.DestinationStream
|
||||
import io.airbyte.cdk.load.dataflow.config.MemoryAndParallelismConfig
|
||||
import io.airbyte.cdk.load.dataflow.finalization.StreamCompletionTracker
|
||||
import io.airbyte.cdk.load.dataflow.pipeline.DataFlowPipeline
|
||||
import io.airbyte.cdk.load.dataflow.pipeline.PipelineRunner
|
||||
import io.airbyte.cdk.load.write.DestinationWriter
|
||||
import io.airbyte.cdk.load.write.StreamLoader
|
||||
import io.mockk.coEvery
|
||||
@@ -22,7 +22,7 @@ class DestinationLifecycleTest {
|
||||
|
||||
private val destinationInitializer: DestinationWriter = mockk(relaxed = true)
|
||||
private val destinationCatalog: DestinationCatalog = mockk(relaxed = true)
|
||||
private val pipeline: DataFlowPipeline = mockk(relaxed = true)
|
||||
private val pipeline: PipelineRunner = mockk(relaxed = true)
|
||||
private val completionTracker: StreamCompletionTracker = mockk(relaxed = true)
|
||||
private val memoryAndParallelismConfig: MemoryAndParallelismConfig =
|
||||
MemoryAndParallelismConfig(maxOpenAggregates = 1, maxBufferedAggregates = 1)
|
||||
|
||||
@@ -0,0 +1,308 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
|
||||
*/
|
||||
|
||||
package io.airbyte.cdk.load.dataflow.config
|
||||
|
||||
import io.airbyte.cdk.load.dataflow.aggregate.AggregateStore
|
||||
import io.airbyte.cdk.load.dataflow.aggregate.AggregateStoreFactory
|
||||
import io.airbyte.cdk.load.dataflow.finalization.StreamCompletionTracker
|
||||
import io.airbyte.cdk.load.dataflow.input.DataFlowPipelineInputFlow
|
||||
import io.airbyte.cdk.load.dataflow.input.DestinationMessageInputFlow
|
||||
import io.airbyte.cdk.load.dataflow.pipeline.DataFlowStage
|
||||
import io.airbyte.cdk.load.dataflow.state.StateHistogramStore
|
||||
import io.airbyte.cdk.load.dataflow.state.StateKeyClient
|
||||
import io.airbyte.cdk.load.dataflow.state.StateStore
|
||||
import io.airbyte.cdk.load.file.ClientSocket
|
||||
import io.airbyte.cdk.load.message.ProtocolMessageDeserializer
|
||||
import io.mockk.every
|
||||
import io.mockk.impl.annotations.MockK
|
||||
import io.mockk.junit5.MockKExtension
|
||||
import io.mockk.mockk
|
||||
import io.mockk.unmockkAll
|
||||
import io.mockk.verify
|
||||
import java.io.ByteArrayInputStream
|
||||
import org.junit.jupiter.api.AfterEach
|
||||
import org.junit.jupiter.api.Assertions.assertEquals
|
||||
import org.junit.jupiter.api.Assertions.assertNotNull
|
||||
import org.junit.jupiter.api.BeforeEach
|
||||
import org.junit.jupiter.api.Test
|
||||
import org.junit.jupiter.api.extension.ExtendWith
|
||||
|
||||
@ExtendWith(MockKExtension::class)
|
||||
class InputBeanFactoryTest {
|
||||
|
||||
@MockK private lateinit var deserializer: ProtocolMessageDeserializer
|
||||
|
||||
@MockK private lateinit var stateStore: StateStore
|
||||
|
||||
@MockK private lateinit var stateKeyClient: StateKeyClient
|
||||
|
||||
@MockK private lateinit var completionTracker: StreamCompletionTracker
|
||||
|
||||
@MockK private lateinit var parseStage: DataFlowStage
|
||||
|
||||
@MockK private lateinit var flushStage: DataFlowStage
|
||||
|
||||
@MockK private lateinit var stateStage: DataFlowStage
|
||||
|
||||
@MockK private lateinit var aggregateStoreFactory: AggregateStoreFactory
|
||||
|
||||
@MockK private lateinit var stateHistogramStore: StateHistogramStore
|
||||
|
||||
private var memoryAndParallelismConfig = MemoryAndParallelismConfig()
|
||||
|
||||
private lateinit var factory: InputBeanFactory
|
||||
|
||||
@BeforeEach
|
||||
fun setup() {
|
||||
factory = InputBeanFactory()
|
||||
}
|
||||
|
||||
@AfterEach
|
||||
fun tearDown() {
|
||||
unmockkAll()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `sockets should create ClientSocket instances from socket paths`() {
|
||||
// Given
|
||||
val socketPaths = listOf("/tmp/socket1", "/tmp/socket2", "/tmp/socket3")
|
||||
val bufferSizeBytes = 8192
|
||||
val socketConnectionTimeoutMs = 5000L
|
||||
|
||||
// When
|
||||
val result =
|
||||
factory.sockets(
|
||||
socketPaths = socketPaths,
|
||||
bufferSizeBytes = bufferSizeBytes,
|
||||
socketConnectionTimeoutMs = socketConnectionTimeoutMs
|
||||
)
|
||||
|
||||
// Then
|
||||
assertEquals(3, result.size)
|
||||
assertEquals("/tmp/socket1", result[0].socketPath)
|
||||
assertEquals("/tmp/socket2", result[1].socketPath)
|
||||
assertEquals("/tmp/socket3", result[2].socketPath)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `sockets should create single ClientSocket for single path`() {
|
||||
// Given
|
||||
val socketPaths = listOf("/tmp/single.socket")
|
||||
val bufferSizeBytes = 16384
|
||||
val socketConnectionTimeoutMs = 10000L
|
||||
|
||||
// When
|
||||
val result =
|
||||
factory.sockets(
|
||||
socketPaths = socketPaths,
|
||||
bufferSizeBytes = bufferSizeBytes,
|
||||
socketConnectionTimeoutMs = socketConnectionTimeoutMs
|
||||
)
|
||||
|
||||
// Then
|
||||
assertEquals(1, result.size)
|
||||
assertEquals("/tmp/single.socket", result[0].socketPath)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `socketStreams should create input streams from ClientSocket list`() {
|
||||
// Given
|
||||
val socket1 = mockk<ClientSocket>()
|
||||
val socket2 = mockk<ClientSocket>()
|
||||
val socket3 = mockk<ClientSocket>()
|
||||
val sockets = listOf(socket1, socket2, socket3)
|
||||
|
||||
val mockInputStream1 = ByteArrayInputStream("stream1".toByteArray())
|
||||
val mockInputStream2 = ByteArrayInputStream("stream2".toByteArray())
|
||||
val mockInputStream3 = ByteArrayInputStream("stream3".toByteArray())
|
||||
|
||||
every { socket1.openInputStream() } returns mockInputStream1
|
||||
every { socket2.openInputStream() } returns mockInputStream2
|
||||
every { socket3.openInputStream() } returns mockInputStream3
|
||||
|
||||
// When
|
||||
val result = factory.socketStreams(sockets)
|
||||
|
||||
// Then
|
||||
assertEquals(3, result.size)
|
||||
assertEquals(mockInputStream1, result[0])
|
||||
assertEquals(mockInputStream2, result[1])
|
||||
assertEquals(mockInputStream3, result[2])
|
||||
|
||||
verify(exactly = 1) { socket1.openInputStream() }
|
||||
verify(exactly = 1) { socket2.openInputStream() }
|
||||
verify(exactly = 1) { socket3.openInputStream() }
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `socketStreams should handle single socket`() {
|
||||
// Given
|
||||
val socket = mockk<ClientSocket>()
|
||||
val sockets = listOf(socket)
|
||||
val mockInputStream = ByteArrayInputStream("single stream".toByteArray())
|
||||
|
||||
every { socket.openInputStream() } returns mockInputStream
|
||||
|
||||
// When
|
||||
val result = factory.socketStreams(sockets)
|
||||
|
||||
// Then
|
||||
assertEquals(1, result.size)
|
||||
assertEquals(mockInputStream, result[0])
|
||||
verify(exactly = 1) { socket.openInputStream() }
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `stdInStreams should return System in`() {
|
||||
// When
|
||||
val result = factory.stdInStreams()
|
||||
|
||||
// Then
|
||||
assertEquals(1, result.size)
|
||||
assertEquals(System.`in`, result[0])
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `messageFlows should create DestinationMessageInputFlow for each input stream`() {
|
||||
// Given
|
||||
val inputStream1 = ByteArrayInputStream("stream1".toByteArray())
|
||||
val inputStream2 = ByteArrayInputStream("stream2".toByteArray())
|
||||
val inputStreams = listOf(inputStream1, inputStream2)
|
||||
|
||||
// When
|
||||
val result = factory.messageFlows(inputStreams, deserializer)
|
||||
|
||||
// Then
|
||||
assertEquals(2, result.size)
|
||||
assertNotNull(result[0])
|
||||
assertNotNull(result[1])
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `inputFlows should create DataFlowPipelineInputFlow for each message flow`() {
|
||||
// Given
|
||||
val messageFlow1 = mockk<DestinationMessageInputFlow>()
|
||||
val messageFlow2 = mockk<DestinationMessageInputFlow>()
|
||||
val messageFlows = listOf(messageFlow1, messageFlow2)
|
||||
|
||||
// When
|
||||
val result =
|
||||
factory.inputFlows(
|
||||
messageFlows = messageFlows,
|
||||
stateStore = stateStore,
|
||||
stateKeyClient = stateKeyClient,
|
||||
completionTracker = completionTracker
|
||||
)
|
||||
|
||||
// Then
|
||||
assertEquals(2, result.size)
|
||||
assertNotNull(result[0])
|
||||
assertNotNull(result[1])
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `pipes should create DataFlowPipeline for each input flow`() {
|
||||
every { aggregateStoreFactory.make() } returns mockk()
|
||||
|
||||
// Given
|
||||
val inputFlow1 = mockk<DataFlowPipelineInputFlow>()
|
||||
val inputFlow2 = mockk<DataFlowPipelineInputFlow>()
|
||||
val inputFlows = listOf(inputFlow1, inputFlow2)
|
||||
|
||||
// When
|
||||
val result =
|
||||
factory.pipes(
|
||||
inputFlows = inputFlows,
|
||||
parse = parseStage,
|
||||
flush = flushStage,
|
||||
state = stateStage,
|
||||
aggregateStoreFactory = aggregateStoreFactory,
|
||||
stateHistogramStore = stateHistogramStore,
|
||||
memoryAndParallelismConfig = memoryAndParallelismConfig
|
||||
)
|
||||
|
||||
// Then
|
||||
assertEquals(2, result.size)
|
||||
assertNotNull(result[0])
|
||||
assertNotNull(result[1])
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `pipes should create different aggregate store for each pipeline`() {
|
||||
// Given
|
||||
val inputFlow1 = mockk<DataFlowPipelineInputFlow>()
|
||||
val inputFlow2 = mockk<DataFlowPipelineInputFlow>()
|
||||
val inputFlow3 = mockk<DataFlowPipelineInputFlow>()
|
||||
val inputFlows = listOf(inputFlow1, inputFlow2, inputFlow3)
|
||||
|
||||
val aggregateStore1 = mockk<AggregateStore>(relaxed = true)
|
||||
val aggregateStore2 = mockk<AggregateStore>(relaxed = true)
|
||||
val aggregateStore3 = mockk<AggregateStore>(relaxed = true)
|
||||
|
||||
// Mock the factory to return different instances
|
||||
every { aggregateStoreFactory.make() } returnsMany
|
||||
listOf(aggregateStore1, aggregateStore2, aggregateStore3)
|
||||
|
||||
// When
|
||||
val result =
|
||||
factory.pipes(
|
||||
inputFlows = inputFlows,
|
||||
parse = parseStage,
|
||||
flush = flushStage,
|
||||
state = stateStage,
|
||||
aggregateStoreFactory = aggregateStoreFactory,
|
||||
stateHistogramStore = stateHistogramStore,
|
||||
memoryAndParallelismConfig = memoryAndParallelismConfig
|
||||
)
|
||||
|
||||
// Then
|
||||
assertEquals(3, result.size)
|
||||
|
||||
// Verify that the factory was called 3 times to create 3 different stores
|
||||
verify(exactly = 3) { aggregateStoreFactory.make() }
|
||||
|
||||
// Each pipeline should have received a different aggregate store
|
||||
// We can't directly verify which store went to which pipeline without accessing internals,
|
||||
// but we verified that 3 different stores were created via the factory
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `integration test - complete flow from socket paths to pipelines`() {
|
||||
// Given
|
||||
val mockInputStream1 = ByteArrayInputStream("stream1".toByteArray())
|
||||
val mockInputStream2 = ByteArrayInputStream("stream2".toByteArray())
|
||||
|
||||
every { aggregateStoreFactory.make() } returns mockk()
|
||||
|
||||
val inputStreams = listOf(mockInputStream1, mockInputStream2)
|
||||
|
||||
val messageFlows = factory.messageFlows(inputStreams, deserializer)
|
||||
|
||||
val inputFlows =
|
||||
factory.inputFlows(
|
||||
messageFlows = messageFlows,
|
||||
stateStore = stateStore,
|
||||
stateKeyClient = stateKeyClient,
|
||||
completionTracker = completionTracker
|
||||
)
|
||||
|
||||
val pipes =
|
||||
factory.pipes(
|
||||
inputFlows = inputFlows,
|
||||
parse = parseStage,
|
||||
flush = flushStage,
|
||||
state = stateStage,
|
||||
aggregateStoreFactory = aggregateStoreFactory,
|
||||
stateHistogramStore = stateHistogramStore,
|
||||
memoryAndParallelismConfig = memoryAndParallelismConfig
|
||||
)
|
||||
|
||||
// Then
|
||||
assertEquals(2, inputStreams.size)
|
||||
assertEquals(2, messageFlows.size)
|
||||
assertEquals(2, inputFlows.size)
|
||||
assertEquals(2, pipes.size)
|
||||
}
|
||||
}
|
||||
@@ -20,7 +20,6 @@ class DataFlowPipelineTest {
|
||||
private val aggregate = mockk<AggregateStage>()
|
||||
private val flush = mockk<DataFlowStage>()
|
||||
private val state = mockk<DataFlowStage>()
|
||||
private val startHandler = mockk<PipelineStartHandler>()
|
||||
private val completionHandler = mockk<PipelineCompletionHandler>()
|
||||
private val memoryAndParallelismConfig =
|
||||
MemoryAndParallelismConfig(
|
||||
@@ -40,7 +39,6 @@ class DataFlowPipelineTest {
|
||||
aggregate,
|
||||
flush,
|
||||
state,
|
||||
startHandler,
|
||||
completionHandler,
|
||||
memoryAndParallelismConfig
|
||||
)
|
||||
@@ -50,7 +48,6 @@ class DataFlowPipelineTest {
|
||||
val flushedIO = mockk<DataFlowStageIO>()
|
||||
val stateIO = mockk<DataFlowStageIO>()
|
||||
|
||||
coEvery { startHandler.run() } returns Unit
|
||||
coEvery { parse.apply(initialIO) } returns parsedIO
|
||||
coEvery { aggregate.apply(parsedIO, any()) } coAnswers
|
||||
{
|
||||
@@ -66,7 +63,6 @@ class DataFlowPipelineTest {
|
||||
|
||||
// Then
|
||||
coVerifySequence {
|
||||
startHandler.run()
|
||||
parse.apply(initialIO)
|
||||
aggregate.apply(parsedIO, any())
|
||||
flush.apply(aggregatedIO)
|
||||
|
||||
@@ -9,7 +9,6 @@ import io.airbyte.cdk.load.dataflow.aggregate.AggregateEntry
|
||||
import io.airbyte.cdk.load.dataflow.aggregate.AggregateStore
|
||||
import io.airbyte.cdk.load.dataflow.state.PartitionHistogram
|
||||
import io.airbyte.cdk.load.dataflow.state.StateHistogramStore
|
||||
import io.airbyte.cdk.load.dataflow.state.StateReconciler
|
||||
import io.mockk.Runs
|
||||
import io.mockk.coEvery
|
||||
import io.mockk.coVerify
|
||||
@@ -33,8 +32,6 @@ class PipelineCompletionHandlerTest {
|
||||
|
||||
@MockK private lateinit var stateHistogramStore: StateHistogramStore
|
||||
|
||||
@MockK private lateinit var reconciler: StateReconciler
|
||||
|
||||
private lateinit var pipelineCompletionHandler: PipelineCompletionHandler
|
||||
|
||||
@BeforeEach
|
||||
@@ -43,10 +40,7 @@ class PipelineCompletionHandlerTest {
|
||||
PipelineCompletionHandler(
|
||||
aggStore = aggStore,
|
||||
stateHistogramStore = stateHistogramStore,
|
||||
reconciler = reconciler
|
||||
)
|
||||
|
||||
coEvery { reconciler.disable() } just Runs
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -58,7 +52,6 @@ class PipelineCompletionHandlerTest {
|
||||
val thrownException =
|
||||
assertThrows<RuntimeException> { pipelineCompletionHandler.apply(testException) }
|
||||
|
||||
coVerify(exactly = 1) { reconciler.disable() }
|
||||
assertEquals("Test exception", thrownException.message)
|
||||
}
|
||||
|
||||
@@ -92,7 +85,6 @@ class PipelineCompletionHandlerTest {
|
||||
coEvery { mockAggregate1.flush() } just Runs
|
||||
coEvery { mockAggregate2.flush() } just Runs
|
||||
every { stateHistogramStore.acceptFlushedCounts(any()) } returns mockk()
|
||||
every { reconciler.flushCompleteStates() } just Runs
|
||||
|
||||
// When
|
||||
pipelineCompletionHandler.apply(null)
|
||||
@@ -102,23 +94,18 @@ class PipelineCompletionHandlerTest {
|
||||
coVerify(exactly = 1) { mockAggregate2.flush() }
|
||||
verify(exactly = 1) { stateHistogramStore.acceptFlushedCounts(mockHistogram1) }
|
||||
verify(exactly = 1) { stateHistogramStore.acceptFlushedCounts(mockHistogram2) }
|
||||
coVerify(exactly = 1) { reconciler.disable() }
|
||||
verify(exactly = 1) { reconciler.flushCompleteStates() }
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `apply should handle empty aggregates list`() = runTest {
|
||||
// Given
|
||||
every { aggStore.getAll() } returns emptyList()
|
||||
every { reconciler.flushCompleteStates() } just Runs
|
||||
|
||||
// When
|
||||
pipelineCompletionHandler.apply(null)
|
||||
|
||||
// Then
|
||||
verify(exactly = 1) { aggStore.getAll() }
|
||||
coVerify(exactly = 1) { reconciler.disable() }
|
||||
verify(exactly = 1) { reconciler.flushCompleteStates() }
|
||||
verify(exactly = 0) { stateHistogramStore.acceptFlushedCounts(any()) }
|
||||
}
|
||||
|
||||
@@ -140,12 +127,10 @@ class PipelineCompletionHandlerTest {
|
||||
|
||||
every { aggStore.getAll() } returns listOf(aggregateEntry)
|
||||
coEvery { mockAggregate.flush() } throws flushException
|
||||
every { reconciler.flushCompleteStates() } just Runs
|
||||
|
||||
// When & Then
|
||||
assertThrows<RuntimeException> { pipelineCompletionHandler.apply(null) }
|
||||
|
||||
coVerify(exactly = 1) { reconciler.disable() }
|
||||
coVerify(exactly = 1) { mockAggregate.flush() }
|
||||
// Note: acceptFlushedCounts should not be called if flush fails
|
||||
verify(exactly = 0) { stateHistogramStore.acceptFlushedCounts(any()) }
|
||||
|
||||
@@ -0,0 +1,249 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
|
||||
*/
|
||||
|
||||
package io.airbyte.cdk.load.dataflow.pipeline
|
||||
|
||||
import io.airbyte.cdk.load.dataflow.state.StateReconciler
|
||||
import io.mockk.Runs
|
||||
import io.mockk.coEvery
|
||||
import io.mockk.coVerify
|
||||
import io.mockk.coVerifySequence
|
||||
import io.mockk.every
|
||||
import io.mockk.impl.annotations.MockK
|
||||
import io.mockk.junit5.MockKExtension
|
||||
import io.mockk.just
|
||||
import kotlinx.coroutines.CoroutineScope
|
||||
import kotlinx.coroutines.delay
|
||||
import kotlinx.coroutines.runBlocking
|
||||
import kotlinx.coroutines.test.runTest
|
||||
import org.junit.jupiter.api.Assertions.assertEquals
|
||||
import org.junit.jupiter.api.Assertions.assertTrue
|
||||
import org.junit.jupiter.api.BeforeEach
|
||||
import org.junit.jupiter.api.Test
|
||||
import org.junit.jupiter.api.assertThrows
|
||||
import org.junit.jupiter.api.extension.ExtendWith
|
||||
|
||||
@ExtendWith(MockKExtension::class)
|
||||
class PipelineRunnerTest {
|
||||
|
||||
@MockK(relaxed = true) private lateinit var reconciler: StateReconciler
|
||||
|
||||
@MockK private lateinit var pipeline1: DataFlowPipeline
|
||||
|
||||
@MockK private lateinit var pipeline2: DataFlowPipeline
|
||||
|
||||
@MockK private lateinit var pipeline3: DataFlowPipeline
|
||||
|
||||
private lateinit var runner: PipelineRunner
|
||||
|
||||
@BeforeEach
|
||||
fun setup() {
|
||||
runner = PipelineRunner(reconciler, listOf(pipeline1, pipeline2, pipeline3))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `run should execute all pipelines concurrently`() = runTest {
|
||||
// Given
|
||||
coEvery { pipeline1.run() } coAnswers { delay(100) }
|
||||
coEvery { pipeline2.run() } coAnswers { delay(50) }
|
||||
coEvery { pipeline3.run() } coAnswers { delay(75) }
|
||||
|
||||
// When
|
||||
val startTime = System.currentTimeMillis()
|
||||
runner.run()
|
||||
val endTime = System.currentTimeMillis()
|
||||
|
||||
// Then
|
||||
coVerify(exactly = 1) { pipeline1.run() }
|
||||
coVerify(exactly = 1) { pipeline2.run() }
|
||||
coVerify(exactly = 1) { pipeline3.run() }
|
||||
|
||||
// Verify they ran concurrently (total time should be close to the longest pipeline)
|
||||
// Allow some buffer for execution overhead
|
||||
assertTrue((endTime - startTime) < 200, "Pipelines should run concurrently")
|
||||
|
||||
coVerify(exactly = 1) { reconciler.run(any()) }
|
||||
coVerify(exactly = 1) { reconciler.disable() }
|
||||
coVerify(exactly = 1) { reconciler.flushCompleteStates() }
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `run should disable reconciler even when pipeline fails`() = runTest {
|
||||
// Given
|
||||
val exception = RuntimeException("Pipeline failed")
|
||||
coEvery { pipeline1.run() } throws exception
|
||||
val failingRunner = PipelineRunner(reconciler, listOf(pipeline1))
|
||||
|
||||
// When/Then
|
||||
val thrownException = assertThrows<RuntimeException> { runBlocking { failingRunner.run() } }
|
||||
|
||||
assertEquals("Pipeline failed", thrownException.message)
|
||||
|
||||
// Verify reconciler was still disabled in finally block
|
||||
coVerify(exactly = 1) { reconciler.disable() }
|
||||
coVerify(exactly = 0) { reconciler.flushCompleteStates() }
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `run should execute operations in correct order`() = runTest {
|
||||
// Given
|
||||
coEvery { pipeline1.run() } just Runs
|
||||
val singlePipelineRunner = PipelineRunner(reconciler, listOf(pipeline1))
|
||||
|
||||
// When
|
||||
singlePipelineRunner.run()
|
||||
|
||||
// Then - verify order of operations
|
||||
coVerifySequence {
|
||||
reconciler.run(any())
|
||||
pipeline1.run()
|
||||
reconciler.disable()
|
||||
reconciler.flushCompleteStates()
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `run should pass correct CoroutineScope to reconciler`() = runTest {
|
||||
// Given
|
||||
var capturedScope: CoroutineScope? = null
|
||||
|
||||
every { reconciler.run(any()) } answers { capturedScope = firstArg() }
|
||||
|
||||
val emptyRunner = PipelineRunner(reconciler, emptyList())
|
||||
|
||||
// When
|
||||
emptyRunner.run()
|
||||
|
||||
// Then
|
||||
assertTrue(capturedScope != null, "CoroutineScope should be passed to reconciler")
|
||||
// Verify it's using IO dispatcher
|
||||
assertTrue(capturedScope.toString().contains("Dispatchers.IO"), "Should use IO dispatcher")
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `run should handle single pipeline`() = runTest {
|
||||
// Given
|
||||
coEvery { pipeline1.run() } just Runs
|
||||
val singlePipelineRunner = PipelineRunner(reconciler, listOf(pipeline1))
|
||||
|
||||
// When
|
||||
singlePipelineRunner.run()
|
||||
|
||||
// Then
|
||||
coVerify(exactly = 1) { pipeline1.run() }
|
||||
coVerify(exactly = 1) { reconciler.run(any()) }
|
||||
coVerify(exactly = 1) { reconciler.disable() }
|
||||
coVerify(exactly = 1) { reconciler.flushCompleteStates() }
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `run should handle reconciler disable failure`() = runTest {
|
||||
// Given
|
||||
coEvery { pipeline1.run() } just Runs
|
||||
coEvery { reconciler.disable() } throws RuntimeException("Failed to disable")
|
||||
val singlePipelineRunner = PipelineRunner(reconciler, listOf(pipeline1))
|
||||
|
||||
// When/Then
|
||||
val exception =
|
||||
assertThrows<RuntimeException> { runBlocking { singlePipelineRunner.run() } }
|
||||
|
||||
assertEquals("Failed to disable", exception.message)
|
||||
|
||||
// Verify pipeline was executed before the failure
|
||||
coVerify(exactly = 1) { pipeline1.run() }
|
||||
coVerify(exactly = 0) { reconciler.flushCompleteStates() }
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `run should handle reconciler flushCompleteStates failure`() = runTest {
|
||||
// Given
|
||||
coEvery { pipeline1.run() } just Runs
|
||||
every { reconciler.flushCompleteStates() } throws RuntimeException("Failed to flush")
|
||||
val singlePipelineRunner = PipelineRunner(reconciler, listOf(pipeline1))
|
||||
|
||||
// When/Then
|
||||
val exception =
|
||||
assertThrows<RuntimeException> { runBlocking { singlePipelineRunner.run() } }
|
||||
|
||||
assertEquals("Failed to flush", exception.message)
|
||||
|
||||
// Verify everything up to flush was executed
|
||||
coVerify(exactly = 1) { pipeline1.run() }
|
||||
coVerify(exactly = 1) { reconciler.disable() }
|
||||
coVerify(exactly = 1) { reconciler.flushCompleteStates() }
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `pipelines property should return the list of pipelines`() {
|
||||
// When
|
||||
val result = runner.pipelines
|
||||
|
||||
// Then
|
||||
assertEquals(3, result.size)
|
||||
assertEquals(pipeline1, result[0])
|
||||
assertEquals(pipeline2, result[1])
|
||||
assertEquals(pipeline3, result[2])
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `run should handle large number of pipelines`() = runTest {
|
||||
// Given
|
||||
val pipelines =
|
||||
(1..100).map {
|
||||
@MockK val pipeline = io.mockk.mockk<DataFlowPipeline>()
|
||||
coEvery { pipeline.run() } just Runs
|
||||
pipeline
|
||||
}
|
||||
|
||||
val largeRunner = PipelineRunner(reconciler, pipelines)
|
||||
|
||||
// When
|
||||
largeRunner.run()
|
||||
|
||||
// Then
|
||||
pipelines.forEach { pipeline -> coVerify(exactly = 1) { pipeline.run() } }
|
||||
coVerify(exactly = 1) { reconciler.run(any()) }
|
||||
coVerify(exactly = 1) { reconciler.disable() }
|
||||
coVerify(exactly = 1) { reconciler.flushCompleteStates() }
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `run should wait for all pipelines to complete before disabling reconciler`() = runTest {
|
||||
// Given
|
||||
var pipeline1Complete = false
|
||||
var pipeline2Complete = false
|
||||
var reconcilerDisabled = false
|
||||
|
||||
coEvery { pipeline1.run() } coAnswers
|
||||
{
|
||||
delay(100)
|
||||
pipeline1Complete = true
|
||||
}
|
||||
coEvery { pipeline2.run() } coAnswers
|
||||
{
|
||||
delay(50)
|
||||
pipeline2Complete = true
|
||||
}
|
||||
coEvery { reconciler.disable() } answers
|
||||
{
|
||||
reconcilerDisabled = true
|
||||
assertTrue(
|
||||
pipeline1Complete,
|
||||
"Pipeline 1 should be complete before disabling reconciler"
|
||||
)
|
||||
assertTrue(
|
||||
pipeline2Complete,
|
||||
"Pipeline 2 should be complete before disabling reconciler"
|
||||
)
|
||||
}
|
||||
|
||||
val twoRunner = PipelineRunner(reconciler, listOf(pipeline1, pipeline2))
|
||||
|
||||
// When
|
||||
twoRunner.run()
|
||||
|
||||
// Then
|
||||
assertTrue(reconcilerDisabled, "Reconciler should have been disabled")
|
||||
}
|
||||
}
|
||||
@@ -1,42 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
|
||||
*/
|
||||
|
||||
package io.airbyte.cdk.load.dataflow.pipeline
|
||||
|
||||
import io.airbyte.cdk.load.dataflow.state.StateReconciler
|
||||
import io.mockk.Runs
|
||||
import io.mockk.every
|
||||
import io.mockk.impl.annotations.MockK
|
||||
import io.mockk.junit5.MockKExtension
|
||||
import io.mockk.just
|
||||
import io.mockk.verify
|
||||
import org.junit.jupiter.api.BeforeEach
|
||||
import org.junit.jupiter.api.Test
|
||||
import org.junit.jupiter.api.extension.ExtendWith
|
||||
|
||||
@ExtendWith(MockKExtension::class)
|
||||
class PipelineStartHandlerTest {
|
||||
|
||||
@MockK private lateinit var reconciler: StateReconciler
|
||||
|
||||
private lateinit var pipelineStartHandler: PipelineStartHandler
|
||||
|
||||
@BeforeEach
|
||||
fun setUp() {
|
||||
pipelineStartHandler = PipelineStartHandler(reconciler)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `run should call reconciler run method`() {
|
||||
// Given
|
||||
|
||||
every { reconciler.run(any()) } just Runs
|
||||
|
||||
// When
|
||||
pipelineStartHandler.run()
|
||||
|
||||
// Then
|
||||
verify(exactly = 1) { reconciler.run(any()) }
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user