]> git.codecow.com Git - Monocypher.git/commitdiff
Added Elligator2 inverse mapping
authorLoup Vaillant <loup@loup-vaillant.fr>
Mon, 16 Mar 2020 12:13:06 +0000 (13:13 +0100)
committerLoup Vaillant <loup@loup-vaillant.fr>
Mon, 16 Mar 2020 12:13:06 +0000 (13:13 +0100)
src/monocypher.c
src/monocypher.h
tests/gen/elligator-direct.py
tests/gen/elligator-inverse.py
tests/gen/elligator.py [changed mode: 0755->0644]
tests/gen/elligator_scalarmult.py [new file with mode: 0644]
tests/gen/makefile
tests/test.c

index 686ea7abfe2095aa5c85538ecdae515a70392ac3..545f1dfd4313305e92697e980a02876611f03aac 100644 (file)
@@ -1316,20 +1316,6 @@ static int fe_isodd(const fe f)
     return isodd;
 }
 
-// Returns 0 if f <= (p-1)/2, 1 otherwise.
-// "Positive" means between 0 and (p-1)/2
-// "Negative" means between (p+1)/2 and p-1
-// Since p is odd (2^255 - 19), the sign is easily tested by leveraging
-// overflow: for any f in [0..p[, (2*f)%p is odd iff 2*f > p
-static int fe_isnegative(const fe f)
-{
-    fe tmp;
-    fe_add(tmp, f, f);
-    int isneg = fe_isodd(tmp);
-    WIPE_BUFFER(tmp);
-    return isneg;
-}
-
 // Returns 0 if zero, 1 if non zero
 static int fe_isnonzero(const fe f)
 {
@@ -1617,12 +1603,13 @@ static int ge_frombytes_vartime(ge *h, const u8 s[32])
     return 0;
 }
 
+static const fe D2 = { // - 2 * 121665 / 121666
+    -21827239, -5839606, -30745221, 13898782, 229458,
+    15978800, -12551817, -6495438, 29715968, 9444199
+};
+
 static void ge_cache(ge_cached *c, const ge *p)
 {
-    static const fe D2 = { // - 2 * 121665 / 121666
-        -21827239, -5839606, -30745221, 13898782, 229458,
-        15978800, -12551817, -6495438, 29715968, 9444199
-    };
     fe_add (c->Yp, p->Y, p->X);
     fe_sub (c->Ym, p->Y, p->X);
     fe_copy(c->Z , p->Z      );
@@ -2215,6 +2202,8 @@ int crypto_check(const u8  signature[64],
 /// Elligator 2 ///
 ///////////////////
 
+static const fe A = {486662};
+
 // From the paper:
 // w = -A / (fe(1) + non_square * r^2)
 // e = chi(w^3 + A*w^2 + w)
@@ -2273,7 +2262,6 @@ void crypto_elligator2_direct(uint8_t curve[32], const uint8_t hash[32])
         -1917299, 15887451, -18755900, -7000830, -24778944,
         544946, -16816446, 4011309, -653372, 10741468,
     };
-    static const fe A   = {486662, 0, 0, 0, 0, 0, 0, 0, 0, 0};
     static const fe A2  = {12721188, 3529, 0, 0, 0, 0, 0, 0, 0, 0};
     static const fe one = {1, 0, 0, 0, 0, 0, 0, 0, 0, 0};
 
@@ -2312,12 +2300,129 @@ void crypto_elligator2_direct(uint8_t curve[32], const uint8_t hash[32])
     WIPE_BUFFER(t3);  WIPE_BUFFER(clamped);
 }
 
