Skip to content

Commit

Permalink
increase AbstractRetrieverToolTests code coverage (#65)
Browse files Browse the repository at this point in the history
Signed-off-by: Mingshi Liu <[email protected]>
  • Loading branch information
mingshl authored Dec 27, 2023
1 parent 0ccb019 commit 8f746aa
Showing 1 changed file with 60 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,10 @@

package org.opensearch.agent.tools;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertThrows;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.*;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import static org.mockito.Mockito.*;
import static org.opensearch.agent.tools.AbstractRetrieverTool.DEFAULT_DESCRIPTION;

import java.io.InputStream;
import java.nio.charset.StandardCharsets;
Expand All @@ -23,6 +19,7 @@

import org.junit.Before;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.Mockito;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
Expand All @@ -32,6 +29,7 @@
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.xcontent.DeprecationHandler;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.ml.common.spi.tools.Tool;
import org.opensearch.search.SearchModule;

import lombok.SneakyThrows;
Expand Down Expand Up @@ -180,4 +178,59 @@ public void testValidate() {
assertFalse(mockedImpl.validate(new HashMap<>()));
assertFalse(mockedImpl.validate(null));
}

@Test
public void testGetAttributes() {
assertEquals(mockedImpl.getVersion(), null);
assertEquals(mockedImpl.getIndex(), TEST_INDEX);
assertEquals(mockedImpl.getDocSize(), TEST_DOC_SIZE);
assertEquals(mockedImpl.getSourceFields(), TEST_SOURCE_FIELDS);
assertEquals(mockedImpl.getQueryBody(TEST_QUERY), TEST_QUERY);
}

@Test
public void testGetQueryBodySuccess() {
assertEquals(mockedImpl.getQueryBody(TEST_QUERY), TEST_QUERY);
}

@Test
@SneakyThrows
public void testRunWithRuntimeException() {
Client client = mock(Client.class);
mockedImpl.setClient(client);
ActionListener listener = mock(ActionListener.class);
doAnswer(invocation -> {
SearchRequest searchRequest = invocation.getArgument(0);
assertEquals((long) TEST_DOC_SIZE, (long) searchRequest.source().size());
ActionListener<SearchResponse> actionListener = invocation.getArgument(1);
actionListener.onFailure(new RuntimeException("Failed to search index"));
return null;
}).when(client).search(any(), any());
mockedImpl.run(Map.of(AbstractRetrieverTool.INPUT_FIELD, "hello world"), listener);
verify(listener).onFailure(any(RuntimeException.class));
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
verify(listener).onFailure(argumentCaptor.capture());
assertEquals("Failed to search index", argumentCaptor.getValue().getMessage());
}

@Test
public void testFactory() {
// Create a mock object of the abstract Factory class
Client client = mock(Client.class);
AbstractRetrieverTool.Factory<Tool> factoryMock = new AbstractRetrieverTool.Factory<>() {
public PPLTool create(Map<String, Object> params) {
return null;
}
};

factoryMock.init(client, TEST_XCONTENT_REGISTRY_FOR_QUERY);

assertNotNull(factoryMock.client);
assertNotNull(factoryMock.xContentRegistry);
assertEquals(client, factoryMock.client);
assertEquals(TEST_XCONTENT_REGISTRY_FOR_QUERY, factoryMock.xContentRegistry);

String defaultDescription = factoryMock.getDefaultDescription();
assertEquals(DEFAULT_DESCRIPTION, defaultDescription);
}
}

0 comments on commit 8f746aa

Please sign in to comment.