]> git.codecow.com Git - Monocypher.git/commitdiff
Elligator script: clean up & comments
authorLoup Vaillant <loup@loup-vaillant.fr>
Sun, 1 Mar 2020 21:50:39 +0000 (22:50 +0100)
committerLoup Vaillant <loup@loup-vaillant.fr>
Sun, 1 Mar 2020 21:50:39 +0000 (22:50 +0100)
tests/gen/elligator.py

index ab9c7fff3dd023e17e3406ea7dd1275799fe3747..50ad19b3cb9e83bcbe796e0367f118fddc273fca 100755 (executable)
@@ -162,8 +162,8 @@ d = fe(-121665) / fe(121666)
 # To avoid divisions, we use affine coordinates: x = X/Z, y = Y/Z.
 # We can multiply Z instead of dividing X and Y.
 def point_add(a, b):
-    x1 = a[0];  y1 = a[1];  z1 = a[2];
-    x2 = b[0];  y2 = b[1];  z2 = b[2];
+    x1, y1, z1 = a
+    x2, y2, z2 = b
     denum = d*x1*x2*y1*y2
     z1z2  = z1 * z2
     z1z22 = z1z2**2
@@ -202,9 +202,7 @@ sqrt_mA2 = sqrt(fe(-486664)) # sqrt(-(A+2))
 # (u, v) = ((1+y)/(1-y), sqrt(-486664)*u/x)
 # (x, y) = (sqrt(-486664)*u/v, (u-1)/(u+1))
 def from_edwards(point):
-    x = point[0]
-    y = point[1]
-    z = point[2]
+    x, y, z = point
     u   = z + y
     zu  = z - y
     v   = u * z * sqrt_mA2
@@ -244,16 +242,15 @@ def hash_to_curve(r):
 # Test whether a point has a representative, straight from the paper.
 def can_curve_to_hash(point):
     u = point[0]
-    return u != -A and is_square(-non_square * u * (u + A))
+    return u != -A and is_square(-non_square * u * (u+A))
 
 # Computes the representative of a point, straight from the paper.
 def curve_to_hash(point):
     if not can_curve_to_hash(point):
         raise ValueError('cannot curve to hash')
-    u   = point[0]
-    v   = point[1]
-    sq1 = sqrt(-u     / (non_square * (u+A)))
-    sq2 = sqrt(-(u+A) / (non_square * u    ))
+    u, v = point
+    sq1  = sqrt(-u     / (non_square * (u+A)))
+    sq2  = sqrt(-(u+A) / (non_square * u    ))
     if v.is_positive(): return sq1
     else              : return sq2
 
@@ -262,13 +259,13 @@ def curve_to_hash(point):
 #####################
 
 # Inverse square root.
-# Returns sqrt(1/x)        if x is square.
-# Returns sqrt(sqrt(-1)/x) if x is not a square.
-# We assume x is not zero
+# Returns (sqrt(1/x)       , True ) if x is non-zero square.
+# Returns (sqrt(sqrt(-1)/x), False) if x is not a square.
+# Returns (0               , False) if x is zero.
 # We do not guarantee the sign of the square root.
 #
 # Notes:
-# let quartic = x^((p-1)/4)
+# Let quartic = x^((p-1)/4)
 #
 # x^((p-1)/2) = chi(x)
 # quartic^2   = chi(x)
@@ -324,70 +321,132 @@ def invsqrt(x):
     is_square = quartic == fe(1) or quartic == fe(-1)
     return isr, is_square
 
