Skip to content

Commit 892b627

Browse files
committed
Add custom directives
1 parent e555789 commit 892b627

File tree

10 files changed

+504
-78
lines changed

10 files changed

+504
-78
lines changed

graphql/execute.lua

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,14 +278,48 @@ local function getFieldEntry(objectType, object, fields, context)
278278
if argument and argument.value then
279279
positions[pos] = {
280280
name=argument.name.value,
281-
value=arguments[argument.name.value]
281+
value=arguments[argument.name.value],
282282
}
283283
pos = pos + 1
284284
end
285285
end
286286

287287
arguments = setmetatable(arguments, {__index=positions,})
288288

289+
local directiveMap = {}
290+
for _, directive in ipairs(firstField.directives or {}) do
291+
directiveMap[directive.name.value] = directive
292+
end
293+
294+
local directives = {}
295+
local directivesDefaultValues = {}
296+
297+
if next(directiveMap) then
298+
util.map_name(context.schema.directives or {}, function(directive, directive_name)
299+
local supplied_directive = directiveMap[directive_name]
300+
if supplied_directive ~= nil then
301+
local directiveArgumentMap = {}
302+
for _, argument in ipairs(supplied_directive.arguments or {}) do
303+
directiveArgumentMap[argument.name.value] = argument
304+
end
305+
306+
directives[directive_name] = util.map(directive.arguments or {}, function(argument, name)
307+
local supplied = directiveArgumentMap[name] and directiveArgumentMap[name].value
308+
local defaultValue = argument.defaultValue
309+
if argument.kind then argument = argument.kind end
310+
directivesDefaultValues[directive_name] = directivesDefaultValues[directive_name] or {}
311+
if defaultValue ~= nil then directivesDefaultValues[directive_name][name] = defaultValue end
312+
local res = util.coerceValue(supplied, argument, context.variables, {
313+
strict_non_null = true,
314+
defaultValues = defaultValues,
315+
})
316+
317+
return res
318+
end)
319+
end
320+
end)
321+
end
322+
289323
local info = {
290324
context = context,
291325
fieldName = fieldName,
@@ -298,6 +332,7 @@ local function getFieldEntry(objectType, object, fields, context)
298332
operation = context.operation,
299333
variableValues = context.variables,
300334
defaultValues = context.defaultValues,
335+
directives = directives,
301336
}
302337

303338
local resolvedObject, err = (fieldType.resolve or defaultResolver)(object, arguments, info)

graphql/introspection.lua

Lines changed: 86 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,18 @@ __Directive = types.object({
109109
if directive.onFragmentDefinition then table.insert(res, 'FRAGMENT_DEFINITION') end
110110
if directive.onFragmentSpread then table.insert(res, 'FRAGMENT_SPREAD') end
111111
if directive.onInlineFragment then table.insert(res, 'INLINE_FRAGMENT') end
112+
if directive.onVariableDefinition then table.insert(res, 'VARIABLE_DEFINITION') end
113+
if directive.onSchema then table.insert(res, 'SCHEMA') end
114+
if directive.onScalar then table.insert(res, 'SCALAR') end
115+
if directive.onObject then table.insert(res, 'OBJECT') end
116+
if directive.onFieldDefinition then table.insert(res, 'FIELD_DEFINITION') end
117+
if directive.onArgumentDefinition then table.insert(res, 'ARGUMENT_DEFINITION') end
118+
if directive.onInterface then table.insert(res, 'INTERFACE') end
119+
if directive.onUnion then table.insert(res, 'UNION') end
120+
if directive.onEnum then table.insert(res, 'ENUM') end
121+
if directive.onEnumValue then table.insert(res, 'ENUM_VALUE') end
122+
if directive.onInputObject then table.insert(res, 'INPUT_OBJECT') end
123+
if directive.onInputFieldDefinition then table.insert(res, 'INPUT_FIELD_DEFINITION') end
112124

113125
return res
114126
end,
@@ -117,7 +129,14 @@ __Directive = types.object({
117129
args = {
118130
kind = types.nonNull(types.list(types.nonNull(__InputValue))),
119131
resolve = resolveArgs,
120-
}
132+
},
133+
134+
isRepeatable = {
135+
kind = types.nonNull(types.boolean),
136+
resolve = function(directive)
137+
return directive.isRepeatable == true
138+
end
139+
},
121140
}
122141
end
123142
})
@@ -160,6 +179,66 @@ __DirectiveLocation = types.enum({
160179
value = 'INLINE_FRAGMENT',
161180
description = 'Location adjacent to an inline fragment.',
162181
},
182+
183+
VARIABLE_DEFINITION = {
184+
value = 'VARIABLE_DEFINITION',
185+
description = 'Location adjacent to a variable definition.',
186+
},
187+
188+
SCHEMA = {
189+
value = 'SCHEMA',
190+
description = 'Location adjacent to schema.',
191+
},
192+
193+
SCALAR = {
194+
value = 'SCALAR',
195+
description = 'Location adjacent to a scalar.',
196+
},
197+
198+
OBJECT = {
199+
value = 'OBJECT',
200+
description = 'Location adjacent to an object.',
201+
},
202+
203+
FIELD_DEFINITION = {
204+
value = 'FIELD_DEFINITION',
205+
description = 'Location adjacent to a field definition.',
206+
},
207+
208+
ARGUMENT_DEFINITION = {
209+
value = 'ARGUMENT_DEFINITION',
210+
description = 'Location adjacent to an argument definition.',
211+
},
212+
213+
INTERFACE = {
214+
value = 'INTERFACE',
215+
description = 'Location adjacent to an interface.',
216+
},
217+
218+
UNION = {
219+
value = 'UNION',
220+
description = 'Location adjacent to an union.',
221+
},
222+
223+
ENUM = {
224+
value = 'ENUM',
225+
description = 'Location adjacent to an enum.',
226+
},
227+
228+
ENUM_VALUE = {
229+
value = 'ENUM_VALUE',
230+
description = 'Location adjacent to an enum value.',
231+
},
232+
233+
INPUT_OBJECT = {
234+
value = 'INPUT_OBJECT',
235+
description = 'Location adjacent to an input object.',
236+
},
237+
238+
INPUT_FIELD_DEFINITION = {
239+
value = 'INPUT_FIELD_DEFINITION',
240+
description = 'Location adjacent to an input field definition.',
241+
},
163242
}
164243
})
165244

