From 5cfb772a7440bcbc1005213d0483d7d1cc770367 Mon Sep 17 00:00:00 2001 From: Loup Vaillant Date: Wed, 26 Feb 2020 22:14:57 +0100 Subject: [PATCH] Replaced fast mappings by even better ones Turned out there were much simpler ways to compute the mapping, thanks to the fact that when the prime p is congruent to 5 modulo 8, we have this nice equality: x^((p-5)/8) = sqrt(1/x) if x is square, x^((p-5)/8) = sqrt(sqrt(-1)/x) otherwise The code was kindly given by Andrew Moon, who got the original trick from Mike Hamburg. --- tests/gen/elligator.py | 154 +++++++++++++++++++---------------------- 1 file changed, 72 insertions(+), 82 deletions(-) diff --git a/tests/gen/elligator.py b/tests/gen/elligator.py index eea2b7f..ec1448b 100755 --- a/tests/gen/elligator.py +++ b/tests/gen/elligator.py @@ -219,87 +219,81 @@ def pow_p58(f): return f ** ((p-5)//8) sqrt_half = (fe(-1) / fe(2)) ** ((p+3)//8) chi_minus2 = chi(fe(-2)) -def fast_curve_to_hash(point): - u = point[0] - v = point[1] - c = pow_p58(u * (u+A)**7) - sq1 = u * (u+A)**3 * c * sqrt_half - sq2 = u**3 * (u+A)**25 * c**7 * sqrt_half - sqv = u**2 * (u+A)**14 * c**4 * chi_minus2 - if (sqv == fe(-1)): - return None - if fe(2) * (u + A) * sq1**2 + u != fe(0): sq1 = (sq1 * sqrt1) - if fe(2) * u * sq2**2 + u + A != fe(0): sq2 = (sq2 * sqrt1) - sq = sq1 if v.is_positive() else sq2 - return sq.abs() +def invsqrt(x): + isr = x**((p - 5) // 8) + quartic = x * isr**2 + if quartic == fe(-1) or quartic == -sqrt1: + isr = isr * sqrt1 + is_square = quartic == fe(1) or quartic == fe(-1) + return isr, is_square + +def fast_hash_to_curve(q): + u = fe(2) + ufactor = -u * sqrt1 + ufactor_sqrt = sqrt(ufactor) + + r = u * q**2 + r1 = (r + fe(1)) + num = A * (A**2 * r - r1**2) + den = r1**3 + # x = -A / (r + 1) + # y = x^3 + A*x^2 + x + # y = A^3/(r + 1)^2 - A^3/(r + 1)^3 - A/(r + 1) + # y = (A^3*r - A*(r + 1)^2) / (r + 1)^3 + isr, is_square = invsqrt(num * den) + # if is_square: isr = sqrt(1 / (num * den)) + # if not is_square: isr = sqrt(sqrt1 / (num * den)) + x = -A * (num * r1**2 * isr**2) + # x = -A * num * (r + 1)^2 * sqrt(1 / (num * den))^2 + # x = -A * num * (r + 1)^2 * 1 / (num * den) + # x = -A * (r + 1)^2 * 1 / den + # x = -A / (r + 1) + y = num * isr + # y = num * sqrt(1 / (num * den)) + # y = sqrt(num^2 / (num * den)) + # y = sqrt(num / den) + qx = q**2 * ufactor + qy = q * ufactor_sqrt + if is_square: qx = fe(1) + if is_square: qy = fe(1) + x = qx * x + # x = q^2 * -u * sqrt(-1) * -A * sqrt(-1) / (r + 1) + # x = -A * u * q^2 / (r + 1) + # x = -A * r / (r + 1) + y = qy * y + # y = q * sqrt(-u * sqrt1) * sqrt(sqrt(-1) * num / den) + # y = sqrt(q^2 * u * num / den) + # y = sqrt(r * num / den) + y = y.abs() + if is_square: y = -y + return (x, y) -# Explicit formula for curve_to_hash -def explicit_curve_to_hash(point): - u = point[0] - v = point[1] - ua = u + A - t1 = ua**2 # 2ua - sq1 = t1 * ua # 3ua - t2 = t1**2 # 4ua - t1 = t2 * sq1 # 7ua - sq2 = t1**2 # 14ua - c = pow_p58(u * t1) - t3 = c**2 # 2c - t4 = t3**2 # 4c - t5 = u**2 # 2u - sqv = t5 * sq2 - sqv = sqv * t4 - sqv = sqv * chi_minus2 - if (sqv == fe(-1)): # not constant time, don't have a choice +def fast_curve_to_hash(point): + u = fe(2) + ufactor = -u * sqrt1 + ufactor_sqrt = sqrt(ufactor) + + x, y = point + t0 = A + x + t1 = x + # if is_positive(y): r = u*q^2 = -(A + x)/x + # if is_negative(y): r = u*q^2 = -x/(A + x) + isr, is_square = invsqrt(-t0 * t1 * u) + # isr = sqrt(-1 / ((A + x) * x * u)) + if not is_square: return None - sq1 = u * sq1 - sq1 = sq1 * c - sq1 = sq1 * sqrt_half - t5 = u * t5 # 3u - t3 = c * t3 # 3c - c = t4 * t3 # 7c - sq2 = sq2 * t1 # 21ua - sq2 = sq2 * t2 # 25ua - sq2 = t5 * sq2 - sq2 = sq2 * c - sq2 = sq2 * sqrt_half - t1 = sq1 * sqrt1 - sqv = fe(2) * sq1**2 - sqv = sqv * ua - sqv = sqv + u - if sqv != fe(0) : sq1 = t1 # constant time move - t2 = sq2 * sqrt1 - sqv = fe(2) * sq2**2 - sqv = sqv * u - sqv = sqv + ua - if sqv != fe(0) : sq2 = t2 # constant time move - if v .is_negative(): sq1 = sq2 # constant time move - t1 = -sq1 - if sq1.is_negative(): sq1 = t1 # constant time move - # wipe temporaries: ua, c, sq1, sq2, sqv, t1, t2, t3, t4, t5 - return sq1 + num = t0 + if y.is_positive(): num = t1 + q = num * isr + # if is_positive(y): q = (A + x) * sqrt(1 / (-x * (A + x) * u)) = sqrt(-(A + x) / (x * u)) + # if is_positive(y): q = sqrt(-(A + x) / (x * u)) + # if is_negative(y): q = x * sqrt(1 / (-x * (A + x) * u)) = sqrt(-x / ((A + x) * u)) + # if is_negative(y): q = sqrt(-x / ((A + x) * u)) + q = q.abs() + return q half_A = A // 2 -# Explicit formula for hash_to_curve -# We don't need the v coordinate for X25519, so it is omited -def explicit_hash_to_curve(r): - w = fe(2) * r**2 # fe_sq2() - w = w + fe(1) - w = w.invert() - w = w * A - w = -w - e = A + w - e = e * w - e = e + fe(1) - e = e * w - e = chi(e) - u = fe(1) - e - u = u * half_A - w = e * w - u = w - u - return u - # entire key generation chain def full_cycle_check(scalar, u): fe(scalar).print() @@ -314,20 +308,16 @@ def full_cycle_check(scalar, u): h = curve_to_hash(uv) if h.is_negative(): raise ValueError('Non Canonical representative') fh = fast_curve_to_hash(uv) - eh = explicit_curve_to_hash(uv) if fh != h: raise ValueError('Incorrect fast_curve_to_hash()') - if eh != h: raise ValueError('Incorrect explicit_curve_to_hash()') print('01:') # Success h.print() # actual value for the hash c = hash_to_curve(h) - u = explicit_hash_to_curve(h) - if u != c[0]: raise ValueError('Incorrect explicit_hash_to_curve()') + f = fast_hash_to_curve(h) + if f != c : raise ValueError('Incorrect fast_hash_to_curve()') if c != uv : raise ValueError('Round trip failure') else: fh = fast_curve_to_hash(uv) - eh = explicit_curve_to_hash(uv) if not (fh is None): raise ValueError('Fast Curve to Hash did not fail') - if not (eh is None): raise ValueError('Explicit Curve to Hash did not fail') print('00:') # Failure print('00:') # dummy value for the hash -- 2.47.3