diff --git a/airbyte-cdk/bulk/AGENT.md b/airbyte-cdk/bulk/AGENT.md index 62d743698ec..d08af0ad2e1 100644 --- a/airbyte-cdk/bulk/AGENT.md +++ b/airbyte-cdk/bulk/AGENT.md @@ -18,5 +18,6 @@ When module is mentioned, it refers to this [folder](.) - update the [changelog.md](changelog.md) where we need to add a description for the new version - If the version has already been bumped on the local branch, we shouldn't bump it again - We format our code by running the command `pre-commit run --all-files` from the root of the project +- when writing a test function, the function name can't start with test because it is redundant When you are done with a change always run the format and update the changelog if needed. \ No newline at end of file diff --git a/airbyte-cdk/bulk/changelog.md b/airbyte-cdk/bulk/changelog.md index c6ea5e4bbb9..1628547a32f 100644 --- a/airbyte-cdk/bulk/changelog.md +++ b/airbyte-cdk/bulk/changelog.md @@ -12,6 +12,8 @@ ## Version 0.1.14 +**Load CDK** + * **Changed:** Add agent. ## Version 0.1.13 diff --git a/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/load/dataflow/DataFlowPipelineTest.kt b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/load/dataflow/DataFlowPipelineTest.kt new file mode 100644 index 00000000000..bd3c58abb48 --- /dev/null +++ b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/load/dataflow/DataFlowPipelineTest.kt @@ -0,0 +1,77 @@ +/* + * Copyright (c) 2025 Airbyte, Inc., all rights reserved. + */ + +package io.airbyte.cdk.load.dataflow + +import io.airbyte.cdk.load.dataflow.config.MemoryAndParallelismConfig +import io.airbyte.cdk.load.dataflow.stages.AggregateStage +import io.mockk.coEvery +import io.mockk.coVerifySequence +import io.mockk.mockk +import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.flow.flowOf +import kotlinx.coroutines.test.runTest +import org.junit.jupiter.api.Test + +@ExperimentalCoroutinesApi +class DataFlowPipelineTest { + private val parse = mockk() + private val aggregate = mockk() + private val flush = mockk() + private val state = mockk() + private val startHandler = mockk() + private val completionHandler = mockk() + private val memoryAndParallelismConfig = + MemoryAndParallelismConfig( + maxOpenAggregates = 2, + maxBufferedAggregates = 2, + ) + + @Test + fun `pipeline execution flow`() = runTest { + // Given + val initialIO = mockk() + val input = flowOf(initialIO) + val pipeline = + DataFlowPipeline( + input, + parse, + aggregate, + flush, + state, + startHandler, + completionHandler, + memoryAndParallelismConfig + ) + + val parsedIO = mockk() + val aggregatedIO = mockk() + val flushedIO = mockk() + val stateIO = mockk() + + coEvery { startHandler.run() } returns Unit + coEvery { parse.apply(initialIO) } returns parsedIO + coEvery { aggregate.apply(parsedIO, any()) } coAnswers + { + val collector = secondArg>() + collector.emit(aggregatedIO) + } + coEvery { flush.apply(aggregatedIO) } returns flushedIO + coEvery { state.apply(flushedIO) } returns stateIO + coEvery { completionHandler.apply(null) } returns Unit + + // When + pipeline.run() + + // Then + coVerifySequence { + startHandler.run() + parse.apply(initialIO) + aggregate.apply(parsedIO, any()) + flush.apply(aggregatedIO) + state.apply(flushedIO) + completionHandler.apply(null) + } + } +} diff --git a/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/load/dataflow/DestinationLifecycleTest.kt b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/load/dataflow/DestinationLifecycleTest.kt new file mode 100644 index 00000000000..af91140d9dc --- /dev/null +++ b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/load/dataflow/DestinationLifecycleTest.kt @@ -0,0 +1,60 @@ +/* + * Copyright (c) 2025 Airbyte, Inc., all rights reserved. + */ + +package io.airbyte.cdk.load.dataflow + +import io.airbyte.cdk.load.command.DestinationCatalog +import io.airbyte.cdk.load.command.DestinationStream +import io.airbyte.cdk.load.dataflow.config.MemoryAndParallelismConfig +import io.airbyte.cdk.load.write.DestinationWriter +import io.airbyte.cdk.load.write.StreamLoader +import io.mockk.coEvery +import io.mockk.coVerify +import io.mockk.mockk +import kotlinx.coroutines.test.runTest +import org.junit.jupiter.api.Test + +class DestinationLifecycleTest { + + private val destinationInitializer: DestinationWriter = mockk(relaxed = true) + private val destinationCatalog: DestinationCatalog = mockk(relaxed = true) + private val pipeline: DataFlowPipeline = mockk(relaxed = true) + private val memoryAndParallelismConfig: MemoryAndParallelismConfig = + MemoryAndParallelismConfig(maxOpenAggregates = 1, maxBufferedAggregates = 1) + + private val destinationLifecycle = + DestinationLifecycle( + destinationInitializer, + destinationCatalog, + pipeline, + memoryAndParallelismConfig, + ) + + @Test + fun `should execute full lifecycle in correct order`() = runTest { + // Given + val streamLoader1 = mockk(relaxed = true) + val streamLoader2 = mockk(relaxed = true) + val stream1 = mockk(relaxed = true) + val stream2 = mockk(relaxed = true) + + coEvery { destinationCatalog.streams } returns listOf(stream1, stream2) + coEvery { destinationInitializer.createStreamLoader(stream1) } returns streamLoader1 + coEvery { destinationInitializer.createStreamLoader(stream2) } returns streamLoader2 + + // When + destinationLifecycle.run() + + // Then + coVerify(exactly = 1) { destinationInitializer.setup() } + coVerify(exactly = 1) { destinationInitializer.createStreamLoader(stream1) } + coVerify(exactly = 1) { streamLoader1.start() } + coVerify(exactly = 1) { destinationInitializer.createStreamLoader(stream2) } + coVerify(exactly = 1) { streamLoader2.start() } + coVerify(exactly = 1) { pipeline.run() } + coVerify(exactly = 1) { streamLoader1.close(true) } + coVerify(exactly = 1) { streamLoader2.close(true) } + coVerify(exactly = 1) { destinationInitializer.teardown() } + } +} diff --git a/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/load/dataflow/aggregate/SizeTriggerTest.kt b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/load/dataflow/aggregate/SizeTriggerTest.kt new file mode 100644 index 00000000000..97e0726b596 --- /dev/null +++ b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/load/dataflow/aggregate/SizeTriggerTest.kt @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2025 Airbyte, Inc., all rights reserved. + */ + +package io.airbyte.cdk.load.dataflow.aggregate + +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Assertions.assertFalse +import org.junit.jupiter.api.Assertions.assertTrue +import org.junit.jupiter.api.Test + +class SizeTriggerTest { + @Test + fun `trigger not fired when under threshold`() { + val trigger = SizeTrigger(100) + trigger.increment(50) + assertFalse(trigger.isComplete()) + assertEquals(50L, trigger.watermark()) + } + + @Test + fun `trigger fired when at threshold`() { + val trigger = SizeTrigger(100) + trigger.increment(100) + assertTrue(trigger.isComplete()) + assertEquals(100L, trigger.watermark()) + } + + @Test + fun `trigger fired when over threshold`() { + val trigger = SizeTrigger(100) + trigger.increment(150) + assertTrue(trigger.isComplete()) + assertEquals(150L, trigger.watermark()) + } + + @Test + fun `multiple increments`() { + val trigger = SizeTrigger(100) + trigger.increment(50) + assertFalse(trigger.isComplete()) + assertEquals(50L, trigger.watermark()) + trigger.increment(49) + assertFalse(trigger.isComplete()) + assertEquals(99L, trigger.watermark()) + trigger.increment(1) + assertTrue(trigger.isComplete()) + assertEquals(100L, trigger.watermark()) + } +} diff --git a/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/load/dataflow/aggregate/TimeTriggerTest.kt b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/load/dataflow/aggregate/TimeTriggerTest.kt new file mode 100644 index 00000000000..113ef9e9362 --- /dev/null +++ b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/load/dataflow/aggregate/TimeTriggerTest.kt @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2025 Airbyte, Inc., all rights reserved. + */ + +package io.airbyte.cdk.load.dataflow.aggregate + +import org.junit.jupiter.api.Assertions.assertFalse +import org.junit.jupiter.api.Assertions.assertTrue +import org.junit.jupiter.api.Test + +class TimeTriggerTest { + + @Test + fun `time trigger logic`() { + val triggerSize = 1000L // 1 second + val trigger = TimeTrigger(triggerSize) + + val startTime = System.currentTimeMillis() + + // Update the trigger with the start time + trigger.update(startTime) + + // Immediately after update, it should not be complete + assertFalse(trigger.isComplete(startTime)) + + // Just before the trigger time, it should not be complete + assertFalse(trigger.isComplete(startTime + triggerSize - 1)) + + // Exactly at the trigger time, it should be complete + assertTrue(trigger.isComplete(startTime + triggerSize)) + + // After the trigger time, it should be complete + assertTrue(trigger.isComplete(startTime + triggerSize + 1)) + } + + @Test + fun `updating timestamp resets the trigger`() { + val triggerSize = 1000L + val trigger = TimeTrigger(triggerSize) + + val startTime = System.currentTimeMillis() + trigger.update(startTime) + + // It should be complete after triggerSize + assertTrue(trigger.isComplete(startTime + triggerSize)) + + // Update the timestamp to a later time + val newStartTime = startTime + 500L + trigger.update(newStartTime) + + // Now it should not be complete relative to the new start time + assertFalse(trigger.isComplete(newStartTime + triggerSize - 1)) + assertTrue(trigger.isComplete(newStartTime + triggerSize)) + } +} diff --git a/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/load/dataflow/input/DataFlowPipelineInputFlowTest.kt b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/load/dataflow/input/DataFlowPipelineInputFlowTest.kt new file mode 100644 index 00000000000..5fb3cf9364b --- /dev/null +++ b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/load/dataflow/input/DataFlowPipelineInputFlowTest.kt @@ -0,0 +1,99 @@ +/* + * Copyright (c) 2025 Airbyte, Inc., all rights reserved. + */ + +package io.airbyte.cdk.load.dataflow.input + +import io.airbyte.cdk.load.command.DestinationStream +import io.airbyte.cdk.load.dataflow.DataFlowStageIO +import io.airbyte.cdk.load.dataflow.state.PartitionKey +import io.airbyte.cdk.load.dataflow.state.StateKeyClient +import io.airbyte.cdk.load.dataflow.state.StateStore +import io.airbyte.cdk.load.message.CheckpointMessage +import io.airbyte.cdk.load.message.DestinationMessage +import io.airbyte.cdk.load.message.DestinationRecord +import io.airbyte.cdk.load.message.DestinationRecordSource +import io.airbyte.cdk.load.message.Undefined +import io.mockk.coVerify +import io.mockk.every +import io.mockk.mockk +import java.util.UUID +import kotlinx.coroutines.flow.flowOf +import kotlinx.coroutines.flow.toList +import kotlinx.coroutines.runBlocking +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Test + +class DataFlowPipelineInputFlowTest { + @Test + fun `checkpoint message`() = runBlocking { + // Given + val checkpointMessage = mockk() + val inputFlow = flowOf(checkpointMessage) + val stateStore = mockk(relaxed = true) + val stateKeyClient = mockk() + val dataFlowPipelineInputFlow = + DataFlowPipelineInputFlow(inputFlow, stateStore, stateKeyClient) + + // When + val result = dataFlowPipelineInputFlow.toList() + + // Then + coVerify(exactly = 1) { stateStore.accept(checkpointMessage) } + assertEquals(0, result.size) + } + + @Test + fun `destination record`() = runBlocking { + // Given + val stream = mockk() + every { stream.schema } returns mockk() + every { stream.airbyteValueProxyFieldAccessors } returns emptyArray() + val message = mockk() + every { message.fileReference } returns null + val destinationRecord = + DestinationRecord( + stream, + message, + 1L, + null, + UUID.randomUUID(), + ) + val inputFlow = flowOf(destinationRecord) + val stateStore = mockk() + val stateKeyClient = mockk() + val partitionKey = PartitionKey("partitionKey") + every { stateKeyClient.getPartitionKey(any()) } returns partitionKey + val dataFlowPipelineInputFlow = + DataFlowPipelineInputFlow(inputFlow, stateStore, stateKeyClient) + + // When + val result = dataFlowPipelineInputFlow.toList() + + // Then + assertEquals(1, result.size) + val expected = + DataFlowStageIO( + raw = destinationRecord.asDestinationRecordRaw(), + partitionKey = partitionKey, + ) + assertEquals(expected, result[0]) + } + + @Test + fun `other message`() = runBlocking { + // Given + val undefinedMessage = Undefined + val inputFlow = flowOf(undefinedMessage) + val stateStore = mockk() + val stateKeyClient = mockk() + val dataFlowPipelineInputFlow = + DataFlowPipelineInputFlow(inputFlow, stateStore, stateKeyClient) + + // When + val result = dataFlowPipelineInputFlow.toList() + + // Then + assertEquals(0, result.size) + } +} diff --git a/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/load/dataflow/input/DestinationMessageInputFlowTest.kt b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/load/dataflow/input/DestinationMessageInputFlowTest.kt new file mode 100644 index 00000000000..8a7cfa57102 --- /dev/null +++ b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/load/dataflow/input/DestinationMessageInputFlowTest.kt @@ -0,0 +1,89 @@ +/* + * Copyright (c) 2025 Airbyte, Inc., all rights reserved. + */ + +package io.airbyte.cdk.load.dataflow.input + +import io.airbyte.cdk.load.message.DestinationMessage +import io.airbyte.cdk.load.message.ProtocolMessageDeserializer +import io.mockk.coVerify +import io.mockk.every +import io.mockk.mockk +import io.mockk.verify +import java.io.ByteArrayInputStream +import kotlinx.coroutines.flow.FlowCollector +import kotlinx.coroutines.test.runTest +import org.junit.jupiter.api.Test + +class DestinationMessageInputFlowTest { + + @Test + fun `should deserialize and emit messages from input stream`() = runTest { + // Given + val line1 = "message1" + val line2 = "message2" + val inputStream = ByteArrayInputStream("$line1\n$line2\n".toByteArray()) + val deserializer = mockk() + val message1 = mockk() + val message2 = mockk() + + every { deserializer.deserialize(line1) } returns message1 + every { deserializer.deserialize(line2) } returns message2 + + val inputFlow = DestinationMessageInputFlow(inputStream, deserializer) + val collector = mockk>(relaxed = true) + + // When + inputFlow.collect(collector) + + // Then + verify(exactly = 1) { deserializer.deserialize(line1) } + verify(exactly = 1) { deserializer.deserialize(line2) } + coVerify(exactly = 1) { collector.emit(message1) } + coVerify(exactly = 1) { collector.emit(message2) } + } + + @Test + fun `should ignore empty lines`() = runTest { + // Given + val line1 = "message1" + val emptyLine = "" + val line2 = "message2" + val inputStream = ByteArrayInputStream("$line1\n$emptyLine\n$line2\n".toByteArray()) + val deserializer = mockk() + val message1 = mockk() + val message2 = mockk() + + every { deserializer.deserialize(line1) } returns message1 + every { deserializer.deserialize(line2) } returns message2 + + val inputFlow = DestinationMessageInputFlow(inputStream, deserializer) + val collector = mockk>(relaxed = true) + + // When + inputFlow.collect(collector) + + // Then + verify(exactly = 1) { deserializer.deserialize(line1) } + verify(exactly = 0) { deserializer.deserialize(emptyLine) } // Ensure empty line is ignored + verify(exactly = 1) { deserializer.deserialize(line2) } + coVerify(exactly = 1) { collector.emit(message1) } + coVerify(exactly = 1) { collector.emit(message2) } + } + + @Test + fun `should not emit if no record is present`() = runTest { + // Given + val inputStream = ByteArrayInputStream("".toByteArray()) + val deserializer = mockk() + + val inputFlow = DestinationMessageInputFlow(inputStream, deserializer) + val collector = mockk>(relaxed = true) + + // When + inputFlow.collect(collector) + + // Then + coVerify(exactly = 0) { collector.emit(any()) } // Ensure no other message is emitted + } +} diff --git a/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/load/dataflow/stages/AggregateStageTest.kt b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/load/dataflow/stages/AggregateStageTest.kt new file mode 100644 index 00000000000..82c39e46b8d --- /dev/null +++ b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/load/dataflow/stages/AggregateStageTest.kt @@ -0,0 +1,166 @@ +/* + * Copyright (c) 2025 Airbyte, Inc., all rights reserved. + */ + +package io.airbyte.cdk.load.dataflow.stages + +import io.airbyte.cdk.load.command.DestinationStream +import io.airbyte.cdk.load.dataflow.DataFlowStageIO +import io.airbyte.cdk.load.dataflow.aggregate.Aggregate +import io.airbyte.cdk.load.dataflow.aggregate.AggregateEntry +import io.airbyte.cdk.load.dataflow.aggregate.AggregateStore +import io.airbyte.cdk.load.dataflow.state.PartitionHistogram +import io.airbyte.cdk.load.dataflow.state.PartitionKey +import io.airbyte.cdk.load.dataflow.transform.RecordDTO +import io.airbyte.cdk.load.message.DestinationRecordRaw +import io.mockk.coEvery +import io.mockk.coVerify +import io.mockk.every +import io.mockk.mockk +import kotlinx.coroutines.flow.FlowCollector +import kotlinx.coroutines.test.runTest +import org.junit.jupiter.api.Test + +class AggregateStageTest { + + private val store: AggregateStore = mockk(relaxed = true) + private val stage = AggregateStage(store) + + @Test + fun `test aggregation`() = runTest { + val streamDescriptor = DestinationStream.Descriptor("test_namespace", "test_name") + val emittedAtMs = 1L + val partitionKey = PartitionKey("partition_id_1") + val recordDto = + RecordDTO( + fields = emptyMap(), + partitionKey = partitionKey, + sizeBytes = 100L, + emittedAtMs = emittedAtMs + ) + + val streamMock = + mockk { every { mappedDescriptor } returns streamDescriptor } + val rawMock = mockk { every { stream } returns streamMock } + val input = DataFlowStageIO(raw = rawMock, munged = recordDto) + + val mockAggregate = mockk() + val mockPartitionHistogram = mockk() + val aggregateEntry = + mockk { + every { value } returns mockAggregate + every { partitionHistogram } returns mockPartitionHistogram + } + coEvery { store.acceptFor(streamDescriptor, recordDto) } returns Unit + coEvery { store.removeNextComplete(emittedAtMs) } returns aggregateEntry andThen null + + val outputFlow = mockk>(relaxed = true) + + stage.apply(input, outputFlow) + + coVerify { store.acceptFor(streamDescriptor, recordDto) } + coVerify { store.removeNextComplete(emittedAtMs) } + coVerify { + outputFlow.emit( + DataFlowStageIO( + aggregate = mockAggregate, + partitionHistogram = mockPartitionHistogram + ) + ) + } + } + + @Test + fun `should not emit if no complete aggregate is ready`() = runTest { + val streamDescriptor = DestinationStream.Descriptor("test_namespace", "test_name") + val emittedAtMs = 1L + val partitionKey = PartitionKey("partition_id_1") + val recordDto = + RecordDTO( + fields = emptyMap(), + partitionKey = partitionKey, + sizeBytes = 100L, + emittedAtMs = emittedAtMs + ) + + val streamMock = + mockk { every { mappedDescriptor } returns streamDescriptor } + val rawMock = mockk { every { stream } returns streamMock } + val input = DataFlowStageIO(raw = rawMock, munged = recordDto) + + coEvery { store.acceptFor(streamDescriptor, recordDto) } returns Unit + coEvery { store.removeNextComplete(emittedAtMs) } returns null // No aggregate ready + + val outputFlow = mockk>(relaxed = true) + + stage.apply(input, outputFlow) + + coVerify { store.acceptFor(streamDescriptor, recordDto) } + coVerify { store.removeNextComplete(emittedAtMs) } + coVerify(exactly = 0) { outputFlow.emit(any()) } // Verify no emission + } + + @Test + fun `should emit multiple times if multiple complete aggregates are ready`() = runTest { + val streamDescriptor = DestinationStream.Descriptor("test_namespace", "test_name") + val emittedAtMs = 1L + val partitionKey = PartitionKey("partition_id_1") + val recordDto = + RecordDTO( + fields = emptyMap(), + partitionKey = partitionKey, + sizeBytes = 100L, + emittedAtMs = emittedAtMs + ) + + val streamMock = + mockk { every { mappedDescriptor } returns streamDescriptor } + val rawMock = mockk { every { stream } returns streamMock } + val input = DataFlowStageIO(raw = rawMock, munged = recordDto) + + val mockAggregate1 = mockk() + val mockPartitionHistogram1 = mockk() + val aggregateEntry1 = + mockk { + every { value } returns mockAggregate1 + every { partitionHistogram } returns mockPartitionHistogram1 + } + + val mockAggregate2 = mockk() + val mockPartitionHistogram2 = mockk() + val aggregateEntry2 = + mockk { + every { value } returns mockAggregate2 + every { partitionHistogram } returns mockPartitionHistogram2 + } + + coEvery { store.acceptFor(streamDescriptor, recordDto) } returns Unit + coEvery { store.removeNextComplete(emittedAtMs) } returns + aggregateEntry1 andThen + aggregateEntry2 andThen + null + + val outputFlow = mockk>(relaxed = true) + + stage.apply(input, outputFlow) + + coVerify { store.acceptFor(streamDescriptor, recordDto) } + coVerify { store.removeNextComplete(emittedAtMs) } + coVerify(exactly = 1) { + outputFlow.emit( + DataFlowStageIO( + aggregate = mockAggregate1, + partitionHistogram = mockPartitionHistogram1 + ) + ) + } + coVerify(exactly = 1) { + outputFlow.emit( + DataFlowStageIO( + aggregate = mockAggregate2, + partitionHistogram = mockPartitionHistogram2 + ) + ) + } + } +} diff --git a/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/load/dataflow/stages/FlushStageTest.kt b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/load/dataflow/stages/FlushStageTest.kt new file mode 100644 index 00000000000..9ee65dce656 --- /dev/null +++ b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/load/dataflow/stages/FlushStageTest.kt @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2025 Airbyte, Inc., all rights reserved. + */ + +package io.airbyte.cdk.load.dataflow.stages + +import io.airbyte.cdk.load.dataflow.DataFlowStageIO +import io.airbyte.cdk.load.dataflow.aggregate.Aggregate +import io.mockk.coVerify +import io.mockk.mockk +import kotlin.test.assertSame +import kotlinx.coroutines.test.runTest +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.assertThrows + +class FlushStageTest { + private val flushStage = FlushStage() + + @Test + fun `given input with an aggregate, when apply is called, then it flushes the aggregate`() = + runTest { + // Given + val mockAggregate = mockk(relaxed = true) + val input = DataFlowStageIO(aggregate = mockAggregate) + + // When + val result = flushStage.apply(input) + + // Then + coVerify(exactly = 1) { mockAggregate.flush() } + assertSame(input, result, "The output should be the same instance as the input") + } + + @Test + fun `given input with a null aggregate, when apply is called, then it throws NullPointerException`() = + runTest { + // Given + val input = DataFlowStageIO(aggregate = null) + + // When & Then + assertThrows { flushStage.apply(input) } + } +} diff --git a/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/load/dataflow/stages/ParseStageTest.kt b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/load/dataflow/stages/ParseStageTest.kt new file mode 100644 index 00000000000..f2a3612c7c1 --- /dev/null +++ b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/load/dataflow/stages/ParseStageTest.kt @@ -0,0 +1,112 @@ +/* + * Copyright (c) 2025 Airbyte, Inc., all rights reserved. + */ + +package io.airbyte.cdk.load.dataflow.stages + +import io.airbyte.cdk.load.command.Append +import io.airbyte.cdk.load.command.DestinationStream +import io.airbyte.cdk.load.command.NamespaceMapper +import io.airbyte.cdk.load.data.StringValue +import io.airbyte.cdk.load.dataflow.DataFlowStageIO +import io.airbyte.cdk.load.dataflow.DataMunger +import io.airbyte.cdk.load.dataflow.state.PartitionKey +import io.airbyte.cdk.load.message.DestinationRecordJsonSource +import io.airbyte.cdk.load.message.DestinationRecordRaw +import io.airbyte.protocol.models.v0.AirbyteMessage +import io.airbyte.protocol.models.v0.AirbyteRecordMessage +import io.mockk.every +import io.mockk.impl.annotations.MockK +import io.mockk.junit5.MockKExtension +import io.mockk.verify +import java.util.UUID +import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.test.runTest +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Assertions.assertNotNull +import org.junit.jupiter.api.Assertions.assertSame +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.assertThrows +import org.junit.jupiter.api.extension.ExtendWith + +@ExtendWith(MockKExtension::class) +class ParseStageTest { + @MockK private lateinit var munger: DataMunger + + private lateinit var stage: ParseStage + private lateinit var stream: DestinationStream + private lateinit var rawRecord: DestinationRecordRaw + + @BeforeEach + fun setup() { + stage = ParseStage(munger) + stream = + DestinationStream( + unmappedNamespace = "test-namespace", + unmappedName = "test-stream", + importType = Append, + schema = io.airbyte.cdk.load.data.ObjectType(linkedMapOf()), + generationId = 1L, + minimumGenerationId = 1L, + syncId = 1L, + namespaceMapper = NamespaceMapper() + ) + rawRecord = + DestinationRecordRaw( + stream = stream, + rawData = + DestinationRecordJsonSource( + AirbyteMessage().withRecord(AirbyteRecordMessage().withEmittedAt(12345L)) + ), + serializedSizeBytes = 100, + airbyteRawId = UUID.randomUUID(), + ) + } + + @Test + fun `given valid input, when apply is called, then it should munge the raw record and update the IO object`() = + runTest { + // Given + val input = + DataFlowStageIO(raw = rawRecord, partitionKey = PartitionKey("test-partition")) + val transformedFields = + mapOf("field1" to StringValue("value1"), "field2" to StringValue("42")) + every { munger.transformForDest(rawRecord) } returns transformedFields + + // When + val result = stage.apply(input) + + // Then + assertSame(input, result, "The stage should return the same input object instance") + assertNotNull(result.munged) + + val mungedRecord = result.munged!! + assertEquals(transformedFields, mungedRecord.fields) + assertEquals(PartitionKey("test-partition"), mungedRecord.partitionKey) + assertEquals(100L, mungedRecord.sizeBytes) + assertEquals(12345L, mungedRecord.emittedAtMs) + + verify(exactly = 1) { munger.transformForDest(rawRecord) } + } + + @Test + fun `given input with null raw record, when apply is called, then it should throw NullPointerException`() { + // Given + val input = DataFlowStageIO(raw = null, partitionKey = PartitionKey("test-partition")) + + // When & Then + assertThrows { runBlocking { stage.apply(input) } } + } + + @Test + fun `given input with null partition key, when apply is called, then it should throw NullPointerException`() { + // Given + val input = DataFlowStageIO(raw = rawRecord, partitionKey = null) + val transformedFields = mapOf("field1" to StringValue("value1")) + every { munger.transformForDest(rawRecord) } returns transformedFields + + // When & Then + assertThrows { runBlocking { stage.apply(input) } } + } +} diff --git a/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/load/dataflow/stages/StateStageTest.kt b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/load/dataflow/stages/StateStageTest.kt new file mode 100644 index 00000000000..8048d6c8867 --- /dev/null +++ b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/load/dataflow/stages/StateStageTest.kt @@ -0,0 +1,47 @@ +/* + * Copyright (c) 2025 Airbyte, Inc., all rights reserved. + */ + +package io.airbyte.cdk.load.dataflow.stages + +import io.airbyte.cdk.load.dataflow.DataFlowStageIO +import io.airbyte.cdk.load.dataflow.state.PartitionHistogram +import io.airbyte.cdk.load.dataflow.state.StateHistogramStore +import io.mockk.mockk +import io.mockk.verify +import kotlin.test.assertFailsWith +import kotlin.test.assertSame +import kotlinx.coroutines.test.runTest +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test + +class StateStageTest { + private val stateStore: StateHistogramStore = mockk(relaxed = true) + private lateinit var stateStage: StateStage + + @BeforeEach + fun setup() { + stateStage = StateStage(stateStore) + } + + @Test + fun `apply happy path`() = runTest { + // Arrange + val histogram = mockk() + val input = DataFlowStageIO(partitionHistogram = histogram) + + // Act + val result = stateStage.apply(input) + + // Assert + verify(exactly = 1) { stateStore.acceptFlushedCounts(histogram) } + assertSame(input, result, "The output should be the same as the input object") + } + + @Test + fun `apply with null partition histogram throws exception`() = runTest { + val input = DataFlowStageIO(partitionHistogram = null) + assertFailsWith { stateStage.apply(input) } + verify(exactly = 0) { stateStore.acceptFlushedCounts(any()) } + } +}