-int crypto_curve_to_hash(uint8_t hash[32], const uint8_t curve[32])
+
+// Compute the representative of a point (defined by the secret key and
+// tweak), if possible. If not it does nothing and returns -1
+// The tweak comprises 3 parts:
+// - Bits 4-5: random padding
+// - Bit  3  : sign of the v coordinate (0 if positive, 1 if negative)
+// - Bits 0-2: cofactor
+// The bits 6-7 are ignored.
+//
+// Note that to ensure the representative is fully random, we do *not*
+// clear the cofactor.
+int crypto_elligator2_inverse(u8 hash[32], const u8 secret_key [32], u8 tweak)
 {
+    static const fe lop_x = {
+        21352778, 5345713, 4660180, -8347857, 24143090,
+        14568123, 30185756, -12247770, -33528939, 8345319,
+    };
+    static const fe lop_y = {
+        -6952922, -1265500, 6862341, -7057498, -4037696,
+        -5447722, 31680899, -15325402, -19365852, 1569102,
+    };
+
+    u8 scalar[32];
     FOR (i, 0, 32) {
-        hash[i] = curve[i];
+        scalar[i] = secret_key[i];
+    }
+    trim_scalar(scalar);
+    ge pk;
+    ge_scalarmult_base(&pk, scalar);
+
+    // Select low order point
+    // We're computing the [cofactor]lop scalar multiplication, where:
+    //   cofactor = tweak & 7.
+    //   lop      = (lop_x, lop_y)
+    //   lop_x    = sqrt((sqrt(d + 1) + 1) / d)
+    //   lop_y    = -lop_x * sqrtm1
+    // Notes:
+    // - A (single) Montgomery ladder would be twice as slow.
+    // - An actual scalar multiplication would hurt performance.
+    // - A full table lookup would take more code.
+    int a = (tweak >> 2) & 1;
+    int b = (tweak >> 1) & 1;
+    int c = (tweak >> 0) & 1;
+    fe t1, t2, t3;
+    fe_0(t1);
+    fe_ccopy(t1, sqrtm1, b);
+    fe_ccopy(t1, lop_x , c);
+    fe_neg  (t3, t1);
+    fe_ccopy(t1, t3, a);
+    fe_1(t2);
+    fe_0(t3);
+    fe_ccopy(t2, t3   , b);
+    fe_ccopy(t2, lop_y, c);
+    fe_neg  (t3, t2);
+    fe_ccopy(t2, t3, a^b);
+    ge_precomp low_order_point;
+    fe_add(low_order_point.Yp, t2, t1);
+    fe_sub(low_order_point.Ym, t2, t1);
+    fe_mul(low_order_point.T2, t2, t1);
+    fe_mul(low_order_point.T2, low_order_point.T2, D2);
+
+    // Add low order point to the public key
+    ge_madd(&pk, &pk, &low_order_point, t1, t2);
+
+    // Convert to Montgomery u coordinate (we ignore the sign)
+    fe_add(t1, pk.Z, pk.Y);
+    fe_sub(t2, pk.Z, pk.Y);
+    fe_invert(t2, t2);
+    fe_mul(t1, t1, t2);
+
+    // Convert to representative
+    // From the paper:
+    // Let sq = -non_square * u * (u+A)
+    // if sq is not a square, or u = -A, there is no mapping
+    // Assuming there is a mapping:
+    //   if v is positive: r = sqrt(-(u+A) / u)
+    //   if v is negative: r = sqrt(-u / (u+A))
+    //
+    // We compute isr = invsqrt(-non_square * u * (u+A))
+    // if it wasn't a non-zero square, abort.
+    // else, isr = sqrt(-1 / (non_square * u * (u+A))
+    //
+    // This causes us to abort if u is zero, even though we shouldn't. This
+    // never happens in practice, because (i) a random point in the curve has
+    // a negligible chance of being zero, and (ii) scalar multiplication with
+    // a trimmed scalar *never* yields zero.
+    //
+    // Since:
+    //   isr * (u+A) = sqrt(-1     / (non_square * u * (u+A)) * (u+A)
+    //   isr * (u+A) = sqrt(-(u+A) / (non_square * u * (u+A))
+    // and:
+    //   isr = u = sqrt(-1 / (non_square * u * (u+A)) * u
+    //   isr = u = sqrt(-u / (non_square * u * (u+A))
+    // Therefore:
+    //   if v is positive: r = isr * (u+A)
+    //   if v is negative: r = isr * u
+    fe_add(t2, t1, A);
+    fe_mul(t3, t1, t2);
+    fe_mul_small(t3, t3, -2);
+    int is_square = invsqrt(t3, t3);
+    if (!is_square) {
+        // The only variable time bit.  This ultimately reveals how many
+        // tries it took us to find a representable key.
+        // This does not affect security as long as we try keys at random.
+        WIPE_BUFFER(t1);  WIPE_BUFFER(scalar);
+        WIPE_BUFFER(t2);  WIPE_CTX(&pk);
+        WIPE_BUFFER(t3);  WIPE_CTX(&low_order_point);
+        return -1;
     }
-    return -1;
+    fe_ccopy(t1, t2, (tweak >> 3) & 1);
+    fe_mul  (t3, t1, t3);
+    fe_add  (t1, t3, t3);
+    fe_neg  (t2, t3);
+    fe_ccopy(t3, t2, fe_isodd(t1));
+    fe_tobytes(hash, t3);
+
+    // Pad with two random bits
+    hash[31] |= (tweak << 2) & 0xc0;
+
+    WIPE_BUFFER(t1);  WIPE_BUFFER(scalar);
+    WIPE_BUFFER(t2);  WIPE_CTX(&pk);
+    WIPE_BUFFER(t3);  WIPE_CTX(&low_order_point);
+    return 0;
 }
 
 ////////////////////
