Add bytes to dataflow cdk emitted states (#65953)
This commit is contained in:
@@ -6,7 +6,7 @@ import javax.xml.xpath.XPathFactory
|
||||
import org.w3c.dom.Document
|
||||
|
||||
allprojects {
|
||||
version = "0.1.23"
|
||||
version = "0.1.24"
|
||||
apply plugin: 'java-library'
|
||||
apply plugin: 'maven-publish'
|
||||
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
**Load CDK**
|
||||
|
||||
## Version 0.1.24
|
||||
|
||||
* **Changed:** Adds byte counts to emitted state stats.
|
||||
|
||||
## Version 0.1.23
|
||||
|
||||
* **Changed:** Dataflow CDK fails syncs if there are unflushed states at the end of a sync.
|
||||
|
||||
@@ -29,10 +29,11 @@ class AggregateStore(
|
||||
private val aggregates = ConcurrentHashMap<StoreKey, AggregateEntry>()
|
||||
|
||||
fun acceptFor(key: StoreKey, record: RecordDTO) {
|
||||
val (agg, histogram, timeTrigger, countTrigger, bytesTrigger) = getOrCreate(key)
|
||||
val (agg, counts, bytes, timeTrigger, countTrigger, bytesTrigger) = getOrCreate(key)
|
||||
|
||||
agg.accept(record)
|
||||
histogram.increment(record.partitionKey)
|
||||
counts.increment(record.partitionKey, 1)
|
||||
bytes.increment(record.partitionKey, record.sizeBytes)
|
||||
countTrigger.increment(1)
|
||||
bytesTrigger.increment(record.sizeBytes)
|
||||
timeTrigger.update(record.emittedAtMs)
|
||||
@@ -69,7 +70,8 @@ class AggregateStore(
|
||||
aggregates.computeIfAbsent(key) {
|
||||
AggregateEntry(
|
||||
value = aggFactory.create(it),
|
||||
partitionHistogram = PartitionHistogram(),
|
||||
partitionCountsHistogram = PartitionHistogram(),
|
||||
partitionBytesHistogram = PartitionHistogram(),
|
||||
stalenessTrigger = TimeTrigger(stalenessDeadlinePerAggMs),
|
||||
recordCountTrigger = SizeTrigger(maxRecordsPerAgg),
|
||||
estimatedBytesTrigger = SizeTrigger(maxEstBytesPerAgg),
|
||||
@@ -87,7 +89,8 @@ class AggregateStore(
|
||||
|
||||
data class AggregateEntry(
|
||||
val value: Aggregate,
|
||||
val partitionHistogram: PartitionHistogram,
|
||||
val partitionCountsHistogram: PartitionHistogram,
|
||||
val partitionBytesHistogram: PartitionHistogram,
|
||||
val stalenessTrigger: TimeTrigger,
|
||||
val recordCountTrigger: SizeTrigger,
|
||||
val estimatedBytesTrigger: SizeTrigger,
|
||||
|
||||
@@ -15,5 +15,6 @@ data class DataFlowStageIO(
|
||||
var partitionKey: PartitionKey? = null,
|
||||
var munged: RecordDTO? = null,
|
||||
var aggregate: Aggregate? = null,
|
||||
var partitionHistogram: PartitionHistogram? = null,
|
||||
var partitionCountsHistogram: PartitionHistogram? = null,
|
||||
var partitionBytesHistogram: PartitionHistogram? = null,
|
||||
)
|
||||
|
||||
@@ -32,7 +32,8 @@ class PipelineCompletionHandler(
|
||||
.map {
|
||||
async {
|
||||
it.value.flush()
|
||||
stateHistogramStore.acceptFlushedCounts(it.partitionHistogram)
|
||||
stateHistogramStore.acceptFlushedCounts(it.partitionCountsHistogram)
|
||||
stateHistogramStore.acceptFlushedBytes(it.partitionBytesHistogram)
|
||||
}
|
||||
}
|
||||
.awaitAll()
|
||||
|
||||
@@ -26,7 +26,8 @@ class AggregateStage(
|
||||
outputFlow.emit(
|
||||
DataFlowStageIO(
|
||||
aggregate = next.value,
|
||||
partitionHistogram = next.partitionHistogram,
|
||||
partitionCountsHistogram = next.partitionCountsHistogram,
|
||||
partitionBytesHistogram = next.partitionBytesHistogram,
|
||||
)
|
||||
)
|
||||
next = store.removeNextComplete(rec.emittedAtMs)
|
||||
|
||||
@@ -19,9 +19,11 @@ class StateStage(
|
||||
private val log = KotlinLogging.logger {}
|
||||
|
||||
override suspend fun apply(input: DataFlowStageIO): DataFlowStageIO {
|
||||
val stateUpdates = input.partitionHistogram!!
|
||||
val countUpdates = input.partitionCountsHistogram!!
|
||||
val byteUpdates = input.partitionBytesHistogram!!
|
||||
|
||||
stateHistogramStore.acceptFlushedCounts(stateUpdates)
|
||||
stateHistogramStore.acceptFlushedCounts(countUpdates)
|
||||
stateHistogramStore.acceptFlushedBytes(byteUpdates)
|
||||
|
||||
return input
|
||||
}
|
||||
|
||||
@@ -23,13 +23,11 @@ data class PartitionKey(
|
||||
)
|
||||
|
||||
open class Histogram<T>(private val map: ConcurrentMap<T, Long> = ConcurrentHashMap()) {
|
||||
fun increment(key: T): Histogram<T> {
|
||||
return this.apply { map.merge(key, 1, Long::plus) }
|
||||
}
|
||||
fun increment(key: T, quantity: Long): Histogram<T> =
|
||||
this.apply { map.merge(key, quantity, Long::plus) }
|
||||
|
||||
fun merge(other: Histogram<T>): Histogram<T> {
|
||||
return this.apply { other.map.forEach { map.merge(it.key, it.value, Long::plus) } }
|
||||
}
|
||||
fun merge(other: Histogram<T>): Histogram<T> =
|
||||
this.apply { other.map.forEach { map.merge(it.key, it.value, Long::plus) } }
|
||||
|
||||
fun get(key: T): Long? = map[key]
|
||||
|
||||
|
||||
@@ -13,11 +13,17 @@ class StateHistogramStore {
|
||||
private val flushed: PartitionHistogram = PartitionHistogram(ConcurrentHashMap())
|
||||
// Counts of expected messages by state id
|
||||
private val expected: StateHistogram = StateHistogram(ConcurrentHashMap())
|
||||
// Counts of flushed bytes by partition id
|
||||
private val bytes: PartitionHistogram = PartitionHistogram(ConcurrentHashMap())
|
||||
|
||||
fun acceptFlushedCounts(value: PartitionHistogram): PartitionHistogram {
|
||||
return flushed.merge(value)
|
||||
}
|
||||
|
||||
fun acceptFlushedBytes(value: PartitionHistogram): PartitionHistogram {
|
||||
return bytes.merge(value)
|
||||
}
|
||||
|
||||
fun acceptExpectedCounts(key: StateKey, count: Long): StateHistogram {
|
||||
val inner = ConcurrentHashMap<StateKey, Long>()
|
||||
inner[key] = count
|
||||
@@ -32,8 +38,16 @@ class StateHistogramStore {
|
||||
return expectedCount == flushedCount
|
||||
}
|
||||
|
||||
fun remove(key: StateKey): Long? {
|
||||
key.partitionKeys.forEach { flushed.remove(it) }
|
||||
return expected.remove(key)
|
||||
fun remove(key: StateKey): StateHistogramStats {
|
||||
val bytes =
|
||||
key.partitionKeys.sumOf {
|
||||
flushed.remove(it)
|
||||
bytes.remove(it) ?: 0
|
||||
}
|
||||
val count = expected.remove(key) ?: 0
|
||||
|
||||
return StateHistogramStats(count, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
data class StateHistogramStats(val count: Long, val bytes: Long)
|
||||
|
||||
@@ -39,13 +39,14 @@ class StateStore(
|
||||
|
||||
stateSequence.incrementAndGet()
|
||||
val msg = states.remove(key)
|
||||
val count = histogramStore.remove(key)!!
|
||||
val stats = histogramStore.remove(key)
|
||||
|
||||
// Add count to stats (will always equal source stats)
|
||||
// TODO: decide what we want to do with dest stats
|
||||
msg!!.updateStats(
|
||||
destinationStats = CheckpointMessage.Stats(count),
|
||||
totalRecords = count,
|
||||
destinationStats = CheckpointMessage.Stats(stats.count),
|
||||
totalRecords = stats.count,
|
||||
totalBytes = stats.bytes,
|
||||
)
|
||||
|
||||
return msg
|
||||
|
||||
@@ -111,7 +111,8 @@ class AggregateStoreTest {
|
||||
val entry = aggregateStore.getOrCreate(testKey)
|
||||
assertEquals(1L, entry.recordCountTrigger.watermark())
|
||||
assertEquals(50L, entry.estimatedBytesTrigger.watermark())
|
||||
assertEquals(1L, entry.partitionHistogram.get(PartitionKey("partition1")))
|
||||
assertEquals(1L, entry.partitionCountsHistogram.get(PartitionKey("partition1")))
|
||||
assertEquals(50L, entry.partitionBytesHistogram.get(PartitionKey("partition1")))
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -210,7 +211,8 @@ class AggregateStoreTest {
|
||||
val entry =
|
||||
AggregateEntry(
|
||||
value = mockAggregate,
|
||||
partitionHistogram = PartitionHistogram(),
|
||||
partitionCountsHistogram = PartitionHistogram(),
|
||||
partitionBytesHistogram = PartitionHistogram(),
|
||||
stalenessTrigger = TimeTrigger(10000),
|
||||
recordCountTrigger = SizeTrigger(10).apply { repeat(10) { increment(1) } },
|
||||
estimatedBytesTrigger = SizeTrigger(1000)
|
||||
@@ -224,7 +226,8 @@ class AggregateStoreTest {
|
||||
val entry =
|
||||
AggregateEntry(
|
||||
value = mockAggregate,
|
||||
partitionHistogram = PartitionHistogram(),
|
||||
partitionCountsHistogram = PartitionHistogram(),
|
||||
partitionBytesHistogram = PartitionHistogram(),
|
||||
stalenessTrigger = TimeTrigger(10000),
|
||||
recordCountTrigger = SizeTrigger(100),
|
||||
estimatedBytesTrigger = SizeTrigger(1000).apply { increment(1000) }
|
||||
@@ -238,7 +241,8 @@ class AggregateStoreTest {
|
||||
val entry =
|
||||
AggregateEntry(
|
||||
value = mockAggregate,
|
||||
partitionHistogram = PartitionHistogram(),
|
||||
partitionCountsHistogram = PartitionHistogram(),
|
||||
partitionBytesHistogram = PartitionHistogram(),
|
||||
stalenessTrigger = TimeTrigger(10000),
|
||||
recordCountTrigger = SizeTrigger(100),
|
||||
estimatedBytesTrigger = SizeTrigger(1000)
|
||||
@@ -252,7 +256,8 @@ class AggregateStoreTest {
|
||||
val entry =
|
||||
AggregateEntry(
|
||||
value = mockAggregate,
|
||||
partitionHistogram = PartitionHistogram(),
|
||||
partitionCountsHistogram = PartitionHistogram(),
|
||||
partitionBytesHistogram = PartitionHistogram(),
|
||||
stalenessTrigger = TimeTrigger(1000).apply { update(5000) },
|
||||
recordCountTrigger = SizeTrigger(100),
|
||||
estimatedBytesTrigger = SizeTrigger(1000)
|
||||
|
||||
@@ -60,13 +60,16 @@ class PipelineCompletionHandlerTest {
|
||||
// Given
|
||||
val mockAggregate1 = mockk<Aggregate>()
|
||||
val mockAggregate2 = mockk<Aggregate>()
|
||||
val mockHistogram1 = mockk<PartitionHistogram>()
|
||||
val mockHistogram2 = mockk<PartitionHistogram>()
|
||||
val mockCountsHistogram1 = mockk<PartitionHistogram>()
|
||||
val mockCountsHistogram2 = mockk<PartitionHistogram>()
|
||||
val mockBytesHistogram1 = mockk<PartitionHistogram>()
|
||||
val mockBytesHistogram2 = mockk<PartitionHistogram>()
|
||||
|
||||
val aggregateEntry1 =
|
||||
AggregateEntry(
|
||||
value = mockAggregate1,
|
||||
partitionHistogram = mockHistogram1,
|
||||
partitionCountsHistogram = mockCountsHistogram1,
|
||||
partitionBytesHistogram = mockBytesHistogram1,
|
||||
stalenessTrigger = mockk(),
|
||||
recordCountTrigger = mockk(),
|
||||
estimatedBytesTrigger = mockk()
|
||||
@@ -75,7 +78,8 @@ class PipelineCompletionHandlerTest {
|
||||
val aggregateEntry2 =
|
||||
AggregateEntry(
|
||||
value = mockAggregate2,
|
||||
partitionHistogram = mockHistogram2,
|
||||
partitionCountsHistogram = mockCountsHistogram2,
|
||||
partitionBytesHistogram = mockBytesHistogram2,
|
||||
stalenessTrigger = mockk(),
|
||||
recordCountTrigger = mockk(),
|
||||
estimatedBytesTrigger = mockk()
|
||||
@@ -85,6 +89,7 @@ class PipelineCompletionHandlerTest {
|
||||
coEvery { mockAggregate1.flush() } just Runs
|
||||
coEvery { mockAggregate2.flush() } just Runs
|
||||
every { stateHistogramStore.acceptFlushedCounts(any()) } returns mockk()
|
||||
every { stateHistogramStore.acceptFlushedBytes(any()) } returns mockk()
|
||||
|
||||
// When
|
||||
pipelineCompletionHandler.apply(null)
|
||||
@@ -92,8 +97,10 @@ class PipelineCompletionHandlerTest {
|
||||
// Then
|
||||
coVerify(exactly = 1) { mockAggregate1.flush() }
|
||||
coVerify(exactly = 1) { mockAggregate2.flush() }
|
||||
verify(exactly = 1) { stateHistogramStore.acceptFlushedCounts(mockHistogram1) }
|
||||
verify(exactly = 1) { stateHistogramStore.acceptFlushedCounts(mockHistogram2) }
|
||||
verify(exactly = 1) { stateHistogramStore.acceptFlushedCounts(mockCountsHistogram1) }
|
||||
verify(exactly = 1) { stateHistogramStore.acceptFlushedCounts(mockCountsHistogram2) }
|
||||
verify(exactly = 1) { stateHistogramStore.acceptFlushedBytes(mockBytesHistogram1) }
|
||||
verify(exactly = 1) { stateHistogramStore.acceptFlushedBytes(mockBytesHistogram2) }
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -113,13 +120,15 @@ class PipelineCompletionHandlerTest {
|
||||
fun `apply should handle aggregate flush failure`() = runTest {
|
||||
// Given
|
||||
val mockAggregate = mockk<Aggregate>()
|
||||
val mockHistogram = mockk<PartitionHistogram>()
|
||||
val mockCountsHistogram = mockk<PartitionHistogram>()
|
||||
val mockBytesHistogram = mockk<PartitionHistogram>()
|
||||
val flushException = RuntimeException("Flush failed")
|
||||
|
||||
val aggregateEntry =
|
||||
AggregateEntry(
|
||||
value = mockAggregate,
|
||||
partitionHistogram = mockHistogram,
|
||||
partitionCountsHistogram = mockCountsHistogram,
|
||||
partitionBytesHistogram = mockBytesHistogram,
|
||||
stalenessTrigger = mockk(),
|
||||
recordCountTrigger = mockk(),
|
||||
estimatedBytesTrigger = mockk()
|
||||
@@ -134,5 +143,6 @@ class PipelineCompletionHandlerTest {
|
||||
coVerify(exactly = 1) { mockAggregate.flush() }
|
||||
// Note: acceptFlushedCounts should not be called if flush fails
|
||||
verify(exactly = 0) { stateHistogramStore.acceptFlushedCounts(any()) }
|
||||
verify(exactly = 0) { stateHistogramStore.acceptFlushedBytes(any()) }
|
||||
}
|
||||
}
|
||||
|
||||
@@ -45,11 +45,13 @@ class AggregateStageTest {
|
||||
val input = DataFlowStageIO(raw = rawMock, munged = recordDto)
|
||||
|
||||
val mockAggregate = mockk<Aggregate>()
|
||||
val mockPartitionHistogram = mockk<PartitionHistogram>()
|
||||
val mockCountsHistogram = mockk<PartitionHistogram>()
|
||||
val mockBytesHistogram = mockk<PartitionHistogram>()
|
||||
val aggregateEntry =
|
||||
mockk<AggregateEntry> {
|
||||
every { value } returns mockAggregate
|
||||
every { partitionHistogram } returns mockPartitionHistogram
|
||||
every { partitionCountsHistogram } returns mockCountsHistogram
|
||||
every { partitionBytesHistogram } returns mockBytesHistogram
|
||||
}
|
||||
coEvery { store.acceptFor(streamDescriptor, recordDto) } returns Unit
|
||||
coEvery { store.removeNextComplete(emittedAtMs) } returns aggregateEntry andThen null
|
||||
@@ -64,7 +66,8 @@ class AggregateStageTest {
|
||||
outputFlow.emit(
|
||||
DataFlowStageIO(
|
||||
aggregate = mockAggregate,
|
||||
partitionHistogram = mockPartitionHistogram
|
||||
partitionCountsHistogram = mockCountsHistogram,
|
||||
partitionBytesHistogram = mockBytesHistogram,
|
||||
)
|
||||
)
|
||||
}
|
||||
@@ -119,19 +122,23 @@ class AggregateStageTest {
|
||||
val input = DataFlowStageIO(raw = rawMock, munged = recordDto)
|
||||
|
||||
val mockAggregate1 = mockk<Aggregate>()
|
||||
val mockPartitionHistogram1 = mockk<PartitionHistogram>()
|
||||
val mockCounts1 = mockk<PartitionHistogram>()
|
||||
val mockBytes1 = mockk<PartitionHistogram>()
|
||||
val aggregateEntry1 =
|
||||
mockk<AggregateEntry> {
|
||||
every { value } returns mockAggregate1
|
||||
every { partitionHistogram } returns mockPartitionHistogram1
|
||||
every { partitionCountsHistogram } returns mockCounts1
|
||||
every { partitionBytesHistogram } returns mockBytes1
|
||||
}
|
||||
|
||||
val mockAggregate2 = mockk<Aggregate>()
|
||||
val mockPartitionHistogram2 = mockk<PartitionHistogram>()
|
||||
val mockCounts2 = mockk<PartitionHistogram>()
|
||||
val mockBytes2 = mockk<PartitionHistogram>()
|
||||
val aggregateEntry2 =
|
||||
mockk<AggregateEntry> {
|
||||
every { value } returns mockAggregate2
|
||||
every { partitionHistogram } returns mockPartitionHistogram2
|
||||
every { partitionCountsHistogram } returns mockCounts2
|
||||
every { partitionBytesHistogram } returns mockBytes2
|
||||
}
|
||||
|
||||
coEvery { store.acceptFor(streamDescriptor, recordDto) } returns Unit
|
||||
@@ -150,7 +157,8 @@ class AggregateStageTest {
|
||||
outputFlow.emit(
|
||||
DataFlowStageIO(
|
||||
aggregate = mockAggregate1,
|
||||
partitionHistogram = mockPartitionHistogram1
|
||||
partitionCountsHistogram = mockCounts1,
|
||||
partitionBytesHistogram = mockBytes1,
|
||||
)
|
||||
)
|
||||
}
|
||||
@@ -158,7 +166,8 @@ class AggregateStageTest {
|
||||
outputFlow.emit(
|
||||
DataFlowStageIO(
|
||||
aggregate = mockAggregate2,
|
||||
partitionHistogram = mockPartitionHistogram2
|
||||
partitionCountsHistogram = mockCounts2,
|
||||
partitionBytesHistogram = mockBytes2,
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
@@ -27,21 +27,34 @@ class StateStageTest {
|
||||
@Test
|
||||
fun `apply happy path`() = runTest {
|
||||
// Arrange
|
||||
val histogram = mockk<PartitionHistogram>()
|
||||
val input = DataFlowStageIO(partitionHistogram = histogram)
|
||||
val countsHistogram = mockk<PartitionHistogram>()
|
||||
val bytesHistogram = mockk<PartitionHistogram>()
|
||||
val input =
|
||||
DataFlowStageIO(
|
||||
partitionCountsHistogram = countsHistogram,
|
||||
partitionBytesHistogram = bytesHistogram,
|
||||
)
|
||||
|
||||
// Act
|
||||
val result = stateStage.apply(input)
|
||||
|
||||
// Assert
|
||||
verify(exactly = 1) { stateStore.acceptFlushedCounts(histogram) }
|
||||
verify(exactly = 1) { stateStore.acceptFlushedCounts(countsHistogram) }
|
||||
verify(exactly = 1) { stateStore.acceptFlushedBytes(bytesHistogram) }
|
||||
assertSame(input, result, "The output should be the same as the input object")
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `apply with null partition histogram throws exception`() = runTest {
|
||||
val input = DataFlowStageIO(partitionHistogram = null)
|
||||
fun `apply with null counts histogram throws exception`() = runTest {
|
||||
val input = DataFlowStageIO(partitionCountsHistogram = null)
|
||||
assertFailsWith<NullPointerException> { stateStage.apply(input) }
|
||||
verify(exactly = 0) { stateStore.acceptFlushedCounts(any()) }
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `apply with null bytes histogram throws exception`() = runTest {
|
||||
val input = DataFlowStageIO(partitionBytesHistogram = null)
|
||||
assertFailsWith<NullPointerException> { stateStage.apply(input) }
|
||||
verify(exactly = 0) { stateStore.acceptFlushedBytes(any()) }
|
||||
}
|
||||
}
|
||||
|
||||
@@ -29,7 +29,7 @@ class StateHistogramStoreTest {
|
||||
stateHistogramStore.acceptExpectedCounts(stateKey, 5L)
|
||||
|
||||
val partitionHistogram = PartitionHistogram(ConcurrentHashMap())
|
||||
repeat(5) { partitionHistogram.increment(partitionKey) }
|
||||
repeat(5) { partitionHistogram.increment(partitionKey, 1) }
|
||||
stateHistogramStore.acceptFlushedCounts(partitionHistogram)
|
||||
|
||||
// When
|
||||
@@ -50,9 +50,9 @@ class StateHistogramStoreTest {
|
||||
stateHistogramStore.acceptExpectedCounts(stateKey, 15L) // 5 + 3 + 7 = 15
|
||||
|
||||
val partitionHistogram = PartitionHistogram(ConcurrentHashMap())
|
||||
repeat(5) { partitionHistogram.increment(partitionKey1) }
|
||||
repeat(3) { partitionHistogram.increment(partitionKey2) }
|
||||
repeat(7) { partitionHistogram.increment(partitionKey3) }
|
||||
repeat(5) { partitionHistogram.increment(partitionKey1, 1) }
|
||||
repeat(3) { partitionHistogram.increment(partitionKey2, 1) }
|
||||
repeat(7) { partitionHistogram.increment(partitionKey3, 1) }
|
||||
stateHistogramStore.acceptFlushedCounts(partitionHistogram)
|
||||
|
||||
// When
|
||||
@@ -71,7 +71,7 @@ class StateHistogramStoreTest {
|
||||
stateHistogramStore.acceptExpectedCounts(stateKey, 10L)
|
||||
|
||||
val partitionHistogram = PartitionHistogram(ConcurrentHashMap())
|
||||
repeat(7) { partitionHistogram.increment(partitionKey) } // Less than expected
|
||||
repeat(7) { partitionHistogram.increment(partitionKey, 1) } // Less than expected
|
||||
stateHistogramStore.acceptFlushedCounts(partitionHistogram)
|
||||
|
||||
// When
|
||||
@@ -90,7 +90,7 @@ class StateHistogramStoreTest {
|
||||
stateHistogramStore.acceptExpectedCounts(stateKey, 5L)
|
||||
|
||||
val partitionHistogram = PartitionHistogram(ConcurrentHashMap())
|
||||
repeat(8) { partitionHistogram.increment(partitionKey) } // More than expected
|
||||
repeat(8) { partitionHistogram.increment(partitionKey, 1) } // More than expected
|
||||
stateHistogramStore.acceptFlushedCounts(partitionHistogram)
|
||||
|
||||
// When
|
||||
@@ -107,7 +107,7 @@ class StateHistogramStoreTest {
|
||||
val stateKey = StateKey(1L, listOf(partitionKey))
|
||||
|
||||
val partitionHistogram = PartitionHistogram(ConcurrentHashMap())
|
||||
repeat(5) { partitionHistogram.increment(partitionKey) }
|
||||
repeat(5) { partitionHistogram.increment(partitionKey, 1) }
|
||||
stateHistogramStore.acceptFlushedCounts(partitionHistogram)
|
||||
|
||||
// When
|
||||
@@ -142,7 +142,7 @@ class StateHistogramStoreTest {
|
||||
stateHistogramStore.acceptExpectedCounts(stateKey, 3L)
|
||||
|
||||
val partitionHistogram = PartitionHistogram(ConcurrentHashMap())
|
||||
repeat(3) { partitionHistogram.increment(partitionKey1) }
|
||||
repeat(3) { partitionHistogram.increment(partitionKey1, 1) }
|
||||
// partitionKey2 has no flushed counts, should be treated as 0
|
||||
stateHistogramStore.acceptFlushedCounts(partitionHistogram)
|
||||
|
||||
@@ -154,28 +154,123 @@ class StateHistogramStoreTest {
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `remove should delete both expected and flushed counts for state key`() {
|
||||
fun `remove should delete both expected and flushed counts for state key and return stats`() {
|
||||
// Given
|
||||
val partitionKey1 = PartitionKey("partition-1")
|
||||
val partitionKey2 = PartitionKey("partition-2")
|
||||
val stateKey = StateKey(1L, listOf(partitionKey1, partitionKey2))
|
||||
val expectedCount = 10L
|
||||
val bytes1 = 1000L
|
||||
val bytes2 = 2000L
|
||||
|
||||
stateHistogramStore.acceptExpectedCounts(stateKey, expectedCount)
|
||||
|
||||
val partitionHistogram = PartitionHistogram(ConcurrentHashMap())
|
||||
repeat(5) { partitionHistogram.increment(partitionKey1) }
|
||||
repeat(3) { partitionHistogram.increment(partitionKey2) }
|
||||
stateHistogramStore.acceptFlushedCounts(partitionHistogram)
|
||||
val partitionCountsHistogram = PartitionHistogram(ConcurrentHashMap())
|
||||
repeat(5) { partitionCountsHistogram.increment(partitionKey1, 1) }
|
||||
repeat(3) { partitionCountsHistogram.increment(partitionKey2, 1) }
|
||||
stateHistogramStore.acceptFlushedCounts(partitionCountsHistogram)
|
||||
|
||||
val partitionBytesHistogram = PartitionHistogram(ConcurrentHashMap())
|
||||
partitionBytesHistogram.increment(partitionKey1, bytes1)
|
||||
partitionBytesHistogram.increment(partitionKey2, bytes2)
|
||||
stateHistogramStore.acceptFlushedBytes(partitionBytesHistogram)
|
||||
|
||||
// When
|
||||
val count = stateHistogramStore.remove(stateKey)
|
||||
val stats = stateHistogramStore.remove(stateKey)
|
||||
|
||||
// Then
|
||||
assertEquals(expectedCount, stats.count)
|
||||
assertEquals(bytes1 + bytes2, stats.bytes)
|
||||
assertFalse(
|
||||
stateHistogramStore.isComplete(stateKey)
|
||||
) // Should be false due to missing expected count
|
||||
assertEquals(expectedCount, count)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `remove should handle missing byte counts as zero`() {
|
||||
// Given
|
||||
val partitionKey = PartitionKey("partition-1")
|
||||
val stateKey = StateKey(1L, listOf(partitionKey))
|
||||
val expectedCount = 5L
|
||||
|
||||
stateHistogramStore.acceptExpectedCounts(stateKey, expectedCount)
|
||||
|
||||
val partitionCountsHistogram = PartitionHistogram(ConcurrentHashMap())
|
||||
repeat(5) { partitionCountsHistogram.increment(partitionKey, 1) }
|
||||
stateHistogramStore.acceptFlushedCounts(partitionCountsHistogram)
|
||||
|
||||
// No bytes histogram added - should be treated as 0
|
||||
|
||||
// When
|
||||
val stats = stateHistogramStore.remove(stateKey)
|
||||
|
||||
// Then
|
||||
assertEquals(expectedCount, stats.count)
|
||||
assertEquals(0L, stats.bytes) // No bytes were added
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `remove should sum bytes from multiple partitions correctly`() {
|
||||
// Given
|
||||
val partitionKey1 = PartitionKey("partition-1")
|
||||
val partitionKey2 = PartitionKey("partition-2")
|
||||
val partitionKey3 = PartitionKey("partition-3")
|
||||
val stateKey = StateKey(1L, listOf(partitionKey1, partitionKey2, partitionKey3))
|
||||
val expectedCount = 15L
|
||||
val bytes1 = 5000L
|
||||
val bytes2 = 3000L
|
||||
val bytes3 = 7000L
|
||||
|
||||
stateHistogramStore.acceptExpectedCounts(stateKey, expectedCount)
|
||||
|
||||
val partitionCountsHistogram = PartitionHistogram(ConcurrentHashMap())
|
||||
repeat(7) { partitionCountsHistogram.increment(partitionKey1, 1) }
|
||||
repeat(3) { partitionCountsHistogram.increment(partitionKey2, 1) }
|
||||
repeat(5) { partitionCountsHistogram.increment(partitionKey3, 1) }
|
||||
stateHistogramStore.acceptFlushedCounts(partitionCountsHistogram)
|
||||
|
||||
val partitionBytesHistogram = PartitionHistogram(ConcurrentHashMap())
|
||||
partitionBytesHistogram.increment(partitionKey1, bytes1)
|
||||
partitionBytesHistogram.increment(partitionKey2, bytes2)
|
||||
partitionBytesHistogram.increment(partitionKey3, bytes3)
|
||||
stateHistogramStore.acceptFlushedBytes(partitionBytesHistogram)
|
||||
|
||||
// When
|
||||
val stats = stateHistogramStore.remove(stateKey)
|
||||
|
||||
// Then
|
||||
assertEquals(expectedCount, stats.count)
|
||||
assertEquals(bytes1 + bytes2 + bytes3, stats.bytes)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `remove should handle partial byte counts for partitions`() {
|
||||
// Given
|
||||
val partitionKey1 = PartitionKey("partition-1")
|
||||
val partitionKey2 = PartitionKey("partition-2")
|
||||
val stateKey = StateKey(1L, listOf(partitionKey1, partitionKey2))
|
||||
val expectedCount = 10L
|
||||
val bytes1 = 2500L
|
||||
// partitionKey2 will have no bytes recorded
|
||||
|
||||
stateHistogramStore.acceptExpectedCounts(stateKey, expectedCount)
|
||||
|
||||
val partitionCountsHistogram = PartitionHistogram(ConcurrentHashMap())
|
||||
repeat(6) { partitionCountsHistogram.increment(partitionKey1, 1) }
|
||||
repeat(4) { partitionCountsHistogram.increment(partitionKey2, 1) }
|
||||
stateHistogramStore.acceptFlushedCounts(partitionCountsHistogram)
|
||||
|
||||
val partitionBytesHistogram = PartitionHistogram(ConcurrentHashMap())
|
||||
partitionBytesHistogram.increment(partitionKey1, bytes1)
|
||||
// No bytes for partitionKey2
|
||||
stateHistogramStore.acceptFlushedBytes(partitionBytesHistogram)
|
||||
|
||||
// When
|
||||
val stats = stateHistogramStore.remove(stateKey)
|
||||
|
||||
// Then
|
||||
assertEquals(expectedCount, stats.count)
|
||||
assertEquals(bytes1, stats.bytes) // Only bytes from partition1
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -192,9 +287,9 @@ class StateHistogramStoreTest {
|
||||
stateHistogramStore.acceptExpectedCounts(stateKey2, 4L)
|
||||
|
||||
val partitionHistogram = PartitionHistogram(ConcurrentHashMap())
|
||||
repeat(5) { partitionHistogram.increment(partitionKey1) }
|
||||
repeat(3) { partitionHistogram.increment(partitionKey2) }
|
||||
repeat(4) { partitionHistogram.increment(partitionKey3) }
|
||||
repeat(5) { partitionHistogram.increment(partitionKey1, 1) }
|
||||
repeat(3) { partitionHistogram.increment(partitionKey2, 1) }
|
||||
repeat(4) { partitionHistogram.increment(partitionKey3, 1) }
|
||||
stateHistogramStore.acceptFlushedCounts(partitionHistogram)
|
||||
|
||||
// When
|
||||
|
||||
@@ -17,7 +17,7 @@ class StateHistogramTest {
|
||||
val stateKey = StateKey(1L, listOf(PartitionKey("partition-1")))
|
||||
|
||||
// When
|
||||
histogram.increment(stateKey)
|
||||
histogram.increment(stateKey, 1)
|
||||
|
||||
// Then
|
||||
assertEquals(1L, histogram.get(stateKey))
|
||||
@@ -30,9 +30,9 @@ class StateHistogramTest {
|
||||
val stateKey = StateKey(1L, listOf(PartitionKey("partition-1")))
|
||||
|
||||
// When
|
||||
histogram.increment(stateKey)
|
||||
histogram.increment(stateKey)
|
||||
histogram.increment(stateKey)
|
||||
histogram.increment(stateKey, 1)
|
||||
histogram.increment(stateKey, 1)
|
||||
histogram.increment(stateKey, 1)
|
||||
|
||||
// Then
|
||||
assertEquals(3L, histogram.get(stateKey))
|
||||
@@ -46,9 +46,9 @@ class StateHistogramTest {
|
||||
val stateKey2 = StateKey(2L, listOf(PartitionKey("partition-2")))
|
||||
|
||||
// When
|
||||
histogram.increment(stateKey1)
|
||||
histogram.increment(stateKey1)
|
||||
histogram.increment(stateKey2)
|
||||
histogram.increment(stateKey1, 1)
|
||||
histogram.increment(stateKey1, 1)
|
||||
histogram.increment(stateKey2, 1)
|
||||
|
||||
// Then
|
||||
assertEquals(2L, histogram.get(stateKey1))
|
||||
@@ -77,13 +77,13 @@ class StateHistogramTest {
|
||||
val stateKey2 = StateKey(2L, listOf(PartitionKey("partition-2")))
|
||||
val sharedKey = StateKey(3L, listOf(PartitionKey("partition-3")))
|
||||
|
||||
histogram1.increment(stateKey1)
|
||||
histogram1.increment(stateKey1)
|
||||
histogram1.increment(sharedKey)
|
||||
histogram1.increment(stateKey1, 1)
|
||||
histogram1.increment(stateKey1, 1)
|
||||
histogram1.increment(sharedKey, 1)
|
||||
|
||||
histogram2.increment(stateKey2)
|
||||
histogram2.increment(sharedKey)
|
||||
histogram2.increment(sharedKey)
|
||||
histogram2.increment(stateKey2, 1)
|
||||
histogram2.increment(sharedKey, 1)
|
||||
histogram2.increment(sharedKey, 1)
|
||||
|
||||
// When
|
||||
histogram1.merge(histogram2)
|
||||
@@ -102,8 +102,8 @@ class StateHistogramTest {
|
||||
val stateKey1 = StateKey(1L, listOf(PartitionKey("partition-1")))
|
||||
val stateKey2 = StateKey(2L, listOf(PartitionKey("partition-2")))
|
||||
|
||||
histogram1.increment(stateKey1)
|
||||
histogram2.increment(stateKey2)
|
||||
histogram1.increment(stateKey1, 1)
|
||||
histogram2.increment(stateKey2, 1)
|
||||
|
||||
// When
|
||||
histogram1.merge(histogram2)
|
||||
@@ -120,8 +120,8 @@ class StateHistogramTest {
|
||||
// Given
|
||||
val histogram = StateHistogram()
|
||||
val stateKey = StateKey(1L, listOf(PartitionKey("partition-1")))
|
||||
histogram.increment(stateKey)
|
||||
histogram.increment(stateKey)
|
||||
histogram.increment(stateKey, 1)
|
||||
histogram.increment(stateKey, 1)
|
||||
|
||||
// When
|
||||
val removedValue = histogram.remove(stateKey)
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
package io.airbyte.cdk.load.dataflow.state
|
||||
|
||||
import io.airbyte.cdk.load.message.CheckpointMessage
|
||||
import io.airbyte.cdk.load.message.StreamCheckpoint
|
||||
import io.mockk.every
|
||||
import io.mockk.impl.annotations.MockK
|
||||
import io.mockk.junit5.MockKExtension
|
||||
@@ -12,6 +13,7 @@ import io.mockk.mockk
|
||||
import io.mockk.verify
|
||||
import kotlin.test.assertEquals
|
||||
import kotlin.test.assertFalse
|
||||
import kotlin.test.assertNotNull
|
||||
import kotlin.test.assertNull
|
||||
import kotlin.test.assertTrue
|
||||
import org.junit.jupiter.api.BeforeEach
|
||||
@@ -30,7 +32,7 @@ class StateStoreTest {
|
||||
@BeforeEach
|
||||
fun setUp() {
|
||||
stateStore = StateStore(keyClient, histogramStore)
|
||||
every { histogramStore.remove(any()) } returns 1L
|
||||
every { histogramStore.remove(any()) } returns StateHistogramStats(1, 1)
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -189,6 +191,75 @@ class StateStoreTest {
|
||||
assertNull(stateStore.getNextComplete()) // no more states
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `getNextComplete should add byte counts and record counts to destination stats`() {
|
||||
// Given
|
||||
val sourceStats = CheckpointMessage.Stats(recordCount = 150L)
|
||||
val checkpointMessage =
|
||||
StreamCheckpoint(
|
||||
checkpoint = mockk(),
|
||||
sourceStats = sourceStats,
|
||||
serializedSizeBytes = 1L,
|
||||
)
|
||||
val stateKey = StateKey(1L, listOf(PartitionKey("partition-1")))
|
||||
val recordCount = 150L
|
||||
val byteCount = 50000L
|
||||
|
||||
every { keyClient.getStateKey(checkpointMessage) } returns stateKey
|
||||
every { histogramStore.acceptExpectedCounts(stateKey, recordCount) } returns mockk()
|
||||
every { histogramStore.isComplete(stateKey) } returns true
|
||||
every { histogramStore.remove(stateKey) } returns
|
||||
StateHistogramStats(count = recordCount, bytes = byteCount)
|
||||
|
||||
stateStore.accept(checkpointMessage)
|
||||
|
||||
// When
|
||||
val result = stateStore.getNextComplete()
|
||||
|
||||
// Then
|
||||
assertNotNull(result)
|
||||
assertEquals(recordCount, result.destinationStats?.recordCount)
|
||||
assertEquals(recordCount, result.totalRecords)
|
||||
assertEquals(byteCount, result.totalBytes)
|
||||
|
||||
// Verify histogram stats were removed
|
||||
verify { histogramStore.remove(stateKey) }
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `getNextComplete should handle different byte and record counts from histogram`() {
|
||||
// Given
|
||||
val expectedRecordCount = 100L
|
||||
val actualRecordCount = 98L // Slightly different
|
||||
val actualByteCount = 32768L
|
||||
|
||||
val sourceStats = CheckpointMessage.Stats(recordCount = expectedRecordCount)
|
||||
val checkpointMessage =
|
||||
StreamCheckpoint(
|
||||
checkpoint = mockk(),
|
||||
sourceStats = sourceStats,
|
||||
serializedSizeBytes = 1L,
|
||||
)
|
||||
val stateKey = StateKey(1L, listOf(PartitionKey("partition-1")))
|
||||
|
||||
every { keyClient.getStateKey(checkpointMessage) } returns stateKey
|
||||
every { histogramStore.acceptExpectedCounts(stateKey, expectedRecordCount) } returns mockk()
|
||||
every { histogramStore.isComplete(stateKey) } returns true
|
||||
every { histogramStore.remove(stateKey) } returns
|
||||
StateHistogramStats(count = actualRecordCount, bytes = actualByteCount)
|
||||
|
||||
stateStore.accept(checkpointMessage)
|
||||
|
||||
// When
|
||||
val result = stateStore.getNextComplete()
|
||||
|
||||
// Then
|
||||
assertNotNull(result)
|
||||
assertEquals(actualRecordCount, result.destinationStats?.recordCount)
|
||||
assertEquals(actualRecordCount, result.totalRecords)
|
||||
assertEquals(actualByteCount, result.totalBytes)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `getNextComplete should skip incomplete states and not advance sequence`() {
|
||||
// Given
|
||||
|
||||
Reference in New Issue
Block a user