diff --git a/graphql/execute.lua b/graphql/execute.lua index 181b0c9..6cf053d 100644 --- a/graphql/execute.lua +++ b/graphql/execute.lua @@ -2,20 +2,8 @@ local path = (...):gsub('%.[^%.]+$', '') local types = require(path .. '.types') local util = require(path .. '.util') local introspection = require(path .. '.introspection') - -local function typeFromAST(node, schema) - local innerType - if node.kind == 'listType' then - innerType = typeFromAST(node.type) - return innerType and types.list(innerType) - elseif node.kind == 'nonNullType' then - innerType = typeFromAST(node.type) - return innerType and types.nonNull(innerType) - else - assert(node.kind == 'namedType', 'Variable must be a named type') - return schema:getType(node.name.value) - end -end +local query_util = require(path .. '.query_util') +local validate_variables = require(path .. '.validate_variables') local function getFieldResponseKey(field) return field.alias and field.alias.name.value or field.name.value @@ -49,7 +37,7 @@ end local function doesFragmentApply(fragment, type, context) if not fragment.typeCondition then return true end - local innerType = typeFromAST(fragment.typeCondition, context.schema) + local innerType = query_util.typeFromAST(fragment.typeCondition, context.schema) if innerType == type then return true @@ -82,38 +70,69 @@ local function defaultResolver(object, arguments, info) return object[info.fieldASTs[1].name.value] end -local function buildContext(schema, tree, rootValue, variables, operationName) - local context = { - schema = schema, - rootValue = rootValue, - variables = variables, - operation = nil, - fragmentMap = {} - } +local function getOperation(tree, operationName) + local operation - for _, definition in ipairs(tree.definitions) do - if definition.kind == 'operation' then - if not operationName and context.operation then - error('Operation name must be specified if more than one operation exists.') - end + for _, definition in ipairs(tree.definitions) do + if definition.kind == 'operation' then + if not operationName and operation then + error('Operation name must be specified if more than one operation exists.') + end - if not operationName or definition.name.value == operationName then - context.operation = definition - end - elseif definition.kind == 'fragmentDefinition' then - context.fragmentMap[definition.name.value] = definition + if not operationName or definition.name.value == operationName then + operation = definition + end + end end - end - if not context.operation then - if operationName then - error('Unknown operation "' .. operationName .. '"') - else - error('Must provide an operation') + if not operation then + if operationName then + error('Unknown operation "' .. operationName .. '"') + else + error('Must provide an operation') + end end - end - return context + return operation +end + +local function getFragmentDefinitions(tree) + local fragmentMap = {} + + for _, definition in ipairs(tree.definitions) do + if definition.kind == 'fragmentDefinition' then + fragmentMap[definition.name.value] = definition + end + end + + return fragmentMap +end + +-- Extract variableTypes from the operation. +local function getVariableTypes(schema, operation) + local variableTypes = {} + + for _, definition in ipairs(operation.variableDefinitions or {}) do + variableTypes[definition.variable.name.value] = + query_util.typeFromAST(definition.type, schema) + end + + return variableTypes +end + +local function buildContext(schema, tree, rootValue, variables, operationName) + local operation = getOperation(tree, operationName) + local fragmentMap = getFragmentDefinitions(tree) + local variableTypes = getVariableTypes(schema, operation) + return { + schema = schema, + rootValue = rootValue, + variables = variables, + operation = operation, + fragmentMap = fragmentMap, + variableTypes = variableTypes, + request_cache = {}, + } end local function collectFields(objectType, selections, visitedFragments, result, context) @@ -247,5 +266,7 @@ return function(schema, tree, rootValue, variables, operationName) error('Unsupported operation "' .. context.operation.operation .. '"') end + validate_variables.validate_variables(context) + return evaluateSelections(rootType, rootValue, context.operation.selectionSet.selections, context) end diff --git a/graphql/query_util.lua b/graphql/query_util.lua new file mode 100644 index 0000000..9b0524f --- /dev/null +++ b/graphql/query_util.lua @@ -0,0 +1,20 @@ +local path = (...):gsub('%.[^%.]+$', '') +local types = require(path .. '.types') + +local function typeFromAST(node, schema) + local innerType + if node.kind == 'listType' then + innerType = typeFromAST(node.type, schema) + return innerType and types.list(innerType) + elseif node.kind == 'nonNullType' then + innerType = typeFromAST(node.type, schema) + return innerType and types.nonNull(innerType) + else + assert(node.kind == 'namedType', 'Variable must be a named type') + return schema:getType(node.name.value) + end +end + +return { + typeFromAST = typeFromAST, +} diff --git a/graphql/rules.lua b/graphql/rules.lua index 61005ea..669d961 100644 --- a/graphql/rules.lua +++ b/graphql/rules.lua @@ -1,7 +1,7 @@ local path = (...):gsub('%.[^%.]+$', '') local types = require(path .. '.types') local util = require(path .. '.util') -local schema = require(path .. '.schema') +local query_util = require(path .. '.query_util') local introspection = require(path .. '.introspection') local function getParentField(context, name, count) @@ -475,22 +475,8 @@ function rules.variableUsageAllowed(node, context) local variableName = argument.value.name.value local variableDefinition = variableMap[variableName] local hasDefault = variableDefinition.defaultValue ~= nil - - local function typeFromAST(variable) - local innerType - if variable.kind == 'listType' then - innerType = typeFromAST(variable.type) - return innerType and types.list(innerType) - elseif variable.kind == 'nonNullType' then - innerType = typeFromAST(variable.type) - return innerType and types.nonNull(innerType) - else - assert(variable.kind == 'namedType', 'Variable must be a named type') - return context.schema:getType(variable.name.value) - end - end - - local variableType = typeFromAST(variableDefinition.type) + local variableType = query_util.typeFromAST(variableDefinition.type, + context.schema) if hasDefault and variableType.__type ~= 'NonNull' then variableType = types.nonNull(variableType) diff --git a/graphql/types.lua b/graphql/types.lua index e24a30d..b792f6c 100644 --- a/graphql/types.lua +++ b/graphql/types.lua @@ -1,8 +1,29 @@ +local ffi = require('ffi') local path = (...):gsub('%.[^%.]+$', '') local util = require(path .. '.util') local types = {} +local function initFields(kind, fields) + assert(type(fields) == 'table', 'fields table must be provided') + + local result = {} + + for fieldName, field in pairs(fields) do + field = field.__type and { kind = field } or field + result[fieldName] = { + name = fieldName, + kind = field.kind, + description = field.description, + deprecationReason = field.deprecationReason, + arguments = field.arguments or {}, + resolve = kind == 'Object' and field.resolve or nil + } + end + + return result +end + function types.nonNull(kind) assert(kind, 'Must provide a type') @@ -21,6 +42,15 @@ function types.list(kind) } end +function types.nullable(kind) + assert(type(kind) == 'table', 'kind must be a table, got ' .. type(kind)) + + if kind.__type ~= 'NonNull' then return kind end + + assert(kind.ofType ~= nil, 'kind.ofType must not be nil') + return types.nullable(kind.ofType) +end + function types.scalar(config) assert(type(config.name) == 'string', 'type name must be provided as a string') assert(type(config.serialize) == 'function', 'serialize must be a function') @@ -37,7 +67,8 @@ function types.scalar(config) description = config.description, serialize = config.serialize, parseValue = config.parseValue, - parseLiteral = config.parseLiteral + parseLiteral = config.parseLiteral, + isValueOfTheType = config.isValueOfTheType, } instance.nonNull = types.nonNull(instance) @@ -99,26 +130,6 @@ function types.interface(config) return instance end -function initFields(kind, fields) - assert(type(fields) == 'table', 'fields table must be provided') - - local result = {} - - for fieldName, field in pairs(fields) do - field = field.__type and { kind = field } or field - result[fieldName] = { - name = fieldName, - kind = field.kind, - description = field.description, - deprecationReason = field.deprecationReason, - arguments = field.arguments or {}, - resolve = kind == 'Object' and field.resolve or nil - } - end - - return result -end - function types.enum(config) assert(type(config.name) == 'string', 'type name must be provided as a string') assert(type(config.values) == 'table', 'values table must be provided') @@ -189,14 +200,30 @@ function types.inputObject(config) return instance end -local coerceInt = function(value) - value = tonumber(value) - - if not value then return end +-- Based on the code from tarantool/checks. +local function isInt(value) + if type(value) == 'number' then + return value >= -2^31 and value < 2^31 and math.floor(value) == value + end - if value == value and value < 2 ^ 32 and value >= -2 ^ 32 then - return value < 0 and math.ceil(value) or math.floor(value) + if type(value) == 'cdata' then + if ffi.istype('int64_t', value) then + return value >= -2^31 and value < 2^31 + elseif ffi.istype('uint64_t', value) then + return value < 2^31 + end end + + return false +end + +local function coerceInt(value) + local value = tonumber(value) + + if value == nil then return end + if not isInt(value) then return end + + return value end types.int = types.scalar({ @@ -208,7 +235,8 @@ types.int = types.scalar({ if node.kind == 'int' then return coerceInt(node.value) end - end + end, + isValueOfTheType = isInt, }) types.float = types.scalar({ @@ -219,7 +247,10 @@ types.float = types.scalar({ if node.kind == 'float' or node.kind == 'int' then return tonumber(node.value) end - end + end, + isValueOfTheType = function(value) + return type(value) == 'number' + end, }) types.string = types.scalar({ @@ -231,7 +262,10 @@ types.string = types.scalar({ if node.kind == 'string' then return node.value end - end + end, + isValueOfTheType = function(value) + return type(value) == 'string' + end, }) local function toboolean(x) @@ -249,7 +283,10 @@ types.boolean = types.scalar({ else return nil end - end + end, + isValueOfTheType = function(value) + return type(value) == 'boolean' + end, }) types.id = types.scalar({ @@ -258,7 +295,10 @@ types.id = types.scalar({ parseValue = tostring, parseLiteral = function(node) return node.kind == 'string' or node.kind == 'int' and node.value or nil - end + end, + isValueOfTheType = function(value) + error('Not yet implemented') + end, }) function types.directive(config) diff --git a/graphql/util.lua b/graphql/util.lua index 45aa3b1..81ce9db 100644 --- a/graphql/util.lua +++ b/graphql/util.lua @@ -1,21 +1,20 @@ -local util = {} +local yaml = require('yaml').new({encode_use_tostring = true}) -function util.map(t, fn) +local function map(t, fn) local res = {} for k, v in pairs(t) do res[k] = fn(v, k) end return res end -function util.find(t, fn) - local res = {} +local function find(t, fn) for k, v in pairs(t) do if fn(v, k) then return v end end end -function util.filter(t, fn) +local function filter(t, fn) local res = {} - for k,v in pairs(t) do + for _,v in pairs(t) do if fn(v) then table.insert(res, v) end @@ -23,7 +22,7 @@ function util.filter(t, fn) return res end -function util.values(t) +local function values(t) local res = {} for _, value in pairs(t) do table.insert(res, value) @@ -31,31 +30,56 @@ function util.values(t) return res end -function util.compose(f, g) +local function compose(f, g) return function(...) return f(g(...)) end end -function util.bind1(func, x) +local function bind1(func, x) return function(y) return func(x, y) end end -function util.trim(s) - return s:gsub('^%s+', ''):gsub('%s$', ''):gsub('%s%s+', ' ') +local function trim(s) + return s:gsub('^%s+', ''):gsub('%s+$', ''):gsub('%s%s+', ' ') +end + +local function getTypeName(t) + if t.name ~= nil then + return t.name + elseif t.__type == 'NonNull' then + return ('NonNull(%s)'):format(getTypeName(t.ofType)) + elseif t.__type == 'List' then + return ('List(%s)'):format(getTypeName(t.ofType)) + end + + local err = ('Internal error: unknown type:\n%s'):format(yaml.encode(t)) + error(err) end -function util.coerceValue(node, schemaType, variables) +local function coerceValue(node, schemaType, variables, opts) variables = variables or {} + opts = opts or {} + local strict_non_null = opts.strict_non_null or false if schemaType.__type == 'NonNull' then - return util.coerceValue(node, schemaType.ofType, variables) + local res = coerceValue(node, schemaType.ofType, variables, opts) + if strict_non_null and res == nil then + error(('Expected non-null for "%s", got null'):format( + getTypeName(schemaType))) + end + return res end if not node then return nil end + -- handle precompiled values + if node.compiled ~= nil then + return node.compiled + end + if node.kind == 'variable' then return variables[node.name.value] end @@ -65,32 +89,50 @@ function util.coerceValue(node, schemaType, variables) error('Expected a list') end - return util.map(node.values, function(value) - return util.coerceValue(value, schemaType.ofType, variables) + return map(node.values, function(value) + return coerceValue(value, schemaType.ofType, variables, opts) end) end - if schemaType.__type == 'InputObject' then + local isInputObject = schemaType.__type == 'InputObject' + if isInputObject then if node.kind ~= 'inputObject' then error('Expected an input object') end - return util.map(node.values, function(field) - if not schemaType.fields[field.name] then - error('Unknown input object field "' .. field.name .. '"') + -- check all fields: as from value as well as from schema + local fieldNameSet = {} + local fieldValues = {} + for _, field in ipairs(node.values) do + fieldNameSet[field.name] = true + fieldValues[field.name] = field.value + end + for fieldName, _ in pairs(schemaType.fields) do + fieldNameSet[fieldName] = true + end + + local inputObjectValue = {} + for fieldName, _ in pairs(fieldNameSet) do + if not schemaType.fields[fieldName] then + error(('Unknown input object field "%s"'):format(fieldName)) end - return util.coerceValue(field.value, schemaType.fields[field.name].kind, variables) - end) + local childValue = fieldValues[fieldName] + local childType = schemaType.fields[fieldName].kind + inputObjectValue[fieldName] = coerceValue(childValue, childType, + variables, opts) + end + + return inputObjectValue end if schemaType.__type == 'Enum' then if node.kind ~= 'enum' then - error('Expected enum value, got ' .. node.kind) + error(('Expected enum value, got %s'):format(node.kind)) end if not schemaType.values[node.value] then - error('Invalid enum value "' .. node.value .. '"') + error(('Invalid enum value "%s"'):format(node.value)) end return node.value @@ -98,11 +140,83 @@ function util.coerceValue(node, schemaType, variables) if schemaType.__type == 'Scalar' then if schemaType.parseLiteral(node) == nil then - error('Could not coerce "' .. tostring(node.value) .. '" to "' .. schemaType.name .. '"') + error(('Could not coerce "%s" to "%s"'):format( + tostring(node.value), schemaType.name)) end return schemaType.parseLiteral(node) end end -return util +--- Check whether passed value has one of listed types. +--- +--- @param obj value to check +--- +--- @tparam string obj_name name of the value to form an error +--- +--- @tparam string type_1 +--- @tparam[opt] string type_2 +--- @tparam[opt] string type_3 +--- +--- @return nothing +local function check(obj, obj_name, type_1, type_2, type_3) + if type(obj) == type_1 or type(obj) == type_2 or type(obj) == type_3 then + return + end + + if type_3 ~= nil then + error(('%s must be a %s or a % or a %s, got %s'):format(obj_name, + type_1, type_2, type_3, type(obj))) + elseif type_2 ~= nil then + error(('%s must be a %s or a %s, got %s'):format(obj_name, type_1, + type_2, type(obj))) + else + error(('%s must be a %s, got %s'):format(obj_name, type_1, type(obj))) + end +end + +--- Check whether table is an array. +--- +--- Based on [that][1] implementation. +--- [1]: https://github.com/mpx/lua-cjson/blob/db122676/lua/cjson/util.lua +--- +--- @tparam table table to check +--- @return[1] `true` if passed table is an array (includes the empty table +--- case) +--- @return[2] `false` otherwise +local function is_array(table) + check(table, 'table', 'table') + + local max = 0 + local count = 0 + for k, _ in pairs(table) do + if type(k) == 'number' then + if k > max then + max = k + end + count = count + 1 + else + return false + end + end + if max > count * 2 then + return false + end + + return max >= 0 +end + +return { + map = map, + find = find, + filter = filter, + values = values, + compose = compose, + bind1 = bind1, + trim = trim, + getTypeName = getTypeName, + coerceValue = coerceValue, + + is_array = is_array, + check = check, +} diff --git a/graphql/validate_variables.lua b/graphql/validate_variables.lua new file mode 100644 index 0000000..d4f6f89 --- /dev/null +++ b/graphql/validate_variables.lua @@ -0,0 +1,112 @@ +local path = (...):gsub('%.[^%.]+$', '') +local types = require(path .. '.types') +local util = require(path .. '.util') +local check = util.check + +-- Traverse type more or less likewise util.coerceValue do. +local function checkVariableValue(variableName, value, variableType) + check(variableName, 'variableName', 'string') + check(variableType, 'variableType', 'table') + + local isNonNull = variableType.__type == 'NonNull' + + if isNonNull then + variableType = types.nullable(variableType) + if value == nil then + error(('Variable "%s" expected to be non-null'):format(variableName)) + end + end + + local isList = variableType.__type == 'List' + local isScalar = variableType.__type == 'Scalar' + local isInputObject = variableType.__type == 'InputObject' + + -- Nullable variable type + null value case: value can be nil only when + -- isNonNull is false. + if value == nil then return end + + if isList then + if type(value) ~= 'table' then + error(('Variable "%s" for a List must be a Lua ' .. + 'table, got %s'):format(variableName, type(value))) + end + if not util.is_array(value) then + error(('Variable "%s" for a List must be an array, ' .. + 'got map'):format(variableName)) + end + assert(variableType.ofType ~= nil, 'variableType.ofType must not be nil') + for i, item in ipairs(value) do + local itemName = variableName .. '[' .. tostring(i) .. ']' + checkVariableValue(itemName, item, variableType.ofType) + end + return + end + + if isInputObject then + if type(value) ~= 'table' then + error(('Variable "%s" for the InputObject "%s" must ' .. + 'be a Lua table, got %s'):format(variableName, variableType.name, + type(value))) + end + + -- check all fields: as from value as well as from schema + local fieldNameSet = {} + for fieldName, _ in pairs(value) do + fieldNameSet[fieldName] = true + end + for fieldName, _ in pairs(variableType.fields) do + fieldNameSet[fieldName] = true + end + + for fieldName, _ in pairs(fieldNameSet) do + local fieldValue = value[fieldName] + if type(fieldName) ~= 'string' then + error(('Field key of the variable "%s" for the ' .. + 'InputObject "%s" must be a string, got %s'):format(variableName, + variableType.name, type(fieldName))) + end + if type(variableType.fields[fieldName]) == 'nil' then + error(('Unknown field "%s" of the variable "%s" ' .. + 'for the InputObject "%s"'):format(fieldName, variableName, + variableType.name)) + end + + local childType = variableType.fields[fieldName].kind + local childName = variableName .. '.' .. fieldName + checkVariableValue(childName, fieldValue, childType) + end + + return + end + + if isScalar then + check(variableType.isValueOfTheType, 'isValueOfTheType', 'function') + if not variableType.isValueOfTheType(value) then + error(('Wrong variable "%s" for the Scalar "%s"'):format( + variableName, variableType.name)) + end + return + end + + error(('Unknown type of the variable "%s"'):format(variableName)) +end + +local function validate_variables(context) + -- check that all variable values have corresponding variable declaration + for variableName, _ in pairs(context.variables or {}) do + if context.variableTypes[variableName] == nil then + error(('There is no declaration for the variable "%s"') + :format(variableName)) + end + end + + -- check that variable values have correct type + for variableName, variableType in pairs(context.variableTypes) do + local value = (context.variables or {})[variableName] + checkVariableValue(variableName, value, variableType) + end +end + +return { + validate_variables = validate_variables, +}