Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve HTTP validation for inlining and specialized types #5446

Merged
merged 1 commit into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
209 changes: 169 additions & 40 deletions vertx-core/src/main/java/io/vertx/core/http/impl/HttpUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -604,67 +604,181 @@ public static void validateHeader(CharSequence name, Iterable<? extends CharSequ
});
}

public static void validateHeaderValue(CharSequence seq) {
public static void validateHeaderValue(CharSequence value) {
if (value instanceof AsciiString) {
validateAsciiHeaderValue((AsciiString) value);
} else if (value instanceof String) {
validateStringHeaderValue((String) value);
} else {
validateSequenceHeaderValue(value);
}
}

private static void validateAsciiHeaderValue(AsciiString value) {
final int length = value.length();
if (length == 0) {
return;
}
byte[] asciiChars = value.array();
int off = value.arrayOffset();
if (off == 0 && length == asciiChars.length) {
for (int index = 0; index < asciiChars.length; index++) {
int latinChar = asciiChars[index] & 0xFF;
if (latinChar == 0x7F) {
throw new IllegalArgumentException("a header value contains a prohibited character '127': " + value);
}
// non-printable chars are rare so let's make it a fall-back method, whilst still accepting HTAB
if (latinChar < 32 && latinChar != 0x09) {
validateSequenceHeaderValue(value, index - off);
break;
}
}
} else {
validateAsciiRangeHeaderValue(value, off, length, asciiChars);
}
}

/**
* This method is the slow-path generic version of {@link #validateAsciiHeaderValue(AsciiString)} which
* is optimized for {@link AsciiString} instances which are backed by a 0-offset full-blown byte array.
*/
private static void validateAsciiRangeHeaderValue(AsciiString value, int off, int length, byte[] asciiChars) {
int end = off + length;
for (int index = off; index < end; index++) {
int latinChar = asciiChars[index] & 0xFF;
if (latinChar == 0x7F) {
throw new IllegalArgumentException("a header value contains a prohibited character '127': " + value);
}
// non-printable chars are rare so let's make it a fall-back method, whilst still accepting HTAB
if (latinChar < 32 && latinChar != 0x09) {
validateSequenceHeaderValue(value, index - off);
break;
}
}
}

int state = 0;
// Start looping through each of the character
for (int index = 0; index < seq.length(); index++) {
state = validateValueChar(seq, state, seq.charAt(index));
private static void validateStringHeaderValue(String value) {
final int length = value.length();
if (length == 0) {
return;
}

if (state != 0) {
throw new IllegalArgumentException("a header value must not end with '\\r' or '\\n':" + seq);
for (int index = 0; index < length; index++) {
char latinChar = value.charAt(index);
if (latinChar == 0x7F) {
throw new IllegalArgumentException("a header value contains a prohibited character '127': " + value);
}
// non-printable chars are rare so let's make it a fall-back method, whilst still accepting HTAB
if (latinChar < 32 && latinChar != 0x09) {
validateSequenceHeaderValue(value, index);
break;
}
}
}

private static void validateSequenceHeaderValue(CharSequence value) {
final int length = value.length();
if (length == 0) {
return;
}

for (int index = 0; index < length; index++) {
char latinChar = value.charAt(index);
if (latinChar == 0x7F) {
throw new IllegalArgumentException("a header value contains a prohibited character '127': " + value);
}
// non-printable chars are rare so let's make it a fall-back method, whilst still accepting HTAB
if (latinChar < 32 && latinChar != 0x09) {
validateSequenceHeaderValue(value, index);
break;
}
}
}

private static final int HIGHEST_INVALID_VALUE_CHAR_MASK = ~0x1F;
private static final int NO_CR_LF_STATE = 0;
private static final int CR_STATE = 1;
private static final int LF_STATE = 2;

/**
* This method is taken as we need to validate the header value for the non-printable characters.
*/
private static void validateSequenceHeaderValue(CharSequence seq, int index) {
// we already expect the very-first character to be non-printable
int state = validateValueChar(seq, NO_CR_LF_STATE, seq.charAt(index));
for (int i = index + 1; i < seq.length(); i++) {
state = validateValueChar(seq, state, seq.charAt(i));
}
if (state != NO_CR_LF_STATE) {
throw new IllegalArgumentException("a header value must not end with '\\r' or '\\n':" + seq);
}
}

private static int validateValueChar(CharSequence seq, int state, char character) {
private static int validateValueChar(CharSequence seq, int state, char ch) {
/*
* State:
* 0: Previous character was neither CR nor LF
* 1: The previous character was CR
* 2: The previous character was LF
*/
if ((character & HIGHEST_INVALID_VALUE_CHAR_MASK) == 0 || character == 0x7F) { // 0x7F is "DEL".
// The only characters allowed in the range 0x00-0x1F are : HTAB, LF and CR
switch (character) {
case 0x09: // Horizontal tab - HTAB
case 0x0a: // Line feed - LF
case 0x0d: // Carriage return - CR
break;
default:
throw new IllegalArgumentException("a header value contains a prohibited character '" + (int) character + "': " + seq);
if (ch == 0x7F) {
throw new IllegalArgumentException("a header value contains a prohibited character '127': " + seq);
}
if ((ch & HIGHEST_INVALID_VALUE_CHAR_MASK) == 0) {
// this is a rare scenario
validateNonPrintableCtrlChar(seq, ch);
// this can include LF and CR as they are non-printable characters
if (state == NO_CR_LF_STATE) {
// Check the CRLF (HT | SP) pattern
switch (ch) {
case '\r':
return CR_STATE;
case '\n':
return LF_STATE;
}
return NO_CR_LF_STATE;
}
}
if (state != NO_CR_LF_STATE) {
// this is a rare scenario
return validateCrLfChar(seq, state, ch);
} else {
return NO_CR_LF_STATE;
}
}

// Check the CRLF (HT | SP) pattern
private static int validateCrLfChar(CharSequence seq, int state, char ch) {
switch (state) {
case 0:
switch (character) {
case '\r':
return 1;
case '\n':
return 2;
}
break;
case 1:
switch (character) {
case '\n':
return 2;
default:
throw new IllegalArgumentException("only '\\n' is allowed after '\\r': " + seq);
case CR_STATE:
if (ch == '\n') {
return LF_STATE;
}
case 2:
switch (character) {
throw new IllegalArgumentException("only '\\n' is allowed after '\\r': " + seq);
case LF_STATE:
switch (ch) {
case '\t':
case ' ':
return 0;
// return to the normal state
return NO_CR_LF_STATE;
default:
throw new IllegalArgumentException("only ' ' and '\\t' are allowed after '\\n': " + seq);
}
default:
// this should never happen
throw new AssertionError();
}
}

private static void validateNonPrintableCtrlChar(CharSequence seq, int ch) {
// The only characters allowed in the range 0x00-0x1F are : HTAB, LF and CR
switch (ch) {
case 0x09: // Horizontal tab - HTAB
case 0x0a: // Line feed - LF
case 0x0d: // Carriage return - CR
break;
default:
throw new IllegalArgumentException("a header value contains a prohibited character '" + (int) ch + "': " + seq);
}
return state;
}

private static final boolean[] VALID_H_NAME_ASCII_CHARS;
Expand Down Expand Up @@ -700,13 +814,15 @@ private static int validateValueChar(CharSequence seq, int state, char character
public static void validateHeaderName(CharSequence value) {
if (value instanceof AsciiString) {
// no need to check for ASCII-ness anymore
validateHeaderName((AsciiString) value);
validateAsciiHeaderName((AsciiString) value);
} else if(value instanceof String) {
validateStringHeaderName((String) value);
} else {
validateHeaderName0(value);
validateSequenceHeaderName(value);
}
}

private static void validateHeaderName(AsciiString value) {
private static void validateAsciiHeaderName(AsciiString value) {
final int len = value.length();
final int off = value.arrayOffset();
final byte[] asciiChars = value.array();
Expand All @@ -722,7 +838,20 @@ private static void validateHeaderName(AsciiString value) {
}
}

private static void validateHeaderName0(CharSequence value) {
private static void validateStringHeaderName(String value) {
for (int i = 0; i < value.length(); i++) {
final char c = value.charAt(i);
// Check to see if the character is not an ASCII character, or invalid
if (c > 0x7f) {
throw new IllegalArgumentException("a header name cannot contain non-ASCII character: " + value);
}
if (!VALID_H_NAME_ASCII_CHARS[c & 0x7F]) {
throw new IllegalArgumentException("a header name cannot contain some prohibited characters, such as : " + value);
}
}
}

private static void validateSequenceHeaderName(CharSequence value) {
for (int i = 0; i < value.length(); i++) {
final char c = value.charAt(i);
// Check to see if the character is not an ASCII character, or invalid
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,6 @@
@Threads(1)
@BenchmarkMode(Mode.Throughput)
@Fork(value = 1, jvmArgs = {
"-XX:+UseBiasedLocking",
"-XX:BiasedLockingStartupDelay=0",
"-XX:+AggressiveOpts",
"-Djmh.executor=CUSTOM",
"-Djmh.executor.class=io.vertx.benchmarks.VertxExecutorService"
})
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
package io.vertx.benchmarks;

import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;

import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.CompilerControl;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Param;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.Warmup;

import io.netty.handler.codec.DefaultHeaders;
import io.netty.handler.codec.http.DefaultHttpHeadersFactory;
import io.vertx.core.http.impl.HttpUtils;

@State(Scope.Thread)
@Warmup(iterations = 10, time = 1, timeUnit = TimeUnit.SECONDS)
@Measurement(iterations = 5, time = 400, timeUnit = TimeUnit.MILLISECONDS)
public class HeadersValidationBenchmark extends BenchmarkBase {

@Param({ "true", "false" })
public boolean ascii;

private CharSequence[] headerNames;
private CharSequence[] headerValues;
private static final DefaultHeaders.NameValidator<CharSequence> nettyNameValidator = DefaultHttpHeadersFactory.headersFactory().getNameValidator();
private static final DefaultHeaders.ValueValidator<CharSequence> nettyValueValidator = DefaultHttpHeadersFactory.headersFactory().getValueValidator();
private static final Consumer<CharSequence> vertxNameValidator = HttpUtils::validateHeaderName;
private static final Consumer<CharSequence> vertxValueValidator = HttpUtils::validateHeaderValue;

@Setup
public void setup() {
headerNames = new CharSequence[4];
headerValues = new CharSequence[4];
headerNames[0] = io.vertx.core.http.HttpHeaders.CONTENT_TYPE;
headerValues[0] = io.vertx.core.http.HttpHeaders.createOptimized("text/plain");
headerNames[1] = io.vertx.core.http.HttpHeaders.CONTENT_LENGTH;
headerValues[1] = io.vertx.core.http.HttpHeaders.createOptimized("20");
headerNames[2] = io.vertx.core.http.HttpHeaders.SERVER;
headerValues[2] = io.vertx.core.http.HttpHeaders.createOptimized("vert.x");
headerNames[3] = io.vertx.core.http.HttpHeaders.DATE;
headerValues[3] = io.vertx.core.http.HttpHeaders.createOptimized(HeadersUtils.DATE_FORMAT.format(new java.util.Date(0)));
if (!ascii) {
for (int i = 0; i < headerNames.length; i++) {
headerNames[i] = headerNames[i].toString();
}
}
if (!ascii) {
for (int i = 0; i < headerValues.length; i++) {
headerValues[i] = headerValues[i].toString();
}
}
}

@Benchmark
public void validateNameNetty() {
for (CharSequence headerName : headerNames) {
nettyNameValidation(headerName);
}
}

@CompilerControl(CompilerControl.Mode.DONT_INLINE)
private void nettyNameValidation(CharSequence headerName) {
nettyNameValidator.validateName(headerName);
}

@Benchmark
public void validateNameVertx() {
for (CharSequence headerName : headerNames) {
vertxNameValidation(headerName);
}
}

@CompilerControl(CompilerControl.Mode.DONT_INLINE)
private void vertxNameValidation(CharSequence headerName) {
vertxNameValidator.accept(headerName);
}


@Benchmark
public void validateValueNetty() {
for (CharSequence headerValue : headerValues) {
nettyValueValidation(headerValue);
}
}

@CompilerControl(CompilerControl.Mode.DONT_INLINE)
private void nettyValueValidation(CharSequence headerValue) {
nettyValueValidator.validate(headerValue);
}

@Benchmark
public void validateValueVertx() {
for (CharSequence headerValue : headerValues) {
vertxValueValidation(headerValue);
}
}

@CompilerControl(CompilerControl.Mode.DONT_INLINE)
private void vertxValueValidation(CharSequence headerValue) {
vertxValueValidator.accept(headerValue);
}
}
Loading