From 740be8d06557c735541f7d2add9af80603325ba0 Mon Sep 17 00:00:00 2001 From: Loup Vaillant Date: Fri, 28 Feb 2020 00:40:23 +0100 Subject: [PATCH] Elligator script: general organisation --- tests/gen/elligator.py | 129 ++++++++++++++++++++++++----------------- 1 file changed, 75 insertions(+), 54 deletions(-) diff --git a/tests/gen/elligator.py b/tests/gen/elligator.py index 256a237..c2c5986 100755 --- a/tests/gen/elligator.py +++ b/tests/gen/elligator.py @@ -53,6 +53,9 @@ import sys # stdin +#################### +# Field arithmetic # +#################### class fe: """Prime field over 2^255 - 19""" p = 2**255 - 19 @@ -87,64 +90,50 @@ class fe: if m != 0: raise ValueError('number is too big!!') print(':') -# Curve25519 constants +######################## +# Curve25519 constants # +######################## p = fe.p A = fe(486662) # B = 1 +############### +# Square root # +############### def chi (n): return n**((p-1)//2) def is_square(n): return n == fe(0) or chi(n) == fe(1) -sqrt1 = ((fe(2)**((p-1) // 4)) * fe(-1)**((p+3) // 8)).abs() +sqrtm1 = ((fe(2)**((p-1) // 4)) * fe(-1)**((p+3) // 8)).abs() def sqrt(n): if not is_square(n) : raise ValueError('Not a square!') root = n**((p+3) // 8) - if root * root != n: root = (root * sqrt1) + if root * root != n: root = (root * sqrtm1) if root * root != n: raise ValueError('Should be a square!!') return root.abs() -# Elligator 2 - -# Arbitrary non square, typically chosen to minimise computation. -# 2 and sqrt(-1) both work fairly well -non_square = fe(2) # that's what standards seem to agree upon - -def hash_to_curve(r): - w = -A / (fe(1) + non_square * r**2) - e = chi(w**3 + A*w**2 + w) - u = e*w - (fe(1)-e)*(A//2) - v = -e * sqrt(u**3 + A*u**2 + u) - return (u, v) - -def can_curve_to_hash(point): - u = point[0] - return u != -A and is_square(-non_square * u * (u + A)) - -def curve_to_hash(point): - if not can_curve_to_hash(point): - raise ValueError('cannot curve to hash') - u = point[0] - v = point[1] - sq1 = sqrt(-u / (non_square * (u+A))) - sq2 = sqrt(-(u+A) / (non_square * u )) - if v.is_positive(): return sq1 - else : return sq2 - -# Edwards (Edwards25519) -# -x^2 + y^2 = 1 + d*x^2*y^2 -d = fe(-121665) / fe(121666) - +######################### +# scalar multiplication # +######################### +# Clamp the scalar. +# % 8 stops subgroup attacks +# Clearing bit 255 and setting bit 254 facilitates constant time ladders. def trim(scalar): trimmed = scalar - scalar % 8 trimmed = trimmed % 2**254 trimmed = trimmed + 2**254 return trimmed -# fast point addition & scalar multiplication with affine coordinates: -# x = X/Z, y = Y/Z. We can multiply Z instead of dividing X and Y. -# The goal is to test the merging of the final inversion -# with the exponentiations required for curve_to_hash +# Edwards25519 equation (d defined below): +# -x^2 + y^2 = 1 + d*x^2*y^2 +d = fe(-121665) / fe(121666) + +# Point addition: +# denum = d*x1*x2*y1*y2 +# x = (x1*y2 + x2*y1) / (1 + denum) +# y = (y1*y2 + x1*x2) / (1 - denum) +# To avoid divisions, we use affine coordinates: x = X/Z, y = Y/Z. +# We can multiply Z instead of dividing X and Y. def point_add(a, b): x1 = a[0]; y1 = a[1]; z1 = a[2]; x2 = b[0]; y2 = b[1]; z2 = b[2]; @@ -188,24 +177,55 @@ def from_edwards(point): zu = z - y v = u * z * sqrt_mA2 zv = zu * x - div = (zu * zv).invert() + div = (zu * zv).invert() # now we have to divide return (u*zv*div, v*zu*div) -def pow_p58(f): return f ** ((p-5)//8) -sqrt_half = (fe(-1) / fe(2)) ** ((p+3)//8) -chi_minus2 = chi(fe(-2)) +def x25519_public_key(private_key): + return from_edwards(scalarbase(private_key)) + +########################### +# Elligator 2 (reference) # +########################### + +# Arbitrary non square, typically chosen to minimise computation. +# 2 and sqrt(-1) both work fairly well +non_square = fe(2) # that's what standards seem to agree upon +def hash_to_curve(r): + w = -A / (fe(1) + non_square * r**2) + e = chi(w**3 + A*w**2 + w) + u = e*w - (fe(1)-e)*(A//2) + v = -e * sqrt(u**3 + A*u**2 + u) + return (u, v) + +def can_curve_to_hash(point): + u = point[0] + return u != -A and is_square(-non_square * u * (u + A)) + +def curve_to_hash(point): + if not can_curve_to_hash(point): + raise ValueError('cannot curve to hash') + u = point[0] + v = point[1] + sq1 = sqrt(-u / (non_square * (u+A))) + sq2 = sqrt(-(u+A) / (non_square * u )) + if v.is_positive(): return sq1 + else : return sq2 + +##################### +# Elligator2 (fast) # +##################### def invsqrt(x): isr = x**((p - 5) // 8) quartic = x * isr**2 - if quartic == fe(-1) or quartic == -sqrt1: - isr = isr * sqrt1 + if quartic == fe(-1) or quartic == -sqrtm1: + isr = isr * sqrtm1 is_square = quartic == fe(1) or quartic == fe(-1) return isr, is_square def fast_hash_to_curve(q): u = non_square - ufactor = -u * sqrt1 + ufactor = -u * sqrtm1 ufactor_sqrt = sqrt(ufactor) r = u * q**2 @@ -218,7 +238,7 @@ def fast_hash_to_curve(q): # y = (A^3*r - A*(r + 1)^2) / (r + 1)^3 isr, is_square = invsqrt(num * den) # if is_square: isr = sqrt(1 / (num * den)) - # if not is_square: isr = sqrt(sqrt1 / (num * den)) + # if not is_square: isr = sqrt(sqrtm1 / (num * den)) x = -A * (num * r1**2 * isr**2) # x = -A * num * (r + 1)^2 * sqrt(1 / (num * den))^2 # x = -A * num * (r + 1)^2 * 1 / (num * den) @@ -237,7 +257,7 @@ def fast_hash_to_curve(q): # x = -A * u * q^2 / (r + 1) # x = -A * r / (r + 1) y = qy * y - # y = q * sqrt(-u * sqrt1) * sqrt(sqrt(-1) * num / den) + # y = q * sqrt(-u * sqrtm1) * sqrt(sqrt(-1) * num / den) # y = sqrt(q^2 * u * num / den) # y = sqrt(r * num / den) y = y.abs() @@ -246,7 +266,7 @@ def fast_hash_to_curve(q): def fast_curve_to_hash(point): u = non_square - ufactor = -u * sqrt1 + ufactor = -u * sqrtm1 ufactor_sqrt = sqrt(ufactor) x, y = point @@ -268,12 +288,13 @@ def fast_curve_to_hash(point): q = q.abs() return q -half_A = A // 2 - -# entire key generation chain -def full_cycle_check(scalar, u): - fe(scalar).print() - uv = from_edwards(scalarbase(scalar)) +############## +# Test suite # +############## +# Test a full round trip, and print the relevant test vectors +def full_cycle_check(private_key, u): + fe(private_key).print() + uv = x25519_public_key(private_key) if uv [0] != u: raise ValueError('Test vector failure') uv[0].print() uv[1].print() -- 2.47.3