Skip to content

Improve JWT parse / decode performance #620

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 27 commits into from
Jan 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
13958ad
Optimise parsing of token for well-defined JWT format
noetro Sep 30, 2022
e2a8314
Update error message in test to match new code
noetro Sep 30, 2022
a24c18b
Fixing checkstyle issues
noetro Oct 6, 2022
3940358
Merge branch 'master' into master
jimmyjames Oct 6, 2022
8c24018
Added missing test case for no parts
noetro Oct 7, 2022
298d522
Return new JWTDecodeException
jimmyjames Oct 7, 2022
8bf1d27
Add JMH support to build script
noetro Oct 8, 2022
20dbf7c
Add benchmark for decoder and cleanup build file
noetro Oct 8, 2022
21e4d7e
Merge remote-tracking branch 'origin/jmh' into jackson-optimisations
noetro Oct 8, 2022
693fdc3
Optimise JWT deserialisation by re-using threadsafe Jackson objects
noetro Oct 8, 2022
21bd831
Merge branch 'master' into master
jimmyjames Oct 12, 2022
1c2d3c9
Merge remote-tracking branch 'origin/master' into jmh
noetro Oct 17, 2022
8f3d270
Merge remote-tracking branch 'origin/jmh'
noetro Oct 17, 2022
1df1df1
Merge remote-tracking branch 'origin/master' into jackson-optimisations
noetro Oct 17, 2022
4aa2dc2
Disable lint checks on JMH source set that is for testing
noetro Oct 17, 2022
e7fd3aa
Remove extra line break
noetro Oct 17, 2022
75e7db7
Merge remote-tracking branch 'origin/jackson-optimisations'
noetro Oct 17, 2022
cd2e858
Merge branch 'auth0:master' into master
noetro Oct 17, 2022
e363b33
Merge branch 'master' into master
noetro Oct 18, 2022
1d6b07e
Merge branch 'master' into master
jimmyjames Oct 20, 2022
d2dd8a5
Merge branch 'master' into master
noetro Oct 24, 2022
9d017c6
Merge branch 'master' into master
noetro Oct 25, 2022
fe9c137
Merge branch 'master' into master
noetro Oct 25, 2022
f96e4f4
Merge branch 'master' into master
noetro Oct 26, 2022
0b63418
Merge branch 'master' into master
noetro Oct 27, 2022
37122ec
Merge branch 'master' into master
noetro Nov 25, 2022
6a8ef68
Merge branch 'master' into master
jimmyjames Jan 31, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 45 additions & 1 deletion lib/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,28 @@ plugins {
id 'checkstyle'
}

sourceSets {
jmh {

}
}

configurations {
jmhImplementation {
extendsFrom implementation
}
}

checkstyle {
toolVersion '10.0'
checkstyleTest.enabled = false //We are disabling lint checks for tests
}
//We are disabling lint checks for tests
tasks.named("checkstyleTest").configure({
enabled = false
})
tasks.named("checkstyleJmh").configure({
enabled = false
})

logger.lifecycle("Using version ${version} for ${group}.${name}")

Expand Down Expand Up @@ -61,6 +79,10 @@ dependencies {
testImplementation 'net.jodah:concurrentunit:0.4.6'
testImplementation 'org.hamcrest:hamcrest:2.2'
testImplementation 'org.mockito:mockito-core:4.4.0'

jmhImplementation sourceSets.main.output
jmhImplementation 'org.openjdk.jmh:jmh-core:1.35'
jmhAnnotationProcessor 'org.openjdk.jmh:jmh-generator-annprocess:1.35'
}

jacoco {
Expand Down Expand Up @@ -143,3 +165,25 @@ task exportVersion() {
new File(rootDir, "version.txt").text = "$version"
}
}

