From 9c3b859fb59b096bde8dd8b70f758f6fc815dd44 Mon Sep 17 00:00:00 2001 From: Loup Vaillant Date: Sat, 22 May 2021 15:46:44 +0200 Subject: [PATCH] Simplified crypto_x25519_dirty_small() a tiny bit To give the same results as crypto_x25519_dirty_fast(), we originally multiplied the cofactor by 5 before we multiplied it by L. I noticed however that this multiplication by 5 could be baked in the base point itself, and simplifies the computation a little bit. This also saves a single MUL instruction. --- src/monocypher.c | 18 ++++++------ tests/gen/elligator_scalarmult.py | 48 ++++++++++++++++++++++--------- 2 files changed, 44 insertions(+), 22 deletions(-) diff --git a/src/monocypher.c b/src/monocypher.c index 06f9522..8b84787 100644 --- a/src/monocypher.c +++ b/src/monocypher.c @@ -2549,10 +2549,14 @@ void crypto_x25519_dirty_small(u8 public_key[32], const u8 secret_key[32]) // Base point of order 8*L // Raw scalar multiplication with it does not clear the cofactor, // and the resulting public key will reveal 3 bits of the scalar. + // + // The low order component of this base point has been chosen + // to yield the same results as crypto_x25519_dirty_fast(). static const u8 dirty_base_point[32] = { - 0x34, 0xfc, 0x6c, 0xb7, 0xc8, 0xde, 0x58, 0x97, 0x77, 0x70, 0xd9, 0x52, - 0x16, 0xcc, 0xdc, 0x6c, 0x85, 0x90, 0xbe, 0xcd, 0x91, 0x9c, 0x07, 0x59, - 0x94, 0x14, 0x56, 0x3b, 0x4b, 0xa4, 0x47, 0x0f, }; + 0xd8, 0x86, 0x1a, 0xa2, 0x78, 0x7a, 0xd9, 0x26, 0x8b, 0x74, 0x74, 0xb6, + 0x82, 0xe3, 0xbe, 0xc3, 0xce, 0x36, 0x9a, 0x1e, 0x5e, 0x31, 0x47, 0xa2, + 0x6d, 0x37, 0x7c, 0xfd, 0x20, 0xb5, 0xdf, 0x75, + }; // separate the main factor & the cofactor of the scalar u8 scalar[32]; COPY(scalar, secret_key, 32); @@ -2564,11 +2568,9 @@ void crypto_x25519_dirty_small(u8 public_key[32], const u8 secret_key[32]) // least significant bits however still have a main factor. We must // remove it for X25519 compatibility. // - // We exploit the fact that 5*L = 1 (modulo 8) - // cofactor = lsb * 5 * L (modulo 8*L) - // combined = scalar + cofactor (modulo 8*L) - // combined = scalar + (lsb * 5 * L) (modulo 8*L) - add_xl(scalar, secret_key[0] * 5); + // cofactor = lsb * L (modulo 8*L) + // combined = scalar + cofactor (modulo 8*L) + add_xl(scalar, secret_key[0]); scalarmult(public_key, scalar, dirty_base_point, 256); WIPE_BUFFER(scalar); } diff --git a/tests/gen/elligator_scalarmult.py b/tests/gen/elligator_scalarmult.py index 2ee0860..a0d84e0 100644 --- a/tests/gen/elligator_scalarmult.py +++ b/tests/gen/elligator_scalarmult.py @@ -153,12 +153,16 @@ montgomery_base = 9 # Point of order 8, used to add the cofactor component low_order_point_x = sqrt((sqrt(d + fe(1)) + fe(1)) / d) low_order_point_y = -low_order_point_x * sqrtm1 -low_order_point = (low_order_point_x, low_order_point_y) +low_order_point_1 = (low_order_point_x, low_order_point_y) +low_order_point_2 = point_add2(low_order_point_1, low_order_point_1) +low_order_point_4 = point_add2(low_order_point_2, low_order_point_2) +low_order_point_8 = point_add2(low_order_point_4, low_order_point_4) +low_order_point_5 = point_add2(low_order_point_1, low_order_point_4) def check_low_order_point(): - lop2 = point_add2(low_order_point, low_order_point) - lop4 = point_add2(lop2, lop2) - lop8 = point_add2(lop4, lop4) + lop2 = low_order_point_2 + lop4 = low_order_point_4 + lop8 = low_order_point_8 zero = (fe(0), fe(1)) if lop8 != zero: raise ValueError('low_order_point does not have low order') if lop2 == zero: raise ValueError('low_order_point only has order 2') @@ -166,8 +170,10 @@ def check_low_order_point(): check_low_order_point() # base point + low order point -ed_base = point_add2(low_order_point, edwards_base) # in Edwards space -mt_base = (fe(1) + ed_base[1]) / (fe(1) - ed_base[1]) # in Montgomery space +ed_base_1 = point_add2(low_order_point_1, edwards_base) # in Edwards space +ed_base_5 = point_add2(low_order_point_5, edwards_base) # in Edwards space +mt_base_1 = (fe(1)+ed_base_1[1]) / (fe(1)-ed_base_1[1]) # in Montgomery space +mt_base_5 = (fe(1)+ed_base_5[1]) / (fe(1)-ed_base_5[1]) # in Montgomery space # Clamp the scalar. # % 8 stops subgroup attacks @@ -186,25 +192,39 @@ order = 2**252 + 27742317777372353535851937790883648493 def scalarmult1(scalar, cofactor): co_cleared = ((cofactor * 5) % 8) * order # cleared main factor combined = trim(scalar) + co_cleared - return from_edwards(ed_scalarmult(ed_base, combined)) + return from_edwards(ed_scalarmult(ed_base_1, combined)) -# Single scalar multiplication (in Montgomery space) +# Single scalar multiplication (in Edwards space, simplified) def scalarmult2(scalar, cofactor): + co_cleared = (cofactor % 8) * order # cleared main factor + combined = trim(scalar) + co_cleared + return from_edwards(ed_scalarmult(ed_base_5, combined)) + +# Single scalar multiplication (in Montgomery space) +def scalarmult3(scalar, cofactor): co_cleared = ((cofactor * 5) % 8) * order # cleared main factor combined = trim(scalar) + co_cleared - return mt_scalarmult(mt_base, combined) + return mt_scalarmult(mt_base_1, combined) + +# Single scalar multiplication (in Montgomery space, simplified) +def scalarmult4(scalar, cofactor): + co_cleared = (cofactor % 8) * order # cleared main factor + combined = trim(scalar) + co_cleared + return mt_scalarmult(mt_base_5, combined) # Double scalar multiplication (reuses EdDSA code) -def scalarmult3(scalar, cofactor): - main_point = ed_scalarmult(edwards_base , trim(scalar)) - low_order = ed_scalarmult(low_order_point, cofactor ) +def scalarmult5(scalar, cofactor): + main_point = ed_scalarmult(edwards_base , trim(scalar)) + low_order = ed_scalarmult(low_order_point_1, cofactor ) return from_edwards(point_add(main_point, low_order)) -# Combine and compare all ways ofd doing the scalar multiplication +# Combine and compare all ways of doing the scalar multiplication def scalarmult(scalar, cofactor): p1 = scalarmult1(scalar, cofactor) p2 = scalarmult2(scalar, cofactor) p3 = scalarmult3(scalar, cofactor) - if p1 != p2 or p1 != p3: + p4 = scalarmult4(scalar, cofactor) + p5 = scalarmult5(scalar, cofactor) + if p1 != p2 or p1 != p3 or p1 != p4 or p1 != p5: raise ValueError('Incoherent scalarmult') return p1 -- 2.47.3