|
20 | 20 | import java.security.PublicKey;
|
21 | 21 | import java.security.cert.X509Certificate;
|
22 | 22 | import java.text.ParseException;
|
| 23 | +import java.time.Clock; |
| 24 | +import java.time.Instant; |
| 25 | +import java.time.temporal.ChronoUnit; |
23 | 26 | import java.util.Arrays;
|
24 | 27 | import java.util.Map;
|
25 | 28 | import java.util.concurrent.ConcurrentHashMap;
|
| 29 | +import java.util.concurrent.locks.ReentrantReadWriteLock; |
26 | 30 | import java.util.function.Consumer;
|
27 | 31 | import java.util.function.Function;
|
28 | 32 | import java.util.function.Supplier;
|
@@ -158,19 +162,44 @@ private JWKSet retrieve(String jwkSetUrl) {
|
158 | 162 | }
|
159 | 163 |
|
160 | 164 | private class JwkSetHolder implements Supplier<JWKSet> {
|
| 165 | + private final ReentrantReadWriteLock rwLock = new ReentrantReadWriteLock(); |
| 166 | + private final Clock clock = Clock.systemUTC(); |
161 | 167 | private final String jwkSetUrl;
|
162 | 168 | private JWKSet jwkSet;
|
| 169 | + private Instant lastUpdatedAt; |
163 | 170 |
|
164 | 171 | private JwkSetHolder(String jwkSetUrl) {
|
165 | 172 | this.jwkSetUrl = jwkSetUrl;
|
166 | 173 | }
|
167 | 174 |
|
168 | 175 | @Override
|
169 | 176 | public JWKSet get() {
|
170 |
| - if (this.jwkSet == null) { |
171 |
| - this.jwkSet = retrieve(this.jwkSetUrl); |
| 177 | + this.rwLock.readLock().lock(); |
| 178 | + if (shouldRefresh()) { |
| 179 | + this.rwLock.readLock().unlock(); |
| 180 | + this.rwLock.writeLock().lock(); |
| 181 | + try { |
| 182 | + if (shouldRefresh()) { |
| 183 | + this.jwkSet = retrieve(this.jwkSetUrl); |
| 184 | + this.lastUpdatedAt = Instant.now(); |
| 185 | + } |
| 186 | + this.rwLock.readLock().lock(); |
| 187 | + } finally { |
| 188 | + this.rwLock.writeLock().unlock(); |
| 189 | + } |
| 190 | + } |
| 191 | + |
| 192 | + try { |
| 193 | + return this.jwkSet; |
| 194 | + } finally { |
| 195 | + this.rwLock.readLock().unlock(); |
172 | 196 | }
|
173 |
| - return this.jwkSet; |
| 197 | + } |
| 198 | + |
| 199 | + private boolean shouldRefresh() { |
| 200 | + // Refresh every 5 minutes |
| 201 | + return (this.jwkSet == null || |
| 202 | + this.clock.instant().isAfter(this.lastUpdatedAt.plus(5, ChronoUnit.MINUTES))); |
174 | 203 | }
|
175 | 204 |
|
176 | 205 | }
|
|
0 commit comments