std: Add/Fix/Change parts of big.int

* Add an optimized squaring routine under the `sqr` name.
  Algorithms for squaring bigger numbers efficiently will come in a
  PR later.
* Fix a bug where a multiplication was done twice if the threshold for
  the use of Karatsuba algorithm was crossed. Add a test to make sure
  this won't happen again.
* Streamline `pow` method, take a `Const` parameter.
* Minor tweaks to `pow`, avoid bit-reversing the exponent.
This commit is contained in:
LemonBoy 2020-10-10 00:46:53 +02:00 committed by Andrew Kelley
parent fbc6a00b0a
commit a31b70c4b8
2 changed files with 137 additions and 37 deletions

View File

@ -446,6 +446,26 @@ pub const Mutable = struct {
rma.positive = (a.positive == b.positive);
}
/// rma = a * a
///
/// `rma` may not alias with `a`.
///
/// Asserts the result fits in `rma`. An upper bound on the number of limbs needed by
/// rma is given by `2 * a.limbs.len + 1`.
///
/// If `allocator` is provided, it will be used for temporary storage to improve
/// multiplication performance. `error.OutOfMemory` is handled with a fallback algorithm.
pub fn sqrNoAlias(rma: *Mutable, a: Const, opt_allocator: ?*Allocator) void {
assert(rma.limbs.ptr != a.limbs.ptr); // illegal aliasing
mem.set(Limb, rma.limbs, 0);
llsquare_basecase(rma.limbs, a.limbs);
rma.normalize(2 * a.limbs.len + 1);
rma.positive = true;
}
/// q = a / b (rem r)
///
/// a / b are floored (rounded towards 0).
@ -1827,7 +1847,28 @@ pub const Managed = struct {
rma.setMetadata(m.positive, m.len);
}
pub fn pow(rma: *Managed, a: Managed, b: u32) !void {
/// r = a * a
pub fn sqr(rma: *Managed, a: Const) !void {
const needed_limbs = 2 * a.limbs.len + 1;
if (rma.limbs.ptr == a.limbs.ptr) {
var m = try Managed.initCapacity(rma.allocator, needed_limbs);
errdefer m.deinit();
var m_mut = m.toMutable();
m_mut.sqrNoAlias(a, rma.allocator);
m.setMetadata(m_mut.positive, m_mut.len);
rma.deinit();
rma.swap(&m);
} else {
try rma.ensureCapacity(needed_limbs);
var rma_mut = rma.toMutable();
rma_mut.sqrNoAlias(a, rma.allocator);
rma.setMetadata(rma_mut.positive, rma_mut.len);
}
}
pub fn pow(rma: *Managed, a: Const, b: u32) !void {
const needed_limbs = calcPowLimbsBufferLen(a.bitCountAbs(), b);
const limbs_buffer = try rma.allocator.alloc(Limb, needed_limbs);
@ -1837,7 +1878,7 @@ pub const Managed = struct {
var m = try Managed.initCapacity(rma.allocator, needed_limbs);
errdefer m.deinit();
var m_mut = m.toMutable();
try m_mut.pow(a.toConst(), b, limbs_buffer);
try m_mut.pow(a, b, limbs_buffer);
m.setMetadata(m_mut.positive, m_mut.len);
rma.deinit();
@ -1845,7 +1886,7 @@ pub const Managed = struct {
} else {
try rma.ensureCapacity(needed_limbs);
var rma_mut = rma.toMutable();
try rma_mut.pow(a.toConst(), b, limbs_buffer);
try rma_mut.pow(a, b, limbs_buffer);
rma.setMetadata(rma_mut.positive, rma_mut.len);
}
}
@ -1869,11 +1910,14 @@ fn llmulacc(opt_allocator: ?*Allocator, r: []Limb, a: []const Limb, b: []const L
assert(r.len >= x.len + y.len + 1);
// 48 is a pretty abitrary size chosen based on performance of a factorial program.
if (x.len > 48) {
if (opt_allocator) |allocator| {
llmulacc_karatsuba(allocator, r, x, y) catch |err| switch (err) {
error.OutOfMemory => {}, // handled below
};
k_mul: {
if (x.len > 48) {
if (opt_allocator) |allocator| {
llmulacc_karatsuba(allocator, r, x, y) catch |err| switch (err) {
error.OutOfMemory => break :k_mul, // handled below
};
return;
}
}
}
@ -2203,6 +2247,42 @@ fn llxor(r: []Limb, a: []const Limb, b: []const Limb) void {
}
}
/// r MUST NOT alias x.
fn llsquare_basecase(r: []Limb, x: []const Limb) void {
@setRuntimeSafety(debug_safety);
const x_norm = x;
assert(r.len >= 2 * x_norm.len + 1);
// Compute the square of a N-limb bigint with only (N^2 + N)/2
// multiplications by exploting the symmetry of the coefficients around the
// diagonal:
//
// a b c *
// a b c =
// -------------------
// ca cb cc +
// ba bb bc +
// aa ab ac
//
// Note that:
// - Each mixed-product term appears twice for each column,
// - Squares are always in the 2k (0 <= k < N) column
for (x_norm) |v, i| {
// Accumulate all the x[i]*x[j] (with x!=j) products
llmulDigit(r[2 * i + 1 ..], x_norm[i + 1 ..], v);
}
// Each product appears twice, multiply by 2
llshl(r, r[0 .. 2 * x_norm.len], 1);
for (x_norm) |v, i| {
// Compute and add the squares
llmulDigit(r[2 * i ..], x[i .. i + 1], v);
}
}
/// Knuth 4.6.3
fn llpow(r: []Limb, a: []const Limb, b: u32, tmp_limbs: []Limb) void {
var tmp1: []Limb = undefined;
@ -2212,9 +2292,9 @@ fn llpow(r: []Limb, a: []const Limb, b: u32, tmp_limbs: []Limb) void {
// variable, use the output limbs and another temporary set to overcome this
// limitation.
// The initial assignment makes the result end in `r` so an extra memory
// copy is saved, each 1 flips the index twice so it's a no-op so count the
// 0.
const b_leading_zeros = @intCast(u5, @clz(u32, b));
// copy is saved, each 1 flips the index twice so it's only the zeros that
// matter.
const b_leading_zeros = @clz(u32, b);
const exp_zeros = @popCount(u32, ~b) - b_leading_zeros;
if (exp_zeros & 1 != 0) {
tmp1 = tmp_limbs;
@ -2224,32 +2304,28 @@ fn llpow(r: []Limb, a: []const Limb, b: u32, tmp_limbs: []Limb) void {
tmp2 = tmp_limbs;
}
const a_norm = a[0..llnormalize(a)];
mem.copy(Limb, tmp1, a_norm);
mem.set(Limb, tmp1[a_norm.len..], 0);
mem.copy(Limb, tmp1, a);
mem.set(Limb, tmp1[a.len..], 0);
// Scan the exponent as a binary number, from left to right, dropping the
// most significant bit set.
const exp_bits = @intCast(u5, 31 - b_leading_zeros);
var exp = @bitReverse(u32, b) >> 1 + b_leading_zeros;
// Square the result if the current bit is zero, square and multiply by a if
// it is one.
var exp_bits = 32 - 1 - b_leading_zeros;
var exp = b << @intCast(u5, 1 + b_leading_zeros);
var i: u5 = 0;
var i: usize = 0;
while (i < exp_bits) : (i += 1) {
// Square
{
mem.set(Limb, tmp2, 0);
const op = tmp1[0..llnormalize(tmp1)];
llmulacc(null, tmp2, op, op);
mem.swap([]Limb, &tmp1, &tmp2);
}
mem.set(Limb, tmp2, 0);
llsquare_basecase(tmp2, tmp1[0..llnormalize(tmp1)]);
mem.swap([]Limb, &tmp1, &tmp2);
// Multiply by a
if (exp & 1 != 0) {
if (@shlWithOverflow(u32, exp, 1, &exp)) {
mem.set(Limb, tmp2, 0);
llmulacc(null, tmp2, tmp1[0..llnormalize(tmp1)], a_norm);
llmulacc(null, tmp2, tmp1[0..llnormalize(tmp1)], a);
mem.swap([]Limb, &tmp1, &tmp2);
}
exp >>= 1;
}
}

View File

@ -720,6 +720,27 @@ test "big.int mul 0*0" {
testing.expect((try c.to(u32)) == 0);
}
test "big.int mul large" {
var a = try Managed.initCapacity(testing.allocator, 50);
defer a.deinit();
var b = try Managed.initCapacity(testing.allocator, 100);
defer b.deinit();
var c = try Managed.initCapacity(testing.allocator, 100);
defer c.deinit();
// Generate a number that's large enough to cross the thresholds for the use
// of subquadratic algorithms
for (a.limbs) |*p| {
p.* = std.math.maxInt(Limb);
}
a.setMetadata(true, 50);
try b.mul(a.toConst(), a.toConst());
try c.sqr(a.toConst());
testing.expect(b.eq(c));
}
test "big.int div single-single no rem" {
var a = try Managed.initSet(testing.allocator, 50);
defer a.deinit();
@ -1483,11 +1504,14 @@ test "big.int const to managed" {
test "big.int pow" {
{
var a = try Managed.initSet(testing.allocator, 10);
var a = try Managed.initSet(testing.allocator, -3);
defer a.deinit();
try a.pow(a, 8);
testing.expectEqual(@as(u32, 100000000), try a.to(u32));
try a.pow(a.toConst(), 3);
testing.expectEqual(@as(i32, -27), try a.to(i32));
try a.pow(a.toConst(), 4);
testing.expectEqual(@as(i32, 531441), try a.to(i32));
}
{
var a = try Managed.initSet(testing.allocator, 10);
@ -1497,9 +1521,9 @@ test "big.int pow" {
defer y.deinit();
// y and a are not aliased
try y.pow(a, 123);
try y.pow(a.toConst(), 123);
// y and a are aliased
try a.pow(a, 123);
try a.pow(a.toConst(), 123);
testing.expect(a.eq(y));
@ -1517,18 +1541,18 @@ test "big.int pow" {
var a = try Managed.initSet(testing.allocator, 0);
defer a.deinit();
try a.pow(a, 100);
try a.pow(a.toConst(), 100);
testing.expectEqual(@as(i32, 0), try a.to(i32));
try a.set(1);
try a.pow(a, 0);
try a.pow(a.toConst(), 0);
testing.expectEqual(@as(i32, 1), try a.to(i32));
try a.pow(a, 100);
try a.pow(a.toConst(), 100);
testing.expectEqual(@as(i32, 1), try a.to(i32));
try a.set(-1);
try a.pow(a, 15);
try a.pow(a.toConst(), 15);
testing.expectEqual(@as(i32, -1), try a.to(i32));
try a.pow(a, 16);
try a.pow(a.toConst(), 16);
testing.expectEqual(@as(i32, 1), try a.to(i32));
}
}