Skip to content

Commit

Permalink
Add octet streaming of sketchs in MSQ (#16269)
Browse files Browse the repository at this point in the history
There are a few issues with using Jackson serialization in sending datasketches between controller and worker in MSQ. This caused a blowup due to holding multiple copies of the sketch being stored.

This PR aims to resolve this by switching to deserializing the sketch payload without Jackson.

The PR adds a new query parameter used during communication between controller and worker while fetching sketches, "sketchEncoding".

    If the value of this parameter is OCTET, the sketch is returned as a binary encoding, done by ClusterByStatisticsSnapshotSerde.
    If the value is not the above, the sketch is encoded by Jackson as before.
  • Loading branch information
adarshsanjeev authored May 28, 2024
1 parent 9d77ef0 commit 21f725f
Show file tree
Hide file tree
Showing 21 changed files with 1,171 additions and 37 deletions.
6 changes: 6 additions & 0 deletions benchmarks/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,12 @@
<version>${project.parent.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.apache.druid.extensions</groupId>
<artifactId>druid-multi-stage-query</artifactId>
<version>${project.parent.version}</version>
<scope>test</scope>
</dependency>
</dependencies>

<properties>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.druid.benchmark;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.collect.ImmutableList;
import org.apache.druid.frame.key.ClusterBy;
import org.apache.druid.frame.key.KeyColumn;
import org.apache.druid.frame.key.KeyOrder;
import org.apache.druid.frame.key.KeyTestUtils;
import org.apache.druid.frame.key.RowKey;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.msq.statistics.ClusterByStatisticsCollectorImpl;
import org.apache.druid.msq.statistics.ClusterByStatisticsSnapshot;
import org.apache.druid.msq.statistics.serde.ClusterByStatisticsSnapshotSerde;
import org.apache.druid.segment.TestHelper;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.testing.InitializedNullHandlingTest;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Level;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Param;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.Warmup;
import org.openjdk.jmh.infra.Blackhole;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.concurrent.TimeUnit;
import java.util.stream.LongStream;

@State(Scope.Benchmark)
@Fork(value = 1)
@Warmup(iterations = 3)
@Measurement(iterations = 5)
public class MsqSketchesBenchmark extends InitializedNullHandlingTest
{
private static final int MAX_BYTES = 1_000_000_000;
private static final int MAX_BUCKETS = 10_000;

private static final RowSignature SIGNATURE = RowSignature.builder()
.add("x", ColumnType.LONG)
.add("y", ColumnType.LONG)
.add("z", ColumnType.STRING)
.build();

private static final ClusterBy CLUSTER_BY_XYZ_BUCKET_BY_X = new ClusterBy(
ImmutableList.of(
new KeyColumn("x", KeyOrder.ASCENDING),
new KeyColumn("y", KeyOrder.ASCENDING),
new KeyColumn("z", KeyOrder.ASCENDING)
),
1
);

@Param({"1", "1000"})
private long numBuckets;

@Param({"100000", "1000000"})
private long numRows;

@Param({"true", "false"})
private boolean aggregate;

private ObjectMapper jsonMapper;
private ClusterByStatisticsSnapshot snapshot;

@Setup(Level.Trial)
public void setup()
{
jsonMapper = TestHelper.makeJsonMapper();
final Iterable<RowKey> keys = () ->
LongStream.range(0, numRows)
.mapToObj(n -> createKey(numBuckets, n))
.iterator();

ClusterByStatisticsCollectorImpl collector = makeCollector(aggregate);
keys.forEach(k -> collector.add(k, 1));
snapshot = collector.snapshot();
}

@Benchmark
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.MILLISECONDS)
public void benchmarkJacksonSketch(Blackhole blackhole) throws IOException
{
final byte[] serializedSnapshot = jsonMapper.writeValueAsBytes(snapshot);

final ClusterByStatisticsSnapshot deserializedSnapshot = jsonMapper.readValue(
serializedSnapshot,
ClusterByStatisticsSnapshot.class
);
blackhole.consume(deserializedSnapshot);
}

@Benchmark
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.MILLISECONDS)
public void benchmarkOctetSketch(Blackhole blackhole) throws IOException
{
ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
ClusterByStatisticsSnapshotSerde.serialize(byteArrayOutputStream, snapshot);
final ByteBuffer serializedSnapshot = ByteBuffer.wrap(byteArrayOutputStream.toByteArray());
final ClusterByStatisticsSnapshot deserializedSnapshot = ClusterByStatisticsSnapshotSerde.deserialize(serializedSnapshot);
blackhole.consume(deserializedSnapshot);
}

private ClusterByStatisticsCollectorImpl makeCollector(final boolean aggregate)
{
return (ClusterByStatisticsCollectorImpl) ClusterByStatisticsCollectorImpl.create(MsqSketchesBenchmark.CLUSTER_BY_XYZ_BUCKET_BY_X, SIGNATURE, MAX_BYTES, MAX_BUCKETS, aggregate, false);
}

private static RowKey createKey(final long numBuckets, final long keyNo)
{
final Object[] key = new Object[3];
key[0] = keyNo % numBuckets;
key[1] = keyNo % 5;
key[2] = StringUtils.repeat("*", 67);
return KeyTestUtils.createKey(KeyTestUtils.createKeySignature(MsqSketchesBenchmark.CLUSTER_BY_XYZ_BUCKET_BY_X.getColumns(), SIGNATURE), key);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.druid.msq.indexing.client;

import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.apache.druid.java.util.http.client.response.BytesFullResponseHolder;
import org.apache.druid.java.util.http.client.response.ClientResponse;
import org.apache.druid.java.util.http.client.response.HttpResponseHandler;
import org.apache.druid.msq.statistics.ClusterByStatisticsSnapshot;
import org.apache.druid.msq.statistics.serde.ClusterByStatisticsSnapshotSerde;
import org.jboss.netty.buffer.ChannelBuffer;
import org.jboss.netty.handler.codec.http.HttpChunk;
import org.jboss.netty.handler.codec.http.HttpResponse;

import javax.ws.rs.core.HttpHeaders;
import javax.ws.rs.core.MediaType;
import java.nio.ByteBuffer;
import java.util.function.Function;

public class SketchResponseHandler implements HttpResponseHandler<BytesFullResponseHolder, ClusterByStatisticsSnapshot>
{
private final ObjectMapper jsonMapper;
private Function<BytesFullResponseHolder, ClusterByStatisticsSnapshot> deserializerFunction;

public SketchResponseHandler(ObjectMapper jsonMapper)
{
this.jsonMapper = jsonMapper;
}

@Override
public ClientResponse<BytesFullResponseHolder> handleResponse(HttpResponse response, HttpResponseHandler.TrafficCop
trafficCop)
{
final BytesFullResponseHolder holder = new BytesFullResponseHolder(response);
final String contentType = response.headers().get(HttpHeaders.CONTENT_TYPE);
if (MediaType.APPLICATION_OCTET_STREAM.equals(contentType)) {
deserializerFunction = responseHolder -> ClusterByStatisticsSnapshotSerde.deserialize(ByteBuffer.wrap(responseHolder.getContent()));
} else {
deserializerFunction = responseHolder -> responseHolder.deserialize(jsonMapper, new TypeReference<ClusterByStatisticsSnapshot>()
{
});
}
holder.addChunk(getContentBytes(response.getContent()));

return ClientResponse.unfinished(holder);
}

@Override
public ClientResponse<BytesFullResponseHolder> handleChunk(
ClientResponse<BytesFullResponseHolder> response,
HttpChunk chunk,
long chunkNum
)
{
BytesFullResponseHolder holder = response.getObj();

if (holder == null) {
return ClientResponse.finished(null);
}

holder.addChunk(getContentBytes(chunk.getContent()));
return response;
}

@Override
public ClientResponse<ClusterByStatisticsSnapshot> done(ClientResponse<BytesFullResponseHolder> response)
{
return ClientResponse.finished(deserializerFunction.apply(response.getObj()));
}

@Override
public void exceptionCaught(ClientResponse<BytesFullResponseHolder> clientResponse, Throwable e)
{
}

private byte[] getContentBytes(ChannelBuffer content)
{
byte[] contentBytes = new byte[content.readableBytes()];
content.readBytes(contentBytes);
return contentBytes;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,13 @@
import org.apache.druid.msq.kernel.StageId;
import org.apache.druid.msq.kernel.WorkOrder;
import org.apache.druid.msq.statistics.ClusterByStatisticsSnapshot;
import org.apache.druid.msq.statistics.serde.ClusterByStatisticsSnapshotSerde;
import org.apache.druid.segment.realtime.firehose.ChatHandler;
import org.apache.druid.segment.realtime.firehose.ChatHandlers;
import org.apache.druid.server.security.Action;
import org.apache.druid.utils.CloseableUtils;

import javax.annotation.Nullable;
import javax.servlet.http.HttpServletRequest;
import javax.ws.rs.Consumes;
import javax.ws.rs.GET;
Expand Down Expand Up @@ -185,11 +187,12 @@ public Response httpPostResultPartitionBoundaries(

@POST
@Path("/keyStatistics/{queryId}/{stageNumber}")
@Produces(MediaType.APPLICATION_JSON)
@Produces({MediaType.APPLICATION_JSON, MediaType.APPLICATION_OCTET_STREAM})
@Consumes(MediaType.APPLICATION_JSON)
public Response httpFetchKeyStatistics(
@PathParam("queryId") final String queryId,
@PathParam("stageNumber") final int stageNumber,
@QueryParam("sketchEncoding") @Nullable final SketchEncoding sketchEncoding,
@Context final HttpServletRequest req
)
{
Expand All @@ -198,9 +201,17 @@ public Response httpFetchKeyStatistics(
StageId stageId = new StageId(queryId, stageNumber);
try {
clusterByStatisticsSnapshot = worker.fetchStatisticsSnapshot(stageId);
return Response.status(Response.Status.ACCEPTED)
.entity(clusterByStatisticsSnapshot)
.build();
if (SketchEncoding.OCTET_STREAM.equals(sketchEncoding)) {
return Response.status(Response.Status.ACCEPTED)
.type(MediaType.APPLICATION_OCTET_STREAM)
.entity((StreamingOutput) output -> ClusterByStatisticsSnapshotSerde.serialize(output, clusterByStatisticsSnapshot))
.build();
} else {
return Response.status(Response.Status.ACCEPTED)
.type(MediaType.APPLICATION_JSON)
.entity(clusterByStatisticsSnapshot)
.build();
}
}
catch (Exception e) {
String errorMessage = StringUtils.format(
Expand All @@ -217,12 +228,13 @@ public Response httpFetchKeyStatistics(

@POST
@Path("/keyStatisticsForTimeChunk/{queryId}/{stageNumber}/{timeChunk}")
@Produces(MediaType.APPLICATION_JSON)
@Produces({MediaType.APPLICATION_JSON, MediaType.APPLICATION_OCTET_STREAM})
@Consumes(MediaType.APPLICATION_JSON)
public Response httpFetchKeyStatisticsWithSnapshot(
@PathParam("queryId") final String queryId,
@PathParam("stageNumber") final int stageNumber,
@PathParam("timeChunk") final long timeChunk,
@QueryParam("sketchEncoding") @Nullable final SketchEncoding sketchEncoding,
@Context final HttpServletRequest req
)
{
Expand All @@ -231,9 +243,17 @@ public Response httpFetchKeyStatisticsWithSnapshot(
StageId stageId = new StageId(queryId, stageNumber);
try {
snapshotForTimeChunk = worker.fetchStatisticsSnapshotForTimeChunk(stageId, timeChunk);
return Response.status(Response.Status.ACCEPTED)
.entity(snapshotForTimeChunk)
.build();
if (SketchEncoding.OCTET_STREAM.equals(sketchEncoding)) {
return Response.status(Response.Status.ACCEPTED)
.type(MediaType.APPLICATION_OCTET_STREAM)
.entity((StreamingOutput) output -> ClusterByStatisticsSnapshotSerde.serialize(output, snapshotForTimeChunk))
.build();
} else {
return Response.status(Response.Status.ACCEPTED)
.type(MediaType.APPLICATION_JSON)
.entity(snapshotForTimeChunk)
.build();
}
}
catch (Exception e) {
String errorMessage = StringUtils.format(
Expand Down Expand Up @@ -289,4 +309,20 @@ public Response httpGetCounters(@Context final HttpServletRequest req)
ChatHandlers.authorizationCheck(req, Action.WRITE, task.getDataSource(), toolbox.getAuthorizerMapper());
return Response.status(Response.Status.OK).entity(worker.getCounters()).build();
}

/**
* Determines the encoding of key collectors returned by {@link #httpFetchKeyStatistics} and
* {@link #httpFetchKeyStatisticsWithSnapshot}.
*/
public enum SketchEncoding
{
/**
* The key collector is encoded as a byte stream with {@link ClusterByStatisticsSnapshotSerde}.
*/
OCTET_STREAM,
/**
* The key collector is encoded as json
*/
JSON
}
}
Loading

0 comments on commit 21f725f

Please sign in to comment.