Skip to content

Commit ab99dc2

Browse files
committed
Improve ThreadLocal value handling
ThreadLocal values from a Servlet container thread maybe end up being unnecessarily restored, e.g. if DataFetcher is invoked on the same thread and then also removed, which then impacts the filter chain. The ContextManager now saves the thread id when values are extracted and ignores restore or remove calls if still on the same thread. This should also be more optimal, avoiding ThreadLocal access if threads aren't switched. See gh-58
1 parent 52572c9 commit ab99dc2

File tree

3 files changed

+100
-11
lines changed

3 files changed

+100
-11
lines changed

spring-graphql/src/main/java/org/springframework/graphql/execution/ContextManager.java

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ public abstract class ContextManager {
3838

3939
private static final String CONTEXT_VIEW_KEY = ContextManager.class.getName() + ".CONTEXT_VIEW";
4040

41+
private static final String THREAD_ID = ContextManager.class.getName() + ".THREAD_ID";
42+
4143
private static final String THREAD_LOCAL_VALUES_KEY = ContextManager.class.getName() + ".THREAD_VALUES_ACCESSOR";
4244

4345
private static final String THREAD_LOCAL_ACCESSOR_KEY = ContextManager.class.getName() + ".THREAD_LOCAL_ACCESSOR";
@@ -80,38 +82,39 @@ public static Context extractThreadLocalValues(ThreadLocalAccessor accessor, Con
8082
return context;
8183
}
8284
return context.putAll((ContextView) Context.of(
83-
THREAD_LOCAL_VALUES_KEY, valuesMap, THREAD_LOCAL_ACCESSOR_KEY, accessor));
85+
THREAD_LOCAL_VALUES_KEY, valuesMap,
86+
THREAD_LOCAL_ACCESSOR_KEY, accessor,
87+
THREAD_ID, Thread.currentThread().getId()));
8488
}
8589

8690
/**
87-
* Look up saved ThreadLocal values and use them to re-establish ThreadLocal context.
91+
* Look up saved ThreadLocal values and restore them if any are found.
92+
* This is a no-op if invoked on the thread that values were extracted on.
8893
* @param contextView the reactor {@link ContextView}
8994
*/
9095
static void restoreThreadLocalValues(ContextView contextView) {
9196
ThreadLocalAccessor accessor = getThreadLocalAccessor(contextView);
9297
if (accessor != null) {
93-
accessor.restoreValues(getThreadLocalValues(contextView));
98+
accessor.restoreValues(contextView.get(THREAD_LOCAL_VALUES_KEY));
9499
}
95100
}
96101

97102
/**
98-
* Look up saved ThreadLocal values and remove associated ThreadLocal context.
103+
* Look up saved ThreadLocal values and remove the ThreadLocal values.
104+
* This is a no-op if invoked on the thread that values were extracted on.
99105
* @param contextView the reactor {@link ContextView}
100106
*/
101107
static void resetThreadLocalValues(ContextView contextView) {
102108
ThreadLocalAccessor accessor = getThreadLocalAccessor(contextView);
103109
if (accessor != null) {
104-
accessor.resetValues(getThreadLocalValues(contextView));
110+
accessor.resetValues(contextView.get(THREAD_LOCAL_VALUES_KEY));
105111
}
106112
}
107113

108114
@Nullable
109-
private static ThreadLocalAccessor getThreadLocalAccessor(ContextView contextView) {
110-
return (contextView.hasKey(THREAD_LOCAL_ACCESSOR_KEY) ? contextView.get(THREAD_LOCAL_ACCESSOR_KEY) : null);
111-
}
112-
113-
private static Map<String, Object> getThreadLocalValues(ContextView contextView) {
114-
return contextView.get(THREAD_LOCAL_VALUES_KEY);
115+
private static ThreadLocalAccessor getThreadLocalAccessor(ContextView view) {
116+
Long id = view.getOrDefault(THREAD_ID, null);
117+
return (id != null && id != Thread.currentThread().getId() ? view.get(THREAD_LOCAL_ACCESSOR_KEY) : null);
115118
}
116119

117120
}

spring-graphql/src/test/java/org/springframework/graphql/TestThreadLocalAccessor.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,15 @@ public class TestThreadLocalAccessor<T> implements ThreadLocalAccessor {
3434
@Nullable
3535
private Long threadId;
3636

37+
private boolean suppressThreadIdCheck;
38+
3739
public TestThreadLocalAccessor(ThreadLocal<T> threadLocal) {
40+
this(threadLocal, false);
41+
}
42+
43+
public TestThreadLocalAccessor(ThreadLocal<T> threadLocal, boolean suppressThreadIdCheck) {
3844
this.threadLocal = threadLocal;
45+
this.suppressThreadIdCheck = suppressThreadIdCheck;
3946
}
4047

4148
@Override
@@ -61,10 +68,16 @@ public void resetValues(Map<String, Object> values) {
6168
}
6269

6370
private void saveThreadId() {
71+
if (this.suppressThreadIdCheck) {
72+
return;
73+
}
6474
this.threadId = Thread.currentThread().getId();
6575
}
6676

6777
private void checkThreadId() {
78+
if (this.suppressThreadIdCheck) {
79+
return;
80+
}
6881
assertThat(this.threadId).as("No threadId to check. Was extractValues not called?").isNotNull();
6982
assertThat(Thread.currentThread().getId() != this.threadId)
7083
.as("ThreadLocal value extracted and restored on the same thread. Propagation not tested effectively.")
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
/*
2+
* Copyright 2002-2021 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.graphql.execution;
17+
18+
import java.time.Duration;
19+
20+
import org.junit.jupiter.api.Test;
21+
import reactor.core.publisher.Mono;
22+
import reactor.util.context.Context;
23+
24+
import org.springframework.graphql.TestThreadLocalAccessor;
25+
26+
import static org.assertj.core.api.Assertions.assertThat;
27+
28+
/**
29+
* Unit tests for {@link ContextManager}.
30+
* @author Rossen Stoyanchev
31+
*/
32+
public class ContextManagerTests {
33+
34+
@Test
35+
void restoreThreadLocaValues() {
36+
ThreadLocal<String> threadLocal = new ThreadLocal<>();
37+
threadLocal.set("myValue");
38+
39+
Context context = ContextManager.extractThreadLocalValues(
40+
new TestThreadLocalAccessor<>(threadLocal), Context.empty());
41+
try {
42+
Mono.delay(Duration.ofMillis(10))
43+
.doOnNext(aLong -> {
44+
assertThat(threadLocal.get()).isNull();
45+
ContextManager.restoreThreadLocalValues(context);
46+
assertThat(threadLocal.get()).isEqualTo("myValue");
47+
ContextManager.resetThreadLocalValues(context);
48+
})
49+
.block();
50+
}
51+
finally {
52+
threadLocal.remove();
53+
}
54+
}
55+
56+
@Test
57+
void restoreThreadLocaValuesOnSameThreadIsNoOp() {
58+
ThreadLocal<String> threadLocal = new ThreadLocal<>();
59+
threadLocal.set("myValue");
60+
61+
Context context = ContextManager.extractThreadLocalValues(
62+
new TestThreadLocalAccessor<>(threadLocal, true), Context.empty());
63+
64+
threadLocal.remove();
65+
ContextManager.restoreThreadLocalValues(context);
66+
assertThat(threadLocal.get()).isNull();
67+
68+
threadLocal.set("anotherValue");
69+
ContextManager.resetThreadLocalValues(context);
70+
assertThat(threadLocal.get()).isEqualTo("anotherValue");
71+
}
72+
73+
}

0 commit comments

Comments
 (0)