Skip to content

Commit

Permalink
HTTPCORE-775: The SSLIOSession::write does not handle SSLEngineResult…
Browse files Browse the repository at this point in the history
…#BUFFER_OVERFLOW

Signed-off-by: Andriy Redko <[email protected]>
  • Loading branch information
reta authored and ok2c committed Jan 14, 2025
1 parent c897f5a commit a4ae27d
Show file tree
Hide file tree
Showing 8 changed files with 156 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,19 @@
import org.apache.hc.core5.ssl.SSLContextBuilder;

public final class SSLTestContexts {

public static SSLContext createServerSSLContext() {
return createServerSSLContext(null);
}

public static SSLContext createServerSSLContext(final String protocol) {
final URL keyStoreURL = SSLTestContexts.class.getResource("/test.p12");
final String storePassword = "nopassword";
try {
return SSLContextBuilder.create()
.setKeyStoreType("pkcs12")
.loadTrustMaterial(keyStoreURL, storePassword.toCharArray())
.loadKeyMaterial(keyStoreURL, storePassword.toCharArray(), storePassword.toCharArray())
.setProtocol(protocol)
.build();
} catch (final NoSuchAlgorithmException | KeyManagementException | KeyStoreException | CertificateException |
UnrecoverableKeyException | IOException ex) {
Expand All @@ -57,12 +61,17 @@ public static SSLContext createServerSSLContext() {
}

public static SSLContext createClientSSLContext() {
return createClientSSLContext(null);
}

public static SSLContext createClientSSLContext(final String protocol) {
final URL keyStoreURL = SSLTestContexts.class.getResource("/test.p12");
final String storePassword = "nopassword";
try {
return SSLContextBuilder.create()
.setKeyStoreType("pkcs12")
.loadTrustMaterial(keyStoreURL, storePassword.toCharArray())
.setProtocol(protocol)
.build();
} catch (final NoSuchAlgorithmException | KeyManagementException | KeyStoreException | CertificateException |
IOException ex) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
import java.util.Queue;
import java.util.Random;
import java.util.concurrent.Future;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import org.apache.hc.core5.concurrent.Cancellable;
import org.apache.hc.core5.concurrent.CountDownLatchFutureCallback;
Expand All @@ -48,6 +50,7 @@
import org.apache.hc.core5.http.impl.bootstrap.HttpAsyncServer;
import org.apache.hc.core5.http.impl.routing.RequestRouter;
import org.apache.hc.core5.http.nio.AsyncServerExchangeHandler;
import org.apache.hc.core5.http.nio.entity.AsyncEntityProducers;
import org.apache.hc.core5.http.nio.entity.StringAsyncEntityConsumer;
import org.apache.hc.core5.http.nio.entity.StringAsyncEntityProducer;
import org.apache.hc.core5.http.nio.support.BasicClientExchangeHandler;
Expand Down Expand Up @@ -146,6 +149,27 @@ void testSequentialRequests() throws Exception {
assertThat(body3, CoreMatchers.equalTo("some more stuff"));
}

@Test
void testLargeRequest() throws Exception {
final HttpAsyncServer server = serverResource.start();
final Future<ListenerEndpoint> future = server.listen(new InetSocketAddress(0), scheme);
final ListenerEndpoint listener = future.get();
final InetSocketAddress address = (InetSocketAddress) listener.getAddress();
final H2MultiplexingRequester requester = clientResource.start();

final HttpHost target = new HttpHost(scheme.id, "localhost", address.getPort());
final String content = IntStream.range(0, 1000).mapToObj(i -> "a lot of stuff").collect(Collectors.joining(" "));
final Future<Message<HttpResponse, String>> resultFuture = requester.execute(
new BasicRequestProducer(Method.POST, target, "/a-lot-of-stuff", AsyncEntityProducers.create(content, ContentType.TEXT_PLAIN)),
new BasicResponseConsumer<>(new StringAsyncEntityConsumer()), TIMEOUT, null);
final Message<HttpResponse, String> message = resultFuture.get(TIMEOUT.getDuration(), TIMEOUT.getTimeUnit());
assertThat(message, CoreMatchers.notNullValue());
final HttpResponse response = message.getHead();
assertThat(response.getCode(), CoreMatchers.equalTo(HttpStatus.SC_OK));
final String body = message.getBody();
assertThat(body, CoreMatchers.equalTo(content));
}

@Test
void testMultiplexedRequests() throws Exception {
final HttpAsyncServer server = serverResource.start();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,15 @@ abstract class H2CoreTransportTest extends HttpCoreTransportTest {
private final H2AsyncRequesterResource clientResource;

public H2CoreTransportTest(final URIScheme scheme) {
this(scheme, null);
}

public H2CoreTransportTest(final URIScheme scheme, final String tlsProtocol) {
super(scheme);
this.serverResource = new H2AsyncServerResource();
this.serverResource.configure(bootstrap -> bootstrap
.setVersionPolicy(HttpVersionPolicy.NEGOTIATE)
.setTlsStrategy(new H2ServerTlsStrategy(SSLTestContexts.createServerSSLContext()))
.setTlsStrategy(new H2ServerTlsStrategy(SSLTestContexts.createServerSSLContext(tlsProtocol)))
.setIOReactorConfig(
IOReactorConfig.custom()
.setSoTimeout(TIMEOUT)
Expand All @@ -72,7 +76,7 @@ public H2CoreTransportTest(final URIScheme scheme) {
this.clientResource = new H2AsyncRequesterResource();
this.clientResource.configure(bootstrap -> bootstrap
.setVersionPolicy(HttpVersionPolicy.NEGOTIATE)
.setTlsStrategy(new H2ClientTlsStrategy(SSLTestContexts.createClientSSLContext()))
.setTlsStrategy(new H2ClientTlsStrategy(SSLTestContexts.createClientSSLContext(tlsProtocol)))
.setIOReactorConfig(IOReactorConfig.custom()
.setSoTimeout(TIMEOUT)
.build())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,14 @@ abstract class Http1CoreTransportTest extends HttpCoreTransportTest {
private final HttpAsyncRequesterResource clientResource;

public Http1CoreTransportTest(final URIScheme scheme) {
this(scheme, null);
}

public Http1CoreTransportTest(final URIScheme scheme, final String tlsProtocol) {
super(scheme);
this.serverResource = new HttpAsyncServerResource();
this.serverResource.configure(bootstrap -> bootstrap
.setTlsStrategy(new H2ServerTlsStrategy(SSLTestContexts.createServerSSLContext()))
.setTlsStrategy(new H2ServerTlsStrategy(SSLTestContexts.createServerSSLContext(tlsProtocol)))
.setIOReactorConfig(
IOReactorConfig.custom()
.setSoTimeout(TIMEOUT)
Expand Down Expand Up @@ -121,7 +125,7 @@ public void pushPromise(
);
this.clientResource = new HttpAsyncRequesterResource();
this.clientResource.configure(bootstrap -> bootstrap
.setTlsStrategy(new H2ClientTlsStrategy(SSLTestContexts.createClientSSLContext()))
.setTlsStrategy(new H2ClientTlsStrategy(SSLTestContexts.createClientSSLContext(tlsProtocol)))
.setIOReactorConfig(IOReactorConfig.custom()
.setSoTimeout(TIMEOUT)
.build())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
import java.util.LinkedList;
import java.util.Queue;
import java.util.concurrent.Future;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import org.apache.hc.core5.http.ContentType;
import org.apache.hc.core5.http.HttpHost;
Expand All @@ -47,6 +49,7 @@
import org.apache.hc.core5.http.impl.bootstrap.HttpAsyncServer;
import org.apache.hc.core5.http.message.BasicHttpRequest;
import org.apache.hc.core5.http.nio.AsyncClientEndpoint;
import org.apache.hc.core5.http.nio.entity.AsyncEntityProducers;
import org.apache.hc.core5.http.nio.entity.StringAsyncEntityConsumer;
import org.apache.hc.core5.http.nio.entity.StringAsyncEntityProducer;
import org.apache.hc.core5.http.nio.support.BasicRequestProducer;
Expand Down Expand Up @@ -113,6 +116,27 @@ void testSequentialRequests() throws Exception {
assertThat(body3, CoreMatchers.equalTo("some more stuff"));
}

@Test
void testLargeRequest() throws Exception {
final HttpAsyncServer server = serverStart();
final Future<ListenerEndpoint> future = server.listen(new InetSocketAddress(0), scheme);
final ListenerEndpoint listener = future.get();
final InetSocketAddress address = (InetSocketAddress) listener.getAddress();
final HttpAsyncRequester requester = clientStart();

final HttpHost target = new HttpHost(scheme.id, "localhost", address.getPort());
final String content = IntStream.range(0, 1000).mapToObj(i -> "a lot of stuff").collect(Collectors.joining(" "));
final Future<Message<HttpResponse, String>> resultFuture = requester.execute(
new BasicRequestProducer(Method.POST, target, "/a-lot-of-stuff", AsyncEntityProducers.create(content, ContentType.TEXT_PLAIN)),
new BasicResponseConsumer<>(new StringAsyncEntityConsumer()), TIMEOUT, null);
final Message<HttpResponse, String> message = resultFuture.get(TIMEOUT.getDuration(), TIMEOUT.getTimeUnit());
assertThat(message, CoreMatchers.notNullValue());
final HttpResponse response = message.getHead();
assertThat(response.getCode(), CoreMatchers.equalTo(HttpStatus.SC_OK));
final String body = message.getBody();
assertThat(body, CoreMatchers.equalTo(content));
}

@Test
void testSequentialRequestsNonPersistentConnection() throws Exception {
final HttpAsyncServer server = serverStart();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,16 @@ public CoreTransportTls() {

}

@Nested
@DisplayName("Core transport (HTTP/1.1, TLSv1.3)")
class CoreTransportTls13 extends Http1CoreTransportTest {

public CoreTransportTls13() {
super(URIScheme.HTTPS, "TLSv1.3");
}

}

@Nested
@DisplayName("Core transport (H2)")
class CoreTransportH2 extends H2CoreTransportTest {
Expand All @@ -73,6 +83,16 @@ public CoreTransportH2Tls() {

}

@Nested
@DisplayName("Core transport (H2, TLSv1.3)")
class CoreTransportH2Tls13 extends H2CoreTransportTest {

public CoreTransportH2Tls13() {
super(URIScheme.HTTPS, "TLSv1.3");
}

}

@Nested
@DisplayName("Core transport (H2, multiplexing)")
class CoreTransportH2Multiplexing extends H2CoreTransportMultiplexingTest {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ enum TLSHandShakeState { READY, INITIALIZED, HANDSHAKING, COMPLETE }
private final AtomicInteger outboundClosedCount;
private final AtomicReference<TLSHandShakeState> handshakeStateRef;
private final IOEventHandler internalEventHandler;
private final int packetBufferSize;

private int appEventMask;

Expand Down Expand Up @@ -178,9 +179,9 @@ public SSLIOSession(

final SSLSession sslSession = this.sslEngine.getSession();
// Allocate buffers for network (encrypted) data
final int netBufferSize = sslSession.getPacketBufferSize();
this.inEncrypted = SSLManagedBuffer.create(sslBufferMode, netBufferSize);
this.outEncrypted = SSLManagedBuffer.create(sslBufferMode, netBufferSize);
this.packetBufferSize = sslSession.getPacketBufferSize();
this.inEncrypted = SSLManagedBuffer.create(sslBufferMode, packetBufferSize);
this.outEncrypted = SSLManagedBuffer.create(sslBufferMode, packetBufferSize);

// Allocate buffers for application (unencrypted) data
final int appBufferSize = sslSession.getApplicationBufferSize();
Expand Down Expand Up @@ -668,9 +669,18 @@ public int write(final ByteBuffer src) throws IOException {
if (this.handshakeStateRef.get() == TLSHandShakeState.READY) {
return 0;
}
final ByteBuffer outEncryptedBuf = this.outEncrypted.acquire();
final SSLEngineResult result = doWrap(src, outEncryptedBuf);
return result.bytesConsumed();

for (;;) {
final ByteBuffer outEncryptedBuf = this.outEncrypted.acquire();
final SSLEngineResult result = doWrap(src, outEncryptedBuf);
if (result.getStatus() == SSLEngineResult.Status.BUFFER_OVERFLOW) {
// We don't release the buffer here, it will be expanded (if needed)
// and returned by the next attempt of SSLManagedBuffer#acquire() call.
this.outEncrypted.ensureWriteable(packetBufferSize);
} else {
return result.bytesConsumed();
}
}
} finally {
this.session.getLock().unlock();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,54 @@ abstract class SSLManagedBuffer {
*/
abstract boolean hasData();

/**
* Expands the underlying buffer's to make sure it has enough write capacity to accommodate
* the required amount of bytes. This method has no side effect if the buffer has enough writeable
* capacity left.
* @param size the required write capacity
*/
abstract void ensureWriteable(final int size);

/**
* Helper method to ensure additional writeable capacity with respect to the source buffer. It
* allocates a new buffer and copies all the data if needed, returning the new buffer. This method
* has no side effect if the source buffer has enough writeable capacity left.
* @param src source buffer
* @param size the required write capacity
* @return new buffer (or the source buffer of it has enough writeable capacity left)
*/
ByteBuffer ensureWriteable(final ByteBuffer src, final int size) {
if (src == null) {
// Nothing to do, the buffer is not allocated
return null;
}

// There is not enough capacity left, we need to expand
if (src.remaining() < size) {
final int additionalCapacityNeeded = size - src.remaining();
final ByteBuffer expanded = ByteBuffer.allocate(src.capacity() + additionalCapacityNeeded);

// use a duplicated buffer so we don't disrupt the limit of the original buffer
final ByteBuffer tmp = src.duplicate();
tmp.flip();

// Copy to expanded buffer
expanded.put(tmp);

// Use a new buffer
return expanded;
} else {
return src;
}
}

static SSLManagedBuffer create(final SSLBufferMode mode, final int size) {
return mode == SSLBufferMode.DYNAMIC ? new DynamicBuffer(size) : new StaticBuffer(size);
}

static final class StaticBuffer extends SSLManagedBuffer {

private final ByteBuffer buffer;
private ByteBuffer buffer;

public StaticBuffer(final int size) {
Args.positive(size, "size");
Expand All @@ -90,6 +131,10 @@ public boolean hasData() {
return buffer.position() > 0;
}

@Override
void ensureWriteable(final int size) {
buffer = ensureWriteable(buffer, size);
}
}

static final class DynamicBuffer extends SSLManagedBuffer {
Expand Down Expand Up @@ -126,6 +171,10 @@ public boolean hasData() {
return wrapped != null && wrapped.position() > 0;
}

@Override
void ensureWriteable(final int size) {
wrapped = ensureWriteable(wrapped, size);
}
}

}

0 comments on commit a4ae27d

Please sign in to comment.