diff --git a/datashare-app/src/main/java/org/icij/datashare/mode/CommonMode.java b/datashare-app/src/main/java/org/icij/datashare/mode/CommonMode.java index 4c1559d74..ac57cea28 100644 --- a/datashare-app/src/main/java/org/icij/datashare/mode/CommonMode.java +++ b/datashare-app/src/main/java/org/icij/datashare/mode/CommonMode.java @@ -23,6 +23,8 @@ import org.icij.datashare.asynctasks.TaskManager; import org.icij.datashare.asynctasks.TaskModifier; import org.icij.datashare.asynctasks.TaskRepository; +import org.icij.datashare.cli.TaskRepositoryType; +import org.icij.datashare.db.JooqTaskRepository; import org.icij.datashare.tasks.TaskRepositoryRedis; import org.icij.datashare.asynctasks.TaskSupplier; import org.icij.datashare.batch.BatchSearchRepository; @@ -79,6 +81,7 @@ import static org.icij.datashare.cli.DatashareCliOptions.BATCH_QUEUE_TYPE_OPT; import static org.icij.datashare.cli.DatashareCliOptions.MODE_OPT; import static org.icij.datashare.cli.DatashareCliOptions.QUEUE_TYPE_OPT; +import static org.icij.datashare.cli.DatashareCliOptions.TASK_REPOSITORY_OPT; import static org.icij.datashare.text.indexing.elasticsearch.ElasticsearchConfiguration.createESClient; public abstract class CommonMode extends AbstractModule implements Closeable { @@ -140,13 +143,11 @@ protected void configure() { QueueType batchQueueType = getQueueType(propertiesProvider, BATCH_QUEUE_TYPE_OPT, QueueType.MEMORY); switch ( batchQueueType ) { case REDIS: - bind(TaskRepository.class).to(TaskRepositoryRedis.class); bind(TaskManager.class).to(TaskManagerRedis.class); bind(TaskModifier.class).to(TaskSupplierRedis.class); bind(TaskSupplier.class).to(TaskSupplierRedis.class); break; case AMQP: - bind(TaskRepository.class).to(TaskRepositoryRedis.class); bind(TaskManager.class).to(TaskManagerAmqp.class); bind(TaskSupplier.class).to(TaskSupplierAmqp.class); bind(TaskModifier.class).to(TaskSupplierAmqp.class); @@ -268,6 +269,16 @@ void configurePersistence() { bind(Repository.class).toInstance(repositoryFactory.createRepository()); bind(ApiKeyRepository.class).toInstance(repositoryFactory.createApiKeyRepository()); bind(BatchSearchRepository.class).toInstance(repositoryFactory.createBatchSearchRepository()); + + TaskRepositoryType taskRepositoryType = TaskRepositoryType.valueOf(propertiesProvider.get(TASK_REPOSITORY_OPT).orElse("REDIS")); + switch ( taskRepositoryType ) { + case REDIS -> { + bind(TaskRepository.class).to(TaskRepositoryRedis.class); + } + case DATABASE -> { + bind(TaskRepository.class).toInstance(new JooqTaskRepository(repositoryFactory.getDataSource(), repositoryFactory.guessSqlDialect())); + } + } repositoryFactory.initDatabase(); } diff --git a/datashare-cli/src/main/java/org/icij/datashare/cli/DatashareCli.java b/datashare-cli/src/main/java/org/icij/datashare/cli/DatashareCli.java index f226b4e42..dd1e727c4 100644 --- a/datashare-cli/src/main/java/org/icij/datashare/cli/DatashareCli.java +++ b/datashare-cli/src/main/java/org/icij/datashare/cli/DatashareCli.java @@ -191,6 +191,7 @@ OptionParser createParser() { DatashareCliOptions.taskRoutingStrategy(parser); DatashareCliOptions.taskRoutingKey(parser); DatashareCliOptions.pollingInterval(parser); + DatashareCliOptions.taskRepositoryType(parser); return parser; } diff --git a/datashare-cli/src/main/java/org/icij/datashare/cli/DatashareCliOptions.java b/datashare-cli/src/main/java/org/icij/datashare/cli/DatashareCliOptions.java index cf55a422b..64c3c797a 100644 --- a/datashare-cli/src/main/java/org/icij/datashare/cli/DatashareCliOptions.java +++ b/datashare-cli/src/main/java/org/icij/datashare/cli/DatashareCliOptions.java @@ -127,7 +127,7 @@ public final class DatashareCliOptions { public static final String TASK_ROUTING_KEY_OPT = "taskRoutingKey"; public static final String OAUTH_USER_PROJECTS_KEY_OPT = "oauthUserProjectsAttribute"; public static final String POLLING_INTERVAL_OPT = "pollingInterval"; - + public static final String TASK_REPOSITORY_OPT = "taskRepositoryType"; private static final Path DEFAULT_DATASHARE_HOME = Paths.get(System.getProperty("user.home"), ".local/share/datashare"); private static final Integer DEFAULT_NLP_PARALLELISM = 1; @@ -836,6 +836,15 @@ public static void pollingInterval(OptionParser parser) { .ofType(String.class).defaultsTo("60"); } + public static void taskRepositoryType(OptionParser parser) { + parser.acceptsAll( + singletonList(TASK_REPOSITORY_OPT), "type of task repository") + .withRequiredArg() + .ofType( TaskRepositoryType.class ) + .defaultsTo(TaskRepositoryType.REDIS); + } + + public static ValueConverter toAbsolute() { return new ValueConverter() { @Override diff --git a/datashare-cli/src/main/java/org/icij/datashare/cli/TaskRepositoryType.java b/datashare-cli/src/main/java/org/icij/datashare/cli/TaskRepositoryType.java new file mode 100644 index 000000000..c5debde08 --- /dev/null +++ b/datashare-cli/src/main/java/org/icij/datashare/cli/TaskRepositoryType.java @@ -0,0 +1,5 @@ +package org.icij.datashare.cli; + +public enum TaskRepositoryType { + REDIS, DATABASE +} diff --git a/datashare-db/src/main/java/org/icij/datashare/db/JooqTaskRepository.java b/datashare-db/src/main/java/org/icij/datashare/db/JooqTaskRepository.java index 214048257..532c283f7 100644 --- a/datashare-db/src/main/java/org/icij/datashare/db/JooqTaskRepository.java +++ b/datashare-db/src/main/java/org/icij/datashare/db/JooqTaskRepository.java @@ -30,7 +30,7 @@ public class JooqTaskRepository implements TaskRepository { private final DataSource connectionProvider; private final SQLDialect dialect; - JooqTaskRepository(final DataSource connectionProvider, final SQLDialect dialect) { + public JooqTaskRepository(final DataSource connectionProvider, final SQLDialect dialect) { this.connectionProvider = connectionProvider; this.dialect = dialect; } diff --git a/datashare-db/src/main/java/org/icij/datashare/db/RepositoryFactoryImpl.java b/datashare-db/src/main/java/org/icij/datashare/db/RepositoryFactoryImpl.java index 0e67a7300..81a741921 100644 --- a/datashare-db/src/main/java/org/icij/datashare/db/RepositoryFactoryImpl.java +++ b/datashare-db/src/main/java/org/icij/datashare/db/RepositoryFactoryImpl.java @@ -78,7 +78,7 @@ private T createRepository(BiFunction constructor return constructor.apply(dataSource, guessSqlDialectFrom(getDataSourceUrl())); } - static SQLDialect guessSqlDialectFrom(String dataSourceUrl) { + public static SQLDialect guessSqlDialectFrom(String dataSourceUrl) { for (SQLDialect dialect: SQLDialect.values()) { if (dataSourceUrl.contains(dialect.name().toLowerCase())) { return dialect; @@ -87,6 +87,9 @@ static SQLDialect guessSqlDialectFrom(String dataSourceUrl) { throw new IllegalArgumentException("unknown SQL dialect for datasource : " + dataSourceUrl); } + public DataSource getDataSource() {return dataSource;} + public SQLDialect guessSqlDialect() {return guessSqlDialectFrom(getDataSourceUrl());} + DataSource createDatasource() { HikariConfig config = new HikariConfig(); String dataSourceUrl = getDataSourceUrl();