# with this software. If not, see
# <https://creativecommons.org/publicdomain/zero/1.0/>
+class fe:
+ """Prime field over 2^255 - 19"""
+ p = 2**255 - 19
+ def __init__(self, x):
+ self.val = x % self.p
+
+ # Basic arithmetic operations
+ def __neg__ (self ): return fe(-self.val )
+ def __add__ (self, o): return fe( self.val + o.val )
+ def __sub__ (self, o): return fe( self.val - o.val )
+ def __mul__ (self, o): return fe((self.val * o.val ) % self.p)
+ def __truediv__ (self, o): return fe((self.val * o.invert().val) % self.p)
+ def __floordiv__(self, o): return fe( self.val // o )
+ def __pow__ (self, s): return fe(pow(self.val, s , self.p))
+ def invert (self ): return fe(pow(self.val, self.p-2, self.p))
+
+ def __eq__(self, other): return self.val % self.p == other.val % self.p
+ def __ne__(self, other): return self.val % self.p != other.val % self.p
+ def isPositive(self) : return self.val % self.p <= (p-1) // 2
+
+ def abs(self):
+ if self.isPositive(): return self
+ else : return -self
+
+ def print(self):
+ """prints a field element in little endian"""
+ m = self.val % self.p
+ for _ in range(32):
+ print(format(m % 256, '02x'), end='')
+ m //= 256
+ if m != 0: raise ValueError('number is too big!!')
+ print(':')
+
# Curve25519 constants
-p = 2**255 - 19 # prime field (note that p % 8 == 5)
-A = 486662
+p = fe.p
+A = fe(486662)
+# chosen non-square: 2
# B = 1
-# chosen non-square = 2
-
-def print_little(n):
- """prints a field element in little endian"""
- m = n % p
- for _ in range(32):
- print(format(m % 256, '02x'), end='')
- m //= 256
- if m != 0: raise ValueError('number is too big!!')
- print(':')
-
-def binary(b):
- return [int(c) for c in list(format(b, 'b'))]
-
-def exp(a, b): return pow(a, b, p)
-def invert(n): return exp(n, p-2)
-def m_abs(n):
- """Modular absolute value of n, to canonicalise square roots."""
- m = n%p
- if m <= (p-1) // 2: return m
- else : return -m % p
+def chi (n): return n**((p-1)//2)
+def is_square(n): return n == fe(0) or chi(n) == fe(1)
-def chi (n): return exp(n, (p-1)//2)
-def is_square(n): return n%p == 0 or chi(n) == 1
-
-sqrt1 = m_abs(exp(2, (p-1) // 4) * exp(-1, (p+3) // 8))
+sqrt1 = ((fe(2)**((p-1) // 4)) * fe(-1)**((p+3) // 8)).abs()
def sqrt(n):
if not(is_square(n)) : raise ValueError('Not a square!')
- root = exp(n, (p+3) // 8)
- if (root * root) % p == n % p : return m_abs(root)
- else : return m_abs(root * sqrt1)
-
+ root = n**((p+3) // 8)
+ if root * root != n: root = (root * sqrt1)
+ if root * root != n: raise ValueError('Should be a square!!')
+ return root.abs()
# Elligator 2
def hash_to_curve(r):
- w = (-A * invert(1 + 2 * r**2) ) % p
- e = (chi(w**3 + A*w**2 + w) ) % p
- u = (e*w - (1-e)*A//2 ) % p
- v = (-e * sqrt(u**3 + A*u**2 + u)) % p
+ w = -A / (fe(1) + fe(2) * r**2)
+ e = chi(w**3 + A*w**2 + w)
+ u = e*w - (fe(1)-e)*(A//2)
+ v = -e * sqrt(u**3 + A*u**2 + u)
return (u, v)
def can_curve_to_hash(point):
x = point[0]
- return x != -A and is_square(-2 * x * (x + A))
+ return x != -A and is_square(-fe(2) * x * (x + A))
def curve_to_hash(point):
+ if not can_curve_to_hash(point):
+ raise ValueError('cannot curve to hash')
u = point[0]
v = point[1]
- sq1 = sqrt(-u * invert(2 * (u+A)))
- sq2 = sqrt(-(u+A) * invert(2 * u))
- if v % p <= (p-1) // 2: return sq1
- else : return sq2
+ sq1 = sqrt(-u / (fe(2) * (u+A)))
+ sq2 = sqrt(-(u+A) / (fe(2) * u ))
+ if v.isPositive(): return sq1
+ else : return sq2
+
+# round trip test
+for i in range(50):
+ h = fe(1234567890 * i).invert() # "random" hash
+ pp = hash_to_curve(h)
+ hh = curve_to_hash(pp)
+ ppp = hash_to_curve(hh)
+ if hh != h.abs() : raise ValueError('h != hh')
+ if pp != ppp : raise ValueError('pp != ppp')
+
# Edwards (Edwards25519)
# -x^2 + y^2 = 1 + d*x^2*y^2
-d = (-121665 * invert(121666)) % p
+d = fe(-121665) / fe(121666)
def point_add(a, b):
- x1 = a[0]; y1 = a[1];
- x2 = b[0]; y2 = b[1];
- x = ((x1*y2 + x2*y1) * invert(1 + d*x1*x2*y1*y2)) % p
- y = ((y1*y2 + x1*x2) * invert(1 - d*x1*x2*y1*y2)) % p
+ x1 = a[0]; y1 = a[1];
+ x2 = b[0]; y2 = b[1];
+ denum = d*x1*x2*y1*y2
+ x = (x1*y2 + x2*y1) / (fe(1) + denum)
+ y = (y1*y2 + x1*x2) / (fe(1) - denum)
return (x, y)
def trim(scalar):
return trimmed
def scalarmult(point, scalar):
- acc = (0, 1)
- for i in binary(trim(scalar)):
+ acc = (fe(0), fe(1))
+ trimmed = trim(scalar)
+ binary = [int(c) for c in list(format(trimmed, 'b'))]
+ for i in binary:
acc = point_add(acc, acc)
if i == 1:
acc = point_add(acc, point)
return acc
-eby = (4 * invert(5)) % p
-ebx = sqrt((eby**2 - 1) * invert(1 + d * eby**2))
+eby = fe(4) / fe(5)
+ebx = sqrt((eby**2 - fe(1)) / (fe(1) + d * eby**2))
edwards_base = (ebx, eby)
def scalarbase(scalar):
def from_edwards(point):
x = point[0]
y = point[1]
- u = ((1 + y) * invert(1 - y)) % p
- v = m_abs(sqrt(-486664) * u * invert(x))
+ u = (fe(1) + y) / (fe(1) - y)
+ v = (sqrt(fe(-486664)) * u / x).abs()
return (u, v)
# entire key generation chain
def private_to_hash(scalar):
- xy = scalarbase(private)
+ xy = scalarbase(scalar)
uv = from_edwards(xy)
if can_curve_to_hash(uv):
return curve_to_hash(uv)
return None
def full_cycle_check(scalar):
- print_little(scalar)
- xy = scalarbase(private)
+ fe(scalar).print()
+ xy = scalarbase(scalar)
uv = from_edwards(xy)
h = private_to_hash(scalar)
- print_little(uv[0])
- print_little(uv[1])
- if h == None:
+ uv[0].print()
+ uv[1].print()
+ if h is None:
print('00:') # Failure
print('00:') # dummy value for the hash
else:
print('01:') # Success
- print_little(h) # actual value for the hash
+ h.print() # actual value for the hash
c = hash_to_curve(h)
if c != uv:
print('Round trip failure')