34
34
import org .mockito .Mock ;
35
35
import org .mockito .junit .MockitoJUnitRunner ;
36
36
37
+ import org .springframework .security .web .server .context .SecurityContextServerWebExchangeWebFilter ;
38
+ import org .springframework .web .server .WebFilterChain ;
37
39
import reactor .core .publisher .Mono ;
38
40
import reactor .test .publisher .TestPublisher ;
39
41
@@ -190,6 +192,30 @@ public void csrfServerLogoutHandlerAppliedIfCsrfIsEnabled() {
190
192
.isEqualTo (Arrays .asList (SecurityContextServerLogoutHandler .class , CsrfServerLogoutHandler .class ));
191
193
}
192
194
195
+ @ Test
196
+ @ SuppressWarnings ("unchecked" )
197
+ public void addFilterAfterIsApplied (){
198
+ SecurityWebFilterChain securityWebFilterChain = this .http .addFilterAfter (new TestWebFilter (), SecurityWebFiltersOrder .SECURITY_CONTEXT_SERVER_WEB_EXCHANGE ).build ();
199
+ List filters = securityWebFilterChain .getWebFilters ().map (WebFilter ::getClass ).collectList ().block ();
200
+
201
+ assertThat (filters ).isNotNull ()
202
+ .isNotEmpty ()
203
+ .containsSequence (SecurityContextServerWebExchangeWebFilter .class , TestWebFilter .class );
204
+
205
+ }
206
+
207
+ @ Test
208
+ @ SuppressWarnings ("unchecked" )
209
+ public void addFilterBeforeIsApplied (){
210
+ SecurityWebFilterChain securityWebFilterChain = this .http .addFilterBefore (new TestWebFilter (), SecurityWebFiltersOrder .SECURITY_CONTEXT_SERVER_WEB_EXCHANGE ).build ();
211
+ List filters = securityWebFilterChain .getWebFilters ().map (WebFilter ::getClass ).collectList ().block ();
212
+
213
+ assertThat (filters ).isNotNull ()
214
+ .isNotEmpty ()
215
+ .containsSequence (TestWebFilter .class , SecurityContextServerWebExchangeWebFilter .class );
216
+
217
+ }
218
+
193
219
private <T extends WebFilter > Optional <T > getWebFilter (SecurityWebFilterChain filterChain , Class <T > filterClass ) {
194
220
return (Optional <T >) filterChain .getWebFilters ()
195
221
.filter (Objects ::nonNull )
@@ -214,4 +240,12 @@ Mono<String> pathWithinApplicationFromContext() {
214
240
.map (e -> e .getRequest ().getPath ().pathWithinApplication ().value ());
215
241
}
216
242
}
243
+
244
+ private static class TestWebFilter implements WebFilter {
245
+
246
+ @ Override
247
+ public Mono <Void > filter (ServerWebExchange exchange , WebFilterChain chain ) {
248
+ return chain .filter (exchange );
249
+ }
250
+ }
217
251
}
0 commit comments