stage1: Implement @reduce builtin for vector types

The builtin folds a Vector(N,T) into a scalar T using a specified
operator.

Closes #2698
This commit is contained in:
LemonBoy 2020-10-04 18:23:52 +02:00 committed by Andrew Kelley
parent 7c5a24e08c
commit 22b5e47839
8 changed files with 430 additions and 39 deletions

View File

@ -98,6 +98,16 @@ pub const AtomicOrder = enum {
SeqCst,
};
/// This data structure is used by the Zig language code generation and
/// therefore must be kept in sync with the compiler implementation.
pub const ReduceOp = enum {
And,
Or,
Xor,
Min,
Max,
};
/// This data structure is used by the Zig language code generation and
/// therefore must be kept in sync with the compiler implementation.
pub const AtomicRmwOp = enum {

View File

@ -1821,6 +1821,7 @@ enum BuiltinFnId {
BuiltinFnIdWasmMemorySize,
BuiltinFnIdWasmMemoryGrow,
BuiltinFnIdSrc,
BuiltinFnIdReduce,
};
struct BuiltinFnEntry {
@ -2436,6 +2437,15 @@ enum AtomicOrder {
AtomicOrderSeqCst,
};
// synchronized with code in define_builtin_compile_vars
enum ReduceOp {
ReduceOp_and,
ReduceOp_or,
ReduceOp_xor,
ReduceOp_min,
ReduceOp_max,
};
// synchronized with the code in define_builtin_compile_vars
enum AtomicRmwOp {
AtomicRmwOp_xchg,
@ -2545,6 +2555,7 @@ enum IrInstSrcId {
IrInstSrcIdEmbedFile,
IrInstSrcIdCmpxchg,
IrInstSrcIdFence,
IrInstSrcIdReduce,
IrInstSrcIdTruncate,
IrInstSrcIdIntCast,
IrInstSrcIdFloatCast,
@ -2667,6 +2678,7 @@ enum IrInstGenId {
IrInstGenIdErrName,
IrInstGenIdCmpxchg,
IrInstGenIdFence,
IrInstGenIdReduce,
IrInstGenIdTruncate,
IrInstGenIdShuffleVector,
IrInstGenIdSplat,
@ -3516,6 +3528,20 @@ struct IrInstGenFence {
AtomicOrder order;
};
struct IrInstSrcReduce {
IrInstSrc base;
IrInstSrc *op;
IrInstSrc *value;
};
struct IrInstGenReduce {
IrInstGen base;
ReduceOp op;
IrInstGen *value;
};
struct IrInstSrcTruncate {
IrInstSrc base;

View File

@ -2583,36 +2583,6 @@ static LLVMValueRef ir_render_return(CodeGen *g, IrExecutableGen *executable, Ir
return nullptr;
}
enum class ScalarizePredicate {
// Returns true iff all the elements in the vector are 1.
// Equivalent to folding all the bits with `and`.
All,
// Returns true iff there's at least one element in the vector that is 1.
// Equivalent to folding all the bits with `or`.
Any,
};
// Collapses a <N x i1> vector into a single i1 according to the given predicate
static LLVMValueRef scalarize_cmp_result(CodeGen *g, LLVMValueRef val, ScalarizePredicate predicate) {
assert(LLVMGetTypeKind(LLVMTypeOf(val)) == LLVMVectorTypeKind);
LLVMTypeRef scalar_type = LLVMIntType(LLVMGetVectorSize(LLVMTypeOf(val)));
LLVMValueRef casted = LLVMBuildBitCast(g->builder, val, scalar_type, "");
switch (predicate) {
case ScalarizePredicate::Any: {
LLVMValueRef all_zeros = LLVMConstNull(scalar_type);
return LLVMBuildICmp(g->builder, LLVMIntNE, casted, all_zeros, "");
}
case ScalarizePredicate::All: {
LLVMValueRef all_ones = LLVMConstAllOnes(scalar_type);
return LLVMBuildICmp(g->builder, LLVMIntEQ, casted, all_ones, "");
}
}
zig_unreachable();
}
static LLVMValueRef gen_overflow_shl_op(CodeGen *g, ZigType *operand_type,
LLVMValueRef val1, LLVMValueRef val2)
{
@ -2637,7 +2607,7 @@ static LLVMValueRef gen_overflow_shl_op(CodeGen *g, ZigType *operand_type,
LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "OverflowOk");
LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "OverflowFail");
if (operand_type->id == ZigTypeIdVector) {
ok_bit = scalarize_cmp_result(g, ok_bit, ScalarizePredicate::All);
ok_bit = ZigLLVMBuildAndReduce(g->builder, ok_bit);
}
LLVMBuildCondBr(g->builder, ok_bit, ok_block, fail_block);
@ -2668,7 +2638,7 @@ static LLVMValueRef gen_overflow_shr_op(CodeGen *g, ZigType *operand_type,
LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "OverflowOk");
LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "OverflowFail");
if (operand_type->id == ZigTypeIdVector) {
ok_bit = scalarize_cmp_result(g, ok_bit, ScalarizePredicate::All);
ok_bit = ZigLLVMBuildAndReduce(g->builder, ok_bit);
}
LLVMBuildCondBr(g->builder, ok_bit, ok_block, fail_block);
@ -2745,7 +2715,7 @@ static LLVMValueRef gen_div(CodeGen *g, bool want_runtime_safety, bool want_fast
}
if (operand_type->id == ZigTypeIdVector) {
is_zero_bit = scalarize_cmp_result(g, is_zero_bit, ScalarizePredicate::Any);
is_zero_bit = ZigLLVMBuildOrReduce(g->builder, is_zero_bit);
}
LLVMBasicBlockRef div_zero_fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivZeroFail");
@ -2770,7 +2740,7 @@ static LLVMValueRef gen_div(CodeGen *g, bool want_runtime_safety, bool want_fast
LLVMValueRef den_is_neg_1 = LLVMBuildICmp(g->builder, LLVMIntEQ, val2, neg_1_value, "");
LLVMValueRef overflow_fail_bit = LLVMBuildAnd(g->builder, num_is_int_min, den_is_neg_1, "");
if (operand_type->id == ZigTypeIdVector) {
overflow_fail_bit = scalarize_cmp_result(g, overflow_fail_bit, ScalarizePredicate::Any);
overflow_fail_bit = ZigLLVMBuildOrReduce(g->builder, overflow_fail_bit);
}
LLVMBuildCondBr(g->builder, overflow_fail_bit, overflow_fail_block, overflow_ok_block);
@ -2795,7 +2765,7 @@ static LLVMValueRef gen_div(CodeGen *g, bool want_runtime_safety, bool want_fast
LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivExactFail");
LLVMValueRef ok_bit = LLVMBuildFCmp(g->builder, LLVMRealOEQ, floored, result, "");
if (operand_type->id == ZigTypeIdVector) {
ok_bit = scalarize_cmp_result(g, ok_bit, ScalarizePredicate::All);
ok_bit = ZigLLVMBuildAndReduce(g->builder, ok_bit);
}
LLVMBuildCondBr(g->builder, ok_bit, ok_block, fail_block);
@ -2812,7 +2782,7 @@ static LLVMValueRef gen_div(CodeGen *g, bool want_runtime_safety, bool want_fast
LLVMBasicBlockRef end_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivTruncEnd");
LLVMValueRef ltz = LLVMBuildFCmp(g->builder, LLVMRealOLT, val1, zero, "");
if (operand_type->id == ZigTypeIdVector) {
ltz = scalarize_cmp_result(g, ltz, ScalarizePredicate::Any);
ltz = ZigLLVMBuildOrReduce(g->builder, ltz);
}
LLVMBuildCondBr(g->builder, ltz, ltz_block, gez_block);
@ -2864,7 +2834,7 @@ static LLVMValueRef gen_div(CodeGen *g, bool want_runtime_safety, bool want_fast
LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivExactFail");
LLVMValueRef ok_bit = LLVMBuildICmp(g->builder, LLVMIntEQ, remainder_val, zero, "");
if (operand_type->id == ZigTypeIdVector) {
ok_bit = scalarize_cmp_result(g, ok_bit, ScalarizePredicate::All);
ok_bit = ZigLLVMBuildAndReduce(g->builder, ok_bit);
}
LLVMBuildCondBr(g->builder, ok_bit, ok_block, fail_block);
@ -2928,7 +2898,7 @@ static LLVMValueRef gen_rem(CodeGen *g, bool want_runtime_safety, bool want_fast
}
if (operand_type->id == ZigTypeIdVector) {
is_zero_bit = scalarize_cmp_result(g, is_zero_bit, ScalarizePredicate::Any);
is_zero_bit = ZigLLVMBuildOrReduce(g->builder, is_zero_bit);
}
LLVMBasicBlockRef rem_zero_ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "RemZeroOk");
@ -2985,7 +2955,7 @@ static void gen_shift_rhs_check(CodeGen *g, ZigType *lhs_type, ZigType *rhs_type
LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "CheckOk");
LLVMValueRef less_than_bit = LLVMBuildICmp(g->builder, LLVMIntULT, value, bit_count_value, "");
if (rhs_type->id == ZigTypeIdVector) {
less_than_bit = scalarize_cmp_result(g, less_than_bit, ScalarizePredicate::Any);
less_than_bit = ZigLLVMBuildOrReduce(g->builder, less_than_bit);
}
LLVMBuildCondBr(g->builder, less_than_bit, ok_block, fail_block);
@ -5470,6 +5440,50 @@ static LLVMValueRef ir_render_cmpxchg(CodeGen *g, IrExecutableGen *executable, I
return result_loc;
}
static LLVMValueRef ir_render_reduce(CodeGen *g, IrExecutableGen *executable, IrInstGenReduce *instruction) {
LLVMValueRef value = ir_llvm_value(g, instruction->value);
ZigType *value_type = instruction->value->value->type;
assert(value_type->id == ZigTypeIdVector);
ZigType *scalar_type = value_type->data.vector.elem_type;
LLVMValueRef result_val;
switch (instruction->op) {
case ReduceOp_and:
assert(scalar_type->id == ZigTypeIdInt || scalar_type->id == ZigTypeIdBool);
result_val = ZigLLVMBuildAndReduce(g->builder, value);
break;
case ReduceOp_or:
assert(scalar_type->id == ZigTypeIdInt || scalar_type->id == ZigTypeIdBool);
result_val = ZigLLVMBuildOrReduce(g->builder, value);
break;
case ReduceOp_xor:
assert(scalar_type->id == ZigTypeIdInt || scalar_type->id == ZigTypeIdBool);
result_val = ZigLLVMBuildXorReduce(g->builder, value);
break;
case ReduceOp_min: {
if (scalar_type->id == ZigTypeIdInt) {
const bool is_signed = scalar_type->data.integral.is_signed;
result_val = ZigLLVMBuildIntMinReduce(g->builder, value, is_signed);
} else if (scalar_type->id == ZigTypeIdFloat) {
result_val = ZigLLVMBuildFPMinReduce(g->builder, value);
} else zig_unreachable();
} break;
case ReduceOp_max: {
if (scalar_type->id == ZigTypeIdInt) {
const bool is_signed = scalar_type->data.integral.is_signed;
result_val = ZigLLVMBuildIntMaxReduce(g->builder, value, is_signed);
} else if (scalar_type->id == ZigTypeIdFloat) {
result_val = ZigLLVMBuildFPMaxReduce(g->builder, value);
} else zig_unreachable();
} break;
default:
zig_unreachable();
}
return result_val;
}
static LLVMValueRef ir_render_fence(CodeGen *g, IrExecutableGen *executable, IrInstGenFence *instruction) {
LLVMAtomicOrdering atomic_order = to_LLVMAtomicOrdering(instruction->order);
LLVMBuildFence(g->builder, atomic_order, false, "");
@ -6674,6 +6688,8 @@ static LLVMValueRef ir_render_instruction(CodeGen *g, IrExecutableGen *executabl
return ir_render_cmpxchg(g, executable, (IrInstGenCmpxchg *)instruction);
case IrInstGenIdFence:
return ir_render_fence(g, executable, (IrInstGenFence *)instruction);
case IrInstGenIdReduce:
return ir_render_reduce(g, executable, (IrInstGenReduce *)instruction);
case IrInstGenIdTruncate:
return ir_render_truncate(g, executable, (IrInstGenTruncate *)instruction);
case IrInstGenIdBoolNot:
@ -8630,6 +8646,7 @@ static void define_builtin_fns(CodeGen *g) {
create_builtin_fn(g, BuiltinFnIdWasmMemorySize, "wasmMemorySize", 1);
create_builtin_fn(g, BuiltinFnIdWasmMemoryGrow, "wasmMemoryGrow", 2);
create_builtin_fn(g, BuiltinFnIdSrc, "src", 0);
create_builtin_fn(g, BuiltinFnIdReduce, "reduce", 2);
}
static const char *bool_to_str(bool b) {

View File

@ -402,6 +402,8 @@ static void destroy_instruction_src(IrInstSrc *inst) {
return heap::c_allocator.destroy(reinterpret_cast<IrInstSrcCmpxchg *>(inst));
case IrInstSrcIdFence:
return heap::c_allocator.destroy(reinterpret_cast<IrInstSrcFence *>(inst));
case IrInstSrcIdReduce:
return heap::c_allocator.destroy(reinterpret_cast<IrInstSrcReduce *>(inst));
case IrInstSrcIdTruncate:
return heap::c_allocator.destroy(reinterpret_cast<IrInstSrcTruncate *>(inst));
case IrInstSrcIdIntCast:
@ -636,6 +638,8 @@ void destroy_instruction_gen(IrInstGen *inst) {
return heap::c_allocator.destroy(reinterpret_cast<IrInstGenCmpxchg *>(inst));
case IrInstGenIdFence:
return heap::c_allocator.destroy(reinterpret_cast<IrInstGenFence *>(inst));
case IrInstGenIdReduce:
return heap::c_allocator.destroy(reinterpret_cast<IrInstGenReduce *>(inst));
case IrInstGenIdTruncate:
return heap::c_allocator.destroy(reinterpret_cast<IrInstGenTruncate *>(inst));
case IrInstGenIdShuffleVector:
@ -1311,6 +1315,10 @@ static constexpr IrInstSrcId ir_inst_id(IrInstSrcFence *) {
return IrInstSrcIdFence;
}
static constexpr IrInstSrcId ir_inst_id(IrInstSrcReduce *) {
return IrInstSrcIdReduce;
}
static constexpr IrInstSrcId ir_inst_id(IrInstSrcTruncate *) {
return IrInstSrcIdTruncate;
}
@ -1775,6 +1783,10 @@ static constexpr IrInstGenId ir_inst_id(IrInstGenFence *) {
return IrInstGenIdFence;
}
static constexpr IrInstGenId ir_inst_id(IrInstGenReduce *) {
return IrInstGenIdReduce;
}
static constexpr IrInstGenId ir_inst_id(IrInstGenTruncate *) {
return IrInstGenIdTruncate;
}
@ -3502,6 +3514,29 @@ static IrInstGen *ir_build_fence_gen(IrAnalyze *ira, IrInst *source_instr, Atomi
return &instruction->base;
}
static IrInstSrc *ir_build_reduce(IrBuilderSrc *irb, Scope *scope, AstNode *source_node, IrInstSrc *op, IrInstSrc *value) {
IrInstSrcReduce *instruction = ir_build_instruction<IrInstSrcReduce>(irb, scope, source_node);
instruction->op = op;
instruction->value = value;
ir_ref_instruction(op, irb->current_basic_block);
ir_ref_instruction(value, irb->current_basic_block);
return &instruction->base;
}
static IrInstGen *ir_build_reduce_gen(IrAnalyze *ira, IrInst *source_instruction, ReduceOp op, IrInstGen *value, ZigType *result_type) {
IrInstGenReduce *instruction = ir_build_inst_gen<IrInstGenReduce>(&ira->new_irb,
source_instruction->scope, source_instruction->source_node);
instruction->base.value->type = result_type;
instruction->op = op;
instruction->value = value;
ir_ref_inst_gen(value);
return &instruction->base;
}
static IrInstSrc *ir_build_truncate(IrBuilderSrc *irb, Scope *scope, AstNode *source_node,
IrInstSrc *dest_type, IrInstSrc *target)
{
@ -6580,6 +6615,21 @@ static IrInstSrc *ir_gen_builtin_fn_call(IrBuilderSrc *irb, Scope *scope, AstNod
IrInstSrc *fence = ir_build_fence(irb, scope, node, arg0_value);
return ir_lval_wrap(irb, scope, fence, lval, result_loc);
}
case BuiltinFnIdReduce:
{
AstNode *arg0_node = node->data.fn_call_expr.params.at(0);
IrInstSrc *arg0_value = ir_gen_node(irb, arg0_node, scope);
if (arg0_value == irb->codegen->invalid_inst_src)
return arg0_value;
AstNode *arg1_node = node->data.fn_call_expr.params.at(1);
IrInstSrc *arg1_value = ir_gen_node(irb, arg1_node, scope);
if (arg1_value == irb->codegen->invalid_inst_src)
return arg1_value;
IrInstSrc *reduce = ir_build_reduce(irb, scope, node, arg0_value, arg1_value);
return ir_lval_wrap(irb, scope, reduce, lval, result_loc);
}
case BuiltinFnIdDivExact:
{
AstNode *arg0_node = node->data.fn_call_expr.params.at(0);
@ -15932,6 +15982,24 @@ static bool ir_resolve_comptime(IrAnalyze *ira, IrInstGen *value, bool *out) {
return ir_resolve_bool(ira, value, out);
}
static bool ir_resolve_reduce_op(IrAnalyze *ira, IrInstGen *value, ReduceOp *out) {
if (type_is_invalid(value->value->type))
return false;
ZigType *reduce_op_type = get_builtin_type(ira->codegen, "ReduceOp");
IrInstGen *casted_value = ir_implicit_cast(ira, value, reduce_op_type);
if (type_is_invalid(casted_value->value->type))
return false;
ZigValue *const_val = ir_resolve_const(ira, casted_value, UndefBad);
if (!const_val)
return false;
*out = (ReduceOp)bigint_as_u32(&const_val->data.x_enum_tag);
return true;
}
static bool ir_resolve_atomic_order(IrAnalyze *ira, IrInstGen *value, AtomicOrder *out) {
if (type_is_invalid(value->value->type))
return false;
@ -26802,6 +26870,161 @@ static IrInstGen *ir_analyze_instruction_cmpxchg(IrAnalyze *ira, IrInstSrcCmpxch
success_order, failure_order, instruction->is_weak, result_loc);
}
static ErrorMsg *ir_eval_reduce(IrAnalyze *ira, IrInst *source_instr, ReduceOp op, ZigValue *value, ZigValue *out_value) {
assert(value->type->id == ZigTypeIdVector);
ZigType *scalar_type = value->type->data.vector.elem_type;
const size_t len = value->type->data.vector.len;
assert(len > 0);
out_value->type = scalar_type;
out_value->special = ConstValSpecialStatic;
if (scalar_type->id == ZigTypeIdBool) {
ZigValue *first_elem_val = &value->data.x_array.data.s_none.elements[0];
bool result = first_elem_val->data.x_bool;
for (size_t i = 1; i < len; i++) {
ZigValue *elem_val = &value->data.x_array.data.s_none.elements[i];
switch (op) {
case ReduceOp_and:
result = result && elem_val->data.x_bool;
if (!result) break; // Short circuit
break;
case ReduceOp_or:
result = result || elem_val->data.x_bool;
if (result) break; // Short circuit
break;
case ReduceOp_xor:
result = result != elem_val->data.x_bool;
break;
default:
zig_unreachable();
}
}
out_value->data.x_bool = result;
return nullptr;
}
if (op != ReduceOp_min && op != ReduceOp_max) {
ZigValue *first_elem_val = &value->data.x_array.data.s_none.elements[0];
copy_const_val(ira->codegen, out_value, first_elem_val);
for (size_t i = 1; i < len; i++) {
ZigValue *elem_val = &value->data.x_array.data.s_none.elements[i];
IrBinOp bin_op;
switch (op) {
case ReduceOp_and: bin_op = IrBinOpBinAnd; break;
case ReduceOp_or: bin_op = IrBinOpBinOr; break;
case ReduceOp_xor: bin_op = IrBinOpBinXor; break;
default: zig_unreachable();
}
ErrorMsg *msg = ir_eval_math_op_scalar(ira, source_instr, scalar_type,
out_value, bin_op, elem_val, out_value);
if (msg != nullptr)
return msg;
}
return nullptr;
}
ZigValue *candidate_elem_val = &value->data.x_array.data.s_none.elements[0];
ZigValue *dummy_cmp_value = ira->codegen->pass1_arena->create<ZigValue>();
for (size_t i = 1; i < len; i++) {
ZigValue *elem_val = &value->data.x_array.data.s_none.elements[i];
IrBinOp bin_op;
switch (op) {
case ReduceOp_min: bin_op = IrBinOpCmpLessThan; break;
case ReduceOp_max: bin_op = IrBinOpCmpGreaterThan; break;
default: zig_unreachable();
}
ErrorMsg *msg = ir_eval_bin_op_cmp_scalar(ira, source_instr,
elem_val, bin_op, candidate_elem_val, dummy_cmp_value);
if (msg != nullptr)
return msg;
if (dummy_cmp_value->data.x_bool)
candidate_elem_val = elem_val;
}
ira->codegen->pass1_arena->destroy(dummy_cmp_value);
copy_const_val(ira->codegen, out_value, candidate_elem_val);
return nullptr;
}
static IrInstGen *ir_analyze_instruction_reduce(IrAnalyze *ira, IrInstSrcReduce *instruction) {
IrInstGen *op_inst = instruction->op->child;
if (type_is_invalid(op_inst->value->type))
return ira->codegen->invalid_inst_gen;
IrInstGen *value_inst = instruction->value->child;
if (type_is_invalid(value_inst->value->type))
return ira->codegen->invalid_inst_gen;
ZigType *value_type = value_inst->value->type;
if (value_type->id != ZigTypeIdVector) {
ir_add_error(ira, &value_inst->base,
buf_sprintf("expected vector type, found '%s'",
buf_ptr(&value_type->name)));
return ira->codegen->invalid_inst_gen;
}
ReduceOp op;
if (!ir_resolve_reduce_op(ira, op_inst, &op))
return ira->codegen->invalid_inst_gen;
ZigType *elem_type = value_type->data.vector.elem_type;
switch (elem_type->id) {
case ZigTypeIdInt:
break;
case ZigTypeIdBool:
if (op > ReduceOp_xor) {
ir_add_error(ira, &op_inst->base,
buf_sprintf("invalid operation for '%s' type",
buf_ptr(&elem_type->name)));
return ira->codegen->invalid_inst_gen;
} break;
case ZigTypeIdFloat:
if (op < ReduceOp_min) {
ir_add_error(ira, &op_inst->base,
buf_sprintf("invalid operation for '%s' type",
buf_ptr(&elem_type->name)));
return ira->codegen->invalid_inst_gen;
} break;
default:
// Vectors cannot have child types other than those listed above
zig_unreachable();
}
// special case zero bit types
switch (type_has_one_possible_value(ira->codegen, elem_type)) {
case OnePossibleValueInvalid:
return ira->codegen->invalid_inst_gen;
case OnePossibleValueYes:
return ir_const_move(ira, &instruction->base.base,
get_the_one_possible_value(ira->codegen, elem_type));
case OnePossibleValueNo:
break;
}
if (instr_is_comptime(value_inst)) {
IrInstGen *result = ir_const(ira, &instruction->base.base, elem_type);
if (ir_eval_reduce(ira, &instruction->base.base, op, value_inst->value, result->value))
return ira->codegen->invalid_inst_gen;
return result;
}
return ir_build_reduce_gen(ira, &instruction->base.base, op, value_inst, elem_type);
}
static IrInstGen *ir_analyze_instruction_fence(IrAnalyze *ira, IrInstSrcFence *instruction) {
IrInstGen *order_inst = instruction->order->child;
if (type_is_invalid(order_inst->value->type))
@ -31550,6 +31773,8 @@ static IrInstGen *ir_analyze_instruction_base(IrAnalyze *ira, IrInstSrc *instruc
return ir_analyze_instruction_cmpxchg(ira, (IrInstSrcCmpxchg *)instruction);
case IrInstSrcIdFence:
return ir_analyze_instruction_fence(ira, (IrInstSrcFence *)instruction);
case IrInstSrcIdReduce:
return ir_analyze_instruction_reduce(ira, (IrInstSrcReduce *)instruction);
case IrInstSrcIdTruncate:
return ir_analyze_instruction_truncate(ira, (IrInstSrcTruncate *)instruction);
case IrInstSrcIdIntCast:
@ -31937,6 +32162,7 @@ bool ir_inst_gen_has_side_effects(IrInstGen *instruction) {
case IrInstGenIdNegation:
case IrInstGenIdNegationWrapping:
case IrInstGenIdWasmMemorySize:
case IrInstGenIdReduce:
return false;
case IrInstGenIdAsm:
@ -32106,6 +32332,7 @@ bool ir_inst_src_has_side_effects(IrInstSrc *instruction) {
case IrInstSrcIdSpillEnd:
case IrInstSrcIdWasmMemorySize:
case IrInstSrcIdSrc:
case IrInstSrcIdReduce:
return false;
case IrInstSrcIdAsm:

View File

@ -200,6 +200,8 @@ const char* ir_inst_src_type_str(IrInstSrcId id) {
return "SrcCmpxchg";
case IrInstSrcIdFence:
return "SrcFence";
case IrInstSrcIdReduce:
return "SrcReduce";
case IrInstSrcIdTruncate:
return "SrcTruncate";
case IrInstSrcIdIntCast:
@ -436,6 +438,8 @@ const char* ir_inst_gen_type_str(IrInstGenId id) {
return "GenCmpxchg";
case IrInstGenIdFence:
return "GenFence";
case IrInstGenIdReduce:
return "GenReduce";
case IrInstGenIdTruncate:
return "GenTruncate";
case IrInstGenIdBoolNot:
@ -1584,6 +1588,14 @@ static void ir_print_fence(IrPrintSrc *irp, IrInstSrcFence *instruction) {
fprintf(irp->f, ")");
}
static void ir_print_reduce(IrPrintSrc *irp, IrInstSrcReduce *instruction) {
fprintf(irp->f, "@reduce(");
ir_print_other_inst_src(irp, instruction->op);
fprintf(irp->f, ", ");
ir_print_other_inst_src(irp, instruction->value);
fprintf(irp->f, ")");
}
static const char *atomic_order_str(AtomicOrder order) {
switch (order) {
case AtomicOrderUnordered: return "Unordered";
@ -1600,6 +1612,23 @@ static void ir_print_fence(IrPrintGen *irp, IrInstGenFence *instruction) {
fprintf(irp->f, "fence %s", atomic_order_str(instruction->order));
}
static const char *reduce_op_str(ReduceOp op) {
switch (op) {
case ReduceOp_and: return "And";
case ReduceOp_or: return "Or";
case ReduceOp_xor: return "Xor";
case ReduceOp_min: return "Min";
case ReduceOp_max: return "Max";
}
zig_unreachable();
}
static void ir_print_reduce(IrPrintGen *irp, IrInstGenReduce *instruction) {
fprintf(irp->f, "@reduce(.%s, ", reduce_op_str(instruction->op));
ir_print_other_inst_gen(irp, instruction->value);
fprintf(irp->f, ")");
}
static void ir_print_truncate(IrPrintSrc *irp, IrInstSrcTruncate *instruction) {
fprintf(irp->f, "@truncate(");
ir_print_other_inst_src(irp, instruction->dest_type);
@ -2749,6 +2778,9 @@ static void ir_print_inst_src(IrPrintSrc *irp, IrInstSrc *instruction, bool trai
case IrInstSrcIdFence:
ir_print_fence(irp, (IrInstSrcFence *)instruction);
break;
case IrInstSrcIdReduce:
ir_print_reduce(irp, (IrInstSrcReduce *)instruction);
break;
case IrInstSrcIdTruncate:
ir_print_truncate(irp, (IrInstSrcTruncate *)instruction);
break;
@ -3097,6 +3129,9 @@ static void ir_print_inst_gen(IrPrintGen *irp, IrInstGen *instruction, bool trai
case IrInstGenIdFence:
ir_print_fence(irp, (IrInstGenFence *)instruction);
break;
case IrInstGenIdReduce:
ir_print_reduce(irp, (IrInstGenReduce *)instruction);
break;
case IrInstGenIdTruncate:
ir_print_truncate(irp, (IrInstGenTruncate *)instruction);
break;

View File

@ -1123,6 +1123,34 @@ LLVMValueRef ZigLLVMBuildAtomicRMW(LLVMBuilderRef B, enum ZigLLVM_AtomicRMWBinOp
singleThread ? SyncScope::SingleThread : SyncScope::System));
}
LLVMValueRef ZigLLVMBuildAndReduce(LLVMBuilderRef B, LLVMValueRef Val) {
return wrap(unwrap(B)->CreateAndReduce(unwrap(Val)));
}
LLVMValueRef ZigLLVMBuildOrReduce(LLVMBuilderRef B, LLVMValueRef Val) {
return wrap(unwrap(B)->CreateOrReduce(unwrap(Val)));
}
LLVMValueRef ZigLLVMBuildXorReduce(LLVMBuilderRef B, LLVMValueRef Val) {
return wrap(unwrap(B)->CreateXorReduce(unwrap(Val)));
}
LLVMValueRef ZigLLVMBuildIntMaxReduce(LLVMBuilderRef B, LLVMValueRef Val, bool is_signed) {
return wrap(unwrap(B)->CreateIntMaxReduce(unwrap(Val), is_signed));
}
LLVMValueRef ZigLLVMBuildIntMinReduce(LLVMBuilderRef B, LLVMValueRef Val, bool is_signed) {
return wrap(unwrap(B)->CreateIntMinReduce(unwrap(Val), is_signed));
}
LLVMValueRef ZigLLVMBuildFPMaxReduce(LLVMBuilderRef B, LLVMValueRef Val) {
return wrap(unwrap(B)->CreateFPMaxReduce(unwrap(Val)));
}
LLVMValueRef ZigLLVMBuildFPMinReduce(LLVMBuilderRef B, LLVMValueRef Val) {
return wrap(unwrap(B)->CreateFPMinReduce(unwrap(Val)));
}
static_assert((Triple::ArchType)ZigLLVM_UnknownArch == Triple::UnknownArch, "");
static_assert((Triple::ArchType)ZigLLVM_arm == Triple::arm, "");
static_assert((Triple::ArchType)ZigLLVM_armeb == Triple::armeb, "");

View File

@ -455,6 +455,14 @@ LLVMValueRef ZigLLVMBuildAtomicRMW(LLVMBuilderRef B, enum ZigLLVM_AtomicRMWBinOp
LLVMValueRef PTR, LLVMValueRef Val,
LLVMAtomicOrdering ordering, LLVMBool singleThread);
LLVMValueRef ZigLLVMBuildAndReduce(LLVMBuilderRef B, LLVMValueRef Val);
LLVMValueRef ZigLLVMBuildOrReduce(LLVMBuilderRef B, LLVMValueRef Val);
LLVMValueRef ZigLLVMBuildXorReduce(LLVMBuilderRef B, LLVMValueRef Val);
LLVMValueRef ZigLLVMBuildIntMaxReduce(LLVMBuilderRef B, LLVMValueRef Val, bool is_signed);
LLVMValueRef ZigLLVMBuildIntMinReduce(LLVMBuilderRef B, LLVMValueRef Val, bool is_signed);
LLVMValueRef ZigLLVMBuildFPMaxReduce(LLVMBuilderRef B, LLVMValueRef Val);
LLVMValueRef ZigLLVMBuildFPMinReduce(LLVMBuilderRef B, LLVMValueRef Val);
#define ZigLLVM_DIFlags_Zero 0U
#define ZigLLVM_DIFlags_Private 1U
#define ZigLLVM_DIFlags_Protected 2U

View File

@ -484,3 +484,43 @@ test "vector shift operators" {
S.doTheTest();
comptime S.doTheTest();
}
test "vector reduce operation" {
const S = struct {
fn doTheTestReduce(comptime op: builtin.ReduceOp, x: anytype, expected: anytype) void {
const N = @typeInfo(@TypeOf(x)).Array.len;
const TX = @typeInfo(@TypeOf(x)).Array.child;
var r = @reduce(op, @as(Vector(N, TX), x));
expectEqual(expected, r);
}
fn doTheTest() void {
doTheTestReduce(.And, [4]bool{ true, false, true, true }, @as(bool, false));
doTheTestReduce(.Or, [4]bool{ false, true, false, false }, @as(bool, true));
doTheTestReduce(.Xor, [4]bool{ true, true, true, false }, @as(bool, true));
doTheTestReduce(.And, [4]u1{ 1, 0, 1, 1 }, @as(u1, 0));
doTheTestReduce(.Or, [4]u1{ 0, 1, 0, 0 }, @as(u1, 1));
doTheTestReduce(.Xor, [4]u1{ 1, 1, 1, 0 }, @as(u1, 1));
doTheTestReduce(.And, [4]u32{ 0xffffffff, 0xffff5555, 0xaaaaffff, 0x10101010 }, @as(u32, 0x1010));
doTheTestReduce(.Or, [4]u32{ 0xffff0000, 0xff00, 0xf0, 0xf }, ~@as(u32, 0));
doTheTestReduce(.Xor, [4]u32{ 0x00000000, 0x33333333, 0x88888888, 0x44444444 }, ~@as(u32, 0));
doTheTestReduce(.Min, [4]i32{ 1234567, -386, 0, 3 }, @as(i32, -386));
doTheTestReduce(.Max, [4]i32{ 1234567, -386, 0, 3 }, @as(i32, 1234567));
doTheTestReduce(.Min, [4]u32{ 99, 9999, 9, 99999 }, @as(u32, 9));
doTheTestReduce(.Max, [4]u32{ 99, 9999, 9, 99999 }, @as(u32, 99999));
doTheTestReduce(.Min, [4]f32{ -10.3, 10.0e9, 13.0, -100.0 }, @as(f32, -100.0));
doTheTestReduce(.Max, [4]f32{ -10.3, 10.0e9, 13.0, -100.0 }, @as(f32, 10.0e9));
doTheTestReduce(.Min, [4]f64{ -10.3, 10.0e9, 13.0, -100.0 }, @as(f64, -100.0));
doTheTestReduce(.Max, [4]f64{ -10.3, 10.0e9, 13.0, -100.0 }, @as(f64, 10.0e9));
}
};
S.doTheTest();
comptime S.doTheTest();
}