Skip to content

Commit bcb673d

Browse files
committed
Sema: resolve union payload switch captures with peer type resolution
This is a bit harder than it seems at first glance. Actually resolving the type is the easy part: the interesting thing is actually getting the capture value. We split this into three cases: * If all payload types are the same (as is required in status quo), we can just do what we already do: get the first field value. * If all payloads are in-memory coercible to the resolved type, we still fetch the first field, but we also emit a `bitcast` to convert to the resolved type. * Otherwise, we need to handle each case separately. We emit a nested `switch_br` which, for each possible case, gets the corresponding union field, and coerces it to the resolved type. As an optimization, the inner switch's 'else' prong is used for any peer which is in-memory coercible to the target type, and the bitcast approach described above is used. Pointer captures have the additional constraint that all payload types must be in-memory coercible to the resolved type. Resolves: #2812
1 parent 85e94fe commit bcb673d

File tree

3 files changed

+347
-37
lines changed

3 files changed

+347
-37
lines changed

src/Sema.zig

Lines changed: 252 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -2392,6 +2392,34 @@ fn failWithOwnedErrorMsg(sema: *Sema, err_msg: *Module.ErrorMsg) CompileError {
23922392
return error.AnalysisFail;
23932393
}
23942394

2395+
/// Given an ErrorMsg, modify its message and source location to the given values, turning the
2396+
/// original message into a note. Notes on the original message are preserved as further notes.
2397+
/// Reference trace is preserved.
2398+
fn reparentOwnedErrorMsg(
2399+
sema: *Sema,
2400+
block: *Block,
2401+
src: LazySrcLoc,
2402+
msg: *Module.ErrorMsg,
2403+
comptime format: []const u8,
2404+
args: anytype,
2405+
) !void {
2406+
const mod = sema.mod;
2407+
const src_decl = mod.declPtr(block.src_decl);
2408+
const resolved_src = src.toSrcLoc(src_decl, mod);
2409+
const msg_str = try std.fmt.allocPrint(mod.gpa, format, args);
2410+
2411+
const orig_notes = msg.notes.len;
2412+
msg.notes = try sema.gpa.realloc(msg.notes, orig_notes + 1);
2413+
std.mem.copyBackwards(Module.ErrorMsg, msg.notes[1..], msg.notes[0..orig_notes]);
2414+
msg.notes[0] = .{
2415+
.src_loc = msg.src_loc,
2416+
.msg = msg.msg,
2417+
};
2418+
2419+
msg.src_loc = resolved_src;
2420+
msg.msg = msg_str;
2421+
}
2422+
23952423
const align_ty = Type.u29;
23962424

