Merge pull request #19239 from jedisct1/ml-kem

std.crypto: add support for ML-KEM
This commit is contained in:
Andrew Kelley 2024-03-11 18:48:08 -07:00 committed by GitHub
commit cb4e087fda
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 133 additions and 85 deletions

View File

@ -72,7 +72,8 @@ pub const dh = struct {
/// Key Encapsulation Mechanisms.
pub const kem = struct {
pub const kyber_d00 = @import("crypto/kyber_d00.zig");
pub const kyber_d00 = @import("crypto/ml_kem.zig").kyber_d00;
pub const ml_kem_01 = @import("crypto/ml_kem.zig").ml_kem_01;
};
/// Elliptic-curve arithmetic.

View File

@ -1,14 +1,15 @@
//! Implementation of the IND-CCA2 post-quantum secure key encapsulation
//! mechanism (KEM) CRYSTALS-Kyber, as submitted to the third round of the NIST
//! Post-Quantum Cryptography (v3.02/"draft00"), and selected for standardisation.
//! Implementation of the IND-CCA2 post-quantum secure key encapsulation mechanism (KEM)
//! ML-KEM (NIST FIPS-203 publication) and CRYSTALS-Kyber (v3.02/"draft00" CFRG draft).
//!
//! Kyber will likely change before final standardisation.
//! The schemes are not finalized yet, and are still subject to breaking changes.
//!
//! The namespace suffix (currently `_d00`) refers to the version currently
//! implemented, in accordance with the draft. It may not be updated if new
//! versions of the draft only include editorial changes.
//! The Kyber namespace suffix (currently `_d00`) refers to the version currently
//! implemented, in accordance with the draft.
//! The ML-KEM namespace suffix (currently `_01`) refers to the NIST FIPS-203 draft
//! published on August 24, 2023, with the unintentional transposition of  having been reverted.
//!
//! The suffix will eventually be removed once Kyber is finalized.
//! Suffixes may not be updated if new versions of the documents only include editorial changes.
//! The suffixes will be removed once the schemes are finalized.
//!
//! Quoting from the CFRG I-D:
//!
@ -108,6 +109,7 @@ const builtin = @import("builtin");
const testing = std.testing;
const assert = std.debug.assert;
const crypto = std.crypto;
const errors = std.crypto.errors;
const math = std.math;
const mem = std.mem;
const RndGen = std.Random.DefaultPrng;
@ -128,6 +130,9 @@ const eta2: u8 = 2;
const Params = struct {
name: []const u8,
// NIST ML-KEM variant instead of Kyber as originally submitted.
ml_kem: bool = false,
// Width and height of the matrix A.
k: u8,
@ -143,31 +148,69 @@ const Params = struct {
dv: u8,
};
pub const Kyber512 = Kyber(.{
.name = "Kyber512",
.k = 2,
.eta1 = 3,
.du = 10,
.dv = 4,
});
pub const kyber_d00 = struct {
pub const Kyber512 = Kyber(.{
.name = "Kyber512",
.k = 2,
.eta1 = 3,
.du = 10,
.dv = 4,
});
pub const Kyber768 = Kyber(.{
.name = "Kyber768",
.k = 3,
.eta1 = 2,
.du = 10,
.dv = 4,
});
pub const Kyber768 = Kyber(.{
.name = "Kyber768",
.k = 3,
.eta1 = 2,
.du = 10,
.dv = 4,
});
pub const Kyber1024 = Kyber(.{
.name = "Kyber1024",
.k = 4,
.eta1 = 2,
.du = 11,
.dv = 5,
});
pub const Kyber1024 = Kyber(.{
.name = "Kyber1024",
.k = 4,
.eta1 = 2,
.du = 11,
.dv = 5,
});
};
const modes = [_]type{ Kyber512, Kyber768, Kyber1024 };
pub const ml_kem_01 = struct {
pub const MLKem512 = Kyber(.{
.name = "ML-KEM-512",
.ml_kem = true,
.k = 2,
.eta1 = 3,
.du = 10,
.dv = 4,
});
pub const MLKem768 = Kyber(.{
.name = "ML-KEM-768",
.ml_kem = true,
.k = 3,
.eta1 = 2,
.du = 10,
.dv = 4,
});
pub const MLKem1024 = Kyber(.{
.name = "ML-KEM-1024",
.ml_kem = true,
.k = 4,
.eta1 = 2,
.du = 11,
.dv = 5,
});
};
const modes = [_]type{
kyber_d00.Kyber512,
kyber_d00.Kyber768,
kyber_d00.Kyber1024,
ml_kem_01.MLKem512,
ml_kem_01.MLKem768,
ml_kem_01.MLKem1024,
};
const h_length: usize = 32;
const inner_seed_length: usize = 32;
const common_encaps_seed_length: usize = 32;
@ -211,18 +254,18 @@ fn Kyber(comptime p: Params) type {
/// If `seed` is `null`, a random seed is used. This is recommended.
/// If `seed` is set, encapsulation is deterministic.
pub fn encaps(pk: PublicKey, seed_: ?[encaps_seed_length]u8) EncapsulatedSecret {
const seed = seed_ orelse seed: {
var random_seed: [encaps_seed_length]u8 = undefined;
crypto.random.bytes(&random_seed);
break :seed random_seed;
};
var m: [inner_plaintext_length]u8 = undefined;
// m = H(seed)
var h = sha3.Sha3_256.init(.{});
h.update(&seed);
h.final(&m);
if (seed_) |seed| {
if (p.ml_kem) {
@memcpy(&m, &seed);
} else {
// m = H(seed)
sha3.Sha3_256.hash(&seed, &m, .{});
}
} else {
crypto.random.bytes(&m);
}
// (K', r) = G(m H(pk))
var kr: [inner_plaintext_length + h_length]u8 = undefined;
@ -231,24 +274,25 @@ fn Kyber(comptime p: Params) type {
g.update(&pk.hpk);
g.final(&kr);
// c = innerEncrypy(pk, m, r)
// c = innerEncrypt(pk, m, r)
const ct = pk.pk.encrypt(&m, kr[32..64]);
// Compute H(c) and put in second slot of kr, which will be (K', H(c)).
h = sha3.Sha3_256.init(.{});
h.update(&ct);
h.final(kr[32..64]);
if (p.ml_kem) {
return EncapsulatedSecret{
.shared_secret = kr[0..shared_length].*, // ML-KEM: K = K'
.ciphertext = ct,
};
} else {
// Compute H(c) and put in second slot of kr, which will be (K', H(c)).
sha3.Sha3_256.hash(&ct, kr[32..], .{});
// K = KDF(K' H(c))
var kdf = sha3.Shake256.init(.{});
kdf.update(&kr);
var ss: [shared_length]u8 = undefined;
kdf.squeeze(&ss);
return EncapsulatedSecret{
.shared_secret = ss,
.ciphertext = ct,
};
var ss: [shared_length]u8 = undefined;
sha3.Shake256.hash(&kr, &ss, .{});
return EncapsulatedSecret{
.shared_secret = ss, // Kyber: K = KDF(K' H(c))
.ciphertext = ct,
};
}
}
/// Serializes the key into a byte array.
@ -257,13 +301,10 @@ fn Kyber(comptime p: Params) type {
}
/// Deserializes the key from a byte array.
pub fn fromBytes(buf: *const [bytes_length]u8) !PublicKey {
pub fn fromBytes(buf: *const [bytes_length]u8) errors.NonCanonicalError!PublicKey {
var ret: PublicKey = undefined;
ret.pk = InnerPk.fromBytes(buf[0..InnerPk.bytes_length]);
var h = sha3.Sha3_256.init(.{});
h.update(buf);
h.final(&ret.hpk);
ret.pk = try InnerPk.fromBytes(buf[0..InnerPk.bytes_length]);
sha3.Sha3_256.hash(buf, &ret.hpk, .{});
return ret;
}
};
@ -295,19 +336,20 @@ fn Kyber(comptime p: Params) type {
const ct2 = sk.pk.encrypt(&m2, kr2[32..64]);
// Compute H(ct) and put in the second slot of kr2 which will be (K'', H(ct)).
var h = sha3.Sha3_256.init(.{});
h.update(ct);
h.final(kr2[32..64]);
sha3.Sha3_256.hash(ct, kr2[32..], .{});
// Replace K'' by z in the first slot of kr2 if ct ct'.
cmov(32, kr2[0..32], sk.z, ctneq(ciphertext_length, ct.*, ct2));
// K = KDF(K''/z, H(c))
var kdf = sha3.Shake256.init(.{});
var ss: [shared_length]u8 = undefined;
kdf.update(&kr2);
kdf.squeeze(&ss);
return ss;
if (p.ml_kem) {
// ML-KEM: K = K''/z
return kr2[0..shared_length].*;
} else {
// Kyber: K = KDF(K''/z H(c))
var ss: [shared_length]u8 = undefined;
sha3.Shake256.hash(&kr2, &ss, .{});
return ss;
}
}
/// Serializes the key into a byte array.
@ -316,12 +358,12 @@ fn Kyber(comptime p: Params) type {
}
/// Deserializes the key from a byte array.
pub fn fromBytes(buf: *const [bytes_length]u8) !SecretKey {
pub fn fromBytes(buf: *const [bytes_length]u8) errors.NonCanonicalError!SecretKey {
var ret: SecretKey = undefined;
comptime var s: usize = 0;
ret.sk = InnerSk.fromBytes(buf[s .. s + InnerSk.bytes_length]);
s += InnerSk.bytes_length;
ret.pk = InnerPk.fromBytes(buf[s .. s + InnerPk.bytes_length]);
ret.pk = try InnerPk.fromBytes(buf[s .. s + InnerPk.bytes_length]);
s += InnerPk.bytes_length;
ret.hpk = buf[s..][0..h_length].*;
s += h_length;
@ -359,9 +401,7 @@ fn Kyber(comptime p: Params) type {
ret.secret_key.z = seed[inner_seed_length..seed_length].*;
// Compute H(pk)
var h = sha3.Sha3_256.init(.{});
h.update(&ret.public_key.pk.toBytes());
h.final(&ret.secret_key.hpk);
sha3.Sha3_256.hash(&ret.public_key.pk.toBytes(), &ret.secret_key.hpk, .{});
ret.public_key.hpk = ret.secret_key.hpk;
return ret;
@ -415,9 +455,19 @@ fn Kyber(comptime p: Params) type {
return pk.th.toBytes() ++ pk.rho;
}
fn fromBytes(buf: *const [bytes_length]u8) InnerPk {
fn fromBytes(buf: *const [bytes_length]u8) errors.NonCanonicalError!InnerPk {
var ret: InnerPk = undefined;
ret.th = V.fromBytes(buf[0..V.bytes_length]).normalize();
const th_bytes = buf[0..V.bytes_length];
ret.th = V.fromBytes(th_bytes).normalize();
if (p.ml_kem) {
// Verify that the coefficients used a canonical representation.
if (!mem.eql(u8, &ret.th.toBytes(), th_bytes)) {
return error.NonCanonical;
}
}
ret.rho = buf[V.bytes_length..bytes_length].*;
ret.aT = M.uniform(ret.rho, true);
return ret;
@ -455,10 +505,7 @@ fn Kyber(comptime p: Params) type {
// Derives inner PKE keypair from given seed.
fn innerKeyFromSeed(seed: [inner_seed_length]u8, pk: *InnerPk, sk: *InnerSk) void {
var expanded_seed: [64]u8 = undefined;
var h = sha3.Sha3_512.init(.{});
h.update(&seed);
h.final(&expanded_seed);
sha3.Sha3_512.hash(&seed, &expanded_seed, .{});
pk.rho = expanded_seed[0..32].*;
const sigma = expanded_seed[32..64];
pk.aT = M.uniform(pk.rho, false); // Expand ρ to A; we'll transpose later on
@ -1675,9 +1722,9 @@ const sha2 = crypto.hash.sha2;
test "NIST KAT test" {
inline for (.{
.{ Kyber512, "e9c2bd37133fcb40772f81559f14b1f58dccd1c816701be9ba6214d43baf4547" },
.{ Kyber1024, "89248f2f33f7f4f7051729111f3049c409a933ec904aedadf035f30fa5646cd5" },
.{ Kyber768, "a1e122cad3c24bc51622e4c242d8b8acbcd3f618fee4220400605ca8f9ea02c2" },
.{ kyber_d00.Kyber512, "e9c2bd37133fcb40772f81559f14b1f58dccd1c816701be9ba6214d43baf4547" },
.{ kyber_d00.Kyber1024, "89248f2f33f7f4f7051729111f3049c409a933ec904aedadf035f30fa5646cd5" },
.{ kyber_d00.Kyber768, "a1e122cad3c24bc51622e4c242d8b8acbcd3f618fee4220400605ca8f9ea02c2" },
}) |modeHash| {
const mode = modeHash[0];
var seed: [48]u8 = undefined;