Add tests to the dataflow pipeline (#64887)
This commit is contained in:
@@ -18,5 +18,6 @@ When module is mentioned, it refers to this [folder](.)
|
||||
- update the [changelog.md](changelog.md) where we need to add a description for the new version
|
||||
- If the version has already been bumped on the local branch, we shouldn't bump it again
|
||||
- We format our code by running the command `pre-commit run --all-files` from the root of the project
|
||||
- when writing a test function, the function name can't start with test because it is redundant
|
||||
|
||||
When you are done with a change always run the format and update the changelog if needed.
|
||||
@@ -12,6 +12,8 @@
|
||||
|
||||
## Version 0.1.14
|
||||
|
||||
**Load CDK**
|
||||
|
||||
* **Changed:** Add agent.
|
||||
|
||||
## Version 0.1.13
|
||||
|
||||
@@ -0,0 +1,77 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
|
||||
*/
|
||||
|
||||
package io.airbyte.cdk.load.dataflow
|
||||
|
||||
import io.airbyte.cdk.load.dataflow.config.MemoryAndParallelismConfig
|
||||
import io.airbyte.cdk.load.dataflow.stages.AggregateStage
|
||||
import io.mockk.coEvery
|
||||
import io.mockk.coVerifySequence
|
||||
import io.mockk.mockk
|
||||
import kotlinx.coroutines.ExperimentalCoroutinesApi
|
||||
import kotlinx.coroutines.flow.flowOf
|
||||
import kotlinx.coroutines.test.runTest
|
||||
import org.junit.jupiter.api.Test
|
||||
|
||||
@ExperimentalCoroutinesApi
|
||||
class DataFlowPipelineTest {
|
||||
private val parse = mockk<DataFlowStage>()
|
||||
private val aggregate = mockk<AggregateStage>()
|
||||
private val flush = mockk<DataFlowStage>()
|
||||
private val state = mockk<DataFlowStage>()
|
||||
private val startHandler = mockk<PipelineStartHandler>()
|
||||
private val completionHandler = mockk<PipelineCompletionHandler>()
|
||||
private val memoryAndParallelismConfig =
|
||||
MemoryAndParallelismConfig(
|
||||
maxOpenAggregates = 2,
|
||||
maxBufferedAggregates = 2,
|
||||
)
|
||||
|
||||
@Test
|
||||
fun `pipeline execution flow`() = runTest {
|
||||
// Given
|
||||
val initialIO = mockk<DataFlowStageIO>()
|
||||
val input = flowOf(initialIO)
|
||||
val pipeline =
|
||||
DataFlowPipeline(
|
||||
input,
|
||||
parse,
|
||||
aggregate,
|
||||
flush,
|
||||
state,
|
||||
startHandler,
|
||||
completionHandler,
|
||||
memoryAndParallelismConfig
|
||||
)
|
||||
|
||||
val parsedIO = mockk<DataFlowStageIO>()
|
||||
val aggregatedIO = mockk<DataFlowStageIO>()
|
||||
val flushedIO = mockk<DataFlowStageIO>()
|
||||
val stateIO = mockk<DataFlowStageIO>()
|
||||
|
||||
coEvery { startHandler.run() } returns Unit
|
||||
coEvery { parse.apply(initialIO) } returns parsedIO
|
||||
coEvery { aggregate.apply(parsedIO, any()) } coAnswers
|
||||
{
|
||||
val collector = secondArg<kotlinx.coroutines.flow.FlowCollector<DataFlowStageIO>>()
|
||||
collector.emit(aggregatedIO)
|
||||
}
|
||||
coEvery { flush.apply(aggregatedIO) } returns flushedIO
|
||||
coEvery { state.apply(flushedIO) } returns stateIO
|
||||
coEvery { completionHandler.apply(null) } returns Unit
|
||||
|
||||
// When
|
||||
pipeline.run()
|
||||
|
||||
// Then
|
||||
coVerifySequence {
|
||||
startHandler.run()
|
||||
parse.apply(initialIO)
|
||||
aggregate.apply(parsedIO, any())
|
||||
flush.apply(aggregatedIO)
|
||||
state.apply(flushedIO)
|
||||
completionHandler.apply(null)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,60 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
|
||||
*/
|
||||
|
||||
package io.airbyte.cdk.load.dataflow
|
||||
|
||||
import io.airbyte.cdk.load.command.DestinationCatalog
|
||||
import io.airbyte.cdk.load.command.DestinationStream
|
||||
import io.airbyte.cdk.load.dataflow.config.MemoryAndParallelismConfig
|
||||
import io.airbyte.cdk.load.write.DestinationWriter
|
||||
import io.airbyte.cdk.load.write.StreamLoader
|
||||
import io.mockk.coEvery
|
||||
import io.mockk.coVerify
|
||||
import io.mockk.mockk
|
||||
import kotlinx.coroutines.test.runTest
|
||||
import org.junit.jupiter.api.Test
|
||||
|
||||
class DestinationLifecycleTest {
|
||||
|
||||
private val destinationInitializer: DestinationWriter = mockk(relaxed = true)
|
||||
private val destinationCatalog: DestinationCatalog = mockk(relaxed = true)
|
||||
private val pipeline: DataFlowPipeline = mockk(relaxed = true)
|
||||
private val memoryAndParallelismConfig: MemoryAndParallelismConfig =
|
||||
MemoryAndParallelismConfig(maxOpenAggregates = 1, maxBufferedAggregates = 1)
|
||||
|
||||
private val destinationLifecycle =
|
||||
DestinationLifecycle(
|
||||
destinationInitializer,
|
||||
destinationCatalog,
|
||||
pipeline,
|
||||
memoryAndParallelismConfig,
|
||||
)
|
||||
|
||||
@Test
|
||||
fun `should execute full lifecycle in correct order`() = runTest {
|
||||
// Given
|
||||
val streamLoader1 = mockk<StreamLoader>(relaxed = true)
|
||||
val streamLoader2 = mockk<StreamLoader>(relaxed = true)
|
||||
val stream1 = mockk<DestinationStream>(relaxed = true)
|
||||
val stream2 = mockk<DestinationStream>(relaxed = true)
|
||||
|
||||
coEvery { destinationCatalog.streams } returns listOf(stream1, stream2)
|
||||
coEvery { destinationInitializer.createStreamLoader(stream1) } returns streamLoader1
|
||||
coEvery { destinationInitializer.createStreamLoader(stream2) } returns streamLoader2
|
||||
|
||||
// When
|
||||
destinationLifecycle.run()
|
||||
|
||||
// Then
|
||||
coVerify(exactly = 1) { destinationInitializer.setup() }
|
||||
coVerify(exactly = 1) { destinationInitializer.createStreamLoader(stream1) }
|
||||
coVerify(exactly = 1) { streamLoader1.start() }
|
||||
coVerify(exactly = 1) { destinationInitializer.createStreamLoader(stream2) }
|
||||
coVerify(exactly = 1) { streamLoader2.start() }
|
||||
coVerify(exactly = 1) { pipeline.run() }
|
||||
coVerify(exactly = 1) { streamLoader1.close(true) }
|
||||
coVerify(exactly = 1) { streamLoader2.close(true) }
|
||||
coVerify(exactly = 1) { destinationInitializer.teardown() }
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,50 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
|
||||
*/
|
||||
|
||||
package io.airbyte.cdk.load.dataflow.aggregate
|
||||
|
||||
import org.junit.jupiter.api.Assertions.assertEquals
|
||||
import org.junit.jupiter.api.Assertions.assertFalse
|
||||
import org.junit.jupiter.api.Assertions.assertTrue
|
||||
import org.junit.jupiter.api.Test
|
||||
|
||||
class SizeTriggerTest {
|
||||
@Test
|
||||
fun `trigger not fired when under threshold`() {
|
||||
val trigger = SizeTrigger(100)
|
||||
trigger.increment(50)
|
||||
assertFalse(trigger.isComplete())
|
||||
assertEquals(50L, trigger.watermark())
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `trigger fired when at threshold`() {
|
||||
val trigger = SizeTrigger(100)
|
||||
trigger.increment(100)
|
||||
assertTrue(trigger.isComplete())
|
||||
assertEquals(100L, trigger.watermark())
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `trigger fired when over threshold`() {
|
||||
val trigger = SizeTrigger(100)
|
||||
trigger.increment(150)
|
||||
assertTrue(trigger.isComplete())
|
||||
assertEquals(150L, trigger.watermark())
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `multiple increments`() {
|
||||
val trigger = SizeTrigger(100)
|
||||
trigger.increment(50)
|
||||
assertFalse(trigger.isComplete())
|
||||
assertEquals(50L, trigger.watermark())
|
||||
trigger.increment(49)
|
||||
assertFalse(trigger.isComplete())
|
||||
assertEquals(99L, trigger.watermark())
|
||||
trigger.increment(1)
|
||||
assertTrue(trigger.isComplete())
|
||||
assertEquals(100L, trigger.watermark())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,55 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
|
||||
*/
|
||||
|
||||
package io.airbyte.cdk.load.dataflow.aggregate
|
||||
|
||||
import org.junit.jupiter.api.Assertions.assertFalse
|
||||
import org.junit.jupiter.api.Assertions.assertTrue
|
||||
import org.junit.jupiter.api.Test
|
||||
|
||||
class TimeTriggerTest {
|
||||
|
||||
@Test
|
||||
fun `time trigger logic`() {
|
||||
val triggerSize = 1000L // 1 second
|
||||
val trigger = TimeTrigger(triggerSize)
|
||||
|
||||
val startTime = System.currentTimeMillis()
|
||||
|
||||
// Update the trigger with the start time
|
||||
trigger.update(startTime)
|
||||
|
||||
// Immediately after update, it should not be complete
|
||||
assertFalse(trigger.isComplete(startTime))
|
||||
|
||||
// Just before the trigger time, it should not be complete
|
||||
assertFalse(trigger.isComplete(startTime + triggerSize - 1))
|
||||
|
||||
// Exactly at the trigger time, it should be complete
|
||||
assertTrue(trigger.isComplete(startTime + triggerSize))
|
||||
|
||||
// After the trigger time, it should be complete
|
||||
assertTrue(trigger.isComplete(startTime + triggerSize + 1))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `updating timestamp resets the trigger`() {
|
||||
val triggerSize = 1000L
|
||||
val trigger = TimeTrigger(triggerSize)
|
||||
|
||||
val startTime = System.currentTimeMillis()
|
||||
trigger.update(startTime)
|
||||
|
||||
// It should be complete after triggerSize
|
||||
assertTrue(trigger.isComplete(startTime + triggerSize))
|
||||
|
||||
// Update the timestamp to a later time
|
||||
val newStartTime = startTime + 500L
|
||||
trigger.update(newStartTime)
|
||||
|
||||
// Now it should not be complete relative to the new start time
|
||||
assertFalse(trigger.isComplete(newStartTime + triggerSize - 1))
|
||||
assertTrue(trigger.isComplete(newStartTime + triggerSize))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,99 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
|
||||
*/
|
||||
|
||||
package io.airbyte.cdk.load.dataflow.input
|
||||
|
||||
import io.airbyte.cdk.load.command.DestinationStream
|
||||
import io.airbyte.cdk.load.dataflow.DataFlowStageIO
|
||||
import io.airbyte.cdk.load.dataflow.state.PartitionKey
|
||||
import io.airbyte.cdk.load.dataflow.state.StateKeyClient
|
||||
import io.airbyte.cdk.load.dataflow.state.StateStore
|
||||
import io.airbyte.cdk.load.message.CheckpointMessage
|
||||
import io.airbyte.cdk.load.message.DestinationMessage
|
||||
import io.airbyte.cdk.load.message.DestinationRecord
|
||||
import io.airbyte.cdk.load.message.DestinationRecordSource
|
||||
import io.airbyte.cdk.load.message.Undefined
|
||||
import io.mockk.coVerify
|
||||
import io.mockk.every
|
||||
import io.mockk.mockk
|
||||
import java.util.UUID
|
||||
import kotlinx.coroutines.flow.flowOf
|
||||
import kotlinx.coroutines.flow.toList
|
||||
import kotlinx.coroutines.runBlocking
|
||||
import org.junit.jupiter.api.Assertions.assertEquals
|
||||
import org.junit.jupiter.api.Test
|
||||
|
||||
class DataFlowPipelineInputFlowTest {
|
||||
@Test
|
||||
fun `checkpoint message`() = runBlocking {
|
||||
// Given
|
||||
val checkpointMessage = mockk<CheckpointMessage>()
|
||||
val inputFlow = flowOf<DestinationMessage>(checkpointMessage)
|
||||
val stateStore = mockk<StateStore>(relaxed = true)
|
||||
val stateKeyClient = mockk<StateKeyClient>()
|
||||
val dataFlowPipelineInputFlow =
|
||||
DataFlowPipelineInputFlow(inputFlow, stateStore, stateKeyClient)
|
||||
|
||||
// When
|
||||
val result = dataFlowPipelineInputFlow.toList()
|
||||
|
||||
// Then
|
||||
coVerify(exactly = 1) { stateStore.accept(checkpointMessage) }
|
||||
assertEquals(0, result.size)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `destination record`() = runBlocking {
|
||||
// Given
|
||||
val stream = mockk<DestinationStream>()
|
||||
every { stream.schema } returns mockk()
|
||||
every { stream.airbyteValueProxyFieldAccessors } returns emptyArray()
|
||||
val message = mockk<DestinationRecordSource>()
|
||||
every { message.fileReference } returns null
|
||||
val destinationRecord =
|
||||
DestinationRecord(
|
||||
stream,
|
||||
message,
|
||||
1L,
|
||||
null,
|
||||
UUID.randomUUID(),
|
||||
)
|
||||
val inputFlow = flowOf<DestinationMessage>(destinationRecord)
|
||||
val stateStore = mockk<StateStore>()
|
||||
val stateKeyClient = mockk<StateKeyClient>()
|
||||
val partitionKey = PartitionKey("partitionKey")
|
||||
every { stateKeyClient.getPartitionKey(any()) } returns partitionKey
|
||||
val dataFlowPipelineInputFlow =
|
||||
DataFlowPipelineInputFlow(inputFlow, stateStore, stateKeyClient)
|
||||
|
||||
// When
|
||||
val result = dataFlowPipelineInputFlow.toList()
|
||||
|
||||
// Then
|
||||
assertEquals(1, result.size)
|
||||
val expected =
|
||||
DataFlowStageIO(
|
||||
raw = destinationRecord.asDestinationRecordRaw(),
|
||||
partitionKey = partitionKey,
|
||||
)
|
||||
assertEquals(expected, result[0])
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `other message`() = runBlocking {
|
||||
// Given
|
||||
val undefinedMessage = Undefined
|
||||
val inputFlow = flowOf<DestinationMessage>(undefinedMessage)
|
||||
val stateStore = mockk<StateStore>()
|
||||
val stateKeyClient = mockk<StateKeyClient>()
|
||||
val dataFlowPipelineInputFlow =
|
||||
DataFlowPipelineInputFlow(inputFlow, stateStore, stateKeyClient)
|
||||
|
||||
// When
|
||||
val result = dataFlowPipelineInputFlow.toList()
|
||||
|
||||
// Then
|
||||
assertEquals(0, result.size)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,89 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
|
||||
*/
|
||||
|
||||
package io.airbyte.cdk.load.dataflow.input
|
||||
|
||||
import io.airbyte.cdk.load.message.DestinationMessage
|
||||
import io.airbyte.cdk.load.message.ProtocolMessageDeserializer
|
||||
import io.mockk.coVerify
|
||||
import io.mockk.every
|
||||
import io.mockk.mockk
|
||||
import io.mockk.verify
|
||||
import java.io.ByteArrayInputStream
|
||||
import kotlinx.coroutines.flow.FlowCollector
|
||||
import kotlinx.coroutines.test.runTest
|
||||
import org.junit.jupiter.api.Test
|
||||
|
||||
class DestinationMessageInputFlowTest {
|
||||
|
||||
@Test
|
||||
fun `should deserialize and emit messages from input stream`() = runTest {
|
||||
// Given
|
||||
val line1 = "message1"
|
||||
val line2 = "message2"
|
||||
val inputStream = ByteArrayInputStream("$line1\n$line2\n".toByteArray())
|
||||
val deserializer = mockk<ProtocolMessageDeserializer>()
|
||||
val message1 = mockk<DestinationMessage>()
|
||||
val message2 = mockk<DestinationMessage>()
|
||||
|
||||
every { deserializer.deserialize(line1) } returns message1
|
||||
every { deserializer.deserialize(line2) } returns message2
|
||||
|
||||
val inputFlow = DestinationMessageInputFlow(inputStream, deserializer)
|
||||
val collector = mockk<FlowCollector<DestinationMessage>>(relaxed = true)
|
||||
|
||||
// When
|
||||
inputFlow.collect(collector)
|
||||
|
||||
// Then
|
||||
verify(exactly = 1) { deserializer.deserialize(line1) }
|
||||
verify(exactly = 1) { deserializer.deserialize(line2) }
|
||||
coVerify(exactly = 1) { collector.emit(message1) }
|
||||
coVerify(exactly = 1) { collector.emit(message2) }
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `should ignore empty lines`() = runTest {
|
||||
// Given
|
||||
val line1 = "message1"
|
||||
val emptyLine = ""
|
||||
val line2 = "message2"
|
||||
val inputStream = ByteArrayInputStream("$line1\n$emptyLine\n$line2\n".toByteArray())
|
||||
val deserializer = mockk<ProtocolMessageDeserializer>()
|
||||
val message1 = mockk<DestinationMessage>()
|
||||
val message2 = mockk<DestinationMessage>()
|
||||
|
||||
every { deserializer.deserialize(line1) } returns message1
|
||||
every { deserializer.deserialize(line2) } returns message2
|
||||
|
||||
val inputFlow = DestinationMessageInputFlow(inputStream, deserializer)
|
||||
val collector = mockk<FlowCollector<DestinationMessage>>(relaxed = true)
|
||||
|
||||
// When
|
||||
inputFlow.collect(collector)
|
||||
|
||||
// Then
|
||||
verify(exactly = 1) { deserializer.deserialize(line1) }
|
||||
verify(exactly = 0) { deserializer.deserialize(emptyLine) } // Ensure empty line is ignored
|
||||
verify(exactly = 1) { deserializer.deserialize(line2) }
|
||||
coVerify(exactly = 1) { collector.emit(message1) }
|
||||
coVerify(exactly = 1) { collector.emit(message2) }
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `should not emit if no record is present`() = runTest {
|
||||
// Given
|
||||
val inputStream = ByteArrayInputStream("".toByteArray())
|
||||
val deserializer = mockk<ProtocolMessageDeserializer>()
|
||||
|
||||
val inputFlow = DestinationMessageInputFlow(inputStream, deserializer)
|
||||
val collector = mockk<FlowCollector<DestinationMessage>>(relaxed = true)
|
||||
|
||||
// When
|
||||
inputFlow.collect(collector)
|
||||
|
||||
// Then
|
||||
coVerify(exactly = 0) { collector.emit(any()) } // Ensure no other message is emitted
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,166 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
|
||||
*/
|
||||
|
||||
package io.airbyte.cdk.load.dataflow.stages
|
||||
|
||||
import io.airbyte.cdk.load.command.DestinationStream
|
||||
import io.airbyte.cdk.load.dataflow.DataFlowStageIO
|
||||
import io.airbyte.cdk.load.dataflow.aggregate.Aggregate
|
||||
import io.airbyte.cdk.load.dataflow.aggregate.AggregateEntry
|
||||
import io.airbyte.cdk.load.dataflow.aggregate.AggregateStore
|
||||
import io.airbyte.cdk.load.dataflow.state.PartitionHistogram
|
||||
import io.airbyte.cdk.load.dataflow.state.PartitionKey
|
||||
import io.airbyte.cdk.load.dataflow.transform.RecordDTO
|
||||
import io.airbyte.cdk.load.message.DestinationRecordRaw
|
||||
import io.mockk.coEvery
|
||||
import io.mockk.coVerify
|
||||
import io.mockk.every
|
||||
import io.mockk.mockk
|
||||
import kotlinx.coroutines.flow.FlowCollector
|
||||
import kotlinx.coroutines.test.runTest
|
||||
import org.junit.jupiter.api.Test
|
||||
|
||||
class AggregateStageTest {
|
||||
|
||||
private val store: AggregateStore = mockk(relaxed = true)
|
||||
private val stage = AggregateStage(store)
|
||||
|
||||
@Test
|
||||
fun `test aggregation`() = runTest {
|
||||
val streamDescriptor = DestinationStream.Descriptor("test_namespace", "test_name")
|
||||
val emittedAtMs = 1L
|
||||
val partitionKey = PartitionKey("partition_id_1")
|
||||
val recordDto =
|
||||
RecordDTO(
|
||||
fields = emptyMap(),
|
||||
partitionKey = partitionKey,
|
||||
sizeBytes = 100L,
|
||||
emittedAtMs = emittedAtMs
|
||||
)
|
||||
|
||||
val streamMock =
|
||||
mockk<DestinationStream> { every { mappedDescriptor } returns streamDescriptor }
|
||||
val rawMock = mockk<DestinationRecordRaw> { every { stream } returns streamMock }
|
||||
val input = DataFlowStageIO(raw = rawMock, munged = recordDto)
|
||||
|
||||
val mockAggregate = mockk<Aggregate>()
|
||||
val mockPartitionHistogram = mockk<PartitionHistogram>()
|
||||
val aggregateEntry =
|
||||
mockk<AggregateEntry> {
|
||||
every { value } returns mockAggregate
|
||||
every { partitionHistogram } returns mockPartitionHistogram
|
||||
}
|
||||
coEvery { store.acceptFor(streamDescriptor, recordDto) } returns Unit
|
||||
coEvery { store.removeNextComplete(emittedAtMs) } returns aggregateEntry andThen null
|
||||
|
||||
val outputFlow = mockk<FlowCollector<DataFlowStageIO>>(relaxed = true)
|
||||
|
||||
stage.apply(input, outputFlow)
|
||||
|
||||
coVerify { store.acceptFor(streamDescriptor, recordDto) }
|
||||
coVerify { store.removeNextComplete(emittedAtMs) }
|
||||
coVerify {
|
||||
outputFlow.emit(
|
||||
DataFlowStageIO(
|
||||
aggregate = mockAggregate,
|
||||
partitionHistogram = mockPartitionHistogram
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `should not emit if no complete aggregate is ready`() = runTest {
|
||||
val streamDescriptor = DestinationStream.Descriptor("test_namespace", "test_name")
|
||||
val emittedAtMs = 1L
|
||||
val partitionKey = PartitionKey("partition_id_1")
|
||||
val recordDto =
|
||||
RecordDTO(
|
||||
fields = emptyMap(),
|
||||
partitionKey = partitionKey,
|
||||
sizeBytes = 100L,
|
||||
emittedAtMs = emittedAtMs
|
||||
)
|
||||
|
||||
val streamMock =
|
||||
mockk<DestinationStream> { every { mappedDescriptor } returns streamDescriptor }
|
||||
val rawMock = mockk<DestinationRecordRaw> { every { stream } returns streamMock }
|
||||
val input = DataFlowStageIO(raw = rawMock, munged = recordDto)
|
||||
|
||||
coEvery { store.acceptFor(streamDescriptor, recordDto) } returns Unit
|
||||
coEvery { store.removeNextComplete(emittedAtMs) } returns null // No aggregate ready
|
||||
|
||||
val outputFlow = mockk<FlowCollector<DataFlowStageIO>>(relaxed = true)
|
||||
|
||||
stage.apply(input, outputFlow)
|
||||
|
||||
coVerify { store.acceptFor(streamDescriptor, recordDto) }
|
||||
coVerify { store.removeNextComplete(emittedAtMs) }
|
||||
coVerify(exactly = 0) { outputFlow.emit(any()) } // Verify no emission
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `should emit multiple times if multiple complete aggregates are ready`() = runTest {
|
||||
val streamDescriptor = DestinationStream.Descriptor("test_namespace", "test_name")
|
||||
val emittedAtMs = 1L
|
||||
val partitionKey = PartitionKey("partition_id_1")
|
||||
val recordDto =
|
||||
RecordDTO(
|
||||
fields = emptyMap(),
|
||||
partitionKey = partitionKey,
|
||||
sizeBytes = 100L,
|
||||
emittedAtMs = emittedAtMs
|
||||
)
|
||||
|
||||
val streamMock =
|
||||
mockk<DestinationStream> { every { mappedDescriptor } returns streamDescriptor }
|
||||
val rawMock = mockk<DestinationRecordRaw> { every { stream } returns streamMock }
|
||||
val input = DataFlowStageIO(raw = rawMock, munged = recordDto)
|
||||
|
||||
val mockAggregate1 = mockk<Aggregate>()
|
||||
val mockPartitionHistogram1 = mockk<PartitionHistogram>()
|
||||
val aggregateEntry1 =
|
||||
mockk<AggregateEntry> {
|
||||
every { value } returns mockAggregate1
|
||||
every { partitionHistogram } returns mockPartitionHistogram1
|
||||
}
|
||||
|
||||
val mockAggregate2 = mockk<Aggregate>()
|
||||
val mockPartitionHistogram2 = mockk<PartitionHistogram>()
|
||||
val aggregateEntry2 =
|
||||
mockk<AggregateEntry> {
|
||||
every { value } returns mockAggregate2
|
||||
every { partitionHistogram } returns mockPartitionHistogram2
|
||||
}
|
||||
|
||||
coEvery { store.acceptFor(streamDescriptor, recordDto) } returns Unit
|
||||
coEvery { store.removeNextComplete(emittedAtMs) } returns
|
||||
aggregateEntry1 andThen
|
||||
aggregateEntry2 andThen
|
||||
null
|
||||
|
||||
val outputFlow = mockk<FlowCollector<DataFlowStageIO>>(relaxed = true)
|
||||
|
||||
stage.apply(input, outputFlow)
|
||||
|
||||
coVerify { store.acceptFor(streamDescriptor, recordDto) }
|
||||
coVerify { store.removeNextComplete(emittedAtMs) }
|
||||
coVerify(exactly = 1) {
|
||||
outputFlow.emit(
|
||||
DataFlowStageIO(
|
||||
aggregate = mockAggregate1,
|
||||
partitionHistogram = mockPartitionHistogram1
|
||||
)
|
||||
)
|
||||
}
|
||||
coVerify(exactly = 1) {
|
||||
outputFlow.emit(
|
||||
DataFlowStageIO(
|
||||
aggregate = mockAggregate2,
|
||||
partitionHistogram = mockPartitionHistogram2
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,43 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
|
||||
*/
|
||||
|
||||
package io.airbyte.cdk.load.dataflow.stages
|
||||
|
||||
import io.airbyte.cdk.load.dataflow.DataFlowStageIO
|
||||
import io.airbyte.cdk.load.dataflow.aggregate.Aggregate
|
||||
import io.mockk.coVerify
|
||||
import io.mockk.mockk
|
||||
import kotlin.test.assertSame
|
||||
import kotlinx.coroutines.test.runTest
|
||||
import org.junit.jupiter.api.Test
|
||||
import org.junit.jupiter.api.assertThrows
|
||||
|
||||
class FlushStageTest {
|
||||
private val flushStage = FlushStage()
|
||||
|
||||
@Test
|
||||
fun `given input with an aggregate, when apply is called, then it flushes the aggregate`() =
|
||||
runTest {
|
||||
// Given
|
||||
val mockAggregate = mockk<Aggregate>(relaxed = true)
|
||||
val input = DataFlowStageIO(aggregate = mockAggregate)
|
||||
|
||||
// When
|
||||
val result = flushStage.apply(input)
|
||||
|
||||
// Then
|
||||
coVerify(exactly = 1) { mockAggregate.flush() }
|
||||
assertSame(input, result, "The output should be the same instance as the input")
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `given input with a null aggregate, when apply is called, then it throws NullPointerException`() =
|
||||
runTest {
|
||||
// Given
|
||||
val input = DataFlowStageIO(aggregate = null)
|
||||
|
||||
// When & Then
|
||||
assertThrows<NullPointerException> { flushStage.apply(input) }
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,112 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
|
||||
*/
|
||||
|
||||
package io.airbyte.cdk.load.dataflow.stages
|
||||
|
||||
import io.airbyte.cdk.load.command.Append
|
||||
import io.airbyte.cdk.load.command.DestinationStream
|
||||
import io.airbyte.cdk.load.command.NamespaceMapper
|
||||
import io.airbyte.cdk.load.data.StringValue
|
||||
import io.airbyte.cdk.load.dataflow.DataFlowStageIO
|
||||
import io.airbyte.cdk.load.dataflow.DataMunger
|
||||
import io.airbyte.cdk.load.dataflow.state.PartitionKey
|
||||
import io.airbyte.cdk.load.message.DestinationRecordJsonSource
|
||||
import io.airbyte.cdk.load.message.DestinationRecordRaw
|
||||
import io.airbyte.protocol.models.v0.AirbyteMessage
|
||||
import io.airbyte.protocol.models.v0.AirbyteRecordMessage
|
||||
import io.mockk.every
|
||||
import io.mockk.impl.annotations.MockK
|
||||
import io.mockk.junit5.MockKExtension
|
||||
import io.mockk.verify
|
||||
import java.util.UUID
|
||||
import kotlinx.coroutines.runBlocking
|
||||
import kotlinx.coroutines.test.runTest
|
||||
import org.junit.jupiter.api.Assertions.assertEquals
|
||||
import org.junit.jupiter.api.Assertions.assertNotNull
|
||||
import org.junit.jupiter.api.Assertions.assertSame
|
||||
import org.junit.jupiter.api.BeforeEach
|
||||
import org.junit.jupiter.api.Test
|
||||
import org.junit.jupiter.api.assertThrows
|
||||
import org.junit.jupiter.api.extension.ExtendWith
|
||||
|
||||
@ExtendWith(MockKExtension::class)
|
||||
class ParseStageTest {
|
||||
@MockK private lateinit var munger: DataMunger
|
||||
|
||||
private lateinit var stage: ParseStage
|
||||
private lateinit var stream: DestinationStream
|
||||
private lateinit var rawRecord: DestinationRecordRaw
|
||||
|
||||
@BeforeEach
|
||||
fun setup() {
|
||||
stage = ParseStage(munger)
|
||||
stream =
|
||||
DestinationStream(
|
||||
unmappedNamespace = "test-namespace",
|
||||
unmappedName = "test-stream",
|
||||
importType = Append,
|
||||
schema = io.airbyte.cdk.load.data.ObjectType(linkedMapOf()),
|
||||
generationId = 1L,
|
||||
minimumGenerationId = 1L,
|
||||
syncId = 1L,
|
||||
namespaceMapper = NamespaceMapper()
|
||||
)
|
||||
rawRecord =
|
||||
DestinationRecordRaw(
|
||||
stream = stream,
|
||||
rawData =
|
||||
DestinationRecordJsonSource(
|
||||
AirbyteMessage().withRecord(AirbyteRecordMessage().withEmittedAt(12345L))
|
||||
),
|
||||
serializedSizeBytes = 100,
|
||||
airbyteRawId = UUID.randomUUID(),
|
||||
)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `given valid input, when apply is called, then it should munge the raw record and update the IO object`() =
|
||||
runTest {
|
||||
// Given
|
||||
val input =
|
||||
DataFlowStageIO(raw = rawRecord, partitionKey = PartitionKey("test-partition"))
|
||||
val transformedFields =
|
||||
mapOf("field1" to StringValue("value1"), "field2" to StringValue("42"))
|
||||
every { munger.transformForDest(rawRecord) } returns transformedFields
|
||||
|
||||
// When
|
||||
val result = stage.apply(input)
|
||||
|
||||
// Then
|
||||
assertSame(input, result, "The stage should return the same input object instance")
|
||||
assertNotNull(result.munged)
|
||||
|
||||
val mungedRecord = result.munged!!
|
||||
assertEquals(transformedFields, mungedRecord.fields)
|
||||
assertEquals(PartitionKey("test-partition"), mungedRecord.partitionKey)
|
||||
assertEquals(100L, mungedRecord.sizeBytes)
|
||||
assertEquals(12345L, mungedRecord.emittedAtMs)
|
||||
|
||||
verify(exactly = 1) { munger.transformForDest(rawRecord) }
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `given input with null raw record, when apply is called, then it should throw NullPointerException`() {
|
||||
// Given
|
||||
val input = DataFlowStageIO(raw = null, partitionKey = PartitionKey("test-partition"))
|
||||
|
||||
// When & Then
|
||||
assertThrows<NullPointerException> { runBlocking { stage.apply(input) } }
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `given input with null partition key, when apply is called, then it should throw NullPointerException`() {
|
||||
// Given
|
||||
val input = DataFlowStageIO(raw = rawRecord, partitionKey = null)
|
||||
val transformedFields = mapOf("field1" to StringValue("value1"))
|
||||
every { munger.transformForDest(rawRecord) } returns transformedFields
|
||||
|
||||
// When & Then
|
||||
assertThrows<NullPointerException> { runBlocking { stage.apply(input) } }
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,47 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
|
||||
*/
|
||||
|
||||
package io.airbyte.cdk.load.dataflow.stages
|
||||
|
||||
import io.airbyte.cdk.load.dataflow.DataFlowStageIO
|
||||
import io.airbyte.cdk.load.dataflow.state.PartitionHistogram
|
||||
import io.airbyte.cdk.load.dataflow.state.StateHistogramStore
|
||||
import io.mockk.mockk
|
||||
import io.mockk.verify
|
||||
import kotlin.test.assertFailsWith
|
||||
import kotlin.test.assertSame
|
||||
import kotlinx.coroutines.test.runTest
|
||||
import org.junit.jupiter.api.BeforeEach
|
||||
import org.junit.jupiter.api.Test
|
||||
|
||||
class StateStageTest {
|
||||
private val stateStore: StateHistogramStore = mockk(relaxed = true)
|
||||
private lateinit var stateStage: StateStage
|
||||
|
||||
@BeforeEach
|
||||
fun setup() {
|
||||
stateStage = StateStage(stateStore)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `apply happy path`() = runTest {
|
||||
// Arrange
|
||||
val histogram = mockk<PartitionHistogram>()
|
||||
val input = DataFlowStageIO(partitionHistogram = histogram)
|
||||
|
||||
// Act
|
||||
val result = stateStage.apply(input)
|
||||
|
||||
// Assert
|
||||
verify(exactly = 1) { stateStore.acceptFlushedCounts(histogram) }
|
||||
assertSame(input, result, "The output should be the same as the input object")
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `apply with null partition histogram throws exception`() = runTest {
|
||||
val input = DataFlowStageIO(partitionHistogram = null)
|
||||
assertFailsWith<NullPointerException> { stateStage.apply(input) }
|
||||
verify(exactly = 0) { stateStore.acceptFlushedCounts(any()) }
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user