From 7abf1b0b291ae1352661cbd8f45c8e5bd275a5dc Mon Sep 17 00:00:00 2001 From: Loup Vaillant Date: Sun, 1 Mar 2020 22:50:39 +0100 Subject: [PATCH] Elligator script: clean up & comments --- tests/gen/elligator.py | 207 ++++++++++++++++++++++++++--------------- 1 file changed, 133 insertions(+), 74 deletions(-) diff --git a/tests/gen/elligator.py b/tests/gen/elligator.py index ab9c7ff..50ad19b 100755 --- a/tests/gen/elligator.py +++ b/tests/gen/elligator.py @@ -162,8 +162,8 @@ d = fe(-121665) / fe(121666) # To avoid divisions, we use affine coordinates: x = X/Z, y = Y/Z. # We can multiply Z instead of dividing X and Y. def point_add(a, b): - x1 = a[0]; y1 = a[1]; z1 = a[2]; - x2 = b[0]; y2 = b[1]; z2 = b[2]; + x1, y1, z1 = a + x2, y2, z2 = b denum = d*x1*x2*y1*y2 z1z2 = z1 * z2 z1z22 = z1z2**2 @@ -202,9 +202,7 @@ sqrt_mA2 = sqrt(fe(-486664)) # sqrt(-(A+2)) # (u, v) = ((1+y)/(1-y), sqrt(-486664)*u/x) # (x, y) = (sqrt(-486664)*u/v, (u-1)/(u+1)) def from_edwards(point): - x = point[0] - y = point[1] - z = point[2] + x, y, z = point u = z + y zu = z - y v = u * z * sqrt_mA2 @@ -244,16 +242,15 @@ def hash_to_curve(r): # Test whether a point has a representative, straight from the paper. def can_curve_to_hash(point): u = point[0] - return u != -A and is_square(-non_square * u * (u + A)) + return u != -A and is_square(-non_square * u * (u+A)) # Computes the representative of a point, straight from the paper. def curve_to_hash(point): if not can_curve_to_hash(point): raise ValueError('cannot curve to hash') - u = point[0] - v = point[1] - sq1 = sqrt(-u / (non_square * (u+A))) - sq2 = sqrt(-(u+A) / (non_square * u )) + u, v = point + sq1 = sqrt(-u / (non_square * (u+A))) + sq2 = sqrt(-(u+A) / (non_square * u )) if v.is_positive(): return sq1 else : return sq2 @@ -262,13 +259,13 @@ def curve_to_hash(point): ##################### # Inverse square root. -# Returns sqrt(1/x) if x is square. -# Returns sqrt(sqrt(-1)/x) if x is not a square. -# We assume x is not zero +# Returns (sqrt(1/x) , True ) if x is non-zero square. +# Returns (sqrt(sqrt(-1)/x), False) if x is not a square. +# Returns (0 , False) if x is zero. # We do not guarantee the sign of the square root. # # Notes: -# let quartic = x^((p-1)/4) +# Let quartic = x^((p-1)/4) # # x^((p-1)/2) = chi(x) # quartic^2 = chi(x) @@ -324,70 +321,132 @@ def invsqrt(x): is_square = quartic == fe(1) or quartic == fe(-1) return isr, is_square -def fast_hash_to_curve(q): - u = non_square - ufactor = -u * sqrtm1 - ufactor_sqrt = sqrt(ufactor) - - r = u * q**2 - r1 = (r + fe(1)) - num = A * (A**2 * r - r1**2) - den = r1**3 - # x = -A / (r + 1) - # y = x^3 + A*x^2 + x - # y = A^3/(r + 1)^2 - A^3/(r + 1)^3 - A/(r + 1) - # y = (A^3*r - A*(r + 1)^2) / (r + 1)^3 - isr, is_square = invsqrt(num * den) - # if is_square: isr = sqrt(1 / (num * den)) - # if not is_square: isr = sqrt(sqrtm1 / (num * den)) - x = -A * (num * r1**2 * isr**2) - # x = -A * num * (r + 1)^2 * sqrt(1 / (num * den))^2 - # x = -A * num * (r + 1)^2 * 1 / (num * den) - # x = -A * (r + 1)^2 * 1 / den - # x = -A / (r + 1) - y = num * isr - # y = num * sqrt(1 / (num * den)) - # y = sqrt(num^2 / (num * den)) - # y = sqrt(num / den) - qx = q**2 * ufactor - qy = q * ufactor_sqrt - if is_square: qx = fe(1) - if is_square: qy = fe(1) - x = qx * x - # x = q^2 * -u * sqrt(-1) * -A * sqrt(-1) / (r + 1) - # x = -A * u * q^2 / (r + 1) - # x = -A * r / (r + 1) - y = qy * y - # y = q * sqrt(-u * sqrtm1) * sqrt(sqrt(-1) * num / den) - # y = sqrt(q^2 * u * num / den) - # y = sqrt(r * num / den) - y = y.abs() - if is_square: y = -y - return (x, y) +# From the paper: +# w = -A / (fe(1) + non_square * r^2) +# e = chi(w^3 + A*w^2 + w) +# u = e*w - (fe(1)-e)*(A//2) +# v = -e * sqrt(u^3 + A*u^2 + u) +# +# Note that e is eiter 0, 1 or -1 +# if e = 0 +# (u, v) = (0, 0) +# if e = 1 +# u = w +# v = -sqrt(u^3 + A*u^2 + u) +# if e = -1 +# u = -w - A = w * non_square * r^2 +# v = sqrt(u^3 + A*u^2 + u) +# +# Let r1 = non_square * r^2 +# Let r2 = 1 + r1 +# Note that r2 cannot be zero, -1/non_square is not a square. +# We can (tediously) verify that: +# w^3 + A*w^2 + w = (A^2*r1 - r2^2) * A / r2^3 +# Therefore: +# chi(w^3 + A*w^2 + w) = chi((A^2*r1 - r2^2) * (A / r2^3)) +# chi(w^3 + A*w^2 + w) = chi((A^2*r1 - r2^2) * (A / r2^3)) * 1 +# chi(w^3 + A*w^2 + w) = chi((A^2*r1 - r2^2) * (A / r2^3)) * chi(r2^6) +# chi(w^3 + A*w^2 + w) = chi((A^2*r1 - r2^2) * (A / r2^3) * r2^6) +# chi(w^3 + A*w^2 + w) = chi((A^2*r1 - r2^2) * A * r2^3) +# Corollary: +# e = 1 if (A^2*r1 - r2^2) * A * r2^3) is a non-zero square +# e = -1 if (A^2*r1 - r2^2) * A * r2^3) is not a square +# Note that w^3 + A*w^2 + w (and therefore e) can never be zero: +# w^3 + A*w^2 + w = w * (w^2 + A*w + 1) +# w^3 + A*w^2 + w = w * (w^2 + A*w + A^2/4 - A^2/4 + 1) +# w^3 + A*w^2 + w = w * (w + A/2)^2 - A^2/4 + 1) +# which is zero only if: +# w = 0 (impossible) +# (w + A/2)^2 = A^2/4 - 1 (impossible, because A^2/4-1 is not a square) +# +# Let isr = invsqrt((A^2*r1 - r2^2) * A * r2^3) +# isr = sqrt(1 / ((A^2*r1 - r2^2) * A * r2^3)) if e = 1 +# isr = strt(sqrt(-1) / ((A^2*r1 - r2^2) * A * r2^3)) if e = -1 +# +# if e = 1 +# let u1 = -A * (A^2*r1 - r2^2) * A * r2^2 * isr^2 +# u1 = w +# u1 = u +# let v1 = -(A^2*r1 - r2^2) * A * isr +# v1 = -sqrt((A^2*r1 - r2^2) * A / r2^3) +# v1 = -sqrt(w^3 + A*w^2 + w) +# v1 = -sqrt(u^3 + A*u^2 + u) (because u = w) +# v1 = v +# +# if e = -1 +# let ufactor = -non_square * sqrt(-1) * r^2 +# let vfactor = sqrt(ufactor) +# let u2 = -A * (A^2*r1 - r2^2) * A * r2^2 * isr^2 * ufactor +# u2 = w * -1 * -non_square * r^2 +# u2 = w * non_square * r^2 +# u2 = u +# let v2 = (A^2*r1 - r2^2) * A * isr * vfactor +# v2 = sqrt(non_square * r^2 * (A^2*r1 - r2^2) * A / r2^3) +# v2 = sqrt(non_square * r^2 * (w^3 + A*w^2 + w)) +# v2 = sqrt(non_square * r^2 * w * (w^2 + A*w + 1)) +# v2 = sqrt(u (w^2 + A*w + 1)) +# v2 = sqrt(u ((-u-A)^2 + A*(-u-A) + 1)) +# v2 = sqrt(u (u^2 + A^2 + 2*A*u - A*u -A^2) + 1)) +# v2 = sqrt(u (u^2 + A*u + 1)) +# v2 = sqrt(u^3 + A*u^2 + u) +# v2 = v +ufactor = -non_square * sqrtm1 +vfactor = sqrt(ufactor) + +def fast_hash_to_curve(r): + t1 = r**2 * non_square # r1 + t2 = t1 + fe(1) # r2 + t3 = t2**2 + t4 = (A**2 * t1 - t3) * A # numerator + t1 = t3 * t2 # denominator + t1, is_square = invsqrt(t4 * t1) + u = r**2 * ufactor + v = r * vfactor + if is_square: u = fe(1) + if is_square: v = fe(1) + v = v * t4 * t1 + t1 = t1**2 + u = u * -A * t4 * t3 * t1 + if is_square != v.is_negative(): # XOR + v = -v + return (u, v) +# From the paper: +# Let sq = -non_square * u * (u+A) +# if sq is not a square, or u = -A, there is no mapping +# Assuming there is a mapping: +# if v is positive: r = sqrt(-(u+A) / u) +# if v is negative: r = sqrt(-u / (u+A)) +# +# We compute isr = invsqrt(-non_square * u * (u+A)) +# if it wasn't a non-zero square, abort. +# else, isr = sqrt(-1 / (non_square * u * (u+A)) +# +# This causes us to abort if u is zero, even though we shouldn't. This +# never happens in practice, because (i) a random point in the curve has +# a negligible chance of being zero, and (ii) scalar multiplication with +# a trimmed scalar *never* yields zero. +# +# Since: +# isr * (u+A) = sqrt(-1 / (non_square * u * (u+A)) * (u+A) +# isr * (u+A) = sqrt(-(u+A) / (non_square * u * (u+A)) +# and: +# isr = u = sqrt(-1 / (non_square * u * (u+A)) * u +# isr = u = sqrt(-u / (non_square * u * (u+A)) +# Therefore: +# if v is positive: r = isr * (u+A) +# if v is negative: r = isr * u def fast_curve_to_hash(point): - u = non_square - ufactor = -u * sqrtm1 - ufactor_sqrt = sqrt(ufactor) - - x, y = point - t0 = A + x - t1 = x - # if is_positive(y): r = u*q^2 = -(A + x)/x - # if is_negative(y): r = u*q^2 = -x/(A + x) - isr, is_square = invsqrt(-t0 * t1 * u) - # isr = sqrt(-1 / ((A + x) * x * u)) + u, v = point + t = u + A + r = -non_square * u * t + isr, is_square = invsqrt(r) if not is_square: return None - num = t0 - if y.is_positive(): num = t1 - q = num * isr - # if is_positive(y): q = (A + x) * sqrt(1 / (-x * (A + x) * u)) = sqrt(-(A + x) / (x * u)) - # if is_positive(y): q = sqrt(-(A + x) / (x * u)) - # if is_negative(y): q = x * sqrt(1 / (-x * (A + x) * u)) = sqrt(-x / ((A + x) * u)) - # if is_negative(y): q = sqrt(-x / ((A + x) * u)) - q = q.abs() - return q + if v.is_positive(): t = u + r = t * isr + r = r.abs() + return r ############## # Test suite # -- 2.47.3