From: Loup Vaillant Date: Wed, 19 Feb 2020 20:30:47 +0000 (+0100) Subject: Added the fe (field element) type for readability X-Git-Url: https://git.codecow.com/?a=commitdiff_plain;h=a57d39eb28f8849076b8475cdab3cfc929a4704a;p=Monocypher.git Added the fe (field element) type for readability Having to write those modulo operators everywhere was tiresome. Having an explicit field element type allows a more direct writing. It also helps Python throw type errors if we misuse anything. --- diff --git a/tests/gen/elligator.py b/tests/gen/elligator.py index 7634bf2..5c47222 100755 --- a/tests/gen/elligator.py +++ b/tests/gen/elligator.py @@ -51,74 +51,99 @@ # with this software. If not, see # +class fe: + """Prime field over 2^255 - 19""" + p = 2**255 - 19 + def __init__(self, x): + self.val = x % self.p + + # Basic arithmetic operations + def __neg__ (self ): return fe(-self.val ) + def __add__ (self, o): return fe( self.val + o.val ) + def __sub__ (self, o): return fe( self.val - o.val ) + def __mul__ (self, o): return fe((self.val * o.val ) % self.p) + def __truediv__ (self, o): return fe((self.val * o.invert().val) % self.p) + def __floordiv__(self, o): return fe( self.val // o ) + def __pow__ (self, s): return fe(pow(self.val, s , self.p)) + def invert (self ): return fe(pow(self.val, self.p-2, self.p)) + + def __eq__(self, other): return self.val % self.p == other.val % self.p + def __ne__(self, other): return self.val % self.p != other.val % self.p + def isPositive(self) : return self.val % self.p <= (p-1) // 2 + + def abs(self): + if self.isPositive(): return self + else : return -self + + def print(self): + """prints a field element in little endian""" + m = self.val % self.p + for _ in range(32): + print(format(m % 256, '02x'), end='') + m //= 256 + if m != 0: raise ValueError('number is too big!!') + print(':') + # Curve25519 constants -p = 2**255 - 19 # prime field (note that p % 8 == 5) -A = 486662 +p = fe.p +A = fe(486662) +# chosen non-square: 2 # B = 1 -# chosen non-square = 2 - -def print_little(n): - """prints a field element in little endian""" - m = n % p - for _ in range(32): - print(format(m % 256, '02x'), end='') - m //= 256 - if m != 0: raise ValueError('number is too big!!') - print(':') - -def binary(b): - return [int(c) for c in list(format(b, 'b'))] - -def exp(a, b): return pow(a, b, p) -def invert(n): return exp(n, p-2) -def m_abs(n): - """Modular absolute value of n, to canonicalise square roots.""" - m = n%p - if m <= (p-1) // 2: return m - else : return -m % p +def chi (n): return n**((p-1)//2) +def is_square(n): return n == fe(0) or chi(n) == fe(1) -def chi (n): return exp(n, (p-1)//2) -def is_square(n): return n%p == 0 or chi(n) == 1 - -sqrt1 = m_abs(exp(2, (p-1) // 4) * exp(-1, (p+3) // 8)) +sqrt1 = ((fe(2)**((p-1) // 4)) * fe(-1)**((p+3) // 8)).abs() def sqrt(n): if not(is_square(n)) : raise ValueError('Not a square!') - root = exp(n, (p+3) // 8) - if (root * root) % p == n % p : return m_abs(root) - else : return m_abs(root * sqrt1) - + root = n**((p+3) // 8) + if root * root != n: root = (root * sqrt1) + if root * root != n: raise ValueError('Should be a square!!') + return root.abs() # Elligator 2 def hash_to_curve(r): - w = (-A * invert(1 + 2 * r**2) ) % p - e = (chi(w**3 + A*w**2 + w) ) % p - u = (e*w - (1-e)*A//2 ) % p - v = (-e * sqrt(u**3 + A*u**2 + u)) % p + w = -A / (fe(1) + fe(2) * r**2) + e = chi(w**3 + A*w**2 + w) + u = e*w - (fe(1)-e)*(A//2) + v = -e * sqrt(u**3 + A*u**2 + u) return (u, v) def can_curve_to_hash(point): x = point[0] - return x != -A and is_square(-2 * x * (x + A)) + return x != -A and is_square(-fe(2) * x * (x + A)) def curve_to_hash(point): + if not can_curve_to_hash(point): + raise ValueError('cannot curve to hash') u = point[0] v = point[1] - sq1 = sqrt(-u * invert(2 * (u+A))) - sq2 = sqrt(-(u+A) * invert(2 * u)) - if v % p <= (p-1) // 2: return sq1 - else : return sq2 + sq1 = sqrt(-u / (fe(2) * (u+A))) + sq2 = sqrt(-(u+A) / (fe(2) * u )) + if v.isPositive(): return sq1 + else : return sq2 + +# round trip test +for i in range(50): + h = fe(1234567890 * i).invert() # "random" hash + pp = hash_to_curve(h) + hh = curve_to_hash(pp) + ppp = hash_to_curve(hh) + if hh != h.abs() : raise ValueError('h != hh') + if pp != ppp : raise ValueError('pp != ppp') + # Edwards (Edwards25519) # -x^2 + y^2 = 1 + d*x^2*y^2 -d = (-121665 * invert(121666)) % p +d = fe(-121665) / fe(121666) def point_add(a, b): - x1 = a[0]; y1 = a[1]; - x2 = b[0]; y2 = b[1]; - x = ((x1*y2 + x2*y1) * invert(1 + d*x1*x2*y1*y2)) % p - y = ((y1*y2 + x1*x2) * invert(1 - d*x1*x2*y1*y2)) % p + x1 = a[0]; y1 = a[1]; + x2 = b[0]; y2 = b[1]; + denum = d*x1*x2*y1*y2 + x = (x1*y2 + x2*y1) / (fe(1) + denum) + y = (y1*y2 + x1*x2) / (fe(1) - denum) return (x, y) def trim(scalar): @@ -128,15 +153,17 @@ def trim(scalar): return trimmed def scalarmult(point, scalar): - acc = (0, 1) - for i in binary(trim(scalar)): + acc = (fe(0), fe(1)) + trimmed = trim(scalar) + binary = [int(c) for c in list(format(trimmed, 'b'))] + for i in binary: acc = point_add(acc, acc) if i == 1: acc = point_add(acc, point) return acc -eby = (4 * invert(5)) % p -ebx = sqrt((eby**2 - 1) * invert(1 + d * eby**2)) +eby = fe(4) / fe(5) +ebx = sqrt((eby**2 - fe(1)) / (fe(1) + d * eby**2)) edwards_base = (ebx, eby) def scalarbase(scalar): @@ -148,31 +175,31 @@ def scalarbase(scalar): def from_edwards(point): x = point[0] y = point[1] - u = ((1 + y) * invert(1 - y)) % p - v = m_abs(sqrt(-486664) * u * invert(x)) + u = (fe(1) + y) / (fe(1) - y) + v = (sqrt(fe(-486664)) * u / x).abs() return (u, v) # entire key generation chain def private_to_hash(scalar): - xy = scalarbase(private) + xy = scalarbase(scalar) uv = from_edwards(xy) if can_curve_to_hash(uv): return curve_to_hash(uv) return None def full_cycle_check(scalar): - print_little(scalar) - xy = scalarbase(private) + fe(scalar).print() + xy = scalarbase(scalar) uv = from_edwards(xy) h = private_to_hash(scalar) - print_little(uv[0]) - print_little(uv[1]) - if h == None: + uv[0].print() + uv[1].print() + if h is None: print('00:') # Failure print('00:') # dummy value for the hash else: print('01:') # Success - print_little(h) # actual value for the hash + h.print() # actual value for the hash c = hash_to_curve(h) if c != uv: print('Round trip failure')