Skip to content

Make @SpringMock work for beans with @Primary #1503

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,31 @@
package org.spockframework.spring.mock;

import org.spockframework.runtime.model.FieldInfo;
import org.spockframework.spring.*;

import java.util.*;

import org.spockframework.spring.SpringBean;
import org.spockframework.spring.SpringExtensionException;
import org.spockframework.spring.SpringSpy;
import org.springframework.aop.scope.ScopedProxyUtils;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.*;
import org.springframework.beans.factory.BeanFactory;
import org.springframework.beans.factory.BeanFactoryUtils;
import org.springframework.beans.factory.FactoryBean;
import org.springframework.beans.factory.NoUniqueBeanDefinitionException;
import org.springframework.beans.factory.config.*;
import org.springframework.beans.factory.config.ConstructorArgumentValues.ValueHolder;
import org.springframework.beans.factory.support.*;
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
import org.springframework.beans.factory.support.BeanNameGenerator;
import org.springframework.beans.factory.support.DefaultBeanNameGenerator;
import org.springframework.beans.factory.support.RootBeanDefinition;
import org.springframework.context.ApplicationContext;
import org.springframework.core.*;
import org.springframework.util.*;
import org.springframework.core.Ordered;
import org.springframework.core.PriorityOrdered;
import org.springframework.core.ResolvableType;
import org.springframework.util.Assert;
import org.springframework.util.ObjectUtils;
import org.springframework.util.StringUtils;

import java.util.*;
import java.util.concurrent.ConcurrentHashMap;

import static java.util.Arrays.asList;

