]> git.codecow.com Git - Monocypher.git/commitdiff
Simplified crypto_x25519_dirty_small() a tiny bit
authorLoup Vaillant <loup@loup-vaillant.fr>
Sat, 22 May 2021 13:46:44 +0000 (15:46 +0200)
committerLoup Vaillant <loup@loup-vaillant.fr>
Sat, 22 May 2021 13:54:38 +0000 (15:54 +0200)
To give the same results as crypto_x25519_dirty_fast(), we originally
multiplied the cofactor by 5 before we multiplied it by L. I noticed
however that this multiplication by 5 could be baked in the base point
itself, and simplifies the computation a little bit.

This also saves a single MUL instruction.

src/monocypher.c
tests/gen/elligator_scalarmult.py

index 06f9522a89a295a6b09ad1dfc21e709387717ada..8b8478712b22f5c6a6fd20f1185a5d23bd346727 100644 (file)
@@ -2549,10 +2549,14 @@ void crypto_x25519_dirty_small(u8 public_key[32], const u8 secret_key[32])
     // Base point of order 8*L
     // Raw scalar multiplication with it does not clear the cofactor,
     // and the resulting public key will reveal 3 bits of the scalar.
+    //
+    // The low order component of this base point  has been chosen
+    // to yield the same results as crypto_x25519_dirty_fast().
     static const u8 dirty_base_point[32] = {
-        0x34, 0xfc, 0x6c, 0xb7, 0xc8, 0xde, 0x58, 0x97, 0x77, 0x70, 0xd9, 0x52,
-        0x16, 0xcc, 0xdc, 0x6c, 0x85, 0x90, 0xbe, 0xcd, 0x91, 0x9c, 0x07, 0x59,
-        0x94, 0x14, 0x56, 0x3b, 0x4b, 0xa4, 0x47, 0x0f, };
+        0xd8, 0x86, 0x1a, 0xa2, 0x78, 0x7a, 0xd9, 0x26, 0x8b, 0x74, 0x74, 0xb6,
+        0x82, 0xe3, 0xbe, 0xc3, 0xce, 0x36, 0x9a, 0x1e, 0x5e, 0x31, 0x47, 0xa2,
+        0x6d, 0x37, 0x7c, 0xfd, 0x20, 0xb5, 0xdf, 0x75,
+    };
     // separate the main factor & the cofactor of the scalar
     u8 scalar[32];
     COPY(scalar, secret_key, 32);
@@ -2564,11 +2568,9 @@ void crypto_x25519_dirty_small(u8 public_key[32], const u8 secret_key[32])
     // least significant bits however still have a main factor.  We must
     // remove it for X25519 compatibility.
     //
-    // We exploit the fact that 5*L = 1 (modulo 8)
-    //   cofactor = lsb * 5 * L             (modulo 8*L)
-    //   combined = scalar + cofactor       (modulo 8*L)
-    //   combined = scalar + (lsb * 5 * L)  (modulo 8*L)
-    add_xl(scalar, secret_key[0] * 5);
+    //   cofactor = lsb * L            (modulo 8*L)
+    //   combined = scalar + cofactor  (modulo 8*L)
+    add_xl(scalar, secret_key[0]);
     scalarmult(public_key, scalar, dirty_base_point, 256);
     WIPE_BUFFER(scalar);
 }
index 2ee08604de79c254140be0e5ddd32f4c285ed96a..a0d84e020163f78d09d1bc26fe7488d8abe4fc6f 100644 (file)
@@ -153,12 +153,16 @@ montgomery_base = 9
 # Point of order 8, used to add the cofactor component
 low_order_point_x = sqrt((sqrt(d + fe(1)) + fe(1)) / d)
 low_order_point_y = -low_order_point_x * sqrtm1
-low_order_point   = (low_order_point_x, low_order_point_y)
+low_order_point_1 = (low_order_point_x, low_order_point_y)
+low_order_point_2 = point_add2(low_order_point_1, low_order_point_1)
+low_order_point_4 = point_add2(low_order_point_2, low_order_point_2)
+low_order_point_8 = point_add2(low_order_point_4, low_order_point_4)
+low_order_point_5 = point_add2(low_order_point_1, low_order_point_4)
 
 def check_low_order_point():