index 24d65fa7245829bea8183c8430f0eb70844dcb3f..0b76939a1f2805e7748bd09560fc293682583542 100644 (file)
@@ -254,8 +254,9 @@ void crypto_check_init_custom_hash(crypto_check_ctx_abstract *ctx,
 
 // Elligator 2
 // -----------
-void crypto_elligator2_direct(uint8_t curve[32], const uint8_t hash [32]);
-
+void crypto_elligator2_direct(uint8_t curve[32], const uint8_t hash[32]);
+int crypto_elligator2_inverse(uint8_t hash[32], const uint8_t secret_key[32],
+                              uint8_t tweak);
 
 ////////////////////////////
 /// Low level primitives ///
index 72a43373da3fb5de242eee4e68a8e30022f372bd..7cc12b549c9f231062ed60727acd9e060eddb1b9 100755 (executable)
@@ -61,7 +61,7 @@ from random    import randrange
 def direct(r1):
     q1 = hash_to_curve(r1)
     q2 = fast_hash_to_curve(r1)
-    r2 = curve_to_hash(q1)
+    r2 = curve_to_hash(q1[0], q1[1].is_negative())
     if q1 != q2: raise ValueError('Incorrect fast_hash_to_curve')
     if r1 != r2: raise ValueError('Round trip failure')
     r1   .print()
index a9306baddf939f6c09b62dd24ead881ff3a2b174..361c0852135f3ecf837deb0a6fe1fb4471412e65 100755 (executable)
 # <https://creativecommons.org/publicdomain/zero/1.0/>
 
 from elligator import fe
-from elligator import x25519_public_key
 from elligator import can_curve_to_hash
 from elligator import curve_to_hash
 from elligator import fast_curve_to_hash
 from elligator import hash_to_curve
-from elligator import fast_hash_to_curve
-from sys       import stdin
 
-# Test a full round trip, and print the relevant test vectors
-def full_cycle_check(private_key, u):
-    fe(private_key).print()
-    uv = x25519_public_key(private_key)
-    if uv [0] != u: raise ValueError('Test vector failure')
-    uv[0].print()
-    uv[1].print()
-    if can_curve_to_hash(uv):
-        h  = curve_to_hash(uv)
-        if h.is_negative(): raise ValueError('Non Canonical representative')
-        fh = fast_curve_to_hash(uv)
-        if fh != h: raise ValueError('Incorrect fast_curve_to_hash()')
-        print('01:')    # Success
-        h.print()       # actual value for the hash
-        c = hash_to_curve(h)
-        f = fast_hash_to_curve(h)
-        if f != c   : raise ValueError('Incorrect fast_hash_to_curve()')
-        if c != uv  : raise ValueError('Round trip failure')
-    else:
-        fh = fast_curve_to_hash(uv)
-        if not (fh is None): raise ValueError('Fast Curve to Hash did not fail')
-        print('00:')    # Failure
-        print('00:')    # dummy value for the hash
+from elligator_scalarmult import scalarmult
+from elligator_scalarmult import print_scalar
 
-# read test vectors:
-def read_vector(vector): # vector: little endian hex number
-    cut = vector[:64]    # remove final ':' character
-    acc = 0              # final sum
-    pos = 1              # power of 256
-    for b in bytes.fromhex(cut):
-        acc += b * pos
-        pos *= 256
-    return acc
+from random import randrange
 
-def read_test_vectors():
-    vectors = []
-    lines = [x.strip() for x in stdin.readlines() if x.strip()]
-    for i in range(len(lines) // 2):
-        private = read_vector(lines[i*2    ])
-        public  = read_vector(lines[i*2 + 1])
-        vectors.append((private, fe(public)))
-    return vectors
+def private_to_hash(scalar, tweak):
+    cofactor      = tweak % 8     ; tweak = tweak // 8
+    v_is_negative = tweak % 2 == 1; tweak = tweak // 2
+    msb           = tweak * 2**254
+    u             = scalarmult(scalar, cofactor)
+    r1 = None
+    if can_curve_to_hash(u):
+        r1 = curve_to_hash(u, v_is_negative)
+    r2 = fast_curve_to_hash(u, v_is_negative)
+    if r1 != r2: raise ValueError('Incoherent hash_to_curve')
+    if r1 is None:
+        return None
+    if r1.val > 2**254: raise ValueError('Representative too big')
+    u2, v2 = hash_to_curve(r1)
+    if u2 != u: raise ValueError('Round trip failure')
+    return r1.val + msb
 
-vectors = read_test_vectors()
-for v in vectors:
-    private = v[0]
-    public  = v[1]
-    full_cycle_check(private, public)
-    print('')
+# All possible failures
+for tweak in range(2**4):
+    while True:
+        scalar = randrange(0, 2**256)
+        r      = private_to_hash(scalar, tweak)
+        if r is None:
+            print_scalar(scalar)
+            print(format(tweak, '02x') + ":")
+            print('ff:') # Failure
+            print('00:') # dummy value for the hash
+            print()
+            break
+
+# All possible successes
+for tweak in range(2**6):
+    while True:
+        scalar = randrange(0, 2**256)
+        r      = private_to_hash(scalar, tweak)
+        if r is not None:
+            print_scalar(scalar)
+            print(format(tweak, '02x') + ":")
+            print('00:') # Success
+            print_scalar(r)
+            print()
+            break
old mode 100755 (executable)
new mode 100644 (file)
index 4fef6bf..e7a28d2
@@ -137,81 +137,6 @@ def sqrt(n):
     if root * root != n: raise ValueError('Should be a square!!')
     return root.abs()
 
-#########################
-# scalar multiplication #
-#########################
-# Clamp the scalar.
-# % 8 stops subgroup attacks
-# Clearing bit 255 and setting bit 254 facilitates constant time ladders.
-def trim(scalar):
-    trimmed = scalar - scalar % 8
-    trimmed = trimmed % 2**254
-    trimmed = trimmed + 2**254
-    return trimmed
-
-# Edwards25519 equation (d defined below):
-# -x^2 + y^2 = 1 + d*x^2*y^2
-d = fe(-121665) / fe(121666)
-
-# Point addition:
-# denum = d*x1*x2*y1*y2
-# x     = (x1*y2 + x2*y1) / (1 + denum)
-# y     = (y1*y2 + x1*x2) / (1 - denum)
-# To avoid divisions, we use affine coordinates: x = X/Z, y = Y/Z.
-# We can multiply Z instead of dividing X and Y.
-def point_add(a, b):
-    x1, y1, z1 = a
-    x2, y2, z2 = b
-    denum = d*x1*x2*y1*y2
-    z1z2  = z1 * z2
-    z1z22 = z1z2**2
-    xt    = z1z2 * (x1*y2 + x2*y1)
-    yt    = z1z2 * (y1*y2 + x1*x2)
-    zx    = z1z22 + denum
-    zy    = z1z22 - denum
-    return (xt*zy, yt*zx, zx*zy)
-
-# scalar multiplication:
-# point + point + ... + point, scalar times
-# (using a double and add ladder for speed)
-def scalarmult(point, scalar):
-    affine  = (point[0], point[1], fe(1))
-    acc     = (fe(0), fe(1), fe(1))
-    trimmed = trim(scalar)
-    binary  = [int(c) for c in list(format(trimmed, 'b'))]
-    for i in binary:
-        acc = point_add(acc, acc)
-        if i == 1:
-            acc = point_add(acc, affine)
-    return acc
-
-# edwards base point
-eby = fe(4) / fe(5)
-ebx = sqrt((eby**2 - fe(1)) / (fe(1) + d * eby**2))
-edwards_base = (ebx, eby)
-
-# scalar multiplication of the base point
-def scalarbase(scalar):
-    return scalarmult(edwards_base, scalar)
-
-sqrt_mA2 = sqrt(fe(-486664)) # sqrt(-(A+2))
-
-# conversion to Montgomery
-# (u, v) = ((1+y)/(1-y), sqrt(-486664)*u/x)
-# (x, y) = (sqrt(-486664)*u/v, (u-1)/(u+1))
-def from_edwards(point):
-    x, y, z = point
-    u   = z + y
-    zu  = z - y
-    v   = u * z * sqrt_mA2
-    zv  = zu * x
-    div = (zu * zv).invert() # now we have to divide
-    return (u*zv*div, v*zu*div)
-
-# Generates an X25519 public key from the given private key.
-def x25519_public_key(private_key):
-    return from_edwards(scalarbase(private_key))
-
 ###########################
 # Elligator 2 (reference) #
 ###########################
@@ -238,19 +163,17 @@ def hash_to_curve(r):
     return (u, v)
 
 # Test whether a point has a representative, straight from the paper.
-def can_curve_to_hash(point):
-    u = point[0]
+def can_curve_to_hash(u):
     return u != -A and is_square(-non_square * u * (u+A))
 
 # Computes the representative of a point, straight from the paper.
-def curve_to_hash(point):
-    if not can_curve_to_hash(point):
+def curve_to_hash(u, v_is_negative):
+    if not can_curve_to_hash(u):
         raise ValueError('cannot curve to hash')
-    u, v = point
     sq1  = sqrt(-u     / (non_square * (u+A)))
     sq2  = sqrt(-(u+A) / (non_square * u    ))
-    if v.is_positive(): return sq1
-    else              : return sq2
+    if v_is_negative: return sq2
+    else            : return sq1
 
 #####################
 # Elligator2 (fast) #
@@ -434,14 +357,13 @@ def fast_hash_to_curve(r):
 # Therefore:
 #   if v is positive: r = isr * (u+A)
 #   if v is negative: r = isr * u
-def fast_curve_to_hash(point):
-    u, v = point
+def fast_curve_to_hash(u, v_is_negative):
     t = u + A
     r = -non_square * u * t
     isr, is_square = invsqrt(r)
     if not is_square:
         return None
-    if v.is_positive(): t = u
-    r = t * isr
+    if v_is_negative: u = t
+    r = u * isr
     r = r.abs()
     return r
diff --git a/tests/gen/elligator_scalarmult.py b/tests/gen/elligator_scalarmult.py
new file mode 100644 (file)
index 0000000..05da5eb
--- /dev/null
@@ -0,0 +1,230 @@
+#! /usr/bin/env python3
+
+# This file is dual-licensed.  Choose whichever licence you want from
+# the two licences listed below.
+#
+# The first licence is a regular 2-clause BSD licence.  The second licence
+# is the CC-0 from Creative Commons. It is intended to release Monocypher
+# to the public domain.  The BSD licence serves as a fallback option.
+#
+# SPDX-License-Identifier: BSD-2-Clause OR CC0-1.0
+#
+# ------------------------------------------------------------------------
+#
+# Copyright (c) 2020, Loup Vaillant
+# All rights reserved.
+#
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# 1. Redistributions of source code must retain the above copyright
+#    notice, this list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright
+#    notice, this list of conditions and the following disclaimer in the
+#    documentation and/or other materials provided with the
+#    distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+#
+# ------------------------------------------------------------------------
+#
+# Written in 2020 by Loup Vaillant
+#
+# To the extent possible under law, the author(s) have dedicated all copyright
+# and related neighboring rights to this software to the public domain
+# worldwide.  This software is distributed without any warranty.
+#
+# You should have received a copy of the CC0 Public Domain Dedication along
+# with this software.  If not, see
+# <https://creativecommons.org/publicdomain/zero/1.0/>
+
+from elligator import fe
+from elligator import sqrt
+from elligator import sqrtm1
+
+def print_scalar(scalar):
+    """prints a scalar element in little endian"""
+    while scalar != 0:
+        print(format(scalar % 256, '02x'), end='')
+        scalar //= 256
+    print(':')
+
+#########################################
+# scalar multiplication (Edwards space) #
+#########################################
+
+# Edwards25519 equation (d defined below):
+# -x^2 + y^2 = 1 + d*x^2*y^2
+d = fe(-121665) / fe(121666)
+
+# Point addition:
+# denum = d*x1*x2*y1*y2
+# x     = (x1*y2 + x2*y1) / (1 + denum)
+# y     = (y1*y2 + x1*x2) / (1 - denum)
+# To avoid divisions, we use affine coordinates: x = X/Z, y = Y/Z.
+# We can multiply Z instead of dividing X and Y.
+def point_add(a, b):
+    x1, y1, z1 = a
+    x2, y2, z2 = b
+    denum = d*x1*x2*y1*y2
+    z1z2  = z1 * z2
+    z1z22 = z1z2**2
+    xt    = z1z2 * (x1*y2 + x2*y1)
+    yt    = z1z2 * (y1*y2 + x1*x2)
+    zx    = z1z22 + denum
+    zy    = z1z22 - denum
+    return (xt*zy, yt*zx, zx*zy)
+
+# Point addition, with the final division
+def point_add2(p1, p2):
+    x1, y1  = p1
+    x2, y2  = p2
+    z1, z2  = (fe(1), fe(1))
+    x, y, z = point_add((x1, y1, z1), (x2, y2, z2))
+    div     = z.invert()
+    return (x*div, y*div)
+
+# scalar multiplication in edwards space:
+# point + point + ... + point, scalar times
+# (using a double and add ladder for speed)
+def ed_scalarmult(point, scalar):
+    affine  = (point[0], point[1], fe(1))
+    acc     = (fe(0), fe(1), fe(1))
+    binary  = [int(c) for c in list(format(scalar, 'b'))]
+    for i in binary:
+        acc = point_add(acc, acc)
+        if i == 1:
+            acc = point_add(acc, affine)
+    return acc
+
+# convert the point to Montgomery (u coordinate only)
+# (u, v) = ((1+y)/(1-y), sqrt(-486664)*u/x)
+# (x, y) = (sqrt(-486664)*u/v, (u-1)/(u+1))
+def from_edwards(point):
+    x, y, z = point
+    return (z + y) / (z - y)
+
+# edwards base point
+eby = fe(4) / fe(5)
+ebx = sqrt((eby**2 - fe(1)) / (fe(1) + d * eby**2))
+edwards_base = (ebx, eby)
+
+############################################
+# scalar multiplication (Montgomery space) #
+############################################
+def mt_scalarmult(u, scalar):
+    x1     = u
+    x2, z2 = fe(1), fe(0) # "zero" point
+    x3, z3 = x1   , fe(1) # "one"  point
+    binary = [int(c) for c in list(format(scalar, 'b'))]
+    for b in binary:
+        # Montgomery ladder step:
+        # if b == 0, then (P2, P3) == (P2*2 , P2+P3)
+        # if b == 1, then (P2, P3) == (P2+P3, P3*2 )
+        if b == 1:
+            x2, x3 = x3, x2
+            z2, z3 = z3, z2
+        t0 = x3 - z3
+        t1 = x2 - z2
+        x2 = x2 + z2
+        z2 = x3 + z3
+        z3 = t0 * x2
+        z2 = z2 * t1
+        t0 = t1**2
+        t1 = x2**2
+        x3 = z3 + z2
+        z2 = z3 - z2
+        x2 = t1 * t0
+        t1 = t1 - t0
+        z2 = z2**2
+        z3 = t1 * fe(121666)
+        x3 = x3**2
+        t0 = t0 + z3
+        z3 = x1 * z2
+        z2 = t1 * t0
+        if b == 1:
+            x2, x3 = x3, x2
+            z2, z3 = z3, z2
+    return x2 / z2
+
+montgomery_base = 9
+
+############################
+# Scalarmult with cofactor #
+############################
+
+# Keeping a random cofactor is important to keep points
+# indistinguishable from random.  (Else we'd notice all representatives
+# represent points with cleared cofactor.  Not exactly random.)
+
+# Point of order 8, used to add the cofactor component
+low_order_point_x = sqrt((sqrt(d + fe(1)) + fe(1)) / d)
+low_order_point_y = -low_order_point_x * sqrtm1
+low_order_point   = (low_order_point_x, low_order_point_y)
+
+def check_low_order_point():
+    lop2 = point_add2(low_order_point, low_order_point)
+    lop4 = point_add2(lop2, lop2)
+    lop8 = point_add2(lop4, lop4)
+    zero = (fe(0), fe(1))
+    if lop8 != zero: raise ValueError('low_order_point does not have low order')
+    if lop2 == zero: raise ValueError('low_order_point only has order 2')
+    if lop4 == zero: raise ValueError('low_order_point only has order 4')
+check_low_order_point()
+
+# base point + low order point
+ed_base = point_add2(low_order_point, edwards_base)   # in Edwards space
+mt_base = (fe(1) + ed_base[1]) / (fe(1) - ed_base[1]) # in Montgomery space
+
+# Clamp the scalar.
+# % 8 stops subgroup attacks
+# Clearing bit 255 and setting bit 254 facilitates constant time ladders.
+# We're not supposed to clear the cofactor, but scalar multiplication
+# usually does, and we want to reuse existing code as much as possible.
+def trim(scalar):
+    trimmed = scalar - scalar % 8
+    trimmed = trimmed % 2**254
+    trimmed = trimmed + 2**254
+    return trimmed
+
+order = 2**252 + 27742317777372353535851937790883648493
+
+# Single scalar multiplication (in Edwards space)
+def scalarmult1(scalar, cofactor):
+    co_cleared = cofactor * (5 * order)  # cleared main factor
+    combined   = trim(scalar) + co_cleared
+    return from_edwards(ed_scalarmult(ed_base, combined))
+
+# Single scalar multiplication (in Montgomery space)
+def scalarmult2(scalar, cofactor):
+    co_cleared = cofactor * (5 * order) # cleared main factor
+    combined   = trim(scalar) + co_cleared
+    return mt_scalarmult(mt_base, combined)
+
+# Double scalar multiplication (reuses EdDSA code)
+def scalarmult3(scalar, cofactor):
+    main_point = ed_scalarmult(edwards_base   , trim(scalar))
+    low_order  = ed_scalarmult(low_order_point, cofactor    )
+    return from_edwards(point_add(main_point, low_order))
+
+# Combine and compare all ways ofd doing the scalar multiplication
+def scalarmult(scalar, cofactor):
+    p1 = scalarmult1(scalar, cofactor)
+    p2 = scalarmult2(scalar, cofactor)
+    p3 = scalarmult3(scalar, cofactor)
+    if p1 != p2 or p1 != p3:
+        raise ValueError('Incoherent scalarmult')
+    return p1
index 07179e8a78038f6f33bec8fb7bbd16e30e13a98c..63fa6a9266b5a123abaf3b9b03b508e2bc1f4518 100644 (file)
@@ -68,8 +68,8 @@ clean:
        rm -f *.out *.vec *.o
        rm -f $(VECTORS)
 
-elligator_inv.vec: elligator-inverse.py elligator.py x25519_pk.all.vec
-       ./$< <x25519_pk.all.vec >$@
+elligator_inv.vec: elligator-inverse.py elligator.py elligator_scalarmult.py
+       ./$< >$@
 elligator_dir.vec: elligator-direct.py elligator.py
        ./$< >$@
 
index f23ce22fa6ba90010d915646147c5e0e2d1f1bd8..f56e56d7c8d65ba942c71d3dd4d68f51a88008ad 100644 (file)
@@ -292,6 +292,20 @@ static void elligator_dir(const vector in[], vector *out)
     crypto_elligator2_direct(out->buf, in->buf);
 }
 
+static void elligator_inv(const vector in[], vector *out)
+{
+    const vector *sk = in;
+    u8  tweak   = in[1].buf[0];
+    u8  failure = in[2].buf[0];
+    int check   = crypto_elligator2_inverse(out->buf, sk->buf, tweak);
+    if ((u8)check != failure) {
+        fprintf(stderr, "Elligator inverse map: failure mismatch\n");
+    }
+    if (check) {
+        out->buf[0] = 0;
+    }
+}
+
 //////////////////////////////
 /// Self consistency tests ///
 //////////////////////////////
@@ -916,6 +930,7 @@ int main(int argc, char *argv[])
     status |= TEST(ed_25519_check, 3);
     status |= test_x25519();
     status |= TEST(elligator_dir , 1);
+    status |= TEST(elligator_inv , 3);
 
     printf("\nProperty based tests");
     printf("\n--------------------\n");