Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Poll workers of a partitioned step with a database query #4705

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

package org.springframework.batch.core;

import java.util.Set;

/**
* Enumeration representing the status of an execution.
*
Expand Down Expand Up @@ -71,6 +73,8 @@ public enum BatchStatus {
*/
UNKNOWN;

public static final Set<BatchStatus> RUNNING_STATUSES = Set.of(STARTING, STARTED, STOPPING);

/**
* Convenience method to return the higher value status of the statuses passed to the
* method.
Expand All @@ -87,7 +91,7 @@ public static BatchStatus max(BatchStatus status1, BatchStatus status2) {
* @return true if the status is STARTING, STARTED, STOPPING
*/
public boolean isRunning() {
return this == STARTING || this == STARTED || this == STOPPING;
return RUNNING_STATUSES.contains(this);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import java.util.List;
import java.util.Set;

import org.springframework.batch.core.BatchStatus;
import org.springframework.batch.core.JobExecution;
import org.springframework.batch.core.JobInstance;
import org.springframework.batch.core.JobParameters;
Expand Down Expand Up @@ -87,6 +88,14 @@ default JobInstance getLastJobInstance(String jobName) {
@Nullable
StepExecution getStepExecution(@Nullable Long jobExecutionId, @Nullable Long stepExecutionId);

/**
* Find {@link StepExecution}s by IDs and parent {@link JobExecution} ID
* @param jobExecutionId given job execution id
* @param stepExecutionIds given step execution ids
* @return collection of {@link StepExecution}
*/
Set<StepExecution> getStepExecutions(Long jobExecutionId, Set<Long> stepExecutionIds);

/**
* @param instanceId {@link Long} The ID for the {@link JobInstance} to obtain.
* @return the {@code JobInstance} that has this ID, or {@code null} if not found.
Expand Down Expand Up @@ -170,4 +179,13 @@ default JobExecution getLastJobExecution(JobInstance jobInstance) {
*/
long getJobInstanceCount(@Nullable String jobName) throws NoSuchJobException;

/**
* Retrieve number of step executions that match the step execution ids and the batch
* statuses
* @param stepExecutionIds given step execution ids
* @param matchingBatchStatuses given batch statuses to match against
* @return number of {@link StepExecution} matching the criteria
*/
long getStepExecutionCount(Set<Long> stepExecutionIds, Set<BatchStatus> matchingBatchStatuses);

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package org.springframework.batch.core.explore.support;

import org.springframework.batch.core.BatchStatus;
import org.springframework.batch.core.JobExecution;
import org.springframework.batch.core.JobInstance;
import org.springframework.batch.core.JobParameters;
Expand Down Expand Up @@ -147,6 +148,19 @@ public StepExecution getStepExecution(@Nullable Long jobExecutionId, @Nullable L
return stepExecution;
}

@Nullable
@Override
public Set<StepExecution> getStepExecutions(Long jobExecutionId, Set<Long> stepExecutionIds) {
JobExecution jobExecution = jobExecutionDao.getJobExecution(jobExecutionId);
if (jobExecution == null) {
return null;
}
getJobExecutionDependencies(jobExecution);
Set<StepExecution> stepExecutions = stepExecutionDao.getStepExecutions(jobExecution, stepExecutionIds);
stepExecutions.forEach(this::getStepExecutionDependencies);
return stepExecutions;
}

@Nullable
@Override
public JobInstance getJobInstance(@Nullable Long instanceId) {
Expand Down Expand Up @@ -180,6 +194,14 @@ public long getJobInstanceCount(@Nullable String jobName) throws NoSuchJobExcept
return jobInstanceDao.getJobInstanceCount(jobName);
}

@Override
public long getStepExecutionCount(Set<Long> stepExecutionIds, Set<BatchStatus> matchingBatchStatuses) {
if (stepExecutionIds.isEmpty() || matchingBatchStatuses.isEmpty()) {
return 0;
}
return stepExecutionDao.countStepExecutions(stepExecutionIds, matchingBatchStatuses);
}

/**
* @return instance of {@link JobInstanceDao}.
* @since 5.1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
package org.springframework.batch.core.repository.dao;

import java.sql.Types;
import java.util.Collection;
import java.util.Map;
import java.util.stream.Collectors;

import org.springframework.beans.factory.InitializingBean;
import org.springframework.jdbc.core.JdbcOperations;
Expand Down Expand Up @@ -51,6 +54,14 @@ protected String getQuery(String base) {
return StringUtils.replace(base, "%PREFIX%", tablePrefix);
}

protected String getQuery(String base, Map<String, Collection<?>> collectionParams) {
String query = getQuery(base);
for (Map.Entry<String, Collection<?>> collectionParam : collectionParams.entrySet()) {
query = createParameterizedQuery(query, collectionParam.getKey(), collectionParam.getValue());
}
return query;
}

protected String getTablePrefix() {
return tablePrefix;
}
Expand Down Expand Up @@ -80,6 +91,18 @@ public void setClobTypeToUse(int clobTypeToUse) {
this.clobTypeToUse = clobTypeToUse;
}

/**
* Replaces a given placeholder with a number of parameters (i.e. "?").
* @param sqlTemplate given sql template
* @param placeholder placeholder that is being used for parameters
* @param parameters collection of parameters with variable size
* @return sql query replaced with a number of parameters
*/
private static String createParameterizedQuery(String sqlTemplate, String placeholder, Collection<?> parameters) {
String params = parameters.stream().map(p -> "?").collect(Collectors.joining(", "));
return sqlTemplate.replace(placeholder, params);
}

@Override
public void afterPropertiesSet() throws Exception {
Assert.state(jdbcTemplate != null, "JdbcOperations is required");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,11 @@
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import java.util.stream.Stream;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
Expand Down Expand Up @@ -93,6 +96,16 @@ public class JdbcStepExecutionDao extends AbstractJdbcBatchMetadataDao implement

private static final String GET_STEP_EXECUTION = GET_RAW_STEP_EXECUTIONS + " AND STEP_EXECUTION_ID = ?";

private static final String GET_STEP_EXECUTIONS_BY_IDS = GET_RAW_STEP_EXECUTIONS
+ " and STEP_EXECUTION_ID IN (%STEP_EXECUTION_IDS%)";

private static final String COUNT_STEP_EXECUTIONS_BY_IDS_AND_STATUSES = """
SELECT COUNT(*)
FROM %PREFIX%STEP_EXECUTION SE
WHERE SE.STEP_EXECUTION_ID IN (%STEP_EXECUTION_IDS%)
AND SE.STATUS IN (%STEP_STATUSES%)
""";

private static final String GET_LAST_STEP_EXECUTION = """
SELECT SE.STEP_EXECUTION_ID, SE.STEP_NAME, SE.START_TIME, SE.END_TIME, SE.STATUS, SE.COMMIT_COUNT, SE.READ_COUNT, SE.FILTER_COUNT, SE.WRITE_COUNT, SE.EXIT_CODE, SE.EXIT_MESSAGE, SE.READ_SKIP_COUNT, SE.WRITE_SKIP_COUNT, SE.PROCESS_SKIP_COUNT, SE.ROLLBACK_COUNT, SE.LAST_UPDATED, SE.VERSION, SE.CREATE_TIME, JE.JOB_EXECUTION_ID, JE.START_TIME, JE.END_TIME, JE.STATUS, JE.EXIT_CODE, JE.EXIT_MESSAGE, JE.CREATE_TIME, JE.LAST_UPDATED, JE.VERSION
FROM %PREFIX%JOB_EXECUTION JE
Expand Down Expand Up @@ -337,6 +350,16 @@ public StepExecution getStepExecution(JobExecution jobExecution, Long stepExecut
}
}

@Override
@Nullable
public Set<StepExecution> getStepExecutions(JobExecution jobExecution, Set<Long> stepExecutionIds) {
List<StepExecution> executions = getJdbcTemplate().query(
getQuery(GET_STEP_EXECUTIONS_BY_IDS, Map.of("%STEP_EXECUTION_IDS%", stepExecutionIds)),
new StepExecutionRowMapper(jobExecution),
Stream.concat(Stream.of(jobExecution.getId()), stepExecutionIds.stream()).toArray(Object[]::new));
return Set.copyOf(executions);
}

@Override
public StepExecution getLastStepExecution(JobInstance jobInstance, String stepName) {
List<StepExecution> executions = getJdbcTemplate().query(getQuery(GET_LAST_STEP_EXECUTION), (rs, rowNum) -> {
Expand All @@ -360,6 +383,16 @@ public StepExecution getLastStepExecution(JobInstance jobInstance, String stepNa
}
}

@Override
public long countStepExecutions(Collection<Long> stepExecutionIds, Collection<BatchStatus> matchingBatchStatuses) {
return getJdbcTemplate().queryForObject(
getQuery(COUNT_STEP_EXECUTIONS_BY_IDS_AND_STATUSES,
Map.of("%STEP_EXECUTION_IDS%", stepExecutionIds, "%STEP_STATUSES%", matchingBatchStatuses)),
Long.class,
Stream.concat(stepExecutionIds.stream(), matchingBatchStatuses.stream().map(BatchStatus::name))
.toArray(Object[]::new));
}

@Override
public void addStepExecutions(JobExecution jobExecution) {
getJdbcTemplate().query(getQuery(GET_STEP_EXECUTIONS), new StepExecutionRowMapper(jobExecution),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@
import java.util.Comparator;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

import org.springframework.batch.core.BatchStatus;
import org.springframework.batch.core.JobExecution;
import org.springframework.batch.core.JobInstance;
import org.springframework.batch.core.StepExecution;
Expand Down Expand Up @@ -95,6 +98,17 @@ public StepExecution getStepExecution(JobExecution jobExecution, Long stepExecut
return stepExecution != null ? this.stepExecutionConverter.toStepExecution(stepExecution, jobExecution) : null;
}

@Override
public Set<StepExecution> getStepExecutions(JobExecution jobExecution, Set<Long> stepExecutionIds) {
Query query = query(where("stepExecutionId").in(stepExecutionIds));
List<org.springframework.batch.core.repository.persistence.StepExecution> stepExecutions = this.mongoOperations
.find(query, org.springframework.batch.core.repository.persistence.StepExecution.class,
STEP_EXECUTIONS_COLLECTION_NAME);
return stepExecutions.stream()
.map(stepExecution -> this.stepExecutionConverter.toStepExecution(stepExecution, jobExecution))
.collect(Collectors.toSet());
}

@Override
public StepExecution getLastStepExecution(JobInstance jobInstance, String stepName) {
// TODO optimize the query
Expand Down Expand Up @@ -160,4 +174,12 @@ public long countStepExecutions(JobInstance jobInstance, String stepName) {
return count;
}

@Override
public long countStepExecutions(Collection<Long> stepExecutionIds, Collection<BatchStatus> matchingBatchStatuses) {
Query query = query(where("jobExecutionId").is(stepExecutionIds).and("status").in(matchingBatchStatuses));
return this.mongoOperations.count(query,
org.springframework.batch.core.repository.persistence.StepExecution.class,
STEP_EXECUTIONS_COLLECTION_NAME);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
package org.springframework.batch.core.repository.dao;

import java.util.Collection;
import java.util.Set;

import org.springframework.batch.core.BatchStatus;
import org.springframework.batch.core.JobExecution;
import org.springframework.batch.core.JobInstance;
import org.springframework.batch.core.StepExecution;
Expand Down Expand Up @@ -62,6 +64,15 @@ public interface StepExecutionDao {
@Nullable
StepExecution getStepExecution(JobExecution jobExecution, Long stepExecutionId);

/**
* Get a collection of {@link StepExecution} matching job execution and step execution
* ids.
* @param jobExecution the parent job execution
* @param stepExecutionIds the step execution ids
* @return collection of {@link StepExecution}
*/
Set<StepExecution> getStepExecutions(JobExecution jobExecution, Set<Long> stepExecutionIds);

/**
* Retrieve the last {@link StepExecution} for a given {@link JobInstance} ordered by
* creation time and then id.
Expand Down Expand Up @@ -91,6 +102,15 @@ default long countStepExecutions(JobInstance jobInstance, String stepName) {
throw new UnsupportedOperationException();
}

/**
* Count {@link StepExecution} that match the ids and statuses of them - avoid loading
* them into memory
* @param stepExecutionIds given step execution ids
* @param matchingBatchStatuses
* @return the count of matching steps
*/
long countStepExecutions(Collection<Long> stepExecutionIds, Collection<BatchStatus> matchingBatchStatuses);

/**
* Delete the given step execution.
* @param stepExecution the step execution to delete
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,11 @@ public StepExecution getStepExecution(@Nullable Long jobExecutionId, @Nullable L
throw new UnsupportedOperationException();
}

@Override
public Set<StepExecution> getStepExecutions(Long jobExecutionId, Set<Long> stepExecutionIds) {
return Set.of();
}

@Override
public List<String> getJobNames() {
throw new UnsupportedOperationException();
Expand Down Expand Up @@ -579,6 +584,11 @@ public long getJobInstanceCount(@Nullable String jobName) throws NoSuchJobExcept
}
}

@Override
public long getStepExecutionCount(Set<Long> stepExecutionIds, Set<BatchStatus> matchingBatchStatuses) {
return 0;
}

}

public static class StubJobParametersConverter implements JobParametersConverter {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ void testSaveStepExecutionSetsLastUpdated() {
assertNotNull(stepExecution.getLastUpdated());

LocalDateTime lastUpdated = stepExecution.getLastUpdated();
assertTrue(lastUpdated.isAfter(before));
assertFalse(lastUpdated.isBefore(before));
}

@Test
Expand Down Expand Up @@ -236,7 +236,7 @@ void testUpdateStepExecutionSetsLastUpdated() {
assertNotNull(stepExecution.getLastUpdated());

LocalDateTime lastUpdated = stepExecution.getLastUpdated();
assertTrue(lastUpdated.isAfter(before));
assertFalse(lastUpdated.isBefore(before));
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import org.springframework.batch.core.JobExecution;
import org.springframework.batch.core.BatchStatus;
import org.springframework.batch.core.Step;
import org.springframework.batch.core.StepExecution;
import org.springframework.batch.core.explore.JobExplorer;
Expand Down Expand Up @@ -251,25 +251,20 @@ protected Set<StepExecution> doHandle(StepExecution managerStepExecution,

private Set<StepExecution> pollReplies(final StepExecution managerStepExecution, final Set<StepExecution> split)
throws Exception {
Set<Long> partitionStepExecutionIds = split.stream().map(StepExecution::getId).collect(Collectors.toSet());

Callable<Set<StepExecution>> callback = () -> {
JobExecution jobExecution = jobExplorer.getJobExecution(managerStepExecution.getJobExecutionId());
Set<StepExecution> finishedStepExecutions = jobExecution.getStepExecutions()
.stream()
.filter(stepExecution -> partitionStepExecutionIds.contains(stepExecution.getId()))
.filter(stepExecution -> !stepExecution.getStatus().isRunning())
.collect(Collectors.toSet());

if (logger.isDebugEnabled()) {
logger.debug(String.format("Currently waiting on %s partitions to finish", split.size()));
}

if (finishedStepExecutions.size() == split.size()) {
return finishedStepExecutions;
Set<Long> currentStepExecutionIds = split.stream().map(StepExecution::getId).collect(Collectors.toSet());
long runningStepExecutions = jobExplorer.getStepExecutionCount(currentStepExecutionIds,
BatchStatus.RUNNING_STATUSES);
if (runningStepExecutions > 0 && !split.isEmpty()) {
if (logger.isDebugEnabled()) {
logger.debug(String.format("Currently waiting on %s out of %s partitions to finish",
runningStepExecutions, split.size()));
}
return null;
}
else {
return null;
return jobExplorer.getStepExecutions(managerStepExecution.getJobExecutionId(), currentStepExecutionIds);
}
};

Expand Down
Loading