@@ -272,7 +351,7 @@ __Type = types.object({
272351
kind = __Type,
273352
},
274353
}
275-
end
354+
end,
276355
})
277356

278357
__Field = types.object({
@@ -309,7 +388,7 @@ __Field = types.object({
309388

310389
deprecationReason = types.string,
311390
}
312-
end
391+
end,
313392
})
314393

315394
__InputValue = types.object({
@@ -341,7 +420,7 @@ __InputValue = types.object({
341420
end,
342421
},
343422
}
344-
end
423+
end,
345424
})
346425

347426
__EnumValue = types.object({
@@ -359,11 +438,11 @@ __EnumValue = types.object({
359438
description = types.string,
360439
isDeprecated = {
361440
kind = types.boolean.nonNull,
362-
resolve = function(enumValue) return enumValue.deprecationReason ~= nil end
441+
resolve = function(enumValue) return enumValue.deprecationReason ~= nil end,
363442
},
364443
deprecationReason = types.string,
365444
}
366-
end
445+
end,
367446
})
368447

369448
__TypeKind = types.enum({
@@ -409,7 +488,7 @@ __TypeKind = types.enum({
409488
value = 'NON_NULL',
410489
description = 'Indicates this type is a non-null. `ofType` is a valid field.',
411490
},
412-
}
491+
},
413492
})
414493

415494
local Schema = {

graphql/rules.lua

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ function rules.unambiguousSelections(node, context)
146146

147147
table.insert(selectionMap[key], entry)
148148
else
149-
selectionMap[key] = { entry }
149+
selectionMap[key] = { entry, }
150150
end
151151
end
152152

@@ -314,7 +314,7 @@ function rules.fragmentSpreadIsPossible(node, context)
314314

315315
local function getTypes(kind)
316316
if kind.__type == 'Object' then
317-
return { [kind] = kind }
317+
return { [kind] = kind, }
318318
elseif kind.__type == 'Interface' then
319319
return context.schema:getImplementors(kind.name)
320320
elseif kind.__type == 'Union' then
@@ -332,7 +332,6 @@ function rules.fragmentSpreadIsPossible(node, context)
332332
local fragmentTypes = getTypes(fragmentType)
333333

334334
local valid = util.find(parentTypes, function(kind)
335-
local kind = kind
336335
-- Here is the check that type, mentioned in '... on some_type'
337336
-- conditional fragment expression is type of some field of parent object.
338337
-- In case of Union parent object and NonNull wrapped inner types

graphql/schema.lua

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,26 @@ end
9999
function schema:generateDirectiveMap()
100100
for _, directive in ipairs(self.directives) do
101101
self.directiveMap[directive.name] = directive
102+
if directive.arguments then
103+
for name, argument in pairs(directive.arguments) do
104+
105+
-- BEGIN_HACK: resolve type names to real types
106+
if type(argument) == 'string' then
107+
argument = types.resolve(argument, self.name)
108+
directive.arguments[name] = argument
109+
end
110+
111+
if type(argument.kind) == 'string' then
112+
argument.kind = types.resolve(argument.kind, self.name)
113+
end
114+
-- END_HACK: resolve type names to real types
115+
116+
local argumentType = argument.__type and argument or argument.kind
117+
assert(argumentType, 'Must supply type for argument "' .. name .. '" on "' .. directive.name .. '"')
118+
argumentType.defaultValue = argument.defaultValue
119+
self:generateTypeMap(argumentType)
120+
end
121+
end
102122
end
103123
end
104124

graphql/types.lua

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,19 @@ function types.directive(config)
419419
onFragmentDefinition = config.onFragmentDefinition,
420420
onFragmentSpread = config.onFragmentSpread,
421421
onInlineFragment = config.onInlineFragment,
422+
onVariableDefinition = config.onVariableDefinition,
423+
onSchema = config.onSchema,
424+
onScalar = config.onScalar,
425+
onObject = config.onObject,
426+
onFieldDefinition = config.onFieldDefinition,
427+
onArgumentDefinition = config.onArgumentDefinition,
428+
onInterface = config.onInterface,
429+
onUnion = config.onUnion,
430+
onEnum = config.onEnum,
431+
onEnumValue = config.onEnumValue,
432+
onInputObject = config.onInputObject,
433+
onInputFieldDefinition = config.onInputFieldDefinition,
434+
isRepeatable = config.isRepeatable or false
422435
}
423436

424437
return instance
@@ -428,7 +441,7 @@ types.include = types.directive({
428441
name = 'include',
429442
description = 'Directs the executor to include this field or fragment only when the `if` argument is true.',
430443
arguments = {
431-
['if'] = { kind = types.boolean.nonNull, description = 'Included when true.'}
444+
['if'] = { kind = types.boolean.nonNull, description = 'Included when true.', },
432445
},
433446
onField = true,
434447
onFragmentSpread = true,
@@ -439,7 +452,7 @@ types.skip = types.directive({
439452
name = 'skip',
440453
description = 'Directs the executor to skip this field or fragment when the `if` argument is true.',
441454
arguments = {
442-
['if'] = { kind = types.boolean.nonNull, description = 'Skipped when true.' }
455+
['if'] = { kind = types.boolean.nonNull, description = 'Skipped when true.', },
443456
},
444457
onField = true,
445458
onFragmentSpread = true,

graphql/util.lua

Lines changed: 31 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,16 @@ local function map(t, fn)
1111
return res
1212
end
1313

14+
local function map_name(t, fn)
15+
local res = {}
16+
for _, v in ipairs(t or {}) do
17+
if v.name then
18+
res[v.name] = fn(v, v.name)
19+
end
20+
end
21+
return res
22+
end
23+
1424
local function find(t, fn)
1525
for k, v in pairs(t) do
1626
if fn(v, k) then return v end
@@ -169,17 +179,17 @@ local function coerceValue(node, schemaType, variables, opts)
169179
end
170180
end
171181

172-
--- Check whether passed value has one of listed types.
173-
---
174-
--- @param obj value to check
175-
---
176-
--- @tparam string obj_name name of the value to form an error
177-
---
178-
--- @tparam string type_1
179-
--- @tparam[opt] string type_2
180-
--- @tparam[opt] string type_3
181-
---
182-
--- @return nothing
182+
-- Check whether passed value has one of listed types.
183+
--
184+
-- @param obj value to check
185+
--
186+
-- @tparam string obj_name name of the value to form an error
187+
--
188+
-- @tparam string type_1
189+
-- @tparam[opt] string type_2
190+
-- @tparam[opt] string type_3
191+
--
192+
-- @return nothing
183193
local function check(obj, obj_name, type_1, type_2, type_3)
184194
if type(obj) == type_1 or type(obj) == type_2 or type(obj) == type_3 then
185195
return
@@ -196,15 +206,15 @@ local function check(obj, obj_name, type_1, type_2, type_3)
196206
end
197207
end
198208

199-
--- Check whether table is an array.
200-
---
201-
--- Based on [that][1] implementation.
202-
--- [1]: https://github.com/mpx/lua-cjson/blob/db122676/lua/cjson/util.lua
203-
---
204-
--- @tparam table table to check
205-
--- @return[1] `true` if passed table is an array (includes the empty table
206-
--- case)
207-
--- @return[2] `false` otherwise
209+
-- Check whether table is an array.
210+
--
211+
-- Based on [that][1] implementation.
212+
-- [1]: https://github.com/mpx/lua-cjson/blob/db122676/lua/cjson/util.lua
213+
--
214+
-- @tparam table table to check
215+
-- @return[1] `true` if passed table is an array (includes the empty table
216+
-- case)
217+
-- @return[2] `false` otherwise
208218
local function is_array(table)
209219
if type(table) ~= 'table' then
210220
return false
@@ -270,6 +280,7 @@ end
270280

271281
return {
272282
map = map,
283+
map_name = map_name,
273284
find = find,
274285
filter = filter,
275286
values = values,

0 commit comments

Comments
 (0)