sqrt_half = (fe(-1) / fe(2)) ** ((p+3)//8)
chi_minus2 = chi(fe(-2))
-def fast_curve_to_hash(point):
- u = point[0]
- v = point[1]
- c = pow_p58(u * (u+A)**7)
- sq1 = u * (u+A)**3 * c * sqrt_half
- sq2 = u**3 * (u+A)**25 * c**7 * sqrt_half
- sqv = u**2 * (u+A)**14 * c**4 * chi_minus2
- if (sqv == fe(-1)):
- return None
- if fe(2) * (u + A) * sq1**2 + u != fe(0): sq1 = (sq1 * sqrt1)
- if fe(2) * u * sq2**2 + u + A != fe(0): sq2 = (sq2 * sqrt1)
- sq = sq1 if v.is_positive() else sq2
- return sq.abs()
+def invsqrt(x):
+ isr = x**((p - 5) // 8)
+ quartic = x * isr**2
+ if quartic == fe(-1) or quartic == -sqrt1:
+ isr = isr * sqrt1
+ is_square = quartic == fe(1) or quartic == fe(-1)
+ return isr, is_square
+
+def fast_hash_to_curve(q):
+ u = fe(2)
+ ufactor = -u * sqrt1
+ ufactor_sqrt = sqrt(ufactor)
+
+ r = u * q**2
+ r1 = (r + fe(1))
+ num = A * (A**2 * r - r1**2)
+ den = r1**3
+ # x = -A / (r + 1)
+ # y = x^3 + A*x^2 + x
+ # y = A^3/(r + 1)^2 - A^3/(r + 1)^3 - A/(r + 1)
+ # y = (A^3*r - A*(r + 1)^2) / (r + 1)^3
+ isr, is_square = invsqrt(num * den)
+ # if is_square: isr = sqrt(1 / (num * den))
+ # if not is_square: isr = sqrt(sqrt1 / (num * den))
+ x = -A * (num * r1**2 * isr**2)
+ # x = -A * num * (r + 1)^2 * sqrt(1 / (num * den))^2
+ # x = -A * num * (r + 1)^2 * 1 / (num * den)
+ # x = -A * (r + 1)^2 * 1 / den
+ # x = -A / (r + 1)
+ y = num * isr
+ # y = num * sqrt(1 / (num * den))
+ # y = sqrt(num^2 / (num * den))
+ # y = sqrt(num / den)
+ qx = q**2 * ufactor
+ qy = q * ufactor_sqrt
+ if is_square: qx = fe(1)
+ if is_square: qy = fe(1)
+ x = qx * x
+ # x = q^2 * -u * sqrt(-1) * -A * sqrt(-1) / (r + 1)
+ # x = -A * u * q^2 / (r + 1)
+ # x = -A * r / (r + 1)
+ y = qy * y
+ # y = q * sqrt(-u * sqrt1) * sqrt(sqrt(-1) * num / den)
+ # y = sqrt(q^2 * u * num / den)
+ # y = sqrt(r * num / den)
+ y = y.abs()
+ if is_square: y = -y
+ return (x, y)
-# Explicit formula for curve_to_hash
-def explicit_curve_to_hash(point):
- u = point[0]
- v = point[1]
- ua = u + A
- t1 = ua**2 # 2ua
- sq1 = t1 * ua # 3ua
- t2 = t1**2 # 4ua
- t1 = t2 * sq1 # 7ua
- sq2 = t1**2 # 14ua
- c = pow_p58(u * t1)
- t3 = c**2 # 2c
- t4 = t3**2 # 4c
- t5 = u**2 # 2u
- sqv = t5 * sq2
- sqv = sqv * t4
- sqv = sqv * chi_minus2
- if (sqv == fe(-1)): # not constant time, don't have a choice
+def fast_curve_to_hash(point):
+ u = fe(2)
+ ufactor = -u * sqrt1
+ ufactor_sqrt = sqrt(ufactor)
+
+ x, y = point
+ t0 = A + x
+ t1 = x
+ # if is_positive(y): r = u*q^2 = -(A + x)/x
+ # if is_negative(y): r = u*q^2 = -x/(A + x)
+ isr, is_square = invsqrt(-t0 * t1 * u)
+ # isr = sqrt(-1 / ((A + x) * x * u))
+ if not is_square:
return None
- sq1 = u * sq1
- sq1 = sq1 * c
- sq1 = sq1 * sqrt_half
- t5 = u * t5 # 3u
- t3 = c * t3 # 3c
- c = t4 * t3 # 7c
- sq2 = sq2 * t1 # 21ua
- sq2 = sq2 * t2 # 25ua
- sq2 = t5 * sq2
- sq2 = sq2 * c
- sq2 = sq2 * sqrt_half
- t1 = sq1 * sqrt1
- sqv = fe(2) * sq1**2
- sqv = sqv * ua
- sqv = sqv + u
- if sqv != fe(0) : sq1 = t1 # constant time move
- t2 = sq2 * sqrt1
- sqv = fe(2) * sq2**2
- sqv = sqv * u
- sqv = sqv + ua
- if sqv != fe(0) : sq2 = t2 # constant time move
- if v .is_negative(): sq1 = sq2 # constant time move
- t1 = -sq1
- if sq1.is_negative(): sq1 = t1 # constant time move
- # wipe temporaries: ua, c, sq1, sq2, sqv, t1, t2, t3, t4, t5
- return sq1
+ num = t0
+ if y.is_positive(): num = t1
+ q = num * isr
+ # if is_positive(y): q = (A + x) * sqrt(1 / (-x * (A + x) * u)) = sqrt(-(A + x) / (x * u))
+ # if is_positive(y): q = sqrt(-(A + x) / (x * u))
+ # if is_negative(y): q = x * sqrt(1 / (-x * (A + x) * u)) = sqrt(-x / ((A + x) * u))
+ # if is_negative(y): q = sqrt(-x / ((A + x) * u))
+ q = q.abs()
+ return q
half_A = A // 2
-# Explicit formula for hash_to_curve
-# We don't need the v coordinate for X25519, so it is omited
-def explicit_hash_to_curve(r):
- w = fe(2) * r**2 # fe_sq2()
- w = w + fe(1)
- w = w.invert()
- w = w * A
- w = -w
- e = A + w
- e = e * w
- e = e + fe(1)
- e = e * w
- e = chi(e)
- u = fe(1) - e
- u = u * half_A
- w = e * w
- u = w - u
- return u
-
# entire key generation chain
def full_cycle_check(scalar, u):
fe(scalar).print()
h = curve_to_hash(uv)
if h.is_negative(): raise ValueError('Non Canonical representative')
fh = fast_curve_to_hash(uv)
- eh = explicit_curve_to_hash(uv)
if fh != h: raise ValueError('Incorrect fast_curve_to_hash()')
- if eh != h: raise ValueError('Incorrect explicit_curve_to_hash()')
print('01:') # Success
h.print() # actual value for the hash
c = hash_to_curve(h)
- u = explicit_hash_to_curve(h)
- if u != c[0]: raise ValueError('Incorrect explicit_hash_to_curve()')
+ f = fast_hash_to_curve(h)
+ if f != c : raise ValueError('Incorrect fast_hash_to_curve()')
if c != uv : raise ValueError('Round trip failure')
else:
fh = fast_curve_to_hash(uv)
- eh = explicit_curve_to_hash(uv)
if not (fh is None): raise ValueError('Fast Curve to Hash did not fail')
- if not (eh is None): raise ValueError('Explicit Curve to Hash did not fail')
print('00:') # Failure
print('00:') # dummy value for the hash