feat(script): Resume of container for the Docker task runner (#11964)

* feat(script): Implements resume of container for the Docker task runner

closes #4129

* feat(script): docker resume review recommendations

* feat(script): Get volume name and update resume Docker tests

* feat(script): Fix tests for docker resume

* feat(script): test same container id created/reused

* feat(script): delete container after second run

* feat(script): Docker resume should be true by default

* feat(script): fix spacing
This commit is contained in:
Julio Daniel Reyes
2025-10-16 11:01:47 -03:00
committed by GitHub
parent f7031ec596
commit 3e4eed3306
3 changed files with 292 additions and 89 deletions

View File

@@ -50,6 +50,7 @@ import java.util.*;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Stream;
import static io.kestra.core.utils.Rethrow.throwConsumer;
import static io.kestra.core.utils.Rethrow.throwFunction;
@@ -313,6 +314,14 @@ public class Docker extends TaskRunner<Docker.DockerTaskRunnerDetailResult> {
)
private Duration killGracePeriod = Duration.ZERO;
@Builder.Default
@Schema(
title = "Whether to resume an existing matching container on restart.",
description = "If enabled, the runner will search for an existing container labeled with the current execution/task identifiers and reattach to it instead of creating a new container."
)
@PluginProperty
private Property<Boolean> resume = Property.ofValue(true);
/**
* Convenient default instance to be used as task default value for a 'taskRunner' property.
**/
@@ -364,18 +373,25 @@ public class Docker extends TaskRunner<Docker.DockerTaskRunnerDetailResult> {
String image = runContext.render(this.image, additionalVars);
String resolvedHost = DockerService.findHost(runContext, this.host);
try (DockerClient dockerClient = dockerClient(runContext, image, resolvedHost)) {
// pull image
var renderedPolicy = runContext.render(this.getPullPolicy()).as(PullPolicy.class).orElseThrow();
if (!PullPolicy.NEVER.equals(renderedPolicy)) {
pullImage(dockerClient, image, renderedPolicy, logger);
}
Map<String, String> labels = ScriptService.labels(runContext, "kestra.io/");
// create container
CreateContainerCmd container = configure(taskCommands, dockerClient, runContext, additionalVars);
CreateContainerResponse exec = container.exec();
if (logger.isTraceEnabled()) {
logger.trace("Container created: {}", exec.getId());
try (DockerClient dockerClient = dockerClient(runContext, image, resolvedHost)) {
// evaluate resume (task property overrides plugin configuration if set)
Boolean resumeProp = runContext.render(this.resume).as(Boolean.class).orElse(Boolean.FALSE);
boolean resumeEnabled = Boolean.TRUE.equals(resumeProp);
String containerId = null;
if (resumeEnabled) {
List<Container> existing = dockerClient.listContainersCmd()
.withShowAll(true)
.withLabelFilter(labels)
.exec();
if (!existing.isEmpty()) {
containerId = existing.get(0).getId();
logger.debug("Resuming existing container: {}", containerId);
}
}
List<Path> relativeWorkingDirectoryFilesPaths = taskCommands.relativeWorkingDirectoryFilesPaths(true);
@@ -384,94 +400,134 @@ public class Docker extends TaskRunner<Docker.DockerTaskRunnerDetailResult> {
boolean outputDirectoryEnabled = taskCommands.outputDirectoryEnabled();
boolean needVolume = hasFilesToDownload || hasFilesToUpload || outputDirectoryEnabled;
String filesVolumeName = null;
// create a volume if we need to handle files
var strategy = runContext.render(this.fileHandlingStrategy).as(FileHandlingStrategy.class).orElse(null);
if (needVolume && FileHandlingStrategy.VOLUME.equals(strategy)) {
CreateVolumeCmd files = dockerClient.createVolumeCmd()
.withLabels(ScriptService.labels(runContext, "kestra.io/"));
filesVolumeName = files.exec().getName();
// pull image only if we will create a new container
if (containerId == null) {
var renderedPolicy = runContext.render(this.getPullPolicy()).as(PullPolicy.class).orElseThrow();
if (!PullPolicy.NEVER.equals(renderedPolicy)) {
pullImage(dockerClient, image, renderedPolicy, logger);
}
// create container
CreateContainerCmd container = configure(taskCommands, dockerClient, runContext, additionalVars);
CreateContainerResponse exec = container.exec();
containerId = exec.getId();
if (logger.isTraceEnabled()) {
logger.trace("Volume created: {}", filesVolumeName);
logger.trace("Container created: {}", containerId);
}
String remotePath = windowsToUnixPath(taskCommands.getWorkingDirectory().toString());
// first, create an archive
Path fileArchive = runContext.workingDir().createFile("inputFiles.tar");
try (FileOutputStream fos = new FileOutputStream(fileArchive.toString());
TarArchiveOutputStream out = new TarArchiveOutputStream(fos)) {
out.setLongFileMode(TarArchiveOutputStream.LONGFILE_POSIX); // allow long file name
out.setBigNumberMode(TarArchiveOutputStream.BIGNUMBER_POSIX); // allow large archive name
for (Path file: relativeWorkingDirectoryFilesPaths) {
Path resolvedFile = runContext.workingDir().resolve(file);
TarArchiveEntry entry = out.createArchiveEntry(resolvedFile.toFile(), file.toString());
// Preserve POSIX permissions if supported
try {
Set<PosixFilePermission> perms = Files.getPosixFilePermissions(resolvedFile);
entry.setMode(UnixModeToPosixFilePermissions.fromPosixFilePermissions(perms));
} catch (UnsupportedOperationException | IOException ignore) {
// Skipping unix file permission
}
out.putArchiveEntry(entry);
if (!Files.isDirectory(resolvedFile)) {
try (InputStream fis = Files.newInputStream(resolvedFile)) {
IOUtils.copy(fis, out);
}
}
out.closeArchiveEntry();
// create a volume if we need to handle files
if (needVolume && FileHandlingStrategy.VOLUME.equals(strategy)) {
CreateVolumeCmd files = dockerClient.createVolumeCmd()
.withLabels(labels);
filesVolumeName = files.exec().getName();
if (logger.isTraceEnabled()) {
logger.trace("Volume created: {}", filesVolumeName);
}
String remotePath = windowsToUnixPath(taskCommands.getWorkingDirectory().toString());
// first, create an archive
Path fileArchive = runContext.workingDir().createFile("inputFiles.tar");
try (FileOutputStream fos = new FileOutputStream(fileArchive.toString());
TarArchiveOutputStream out = new TarArchiveOutputStream(fos)) {
out.setLongFileMode(TarArchiveOutputStream.LONGFILE_POSIX); // allow long file name
out.setBigNumberMode(TarArchiveOutputStream.BIGNUMBER_POSIX); // allow large archive name
for (Path file: relativeWorkingDirectoryFilesPaths) {
Path resolvedFile = runContext.workingDir().resolve(file);
TarArchiveEntry entry = out.createArchiveEntry(resolvedFile.toFile(), file.toString());
// Preserve POSIX permissions if supported
try {
Set<PosixFilePermission> perms = Files.getPosixFilePermissions(resolvedFile);
entry.setMode(UnixModeToPosixFilePermissions.fromPosixFilePermissions(perms));
} catch (UnsupportedOperationException | IOException ignore) {
// Skipping unix file permission
}
out.putArchiveEntry(entry);
if (!Files.isDirectory(resolvedFile)) {
try (InputStream fis = Files.newInputStream(resolvedFile)) {
IOUtils.copy(fis, out);
}
}
out.closeArchiveEntry();
}
out.finish();
}
// then send it to the container
try (InputStream is = new FileInputStream(fileArchive.toString())) {
CopyArchiveToContainerCmd copyArchiveToContainerCmd = dockerClient.copyArchiveToContainerCmd(containerId)
.withTarInputStream(is)
.withRemotePath(remotePath);
copyArchiveToContainerCmd.exec();
}
Files.delete(fileArchive);
// create the outputDir if needed
if (taskCommands.outputDirectoryEnabled()) {
CopyArchiveToContainerCmd copyArchiveToContainerCmd = dockerClient.copyArchiveToContainerCmd(containerId)
.withHostResource(taskCommands.getOutputDirectory().toString())
.withRemotePath(remotePath);
copyArchiveToContainerCmd.exec();
}
out.finish();
}
// then send it to the container
try (InputStream is = new FileInputStream(fileArchive.toString())) {
CopyArchiveToContainerCmd copyArchiveToContainerCmd = dockerClient.copyArchiveToContainerCmd(exec.getId())
.withTarInputStream(is)
.withRemotePath(remotePath);
copyArchiveToContainerCmd.exec();
}
// start container
dockerClient.startContainerCmd(containerId).exec();
Files.delete(fileArchive);
List<String> renderedCommands = runContext.render(taskCommands.getCommands()).asList(String.class);
// create the outputDir if needed
if (taskCommands.outputDirectoryEnabled()) {
CopyArchiveToContainerCmd copyArchiveToContainerCmd = dockerClient.copyArchiveToContainerCmd(exec.getId())
.withHostResource(taskCommands.getOutputDirectory().toString())
.withRemotePath(remotePath);
copyArchiveToContainerCmd.exec();
if (logger.isDebugEnabled()) {
logger.debug(
"Starting command with container id {} [{}]",
containerId,
String.join(" ", renderedCommands)
);
}
} else {
// resumed path: do not re-create or start the container, just attach and wait
if (logger.isDebugEnabled()) {
logger.debug("Attaching to logs of container {}", containerId);
}
if (needVolume && FileHandlingStrategy.VOLUME.equals(strategy)) {
List<String> labelsList = labels.entrySet()
.stream()
.map(entry -> String.join("=", entry.getKey(), entry.getValue()))
.toList();
var volumes = dockerClient.listVolumesCmd()
.withFilter("label", labelsList).exec();
if (volumes.getVolumes() == null || volumes.getVolumes().isEmpty()) {
logger.error("No volume found for resumed container {}", containerId);
throw new TaskException(1, defaultLogConsumer);
} else {
var volume = volumes.getVolumes().get(0);
filesVolumeName = volume.getName();
logger.debug("Volume found with name {} for resumed container {}", filesVolumeName, containerId);
}
}
}
// start container
dockerClient.startContainerCmd(exec.getId()).exec();
List<String> renderedCommands = runContext.render(taskCommands.getCommands()).asList(String.class);
if (logger.isDebugEnabled()) {
logger.debug(
"Starting command with container id {} [{}]",
exec.getId(),
String.join(" ", renderedCommands)
);
}
final String runContainerId = containerId;
if (!Boolean.TRUE.equals(runContext.render(wait).as(Boolean.class).orElseThrow())) {
return TaskRunnerResult.<DockerTaskRunnerDetailResult>builder()
.exitCode(0)
.logConsumer(defaultLogConsumer)
.details(DockerTaskRunnerDetailResult.builder().containerId(exec.getId()).build())
.details(DockerTaskRunnerDetailResult.builder().containerId(runContainerId).build())
.build();
}
// register the runnable to be used for killing the container.
onKill(() -> kill(dockerClient, exec.getId(), logger));
onKill(() -> kill(dockerClient, runContainerId, logger));
AtomicBoolean ended = new AtomicBoolean(false);
try {
dockerClient.logContainerCmd(exec.getId())
dockerClient.logContainerCmd(runContainerId)
.withFollowStream(true)
.withStdErr(true)
.withStdOut(true)
@@ -527,14 +583,15 @@ public class Docker extends TaskRunner<Docker.DockerTaskRunnerDetailResult> {
}
});
WaitContainerResultCallback result = dockerClient.waitContainerCmd(exec.getId()).start();
WaitContainerResultCallback result = dockerClient.waitContainerCmd(runContainerId).start();
Integer exitCode = result.awaitStatusCode();
Await.until(ended::get);
if (exitCode != 0) {
if (needVolume && FileHandlingStrategy.VOLUME.equals(strategy) && filesVolumeName != null) {
downloadOutputFiles(exec.getId(), dockerClient, runContext, taskCommands);
if (needVolume && FileHandlingStrategy.VOLUME.equals(strategy) && filesVolumeName != null) {
// On failure, still attempt to download outputs if VOLUME strategy is used
downloadOutputFiles(runContainerId, dockerClient, runContext, taskCommands);
}
throw new TaskException(exitCode, defaultLogConsumer);
@@ -542,14 +599,14 @@ public class Docker extends TaskRunner<Docker.DockerTaskRunnerDetailResult> {
logger.debug("Command succeed with exit code {}", exitCode);
}
if (needVolume && FileHandlingStrategy.VOLUME.equals(strategy) && filesVolumeName != null) {
downloadOutputFiles(exec.getId(), dockerClient, runContext, taskCommands);
if (needVolume && FileHandlingStrategy.VOLUME.equals(strategy) && filesVolumeName != null) {
downloadOutputFiles(runContainerId, dockerClient, runContext, taskCommands);
}
return TaskRunnerResult.<DockerTaskRunnerDetailResult>builder()
.exitCode(exitCode)
.logConsumer(defaultLogConsumer)
.details(DockerTaskRunnerDetailResult.builder().containerId(exec.getId()).build())
.details(DockerTaskRunnerDetailResult.builder().containerId(runContainerId).build())
.build();
} finally {
try {
@@ -558,12 +615,12 @@ public class Docker extends TaskRunner<Docker.DockerTaskRunnerDetailResult> {
kill();
if (Boolean.TRUE.equals(renderedDelete)) {
dockerClient.removeContainerCmd(exec.getId()).exec();
dockerClient.removeContainerCmd(runContainerId).exec();
if (logger.isTraceEnabled()) {
logger.trace("Container deleted: {}", exec.getId());
logger.trace("Container deleted: {}", runContainerId);
}
if (needVolume && FileHandlingStrategy.VOLUME.equals(strategy) && filesVolumeName != null) {
if (needVolume && FileHandlingStrategy.VOLUME.equals(strategy) && filesVolumeName != null) {
dockerClient.removeVolumeCmd(filesVolumeName).exec();
if (logger.isTraceEnabled()) {

View File

@@ -1,21 +1,44 @@
package io.kestra.plugin.scripts.runner.docker;
import com.github.dockerjava.api.model.Container;
import io.kestra.core.models.executions.LogEntry;
import io.kestra.core.models.property.Property;
import io.kestra.core.models.tasks.runners.AbstractTaskRunnerTest;
import io.kestra.core.models.tasks.runners.TaskCommands;
import io.kestra.core.models.tasks.runners.ScriptService;
import io.kestra.core.models.tasks.runners.TaskRunner;
import io.kestra.core.queues.QueueFactoryInterface;
import io.kestra.core.queues.QueueInterface;
import io.kestra.core.runners.RunContext;
import io.kestra.core.utils.Await;
import io.kestra.core.utils.IdUtils;
import io.kestra.core.utils.TestsUtils;
import io.kestra.plugin.scripts.exec.scripts.runners.CommandsWrapper;
import jakarta.inject.Inject;
import jakarta.inject.Named;
import org.assertj.core.api.Assertions;
import org.hamcrest.MatcherAssert;
import org.junit.jupiter.api.Test;
import reactor.core.publisher.Flux;
import java.lang.reflect.Method;
import java.time.Duration;
import java.util.*;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.mockito.Mockito;
import java.util.Collections;
import java.util.List;
import static io.kestra.core.utils.Rethrow.throwRunnable;
import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.is;
class DockerTest extends AbstractTaskRunnerTest {
@Inject
@Named(QueueFactoryInterface.WORKERTASKLOG_NAMED)
QueueInterface<LogEntry> workerTaskLogQueue;
@Override
protected TaskRunner<?> taskRunner() {
return Docker.builder().image("rockylinux:9.3-minimal").build();
@@ -66,4 +89,123 @@ class DockerTest extends AbstractTaskRunnerTest {
MatcherAssert.assertThat((String) result.getLogConsumer().getOutputs().get("cpuLimit"), containsString("150000"));
assertThat(result.getLogConsumer().getStdOutCount()).isEqualTo(1);
}
}
public static void callOnKill(TaskRunner<?> taskRunner, Runnable runnable) throws Exception {
Method method = TaskRunner.class.getDeclaredMethod("onKill", Runnable.class);
method.setAccessible(true);
method.invoke(taskRunner, runnable);
}
@Test
void killAfterResume() throws Exception {
var taskRunId = IdUtils.create();
// Create a new RunContext with a specific taskRunId
var runContext = runContext(this.runContextFactory, null, taskRunId);
var commands = initScriptCommands(runContext);
// Setup log queue consumer
List<LogEntry> logs = new CopyOnWriteArrayList<>();
Flux<LogEntry> receive = TestsUtils.receive(workerTaskLogQueue, (logEntry) -> {
if (logEntry.isLeft()) {
logs.add(logEntry.getLeft());
}
});
var commandsList = ScriptService.scriptCommands(List.of("/bin/sh", "-c"), Collections.emptyList(),
List.of("echo 'sleeping for 50 seconds' && sleep 50"));
Mockito.when(commands.getCommands()).thenReturn(Property.ofValue(commandsList));
var taskRunner = ((Docker) taskRunner())
.toBuilder()
.delete(Property.ofValue(false))
.build();
// Assert that the resume property is set to true by default
Boolean resume = runContext.render(taskRunner.getResume()).as(Boolean.class).orElseThrow();
assertThat(resume).isEqualTo(Boolean.TRUE);
Thread initialContainerThread = new Thread(throwRunnable(() -> taskRunner.run(runContext, commands, Collections.emptyList())));
initialContainerThread.start();
try (var client = DockerService.client(runContext, null, null, null, "rockylinux:9.3-minimal")) {
Map<String, String> labels = ScriptService.labels(runContext, "kestra.io/");
var timeout = Duration.ofSeconds(30);
// Wait for the container to be created
Await.until(() -> {
List<Container> existingContainers = client.listContainersCmd()
.withShowAll(true)
.withLabelFilter(labels)
.exec();
return !existingContainers.isEmpty() && existingContainers.get(0).getState().equals("running");
}, Duration.ofMillis(100), timeout); // Add timeout to avoid waiting forever for container to be created
callOnKill(taskRunner, () -> {
// override the kill method to not kill the container
});
initialContainerThread.interrupt();
initialContainerThread.join();
// Create a new RunContext with the same taskRunId to maintain labels AND the same method to get a similar context
RunContext anotherRunContext = runContext(this.runContextFactory, null, taskRunId);
var anotherTaskRunner = ((Docker) taskRunner())
.toBuilder()
.delete(Property.ofValue(true)) // Delete the container after the second run
.build();
// Start resume in a new thread
var resumeCommands = initScriptCommands(anotherRunContext);
Mockito.when(resumeCommands.getCommands()).thenReturn(Property.ofValue(commandsList));
Thread resumeContainerThread = new Thread(throwRunnable(() -> anotherTaskRunner.run(anotherRunContext, resumeCommands, Collections.emptyList())));
resumeContainerThread.start();
// Wait for the log message indicating resume
LogEntry awaitLog = TestsUtils
.awaitLog(logs, logEntry -> logEntry.getMessage().contains("Resuming existing container:"));
LogEntry createContainerLog = TestsUtils
.awaitLog(logs, logEntry -> logEntry.getMessage().contains("Container created:"));
receive.blockLast(timeout);
// Assert that the log messages are present
assertThat(createContainerLog).isNotNull().withFailMessage("create container log should not be null");
assertThat(createContainerLog.getMessage()).contains("Container created:");
assertThat(awaitLog).isNotNull().withFailMessage("await log should not be null");
assertThat(awaitLog.getMessage()).contains("Resuming existing container:");
// Get container id from the logs using regex
String createContainerId = null;
String resumeContainerId = null;
Matcher createContainerMatcher =
Pattern.compile("Container created: ([\\w]+)").matcher(createContainerLog.getMessage());
if (createContainerMatcher.find()) {
createContainerId = createContainerMatcher.group(1);
}
assertThat(createContainerId)
.withFailMessage("Could not extract container id from create container log: %s", createContainerLog.getMessage())
.isNotNull();
Matcher resumeContainerMatcher =
Pattern.compile("Resuming existing container: ([\\w]+)").matcher(awaitLog.getMessage());
if (resumeContainerMatcher.find()) {
resumeContainerId = resumeContainerMatcher.group(1);
}
// Assert that the container id is the same
assertThat(resumeContainerId).isEqualTo(createContainerId);
// Kill the container and verify cleanup
resumeContainerThread.interrupt();
resumeContainerThread.join();
List<Container> existingContainers = client.listContainersCmd()
.withShowAll(true)
.withLabelFilter(labels)
.exec();
MatcherAssert.assertThat(existingContainers.isEmpty(), is(true));
}
}
}

View File

@@ -211,6 +211,10 @@ public abstract class AbstractTaskRunnerTest {
}
protected RunContext runContext(RunContextFactory runContextFactory, Map<String, Object> additionalVars) {
return this.runContext(runContextFactory, additionalVars, IdUtils.create());
}
protected RunContext runContext(RunContextFactory runContextFactory, Map<String, Object> additionalVars, String taskRunId) {
// create a fake flow and execution
Task task = new Task() {
@Override
@@ -223,7 +227,7 @@ public abstract class AbstractTaskRunnerTest {
return "Task";
}
};
TaskRun taskRun = TaskRun.builder().id(IdUtils.create()).taskId("task").flowId("flow").namespace("namespace").executionId("execution")
TaskRun taskRun = TaskRun.builder().id(taskRunId).taskId("task").flowId("flow").namespace("namespace").executionId("execution")
.state(new State().withState(State.Type.RUNNING))
.build();
Flow flow = Flow.builder().id("flow").namespace("namespace").revision(1)