1
0
mirror of synced 2025-12-23 21:03:15 -05:00

Socket support for Dataflow (#65606)

This commit is contained in:
Ryan Br...
2025-09-04 16:18:50 -07:00
committed by GitHub
parent f42b67383e
commit e401d477ec
20 changed files with 829 additions and 138 deletions

View File

@@ -6,7 +6,7 @@ import javax.xml.xpath.XPathFactory
import org.w3c.dom.Document
allprojects {
version = "0.1.20"
version = "0.1.21"
apply plugin: 'java-library'
apply plugin: 'maven-publish'

View File

@@ -1,5 +1,9 @@
**Load CDK**
## Version 0.1.21
* **Changed:** Adds basic socket support.
## Version 0.1.20
* **Changed:** Fix hard failure edge case in stream initialization for in dataflow cdk lifecycle.

View File

@@ -7,7 +7,7 @@ package io.airbyte.cdk.load.dataflow
import io.airbyte.cdk.load.command.DestinationCatalog
import io.airbyte.cdk.load.dataflow.config.MemoryAndParallelismConfig
import io.airbyte.cdk.load.dataflow.finalization.StreamCompletionTracker
import io.airbyte.cdk.load.dataflow.pipeline.DataFlowPipeline
import io.airbyte.cdk.load.dataflow.pipeline.PipelineRunner
import io.airbyte.cdk.load.write.DestinationWriter
import io.airbyte.cdk.load.write.StreamLoader
import io.github.oshai.kotlinlogging.KotlinLogging
@@ -23,7 +23,7 @@ import kotlinx.coroutines.runBlocking
class DestinationLifecycle(
private val destinationInitializer: DestinationWriter,
private val destinationCatalog: DestinationCatalog,
private val pipeline: DataFlowPipeline,
private val pipeline: PipelineRunner,
private val completionTracker: StreamCompletionTracker,
private val memoryAndParallelismConfig: MemoryAndParallelismConfig,
) {

View File

@@ -10,15 +10,13 @@ import io.airbyte.cdk.load.dataflow.config.MemoryAndParallelismConfig
import io.airbyte.cdk.load.dataflow.state.PartitionHistogram
import io.airbyte.cdk.load.dataflow.transform.RecordDTO
import io.github.oshai.kotlinlogging.KotlinLogging
import jakarta.inject.Singleton
import java.util.concurrent.ConcurrentHashMap
typealias StoreKey = DestinationStream.Descriptor
@Singleton
class AggregateStore(
private val aggFactory: AggregateFactory,
private val memoryAndParallelismConfig: MemoryAndParallelismConfig,
memoryAndParallelismConfig: MemoryAndParallelismConfig,
) {
private val log = KotlinLogging.logger {}
@@ -102,3 +100,11 @@ data class AggregateEntry(
return stalenessTrigger.isComplete(ts)
}
}
/* For testing purposes so we can mock. */
class AggregateStoreFactory(
private val aggFactory: AggregateFactory,
private val memoryAndParallelismConfig: MemoryAndParallelismConfig,
) {
fun make() = AggregateStore(aggFactory, memoryAndParallelismConfig)
}

View File

@@ -0,0 +1,126 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.cdk.load.dataflow.config
import io.airbyte.cdk.load.dataflow.aggregate.AggregateFactory
import io.airbyte.cdk.load.dataflow.aggregate.AggregateStoreFactory
import io.airbyte.cdk.load.dataflow.finalization.StreamCompletionTracker
import io.airbyte.cdk.load.dataflow.input.DataFlowPipelineInputFlow
import io.airbyte.cdk.load.dataflow.input.DestinationMessageInputFlow
import io.airbyte.cdk.load.dataflow.pipeline.DataFlowPipeline
import io.airbyte.cdk.load.dataflow.pipeline.DataFlowStage
import io.airbyte.cdk.load.dataflow.pipeline.PipelineCompletionHandler
import io.airbyte.cdk.load.dataflow.stages.AggregateStage
import io.airbyte.cdk.load.dataflow.state.StateHistogramStore
import io.airbyte.cdk.load.dataflow.state.StateKeyClient
import io.airbyte.cdk.load.dataflow.state.StateStore
import io.airbyte.cdk.load.file.ClientSocket
import io.airbyte.cdk.load.message.ProtocolMessageDeserializer
import io.micronaut.context.annotation.Factory
import io.micronaut.context.annotation.Requires
import io.micronaut.context.annotation.Value
import jakarta.inject.Named
import jakarta.inject.Singleton
import java.io.InputStream
/**
* Conditionally creates input streams / sockets based on channel medium, then wires up a pipeline
* to each input with separate aggregate stores but shared state stores.
*/
@Factory
class InputBeanFactory {
@Requires(property = "airbyte.destination.core.data-channel.medium", value = "SOCKET")
@Singleton
fun sockets(
@Value("\${airbyte.destination.core.data-channel.socket-paths}") socketPaths: List<String>,
@Value("\${airbyte.destination.core.data-channel.socket-buffer-size-bytes}")
bufferSizeBytes: Int,
@Value("\${airbyte.destination.core.data-channel.socket-connection-timeout-ms}")
socketConnectionTimeoutMs: Long,
): List<ClientSocket> =
socketPaths.map {
ClientSocket(
socketPath = it,
bufferSizeBytes = bufferSizeBytes,
connectTimeoutMs = socketConnectionTimeoutMs,
)
}
@Requires(property = "airbyte.destination.core.data-channel.medium", value = "SOCKET")
@Named("inputStreams")
@Singleton
fun socketStreams(
sockets: List<ClientSocket>,
): List<InputStream> = sockets.map(ClientSocket::openInputStream)
@Requires(property = "airbyte.destination.core.data-channel.medium", value = "STDIO")
@Named("inputStreams")
@Singleton
fun stdInStreams(): List<InputStream> = listOf(System.`in`)
@Singleton
fun messageFlows(
@Named("inputStreams") inputStreams: List<InputStream>,
deserializer: ProtocolMessageDeserializer,
): List<DestinationMessageInputFlow> =
inputStreams.map {
DestinationMessageInputFlow(
inputStream = it,
deserializer = deserializer,
)
}
@Singleton
fun inputFlows(
messageFlows: List<DestinationMessageInputFlow>,
stateStore: StateStore,
stateKeyClient: StateKeyClient,
completionTracker: StreamCompletionTracker,
): List<DataFlowPipelineInputFlow> =
messageFlows.map {
DataFlowPipelineInputFlow(
inputFlow = it,
stateStore = stateStore,
stateKeyClient = stateKeyClient,
completionTracker = completionTracker,
)
}
@Singleton
fun aggregateStoreFactory(
aggFactory: AggregateFactory,
memoryAndParallelismConfig: MemoryAndParallelismConfig,
) =
AggregateStoreFactory(
aggFactory,
memoryAndParallelismConfig,
)
@Singleton
fun pipes(
inputFlows: List<DataFlowPipelineInputFlow>,
@Named("parse") parse: DataFlowStage,
@Named("flush") flush: DataFlowStage,
@Named("state") state: DataFlowStage,
aggregateStoreFactory: AggregateStoreFactory,
stateHistogramStore: StateHistogramStore,
memoryAndParallelismConfig: MemoryAndParallelismConfig,
): List<DataFlowPipeline> =
inputFlows.map {
val aggStore = aggregateStoreFactory.make()
val aggregate = AggregateStage(aggStore)
val completionHandler = PipelineCompletionHandler(aggStore, stateHistogramStore)
DataFlowPipeline(
input = it,
parse = parse,
aggregate = aggregate,
flush = flush,
state = state,
completionHandler = completionHandler,
memoryAndParallelismConfig = memoryAndParallelismConfig,
)
}
}

View File

@@ -14,8 +14,8 @@ import kotlin.time.Duration.Companion.minutes
* progress.
* - maxEstBytesPerAgg configures the estimated size of each aggregate.
* - The max memory consumption is (maxEstBytesPerAgg * maxConcurrentAggregates) +
* (maxEstBytesPerAgg * 2). Example with default values: (70,000,000 * 5) + (70,000,000 * 2) =
* 350,000,000 + 140,000,000 = 490,000,000 bytes (approx 0.49 GB).
* (maxEstBytesPerAgg * maxBufferedAggregates). Example with default values: (50,000,000 * 5) +
* (50,000,000 * 3) = 300,000,000 + 150,000,000 = 450,000,000 bytes (approx 0.45 GB).
* - stalenessDeadlinePerAggMs is how long we will wait to flush an aggregate if it is not
* fulfilling the requirement of entry count or max memory.
* - maxRecordsPerAgg configures the max number of records in an aggregate.
@@ -23,10 +23,10 @@ import kotlin.time.Duration.Companion.minutes
*/
data class MemoryAndParallelismConfig(
val maxOpenAggregates: Int = 5,
val maxBufferedAggregates: Int = 5,
val maxBufferedAggregates: Int = 3,
val stalenessDeadlinePerAgg: Duration = 5.minutes,
val maxRecordsPerAgg: Long = 100_000L,
val maxEstBytesPerAgg: Long = 70_000_000L,
val maxEstBytesPerAgg: Long = 50_000_000L,
val maxConcurrentLifecycleOperations: Int = 10,
) {
init {

View File

@@ -12,7 +12,7 @@ import io.airbyte.cdk.load.message.CheckpointMessage
import io.airbyte.cdk.load.message.DestinationMessage
import io.airbyte.cdk.load.message.DestinationRecord
import io.airbyte.cdk.load.message.DestinationRecordStreamComplete
import jakarta.inject.Singleton
import io.github.oshai.kotlinlogging.KotlinLogging
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.FlowCollector
@@ -22,13 +22,14 @@ import kotlinx.coroutines.flow.FlowCollector
* Adds state ids to the input, handling the serial case where we infer the state id from a global
* counter.
*/
@Singleton
class DataFlowPipelineInputFlow(
private val inputFlow: Flow<DestinationMessage>,
private val stateStore: StateStore,
private val stateKeyClient: StateKeyClient,
private val completionTracker: StreamCompletionTracker,
) : Flow<DataFlowStageIO> {
val log = KotlinLogging.logger {}
override suspend fun collect(
collector: FlowCollector<DataFlowStageIO>,
) {
@@ -48,5 +49,7 @@ class DataFlowPipelineInputFlow(
else -> Unit
}
}
log.info { "Finished routing input." }
}
}

View File

@@ -7,16 +7,15 @@ package io.airbyte.cdk.load.dataflow.input
import io.airbyte.cdk.load.message.DestinationMessage
import io.airbyte.cdk.load.message.ProtocolMessageDeserializer
import io.github.oshai.kotlinlogging.KotlinLogging
import jakarta.inject.Named
import jakarta.inject.Singleton
import java.io.InputStream
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.FlowCollector
import kotlinx.coroutines.withContext
/** Takes bytes and emits DestinationMessages */
@Singleton
class DestinationMessageInputFlow(
@Named("inputStream") private val inputStream: InputStream,
private val inputStream: InputStream,
private val deserializer: ProtocolMessageDeserializer,
) : Flow<DestinationMessage> {
val log = KotlinLogging.logger {}
@@ -26,21 +25,29 @@ class DestinationMessageInputFlow(
) {
var msgCount = 0L
var estBytes = 0L
inputStream
.bufferedReader()
.lineSequence()
.filter { it.isNotEmpty() }
.forEach { line ->
val message = deserializer.deserialize(line)
collector.emit(message)
val reader = inputStream.bufferedReader()
try {
reader
.lineSequence()
.filter { it.isNotEmpty() }
.forEach { line ->
val message = deserializer.deserialize(line)
estBytes += line.length
if (++msgCount % 100_000 == 0L) {
log.info { "Processed $msgCount messages (${estBytes/1024/1024}Mb)" }
collector.emit(message)
estBytes += line.length
if (++msgCount % 100_000 == 0L) {
log.info { "Processed $msgCount messages (${estBytes/1024/1024}Mb)" }
}
}
} finally {
withContext(Dispatchers.IO) {
reader.close()
log.info { "Input stream reader closed." }
}
}
log.info { "Finished processing input" }
log.info { "Finished reading input." }
}
}

View File

@@ -6,31 +6,25 @@ package io.airbyte.cdk.load.dataflow.pipeline
import io.airbyte.cdk.load.dataflow.config.MemoryAndParallelismConfig
import io.airbyte.cdk.load.dataflow.stages.AggregateStage
import jakarta.inject.Named
import jakarta.inject.Singleton
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.buffer
import kotlinx.coroutines.flow.flowOn
import kotlinx.coroutines.flow.map
import kotlinx.coroutines.flow.onCompletion
import kotlinx.coroutines.flow.onStart
import kotlinx.coroutines.flow.transform
@Singleton
class DataFlowPipeline(
private val input: Flow<DataFlowStageIO>,
@Named("parse") private val parse: DataFlowStage,
@Named("aggregate") private val aggregate: AggregateStage,
@Named("flush") private val flush: DataFlowStage,
@Named("state") private val state: DataFlowStage,
private val startHandler: PipelineStartHandler,
private val parse: DataFlowStage,
private val aggregate: AggregateStage,
private val flush: DataFlowStage,
private val state: DataFlowStage,
private val completionHandler: PipelineCompletionHandler,
private val memoryAndParallelismConfig: MemoryAndParallelismConfig,
) {
suspend fun run() {
input
.onStart { startHandler.run() }
.map(parse::apply)
.transform { aggregate.apply(it, this) }
.buffer(capacity = memoryAndParallelismConfig.maxBufferedAggregates)

View File

@@ -6,34 +6,25 @@ package io.airbyte.cdk.load.dataflow.pipeline
import io.airbyte.cdk.load.dataflow.aggregate.AggregateStore
import io.airbyte.cdk.load.dataflow.state.StateHistogramStore
import io.airbyte.cdk.load.dataflow.state.StateReconciler
import io.github.oshai.kotlinlogging.KotlinLogging
import jakarta.inject.Singleton
import kotlinx.coroutines.async
import kotlinx.coroutines.awaitAll
import kotlinx.coroutines.coroutineScope
@Singleton
class PipelineCompletionHandler(
private val aggStore: AggregateStore,
private val stateHistogramStore: StateHistogramStore,
private val reconciler: StateReconciler,
) {
private val log = KotlinLogging.logger {}
suspend fun apply(
cause: Throwable?,
) = coroutineScope {
// shutdown the reconciler regardless of success or failure, so we don't hang
reconciler.disable()
if (cause != null) {
log.error { "Destination Pipeline Completed — Exceptionally" }
throw cause
}
log.info { "Destination Pipeline Completed — Successfully" }
val remainingAggregates = aggStore.getAll()
log.info { "Flushing ${remainingAggregates.size} final aggregates..." }
@@ -46,6 +37,6 @@ class PipelineCompletionHandler(
}
.awaitAll()
reconciler.flushCompleteStates()
log.info { "Final aggregates flushed." }
}
}

View File

@@ -0,0 +1,43 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.cdk.load.dataflow.pipeline
import io.airbyte.cdk.load.dataflow.state.StateReconciler
import io.github.oshai.kotlinlogging.KotlinLogging
import jakarta.inject.Singleton
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.coroutineScope
import kotlinx.coroutines.joinAll
import kotlinx.coroutines.launch
@Singleton
class PipelineRunner(
private val reconciler: StateReconciler,
val pipelines: List<DataFlowPipeline>,
) {
private val log = KotlinLogging.logger {}
suspend fun run() = coroutineScope {
log.info { "Destination Pipeline Starting..." }
log.info { "Running with ${pipelines.size} input streams..." }
reconciler.run(CoroutineScope(Dispatchers.IO))
try {
pipelines.map { p -> launch { p.run() } }.joinAll()
log.info { "Individual pipelines complete..." }
} finally {
// shutdown the reconciler regardless of success or failure, so we don't hang
log.info { "Disabling reconciler..." }
reconciler.disable()
}
log.info { "Flushing final states..." }
reconciler.flushCompleteStates()
log.info { "Destination Pipeline Completed — Successfully" }
}
}

View File

@@ -1,24 +0,0 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.cdk.load.dataflow.pipeline
import io.airbyte.cdk.load.dataflow.state.StateReconciler
import io.github.oshai.kotlinlogging.KotlinLogging
import jakarta.inject.Singleton
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
@Singleton
class PipelineStartHandler(
private val reconciler: StateReconciler,
) {
private val log = KotlinLogging.logger {}
fun run() {
log.info { "Destination Pipeline Starting..." }
reconciler.run(CoroutineScope(Dispatchers.IO))
}
}

View File

@@ -6,12 +6,8 @@ package io.airbyte.cdk.load.dataflow.stages
import io.airbyte.cdk.load.dataflow.aggregate.AggregateStore
import io.airbyte.cdk.load.dataflow.pipeline.DataFlowStageIO
import jakarta.inject.Named
import jakarta.inject.Singleton
import kotlinx.coroutines.flow.FlowCollector
@Named("aggregate")
@Singleton
class AggregateStage(
val store: AggregateStore,
) {

View File

@@ -57,4 +57,53 @@ class ClientSocket(
}
log.info { "Reading from socket $socketPath complete" }
}
fun openInputStream(): InputStream {
log.info { "Connecting client socket at $socketPath" }
val socketFile = File(socketPath)
var totalWaitMs = 0L
while (!socketFile.exists()) {
log.info { "Waiting for socket file to be created: $socketPath" }
Thread.sleep(connectWaitDelayMs)
totalWaitMs += connectWaitDelayMs
if (totalWaitMs > connectTimeoutMs) {
throw IllegalStateException(
"Socket file $socketPath not created after $connectTimeoutMs ms"
)
}
}
log.info { "Socket file $socketPath created" }
val address = UnixDomainSocketAddress.of(socketFile.toPath())
val openedSocket = SocketChannel.open(StandardProtocolFamily.UNIX)
log.info { "Socket file $socketPath opened" }
if (!openedSocket.connect(address)) {
throw IllegalStateException("Failed to connect to socket $socketPath")
}
// HACK: The dockerized destination tests uses this exact message
// as a signal that it's safe to create the TCP connection to the
// socat sidecar that feeds data into the socket. Removing it
// will break tests. TODO: Anything else.
log.info { "Socket file $socketPath connected for reading" }
val inputStream = Channels.newInputStream(openedSocket).buffered(bufferSizeBytes)
return SocketInputStream(openedSocket, inputStream)
}
}
class SocketInputStream(
private val socketChannel: SocketChannel,
private val inputStream: InputStream,
) : InputStream() {
override fun read(): Int = inputStream.read()
override fun close() {
inputStream.close()
socketChannel.close()
}
}

View File

@@ -8,7 +8,7 @@ import io.airbyte.cdk.load.command.DestinationCatalog
import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.cdk.load.dataflow.config.MemoryAndParallelismConfig
import io.airbyte.cdk.load.dataflow.finalization.StreamCompletionTracker
import io.airbyte.cdk.load.dataflow.pipeline.DataFlowPipeline
import io.airbyte.cdk.load.dataflow.pipeline.PipelineRunner
import io.airbyte.cdk.load.write.DestinationWriter
import io.airbyte.cdk.load.write.StreamLoader
import io.mockk.coEvery
@@ -22,7 +22,7 @@ class DestinationLifecycleTest {
private val destinationInitializer: DestinationWriter = mockk(relaxed = true)
private val destinationCatalog: DestinationCatalog = mockk(relaxed = true)
private val pipeline: DataFlowPipeline = mockk(relaxed = true)
private val pipeline: PipelineRunner = mockk(relaxed = true)
private val completionTracker: StreamCompletionTracker = mockk(relaxed = true)
private val memoryAndParallelismConfig: MemoryAndParallelismConfig =
MemoryAndParallelismConfig(maxOpenAggregates = 1, maxBufferedAggregates = 1)

View File

@@ -0,0 +1,308 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.cdk.load.dataflow.config
import io.airbyte.cdk.load.dataflow.aggregate.AggregateStore
import io.airbyte.cdk.load.dataflow.aggregate.AggregateStoreFactory
import io.airbyte.cdk.load.dataflow.finalization.StreamCompletionTracker
import io.airbyte.cdk.load.dataflow.input.DataFlowPipelineInputFlow
import io.airbyte.cdk.load.dataflow.input.DestinationMessageInputFlow
import io.airbyte.cdk.load.dataflow.pipeline.DataFlowStage
import io.airbyte.cdk.load.dataflow.state.StateHistogramStore
import io.airbyte.cdk.load.dataflow.state.StateKeyClient
import io.airbyte.cdk.load.dataflow.state.StateStore
import io.airbyte.cdk.load.file.ClientSocket
import io.airbyte.cdk.load.message.ProtocolMessageDeserializer
import io.mockk.every
import io.mockk.impl.annotations.MockK
import io.mockk.junit5.MockKExtension
import io.mockk.mockk
import io.mockk.unmockkAll
import io.mockk.verify
import java.io.ByteArrayInputStream
import org.junit.jupiter.api.AfterEach
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Assertions.assertNotNull
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.extension.ExtendWith
@ExtendWith(MockKExtension::class)
class InputBeanFactoryTest {
@MockK private lateinit var deserializer: ProtocolMessageDeserializer
@MockK private lateinit var stateStore: StateStore
@MockK private lateinit var stateKeyClient: StateKeyClient
@MockK private lateinit var completionTracker: StreamCompletionTracker
@MockK private lateinit var parseStage: DataFlowStage
@MockK private lateinit var flushStage: DataFlowStage
@MockK private lateinit var stateStage: DataFlowStage
@MockK private lateinit var aggregateStoreFactory: AggregateStoreFactory
@MockK private lateinit var stateHistogramStore: StateHistogramStore
private var memoryAndParallelismConfig = MemoryAndParallelismConfig()
private lateinit var factory: InputBeanFactory
@BeforeEach
fun setup() {
factory = InputBeanFactory()
}
@AfterEach
fun tearDown() {
unmockkAll()
}
@Test
fun `sockets should create ClientSocket instances from socket paths`() {
// Given
val socketPaths = listOf("/tmp/socket1", "/tmp/socket2", "/tmp/socket3")
val bufferSizeBytes = 8192
val socketConnectionTimeoutMs = 5000L
// When
val result =
factory.sockets(
socketPaths = socketPaths,
bufferSizeBytes = bufferSizeBytes,
socketConnectionTimeoutMs = socketConnectionTimeoutMs
)
// Then
assertEquals(3, result.size)
assertEquals("/tmp/socket1", result[0].socketPath)
assertEquals("/tmp/socket2", result[1].socketPath)
assertEquals("/tmp/socket3", result[2].socketPath)
}
@Test
fun `sockets should create single ClientSocket for single path`() {
// Given
val socketPaths = listOf("/tmp/single.socket")
val bufferSizeBytes = 16384
val socketConnectionTimeoutMs = 10000L
// When
val result =
factory.sockets(
socketPaths = socketPaths,
bufferSizeBytes = bufferSizeBytes,
socketConnectionTimeoutMs = socketConnectionTimeoutMs
)
// Then
assertEquals(1, result.size)
assertEquals("/tmp/single.socket", result[0].socketPath)
}
@Test
fun `socketStreams should create input streams from ClientSocket list`() {
// Given
val socket1 = mockk<ClientSocket>()
val socket2 = mockk<ClientSocket>()
val socket3 = mockk<ClientSocket>()
val sockets = listOf(socket1, socket2, socket3)
val mockInputStream1 = ByteArrayInputStream("stream1".toByteArray())
val mockInputStream2 = ByteArrayInputStream("stream2".toByteArray())
val mockInputStream3 = ByteArrayInputStream("stream3".toByteArray())
every { socket1.openInputStream() } returns mockInputStream1
every { socket2.openInputStream() } returns mockInputStream2
every { socket3.openInputStream() } returns mockInputStream3
// When
val result = factory.socketStreams(sockets)
// Then
assertEquals(3, result.size)
assertEquals(mockInputStream1, result[0])
assertEquals(mockInputStream2, result[1])
assertEquals(mockInputStream3, result[2])
verify(exactly = 1) { socket1.openInputStream() }
verify(exactly = 1) { socket2.openInputStream() }
verify(exactly = 1) { socket3.openInputStream() }
}
@Test
fun `socketStreams should handle single socket`() {
// Given
val socket = mockk<ClientSocket>()
val sockets = listOf(socket)
val mockInputStream = ByteArrayInputStream("single stream".toByteArray())
every { socket.openInputStream() } returns mockInputStream
// When
val result = factory.socketStreams(sockets)
// Then
assertEquals(1, result.size)
assertEquals(mockInputStream, result[0])
verify(exactly = 1) { socket.openInputStream() }
}
@Test
fun `stdInStreams should return System in`() {
// When
val result = factory.stdInStreams()
// Then
assertEquals(1, result.size)
assertEquals(System.`in`, result[0])
}
@Test
fun `messageFlows should create DestinationMessageInputFlow for each input stream`() {
// Given
val inputStream1 = ByteArrayInputStream("stream1".toByteArray())
val inputStream2 = ByteArrayInputStream("stream2".toByteArray())
val inputStreams = listOf(inputStream1, inputStream2)
// When
val result = factory.messageFlows(inputStreams, deserializer)
// Then
assertEquals(2, result.size)
assertNotNull(result[0])
assertNotNull(result[1])
}
@Test
fun `inputFlows should create DataFlowPipelineInputFlow for each message flow`() {
// Given
val messageFlow1 = mockk<DestinationMessageInputFlow>()
val messageFlow2 = mockk<DestinationMessageInputFlow>()
val messageFlows = listOf(messageFlow1, messageFlow2)
// When
val result =
factory.inputFlows(
messageFlows = messageFlows,
stateStore = stateStore,
stateKeyClient = stateKeyClient,
completionTracker = completionTracker
)
// Then
assertEquals(2, result.size)
assertNotNull(result[0])
assertNotNull(result[1])
}
@Test
fun `pipes should create DataFlowPipeline for each input flow`() {
every { aggregateStoreFactory.make() } returns mockk()
// Given
val inputFlow1 = mockk<DataFlowPipelineInputFlow>()
val inputFlow2 = mockk<DataFlowPipelineInputFlow>()
val inputFlows = listOf(inputFlow1, inputFlow2)
// When
val result =
factory.pipes(
inputFlows = inputFlows,
parse = parseStage,
flush = flushStage,
state = stateStage,
aggregateStoreFactory = aggregateStoreFactory,
stateHistogramStore = stateHistogramStore,
memoryAndParallelismConfig = memoryAndParallelismConfig
)
// Then
assertEquals(2, result.size)
assertNotNull(result[0])
assertNotNull(result[1])
}
@Test
fun `pipes should create different aggregate store for each pipeline`() {
// Given
val inputFlow1 = mockk<DataFlowPipelineInputFlow>()
val inputFlow2 = mockk<DataFlowPipelineInputFlow>()
val inputFlow3 = mockk<DataFlowPipelineInputFlow>()
val inputFlows = listOf(inputFlow1, inputFlow2, inputFlow3)
val aggregateStore1 = mockk<AggregateStore>(relaxed = true)
val aggregateStore2 = mockk<AggregateStore>(relaxed = true)
val aggregateStore3 = mockk<AggregateStore>(relaxed = true)
// Mock the factory to return different instances
every { aggregateStoreFactory.make() } returnsMany
listOf(aggregateStore1, aggregateStore2, aggregateStore3)
// When
val result =
factory.pipes(
inputFlows = inputFlows,
parse = parseStage,
flush = flushStage,
state = stateStage,
aggregateStoreFactory = aggregateStoreFactory,
stateHistogramStore = stateHistogramStore,
memoryAndParallelismConfig = memoryAndParallelismConfig
)
// Then
assertEquals(3, result.size)
// Verify that the factory was called 3 times to create 3 different stores
verify(exactly = 3) { aggregateStoreFactory.make() }
// Each pipeline should have received a different aggregate store
// We can't directly verify which store went to which pipeline without accessing internals,
// but we verified that 3 different stores were created via the factory
}
@Test
fun `integration test - complete flow from socket paths to pipelines`() {
// Given
val mockInputStream1 = ByteArrayInputStream("stream1".toByteArray())
val mockInputStream2 = ByteArrayInputStream("stream2".toByteArray())
every { aggregateStoreFactory.make() } returns mockk()
val inputStreams = listOf(mockInputStream1, mockInputStream2)
val messageFlows = factory.messageFlows(inputStreams, deserializer)
val inputFlows =
factory.inputFlows(
messageFlows = messageFlows,
stateStore = stateStore,
stateKeyClient = stateKeyClient,
completionTracker = completionTracker
)
val pipes =
factory.pipes(
inputFlows = inputFlows,
parse = parseStage,
flush = flushStage,
state = stateStage,
aggregateStoreFactory = aggregateStoreFactory,
stateHistogramStore = stateHistogramStore,
memoryAndParallelismConfig = memoryAndParallelismConfig
)
// Then
assertEquals(2, inputStreams.size)
assertEquals(2, messageFlows.size)
assertEquals(2, inputFlows.size)
assertEquals(2, pipes.size)
}
}

View File

@@ -20,7 +20,6 @@ class DataFlowPipelineTest {
private val aggregate = mockk<AggregateStage>()
private val flush = mockk<DataFlowStage>()
private val state = mockk<DataFlowStage>()
private val startHandler = mockk<PipelineStartHandler>()
private val completionHandler = mockk<PipelineCompletionHandler>()
private val memoryAndParallelismConfig =
MemoryAndParallelismConfig(
@@ -40,7 +39,6 @@ class DataFlowPipelineTest {
aggregate,
flush,
state,
startHandler,
completionHandler,
memoryAndParallelismConfig
)
@@ -50,7 +48,6 @@ class DataFlowPipelineTest {
val flushedIO = mockk<DataFlowStageIO>()
val stateIO = mockk<DataFlowStageIO>()
coEvery { startHandler.run() } returns Unit
coEvery { parse.apply(initialIO) } returns parsedIO
coEvery { aggregate.apply(parsedIO, any()) } coAnswers
{
@@ -66,7 +63,6 @@ class DataFlowPipelineTest {
// Then
coVerifySequence {
startHandler.run()
parse.apply(initialIO)
aggregate.apply(parsedIO, any())
flush.apply(aggregatedIO)

View File

@@ -9,7 +9,6 @@ import io.airbyte.cdk.load.dataflow.aggregate.AggregateEntry
import io.airbyte.cdk.load.dataflow.aggregate.AggregateStore
import io.airbyte.cdk.load.dataflow.state.PartitionHistogram
import io.airbyte.cdk.load.dataflow.state.StateHistogramStore
import io.airbyte.cdk.load.dataflow.state.StateReconciler
import io.mockk.Runs
import io.mockk.coEvery
import io.mockk.coVerify
@@ -33,8 +32,6 @@ class PipelineCompletionHandlerTest {
@MockK private lateinit var stateHistogramStore: StateHistogramStore
@MockK private lateinit var reconciler: StateReconciler
private lateinit var pipelineCompletionHandler: PipelineCompletionHandler
@BeforeEach
@@ -43,10 +40,7 @@ class PipelineCompletionHandlerTest {
PipelineCompletionHandler(
aggStore = aggStore,
stateHistogramStore = stateHistogramStore,
reconciler = reconciler
)
coEvery { reconciler.disable() } just Runs
}
@Test
@@ -58,7 +52,6 @@ class PipelineCompletionHandlerTest {
val thrownException =
assertThrows<RuntimeException> { pipelineCompletionHandler.apply(testException) }
coVerify(exactly = 1) { reconciler.disable() }
assertEquals("Test exception", thrownException.message)
}
@@ -92,7 +85,6 @@ class PipelineCompletionHandlerTest {
coEvery { mockAggregate1.flush() } just Runs
coEvery { mockAggregate2.flush() } just Runs
every { stateHistogramStore.acceptFlushedCounts(any()) } returns mockk()
every { reconciler.flushCompleteStates() } just Runs
// When
pipelineCompletionHandler.apply(null)
@@ -102,23 +94,18 @@ class PipelineCompletionHandlerTest {
coVerify(exactly = 1) { mockAggregate2.flush() }
verify(exactly = 1) { stateHistogramStore.acceptFlushedCounts(mockHistogram1) }
verify(exactly = 1) { stateHistogramStore.acceptFlushedCounts(mockHistogram2) }
coVerify(exactly = 1) { reconciler.disable() }
verify(exactly = 1) { reconciler.flushCompleteStates() }
}
@Test
fun `apply should handle empty aggregates list`() = runTest {
// Given
every { aggStore.getAll() } returns emptyList()
every { reconciler.flushCompleteStates() } just Runs
// When
pipelineCompletionHandler.apply(null)
// Then
verify(exactly = 1) { aggStore.getAll() }
coVerify(exactly = 1) { reconciler.disable() }
verify(exactly = 1) { reconciler.flushCompleteStates() }
verify(exactly = 0) { stateHistogramStore.acceptFlushedCounts(any()) }
}
@@ -140,12 +127,10 @@ class PipelineCompletionHandlerTest {
every { aggStore.getAll() } returns listOf(aggregateEntry)
coEvery { mockAggregate.flush() } throws flushException
every { reconciler.flushCompleteStates() } just Runs
// When & Then
assertThrows<RuntimeException> { pipelineCompletionHandler.apply(null) }
coVerify(exactly = 1) { reconciler.disable() }
coVerify(exactly = 1) { mockAggregate.flush() }
// Note: acceptFlushedCounts should not be called if flush fails
verify(exactly = 0) { stateHistogramStore.acceptFlushedCounts(any()) }

View File

@@ -0,0 +1,249 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.cdk.load.dataflow.pipeline
import io.airbyte.cdk.load.dataflow.state.StateReconciler
import io.mockk.Runs
import io.mockk.coEvery
import io.mockk.coVerify
import io.mockk.coVerifySequence
import io.mockk.every
import io.mockk.impl.annotations.MockK
import io.mockk.junit5.MockKExtension
import io.mockk.just
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.delay
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.test.runTest
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Assertions.assertTrue
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.assertThrows
import org.junit.jupiter.api.extension.ExtendWith
@ExtendWith(MockKExtension::class)
class PipelineRunnerTest {
@MockK(relaxed = true) private lateinit var reconciler: StateReconciler
@MockK private lateinit var pipeline1: DataFlowPipeline
@MockK private lateinit var pipeline2: DataFlowPipeline
@MockK private lateinit var pipeline3: DataFlowPipeline
private lateinit var runner: PipelineRunner
@BeforeEach
fun setup() {
runner = PipelineRunner(reconciler, listOf(pipeline1, pipeline2, pipeline3))
}
@Test
fun `run should execute all pipelines concurrently`() = runTest {
// Given
coEvery { pipeline1.run() } coAnswers { delay(100) }
coEvery { pipeline2.run() } coAnswers { delay(50) }
coEvery { pipeline3.run() } coAnswers { delay(75) }
// When
val startTime = System.currentTimeMillis()
runner.run()
val endTime = System.currentTimeMillis()
// Then
coVerify(exactly = 1) { pipeline1.run() }
coVerify(exactly = 1) { pipeline2.run() }
coVerify(exactly = 1) { pipeline3.run() }
// Verify they ran concurrently (total time should be close to the longest pipeline)
// Allow some buffer for execution overhead
assertTrue((endTime - startTime) < 200, "Pipelines should run concurrently")
coVerify(exactly = 1) { reconciler.run(any()) }
coVerify(exactly = 1) { reconciler.disable() }
coVerify(exactly = 1) { reconciler.flushCompleteStates() }
}
@Test
fun `run should disable reconciler even when pipeline fails`() = runTest {
// Given
val exception = RuntimeException("Pipeline failed")
coEvery { pipeline1.run() } throws exception
val failingRunner = PipelineRunner(reconciler, listOf(pipeline1))
// When/Then
val thrownException = assertThrows<RuntimeException> { runBlocking { failingRunner.run() } }
assertEquals("Pipeline failed", thrownException.message)
// Verify reconciler was still disabled in finally block
coVerify(exactly = 1) { reconciler.disable() }
coVerify(exactly = 0) { reconciler.flushCompleteStates() }
}
@Test
fun `run should execute operations in correct order`() = runTest {
// Given
coEvery { pipeline1.run() } just Runs
val singlePipelineRunner = PipelineRunner(reconciler, listOf(pipeline1))
// When
singlePipelineRunner.run()
// Then - verify order of operations
coVerifySequence {
reconciler.run(any())
pipeline1.run()
reconciler.disable()
reconciler.flushCompleteStates()
}
}
@Test
fun `run should pass correct CoroutineScope to reconciler`() = runTest {
// Given
var capturedScope: CoroutineScope? = null
every { reconciler.run(any()) } answers { capturedScope = firstArg() }
val emptyRunner = PipelineRunner(reconciler, emptyList())
// When
emptyRunner.run()
// Then
assertTrue(capturedScope != null, "CoroutineScope should be passed to reconciler")
// Verify it's using IO dispatcher
assertTrue(capturedScope.toString().contains("Dispatchers.IO"), "Should use IO dispatcher")
}
@Test
fun `run should handle single pipeline`() = runTest {
// Given
coEvery { pipeline1.run() } just Runs
val singlePipelineRunner = PipelineRunner(reconciler, listOf(pipeline1))
// When
singlePipelineRunner.run()
// Then
coVerify(exactly = 1) { pipeline1.run() }
coVerify(exactly = 1) { reconciler.run(any()) }
coVerify(exactly = 1) { reconciler.disable() }
coVerify(exactly = 1) { reconciler.flushCompleteStates() }
}
@Test
fun `run should handle reconciler disable failure`() = runTest {
// Given
coEvery { pipeline1.run() } just Runs
coEvery { reconciler.disable() } throws RuntimeException("Failed to disable")
val singlePipelineRunner = PipelineRunner(reconciler, listOf(pipeline1))
// When/Then
val exception =
assertThrows<RuntimeException> { runBlocking { singlePipelineRunner.run() } }
assertEquals("Failed to disable", exception.message)
// Verify pipeline was executed before the failure
coVerify(exactly = 1) { pipeline1.run() }
coVerify(exactly = 0) { reconciler.flushCompleteStates() }
}
@Test
fun `run should handle reconciler flushCompleteStates failure`() = runTest {
// Given
coEvery { pipeline1.run() } just Runs
every { reconciler.flushCompleteStates() } throws RuntimeException("Failed to flush")
val singlePipelineRunner = PipelineRunner(reconciler, listOf(pipeline1))
// When/Then
val exception =
assertThrows<RuntimeException> { runBlocking { singlePipelineRunner.run() } }
assertEquals("Failed to flush", exception.message)
// Verify everything up to flush was executed
coVerify(exactly = 1) { pipeline1.run() }
coVerify(exactly = 1) { reconciler.disable() }
coVerify(exactly = 1) { reconciler.flushCompleteStates() }
}
@Test
fun `pipelines property should return the list of pipelines`() {
// When
val result = runner.pipelines
// Then
assertEquals(3, result.size)
assertEquals(pipeline1, result[0])
assertEquals(pipeline2, result[1])
assertEquals(pipeline3, result[2])
}
@Test
fun `run should handle large number of pipelines`() = runTest {
// Given
val pipelines =
(1..100).map {
@MockK val pipeline = io.mockk.mockk<DataFlowPipeline>()
coEvery { pipeline.run() } just Runs
pipeline
}
val largeRunner = PipelineRunner(reconciler, pipelines)
// When
largeRunner.run()
// Then
pipelines.forEach { pipeline -> coVerify(exactly = 1) { pipeline.run() } }
coVerify(exactly = 1) { reconciler.run(any()) }
coVerify(exactly = 1) { reconciler.disable() }
coVerify(exactly = 1) { reconciler.flushCompleteStates() }
}
@Test
fun `run should wait for all pipelines to complete before disabling reconciler`() = runTest {
// Given
var pipeline1Complete = false
var pipeline2Complete = false
var reconcilerDisabled = false
coEvery { pipeline1.run() } coAnswers
{
delay(100)
pipeline1Complete = true
}
coEvery { pipeline2.run() } coAnswers
{
delay(50)
pipeline2Complete = true
}
coEvery { reconciler.disable() } answers
{
reconcilerDisabled = true
assertTrue(
pipeline1Complete,
"Pipeline 1 should be complete before disabling reconciler"
)
assertTrue(
pipeline2Complete,
"Pipeline 2 should be complete before disabling reconciler"
)
}
val twoRunner = PipelineRunner(reconciler, listOf(pipeline1, pipeline2))
// When
twoRunner.run()
// Then
assertTrue(reconcilerDisabled, "Reconciler should have been disabled")
}
}

View File

@@ -1,42 +0,0 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.cdk.load.dataflow.pipeline
import io.airbyte.cdk.load.dataflow.state.StateReconciler
import io.mockk.Runs
import io.mockk.every
import io.mockk.impl.annotations.MockK
import io.mockk.junit5.MockKExtension
import io.mockk.just
import io.mockk.verify
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.extension.ExtendWith
@ExtendWith(MockKExtension::class)
class PipelineStartHandlerTest {
@MockK private lateinit var reconciler: StateReconciler
private lateinit var pipelineStartHandler: PipelineStartHandler
@BeforeEach
fun setUp() {
pipelineStartHandler = PipelineStartHandler(reconciler)
}
@Test
fun `run should call reconciler run method`() {
// Given
every { reconciler.run(any()) } just Runs
// When
pipelineStartHandler.run()
// Then
verify(exactly = 1) { reconciler.run(any()) }
}
}