diff --git a/src/main/kotlin/com/coxautodev/graphql/tools/CoroutineContextProvider.java b/src/main/kotlin/com/coxautodev/graphql/tools/CoroutineContextProvider.java new file mode 100644 index 00000000..daa0edae --- /dev/null +++ b/src/main/kotlin/com/coxautodev/graphql/tools/CoroutineContextProvider.java @@ -0,0 +1,7 @@ +package com.coxautodev.graphql.tools; + +import kotlin.coroutines.CoroutineContext; + +public interface CoroutineContextProvider { + CoroutineContext provide(); +} diff --git a/src/main/kotlin/com/coxautodev/graphql/tools/MethodFieldResolver.kt b/src/main/kotlin/com/coxautodev/graphql/tools/MethodFieldResolver.kt index 02b22773..03695a3d 100644 --- a/src/main/kotlin/com/coxautodev/graphql/tools/MethodFieldResolver.kt +++ b/src/main/kotlin/com/coxautodev/graphql/tools/MethodFieldResolver.kt @@ -156,7 +156,7 @@ open class MethodFieldResolverDataFetcher(private val sourceResolver: SourceReso val args = this.args.map { it(environment) }.toTypedArray() return if (isSuspendFunction) { - GlobalScope.future(options.coroutineContext) { + GlobalScope.future(options.coroutineContextProvider.provide()) { methodAccess.invokeSuspend(source, methodIndex, args)?.transformWithGenericWrapper(environment) } } else { diff --git a/src/main/kotlin/com/coxautodev/graphql/tools/SchemaParserBuilder.kt b/src/main/kotlin/com/coxautodev/graphql/tools/SchemaParserBuilder.kt index 848968d1..749eabab 100644 --- a/src/main/kotlin/com/coxautodev/graphql/tools/SchemaParserBuilder.kt +++ b/src/main/kotlin/com/coxautodev/graphql/tools/SchemaParserBuilder.kt @@ -7,11 +7,6 @@ import com.google.common.collect.HashBiMap import com.google.common.collect.Maps import graphql.language.Definition import graphql.language.Document -import graphql.language.FieldDefinition -import graphql.language.ListType -import graphql.language.NonNullType -import graphql.language.ObjectTypeDefinition -import graphql.language.TypeName import graphql.parser.Parser import graphql.schema.DataFetchingEnvironment import graphql.schema.GraphQLScalarType @@ -288,7 +283,7 @@ data class SchemaParserOptions internal constructor( val proxyHandlers: List, val preferGraphQLResolver: Boolean, val introspectionEnabled: Boolean, - val coroutineContext: CoroutineContext, + val coroutineContextProvider: CoroutineContextProvider, val typeDefinitionFactories: List ) { companion object { @@ -299,6 +294,9 @@ data class SchemaParserOptions internal constructor( fun defaultOptions() = Builder().build() } + val coroutineContext: CoroutineContext + get() = coroutineContextProvider.provide() + class Builder { private var contextClass: Class<*>? = null private val genericWrappers: MutableList = mutableListOf() @@ -308,6 +306,7 @@ data class SchemaParserOptions internal constructor( private val proxyHandlers: MutableList = mutableListOf(Spring4AopProxyHandler(), GuiceAopProxyHandler(), JavassistProxyHandler(), WeldProxyHandler()) private var preferGraphQLResolver = false private var introspectionEnabled = true + private var coroutineContextProvider: CoroutineContextProvider? = null private var coroutineContext: CoroutineContext? = null private var typeDefinitionFactories: MutableList = mutableListOf(RelayConnectionFactory()) @@ -360,7 +359,11 @@ data class SchemaParserOptions internal constructor( } fun coroutineContext(context: CoroutineContext) = this.apply { - this.coroutineContext = context + this.coroutineContextProvider = DefaultCoroutineContextProvider(context) + } + + fun coroutineContextProvider(contextProvider: CoroutineContextProvider) = this.apply { + this.coroutineContextProvider = contextProvider } fun typeDefinitionFactory(factory: TypeDefinitionFactory) = this.apply { @@ -369,7 +372,7 @@ data class SchemaParserOptions internal constructor( @ExperimentalCoroutinesApi fun build(): SchemaParserOptions { - val coroutineContext = coroutineContext ?: Dispatchers.Default + val coroutineContextProvider = coroutineContextProvider ?: DefaultCoroutineContextProvider(Dispatchers.Default) val wrappers = if (useDefaultGenericWrappers) { genericWrappers + listOf( GenericWrapper(Future::class, 0), @@ -377,7 +380,7 @@ data class SchemaParserOptions internal constructor( GenericWrapper(CompletionStage::class, 0), GenericWrapper(Publisher::class, 0), GenericWrapper.withTransformer(ReceiveChannel::class, 0, { receiveChannel -> - GlobalScope.publish(coroutineContext) { + GlobalScope.publish(coroutineContextProvider.provide()) { try { for (item in receiveChannel) { send(item) @@ -393,7 +396,13 @@ data class SchemaParserOptions internal constructor( } return SchemaParserOptions(contextClass, wrappers, allowUnimplementedResolvers, objectMapperProvider, - proxyHandlers, preferGraphQLResolver, introspectionEnabled, coroutineContext, typeDefinitionFactories) + proxyHandlers, preferGraphQLResolver, introspectionEnabled, coroutineContextProvider, typeDefinitionFactories) + } + } + + internal class DefaultCoroutineContextProvider(val coroutineContext: CoroutineContext): CoroutineContextProvider { + override fun provide(): CoroutineContext { + return coroutineContext } }