23972425
fn analyzeAsAlign(
@@ -10082,6 +10110,8 @@ const SwitchProngAnalysis = struct {
1008210110
operand: Air.Inst.Ref,
1008310111
/// May be `undefined` if no prong has a by-ref capture.
1008410112
operand_ptr: Air.Inst.Ref,
10113+
/// The switch condition value. For unions, `operand` is the union and `cond` is its tag.
10114+
cond: Air.Inst.Ref,
1008510115
/// If this switch is on an error set, this is the type to assign to the
1008610116
/// `else` prong. If `null`, the prong should be unreachable.
1008710117
else_error_ty: ?Type,
@@ -10315,61 +10345,245 @@ const SwitchProngAnalysis = struct {
1031510345
const first_field_index = @intCast(u32, operand_ty.unionTagFieldIndex(first_item_val, mod).?);
1031610346
const first_field = union_obj.fields.values()[first_field_index];
1031710347

10318-
for (case_vals[1..], 0..) |item, i| {
10348+
const field_tys = try sema.arena.alloc(Type, case_vals.len);
10349+
for (case_vals, field_tys) |item, *field_ty| {
1031910350
const item_val = sema.resolveConstValue(block, .unneeded, item, "") catch unreachable;
10351+
const field_idx = @intCast(u32, operand_ty.unionTagFieldIndex(item_val, sema.mod).?);
10352+
field_ty.* = union_obj.fields.values()[field_idx].ty;
10353+
}
1032010354

10321-
const field_index = operand_ty.unionTagFieldIndex(item_val, mod).?;
10322-
const field = union_obj.fields.values()[field_index];
10323-
if (!field.ty.eql(first_field.ty, mod)) {
10324-
const msg = msg: {
10325-
const capture_src = raw_capture_src.resolve(mod, mod.declPtr(block.src_decl), switch_node_offset, .none);
10355+
// Fast path: if all the operands are the same type already, we don't need to hit
10356+
// PTR! This will also allow us to emit simpler code.
10357+
const same_types = for (field_tys[1..]) |field_ty| {
10358+
if (!field_ty.eql(field_tys[0], sema.mod)) break false;
10359+
} else true;
1032610360

10327-
const msg = try sema.errMsg(block, capture_src, "capture group with incompatible types", .{});
10328-
errdefer msg.destroy(sema.gpa);
10361+
const capture_ty = if (same_types) field_tys[0] else capture_ty: {
10362+
// We need values to run PTR on, so make a bunch of undef constants.
10363+
const dummy_captures = try sema.arena.alloc(Air.Inst.Ref, case_vals.len);
10364+
for (dummy_captures, field_tys) |*dummy, field_ty| {
10365+
dummy.* = try sema.addConstUndef(field_ty);
10366+
}
10367+
10368+
const case_srcs = try sema.arena.alloc(?LazySrcLoc, case_vals.len);
10369+
@memset(case_srcs, .unneeded);
1032910370

10371+
break :capture_ty sema.resolvePeerTypes(block, .unneeded, dummy_captures, .{ .override = case_srcs }) catch |err| switch (err) {
10372+
error.NeededSourceLocation => {
1033010373
// This must be a multi-prong so this must be a `multi_capture` src
1033110374
const multi_idx = raw_capture_src.multi_capture;
10375+
const src_decl_ptr = sema.mod.declPtr(block.src_decl);
10376+
for (case_srcs, 0..) |*case_src, i| {
10377+
const raw_case_src: Module.SwitchProngSrc = .{ .multi = .{ .prong = multi_idx, .item = @intCast(u32, i) } };
10378+
case_src.* = raw_case_src.resolve(mod, src_decl_ptr, switch_node_offset, .none);
10379+
}
10380+
const capture_src = raw_capture_src.resolve(mod, src_decl_ptr, switch_node_offset, .none);
10381+
_ = sema.resolvePeerTypes(block, capture_src, dummy_captures, .{ .override = case_srcs }) catch |err1| switch (err1) {
10382+
error.AnalysisFail => {
10383+
const msg = sema.err orelse return error.AnalysisFail;
10384+
try sema.reparentOwnedErrorMsg(block, capture_src, msg, "capture group with incompatible types", .{});
10385+
return error.AnalysisFail;
10386+
},
10387+
else => |e| return e,
10388+
};
10389+
unreachable;
10390+
},
10391+
else => |e| return e,
10392+
};
10393+
};
1033210394

10333-
const raw_first_item_src = Module.SwitchProngSrc{ .multi = .{ .prong = multi_idx, .item = 0 } };
10334-
const first_item_src = raw_first_item_src.resolve(mod, mod.declPtr(block.src_decl), switch_node_offset, .first);
10335-
const raw_item_src = Module.SwitchProngSrc{ .multi = .{ .prong = multi_idx, .item = 1 + @intCast(u32, i) } };
10336-
const item_src = raw_item_src.resolve(mod, mod.declPtr(block.src_decl), switch_node_offset, .first);
10337-
try sema.errNote(block, first_item_src, msg, "type '{}' here", .{first_field.ty.fmt(mod)});
10338-
try sema.errNote(block, item_src, msg, "type '{}' here", .{field.ty.fmt(mod)});
10339-
break :msg msg;
10340-
};
10341-
return sema.failWithOwnedErrorMsg(msg);
10342-
}
10343-
}
10344-
10395+
// By-reference captures have some further restrictions which make them easier to emit
1034510396
if (capture_byref) {
10346-
const field_ty_ptr = try Type.ptr(sema.arena, mod, .{
10347-
.pointee_type = first_field.ty,
10348-
.@"addrspace" = .generic,
10349-
.mutable = operand_ptr_ty.ptrIsMutable(mod),
10397+
const operand_ptr_info = operand_ptr_ty.ptrInfo(mod);
10398+
const capture_ptr_ty = try Type.ptr(sema.arena, sema.mod, .{
10399+
.pointee_type = capture_ty,
10400+
.@"addrspace" = operand_ptr_info.@"addrspace",
10401+
.mutable = operand_ptr_info.mutable,
10402+
.@"volatile" = operand_ptr_info.@"volatile",
10403+
// TODO: alignment!
1035010404
});
1035110405

10406+
// By-ref captures of hetereogeneous types are only allowed if each field
10407+
// pointer type is in-memory coercible to the capture pointer type.
10408+
if (!same_types) {
10409+
for (field_tys, 0..) |field_ty, i| {
10410+
const field_ptr_ty = try Type.ptr(sema.arena, sema.mod, .{
10411+
.pointee_type = field_ty,
10412+
.@"addrspace" = operand_ptr_info.@"addrspace",
10413+
.mutable = operand_ptr_info.mutable,
10414+
.@"volatile" = operand_ptr_info.@"volatile",
10415+
// TODO: alignment!
10416+
});
10417+
if (.ok != try sema.coerceInMemoryAllowed(block, capture_ptr_ty, field_ptr_ty, false, sema.mod.getTarget(), .unneeded, .unneeded)) {
10418+
const multi_idx = raw_capture_src.multi_capture;
10419+
const src_decl_ptr = sema.mod.declPtr(block.src_decl);
10420+
const capture_src = raw_capture_src.resolve(mod, src_decl_ptr, switch_node_offset, .none);
10421+
const raw_case_src: Module.SwitchProngSrc = .{ .multi = .{ .prong = multi_idx, .item = @intCast(u32, i) } };
10422+
const case_src = raw_case_src.resolve(mod, src_decl_ptr, switch_node_offset, .none);
10423+
const msg = msg: {
10424+
const msg = try sema.errMsg(block, capture_src, "capture group with incompatible types", .{});
10425+
errdefer msg.destroy(sema.gpa);
10426+
try sema.errNote(block, case_src, msg, "pointer type child '{}' cannot cast into resolved pointer type child '{}'", .{
10427+
field_ty.fmt(sema.mod),
10428+
capture_ty.fmt(sema.mod),
10429+
});
10430+
try sema.errNote(block, capture_src, msg, "this coercion is only possible when capturing by value", .{});
10431+
break :msg msg;
10432+
};
10433+
return sema.failWithOwnedErrorMsg(msg);
10434+
}
10435+
}
10436+
}
10437+
1035210438
if (try sema.resolveDefinedValue(block, operand_src, spa.operand_ptr)) |op_ptr_val| {
10353-
return sema.addConstant(field_ty_ptr, (try mod.intern(.{ .ptr = .{
10354-
.ty = field_ty_ptr.toIntern(),
10355-
.addr = .{ .field = .{
10356-
.base = op_ptr_val.toIntern(),
10357-
.index = first_field_index,
10358-
} },
10359-
} })).toValue());
10439+
if (op_ptr_val.isUndef(mod)) return sema.addConstUndef(capture_ptr_ty);
10440+
return sema.addConstant(
10441+
capture_ptr_ty,
10442+
(try mod.intern(.{ .ptr = .{
10443+
.ty = capture_ptr_ty.toIntern(),
10444+
.addr = .{ .field = .{
10445+
.base = op_ptr_val.toIntern(),
10446+
.index = first_field_index,
10447+
} },
10448+
} })).toValue(),
10449+
);
1036010450
}
10451+
1036110452
try sema.requireRuntimeBlock(block, operand_src, null);
10362-
return block.addStructFieldPtr(spa.operand_ptr, first_field_index, field_ty_ptr);
10453+
return block.addStructFieldPtr(spa.operand_ptr, first_field_index, capture_ptr_ty);
1036310454
}
1036410455

1036510456
if (try sema.resolveDefinedValue(block, operand_src, spa.operand)) |operand_val| {
10366-
return sema.addConstant(
10367-
first_field.ty,
10368-
mod.intern_pool.indexToKey(operand_val.toIntern()).un.val.toValue(),
10369-
);
10457+
if (operand_val.isUndef(mod)) return sema.addConstUndef(capture_ty);
10458+
const union_val = mod.intern_pool.indexToKey(operand_val.toIntern()).un;
10459+
if (union_val.tag.toValue().isUndef(mod)) return sema.addConstUndef(capture_ty);
10460+
const active_field_idx = @intCast(u32, operand_ty.unionTagFieldIndex(union_val.tag.toValue(), sema.mod).?);
10461+
const field_ty = union_obj.fields.values()[active_field_idx].ty;
10462+
const uncoerced = try sema.addConstant(field_ty, union_val.val.toValue());
10463+
return sema.coerce(block, capture_ty, uncoerced, operand_src);
1037010464
}
10465+
1037110466
try sema.requireRuntimeBlock(block, operand_src, null);
10372-
return block.addStructFieldVal(spa.operand, first_field_index, first_field.ty);
10467+
10468+
if (same_types) {
10469+
return block.addStructFieldVal(spa.operand, first_field_index, capture_ty);
10470+
}
10471+
10472+
// We may have to emit a switch block which coerces the operand to the capture type.
10473+
// If we can, try to avoid that using in-memory coercions.
10474+
const first_non_imc = in_mem: {
10475+
for (field_tys, 0..) |field_ty, i| {
10476+
if (.ok != try sema.coerceInMemoryAllowed(block, capture_ty, field_ty, false, sema.mod.getTarget(), .unneeded, .unneeded)) {
10477+
break :in_mem i;
10478+
}
10479+
}
10480+
// All fields are in-memory coercible to the resolved type!
10481+
// Just take the first field and bitcast the result.
10482+
const uncoerced = try block.addStructFieldVal(spa.operand, first_field_index, first_field.ty);
10483+
return block.addBitCast(capture_ty, uncoerced);
10484+
};
10485+
10486+
// By-val capture with heterogeneous types which are not all in-memory coercible to
10487+
// the resolved capture type. We finally have to fall back to the ugly method.
10488+
10489+
// However, let's first track which operands are in-memory coercible. There may well
10490+
// be several, and we can squash all of these cases into the same switch prong using
10491+
// a simple bitcast. We'll make this the 'else' prong.
10492+
10493+
var in_mem_coercible = try std.DynamicBitSet.initFull(sema.arena, field_tys.len);
10494+
in_mem_coercible.unset(first_non_imc);
10495+
{
10496+
const next = first_non_imc + 1;
10497+
for (field_tys[next..], next..) |field_ty, i| {
10498+
if (.ok != try sema.coerceInMemoryAllowed(block, capture_ty, field_ty, false, sema.mod.getTarget(), .unneeded, .unneeded)) {
10499+
in_mem_coercible.unset(i);
10500+
}
10501+
}
10502+
}
10503+
10504+
const capture_block_inst = try block.addInstAsIndex(.{
10505+
.tag = .block,
10506+
.data = .{
10507+
.ty_pl = .{
10508+
.ty = try sema.addType(capture_ty),
10509+
.payload = undefined, // updated below
10510+
},
10511+
},
10512+
});
10513+
10514+
const prong_count = field_tys.len - in_mem_coercible.count();
10515+
10516+
const estimated_extra = prong_count * 6; // 2 for Case, 1 item, probably 3 insts
10517+
var cases_extra = try std.ArrayList(u32).initCapacity(sema.gpa, estimated_extra);
10518+
defer cases_extra.deinit();
10519+
10520+
{
10521+
// Non-bitcast cases
10522+
var it = in_mem_coercible.iterator(.{ .kind = .unset });
10523+
while (it.next()) |idx| {
10524+
var coerce_block = block.makeSubBlock();
10525+
defer coerce_block.instructions.deinit(sema.gpa);
10526+
10527+
const uncoerced = try coerce_block.addStructFieldVal(spa.operand, @intCast(u32, idx), field_tys[idx]);
10528+
const coerced = sema.coerce(&coerce_block, capture_ty, uncoerced, .unneeded) catch |err| switch (err) {
10529+
error.NeededSourceLocation => {
10530+
const multi_idx = raw_capture_src.multi_capture;
10531+
const src_decl_ptr = sema.mod.declPtr(block.src_decl);
10532+
const raw_case_src: Module.SwitchProngSrc = .{ .multi = .{ .prong = multi_idx, .item = @intCast(u32, idx) } };
10533+
const case_src = raw_case_src.resolve(mod, src_decl_ptr, switch_node_offset, .none);
10534+
_ = try sema.coerce(&coerce_block, capture_ty, uncoerced, case_src);
10535+
unreachable;
10536+
},
10537+
else => |e| return e,
10538+
};
10539+
_ = try coerce_block.addBr(capture_block_inst, coerced);
10540+
10541+
try cases_extra.ensureUnusedCapacity(3 + coerce_block.instructions.items.len);
10542+
cases_extra.appendAssumeCapacity(1); // items_len
10543+
cases_extra.appendAssumeCapacity(@intCast(u32, coerce_block.instructions.items.len)); // body_len
10544+
cases_extra.appendAssumeCapacity(@enumToInt(case_vals[idx])); // item
10545+
cases_extra.appendSliceAssumeCapacity(coerce_block.instructions.items); // body
10546+
}
10547+
}
10548+
const else_body_len = len: {
10549+
// 'else' prong uses a bitcast
10550+
var coerce_block = block.makeSubBlock();
10551+
defer coerce_block.instructions.deinit(sema.gpa);
10552+
10553+
const first_imc = in_mem_coercible.findFirstSet().?;
10554+
const uncoerced = try coerce_block.addStructFieldVal(spa.operand, @intCast(u32, first_imc), field_tys[first_imc]);
10555+
const coerced = try coerce_block.addBitCast(capture_ty, uncoerced);
10556+
_ = try coerce_block.addBr(capture_block_inst, coerced);
10557+
10558+
try cases_extra.appendSlice(coerce_block.instructions.items);
10559+
break :len coerce_block.instructions.items.len;
10560+
};
10561+
10562+
try sema.air_extra.ensureUnusedCapacity(sema.gpa, @typeInfo(Air.SwitchBr).Struct.fields.len +
10563+
cases_extra.items.len +
10564+
@typeInfo(Air.Block).Struct.fields.len +
10565+
1);
10566+
10567+
const switch_br_inst = @intCast(u32, sema.air_instructions.len);
10568+
try sema.air_instructions.append(sema.gpa, .{
10569+
.tag = .switch_br,
10570+
.data = .{ .pl_op = .{
10571+
.operand = spa.cond,
10572+
.payload = sema.addExtraAssumeCapacity(Air.SwitchBr{
10573+
.cases_len = @intCast(u32, prong_count),
10574+
.else_body_len = @intCast(u32, else_body_len),
10575+
}),
10576+
} },
10577+
});
10578+
sema.air_extra.appendSliceAssumeCapacity(cases_extra.items);
10579+
10580+
// Set up block body
10581+
sema.air_instructions.items(.data)[capture_block_inst].ty_pl.payload = sema.addExtraAssumeCapacity(Air.Block{
10582+
.body_len = 1,
10583+
});
10584+
sema.air_extra.appendAssumeCapacity(switch_br_inst);
10585+
10586+
return Air.indexToRef(capture_block_inst);
1037310587
},
1037410588
.ErrorSet => {
1037510589
if (capture_byref) {
@@ -11099,6 +11313,7 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index, operand_is_r
1109911313
.parent_block = block,
1110011314
.operand = raw_operand.val,
1110111315
.operand_ptr = raw_operand.ptr,
11316+
.cond = operand,
1110211317
.else_error_ty = else_error_ty,
1110311318
.switch_block_inst = inst,
1110411319
.tag_capture_inst = tag_capture_inst,

0 commit comments

Comments
 (0)