Skip to content

Commit

Permalink
Merge pull request #448 from rubanm/rubanm/check_read_length
Browse files Browse the repository at this point in the history
Add container size check in ThriftBinaryProtocol
  • Loading branch information
rubanm committed Sep 2, 2015
2 parents 5657c75 + e882509 commit fb94ed1
Show file tree
Hide file tree
Showing 3 changed files with 254 additions and 0 deletions.
5 changes: 5 additions & 0 deletions core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@
<groupId>junit</groupId>
<artifactId>junit</artifactId>
</dependency>
<dependency>
<groupId>org.easymock</groupId>
<artifactId>easymock</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>ant</groupId>
<artifactId>ant</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import org.apache.thrift.protocol.TList;
import org.apache.thrift.protocol.TMap;
import org.apache.thrift.protocol.TProtocol;
import org.apache.thrift.protocol.TProtocolException;
import org.apache.thrift.protocol.TProtocolFactory;
import org.apache.thrift.protocol.TSet;
import org.apache.thrift.protocol.TType;
Expand Down Expand Up @@ -58,6 +59,7 @@ public static void checkContainerElemType(byte type) throws TException {
@Override
public TMap readMapBegin() throws TException {
TMap map = super.readMapBegin();
checkContainerSize(map.size);
checkContainerElemType(map.keyType);
checkContainerElemType(map.valueType);
return map;
Expand All @@ -66,17 +68,39 @@ public TMap readMapBegin() throws TException {
@Override
public TList readListBegin() throws TException {
TList list = super.readListBegin();
checkContainerSize(list.size);
checkContainerElemType(list.elemType);
return list;
}

@Override
public TSet readSetBegin() throws TException {
TSet set = super.readSetBegin();
checkContainerSize(set.size);
checkContainerElemType(set.elemType);
return set;
}

/**
* Check if the container size if valid.
*
* NOTE: This assumes that the elements are one byte each.
* So this does not catch all cases, but does increase the chances of
* handling malformed lengths when the number of remaining bytes in
* the underlying Transport is clearly less than the container size
* that the Transport provides.
*/
protected void checkContainerSize(int size) throws TProtocolException {
if (size < 0) {
throw new TProtocolException("Negative container size: " + size);
}
if (checkReadLength_) {
if ((readLength_ - size) < 0) {
throw new TProtocolException("Remaining message length is " + readLength_ + " but container size in underlying TTransport is set to at least: " + size);
}
}
}

public static class Factory implements TProtocolFactory {

public TProtocol getProtocol(TTransport trans) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
package com.twitter.elephantbird.thrift;

import java.io.IOException;
import java.nio.ByteBuffer;

import org.apache.thrift.TException;
import org.apache.thrift.protocol.TProtocolException;
import org.apache.thrift.protocol.TType;
import org.apache.thrift.transport.TTransport;
import org.apache.thrift.transport.TTransportException;
import org.easymock.IAnswer;
import org.easymock.EasyMock;
import org.junit.Test;

import static org.easymock.EasyMock.anyInt;
import static org.easymock.EasyMock.createStrictMock;
import static org.easymock.EasyMock.expect;
import static org.easymock.EasyMock.isA;
import static org.easymock.EasyMock.replay;
import static org.easymock.EasyMock.verify;

public class TestThriftBinaryProtocol {

int METADATA_BYTES = 5; // type(1) + size(4)
int MAP_METADATA_BYTES = 6; // key type(1) + value type(1) + size(4)

// helper method to set container size correctly in the supplied byte array
protected void setContainerSize(byte[] buf, int n) {
byte[] b = ByteBuffer.allocate(4).putInt(n).array();
for (int i = 0; i < 4; i++) {
buf[i] = b[i];
}
}

protected void setDataType(byte[] buf) {
buf[0] = TType.BYTE;
}

// mock transport for Set and List container types
private TTransport getMockTransport(final int containerSize) throws TException {
TTransport transport = createStrictMock(TTransport.class);
// not using buffered mode for tests, so return -1 per the contract
expect(transport.getBytesRemainingInBuffer()).andReturn(-1);
// first call, set data type
expect(transport.readAll(isA(byte[].class), anyInt(), anyInt()))
.andAnswer(
new IAnswer<Integer>() {
public Integer answer() {
byte[] buf = (byte[])(EasyMock.getCurrentArguments()[0]);
setDataType(buf);
return 1;
}
}
);
expect(transport.getBytesRemainingInBuffer()).andReturn(-1);
// second call, set container size
expect(transport.readAll(isA(byte[].class), anyInt(), anyInt()))
.andAnswer(
new IAnswer<Integer>() {
public Integer answer() {
byte[] buf = (byte[])(EasyMock.getCurrentArguments()[0]);
setContainerSize(buf, containerSize);
return 4;
}
}
);
return transport;
}

// mock transport for Map container type
private TTransport getMockMapTransport(final int containerSize) throws TException {
TTransport transport = createStrictMock(TTransport.class);
// not using buffered mode for tests, so return -1 per the contract
expect(transport.getBytesRemainingInBuffer()).andReturn(-1);
// first call, set key type
expect(transport.readAll(isA(byte[].class), anyInt(), anyInt()))
.andAnswer(
new IAnswer<Integer>() {
public Integer answer() {
byte[] buf = (byte[])(EasyMock.getCurrentArguments()[0]);
setDataType(buf);
return 1;
}
}
);
expect(transport.getBytesRemainingInBuffer()).andReturn(-1);
// second call, set value type
expect(transport.readAll(isA(byte[].class), anyInt(), anyInt()))
.andAnswer(
new IAnswer<Integer>() {
public Integer answer() {
byte[] buf = (byte[])(EasyMock.getCurrentArguments()[0]);
setDataType(buf);
return 1;
}
}
);
expect(transport.getBytesRemainingInBuffer()).andReturn(-1);
// third call, set container size
expect(transport.readAll(isA(byte[].class), anyInt(), anyInt()))
.andAnswer(
new IAnswer<Integer>() {
public Integer answer() {
byte[] buf = (byte[])(EasyMock.getCurrentArguments()[0]);
setContainerSize(buf, containerSize);
return 4;
}
}
);
return transport;
}

@Test
public void testCheckContainerSizeValid() throws TException {
// any non-negative value is considered valid when checkReadLength is not enabled
TTransport transport;
ThriftBinaryProtocol protocol;

transport = getMockTransport(3);
replay(transport);
protocol = new ThriftBinaryProtocol(transport);
protocol.readListBegin();
verify(transport);

transport = getMockTransport(3);
replay(transport);
protocol = new ThriftBinaryProtocol(transport);
protocol.readSetBegin();
verify(transport);

transport = getMockMapTransport(3);
replay(transport);
protocol = new ThriftBinaryProtocol(transport);
protocol.readMapBegin();
verify(transport);
}

@Test
public void testCheckContainerSizeValidWhenCheckReadLength() throws TException {
TTransport transport;
ThriftBinaryProtocol protocol;

transport = getMockTransport(3);
replay(transport);
protocol = new ThriftBinaryProtocol(transport);
protocol.setReadLength(METADATA_BYTES + 3);
protocol.readListBegin();
verify(transport);

transport = getMockTransport(3);
replay(transport);
protocol = new ThriftBinaryProtocol(transport);
protocol.setReadLength(METADATA_BYTES + 3);
protocol.readSetBegin();
verify(transport);

transport = getMockMapTransport(3);
replay(transport);
protocol = new ThriftBinaryProtocol(transport);
protocol.setReadLength(MAP_METADATA_BYTES + 3);
protocol.readMapBegin();
verify(transport);
}

@Test(expected=TProtocolException.class)
public void testCheckListContainerSizeInvalid() throws TException {
// any negative value is considered invalid when checkReadLength is not enabled
TTransport transport = getMockTransport(-1);
replay(transport);
ThriftBinaryProtocol protocol = new ThriftBinaryProtocol(transport);
protocol.readListBegin();
verify(transport);
}

@Test(expected=TProtocolException.class)
public void testCheckSetContainerSizeInvalid() throws TException {
TTransport transport = getMockTransport(-1);
replay(transport);
ThriftBinaryProtocol protocol = new ThriftBinaryProtocol(transport);
protocol.readSetBegin();
verify(transport);
}

@Test(expected=TProtocolException.class)
public void testCheckMapContainerSizeInvalid() throws TException {
TTransport transport = getMockMapTransport(-1);
replay(transport);
ThriftBinaryProtocol protocol = new ThriftBinaryProtocol(transport);
protocol.readMapBegin();
verify(transport);
}

@Test(expected=TProtocolException.class)
public void testCheckListContainerSizeInvalidWhenCheckReadLength() throws TException {
TTransport transport = getMockTransport(400);
replay(transport);
ThriftBinaryProtocol protocol = new ThriftBinaryProtocol(transport);
protocol.setReadLength(METADATA_BYTES + 3);
// this throws because size returned by Transport (400) > size per readLength (3)
protocol.readListBegin();
verify(transport);
}

@Test(expected=TProtocolException.class)
public void testCheckSetContainerSizeInvalidWhenCheckReadLength() throws TException {
TTransport transport = getMockTransport(400);
replay(transport);
ThriftBinaryProtocol protocol = new ThriftBinaryProtocol(transport);
// this throws because size returned by Transport (400) > size per readLength (3)
protocol.setReadLength(METADATA_BYTES + 3);
protocol.readSetBegin();
verify(transport);
}

@Test(expected=TProtocolException.class)
public void testCheckMapContainerSizeInvalidWhenCheckReadLength() throws TException {
TTransport transport = getMockMapTransport(400);
replay(transport);
ThriftBinaryProtocol protocol = new ThriftBinaryProtocol(transport);
// this throws because size returned by Transport (400) > size per readLength (3)
protocol.setReadLength(MAP_METADATA_BYTES + 3);
protocol.readMapBegin();
verify(transport);
}
}

0 comments on commit fb94ed1

Please sign in to comment.