diff --git a/curve25519/solana-ed25519/src/backend/vector/avx2/field.rs b/curve25519/solana-ed25519/src/backend/vector/avx2/field.rs index a276def..9926fd8 100644 --- a/curve25519/solana-ed25519/src/backend/vector/avx2/field.rs +++ b/curve25519/solana-ed25519/src/backend/vector/avx2/field.rs @@ -792,6 +792,28 @@ impl Mul<&FieldElement2625x4> for &FieldElement2625x4 { x.mul32(y).into() } + #[inline(always)] + fn twice(x: u32x8) -> u32x8 { + x + x + } + + #[inline(always)] + #[allow(clippy::too_many_arguments)] + fn sum10( + a: u64x4, + b: u64x4, + c: u64x4, + d: u64x4, + e: u64x4, + f: u64x4, + g: u64x4, + h: u64x4, + i: u64x4, + j: u64x4, + ) -> u64x4 { + ((a + b) + (c + d)) + ((e + f) + (g + h)) + (i + j) + } + let (x0, x1) = unpack_pair(self.0[0]); let (x2, x3) = unpack_pair(self.0[1]); let (x4, x5) = unpack_pair(self.0[2]); @@ -816,22 +838,17 @@ impl Mul<&FieldElement2625x4> for &FieldElement2625x4 { let y8_19 = m_lo(v19, y8); let y9_19 = m_lo(v19, y9); - let x1_2 = x1 + x1; // This fits in a u32 iff 25 + b + 1 < 32 - let x3_2 = x3 + x3; // iff b < 6 - let x5_2 = x5 + x5; - let x7_2 = x7 + x7; - let x9_2 = x9 + x9; - - let z0 = m(x0, y0) + m(x1_2, y9_19) + m(x2, y8_19) + m(x3_2, y7_19) + m(x4, y6_19) + m(x5_2, y5_19) + m(x6, y4_19) + m(x7_2, y3_19) + m(x8, y2_19) + m(x9_2, y1_19); - let z1 = m(x0, y1) + m(x1, y0) + m(x2, y9_19) + m(x3, y8_19) + m(x4, y7_19) + m(x5, y6_19) + m(x6, y5_19) + m(x7, y4_19) + m(x8, y3_19) + m(x9, y2_19); - let z2 = m(x0, y2) + m(x1_2, y1) + m(x2, y0) + m(x3_2, y9_19) + m(x4, y8_19) + m(x5_2, y7_19) + m(x6, y6_19) + m(x7_2, y5_19) + m(x8, y4_19) + m(x9_2, y3_19); - let z3 = m(x0, y3) + m(x1, y2) + m(x2, y1) + m(x3, y0) + m(x4, y9_19) + m(x5, y8_19) + m(x6, y7_19) + m(x7, y6_19) + m(x8, y5_19) + m(x9, y4_19); - let z4 = m(x0, y4) + m(x1_2, y3) + m(x2, y2) + m(x3_2, y1) + m(x4, y0) + m(x5_2, y9_19) + m(x6, y8_19) + m(x7_2, y7_19) + m(x8, y6_19) + m(x9_2, y5_19); - let z5 = m(x0, y5) + m(x1, y4) + m(x2, y3) + m(x3, y2) + m(x4, y1) + m(x5, y0) + m(x6, y9_19) + m(x7, y8_19) + m(x8, y7_19) + m(x9, y6_19); - let z6 = m(x0, y6) + m(x1_2, y5) + m(x2, y4) + m(x3_2, y3) + m(x4, y2) + m(x5_2, y1) + m(x6, y0) + m(x7_2, y9_19) + m(x8, y8_19) + m(x9_2, y7_19); - let z7 = m(x0, y7) + m(x1, y6) + m(x2, y5) + m(x3, y4) + m(x4, y3) + m(x5, y2) + m(x6, y1) + m(x7, y0) + m(x8, y9_19) + m(x9, y8_19); - let z8 = m(x0, y8) + m(x1_2, y7) + m(x2, y6) + m(x3_2, y5) + m(x4, y4) + m(x5_2, y3) + m(x6, y2) + m(x7_2, y1) + m(x8, y0) + m(x9_2, y9_19); - let z9 = m(x0, y9) + m(x1, y8) + m(x2, y7) + m(x3, y6) + m(x4, y5) + m(x5, y4) + m(x6, y3) + m(x7, y2) + m(x8, y1) + m(x9, y0); + // Doubling odd x limbs fits in a u32 iff 25 + b + 1 < 32, i.e. b < 6. + let z0 = sum10(m(x0, y0), m(twice(x1), y9_19), m(x2, y8_19), m(twice(x3), y7_19), m(x4, y6_19), m(twice(x5), y5_19), m(x6, y4_19), m(twice(x7), y3_19), m(x8, y2_19), m(twice(x9), y1_19)); + let z1 = sum10(m(x0, y1), m(x1, y0), m(x2, y9_19), m(x3, y8_19), m(x4, y7_19), m(x5, y6_19), m(x6, y5_19), m(x7, y4_19), m(x8, y3_19), m(x9, y2_19)); + let z2 = sum10(m(x0, y2), m(twice(x1), y1), m(x2, y0), m(twice(x3), y9_19), m(x4, y8_19), m(twice(x5), y7_19), m(x6, y6_19), m(twice(x7), y5_19), m(x8, y4_19), m(twice(x9), y3_19)); + let z3 = sum10(m(x0, y3), m(x1, y2), m(x2, y1), m(x3, y0), m(x4, y9_19), m(x5, y8_19), m(x6, y7_19), m(x7, y6_19), m(x8, y5_19), m(x9, y4_19)); + let z4 = sum10(m(x0, y4), m(twice(x1), y3), m(x2, y2), m(twice(x3), y1), m(x4, y0), m(twice(x5), y9_19), m(x6, y8_19), m(twice(x7), y7_19), m(x8, y6_19), m(twice(x9), y5_19)); + let z5 = sum10(m(x0, y5), m(x1, y4), m(x2, y3), m(x3, y2), m(x4, y1), m(x5, y0), m(x6, y9_19), m(x7, y8_19), m(x8, y7_19), m(x9, y6_19)); + let z6 = sum10(m(x0, y6), m(twice(x1), y5), m(x2, y4), m(twice(x3), y3), m(x4, y2), m(twice(x5), y1), m(x6, y0), m(twice(x7), y9_19), m(x8, y8_19), m(twice(x9), y7_19)); + let z7 = sum10(m(x0, y7), m(x1, y6), m(x2, y5), m(x3, y4), m(x4, y3), m(x5, y2), m(x6, y1), m(x7, y0), m(x8, y9_19), m(x9, y8_19)); + let z8 = sum10(m(x0, y8), m(twice(x1), y7), m(x2, y6), m(twice(x3), y5), m(x4, y4), m(twice(x5), y3), m(x6, y2), m(twice(x7), y1), m(x8, y0), m(twice(x9), y9_19)); + let z9 = sum10(m(x0, y9), m(x1, y8), m(x2, y7), m(x3, y6), m(x4, y5), m(x5, y4), m(x6, y3), m(x7, y2), m(x8, y1), m(x9, y0)); // The bounds on z[i] are the same as in the serial 32-bit code // and the comment below is copied from there: