From 103b885fc6660cad4bc596b6f43fad3905f4c1aa Mon Sep 17 00:00:00 2001 From: expikr <77922942+expikr@users.noreply.github.com> Date: Thu, 30 May 2024 03:58:05 -0600 Subject: [PATCH] math.hypot: fix incorrect over/underflow behavior (#19472) --- lib/std/math.zig | 1 + lib/std/math/float.zig | 13 +++ lib/std/math/hypot.zig | 245 +++++++++++++++++------------------------ 3 files changed, 114 insertions(+), 145 deletions(-) diff --git a/lib/std/math.zig b/lib/std/math.zig index 7d9b400fc4..c7bd4fb9f4 100644 --- a/lib/std/math.zig +++ b/lib/std/math.zig @@ -52,6 +52,7 @@ pub const floatTrueMin = @import("math/float.zig").floatTrueMin; pub const floatMin = @import("math/float.zig").floatMin; pub const floatMax = @import("math/float.zig").floatMax; pub const floatEps = @import("math/float.zig").floatEps; +pub const floatEpsAt = @import("math/float.zig").floatEpsAt; pub const inf = @import("math/float.zig").inf; pub const nan = @import("math/float.zig").nan; pub const snan = @import("math/float.zig").snan; diff --git a/lib/std/math/float.zig b/lib/std/math/float.zig index f8f04e217c..1d19fdc57c 100644 --- a/lib/std/math/float.zig +++ b/lib/std/math/float.zig @@ -94,6 +94,19 @@ pub inline fn floatEps(comptime T: type) T { return reconstructFloat(T, -floatFractionalBits(T), mantissaOne(T)); } +/// Returns the local epsilon of floating point type T. +pub inline fn floatEpsAt(comptime T: type, x: T) T { + switch (@typeInfo(T)) { + .Float => |F| { + const U: type = @Type(.{ .Int = .{ .signedness = .unsigned, .bits = F.bits } }); + const u: U = @bitCast(x); + const y: T = @bitCast(u ^ 1); + return @abs(x - y); + }, + else => @compileError("floatEpsAt only supports floats"), + } +} + /// Returns the value inf for floating point type T. pub inline fn inf(comptime T: type) T { return reconstructFloat(T, floatExponentMax(T) + 1, mantissaOne(T)); diff --git a/lib/std/math/hypot.zig b/lib/std/math/hypot.zig index c0bfabca1c..cc0dc17ab1 100644 --- a/lib/std/math/hypot.zig +++ b/lib/std/math/hypot.zig @@ -1,13 +1,14 @@ -// Ported from musl, which is licensed under the MIT license: -// https://git.musl-libc.org/cgit/musl/tree/COPYRIGHT -// -// https://git.musl-libc.org/cgit/musl/tree/src/math/hypotf.c -// https://git.musl-libc.org/cgit/musl/tree/src/math/hypot.c - const std = @import("../std.zig"); const math = std.math; const expect = std.testing.expect; -const maxInt = std.math.maxInt; +const isNan = math.isNan; +const isInf = math.isInf; +const inf = math.inf; +const nan = math.nan; +const floatEpsAt = math.floatEpsAt; +const floatEps = math.floatEps; +const floatMin = math.floatMin; +const floatMax = math.floatMax; /// Returns sqrt(x * x + y * y), avoiding unnecessary overflow and underflow. /// @@ -15,162 +16,116 @@ const maxInt = std.math.maxInt; /// /// | x | y | hypot | /// |-------|-------|-------| -/// | +inf | num | +inf | -/// | num | +-inf | +inf | -/// | nan | any | nan | -/// | any | nan | nan | +/// | +-inf | any | +inf | +/// | any | +-inf | +inf | +/// | nan | fin | nan | +/// | fin | nan | nan | pub fn hypot(x: anytype, y: anytype) @TypeOf(x, y) { const T = @TypeOf(x, y); - return switch (T) { - f32 => hypot32(x, y), - f64 => hypot64(x, y), + switch (@typeInfo(T)) { + .Float => {}, + .ComptimeFloat => return @sqrt(x * x + y * y), else => @compileError("hypot not implemented for " ++ @typeName(T)), - }; + } + const lower = @sqrt(floatMin(T)); + const upper = @sqrt(floatMax(T) / 2); + const incre = @sqrt(floatEps(T) / 2); + const scale = floatEpsAt(T, incre); + const hypfn = if (emulateFma(T)) hypotUnfused else hypotFused; + var major: T = x; + var minor: T = y; + if (isInf(major) or isInf(minor)) return inf(T); + if (isNan(major) or isNan(minor)) return nan(T); + if (T == f16) return @floatCast(@sqrt(@mulAdd(f32, x, x, @as(f32, y) * y))); + if (T == f32) return @floatCast(@sqrt(@mulAdd(f64, x, x, @as(f64, y) * y))); + major = @abs(major); + minor = @abs(minor); + if (minor > major) { + const tempo = major; + major = minor; + minor = tempo; + } + if (major * incre >= minor) return major; + if (major > upper) return hypfn(T, major * scale, minor * scale) / scale; + if (minor < lower) return hypfn(T, major / scale, minor / scale) * scale; + return hypfn(T, major, minor); } -fn hypot32(x: f32, y: f32) f32 { - var ux = @as(u32, @bitCast(x)); - var uy = @as(u32, @bitCast(y)); - - ux &= maxInt(u32) >> 1; - uy &= maxInt(u32) >> 1; - if (ux < uy) { - const tmp = ux; - ux = uy; - uy = tmp; - } - - var xx = @as(f32, @bitCast(ux)); - var yy = @as(f32, @bitCast(uy)); - if (uy == 0xFF << 23) { - return yy; - } - if (ux >= 0xFF << 23 or uy == 0 or ux - uy >= (25 << 23)) { - return xx + yy; - } - - var z: f32 = 1.0; - if (ux >= (0x7F + 60) << 23) { - z = 0x1.0p90; - xx *= 0x1.0p-90; - yy *= 0x1.0p-90; - } else if (uy < (0x7F - 60) << 23) { - z = 0x1.0p-90; - xx *= 0x1.0p-90; - yy *= 0x1.0p-90; - } - - return z * @sqrt(@as(f32, @floatCast(@as(f64, x) * x + @as(f64, y) * y))); +inline fn emulateFma(comptime T: type) bool { + // If @mulAdd lowers to the software implementation, + // hypotUnfused should be used in place of hypotFused. + // This takes an educated guess, but ideally we should + // properly detect at comptime when that fallback will + // occur. + return (T == f128 or T == f80); } -fn sq(hi: *f64, lo: *f64, x: f64) void { - const split: f64 = 0x1.0p27 + 1.0; - const xc = x * split; - const xh = x - xc + xc; - const xl = x - xh; - hi.* = x * x; - lo.* = xh * xh - hi.* + 2 * xh * xl + xl * xl; +inline fn hypotFused(comptime F: type, x: F, y: F) F { + const r = @sqrt(@mulAdd(F, x, x, y * y)); + const rr = r * r; + const xx = x * x; + const z = @mulAdd(F, -y, y, rr - xx) + @mulAdd(F, r, r, -rr) - @mulAdd(F, x, x, -xx); + return r - z / (2 * r); } -fn hypot64(x: f64, y: f64) f64 { - var ux = @as(u64, @bitCast(x)); - var uy = @as(u64, @bitCast(y)); - - ux &= maxInt(u64) >> 1; - uy &= maxInt(u64) >> 1; - if (ux < uy) { - const tmp = ux; - ux = uy; - uy = tmp; +inline fn hypotUnfused(comptime F: type, x: F, y: F) F { + const r = @sqrt(x * x + y * y); + if (r <= 2 * y) { // 30deg or steeper + const dx = r - y; + const z = x * (2 * dx - x) + (dx - 2 * (x - y)) * dx; + return r - z / (2 * r); + } else { // shallower than 30 deg + const dy = r - x; + const z = 2 * dy * (x - 2 * y) + (4 * dy - y) * y + dy * dy; + return r - z / (2 * r); } - - const ex = ux >> 52; - const ey = uy >> 52; - var xx = @as(f64, @bitCast(ux)); - var yy = @as(f64, @bitCast(uy)); - - // hypot(inf, nan) == inf - if (ey == 0x7FF) { - return yy; - } - if (ex == 0x7FF or uy == 0) { - return xx; - } - - // hypot(x, y) ~= x + y * y / x / 2 with inexact for small y/x - if (ex - ey > 64) { - return xx + yy; - } - - var z: f64 = 1; - if (ex > 0x3FF + 510) { - z = 0x1.0p700; - xx *= 0x1.0p-700; - yy *= 0x1.0p-700; - } else if (ey < 0x3FF - 450) { - z = 0x1.0p-700; - xx *= 0x1.0p700; - yy *= 0x1.0p700; - } - - var hx: f64 = undefined; - var lx: f64 = undefined; - var hy: f64 = undefined; - var ly: f64 = undefined; - - sq(&hx, &lx, x); - sq(&hy, &ly, y); - - return z * @sqrt(ly + lx + hy + hx); } +const hypot_test_cases = .{ + .{ 0.0, -1.2, 1.2 }, + .{ 0.2, -0.34, 0.3944616584663203993612799816649560759946493601889826495362 }, + .{ 0.8923, 2.636890, 2.7837722899152509525110650481670176852603253522923737962880 }, + .{ 1.5, 5.25, 5.4600824169603887033229768686452745953332522619323580787836 }, + .{ 37.45, 159.835, 164.16372840856167640478217141034363907565754072954443805164 }, + .{ 89.123, 382.028905, 392.28687638576315875933966414927490685367196874260165618371 }, + .{ 123123.234375, 529428.707813, 543556.88524707706887251269205923830745438413088753096759371 }, +}; + test hypot { - const x32: f32 = 0.0; - const y32: f32 = -1.2; - const x64: f64 = 0.0; - const y64: f64 = -1.2; - try expect(hypot(x32, y32) == hypot32(0.0, -1.2)); - try expect(hypot(x64, y64) == hypot64(0.0, -1.2)); + try expect(hypot(0.3, 0.4) == 0.5); } -test hypot32 { - const epsilon = 0.000001; - - try expect(math.approxEqAbs(f32, hypot32(0.0, -1.2), 1.2, epsilon)); - try expect(math.approxEqAbs(f32, hypot32(0.2, -0.34), 0.394462, epsilon)); - try expect(math.approxEqAbs(f32, hypot32(0.8923, 2.636890), 2.783772, epsilon)); - try expect(math.approxEqAbs(f32, hypot32(1.5, 5.25), 5.460083, epsilon)); - try expect(math.approxEqAbs(f32, hypot32(37.45, 159.835), 164.163742, epsilon)); - try expect(math.approxEqAbs(f32, hypot32(89.123, 382.028905), 392.286865, epsilon)); - try expect(math.approxEqAbs(f32, hypot32(123123.234375, 529428.707813), 543556.875, epsilon)); +test "hypot.correct" { + inline for (.{ f16, f32, f64, f128 }) |T| { + inline for (hypot_test_cases) |v| { + const a: T, const b: T, const c: T = v; + try expect(math.approxEqRel(T, hypot(a, b), c, @sqrt(floatEps(T)))); + } + } } -test hypot64 { - const epsilon = 0.000001; - - try expect(math.approxEqAbs(f64, hypot64(0.0, -1.2), 1.2, epsilon)); - try expect(math.approxEqAbs(f64, hypot64(0.2, -0.34), 0.394462, epsilon)); - try expect(math.approxEqAbs(f64, hypot64(0.8923, 2.636890), 2.783772, epsilon)); - try expect(math.approxEqAbs(f64, hypot64(1.5, 5.25), 5.460082, epsilon)); - try expect(math.approxEqAbs(f64, hypot64(37.45, 159.835), 164.163728, epsilon)); - try expect(math.approxEqAbs(f64, hypot64(89.123, 382.028905), 392.286876, epsilon)); - try expect(math.approxEqAbs(f64, hypot64(123123.234375, 529428.707813), 543556.885247, epsilon)); +test "hypot.precise" { + inline for (.{ f16, f32, f64 }) |T| { // f128 seems to be 5 ulp + inline for (hypot_test_cases) |v| { + const a: T, const b: T, const c: T = v; + try expect(math.approxEqRel(T, hypot(a, b), c, floatEps(T))); + } + } } -test "hypot32.special" { - try expect(math.isPositiveInf(hypot32(math.inf(f32), 0.0))); - try expect(math.isPositiveInf(hypot32(-math.inf(f32), 0.0))); - try expect(math.isPositiveInf(hypot32(0.0, math.inf(f32)))); - try expect(math.isPositiveInf(hypot32(0.0, -math.inf(f32)))); - try expect(math.isNan(hypot32(math.nan(f32), 0.0))); - try expect(math.isNan(hypot32(0.0, math.nan(f32)))); -} +test "hypot.special" { + inline for (.{ f16, f32, f64, f128 }) |T| { + try expect(math.isNan(hypot(nan(T), 0.0))); + try expect(math.isNan(hypot(0.0, nan(T)))); -test "hypot64.special" { - try expect(math.isPositiveInf(hypot64(math.inf(f64), 0.0))); - try expect(math.isPositiveInf(hypot64(-math.inf(f64), 0.0))); - try expect(math.isPositiveInf(hypot64(0.0, math.inf(f64)))); - try expect(math.isPositiveInf(hypot64(0.0, -math.inf(f64)))); - try expect(math.isNan(hypot64(math.nan(f64), 0.0))); - try expect(math.isNan(hypot64(0.0, math.nan(f64)))); + try expect(math.isPositiveInf(hypot(inf(T), 0.0))); + try expect(math.isPositiveInf(hypot(0.0, inf(T)))); + try expect(math.isPositiveInf(hypot(inf(T), nan(T)))); + try expect(math.isPositiveInf(hypot(nan(T), inf(T)))); + + try expect(math.isPositiveInf(hypot(-inf(T), 0.0))); + try expect(math.isPositiveInf(hypot(0.0, -inf(T)))); + try expect(math.isPositiveInf(hypot(-inf(T), nan(T)))); + try expect(math.isPositiveInf(hypot(nan(T), -inf(T)))); + } }