Expand Down Expand Up @@ -102,6 +114,8 @@ private void registerMock(ConfigurableListableBeanFactory beanFactory,
String transformedBeanName = BeanFactoryUtils.transformedBeanName(beanName);
beanDefinition.getConstructorArgumentValues().addIndexedArgumentValue(1, beanName);
if (registry.containsBeanDefinition(transformedBeanName)) {
BeanDefinition existing = registry.getBeanDefinition(transformedBeanName);
copyBeanDefinitionDetails(existing, beanDefinition);
registry.removeBeanDefinition(transformedBeanName);
}
registry.registerBeanDefinition(transformedBeanName, beanDefinition);
Expand Down Expand Up @@ -146,22 +160,30 @@ private String getBeanName(ConfigurableListableBeanFactory beanFactory,
if (StringUtils.hasLength(mockDefinition.getName())) {
return mockDefinition.getName();
}
Set<String> existingBeans = findCandidateBeans(beanFactory, mockDefinition);
Set<String> existingBeans = getExistingBeans(beanFactory, mockDefinition);
if (existingBeans.isEmpty()) {
return this.beanNameGenerator.generateBeanName(beanDefinition, registry);
}
if (existingBeans.size() == 1) {
return existingBeans.iterator().next();
}
String primaryCandidate = determinePrimaryCandidate(registry, existingBeans, mockDefinition.getTypeToMock());
if (primaryCandidate != null) {
return primaryCandidate;
}
throw new IllegalStateException(
"Unable to register mock bean " + mockDefinition.getTypeToMock()
+ " expected a single matching bean to replace but found "
+ existingBeans);
}

private void copyBeanDefinitionDetails(BeanDefinition from, BeanDefinition to) {
to.setPrimary(from.isPrimary());
}

private void registerSpy(ConfigurableListableBeanFactory beanFactory,
BeanDefinitionRegistry registry, SpyDefinition definition) {
String[] existingBeans = getExistingBeans(beanFactory, definition.getTypeToSpy());
Set<String> existingBeans = getExistingBeans(beanFactory, definition.getTypeToSpy());
if (ObjectUtils.isEmpty(existingBeans)) {
FieldInfo fieldInfo = definition.getFieldInfo();
throw new SpringExtensionException(String.format("No matching bean found! " +
Expand All @@ -172,8 +194,8 @@ private void registerSpy(ConfigurableListableBeanFactory beanFactory,
}
}

private Set<String> findCandidateBeans(ConfigurableListableBeanFactory beanFactory,
MockDefinition mockDefinition) {
private Set<String> getExistingBeans(ConfigurableListableBeanFactory beanFactory,
MockDefinition mockDefinition) {
QualifierDefinition qualifier = mockDefinition.getQualifier();
Set<String> candidates = new TreeSet<>();
for (String candidate : getExistingBeans(beanFactory,
Expand All @@ -185,7 +207,7 @@ private Set<String> findCandidateBeans(ConfigurableListableBeanFactory beanFacto
return candidates;
}

private String[] getExistingBeans(ConfigurableListableBeanFactory beanFactory,
private Set<String> getExistingBeans(ConfigurableListableBeanFactory beanFactory,
ResolvableType type) {
Set<String> beans = new LinkedHashSet<>(
asList(beanFactory.getBeanNamesForType(type)));
Expand All @@ -199,7 +221,7 @@ private String[] getExistingBeans(ConfigurableListableBeanFactory beanFactory,
}

beans.removeIf(this::isScopedTarget);
return beans.toArray(new String[0]);
return beans;
}

private boolean isScopedTarget(String beanName) {
Expand All @@ -211,7 +233,7 @@ private boolean isScopedTarget(String beanName) {
}

private void registerSpies(BeanDefinitionRegistry registry, SpyDefinition definition,
String[] existingBeans) {
Set<String> existingBeans) {
try {
registerSpy(definition,
determineBeanName(existingBeans, definition, registry));
Expand All @@ -221,28 +243,28 @@ private void registerSpies(BeanDefinitionRegistry registry, SpyDefinition defini
}
}

private String determineBeanName(String[] existingBeans, SpyDefinition definition,
private String determineBeanName(Collection<String> existingBeans, SpyDefinition definition,
BeanDefinitionRegistry registry) {
if (StringUtils.hasText(definition.getName())) {
return definition.getName();
}
if (existingBeans.length == 1) {
return existingBeans[0];
if (existingBeans.size() == 1) {
return existingBeans.iterator().next();
}
return determinePrimaryCandidate(registry, existingBeans, definition.getTypeToSpy());
}

private String determinePrimaryCandidate(BeanDefinitionRegistry registry,
String[] candidateBeanNames, ResolvableType type) {
Collection<String> candidateBeanNames, ResolvableType type) {
String primaryBeanName = null;
for (String candidateBeanName : candidateBeanNames) {
BeanDefinition beanDefinition = registry.getBeanDefinition(candidateBeanName);
if (beanDefinition.isPrimary()) {
if (primaryBeanName != null) {
throw new NoUniqueBeanDefinitionException(type.resolve(),
candidateBeanNames.length,
candidateBeanNames.size(),
"more than one 'primary' bean found among candidates: "
+ asList(candidateBeanNames));
+ candidateBeanNames);
}
primaryBeanName = candidateBeanName;
}
Expand Down Expand Up @@ -332,6 +354,8 @@ static class SpyPostProcessor extends BackwardsCompatibleInstantiationAwareBeanP

private static final String BEAN_NAME = SpyPostProcessor.class.getName();

private final Map<String, Object> earlySpyReferences = new ConcurrentHashMap<>(16);

private final SpockMockPostprocessor spockMockPostprocessor;

SpyPostProcessor(SpockMockPostprocessor spockMockPostprocessor) {
Expand All @@ -346,6 +370,10 @@ public int getOrder() {
@Override
public Object getEarlyBeanReference(Object bean, String beanName)
throws BeansException {
if (bean instanceof FactoryBean) {
return bean;
}
this.earlySpyReferences.put(getCacheKey(bean, beanName), bean);
return createSpyIfNecessary(bean, beanName);
}

Expand All @@ -355,12 +383,18 @@ public Object postProcessAfterInitialization(Object bean, String beanName)
if (bean instanceof FactoryBean) {
return bean;
}
return createSpyIfNecessary(bean, beanName);
if (this.earlySpyReferences.remove(getCacheKey(bean, beanName)) != bean) {
return this.spockMockPostprocessor.createSpyIfNecessary(bean, beanName);
}
return bean;
}

private Object createSpyIfNecessary(Object bean, String beanName) {
return this.spockMockPostprocessor.createSpyIfNecessary(bean, beanName);
}
private String getCacheKey(Object bean, String beanName) {
return StringUtils.hasLength(beanName) ? beanName : bean.getClass().getName();
}

static void register(BeanDefinitionRegistry registry) {
if (!registry.containsBeanDefinition(BEAN_NAME)) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
* Copyright 2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.spockframework.spring.mock

import org.spockframework.spring.IService1
import org.spockframework.spring.Service1
import org.spockframework.spring.Service2
import org.spockframework.spring.SpringBean
import org.springframework.beans.factory.annotation.Autowired
import org.springframework.context.annotation.Bean
import org.springframework.context.annotation.Primary
import org.springframework.test.context.ContextConfiguration
import spock.lang.Specification

@ContextConfiguration(classes = TwoService2BeansWithPrimaryConfig)
class SpringBeanWIthSinglePrimaryBeanSpec extends Specification {

@SpringBean
Service2 service2 = Mock() {
generateQuickBrownFox() >> "blubb"
}

@Autowired
Service1 service1

def "injection with stubbing works"() {
expect:
service1.generateString() == "blubb"
}

def "mocking works was well"() {
when:
def result = service1.generateString()

then:
result == "Foo"
1 * service2.generateQuickBrownFox() >> "Foo"
}

static class TwoService2BeansWithPrimaryConfig {

@Bean
Service2 service2NotPrimary() {
new Service2()
}

@Primary
@Bean
Service2 service2Primary() {
new Service2()
}

@Bean
IService1 service1(Service2 service2) {
return new Service1(service2)
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
/*
* Copyright 2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.spockframework.spring.mock

import spock.lang.Specification
import spock.util.EmbeddedSpecRunner

class SpringBeanWithManyOrNoPrimarySpec extends Specification {

def "rejects an attempt to mock a bean with more than one instance without primary"() {
setup:
def runner = new EmbeddedSpecRunner()
runner.throwFailure = false

when:
def result = runner.run("""
import org.spockframework.spring.Service2
import org.spockframework.spring.SpringBean
import org.springframework.context.annotation.Bean
import org.springframework.test.context.ContextConfiguration
import spock.lang.Specification

@ContextConfiguration(classes = TwoService2BeansNoPrimaryConfig)
class Foo extends Specification {

@SpringBean
Service2 service2 = Mock()

def foo() {
expect:
1==1
}

static class TwoService2BeansNoPrimaryConfig {

@Bean
Service2 service2instance1() {
new Service2()
}

@Bean
Service2 service2instance2() {
new Service2()
}
}
}
""")

then:
result.totalFailureCount > 0
result.failures[0].exception.message.contains("Failed to load ApplicationContext")
}

def "rejects an attempt to mock a bean with more than one primary instance"() {
setup:
def runner = new EmbeddedSpecRunner()
runner.throwFailure = false

when:
def result = runner.run("""
import org.spockframework.spring.Service2
import org.spockframework.spring.SpringBean
import org.springframework.context.annotation.Bean
import org.springframework.context.annotation.Primary
import org.springframework.test.context.ContextConfiguration
import spock.lang.Specification

@ContextConfiguration(classes = TwoService2BeansBothPrimaryConfig)
class Foo extends Specification {

@SpringBean
Service2 service2 = Mock()

def foo() {
expect:
1==1
}

static class TwoService2BeansBothPrimaryConfig {

@Primary
@Bean
Service2 service2Primary1() {
new Service2()
}

@Primary
@Bean
Service2 service2Primary2() {
new Service2()
}
}
}
""")

then:
result.totalFailureCount > 0
result.failures[0].exception.message.contains("Failed to load ApplicationContext")
}

}
Loading