refactor(system): extract JdbcQueuePoller class from JdbcQueue

Extract a JdbcQueueConfiguration and JdbcQueuePoller classes from
JdbcQueue to improve clarity, testability and reuse of the code.
This commit is contained in:
Florian Hussonnois
2025-10-23 17:42:05 +02:00
committed by Loïc Mathieu
parent 4e3a786c3b
commit ed8e810791
8 changed files with 298 additions and 109 deletions

View File

@@ -42,7 +42,7 @@ public class H2Queue<T> extends JdbcQueue<T> {
var limitSelect = select
.orderBy(AbstractJdbcRepository.field("offset").asc())
.limit(configuration.getPollSize());
.limit(configuration.pollSize());
ResultQuery<Record2<Object, Object>> configuredSelect = limitSelect;
if (forUpdate) {

View File

@@ -51,7 +51,7 @@ public class MysqlQueue<T> extends JdbcQueue<T> {
var limitSelect = select
.orderBy(AbstractJdbcRepository.field("offset").asc())
.limit(configuration.getPollSize());
.limit(configuration.pollSize());
ResultQuery<Record2<Object, Object>> configuredSelect = limitSelect;
if (forUpdate) {

View File

@@ -70,7 +70,7 @@ public class PostgresQueue<T> extends JdbcQueue<T> {
var limitSelect = select
.orderBy(AbstractJdbcRepository.field("offset").asc())
.limit(configuration.getPollSize());
.limit(configuration.pollSize());
ResultQuery<Record2<Object, Object>> configuredSelect = limitSelect;
if (forUpdate) {

View File

@@ -18,10 +18,6 @@ import io.kestra.jdbc.repository.AbstractJdbcRepository;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Timer;
import io.micronaut.context.ApplicationContext;
import io.micronaut.context.annotation.ConfigurationProperties;
import io.micronaut.context.annotation.Value;
import io.micronaut.transaction.exceptions.CannotCreateTransactionException;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import org.jooq.*;
import org.jooq.Record;
@@ -29,8 +25,6 @@ import org.jooq.exception.DataException;
import org.jooq.impl.DSL;
import java.io.IOException;
import java.time.Duration;
import java.time.ZonedDateTime;
import java.util.*;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.AtomicBoolean;
@@ -56,7 +50,7 @@ public abstract class JdbcQueue<T> implements QueueInterface<T> {
protected final JooqDSLContextWrapper dslContextWrapper;
protected final Configuration configuration;
protected final JdbcQueueConfiguration configuration;
protected final MessageProtectionConfiguration messageProtectionConfiguration;
@@ -65,11 +59,10 @@ public abstract class JdbcQueue<T> implements QueueInterface<T> {
protected final Table<Record> table;
protected final JdbcQueueIndexer jdbcQueueIndexer;
private final boolean immediateRepoll;
private final AtomicBoolean isClosed = new AtomicBoolean(false);
private final AtomicBoolean isPaused = new AtomicBoolean(false);
private final List<JdbcQueuePoller> pollers = new ArrayList<>();
private final Counter bigMessageCounter;
@@ -81,7 +74,7 @@ public abstract class JdbcQueue<T> implements QueueInterface<T> {
this.queueService = applicationContext.getBean(QueueService.class);
this.cls = cls;
this.dslContextWrapper = applicationContext.getBean(JooqDSLContextWrapper.class);
this.configuration = applicationContext.getBean(Configuration.class);
this.configuration = applicationContext.getBean(JdbcQueueConfiguration.class);
this.messageProtectionConfiguration = applicationContext.getBean(MessageProtectionConfiguration.class);
this.metricRegistry = applicationContext.getBean(MetricRegistry.class);
@@ -90,9 +83,7 @@ public abstract class JdbcQueue<T> implements QueueInterface<T> {
this.table = DSL.table(jdbcTableConfigs.tableConfig("queues").table());
this.jdbcQueueIndexer = applicationContext.getBean(JdbcQueueIndexer.class);
this.immediateRepoll = applicationContext.getProperty("kestra.jdbc.queues.immediate-repoll", Boolean.class).orElse(true);
// init metrics we can at post construct to avoid costly Metric.Id computation
this.bigMessageCounter = metricRegistry
.counter(MetricRegistry.METRIC_QUEUE_BIG_MESSAGE_COUNT, MetricRegistry.METRIC_QUEUE_BIG_MESSAGE_COUNT_DESCRIPTION, MetricRegistry.TAG_CLASS_NAME, queueType());
@@ -248,7 +239,7 @@ public abstract class JdbcQueue<T> implements QueueInterface<T> {
var limitSelect = select
.orderBy(AbstractJdbcRepository.field("offset").asc())
.limit(configuration.getPollSize());
.limit(configuration.pollSize());
ResultQuery<Record2<Object, Object>> configuredSelect = limitSelect;
if (forUpdate) {
@@ -423,57 +414,17 @@ public abstract class JdbcQueue<T> implements QueueInterface<T> {
queueType.getSimpleName()
);
}
@SuppressWarnings("BusyWait")
protected Runnable poll(Supplier<Integer> runnable) {
AtomicBoolean running = new AtomicBoolean(true);
poolExecutor.execute(() -> {
List<Configuration.Step> steps = configuration.computeSteps();
Duration sleep = configuration.minPollInterval;
ZonedDateTime lastPoll = ZonedDateTime.now();
while (running.get() && !this.isClosed.get()) {
if (!this.isPaused.get()) {
try {
Integer count = runnable.get();
if (count > 0) {
lastPoll = ZonedDateTime.now();
sleep = configuration.minPollInterval;
if (immediateRepoll) {
continue;
} else if (count.equals(configuration.pollSize)) {
// Note: this provides better latency on high throughput: when Kestra is a top capacity,
// it will not do a sleep and immediately poll again.
// We can even have better latency at even higher latency by continuing for positive count,
// but at higher database cost.
// Current impl balance database cost with latency.
continue;
}
} else {
ZonedDateTime finalLastPoll = lastPoll;
// get all poll steps which duration is less than the duration between last poll and now
List<Configuration.Step> selectedSteps = steps.stream()
.takeWhile(step -> finalLastPoll.plus(step.switchInterval()).compareTo(ZonedDateTime.now()) < 0)
.toList();
// then select the last one (longest) or minPoll if all are beyond while means we are under the first interval
sleep = selectedSteps.isEmpty() ? configuration.minPollInterval : selectedSteps.getLast().pollInterval();
}
} catch (CannotCreateTransactionException e) {
if (log.isDebugEnabled()) {
log.debug("Can't poll on receive", e);
}
}
}
try {
Thread.sleep(sleep);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}
});
return () -> running.set(false);
JdbcQueuePoller queuePoller = new JdbcQueuePoller(configuration, runnable::get);
pollers.add(queuePoller);
poolExecutor.execute(queuePoller);
return () -> {
pollers.remove(queuePoller);
queuePoller.stop();
};
}
protected List<Either<T, DeserializationException>> map(Result<Record> fetch) {
@@ -494,12 +445,12 @@ public abstract class JdbcQueue<T> implements QueueInterface<T> {
@Override
public void pause() {
this.isPaused.set(true);
this.pollers.forEach(JdbcQueuePoller::pause);
}
@Override
public void resume() {
this.isPaused.set(false);
this.pollers.forEach(JdbcQueuePoller::resume);
}
@Override
@@ -507,45 +458,9 @@ public abstract class JdbcQueue<T> implements QueueInterface<T> {
if (!this.isClosed.compareAndSet(false, true)) {
return;
}
this.pollers.forEach(JdbcQueuePoller::stop);
this.poolExecutor.shutdown();
this.asyncPoolExecutor.shutdown();
}
@ConfigurationProperties("kestra.jdbc.queues")
@Getter
public static class Configuration {
Duration minPollInterval = Duration.ofMillis(25);
Duration maxPollInterval = Duration.ofMillis(500);
Duration pollSwitchInterval = Duration.ofSeconds(60);
Integer pollSize = 100;
Integer switchSteps = 5;
public List<Step> computeSteps() {
if (this.maxPollInterval.compareTo(this.minPollInterval) <= 0) {
throw new IllegalArgumentException("'maxPollInterval' (" + this.maxPollInterval + ") must be greater than 'minPollInterval' (" + this.minPollInterval + ")");
}
List<Step> steps = new ArrayList<>();
Step currentStep = new Step(this.maxPollInterval, this.pollSwitchInterval);
steps.add(currentStep);
for (int i = 0; i < switchSteps; i++) {
Duration stepPollInterval = Duration.ofMillis(currentStep.pollInterval().toMillis() / 2);
if (stepPollInterval.compareTo(minPollInterval) < 0) {
stepPollInterval = minPollInterval;
}
Duration stepSwitchInterval = Duration.ofMillis(currentStep.switchInterval().toMillis() / 2);
currentStep = new Step(stepPollInterval, stepSwitchInterval);
steps.add(currentStep);
}
Collections.sort(steps);
return steps;
}
public record Step (Duration pollInterval, Duration switchInterval) implements Comparable<Step> {
@Override
public int compareTo(Step o) {
return this.switchInterval.compareTo(o.switchInterval);
}
}
}
}

View File

@@ -0,0 +1,54 @@
package io.kestra.jdbc.runner;
import io.micronaut.context.annotation.ConfigurationProperties;
import io.micronaut.core.bind.annotation.Bindable;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
@ConfigurationProperties("kestra.jdbc.queues")
public record JdbcQueueConfiguration(
@Bindable(defaultValue = "PT0.025S")
Duration minPollInterval,
@Bindable(defaultValue = "PT0.5S")
Duration maxPollInterval,
@Bindable(defaultValue = "PT60S")
Duration pollSwitchInterval,
@Bindable(defaultValue = "100")
Integer pollSize,
@Bindable(defaultValue = "5")
Integer switchSteps,
@Bindable(defaultValue = "true")
Boolean immediateRepoll
) {
public List<Step> computeSteps() {
if (this.maxPollInterval.compareTo(this.minPollInterval) <= 0) {
throw new IllegalArgumentException("'maxPollInterval' (" + this.maxPollInterval + ") must be greater than 'minPollInterval' (" + this.minPollInterval + ")");
}
List<Step> steps = new ArrayList<>();
Step currentStep = new Step(this.maxPollInterval, this.pollSwitchInterval);
steps.add(currentStep);
for (int i = 0; i < switchSteps; i++) {
Duration stepPollInterval = Duration.ofMillis(currentStep.pollInterval().toMillis() / 2);
if (stepPollInterval.compareTo(minPollInterval) < 0) {
stepPollInterval = minPollInterval;
}
Duration stepSwitchInterval = Duration.ofMillis(currentStep.switchInterval().toMillis() / 2);
currentStep = new Step(stepPollInterval, stepSwitchInterval);
steps.add(currentStep);
}
Collections.sort(steps);
return steps;
}
public record Step(Duration pollInterval, Duration switchInterval) implements Comparable<Step> {
@Override
public int compareTo(Step o) {
return this.switchInterval.compareTo(o.switchInterval);
}
}
}

View File

@@ -0,0 +1,175 @@
package io.kestra.jdbc.runner;
import com.google.common.annotations.VisibleForTesting;
import io.kestra.core.exceptions.KestraRuntimeException;
import io.micronaut.transaction.exceptions.CannotCreateTransactionException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.time.Duration;
import java.time.ZonedDateTime;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.Callable;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.ReentrantLock;
/**
* Class responsible for continuously executing a polling query.
*/
public final class JdbcQueuePoller implements Runnable {
private static final Logger log = LoggerFactory.getLogger(JdbcQueuePoller.class);
private final JdbcQueueConfiguration configuration;
private final AtomicBoolean running = new AtomicBoolean(true);
private final Callable<Integer> pollingQuery;
// Pause
private final AtomicBoolean paused = new AtomicBoolean(false);
private final ReentrantLock pauseLock = new ReentrantLock();
private final Condition unpaused = pauseLock.newCondition();
private final CountDownLatch stopped = new CountDownLatch(1);
/**
* Creates a new {@link JdbcQueuePoller} instance.
*
* @param configuration the {@link JdbcQueueConfiguration}.
* @param pollingQuery the query to be executed.
*/
public JdbcQueuePoller(final JdbcQueueConfiguration configuration,
final Callable<Integer> pollingQuery) {
this.configuration = Objects.requireNonNull(configuration);
this.pollingQuery = Objects.requireNonNull(pollingQuery);
}
@Override
public void run() {
List<JdbcQueueConfiguration.Step> steps = configuration.computeSteps();
ZonedDateTime lastPoll = ZonedDateTime.now();
try {
while (running.get()) {
ZonedDateTime poll = pollOnce(lastPoll, steps);
if (poll != null) {
lastPoll = poll;
}
}
} finally {
stopped.countDown();
}
}
@VisibleForTesting
ZonedDateTime pollOnce(ZonedDateTime lastPoll, List<JdbcQueueConfiguration.Step> steps) {
Duration sleep;
try {
// Check pause before starting any query
waitIfPaused();
// Check if the loop was stopped while being paused
if (!running.get()) {
return null;
}
Integer count = pollingQuery.call();
if (count > 0) {
lastPoll = ZonedDateTime.now();
sleep = configuration.minPollInterval();
if (configuration.immediateRepoll()) {
return null;
} else if (count.equals(configuration.pollSize())) {
// Note: this provides better latency on high throughput: when Kestra is a top capacity,
// it will not do a sleep and immediately poll again.
// We can even have better latency at even higher latency by continuing for positive count,
// but at higher database cost.
// Current impl balance database cost with latency.
return null;
}
} else {
ZonedDateTime finalLastPoll = lastPoll;
// get all poll steps which duration is less than the duration between last poll and now
List<JdbcQueueConfiguration.Step> selectedSteps = steps.stream()
.takeWhile(step -> finalLastPoll.plus(step.switchInterval()).compareTo(ZonedDateTime.now()) < 0)
.toList();
// then select the last one (longest) or minPoll if all are beyond while means we are under the first interval
sleep = selectedSteps.isEmpty() ? configuration.minPollInterval() : selectedSteps.getLast().pollInterval();
}
Thread.sleep(sleep);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
log.warn("Interrupted while waiting. Stopping.");
running.set(false);
} catch (CannotCreateTransactionException e) {
if (log.isDebugEnabled()) {
log.debug("Can't poll on receive", e);
}
} catch (Exception e) {
throw new KestraRuntimeException("Unexpected error while executing queue polling query", e);
}
return lastPoll;
}
private void waitIfPaused() throws InterruptedException {
if (!paused.get()) {
return; // return immediately if not paused.
}
pauseLock.lock();
try {
while (paused.get() && running.get()) {
log.debug("Paused. Waiting for {} to resume", JdbcQueuePoller.class.getSimpleName());
unpaused.await(); // Wait until resume() signals
log.debug("Resumed");
}
} finally {
pauseLock.unlock();
}
}
/**
* Pauses this poller.
*/
public void pause() {
paused.set(true);
}
/**
* Resumes this poller if currently paused.
*/
public void resume() {
pauseLock.lock();
try {
if (paused.compareAndSet(true, false)) {
unpaused.signalAll();
}
} finally {
pauseLock.unlock();
}
}
/**
* Stops this poller.
*/
public void stop() {
if (!this.running.compareAndSet(true, false)) {
return; // already stopped
}
resume(); // In case it's paused and blocked
try {
// wait for the poller to be stooped
stopped.await();
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
log.warn("Interrupted while waiting for {} to be stopped.", this.getClass().getSimpleName());
}
}
}

View File

@@ -19,4 +19,5 @@ processResources.dependsOn copyGradleProperties
dependencies {
jmh project(':core')
jmh project(':jdbc')
}

View File

@@ -0,0 +1,44 @@
package io.kestra.jdbc.runner;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Level;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import java.time.Duration;
import java.time.ZonedDateTime;
import java.util.List;
import java.util.concurrent.TimeUnit;
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.MICROSECONDS)
@State(Scope.Thread)
public class JdbcQueuePollerBenchmark {
public static final JdbcQueueConfiguration DEFAULT_POLLER_CONFIG = new JdbcQueueConfiguration(
Duration.ofMillis(25),
Duration.ofMillis(500),
Duration.ofSeconds(60),
100,
5,
true
);
List<JdbcQueueConfiguration.Step> STEPS = DEFAULT_POLLER_CONFIG.computeSteps();
private JdbcQueuePoller poller;
@Setup(Level.Invocation)
public void setup() {
poller = new JdbcQueuePoller(DEFAULT_POLLER_CONFIG, () -> 1);
}
@Benchmark
public void testPollOnce() {
poller.pollOnce(ZonedDateTime.now(), STEPS);
}
}