1
0
mirror of synced 2025-12-25 02:09:19 -05:00

Add bytes to dataflow cdk emitted states (#65953)

This commit is contained in:
Ryan Br...
2025-09-09 10:08:19 -07:00
committed by GitHub
parent 1bffb2dfc4
commit 1502f68b52
17 changed files with 313 additions and 85 deletions

View File

@@ -6,7 +6,7 @@ import javax.xml.xpath.XPathFactory
import org.w3c.dom.Document
allprojects {
version = "0.1.23"
version = "0.1.24"
apply plugin: 'java-library'
apply plugin: 'maven-publish'

View File

@@ -1,5 +1,9 @@
**Load CDK**
## Version 0.1.24
* **Changed:** Adds byte counts to emitted state stats.
## Version 0.1.23
* **Changed:** Dataflow CDK fails syncs if there are unflushed states at the end of a sync.

View File

@@ -29,10 +29,11 @@ class AggregateStore(
private val aggregates = ConcurrentHashMap<StoreKey, AggregateEntry>()
fun acceptFor(key: StoreKey, record: RecordDTO) {
val (agg, histogram, timeTrigger, countTrigger, bytesTrigger) = getOrCreate(key)
val (agg, counts, bytes, timeTrigger, countTrigger, bytesTrigger) = getOrCreate(key)
agg.accept(record)
histogram.increment(record.partitionKey)
counts.increment(record.partitionKey, 1)
bytes.increment(record.partitionKey, record.sizeBytes)
countTrigger.increment(1)
bytesTrigger.increment(record.sizeBytes)
timeTrigger.update(record.emittedAtMs)
@@ -69,7 +70,8 @@ class AggregateStore(
aggregates.computeIfAbsent(key) {
AggregateEntry(
value = aggFactory.create(it),
partitionHistogram = PartitionHistogram(),
partitionCountsHistogram = PartitionHistogram(),
partitionBytesHistogram = PartitionHistogram(),
stalenessTrigger = TimeTrigger(stalenessDeadlinePerAggMs),
recordCountTrigger = SizeTrigger(maxRecordsPerAgg),
estimatedBytesTrigger = SizeTrigger(maxEstBytesPerAgg),
@@ -87,7 +89,8 @@ class AggregateStore(
data class AggregateEntry(
val value: Aggregate,
val partitionHistogram: PartitionHistogram,
val partitionCountsHistogram: PartitionHistogram,
val partitionBytesHistogram: PartitionHistogram,
val stalenessTrigger: TimeTrigger,
val recordCountTrigger: SizeTrigger,
val estimatedBytesTrigger: SizeTrigger,

View File

@@ -15,5 +15,6 @@ data class DataFlowStageIO(
var partitionKey: PartitionKey? = null,
var munged: RecordDTO? = null,
var aggregate: Aggregate? = null,
var partitionHistogram: PartitionHistogram? = null,
var partitionCountsHistogram: PartitionHistogram? = null,
var partitionBytesHistogram: PartitionHistogram? = null,
)

View File

@@ -32,7 +32,8 @@ class PipelineCompletionHandler(
.map {
async {
it.value.flush()
stateHistogramStore.acceptFlushedCounts(it.partitionHistogram)
stateHistogramStore.acceptFlushedCounts(it.partitionCountsHistogram)
stateHistogramStore.acceptFlushedBytes(it.partitionBytesHistogram)
}
}
.awaitAll()

View File

@@ -26,7 +26,8 @@ class AggregateStage(
outputFlow.emit(
DataFlowStageIO(
aggregate = next.value,
partitionHistogram = next.partitionHistogram,
partitionCountsHistogram = next.partitionCountsHistogram,
partitionBytesHistogram = next.partitionBytesHistogram,
)
)
next = store.removeNextComplete(rec.emittedAtMs)

View File

@@ -19,9 +19,11 @@ class StateStage(
private val log = KotlinLogging.logger {}
override suspend fun apply(input: DataFlowStageIO): DataFlowStageIO {
val stateUpdates = input.partitionHistogram!!
val countUpdates = input.partitionCountsHistogram!!
val byteUpdates = input.partitionBytesHistogram!!
stateHistogramStore.acceptFlushedCounts(stateUpdates)
stateHistogramStore.acceptFlushedCounts(countUpdates)
stateHistogramStore.acceptFlushedBytes(byteUpdates)
return input
}

View File

@@ -23,13 +23,11 @@ data class PartitionKey(
)
open class Histogram<T>(private val map: ConcurrentMap<T, Long> = ConcurrentHashMap()) {
fun increment(key: T): Histogram<T> {
return this.apply { map.merge(key, 1, Long::plus) }
}
fun increment(key: T, quantity: Long): Histogram<T> =
this.apply { map.merge(key, quantity, Long::plus) }
fun merge(other: Histogram<T>): Histogram<T> {
return this.apply { other.map.forEach { map.merge(it.key, it.value, Long::plus) } }
}
fun merge(other: Histogram<T>): Histogram<T> =
this.apply { other.map.forEach { map.merge(it.key, it.value, Long::plus) } }
fun get(key: T): Long? = map[key]

View File

@@ -13,11 +13,17 @@ class StateHistogramStore {
private val flushed: PartitionHistogram = PartitionHistogram(ConcurrentHashMap())
// Counts of expected messages by state id
private val expected: StateHistogram = StateHistogram(ConcurrentHashMap())
// Counts of flushed bytes by partition id
private val bytes: PartitionHistogram = PartitionHistogram(ConcurrentHashMap())
fun acceptFlushedCounts(value: PartitionHistogram): PartitionHistogram {
return flushed.merge(value)
}
fun acceptFlushedBytes(value: PartitionHistogram): PartitionHistogram {
return bytes.merge(value)
}
fun acceptExpectedCounts(key: StateKey, count: Long): StateHistogram {
val inner = ConcurrentHashMap<StateKey, Long>()
inner[key] = count
@@ -32,8 +38,16 @@ class StateHistogramStore {
return expectedCount == flushedCount
}
fun remove(key: StateKey): Long? {
key.partitionKeys.forEach { flushed.remove(it) }
return expected.remove(key)
fun remove(key: StateKey): StateHistogramStats {
val bytes =
key.partitionKeys.sumOf {
flushed.remove(it)
bytes.remove(it) ?: 0
}
val count = expected.remove(key) ?: 0
return StateHistogramStats(count, bytes)
}
}
data class StateHistogramStats(val count: Long, val bytes: Long)

View File

@@ -39,13 +39,14 @@ class StateStore(
stateSequence.incrementAndGet()
val msg = states.remove(key)
val count = histogramStore.remove(key)!!
val stats = histogramStore.remove(key)
// Add count to stats (will always equal source stats)
// TODO: decide what we want to do with dest stats
msg!!.updateStats(
destinationStats = CheckpointMessage.Stats(count),
totalRecords = count,
destinationStats = CheckpointMessage.Stats(stats.count),
totalRecords = stats.count,
totalBytes = stats.bytes,
)
return msg

View File

@@ -111,7 +111,8 @@ class AggregateStoreTest {
val entry = aggregateStore.getOrCreate(testKey)
assertEquals(1L, entry.recordCountTrigger.watermark())
assertEquals(50L, entry.estimatedBytesTrigger.watermark())
assertEquals(1L, entry.partitionHistogram.get(PartitionKey("partition1")))
assertEquals(1L, entry.partitionCountsHistogram.get(PartitionKey("partition1")))
assertEquals(50L, entry.partitionBytesHistogram.get(PartitionKey("partition1")))
}
@Test
@@ -210,7 +211,8 @@ class AggregateStoreTest {
val entry =
AggregateEntry(
value = mockAggregate,
partitionHistogram = PartitionHistogram(),
partitionCountsHistogram = PartitionHistogram(),
partitionBytesHistogram = PartitionHistogram(),
stalenessTrigger = TimeTrigger(10000),
recordCountTrigger = SizeTrigger(10).apply { repeat(10) { increment(1) } },
estimatedBytesTrigger = SizeTrigger(1000)
@@ -224,7 +226,8 @@ class AggregateStoreTest {
val entry =
AggregateEntry(
value = mockAggregate,
partitionHistogram = PartitionHistogram(),
partitionCountsHistogram = PartitionHistogram(),
partitionBytesHistogram = PartitionHistogram(),
stalenessTrigger = TimeTrigger(10000),
recordCountTrigger = SizeTrigger(100),
estimatedBytesTrigger = SizeTrigger(1000).apply { increment(1000) }
@@ -238,7 +241,8 @@ class AggregateStoreTest {
val entry =
AggregateEntry(
value = mockAggregate,
partitionHistogram = PartitionHistogram(),
partitionCountsHistogram = PartitionHistogram(),
partitionBytesHistogram = PartitionHistogram(),
stalenessTrigger = TimeTrigger(10000),
recordCountTrigger = SizeTrigger(100),
estimatedBytesTrigger = SizeTrigger(1000)
@@ -252,7 +256,8 @@ class AggregateStoreTest {
val entry =
AggregateEntry(
value = mockAggregate,
partitionHistogram = PartitionHistogram(),
partitionCountsHistogram = PartitionHistogram(),
partitionBytesHistogram = PartitionHistogram(),
stalenessTrigger = TimeTrigger(1000).apply { update(5000) },
recordCountTrigger = SizeTrigger(100),
estimatedBytesTrigger = SizeTrigger(1000)

View File

@@ -60,13 +60,16 @@ class PipelineCompletionHandlerTest {
// Given
val mockAggregate1 = mockk<Aggregate>()
val mockAggregate2 = mockk<Aggregate>()
val mockHistogram1 = mockk<PartitionHistogram>()
val mockHistogram2 = mockk<PartitionHistogram>()
val mockCountsHistogram1 = mockk<PartitionHistogram>()
val mockCountsHistogram2 = mockk<PartitionHistogram>()
val mockBytesHistogram1 = mockk<PartitionHistogram>()
val mockBytesHistogram2 = mockk<PartitionHistogram>()
val aggregateEntry1 =
AggregateEntry(
value = mockAggregate1,
partitionHistogram = mockHistogram1,
partitionCountsHistogram = mockCountsHistogram1,
partitionBytesHistogram = mockBytesHistogram1,
stalenessTrigger = mockk(),
recordCountTrigger = mockk(),
estimatedBytesTrigger = mockk()
@@ -75,7 +78,8 @@ class PipelineCompletionHandlerTest {
val aggregateEntry2 =
AggregateEntry(
value = mockAggregate2,
partitionHistogram = mockHistogram2,
partitionCountsHistogram = mockCountsHistogram2,
partitionBytesHistogram = mockBytesHistogram2,
stalenessTrigger = mockk(),
recordCountTrigger = mockk(),
estimatedBytesTrigger = mockk()
@@ -85,6 +89,7 @@ class PipelineCompletionHandlerTest {
coEvery { mockAggregate1.flush() } just Runs
coEvery { mockAggregate2.flush() } just Runs
every { stateHistogramStore.acceptFlushedCounts(any()) } returns mockk()
every { stateHistogramStore.acceptFlushedBytes(any()) } returns mockk()
// When
pipelineCompletionHandler.apply(null)
@@ -92,8 +97,10 @@ class PipelineCompletionHandlerTest {
// Then
coVerify(exactly = 1) { mockAggregate1.flush() }
coVerify(exactly = 1) { mockAggregate2.flush() }
verify(exactly = 1) { stateHistogramStore.acceptFlushedCounts(mockHistogram1) }
verify(exactly = 1) { stateHistogramStore.acceptFlushedCounts(mockHistogram2) }
verify(exactly = 1) { stateHistogramStore.acceptFlushedCounts(mockCountsHistogram1) }
verify(exactly = 1) { stateHistogramStore.acceptFlushedCounts(mockCountsHistogram2) }
verify(exactly = 1) { stateHistogramStore.acceptFlushedBytes(mockBytesHistogram1) }
verify(exactly = 1) { stateHistogramStore.acceptFlushedBytes(mockBytesHistogram2) }
}
@Test
@@ -113,13 +120,15 @@ class PipelineCompletionHandlerTest {
fun `apply should handle aggregate flush failure`() = runTest {
// Given
val mockAggregate = mockk<Aggregate>()
val mockHistogram = mockk<PartitionHistogram>()
val mockCountsHistogram = mockk<PartitionHistogram>()
val mockBytesHistogram = mockk<PartitionHistogram>()
val flushException = RuntimeException("Flush failed")
val aggregateEntry =
AggregateEntry(
value = mockAggregate,
partitionHistogram = mockHistogram,
partitionCountsHistogram = mockCountsHistogram,
partitionBytesHistogram = mockBytesHistogram,
stalenessTrigger = mockk(),
recordCountTrigger = mockk(),
estimatedBytesTrigger = mockk()
@@ -134,5 +143,6 @@ class PipelineCompletionHandlerTest {
coVerify(exactly = 1) { mockAggregate.flush() }
// Note: acceptFlushedCounts should not be called if flush fails
verify(exactly = 0) { stateHistogramStore.acceptFlushedCounts(any()) }
verify(exactly = 0) { stateHistogramStore.acceptFlushedBytes(any()) }
}
}

View File

@@ -45,11 +45,13 @@ class AggregateStageTest {
val input = DataFlowStageIO(raw = rawMock, munged = recordDto)
val mockAggregate = mockk<Aggregate>()
val mockPartitionHistogram = mockk<PartitionHistogram>()
val mockCountsHistogram = mockk<PartitionHistogram>()
val mockBytesHistogram = mockk<PartitionHistogram>()
val aggregateEntry =
mockk<AggregateEntry> {
every { value } returns mockAggregate
every { partitionHistogram } returns mockPartitionHistogram
every { partitionCountsHistogram } returns mockCountsHistogram
every { partitionBytesHistogram } returns mockBytesHistogram
}
coEvery { store.acceptFor(streamDescriptor, recordDto) } returns Unit
coEvery { store.removeNextComplete(emittedAtMs) } returns aggregateEntry andThen null
@@ -64,7 +66,8 @@ class AggregateStageTest {
outputFlow.emit(
DataFlowStageIO(
aggregate = mockAggregate,
partitionHistogram = mockPartitionHistogram
partitionCountsHistogram = mockCountsHistogram,
partitionBytesHistogram = mockBytesHistogram,
)
)
}
@@ -119,19 +122,23 @@ class AggregateStageTest {
val input = DataFlowStageIO(raw = rawMock, munged = recordDto)
val mockAggregate1 = mockk<Aggregate>()
val mockPartitionHistogram1 = mockk<PartitionHistogram>()
val mockCounts1 = mockk<PartitionHistogram>()
val mockBytes1 = mockk<PartitionHistogram>()
val aggregateEntry1 =
mockk<AggregateEntry> {
every { value } returns mockAggregate1
every { partitionHistogram } returns mockPartitionHistogram1
every { partitionCountsHistogram } returns mockCounts1
every { partitionBytesHistogram } returns mockBytes1
}
val mockAggregate2 = mockk<Aggregate>()
val mockPartitionHistogram2 = mockk<PartitionHistogram>()
val mockCounts2 = mockk<PartitionHistogram>()
val mockBytes2 = mockk<PartitionHistogram>()
val aggregateEntry2 =
mockk<AggregateEntry> {
every { value } returns mockAggregate2
every { partitionHistogram } returns mockPartitionHistogram2
every { partitionCountsHistogram } returns mockCounts2
every { partitionBytesHistogram } returns mockBytes2
}
coEvery { store.acceptFor(streamDescriptor, recordDto) } returns Unit
@@ -150,7 +157,8 @@ class AggregateStageTest {
outputFlow.emit(
DataFlowStageIO(
aggregate = mockAggregate1,
partitionHistogram = mockPartitionHistogram1
partitionCountsHistogram = mockCounts1,
partitionBytesHistogram = mockBytes1,
)
)
}
@@ -158,7 +166,8 @@ class AggregateStageTest {
outputFlow.emit(
DataFlowStageIO(
aggregate = mockAggregate2,
partitionHistogram = mockPartitionHistogram2
partitionCountsHistogram = mockCounts2,
partitionBytesHistogram = mockBytes2,
)
)
}

View File

@@ -27,21 +27,34 @@ class StateStageTest {
@Test
fun `apply happy path`() = runTest {
// Arrange
val histogram = mockk<PartitionHistogram>()
val input = DataFlowStageIO(partitionHistogram = histogram)
val countsHistogram = mockk<PartitionHistogram>()
val bytesHistogram = mockk<PartitionHistogram>()
val input =
DataFlowStageIO(
partitionCountsHistogram = countsHistogram,
partitionBytesHistogram = bytesHistogram,
)
// Act
val result = stateStage.apply(input)
// Assert
verify(exactly = 1) { stateStore.acceptFlushedCounts(histogram) }
verify(exactly = 1) { stateStore.acceptFlushedCounts(countsHistogram) }
verify(exactly = 1) { stateStore.acceptFlushedBytes(bytesHistogram) }
assertSame(input, result, "The output should be the same as the input object")
}
@Test
fun `apply with null partition histogram throws exception`() = runTest {
val input = DataFlowStageIO(partitionHistogram = null)
fun `apply with null counts histogram throws exception`() = runTest {
val input = DataFlowStageIO(partitionCountsHistogram = null)
assertFailsWith<NullPointerException> { stateStage.apply(input) }
verify(exactly = 0) { stateStore.acceptFlushedCounts(any()) }
}
@Test
fun `apply with null bytes histogram throws exception`() = runTest {
val input = DataFlowStageIO(partitionBytesHistogram = null)
assertFailsWith<NullPointerException> { stateStage.apply(input) }
verify(exactly = 0) { stateStore.acceptFlushedBytes(any()) }
}
}

View File

@@ -29,7 +29,7 @@ class StateHistogramStoreTest {
stateHistogramStore.acceptExpectedCounts(stateKey, 5L)
val partitionHistogram = PartitionHistogram(ConcurrentHashMap())
repeat(5) { partitionHistogram.increment(partitionKey) }
repeat(5) { partitionHistogram.increment(partitionKey, 1) }
stateHistogramStore.acceptFlushedCounts(partitionHistogram)
// When
@@ -50,9 +50,9 @@ class StateHistogramStoreTest {
stateHistogramStore.acceptExpectedCounts(stateKey, 15L) // 5 + 3 + 7 = 15
val partitionHistogram = PartitionHistogram(ConcurrentHashMap())
repeat(5) { partitionHistogram.increment(partitionKey1) }
repeat(3) { partitionHistogram.increment(partitionKey2) }
repeat(7) { partitionHistogram.increment(partitionKey3) }
repeat(5) { partitionHistogram.increment(partitionKey1, 1) }
repeat(3) { partitionHistogram.increment(partitionKey2, 1) }
repeat(7) { partitionHistogram.increment(partitionKey3, 1) }
stateHistogramStore.acceptFlushedCounts(partitionHistogram)
// When
@@ -71,7 +71,7 @@ class StateHistogramStoreTest {
stateHistogramStore.acceptExpectedCounts(stateKey, 10L)
val partitionHistogram = PartitionHistogram(ConcurrentHashMap())
repeat(7) { partitionHistogram.increment(partitionKey) } // Less than expected
repeat(7) { partitionHistogram.increment(partitionKey, 1) } // Less than expected
stateHistogramStore.acceptFlushedCounts(partitionHistogram)
// When
@@ -90,7 +90,7 @@ class StateHistogramStoreTest {
stateHistogramStore.acceptExpectedCounts(stateKey, 5L)
val partitionHistogram = PartitionHistogram(ConcurrentHashMap())
repeat(8) { partitionHistogram.increment(partitionKey) } // More than expected
repeat(8) { partitionHistogram.increment(partitionKey, 1) } // More than expected
stateHistogramStore.acceptFlushedCounts(partitionHistogram)
// When
@@ -107,7 +107,7 @@ class StateHistogramStoreTest {
val stateKey = StateKey(1L, listOf(partitionKey))
val partitionHistogram = PartitionHistogram(ConcurrentHashMap())
repeat(5) { partitionHistogram.increment(partitionKey) }
repeat(5) { partitionHistogram.increment(partitionKey, 1) }
stateHistogramStore.acceptFlushedCounts(partitionHistogram)
// When
@@ -142,7 +142,7 @@ class StateHistogramStoreTest {
stateHistogramStore.acceptExpectedCounts(stateKey, 3L)
val partitionHistogram = PartitionHistogram(ConcurrentHashMap())
repeat(3) { partitionHistogram.increment(partitionKey1) }
repeat(3) { partitionHistogram.increment(partitionKey1, 1) }
// partitionKey2 has no flushed counts, should be treated as 0
stateHistogramStore.acceptFlushedCounts(partitionHistogram)
@@ -154,28 +154,123 @@ class StateHistogramStoreTest {
}
@Test
fun `remove should delete both expected and flushed counts for state key`() {
fun `remove should delete both expected and flushed counts for state key and return stats`() {
// Given
val partitionKey1 = PartitionKey("partition-1")
val partitionKey2 = PartitionKey("partition-2")
val stateKey = StateKey(1L, listOf(partitionKey1, partitionKey2))
val expectedCount = 10L
val bytes1 = 1000L
val bytes2 = 2000L
stateHistogramStore.acceptExpectedCounts(stateKey, expectedCount)
val partitionHistogram = PartitionHistogram(ConcurrentHashMap())
repeat(5) { partitionHistogram.increment(partitionKey1) }
repeat(3) { partitionHistogram.increment(partitionKey2) }
stateHistogramStore.acceptFlushedCounts(partitionHistogram)
val partitionCountsHistogram = PartitionHistogram(ConcurrentHashMap())
repeat(5) { partitionCountsHistogram.increment(partitionKey1, 1) }
repeat(3) { partitionCountsHistogram.increment(partitionKey2, 1) }
stateHistogramStore.acceptFlushedCounts(partitionCountsHistogram)
val partitionBytesHistogram = PartitionHistogram(ConcurrentHashMap())
partitionBytesHistogram.increment(partitionKey1, bytes1)
partitionBytesHistogram.increment(partitionKey2, bytes2)
stateHistogramStore.acceptFlushedBytes(partitionBytesHistogram)
// When
val count = stateHistogramStore.remove(stateKey)
val stats = stateHistogramStore.remove(stateKey)
// Then
assertEquals(expectedCount, stats.count)
assertEquals(bytes1 + bytes2, stats.bytes)
assertFalse(
stateHistogramStore.isComplete(stateKey)
) // Should be false due to missing expected count
assertEquals(expectedCount, count)
}
@Test
fun `remove should handle missing byte counts as zero`() {
// Given
val partitionKey = PartitionKey("partition-1")
val stateKey = StateKey(1L, listOf(partitionKey))
val expectedCount = 5L
stateHistogramStore.acceptExpectedCounts(stateKey, expectedCount)
val partitionCountsHistogram = PartitionHistogram(ConcurrentHashMap())
repeat(5) { partitionCountsHistogram.increment(partitionKey, 1) }
stateHistogramStore.acceptFlushedCounts(partitionCountsHistogram)
// No bytes histogram added - should be treated as 0
// When
val stats = stateHistogramStore.remove(stateKey)
// Then
assertEquals(expectedCount, stats.count)
assertEquals(0L, stats.bytes) // No bytes were added
}
@Test
fun `remove should sum bytes from multiple partitions correctly`() {
// Given
val partitionKey1 = PartitionKey("partition-1")
val partitionKey2 = PartitionKey("partition-2")
val partitionKey3 = PartitionKey("partition-3")
val stateKey = StateKey(1L, listOf(partitionKey1, partitionKey2, partitionKey3))
val expectedCount = 15L
val bytes1 = 5000L
val bytes2 = 3000L
val bytes3 = 7000L
stateHistogramStore.acceptExpectedCounts(stateKey, expectedCount)
val partitionCountsHistogram = PartitionHistogram(ConcurrentHashMap())
repeat(7) { partitionCountsHistogram.increment(partitionKey1, 1) }
repeat(3) { partitionCountsHistogram.increment(partitionKey2, 1) }
repeat(5) { partitionCountsHistogram.increment(partitionKey3, 1) }
stateHistogramStore.acceptFlushedCounts(partitionCountsHistogram)
val partitionBytesHistogram = PartitionHistogram(ConcurrentHashMap())
partitionBytesHistogram.increment(partitionKey1, bytes1)
partitionBytesHistogram.increment(partitionKey2, bytes2)
partitionBytesHistogram.increment(partitionKey3, bytes3)
stateHistogramStore.acceptFlushedBytes(partitionBytesHistogram)
// When
val stats = stateHistogramStore.remove(stateKey)
// Then
assertEquals(expectedCount, stats.count)
assertEquals(bytes1 + bytes2 + bytes3, stats.bytes)
}
@Test
fun `remove should handle partial byte counts for partitions`() {
// Given
val partitionKey1 = PartitionKey("partition-1")
val partitionKey2 = PartitionKey("partition-2")
val stateKey = StateKey(1L, listOf(partitionKey1, partitionKey2))
val expectedCount = 10L
val bytes1 = 2500L
// partitionKey2 will have no bytes recorded
stateHistogramStore.acceptExpectedCounts(stateKey, expectedCount)
val partitionCountsHistogram = PartitionHistogram(ConcurrentHashMap())
repeat(6) { partitionCountsHistogram.increment(partitionKey1, 1) }
repeat(4) { partitionCountsHistogram.increment(partitionKey2, 1) }
stateHistogramStore.acceptFlushedCounts(partitionCountsHistogram)
val partitionBytesHistogram = PartitionHistogram(ConcurrentHashMap())
partitionBytesHistogram.increment(partitionKey1, bytes1)
// No bytes for partitionKey2
stateHistogramStore.acceptFlushedBytes(partitionBytesHistogram)
// When
val stats = stateHistogramStore.remove(stateKey)
// Then
assertEquals(expectedCount, stats.count)
assertEquals(bytes1, stats.bytes) // Only bytes from partition1
}
@Test
@@ -192,9 +287,9 @@ class StateHistogramStoreTest {
stateHistogramStore.acceptExpectedCounts(stateKey2, 4L)
val partitionHistogram = PartitionHistogram(ConcurrentHashMap())
repeat(5) { partitionHistogram.increment(partitionKey1) }
repeat(3) { partitionHistogram.increment(partitionKey2) }
repeat(4) { partitionHistogram.increment(partitionKey3) }
repeat(5) { partitionHistogram.increment(partitionKey1, 1) }
repeat(3) { partitionHistogram.increment(partitionKey2, 1) }
repeat(4) { partitionHistogram.increment(partitionKey3, 1) }
stateHistogramStore.acceptFlushedCounts(partitionHistogram)
// When

View File

@@ -17,7 +17,7 @@ class StateHistogramTest {
val stateKey = StateKey(1L, listOf(PartitionKey("partition-1")))
// When
histogram.increment(stateKey)
histogram.increment(stateKey, 1)
// Then
assertEquals(1L, histogram.get(stateKey))
@@ -30,9 +30,9 @@ class StateHistogramTest {
val stateKey = StateKey(1L, listOf(PartitionKey("partition-1")))
// When
histogram.increment(stateKey)
histogram.increment(stateKey)
histogram.increment(stateKey)
histogram.increment(stateKey, 1)
histogram.increment(stateKey, 1)
histogram.increment(stateKey, 1)
// Then
assertEquals(3L, histogram.get(stateKey))
@@ -46,9 +46,9 @@ class StateHistogramTest {
val stateKey2 = StateKey(2L, listOf(PartitionKey("partition-2")))
// When
histogram.increment(stateKey1)
histogram.increment(stateKey1)
histogram.increment(stateKey2)
histogram.increment(stateKey1, 1)
histogram.increment(stateKey1, 1)
histogram.increment(stateKey2, 1)
// Then
assertEquals(2L, histogram.get(stateKey1))
@@ -77,13 +77,13 @@ class StateHistogramTest {
val stateKey2 = StateKey(2L, listOf(PartitionKey("partition-2")))
val sharedKey = StateKey(3L, listOf(PartitionKey("partition-3")))
histogram1.increment(stateKey1)
histogram1.increment(stateKey1)
histogram1.increment(sharedKey)
histogram1.increment(stateKey1, 1)
histogram1.increment(stateKey1, 1)
histogram1.increment(sharedKey, 1)
histogram2.increment(stateKey2)
histogram2.increment(sharedKey)
histogram2.increment(sharedKey)
histogram2.increment(stateKey2, 1)
histogram2.increment(sharedKey, 1)
histogram2.increment(sharedKey, 1)
// When
histogram1.merge(histogram2)
@@ -102,8 +102,8 @@ class StateHistogramTest {
val stateKey1 = StateKey(1L, listOf(PartitionKey("partition-1")))
val stateKey2 = StateKey(2L, listOf(PartitionKey("partition-2")))
histogram1.increment(stateKey1)
histogram2.increment(stateKey2)
histogram1.increment(stateKey1, 1)
histogram2.increment(stateKey2, 1)
// When
histogram1.merge(histogram2)
@@ -120,8 +120,8 @@ class StateHistogramTest {
// Given
val histogram = StateHistogram()
val stateKey = StateKey(1L, listOf(PartitionKey("partition-1")))
histogram.increment(stateKey)
histogram.increment(stateKey)
histogram.increment(stateKey, 1)
histogram.increment(stateKey, 1)
// When
val removedValue = histogram.remove(stateKey)

View File

@@ -5,6 +5,7 @@
package io.airbyte.cdk.load.dataflow.state
import io.airbyte.cdk.load.message.CheckpointMessage
import io.airbyte.cdk.load.message.StreamCheckpoint
import io.mockk.every
import io.mockk.impl.annotations.MockK
import io.mockk.junit5.MockKExtension
@@ -12,6 +13,7 @@ import io.mockk.mockk
import io.mockk.verify
import kotlin.test.assertEquals
import kotlin.test.assertFalse
import kotlin.test.assertNotNull
import kotlin.test.assertNull
import kotlin.test.assertTrue
import org.junit.jupiter.api.BeforeEach
@@ -30,7 +32,7 @@ class StateStoreTest {
@BeforeEach
fun setUp() {
stateStore = StateStore(keyClient, histogramStore)
every { histogramStore.remove(any()) } returns 1L
every { histogramStore.remove(any()) } returns StateHistogramStats(1, 1)
}
@Test
@@ -189,6 +191,75 @@ class StateStoreTest {
assertNull(stateStore.getNextComplete()) // no more states
}
@Test
fun `getNextComplete should add byte counts and record counts to destination stats`() {
// Given
val sourceStats = CheckpointMessage.Stats(recordCount = 150L)
val checkpointMessage =
StreamCheckpoint(
checkpoint = mockk(),
sourceStats = sourceStats,
serializedSizeBytes = 1L,
)
val stateKey = StateKey(1L, listOf(PartitionKey("partition-1")))
val recordCount = 150L
val byteCount = 50000L
every { keyClient.getStateKey(checkpointMessage) } returns stateKey
every { histogramStore.acceptExpectedCounts(stateKey, recordCount) } returns mockk()
every { histogramStore.isComplete(stateKey) } returns true
every { histogramStore.remove(stateKey) } returns
StateHistogramStats(count = recordCount, bytes = byteCount)
stateStore.accept(checkpointMessage)
// When
val result = stateStore.getNextComplete()
// Then
assertNotNull(result)
assertEquals(recordCount, result.destinationStats?.recordCount)
assertEquals(recordCount, result.totalRecords)
assertEquals(byteCount, result.totalBytes)
// Verify histogram stats were removed
verify { histogramStore.remove(stateKey) }
}
@Test
fun `getNextComplete should handle different byte and record counts from histogram`() {
// Given
val expectedRecordCount = 100L
val actualRecordCount = 98L // Slightly different
val actualByteCount = 32768L
val sourceStats = CheckpointMessage.Stats(recordCount = expectedRecordCount)
val checkpointMessage =
StreamCheckpoint(
checkpoint = mockk(),
sourceStats = sourceStats,
serializedSizeBytes = 1L,
)
val stateKey = StateKey(1L, listOf(PartitionKey("partition-1")))
every { keyClient.getStateKey(checkpointMessage) } returns stateKey
every { histogramStore.acceptExpectedCounts(stateKey, expectedRecordCount) } returns mockk()
every { histogramStore.isComplete(stateKey) } returns true
every { histogramStore.remove(stateKey) } returns
StateHistogramStats(count = actualRecordCount, bytes = actualByteCount)
stateStore.accept(checkpointMessage)
// When
val result = stateStore.getNextComplete()
// Then
assertNotNull(result)
assertEquals(actualRecordCount, result.destinationStats?.recordCount)
assertEquals(actualRecordCount, result.totalRecords)
assertEquals(actualByteCount, result.totalBytes)
}
@Test
fun `getNextComplete should skip incomplete states and not advance sequence`() {
// Given