-def fast_hash_to_curve(q):
-    u = non_square
-    ufactor = -u * sqrtm1
-    ufactor_sqrt = sqrt(ufactor)
-
-    r = u * q**2
-    r1 = (r + fe(1))
-    num = A * (A**2 * r - r1**2)
-    den = r1**3
-    # x = -A / (r + 1)
-    # y = x^3 + A*x^2 + x
-    # y = A^3/(r + 1)^2 - A^3/(r + 1)^3 - A/(r + 1)
-    # y = (A^3*r - A*(r + 1)^2) / (r + 1)^3
-    isr, is_square = invsqrt(num * den)
-    # if is_square: isr = sqrt(1 / (num * den))
-    # if not is_square: isr = sqrt(sqrtm1 / (num * den))
-    x = -A * (num * r1**2 * isr**2)
-    # x = -A * num * (r + 1)^2 * sqrt(1 / (num * den))^2
-    # x = -A * num * (r + 1)^2 * 1 / (num * den)
-    # x = -A * (r + 1)^2 * 1 / den
-    # x = -A / (r + 1)
-    y = num * isr
-    # y = num * sqrt(1 / (num * den))
-    # y = sqrt(num^2 / (num * den))
-    # y = sqrt(num / den)
-    qx = q**2 * ufactor
-    qy = q * ufactor_sqrt
-    if is_square: qx = fe(1)
-    if is_square: qy = fe(1)
-    x = qx * x
-    # x = q^2 * -u * sqrt(-1) * -A * sqrt(-1) / (r + 1)
-    # x = -A * u * q^2 / (r + 1)
-    # x = -A * r / (r + 1)
-    y = qy * y
-    # y = q * sqrt(-u * sqrtm1) * sqrt(sqrt(-1) * num / den)
-    # y = sqrt(q^2 * u * num / den)
-    # y = sqrt(r * num / den)
-    y = y.abs()
-    if is_square: y = -y
-    return (x, y)
+# From the paper:
+# w = -A / (fe(1) + non_square * r^2)
+# e = chi(w^3 + A*w^2 + w)
+# u = e*w - (fe(1)-e)*(A//2)
+# v = -e * sqrt(u^3 + A*u^2 + u)
+#
+# Note that e is eiter 0, 1 or -1
+# if e = 0
+#   (u, v) = (0, 0)
+# if e = 1
+#   u = w
+#   v = -sqrt(u^3 + A*u^2 + u)
+# if e = -1
+#   u = -w - A = w * non_square * r^2
+#   v = sqrt(u^3 + A*u^2 + u)
+#
+# Let r1 = non_square * r^2
+# Let r2 = 1 + r1
+# Note that r2 cannot be zero, -1/non_square is not a square.
+# We can (tediously) verify that:
+#   w^3 + A*w^2 + w = (A^2*r1 - r2^2) * A / r2^3
+# Therefore:
+#   chi(w^3 + A*w^2 + w) = chi((A^2*r1 - r2^2) * (A / r2^3))
+#   chi(w^3 + A*w^2 + w) = chi((A^2*r1 - r2^2) * (A / r2^3)) * 1
+#   chi(w^3 + A*w^2 + w) = chi((A^2*r1 - r2^2) * (A / r2^3)) * chi(r2^6)
+#   chi(w^3 + A*w^2 + w) = chi((A^2*r1 - r2^2) * (A / r2^3)  *     r2^6)
+#   chi(w^3 + A*w^2 + w) = chi((A^2*r1 - r2^2) *  A * r2^3)
+# Corollary:
+#   e =  1 if (A^2*r1 - r2^2) *  A * r2^3) is a non-zero square
+#   e = -1 if (A^2*r1 - r2^2) *  A * r2^3) is not a square
+#   Note that w^3 + A*w^2 + w (and therefore e) can never be zero:
+#     w^3 + A*w^2 + w = w * (w^2 + A*w + 1)
+#     w^3 + A*w^2 + w = w * (w^2 + A*w + A^2/4 - A^2/4 + 1)
+#     w^3 + A*w^2 + w = w * (w + A/2)^2        - A^2/4 + 1)
+#     which is zero only if:
+#       w = 0                   (impossible)
+#       (w + A/2)^2 = A^2/4 - 1 (impossible, because A^2/4-1 is not a square)
+#
+# Let isr   = invsqrt((A^2*r1 - r2^2) *  A * r2^3)
+#     isr   = sqrt(1        / ((A^2*r1 - r2^2) *  A * r2^3)) if e =  1
+#     isr   = strt(sqrt(-1) / ((A^2*r1 - r2^2) *  A * r2^3)) if e = -1
+#
+# if e = 1
+#   let u1 = -A * (A^2*r1 - r2^2) * A * r2^2 * isr^2
+#       u1 = w
+#       u1 = u
+#   let v1 = -(A^2*r1 - r2^2) * A * isr
+#       v1 = -sqrt((A^2*r1 - r2^2) * A / r2^3)
+#       v1 = -sqrt(w^3 + A*w^2 + w)
+#       v1 = -sqrt(u^3 + A*u^2 + u)   (because u = w)
+#       v1 = v
+#
+# if e = -1
+#   let ufactor = -non_square * sqrt(-1) * r^2
+#   let vfactor = sqrt(ufactor)
+#   let u2 = -A * (A^2*r1 - r2^2) * A * r2^2 * isr^2 * ufactor
+#       u2 = w * -1 * -non_square * r^2
+#       u2 = w * non_square * r^2
+#       u2 = u
+#   let v2 = (A^2*r1 - r2^2) * A * isr * vfactor
+#       v2 = sqrt(non_square * r^2 * (A^2*r1 - r2^2) * A / r2^3)
+#       v2 = sqrt(non_square * r^2 * (w^3 + A*w^2 + w))
+#       v2 = sqrt(non_square * r^2 * w * (w^2 + A*w + 1))
+#       v2 = sqrt(u (w^2 + A*w + 1))
+#       v2 = sqrt(u ((-u-A)^2 + A*(-u-A) + 1))
+#       v2 = sqrt(u (u^2 + A^2 + 2*A*u - A*u -A^2) + 1))
+#       v2 = sqrt(u (u^2 + A*u + 1))
+#       v2 = sqrt(u^3 + A*u^2 + u)
+#       v2 = v
+ufactor = -non_square * sqrtm1
+vfactor = sqrt(ufactor)
+
+def fast_hash_to_curve(r):
+    t1  = r**2 * non_square   # r1
+    t2  = t1 + fe(1)          # r2
+    t3  = t2**2
+    t4 = (A**2 * t1 - t3) * A # numerator
+    t1 = t3 * t2              # denominator
+    t1, is_square = invsqrt(t4 * t1)
+    u  = r**2 * ufactor
+    v  = r    * vfactor
+    if is_square: u = fe(1)
+    if is_square: v = fe(1)
+    v   = v * t4 * t1
+    t1 = t1**2
+    u   = u * -A * t4 * t3 * t1
+    if is_square != v.is_negative(): # XOR
+        v = -v
+    return (u, v)
 
