diff --git a/changelog.md b/changelog.md index 891ebd5a0..7dc8d0ce6 100644 --- a/changelog.md +++ b/changelog.md @@ -5,6 +5,14 @@ * `FIX` Improve type narrow with **literal alias type** during completion and signature help * `NEW` Setting: `Lua.type.inferTableSize`: A Small Table array can be infered * `NEW` Add custom repository support for addonManager. New configuration setting: `Lua.addonManager.repositoryBranch` and `Lua.addonManager.repositoryPath` +* `NEW` Infer function parameter types when the function is used as an callback argument and that argument has a `fun()` annotation. Enable with `Lua.type.inferParamType` setting. [#2695](https://github.com/LuaLS/lua-language-server/pull/2695) + ```lua + ---@param callback fun(a: integer) + function register(callback) end + + local function callback(a) end --> a: integer + register(callback) + ``` ## 3.12.0 `2024-10-30` diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua index a3053f589..55406da68 100644 --- a/script/vm/compiler.lua +++ b/script/vm/compiler.lua @@ -1209,7 +1209,56 @@ local function compileFunctionParam(func, source) end ::continue:: end - return found + if found then + return true + end + -- infer local callback function param type + --[[ + ---@param callback fun(a: integer) + function register(callback) end + + local function callback(a) end --> a: integer + register(callback) + ]] + for _, ref in ipairs(refs) do + if ref.parent.type ~= 'callargs' then + goto continue + end + -- the parent function is a variable used as callback param, find the callback arg index first + local call = ref.parent.parent + local cbIndex + for i, arg in ipairs(call.args) do + if arg == ref then + cbIndex = i + break + end + end + ---@cast cbIndex integer + + -- simulate a completion at `cbIndex` to infer this callback function type + ---@diagnostic disable-next-line: missing-fields + local node = vm.compileCallArg({ type = 'dummyarg', uri = guide.getUri(call) }, call, cbIndex) + if not node then + goto continue + end + for n in node:eachObject() do + -- check if the inferred function has arg at `aindex` + if n.type == 'doc.type.function' and n.args and n.args[aindex] then + -- use type info on this `aindex` arg + local argNode = vm.compileNode(n.args[aindex]) + for an in argNode:eachObject() do + if an.type ~= 'doc.generic.name' then + vm.setNode(source, an) + found = true + end + end + end + end + ::continue:: + end + if found then + return true + end end do diff --git a/test/type_inference/common.lua b/test/type_inference/common.lua index fa56fbe56..69e83508a 100644 --- a/test/type_inference/common.lua +++ b/test/type_inference/common.lua @@ -1135,6 +1135,80 @@ xpcall(work, debug.traceback, function () end) ]] +config.set(nil, "Lua.type.inferParamType", true) + +TEST 'Class' [[ +---@class Class + +---@param callback fun(value: Class) +function work(callback) end + +local function cb() end +work(cb) +]] + +TEST 'any' [[ +---@class Class + +---@param callback fun(value: Class) +function work(callback) end + +---@param value any +local function cb() end +work(cb) +]] + +TEST 'any' [[ +---@class Class + +function work(callback) end + +local function cb() end +work(cb) +]] + +TEST 'string' [[ +---@class Class + +function work(callback) end + +---@param value string +local function cb() end +work(cb) +]] + + +TEST 'Parent' [[ +---@class Parent +local Parent + +---@generic T +---@param self T +---@param callback fun(self: T) +function Parent:work(callback) end + +local function cb() end +Parent:work(cb) +]] + +TEST 'Child' [[ +---@class Parent +local Parent + +---@generic T +---@param self T +---@param callback fun(self: T) +function Parent:work(callback) end + +---@class Child: Parent +local Child + +local function cb() end +Child:work(cb) +]] + +config.set(nil, "Lua.type.inferParamType", false) + TEST 'string' [[ ---@generic T ---@param x T