Compare commits

...

1 Commits

Author SHA1 Message Date
Florian Hussonnois
8ea8d86516 WIP 2025-08-22 14:56:24 +02:00
58 changed files with 2954 additions and 11 deletions

View File

@@ -34,6 +34,7 @@ dependencies {
implementation project(":storage-local")
implementation project(":webserver")
implementation project(":worker")
//test
testImplementation "org.wiremock:wiremock-jetty12"

View File

@@ -19,6 +19,7 @@ import picocli.CommandLine;
WebServerCommand.class,
WorkerCommand.class,
LocalCommand.class,
WorkerAgentCommand.class,
}
)
@Slf4j

View File

@@ -10,6 +10,7 @@ import io.kestra.core.runners.StandAloneRunner;
import io.kestra.core.services.SkipExecutionService;
import io.kestra.core.services.StartExecutorService;
import io.kestra.core.utils.Await;
import io.kestra.controller.Controller;
import io.micronaut.context.ApplicationContext;
import jakarta.annotation.Nullable;
import jakarta.inject.Inject;
@@ -110,6 +111,8 @@ public class StandAloneCommand extends AbstractServerCommand {
}
StandAloneRunner standAloneRunner = applicationContext.getBean(StandAloneRunner.class);
Controller controller = applicationContext.getBean(Controller.class);
if (this.workerThread == 0) {
standAloneRunner.setWorkerEnabled(false);

View File

@@ -0,0 +1,59 @@
package io.kestra.cli.commands.servers;
import com.google.common.collect.ImmutableMap;
import io.kestra.core.contexts.KestraContext;
import io.kestra.core.models.ServerType;
import io.kestra.core.utils.Await;
import io.kestra.worker.Worker;
import io.micronaut.context.ApplicationContext;
import jakarta.inject.Inject;
import picocli.CommandLine;
import picocli.CommandLine.Option;
import java.util.Map;
@CommandLine.Command(
name = "worker-agent",
description = "Start the Kestra worker"
)
public class WorkerAgentCommand extends AbstractServerCommand {
@Inject
private ApplicationContext applicationContext;
@Option(names = {"-t", "--thread"}, description = "The max number of worker threads, defaults to four times the number of available processors")
private int thread = defaultWorkerThread();
@Option(names = {"-g", "--worker-group"}, description = "The worker group key, must match the regex [a-zA-Z0-9_-]+ (EE only)")
private String workerGroupKey = null;
@SuppressWarnings("unused")
public static Map<String, Object> propertiesOverrides() {
return ImmutableMap.of(
"kestra.server-type", ServerType.WORKER_AGENT
);
}
@Override
public Integer call() throws Exception {
KestraContext.getContext().injectWorkerConfigs(thread, workerGroupKey);
super.call();
if (this.workerGroupKey != null && !this.workerGroupKey.matches("[a-zA-Z0-9_-]+")) {
throw new IllegalArgumentException("The --worker-group option must match the [a-zA-Z0-9_-]+ pattern");
}
Worker worker = applicationContext.getBean(Worker.class);
worker.start(thread, workerGroupKey);
Await.until(() -> !this.applicationContext.isRunning());
return 0;
}
public String workerGroupKey() {
return workerGroupKey;
}
}

View File

@@ -50,7 +50,6 @@ micronaut:
caches:
default:
maximum-weight: 10485760
http:
client:
read-idle-timeout: 60s
@@ -93,7 +92,13 @@ jackson:
deserialization:
FAIL_ON_UNKNOWN_PROPERTIES: false
endpoints:
# Disable Micronaut GRPC
grpc:
server:
enabled: false
health:
enabled: false
endpoints:
all:
port: 8081
enabled: true

View File

@@ -15,6 +15,8 @@ import jakarta.inject.Singleton;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.ArrayUtils;
import java.util.function.Supplier;
@Singleton
@Slf4j
public class MetricRegistry {
@@ -180,11 +182,26 @@ public class MetricRegistry {
* statement.
*/
public <T extends Number> T gauge(String name, String description, T number, String... tags) {
Gauge.builder(metricName(name), () -> number)
registerGauge(name, description, () -> number, tags);
return number;
}
/**
* Register a gauge that reports the value of the {@link Number}.
*
* @param name Name of the gauge being registered.
* @param description The metric description
* @param supplier A function that yields a double value for the gauge.
* @param tags Sequence of dimensions for breaking down the name.
* @param <T> The type of the number from which the gauge value is extracted.
* @return The number that was passed in so the registration can be done as part of an assignment
* statement.
*/
public <T extends Number> Gauge registerGauge(String name, String description, Supplier<T> supplier, String... tags) {
return Gauge.builder(metricName(name),supplier)
.description(description)
.tags(tags)
.register(this.meterRegistry);
return number;
}
/**

View File

@@ -1,10 +1,11 @@
package io.kestra.core.models;
public enum ServerType {
public enum ServerType {
EXECUTOR,
INDEXER,
SCHEDULER,
STANDALONE,
WEBSERVER,
WORKER,
WORKER_AGENT,
}

View File

@@ -9,4 +9,5 @@ import java.util.function.Consumer;
public interface WorkerJobQueueInterface extends QueueInterface<WorkerJob> {
Runnable subscribe(String workerId, String workerGroup, Consumer<Either<WorkerJob, DeserializationException>> consumer);
}

View File

@@ -6,6 +6,7 @@ import io.kestra.core.utils.Exceptions;
import io.opentelemetry.api.trace.Span;
import io.opentelemetry.api.trace.StatusCode;
import lombok.Getter;
import lombok.Setter;
import lombok.Synchronized;
import org.slf4j.Logger;
@@ -32,6 +33,7 @@ public abstract class AbstractWorkerCallable implements Callable<State.Type> {
String uid;
@Getter
@Setter
Throwable exception;
private final CountDownLatch shutdownLatch = new CountDownLatch(1);
@@ -80,7 +82,7 @@ public abstract class AbstractWorkerCallable implements Callable<State.Type> {
*
* @see WorkerJobLifecycle#stop()
*/
protected abstract void signalStop();
public abstract void signalStop();
/**
* Wait for this worker task to complete stopping.

View File

@@ -17,4 +17,6 @@ public interface WorkerJobRunningStateStore {
* @param key the key of the worker job to be deleted.
*/
void deleteByKey(String key);
void put(WorkerJobRunning workerJobRunning);
}

View File

@@ -24,7 +24,7 @@ public class WorkerTaskCallable extends AbstractWorkerCallable {
@Getter
Output taskOutput;
WorkerTaskCallable(WorkerTask workerTask, RunnableTask<?> task, RunContext runContext, MetricRegistry metricRegistry) {
public WorkerTaskCallable(WorkerTask workerTask, RunnableTask<?> task, RunContext runContext, MetricRegistry metricRegistry) {
super(runContext, task.getClass().getName(), workerTask.uid(), task.getClass().getClassLoader());
this.workerTask = workerTask;
this.task = task;

View File

@@ -15,7 +15,7 @@ public class WorkerTriggerCallable extends AbstractWorkerTriggerCallable {
@Getter
Optional<Execution> evaluate;
WorkerTriggerCallable(RunContext runContext, WorkerTrigger workerTrigger, PollingTriggerInterface pollingTrigger) {
public WorkerTriggerCallable(RunContext runContext, WorkerTrigger workerTrigger, PollingTriggerInterface pollingTrigger) {
super(runContext, pollingTrigger.getClass().getName(), workerTrigger);
this.pollingTrigger = pollingTrigger;
}

View File

@@ -16,7 +16,7 @@ public class WorkerTriggerRealtimeCallable extends AbstractWorkerTriggerCallable
Consumer<? super Throwable> onError;
Consumer<Execution> onNext;
WorkerTriggerRealtimeCallable(
public WorkerTriggerRealtimeCallable(
RunContext runContext,
WorkerTrigger workerTrigger,
RealtimeTriggerInterface realtimeTrigger,

View File

@@ -18,7 +18,7 @@ import java.util.stream.Collectors;
/**
* Runtime information about a Kestra's service (e.g., WORKER, EXECUTOR, etc.).
*
* @param uid The service unique identifier.
* @param uid The service unique identifier.
* @param type The service type.
* @param state The state of the service.
* @param server The server running this service.

View File

@@ -12,6 +12,8 @@ public enum ServiceType {
SCHEDULER,
WEBSERVER,
WORKER,
WORKER_AGENT,
CONTROLLER,
INVALID;
@JsonCreator

View File

@@ -1,6 +1,7 @@
package io.kestra.runner.postgres;
import io.kestra.core.exceptions.DeserializationException;
import io.kestra.core.queues.WorkerJobQueueInterface;
import io.kestra.core.runners.WorkerJob;
import io.kestra.core.queues.WorkerJobQueueInterface;
import io.kestra.core.utils.Either;
@@ -17,7 +18,7 @@ import java.util.function.Consumer;
@Slf4j
public class PostgresWorkerJobQueue extends PostgresQueue<WorkerJob> implements WorkerJobQueueInterface {
private final JdbcWorkerJobQueueService jdbcWorkerJobQueueService;
public PostgresWorkerJobQueue(ApplicationContext applicationContext) {
super(WorkerJob.class, applicationContext);
this.jdbcWorkerJobQueueService = applicationContext.getBean(JdbcWorkerJobQueueService.class);

View File

@@ -37,6 +37,11 @@ public abstract class AbstractJdbcWorkerJobRunningRepository extends AbstractJdb
.execute()
);
}
@Override
public void put(WorkerJobRunning workerJobRunning) {
this.jdbcRepository.persist(workerJobRunning);
}
@Override
public Optional<WorkerJobRunning> findByKey(String uid) {

View File

@@ -23,5 +23,6 @@ include 'model'
include 'processor'
include 'script'
include 'e2e-tests'
include 'worker'
include 'jmh-benchmarks'

42
worker/build.gradle Normal file
View File

@@ -0,0 +1,42 @@
plugins {
id 'com.google.protobuf' version '0.9.4'
}
ext {
grpcVersion = '1.71.0'
protobufVersion = '3.25.1'
}
configurations {
tests
implementation.extendsFrom(micronaut)
}
dependencies {
// Kestra
implementation project(':core')
annotationProcessor project(':processor')
// gRPC
implementation("io.micronaut.grpc:micronaut-grpc-server-runtime")
implementation("io.micronaut.grpc:micronaut-grpc-client-runtime")
}
protobuf {
protoc {
artifact = "com.google.protobuf:protoc:$protobufVersion"
}
plugins {
grpc {
artifact = "io.grpc:protoc-gen-grpc-java:$grpcVersion"
}
}
generateProtoTasks {
all()*.plugins {
grpc {
// avoid issues javax packages
option '@generated=omit'
}
}
}
}

View File

@@ -0,0 +1,71 @@
package io.kestra.controller;
import io.grpc.Grpc;
import io.grpc.InsecureServerCredentials;
import io.grpc.Server;
import io.kestra.controller.grpc.server.LivenessControllerService;
import io.kestra.controller.grpc.server.WorkerControllerService;
import io.kestra.core.server.Service;
import io.kestra.core.server.ServiceStateChangeEvent;
import io.kestra.core.server.ServiceType;
import io.kestra.server.AbstractService;
import io.micronaut.context.event.ApplicationEventPublisher;
import jakarta.annotation.PostConstruct;
import jakarta.inject.Inject;
import jakarta.inject.Singleton;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.util.concurrent.TimeUnit;
@Singleton
public class Controller extends AbstractService implements Service {
private static final Logger log = LoggerFactory.getLogger(Controller.class);
private Server server;
private final WorkerControllerService workerControllerService;
private final LivenessControllerService livenessControllerService;
@Inject
public Controller(
WorkerControllerService workerControllerService,
LivenessControllerService livenessControllerService,
ApplicationEventPublisher<ServiceStateChangeEvent> eventPublisher) {
super(ServiceType.CONTROLLER, eventPublisher);
this.workerControllerService = workerControllerService;
this.livenessControllerService = livenessControllerService;
}
@PostConstruct
public void start() throws IOException {
if (getState() != ServiceState.CREATED) {
throw new IllegalStateException("Controller is already started or stopped");
}
log.info("Starting Controller");
/* The port on which the server should run */
int port = 9096; // TODO to externalize
server = Grpc.newServerBuilderForPort(port, InsecureServerCredentials.create())
.addService(workerControllerService)
.addService(livenessControllerService)
.build()
.start();
log.info("Controller started, listening on {}", port);
setState(ServiceState.RUNNING);
}
@Override
protected ServiceState doStop() throws InterruptedException {
if (server != null && !server.isTerminated()) {
shutdownServerAndWait();
}
return ServiceState.TERMINATED_GRACEFULLY;
}
private void shutdownServerAndWait() throws InterruptedException {
server.shutdown().awaitTermination(30, TimeUnit.SECONDS);
}
}

View File

@@ -0,0 +1,60 @@
package io.kestra.controller;
import io.kestra.controller.grpc.HeartbeatRequest;
import io.kestra.controller.grpc.HeartbeatResponse;
import io.kestra.controller.grpc.LivenessControllerServiceGrpc.LivenessControllerServiceBlockingStub;
import io.kestra.controller.messages.HeartbeatMessage;
import io.kestra.controller.messages.HeartbeatMessageReply;
import io.kestra.core.contexts.KestraContext;
import io.kestra.core.server.Service;
import io.kestra.core.server.ServiceInstance;
import io.kestra.core.server.ServiceLivenessUpdater;
import io.kestra.core.server.ServiceStateTransition;
import io.kestra.server.grpc.RequestOrResponseHeader;
import io.kestra.server.internals.MessageFormat;
import io.kestra.server.internals.MessageFormats;
import io.micronaut.context.annotation.Secondary;
import jakarta.inject.Inject;
import jakarta.inject.Singleton;
import java.util.Objects;
import java.util.UUID;
public class GrpcServiceLivenessUpdater implements ServiceLivenessUpdater {
private final LivenessControllerServiceBlockingStub client;
public GrpcServiceLivenessUpdater(final LivenessControllerServiceBlockingStub client) {
this.client = Objects.requireNonNull(client, "client must not be null.");
}
/** {@inheritDoc} **/
@Override
public void update(ServiceInstance service) {
update(service, null, null);
}
/** {@inheritDoc} **/
@Override
public ServiceStateTransition.Response update(ServiceInstance instance, Service.ServiceState newState, String reason) {
HeartbeatResponse response = client.heartbeat(HeartbeatRequest
.newBuilder()
.setHeader(RequestOrResponseHeader
.newBuilder()
.setClientId(instance.uid())
.setClientVersion(KestraContext.getContext().getVersion())
.setMessageFormat(MessageFormats.JSON.name())
.setCorrelationId(UUID.randomUUID().toString())
.build()
)
.setMessage(MessageFormats.JSON.toByteString(new HeartbeatMessage(instance, newState, reason)))
.build()
);
HeartbeatMessageReply messageReply = MessageFormat
.resolve(response.getHeader().getMessageFormat())
.fromByteString(response.getMessage(), HeartbeatMessageReply.class);
return new ServiceStateTransition.Response(messageReply.result(), messageReply.instance());
}
}

View File

@@ -0,0 +1,18 @@
package io.kestra.controller.grpc.client;
import io.kestra.controller.grpc.LivenessControllerServiceGrpc;
import io.kestra.server.GrpcChannelProvider;
import io.micronaut.context.annotation.Bean;
import io.micronaut.context.annotation.Factory;
import jakarta.inject.Singleton;
@Factory
public class GrpcClientBeanFactory {
@Bean
@Singleton
public LivenessControllerServiceGrpc.LivenessControllerServiceBlockingStub workerServiceStub(GrpcChannelProvider grpcChannelProvider) {
return LivenessControllerServiceGrpc.newBlockingStub(grpcChannelProvider.createOrGetDefault());
}
}

View File

@@ -0,0 +1,54 @@
package io.kestra.controller.grpc.server;
import io.grpc.stub.StreamObserver;
import io.kestra.controller.grpc.HeartbeatRequest;
import io.kestra.controller.grpc.HeartbeatResponse;
import io.kestra.controller.grpc.LivenessControllerServiceGrpc;
import io.kestra.controller.messages.HeartbeatMessage;
import io.kestra.controller.messages.HeartbeatMessageReply;
import io.kestra.core.server.ServiceLivenessUpdater;
import io.kestra.core.server.ServiceStateTransition;
import io.kestra.server.internals.MessageFormat;
import jakarta.inject.Inject;
import jakarta.inject.Singleton;
@Singleton
public class LivenessControllerService extends LivenessControllerServiceGrpc.LivenessControllerServiceImplBase {
private final ServiceLivenessUpdater serviceLivenessUpdater;
@Inject
public LivenessControllerService(ServiceLivenessUpdater serviceLivenessUpdater) {
this.serviceLivenessUpdater = serviceLivenessUpdater;
}
/**
* {@inheritDoc}
*/
@Override
public void heartbeat(HeartbeatRequest request, StreamObserver<HeartbeatResponse> responseObserver) {
final MessageFormat messageFormat = MessageFormat.resolve(request.getHeader().getMessageFormat());
HeartbeatMessage message = messageFormat
.fromByteString(request.getMessage(), HeartbeatMessage.class);
ServiceStateTransition.Response response;
if (message.newState() != null) {
response = serviceLivenessUpdater.update(message.instance(), message.newState(), message.reason());
} else {
serviceLivenessUpdater.update(message.instance());
response = new ServiceStateTransition.Response(ServiceStateTransition.Result.SUCCEEDED, message.instance());
}
responseObserver.onNext(HeartbeatResponse
.newBuilder()
.setHeader(request.getHeader())
.setMessage(messageFormat.toByteString(new HeartbeatMessageReply(
response.instance(),
response.result()
)))
.build()
);
responseObserver.onCompleted();
}
}

View File

@@ -0,0 +1,130 @@
package io.kestra.controller.grpc.server;
import com.fasterxml.jackson.core.type.TypeReference;
import io.grpc.stub.ServerCallStreamObserver;
import io.grpc.stub.StreamObserver;
import io.kestra.core.queues.QueueException;
import io.kestra.core.queues.QueueFactoryInterface;
import io.kestra.core.queues.QueueInterface;
import io.kestra.core.queues.WorkerJobQueueInterface;
import io.kestra.core.runners.Worker;
import io.kestra.core.runners.WorkerInstance;
import io.kestra.core.runners.WorkerJob;
import io.kestra.core.runners.WorkerJobRunningStateStore;
import io.kestra.core.runners.WorkerTask;
import io.kestra.core.runners.WorkerTaskResult;
import io.kestra.core.runners.WorkerTaskRunning;
import io.kestra.core.runners.WorkerTrigger;
import io.kestra.core.runners.WorkerTriggerRunning;
import io.kestra.server.internals.BatchMessage;
import io.kestra.server.internals.MessageFormat;
import io.kestra.worker.grpc.FetchWorkerJobRequest;
import io.kestra.worker.grpc.FetchWorkerJobResponse;
import io.kestra.worker.grpc.WorkerControllerServiceGrpc;
import io.kestra.worker.grpc.WorkerJobResultsRequest;
import io.kestra.worker.grpc.WorkerJobResultsResponse;
import io.kestra.worker.messages.FetchWorkerJobMessage;
import io.kestra.worker.messages.WorkerJobBatchMessage;
import jakarta.annotation.PreDestroy;
import jakarta.inject.Inject;
import jakarta.inject.Named;
import jakarta.inject.Singleton;
import lombok.extern.slf4j.Slf4j;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
@Singleton
@Slf4j
public class WorkerControllerService extends WorkerControllerServiceGrpc.WorkerControllerServiceImplBase {
public static final TypeReference<BatchMessage<WorkerTaskResult>> WORKER_TASK_RESULT_BATCH_MESSAGE_TYPE_REFERENCE = new TypeReference<>() {
};
@Inject
@Named(QueueFactoryInterface.WORKERJOB_NAMED)
private WorkerJobQueueInterface workerJobQueue;
@Inject
@Named(QueueFactoryInterface.WORKERTASKRESULT_NAMED)
private QueueInterface<WorkerTaskResult> workerTaskResultQueue;
@Inject
private WorkerJobRunningStateStore workerJobRunningStateStore;
private final ConcurrentHashMap<String, Runnable> disposables = new ConcurrentHashMap<>();
@Override
public void fetchWorkerJobsStream(FetchWorkerJobRequest request, StreamObserver<FetchWorkerJobResponse> responseObserver) {
final MessageFormat messageFormat = MessageFormat.resolve(request.getHeader().getMessageFormat());
FetchWorkerJobMessage message = messageFormat.fromByteString(request.getMessage(), FetchWorkerJobMessage.class);
ServerCallStreamObserver<FetchWorkerJobResponse> serverObserver = (ServerCallStreamObserver<FetchWorkerJobResponse>) responseObserver;
log.info("Received worker-job request from worker [{}]", message.workerId());
serverObserver.setOnCancelHandler(() -> {
log.info("Worker [{}] disconnected or cancelled", message.workerId());
Optional.ofNullable(disposables.remove(message.workerId())).ifPresent(Runnable::run);
});
// TODO
// Currently consumer thread is managed directly by the WorkerJobQueue.
// It could be preferable that the WorkerControllerServer start a polling thread
// for consuming the workerJobQueue (e.g., via a poll method) to be able to manage it more properly on cancel.
Runnable stopReceiving = this.workerJobQueue.receive(message.workerGroup(), Worker.class, either -> {
if (either.isRight()) {
log.error("Unable to deserialize a worker job: {}", either.getRight().getMessage());
return;
}
WorkerJob job = either.getLeft();
log.info("Sending job [{}] to worker [{}]", job.uid(), message.workerId()); // TODO change to debug
serverObserver.onNext(FetchWorkerJobResponse
.newBuilder()
.setHeader(request.getHeader())
.setMessage(messageFormat.toByteString(new WorkerJobBatchMessage(List.of(job))))
.build()
);
WorkerInstance workerInstance = new WorkerInstance(message.workerId(), message.workerGroup());
if (job instanceof WorkerTask workerTask) {
workerJobRunningStateStore.put(WorkerTaskRunning.of(workerTask, workerInstance, -1));
} else if (job instanceof WorkerTrigger workerTrigger) {
workerJobRunningStateStore.put(WorkerTriggerRunning.of(workerTrigger, workerInstance, -1));
} else {
log.error("Message is of type [{}] which should never occurs", job);
}
}, false);
disposables.put(message.workerId(), () -> {
stopReceiving.run();
serverObserver.onCompleted();
});
}
@Override
public void sendWorkerJobResults(WorkerJobResultsRequest request, StreamObserver<WorkerJobResultsResponse> responseObserver) {
final MessageFormat messageFormat = MessageFormat.resolve(request.getHeader().getMessageFormat());
BatchMessage<WorkerTaskResult> message = messageFormat.fromByteString(request.getMessage(), WORKER_TASK_RESULT_BATCH_MESSAGE_TYPE_REFERENCE);
message.records().forEach(workerTaskResult -> {
try {
workerTaskResultQueue.emit(workerTaskResult);
} catch (QueueException e) {
throw new RuntimeException(e);
}
});
responseObserver.onNext(WorkerJobResultsResponse
.newBuilder()
.setHeader(request.getHeader())
.build()
);
responseObserver.onCompleted();
}
@PreDestroy
public void close() {
this.disposables.values().forEach(Runnable::run);
}
}

View File

@@ -0,0 +1,14 @@
package io.kestra.controller.messages;
import io.kestra.core.server.Service;
import io.kestra.core.server.ServiceInstance;
/**
* Message for {@link io.kestra.controller.grpc.HeartbeatRequest}.
*/
public record HeartbeatMessage(
ServiceInstance instance,
Service.ServiceState newState,
String reason
) {
}

View File

@@ -0,0 +1,14 @@
package io.kestra.controller.messages;
import io.kestra.core.server.ServiceInstance;
import io.kestra.core.server.ServiceStateTransition;
/**
* Message for {@link io.kestra.controller.grpc.HeartbeatResponse}.
*/
public record HeartbeatMessageReply(
ServiceInstance instance,
ServiceStateTransition.Result result
) {
}

View File

@@ -0,0 +1,87 @@
package io.kestra.server;
import io.kestra.core.server.Service;
import io.kestra.core.server.ServiceStateChangeEvent;
import io.kestra.core.server.ServiceType;
import io.kestra.core.utils.IdUtils;
import io.micronaut.context.event.ApplicationEventPublisher;
import jakarta.annotation.PreDestroy;
import lombok.extern.slf4j.Slf4j;
import java.util.Map;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
// TODO add it kestra 1.x
@Slf4j
public class AbstractService implements Service {
private final String id;
private final ServiceType serviceType;
private final ApplicationEventPublisher<ServiceStateChangeEvent> eventPublisher;
private final AtomicReference<ServiceState> state = new AtomicReference<>();
private final AtomicBoolean stopped = new AtomicBoolean(false);
public AbstractService(ServiceType serviceType, ApplicationEventPublisher<ServiceStateChangeEvent> eventPublisher) {
this.id = IdUtils.create();
this.serviceType = serviceType;
this.eventPublisher = eventPublisher;
setState(ServiceState.CREATED);
}
protected void setState(final ServiceState state) {
this.state.set(state);
this.eventPublisher.publishEvent(new ServiceStateChangeEvent(this, getProperties()));
}
@Override
public String getId() {
return id;
}
@Override
public ServiceType getType() {
return serviceType;
}
@Override
public ServiceState getState() {
return state.get();
}
protected Map<String, Object> getProperties() {
return Map.of();
}
/**
* {@inheritDoc}
*/
@Override
public void close() {
stop();
}
@PreDestroy
public void stop() {
if (stopped.compareAndSet(false, true)) {
setState(ServiceState.TERMINATING);
log.info("Terminating");
try {
ServiceState serviceState = doStop();
setState(serviceState);
} catch (Exception e) {
log.debug("Error while stopping service [{}]", this.getClass().getSimpleName(), e);
if (e instanceof InterruptedException) {
Thread.currentThread().interrupt();
}
setState(ServiceState.TERMINATED_FORCED);
}
log.info("Service [{}] stopped {}", this.getClass().getSimpleName(), getState());
}
}
protected ServiceState doStop() throws Exception {
return ServiceState.TERMINATED_GRACEFULLY;
}
}

View File

@@ -0,0 +1,78 @@
package io.kestra.server;
import io.grpc.Channel;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.kestra.core.contexts.KestraContext;
import jakarta.annotation.PreDestroy;
import jakarta.inject.Singleton;
import lombok.extern.slf4j.Slf4j;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
@Singleton
@Slf4j
public class GrpcChannelProvider {
private volatile ManagedChannel defaultChannel;
private volatile ExecutorService defaultExecutorService;
private final AtomicBoolean stopped = new AtomicBoolean(false);
/**
* Return a shared gRPC Channel.
* <p>
* This method will create the channel if necessary.
*
* @return the {@link Channel}
*/
public Channel createOrGetDefault() {
// TODO externalize all config
if (this.defaultChannel == null) {
synchronized (this) {
if (this.defaultChannel == null) {
defaultExecutorService = Executors.newSingleThreadExecutor();
defaultChannel = ManagedChannelBuilder.forAddress("localhost", 9096)
.usePlaintext()
.enableRetry()
.maxRetryAttempts(10)
.userAgent(getUserAgent())
.keepAliveTime(1, TimeUnit.HOURS)
.keepAliveWithoutCalls(true)
.executor(defaultExecutorService)
.build();
}
}
}
return defaultChannel;
}
@PreDestroy
public void close() {
if (!stopped.compareAndSet(false, true)) {
return; // Method called twice
}
if (this.defaultChannel != null && !this.defaultChannel.isShutdown()) {
try {
shutdownServerAndWait();
} catch (Exception e) {
log.debug("Error while stopping default gRPC channel", e);
if (e instanceof InterruptedException)
Thread.currentThread().interrupt();
}
this.defaultExecutorService.shutdownNow();
}
}
private void shutdownServerAndWait() throws InterruptedException {
this.defaultChannel.shutdown().awaitTermination(30, TimeUnit.SECONDS);
}
private static String getUserAgent() {
return "Kestra/" + KestraContext.getContext().getVersion();
}
}

View File

@@ -0,0 +1,24 @@
package io.kestra.server.internals;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;
/**
* A generic bath message.
*
* @param records the records of the batch.
* @param <T> the record type.
*/
public record BatchMessage<T>(
List<T> records
) {
public static <T> BatchMessage<T> of(T... records) {
return new BatchMessage<>(Arrays.asList(records));
}
public List<T> records() {
return Optional.ofNullable(records).orElse(List.of());
}
}

View File

@@ -0,0 +1,27 @@
package io.kestra.server.internals;
import com.fasterxml.jackson.core.type.TypeReference;
import com.google.protobuf.ByteString;
import io.kestra.core.utils.Enums;
/**
* Represents a specific message format.
* <p>
* Each gRPC message contain a generic byte array `message` field.
*/
public interface MessageFormat {
<T> T fromByteString(ByteString data, Class<T> type);
<T> T fromByteString(ByteString data, TypeReference<T> type);
ByteString toByteString(Object value);
static MessageFormat resolve(final String format) {
try {
return Enums.getForNameIgnoreCase(format, MessageFormats.class);
} catch (IllegalArgumentException e) {
return MessageFormats.JSON; // default
}
}
}

View File

@@ -0,0 +1,64 @@
package io.kestra.server.internals;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.protobuf.ByteString;
import io.kestra.core.serializers.JacksonMapper;
import java.io.IOException;
import java.util.Optional;
/**
* Supported formats for serialized messages in Protocol Buffer message.
*/
public enum MessageFormats implements MessageFormat{
JSON() {
private static final ObjectMapper OBJECT_MAPPER = JacksonMapper.ofJson(false);
/** {@inheritDoc} **/
@Override
public <T> T fromByteString(ByteString data, Class<T> type) {
byte[] bytes = toByteArray(data);
if (bytes == null || bytes.length == 0) {
return null;
}
try {
return OBJECT_MAPPER.readValue(bytes, type);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
@Override
public <T> T fromByteString(ByteString data, TypeReference<T> type) {
byte[] bytes = toByteArray(data);
if (bytes == null || bytes.length == 0) {
return null;
}
try {
return OBJECT_MAPPER.readValue(bytes, type);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
/** {@inheritDoc} **/
@Override
public ByteString toByteString(Object value) {
if (value == null) {
return ByteString.EMPTY;
}
try {
return ByteString.copyFrom(OBJECT_MAPPER.writeValueAsBytes(value));
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
}
};
private static byte[] toByteArray(ByteString data) {
return Optional.ofNullable(data).map(ByteString::toByteArray).orElse(null);
}
}

View File

@@ -0,0 +1,33 @@
package io.kestra.worker;
import io.kestra.controller.GrpcServiceLivenessUpdater;
import io.kestra.controller.grpc.LivenessControllerServiceGrpc;
import io.kestra.core.server.ServiceLivenessUpdater;
import io.kestra.server.GrpcChannelProvider;
import io.kestra.worker.grpc.WorkerControllerServiceGrpc;
import io.micronaut.context.annotation.Bean;
import io.micronaut.context.annotation.Factory;
import io.micronaut.context.annotation.Primary;
import jakarta.inject.Singleton;
@Factory
public class BeanFactory {
@Singleton
@Primary
public ServiceLivenessUpdater serviceLivenessUpdater(LivenessControllerServiceGrpc.LivenessControllerServiceBlockingStub client) {
return new GrpcServiceLivenessUpdater(client);
}
@Bean
@Singleton
public WorkerControllerServiceGrpc.WorkerControllerServiceBlockingStub blockingWorkerServiceStub(GrpcChannelProvider grpcChannelProvider) {
return WorkerControllerServiceGrpc.newBlockingStub(grpcChannelProvider.createOrGetDefault());
}
@Bean
@Singleton
public WorkerControllerServiceGrpc.WorkerControllerServiceStub asyncWorkerServiceStub(GrpcChannelProvider grpcChannelProvider) {
return WorkerControllerServiceGrpc.newStub(grpcChannelProvider.createOrGetDefault());
}
}

View File

@@ -0,0 +1,210 @@
package io.kestra.worker;
import io.kestra.core.metrics.MetricRegistry;
import io.kestra.core.server.Metric;
import io.kestra.core.server.ServerConfig;
import io.kestra.core.server.Service;
import io.kestra.core.server.ServiceStateChangeEvent;
import io.kestra.core.server.ServiceType;
import io.kestra.core.services.WorkerGroupService;
import io.kestra.core.utils.Await;
import io.kestra.server.AbstractService;
import io.micronaut.context.annotation.Prototype;
import io.micronaut.context.event.ApplicationEventPublisher;
import jakarta.inject.Inject;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import java.time.Duration;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import static io.kestra.core.server.Service.ServiceState.TERMINATED_FORCED;
import static io.kestra.core.server.Service.ServiceState.TERMINATED_GRACEFULLY;
@SuppressWarnings("this-escape")
@Slf4j
@Prototype
public class Worker extends AbstractService implements Service {
private static final String SERVICE_PROPS_WORKER_GROUP = "worker.group";
@Inject
private MetricRegistry metricRegistry;
@Inject
private ServerConfig serverConfig;
@Getter
private final Map<Long, AtomicInteger> metricRunningCount = new ConcurrentHashMap<>();
private final AtomicBoolean skipGracefulTermination = new AtomicBoolean(false);
private final WorkerGroupService workerGroupService;
private final AtomicBoolean initialized = new AtomicBoolean(false);
private final AtomicInteger pendingJobCount = new AtomicInteger(0);
private final AtomicInteger runningJobCount = new AtomicInteger(0);
private final WorkerJobExecutor workerJobExecutor;
private final WorkerJobFetcher workerJobFetcher;
private final WorkerTaskResultSender workerTaskResultSender;
private String workerGroup;
/**
* Creates a new {@link Worker} instance.
*/
@Inject
public Worker(
ApplicationEventPublisher<ServiceStateChangeEvent> eventPublisher,
WorkerGroupService workerGroupService,
WorkerJobExecutor workerJobExecutor,
WorkerJobFetcher workerJobFetcher,
WorkerTaskResultSender workerTaskResultSender
) {
super(ServiceType.WORKER_AGENT, eventPublisher);
this.workerGroupService = workerGroupService;
this.workerJobExecutor = workerJobExecutor;
this.workerJobFetcher = workerJobFetcher;
this.workerTaskResultSender = workerTaskResultSender;
this.setState(ServiceState.CREATED);
}
@Override
public Set<Metric> getMetrics() {
if (this.metricRegistry == null) {
// can arrive if called before the instance is fully created
return Collections.emptySet();
}
Stream<String> metrics = Stream.of(
MetricRegistry.METRIC_WORKER_JOB_THREAD_COUNT,
MetricRegistry.METRIC_WORKER_JOB_PENDING_COUNT,
MetricRegistry.METRIC_WORKER_JOB_RUNNING_COUNT
);
return metrics
.flatMap(metric -> Optional.ofNullable(metricRegistry.findGauge(metric)).stream())
.map(Metric::of)
.collect(Collectors.toSet());
}
public void start(int numThreads, String workerGroupKey) {
if (!this.initialized.compareAndSet(false, true)) {
throw new IllegalStateException("Worker already started");
}
this.workerGroup = workerGroupService.resolveGroupFromKey(workerGroupKey);
String[] tags = workerGroup == null ? new String[0] : new String[]{MetricRegistry.TAG_WORKER_GROUP, workerGroup};
// create metrics to store thread count, pending jobs and running jobs, so we can have autoscaling easily
this.metricRegistry.gauge(MetricRegistry.METRIC_WORKER_JOB_THREAD_COUNT, MetricRegistry.METRIC_WORKER_JOB_THREAD_COUNT_DESCRIPTION, numThreads, tags);
this.metricRegistry.gauge(MetricRegistry.METRIC_WORKER_JOB_PENDING_COUNT, MetricRegistry.METRIC_WORKER_JOB_PENDING_COUNT_DESCRIPTION, pendingJobCount, tags);
this.metricRegistry.gauge(MetricRegistry.METRIC_WORKER_JOB_RUNNING_COUNT, MetricRegistry.METRIC_WORKER_JOB_RUNNING_COUNT_DESCRIPTION, runningJobCount, tags);
workerTaskResultSender.start(getId(), workerGroup);
workerJobFetcher.start(getId(), workerGroup);
workerJobExecutor.start(getId(), workerGroup, numThreads);
if (workerGroupKey != null) {
log.info("Worker started with {} thread(s) in group '{}'", numThreads, workerGroupKey);
} else {
log.info("Worker started with {} thread(s)", numThreads);
}
setState(ServiceState.RUNNING);
}
/**
* {@inheritDoc}
**/
@Override
protected Map<String, Object> getProperties() {
Map<String, Object> properties = new HashMap<>();
properties.put(SERVICE_PROPS_WORKER_GROUP, workerGroup);
return properties;
}
@Override
protected ServiceState doStop() {
this.workerJobFetcher.stop();
this.workerJobExecutor.pause();
this.workerTaskResultSender.stop();
final boolean terminatedGracefully;
if (!skipGracefulTermination.get()) {
terminatedGracefully = waitForTasksCompletion(serverConfig.terminationGracePeriod());
} else {
log.info("Terminating now and skip waiting for tasks completions.");
this.workerJobExecutor.shutdownNow();
terminatedGracefully = false;
}
return terminatedGracefully ? TERMINATED_GRACEFULLY : TERMINATED_FORCED;
}
private boolean waitForTasksCompletion(final Duration timeout) {
final Instant deadline = Instant.now().plus(timeout);
AtomicReference<ServiceState> shutdownState = new AtomicReference<>();
// start shutdown
Thread.ofVirtual().name("worker-shutdown").start(
() -> {
try {
long remaining = Math.max(0, Instant.now().until(deadline, ChronoUnit.MILLIS));
boolean gracefullyShutdown = this.workerJobExecutor.shutdown(Duration.ofMillis(remaining));
shutdownState.set(gracefullyShutdown ? TERMINATED_GRACEFULLY : TERMINATED_FORCED);
} catch (InterruptedException e) {
log.error("Failed to shutdown. Thread was interrupted");
shutdownState.set(TERMINATED_FORCED);
}
}
);
// wait for task completion
Await.until(
() -> {
ServiceState serviceState = shutdownState.get();
if (serviceState == TERMINATED_FORCED || serviceState == TERMINATED_GRACEFULLY) {
log.info("All working threads are terminated.");
return true;
}
long runningJobs = this.workerJobExecutor.getRunningJobCount();
if (runningJobs == 0) {
log.debug("All worker threads is terminated.");
} else {
log.warn("Waiting for all worker threads to terminate (remaining: {}).", runningJobs);
}
return false;
},
Duration.ofSeconds(1)
);
return shutdownState.get() == TERMINATED_GRACEFULLY;
}
/**
* Specify whether to skip graceful termination on shutdown.
*
* @param skipGracefulTermination {@code true} to skip graceful termination on shutdown.
*/
@Override
public void skipGracefulTermination(final boolean skipGracefulTermination) {
this.skipGracefulTermination.set(skipGracefulTermination);
}
}

View File

@@ -0,0 +1,95 @@
package io.kestra.worker;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.Objects;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
/**
* A WorkerIO thread is responsible for processing incoming/outgoing data from and to the worker.
* <p>
* A WorkerIO mostly does network operations.
*/
public abstract class WorkerIOThread implements Runnable {
private static final Logger LOG = LoggerFactory.getLogger(WorkerIOThread.class);
private final String name;
protected String workerId;
protected String workerGroup;
private volatile Thread thread;
private final AtomicBoolean running = new AtomicBoolean(false);
private final CountDownLatch stopped = new CountDownLatch(1);
public WorkerIOThread(final String name) {
this.name = Objects.requireNonNull(name, "name must not be null");
}
public synchronized void start(final String workerId, final String workerGroup) {
if (!running.compareAndSet(false, true)) {
throw new IllegalStateException("[%s] already started".formatted(getClass().getSimpleName()));
}
this.workerId = workerId;
this.workerGroup = workerGroup;
// TODO could be probably replace with a virtual-thread executor
this.thread = new Thread(this, "worker-" + this.name + "-" + workerId);
this.thread.setDaemon(false);
this.thread.start();
LOG.info("[{}] started with workerId={} group={}", getClass().getSimpleName(), workerId, workerGroup);
}
@Override
public void run() {
try {
while (running.get()) {
try {
doOnLoop();
} catch (InterruptedException ie) {
LOG.info("[{}] interrupted, stopping", getClass().getSimpleName());
Thread.currentThread().interrupt();
break; // exit loop
} catch (Exception e) {
LOG.error("Error in IO worker loop", e);
}
}
} finally {
stopped.countDown();
LOG.info("[{}] stopped", getClass().getSimpleName());
}
}
protected abstract void doOnLoop() throws Exception;
protected void doOnStop() {
//noop
}
public synchronized void stop() {
if (!running.compareAndSet(true, false)) {
LOG.debug("[{}] stop() called but not running", getClass().getSimpleName());
return;
}
if (thread != null) {
try {
doOnStop();
} catch (Exception e) {
LOG.error("Error in IO worker loop", e);
}
thread.interrupt();
try {
if (!stopped.await(1, TimeUnit.MINUTES)) {
LOG.warn("Timeout while waiting for {} to complete", thread.getName());
}
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
}
}
}

View File

@@ -0,0 +1,272 @@
package io.kestra.worker;
import io.kestra.core.runners.WorkerJob;
import io.kestra.core.utils.ExecutorsUtils;
import io.kestra.worker.processors.WorkerJobProcessor;
import io.kestra.worker.processors.WorkerJobProcessorFactory;
import io.kestra.worker.queues.WorkerJobQueue;
import io.kestra.worker.queues.WorkerQueueFactory;
import io.micronaut.context.annotation.Prototype;
import jakarta.inject.Inject;
import lombok.extern.slf4j.Slf4j;
import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.ReentrantLock;
/**
* Components responsible for executing {@link io.kestra.core.runners.WorkerJob}
*/
@Prototype
@Slf4j
public class WorkerJobExecutor {
private static final String EXECUTOR_NAME = "worker";
private final WorkerQueueFactory workerQueueFactory;
private final WorkerJobProcessorFactory workerJobProcessorFactory;
private final ExecutorsUtils executorsUtils;
private ExecutorService executorService;
private List<WorkerJobConsumer> workerJobConsumers;
private final AtomicBoolean started = new AtomicBoolean(false);
@Inject
public WorkerJobExecutor(final WorkerQueueFactory workerQueueFactory,
final ExecutorsUtils executorsUtils,
final WorkerJobProcessorFactory workerJobProcessorFactory) {
this.workerJobProcessorFactory = workerJobProcessorFactory;
this.workerQueueFactory = workerQueueFactory;
this.executorsUtils = executorsUtils;
}
public void start(final String workerId,
final String workerGroup,
int threads) {
WorkerJobQueue workerJobQueue = new WorkerJobQueue.Default(workerQueueFactory.getOrCreate(workerId, WorkerJob.class));
if (this.started.compareAndSet(false, true)) {
this.executorService = executorsUtils.maxCachedThreadPool(threads, EXECUTOR_NAME);
this.workerJobConsumers = new ArrayList<>(threads);
for (int i = 0; i < threads; i++) {
WorkerJobConsumer consumer = new WorkerJobConsumer(
workerJobQueue,
workerJobProcessorFactory,
workerId,
workerGroup
);
this.workerJobConsumers.add(consumer);
executorService.submit(consumer);
}
} else {
throw new IllegalStateException("already started");
}
}
/**
* Returns the number of running a job.
*
* @return the number of job being processed
*/
public long getRunningJobCount() {
return workerJobConsumers.stream()
.filter(WorkerJobConsumer::isProcessing)
.count();
}
/**
* Notify all underlying WorkerJob consumers to pause.
*/
public void pause() {
workerJobConsumers.forEach(WorkerJobConsumer::pause);
}
/**
* Notify all underlying WorkerJob consumers to resume.
*/
public void resume() {
checkIsStarted();
workerJobConsumers.forEach(WorkerJobConsumer::resume);
}
private void checkIsStarted() {
if (!this.started.get()) {
throw new IllegalStateException("WorkerJobExecutor not started");
}
}
/**
* Immediately initiates shutdown of all consumers and halts the processing of waiting jobs.
* <p>
* This is a convenience method that calls {@link #shutdown(Duration)} with {@code Duration.ZERO}
* and ignores any {@link InterruptedException} by resetting the interrupt flag.
*/
public void shutdownNow() {
try {
shutdown(Duration.ZERO);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
}
/**
* Initiates a graceful shutdown by notifying all consumers to stop and waiting for termination.
* <p>
* If the specified {@code terminationGracePeriod} is {@code null} or {@code Duration.ZERO},
* the executor will skip graceful shutdown and immediately attempt to forcefully stop all
* running tasks.
*
* @param terminationGracePeriod the maximum duration to wait for graceful shutdown
* @return {@code true} if the executor terminated within the timeout; {@code false} if forced shutdown was required
* @throws InterruptedException if the current thread is interrupted while waiting
*/
public boolean shutdown(final Duration terminationGracePeriod) throws InterruptedException {
if (!this.started.compareAndSet(true, false)) {
return true; // Already shut down or not started.
}
// Initiate graceful shutdown
this.executorService.shutdown();
// Notify all WorkerJobConsumers to stop
this.workerJobConsumers.forEach(WorkerJobConsumer::stop);
if (terminationGracePeriod == null || terminationGracePeriod.equals(Duration.ZERO)) {
this.executorService.shutdownNow();
return false;
}
// Wait for all WorkerJobConsumers to terminate
boolean terminated = this.executorService.awaitTermination(
terminationGracePeriod.toMillis(), TimeUnit.MILLISECONDS);
if (!terminated) {
log.warn("Worker still has pending jobs after the termination grace period. Forcing shutdown.");
this.executorService.shutdownNow();
}
return terminated;
}
private static class WorkerJobConsumer implements Runnable {
private final AtomicBoolean stopped = new AtomicBoolean(false);
private final AtomicBoolean paused = new AtomicBoolean(false);
private final ReentrantLock pauseLock = new ReentrantLock();
private final Condition unpaused = pauseLock.newCondition();
private final AtomicReference<WorkerJobProcessor<WorkerJob>> running = new AtomicReference<>(null);
private final WorkerJobQueue workerJobQueue;
private final WorkerJobProcessorFactory workerJobProcessorFactory;
private final String workerId;
private final String workerGroup;
public WorkerJobConsumer(WorkerJobQueue workerJobQueue,
WorkerJobProcessorFactory workerJobProcessorFactory,
String workerId,
String workerGroup) {
this.workerJobQueue = workerJobQueue;
this.workerJobProcessorFactory = workerJobProcessorFactory;
this.workerId = workerId;
this.workerGroup = workerGroup;
}
/**
* Continuously polls for new {@link WorkerJob} and processes them sequentially.
* <p>
* It blocks while waiting for new jobs and ensures that only one job is processed
* at a time. This method will not return unless interrupted or explicitly stopped.
*/
@Override
public void run() {
try {
while (!stopped.get()) {
waitIfPaused();
WorkerJob job = workerJobQueue.poll(Duration.ofSeconds(1));
if (job == null || stopped.get()) {
continue;
}
try {
WorkerJobProcessor<WorkerJob> processor =
workerJobProcessorFactory.create(workerId, workerGroup, job);
running.set(processor);
processor.process(job);
} finally {
running.set(null);
}
}
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
}
/**
* Check whether a job is currently being processed
*
* @return {@code true} if a {@link WorkerJob} is actively being processed; {@code false} otherwise.
*/
public boolean isProcessing() {
return running.get() != null;
}
private void waitIfPaused() throws InterruptedException {
pauseLock.lock();
try {
while (paused.get() && !stopped.get()) {
unpaused.await(); // Wait until resume() signals
}
} finally {
pauseLock.unlock();
}
}
/**
* Pauses polling for new {@link WorkerJob} instances.
* <p>
* If a job is currently running, it will continue to completion.
* No new jobs will be polled until resumed.
*/
public void pause() {
paused.set(true);
}
/**
* Resumes polling and processing of {@link WorkerJob} instances if currently paused.
*/
public void resume() {
pauseLock.lock();
try {
if (paused.compareAndSet(true, false)) {
unpaused.signalAll();
}
} finally {
pauseLock.unlock();
}
}
/**
* Stops polling and processing of {@link WorkerJob} instances.
*/
public void stop() {
if (this.stopped.compareAndSet(false, true)) {
resume(); // In case it's paused and blocked
WorkerJobProcessor<WorkerJob> processor = running.get();
if (processor != null) {
processor.stop();
}
}
}
}
}

View File

@@ -0,0 +1,122 @@
package io.kestra.worker;
import io.grpc.stub.ClientCallStreamObserver;
import io.grpc.stub.ClientResponseObserver;
import io.kestra.core.contexts.KestraContext;
import io.kestra.core.runners.WorkerJob;
import io.kestra.server.grpc.RequestOrResponseHeader;
import io.kestra.server.internals.MessageFormat;
import io.kestra.server.internals.MessageFormats;
import io.kestra.worker.grpc.FetchWorkerJobRequest;
import io.kestra.worker.grpc.FetchWorkerJobResponse;
import io.kestra.worker.grpc.WorkerControllerServiceGrpc.WorkerControllerServiceStub;
import io.kestra.worker.messages.FetchWorkerJobMessage;
import io.kestra.worker.messages.WorkerJobBatchMessage;
import io.kestra.worker.queues.WorkerJobQueue;
import io.kestra.worker.queues.WorkerQueueFactory;
import io.micronaut.context.annotation.Prototype;
import jakarta.inject.Inject;
import lombok.extern.slf4j.Slf4j;
import java.util.UUID;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicReference;
/**
* Component responsible for fetching worker jobs.
*/
@Prototype
@Slf4j
public class WorkerJobFetcher extends WorkerIOThread {
private final WorkerControllerServiceStub workerControllerServiceStub;
private final WorkerQueueFactory workerQueueFactory;
private WorkerJobQueue workerJobQueue;
private final AtomicReference<ClientCallStreamObserver<FetchWorkerJobRequest>> currentStreamObserver = new AtomicReference<>();
@Inject
public WorkerJobFetcher(
final WorkerControllerServiceStub workerControllerServiceStub,
final WorkerQueueFactory workerQueueFactory) {
super(WorkerJobFetcher.class.getSimpleName());
this.workerQueueFactory = workerQueueFactory;
this.workerControllerServiceStub = workerControllerServiceStub;
}
/**
* {@inheritDoc}
**/
@Override
public synchronized void start(String workerId, String workerGroup) {
this.workerJobQueue = new WorkerJobQueue.Default(workerQueueFactory.getOrCreate(workerId, WorkerJob.class));
super.start(workerId, workerGroup);
}
/**
* {@inheritDoc}
**/
@Override
protected void doOnLoop() throws Exception {
FetchWorkerJobRequest request = FetchWorkerJobRequest.newBuilder()
.setHeader(newRequestOrResponseHeader())
.setMessage(MessageFormats.JSON.toByteString(new FetchWorkerJobMessage(workerId, workerGroup)))
.build();
CountDownLatch completed = new CountDownLatch(1);
// Start the streaming call
ClientResponseObserver<FetchWorkerJobRequest, FetchWorkerJobResponse> streamCompleted = new ClientResponseObserver<>() {
@Override
public void beforeStart(ClientCallStreamObserver<FetchWorkerJobRequest> requestStream) {
currentStreamObserver.set(requestStream);
}
@Override
public void onNext(FetchWorkerJobResponse response) {
log.info("Stream onNext: {}", response);
String messageFormat = response.getHeader().getMessageFormat();
WorkerJobBatchMessage workerJobBatch = MessageFormat
.resolve(messageFormat)
.fromByteString(response.getMessage(), WorkerJobBatchMessage.class);
if (workerJobBatch != null && !workerJobBatch.jobs().isEmpty()) {
workerJobBatch.jobs().forEach(workerJobQueue::put);
}
}
@Override
public void onError(Throwable t) {
log.error("Stream error: {}", t.getMessage(), t);
completed.countDown();
}
@Override
public void onCompleted() {
log.error("Stream completed");
completed.countDown();
}
};
workerControllerServiceStub.fetchWorkerJobsStream(request, streamCompleted);
completed.await(); // Block until stream ends
}
@Override
protected void doOnStop() {
ClientCallStreamObserver<FetchWorkerJobRequest> active = currentStreamObserver.getAndSet(null);
if (active != null) {
active.cancel("Worker stopping", null);
}
}
private RequestOrResponseHeader newRequestOrResponseHeader() {
return RequestOrResponseHeader
.newBuilder()
.setClientId(workerId)
.setClientVersion(KestraContext.getContext().getVersion())
.setMessageFormat(MessageFormats.JSON.name())
.setCorrelationId(UUID.randomUUID().toString())
.build();
}
}

View File

@@ -0,0 +1,95 @@
package io.kestra.worker;
import io.grpc.stub.StreamObserver;
import io.kestra.core.contexts.KestraContext;
import io.kestra.core.runners.WorkerJob;
import io.kestra.core.runners.WorkerTaskResult;
import io.kestra.server.GrpcChannelProvider;
import io.kestra.server.grpc.RequestOrResponseHeader;
import io.kestra.server.internals.BatchMessage;
import io.kestra.server.internals.MessageFormats;
import io.kestra.worker.grpc.WorkerControllerServiceGrpc;
import io.kestra.worker.grpc.WorkerJobResultsRequest;
import io.kestra.worker.grpc.WorkerJobResultsResponse;
import io.kestra.worker.queues.WorkerJobQueue;
import io.kestra.worker.queues.WorkerQueueFactory;
import io.kestra.worker.queues.WorkerTaskResultQueue;
import io.kestra.worker.queues.WorkerTriggerResultQueue;
import io.micronaut.context.annotation.Prototype;
import jakarta.inject.Inject;
import lombok.extern.slf4j.Slf4j;
import java.time.Duration;
import java.util.UUID;
/**
* Component responsible for fetching worker jobs.
*/
@Prototype
@Slf4j
public class WorkerTaskResultSender extends WorkerIOThread {
private final WorkerControllerServiceGrpc.WorkerControllerServiceStub controllerServiceStub;
private final WorkerQueueFactory workerQueueFactory;
private WorkerTaskResultQueue queue;
@Inject
public WorkerTaskResultSender(
final WorkerControllerServiceGrpc.WorkerControllerServiceStub controllerServiceStub,
final WorkerQueueFactory workerQueueFactory) {
super(WorkerTaskResultSender.class.getSimpleName());
this.workerQueueFactory = workerQueueFactory;
this.controllerServiceStub = controllerServiceStub;
}
/**
* {@inheritDoc}
**/
@Override
public synchronized void start(String workerId, String workerGroup) {
this.queue = new WorkerTaskResultQueue.Default(workerQueueFactory.getOrCreate(workerId, WorkerTaskResult.class));
super.start(workerId, workerGroup);
}
/**
* {@inheritDoc}
*/
@Override
protected void doOnLoop() throws Exception {
WorkerTaskResult result = queue.poll(Duration.ofMillis(Long.MAX_VALUE));
if (result == null) return;
WorkerJobResultsRequest request = WorkerJobResultsRequest
.newBuilder()
.setHeader(newRequestOrResponseHeader())
.setMessage(MessageFormats.JSON.toByteString(BatchMessage.of(result)))
.build();
controllerServiceStub.sendWorkerJobResults(request, new StreamObserver<>() {
@Override
public void onNext(WorkerJobResultsResponse value) {
log.info("onNext {}", value);
}
@Override
public void onError(Throwable t) {
log.error("Error while sending worker job results", t);
}
@Override
public void onCompleted() {
log.info("onCompleted");
}
}
);
}
private RequestOrResponseHeader newRequestOrResponseHeader() {
return RequestOrResponseHeader
.newBuilder()
.setClientId(workerId)
.setClientVersion(KestraContext.getContext().getVersion())
.setMessageFormat(MessageFormats.JSON.name())
.setCorrelationId(UUID.randomUUID().toString())
.build();
}
}

View File

@@ -0,0 +1,7 @@
package io.kestra.worker.messages;
public record FetchWorkerJobMessage(
String workerId,
String workerGroup
) {
}

View File

@@ -0,0 +1,15 @@
package io.kestra.worker.messages;
import io.kestra.core.runners.WorkerJob;
import java.util.List;
import java.util.Optional;
public record WorkerJobBatchMessage(
List<WorkerJob> jobs
) {
public List<WorkerJob> getJobs() {
return Optional.ofNullable(jobs).orElse(List.of());
}
}

View File

@@ -0,0 +1,6 @@
@Configuration
@Requires(property = "kestra.server-type", pattern = "WORKER_AGENT")
package io.kestra.worker;
import io.micronaut.context.annotation.Configuration;
import io.micronaut.context.annotation.Requires;

View File

@@ -0,0 +1,89 @@
package io.kestra.worker.processors;
import io.kestra.core.metrics.MetricRegistry;
import io.kestra.core.models.flows.State;
import io.kestra.core.runners.AbstractWorkerCallable;
import io.kestra.core.runners.WorkerJob;
import io.kestra.core.runners.WorkerSecurityService;
import io.kestra.core.services.LogService;
import io.kestra.core.trace.TraceUtils;
import io.kestra.core.trace.Tracer;
import io.opentelemetry.api.common.Attributes;
import java.util.ConcurrentModificationException;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
public abstract class AbstractWorkerJobProcessor<T extends WorkerJob> implements WorkerJobProcessor<T> {
protected final String workerGroup;
protected final MetricRegistry metricRegistry;
protected final LogService logService;
private final WorkerSecurityService workerSecurityService;
private final Tracer tracer;
private final AtomicReference<WorkerJob> currentWorkerJob = new AtomicReference<>();
private final AtomicReference<AbstractWorkerCallable> currentWorkerCallable = new AtomicReference<>();
private final AtomicBoolean stopped = new AtomicBoolean(false);
public AbstractWorkerJobProcessor(String workerGroup,
LogService logService,
MetricRegistry metricRegistry,
WorkerSecurityService workerSecurityService,
Tracer tracer) {
this.workerGroup = workerGroup;
this.tracer = tracer;
this.metricRegistry = metricRegistry;
this.logService = logService;
this.workerSecurityService = workerSecurityService;
}
@Override
public void process(final T job) {
if (currentWorkerJob.compareAndSet(null, job)) {
try {
doProcess(job);
} finally {
currentWorkerJob.set(null);
}
} else {
// avoid miss-use of this class
throw new ConcurrentModificationException("Processor can only process one job at a time.");
}
}
protected abstract void doProcess(final T job);
protected io.kestra.core.models.flows.State.Type callJob(AbstractWorkerCallable workerJobCallable) {
this.currentWorkerCallable.set(workerJobCallable);
try {
return tracer.inCurrentContext(
workerJobCallable.getRunContext(),
workerJobCallable.getType(),
Attributes.of(TraceUtils.ATTR_UID, workerJobCallable.getUid()),
() -> workerSecurityService.callInSecurityContext(workerJobCallable)
);
} catch (Exception e) {
// should only occur if it fails in the tracing code which should be unexpected
// we add the exception to have some log in that case
workerJobCallable.setException(e);
return State.Type.FAILED;
} finally {
this.currentWorkerCallable.set(null);
}
}
@Override
public void stop() {
if (this.stopped.compareAndSet(false, true)) {
Optional.ofNullable(currentWorkerCallable.get()).ifPresent(AbstractWorkerCallable::signalStop);
}
}
protected boolean isStopped() {
return this.stopped.get();
}
}

View File

@@ -0,0 +1,30 @@
package io.kestra.worker.processors;
import io.kestra.core.runners.WorkerJob;
import io.micronaut.core.annotation.Blocking;
/**
* A processor responsible for executing a specific {@link WorkerJob}.
*
* @param <T> the type of {@link WorkerJob} to be processed
*/
public interface WorkerJobProcessor<T extends WorkerJob> {
/**
* Processes the given {@link WorkerJob}.
* <p>
* This method will block the calling thread until the job has been completed or terminated.
* Only one job may be processed at a time per {@code WorkerJobProcessor} instance.
*
* @param workerJob the {@link WorkerJob} to be executed
*/
@Blocking
void process(T workerJob);
/**
* Signals the currently running job to stop, if any.
* <p>
* If no job is currently running, the method returns immediately without any side effects.
*/
void stop();
}

View File

@@ -0,0 +1,92 @@
package io.kestra.worker.processors;
import io.kestra.core.metrics.MetricRegistry;
import io.kestra.core.models.executions.LogEntry;
import io.kestra.core.models.executions.MetricEntry;
import io.kestra.core.runners.RunContextInitializer;
import io.kestra.core.runners.RunContextLoggerFactory;
import io.kestra.core.runners.Worker;
import io.kestra.core.runners.WorkerJob;
import io.kestra.core.runners.WorkerSecurityService;
import io.kestra.core.runners.WorkerTask;
import io.kestra.core.runners.WorkerTaskResult;
import io.kestra.core.runners.WorkerTrigger;
import io.kestra.core.runners.WorkerTriggerResult;
import io.kestra.core.server.Metric;
import io.kestra.core.services.LogService;
import io.kestra.core.services.VariablesService;
import io.kestra.core.trace.Tracer;
import io.kestra.core.trace.TracerFactory;
import io.kestra.worker.queues.WorkerLogQueue;
import io.kestra.worker.queues.WorkerMetricQueue;
import io.kestra.worker.queues.WorkerQueueFactory;
import io.kestra.worker.queues.WorkerTaskResultQueue;
import io.kestra.worker.queues.WorkerTriggerResultQueue;
import jakarta.annotation.PostConstruct;
import jakarta.inject.Inject;
import jakarta.inject.Singleton;
@Singleton
public class WorkerJobProcessorFactory {
@Inject
private LogService logService;
@Inject
private MetricRegistry metricRegistry;
@Inject
private WorkerSecurityService workerSecurityService;
@Inject
private RunContextInitializer runContextInitializer;
@Inject
private RunContextLoggerFactory runContextLoggerFactory;
@Inject
private VariablesService variablesService;
// QUEUES
@Inject
private WorkerQueueFactory workerQueueFactory;
@Inject
private TracerFactory tracerFactory;
private Tracer tracer;
@PostConstruct
public void init() {
this.tracer = tracerFactory.getTracer(Worker.class, "WORKER");
}
@SuppressWarnings("unchecked")
public <T extends WorkerJob> WorkerJobProcessor<T> create(String workerId,
String workerGroup,
T job) {
if (job instanceof WorkerTask) {
return (WorkerJobProcessor<T>) new WorkerTaskProcessor(
workerId,
workerGroup,
logService,
metricRegistry,
workerSecurityService,
tracer,
variablesService,
runContextInitializer,
runContextLoggerFactory,
new WorkerTaskResultQueue.Default(workerQueueFactory.getOrCreate(workerId, WorkerTaskResult.class)),
new WorkerMetricQueue.Default(workerQueueFactory.getOrCreate(workerId, MetricEntry.class))
);
} else if (job instanceof WorkerTrigger) {
return (WorkerJobProcessor<T>) new WorkerTriggerProcessor(
workerGroup,
logService,
metricRegistry,
workerSecurityService,
tracer,
runContextInitializer,
new WorkerLogQueue.Default(workerQueueFactory.getOrCreate(workerId, LogEntry.class)),
new WorkerTriggerResultQueue.Default(workerQueueFactory.getOrCreate(workerId, WorkerTriggerResult.class))
);
}
throw new IllegalArgumentException("Unsupported worker job type [" + job.getClass().getName() + "]");
}
}

View File

@@ -0,0 +1,447 @@
package io.kestra.worker.processors;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.google.common.collect.ImmutableList;
import io.kestra.core.exceptions.IllegalVariableEvaluationException;
import io.kestra.core.metrics.MetricRegistry;
import io.kestra.core.models.executions.MetricEntry;
import io.kestra.core.models.executions.TaskRun;
import io.kestra.core.models.executions.TaskRunAttempt;
import io.kestra.core.models.executions.Variables;
import io.kestra.core.models.flows.State;
import io.kestra.core.models.tasks.RunnableTask;
import io.kestra.core.models.tasks.Task;
import io.kestra.core.runners.DefaultRunContext;
import io.kestra.core.runners.RunContext;
import io.kestra.core.runners.RunContextInitializer;
import io.kestra.core.runners.RunContextLogger;
import io.kestra.core.runners.RunContextLoggerFactory;
import io.kestra.core.runners.WorkerSecurityService;
import io.kestra.core.runners.WorkerTask;
import io.kestra.core.runners.WorkerTaskCallable;
import io.kestra.core.runners.WorkerTaskResult;
import io.kestra.core.serializers.JacksonMapper;
import io.kestra.core.services.LogService;
import io.kestra.core.services.VariablesService;
import io.kestra.core.storages.StorageContext;
import io.kestra.core.trace.Tracer;
import io.kestra.core.utils.Hashing;
import io.kestra.core.utils.TruthUtils;
import io.kestra.plugin.core.flow.WorkingDirectory;
import io.kestra.worker.queues.WorkerMetricQueue;
import io.kestra.worker.queues.WorkerTaskResultQueue;
import lombok.extern.slf4j.Slf4j;
import org.slf4j.Logger;
import org.slf4j.event.Level;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.net.URI;
import java.nio.file.Files;
import java.nio.file.Path;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HexFormat;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.zip.ZipEntry;
import java.util.zip.ZipInputStream;
import java.util.zip.ZipOutputStream;
import static io.kestra.core.models.flows.State.Type.CREATED;
import static io.kestra.core.models.flows.State.Type.RUNNING;
import static io.kestra.core.models.flows.State.Type.SKIPPED;
import static io.kestra.core.models.flows.State.Type.SUCCESS;
import static io.kestra.core.models.flows.State.Type.WARNING;
@Slf4j
public class WorkerTaskProcessor extends AbstractWorkerJobProcessor<WorkerTask> {
private final RunContextInitializer runContextInitializer;
private final RunContextLoggerFactory runContextLoggerFactory;
private final String workerId;
private final String workerGroup;
private final VariablesService variablesService;
private final Map<Long, AtomicInteger> metricRunningCount = new ConcurrentHashMap<>();
private final WorkerTaskResultQueue workerTaskResultQueue;
private final WorkerMetricQueue workerMetricQueue;
public WorkerTaskProcessor(final String workerId,
final String workerGroup,
final LogService logService,
final MetricRegistry metricRegistry,
final WorkerSecurityService workerSecurityService,
final Tracer tracer,
final VariablesService variablesService,
final RunContextInitializer runContextInitializer,
final RunContextLoggerFactory runContextLoggerFactory,
final WorkerTaskResultQueue workerTaskResultQueue,
final WorkerMetricQueue workerMetricQueue) {
super(workerGroup, logService, metricRegistry, workerSecurityService, tracer);
this.runContextInitializer = runContextInitializer;
this.runContextLoggerFactory = runContextLoggerFactory;
this.workerGroup = workerGroup;
this.workerId = workerId;
this.variablesService = variablesService;
this.workerTaskResultQueue = workerTaskResultQueue;
this.workerMetricQueue = workerMetricQueue;
}
@Override
protected void doProcess(final WorkerTask workerTask) {
Task task = workerTask.getTask();
if (task instanceof RunnableTask) {
runTask(workerTask, true);
} else if (task instanceof WorkingDirectory workingDirectory) {
runWorkingDirectory(workerTask, workingDirectory);
} else {
throw new IllegalArgumentException("Unable to process the task '" + task.getId() + "' as it's not a runnable task");
}
}
private void runWorkingDirectory(WorkerTask workerTask, WorkingDirectory workingDirectory) {
DefaultRunContext runContext = runContextInitializer.forWorkingDirectory(((DefaultRunContext) workerTask.getRunContext()), workerTask);
final RunContext workingDirectoryRunContext = runContext.clone();
try {
// preExecuteTasks
try {
workingDirectory.preExecuteTasks(workingDirectoryRunContext, workerTask.getTaskRun());
} catch (Exception e) {
workingDirectoryRunContext.logger().error("Failed preExecuteTasks on WorkingDirectory: {}", e.getMessage(), e);
workerTask = workerTask.withTaskRun(workerTask.fail());
workerTaskResultQueue.put(new WorkerTaskResult(workerTask.getTaskRun()));
return;
}
// execute all tasks
for (Task currentTask : workingDirectory.getTasks()) {
if (Boolean.TRUE.equals(currentTask.getDisabled())) {
continue;
}
WorkerTask currentWorkerTask = workingDirectory.workerTask(
workerTask.getTaskRun(),
currentTask,
runContextInitializer.forPlugin(runContext, currentTask)
);
// all tasks will be handled immediately by the worker
WorkerTaskResult workerTaskResult = null;
try {
if (!TruthUtils.isTruthy(runContext.render(currentWorkerTask.getTask().getRunIf()))) {
workerTaskResult = new WorkerTaskResult(currentWorkerTask.getTaskRun().withState(SKIPPED));
workerTaskResultQueue.put(workerTaskResult);
} else {
workerTaskResult = this.runTask(currentWorkerTask, false);
}
} catch (IllegalVariableEvaluationException e) {
RunContextLogger contextLogger = runContextLoggerFactory.create(currentWorkerTask);
contextLogger.logger().error("Failed evaluating runIf: {}", e.getMessage(), e);
workerTaskResultQueue.put(new WorkerTaskResult(workerTask.fail()));
}
if (workerTaskResult == null || workerTaskResult.getTaskRun().getState().isFailed() && !currentWorkerTask.getTask().isAllowFailure()) {
break;
}
// create the next RunContext populated with the previous WorkerTaskResult
runContext = runContextInitializer.forWorker(runContext.clone(), workerTaskResult, workerTask.getTaskRun());
}
// postExecuteTasks
try {
workingDirectory.postExecuteTasks(workingDirectoryRunContext, workerTask.getTaskRun());
} catch (Exception e) {
workingDirectoryRunContext.logger().error("Failed postExecuteTasks on WorkingDirectory: {}", e.getMessage(), e);
workerTaskResultQueue.put(new WorkerTaskResult(workerTask.fail()));
}
} finally {
this.logTerminated(workerTask);
runContext.cleanup();
}
}
private WorkerTaskResult runTask(WorkerTask workerTask, boolean cleanUp) {
String[] metricTags = metricRegistry.tags(workerTask, workerGroup);
this.metricRegistry
.counter(MetricRegistry.METRIC_WORKER_STARTED_COUNT, MetricRegistry.METRIC_WORKER_STARTED_COUNT_DESCRIPTION, metricTags)
.increment();
if (workerTask.getTaskRun().getState().getCurrent() == CREATED) {
this.metricRegistry
.timer(MetricRegistry.METRIC_WORKER_QUEUED_DURATION, MetricRegistry.METRIC_WORKER_QUEUED_DURATION_DESCRIPTION, metricTags)
.record(Duration.between(
workerTask.getTaskRun().getState().getStartDate(), Instant.now()
));
}
try {
// TODO
/**
if (!Boolean.TRUE.equals(workerTask.getTaskRun().getForceExecution()) && killedExecution.contains(workerTask.getTaskRun().getExecutionId())) {
WorkerTaskResult workerTaskResult = new WorkerTaskResult(workerTask.getTaskRun().withState(KILLED));
workerTaskResultQueue.produce(workerTaskResult);
// We cannot remove the execution ID from the killedExecution in case the worker is processing multiple tasks of the execution
// which can happens due to parallel processing.
return workerTaskResult;
}
**/
logService.logTaskRun(
workerTask.getTaskRun(),
Level.INFO,
"Type {} started",
workerTask.getTask().getClass().getSimpleName()
);
workerTask = workerTask.withTaskRun(workerTask.getTaskRun().withState(RUNNING));
DefaultRunContext runContext = runContextInitializer.forWorker((DefaultRunContext) workerTask.getRunContext(), workerTask);
Optional<String> hash = Optional.empty();
if (workerTask.getTask().getTaskCache() != null && workerTask.getTask().getTaskCache().getEnabled()) {
runContext.logger().debug("Task output caching is enabled for task '{}''", workerTask.getTask().getId());
hash = hashTask(runContext, workerTask.getTask());
if (hash.isPresent()) {
try {
Optional<InputStream> cacheFile = runContext.storage().getCacheFile(hash.get(), workerTask.getTaskRun().getValue(), workerTask.getTask().getTaskCache().getTtl());
if (cacheFile.isPresent()) {
runContext.logger().info("Skipping task execution for task '{}' as there is an existing cache entry for it", workerTask.getTask().getId());
try (ZipInputStream archive = new ZipInputStream(cacheFile.get())) {
if (archive.getNextEntry() != null) {
byte[] cache = archive.readAllBytes();
Map<String, Object> outputMap = JacksonMapper.ofIon().readValue(cache, JacksonMapper.MAP_TYPE_REFERENCE);
Variables variables = variablesService.of(StorageContext.forTask(workerTask.getTaskRun()), outputMap);
TaskRunAttempt attempt = TaskRunAttempt.builder()
.state(new io.kestra.core.models.flows.State().withState(SUCCESS))
.workerId(this.workerId)
.build();
List<TaskRunAttempt> attempts = this.addAttempt(workerTask, attempt);
TaskRun taskRun = workerTask.getTaskRun().withAttempts(attempts).withOutputs(variables).withState(SUCCESS);
WorkerTaskResult workerTaskResult = new WorkerTaskResult(taskRun);
workerTaskResultQueue.put(workerTaskResult);
return workerTaskResult;
}
}
}
} catch (IOException | RuntimeException e) {
// in case of any exception, log an error and continue
runContext.logger().error("Unexpected exception while loading the cache for task '{}', the task will be executed instead.", workerTask.getTask().getId(), e);
}
}
}
// run
workerTask = this.runAttempt(runContext, workerTask);
// get last state
TaskRunAttempt lastAttempt = workerTask.getTaskRun().lastAttempt();
if (lastAttempt == null) {
throw new IllegalStateException("Can find lastAttempt on taskRun '" +
workerTask.getTaskRun().toString(true) + "'"
);
}
io.kestra.core.models.flows.State.Type state = lastAttempt.getState().getCurrent();
if (workerTask.getTask().getRetry() != null &&
workerTask.getTask().getRetry().getWarningOnRetry() &&
workerTask.getTaskRun().attemptNumber() > 1 &&
state == SUCCESS
) {
state = WARNING;
}
if (workerTask.getTask().isAllowFailure() && !workerTask.getTaskRun().shouldBeRetried(workerTask.getTask().getRetry()) && state.isFailed()) {
state = WARNING;
}
if (workerTask.getTask().isAllowWarning() && WARNING.equals(state)) {
state = SUCCESS;
}
// emit
List<WorkerTaskResult> dynamicWorkerResults = workerTask.getRunContext().dynamicWorkerResults();
List<TaskRun> dynamicTaskRuns = dynamicWorkerResults(dynamicWorkerResults);
workerTask = workerTask.withTaskRun(workerTask.getTaskRun().withState(state));
WorkerTaskResult workerTaskResult = new WorkerTaskResult(workerTask.getTaskRun(), dynamicTaskRuns);
workerTaskResultQueue.put(workerTaskResult);
// upload the cache file, hash may not be present if we didn't succeed in computing it
if (workerTask.getTask().getTaskCache() != null && workerTask.getTask().getTaskCache().getEnabled() && hash.isPresent() &&
(state == State.Type.SUCCESS || state == State.Type.WARNING)) {
runContext.logger().info("Uploading a cache entry for task '{}'", workerTask.getTask().getId());
try (ByteArrayOutputStream bos = new ByteArrayOutputStream();
ZipOutputStream archive = new ZipOutputStream(bos)) {
var zipEntry = new ZipEntry("outputs.ion");
archive.putNextEntry(zipEntry);
archive.write(JacksonMapper.ofIon().writeValueAsBytes(workerTask.getTaskRun().getOutputs()));
archive.closeEntry();
archive.finish();
Path archiveFile = runContext.workingDir().createTempFile(".zip");
Files.write(archiveFile, bos.toByteArray());
URI uri = runContext.storage().putCacheFile(archiveFile.toFile(), hash.get(), workerTask.getTaskRun().getValue());
runContext.logger().debug("Caching entry uploaded in URI {}", uri);
} catch (IOException | RuntimeException e) {
// in case of any exception, log an error and continue
runContext.logger().error("Unexpected exception while uploading the cache entry for task '{}', the task not be cached.", workerTask.getTask().getId(), e);
}
}
return workerTaskResult;
} finally {
this.logTerminated(workerTask);
// remove tmp directory
if (cleanUp) {
workerTask.getRunContext().cleanup();
}
}
}
private void logTerminated(WorkerTask workerTask) {
final String[] tags = metricRegistry.tags(workerTask, workerGroup);
metricRegistry
.counter(MetricRegistry.METRIC_WORKER_ENDED_COUNT, MetricRegistry.METRIC_WORKER_ENDED_COUNT_DESCRIPTION, tags)
.increment();
metricRegistry
.timer(MetricRegistry.METRIC_WORKER_ENDED_DURATION, MetricRegistry.METRIC_WORKER_ENDED_DURATION_DESCRIPTION, tags)
.record(workerTask.getTaskRun().getState().getDuration());
logService.logTaskRun(
workerTask.getTaskRun(),
Level.INFO,
"Type {} with state {} completed in {}",
workerTask.getTask().getClass().getSimpleName(),
workerTask.getTaskRun().getState().getCurrent(),
workerTask.getTaskRun().getState().humanDuration()
);
}
private WorkerTask runAttempt(final RunContext runContext, final WorkerTask workerTask) {
Logger logger = runContext.logger();
if (!(workerTask.getTask() instanceof RunnableTask<?> task)) {
// This should never happen but better to deal with it than crashing the Worker
var state = State.Type.fail(workerTask.getTask());
TaskRunAttempt attempt = TaskRunAttempt.builder()
.state(new io.kestra.core.models.flows.State().withState(state))
.workerId(this.workerId)
.build();
List<TaskRunAttempt> attempts = this.addAttempt(workerTask, attempt);
TaskRun taskRun = workerTask.getTaskRun().withAttempts(attempts);
logger.error("Unable to execute the task '{}': only runnable tasks can be executed by the worker but the task is of type {}", workerTask.getTask().getId(), workerTask.getTask().getClass());
return workerTask.withTaskRun(taskRun);
}
TaskRunAttempt.TaskRunAttemptBuilder builder = TaskRunAttempt.builder()
.state(new io.kestra.core.models.flows.State().withState(RUNNING))
.workerId(this.workerId);
// emit the attempt so the execution knows that the task is in RUNNING
workerTaskResultQueue.put(new WorkerTaskResult(
workerTask.getTaskRun()
.withAttempts(this.addAttempt(workerTask, builder.build()))
)
);
AtomicInteger metricRunningCount = getMetricRunningCount(workerTask);
metricRunningCount.incrementAndGet();
// run it
WorkerTaskCallable workerTaskCallable = new WorkerTaskCallable(workerTask, task, runContext, metricRegistry);
io.kestra.core.models.flows.State.Type state = callJob(workerTaskCallable);
metricRunningCount.decrementAndGet();
// attempt
TaskRunAttempt taskRunAttempt = builder
.build()
.withState(state)
.withLogFile(runContext.logFileURI());
// metrics
runContext.metrics()
.stream()
.map(metric -> MetricEntry.of(workerTask.getTaskRun(), metric, workerTask.getExecutionKind()))
.forEach(workerMetricQueue::put);
// save outputs
List<TaskRunAttempt> attempts = this.addAttempt(workerTask, taskRunAttempt);
TaskRun taskRun = workerTask.getTaskRun()
.withAttempts(attempts);
try {
Variables variables = variablesService.of(StorageContext.forTask(taskRun), workerTaskCallable.getTaskOutput());
taskRun = taskRun.withOutputs(variables);
} catch (Exception e) {
logger.warn("Unable to save output on taskRun '{}'", taskRun, e);
}
return workerTask
.withTaskRun(taskRun);
}
private List<TaskRunAttempt> addAttempt(WorkerTask workerTask, TaskRunAttempt taskRunAttempt) {
return ImmutableList.<TaskRunAttempt>builder()
.addAll(workerTask.getTaskRun().getAttempts() == null ? new ArrayList<>() : workerTask.getTaskRun().getAttempts())
.add(taskRunAttempt)
.build();
}
private Optional<String> hashTask(RunContext runContext, Task task) {
try {
var map = JacksonMapper.toMap(task);
var rMap = runContext.render(map);
var json = JacksonMapper.ofJson().writeValueAsBytes(rMap);
MessageDigest digest = MessageDigest.getInstance("SHA-256");
digest.update(json);
byte[] bytes = digest.digest();
return Optional.of(HexFormat.of().formatHex(bytes));
} catch (RuntimeException | IllegalVariableEvaluationException | JsonProcessingException |
NoSuchAlgorithmException e) {
runContext.logger().error("Unable to create the cache key for the task '{}'", task.getId(), e);
return Optional.empty();
}
}
private List<TaskRun> dynamicWorkerResults(List<WorkerTaskResult> dynamicWorkerResults) {
return dynamicWorkerResults
.stream()
.map(WorkerTaskResult::getTaskRun)
.map(taskRun -> taskRun.withDynamic(true))
.toList();
}
public AtomicInteger getMetricRunningCount(final WorkerTask workerTask) {
String[] tags = this.metricRegistry.tags(workerTask, workerGroup);
Arrays.sort(tags);
long index = Hashing.hashToLong(String.join("-", tags));
return this.metricRunningCount
.computeIfAbsent(index, l -> metricRegistry.gauge(
MetricRegistry.METRIC_WORKER_RUNNING_COUNT,
MetricRegistry.METRIC_WORKER_RUNNING_COUNT_DESCRIPTION,
new AtomicInteger(0),
tags
));
}
}

View File

@@ -0,0 +1,268 @@
package io.kestra.worker.processors;
import com.google.common.base.Throwables;
import io.kestra.core.metrics.MetricRegistry;
import io.kestra.core.models.Label;
import io.kestra.core.models.executions.Execution;
import io.kestra.core.models.executions.LogEntry;
import io.kestra.core.models.tasks.Output;
import io.kestra.core.models.triggers.PollingTriggerInterface;
import io.kestra.core.models.triggers.RealtimeTriggerInterface;
import io.kestra.core.models.triggers.TriggerService;
import io.kestra.core.runners.DefaultRunContext;
import io.kestra.core.runners.RunContextInitializer;
import io.kestra.core.runners.RunContextLogger;
import io.kestra.core.runners.WorkerSecurityService;
import io.kestra.core.runners.WorkerTrigger;
import io.kestra.core.runners.WorkerTriggerCallable;
import io.kestra.core.runners.WorkerTriggerRealtimeCallable;
import io.kestra.core.runners.WorkerTriggerResult;
import io.kestra.core.services.LabelService;
import io.kestra.core.services.LogService;
import io.kestra.core.trace.Tracer;
import io.kestra.worker.queues.WorkerLogQueue;
import io.kestra.worker.queues.WorkerTriggerResultQueue;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.time.DurationFormatUtils;
import org.apache.commons.lang3.time.StopWatch;
import org.slf4j.Logger;
import org.slf4j.event.Level;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import static io.kestra.core.models.flows.State.Type.FAILED;
import static io.kestra.core.models.flows.State.Type.SUCCESS;
@Slf4j
public class WorkerTriggerProcessor extends AbstractWorkerJobProcessor<WorkerTrigger> {
private final Map<String, AtomicInteger> evaluateTriggerRunningCount = new ConcurrentHashMap<>();
private final WorkerLogQueue workerLogQueue;
private final WorkerTriggerResultQueue workerTriggerResultQueue;
private final RunContextInitializer runContextInitializer;
public WorkerTriggerProcessor(String workerGroup,
LogService logService,
MetricRegistry metricRegistry,
WorkerSecurityService workerSecurityService,
Tracer tracer,
final RunContextInitializer runContextInitializer,
WorkerLogQueue workerLogQueue,
WorkerTriggerResultQueue workerTriggerResultQueue) {
super(workerGroup, logService, metricRegistry, workerSecurityService, tracer);
this.workerLogQueue = workerLogQueue;
this.workerTriggerResultQueue = workerTriggerResultQueue;
this.runContextInitializer = runContextInitializer;
}
@Override
protected void doProcess(WorkerTrigger workerTrigger) {
final String[] metricsTags = metricRegistry.tags(workerTrigger, workerGroup);
this.metricRegistry
.counter(MetricRegistry.METRIC_WORKER_TRIGGER_STARTED_COUNT, MetricRegistry.METRIC_WORKER_TRIGGER_STARTED_COUNT_DESCRIPTION, metricsTags)
.increment();
this.metricRegistry
.timer(MetricRegistry.METRIC_WORKER_TRIGGER_DURATION, MetricRegistry.METRIC_WORKER_TRIGGER_DURATION_DESCRIPTION, metricsTags)
.record(() -> {
StopWatch stopWatch = new StopWatch();
stopWatch.start();
this.evaluateTriggerRunningCount.computeIfAbsent(workerTrigger.getTriggerContext().uid(), s -> metricRegistry
.gauge(MetricRegistry.METRIC_WORKER_TRIGGER_RUNNING_COUNT, MetricRegistry.METRIC_WORKER_TRIGGER_RUNNING_COUNT_DESCRIPTION, new AtomicInteger(0), metricsTags));
this.evaluateTriggerRunningCount.get(workerTrigger.getTriggerContext().uid()).addAndGet(1);
DefaultRunContext runContext = (DefaultRunContext) workerTrigger.getConditionContext().getRunContext();
runContextInitializer.forWorker(runContext, workerTrigger);
try {
logService.logTrigger(
workerTrigger.getTriggerContext(),
runContext.logger(),
Level.INFO,
"Type {} started",
workerTrigger.getTrigger().getType()
);
if (workerTrigger.getTrigger() instanceof PollingTriggerInterface pollingTrigger) {
WorkerTriggerCallable workerCallable = new WorkerTriggerCallable(runContext, workerTrigger, pollingTrigger);
io.kestra.core.models.flows.State.Type state = callJob(workerCallable);
if (workerCallable.getException() != null || !state.equals(SUCCESS)) {
this.handleTriggerError(workerTrigger, workerCallable.getException());
}
if (!state.equals(FAILED)) {
this.publishTriggerExecution(workerTrigger, workerCallable.getEvaluate());
}
} else if (workerTrigger.getTrigger() instanceof RealtimeTriggerInterface streamingTrigger) {
WorkerTriggerRealtimeCallable workerCallable = new WorkerTriggerRealtimeCallable(
runContext,
workerTrigger,
streamingTrigger,
throwable -> this.handleTriggerError(workerTrigger, throwable),
execution -> this.publishTriggerExecution(workerTrigger, Optional.of(execution))
);
io.kestra.core.models.flows.State.Type state = callJob(workerCallable);
// here the realtime trigger fail before the publisher being call so we create a fail execution
if (workerCallable.getException() != null || !state.equals(SUCCESS)) {
this.handleRealtimeTriggerError(workerTrigger, workerCallable.getException());
}
}
} catch (Exception e) {
this.handleTriggerError(workerTrigger, e);
} finally {
logService.logTrigger(
workerTrigger.getTriggerContext(),
runContext.logger(),
Level.INFO,
"Type {} completed in {}",
workerTrigger.getTrigger().getType(),
DurationFormatUtils.formatDurationHMS(stopWatch.getTime(TimeUnit.MILLISECONDS))
);
workerTrigger.getConditionContext().getRunContext().cleanup();
}
this.evaluateTriggerRunningCount.get(workerTrigger.getTriggerContext().uid()).addAndGet(-1);
}
);
metricRegistry
.counter(MetricRegistry.METRIC_WORKER_TRIGGER_ENDED_COUNT, MetricRegistry.METRIC_WORKER_TRIGGER_ENDED_COUNT_DESCRIPTION, metricsTags)
.increment();
}
private void handleTriggerError(WorkerTrigger workerTrigger, Throwable e) {
String[] tags = metricRegistry.tags(workerTrigger, workerGroup);
metricRegistry
.counter(MetricRegistry.METRIC_WORKER_TRIGGER_ERROR_COUNT, MetricRegistry.METRIC_WORKER_TRIGGER_ERROR_COUNT_DESCRIPTION, tags)
.increment();
logError(workerTrigger, e);
Execution execution = workerTrigger.getTrigger().isFailOnTriggerError() ? TriggerService.generateExecution(workerTrigger.getTrigger(), workerTrigger.getConditionContext(), workerTrigger.getTriggerContext(), (Output) null)
.withState(FAILED) : null;
if (execution != null) {
RunContextLogger.logEntries(Execution.loggingEventFromException(e), LogEntry.of(execution)).forEach(workerLogQueue::put);
}
this.workerTriggerResultQueue.put(
WorkerTriggerResult.builder()
.triggerContext(workerTrigger.getTriggerContext())
.trigger(workerTrigger.getTrigger())
.execution(Optional.ofNullable(execution))
.build()
);
}
private void handleRealtimeTriggerError(WorkerTrigger workerTrigger, Throwable e) {
String[] tags = metricRegistry.tags(workerTrigger, workerGroup);
this.metricRegistry
.counter(MetricRegistry.METRIC_WORKER_TRIGGER_ERROR_COUNT, MetricRegistry.METRIC_WORKER_TRIGGER_ERROR_COUNT_DESCRIPTION, tags)
.increment();
// We create a FAILED execution, so the user is aware that the realtime trigger failed to be created
var execution = TriggerService
.generateRealtimeExecution(workerTrigger.getTrigger(), workerTrigger.getConditionContext(), workerTrigger.getTriggerContext(), null)
.withState(FAILED);
// We create an ERROR log attached to the execution
Logger logger = workerTrigger.getConditionContext().getRunContext().logger();
logService.logExecution(
execution,
logger,
Level.ERROR,
"[date: {}] Realtime trigger failed to be created in the worker with error: {}",
workerTrigger.getTriggerContext().getDate(),
e != null ? e.getMessage() : "unknown",
e
);
if (logger.isTraceEnabled() && e != null) {
logger.trace(Throwables.getStackTraceAsString(e));
}
this.workerTriggerResultQueue.put(
WorkerTriggerResult.builder()
.execution(Optional.of(execution))
.triggerContext(workerTrigger.getTriggerContext())
.trigger(workerTrigger.getTrigger())
.build()
);
}
private void publishTriggerExecution(WorkerTrigger workerTrigger, Optional<Execution> evaluate) {
metricRegistry
.counter(
MetricRegistry.METRIC_WORKER_TRIGGER_EXECUTION_COUNT,
MetricRegistry.METRIC_WORKER_TRIGGER_EXECUTION_COUNT_DESCRIPTION,
metricRegistry.tags(workerTrigger, workerGroup)
).increment();
if (log.isDebugEnabled()) {
logService.logTrigger(
workerTrigger.getTriggerContext(),
Level.DEBUG,
"[type: {}] {}",
workerTrigger.getTrigger().getType(),
evaluate.map(execution -> "New execution '" + execution.getId() + "'").orElse("Empty evaluation")
);
}
var flow = workerTrigger.getConditionContext().getFlow();
if (flow.getLabels() != null) {
evaluate = evaluate.map(execution -> {
List<Label> executionLabels = execution.getLabels() != null ? execution.getLabels() : new ArrayList<>();
executionLabels.addAll(LabelService.labelsExcludingSystem(flow));
return execution.withLabels(executionLabels);
}
);
}
this.workerTriggerResultQueue.put(
WorkerTriggerResult.builder()
.execution(evaluate)
.triggerContext(workerTrigger.getTriggerContext())
.trigger(workerTrigger.getTrigger())
.build()
);
}
private void logError(WorkerTrigger workerTrigger, Throwable e) {
Logger logger = workerTrigger.getConditionContext().getRunContext().logger();
if (e instanceof InterruptedException || (e != null && e.getCause() instanceof InterruptedException)) {
logService.logTrigger(
workerTrigger.getTriggerContext(),
logger,
Level.WARN,
"[date: {}] Trigger evaluation interrupted in the worker",
workerTrigger.getTriggerContext().getDate()
);
} else {
logService.logTrigger(
workerTrigger.getTriggerContext(),
logger,
Level.WARN,
"[date: {}] Trigger evaluation failed in the worker with error: {}",
workerTrigger.getTriggerContext().getDate(),
e != null ? e.getMessage() : "unknown",
e
);
}
if (logger.isTraceEnabled() && e != null) {
logger.trace(Throwables.getStackTraceAsString(e));
}
}
}

View File

@@ -0,0 +1,53 @@
package io.kestra.worker.queues;
import java.time.Duration;
import java.util.Objects;
public abstract class AbstractDelegateWorkerQueue<T> implements WorkerQueue<T> {
private final WorkerQueue<T> queue;
public AbstractDelegateWorkerQueue(final WorkerQueue<T> queue) {
this.queue = Objects.requireNonNull(queue, "queue must not be null.");
}
/**
* {@inheritDoc}
*/
@Override
public T poll(Duration timeout) throws InterruptedException {
return queue.poll(timeout);
}
/**
* {@inheritDoc}
*/
@Override
public void put(T event) {
queue.put(event);
}
/**
* {@inheritDoc}
*/
@Override
public int remainingCapacity() {
return queue.remainingCapacity();
}
/**
* {@inheritDoc}
*/
@Override
public int capacity() {
return queue.capacity();
}
/**
* {@inheritDoc}
*/
@Override
public int size() {
return queue.size();
}
}

View File

@@ -0,0 +1,61 @@
package io.kestra.worker.queues;
import java.time.Duration;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
public class InMemoryWorkerQueue<T> implements WorkerQueue<T> {
private final int capacity;
private final LinkedBlockingQueue<T> queue;
public InMemoryWorkerQueue(int capacity) {
this.capacity = capacity;
this.queue = new LinkedBlockingQueue<>(capacity);
}
/**
* {@inheritDoc}
*/
@Override
public T poll(Duration timeout) throws InterruptedException {
return queue.poll(timeout.toMillis(), TimeUnit.MILLISECONDS);
}
/**
* {@inheritDoc}
*/
@Override
public void put(T event) {
try {
this.queue.put(event);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}
/**
* {@inheritDoc}
*/
@Override
public int remainingCapacity() {
return this.queue.remainingCapacity();
}
/**
* {@inheritDoc}
*/
@Override
public int capacity() {
return this.capacity;
}
/**
* {@inheritDoc}
*/
@Override
public int size() {
return this.queue.size();
}
}

View File

@@ -0,0 +1,48 @@
package io.kestra.worker.queues;
import io.kestra.core.metrics.MetricRegistry;
import io.micrometer.core.instrument.Counter;
import java.time.Duration;
public class MonitoredWorkerQueue<T> extends AbstractDelegateWorkerQueue<T> {
public static final String QUEUE_SIZE = "queue.size";
public static final String QUEUE_REMAINING_CAPACITY = "queue.remaining.capacity";
public static final String QUEUE_ENQUEUED = "queue.enqueued";
public static final String QUEUE_DEQUEUED = "queue.dequeued";
private final MetricRegistry metricRegistry;
private final Counter enqueuedCounter;
private final Counter dequeuedCounter;
public MonitoredWorkerQueue(MetricRegistry metricRegistry, String queueName, WorkerQueue<T> queue) {
super(queue);
this.metricRegistry = metricRegistry;
String[] tags = new String[]{"name", queueName};
this.metricRegistry.registerGauge(QUEUE_SIZE, "Current number of items in the queue", this::size, tags);
this.metricRegistry.registerGauge(QUEUE_REMAINING_CAPACITY, "Remaining capacity in the queue", this::size, tags);
this.enqueuedCounter = this.metricRegistry.counter(QUEUE_ENQUEUED, "Number of items enqueued", tags);
this.dequeuedCounter = this.metricRegistry.counter(QUEUE_DEQUEUED, "Number of items dequeued", tags);
}
/**
* {@inheritDoc}
*/
@Override
public T poll(Duration timeout) throws InterruptedException {
T item = super.poll(timeout);
dequeuedCounter.increment();
return item;
}
/**
* {@inheritDoc}
*/
@Override
public void put(T event) {
super.put(event);
enqueuedCounter.increment();
}
}

View File

@@ -0,0 +1,18 @@
package io.kestra.worker.queues;
import io.kestra.core.runners.WorkerJob;
/**
*
*/
public interface WorkerJobQueue extends WorkerQueue<WorkerJob> {
/**
* The default {@link WorkerJob} implementation
*/
class Default extends AbstractDelegateWorkerQueue<WorkerJob> implements WorkerJobQueue {
public Default(final WorkerQueue<WorkerJob> queue) {
super(queue);
}
}
}

View File

@@ -0,0 +1,20 @@
package io.kestra.worker.queues;
import io.kestra.core.models.executions.LogEntry;
import io.kestra.core.models.executions.MetricEntry;
import io.kestra.core.runners.WorkerTriggerResult;
import java.time.Duration;
import java.util.Objects;
public interface WorkerLogQueue extends WorkerQueue<LogEntry>{
/**
* The default {@link LogEntry} implementation
*/
class Default extends AbstractDelegateWorkerQueue<LogEntry> implements WorkerLogQueue {
public Default(final WorkerQueue<LogEntry> queue) {
super(queue);
}
}
}

View File

@@ -0,0 +1,15 @@
package io.kestra.worker.queues;
import io.kestra.core.models.executions.MetricEntry;
public interface WorkerMetricQueue extends WorkerQueue<MetricEntry> {
/**
* The default {@link MetricEntry} implementation
*/
class Default extends AbstractDelegateWorkerQueue<MetricEntry> implements WorkerMetricQueue {
public Default(final WorkerQueue<MetricEntry> queue) {
super(queue);
}
}
}

View File

@@ -0,0 +1,23 @@
package io.kestra.worker.queues;
import java.time.Duration;
/**
* Represents an event queue used for worker intra-processes communication.
* <p>
* Implementations of this interface are expected to be in-memory oriented.
*
* @param <T> type of the queue.
*/
public interface WorkerQueue<T> {
T poll(Duration timeout) throws InterruptedException;
void put(T event);
int remainingCapacity();
int capacity();
int size();
}

View File

@@ -0,0 +1,37 @@
package io.kestra.worker.queues;
import io.kestra.core.metrics.MetricRegistry;
import jakarta.inject.Inject;
import jakarta.inject.Singleton;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
@Singleton
public class WorkerQueueFactory {
public static final int DEFAULT_QUEUE_SIZE = 5000;
private final Map<QueueKey, WorkerQueue<?>> queues;
private final MetricRegistry metricRegistry;
@Inject
public WorkerQueueFactory(final MetricRegistry metricRegistry) {
this.queues = new ConcurrentHashMap<>();
this.metricRegistry = metricRegistry;
}
@SuppressWarnings("unchecked")
public synchronized <T> WorkerQueue<T> getOrCreate(final String workerId, final Class<T> type) {
QueueKey key = new QueueKey(workerId, type);
return (WorkerQueue<T>) queues.computeIfAbsent(key, unused ->
new MonitoredWorkerQueue<T>(metricRegistry, type.getSimpleName().toLowerCase(),
new InMemoryWorkerQueue<>(DEFAULT_QUEUE_SIZE)
)
);
}
private record QueueKey(String workerId, Class<?> type) {
}
}

View File

@@ -0,0 +1,15 @@
package io.kestra.worker.queues;
import io.kestra.core.runners.WorkerTaskResult;
public interface WorkerTaskResultQueue extends WorkerQueue<WorkerTaskResult> {
/**
* The default {@link WorkerTaskResultQueue} implementation
*/
class Default extends AbstractDelegateWorkerQueue<WorkerTaskResult> implements WorkerTaskResultQueue {
public Default(final WorkerQueue<WorkerTaskResult> queue) {
super(queue);
}
}
}

View File

@@ -0,0 +1,18 @@
package io.kestra.worker.queues;
import io.kestra.core.runners.WorkerTriggerResult;
/**
* Typed worker queue for {@link WorkerTriggerResult}.
*/
public interface WorkerTriggerResultQueue extends WorkerQueue<WorkerTriggerResult> {
/**
* The default {@link WorkerTriggerResultQueue} implementation
*/
class Default extends AbstractDelegateWorkerQueue<WorkerTriggerResult> implements WorkerTriggerResultQueue{
public Default(final WorkerQueue<WorkerTriggerResult> queue) {
super(queue);
}
}
}

View File

@@ -0,0 +1,20 @@
syntax = "proto3";
option java_package = "io.kestra.controller.grpc";
option java_multiple_files = true;
import "request.proto";
service LivenessControllerService {
rpc heartbeat(HeartbeatRequest) returns (HeartbeatResponse);
}
message HeartbeatRequest {
RequestOrResponseHeader header = 1;
bytes message = 2;
}
message HeartbeatResponse {
RequestOrResponseHeader header = 1;
bytes message = 2;
}

View File

@@ -0,0 +1,16 @@
syntax = "proto3";
option java_package = "io.kestra.server.grpc";
option java_multiple_files = true;
// Common request and response header
message RequestOrResponseHeader {
// The client ID string.
string clientId = 1;
// The client version.
string clientVersion = 2;
// The correlation ID of this request.
string correlationId = 3;
// The format of the message
string messageFormat = 4;
}

View File

@@ -0,0 +1,34 @@
syntax = "proto3";
option java_package = "io.kestra.worker.grpc";
option java_multiple_files = true;
import "request.proto";
service WorkerControllerService {
rpc fetchWorkerJobs(FetchWorkerJobRequest) returns (FetchWorkerJobResponse);
rpc fetchWorkerJobsStream(FetchWorkerJobRequest) returns (stream FetchWorkerJobResponse);
rpc sendWorkerJobResults(WorkerJobResultsRequest) returns (WorkerJobResultsResponse);
}
message FetchWorkerJobRequest {
RequestOrResponseHeader header = 1;
bytes message = 2;
}
message FetchWorkerJobResponse {
RequestOrResponseHeader header = 1;
bytes message = 2;
}
message WorkerJobResultsRequest {
RequestOrResponseHeader header = 1;
bytes message = 2;
}
message WorkerJobResultsResponse {
RequestOrResponseHeader header = 1;
bytes message = 2;
}