+# From the paper:
+# Let sq = -non_square * u * (u+A)
+# if sq is not a square, or u = -A, there is no mapping
+# Assuming there is a mapping:
+#   if v is positive: r = sqrt(-(u+A) / u)
+#   if v is negative: r = sqrt(-u / (u+A))
+#
+# We compute isr = invsqrt(-non_square * u * (u+A))
+# if it wasn't a non-zero square, abort.
+# else, isr = sqrt(-1 / (non_square * u * (u+A))
+#
+# This causes us to abort if u is zero, even though we shouldn't. This
+# never happens in practice, because (i) a random point in the curve has
+# a negligible chance of being zero, and (ii) scalar multiplication with
+# a trimmed scalar *never* yields zero.
+#
+# Since:
+#   isr * (u+A) = sqrt(-1     / (non_square * u * (u+A)) * (u+A)
+#   isr * (u+A) = sqrt(-(u+A) / (non_square * u * (u+A))
+# and:
+#   isr = u = sqrt(-1 / (non_square * u * (u+A)) * u
+#   isr = u = sqrt(-u / (non_square * u * (u+A))
+# Therefore:
+#   if v is positive: r = isr * (u+A)
+#   if v is negative: r = isr * u
 def fast_curve_to_hash(point):
-    u = non_square
-    ufactor = -u * sqrtm1
-    ufactor_sqrt = sqrt(ufactor)
-
-    x, y = point
-    t0 = A + x
-    t1 = x
-    # if is_positive(y): r = u*q^2 = -(A + x)/x
-    # if is_negative(y): r = u*q^2 = -x/(A + x)
-    isr, is_square = invsqrt(-t0 * t1 * u)
-    # isr = sqrt(-1 / ((A + x) * x * u))
+    u, v = point
+    t = u + A
+    r = -non_square * u * t
+    isr, is_square = invsqrt(r)
     if not is_square:
         return None
-    num = t0
-    if y.is_positive(): num = t1
-    q = num * isr
-    # if is_positive(y): q = (A + x) * sqrt(1 / (-x * (A + x) * u)) = sqrt(-(A + x) / (x * u))
-    # if is_positive(y): q = sqrt(-(A + x) / (x * u))
-    # if is_negative(y): q = x * sqrt(1 / (-x * (A + x) * u)) = sqrt(-x / ((A + x) * u))
-    # if is_negative(y): q = sqrt(-x / ((A + x) * u))
-    q = q.abs()
-    return q
+    if v.is_positive(): t = u
+    r = t * isr
+    r = r.abs()
+    return r
 
 ##############
 # Test suite #