Skip to content

Commit

Permalink
feat: implementing other methods for TaskRepository #1626
Browse files Browse the repository at this point in the history
  • Loading branch information
bamthomas committed Jan 17, 2025
1 parent 3b40dae commit 6a28e58
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,20 @@
import org.icij.datashare.asynctasks.Task;
import org.icij.datashare.asynctasks.TaskRepository;
import org.icij.datashare.db.tables.records.TaskRecord;
import org.jooq.DSLContext;
import org.jooq.InsertValuesStep10;
import org.jooq.SQLDialect;
import org.jooq.impl.DSL;

import javax.sql.DataSource;
import java.sql.Timestamp;
import java.time.LocalDateTime;
import java.util.AbstractMap;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

import static java.util.Optional.ofNullable;
import static org.icij.datashare.asynctasks.bus.amqp.Event.MAX_RETRIES_LEFT;
Expand All @@ -33,22 +37,22 @@ public class JooqTaskRepository implements TaskRepository {

@Override
public int size() {
return 0;
}

@Override
public boolean isEmpty() {
return false;
return DSL.using(connectionProvider, dialect).selectCount().from(TASK).fetchOne(0, Integer.class);
}

@Override
public boolean containsKey(Object o) {
return false;
DSLContext ctx = DSL.using(connectionProvider, dialect);
return ctx.fetchExists(
ctx.selectOne()
.from(TASK)
.where(TASK.ID.eq((String) o))
);
}

@Override
public boolean containsValue(Object o) {
return false;
return containsKey(((Task<?>) o).getId());
}

@Override
Expand All @@ -59,66 +63,83 @@ public Task<?> get(Object o) {

@Override
public Task<?> put(String s, Task<?> task) {
try {
if (task == null || s == null || !s.equals(task.getId())) {
throw new IllegalArgumentException(String.format("task is null or its id (%s) is different than the key (%s)", ofNullable(task).map(Task::getId).orElse(null), s));
}
final int inserted = DSL.using(connectionProvider, dialect)
.insertInto(TASK).columns(
TASK.ID, TASK.NAME, TASK.STATE, TASK.USER_ID, TASK.GROUP_ID, TASK.PROGRESS,
TASK.CREATED_AT, TASK.RETRIES_LEFT, TASK.MAX_RETRIES, TASK.ARGS)
.values(task.id, task.name,
task.getState().name(),
ofNullable(task.getUser()).map(u -> u.id).orElse(null),
ofNullable(task.getGroup()).map(Group::id).orElse(null),
task.getProgress(),
new Timestamp(task.createdAt.getTime()).toLocalDateTime(), task.getRetriesLeft(),
MAX_RETRIES_LEFT, new ObjectMapper().writeValueAsString(task.args)).execute();
return inserted == 1 ? task : null;
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
if (task == null || s == null || !s.equals(task.getId())) {
throw new IllegalArgumentException(String.format("task is null or its id (%s) is different than the key (%s)", ofNullable(task).map(Task::getId).orElse(null), s));
}
InsertValuesStep10<TaskRecord, String, String, String, String, String, Double, LocalDateTime, Integer, Integer, String> insertInto = insert();
insertValues(task, insertInto);
int inserted = insertInto.execute();
return inserted == 1 ? task : null;
}

@Override
public Task<?> remove(Object o) {
return null;
public Task<?> remove(Object key) {
return createTaskFrom(DSL.using(connectionProvider, dialect).deleteFrom(TASK).where(TASK.ID.eq((String) key)).returning().fetchOne());
}

@Override
public void putAll(Map<? extends String, ? extends Task<?>> map) {

ofNullable(map).orElseThrow(() -> new IllegalArgumentException("task(s) map is null"));
InsertValuesStep10<TaskRecord, String, String, String, String, String, Double, LocalDateTime, Integer, Integer, String> insert = insert();
map.values().forEach(t -> insertValues(t, insert));
insert.execute();
}

@Override
public void clear() {

DSL.using(connectionProvider, dialect).deleteFrom(TASK).execute();
}

@Override
public Set<String> keySet() {
return Set.of();
return DSL.using(connectionProvider, dialect).select(TASK.ID).from(TASK)
.stream().map(r -> r.field1().getValue(r)).collect(Collectors.toSet());
}

@Override
public Collection<Task<?>> values() {
return List.of();
return DSL.using(connectionProvider, dialect).selectFrom(TASK).stream()
.map(this::createTaskFrom).collect(Collectors.toList());
}

@Override
public Set<Entry<String, Task<?>>> entrySet() {
return Set.of();
return DSL.using(connectionProvider, dialect).selectFrom(TASK).stream()
.map(t -> new AbstractMap.SimpleEntry<String, Task<?>>(t.getId(), createTaskFrom(t)))
.collect(Collectors.toSet());
}

private Task<?> createTaskFrom(TaskRecord taskRecord) {
return ofNullable(taskRecord).map(r ->
{
try {
return new Task<>(r.getId(), r.getName(), Task.State.valueOf(r.getState()),
r.getProgress(), null, new ObjectMapper().readValue(r.getArgs(), new TypeReference<HashMap<String, Object>>(){}));
r.getProgress(), null, new ObjectMapper().readValue(r.getArgs(), new TypeReference<HashMap<String, Object>>() {
}));
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
}).orElse(null);
}

private InsertValuesStep10<TaskRecord, String, String, String, String, String, Double, LocalDateTime, Integer, Integer, String> insert() {
return DSL.using(connectionProvider, dialect)
.insertInto(TASK).columns(
TASK.ID, TASK.NAME, TASK.STATE, TASK.USER_ID, TASK.GROUP_ID, TASK.PROGRESS,
TASK.CREATED_AT, TASK.RETRIES_LEFT, TASK.MAX_RETRIES, TASK.ARGS);
}

private static void insertValues(Task<?> t, InsertValuesStep10<TaskRecord, String, String, String, String, String, Double, LocalDateTime, Integer, Integer, String> insert) {
try {
insert.values(t.id, t.name,
t.getState().name(),
ofNullable(t.getUser()).map(u -> u.id).orElse(null),
ofNullable(t.getGroup()).map(Group::id).orElse(null),
t.getProgress(),
new Timestamp(t.createdAt.getTime()).toLocalDateTime(), t.getRetriesLeft(),
MAX_RETRIES_LEFT, new ObjectMapper().writeValueAsString(t.args));
} catch (JsonProcessingException ex) {
throw new RuntimeException(ex);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import org.icij.datashare.asynctasks.Task;
import org.icij.datashare.test.DatashareTimeRule;
import org.icij.datashare.user.User;
import org.junit.After;
import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
Expand All @@ -22,6 +23,14 @@ public class JooqTaskRepositoryTest {
public DbSetupRule dbRule;
private final JooqTaskRepository repository;

@Parameterized.Parameters
public static Collection<Object[]> dataSources() {
return asList(new Object[][]{
{new DbSetupRule("jdbc:sqlite:file:memorydb.db?mode=memory&cache=shared")},
{new DbSetupRule("jdbc:postgresql://postgres/dstest?user=dstest&password=test")}
});
}

@Test(expected = IllegalArgumentException.class)
public void test_put_with_key_different_than_id() {
assertThat(repository.put("my_key", new Task<>("foo", User.local(), Map.of())));
Expand All @@ -47,15 +56,89 @@ public void test_put_get() {
assertThat(repository.get(foo.getId())).isEqualTo(foo); // but equals as defined by Task
}

@Parameterized.Parameters
public static Collection<Object[]> dataSources() {
return asList(new Object[][]{
{new DbSetupRule("jdbc:sqlite:file:memorydb.db?mode=memory&cache=shared")},
{new DbSetupRule("jdbc:postgresql://postgres/dstest?user=dstest&password=test")}
});
@Test
public void test_size() {
assertThat(repository.size()).isEqualTo(0);
repository.save(new Task<>("foo", User.local(), Map.of()));
assertThat(repository.size()).isEqualTo(1);
}

@Test
public void test_empty() {
assertThat(repository.isEmpty()).isTrue();
repository.save(new Task<>("foo", User.local(), Map.of()));
assertThat(repository.isEmpty()).isFalse();
}

@Test
public void test_contains_key() {
Task<Object> foo = new Task<>("foo", User.local(), Map.of());
assertThat(repository.containsKey(foo.getId())).isFalse();
repository.save(foo);
assertThat(repository.containsKey(foo.getId())).isTrue();
}

@Test
public void test_contains_value() {
Task<Object> foo = new Task<>("foo", User.local(), Map.of());
assertThat(repository.containsValue(foo)).isFalse();
repository.save(foo);
assertThat(repository.containsValue(foo)).isTrue();
}

@Test
public void test_remove() {
Task<Object> foo = new Task<>("foo", User.local(), Map.of());
assertThat(repository.remove(foo.getId())).isNull();
repository.save(foo);
assertThat(repository.remove(foo.getId())).isEqualTo(foo);
assertThat(repository.isEmpty()).isTrue();
}

@Test(expected = IllegalArgumentException.class)
public void test_putAll_null() {
repository.putAll(null);
}

@Test
public void test_putAll() {
Map<String, Task<?>> map = Map.of("id_foo", new Task<>("id_foo", "foo", User.local(), Map.of()), "id_bar", new Task<>("id_bar", "bar", User.local(), Map.of()));
assertThat(repository.isEmpty()).isTrue();

repository.putAll(map);

assertThat(repository.size()).isEqualTo(2);
assertThat(repository.containsKey("id_foo")).isTrue();
assertThat(repository.containsKey("id_bar")).isTrue();
}

@Test
public void test_keySet() {
repository.putAll(Map.of("id_foo", new Task<>("id_foo", "foo", User.local(), Map.of()), "id_bar", new Task<>("id_bar", "bar", User.local(), Map.of())));
assertThat(repository.keySet()).containsOnly("id_foo", "id_bar");
}

@Test
public void test_values() {
Task<Object> taskFoo = new Task<>("id_foo", "foo", User.local(), Map.of());
Task<Object> taskBar = new Task<>("id_bar", "bar", User.local(), Map.of());
repository.putAll(Map.of(taskFoo.getId(), taskFoo, taskBar.getId(), taskBar));
assertThat(repository.values()).containsOnly(taskFoo, taskBar);
}

@Test
public void test_entrySet() {
Task<Object> taskFoo = new Task<>("id_foo", "foo", User.local(), Map.of());
Task<Object> taskBar = new Task<>("id_bar", "bar", User.local(), Map.of());
repository.putAll(Map.of(taskFoo.getId(), taskFoo, taskBar.getId(), taskBar));
assertThat(repository.entrySet().stream().map(Map.Entry::getKey).toList()).containsOnly(taskFoo.getId(), taskBar.getId());
assertThat(repository.entrySet().stream().map(Map.Entry::getValue).toList()).containsOnly(taskFoo, taskBar);
}

@After
public void tearDown() throws Exception {
repository.clear();
}

public JooqTaskRepositoryTest(DbSetupRule rule) {
dbRule = rule;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,8 @@ default Task<?> save(Task<?> task) {
default Task<?> get(String id) {
return get((Object)id);
}

default boolean isEmpty() {
return size() == 0;
}
}

0 comments on commit 6a28e58

Please sign in to comment.