From b2283d6919915840002ce5e446f1a9518a95da4e Mon Sep 17 00:00:00 2001 From: Loup Vaillant Date: Sun, 29 Mar 2020 16:09:51 +0200 Subject: [PATCH] Optimised scalar inversion with Montgomery multiplication This causes us to overshoot the 2000 lines mark by 35 lines or so. But this is much faster than using the much slower mul_add() routine. --- src/monocypher.c | 120 ++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 109 insertions(+), 11 deletions(-) diff --git a/src/monocypher.c b/src/monocypher.c index 1be6416..7ae8d44 100644 --- a/src/monocypher.c +++ b/src/monocypher.c @@ -2650,6 +2650,80 @@ void crypto_key_exchange(u8 shared_key[32], /////////////////////// /// Scalar division /// /////////////////////// +static void multiply(u32 p[16], const u32 a[8], const u32 b[8]) +{ + ZERO(p, 16); + FOR (i, 0, 8) { + u64 carry = 0; + FOR (j, 0, 8) { + carry += p[i+j] + (u64)a[i] * b[j]; + p[i+j] = (u32)carry; + carry >>= 32; + } + p[i+8] = (u32)carry; + } +} + +// Montgomery reduction. +// Divides x by (2^256), and reduces the result modulo L +// +// Precondition: +// x < L * 2^256 +// Constants: +// r = 2^256 (makes division by r trivial) +// k = (r * (1/r) - 1) // L (1/r is computed modulo L ) +// Algorithm: +// s = (x * k) % r +// t = x + s*L (t is always a multiple of r) +// u = (t/r) % L (u is always below 2*L, conditional subtraction is enough) +static void redc(u32 u[8], u32 x[16]) +{ + static const u32 k[8] = { 0x12547e1b, 0xd2b51da3, 0xfdba84ff, 0xb1a206f2, + 0xffa36bea, 0x14e75438, 0x6fe91836, 0x9db6c6f2,}; + static const u32 l[8] = { 0x5cf5d3ed, 0x5812631a, 0xa2f79cd6, 0x14def9de, + 0x00000000, 0x00000000, 0x00000000, 0x10000000,}; + // s = x * k (modulo 2^256) + // This is cheaper than the full multiplication. + u32 s[8] = {0}; + FOR (i, 0, 8) { + u64 carry = 0; + FOR (j, 0, 8-i) { + carry += s[i+j] + (u64)x[i] * k[j]; + s[i+j] = (u32)carry; + carry >>= 32; + } + } + u32 t[16]; + multiply(t, s, l); + + // t = t + x + u64 carry = 0; + FOR (i, 0, 16) { + carry += (u64)t[i] + x[i]; + t[i] = (u32)carry; + carry >>= 32; + } + + // u = (t / 2^256) % L + // Note that t / 2^256 is always below 2*L, + // So a constant time conditional subtraction is enough + // We work with L directly, in a 2's complement encoding + // (-L == ~L + 1) + carry = 1; + FOR (i, 0, 8) { + carry += (u64)t[i+8] + ~l[i]; + carry >>= 32; + } + u32 mask = (u32)-carry; // carry == 0 or 1 + FOR (i, 0, 8) { + carry += (u64)t[i+8] + (~l[i] & mask); + u[i] = (u32)carry; + carry >>= 32; + } + WIPE_BUFFER(s); + WIPE_BUFFER(t); +} + void crypto_x25519_inverse(u8 blind_salt [32], const u8 private_key[32], const u8 curve_point[32]) @@ -2660,29 +2734,53 @@ void crypto_x25519_inverse(u8 blind_salt [32], 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x10, }; + // 1 in Montgomery form + u32 m_inv [8] = {0x8d98951d, 0xd6ec3174, 0x737dcf70, 0xc6ef5bf4, + 0xfffffffe, 0xffffffff, 0xffffffff, 0x0fffffff,}; + u8 scalar[32]; trim_scalar(scalar, private_key); - u8 inverse[32] = {1}; + + // Convert the scalar in Montgomery form + // m_scl = scalar * 2^256 (modulo L) + u32 m_scl[8]; + i64 tmp[64]; + ZERO(tmp, 32); + COPY(tmp+32, scalar, 32); + modL(scalar, tmp); + load32_le_buf(m_scl, scalar, 8); + WIPE_BUFFER(tmp); // Wipe ASAP to save stack space + + u32 product[16]; for (int i = 252; i >= 0; i--) { - mul_add(inverse, inverse, inverse, zero); + multiply(product, m_inv, m_inv); + redc(m_inv, product); if (scalar_bit(Lm2, i)) { - mul_add(inverse, inverse, scalar, zero); + multiply(product, m_inv, m_scl); + redc(m_inv, product); } } - // Clear the cofactor of inverse: - // cleared = inverse * (3*L + 1) (modulo 8*L) - // cleared = inverse + inverse * 3 * L (modulo 8*L) - // Note that (inverse * 3) is reduced modulo 8, so we only need the + // Convert the inverse *out* of Montgomery form + // scalar = m_inv / 2^256 (modulo L) + COPY(product, m_inv, 8); + ZERO(product + 8, 8); + redc(m_inv, product); + store32_le_buf(scalar, m_inv, 8); // the *inverse* of the scalar + + // Clear the cofactor of scalar: + // cleared = scalar * (3*L + 1) (modulo 8*L) + // cleared = scalar + scalar * 3 * L (modulo 8*L) + // Note that (scalar * 3) is reduced modulo 8, so we only need the // first byte. - add_xl(inverse, inverse[0] * 3); + add_xl(scalar, scalar[0] * 3); // Recall that 8*L < 2^256. However it is also very close to // 2^255. If we spanned the ladder over 255 bits, random tests // wouldn't catch the off-by-one error. - scalarmult(blind_salt, inverse, curve_point, 256); + scalarmult(blind_salt, scalar, curve_point, 256); - WIPE_BUFFER(scalar); - WIPE_BUFFER(inverse); + WIPE_BUFFER(scalar); WIPE_BUFFER(m_scl); + WIPE_BUFFER(product); WIPE_BUFFER(m_inv); } //////////////////////////////// -- 2.47.3