Skip to content

Update custom context class option to be compatible with new GraphQLContext class #566

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Oct 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ internal class SchemaClassScanner(
val methods = clazz.methods

val filteredMethods = methods.filter {
it.name == name || it.name == "get${name.capitalize()}"
it.name == name || it.name == "get${name.replaceFirstChar(Char::titlecase)}"
}.sortedBy { it.name.length }
return filteredMethods.find {
!it.isSynthetic
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package graphql.kickstart.tools.resolver

import graphql.GraphQLContext
import graphql.Scalars
import graphql.kickstart.tools.ResolverInfo
import graphql.kickstart.tools.RootResolverInfo
Expand All @@ -25,7 +26,7 @@ internal class FieldResolverScanner(val options: SchemaParserOptions) {

private val log = LoggerFactory.getLogger(javaClass)

private val allowedLastArgumentTypes = listOfNotNull(DataFetchingEnvironment::class.java, options.contextClass)
private val allowedLastArgumentTypes = listOfNotNull(DataFetchingEnvironment::class.java, GraphQLContext::class.java, options.contextClass)

fun findFieldResolver(field: FieldDefinition, resolverInfo: ResolverInfo): FieldResolver {
val searches = resolverInfo.getFieldSearches()
Expand Down Expand Up @@ -72,7 +73,7 @@ internal class FieldResolverScanner(val options: SchemaParserOptions) {
}
}

private fun findResolverMethod(field: FieldDefinition, search: Search): java.lang.reflect.Method? {
private fun findResolverMethod(field: FieldDefinition, search: Search): Method? {
val methods = getAllMethods(search.type)
val argumentCount = field.inputValueDefinitions.size + if (search.requiredFirstParameterType != null) 1 else 0
val name = field.name
Expand Down Expand Up @@ -113,7 +114,7 @@ internal class FieldResolverScanner(val options: SchemaParserOptions) {

private fun isBoolean(type: GraphQLLangType) = type.unwrap().let { it is TypeName && it.name == Scalars.GraphQLBoolean.name }

private fun verifyMethodArguments(method: java.lang.reflect.Method, requiredCount: Int, search: Search): Boolean {
private fun verifyMethodArguments(method: Method, requiredCount: Int, search: Search): Boolean {
val appropriateFirstParameter = if (search.requiredFirstParameterType != null) {
method.genericParameterTypes.firstOrNull()?.let {
it == search.requiredFirstParameterType || method.declaringClass.typeParameters.contains(it)
Expand All @@ -130,15 +131,15 @@ internal class FieldResolverScanner(val options: SchemaParserOptions) {
return correctParameterCount && appropriateFirstParameter
}

private fun getMethodParameterCount(method: java.lang.reflect.Method): Int {
private fun getMethodParameterCount(method: Method): Int {
return try {
method.kotlinFunction?.valueParameters?.size ?: method.parameterCount
} catch (e: InternalError) {
method.parameterCount
}
}

private fun getMethodLastParameter(method: java.lang.reflect.Method): Type? {
private fun getMethodLastParameter(method: Method): Type? {
return try {
method.kotlinFunction?.valueParameters?.lastOrNull()?.type?.javaType
?: method.parameterTypes.lastOrNull()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package graphql.kickstart.tools.resolver

import com.fasterxml.jackson.core.type.TypeReference
import graphql.GraphQLContext
import graphql.TrivialDataFetcher
import graphql.kickstart.tools.*
import graphql.kickstart.tools.SchemaParserOptions.GenericWrapper
Expand All @@ -13,6 +14,7 @@ import graphql.schema.DataFetcher
import graphql.schema.DataFetchingEnvironment
import graphql.schema.GraphQLTypeUtil.isScalar
import kotlinx.coroutines.future.future
import org.slf4j.LoggerFactory
import java.lang.reflect.Method
import java.util.*
import kotlin.coroutines.intrinsics.suspendCoroutineUninterceptedOrReturn
Expand All @@ -30,9 +32,12 @@ internal class MethodFieldResolver(
val method: Method
) : FieldResolver(field, search, options, search.type) {

private val log = LoggerFactory.getLogger(javaClass)

private val additionalLastArgument =
try {
method.kotlinFunction?.valueParameters?.size ?: method.parameterCount == (field.inputValueDefinitions.size + getIndexOffset() + 1)
(method.kotlinFunction?.valueParameters?.size
?: method.parameterCount) == (field.inputValueDefinitions.size + getIndexOffset() + 1)
} catch (e: InternalError) {
method.parameterCount == (field.inputValueDefinitions.size + getIndexOffset() + 1)
}
Expand Down Expand Up @@ -94,7 +99,21 @@ internal class MethodFieldResolver(
if (this.additionalLastArgument) {
when (this.method.parameterTypes.last()) {
null -> throw ResolverError("Expected at least one argument but got none, this is most likely a bug with graphql-java-tools")
options.contextClass -> args.add { environment -> environment.getContext() }
options.contextClass -> args.add { environment ->
val context: Any? = environment.graphQlContext[options.contextClass]
if (context != null) {
context
} else {
log.warn(
"Generic context class has been deprecated by graphql-java. " +
"To continue using a custom context class as the last parameter in resolver methods " +
"please insert it into the GraphQLContext map when building the ExecutionInput. " +
"This warning will become an error in the future."
)
environment.getContext() // TODO: remove deprecated use in next major release
}
}
GraphQLContext::class.java -> args.add { environment -> environment.graphQlContext }
else -> args.add { environment -> environment }
}
}
Expand Down
6 changes: 3 additions & 3 deletions src/test/kotlin/graphql/kickstart/tools/DirectiveTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,9 @@ class DirectiveTest {
val parentType = environment.fieldsContainer

val originalDataFetcher = environment.codeRegistry.getDataFetcher(parentType, field)
val wrappedDataFetcher = DataFetcherFactories.wrapDataFetcher(originalDataFetcher, { _, value ->
(value as? String)?.toUpperCase()
})
val wrappedDataFetcher = DataFetcherFactories.wrapDataFetcher(originalDataFetcher) { _, value ->
(value as? String)?.uppercase()
}

environment.codeRegistry.dataFetcher(parentType, field, wrappedDataFetcher)

Expand Down
8 changes: 3 additions & 5 deletions src/test/kotlin/graphql/kickstart/tools/EndToEndSpecHelper.kt
Original file line number Diff line number Diff line change
Expand Up @@ -366,25 +366,23 @@ class Mutation : GraphQLMutationResolver {
}
}

class OnItemCreatedContext(val newItem: Item)

class Subscription : GraphQLSubscriptionResolver {
fun onItemCreated(env: DataFetchingEnvironment) =
Publisher<Item> { subscriber ->
subscriber.onNext(env.getContext<OnItemCreatedContext>().newItem)
subscriber.onNext(env.graphQlContext["newItem"])
// subscriber.onComplete()
}

fun onItemCreatedCoroutineChannel(env: DataFetchingEnvironment): ReceiveChannel<Item> {
val channel = Channel<Item>(1)
channel.offer(env.getContext<OnItemCreatedContext>().newItem)
channel.trySend(env.graphQlContext["newItem"])
return channel
}

suspend fun onItemCreatedCoroutineChannelAndSuspendFunction(env: DataFetchingEnvironment): ReceiveChannel<Item> {
return coroutineScope {
val channel = Channel<Item>(1)
channel.offer(env.getContext<OnItemCreatedContext>().newItem)
channel.trySend(env.graphQlContext["newItem"])
channel
}
}
Expand Down
6 changes: 3 additions & 3 deletions src/test/kotlin/graphql/kickstart/tools/EndToEndTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class EndToEndTest {

val result = gql.execute(ExecutionInput.newExecutionInput()
.query(closure.invoke())
.context(OnItemCreatedContext(newItem))
.graphQLContext(mapOf("newItem" to newItem))
.variables(mapOf()))

val data = result.getData() as Publisher<ExecutionResult>
Expand Down Expand Up @@ -597,7 +597,7 @@ class EndToEndTest {

val result = gql.execute(ExecutionInput.newExecutionInput()
.query(closure.invoke())
.context(OnItemCreatedContext(newItem))
.graphQLContext(mapOf("newItem" to newItem))
.variables(mapOf()))

val data = result.getData() as Publisher<ExecutionResult>
Expand Down Expand Up @@ -625,7 +625,7 @@ class EndToEndTest {

val result = gql.execute(ExecutionInput.newExecutionInput()
.query(closure.invoke())
.context(OnItemCreatedContext(newItem))
.graphQLContext(mapOf("newItem" to newItem))
.variables(mapOf()))

val data = result.getData() as Publisher<ExecutionResult>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package graphql.kickstart.tools

import graphql.ExecutionResult
import graphql.GraphQLContext
import graphql.execution.*
import graphql.execution.instrumentation.SimpleInstrumentation
import graphql.kickstart.tools.resolver.FieldResolverError
Expand Down Expand Up @@ -79,8 +80,8 @@ class MethodFieldResolverDataFetcherTest {
val channel = Channel<String>(10)

init {
channel.offer("A")
channel.offer("B")
channel.trySend("A")
channel.trySend("B")
}

@Suppress("UNUSED_PARAMETER")
Expand Down Expand Up @@ -176,21 +177,34 @@ class MethodFieldResolverDataFetcherTest {

@Test
fun `data fetcher passes environment if method has extra argument even if context is specified`() {
val options = SchemaParserOptions.newOptions().contextClass(ContextClass::class).build()
val resolver = createFetcher("active", options = options, resolver = object : GraphQLResolver<DataClass> {
val context = GraphQLContext.newContext().build()
val resolver = createFetcher("active", resolver = object : GraphQLResolver<DataClass> {
fun isActive(dataClass: DataClass, env: DataFetchingEnvironment): Boolean = env is DataFetchingEnvironment
})

assertEquals(resolver.get(createEnvironment(DataClass(), context = ContextClass())), true)
assertEquals(resolver.get(createEnvironment(DataClass(), context = context)), true)
}

@Test
fun `data fetcher passes context if method has extra argument and context is specified`() {
val context = ContextClass()
val context = GraphQLContext.newContext().build()
val resolver = createFetcher("active", resolver = object : GraphQLResolver<DataClass> {
fun isActive(dataClass: DataClass, ctx: GraphQLContext): Boolean {
return ctx == context
}
})

assertEquals(resolver.get(createEnvironment(DataClass(), context = context)), true)
}

@Test
fun `data fetcher passes custom context if method has extra argument and custom context is specified as part of GraphQLContext`() {
val customContext = ContextClass()
val context = GraphQLContext.of(mapOf(ContextClass::class.java to customContext))
val options = SchemaParserOptions.newOptions().contextClass(ContextClass::class).build()
val resolver = createFetcher("active", options = options, resolver = object : GraphQLResolver<DataClass> {
fun isActive(dataClass: DataClass, ctx: ContextClass): Boolean {
return ctx == context
return ctx == customContext
}
})

Expand Down Expand Up @@ -243,7 +257,7 @@ class MethodFieldResolverDataFetcherTest {
private val channel = Channel<String>(10)

init {
channel.offer("A")
channel.trySend("A")
channel.close(IllegalStateException("Channel error"))
}

Expand Down Expand Up @@ -281,11 +295,11 @@ class MethodFieldResolverDataFetcherTest {
return FieldResolverScanner(options).findFieldResolver(field, resolverInfo).createDataFetcher()
}

private fun createEnvironment(source: Any = Object(), arguments: Map<String, Any> = emptyMap(), context: Any? = null): DataFetchingEnvironment {
private fun createEnvironment(source: Any = Object(), arguments: Map<String, Any> = emptyMap(), context: GraphQLContext? = null): DataFetchingEnvironment {
return DataFetchingEnvironmentImpl.newDataFetchingEnvironment(buildExecutionContext())
.source(source)
.arguments(arguments)
.context(context)
.graphQLContext(context)
.build()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ class MethodFieldResolverTest {
testNull(input: null)
}
""")
.context(Object())
.root(Object()))

assertEquals(result.getData(), mapOf(
Expand Down Expand Up @@ -88,7 +87,6 @@ class MethodFieldResolverTest {
testNull(input: null)
}
""")
.context(Object())
.root(Object()))

assertEquals(result.getData(), mapOf(
Expand Down Expand Up @@ -124,7 +122,6 @@ class MethodFieldResolverTest {
}
""")
.variables(mapOf("input" to "FooBar"))
.context(Object())
.root(Object()))

assertEquals(result.getData(), mapOf("test" to 6))
Expand Down Expand Up @@ -156,7 +153,6 @@ class MethodFieldResolverTest {
}
""")
.variables(mapOf("input" to listOf("Foo", "Bar")))
.context(Object())
.root(Object()))

assertEquals(result.getData(), mapOf("test" to 6))
Expand Down Expand Up @@ -204,7 +200,6 @@ class MethodFieldResolverTest {
}
""")
.variables(mapOf("input" to listOf("Foo", "Bar")))
.context(Object())
.root(Object()))

assertEquals(result.getData(), mapOf("test" to 6))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ public void testOmittedBooleanArgument() {
ExecutionResult result = gql
.execute(ExecutionInput.newExecutionInput()
.query("query { testOmittedBoolean }")
.context(new Object())
.root(new Object()));

assertTrue(result.getErrors().isEmpty());
Expand Down
52 changes: 20 additions & 32 deletions src/test/kotlin/graphql/kickstart/tools/SchemaParserTest.kt
Original file line number Diff line number Diff line change
@@ -1,28 +1,19 @@
package graphql.kickstart.tools

import graphql.kickstart.tools.resolver.FieldResolverError
import graphql.schema.GraphQLInterfaceType
import graphql.schema.GraphQLObjectType
import graphql.schema.GraphQLArgument
import graphql.schema.GraphQLInputObjectType
import graphql.schema.GraphQLNonNull
import graphql.schema.*
import graphql.schema.idl.SchemaDirectiveWiring
import graphql.schema.idl.SchemaDirectiveWiringEnvironment
import org.junit.Assert.assertThrows
import org.junit.Before
import org.junit.Rule
import org.junit.Test
import org.junit.rules.ExpectedException
import org.springframework.aop.framework.ProxyFactory
import java.io.FileNotFoundException
import java.util.concurrent.Future

class SchemaParserTest {
private lateinit var builder: SchemaParserBuilder

@Rule
@JvmField
var expectedEx: ExpectedException = ExpectedException.none()

@Before
fun setup() {
builder = SchemaParser.newParser()
Expand Down Expand Up @@ -197,27 +188,24 @@ class SchemaParserTest {

@Test
fun `parser should throw descriptive exception when object is used as input type incorrectly`() {
expectedEx.expect(SchemaError::class.java)
expectedEx.expectMessage("Was a type only permitted for object types incorrectly used as an input type, or vice-versa")

SchemaParser.newParser()
.schemaString(
"""
type Query {
name(filter: Filter): [String]
}

type Filter {
filter: String
}
""")
.resolvers(object : GraphQLQueryResolver {
fun name(filter: Filter): List<String>? = null
})
.build()
.makeExecutableSchema()

throw AssertionError("should not be called")
assertThrows("Was a type only permitted for object types incorrectly used as an input type, or vice-versa", SchemaError::class.java) {
SchemaParser.newParser()
.schemaString(
"""
type Query {
name(filter: Filter): [String]
}

type Filter {
filter: String
}
""")
.resolvers(object : GraphQLQueryResolver {
fun name(filter: Filter): List<String>? = null
})
.build()
.makeExecutableSchema()
}
}

@Test
Expand Down
Loading