// you can pass any arguments JMH accepts via Gradle args.
// Example: ./gradlew runJMH --args="-lrf"
tasks.register('runJMH', JavaExec) {
description 'Run JMH benchmarks.'
group 'verification'

main 'org.openjdk.jmh.Main'
classpath sourceSets.jmh.runtimeClasspath

args project.hasProperty("args") ? project.property("args").split() : ""
}
tasks.register('jmhHelp', JavaExec) {
description 'Prints the available command line options for JMH.'
group 'help'

main 'org.openjdk.jmh.Main'
classpath sourceSets.jmh.runtimeClasspath

args '-h'
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package com.auth0.jwt.benchmark;

import com.auth0.jwt.JWT;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.infra.Blackhole;

/**
* This class is a JMH benchmark for decoding JWTs.
*/
public class JWTDecoderBenchmark {
private static final String TOKEN = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ";

@Benchmark
@BenchmarkMode(Mode.Throughput)
public void throughputDecodeTime(Blackhole blackhole) {
blackhole.consume(JWT.decode(TOKEN));
}
}
13 changes: 6 additions & 7 deletions lib/src/main/java/com/auth0/jwt/impl/BasicHeader.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@

import com.auth0.jwt.interfaces.Claim;
import com.auth0.jwt.interfaces.Header;
import com.fasterxml.jackson.core.ObjectCodec;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectReader;

import java.io.Serializable;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;

import static com.auth0.jwt.impl.JsonNodeClaim.extractClaim;
Expand All @@ -23,22 +22,22 @@ class BasicHeader implements Header, Serializable {
private final String contentType;
private final String keyId;
private final Map<String, JsonNode> tree;
private final ObjectReader objectReader;
private final ObjectCodec objectCodec;

BasicHeader(
String algorithm,
String type,
String contentType,
String keyId,
Map<String, JsonNode> tree,
ObjectReader objectReader
ObjectCodec objectCodec
) {
this.algorithm = algorithm;
this.type = type;
this.contentType = contentType;
this.keyId = keyId;
this.tree = Collections.unmodifiableMap(tree == null ? new HashMap<>() : tree);
this.objectReader = objectReader;
this.tree = tree == null ? Collections.emptyMap() : Collections.unmodifiableMap(tree);
this.objectCodec = objectCodec;
}

Map<String, JsonNode> getTree() {
Expand Down Expand Up @@ -67,6 +66,6 @@ public String getKeyId() {

@Override
public Claim getHeaderClaim(String name) {
return extractClaim(name, tree, objectReader);
return extractClaim(name, tree, objectCodec);
}
}
20 changes: 6 additions & 14 deletions lib/src/main/java/com/auth0/jwt/impl/HeaderDeserializer.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

import com.auth0.jwt.HeaderParams;
import com.auth0.jwt.exceptions.JWTDecodeException;
import com.auth0.jwt.interfaces.Header;
import com.fasterxml.jackson.core.JsonParser;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.DeserializationContext;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectReader;
import com.fasterxml.jackson.databind.deser.std.StdDeserializer;

import java.io.IOException;
Expand All @@ -19,22 +19,14 @@
*
* @see JWTParser
*/
class HeaderDeserializer extends StdDeserializer<BasicHeader> {
class HeaderDeserializer extends StdDeserializer<Header> {

private final ObjectReader objectReader;

HeaderDeserializer(ObjectReader objectReader) {
this(null, objectReader);
}

private HeaderDeserializer(Class<?> vc, ObjectReader objectReader) {
super(vc);

this.objectReader = objectReader;
HeaderDeserializer() {
super(Header.class);
}

@Override
public BasicHeader deserialize(JsonParser p, DeserializationContext ctxt) throws IOException {
public Header deserialize(JsonParser p, DeserializationContext ctxt) throws IOException {
Map<String, JsonNode> tree = p.getCodec().readValue(p, new TypeReference<Map<String, JsonNode>>() {
});
if (tree == null) {
Expand All @@ -45,7 +37,7 @@ public BasicHeader deserialize(JsonParser p, DeserializationContext ctxt) throws
String type = getString(tree, HeaderParams.TYPE);
String contentType = getString(tree, HeaderParams.CONTENT_TYPE);
String keyId = getString(tree, HeaderParams.KEY_ID);
return new BasicHeader(algorithm, type, contentType, keyId, tree, objectReader);
return new BasicHeader(algorithm, type, contentType, keyId, tree, p.getCodec());
}

String getString(Map<String, JsonNode> tree, String claimName) {
Expand Down
22 changes: 17 additions & 5 deletions lib/src/main/java/com/auth0/jwt/impl/JWTParser.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,21 @@
* {@link HeaderSerializer} and {@link PayloadSerializer}.
*/
public class JWTParser implements JWTPartsParser {
private static final ObjectMapper DEFAULT_OBJECT_MAPPER = createDefaultObjectMapper();
private static final ObjectReader DEFAULT_PAYLOAD_READER = DEFAULT_OBJECT_MAPPER.readerFor(Payload.class);
private static final ObjectReader DEFAULT_HEADER_READER = DEFAULT_OBJECT_MAPPER.readerFor(Header.class);

private final ObjectReader payloadReader;
private final ObjectReader headerReader;

public JWTParser() {
this(getDefaultObjectMapper());
this.payloadReader = DEFAULT_PAYLOAD_READER;
this.headerReader = DEFAULT_HEADER_READER;
}

JWTParser(ObjectMapper mapper) {
addDeserializers(mapper);

this.payloadReader = mapper.readerFor(Payload.class);
this.headerReader = mapper.readerFor(Header.class);
}
Expand Down Expand Up @@ -55,18 +61,24 @@ public Header parseHeader(String json) throws JWTDecodeException {
}
}

private void addDeserializers(ObjectMapper mapper) {
static void addDeserializers(ObjectMapper mapper) {
SimpleModule module = new SimpleModule();
ObjectReader reader = mapper.reader();
module.addDeserializer(Payload.class, new PayloadDeserializer(reader));
module.addDeserializer(Header.class, new HeaderDeserializer(reader));
module.addDeserializer(Payload.class, new PayloadDeserializer());
module.addDeserializer(Header.class, new HeaderDeserializer());
mapper.registerModule(module);
}

static ObjectMapper getDefaultObjectMapper() {
return DEFAULT_OBJECT_MAPPER;
}

private static ObjectMapper createDefaultObjectMapper() {
ObjectMapper mapper = new ObjectMapper();
mapper.disable(SerializationFeature.FAIL_ON_EMPTY_BEANS);
mapper.setSerializationInclusion(JsonInclude.Include.NON_EMPTY);

addDeserializers(mapper);

return mapper;
}

Expand Down
36 changes: 19 additions & 17 deletions lib/src/main/java/com/auth0/jwt/impl/JsonNodeClaim.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import com.auth0.jwt.interfaces.Claim;
import com.fasterxml.jackson.core.JsonParser;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.ObjectCodec;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectReader;

import java.io.IOException;
import java.lang.reflect.Array;
Expand All @@ -21,12 +21,12 @@
*/
class JsonNodeClaim implements Claim {

private final ObjectReader objectReader;
private final ObjectCodec codec;
private final JsonNode data;

private JsonNodeClaim(JsonNode node, ObjectReader objectReader) {
private JsonNodeClaim(JsonNode node, ObjectCodec codec) {
this.data = node;
this.objectReader = objectReader;
this.codec = codec;
}

@Override
Expand Down Expand Up @@ -82,7 +82,7 @@ public <T> T[] asArray(Class<T> clazz) throws JWTDecodeException {
T[] arr = (T[]) Array.newInstance(clazz, data.size());
for (int i = 0; i < data.size(); i++) {
try {
arr[i] = objectReader.treeToValue(data.get(i), clazz);
arr[i] = codec.treeToValue(data.get(i), clazz);
} catch (JsonProcessingException e) {
throw new JWTDecodeException("Couldn't map the Claim's array contents to " + clazz.getSimpleName(), e);
}
Expand All @@ -99,7 +99,7 @@ public <T> List<T> asList(Class<T> clazz) throws JWTDecodeException {
List<T> list = new ArrayList<>();
for (int i = 0; i < data.size(); i++) {
try {
list.add(objectReader.treeToValue(data.get(i), clazz));
list.add(codec.treeToValue(data.get(i), clazz));
} catch (JsonProcessingException e) {
throw new JWTDecodeException("Couldn't map the Claim's array contents to " + clazz.getSimpleName(), e);
}
Expand All @@ -113,11 +113,11 @@ public Map<String, Object> asMap() throws JWTDecodeException {
return null;
}

try {
TypeReference<Map<String, Object>> mapType = new TypeReference<Map<String, Object>>() {
};
JsonParser thisParser = objectReader.treeAsTokens(data);
return thisParser.readValueAs(mapType);
TypeReference<Map<String, Object>> mapType = new TypeReference<Map<String, Object>>() {
};

try (JsonParser parser = codec.treeAsTokens(data)) {
return parser.readValueAs(mapType);
} catch (IOException e) {
throw new JWTDecodeException("Couldn't map the Claim value to Map", e);
}
Expand All @@ -129,8 +129,8 @@ public <T> T as(Class<T> clazz) throws JWTDecodeException {
if (isMissing() || isNull()) {
return null;
}
return objectReader.treeAsTokens(data).readValueAs(clazz);
} catch (IOException e) {
return codec.treeToValue(data, clazz);
} catch (JsonProcessingException e) {
throw new JWTDecodeException("Couldn't map the Claim value to " + clazz.getSimpleName(), e);
}
}
Expand Down Expand Up @@ -160,21 +160,23 @@ public String toString() {
*
* @param claimName the Claim to search for.
* @param tree the JsonNode tree to search the Claim in.
* @param objectCodec the object codec in use for deserialization
* @return a valid non-null Claim.
*/
static Claim extractClaim(String claimName, Map<String, JsonNode> tree, ObjectReader objectReader) {
static Claim extractClaim(String claimName, Map<String, JsonNode> tree, ObjectCodec objectCodec) {
JsonNode node = tree.get(claimName);
return claimFromNode(node, objectReader);
return claimFromNode(node, objectCodec);
}

/**
* Helper method to create a Claim representation from the given JsonNode.
*
* @param node the JsonNode to convert into a Claim.
* @param objectCodec the object codec in use for deserialization
* @return a valid Claim instance. If the node is null or missing, a NullClaim will be returned.
*/
static Claim claimFromNode(JsonNode node, ObjectReader objectReader) {
return new JsonNodeClaim(node, objectReader);
static Claim claimFromNode(JsonNode node, ObjectCodec objectCodec) {
return new JsonNodeClaim(node, objectCodec);
}

}
22 changes: 8 additions & 14 deletions lib/src/main/java/com/auth0/jwt/impl/PayloadDeserializer.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import com.auth0.jwt.interfaces.Payload;
import com.fasterxml.jackson.core.JsonParser;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.ObjectCodec;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.DeserializationContext;
import com.fasterxml.jackson.databind.JsonNode;
Expand All @@ -24,16 +25,8 @@
*/
class PayloadDeserializer extends StdDeserializer<Payload> {

private final ObjectReader objectReader;

PayloadDeserializer(ObjectReader reader) {
this(null, reader);
}

private PayloadDeserializer(Class<?> vc, ObjectReader reader) {
super(vc);

this.objectReader = reader;
PayloadDeserializer() {
super(Payload.class);
}

@Override
Expand All @@ -46,16 +39,17 @@ public Payload deserialize(JsonParser p, DeserializationContext ctxt) throws IOE

String issuer = getString(tree, RegisteredClaims.ISSUER);
String subject = getString(tree, RegisteredClaims.SUBJECT);
List<String> audience = getStringOrArray(tree, RegisteredClaims.AUDIENCE);
List<String> audience = getStringOrArray(p.getCodec(), tree, RegisteredClaims.AUDIENCE);
Instant expiresAt = getInstantFromSeconds(tree, RegisteredClaims.EXPIRES_AT);
Instant notBefore = getInstantFromSeconds(tree, RegisteredClaims.NOT_BEFORE);
Instant issuedAt = getInstantFromSeconds(tree, RegisteredClaims.ISSUED_AT);
String jwtId = getString(tree, RegisteredClaims.JWT_ID);

return new PayloadImpl(issuer, subject, audience, expiresAt, notBefore, issuedAt, jwtId, tree, objectReader);
return new PayloadImpl(issuer, subject, audience, expiresAt, notBefore, issuedAt, jwtId, tree, p.getCodec());
}

List<String> getStringOrArray(Map<String, JsonNode> tree, String claimName) throws JWTDecodeException {
List<String> getStringOrArray(ObjectCodec codec, Map<String, JsonNode> tree, String claimName)
throws JWTDecodeException {
JsonNode node = tree.get(claimName);
if (node == null || node.isNull() || !(node.isArray() || node.isTextual())) {
return null;
Expand All @@ -67,7 +61,7 @@ List<String> getStringOrArray(Map<String, JsonNode> tree, String claimName) thro
List<String> list = new ArrayList<>(node.size());
for (int i = 0; i < node.size(); i++) {
try {
list.add(objectReader.treeToValue(node.get(i), String.class));
list.add(codec.treeToValue(node.get(i), String.class));
} catch (JsonProcessingException e) {
throw new JWTDecodeException("Couldn't map the Claim's array contents to String", e);
}
Expand Down
Loading