diff --git a/lib/std/builtin.zig b/lib/std/builtin.zig index 92fa78bc39..68bbbe3b2d 100644 --- a/lib/std/builtin.zig +++ b/lib/std/builtin.zig @@ -98,6 +98,16 @@ pub const AtomicOrder = enum { SeqCst, }; +/// This data structure is used by the Zig language code generation and +/// therefore must be kept in sync with the compiler implementation. +pub const ReduceOp = enum { + And, + Or, + Xor, + Min, + Max, +}; + /// This data structure is used by the Zig language code generation and /// therefore must be kept in sync with the compiler implementation. pub const AtomicRmwOp = enum { diff --git a/src/stage1/all_types.hpp b/src/stage1/all_types.hpp index a4d285bb1a..c9d7755942 100644 --- a/src/stage1/all_types.hpp +++ b/src/stage1/all_types.hpp @@ -1821,6 +1821,7 @@ enum BuiltinFnId { BuiltinFnIdWasmMemorySize, BuiltinFnIdWasmMemoryGrow, BuiltinFnIdSrc, + BuiltinFnIdReduce, }; struct BuiltinFnEntry { @@ -2436,6 +2437,15 @@ enum AtomicOrder { AtomicOrderSeqCst, }; +// synchronized with code in define_builtin_compile_vars +enum ReduceOp { + ReduceOp_and, + ReduceOp_or, + ReduceOp_xor, + ReduceOp_min, + ReduceOp_max, +}; + // synchronized with the code in define_builtin_compile_vars enum AtomicRmwOp { AtomicRmwOp_xchg, @@ -2545,6 +2555,7 @@ enum IrInstSrcId { IrInstSrcIdEmbedFile, IrInstSrcIdCmpxchg, IrInstSrcIdFence, + IrInstSrcIdReduce, IrInstSrcIdTruncate, IrInstSrcIdIntCast, IrInstSrcIdFloatCast, @@ -2667,6 +2678,7 @@ enum IrInstGenId { IrInstGenIdErrName, IrInstGenIdCmpxchg, IrInstGenIdFence, + IrInstGenIdReduce, IrInstGenIdTruncate, IrInstGenIdShuffleVector, IrInstGenIdSplat, @@ -3516,6 +3528,20 @@ struct IrInstGenFence { AtomicOrder order; }; +struct IrInstSrcReduce { + IrInstSrc base; + + IrInstSrc *op; + IrInstSrc *value; +}; + +struct IrInstGenReduce { + IrInstGen base; + + ReduceOp op; + IrInstGen *value; +}; + struct IrInstSrcTruncate { IrInstSrc base; diff --git a/src/stage1/codegen.cpp b/src/stage1/codegen.cpp index 194e5e38fb..0920482488 100644 --- a/src/stage1/codegen.cpp +++ b/src/stage1/codegen.cpp @@ -2583,36 +2583,6 @@ static LLVMValueRef ir_render_return(CodeGen *g, IrExecutableGen *executable, Ir return nullptr; } -enum class ScalarizePredicate { - // Returns true iff all the elements in the vector are 1. - // Equivalent to folding all the bits with `and`. - All, - // Returns true iff there's at least one element in the vector that is 1. - // Equivalent to folding all the bits with `or`. - Any, -}; - -// Collapses a vector into a single i1 according to the given predicate -static LLVMValueRef scalarize_cmp_result(CodeGen *g, LLVMValueRef val, ScalarizePredicate predicate) { - assert(LLVMGetTypeKind(LLVMTypeOf(val)) == LLVMVectorTypeKind); - LLVMTypeRef scalar_type = LLVMIntType(LLVMGetVectorSize(LLVMTypeOf(val))); - LLVMValueRef casted = LLVMBuildBitCast(g->builder, val, scalar_type, ""); - - switch (predicate) { - case ScalarizePredicate::Any: { - LLVMValueRef all_zeros = LLVMConstNull(scalar_type); - return LLVMBuildICmp(g->builder, LLVMIntNE, casted, all_zeros, ""); - } - case ScalarizePredicate::All: { - LLVMValueRef all_ones = LLVMConstAllOnes(scalar_type); - return LLVMBuildICmp(g->builder, LLVMIntEQ, casted, all_ones, ""); - } - } - - zig_unreachable(); -} - - static LLVMValueRef gen_overflow_shl_op(CodeGen *g, ZigType *operand_type, LLVMValueRef val1, LLVMValueRef val2) { @@ -2637,7 +2607,7 @@ static LLVMValueRef gen_overflow_shl_op(CodeGen *g, ZigType *operand_type, LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "OverflowOk"); LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "OverflowFail"); if (operand_type->id == ZigTypeIdVector) { - ok_bit = scalarize_cmp_result(g, ok_bit, ScalarizePredicate::All); + ok_bit = ZigLLVMBuildAndReduce(g->builder, ok_bit); } LLVMBuildCondBr(g->builder, ok_bit, ok_block, fail_block); @@ -2668,7 +2638,7 @@ static LLVMValueRef gen_overflow_shr_op(CodeGen *g, ZigType *operand_type, LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "OverflowOk"); LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "OverflowFail"); if (operand_type->id == ZigTypeIdVector) { - ok_bit = scalarize_cmp_result(g, ok_bit, ScalarizePredicate::All); + ok_bit = ZigLLVMBuildAndReduce(g->builder, ok_bit); } LLVMBuildCondBr(g->builder, ok_bit, ok_block, fail_block); @@ -2745,7 +2715,7 @@ static LLVMValueRef gen_div(CodeGen *g, bool want_runtime_safety, bool want_fast } if (operand_type->id == ZigTypeIdVector) { - is_zero_bit = scalarize_cmp_result(g, is_zero_bit, ScalarizePredicate::Any); + is_zero_bit = ZigLLVMBuildOrReduce(g->builder, is_zero_bit); } LLVMBasicBlockRef div_zero_fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivZeroFail"); @@ -2770,7 +2740,7 @@ static LLVMValueRef gen_div(CodeGen *g, bool want_runtime_safety, bool want_fast LLVMValueRef den_is_neg_1 = LLVMBuildICmp(g->builder, LLVMIntEQ, val2, neg_1_value, ""); LLVMValueRef overflow_fail_bit = LLVMBuildAnd(g->builder, num_is_int_min, den_is_neg_1, ""); if (operand_type->id == ZigTypeIdVector) { - overflow_fail_bit = scalarize_cmp_result(g, overflow_fail_bit, ScalarizePredicate::Any); + overflow_fail_bit = ZigLLVMBuildOrReduce(g->builder, overflow_fail_bit); } LLVMBuildCondBr(g->builder, overflow_fail_bit, overflow_fail_block, overflow_ok_block); @@ -2795,7 +2765,7 @@ static LLVMValueRef gen_div(CodeGen *g, bool want_runtime_safety, bool want_fast LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivExactFail"); LLVMValueRef ok_bit = LLVMBuildFCmp(g->builder, LLVMRealOEQ, floored, result, ""); if (operand_type->id == ZigTypeIdVector) { - ok_bit = scalarize_cmp_result(g, ok_bit, ScalarizePredicate::All); + ok_bit = ZigLLVMBuildAndReduce(g->builder, ok_bit); } LLVMBuildCondBr(g->builder, ok_bit, ok_block, fail_block); @@ -2812,7 +2782,7 @@ static LLVMValueRef gen_div(CodeGen *g, bool want_runtime_safety, bool want_fast LLVMBasicBlockRef end_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivTruncEnd"); LLVMValueRef ltz = LLVMBuildFCmp(g->builder, LLVMRealOLT, val1, zero, ""); if (operand_type->id == ZigTypeIdVector) { - ltz = scalarize_cmp_result(g, ltz, ScalarizePredicate::Any); + ltz = ZigLLVMBuildOrReduce(g->builder, ltz); } LLVMBuildCondBr(g->builder, ltz, ltz_block, gez_block); @@ -2864,7 +2834,7 @@ static LLVMValueRef gen_div(CodeGen *g, bool want_runtime_safety, bool want_fast LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivExactFail"); LLVMValueRef ok_bit = LLVMBuildICmp(g->builder, LLVMIntEQ, remainder_val, zero, ""); if (operand_type->id == ZigTypeIdVector) { - ok_bit = scalarize_cmp_result(g, ok_bit, ScalarizePredicate::All); + ok_bit = ZigLLVMBuildAndReduce(g->builder, ok_bit); } LLVMBuildCondBr(g->builder, ok_bit, ok_block, fail_block); @@ -2928,7 +2898,7 @@ static LLVMValueRef gen_rem(CodeGen *g, bool want_runtime_safety, bool want_fast } if (operand_type->id == ZigTypeIdVector) { - is_zero_bit = scalarize_cmp_result(g, is_zero_bit, ScalarizePredicate::Any); + is_zero_bit = ZigLLVMBuildOrReduce(g->builder, is_zero_bit); } LLVMBasicBlockRef rem_zero_ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "RemZeroOk"); @@ -2985,7 +2955,7 @@ static void gen_shift_rhs_check(CodeGen *g, ZigType *lhs_type, ZigType *rhs_type LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "CheckOk"); LLVMValueRef less_than_bit = LLVMBuildICmp(g->builder, LLVMIntULT, value, bit_count_value, ""); if (rhs_type->id == ZigTypeIdVector) { - less_than_bit = scalarize_cmp_result(g, less_than_bit, ScalarizePredicate::Any); + less_than_bit = ZigLLVMBuildOrReduce(g->builder, less_than_bit); } LLVMBuildCondBr(g->builder, less_than_bit, ok_block, fail_block); @@ -5470,6 +5440,50 @@ static LLVMValueRef ir_render_cmpxchg(CodeGen *g, IrExecutableGen *executable, I return result_loc; } +static LLVMValueRef ir_render_reduce(CodeGen *g, IrExecutableGen *executable, IrInstGenReduce *instruction) { + LLVMValueRef value = ir_llvm_value(g, instruction->value); + + ZigType *value_type = instruction->value->value->type; + assert(value_type->id == ZigTypeIdVector); + ZigType *scalar_type = value_type->data.vector.elem_type; + + LLVMValueRef result_val; + switch (instruction->op) { + case ReduceOp_and: + assert(scalar_type->id == ZigTypeIdInt || scalar_type->id == ZigTypeIdBool); + result_val = ZigLLVMBuildAndReduce(g->builder, value); + break; + case ReduceOp_or: + assert(scalar_type->id == ZigTypeIdInt || scalar_type->id == ZigTypeIdBool); + result_val = ZigLLVMBuildOrReduce(g->builder, value); + break; + case ReduceOp_xor: + assert(scalar_type->id == ZigTypeIdInt || scalar_type->id == ZigTypeIdBool); + result_val = ZigLLVMBuildXorReduce(g->builder, value); + break; + case ReduceOp_min: { + if (scalar_type->id == ZigTypeIdInt) { + const bool is_signed = scalar_type->data.integral.is_signed; + result_val = ZigLLVMBuildIntMinReduce(g->builder, value, is_signed); + } else if (scalar_type->id == ZigTypeIdFloat) { + result_val = ZigLLVMBuildFPMinReduce(g->builder, value); + } else zig_unreachable(); + } break; + case ReduceOp_max: { + if (scalar_type->id == ZigTypeIdInt) { + const bool is_signed = scalar_type->data.integral.is_signed; + result_val = ZigLLVMBuildIntMaxReduce(g->builder, value, is_signed); + } else if (scalar_type->id == ZigTypeIdFloat) { + result_val = ZigLLVMBuildFPMaxReduce(g->builder, value); + } else zig_unreachable(); + } break; + default: + zig_unreachable(); + } + + return result_val; +} + static LLVMValueRef ir_render_fence(CodeGen *g, IrExecutableGen *executable, IrInstGenFence *instruction) { LLVMAtomicOrdering atomic_order = to_LLVMAtomicOrdering(instruction->order); LLVMBuildFence(g->builder, atomic_order, false, ""); @@ -6674,6 +6688,8 @@ static LLVMValueRef ir_render_instruction(CodeGen *g, IrExecutableGen *executabl return ir_render_cmpxchg(g, executable, (IrInstGenCmpxchg *)instruction); case IrInstGenIdFence: return ir_render_fence(g, executable, (IrInstGenFence *)instruction); + case IrInstGenIdReduce: + return ir_render_reduce(g, executable, (IrInstGenReduce *)instruction); case IrInstGenIdTruncate: return ir_render_truncate(g, executable, (IrInstGenTruncate *)instruction); case IrInstGenIdBoolNot: @@ -8630,6 +8646,7 @@ static void define_builtin_fns(CodeGen *g) { create_builtin_fn(g, BuiltinFnIdWasmMemorySize, "wasmMemorySize", 1); create_builtin_fn(g, BuiltinFnIdWasmMemoryGrow, "wasmMemoryGrow", 2); create_builtin_fn(g, BuiltinFnIdSrc, "src", 0); + create_builtin_fn(g, BuiltinFnIdReduce, "reduce", 2); } static const char *bool_to_str(bool b) { diff --git a/src/stage1/ir.cpp b/src/stage1/ir.cpp index 045f1ad784..095ed301c9 100644 --- a/src/stage1/ir.cpp +++ b/src/stage1/ir.cpp @@ -402,6 +402,8 @@ static void destroy_instruction_src(IrInstSrc *inst) { return heap::c_allocator.destroy(reinterpret_cast(inst)); case IrInstSrcIdFence: return heap::c_allocator.destroy(reinterpret_cast(inst)); + case IrInstSrcIdReduce: + return heap::c_allocator.destroy(reinterpret_cast(inst)); case IrInstSrcIdTruncate: return heap::c_allocator.destroy(reinterpret_cast(inst)); case IrInstSrcIdIntCast: @@ -636,6 +638,8 @@ void destroy_instruction_gen(IrInstGen *inst) { return heap::c_allocator.destroy(reinterpret_cast(inst)); case IrInstGenIdFence: return heap::c_allocator.destroy(reinterpret_cast(inst)); + case IrInstGenIdReduce: + return heap::c_allocator.destroy(reinterpret_cast(inst)); case IrInstGenIdTruncate: return heap::c_allocator.destroy(reinterpret_cast(inst)); case IrInstGenIdShuffleVector: @@ -1311,6 +1315,10 @@ static constexpr IrInstSrcId ir_inst_id(IrInstSrcFence *) { return IrInstSrcIdFence; } +static constexpr IrInstSrcId ir_inst_id(IrInstSrcReduce *) { + return IrInstSrcIdReduce; +} + static constexpr IrInstSrcId ir_inst_id(IrInstSrcTruncate *) { return IrInstSrcIdTruncate; } @@ -1775,6 +1783,10 @@ static constexpr IrInstGenId ir_inst_id(IrInstGenFence *) { return IrInstGenIdFence; } +static constexpr IrInstGenId ir_inst_id(IrInstGenReduce *) { + return IrInstGenIdReduce; +} + static constexpr IrInstGenId ir_inst_id(IrInstGenTruncate *) { return IrInstGenIdTruncate; } @@ -3502,6 +3514,29 @@ static IrInstGen *ir_build_fence_gen(IrAnalyze *ira, IrInst *source_instr, Atomi return &instruction->base; } +static IrInstSrc *ir_build_reduce(IrBuilderSrc *irb, Scope *scope, AstNode *source_node, IrInstSrc *op, IrInstSrc *value) { + IrInstSrcReduce *instruction = ir_build_instruction(irb, scope, source_node); + instruction->op = op; + instruction->value = value; + + ir_ref_instruction(op, irb->current_basic_block); + ir_ref_instruction(value, irb->current_basic_block); + + return &instruction->base; +} + +static IrInstGen *ir_build_reduce_gen(IrAnalyze *ira, IrInst *source_instruction, ReduceOp op, IrInstGen *value, ZigType *result_type) { + IrInstGenReduce *instruction = ir_build_inst_gen(&ira->new_irb, + source_instruction->scope, source_instruction->source_node); + instruction->base.value->type = result_type; + instruction->op = op; + instruction->value = value; + + ir_ref_inst_gen(value); + + return &instruction->base; +} + static IrInstSrc *ir_build_truncate(IrBuilderSrc *irb, Scope *scope, AstNode *source_node, IrInstSrc *dest_type, IrInstSrc *target) { @@ -6580,6 +6615,21 @@ static IrInstSrc *ir_gen_builtin_fn_call(IrBuilderSrc *irb, Scope *scope, AstNod IrInstSrc *fence = ir_build_fence(irb, scope, node, arg0_value); return ir_lval_wrap(irb, scope, fence, lval, result_loc); } + case BuiltinFnIdReduce: + { + AstNode *arg0_node = node->data.fn_call_expr.params.at(0); + IrInstSrc *arg0_value = ir_gen_node(irb, arg0_node, scope); + if (arg0_value == irb->codegen->invalid_inst_src) + return arg0_value; + + AstNode *arg1_node = node->data.fn_call_expr.params.at(1); + IrInstSrc *arg1_value = ir_gen_node(irb, arg1_node, scope); + if (arg1_value == irb->codegen->invalid_inst_src) + return arg1_value; + + IrInstSrc *reduce = ir_build_reduce(irb, scope, node, arg0_value, arg1_value); + return ir_lval_wrap(irb, scope, reduce, lval, result_loc); + } case BuiltinFnIdDivExact: { AstNode *arg0_node = node->data.fn_call_expr.params.at(0); @@ -15932,6 +15982,24 @@ static bool ir_resolve_comptime(IrAnalyze *ira, IrInstGen *value, bool *out) { return ir_resolve_bool(ira, value, out); } +static bool ir_resolve_reduce_op(IrAnalyze *ira, IrInstGen *value, ReduceOp *out) { + if (type_is_invalid(value->value->type)) + return false; + + ZigType *reduce_op_type = get_builtin_type(ira->codegen, "ReduceOp"); + + IrInstGen *casted_value = ir_implicit_cast(ira, value, reduce_op_type); + if (type_is_invalid(casted_value->value->type)) + return false; + + ZigValue *const_val = ir_resolve_const(ira, casted_value, UndefBad); + if (!const_val) + return false; + + *out = (ReduceOp)bigint_as_u32(&const_val->data.x_enum_tag); + return true; +} + static bool ir_resolve_atomic_order(IrAnalyze *ira, IrInstGen *value, AtomicOrder *out) { if (type_is_invalid(value->value->type)) return false; @@ -26802,6 +26870,161 @@ static IrInstGen *ir_analyze_instruction_cmpxchg(IrAnalyze *ira, IrInstSrcCmpxch success_order, failure_order, instruction->is_weak, result_loc); } +static ErrorMsg *ir_eval_reduce(IrAnalyze *ira, IrInst *source_instr, ReduceOp op, ZigValue *value, ZigValue *out_value) { + assert(value->type->id == ZigTypeIdVector); + ZigType *scalar_type = value->type->data.vector.elem_type; + const size_t len = value->type->data.vector.len; + assert(len > 0); + + out_value->type = scalar_type; + out_value->special = ConstValSpecialStatic; + + if (scalar_type->id == ZigTypeIdBool) { + ZigValue *first_elem_val = &value->data.x_array.data.s_none.elements[0]; + + bool result = first_elem_val->data.x_bool; + for (size_t i = 1; i < len; i++) { + ZigValue *elem_val = &value->data.x_array.data.s_none.elements[i]; + + switch (op) { + case ReduceOp_and: + result = result && elem_val->data.x_bool; + if (!result) break; // Short circuit + break; + case ReduceOp_or: + result = result || elem_val->data.x_bool; + if (result) break; // Short circuit + break; + case ReduceOp_xor: + result = result != elem_val->data.x_bool; + break; + default: + zig_unreachable(); + } + } + + out_value->data.x_bool = result; + return nullptr; + } + + if (op != ReduceOp_min && op != ReduceOp_max) { + ZigValue *first_elem_val = &value->data.x_array.data.s_none.elements[0]; + + copy_const_val(ira->codegen, out_value, first_elem_val); + + for (size_t i = 1; i < len; i++) { + ZigValue *elem_val = &value->data.x_array.data.s_none.elements[i]; + + IrBinOp bin_op; + switch (op) { + case ReduceOp_and: bin_op = IrBinOpBinAnd; break; + case ReduceOp_or: bin_op = IrBinOpBinOr; break; + case ReduceOp_xor: bin_op = IrBinOpBinXor; break; + default: zig_unreachable(); + } + + ErrorMsg *msg = ir_eval_math_op_scalar(ira, source_instr, scalar_type, + out_value, bin_op, elem_val, out_value); + if (msg != nullptr) + return msg; + } + + return nullptr; + } + + ZigValue *candidate_elem_val = &value->data.x_array.data.s_none.elements[0]; + + ZigValue *dummy_cmp_value = ira->codegen->pass1_arena->create(); + for (size_t i = 1; i < len; i++) { + ZigValue *elem_val = &value->data.x_array.data.s_none.elements[i]; + + IrBinOp bin_op; + switch (op) { + case ReduceOp_min: bin_op = IrBinOpCmpLessThan; break; + case ReduceOp_max: bin_op = IrBinOpCmpGreaterThan; break; + default: zig_unreachable(); + } + + ErrorMsg *msg = ir_eval_bin_op_cmp_scalar(ira, source_instr, + elem_val, bin_op, candidate_elem_val, dummy_cmp_value); + if (msg != nullptr) + return msg; + + if (dummy_cmp_value->data.x_bool) + candidate_elem_val = elem_val; + } + + ira->codegen->pass1_arena->destroy(dummy_cmp_value); + copy_const_val(ira->codegen, out_value, candidate_elem_val); + + return nullptr; +} + +static IrInstGen *ir_analyze_instruction_reduce(IrAnalyze *ira, IrInstSrcReduce *instruction) { + IrInstGen *op_inst = instruction->op->child; + if (type_is_invalid(op_inst->value->type)) + return ira->codegen->invalid_inst_gen; + + IrInstGen *value_inst = instruction->value->child; + if (type_is_invalid(value_inst->value->type)) + return ira->codegen->invalid_inst_gen; + + ZigType *value_type = value_inst->value->type; + if (value_type->id != ZigTypeIdVector) { + ir_add_error(ira, &value_inst->base, + buf_sprintf("expected vector type, found '%s'", + buf_ptr(&value_type->name))); + return ira->codegen->invalid_inst_gen; + } + + ReduceOp op; + if (!ir_resolve_reduce_op(ira, op_inst, &op)) + return ira->codegen->invalid_inst_gen; + + ZigType *elem_type = value_type->data.vector.elem_type; + switch (elem_type->id) { + case ZigTypeIdInt: + break; + case ZigTypeIdBool: + if (op > ReduceOp_xor) { + ir_add_error(ira, &op_inst->base, + buf_sprintf("invalid operation for '%s' type", + buf_ptr(&elem_type->name))); + return ira->codegen->invalid_inst_gen; + } break; + case ZigTypeIdFloat: + if (op < ReduceOp_min) { + ir_add_error(ira, &op_inst->base, + buf_sprintf("invalid operation for '%s' type", + buf_ptr(&elem_type->name))); + return ira->codegen->invalid_inst_gen; + } break; + default: + // Vectors cannot have child types other than those listed above + zig_unreachable(); + } + + // special case zero bit types + switch (type_has_one_possible_value(ira->codegen, elem_type)) { + case OnePossibleValueInvalid: + return ira->codegen->invalid_inst_gen; + case OnePossibleValueYes: + return ir_const_move(ira, &instruction->base.base, + get_the_one_possible_value(ira->codegen, elem_type)); + case OnePossibleValueNo: + break; + } + + if (instr_is_comptime(value_inst)) { + IrInstGen *result = ir_const(ira, &instruction->base.base, elem_type); + if (ir_eval_reduce(ira, &instruction->base.base, op, value_inst->value, result->value)) + return ira->codegen->invalid_inst_gen; + return result; + } + + return ir_build_reduce_gen(ira, &instruction->base.base, op, value_inst, elem_type); +} + static IrInstGen *ir_analyze_instruction_fence(IrAnalyze *ira, IrInstSrcFence *instruction) { IrInstGen *order_inst = instruction->order->child; if (type_is_invalid(order_inst->value->type)) @@ -31550,6 +31773,8 @@ static IrInstGen *ir_analyze_instruction_base(IrAnalyze *ira, IrInstSrc *instruc return ir_analyze_instruction_cmpxchg(ira, (IrInstSrcCmpxchg *)instruction); case IrInstSrcIdFence: return ir_analyze_instruction_fence(ira, (IrInstSrcFence *)instruction); + case IrInstSrcIdReduce: + return ir_analyze_instruction_reduce(ira, (IrInstSrcReduce *)instruction); case IrInstSrcIdTruncate: return ir_analyze_instruction_truncate(ira, (IrInstSrcTruncate *)instruction); case IrInstSrcIdIntCast: @@ -31937,6 +32162,7 @@ bool ir_inst_gen_has_side_effects(IrInstGen *instruction) { case IrInstGenIdNegation: case IrInstGenIdNegationWrapping: case IrInstGenIdWasmMemorySize: + case IrInstGenIdReduce: return false; case IrInstGenIdAsm: @@ -32106,6 +32332,7 @@ bool ir_inst_src_has_side_effects(IrInstSrc *instruction) { case IrInstSrcIdSpillEnd: case IrInstSrcIdWasmMemorySize: case IrInstSrcIdSrc: + case IrInstSrcIdReduce: return false; case IrInstSrcIdAsm: diff --git a/src/stage1/ir_print.cpp b/src/stage1/ir_print.cpp index 18c2ca99f7..7d7fd4c9ea 100644 --- a/src/stage1/ir_print.cpp +++ b/src/stage1/ir_print.cpp @@ -200,6 +200,8 @@ const char* ir_inst_src_type_str(IrInstSrcId id) { return "SrcCmpxchg"; case IrInstSrcIdFence: return "SrcFence"; + case IrInstSrcIdReduce: + return "SrcReduce"; case IrInstSrcIdTruncate: return "SrcTruncate"; case IrInstSrcIdIntCast: @@ -436,6 +438,8 @@ const char* ir_inst_gen_type_str(IrInstGenId id) { return "GenCmpxchg"; case IrInstGenIdFence: return "GenFence"; + case IrInstGenIdReduce: + return "GenReduce"; case IrInstGenIdTruncate: return "GenTruncate"; case IrInstGenIdBoolNot: @@ -1584,6 +1588,14 @@ static void ir_print_fence(IrPrintSrc *irp, IrInstSrcFence *instruction) { fprintf(irp->f, ")"); } +static void ir_print_reduce(IrPrintSrc *irp, IrInstSrcReduce *instruction) { + fprintf(irp->f, "@reduce("); + ir_print_other_inst_src(irp, instruction->op); + fprintf(irp->f, ", "); + ir_print_other_inst_src(irp, instruction->value); + fprintf(irp->f, ")"); +} + static const char *atomic_order_str(AtomicOrder order) { switch (order) { case AtomicOrderUnordered: return "Unordered"; @@ -1600,6 +1612,23 @@ static void ir_print_fence(IrPrintGen *irp, IrInstGenFence *instruction) { fprintf(irp->f, "fence %s", atomic_order_str(instruction->order)); } +static const char *reduce_op_str(ReduceOp op) { + switch (op) { + case ReduceOp_and: return "And"; + case ReduceOp_or: return "Or"; + case ReduceOp_xor: return "Xor"; + case ReduceOp_min: return "Min"; + case ReduceOp_max: return "Max"; + } + zig_unreachable(); +} + +static void ir_print_reduce(IrPrintGen *irp, IrInstGenReduce *instruction) { + fprintf(irp->f, "@reduce(.%s, ", reduce_op_str(instruction->op)); + ir_print_other_inst_gen(irp, instruction->value); + fprintf(irp->f, ")"); +} + static void ir_print_truncate(IrPrintSrc *irp, IrInstSrcTruncate *instruction) { fprintf(irp->f, "@truncate("); ir_print_other_inst_src(irp, instruction->dest_type); @@ -2749,6 +2778,9 @@ static void ir_print_inst_src(IrPrintSrc *irp, IrInstSrc *instruction, bool trai case IrInstSrcIdFence: ir_print_fence(irp, (IrInstSrcFence *)instruction); break; + case IrInstSrcIdReduce: + ir_print_reduce(irp, (IrInstSrcReduce *)instruction); + break; case IrInstSrcIdTruncate: ir_print_truncate(irp, (IrInstSrcTruncate *)instruction); break; @@ -3097,6 +3129,9 @@ static void ir_print_inst_gen(IrPrintGen *irp, IrInstGen *instruction, bool trai case IrInstGenIdFence: ir_print_fence(irp, (IrInstGenFence *)instruction); break; + case IrInstGenIdReduce: + ir_print_reduce(irp, (IrInstGenReduce *)instruction); + break; case IrInstGenIdTruncate: ir_print_truncate(irp, (IrInstGenTruncate *)instruction); break; diff --git a/src/zig_llvm.cpp b/src/zig_llvm.cpp index 08823050ad..78082d16ba 100644 --- a/src/zig_llvm.cpp +++ b/src/zig_llvm.cpp @@ -1123,6 +1123,34 @@ LLVMValueRef ZigLLVMBuildAtomicRMW(LLVMBuilderRef B, enum ZigLLVM_AtomicRMWBinOp singleThread ? SyncScope::SingleThread : SyncScope::System)); } +LLVMValueRef ZigLLVMBuildAndReduce(LLVMBuilderRef B, LLVMValueRef Val) { + return wrap(unwrap(B)->CreateAndReduce(unwrap(Val))); +} + +LLVMValueRef ZigLLVMBuildOrReduce(LLVMBuilderRef B, LLVMValueRef Val) { + return wrap(unwrap(B)->CreateOrReduce(unwrap(Val))); +} + +LLVMValueRef ZigLLVMBuildXorReduce(LLVMBuilderRef B, LLVMValueRef Val) { + return wrap(unwrap(B)->CreateXorReduce(unwrap(Val))); +} + +LLVMValueRef ZigLLVMBuildIntMaxReduce(LLVMBuilderRef B, LLVMValueRef Val, bool is_signed) { + return wrap(unwrap(B)->CreateIntMaxReduce(unwrap(Val), is_signed)); +} + +LLVMValueRef ZigLLVMBuildIntMinReduce(LLVMBuilderRef B, LLVMValueRef Val, bool is_signed) { + return wrap(unwrap(B)->CreateIntMinReduce(unwrap(Val), is_signed)); +} + +LLVMValueRef ZigLLVMBuildFPMaxReduce(LLVMBuilderRef B, LLVMValueRef Val) { + return wrap(unwrap(B)->CreateFPMaxReduce(unwrap(Val))); +} + +LLVMValueRef ZigLLVMBuildFPMinReduce(LLVMBuilderRef B, LLVMValueRef Val) { + return wrap(unwrap(B)->CreateFPMinReduce(unwrap(Val))); +} + static_assert((Triple::ArchType)ZigLLVM_UnknownArch == Triple::UnknownArch, ""); static_assert((Triple::ArchType)ZigLLVM_arm == Triple::arm, ""); static_assert((Triple::ArchType)ZigLLVM_armeb == Triple::armeb, ""); diff --git a/src/zig_llvm.h b/src/zig_llvm.h index 007d8afc1f..966f142e03 100644 --- a/src/zig_llvm.h +++ b/src/zig_llvm.h @@ -455,6 +455,14 @@ LLVMValueRef ZigLLVMBuildAtomicRMW(LLVMBuilderRef B, enum ZigLLVM_AtomicRMWBinOp LLVMValueRef PTR, LLVMValueRef Val, LLVMAtomicOrdering ordering, LLVMBool singleThread); +LLVMValueRef ZigLLVMBuildAndReduce(LLVMBuilderRef B, LLVMValueRef Val); +LLVMValueRef ZigLLVMBuildOrReduce(LLVMBuilderRef B, LLVMValueRef Val); +LLVMValueRef ZigLLVMBuildXorReduce(LLVMBuilderRef B, LLVMValueRef Val); +LLVMValueRef ZigLLVMBuildIntMaxReduce(LLVMBuilderRef B, LLVMValueRef Val, bool is_signed); +LLVMValueRef ZigLLVMBuildIntMinReduce(LLVMBuilderRef B, LLVMValueRef Val, bool is_signed); +LLVMValueRef ZigLLVMBuildFPMaxReduce(LLVMBuilderRef B, LLVMValueRef Val); +LLVMValueRef ZigLLVMBuildFPMinReduce(LLVMBuilderRef B, LLVMValueRef Val); + #define ZigLLVM_DIFlags_Zero 0U #define ZigLLVM_DIFlags_Private 1U #define ZigLLVM_DIFlags_Protected 2U diff --git a/test/stage1/behavior/vector.zig b/test/stage1/behavior/vector.zig index dc9e49da43..aeb98f28fd 100644 --- a/test/stage1/behavior/vector.zig +++ b/test/stage1/behavior/vector.zig @@ -484,3 +484,43 @@ test "vector shift operators" { S.doTheTest(); comptime S.doTheTest(); } + +test "vector reduce operation" { + const S = struct { + fn doTheTestReduce(comptime op: builtin.ReduceOp, x: anytype, expected: anytype) void { + const N = @typeInfo(@TypeOf(x)).Array.len; + const TX = @typeInfo(@TypeOf(x)).Array.child; + + var r = @reduce(op, @as(Vector(N, TX), x)); + expectEqual(expected, r); + } + fn doTheTest() void { + doTheTestReduce(.And, [4]bool{ true, false, true, true }, @as(bool, false)); + doTheTestReduce(.Or, [4]bool{ false, true, false, false }, @as(bool, true)); + doTheTestReduce(.Xor, [4]bool{ true, true, true, false }, @as(bool, true)); + + doTheTestReduce(.And, [4]u1{ 1, 0, 1, 1 }, @as(u1, 0)); + doTheTestReduce(.Or, [4]u1{ 0, 1, 0, 0 }, @as(u1, 1)); + doTheTestReduce(.Xor, [4]u1{ 1, 1, 1, 0 }, @as(u1, 1)); + + doTheTestReduce(.And, [4]u32{ 0xffffffff, 0xffff5555, 0xaaaaffff, 0x10101010 }, @as(u32, 0x1010)); + doTheTestReduce(.Or, [4]u32{ 0xffff0000, 0xff00, 0xf0, 0xf }, ~@as(u32, 0)); + doTheTestReduce(.Xor, [4]u32{ 0x00000000, 0x33333333, 0x88888888, 0x44444444 }, ~@as(u32, 0)); + + doTheTestReduce(.Min, [4]i32{ 1234567, -386, 0, 3 }, @as(i32, -386)); + doTheTestReduce(.Max, [4]i32{ 1234567, -386, 0, 3 }, @as(i32, 1234567)); + + doTheTestReduce(.Min, [4]u32{ 99, 9999, 9, 99999 }, @as(u32, 9)); + doTheTestReduce(.Max, [4]u32{ 99, 9999, 9, 99999 }, @as(u32, 99999)); + + doTheTestReduce(.Min, [4]f32{ -10.3, 10.0e9, 13.0, -100.0 }, @as(f32, -100.0)); + doTheTestReduce(.Max, [4]f32{ -10.3, 10.0e9, 13.0, -100.0 }, @as(f32, 10.0e9)); + + doTheTestReduce(.Min, [4]f64{ -10.3, 10.0e9, 13.0, -100.0 }, @as(f64, -100.0)); + doTheTestReduce(.Max, [4]f64{ -10.3, 10.0e9, 13.0, -100.0 }, @as(f64, 10.0e9)); + } + }; + + S.doTheTest(); + comptime S.doTheTest(); +}