diff --git a/doc/langref.html.in b/doc/langref.html.in
index 818e0b5fe4..b006544f00 100644
--- a/doc/langref.html.in
+++ b/doc/langref.html.in
@@ -3016,6 +3016,7 @@ test "switch on tagged union" {
A: u32,
C: Point,
D,
+ E: u32,
};
var a = Item{ .C = Point{ .x = 1, .y = 2 } };
@@ -3023,8 +3024,9 @@ test "switch on tagged union" {
// Switching on more complex enums is allowed.
const b = switch (a) {
// A capture group is allowed on a match, and will return the enum
- // value matched.
- Item.A => |item| item,
+ // value matched. If the payloads of both cases are the same
+ // they can be put into the same switch prong.
+ Item.A, Item.E => |item| item,
// A reference to the matched value can be obtained using `*` syntax.
Item.C => |*item| blk: {
diff --git a/src/ir.cpp b/src/ir.cpp
index 6b19ce2909..995c993ab4 100644
--- a/src/ir.cpp
+++ b/src/ir.cpp
@@ -19230,10 +19230,6 @@ static IrInstruction *ir_analyze_instruction_switch_var(IrAnalyze *ira, IrInstru
assert(enum_type != nullptr);
assert(enum_type->id == ZigTypeIdEnum);
- if (instruction->prongs_len != 1) {
- return target_value_ptr;
- }
-
IrInstruction *prong_value = instruction->prongs_ptr[0]->child;
if (type_is_invalid(prong_value->value.type))
return ira->codegen->invalid_instruction;
@@ -19248,6 +19244,40 @@ static IrInstruction *ir_analyze_instruction_switch_var(IrAnalyze *ira, IrInstru
TypeUnionField *field = find_union_field_by_tag(target_type, &prong_val->data.x_enum_tag);
+ if (instruction->prongs_len != 1) {
+ ErrorMsg *invalid_payload = nullptr;
+ Buf *invalid_payload_list = nullptr;
+
+ for (size_t i = 1; i < instruction->prongs_len; i++) {
+ IrInstruction *casted_prong_value = ir_implicit_cast(ira, instruction->prongs_ptr[i]->child, enum_type);
+ if (type_is_invalid(casted_prong_value->value.type))
+ return ira->codegen->invalid_instruction;
+
+ ConstExprValue *next_prong = ir_resolve_const(ira, casted_prong_value, UndefBad);
+ if (!next_prong)
+ return ira->codegen->invalid_instruction;
+
+ ZigType *payload = find_union_field_by_tag(target_type, &next_prong->data.x_enum_tag)->type_entry;
+
+ if (field->type_entry != payload) {
+ if (!invalid_payload) {
+ invalid_payload = ir_add_error(ira, &instruction->base,
+ buf_sprintf("switch prong contains cases with different payloads"));
+ invalid_payload_list = buf_sprintf("payload types are %s", buf_ptr(&field->type_entry->name));
+ }
+
+ if (i == instruction->prongs_len - 1)
+ buf_append_buf(invalid_payload_list, buf_sprintf(" and %s", buf_ptr(&payload->name)));
+ else
+ buf_append_buf(invalid_payload_list, buf_sprintf(", %s", buf_ptr(&payload->name)));
+ }
+ }
+
+ if (invalid_payload)
+ add_error_note(ira->codegen, invalid_payload,
+ ((IrInstruction*)instruction)->source_node, invalid_payload_list);
+ }
+
if (instr_is_comptime(target_value_ptr)) {
ConstExprValue *target_val_ptr = ir_resolve_const(ira, target_value_ptr, UndefBad);
if (!target_value_ptr)
diff --git a/test/compile_errors.zig b/test/compile_errors.zig
index c411ba46f6..c6852621e3 100644
--- a/test/compile_errors.zig
+++ b/test/compile_errors.zig
@@ -6073,4 +6073,21 @@ pub fn addCases(cases: *tests.CompileErrorContext) void {
"tmp.zig:5:30: error: expression value is ignored",
"tmp.zig:9:30: error: expression value is ignored",
);
+
+ cases.add(
+ "capture group on switch prong with different payloads",
+ \\const Union = union(enum) {
+ \\ A: usize,
+ \\ B: isize,
+ \\};
+ \\comptime {
+ \\ var u = Union{ .A = 8 };
+ \\ switch (u) {
+ \\ .A, .B => |e| unreachable,
+ \\ }
+ \\}
+ ,
+ "tmp.zig:8:20: error: switch prong contains cases with different payloads",
+ "tmp.zig:8:20: note: payload types are usize and isize",
+ );
}
diff --git a/test/stage1/behavior/switch.zig b/test/stage1/behavior/switch.zig
index 12e026d0ba..2b7422fa6d 100644
--- a/test/stage1/behavior/switch.zig
+++ b/test/stage1/behavior/switch.zig
@@ -391,3 +391,21 @@ test "switch with null and T peer types and inferred result location type" {
S.doTheTest(1);
comptime S.doTheTest(1);
}
+
+test "switch prongs with cases with identical payloads" {
+ const Union = union(enum) {
+ A: usize,
+ B: isize,
+ C: usize,
+ };
+ const S = struct {
+ fn doTheTest(u: Union) void {
+ switch (u) {
+ .A, .C => |e| expect(@typeOf(e) == usize),
+ .B => |e| expect(@typeOf(e) == isize),
+ }
+ }
+ };
+ S.doTheTest(Union{ .A = 8 });
+ comptime S.doTheTest(Union{ .B = -8 });
+}