Skip to content

Commit b22c327

Browse files
committed
fix: type narrow on fields with multiple literals
1 parent 9b33baa commit b22c327

File tree

2 files changed

+22
-5
lines changed

2 files changed

+22
-5
lines changed

script/vm/tracer.lua

+7-5
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ end
261261
--- @param source parser.object
262262
--- @param fieldName string
263263
--- @param literal parser.object
264-
--- @return string[]?
264+
--- @return [string, boolean][]?
265265
local function getNodeTypesWithLiteralField(uri, source, fieldName, literal)
266266
local loc = vm.getVariable(source)
267267
if not loc then
@@ -279,7 +279,9 @@ local function getNodeTypesWithLiteralField(uri, source, fieldName, literal)
279279
for _, t in ipairs(f.extends.types) do
280280
if t[1] == literal[1] then
281281
tys = tys or {}
282-
table.insert(tys, set.class[1])
282+
-- If the type is in a union (e.g. 'lit' | foo), then the outNode
283+
-- cannot be narrowed.
284+
table.insert(tys, {set.class[1], #f.extends.types > 1})
283285
break
284286
end
285287
end
@@ -682,16 +684,16 @@ local lookIntoChild = util.switch()
682684

683685
-- TODO: handle more types
684686
if tys and #tys == 1 then
685-
local ty = tys[1]
687+
local ty, tyInUnion = tys[1][1], tys[1][2]
686688
topNode = topNode:copy()
687689
if action.op.type == '==' then
688690
topNode:narrow(tracer.uri, ty)
689-
if outNode then
691+
if not tyInUnion and outNode then
690692
outNode:remove(ty)
691693
end
692694
else
693695
topNode:remove(ty)
694-
if outNode then
696+
if not tyInUnion and outNode then
695697
outNode:narrow(tracer.uri, ty)
696698
end
697699
end

test/type_inference/common.lua

+15
Original file line numberDiff line numberDiff line change
@@ -4529,3 +4529,18 @@ if obj.type == 'a' then
45294529
local <?r?> = obj
45304530
end
45314531
]]
4532+
4533+
TEST 'A|B' [[
4534+
--- @class A
4535+
--- @field mode? 'a' | 'b'
4536+
4537+
--- @class B
4538+
4539+
local a --- @type A | B
4540+
4541+
if a.mode == 'a' then
4542+
local b = a
4543+
else
4544+
local <?b?> = a
4545+
end
4546+
]]

0 commit comments

Comments
 (0)