implement spills when expressions used across suspend points

closes #3077
This commit is contained in:
Andrew Kelley 2019-09-07 00:12:15 -04:00
parent 9ca8d9e21a
commit d1a98ccff4
No known key found for this signature in database
GPG Key ID: 7C5F548F728501A9
7 changed files with 218 additions and 29 deletions

View File

@ -2124,6 +2124,7 @@ enum ScopeId {
ScopeIdCompTime,
ScopeIdRuntime,
ScopeIdTypeOf,
ScopeIdExpr,
};
struct Scope {
@ -2271,6 +2272,24 @@ struct ScopeTypeOf {
Scope base;
};
enum MemoizedBool {
MemoizedBoolUnknown,
MemoizedBoolFalse,
MemoizedBoolTrue,
};
// This scope is created for each expression.
// It's used to identify when an instruction needs to be spilled,
// so that it can be accessed after a suspend point.
struct ScopeExpr {
Scope base;
ScopeExpr **children_ptr;
size_t children_len;
MemoizedBool need_spill;
};
// synchronized with code in define_builtin_compile_vars
enum AtomicOrder {
AtomicOrderUnordered,
@ -2510,6 +2529,10 @@ struct IrInstruction {
// with this child field.
IrInstruction *child;
IrBasicBlock *owner_bb;
// Nearly any instruction can have to be stored as a local variable before suspending
// and then loaded after resuming, in case there is an expression with a suspend point
// in it, such as: x + await y
IrInstruction *spill;
IrInstructionId id;
// true if this instruction was generated by zig and not from user code
bool is_gen;

View File

@ -96,6 +96,30 @@ static ScopeDecls **get_container_scope_ptr(ZigType *type_entry) {
zig_unreachable();
}
static ScopeExpr *find_expr_scope(Scope *scope) {
for (;;) {
switch (scope->id) {
case ScopeIdExpr:
return reinterpret_cast<ScopeExpr *>(scope);
case ScopeIdDefer:
case ScopeIdDeferExpr:
case ScopeIdDecls:
case ScopeIdFnDef:
case ScopeIdCompTime:
case ScopeIdVarDecl:
case ScopeIdCImport:
case ScopeIdSuspend:
case ScopeIdTypeOf:
case ScopeIdBlock:
return nullptr;
case ScopeIdLoop:
case ScopeIdRuntime:
scope = scope->parent;
continue;
}
}
}
ScopeDecls *get_container_scope(ZigType *type_entry) {
return *get_container_scope_ptr(type_entry);
}
@ -203,6 +227,20 @@ Scope *create_typeof_scope(CodeGen *g, AstNode *node, Scope *parent) {
return &scope->base;
}
Scope *create_expr_scope(CodeGen *g, AstNode *node, Scope *parent) {
ScopeExpr *scope = allocate<ScopeExpr>(1);
init_scope(g, &scope->base, ScopeIdExpr, node, parent);
ScopeExpr *parent_expr = find_expr_scope(parent);
if (parent_expr != nullptr) {
size_t new_len = parent_expr->children_len + 1;
parent_expr->children_ptr = reallocate_nonzero<ScopeExpr *>(
parent_expr->children_ptr, parent_expr->children_len, new_len);
parent_expr->children_ptr[parent_expr->children_len] = scope;
parent_expr->children_len = new_len;
}
return &scope->base;
}
ZigType *get_scope_import(Scope *scope) {
while (scope) {
if (scope->id == ScopeIdDecls) {
@ -5654,6 +5692,69 @@ static ZigType *get_async_fn_type(CodeGen *g, ZigType *orig_fn_type) {
return fn_type;
}
// Traverse up to the very top ExprScope, which has children.
// We have just arrived at the top from a child. That child,
// and its next siblings, do not need to be marked. But the previous
// siblings do.
// x + (await y)
// vs
// (await y) + x
static void mark_suspension_point(Scope *scope) {
ScopeExpr *child_expr_scope = (scope->id == ScopeIdExpr) ? reinterpret_cast<ScopeExpr *>(scope) : nullptr;
for (;;) {
scope = scope->parent;
switch (scope->id) {
case ScopeIdDefer:
case ScopeIdDeferExpr:
case ScopeIdDecls:
case ScopeIdFnDef:
case ScopeIdCompTime:
case ScopeIdVarDecl:
case ScopeIdCImport:
case ScopeIdSuspend:
case ScopeIdTypeOf:
case ScopeIdBlock:
return;
case ScopeIdLoop:
case ScopeIdRuntime:
continue;
case ScopeIdExpr: {
ScopeExpr *parent_expr_scope = reinterpret_cast<ScopeExpr *>(scope);
if (child_expr_scope != nullptr) {
for (size_t i = 0; parent_expr_scope->children_ptr[i] != child_expr_scope; i += 1) {
assert(i < parent_expr_scope->children_len);
parent_expr_scope->children_ptr[i]->need_spill = MemoizedBoolTrue;
}
}
parent_expr_scope->need_spill = MemoizedBoolTrue;
child_expr_scope = parent_expr_scope;
continue;
}
}
}
}
static bool scope_needs_spill(Scope *scope) {
ScopeExpr *scope_expr = find_expr_scope(scope);
if (scope_expr == nullptr) return false;
switch (scope_expr->need_spill) {
case MemoizedBoolUnknown:
if (scope_needs_spill(scope_expr->base.parent)) {
scope_expr->need_spill = MemoizedBoolTrue;
return true;
} else {
scope_expr->need_spill = MemoizedBoolFalse;
return false;
}
case MemoizedBoolFalse:
return false;
case MemoizedBoolTrue:
return true;
}
zig_unreachable();
}
static Error resolve_async_frame(CodeGen *g, ZigType *frame_type) {
Error err;
@ -5786,21 +5887,17 @@ static Error resolve_async_frame(CodeGen *g, ZigType *frame_type) {
callee_frame_type, "");
}
// Since this frame is async, an await might represent a suspend point, and
// therefore need to spill.
// therefore need to spill. It also needs to mark expr scopes as having to spill.
// For example: foo() + await z
// The funtion call result of foo() must be spilled.
for (size_t i = 0; i < fn->await_list.length; i += 1) {
IrInstructionAwaitGen *await = fn->await_list.at(i);
// TODO If this is a noasync await, it doesn't need to spill
// TODO If this is a noasync await, it doesn't suspend
// https://github.com/ziglang/zig/issues/3157
if (await->result_loc != nullptr) {
// If there's a result location, that is the spill
if (await->base.value.special != ConstValSpecialRuntime) {
// Known at comptime. No spill, no suspend.
continue;
}
if (!type_has_bits(await->base.value.type))
continue;
if (await->base.value.special != ConstValSpecialRuntime)
continue;
if (await->base.ref_count == 0)
continue;
if (await->target_fn != nullptr) {
// we might not need to suspend
analyze_fn_async(g, await->target_fn, false);
@ -5809,13 +5906,53 @@ static Error resolve_async_frame(CodeGen *g, ZigType *frame_type) {
return ErrorSemanticAnalyzeFail;
}
if (!fn_is_async(await->target_fn)) {
// This await does not represent a suspend point. No spill needed.
// This await does not represent a suspend point. No spill needed,
// and no need to mark ExprScope.
continue;
}
}
// This await is a suspend point, but it might not need a spill.
// We do need to mark the ExprScope as having a suspend point in it.
mark_suspension_point(await->base.scope);
if (await->result_loc != nullptr) {
// If there's a result location, that is the spill
continue;
}
if (await->base.ref_count == 0)
continue;
if (!type_has_bits(await->base.value.type))
continue;
await->result_loc = ir_create_alloca(g, await->base.scope, await->base.source_node, fn,
await->base.value.type, "");
}
// Now that we've marked all the expr scopes that have to spill, we go over the instructions
// and spill the relevant ones.
for (size_t block_i = 0; block_i < fn->analyzed_executable.basic_block_list.length; block_i += 1) {
IrBasicBlock *block = fn->analyzed_executable.basic_block_list.at(block_i);
for (size_t instr_i = 0; instr_i < block->instruction_list.length; instr_i += 1) {
IrInstruction *instruction = block->instruction_list.at(instr_i);
if (instruction->id == IrInstructionIdAwaitGen ||
instruction->id == IrInstructionIdVarPtr ||
instruction->id == IrInstructionIdDeclRef ||
instruction->id == IrInstructionIdAllocaGen)
{
// This instruction does its own spilling specially, or otherwise doesn't need it.
continue;
}
if (instruction->value.special != ConstValSpecialRuntime)
continue;
if (instruction->ref_count == 0)
continue;
if (!type_has_bits(instruction->value.type))
continue;
if (scope_needs_spill(instruction->scope)) {
instruction->spill = ir_create_alloca(g, instruction->scope, instruction->source_node,
fn, instruction->value.type, "");
}
}
}
FnTypeId *fn_type_id = &fn_type->data.fn.fn_type_id;
ZigType *ptr_return_type = get_pointer_to_type(g, fn_type_id->return_type, false);

View File

@ -114,6 +114,7 @@ ScopeFnDef *create_fndef_scope(CodeGen *g, AstNode *node, Scope *parent, ZigFn *
Scope *create_comptime_scope(CodeGen *g, AstNode *node, Scope *parent);
Scope *create_runtime_scope(CodeGen *g, AstNode *node, Scope *parent, IrInstruction *is_comptime);
Scope *create_typeof_scope(CodeGen *g, AstNode *node, Scope *parent);
Scope *create_expr_scope(CodeGen *g, AstNode *node, Scope *parent);
void init_const_str_lit(CodeGen *g, ConstExprValue *const_val, Buf *str);
ConstExprValue *create_const_str_lit(CodeGen *g, Buf *str);
@ -261,5 +262,4 @@ void add_async_error_notes(CodeGen *g, ErrorMsg *msg, ZigFn *fn);
IrInstruction *ir_create_alloca(CodeGen *g, Scope *scope, AstNode *source_node, ZigFn *fn,
ZigType *var_type, const char *name_hint);
#endif

View File

@ -649,6 +649,7 @@ static ZigLLVMDIScope *get_di_scope(CodeGen *g, Scope *scope) {
case ScopeIdCompTime:
case ScopeIdRuntime:
case ScopeIdTypeOf:
case ScopeIdExpr:
return get_di_scope(g, scope->parent);
}
zig_unreachable();
@ -1644,7 +1645,6 @@ static void gen_assign_raw(CodeGen *g, LLVMValueRef ptr, ZigType *ptr_type,
LLVMValueRef ored_value = LLVMBuildOr(g->builder, shifted_value, anded_containing_int, "");
gen_store(g, ored_value, ptr, ptr_type);
return;
}
static void gen_var_debug_decl(CodeGen *g, ZigVar *var) {
@ -1664,11 +1664,16 @@ static LLVMValueRef ir_llvm_value(CodeGen *g, IrInstruction *instruction) {
if (instruction->id == IrInstructionIdAwaitGen) {
IrInstructionAwaitGen *await = reinterpret_cast<IrInstructionAwaitGen*>(instruction);
if (await->result_loc != nullptr) {
instruction->llvm_value = get_handle_value(g, ir_llvm_value(g, await->result_loc),
return get_handle_value(g, ir_llvm_value(g, await->result_loc),
await->result_loc->value.type->data.pointer.child_type, await->result_loc->value.type);
return instruction->llvm_value;
}
}
if (instruction->spill != nullptr) {
ZigType *ptr_type = instruction->spill->value.type;
src_assert(ptr_type->id == ZigTypeIdPointer, instruction->source_node);
return get_handle_value(g, ir_llvm_value(g, instruction->spill),
ptr_type->data.pointer.child_type, instruction->spill->value.type);
}
src_assert(instruction->value.special != ConstValSpecialRuntime, instruction->source_node);
assert(instruction->value.type);
render_const_val(g, &instruction->value, "");
@ -3786,6 +3791,7 @@ static void render_async_var_decls(CodeGen *g, Scope *scope) {
case ScopeIdCompTime:
case ScopeIdRuntime:
case ScopeIdTypeOf:
case ScopeIdExpr:
scope = scope->parent;
continue;
}
@ -6049,6 +6055,11 @@ static void ir_render(CodeGen *g, ZigFn *fn_entry) {
set_debug_location(g, instruction);
}
instruction->llvm_value = ir_render_instruction(g, executable, instruction);
if (instruction->spill != nullptr) {
LLVMValueRef spill_ptr = ir_llvm_value(g, instruction->spill);
gen_assign_raw(g, spill_ptr, instruction->spill->value.type, instruction->llvm_value);
instruction->llvm_value = nullptr;
}
}
current_block->llvm_exit_block = LLVMGetInsertBlock(g->builder);
}

View File

@ -3364,6 +3364,7 @@ static void ir_count_defers(IrBuilder *irb, Scope *inner_scope, Scope *outer_sco
case ScopeIdCompTime:
case ScopeIdRuntime:
case ScopeIdTypeOf:
case ScopeIdExpr:
scope = scope->parent;
continue;
case ScopeIdDeferExpr:
@ -3420,6 +3421,7 @@ static bool ir_gen_defers_for_block(IrBuilder *irb, Scope *inner_scope, Scope *o
case ScopeIdCompTime:
case ScopeIdRuntime:
case ScopeIdTypeOf:
case ScopeIdExpr:
scope = scope->parent;
continue;
case ScopeIdDeferExpr:
@ -8158,7 +8160,15 @@ static IrInstruction *ir_gen_node_extra(IrBuilder *irb, AstNode *node, Scope *sc
result_loc = no_result_loc();
ir_build_reset_result(irb, scope, node, result_loc);
}
IrInstruction *result = ir_gen_node_raw(irb, node, scope, lval, result_loc);
Scope *child_scope;
if (irb->exec->is_inline ||
(irb->exec->fn_entry != nullptr && irb->exec->fn_entry->child_scope == scope))
{
child_scope = scope;
} else {
child_scope = create_expr_scope(irb->codegen, node, scope);
}
IrInstruction *result = ir_gen_node_raw(irb, node, child_scope, lval, result_loc);
if (result == irb->codegen->invalid_instruction) {
if (irb->exec->first_err_trace_msg == nullptr) {
irb->exec->first_err_trace_msg = irb->codegen->trace_err;

View File

@ -104,11 +104,7 @@ fn testFuture(loop: *Loop) void {
var b = async waitOnFuture(&future);
resolveFuture(&future);
// TODO https://github.com/ziglang/zig/issues/3077
//const result = (await a) + (await b);
const a_result = await a;
const b_result = await b;
const result = a_result + b_result;
const result = (await a) + (await b);
testing.expect(result == 12);
}

View File

@ -921,12 +921,10 @@ fn recursiveAsyncFunctionTest(comptime suspending_implementation: bool) type {
var sum: u32 = 0;
f1_awaited = true;
const result_f1 = await f1; // TODO https://github.com/ziglang/zig/issues/3077
sum += try result_f1;
sum += try await f1;
f2_awaited = true;
const result_f2 = await f2; // TODO https://github.com/ziglang/zig/issues/3077
sum += try result_f2;
sum += try await f2;
return sum;
}
@ -943,8 +941,7 @@ fn recursiveAsyncFunctionTest(comptime suspending_implementation: bool) type {
fn amain(result: *u32) void {
var x = async fib(std.heap.direct_allocator, 10);
const res = await x; // TODO https://github.com/ziglang/zig/issues/3077
result.* = res catch unreachable;
result.* = (await x) catch unreachable;
}
};
}
@ -1002,8 +999,7 @@ test "@asyncCall using the result location inside the frame" {
return 1234;
}
fn getAnswer(f: anyframe->i32, out: *i32) void {
var res = await f; // TODO https://github.com/ziglang/zig/issues/3077
out.* = res;
out.* = await f;
}
};
var data: i32 = 1;
@ -1124,3 +1120,19 @@ test "await used in expression and awaiting fn with no suspend but async calling
};
_ = async S.atest();
}
test "await used in expression after a fn call" {
const S = struct {
fn atest() void {
var f1 = async add(3, 4);
var sum: i32 = 0;
sum = foo() + await f1;
expect(sum == 8);
}
async fn add(a: i32, b: i32) i32 {
return a + b;
}
fn foo() i32 { return 1; }
};
_ = async S.atest();
}