diff --git a/doc/langref.html.in b/doc/langref.html.in index 818e0b5fe4..b006544f00 100644 --- a/doc/langref.html.in +++ b/doc/langref.html.in @@ -3016,6 +3016,7 @@ test "switch on tagged union" { A: u32, C: Point, D, + E: u32, }; var a = Item{ .C = Point{ .x = 1, .y = 2 } }; @@ -3023,8 +3024,9 @@ test "switch on tagged union" { // Switching on more complex enums is allowed. const b = switch (a) { // A capture group is allowed on a match, and will return the enum - // value matched. - Item.A => |item| item, + // value matched. If the payloads of both cases are the same + // they can be put into the same switch prong. + Item.A, Item.E => |item| item, // A reference to the matched value can be obtained using `*` syntax. Item.C => |*item| blk: { diff --git a/src/ir.cpp b/src/ir.cpp index 6b19ce2909..995c993ab4 100644 --- a/src/ir.cpp +++ b/src/ir.cpp @@ -19230,10 +19230,6 @@ static IrInstruction *ir_analyze_instruction_switch_var(IrAnalyze *ira, IrInstru assert(enum_type != nullptr); assert(enum_type->id == ZigTypeIdEnum); - if (instruction->prongs_len != 1) { - return target_value_ptr; - } - IrInstruction *prong_value = instruction->prongs_ptr[0]->child; if (type_is_invalid(prong_value->value.type)) return ira->codegen->invalid_instruction; @@ -19248,6 +19244,40 @@ static IrInstruction *ir_analyze_instruction_switch_var(IrAnalyze *ira, IrInstru TypeUnionField *field = find_union_field_by_tag(target_type, &prong_val->data.x_enum_tag); + if (instruction->prongs_len != 1) { + ErrorMsg *invalid_payload = nullptr; + Buf *invalid_payload_list = nullptr; + + for (size_t i = 1; i < instruction->prongs_len; i++) { + IrInstruction *casted_prong_value = ir_implicit_cast(ira, instruction->prongs_ptr[i]->child, enum_type); + if (type_is_invalid(casted_prong_value->value.type)) + return ira->codegen->invalid_instruction; + + ConstExprValue *next_prong = ir_resolve_const(ira, casted_prong_value, UndefBad); + if (!next_prong) + return ira->codegen->invalid_instruction; + + ZigType *payload = find_union_field_by_tag(target_type, &next_prong->data.x_enum_tag)->type_entry; + + if (field->type_entry != payload) { + if (!invalid_payload) { + invalid_payload = ir_add_error(ira, &instruction->base, + buf_sprintf("switch prong contains cases with different payloads")); + invalid_payload_list = buf_sprintf("payload types are %s", buf_ptr(&field->type_entry->name)); + } + + if (i == instruction->prongs_len - 1) + buf_append_buf(invalid_payload_list, buf_sprintf(" and %s", buf_ptr(&payload->name))); + else + buf_append_buf(invalid_payload_list, buf_sprintf(", %s", buf_ptr(&payload->name))); + } + } + + if (invalid_payload) + add_error_note(ira->codegen, invalid_payload, + ((IrInstruction*)instruction)->source_node, invalid_payload_list); + } + if (instr_is_comptime(target_value_ptr)) { ConstExprValue *target_val_ptr = ir_resolve_const(ira, target_value_ptr, UndefBad); if (!target_value_ptr) diff --git a/test/compile_errors.zig b/test/compile_errors.zig index c411ba46f6..c6852621e3 100644 --- a/test/compile_errors.zig +++ b/test/compile_errors.zig @@ -6073,4 +6073,21 @@ pub fn addCases(cases: *tests.CompileErrorContext) void { "tmp.zig:5:30: error: expression value is ignored", "tmp.zig:9:30: error: expression value is ignored", ); + + cases.add( + "capture group on switch prong with different payloads", + \\const Union = union(enum) { + \\ A: usize, + \\ B: isize, + \\}; + \\comptime { + \\ var u = Union{ .A = 8 }; + \\ switch (u) { + \\ .A, .B => |e| unreachable, + \\ } + \\} + , + "tmp.zig:8:20: error: switch prong contains cases with different payloads", + "tmp.zig:8:20: note: payload types are usize and isize", + ); } diff --git a/test/stage1/behavior/switch.zig b/test/stage1/behavior/switch.zig index 12e026d0ba..2b7422fa6d 100644 --- a/test/stage1/behavior/switch.zig +++ b/test/stage1/behavior/switch.zig @@ -391,3 +391,21 @@ test "switch with null and T peer types and inferred result location type" { S.doTheTest(1); comptime S.doTheTest(1); } + +test "switch prongs with cases with identical payloads" { + const Union = union(enum) { + A: usize, + B: isize, + C: usize, + }; + const S = struct { + fn doTheTest(u: Union) void { + switch (u) { + .A, .C => |e| expect(@typeOf(e) == usize), + .B => |e| expect(@typeOf(e) == isize), + } + } + }; + S.doTheTest(Union{ .A = 8 }); + comptime S.doTheTest(Union{ .B = -8 }); +}