diff --git a/curve25519/solana-ed25519/src/backend/serial/scalar_mul/vartime_triple_base.rs b/curve25519/solana-ed25519/src/backend/serial/scalar_mul/vartime_triple_base.rs index cecc6b9..36ff876 100644 --- a/curve25519/solana-ed25519/src/backend/serial/scalar_mul/vartime_triple_base.rs +++ b/curve25519/solana-ed25519/src/backend/serial/scalar_mul/vartime_triple_base.rs @@ -71,10 +71,10 @@ pub fn mul_128_128_256( let b_hi = Scalar::from_canonical_bytes(b_hi_bytes).unwrap(); // Compute NAF representations (all scalars are now ~128 bits) - let a1_naf = a1.non_adjacent_form(5); - let a2_naf = a2.non_adjacent_form(5); - let b_lo_naf = b_lo.non_adjacent_form(5); - let b_hi_naf = b_hi.non_adjacent_form(5); + let a1_naf = a1.non_adjacent_form_128(5); + let a2_naf = a2.non_adjacent_form_128(5); + let b_lo_naf = b_lo.non_adjacent_form_128(5); + let b_hi_naf = b_hi.non_adjacent_form_128(5); // Find starting index - check all NAFs up to bit 127 // (with potential carry to bit 128 or 129) diff --git a/curve25519/solana-ed25519/src/backend/vector/scalar_mul/vartime_triple_base.rs b/curve25519/solana-ed25519/src/backend/vector/scalar_mul/vartime_triple_base.rs index 4508051..4b1ec17 100644 --- a/curve25519/solana-ed25519/src/backend/vector/scalar_mul/vartime_triple_base.rs +++ b/curve25519/solana-ed25519/src/backend/vector/scalar_mul/vartime_triple_base.rs @@ -82,15 +82,15 @@ pub mod spec { let b_hi = Scalar::from_canonical_bytes(b_hi_bytes).unwrap(); // Compute NAF representations (all scalars are now ~128 bits) - let a1_naf = a1.non_adjacent_form(5); - let a2_naf = a2.non_adjacent_form(5); + let a1_naf = a1.non_adjacent_form_128(5); + let a2_naf = a2.non_adjacent_form_128(5); #[cfg(feature = "precomputed-tables")] - let b_lo_naf = b_lo.non_adjacent_form(8); + let b_lo_naf = b_lo.non_adjacent_form_128(8); #[cfg(not(feature = "precomputed-tables"))] - let b_lo_naf = b_lo.non_adjacent_form(5); + let b_lo_naf = b_lo.non_adjacent_form_128(5); - let b_hi_naf = b_hi.non_adjacent_form(5); + let b_hi_naf = b_hi.non_adjacent_form_128(5); // Find starting index - check all NAFs up to bit 127 // (with potential carry to bit 128 or 129) diff --git a/curve25519/solana-ed25519/src/scalar.rs b/curve25519/solana-ed25519/src/scalar.rs index 56d4b3c..cb101d6 100644 --- a/curve25519/solana-ed25519/src/scalar.rs +++ b/curve25519/solana-ed25519/src/scalar.rs @@ -152,6 +152,8 @@ use crate::traits::HEEADecomposition; mod heea; pub(crate) use heea::HEEA_MAX_INDEX; +pub(crate) const NAF_128_SIZE: usize = HEEA_MAX_INDEX + 1; + /// An `UnpackedScalar` represents an element of the field GF(l), optimized for speed. /// /// This is pinned to the 64-bit serial scalar backend. @@ -1013,6 +1015,55 @@ impl Scalar { naf } + /// Compute a width-\\(w\\) non-adjacent form for scalars known to fit in 128 bits. + pub(crate) fn non_adjacent_form_128(&self, w: usize) -> [i8; NAF_128_SIZE] { + // required by the NAF definition + debug_assert!(w >= 2); + // required so that the NAF digits fit in i8 + debug_assert!(w <= 8); + debug_assert!(self.bytes[16..32].iter().all(|&b| b == 0)); + + let mut naf = [0i8; NAF_128_SIZE]; + + let mut x_u64 = [0u64; 3]; + read_le_u64_into(&self.bytes[..16], &mut x_u64[0..2]); + + let width = 1u64 << w; + let window_mask = width - 1; + + let mut pos = 0; + let mut carry = 0; + while pos < HEEA_MAX_INDEX { + let u64_idx = pos / 64; + let bit_idx = pos % 64; + let bit_buf: u64 = if bit_idx < 64 - w { + x_u64[u64_idx] >> bit_idx + } else { + (x_u64[u64_idx] >> bit_idx) | (x_u64[1 + u64_idx] << (64 - bit_idx)) + }; + + let window = carry + (bit_buf & window_mask); + + if window & 1 == 0 { + pos += 1; + continue; + } + + if window < width / 2 { + carry = 0; + naf[pos] = window as i8; + } else { + carry = 1; + naf[pos] = (window as i8).wrapping_sub(width as i8); + } + + pos += w; + } + + debug_assert_eq!(carry, 0); + naf + } + /// Write this scalar in radix 16, with coefficients in \\([-8,8)\\), /// i.e., compute \\(a\_i\\) such that /// $$ @@ -1583,6 +1634,40 @@ pub(crate) mod test { } } + #[test] + fn non_adjacent_form_128_matches_generic() { + let mut high_bit = [0u8; 32]; + high_bit[15] = 0x80; + + let cases = [ + Scalar::ZERO, + Scalar::ONE, + Scalar::from(u64::MAX), + Scalar::from(0xfedc_ba98_7654_3210_0123_4567_89ab_cdefu128), + Scalar::from(u128::MAX), + Scalar { bytes: high_bit }, + ]; + + for scalar in cases { + for w in 2..=8 { + let generic = scalar.non_adjacent_form(w); + let naf_128 = scalar.non_adjacent_form_128(w); + + for i in 0..=HEEA_MAX_INDEX { + assert_eq!( + naf_128[i], generic[i], + "NAF mismatch at index {i} for width {w}" + ); + } + + assert!( + generic[NAF_128_SIZE..].iter().all(|&digit| digit == 0), + "generic NAF has non-zero digits above 128-bit range for width {w}" + ); + } + } + } + #[cfg(feature = "rand_core")] fn non_adjacent_form_iter(w: usize, x: &Scalar) { let naf = x.non_adjacent_form(w);