From 092e8908dc173c535e00d0f3deb500324ac261de Mon Sep 17 00:00:00 2001 From: Loup Vaillant Date: Mon, 16 Mar 2020 13:13:06 +0100 Subject: [PATCH] Added Elligator2 inverse mapping --- src/monocypher.c | 149 ++++++++++++++++--- src/monocypher.h | 5 +- tests/gen/elligator-direct.py | 2 +- tests/gen/elligator-inverse.py | 93 ++++++------ tests/gen/elligator.py | 94 ++---------- tests/gen/elligator_scalarmult.py | 230 ++++++++++++++++++++++++++++++ tests/gen/makefile | 4 +- tests/test.c | 15 ++ 8 files changed, 430 insertions(+), 162 deletions(-) mode change 100755 => 100644 tests/gen/elligator.py create mode 100644 tests/gen/elligator_scalarmult.py diff --git a/src/monocypher.c b/src/monocypher.c index 686ea7a..545f1df 100644 --- a/src/monocypher.c +++ b/src/monocypher.c @@ -1316,20 +1316,6 @@ static int fe_isodd(const fe f) return isodd; } -// Returns 0 if f <= (p-1)/2, 1 otherwise. -// "Positive" means between 0 and (p-1)/2 -// "Negative" means between (p+1)/2 and p-1 -// Since p is odd (2^255 - 19), the sign is easily tested by leveraging -// overflow: for any f in [0..p[, (2*f)%p is odd iff 2*f > p -static int fe_isnegative(const fe f) -{ - fe tmp; - fe_add(tmp, f, f); - int isneg = fe_isodd(tmp); - WIPE_BUFFER(tmp); - return isneg; -} - // Returns 0 if zero, 1 if non zero static int fe_isnonzero(const fe f) { @@ -1617,12 +1603,13 @@ static int ge_frombytes_vartime(ge *h, const u8 s[32]) return 0; } +static const fe D2 = { // - 2 * 121665 / 121666 + -21827239, -5839606, -30745221, 13898782, 229458, + 15978800, -12551817, -6495438, 29715968, 9444199 +}; + static void ge_cache(ge_cached *c, const ge *p) { - static const fe D2 = { // - 2 * 121665 / 121666 - -21827239, -5839606, -30745221, 13898782, 229458, - 15978800, -12551817, -6495438, 29715968, 9444199 - }; fe_add (c->Yp, p->Y, p->X); fe_sub (c->Ym, p->Y, p->X); fe_copy(c->Z , p->Z ); @@ -2215,6 +2202,8 @@ int crypto_check(const u8 signature[64], /// Elligator 2 /// /////////////////// +static const fe A = {486662}; + // From the paper: // w = -A / (fe(1) + non_square * r^2) // e = chi(w^3 + A*w^2 + w) @@ -2273,7 +2262,6 @@ void crypto_elligator2_direct(uint8_t curve[32], const uint8_t hash[32]) -1917299, 15887451, -18755900, -7000830, -24778944, 544946, -16816446, 4011309, -653372, 10741468, }; - static const fe A = {486662, 0, 0, 0, 0, 0, 0, 0, 0, 0}; static const fe A2 = {12721188, 3529, 0, 0, 0, 0, 0, 0, 0, 0}; static const fe one = {1, 0, 0, 0, 0, 0, 0, 0, 0, 0}; @@ -2312,12 +2300,129 @@ void crypto_elligator2_direct(uint8_t curve[32], const uint8_t hash[32]) WIPE_BUFFER(t3); WIPE_BUFFER(clamped); } -int crypto_curve_to_hash(uint8_t hash[32], const uint8_t curve[32]) + +// Compute the representative of a point (defined by the secret key and +// tweak), if possible. If not it does nothing and returns -1 +// The tweak comprises 3 parts: +// - Bits 4-5: random padding +// - Bit 3 : sign of the v coordinate (0 if positive, 1 if negative) +// - Bits 0-2: cofactor +// The bits 6-7 are ignored. +// +// Note that to ensure the representative is fully random, we do *not* +// clear the cofactor. +int crypto_elligator2_inverse(u8 hash[32], const u8 secret_key [32], u8 tweak) { + static const fe lop_x = { + 21352778, 5345713, 4660180, -8347857, 24143090, + 14568123, 30185756, -12247770, -33528939, 8345319, + }; + static const fe lop_y = { + -6952922, -1265500, 6862341, -7057498, -4037696, + -5447722, 31680899, -15325402, -19365852, 1569102, + }; + + u8 scalar[32]; FOR (i, 0, 32) { - hash[i] = curve[i]; + scalar[i] = secret_key[i]; + } + trim_scalar(scalar); + ge pk; + ge_scalarmult_base(&pk, scalar); + + // Select low order point + // We're computing the [cofactor]lop scalar multiplication, where: + // cofactor = tweak & 7. + // lop = (lop_x, lop_y) + // lop_x = sqrt((sqrt(d + 1) + 1) / d) + // lop_y = -lop_x * sqrtm1 + // Notes: + // - A (single) Montgomery ladder would be twice as slow. + // - An actual scalar multiplication would hurt performance. + // - A full table lookup would take more code. + int a = (tweak >> 2) & 1; + int b = (tweak >> 1) & 1; + int c = (tweak >> 0) & 1; + fe t1, t2, t3; + fe_0(t1); + fe_ccopy(t1, sqrtm1, b); + fe_ccopy(t1, lop_x , c); + fe_neg (t3, t1); + fe_ccopy(t1, t3, a); + fe_1(t2); + fe_0(t3); + fe_ccopy(t2, t3 , b); + fe_ccopy(t2, lop_y, c); + fe_neg (t3, t2); + fe_ccopy(t2, t3, a^b); + ge_precomp low_order_point; + fe_add(low_order_point.Yp, t2, t1); + fe_sub(low_order_point.Ym, t2, t1); + fe_mul(low_order_point.T2, t2, t1); + fe_mul(low_order_point.T2, low_order_point.T2, D2); + + // Add low order point to the public key + ge_madd(&pk, &pk, &low_order_point, t1, t2); + + // Convert to Montgomery u coordinate (we ignore the sign) + fe_add(t1, pk.Z, pk.Y); + fe_sub(t2, pk.Z, pk.Y); + fe_invert(t2, t2); + fe_mul(t1, t1, t2); + + // Convert to representative + // From the paper: + // Let sq = -non_square * u * (u+A) + // if sq is not a square, or u = -A, there is no mapping + // Assuming there is a mapping: + // if v is positive: r = sqrt(-(u+A) / u) + // if v is negative: r = sqrt(-u / (u+A)) + // + // We compute isr = invsqrt(-non_square * u * (u+A)) + // if it wasn't a non-zero square, abort. + // else, isr = sqrt(-1 / (non_square * u * (u+A)) + // + // This causes us to abort if u is zero, even though we shouldn't. This + // never happens in practice, because (i) a random point in the curve has + // a negligible chance of being zero, and (ii) scalar multiplication with + // a trimmed scalar *never* yields zero. + // + // Since: + // isr * (u+A) = sqrt(-1 / (non_square * u * (u+A)) * (u+A) + // isr * (u+A) = sqrt(-(u+A) / (non_square * u * (u+A)) + // and: + // isr = u = sqrt(-1 / (non_square * u * (u+A)) * u + // isr = u = sqrt(-u / (non_square * u * (u+A)) + // Therefore: + // if v is positive: r = isr * (u+A) + // if v is negative: r = isr * u + fe_add(t2, t1, A); + fe_mul(t3, t1, t2); + fe_mul_small(t3, t3, -2); + int is_square = invsqrt(t3, t3); + if (!is_square) { + // The only variable time bit. This ultimately reveals how many + // tries it took us to find a representable key. + // This does not affect security as long as we try keys at random. + WIPE_BUFFER(t1); WIPE_BUFFER(scalar); + WIPE_BUFFER(t2); WIPE_CTX(&pk); + WIPE_BUFFER(t3); WIPE_CTX(&low_order_point); + return -1; } - return -1; + fe_ccopy(t1, t2, (tweak >> 3) & 1); + fe_mul (t3, t1, t3); + fe_add (t1, t3, t3); + fe_neg (t2, t3); + fe_ccopy(t3, t2, fe_isodd(t1)); + fe_tobytes(hash, t3); + + // Pad with two random bits + hash[31] |= (tweak << 2) & 0xc0; + + WIPE_BUFFER(t1); WIPE_BUFFER(scalar); + WIPE_BUFFER(t2); WIPE_CTX(&pk); + WIPE_BUFFER(t3); WIPE_CTX(&low_order_point); + return 0; } //////////////////// diff --git a/src/monocypher.h b/src/monocypher.h index 24d65fa..0b76939 100644 --- a/src/monocypher.h +++ b/src/monocypher.h @@ -254,8 +254,9 @@ void crypto_check_init_custom_hash(crypto_check_ctx_abstract *ctx, // Elligator 2 // ----------- -void crypto_elligator2_direct(uint8_t curve[32], const uint8_t hash [32]); - +void crypto_elligator2_direct(uint8_t curve[32], const uint8_t hash[32]); +int crypto_elligator2_inverse(uint8_t hash[32], const uint8_t secret_key[32], + uint8_t tweak); //////////////////////////// /// Low level primitives /// diff --git a/tests/gen/elligator-direct.py b/tests/gen/elligator-direct.py index 72a4337..7cc12b5 100755 --- a/tests/gen/elligator-direct.py +++ b/tests/gen/elligator-direct.py @@ -61,7 +61,7 @@ from random import randrange def direct(r1): q1 = hash_to_curve(r1) q2 = fast_hash_to_curve(r1) - r2 = curve_to_hash(q1) + r2 = curve_to_hash(q1[0], q1[1].is_negative()) if q1 != q2: raise ValueError('Incorrect fast_hash_to_curve') if r1 != r2: raise ValueError('Round trip failure') r1 .print() diff --git a/tests/gen/elligator-inverse.py b/tests/gen/elligator-inverse.py index a9306ba..361c085 100755 --- a/tests/gen/elligator-inverse.py +++ b/tests/gen/elligator-inverse.py @@ -52,60 +52,55 @@ # from elligator import fe -from elligator import x25519_public_key from elligator import can_curve_to_hash from elligator import curve_to_hash from elligator import fast_curve_to_hash from elligator import hash_to_curve -from elligator import fast_hash_to_curve -from sys import stdin -# Test a full round trip, and print the relevant test vectors -def full_cycle_check(private_key, u): - fe(private_key).print() - uv = x25519_public_key(private_key) - if uv [0] != u: raise ValueError('Test vector failure') - uv[0].print() - uv[1].print() - if can_curve_to_hash(uv): - h = curve_to_hash(uv) - if h.is_negative(): raise ValueError('Non Canonical representative') - fh = fast_curve_to_hash(uv) - if fh != h: raise ValueError('Incorrect fast_curve_to_hash()') - print('01:') # Success - h.print() # actual value for the hash - c = hash_to_curve(h) - f = fast_hash_to_curve(h) - if f != c : raise ValueError('Incorrect fast_hash_to_curve()') - if c != uv : raise ValueError('Round trip failure') - else: - fh = fast_curve_to_hash(uv) - if not (fh is None): raise ValueError('Fast Curve to Hash did not fail') - print('00:') # Failure - print('00:') # dummy value for the hash +from elligator_scalarmult import scalarmult +from elligator_scalarmult import print_scalar -# read test vectors: -def read_vector(vector): # vector: little endian hex number - cut = vector[:64] # remove final ':' character - acc = 0 # final sum - pos = 1 # power of 256 - for b in bytes.fromhex(cut): - acc += b * pos - pos *= 256 - return acc +from random import randrange -def read_test_vectors(): - vectors = [] - lines = [x.strip() for x in stdin.readlines() if x.strip()] - for i in range(len(lines) // 2): - private = read_vector(lines[i*2 ]) - public = read_vector(lines[i*2 + 1]) - vectors.append((private, fe(public))) - return vectors +def private_to_hash(scalar, tweak): + cofactor = tweak % 8 ; tweak = tweak // 8 + v_is_negative = tweak % 2 == 1; tweak = tweak // 2 + msb = tweak * 2**254 + u = scalarmult(scalar, cofactor) + r1 = None + if can_curve_to_hash(u): + r1 = curve_to_hash(u, v_is_negative) + r2 = fast_curve_to_hash(u, v_is_negative) + if r1 != r2: raise ValueError('Incoherent hash_to_curve') + if r1 is None: + return None + if r1.val > 2**254: raise ValueError('Representative too big') + u2, v2 = hash_to_curve(r1) + if u2 != u: raise ValueError('Round trip failure') + return r1.val + msb -vectors = read_test_vectors() -for v in vectors: - private = v[0] - public = v[1] - full_cycle_check(private, public) - print('') +# All possible failures +for tweak in range(2**4): + while True: + scalar = randrange(0, 2**256) + r = private_to_hash(scalar, tweak) + if r is None: + print_scalar(scalar) + print(format(tweak, '02x') + ":") + print('ff:') # Failure + print('00:') # dummy value for the hash + print() + break + +# All possible successes +for tweak in range(2**6): + while True: + scalar = randrange(0, 2**256) + r = private_to_hash(scalar, tweak) + if r is not None: + print_scalar(scalar) + print(format(tweak, '02x') + ":") + print('00:') # Success + print_scalar(r) + print() + break diff --git a/tests/gen/elligator.py b/tests/gen/elligator.py old mode 100755 new mode 100644 index 4fef6bf..e7a28d2 --- a/tests/gen/elligator.py +++ b/tests/gen/elligator.py @@ -137,81 +137,6 @@ def sqrt(n): if root * root != n: raise ValueError('Should be a square!!') return root.abs() -######################### -# scalar multiplication # -######################### -# Clamp the scalar. -# % 8 stops subgroup attacks -# Clearing bit 255 and setting bit 254 facilitates constant time ladders. -def trim(scalar): - trimmed = scalar - scalar % 8 - trimmed = trimmed % 2**254 - trimmed = trimmed + 2**254 - return trimmed - -# Edwards25519 equation (d defined below): -# -x^2 + y^2 = 1 + d*x^2*y^2 -d = fe(-121665) / fe(121666) - -# Point addition: -# denum = d*x1*x2*y1*y2 -# x = (x1*y2 + x2*y1) / (1 + denum) -# y = (y1*y2 + x1*x2) / (1 - denum) -# To avoid divisions, we use affine coordinates: x = X/Z, y = Y/Z. -# We can multiply Z instead of dividing X and Y. -def point_add(a, b): - x1, y1, z1 = a - x2, y2, z2 = b - denum = d*x1*x2*y1*y2 - z1z2 = z1 * z2 - z1z22 = z1z2**2 - xt = z1z2 * (x1*y2 + x2*y1) - yt = z1z2 * (y1*y2 + x1*x2) - zx = z1z22 + denum - zy = z1z22 - denum - return (xt*zy, yt*zx, zx*zy) - -# scalar multiplication: -# point + point + ... + point, scalar times -# (using a double and add ladder for speed) -def scalarmult(point, scalar): - affine = (point[0], point[1], fe(1)) - acc = (fe(0), fe(1), fe(1)) - trimmed = trim(scalar) - binary = [int(c) for c in list(format(trimmed, 'b'))] - for i in binary: - acc = point_add(acc, acc) - if i == 1: - acc = point_add(acc, affine) - return acc - -# edwards base point -eby = fe(4) / fe(5) -ebx = sqrt((eby**2 - fe(1)) / (fe(1) + d * eby**2)) -edwards_base = (ebx, eby) - -# scalar multiplication of the base point -def scalarbase(scalar): - return scalarmult(edwards_base, scalar) - -sqrt_mA2 = sqrt(fe(-486664)) # sqrt(-(A+2)) - -# conversion to Montgomery -# (u, v) = ((1+y)/(1-y), sqrt(-486664)*u/x) -# (x, y) = (sqrt(-486664)*u/v, (u-1)/(u+1)) -def from_edwards(point): - x, y, z = point - u = z + y - zu = z - y - v = u * z * sqrt_mA2 - zv = zu * x - div = (zu * zv).invert() # now we have to divide - return (u*zv*div, v*zu*div) - -# Generates an X25519 public key from the given private key. -def x25519_public_key(private_key): - return from_edwards(scalarbase(private_key)) - ########################### # Elligator 2 (reference) # ########################### @@ -238,19 +163,17 @@ def hash_to_curve(r): return (u, v) # Test whether a point has a representative, straight from the paper. -def can_curve_to_hash(point): - u = point[0] +def can_curve_to_hash(u): return u != -A and is_square(-non_square * u * (u+A)) # Computes the representative of a point, straight from the paper. -def curve_to_hash(point): - if not can_curve_to_hash(point): +def curve_to_hash(u, v_is_negative): + if not can_curve_to_hash(u): raise ValueError('cannot curve to hash') - u, v = point sq1 = sqrt(-u / (non_square * (u+A))) sq2 = sqrt(-(u+A) / (non_square * u )) - if v.is_positive(): return sq1 - else : return sq2 + if v_is_negative: return sq2 + else : return sq1 ##################### # Elligator2 (fast) # @@ -434,14 +357,13 @@ def fast_hash_to_curve(r): # Therefore: # if v is positive: r = isr * (u+A) # if v is negative: r = isr * u -def fast_curve_to_hash(point): - u, v = point +def fast_curve_to_hash(u, v_is_negative): t = u + A r = -non_square * u * t isr, is_square = invsqrt(r) if not is_square: return None - if v.is_positive(): t = u - r = t * isr + if v_is_negative: u = t + r = u * isr r = r.abs() return r diff --git a/tests/gen/elligator_scalarmult.py b/tests/gen/elligator_scalarmult.py new file mode 100644 index 0000000..05da5eb --- /dev/null +++ b/tests/gen/elligator_scalarmult.py @@ -0,0 +1,230 @@ +#! /usr/bin/env python3 + +# This file is dual-licensed. Choose whichever licence you want from +# the two licences listed below. +# +# The first licence is a regular 2-clause BSD licence. The second licence +# is the CC-0 from Creative Commons. It is intended to release Monocypher +# to the public domain. The BSD licence serves as a fallback option. +# +# SPDX-License-Identifier: BSD-2-Clause OR CC0-1.0 +# +# ------------------------------------------------------------------------ +# +# Copyright (c) 2020, Loup Vaillant +# All rights reserved. +# +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# 1. Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the +# distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# ------------------------------------------------------------------------ +# +# Written in 2020 by Loup Vaillant +# +# To the extent possible under law, the author(s) have dedicated all copyright +# and related neighboring rights to this software to the public domain +# worldwide. This software is distributed without any warranty. +# +# You should have received a copy of the CC0 Public Domain Dedication along +# with this software. If not, see +# + +from elligator import fe +from elligator import sqrt +from elligator import sqrtm1 + +def print_scalar(scalar): + """prints a scalar element in little endian""" + while scalar != 0: + print(format(scalar % 256, '02x'), end='') + scalar //= 256 + print(':') + +######################################### +# scalar multiplication (Edwards space) # +######################################### + +# Edwards25519 equation (d defined below): +# -x^2 + y^2 = 1 + d*x^2*y^2 +d = fe(-121665) / fe(121666) + +# Point addition: +# denum = d*x1*x2*y1*y2 +# x = (x1*y2 + x2*y1) / (1 + denum) +# y = (y1*y2 + x1*x2) / (1 - denum) +# To avoid divisions, we use affine coordinates: x = X/Z, y = Y/Z. +# We can multiply Z instead of dividing X and Y. +def point_add(a, b): + x1, y1, z1 = a + x2, y2, z2 = b + denum = d*x1*x2*y1*y2 + z1z2 = z1 * z2 + z1z22 = z1z2**2 + xt = z1z2 * (x1*y2 + x2*y1) + yt = z1z2 * (y1*y2 + x1*x2) + zx = z1z22 + denum + zy = z1z22 - denum + return (xt*zy, yt*zx, zx*zy) + +# Point addition, with the final division +def point_add2(p1, p2): + x1, y1 = p1 + x2, y2 = p2 + z1, z2 = (fe(1), fe(1)) + x, y, z = point_add((x1, y1, z1), (x2, y2, z2)) + div = z.invert() + return (x*div, y*div) + +# scalar multiplication in edwards space: +# point + point + ... + point, scalar times +# (using a double and add ladder for speed) +def ed_scalarmult(point, scalar): + affine = (point[0], point[1], fe(1)) + acc = (fe(0), fe(1), fe(1)) + binary = [int(c) for c in list(format(scalar, 'b'))] + for i in binary: + acc = point_add(acc, acc) + if i == 1: + acc = point_add(acc, affine) + return acc + +# convert the point to Montgomery (u coordinate only) +# (u, v) = ((1+y)/(1-y), sqrt(-486664)*u/x) +# (x, y) = (sqrt(-486664)*u/v, (u-1)/(u+1)) +def from_edwards(point): + x, y, z = point + return (z + y) / (z - y) + +# edwards base point +eby = fe(4) / fe(5) +ebx = sqrt((eby**2 - fe(1)) / (fe(1) + d * eby**2)) +edwards_base = (ebx, eby) + +############################################ +# scalar multiplication (Montgomery space) # +############################################ +def mt_scalarmult(u, scalar): + x1 = u + x2, z2 = fe(1), fe(0) # "zero" point + x3, z3 = x1 , fe(1) # "one" point + binary = [int(c) for c in list(format(scalar, 'b'))] + for b in binary: + # Montgomery ladder step: + # if b == 0, then (P2, P3) == (P2*2 , P2+P3) + # if b == 1, then (P2, P3) == (P2+P3, P3*2 ) + if b == 1: + x2, x3 = x3, x2 + z2, z3 = z3, z2 + t0 = x3 - z3 + t1 = x2 - z2 + x2 = x2 + z2 + z2 = x3 + z3 + z3 = t0 * x2 + z2 = z2 * t1 + t0 = t1**2 + t1 = x2**2 + x3 = z3 + z2 + z2 = z3 - z2 + x2 = t1 * t0 + t1 = t1 - t0 + z2 = z2**2 + z3 = t1 * fe(121666) + x3 = x3**2 + t0 = t0 + z3 + z3 = x1 * z2 + z2 = t1 * t0 + if b == 1: + x2, x3 = x3, x2 + z2, z3 = z3, z2 + return x2 / z2 + +montgomery_base = 9 + +############################ +# Scalarmult with cofactor # +############################ + +# Keeping a random cofactor is important to keep points +# indistinguishable from random. (Else we'd notice all representatives +# represent points with cleared cofactor. Not exactly random.) + +# Point of order 8, used to add the cofactor component +low_order_point_x = sqrt((sqrt(d + fe(1)) + fe(1)) / d) +low_order_point_y = -low_order_point_x * sqrtm1 +low_order_point = (low_order_point_x, low_order_point_y) + +def check_low_order_point(): + lop2 = point_add2(low_order_point, low_order_point) + lop4 = point_add2(lop2, lop2) + lop8 = point_add2(lop4, lop4) + zero = (fe(0), fe(1)) + if lop8 != zero: raise ValueError('low_order_point does not have low order') + if lop2 == zero: raise ValueError('low_order_point only has order 2') + if lop4 == zero: raise ValueError('low_order_point only has order 4') +check_low_order_point() + +# base point + low order point +ed_base = point_add2(low_order_point, edwards_base) # in Edwards space +mt_base = (fe(1) + ed_base[1]) / (fe(1) - ed_base[1]) # in Montgomery space + +# Clamp the scalar. +# % 8 stops subgroup attacks +# Clearing bit 255 and setting bit 254 facilitates constant time ladders. +# We're not supposed to clear the cofactor, but scalar multiplication +# usually does, and we want to reuse existing code as much as possible. +def trim(scalar): + trimmed = scalar - scalar % 8 + trimmed = trimmed % 2**254 + trimmed = trimmed + 2**254 + return trimmed + +order = 2**252 + 27742317777372353535851937790883648493 + +# Single scalar multiplication (in Edwards space) +def scalarmult1(scalar, cofactor): + co_cleared = cofactor * (5 * order) # cleared main factor + combined = trim(scalar) + co_cleared + return from_edwards(ed_scalarmult(ed_base, combined)) + +# Single scalar multiplication (in Montgomery space) +def scalarmult2(scalar, cofactor): + co_cleared = cofactor * (5 * order) # cleared main factor + combined = trim(scalar) + co_cleared + return mt_scalarmult(mt_base, combined) + +# Double scalar multiplication (reuses EdDSA code) +def scalarmult3(scalar, cofactor): + main_point = ed_scalarmult(edwards_base , trim(scalar)) + low_order = ed_scalarmult(low_order_point, cofactor ) + return from_edwards(point_add(main_point, low_order)) + +# Combine and compare all ways ofd doing the scalar multiplication +def scalarmult(scalar, cofactor): + p1 = scalarmult1(scalar, cofactor) + p2 = scalarmult2(scalar, cofactor) + p3 = scalarmult3(scalar, cofactor) + if p1 != p2 or p1 != p3: + raise ValueError('Incoherent scalarmult') + return p1 diff --git a/tests/gen/makefile b/tests/gen/makefile index 07179e8..63fa6a9 100644 --- a/tests/gen/makefile +++ b/tests/gen/makefile @@ -68,8 +68,8 @@ clean: rm -f *.out *.vec *.o rm -f $(VECTORS) -elligator_inv.vec: elligator-inverse.py elligator.py x25519_pk.all.vec - ./$< $@ +elligator_inv.vec: elligator-inverse.py elligator.py elligator_scalarmult.py + ./$< >$@ elligator_dir.vec: elligator-direct.py elligator.py ./$< >$@ diff --git a/tests/test.c b/tests/test.c index f23ce22..f56e56d 100644 --- a/tests/test.c +++ b/tests/test.c @@ -292,6 +292,20 @@ static void elligator_dir(const vector in[], vector *out) crypto_elligator2_direct(out->buf, in->buf); } +static void elligator_inv(const vector in[], vector *out) +{ + const vector *sk = in; + u8 tweak = in[1].buf[0]; + u8 failure = in[2].buf[0]; + int check = crypto_elligator2_inverse(out->buf, sk->buf, tweak); + if ((u8)check != failure) { + fprintf(stderr, "Elligator inverse map: failure mismatch\n"); + } + if (check) { + out->buf[0] = 0; + } +} + ////////////////////////////// /// Self consistency tests /// ////////////////////////////// @@ -916,6 +930,7 @@ int main(int argc, char *argv[]) status |= TEST(ed_25519_check, 3); status |= test_x25519(); status |= TEST(elligator_dir , 1); + status |= TEST(elligator_inv , 3); printf("\nProperty based tests"); printf("\n--------------------\n"); -- 2.47.3