-    lop2 = point_add2(low_order_point, low_order_point)
-    lop4 = point_add2(lop2, lop2)
-    lop8 = point_add2(lop4, lop4)
+    lop2 = low_order_point_2
+    lop4 = low_order_point_4
+    lop8 = low_order_point_8
     zero = (fe(0), fe(1))
     if lop8 != zero: raise ValueError('low_order_point does not have low order')
     if lop2 == zero: raise ValueError('low_order_point only has order 2')
@@ -166,8 +170,10 @@ def check_low_order_point():
 check_low_order_point()
 
 # base point + low order point
-ed_base = point_add2(low_order_point, edwards_base)   # in Edwards space
-mt_base = (fe(1) + ed_base[1]) / (fe(1) - ed_base[1]) # in Montgomery space
+ed_base_1 = point_add2(low_order_point_1, edwards_base) # in Edwards space
+ed_base_5 = point_add2(low_order_point_5, edwards_base) # in Edwards space
+mt_base_1 = (fe(1)+ed_base_1[1]) / (fe(1)-ed_base_1[1]) # in Montgomery space
+mt_base_5 = (fe(1)+ed_base_5[1]) / (fe(1)-ed_base_5[1]) # in Montgomery space
 
 # Clamp the scalar.
 # % 8 stops subgroup attacks
@@ -186,25 +192,39 @@ order = 2**252 + 27742317777372353535851937790883648493
 def scalarmult1(scalar, cofactor):
     co_cleared = ((cofactor * 5) % 8) * order  # cleared main factor
     combined   = trim(scalar) + co_cleared
-    return from_edwards(ed_scalarmult(ed_base, combined))
+    return from_edwards(ed_scalarmult(ed_base_1, combined))
 
-# Single scalar multiplication (in Montgomery space)
+# Single scalar multiplication (in Edwards space, simplified)
 def scalarmult2(scalar, cofactor):
+    co_cleared = (cofactor % 8) * order  # cleared main factor
+    combined   = trim(scalar) + co_cleared
+    return from_edwards(ed_scalarmult(ed_base_5, combined))
+
+# Single scalar multiplication (in Montgomery space)
+def scalarmult3(scalar, cofactor):
     co_cleared = ((cofactor * 5) % 8) * order  # cleared main factor
     combined   = trim(scalar) + co_cleared
-    return mt_scalarmult(mt_base, combined)
+    return mt_scalarmult(mt_base_1, combined)
+
+# Single scalar multiplication (in Montgomery space, simplified)
+def scalarmult4(scalar, cofactor):
+    co_cleared = (cofactor % 8) * order  # cleared main factor
+    combined   = trim(scalar) + co_cleared
+    return mt_scalarmult(mt_base_5, combined)
 
 # Double scalar multiplication (reuses EdDSA code)
-def scalarmult3(scalar, cofactor):
-    main_point = ed_scalarmult(edwards_base   , trim(scalar))
-    low_order  = ed_scalarmult(low_order_point, cofactor    )
+def scalarmult5(scalar, cofactor):
+    main_point = ed_scalarmult(edwards_base     , trim(scalar))
+    low_order  = ed_scalarmult(low_order_point_1, cofactor    )
     return from_edwards(point_add(main_point, low_order))
 
-# Combine and compare all ways ofd doing the scalar multiplication
+# Combine and compare all ways of doing the scalar multiplication
 def scalarmult(scalar, cofactor):
     p1 = scalarmult1(scalar, cofactor)
     p2 = scalarmult2(scalar, cofactor)
     p3 = scalarmult3(scalar, cofactor)
-    if p1 != p2 or p1 != p3:
+    p4 = scalarmult4(scalar, cofactor)
+    p5 = scalarmult5(scalar, cofactor)
+    if p1 != p2 or p1 != p3 or p1 != p4 or p1 != p5:
         raise ValueError('Incoherent scalarmult')
     return p1