]> git.codecow.com Git - Monocypher.git/commitdiff
Replaced fast mappings by even better ones
authorLoup Vaillant <loup@loup-vaillant.fr>
Wed, 26 Feb 2020 21:14:57 +0000 (22:14 +0100)
committerLoup Vaillant <loup@loup-vaillant.fr>
Wed, 26 Feb 2020 21:14:57 +0000 (22:14 +0100)
Turned out there were much simpler ways to compute the mapping, thanks
to the fact that when the prime p is congruent to 5 modulo 8, we have
this nice equality:

    x^((p-5)/8) = sqrt(1/x)          if x is square,
    x^((p-5)/8) = sqrt(sqrt(-1)/x)   otherwise

The code was kindly given by Andrew Moon, who got the original trick
from Mike Hamburg.

tests/gen/elligator.py

index eea2b7f4637881e2f5333e8a9b49e167de36b01f..ec1448bafbb9a7c37b31397bd0393e7ff12c2e32 100755 (executable)
@@ -219,87 +219,81 @@ def pow_p58(f): return f ** ((p-5)//8)
 sqrt_half  = (fe(-1) / fe(2)) ** ((p+3)//8)
 chi_minus2 = chi(fe(-2))
 
-def fast_curve_to_hash(point):
-    u = point[0]
-    v = point[1]
-    c   = pow_p58(u * (u+A)**7)
-    sq1 = u    * (u+A)**3  * c    * sqrt_half
-    sq2 = u**3 * (u+A)**25 * c**7 * sqrt_half
-    sqv = u**2 * (u+A)**14 * c**4 * chi_minus2
-    if (sqv == fe(-1)):
-        return None
-    if fe(2) * (u + A) * sq1**2 + u      != fe(0): sq1 = (sq1 * sqrt1)
-    if fe(2) *  u      * sq2**2 + u + A  != fe(0): sq2 = (sq2 * sqrt1)
-    sq = sq1 if v.is_positive() else sq2
-    return sq.abs()
+def invsqrt(x):
+    isr = x**((p - 5) // 8)
+    quartic = x * isr**2
+    if quartic == fe(-1) or quartic == -sqrt1:
+        isr = isr * sqrt1
+    is_square = quartic == fe(1) or quartic == fe(-1)
+    return isr, is_square
+
+def fast_hash_to_curve(q):
+    u = fe(2)
+    ufactor = -u * sqrt1
+    ufactor_sqrt = sqrt(ufactor)
+
+    r = u * q**2
+    r1 = (r + fe(1))
+    num = A * (A**2 * r - r1**2)
+    den = r1**3
+    # x = -A / (r + 1)
+    # y = x^3 + A*x^2 + x
+    # y = A^3/(r + 1)^2 - A^3/(r + 1)^3 - A/(r + 1)
+    # y = (A^3*r - A*(r + 1)^2) / (r + 1)^3
+    isr, is_square = invsqrt(num * den)
+    # if is_square: isr = sqrt(1 / (num * den))
+    # if not is_square: isr = sqrt(sqrt1 / (num * den))
+    x = -A * (num * r1**2 * isr**2)
+    # x = -A * num * (r + 1)^2 * sqrt(1 / (num * den))^2
+    # x = -A * num * (r + 1)^2 * 1 / (num * den)
+    # x = -A * (r + 1)^2 * 1 / den
+    # x = -A / (r + 1)
+    y = num * isr
+    # y = num * sqrt(1 / (num * den))
+    # y = sqrt(num^2 / (num * den))
+    # y = sqrt(num / den)
+    qx = q**2 * ufactor
+    qy = q * ufactor_sqrt
+    if is_square: qx = fe(1)
+    if is_square: qy = fe(1)
+    x = qx * x
+    # x = q^2 * -u * sqrt(-1) * -A * sqrt(-1) / (r + 1)
+    # x = -A * u * q^2 / (r + 1)
+    # x = -A * r / (r + 1)
+    y = qy * y
+    # y = q * sqrt(-u * sqrt1) * sqrt(sqrt(-1) * num / den)
+    # y = sqrt(q^2 * u * num / den)
+    # y = sqrt(r * num / den)
+    y = y.abs()
+    if is_square: y = -y
+    return (x, y)
 
-# Explicit formula for curve_to_hash
-def explicit_curve_to_hash(point):
-    u = point[0]
-    v = point[1]
-    ua  = u + A
-    t1  = ua**2     #  2ua
-    sq1 = t1  * ua  #  3ua
-    t2  = t1**2     #  4ua
-    t1  = t2  * sq1 #  7ua
-    sq2 = t1**2     # 14ua
-    c   = pow_p58(u * t1)
-    t3  = c**2      #  2c
-    t4  = t3**2     #  4c
-    t5  = u**2      #  2u
-    sqv = t5  * sq2
-    sqv = sqv * t4
-    sqv = sqv * chi_minus2
-    if (sqv == fe(-1)): # not constant time, don't have a choice
+def fast_curve_to_hash(point):
+    u = fe(2)
+    ufactor = -u * sqrt1
+    ufactor_sqrt = sqrt(ufactor)
+
+    x, y = point
+    t0 = A + x
+    t1 = x
+    # if is_positive(y): r = u*q^2 = -(A + x)/x
+    # if is_negative(y): r = u*q^2 = -x/(A + x)
+    isr, is_square = invsqrt(-t0 * t1 * u)
+    # isr = sqrt(-1 / ((A + x) * x * u))
+    if not is_square:
         return None
-    sq1 = u   * sq1
-    sq1 = sq1 * c
-    sq1 = sq1 * sqrt_half
-    t5  = u   * t5 #  3u
-    t3  = c   * t3 #  3c
-    c   = t4  * t3 #  7c
-    sq2 = sq2 * t1 # 21ua
-    sq2 = sq2 * t2 # 25ua
-    sq2 = t5  * sq2
-    sq2 = sq2 * c
-    sq2 = sq2 * sqrt_half
-    t1  = sq1 * sqrt1
-    sqv = fe(2) * sq1**2
-    sqv = sqv * ua
-    sqv = sqv + u
-    if sqv != fe(0)     : sq1 = t1  # constant time move
-    t2  = sq2 * sqrt1
-    sqv = fe(2) * sq2**2
-    sqv = sqv * u
-    sqv = sqv + ua
-    if sqv != fe(0)     : sq2 = t2  # constant time move
-    if v  .is_negative(): sq1 = sq2 # constant time move
-    t1 = -sq1
-    if sq1.is_negative(): sq1 = t1  # constant time move
-    # wipe temporaries: ua, c, sq1, sq2, sqv, t1, t2, t3, t4, t5
-    return sq1
+    num = t0
+    if y.is_positive(): num = t1
+    q = num * isr
+    # if is_positive(y): q = (A + x) * sqrt(1 / (-x * (A + x) * u)) = sqrt(-(A + x) / (x * u))
+    # if is_positive(y): q = sqrt(-(A + x) / (x * u))
+    # if is_negative(y): q = x * sqrt(1 / (-x * (A + x) * u)) = sqrt(-x / ((A + x) * u))
+    # if is_negative(y): q = sqrt(-x / ((A + x) * u))
+    q = q.abs()
+    return q
 
 half_A = A // 2
 
-# Explicit formula for hash_to_curve
-# We don't need the v coordinate for X25519, so it is omited
-def explicit_hash_to_curve(r):
-    w = fe(2) * r**2  # fe_sq2()
-    w = w + fe(1)
-    w = w.invert()
-    w = w * A
-    w = -w
-    e = A + w
-    e = e * w
-    e = e + fe(1)
-    e = e * w
-    e = chi(e)
-    u = fe(1) - e
-    u = u * half_A
-    w = e * w
-    u = w - u
-    return u
-
 # entire key generation chain
 def full_cycle_check(scalar, u):
     fe(scalar).print()
@@ -314,20 +308,16 @@ def full_cycle_check(scalar, u):
         h  = curve_to_hash(uv)
         if h.is_negative(): raise ValueError('Non Canonical representative')
         fh = fast_curve_to_hash(uv)
-        eh = explicit_curve_to_hash(uv)
         if fh != h: raise ValueError('Incorrect fast_curve_to_hash()')
-        if eh != h: raise ValueError('Incorrect explicit_curve_to_hash()')
         print('01:')    # Success
         h.print()       # actual value for the hash
         c = hash_to_curve(h)
-        u = explicit_hash_to_curve(h)
-        if u != c[0]: raise ValueError('Incorrect explicit_hash_to_curve()')
+        f = fast_hash_to_curve(h)
+        if f != c   : raise ValueError('Incorrect fast_hash_to_curve()')
         if c != uv  : raise ValueError('Round trip failure')
     else:
         fh = fast_curve_to_hash(uv)
-        eh = explicit_curve_to_hash(uv)
         if not (fh is None): raise ValueError('Fast Curve to Hash did not fail')
-        if not (eh is None): raise ValueError('Explicit Curve to Hash did not fail')
         print('00:')    # Failure
         print('00:')    # dummy value for the hash