Sema: do not assume switch item indices align with union field indices

Resolves: #17754
This commit is contained in:
mlugg 2023-10-28 01:22:30 +01:00 committed by Matthew Lugg
parent 5257643d3d
commit c1c9bc0c41
2 changed files with 47 additions and 16 deletions

View File

@ -10789,23 +10789,24 @@ const SwitchProngAnalysis = struct {
const first_field_index: u32 = mod.unionTagFieldIndex(union_obj, first_item_val).?;
const first_field_ty = union_obj.field_types.get(ip)[first_field_index].toType();
const field_tys = try sema.arena.alloc(Type, case_vals.len);
for (case_vals, field_tys) |item, *field_ty| {
const field_indices = try sema.arena.alloc(u32, case_vals.len);
for (case_vals, field_indices) |item, *field_idx| {
const item_val = sema.resolveConstDefinedValue(block, .unneeded, item, undefined) catch unreachable;
const field_idx = mod.unionTagFieldIndex(union_obj, item_val).?;
field_ty.* = union_obj.field_types.get(ip)[field_idx].toType();
field_idx.* = mod.unionTagFieldIndex(union_obj, item_val).?;
}
// Fast path: if all the operands are the same type already, we don't need to hit
// PTR! This will also allow us to emit simpler code.
const same_types = for (field_tys[1..]) |field_ty| {
if (!field_ty.eql(field_tys[0], sema.mod)) break false;
const same_types = for (field_indices[1..]) |field_idx| {
const field_ty = union_obj.field_types.get(ip)[field_idx].toType();
if (!field_ty.eql(first_field_ty, sema.mod)) break false;
} else true;
const capture_ty = if (same_types) field_tys[0] else capture_ty: {
const capture_ty = if (same_types) first_field_ty else capture_ty: {
// We need values to run PTR on, so make a bunch of undef constants.
const dummy_captures = try sema.arena.alloc(Air.Inst.Ref, case_vals.len);
for (dummy_captures, field_tys) |*dummy, field_ty| {
for (dummy_captures, field_indices) |*dummy, field_idx| {
const field_ty = union_obj.field_types.get(ip)[field_idx].toType();
dummy.* = try mod.undefRef(field_ty);
}
@ -10852,7 +10853,8 @@ const SwitchProngAnalysis = struct {
// By-ref captures of hetereogeneous types are only allowed if each field
// pointer type is in-memory coercible to the capture pointer type.
if (!same_types) {
for (field_tys, 0..) |field_ty, i| {
for (field_indices, 0..) |field_idx, i| {
const field_ty = union_obj.field_types.get(ip)[field_idx].toType();
const field_ptr_ty = try sema.ptrType(.{
.child = field_ty.toIntern(),
.flags = .{
@ -10915,7 +10917,8 @@ const SwitchProngAnalysis = struct {
// We may have to emit a switch block which coerces the operand to the capture type.
// If we can, try to avoid that using in-memory coercions.
const first_non_imc = in_mem: {
for (field_tys, 0..) |field_ty, i| {
for (field_indices, 0..) |field_idx, i| {
const field_ty = union_obj.field_types.get(ip)[field_idx].toType();
if (.ok != try sema.coerceInMemoryAllowed(block, capture_ty, field_ty, false, sema.mod.getTarget(), .unneeded, .unneeded)) {
break :in_mem i;
}
@ -10933,11 +10936,12 @@ const SwitchProngAnalysis = struct {
// be several, and we can squash all of these cases into the same switch prong using
// a simple bitcast. We'll make this the 'else' prong.
var in_mem_coercible = try std.DynamicBitSet.initFull(sema.arena, field_tys.len);
var in_mem_coercible = try std.DynamicBitSet.initFull(sema.arena, field_indices.len);
in_mem_coercible.unset(first_non_imc);
{
const next = first_non_imc + 1;
for (field_tys[next..], next..) |field_ty, i| {
for (field_indices[next..], next..) |field_idx, i| {
const field_ty = union_obj.field_types.get(ip)[field_idx].toType();
if (.ok != try sema.coerceInMemoryAllowed(block, capture_ty, field_ty, false, sema.mod.getTarget(), .unneeded, .unneeded)) {
in_mem_coercible.unset(i);
}
@ -10954,7 +10958,7 @@ const SwitchProngAnalysis = struct {
},
});
const prong_count = field_tys.len - in_mem_coercible.count();
const prong_count = field_indices.len - in_mem_coercible.count();
const estimated_extra = prong_count * 6; // 2 for Case, 1 item, probably 3 insts
var cases_extra = try std.ArrayList(u32).initCapacity(sema.gpa, estimated_extra);
@ -10967,7 +10971,9 @@ const SwitchProngAnalysis = struct {
var coerce_block = block.makeSubBlock();
defer coerce_block.instructions.deinit(sema.gpa);
const uncoerced = try coerce_block.addStructFieldVal(spa.operand, @intCast(idx), field_tys[idx]);
const field_idx = field_indices[idx];
const field_ty = union_obj.field_types.get(ip)[field_idx].toType();
const uncoerced = try coerce_block.addStructFieldVal(spa.operand, field_idx, field_ty);
const coerced = sema.coerce(&coerce_block, capture_ty, uncoerced, .unneeded) catch |err| switch (err) {
error.NeededSourceLocation => {
const multi_idx = raw_capture_src.multi_capture;
@ -10993,8 +10999,10 @@ const SwitchProngAnalysis = struct {
var coerce_block = block.makeSubBlock();
defer coerce_block.instructions.deinit(sema.gpa);
const first_imc = in_mem_coercible.findFirstSet().?;
const uncoerced = try coerce_block.addStructFieldVal(spa.operand, @intCast(first_imc), field_tys[first_imc]);
const first_imc_item_idx = in_mem_coercible.findFirstSet().?;
const first_imc_field_idx = field_indices[first_imc_item_idx];
const first_imc_field_ty = union_obj.field_types.get(ip)[first_imc_field_idx].toType();
const uncoerced = try coerce_block.addStructFieldVal(spa.operand, first_imc_field_idx, first_imc_field_ty);
const coerced = try coerce_block.addBitCast(capture_ty, uncoerced);
_ = try coerce_block.addBr(capture_block_inst, coerced);

View File

@ -800,3 +800,26 @@ test "nested break ignores switch conditions and breaks instead" {
// Originally reported at https://github.com/ziglang/zig/issues/10196
try expect(0x01 == try S.register_to_address("a0"));
}
test "peer type resolution on switch captures ignores unused payload bits" {
if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
const Foo = union(enum) {
a: u32,
b: u64,
};
var val: Foo = undefined;
@memset(std.mem.asBytes(&val), 0xFF);
// This is runtime-known so the following store isn't comptime-known.
var rt: u32 = 123;
val = .{ .a = rt }; // will not necessarily zero remaning payload memory
// Fields intentionally backwards here
const x = switch (val) {
.b, .a => |x| x,
};
try expect(x == 123);
}