]> git.codecow.com Git - Monocypher.git/commitdiff
Optimised scalar inversion with Montgomery multiplication
authorLoup Vaillant <loup@loup-vaillant.fr>
Sun, 29 Mar 2020 14:09:51 +0000 (16:09 +0200)
committerLoup Vaillant <loup@loup-vaillant.fr>
Sun, 29 Mar 2020 14:09:51 +0000 (16:09 +0200)
This causes us to overshoot the 2000 lines mark by 35 lines or so.  But
this is much faster than using the much slower mul_add() routine.

src/monocypher.c

index 1be641614480fab3b79e05272b3c9318a6d37a3b..7ae8d44c3d8d50b15a0f5c7e437819239880ffd8 100644 (file)
@@ -2650,6 +2650,80 @@ void crypto_key_exchange(u8       shared_key[32],
 ///////////////////////
 /// Scalar division ///
 ///////////////////////
+static void multiply(u32 p[16], const u32 a[8], const u32 b[8])
+{
+    ZERO(p, 16);
+    FOR (i, 0, 8) {
+        u64 carry = 0;
+        FOR (j, 0, 8) {
+            carry  += p[i+j] + (u64)a[i] * b[j];
+            p[i+j]  = (u32)carry;
+            carry >>= 32;
+        }
+        p[i+8] = (u32)carry;
+    }
+}
+
+// Montgomery reduction.
+// Divides x by (2^256), and reduces the result modulo L
+//
+// Precondition:
+//   x < L * 2^256
+// Constants:
+//   r = 2^256                 (makes division by r trivial)
+//   k = (r * (1/r) - 1) // L  (1/r is computed modulo L   )
+// Algorithm:
+//   s = (x * k) % r
+//   t = x + s*L      (t is always a multiple of r)
+//   u = (t/r) % L    (u is always below 2*L, conditional subtraction is enough)
+static void redc(u32 u[8], u32 x[16])
+{
+    static const u32 k[8]  = { 0x12547e1b, 0xd2b51da3, 0xfdba84ff, 0xb1a206f2,
+                               0xffa36bea, 0x14e75438, 0x6fe91836, 0x9db6c6f2,};
+    static const u32 l[8]  = { 0x5cf5d3ed, 0x5812631a, 0xa2f79cd6, 0x14def9de,
+                               0x00000000, 0x00000000, 0x00000000, 0x10000000,};
+    // s = x * k (modulo 2^256)
+    // This is cheaper than the full multiplication.
+    u32 s[8] = {0};
+    FOR (i, 0, 8) {
+        u64 carry = 0;
+        FOR (j, 0, 8-i) {
+            carry  += s[i+j] + (u64)x[i] * k[j];
+            s[i+j]  = (u32)carry;
+            carry >>= 32;
+        }
+    }
+    u32 t[16];
+    multiply(t, s, l);
+
+    // t = t + x
+    u64 carry = 0;
+    FOR (i, 0, 16) {
+        carry  += (u64)t[i] + x[i];
+        t[i]    = (u32)carry;
+        carry >>= 32;
+    }
+
+    // u = (t / 2^256) % L
+    // Note that t / 2^256 is always below 2*L,
+    // So a constant time conditional subtraction is enough
+    // We work with L directly, in a 2's complement encoding
+    // (-L == ~L + 1)
+    carry = 1;
+    FOR (i, 0, 8) {
+        carry  += (u64)t[i+8] + ~l[i];
+        carry >>= 32;
+    }
+    u32 mask = (u32)-carry; // carry == 0 or 1
+    FOR (i, 0, 8) {
+        carry  += (u64)t[i+8] + (~l[i] & mask);
+        u[i]    = (u32)carry;
+        carry >>= 32;
+    }
+    WIPE_BUFFER(s);
+    WIPE_BUFFER(t);
+}
+
 void crypto_x25519_inverse(u8       blind_salt [32],
                            const u8 private_key[32],
                            const u8 curve_point[32])
@@ -2660,29 +2734,53 @@ void crypto_x25519_inverse(u8       blind_salt [32],
         0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
         0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x10,
     };
+    // 1 in Montgomery form
+    u32 m_inv [8] = {0x8d98951d, 0xd6ec3174, 0x737dcf70, 0xc6ef5bf4,
+                     0xfffffffe, 0xffffffff, 0xffffffff, 0x0fffffff,};
+
     u8 scalar[32];
     trim_scalar(scalar, private_key);
-    u8 inverse[32] = {1};
+
+    // Convert the scalar in Montgomery form
+    // m_scl = scalar * 2^256 (modulo L)
+    u32 m_scl[8];
+    i64 tmp[64];
+    ZERO(tmp, 32);
+    COPY(tmp+32, scalar, 32);
+    modL(scalar, tmp);
+    load32_le_buf(m_scl, scalar, 8);
+    WIPE_BUFFER(tmp); // Wipe ASAP to save stack space
+
+    u32 product[16];
     for (int i = 252; i >= 0; i--) {
-        mul_add(inverse, inverse, inverse, zero);
+        multiply(product, m_inv, m_inv);
+        redc(m_inv, product);
         if (scalar_bit(Lm2, i)) {
-            mul_add(inverse, inverse, scalar, zero);
+            multiply(product, m_inv, m_scl);
+            redc(m_inv, product);
         }
     }
-    // Clear the cofactor of inverse:
-    //   cleared = inverse * (3*L + 1)       (modulo 8*L)
-    //   cleared = inverse + inverse * 3 * L (modulo 8*L)
-    // Note that (inverse * 3) is reduced modulo 8, so we only need the
+    // Convert the inverse *out* of Montgomery form
+    // scalar = m_inv / 2^256 (modulo L)
+    COPY(product, m_inv, 8);
+    ZERO(product + 8, 8);
+    redc(m_inv, product);
+    store32_le_buf(scalar, m_inv, 8); // the *inverse* of the scalar
+
+    // Clear the cofactor of scalar:
+    //   cleared = scalar * (3*L + 1)       (modulo 8*L)
+    //   cleared = scalar + scalar * 3 * L (modulo 8*L)
+    // Note that (scalar * 3) is reduced modulo 8, so we only need the
     // first byte.
-    add_xl(inverse, inverse[0] * 3);
+    add_xl(scalar, scalar[0] * 3);
 
     // Recall that 8*L < 2^256. However it is also very close to
     // 2^255. If we spanned the ladder over 255 bits, random tests
     // wouldn't catch the off-by-one error.
-    scalarmult(blind_salt, inverse, curve_point, 256);
+    scalarmult(blind_salt, scalar, curve_point, 256);
 
-    WIPE_BUFFER(scalar);
-    WIPE_BUFFER(inverse);
+    WIPE_BUFFER(scalar);   WIPE_BUFFER(m_scl);
+    WIPE_BUFFER(product);  WIPE_BUFFER(m_inv);
 }
 
 ////////////////////////////////