Skip to content

Commit 1831e60

Browse files
authored
Merge pull request #2864 from lewis6991/feat/narrowlit
feat: type narrow types with literal fields
2 parents 8f96025 + 5c3086a commit 1831e60

File tree

3 files changed

+110
-0
lines changed

3 files changed

+110
-0
lines changed

changelog.md

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
* `NEW` Infer function parameter types when overriding the same-named class function in an instance of that class [#2158](https://github.com/LuaLS/lua-language-server/issues/2158)
77
* `FIX` Eliminate floating point error in test benchmark output
88
* `FIX` Remove luamake install from make scripts
9+
* `NEW` Types with literal fields can be narrowed.
910

1011
## 3.10.6
1112
`2024-9-10`

script/vm/tracer.lua

+59
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,44 @@ function mt:fastWardCasts(pos, node)
256256
return node
257257
end
258258

259+
--- Return types of source which have a field with the value of literal.
260+
--- @param uri uri
261+
--- @param source parser.object
262+
--- @param fieldName string
263+
--- @param literal parser.object
264+
--- @return string[]?
265+
local function getNodeTypesWithLiteralField(uri, source, fieldName, literal)
266+
local loc = vm.getVariable(source)
267+
if not loc then
268+
return
269+
end
270+
271+
local tys
272+
273+
for _, c in ipairs(vm.compileNode(loc)) do
274+
if c.cate == 'type' then
275+
for _, set in ipairs(c:getSets(uri)) do
276+
if set.type == 'doc.class' then
277+
for _, f in ipairs(set.fields) do
278+
if f.field[1] == fieldName then
279+
for _, t in ipairs(f.extends.types) do
280+
if t[1] == literal[1] then
281+
tys = tys or {}
282+
table.insert(tys, set.class[1])
283+
break
284+
end
285+
end
286+
break
287+
end
288+
end
289+
end
290+
end
291+
end
292+
end
293+
294+
return tys
295+
end
296+
259297
local lookIntoChild = util.switch()
260298
: case 'getlocal'
261299
: case 'getglobal'
@@ -637,6 +675,27 @@ local lookIntoChild = util.switch()
637675
end
638676
end
639677
end
678+
elseif handler.type == 'getfield'
679+
and handler.node.type == 'getlocal' then
680+
local tys = getNodeTypesWithLiteralField(
681+
tracer.uri, handler.node, handler.field[1], checker)
682+
683+
-- TODO: handle more types
684+
if tys and #tys == 1 then
685+
local ty = tys[1]
686+
topNode = topNode:copy()
687+
if action.op.type == '==' then
688+
topNode:narrow(tracer.uri, ty)
689+
if outNode then
690+
outNode:remove(ty)
691+
end
692+
else
693+
topNode:remove(ty)
694+
if outNode then
695+
outNode:narrow(tracer.uri, ty)
696+
end
697+
end
698+
end
640699
elseif handler.type == 'call'
641700
and checker.type == 'string'
642701
and handler.node.special == 'type'

test/type_inference/common.lua

+50
Original file line numberDiff line numberDiff line change
@@ -4453,3 +4453,53 @@ function A:func(x) end
44534453
local a = {}
44544454
function a:func(<?x?>) end
44554455
]]
4456+
4457+
TEST 'A' [[
4458+
---@class A
4459+
---@field type 'a'
4460+
---@field field1 integer
4461+
4462+
---@class B
4463+
---@field type 'b'
4464+
4465+
local obj --- @type A|B
4466+
4467+
if obj.type == 'a' and obj.field1 > 0 then
4468+
local <?r?> = obj
4469+
end
4470+
]]
4471+
4472+
TEST 'B' [[
4473+
---@class A
4474+
---@field type 'a'
4475+
4476+
---@class B
4477+
---@field type 'b'
4478+
4479+
local obj --- @type A|B
4480+
4481+
if obj.type == 'a' then
4482+
---
4483+
else
4484+
local <?r?> = obj
4485+
end
4486+
]]
4487+
4488+
TEST 'A' [[
4489+
---@class A
4490+
---@field type 'a'
4491+
4492+
---@class B
4493+
---@field type 'b'
4494+
4495+
---@class C
4496+
---@field type 'c'
4497+
4498+
---@alias AB A|B
4499+
4500+
local obj --- @type C|AB
4501+
4502+
if obj.type == 'a' then
4503+
local <?r?> = obj
4504+
end
4505+
]]

0 commit comments

Comments
 (0)