Skip to content

Commit

Permalink
chore: wirering IoC with TaskRepository #1626
Browse files Browse the repository at this point in the history
  • Loading branch information
bamthomas committed Jan 17, 2025
1 parent ac6e7b6 commit f9aceee
Show file tree
Hide file tree
Showing 15 changed files with 71 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import org.icij.datashare.Repository;
import org.icij.datashare.asynctasks.TaskManager;
import org.icij.datashare.asynctasks.TaskModifier;
import org.icij.datashare.asynctasks.TaskRepository;
import org.icij.datashare.tasks.TaskRepositoryRedis;
import org.icij.datashare.asynctasks.TaskSupplier;
import org.icij.datashare.batch.BatchSearchRepository;
import org.icij.datashare.cli.Mode;
Expand Down Expand Up @@ -138,11 +140,13 @@ protected void configure() {
QueueType batchQueueType = getQueueType(propertiesProvider, BATCH_QUEUE_TYPE_OPT, QueueType.MEMORY);
switch ( batchQueueType ) {
case REDIS:
bind(TaskRepository.class).to(TaskRepositoryRedis.class);
bind(TaskManager.class).to(TaskManagerRedis.class);
bind(TaskModifier.class).to(TaskSupplierRedis.class);
bind(TaskSupplier.class).to(TaskSupplierRedis.class);
break;
case AMQP:
bind(TaskRepository.class).to(TaskRepositoryRedis.class);
bind(TaskManager.class).to(TaskManagerAmqp.class);
bind(TaskSupplier.class).to(TaskSupplierAmqp.class);
bind(TaskModifier.class).to(TaskSupplierAmqp.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,9 @@
import org.icij.datashare.PropertiesProvider;
import org.icij.datashare.asynctasks.TaskManagerRedis;
import org.icij.datashare.asynctasks.Task;
import org.icij.datashare.asynctasks.TaskRepository;
import org.icij.datashare.asynctasks.bus.amqp.AmqpInterlocutor;
import org.icij.datashare.cli.DatashareCliOptions;
import org.icij.datashare.mode.CommonMode;
import org.jetbrains.annotations.NotNull;
import org.redisson.Redisson;
import org.redisson.RedissonMap;
import org.redisson.api.RedissonClient;
Expand All @@ -19,27 +18,13 @@

@Singleton
public class TaskManagerAmqp extends org.icij.datashare.asynctasks.TaskManagerAmqp {

@Inject
public TaskManagerAmqp(AmqpInterlocutor amqp, RedissonClient redissonClient, PropertiesProvider propertiesProvider)
public TaskManagerAmqp(AmqpInterlocutor amqp, TaskRepository taskRepository, PropertiesProvider propertiesProvider)
throws IOException {
this(amqp, redissonClient, propertiesProvider, null);
}

TaskManagerAmqp(AmqpInterlocutor amqp, RedissonClient redissonClient, PropertiesProvider propertiesProvider, Runnable eventCallback) throws IOException {
// We start with a fresh list of known task everytime, we could decide to allow inheriting
// existing tasks
super(amqp, createTaskQueue(redissonClient), Utils.getRoutingStrategy(propertiesProvider), eventCallback);
this(amqp, taskRepository, propertiesProvider, null);
}

private static RedissonMap<String, Task<?>> createTaskQueue(RedissonClient redissonClient) {
return new RedissonMap<>(new TaskManagerRedis.TaskViewCodec(),
new CommandSyncService(((Redisson) redissonClient).getConnectionManager(),
new RedissonObjectBuilder(redissonClient)),
CommonMode.DS_TASK_MANAGER_QUEUE_NAME,
redissonClient,
null,
null
);
TaskManagerAmqp(AmqpInterlocutor amqp, TaskRepository taskRepository, PropertiesProvider propertiesProvider, Runnable eventCallback) throws IOException {
super(amqp, taskRepository, Utils.getRoutingStrategy(propertiesProvider), eventCallback);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import com.google.inject.Inject;
import com.google.inject.Singleton;
import org.icij.datashare.PropertiesProvider;
import org.icij.datashare.asynctasks.TaskRepositoryMemory;

import java.util.concurrent.CountDownLatch;

Expand All @@ -16,6 +17,6 @@ public TaskManagerMemory(DatashareTaskFactory taskFactory, PropertiesProvider pr
}

TaskManagerMemory(DatashareTaskFactory taskFactory, PropertiesProvider propertiesProvider, CountDownLatch latch) {
super(taskFactory, propertiesProvider, latch);
super(taskFactory, propertiesProvider, new TaskRepositoryMemory(), latch);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package org.icij.datashare.tasks;

import com.google.inject.Inject;
import com.google.inject.Singleton;
import org.redisson.api.RedissonClient;

@Singleton
public class TaskRepositoryRedis extends org.icij.datashare.asynctasks.TaskRepositoryRedis{
@Inject
public TaskRepositoryRedis(RedissonClient redisson) {
super(redisson);
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package org.icij.datashare.mode;

import org.icij.datashare.asynctasks.TaskManager;
import org.icij.datashare.asynctasks.TaskModifier;
import org.icij.datashare.asynctasks.TaskSupplier;
import org.icij.datashare.asynctasks.TaskWorkerLoop;
import org.icij.datashare.cli.QueueType;
Expand All @@ -25,7 +26,7 @@ public class CliModeWorkerAcceptanceTest {
public static Collection<Object[]> mode() throws Exception {
return asList(new Object[][]{
{
CommonMode.create(Map.of(
CommonMode.create(Map.of(
"dataDir", "/tmp",
"mode", "TASK_WORKER",
"batchQueueType", QueueType.AMQP.name(),
Expand All @@ -51,7 +52,7 @@ public CliModeWorkerAcceptanceTest(CommonMode mode) {
@Test(timeout = 30000)
public void test_task_worker() throws Exception {
CountDownLatch workerStarted = new CountDownLatch(1);
TaskWorkerLoop taskWorkerLoop = new TaskWorkerLoop(mode.get(DatashareTaskFactory.class), mode.get(TaskSupplier.class), workerStarted);
TaskWorkerLoop taskWorkerLoop = new TaskWorkerLoop(mode.get(DatashareTaskFactory.class), mode.get(TaskSupplier.class), workerStarted, 1000);
Thread workerApp = new Thread(taskWorkerLoop::call);
workerApp.start();
workerStarted.await();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,33 +8,31 @@
import org.icij.datashare.asynctasks.bus.amqp.TaskEvent;

import org.icij.datashare.tasks.RoutingStrategy;
import org.icij.datashare.user.User;

import java.io.IOException;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.function.Consumer;
import java.util.regex.Pattern;

import static java.util.Optional.ofNullable;
import static java.util.stream.Collectors.toList;

public class TaskManagerAmqp implements TaskManager {
private final Map<String, Task<?>> tasks;
private final TaskRepository tasks;
private final RoutingStrategy routingStrategy;
private final AmqpInterlocutor amqp;
private final AmqpConsumer<TaskEvent, Consumer<TaskEvent>> eventConsumer;

public TaskManagerAmqp(AmqpInterlocutor amqp, Map<String, Task<?>> tasks) throws IOException {
this(amqp, tasks, RoutingStrategy.UNIQUE);
public TaskManagerAmqp(AmqpInterlocutor amqp, TaskRepository taskRepository) throws IOException {
this(amqp, taskRepository, RoutingStrategy.UNIQUE);
}

public TaskManagerAmqp(AmqpInterlocutor amqp, Map<String, Task<?>> tasks, RoutingStrategy routingStrategy) throws IOException {
public TaskManagerAmqp(AmqpInterlocutor amqp, TaskRepository tasks, RoutingStrategy routingStrategy) throws IOException {
this(amqp, tasks, routingStrategy, null);
}

public TaskManagerAmqp(AmqpInterlocutor amqp, Map<String, Task<?>> tasks, RoutingStrategy routingStrategy, Runnable eventCallback) throws IOException {
public TaskManagerAmqp(AmqpInterlocutor amqp, TaskRepository tasks, RoutingStrategy routingStrategy, Runnable eventCallback) throws IOException {
this.amqp = amqp;
this.tasks = tasks;
this.routingStrategy = routingStrategy;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,16 @@
import org.icij.datashare.PropertiesProvider;
import org.icij.datashare.asynctasks.bus.amqp.Event;
import org.icij.datashare.asynctasks.bus.amqp.TaskError;
import org.icij.datashare.user.User;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.io.Serializable;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Consumer;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

Expand All @@ -28,24 +25,25 @@
public class TaskManagerMemory implements TaskManager, TaskSupplier {
private final Logger logger = LoggerFactory.getLogger(getClass());
private final ExecutorService executor;
private final ConcurrentMap<String, Task<?>> tasks = new ConcurrentHashMap<>();
private final TaskRepository tasks;
private final BlockingQueue<Task<?>> taskQueue;
private final List<TaskWorkerLoop> loops;
private final AtomicInteger executedTasks = new AtomicInteger(0);
private final int pollingInterval;

public TaskManagerMemory(TaskFactory taskFactory) {
this(taskFactory, new PropertiesProvider(), new CountDownLatch(1));
this(taskFactory, new PropertiesProvider(), new TaskRepositoryMemory(), new CountDownLatch(1));
}

public TaskManagerMemory(TaskFactory taskFactory, PropertiesProvider propertiesProvider, CountDownLatch latch) {
public TaskManagerMemory(TaskFactory taskFactory, PropertiesProvider propertiesProvider, TaskRepository tasks, CountDownLatch latch) {
this.taskQueue = new LinkedBlockingQueue<>();
int parallelism = parseInt(propertiesProvider.get("parallelism").orElse("1"));
pollingInterval = Integer.parseInt(propertiesProvider.get("pollingInterval").orElse("60"));
logger.info("running TaskManager {} with {} workers", this, parallelism);
executor = Executors.newFixedThreadPool(parallelism);
loops = IntStream.range(0, parallelism).mapToObj(i -> new TaskWorkerLoop(taskFactory, this, latch, pollingInterval)).collect(Collectors.toList());
loops.forEach(executor::submit);
this.tasks = tasks;
}

public <V> Task<V> getTask(final String taskId) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
public class TaskManagerRedis implements TaskManager {
private final Runnable eventCallback; // for test
public static final String EVENT_CHANNEL_NAME = "EVENT";
private final RedissonMap<String, Task<?>> tasks;
private final TaskRepository tasks;
private final RTopic eventTopic;
private final RedissonClient redissonClient;
private final RoutingStrategy routingStrategy;
Expand All @@ -55,8 +55,7 @@ public TaskManagerRedis(RedissonClient redissonClient, String taskMapName) {
public TaskManagerRedis(RedissonClient redissonClient, String taskMapName, RoutingStrategy routingStrategy, Runnable eventCallback) {
this.redissonClient = redissonClient;
this.routingStrategy = routingStrategy;
CommandSyncService commandSyncService = getCommandSyncService();
this.tasks = new RedissonMap<>(new TaskViewCodec(), commandSyncService, taskMapName, redissonClient, null, null);
this.tasks = new TaskRepositoryRedis(redissonClient, taskMapName);
this.eventTopic = redissonClient.getTopic(EVENT_CHANNEL_NAME);
this.eventCallback = eventCallback;
eventTopic.addListener(TaskEvent.class, (channelString, message) -> handleEvent(message));
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package org.icij.datashare.asynctasks;

import java.util.concurrent.ConcurrentHashMap;

public class TaskRepositoryMemory extends ConcurrentHashMap<String, Task<?>> implements TaskRepository { }
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package org.icij.datashare.asynctasks;

import org.redisson.Redisson;
import org.redisson.RedissonMap;
import org.redisson.api.RedissonClient;
import org.redisson.command.CommandSyncService;
import org.redisson.liveobject.core.RedissonObjectBuilder;

public class TaskRepositoryRedis extends RedissonMap<String, Task<?>> implements TaskRepository {
public TaskRepositoryRedis(RedissonClient redisson) {
this(redisson, "ds:task:manager");
}

public TaskRepositoryRedis(RedissonClient redisson, String name) {
super(new TaskManagerRedis.TaskViewCodec(), new CommandSyncService(((Redisson) redisson).getConnectionManager(), new RedissonObjectBuilder(redisson)),
name, redisson, null, null);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ public TaskWorkerLoop(TaskFactory factory, TaskSupplier taskSupplier, CountDownL
this(factory, taskSupplier, countDownLatch, 60_000);
}

TaskWorkerLoop(TaskFactory factory, TaskSupplier taskSupplier, CountDownLatch countDownLatch, int pollTimeMillis) {
public TaskWorkerLoop(TaskFactory factory, TaskSupplier taskSupplier, CountDownLatch countDownLatch, int pollTimeMillis) {
this.factory = factory;
this.taskSupplier = taskSupplier;
this.waitForMainLoopCalled = countDownLatch;
Expand Down Expand Up @@ -95,6 +95,7 @@ private Integer mainLoop() {
}
}
logger.info("Exiting loop after {} tasks", nbTasks);
loopThread.interrupt();
return nbTasks;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ public void test_new_task() throws Exception {
@Test(timeout = 2000)
public void test_new_task_with_group_routing() throws Exception {
String key = "Key";
try (TaskManagerAmqp groupTaskManager = new TaskManagerAmqp(AMQP, new ConcurrentHashMap<>(), RoutingStrategy.GROUP, () -> nextMessage.countDown());
try (TaskManagerAmqp groupTaskManager = new TaskManagerAmqp(AMQP, new TaskRepositoryMemory(), RoutingStrategy.GROUP, () -> nextMessage.countDown());
TaskSupplierAmqp groupTaskSupplier = new TaskSupplierAmqp(AMQP, key)) {
groupTaskSupplier.consumeTasks(t -> taskQueue.add(t));
String expectedTaskViewId = groupTaskManager.startTask("taskName", User.local(), new Group(key), Map.of());
Expand All @@ -59,7 +59,7 @@ public void test_new_task_with_group_routing() throws Exception {

@Test(timeout = 2000)
public void test_new_task_with_name_routing() throws Exception {
try (TaskManagerAmqp groupTaskManager = new TaskManagerAmqp(AMQP, new ConcurrentHashMap<>(), RoutingStrategy.NAME, () -> nextMessage.countDown());
try (TaskManagerAmqp groupTaskManager = new TaskManagerAmqp(AMQP, new TaskRepositoryMemory(), RoutingStrategy.NAME, () -> nextMessage.countDown());
TaskSupplierAmqp groupTaskSupplier = new TaskSupplierAmqp(AMQP, "TaskName")) {
groupTaskSupplier.consumeTasks(t -> taskQueue.add(t));
String expectedTaskViewId = groupTaskManager.startTask("TaskName", User.local(), Map.of());
Expand Down Expand Up @@ -181,16 +181,8 @@ public static void beforeClass() throws Exception {
public void setUp() throws IOException {
nextMessage = new CountDownLatch(1);
final RedissonClient redissonClient = new RedissonClientFactory().withOptions(
Options.from(new PropertiesProvider(Map.of("redisAddress", "redis://redis:6379")).getProperties())).create();
Map<String, Task<?>> tasks = new RedissonMap<>(new TaskManagerRedis.TaskViewCodec(),
new CommandSyncService(((Redisson) redissonClient).getConnectionManager(),
new RedissonObjectBuilder(redissonClient)),
"tasks:queue:test",
redissonClient,
null,
null
);
taskManager = new TaskManagerAmqp(AMQP, tasks, RoutingStrategy.UNIQUE, () -> nextMessage.countDown());
Options.from(new PropertiesProvider(Map.of("redisAddress", "redis://redis:6379")).getProperties())).create();
taskManager = new TaskManagerAmqp(AMQP, new TaskRepositoryRedis(redissonClient, "tasks:queue:test"), RoutingStrategy.UNIQUE, () -> nextMessage.countDown());
taskSupplier = new TaskSupplierAmqp(AMQP);
taskSupplier.consumeTasks(t -> taskQueue.add(t));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ public class TaskManagerMemoryTest {

@Before
public void setUp() throws Exception {
LinkedBlockingQueue<Task<?>> taskViews = new LinkedBlockingQueue<>();
taskManager = new TaskManagerMemory(factory, new PropertiesProvider(), waitForLoop);
taskManager = new TaskManagerMemory(factory, new PropertiesProvider(), new TaskRepositoryMemory(), waitForLoop);
taskInspector = new TaskInspector(taskManager);
waitForLoop.await();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ public void test_clear_running_task_should_throw_exception() throws Exception {
taskManager.clearTask(taskViewId);
}

@Test
@Test(timeout = 10000)
public void test_done_task_result_for_file() throws Exception {
String taskViewId = taskManager.startTask("HelloWorld", User.local(), new HashMap<>() {{
put("greeted", "world");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,7 @@ public static Collection<Object[]> taskServices() throws Exception {
"messageBusAddress", "amqp://admin:admin@rabbitmq"));
final RedissonClient redissonClient = new RedissonClientFactory().withOptions(
Options.from(propertiesProvider.getProperties())).create();
Map<String, Task<?>> amqpTasks = new RedissonMap<>(new TaskManagerRedis.TaskViewCodec(),
new CommandSyncService(((Redisson) redissonClient).getConnectionManager(),
new RedissonObjectBuilder(redissonClient)),
"tasks:queue:test",
redissonClient,
null,
null
);

AMQP = new AmqpInterlocutor(propertiesProvider);
AMQP.deleteQueues(AmqpQueue.MANAGER_EVENT, AmqpQueue.WORKER_EVENT, AmqpQueue.TASK);
AMQP.createAmqpChannelForPublish(AmqpQueue.TASK);
Expand All @@ -79,7 +72,7 @@ public static Collection<Object[]> taskServices() throws Exception {

return asList(new Object[][]{
{
(Creator<TaskManager>) () -> new TaskManagerAmqp(AMQP, amqpTasks, RoutingStrategy.UNIQUE, amqpWaiter::countDown),
(Creator<TaskManager>) () -> new TaskManagerAmqp(AMQP, new TaskRepositoryRedis(redissonClient), RoutingStrategy.UNIQUE, amqpWaiter::countDown),
(Creator<TaskSupplier>) () -> new TaskSupplierAmqp(AMQP),
amqpWaiter
},
Expand Down

0 comments on commit f9aceee

Please sign in to comment.