Skip to content

Commit

Permalink
remove and add comment
Browse files Browse the repository at this point in the history
Signed-off-by: xinyual <[email protected]>
  • Loading branch information
xinyual committed Nov 20, 2024
1 parent 7cc0a08 commit b4adbc7
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,18 @@
import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry;
import static org.opensearch.ml.utils.RestActionUtils.getFetchSourceContext;

import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Deque;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Function;

import org.apache.commons.lang3.tuple.Pair;
import org.opensearch.OpenSearchStatusException;
import org.opensearch.ResourceNotFoundException;
import org.opensearch.action.ActionRequest;
Expand Down Expand Up @@ -290,21 +293,21 @@ private void checkAgentBeforeDeleteModel(String modelId, ActionListener<Boolean>
private void checkIngestPipelineBeforeDeleteModel(String modelId, ActionListener<Boolean> actionListener) {
GetPipelineRequest getPipelineRequest = new GetPipelineRequest();
client.execute(GetPipelineAction.INSTANCE, getPipelineRequest, ActionListener.wrap(ingestPipelineResponse -> {
List<String> allRelevantPipelineIds = findRelevantPipelines(
List<String> allDependentPipelineIds = findDependentPipelines(
ingestPipelineResponse.pipelines(),
modelId,
org.opensearch.ingest.PipelineConfiguration::getConfigAsMap,
org.opensearch.ingest.PipelineConfiguration::getId
);
if (allRelevantPipelineIds.isEmpty()) {
if (allDependentPipelineIds.isEmpty()) {
actionListener.onResponse(true);
} else {
actionListener
.onFailure(
new OpenSearchStatusException(
allRelevantPipelineIds.size()
allDependentPipelineIds.size()
+ " ingest pipelines are still using this model, please delete or update the pipelines first: "
+ Arrays.toString(allRelevantPipelineIds.toArray(new String[0])),
+ Arrays.toString(allDependentPipelineIds.toArray(new String[0])),
RestStatus.CONFLICT
)
);
Expand All @@ -320,21 +323,21 @@ private void checkIngestPipelineBeforeDeleteModel(String modelId, ActionListener
private void checkSearchPipelineBeforeDeleteModel(String modelId, ActionListener<Boolean> actionListener) {
GetSearchPipelineRequest getSearchPipelineRequest = new GetSearchPipelineRequest();
client.execute(GetSearchPipelineAction.INSTANCE, getSearchPipelineRequest, ActionListener.wrap(searchPipelineResponse -> {
List<String> allRelevantPipelineIds = findRelevantPipelines(
List<String> allDependentPipelineIds = findDependentPipelines(
searchPipelineResponse.pipelines(),
modelId,
org.opensearch.search.pipeline.PipelineConfiguration::getConfigAsMap,
org.opensearch.search.pipeline.PipelineConfiguration::getId
);
if (allRelevantPipelineIds.isEmpty()) {
if (allDependentPipelineIds.isEmpty()) {
actionListener.onResponse(true);
} else {
actionListener
.onFailure(
new OpenSearchStatusException(
allRelevantPipelineIds.size()
allDependentPipelineIds.size()
+ " search pipelines are still using this model, please delete or update the pipelines first: "
+ Arrays.toString(allRelevantPipelineIds.toArray(new String[0])),
+ Arrays.toString(allDependentPipelineIds.toArray(new String[0])),
RestStatus.CONFLICT
)
);
Expand Down Expand Up @@ -475,40 +478,57 @@ private Boolean isModelNotDeployed(MLModelState mlModelState) {
&& !mlModelState.equals(MLModelState.PARTIALLY_DEPLOYED);
}

private <T> List<String> findRelevantPipelines(
private <T> List<String> findDependentPipelines(
List<T> pipelineConfigurations,
String candidateModelId,
Function<T, Map<String, Object>> getConfigFunction,
Function<T, String> getIdFunction
) {
List<String> relevantPipelineConfigurations = new ArrayList<>();
List<String> dependentPipelineConfigurations = new ArrayList<>();
for (T pipelineConfiguration : pipelineConfigurations) {
Map<String, Object> config = getConfigFunction.apply(pipelineConfiguration);
if (searchThroughConfig(config, candidateModelId, "")) {
relevantPipelineConfigurations.add(getIdFunction.apply(pipelineConfiguration));
dependentPipelineConfigurations.add(getIdFunction.apply(pipelineConfiguration));
}
}
return relevantPipelineConfigurations;
return dependentPipelineConfigurations;
}

// This method is to go through the pipeline configs and only when the key is model id and value is
// 1. String and equal to candidate id 2. A list of String containing candidate id We will return True. Otherwise False
private Boolean searchThroughConfig(Object searchCandidate, String candidateId, String targetModelKey) {
boolean flag = false;
if (searchCandidate instanceof String
&& Objects.equals(targetModelKey, PIPELINE_TARGET_MODEL_KEY)
&& Objects.equals(candidateId, searchCandidate)) {
return true;
} else if (searchCandidate instanceof List<?>) {
for (Object v : (List<?>) searchCandidate) {
flag = flag || searchThroughConfig(v, candidateId, targetModelKey);
}
} else if (searchCandidate instanceof Map<?, ?>) {
for (Map.Entry<String, Object> entry : ((Map<String, Object>) searchCandidate).entrySet()) {
String key = entry.getKey();
Object value = entry.getValue();
flag = flag || searchThroughConfig(value, candidateId, key);
// Use a stack to store the elements to be processed
Deque<Pair<String, Object>> stack = new ArrayDeque<>();
stack.push(Pair.of(targetModelKey, searchCandidate));

while (!stack.isEmpty()) {
// Pop an item from the stack
Pair<String, Object> current = stack.pop();
String currentKey = current.getLeft();
Object currentCandidate = current.getRight();

if (currentCandidate instanceof String) {
// Check for a match
if (Objects.equals(currentKey, PIPELINE_TARGET_MODEL_KEY) && Objects.equals(candidateId, currentCandidate)) {
return true;
}
} else if (currentCandidate instanceof List<?>) {
// Push all elements in the list onto the stack
for (Object v : (List<?>) currentCandidate) {
stack.push(Pair.of(currentKey, v));
}
} else if (currentCandidate instanceof Map<?, ?>) {
// Push all values in the map onto the stack
for (Map.Entry<?, ?> entry : ((Map<?, ?>) currentCandidate).entrySet()) {
String key = (String) entry.getKey();
Object value = entry.getValue();
stack.push(Pair.of(key, value));
}
}
}
return flag;

// If no match is found
return false;
}

// this method is only to stub static method.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -749,7 +749,7 @@ private void prepare() throws IOException {
return null;
}).when(client).execute(eq(GetSearchPipelineAction.INSTANCE), any(), any());
configDataMap = Map
.of("model_id", "test_id", "list_model_id", List.of("test_list_id"), "test_map_id", Map.of("test_key", "test_map_id"));
.of("single_model_id", "test_id", "list_model_id", List.of("test_id"), "test_map_id", Map.of("model_id", "test_id"));
doAnswer(invocation -> new SearchRequest()).when(agentModelsSearcher).constructQueryRequest(any());

GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED, null, false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,10 @@ interface Factory<T extends Tool> {
*/
String getDefaultVersion();

/**
* Get model id related field names
* @return the list of all model id related field names
*/
List<String> getAllModelKeys();
}
}

0 comments on commit b4adbc7

Please sign in to comment.