Skip to content

Commit 8b3fb55

Browse files
ankurpathakrwinch
authored andcommitted
Added methods to add filter relatively in ServerHttpSecurity
Addition of two new methods addFilterBefore and addFilterAfter in ServerHttpSecurity to allow addition of WebFilter before and after of specified order Fixes: gh-6138
1 parent 3c35f4c commit 8b3fb55

File tree

2 files changed

+60
-0
lines changed

2 files changed

+60
-0
lines changed

config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,32 @@ public ServerHttpSecurity addFilterAt(WebFilter webFilter, SecurityWebFiltersOrd
288288
return this;
289289
}
290290

291+
/**
292+
* Adds a {@link WebFilter} before specific position.
293+
* @param webFilter the {@link WebFilter} to add
294+
* @param order the place before which to insert the {@link WebFilter}
295+
* @return the {@link ServerHttpSecurity} to continue configuring
296+
* @since 5.2.0
297+
* @author Ankur Pathak
298+
*/
299+
public ServerHttpSecurity addFilterBefore(WebFilter webFilter, SecurityWebFiltersOrder order) {
300+
this.webFilters.add(new OrderedWebFilter(webFilter, order.getOrder() - 1));
301+
return this;
302+
}
303+
304+
/**
305+
* Adds a {@link WebFilter} after specific position.
306+
* @param webFilter the {@link WebFilter} to add
307+
* @param order the place after which to insert the {@link WebFilter}
308+
* @return the {@link ServerHttpSecurity} to continue configuring
309+
* @since 5.2.0
310+
* @author Ankur Pathak
311+
*/
312+
public ServerHttpSecurity addFilterAfter(WebFilter webFilter, SecurityWebFiltersOrder order) {
313+
this.webFilters.add(new OrderedWebFilter(webFilter, order.getOrder() + 1));
314+
return this;
315+
}
316+
291317
/**
292318
* Gets the ServerExchangeMatcher that determines which requests apply to this HttpSecurity instance.
293319
* @return the ServerExchangeMatcher that determines which requests apply to this HttpSecurity instance.

config/src/test/java/org/springframework/security/config/web/server/ServerHttpSecurityTests.java

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
import org.mockito.Mock;
3535
import org.mockito.junit.MockitoJUnitRunner;
3636

37+
import org.springframework.security.web.server.context.SecurityContextServerWebExchangeWebFilter;
38+
import org.springframework.web.server.WebFilterChain;
3739
import reactor.core.publisher.Mono;
3840
import reactor.test.publisher.TestPublisher;
3941

@@ -190,6 +192,30 @@ public void csrfServerLogoutHandlerAppliedIfCsrfIsEnabled() {
190192
.isEqualTo(Arrays.asList(SecurityContextServerLogoutHandler.class, CsrfServerLogoutHandler.class));
191193
}
192194

195+
@Test
196+
@SuppressWarnings("unchecked")
197+
public void addFilterAfterIsApplied(){
198+
SecurityWebFilterChain securityWebFilterChain = this.http.addFilterAfter(new TestWebFilter(), SecurityWebFiltersOrder.SECURITY_CONTEXT_SERVER_WEB_EXCHANGE).build();
199+
List filters = securityWebFilterChain.getWebFilters().map(WebFilter::getClass).collectList().block();
200+
201+
assertThat(filters).isNotNull()
202+
.isNotEmpty()
203+
.containsSequence(SecurityContextServerWebExchangeWebFilter.class, TestWebFilter.class);
204+
205+
}
206+
207+
@Test
208+
@SuppressWarnings("unchecked")
209+
public void addFilterBeforeIsApplied(){
210+
SecurityWebFilterChain securityWebFilterChain = this.http.addFilterBefore(new TestWebFilter(), SecurityWebFiltersOrder.SECURITY_CONTEXT_SERVER_WEB_EXCHANGE).build();
211+
List filters = securityWebFilterChain.getWebFilters().map(WebFilter::getClass).collectList().block();
212+
213+
assertThat(filters).isNotNull()
214+
.isNotEmpty()
215+
.containsSequence(TestWebFilter.class, SecurityContextServerWebExchangeWebFilter.class);
216+
217+
}
218+
193219
private <T extends WebFilter> Optional<T> getWebFilter(SecurityWebFilterChain filterChain, Class<T> filterClass) {
194220
return (Optional<T>) filterChain.getWebFilters()
195221
.filter(Objects::nonNull)
@@ -214,4 +240,12 @@ Mono<String> pathWithinApplicationFromContext() {
214240
.map(e -> e.getRequest().getPath().pathWithinApplication().value());
215241
}
216242
}
243+
244+
private static class TestWebFilter implements WebFilter {
245+
246+
@Override
247+
public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
248+
return chain.filter(exchange);
249+
}
250+
}
217251
}

0 commit comments

Comments
 (0)