From 646d927c792dbdd6db4a5bbee3cf5847283fe861 Mon Sep 17 00:00:00 2001 From: Veikka Tuominen Date: Thu, 20 Oct 2022 13:29:58 +0300 Subject: [PATCH] stage2: fix handling of aarch64 C ABI float array like structs Closes #11702 Closes #13125 --- src/arch/aarch64/abi.zig | 90 ++++++++++++++++++++++++++++++++-------- src/codegen/llvm.zig | 19 ++++----- test/c_abi/cfuncs.c | 27 ++++++++++++ test/c_abi/main.zig | 33 +++++++++++++++ 4 files changed, 142 insertions(+), 27 deletions(-) diff --git a/src/arch/aarch64/abi.zig b/src/arch/aarch64/abi.zig index f26a9a8a8a..7c92d4e91c 100644 --- a/src/arch/aarch64/abi.zig +++ b/src/arch/aarch64/abi.zig @@ -5,29 +5,21 @@ const Register = bits.Register; const RegisterManagerFn = @import("../../register_manager.zig").RegisterManager; const Type = @import("../../type.zig").Type; -pub const Class = enum { memory, integer, none, float_array }; +pub const Class = enum(u8) { memory, integer, none, float_array, _ }; +/// For `float_array` the second element will be the amount of floats. pub fn classifyType(ty: Type, target: std.Target) [2]Class { + var maybe_float_bits: ?u16 = null; + const float_count = countFloats(ty, target, &maybe_float_bits); + if (float_count <= sret_float_count) return .{ .float_array, @intToEnum(Class, float_count) }; + return classifyTypeInner(ty, target); +} + +fn classifyTypeInner(ty: Type, target: std.Target) [2]Class { if (!ty.hasRuntimeBitsIgnoreComptime()) return .{ .none, .none }; switch (ty.zigTypeTag()) { .Struct => { if (ty.containerLayout() == .Packed) return .{ .integer, .none }; - - if (ty.structFieldCount() <= 4) { - const fields = ty.structFields(); - var float_size: ?u64 = null; - for (fields.values()) |field| { - if (field.ty.zigTypeTag() != .Float) break; - const field_size = field.ty.bitSize(target); - const prev_size = float_size orelse { - float_size = field_size; - continue; - }; - if (field_size != prev_size) break; - } else { - return .{ .float_array, .none }; - } - } const bit_size = ty.bitSize(target); if (bit_size > 128) return .{ .memory, .none }; if (bit_size > 64) return .{ .integer, .integer }; @@ -67,6 +59,70 @@ pub fn classifyType(ty: Type, target: std.Target) [2]Class { } } +const sret_float_count = 4; +fn countFloats(ty: Type, target: std.Target, maybe_float_bits: *?u16) u32 { + const invalid = std.math.maxInt(u32); + switch (ty.zigTypeTag()) { + .Union => { + const fields = ty.unionFields(); + var max_count: u32 = 0; + for (fields.values()) |field| { + const field_count = countFloats(field.ty, target, maybe_float_bits); + if (field_count == invalid) return invalid; + if (field_count > max_count) max_count = field_count; + if (max_count > sret_float_count) return invalid; + } + return max_count; + }, + .Struct => { + const fields_len = ty.structFieldCount(); + var count: u32 = 0; + var i: u32 = 0; + while (i < fields_len) : (i += 1) { + const field_ty = ty.structFieldType(i); + const field_count = countFloats(field_ty, target, maybe_float_bits); + if (field_count == invalid) return invalid; + count += field_count; + if (count > sret_float_count) return invalid; + } + return count; + }, + .Float => { + const float_bits = maybe_float_bits.* orelse { + maybe_float_bits.* = ty.floatBits(target); + return 1; + }; + if (ty.floatBits(target) == float_bits) return 1; + return invalid; + }, + .Void => return 0, + else => return invalid, + } +} + +pub fn getFloatArrayType(ty: Type) ?Type { + switch (ty.zigTypeTag()) { + .Union => { + const fields = ty.unionFields(); + for (fields.values()) |field| { + if (getFloatArrayType(field.ty)) |some| return some; + } + return null; + }, + .Struct => { + const fields_len = ty.structFieldCount(); + var i: u32 = 0; + while (i < fields_len) : (i += 1) { + const field_ty = ty.structFieldType(i); + if (getFloatArrayType(field_ty)) |some| return some; + } + return null; + }, + .Float => return ty, + else => return null, + } +} + const callee_preserved_regs_impl = if (builtin.os.tag.isDarwin()) struct { pub const callee_preserved_regs = [_]Register{ .x20, .x21, .x22, .x23, diff --git a/src/codegen/llvm.zig b/src/codegen/llvm.zig index b16fc76c01..9894f3efd6 100644 --- a/src/codegen/llvm.zig +++ b/src/codegen/llvm.zig @@ -3125,10 +3125,10 @@ pub const DeclGen = struct { .as_u16 => { try llvm_params.append(dg.context.intType(16)); }, - .float_array => { + .float_array => |count| { const param_ty = fn_info.param_types[it.zig_index - 1]; - const float_ty = try dg.lowerType(param_ty.structFieldType(0)); - const field_count = @intCast(c_uint, param_ty.structFieldCount()); + const float_ty = try dg.lowerType(aarch64_c_abi.getFloatArrayType(param_ty).?); + const field_count = @intCast(c_uint, count); const arr_ty = float_ty.arrayType(field_count); try llvm_params.append(arr_ty); }, @@ -4801,7 +4801,7 @@ pub const FuncGen = struct { const casted = self.builder.buildBitCast(llvm_arg, self.dg.context.intType(16), ""); try llvm_args.append(casted); }, - .float_array => { + .float_array => |count| { const arg = args[it.zig_index - 1]; const arg_ty = self.air.typeOf(arg); var llvm_arg = try self.resolveInst(arg); @@ -4812,9 +4812,8 @@ pub const FuncGen = struct { llvm_arg = store_inst; } - const float_ty = try self.dg.lowerType(arg_ty.structFieldType(0)); - const field_count = @intCast(u32, arg_ty.structFieldCount()); - const array_llvm_ty = float_ty.arrayType(field_count); + const float_ty = try self.dg.lowerType(aarch64_c_abi.getFloatArrayType(arg_ty).?); + const array_llvm_ty = float_ty.arrayType(count); const casted = self.builder.buildBitCast(llvm_arg, array_llvm_ty.pointerType(0), ""); const alignment = arg_ty.abiAlignment(target); @@ -10214,7 +10213,7 @@ const ParamTypeIterator = struct { llvm_types_buffer: [8]u16, byval_attr: bool, - const Lowering = enum { + const Lowering = union(enum) { no_bits, byval, byref, @@ -10223,7 +10222,7 @@ const ParamTypeIterator = struct { multiple_llvm_float, slice, as_u16, - float_array, + float_array: u8, }; pub fn next(it: *ParamTypeIterator) ?Lowering { @@ -10400,7 +10399,7 @@ const ParamTypeIterator = struct { return .byref; } if (classes[0] == .float_array) { - return .float_array; + return Lowering{ .float_array = @enumToInt(classes[1]) }; } if (classes[1] == .none) { it.llvm_types_len = 1; diff --git a/test/c_abi/cfuncs.c b/test/c_abi/cfuncs.c index c18accd4d6..2b560c8743 100644 --- a/test/c_abi/cfuncs.c +++ b/test/c_abi/cfuncs.c @@ -650,3 +650,30 @@ void c_struct_with_array(StructWithArray x) { StructWithArray c_ret_struct_with_array() { return (StructWithArray) { 4, {}, 155 }; } + +typedef struct { + struct Point { + double x; + double y; + } origin; + struct Size { + double width; + double height; + } size; +} FloatArrayStruct; + +void c_float_array_struct(FloatArrayStruct x) { + assert_or_panic(x.origin.x == 5); + assert_or_panic(x.origin.y == 6); + assert_or_panic(x.size.width == 7); + assert_or_panic(x.size.height == 8); +} + +FloatArrayStruct c_ret_float_array_struct() { + FloatArrayStruct x; + x.origin.x = 1; + x.origin.y = 2; + x.size.width = 3; + x.size.height = 4; + return x; +} diff --git a/test/c_abi/main.zig b/test/c_abi/main.zig index fc9db79076..9e856e0ab2 100644 --- a/test/c_abi/main.zig +++ b/test/c_abi/main.zig @@ -700,3 +700,36 @@ test "Struct with array as padding." { try std.testing.expect(x.a == 4); try std.testing.expect(x.b == 155); } + +const FloatArrayStruct = extern struct { + origin: extern struct { + x: f64, + y: f64, + }, + size: extern struct { + width: f64, + height: f64, + }, +}; + +extern fn c_float_array_struct(FloatArrayStruct) void; +extern fn c_ret_float_array_struct() FloatArrayStruct; + +test "Float array like struct" { + c_float_array_struct(.{ + .origin = .{ + .x = 5, + .y = 6, + }, + .size = .{ + .width = 7, + .height = 8, + }, + }); + + var x = c_ret_float_array_struct(); + try std.testing.expect(x.origin.x == 1); + try std.testing.expect(x.origin.y == 2); + try std.testing.expect(x.size.width == 3); + try std.testing.expect(x.size.height == 4); +}