]> git.codecow.com Git - Monocypher.git/commitdiff
Elligator script: general organisation
authorLoup Vaillant <loup@loup-vaillant.fr>
Thu, 27 Feb 2020 23:40:23 +0000 (00:40 +0100)
committerLoup Vaillant <loup@loup-vaillant.fr>
Thu, 27 Feb 2020 23:40:23 +0000 (00:40 +0100)
tests/gen/elligator.py

index 256a237c90bbca1667bacadd7b757e320ebaefc5..c2c5986b98ed6b8f2c2f44051584facc57ea54d0 100755 (executable)
@@ -53,6 +53,9 @@
 
 import sys # stdin
 
+####################
+# Field arithmetic #
+####################
 class fe:
     """Prime field over 2^255 - 19"""
     p = 2**255 - 19
@@ -87,64 +90,50 @@ class fe:
         if m != 0: raise ValueError('number is too big!!')
         print(':')
 
-# Curve25519 constants
+########################
+# Curve25519 constants #
+########################
 p = fe.p
 A = fe(486662)
 # B = 1
 
+###############
+# Square root #
+###############
 def chi      (n): return n**((p-1)//2)
 def is_square(n): return n == fe(0) or chi(n) == fe(1)
 
-sqrt1 = ((fe(2)**((p-1) // 4)) * fe(-1)**((p+3) // 8)).abs()
+sqrtm1 = ((fe(2)**((p-1) // 4)) * fe(-1)**((p+3) // 8)).abs()
 
 def sqrt(n):
     if not is_square(n) : raise ValueError('Not a square!')
     root = n**((p+3) // 8)
-    if root * root != n: root = (root * sqrt1)
+    if root * root != n: root = (root * sqrtm1)
     if root * root != n: raise ValueError('Should be a square!!')
     return root.abs()
 
-# Elligator 2
-
-# Arbitrary non square, typically chosen to minimise computation.
-# 2 and sqrt(-1) both work fairly well
-non_square = fe(2) # that's what standards seem to agree upon
-
-def hash_to_curve(r):
-    w = -A / (fe(1) + non_square * r**2)
-    e = chi(w**3 + A*w**2 + w)
-    u = e*w - (fe(1)-e)*(A//2)
-    v = -e * sqrt(u**3 + A*u**2 + u)
-    return (u, v)
-
-def can_curve_to_hash(point):
-    u = point[0]
-    return u != -A and is_square(-non_square * u * (u + A))
-
-def curve_to_hash(point):
-    if not can_curve_to_hash(point):
-        raise ValueError('cannot curve to hash')
-    u   = point[0]
-    v   = point[1]
-    sq1 = sqrt(-u     / (non_square * (u+A)))
-    sq2 = sqrt(-(u+A) / (non_square * u    ))
-    if v.is_positive(): return sq1
-    else              : return sq2
-
-# Edwards (Edwards25519)
-# -x^2 + y^2 = 1 + d*x^2*y^2
-d = fe(-121665) / fe(121666)
-
+#########################
+# scalar multiplication #
+#########################
+# Clamp the scalar.
+# % 8 stops subgroup attacks
+# Clearing bit 255 and setting bit 254 facilitates constant time ladders.
 def trim(scalar):
     trimmed = scalar - scalar % 8
     trimmed = trimmed % 2**254
     trimmed = trimmed + 2**254
     return trimmed
 
-# fast point addition & scalar multiplication with affine coordinates:
-# x = X/Z, y = Y/Z. We can multiply Z instead of dividing X and Y.
-# The goal is to test the merging of the final inversion
-# with the exponentiations required for curve_to_hash
+# Edwards25519 equation (d defined below):
+# -x^2 + y^2 = 1 + d*x^2*y^2
+d = fe(-121665) / fe(121666)
+
+# Point addition:
+# denum = d*x1*x2*y1*y2
+# x     = (x1*y2 + x2*y1) / (1 + denum)
+# y     = (y1*y2 + x1*x2) / (1 - denum)
+# To avoid divisions, we use affine coordinates: x = X/Z, y = Y/Z.
+# We can multiply Z instead of dividing X and Y.
 def point_add(a, b):
     x1 = a[0];  y1 = a[1];  z1 = a[2];
     x2 = b[0];  y2 = b[1];  z2 = b[2];
@@ -188,24 +177,55 @@ def from_edwards(point):
     zu  = z - y
     v   = u * z * sqrt_mA2
     zv  = zu * x
-    div = (zu * zv).invert()
+    div = (zu * zv).invert() # now we have to divide
     return (u*zv*div, v*zu*div)
 
-def pow_p58(f): return f ** ((p-5)//8)
-sqrt_half  = (fe(-1) / fe(2)) ** ((p+3)//8)
-chi_minus2 = chi(fe(-2))
+def x25519_public_key(private_key):
+    return from_edwards(scalarbase(private_key))
+
+###########################
+# Elligator 2 (reference) #
+###########################
+
+# Arbitrary non square, typically chosen to minimise computation.
+# 2 and sqrt(-1) both work fairly well
+non_square = fe(2) # that's what standards seem to agree upon
 
+def hash_to_curve(r):
+    w = -A / (fe(1) + non_square * r**2)
+    e = chi(w**3 + A*w**2 + w)
+    u = e*w - (fe(1)-e)*(A//2)
+    v = -e * sqrt(u**3 + A*u**2 + u)
+    return (u, v)
+
+def can_curve_to_hash(point):
+    u = point[0]
+    return u != -A and is_square(-non_square * u * (u + A))
+
+def curve_to_hash(point):
+    if not can_curve_to_hash(point):
+        raise ValueError('cannot curve to hash')
+    u   = point[0]
+    v   = point[1]
+    sq1 = sqrt(-u     / (non_square * (u+A)))
+    sq2 = sqrt(-(u+A) / (non_square * u    ))
+    if v.is_positive(): return sq1
+    else              : return sq2
+
+#####################
+# Elligator2 (fast) #
+#####################
 def invsqrt(x):
     isr = x**((p - 5) // 8)
     quartic = x * isr**2
-    if quartic == fe(-1) or quartic == -sqrt1:
-        isr = isr * sqrt1
+    if quartic == fe(-1) or quartic == -sqrtm1:
+        isr = isr * sqrtm1
     is_square = quartic == fe(1) or quartic == fe(-1)
     return isr, is_square
 
 def fast_hash_to_curve(q):
     u = non_square
-    ufactor = -u * sqrt1
+    ufactor = -u * sqrtm1
     ufactor_sqrt = sqrt(ufactor)
 
     r = u * q**2
@@ -218,7 +238,7 @@ def fast_hash_to_curve(q):
     # y = (A^3*r - A*(r + 1)^2) / (r + 1)^3
     isr, is_square = invsqrt(num * den)
     # if is_square: isr = sqrt(1 / (num * den))
-    # if not is_square: isr = sqrt(sqrt1 / (num * den))
+    # if not is_square: isr = sqrt(sqrtm1 / (num * den))
     x = -A * (num * r1**2 * isr**2)
     # x = -A * num * (r + 1)^2 * sqrt(1 / (num * den))^2
     # x = -A * num * (r + 1)^2 * 1 / (num * den)
@@ -237,7 +257,7 @@ def fast_hash_to_curve(q):
     # x = -A * u * q^2 / (r + 1)
     # x = -A * r / (r + 1)
     y = qy * y
-    # y = q * sqrt(-u * sqrt1) * sqrt(sqrt(-1) * num / den)
+    # y = q * sqrt(-u * sqrtm1) * sqrt(sqrt(-1) * num / den)
     # y = sqrt(q^2 * u * num / den)
     # y = sqrt(r * num / den)
     y = y.abs()
@@ -246,7 +266,7 @@ def fast_hash_to_curve(q):
 
 def fast_curve_to_hash(point):
     u = non_square
-    ufactor = -u * sqrt1
+    ufactor = -u * sqrtm1
     ufactor_sqrt = sqrt(ufactor)
 
     x, y = point
@@ -268,12 +288,13 @@ def fast_curve_to_hash(point):
     q = q.abs()
     return q
 
-half_A = A // 2
-
-# entire key generation chain
-def full_cycle_check(scalar, u):
-    fe(scalar).print()
-    uv = from_edwards(scalarbase(scalar))
+##############
+# Test suite #
+##############
+# Test a full round trip, and print the relevant test vectors
+def full_cycle_check(private_key, u):
+    fe(private_key).print()
+    uv = x25519_public_key(private_key)
     if uv [0] != u: raise ValueError('Test vector failure')
     uv[0].print()
     uv[1].print()