1
0
mirror of synced 2025-12-19 18:14:56 -05:00

Add tests to the dataflow pipeline (#64887)

This commit is contained in:
Benoit Moriceau
2025-08-13 14:30:56 -07:00
committed by GitHub
parent b1094cca43
commit f361d660c2
12 changed files with 801 additions and 0 deletions

View File

@@ -18,5 +18,6 @@ When module is mentioned, it refers to this [folder](.)
- update the [changelog.md](changelog.md) where we need to add a description for the new version
- If the version has already been bumped on the local branch, we shouldn't bump it again
- We format our code by running the command `pre-commit run --all-files` from the root of the project
- when writing a test function, the function name can't start with test because it is redundant
When you are done with a change always run the format and update the changelog if needed.

View File

@@ -12,6 +12,8 @@
## Version 0.1.14
**Load CDK**
* **Changed:** Add agent.
## Version 0.1.13

View File

@@ -0,0 +1,77 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.cdk.load.dataflow
import io.airbyte.cdk.load.dataflow.config.MemoryAndParallelismConfig
import io.airbyte.cdk.load.dataflow.stages.AggregateStage
import io.mockk.coEvery
import io.mockk.coVerifySequence
import io.mockk.mockk
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.flow.flowOf
import kotlinx.coroutines.test.runTest
import org.junit.jupiter.api.Test
@ExperimentalCoroutinesApi
class DataFlowPipelineTest {
private val parse = mockk<DataFlowStage>()
private val aggregate = mockk<AggregateStage>()
private val flush = mockk<DataFlowStage>()
private val state = mockk<DataFlowStage>()
private val startHandler = mockk<PipelineStartHandler>()
private val completionHandler = mockk<PipelineCompletionHandler>()
private val memoryAndParallelismConfig =
MemoryAndParallelismConfig(
maxOpenAggregates = 2,
maxBufferedAggregates = 2,
)
@Test
fun `pipeline execution flow`() = runTest {
// Given
val initialIO = mockk<DataFlowStageIO>()
val input = flowOf(initialIO)
val pipeline =
DataFlowPipeline(
input,
parse,
aggregate,
flush,
state,
startHandler,
completionHandler,
memoryAndParallelismConfig
)
val parsedIO = mockk<DataFlowStageIO>()
val aggregatedIO = mockk<DataFlowStageIO>()
val flushedIO = mockk<DataFlowStageIO>()
val stateIO = mockk<DataFlowStageIO>()
coEvery { startHandler.run() } returns Unit
coEvery { parse.apply(initialIO) } returns parsedIO
coEvery { aggregate.apply(parsedIO, any()) } coAnswers
{
val collector = secondArg<kotlinx.coroutines.flow.FlowCollector<DataFlowStageIO>>()
collector.emit(aggregatedIO)
}
coEvery { flush.apply(aggregatedIO) } returns flushedIO
coEvery { state.apply(flushedIO) } returns stateIO
coEvery { completionHandler.apply(null) } returns Unit
// When
pipeline.run()
// Then
coVerifySequence {
startHandler.run()
parse.apply(initialIO)
aggregate.apply(parsedIO, any())
flush.apply(aggregatedIO)
state.apply(flushedIO)
completionHandler.apply(null)
}
}
}

View File

@@ -0,0 +1,60 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.cdk.load.dataflow
import io.airbyte.cdk.load.command.DestinationCatalog
import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.cdk.load.dataflow.config.MemoryAndParallelismConfig
import io.airbyte.cdk.load.write.DestinationWriter
import io.airbyte.cdk.load.write.StreamLoader
import io.mockk.coEvery
import io.mockk.coVerify
import io.mockk.mockk
import kotlinx.coroutines.test.runTest
import org.junit.jupiter.api.Test
class DestinationLifecycleTest {
private val destinationInitializer: DestinationWriter = mockk(relaxed = true)
private val destinationCatalog: DestinationCatalog = mockk(relaxed = true)
private val pipeline: DataFlowPipeline = mockk(relaxed = true)
private val memoryAndParallelismConfig: MemoryAndParallelismConfig =
MemoryAndParallelismConfig(maxOpenAggregates = 1, maxBufferedAggregates = 1)
private val destinationLifecycle =
DestinationLifecycle(
destinationInitializer,
destinationCatalog,
pipeline,
memoryAndParallelismConfig,
)
@Test
fun `should execute full lifecycle in correct order`() = runTest {
// Given
val streamLoader1 = mockk<StreamLoader>(relaxed = true)
val streamLoader2 = mockk<StreamLoader>(relaxed = true)
val stream1 = mockk<DestinationStream>(relaxed = true)
val stream2 = mockk<DestinationStream>(relaxed = true)
coEvery { destinationCatalog.streams } returns listOf(stream1, stream2)
coEvery { destinationInitializer.createStreamLoader(stream1) } returns streamLoader1
coEvery { destinationInitializer.createStreamLoader(stream2) } returns streamLoader2
// When
destinationLifecycle.run()
// Then
coVerify(exactly = 1) { destinationInitializer.setup() }
coVerify(exactly = 1) { destinationInitializer.createStreamLoader(stream1) }
coVerify(exactly = 1) { streamLoader1.start() }
coVerify(exactly = 1) { destinationInitializer.createStreamLoader(stream2) }
coVerify(exactly = 1) { streamLoader2.start() }
coVerify(exactly = 1) { pipeline.run() }
coVerify(exactly = 1) { streamLoader1.close(true) }
coVerify(exactly = 1) { streamLoader2.close(true) }
coVerify(exactly = 1) { destinationInitializer.teardown() }
}
}

View File

@@ -0,0 +1,50 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.cdk.load.dataflow.aggregate
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Assertions.assertFalse
import org.junit.jupiter.api.Assertions.assertTrue
import org.junit.jupiter.api.Test
class SizeTriggerTest {
@Test
fun `trigger not fired when under threshold`() {
val trigger = SizeTrigger(100)
trigger.increment(50)
assertFalse(trigger.isComplete())
assertEquals(50L, trigger.watermark())
}
@Test
fun `trigger fired when at threshold`() {
val trigger = SizeTrigger(100)
trigger.increment(100)
assertTrue(trigger.isComplete())
assertEquals(100L, trigger.watermark())
}
@Test
fun `trigger fired when over threshold`() {
val trigger = SizeTrigger(100)
trigger.increment(150)
assertTrue(trigger.isComplete())
assertEquals(150L, trigger.watermark())
}
@Test
fun `multiple increments`() {
val trigger = SizeTrigger(100)
trigger.increment(50)
assertFalse(trigger.isComplete())
assertEquals(50L, trigger.watermark())
trigger.increment(49)
assertFalse(trigger.isComplete())
assertEquals(99L, trigger.watermark())
trigger.increment(1)
assertTrue(trigger.isComplete())
assertEquals(100L, trigger.watermark())
}
}

View File

@@ -0,0 +1,55 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.cdk.load.dataflow.aggregate
import org.junit.jupiter.api.Assertions.assertFalse
import org.junit.jupiter.api.Assertions.assertTrue
import org.junit.jupiter.api.Test
class TimeTriggerTest {
@Test
fun `time trigger logic`() {
val triggerSize = 1000L // 1 second
val trigger = TimeTrigger(triggerSize)
val startTime = System.currentTimeMillis()
// Update the trigger with the start time
trigger.update(startTime)
// Immediately after update, it should not be complete
assertFalse(trigger.isComplete(startTime))
// Just before the trigger time, it should not be complete
assertFalse(trigger.isComplete(startTime + triggerSize - 1))
// Exactly at the trigger time, it should be complete
assertTrue(trigger.isComplete(startTime + triggerSize))
// After the trigger time, it should be complete
assertTrue(trigger.isComplete(startTime + triggerSize + 1))
}
@Test
fun `updating timestamp resets the trigger`() {
val triggerSize = 1000L
val trigger = TimeTrigger(triggerSize)
val startTime = System.currentTimeMillis()
trigger.update(startTime)
// It should be complete after triggerSize
assertTrue(trigger.isComplete(startTime + triggerSize))
// Update the timestamp to a later time
val newStartTime = startTime + 500L
trigger.update(newStartTime)
// Now it should not be complete relative to the new start time
assertFalse(trigger.isComplete(newStartTime + triggerSize - 1))
assertTrue(trigger.isComplete(newStartTime + triggerSize))
}
}

View File

@@ -0,0 +1,99 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.cdk.load.dataflow.input
import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.cdk.load.dataflow.DataFlowStageIO
import io.airbyte.cdk.load.dataflow.state.PartitionKey
import io.airbyte.cdk.load.dataflow.state.StateKeyClient
import io.airbyte.cdk.load.dataflow.state.StateStore
import io.airbyte.cdk.load.message.CheckpointMessage
import io.airbyte.cdk.load.message.DestinationMessage
import io.airbyte.cdk.load.message.DestinationRecord
import io.airbyte.cdk.load.message.DestinationRecordSource
import io.airbyte.cdk.load.message.Undefined
import io.mockk.coVerify
import io.mockk.every
import io.mockk.mockk
import java.util.UUID
import kotlinx.coroutines.flow.flowOf
import kotlinx.coroutines.flow.toList
import kotlinx.coroutines.runBlocking
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Test
class DataFlowPipelineInputFlowTest {
@Test
fun `checkpoint message`() = runBlocking {
// Given
val checkpointMessage = mockk<CheckpointMessage>()
val inputFlow = flowOf<DestinationMessage>(checkpointMessage)
val stateStore = mockk<StateStore>(relaxed = true)
val stateKeyClient = mockk<StateKeyClient>()
val dataFlowPipelineInputFlow =
DataFlowPipelineInputFlow(inputFlow, stateStore, stateKeyClient)
// When
val result = dataFlowPipelineInputFlow.toList()
// Then
coVerify(exactly = 1) { stateStore.accept(checkpointMessage) }
assertEquals(0, result.size)
}
@Test
fun `destination record`() = runBlocking {
// Given
val stream = mockk<DestinationStream>()
every { stream.schema } returns mockk()
every { stream.airbyteValueProxyFieldAccessors } returns emptyArray()
val message = mockk<DestinationRecordSource>()
every { message.fileReference } returns null
val destinationRecord =
DestinationRecord(
stream,
message,
1L,
null,
UUID.randomUUID(),
)
val inputFlow = flowOf<DestinationMessage>(destinationRecord)
val stateStore = mockk<StateStore>()
val stateKeyClient = mockk<StateKeyClient>()
val partitionKey = PartitionKey("partitionKey")
every { stateKeyClient.getPartitionKey(any()) } returns partitionKey
val dataFlowPipelineInputFlow =
DataFlowPipelineInputFlow(inputFlow, stateStore, stateKeyClient)
// When
val result = dataFlowPipelineInputFlow.toList()
// Then
assertEquals(1, result.size)
val expected =
DataFlowStageIO(
raw = destinationRecord.asDestinationRecordRaw(),
partitionKey = partitionKey,
)
assertEquals(expected, result[0])
}
@Test
fun `other message`() = runBlocking {
// Given
val undefinedMessage = Undefined
val inputFlow = flowOf<DestinationMessage>(undefinedMessage)
val stateStore = mockk<StateStore>()
val stateKeyClient = mockk<StateKeyClient>()
val dataFlowPipelineInputFlow =
DataFlowPipelineInputFlow(inputFlow, stateStore, stateKeyClient)
// When
val result = dataFlowPipelineInputFlow.toList()
// Then
assertEquals(0, result.size)
}
}

View File

@@ -0,0 +1,89 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.cdk.load.dataflow.input
import io.airbyte.cdk.load.message.DestinationMessage
import io.airbyte.cdk.load.message.ProtocolMessageDeserializer
import io.mockk.coVerify
import io.mockk.every
import io.mockk.mockk
import io.mockk.verify
import java.io.ByteArrayInputStream
import kotlinx.coroutines.flow.FlowCollector
import kotlinx.coroutines.test.runTest
import org.junit.jupiter.api.Test
class DestinationMessageInputFlowTest {
@Test
fun `should deserialize and emit messages from input stream`() = runTest {
// Given
val line1 = "message1"
val line2 = "message2"
val inputStream = ByteArrayInputStream("$line1\n$line2\n".toByteArray())
val deserializer = mockk<ProtocolMessageDeserializer>()
val message1 = mockk<DestinationMessage>()
val message2 = mockk<DestinationMessage>()
every { deserializer.deserialize(line1) } returns message1
every { deserializer.deserialize(line2) } returns message2
val inputFlow = DestinationMessageInputFlow(inputStream, deserializer)
val collector = mockk<FlowCollector<DestinationMessage>>(relaxed = true)
// When
inputFlow.collect(collector)
// Then
verify(exactly = 1) { deserializer.deserialize(line1) }
verify(exactly = 1) { deserializer.deserialize(line2) }
coVerify(exactly = 1) { collector.emit(message1) }
coVerify(exactly = 1) { collector.emit(message2) }
}
@Test
fun `should ignore empty lines`() = runTest {
// Given
val line1 = "message1"
val emptyLine = ""
val line2 = "message2"
val inputStream = ByteArrayInputStream("$line1\n$emptyLine\n$line2\n".toByteArray())
val deserializer = mockk<ProtocolMessageDeserializer>()
val message1 = mockk<DestinationMessage>()
val message2 = mockk<DestinationMessage>()
every { deserializer.deserialize(line1) } returns message1
every { deserializer.deserialize(line2) } returns message2
val inputFlow = DestinationMessageInputFlow(inputStream, deserializer)
val collector = mockk<FlowCollector<DestinationMessage>>(relaxed = true)
// When
inputFlow.collect(collector)
// Then
verify(exactly = 1) { deserializer.deserialize(line1) }
verify(exactly = 0) { deserializer.deserialize(emptyLine) } // Ensure empty line is ignored
verify(exactly = 1) { deserializer.deserialize(line2) }
coVerify(exactly = 1) { collector.emit(message1) }
coVerify(exactly = 1) { collector.emit(message2) }
}
@Test
fun `should not emit if no record is present`() = runTest {
// Given
val inputStream = ByteArrayInputStream("".toByteArray())
val deserializer = mockk<ProtocolMessageDeserializer>()
val inputFlow = DestinationMessageInputFlow(inputStream, deserializer)
val collector = mockk<FlowCollector<DestinationMessage>>(relaxed = true)
// When
inputFlow.collect(collector)
// Then
coVerify(exactly = 0) { collector.emit(any()) } // Ensure no other message is emitted
}
}

View File

@@ -0,0 +1,166 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.cdk.load.dataflow.stages
import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.cdk.load.dataflow.DataFlowStageIO
import io.airbyte.cdk.load.dataflow.aggregate.Aggregate
import io.airbyte.cdk.load.dataflow.aggregate.AggregateEntry
import io.airbyte.cdk.load.dataflow.aggregate.AggregateStore
import io.airbyte.cdk.load.dataflow.state.PartitionHistogram
import io.airbyte.cdk.load.dataflow.state.PartitionKey
import io.airbyte.cdk.load.dataflow.transform.RecordDTO
import io.airbyte.cdk.load.message.DestinationRecordRaw
import io.mockk.coEvery
import io.mockk.coVerify
import io.mockk.every
import io.mockk.mockk
import kotlinx.coroutines.flow.FlowCollector
import kotlinx.coroutines.test.runTest
import org.junit.jupiter.api.Test
class AggregateStageTest {
private val store: AggregateStore = mockk(relaxed = true)
private val stage = AggregateStage(store)
@Test
fun `test aggregation`() = runTest {
val streamDescriptor = DestinationStream.Descriptor("test_namespace", "test_name")
val emittedAtMs = 1L
val partitionKey = PartitionKey("partition_id_1")
val recordDto =
RecordDTO(
fields = emptyMap(),
partitionKey = partitionKey,
sizeBytes = 100L,
emittedAtMs = emittedAtMs
)
val streamMock =
mockk<DestinationStream> { every { mappedDescriptor } returns streamDescriptor }
val rawMock = mockk<DestinationRecordRaw> { every { stream } returns streamMock }
val input = DataFlowStageIO(raw = rawMock, munged = recordDto)
val mockAggregate = mockk<Aggregate>()
val mockPartitionHistogram = mockk<PartitionHistogram>()
val aggregateEntry =
mockk<AggregateEntry> {
every { value } returns mockAggregate
every { partitionHistogram } returns mockPartitionHistogram
}
coEvery { store.acceptFor(streamDescriptor, recordDto) } returns Unit
coEvery { store.removeNextComplete(emittedAtMs) } returns aggregateEntry andThen null
val outputFlow = mockk<FlowCollector<DataFlowStageIO>>(relaxed = true)
stage.apply(input, outputFlow)
coVerify { store.acceptFor(streamDescriptor, recordDto) }
coVerify { store.removeNextComplete(emittedAtMs) }
coVerify {
outputFlow.emit(
DataFlowStageIO(
aggregate = mockAggregate,
partitionHistogram = mockPartitionHistogram
)
)
}
}
@Test
fun `should not emit if no complete aggregate is ready`() = runTest {
val streamDescriptor = DestinationStream.Descriptor("test_namespace", "test_name")
val emittedAtMs = 1L
val partitionKey = PartitionKey("partition_id_1")
val recordDto =
RecordDTO(
fields = emptyMap(),
partitionKey = partitionKey,
sizeBytes = 100L,
emittedAtMs = emittedAtMs
)
val streamMock =
mockk<DestinationStream> { every { mappedDescriptor } returns streamDescriptor }
val rawMock = mockk<DestinationRecordRaw> { every { stream } returns streamMock }
val input = DataFlowStageIO(raw = rawMock, munged = recordDto)
coEvery { store.acceptFor(streamDescriptor, recordDto) } returns Unit
coEvery { store.removeNextComplete(emittedAtMs) } returns null // No aggregate ready
val outputFlow = mockk<FlowCollector<DataFlowStageIO>>(relaxed = true)
stage.apply(input, outputFlow)
coVerify { store.acceptFor(streamDescriptor, recordDto) }
coVerify { store.removeNextComplete(emittedAtMs) }
coVerify(exactly = 0) { outputFlow.emit(any()) } // Verify no emission
}
@Test
fun `should emit multiple times if multiple complete aggregates are ready`() = runTest {
val streamDescriptor = DestinationStream.Descriptor("test_namespace", "test_name")
val emittedAtMs = 1L
val partitionKey = PartitionKey("partition_id_1")
val recordDto =
RecordDTO(
fields = emptyMap(),
partitionKey = partitionKey,
sizeBytes = 100L,
emittedAtMs = emittedAtMs
)
val streamMock =
mockk<DestinationStream> { every { mappedDescriptor } returns streamDescriptor }
val rawMock = mockk<DestinationRecordRaw> { every { stream } returns streamMock }
val input = DataFlowStageIO(raw = rawMock, munged = recordDto)
val mockAggregate1 = mockk<Aggregate>()
val mockPartitionHistogram1 = mockk<PartitionHistogram>()
val aggregateEntry1 =
mockk<AggregateEntry> {
every { value } returns mockAggregate1
every { partitionHistogram } returns mockPartitionHistogram1
}
val mockAggregate2 = mockk<Aggregate>()
val mockPartitionHistogram2 = mockk<PartitionHistogram>()
val aggregateEntry2 =
mockk<AggregateEntry> {
every { value } returns mockAggregate2
every { partitionHistogram } returns mockPartitionHistogram2
}
coEvery { store.acceptFor(streamDescriptor, recordDto) } returns Unit
coEvery { store.removeNextComplete(emittedAtMs) } returns
aggregateEntry1 andThen
aggregateEntry2 andThen
null
val outputFlow = mockk<FlowCollector<DataFlowStageIO>>(relaxed = true)
stage.apply(input, outputFlow)
coVerify { store.acceptFor(streamDescriptor, recordDto) }
coVerify { store.removeNextComplete(emittedAtMs) }
coVerify(exactly = 1) {
outputFlow.emit(
DataFlowStageIO(
aggregate = mockAggregate1,
partitionHistogram = mockPartitionHistogram1
)
)
}
coVerify(exactly = 1) {
outputFlow.emit(
DataFlowStageIO(
aggregate = mockAggregate2,
partitionHistogram = mockPartitionHistogram2
)
)
}
}
}

View File

@@ -0,0 +1,43 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.cdk.load.dataflow.stages
import io.airbyte.cdk.load.dataflow.DataFlowStageIO
import io.airbyte.cdk.load.dataflow.aggregate.Aggregate
import io.mockk.coVerify
import io.mockk.mockk
import kotlin.test.assertSame
import kotlinx.coroutines.test.runTest
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.assertThrows
class FlushStageTest {
private val flushStage = FlushStage()
@Test
fun `given input with an aggregate, when apply is called, then it flushes the aggregate`() =
runTest {
// Given
val mockAggregate = mockk<Aggregate>(relaxed = true)
val input = DataFlowStageIO(aggregate = mockAggregate)
// When
val result = flushStage.apply(input)
// Then
coVerify(exactly = 1) { mockAggregate.flush() }
assertSame(input, result, "The output should be the same instance as the input")
}
@Test
fun `given input with a null aggregate, when apply is called, then it throws NullPointerException`() =
runTest {
// Given
val input = DataFlowStageIO(aggregate = null)
// When & Then
assertThrows<NullPointerException> { flushStage.apply(input) }
}
}

View File

@@ -0,0 +1,112 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.cdk.load.dataflow.stages
import io.airbyte.cdk.load.command.Append
import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.cdk.load.command.NamespaceMapper
import io.airbyte.cdk.load.data.StringValue
import io.airbyte.cdk.load.dataflow.DataFlowStageIO
import io.airbyte.cdk.load.dataflow.DataMunger
import io.airbyte.cdk.load.dataflow.state.PartitionKey
import io.airbyte.cdk.load.message.DestinationRecordJsonSource
import io.airbyte.cdk.load.message.DestinationRecordRaw
import io.airbyte.protocol.models.v0.AirbyteMessage
import io.airbyte.protocol.models.v0.AirbyteRecordMessage
import io.mockk.every
import io.mockk.impl.annotations.MockK
import io.mockk.junit5.MockKExtension
import io.mockk.verify
import java.util.UUID
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.test.runTest
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Assertions.assertNotNull
import org.junit.jupiter.api.Assertions.assertSame
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.assertThrows
import org.junit.jupiter.api.extension.ExtendWith
@ExtendWith(MockKExtension::class)
class ParseStageTest {
@MockK private lateinit var munger: DataMunger
private lateinit var stage: ParseStage
private lateinit var stream: DestinationStream
private lateinit var rawRecord: DestinationRecordRaw
@BeforeEach
fun setup() {
stage = ParseStage(munger)
stream =
DestinationStream(
unmappedNamespace = "test-namespace",
unmappedName = "test-stream",
importType = Append,
schema = io.airbyte.cdk.load.data.ObjectType(linkedMapOf()),
generationId = 1L,
minimumGenerationId = 1L,
syncId = 1L,
namespaceMapper = NamespaceMapper()
)
rawRecord =
DestinationRecordRaw(
stream = stream,
rawData =
DestinationRecordJsonSource(
AirbyteMessage().withRecord(AirbyteRecordMessage().withEmittedAt(12345L))
),
serializedSizeBytes = 100,
airbyteRawId = UUID.randomUUID(),
)
}
@Test
fun `given valid input, when apply is called, then it should munge the raw record and update the IO object`() =
runTest {
// Given
val input =
DataFlowStageIO(raw = rawRecord, partitionKey = PartitionKey("test-partition"))
val transformedFields =
mapOf("field1" to StringValue("value1"), "field2" to StringValue("42"))
every { munger.transformForDest(rawRecord) } returns transformedFields
// When
val result = stage.apply(input)
// Then
assertSame(input, result, "The stage should return the same input object instance")
assertNotNull(result.munged)
val mungedRecord = result.munged!!
assertEquals(transformedFields, mungedRecord.fields)
assertEquals(PartitionKey("test-partition"), mungedRecord.partitionKey)
assertEquals(100L, mungedRecord.sizeBytes)
assertEquals(12345L, mungedRecord.emittedAtMs)
verify(exactly = 1) { munger.transformForDest(rawRecord) }
}
@Test
fun `given input with null raw record, when apply is called, then it should throw NullPointerException`() {
// Given
val input = DataFlowStageIO(raw = null, partitionKey = PartitionKey("test-partition"))
// When & Then
assertThrows<NullPointerException> { runBlocking { stage.apply(input) } }
}
@Test
fun `given input with null partition key, when apply is called, then it should throw NullPointerException`() {
// Given
val input = DataFlowStageIO(raw = rawRecord, partitionKey = null)
val transformedFields = mapOf("field1" to StringValue("value1"))
every { munger.transformForDest(rawRecord) } returns transformedFields
// When & Then
assertThrows<NullPointerException> { runBlocking { stage.apply(input) } }
}
}

View File

@@ -0,0 +1,47 @@
/*
* Copyright (c) 2025 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.cdk.load.dataflow.stages
import io.airbyte.cdk.load.dataflow.DataFlowStageIO
import io.airbyte.cdk.load.dataflow.state.PartitionHistogram
import io.airbyte.cdk.load.dataflow.state.StateHistogramStore
import io.mockk.mockk
import io.mockk.verify
import kotlin.test.assertFailsWith
import kotlin.test.assertSame
import kotlinx.coroutines.test.runTest
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
class StateStageTest {
private val stateStore: StateHistogramStore = mockk(relaxed = true)
private lateinit var stateStage: StateStage
@BeforeEach
fun setup() {
stateStage = StateStage(stateStore)
}
@Test
fun `apply happy path`() = runTest {
// Arrange
val histogram = mockk<PartitionHistogram>()
val input = DataFlowStageIO(partitionHistogram = histogram)
// Act
val result = stateStage.apply(input)
// Assert
verify(exactly = 1) { stateStore.acceptFlushedCounts(histogram) }
assertSame(input, result, "The output should be the same as the input object")
}
@Test
fun `apply with null partition histogram throws exception`() = runTest {
val input = DataFlowStageIO(partitionHistogram = null)
assertFailsWith<NullPointerException> { stateStage.apply(input) }
verify(exactly = 0) { stateStore.acceptFlushedCounts(any()) }
}
}