Skip to content

Commit

Permalink
Address concurrency/synchronization comment
Browse files Browse the repository at this point in the history
Signed-off-by: Srikanth Govindarajan <[email protected]>
  • Loading branch information
srikanthjg committed Jan 25, 2025
1 parent 3f70b6c commit ced7a2d
Show file tree
Hide file tree
Showing 5 changed files with 343 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,11 @@ public static Map<Buffer, CompletableFuture<InvokeResponse>> sendRecords(
List<Buffer> batchedBuffers = createBufferBatches(records, config.getBatchOptions(),
outputCodecContext);

Map<Buffer, CompletableFuture<InvokeResponse>> bufferToFutureMap = invokeLambdaAndGetFutureMap(config, lambdaAsyncClient, batchedBuffers);
return bufferToFutureMap;
}

public static Map<Buffer, CompletableFuture<InvokeResponse>> invokeLambdaAndGetFutureMap(LambdaCommonConfig config, LambdaAsyncClient lambdaAsyncClient, List<Buffer> batchedBuffers) {
Map<Buffer, CompletableFuture<InvokeResponse>> bufferToFutureMap = new HashMap<>();
LOG.debug("Batch Chunks created after threshold check: {}", batchedBuffers.size());
for (Buffer buffer : batchedBuffers) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.dataprepper.plugins.lambda.common.accumlator;

import org.apache.commons.lang3.time.StopWatch;
import org.opensearch.dataprepper.model.codec.OutputCodec;
import org.opensearch.dataprepper.model.event.Event;
import org.opensearch.dataprepper.model.record.Record;
import org.opensearch.dataprepper.model.sink.OutputCodecContext;
import org.opensearch.dataprepper.plugins.codec.json.JsonOutputCodec;
import org.opensearch.dataprepper.plugins.codec.json.JsonOutputCodecConfig;
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.services.lambda.model.InvokeRequest;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.TimeUnit;

/**
* A buffer can hold in memory data and flushing it.
*/
public class InMemoryBufferSynchronized implements Buffer {

private final ByteArrayOutputStream byteArrayOutputStream;

private final List<Record<Event>> records;
private final StopWatch bufferWatch;
private final StopWatch lambdaLatencyWatch;
private final OutputCodec requestCodec;
private final OutputCodecContext outputCodecContext;
private final long payloadResponseSize;
private int eventCount;
private long payloadRequestSize;


public InMemoryBufferSynchronized(String batchOptionKeyName) {
this(batchOptionKeyName, new OutputCodecContext());
}

public InMemoryBufferSynchronized(String batchOptionKeyName, OutputCodecContext outputCodecContext) {
byteArrayOutputStream = new ByteArrayOutputStream();
records = Collections.synchronizedList(new ArrayList<>());
bufferWatch = new StopWatch();
bufferWatch.start();
lambdaLatencyWatch = new StopWatch();
eventCount = 0;
payloadRequestSize = 0;
payloadResponseSize = 0;
// Setup request codec
JsonOutputCodecConfig jsonOutputCodecConfig = new JsonOutputCodecConfig();
jsonOutputCodecConfig.setKeyName(batchOptionKeyName);
requestCodec = new JsonOutputCodec(jsonOutputCodecConfig);
this.outputCodecContext = outputCodecContext;
}

/*
* Note: JsonCodec is NOT thread safe, so we need to synchronize this method
*/
@Override
public synchronized void addRecord(Record<Event> record) {
records.add(record);
Event event = record.getData();
try {
if (eventCount == 0) {
requestCodec.start(this.byteArrayOutputStream, event, this.outputCodecContext);
}
requestCodec.writeEvent(event, this.byteArrayOutputStream);
} catch (IOException e) {
throw new RuntimeException(e);
}
eventCount++;
}

@Override
public List<Record<Event>> getRecords() {
return records;
}

@Override
public long getSize() {
return byteArrayOutputStream.size();
}

@Override
public int getEventCount() {
return eventCount;
}

public Duration getDuration() {
return Duration.ofMillis(bufferWatch.getTime(TimeUnit.MILLISECONDS));
}

@Override
public InvokeRequest getRequestPayload(String functionName, String invocationType) {

if (eventCount == 0) {
//We never added any events so there is no payload
return null;
}

try {
requestCodec.complete(this.byteArrayOutputStream);
} catch (IOException e) {
throw new RuntimeException(e);
}

SdkBytes payload = getPayload();
payloadRequestSize = payload.asByteArray().length;

// Setup an InvokeRequest.
InvokeRequest request = InvokeRequest.builder()
.functionName(functionName)
.payload(payload)
.invocationType(invocationType)
.build();

synchronized (this) {
if (lambdaLatencyWatch.isStarted()) {
lambdaLatencyWatch.reset();
}
lambdaLatencyWatch.start();
}
return request;
}

public synchronized Duration stopLatencyWatch() {
if (lambdaLatencyWatch.isStarted()) {
lambdaLatencyWatch.stop();
}
long timeInMillis = lambdaLatencyWatch.getTime();
return Duration.ofMillis(timeInMillis);
}

@Override
public SdkBytes getPayload() {
byte[] bytes = byteArrayOutputStream.toByteArray();
return SdkBytes.fromByteArray(bytes);
}

public Duration getFlushLambdaLatencyMetric() {
return Duration.ofMillis(lambdaLatencyWatch.getTime(TimeUnit.MILLISECONDS));
}

public Long getPayloadRequestSize() {
return payloadRequestSize;
}

}

Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import org.opensearch.dataprepper.model.types.ByteCount;
import org.opensearch.dataprepper.plugins.lambda.common.LambdaCommonHandler;
import org.opensearch.dataprepper.plugins.lambda.common.accumlator.Buffer;
import org.opensearch.dataprepper.plugins.lambda.common.accumlator.InMemoryBuffer;
import org.opensearch.dataprepper.plugins.lambda.common.accumlator.InMemoryBufferSynchronized;
import org.opensearch.dataprepper.plugins.lambda.common.client.LambdaClientFactory;
import org.opensearch.dataprepper.plugins.lambda.common.config.ClientOptions;
import org.opensearch.dataprepper.plugins.lambda.common.util.ThresholdCheck;
Expand Down Expand Up @@ -152,7 +152,7 @@ public void doInitialize() {

private void doInitializeInternal() {
// Initialize the partial buffer
statefulBuffer = new InMemoryBuffer(
statefulBuffer = new InMemoryBufferSynchronized(
lambdaSinkConfig.getBatchOptions().getKeyName(),
outputCodecContext
);
Expand All @@ -164,15 +164,15 @@ private void doInitializeInternal() {
* do a time-based flush.
*/
@Override
public synchronized void shutdown() {
public void shutdown() {
// Flush the partial buffer if any leftover
if (statefulBuffer.getEventCount() > 0) {
flushBuffers(Collections.singletonList(statefulBuffer));
}
}

@Override
public synchronized void doOutput(final Collection<Record<Event>> records) {
public void doOutput(final Collection<Record<Event>> records) {
if (!sinkInitialized) {
LOG.warn("LambdaSink doOutput called before initialization");
return;
Expand All @@ -193,7 +193,7 @@ public synchronized void doOutput(final Collection<Record<Event>> records) {
// This buffer is full
fullBuffers.add(statefulBuffer);
// Create new partial buffer
statefulBuffer = new InMemoryBuffer(
statefulBuffer = new InMemoryBufferSynchronized(
lambdaSinkConfig.getBatchOptions().getKeyName(),
outputCodecContext
);
Expand Down Expand Up @@ -225,7 +225,7 @@ private DlqObject createDlqObjectFromEvent(final Event event,
.build();
}

synchronized void handleFailure(Collection<Record<Event>> failedRecords, Throwable throwable, int statusCode) {
void handleFailure(Collection<Record<Event>> failedRecords, Throwable throwable, int statusCode) {
if (failedRecords.isEmpty()) {
return;
}
Expand Down Expand Up @@ -265,7 +265,7 @@ private void releaseEventHandles(Collection<Record<Event>> records, boolean succ
}
}

private synchronized void flushBuffers(final List<Buffer> buffersToFlush) {
private void flushBuffers(final List<Buffer> buffersToFlush) {
// Combine all their records for a single call to sendRecords
List<Record<Event>> combinedRecords = new ArrayList<>();
for (Buffer buf : buffersToFlush) {
Expand All @@ -274,11 +274,10 @@ private synchronized void flushBuffers(final List<Buffer> buffersToFlush) {

Map<Buffer, CompletableFuture<InvokeResponse>> bufferToFutureMap;
try {
bufferToFutureMap = LambdaCommonHandler.sendRecords(
combinedRecords,
bufferToFutureMap = LambdaCommonHandler.invokeLambdaAndGetFutureMap(
lambdaSinkConfig,
lambdaAsyncClient,
outputCodecContext
buffersToFlush
);
} catch (Exception e) {
LOG.error(NOISY, "Error sending buffers to Lambda", e);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
package org.opensearch.dataprepper.plugins.lambda.common.accumulator;

import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.MockitoAnnotations;
import org.opensearch.dataprepper.model.event.Event;
import org.opensearch.dataprepper.model.event.JacksonEvent;
import org.opensearch.dataprepper.model.record.Record;
import org.opensearch.dataprepper.plugins.lambda.common.accumlator.InMemoryBufferSynchronized;
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.services.lambda.model.InvokeRequest;

import java.util.Collections;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertTrue;


class InMemoryBufferSynchronizedTest {
private AutoCloseable mocks;

@BeforeEach
void setUp() {
mocks = MockitoAnnotations.openMocks(this);
}

@AfterEach
void tearDown() throws Exception {
mocks.close();
}

@Test
void testAddRecordAndGetRecords() {
InMemoryBufferSynchronized buffer = new InMemoryBufferSynchronized("testKey");

// Initially empty
assertEquals(0, buffer.getEventCount());
assertTrue(buffer.getRecords().isEmpty());
assertEquals(0, buffer.getSize());

// Add a record
Event event = createSimpleEvent("hello", 123);
buffer.addRecord(new Record<>(event));

assertEquals(1, buffer.getEventCount());
assertEquals(1, buffer.getRecords().size());
assertTrue(buffer.getSize() > 0, "ByteArrayOutputStream should have some bytes after writing an event");
}

@Test
void testGetRequestPayloadWhenEmptyReturnsNull() {
InMemoryBufferSynchronized buffer = new InMemoryBufferSynchronized("testKey");
// No records added => eventCount=0
InvokeRequest request = buffer.getRequestPayload("someFunction", "RequestResponse");
assertNull(request, "Expected null request if no events are in the buffer");
}

@Test
void testGetRequestPayloadNonEmpty() {
InMemoryBufferSynchronized buffer = new InMemoryBufferSynchronized("testKey");
buffer.addRecord(new Record<>(createSimpleEvent("k1", 111)));
buffer.addRecord(new Record<>(createSimpleEvent("k2", 222)));

// Now we should have 2 events
assertEquals(2, buffer.getEventCount());

// getRequestPayload => closes JSON, returns an InvokeRequest
InvokeRequest request = buffer.getRequestPayload("testFunction", "RequestResponse");
assertNotNull(request);
// Should not be null after we finalize
SdkBytes payload = request.payload();
assertNotNull(payload);
// The payload should contain some JSON array with 2 items
String payloadString = payload.asUtf8String();
assertTrue(payloadString.contains("\"k1\":\"111\""), "Expected 'k1' field in JSON");
assertTrue(payloadString.contains("\"k2\":\"222\""), "Expected 'k2' field in JSON");

// Also, verify the payloadRequestSize is set
Long requestSize = buffer.getPayloadRequestSize();
assertNotNull(requestSize);
assertTrue(requestSize > 0, "Expected a non-zero payload request size");
}

@Test
void testConcurrentAddRecords() throws InterruptedException {
InMemoryBufferSynchronized buffer = new InMemoryBufferSynchronized("testKey");

int numThreads = 5;
int recordsPerThread = 10;
ExecutorService pool = Executors.newFixedThreadPool(numThreads);

// Each thread adds 10 records => total 50
for (int t = 0; t < numThreads; t++) {
pool.submit(() -> {
for (int i = 0; i < recordsPerThread; i++) {
buffer.addRecord(new Record<>(createSimpleEvent("thread", i)));
}
});
}
pool.shutdown();
assertTrue(pool.awaitTermination(5, TimeUnit.SECONDS),
"Threads did not finish in time");

// Should now have 50 records
assertEquals(numThreads * recordsPerThread, buffer.getEventCount());
assertEquals(numThreads * recordsPerThread, buffer.getRecords().size());

// ensure we get a JSON array with 50 items
InvokeRequest request = buffer.getRequestPayload("threadFunction", "RequestResponse");
String payloadStr = request.payload().asUtf8String();
// Just check if it has multiple items
long countOfThread = countOccurrences(payloadStr, "\"thread\":\"");
assertTrue(countOfThread >= numThreads,
"Expected multiple 'thread' fields in the JSON payload, found " + countOfThread);
}

// Utility to create a simple test event
private Event createSimpleEvent(String key, int value) {
// This is just one possible way to create a test Event
return JacksonEvent.builder()
.withData(Collections.singletonMap(key, String.valueOf(value)))
.withEventType("TEST")
.build();
}

// Utility to count occurrences of a substring
private static long countOccurrences(String haystack, String needle) {
long count = 0;
int idx = 0;
while ((idx = haystack.indexOf(needle, idx)) != -1) {
count++;
idx += needle.length();
}
return count;
}
}
Loading

0 comments on commit ced7a2d

Please sign in to comment.