Skip to content

Commit b9faa87

Browse files
authored
Merge pull request #2871 from lewis6991/fix/narrowlit
fix: type narrow on fields with multiple literals
2 parents 6ba0c93 + 1543737 commit b9faa87

File tree

2 files changed

+51
-5
lines changed

2 files changed

+51
-5
lines changed

script/vm/tracer.lua

+9-5
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ end
261261
--- @param source parser.object
262262
--- @param fieldName string
263263
--- @param literal parser.object
264-
--- @return string[]?
264+
--- @return [string, boolean][]?
265265
local function getNodeTypesWithLiteralField(uri, source, fieldName, literal)
266266
local loc = vm.getVariable(source)
267267
if not loc then
@@ -279,7 +279,7 @@ local function getNodeTypesWithLiteralField(uri, source, fieldName, literal)
279279
for _, t in ipairs(f.extends.types) do
280280
if t[1] == literal[1] then
281281
tys = tys or {}
282-
table.insert(tys, set.class[1])
282+
table.insert(tys, {set.class[1], #f.extends.types > 1})
283283
break
284284
end
285285
end
@@ -682,15 +682,19 @@ local lookIntoChild = util.switch()
682682

683683
-- TODO: handle more types
684684
if tys and #tys == 1 then
685-
local ty = tys[1]
685+
-- If the type is in a union (e.g. 'lit' | foo), then the type
686+
-- cannot be removed from the node.
687+
local ty, tyInUnion = tys[1][1], tys[1][2]
686688
topNode = topNode:copy()
687689
if action.op.type == '==' then
688690
topNode:narrow(tracer.uri, ty)
689-
if outNode then
691+
if not tyInUnion and outNode then
690692
outNode:remove(ty)
691693
end
692694
else
693-
topNode:remove(ty)
695+
if not tyInUnion then
696+
topNode:remove(ty)
697+
end
694698
if outNode then
695699
outNode:narrow(tracer.uri, ty)
696700
end

test/type_inference/common.lua

+42
Original file line numberDiff line numberDiff line change
@@ -4529,3 +4529,45 @@ if obj.type == 'a' then
45294529
local <?r?> = obj
45304530
end
45314531
]]
4532+
4533+
TEST 'A|B' [[
4534+
--- @class A
4535+
--- @field mode? 'a' | 'b'
4536+
4537+
--- @class B
4538+
4539+
local a --- @type A | B
4540+
4541+
if a.mode == 'a' then
4542+
local b = a
4543+
else
4544+
local <?b?> = a
4545+
end
4546+
]]
4547+
4548+
TEST 'A|B' [[
4549+
--- @class A
4550+
--- @field mode? 'a' | 'b'
4551+
4552+
--- @class B
4553+
4554+
local a --- @type A | B
4555+
4556+
if a.mode ~= 'a' then
4557+
local <?b?> = a
4558+
end
4559+
]]
4560+
4561+
TEST 'A' [[
4562+
--- @class A
4563+
--- @field mode? 'a' | 'b'
4564+
4565+
--- @class B
4566+
4567+
local a --- @type A | B
4568+
4569+
if a.mode ~= 'a' then
4570+
else
4571+
local <?b?> = a
4572+
end
4573+
]]

0 commit comments

Comments
 (0)