From a31b70c4b8d0bed67463b2f54e74198baa93329f Mon Sep 17 00:00:00 2001 From: LemonBoy Date: Sat, 10 Oct 2020 00:46:53 +0200 Subject: [PATCH] std: Add/Fix/Change parts of big.int * Add an optimized squaring routine under the `sqr` name. Algorithms for squaring bigger numbers efficiently will come in a PR later. * Fix a bug where a multiplication was done twice if the threshold for the use of Karatsuba algorithm was crossed. Add a test to make sure this won't happen again. * Streamline `pow` method, take a `Const` parameter. * Minor tweaks to `pow`, avoid bit-reversing the exponent. --- lib/std/math/big/int.zig | 130 +++++++++++++++++++++++++++------- lib/std/math/big/int_test.zig | 44 +++++++++--- 2 files changed, 137 insertions(+), 37 deletions(-) diff --git a/lib/std/math/big/int.zig b/lib/std/math/big/int.zig index 54ad2f55d0..25cafda9ac 100644 --- a/lib/std/math/big/int.zig +++ b/lib/std/math/big/int.zig @@ -446,6 +446,26 @@ pub const Mutable = struct { rma.positive = (a.positive == b.positive); } + /// rma = a * a + /// + /// `rma` may not alias with `a`. + /// + /// Asserts the result fits in `rma`. An upper bound on the number of limbs needed by + /// rma is given by `2 * a.limbs.len + 1`. + /// + /// If `allocator` is provided, it will be used for temporary storage to improve + /// multiplication performance. `error.OutOfMemory` is handled with a fallback algorithm. + pub fn sqrNoAlias(rma: *Mutable, a: Const, opt_allocator: ?*Allocator) void { + assert(rma.limbs.ptr != a.limbs.ptr); // illegal aliasing + + mem.set(Limb, rma.limbs, 0); + + llsquare_basecase(rma.limbs, a.limbs); + + rma.normalize(2 * a.limbs.len + 1); + rma.positive = true; + } + /// q = a / b (rem r) /// /// a / b are floored (rounded towards 0). @@ -1827,7 +1847,28 @@ pub const Managed = struct { rma.setMetadata(m.positive, m.len); } - pub fn pow(rma: *Managed, a: Managed, b: u32) !void { + /// r = a * a + pub fn sqr(rma: *Managed, a: Const) !void { + const needed_limbs = 2 * a.limbs.len + 1; + + if (rma.limbs.ptr == a.limbs.ptr) { + var m = try Managed.initCapacity(rma.allocator, needed_limbs); + errdefer m.deinit(); + var m_mut = m.toMutable(); + m_mut.sqrNoAlias(a, rma.allocator); + m.setMetadata(m_mut.positive, m_mut.len); + + rma.deinit(); + rma.swap(&m); + } else { + try rma.ensureCapacity(needed_limbs); + var rma_mut = rma.toMutable(); + rma_mut.sqrNoAlias(a, rma.allocator); + rma.setMetadata(rma_mut.positive, rma_mut.len); + } + } + + pub fn pow(rma: *Managed, a: Const, b: u32) !void { const needed_limbs = calcPowLimbsBufferLen(a.bitCountAbs(), b); const limbs_buffer = try rma.allocator.alloc(Limb, needed_limbs); @@ -1837,7 +1878,7 @@ pub const Managed = struct { var m = try Managed.initCapacity(rma.allocator, needed_limbs); errdefer m.deinit(); var m_mut = m.toMutable(); - try m_mut.pow(a.toConst(), b, limbs_buffer); + try m_mut.pow(a, b, limbs_buffer); m.setMetadata(m_mut.positive, m_mut.len); rma.deinit(); @@ -1845,7 +1886,7 @@ pub const Managed = struct { } else { try rma.ensureCapacity(needed_limbs); var rma_mut = rma.toMutable(); - try rma_mut.pow(a.toConst(), b, limbs_buffer); + try rma_mut.pow(a, b, limbs_buffer); rma.setMetadata(rma_mut.positive, rma_mut.len); } } @@ -1869,11 +1910,14 @@ fn llmulacc(opt_allocator: ?*Allocator, r: []Limb, a: []const Limb, b: []const L assert(r.len >= x.len + y.len + 1); // 48 is a pretty abitrary size chosen based on performance of a factorial program. - if (x.len > 48) { - if (opt_allocator) |allocator| { - llmulacc_karatsuba(allocator, r, x, y) catch |err| switch (err) { - error.OutOfMemory => {}, // handled below - }; + k_mul: { + if (x.len > 48) { + if (opt_allocator) |allocator| { + llmulacc_karatsuba(allocator, r, x, y) catch |err| switch (err) { + error.OutOfMemory => break :k_mul, // handled below + }; + return; + } } } @@ -2203,6 +2247,42 @@ fn llxor(r: []Limb, a: []const Limb, b: []const Limb) void { } } +/// r MUST NOT alias x. +fn llsquare_basecase(r: []Limb, x: []const Limb) void { + @setRuntimeSafety(debug_safety); + + const x_norm = x; + assert(r.len >= 2 * x_norm.len + 1); + + // Compute the square of a N-limb bigint with only (N^2 + N)/2 + // multiplications by exploting the symmetry of the coefficients around the + // diagonal: + // + // a b c * + // a b c = + // ------------------- + // ca cb cc + + // ba bb bc + + // aa ab ac + // + // Note that: + // - Each mixed-product term appears twice for each column, + // - Squares are always in the 2k (0 <= k < N) column + + for (x_norm) |v, i| { + // Accumulate all the x[i]*x[j] (with x!=j) products + llmulDigit(r[2 * i + 1 ..], x_norm[i + 1 ..], v); + } + + // Each product appears twice, multiply by 2 + llshl(r, r[0 .. 2 * x_norm.len], 1); + + for (x_norm) |v, i| { + // Compute and add the squares + llmulDigit(r[2 * i ..], x[i .. i + 1], v); + } +} + /// Knuth 4.6.3 fn llpow(r: []Limb, a: []const Limb, b: u32, tmp_limbs: []Limb) void { var tmp1: []Limb = undefined; @@ -2212,9 +2292,9 @@ fn llpow(r: []Limb, a: []const Limb, b: u32, tmp_limbs: []Limb) void { // variable, use the output limbs and another temporary set to overcome this // limitation. // The initial assignment makes the result end in `r` so an extra memory - // copy is saved, each 1 flips the index twice so it's a no-op so count the - // 0. - const b_leading_zeros = @intCast(u5, @clz(u32, b)); + // copy is saved, each 1 flips the index twice so it's only the zeros that + // matter. + const b_leading_zeros = @clz(u32, b); const exp_zeros = @popCount(u32, ~b) - b_leading_zeros; if (exp_zeros & 1 != 0) { tmp1 = tmp_limbs; @@ -2224,32 +2304,28 @@ fn llpow(r: []Limb, a: []const Limb, b: u32, tmp_limbs: []Limb) void { tmp2 = tmp_limbs; } - const a_norm = a[0..llnormalize(a)]; - - mem.copy(Limb, tmp1, a_norm); - mem.set(Limb, tmp1[a_norm.len..], 0); + mem.copy(Limb, tmp1, a); + mem.set(Limb, tmp1[a.len..], 0); // Scan the exponent as a binary number, from left to right, dropping the // most significant bit set. - const exp_bits = @intCast(u5, 31 - b_leading_zeros); - var exp = @bitReverse(u32, b) >> 1 + b_leading_zeros; + // Square the result if the current bit is zero, square and multiply by a if + // it is one. + var exp_bits = 32 - 1 - b_leading_zeros; + var exp = b << @intCast(u5, 1 + b_leading_zeros); - var i: u5 = 0; + var i: usize = 0; while (i < exp_bits) : (i += 1) { // Square - { - mem.set(Limb, tmp2, 0); - const op = tmp1[0..llnormalize(tmp1)]; - llmulacc(null, tmp2, op, op); - mem.swap([]Limb, &tmp1, &tmp2); - } + mem.set(Limb, tmp2, 0); + llsquare_basecase(tmp2, tmp1[0..llnormalize(tmp1)]); + mem.swap([]Limb, &tmp1, &tmp2); // Multiply by a - if (exp & 1 != 0) { + if (@shlWithOverflow(u32, exp, 1, &exp)) { mem.set(Limb, tmp2, 0); - llmulacc(null, tmp2, tmp1[0..llnormalize(tmp1)], a_norm); + llmulacc(null, tmp2, tmp1[0..llnormalize(tmp1)], a); mem.swap([]Limb, &tmp1, &tmp2); } - exp >>= 1; } } diff --git a/lib/std/math/big/int_test.zig b/lib/std/math/big/int_test.zig index 5d07bee9b5..1f4bd65974 100644 --- a/lib/std/math/big/int_test.zig +++ b/lib/std/math/big/int_test.zig @@ -720,6 +720,27 @@ test "big.int mul 0*0" { testing.expect((try c.to(u32)) == 0); } +test "big.int mul large" { + var a = try Managed.initCapacity(testing.allocator, 50); + defer a.deinit(); + var b = try Managed.initCapacity(testing.allocator, 100); + defer b.deinit(); + var c = try Managed.initCapacity(testing.allocator, 100); + defer c.deinit(); + + // Generate a number that's large enough to cross the thresholds for the use + // of subquadratic algorithms + for (a.limbs) |*p| { + p.* = std.math.maxInt(Limb); + } + a.setMetadata(true, 50); + + try b.mul(a.toConst(), a.toConst()); + try c.sqr(a.toConst()); + + testing.expect(b.eq(c)); +} + test "big.int div single-single no rem" { var a = try Managed.initSet(testing.allocator, 50); defer a.deinit(); @@ -1483,11 +1504,14 @@ test "big.int const to managed" { test "big.int pow" { { - var a = try Managed.initSet(testing.allocator, 10); + var a = try Managed.initSet(testing.allocator, -3); defer a.deinit(); - try a.pow(a, 8); - testing.expectEqual(@as(u32, 100000000), try a.to(u32)); + try a.pow(a.toConst(), 3); + testing.expectEqual(@as(i32, -27), try a.to(i32)); + + try a.pow(a.toConst(), 4); + testing.expectEqual(@as(i32, 531441), try a.to(i32)); } { var a = try Managed.initSet(testing.allocator, 10); @@ -1497,9 +1521,9 @@ test "big.int pow" { defer y.deinit(); // y and a are not aliased - try y.pow(a, 123); + try y.pow(a.toConst(), 123); // y and a are aliased - try a.pow(a, 123); + try a.pow(a.toConst(), 123); testing.expect(a.eq(y)); @@ -1517,18 +1541,18 @@ test "big.int pow" { var a = try Managed.initSet(testing.allocator, 0); defer a.deinit(); - try a.pow(a, 100); + try a.pow(a.toConst(), 100); testing.expectEqual(@as(i32, 0), try a.to(i32)); try a.set(1); - try a.pow(a, 0); + try a.pow(a.toConst(), 0); testing.expectEqual(@as(i32, 1), try a.to(i32)); - try a.pow(a, 100); + try a.pow(a.toConst(), 100); testing.expectEqual(@as(i32, 1), try a.to(i32)); try a.set(-1); - try a.pow(a, 15); + try a.pow(a.toConst(), 15); testing.expectEqual(@as(i32, -1), try a.to(i32)); - try a.pow(a, 16); + try a.pow(a.toConst(), 16); testing.expectEqual(@as(i32, 1), try a.to(i32)); } }