diff --git a/jdbc-h2/src/test/java/io/kestra/runner/h2/H2RunnerConcurrencyTest.java b/jdbc-h2/src/test/java/io/kestra/runner/h2/H2RunnerConcurrencyTest.java index f4b5d590f1..a73673a687 100644 --- a/jdbc-h2/src/test/java/io/kestra/runner/h2/H2RunnerConcurrencyTest.java +++ b/jdbc-h2/src/test/java/io/kestra/runner/h2/H2RunnerConcurrencyTest.java @@ -1,6 +1,6 @@ package io.kestra.runner.h2; -import io.kestra.core.runners.AbstractRunnerConcurrencyTest; +import io.kestra.jdbc.runner.JdbcConcurrencyRunnerTest; -public class H2RunnerConcurrencyTest extends AbstractRunnerConcurrencyTest { +public class H2RunnerConcurrencyTest extends JdbcConcurrencyRunnerTest { } diff --git a/jdbc-mysql/src/test/java/io/kestra/runner/mysql/MysqlRunnerConcurrencyTest.java b/jdbc-mysql/src/test/java/io/kestra/runner/mysql/MysqlRunnerConcurrencyTest.java index e8deeb3fb2..395db0b274 100644 --- a/jdbc-mysql/src/test/java/io/kestra/runner/mysql/MysqlRunnerConcurrencyTest.java +++ b/jdbc-mysql/src/test/java/io/kestra/runner/mysql/MysqlRunnerConcurrencyTest.java @@ -1,6 +1,6 @@ package io.kestra.runner.mysql; -import io.kestra.core.runners.AbstractRunnerConcurrencyTest; +import io.kestra.jdbc.runner.JdbcConcurrencyRunnerTest; -public class MysqlRunnerConcurrencyTest extends AbstractRunnerConcurrencyTest { +public class MysqlRunnerConcurrencyTest extends JdbcConcurrencyRunnerTest { } diff --git a/jdbc-postgres/src/test/java/io/kestra/runner/postgres/PostgresRunnerConcurrencyTest.java b/jdbc-postgres/src/test/java/io/kestra/runner/postgres/PostgresRunnerConcurrencyTest.java index 1df8d26ee8..54f113872e 100644 --- a/jdbc-postgres/src/test/java/io/kestra/runner/postgres/PostgresRunnerConcurrencyTest.java +++ b/jdbc-postgres/src/test/java/io/kestra/runner/postgres/PostgresRunnerConcurrencyTest.java @@ -1,6 +1,6 @@ package io.kestra.runner.postgres; -import io.kestra.core.runners.AbstractRunnerConcurrencyTest; +import io.kestra.jdbc.runner.JdbcConcurrencyRunnerTest; -public class PostgresRunnerConcurrencyTest extends AbstractRunnerConcurrencyTest { +public class PostgresRunnerConcurrencyTest extends JdbcConcurrencyRunnerTest { } diff --git a/jdbc/src/main/java/io/kestra/jdbc/runner/AbstractJdbcConcurrencyLimitStorage.java b/jdbc/src/main/java/io/kestra/jdbc/runner/AbstractJdbcConcurrencyLimitStorage.java index 143150b349..e79882ab4e 100644 --- a/jdbc/src/main/java/io/kestra/jdbc/runner/AbstractJdbcConcurrencyLimitStorage.java +++ b/jdbc/src/main/java/io/kestra/jdbc/runner/AbstractJdbcConcurrencyLimitStorage.java @@ -74,15 +74,19 @@ public class AbstractJdbcConcurrencyLimitStorage extends AbstractJdbcRepository * Decrement the concurrency limit counter. * Must only be called when a flow having concurrency limit ends. */ - public void decrement(FlowInterface flow) { - this.jdbcRepository + public int decrement(FlowInterface flow) { + return this.jdbcRepository .getDslContextWrapper() - .transaction(configuration -> { + .transactionResult(configuration -> { var dslContext = DSL.using(configuration); - fetchOne(dslContext, flow).ifPresent( - concurrencyLimit -> update(dslContext, concurrencyLimit.withRunning(concurrencyLimit.getRunning() == 0 ? 0 : concurrencyLimit.getRunning() - 1)) - ); + return fetchOne(dslContext, flow).map( + concurrencyLimit -> { + int newLimit = concurrencyLimit.getRunning() == 0 ? 0 : concurrencyLimit.getRunning() - 1; + update(dslContext, concurrencyLimit.withRunning(newLimit)); + return newLimit; + } + ).orElse(0); }); } diff --git a/jdbc/src/main/java/io/kestra/jdbc/runner/JdbcExecutor.java b/jdbc/src/main/java/io/kestra/jdbc/runner/JdbcExecutor.java index 6472bd9c36..cac17e7dc4 100644 --- a/jdbc/src/main/java/io/kestra/jdbc/runner/JdbcExecutor.java +++ b/jdbc/src/main/java/io/kestra/jdbc/runner/JdbcExecutor.java @@ -1210,24 +1210,30 @@ public class JdbcExecutor implements ExecutorInterface { // as we may receive multiple time killed execution (one when we kill it, then one for each running worker task), we limit to the first we receive: when the state transitionned from KILLING to KILLED boolean killingThenKilled = execution.getState().getCurrent().isKilled() && executor.getOriginalState() == State.Type.KILLING; if (!queuedThenKilled && !concurrencyShortCircuitState && (!execution.getState().getCurrent().isKilled() || killingThenKilled)) { - // decrement execution concurrency limit and pop a new queued execution if needed - concurrencyLimitStorage.decrement(executor.getFlow()); + int newLimit = concurrencyLimitStorage.decrement(executor.getFlow()); if (executor.getFlow().getConcurrency().getBehavior() == Concurrency.Behavior.QUEUE) { var finalFlow = executor.getFlow(); - executionQueuedStorage.pop(executor.getFlow().getTenantId(), - executor.getFlow().getNamespace(), - executor.getFlow().getId(), - throwBiConsumer((dslContext, queued) -> { - var newExecution = queued.withState(State.Type.RUNNING); - concurrencyLimitStorage.increment(dslContext, finalFlow); - executionQueue.emit(newExecution); - metricRegistry.counter(MetricRegistry.METRIC_EXECUTOR_EXECUTION_POPPED_COUNT, MetricRegistry.METRIC_EXECUTOR_EXECUTION_POPPED_COUNT_DESCRIPTION, metricRegistry.tags(newExecution)).increment(); - // process flow triggers to allow listening on RUNNING state after a QUEUED state - processFlowTriggers(newExecution); - }) - ); + if (newLimit < finalFlow.getConcurrency().getLimit()) { + executionQueuedStorage.pop(executor.getFlow().getTenantId(), + executor.getFlow().getNamespace(), + executor.getFlow().getId(), + throwBiConsumer((dslContext, queued) -> { + var newExecution = queued.withState(State.Type.RUNNING); + concurrencyLimitStorage.increment(dslContext, finalFlow); + executionQueue.emit(newExecution); + metricRegistry.counter(MetricRegistry.METRIC_EXECUTOR_EXECUTION_POPPED_COUNT, MetricRegistry.METRIC_EXECUTOR_EXECUTION_POPPED_COUNT_DESCRIPTION, metricRegistry.tags(newExecution)).increment(); + + // process flow triggers to allow listening on RUNNING state after a QUEUED state + processFlowTriggers(newExecution); + }) + ); + } else { + log.error("Concurrency limit reached for flow {}.{} after decrementing the execution running count due to the terminated execution {}. No new executions will be dequeued.", executor.getFlow().getNamespace(), executor.getFlow().getId(), executor.getExecution().getId()); + } + } else if (newLimit >= executor.getFlow().getConcurrency().getLimit()) { + log.error("Concurrency limit reached for flow {}.{} after decrementing the execution running count due to the terminated execution {}. This should not happen.", executor.getFlow().getNamespace(), executor.getFlow().getId(), executor.getExecution().getId()); } } } diff --git a/jdbc/src/test/java/io/kestra/jdbc/runner/JdbcConcurrencyRunnerTest.java b/jdbc/src/test/java/io/kestra/jdbc/runner/JdbcConcurrencyRunnerTest.java new file mode 100644 index 0000000000..95c0190aa8 --- /dev/null +++ b/jdbc/src/test/java/io/kestra/jdbc/runner/JdbcConcurrencyRunnerTest.java @@ -0,0 +1,65 @@ +package io.kestra.jdbc.runner; + +import io.kestra.core.junit.annotations.LoadFlows; +import io.kestra.core.models.executions.Execution; +import io.kestra.core.models.flows.Flow; +import io.kestra.core.models.flows.State; +import io.kestra.core.queues.QueueException; +import io.kestra.core.repositories.ExecutionRepositoryInterface; +import io.kestra.core.repositories.FlowRepositoryInterface; +import io.kestra.core.runners.AbstractRunnerConcurrencyTest; +import io.kestra.core.runners.ConcurrencyLimit; +import io.kestra.core.runners.TestRunnerUtils; +import jakarta.inject.Inject; +import org.junit.jupiter.api.Test; + +import java.time.Duration; +import java.util.Optional; + +import static io.kestra.core.tenant.TenantService.MAIN_TENANT; +import static org.assertj.core.api.Assertions.assertThat; + +public abstract class JdbcConcurrencyRunnerTest extends AbstractRunnerConcurrencyTest { + public static final String NAMESPACE = "io.kestra.tests"; + + @Inject + private AbstractJdbcConcurrencyLimitStorage concurrencyLimitStorage; + + @Inject + private FlowRepositoryInterface flowRepository; + + @Inject + private ExecutionRepositoryInterface executionRepository; + + @Inject + private TestRunnerUtils runnerUtils; + + @Test + @LoadFlows({"flows/valids/flow-concurrency-queue.yml"}) + void flowConcurrencyQueuedProtection() throws QueueException, InterruptedException { + Execution execution1 = runnerUtils.runOneUntilRunning(MAIN_TENANT, NAMESPACE, "flow-concurrency-queue", null, null, Duration.ofSeconds(30)); + assertThat(execution1.getState().isRunning()).isTrue(); + + Flow flow = flowRepository + .findById(MAIN_TENANT, NAMESPACE, "flow-concurrency-queue", Optional.empty()) + .orElseThrow(); + Execution execution2 = runnerUtils.emitAndAwaitExecution(e -> e.getState().getCurrent().equals(State.Type.QUEUED), Execution.newExecution(flow, null, null, Optional.empty())); + assertThat(execution2.getState().getCurrent()).isEqualTo(State.Type.QUEUED); + + // manually update the concurrency count so that queued protection kicks in and no new execution would be popped + ConcurrencyLimit concurrencyLimit = concurrencyLimitStorage.findById(MAIN_TENANT, NAMESPACE, "flow-concurrency-queue").orElseThrow(); + concurrencyLimit = concurrencyLimit.withRunning(concurrencyLimit.getRunning() + 1); + concurrencyLimitStorage.update(concurrencyLimit); + + Execution executionResult1 = runnerUtils.awaitExecution(e -> e.getState().getCurrent().equals(State.Type.SUCCESS), execution1); + assertThat(executionResult1.getState().getCurrent()).isEqualTo(State.Type.SUCCESS); + + // we wait for a few ms and checked that the second execution is still queued + Thread.sleep(500); + Execution executionResult2 = executionRepository.findById(MAIN_TENANT, execution2.getId()).orElseThrow(); + assertThat(executionResult2.getState().getCurrent()).isEqualTo(State.Type.QUEUED); + + // we manually reset the concurrency count to avoid messing with any other tests + concurrencyLimitStorage.update(concurrencyLimit.withRunning(concurrencyLimit.getRunning() - 1)); + } +}