diff --git a/build.zig b/build.zig index 1ab27bf4..825eb507 100644 --- a/build.zig +++ b/build.zig @@ -278,6 +278,27 @@ pub fn build(b: *std.Build) void { const field_bench_step = b.step("bench-field", "Run field arithmetic benchmark"); field_bench_step.dependOn(&run_field_bench.step); + // Release-optimized dep chain for benchmarks (so zolt-arith gets + // compiled at ReleaseFast instead of Debug, enabling LLVM intrinsics). + const zolt_pool_dep_release = b.dependency("zolt_pool", .{ + .target = target, + .optimize = .ReleaseFast, + }); + const zolt_arith_dep_release = b.dependency("zolt_arith", .{ + .target = target, + .optimize = .ReleaseFast, + }); + const zolt_mod_release = b.createModule(.{ + .root_source_file = b.path("src/root.zig"), + .target = target, + .optimize = .ReleaseFast, + .imports = &.{ + .{ .name = "zolt_pool", .module = zolt_pool_dep_release.module("zolt_pool") }, + .{ .name = "zolt_arith", .module = zolt_arith_dep_release.module("zolt_arith") }, + }, + }); + if (is_apple_silicon) linkMetalFrameworks(zolt_mod_release); + // Benchmark: zolt-arith field microbench (repo-level, optional) const zolt_arith_field_micro = b.addExecutable(.{ .name = "zolt-arith-field-micro", @@ -286,7 +307,7 @@ pub fn build(b: *std.Build) void { .target = target, .optimize = .ReleaseFast, .imports = &.{ - .{ .name = "zolt", .module = lib.root_module }, + .{ .name = "zolt", .module = zolt_mod_release }, }, }), }); @@ -302,7 +323,7 @@ pub fn build(b: *std.Build) void { .target = target, .optimize = .ReleaseFast, .imports = &.{ - .{ .name = "zolt", .module = lib.root_module }, + .{ .name = "zolt", .module = zolt_mod_release }, }, }), }); diff --git a/packages/zolt-arith/src/bigint.zig b/packages/zolt-arith/src/bigint.zig new file mode 100644 index 00000000..2a34f987 --- /dev/null +++ b/packages/zolt-arith/src/bigint.zig @@ -0,0 +1,261 @@ +//! Multi-limb big integer arithmetic over fixed-size `[N]u64` arrays. +//! +//! This is the foundation `zolt-arith` builds field arithmetic on top of. +//! It is intentionally narrow and unopinionated: +//! +//! - Limbs are little-endian: `limbs[0]` is the least-significant 64-bit +//! word. The numeric value is sum_{i} limbs[i] * 2^(64*i). +//! - Operations are constant-shape (no early exit on equality), but NOT +//! constant-time on every CPU — branches around carry chains exist +//! for portability. The constant-time story comes when the BLS12-381 +//! instantiation lands and we can pin a specific target. +//! - There is no allocation. Every helper takes its result by `*[N]u64` +//! or returns a `[N]u64` value. +//! +//! The 256-bit (4-limb) and 384-bit (6-limb) cases both flow through the +//! same comptime-generic code, so the BN254 / BLS12-381 split lives in +//! `field.zig` rather than here. + +const std = @import("std"); + +/// Add `a + b` modulo 2^(64*N), returning the carry-out as `u1`. +/// Out-of-place to keep the function easy to use as a building block; +/// callers that need an in-place version can pass the destination as +/// `out` and one of the inputs as `b`. +pub fn add(comptime N: comptime_int, out: *[N]u64, a: [N]u64, b: [N]u64) u1 { + var carry: u1 = 0; + inline for (0..N) |i| { + const sum_ab = @addWithOverflow(a[i], b[i]); + const sum_c = @addWithOverflow(sum_ab[0], carry); + out[i] = sum_c[0]; + carry = sum_ab[1] | sum_c[1]; + } + return carry; +} + +/// Subtract `a - b` modulo 2^(64*N), returning the borrow-out as `u1`. +/// `out` may alias either input. +pub fn sub(comptime N: comptime_int, out: *[N]u64, a: [N]u64, b: [N]u64) u1 { + var borrow: u1 = 0; + inline for (0..N) |i| { + const diff_ab = @subWithOverflow(a[i], b[i]); + const diff_c = @subWithOverflow(diff_ab[0], borrow); + out[i] = diff_c[0]; + borrow = diff_ab[1] | diff_c[1]; + } + return borrow; +} + +/// Lexicographic comparison treating both operands as little-endian +/// limb arrays. Returns `.lt`, `.eq`, or `.gt` matching `std.math.Order`. +pub fn cmp(comptime N: comptime_int, a: [N]u64, b: [N]u64) std.math.Order { + var i: usize = N; + while (i > 0) { + i -= 1; + if (a[i] < b[i]) return .lt; + if (a[i] > b[i]) return .gt; + } + return .eq; +} + +/// `a == 0` for an N-limb integer. +pub fn isZero(comptime N: comptime_int, a: [N]u64) bool { + inline for (0..N) |i| { + if (a[i] != 0) return false; + } + return true; +} + +/// `a == 1` for an N-limb integer (least-significant limb only). +pub fn isOne(comptime N: comptime_int, a: [N]u64) bool { + if (a[0] != 1) return false; + inline for (1..N) |i| { + if (a[i] != 0) return false; + } + return true; +} + +/// Bit length of an N-limb integer (1-indexed; 0 for zero). +pub fn bitLen(comptime N: comptime_int, a: [N]u64) usize { + var i: usize = N; + while (i > 0) { + i -= 1; + if (a[i] != 0) return i * 64 + (64 - @clz(a[i])); + } + return 0; +} + +/// Read a little-endian byte slice into an N-limb integer. The slice +/// length must be at most `N * 8`; remaining high bytes are zero. +pub fn fromBytesLe(comptime N: comptime_int, bytes: []const u8) [N]u64 { + std.debug.assert(bytes.len <= N * 8); + var out: [N]u64 = .{0} ** N; + var buf: [N * 8]u8 = .{0} ** (N * 8); + @memcpy(buf[0..bytes.len], bytes); + inline for (0..N) |i| { + out[i] = std.mem.readInt(u64, buf[i * 8 ..][0..8], .little); + } + return out; +} + +/// Write an N-limb integer into a little-endian byte slice. The slice +/// length must be exactly `N * 8`. +pub fn toBytesLe(comptime N: comptime_int, value: [N]u64, out: []u8) void { + std.debug.assert(out.len == N * 8); + inline for (0..N) |i| { + std.mem.writeInt(u64, out[i * 8 ..][0..8], value[i], .little); + } +} + +/// Read a big-endian byte slice into an N-limb integer (the natural +/// encoding for cryptographic public keys / signatures). Slice length +/// must be at most `N * 8`. +pub fn fromBytesBe(comptime N: comptime_int, bytes: []const u8) [N]u64 { + std.debug.assert(bytes.len <= N * 8); + var out: [N]u64 = .{0} ** N; + var buf: [N * 8]u8 = .{0} ** (N * 8); + // Right-align the input so the most-significant byte lands at the + // top of the buffer (matches the BE convention). + const offset = (N * 8) - bytes.len; + @memcpy(buf[offset..][0..bytes.len], bytes); + // Walk limbs from MSB to LSB, reading 8 BE bytes each. + inline for (0..N) |i| { + const limb_bytes = buf[(N - 1 - i) * 8 ..][0..8]; + out[i] = std.mem.readInt(u64, limb_bytes, .big); + } + return out; +} + +/// Write an N-limb integer as a big-endian byte slice. Slice length +/// must be exactly `N * 8`. +pub fn toBytesBe(comptime N: comptime_int, value: [N]u64, out: []u8) void { + std.debug.assert(out.len == N * 8); + inline for (0..N) |i| { + std.mem.writeInt(u64, out[(N - 1 - i) * 8 ..][0..8], value[i], .big); + } +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +const testing = std.testing; + +test "add: 0 + 0 = 0 with no carry" { + var out: [4]u64 = undefined; + const carry = add(4, &out, .{ 0, 0, 0, 0 }, .{ 0, 0, 0, 0 }); + try testing.expectEqual(@as(u1, 0), carry); + try testing.expectEqual(@as(u64, 0), out[0]); +} + +test "add: 1 + 1 = 2 with no carry (4 limbs)" { + var out: [4]u64 = undefined; + const carry = add(4, &out, .{ 1, 0, 0, 0 }, .{ 1, 0, 0, 0 }); + try testing.expectEqual(@as(u1, 0), carry); + try testing.expectEqual(@as(u64, 2), out[0]); +} + +test "add: u64::MAX + 1 propagates carry through all limbs" { + var out: [4]u64 = undefined; + const carry = add( + 4, + &out, + .{ std.math.maxInt(u64), std.math.maxInt(u64), std.math.maxInt(u64), std.math.maxInt(u64) }, + .{ 1, 0, 0, 0 }, + ); + try testing.expectEqual(@as(u1, 1), carry); + try testing.expectEqual([_]u64{ 0, 0, 0, 0 }, out); +} + +test "add: 6-limb (BLS12-381 width) carry chain" { + var out: [6]u64 = undefined; + const carry = add( + 6, + &out, + .{ std.math.maxInt(u64), 0, 0, 0, 0, 0 }, + .{ 1, 0, 0, 0, 0, 0 }, + ); + try testing.expectEqual(@as(u1, 0), carry); + try testing.expectEqual([_]u64{ 0, 1, 0, 0, 0, 0 }, out); +} + +test "sub: simple no-borrow case" { + var out: [4]u64 = undefined; + const borrow = sub(4, &out, .{ 5, 0, 0, 0 }, .{ 3, 0, 0, 0 }); + try testing.expectEqual(@as(u1, 0), borrow); + try testing.expectEqual(@as(u64, 2), out[0]); +} + +test "sub: borrow propagates through limbs" { + var out: [4]u64 = undefined; + const borrow = sub(4, &out, .{ 0, 1, 0, 0 }, .{ 1, 0, 0, 0 }); + try testing.expectEqual(@as(u1, 0), borrow); + try testing.expectEqual([_]u64{ std.math.maxInt(u64), 0, 0, 0 }, out); +} + +test "sub: underflow yields borrow=1" { + var out: [4]u64 = undefined; + const borrow = sub(4, &out, .{ 0, 0, 0, 0 }, .{ 1, 0, 0, 0 }); + try testing.expectEqual(@as(u1, 1), borrow); + try testing.expectEqual([_]u64{ std.math.maxInt(u64), std.math.maxInt(u64), std.math.maxInt(u64), std.math.maxInt(u64) }, out); +} + +test "cmp: lexicographic order across limbs" { + try testing.expectEqual(std.math.Order.eq, cmp(4, .{ 1, 2, 3, 4 }, .{ 1, 2, 3, 4 })); + try testing.expectEqual(std.math.Order.lt, cmp(4, .{ 1, 2, 3, 4 }, .{ 1, 2, 3, 5 })); + try testing.expectEqual(std.math.Order.gt, cmp(4, .{ 1, 2, 3, 5 }, .{ 1, 2, 3, 4 })); + // High limb dominates over low limbs. + try testing.expectEqual( + std.math.Order.gt, + cmp(4, .{ 0, 0, 0, 1 }, .{ std.math.maxInt(u64), std.math.maxInt(u64), std.math.maxInt(u64), 0 }), + ); +} + +test "isZero / isOne" { + try testing.expect(isZero(4, .{ 0, 0, 0, 0 })); + try testing.expect(!isZero(4, .{ 1, 0, 0, 0 })); + try testing.expect(!isZero(4, .{ 0, 0, 0, 1 })); + try testing.expect(isOne(4, .{ 1, 0, 0, 0 })); + try testing.expect(!isOne(4, .{ 2, 0, 0, 0 })); + try testing.expect(!isOne(4, .{ 1, 1, 0, 0 })); +} + +test "bitLen" { + try testing.expectEqual(@as(usize, 0), bitLen(4, .{ 0, 0, 0, 0 })); + try testing.expectEqual(@as(usize, 1), bitLen(4, .{ 1, 0, 0, 0 })); + try testing.expectEqual(@as(usize, 64), bitLen(4, .{ std.math.maxInt(u64), 0, 0, 0 })); + try testing.expectEqual(@as(usize, 65), bitLen(4, .{ 0, 1, 0, 0 })); + try testing.expectEqual(@as(usize, 256), bitLen(4, .{ 0, 0, 0, 1 << 63 })); +} + +test "fromBytesLe / toBytesLe round-trip" { + const original: [4]u64 = .{ 0x0102030405060708, 0x1112131415161718, 0x2122232425262728, 0x3132333435363738 }; + var bytes: [32]u8 = undefined; + toBytesLe(4, original, &bytes); + // LE within each limb: limb 0's least-significant byte (0x08) is at + // index 0, limb 3's most-significant byte (0x31) is at index 31. + try testing.expectEqual(@as(u8, 0x08), bytes[0]); + try testing.expectEqual(@as(u8, 0x31), bytes[31]); + const round_tripped = fromBytesLe(4, &bytes); + try testing.expectEqual(original, round_tripped); +} + +test "fromBytesBe / toBytesBe round-trip" { + const original: [4]u64 = .{ 0x0102030405060708, 0x1112131415161718, 0x2122232425262728, 0x3132333435363738 }; + var bytes: [32]u8 = undefined; + toBytesBe(4, original, &bytes); + // BE: most-significant limb first, MSB first within the limb. + try testing.expectEqual(@as(u8, 0x31), bytes[0]); + try testing.expectEqual(@as(u8, 0x08), bytes[31]); + const round_tripped = fromBytesBe(4, &bytes); + try testing.expectEqual(original, round_tripped); +} + +test "fromBytesBe handles short input by left-padding with zeroes" { + // Public BLS12-381 G1 keys are 48 bytes (6 limbs). A short input + // should land in the high bytes — this is the natural BE convention. + const short = [_]u8{ 0xab, 0xcd }; + const value = fromBytesBe(6, &short); + try testing.expectEqual(@as(u64, 0), value[5]); + try testing.expectEqual(@as(u64, 0xabcd), value[0]); +} diff --git a/packages/zolt-arith/src/curves/bls12_381/bls.zig b/packages/zolt-arith/src/curves/bls12_381/bls.zig new file mode 100644 index 00000000..3ffc8e8f --- /dev/null +++ b/packages/zolt-arith/src/curves/bls12_381/bls.zig @@ -0,0 +1,506 @@ +//! BLS12-381 signature verification entry point. +//! +//! Wraps the optimal Ate pairing and the RFC 9380 hash-to-curve into +//! the standard BLS-min-pk verification rule: +//! +//! e(pk, H(msg)) == e(g1, sig) +//! +//! `min-pk` puts public keys in G1 (48 bytes compressed) and signatures +//! in G2 (96 bytes compressed). This matches `blst::min_pk` (which is +//! what Hyli uses through `hyli-crypto`) and the IETF +//! `BLS_SIG_BLS12381G2_XMD:SHA-256_SSWU_RO_NUL_` ciphersuite. +//! +//! The whole pipeline lives in zolt-arith because the same primitives +//! are useful to consumers other than Zyli (Zolt eventually). The +//! Hyli-specific seam — turning a `Signed` envelope into the +//! signable byte string — stays in Zyli's `crypto/signable.zig`. + +const std = @import("std"); +const bls12_381 = @import("curve.zig"); +const hash_to_curve_g2 = @import("hash_to_curve_g2.zig"); +const hash_to_field = @import("hash_to_field.zig"); + +const Fp12 = bls12_381.Fp12; +const G1Affine = bls12_381.G1Affine; +const G2Affine = bls12_381.G2Affine; + +/// IETF ciphersuite identifier for BLS sigs over G2 with SHA-256 SSWU +/// random-oracle hashing and the empty augmentation. Same string used +/// by `blst::min_pk` and pinned in Hyli's `hyli-crypto` `DST` constant. +pub const DST_BLS_SIG_NUL: []const u8 = hash_to_curve_g2.DST_BLS_SIG_NUL; + +/// Errors that the verifier can return without producing a verdict. +pub const VerifyError = error{ + /// `pk` decoded but is the point at infinity, which is not a valid + /// signing key. + PublicKeyIsIdentity, + /// `pk` is not in the prime-order r-subgroup of G1. + PublicKeyNotInSubgroup, + /// `sig` is not in the prime-order r-subgroup of G2. + SignatureNotInSubgroup, + /// `hash_to_curve` over the (msg, DST) pair failed (e.g., DST too + /// long). Bubbles up from the underlying `expand_message_xmd`. + HashFailed, +}; + +/// Verify a BLS signature against a public key and message bytes. +/// Returns `true` for a valid signature, `false` for an invalid one. +/// Errors are reserved for malformed inputs that prevent the verifier +/// from producing a verdict at all. +/// +/// The pairing check is implemented as a single combined product: +/// +/// e(-pk, H(msg)) · e(g1, sig) == 1 (multi-pairing form) +/// +/// Equivalent to checking `e(pk, H(msg)) == e(g1, sig)` but with one +/// final exponentiation instead of two — roughly halves verify cost +/// because the hard part of the final exponentiation dominates the +/// pairing pipeline. +/// +/// Subgroup membership checks on pk and sig are mandatory and match +/// the blst::min_pk default. Hyli relies on these for adversarially- +/// supplied pubkeys; weakening them silently would be a consensus- +/// safety regression. +pub fn verify( + pk: G1Affine, + msg: []const u8, + sig: G2Affine, + dst: []const u8, +) VerifyError!bool { + // 1. Validate the public key. + if (pk.isIdentity()) return VerifyError.PublicKeyIsIdentity; + if (!bls12_381.isInG1Subgroup(pk)) return VerifyError.PublicKeyNotInSubgroup; + + // 2. Validate the signature. + if (!bls12_381.isInG2Subgroup(sig)) return VerifyError.SignatureNotInSubgroup; + + // 3. Hash the message to a G2 point. + const h = hash_to_curve_g2.hashToG2(msg, dst) catch return VerifyError.HashFailed; + + // 4. Pairing check via multi-pairing trick: + // e(pk, H) == e(g1, sig) + // ⇔ e(-pk, H) · e(g1, sig) == 1 + // ⇔ Miller(-pk, H) · Miller(g1, sig) → finalExp → == 1 + const m1 = bls12_381.millerLoop(pk.neg(), h); + const m2 = bls12_381.millerLoop(bls12_381.g1Generator(), sig); + const product = bls12_381.Fp12.mul(m1, m2); + const final = bls12_381.fp12FinalExp(product); + return Fp12.eql(final, Fp12.one()); +} + +/// Convenience wrapper that decodes the compressed wire forms before +/// verifying. Returns `false` for any decode error so callers can treat +/// "malformed bytes" as "invalid signature" without distinguishing. +/// Use the strict variant when you need to know which step failed. +pub fn verifyCompressed( + pk_bytes: []const u8, + msg: []const u8, + sig_bytes: []const u8, + dst: []const u8, +) bool { + const pk = bls12_381.decodeG1Compressed(pk_bytes) catch return false; + const sig = bls12_381.decodeG2Compressed(sig_bytes) catch return false; + return verify(pk, msg, sig, dst) catch return false; +} + +// --------------------------------------------------------------------------- +// Same-message aggregate signature verification. +// +// In the same-message aggregate scheme used by `blst::min_pk` (and by +// Hyli for consensus QCs), several validators each sign the *same* +// message bytes. The aggregate public key is the sum of the individual +// public keys, and the aggregate signature is the sum of the individual +// signatures. Verification reduces to a single pairing equation: +// +// e(g1, agg_sig) == e(agg_pk, H(msg)) +// +// This is exactly the shape of `verify` with `pk = sum_i pk_i`, so we +// expose `aggregatePublicKeys` as a standalone helper and reuse the +// regular verifier underneath. +// --------------------------------------------------------------------------- + +/// Sum a slice of G1 public keys into a single aggregate public key. +/// The empty slice produces the identity (which `verify` will reject +/// downstream). Each pk is added directly via the affine arithmetic; +/// projective is unnecessary because aggregations are typically tiny +/// (≤ 100 validators) and a Fermat inversion every few hundred +/// validators is dwarfed by the pairing cost. +/// +/// Subgroup membership is intentionally NOT checked here — the caller +/// has already pulled the points out of the wire and the verify path +/// re-checks the aggregate. +pub fn aggregatePublicKeys(pks: []const G1Affine) G1Affine { + if (pks.len == 0) return G1Affine.identity(); + var acc = pks[0]; + var i: usize = 1; + while (i < pks.len) : (i += 1) { + acc = acc.add(pks[i]); + } + return acc; +} + +/// Sum a slice of G2 signatures into a single aggregate signature. +/// Mirrors `aggregatePublicKeys` for the signature side. +pub fn aggregateSignatures(sigs: []const G2Affine) G2Affine { + if (sigs.len == 0) return G2Affine.identity(); + var acc = sigs[0]; + var i: usize = 1; + while (i < sigs.len) : (i += 1) { + acc = acc.add(sigs[i]); + } + return acc; +} + +/// Verify a same-message aggregate signature against a list of +/// validator public keys. Returns `true` for valid, `false` for +/// invalid; errors only on the same malformed-input cases as `verify`. +/// +/// The empty pubkey list always rejects (nothing to verify against). +pub fn verifyAggregate( + pks: []const G1Affine, + msg: []const u8, + sig: G2Affine, + dst: []const u8, +) VerifyError!bool { + if (pks.len == 0) return false; + // The individual subgroup checks are still important — we don't + // want to let a single rogue out-of-subgroup pk poison the sum. + // Walk the inputs first. + for (pks) |pk| { + if (pk.isIdentity()) return VerifyError.PublicKeyIsIdentity; + if (!bls12_381.isInG1Subgroup(pk)) return VerifyError.PublicKeyNotInSubgroup; + } + const agg_pk = aggregatePublicKeys(pks); + return verify(agg_pk, msg, sig, dst); +} + +// --------------------------------------------------------------------------- +// Signing. +// +// In BLS-min-pk, signing is straightforward: +// +// sig = sk · H(msg, dst) +// +// where `sk ∈ Fr` (the scalar field), `H(msg, dst) ∈ G2`, and `·` +// is scalar multiplication on the curve. The resulting `sig ∈ G2` is +// then compressed to 96 wire bytes. +// +// Public key derivation: +// +// pk = sk · G1_generator +// +// We accept the secret key as a raw `[4]u64` little-endian limb array +// (Fr canonical form). Bytes-level helpers wrap that for callers that +// receive sks as 32-byte big-endian buffers (the standard wire form). +// --------------------------------------------------------------------------- + +pub const SignError = error{ + /// `sk_bytes` is not a 32-byte buffer. + InvalidSecretKeyLength, + /// The decoded secret key is `0`, which is not a valid signing key + /// (the resulting public key would be the identity). + SecretKeyIsZero, + /// `hash_to_curve` over the (msg, DST) pair failed. + HashFailed, +}; + +/// Sign a message with a raw scalar secret key. Returns the affine +/// G2 signature point. Caller is responsible for ensuring `sk` is in +/// `[1, r-1]` — this function does not validate. +pub fn signWithScalar( + sk: [4]u64, + msg: []const u8, + dst: []const u8, +) SignError!G2Affine { + const h_msg = hash_to_curve_g2.hashToG2(msg, dst) catch return SignError.HashFailed; + return h_msg.mul(4, sk); +} + +/// Derive the BLS public key for a raw scalar. `sk · G1_generator`. +pub fn derivePublicKeyFromScalar(sk: [4]u64) G1Affine { + return bls12_381.g1Generator().mul(4, sk); +} + +/// Sign a message with a 32-byte big-endian secret key (the standard +/// wire form blst uses). Returns the 96-byte compressed signature. +pub fn signBytes( + sk_bytes: []const u8, + msg: []const u8, + dst: []const u8, +) SignError![96]u8 { + if (sk_bytes.len != 32) return SignError.InvalidSecretKeyLength; + var sk_be: [32]u8 = undefined; + @memcpy(&sk_be, sk_bytes); + // Reinterpret as a 4-limb little-endian integer (the canonical Fr + // representation). Big-endian bytes need to be reversed limb-wise: + // the high u64 of the BE representation is limb[3] (the MSB limb) + // in little-endian. + var sk_limbs: [4]u64 = undefined; + inline for (0..4) |i| { + const start = (3 - i) * 8; + sk_limbs[i] = std.mem.readInt(u64, sk_be[start..][0..8], .big); + } + if (bigint_isZero4(sk_limbs)) return SignError.SecretKeyIsZero; + const sig = try signWithScalar(sk_limbs, msg, dst); + return bls12_381.encodeG2Compressed(sig); +} + +/// Derive a 48-byte compressed public key from a 32-byte big-endian +/// secret key. Mirrors `signBytes` for the verification side. +pub fn derivePublicKeyBytes(sk_bytes: []const u8) SignError![48]u8 { + if (sk_bytes.len != 32) return SignError.InvalidSecretKeyLength; + var sk_limbs: [4]u64 = undefined; + inline for (0..4) |i| { + const start = (3 - i) * 8; + sk_limbs[i] = std.mem.readInt(u64, sk_bytes[start..][0..8], .big); + } + if (bigint_isZero4(sk_limbs)) return SignError.SecretKeyIsZero; + return bls12_381.encodeG1Compressed(derivePublicKeyFromScalar(sk_limbs)); +} + +inline fn bigint_isZero4(v: [4]u64) bool { + return v[0] == 0 and v[1] == 0 and v[2] == 0 and v[3] == 0; +} + +// --------------------------------------------------------------------------- +// Sign / verify round-trip tests +// --------------------------------------------------------------------------- + +test "signWithScalar / verify round-trip" { + const sk: [4]u64 = .{ 0x1234567890abcdef, 0xdeadbeef, 0, 0 }; + const pk = derivePublicKeyFromScalar(sk); + const sig = try signWithScalar(sk, "round trip", DST_BLS_SIG_NUL); + const ok = try verify(pk, "round trip", sig, DST_BLS_SIG_NUL); + try testing.expect(ok); +} + +test "signBytes / derivePublicKeyBytes / verifyCompressed round-trip" { + var sk: [32]u8 = .{0} ** 32; + sk[31] = 0x42; + sk[30] = 0x13; + const sig = try signBytes(&sk, "compressed round trip", DST_BLS_SIG_NUL); + const pk = try derivePublicKeyBytes(&sk); + try testing.expect(verifyCompressed(&pk, "compressed round trip", &sig, DST_BLS_SIG_NUL)); +} + +test "signBytes: rejects zero secret key" { + const zero_sk: [32]u8 = .{0} ** 32; + try testing.expectError(SignError.SecretKeyIsZero, signBytes(&zero_sk, "msg", DST_BLS_SIG_NUL)); +} + +test "signBytes: rejects wrong-length secret key" { + const short: [31]u8 = .{0} ** 31; + try testing.expectError(SignError.InvalidSecretKeyLength, signBytes(&short, "msg", DST_BLS_SIG_NUL)); +} + +test "derivePublicKeyBytes: distinct sks produce distinct pks" { + var sk1: [32]u8 = .{0} ** 32; + sk1[31] = 1; + var sk2: [32]u8 = .{0} ** 32; + sk2[31] = 2; + const pk1 = try derivePublicKeyBytes(&sk1); + const pk2 = try derivePublicKeyBytes(&sk2); + try testing.expect(!std.mem.eql(u8, &pk1, &pk2)); +} + +test "signBytes / verify: tampered message rejects" { + var sk: [32]u8 = .{0} ** 32; + sk[31] = 0x99; + const sig = try signBytes(&sk, "honest message", DST_BLS_SIG_NUL); + const pk = try derivePublicKeyBytes(&sk); + try testing.expect(!verifyCompressed(&pk, "tampered message", &sig, DST_BLS_SIG_NUL)); +} + +// --------------------------------------------------------------------------- +// Aggregate verification tests +// --------------------------------------------------------------------------- + +test "aggregatePublicKeys: empty slice → identity" { + const empty: []const G1Affine = &.{}; + const result = aggregatePublicKeys(empty); + try testing.expect(result.isIdentity()); +} + +test "aggregatePublicKeys: single key → that key" { + const sk: [4]u64 = .{ 5, 0, 0, 0 }; + const pk = bls12_381.g1Generator().mul(4, sk); + const result = aggregatePublicKeys(&.{pk}); + try testing.expect(G1Affine.eql(result, pk)); +} + +test "aggregatePublicKeys: two keys sum correctly" { + const sk1: [4]u64 = .{ 5, 0, 0, 0 }; + const sk2: [4]u64 = .{ 7, 0, 0, 0 }; + const pk1 = bls12_381.g1Generator().mul(4, sk1); + const pk2 = bls12_381.g1Generator().mul(4, sk2); + const aggregated = aggregatePublicKeys(&.{ pk1, pk2 }); + // (sk1 + sk2) * G should equal pk1 + pk2. + const sk_sum: [4]u64 = .{ 12, 0, 0, 0 }; + const expected = bls12_381.g1Generator().mul(4, sk_sum); + try testing.expect(G1Affine.eql(aggregated, expected)); +} + +test "verifyAggregate: empty pks → false" { + const empty: []const G1Affine = &.{}; + const sig = bls12_381.g2Generator(); + const result = try verifyAggregate(empty, "msg", sig, DST_BLS_SIG_NUL); + try testing.expect(!result); +} + +test "verifyAggregate: two-validator round-trip" { + // Build (sk_i, pk_i, sig_i) for i ∈ {1, 2} on the same message, + // sum the sigs, and verify against the aggregated pk list. + const msg = "consensus prepare slot 42"; + + const sk1: [4]u64 = .{ 7, 0, 0, 0 }; + const sk2: [4]u64 = .{ 11, 0, 0, 0 }; + + const pk1 = bls12_381.g1Generator().mul(4, sk1); + const pk2 = bls12_381.g1Generator().mul(4, sk2); + + const h_msg = try hash_to_curve_g2.hashToG2(msg, DST_BLS_SIG_NUL); + const sig1 = h_msg.mul(4, sk1); + const sig2 = h_msg.mul(4, sk2); + + const agg_sig = aggregateSignatures(&.{ sig1, sig2 }); + const result = try verifyAggregate(&.{ pk1, pk2 }, msg, agg_sig, DST_BLS_SIG_NUL); + try testing.expect(result); +} + +test "verifyAggregate: rejects when one signature is for the wrong message" { + // Validator 1 signs message A, validator 2 signs message B. The + // aggregate signature is the sum, but verifyAggregate is asked + // to verify against message A. The pairing equation must reject. + const sk1: [4]u64 = .{ 13, 0, 0, 0 }; + const sk2: [4]u64 = .{ 17, 0, 0, 0 }; + const pk1 = bls12_381.g1Generator().mul(4, sk1); + const pk2 = bls12_381.g1Generator().mul(4, sk2); + + const h_a = try hash_to_curve_g2.hashToG2("message A", DST_BLS_SIG_NUL); + const h_b = try hash_to_curve_g2.hashToG2("message B", DST_BLS_SIG_NUL); + const sig1 = h_a.mul(4, sk1); + const sig2 = h_b.mul(4, sk2); + + const agg_sig = aggregateSignatures(&.{ sig1, sig2 }); + const result = try verifyAggregate(&.{ pk1, pk2 }, "message A", agg_sig, DST_BLS_SIG_NUL); + try testing.expect(!result); +} + +test "verifyAggregate: rejects when one pk is missing from the aggregate" { + // Both validators sign message A, but verify is called with only + // pk1. The aggregate signature was built from both, so the pairing + // equation no longer balances. + const msg = "message A"; + const sk1: [4]u64 = .{ 19, 0, 0, 0 }; + const sk2: [4]u64 = .{ 23, 0, 0, 0 }; + const pk1 = bls12_381.g1Generator().mul(4, sk1); + + const h_msg = try hash_to_curve_g2.hashToG2(msg, DST_BLS_SIG_NUL); + const sig1 = h_msg.mul(4, sk1); + const sig2 = h_msg.mul(4, sk2); + const agg_sig = aggregateSignatures(&.{ sig1, sig2 }); + + const result = try verifyAggregate(&.{pk1}, msg, agg_sig, DST_BLS_SIG_NUL); + try testing.expect(!result); +} + +test "verifyAggregate: three-validator path" { + const msg = "three validators"; + const sk1: [4]u64 = .{ 29, 0, 0, 0 }; + const sk2: [4]u64 = .{ 31, 0, 0, 0 }; + const sk3: [4]u64 = .{ 37, 0, 0, 0 }; + const pk1 = bls12_381.g1Generator().mul(4, sk1); + const pk2 = bls12_381.g1Generator().mul(4, sk2); + const pk3 = bls12_381.g1Generator().mul(4, sk3); + + const h_msg = try hash_to_curve_g2.hashToG2(msg, DST_BLS_SIG_NUL); + const sig1 = h_msg.mul(4, sk1); + const sig2 = h_msg.mul(4, sk2); + const sig3 = h_msg.mul(4, sk3); + const agg_sig = aggregateSignatures(&.{ sig1, sig2, sig3 }); + + const result = try verifyAggregate(&.{ pk1, pk2, pk3 }, msg, agg_sig, DST_BLS_SIG_NUL); + try testing.expect(result); +} + +// --------------------------------------------------------------------------- +// Tests — these don't have a real cross-implementation vector yet, but +// they exercise the algebraic identities that any correct verifier must +// satisfy. The first cross-vector test against a known blst signature +// can land via a fixture once the corpus carries one. +// --------------------------------------------------------------------------- + +const testing = std.testing; + +test "verify: rejects identity pubkey" { + const id = G1Affine.identity(); + const sig = bls12_381.g2Generator(); + try testing.expectError(VerifyError.PublicKeyIsIdentity, verify(id, "msg", sig, DST_BLS_SIG_NUL)); +} + +test "verify: accepts a self-constructed valid signature" { + // Construct (sk, pk) where sk = 7. Compute sig = sk · H("hello"). + // The verifier must accept it. + const sk: [4]u64 = .{ 7, 0, 0, 0 }; + const pk = bls12_381.g1Generator().mul(4, sk); + const h_msg = try hash_to_curve_g2.hashToG2("hello", DST_BLS_SIG_NUL); + const sig = h_msg.mul(4, sk); + const result = try verify(pk, "hello", sig, DST_BLS_SIG_NUL); + try testing.expect(result); +} + +test "verify: rejects signature with wrong message" { + const sk: [4]u64 = .{ 13, 0, 0, 0 }; + const pk = bls12_381.g1Generator().mul(4, sk); + const h_a = try hash_to_curve_g2.hashToG2("message A", DST_BLS_SIG_NUL); + const sig = h_a.mul(4, sk); + // Sign "message A" but verify against "message B". + const result = try verify(pk, "message B", sig, DST_BLS_SIG_NUL); + try testing.expect(!result); +} + +test "verify: rejects signature signed by a different key" { + const sk_signer: [4]u64 = .{ 5, 0, 0, 0 }; + const sk_other: [4]u64 = .{ 11, 0, 0, 0 }; + const pk_other = bls12_381.g1Generator().mul(4, sk_other); + const h_msg = try hash_to_curve_g2.hashToG2("hello", DST_BLS_SIG_NUL); + const sig = h_msg.mul(4, sk_signer); + // Verify with the wrong public key. + const result = try verify(pk_other, "hello", sig, DST_BLS_SIG_NUL); + try testing.expect(!result); +} + +test "verify: rejects a signature swapped with a different message's signature" { + // sk signs "hello A" then we present sig together with the (otherwise + // matching) public key but message "hello B". Pairing equality fails. + const sk: [4]u64 = .{ 17, 0, 0, 0 }; + const pk = bls12_381.g1Generator().mul(4, sk); + const h_a = try hash_to_curve_g2.hashToG2("hello A", DST_BLS_SIG_NUL); + const sig_a = h_a.mul(4, sk); + const result = try verify(pk, "hello B", sig_a, DST_BLS_SIG_NUL); + try testing.expect(!result); +} + +test "verify: empty message works" { + const sk: [4]u64 = .{ 23, 0, 0, 0 }; + const pk = bls12_381.g1Generator().mul(4, sk); + const h_empty = try hash_to_curve_g2.hashToG2("", DST_BLS_SIG_NUL); + const sig = h_empty.mul(4, sk); + const result = try verify(pk, "", sig, DST_BLS_SIG_NUL); + try testing.expect(result); +} + +test "verifyCompressed: round-trip with valid signature" { + // Build a valid (pk, sig) pair, compress them, and verify via the + // compressed-wire entry point. + const sk: [4]u64 = .{ 19, 0, 0, 0 }; + const pk = bls12_381.g1Generator().mul(4, sk); + const h_msg = try hash_to_curve_g2.hashToG2("compressed", DST_BLS_SIG_NUL); + const sig = h_msg.mul(4, sk); + + // We don't have an encoder yet, so round-trip via decode-of-known-good + // wire forms is not possible. Skip the wire round-trip and just test + // that the strict verify path agrees on a valid pair. + const result = try verify(pk, "compressed", sig, DST_BLS_SIG_NUL); + try testing.expect(result); +} diff --git a/packages/zolt-arith/src/curves/bls12_381/curve.zig b/packages/zolt-arith/src/curves/bls12_381/curve.zig new file mode 100644 index 00000000..ee6271b8 --- /dev/null +++ b/packages/zolt-arith/src/curves/bls12_381/curve.zig @@ -0,0 +1,3311 @@ +//! BLS12-381 instantiations of the generic field machinery. +//! +//! This file is the bridge between `zolt_arith.field` and the concrete +//! Hyli BLS surface. It pins the BLS12-381 base field (`Fp`) constants +//! and exposes a strongly-typed field instance the rest of the package +//! (and Zyli's adapter) consumes. +//! +//! BLS12-381 parameters: +//! +//! - `p` (base field prime, 381 bits): +//! `0x1a0111ea397fe69a4b1ba7b6434bacd764774b84f38512bf6730d2a0f6b0f6241eabfffeb153ffffb9feffffffffaaab` +//! +//! - Curve embedding degree 12, optimal Ate pairing-friendly. The +//! scalar field `Fr` is 255 bits and lives in a separate type. +//! +//! All Montgomery constants come from the standard `blst` reference +//! implementation. They are pinned in source so a regression in the +//! field machinery surfaces immediately rather than after a hand-typed +//! constant drifts. + +const std = @import("std"); +const field = @import("field.zig"); +const bigint = @import("../../bigint.zig"); + +/// BLS12-381 base field prime, little-endian limbs. +/// +/// p = 0x1a0111ea397fe69a4b1ba7b6434bacd764774b84f38512bf +/// 6730d2a0f6b0f6241eabfffeb153ffffb9feffffffffaaab +pub const FP_MODULUS: [6]u64 = .{ + 0xb9feffffffffaaab, + 0x1eabfffeb153ffff, + 0x6730d2a0f6b0f624, + 0x64774b84f38512bf, + 0x4b1ba7b6434bacd7, + 0x1a0111ea397fe69a, +}; + +/// `R^2 mod p` where `R = 2^384`. Used to convert raw integers into +/// Montgomery form via `montMul(raw, R2)`. From the blst constants. +/// +/// R^2 = 0x11988fe592cae3aa9a793e85b519952d67eb88a9939d83c0 +/// 8de5476c4c95b6d50a76e6a609d104f1f4df1f341c341746 +pub const FP_R2: [6]u64 = .{ + 0xf4df1f341c341746, + 0x0a76e6a609d104f1, + 0x8de5476c4c95b6d5, + 0x67eb88a9939d83c0, + 0x9a793e85b519952d, + 0x11988fe592cae3aa, +}; + +/// `-p^{-1} mod 2^64`. Drives the per-limb reduction in CIOS Montgomery +/// multiplication. From the blst constants. +pub const FP_N_PRIME: u64 = 0x89f3fffcfffcfffd; + +/// BLS12-381 base field `Fp = ℤ / pℤ`. Elements are stored in +/// Montgomery form and indexed by 6-limb arrays. +pub const Fp = field.MontgomeryField(6, FP_MODULUS, FP_R2, FP_N_PRIME); + +/// `(p + 1) / 4` derived at comptime from `FP_MODULUS`. BLS12-381's +/// base prime has `p ≡ 3 (mod 4)` (the lowest byte of p is 0xab; note +/// 0xab mod 4 = 3), which lets us compute square roots via +/// `a^((p+1)/4)` without falling back to Tonelli-Shanks. +pub const FP_P_PLUS_1_OVER_4: [6]u64 = blk: { + @setEvalBranchQuota(10000); + var v = FP_MODULUS; + // p + 1: the low limb ends in 0xb9feffffffffaaab, so adding 1 gives + // 0xb9feffffffffaaac with no carry into the next limb. + v[0] += 1; + // Right-shift the whole 6-limb integer by 2 bits. + var i: usize = 0; + while (i < 5) : (i += 1) { + v[i] = (v[i] >> 2) | (v[i + 1] << 62); + } + v[5] >>= 2; + break :blk v; +}; + +/// `(p - 1) / 3` derived at comptime via long division by 3. BLS12-381's +/// base prime has `p ≡ 1 (mod 3)` — that's exactly what makes `1+u` +/// a non-cube in Fp2 and lets us pick `v³ = 1+u` as the Fp6 modulus. +/// The constant is the exponent for the Frobenius coefficient +/// `γ₁ = (1+u)^((p-1)/3)`. +pub const FP_P_MINUS_1_OVER_3: [6]u64 = blk: { + @setEvalBranchQuota(10000); + var v = FP_MODULUS; + v[0] -= 1; // p - 1 + // Long division by 3 from MSB to LSB. + var rem: u128 = 0; + var i: usize = 6; + while (i > 0) { + i -= 1; + const word = (rem << 64) | @as(u128, v[i]); + v[i] = @intCast(word / 3); + rem = word % 3; + } + // Sanity: 3 must divide p - 1 exactly. + std.debug.assert(rem == 0); + break :blk v; +}; + +/// Frobenius coefficient `γ₁ = (1+u)^((p-1)/3)` for the Fp6 over Fp2 +/// tower with non-residue `1+u`. Used by `fp6Frobenius` and (squared) +/// for the v² coefficient. Computed once on first call rather than +/// embedded as a hand-typed constant — if the comptime division of +/// (p-1)/3 ever drifts, this catches it via the existing tower +/// relations. +pub fn fp6FrobeniusGamma1() Fp2 { + const one_plus_u: Fp2 = .{ .c0 = Fp.one(), .c1 = Fp.one() }; + return fp2Pow(one_plus_u, 6, FP_P_MINUS_1_OVER_3); +} + +/// `(p - 1) / 6` derived at comptime by halving `FP_P_MINUS_1_OVER_3`. +/// BLS12-381's prime satisfies `p ≡ 1 (mod 6)` (CRT of `p ≡ 3 mod 4` +/// and `p ≡ 1 mod 3` gives `p ≡ 7 mod 12`, so `p mod 6 = 1`), which +/// makes the division exact. +pub const FP_P_MINUS_1_OVER_6: [6]u64 = blk: { + @setEvalBranchQuota(10000); + var v = FP_P_MINUS_1_OVER_3; + // Right-shift by 1 (divide by 2). Walk LSB to MSB collecting the + // outgoing bit from the next limb. + var i: usize = 0; + while (i < 5) : (i += 1) { + v[i] = (v[i] >> 1) | (v[i + 1] << 63); + } + v[5] >>= 1; + break :blk v; +}; + +/// Frobenius coefficient `γ_w = (1+u)^((p-1)/6)` for the Fp12 over Fp6 +/// tower. Drives the action of `φ` on the `w` element, since +/// `w² = v` and `v^p = γ₁·v` give `w^(p-1) = (1+u)^((p-1)/6)`. +pub fn fp12FrobeniusGammaW() Fp2 { + const one_plus_u: Fp2 = .{ .c0 = Fp.one(), .c1 = Fp.one() }; + return fp2Pow(one_plus_u, 6, FP_P_MINUS_1_OVER_6); +} + +/// Frobenius endomorphism in Fp6: `a → a^p`. The action on the basis +/// elements `(1, v, v²)` of Fp6 over Fp2 is: +/// +/// φ(1) = 1 +/// φ(v) = γ₁ · v +/// φ(v²) = γ₁² · v² +/// +/// where `γ₁ = (1+u)^((p-1)/3)` lives in Fp2. Each Fp2 coefficient +/// also gets its own Fp2 Frobenius applied. +pub fn fp6Frobenius(a: Fp6) Fp6 { + const gamma1 = fp6FrobeniusGamma1(); + const gamma1_sq = Fp2.square(gamma1); + return .{ + .c0 = fp2Frobenius(a.c0), + .c1 = Fp2.mul(fp2Frobenius(a.c1), gamma1), + .c2 = Fp2.mul(fp2Frobenius(a.c2), gamma1_sq), + }; +} + +/// Multiply an Fp6 element by an Fp2 scalar (lifted as `(b, 0, 0)`). +/// Componentwise multiplication of each Fp2 coefficient by `b`. +fn fp6MulByFp2(a: Fp6, b: Fp2) Fp6 { + return .{ + .c0 = Fp2.mul(a.c0, b), + .c1 = Fp2.mul(a.c1, b), + .c2 = Fp2.mul(a.c2, b), + }; +} + +/// Frobenius endomorphism in Fp12: `a → a^p`. For +/// `a = c₀ + c₁·w` with `c₀, c₁ ∈ Fp6`: +/// +/// φ(c₀ + c₁·w) = φ(c₀) + φ(c₁)·γ_w·w +/// +/// where `γ_w = (1+u)^((p-1)/6)` lives in Fp2 (which is a subfield of +/// Fp6). The c₁ side does an Fp6 Frobenius then a scalar multiply by +/// `γ_w`. +pub fn fp12Frobenius(a: Fp12) Fp12 { + const gamma_w = fp12FrobeniusGammaW(); + return .{ + .c0 = fp6Frobenius(a.c0), + .c1 = fp6MulByFp2(fp6Frobenius(a.c1), gamma_w), + }; +} + +/// Squared Frobenius `a → a^(p²)`. Two `fp12Frobenius` calls compose +/// into the squared form, which is what the easy part of the final +/// exponentiation needs. +pub fn fp12FrobeniusSquared(a: Fp12) Fp12 { + return fp12Frobenius(fp12Frobenius(a)); +} + +/// Frobenius cubed: `a → a^(p³)`. Composes Frobenius and Frobenius². +pub fn fp12FrobeniusCubed(a: Fp12) Fp12 { + return fp12Frobenius(fp12FrobeniusSquared(a)); +} + +/// "Easy" part of the BLS12-381 final exponentiation: +/// `f^((p^6 - 1)(p^2 + 1))`. +/// +/// This decomposes into operations that don't need the BLS x parameter: +/// +/// 1. `f^(p^6 - 1)` = `f^(p^6) · f^(-1)` = `conjugate(f) · inv(f)`. +/// (For Fp12, raising to the `p^6` power equals conjugation — +/// that's the property of the cyclotomic subgroup.) +/// +/// 2. `f^(p^2 + 1)` = `f^(p^2) · f` = `frobeniusSquared(f) · f`. +/// +/// After the easy part the result is in the cyclotomic subgroup, and +/// the "hard" part (`f^((p^4 - p^2 + 1) / r)`) finishes the job. +/// +/// The slow `Fp12.inv` is acceptable here because the easy part is +/// only invoked once per pairing (not per Miller loop iteration). +pub fn fp12FinalExpEasy(f: Fp12) Fp12 { + // Step 1: f1 = f^(p^6 - 1) = conjugate(f) * inv(f). + const conj = Fp12.conjugate(f); + const inv_f = Fp12.inv(f); + const f1 = Fp12.mul(conj, inv_f); + // Step 2: f2 = f1^(p^2 + 1) = frobeniusSquared(f1) * f1. + const f1_p2 = fp12FrobeniusSquared(f1); + return Fp12.mul(f1_p2, f1); +} + +/// Square root of `a` in Fp via `a^((p+1)/4)`. The caller is responsible +/// for verifying that the returned value squares back to `a` — +/// non-residues yield a value whose square is `-a` instead. +/// +/// Returns the canonical "positive" root; the caller picks the y-sign +/// when decoding compressed points. +pub fn fpSqrt(a: Fp.Element) Fp.Element { + return Fp.pow(a, FP_P_PLUS_1_OVER_4); +} + +/// Predicate: does `candidate^2 == a`? Cheap check the caller uses to +/// rule out non-residues after `fpSqrt`. +pub fn fpIsSquareRoot(a: Fp.Element, candidate: Fp.Element) bool { + return Fp.eql(Fp.square(candidate), a); +} + +/// `2⁻¹ mod p` in Montgomery form. Computed once at comptime via +/// Fermat — `inv(2)` is a constant we use repeatedly in Fp2 sqrt. +pub const FP_TWO_INV: Fp.Element = blk: { + @setEvalBranchQuota(200000); + const two = Fp.fromRaw(.{ 2, 0, 0, 0, 0, 0 }); + break :blk Fp.inv(two); +}; + +/// `2^256 mod p` as an Fp element. Used by `fpFromBytes64Be` to +/// reduce a 512-bit value modulo `p` via the identity +/// `(high·2^256 + low) mod p = ((high mod p)·(2^256 mod p) + (low mod p)) mod p`. +/// +/// Since `p > 2^380 > 2^256`, both `2^256 mod p == 2^256` and any +/// 32-byte high/low chunk is already < p, so the reduction collapses +/// to a single Fp multiplication and addition. +pub const FP_2_TO_256: Fp.Element = blk: { + @setEvalBranchQuota(50000); + // 2^256 in raw 6-limb LE form: limb[4] = 1, everything else 0. + break :blk Fp.fromRaw(.{ 0, 0, 0, 0, 1, 0 }); +}; + +/// Reduce a 64-byte big-endian integer modulo `p` and return the +/// result in Montgomery form. Used by `hash_to_field` to take 64-byte +/// uniform chunks (`L = 64` for BLS12-381 with `k = 128`) and turn +/// them into Fp elements. +pub fn fpFromBytes64Be(bytes: *const [64]u8) Fp.Element { + // Big-endian: the first 32 bytes are the high half. + var high_buf: [48]u8 = .{0} ** 48; + var low_buf: [48]u8 = .{0} ** 48; + // Right-align each 32-byte chunk into a 48-byte buffer so + // bigint.fromBytesBe interprets it as a 6-limb integer. + @memcpy(high_buf[16..48], bytes[0..32]); + @memcpy(low_buf[16..48], bytes[32..64]); + const high_raw = bigint.fromBytesBe(6, &high_buf); + const low_raw = bigint.fromBytesBe(6, &low_buf); + // Both halves are < 2^256 < p, so they're already canonical Fp + // values; convert into Montgomery form directly. + const high_fp = Fp.fromRaw(high_raw); + const low_fp = Fp.fromRaw(low_raw); + return Fp.add(Fp.montMul(high_fp, FP_2_TO_256), low_fp); +} + +/// `a / 2` in Fp. +pub fn fpHalve(a: Fp.Element) Fp.Element { + return Fp.montMul(a, FP_TWO_INV); +} + +// --------------------------------------------------------------------------- +// BLS12-381 scalar field Fr. +// --------------------------------------------------------------------------- + +/// BLS12-381 scalar field prime, little-endian limbs (4 × 64 = 256 +/// bits, but the actual prime is 255 bits). +/// +/// r = 0x73eda753299d7d483339d80809a1d80553bda402fffe5bfeffffffff00000001 +pub const FR_MODULUS: [4]u64 = .{ + 0xffffffff00000001, + 0x53bda402fffe5bfe, + 0x3339d80809a1d805, + 0x73eda753299d7d48, +}; + +/// `R² mod r` where `R = 2^256`. Pinned from the standard `blst` +/// constants. +pub const FR_R2: [4]u64 = .{ + 0xc999e990f3f29c6d, + 0x2b6cedcb87925c23, + 0x05d314967254398f, + 0x0748d9d99f59ff11, +}; + +/// `-r⁻¹ mod 2^64`. +pub const FR_N_PRIME: u64 = 0xfffffffeffffffff; + +/// BLS12-381 scalar field `Fr = ℤ / rℤ`. Validators sign with scalars +/// drawn from this field, and the curve point group order is exactly +/// `r`. Stored in Montgomery form. +pub const Fr = field.MontgomeryField(4, FR_MODULUS, FR_R2, FR_N_PRIME); + +/// Check that an affine G1 point lies in the prime-order r-subgroup. +/// +/// The simplest correct check is `r·P == identity`. Routed through +/// `G1Projective.mul` so the 255-bit scalar multiplication doesn't +/// drag a Fermat inversion through every step. Faster checks like +/// Bowe's endomorphism trick can land later — they need the GLS / +/// GLV machinery that hasn't been written yet. +/// +/// Identity is in every subgroup by definition. +pub fn isInG1Subgroup(p: G1Affine) bool { + if (p.isIdentity()) return true; + return G1Projective.fromAffine(p).mul(4, FR_MODULUS).isIdentity(); +} + +/// Same shape as `isInG1Subgroup` but for G2. +pub fn isInG2Subgroup(p: G2Affine) bool { + if (p.isIdentity()) return true; + return G2Projective.fromAffine(p).mul(4, FR_MODULUS).isIdentity(); +} + +// --------------------------------------------------------------------------- +// G1 Jacobian projective coordinates. +// +// A Jacobian point (X, Y, Z) represents the affine point (X/Z², Y/Z³). +// The identity is signalled by Z = 0. Doubling and addition are +// inversion-free; only the final affine projection needs an inverse. +// +// This is the representation the Miller loop will hold its G2 +// accumulator in — affine arithmetic is too slow because every step +// runs Fermat inversion. +// --------------------------------------------------------------------------- + +pub const G1Projective = struct { + x: Fp.Element, + y: Fp.Element, + z: Fp.Element, + + pub fn identity() G1Projective { + return .{ .x = Fp.zero(), .y = Fp.one(), .z = Fp.zero() }; + } + + pub fn isIdentity(self: G1Projective) bool { + return Fp.eql(self.z, Fp.zero()); + } + + pub fn fromAffine(p: G1Affine) G1Projective { + if (p.infinity) return identity(); + return .{ .x = p.x, .y = p.y, .z = Fp.one() }; + } + + /// Project a Jacobian point back to affine via one Fp inversion. + pub fn toAffine(self: G1Projective) G1Affine { + if (self.isIdentity()) return G1Affine.identity(); + const z_inv = Fp.inv(self.z); + const z_inv_sq = Fp.square(z_inv); + const z_inv_cubed = Fp.montMul(z_inv_sq, z_inv); + return .{ + .x = Fp.montMul(self.x, z_inv_sq), + .y = Fp.montMul(self.y, z_inv_cubed), + .infinity = false, + }; + } + + /// Doubling: standard `dbl-2009-l` formulas for `a = 0` short + /// Weierstrass curves. ~3 squarings + 4 multiplications. + pub fn double(self: G1Projective) G1Projective { + if (self.isIdentity()) return self; + // A = X² + const A = Fp.square(self.x); + // B = Y² + const B = Fp.square(self.y); + // C = B² + const C = Fp.square(B); + // D = 2((X + B)² - A - C) + const x_plus_b = Fp.add(self.x, B); + const x_plus_b_sq = Fp.square(x_plus_b); + const D_inner = Fp.sub(Fp.sub(x_plus_b_sq, A), C); + const D = Fp.add(D_inner, D_inner); + // E = 3A + const E = Fp.add(Fp.add(A, A), A); + // F = E² + const F = Fp.square(E); + // X' = F - 2D + const x3 = Fp.sub(F, Fp.add(D, D)); + // Y' = E·(D - X') - 8C + const eight_c = blk: { + const two_c = Fp.add(C, C); + const four_c = Fp.add(two_c, two_c); + break :blk Fp.add(four_c, four_c); + }; + const y3 = Fp.sub(Fp.montMul(E, Fp.sub(D, x3)), eight_c); + // Z' = 2 Y Z + const z3 = Fp.add(Fp.montMul(self.y, self.z), Fp.montMul(self.y, self.z)); + return .{ .x = x3, .y = y3, .z = z3 }; + } + + /// Addition: `add-2007-bl` formulas. Falls back to `double` when + /// the inputs are equal and to the identity when they cancel. + /// ~12 multiplications + 4 squarings; not as fast as the + /// specialized mixed-add but covers every case for now. + pub fn add(p: G1Projective, q: G1Projective) G1Projective { + if (p.isIdentity()) return q; + if (q.isIdentity()) return p; + // Z1Z1 = Z1² + const Z1Z1 = Fp.square(p.z); + // Z2Z2 = Z2² + const Z2Z2 = Fp.square(q.z); + // U1 = X1 · Z2Z2 + const U1 = Fp.montMul(p.x, Z2Z2); + // U2 = X2 · Z1Z1 + const U2 = Fp.montMul(q.x, Z1Z1); + // S1 = Y1 · Z2 · Z2Z2 + const S1 = Fp.montMul(Fp.montMul(p.y, q.z), Z2Z2); + // S2 = Y2 · Z1 · Z1Z1 + const S2 = Fp.montMul(Fp.montMul(q.y, p.z), Z1Z1); + if (Fp.eql(U1, U2)) { + if (Fp.eql(S1, S2)) return p.double(); + return identity(); + } + // H = U2 - U1 + const H = Fp.sub(U2, U1); + // I = (2H)² + const two_h = Fp.add(H, H); + const I = Fp.square(two_h); + // J = H · I + const J = Fp.montMul(H, I); + // r = 2(S2 - S1) + const r = Fp.add(Fp.sub(S2, S1), Fp.sub(S2, S1)); + // V = U1 · I + const V = Fp.montMul(U1, I); + // X3 = r² - J - 2V + const x3 = Fp.sub(Fp.sub(Fp.square(r), J), Fp.add(V, V)); + // Y3 = r·(V - X3) - 2·S1·J + const two_s1_j = Fp.add(Fp.montMul(S1, J), Fp.montMul(S1, J)); + const y3 = Fp.sub(Fp.montMul(r, Fp.sub(V, x3)), two_s1_j); + // Z3 = ((Z1 + Z2)² - Z1Z1 - Z2Z2) · H + const z_sum_sq = Fp.square(Fp.add(p.z, q.z)); + const z3 = Fp.montMul(Fp.sub(Fp.sub(z_sum_sq, Z1Z1), Z2Z2), H); + return .{ .x = x3, .y = y3, .z = z3 }; + } + + /// Equality check that respects the Z scaling. Two Jacobian points + /// `(X1, Y1, Z1)` and `(X2, Y2, Z2)` represent the same affine point + /// iff `X1·Z2² == X2·Z1²` and `Y1·Z2³ == Y2·Z1³`. + pub fn eql(a: G1Projective, b: G1Projective) bool { + const a_inf = a.isIdentity(); + const b_inf = b.isIdentity(); + if (a_inf and b_inf) return true; + if (a_inf or b_inf) return false; + const z1z1 = Fp.square(a.z); + const z2z2 = Fp.square(b.z); + const x1z2z2 = Fp.montMul(a.x, z2z2); + const x2z1z1 = Fp.montMul(b.x, z1z1); + if (!Fp.eql(x1z2z2, x2z1z1)) return false; + const z1z1z1 = Fp.montMul(z1z1, a.z); + const z2z2z2 = Fp.montMul(z2z2, b.z); + const y1z2z2z2 = Fp.montMul(a.y, z2z2z2); + const y2z1z1z1 = Fp.montMul(b.y, z1z1z1); + return Fp.eql(y1z2z2z2, y2z1z1z1); + } + + /// Double-and-add scalar multiplication. Generic over the scalar + /// limb count. Dramatically faster than `G1Affine.mul` because + /// neither doubling nor addition needs Fermat inversion — only the + /// final `toAffine` does. + pub fn mul(self: G1Projective, comptime ScalarLimbs: comptime_int, scalar: [ScalarLimbs]u64) G1Projective { + const top = bigint.bitLen(ScalarLimbs, scalar); + if (top == 0) return identity(); + var result = identity(); + var i = top; + while (i > 0) { + i -= 1; + result = result.double(); + const limb = i / 64; + const bit = @as(u6, @intCast(i % 64)); + if (((scalar[limb] >> bit) & 1) == 1) { + result = result.add(self); + } + } + return result; + } +}; + +// --------------------------------------------------------------------------- +// G2 Jacobian projective coordinates. Mirrors G1Projective but with +// Fp2 coordinates. Same formulas; the only thing that changes is the +// underlying field operations. +// --------------------------------------------------------------------------- + +pub const G2Projective = struct { + x: Fp2, + y: Fp2, + z: Fp2, + + pub fn identity() G2Projective { + return .{ .x = Fp2.zero(), .y = Fp2.one(), .z = Fp2.zero() }; + } + + pub fn isIdentity(self: G2Projective) bool { + return Fp2.eql(self.z, Fp2.zero()); + } + + pub fn fromAffine(p: G2Affine) G2Projective { + if (p.infinity) return identity(); + return .{ .x = p.x, .y = p.y, .z = Fp2.one() }; + } + + pub fn toAffine(self: G2Projective) G2Affine { + if (self.isIdentity()) return G2Affine.identity(); + const z_inv = Fp2.inv(self.z); + const z_inv_sq = Fp2.square(z_inv); + const z_inv_cubed = Fp2.mul(z_inv_sq, z_inv); + return .{ + .x = Fp2.mul(self.x, z_inv_sq), + .y = Fp2.mul(self.y, z_inv_cubed), + .infinity = false, + }; + } + + pub fn double(self: G2Projective) G2Projective { + if (self.isIdentity()) return self; + const A = Fp2.square(self.x); + const B = Fp2.square(self.y); + const C = Fp2.square(B); + const x_plus_b = Fp2.add(self.x, B); + const x_plus_b_sq = Fp2.square(x_plus_b); + const D_inner = Fp2.sub(Fp2.sub(x_plus_b_sq, A), C); + const D = Fp2.add(D_inner, D_inner); + const E = Fp2.add(Fp2.add(A, A), A); + const F = Fp2.square(E); + const x3 = Fp2.sub(F, Fp2.add(D, D)); + const eight_c = blk: { + const two_c = Fp2.add(C, C); + const four_c = Fp2.add(two_c, two_c); + break :blk Fp2.add(four_c, four_c); + }; + const y3 = Fp2.sub(Fp2.mul(E, Fp2.sub(D, x3)), eight_c); + const z3 = Fp2.add(Fp2.mul(self.y, self.z), Fp2.mul(self.y, self.z)); + return .{ .x = x3, .y = y3, .z = z3 }; + } + + pub fn add(p: G2Projective, q: G2Projective) G2Projective { + if (p.isIdentity()) return q; + if (q.isIdentity()) return p; + const Z1Z1 = Fp2.square(p.z); + const Z2Z2 = Fp2.square(q.z); + const U1 = Fp2.mul(p.x, Z2Z2); + const U2 = Fp2.mul(q.x, Z1Z1); + const S1 = Fp2.mul(Fp2.mul(p.y, q.z), Z2Z2); + const S2 = Fp2.mul(Fp2.mul(q.y, p.z), Z1Z1); + if (Fp2.eql(U1, U2)) { + if (Fp2.eql(S1, S2)) return p.double(); + return identity(); + } + const H = Fp2.sub(U2, U1); + const two_h = Fp2.add(H, H); + const I = Fp2.square(two_h); + const J = Fp2.mul(H, I); + const r = Fp2.add(Fp2.sub(S2, S1), Fp2.sub(S2, S1)); + const V = Fp2.mul(U1, I); + const x3 = Fp2.sub(Fp2.sub(Fp2.square(r), J), Fp2.add(V, V)); + const two_s1_j = Fp2.add(Fp2.mul(S1, J), Fp2.mul(S1, J)); + const y3 = Fp2.sub(Fp2.mul(r, Fp2.sub(V, x3)), two_s1_j); + const z_sum_sq = Fp2.square(Fp2.add(p.z, q.z)); + const z3 = Fp2.mul(Fp2.sub(Fp2.sub(z_sum_sq, Z1Z1), Z2Z2), H); + return .{ .x = x3, .y = y3, .z = z3 }; + } + + pub fn eql(a: G2Projective, b: G2Projective) bool { + const a_inf = a.isIdentity(); + const b_inf = b.isIdentity(); + if (a_inf and b_inf) return true; + if (a_inf or b_inf) return false; + const z1z1 = Fp2.square(a.z); + const z2z2 = Fp2.square(b.z); + const x1z2z2 = Fp2.mul(a.x, z2z2); + const x2z1z1 = Fp2.mul(b.x, z1z1); + if (!Fp2.eql(x1z2z2, x2z1z1)) return false; + const z1z1z1 = Fp2.mul(z1z1, a.z); + const z2z2z2 = Fp2.mul(z2z2, b.z); + const y1z2z2z2 = Fp2.mul(a.y, z2z2z2); + const y2z1z1z1 = Fp2.mul(b.y, z1z1z1); + return Fp2.eql(y1z2z2z2, y2z1z1z1); + } + + /// Double-and-add scalar multiplication in projective form. See + /// `G1Projective.mul` for the rationale — this is the fast path + /// for any G2 scalar mul that the affine routine handles too + /// slowly (subgroup checks, cofactor clearing, etc.). + pub fn mul(self: G2Projective, comptime ScalarLimbs: comptime_int, scalar: [ScalarLimbs]u64) G2Projective { + const top = bigint.bitLen(ScalarLimbs, scalar); + if (top == 0) return identity(); + var result = identity(); + var i = top; + while (i > 0) { + i -= 1; + result = result.double(); + const limb = i / 64; + const bit = @as(u6, @intCast(i % 64)); + if (((scalar[limb] >> bit) & 1) == 1) { + result = result.add(self); + } + } + return result; + } +}; + +// --------------------------------------------------------------------------- +// BLS12-381 pairing parameters and roadmap. +// --------------------------------------------------------------------------- + +/// The BLS12-381 trace parameter `|x|`. The actual `x` is negative — +/// `x = -0xd201000000010000` — but the Miller loop walks the absolute +/// value and conjugates the final result. This single 64-bit constant +/// is everything the Miller loop needs to know about the curve choice. +pub const BLS_X_ABS: u64 = 0xd201000000010000; + +/// Whether `x` is negative. If true, the Miller loop result must be +/// conjugated (`(c0 + c1·w) → (c0 - c1·w)`) before final exponentiation. +pub const BLS_X_IS_NEGATIVE: bool = true; + +/// Miller loop length: bit length of `|x|`. The loop walks bits +/// `BLS_X_LOOP_BITS - 2` down to `0`. +pub const BLS_X_LOOP_BITS: usize = 64; + +// ---- Pairing implementation is below, after the Fp12 / point decoding ---- +// See `millerLoop`, `fp12FinalExp`, and `pairing` near the end of the +// non-test section. Hash-to-curve for G2 (RFC 9380 SSWU map) is in +// hash_to_field.zig and map_to_curve.zig (still to come). + +/// `Fp2 = Fp[u] / (u² + 1)`. Elements are pairs `(c0, c1)` representing +/// `c0 + c1·u`. The non-residue is `-1`, so squaring `u` produces +/// `-1 ∈ Fp` directly. +/// +/// Operations: +/// +/// - `(a + bu) + (c + du) = (a + c) + (b + d)u` +/// - `(a + bu) - (c + du) = (a - c) + (b - d)u` +/// - `(a + bu) · (c + du) = (ac - bd) + (ad + bc)u` +/// - `(a + bu)⁻¹ = (a - bu) / (a² + b²)` +/// +/// Multiplication uses the standard Karatsuba trick: three Fp +/// multiplications instead of four. The implementation here is +/// schoolbook for clarity; the optimization can land later if benchmarks +/// justify it. +pub const Fp2 = struct { + c0: Fp.Element, + c1: Fp.Element, + + pub fn zero() Fp2 { + return .{ .c0 = Fp.zero(), .c1 = Fp.zero() }; + } + + pub fn one() Fp2 { + return .{ .c0 = Fp.one(), .c1 = Fp.zero() }; + } + + pub fn eql(a: Fp2, b: Fp2) bool { + return Fp.eql(a.c0, b.c0) and Fp.eql(a.c1, b.c1); + } + + pub fn add(a: Fp2, b: Fp2) Fp2 { + return .{ + .c0 = Fp.add(a.c0, b.c0), + .c1 = Fp.add(a.c1, b.c1), + }; + } + + pub fn sub(a: Fp2, b: Fp2) Fp2 { + return .{ + .c0 = Fp.sub(a.c0, b.c0), + .c1 = Fp.sub(a.c1, b.c1), + }; + } + + pub fn neg(a: Fp2) Fp2 { + return .{ .c0 = Fp.neg(a.c0), .c1 = Fp.neg(a.c1) }; + } + + /// `(a + bu)·(c + du) = (ac - bd) + (ad + bc)u`. + pub fn mul(a: Fp2, b: Fp2) Fp2 { + const ac = Fp.montMul(a.c0, b.c0); + const bd = Fp.montMul(a.c1, b.c1); + const ad = Fp.montMul(a.c0, b.c1); + const bc = Fp.montMul(a.c1, b.c0); + return .{ + .c0 = Fp.sub(ac, bd), + .c1 = Fp.add(ad, bc), + }; + } + + /// `(a + bu)² = (a² - b²) + 2ab·u`. Specialized so the squaring + /// path uses two Fp multiplications + one Fp addition instead of + /// four Fp multiplications. + pub fn square(a: Fp2) Fp2 { + const a_plus_b = Fp.add(a.c0, a.c1); + const a_minus_b = Fp.sub(a.c0, a.c1); + const c0 = Fp.montMul(a_plus_b, a_minus_b); + const ab = Fp.montMul(a.c0, a.c1); + const c1 = Fp.add(ab, ab); + return .{ .c0 = c0, .c1 = c1 }; + } + + /// `(a + bu)⁻¹ = (a - bu) · (a² + b²)⁻¹`. The denominator is the + /// norm of the element in `Fp`, so its inversion only needs an + /// `Fp.inv` call. + pub fn inv(a: Fp2) Fp2 { + if (Fp.eql(a.c0, Fp.zero()) and Fp.eql(a.c1, Fp.zero())) return zero(); + const a0_sq = Fp.square(a.c0); + const a1_sq = Fp.square(a.c1); + const norm = Fp.add(a0_sq, a1_sq); + const norm_inv = Fp.inv(norm); + return .{ + .c0 = Fp.montMul(a.c0, norm_inv), + .c1 = Fp.montMul(Fp.neg(a.c1), norm_inv), + }; + } +}; + +/// Frobenius endomorphism in Fp2: `a → a^p`. For BLS12-381's base +/// prime `p ≡ 3 (mod 4)`, raising `u` to the `p`-th power gives `-u`, +/// so `(a₀ + a₁·u)^p = a₀ - a₁·u`. Conjugation by another name. +/// +/// This is the building block the higher tower extensions use to +/// implement their own Frobenius via precomputed coefficients. +pub fn fp2Frobenius(a: Fp2) Fp2 { + return .{ .c0 = a.c0, .c1 = Fp.neg(a.c1) }; +} + +/// Square-and-multiply exponentiation in Fp2. The exponent is a raw +/// little-endian limb array — each bit walked from MSB to LSB. Useful +/// for computing Frobenius coefficients and similar one-shot tower +/// constants without dragging precomputed tables into source. +pub fn fp2Pow(a: Fp2, comptime ExponentLimbs: comptime_int, exponent: [ExponentLimbs]u64) Fp2 { + const top = bigint.bitLen(ExponentLimbs, exponent); + if (top == 0) return Fp2.one(); + var result = a; + var i = top - 1; + while (i > 0) { + i -= 1; + result = Fp2.square(result); + const limb = i / 64; + const bit = @as(u6, @intCast(i % 64)); + if (((exponent[limb] >> bit) & 1) == 1) { + result = Fp2.mul(result, a); + } + } + return result; +} + +/// Square root of `a ∈ Fp2` when one exists, otherwise `error.NotASquare`. +/// +/// Algorithm: for `a = a₀ + a₁·u` and `u² = -1`, we want +/// `(c₀ + c₁·u)² = a`. Expanding gives the system +/// `c₀² - c₁² = a₀` and `2 c₀ c₁ = a₁`. Eliminating `c₁` and solving +/// the resulting quadratic in `c₀²` yields +/// `c₀² ∈ {(a₀ + β)/2, (a₀ - β)/2}` where `β = sqrt(a₀² + a₁²)` is +/// the square root of the norm in Fp. +/// +/// At least one of the two candidates is a square in Fp; we try the +/// `+` branch first and fall back to `-`. Once `c₀` is known, +/// `c₁ = a₁ / (2 c₀)`. Special-case the `c₀ == 0` path: that means +/// `a₁ = 0` (`a` is a pure `Fp` element) and the sqrt collapses to +/// `(sqrt(a₀), 0)` if `a₀` is a square or `(0, sqrt(-a₀))` otherwise. +pub fn fp2Sqrt(a: Fp2) error{NotASquare}!Fp2 { + // Pure-Fp special case (`a₁ = 0`). + if (Fp.eql(a.c1, Fp.zero())) { + if (Fp.eql(a.c0, Fp.zero())) return Fp2.zero(); + const root = fpSqrt(a.c0); + if (fpIsSquareRoot(a.c0, root)) { + return .{ .c0 = root, .c1 = Fp.zero() }; + } + // a₀ is not a square in Fp. Then -a₀ might be — try it. + const neg_a0 = Fp.neg(a.c0); + const root2 = fpSqrt(neg_a0); + if (fpIsSquareRoot(neg_a0, root2)) { + return .{ .c0 = Fp.zero(), .c1 = root2 }; + } + return error.NotASquare; + } + + // General case. + const norm = Fp.add(Fp.square(a.c0), Fp.square(a.c1)); + const beta = fpSqrt(norm); + if (!fpIsSquareRoot(norm, beta)) return error.NotASquare; + + // Try the `+` branch: c₀² = (a₀ + β) / 2. + const gamma_plus = fpHalve(Fp.add(a.c0, beta)); + var c0 = fpSqrt(gamma_plus); + if (!fpIsSquareRoot(gamma_plus, c0)) { + // Fall back to the `-` branch: c₀² = (a₀ - β) / 2. + const gamma_minus = fpHalve(Fp.sub(a.c0, beta)); + c0 = fpSqrt(gamma_minus); + if (!fpIsSquareRoot(gamma_minus, c0)) return error.NotASquare; + } + if (Fp.eql(c0, Fp.zero())) return error.NotASquare; + + // c₁ = a₁ / (2 c₀). + const two_c0 = Fp.add(c0, c0); + const c1 = Fp.montMul(a.c1, Fp.inv(two_c0)); + + return .{ .c0 = c0, .c1 = c1 }; +} + +// --------------------------------------------------------------------------- +// BLS12-381 G1 short Weierstrass curve: y² = x³ + 4 over Fp. +// --------------------------------------------------------------------------- + +/// `B` curve coefficient (4 in raw form). +pub const G1_B_RAW: [6]u64 = .{ 4, 0, 0, 0, 0, 0 }; + +/// Affine point on G1. The point at infinity uses the `infinity = true` +/// flag rather than encoding it as `(0, 0)`, so the predicates can stay +/// straightforward and the formulas don't have to special-case `(0, 0)` +/// when computing slopes. +pub const G1Affine = struct { + x: Fp.Element, + y: Fp.Element, + infinity: bool, + + /// Identity element (point at infinity). + pub fn identity() G1Affine { + return .{ .x = Fp.zero(), .y = Fp.zero(), .infinity = true }; + } + + /// Construct an affine point from raw little-endian limb arrays. + /// Caller is responsible for ensuring the point is on the curve. + pub fn fromRaw(x_raw: [6]u64, y_raw: [6]u64) G1Affine { + return .{ + .x = Fp.fromRaw(x_raw), + .y = Fp.fromRaw(y_raw), + .infinity = false, + }; + } + + pub fn isIdentity(self: G1Affine) bool { + return self.infinity; + } + + /// Equality. Identity points compare equal regardless of their + /// stored coordinates. + pub fn eql(a: G1Affine, b: G1Affine) bool { + if (a.infinity and b.infinity) return true; + if (a.infinity or b.infinity) return false; + return Fp.eql(a.x, b.x) and Fp.eql(a.y, b.y); + } + + /// Curve membership: `y² == x³ + 4`. Identity points are members + /// by definition. + pub fn isOnCurve(self: G1Affine) bool { + if (self.infinity) return true; + const y_sq = Fp.square(self.y); + const x_sq = Fp.square(self.x); + const x_cubed = Fp.montMul(x_sq, self.x); + const b = Fp.fromRaw(G1_B_RAW); + const rhs = Fp.add(x_cubed, b); + return Fp.eql(y_sq, rhs); + } + + /// `-P = (x, -y)`. + pub fn neg(self: G1Affine) G1Affine { + if (self.infinity) return self; + return .{ .x = self.x, .y = Fp.neg(self.y), .infinity = false }; + } + + /// Affine doubling: `2P` for a non-identity point. + /// `λ = 3x² / (2y)`, `x₃ = λ² - 2x`, `y₃ = λ(x - x₃) - y`. + pub fn double(self: G1Affine) G1Affine { + if (self.infinity) return self; + // If `y == 0`, the doubled point is the identity. + if (Fp.eql(self.y, Fp.zero())) return identity(); + const x_sq = Fp.square(self.x); + const three_x_sq = Fp.add(Fp.add(x_sq, x_sq), x_sq); + const two_y = Fp.add(self.y, self.y); + const lambda = Fp.montMul(three_x_sq, Fp.inv(two_y)); + const lambda_sq = Fp.square(lambda); + const two_x = Fp.add(self.x, self.x); + const x3 = Fp.sub(lambda_sq, two_x); + const y3 = Fp.sub(Fp.montMul(lambda, Fp.sub(self.x, x3)), self.y); + return .{ .x = x3, .y = y3, .infinity = false }; + } + + /// Affine addition: `P + Q` for distinct, non-identity points. + /// Falls back to `double` when `P == Q` and to the identity when + /// `P == -Q`. + pub fn add(a: G1Affine, b: G1Affine) G1Affine { + if (a.infinity) return b; + if (b.infinity) return a; + if (Fp.eql(a.x, b.x)) { + // Same x → either P+P or P + (-P). + if (Fp.eql(a.y, b.y)) return a.double(); + return identity(); + } + const lambda = Fp.montMul(Fp.sub(b.y, a.y), Fp.inv(Fp.sub(b.x, a.x))); + const lambda_sq = Fp.square(lambda); + const x3 = Fp.sub(Fp.sub(lambda_sq, a.x), b.x); + const y3 = Fp.sub(Fp.montMul(lambda, Fp.sub(a.x, x3)), a.y); + return .{ .x = x3, .y = y3, .infinity = false }; + } + + /// Scalar multiplication via double-and-add. The scalar is a raw + /// little-endian limb array of arbitrary width — bits are walked + /// from MSB to LSB. + /// + /// This is the simplest correct implementation. It does NOT use a + /// constant-time ladder, sliding-window NAF, GLV decomposition, or + /// any of the other tricks that real BLS verifiers reach for. The + /// upcoming pairing-based verification will need scalar multiples + /// of fixed/variable points, and we can pick a faster algorithm + /// when benchmarks justify it. + pub fn mul(self: G1Affine, comptime ScalarLimbs: comptime_int, scalar: [ScalarLimbs]u64) G1Affine { + const top = bigint.bitLen(ScalarLimbs, scalar); + if (top == 0) return identity(); + var result = identity(); + var i = top; + while (i > 0) { + i -= 1; + result = result.double(); + const limb = i / 64; + const bit = @as(u6, @intCast(i % 64)); + if (((scalar[limb] >> bit) & 1) == 1) { + result = result.add(self); + } + } + return result; + } +}; + +/// BLS12-381 G1 generator point. Coordinates from RFC 9380 §8.8.1 +/// (also matches the blst constants). +/// +/// x = 0x17F1D3A73197D7942695638C4FA9AC0FC3688C4F9774B905 +/// A14E3A3F171BAC586C55E83FF97A1AEFFB3AF00ADB22C6BB +/// y = 0x08B3F481E3AAA0F1A09E30ED741D8AE4FCF5E095D5D00AF6 +/// 00DB18CB2C04B3EDD03CC744A2888AE40CAA232946C5E7E1 +pub const G1_GENERATOR_X: [6]u64 = .{ + 0xfb3af00adb22c6bb, + 0x6c55e83ff97a1aef, + 0xa14e3a3f171bac58, + 0xc3688c4f9774b905, + 0x2695638c4fa9ac0f, + 0x17f1d3a73197d794, +}; +pub const G1_GENERATOR_Y: [6]u64 = .{ + 0x0caa232946c5e7e1, + 0xd03cc744a2888ae4, + 0x00db18cb2c04b3ed, + 0xfcf5e095d5d00af6, + 0xa09e30ed741d8ae4, + 0x08b3f481e3aaa0f1, +}; + +pub fn g1Generator() G1Affine { + return G1Affine.fromRaw(G1_GENERATOR_X, G1_GENERATOR_Y); +} + +// --------------------------------------------------------------------------- +// Fp6 = Fp2[v]/(v³ - (1+u)). Cubic extension built on top of Fp2. +// Elements are tuples (c0, c1, c2) representing c0 + c1·v + c2·v² with +// v³ = 1+u (the non-residue from Fp2). +// --------------------------------------------------------------------------- + +/// Multiply an Fp2 element by the non-residue `1 + u`. Used heavily by +/// Fp6 / Fp12 reduction. The expanded form is `(c0 - c1) + (c0 + c1)·u`. +pub fn fp2MulByNonresidue(a: Fp2) Fp2 { + return .{ + .c0 = Fp.sub(a.c0, a.c1), + .c1 = Fp.add(a.c0, a.c1), + }; +} + +pub const Fp6 = struct { + c0: Fp2, + c1: Fp2, + c2: Fp2, + + pub fn zero() Fp6 { + return .{ .c0 = Fp2.zero(), .c1 = Fp2.zero(), .c2 = Fp2.zero() }; + } + + pub fn one() Fp6 { + return .{ .c0 = Fp2.one(), .c1 = Fp2.zero(), .c2 = Fp2.zero() }; + } + + pub fn eql(a: Fp6, b: Fp6) bool { + return Fp2.eql(a.c0, b.c0) and Fp2.eql(a.c1, b.c1) and Fp2.eql(a.c2, b.c2); + } + + pub fn add(a: Fp6, b: Fp6) Fp6 { + return .{ + .c0 = Fp2.add(a.c0, b.c0), + .c1 = Fp2.add(a.c1, b.c1), + .c2 = Fp2.add(a.c2, b.c2), + }; + } + + pub fn sub(a: Fp6, b: Fp6) Fp6 { + return .{ + .c0 = Fp2.sub(a.c0, b.c0), + .c1 = Fp2.sub(a.c1, b.c1), + .c2 = Fp2.sub(a.c2, b.c2), + }; + } + + pub fn neg(a: Fp6) Fp6 { + return .{ .c0 = Fp2.neg(a.c0), .c1 = Fp2.neg(a.c1), .c2 = Fp2.neg(a.c2) }; + } + + /// Schoolbook multiplication. After collecting like-terms: + /// + /// c₀ = a₀·b₀ + (a₁·b₂ + a₂·b₁) · (1+u) + /// c₁ = a₀·b₁ + a₁·b₀ + a₂·b₂ · (1+u) + /// c₂ = a₀·b₂ + a₁·b₁ + a₂·b₀ + /// + /// 9 Fp2 multiplications + a handful of additions. Karatsuba can + /// drop this to 6 Fp2 multiplications; left as future work. + pub fn mul(a: Fp6, b: Fp6) Fp6 { + const t00 = Fp2.mul(a.c0, b.c0); + const t01 = Fp2.mul(a.c0, b.c1); + const t02 = Fp2.mul(a.c0, b.c2); + const t10 = Fp2.mul(a.c1, b.c0); + const t11 = Fp2.mul(a.c1, b.c1); + const t12 = Fp2.mul(a.c1, b.c2); + const t20 = Fp2.mul(a.c2, b.c0); + const t21 = Fp2.mul(a.c2, b.c1); + const t22 = Fp2.mul(a.c2, b.c2); + const c0 = Fp2.add(t00, fp2MulByNonresidue(Fp2.add(t12, t21))); + const c1 = Fp2.add(Fp2.add(t01, t10), fp2MulByNonresidue(t22)); + const c2 = Fp2.add(Fp2.add(t02, t11), t20); + return .{ .c0 = c0, .c1 = c1, .c2 = c2 }; + } + + pub fn square(a: Fp6) Fp6 { + // Could specialize but mul(a, a) is correct and clear. + return mul(a, a); + } + + /// Multiply by `v` (i.e., shift coefficients up). Used by Fp12. + /// `(c0 + c1v + c2v²) · v = c0v + c1v² + c2v³ = c2(1+u) + c0v + c1v²`. + pub fn mulByV(a: Fp6) Fp6 { + return .{ + .c0 = fp2MulByNonresidue(a.c2), + .c1 = a.c0, + .c2 = a.c1, + }; + } + + /// Square-and-multiply exponentiation in Fp6. Generic over the + /// exponent limb count. + pub fn pow(a: Fp6, comptime ExponentLimbs: comptime_int, exponent: [ExponentLimbs]u64) Fp6 { + const top = bigint.bitLen(ExponentLimbs, exponent); + if (top == 0) return one(); + var result = a; + var i = top - 1; + while (i > 0) { + i -= 1; + result = square(result); + const limb = i / 64; + const bit = @as(u6, @intCast(i % 64)); + if (((exponent[limb] >> bit) & 1) == 1) { + result = mul(result, a); + } + } + return result; + } + + /// Inversion using the standard adjugate / norm formula. For + /// `a = a₀ + a₁v + a₂v²` in Fp6, define + /// + /// A = a₀² − ξ·a₁·a₂ + /// B = ξ·a₂² − a₀·a₁ + /// C = a₁² − a₀·a₂ + /// + /// Then `a · (A + Bv + Cv²) = D` where + /// + /// D = a₀·A + ξ·a₂·B + ξ·a₁·C + /// + /// is an element of Fp2. Inverting D in Fp2 and scaling gives + /// `a⁻¹ = (A + Bv + Cv²) / D`. + pub fn inv(a: Fp6) Fp6 { + const a0_sq = Fp2.square(a.c0); + const a1_sq = Fp2.square(a.c1); + const a2_sq = Fp2.square(a.c2); + const a0_a1 = Fp2.mul(a.c0, a.c1); + const a0_a2 = Fp2.mul(a.c0, a.c2); + const a1_a2 = Fp2.mul(a.c1, a.c2); + + const A = Fp2.sub(a0_sq, fp2MulByNonresidue(a1_a2)); + const B = Fp2.sub(fp2MulByNonresidue(a2_sq), a0_a1); + const C = Fp2.sub(a1_sq, a0_a2); + + const a0_A = Fp2.mul(a.c0, A); + const xi_a2_B = fp2MulByNonresidue(Fp2.mul(a.c2, B)); + const xi_a1_C = fp2MulByNonresidue(Fp2.mul(a.c1, C)); + const D = Fp2.add(Fp2.add(a0_A, xi_a2_B), xi_a1_C); + const D_inv = Fp2.inv(D); + + return .{ + .c0 = Fp2.mul(A, D_inv), + .c1 = Fp2.mul(B, D_inv), + .c2 = Fp2.mul(C, D_inv), + }; + } +}; + +// --------------------------------------------------------------------------- +// Fp12 = Fp6[w]/(w² - v). The pairing target group lives here. +// Elements are pairs (c0, c1) representing c0 + c1·w with w² = v. +// --------------------------------------------------------------------------- + +pub const Fp12 = struct { + c0: Fp6, + c1: Fp6, + + pub fn zero() Fp12 { + return .{ .c0 = Fp6.zero(), .c1 = Fp6.zero() }; + } + + pub fn one() Fp12 { + return .{ .c0 = Fp6.one(), .c1 = Fp6.zero() }; + } + + pub fn eql(a: Fp12, b: Fp12) bool { + return Fp6.eql(a.c0, b.c0) and Fp6.eql(a.c1, b.c1); + } + + pub fn add(a: Fp12, b: Fp12) Fp12 { + return .{ .c0 = Fp6.add(a.c0, b.c0), .c1 = Fp6.add(a.c1, b.c1) }; + } + + pub fn sub(a: Fp12, b: Fp12) Fp12 { + return .{ .c0 = Fp6.sub(a.c0, b.c0), .c1 = Fp6.sub(a.c1, b.c1) }; + } + + pub fn neg(a: Fp12) Fp12 { + return .{ .c0 = Fp6.neg(a.c0), .c1 = Fp6.neg(a.c1) }; + } + + /// `(a₀ + a₁w)(b₀ + b₁w) = (a₀b₀ + a₁b₁·v) + (a₀b₁ + a₁b₀)w`. + /// Karatsuba: `(a₀ + a₁)(b₀ + b₁) - a₀b₀ - a₁b₁` for the cross term. + pub fn mul(a: Fp12, b: Fp12) Fp12 { + const aa = Fp6.mul(a.c0, b.c0); + const bb = Fp6.mul(a.c1, b.c1); + const c0 = Fp6.add(aa, Fp6.mulByV(bb)); + const c1 = Fp6.sub( + Fp6.sub(Fp6.mul(Fp6.add(a.c0, a.c1), Fp6.add(b.c0, b.c1)), aa), + bb, + ); + return .{ .c0 = c0, .c1 = c1 }; + } + + pub fn square(a: Fp12) Fp12 { + return mul(a, a); + } + + /// `(c₀ + c₁w)⁻¹ = (c₀ - c₁w) / (c₀² - v·c₁²)`. The denominator is + /// in Fp6, so the cost is one Fp6 inversion plus a handful of mul. + pub fn inv(a: Fp12) Fp12 { + const c0_sq = Fp6.square(a.c0); + const c1_sq = Fp6.square(a.c1); + const norm = Fp6.sub(c0_sq, Fp6.mulByV(c1_sq)); + const norm_inv = Fp6.inv(norm); + return .{ + .c0 = Fp6.mul(a.c0, norm_inv), + .c1 = Fp6.neg(Fp6.mul(a.c1, norm_inv)), + }; + } + + /// Conjugation: `(c₀ + c₁·w) → (c₀ - c₁·w)`. + /// + /// For elements of the cyclotomic subgroup, conjugation equals + /// `a^(p^6)` — that's the property the easy part of the final + /// exponentiation exploits to avoid a full p^6 powering. + pub fn conjugate(a: Fp12) Fp12 { + return .{ .c0 = a.c0, .c1 = Fp6.neg(a.c1) }; + } + + /// Square-and-multiply exponentiation in Fp12. Generic over the + /// limb count of the exponent. Used by the slow-but-correct path + /// of the final exponentiation; the optimized version replaces + /// some powerings with Frobenius applications once the + /// `Fp6.frobenius` constants land. + pub fn pow(a: Fp12, comptime ExponentLimbs: comptime_int, exponent: [ExponentLimbs]u64) Fp12 { + const top = bigint.bitLen(ExponentLimbs, exponent); + if (top == 0) return one(); + var result = a; + var i = top - 1; + while (i > 0) { + i -= 1; + result = square(result); + const limb = i / 64; + const bit = @as(u6, @intCast(i % 64)); + if (((exponent[limb] >> bit) & 1) == 1) { + result = mul(result, a); + } + } + return result; + } +}; + +// --------------------------------------------------------------------------- +// Compressed point encoding (BLS12-381 / RFC 9380 §3.3, IETF +// draft-irtf-cfrg-pairing-friendly-curves §C.2). +// +// A compressed G1 point is exactly 48 bytes — the big-endian encoding +// of the x coordinate, with three flag bits stuffed into the highest +// three bits of the first byte: +// +// bit 7 (msb): compression flag — 1 = compressed, 0 = uncompressed +// bit 6 : infinity flag +// bit 5 : y_sign / y_lex flag (which of the two y roots to take) +// +// The actual x coordinate occupies the remaining 381 bits (the top +// three bits of the 384-bit field are masked off when reading). +// --------------------------------------------------------------------------- +// --------------------------------------------------------------------------- + +pub const PointDecodeError = error{ + InvalidLength, + InvalidEncoding, + NotOnCurve, + NotInField, +}; + +/// Decode a 48-byte compressed BLS12-381 G1 point. Returns the +/// resulting `G1Affine`. Validates the compression / infinity flags +/// and verifies that the recovered point lies on the curve. +/// +/// Subgroup membership (cofactor clearing) is NOT checked here — that +/// is a higher-level decision the caller can layer on top. +pub fn decodeG1Compressed(bytes: []const u8) PointDecodeError!G1Affine { + if (bytes.len != 48) return PointDecodeError.InvalidLength; + + // Extract and clear the flag bits. + const compression_flag = (bytes[0] >> 7) & 1; + const infinity_flag = (bytes[0] >> 6) & 1; + const y_sign = (bytes[0] >> 5) & 1; + if (compression_flag != 1) return PointDecodeError.InvalidEncoding; + + // Infinity: every other bit must be zero. + if (infinity_flag == 1) { + if (y_sign != 0) return PointDecodeError.InvalidEncoding; + // First byte: 0xc0 (compression + infinity bits). Remaining 47 + // bytes must all be zero. + if (bytes[0] != 0xc0) return PointDecodeError.InvalidEncoding; + for (bytes[1..]) |b| if (b != 0) return PointDecodeError.InvalidEncoding; + return G1Affine.identity(); + } + + // Strip the flag bits from the first byte and copy into a working + // buffer. Then read the 48 bytes as big-endian limbs. + var clean: [48]u8 = undefined; + @memcpy(&clean, bytes); + clean[0] &= 0b0001_1111; + const x_raw = bigint.fromBytesBe(6, &clean); + + // Reject x ≥ p. The compressed encoding pretends the top three + // bits aren't there, but a malicious sender could still set the + // remaining 381 bits above p. + if (bigint.cmp(6, x_raw, FP_MODULUS) != .lt) return PointDecodeError.NotInField; + + // Convert into Montgomery form and reconstruct y from the curve + // equation: y² = x³ + 4. The square-root must round-trip — if it + // doesn't, x is not a valid x-coordinate of any point on the curve. + const x = Fp.fromRaw(x_raw); + const x_sq = Fp.square(x); + const x_cubed = Fp.montMul(x_sq, x); + const b = Fp.fromRaw(G1_B_RAW); + const rhs = Fp.add(x_cubed, b); + const y_candidate = fpSqrt(rhs); + if (!fpIsSquareRoot(rhs, y_candidate)) return PointDecodeError.NotOnCurve; + + // Pick the y root matching the sign flag. The "lexicographically + // larger" of the two roots has its high bit set in the raw + // representation. Compare against the negation to decide. + const y_neg = Fp.neg(y_candidate); + const y_raw = Fp.toRaw(y_candidate); + const y_neg_raw = Fp.toRaw(y_neg); + const candidate_is_larger = bigint.cmp(6, y_raw, y_neg_raw) == .gt; + const y = if ((y_sign == 1) == candidate_is_larger) y_candidate else y_neg; + + return .{ .x = x, .y = y, .infinity = false }; +} + +// --------------------------------------------------------------------------- +// BLS12-381 G2 short Weierstrass curve: y² = x³ + 4(1 + u) over Fp2. +// --------------------------------------------------------------------------- + +/// `B` curve coefficient for G2 in Fp2 form: `4 + 4·u`. +pub fn g2B() Fp2 { + const four = Fp.fromRaw(.{ 4, 0, 0, 0, 0, 0 }); + return .{ .c0 = four, .c1 = four }; +} + +/// Affine point on G2. Mirrors `G1Affine` but with `Fp2` coordinates. +/// The point at infinity is signalled by the `infinity` flag. +pub const G2Affine = struct { + x: Fp2, + y: Fp2, + infinity: bool, + + pub fn identity() G2Affine { + return .{ .x = Fp2.zero(), .y = Fp2.zero(), .infinity = true }; + } + + pub fn isIdentity(self: G2Affine) bool { + return self.infinity; + } + + pub fn eql(a: G2Affine, b: G2Affine) bool { + if (a.infinity and b.infinity) return true; + if (a.infinity or b.infinity) return false; + return Fp2.eql(a.x, b.x) and Fp2.eql(a.y, b.y); + } + + /// `y² == x³ + 4(1 + u)`. Identity is on the curve by definition. + pub fn isOnCurve(self: G2Affine) bool { + if (self.infinity) return true; + const y_sq = Fp2.square(self.y); + const x_sq = Fp2.square(self.x); + const x_cubed = Fp2.mul(x_sq, self.x); + const rhs = Fp2.add(x_cubed, g2B()); + return Fp2.eql(y_sq, rhs); + } + + /// `-P = (x, -y)`. + pub fn neg(self: G2Affine) G2Affine { + if (self.infinity) return self; + return .{ .x = self.x, .y = Fp2.neg(self.y), .infinity = false }; + } + + /// Affine doubling. Same shape as G1; the only differences are the + /// underlying field operations. + pub fn double(self: G2Affine) G2Affine { + if (self.infinity) return self; + if (Fp2.eql(self.y, Fp2.zero())) return identity(); + const x_sq = Fp2.square(self.x); + const three_x_sq = Fp2.add(Fp2.add(x_sq, x_sq), x_sq); + const two_y = Fp2.add(self.y, self.y); + const lambda = Fp2.mul(three_x_sq, Fp2.inv(two_y)); + const lambda_sq = Fp2.square(lambda); + const two_x = Fp2.add(self.x, self.x); + const x3 = Fp2.sub(lambda_sq, two_x); + const y3 = Fp2.sub(Fp2.mul(lambda, Fp2.sub(self.x, x3)), self.y); + return .{ .x = x3, .y = y3, .infinity = false }; + } + + /// Affine addition. Falls back to `double` when `P == Q` and to + /// the identity when `P == -Q`. + pub fn add(a: G2Affine, b: G2Affine) G2Affine { + if (a.infinity) return b; + if (b.infinity) return a; + if (Fp2.eql(a.x, b.x)) { + if (Fp2.eql(a.y, b.y)) return a.double(); + return identity(); + } + const lambda = Fp2.mul(Fp2.sub(b.y, a.y), Fp2.inv(Fp2.sub(b.x, a.x))); + const lambda_sq = Fp2.square(lambda); + const x3 = Fp2.sub(Fp2.sub(lambda_sq, a.x), b.x); + const y3 = Fp2.sub(Fp2.mul(lambda, Fp2.sub(a.x, x3)), a.y); + return .{ .x = x3, .y = y3, .infinity = false }; + } + + /// Scalar multiplication via double-and-add. Generic over the + /// scalar limb count. + pub fn mul(self: G2Affine, comptime ScalarLimbs: comptime_int, scalar: [ScalarLimbs]u64) G2Affine { + const top = bigint.bitLen(ScalarLimbs, scalar); + if (top == 0) return identity(); + var result = identity(); + var i = top; + while (i > 0) { + i -= 1; + result = result.double(); + const limb = i / 64; + const bit = @as(u6, @intCast(i % 64)); + if (((scalar[limb] >> bit) & 1) == 1) { + result = result.add(self); + } + } + return result; + } +}; + +/// BLS12-381 G2 generator coordinates. Decomposed limb-by-limb from the +/// standard hex strings (also pinned by blst): +/// +/// x.c0 = 0x024aa2b2f08f0a91260805272dc51051 +/// c6e47ad4fa403b02b4510b647ae3d177 +/// 0bac0326a805bbefd48056c8c121bdb8 +/// x.c1 = 0x13e02b6052719f607dacd3a088274f65 +/// 596bd0d09920b61ab5da61bbdc7f5049 +/// 334cf11213945d57e5ac7d055d042b7e +/// y.c0 = 0x0ce5d527727d6e118cc9cdc6da2e351a +/// adfd9baa8cbdd3a76d429a695160d12c +/// 923ac9cc3baca289e193548608b82801 +/// y.c1 = 0x0606c4a02ea734cc32acd2b02bc28b99 +/// cb3e287e85a763af267492ab572e99ab +/// 3f370d275cec1da1aaa9075ff05f79be +pub const G2_GENERATOR_X_C0: [6]u64 = .{ + 0xd48056c8c121bdb8, + 0x0bac0326a805bbef, + 0xb4510b647ae3d177, + 0xc6e47ad4fa403b02, + 0x260805272dc51051, + 0x024aa2b2f08f0a91, +}; +pub const G2_GENERATOR_X_C1: [6]u64 = .{ + 0xe5ac7d055d042b7e, + 0x334cf11213945d57, + 0xb5da61bbdc7f5049, + 0x596bd0d09920b61a, + 0x7dacd3a088274f65, + 0x13e02b6052719f60, +}; +pub const G2_GENERATOR_Y_C0: [6]u64 = .{ + 0xe193548608b82801, + 0x923ac9cc3baca289, + 0x6d429a695160d12c, + 0xadfd9baa8cbdd3a7, + 0x8cc9cdc6da2e351a, + 0x0ce5d527727d6e11, +}; +pub const G2_GENERATOR_Y_C1: [6]u64 = .{ + 0xaaa9075ff05f79be, + 0x3f370d275cec1da1, + 0x267492ab572e99ab, + 0xcb3e287e85a763af, + 0x32acd2b02bc28b99, + 0x0606c4a02ea734cc, +}; + +pub fn g2Generator() G2Affine { + return .{ + .x = .{ + .c0 = Fp.fromRaw(G2_GENERATOR_X_C0), + .c1 = Fp.fromRaw(G2_GENERATOR_X_C1), + }, + .y = .{ + .c0 = Fp.fromRaw(G2_GENERATOR_Y_C0), + .c1 = Fp.fromRaw(G2_GENERATOR_Y_C1), + }, + .infinity = false, + }; +} + +/// Decode a 96-byte compressed BLS12-381 G2 point. The wire format is +/// the IETF pairing-friendly-curves draft §C.2 layout: the first 48 +/// bytes encode `x.c1` (the imaginary coordinate) and the next 48 +/// encode `x.c0`, both as big-endian Fp elements. The same three flag +/// bits live in the high bits of byte 0: +/// +/// bit 7: compression flag +/// bit 6: infinity flag +/// bit 5: y-sign / lex flag +/// +/// The recovered y is reconstructed from `y² = x³ + 4(1+u)` via +/// `fp2Sqrt`, with the y-sign bit picking between the two roots based +/// on lexicographic comparison of the c1/c0 limb representation. +/// +/// Subgroup membership is NOT checked here. +pub fn decodeG2Compressed(bytes: []const u8) PointDecodeError!G2Affine { + if (bytes.len != 96) return PointDecodeError.InvalidLength; + + const compression_flag = (bytes[0] >> 7) & 1; + const infinity_flag = (bytes[0] >> 6) & 1; + const y_sign = (bytes[0] >> 5) & 1; + if (compression_flag != 1) return PointDecodeError.InvalidEncoding; + + if (infinity_flag == 1) { + if (y_sign != 0) return PointDecodeError.InvalidEncoding; + if (bytes[0] != 0xc0) return PointDecodeError.InvalidEncoding; + for (bytes[1..]) |b| if (b != 0) return PointDecodeError.InvalidEncoding; + return G2Affine.identity(); + } + + // Read x.c1 (first 48 bytes) and x.c0 (next 48). Strip the flag + // bits before parsing the c1 limb representation. + var c1_bytes: [48]u8 = undefined; + @memcpy(&c1_bytes, bytes[0..48]); + c1_bytes[0] &= 0b0001_1111; + const c0_bytes = bytes[48..96]; + + const x_c1_raw = bigint.fromBytesBe(6, &c1_bytes); + const x_c0_raw = bigint.fromBytesBe(6, c0_bytes); + if (bigint.cmp(6, x_c1_raw, FP_MODULUS) != .lt) return PointDecodeError.NotInField; + if (bigint.cmp(6, x_c0_raw, FP_MODULUS) != .lt) return PointDecodeError.NotInField; + + const x: Fp2 = .{ + .c0 = Fp.fromRaw(x_c0_raw), + .c1 = Fp.fromRaw(x_c1_raw), + }; + + // y² = x³ + 4(1+u). Reconstruct y via fp2Sqrt. + const x_sq = Fp2.square(x); + const x_cubed = Fp2.mul(x_sq, x); + const rhs = Fp2.add(x_cubed, g2B()); + const y_candidate = fp2Sqrt(rhs) catch return PointDecodeError.NotOnCurve; + const y_neg = Fp2.neg(y_candidate); + + // Choose the y root matching the y_sign flag. Lexicographic order + // on Fp2 elements compares c1 first then c0 (the natural projection + // of the byte serialization). + const y_c1_raw = Fp.toRaw(y_candidate.c1); + const y_c0_raw = Fp.toRaw(y_candidate.c0); + const y_neg_c1_raw = Fp.toRaw(y_neg.c1); + const y_neg_c0_raw = Fp.toRaw(y_neg.c0); + + const candidate_is_larger = blk: { + const c1_cmp = bigint.cmp(6, y_c1_raw, y_neg_c1_raw); + if (c1_cmp != .eq) break :blk c1_cmp == .gt; + break :blk bigint.cmp(6, y_c0_raw, y_neg_c0_raw) == .gt; + }; + const y = if ((y_sign == 1) == candidate_is_larger) y_candidate else y_neg; + + return .{ .x = x, .y = y, .infinity = false }; +} + +// --------------------------------------------------------------------------- +// Compressed point encoding (inverse of `decodeG1Compressed` / +// `decodeG2Compressed`). +// +// The encoder is the simple inverse of the decoder: +// +// 1. Identity: 0xc0 || 47 zeros (G1) or 95 zeros (G2). +// 2. Otherwise: serialize the x coordinate big-endian, set the +// compression flag (bit 7), and set the y-sign flag (bit 5) if +// the canonical y is the lexicographically larger of the two +// square roots. +// +// The y-sign convention matches the decoder: the bit is set when the +// affine y is greater than (-y) under the raw little-endian limb +// comparison. The infinity flag (bit 6) is mutually exclusive with the +// y-sign flag. +// --------------------------------------------------------------------------- + +/// Encode an affine G1 point into 48 compressed bytes. The output is +/// the inverse of `decodeG1Compressed` — feeding the result back through +/// the decoder yields a point equal to `p`. +pub fn encodeG1Compressed(p: G1Affine) [48]u8 { + var out: [48]u8 = .{0} ** 48; + if (p.infinity) { + out[0] = 0xc0; // compression flag + infinity flag + return out; + } + // Serialize x as 48 big-endian bytes. + const x_raw = Fp.toRaw(p.x); + bigint.toBytesBe(6, x_raw, &out); + + // Decide y-sign: set bit 5 of byte 0 iff y > -y. + const y_raw = Fp.toRaw(p.y); + const y_neg_raw = Fp.toRaw(Fp.neg(p.y)); + const y_is_larger = bigint.cmp(6, y_raw, y_neg_raw) == .gt; + + // Set the compression flag (bit 7) and the y-sign flag if needed. + out[0] |= 0x80; + if (y_is_larger) out[0] |= 0x20; + return out; +} + +/// Encode an affine G2 point into 96 compressed bytes. Layout matches +/// the IETF pairing-friendly-curves draft §C.2: bytes [0..48] hold +/// `x.c1` (with the flag bits in byte 0) and bytes [48..96] hold +/// `x.c0`. +pub fn encodeG2Compressed(p: G2Affine) [96]u8 { + var out: [96]u8 = .{0} ** 96; + if (p.infinity) { + out[0] = 0xc0; + return out; + } + + // x.c1 → first 48 bytes, x.c0 → next 48 bytes. + const x_c1_raw = Fp.toRaw(p.x.c1); + const x_c0_raw = Fp.toRaw(p.x.c0); + bigint.toBytesBe(6, x_c1_raw, out[0..48]); + bigint.toBytesBe(6, x_c0_raw, out[48..96]); + + // y-sign decision uses lex comparison on (c1, c0). + const y_neg = Fp2.neg(p.y); + const y_c1_raw = Fp.toRaw(p.y.c1); + const y_c0_raw = Fp.toRaw(p.y.c0); + const y_neg_c1_raw = Fp.toRaw(y_neg.c1); + const y_neg_c0_raw = Fp.toRaw(y_neg.c0); + + const y_is_larger = blk: { + const c1_cmp = bigint.cmp(6, y_c1_raw, y_neg_c1_raw); + if (c1_cmp != .eq) break :blk c1_cmp == .gt; + break :blk bigint.cmp(6, y_c0_raw, y_neg_c0_raw) == .gt; + }; + + out[0] |= 0x80; + if (y_is_larger) out[0] |= 0x20; + return out; +} + +// --------------------------------------------------------------------------- +// Sparse Fp6 / Fp12 multiplication helpers used by the Miller loop. +// +// The line function evaluation in BLS12 pairings produces an Fp12 element +// with most coefficients zero. Multiplying f ∈ Fp12 by such a sparse value +// is much cheaper than a full Fp12.mul if we exploit the zero positions. +// +// Conventions match arkworks (`mul_by_014`, `mul_by_01`, `mul_by_1`): +// +// - `Fp6.mulBy01(c0, c1)`: multiply Fp6 element by `c0 + c1·v` +// - `Fp6.mulBy1(c1)`: multiply Fp6 element by `c1·v` +// - `Fp12.mulBy014(c0, c1, c4)`: multiply Fp12 element by an Fp12 whose +// basis components at positions {0=1, 1=v, 4=vw} are c0, c1, c4 and +// all other positions are zero. M-twist line evaluations sit there. +// --------------------------------------------------------------------------- + +/// Sparse Fp6 multiplication: `a · (c0 + c1·v)`. Saves the three Fp2 +/// muls that would touch the zero `v²` coefficient of `(c0, c1, 0)`. +pub fn fp6MulBy01(a: Fp6, c0: Fp2, c1: Fp2) Fp6 { + // a_a = a.c0 · c0 + const a_a = Fp2.mul(a.c0, c0); + // b_b = a.c1 · c1 + const b_b = Fp2.mul(a.c1, c1); + + // t1 = c1 · (a.c1 + a.c2) − b_b, then × non-residue, then + a_a + var t1 = Fp2.mul(c1, Fp2.add(a.c1, a.c2)); + t1 = Fp2.sub(t1, b_b); + t1 = fp2MulByNonresidue(t1); + t1 = Fp2.add(t1, a_a); + + // t3 = c0 · (a.c0 + a.c2) − a_a + b_b + var t3 = Fp2.mul(c0, Fp2.add(a.c0, a.c2)); + t3 = Fp2.sub(t3, a_a); + t3 = Fp2.add(t3, b_b); + + // t2 = (c0 + c1) · (a.c0 + a.c1) − a_a − b_b + var t2 = Fp2.mul(Fp2.add(c0, c1), Fp2.add(a.c0, a.c1)); + t2 = Fp2.sub(t2, a_a); + t2 = Fp2.sub(t2, b_b); + + return .{ .c0 = t1, .c1 = t2, .c2 = t3 }; +} + +/// Sparse Fp6 multiplication: `a · (c1·v)`. Used by `mulBy014` to +/// handle the upper-half multiplication where only `vw` is non-zero. +pub fn fp6MulBy1(a: Fp6, c1: Fp2) Fp6 { + const b_b = Fp2.mul(a.c1, c1); + + // t1 = c1 · (a.c1 + a.c2) − b_b, then × non-residue. + var t1 = Fp2.mul(c1, Fp2.add(a.c1, a.c2)); + t1 = Fp2.sub(t1, b_b); + t1 = fp2MulByNonresidue(t1); + + // t2 = c1 · (a.c0 + a.c1) − b_b + var t2 = Fp2.mul(c1, Fp2.add(a.c0, a.c1)); + t2 = Fp2.sub(t2, b_b); + + return .{ .c0 = t1, .c1 = t2, .c2 = b_b }; +} + +/// Sparse Fp12 multiplication: `f · (c0 + c1·v + c4·vw)`. +/// +/// The right operand has only three non-zero Fp2 coefficients out of +/// the six positions of Fp12 — exactly the shape of an M-twist line. +/// Implemented exactly the way arkworks does it (Karatsuba-style) so a +/// future cross-check against a known-good Rust output is straightforward. +pub fn fp12MulBy014(f: Fp12, c0: Fp2, c1: Fp2, c4: Fp2) Fp12 { + // aa = f.c0 · (c0, c1, 0) — three Fp2 muls saved over a full mul. + const aa = fp6MulBy01(f.c0, c0, c1); + // bb = f.c1 · (0, c4, 0) + const bb = fp6MulBy1(f.c1, c4); + + // o = c1 + c4 + const o = Fp2.add(c1, c4); + + // c1' = (f.c0 + f.c1) · (c0, o, 0) − aa − bb + const sum = Fp6.add(f.c0, f.c1); + var c1_new = fp6MulBy01(sum, c0, o); + c1_new = Fp6.sub(c1_new, aa); + c1_new = Fp6.sub(c1_new, bb); + + // c0' = bb · v + aa + var c0_new = Fp6.mulByV(bb); + c0_new = Fp6.add(c0_new, aa); + + return .{ .c0 = c0_new, .c1 = c1_new }; +} + +/// Multiply an `Fp2` value by an `Fp` scalar (lifts the Fp into Fp2 as +/// `(s, 0)` and componentwise-multiplies). Used by line evaluation to +/// scale the line coefficients by `P.x` and `P.y`. +pub fn fp2MulByFp(a: Fp2, s: Fp.Element) Fp2 { + return .{ + .c0 = Fp.montMul(a.c0, s), + .c1 = Fp.montMul(a.c1, s), + }; +} + +// --------------------------------------------------------------------------- +// G2 homogeneous projective coordinates (NOT Jacobian). +// +// A "homogeneous" projective point (X : Y : Z) represents the affine +// point (X/Z, Y/Z), unlike Jacobian which uses (X/Z², Y/Z³). The pairing +// formulas from arkworks (and Costello / Aranha-Karabina-Longa-Gebotys- +// Lopez) work in this representation. The Miller loop runs entirely on +// G2HomProjective; we never need to project back to affine until the +// loop is done. +// +// We do NOT touch the existing `G2Projective` Jacobian type — keeping +// this separate avoids a representation change that would touch every +// G2 test in the package. +// --------------------------------------------------------------------------- + +pub const G2HomProjective = struct { + x: Fp2, + y: Fp2, + z: Fp2, + + pub fn fromAffine(p: G2Affine) G2HomProjective { + if (p.infinity) return .{ .x = Fp2.zero(), .y = Fp2.one(), .z = Fp2.zero() }; + return .{ .x = p.x, .y = p.y, .z = Fp2.one() }; + } + + pub fn isIdentity(self: G2HomProjective) bool { + return Fp2.eql(self.z, Fp2.zero()); + } + + /// Doubling step that also computes the line coefficients evaluated + /// at a G1 point. Returns `(2T, (c0, c1, c4))` where the line + /// coefficients live in Fp2 and need to be scaled by `(P.y, P.x)` + /// before being fed to `fp12MulBy014`. + /// + /// Formulas from arkworks `bls12::g2::G2HomProjective::double_in_place`, + /// which in turn cite the Costello / "Faster Explicit Formulas" + /// reference. Uses `b' = 4(1+u)` for BLS12-381's G2 curve constant. + pub fn doubleStep(self: *G2HomProjective) struct { Fp2, Fp2, Fp2 } { + // a = (X · Y) / 2 + var a = Fp2.mul(self.x, self.y); + a = halveFp2(a); + // b = Y² + const b = Fp2.square(self.y); + // c = Z² + const c = Fp2.square(self.z); + // e = 3 b' c + const three_c = Fp2.add(Fp2.add(c, c), c); + const e = Fp2.mul(g2B(), three_c); + // f = 3 e + const f = Fp2.add(Fp2.add(e, e), e); + // g = (b + f) / 2 + const g = halveFp2(Fp2.add(b, f)); + // h = (Y + Z)² − (b + c) + const y_plus_z = Fp2.add(self.y, self.z); + const h = Fp2.sub(Fp2.square(y_plus_z), Fp2.add(b, c)); + // i = e − b + const i = Fp2.sub(e, b); + // j = X² + const j = Fp2.square(self.x); + // e_square = e² + const e_square = Fp2.square(e); + + // X' = a · (b − f) + self.x = Fp2.mul(a, Fp2.sub(b, f)); + // Y' = g² − 3 e² + self.y = Fp2.sub(Fp2.square(g), Fp2.add(Fp2.add(e_square, e_square), e_square)); + // Z' = b · h + self.z = Fp2.mul(b, h); + + // M-twist line coefficients: (i, 3j, -h) + const three_j = Fp2.add(Fp2.add(j, j), j); + return .{ i, three_j, Fp2.neg(h) }; + } + + /// Mixed addition step: in-place adds `q` (affine) into `self` + /// (homogeneous projective) and returns line coefficients evaluated + /// against a G1 point. + pub fn addStep(self: *G2HomProjective, q: G2Affine) struct { Fp2, Fp2, Fp2 } { + // theta = Y - q.y · Z + const theta = Fp2.sub(self.y, Fp2.mul(q.y, self.z)); + // lambda = X - q.x · Z + const lambda = Fp2.sub(self.x, Fp2.mul(q.x, self.z)); + // c = theta² + const c = Fp2.square(theta); + // d = lambda² + const d = Fp2.square(lambda); + // e = lambda · d + const e = Fp2.mul(lambda, d); + // f = Z · c + const f = Fp2.mul(self.z, c); + // g = X · d + const g = Fp2.mul(self.x, d); + // h = e + f − 2g + const h = Fp2.sub(Fp2.add(e, f), Fp2.add(g, g)); + // X' = lambda · h + self.x = Fp2.mul(lambda, h); + // Y' = theta · (g − h) − e · Y + self.y = Fp2.sub(Fp2.mul(theta, Fp2.sub(g, h)), Fp2.mul(e, self.y)); + // Z' = Z · e + self.z = Fp2.mul(self.z, e); + // j = theta · q.x − lambda · q.y + const j = Fp2.sub(Fp2.mul(theta, q.x), Fp2.mul(lambda, q.y)); + + // M-twist line coefficients: (j, -theta, lambda) + return .{ j, Fp2.neg(theta), lambda }; + } +}; + +/// Halve an Fp2 element. `(a₀ + a₁·u)/2 = (a₀/2) + (a₁/2)·u`. +inline fn halveFp2(a: Fp2) Fp2 { + return .{ .c0 = fpHalve(a.c0), .c1 = fpHalve(a.c1) }; +} + +// --------------------------------------------------------------------------- +// Optimal Ate Miller loop for BLS12-381. +// +// Walks the bits of `|x|` from the second-most-significant down to bit 0, +// squaring the accumulator and applying a doubling line evaluation each +// step, with an extra addition line evaluation when the bit is set. The +// final result is conjugated when `x` is negative. +// --------------------------------------------------------------------------- + +/// `BLS_X_ABS` as raw bytes ordered MSB → LSB so the loop can pull bits +/// out from the top down without recomputing the bit length on every +/// iteration. The constant has bit length 64 (top bit set). +const BLS_X_ABS_BITS_MSB: [BLS_X_LOOP_BITS]u1 = blk: { + @setEvalBranchQuota(10000); + var bits: [BLS_X_LOOP_BITS]u1 = undefined; + var idx: usize = 0; + while (idx < BLS_X_LOOP_BITS) : (idx += 1) { + const i = BLS_X_LOOP_BITS - 1 - idx; + bits[idx] = @intCast((BLS_X_ABS >> @intCast(i)) & 1); + } + break :blk bits; +}; + +/// Apply a line evaluation produced by `doubleStep`/`addStep` to a +/// running Fp12 accumulator, scaling the M-twist line coefficients by +/// `(P.x, P.y)` first. +/// +/// `coeffs = (c0, c1, c4)` where `c4` already gets scaled by `P.y` and +/// `c1` already gets scaled by `P.x`. `c0` stays in Fp2. +fn ellM( + f: Fp12, + coeffs: struct { Fp2, Fp2, Fp2 }, + p: G1Affine, +) Fp12 { + // For M-twist: + // c2.mul_assign_by_fp(p.y); + // c1.mul_assign_by_fp(p.x); + // f.mul_by_014(c0, c1, c2) + // (where the third coefficient is at position 4 = vw). + const c0 = coeffs[0]; + const c1_scaled = fp2MulByFp(coeffs[1], p.x); + const c4_scaled = fp2MulByFp(coeffs[2], p.y); + return fp12MulBy014(f, c0, c1_scaled, c4_scaled); +} + +/// Optimal Ate Miller loop. Inputs are an affine G1 point `P` and an +/// affine G2 point `Q`; the result is an Fp12 element that becomes the +/// pairing value after the final exponentiation. +/// +/// Identity inputs short-circuit to `Fp12.one()` (the identity-pair +/// convention; the final exponentiation maps that to the multiplicative +/// identity in the target group). +pub fn millerLoop(p: G1Affine, q: G2Affine) Fp12 { + if (p.infinity or q.infinity) return Fp12.one(); + + var f = Fp12.one(); + var t = G2HomProjective.fromAffine(q); + + // Walk bits from BLS_X_LOOP_BITS-2 down to 0 — that is, skip the + // top bit (which is just "start with T = Q, f = 1") and do a + // double-line for each remaining bit, plus an add-line when the + // bit is set. + var i: usize = 1; + while (i < BLS_X_LOOP_BITS) : (i += 1) { + f = Fp12.square(f); + const dbl_coeffs = t.doubleStep(); + f = ellM(f, dbl_coeffs, p); + + if (BLS_X_ABS_BITS_MSB[i] == 1) { + const add_coeffs = t.addStep(q); + f = ellM(f, add_coeffs, p); + } + } + + // x is negative for BLS12-381 (`x = -0xd201000000010000`); the + // resulting Miller value picks up an inversion. Conjugating an + // Fp12 element equals raising it to `p^6`, which differs from + // a true inverse by a factor of `(p^12 − 1)` — i.e., 1 in Fp12* — + // so the easy part of final exponentiation absorbs the difference + // (see arkworks `multi_miller_loop`). + if (BLS_X_IS_NEGATIVE) { + f = Fp12.conjugate(f); + } + + return f; +} + +// --------------------------------------------------------------------------- +// Final exponentiation. +// +// Computes `f^((p^12 - 1) / r)` in two phases: +// +// 1. easy part: `f^((p^6 - 1)(p^2 + 1))`. After this `f` lives in +// the cyclotomic subgroup of order Φ_12(p) = p^4 - p^2 + 1. +// 2. hard part: `f^((p^4 - p^2 + 1) / r)`. Computed via an addition +// chain over `x` plus Frobenius applications. Uses the +// same chain as arkworks (which itself follows the +// ConsenSys/Gurvy implementation; see eprint 2020/875). +// --------------------------------------------------------------------------- + +/// Raise `f` to the BLS x parameter (`f^x`). Since `x` is negative for +/// BLS12-381, this conjugates the result of `f^|x|` (which equals the +/// true inverse for cyclotomic elements). +fn expByX(f: Fp12) Fp12 { + // Use plain `Fp12.pow` against the absolute value of x. A faster + // cyclotomic-aware exponentiation can land later — the addition + // chain shape doesn't change, only the per-step cost. + const x_abs_limbs: [1]u64 = .{BLS_X_ABS}; + var result = Fp12.pow(f, 1, x_abs_limbs); + if (BLS_X_IS_NEGATIVE) { + result = Fp12.conjugate(result); + } + return result; +} + +/// Hard part of the BLS12-381 final exponentiation, expressed as the +/// addition chain from arkworks (eprint 2020/875). All intermediate +/// values live in the cyclotomic subgroup so conjugation is the same +/// as inversion — every `cyclotomic_inverse_in_place` in the upstream +/// implementation maps to a plain `Fp12.conjugate` here. +pub fn fp12FinalExpHard(input: Fp12) Fp12 { + var r = input; + var y0 = Fp12.square(r); + var y1 = expByX(r); + var y2 = Fp12.conjugate(r); + y1 = Fp12.mul(y1, y2); + y2 = expByX(y1); + y1 = Fp12.conjugate(y1); + y1 = Fp12.mul(y1, y2); + y2 = expByX(y1); + y1 = fp12Frobenius(y1); + y1 = Fp12.mul(y1, y2); + r = Fp12.mul(r, y0); + y0 = expByX(y1); + y2 = expByX(y0); + y0 = y1; + y0 = fp12FrobeniusSquared(y0); + y1 = Fp12.conjugate(y1); + y1 = Fp12.mul(y1, y2); + y1 = Fp12.mul(y1, y0); + r = Fp12.mul(r, y1); + return r; +} + +/// Full final exponentiation: easy part followed by hard part. +pub fn fp12FinalExp(f: Fp12) Fp12 { + return fp12FinalExpHard(fp12FinalExpEasy(f)); +} + +/// Optimal Ate pairing for BLS12-381: `e(P, Q) = millerLoop(P, Q)^((p^12 - 1)/r)`. +/// +/// Returns `Fp12.one()` if either input is the identity (the pairing is +/// trivially 1 in that case after final exponentiation). +pub fn pairing(p: G1Affine, q: G2Affine) Fp12 { + if (p.infinity or q.infinity) return Fp12.one(); + return fp12FinalExp(millerLoop(p, q)); +} + +// --------------------------------------------------------------------------- +// Tests — these exercise the BLS12-381 Fp instance against arithmetic +// laws and a few hand-computed values. Real cross-implementation +// vectors against `blst` come once a real `blst` test harness is in +// place. +// --------------------------------------------------------------------------- + +const testing = std.testing; + +test "Fp.zero / Fp.one" { + const z = Fp.zero(); + try testing.expect(bigint.isZero(6, z)); + const o = Fp.one(); + // `one` is `R mod p`, NOT raw 1, so don't test for limbs == [1, 0, ...]. + // Instead, check that toRaw(one) == 1. + const raw = Fp.toRaw(o); + try testing.expectEqual(@as(u64, 1), raw[0]); + inline for (1..6) |i| try testing.expectEqual(@as(u64, 0), raw[i]); +} + +test "Fp identity laws" { + const a = Fp.fromRaw(.{ 0x0102030405060708, 0x1112131415161718, 0x2122232425262728, 0x3132333435363738, 0x4142434445464748, 0x0102030400000000 }); + try testing.expect(Fp.eql(Fp.add(a, Fp.zero()), a)); + try testing.expect(Fp.eql(Fp.add(Fp.zero(), a), a)); + try testing.expect(Fp.eql(Fp.montMul(a, Fp.one()), a)); + try testing.expect(Fp.eql(Fp.montMul(Fp.one(), a), a)); + try testing.expect(Fp.eql(Fp.add(a, Fp.neg(a)), Fp.zero())); +} + +test "Fp.add wraps around the modulus" { + const one_e = Fp.fromRaw(.{ 1, 0, 0, 0, 0, 0 }); + var p_minus_one_raw: [6]u64 = FP_MODULUS; + p_minus_one_raw[0] -= 1; + const p_minus_one = Fp.fromRaw(p_minus_one_raw); + const sum = Fp.add(one_e, p_minus_one); + try testing.expect(Fp.eql(sum, Fp.zero())); +} + +test "Fp.montMul: 2 * 3 = 6" { + const two = Fp.fromRaw(.{ 2, 0, 0, 0, 0, 0 }); + const three = Fp.fromRaw(.{ 3, 0, 0, 0, 0, 0 }); + const six = Fp.fromRaw(.{ 6, 0, 0, 0, 0, 0 }); + const product = Fp.montMul(two, three); + try testing.expect(Fp.eql(product, six)); +} + +test "Fp.montMul: distributive over add" { + const a = Fp.fromRaw(.{ 0x12345678, 0xabcdef00, 0x11111111, 0, 0, 0 }); + const b = Fp.fromRaw(.{ 0xfedcba98, 0x76543210, 0, 0x22222222, 0, 0 }); + const c = Fp.fromRaw(.{ 0x42, 0, 0, 0, 0x33333333, 0 }); + const lhs = Fp.montMul(Fp.add(a, b), c); + const rhs = Fp.add(Fp.montMul(a, c), Fp.montMul(b, c)); + try testing.expect(Fp.eql(lhs, rhs)); +} + +test "Fp.montMul: associativity" { + const a = Fp.fromRaw(.{ 7, 0, 0, 0, 0, 0 }); + const b = Fp.fromRaw(.{ 11, 0, 0, 0, 0, 0 }); + const c = Fp.fromRaw(.{ 13, 0, 0, 0, 0, 0 }); + const lhs = Fp.montMul(Fp.montMul(a, b), c); + const rhs = Fp.montMul(a, Fp.montMul(b, c)); + try testing.expect(Fp.eql(lhs, rhs)); + // Hand check: 7 * 11 * 13 = 1001 + const raw = Fp.toRaw(lhs); + try testing.expectEqual(@as(u64, 1001), raw[0]); + inline for (1..6) |i| try testing.expectEqual(@as(u64, 0), raw[i]); +} + +test "Fp.toRaw round-trips a near-modulus value" { + var raw: [6]u64 = FP_MODULUS; + raw[0] -= 7; + const e = Fp.fromRaw(raw); + const back = Fp.toRaw(e); + try testing.expectEqual(raw, back); +} + +test "Fp.fromBytesLeReduced rejects ≥ p" { + var bytes: [48]u8 = .{0xff} ** 48; + try testing.expectError(error.NotInField, Fp.fromBytesLeReduced(&bytes)); + + var ok: [48]u8 = .{0} ** 48; + ok[0] = 1; + const e = try Fp.fromBytesLeReduced(&ok); + try testing.expect(Fp.eql(e, Fp.one())); +} + +test "Fp.toBytesLe round-trips fromBytesLeReduced" { + var input: [48]u8 = .{0} ** 48; + input[0] = 0x42; + input[1] = 0x13; + input[47] = 0x01; // High byte must respect the prime ceiling. + const e = try Fp.fromBytesLeReduced(&input); + var output: [48]u8 = undefined; + Fp.toBytesLe(e, &output); + try testing.expectEqualSlices(u8, &input, &output); +} + +test "Fp: (p-1)^2 + (2p-1) ≡ 0 mod p" { + // (p-1)^2 = p^2 - 2p + 1 ≡ 1 mod p, so (p-1)^2 + (-1) ≡ 0. + // We test it as: (p-1)^2 + (p-1) ≡ p-1+1 = 0 ... no wait. + // Simpler test: (p-1) + 1 = 0 + var raw: [6]u64 = FP_MODULUS; + raw[0] -= 1; + const p_minus_one = Fp.fromRaw(raw); + const one = Fp.fromRaw(.{ 1, 0, 0, 0, 0, 0 }); + const sum = Fp.add(p_minus_one, one); + try testing.expect(Fp.eql(sum, Fp.zero())); + // And (p-1) * (p-1) = 1 mod p + const sq = Fp.montMul(p_minus_one, p_minus_one); + const one_mont = Fp.one(); + try testing.expect(Fp.eql(sq, one_mont)); +} + +test "Fp.inv: a * a^-1 = 1 (6-limb Fermat inversion)" { + // 6-limb Fermat inversion is the most expensive operation in the + // package — ~381 squarings + ~190 multiplies. Verify it produces + // the multiplicative inverse for a representative non-trivial + // value. + const a = Fp.fromRaw(.{ 0x123456789abcdef0, 0xfedcba9876543210, 0x1111222233334444, 0x5555666677778888, 0x9999aaaabbbbcccc, 0x0123 }); + const inv_a = Fp.inv(a); + const product = Fp.montMul(a, inv_a); + try testing.expect(Fp.eql(product, Fp.one())); +} + +test "Fp.inv: inv(1) = 1 (6 limbs)" { + const one = Fp.one(); + try testing.expect(Fp.eql(Fp.inv(one), one)); +} + +test "Fp.inv: inv(zero) = zero" { + const z = Fp.zero(); + try testing.expect(Fp.eql(Fp.inv(z), z)); +} + +test "Fp.square: 5^2 = 25 (6 limbs)" { + const five = Fp.fromRaw(.{ 5, 0, 0, 0, 0, 0 }); + const sq = Fp.square(five); + const twenty_five = Fp.fromRaw(.{ 25, 0, 0, 0, 0, 0 }); + try testing.expect(Fp.eql(sq, twenty_five)); +} + +test "Fp: (p+1)/4 derivation gives a sane result" { + // Sanity check: 4 * ((p+1)/4) - 1 == p, modulo wrap. + // We can't easily verify the constant by hand, so instead test + // that fpSqrt(4) returns 2 (and squaring 2 gives 4 back). + const four = Fp.fromRaw(.{ 4, 0, 0, 0, 0, 0 }); + const two = Fp.fromRaw(.{ 2, 0, 0, 0, 0, 0 }); + const root = fpSqrt(four); + // Square root of 4 should be ±2; either branch must square to 4. + try testing.expect(fpIsSquareRoot(four, root)); + // The "positive" root (smaller of the two) should match 2. + if (!Fp.eql(root, two)) { + try testing.expect(Fp.eql(root, Fp.neg(two))); + } +} + +test "Fp.sqrt: 25 -> ±5" { + const twenty_five = Fp.fromRaw(.{ 25, 0, 0, 0, 0, 0 }); + const five = Fp.fromRaw(.{ 5, 0, 0, 0, 0, 0 }); + const root = fpSqrt(twenty_five); + try testing.expect(fpIsSquareRoot(twenty_five, root)); + try testing.expect(Fp.eql(root, five) or Fp.eql(root, Fp.neg(five))); +} + +test "Fr identity laws" { + const a = Fr.fromRaw(.{ 0x12345678, 0xdeadbeef, 0xabad1dea, 0x0123 }); + try testing.expect(Fr.eql(Fr.add(a, Fr.zero()), a)); + try testing.expect(Fr.eql(Fr.montMul(a, Fr.one()), a)); + try testing.expect(Fr.eql(Fr.add(a, Fr.neg(a)), Fr.zero())); +} + +test "Fr: 2 * 3 = 6" { + const two = Fr.fromRaw(.{ 2, 0, 0, 0 }); + const three = Fr.fromRaw(.{ 3, 0, 0, 0 }); + const six = Fr.fromRaw(.{ 6, 0, 0, 0 }); + try testing.expect(Fr.eql(Fr.montMul(two, three), six)); +} + +test "Fr: distributive over add" { + const a = Fr.fromRaw(.{ 0x12345678, 0xabcd, 0, 0 }); + const b = Fr.fromRaw(.{ 0xdeadbeef, 0, 0xfeedface, 0 }); + const c = Fr.fromRaw(.{ 0x42, 0, 0, 0xbeef }); + const lhs = Fr.montMul(Fr.add(a, b), c); + const rhs = Fr.add(Fr.montMul(a, c), Fr.montMul(b, c)); + try testing.expect(Fr.eql(lhs, rhs)); +} + +test "Fr: (r-1) + 1 = 0" { + var r_minus_one_raw: [4]u64 = FR_MODULUS; + r_minus_one_raw[0] -= 1; + const r_minus_one = Fr.fromRaw(r_minus_one_raw); + const one_e = Fr.fromRaw(.{ 1, 0, 0, 0 }); + try testing.expect(Fr.eql(Fr.add(r_minus_one, one_e), Fr.zero())); +} + +test "Fr.inv: a * a^-1 = 1 (4-limb scalar field)" { + const a = Fr.fromRaw(.{ 0x123456789abcdef0, 0xfedcba9876543210, 0x1122334455667788, 0x0123456789abcdef }); + const inv_a = Fr.inv(a); + try testing.expect(Fr.eql(Fr.montMul(a, inv_a), Fr.one())); +} + +test "isInG1Subgroup: G1 generator is in subgroup" { + try testing.expect(isInG1Subgroup(g1Generator())); +} + +test "isInG1Subgroup: identity is in every subgroup" { + try testing.expect(isInG1Subgroup(G1Affine.identity())); +} + +test "isInG2Subgroup: G2 generator is in subgroup" { + try testing.expect(isInG2Subgroup(g2Generator())); +} + +test "isInG2Subgroup: identity is in every subgroup" { + try testing.expect(isInG2Subgroup(G2Affine.identity())); +} + +// --------------------------------------------------------------------------- +// Optimal Ate pairing tests +// --------------------------------------------------------------------------- + +test "pairing: e(O, Q) = 1" { + const id_g1 = G1Affine.identity(); + const g2 = g2Generator(); + const result = pairing(id_g1, g2); + try testing.expect(Fp12.eql(result, Fp12.one())); +} + +test "pairing: e(P, O) = 1" { + const g1 = g1Generator(); + const id_g2 = G2Affine.identity(); + const result = pairing(g1, id_g2); + try testing.expect(Fp12.eql(result, Fp12.one())); +} + +test "pairing: non-degenerate (e(g1, g2) ≠ 1)" { + const g1 = g1Generator(); + const g2 = g2Generator(); + const result = pairing(g1, g2); + try testing.expect(!Fp12.eql(result, Fp12.one())); +} + +test "pairing: deterministic" { + const g1 = g1Generator(); + const g2 = g2Generator(); + const result1 = pairing(g1, g2); + const result2 = pairing(g1, g2); + try testing.expect(Fp12.eql(result1, result2)); +} + +test "pairing bilinearity: e(2P, Q) = e(P, Q)^2" { + // e(2P, Q) should equal e(P, 2Q) should equal e(P, Q)². + const g1 = g1Generator(); + const g2 = g2Generator(); + + const two_g1 = g1.double(); + const two_g2 = g2.double(); + + const e_2pq = pairing(two_g1, g2); + const e_p2q = pairing(g1, two_g2); + const e_pq = pairing(g1, g2); + const e_pq_sq = Fp12.square(e_pq); + + try testing.expect(Fp12.eql(e_2pq, e_p2q)); + try testing.expect(Fp12.eql(e_2pq, e_pq_sq)); +} + +test "pairing bilinearity: e(aP, Q) = e(P, Q)^a for small a" { + // Pick a = 5 so the test runs in a sane amount of time. The full + // 4-limb scalar exponent path is exercised by the (aP, bQ) test. + const g1 = g1Generator(); + const g2 = g2Generator(); + + const five_g1 = g1.mul(1, .{5}); + const e_5pq = pairing(five_g1, g2); + + const e_pq = pairing(g1, g2); + const e_pq_5 = Fp12.pow(e_pq, 1, .{5}); + + try testing.expect(Fp12.eql(e_5pq, e_pq_5)); +} + +test "pairing bilinearity: e(aP, bQ) = e(P, Q)^(ab)" { + const g1 = g1Generator(); + const g2 = g2Generator(); + + // Pick small scalars so the affine scalar muls don't dominate + // test runtime, but large enough that ab would catch off-by-one + // bugs in the addition chain. + const a: u64 = 7; + const b: u64 = 11; + const ab: u64 = a * b; + + const ap = g1.mul(1, .{a}); + const bq = g2.mul(1, .{b}); + + const lhs = pairing(ap, bq); + const rhs = Fp12.pow(pairing(g1, g2), 1, .{ab}); + + try testing.expect(Fp12.eql(lhs, rhs)); +} + +test "pairing: e(P, -Q) = e(-P, Q) = e(P, Q)^(-1)" { + // The pairing is bilinear, so swapping the sign on either input + // should produce the inverse in the target group. After final + // exponentiation, the inverse equals the conjugate. + const g1 = g1Generator(); + const g2 = g2Generator(); + + const e_pq = pairing(g1, g2); + const e_neg_p_q = pairing(g1.neg(), g2); + const e_p_neg_q = pairing(g1, g2.neg()); + + // After final exponentiation we're in the cyclotomic subgroup, + // so conjugate(x) = x^(-1). + const inv_e_pq = Fp12.conjugate(e_pq); + + try testing.expect(Fp12.eql(e_neg_p_q, inv_e_pq)); + try testing.expect(Fp12.eql(e_p_neg_q, inv_e_pq)); + // And e(-P, Q) · e(P, Q) = 1. + try testing.expect(Fp12.eql(Fp12.mul(e_neg_p_q, e_pq), Fp12.one())); +} + +// --------------------------------------------------------------------------- +// G1Projective tests +// --------------------------------------------------------------------------- + +test "G1Projective: identity round-trip via toAffine" { + try testing.expect(G1Projective.identity().toAffine().isIdentity()); +} + +test "G1Projective: fromAffine -> toAffine is identity" { + const g = g1Generator(); + const proj = G1Projective.fromAffine(g); + const back = proj.toAffine(); + try testing.expect(G1Affine.eql(back, g)); +} + +test "G1Projective: double matches G1Affine.double" { + const g = g1Generator(); + const aff_two_g = g.double(); + const proj_two_g = G1Projective.fromAffine(g).double().toAffine(); + try testing.expect(G1Affine.eql(aff_two_g, proj_two_g)); +} + +test "G1Projective: add matches G1Affine.add for distinct points" { + const g = g1Generator(); + const two_g_aff = g.double(); + const three_g_aff = two_g_aff.add(g); + // Same chain via projective. + const g_p = G1Projective.fromAffine(g); + const two_g_p = g_p.double(); + const three_g_p = two_g_p.add(g_p).toAffine(); + try testing.expect(G1Affine.eql(three_g_aff, three_g_p)); +} + +test "G1Projective: P + (-P) = identity" { + const g = g1Generator(); + const g_p = G1Projective.fromAffine(g); + const neg_g_p = G1Projective.fromAffine(g.neg()); + const sum = g_p.add(neg_g_p); + try testing.expect(sum.isIdentity()); + try testing.expect(sum.toAffine().isIdentity()); +} + +test "G1Projective: add(P, P) falls through to double" { + const g_p = G1Projective.fromAffine(g1Generator()); + const sum = g_p.add(g_p); + const doubled = g_p.double(); + try testing.expect(G1Projective.eql(sum, doubled)); +} + +test "G1Projective: associativity through projective then back to affine" { + const g = g1Generator(); + const a = G1Projective.fromAffine(g); + const b = a.double(); + const c = b.add(a); + // (a + b) + c == a + (b + c) + const lhs = a.add(b).add(c).toAffine(); + const rhs = a.add(b.add(c)).toAffine(); + try testing.expect(G1Affine.eql(lhs, rhs)); +} + +// --------------------------------------------------------------------------- +// G2Projective tests — same shape as G1Projective. +// --------------------------------------------------------------------------- + +test "G2Projective: identity round-trip via toAffine" { + try testing.expect(G2Projective.identity().toAffine().isIdentity()); +} + +test "G2Projective: fromAffine -> toAffine is identity" { + const g = g2Generator(); + const back = G2Projective.fromAffine(g).toAffine(); + try testing.expect(G2Affine.eql(back, g)); +} + +test "G2Projective: double matches G2Affine.double" { + const g = g2Generator(); + const aff_two_g = g.double(); + const proj_two_g = G2Projective.fromAffine(g).double().toAffine(); + try testing.expect(G2Affine.eql(aff_two_g, proj_two_g)); +} + +test "G2Projective: add matches G2Affine.add for distinct points" { + const g = g2Generator(); + const two_g_aff = g.double(); + const three_g_aff = two_g_aff.add(g); + const g_p = G2Projective.fromAffine(g); + const two_g_p = g_p.double(); + const three_g_p = two_g_p.add(g_p).toAffine(); + try testing.expect(G2Affine.eql(three_g_aff, three_g_p)); +} + +test "G2Projective: P + (-P) = identity" { + const g = g2Generator(); + const g_p = G2Projective.fromAffine(g); + const neg_g_p = G2Projective.fromAffine(g.neg()); + try testing.expect(g_p.add(neg_g_p).isIdentity()); +} + +test "G2Projective: add(P, P) falls through to double" { + const g_p = G2Projective.fromAffine(g2Generator()); + try testing.expect(G2Projective.eql(g_p.add(g_p), g_p.double())); +} + +test "G1Projective.mul: 0*G = identity, 1*G = G" { + const g = G1Projective.fromAffine(g1Generator()); + try testing.expect(g.mul(1, .{0}).isIdentity()); + try testing.expect(G1Projective.eql(g.mul(1, .{1}), g)); +} + +test "G1Projective.mul matches affine for small scalars" { + const g = g1Generator(); + const g_p = G1Projective.fromAffine(g); + // 5 * G via projective and affine paths must agree. + const five_g_proj = g_p.mul(1, .{5}).toAffine(); + const five_g_aff = g.mul(1, .{5}); + try testing.expect(G1Affine.eql(five_g_proj, five_g_aff)); +} + +test "G1Projective.mul: 7 * G via projective matches G + G + ... + G" { + const g_p = G1Projective.fromAffine(g1Generator()); + const seven_g = g_p.mul(1, .{7}); + var manual = g_p; + var i: usize = 0; + while (i < 6) : (i += 1) manual = manual.add(g_p); + try testing.expect(G1Projective.eql(seven_g, manual)); +} + +test "G2Projective.mul: 0*G = identity, 1*G = G" { + const g = G2Projective.fromAffine(g2Generator()); + try testing.expect(g.mul(1, .{0}).isIdentity()); + try testing.expect(G2Projective.eql(g.mul(1, .{1}), g)); +} + +test "G2Projective.mul matches affine for small scalars" { + const g = g2Generator(); + const g_p = G2Projective.fromAffine(g); + const five_g_proj = g_p.mul(1, .{5}).toAffine(); + const five_g_aff = g.mul(1, .{5}); + try testing.expect(G2Affine.eql(five_g_proj, five_g_aff)); +} + +test "Fp.sqrt: non-residue check" { + // We don't have an analytical "is this a residue" predicate yet, + // but for any non-square `a`, `fpIsSquareRoot(a, fpSqrt(a))` + // should return false. Picking 5 — quadratic residue status of + // small primes mod the BLS12-381 base prime is not obvious by + // hand, so we just assert that the round-trip predicate works + // correctly: either it round-trips (residue) or it doesn't. + const five = Fp.fromRaw(.{ 5, 0, 0, 0, 0, 0 }); + const root = fpSqrt(five); + const round_trip = Fp.square(root); + // round_trip is either 5 (residue) or -5 (non-residue). + try testing.expect(Fp.eql(round_trip, five) or Fp.eql(round_trip, Fp.neg(five))); +} + +// --------------------------------------------------------------------------- +// Fp2 tests +// --------------------------------------------------------------------------- + +fn fpFromU64(n: u64) Fp.Element { + return Fp.fromRaw(.{ n, 0, 0, 0, 0, 0 }); +} + +fn fp2FromU64Pair(c0: u64, c1: u64) Fp2 { + return .{ .c0 = fpFromU64(c0), .c1 = fpFromU64(c1) }; +} + +test "Fp2 identity laws" { + const a = fp2FromU64Pair(7, 11); + try testing.expect(Fp2.eql(Fp2.add(a, Fp2.zero()), a)); + try testing.expect(Fp2.eql(Fp2.add(Fp2.zero(), a), a)); + try testing.expect(Fp2.eql(Fp2.mul(a, Fp2.one()), a)); + try testing.expect(Fp2.eql(Fp2.mul(Fp2.one(), a), a)); + try testing.expect(Fp2.eql(Fp2.add(a, Fp2.neg(a)), Fp2.zero())); +} + +test "Fp2.mul: u² = -1" { + // u is (0 + 1·u), and u² should equal (-1 + 0·u). + const u = fp2FromU64Pair(0, 1); + const u_sq = Fp2.mul(u, u); + const minus_one_in_fp = Fp.neg(Fp.one()); + const expected: Fp2 = .{ .c0 = minus_one_in_fp, .c1 = Fp.zero() }; + try testing.expect(Fp2.eql(u_sq, expected)); +} + +test "Fp2.square: equivalent to mul(a, a)" { + const a = fp2FromU64Pair(0x1234, 0xabcd); + try testing.expect(Fp2.eql(Fp2.square(a), Fp2.mul(a, a))); +} + +test "Fp2.mul: distributive over add" { + const a = fp2FromU64Pair(3, 5); + const b = fp2FromU64Pair(7, 11); + const c = fp2FromU64Pair(13, 17); + const lhs = Fp2.mul(Fp2.add(a, b), c); + const rhs = Fp2.add(Fp2.mul(a, c), Fp2.mul(b, c)); + try testing.expect(Fp2.eql(lhs, rhs)); +} + +test "Fp2.mul: hand-computed value" { + // (2 + 3u) * (5 + 7u) = (2*5 - 3*7) + (2*7 + 3*5)u + // = (10 - 21) + (14 + 15)u + // = -11 + 29u + const a = fp2FromU64Pair(2, 3); + const b = fp2FromU64Pair(5, 7); + const product = Fp2.mul(a, b); + const minus_eleven = Fp.neg(fpFromU64(11)); + const expected: Fp2 = .{ .c0 = minus_eleven, .c1 = fpFromU64(29) }; + try testing.expect(Fp2.eql(product, expected)); +} + +test "Fp2.inv: a * a^-1 = 1" { + const a = fp2FromU64Pair(0x12345678, 0x9abcdef0); + const inv_a = Fp2.inv(a); + const product = Fp2.mul(a, inv_a); + try testing.expect(Fp2.eql(product, Fp2.one())); +} + +test "Fp2.inv: inv(0) = 0" { + try testing.expect(Fp2.eql(Fp2.inv(Fp2.zero()), Fp2.zero())); +} + +test "Fp2.inv: inv(1) = 1" { + try testing.expect(Fp2.eql(Fp2.inv(Fp2.one()), Fp2.one())); +} + +test "Fp2.sqrt: round-trips a hand-built square" { + // Build a = (3 + 5u)² and check that fp2Sqrt(a) returns ±(3 + 5u). + const original = fp2FromU64Pair(3, 5); + const sq = Fp2.square(original); + const root = try fp2Sqrt(sq); + const round_trip = Fp2.square(root); + try testing.expect(Fp2.eql(round_trip, sq)); + try testing.expect(Fp2.eql(root, original) or Fp2.eql(root, Fp2.neg(original))); +} + +test "Fp2.sqrt: round-trips a 6-limb random square" { + const original: Fp2 = .{ + .c0 = Fp.fromRaw(.{ 0x12345678, 0x9abcdef0, 0x1111, 0, 0, 0 }), + .c1 = Fp.fromRaw(.{ 0xfedcba98, 0x1234, 0, 0x4321, 0, 0 }), + }; + const sq = Fp2.square(original); + const root = try fp2Sqrt(sq); + try testing.expect(Fp2.eql(Fp2.square(root), sq)); +} + +test "Fp2.sqrt: zero -> zero" { + const root = try fp2Sqrt(Fp2.zero()); + try testing.expect(Fp2.eql(root, Fp2.zero())); +} + +test "Fp2.sqrt: 1 -> ±1" { + const root = try fp2Sqrt(Fp2.one()); + const sq = Fp2.square(root); + try testing.expect(Fp2.eql(sq, Fp2.one())); +} + +test "Fp2.sqrt: pure-Fp residue" { + // 4 = 2² in Fp; in Fp2 the sqrt should be (2, 0) (or its negation). + const four_in_fp2: Fp2 = .{ .c0 = Fp.fromRaw(.{ 4, 0, 0, 0, 0, 0 }), .c1 = Fp.zero() }; + const root = try fp2Sqrt(four_in_fp2); + const sq = Fp2.square(root); + try testing.expect(Fp2.eql(sq, four_in_fp2)); +} + +test "Fp2.sqrt: pure-Fp non-residue takes the (0, sqrt(-a₀)) branch" { + // -1 in Fp2 is u² = (0, 0).c0 = -1, but actually (0, 1) since u² = -1. + // Easier test: a = -4 has sqrt 2u, since (2u)² = -4. + const minus_four_in_fp2: Fp2 = .{ .c0 = Fp.neg(Fp.fromRaw(.{ 4, 0, 0, 0, 0, 0 })), .c1 = Fp.zero() }; + const root = try fp2Sqrt(minus_four_in_fp2); + const sq = Fp2.square(root); + try testing.expect(Fp2.eql(sq, minus_four_in_fp2)); +} + +// --------------------------------------------------------------------------- +// Fp6 tests +// --------------------------------------------------------------------------- + +fn fp6FromInts(c0_a: u64, c0_b: u64, c1_a: u64, c1_b: u64, c2_a: u64, c2_b: u64) Fp6 { + return .{ + .c0 = fp2FromU64Pair(c0_a, c0_b), + .c1 = fp2FromU64Pair(c1_a, c1_b), + .c2 = fp2FromU64Pair(c2_a, c2_b), + }; +} + +test "Fp6 identity laws" { + const a = fp6FromInts(1, 2, 3, 4, 5, 6); + try testing.expect(Fp6.eql(Fp6.add(a, Fp6.zero()), a)); + try testing.expect(Fp6.eql(Fp6.mul(a, Fp6.one()), a)); + try testing.expect(Fp6.eql(Fp6.mul(Fp6.one(), a), a)); + try testing.expect(Fp6.eql(Fp6.add(a, Fp6.neg(a)), Fp6.zero())); +} + +test "Fp6.mul: v³ = 1+u" { + // v in Fp6 = (0 + 0u, 1 + 0u, 0 + 0u) — coefficient of v is 1. + const v: Fp6 = .{ + .c0 = Fp2.zero(), + .c1 = Fp2.one(), + .c2 = Fp2.zero(), + }; + const v_sq = Fp6.mul(v, v); + const v_cubed = Fp6.mul(v_sq, v); + // v² should be (0, 0, 1) and v³ should be (1+u, 0, 0). + const expected_v_sq: Fp6 = .{ + .c0 = Fp2.zero(), + .c1 = Fp2.zero(), + .c2 = Fp2.one(), + }; + try testing.expect(Fp6.eql(v_sq, expected_v_sq)); + const one_plus_u: Fp2 = .{ .c0 = Fp.one(), .c1 = Fp.one() }; + const expected_v_cubed: Fp6 = .{ + .c0 = one_plus_u, + .c1 = Fp2.zero(), + .c2 = Fp2.zero(), + }; + try testing.expect(Fp6.eql(v_cubed, expected_v_cubed)); +} + +test "Fp6.mul: distributive over add" { + const a = fp6FromInts(1, 2, 3, 4, 5, 6); + const b = fp6FromInts(7, 8, 9, 10, 11, 12); + const c = fp6FromInts(13, 14, 15, 16, 17, 18); + const lhs = Fp6.mul(Fp6.add(a, b), c); + const rhs = Fp6.add(Fp6.mul(a, c), Fp6.mul(b, c)); + try testing.expect(Fp6.eql(lhs, rhs)); +} + +test "Fp6.mul: associative" { + const a = fp6FromInts(2, 3, 5, 7, 11, 13); + const b = fp6FromInts(17, 19, 23, 29, 31, 37); + const c = fp6FromInts(41, 43, 47, 53, 59, 61); + const lhs = Fp6.mul(Fp6.mul(a, b), c); + const rhs = Fp6.mul(a, Fp6.mul(b, c)); + try testing.expect(Fp6.eql(lhs, rhs)); +} + +test "Fp6.square: equivalent to mul(a, a)" { + const a = fp6FromInts(11, 13, 17, 19, 23, 29); + try testing.expect(Fp6.eql(Fp6.square(a), Fp6.mul(a, a))); +} + +test "Fp6.mulByV: equivalent to mul by (0, 1, 0)" { + const a = fp6FromInts(11, 13, 17, 19, 23, 29); + const v: Fp6 = .{ + .c0 = Fp2.zero(), + .c1 = Fp2.one(), + .c2 = Fp2.zero(), + }; + try testing.expect(Fp6.eql(Fp6.mulByV(a), Fp6.mul(a, v))); +} + +test "Fp6.inv: a * a⁻¹ = 1" { + const a = fp6FromInts(2, 3, 5, 7, 11, 13); + const inv_a = Fp6.inv(a); + try testing.expect(Fp6.eql(Fp6.mul(a, inv_a), Fp6.one())); +} + +test "Fp6.pow: a^0 = 1, a^1 = a, a^2 = square(a)" { + const a = fp6FromInts(2, 3, 5, 7, 11, 13); + try testing.expect(Fp6.eql(Fp6.pow(a, 1, .{0}), Fp6.one())); + try testing.expect(Fp6.eql(Fp6.pow(a, 1, .{1}), a)); + try testing.expect(Fp6.eql(Fp6.pow(a, 1, .{2}), Fp6.square(a))); +} + +test "fp6Frobenius: applied 6 times returns the input" { + // The Frobenius has order dividing 6 in Fp6 (since Fp6 has degree + // 6 over Fp). Applying it 6 times should be the identity for + // EVERY element, even outside the cyclotomic subgroup. + const a = fp6FromInts(11, 13, 17, 19, 23, 29); + var result = a; + inline for (0..6) |_| result = fp6Frobenius(result); + try testing.expect(Fp6.eql(result, a)); +} + +test "fp6Frobenius: matches Fp6.pow with exponent p" { + // φ(a) = a^p must equal Fp6.pow(a, FP_MODULUS) for any a. Slow + // but the most direct cross-check. Limit to a small element to + // keep the test budget reasonable. + const a = fp6FromInts(2, 3, 5, 7, 11, 13); + const via_frobenius = fp6Frobenius(a); + const via_pow = Fp6.pow(a, 6, FP_MODULUS); + try testing.expect(Fp6.eql(via_frobenius, via_pow)); +} + +test "fp6Frobenius: leaves Fp2 elements alone for c1, c2 = 0" { + // For a = (a0, 0, 0), the v and v² components of fp6Frobenius + // are zero (because they're multiplied by gamma1 / gamma1², but + // a1 = a2 = 0). And the c0 component is fp2Frobenius(a0). + const a: Fp6 = .{ + .c0 = fp2FromU64Pair(42, 7), + .c1 = Fp2.zero(), + .c2 = Fp2.zero(), + }; + const result = fp6Frobenius(a); + try testing.expect(Fp2.eql(result.c0, fp2Frobenius(a.c0))); + try testing.expect(Fp2.eql(result.c1, Fp2.zero())); + try testing.expect(Fp2.eql(result.c2, Fp2.zero())); +} + +test "fp6FrobeniusGamma1: applying Frobenius 6 times to v returns v" { + // The Frobenius coefficient γ₁ = (1+u)^((p-1)/3) drives the action + // of φ on v ∈ Fp6. After 6 applications of Frobenius, every Fp6 + // element returns to itself; the test reaching back to v specifically + // exercises the gamma1 / gamma1² product chain. + const v: Fp6 = .{ + .c0 = Fp2.zero(), + .c1 = Fp2.one(), + .c2 = Fp2.zero(), + }; + var v6 = v; + inline for (0..6) |_| v6 = fp6Frobenius(v6); + try testing.expect(Fp6.eql(v6, v)); +} + +// --------------------------------------------------------------------------- +// Fp12 tests +// --------------------------------------------------------------------------- + +fn fp12FromInts(c0: Fp6, c1: Fp6) Fp12 { + return .{ .c0 = c0, .c1 = c1 }; +} + +test "Fp12 identity laws" { + const a = fp12FromInts( + fp6FromInts(1, 2, 3, 4, 5, 6), + fp6FromInts(7, 8, 9, 10, 11, 12), + ); + try testing.expect(Fp12.eql(Fp12.add(a, Fp12.zero()), a)); + try testing.expect(Fp12.eql(Fp12.mul(a, Fp12.one()), a)); + try testing.expect(Fp12.eql(Fp12.mul(Fp12.one(), a), a)); + try testing.expect(Fp12.eql(Fp12.add(a, Fp12.neg(a)), Fp12.zero())); +} + +test "Fp12.mul: w² = v" { + // w in Fp12 = (0 + 0w + 0w² , 1 + 0w + 0w²) — i.e. c0=0, c1=1 + const w: Fp12 = .{ .c0 = Fp6.zero(), .c1 = Fp6.one() }; + const w_sq = Fp12.mul(w, w); + // w² should be (v, 0). v = (0, 1, 0) in Fp6. + const v_in_fp6: Fp6 = .{ + .c0 = Fp2.zero(), + .c1 = Fp2.one(), + .c2 = Fp2.zero(), + }; + const expected: Fp12 = .{ .c0 = v_in_fp6, .c1 = Fp6.zero() }; + try testing.expect(Fp12.eql(w_sq, expected)); +} + +test "Fp12.mul: distributive over add" { + const a = fp12FromInts( + fp6FromInts(2, 3, 5, 7, 11, 13), + fp6FromInts(17, 19, 23, 29, 31, 37), + ); + const b = fp12FromInts( + fp6FromInts(41, 43, 47, 53, 59, 61), + fp6FromInts(67, 71, 73, 79, 83, 89), + ); + const c = fp12FromInts( + fp6FromInts(97, 101, 103, 107, 109, 113), + fp6FromInts(127, 131, 137, 139, 149, 151), + ); + const lhs = Fp12.mul(Fp12.add(a, b), c); + const rhs = Fp12.add(Fp12.mul(a, c), Fp12.mul(b, c)); + try testing.expect(Fp12.eql(lhs, rhs)); +} + +test "Fp12.square: equivalent to mul(a, a)" { + const a = fp12FromInts( + fp6FromInts(11, 13, 17, 19, 23, 29), + fp6FromInts(31, 37, 41, 43, 47, 53), + ); + try testing.expect(Fp12.eql(Fp12.square(a), Fp12.mul(a, a))); +} + +test "Fp12.inv: a * a⁻¹ = 1" { + const a = fp12FromInts( + fp6FromInts(2, 3, 5, 7, 11, 13), + fp6FromInts(17, 19, 23, 29, 31, 37), + ); + const inv_a = Fp12.inv(a); + try testing.expect(Fp12.eql(Fp12.mul(a, inv_a), Fp12.one())); +} + +test "Fp12.conjugate: applied twice returns the input" { + const a = fp12FromInts( + fp6FromInts(2, 3, 5, 7, 11, 13), + fp6FromInts(17, 19, 23, 29, 31, 37), + ); + try testing.expect(Fp12.eql(Fp12.conjugate(Fp12.conjugate(a)), a)); +} + +test "Fp12.conjugate: conjugate(real-only element) is itself" { + const real_only: Fp12 = .{ + .c0 = fp6FromInts(2, 3, 5, 7, 11, 13), + .c1 = Fp6.zero(), + }; + try testing.expect(Fp12.eql(Fp12.conjugate(real_only), real_only)); +} + +test "Fp12.conjugate: a · conjugate(a) is in Fp6" { + // For any a = c0 + c1·w, a · conjugate(a) = c0² - c1²·v which has + // zero c1 component (i.e., lives in the Fp6 subfield). + const a = fp12FromInts( + fp6FromInts(2, 3, 5, 7, 11, 13), + fp6FromInts(17, 19, 23, 29, 31, 37), + ); + const product = Fp12.mul(a, Fp12.conjugate(a)); + try testing.expect(Fp6.eql(product.c1, Fp6.zero())); +} + +test "Fp12.pow: a^0 = 1, a^1 = a, a^2 = square(a)" { + const a = fp12FromInts( + fp6FromInts(2, 3, 5, 7, 11, 13), + fp6FromInts(17, 19, 23, 29, 31, 37), + ); + try testing.expect(Fp12.eql(Fp12.pow(a, 1, .{0}), Fp12.one())); + try testing.expect(Fp12.eql(Fp12.pow(a, 1, .{1}), a)); + try testing.expect(Fp12.eql(Fp12.pow(a, 1, .{2}), Fp12.square(a))); +} + +test "Fp12.pow: matches manual repeated mul" { + const a = fp12FromInts( + fp6FromInts(11, 13, 17, 19, 23, 29), + fp6FromInts(31, 37, 41, 43, 47, 53), + ); + const a_to_5 = Fp12.pow(a, 1, .{5}); + var manual = a; + inline for (0..4) |_| manual = Fp12.mul(manual, a); + try testing.expect(Fp12.eql(a_to_5, manual)); +} + +test "fp12Frobenius: applied 12 times returns the input" { + // The Fp12 Frobenius has order dividing 12. + const a = fp12FromInts( + fp6FromInts(2, 3, 5, 7, 11, 13), + fp6FromInts(17, 19, 23, 29, 31, 37), + ); + var result = a; + inline for (0..12) |_| result = fp12Frobenius(result); + try testing.expect(Fp12.eql(result, a)); +} + +test "fp12Frobenius: matches Fp12.pow with exponent p" { + // φ(a) = a^p must equal Fp12.pow(a, FP_MODULUS) for any a. + // Slow but the most direct cross-check. The pow path here goes + // through ~381 Fp12 squarings + ~190 Fp12 mults; even with a small + // input that's a few hundred ms. + const a = fp12FromInts( + fp6FromInts(2, 3, 0, 0, 0, 0), + fp6FromInts(0, 0, 5, 7, 0, 0), + ); + const via_frobenius = fp12Frobenius(a); + const via_pow = Fp12.pow(a, 6, FP_MODULUS); + try testing.expect(Fp12.eql(via_frobenius, via_pow)); +} + +test "fp12Frobenius: leaves c0=Fp6, c1=0 alone for the c1 side" { + // For a = (c0, 0), the c1 side stays zero and the c0 side gets + // fp6Frobenius applied directly. + const c0 = fp6FromInts(2, 3, 5, 7, 11, 13); + const a: Fp12 = .{ .c0 = c0, .c1 = Fp6.zero() }; + const result = fp12Frobenius(a); + try testing.expect(Fp6.eql(result.c0, fp6Frobenius(c0))); + try testing.expect(Fp6.eql(result.c1, Fp6.zero())); +} + +test "fp12FrobeniusSquared: matches fp12Frobenius applied twice" { + const a = fp12FromInts( + fp6FromInts(2, 3, 5, 7, 11, 13), + fp6FromInts(17, 19, 23, 29, 31, 37), + ); + const direct = fp12FrobeniusSquared(a); + const composed = fp12Frobenius(fp12Frobenius(a)); + try testing.expect(Fp12.eql(direct, composed)); +} + +test "fp12FinalExpEasy: result is in the cyclotomic subgroup" { + // After the easy part, the result `g` should satisfy + // `conjugate(g) · g = 1` (this is the defining property of the + // cyclotomic subgroup of Fp12). + const a = fp12FromInts( + fp6FromInts(2, 3, 5, 7, 11, 13), + fp6FromInts(17, 19, 23, 29, 31, 37), + ); + const g = fp12FinalExpEasy(a); + const product = Fp12.mul(Fp12.conjugate(g), g); + try testing.expect(Fp12.eql(product, Fp12.one())); +} + +test "fp12FinalExpEasy: a^0 stays zero, a=1 stays 1" { + // f = 1: easy result should be 1. + const one = Fp12.one(); + const result = fp12FinalExpEasy(one); + try testing.expect(Fp12.eql(result, one)); +} + +test "fp2Frobenius: applied twice returns the input" { + const a = fp2FromU64Pair(123, 456); + try testing.expect(Fp2.eql(fp2Frobenius(fp2Frobenius(a)), a)); +} + +test "fp2Frobenius: conjugates the imaginary part" { + const a = fp2FromU64Pair(123, 456); + const expected: Fp2 = .{ .c0 = a.c0, .c1 = Fp.neg(a.c1) }; + try testing.expect(Fp2.eql(fp2Frobenius(a), expected)); +} + +test "fp2Frobenius: leaves Fp elements alone" { + const a: Fp2 = .{ .c0 = Fp.fromRaw(.{ 0x42, 0, 0, 0, 0, 0 }), .c1 = Fp.zero() }; + try testing.expect(Fp2.eql(fp2Frobenius(a), a)); +} + +test "fp2Pow: a^0 = 1" { + const a = fp2FromU64Pair(123, 456); + const result = fp2Pow(a, 1, .{0}); + try testing.expect(Fp2.eql(result, Fp2.one())); +} + +test "fp2Pow: a^1 = a" { + const a = fp2FromU64Pair(123, 456); + try testing.expect(Fp2.eql(fp2Pow(a, 1, .{1}), a)); +} + +test "fp2Pow: a^2 = square(a)" { + const a = fp2FromU64Pair(123, 456); + try testing.expect(Fp2.eql(fp2Pow(a, 1, .{2}), Fp2.square(a))); +} + +test "fp2Pow: 3^7 = 2187 (in Fp2)" { + const three = fp2FromU64Pair(3, 0); + const result = fp2Pow(three, 1, .{7}); + const expected = fp2FromU64Pair(2187, 0); + try testing.expect(Fp2.eql(result, expected)); +} + +test "fp2Pow matches manual repeated mul for small exponents" { + const a = fp2FromU64Pair(5, 7); + const a_to_5 = fp2Pow(a, 1, .{5}); + var manual = a; + inline for (0..4) |_| manual = Fp2.mul(manual, a); + try testing.expect(Fp2.eql(a_to_5, manual)); +} + +// --------------------------------------------------------------------------- +// G1 affine arithmetic tests +// --------------------------------------------------------------------------- + +test "G1: generator is on the curve" { + const g = g1Generator(); + try testing.expect(g.isOnCurve()); +} + +test "G1: identity is on the curve" { + try testing.expect(G1Affine.identity().isOnCurve()); +} + +test "G1: identity is the additive neutral element" { + const g = g1Generator(); + const id = G1Affine.identity(); + try testing.expect(G1Affine.eql(g.add(id), g)); + try testing.expect(G1Affine.eql(id.add(g), g)); +} + +test "G1: P + (-P) = identity" { + const g = g1Generator(); + const neg_g = g.neg(); + const sum = g.add(neg_g); + try testing.expect(sum.isIdentity()); +} + +test "G1: 2P via double matches P + P" { + const g = g1Generator(); + const doubled = g.double(); + const summed = g.add(g); + try testing.expect(G1Affine.eql(doubled, summed)); + // 2P should still be on the curve. + try testing.expect(doubled.isOnCurve()); +} + +test "G1: 3P = 2P + P matches P + 2P" { + const g = g1Generator(); + const two_g = g.double(); + const three_g_a = two_g.add(g); + const three_g_b = g.add(two_g); + try testing.expect(G1Affine.eql(three_g_a, three_g_b)); + try testing.expect(three_g_a.isOnCurve()); +} + +test "G1: 4P = 2(2P) matches 3P + P" { + const g = g1Generator(); + const two_g = g.double(); + const four_g_a = two_g.double(); + const four_g_b = two_g.add(g).add(g); + try testing.expect(G1Affine.eql(four_g_a, four_g_b)); + try testing.expect(four_g_a.isOnCurve()); +} + +test "G1: addition is commutative" { + const g = g1Generator(); + const two_g = g.double(); + const three_g = two_g.add(g); + const lhs = three_g.add(two_g); + const rhs = two_g.add(three_g); + try testing.expect(G1Affine.eql(lhs, rhs)); +} + +test "G1: addition is associative" { + const g = g1Generator(); + const two_g = g.double(); + const three_g = two_g.add(g); + // (g + 2g) + 3g == g + (2g + 3g) + const lhs = g.add(two_g).add(three_g); + const rhs = g.add(two_g.add(three_g)); + try testing.expect(G1Affine.eql(lhs, rhs)); +} + +test "G1: scalar mul matches repeated add for small scalars" { + const g = g1Generator(); + // 5 * G == G + G + G + G + G + const five_g = g.mul(1, .{5}); + const expected = g.add(g).add(g).add(g).add(g); + try testing.expect(G1Affine.eql(five_g, expected)); +} + +test "G1: scalar mul: 0 * G = identity, 1 * G = G" { + const g = g1Generator(); + try testing.expect(g.mul(1, .{0}).isIdentity()); + try testing.expect(G1Affine.eql(g.mul(1, .{1}), g)); +} + +test "G1: scalar mul: 2 * G via mul matches G.double()" { + const g = g1Generator(); + try testing.expect(G1Affine.eql(g.mul(1, .{2}), g.double())); +} + +test "G1: scalar mul distributes over scalar add (small scalars)" { + const g = g1Generator(); + // (3 + 5) * G == 3*G + 5*G + const lhs = g.mul(1, .{8}); + const rhs = g.mul(1, .{3}).add(g.mul(1, .{5})); + try testing.expect(G1Affine.eql(lhs, rhs)); +} + +// --------------------------------------------------------------------------- +// Compressed point decoding +// --------------------------------------------------------------------------- + +test "decodeG1Compressed: infinity flag round-trip" { + var bytes: [48]u8 = .{0} ** 48; + bytes[0] = 0xc0; // compression + infinity bits set + const point = try decodeG1Compressed(&bytes); + try testing.expect(point.isIdentity()); +} + +test "decodeG1Compressed: rejects wrong length" { + const short: [47]u8 = .{0} ** 47; + try testing.expectError(PointDecodeError.InvalidLength, decodeG1Compressed(&short)); +} + +test "decodeG1Compressed: rejects missing compression flag" { + var bytes: [48]u8 = .{0} ** 48; + // Top bit cleared = uncompressed encoding, which we don't support. + try testing.expectError(PointDecodeError.InvalidEncoding, decodeG1Compressed(&bytes)); +} + +test "decodeG1Compressed: G1 generator round-trip" { + // Compressed encoding of the standard G1 generator (from the + // BLS12-381 IETF spec): + // 0x97f1d3a7 3197d794 2695638c 4fa9ac0f + // c3688c4f 9774b905 a14e3a3f 171bac58 + // 6c55e83f f97a1aef fb3af00a db22c6bb + // The high bit (0x80) is the compression flag; the next bit + // would be infinity (cleared); the third bit is the y-sign. + // For the standard generator the y-sign bit is 0 (lex-smaller y), + // so the first byte is exactly 0x97 (not 0xb7). + const compressed_hex = "97f1d3a73197d7942695638c4fa9ac0fc3688c4f9774b905a14e3a3f171bac586c55e83ff97a1aeffb3af00adb22c6bb"; + var compressed: [48]u8 = undefined; + _ = try std.fmt.hexToBytes(&compressed, compressed_hex); + const decoded = try decodeG1Compressed(&compressed); + try testing.expect(decoded.isOnCurve()); + // The recovered x must equal the canonical generator x. + try testing.expect(Fp.eql(decoded.x, Fp.fromRaw(G1_GENERATOR_X))); + // And the recovered point must equal the canonical generator (with + // matching y root). + try testing.expect(G1Affine.eql(decoded, g1Generator())); +} + +test "decodeG1Compressed: bit-flipped y-sign decodes to -G" { + // Same encoding, but flip the y-sign bit (bit 5 of the first byte). + const compressed_hex = "97f1d3a73197d7942695638c4fa9ac0fc3688c4f9774b905a14e3a3f171bac586c55e83ff97a1aeffb3af00adb22c6bb"; + var compressed: [48]u8 = undefined; + _ = try std.fmt.hexToBytes(&compressed, compressed_hex); + compressed[0] |= 0b0010_0000; + const decoded = try decodeG1Compressed(&compressed); + try testing.expect(decoded.isOnCurve()); + try testing.expect(G1Affine.eql(decoded, g1Generator().neg())); +} + +test "decodeG1Compressed: rejects x ≥ p" { + var bytes: [48]u8 = .{0xff} ** 48; + bytes[0] = 0x80 | 0x1f; // compression flag + low 5 bits of 0xff + try testing.expectError(PointDecodeError.NotInField, decodeG1Compressed(&bytes)); +} + +// --------------------------------------------------------------------------- +// G2 affine arithmetic tests. The test set mirrors G1 — the curve +// equation differs but the affine algebra is the same. +// --------------------------------------------------------------------------- + +test "G2: generator is on the curve" { + const g = g2Generator(); + try testing.expect(g.isOnCurve()); +} + +test "G2: identity is on the curve" { + try testing.expect(G2Affine.identity().isOnCurve()); +} + +test "G2: identity is the additive neutral element" { + const g = g2Generator(); + const id = G2Affine.identity(); + try testing.expect(G2Affine.eql(g.add(id), g)); + try testing.expect(G2Affine.eql(id.add(g), g)); +} + +test "G2: P + (-P) = identity" { + const g = g2Generator(); + const sum = g.add(g.neg()); + try testing.expect(sum.isIdentity()); +} + +test "G2: 2P via double matches P + P" { + const g = g2Generator(); + const doubled = g.double(); + const summed = g.add(g); + try testing.expect(G2Affine.eql(doubled, summed)); + try testing.expect(doubled.isOnCurve()); +} + +test "G2: 3P consistency and on-curve" { + const g = g2Generator(); + const two_g = g.double(); + const three_g_a = two_g.add(g); + const three_g_b = g.add(two_g); + try testing.expect(G2Affine.eql(three_g_a, three_g_b)); + try testing.expect(three_g_a.isOnCurve()); +} + +test "G2: addition is commutative" { + const g = g2Generator(); + const two_g = g.double(); + const three_g = two_g.add(g); + try testing.expect(G2Affine.eql(three_g.add(two_g), two_g.add(three_g))); +} + +test "G2: scalar mul matches repeated add for small scalars" { + const g = g2Generator(); + const four_g = g.mul(1, .{4}); + const expected = g.double().double(); + try testing.expect(G2Affine.eql(four_g, expected)); +} + +test "G2: 0 * G = identity, 1 * G = G" { + const g = g2Generator(); + try testing.expect(g.mul(1, .{0}).isIdentity()); + try testing.expect(G2Affine.eql(g.mul(1, .{1}), g)); +} + +// --------------------------------------------------------------------------- +// Compressed G2 decoding +// --------------------------------------------------------------------------- + +test "decodeG2Compressed: infinity flag round-trip" { + var bytes: [96]u8 = .{0} ** 96; + bytes[0] = 0xc0; + const point = try decodeG2Compressed(&bytes); + try testing.expect(point.isIdentity()); +} + +test "decodeG2Compressed: rejects wrong length" { + const short: [95]u8 = .{0} ** 95; + try testing.expectError(PointDecodeError.InvalidLength, decodeG2Compressed(&short)); +} + +test "decodeG2Compressed: rejects missing compression flag" { + var bytes: [96]u8 = .{0} ** 96; + try testing.expectError(PointDecodeError.InvalidEncoding, decodeG2Compressed(&bytes)); +} + +test "decodeG2Compressed: G2 generator round-trip" { + // Compressed encoding of the standard G2 generator (from the + // BLS12-381 IETF spec): bytes[0..48] = x.c1, bytes[48..96] = x.c0, + // with the high three bits of byte 0 set to {1, 0, sign}. + // + // For the standard generator, the lex-smaller y root is taken, + // so the y-sign bit is 0 and the first byte is 0x80 | (top byte + // of x.c1 with high 3 bits cleared). + // + // x.c1 high byte = 0x13, with the compression flag set the first + // byte becomes 0x80 | 0x13 = 0x93. + const compressed_hex = "93e02b6052719f607dacd3a088274f65596bd0d09920b61ab5da61bbdc7f5049334cf11213945d57e5ac7d055d042b7e024aa2b2f08f0a91260805272dc51051c6e47ad4fa403b02b4510b647ae3d1770bac0326a805bbefd48056c8c121bdb8"; + var compressed: [96]u8 = undefined; + _ = try std.fmt.hexToBytes(&compressed, compressed_hex); + const decoded = try decodeG2Compressed(&compressed); + try testing.expect(decoded.isOnCurve()); + // The recovered x must equal the canonical generator x. + try testing.expect(Fp2.eql(decoded.x, g2Generator().x)); + // And the full point must equal one of {±G2}. + const g2 = g2Generator(); + try testing.expect(G2Affine.eql(decoded, g2) or G2Affine.eql(decoded, g2.neg())); +} + +// --------------------------------------------------------------------------- +// Compressed point encoder tests +// --------------------------------------------------------------------------- + +test "encodeG1Compressed: identity round-trip" { + const id = G1Affine.identity(); + const bytes = encodeG1Compressed(id); + try testing.expectEqual(@as(u8, 0xc0), bytes[0]); + inline for (1..48) |i| try testing.expectEqual(@as(u8, 0), bytes[i]); + + const decoded = try decodeG1Compressed(&bytes); + try testing.expect(decoded.isIdentity()); +} + +test "encodeG1Compressed: generator round-trip" { + const g = g1Generator(); + const bytes = encodeG1Compressed(g); + // Compression flag must be set. + try testing.expect((bytes[0] & 0x80) != 0); + // Infinity flag must be cleared. + try testing.expect((bytes[0] & 0x40) == 0); + + const decoded = try decodeG1Compressed(&bytes); + try testing.expect(G1Affine.eql(decoded, g)); +} + +test "encodeG1Compressed: matches the canonical generator hex" { + // The canonical compressed encoding of G1 from the IETF spec. + const expected_hex = "97f1d3a73197d7942695638c4fa9ac0fc3688c4f9774b905a14e3a3f171bac586c55e83ff97a1aeffb3af00adb22c6bb"; + var expected: [48]u8 = undefined; + _ = try std.fmt.hexToBytes(&expected, expected_hex); + + const bytes = encodeG1Compressed(g1Generator()); + try testing.expectEqualSlices(u8, &expected, &bytes); +} + +test "encodeG1Compressed: -G round-trip with sign flag set" { + const g = g1Generator(); + const neg_g = g.neg(); + const bytes = encodeG1Compressed(neg_g); + // The y-sign for -G should differ from G's. Decode and verify. + const decoded = try decodeG1Compressed(&bytes); + try testing.expect(G1Affine.eql(decoded, neg_g)); +} + +test "encodeG1Compressed: random scalar multiple round-trips" { + const g = g1Generator(); + const five_g = g.mul(1, .{5}); + const bytes = encodeG1Compressed(five_g); + const decoded = try decodeG1Compressed(&bytes); + try testing.expect(G1Affine.eql(decoded, five_g)); +} + +test "encodeG2Compressed: identity round-trip" { + const id = G2Affine.identity(); + const bytes = encodeG2Compressed(id); + try testing.expectEqual(@as(u8, 0xc0), bytes[0]); + inline for (1..96) |i| try testing.expectEqual(@as(u8, 0), bytes[i]); + + const decoded = try decodeG2Compressed(&bytes); + try testing.expect(decoded.isIdentity()); +} + +test "encodeG2Compressed: generator round-trip" { + const g = g2Generator(); + const bytes = encodeG2Compressed(g); + try testing.expect((bytes[0] & 0x80) != 0); + try testing.expect((bytes[0] & 0x40) == 0); + + const decoded = try decodeG2Compressed(&bytes); + try testing.expect(G2Affine.eql(decoded, g)); +} + +test "encodeG2Compressed: matches the canonical generator hex" { + // The canonical compressed encoding of G2 from the IETF spec. + const expected_hex = "93e02b6052719f607dacd3a088274f65596bd0d09920b61ab5da61bbdc7f5049334cf11213945d57e5ac7d055d042b7e024aa2b2f08f0a91260805272dc51051c6e47ad4fa403b02b4510b647ae3d1770bac0326a805bbefd48056c8c121bdb8"; + var expected: [96]u8 = undefined; + _ = try std.fmt.hexToBytes(&expected, expected_hex); + + const bytes = encodeG2Compressed(g2Generator()); + try testing.expectEqualSlices(u8, &expected, &bytes); +} + +test "encodeG2Compressed: -G round-trip with sign flag set" { + const g = g2Generator(); + const neg_g = g.neg(); + const bytes = encodeG2Compressed(neg_g); + const decoded = try decodeG2Compressed(&bytes); + try testing.expect(G2Affine.eql(decoded, neg_g)); +} + +test "encodeG2Compressed: scalar multiple round-trips" { + const g = g2Generator(); + const seven_g = g.mul(1, .{7}); + const bytes = encodeG2Compressed(seven_g); + const decoded = try decodeG2Compressed(&bytes); + try testing.expect(G2Affine.eql(decoded, seven_g)); +} diff --git a/packages/zolt-arith/src/curves/bls12_381/field.zig b/packages/zolt-arith/src/curves/bls12_381/field.zig new file mode 100644 index 00000000..36578504 --- /dev/null +++ b/packages/zolt-arith/src/curves/bls12_381/field.zig @@ -0,0 +1,427 @@ +//! Montgomery field arithmetic over `[N]u64` limbs. +//! +//! `MontgomeryField(modulus, n_limbs, r2, n_prime)` instantiates a +//! finite field `Fp = ℤ / pℤ` with `p = modulus` and stores elements in +//! Montgomery form (`a · R mod p` where `R = 2^(64·N)`). +//! +//! The Montgomery representation lets multiplication be implemented +//! without an expensive `mod p` step at the end of every operation — +//! instead, multiplication uses the CIOS (Coarsely Integrated Operand +//! Scanning) algorithm which interleaves multiplication and reduction. +//! See Acar's "Analyzing and Comparing Montgomery Multiplication +//! Algorithms" (1996) for the canonical description. +//! +//! Constants the caller must precompute (the constructor takes them so +//! the field can stay generic over any modulus): +//! +//! - `modulus`: `[N]u64`, the prime `p` (little-endian limbs). +//! - `r2`: `R^2 mod p` where `R = 2^(64·N)`. Used to convert into +//! Montgomery form via `to_mont(a) = a · R^2 · R^{-1} mod p`. +//! - `n_prime`: `-p^{-1} mod 2^64`. Drives the per-limb reduction +//! step inside CIOS. +//! +//! Element layout: +//! +//! - `Element` is `[N]u64`. The numeric value of `e` is +//! `(sum_i e[i] * 2^(64·i)) · R^{-1} mod p`. +//! - Zero is `[0; N]` (Montgomery form is closed under zero). +//! - One is `R mod p`, which the constructor derives from +//! `to_mont(1)`. + +const std = @import("std"); +const bigint = @import("../../bigint.zig"); + +/// Build a Montgomery field type over the given modulus / N / R^2 / -p^{-1}. +/// +/// The factory is comptime so callers get a fresh type per modulus and +/// the inlined arithmetic stays branch-free over the limb count. +pub fn MontgomeryField( + comptime N: comptime_int, + comptime modulus: [N]u64, + comptime r2: [N]u64, + comptime n_prime: u64, +) type { + return struct { + const Self = @This(); + + /// One element of the field. Stored in Montgomery form. + pub const Element = [N]u64; + + /// The prime modulus, exposed so consumers can interrogate it. + pub const MODULUS: [N]u64 = modulus; + + /// `R^2 mod p` — used to convert raw integers into Montgomery + /// form via `montMul(raw, R2)`. + pub const R2: [N]u64 = r2; + + /// Negative inverse of the modulus mod 2^64. Drives the per- + /// limb reduction in CIOS. + pub const N_PRIME: u64 = n_prime; + + /// Limb count, exported for byte conversions. + pub const LIMB_COUNT: comptime_int = N; + + /// The additive identity. Already in Montgomery form (zero is + /// the same in both representations). + pub fn zero() Element { + return .{0} ** N; + } + + /// The multiplicative identity in Montgomery form: `R mod p`. + /// Computed as `montMul(1, R2)` so the constructor doesn't have + /// to embed `R mod p` separately. + pub fn one() Element { + var raw_one: [N]u64 = .{0} ** N; + raw_one[0] = 1; + return montMul(raw_one, r2); + } + + /// Constant-shape equality. Walks all limbs. + pub fn eql(a: Element, b: Element) bool { + inline for (0..N) |i| { + if (a[i] != b[i]) return false; + } + return true; + } + + /// `a + b mod p`. + pub fn add(a: Element, b: Element) Element { + var sum: [N]u64 = undefined; + const carry = bigint.add(N, &sum, a, b); + // If the sum overflowed `R` (carry=1) or is now ≥ p, + // subtract the modulus to bring it back into range. + var reduced: [N]u64 = undefined; + const borrow = bigint.sub(N, &reduced, sum, modulus); + // Use sum if borrow=1 (sum < p) and carry=0; otherwise use + // the reduced form. This is a constant-shape select. + if (carry == 1 or borrow == 0) return reduced; + return sum; + } + + /// `a - b mod p`. + pub fn sub(a: Element, b: Element) Element { + var diff: [N]u64 = undefined; + const borrow = bigint.sub(N, &diff, a, b); + if (borrow == 0) return diff; + // Borrow: add the modulus back. + var corrected: [N]u64 = undefined; + _ = bigint.add(N, &corrected, diff, modulus); + return corrected; + } + + /// `-a mod p` = `p - a` (or `0` for zero). + pub fn neg(a: Element) Element { + if (bigint.isZero(N, a)) return zero(); + var out: [N]u64 = undefined; + _ = bigint.sub(N, &out, modulus, a); + return out; + } + + /// `a^2 mod p`. Convenience wrapper around `montMul(a, a)`. + pub fn square(a: Element) Element { + return montMul(a, a); + } + + /// Square-and-multiply exponentiation `a^e mod p` where the + /// exponent is a raw little-endian limb array (not in + /// Montgomery form). The exponent is walked from MSB to LSB + /// across all `N` limbs. + /// + /// `a` must be in Montgomery form. + pub fn pow(a: Element, exponent: [N]u64) Element { + // Find the index of the most-significant set bit so we can + // skip the leading zeros and avoid an empty `result *= a` + // step on `a^0`. + const top_bit = bigint.bitLen(N, exponent); + if (top_bit == 0) return one(); + + var result = a; + var i = top_bit - 1; + while (i > 0) { + i -= 1; + result = square(result); + const limb = i / 64; + const bit = @as(u6, @intCast(i % 64)); + if (((exponent[limb] >> bit) & 1) == 1) { + result = montMul(result, a); + } + } + return result; + } + + /// Field inversion via Fermat's little theorem: `a^{-1} = a^{p-2} mod p`. + /// Returns `zero` for `zero` (which is mathematically undefined + /// but lets callers avoid an explicit branch in common code). + /// + /// This is much slower than the binary extended Euclidean + /// algorithm — Fermat does ~381 squarings + ~190 multiplies + /// for BLS12-381 — but it is constant-time-friendly and trivial + /// to validate against the cheaper algorithms once they land. + pub fn inv(a: Element) Element { + if (bigint.isZero(N, a)) return zero(); + // Compute p - 2 as a raw limb array. + var p_minus_two: [N]u64 = modulus; + // The modulus must be odd for this to work; subtract 2 from + // the lowest limb without underflow. + std.debug.assert(p_minus_two[0] >= 2); + p_minus_two[0] -= 2; + return pow(a, p_minus_two); + } + + /// Coarsely Integrated Operand Scanning (CIOS) Montgomery + /// multiplication. Returns `a · b · R^{-1} mod p` — i.e. the + /// product in Montgomery form. + /// + /// The algorithm uses an `[N+2]u64` accumulator (one extra limb + /// for partial sums and one for the carry-out of the reduction + /// step). For BLS12-381's 6-limb field this is 8 limbs of + /// scratch — well within the inlined comptime budget. + pub fn montMul(a: Element, b: Element) Element { + var t: [N + 2]u64 = .{0} ** (N + 2); + inline for (0..N) |i| { + // Outer-loop multiply: t += a * b[i] + var carry: u64 = 0; + inline for (0..N) |j| { + const prod = @as(u128, a[j]) * @as(u128, b[i]) + @as(u128, t[j]) + @as(u128, carry); + t[j] = @truncate(prod); + carry = @intCast(prod >> 64); + } + const t_n_sum = @addWithOverflow(t[N], carry); + t[N] = t_n_sum[0]; + t[N + 1] += @intCast(t_n_sum[1]); + + // Reduce: m = t[0] * n_prime mod 2^64; t = (t + m * p) / 2^64 + const m: u64 = @truncate(@as(u128, t[0]) *% @as(u128, n_prime)); + { + const prod0 = @as(u128, m) * @as(u128, modulus[0]) + @as(u128, t[0]); + var reduce_carry: u64 = @intCast(prod0 >> 64); + inline for (1..N) |j| { + const prod = @as(u128, m) * @as(u128, modulus[j]) + @as(u128, t[j]) + @as(u128, reduce_carry); + t[j - 1] = @truncate(prod); + reduce_carry = @intCast(prod >> 64); + } + const t_n_red = @addWithOverflow(t[N], reduce_carry); + t[N - 1] = t_n_red[0]; + const t_n1_red = @addWithOverflow(t[N + 1], @as(u64, t_n_red[1])); + t[N] = t_n1_red[0]; + t[N + 1] = t_n1_red[1]; + } + } + // Final conditional subtract — t may be in [0, 2p). + var result: [N]u64 = undefined; + inline for (0..N) |i| result[i] = t[i]; + var corrected: [N]u64 = undefined; + const borrow = bigint.sub(N, &corrected, result, modulus); + // If borrow=0 the subtraction succeeded (result ≥ p), use + // corrected; otherwise use result. The N+1 carry above + // also forces the corrected branch. + if (t[N] != 0 or borrow == 0) return corrected; + return result; + } + + /// Convert a raw little-endian integer into Montgomery form. + pub fn fromRaw(raw: [N]u64) Element { + return montMul(raw, r2); + } + + /// Convert from Montgomery form back to a raw little-endian + /// integer. Implemented as `montMul(value, [1, 0, ...])` which + /// is `value · 1 · R^{-1} mod p`. + pub fn toRaw(value: Element) [N]u64 { + var one_raw: [N]u64 = .{0} ** N; + one_raw[0] = 1; + return montMul(value, one_raw); + } + + /// Read a little-endian byte string into a field element. + /// The input MUST already be < p; values ≥ p are rejected. + pub fn fromBytesLeReduced(bytes: []const u8) error{NotInField}!Element { + const raw = bigint.fromBytesLe(N, bytes); + if (bigint.cmp(N, raw, modulus) != .lt) return error.NotInField; + return fromRaw(raw); + } + + /// Write a field element as a little-endian byte string. The + /// output is canonical (i.e. fully reduced). + pub fn toBytesLe(value: Element, out: []u8) void { + const raw = toRaw(value); + bigint.toBytesLe(N, raw, out); + } + }; +} + +// --------------------------------------------------------------------------- +// Tests — exercise the field machinery against a small toy modulus that +// is easy to reason about by hand. The BLS12-381 instantiation lives in +// `bls12_381.zig` once it lands. +// --------------------------------------------------------------------------- + +const testing = std.testing; + +// Toy 4-limb prime: p = 2^255 - 19 (Curve25519's base field). Easy to +// reason about, well-known constants on hand. We use 4 limbs even +// though it fits in 4 to keep the test exercising the same path the +// BLS12-381 6-limb instantiation will use. +// +// Constants computed once via Python: +// p = 2**255 - 19 +// R = 2**256 +// r2 = (R * R) % p = 1444 = 0x5a4 +// n_prime = (-pow(p, -1, 2**64)) % 2**64 = 0x86bca1af286bca1b +const ED25519_P: [4]u64 = .{ + 0xffffffffffffffed, + 0xffffffffffffffff, + 0xffffffffffffffff, + 0x7fffffffffffffff, +}; +const ED25519_R2: [4]u64 = .{ + 0x00000000000005a4, + 0x0000000000000000, + 0x0000000000000000, + 0x0000000000000000, +}; +const ED25519_N_PRIME: u64 = 0x86bca1af286bca1b; + +const Ed25519Fp = MontgomeryField(4, ED25519_P, ED25519_R2, ED25519_N_PRIME); + +test "Ed25519Fp.zero is identity for add" { + const zero = Ed25519Fp.zero(); + try testing.expect(Ed25519Fp.eql(Ed25519Fp.add(zero, zero), zero)); + const one = Ed25519Fp.one(); + try testing.expect(Ed25519Fp.eql(Ed25519Fp.add(one, zero), one)); + try testing.expect(Ed25519Fp.eql(Ed25519Fp.add(zero, one), one)); +} + +test "Ed25519Fp.one is identity for mul" { + const one = Ed25519Fp.one(); + const a = Ed25519Fp.fromRaw(.{ 0x1234567890abcdef, 0, 0, 0 }); + try testing.expect(Ed25519Fp.eql(Ed25519Fp.montMul(a, one), a)); + try testing.expect(Ed25519Fp.eql(Ed25519Fp.montMul(one, a), a)); +} + +test "Ed25519Fp.toRaw round-trips fromRaw" { + const raw: [4]u64 = .{ 0x0102030405060708, 0x0a0b0c0d0e0f0001, 0x1122334455667788, 0x12345678 }; + const m = Ed25519Fp.fromRaw(raw); + const back = Ed25519Fp.toRaw(m); + try testing.expectEqual(raw, back); +} + +test "Ed25519Fp.add wraps around the modulus" { + // 1 + (p - 1) = p ≡ 0 + const one = Ed25519Fp.fromRaw(.{ 1, 0, 0, 0 }); + const p_minus_one_raw: [4]u64 = .{ + ED25519_P[0] - 1, + ED25519_P[1], + ED25519_P[2], + ED25519_P[3], + }; + const p_minus_one = Ed25519Fp.fromRaw(p_minus_one_raw); + const sum = Ed25519Fp.add(one, p_minus_one); + try testing.expect(Ed25519Fp.eql(sum, Ed25519Fp.zero())); +} + +test "Ed25519Fp.sub borrows correctly" { + // 0 - 1 = p - 1 + const zero = Ed25519Fp.zero(); + const one_e = Ed25519Fp.fromRaw(.{ 1, 0, 0, 0 }); + const result = Ed25519Fp.sub(zero, one_e); + const expected = Ed25519Fp.fromRaw(.{ + ED25519_P[0] - 1, + ED25519_P[1], + ED25519_P[2], + ED25519_P[3], + }); + try testing.expect(Ed25519Fp.eql(result, expected)); +} + +test "Ed25519Fp.neg + add = zero" { + const a = Ed25519Fp.fromRaw(.{ 0x1234567890abcdef, 0xabad1deafeedface, 0xdeadbeefcafebabe, 0x0123456789abcdef }); + const neg_a = Ed25519Fp.neg(a); + const sum = Ed25519Fp.add(a, neg_a); + try testing.expect(Ed25519Fp.eql(sum, Ed25519Fp.zero())); +} + +test "Ed25519Fp.montMul: 2 * 3 = 6" { + const two = Ed25519Fp.fromRaw(.{ 2, 0, 0, 0 }); + const three = Ed25519Fp.fromRaw(.{ 3, 0, 0, 0 }); + const product = Ed25519Fp.montMul(two, three); + const six = Ed25519Fp.fromRaw(.{ 6, 0, 0, 0 }); + try testing.expect(Ed25519Fp.eql(product, six)); +} + +test "Ed25519Fp.montMul: distributive over add (random sample)" { + // (a + b) * c == a*c + b*c + const a = Ed25519Fp.fromRaw(.{ 0x1234, 0x5678, 0, 0 }); + const b = Ed25519Fp.fromRaw(.{ 0x9abc, 0xdef0, 0x1234, 0 }); + const c = Ed25519Fp.fromRaw(.{ 0x42, 0, 0, 0 }); + const lhs = Ed25519Fp.montMul(Ed25519Fp.add(a, b), c); + const rhs = Ed25519Fp.add(Ed25519Fp.montMul(a, c), Ed25519Fp.montMul(b, c)); + try testing.expect(Ed25519Fp.eql(lhs, rhs)); +} + +test "Ed25519Fp.fromBytesLeReduced rejects values ≥ p" { + var buf: [32]u8 = .{0xff} ** 32; + // Top byte must respect the prime ceiling — 0xff is too high. + try testing.expectError(error.NotInField, Ed25519Fp.fromBytesLeReduced(&buf)); + + // 1 < p; should succeed. + var ok: [32]u8 = .{0} ** 32; + ok[0] = 1; + const e = try Ed25519Fp.fromBytesLeReduced(&ok); + try testing.expect(Ed25519Fp.eql(e, Ed25519Fp.one())); +} + +test "Ed25519Fp.square: 5^2 = 25" { + const five = Ed25519Fp.fromRaw(.{ 5, 0, 0, 0 }); + const sq = Ed25519Fp.square(five); + const twenty_five = Ed25519Fp.fromRaw(.{ 25, 0, 0, 0 }); + try testing.expect(Ed25519Fp.eql(sq, twenty_five)); +} + +test "Ed25519Fp.pow: 3^7 = 2187" { + const three = Ed25519Fp.fromRaw(.{ 3, 0, 0, 0 }); + const result = Ed25519Fp.pow(three, .{ 7, 0, 0, 0 }); + const expected = Ed25519Fp.fromRaw(.{ 2187, 0, 0, 0 }); + try testing.expect(Ed25519Fp.eql(result, expected)); +} + +test "Ed25519Fp.pow: a^0 = 1" { + const a = Ed25519Fp.fromRaw(.{ 0x42, 0, 0, 0 }); + const result = Ed25519Fp.pow(a, .{ 0, 0, 0, 0 }); + try testing.expect(Ed25519Fp.eql(result, Ed25519Fp.one())); +} + +test "Ed25519Fp.pow: a^1 = a" { + const a = Ed25519Fp.fromRaw(.{ 0x42, 0xab, 0, 0 }); + const result = Ed25519Fp.pow(a, .{ 1, 0, 0, 0 }); + try testing.expect(Ed25519Fp.eql(result, a)); +} + +test "Ed25519Fp.inv: a * a^-1 = 1" { + const a = Ed25519Fp.fromRaw(.{ 0x12345678, 0, 0, 0 }); + const inv_a = Ed25519Fp.inv(a); + const product = Ed25519Fp.montMul(a, inv_a); + try testing.expect(Ed25519Fp.eql(product, Ed25519Fp.one())); +} + +test "Ed25519Fp.inv: inv(1) = 1" { + const one = Ed25519Fp.one(); + const inv_one = Ed25519Fp.inv(one); + try testing.expect(Ed25519Fp.eql(inv_one, one)); +} + +test "Ed25519Fp.inv: inv(zero) = zero" { + const z = Ed25519Fp.zero(); + const inv_z = Ed25519Fp.inv(z); + try testing.expect(Ed25519Fp.eql(inv_z, z)); +} + +test "Ed25519Fp.toBytesLe round-trips fromBytesLeReduced" { + var bytes_in: [32]u8 = .{0} ** 32; + bytes_in[0] = 0x42; + bytes_in[1] = 0x13; + const e = try Ed25519Fp.fromBytesLeReduced(&bytes_in); + var bytes_out: [32]u8 = undefined; + Ed25519Fp.toBytesLe(e, &bytes_out); + try testing.expectEqualSlices(u8, &bytes_in, &bytes_out); +} diff --git a/packages/zolt-arith/src/curves/bls12_381/hash_to_curve_g2.zig b/packages/zolt-arith/src/curves/bls12_381/hash_to_curve_g2.zig new file mode 100644 index 00000000..8a2d18b8 --- /dev/null +++ b/packages/zolt-arith/src/curves/bls12_381/hash_to_curve_g2.zig @@ -0,0 +1,459 @@ +//! BLS12-381 hash-to-curve for G2. +//! +//! Implements the RFC 9380 §8.8.2 pipeline for the suite +//! `BLS_SIG_BLS12381G2_XMD:SHA-256_SSWU_RO_NUL_`. The pipeline is: +//! +//! 1. `hash_to_field_fp2` (already in `hash_to_field.zig`) produces +//! two `Fp2` field elements `u0, u1` from `(msg, DST)`. +//! 2. The simplified Shallue-van de Woestijne-Ulas (SSWU) map sends +//! each `u_i` to a point on the isogenous curve +//! `E': y² = x³ + 240·u·x + 1012·(1+u)`. +//! 3. The 3-degree isogeny (4 polynomials over Fp2) pushes the SSWU +//! output from `E'` to BLS12-381 G2 (`y² = x³ + 4(1+u)`). +//! 4. The two G2 points are added together to form `Q'`. +//! 5. Cofactor clearing scales `Q'` by the G2 cofactor `h`, producing +//! a point in the prime-order `r`-subgroup. +//! +//! All constants here come from the IETF draft and are cross-checked +//! against arkworks `ark-bls12-381` (`g2_swu_iso.rs`). The SSWU +//! algorithm follows Wahby & Boneh (2019) "Fast and simple constant-time +//! hashing to the BLS12-381 elliptic curve" §4.1, which is the +//! "avoiding inversions" optimization arkworks uses too. +//! +//! Performance is correctness-first: cofactor clearing uses naive +//! scalar multiplication by the 8-limb cofactor instead of the +//! ψ-endomorphism shortcut. A future iteration can swap that in +//! without changing the public surface. + +const std = @import("std"); +const bls12_381 = @import("curve.zig"); +const hash_to_field = @import("hash_to_field.zig"); +const bigint = @import("../../bigint.zig"); + +const Fp = bls12_381.Fp; +const Fp2 = bls12_381.Fp2; +const G2Affine = bls12_381.G2Affine; +const G2Projective = bls12_381.G2Projective; + +// --------------------------------------------------------------------------- +// Isogenous curve E' constants. +// +// E': y'² = x'³ + A'·x' + B' +// A' = 240·u +// B' = 1012 + 1012·u +// ZETA = -(2 + u) (the SSWU non-square parameter) +// --------------------------------------------------------------------------- + +/// `A' = 240·u` in Fp2 raw form (NOT Montgomery). We turn it into +/// Montgomery form lazily so the constant block stays declarative. +const ISO_A_RAW: Fp2Raw = .{ + .c0 = .{ 0, 0, 0, 0, 0, 0 }, + .c1 = .{ 240, 0, 0, 0, 0, 0 }, +}; + +/// `B' = 1012 + 1012·u`. +const ISO_B_RAW: Fp2Raw = .{ + .c0 = .{ 1012, 0, 0, 0, 0, 0 }, + .c1 = .{ 1012, 0, 0, 0, 0, 0 }, +}; + +/// `ZETA = -(2 + u) = -2 - u` mod p. +const ISO_ZETA_RAW: Fp2Raw = .{ + .c0 = .{ 0xb9feffffffffaaa9, 0x1eabfffeb153ffff, 0x6730d2a0f6b0f624, 0x64774b84f38512bf, 0x4b1ba7b6434bacd7, 0x1a0111ea397fe69a }, + .c1 = .{ 0xb9feffffffffaaaa, 0x1eabfffeb153ffff, 0x6730d2a0f6b0f624, 0x64774b84f38512bf, 0x4b1ba7b6434bacd7, 0x1a0111ea397fe69a }, +}; + +const Fp2Raw = struct { + c0: [6]u64, + c1: [6]u64, +}; + +inline fn fp2FromRaw(r: Fp2Raw) Fp2 { + return .{ + .c0 = Fp.fromRaw(r.c0), + .c1 = Fp.fromRaw(r.c1), + }; +} + +fn isoA() Fp2 { + return fp2FromRaw(ISO_A_RAW); +} + +fn isoB() Fp2 { + return fp2FromRaw(ISO_B_RAW); +} + +fn isoZeta() Fp2 { + return fp2FromRaw(ISO_ZETA_RAW); +} + +// --------------------------------------------------------------------------- +// 3-degree isogeny constants. Each entry is a polynomial in `x'` +// (the x-coordinate of the iso curve point) with Fp2 coefficients. +// Index 0 is the constant term; index k is the x'^k coefficient. +// +// Values converted from arkworks `g2_swu_iso.rs::ISOGENY_MAP_TO_G2` +// via Python (decimal → little-endian 6-limb hex). The Rust constants +// are pinned in the IETF draft, section E.3. +// --------------------------------------------------------------------------- + +const ISOGENY_X_NUM: [4]Fp2Raw = .{ + .{ .c0 = .{ 0x6238aaaaaaaa97d6, 0x5c2638e343d9c71c, 0x88b58423c50ae15d, 0x32c52d39fd3a042a, 0xbb5b7a9a47d7ed85, 0x05c759507e8e333e }, .c1 = .{ 0x6238aaaaaaaa97d6, 0x5c2638e343d9c71c, 0x88b58423c50ae15d, 0x32c52d39fd3a042a, 0xbb5b7a9a47d7ed85, 0x05c759507e8e333e } }, + .{ .c0 = .{ 0, 0, 0, 0, 0, 0 }, .c1 = .{ 0x26a9ffffffffc71a, 0x1472aaa9cb8d5555, 0x9a208c6b4f20a418, 0x984f87adf7ae0c7f, 0x32126fced787c88f, 0x11560bf17baa99bc } }, + .{ .c0 = .{ 0x26a9ffffffffc71e, 0x1472aaa9cb8d5555, 0x9a208c6b4f20a418, 0x984f87adf7ae0c7f, 0x32126fced787c88f, 0x11560bf17baa99bc }, .c1 = .{ 0x9354ffffffffe38d, 0x0a395554e5c6aaaa, 0xcd104635a790520c, 0xcc27c3d6fbd7063f, 0x190937e76bc3e447, 0x08ab05f8bdd54cde } }, + .{ .c0 = .{ 0x88e2aaaaaaaa5ed1, 0x7098e38d0f671c71, 0x22d6108f142b8575, 0xcb14b4e7f4e810aa, 0xed6dea691f5fb614, 0x171d6541fa38ccfa }, .c1 = .{ 0, 0, 0, 0, 0, 0 } }, +}; + +const ISOGENY_X_DEN: [3]Fp2Raw = .{ + .{ .c0 = .{ 0, 0, 0, 0, 0, 0 }, .c1 = .{ 0xb9feffffffffaa63, 0x1eabfffeb153ffff, 0x6730d2a0f6b0f624, 0x64774b84f38512bf, 0x4b1ba7b6434bacd7, 0x1a0111ea397fe69a } }, + .{ .c0 = .{ 0x000000000000000c, 0, 0, 0, 0, 0 }, .c1 = .{ 0xb9feffffffffaa9f, 0x1eabfffeb153ffff, 0x6730d2a0f6b0f624, 0x64774b84f38512bf, 0x4b1ba7b6434bacd7, 0x1a0111ea397fe69a } }, + .{ .c0 = .{ 0x0000000000000001, 0, 0, 0, 0, 0 }, .c1 = .{ 0, 0, 0, 0, 0, 0 } }, +}; + +const ISOGENY_Y_NUM: [4]Fp2Raw = .{ + .{ .c0 = .{ 0x12cfc71c71c6d706, 0xfc8c25ebf8c92f68, 0xf54439d87d27e500, 0x0f7da5d4a07f649b, 0x59a4c18b076d1193, 0x1530477c7ab4113b }, .c1 = .{ 0x12cfc71c71c6d706, 0xfc8c25ebf8c92f68, 0xf54439d87d27e500, 0x0f7da5d4a07f649b, 0x59a4c18b076d1193, 0x1530477c7ab4113b } }, + .{ .c0 = .{ 0, 0, 0, 0, 0, 0 }, .c1 = .{ 0x6238aaaaaaaa97be, 0x5c2638e343d9c71c, 0x88b58423c50ae15d, 0x32c52d39fd3a042a, 0xbb5b7a9a47d7ed85, 0x05c759507e8e333e } }, + .{ .c0 = .{ 0x26a9ffffffffc71c, 0x1472aaa9cb8d5555, 0x9a208c6b4f20a418, 0x984f87adf7ae0c7f, 0x32126fced787c88f, 0x11560bf17baa99bc }, .c1 = .{ 0x9354ffffffffe38f, 0x0a395554e5c6aaaa, 0xcd104635a790520c, 0xcc27c3d6fbd7063f, 0x190937e76bc3e447, 0x08ab05f8bdd54cde } }, + .{ .c0 = .{ 0xe1b371c71c718b10, 0x4e79097a56dc4bd9, 0xb0e977c69aa27452, 0x761b0f37a1e26286, 0xfbf7043de3811ad0, 0x124c9ad43b6cf79b }, .c1 = .{ 0, 0, 0, 0, 0, 0 } }, +}; + +const ISOGENY_Y_DEN: [4]Fp2Raw = .{ + .{ .c0 = .{ 0xb9feffffffffa8fb, 0x1eabfffeb153ffff, 0x6730d2a0f6b0f624, 0x64774b84f38512bf, 0x4b1ba7b6434bacd7, 0x1a0111ea397fe69a }, .c1 = .{ 0xb9feffffffffa8fb, 0x1eabfffeb153ffff, 0x6730d2a0f6b0f624, 0x64774b84f38512bf, 0x4b1ba7b6434bacd7, 0x1a0111ea397fe69a } }, + .{ .c0 = .{ 0, 0, 0, 0, 0, 0 }, .c1 = .{ 0xb9feffffffffa9d3, 0x1eabfffeb153ffff, 0x6730d2a0f6b0f624, 0x64774b84f38512bf, 0x4b1ba7b6434bacd7, 0x1a0111ea397fe69a } }, + .{ .c0 = .{ 0x0000000000000012, 0, 0, 0, 0, 0 }, .c1 = .{ 0xb9feffffffffaa99, 0x1eabfffeb153ffff, 0x6730d2a0f6b0f624, 0x64774b84f38512bf, 0x4b1ba7b6434bacd7, 0x1a0111ea397fe69a } }, + .{ .c0 = .{ 0x0000000000000001, 0, 0, 0, 0, 0 }, .c1 = .{ 0, 0, 0, 0, 0, 0 } }, +}; + +/// G2 cofactor `h` from the IETF spec. 8 limbs. NOT used for hash-to- +/// curve cofactor clearing — the IETF SSWU_RO suite mandates `h_eff` +/// instead, which produces a *different* point in the same prime-order +/// subgroup. Multiplying by `h` instead of `h_eff` would still land in +/// the r-subgroup, but the result would be a different multiple of the +/// SSWU output and would not match what `blst::min_pk` produces. +/// +/// Kept around as documentation; the cofactor clearing path uses +/// `G2_H_EFF` below. +const G2_COFACTOR: [8]u64 = .{ + 0xcf1c38e31c7238e5, + 0x1616ec6e786f0c70, + 0x21537e293a6691ae, + 0xa628f1cb4d9e82ef, + 0xa68a205b2e5a7ddf, + 0xcd91de4547085aba, + 0x091d50792876a202, + 0x05d543a95414e7f1, +}; + +/// `h_eff` for BLS12-381 G2 hash-to-curve, from RFC 9380 §8.8.2 / IETF +/// pairing-friendly-curves draft. 636 bits → 10 limbs. +/// +/// The hex value (big-endian): +/// 0xbc69f08f2ee75b3584c6a0ea91b352888e2a8e9145ad7689986ff031508ffe13 +/// 29c2f178731db956d82bf015d1212b02ec0ec69d7477c1ae954cbc06689f6a35 +/// 9894c0adebbf6b4e8020005aaa95551 +/// +/// Multiplying an SSWU output by this value lands the point in the +/// prime-order r-subgroup at exactly the multiple `blst::min_pk` (and +/// every other RFC-9380-conformant implementation) lands at. The +/// regular cofactor `h` would also land in the r-subgroup but at a +/// different multiple, so signatures would not cross-verify. +const G2_H_EFF: [10]u64 = .{ + 0xe8020005aaa95551, + 0x59894c0adebbf6b4, + 0xe954cbc06689f6a3, + 0x2ec0ec69d7477c1a, + 0x6d82bf015d1212b0, + 0x329c2f178731db95, + 0x9986ff031508ffe1, + 0x88e2a8e9145ad768, + 0x584c6a0ea91b3528, + 0x0bc69f08f2ee75b3, +}; + +// --------------------------------------------------------------------------- +// Parity helpers (RFC 9380 §4.1 sgn0). +// --------------------------------------------------------------------------- + +/// Parity of a raw Fp value: true if the integer (NOT Montgomery form) +/// is odd, false if even. +fn fpIsOddRaw(a: Fp.Element) bool { + const raw = Fp.toRaw(a); + return (raw[0] & 1) == 1; +} + +/// Parity for an Fp2 element: take the parity of the first non-zero +/// coordinate in the order `(c0, c1)`. A pure-zero element has parity +/// 0. This matches arkworks `parity` and the RFC 9380 sgn0 rule with +/// the "first non-zero coordinate" tie-breaking. +fn fp2Parity(a: Fp2) bool { + if (!Fp.eql(a.c0, Fp.zero())) return fpIsOddRaw(a.c0); + if (!Fp.eql(a.c1, Fp.zero())) return fpIsOddRaw(a.c1); + return false; +} + +/// Predicate: is `a` a quadratic residue in Fp2? Implemented by trying +/// the sqrt and checking the round-trip — slow but correct. +fn fp2IsSquare(a: Fp2) bool { + const root = bls12_381.fp2Sqrt(a) catch return false; + return Fp2.eql(Fp2.square(root), a); +} + +// --------------------------------------------------------------------------- +// Polynomial evaluation in Fp2 (Horner). +// --------------------------------------------------------------------------- + +/// Evaluate the Fp2 polynomial whose coefficients are `coeffs` (constant +/// term first, then ascending degrees) at `x` using Horner's method. +fn evalPoly(comptime N: comptime_int, coeffs: [N]Fp2Raw, x: Fp2) Fp2 { + if (N == 0) return Fp2.zero(); + var result = fp2FromRaw(coeffs[N - 1]); + var i: usize = N - 1; + while (i > 0) { + i -= 1; + result = Fp2.mul(result, x); + result = Fp2.add(result, fp2FromRaw(coeffs[i])); + } + return result; +} + +// --------------------------------------------------------------------------- +// Simplified SWU map (Wahby & Boneh 2019, §4.1; arkworks `swu.rs`). +// +// Given an Fp2 element `u`, produces an affine point on the iso curve +// `E': y² = x³ + 240·u·x + 1012·(1 + u)`. The result is always a valid +// curve point — no failure mode. +// --------------------------------------------------------------------------- + +/// Apply the simplified SWU map to a single field element. Returns an +/// affine point on the iso curve `E'`. Caller is responsible for then +/// applying the isogeny to push the point onto BLS12-381 G2. +pub fn sswuMapToCurve(u: Fp2) struct { x: Fp2, y: Fp2 } { + const a = isoA(); + const b = isoB(); + const zeta = isoZeta(); + + // tv1 = ZETA · u² + const u_sq = Fp2.square(u); + const zeta_u2 = Fp2.mul(zeta, u_sq); + // ta = (ZETA·u²)² + (ZETA·u²) = Z²u⁴ + Zu² + const ta = Fp2.add(Fp2.square(zeta_u2), zeta_u2); + // num_x1 = B · (ta + 1) + const num_x1 = Fp2.mul(b, Fp2.add(ta, Fp2.one())); + // div = if ta == 0 then A·ZETA else A·(-ta) + const div = if (Fp2.eql(ta, Fp2.zero())) + Fp2.mul(a, zeta) + else + Fp2.mul(a, Fp2.neg(ta)); + + // num²_x1 = num_x1² + const num2_x1 = Fp2.square(num_x1); + // div² = div² + const div2 = Fp2.square(div); + // div³ = div² · div + const div3 = Fp2.mul(div2, div); + // num_gx1 = (num²_x1 + A · div²) · num_x1 + B · div³ + const num_gx1 = Fp2.add( + Fp2.mul(Fp2.add(num2_x1, Fp2.mul(a, div2)), num_x1), + Fp2.mul(b, div3), + ); + + // num_x2 = ZETA·u² · num_x1 (x2 = ZETA·u² · x1, same div) + const num_x2 = Fp2.mul(zeta_u2, num_x1); + + // gx1 = num_gx1 / div³ + const div3_inv = Fp2.inv(div3); + const gx1 = Fp2.mul(num_gx1, div3_inv); + + // Try to take sqrt(gx1). If it's a square, use (x1, sqrt(gx1)). + // Otherwise use (x2, ZETA·u·u² · sqrt(ZETA · gx1)). + var x_num: Fp2 = undefined; + var y: Fp2 = undefined; + if (fp2IsSquare(gx1)) { + x_num = num_x1; + y = bls12_381.fp2Sqrt(gx1) catch unreachable; + } else { + // ZETA · gx1 must be a square (by the structure of the SWU map). + const zeta_gx1 = Fp2.mul(zeta, gx1); + const y1 = bls12_381.fp2Sqrt(zeta_gx1) catch unreachable; + // y2 = ZETA · u² · u · y1 = zeta_u2 · u · y1 + y = Fp2.mul(Fp2.mul(zeta_u2, u), y1); + x_num = num_x2; + } + + // x = num_x / div + const x = Fp2.mul(x_num, Fp2.inv(div)); + // Final y-sign tweak: parity(y) must equal parity(u). RFC 9380 4.1. + if (fp2Parity(y) != fp2Parity(u)) { + y = Fp2.neg(y); + } + return .{ .x = x, .y = y }; +} + +// --------------------------------------------------------------------------- +// Isogeny push from E' to G2. +// --------------------------------------------------------------------------- + +/// Apply the 3-degree isogeny to push an iso curve point onto BLS12-381 +/// G2. Identity input maps to identity output. Identity here means +/// the iso point's denominator polynomial vanishes at the point's `x`, +/// which we treat as the point at infinity. +pub fn isogenyMap(p_x: Fp2, p_y: Fp2) G2Affine { + const x_num_val = evalPoly(4, ISOGENY_X_NUM, p_x); + const x_den_val = evalPoly(3, ISOGENY_X_DEN, p_x); + const y_num_val = evalPoly(4, ISOGENY_Y_NUM, p_x); + const y_den_val = evalPoly(4, ISOGENY_Y_DEN, p_x); + + // If either denominator vanishes the iso point lies on the kernel + // of the isogeny — map to identity. + if (Fp2.eql(x_den_val, Fp2.zero()) or Fp2.eql(y_den_val, Fp2.zero())) { + return G2Affine.identity(); + } + + const new_x = Fp2.mul(x_num_val, Fp2.inv(x_den_val)); + const new_y = Fp2.mul(p_y, Fp2.mul(y_num_val, Fp2.inv(y_den_val))); + return .{ .x = new_x, .y = new_y, .infinity = false }; +} + +// --------------------------------------------------------------------------- +// Cofactor clearing. +// --------------------------------------------------------------------------- + +/// Multiply a G2Affine point by `h_eff` from the IETF SSWU_RO suite. +/// The result is in the prime-order r-subgroup at the same multiple +/// `blst::min_pk` lands at, so the resulting hash-to-curve output is +/// byte-for-byte cross-implementation-compatible. +/// +/// Implementation: naive double-and-add scalar multiplication on +/// `G2Projective`. `h_eff` is a 636-bit scalar, so this is ~636 +/// doublings + ~318 additions. The faster ψ-endomorphism shortcut from +/// the IETF draft can land later as a drop-in replacement — the public +/// surface stays the same. +pub fn clearCofactor(p: G2Affine) G2Affine { + if (p.infinity) return p; + const proj = G2Projective.fromAffine(p); + return proj.mul(10, G2_H_EFF).toAffine(); +} + +// --------------------------------------------------------------------------- +// Top-level entry point. +// --------------------------------------------------------------------------- + +/// `BLS_SIG_BLS12381G2_XMD:SHA-256_SSWU_RO_NUL_` per draft-irtf-cfrg-bls-signature-05. +pub const DST_BLS_SIG_NUL: []const u8 = "BLS_SIG_BLS12381G2_XMD:SHA-256_SSWU_RO_NUL_"; + +/// Hash a message to BLS12-381 G2 using the RFC 9380 `SSWU_RO` +/// (random-oracle) variant. Produces a point in the prime-order +/// r-subgroup, ready for use as a BLS signature input. +pub fn hashToG2(msg: []const u8, dst: []const u8) hash_to_field.ExpandError!G2Affine { + var us: [2]Fp2 = undefined; + try hash_to_field.hash_to_field_fp2(&us, msg, dst); + + // Map each of u0 and u1 through SSWU and the isogeny, then add. + const r0 = sswuMapToCurve(us[0]); + const q0 = isogenyMap(r0.x, r0.y); + + const r1 = sswuMapToCurve(us[1]); + const q1 = isogenyMap(r1.x, r1.y); + + // Affine addition handles the (q0 == q1) and (q0 == -q1) edge cases + // naturally; the result lives on G2 but generally NOT in the + // prime-order subgroup, so cofactor clearing comes next. + const sum = q0.add(q1); + return clearCofactor(sum); +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +const testing = std.testing; + +test "sswuMapToCurve: produces a point on the iso curve E'" { + // E': y² = x³ + 240·u·x + 1012·(1+u). Pick a small Fp2 element and + // verify the SSWU output satisfies the iso curve equation. + const u = Fp2{ + .c0 = Fp.fromRaw(.{ 7, 0, 0, 0, 0, 0 }), + .c1 = Fp.fromRaw(.{ 11, 0, 0, 0, 0, 0 }), + }; + const point = sswuMapToCurve(u); + const a = isoA(); + const b = isoB(); + const lhs = Fp2.square(point.y); + const rhs = Fp2.add( + Fp2.add(Fp2.mul(Fp2.square(point.x), point.x), Fp2.mul(a, point.x)), + b, + ); + try testing.expect(Fp2.eql(lhs, rhs)); +} + +test "sswuMapToCurve: deterministic" { + const u = Fp2{ + .c0 = Fp.fromRaw(.{ 0xdeadbeef, 0, 0, 0, 0, 0 }), + .c1 = Fp.fromRaw(.{ 0xfeedface, 0, 0, 0, 0, 0 }), + }; + const a = sswuMapToCurve(u); + const b = sswuMapToCurve(u); + try testing.expect(Fp2.eql(a.x, b.x)); + try testing.expect(Fp2.eql(a.y, b.y)); +} + +test "sswuMapToCurve: distinct inputs produce distinct outputs" { + const ua = Fp2{ + .c0 = Fp.fromRaw(.{ 1, 0, 0, 0, 0, 0 }), + .c1 = Fp.fromRaw(.{ 0, 0, 0, 0, 0, 0 }), + }; + const ub = Fp2{ + .c0 = Fp.fromRaw(.{ 2, 0, 0, 0, 0, 0 }), + .c1 = Fp.fromRaw(.{ 0, 0, 0, 0, 0, 0 }), + }; + const a = sswuMapToCurve(ua); + const b = sswuMapToCurve(ub); + try testing.expect(!Fp2.eql(a.x, b.x) or !Fp2.eql(a.y, b.y)); +} + +test "isogenyMap: pushes iso curve point to G2" { + // Run SSWU then push through the isogeny. The result must lie on + // the actual G2 curve y² = x³ + 4(1+u). + const u = Fp2{ + .c0 = Fp.fromRaw(.{ 0x123456, 0, 0, 0, 0, 0 }), + .c1 = Fp.fromRaw(.{ 0xabcdef, 0, 0, 0, 0, 0 }), + }; + const iso_point = sswuMapToCurve(u); + const g2 = isogenyMap(iso_point.x, iso_point.y); + try testing.expect(g2.isOnCurve()); +} + +test "clearCofactor: result is in the prime-order subgroup" { + // Cofactor-clearing the G2 generator must leave it in the + // r-subgroup (which it already is), and the result should still + // be on the curve. This is a smoke test for the 8-limb scalar + // multiplication path. + const g2 = bls12_381.g2Generator(); + const cleared = clearCofactor(g2); + try testing.expect(cleared.isOnCurve()); + try testing.expect(bls12_381.isInG2Subgroup(cleared)); +} + +test "hashToG2: result lies on G2 in the prime-order subgroup" { + const point = try hashToG2("hello world", DST_BLS_SIG_NUL); + try testing.expect(point.isOnCurve()); + try testing.expect(bls12_381.isInG2Subgroup(point)); +} + +test "hashToG2: deterministic for same (msg, dst)" { + const a = try hashToG2("test message", DST_BLS_SIG_NUL); + const b = try hashToG2("test message", DST_BLS_SIG_NUL); + try testing.expect(G2Affine.eql(a, b)); +} + +test "hashToG2: distinct messages produce distinct points" { + const a = try hashToG2("message A", DST_BLS_SIG_NUL); + const b = try hashToG2("message B", DST_BLS_SIG_NUL); + try testing.expect(!G2Affine.eql(a, b)); +} + +test "hashToG2: empty message works" { + const point = try hashToG2("", DST_BLS_SIG_NUL); + try testing.expect(point.isOnCurve()); + try testing.expect(bls12_381.isInG2Subgroup(point)); +} diff --git a/packages/zolt-arith/src/curves/bls12_381/hash_to_field.zig b/packages/zolt-arith/src/curves/bls12_381/hash_to_field.zig new file mode 100644 index 00000000..11cae178 --- /dev/null +++ b/packages/zolt-arith/src/curves/bls12_381/hash_to_field.zig @@ -0,0 +1,339 @@ +//! RFC 9380 hash-to-curve byte-expansion primitives. +//! +//! Hash-to-curve is the standardized way to deterministically map a +//! byte string into a curve point. It is the missing piece that BLS +//! signature verification needs to convert the message bytes into a +//! G2 element. +//! +//! The full pipeline for `BLS_SIG_BLS12381G2_XMD:SHA-256_SSWU_RO_NUL_` +//! has four stages: +//! +//! 1. `expand_message_xmd` (RFC 9380 §5.3.1) — expand the input bytes +//! and DST into a byte string of arbitrary length using SHA-256. +//! 2. `hash_to_field` (RFC 9380 §5.2) — interpret the expanded bytes +//! as field elements modulo `p` (or `p²` for Fp2). +//! 3. `map_to_curve` (the SSWU map for BLS12-381 G2) — map a field +//! element to an isogenous curve point, then push it through the +//! isogeny to land on the actual curve. +//! 4. `clear_cofactor` — scale by the cofactor to land in the +//! prime-order subgroup. +//! +//! This file currently implements stage 1 only. The remaining stages +//! land alongside the optimal Ate pairing in upcoming iterations. + +const std = @import("std"); +const bls12_381 = @import("curve.zig"); + +const Sha256 = std.crypto.hash.sha2.Sha256; +const Fp = bls12_381.Fp; +const Fp2 = bls12_381.Fp2; + +/// Errors `expand_message_xmd` can surface. +pub const ExpandError = error{ + /// `len_in_bytes` exceeds `255 * b_in_bytes` (the limit set by the + /// 8-bit big-endian length encoding inside `expand_message_xmd`). + OutputTooLong, + /// The DST exceeds 255 bytes. RFC 9380 caps DSTs at this length so + /// `len(DST_prime)` always fits in a single byte. + DstTooLong, +}; + +/// `expand_message_xmd(msg, DST, len_in_bytes)` from RFC 9380 §5.3.1. +/// +/// Writes exactly `out.len` bytes into `out`. The output is a uniformly +/// pseudo-random byte string derived from `msg` and `dst`, suitable for +/// feeding into `hash_to_field`. +/// +/// Algorithm (with `H = SHA-256`, `b_in_bytes = 32`, `s_in_bytes = 64`): +/// +/// ell = ceil(len_in_bytes / b_in_bytes) +/// DST_prime = DST || I2OSP(len(DST), 1) +/// Z_pad = I2OSP(0, s_in_bytes) +/// l_i_b_str = I2OSP(len_in_bytes, 2) +/// msg_prime = Z_pad || msg || l_i_b_str || I2OSP(0, 1) || DST_prime +/// b_0 = H(msg_prime) +/// b_1 = H(b_0 || I2OSP(1, 1) || DST_prime) +/// for i in 2..ell: +/// b_i = H((b_0 XOR b_{i-1}) || I2OSP(i, 1) || DST_prime) +/// return concat(b_1, b_2, ..., b_ell)[0..len_in_bytes] +pub fn expand_message_xmd( + out: []u8, + msg: []const u8, + dst: []const u8, +) ExpandError!void { + const b_in_bytes: usize = 32; // SHA-256 output length + const s_in_bytes: usize = 64; // SHA-256 input block size + + if (dst.len > 255) return ExpandError.DstTooLong; + if (out.len > 255 * b_in_bytes) return ExpandError.OutputTooLong; + if (out.len > std.math.maxInt(u16)) return ExpandError.OutputTooLong; + + const ell = (out.len + b_in_bytes - 1) / b_in_bytes; + std.debug.assert(ell <= 255); + + // DST_prime = DST || I2OSP(len(DST), 1). + var dst_prime: [256]u8 = undefined; + @memcpy(dst_prime[0..dst.len], dst); + dst_prime[dst.len] = @intCast(dst.len); + const dst_prime_len = dst.len + 1; + + // b_0 = H(Z_pad || msg || l_i_b_str || 0x00 || DST_prime). + var hasher = Sha256.init(.{}); + const z_pad: [s_in_bytes]u8 = .{0} ** s_in_bytes; + hasher.update(&z_pad); + hasher.update(msg); + var l_i_b_str: [2]u8 = undefined; + std.mem.writeInt(u16, &l_i_b_str, @intCast(out.len), .big); + hasher.update(&l_i_b_str); + hasher.update(&[_]u8{0}); + hasher.update(dst_prime[0..dst_prime_len]); + var b_0: [b_in_bytes]u8 = undefined; + hasher.final(&b_0); + + // b_1 = H(b_0 || 0x01 || DST_prime). + var b_prev: [b_in_bytes]u8 = undefined; + { + var h = Sha256.init(.{}); + h.update(&b_0); + h.update(&[_]u8{1}); + h.update(dst_prime[0..dst_prime_len]); + h.final(&b_prev); + } + const first_chunk = @min(out.len, b_in_bytes); + @memcpy(out[0..first_chunk], b_prev[0..first_chunk]); + + var i: usize = 2; + while (i <= ell) : (i += 1) { + // b_i = H((b_0 XOR b_{i-1}) || I2OSP(i, 1) || DST_prime). + var xored: [b_in_bytes]u8 = undefined; + for (0..b_in_bytes) |j| xored[j] = b_0[j] ^ b_prev[j]; + var h = Sha256.init(.{}); + h.update(&xored); + h.update(&[_]u8{@intCast(i)}); + h.update(dst_prime[0..dst_prime_len]); + var b_i: [b_in_bytes]u8 = undefined; + h.final(&b_i); + + const start = (i - 1) * b_in_bytes; + const remaining = out.len - start; + const chunk = @min(remaining, b_in_bytes); + @memcpy(out[start .. start + chunk], b_i[0..chunk]); + b_prev = b_i; + } +} + +// --------------------------------------------------------------------------- +// Tests +// +// These are self-consistency tests rather than cross-implementation +// vectors: they verify that the function is deterministic, produces +// distinct outputs for distinct inputs, fills the requested length, +// and rejects malformed inputs. A separate cross-check pass against +// RFC 9380 Appendix K.1 vectors lands once a Python or external test +// harness is in place — transcribing the exact byte values from the +// RFC by hand is too error-prone. +// --------------------------------------------------------------------------- + +const testing = std.testing; + +const TEST_DST: []const u8 = "QUUX-V01-CS02-with-expander-SHA256-128"; + +test "expand_message_xmd: deterministic" { + var out_a: [32]u8 = undefined; + var out_b: [32]u8 = undefined; + try expand_message_xmd(&out_a, "abc", TEST_DST); + try expand_message_xmd(&out_b, "abc", TEST_DST); + try testing.expectEqualSlices(u8, &out_a, &out_b); +} + +test "expand_message_xmd: distinct messages produce distinct outputs" { + var out_a: [32]u8 = undefined; + var out_b: [32]u8 = undefined; + try expand_message_xmd(&out_a, "abc", TEST_DST); + try expand_message_xmd(&out_b, "abd", TEST_DST); + try testing.expect(!std.mem.eql(u8, &out_a, &out_b)); +} + +test "expand_message_xmd: distinct DSTs produce distinct outputs" { + var out_a: [32]u8 = undefined; + var out_b: [32]u8 = undefined; + try expand_message_xmd(&out_a, "abc", "DST_A"); + try expand_message_xmd(&out_b, "abc", "DST_B"); + try testing.expect(!std.mem.eql(u8, &out_a, &out_b)); +} + +test "expand_message_xmd: 128-byte output is fully populated" { + // The first b_in_bytes bytes come from b_1 and the next ones from + // b_2 through b_ell. Make sure no chunk is left as the 0x00 default + // by checking that the last byte is not the default zero value. + var out: [128]u8 = undefined; + try expand_message_xmd(&out, "abc", TEST_DST); + // It's overwhelmingly unlikely that the actual output is all zeros. + var any_nonzero = false; + for (out) |b| if (b != 0) { + any_nonzero = true; + break; + }; + try testing.expect(any_nonzero); + // The trailing 32 bytes (b_4) should be different from the first + // 32 bytes (b_1) for a non-degenerate output. + try testing.expect(!std.mem.eql(u8, out[0..32], out[96..128])); +} + +test "expand_message_xmd: empty msg works" { + var out: [32]u8 = undefined; + try expand_message_xmd(&out, "", TEST_DST); + var any_nonzero = false; + for (out) |b| if (b != 0) { + any_nonzero = true; + break; + }; + try testing.expect(any_nonzero); +} + +test "expand_message_xmd: rejects DST longer than 255 bytes" { + var huge_dst: [300]u8 = .{0} ** 300; + var out: [16]u8 = undefined; + try testing.expectError(ExpandError.DstTooLong, expand_message_xmd(&out, "abc", &huge_dst)); +} + +test "expand_message_xmd: rejects oversized output" { + // 255 * 32 = 8160 is the maximum permitted output length. + const oversized = try testing.allocator.alloc(u8, 8161); + defer testing.allocator.free(oversized); + try testing.expectError(ExpandError.OutputTooLong, expand_message_xmd(oversized, "abc", "test")); +} + +test "expand_message_xmd: short outputs are exact-length" { + // Output length 1 stresses the case where ell = 1 and only the + // first byte of b_1 is copied. + var out: [1]u8 = undefined; + try expand_message_xmd(&out, "abc", TEST_DST); + // Determinism check. + var out2: [1]u8 = undefined; + try expand_message_xmd(&out2, "abc", TEST_DST); + try testing.expectEqual(out[0], out2[0]); +} + +// --------------------------------------------------------------------------- +// hash_to_field for BLS12-381 Fp / Fp2 +// +// RFC 9380 §5.2 with `k = 128` security parameter and BLS12-381 base +// prime `p`. The per-element byte length is +// +// L = ceil((ceil(log2(p)) + k) / 8) = ceil((381 + 128) / 8) = 64 +// +// `hash_to_field_fp(msg, dst, count)` produces `count` Fp elements. +// `hash_to_field_fp2(msg, dst, count)` produces `count` Fp2 elements. +// --------------------------------------------------------------------------- + +/// Per-element byte length for BLS12-381 hash-to-field with k = 128. +pub const BLS12_381_L: usize = 64; + +/// Hash a message into `count` BLS12-381 Fp elements. The output +/// slice must have length exactly `count`. +pub fn hash_to_field_fp( + out: []Fp.Element, + msg: []const u8, + dst: []const u8, +) ExpandError!void { + const count = out.len; + const len_in_bytes = count * BLS12_381_L; + // Stack-allocate up to 8 elements (512 bytes). Sufficient for the + // BLS hash-to-curve cases that need 1 or 2 elements. + if (count > 8) return ExpandError.OutputTooLong; + var uniform_bytes: [8 * BLS12_381_L]u8 = undefined; + try expand_message_xmd(uniform_bytes[0..len_in_bytes], msg, dst); + var i: usize = 0; + while (i < count) : (i += 1) { + const start = i * BLS12_381_L; + const chunk: *const [BLS12_381_L]u8 = uniform_bytes[start..][0..BLS12_381_L]; + out[i] = bls12_381.fpFromBytes64Be(chunk); + } +} + +/// Hash a message into `count` BLS12-381 Fp2 elements. Each Fp2 +/// element consumes two 64-byte chunks of expanded output. +pub fn hash_to_field_fp2( + out: []Fp2, + msg: []const u8, + dst: []const u8, +) ExpandError!void { + const count = out.len; + const len_in_bytes = count * 2 * BLS12_381_L; + if (count > 4) return ExpandError.OutputTooLong; + var uniform_bytes: [4 * 2 * BLS12_381_L]u8 = undefined; + try expand_message_xmd(uniform_bytes[0..len_in_bytes], msg, dst); + var i: usize = 0; + while (i < count) : (i += 1) { + const c0_start = (2 * i) * BLS12_381_L; + const c1_start = (2 * i + 1) * BLS12_381_L; + const c0_chunk: *const [BLS12_381_L]u8 = uniform_bytes[c0_start..][0..BLS12_381_L]; + const c1_chunk: *const [BLS12_381_L]u8 = uniform_bytes[c1_start..][0..BLS12_381_L]; + out[i] = .{ + .c0 = bls12_381.fpFromBytes64Be(c0_chunk), + .c1 = bls12_381.fpFromBytes64Be(c1_chunk), + }; + } +} + +test "hash_to_field_fp: deterministic" { + var a: [2]Fp.Element = undefined; + var b: [2]Fp.Element = undefined; + try hash_to_field_fp(&a, "abc", TEST_DST); + try hash_to_field_fp(&b, "abc", TEST_DST); + try testing.expect(Fp.eql(a[0], b[0])); + try testing.expect(Fp.eql(a[1], b[1])); +} + +test "hash_to_field_fp: distinct inputs produce distinct outputs" { + var a: [1]Fp.Element = undefined; + var b: [1]Fp.Element = undefined; + try hash_to_field_fp(&a, "abc", TEST_DST); + try hash_to_field_fp(&b, "abd", TEST_DST); + try testing.expect(!Fp.eql(a[0], b[0])); +} + +test "hash_to_field_fp: produces non-zero element" { + var elements: [1]Fp.Element = undefined; + try hash_to_field_fp(&elements, "abc", TEST_DST); + // Overwhelmingly unlikely to be zero (probability ~1/p ≈ 2^-381). + var any_nonzero = false; + inline for (0..6) |i| { + if (elements[0][i] != 0) { + any_nonzero = true; + break; + } + } + try testing.expect(any_nonzero); +} + +test "hash_to_field_fp2: deterministic" { + var a: [2]Fp2 = undefined; + var b: [2]Fp2 = undefined; + try hash_to_field_fp2(&a, "abc", TEST_DST); + try hash_to_field_fp2(&b, "abc", TEST_DST); + try testing.expect(Fp2.eql(a[0], b[0])); + try testing.expect(Fp2.eql(a[1], b[1])); +} + +test "hash_to_field_fp2: c0 and c1 differ" { + // Each Fp2 element is built from two distinct 64-byte chunks of the + // expanded output, so c0 should differ from c1 with overwhelming + // probability. + var elements: [1]Fp2 = undefined; + try hash_to_field_fp2(&elements, "abc", TEST_DST); + try testing.expect(!Fp.eql(elements[0].c0, elements[0].c1)); +} + +test "expand_message_xmd: distinct output lengths produce distinct b_1" { + // The l_i_b_str field is part of msg_prime, so b_0 — and therefore + // b_1 — depends on the requested output length. A 16-byte output + // and the first 16 bytes of a 32-byte output for the same (msg, DST) + // should NOT match. This is a property of the spec, not a bug. + var short: [16]u8 = undefined; + var full: [32]u8 = undefined; + try expand_message_xmd(&short, "abc", TEST_DST); + try expand_message_xmd(&full, "abc", TEST_DST); + try testing.expect(!std.mem.eql(u8, &short, full[0..16])); +} diff --git a/packages/zolt-arith/src/curves/bls12_381/mod.zig b/packages/zolt-arith/src/curves/bls12_381/mod.zig new file mode 100644 index 00000000..b6412069 --- /dev/null +++ b/packages/zolt-arith/src/curves/bls12_381/mod.zig @@ -0,0 +1,43 @@ +//! BLS12-381 curve type bundle. +//! +//! Re-exports from the self-contained BLS12-381 implementation under +//! this directory. The public API matches the standalone `zolt-arith` +//! package that zyli consumed before consolidation. + +pub const curve = @import("curve.zig"); +pub const bls = @import("bls.zig"); +pub const hash_to_field = @import("hash_to_field.zig"); +pub const hash_to_curve_g2 = @import("hash_to_curve_g2.zig"); + +// Re-export the most commonly accessed types at the top level so +// callers can write `bls12_381.G1Affine` instead of +// `bls12_381.curve.G1Affine`. +pub const Fp = curve.Fp; +pub const Fr = curve.Fr; +pub const Fp2 = curve.Fp2; +pub const Fp6 = curve.Fp6; +pub const Fp12 = curve.Fp12; +pub const G1Affine = curve.G1Affine; +pub const G1Projective = curve.G1Projective; +pub const G2Affine = curve.G2Affine; +pub const G2Projective = curve.G2Projective; +pub const PointDecodeError = curve.PointDecodeError; + +pub const decodeG1Compressed = curve.decodeG1Compressed; +pub const decodeG2Compressed = curve.decodeG2Compressed; +pub const encodeG1Compressed = curve.encodeG1Compressed; +pub const encodeG2Compressed = curve.encodeG2Compressed; +pub const g1Generator = curve.g1Generator; +pub const g2Generator = curve.g2Generator; +pub const isInG1Subgroup = curve.isInG1Subgroup; +pub const isInG2Subgroup = curve.isInG2Subgroup; +pub const millerLoop = curve.millerLoop; +pub const fp12FinalExp = curve.fp12FinalExp; +pub const fpFromBytes64Be = curve.fpFromBytes64Be; + +test { + _ = curve; + _ = bls; + _ = hash_to_field; + _ = hash_to_curve_g2; +} diff --git a/packages/zolt-arith/src/curves/bn254/mod.zig b/packages/zolt-arith/src/curves/bn254/mod.zig new file mode 100644 index 00000000..407daf79 --- /dev/null +++ b/packages/zolt-arith/src/curves/bn254/mod.zig @@ -0,0 +1,109 @@ +//! BN254 curve type bundle. +//! +//! Phase 1: field types only (Fr, Fp). These are instantiations of the +//! generic `MontgomeryField(N, ...)` factory with BN254's specific +//! moduli and Montgomery constants. +//! +//! Later phases will add: Fp2, Fp6, Fp12, G1Affine, G2Affine, Pairing. + +const MontgomeryField = @import("../montgomery_field.zig").MontgomeryField; +const params = @import("params.zig"); + +/// BN254 scalar field element (4-limb, 254 bits). +pub const Fr = MontgomeryField(4, params.FR_MODULUS, params.FR_R2, params.FR_INV); + +/// BN254 base field element (4-limb, 254 bits). +pub const Fp = MontgomeryField(4, params.FP_MODULUS, params.FP_R2, params.FP_INV); + +// ========================================================================= +// Tests — verify the generic instantiation matches the OG's constants +// ========================================================================= + +const testing = @import("std").testing; + +test "BN254 Fr.one() matches BN254_R" { + const o = Fr.one(); + try testing.expectEqual(params.FR_R, o.limbs); +} + +test "BN254 Fp.one() matches BN254_FP_R" { + const o = Fp.one(); + try testing.expectEqual(params.FP_R, o.limbs); +} + +test "BN254 Fr: 2 * 3 = 6" { + const two = Fr.fromRaw(.{ 2, 0, 0, 0 }); + const three = Fr.fromRaw(.{ 3, 0, 0, 0 }); + const six = Fr.fromRaw(.{ 6, 0, 0, 0 }); + try testing.expect(Fr.eql(Fr.montgomeryMul(two, three), six)); +} + +test "BN254 Fp: a * a^-1 = 1" { + const a = Fp.fromRaw(.{ 0x12345678, 0, 0, 0 }); + const inv_a = Fp.inverse(a) orelse return error.SkipZigTest; + try testing.expect(Fp.eql(Fp.montgomeryMul(a, inv_a), Fp.one())); +} + +test "BN254 Fr: .limbs access works for backward compat" { + const a = Fr.fromU64(42); + _ = a.limbs[0]; + _ = a.limbs[3]; +} + +// Cross-validation: generic Fp vs OG BN254BaseField (now aliased) +const field_mod = @import("../../field/mod.zig"); +const OGFp = field_mod.BN254BaseField; + +test "Cross-validate: generic Fp one() matches OG BN254BaseField one()" { + try testing.expectEqual(Fp.one().limbs, OGFp.one().limbs); +} + +test "Cross-validate: generic Fp add matches OG" { + const a = Fp.fromU64(12345); + const b = Fp.fromU64(67890); + const og_a = OGFp.fromU64(12345); + const og_b = OGFp.fromU64(67890); + try testing.expectEqual(Fp.add(a, b).limbs, OGFp.add(og_a, og_b).limbs); +} + +test "Cross-validate: generic Fp mul matches OG" { + const a = Fp.fromU64(42); + const b = Fp.fromU64(99); + const og_a = OGFp.fromU64(42); + const og_b = OGFp.fromU64(99); + try testing.expectEqual(Fp.mul(a, b).limbs, OGFp.montgomeryMul(og_a, og_b).limbs); +} + +test "Cross-validate: generic Fp inverse matches OG" { + const a = Fp.fromU64(42); + const og_a = OGFp.fromU64(42); + const inv = Fp.inverse(a) orelse return error.SkipZigTest; + const og_inv = OGFp.inverse(og_a) orelse return error.SkipZigTest; + try testing.expectEqual(inv.limbs, og_inv.limbs); +} + +test "Cross-validate: generic Fp sumOfProducts matches OG" { + const a0 = Fp.fromU64(3); + const a1 = Fp.fromU64(7); + const b0 = Fp.fromU64(11); + const b1 = Fp.fromU64(13); + const og_a0 = OGFp.fromU64(3); + const og_a1 = OGFp.fromU64(7); + const og_b0 = OGFp.fromU64(11); + const og_b1 = OGFp.fromU64(13); + const result = Fp.sumOfProducts(.{ a0, a1 }, .{ b0, b1 }); + const og_result = OGFp.sumOfProducts(.{ og_a0, og_a1 }, .{ og_b0, og_b1 }); + try testing.expectEqual(result.limbs, og_result.limbs); +} + +test "Cross-validate: generic Fp double matches OG" { + const a = Fp.fromU64(42); + const og_a = OGFp.fromU64(42); + try testing.expectEqual(Fp.double(a).limbs, OGFp.double(og_a).limbs); +} + +test "Cross-validate: generic Fp toBytesBE matches OG" { + const a = Fp.fromU64(0xdeadbeef); + const og_a = OGFp.fromU64(0xdeadbeef); + try testing.expectEqual(Fp.toBytesBE(a), OGFp.toBytesBE(og_a)); +} diff --git a/packages/zolt-arith/src/curves/bn254/params.zig b/packages/zolt-arith/src/curves/bn254/params.zig new file mode 100644 index 00000000..2394e545 --- /dev/null +++ b/packages/zolt-arith/src/curves/bn254/params.zig @@ -0,0 +1,66 @@ +//! BN254 curve parameters for the generic curve substrate. +//! +//! These are the field moduli, Montgomery constants, and (eventually) +//! extension field / pairing parameters. They are the SAME numbers that +//! live in `field/mod.zig` today — just gathered into a self-contained +//! location so the generic factories can reference them. + +// ── Scalar field Fr ────────────────────────────────────────────────── + +/// BN254 scalar field modulus +/// p = 21888242871839275222246405745257275088548364400416034343698204186575808495617 +pub const FR_MODULUS: [4]u64 = .{ + 0x43e1f593f0000001, + 0x2833e84879b97091, + 0xb85045b68181585d, + 0x30644e72e131a029, +}; + +/// Montgomery R for Fr (R = 2^256 mod p) +pub const FR_R: [4]u64 = .{ + 0xac96341c4ffffffb, + 0x36fc76959f60cd29, + 0x666ea36f7879462e, + 0x0e0a77c19a07df2f, +}; + +/// Montgomery R^2 for Fr (R^2 = 2^512 mod p) +pub const FR_R2: [4]u64 = .{ + 0x1bb8e645ae216da7, + 0x53fe3ab1e35c59e3, + 0x8c49833d53bb8085, + 0x0216d0b17f4e44a5, +}; + +/// -p^{-1} mod 2^64 +pub const FR_INV: u64 = 0xc2e1f593efffffff; + +// ── Base field Fp ──────────────────────────────────────────────────── + +/// BN254 base field modulus +/// q = 21888242871839275222246405745257275088696311157297823662689037894645226208583 +pub const FP_MODULUS: [4]u64 = .{ + 0x3c208c16d87cfd47, + 0x97816a916871ca8d, + 0xb85045b68181585d, + 0x30644e72e131a029, +}; + +/// Montgomery R for Fp +pub const FP_R: [4]u64 = .{ + 0xd35d438dc58f0d9d, + 0x0a78eb28f5c70b3d, + 0x666ea36f7879462c, + 0x0e0a77c19a07df2f, +}; + +/// Montgomery R^2 for Fp +pub const FP_R2: [4]u64 = .{ + 0xf32cfc5b538afa89, + 0xb5e71911d44501fb, + 0x47ab1eff0a417ff6, + 0x06d89f71cab8351f, +}; + +/// -q^{-1} mod 2^64 +pub const FP_INV: u64 = 0x87d20782e4866389; diff --git a/packages/zolt-arith/src/curves/extensions.zig b/packages/zolt-arith/src/curves/extensions.zig new file mode 100644 index 00000000..f7227c9e --- /dev/null +++ b/packages/zolt-arith/src/curves/extensions.zig @@ -0,0 +1,520 @@ +//! Generic extension field factories. +//! +//! `Fp2(BaseFp)` — quadratic extension Fp[u] / (u² + 1). +//! +//! Both BN254 and BLS12-381 use non-residue β = -1 for the quadratic +//! extension, so the generic factory is specialized for that case. +//! A future version can parameterize over β if needed for other curves. + +/// Build a quadratic extension field Fp[u] / (u² + 1) over a base +/// field `BaseFp`. The base field must support: `zero, one, add, sub, +/// neg, mul, square, inverse, eql, isZero`. +pub fn Fp2(comptime BaseFp: type) type { + return struct { + c0: BaseFp, + c1: BaseFp, + + const Self = @This(); + + pub fn init(c0: BaseFp, c1: BaseFp) Self { + return .{ .c0 = c0, .c1 = c1 }; + } + + pub fn zero() Self { + return .{ .c0 = BaseFp.zero(), .c1 = BaseFp.zero() }; + } + + pub fn one() Self { + return .{ .c0 = BaseFp.one(), .c1 = BaseFp.zero() }; + } + + pub fn eql(a: Self, b: Self) bool { + return BaseFp.eql(a.c0, b.c0) and BaseFp.eql(a.c1, b.c1); + } + + pub fn isZero(self: Self) bool { + return self.c0.isZero() and self.c1.isZero(); + } + + pub fn add(a: Self, b: Self) Self { + return .{ + .c0 = BaseFp.add(a.c0, b.c0), + .c1 = BaseFp.add(a.c1, b.c1), + }; + } + + pub fn sub(a: Self, b: Self) Self { + return .{ + .c0 = BaseFp.sub(a.c0, b.c0), + .c1 = BaseFp.sub(a.c1, b.c1), + }; + } + + pub fn neg(a: Self) Self { + return .{ .c0 = BaseFp.neg(a.c0), .c1 = BaseFp.neg(a.c1) }; + } + + /// Whether the base field exposes sumOfProducts (fused mul-accumulate). + const has_sum_of_products = @hasDecl(BaseFp, "sumOfProducts"); + /// Whether the base field exposes addNoReduce (lazy reduction). + const has_add_no_reduce = @hasDecl(BaseFp, "addNoReduce"); + + /// Multiplication: (a₀ + a₁u)(b₀ + b₁u) = (a₀b₀ - a₁b₁) + (a₀b₁ + a₁b₀)u + /// When BaseFp has sumOfProducts, uses fused path (2 reductions + /// instead of 3). Otherwise falls back to Karatsuba. + pub fn mul(a: Self, b: Self) Self { + if (comptime has_sum_of_products) { + const neg_a1 = BaseFp.neg(a.c1); + return .{ + .c0 = BaseFp.sumOfProducts(.{ a.c0, neg_a1 }, .{ b.c0, b.c1 }), + .c1 = BaseFp.sumOfProducts(.{ a.c0, a.c1 }, .{ b.c1, b.c0 }), + }; + } + const a0b0 = BaseFp.mul(a.c0, b.c0); + const a1b1 = BaseFp.mul(a.c1, b.c1); + const a0_plus_a1 = BaseFp.add(a.c0, a.c1); + const b0_plus_b1 = BaseFp.add(b.c0, b.c1); + const cross = BaseFp.mul(a0_plus_a1, b0_plus_b1); + return .{ + .c0 = BaseFp.sub(a0b0, a1b1), + .c1 = BaseFp.sub(BaseFp.sub(cross, a0b0), a1b1), + }; + } + + /// Squaring: (a + bu)² = (a²-b²) + 2ab·u + /// Uses addNoReduce for (a+b) when available (saves 1 reduction). + pub fn square(a: Self) Self { + const a_plus_b = if (comptime has_add_no_reduce) + BaseFp.addNoReduce(a.c0, a.c1) + else + BaseFp.add(a.c0, a.c1); + const a_minus_b = BaseFp.sub(a.c0, a.c1); + const c0 = BaseFp.mul(a_plus_b, a_minus_b); + const ab = BaseFp.mul(a.c0, a.c1); + return .{ .c0 = c0, .c1 = BaseFp.add(ab, ab) }; + } + + /// Conjugate: a + bu → a - bu + pub fn conjugate(a: Self) Self { + return .{ .c0 = a.c0, .c1 = BaseFp.neg(a.c1) }; + } + + /// Inverse: (a + bu)⁻¹ = (a - bu) / (a² + b²) + pub fn inverse(a: Self) ?Self { + if (a.isZero()) return null; + const norm = BaseFp.add(BaseFp.square(a.c0), BaseFp.square(a.c1)); + const norm_inv = norm.inverse() orelse return null; + return .{ + .c0 = BaseFp.mul(a.c0, norm_inv), + .c1 = BaseFp.mul(BaseFp.neg(a.c1), norm_inv), + }; + } + + /// Non-optional inverse (returns zero for zero). + pub fn inv(a: Self) Self { + return a.inverse() orelse zero(); + } + + /// Scalar multiplication by a base field element. + pub fn scalarMul(a: Self, s: BaseFp) Self { + return .{ + .c0 = BaseFp.mul(a.c0, s), + .c1 = BaseFp.mul(a.c1, s), + }; + } + + /// Multiply by the non-residue u: (a + bu) * u = -b + au + pub fn mulByNonResidue(a: Self) Self { + return .{ .c0 = BaseFp.neg(a.c1), .c1 = a.c0 }; + } + + /// Double + pub fn double(a: Self) Self { + return .{ + .c0 = BaseFp.add(a.c0, a.c0), + .c1 = BaseFp.add(a.c1, a.c1), + }; + } + }; +} + +/// Build a cubic extension field Fp2[v] / (v³ − ξ) where ξ is the +/// cubic non-residue in Fp2. The `mulByXiFn` parameter is a comptime +/// function that multiplies an Fp2 element by ξ. +/// +/// BN254: ξ = 9 + u → shift-add (4 Fp additions per component) +/// BLS12-381: ξ = 1 + u → (c0-c1) + (c0+c1)·u +pub fn Fp6(comptime Fp2Type: type, comptime mulByXiFn: fn (Fp2Type) Fp2Type) type { + return struct { + c0: Fp2Type, + c1: Fp2Type, + c2: Fp2Type, + + const Self = @This(); + + pub fn zero() Self { + return .{ .c0 = Fp2Type.zero(), .c1 = Fp2Type.zero(), .c2 = Fp2Type.zero() }; + } + + pub fn one() Self { + return .{ .c0 = Fp2Type.one(), .c1 = Fp2Type.zero(), .c2 = Fp2Type.zero() }; + } + + pub fn eql(a: Self, b: Self) bool { + return Fp2Type.eql(a.c0, b.c0) and Fp2Type.eql(a.c1, b.c1) and Fp2Type.eql(a.c2, b.c2); + } + + pub fn add(a: Self, b: Self) Self { + return .{ + .c0 = Fp2Type.add(a.c0, b.c0), + .c1 = Fp2Type.add(a.c1, b.c1), + .c2 = Fp2Type.add(a.c2, b.c2), + }; + } + + pub fn sub(a: Self, b: Self) Self { + return .{ + .c0 = Fp2Type.sub(a.c0, b.c0), + .c1 = Fp2Type.sub(a.c1, b.c1), + .c2 = Fp2Type.sub(a.c2, b.c2), + }; + } + + pub fn neg(a: Self) Self { + return .{ .c0 = Fp2Type.neg(a.c0), .c1 = Fp2Type.neg(a.c1), .c2 = Fp2Type.neg(a.c2) }; + } + + /// Multiply an Fp2 element by the cubic non-residue ξ. + pub fn mulByXi(x: Fp2Type) Fp2Type { + return mulByXiFn(x); + } + + /// Karatsuba multiplication in Fp6. + pub fn mul(a: Self, b: Self) Self { + const v0 = Fp2Type.mul(a.c0, b.c0); + const v1 = Fp2Type.mul(a.c1, b.c1); + const v2 = Fp2Type.mul(a.c2, b.c2); + + const c1_plus_c2 = Fp2Type.add(a.c1, a.c2); + const d1_plus_d2 = Fp2Type.add(b.c1, b.c2); + const t0 = mulByXiFn(Fp2Type.sub(Fp2Type.sub(Fp2Type.mul(c1_plus_c2, d1_plus_d2), v1), v2)); + const new_c0 = Fp2Type.add(v0, t0); + + const c0_plus_c1 = Fp2Type.add(a.c0, a.c1); + const d0_plus_d1 = Fp2Type.add(b.c0, b.c1); + const t1 = Fp2Type.sub(Fp2Type.sub(Fp2Type.mul(c0_plus_c1, d0_plus_d1), v0), v1); + const new_c1 = Fp2Type.add(t1, mulByXiFn(v2)); + + const c0_plus_c2 = Fp2Type.add(a.c0, a.c2); + const d0_plus_d2 = Fp2Type.add(b.c0, b.c2); + const t2 = Fp2Type.sub(Fp2Type.sub(Fp2Type.mul(c0_plus_c2, d0_plus_d2), v0), v2); + const new_c2 = Fp2Type.add(t2, v1); + + return .{ .c0 = new_c0, .c1 = new_c1, .c2 = new_c2 }; + } + + /// Chung-Hasan SQ2 squaring. + pub fn square(a: Self) Self { + const s0 = Fp2Type.square(a.c0); + const ab = Fp2Type.mul(a.c0, a.c1); + const s1 = Fp2Type.add(ab, ab); + const s2 = Fp2Type.square(Fp2Type.add(Fp2Type.sub(a.c0, a.c1), a.c2)); + const bc = Fp2Type.mul(a.c1, a.c2); + const s3 = Fp2Type.add(bc, bc); + const s4 = Fp2Type.square(a.c2); + + return .{ + .c0 = Fp2Type.add(s0, mulByXiFn(s3)), + .c1 = Fp2Type.add(s1, mulByXiFn(s4)), + .c2 = Fp2Type.sub(Fp2Type.add(Fp2Type.add(s1, s2), s3), Fp2Type.add(s0, s4)), + }; + } + + pub fn inverse(a: Self) ?Self { + const c0_sq = Fp2Type.square(a.c0); + const c1_sq = Fp2Type.square(a.c1); + const c2_sq = Fp2Type.square(a.c2); + const c0c1 = Fp2Type.mul(a.c0, a.c1); + const c0c2 = Fp2Type.mul(a.c0, a.c2); + const c1c2 = Fp2Type.mul(a.c1, a.c2); + + const a0 = Fp2Type.sub(c0_sq, mulByXiFn(c1c2)); + const a1 = Fp2Type.sub(mulByXiFn(c2_sq), c0c1); + const a2 = Fp2Type.sub(c1_sq, c0c2); + + const tmp = mulByXiFn(Fp2Type.add(Fp2Type.mul(a.c1, a2), Fp2Type.mul(a.c2, a1))); + const norm = Fp2Type.add(Fp2Type.mul(a.c0, a0), tmp); + + const norm_inv = norm.inverse() orelse return null; + + return .{ + .c0 = Fp2Type.mul(a0, norm_inv), + .c1 = Fp2Type.mul(a1, norm_inv), + .c2 = Fp2Type.mul(a2, norm_inv), + }; + } + + /// Multiply by v (coefficient shift): (c0 + c1·v + c2·v²) · v = ξ·c2 + c0·v + c1·v² + pub fn mulByV(f: Self) Self { + return .{ + .c0 = mulByXiFn(f.c2), + .c1 = f.c0, + .c2 = f.c1, + }; + } + }; +} + +/// Build a quadratic extension Fp6[w] / (w² − v). +/// `Fp6MulByV` is used for w² = v → coefficient shift in Fp6. +pub fn Fp12(comptime Fp6Type: type, comptime Fp2Type: type) type { + return struct { + c0: Fp6Type, + c1: Fp6Type, + + const Self = @This(); + + pub fn zero() Self { + return .{ .c0 = Fp6Type.zero(), .c1 = Fp6Type.zero() }; + } + + pub fn one() Self { + return .{ .c0 = Fp6Type.one(), .c1 = Fp6Type.zero() }; + } + + pub fn eql(a: Self, b: Self) bool { + return Fp6Type.eql(a.c0, b.c0) and Fp6Type.eql(a.c1, b.c1); + } + + pub fn isOne(self: Self) bool { + return self.eql(Self.one()); + } + + pub fn add(a: Self, b: Self) Self { + return .{ .c0 = Fp6Type.add(a.c0, b.c0), .c1 = Fp6Type.add(a.c1, b.c1) }; + } + + pub fn sub(a: Self, b: Self) Self { + return .{ .c0 = Fp6Type.sub(a.c0, b.c0), .c1 = Fp6Type.sub(a.c1, b.c1) }; + } + + pub fn neg(a: Self) Self { + return .{ .c0 = Fp6Type.neg(a.c0), .c1 = Fp6Type.neg(a.c1) }; + } + + /// Karatsuba: (a + bw)(c + dw) = (ac + bd·v) + ((a+b)(c+d) - ac - bd)w + pub fn mul(a: Self, b: Self) Self { + const ac = Fp6Type.mul(a.c0, b.c0); + const bd = Fp6Type.mul(a.c1, b.c1); + const cross = Fp6Type.sub(Fp6Type.sub( + Fp6Type.mul(Fp6Type.add(a.c0, a.c1), Fp6Type.add(b.c0, b.c1)), + ac, + ), bd); + return .{ + .c0 = Fp6Type.add(ac, Fp6Type.mulByV(bd)), + .c1 = cross, + }; + } + + /// Complex squaring. + pub fn square(a: Self) Self { + const ab = Fp6Type.mul(a.c0, a.c1); + const c1_new = Fp6Type.add(ab, ab); + + const a_plus_b = Fp6Type.add(a.c0, a.c1); + const bv = Fp6Type.mulByV(a.c1); + const a_plus_bv = Fp6Type.add(a.c0, bv); + const abv = Fp6Type.mulByV(ab); + const c0_new = Fp6Type.sub(Fp6Type.sub(Fp6Type.mul(a_plus_b, a_plus_bv), ab), abv); + + return .{ .c0 = c0_new, .c1 = c1_new }; + } + + /// Conjugate: a + bw → a - bw + pub fn conjugate(a: Self) Self { + return .{ .c0 = a.c0, .c1 = Fp6Type.neg(a.c1) }; + } + + pub fn inverse(a: Self) ?Self { + const a_squared = Fp6Type.mul(a.c0, a.c0); + const b_squared_v = Fp6Type.mulByV(Fp6Type.mul(a.c1, a.c1)); + const norm = Fp6Type.sub(a_squared, b_squared_v); + const norm_inv = norm.inverse() orelse return null; + return .{ + .c0 = Fp6Type.mul(a.c0, norm_inv), + .c1 = Fp6Type.mul(Fp6Type.neg(a.c1), norm_inv), + }; + } + + /// Cyclotomic squaring (Granger-Scott) for unitary Fp12. + pub fn cyclotomicSquare(a: Self) Self { + const r0 = a.c0.c0; + const r4 = a.c0.c1; + const r3 = a.c0.c2; + const r2 = a.c1.c0; + const r1 = a.c1.c1; + const r5 = a.c1.c2; + + const xi = Fp6Type.mulByXi; + + const tmp01 = Fp2Type.mul(r0, r1); + const t0 = Fp2Type.sub(Fp2Type.sub(Fp2Type.mul(Fp2Type.add(r0, r1), Fp2Type.add(xi(r1), r0)), tmp01), xi(tmp01)); + const t1 = Fp2Type.add(tmp01, tmp01); + + const tmp23 = Fp2Type.mul(r2, r3); + const t2 = Fp2Type.sub(Fp2Type.sub(Fp2Type.mul(Fp2Type.add(r2, r3), Fp2Type.add(xi(r3), r2)), tmp23), xi(tmp23)); + const t3 = Fp2Type.add(tmp23, tmp23); + + const tmp45 = Fp2Type.mul(r4, r5); + const t4 = Fp2Type.sub(Fp2Type.sub(Fp2Type.mul(Fp2Type.add(r4, r5), Fp2Type.add(xi(r5), r4)), tmp45), xi(tmp45)); + const t5 = Fp2Type.add(tmp45, tmp45); + + const z0 = Fp2Type.add(Fp2Type.add(Fp2Type.sub(t0, r0), Fp2Type.sub(t0, r0)), t0); + const z1 = Fp2Type.add(Fp2Type.add(Fp2Type.add(t1, r1), Fp2Type.add(t1, r1)), t1); + const xi_t5 = xi(t5); + const z2 = Fp2Type.add(Fp2Type.add(Fp2Type.add(xi_t5, r2), Fp2Type.add(xi_t5, r2)), xi_t5); + const z3 = Fp2Type.add(Fp2Type.add(Fp2Type.sub(t4, r3), Fp2Type.sub(t4, r3)), t4); + const z4 = Fp2Type.add(Fp2Type.add(Fp2Type.sub(t2, r4), Fp2Type.sub(t2, r4)), t2); + const z5 = Fp2Type.add(Fp2Type.add(Fp2Type.add(t3, r5), Fp2Type.add(t3, r5)), t3); + + return .{ + .c0 = .{ .c0 = z0, .c1 = z4, .c2 = z3 }, + .c1 = .{ .c0 = z2, .c1 = z1, .c2 = z5 }, + }; + } + }; +} + +// ========================================================================= +// Tests +// ========================================================================= + +const testing = @import("std").testing; +const bn254 = @import("bn254/mod.zig"); +const BN254Fp2 = Fp2(bn254.Fp); + +test "Generic Fp2(BN254): zero + one" { + const z = BN254Fp2.zero(); + const o = BN254Fp2.one(); + try testing.expect(BN254Fp2.eql(BN254Fp2.add(z, o), o)); +} + +test "Generic Fp2(BN254): mul commutativity" { + const a = BN254Fp2.init(bn254.Fp.fromU64(3), bn254.Fp.fromU64(7)); + const b = BN254Fp2.init(bn254.Fp.fromU64(11), bn254.Fp.fromU64(13)); + try testing.expect(BN254Fp2.eql(BN254Fp2.mul(a, b), BN254Fp2.mul(b, a))); +} + +test "Generic Fp2(BN254): a * a^-1 = 1" { + const a = BN254Fp2.init(bn254.Fp.fromU64(42), bn254.Fp.fromU64(99)); + const a_inv = BN254Fp2.inverse(a) orelse return error.SkipZigTest; + try testing.expect(BN254Fp2.eql(BN254Fp2.mul(a, a_inv), BN254Fp2.one())); +} + +test "Generic Fp2(BN254): square matches mul(a,a)" { + const a = BN254Fp2.init(bn254.Fp.fromU64(7), bn254.Fp.fromU64(11)); + try testing.expect(BN254Fp2.eql(BN254Fp2.square(a), BN254Fp2.mul(a, a))); +} + +test "Generic Fp2(BN254): conjugate identity a * conj(a) = norm" { + const a = BN254Fp2.init(bn254.Fp.fromU64(3), bn254.Fp.fromU64(5)); + const prod = BN254Fp2.mul(a, BN254Fp2.conjugate(a)); + // norm = a₀² + a₁² (real number — imaginary part is zero) + try testing.expect(prod.c1.isZero()); +} + +// -- Fp6 / Fp12 tests (BN254) -- + +/// BN254's cubic non-residue ξ = 9 + u. Multiply Fp2 by 9+u using shift-add. +fn bn254MulByXi(x: BN254Fp2) BN254Fp2 { + const a = x.c0; + const a2 = bn254.Fp.add(a, a); + const a4 = bn254.Fp.add(a2, a2); + const a8 = bn254.Fp.add(a4, a4); + const a9 = bn254.Fp.add(a8, a); + const b = x.c1; + const b2 = bn254.Fp.add(b, b); + const b4 = bn254.Fp.add(b2, b2); + const b8 = bn254.Fp.add(b4, b4); + const b9 = bn254.Fp.add(b8, b); + return BN254Fp2.init(bn254.Fp.sub(a9, b), bn254.Fp.add(a, b9)); +} + +const BN254Fp6 = Fp6(BN254Fp2, bn254MulByXi); +const BN254Fp12 = Fp12(BN254Fp6, BN254Fp2); + +test "Generic Fp6(BN254): one is multiplicative identity" { + const a = BN254Fp6{ .c0 = BN254Fp2.init(bn254.Fp.fromU64(3), bn254.Fp.fromU64(7)), .c1 = BN254Fp2.init(bn254.Fp.fromU64(11), bn254.Fp.fromU64(0)), .c2 = BN254Fp2.init(bn254.Fp.fromU64(5), bn254.Fp.fromU64(2)) }; + try testing.expect(BN254Fp6.eql(BN254Fp6.mul(a, BN254Fp6.one()), a)); +} + +test "Generic Fp6(BN254): square matches mul(a,a)" { + const a = BN254Fp6{ .c0 = BN254Fp2.init(bn254.Fp.fromU64(3), bn254.Fp.fromU64(7)), .c1 = BN254Fp2.init(bn254.Fp.fromU64(11), bn254.Fp.fromU64(13)), .c2 = BN254Fp2.init(bn254.Fp.fromU64(5), bn254.Fp.fromU64(2)) }; + try testing.expect(BN254Fp6.eql(BN254Fp6.square(a), BN254Fp6.mul(a, a))); +} + +test "Generic Fp6(BN254): a * a^-1 = 1" { + const a = BN254Fp6{ .c0 = BN254Fp2.init(bn254.Fp.fromU64(3), bn254.Fp.fromU64(7)), .c1 = BN254Fp2.init(bn254.Fp.fromU64(11), bn254.Fp.fromU64(13)), .c2 = BN254Fp2.init(bn254.Fp.fromU64(5), bn254.Fp.fromU64(2)) }; + const a_inv = BN254Fp6.inverse(a) orelse return error.SkipZigTest; + try testing.expect(BN254Fp6.eql(BN254Fp6.mul(a, a_inv), BN254Fp6.one())); +} + +test "Generic Fp12(BN254): one is multiplicative identity" { + const fp6_a = BN254Fp6{ .c0 = BN254Fp2.init(bn254.Fp.fromU64(3), bn254.Fp.fromU64(7)), .c1 = BN254Fp2.init(bn254.Fp.fromU64(11), bn254.Fp.fromU64(0)), .c2 = BN254Fp2.init(bn254.Fp.fromU64(5), bn254.Fp.fromU64(2)) }; + const a = BN254Fp12{ .c0 = fp6_a, .c1 = BN254Fp6.one() }; + try testing.expect(BN254Fp12.eql(BN254Fp12.mul(a, BN254Fp12.one()), a)); +} + +test "Generic Fp12(BN254): square matches mul(a,a)" { + const fp6_a = BN254Fp6{ .c0 = BN254Fp2.init(bn254.Fp.fromU64(3), bn254.Fp.fromU64(7)), .c1 = BN254Fp2.init(bn254.Fp.fromU64(11), bn254.Fp.fromU64(13)), .c2 = BN254Fp2.init(bn254.Fp.fromU64(5), bn254.Fp.fromU64(2)) }; + const fp6_b = BN254Fp6{ .c0 = BN254Fp2.init(bn254.Fp.fromU64(17), bn254.Fp.fromU64(19)), .c1 = BN254Fp2.zero(), .c2 = BN254Fp2.init(bn254.Fp.fromU64(23), bn254.Fp.fromU64(0)) }; + const a = BN254Fp12{ .c0 = fp6_a, .c1 = fp6_b }; + try testing.expect(BN254Fp12.eql(BN254Fp12.square(a), BN254Fp12.mul(a, a))); +} + +test "Generic Fp12(BN254): a * a^-1 = 1" { + const fp6_a = BN254Fp6{ .c0 = BN254Fp2.init(bn254.Fp.fromU64(3), bn254.Fp.fromU64(7)), .c1 = BN254Fp2.init(bn254.Fp.fromU64(11), bn254.Fp.fromU64(13)), .c2 = BN254Fp2.init(bn254.Fp.fromU64(5), bn254.Fp.fromU64(2)) }; + const fp6_b = BN254Fp6{ .c0 = BN254Fp2.init(bn254.Fp.fromU64(17), bn254.Fp.fromU64(19)), .c1 = BN254Fp2.init(bn254.Fp.fromU64(1), bn254.Fp.fromU64(0)), .c2 = BN254Fp2.init(bn254.Fp.fromU64(23), bn254.Fp.fromU64(29)) }; + const a = BN254Fp12{ .c0 = fp6_a, .c1 = fp6_b }; + const a_inv = BN254Fp12.inverse(a) orelse return error.SkipZigTest; + try testing.expect(BN254Fp12.eql(BN254Fp12.mul(a, a_inv), BN254Fp12.one())); +} + +// Cross-validate Fp6 against OG +test "Generic Fp6(BN254): mul matches OG Fp6.mul" { + const OGFp6 = @import("../field/extensions.zig").Fp6; + const OGFp2 = @import("../field/extensions.zig").Fp2; + const OGFp = @import("../field/mod.zig").BN254BaseField; + + const og_a = OGFp6{ .c0 = OGFp2.init(OGFp.fromU64(3), OGFp.fromU64(7)), .c1 = OGFp2.init(OGFp.fromU64(11), OGFp.fromU64(13)), .c2 = OGFp2.init(OGFp.fromU64(5), OGFp.fromU64(2)) }; + const og_b = OGFp6{ .c0 = OGFp2.init(OGFp.fromU64(17), OGFp.fromU64(19)), .c1 = OGFp2.init(OGFp.fromU64(23), OGFp.fromU64(29)), .c2 = OGFp2.init(OGFp.fromU64(31), OGFp.fromU64(37)) }; + const og_result = og_a.mul(og_b); + + const gen_a = BN254Fp6{ .c0 = BN254Fp2.init(bn254.Fp.fromU64(3), bn254.Fp.fromU64(7)), .c1 = BN254Fp2.init(bn254.Fp.fromU64(11), bn254.Fp.fromU64(13)), .c2 = BN254Fp2.init(bn254.Fp.fromU64(5), bn254.Fp.fromU64(2)) }; + const gen_b = BN254Fp6{ .c0 = BN254Fp2.init(bn254.Fp.fromU64(17), bn254.Fp.fromU64(19)), .c1 = BN254Fp2.init(bn254.Fp.fromU64(23), bn254.Fp.fromU64(29)), .c2 = BN254Fp2.init(bn254.Fp.fromU64(31), bn254.Fp.fromU64(37)) }; + const gen_result = BN254Fp6.mul(gen_a, gen_b); + + try testing.expectEqual(gen_result.c0.c0.limbs, og_result.c0.c0.limbs); + try testing.expectEqual(gen_result.c0.c1.limbs, og_result.c0.c1.limbs); + try testing.expectEqual(gen_result.c1.c0.limbs, og_result.c1.c0.limbs); + try testing.expectEqual(gen_result.c2.c0.limbs, og_result.c2.c0.limbs); +} + +// Cross-validate against the OG BN254 Fp2 +test "Generic Fp2(BN254): mul matches OG extensions.Fp2.mul" { + const OGFp2 = @import("../field/extensions.zig").Fp2; + const OGFp = @import("../field/mod.zig").BN254BaseField; + + const og_a = OGFp2.init(OGFp.fromU64(3), OGFp.fromU64(7)); + const og_b = OGFp2.init(OGFp.fromU64(11), OGFp.fromU64(13)); + const og_result = og_a.mul(og_b); + + const gen_a = BN254Fp2.init(bn254.Fp.fromU64(3), bn254.Fp.fromU64(7)); + const gen_b = BN254Fp2.init(bn254.Fp.fromU64(11), bn254.Fp.fromU64(13)); + const gen_result = BN254Fp2.mul(gen_a, gen_b); + + // BN254BaseField is now an alias for curves.bn254.Fp, so the limbs + // should be directly comparable. + try testing.expectEqual(gen_result.c0.limbs, og_result.c0.limbs); + try testing.expectEqual(gen_result.c1.limbs, og_result.c1.limbs); +} diff --git a/packages/zolt-arith/src/curves/mod.zig b/packages/zolt-arith/src/curves/mod.zig new file mode 100644 index 00000000..81ab72ad --- /dev/null +++ b/packages/zolt-arith/src/curves/mod.zig @@ -0,0 +1,23 @@ +//! Curve-generic substrate for pairing-friendly elliptic curves. +//! +//! This module provides the organizing abstraction that lets zolt-arith +//! support multiple curves (BN254 for zolt's prover, BLS12-381 for +//! zyli's validator signature verification) through the same generic +//! factories: `montgomery_field`, `extensions`, `weierstrass`, `pairing`. +//! +//! Each curve lives in its own submodule (`bn254/`, `bls12_381/`) and +//! exports a type bundle: `Fp, Fr, Fp2, Fp6, Fp12, G1Affine, G2Affine, +//! Pairing, HashToG2, Bls`. + +pub const montgomery_field = @import("montgomery_field.zig"); +pub const MontgomeryField = montgomery_field.MontgomeryField; +pub const extensions = @import("extensions.zig"); +pub const Fp2 = extensions.Fp2; + +/// Curve family (drives final-exponentiation and loop-count derivation). +pub const Family = enum { bn, bls }; + +test { + _ = montgomery_field; + _ = extensions; +} diff --git a/packages/zolt-arith/src/curves/montgomery_field.zig b/packages/zolt-arith/src/curves/montgomery_field.zig new file mode 100644 index 00000000..391b0449 --- /dev/null +++ b/packages/zolt-arith/src/curves/montgomery_field.zig @@ -0,0 +1,1097 @@ +//! Generic Montgomery field arithmetic over `[N]u64` limbs. +//! +//! `MontgomeryField(N, modulus, r2, n_prime)` returns a concrete type +//! whose elements are `struct { limbs: [N]u64 }` in Montgomery form. +//! The factory is comptime-generic over the limb count, so both 4-limb +//! BN254 and 6-limb BLS12-381 fields flow through the same code. +//! +//! The API surface matches the OG zolt-arith `BN254Scalar` / `BN254BaseField` +//! shapes (struct wrapper, `.limbs` access, method-call convention) so +//! existing callers in `field/`, `poly/`, `msm/`, etc. continue to work +//! when the BN254 types are aliased to instantiations of this factory. +//! +//! Aliases (e.g. `montMul` for `montgomeryMul`, `fromRaw` for +//! `toMontgomery`) let the BLS12-381 side use the same type under the +//! names established in the standalone `zolt-arith` package that zyli +//! consumed before the consolidation. + +const std = @import("std"); +const builtin = @import("builtin"); +const bigint = @import("../bigint.zig"); + +// Import the N=4 asm backends from field/mod.zig (ARM64 + x86 helpers). +const asm_mod = @import("../field/mod.zig"); + +/// Comptime flag: x86-64 BMI2+ADX available for fast Montgomery mul. +const use_asm_mul = blk: { + if (builtin.cpu.arch != .x86_64) break :blk false; + const features = builtin.cpu.features; + break :blk features.isEnabled(@intFromEnum(std.Target.x86.Feature.bmi2)) and + features.isEnabled(@intFromEnum(std.Target.x86.Feature.adx)); +}; + +/// Comptime flag: AArch64 (adds/adcs/subs/sbcs always available). +const use_arm64_asm = (builtin.cpu.arch == .aarch64); + +/// LLVM x86 carry/borrow intrinsics — produce single adc/sbb instructions. +/// These work in Release builds (LLVM lowers them as intrinsics) but fail in +/// Debug (linker can't resolve the symbol). Guard all uses with !@inComptime(). +// In Debug mode the linker can't resolve LLVM intrinsic symbols (they're +// only inlined by LLVM in optimized builds). So we only declare them when +// the module is compiled with optimizations. The bench build.zig creates +// a ReleaseFast dep chain for zolt-arith, so benches get the fast path. +const x86_has_intrinsics = builtin.cpu.arch == .x86_64 and builtin.mode != .Debug; +const x86_intrinsics = if (x86_has_intrinsics) struct { + extern fn @"llvm.x86.addcarry.u64"(c_in: u8, a: u64, b: u64, result: *u64) u8; + extern fn @"llvm.x86.subborrow.u64"(b_in: u8, a: u64, b: u64, result: *u64) u8; + + pub inline fn addcarry(c_in: u8, a: u64, b: u64, result: *u64) u8 { + return @"llvm.x86.addcarry.u64"(c_in, a, b, result); + } + pub inline fn subborrow(b_in: u8, a: u64, b: u64, result: *u64) u8 { + return @"llvm.x86.subborrow.u64"(b_in, a, b, result); + } +} else struct {}; + +/// Build a Montgomery field type over the given modulus. +/// +/// Parameters (all comptime): +/// - `N`: number of u64 limbs +/// - `modulus`: `[N]u64`, the prime `p` (little-endian) +/// - `r2`: `R^2 mod p` where `R = 2^(64·N)` +/// - `n_prime`: `-p^{-1} mod 2^64` +pub fn MontgomeryField( + comptime N: comptime_int, + comptime modulus: [N]u64, + comptime r2: [N]u64, + comptime n_prime: u64, +) type { + return struct { + limbs: [N]u64, + + const Self = @This(); + + /// The numeric limb count, exported for byte-conversion logic. + pub const LIMB_COUNT: comptime_int = N; + pub const MODULUS: [N]u64 = modulus; + pub const R2: [N]u64 = r2; + pub const N_PRIME: u64 = n_prime; + + // ----------------------------------------------------------------- + // Constants + // ----------------------------------------------------------------- + + /// Zero element (additive identity). + pub fn zero() Self { + return .{ .limbs = .{0} ** N }; + } + + /// One element (multiplicative identity) in Montgomery form. + /// Computed as `montMul(1, R^2)` = `1 * R^2 * R^{-1}` = `R mod p`. + pub fn one() Self { + var raw_one: [N]u64 = .{0} ** N; + raw_one[0] = 1; + return montgomeryMul(.{ .limbs = raw_one }, .{ .limbs = r2 }); + } + + // ----------------------------------------------------------------- + // Predicates + // ----------------------------------------------------------------- + + pub fn isZero(self: Self) bool { + inline for (0..N) |i| { + if (self.limbs[i] != 0) return false; + } + return true; + } + + pub fn isOne(self: Self) bool { + const o = one(); + return eql(self, o); + } + + pub fn eql(a: Self, b: Self) bool { + inline for (0..N) |i| { + if (a.limbs[i] != b.limbs[i]) return false; + } + return true; + } + + // ----------------------------------------------------------------- + // Constructors + // ----------------------------------------------------------------- + + /// Create from a small u64, converting to Montgomery form. + pub fn fromU64(n: u64) Self { + var raw: [N]u64 = .{0} ** N; + raw[0] = n; + return montgomeryMul(.{ .limbs = raw }, .{ .limbs = r2 }); + } + + /// Create from u128, converting to Montgomery form. + pub fn fromU128(n: u128) Self { + var raw: [N]u64 = .{0} ** N; + raw[0] = @truncate(n); + if (N > 1) raw[1] = @truncate(n >> 64); + return montgomeryMul(.{ .limbs = raw }, .{ .limbs = r2 }); + } + + /// Create from little-endian bytes, converting to Montgomery form. + pub fn fromBytes(bytes: []const u8) Self { + var raw: [N]u64 = .{0} ** N; + const len = @min(bytes.len, N * 8); + var buf: [N * 8]u8 = .{0} ** (N * 8); + @memcpy(buf[0..len], bytes[0..len]); + inline for (0..N) |i| { + raw[i] = std.mem.readInt(u64, buf[i * 8 ..][0..8], .little); + } + return montgomeryMul(.{ .limbs = raw }, .{ .limbs = r2 }); + } + + /// Create from big-endian bytes (32 / 48 / 64 byte pubkeys etc.) + pub fn fromBytesBE(bytes: *const [N * 8]u8) Self { + var le_bytes: [N * 8]u8 = undefined; + for (0..N * 8) |i| { + le_bytes[i] = bytes[N * 8 - 1 - i]; + } + return fromBytes(&le_bytes); + } + + /// Convert a raw little-endian limb array INTO Montgomery form. + pub fn fromRaw(raw: [N]u64) Self { + return mul(.{ .limbs = raw }, .{ .limbs = r2 }); + } + + /// Alias used by the OG callers. + pub fn toMontgomery(self: Self) Self { + return mul(self, .{ .limbs = r2 }); + } + + /// Convert FROM Montgomery form back to raw little-endian limbs. + pub fn toRaw(self: Self) [N]u64 { + var one_raw: [N]u64 = .{0} ** N; + one_raw[0] = 1; + return mul(self, .{ .limbs = one_raw }).limbs; + } + + /// Alias used by the OG callers. + pub fn fromMontgomery(self: Self) Self { + var one_raw: [N]u64 = .{0} ** N; + one_raw[0] = 1; + return mul(self, .{ .limbs = one_raw }); + } + + /// Serialize to big-endian bytes, returning by value. + /// This matches the OG `BN254Scalar.toBytesBE()` signature. + pub fn toBytesBE(self: Self) [N * 8]u8 { + var out: [N * 8]u8 = undefined; + const raw = toRaw(self); + bigint.toBytesBe(N, raw, &out); + return out; + } + + // ----------------------------------------------------------------- + // Utility helpers used by accumulators.zig and other low-level code + // ----------------------------------------------------------------- + + /// Wide multiply of two u64 words. Utility, not a field operation. + pub inline fn mulWide(a: u64, b: u64) u128 { + return @as(u128, a) * @as(u128, b); + } + + /// Add with carry — public alias of the internal helper. + pub const addCarry = addCarryFn; + + /// Subtract with borrow — public alias of the internal helper. + pub const subBorrow = subBorrowFn; + + // ----------------------------------------------------------------- + // Arithmetic + // ----------------------------------------------------------------- + + /// Carry-chain add helper. Uses LLVM adc intrinsic on x86, u128 elsewhere. + inline fn addCarryFn(aa: u64, bb: u64, carry_in: u64) struct { result: u64, carry: u64 } { + if (!@inComptime() and comptime x86_has_intrinsics) { + var result: u64 = undefined; + const c = x86_intrinsics.addcarry(@truncate(carry_in), aa, bb, &result); + return .{ .result = result, .carry = c }; + } + const s = @as(u128, aa) + @as(u128, bb) + @as(u128, carry_in); + return .{ .result = @truncate(s), .carry = @truncate(s >> 64) }; + } + + /// Borrow-chain sub helper. Uses LLVM sbb intrinsic on x86, u128 elsewhere. + inline fn subBorrowFn(aa: u64, bb: u64, borrow_in: u64) struct { result: u64, borrow: u64 } { + if (!@inComptime() and comptime x86_has_intrinsics) { + var result: u64 = undefined; + const b_out = x86_intrinsics.subborrow(@truncate(borrow_in), aa, bb, &result); + return .{ .result = result, .borrow = b_out }; + } + const wide_a = @as(u128, aa); + const wide_b = @as(u128, bb) + @as(u128, borrow_in); + const diff = wide_a -% wide_b; + return .{ .result = @truncate(diff), .borrow = @truncate(diff >> 127) }; + } + + /// Modular addition `(a + b) mod p`. + pub inline fn add(a: Self, b: Self) Self { + @setEvalBranchQuota(10000); + if (N == 4 and !@inComptime() and comptime use_arm64_asm) { + var res = Self{ .limbs = asm_mod.arm64Add256(a.limbs, b.limbs) }; + if (!res.lessThanModulus()) res = res.subtractModulus(); + return res; + } + var result: [N]u64 = undefined; + var carry: u64 = 0; + inline for (0..N) |i| { + const ac = addCarryFn(a.limbs[i], b.limbs[i], carry); + result[i] = ac.result; + carry = ac.carry; + } + var res = Self{ .limbs = result }; + if (carry != 0 or !res.lessThanModulus()) res = res.subtractModulus(); + return res; + } + + /// Modular subtraction `(a - b) mod p`. + pub inline fn sub(a: Self, b: Self) Self { + @setEvalBranchQuota(10000); + if (N == 4 and !@inComptime() and comptime use_arm64_asm) + return .{ .limbs = asm_mod.arm64SubMod256(a.limbs, b.limbs, modulus) }; + var result: [N]u64 = undefined; + var borrow: u64 = 0; + inline for (0..N) |i| { + const sb = subBorrowFn(a.limbs[i], b.limbs[i], borrow); + result[i] = sb.result; + borrow = sb.borrow; + } + if (borrow != 0) { + // Add modulus back + var c: u64 = 0; + inline for (0..N) |i| { + const ac = addCarryFn(result[i], modulus[i], c); + result[i] = ac.result; + c = ac.carry; + } + } + return .{ .limbs = result }; + } + + /// Modular negation `-a mod p`. + pub inline fn neg(a: Self) Self { + if (a.isZero()) return zero(); + return (Self{ .limbs = modulus }).sub(a); + } + + /// Squaring with asm dispatch for N=4. + pub inline fn square(a: Self) Self { + if (N == 4 and !@inComptime()) { + if (comptime use_arm64_asm) { + const mod_arr: [4]u64 = modulus; + var r = Self{ .limbs = asm_mod.arm64MontgomerySquare256(&a.limbs, &mod_arr, n_prime) }; + if (!r.lessThanModulus()) r = r.subtractModulus(); + return r; + } + // x86 BMI2+ADX square uses the mul path (mul(a,a) is fine + // since the x86 mul is already fast; dedicated square asm + // can be added later for ~15% additional gain). + if (comptime use_asm_mul) return x86MontgomeryMul4(a, a); + } + return montgomeryMul(a, a); + } + + /// Exponentiation `a^e mod p` by square-and-multiply. + /// `e` is a raw little-endian limb array (NOT in Montgomery form). + pub fn powLimbs(a: Self, exponent: [N]u64) Self { + const top_bit = bigint.bitLen(N, exponent); + if (top_bit == 0) return one(); + var result = a; + var i = top_bit - 1; + while (i > 0) { + i -= 1; + result = square(result); + const limb = i / 64; + const bit: u6 = @intCast(i % 64); + if (((exponent[limb] >> bit) & 1) == 1) { + result = montgomeryMul(result, a); + } + } + return result; + } + + /// Modular inverse via Fermat's little theorem, returning null for zero. + pub fn inverse(a: Self) ?Self { + if (a.isZero()) return null; + + // Square-and-multiply: a^{p-2} mod p + // Runtime loop to keep code compact (inline for unrolls N*64 + // iterations, blowing the icache for N >= 4). + const p_minus_two: [N]u64 = comptime blk: { + var exp: [N]u64 = modulus; + exp[0] -= 2; + break :blk exp; + }; + + var result = Self.one(); + var base = a; + + inline for (0..N) |i| { + var bits = p_minus_two[i]; + var j: usize = 0; + while (j < 64) : (j += 1) { + if ((bits & 1) != 0) { + result = result.mul(base); + } + base = base.square(); + bits >>= 1; + } + } + return result; + } + + /// Non-optional inverse (returns zero for zero). Matches the + /// standalone zolt-arith naming convention used by BLS12-381 code. + pub fn inv(a: Self) Self { + return a.inverse() orelse zero(); + } + + /// CIOS Montgomery multiplication: `a * b * R^{-1} mod p`. + /// Uses [N+1]u64 accumulator matching the original factory. + pub fn montgomeryMul(a: Self, b: Self) Self { + var t: [N + 1]u64 = .{0} ** (N + 1); + + inline for (0..N) |i| { + var carry: u64 = 0; + inline for (0..N) |j| { + const prod = @as(u128, a.limbs[i]) * @as(u128, b.limbs[j]); + const s = @as(u128, t[j]) + prod + @as(u128, carry); + t[j] = @truncate(s); + carry = @truncate(s >> 64); + } + const sum_tn = @as(u128, t[N]) + @as(u128, carry); + t[N] = @truncate(sum_tn); + + // Reduce: m = t[0] * n_prime mod 2^64 + const m = t[0] *% n_prime; + carry = 0; + { + const prod0 = @as(u128, m) * @as(u128, modulus[0]); + const s0 = @as(u128, t[0]) + prod0; + carry = @truncate(s0 >> 64); + } + inline for (1..N) |j| { + const prod = @as(u128, m) * @as(u128, modulus[j]); + const s = @as(u128, t[j]) + prod + @as(u128, carry); + t[j - 1] = @truncate(s); + carry = @truncate(s >> 64); + } + const final_sum = @as(u128, t[N]) + @as(u128, carry); + t[N - 1] = @truncate(final_sum); + t[N] = @truncate(final_sum >> 64); + } + + var result = Self{ .limbs = undefined }; + inline for (0..N) |i| result.limbs[i] = t[i]; + if (t[N] != 0 or !result.lessThanModulus()) return result.subtractModulus(); + return result; + } + + /// Alias matching the standalone's name. + pub const montMul = montgomeryMul; + + /// R^2 as a field element (for compatibility with the OG + /// BN254Scalar.rSquared). Returns `fromRaw(R2)`. + pub fn rSquared() Self { + return fromRaw(r2); + } + + // ----------------------------------------------------------------- + // Aliases and JoltField-compatible surface + // ----------------------------------------------------------------- + + /// Field multiplication with asm dispatch for N=4. + pub inline fn mul(a: Self, b: Self) Self { + if (N == 4 and !@inComptime()) { + if (comptime use_arm64_asm) { + const mod_arr: [4]u64 = modulus; + var r = Self{ .limbs = asm_mod.arm64MontgomeryMul256(&a.limbs, &b.limbs, &mod_arr, n_prime) }; + if (!r.lessThanModulus()) r = r.subtractModulus(); + return r; + } + // x86 BMI2+ADX: use the asm from the old MontgomeryField factory + if (comptime use_asm_mul) return x86MontgomeryMul4(a, b); + } + return montgomeryMul(a, b); + } + + /// x86 BMI2+ADX CIOS Montgomery mul (N=4 only). + fn x86MontgomeryMul4(a_self: Self, b_other: Self) Self { + const a = a_self.limbs; + const b = b_other.limbs; + const mod_arr: [4]u64 = modulus; + var out0: u64 = undefined; + var out1: u64 = undefined; + var out2: u64 = undefined; + var out3: u64 = undefined; + asm volatile ( + \\xorq %%rcx, %%rcx + \\movq (%%rdi), %%rdx + \\mulxq (%%rsi), %%r8, %%r9 + \\mulxq 8(%%rsi), %%rax, %%r10 + \\adcxq %%rax, %%r9 + \\mulxq 16(%%rsi), %%rax, %%r11 + \\adcxq %%rax, %%r10 + \\mulxq 24(%%rsi), %%rax, %%rcx + \\movq $0, %%r13 + \\adcxq %%rax, %%r11 + \\adcxq %%r13, %%rcx + \\movq %%rbx, %%rdx + \\mulxq %%r8, %%rdx, %%rax + \\mulxq (%%r14), %%rax, %%r13 + \\adcxq %%r8, %%rax + \\adoxq %%r13, %%r9 + \\mulxq 8(%%r14), %%rax, %%r13 + \\adcxq %%rax, %%r9 + \\adoxq %%r13, %%r10 + \\mulxq 16(%%r14), %%rax, %%r13 + \\adcxq %%rax, %%r10 + \\adoxq %%r13, %%r11 + \\mulxq 24(%%r14), %%rax, %%r8 + \\movq $0, %%r13 + \\adcxq %%rax, %%r11 + \\adoxq %%rcx, %%r8 + \\adcxq %%r13, %%r8 + \\movq 8(%%rdi), %%rdx + \\mulxq (%%rsi), %%rax, %%r13 + \\adcxq %%rax, %%r9 + \\adoxq %%r13, %%r10 + \\mulxq 8(%%rsi), %%rax, %%r13 + \\adcxq %%rax, %%r10 + \\adoxq %%r13, %%r11 + \\mulxq 16(%%rsi), %%rax, %%r13 + \\adcxq %%rax, %%r11 + \\adoxq %%r13, %%r8 + \\mulxq 24(%%rsi), %%rax, %%rcx + \\movq $0, %%r13 + \\adcxq %%rax, %%r8 + \\adoxq %%r13, %%rcx + \\adcxq %%r13, %%rcx + \\movq %%rbx, %%rdx + \\mulxq %%r9, %%rdx, %%rax + \\mulxq (%%r14), %%rax, %%r13 + \\adcxq %%r9, %%rax + \\adoxq %%r13, %%r10 + \\mulxq 8(%%r14), %%rax, %%r13 + \\adcxq %%rax, %%r10 + \\adoxq %%r13, %%r11 + \\mulxq 16(%%r14), %%rax, %%r13 + \\adcxq %%rax, %%r11 + \\adoxq %%r13, %%r8 + \\mulxq 24(%%r14), %%rax, %%r9 + \\movq $0, %%r13 + \\adcxq %%rax, %%r8 + \\adoxq %%rcx, %%r9 + \\adcxq %%r13, %%r9 + \\movq 16(%%rdi), %%rdx + \\mulxq (%%rsi), %%rax, %%r13 + \\adcxq %%rax, %%r10 + \\adoxq %%r13, %%r11 + \\mulxq 8(%%rsi), %%rax, %%r13 + \\adcxq %%rax, %%r11 + \\adoxq %%r13, %%r8 + \\mulxq 16(%%rsi), %%rax, %%r13 + \\adcxq %%rax, %%r8 + \\adoxq %%r13, %%r9 + \\mulxq 24(%%rsi), %%rax, %%rcx + \\movq $0, %%r13 + \\adcxq %%rax, %%r9 + \\adoxq %%r13, %%rcx + \\adcxq %%r13, %%rcx + \\movq %%rbx, %%rdx + \\mulxq %%r10, %%rdx, %%rax + \\mulxq (%%r14), %%rax, %%r13 + \\adcxq %%r10, %%rax + \\adoxq %%r13, %%r11 + \\mulxq 8(%%r14), %%rax, %%r13 + \\adcxq %%rax, %%r11 + \\adoxq %%r13, %%r8 + \\mulxq 16(%%r14), %%rax, %%r13 + \\adcxq %%rax, %%r8 + \\adoxq %%r13, %%r9 + \\mulxq 24(%%r14), %%rax, %%r10 + \\movq $0, %%r13 + \\adcxq %%rax, %%r9 + \\adoxq %%rcx, %%r10 + \\adcxq %%r13, %%r10 + \\movq 24(%%rdi), %%rdx + \\mulxq (%%rsi), %%rax, %%r13 + \\adcxq %%rax, %%r11 + \\adoxq %%r13, %%r8 + \\mulxq 8(%%rsi), %%rax, %%r13 + \\adcxq %%rax, %%r8 + \\adoxq %%r13, %%r9 + \\mulxq 16(%%rsi), %%rax, %%r13 + \\adcxq %%rax, %%r9 + \\adoxq %%r13, %%r10 + \\mulxq 24(%%rsi), %%rax, %%rcx + \\movq $0, %%r13 + \\adcxq %%rax, %%r10 + \\adoxq %%r13, %%rcx + \\adcxq %%r13, %%rcx + \\movq %%rbx, %%rdx + \\mulxq %%r11, %%rdx, %%rax + \\mulxq (%%r14), %%rax, %%r13 + \\adcxq %%r11, %%rax + \\adoxq %%r13, %%r8 + \\mulxq 8(%%r14), %%rax, %%r13 + \\adcxq %%rax, %%r8 + \\adoxq %%r13, %%r9 + \\mulxq 16(%%r14), %%rax, %%r13 + \\adcxq %%rax, %%r9 + \\adoxq %%r13, %%r10 + \\mulxq 24(%%r14), %%rax, %%r11 + \\movq $0, %%r13 + \\adcxq %%rax, %%r10 + \\adoxq %%rcx, %%r11 + \\adcxq %%r13, %%r11 + : [_r0] "={r8}" (out0), + [_r1] "={r9}" (out1), + [_r2] "={r10}" (out2), + [_r3] "={r11}" (out3), + : [_a] "{rdi}" (&a), + [_b] "{rsi}" (&b), + [_mod] "{r14}" (&mod_arr), + [_inv] "{rbx}" (n_prime), + : .{ .rax = true, .rcx = true, .rdx = true, .r13 = true, .cc = true, .memory = true }); + var result = Self{ .limbs = .{ out0, out1, out2, out3 } }; + if (!result.lessThanModulus()) result = result.subtractModulus(); + return result; + } + + /// Doubling (2*self). + pub inline fn double(self: Self) Self { + return self.add(self); + } + + /// Serialize to little-endian bytes, returning by value. + pub fn toBytes(self: Self) [N * 8]u8 { + var out: [N * 8]u8 = undefined; + const raw = toRaw(self); + bigint.toBytesLe(N, raw, &out); + return out; + } + + /// Extract the low u64 from the raw (non-Montgomery) form. + pub fn toU64(self: Self) u64 { + return toRaw(self)[0]; + } + + /// Field element byte count. + pub const FIELD_ELEMENT_BYTES: usize = N * 8; + + // ----------------------------------------------------------------- + // Helpers for reduction / unreduced paths (used by OG callers) + // ----------------------------------------------------------------- + + /// Check if the limbs are strictly less than the modulus (MSB-first). + /// Unrolled comparison matching the original MontgomeryField factory. + inline fn lessThanModulus(self: Self) bool { + // Unrolled MSB-first comparison — const modulus enables LLVM to + // emit cmpq with immediates. + comptime var idx: usize = N; + inline while (idx > 0) { + idx -= 1; + if (self.limbs[idx] != modulus[idx]) return self.limbs[idx] < modulus[idx]; + } + return self.limbs[0] < modulus[0]; + } + + /// Unconditional subtraction of the modulus. + pub inline fn subtractModulus(self: Self) Self { + @setEvalBranchQuota(10000); + if (N == 4 and !@inComptime() and comptime use_arm64_asm) + return .{ .limbs = asm_mod.arm64Sub256(self.limbs, modulus) }; + var result: [N]u64 = undefined; + var borrow: u64 = 0; + inline for (0..N) |i| { + const sb = subBorrowFn(self.limbs[i], modulus[i], borrow); + result[i] = sb.result; + borrow = sb.borrow; + } + return .{ .limbs = result }; + } + + /// Addition without final reduction. Result in [0, 2p). + pub inline fn addNoReduce(self: Self, other: Self) Self { + @setEvalBranchQuota(10000); + var result: [N]u64 = undefined; + var carry: u64 = 0; + inline for (0..N) |i| { + const ac = addCarryFn(self.limbs[i], other.limbs[i], carry); + result[i] = ac.result; + carry = ac.carry; + } + var res = Self{ .limbs = result }; + if (carry != 0) res = res.subtractModulus(); + return res; + } + + /// Reduce from [0, 2p) to [0, p). + pub inline fn reduce(self: Self) Self { + if (!self.lessThanModulus()) return self.subtractModulus(); + return self; + } + + // (inverseOpt removed — `inverse()` itself returns `?Self` now) + + // ----------------------------------------------------------------- + // Extended arithmetic (used by zolt's polynomial / MSM / accumulator code) + // ----------------------------------------------------------------- + + /// Fused multiply-accumulate: a[0]*b[0] + a[1]*b[1] with 2 reductions + /// instead of 3 (vs separate mul + mul + add). Interleaved CIOS. + pub inline fn sumOfProducts(a_pair: [2]Self, b_pair: [2]Self) Self { + var t: [N + 1]u64 = .{0} ** (N + 1); + + inline for (0..N) |i| { + var carry1: u64 = 0; + inline for (0..2) |pair| { + var carry: u64 = 0; + inline for (0..N) |j| { + const prod = @as(u128, a_pair[pair].limbs[i]) * @as(u128, b_pair[pair].limbs[j]); + const s = @as(u128, t[j]) + prod + @as(u128, carry); + t[j] = @truncate(s); + carry = @truncate(s >> 64); + } + const s_tn = @as(u128, t[N]) + @as(u128, carry) + @as(u128, carry1); + t[N] = @truncate(s_tn); + carry1 = @truncate(s_tn >> 64); + } + + // Montgomery reduction step (shared) + const m = t[0] *% n_prime; + var carry: u64 = 0; + { + const prod0 = @as(u128, m) * @as(u128, modulus[0]) + @as(u128, t[0]); + carry = @truncate(prod0 >> 64); + } + inline for (1..N) |j| { + const prod = @as(u128, m) * @as(u128, modulus[j]) + @as(u128, t[j]) + @as(u128, carry); + t[j - 1] = @truncate(prod); + carry = @truncate(prod >> 64); + } + const final_sum = @as(u128, t[N]) + @as(u128, carry); + t[N - 1] = @truncate(final_sum); + t[N] = @as(u64, @truncate(final_sum >> 64)) +% carry1; + } + + var result: Self = undefined; + inline for (0..N) |i| result.limbs[i] = t[i]; + if (t[N] != 0 or !result.lessThanModulus()) result = result.subtractModulus(); + return result; + } + + /// Multiply by a signed 128-bit integer. + pub fn mulI128(self: Self, val: i128) Self { + if (val == 0) return Self.zero(); + if (val == 1) return self; + if (val == -1) return self.neg(); + if (val > 0) { + return self.mulU128(@intCast(val)); + } else { + return self.mulU128(@intCast(-val)).neg(); + } + } + + fn mulU128(self: Self, val: u128) Self { + if (val == 0) return Self.zero(); + if (val == 1) return self; + const low: u64 = @truncate(val); + const high: u64 = @truncate(val >> 64); + var other = Self.fromU64(low); + if (high != 0) { + const high_fe = Self.fromU64(high); + var two_64 = Self.fromU64(1); + for (0..64) |_| two_64 = two_64.double(); + other = other.add(high_fe.mul(two_64)); + } + return self.mul(other); + } + + /// Exponentiation with a u64 exponent. + pub fn pow(self: Self, exp: u64) Self { + if (exp == 0) return one(); + var result = Self.one(); + var base = self; + var e = exp; + while (e > 0) { + if (e & 1 == 1) result = result.mul(base); + base = base.square(); + e >>= 1; + } + return result; + } + + /// Batch inversion using Montgomery's trick. + /// Inverts n elements in-place with 1 inversion + 3(n-1) muls. + pub fn batchInversion(elements: []Self, scratch: []Self) void { + const n = elements.len; + if (n == 0) return; + var acc = one(); + for (0..n) |i| { + scratch[i] = acc; + if (!elements[i].isZero()) acc = acc.mul(elements[i]); + } + var inv_acc = acc.inverse() orelse unreachable; + var i: usize = n; + while (i > 0) { + i -= 1; + if (elements[i].isZero()) continue; + const old = elements[i]; + elements[i] = scratch[i].mul(inv_acc); + inv_acc = inv_acc.mul(old); + } + } + + /// Optimized multiplication by a 128-bit value in the top 2 limbs + /// of a [N]u64 array. Only the limbs at indices [N-2] and [N-1] are + /// non-zero. Uses 2 iterations of CIOS instead of N. + pub fn mulHiBigIntU128(self: Self, hi_limbs: [N]u64) Self { + // Process only the top 2 non-zero limbs + const limb_n2 = hi_limbs[N - 2]; + const limb_n1 = hi_limbs[N - 1]; + + var r: [N]u64 = .{0} ** N; + + // Iteration for limb at index N-2 + inline for ([_]u64{ limb_n2, limb_n1 }) |limb_val| { + var carry1: u64 = 0; + { + const prod0 = mulWide(self.limbs[0], limb_val); + const sum0 = @as(u128, r[0]) + prod0 + @as(u128, carry1); + r[0] = @truncate(sum0); + carry1 = @truncate(sum0 >> 64); + } + const k = r[0] *% n_prime; + var carry2: u64 = 0; + { + const red0 = mulWide(k, modulus[0]); + const red_sum0 = @as(u128, r[0]) + red0; + carry2 = @truncate(red_sum0 >> 64); + } + inline for (1..N) |j| { + const prod_j = mulWide(self.limbs[j], limb_val); + const new_rj = @as(u128, r[j]) + prod_j + @as(u128, carry1); + const new_rj_trunc: u64 = @truncate(new_rj); + carry1 = @truncate(new_rj >> 64); + + const red_j = mulWide(k, modulus[j]); + const red_sum_j = @as(u128, new_rj_trunc) + red_j + @as(u128, carry2); + r[j - 1] = @truncate(red_sum_j); + carry2 = @truncate(red_sum_j >> 64); + + r[j] = new_rj_trunc; + } + r[N - 1] = carry1 +% carry2; + } + + var result = Self{ .limbs = r }; + if (!result.lessThanModulus()) result = result.subtractModulus(); + return result; + } + + /// Unreduced product accumulator — defers Montgomery reduction + /// across multiple multiply-accumulate steps. The generic version + /// uses `[2*N]u128` slots matching the OG `UnreducedProductAccum`. + pub const ProductAccum = struct { + slots: [2 * N]u128, + + pub inline fn zero() @This() { + return .{ .slots = .{0} ** (2 * N) }; + } + + /// Schoolbook N×N without reduction. + pub inline fn fromMul(a: Self, b: Self) @This() { + @setEvalBranchQuota(10000); + var slots: [2 * N]u128 = .{0} ** (2 * N); + inline for (0..N) |i| { + inline for (0..N) |j| { + const p: u128 = @as(u128, a.limbs[i]) * @as(u128, b.limbs[j]); + slots[i + j] += @as(u128, @as(u64, @truncate(p))); + slots[i + j + 1] += @as(u128, @as(u64, @truncate(p >> 64))); + } + } + return .{ .slots = slots }; + } + + /// Schoolbook N×2 (field_elem × raw u128) without reduction. + pub inline fn fromMulU128(field_elem: Self, raw: u128) @This() { + @setEvalBranchQuota(10000); + const a = field_elem.limbs; + const b: [2]u64 = .{ @truncate(raw), @truncate(raw >> 64) }; + var slots: [2 * N]u128 = .{0} ** (2 * N); + inline for (0..N) |i| { + inline for (0..2) |j| { + const p: u128 = @as(u128, a[i]) * @as(u128, b[j]); + slots[i + j] += @as(u128, @as(u64, @truncate(p))); + slots[i + j + 1] += @as(u128, @as(u64, @truncate(p >> 64))); + } + } + return .{ .slots = slots }; + } + + pub inline fn addAssign(self: *@This(), other: @This()) void { + inline for (0..(2 * N)) |i| { + self.slots[i] += other.slots[i]; + } + } + + pub inline fn add(self: @This(), other: @This()) @This() { + var result = self; + inline for (0..(2 * N)) |i| { + result.slots[i] += other.slots[i]; + } + return result; + } + + /// Reduce accumulated products to a field element via Montgomery + /// reduction. Handles overflow from heavy accumulation (up to + /// ~2^62 products). + pub fn reduce(self: @This()) Self { + // Step 1: Normalize [2N]u128 → [2N]u64 + overflow + var limbs_wide: [2 * N]u64 = undefined; + var carry: u128 = 0; + inline for (0..(2 * N)) |i| { + const s = self.slots[i] + carry; + limbs_wide[i] = @truncate(s); + carry = s >> 64; + } + const overflow_lo: u64 = @truncate(carry); + const overflow_hi: u64 = @truncate(carry >> 64); + + // Step 2: CIOS Montgomery reduction, folding upper limbs + var t: [N + 1]u64 = undefined; + inline for (0..N) |i| t[i] = limbs_wide[i]; + t[N] = 0; + + inline for (0..N) |i| { + const m = t[0] *% n_prime; + var c: u64 = 0; + { + const prod0 = @as(u128, m) * @as(u128, modulus[0]) + @as(u128, t[0]); + c = @truncate(prod0 >> 64); + } + inline for (1..N) |j| { + const prod = @as(u128, m) * @as(u128, modulus[j]) + @as(u128, t[j]) + @as(u128, c); + t[j - 1] = @truncate(prod); + c = @truncate(prod >> 64); + } + const final_sum = @as(u128, t[N]) + @as(u128, c) + @as(u128, limbs_wide[i + N]); + t[N - 1] = @truncate(final_sum); + t[N] = @truncate(final_sum >> 64); + } + + // Step 3: Multi-subtract for accumulated overflow + var result = Self{ .limbs = undefined }; + inline for (0..N) |i| result.limbs[i] = t[i]; + var extra = t[N]; + var iters: u32 = 0; + while (extra != 0 or !result.lessThanModulus()) : (iters += 1) { + std.debug.assert(iters < N + 2); + const was_less = result.lessThanModulus(); + result = result.subtractModulus(); + if (was_less) extra -= 1; + } + + // Step 4: Add overflow contribution + if (overflow_lo != 0 or overflow_hi != 0) { + var raw: [N]u64 = .{0} ** N; + raw[0] = overflow_lo; + if (N > 1) raw[1] = overflow_hi; + result = Self.add(result, Self.toMontgomery(.{ .limbs = raw })); + } + + return result; + } + }; + + /// Create an unreduced product accumulator from `self * other`. + pub inline fn mulToProductAccum(self: Self, other: Self) ProductAccum { + return ProductAccum.fromMul(self, other); + } + + /// Debug formatter for field elements. + pub fn format(self: Self, comptime fmt: []const u8, options: std.fmt.FormatOptions, writer: anytype) !void { + _ = fmt; + _ = options; + const raw = toRaw(self); + try writer.writeAll("0x"); + var i: usize = N; + while (i > 0) { + i -= 1; + try std.fmt.formatInt(raw[i], 16, .lower, .{ .width = 16, .fill = '0' }, writer); + } + } + }; +} + + +// ========================================================================= +// Tests +// ========================================================================= + +const testing = std.testing; + +// Toy 4-limb prime: p = 2^255 - 19 (Curve25519's base field). +const ED25519_P: [4]u64 = .{ + 0xffffffffffffffed, 0xffffffffffffffff, + 0xffffffffffffffff, 0x7fffffffffffffff, +}; +const ED25519_R2: [4]u64 = .{ 0x5a4, 0, 0, 0 }; +const ED25519_N_PRIME: u64 = 0x86bca1af286bca1b; +const Ed25519Fp = MontgomeryField(4, ED25519_P, ED25519_R2, ED25519_N_PRIME); + +// BLS12-381 Fp (6-limb) constants for cross-checking the N=6 path. +const BLS12_381_FP_MODULUS: [6]u64 = .{ + 0xb9feffffffffaaab, 0x1eabfffeb153ffff, + 0x6730d2a0f6b0f624, 0x64774b84f38512bf, + 0x4b1ba7b6434bacd7, 0x1a0111ea397fe69a, +}; +const BLS12_381_FP_R2: [6]u64 = .{ + 0xf4df1f341c341746, 0x0a76e6a609d104f1, + 0x8de5476c4c95b6d5, 0x67eb88a9939d83c0, + 0x9a793e85b519952d, 0x11988fe592cae3aa, +}; +const BLS12_381_FP_N_PRIME: u64 = 0x89f3fffcfffcfffd; +const Bls12381Fp = MontgomeryField(6, BLS12_381_FP_MODULUS, BLS12_381_FP_R2, BLS12_381_FP_N_PRIME); + +// -- N=4 tests (Ed25519 base field) -- + +test "N=4 zero is additive identity" { + const z = Ed25519Fp.zero(); + const o = Ed25519Fp.one(); + try testing.expect(Ed25519Fp.eql(Ed25519Fp.add(z, z), z)); + try testing.expect(Ed25519Fp.eql(Ed25519Fp.add(o, z), o)); + try testing.expect(Ed25519Fp.eql(Ed25519Fp.add(z, o), o)); +} + +test "N=4 one is multiplicative identity" { + const o = Ed25519Fp.one(); + const a = Ed25519Fp.fromRaw(.{ 0x1234567890abcdef, 0, 0, 0 }); + try testing.expect(Ed25519Fp.eql(Ed25519Fp.montgomeryMul(a, o), a)); + try testing.expect(Ed25519Fp.eql(Ed25519Fp.montgomeryMul(o, a), a)); +} + +test "N=4 toRaw round-trips fromRaw" { + const raw: [4]u64 = .{ 0x0102030405060708, 0x0a0b0c0d0e0f0001, 0x1122334455667788, 0x12345678 }; + const m = Ed25519Fp.fromRaw(raw); + try testing.expectEqual(raw, Ed25519Fp.toRaw(m)); +} + +test "N=4 add wraps around modulus" { + const o = Ed25519Fp.fromRaw(.{ 1, 0, 0, 0 }); + const pm1 = Ed25519Fp.fromRaw(.{ ED25519_P[0] - 1, ED25519_P[1], ED25519_P[2], ED25519_P[3] }); + try testing.expect(Ed25519Fp.eql(Ed25519Fp.add(o, pm1), Ed25519Fp.zero())); +} + +test "N=4 sub borrows correctly" { + const z = Ed25519Fp.zero(); + const o = Ed25519Fp.fromRaw(.{ 1, 0, 0, 0 }); + const result = Ed25519Fp.sub(z, o); + const expected = Ed25519Fp.fromRaw(.{ ED25519_P[0] - 1, ED25519_P[1], ED25519_P[2], ED25519_P[3] }); + try testing.expect(Ed25519Fp.eql(result, expected)); +} + +test "N=4 mul: 2 * 3 = 6" { + const two = Ed25519Fp.fromRaw(.{ 2, 0, 0, 0 }); + const three = Ed25519Fp.fromRaw(.{ 3, 0, 0, 0 }); + const six = Ed25519Fp.fromRaw(.{ 6, 0, 0, 0 }); + try testing.expect(Ed25519Fp.eql(Ed25519Fp.montgomeryMul(two, three), six)); +} + +test "N=4 inv: a * a^-1 = 1" { + const a = Ed25519Fp.fromRaw(.{ 0x12345678, 0, 0, 0 }); + const inv_a = Ed25519Fp.inverse(a) orelse return error.SkipZigTest; + try testing.expect(Ed25519Fp.eql(Ed25519Fp.montgomeryMul(a, inv_a), Ed25519Fp.one())); +} + +test "N=4 .limbs access" { + const o = Ed25519Fp.one(); + // The struct wrapper provides `.limbs` for backward compat. + try testing.expect(o.limbs[0] != 0 or o.limbs[1] != 0 or o.limbs[2] != 0 or o.limbs[3] != 0); +} + +// -- N=6 tests (BLS12-381 base field) -- + +test "N=6 zero is additive identity" { + const z = Bls12381Fp.zero(); + const o = Bls12381Fp.one(); + try testing.expect(Bls12381Fp.eql(Bls12381Fp.add(z, z), z)); + try testing.expect(Bls12381Fp.eql(Bls12381Fp.add(o, z), o)); +} + +test "N=6 one is multiplicative identity" { + const o = Bls12381Fp.one(); + const a = Bls12381Fp.fromRaw(.{ 42, 0, 0, 0, 0, 0 }); + try testing.expect(Bls12381Fp.eql(Bls12381Fp.montgomeryMul(a, o), a)); +} + +test "N=6 toRaw round-trips fromRaw" { + const raw: [6]u64 = .{ 0x0102030405060708, 0x0a0b0c0d0e0f0001, 0x1122334455667788, 0x12345678, 0xabcd, 0 }; + const m = Bls12381Fp.fromRaw(raw); + try testing.expectEqual(raw, Bls12381Fp.toRaw(m)); +} + +test "N=6 add wraps around modulus" { + const o = Bls12381Fp.fromRaw(.{ 1, 0, 0, 0, 0, 0 }); + const pm1 = Bls12381Fp.fromRaw(.{ + BLS12_381_FP_MODULUS[0] - 1, + BLS12_381_FP_MODULUS[1], + BLS12_381_FP_MODULUS[2], + BLS12_381_FP_MODULUS[3], + BLS12_381_FP_MODULUS[4], + BLS12_381_FP_MODULUS[5], + }); + try testing.expect(Bls12381Fp.eql(Bls12381Fp.add(o, pm1), Bls12381Fp.zero())); +} + +test "N=6 mul: 2 * 3 = 6" { + const two = Bls12381Fp.fromRaw(.{ 2, 0, 0, 0, 0, 0 }); + const three = Bls12381Fp.fromRaw(.{ 3, 0, 0, 0, 0, 0 }); + const six = Bls12381Fp.fromRaw(.{ 6, 0, 0, 0, 0, 0 }); + try testing.expect(Bls12381Fp.eql(Bls12381Fp.montgomeryMul(two, three), six)); +} + +test "N=6 inv: a * a^-1 = 1" { + const a = Bls12381Fp.fromRaw(.{ 0x42, 0, 0, 0, 0, 0 }); + const inv_a = Bls12381Fp.inverse(a) orelse return error.SkipZigTest; + try testing.expect(Bls12381Fp.eql(Bls12381Fp.montgomeryMul(a, inv_a), Bls12381Fp.one())); +} + +test "N=6 montMul alias works" { + const two = Bls12381Fp.fromRaw(.{ 2, 0, 0, 0, 0, 0 }); + const three = Bls12381Fp.fromRaw(.{ 3, 0, 0, 0, 0, 0 }); + const via_mont = Bls12381Fp.montMul(two, three); + const via_full = Bls12381Fp.montgomeryMul(two, three); + try testing.expect(Bls12381Fp.eql(via_mont, via_full)); +} + +// -- Cross-N test: distributive property -- + +test "N=4 distributive: (a+b)*c == a*c + b*c" { + const a = Ed25519Fp.fromRaw(.{ 0x1234, 0x5678, 0, 0 }); + const b = Ed25519Fp.fromRaw(.{ 0x9abc, 0xdef0, 0x1234, 0 }); + const c = Ed25519Fp.fromRaw(.{ 0x42, 0, 0, 0 }); + const lhs = Ed25519Fp.montgomeryMul(Ed25519Fp.add(a, b), c); + const rhs = Ed25519Fp.add(Ed25519Fp.montgomeryMul(a, c), Ed25519Fp.montgomeryMul(b, c)); + try testing.expect(Ed25519Fp.eql(lhs, rhs)); +} + +test "N=6 distributive: (a+b)*c == a*c + b*c" { + const a = Bls12381Fp.fromRaw(.{ 0x1234, 0x5678, 0, 0, 0, 0 }); + const b = Bls12381Fp.fromRaw(.{ 0x9abc, 0xdef0, 0x1234, 0, 0, 0 }); + const c = Bls12381Fp.fromRaw(.{ 0x42, 0, 0, 0, 0, 0 }); + const lhs = Bls12381Fp.montgomeryMul(Bls12381Fp.add(a, b), c); + const rhs = Bls12381Fp.add(Bls12381Fp.montgomeryMul(a, c), Bls12381Fp.montgomeryMul(b, c)); + try testing.expect(Bls12381Fp.eql(lhs, rhs)); +} diff --git a/packages/zolt-arith/src/field/accumulators.zig b/packages/zolt-arith/src/field/accumulators.zig index 8ab49dbd..531c08bc 100644 --- a/packages/zolt-arith/src/field/accumulators.zig +++ b/packages/zolt-arith/src/field/accumulators.zig @@ -18,26 +18,11 @@ const BN254_INV = mod.BN254_INV; const BN254BaseField = mod.BN254BaseField; /// Unreduced product accumulator for deferred Montgomery reduction. -/// -/// Stores partial products in positional `u128` slots to avoid Montgomery reduction -/// in hot accumulation loops. Each slot holds a sum of u64×u64 partial products; -/// carries between slots are deferred until `reduce()`. This mirrors Jolt's -/// `Folded256ProductAccum` type. -/// -/// ## Usage -/// ``` -/// var accum = UnreducedProductAccum.zero(); -/// for (a_vals, b_vals) |a, b| { -/// accum.addAssign(a.mulToProductAccum(b)); -/// } -/// const result = accum.reduce(); // single Montgomery reduction -/// ``` -/// -/// ## Overflow Safety -/// Each `fromMul` contributes at most `4 × (2^64-1)` to any slot. After N `addAssign` -/// calls, max slot value = `N × 4 × (2^64-1)`. With u128 max = 2^128-1, safe for -/// N up to ~2^62. At T=2^30 with E_in=2^15, N=32768 → max slot ≈ 2^79, well within bounds. -pub const UnreducedProductAccum = struct { +/// Now an alias for the generic factory's ProductAccum type. +pub const UnreducedProductAccum = BN254Scalar.ProductAccum; + +/// Legacy UnreducedProductAccum implementation kept for reference. +const _UnreducedProductAccum_legacy = struct { slots: [8]u128, const Self = @This(); @@ -692,19 +677,31 @@ pub const WideAccumS = struct { } }; -/// LLVM carry/borrow intrinsics — map to single adc/sbb instructions on x86-64. -/// Wrapped in a comptime-conditional struct so they are not emitted on non-x86 targets. -const x86 = if (builtin.cpu.arch == .x86_64) struct { +// Carry/borrow: use LLVM intrinsics in Release, portable u128 in Debug. +const has_x86_intrinsics = builtin.cpu.arch == .x86_64 and builtin.mode != .Debug; +const x86 = if (has_x86_intrinsics) struct { extern fn @"llvm.x86.addcarry.u64"(c_in: u8, a: u64, b: u64, result: *u64) u8; extern fn @"llvm.x86.subborrow.u64"(b_in: u8, a: u64, b: u64, result: *u64) u8; - pub inline fn addcarry(c_in: u8, a: u64, b: u64, result: *u64) u8 { return @"llvm.x86.addcarry.u64"(c_in, a, b, result); } pub inline fn subborrow(b_in: u8, a: u64, b: u64, result: *u64) u8 { return @"llvm.x86.subborrow.u64"(b_in, a, b, result); } -} else struct {}; +} else struct { + pub inline fn addcarry(c_in: u8, a: u64, b: u64, result: *u64) u8 { + const sum = @as(u128, a) + @as(u128, b) + @as(u128, c_in); + result.* = @truncate(sum); + return @truncate(sum >> 64); + } + pub inline fn subborrow(b_in: u8, a: u64, b: u64, result: *u64) u8 { + const wide_a = @as(u128, a); + const wide_b = @as(u128, b) + @as(u128, b_in); + const diff = wide_a -% wide_b; + result.* = @truncate(diff); + return @truncate(diff >> 127); + } +}; /// Comptime flag: true when targeting AArch64 (all AArch64 has adds/adcs/subs/sbcs). const use_arm64_asm = (builtin.cpu.arch == .aarch64); diff --git a/packages/zolt-arith/src/field/extensions.zig b/packages/zolt-arith/src/field/extensions.zig index ea17a6f0..ce82f77f 100644 --- a/packages/zolt-arith/src/field/extensions.zig +++ b/packages/zolt-arith/src/field/extensions.zig @@ -156,248 +156,38 @@ pub fn fp2ScalarMul(a: Fp2, s: Fp) Fp2 { // Extension Field Fp2 = Fp[u] / (u^2 + 1) // ============================================================================ -/// Fp2 element: a + b*u where u^2 = -1 -pub const Fp2 = struct { - c0: Fp, // Real part - c1: Fp, // Imaginary part - - pub fn init(c0: Fp, c1: Fp) Fp2 { - return .{ .c0 = c0, .c1 = c1 }; - } - - pub fn zero() Fp2 { - return .{ .c0 = Fp.zero(), .c1 = Fp.zero() }; - } - - pub fn one() Fp2 { - return .{ .c0 = Fp.one(), .c1 = Fp.zero() }; - } - - pub fn add(self: Fp2, other: Fp2) Fp2 { - return .{ - .c0 = self.c0.add(other.c0), - .c1 = self.c1.add(other.c1), - }; - } - - pub fn sub(self: Fp2, other: Fp2) Fp2 { - return .{ - .c0 = self.c0.sub(other.c0), - .c1 = self.c1.sub(other.c1), - }; - } - - pub fn mul(self: Fp2, other: Fp2) Fp2 { - // sumOfProducts fusion: 2 fused reductions instead of 3 separate muls - // c0 = a0*b0 - a1*b1 (NONRESIDUE = -1) - // c1 = a0*b1 + a1*b0 - const neg_a1 = self.c1.neg(); - return .{ - .c0 = Fp.sumOfProducts(.{ self.c0, neg_a1 }, .{ other.c0, other.c1 }), - .c1 = Fp.sumOfProducts(.{ self.c0, self.c1 }, .{ other.c1, other.c0 }), - }; - } - - pub fn square(self: Fp2) Fp2 { - // (a + bu)^2 = (a^2 - b^2) + 2abu - // Karatsuba optimization: (a+b)(a-b) = a^2 - b^2 - const a_plus_b = self.c0.addNoReduce(self.c1); // [0, 2p), saves 1 reduction - const a_minus_b = self.c0.sub(self.c1); // [0, p) - const a_squared_minus_b_squared = a_plus_b.mul(a_minus_b); - const two_ab = self.c0.mul(self.c1).double(); - - return .{ - .c0 = a_squared_minus_b_squared, - .c1 = two_ab, - }; - } - - pub fn neg(self: Fp2) Fp2 { - return .{ - .c0 = self.c0.neg(), - .c1 = self.c1.neg(), - }; - } - - /// Conjugate: a + bu -> a - bu - pub fn conjugate(self: Fp2) Fp2 { - return .{ - .c0 = self.c0, - .c1 = self.c1.neg(), - }; - } - - /// Inverse using the formula: 1/(a + bu) = (a - bu)/(a^2 + b^2) - pub fn inverse(self: Fp2) ?Fp2 { - const norm = self.c0.square().add(self.c1.square()); - const norm_inv = norm.inverse() orelse return null; - - return .{ - .c0 = self.c0.mul(norm_inv), - .c1 = self.c1.neg().mul(norm_inv), - }; - } - - pub fn eql(self: Fp2, other: Fp2) bool { - return self.c0.eql(other.c0) and self.c1.eql(other.c1); - } - - pub fn isZero(self: Fp2) bool { - return self.c0.isZero() and self.c1.isZero(); - } -}; +/// Fp2 element: a + b*u where u^2 = -1. +/// Now routed through the curve-generic Fp2 factory. +pub const Fp2 = @import("../curves/extensions.zig").Fp2(Fp); // ============================================================================ // Extension Field Fp6 = Fp2[v] / (v^3 - xi) where xi = 9 + u +// Now routed through the curve-generic Fp6 factory. // ============================================================================ -/// Fp6 element: c0 + c1*v + c2*v^2 where v^3 = xi -pub const Fp6 = struct { - c0: Fp2, - c1: Fp2, - c2: Fp2, - - pub fn zero() Fp6 { - return .{ .c0 = Fp2.zero(), .c1 = Fp2.zero(), .c2 = Fp2.zero() }; - } - - pub fn one() Fp6 { - return .{ .c0 = Fp2.one(), .c1 = Fp2.zero(), .c2 = Fp2.zero() }; - } - - pub fn add(self: Fp6, other: Fp6) Fp6 { - return .{ - .c0 = self.c0.add(other.c0), - .c1 = self.c1.add(other.c1), - .c2 = self.c2.add(other.c2), - }; - } - - pub fn sub(self: Fp6, other: Fp6) Fp6 { - return .{ - .c0 = self.c0.sub(other.c0), - .c1 = self.c1.sub(other.c1), - .c2 = self.c2.sub(other.c2), - }; - } - - pub fn neg(self: Fp6) Fp6 { - return .{ - .c0 = self.c0.neg(), - .c1 = self.c1.neg(), - .c2 = self.c2.neg(), - }; - } - - /// Multiplication by xi = 9 + u (the non-residue for BN254) - /// Uses shift-add (9x = 8x + x): 4 additions instead of 1 Montgomery mul - pub fn mulByXi(x: Fp2) Fp2 { - const a = x.c0; - const a2 = a.add(a); - const a4 = a2.add(a2); - const a8 = a4.add(a4); - const a9 = a8.add(a); - - const b = x.c1; - const b2 = b.add(b); - const b4 = b2.add(b2); - const b8 = b4.add(b4); - const b9 = b8.add(b); - - return Fp2.init( - a9.sub(b), - a.add(b9), - ); - } - - pub fn mul(self: Fp6, other: Fp6) Fp6 { - // Karatsuba-like multiplication for cubic extension - const v0 = self.c0.mul(other.c0); - const v1 = self.c1.mul(other.c1); - const v2 = self.c2.mul(other.c2); - - // c0 = v0 + xi((c1 + c2)(d1 + d2) - v1 - v2) - const c1_plus_c2 = self.c1.add(self.c2); - const d1_plus_d2 = other.c1.add(other.c2); - const t0 = mulByXi(c1_plus_c2.mul(d1_plus_d2).sub(v1).sub(v2)); - const new_c0 = v0.add(t0); - - // c1 = (c0 + c1)(d0 + d1) - v0 - v1 + xi*v2 - const c0_plus_c1 = self.c0.add(self.c1); - const d0_plus_d1 = other.c0.add(other.c1); - const t1 = c0_plus_c1.mul(d0_plus_d1).sub(v0).sub(v1); - const new_c1 = t1.add(mulByXi(v2)); - - // c2 = (c0 + c2)(d0 + d2) - v0 - v2 + v1 - const c0_plus_c2 = self.c0.add(self.c2); - const d0_plus_d2 = other.c0.add(other.c2); - const t2 = c0_plus_c2.mul(d0_plus_d2).sub(v0).sub(v2); - const new_c2 = t2.add(v1); - - return .{ .c0 = new_c0, .c1 = new_c1, .c2 = new_c2 }; - } - - /// Chung-Hasan SQ2 squaring: 2 Fp2.mul + 3 Fp2.square instead of 6 Fp2.mul - pub fn square(self: Fp6) Fp6 { - const s0 = self.c0.square(); - const ab = self.c0.mul(self.c1); - const s1 = ab.add(ab); // 2*c0*c1 - const s2 = self.c0.sub(self.c1).add(self.c2).square(); - const bc = self.c1.mul(self.c2); - const s3 = bc.add(bc); // 2*c1*c2 - const s4 = self.c2.square(); - - return .{ - .c0 = s0.add(mulByXi(s3)), - .c1 = s1.add(mulByXi(s4)), - .c2 = s1.add(s2).add(s3).sub(s0).sub(s4), - }; - } - - pub fn inverse(self: Fp6) ?Fp6 { - // Extended Euclidean algorithm for Fp6 - const c0_sq = self.c0.square(); - const c1_sq = self.c1.square(); - const c2_sq = self.c2.square(); - const c0c1 = self.c0.mul(self.c1); - const c0c2 = self.c0.mul(self.c2); - const c1c2 = self.c1.mul(self.c2); - - // Using the formula for inverse in cubic extension - const a0 = c0_sq.sub(mulByXi(c1c2)); - const a1 = mulByXi(c2_sq).sub(c0c1); - const a2 = c1_sq.sub(c0c2); - - const tmp = mulByXi(self.c1.mul(a2).add(self.c2.mul(a1))); - const norm = self.c0.mul(a0).add(tmp); - - const norm_inv = norm.inverse() orelse return null; - - return .{ - .c0 = a0.mul(norm_inv), - .c1 = a1.mul(norm_inv), - .c2 = a2.mul(norm_inv), - }; - } - - pub fn eql(self: Fp6, other: Fp6) bool { - return self.c0.eql(other.c0) and self.c1.eql(other.c1) and self.c2.eql(other.c2); - } +const generic_ext = @import("../curves/extensions.zig"); + +/// BN254 cubic non-residue: multiply Fp2 by ξ = 9 + u (shift-add). +fn bn254MulByXi(x: Fp2) Fp2 { + const a = x.c0; + const a2 = a.add(a); + const a4 = a2.add(a2); + const a8 = a4.add(a4); + const a9 = a8.add(a); + const b = x.c1; + const b2 = b.add(b); + const b4 = b2.add(b2); + const b8 = b4.add(b4); + const b9 = b8.add(b); + return Fp2.init(a9.sub(b), a.add(b9)); +} - // Note: Fp6 frobenius is not needed as a standalone method. - // We implement it directly in Fp12 frobenius to properly handle - // the coefficient structure across all 6 Fp2 components. -}; +pub const Fp6 = generic_ext.Fp6(Fp2, bn254MulByXi); -/// Fp6 multiplication by v (shift operation) -/// For Fp6 = Fp2[v]/(v^3 - xi), multiplying by v shifts coefficients: -/// (c0 + c1*v + c2*v^2) * v = c2*xi + c0*v + c1*v^2 +/// Fp6 multiplication by v (shift operation). +/// Delegates to the generic factory's `mulByV`. pub fn fp6MulByV(f: Fp6) Fp6 { - return Fp6{ - .c0 = Fp6.mulByXi(f.c2), - .c1 = f.c0, - .c2 = f.c1, - }; + return Fp6.mulByV(f); } // ============================================================================ diff --git a/packages/zolt-arith/src/field/mod.zig b/packages/zolt-arith/src/field/mod.zig index ac595774..bc3fe9f9 100644 --- a/packages/zolt-arith/src/field/mod.zig +++ b/packages/zolt-arith/src/field/mod.zig @@ -7,9 +7,9 @@ const std = @import("std"); const builtin = @import("builtin"); const testdata = @import("../testdata.zig"); -/// LLVM carry/borrow intrinsics — map to single adc/sbb instructions on x86-64. -/// Wrapped in a comptime-conditional struct so they are not emitted on non-x86 targets. -const x86 = if (builtin.cpu.arch == .x86_64) struct { +/// LLVM carry/borrow intrinsics — single adc/sbb on x86-64. +/// Only in non-Debug builds (LLVM inlines them in Release; linker fails in Debug). +const x86 = if (builtin.cpu.arch == .x86_64 and builtin.mode != .Debug) struct { extern fn @"llvm.x86.addcarry.u64"(c_in: u8, a: u64, b: u64, result: *u64) u8; extern fn @"llvm.x86.subborrow.u64"(b_in: u8, a: u64, b: u64, result: *u64) u8; @@ -20,6 +20,7 @@ const x86 = if (builtin.cpu.arch == .x86_64) struct { return @"llvm.x86.subborrow.u64"(b_in, a, b, result); } } else struct {}; +const has_x86_intrinsics = builtin.cpu.arch == .x86_64 and builtin.mode != .Debug; /// Comptime flag: true when x86-64 BMI2+ADX instructions are available const use_asm_mul = blk: { @@ -38,7 +39,7 @@ const use_arm64_asm = (builtin.cpu.arch == .aarch64); /// Field subtraction: (a - b) mod p, branch-based. /// subs/sbcs chain, then b.cs skips correction if no borrow. -inline fn arm64SubMod256(a: [4]u64, b: [4]u64, mod: [4]u64) [4]u64 { +pub inline fn arm64SubMod256(a: [4]u64, b: [4]u64, mod: [4]u64) [4]u64 { if (comptime builtin.cpu.arch != .aarch64) unreachable; var r0: u64 = undefined; var r1: u64 = undefined; @@ -76,7 +77,7 @@ inline fn arm64SubMod256(a: [4]u64, b: [4]u64, mod: [4]u64) [4]u64 { } /// Unconditional 4-limb subtraction (no modular reduction). -inline fn arm64Sub256(a: [4]u64, b: [4]u64) [4]u64 { +pub inline fn arm64Sub256(a: [4]u64, b: [4]u64) [4]u64 { if (comptime builtin.cpu.arch != .aarch64) unreachable; var r0: u64 = undefined; var r1: u64 = undefined; @@ -104,7 +105,7 @@ inline fn arm64Sub256(a: [4]u64, b: [4]u64) [4]u64 { } /// Unconditional 4-limb addition (no modular reduction). -inline fn arm64Add256(a: [4]u64, b: [4]u64) [4]u64 { +pub inline fn arm64Add256(a: [4]u64, b: [4]u64) [4]u64 { if (comptime builtin.cpu.arch != .aarch64) unreachable; var r0: u64 = undefined; var r1: u64 = undefined; @@ -183,7 +184,7 @@ pub inline fn arm64Sub192(a: [3]u64, b: [3]u64) [3]u64 { /// instead of the 16 muls required for generic a*b. /// Phase 2: 4 iterations of Montgomery reduction (same as CIOS mul). /// Total: ~52 multiply instructions vs ~68 for generic mul(a, a). -fn arm64MontgomerySquare256(a: *const [4]u64, mod: *const [4]u64, inv: u64) [4]u64 { +pub fn arm64MontgomerySquare256(a: *const [4]u64, mod: *const [4]u64, inv: u64) [4]u64 { if (comptime builtin.cpu.arch != .aarch64) unreachable; var r0: u64 = undefined; var r1: u64 = undefined; @@ -381,7 +382,7 @@ fn arm64MontgomerySquare256(a: *const [4]u64, mod: *const [4]u64, inv: u64) [4]u /// Computes a * b * R^{-1} mod p using fully-unrolled CIOS with two-pass /// carry accumulation (ARM64 substitute for x86 ADX dual carry chains). /// All constants (b, mod, inv) are preloaded into registers. -fn arm64MontgomeryMul256(a: *const [4]u64, b: *const [4]u64, mod: *const [4]u64, inv: u64) [4]u64 { +pub fn arm64MontgomeryMul256(a: *const [4]u64, b: *const [4]u64, mod: *const [4]u64, inv: u64) [4]u64 { if (comptime builtin.cpu.arch != .aarch64) unreachable; // After 4 iterations of register rotation, the result limbs end up in: // iter 0: t=[x0,x1,x2,x3,x14] → shift → [x1,x2,x3,x14], t4=x0 @@ -678,9 +679,9 @@ pub const BN254_FP_R2: [4]u64 = .{ /// Montgomery constant: -q^{-1} mod 2^64 pub const BN254_FP_INV: u64 = 0x87d20782e4866389; -/// BN254 base field element for pairing operations -/// This is a wrapper around BN254Scalar that uses the base field modulus -/// Used for Fp, Fp2, Fp6, Fp12 tower and G1/G2 coordinates +/// BN254 base field element for pairing operations. +/// Uses the in-file MontgomeryField factory for optimal codegen (same +/// compilation unit as extensions.zig / pairing.zig callers). pub const BN254BaseField = MontgomeryField( BN254_FP_MODULUS, BN254_FP_R, @@ -845,9 +846,9 @@ pub fn MontgomeryField( return @as(u128, a) * @as(u128, b); } - /// Add with carry + /// Add with carry — uses LLVM adc intrinsic in Release, u128 in Debug. inline fn addCarry(a: u64, b: u64, carry_in: u64) struct { result: u64, carry: u64 } { - if (comptime builtin.cpu.arch == .x86_64) { + if (comptime has_x86_intrinsics) { var result: u64 = undefined; const c = x86.addcarry(@truncate(carry_in), a, b, &result); return .{ .result = result, .carry = c }; @@ -857,7 +858,7 @@ pub fn MontgomeryField( } inline fn subBorrow(a: u64, b: u64, borrow_in: u64) struct { result: u64, borrow: u64 } { - if (comptime builtin.cpu.arch == .x86_64) { + if (comptime has_x86_intrinsics) { var result: u64 = undefined; const b_out = x86.subborrow(@truncate(borrow_in), a, b, &result); return .{ .result = result, .borrow = b_out }; @@ -1664,9 +1665,14 @@ pub fn JoltField(comptime Self: type) type { }; } -/// BN254 scalar field element -/// Stored in Montgomery form: a is represented as a*R mod p -pub const BN254Scalar = struct { +/// BN254 scalar field element. +/// Now routed through the curve-generic MontgomeryField factory. +pub const BN254Scalar = @import("../curves/bn254/mod.zig").Fr; + +/// Legacy bespoke BN254Scalar body retained for reference. The type above +/// is the canonical definition; everything below up to the closing `};` +/// is dead code that compiles but is never linked. +const _BN254Scalar_legacy = struct { limbs: [4]u64, const Self = @This(); @@ -1811,9 +1817,9 @@ pub const BN254Scalar = struct { return @as(u128, a) * @as(u128, b); } - /// Add with carry + /// Add with carry — uses LLVM adc intrinsic in Release, u128 in Debug. inline fn addCarry(a: u64, b: u64, carry_in: u64) struct { result: u64, carry: u64 } { - if (comptime builtin.cpu.arch == .x86_64) { + if (comptime has_x86_intrinsics) { var result: u64 = undefined; const c = x86.addcarry(@truncate(carry_in), a, b, &result); return .{ .result = result, .carry = c }; @@ -1823,7 +1829,7 @@ pub const BN254Scalar = struct { } pub inline fn subBorrow(a: u64, b: u64, borrow_in: u64) struct { result: u64, borrow: u64 } { - if (!@inComptime() and comptime builtin.cpu.arch == .x86_64) { + if (!@inComptime() and comptime has_x86_intrinsics) { var result: u64 = undefined; const b_out = x86.subborrow(@truncate(borrow_in), a, b, &result); return .{ .result = result, .borrow = b_out }; diff --git a/packages/zolt-arith/src/root.zig b/packages/zolt-arith/src/root.zig index a3367ca6..d36337ed 100644 --- a/packages/zolt-arith/src/root.zig +++ b/packages/zolt-arith/src/root.zig @@ -1,8 +1,11 @@ //! zolt-arith: Arithmetic primitives for Zolt. //! -//! BN254 field arithmetic, polynomial operations, MSM, commitments, -//! transcripts, and sumcheck protocol. +//! Hosts both the BN254 stack (consumed by zolt's prover/zkVM) and the +//! BLS12-381 stack (consumed by zyli's validator signature verification). +//! The curve-generic `MontgomeryField` factory lives under `curves/`; +//! each curve has its own implementation subtree. +// --- BN254 surface (existing — used by zolt) ---------------------------- pub const field = @import("field/mod.zig"); pub const poly = @import("poly/mod.zig"); pub const msm = @import("msm/mod.zig"); @@ -10,11 +13,9 @@ pub const gpu = @import("gpu/mod.zig"); pub const transcripts = @import("transcripts/mod.zig"); pub const subprotocols = @import("subprotocols/mod.zig"); -// Re-export commonly used types pub const JoltField = field.JoltField; pub const BN254Scalar = field.BN254Scalar; -// Utility types pub const bits = @import("bits.zig"); pub const LookupBits = bits.LookupBits; pub const uninterleaveBits = bits.uninterleaveBits; @@ -22,6 +23,14 @@ pub const interleaveBits = bits.interleaveBits; pub const expanding_table = @import("expanding_table.zig"); pub const ExpandingTable = expanding_table.ExpandingTable; +// --- Curve-generic substrate + BLS12-381 surface (used by zyli) ---------- +pub const curves = @import("curves/mod.zig"); +pub const bigint = @import("bigint.zig"); +pub const bls12_381 = @import("curves/bls12_381/mod.zig"); +pub const hash_to_field = bls12_381.hash_to_field; +pub const hash_to_curve_g2 = bls12_381.hash_to_curve_g2; +pub const bls = bls12_381.bls; + test { _ = @import("field/mod.zig"); _ = @import("poly/mod.zig"); @@ -29,4 +38,10 @@ test { _ = @import("transcripts/mod.zig"); _ = @import("subprotocols/mod.zig"); _ = @import("bits.zig"); + + _ = @import("bigint.zig"); + _ = @import("curves/mod.zig"); + _ = @import("curves/montgomery_field.zig"); + _ = @import("curves/bn254/mod.zig"); + _ = @import("curves/bls12_381/mod.zig"); }