mirror of
https://github.com/ziglang/zig.git
synced 2024-11-16 09:03:12 +00:00
implement spills when expressions used across suspend points
closes #3077
This commit is contained in:
parent
9ca8d9e21a
commit
d1a98ccff4
@ -2124,6 +2124,7 @@ enum ScopeId {
|
||||
ScopeIdCompTime,
|
||||
ScopeIdRuntime,
|
||||
ScopeIdTypeOf,
|
||||
ScopeIdExpr,
|
||||
};
|
||||
|
||||
struct Scope {
|
||||
@ -2271,6 +2272,24 @@ struct ScopeTypeOf {
|
||||
Scope base;
|
||||
};
|
||||
|
||||
enum MemoizedBool {
|
||||
MemoizedBoolUnknown,
|
||||
MemoizedBoolFalse,
|
||||
MemoizedBoolTrue,
|
||||
};
|
||||
|
||||
// This scope is created for each expression.
|
||||
// It's used to identify when an instruction needs to be spilled,
|
||||
// so that it can be accessed after a suspend point.
|
||||
struct ScopeExpr {
|
||||
Scope base;
|
||||
|
||||
ScopeExpr **children_ptr;
|
||||
size_t children_len;
|
||||
|
||||
MemoizedBool need_spill;
|
||||
};
|
||||
|
||||
// synchronized with code in define_builtin_compile_vars
|
||||
enum AtomicOrder {
|
||||
AtomicOrderUnordered,
|
||||
@ -2510,6 +2529,10 @@ struct IrInstruction {
|
||||
// with this child field.
|
||||
IrInstruction *child;
|
||||
IrBasicBlock *owner_bb;
|
||||
// Nearly any instruction can have to be stored as a local variable before suspending
|
||||
// and then loaded after resuming, in case there is an expression with a suspend point
|
||||
// in it, such as: x + await y
|
||||
IrInstruction *spill;
|
||||
IrInstructionId id;
|
||||
// true if this instruction was generated by zig and not from user code
|
||||
bool is_gen;
|
||||
|
159
src/analyze.cpp
159
src/analyze.cpp
@ -96,6 +96,30 @@ static ScopeDecls **get_container_scope_ptr(ZigType *type_entry) {
|
||||
zig_unreachable();
|
||||
}
|
||||
|
||||
static ScopeExpr *find_expr_scope(Scope *scope) {
|
||||
for (;;) {
|
||||
switch (scope->id) {
|
||||
case ScopeIdExpr:
|
||||
return reinterpret_cast<ScopeExpr *>(scope);
|
||||
case ScopeIdDefer:
|
||||
case ScopeIdDeferExpr:
|
||||
case ScopeIdDecls:
|
||||
case ScopeIdFnDef:
|
||||
case ScopeIdCompTime:
|
||||
case ScopeIdVarDecl:
|
||||
case ScopeIdCImport:
|
||||
case ScopeIdSuspend:
|
||||
case ScopeIdTypeOf:
|
||||
case ScopeIdBlock:
|
||||
return nullptr;
|
||||
case ScopeIdLoop:
|
||||
case ScopeIdRuntime:
|
||||
scope = scope->parent;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ScopeDecls *get_container_scope(ZigType *type_entry) {
|
||||
return *get_container_scope_ptr(type_entry);
|
||||
}
|
||||
@ -203,6 +227,20 @@ Scope *create_typeof_scope(CodeGen *g, AstNode *node, Scope *parent) {
|
||||
return &scope->base;
|
||||
}
|
||||
|
||||
Scope *create_expr_scope(CodeGen *g, AstNode *node, Scope *parent) {
|
||||
ScopeExpr *scope = allocate<ScopeExpr>(1);
|
||||
init_scope(g, &scope->base, ScopeIdExpr, node, parent);
|
||||
ScopeExpr *parent_expr = find_expr_scope(parent);
|
||||
if (parent_expr != nullptr) {
|
||||
size_t new_len = parent_expr->children_len + 1;
|
||||
parent_expr->children_ptr = reallocate_nonzero<ScopeExpr *>(
|
||||
parent_expr->children_ptr, parent_expr->children_len, new_len);
|
||||
parent_expr->children_ptr[parent_expr->children_len] = scope;
|
||||
parent_expr->children_len = new_len;
|
||||
}
|
||||
return &scope->base;
|
||||
}
|
||||
|
||||
ZigType *get_scope_import(Scope *scope) {
|
||||
while (scope) {
|
||||
if (scope->id == ScopeIdDecls) {
|
||||
@ -5654,6 +5692,69 @@ static ZigType *get_async_fn_type(CodeGen *g, ZigType *orig_fn_type) {
|
||||
return fn_type;
|
||||
}
|
||||
|
||||
// Traverse up to the very top ExprScope, which has children.
|
||||
// We have just arrived at the top from a child. That child,
|
||||
// and its next siblings, do not need to be marked. But the previous
|
||||
// siblings do.
|
||||
// x + (await y)
|
||||
// vs
|
||||
// (await y) + x
|
||||
static void mark_suspension_point(Scope *scope) {
|
||||
ScopeExpr *child_expr_scope = (scope->id == ScopeIdExpr) ? reinterpret_cast<ScopeExpr *>(scope) : nullptr;
|
||||
for (;;) {
|
||||
scope = scope->parent;
|
||||
switch (scope->id) {
|
||||
case ScopeIdDefer:
|
||||
case ScopeIdDeferExpr:
|
||||
case ScopeIdDecls:
|
||||
case ScopeIdFnDef:
|
||||
case ScopeIdCompTime:
|
||||
case ScopeIdVarDecl:
|
||||
case ScopeIdCImport:
|
||||
case ScopeIdSuspend:
|
||||
case ScopeIdTypeOf:
|
||||
case ScopeIdBlock:
|
||||
return;
|
||||
case ScopeIdLoop:
|
||||
case ScopeIdRuntime:
|
||||
continue;
|
||||
case ScopeIdExpr: {
|
||||
ScopeExpr *parent_expr_scope = reinterpret_cast<ScopeExpr *>(scope);
|
||||
if (child_expr_scope != nullptr) {
|
||||
for (size_t i = 0; parent_expr_scope->children_ptr[i] != child_expr_scope; i += 1) {
|
||||
assert(i < parent_expr_scope->children_len);
|
||||
parent_expr_scope->children_ptr[i]->need_spill = MemoizedBoolTrue;
|
||||
}
|
||||
}
|
||||
parent_expr_scope->need_spill = MemoizedBoolTrue;
|
||||
child_expr_scope = parent_expr_scope;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static bool scope_needs_spill(Scope *scope) {
|
||||
ScopeExpr *scope_expr = find_expr_scope(scope);
|
||||
if (scope_expr == nullptr) return false;
|
||||
|
||||
switch (scope_expr->need_spill) {
|
||||
case MemoizedBoolUnknown:
|
||||
if (scope_needs_spill(scope_expr->base.parent)) {
|
||||
scope_expr->need_spill = MemoizedBoolTrue;
|
||||
return true;
|
||||
} else {
|
||||
scope_expr->need_spill = MemoizedBoolFalse;
|
||||
return false;
|
||||
}
|
||||
case MemoizedBoolFalse:
|
||||
return false;
|
||||
case MemoizedBoolTrue:
|
||||
return true;
|
||||
}
|
||||
zig_unreachable();
|
||||
}
|
||||
|
||||
static Error resolve_async_frame(CodeGen *g, ZigType *frame_type) {
|
||||
Error err;
|
||||
|
||||
@ -5786,21 +5887,17 @@ static Error resolve_async_frame(CodeGen *g, ZigType *frame_type) {
|
||||
callee_frame_type, "");
|
||||
}
|
||||
// Since this frame is async, an await might represent a suspend point, and
|
||||
// therefore need to spill.
|
||||
// therefore need to spill. It also needs to mark expr scopes as having to spill.
|
||||
// For example: foo() + await z
|
||||
// The funtion call result of foo() must be spilled.
|
||||
for (size_t i = 0; i < fn->await_list.length; i += 1) {
|
||||
IrInstructionAwaitGen *await = fn->await_list.at(i);
|
||||
// TODO If this is a noasync await, it doesn't need to spill
|
||||
// TODO If this is a noasync await, it doesn't suspend
|
||||
// https://github.com/ziglang/zig/issues/3157
|
||||
if (await->result_loc != nullptr) {
|
||||
// If there's a result location, that is the spill
|
||||
if (await->base.value.special != ConstValSpecialRuntime) {
|
||||
// Known at comptime. No spill, no suspend.
|
||||
continue;
|
||||
}
|
||||
if (!type_has_bits(await->base.value.type))
|
||||
continue;
|
||||
if (await->base.value.special != ConstValSpecialRuntime)
|
||||
continue;
|
||||
if (await->base.ref_count == 0)
|
||||
continue;
|
||||
if (await->target_fn != nullptr) {
|
||||
// we might not need to suspend
|
||||
analyze_fn_async(g, await->target_fn, false);
|
||||
@ -5809,13 +5906,53 @@ static Error resolve_async_frame(CodeGen *g, ZigType *frame_type) {
|
||||
return ErrorSemanticAnalyzeFail;
|
||||
}
|
||||
if (!fn_is_async(await->target_fn)) {
|
||||
// This await does not represent a suspend point. No spill needed.
|
||||
// This await does not represent a suspend point. No spill needed,
|
||||
// and no need to mark ExprScope.
|
||||
continue;
|
||||
}
|
||||
}
|
||||
// This await is a suspend point, but it might not need a spill.
|
||||
// We do need to mark the ExprScope as having a suspend point in it.
|
||||
mark_suspension_point(await->base.scope);
|
||||
|
||||
if (await->result_loc != nullptr) {
|
||||
// If there's a result location, that is the spill
|
||||
continue;
|
||||
}
|
||||
if (await->base.ref_count == 0)
|
||||
continue;
|
||||
if (!type_has_bits(await->base.value.type))
|
||||
continue;
|
||||
await->result_loc = ir_create_alloca(g, await->base.scope, await->base.source_node, fn,
|
||||
await->base.value.type, "");
|
||||
}
|
||||
// Now that we've marked all the expr scopes that have to spill, we go over the instructions
|
||||
// and spill the relevant ones.
|
||||
for (size_t block_i = 0; block_i < fn->analyzed_executable.basic_block_list.length; block_i += 1) {
|
||||
IrBasicBlock *block = fn->analyzed_executable.basic_block_list.at(block_i);
|
||||
for (size_t instr_i = 0; instr_i < block->instruction_list.length; instr_i += 1) {
|
||||
IrInstruction *instruction = block->instruction_list.at(instr_i);
|
||||
if (instruction->id == IrInstructionIdAwaitGen ||
|
||||
instruction->id == IrInstructionIdVarPtr ||
|
||||
instruction->id == IrInstructionIdDeclRef ||
|
||||
instruction->id == IrInstructionIdAllocaGen)
|
||||
{
|
||||
// This instruction does its own spilling specially, or otherwise doesn't need it.
|
||||
continue;
|
||||
}
|
||||
if (instruction->value.special != ConstValSpecialRuntime)
|
||||
continue;
|
||||
if (instruction->ref_count == 0)
|
||||
continue;
|
||||
if (!type_has_bits(instruction->value.type))
|
||||
continue;
|
||||
if (scope_needs_spill(instruction->scope)) {
|
||||
instruction->spill = ir_create_alloca(g, instruction->scope, instruction->source_node,
|
||||
fn, instruction->value.type, "");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
FnTypeId *fn_type_id = &fn_type->data.fn.fn_type_id;
|
||||
ZigType *ptr_return_type = get_pointer_to_type(g, fn_type_id->return_type, false);
|
||||
|
||||
|
@ -114,6 +114,7 @@ ScopeFnDef *create_fndef_scope(CodeGen *g, AstNode *node, Scope *parent, ZigFn *
|
||||
Scope *create_comptime_scope(CodeGen *g, AstNode *node, Scope *parent);
|
||||
Scope *create_runtime_scope(CodeGen *g, AstNode *node, Scope *parent, IrInstruction *is_comptime);
|
||||
Scope *create_typeof_scope(CodeGen *g, AstNode *node, Scope *parent);
|
||||
Scope *create_expr_scope(CodeGen *g, AstNode *node, Scope *parent);
|
||||
|
||||
void init_const_str_lit(CodeGen *g, ConstExprValue *const_val, Buf *str);
|
||||
ConstExprValue *create_const_str_lit(CodeGen *g, Buf *str);
|
||||
@ -261,5 +262,4 @@ void add_async_error_notes(CodeGen *g, ErrorMsg *msg, ZigFn *fn);
|
||||
IrInstruction *ir_create_alloca(CodeGen *g, Scope *scope, AstNode *source_node, ZigFn *fn,
|
||||
ZigType *var_type, const char *name_hint);
|
||||
|
||||
|
||||
#endif
|
||||
|
@ -649,6 +649,7 @@ static ZigLLVMDIScope *get_di_scope(CodeGen *g, Scope *scope) {
|
||||
case ScopeIdCompTime:
|
||||
case ScopeIdRuntime:
|
||||
case ScopeIdTypeOf:
|
||||
case ScopeIdExpr:
|
||||
return get_di_scope(g, scope->parent);
|
||||
}
|
||||
zig_unreachable();
|
||||
@ -1644,7 +1645,6 @@ static void gen_assign_raw(CodeGen *g, LLVMValueRef ptr, ZigType *ptr_type,
|
||||
LLVMValueRef ored_value = LLVMBuildOr(g->builder, shifted_value, anded_containing_int, "");
|
||||
|
||||
gen_store(g, ored_value, ptr, ptr_type);
|
||||
return;
|
||||
}
|
||||
|
||||
static void gen_var_debug_decl(CodeGen *g, ZigVar *var) {
|
||||
@ -1664,11 +1664,16 @@ static LLVMValueRef ir_llvm_value(CodeGen *g, IrInstruction *instruction) {
|
||||
if (instruction->id == IrInstructionIdAwaitGen) {
|
||||
IrInstructionAwaitGen *await = reinterpret_cast<IrInstructionAwaitGen*>(instruction);
|
||||
if (await->result_loc != nullptr) {
|
||||
instruction->llvm_value = get_handle_value(g, ir_llvm_value(g, await->result_loc),
|
||||
return get_handle_value(g, ir_llvm_value(g, await->result_loc),
|
||||
await->result_loc->value.type->data.pointer.child_type, await->result_loc->value.type);
|
||||
return instruction->llvm_value;
|
||||
}
|
||||
}
|
||||
if (instruction->spill != nullptr) {
|
||||
ZigType *ptr_type = instruction->spill->value.type;
|
||||
src_assert(ptr_type->id == ZigTypeIdPointer, instruction->source_node);
|
||||
return get_handle_value(g, ir_llvm_value(g, instruction->spill),
|
||||
ptr_type->data.pointer.child_type, instruction->spill->value.type);
|
||||
}
|
||||
src_assert(instruction->value.special != ConstValSpecialRuntime, instruction->source_node);
|
||||
assert(instruction->value.type);
|
||||
render_const_val(g, &instruction->value, "");
|
||||
@ -3786,6 +3791,7 @@ static void render_async_var_decls(CodeGen *g, Scope *scope) {
|
||||
case ScopeIdCompTime:
|
||||
case ScopeIdRuntime:
|
||||
case ScopeIdTypeOf:
|
||||
case ScopeIdExpr:
|
||||
scope = scope->parent;
|
||||
continue;
|
||||
}
|
||||
@ -6049,6 +6055,11 @@ static void ir_render(CodeGen *g, ZigFn *fn_entry) {
|
||||
set_debug_location(g, instruction);
|
||||
}
|
||||
instruction->llvm_value = ir_render_instruction(g, executable, instruction);
|
||||
if (instruction->spill != nullptr) {
|
||||
LLVMValueRef spill_ptr = ir_llvm_value(g, instruction->spill);
|
||||
gen_assign_raw(g, spill_ptr, instruction->spill->value.type, instruction->llvm_value);
|
||||
instruction->llvm_value = nullptr;
|
||||
}
|
||||
}
|
||||
current_block->llvm_exit_block = LLVMGetInsertBlock(g->builder);
|
||||
}
|
||||
|
12
src/ir.cpp
12
src/ir.cpp
@ -3364,6 +3364,7 @@ static void ir_count_defers(IrBuilder *irb, Scope *inner_scope, Scope *outer_sco
|
||||
case ScopeIdCompTime:
|
||||
case ScopeIdRuntime:
|
||||
case ScopeIdTypeOf:
|
||||
case ScopeIdExpr:
|
||||
scope = scope->parent;
|
||||
continue;
|
||||
case ScopeIdDeferExpr:
|
||||
@ -3420,6 +3421,7 @@ static bool ir_gen_defers_for_block(IrBuilder *irb, Scope *inner_scope, Scope *o
|
||||
case ScopeIdCompTime:
|
||||
case ScopeIdRuntime:
|
||||
case ScopeIdTypeOf:
|
||||
case ScopeIdExpr:
|
||||
scope = scope->parent;
|
||||
continue;
|
||||
case ScopeIdDeferExpr:
|
||||
@ -8158,7 +8160,15 @@ static IrInstruction *ir_gen_node_extra(IrBuilder *irb, AstNode *node, Scope *sc
|
||||
result_loc = no_result_loc();
|
||||
ir_build_reset_result(irb, scope, node, result_loc);
|
||||
}
|
||||
IrInstruction *result = ir_gen_node_raw(irb, node, scope, lval, result_loc);
|
||||
Scope *child_scope;
|
||||
if (irb->exec->is_inline ||
|
||||
(irb->exec->fn_entry != nullptr && irb->exec->fn_entry->child_scope == scope))
|
||||
{
|
||||
child_scope = scope;
|
||||
} else {
|
||||
child_scope = create_expr_scope(irb->codegen, node, scope);
|
||||
}
|
||||
IrInstruction *result = ir_gen_node_raw(irb, node, child_scope, lval, result_loc);
|
||||
if (result == irb->codegen->invalid_instruction) {
|
||||
if (irb->exec->first_err_trace_msg == nullptr) {
|
||||
irb->exec->first_err_trace_msg = irb->codegen->trace_err;
|
||||
|
@ -104,11 +104,7 @@ fn testFuture(loop: *Loop) void {
|
||||
var b = async waitOnFuture(&future);
|
||||
resolveFuture(&future);
|
||||
|
||||
// TODO https://github.com/ziglang/zig/issues/3077
|
||||
//const result = (await a) + (await b);
|
||||
const a_result = await a;
|
||||
const b_result = await b;
|
||||
const result = a_result + b_result;
|
||||
const result = (await a) + (await b);
|
||||
|
||||
testing.expect(result == 12);
|
||||
}
|
||||
|
@ -921,12 +921,10 @@ fn recursiveAsyncFunctionTest(comptime suspending_implementation: bool) type {
|
||||
var sum: u32 = 0;
|
||||
|
||||
f1_awaited = true;
|
||||
const result_f1 = await f1; // TODO https://github.com/ziglang/zig/issues/3077
|
||||
sum += try result_f1;
|
||||
sum += try await f1;
|
||||
|
||||
f2_awaited = true;
|
||||
const result_f2 = await f2; // TODO https://github.com/ziglang/zig/issues/3077
|
||||
sum += try result_f2;
|
||||
sum += try await f2;
|
||||
|
||||
return sum;
|
||||
}
|
||||
@ -943,8 +941,7 @@ fn recursiveAsyncFunctionTest(comptime suspending_implementation: bool) type {
|
||||
|
||||
fn amain(result: *u32) void {
|
||||
var x = async fib(std.heap.direct_allocator, 10);
|
||||
const res = await x; // TODO https://github.com/ziglang/zig/issues/3077
|
||||
result.* = res catch unreachable;
|
||||
result.* = (await x) catch unreachable;
|
||||
}
|
||||
};
|
||||
}
|
||||
@ -1002,8 +999,7 @@ test "@asyncCall using the result location inside the frame" {
|
||||
return 1234;
|
||||
}
|
||||
fn getAnswer(f: anyframe->i32, out: *i32) void {
|
||||
var res = await f; // TODO https://github.com/ziglang/zig/issues/3077
|
||||
out.* = res;
|
||||
out.* = await f;
|
||||
}
|
||||
};
|
||||
var data: i32 = 1;
|
||||
@ -1124,3 +1120,19 @@ test "await used in expression and awaiting fn with no suspend but async calling
|
||||
};
|
||||
_ = async S.atest();
|
||||
}
|
||||
|
||||
test "await used in expression after a fn call" {
|
||||
const S = struct {
|
||||
fn atest() void {
|
||||
var f1 = async add(3, 4);
|
||||
var sum: i32 = 0;
|
||||
sum = foo() + await f1;
|
||||
expect(sum == 8);
|
||||
}
|
||||
async fn add(a: i32, b: i32) i32 {
|
||||
return a + b;
|
||||
}
|
||||
fn foo() i32 { return 1; }
|
||||
};
|
||||
_ = async S.atest();
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user