mirror of
https://github.com/ziglang/zig.git
synced 2025-02-09 22:20:17 +00:00
stage1: Implement @reduce builtin for vector types
The builtin folds a Vector(N,T) into a scalar T using a specified operator. Closes #2698
This commit is contained in:
parent
7c5a24e08c
commit
22b5e47839
@ -98,6 +98,16 @@ pub const AtomicOrder = enum {
|
||||
SeqCst,
|
||||
};
|
||||
|
||||
/// This data structure is used by the Zig language code generation and
|
||||
/// therefore must be kept in sync with the compiler implementation.
|
||||
pub const ReduceOp = enum {
|
||||
And,
|
||||
Or,
|
||||
Xor,
|
||||
Min,
|
||||
Max,
|
||||
};
|
||||
|
||||
/// This data structure is used by the Zig language code generation and
|
||||
/// therefore must be kept in sync with the compiler implementation.
|
||||
pub const AtomicRmwOp = enum {
|
||||
|
@ -1821,6 +1821,7 @@ enum BuiltinFnId {
|
||||
BuiltinFnIdWasmMemorySize,
|
||||
BuiltinFnIdWasmMemoryGrow,
|
||||
BuiltinFnIdSrc,
|
||||
BuiltinFnIdReduce,
|
||||
};
|
||||
|
||||
struct BuiltinFnEntry {
|
||||
@ -2436,6 +2437,15 @@ enum AtomicOrder {
|
||||
AtomicOrderSeqCst,
|
||||
};
|
||||
|
||||
// synchronized with code in define_builtin_compile_vars
|
||||
enum ReduceOp {
|
||||
ReduceOp_and,
|
||||
ReduceOp_or,
|
||||
ReduceOp_xor,
|
||||
ReduceOp_min,
|
||||
ReduceOp_max,
|
||||
};
|
||||
|
||||
// synchronized with the code in define_builtin_compile_vars
|
||||
enum AtomicRmwOp {
|
||||
AtomicRmwOp_xchg,
|
||||
@ -2545,6 +2555,7 @@ enum IrInstSrcId {
|
||||
IrInstSrcIdEmbedFile,
|
||||
IrInstSrcIdCmpxchg,
|
||||
IrInstSrcIdFence,
|
||||
IrInstSrcIdReduce,
|
||||
IrInstSrcIdTruncate,
|
||||
IrInstSrcIdIntCast,
|
||||
IrInstSrcIdFloatCast,
|
||||
@ -2667,6 +2678,7 @@ enum IrInstGenId {
|
||||
IrInstGenIdErrName,
|
||||
IrInstGenIdCmpxchg,
|
||||
IrInstGenIdFence,
|
||||
IrInstGenIdReduce,
|
||||
IrInstGenIdTruncate,
|
||||
IrInstGenIdShuffleVector,
|
||||
IrInstGenIdSplat,
|
||||
@ -3516,6 +3528,20 @@ struct IrInstGenFence {
|
||||
AtomicOrder order;
|
||||
};
|
||||
|
||||
struct IrInstSrcReduce {
|
||||
IrInstSrc base;
|
||||
|
||||
IrInstSrc *op;
|
||||
IrInstSrc *value;
|
||||
};
|
||||
|
||||
struct IrInstGenReduce {
|
||||
IrInstGen base;
|
||||
|
||||
ReduceOp op;
|
||||
IrInstGen *value;
|
||||
};
|
||||
|
||||
struct IrInstSrcTruncate {
|
||||
IrInstSrc base;
|
||||
|
||||
|
@ -2583,36 +2583,6 @@ static LLVMValueRef ir_render_return(CodeGen *g, IrExecutableGen *executable, Ir
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
enum class ScalarizePredicate {
|
||||
// Returns true iff all the elements in the vector are 1.
|
||||
// Equivalent to folding all the bits with `and`.
|
||||
All,
|
||||
// Returns true iff there's at least one element in the vector that is 1.
|
||||
// Equivalent to folding all the bits with `or`.
|
||||
Any,
|
||||
};
|
||||
|
||||
// Collapses a <N x i1> vector into a single i1 according to the given predicate
|
||||
static LLVMValueRef scalarize_cmp_result(CodeGen *g, LLVMValueRef val, ScalarizePredicate predicate) {
|
||||
assert(LLVMGetTypeKind(LLVMTypeOf(val)) == LLVMVectorTypeKind);
|
||||
LLVMTypeRef scalar_type = LLVMIntType(LLVMGetVectorSize(LLVMTypeOf(val)));
|
||||
LLVMValueRef casted = LLVMBuildBitCast(g->builder, val, scalar_type, "");
|
||||
|
||||
switch (predicate) {
|
||||
case ScalarizePredicate::Any: {
|
||||
LLVMValueRef all_zeros = LLVMConstNull(scalar_type);
|
||||
return LLVMBuildICmp(g->builder, LLVMIntNE, casted, all_zeros, "");
|
||||
}
|
||||
case ScalarizePredicate::All: {
|
||||
LLVMValueRef all_ones = LLVMConstAllOnes(scalar_type);
|
||||
return LLVMBuildICmp(g->builder, LLVMIntEQ, casted, all_ones, "");
|
||||
}
|
||||
}
|
||||
|
||||
zig_unreachable();
|
||||
}
|
||||
|
||||
|
||||
static LLVMValueRef gen_overflow_shl_op(CodeGen *g, ZigType *operand_type,
|
||||
LLVMValueRef val1, LLVMValueRef val2)
|
||||
{
|
||||
@ -2637,7 +2607,7 @@ static LLVMValueRef gen_overflow_shl_op(CodeGen *g, ZigType *operand_type,
|
||||
LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "OverflowOk");
|
||||
LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "OverflowFail");
|
||||
if (operand_type->id == ZigTypeIdVector) {
|
||||
ok_bit = scalarize_cmp_result(g, ok_bit, ScalarizePredicate::All);
|
||||
ok_bit = ZigLLVMBuildAndReduce(g->builder, ok_bit);
|
||||
}
|
||||
LLVMBuildCondBr(g->builder, ok_bit, ok_block, fail_block);
|
||||
|
||||
@ -2668,7 +2638,7 @@ static LLVMValueRef gen_overflow_shr_op(CodeGen *g, ZigType *operand_type,
|
||||
LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "OverflowOk");
|
||||
LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "OverflowFail");
|
||||
if (operand_type->id == ZigTypeIdVector) {
|
||||
ok_bit = scalarize_cmp_result(g, ok_bit, ScalarizePredicate::All);
|
||||
ok_bit = ZigLLVMBuildAndReduce(g->builder, ok_bit);
|
||||
}
|
||||
LLVMBuildCondBr(g->builder, ok_bit, ok_block, fail_block);
|
||||
|
||||
@ -2745,7 +2715,7 @@ static LLVMValueRef gen_div(CodeGen *g, bool want_runtime_safety, bool want_fast
|
||||
}
|
||||
|
||||
if (operand_type->id == ZigTypeIdVector) {
|
||||
is_zero_bit = scalarize_cmp_result(g, is_zero_bit, ScalarizePredicate::Any);
|
||||
is_zero_bit = ZigLLVMBuildOrReduce(g->builder, is_zero_bit);
|
||||
}
|
||||
|
||||
LLVMBasicBlockRef div_zero_fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivZeroFail");
|
||||
@ -2770,7 +2740,7 @@ static LLVMValueRef gen_div(CodeGen *g, bool want_runtime_safety, bool want_fast
|
||||
LLVMValueRef den_is_neg_1 = LLVMBuildICmp(g->builder, LLVMIntEQ, val2, neg_1_value, "");
|
||||
LLVMValueRef overflow_fail_bit = LLVMBuildAnd(g->builder, num_is_int_min, den_is_neg_1, "");
|
||||
if (operand_type->id == ZigTypeIdVector) {
|
||||
overflow_fail_bit = scalarize_cmp_result(g, overflow_fail_bit, ScalarizePredicate::Any);
|
||||
overflow_fail_bit = ZigLLVMBuildOrReduce(g->builder, overflow_fail_bit);
|
||||
}
|
||||
LLVMBuildCondBr(g->builder, overflow_fail_bit, overflow_fail_block, overflow_ok_block);
|
||||
|
||||
@ -2795,7 +2765,7 @@ static LLVMValueRef gen_div(CodeGen *g, bool want_runtime_safety, bool want_fast
|
||||
LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivExactFail");
|
||||
LLVMValueRef ok_bit = LLVMBuildFCmp(g->builder, LLVMRealOEQ, floored, result, "");
|
||||
if (operand_type->id == ZigTypeIdVector) {
|
||||
ok_bit = scalarize_cmp_result(g, ok_bit, ScalarizePredicate::All);
|
||||
ok_bit = ZigLLVMBuildAndReduce(g->builder, ok_bit);
|
||||
}
|
||||
LLVMBuildCondBr(g->builder, ok_bit, ok_block, fail_block);
|
||||
|
||||
@ -2812,7 +2782,7 @@ static LLVMValueRef gen_div(CodeGen *g, bool want_runtime_safety, bool want_fast
|
||||
LLVMBasicBlockRef end_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivTruncEnd");
|
||||
LLVMValueRef ltz = LLVMBuildFCmp(g->builder, LLVMRealOLT, val1, zero, "");
|
||||
if (operand_type->id == ZigTypeIdVector) {
|
||||
ltz = scalarize_cmp_result(g, ltz, ScalarizePredicate::Any);
|
||||
ltz = ZigLLVMBuildOrReduce(g->builder, ltz);
|
||||
}
|
||||
LLVMBuildCondBr(g->builder, ltz, ltz_block, gez_block);
|
||||
|
||||
@ -2864,7 +2834,7 @@ static LLVMValueRef gen_div(CodeGen *g, bool want_runtime_safety, bool want_fast
|
||||
LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivExactFail");
|
||||
LLVMValueRef ok_bit = LLVMBuildICmp(g->builder, LLVMIntEQ, remainder_val, zero, "");
|
||||
if (operand_type->id == ZigTypeIdVector) {
|
||||
ok_bit = scalarize_cmp_result(g, ok_bit, ScalarizePredicate::All);
|
||||
ok_bit = ZigLLVMBuildAndReduce(g->builder, ok_bit);
|
||||
}
|
||||
LLVMBuildCondBr(g->builder, ok_bit, ok_block, fail_block);
|
||||
|
||||
@ -2928,7 +2898,7 @@ static LLVMValueRef gen_rem(CodeGen *g, bool want_runtime_safety, bool want_fast
|
||||
}
|
||||
|
||||
if (operand_type->id == ZigTypeIdVector) {
|
||||
is_zero_bit = scalarize_cmp_result(g, is_zero_bit, ScalarizePredicate::Any);
|
||||
is_zero_bit = ZigLLVMBuildOrReduce(g->builder, is_zero_bit);
|
||||
}
|
||||
|
||||
LLVMBasicBlockRef rem_zero_ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "RemZeroOk");
|
||||
@ -2985,7 +2955,7 @@ static void gen_shift_rhs_check(CodeGen *g, ZigType *lhs_type, ZigType *rhs_type
|
||||
LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "CheckOk");
|
||||
LLVMValueRef less_than_bit = LLVMBuildICmp(g->builder, LLVMIntULT, value, bit_count_value, "");
|
||||
if (rhs_type->id == ZigTypeIdVector) {
|
||||
less_than_bit = scalarize_cmp_result(g, less_than_bit, ScalarizePredicate::Any);
|
||||
less_than_bit = ZigLLVMBuildOrReduce(g->builder, less_than_bit);
|
||||
}
|
||||
LLVMBuildCondBr(g->builder, less_than_bit, ok_block, fail_block);
|
||||
|
||||
@ -5470,6 +5440,50 @@ static LLVMValueRef ir_render_cmpxchg(CodeGen *g, IrExecutableGen *executable, I
|
||||
return result_loc;
|
||||
}
|
||||
|
||||
static LLVMValueRef ir_render_reduce(CodeGen *g, IrExecutableGen *executable, IrInstGenReduce *instruction) {
|
||||
LLVMValueRef value = ir_llvm_value(g, instruction->value);
|
||||
|
||||
ZigType *value_type = instruction->value->value->type;
|
||||
assert(value_type->id == ZigTypeIdVector);
|
||||
ZigType *scalar_type = value_type->data.vector.elem_type;
|
||||
|
||||
LLVMValueRef result_val;
|
||||
switch (instruction->op) {
|
||||
case ReduceOp_and:
|
||||
assert(scalar_type->id == ZigTypeIdInt || scalar_type->id == ZigTypeIdBool);
|
||||
result_val = ZigLLVMBuildAndReduce(g->builder, value);
|
||||
break;
|
||||
case ReduceOp_or:
|
||||
assert(scalar_type->id == ZigTypeIdInt || scalar_type->id == ZigTypeIdBool);
|
||||
result_val = ZigLLVMBuildOrReduce(g->builder, value);
|
||||
break;
|
||||
case ReduceOp_xor:
|
||||
assert(scalar_type->id == ZigTypeIdInt || scalar_type->id == ZigTypeIdBool);
|
||||
result_val = ZigLLVMBuildXorReduce(g->builder, value);
|
||||
break;
|
||||
case ReduceOp_min: {
|
||||
if (scalar_type->id == ZigTypeIdInt) {
|
||||
const bool is_signed = scalar_type->data.integral.is_signed;
|
||||
result_val = ZigLLVMBuildIntMinReduce(g->builder, value, is_signed);
|
||||
} else if (scalar_type->id == ZigTypeIdFloat) {
|
||||
result_val = ZigLLVMBuildFPMinReduce(g->builder, value);
|
||||
} else zig_unreachable();
|
||||
} break;
|
||||
case ReduceOp_max: {
|
||||
if (scalar_type->id == ZigTypeIdInt) {
|
||||
const bool is_signed = scalar_type->data.integral.is_signed;
|
||||
result_val = ZigLLVMBuildIntMaxReduce(g->builder, value, is_signed);
|
||||
} else if (scalar_type->id == ZigTypeIdFloat) {
|
||||
result_val = ZigLLVMBuildFPMaxReduce(g->builder, value);
|
||||
} else zig_unreachable();
|
||||
} break;
|
||||
default:
|
||||
zig_unreachable();
|
||||
}
|
||||
|
||||
return result_val;
|
||||
}
|
||||
|
||||
static LLVMValueRef ir_render_fence(CodeGen *g, IrExecutableGen *executable, IrInstGenFence *instruction) {
|
||||
LLVMAtomicOrdering atomic_order = to_LLVMAtomicOrdering(instruction->order);
|
||||
LLVMBuildFence(g->builder, atomic_order, false, "");
|
||||
@ -6674,6 +6688,8 @@ static LLVMValueRef ir_render_instruction(CodeGen *g, IrExecutableGen *executabl
|
||||
return ir_render_cmpxchg(g, executable, (IrInstGenCmpxchg *)instruction);
|
||||
case IrInstGenIdFence:
|
||||
return ir_render_fence(g, executable, (IrInstGenFence *)instruction);
|
||||
case IrInstGenIdReduce:
|
||||
return ir_render_reduce(g, executable, (IrInstGenReduce *)instruction);
|
||||
case IrInstGenIdTruncate:
|
||||
return ir_render_truncate(g, executable, (IrInstGenTruncate *)instruction);
|
||||
case IrInstGenIdBoolNot:
|
||||
@ -8630,6 +8646,7 @@ static void define_builtin_fns(CodeGen *g) {
|
||||
create_builtin_fn(g, BuiltinFnIdWasmMemorySize, "wasmMemorySize", 1);
|
||||
create_builtin_fn(g, BuiltinFnIdWasmMemoryGrow, "wasmMemoryGrow", 2);
|
||||
create_builtin_fn(g, BuiltinFnIdSrc, "src", 0);
|
||||
create_builtin_fn(g, BuiltinFnIdReduce, "reduce", 2);
|
||||
}
|
||||
|
||||
static const char *bool_to_str(bool b) {
|
||||
|
@ -402,6 +402,8 @@ static void destroy_instruction_src(IrInstSrc *inst) {
|
||||
return heap::c_allocator.destroy(reinterpret_cast<IrInstSrcCmpxchg *>(inst));
|
||||
case IrInstSrcIdFence:
|
||||
return heap::c_allocator.destroy(reinterpret_cast<IrInstSrcFence *>(inst));
|
||||
case IrInstSrcIdReduce:
|
||||
return heap::c_allocator.destroy(reinterpret_cast<IrInstSrcReduce *>(inst));
|
||||
case IrInstSrcIdTruncate:
|
||||
return heap::c_allocator.destroy(reinterpret_cast<IrInstSrcTruncate *>(inst));
|
||||
case IrInstSrcIdIntCast:
|
||||
@ -636,6 +638,8 @@ void destroy_instruction_gen(IrInstGen *inst) {
|
||||
return heap::c_allocator.destroy(reinterpret_cast<IrInstGenCmpxchg *>(inst));
|
||||
case IrInstGenIdFence:
|
||||
return heap::c_allocator.destroy(reinterpret_cast<IrInstGenFence *>(inst));
|
||||
case IrInstGenIdReduce:
|
||||
return heap::c_allocator.destroy(reinterpret_cast<IrInstGenReduce *>(inst));
|
||||
case IrInstGenIdTruncate:
|
||||
return heap::c_allocator.destroy(reinterpret_cast<IrInstGenTruncate *>(inst));
|
||||
case IrInstGenIdShuffleVector:
|
||||
@ -1311,6 +1315,10 @@ static constexpr IrInstSrcId ir_inst_id(IrInstSrcFence *) {
|
||||
return IrInstSrcIdFence;
|
||||
}
|
||||
|
||||
static constexpr IrInstSrcId ir_inst_id(IrInstSrcReduce *) {
|
||||
return IrInstSrcIdReduce;
|
||||
}
|
||||
|
||||
static constexpr IrInstSrcId ir_inst_id(IrInstSrcTruncate *) {
|
||||
return IrInstSrcIdTruncate;
|
||||
}
|
||||
@ -1775,6 +1783,10 @@ static constexpr IrInstGenId ir_inst_id(IrInstGenFence *) {
|
||||
return IrInstGenIdFence;
|
||||
}
|
||||
|
||||
static constexpr IrInstGenId ir_inst_id(IrInstGenReduce *) {
|
||||
return IrInstGenIdReduce;
|
||||
}
|
||||
|
||||
static constexpr IrInstGenId ir_inst_id(IrInstGenTruncate *) {
|
||||
return IrInstGenIdTruncate;
|
||||
}
|
||||
@ -3502,6 +3514,29 @@ static IrInstGen *ir_build_fence_gen(IrAnalyze *ira, IrInst *source_instr, Atomi
|
||||
return &instruction->base;
|
||||
}
|
||||
|
||||
static IrInstSrc *ir_build_reduce(IrBuilderSrc *irb, Scope *scope, AstNode *source_node, IrInstSrc *op, IrInstSrc *value) {
|
||||
IrInstSrcReduce *instruction = ir_build_instruction<IrInstSrcReduce>(irb, scope, source_node);
|
||||
instruction->op = op;
|
||||
instruction->value = value;
|
||||
|
||||
ir_ref_instruction(op, irb->current_basic_block);
|
||||
ir_ref_instruction(value, irb->current_basic_block);
|
||||
|
||||
return &instruction->base;
|
||||
}
|
||||
|
||||
static IrInstGen *ir_build_reduce_gen(IrAnalyze *ira, IrInst *source_instruction, ReduceOp op, IrInstGen *value, ZigType *result_type) {
|
||||
IrInstGenReduce *instruction = ir_build_inst_gen<IrInstGenReduce>(&ira->new_irb,
|
||||
source_instruction->scope, source_instruction->source_node);
|
||||
instruction->base.value->type = result_type;
|
||||
instruction->op = op;
|
||||
instruction->value = value;
|
||||
|
||||
ir_ref_inst_gen(value);
|
||||
|
||||
return &instruction->base;
|
||||
}
|
||||
|
||||
static IrInstSrc *ir_build_truncate(IrBuilderSrc *irb, Scope *scope, AstNode *source_node,
|
||||
IrInstSrc *dest_type, IrInstSrc *target)
|
||||
{
|
||||
@ -6580,6 +6615,21 @@ static IrInstSrc *ir_gen_builtin_fn_call(IrBuilderSrc *irb, Scope *scope, AstNod
|
||||
IrInstSrc *fence = ir_build_fence(irb, scope, node, arg0_value);
|
||||
return ir_lval_wrap(irb, scope, fence, lval, result_loc);
|
||||
}
|
||||
case BuiltinFnIdReduce:
|
||||
{
|
||||
AstNode *arg0_node = node->data.fn_call_expr.params.at(0);
|
||||
IrInstSrc *arg0_value = ir_gen_node(irb, arg0_node, scope);
|
||||
if (arg0_value == irb->codegen->invalid_inst_src)
|
||||
return arg0_value;
|
||||
|
||||
AstNode *arg1_node = node->data.fn_call_expr.params.at(1);
|
||||
IrInstSrc *arg1_value = ir_gen_node(irb, arg1_node, scope);
|
||||
if (arg1_value == irb->codegen->invalid_inst_src)
|
||||
return arg1_value;
|
||||
|
||||
IrInstSrc *reduce = ir_build_reduce(irb, scope, node, arg0_value, arg1_value);
|
||||
return ir_lval_wrap(irb, scope, reduce, lval, result_loc);
|
||||
}
|
||||
case BuiltinFnIdDivExact:
|
||||
{
|
||||
AstNode *arg0_node = node->data.fn_call_expr.params.at(0);
|
||||
@ -15932,6 +15982,24 @@ static bool ir_resolve_comptime(IrAnalyze *ira, IrInstGen *value, bool *out) {
|
||||
return ir_resolve_bool(ira, value, out);
|
||||
}
|
||||
|
||||
static bool ir_resolve_reduce_op(IrAnalyze *ira, IrInstGen *value, ReduceOp *out) {
|
||||
if (type_is_invalid(value->value->type))
|
||||
return false;
|
||||
|
||||
ZigType *reduce_op_type = get_builtin_type(ira->codegen, "ReduceOp");
|
||||
|
||||
IrInstGen *casted_value = ir_implicit_cast(ira, value, reduce_op_type);
|
||||
if (type_is_invalid(casted_value->value->type))
|
||||
return false;
|
||||
|
||||
ZigValue *const_val = ir_resolve_const(ira, casted_value, UndefBad);
|
||||
if (!const_val)
|
||||
return false;
|
||||
|
||||
*out = (ReduceOp)bigint_as_u32(&const_val->data.x_enum_tag);
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool ir_resolve_atomic_order(IrAnalyze *ira, IrInstGen *value, AtomicOrder *out) {
|
||||
if (type_is_invalid(value->value->type))
|
||||
return false;
|
||||
@ -26802,6 +26870,161 @@ static IrInstGen *ir_analyze_instruction_cmpxchg(IrAnalyze *ira, IrInstSrcCmpxch
|
||||
success_order, failure_order, instruction->is_weak, result_loc);
|
||||
}
|
||||
|
||||
static ErrorMsg *ir_eval_reduce(IrAnalyze *ira, IrInst *source_instr, ReduceOp op, ZigValue *value, ZigValue *out_value) {
|
||||
assert(value->type->id == ZigTypeIdVector);
|
||||
ZigType *scalar_type = value->type->data.vector.elem_type;
|
||||
const size_t len = value->type->data.vector.len;
|
||||
assert(len > 0);
|
||||
|
||||
out_value->type = scalar_type;
|
||||
out_value->special = ConstValSpecialStatic;
|
||||
|
||||
if (scalar_type->id == ZigTypeIdBool) {
|
||||
ZigValue *first_elem_val = &value->data.x_array.data.s_none.elements[0];
|
||||
|
||||
bool result = first_elem_val->data.x_bool;
|
||||
for (size_t i = 1; i < len; i++) {
|
||||
ZigValue *elem_val = &value->data.x_array.data.s_none.elements[i];
|
||||
|
||||
switch (op) {
|
||||
case ReduceOp_and:
|
||||
result = result && elem_val->data.x_bool;
|
||||
if (!result) break; // Short circuit
|
||||
break;
|
||||
case ReduceOp_or:
|
||||
result = result || elem_val->data.x_bool;
|
||||
if (result) break; // Short circuit
|
||||
break;
|
||||
case ReduceOp_xor:
|
||||
result = result != elem_val->data.x_bool;
|
||||
break;
|
||||
default:
|
||||
zig_unreachable();
|
||||
}
|
||||
}
|
||||
|
||||
out_value->data.x_bool = result;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (op != ReduceOp_min && op != ReduceOp_max) {
|
||||
ZigValue *first_elem_val = &value->data.x_array.data.s_none.elements[0];
|
||||
|
||||
copy_const_val(ira->codegen, out_value, first_elem_val);
|
||||
|
||||
for (size_t i = 1; i < len; i++) {
|
||||
ZigValue *elem_val = &value->data.x_array.data.s_none.elements[i];
|
||||
|
||||
IrBinOp bin_op;
|
||||
switch (op) {
|
||||
case ReduceOp_and: bin_op = IrBinOpBinAnd; break;
|
||||
case ReduceOp_or: bin_op = IrBinOpBinOr; break;
|
||||
case ReduceOp_xor: bin_op = IrBinOpBinXor; break;
|
||||
default: zig_unreachable();
|
||||
}
|
||||
|
||||
ErrorMsg *msg = ir_eval_math_op_scalar(ira, source_instr, scalar_type,
|
||||
out_value, bin_op, elem_val, out_value);
|
||||
if (msg != nullptr)
|
||||
return msg;
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ZigValue *candidate_elem_val = &value->data.x_array.data.s_none.elements[0];
|
||||
|
||||
ZigValue *dummy_cmp_value = ira->codegen->pass1_arena->create<ZigValue>();
|
||||
for (size_t i = 1; i < len; i++) {
|
||||
ZigValue *elem_val = &value->data.x_array.data.s_none.elements[i];
|
||||
|
||||
IrBinOp bin_op;
|
||||
switch (op) {
|
||||
case ReduceOp_min: bin_op = IrBinOpCmpLessThan; break;
|
||||
case ReduceOp_max: bin_op = IrBinOpCmpGreaterThan; break;
|
||||
default: zig_unreachable();
|
||||
}
|
||||
|
||||
ErrorMsg *msg = ir_eval_bin_op_cmp_scalar(ira, source_instr,
|
||||
elem_val, bin_op, candidate_elem_val, dummy_cmp_value);
|
||||
if (msg != nullptr)
|
||||
return msg;
|
||||
|
||||
if (dummy_cmp_value->data.x_bool)
|
||||
candidate_elem_val = elem_val;
|
||||
}
|
||||
|
||||
ira->codegen->pass1_arena->destroy(dummy_cmp_value);
|
||||
copy_const_val(ira->codegen, out_value, candidate_elem_val);
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static IrInstGen *ir_analyze_instruction_reduce(IrAnalyze *ira, IrInstSrcReduce *instruction) {
|
||||
IrInstGen *op_inst = instruction->op->child;
|
||||
if (type_is_invalid(op_inst->value->type))
|
||||
return ira->codegen->invalid_inst_gen;
|
||||
|
||||
IrInstGen *value_inst = instruction->value->child;
|
||||
if (type_is_invalid(value_inst->value->type))
|
||||
return ira->codegen->invalid_inst_gen;
|
||||
|
||||
ZigType *value_type = value_inst->value->type;
|
||||
if (value_type->id != ZigTypeIdVector) {
|
||||
ir_add_error(ira, &value_inst->base,
|
||||
buf_sprintf("expected vector type, found '%s'",
|
||||
buf_ptr(&value_type->name)));
|
||||
return ira->codegen->invalid_inst_gen;
|
||||
}
|
||||
|
||||
ReduceOp op;
|
||||
if (!ir_resolve_reduce_op(ira, op_inst, &op))
|
||||
return ira->codegen->invalid_inst_gen;
|
||||
|
||||
ZigType *elem_type = value_type->data.vector.elem_type;
|
||||
switch (elem_type->id) {
|
||||
case ZigTypeIdInt:
|
||||
break;
|
||||
case ZigTypeIdBool:
|
||||
if (op > ReduceOp_xor) {
|
||||
ir_add_error(ira, &op_inst->base,
|
||||
buf_sprintf("invalid operation for '%s' type",
|
||||
buf_ptr(&elem_type->name)));
|
||||
return ira->codegen->invalid_inst_gen;
|
||||
} break;
|
||||
case ZigTypeIdFloat:
|
||||
if (op < ReduceOp_min) {
|
||||
ir_add_error(ira, &op_inst->base,
|
||||
buf_sprintf("invalid operation for '%s' type",
|
||||
buf_ptr(&elem_type->name)));
|
||||
return ira->codegen->invalid_inst_gen;
|
||||
} break;
|
||||
default:
|
||||
// Vectors cannot have child types other than those listed above
|
||||
zig_unreachable();
|
||||
}
|
||||
|
||||
// special case zero bit types
|
||||
switch (type_has_one_possible_value(ira->codegen, elem_type)) {
|
||||
case OnePossibleValueInvalid:
|
||||
return ira->codegen->invalid_inst_gen;
|
||||
case OnePossibleValueYes:
|
||||
return ir_const_move(ira, &instruction->base.base,
|
||||
get_the_one_possible_value(ira->codegen, elem_type));
|
||||
case OnePossibleValueNo:
|
||||
break;
|
||||
}
|
||||
|
||||
if (instr_is_comptime(value_inst)) {
|
||||
IrInstGen *result = ir_const(ira, &instruction->base.base, elem_type);
|
||||
if (ir_eval_reduce(ira, &instruction->base.base, op, value_inst->value, result->value))
|
||||
return ira->codegen->invalid_inst_gen;
|
||||
return result;
|
||||
}
|
||||
|
||||
return ir_build_reduce_gen(ira, &instruction->base.base, op, value_inst, elem_type);
|
||||
}
|
||||
|
||||
static IrInstGen *ir_analyze_instruction_fence(IrAnalyze *ira, IrInstSrcFence *instruction) {
|
||||
IrInstGen *order_inst = instruction->order->child;
|
||||
if (type_is_invalid(order_inst->value->type))
|
||||
@ -31550,6 +31773,8 @@ static IrInstGen *ir_analyze_instruction_base(IrAnalyze *ira, IrInstSrc *instruc
|
||||
return ir_analyze_instruction_cmpxchg(ira, (IrInstSrcCmpxchg *)instruction);
|
||||
case IrInstSrcIdFence:
|
||||
return ir_analyze_instruction_fence(ira, (IrInstSrcFence *)instruction);
|
||||
case IrInstSrcIdReduce:
|
||||
return ir_analyze_instruction_reduce(ira, (IrInstSrcReduce *)instruction);
|
||||
case IrInstSrcIdTruncate:
|
||||
return ir_analyze_instruction_truncate(ira, (IrInstSrcTruncate *)instruction);
|
||||
case IrInstSrcIdIntCast:
|
||||
@ -31937,6 +32162,7 @@ bool ir_inst_gen_has_side_effects(IrInstGen *instruction) {
|
||||
case IrInstGenIdNegation:
|
||||
case IrInstGenIdNegationWrapping:
|
||||
case IrInstGenIdWasmMemorySize:
|
||||
case IrInstGenIdReduce:
|
||||
return false;
|
||||
|
||||
case IrInstGenIdAsm:
|
||||
@ -32106,6 +32332,7 @@ bool ir_inst_src_has_side_effects(IrInstSrc *instruction) {
|
||||
case IrInstSrcIdSpillEnd:
|
||||
case IrInstSrcIdWasmMemorySize:
|
||||
case IrInstSrcIdSrc:
|
||||
case IrInstSrcIdReduce:
|
||||
return false;
|
||||
|
||||
case IrInstSrcIdAsm:
|
||||
|
@ -200,6 +200,8 @@ const char* ir_inst_src_type_str(IrInstSrcId id) {
|
||||
return "SrcCmpxchg";
|
||||
case IrInstSrcIdFence:
|
||||
return "SrcFence";
|
||||
case IrInstSrcIdReduce:
|
||||
return "SrcReduce";
|
||||
case IrInstSrcIdTruncate:
|
||||
return "SrcTruncate";
|
||||
case IrInstSrcIdIntCast:
|
||||
@ -436,6 +438,8 @@ const char* ir_inst_gen_type_str(IrInstGenId id) {
|
||||
return "GenCmpxchg";
|
||||
case IrInstGenIdFence:
|
||||
return "GenFence";
|
||||
case IrInstGenIdReduce:
|
||||
return "GenReduce";
|
||||
case IrInstGenIdTruncate:
|
||||
return "GenTruncate";
|
||||
case IrInstGenIdBoolNot:
|
||||
@ -1584,6 +1588,14 @@ static void ir_print_fence(IrPrintSrc *irp, IrInstSrcFence *instruction) {
|
||||
fprintf(irp->f, ")");
|
||||
}
|
||||
|
||||
static void ir_print_reduce(IrPrintSrc *irp, IrInstSrcReduce *instruction) {
|
||||
fprintf(irp->f, "@reduce(");
|
||||
ir_print_other_inst_src(irp, instruction->op);
|
||||
fprintf(irp->f, ", ");
|
||||
ir_print_other_inst_src(irp, instruction->value);
|
||||
fprintf(irp->f, ")");
|
||||
}
|
||||
|
||||
static const char *atomic_order_str(AtomicOrder order) {
|
||||
switch (order) {
|
||||
case AtomicOrderUnordered: return "Unordered";
|
||||
@ -1600,6 +1612,23 @@ static void ir_print_fence(IrPrintGen *irp, IrInstGenFence *instruction) {
|
||||
fprintf(irp->f, "fence %s", atomic_order_str(instruction->order));
|
||||
}
|
||||
|
||||
static const char *reduce_op_str(ReduceOp op) {
|
||||
switch (op) {
|
||||
case ReduceOp_and: return "And";
|
||||
case ReduceOp_or: return "Or";
|
||||
case ReduceOp_xor: return "Xor";
|
||||
case ReduceOp_min: return "Min";
|
||||
case ReduceOp_max: return "Max";
|
||||
}
|
||||
zig_unreachable();
|
||||
}
|
||||
|
||||
static void ir_print_reduce(IrPrintGen *irp, IrInstGenReduce *instruction) {
|
||||
fprintf(irp->f, "@reduce(.%s, ", reduce_op_str(instruction->op));
|
||||
ir_print_other_inst_gen(irp, instruction->value);
|
||||
fprintf(irp->f, ")");
|
||||
}
|
||||
|
||||
static void ir_print_truncate(IrPrintSrc *irp, IrInstSrcTruncate *instruction) {
|
||||
fprintf(irp->f, "@truncate(");
|
||||
ir_print_other_inst_src(irp, instruction->dest_type);
|
||||
@ -2749,6 +2778,9 @@ static void ir_print_inst_src(IrPrintSrc *irp, IrInstSrc *instruction, bool trai
|
||||
case IrInstSrcIdFence:
|
||||
ir_print_fence(irp, (IrInstSrcFence *)instruction);
|
||||
break;
|
||||
case IrInstSrcIdReduce:
|
||||
ir_print_reduce(irp, (IrInstSrcReduce *)instruction);
|
||||
break;
|
||||
case IrInstSrcIdTruncate:
|
||||
ir_print_truncate(irp, (IrInstSrcTruncate *)instruction);
|
||||
break;
|
||||
@ -3097,6 +3129,9 @@ static void ir_print_inst_gen(IrPrintGen *irp, IrInstGen *instruction, bool trai
|
||||
case IrInstGenIdFence:
|
||||
ir_print_fence(irp, (IrInstGenFence *)instruction);
|
||||
break;
|
||||
case IrInstGenIdReduce:
|
||||
ir_print_reduce(irp, (IrInstGenReduce *)instruction);
|
||||
break;
|
||||
case IrInstGenIdTruncate:
|
||||
ir_print_truncate(irp, (IrInstGenTruncate *)instruction);
|
||||
break;
|
||||
|
@ -1123,6 +1123,34 @@ LLVMValueRef ZigLLVMBuildAtomicRMW(LLVMBuilderRef B, enum ZigLLVM_AtomicRMWBinOp
|
||||
singleThread ? SyncScope::SingleThread : SyncScope::System));
|
||||
}
|
||||
|
||||
LLVMValueRef ZigLLVMBuildAndReduce(LLVMBuilderRef B, LLVMValueRef Val) {
|
||||
return wrap(unwrap(B)->CreateAndReduce(unwrap(Val)));
|
||||
}
|
||||
|
||||
LLVMValueRef ZigLLVMBuildOrReduce(LLVMBuilderRef B, LLVMValueRef Val) {
|
||||
return wrap(unwrap(B)->CreateOrReduce(unwrap(Val)));
|
||||
}
|
||||
|
||||
LLVMValueRef ZigLLVMBuildXorReduce(LLVMBuilderRef B, LLVMValueRef Val) {
|
||||
return wrap(unwrap(B)->CreateXorReduce(unwrap(Val)));
|
||||
}
|
||||
|
||||
LLVMValueRef ZigLLVMBuildIntMaxReduce(LLVMBuilderRef B, LLVMValueRef Val, bool is_signed) {
|
||||
return wrap(unwrap(B)->CreateIntMaxReduce(unwrap(Val), is_signed));
|
||||
}
|
||||
|
||||
LLVMValueRef ZigLLVMBuildIntMinReduce(LLVMBuilderRef B, LLVMValueRef Val, bool is_signed) {
|
||||
return wrap(unwrap(B)->CreateIntMinReduce(unwrap(Val), is_signed));
|
||||
}
|
||||
|
||||
LLVMValueRef ZigLLVMBuildFPMaxReduce(LLVMBuilderRef B, LLVMValueRef Val) {
|
||||
return wrap(unwrap(B)->CreateFPMaxReduce(unwrap(Val)));
|
||||
}
|
||||
|
||||
LLVMValueRef ZigLLVMBuildFPMinReduce(LLVMBuilderRef B, LLVMValueRef Val) {
|
||||
return wrap(unwrap(B)->CreateFPMinReduce(unwrap(Val)));
|
||||
}
|
||||
|
||||
static_assert((Triple::ArchType)ZigLLVM_UnknownArch == Triple::UnknownArch, "");
|
||||
static_assert((Triple::ArchType)ZigLLVM_arm == Triple::arm, "");
|
||||
static_assert((Triple::ArchType)ZigLLVM_armeb == Triple::armeb, "");
|
||||
|
@ -455,6 +455,14 @@ LLVMValueRef ZigLLVMBuildAtomicRMW(LLVMBuilderRef B, enum ZigLLVM_AtomicRMWBinOp
|
||||
LLVMValueRef PTR, LLVMValueRef Val,
|
||||
LLVMAtomicOrdering ordering, LLVMBool singleThread);
|
||||
|
||||
LLVMValueRef ZigLLVMBuildAndReduce(LLVMBuilderRef B, LLVMValueRef Val);
|
||||
LLVMValueRef ZigLLVMBuildOrReduce(LLVMBuilderRef B, LLVMValueRef Val);
|
||||
LLVMValueRef ZigLLVMBuildXorReduce(LLVMBuilderRef B, LLVMValueRef Val);
|
||||
LLVMValueRef ZigLLVMBuildIntMaxReduce(LLVMBuilderRef B, LLVMValueRef Val, bool is_signed);
|
||||
LLVMValueRef ZigLLVMBuildIntMinReduce(LLVMBuilderRef B, LLVMValueRef Val, bool is_signed);
|
||||
LLVMValueRef ZigLLVMBuildFPMaxReduce(LLVMBuilderRef B, LLVMValueRef Val);
|
||||
LLVMValueRef ZigLLVMBuildFPMinReduce(LLVMBuilderRef B, LLVMValueRef Val);
|
||||
|
||||
#define ZigLLVM_DIFlags_Zero 0U
|
||||
#define ZigLLVM_DIFlags_Private 1U
|
||||
#define ZigLLVM_DIFlags_Protected 2U
|
||||
|
@ -484,3 +484,43 @@ test "vector shift operators" {
|
||||
S.doTheTest();
|
||||
comptime S.doTheTest();
|
||||
}
|
||||
|
||||
test "vector reduce operation" {
|
||||
const S = struct {
|
||||
fn doTheTestReduce(comptime op: builtin.ReduceOp, x: anytype, expected: anytype) void {
|
||||
const N = @typeInfo(@TypeOf(x)).Array.len;
|
||||
const TX = @typeInfo(@TypeOf(x)).Array.child;
|
||||
|
||||
var r = @reduce(op, @as(Vector(N, TX), x));
|
||||
expectEqual(expected, r);
|
||||
}
|
||||
fn doTheTest() void {
|
||||
doTheTestReduce(.And, [4]bool{ true, false, true, true }, @as(bool, false));
|
||||
doTheTestReduce(.Or, [4]bool{ false, true, false, false }, @as(bool, true));
|
||||
doTheTestReduce(.Xor, [4]bool{ true, true, true, false }, @as(bool, true));
|
||||
|
||||
doTheTestReduce(.And, [4]u1{ 1, 0, 1, 1 }, @as(u1, 0));
|
||||
doTheTestReduce(.Or, [4]u1{ 0, 1, 0, 0 }, @as(u1, 1));
|
||||
doTheTestReduce(.Xor, [4]u1{ 1, 1, 1, 0 }, @as(u1, 1));
|
||||
|
||||
doTheTestReduce(.And, [4]u32{ 0xffffffff, 0xffff5555, 0xaaaaffff, 0x10101010 }, @as(u32, 0x1010));
|
||||
doTheTestReduce(.Or, [4]u32{ 0xffff0000, 0xff00, 0xf0, 0xf }, ~@as(u32, 0));
|
||||
doTheTestReduce(.Xor, [4]u32{ 0x00000000, 0x33333333, 0x88888888, 0x44444444 }, ~@as(u32, 0));
|
||||
|
||||
doTheTestReduce(.Min, [4]i32{ 1234567, -386, 0, 3 }, @as(i32, -386));
|
||||
doTheTestReduce(.Max, [4]i32{ 1234567, -386, 0, 3 }, @as(i32, 1234567));
|
||||
|
||||
doTheTestReduce(.Min, [4]u32{ 99, 9999, 9, 99999 }, @as(u32, 9));
|
||||
doTheTestReduce(.Max, [4]u32{ 99, 9999, 9, 99999 }, @as(u32, 99999));
|
||||
|
||||
doTheTestReduce(.Min, [4]f32{ -10.3, 10.0e9, 13.0, -100.0 }, @as(f32, -100.0));
|
||||
doTheTestReduce(.Max, [4]f32{ -10.3, 10.0e9, 13.0, -100.0 }, @as(f32, 10.0e9));
|
||||
|
||||
doTheTestReduce(.Min, [4]f64{ -10.3, 10.0e9, 13.0, -100.0 }, @as(f64, -100.0));
|
||||
doTheTestReduce(.Max, [4]f64{ -10.3, 10.0e9, 13.0, -100.0 }, @as(f64, 10.0e9));
|
||||
}
|
||||
};
|
||||
|
||||
S.doTheTest();
|
||||
comptime S.doTheTest();
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user