]> git.codecow.com Git - Monocypher.git/commitdiff
Added the fe (field element) type for readability
authorLoup Vaillant <loup@loup-vaillant.fr>
Wed, 19 Feb 2020 20:30:47 +0000 (21:30 +0100)
committerLoup Vaillant <loup@loup-vaillant.fr>
Wed, 19 Feb 2020 20:30:47 +0000 (21:30 +0100)
Having to write those modulo operators everywhere was tiresome. Having
an explicit field element type allows a more direct writing. It also
helps Python throw type errors if we misuse anything.

tests/gen/elligator.py

index 7634bf2e8e6bc576ba96b1144b27f2bd78bb7c8b..5c47222644e0cf3336520c1951d0294f05891d87 100755 (executable)
 # with this software.  If not, see
 # <https://creativecommons.org/publicdomain/zero/1.0/>
 
+class fe:
+    """Prime field over 2^255 - 19"""
+    p = 2**255 - 19
+    def __init__(self, x):
+        self.val = x % self.p
+
+    # Basic arithmetic operations
+    def __neg__     (self   ): return fe(-self.val                            )
+    def __add__     (self, o): return fe( self.val +  o.val                   )
+    def __sub__     (self, o): return fe( self.val -  o.val                   )
+    def __mul__     (self, o): return fe((self.val *  o.val         ) % self.p)
+    def __truediv__ (self, o): return fe((self.val *  o.invert().val) % self.p)
+    def __floordiv__(self, o): return fe( self.val // o                       )
+    def __pow__     (self, s): return fe(pow(self.val, s       , self.p))
+    def invert      (self   ): return fe(pow(self.val, self.p-2, self.p))
+
+    def __eq__(self, other): return self.val % self.p == other.val % self.p
+    def __ne__(self, other): return self.val % self.p != other.val % self.p
+    def isPositive(self)   : return self.val % self.p <= (p-1) // 2
+
+    def abs(self):
+        if self.isPositive(): return  self
+        else                : return -self
+
+    def print(self):
+        """prints a field element in little endian"""
+        m = self.val % self.p
+        for _ in range(32):
+            print(format(m % 256, '02x'), end='')
+            m //= 256
+        if m != 0: raise ValueError('number is too big!!')
+        print(':')
+
 # Curve25519 constants
-p = 2**255 - 19 # prime field (note that p % 8 == 5)
-A = 486662
+p = fe.p
+A = fe(486662)
+# chosen non-square: 2
 # B = 1
-# chosen non-square = 2
-
-def print_little(n):
-    """prints a field element in little endian"""
-    m = n % p
-    for _ in range(32):
-        print(format(m % 256, '02x'), end='')
-        m //= 256
-    if m != 0: raise ValueError('number is too big!!')
-    print(':')
-
-def binary(b):
-    return [int(c) for c in list(format(b, 'b'))]
-
-def exp(a, b): return pow(a, b, p)
-def invert(n): return exp(n, p-2)
 
-def m_abs(n):
-    """Modular absolute value of n, to canonicalise square roots."""
-    m = n%p
-    if m <= (p-1) // 2: return m
-    else              : return -m % p
+def chi      (n): return n**((p-1)//2)
+def is_square(n): return n == fe(0) or chi(n) == fe(1)
 
-def chi      (n): return exp(n, (p-1)//2)
-def is_square(n): return n%p == 0 or chi(n) == 1
-
-sqrt1 = m_abs(exp(2, (p-1) // 4) * exp(-1, (p+3) // 8))
+sqrt1 = ((fe(2)**((p-1) // 4)) * fe(-1)**((p+3) // 8)).abs()
 
 def sqrt(n):
     if not(is_square(n)) : raise ValueError('Not a square!')
-    root = exp(n, (p+3) // 8)
-    if (root * root) % p == n % p : return m_abs(root)
-    else                          : return m_abs(root * sqrt1)
-
+    root = n**((p+3) // 8)
+    if root * root != n: root = (root * sqrt1)
+    if root * root != n: raise ValueError('Should be a square!!')
+    return root.abs()
 
 # Elligator 2
 def hash_to_curve(r):
-    w = (-A * invert(1 + 2 * r**2)   ) % p
-    e = (chi(w**3 + A*w**2 + w)      ) % p
-    u = (e*w - (1-e)*A//2            ) % p
-    v = (-e * sqrt(u**3 + A*u**2 + u)) % p
+    w = -A / (fe(1) + fe(2) * r**2)
+    e = chi(w**3 + A*w**2 + w)
+    u = e*w - (fe(1)-e)*(A//2)
+    v = -e * sqrt(u**3 + A*u**2 + u)
     return (u, v)
 
 def can_curve_to_hash(point):
     x = point[0]
-    return x != -A and is_square(-2 * x * (x + A))
+    return x != -A and is_square(-fe(2) * x * (x + A))
 
 def curve_to_hash(point):
+    if not can_curve_to_hash(point):
+        raise ValueError('cannot curve to hash')
     u   = point[0]
     v   = point[1]
-    sq1 = sqrt(-u * invert(2 * (u+A)))
-    sq2 = sqrt(-(u+A) * invert(2 * u))
-    if v % p <= (p-1) // 2: return sq1
-    else                  : return sq2
+    sq1 = sqrt(-u     / (fe(2) * (u+A)))
+    sq2 = sqrt(-(u+A) / (fe(2) * u    ))
+    if v.isPositive(): return sq1
+    else             : return sq2
+
+# round trip test
+for i in range(50):
+    h   = fe(1234567890 * i).invert() # "random" hash
+    pp  = hash_to_curve(h)
+    hh  = curve_to_hash(pp)
+    ppp = hash_to_curve(hh)
+    if hh != h.abs() : raise ValueError('h != hh')
+    if pp != ppp     : raise ValueError('pp != ppp')
+
 
 # Edwards (Edwards25519)
 # -x^2 + y^2 = 1 + d*x^2*y^2
-d = (-121665 * invert(121666)) % p
+d = fe(-121665) / fe(121666)
 
 def point_add(a, b):
-    x1 = a[0]; y1 = a[1];
-    x2 = b[0]; y2 = b[1];
-    x  = ((x1*y2 + x2*y1) * invert(1 + d*x1*x2*y1*y2)) % p
-    y  = ((y1*y2 + x1*x2) * invert(1 - d*x1*x2*y1*y2)) % p
+    x1    = a[0];  y1 = a[1];
+    x2    = b[0];  y2 = b[1];
+    denum = d*x1*x2*y1*y2
+    x     = (x1*y2 + x2*y1) / (fe(1) + denum)
+    y     = (y1*y2 + x1*x2) / (fe(1) - denum)
     return (x, y)
 
 def trim(scalar):
@@ -128,15 +153,17 @@ def trim(scalar):
     return trimmed
 
 def scalarmult(point, scalar):
-    acc = (0, 1)
-    for i in binary(trim(scalar)):
+    acc     = (fe(0), fe(1))
+    trimmed = trim(scalar)
+    binary  = [int(c) for c in list(format(trimmed, 'b'))]
+    for i in binary:
         acc = point_add(acc, acc)
         if i == 1:
             acc = point_add(acc, point)
     return acc
 
-eby = (4 * invert(5)) % p
-ebx = sqrt((eby**2 - 1) * invert(1 + d * eby**2))
+eby = fe(4) / fe(5)
+ebx = sqrt((eby**2 - fe(1)) / (fe(1) + d * eby**2))
 edwards_base = (ebx, eby)
 
 def scalarbase(scalar):
@@ -148,31 +175,31 @@ def scalarbase(scalar):
 def from_edwards(point):
     x = point[0]
     y = point[1]
-    u = ((1 + y) * invert(1 - y)) % p
-    v = m_abs(sqrt(-486664) * u * invert(x))
+    u = (fe(1) + y) / (fe(1) - y)
+    v = (sqrt(fe(-486664)) * u / x).abs()
     return (u, v)
 
 # entire key generation chain
 def private_to_hash(scalar):
-    xy = scalarbase(private)
+    xy = scalarbase(scalar)
     uv = from_edwards(xy)
     if can_curve_to_hash(uv):
         return curve_to_hash(uv)
     return None
 
 def full_cycle_check(scalar):
-    print_little(scalar)
-    xy = scalarbase(private)
+    fe(scalar).print()
+    xy = scalarbase(scalar)
     uv = from_edwards(xy)
     h  = private_to_hash(scalar)
-    print_little(uv[0])
-    print_little(uv[1])
-    if h == None:
+    uv[0].print()
+    uv[1].print()
+    if h is None:
         print('00:')    # Failure
         print('00:')    # dummy value for the hash
     else:
         print('01:')    # Success
-        print_little(h) # actual value for the hash
+        h.print()       # actual value for the hash
         c = hash_to_curve(h)
         if c != uv:
             print('Round trip failure')