diff --git a/Cargo.lock b/Cargo.lock index 965bc80ee2..948643bc94 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -581,40 +581,26 @@ dependencies = [ "derive_arbitrary", ] -[[package]] -name = "ark-bn254" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d69eab57e8d2663efa5c63135b2af4f396d66424f88954c21104125ab6b3e6bc" -replace = "ark-bn254 0.5.0 (git+https://github.com/a16z/arkworks-algebra?branch=dev%2Ftwist-shout)" - [[package]] name = "ark-bn254" version = "0.5.0" source = "git+https://github.com/a16z/arkworks-algebra?branch=dev%2Ftwist-shout#76bb3a4518928f1ff7f15875f940d614bb9845e6" dependencies = [ - "ark-ec 0.5.0 (git+https://github.com/a16z/arkworks-algebra?branch=dev%2Ftwist-shout)", - "ark-ff 0.5.0 (git+https://github.com/a16z/arkworks-algebra?branch=dev%2Ftwist-shout)", - "ark-serialize 0.5.0 (git+https://github.com/a16z/arkworks-algebra?branch=dev%2Ftwist-shout)", + "ark-ec", + "ark-ff 0.5.0", + "ark-serialize 0.5.0", "ark-std 0.5.0", ] -[[package]] -name = "ark-ec" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43d68f2d516162846c1238e755a7c4d131b892b70cc70c471a8e3ca3ed818fce" -replace = "ark-ec 0.5.0 (git+https://github.com/a16z/arkworks-algebra?branch=dev%2Ftwist-shout)" - [[package]] name = "ark-ec" version = "0.5.0" source = "git+https://github.com/a16z/arkworks-algebra?branch=dev%2Ftwist-shout#76bb3a4518928f1ff7f15875f940d614bb9845e6" dependencies = [ "ahash", - "ark-ff 0.5.0 (git+https://github.com/a16z/arkworks-algebra?branch=dev%2Ftwist-shout)", + "ark-ff 0.5.0", "ark-poly", - "ark-serialize 0.5.0 (git+https://github.com/a16z/arkworks-algebra?branch=dev%2Ftwist-shout)", + "ark-serialize 0.5.0", "ark-std 0.5.0", "educe", "fnv", @@ -665,13 +651,6 @@ dependencies = [ "zeroize", ] -[[package]] -name = "ark-ff" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a177aba0ed1e0fbb62aa9f6d0502e9b46dad8c2eab04c14258a1212d2557ea70" -replace = "ark-ff 0.5.0 (git+https://github.com/a16z/arkworks-algebra?branch=dev%2Ftwist-shout)" - [[package]] name = "ark-ff" version = "0.5.0" @@ -680,7 +659,7 @@ dependencies = [ "allocative", "ark-ff-asm 0.5.0", "ark-ff-macros 0.5.0", - "ark-serialize 0.5.0 (git+https://github.com/a16z/arkworks-algebra?branch=dev%2Ftwist-shout)", + "ark-serialize 0.5.0", "ark-std 0.5.0", "arrayvec", "digest 0.10.7", @@ -764,9 +743,9 @@ name = "ark-grumpkin" version = "0.5.0" source = "git+https://github.com/a16z/arkworks-algebra?branch=dev%2Ftwist-shout#76bb3a4518928f1ff7f15875f940d614bb9845e6" dependencies = [ - "ark-bn254 0.5.0 (git+https://github.com/a16z/arkworks-algebra?branch=dev%2Ftwist-shout)", - "ark-ec 0.5.0 (git+https://github.com/a16z/arkworks-algebra?branch=dev%2Ftwist-shout)", - "ark-ff 0.5.0 (git+https://github.com/a16z/arkworks-algebra?branch=dev%2Ftwist-shout)", + "ark-bn254", + "ark-ec", + "ark-ff 0.5.0", "ark-std 0.5.0", ] @@ -776,8 +755,8 @@ version = "0.5.0" source = "git+https://github.com/a16z/arkworks-algebra?branch=dev%2Ftwist-shout#76bb3a4518928f1ff7f15875f940d614bb9845e6" dependencies = [ "ahash", - "ark-ff 0.5.0 (git+https://github.com/a16z/arkworks-algebra?branch=dev%2Ftwist-shout)", - "ark-serialize 0.5.0 (git+https://github.com/a16z/arkworks-algebra?branch=dev%2Ftwist-shout)", + "ark-ff 0.5.0", + "ark-serialize 0.5.0", "ark-std 0.5.0", "educe", "fnv", @@ -789,8 +768,8 @@ name = "ark-secp256k1" version = "0.5.0" source = "git+https://github.com/a16z/arkworks-algebra?branch=dev%2Ftwist-shout#76bb3a4518928f1ff7f15875f940d614bb9845e6" dependencies = [ - "ark-ec 0.5.0 (git+https://github.com/a16z/arkworks-algebra?branch=dev%2Ftwist-shout)", - "ark-ff 0.5.0 (git+https://github.com/a16z/arkworks-algebra?branch=dev%2Ftwist-shout)", + "ark-ec", + "ark-ff 0.5.0", "ark-std 0.5.0", ] @@ -799,8 +778,8 @@ name = "ark-secp256r1" version = "0.5.0" source = "git+https://github.com/a16z/arkworks-algebra?branch=dev%2Ftwist-shout#76bb3a4518928f1ff7f15875f940d614bb9845e6" dependencies = [ - "ark-ec 0.5.0 (git+https://github.com/a16z/arkworks-algebra?branch=dev%2Ftwist-shout)", - "ark-ff 0.5.0 (git+https://github.com/a16z/arkworks-algebra?branch=dev%2Ftwist-shout)", + "ark-ec", + "ark-ff 0.5.0", "ark-std 0.5.0", ] @@ -825,13 +804,6 @@ dependencies = [ "num-bigint", ] -[[package]] -name = "ark-serialize" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3f4d068aaf107ebcd7dfb52bc748f8030e0fc930ac8e360146ca54c1203088f7" -replace = "ark-serialize 0.5.0 (git+https://github.com/a16z/arkworks-algebra?branch=dev%2Ftwist-shout)" - [[package]] name = "ark-serialize" version = "0.5.0" @@ -1403,7 +1375,7 @@ name = "common" version = "0.2.0" dependencies = [ "allocative", - "ark-serialize 0.5.0 (git+https://github.com/a16z/arkworks-algebra?branch=dev%2Ftwist-shout)", + "ark-serialize 0.5.0", "serde", "syn 2.0.117", ] @@ -1918,10 +1890,10 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b8c58baea9f0ed973489cd1981b0e6a8c91aafddb05e3903b1dd54175ddcb52d" dependencies = [ - "ark-bn254 0.5.0 (registry+https://github.com/rust-lang/crates.io-index)", - "ark-ec 0.5.0 (registry+https://github.com/rust-lang/crates.io-index)", - "ark-ff 0.5.0 (registry+https://github.com/rust-lang/crates.io-index)", - "ark-serialize 0.5.0 (registry+https://github.com/rust-lang/crates.io-index)", + "ark-bn254", + "ark-ec", + "ark-ff 0.5.0", + "ark-serialize 0.5.0", "ark-std 0.5.0", "bincode 1.3.3", "blake2 0.10.6", @@ -2789,7 +2761,7 @@ name = "jolt" version = "0.1.0" dependencies = [ "anyhow", - "ark-bn254 0.5.0 (git+https://github.com/a16z/arkworks-algebra?branch=dev%2Ftwist-shout)", + "ark-bn254", "clap", "common", "env_logger", @@ -2817,6 +2789,7 @@ dependencies = [ "jolt-transcript", "rand_chacha 0.3.1", "rand_core 0.6.4", + "rayon", "serde", "thiserror 2.0.18", ] @@ -2839,10 +2812,10 @@ name = "jolt-core" version = "0.1.0" dependencies = [ "allocative", - "ark-bn254 0.5.0 (git+https://github.com/a16z/arkworks-algebra?branch=dev%2Ftwist-shout)", - "ark-ec 0.5.0 (git+https://github.com/a16z/arkworks-algebra?branch=dev%2Ftwist-shout)", - "ark-ff 0.5.0 (git+https://github.com/a16z/arkworks-algebra?branch=dev%2Ftwist-shout)", - "ark-serialize 0.5.0 (git+https://github.com/a16z/arkworks-algebra?branch=dev%2Ftwist-shout)", + "ark-bn254", + "ark-ec", + "ark-ff 0.5.0", + "ark-serialize 0.5.0", "ark-std 0.5.0", "bincode 2.0.1", "blake2 0.11.0-rc.6", @@ -2891,15 +2864,18 @@ dependencies = [ name = "jolt-crypto" version = "0.1.0" dependencies = [ - "ark-bn254 0.5.0 (git+https://github.com/a16z/arkworks-algebra?branch=dev%2Ftwist-shout)", - "ark-ec 0.5.0 (git+https://github.com/a16z/arkworks-algebra?branch=dev%2Ftwist-shout)", - "ark-ff 0.5.0 (git+https://github.com/a16z/arkworks-algebra?branch=dev%2Ftwist-shout)", - "ark-serialize 0.5.0 (git+https://github.com/a16z/arkworks-algebra?branch=dev%2Ftwist-shout)", + "ark-bn254", + "ark-ec", + "ark-ff 0.5.0", + "ark-grumpkin", + "ark-serialize 0.5.0", "ark-std 0.5.0", "bincode 2.0.1", + "blake2 0.11.0-rc.6", "criterion", "jolt-field", "jolt-poly", + "jolt-r1cs", "jolt-transcript", "num-bigint", "num-integer", @@ -2909,13 +2885,14 @@ dependencies = [ "rayon", "serde", "serde_json", + "thiserror 2.0.18", ] [[package]] name = "jolt-dory" version = "0.1.0" dependencies = [ - "ark-serialize 0.5.0 (git+https://github.com/a16z/arkworks-algebra?branch=dev%2Ftwist-shout)", + "ark-serialize 0.5.0", "criterion", "dory-pcs", "jolt-crypto", @@ -2936,7 +2913,7 @@ name = "jolt-eval" version = "0.1.0" dependencies = [ "arbitrary", - "ark-bn254 0.5.0 (git+https://github.com/a16z/arkworks-algebra?branch=dev%2Ftwist-shout)", + "ark-bn254", "clap", "common", "criterion", @@ -2979,9 +2956,9 @@ name = "jolt-field" version = "0.1.0" dependencies = [ "allocative", - "ark-bn254 0.5.0 (git+https://github.com/a16z/arkworks-algebra?branch=dev%2Ftwist-shout)", - "ark-ff 0.5.0 (git+https://github.com/a16z/arkworks-algebra?branch=dev%2Ftwist-shout)", - "ark-serialize 0.5.0 (git+https://github.com/a16z/arkworks-algebra?branch=dev%2Ftwist-shout)", + "ark-bn254", + "ark-ff 0.5.0", + "ark-serialize 0.5.0", "ark-std 0.5.0", "criterion", "num-traits", @@ -3067,8 +3044,8 @@ dependencies = [ name = "jolt-inlines-grumpkin" version = "0.1.0" dependencies = [ - "ark-ec 0.5.0 (git+https://github.com/a16z/arkworks-algebra?branch=dev%2Ftwist-shout)", - "ark-ff 0.5.0 (git+https://github.com/a16z/arkworks-algebra?branch=dev%2Ftwist-shout)", + "ark-ec", + "ark-ff 0.5.0", "ark-grumpkin", "jolt-inlines-sdk", "num-bigint", @@ -3092,8 +3069,8 @@ dependencies = [ name = "jolt-inlines-p256" version = "0.1.0" dependencies = [ - "ark-ec 0.5.0 (git+https://github.com/a16z/arkworks-algebra?branch=dev%2Ftwist-shout)", - "ark-ff 0.5.0 (git+https://github.com/a16z/arkworks-algebra?branch=dev%2Ftwist-shout)", + "ark-ec", + "ark-ff 0.5.0", "ark-secp256r1", "jolt-inlines-sdk", "num-bigint", @@ -3121,7 +3098,7 @@ dependencies = [ name = "jolt-inlines-secp256k1" version = "0.1.0" dependencies = [ - "ark-ff 0.5.0 (git+https://github.com/a16z/arkworks-algebra?branch=dev%2Ftwist-shout)", + "ark-ff 0.5.0", "ark-secp256k1", "jolt-inlines-sdk", "num-bigint", @@ -3161,6 +3138,7 @@ dependencies = [ "jolt-crypto", "jolt-field", "jolt-poly", + "jolt-r1cs", "jolt-transcript", "rand 0.8.5", "rand_chacha 0.3.1", @@ -3175,10 +3153,10 @@ name = "jolt-optimizations" version = "0.5.0" source = "git+https://github.com/a16z/arkworks-algebra?branch=dev%2Ftwist-shout#76bb3a4518928f1ff7f15875f940d614bb9845e6" dependencies = [ - "ark-bn254 0.5.0 (git+https://github.com/a16z/arkworks-algebra?branch=dev%2Ftwist-shout)", - "ark-ec 0.5.0 (git+https://github.com/a16z/arkworks-algebra?branch=dev%2Ftwist-shout)", - "ark-ff 0.5.0 (git+https://github.com/a16z/arkworks-algebra?branch=dev%2Ftwist-shout)", - "ark-serialize 0.5.0 (git+https://github.com/a16z/arkworks-algebra?branch=dev%2Ftwist-shout)", + "ark-bn254", + "ark-ec", + "ark-ff 0.5.0", + "ark-serialize 0.5.0", "ark-std 0.5.0", "arrayvec", "num-bigint", @@ -3233,7 +3211,7 @@ dependencies = [ name = "jolt-program" version = "0.1.0" dependencies = [ - "ark-serialize 0.5.0 (git+https://github.com/a16z/arkworks-algebra?branch=dev%2Ftwist-shout)", + "ark-serialize 0.5.0", "common", "hex", "jolt-riscv", @@ -3251,6 +3229,7 @@ dependencies = [ "jolt-claims", "jolt-field", "jolt-poly", + "num-bigint", "num-traits", "rand_chacha 0.3.1", "rand_core 0.6.4", @@ -3264,7 +3243,7 @@ dependencies = [ name = "jolt-riscv" version = "0.1.0" dependencies = [ - "ark-serialize 0.5.0 (git+https://github.com/a16z/arkworks-algebra?branch=dev%2Ftwist-shout)", + "ark-serialize 0.5.0", "serde", "serde_json", "strum 0.28.0", @@ -3320,12 +3299,13 @@ dependencies = [ name = "jolt-transcript" version = "0.1.0" dependencies = [ - "ark-bn254 0.5.0 (git+https://github.com/a16z/arkworks-algebra?branch=dev%2Ftwist-shout)", - "ark-ff 0.5.0 (git+https://github.com/a16z/arkworks-algebra?branch=dev%2Ftwist-shout)", + "ark-bn254", + "ark-ff 0.5.0", "blake2 0.11.0-rc.6", "criterion", "digest 0.11.3", "jolt-field", + "jolt-r1cs", "light-poseidon", "num-traits", "rand 0.8.5", @@ -3336,8 +3316,8 @@ dependencies = [ name = "jolt-verifier" version = "0.1.0" dependencies = [ - "ark-bn254 0.5.0 (git+https://github.com/a16z/arkworks-algebra?branch=dev%2Ftwist-shout)", - "ark-serialize 0.5.0 (git+https://github.com/a16z/arkworks-algebra?branch=dev%2Ftwist-shout)", + "ark-bn254", + "ark-serialize 0.5.0", "common", "jolt-blindfold", "jolt-claims", @@ -3446,8 +3426,8 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "47a1ccadd0bb5a32c196da536fd72c59183de24a055f6bf0513bf845fefab862" dependencies = [ - "ark-bn254 0.5.0 (registry+https://github.com/rust-lang/crates.io-index)", - "ark-ff 0.5.0 (registry+https://github.com/rust-lang/crates.io-index)", + "ark-bn254", + "ark-ff 0.5.0", "num-bigint", "thiserror 1.0.69", ] @@ -4896,7 +4876,7 @@ dependencies = [ name = "recursion" version = "0.1.0" dependencies = [ - "ark-serialize 0.5.0 (git+https://github.com/a16z/arkworks-algebra?branch=dev%2Ftwist-shout)", + "ark-serialize 0.5.0", "clap", "jolt-sdk", "postcard", @@ -4908,7 +4888,7 @@ dependencies = [ name = "recursion-guest" version = "0.1.0" dependencies = [ - "ark-serialize 0.5.0 (git+https://github.com/a16z/arkworks-algebra?branch=dev%2Ftwist-shout)", + "ark-serialize 0.5.0", "jolt-sdk", ] @@ -5256,7 +5236,7 @@ dependencies = [ "alloy-rlp", "ark-ff 0.3.0", "ark-ff 0.4.2", - "ark-ff 0.5.0 (registry+https://github.com/rust-lang/crates.io-index)", + "ark-ff 0.5.0", "bytes", "fastrlp 0.3.1", "fastrlp 0.4.0", @@ -5953,8 +5933,8 @@ version = "0.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "074f823019979d89e8d46a966feb3d173f3db9a21c6764f8c2282e137017bba5" dependencies = [ - "ark-ff 0.5.0 (registry+https://github.com/rust-lang/crates.io-index)", - "ark-serialize 0.5.0 (registry+https://github.com/rust-lang/crates.io-index)", + "ark-ff 0.5.0", + "ark-serialize 0.5.0", "blake2 0.11.0-rc.6", "digest 0.11.3", "keccak 0.1.6", @@ -6308,7 +6288,7 @@ name = "tracer" version = "0.2.0" dependencies = [ "addr2line 0.26.1", - "ark-serialize 0.5.0 (git+https://github.com/a16z/arkworks-algebra?branch=dev%2Ftwist-shout)", + "ark-serialize 0.5.0", "clap", "common", "derive_more", @@ -6408,9 +6388,9 @@ dependencies = [ name = "transpiler" version = "0.1.0" dependencies = [ - "ark-bn254 0.5.0 (git+https://github.com/a16z/arkworks-algebra?branch=dev%2Ftwist-shout)", - "ark-ff 0.5.0 (git+https://github.com/a16z/arkworks-algebra?branch=dev%2Ftwist-shout)", - "ark-serialize 0.5.0 (git+https://github.com/a16z/arkworks-algebra?branch=dev%2Ftwist-shout)", + "ark-bn254", + "ark-ff 0.5.0", + "ark-serialize 0.5.0", "ark-std 0.5.0", "clap", "common", @@ -7297,9 +7277,9 @@ name = "zklean-extractor" version = "0.1.0" dependencies = [ "allocative", - "ark-bn254 0.5.0 (git+https://github.com/a16z/arkworks-algebra?branch=dev%2Ftwist-shout)", - "ark-ff 0.5.0 (git+https://github.com/a16z/arkworks-algebra?branch=dev%2Ftwist-shout)", - "ark-serialize 0.5.0 (git+https://github.com/a16z/arkworks-algebra?branch=dev%2Ftwist-shout)", + "ark-bn254", + "ark-ff 0.5.0", + "ark-serialize 0.5.0", "ark-std 0.5.0", "build-fs-tree", "clap", diff --git a/Cargo.toml b/Cargo.toml index f36c49c003..336ba1ece3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -167,11 +167,11 @@ lto = "fat" strip = false codegen-units = 1 -[replace] -"ark-bn254:0.5.0" = { git = "https://github.com/a16z/arkworks-algebra", branch = "dev/twist-shout" } -"ark-ff:0.5.0" = { git = "https://github.com/a16z/arkworks-algebra", branch = "dev/twist-shout" } -"ark-ec:0.5.0" = { git = "https://github.com/a16z/arkworks-algebra", branch = "dev/twist-shout" } -"ark-serialize:0.5.0" = { git = "https://github.com/a16z/arkworks-algebra", branch = "dev/twist-shout" } +[patch.crates-io] +ark-bn254 = { git = "https://github.com/a16z/arkworks-algebra", branch = "dev/twist-shout" } +ark-ff = { git = "https://github.com/a16z/arkworks-algebra", branch = "dev/twist-shout" } +ark-ec = { git = "https://github.com/a16z/arkworks-algebra", branch = "dev/twist-shout" } +ark-serialize = { git = "https://github.com/a16z/arkworks-algebra", branch = "dev/twist-shout" } [workspace.metadata.cargo-machete] ignored = ["jolt-sdk"] diff --git a/crates/jolt-blindfold/Cargo.toml b/crates/jolt-blindfold/Cargo.toml index d0d9f572de..97044987f3 100644 --- a/crates/jolt-blindfold/Cargo.toml +++ b/crates/jolt-blindfold/Cargo.toml @@ -16,6 +16,8 @@ jolt-poly.workspace = true jolt-r1cs.workspace = true jolt-sumcheck = { workspace = true, features = ["r1cs"] } jolt-transcript.workspace = true +rand_core.workspace = true +rayon.workspace = true serde = { workspace = true, features = ["derive"] } thiserror.workspace = true diff --git a/crates/jolt-blindfold/src/error.rs b/crates/jolt-blindfold/src/error.rs index 1ff7a645df..3e356d615c 100644 --- a/crates/jolt-blindfold/src/error.rs +++ b/crates/jolt-blindfold/src/error.rs @@ -84,6 +84,59 @@ pub enum RelaxedError { }, } +#[derive(Debug, ThisError)] +pub enum ProverError { + #[error(transparent)] + Relaxed(#[from] RelaxedError), + #[error(transparent)] + R1csMatrix(#[from] ConstraintMatrixEvalError), + #[error(transparent)] + VectorOpening(#[from] VectorOpeningError), + #[error("{name} length mismatch: expected {expected}, got {actual}")] + LengthMismatch { + name: &'static str, + expected: usize, + actual: usize, + }, + #[error("witness row {row} length mismatch: expected {expected}, got {actual}")] + WitnessRowLengthMismatch { + row: usize, + expected: usize, + actual: usize, + }, + #[error("{name} row length {row_len} exceeds vector commitment capacity {capacity}")] + CommitmentCapacityExceeded { + name: &'static str, + capacity: usize, + row_len: usize, + }, + #[error("{name} row commitment backend failed: {reason}")] + RowCommitmentBackend { name: &'static str, reason: String }, + #[error("{name} backend kernel failed: {reason}")] + BackendKernel { name: &'static str, reason: String }, + #[error("{name} must be a non-zero power of two, got {value}")] + InvalidPowerOfTwo { name: &'static str, value: usize }, + #[error("{name} dimension {value} cannot be represented")] + DimensionOverflow { name: &'static str, value: usize }, + #[error("folded eval commitment {index} does not match opened value and blinding")] + EvalCommitmentMismatch { index: usize }, + #[error("folded eval witness {kind} {index} does not match opened witness coordinate: expected {expected}, got {actual}")] + EvalWitnessMismatch { + kind: &'static str, + index: usize, + expected: F, + actual: F, + }, + #[error("interpolation denominator is zero for degree {degree} at point {point}")] + ZeroInterpolationDenominator { degree: usize, point: usize }, + #[error("sumcheck round claim mismatch: expected {expected}, got {actual}")] + SumcheckRoundClaimMismatch { expected: F, actual: F }, + #[error("multilinear evaluation length mismatch: expected {expected}, got {actual}")] + MultilinearLengthMismatch { expected: usize, actual: usize }, + #[error("{name} must have at least one sumcheck round")] + DegenerateSumcheck { name: &'static str }, +} + #[derive(Debug, ThisError)] pub enum VerificationError { #[error("claims have {claim_stages} stages but proof has {proof_stages}")] diff --git a/crates/jolt-blindfold/src/lib.rs b/crates/jolt-blindfold/src/lib.rs index 0f728887c6..c567f35401 100644 --- a/crates/jolt-blindfold/src/lib.rs +++ b/crates/jolt-blindfold/src/lib.rs @@ -4,19 +4,25 @@ mod builder; mod error; mod proof; pub mod protocol; +mod prove; pub mod r1cs; mod relaxed; mod statements; mod verify; pub use builder::{BlindFoldProtocolBuilder, BlindFoldStageBuilder}; -pub use error::{Error, LayoutError, RelaxedError, VerificationError}; +pub use error::{Error, LayoutError, ProverError, RelaxedError, VerificationError}; pub use proof::BlindFoldProof; pub use protocol::{ BlindFoldDimensions, BlindFoldProtocol, FinalOpeningWitnessCoordinates, RowDimensions, WitnessCoordinate, WitnessRowLayout, }; +pub use prove::{ + prove, prove_with_row_committer, BlindFoldRowCommitter, BlindFoldWitness, + DirectBlindFoldRowCommitter, +}; pub use relaxed::{RelaxedInstance, RelaxedWitness}; pub use statements::{ BlindFoldStage, BlindFoldStatement, CommittedClaimRows, FinalOpeningBinding, OpeningAlias, }; +pub use verify::verify; diff --git a/crates/jolt-blindfold/src/protocol.rs b/crates/jolt-blindfold/src/protocol.rs index 6f206b2d07..0013ee446f 100644 --- a/crates/jolt-blindfold/src/protocol.rs +++ b/crates/jolt-blindfold/src/protocol.rs @@ -6,8 +6,9 @@ use jolt_r1cs::{ConstraintMatrices, R1csBuilder, Variable}; use jolt_sumcheck::{CommittedOutputClaims, CommittedSumcheckConsistency}; use crate::{ - r1cs::Layout, BlindFoldProtocolBuilder, BlindFoldStatement, Error, RelaxedError, - RelaxedInstance, VerificationError, + r1cs::{build_with_sources, Layout}, + BlindFoldProtocolBuilder, BlindFoldStatement, Error, RelaxedError, RelaxedInstance, + VerificationError, }; #[derive(Clone, Debug)] @@ -437,7 +438,7 @@ where challenges: &[(Ch, F)], ) -> Result<(ConstraintMatrices, Layout), VerificationError> { let mut r1cs = R1csBuilder::new(); - let layout = self.build_with_sources(&mut r1cs, publics, challenges)?; + let layout = build_with_sources(&mut r1cs, self, publics, challenges)?; Ok((r1cs.into_matrices(), layout)) } } diff --git a/crates/jolt-blindfold/src/prove.rs b/crates/jolt-blindfold/src/prove.rs new file mode 100644 index 0000000000..58c1f90065 --- /dev/null +++ b/crates/jolt-blindfold/src/prove.rs @@ -0,0 +1,1365 @@ +use jolt_crypto::{HomomorphicCommitment, VectorCommitment, VectorCommitmentOpening}; +use jolt_field::Field; +use jolt_poly::{BindingOrder, EqPolynomial, Polynomial, UnivariatePoly}; +use jolt_r1cs::{ConstraintMatrices, ConstraintMatrixEvalError, SparseRow}; +use jolt_sumcheck::{CompressedSumcheckProof, SUMCHECK_ROUND_TRANSCRIPT_LABEL}; +use jolt_transcript::{AppendToTranscript, Label, LabelWithCount, Transcript}; +use rand_core::RngCore; +use rayon::prelude::*; + +use crate::{BlindFoldProof, BlindFoldProtocol, ProverError, WitnessCoordinate}; + +const OUTER_SUMCHECK_DEGREE: usize = 3; +const INNER_SUMCHECK_DEGREE: usize = 2; +const INNER_SUMCHECK_LABEL: &[u8] = b"inner_sumcheck_poly"; + +#[derive(Clone, Copy, Debug)] +pub struct BlindFoldWitness<'a, F: Field> { + pub rows: &'a [Vec], + pub blindings: &'a [F], + pub eval_outputs: &'a [F], + pub eval_blindings: &'a [F], +} + +pub trait BlindFoldRowCommitter +where + F: Field, + VC: VectorCommitment, +{ + fn commit_rows( + &mut self, + setup: &VC::Setup, + rows: &[Vec], + blindings: &[F], + name: &'static str, + ) -> Result, ProverError>; + + fn compute_error_rows( + &mut self, + r1cs: &ConstraintMatrices, + u: F, + witness: &[F], + row_count: usize, + row_len: usize, + name: &'static str, + ) -> Result>, ProverError> { + let _ = name; + error_rows_for(r1cs, u, witness, row_count, row_len) + } + + #[expect( + clippy::too_many_arguments, + reason = "cross-term error rows are defined by two relaxed witnesses" + )] + fn compute_cross_term_error_rows( + &mut self, + r1cs: &ConstraintMatrices, + real_u: F, + real_witness: &[F], + random_u: F, + random_witness: &[F], + row_count: usize, + row_len: usize, + name: &'static str, + ) -> Result>, ProverError> { + let _ = name; + cross_term_error_rows_for( + r1cs, + real_u, + real_witness, + random_u, + random_witness, + row_count, + row_len, + ) + } + + fn fold_rows( + &mut self, + real: &[Vec], + random: &[Vec], + challenge: F, + name: &'static str, + ) -> Result>, ProverError> { + let _ = name; + fold_rows(real, random, challenge) + } + + fn fold_scalars( + &mut self, + real: &[F], + random: &[F], + challenge: F, + name: &'static str, + ) -> Result, ProverError> { + fold_scalars(name, real, random, challenge) + } + + fn fold_error_rows( + &mut self, + real: &[Vec], + cross: &[Vec], + random: &[Vec], + challenge: F, + name: &'static str, + ) -> Result>, ProverError> { + let _ = name; + fold_error_rows(real, cross, random, challenge) + } + + fn fold_error_scalars( + &mut self, + real: &[F], + cross: &[F], + random: &[F], + challenge: F, + name: &'static str, + ) -> Result, ProverError> { + fold_error_scalars(name, real, cross, random, challenge) + } + + fn open_rows( + &mut self, + setup: &VC::Setup, + rows: &[Vec], + blindings: &[F], + row_point: &[F], + entry_point: &[F], + name: &'static str, + ) -> Result<(VectorCommitmentOpening, F), ProverError> { + open_committed_rows::(setup, rows, blindings, row_point, entry_point, name) + } +} + +#[derive(Debug, Default)] +pub struct DirectBlindFoldRowCommitter; + +impl BlindFoldRowCommitter for DirectBlindFoldRowCommitter +where + F: Field, + VC: VectorCommitment, +{ + fn commit_rows( + &mut self, + setup: &VC::Setup, + rows: &[Vec], + blindings: &[F], + name: &'static str, + ) -> Result, ProverError> { + commit_rows::(setup, rows, blindings, name) + } +} + +pub fn prove( + setup: &VC::Setup, + protocol: &BlindFoldProtocol, + transcript: &mut T, + witness: BlindFoldWitness<'_, F>, + rng: &mut R, +) -> Result, ProverError> +where + F: Field + AppendToTranscript, + VC: VectorCommitment, + VC::Output: HomomorphicCommitment + AppendToTranscript, + T: Transcript, + R: RngCore, +{ + let mut row_committer = DirectBlindFoldRowCommitter; + prove_with_row_committer::( + setup, + protocol, + transcript, + witness, + rng, + &mut row_committer, + ) +} + +pub fn prove_with_row_committer( + setup: &VC::Setup, + protocol: &BlindFoldProtocol, + transcript: &mut T, + witness: BlindFoldWitness<'_, F>, + rng: &mut R, + row_committer: &mut C, +) -> Result, ProverError> +where + F: Field + AppendToTranscript, + VC: VectorCommitment, + VC::Output: HomomorphicCommitment + AppendToTranscript, + T: Transcript, + R: RngCore, + C: BlindFoldRowCommitter, +{ + validate_witness::(setup, protocol, witness)?; + + let auxiliary_range = protocol.dimensions.witness_rows.auxiliary.clone(); + let auxiliary_row_commitments = row_committer.commit_rows( + setup, + &witness.rows[auxiliary_range.clone()], + &witness.blindings[auxiliary_range], + "auxiliary witness rows", + )?; + let committed = protocol.committed_relaxed_instance(&auxiliary_row_commitments)?; + for (index, ((commitment, &output), &blinding)) in protocol + .eval_commitments + .iter() + .zip(witness.eval_outputs) + .zip(witness.eval_blindings) + .enumerate() + { + if !VC::verify(setup, commitment, &[output], &blinding) { + return Err(ProverError::EvalCommitmentMismatch { index }); + } + } + + let random_u = F::random(rng); + let mut random_witness_rows = random_rows( + protocol.dimensions.witness.row_count, + protocol.dimensions.witness.row_len, + rng, + ); + let mut random_witness_blindings = (0..protocol.dimensions.witness.row_count) + .map(|_| F::random(rng)) + .collect::>(); + for row in protocol.dimensions.witness_rows.padding.clone() { + random_witness_rows[row].fill(F::zero()); + random_witness_blindings[row] = F::zero(); + } + + let random_eval_outputs = (0..protocol.eval_commitments.len()) + .map(|_| F::random(rng)) + .collect::>(); + let random_eval_blindings = (0..protocol.eval_commitments.len()) + .map(|_| F::random(rng)) + .collect::>(); + let final_coordinates = protocol.final_opening_witness_coordinates()?; + let mut dedicated_rows = Vec::new(); + for coordinates in &final_coordinates { + if let Some(coordinate) = coordinates.evaluation { + dedicated_rows.push(coordinate.row); + } + if let Some(coordinate) = coordinates.blinding { + dedicated_rows.push(coordinate.row); + } + } + dedicated_rows.sort_unstable(); + dedicated_rows.dedup(); + for row in dedicated_rows { + random_witness_rows[row].fill(F::zero()); + random_witness_blindings[row] = F::zero(); + } + for (index, coordinates) in final_coordinates.iter().enumerate() { + if let Some(coordinate) = coordinates.evaluation { + random_witness_rows[coordinate.row][coordinate.column] = random_eval_outputs[index]; + } + if let Some(coordinate) = coordinates.blinding { + random_witness_rows[coordinate.row][coordinate.column] = random_eval_blindings[index]; + } + } + + let random_error_rows = row_committer.compute_error_rows( + &protocol.r1cs, + random_u, + &flatten(&random_witness_rows), + protocol.dimensions.error.row_count, + protocol.dimensions.error.row_len, + "random error rows", + )?; + ensure_len( + "random error rows", + protocol.dimensions.error.row_count, + random_error_rows.len(), + )?; + let random_error_blindings = (0..protocol.dimensions.error.row_count) + .map(|_| F::random(rng)) + .collect::>(); + let coefficient_range = protocol.dimensions.witness_rows.coefficients.clone(); + let output_claim_range = protocol.dimensions.witness_rows.output_claims.clone(); + let auxiliary_range = protocol.dimensions.witness_rows.auxiliary.clone(); + let random_round_commitments = row_committer.commit_rows( + setup, + &random_witness_rows[coefficient_range.clone()], + &random_witness_blindings[coefficient_range], + "random coefficient rows", + )?; + let random_output_claim_row_commitments = row_committer.commit_rows( + setup, + &random_witness_rows[output_claim_range.clone()], + &random_witness_blindings[output_claim_range], + "random output-claim rows", + )?; + let random_auxiliary_row_commitments = row_committer.commit_rows( + setup, + &random_witness_rows[auxiliary_range.clone()], + &random_witness_blindings[auxiliary_range], + "random auxiliary rows", + )?; + let random_error_row_commitments = row_committer.commit_rows( + setup, + &random_error_rows, + &random_error_blindings, + "random error rows", + )?; + let random_eval_rows = random_eval_outputs + .iter() + .copied() + .map(|output| vec![output]) + .collect::>(); + let random_eval_commitments = row_committer.commit_rows( + setup, + &random_eval_rows, + &random_eval_blindings, + "random eval rows", + )?; + let random_instance = protocol.random_relaxed_instance( + &random_round_commitments, + &random_output_claim_row_commitments, + &random_auxiliary_row_commitments, + &random_error_row_commitments, + &random_eval_commitments, + random_u, + )?; + + let cross_term_error_rows = row_committer.compute_cross_term_error_rows( + &protocol.r1cs, + F::one(), + &flatten(witness.rows), + random_u, + &flatten(&random_witness_rows), + protocol.dimensions.error.row_count, + protocol.dimensions.error.row_len, + "cross-term error rows", + )?; + ensure_len( + "cross-term error rows", + protocol.dimensions.error.row_count, + cross_term_error_rows.len(), + )?; + let cross_term_error_blindings = (0..protocol.dimensions.error.row_count) + .map(|_| F::random(rng)) + .collect::>(); + let cross_term_error_row_commitments = row_committer.commit_rows( + setup, + &cross_term_error_rows, + &cross_term_error_blindings, + "cross-term error rows", + )?; + + append_relaxed_instance( + transcript, + RelaxedInstanceLabels { + u: b"bf_committed_u", + witness: b"bf_committed_w", + error: b"bf_committed_e", + eval: b"bf_committed_eval", + }, + committed.u, + &committed.witness_row_commitments, + &committed.error_row_commitments, + &committed.eval_commitments, + ); + append_relaxed_instance( + transcript, + RelaxedInstanceLabels { + u: b"bf_random_u", + witness: b"bf_random_w", + error: b"bf_random_e", + eval: b"bf_random_eval", + }, + random_u, + &random_instance.witness_row_commitments, + &random_instance.error_row_commitments, + &random_instance.eval_commitments, + ); + append_values(transcript, b"bf_cross_e", &cross_term_error_row_commitments); + let folding_challenge = transcript.challenge(); + + let folded_u = F::one() + folding_challenge * random_u; + let folded_witness_rows = row_committer.fold_rows( + witness.rows, + &random_witness_rows, + folding_challenge, + "folded witness rows", + )?; + let folded_witness_blindings = row_committer.fold_scalars( + witness.blindings, + &random_witness_blindings, + folding_challenge, + "folded witness blindings", + )?; + let folded_error_rows = row_committer.fold_error_rows( + &zero_rows( + protocol.dimensions.error.row_count, + protocol.dimensions.error.row_len, + ), + &cross_term_error_rows, + &random_error_rows, + folding_challenge, + "folded error rows", + )?; + let zero_error_blindings = vec![F::zero(); protocol.dimensions.error.row_count]; + let folded_error_blindings = row_committer.fold_error_scalars( + &zero_error_blindings, + &cross_term_error_blindings, + &random_error_blindings, + folding_challenge, + "folded error blindings", + )?; + let folded_eval_outputs = row_committer.fold_scalars( + witness.eval_outputs, + &random_eval_outputs, + folding_challenge, + "folded eval outputs", + )?; + let folded_eval_blindings = row_committer.fold_scalars( + witness.eval_blindings, + &random_eval_blindings, + folding_challenge, + "folded eval blindings", + )?; + + let mut folded_eval_output_openings = Vec::new(); + let mut folded_eval_blinding_openings = Vec::new(); + for (index, coordinates) in final_coordinates.iter().enumerate() { + if let Some(coordinate) = coordinates.evaluation { + let (opening, opened) = open_witness_coordinate::( + setup, + row_committer, + &folded_witness_rows, + &folded_witness_blindings, + coordinate, + "folded eval output opening", + )?; + if opened != folded_eval_outputs[index] { + return Err(ProverError::EvalWitnessMismatch { + kind: "output", + index, + expected: folded_eval_outputs[index], + actual: opened, + }); + } + folded_eval_output_openings.push(opening); + } + if let Some(coordinate) = coordinates.blinding { + let (opening, opened) = open_witness_coordinate::( + setup, + row_committer, + &folded_witness_rows, + &folded_witness_blindings, + coordinate, + "folded eval blinding opening", + )?; + if opened != folded_eval_blindings[index] { + return Err(ProverError::EvalWitnessMismatch { + kind: "blinding", + index, + expected: folded_eval_blindings[index], + actual: opened, + }); + } + folded_eval_blinding_openings.push(opening); + } + } + for opening in &folded_eval_output_openings { + append_vector_opening( + transcript, + b"bf_eval_out_open", + b"bf_eval_out_blind", + opening, + ); + } + for opening in &folded_eval_blinding_openings { + append_vector_opening( + transcript, + b"bf_eval_blind_open", + b"bf_eval_blind_bl", + opening, + ); + } + + transcript.append(&Label(b"bf_spartan")); + let outer_num_vars = log2_power_of_two("error row count", protocol.dimensions.error.row_count)? + + log2_power_of_two("error row length", protocol.dimensions.error.row_len)?; + if outer_num_vars == 0 { + return Err(ProverError::DegenerateSumcheck { + name: "outer folded R1CS sumcheck", + }); + } + let tau = transcript.challenge_vector(outer_num_vars); + let flattened_folded_witness = flatten(&folded_witness_rows); + let flattened_folded_error = flatten(&folded_error_rows); + let outer_trace = prove_outer_sumcheck( + &protocol.r1cs, + folded_u, + &flattened_folded_witness, + &flattened_folded_error, + &tau, + transcript, + )?; + + let (az_rx, bz_rx, cz_rx) = abc_at_point( + &protocol.r1cs, + folded_u, + &flattened_folded_witness, + &outer_trace.point, + ); + let error_row_vars = log2_power_of_two("error row count", protocol.dimensions.error.row_count)?; + let (error_row_point, error_entry_point) = outer_trace.point.split_at(error_row_vars); + let (error_opening, _) = row_committer.open_rows( + setup, + &folded_error_rows, + &folded_error_blindings, + error_row_point, + error_entry_point, + "folded error row opening", + )?; + + append_values(transcript, b"bf_az_bz_cz", &[az_rx, bz_rx, cz_rx]); + append_vector_opening( + transcript, + b"bf_error_opening", + b"bf_error_blind", + &error_opening, + ); + + let ra = transcript.challenge(); + let rb = transcript.challenge(); + let rc = transcript.challenge(); + let inner_num_vars = + log2_power_of_two("witness row count", protocol.dimensions.witness.row_count)? + + log2_power_of_two("witness row length", protocol.dimensions.witness.row_len)?; + if inner_num_vars == 0 { + return Err(ProverError::DegenerateSumcheck { + name: "inner folded R1CS sumcheck", + }); + } + let row_weights = EqPolynomial::::evals(&outer_trace.point, None); + let public = protocol + .r1cs + .public_column_contributions(&row_weights, 0, folded_u)?; + let inner_claim = ra * (az_rx - public.a) + rb * (bz_rx - public.b) + rc * (cz_rx - public.c); + let inner_trace = prove_inner_sumcheck( + &protocol.r1cs, + &outer_trace.point, + &folded_witness_rows, + ra, + rb, + rc, + inner_claim, + transcript, + )?; + let witness_row_vars = + log2_power_of_two("witness row count", protocol.dimensions.witness.row_count)?; + let (witness_row_point, witness_entry_point) = inner_trace.point.split_at(witness_row_vars); + let (witness_opening, _) = row_committer.open_rows( + setup, + &folded_witness_rows, + &folded_witness_blindings, + witness_row_point, + witness_entry_point, + "folded witness row opening", + )?; + + Ok(BlindFoldProof { + auxiliary_row_commitments, + random_round_commitments, + random_output_claim_row_commitments, + random_auxiliary_row_commitments, + random_error_row_commitments, + random_eval_commitments, + random_u, + cross_term_error_row_commitments, + outer_sumcheck: outer_trace.proof, + az_rx, + bz_rx, + cz_rx, + inner_sumcheck: inner_trace.proof, + witness_opening, + error_opening, + folded_eval_outputs, + folded_eval_blindings, + folded_eval_output_openings, + folded_eval_blinding_openings, + }) +} + +fn validate_witness( + setup: &VC::Setup, + protocol: &BlindFoldProtocol, + witness: BlindFoldWitness<'_, F>, +) -> Result<(), ProverError> +where + F: Field, + VC: VectorCommitment, +{ + let _ = log2_power_of_two("witness row count", protocol.dimensions.witness.row_count)?; + let _ = log2_power_of_two("witness row length", protocol.dimensions.witness.row_len)?; + let _ = log2_power_of_two("error row count", protocol.dimensions.error.row_count)?; + let _ = log2_power_of_two("error row length", protocol.dimensions.error.row_len)?; + ensure_row_capacity::(setup, "witness rows", protocol.dimensions.witness.row_len)?; + ensure_row_capacity::(setup, "error rows", protocol.dimensions.error.row_len)?; + ensure_row_capacity::(setup, "evaluation rows", 1)?; + ensure_len( + "witness rows", + protocol.dimensions.witness.row_count, + witness.rows.len(), + )?; + ensure_len( + "witness row blindings", + protocol.dimensions.witness.row_count, + witness.blindings.len(), + )?; + for (row, values) in witness.rows.iter().enumerate() { + if values.len() != protocol.dimensions.witness.row_len { + return Err(ProverError::WitnessRowLengthMismatch { + row, + expected: protocol.dimensions.witness.row_len, + actual: values.len(), + }); + } + } + ensure_len( + "final opening evaluation values", + protocol.eval_commitments.len(), + witness.eval_outputs.len(), + )?; + ensure_len( + "final opening blindings", + protocol.eval_commitments.len(), + witness.eval_blindings.len(), + )?; + Ok(()) +} + +fn ensure_row_capacity( + setup: &VC::Setup, + name: &'static str, + row_len: usize, +) -> Result<(), ProverError> +where + F: Field, + VC: VectorCommitment, +{ + let capacity = VC::capacity(setup); + if row_len > capacity { + return Err(ProverError::CommitmentCapacityExceeded { + name, + capacity, + row_len, + }); + } + Ok(()) +} + +fn commit_rows( + setup: &VC::Setup, + rows: &[Vec], + blindings: &[F], + name: &'static str, +) -> Result, ProverError> +where + F: Field, + VC: VectorCommitment, +{ + ensure_len(name, rows.len(), blindings.len())?; + let capacity = VC::capacity(setup); + for row in rows { + if row.len() > capacity { + return Err(ProverError::CommitmentCapacityExceeded { + name, + capacity, + row_len: row.len(), + }); + } + } + Ok(rows + .par_iter() + .zip(blindings.par_iter()) + .map(|(row, blinding)| VC::commit(setup, row, blinding)) + .collect()) +} + +fn open_committed_rows( + setup: &VC::Setup, + rows: &[Vec], + blindings: &[F], + row_point: &[F], + entry_point: &[F], + name: &'static str, +) -> Result<(VectorCommitmentOpening, F), ProverError> +where + F: Field, + VC: VectorCommitment, +{ + let row_count = basis_len_from_point_len("row point", row_point.len())?; + ensure_len(name, row_count, rows.len())?; + ensure_len(name, row_count, blindings.len())?; + let row_len = rows.first().map_or(0, Vec::len); + let expected_row_len = basis_len_from_point_len("entry point", entry_point.len())?; + ensure_len(name, expected_row_len, row_len)?; + ensure_row_capacity::(setup, name, row_len)?; + for (row_index, row) in rows.iter().enumerate() { + if row.len() != row_len { + return Err(ProverError::WitnessRowLengthMismatch { + row: row_index, + expected: row_len, + actual: row.len(), + }); + } + } + Ok(VC::open_committed_rows( + &flatten(rows), + blindings, + row_len, + row_point, + entry_point, + )?) +} + +fn basis_len_from_point_len( + name: &'static str, + point_len: usize, +) -> Result> +where + F: Field, +{ + if point_len >= usize::BITS as usize { + return Err(ProverError::DimensionOverflow { + name, + value: point_len, + }); + } + Ok(1_usize << point_len) +} + +#[derive(Clone, Debug)] +struct SumcheckTrace { + proof: CompressedSumcheckProof, + point: Vec, +} + +fn prove_outer_sumcheck( + r1cs: &ConstraintMatrices, + u: F, + witness: &[F], + error_values: &[F], + tau: &[F], + transcript: &mut T, +) -> Result, ProverError> +where + F: Field + AppendToTranscript, + T: Transcript, +{ + let num_vars = log2_power_of_two("outer folded R1CS sumcheck", error_values.len())?; + ensure_len("outer challenge vector", num_vars, tau.len())?; + + let z = z_vector(u, witness); + let mut az = matrix_vector_product(&r1cs.a, &z); + let mut bz = matrix_vector_product(&r1cs.b, &z); + let mut cz = matrix_vector_product(&r1cs.c, &z); + let mut e = error_values.to_vec(); + let padded_len = error_values.len(); + pad_to_len("outer Az values", &mut az, padded_len)?; + pad_to_len("outer Bz values", &mut bz, padded_len)?; + pad_to_len("outer Cz values", &mut cz, padded_len)?; + + let mut az = Polynomial::new(az); + let mut bz = Polynomial::new(bz); + let mut cz = Polynomial::new(cz); + let mut e = Polynomial::new(std::mem::take(&mut e)); + let mut eq_tau = Polynomial::new(EqPolynomial::::evals(tau, None)); + + let mut running_sum = F::zero(); + let mut rounds = Vec::with_capacity(num_vars); + let mut point = Vec::with_capacity(num_vars); + + for _round in 0..num_vars { + let half = az.len() / 2; + let mut evals = [F::zero(); OUTER_SUMCHECK_DEGREE + 1]; + for i in 0..half { + let (eq_lo, eq_hi) = eq_tau.sumcheck_eval_pair(i, BindingOrder::HighToLow); + let (az_lo, az_hi) = az.sumcheck_eval_pair(i, BindingOrder::HighToLow); + let (bz_lo, bz_hi) = bz.sumcheck_eval_pair(i, BindingOrder::HighToLow); + let (cz_lo, cz_hi) = cz.sumcheck_eval_pair(i, BindingOrder::HighToLow); + let (e_lo, e_hi) = e.sumcheck_eval_pair(i, BindingOrder::HighToLow); + + let eq_delta = eq_hi - eq_lo; + let az_delta = az_hi - az_lo; + let bz_delta = bz_hi - bz_lo; + let cz_delta = cz_hi - cz_lo; + let e_delta = e_hi - e_lo; + + evals[0] += eq_lo * (az_lo * bz_lo - u * cz_lo - e_lo); + evals[1] += eq_hi * (az_hi * bz_hi - u * cz_hi - e_hi); + + let eq_2 = eq_lo + eq_delta + eq_delta; + let az_2 = az_lo + az_delta + az_delta; + let bz_2 = bz_lo + bz_delta + bz_delta; + let cz_2 = cz_lo + cz_delta + cz_delta; + let e_2 = e_lo + e_delta + e_delta; + evals[2] += eq_2 * (az_2 * bz_2 - u * cz_2 - e_2); + + let eq_3 = eq_2 + eq_delta; + let az_3 = az_2 + az_delta; + let bz_3 = bz_2 + bz_delta; + let cz_3 = cz_2 + cz_delta; + let e_3 = e_2 + e_delta; + evals[3] += eq_3 * (az_3 * bz_3 - u * cz_3 - e_3); + } + + let round_poly = UnivariatePoly::from_evals(&evals); + let round_sum = + round_poly.coefficients()[0] + round_poly.coefficients().iter().copied().sum::(); + if round_sum != running_sum { + return Err(ProverError::SumcheckRoundClaimMismatch { + expected: running_sum, + actual: round_sum, + }); + } + let compressed = round_poly.compress(); + append_values( + transcript, + SUMCHECK_ROUND_TRANSCRIPT_LABEL, + compressed.coeffs_except_linear_term(), + ); + let challenge = transcript.challenge(); + running_sum = round_poly.evaluate(challenge); + az.bind_with_order(challenge, BindingOrder::HighToLow); + bz.bind_with_order(challenge, BindingOrder::HighToLow); + cz.bind_with_order(challenge, BindingOrder::HighToLow); + e.bind_with_order(challenge, BindingOrder::HighToLow); + eq_tau.bind_with_order(challenge, BindingOrder::HighToLow); + point.push(challenge); + rounds.push(compressed); + } + + Ok(SumcheckTrace { + proof: CompressedSumcheckProof { + round_polynomials: rounds, + }, + point, + }) +} + +#[expect( + clippy::too_many_arguments, + reason = "inner folded R1CS sumcheck is parameterized by three random matrix weights" +)] +fn prove_inner_sumcheck( + r1cs: &ConstraintMatrices, + outer_point: &[F], + witness_rows: &[Vec], + ra: F, + rb: F, + rc: F, + claim: F, + transcript: &mut T, +) -> Result, ProverError> +where + F: Field + AppendToTranscript, + T: Transcript, +{ + let witness_values = flatten(witness_rows); + let num_vars = log2_power_of_two("inner folded R1CS sumcheck", witness_values.len())?; + let row_weights = EqPolynomial::::evals(outer_point, None); + let l_w = + linear_form_project_columns(r1cs, &row_weights, 1, witness_values.len(), [ra, rb, rc])?; + + let mut l_w = Polynomial::new(l_w); + let mut witness = Polynomial::new(witness_values); + let mut running_sum = claim; + let mut rounds = Vec::with_capacity(num_vars); + let mut point = Vec::with_capacity(num_vars); + + for _round in 0..num_vars { + let half = l_w.len() / 2; + let mut evals = [F::zero(); INNER_SUMCHECK_DEGREE + 1]; + for i in 0..half { + let (lw_lo, lw_hi) = l_w.sumcheck_eval_pair(i, BindingOrder::HighToLow); + let (w_lo, w_hi) = witness.sumcheck_eval_pair(i, BindingOrder::HighToLow); + let lw_delta = lw_hi - lw_lo; + let w_delta = w_hi - w_lo; + + evals[0] += lw_lo * w_lo; + evals[1] += lw_hi * w_hi; + + let lw_2 = lw_lo + lw_delta + lw_delta; + let w_2 = w_lo + w_delta + w_delta; + evals[2] += lw_2 * w_2; + } + + let round_poly = UnivariatePoly::from_evals(&evals); + let round_sum = + round_poly.coefficients()[0] + round_poly.coefficients().iter().copied().sum::(); + if round_sum != running_sum { + return Err(ProverError::SumcheckRoundClaimMismatch { + expected: running_sum, + actual: round_sum, + }); + } + let compressed = round_poly.compress(); + append_values( + transcript, + INNER_SUMCHECK_LABEL, + compressed.coeffs_except_linear_term(), + ); + let challenge = transcript.challenge(); + running_sum = round_poly.evaluate(challenge); + l_w.bind_with_order(challenge, BindingOrder::HighToLow); + witness.bind_with_order(challenge, BindingOrder::HighToLow); + point.push(challenge); + rounds.push(compressed); + } + + Ok(SumcheckTrace { + proof: CompressedSumcheckProof { + round_polynomials: rounds, + }, + point, + }) +} + +fn matrix_vector_product(rows: &[SparseRow], vector: &[F]) -> Vec +where + F: Field, +{ + rows.par_iter().map(|row| dot(row, vector)).collect() +} + +fn linear_form_project_columns( + r1cs: &ConstraintMatrices, + row_weights: &[F], + start_col: usize, + col_count: usize, + weights: [F; 3], +) -> Result, ProverError> +where + F: Field, +{ + if row_weights.len() < r1cs.num_constraints { + return Err(ConstraintMatrixEvalError::RowWeightsLengthMismatch { + expected: r1cs.num_constraints, + actual: row_weights.len(), + } + .into()); + } + let end_col = + start_col + .checked_add(col_count) + .ok_or(ConstraintMatrixEvalError::ColumnRangeOverflow { + start: start_col, + count: col_count, + })?; + + let mut projected = vec![F::zero(); col_count]; + project_matrix_columns( + &mut projected, + &r1cs.a, + row_weights, + start_col, + end_col, + weights[0], + ); + project_matrix_columns( + &mut projected, + &r1cs.b, + row_weights, + start_col, + end_col, + weights[1], + ); + project_matrix_columns( + &mut projected, + &r1cs.c, + row_weights, + start_col, + end_col, + weights[2], + ); + Ok(projected) +} + +fn project_matrix_columns( + projected: &mut [F], + rows: &[SparseRow], + row_weights: &[F], + start_col: usize, + end_col: usize, + weight: F, +) where + F: Field, +{ + if weight.is_zero() { + return; + } + for (row, &row_weight) in rows.iter().zip(row_weights) { + let scaled_weight = weight * row_weight; + for &(column, coefficient) in row { + if (start_col..end_col).contains(&column) { + projected[column - start_col] += scaled_weight * coefficient; + } + } + } +} + +fn abc_at_point(r1cs: &ConstraintMatrices, u: F, witness: &[F], point: &[F]) -> (F, F, F) +where + F: Field, +{ + let row_weights = EqPolynomial::::evals(point, None); + let z = z_vector(u, witness); + let mut az = F::zero(); + let mut bz = F::zero(); + let mut cz = F::zero(); + for (row_index, &row_weight) in row_weights.iter().enumerate().take(r1cs.num_constraints) { + az += row_weight * dot(&r1cs.a[row_index], &z); + bz += row_weight * dot(&r1cs.b[row_index], &z); + cz += row_weight * dot(&r1cs.c[row_index], &z); + } + (az, bz, cz) +} + +fn open_witness_coordinate( + setup: &VC::Setup, + row_committer: &mut C, + witness_rows: &[Vec], + witness_blindings: &[F], + coordinate: WitnessCoordinate, + name: &'static str, +) -> Result<(VectorCommitmentOpening, F), ProverError> +where + F: Field, + VC: VectorCommitment, + C: BlindFoldRowCommitter, +{ + let row_vars = log2_power_of_two("witness row count", witness_rows.len())?; + let entry_vars = log2_power_of_two("witness row length", witness_rows[0].len())?; + row_committer.open_rows( + setup, + witness_rows, + witness_blindings, + &boolean_point(coordinate.row, row_vars), + &boolean_point(coordinate.column, entry_vars), + name, + ) +} + +#[derive(Clone, Copy, Debug)] +struct RelaxedInstanceLabels { + u: &'static [u8], + witness: &'static [u8], + error: &'static [u8], + eval: &'static [u8], +} + +fn append_relaxed_instance( + transcript: &mut T, + labels: RelaxedInstanceLabels, + u: F, + witness_commitments: &[C], + error_commitments: &[C], + eval_commitments: &[C], +) where + F: AppendToTranscript, + C: AppendToTranscript, + T: Transcript, +{ + transcript.append(&Label(labels.u)); + u.append_to_transcript(transcript); + append_values(transcript, labels.witness, witness_commitments); + append_values(transcript, labels.error, error_commitments); + append_values(transcript, labels.eval, eval_commitments); +} + +fn append_values(transcript: &mut T, label: &'static [u8], values: &[A]) +where + A: AppendToTranscript, + T: Transcript, +{ + transcript.append(&LabelWithCount(label, values.len() as u64)); + for value in values { + value.append_to_transcript(transcript); + } +} + +fn append_vector_opening( + transcript: &mut T, + row_label: &'static [u8], + blinding_label: &'static [u8], + opening: &VectorCommitmentOpening, +) where + F: AppendToTranscript, + T: Transcript, +{ + append_values(transcript, row_label, &opening.combined_vector); + transcript.append(&Label(blinding_label)); + opening.combined_blinding.append_to_transcript(transcript); +} + +fn random_rows(row_count: usize, row_len: usize, rng: &mut R) -> Vec> +where + F: Field, + R: RngCore, +{ + (0..row_count) + .map(|_| (0..row_len).map(|_| F::random(rng)).collect()) + .collect() +} + +fn zero_rows(row_count: usize, row_len: usize) -> Vec> { + vec![vec![F::zero(); row_len]; row_count] +} + +fn fold_rows( + real: &[Vec], + random: &[Vec], + challenge: F, +) -> Result>, ProverError> +where + F: Field, +{ + ensure_len("random witness rows", real.len(), random.len())?; + let mut folded = Vec::with_capacity(real.len()); + for (row_index, (real_row, random_row)) in real.iter().zip(random).enumerate() { + if real_row.len() != random_row.len() { + return Err(ProverError::WitnessRowLengthMismatch { + row: row_index, + expected: real_row.len(), + actual: random_row.len(), + }); + } + folded.push( + real_row + .iter() + .zip(random_row) + .map(|(&real, &random)| real + challenge * random) + .collect(), + ); + } + Ok(folded) +} + +fn fold_scalars( + name: &'static str, + real: &[F], + random: &[F], + challenge: F, +) -> Result, ProverError> +where + F: Field, +{ + ensure_len(name, real.len(), random.len())?; + Ok(real + .iter() + .zip(random) + .map(|(&real, &random)| real + challenge * random) + .collect()) +} + +fn fold_error_rows( + real: &[Vec], + cross: &[Vec], + random: &[Vec], + challenge: F, +) -> Result>, ProverError> +where + F: Field, +{ + ensure_len("cross-term error rows", real.len(), cross.len())?; + ensure_len("random error rows", real.len(), random.len())?; + let challenge_squared = challenge * challenge; + let mut folded = Vec::with_capacity(real.len()); + for (row_index, ((real_row, cross_row), random_row)) in + real.iter().zip(cross).zip(random).enumerate() + { + if real_row.len() != cross_row.len() { + return Err(ProverError::WitnessRowLengthMismatch { + row: row_index, + expected: real_row.len(), + actual: cross_row.len(), + }); + } + if real_row.len() != random_row.len() { + return Err(ProverError::WitnessRowLengthMismatch { + row: row_index, + expected: real_row.len(), + actual: random_row.len(), + }); + } + folded.push( + real_row + .iter() + .zip(cross_row) + .zip(random_row) + .map(|((&real, &cross), &random)| { + real + challenge * cross + challenge_squared * random + }) + .collect(), + ); + } + Ok(folded) +} + +fn fold_error_scalars( + name: &'static str, + real: &[F], + cross: &[F], + random: &[F], + challenge: F, +) -> Result, ProverError> +where + F: Field, +{ + ensure_len(name, real.len(), cross.len())?; + ensure_len(name, real.len(), random.len())?; + let challenge_squared = challenge * challenge; + Ok(real + .iter() + .zip(cross) + .zip(random) + .map(|((&real, &cross), &random)| real + challenge * cross + challenge_squared * random) + .collect()) +} + +fn error_rows_for( + r1cs: &ConstraintMatrices, + u: F, + witness: &[F], + row_count: usize, + row_len: usize, +) -> Result>, ProverError> +where + F: Field, +{ + let _ = log2_power_of_two("error row length", row_len)?; + let target_len = row_count + .checked_mul(row_len) + .ok_or(ProverError::DimensionOverflow { + name: "error values", + value: row_count, + })?; + let z = z_vector(u, witness); + let mut errors = (0..r1cs.num_constraints) + .map(|row_index| { + dot(&r1cs.a[row_index], &z) * dot(&r1cs.b[row_index], &z) + - u * dot(&r1cs.c[row_index], &z) + }) + .collect::>(); + pad_to_len("error values", &mut errors, target_len)?; + Ok(errors.chunks(row_len).map(<[F]>::to_vec).collect()) +} + +fn cross_term_error_rows_for( + r1cs: &ConstraintMatrices, + real_u: F, + real_witness: &[F], + random_u: F, + random_witness: &[F], + row_count: usize, + row_len: usize, +) -> Result>, ProverError> +where + F: Field, +{ + let _ = log2_power_of_two("error row length", row_len)?; + let target_len = row_count + .checked_mul(row_len) + .ok_or(ProverError::DimensionOverflow { + name: "cross-term error values", + value: row_count, + })?; + let real_z = z_vector(real_u, real_witness); + let random_z = z_vector(random_u, random_witness); + let mut errors = (0..r1cs.num_constraints) + .map(|row_index| { + dot(&r1cs.a[row_index], &real_z) * dot(&r1cs.b[row_index], &random_z) + + dot(&r1cs.a[row_index], &random_z) * dot(&r1cs.b[row_index], &real_z) + - real_u * dot(&r1cs.c[row_index], &random_z) + - random_u * dot(&r1cs.c[row_index], &real_z) + }) + .collect::>(); + pad_to_len("cross-term error values", &mut errors, target_len)?; + Ok(errors.chunks(row_len).map(<[F]>::to_vec).collect()) +} + +fn boolean_point(index: usize, num_vars: usize) -> Vec +where + F: Field, +{ + (0..num_vars) + .map(|bit| { + let shift = num_vars - bit - 1; + F::from_u64(((index >> shift) & 1) as u64) + }) + .collect() +} + +fn pad_to_len( + name: &'static str, + values: &mut Vec, + target_len: usize, +) -> Result<(), ProverError> +where + F: Field, +{ + if values.len() > target_len { + return Err(ProverError::LengthMismatch { + name, + expected: target_len, + actual: values.len(), + }); + } + values.resize(target_len, F::zero()); + Ok(()) +} + +fn z_vector(u: F, witness: &[F]) -> Vec +where + F: Field, +{ + let mut z = Vec::with_capacity(witness.len() + 1); + z.push(u); + z.extend_from_slice(witness); + z +} + +fn dot(row: &[(usize, F)], witness: &[F]) -> F +where + F: Field, +{ + row.iter() + .map(|&(column, coefficient)| coefficient * witness[column]) + .sum() +} + +fn flatten(rows: &[Vec]) -> Vec +where + F: Field, +{ + rows.iter().flat_map(|row| row.iter().copied()).collect() +} + +fn ensure_len(name: &'static str, expected: usize, actual: usize) -> Result<(), ProverError> +where + F: Field, +{ + if expected != actual { + return Err(ProverError::LengthMismatch { + name, + expected, + actual, + }); + } + Ok(()) +} + +fn log2_power_of_two(name: &'static str, value: usize) -> Result> +where + F: Field, +{ + if value == 0 || !value.is_power_of_two() { + return Err(ProverError::InvalidPowerOfTwo { name, value }); + } + Ok(value.trailing_zeros() as usize) +} diff --git a/crates/jolt-blindfold/src/r1cs.rs b/crates/jolt-blindfold/src/r1cs.rs index c237f86f72..b4851dfcfc 100644 --- a/crates/jolt-blindfold/src/r1cs.rs +++ b/crates/jolt-blindfold/src/r1cs.rs @@ -53,196 +53,196 @@ pub struct FinalOpeningLayout { pub blinding: Option, } -impl BlindFoldStatement +pub fn build_with_sources( + builder: &mut R1csBuilder, + statement: &BlindFoldStatement, + publics: &[(P, F)], + challenges: &[(Ch, F)], +) -> Result where F: Field, O: Clone + PartialEq, P: Clone + PartialEq, Ch: Clone + PartialEq, { - pub fn build_with_sources( - &self, - builder: &mut R1csBuilder, - publics: &[(P, F)], - challenges: &[(Ch, F)], - ) -> Result { - let layout = self.allocate_layout(builder)?; - let mut claim_sources = ClaimSourceTable::::new(); - insert_output_claim_sources(self, &layout, &mut claim_sources)?; - for (id, value) in publics { - claim_sources.insert_public(id.clone(), *value); - } - for (id, value) in challenges { - claim_sources.insert_challenge(id.clone(), *value); - } - self.append(builder, &layout, &mut claim_sources)?; - Ok(layout) + let layout = allocate_layout(builder, statement)?; + let mut claim_sources = ClaimSourceTable::::new(); + insert_output_claim_sources(statement, &layout, &mut claim_sources)?; + for (id, value) in publics { + claim_sources.insert_public(id.clone(), *value); } + for (id, value) in challenges { + claim_sources.insert_challenge(id.clone(), *value); + } + append(builder, statement, &layout, &mut claim_sources)?; + Ok(layout) } -impl BlindFoldStatement +pub fn build( + builder: &mut R1csBuilder, + statement: &BlindFoldStatement, + claim_sources: &mut R, +) -> Result where F: Field, + R: ClaimSources, { - pub fn build( - &self, - builder: &mut R1csBuilder, - claim_sources: &mut R, - ) -> Result - where - R: ClaimSources, - { - let layout = self.allocate_layout(builder)?; - self.append(builder, &layout, claim_sources)?; - Ok(layout) - } - - pub fn append( - &self, - builder: &mut R1csBuilder, - layout: &Layout, - claim_sources: &mut R, - ) -> Result<(), Error> - where - R: ClaimSources, - { - validate_stage_count(self, layout)?; - validate_final_opening_count(self, layout)?; - - for (stage_index, (stage, stage_layout)) in - self.stages.iter().zip(&layout.stages).enumerate() - { - assert_claim_expr_eq( - builder, - &stage.input_claim, - stage_layout.sumcheck.input_claim, - claim_sources, - )?; + let layout = allocate_layout(builder, statement)?; + append(builder, statement, &layout, claim_sources)?; + Ok(layout) +} - append_sumcheck_r1cs_constraints_for_domain( - builder, - stage.statement, - &stage.consistency.rounds, - &stage_layout.sumcheck, - stage.domain, - ) - .map_err(|source| Error::Sumcheck { - stage_index, - source, - })?; - - assert_claim_expr_eq( - builder, - &stage.output_claim, - stage_layout.sumcheck.output_claim, - claim_sources, - )?; - } +pub fn append( + builder: &mut R1csBuilder, + statement: &BlindFoldStatement, + layout: &Layout, + claim_sources: &mut R, +) -> Result<(), Error> +where + F: Field, + R: ClaimSources, +{ + validate_stage_count(statement, layout)?; + validate_final_opening_count(statement, layout)?; - for (binding, binding_layout) in self.final_openings.iter().zip(&layout.final_openings) { - let Some(evaluation) = binding_layout.evaluation else { - continue; - }; - let mut combined = LinearCombination::zero(); - for (opening_id, &coefficient) in binding.opening_ids.iter().zip(&binding.coefficients) - { - combined = combined - + claim_sources - .opening(opening_id)? - .into_linear_combination() - .scale(coefficient); - } - builder.assert_equal(combined, evaluation); + for (stage_index, (stage, stage_layout)) in + statement.stages.iter().zip(&layout.stages).enumerate() + { + assert_claim_expr_eq( + builder, + &stage.input_claim, + stage_layout.sumcheck.input_claim, + claim_sources, + )?; + + append_sumcheck_r1cs_constraints_for_domain( + builder, + stage.statement, + &stage.consistency.rounds, + &stage_layout.sumcheck, + stage.domain, + ) + .map_err(|source| Error::Sumcheck { + stage_index, + source, + })?; + + assert_claim_expr_eq( + builder, + &stage.output_claim, + stage_layout.sumcheck.output_claim, + claim_sources, + )?; + } + + for (binding, binding_layout) in statement.final_openings.iter().zip(&layout.final_openings) { + let Some(evaluation) = binding_layout.evaluation else { + continue; + }; + let mut combined = LinearCombination::zero(); + for (opening_id, &coefficient) in binding.opening_ids.iter().zip(&binding.coefficients) { + combined = combined + + claim_sources + .opening(opening_id)? + .into_linear_combination() + .scale(coefficient); } - - Ok(()) + builder.assert_equal(combined, evaluation); } - pub fn allocate_layout(&self, builder: &mut R1csBuilder) -> Result { - let witness_row_len = witness_row_len(self)?; - - let coefficients = self - .stages - .iter() - .enumerate() - .map(|(stage_index, stage)| { - validate_stage_statement(stage.statement, &stage.consistency.rounds).map_err( - |source| LayoutError::Sumcheck { - stage_index, - source, - }, - )?; - Ok(stage - .consistency - .rounds - .iter() - .map(|round| { - let coefficients = (0..=round.degree) - .map(|_| builder.alloc_unknown()) - .collect::>(); - for _ in coefficients.len()..witness_row_len { - let _ = builder.alloc(F::zero()); - } - coefficients - }) - .collect::>()) - }) - .collect::, _>>()?; + Ok(()) +} - let output_claim_rows = self - .stages - .iter() - .map(|stage| allocate_output_claim_rows(builder, stage, witness_row_len)) - .collect::>(); +pub fn allocate_layout( + builder: &mut R1csBuilder, + statement: &BlindFoldStatement, +) -> Result +where + F: Field, +{ + let witness_row_len = witness_row_len(statement)?; - let stages = self - .stages - .iter() - .zip(coefficients) - .zip(output_claim_rows) - .map(|((stage, stage_coefficients), output_claim_rows)| { - let input_claim = builder.alloc_unknown(); - let mut claim_in = input_claim; - let mut rounds = Vec::with_capacity(stage.consistency.rounds.len()); - - for coefficients in stage_coefficients { - let claim_out = builder.alloc_unknown(); - rounds.push(SumcheckR1csRoundLayout { - claim_in, - coefficients, - claim_out, - }); - claim_in = claim_out; - } - - Ok(StageLayout { - output_claim_rows, - sumcheck: SumcheckR1csLayout { - input_claim, - rounds, - output_claim: claim_in, - }, + let coefficients = statement + .stages + .iter() + .enumerate() + .map(|(stage_index, stage)| { + validate_stage_statement(stage.statement, &stage.consistency.rounds).map_err( + |source| LayoutError::Sumcheck { + stage_index, + source, + }, + )?; + Ok(stage + .consistency + .rounds + .iter() + .map(|round| { + let coefficients = (0..=round.degree) + .map(|_| builder.alloc_unknown()) + .collect::>(); + for _ in coefficients.len()..witness_row_len { + let _ = builder.alloc(F::zero()); + } + coefficients }) - }) - .collect::, LayoutError>>()?; + .collect::>()) + }) + .collect::, _>>()?; - let final_openings = self - .final_openings - .iter() - .map(|binding| FinalOpeningLayout { - evaluation: (!binding.opening_ids.is_empty()) - .then(|| allocate_private_row_scalar(builder, witness_row_len)), - blinding: (!binding.opening_ids.is_empty()) - .then(|| allocate_private_row_scalar(builder, witness_row_len)), + let output_claim_rows = statement + .stages + .iter() + .map(|stage| allocate_output_claim_rows(builder, stage, witness_row_len)) + .collect::>(); + + let stages = statement + .stages + .iter() + .zip(coefficients) + .zip(output_claim_rows) + .map(|((stage, stage_coefficients), output_claim_rows)| { + let input_claim = builder.alloc_unknown(); + let mut claim_in = input_claim; + let mut rounds = Vec::with_capacity(stage.consistency.rounds.len()); + + for coefficients in stage_coefficients { + let claim_out = builder.alloc_unknown(); + rounds.push(SumcheckR1csRoundLayout { + claim_in, + coefficients, + claim_out, + }); + claim_in = claim_out; + } + + Ok(StageLayout { + output_claim_rows, + sumcheck: SumcheckR1csLayout { + input_claim, + rounds, + output_claim: claim_in, + }, }) - .collect(); + }) + .collect::, LayoutError>>()?; - Ok(Layout { - witness_row_len, - stages, - final_openings, + let final_openings = statement + .final_openings + .iter() + .map(|binding| FinalOpeningLayout { + evaluation: (!binding.opening_ids.is_empty()) + .then(|| allocate_private_row_scalar(builder, witness_row_len)), + blinding: (!binding.opening_ids.is_empty()) + .then(|| allocate_private_row_scalar(builder, witness_row_len)), }) - } + .collect(); + + Ok(Layout { + witness_row_len, + stages, + final_openings, + }) } fn allocate_private_row_scalar( @@ -519,9 +519,7 @@ mod tests { let statement = BlindFoldStatement::new(vec![empty_stage(2, 3, &[1, 3])], Vec::new()); let mut builder = R1csBuilder::::new(); - let layout = statement - .allocate_layout(&mut builder) - .expect("layout allocates"); + let layout = allocate_layout(&mut builder, &statement).expect("layout allocates"); assert_eq!(layout.stage_count(), 1); let stage = &layout.stages[0]; @@ -553,8 +551,7 @@ mod tests { final_openings: Vec::new(), }; - let error = statement - .append(&mut builder, &layout, &mut sources) + let error = append(&mut builder, &statement, &layout, &mut sources) .expect_err("stage counts differ"); assert_eq!( @@ -571,9 +568,7 @@ mod tests { let statement = BlindFoldStatement::new(vec![empty_stage(2, 2, &[2])], Vec::new()); let mut builder = R1csBuilder::::new(); - let error = statement - .allocate_layout(&mut builder) - .expect_err("round counts differ"); + let error = allocate_layout(&mut builder, &statement).expect_err("round counts differ"); assert_eq!( error, @@ -592,9 +587,7 @@ mod tests { let statement = BlindFoldStatement::new(vec![empty_stage(1, 2, &[3])], Vec::new()); let mut builder = R1csBuilder::::new(); - let error = statement - .allocate_layout(&mut builder) - .expect_err("degree exceeds bound"); + let error = allocate_layout(&mut builder, &statement).expect_err("degree exceeds bound"); assert_eq!( error, @@ -629,9 +622,8 @@ mod tests { Vec::new(), ); - let layout = statement - .build(&mut builder, &mut sources) - .expect("constraints should build"); + let layout = + build(&mut builder, &statement, &mut sources).expect("constraints should build"); let stage_layout = &layout.stages[0].sumcheck; assign(&mut builder, stage_layout.input_claim, 10); @@ -666,8 +658,7 @@ mod tests { ); let mut builder = R1csBuilder::::new(); - let layout = statement - .build_with_sources(&mut builder, &[], &[]) + let layout = build_with_sources(&mut builder, &statement, &[], &[]) .expect("constraints should build"); let stage_layout = &layout.stages[0].sumcheck; assign(&mut builder, stage_layout.input_claim, 10); @@ -699,8 +690,7 @@ mod tests { ); let mut sources = ClaimSourceTable::::new(); - let layout = statement - .build(&mut builder, &mut sources) + let layout = build(&mut builder, &statement, &mut sources) .expect("constraints should build for centered domain"); let stage_layout = &layout.stages[0].sumcheck; assign(&mut builder, stage_layout.input_claim, 26); @@ -725,9 +715,7 @@ mod tests { let mut builder = R1csBuilder::::new(); let mut sources = ClaimSourceTable::::new(); - let error = statement - .build(&mut builder, &mut sources) - .expect_err("opening is missing"); + let error = build(&mut builder, &statement, &mut sources).expect_err("opening is missing"); assert_eq!(error, Error::Claim(ClaimLoweringError::MissingOpening)); } @@ -745,15 +733,12 @@ mod tests { Vec::new(), ); let mut builder = R1csBuilder::::new(); - let mut layout = statement - .allocate_layout(&mut builder) - .expect("layout allocates"); + let mut layout = allocate_layout(&mut builder, &statement).expect("layout allocates"); layout.stages[0].sumcheck.rounds[0].claim_in = layout.stages[0].sumcheck.rounds[0].claim_out; let mut sources = ClaimSourceTable::::new(); - let error = statement - .append(&mut builder, &layout, &mut sources) + let error = append(&mut builder, &statement, &layout, &mut sources) .expect_err("layout claim chain is broken"); assert_eq!( diff --git a/crates/jolt-blindfold/src/verify.rs b/crates/jolt-blindfold/src/verify.rs index d6a81baf51..ac09e6a355 100644 --- a/crates/jolt-blindfold/src/verify.rs +++ b/crates/jolt-blindfold/src/verify.rs @@ -3,7 +3,7 @@ use jolt_field::{Field, FieldCore, RingAccumulator, WithAccumulator}; use jolt_poly::EqPolynomial; use jolt_r1cs::{ConstraintMatrices, MatrixColumnContributions}; use jolt_sumcheck::{BooleanHypercube, SumcheckClaim, SUMCHECK_ROUND_TRANSCRIPT_LABEL}; -use jolt_transcript::{AppendToTranscript, Label, Transcript}; +use jolt_transcript::{AppendToTranscript, Label, LabelWithCount, Transcript}; use crate::{ BlindFoldProof, BlindFoldProtocol, RelaxedError, RelaxedInstance, VerificationError, @@ -14,481 +14,433 @@ const OUTER_SUMCHECK_DEGREE: usize = 3; const INNER_SUMCHECK_DEGREE: usize = 2; const INNER_SUMCHECK_LABEL: &[u8] = b"inner_sumcheck_poly"; -impl BlindFoldProtocol +pub fn verify( + protocol: &BlindFoldProtocol, + proof: &BlindFoldProof, + vc_setup: &VC::Setup, + transcript: &mut T, +) -> Result<(), VerificationError> where F: Field + AppendToTranscript, - Com: Copy + HomomorphicCommitment + AppendToTranscript, + VC: VectorCommitment, + VC::Output: Copy + HomomorphicCommitment + AppendToTranscript, + T: Transcript, ::Accumulator: RingAccumulator, { - pub fn verify( - &self, - proof: &BlindFoldProof, - vc_setup: &VC::Setup, - transcript: &mut T, - ) -> Result<(), VerificationError> - where - VC: VectorCommitment, - T: Transcript, - { - let folded = self.folded_instance_from_proof(proof, transcript)?; - ensure_len( - "folded eval outputs", - folded.eval_commitments.len(), - proof.folded_eval_outputs.len(), - )?; - ensure_len( - "folded eval blindings", - folded.eval_commitments.len(), - proof.folded_eval_blindings.len(), - )?; - proof.verify_folded_eval_commitments::(vc_setup, &folded)?; - self.verify_folded_eval_witness_bindings::(proof, vc_setup, &folded, transcript)?; - let outer = self.verify_outer_folded_r1cs::(proof, vc_setup, &folded, transcript)?; - self.verify_inner_folded_r1cs::(proof, vc_setup, &folded, &outer, transcript)?; - Ok(()) - } + let folded = folded_instance_from_proof(protocol, proof, transcript)?; + ensure_len( + "folded eval outputs", + folded.eval_commitments.len(), + proof.folded_eval_outputs.len(), + )?; + ensure_len( + "folded eval blindings", + folded.eval_commitments.len(), + proof.folded_eval_blindings.len(), + )?; + verify_folded_eval_commitments::(proof, vc_setup, &folded)?; + verify_folded_eval_witness_bindings::( + protocol, proof, vc_setup, &folded, transcript, + )?; + let outer = + verify_outer_folded_r1cs::(protocol, proof, vc_setup, &folded, transcript)?; + verify_inner_folded_r1cs::(protocol, proof, vc_setup, &folded, &outer, transcript)?; + Ok(()) } -impl BlindFoldProtocol +fn folded_instance_from_proof( + protocol: &BlindFoldProtocol, + proof: &BlindFoldProof, + transcript: &mut T, +) -> Result, VerificationError> where F: Field + AppendToTranscript, - Com: Clone + HomomorphicCommitment + AppendToTranscript, + C: Clone + HomomorphicCommitment + AppendToTranscript, + T: Transcript, { - fn folded_instance_from_proof( - &self, - proof: &BlindFoldProof, - transcript: &mut T, - ) -> Result, VerificationError> - where - T: Transcript, - { - let committed = self.committed_relaxed_instance(&proof.auxiliary_row_commitments)?; - committed.append_to_transcript( - transcript, - b"bf_committed_u", - b"bf_committed_w", - b"bf_committed_e", - b"bf_committed_eval", - ); - - let random = self.random_relaxed_instance( - &proof.random_round_commitments, - &proof.random_output_claim_row_commitments, - &proof.random_auxiliary_row_commitments, - &proof.random_error_row_commitments, - &proof.random_eval_commitments, - proof.random_u, - )?; - random.append_to_transcript( - transcript, - b"bf_random_u", - b"bf_random_w", - b"bf_random_e", - b"bf_random_eval", - ); - - self.validate_cross_term_error_rows(&proof.cross_term_error_row_commitments)?; - transcript.append_values(b"bf_cross_e", &proof.cross_term_error_row_commitments); + let committed = protocol.committed_relaxed_instance(&proof.auxiliary_row_commitments)?; + append_relaxed_instance( + transcript, + b"bf_committed_u", + b"bf_committed_w", + b"bf_committed_e", + b"bf_committed_eval", + &committed, + ); + + let random = protocol.random_relaxed_instance( + &proof.random_round_commitments, + &proof.random_output_claim_row_commitments, + &proof.random_auxiliary_row_commitments, + &proof.random_error_row_commitments, + &proof.random_eval_commitments, + proof.random_u, + )?; + append_relaxed_instance( + transcript, + b"bf_random_u", + b"bf_random_w", + b"bf_random_e", + b"bf_random_eval", + &random, + ); + + protocol.validate_cross_term_error_rows(&proof.cross_term_error_row_commitments)?; + append_values( + transcript, + b"bf_cross_e", + &proof.cross_term_error_row_commitments, + ); + + let folding_challenge = transcript.challenge(); + Ok(committed.fold( + &random, + &proof.cross_term_error_row_commitments, + folding_challenge, + )?) +} - let folding_challenge = transcript.challenge(); - Ok(committed.fold( - &random, - &proof.cross_term_error_row_commitments, - folding_challenge, - )?) - } +fn append_relaxed_instance( + transcript: &mut T, + u_label: &'static [u8], + witness_label: &'static [u8], + error_label: &'static [u8], + eval_label: &'static [u8], + instance: &RelaxedInstance, +) where + F: AppendToTranscript, + C: AppendToTranscript, + T: Transcript, +{ + transcript.append(&Label(u_label)); + instance.u.append_to_transcript(transcript); + append_values(transcript, witness_label, &instance.witness_row_commitments); + append_values(transcript, error_label, &instance.error_row_commitments); + append_values(transcript, eval_label, &instance.eval_commitments); } -impl RelaxedInstance +fn append_values(transcript: &mut T, label: &'static [u8], values: &[A]) where - F: AppendToTranscript, - Com: AppendToTranscript, + A: AppendToTranscript, + T: Transcript, { - fn append_to_transcript( - &self, - transcript: &mut T, - u_label: &'static [u8], - witness_label: &'static [u8], - error_label: &'static [u8], - eval_label: &'static [u8], - ) where - T: Transcript, - { - transcript.append(&Label(u_label)); - self.u.append_to_transcript(transcript); - transcript.append_values(witness_label, &self.witness_row_commitments); - transcript.append_values(error_label, &self.error_row_commitments); - transcript.append_values(eval_label, &self.eval_commitments); + transcript.append(&LabelWithCount(label, values.len() as u64)); + for value in values { + value.append_to_transcript(transcript); } } -impl BlindFoldProtocol +fn verify_outer_folded_r1cs( + protocol: &BlindFoldProtocol, + proof: &BlindFoldProof, + vc_setup: &VC::Setup, + folded: &RelaxedInstance, + transcript: &mut T, +) -> Result, VerificationError> where F: Field + AppendToTranscript, - Com: Copy + HomomorphicCommitment + AppendToTranscript, + VC: VectorCommitment, + VC::Output: Copy + HomomorphicCommitment + AppendToTranscript, + T: Transcript, ::Accumulator: RingAccumulator, { - fn verify_outer_folded_r1cs( - &self, - proof: &BlindFoldProof, - vc_setup: &VC::Setup, - folded: &RelaxedInstance, - transcript: &mut T, - ) -> Result, VerificationError> - where - VC: VectorCommitment, - T: Transcript, - { - let error_row_count = self.dimensions.error.row_count; - if error_row_count == 0 || !error_row_count.is_power_of_two() { - return Err(VerificationError::InvalidPowerOfTwo { - name: "error row count", - value: error_row_count, - }); - } - let row_vars = error_row_count.trailing_zeros() as usize; - - let error_row_len = self.dimensions.error.row_len; - if error_row_len == 0 || !error_row_len.is_power_of_two() { - return Err(VerificationError::InvalidPowerOfTwo { - name: "error row length", - value: error_row_len, - }); - } - let entry_vars = error_row_len.trailing_zeros() as usize; - let num_vars = - row_vars - .checked_add(entry_vars) - .ok_or(VerificationError::InvalidPowerOfTwo { - name: "outer sumcheck dimension", - value: usize::MAX, - })?; - if num_vars == 0 { - return Err(VerificationError::DegenerateSumcheck { - name: "outer folded R1CS sumcheck", - }); - } - - transcript.append(&Label(b"bf_spartan")); - let tau = transcript.challenge_vector(num_vars); - let claim = SumcheckClaim::new(num_vars, OUTER_SUMCHECK_DEGREE, F::zero()); - let outer = proof - .outer_sumcheck - .verify( - &claim, - BooleanHypercube, - SUMCHECK_ROUND_TRANSCRIPT_LABEL, - transcript, - ) - .map_err(|source| VerificationError::OuterSumcheck { source })?; - - let (row_point, entry_point) = outer.point.split_at(row_vars); - let e_rx = VC::verify_committed_rows( - vc_setup, - &folded.error_row_commitments, - row_point, - entry_point, - &proof.error_opening, - )?; - - let eq_tau_rx = EqPolynomial::::mle(&tau, &outer.point); - let expected = eq_tau_rx * (proof.az_rx * proof.bz_rx - folded.u * proof.cz_rx - e_rx); - if outer.value != expected { - return Err(VerificationError::OuterFinalClaimMismatch { - expected, - actual: outer.value, - }); - } + let row_vars = log2_power_of_two::("error row count", protocol.dimensions.error.row_count)?; + let entry_vars = log2_power_of_two::("error row length", protocol.dimensions.error.row_len)?; + let num_vars = + row_vars + .checked_add(entry_vars) + .ok_or(VerificationError::InvalidPowerOfTwo { + name: "outer sumcheck dimension", + value: usize::MAX, + })?; + if num_vars == 0 { + return Err(VerificationError::DegenerateSumcheck { + name: "outer folded R1CS sumcheck", + }); + } - transcript.append_values(b"bf_az_bz_cz", &[proof.az_rx, proof.bz_rx, proof.cz_rx]); - append_vector_opening( + transcript.append(&Label(b"bf_spartan")); + let tau = transcript.challenge_vector(num_vars); + let claim = SumcheckClaim::new(num_vars, OUTER_SUMCHECK_DEGREE, F::zero()); + let outer = proof + .outer_sumcheck + .verify( + &claim, + BooleanHypercube, + SUMCHECK_ROUND_TRANSCRIPT_LABEL, transcript, - b"bf_error_opening", - b"bf_error_blind", - &proof.error_opening, - ); - - Ok(OuterCheck { - point: outer.point.into_vec(), - }) + ) + .map_err(|source| VerificationError::OuterSumcheck { source })?; + + let (row_point, entry_point) = outer.point.split_at(row_vars); + let e_rx = VC::verify_committed_rows( + vc_setup, + &folded.error_row_commitments, + row_point, + entry_point, + &proof.error_opening, + )?; + + let eq_tau_rx = EqPolynomial::::mle(&tau, &outer.point); + let expected = eq_tau_rx * (proof.az_rx * proof.bz_rx - folded.u * proof.cz_rx - e_rx); + if outer.value != expected { + return Err(VerificationError::OuterFinalClaimMismatch { + expected, + actual: outer.value, + }); } + + append_values( + transcript, + b"bf_az_bz_cz", + &[proof.az_rx, proof.bz_rx, proof.cz_rx], + ); + append_vector_opening( + transcript, + b"bf_error_opening", + b"bf_error_blind", + &proof.error_opening, + ); + + Ok(OuterCheck { + point: outer.point.into_vec(), + }) } -impl BlindFoldProof +fn verify_folded_eval_commitments( + proof: &BlindFoldProof, + vc_setup: &VC::Setup, + folded: &RelaxedInstance, +) -> Result<(), VerificationError> where F: Field, - Com: Copy + AppendToTranscript, + VC: VectorCommitment, + VC::Output: Copy + AppendToTranscript, { - fn verify_folded_eval_commitments( - &self, - vc_setup: &VC::Setup, - folded: &RelaxedInstance, - ) -> Result<(), VerificationError> - where - VC: VectorCommitment, + for (index, ((commitment, &output), &blinding)) in folded + .eval_commitments + .iter() + .zip(&proof.folded_eval_outputs) + .zip(&proof.folded_eval_blindings) + .enumerate() { - for (index, ((commitment, &output), &blinding)) in folded - .eval_commitments - .iter() - .zip(&self.folded_eval_outputs) - .zip(&self.folded_eval_blindings) - .enumerate() - { - if !VC::verify(vc_setup, commitment, &[output], &blinding) { - return Err(VerificationError::EvalCommitmentMismatch { index }); - } + if !VC::verify(vc_setup, commitment, &[output], &blinding) { + return Err(VerificationError::EvalCommitmentMismatch { index }); } - - Ok(()) } + + Ok(()) } -impl BlindFoldProtocol +fn verify_folded_eval_witness_bindings( + protocol: &BlindFoldProtocol, + proof: &BlindFoldProof, + vc_setup: &VC::Setup, + folded: &RelaxedInstance, + transcript: &mut T, +) -> Result<(), VerificationError> where F: Field + AppendToTranscript, - Com: Copy + HomomorphicCommitment + AppendToTranscript, + VC: VectorCommitment, + VC::Output: Copy + HomomorphicCommitment + AppendToTranscript, + T: Transcript, ::Accumulator: RingAccumulator, { - fn verify_folded_eval_witness_bindings( - &self, - proof: &BlindFoldProof, - vc_setup: &VC::Setup, - folded: &RelaxedInstance, - transcript: &mut T, - ) -> Result<(), VerificationError> - where - VC: VectorCommitment, - T: Transcript, - { - let coordinates = self.final_opening_witness_coordinates()?; - ensure_len( - "final opening bindings", - coordinates.len(), - folded.eval_commitments.len(), - )?; - - let expected_outputs = coordinates - .iter() - .filter(|coordinates| coordinates.evaluation.is_some()) - .count(); - ensure_len( - "folded eval output witness openings", - expected_outputs, - proof.folded_eval_output_openings.len(), - )?; - let expected_blindings = coordinates - .iter() - .filter(|coordinates| coordinates.blinding.is_some()) - .count(); - ensure_len( - "folded eval blinding witness openings", - expected_blindings, - proof.folded_eval_blinding_openings.len(), - )?; - - let mut output_openings = proof.folded_eval_output_openings.iter(); - let mut blinding_openings = proof.folded_eval_blinding_openings.iter(); - for (index, coordinates) in coordinates.iter().enumerate() { - if let Some(coordinate) = coordinates.evaluation { - let opening = output_openings.next().ok_or(RelaxedError::LengthMismatch { - name: "folded eval output witness openings", - expected: expected_outputs, - actual: proof.folded_eval_output_openings.len(), - })?; - let opened = coordinate.verify_opening::(vc_setup, folded, opening)?; - if opened != proof.folded_eval_outputs[index] { - return Err(VerificationError::EvalWitnessMismatch { - kind: "output", - index, - }); - } - coordinate.require_dedicated_row(opening, "output", index)?; - append_vector_opening( - transcript, - b"bf_eval_out_open", - b"bf_eval_out_blind", - opening, - ); + let coordinates = protocol.final_opening_witness_coordinates()?; + ensure_len( + "final opening bindings", + coordinates.len(), + folded.eval_commitments.len(), + )?; + + let expected_outputs = coordinates + .iter() + .filter(|coordinates| coordinates.evaluation.is_some()) + .count(); + ensure_len( + "folded eval output witness openings", + expected_outputs, + proof.folded_eval_output_openings.len(), + )?; + let expected_blindings = coordinates + .iter() + .filter(|coordinates| coordinates.blinding.is_some()) + .count(); + ensure_len( + "folded eval blinding witness openings", + expected_blindings, + proof.folded_eval_blinding_openings.len(), + )?; + + let mut output_openings = proof.folded_eval_output_openings.iter(); + let mut blinding_openings = proof.folded_eval_blinding_openings.iter(); + for (index, coordinates) in coordinates.iter().enumerate() { + if let Some(coordinate) = coordinates.evaluation { + let opening = output_openings.next().ok_or(RelaxedError::LengthMismatch { + name: "folded eval output witness openings", + expected: expected_outputs, + actual: proof.folded_eval_output_openings.len(), + })?; + let opened = verify_witness_coordinate::(vc_setup, folded, coordinate, opening)?; + if opened != proof.folded_eval_outputs[index] { + return Err(VerificationError::EvalWitnessMismatch { + kind: "output", + index, + }); } + require_dedicated_witness_row(opening, coordinate, "output", index)?; + append_vector_opening( + transcript, + b"bf_eval_out_open", + b"bf_eval_out_blind", + opening, + ); + } - if let Some(coordinate) = coordinates.blinding { - let opening = blinding_openings - .next() - .ok_or(RelaxedError::LengthMismatch { - name: "folded eval blinding witness openings", - expected: expected_blindings, - actual: proof.folded_eval_blinding_openings.len(), - })?; - let opened = coordinate.verify_opening::(vc_setup, folded, opening)?; - if opened != proof.folded_eval_blindings[index] { - return Err(VerificationError::EvalWitnessMismatch { - kind: "blinding", - index, - }); - } - coordinate.require_dedicated_row(opening, "blinding", index)?; - append_vector_opening( - transcript, - b"bf_eval_blind_open", - b"bf_eval_blind_bl", - opening, - ); + if let Some(coordinate) = coordinates.blinding { + let opening = blinding_openings + .next() + .ok_or(RelaxedError::LengthMismatch { + name: "folded eval blinding witness openings", + expected: expected_blindings, + actual: proof.folded_eval_blinding_openings.len(), + })?; + let opened = verify_witness_coordinate::(vc_setup, folded, coordinate, opening)?; + if opened != proof.folded_eval_blindings[index] { + return Err(VerificationError::EvalWitnessMismatch { + kind: "blinding", + index, + }); } + require_dedicated_witness_row(opening, coordinate, "blinding", index)?; + append_vector_opening( + transcript, + b"bf_eval_blind_open", + b"bf_eval_blind_bl", + opening, + ); } - - Ok(()) } + + Ok(()) } -impl WitnessCoordinate { - fn require_dedicated_row( - self, - opening: &VectorCommitmentOpening, - kind: &'static str, - index: usize, - ) -> Result<(), VerificationError> { - for (slot, value) in opening.combined_vector.iter().enumerate() { - if slot != self.column && !value.is_zero() { - return Err(VerificationError::EvalWitnessRowNotDedicated { kind, index }); - } +fn require_dedicated_witness_row( + opening: &VectorCommitmentOpening, + coordinate: WitnessCoordinate, + kind: &'static str, + index: usize, +) -> Result<(), VerificationError> { + for (slot, value) in opening.combined_vector.iter().enumerate() { + if slot != coordinate.column && !value.is_zero() { + return Err(VerificationError::EvalWitnessRowNotDedicated { kind, index }); } - Ok(()) } + Ok(()) +} - fn verify_opening( - self, - vc_setup: &VC::Setup, - folded: &RelaxedInstance, - opening: &VectorCommitmentOpening, - ) -> Result> - where - F: Field, - VC: VectorCommitment, - VC::Output: Copy + HomomorphicCommitment, - ::Accumulator: RingAccumulator, - { - let witness_row_count = folded.witness_row_commitments.len(); - if witness_row_count == 0 || !witness_row_count.is_power_of_two() { - return Err(VerificationError::InvalidPowerOfTwo { - name: "witness row count", - value: witness_row_count, - }); - } - let row_vars = witness_row_count.trailing_zeros() as usize; - - let witness_row_len = opening.combined_vector.len(); - if witness_row_len == 0 || !witness_row_len.is_power_of_two() { - return Err(VerificationError::InvalidPowerOfTwo { - name: "witness row length", - value: witness_row_len, - }); - } - let entry_vars = witness_row_len.trailing_zeros() as usize; - let row_point = boolean_point::(self.row, row_vars)?; - let entry_point = boolean_point::(self.column, entry_vars)?; - Ok(VC::verify_committed_rows( - vc_setup, - &folded.witness_row_commitments, - &row_point, - &entry_point, - opening, - )?) - } +fn verify_witness_coordinate( + vc_setup: &VC::Setup, + folded: &RelaxedInstance, + coordinate: WitnessCoordinate, + opening: &VectorCommitmentOpening, +) -> Result> +where + F: Field, + VC: VectorCommitment, + VC::Output: Copy + HomomorphicCommitment, + ::Accumulator: RingAccumulator, +{ + let row_vars = + log2_power_of_two::("witness row count", folded.witness_row_commitments.len())?; + let entry_vars = log2_power_of_two::("witness row length", opening.combined_vector.len())?; + let row_point = boolean_point::(coordinate.row, row_vars)?; + let entry_point = boolean_point::(coordinate.column, entry_vars)?; + Ok(VC::verify_committed_rows( + vc_setup, + &folded.witness_row_commitments, + &row_point, + &entry_point, + opening, + )?) } -impl BlindFoldProtocol +fn verify_inner_folded_r1cs( + protocol: &BlindFoldProtocol, + proof: &BlindFoldProof, + vc_setup: &VC::Setup, + folded: &RelaxedInstance, + outer: &OuterCheck, + transcript: &mut T, +) -> Result<(), VerificationError> where F: Field + AppendToTranscript, - Com: Copy + HomomorphicCommitment + AppendToTranscript, + VC: VectorCommitment, + VC::Output: Copy + HomomorphicCommitment + AppendToTranscript, + T: Transcript, ::Accumulator: RingAccumulator, { - fn verify_inner_folded_r1cs( - &self, - proof: &BlindFoldProof, - vc_setup: &VC::Setup, - folded: &RelaxedInstance, - outer: &OuterCheck, - transcript: &mut T, - ) -> Result<(), VerificationError> - where - VC: VectorCommitment, - T: Transcript, - { - let ra = transcript.challenge(); - let rb = transcript.challenge(); - let rc = transcript.challenge(); - let public = public_contributions(&self.r1cs, &outer.point, folded.u)?; - let claim = ra * (proof.az_rx - public.a) - + rb * (proof.bz_rx - public.b) - + rc * (proof.cz_rx - public.c); - - let witness_row_count = self.dimensions.witness.row_count; - if witness_row_count == 0 || !witness_row_count.is_power_of_two() { - return Err(VerificationError::InvalidPowerOfTwo { - name: "witness row count", - value: witness_row_count, - }); - } - let row_vars = witness_row_count.trailing_zeros() as usize; - - let witness_row_len = self.dimensions.witness.row_len; - if witness_row_len == 0 || !witness_row_len.is_power_of_two() { - return Err(VerificationError::InvalidPowerOfTwo { - name: "witness row length", - value: witness_row_len, - }); - } - let entry_vars = witness_row_len.trailing_zeros() as usize; - let num_vars = - row_vars - .checked_add(entry_vars) - .ok_or(VerificationError::InvalidPowerOfTwo { - name: "inner sumcheck dimension", - value: usize::MAX, - })?; - if num_vars == 0 { - return Err(VerificationError::DegenerateSumcheck { - name: "inner folded R1CS sumcheck", - }); - } - let inner_claim = SumcheckClaim::new(num_vars, INNER_SUMCHECK_DEGREE, claim); - let inner = proof - .inner_sumcheck - .verify( - &inner_claim, - BooleanHypercube, - INNER_SUMCHECK_LABEL, - transcript, - ) - .map_err(|source| VerificationError::InnerSumcheck { source })?; - - let (row_point, entry_point) = inner.point.split_at(row_vars); - let w_ry = VC::verify_committed_rows( - vc_setup, - &folded.witness_row_commitments, - row_point, - entry_point, - &proof.witness_opening, - )?; - - let l_w_at_ry = compute_l_w_at_ry(&self.r1cs, &outer.point, &inner.point, ra, rb, rc)?; - let expected = l_w_at_ry * w_ry; - if inner.value != expected { - return Err(VerificationError::InnerFinalClaimMismatch { - expected, - actual: inner.value, - }); - } - - append_vector_opening( + let ra = transcript.challenge(); + let rb = transcript.challenge(); + let rc = transcript.challenge(); + let public = public_contributions(&protocol.r1cs, &outer.point, folded.u)?; + let claim = ra * (proof.az_rx - public.a) + + rb * (proof.bz_rx - public.b) + + rc * (proof.cz_rx - public.c); + + let row_vars = + log2_power_of_two::("witness row count", protocol.dimensions.witness.row_count)?; + let entry_vars = + log2_power_of_two::("witness row length", protocol.dimensions.witness.row_len)?; + let num_vars = + row_vars + .checked_add(entry_vars) + .ok_or(VerificationError::InvalidPowerOfTwo { + name: "inner sumcheck dimension", + value: usize::MAX, + })?; + if num_vars == 0 { + return Err(VerificationError::DegenerateSumcheck { + name: "inner folded R1CS sumcheck", + }); + } + let inner_claim = SumcheckClaim::new(num_vars, INNER_SUMCHECK_DEGREE, claim); + let inner = proof + .inner_sumcheck + .verify( + &inner_claim, + BooleanHypercube, + INNER_SUMCHECK_LABEL, transcript, - b"bf_witness_opening", - b"bf_witness_blind", - &proof.witness_opening, - ); - - Ok(()) + ) + .map_err(|source| VerificationError::InnerSumcheck { source })?; + + let (row_point, entry_point) = inner.point.split_at(row_vars); + let w_ry = VC::verify_committed_rows( + vc_setup, + &folded.witness_row_commitments, + row_point, + entry_point, + &proof.witness_opening, + )?; + + let l_w_at_ry = compute_l_w_at_ry(&protocol.r1cs, &outer.point, &inner.point, ra, rb, rc)?; + let expected = l_w_at_ry * w_ry; + if inner.value != expected { + return Err(VerificationError::InnerFinalClaimMismatch { + expected, + actual: inner.value, + }); } + + append_vector_opening( + transcript, + b"bf_witness_opening", + b"bf_witness_blind", + &proof.witness_opening, + ); + + Ok(()) } #[derive(Clone, Debug, PartialEq, Eq)] @@ -505,11 +457,21 @@ fn append_vector_opening( F: AppendToTranscript, T: Transcript, { - transcript.append_values(row_label, &opening.combined_vector); + append_values(transcript, row_label, &opening.combined_vector); transcript.append(&Label(blinding_label)); opening.combined_blinding.append_to_transcript(transcript); } +fn log2_power_of_two(name: &'static str, value: usize) -> Result> +where + F: FieldCore, +{ + if value == 0 || !value.is_power_of_two() { + return Err(VerificationError::InvalidPowerOfTwo { name, value }); + } + Ok(value.trailing_zeros() as usize) +} + fn public_contributions( r1cs: &ConstraintMatrices, rx: &[F], @@ -830,21 +792,27 @@ mod tests { ) .expect("random instance builds"); let mut transcript = Blake2bTranscript::::new(b"blindfold-verify"); - committed.append_to_transcript( + append_relaxed_instance( &mut transcript, b"bf_committed_u", b"bf_committed_w", b"bf_committed_e", b"bf_committed_eval", + &committed, ); - random.append_to_transcript( + append_relaxed_instance( &mut transcript, b"bf_random_u", b"bf_random_w", b"bf_random_e", b"bf_random_eval", + &random, + ); + append_values( + &mut transcript, + b"bf_cross_e", + &proof.cross_term_error_row_commitments, ); - transcript.append_values(b"bf_cross_e", &proof.cross_term_error_row_commitments); transcript.challenge() } @@ -866,8 +834,7 @@ mod tests { let proof = proof(&setup, &protocol); let mut transcript = Blake2bTranscript::::new(b"blindfold-verify"); - let error = protocol - .verify::, _>(&proof, &setup, &mut transcript) + let error = verify::, _>(&protocol, &proof, &setup, &mut transcript) .expect_err("degenerate outer sumcheck is rejected"); assert!(matches!( @@ -885,8 +852,7 @@ mod tests { let proof = proof(&setup, &protocol); let mut transcript = Blake2bTranscript::::new(b"blindfold-verify"); - let folded = protocol - .folded_instance_from_proof(&proof, &mut transcript) + let folded = folded_instance_from_proof(&protocol, &proof, &mut transcript) .expect("fold inputs are well-shaped"); let committed = protocol @@ -903,21 +869,27 @@ mod tests { ) .expect("random instance builds"); let mut manual_transcript = Blake2bTranscript::::new(b"blindfold-verify"); - committed.append_to_transcript( + append_relaxed_instance( &mut manual_transcript, b"bf_committed_u", b"bf_committed_w", b"bf_committed_e", b"bf_committed_eval", + &committed, ); - random.append_to_transcript( + append_relaxed_instance( &mut manual_transcript, b"bf_random_u", b"bf_random_w", b"bf_random_e", b"bf_random_eval", + &random, + ); + append_values( + &mut manual_transcript, + b"bf_cross_e", + &proof.cross_term_error_row_commitments, ); - manual_transcript.append_values(b"bf_cross_e", &proof.cross_term_error_row_commitments); let folding_challenge = manual_transcript.challenge(); let expected = committed .fold( @@ -939,8 +911,7 @@ mod tests { let _ = proof.random_round_commitments.pop(); let mut transcript = Blake2bTranscript::::new(b"blindfold-verify"); - let error = protocol - .verify::, _>(&proof, &setup, &mut transcript) + let error = verify::, _>(&protocol, &proof, &setup, &mut transcript) .expect_err("random rows are missing"); assert!(matches!( @@ -960,8 +931,7 @@ mod tests { proof.folded_eval_outputs.push(f(7)); let mut transcript = Blake2bTranscript::::new(b"blindfold-verify"); - let error = protocol - .verify::, _>(&proof, &setup, &mut transcript) + let error = verify::, _>(&protocol, &proof, &setup, &mut transcript) .expect_err("folded eval count differs"); assert_eq!( @@ -976,12 +946,10 @@ mod tests { let protocol = protocol_with_eval(&setup); let proof = proof_with_valid_eval_opening(&setup, &protocol); let mut transcript = Blake2bTranscript::::new(b"blindfold-verify"); - let folded = protocol - .folded_instance_from_proof(&proof, &mut transcript) + let folded = folded_instance_from_proof(&protocol, &proof, &mut transcript) .expect("folded instance builds"); - proof - .verify_folded_eval_commitments::>(&setup, &folded) + verify_folded_eval_commitments::>(&proof, &setup, &folded) .expect("folded eval commitment opens"); } @@ -993,8 +961,7 @@ mod tests { proof.folded_eval_outputs[0] += f(1); let mut transcript = Blake2bTranscript::::new(b"blindfold-verify"); - let error = protocol - .verify::, _>(&proof, &setup, &mut transcript) + let error = verify::, _>(&protocol, &proof, &setup, &mut transcript) .expect_err("folded eval commitment opening is wrong"); assert!(matches!( @@ -1011,8 +978,7 @@ mod tests { let _ = proof.outer_sumcheck.round_polynomials.pop(); let mut transcript = Blake2bTranscript::::new(b"blindfold-verify"); - let error = protocol - .verify::, _>(&proof, &setup, &mut transcript) + let error = verify::, _>(&protocol, &proof, &setup, &mut transcript) .expect_err("outer sumcheck has wrong length"); assert!(matches!( @@ -1032,8 +998,7 @@ mod tests { CompressedPoly::new(vec![f(0), f(0), f(0), f(0)]); let mut transcript = Blake2bTranscript::::new(b"blindfold-verify"); - let error = protocol - .verify::, _>(&proof, &setup, &mut transcript) + let error = verify::, _>(&protocol, &proof, &setup, &mut transcript) .expect_err("outer sumcheck degree is too high"); assert!(matches!( @@ -1052,8 +1017,7 @@ mod tests { proof.error_opening.combined_blinding = f(1); let mut transcript = Blake2bTranscript::::new(b"blindfold-verify"); - let error = protocol - .verify::, _>(&proof, &setup, &mut transcript) + let error = verify::, _>(&protocol, &proof, &setup, &mut transcript) .expect_err("error opening is not binding to folded rows"); assert!(matches!( @@ -1071,8 +1035,7 @@ mod tests { proof.bz_rx = f(1); let mut transcript = Blake2bTranscript::::new(b"blindfold-verify"); - let error = protocol - .verify::, _>(&proof, &setup, &mut transcript) + let error = verify::, _>(&protocol, &proof, &setup, &mut transcript) .expect_err("outer final claim does not match opened error row"); assert!(matches!( @@ -1088,8 +1051,7 @@ mod tests { let proof = proof(&setup, &protocol); let mut transcript = Blake2bTranscript::::new(b"blindfold-verify"); - let error = protocol - .verify::, _>(&proof, &setup, &mut transcript) + let error = verify::, _>(&protocol, &proof, &setup, &mut transcript) .expect_err("inner sumcheck has wrong length"); assert!(matches!( @@ -1112,8 +1074,7 @@ mod tests { proof.witness_opening.combined_blinding = f(1); let mut transcript = Blake2bTranscript::::new(b"blindfold-verify"); - let error = protocol - .verify::, _>(&proof, &setup, &mut transcript) + let error = verify::, _>(&protocol, &proof, &setup, &mut transcript) .expect_err("witness opening is not binding to folded rows"); assert!(matches!( @@ -1138,8 +1099,7 @@ mod tests { }; let mut transcript = Blake2bTranscript::::new(b"blindfold-verify"); - let error = protocol - .verify::, _>(&proof, &setup, &mut transcript) + let error = verify::, _>(&protocol, &proof, &setup, &mut transcript) .expect_err("inner final claim does not match opened witness row"); assert!(matches!( diff --git a/crates/jolt-blindfold/tests/jolt_claims_pipeline.rs b/crates/jolt-blindfold/tests/jolt_claims_pipeline.rs index fd597a993e..d7fbad6816 100644 --- a/crates/jolt-blindfold/tests/jolt_claims_pipeline.rs +++ b/crates/jolt-blindfold/tests/jolt_claims_pipeline.rs @@ -2,7 +2,7 @@ mod support; -use jolt_blindfold::{BlindFoldStage, BlindFoldStatement, CommittedClaimRows}; +use jolt_blindfold::{r1cs, BlindFoldStage, BlindFoldStatement, CommittedClaimRows}; use jolt_claims::protocols::jolt::{ formulas::{ booleanity::{booleanity, BooleanityDimensions}, @@ -159,12 +159,8 @@ fn build_jolt_stage_relation( sources.insert_public(id, value); } - let layout = statement - .allocate_layout(&mut builder) - .expect("layout allocates"); - statement - .append(&mut builder, &layout, &mut sources) - .expect("constraints append"); + let layout = r1cs::allocate_layout(&mut builder, &statement).expect("layout allocates"); + r1cs::append(&mut builder, &statement, &layout, &mut sources).expect("constraints append"); assign_generated_stage(&mut builder, &layout.stages[0].sumcheck, generated); let witness = builder.witness().expect("all witnesses assigned"); @@ -283,12 +279,8 @@ fn jolt_claims_pipeline_lowers_booleanity_relation() { sources.insert_public(public_id, f(11)); } - let r1cs_layout = statement - .allocate_layout(&mut builder) - .expect("layout allocates"); - statement - .append(&mut builder, &r1cs_layout, &mut sources) - .expect("constraints append"); + let r1cs_layout = r1cs::allocate_layout(&mut builder, &statement).expect("layout allocates"); + r1cs::append(&mut builder, &statement, &r1cs_layout, &mut sources).expect("constraints append"); assign_generated_stage(&mut builder, &r1cs_layout.stages[0].sumcheck, &generated); let witness = builder.witness().expect("all witnesses assigned"); diff --git a/crates/jolt-blindfold/tests/proof_pipeline.rs b/crates/jolt-blindfold/tests/proof_pipeline.rs index af1e3e6058..f15011aafd 100644 --- a/crates/jolt-blindfold/tests/proof_pipeline.rs +++ b/crates/jolt-blindfold/tests/proof_pipeline.rs @@ -2,8 +2,12 @@ mod support; -use jolt_blindfold::VerificationError; +use jolt_blindfold::{ + verify, BlindFoldRowCommitter, DirectBlindFoldRowCommitter, ProverError, VerificationError, +}; +use jolt_crypto::VectorCommitmentOpening; use jolt_poly::CompressedPoly; +use jolt_r1cs::ConstraintMatrices; use jolt_transcript::{Blake2bTranscript, Transcript}; use rand_chacha::ChaCha20Rng; use rand_core::SeedableRng; @@ -14,8 +18,178 @@ fn verify_blindfold_protocol_pipeline( ) -> Result<(), VerificationError> { let mut transcript = Blake2bTranscript::::new(b"protocol-backed-blindfold-proof"); append_protocol_transcript_prefix(&full.protocol, &mut transcript); - full.protocol - .verify::(&full.proof, &full.setup, &mut transcript) + verify::(&full.protocol, &full.proof, &full.setup, &mut transcript) +} + +#[derive(Debug, Default)] +struct CountingRowCommitter { + inner: DirectBlindFoldRowCommitter, + row_commitments: usize, + error_rows: usize, + cross_term_error_rows: usize, + row_folds: usize, + scalar_folds: usize, + error_row_folds: usize, + error_scalar_folds: usize, + row_openings: usize, +} + +impl BlindFoldRowCommitter for CountingRowCommitter { + fn commit_rows( + &mut self, + setup: &::Setup, + rows: &[Vec], + blindings: &[F], + name: &'static str, + ) -> Result::Output>, ProverError> { + self.row_commitments += 1; + >::commit_rows( + &mut self.inner, + setup, + rows, + blindings, + name, + ) + } + + fn compute_error_rows( + &mut self, + r1cs: &ConstraintMatrices, + u: F, + witness: &[F], + row_count: usize, + row_len: usize, + name: &'static str, + ) -> Result>, ProverError> { + self.error_rows += 1; + >::compute_error_rows( + &mut self.inner, + r1cs, + u, + witness, + row_count, + row_len, + name, + ) + } + + fn compute_cross_term_error_rows( + &mut self, + r1cs: &ConstraintMatrices, + real_u: F, + real_witness: &[F], + random_u: F, + random_witness: &[F], + row_count: usize, + row_len: usize, + name: &'static str, + ) -> Result>, ProverError> { + self.cross_term_error_rows += 1; + >::compute_cross_term_error_rows( + &mut self.inner, + r1cs, + real_u, + real_witness, + random_u, + random_witness, + row_count, + row_len, + name, + ) + } + + fn fold_rows( + &mut self, + real: &[Vec], + random: &[Vec], + challenge: F, + name: &'static str, + ) -> Result>, ProverError> { + self.row_folds += 1; + >::fold_rows( + &mut self.inner, + real, + random, + challenge, + name, + ) + } + + fn fold_scalars( + &mut self, + real: &[F], + random: &[F], + challenge: F, + name: &'static str, + ) -> Result, ProverError> { + self.scalar_folds += 1; + >::fold_scalars( + &mut self.inner, + real, + random, + challenge, + name, + ) + } + + fn fold_error_rows( + &mut self, + real: &[Vec], + cross: &[Vec], + random: &[Vec], + challenge: F, + name: &'static str, + ) -> Result>, ProverError> { + self.error_row_folds += 1; + >::fold_error_rows( + &mut self.inner, + real, + cross, + random, + challenge, + name, + ) + } + + fn fold_error_scalars( + &mut self, + real: &[F], + cross: &[F], + random: &[F], + challenge: F, + name: &'static str, + ) -> Result, ProverError> { + self.error_scalar_folds += 1; + >::fold_error_scalars( + &mut self.inner, + real, + cross, + random, + challenge, + name, + ) + } + + fn open_rows( + &mut self, + setup: &::Setup, + rows: &[Vec], + blindings: &[F], + row_point: &[F], + entry_point: &[F], + name: &'static str, + ) -> Result<(VectorCommitmentOpening, F), ProverError> { + self.row_openings += 1; + >::open_rows( + &mut self.inner, + setup, + rows, + blindings, + row_point, + entry_point, + name, + ) + } } #[test] @@ -30,6 +204,23 @@ fn blindfold_protocol_pipeline_verifies_committed_sumcheck_outputs_and_eval_comm verify_blindfold_protocol_pipeline(&full).expect("protocol-backed BlindFold proof verifies"); } +#[test] +fn blindfold_protocol_pipeline_uses_row_committer_hooks() { + let mut rng = ChaCha20Rng::from_seed([89; 32]); + let mut row_committer = CountingRowCommitter::default(); + let full = prove_blindfold_protocol_pipeline_with_committer(&mut rng, &mut row_committer); + + verify_blindfold_protocol_pipeline(&full).expect("protocol-backed BlindFold proof verifies"); + assert!(row_committer.row_commitments > 0); + assert_eq!(row_committer.error_rows, 1); + assert_eq!(row_committer.cross_term_error_rows, 1); + assert_eq!(row_committer.row_folds, 1); + assert!(row_committer.scalar_folds >= 3); + assert_eq!(row_committer.error_row_folds, 1); + assert_eq!(row_committer.error_scalar_folds, 1); + assert!(row_committer.row_openings >= 4); +} + #[test] fn blindfold_protocol_pipeline_randomness_is_empirically_independent() { const SAMPLES: usize = 128; @@ -147,10 +338,7 @@ fn blindfold_protocol_pipeline_rejects_wrong_transcript() { let mut transcript = Blake2bTranscript::::new(b"wrong-transcript"); append_protocol_transcript_prefix(&full.protocol, &mut transcript); - assert!(full - .protocol - .verify::(&full.proof, &full.setup, &mut transcript) - .is_err()); + assert!(verify::(&full.protocol, &full.proof, &full.setup, &mut transcript).is_err()); } #[test] diff --git a/crates/jolt-blindfold/tests/support/mod.rs b/crates/jolt-blindfold/tests/support/mod.rs index a7a0886951..4a8b322b0a 100644 --- a/crates/jolt-blindfold/tests/support/mod.rs +++ b/crates/jolt-blindfold/tests/support/mod.rs @@ -5,8 +5,8 @@ )] use jolt_blindfold::{ - BlindFoldProof, BlindFoldProtocol, BlindFoldStage, BlindFoldStatement, CommittedClaimRows, - FinalOpeningBinding, WitnessCoordinate, + r1cs, BlindFoldProof, BlindFoldProtocol, BlindFoldStage, BlindFoldStatement, BlindFoldWitness, + CommittedClaimRows, FinalOpeningBinding, WitnessCoordinate, }; use jolt_claims::{challenge, constant, opening, public, Expr}; use jolt_crypto::{ @@ -20,7 +20,7 @@ use jolt_sumcheck::{ CommittedSumcheckProof, CompressedSumcheckProof, RoundMessage, SumcheckDomainSpec, SumcheckR1csLayout, SumcheckStatement, SUMCHECK_ROUND_TRANSCRIPT_LABEL, }; -use jolt_transcript::{AppendToTranscript, Blake2bTranscript, Label, Transcript}; +use jolt_transcript::{AppendToTranscript, Blake2bTranscript, Label, LabelWithCount, Transcript}; use rand_core::RngCore; pub type F = Fr; @@ -160,7 +160,7 @@ pub fn transcript_projection(label: &'static [u8], value: pub fn field_slice_projection(label: &'static [u8], values: &[F]) -> u64 { let mut transcript = Blake2bTranscript::::new(b"blindfold-statistical-projection"); - transcript.append_values(label, values); + append_values(&mut transcript, label, values); field_low_u64(transcript.challenge()) } @@ -692,12 +692,8 @@ pub fn build_deep_relation( sources.insert_public(Public::Offset, values.offset); sources.insert_public(Public::Multiplier, values.multiplier); - let layout = statement - .allocate_layout(&mut builder) - .expect("layout allocates"); - statement - .append(&mut builder, &layout, &mut sources) - .expect("constraints append"); + let layout = r1cs::allocate_layout(&mut builder, &statement).expect("layout allocates"); + r1cs::append(&mut builder, &statement, &layout, &mut sources).expect("constraints append"); assign_generated_stage(&mut builder, &layout.stages[0].sumcheck, stage1); assign_generated_stage(&mut builder, &layout.stages[1].sumcheck, stage2); assign_generated_stage(&mut builder, &layout.stages[2].sumcheck, stage3); @@ -762,6 +758,18 @@ struct SumcheckTrace { } pub fn prove_blindfold_protocol_pipeline(rng: &mut R) -> BlindFoldTestProof { + let mut row_committer = jolt_blindfold::DirectBlindFoldRowCommitter; + prove_blindfold_protocol_pipeline_with_committer(rng, &mut row_committer) +} + +pub fn prove_blindfold_protocol_pipeline_with_committer( + rng: &mut R, + row_committer: &mut C, +) -> BlindFoldTestProof +where + R: RngCore, + C: jolt_blindfold::BlindFoldRowCommitter, +{ let setup = pedersen_setup(4); let transcript_label = b"protocol-backed-blindfold-proof"; let statement1 = SumcheckStatement::new(3, 3); @@ -853,13 +861,21 @@ pub fn prove_blindfold_protocol_pipeline(rng: &mut R) -> BlindFoldTe &real_eval_blindings, rng, ); - let witness = ProtocolWitness { + let witness = BlindFoldWitness { rows: &real_witness_rows, blindings: &real_witness_blindings, eval_outputs: &real_eval_outputs, eval_blindings: &real_eval_blindings, }; - let proof = prove_from_protocol_witness(&setup, &protocol, &mut transcript, witness, rng); + let proof = jolt_blindfold::prove_with_row_committer::( + &setup, + &protocol, + &mut transcript, + witness, + rng, + row_committer, + ) + .expect("protocol-backed BlindFold proof proves"); BlindFoldTestProof { protocol, @@ -934,9 +950,7 @@ fn protocol_backed_witness( ) -> (Vec>, Vec) { let mut builder = R1csBuilder::::new(); let mut sources = ClaimSourceTable::::new(); - let layout = statement - .allocate_layout(&mut builder) - .expect("layout allocates"); + let layout = r1cs::allocate_layout(&mut builder, statement).expect("layout allocates"); for (stage, stage_layout) in statement.stages.iter().zip(&layout.stages) { let variables = stage_layout .output_claim_rows @@ -946,9 +960,7 @@ fn protocol_backed_witness( sources.insert_opening(*opening_id, variable); } } - statement - .append(&mut builder, &layout, &mut sources) - .expect("constraints append"); + r1cs::append(&mut builder, statement, &layout, &mut sources).expect("constraints append"); for (stage, (stage_layout, generated)) in statement .stages .iter() @@ -1216,7 +1228,7 @@ fn prove_from_protocol_witness( &random_instance.error_row_commitments, &random_instance.eval_commitments, ); - transcript.append_values(b"bf_cross_e", &cross_term_error_row_commitments); + append_values(transcript, b"bf_cross_e", &cross_term_error_row_commitments); let folding_challenge = transcript.challenge(); let folded_u = f(1) + folding_challenge * random_u; @@ -1333,7 +1345,7 @@ fn prove_from_protocol_witness( ) .expect("folded error rows open"); - transcript.append_values(b"bf_az_bz_cz", &[az_rx, bz_rx, cz_rx]); + append_values(transcript, b"bf_az_bz_cz", &[az_rx, bz_rx, cz_rx]); append_vector_opening( transcript, b"bf_error_opening", @@ -1456,9 +1468,20 @@ fn append_relaxed_instance_from_parts( ) { transcript.append(&Label(labels.u)); u.append_to_transcript(transcript); - transcript.append_values(labels.witness, witness_commitments); - transcript.append_values(labels.error, error_commitments); - transcript.append_values(labels.eval, eval_commitments); + append_values(transcript, labels.witness, witness_commitments); + append_values(transcript, labels.error, error_commitments); + append_values(transcript, labels.eval, eval_commitments); +} + +fn append_values( + transcript: &mut Blake2bTranscript, + label: &'static [u8], + values: &[A], +) { + transcript.append(&LabelWithCount(label, values.len() as u64)); + for value in values { + value.append_to_transcript(transcript); + } } fn append_vector_opening( @@ -1467,7 +1490,7 @@ fn append_vector_opening( blinding_label: &'static [u8], opening: &jolt_crypto::VectorCommitmentOpening, ) { - transcript.append_values(row_label, &opening.combined_vector); + append_values(transcript, row_label, &opening.combined_vector); transcript.append(&Label(blinding_label)); opening.combined_blinding.append_to_transcript(transcript); } @@ -1689,7 +1712,7 @@ fn prove_slow_sumcheck( let mut compressed = Vec::with_capacity(degree); compressed.push(coefficients[0]); compressed.extend_from_slice(&coefficients[2..]); - transcript.append_values(label, &compressed); + append_values(transcript, label, &compressed); let challenge = transcript.challenge(); running_sum = eval_poly(&coefficients, challenge); prefix.push(challenge); diff --git a/crates/jolt-crypto/Cargo.toml b/crates/jolt-crypto/Cargo.toml index 247760ff72..05d0fee8e7 100644 --- a/crates/jolt-crypto/Cargo.toml +++ b/crates/jolt-crypto/Cargo.toml @@ -31,17 +31,34 @@ bn254 = [ "dep:num-traits", "parallel", ] +grumpkin = [ + "dep:ark-grumpkin", + "dep:ark-ec", + "dep:ark-ff", + "dep:ark-serialize", + "dep:ark-std", + "dep:blake2", + "dep:num-bigint", + "dep:num-traits", + "jolt-field/bn254", + "parallel", +] parallel = ["dep:rayon", "jolt-poly/parallel"] +r1cs = ["grumpkin", "dep:jolt-r1cs", "dep:thiserror", "jolt-poly/r1cs"] [dependencies] jolt-field = { path = "../jolt-field" } jolt-poly = { path = "../jolt-poly", default-features = false } +jolt-r1cs = { workspace = true, optional = true } jolt-transcript = { path = "../jolt-transcript" } serde = { workspace = true, features = ["derive", "alloc"] } rand_core = { workspace = true } +thiserror = { workspace = true, optional = true } +blake2 = { workspace = true, optional = true } # Arkworks — BN254 backend (internal, gated behind `bn254` feature) ark-bn254 = { workspace = true, features = ["curve"], optional = true } +ark-grumpkin = { workspace = true, optional = true } ark-ec = { workspace = true, optional = true } ark-ff = { workspace = true, optional = true } ark-serialize = { workspace = true, optional = true } diff --git a/crates/jolt-crypto/src/commitment.rs b/crates/jolt-crypto/src/commitment.rs index 2727b5fd67..abdf63fb2d 100644 --- a/crates/jolt-crypto/src/commitment.rs +++ b/crates/jolt-crypto/src/commitment.rs @@ -3,7 +3,7 @@ use std::{ fmt::{self, Debug}, }; -use jolt_field::{AdditiveAccumulator, Field, RingAccumulator, WithAccumulator}; +use jolt_field::Field; use jolt_poly::EqPolynomial; use jolt_transcript::AppendToTranscript; use serde::{de::DeserializeOwned, Deserialize, Serialize}; @@ -18,7 +18,9 @@ const PAR_THRESHOLD: usize = 1024; /// polynomial commitment schemes (`jolt_openings::CommitmentScheme`). /// The `Output` associated type is the single piece of connective tissue /// between these different levels of abstraction. -pub trait Commitment: Clone + Debug + Eq + Send + Sync + 'static { +pub trait Commitment: + Clone + Debug + Eq + Send + Sync + 'static + Serialize + DeserializeOwned +{ /// The commitment value (e.g., a group element, a Merkle root, a lattice vector). type Output: Clone + Debug + Eq + Send + Sync + 'static + Serialize + DeserializeOwned; } @@ -59,20 +61,16 @@ pub trait VectorCommitment: /// Opens a row-major matrix of committed rows at `(row_point, entry_point)`. /// - /// Missing entries at the end of `flattened_rows` are treated as zero. - /// Callers must either pass exactly `row_count * row_len` entries or commit - /// each row with the same trailing zero-padding convention; otherwise - /// verification rejects with [`VectorOpeningError::CommitmentMismatch`]. + /// Missing entries at the end of `flattened_rows` are treated as zero. Row + /// commitments passed to verification must be produced with the same + /// row count and zero-padding convention. fn open_committed_rows( flattened_rows: &[Self::Field], row_blindings: &[Self::Field], row_len: usize, row_point: &[Self::Field], entry_point: &[Self::Field], - ) -> Result<(VectorCommitmentOpening, Self::Field), VectorOpeningError> - where - ::Accumulator: RingAccumulator, - { + ) -> Result<(VectorCommitmentOpening, Self::Field), VectorOpeningError> { let row_count = point_len_to_basis_len(row_point.len())?; validate_row_len(row_len, entry_point.len())?; let max_len = row_count @@ -117,7 +115,6 @@ pub trait VectorCommitment: ) -> Result where Self::Output: HomomorphicCommitment, - ::Accumulator: RingAccumulator, { let row_count = point_len_to_basis_len(row_point.len())?; if row_commitments.len() != row_count { @@ -279,10 +276,7 @@ fn combine_rows( row_len: usize, row_weights: &[F], max_len: usize, -) -> Vec -where - ::Accumulator: RingAccumulator, -{ +) -> Vec { let mut combined_vector = vec![F::zero(); row_len]; if max_len >= PAR_THRESHOLD { @@ -292,23 +286,23 @@ where .par_iter_mut() .enumerate() .for_each(|(entry_index, combined_entry)| { - let mut acc = ::Accumulator::default(); + let mut acc = F::zero(); for (row_index, row_weight) in row_weights.iter().copied().enumerate() { if let Some(value) = flattened_rows.get(row_index * row_len + entry_index) { - acc.fmadd(row_weight, *value); + acc += row_weight * *value; } } - *combined_entry = acc.reduce(); + *combined_entry = acc; }); } else { for (entry_index, combined_entry) in combined_vector.iter_mut().enumerate() { - let mut acc = ::Accumulator::default(); + let mut acc = F::zero(); for (row_index, row_weight) in row_weights.iter().copied().enumerate() { if let Some(value) = flattened_rows.get(row_index * row_len + entry_index) { - acc.fmadd(row_weight, *value); + acc += row_weight * *value; } } - *combined_entry = acc.reduce(); + *combined_entry = acc; } } @@ -321,29 +315,23 @@ fn combine_rows( row_len: usize, row_weights: &[F], _max_len: usize, -) -> Vec -where - ::Accumulator: RingAccumulator, -{ +) -> Vec { let mut combined_vector = vec![F::zero(); row_len]; for (entry_index, combined_entry) in combined_vector.iter_mut().enumerate() { - let mut acc = ::Accumulator::default(); + let mut acc = F::zero(); for (row_index, row_weight) in row_weights.iter().copied().enumerate() { if let Some(value) = flattened_rows.get(row_index * row_len + entry_index) { - acc.fmadd(row_weight, *value); + acc += row_weight * *value; } } - *combined_entry = acc.reduce(); + *combined_entry = acc; } combined_vector } -fn inner_product(lhs: &[F], rhs: &[F]) -> F -where - ::Accumulator: RingAccumulator, -{ +fn inner_product(lhs: &[F], rhs: &[F]) -> F { #[cfg(feature = "parallel")] { if lhs.len() >= PAR_THRESHOLD { @@ -352,29 +340,15 @@ where return lhs .par_iter() .zip(rhs.par_iter()) - .fold( - ::Accumulator::default, - |mut acc, (left, right)| { - acc.fmadd(*left, *right); - acc - }, - ) - .reduce( - ::Accumulator::default, - |mut left, right| { - left.merge(right); - left - }, - ) - .reduce(); + .map(|(left, right)| *left * *right) + .sum(); } } - let mut acc = ::Accumulator::default(); - for (left, right) in lhs.iter().zip(rhs.iter()) { - acc.fmadd(*left, *right); - } - acc.reduce() + lhs.iter() + .zip(rhs.iter()) + .map(|(left, right)| *left * *right) + .sum() } fn combine_commitments(commitments: &[C], weights: &[F]) -> C @@ -403,14 +377,14 @@ where }) } -impl HomomorphicCommitment for G { +impl HomomorphicCommitment for G { #[inline] fn add(c1: &G, c2: &G) -> G { *c1 + c2 } #[inline] - fn linear_combine(c1: &G, c2: &G, scalar: &F) -> G { + fn linear_combine(c1: &G, c2: &G, scalar: &G::ScalarField) -> G { *c1 + c2.scalar_mul(scalar) } } diff --git a/crates/jolt-crypto/src/ec/bn254/fq12.rs b/crates/jolt-crypto/src/ec/bn254/fq12.rs new file mode 100644 index 0000000000..603461152d --- /dev/null +++ b/crates/jolt-crypto/src/ec/bn254/fq12.rs @@ -0,0 +1,120 @@ +use std::fmt::Debug; + +use ark_bn254::{Bn254 as ArkBn254, Fq12}; +use ark_ec::{pairing::MillerLoopOutput, pairing::Pairing}; +use ark_ff::{Field as ArkField, Zero}; +use jolt_field::{FixedByteSize, Fq}; +use jolt_transcript::{AppendToTranscript, Transcript}; +use serde::{Deserialize, Serialize}; + +use super::gt::Bn254GT; + +/// BN254 Fq12 element used for raw Miller-loop outputs before final exponentiation. +#[derive(Clone, Copy, Eq, PartialEq)] +#[repr(transparent)] +pub struct Bn254Fq12(pub(crate) Fq12); + +impl Debug for Bn254Fq12 { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_tuple("Bn254Fq12").field(&self.0).finish() + } +} + +impl Default for Bn254Fq12 { + #[inline(always)] + fn default() -> Self { + Self(Fq12::ONE) + } +} + +impl From for Bn254Fq12 { + #[inline(always)] + fn from(value: Bn254GT) -> Self { + Self(value.0) + } +} + +impl Bn254Fq12 { + pub const COEFFICIENTS: usize = 12; + + /// Returns the BN254 Fq12 tower coefficients in canonical basis order. + /// + /// The order is `(c0.c0, c0.c1, c0.c2, c1.c0, c1.c1, c1.c2)`, with each + /// Fq2 coefficient emitted as `(c0, c1)`. + pub fn coefficients(&self) -> [Fq; Self::COEFFICIENTS] { + [ + field_to_fq(&self.0.c0.c0.c0), + field_to_fq(&self.0.c0.c0.c1), + field_to_fq(&self.0.c0.c1.c0), + field_to_fq(&self.0.c0.c1.c1), + field_to_fq(&self.0.c0.c2.c0), + field_to_fq(&self.0.c0.c2.c1), + field_to_fq(&self.0.c1.c0.c0), + field_to_fq(&self.0.c1.c0.c1), + field_to_fq(&self.0.c1.c1.c0), + field_to_fq(&self.0.c1.c1.c1), + field_to_fq(&self.0.c1.c2.c0), + field_to_fq(&self.0.c1.c2.c1), + ] + } + + /// Applies BN254 final exponentiation to this raw Miller-loop output. + pub fn final_exponentiation(&self) -> Option { + ArkBn254::final_exponentiation(MillerLoopOutput(self.0)).map(|value| Bn254GT(value.0)) + } +} + +#[expect( + clippy::expect_used, + reason = "canonical BN254 Fq serialization into a fixed 32-byte buffer cannot fail" +)] +fn field_to_fq(value: &ark_bn254::Fq) -> Fq { + use ark_serialize::CanonicalSerialize; + + let mut bytes = [0_u8; Fq::NUM_BYTES]; + value + .serialize_compressed(&mut bytes[..]) + .expect("BN254 Fq serialization cannot fail"); + Fq::from_le_bytes_mod_order(&bytes) +} + +#[expect(clippy::expect_used)] +impl AppendToTranscript for Bn254Fq12 { + fn append_to_transcript(&self, transcript: &mut T) { + use ark_serialize::CanonicalSerialize; + let mut buf = Vec::with_capacity(self.0.uncompressed_size()); + self.0 + .serialize_uncompressed(&mut buf) + .expect("Fq12 serialization cannot fail"); + buf.reverse(); + transcript.append_bytes(&buf); + } + + fn transcript_payload_len(&self) -> Option { + use ark_serialize::CanonicalSerialize; + Some(self.0.uncompressed_size() as u64) + } +} + +impl Serialize for Bn254Fq12 { + fn serialize(&self, serializer: S) -> Result { + use ark_serialize::CanonicalSerialize; + let mut buf = Vec::with_capacity(self.0.compressed_size()); + self.0 + .serialize_compressed(&mut buf) + .map_err(serde::ser::Error::custom)?; + serializer.serialize_bytes(&buf) + } +} + +impl<'de> Deserialize<'de> for Bn254Fq12 { + fn deserialize>(deserializer: D) -> Result { + use ark_serialize::CanonicalDeserialize; + let buf = >::deserialize(deserializer)?; + let inner = Fq12::deserialize_compressed(&buf[..]).map_err(serde::de::Error::custom)?; + if inner.is_zero() { + return Err(serde::de::Error::custom("Fq12 Miller-loop output is zero")); + } + Ok(Self(inner)) + } +} diff --git a/crates/jolt-crypto/src/ec/bn254/g1.rs b/crates/jolt-crypto/src/ec/bn254/g1.rs index aa12921381..69936ea431 100644 --- a/crates/jolt-crypto/src/ec/bn254/g1.rs +++ b/crates/jolt-crypto/src/ec/bn254/g1.rs @@ -1,4 +1,7 @@ use ark_bn254::{G1Affine, G1Projective}; +use ark_ec::{AffineRepr, CurveGroup}; +use ark_ff::Zero; +use jolt_field::{FixedByteSize, Fq, FromPrimitiveInt}; super::impl_jolt_group_wrapper!( Bn254G1, @@ -6,3 +9,53 @@ super::impl_jolt_group_wrapper!( G1Affine, "BN254 G1 group element (projective coordinates)." ); + +impl Bn254G1 { + /// Returns affine `(x, y, infinity)` coordinates over BN254 Fq. + pub fn affine_coordinates_with_infinity(&self) -> [Fq; 3] { + let affine = self.0.into_affine(); + if affine.infinity { + return [Fq::default(), Fq::default(), Fq::from_u64(1)]; + } + + [ + field_to_fq(&affine.x), + field_to_fq(&affine.y), + Fq::default(), + ] + } + + pub fn from_affine_coordinates_with_infinity(coordinates: [Fq; 3]) -> Option { + let [x, y, infinity] = coordinates; + if infinity == Fq::from_u64(1) { + if x != Fq::default() || y != Fq::default() { + return None; + } + return Some(Self(G1Projective::zero())); + } + if infinity != Fq::default() { + return None; + } + + let affine = G1Affine::new_unchecked(super::fq_to_field(&x), super::fq_to_field(&y)); + if !affine.is_on_curve() || !affine.is_in_correct_subgroup_assuming_on_curve() { + return None; + } + + Some(Self(affine.into_group())) + } +} + +#[expect( + clippy::expect_used, + reason = "canonical BN254 Fq serialization into a fixed 32-byte buffer cannot fail" +)] +fn field_to_fq(value: &ark_bn254::Fq) -> Fq { + use ark_serialize::CanonicalSerialize; + + let mut bytes = [0_u8; Fq::NUM_BYTES]; + value + .serialize_compressed(&mut bytes[..]) + .expect("BN254 Fq serialization cannot fail"); + Fq::from_le_bytes_mod_order(&bytes) +} diff --git a/crates/jolt-crypto/src/ec/bn254/g2.rs b/crates/jolt-crypto/src/ec/bn254/g2.rs index 457215eba4..8e8ae6a4ca 100644 --- a/crates/jolt-crypto/src/ec/bn254/g2.rs +++ b/crates/jolt-crypto/src/ec/bn254/g2.rs @@ -1,4 +1,7 @@ use ark_bn254::{G2Affine, G2Projective}; +use ark_ec::{AffineRepr, CurveGroup}; +use ark_ff::Zero; +use jolt_field::{FixedByteSize, Fq, FromPrimitiveInt}; super::impl_jolt_group_wrapper!( Bn254G2, @@ -6,3 +9,68 @@ super::impl_jolt_group_wrapper!( G2Affine, "BN254 G2 group element (projective coordinates)." ); + +impl Bn254G2 { + /// Returns affine `(x.c0, x.c1, y.c0, y.c1, infinity)` coordinates over BN254 Fq. + pub fn affine_coordinates_with_infinity(&self) -> [Fq; 5] { + let affine = self.0.into_affine(); + if affine.infinity { + return [ + Fq::default(), + Fq::default(), + Fq::default(), + Fq::default(), + Fq::from_u64(1), + ]; + } + + [ + field_to_fq(&affine.x.c0), + field_to_fq(&affine.x.c1), + field_to_fq(&affine.y.c0), + field_to_fq(&affine.y.c1), + Fq::default(), + ] + } + + pub fn from_affine_coordinates_with_infinity(coordinates: [Fq; 5]) -> Option { + let [x0, x1, y0, y1, infinity] = coordinates; + if infinity == Fq::from_u64(1) { + if x0 != Fq::default() + || x1 != Fq::default() + || y0 != Fq::default() + || y1 != Fq::default() + { + return None; + } + return Some(Self(G2Projective::zero())); + } + if infinity != Fq::default() { + return None; + } + + let affine = G2Affine::new_unchecked( + ark_bn254::Fq2::new(super::fq_to_field(&x0), super::fq_to_field(&x1)), + ark_bn254::Fq2::new(super::fq_to_field(&y0), super::fq_to_field(&y1)), + ); + if !affine.is_on_curve() || !affine.is_in_correct_subgroup_assuming_on_curve() { + return None; + } + + Some(Self(affine.into_group())) + } +} + +#[expect( + clippy::expect_used, + reason = "canonical BN254 Fq serialization into a fixed 32-byte buffer cannot fail" +)] +fn field_to_fq(value: &ark_bn254::Fq) -> Fq { + use ark_serialize::CanonicalSerialize; + + let mut bytes = [0_u8; Fq::NUM_BYTES]; + value + .serialize_compressed(&mut bytes[..]) + .expect("BN254 Fq serialization cannot fail"); + Fq::from_le_bytes_mod_order(&bytes) +} diff --git a/crates/jolt-crypto/src/ec/bn254/glv/dory_g1.rs b/crates/jolt-crypto/src/ec/bn254/glv/dory_g1.rs index 2aa44b87f7..01da4b4ad6 100644 --- a/crates/jolt-crypto/src/ec/bn254/glv/dory_g1.rs +++ b/crates/jolt-crypto/src/ec/bn254/glv/dory_g1.rs @@ -1,6 +1,7 @@ //! Vector-scalar operations on G1 using 2D GLV, for Dory inner-product argument rounds. use ark_bn254::{Fr, G1Projective}; +use ark_ff::Zero; use rayon::prelude::*; use super::decomp_2d::{decompose_scalar_2d, glv_endomorphism}; @@ -18,6 +19,9 @@ pub fn vector_add_scalar_mul_g1_online( v.par_iter_mut() .zip(generators.par_iter()) .for_each(|(vi, gen)| { + if gen.is_zero() { + return; + } let bases = [*gen, glv_endomorphism(gen)]; *vi += shamir_glv_mul_2d(&bases, &coeffs, &signs); }); @@ -35,6 +39,10 @@ pub fn vector_scalar_mul_add_gamma_g1_online( v.par_iter_mut() .zip(gamma.par_iter()) .for_each(|(vi, &gamma_i)| { + if vi.is_zero() { + *vi = gamma_i; + return; + } let bases = [*vi, glv_endomorphism(vi)]; *vi = shamir_glv_mul_2d(&bases, &coeffs, &signs) + gamma_i; }); diff --git a/crates/jolt-crypto/src/ec/bn254/glv/dory_g2.rs b/crates/jolt-crypto/src/ec/bn254/glv/dory_g2.rs index 1d70156d40..36b7ac0419 100644 --- a/crates/jolt-crypto/src/ec/bn254/glv/dory_g2.rs +++ b/crates/jolt-crypto/src/ec/bn254/glv/dory_g2.rs @@ -1,6 +1,7 @@ //! Vector-scalar operations on G2 using 4D GLV with Frobenius, for Dory inner-product argument rounds. use ark_bn254::{Fr, G2Projective}; +use ark_ff::Zero; use rayon::prelude::*; use super::decomp_4d::decompose_scalar_4d; @@ -19,6 +20,9 @@ pub fn vector_add_scalar_mul_g2_online( v.par_iter_mut() .zip(generators.par_iter()) .for_each(|(vi, gen)| { + if gen.is_zero() { + return; + } let bases = [ *gen, frobenius_psi_power_projective(gen, 1), @@ -41,6 +45,10 @@ pub fn vector_scalar_mul_add_gamma_g2_online( v.par_iter_mut() .zip(gamma.par_iter()) .for_each(|(vi, &gamma_i)| { + if vi.is_zero() { + *vi = gamma_i; + return; + } let bases = [ *vi, frobenius_psi_power_projective(vi, 1), diff --git a/crates/jolt-crypto/src/ec/bn254/gt.rs b/crates/jolt-crypto/src/ec/bn254/gt.rs index 603089517e..e1da15d1ec 100644 --- a/crates/jolt-crypto/src/ec/bn254/gt.rs +++ b/crates/jolt-crypto/src/ec/bn254/gt.rs @@ -1,9 +1,9 @@ use std::fmt::Debug; use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}; -use ark_bn254::{Fq12, Fr}; +use ark_bn254::{Fq12, Fq2, Fq6, Fr as ArkFr}; use ark_ff::{AdditiveGroup, Field as ArkField, PrimeField}; -use jolt_field::Field; +use jolt_field::{FixedByteSize, Fq, Fr}; use jolt_transcript::{AppendToTranscript, Transcript}; @@ -51,6 +51,70 @@ impl Default for Bn254GT { } } +impl Bn254GT { + pub const FQ12_COEFFICIENTS: usize = 12; + + /// Returns the BN254 Fq12 tower coefficients in canonical basis order. + /// + /// The order is `(c0.c0, c0.c1, c0.c2, c1.c0, c1.c1, c1.c2)`, with each + /// Fq2 coefficient emitted as `(c0, c1)`. + pub fn fq12_coefficients(&self) -> [Fq; Self::FQ12_COEFFICIENTS] { + [ + field_to_fq(&self.0.c0.c0.c0), + field_to_fq(&self.0.c0.c0.c1), + field_to_fq(&self.0.c0.c1.c0), + field_to_fq(&self.0.c0.c1.c1), + field_to_fq(&self.0.c0.c2.c0), + field_to_fq(&self.0.c0.c2.c1), + field_to_fq(&self.0.c1.c0.c0), + field_to_fq(&self.0.c1.c0.c1), + field_to_fq(&self.0.c1.c1.c0), + field_to_fq(&self.0.c1.c1.c1), + field_to_fq(&self.0.c1.c2.c0), + field_to_fq(&self.0.c1.c2.c1), + ] + } + + pub fn from_fq12_coefficients(coefficients: [Fq; Self::FQ12_COEFFICIENTS]) -> Option { + let inner = Fq12::new( + Fq6::new( + Fq2::new( + super::fq_to_field(&coefficients[0]), + super::fq_to_field(&coefficients[1]), + ), + Fq2::new( + super::fq_to_field(&coefficients[2]), + super::fq_to_field(&coefficients[3]), + ), + Fq2::new( + super::fq_to_field(&coefficients[4]), + super::fq_to_field(&coefficients[5]), + ), + ), + Fq6::new( + Fq2::new( + super::fq_to_field(&coefficients[6]), + super::fq_to_field(&coefficients[7]), + ), + Fq2::new( + super::fq_to_field(&coefficients[8]), + super::fq_to_field(&coefficients[9]), + ), + Fq2::new( + super::fq_to_field(&coefficients[10]), + super::fq_to_field(&coefficients[11]), + ), + ), + ); + + if inner == Fq12::ZERO || inner.pow(ArkFr::MODULUS) != Fq12::ONE { + return None; + } + + Some(Self(inner)) + } +} + // GT's additive notation maps to Fq12 multiplication by design. #[expect( clippy::suspicious_arithmetic_impl, @@ -146,7 +210,23 @@ impl AppendToTranscript for Bn254GT { } } +#[expect( + clippy::expect_used, + reason = "canonical BN254 Fq serialization into a fixed 32-byte buffer cannot fail" +)] +fn field_to_fq(value: &ark_bn254::Fq) -> Fq { + use ark_serialize::CanonicalSerialize; + + let mut bytes = [0_u8; Fq::NUM_BYTES]; + value + .serialize_compressed(&mut bytes[..]) + .expect("BN254 Fq serialization cannot fail"); + Fq::from_le_bytes_mod_order(&bytes) +} + impl JoltGroup for Bn254GT { + type ScalarField = Fr; + #[inline(always)] fn identity() -> Self { Self(Fq12::ONE) @@ -163,14 +243,14 @@ impl JoltGroup for Bn254GT { } #[inline] - fn scalar_mul(&self, scalar: &F) -> Self { + fn scalar_mul(&self, scalar: &Self::ScalarField) -> Self { // GT exponentiation: self^scalar (written additively as scalar * self). let fr = field_to_fr(scalar); Self(self.0.pow(fr.into_bigint())) } #[inline] - fn msm(bases: &[Self], scalars: &[F]) -> Self { + fn msm(bases: &[Self], scalars: &[Self::ScalarField]) -> Self { debug_assert_eq!(bases.len(), scalars.len()); // GT "MSM" is Π bases[i]^scalars[i] (written additively as Σ scalars[i] * bases[i]). let mut acc = Fq12::ONE; @@ -206,7 +286,7 @@ impl<'de> serde::Deserialize<'de> for Bn254GT { )); } // Subgroup membership: GT is the r-torsion subgroup, so x^r == 1. - if inner.pow(Fr::MODULUS) != Fq12::ONE { + if inner.pow(ArkFr::MODULUS) != Fq12::ONE { return Err(serde::de::Error::custom( "GT element is not in the r-torsion subgroup", )); diff --git a/crates/jolt-crypto/src/ec/bn254/mod.rs b/crates/jolt-crypto/src/ec/bn254/mod.rs index 38f713e6b9..d71ddccacb 100644 --- a/crates/jolt-crypto/src/ec/bn254/mod.rs +++ b/crates/jolt-crypto/src/ec/bn254/mod.rs @@ -141,6 +141,8 @@ macro_rules! impl_jolt_group_wrapper { } impl $crate::JoltGroup for $wrapper { + type ScalarField = ::jolt_field::Fr; + #[inline(always)] fn identity() -> Self { Self(<$projective as ::ark_ff::Zero>::zero()) @@ -157,12 +159,12 @@ macro_rules! impl_jolt_group_wrapper { } #[inline] - fn scalar_mul(&self, scalar: &F) -> Self { + fn scalar_mul(&self, scalar: &Self::ScalarField) -> Self { Self(self.0 * super::field_to_fr(scalar)) } #[inline] - fn msm(bases: &[Self], scalars: &[F]) -> Self { + fn msm(bases: &[Self], scalars: &[Self::ScalarField]) -> Self { use ::ark_ec::{CurveGroup, VariableBaseMSM}; use ::ark_ff::PrimeField; debug_assert_eq!(bases.len(), scalars.len()); @@ -180,6 +182,7 @@ macro_rules! impl_jolt_group_wrapper { pub(crate) use impl_jolt_group_wrapper; +mod fq12; mod g1; mod g2; mod gt; @@ -189,6 +192,7 @@ pub mod batch_addition; #[doc(hidden)] pub mod glv; +pub use fq12::Bn254Fq12; pub use g1::Bn254G1; pub use g2::Bn254G2; pub use gt::Bn254GT; @@ -197,12 +201,12 @@ use ark_bn254::Bn254 as ArkBn254; use ark_ec::pairing::Pairing; use ark_ec::CurveGroup; use ark_ff::PrimeField as _; -use jolt_field::Field; +use jolt_field::{CanonicalBytes, FixedByteSize, Fq, Fr}; use crate::PairingGroup; /// BN254 pairing-friendly curve. -#[derive(Clone, Debug, Default, PartialEq, Eq)] +#[derive(Clone, Debug, Default, PartialEq, Eq, serde::Serialize, serde::Deserialize)] pub struct Bn254; impl Bn254 { @@ -223,6 +227,20 @@ impl Bn254 { use ark_std::UniformRand; Bn254G1(ark_bn254::G1Projective::rand(rng)) } + + /// Computes the raw product of BN254 Miller loops without final exponentiation. + pub fn multi_miller_loop(g1s: &[Bn254G1], g2s: &[Bn254G2]) -> Bn254Fq12 { + debug_assert_eq!(g1s.len(), g2s.len()); + let g1_projs: Vec = g1s.iter().map(|g| g.0).collect(); + let g2_projs: Vec = g2s.iter().map(|g| g.0).collect(); + let g1_affines = ark_bn254::G1Projective::normalize_batch(&g1_projs); + let g2_affines = ark_bn254::G2Projective::normalize_batch(&g2_projs); + let g1_prepared: Vec<::G1Prepared> = + g1_affines.into_iter().map(Into::into).collect(); + let g2_prepared: Vec<::G2Prepared> = + g2_affines.into_iter().map(Into::into).collect(); + Bn254Fq12(ArkBn254::multi_miller_loop(g1_prepared, g2_prepared).0) + } } impl PairingGroup for Bn254 { @@ -257,8 +275,8 @@ impl PairingGroup for Bn254 { /// In debug builds, asserts that the source value fits in the BN254 Fr modulus — /// catches silent modular reduction when `F` has a larger modulus than BN254 Fr. #[inline] -pub(crate) fn field_to_fr(f: &F) -> ark_bn254::Fr { - let mut bytes = vec![0u8; F::NUM_BYTES]; +pub(crate) fn field_to_fr(f: &Fr) -> ark_bn254::Fr { + let mut bytes = vec![0u8; Fr::NUM_BYTES]; f.to_bytes_le(&mut bytes); #[cfg(debug_assertions)] { @@ -273,6 +291,20 @@ pub(crate) fn field_to_fr(f: &F) -> ark_bn254::Fr { ark_bn254::Fr::from_le_bytes_mod_order(&bytes) } +#[inline] +#[expect( + clippy::expect_used, + reason = "jolt_field::Fq canonical serialization is already a valid BN254 Fq encoding" +)] +pub(crate) fn fq_to_field(f: &Fq) -> ark_bn254::Fq { + use ark_serialize::CanonicalDeserialize; + + let mut bytes = vec![0u8; Fq::NUM_BYTES]; + f.to_bytes_le(&mut bytes); + ark_bn254::Fq::deserialize_compressed(&bytes[..]) + .expect("jolt_field::Fq serializes as a canonical BN254 Fq element") +} + #[cfg(test)] #[expect(clippy::expect_used, reason = "tests may fail loudly")] mod tests { diff --git a/crates/jolt-crypto/src/ec/group.rs b/crates/jolt-crypto/src/ec/group.rs index 9268a1bdc4..4563b44f53 100644 --- a/crates/jolt-crypto/src/ec/group.rs +++ b/crates/jolt-crypto/src/ec/group.rs @@ -37,6 +37,9 @@ pub trait JoltGroup: + for<'de> Deserialize<'de> + AppendToTranscript { + /// Scalar field used for scalar multiplication in this group. + type ScalarField: Field; + /// Group identity element. #[must_use] fn identity() -> Self; @@ -51,7 +54,7 @@ pub trait JoltGroup: /// Scalar multiplication: `scalar * self`. #[must_use] - fn scalar_mul(&self, scalar: &F) -> Self; + fn scalar_mul(&self, scalar: &Self::ScalarField) -> Self; /// Multi-scalar multiplication: `Σᵢ scalars[i] * bases[i]`. /// @@ -59,5 +62,5 @@ pub trait JoltGroup: /// /// Debug-asserts that `bases.len() == scalars.len()`. #[must_use] - fn msm(bases: &[Self], scalars: &[F]) -> Self; + fn msm(bases: &[Self], scalars: &[Self::ScalarField]) -> Self; } diff --git a/crates/jolt-crypto/src/ec/grumpkin/mod.rs b/crates/jolt-crypto/src/ec/grumpkin/mod.rs new file mode 100644 index 0000000000..cd893a33ed --- /dev/null +++ b/crates/jolt-crypto/src/ec/grumpkin/mod.rs @@ -0,0 +1,347 @@ +//! Concrete Grumpkin curve implementation. +//! +//! This module wraps `ark-grumpkin` behind `JoltGroup`. Arkworks types stay +//! internal to this backend module. + +use ark_ec::{AffineRepr, CurveGroup, VariableBaseMSM}; +use ark_ff::{PrimeField, UniformRand, Zero}; +use blake2::{ + digest::{consts::U32, Digest}, + Blake2b, +}; +use jolt_field::{CanonicalBytes, FixedByteSize, Fq}; + +use super::JoltGroup; + +/// Grumpkin prime-order group element. +#[derive(Clone, Copy, Default, Eq, PartialEq)] +#[repr(transparent)] +pub struct GrumpkinPoint(pub(crate) ark_grumpkin::Projective); + +const _: () = assert!( + std::mem::size_of::() == std::mem::size_of::() +); + +impl std::fmt::Debug for GrumpkinPoint { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let affine = self.0.into_affine(); + f.debug_tuple("GrumpkinPoint").field(&affine).finish() + } +} + +impl std::ops::Add for GrumpkinPoint { + type Output = Self; + + #[inline(always)] + fn add(self, rhs: Self) -> Self { + Self(self.0 + rhs.0) + } +} + +impl<'a> std::ops::Add<&'a GrumpkinPoint> for GrumpkinPoint { + type Output = Self; + + #[inline(always)] + fn add(self, rhs: &'a GrumpkinPoint) -> Self { + Self(self.0 + rhs.0) + } +} + +impl std::ops::Sub for GrumpkinPoint { + type Output = Self; + + #[inline(always)] + fn sub(self, rhs: Self) -> Self { + Self(self.0 - rhs.0) + } +} + +impl<'a> std::ops::Sub<&'a GrumpkinPoint> for GrumpkinPoint { + type Output = Self; + + #[inline(always)] + fn sub(self, rhs: &'a GrumpkinPoint) -> Self { + Self(self.0 - rhs.0) + } +} + +impl std::ops::Neg for GrumpkinPoint { + type Output = Self; + + #[inline(always)] + fn neg(self) -> Self { + Self(-self.0) + } +} + +impl std::ops::AddAssign for GrumpkinPoint { + #[inline(always)] + fn add_assign(&mut self, rhs: Self) { + self.0 += rhs.0; + } +} + +impl std::ops::SubAssign for GrumpkinPoint { + #[inline(always)] + fn sub_assign(&mut self, rhs: Self) { + self.0 -= rhs.0; + } +} + +impl serde::Serialize for GrumpkinPoint { + fn serialize(&self, serializer: S) -> Result { + use ark_serialize::CanonicalSerialize; + let mut buf = Vec::with_capacity(self.0.compressed_size()); + self.0 + .serialize_compressed(&mut buf) + .map_err(serde::ser::Error::custom)?; + serializer.serialize_bytes(&buf) + } +} + +impl<'de> serde::Deserialize<'de> for GrumpkinPoint { + fn deserialize>(deserializer: D) -> Result { + use ark_serialize::CanonicalDeserialize; + let buf = >::deserialize(deserializer)?; + let inner = ark_grumpkin::Projective::deserialize_compressed(&buf[..]) + .map_err(serde::de::Error::custom)?; + Ok(Self(inner)) + } +} + +impl jolt_transcript::AppendToTranscript for GrumpkinPoint { + #[expect(clippy::expect_used, reason = "serialization into Vec cannot fail")] + fn append_to_transcript(&self, transcript: &mut T) { + use ark_serialize::CanonicalSerialize; + let mut buf = Vec::with_capacity(self.0.compressed_size()); + self.0 + .serialize_compressed(&mut buf) + .expect("GrumpkinPoint serialization cannot fail"); + transcript.append_bytes(&buf); + } +} + +impl crate::JoltGroup for GrumpkinPoint { + type ScalarField = Fq; + + #[inline(always)] + fn identity() -> Self { + Self(::zero()) + } + + #[inline(always)] + fn is_identity(&self) -> bool { + ::is_zero(&self.0) + } + + #[inline(always)] + fn double(&self) -> Self { + Self(::double( + &self.0, + )) + } + + #[inline] + fn scalar_mul(&self, scalar: &Self::ScalarField) -> Self { + Self(self.0 * fq_to_grumpkin_fr(scalar)) + } + + #[inline] + fn msm(bases: &[Self], scalars: &[Self::ScalarField]) -> Self { + debug_assert_eq!(bases.len(), scalars.len()); + let affines: Vec = bases.iter().map(|b| b.0.into_affine()).collect(); + let ark_scalars: Vec = scalars.iter().map(fq_to_grumpkin_fr).collect(); + let bigints: Vec<_> = ark_scalars.iter().map(|s| s.into_bigint()).collect(); + Self(::msm_bigint( + &affines, &bigints, + )) + } +} + +/// Grumpkin curve marker and constructors. +#[derive(Clone, Debug, Default, PartialEq, Eq, serde::Serialize, serde::Deserialize)] +pub struct Grumpkin; + +impl Grumpkin { + /// Standard Grumpkin generator. + pub fn generator() -> GrumpkinPoint { + GrumpkinPoint(ark_grumpkin::Affine::generator().into()) + } + + /// Samples a uniformly random Grumpkin group element. + pub fn random(rng: &mut R) -> GrumpkinPoint { + GrumpkinPoint(ark_grumpkin::Projective::rand(rng)) + } +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct GrumpkinPedersenSetupSeed<'a> { + pub domain: &'a [u8], + pub seed: &'a [u8], +} + +impl<'a> GrumpkinPedersenSetupSeed<'a> { + pub const fn new(domain: &'a [u8], seed: &'a [u8]) -> Self { + Self { domain, seed } + } +} + +impl crate::DeriveSetup> for super::PedersenSetup { + fn derive(source: &GrumpkinPedersenSetupSeed<'_>, capacity: usize) -> Self { + let mut message_generators = Vec::with_capacity(capacity); + for index in 0..capacity { + message_generators.push(derive_grumpkin_pedersen_point(source, b"message", index)); + } + let blinding_generator = derive_grumpkin_pedersen_point(source, b"blinding", 0); + Self::new(message_generators, blinding_generator) + } +} + +fn derive_grumpkin_pedersen_point( + source: &GrumpkinPedersenSetupSeed<'_>, + role: &[u8], + index: usize, +) -> GrumpkinPoint { + let mut attempt = 0_u64; + loop { + let bytes = hash_grumpkin_pedersen_candidate(source, role, index, attempt); + let x = ark_grumpkin::Fq::from_le_bytes_mod_order(&bytes); + let greatest = bytes[31] & 1 == 1; + if let Some(affine) = ark_grumpkin::Affine::get_point_from_x_unchecked(x, greatest) { + if affine.is_in_correct_subgroup_assuming_on_curve() { + let point = GrumpkinPoint(affine.into_group()); + if !point.is_identity() { + return point; + } + } + } + attempt = attempt.wrapping_add(1); + } +} + +fn hash_grumpkin_pedersen_candidate( + source: &GrumpkinPedersenSetupSeed<'_>, + role: &[u8], + index: usize, + attempt: u64, +) -> [u8; 32] { + let mut hasher = Blake2b::::new(); + hash_len_prefixed(&mut hasher, b"JoltGrumpkinPedersenSetupV1"); + hash_len_prefixed(&mut hasher, source.domain); + hash_len_prefixed(&mut hasher, source.seed); + hash_len_prefixed(&mut hasher, role); + hasher.update((index as u64).to_le_bytes()); + hasher.update(attempt.to_le_bytes()); + hasher.finalize().into() +} + +fn hash_len_prefixed(hasher: &mut Blake2b, bytes: &[u8]) { + hasher.update((bytes.len() as u64).to_le_bytes()); + hasher.update(bytes); +} + +#[inline] +pub(crate) fn fq_to_grumpkin_fr(f: &Fq) -> ark_grumpkin::Fr { + let mut bytes = vec![0u8; Fq::NUM_BYTES]; + f.to_bytes_le(&mut bytes); + #[cfg(debug_assertions)] + { + use ark_ff::{BigInteger, PrimeField as _}; + let value = num_bigint::BigUint::from_bytes_le(&bytes); + let modulus = num_bigint::BigUint::from_bytes_le(&ark_grumpkin::Fr::MODULUS.to_bytes_le()); + debug_assert!( + value < modulus, + "fq_to_grumpkin_fr: source value >= Grumpkin scalar modulus", + ); + } + ark_grumpkin::Fr::from_le_bytes_mod_order(&bytes) +} + +#[cfg(test)] +mod tests { + use jolt_field::{Fq, FromPrimitiveInt}; + + use super::*; + use crate::{DeriveSetup, JoltGroup, Pedersen, PedersenSetup, VectorCommitment}; + + #[test] + fn scalar_mul_and_msm_match() { + let generator = Grumpkin::generator(); + let a = Fq::from_u64(11); + let b = Fq::from_u64(19); + let p = generator.scalar_mul(&a); + let q = generator.scalar_mul(&b); + + assert_eq!(GrumpkinPoint::msm(&[generator, generator], &[a, b]), p + q); + } + + #[test] + fn pedersen_over_grumpkin_uses_fq_scalars() { + type VC = Pedersen; + + let generator = Grumpkin::generator(); + let setup = PedersenSetup::new( + vec![ + generator, + generator.scalar_mul(&Fq::from_u64(2)), + generator.scalar_mul(&Fq::from_u64(3)), + ], + generator.scalar_mul(&Fq::from_u64(99)), + ); + let values = [Fq::from_u64(4), Fq::from_u64(5), Fq::from_u64(6)]; + let opening_scalar = Fq::from_u64(7); + let commitment = VC::commit(&setup, &values, &opening_scalar); + + assert!(VC::verify(&setup, &commitment, &values, &opening_scalar)); + assert!(!VC::verify( + &setup, + &commitment, + &values, + &(opening_scalar + Fq::from_u64(1)), + )); + } + + #[test] + fn seed_derived_pedersen_setup_is_deterministic() { + let seed = GrumpkinPedersenSetupSeed::new(b"dory-assist-hyrax", b"v1"); + let left = PedersenSetup::::derive(&seed, 4); + let right = PedersenSetup::::derive(&seed, 4); + + assert_eq!(left, right); + assert_eq!(left.message_generators.len(), 4); + assert!(left + .message_generators + .iter() + .all(|point| !point.is_identity())); + assert!(!left.blinding_generator.is_identity()); + } + + #[test] + fn seed_derived_pedersen_setup_changes_with_seed() { + let left_seed = GrumpkinPedersenSetupSeed::new(b"dory-assist-hyrax", b"v1"); + let right_seed = GrumpkinPedersenSetupSeed::new(b"dory-assist-hyrax", b"v2"); + let left = PedersenSetup::::derive(&left_seed, 2); + let right = PedersenSetup::::derive(&right_seed, 2); + + assert_ne!(left, right); + } + + #[test] + fn seed_derived_pedersen_setup_works_for_commitments() { + type VC = Pedersen; + + let seed = GrumpkinPedersenSetupSeed::new(b"dory-assist-hyrax", b"v1"); + let setup = PedersenSetup::::derive(&seed, 3); + let values = [Fq::from_u64(5), Fq::from_u64(8), Fq::from_u64(13)]; + let opening_scalar = Fq::from_u64(21); + let commitment = VC::commit(&setup, &values, &opening_scalar); + + assert!(VC::verify(&setup, &commitment, &values, &opening_scalar)); + assert!(!VC::verify( + &setup, + &commitment, + &values, + &(opening_scalar + Fq::from_u64(1)), + )); + } +} diff --git a/crates/jolt-crypto/src/ec/mod.rs b/crates/jolt-crypto/src/ec/mod.rs index 40ce215f34..54cdb6f110 100644 --- a/crates/jolt-crypto/src/ec/mod.rs +++ b/crates/jolt-crypto/src/ec/mod.rs @@ -11,3 +11,6 @@ pub use pedersen::{Pedersen, PedersenSetup}; #[cfg(feature = "bn254")] pub mod bn254; + +#[cfg(feature = "grumpkin")] +pub mod grumpkin; diff --git a/crates/jolt-crypto/src/ec/pairing.rs b/crates/jolt-crypto/src/ec/pairing.rs index ab9c164262..144e0d0043 100644 --- a/crates/jolt-crypto/src/ec/pairing.rs +++ b/crates/jolt-crypto/src/ec/pairing.rs @@ -1,4 +1,5 @@ use jolt_field::Field; +use serde::{de::DeserializeOwned, Serialize}; use std::fmt::Debug; use super::group::JoltGroup; @@ -12,12 +13,14 @@ use super::group::JoltGroup; /// G1, G2, and GT all implement `JoltGroup` (additive notation). GT uses /// additive notation for uniformity, even though the underlying operation /// is Fq12 multiplication. See `Bn254GT` for the mapping. -pub trait PairingGroup: Clone + Debug + Eq + Sync + Send + 'static { +pub trait PairingGroup: + Clone + Debug + Eq + Sync + Send + 'static + Serialize + DeserializeOwned +{ /// Scalar field for G1 and G2 (e.g., BN254 Fr). type ScalarField: Field; - type G1: JoltGroup; - type G2: JoltGroup; - type GT: JoltGroup; + type G1: JoltGroup; + type G2: JoltGroup; + type GT: JoltGroup; /// Computes the bilinear pairing `e(g1, g2)`. #[must_use] diff --git a/crates/jolt-crypto/src/ec/pedersen.rs b/crates/jolt-crypto/src/ec/pedersen.rs index 0054a3fe23..6068b9c4d3 100644 --- a/crates/jolt-crypto/src/ec/pedersen.rs +++ b/crates/jolt-crypto/src/ec/pedersen.rs @@ -1,4 +1,3 @@ -use jolt_field::Fr; use serde::{Deserialize, Serialize}; use super::group::JoltGroup; @@ -21,7 +20,7 @@ pub struct Pedersen { /// Setup parameters for Pedersen commitments: a vector of message generators /// and a separate blinding generator. -#[derive(Clone, Debug, Serialize)] +#[derive(Clone, Debug, PartialEq, Eq, Serialize)] #[serde(bound = "")] pub struct PedersenSetup { pub message_generators: Vec, @@ -67,7 +66,7 @@ impl Commitment for Pedersen { } impl VectorCommitment for Pedersen { - type Field = Fr; + type Field = G::ScalarField; type Setup = PedersenSetup; #[inline] @@ -87,7 +86,7 @@ impl VectorCommitment for Pedersen { /// # Panics /// /// Panics if `values.len() > Self::capacity(setup)`. - fn commit(setup: &Self::Setup, values: &[Fr], blinding: &Fr) -> G { + fn commit(setup: &Self::Setup, values: &[Self::Field], blinding: &Self::Field) -> G { assert!( values.len() <= setup.message_generators.len(), "values length ({}) exceeds generator count ({})", @@ -99,7 +98,12 @@ impl VectorCommitment for Pedersen { msg + blind } - fn verify(setup: &Self::Setup, commitment: &G, values: &[Fr], blinding: &Fr) -> bool { + fn verify( + setup: &Self::Setup, + commitment: &G, + values: &[Self::Field], + blinding: &Self::Field, + ) -> bool { *commitment == Self::commit(setup, values, blinding) } } diff --git a/crates/jolt-crypto/src/lib.rs b/crates/jolt-crypto/src/lib.rs index 2cdf92fc3e..2cbe985a4a 100644 --- a/crates/jolt-crypto/src/lib.rs +++ b/crates/jolt-crypto/src/lib.rs @@ -14,5 +14,10 @@ pub use commitment::{ VectorOpeningError, }; +#[cfg(feature = "r1cs")] +pub mod r1cs; + #[cfg(feature = "bn254")] -pub use ec::bn254::{Bn254, Bn254G1, Bn254G2, Bn254GT}; +pub use ec::bn254::{Bn254, Bn254Fq12, Bn254G1, Bn254G2, Bn254GT}; +#[cfg(feature = "grumpkin")] +pub use ec::grumpkin::{Grumpkin, GrumpkinPedersenSetupSeed, GrumpkinPoint}; diff --git a/crates/jolt-crypto/src/r1cs.rs b/crates/jolt-crypto/src/r1cs.rs new file mode 100644 index 0000000000..257e2f5626 --- /dev/null +++ b/crates/jolt-crypto/src/r1cs.rs @@ -0,0 +1,1888 @@ +//! R1CS helpers for cryptographic group operations. + +use ark_ec::CurveGroup; +use jolt_field::{Fr, FromPrimitiveInt, Invertible}; +use jolt_r1cs::{AssignedScalar, FqVar, LinearCombination, R1csBuilder, ScalarGadget}; +use num_traits::{One, Zero}; +use thiserror::Error; + +use crate::{GrumpkinPoint, JoltGroup, Pedersen, PedersenSetup}; + +#[derive(Clone, Debug, Error, PartialEq, Eq)] +pub enum CryptoR1csError { + #[error("affine Grumpkin point gadget does not represent the identity")] + IdentityPoint, + #[error("non-exceptional affine addition requires distinct input x-coordinates")] + ExceptionalAffineAddition, + #[error("affine doubling requires non-zero y-coordinate")] + ExceptionalAffineDoubling, + #[error("fixed-base MSM length mismatch: bases={bases}, scalars={scalars}")] + FixedBaseMsmLengthMismatch { bases: usize, scalars: usize }, + #[error("vector commitment opening length {values} exceeds setup capacity {capacity}")] + VectorCommitmentCapacityExceeded { capacity: usize, values: usize }, +} + +pub trait GroupElementVar: Clone { + type BuilderField: jolt_field::Field; + type ScalarVar: jolt_poly::r1cs::PolynomialScalarGadget< + ConstraintBuilder = R1csBuilder, + >; + type Error; + + fn assert_valid( + &self, + builder: &mut R1csBuilder, + ) -> Result<(), Self::Error>; + + fn assert_equal(&self, builder: &mut R1csBuilder, rhs: &Self); +} + +pub trait NonExceptionalAddGroupVar: GroupElementVar { + fn assert_nonexceptional_add( + builder: &mut R1csBuilder, + lhs: &Self, + rhs: &Self, + output: &Self, + ) -> Result<(), Self::Error>; +} + +pub trait CompleteAddGroupVar: GroupElementVar { + fn complete_add( + builder: &mut R1csBuilder, + lhs: &Self, + rhs: &Self, + ) -> Result; +} + +pub trait DoubleGroupVar: GroupElementVar { + fn assert_double( + builder: &mut R1csBuilder, + input: &Self, + output: &Self, + ) -> Result<(), Self::Error>; +} + +pub trait VariableBaseScalarMulGroupVar: CompleteAddGroupVar { + fn variable_base_scalar_mul( + builder: &mut R1csBuilder, + base: &Self, + scalar: &Self::ScalarVar, + ) -> Result; +} + +pub trait FixedBaseScalarMulGroupVar: GroupElementVar { + type Constant; + + fn fixed_base_scalar_mul( + builder: &mut R1csBuilder, + base: &Self::Constant, + scalar: &Self::ScalarVar, + ) -> Result; +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct VectorCommitmentOpeningVar +where + S: jolt_poly::r1cs::PolynomialScalarGadget, +{ + pub values: Vec, + pub blinding: S, +} + +impl VectorCommitmentOpeningVar +where + S: jolt_poly::r1cs::PolynomialScalarGadget, +{ + pub fn new(values: Vec, blinding: S) -> Self { + Self { values, blinding } + } +} + +pub trait VectorCommitmentR1cs { + type BuilderField: jolt_field::Field; + type ScalarVar: jolt_poly::r1cs::PolynomialScalarGadget< + ConstraintBuilder = R1csBuilder, + >; + type OutputVar: GroupElementVar; + type SetupVar; + type Error; + + fn capacity(setup: &Self::SetupVar) -> usize; + + fn linear_combine_commitments( + builder: &mut R1csBuilder, + commitments: &[Self::OutputVar], + coefficients: &[Self::ScalarVar], + ) -> Result; + + fn verify_opening( + builder: &mut R1csBuilder, + setup: &Self::SetupVar, + commitment: &Self::OutputVar, + opening: &VectorCommitmentOpeningVar, + ) -> Result<(), Self::Error>; +} + +/// Affine Grumpkin point with coordinates in the BN254 scalar field `Fr`. +/// +/// This is an affine non-identity representation. It is useful for early +/// Grumpkin equation gadgets, but a full Pedersen verifier should move to a +/// complete representation before accepting arbitrary commitment outputs. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct GrumpkinPointVar { + pub x: AssignedScalar, + pub y: AssignedScalar, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct GrumpkinPointWithIdentityVar { + pub x: AssignedScalar, + pub y: AssignedScalar, + pub is_identity: AssignedScalar, +} + +impl GrumpkinPointVar { + pub fn new(x: AssignedScalar, y: AssignedScalar) -> Self { + Self { x, y } + } + + pub fn alloc( + builder: &mut R1csBuilder, + point: &GrumpkinPoint, + ) -> Result { + let (x, y) = grumpkin_coordinates(point)?; + Ok(Self { + x: AssignedScalar::alloc(builder, x), + y: AssignedScalar::alloc(builder, y), + }) + } + + pub fn constant(point: &GrumpkinPoint) -> Result { + let (x, y) = grumpkin_coordinates(point)?; + Ok(Self { + x: AssignedScalar::constant(x), + y: AssignedScalar::constant(y), + }) + } + + pub fn assert_equal(&self, builder: &mut R1csBuilder, rhs: &Self) { + self.x.assert_equal(builder, &rhs.x); + self.y.assert_equal(builder, &rhs.y); + } + + pub fn assert_on_curve(&self, builder: &mut R1csBuilder) { + let y_squared = self.y.mul(builder, &self.y); + let x_squared = self.x.mul(builder, &self.x); + let x_cubed = x_squared.mul(builder, &self.x); + let rhs = x_cubed.add(builder, &AssignedScalar::constant(-Fr::from_u64(17))); + y_squared.assert_equal(builder, &rhs); + } +} + +impl GrumpkinPointWithIdentityVar { + pub fn new( + x: AssignedScalar, + y: AssignedScalar, + is_identity: AssignedScalar, + ) -> Self { + Self { x, y, is_identity } + } + + pub fn identity() -> Self { + Self { + x: AssignedScalar::constant(Fr::zero()), + y: AssignedScalar::constant(Fr::zero()), + is_identity: AssignedScalar::constant(Fr::one()), + } + } + + pub fn alloc(builder: &mut R1csBuilder, point: &GrumpkinPoint) -> Self { + let affine = point.0.into_affine(); + if affine.infinity { + return Self { + x: AssignedScalar::alloc(builder, Fr::zero()), + y: AssignedScalar::alloc(builder, Fr::zero()), + is_identity: AssignedScalar::alloc(builder, Fr::one()), + }; + } + + Self { + x: AssignedScalar::alloc(builder, Fr::from(affine.x)), + y: AssignedScalar::alloc(builder, Fr::from(affine.y)), + is_identity: AssignedScalar::alloc(builder, Fr::zero()), + } + } + + pub fn constant(point: &GrumpkinPoint) -> Self { + let affine = point.0.into_affine(); + if affine.infinity { + return Self::identity(); + } + + Self { + x: AssignedScalar::constant(Fr::from(affine.x)), + y: AssignedScalar::constant(Fr::from(affine.y)), + is_identity: AssignedScalar::constant(Fr::zero()), + } + } + + pub fn from_nonidentity(point: GrumpkinPointVar) -> Self { + Self { + x: point.x, + y: point.y, + is_identity: AssignedScalar::constant(Fr::zero()), + } + } + + pub fn assert_valid(&self, builder: &mut R1csBuilder) { + assert_boolean(builder, &self.is_identity); + builder.assert_product( + self.is_identity.lc.clone(), + self.x.lc.clone(), + LinearCombination::zero(), + ); + builder.assert_product( + self.is_identity.lc.clone(), + self.y.lc.clone(), + LinearCombination::zero(), + ); + + let y_squared = self.y.mul(builder, &self.y); + let x_squared = self.x.mul(builder, &self.x); + let x_cubed = x_squared.mul(builder, &self.x); + let rhs = x_cubed.add(builder, &AssignedScalar::constant(-Fr::from_u64(17))); + let curve_residual = y_squared.sub(builder, &rhs); + let not_identity = AssignedScalar::constant(Fr::one()).sub(builder, &self.is_identity); + builder.assert_product( + not_identity.lc, + curve_residual.lc, + LinearCombination::zero(), + ); + } + + pub fn assert_equal(&self, builder: &mut R1csBuilder, rhs: &Self) { + self.x.assert_equal(builder, &rhs.x); + self.y.assert_equal(builder, &rhs.y); + self.is_identity.assert_equal(builder, &rhs.is_identity); + } + + pub fn select( + builder: &mut R1csBuilder, + selector: &AssignedScalar, + when_true: &Self, + when_false: &Self, + ) -> Self { + Self { + x: as ScalarGadget>::select( + builder, + selector, + &when_true.x, + &when_false.x, + ), + y: as ScalarGadget>::select( + builder, + selector, + &when_true.y, + &when_false.y, + ), + is_identity: as ScalarGadget>::select( + builder, + selector, + &when_true.is_identity, + &when_false.is_identity, + ), + } + } + + pub fn witness_value(&self) -> GrumpkinPoint { + if self.is_identity.value == Fr::one() { + GrumpkinPoint::identity() + } else { + grumpkin_point_from_coordinates(self.x.value, self.y.value) + } + } +} + +impl GroupElementVar for GrumpkinPointVar { + type BuilderField = Fr; + type ScalarVar = jolt_r1cs::FqVar; + type Error = CryptoR1csError; + + fn assert_valid( + &self, + builder: &mut R1csBuilder, + ) -> Result<(), Self::Error> { + self.assert_on_curve(builder); + Ok(()) + } + + fn assert_equal(&self, builder: &mut R1csBuilder, rhs: &Self) { + GrumpkinPointVar::assert_equal(self, builder, rhs); + } +} + +impl GroupElementVar for GrumpkinPointWithIdentityVar { + type BuilderField = Fr; + type ScalarVar = FqVar; + type Error = CryptoR1csError; + + fn assert_valid( + &self, + builder: &mut R1csBuilder, + ) -> Result<(), Self::Error> { + self.assert_valid(builder); + Ok(()) + } + + fn assert_equal(&self, builder: &mut R1csBuilder, rhs: &Self) { + GrumpkinPointWithIdentityVar::assert_equal(self, builder, rhs); + } +} + +impl CompleteAddGroupVar for GrumpkinPointWithIdentityVar { + fn complete_add( + builder: &mut R1csBuilder, + lhs: &Self, + rhs: &Self, + ) -> Result { + complete_grumpkin_add(builder, lhs, rhs) + } +} + +impl VariableBaseScalarMulGroupVar for GrumpkinPointWithIdentityVar { + fn variable_base_scalar_mul( + builder: &mut R1csBuilder, + base: &Self, + scalar: &Self::ScalarVar, + ) -> Result { + variable_base_grumpkin_scalar_mul(builder, base, scalar) + } +} + +impl FixedBaseScalarMulGroupVar for GrumpkinPointWithIdentityVar { + type Constant = GrumpkinPoint; + + fn fixed_base_scalar_mul( + builder: &mut R1csBuilder, + base: &Self::Constant, + scalar: &Self::ScalarVar, + ) -> Result { + fixed_base_grumpkin_scalar_mul(builder, base, scalar) + } +} + +impl NonExceptionalAddGroupVar for GrumpkinPointVar { + fn assert_nonexceptional_add( + builder: &mut R1csBuilder, + lhs: &Self, + rhs: &Self, + output: &Self, + ) -> Result<(), Self::Error> { + assert_nonexceptional_grumpkin_add(builder, lhs, rhs, output) + } +} + +impl DoubleGroupVar for GrumpkinPointVar { + fn assert_double( + builder: &mut R1csBuilder, + input: &Self, + output: &Self, + ) -> Result<(), Self::Error> { + assert_grumpkin_double(builder, input, output) + } +} + +impl VectorCommitmentR1cs for Pedersen { + type BuilderField = Fr; + type ScalarVar = FqVar; + type OutputVar = GrumpkinPointWithIdentityVar; + type SetupVar = PedersenSetup; + type Error = CryptoR1csError; + + fn capacity(setup: &Self::SetupVar) -> usize { + setup.message_generators.len() + } + + fn linear_combine_commitments( + builder: &mut R1csBuilder, + commitments: &[Self::OutputVar], + coefficients: &[Self::ScalarVar], + ) -> Result { + linear_combine_grumpkin_commitments(builder, commitments, coefficients) + } + + fn verify_opening( + builder: &mut R1csBuilder, + setup: &Self::SetupVar, + commitment: &Self::OutputVar, + opening: &VectorCommitmentOpeningVar, + ) -> Result<(), Self::Error> { + verify_grumpkin_pedersen_opening(builder, setup, commitment, opening) + } +} + +/// Constrains `output = 2 * input` for affine Grumpkin points. +/// +/// The formula is valid for non-identity affine points with non-zero +/// y-coordinate. In the Grumpkin prime-order subgroup this is the ordinary +/// doubling case used by scalar multiplication. +pub fn assert_grumpkin_double( + builder: &mut R1csBuilder, + input: &GrumpkinPointVar, + output: &GrumpkinPointVar, +) -> Result<(), CryptoR1csError> { + input.assert_on_curve(builder); + output.assert_on_curve(builder); + + let two_y = input.y.scale_by_constant(builder, Fr::from_u64(2)); + if two_y.value.is_zero() { + return Err(CryptoR1csError::ExceptionalAffineDoubling); + } + let two_y_inverse = AssignedScalar::alloc( + builder, + two_y + .value + .inverse() + .ok_or(CryptoR1csError::ExceptionalAffineDoubling)?, + ); + builder.assert_product( + two_y.lc.clone(), + two_y_inverse.lc.clone(), + LinearCombination::one(), + ); + + let x_squared = input.x.mul(builder, &input.x); + let three_x_squared = x_squared.scale_by_constant(builder, Fr::from_u64(3)); + let slope = three_x_squared.mul(builder, &two_y_inverse); + let slope_squared = slope.mul(builder, &slope); + let two_x = input.x.scale_by_constant(builder, Fr::from_u64(2)); + let expected_slope_squared = output.x.add(builder, &two_x); + slope_squared.assert_equal(builder, &expected_slope_squared); + + let x_delta = input.x.sub(builder, &output.x); + let y_delta = output.y.add(builder, &input.y); + let expected_y_delta = slope.mul(builder, &x_delta); + y_delta.assert_equal(builder, &expected_y_delta); + + Ok(()) +} + +/// Constrains `output = lhs + rhs` for identity-aware affine points. +/// +/// The non-identity/non-identity case still uses the ordinary affine addition +/// relation, so doubling and inverse-pair addition are intentionally rejected. +pub fn assert_grumpkin_add_with_identity( + builder: &mut R1csBuilder, + lhs: &GrumpkinPointWithIdentityVar, + rhs: &GrumpkinPointWithIdentityVar, + output: &GrumpkinPointWithIdentityVar, +) -> Result<(), CryptoR1csError> { + lhs.assert_valid(builder); + rhs.assert_valid(builder); + output.assert_valid(builder); + + let not_lhs_identity = AssignedScalar::constant(Fr::one()).sub(builder, &lhs.is_identity); + let not_rhs_identity = AssignedScalar::constant(Fr::one()).sub(builder, &rhs.is_identity); + let rhs_only_gate = not_lhs_identity.mul(builder, &rhs.is_identity); + let active = not_lhs_identity.mul(builder, ¬_rhs_identity); + let expected_output_identity = lhs.is_identity.mul(builder, &rhs.is_identity); + output + .is_identity + .assert_equal(builder, &expected_output_identity); + + gated_assert_equal(builder, &lhs.is_identity, &output.x, &rhs.x); + gated_assert_equal(builder, &lhs.is_identity, &output.y, &rhs.y); + gated_assert_equal(builder, &rhs_only_gate, &output.x, &lhs.x); + gated_assert_equal(builder, &rhs_only_gate, &output.y, &lhs.y); + + let dx = rhs.x.sub(builder, &lhs.x); + if !active.value.is_zero() && dx.value.is_zero() { + return Err(CryptoR1csError::ExceptionalAffineAddition); + } + let dx_inverse_value = if active.value.is_zero() { + Fr::zero() + } else { + dx.value + .inverse() + .ok_or(CryptoR1csError::ExceptionalAffineAddition)? + }; + let dx_inverse = AssignedScalar::alloc(builder, dx_inverse_value); + let dx_times_inverse = dx.mul(builder, &dx_inverse); + let inverse_residual = dx_times_inverse.sub(builder, &AssignedScalar::constant(Fr::one())); + gated_assert_zero(builder, &active, &inverse_residual); + + let dy = rhs.y.sub(builder, &lhs.y); + let slope = dy.mul(builder, &dx_inverse); + let slope_squared = slope.mul(builder, &slope); + let expected_slope_squared = output.x.add(builder, &lhs.x).add(builder, &rhs.x); + let slope_residual = slope_squared.sub(builder, &expected_slope_squared); + gated_assert_zero(builder, &active, &slope_residual); + + let x_delta = lhs.x.sub(builder, &output.x); + let y_delta = output.y.add(builder, &lhs.y); + let expected_y_delta = slope.mul(builder, &x_delta); + let y_residual = y_delta.sub(builder, &expected_y_delta); + gated_assert_zero(builder, &active, &y_residual); + + Ok(()) +} + +/// Allocates and constrains `lhs + rhs` for all affine Grumpkin cases. +pub fn complete_grumpkin_add( + builder: &mut R1csBuilder, + lhs: &GrumpkinPointWithIdentityVar, + rhs: &GrumpkinPointWithIdentityVar, +) -> Result { + let output = + GrumpkinPointWithIdentityVar::alloc(builder, &(lhs.witness_value() + rhs.witness_value())); + assert_complete_grumpkin_add(builder, lhs, rhs, &output)?; + Ok(output) +} + +/// Constrains `output = lhs + rhs` for identity-aware affine Grumpkin points. +/// +/// This relation covers identity cases, ordinary affine addition, doubling, +/// and inverse-pair addition to the identity. Case flags are constrained with +/// zero tests and gates, so callers do not need to preclude exceptional inputs. +pub fn assert_complete_grumpkin_add( + builder: &mut R1csBuilder, + lhs: &GrumpkinPointWithIdentityVar, + rhs: &GrumpkinPointWithIdentityVar, + output: &GrumpkinPointWithIdentityVar, +) -> Result<(), CryptoR1csError> { + lhs.assert_valid(builder); + rhs.assert_valid(builder); + output.assert_valid(builder); + + let one = AssignedScalar::constant(Fr::one()); + let zero = AssignedScalar::constant(Fr::zero()); + let not_lhs_identity = one.sub(builder, &lhs.is_identity); + let not_rhs_identity = one.sub(builder, &rhs.is_identity); + let rhs_only_gate = not_lhs_identity.mul(builder, &rhs.is_identity); + let both_nonidentity = not_lhs_identity.mul(builder, ¬_rhs_identity); + + gated_assert_equal(builder, &lhs.is_identity, &output.x, &rhs.x); + gated_assert_equal(builder, &lhs.is_identity, &output.y, &rhs.y); + gated_assert_equal( + builder, + &lhs.is_identity, + &output.is_identity, + &rhs.is_identity, + ); + gated_assert_equal(builder, &rhs_only_gate, &output.x, &lhs.x); + gated_assert_equal(builder, &rhs_only_gate, &output.y, &lhs.y); + gated_assert_zero(builder, &rhs_only_gate, &output.is_identity); + + let dx = rhs.x.sub(builder, &lhs.x); + let dx_zero = zero_check(builder, &dx); + let dy = rhs.y.sub(builder, &lhs.y); + let dy_zero = zero_check(builder, &dy); + let y_sum = rhs.y.add(builder, &lhs.y); + let y_sum_zero = zero_check(builder, &y_sum); + + let not_dx_zero = one.sub(builder, &dx_zero.is_zero); + let ordinary_gate = both_nonidentity.mul(builder, ¬_dx_zero); + let same_x_gate = both_nonidentity.mul(builder, &dx_zero.is_zero); + let same_point_gate = same_x_gate.mul(builder, &dy_zero.is_zero); + let not_dy_zero = one.sub(builder, &dy_zero.is_zero); + let inverse_candidate_gate = same_x_gate.mul(builder, ¬_dy_zero); + let inverse_gate = inverse_candidate_gate.mul(builder, &y_sum_zero.is_zero); + let not_y_sum_zero = one.sub(builder, &y_sum_zero.is_zero); + let invalid_same_x_gate = inverse_candidate_gate.mul(builder, ¬_y_sum_zero); + invalid_same_x_gate.assert_equal(builder, &zero); + + let two_y = lhs.y.scale_by_constant(builder, Fr::from_u64(2)); + let two_y_zero = zero_check(builder, &two_y); + let not_two_y_zero = one.sub(builder, &two_y_zero.is_zero); + let double_gate = same_point_gate.mul(builder, ¬_two_y_zero); + let double_to_identity_gate = same_point_gate.mul(builder, &two_y_zero.is_zero); + + gated_assert_zero(builder, &ordinary_gate, &output.is_identity); + let ordinary_slope = dy.mul(builder, &dx_zero.inverse); + let ordinary_slope_squared = ordinary_slope.mul(builder, &ordinary_slope); + let expected_ordinary_slope_squared = output.x.add(builder, &lhs.x).add(builder, &rhs.x); + let ordinary_x_residual = ordinary_slope_squared.sub(builder, &expected_ordinary_slope_squared); + gated_assert_zero(builder, &ordinary_gate, &ordinary_x_residual); + let ordinary_x_delta = lhs.x.sub(builder, &output.x); + let ordinary_y_delta = output.y.add(builder, &lhs.y); + let expected_ordinary_y_delta = ordinary_slope.mul(builder, &ordinary_x_delta); + let ordinary_y_residual = ordinary_y_delta.sub(builder, &expected_ordinary_y_delta); + gated_assert_zero(builder, &ordinary_gate, &ordinary_y_residual); + + gated_assert_zero(builder, &double_gate, &output.is_identity); + let x_squared = lhs.x.mul(builder, &lhs.x); + let three_x_squared = x_squared.scale_by_constant(builder, Fr::from_u64(3)); + let double_slope = three_x_squared.mul(builder, &two_y_zero.inverse); + let double_slope_squared = double_slope.mul(builder, &double_slope); + let two_x = lhs.x.scale_by_constant(builder, Fr::from_u64(2)); + let expected_double_slope_squared = output.x.add(builder, &two_x); + let double_x_residual = double_slope_squared.sub(builder, &expected_double_slope_squared); + gated_assert_zero(builder, &double_gate, &double_x_residual); + let double_x_delta = lhs.x.sub(builder, &output.x); + let double_y_delta = output.y.add(builder, &lhs.y); + let expected_double_y_delta = double_slope.mul(builder, &double_x_delta); + let double_y_residual = double_y_delta.sub(builder, &expected_double_y_delta); + gated_assert_zero(builder, &double_gate, &double_y_residual); + + assert_output_identity(builder, &inverse_gate, output); + assert_output_identity(builder, &double_to_identity_gate, output); + + Ok(()) +} + +pub fn fixed_base_grumpkin_scalar_mul( + builder: &mut R1csBuilder, + base: &GrumpkinPoint, + scalar: &FqVar, +) -> Result { + let mut accumulator = GrumpkinPointWithIdentityVar::identity(); + let mut base_power = *base; + let identity = GrumpkinPointWithIdentityVar::identity(); + + for bit in scalar.bits_le(builder) { + let selected = GrumpkinPointWithIdentityVar::select( + builder, + &bit, + &GrumpkinPointWithIdentityVar::constant(&base_power), + &identity, + ); + accumulator = complete_grumpkin_add(builder, &accumulator, &selected)?; + base_power = base_power.double(); + } + + Ok(accumulator) +} + +/// Constrains a fixed-base MSM over Grumpkin. +/// +pub fn fixed_base_grumpkin_msm( + builder: &mut R1csBuilder, + bases: &[GrumpkinPoint], + scalars: &[FqVar], +) -> Result { + if bases.len() != scalars.len() { + return Err(CryptoR1csError::FixedBaseMsmLengthMismatch { + bases: bases.len(), + scalars: scalars.len(), + }); + } + + let mut accumulator = GrumpkinPointWithIdentityVar::identity(); + for (base, scalar) in bases.iter().zip(scalars) { + let term = fixed_base_grumpkin_scalar_mul(builder, base, scalar)?; + accumulator = complete_grumpkin_add(builder, &accumulator, &term)?; + } + + Ok(accumulator) +} + +pub fn variable_base_grumpkin_scalar_mul( + builder: &mut R1csBuilder, + base: &GrumpkinPointWithIdentityVar, + scalar: &FqVar, +) -> Result { + base.assert_valid(builder); + + let mut accumulator = GrumpkinPointWithIdentityVar::identity(); + let mut base_power = base.clone(); + let identity = GrumpkinPointWithIdentityVar::identity(); + + for bit in scalar.bits_le(builder) { + let selected = GrumpkinPointWithIdentityVar::select(builder, &bit, &base_power, &identity); + accumulator = complete_grumpkin_add(builder, &accumulator, &selected)?; + base_power = complete_grumpkin_add(builder, &base_power, &base_power)?; + } + + Ok(accumulator) +} + +pub fn linear_combine_grumpkin_commitments( + builder: &mut R1csBuilder, + commitments: &[GrumpkinPointWithIdentityVar], + coefficients: &[FqVar], +) -> Result { + if commitments.len() != coefficients.len() { + return Err(CryptoR1csError::FixedBaseMsmLengthMismatch { + bases: commitments.len(), + scalars: coefficients.len(), + }); + } + + let mut accumulator = GrumpkinPointWithIdentityVar::identity(); + for (commitment, coefficient) in commitments.iter().zip(coefficients) { + let term = variable_base_grumpkin_scalar_mul(builder, commitment, coefficient)?; + accumulator = complete_grumpkin_add(builder, &accumulator, &term)?; + } + + Ok(accumulator) +} + +pub fn grumpkin_pedersen_opening_commitment( + builder: &mut R1csBuilder, + setup: &PedersenSetup, + opening: &VectorCommitmentOpeningVar, +) -> Result { + let capacity = setup.message_generators.len(); + if opening.values.len() > capacity { + return Err(CryptoR1csError::VectorCommitmentCapacityExceeded { + capacity, + values: opening.values.len(), + }); + } + + let mut bases = setup.message_generators[..opening.values.len()].to_vec(); + bases.push(setup.blinding_generator); + let mut scalars = opening.values.clone(); + scalars.push(opening.blinding.clone()); + + fixed_base_grumpkin_msm(builder, &bases, &scalars) +} + +pub fn verify_grumpkin_pedersen_opening( + builder: &mut R1csBuilder, + setup: &PedersenSetup, + commitment: &GrumpkinPointWithIdentityVar, + opening: &VectorCommitmentOpeningVar, +) -> Result<(), CryptoR1csError> { + commitment.assert_valid(builder); + let computed = grumpkin_pedersen_opening_commitment(builder, setup, opening)?; + computed.assert_equal(builder, commitment); + Ok(()) +} + +/// Constrains `output = lhs + rhs` for ordinary affine additions. +/// +/// This helper intentionally rejects exceptional cases (`lhs.x == rhs.x`). +/// Doubling and adding inverse points need separate formulas so callers cannot +/// accidentally use an incomplete relation for those cases. +pub fn assert_nonexceptional_grumpkin_add( + builder: &mut R1csBuilder, + lhs: &GrumpkinPointVar, + rhs: &GrumpkinPointVar, + output: &GrumpkinPointVar, +) -> Result<(), CryptoR1csError> { + lhs.assert_on_curve(builder); + rhs.assert_on_curve(builder); + output.assert_on_curve(builder); + + let dx = rhs.x.sub(builder, &lhs.x); + if dx.value.is_zero() { + return Err(CryptoR1csError::ExceptionalAffineAddition); + } + let dy = rhs.y.sub(builder, &lhs.y); + let dx_inverse = AssignedScalar::alloc( + builder, + dx.value + .inverse() + .ok_or(CryptoR1csError::ExceptionalAffineAddition)?, + ); + builder.assert_product( + dx.lc.clone(), + dx_inverse.lc.clone(), + LinearCombination::one(), + ); + + let slope = dy.mul(builder, &dx_inverse); + let slope_squared = slope.mul(builder, &slope); + let expected_slope_squared = output.x.add(builder, &lhs.x).add(builder, &rhs.x); + slope_squared.assert_equal(builder, &expected_slope_squared); + + let x_delta = lhs.x.sub(builder, &output.x); + let y_delta = output.y.add(builder, &lhs.y); + let expected_y_delta = slope.mul(builder, &x_delta); + y_delta.assert_equal(builder, &expected_y_delta); + + Ok(()) +} + +fn grumpkin_coordinates(point: &GrumpkinPoint) -> Result<(Fr, Fr), CryptoR1csError> { + let affine = point.0.into_affine(); + if affine.infinity { + return Err(CryptoR1csError::IdentityPoint); + } + Ok((Fr::from(affine.x), Fr::from(affine.y))) +} + +fn grumpkin_point_from_coordinates(x: Fr, y: Fr) -> GrumpkinPoint { + GrumpkinPoint(ark_grumpkin::Affine::new_unchecked(x.into(), y.into()).into()) +} + +#[derive(Clone, Debug)] +struct ZeroCheck { + is_zero: AssignedScalar, + inverse: AssignedScalar, +} + +fn zero_check(builder: &mut R1csBuilder, value: &AssignedScalar) -> ZeroCheck { + let is_zero = AssignedScalar::alloc(builder, Fr::from_bool(value.value.is_zero())); + assert_boolean(builder, &is_zero); + let inverse = AssignedScalar::alloc(builder, value.value.inverse().unwrap_or_else(Fr::zero)); + builder.assert_product( + value.lc.clone(), + is_zero.lc.clone(), + LinearCombination::zero(), + ); + let product = value.mul(builder, &inverse); + let one_minus_is_zero = AssignedScalar::constant(Fr::one()).sub(builder, &is_zero); + product.assert_equal(builder, &one_minus_is_zero); + + ZeroCheck { is_zero, inverse } +} + +fn assert_boolean(builder: &mut R1csBuilder, value: &AssignedScalar) { + builder.assert_product( + value.lc.clone(), + value.lc.clone() - LinearCombination::one(), + LinearCombination::zero(), + ); +} + +fn gated_assert_equal( + builder: &mut R1csBuilder, + gate: &AssignedScalar, + lhs: &AssignedScalar, + rhs: &AssignedScalar, +) { + let difference = lhs.sub(builder, rhs); + gated_assert_zero(builder, gate, &difference); +} + +fn gated_assert_zero( + builder: &mut R1csBuilder, + gate: &AssignedScalar, + value: &AssignedScalar, +) { + builder.assert_product(gate.lc.clone(), value.lc.clone(), LinearCombination::zero()); +} + +fn assert_output_identity( + builder: &mut R1csBuilder, + gate: &AssignedScalar, + output: &GrumpkinPointWithIdentityVar, +) { + gated_assert_zero(builder, gate, &output.x); + gated_assert_zero(builder, gate, &output.y); + gated_assert_equal( + builder, + gate, + &output.is_identity, + &AssignedScalar::constant(Fr::one()), + ); +} + +#[cfg(test)] +#[expect(clippy::expect_used, reason = "tests may panic on assertion failures")] +mod tests { + use jolt_field::Fq; + use jolt_r1cs::Variable; + + use super::*; + use crate::{Grumpkin, JoltGroup, Pedersen, VectorCommitment}; + + #[test] + fn grumpkin_on_curve_accepts_valid_point() { + let mut builder = R1csBuilder::::new(); + let point = Grumpkin::generator().scalar_mul(&Fq::from_u64(5)); + let point = GrumpkinPointVar::alloc(&mut builder, &point).expect("non-identity point"); + + point.assert_on_curve(&mut builder); + + assert!(builder_accepts(builder)); + } + + #[test] + fn grumpkin_on_curve_rejects_tampered_coordinate() { + let mut builder = R1csBuilder::::new(); + let point = Grumpkin::generator().scalar_mul(&Fq::from_u64(5)); + let point = GrumpkinPointVar::alloc(&mut builder, &point).expect("non-identity point"); + let targets = [ + ("x-coordinate", variable(&point.x)), + ("y-coordinate", variable(&point.y)), + ]; + + point.assert_on_curve(&mut builder); + + assert_tampering_rejected(builder, targets); + } + + #[test] + fn grumpkin_on_curve_rejects_invalid_point() { + let mut builder = R1csBuilder::::new(); + let point = GrumpkinPointVar::new( + AssignedScalar::alloc(&mut builder, Fr::from_u64(1)), + AssignedScalar::alloc(&mut builder, Fr::from_u64(1)), + ); + + point.assert_on_curve(&mut builder); + + assert!(builder_rejects(builder)); + } + + #[test] + fn group_element_trait_accepts_grumpkin_point_var() { + let mut builder = R1csBuilder::::new(); + let point = Grumpkin::generator().scalar_mul(&Fq::from_u64(5)); + let point = GrumpkinPointVar::alloc(&mut builder, &point).expect("non-identity point"); + + assert_group_var_valid(&mut builder, &point).expect("valid group variable"); + GroupElementVar::assert_equal(&point, &mut builder, &point); + + assert!(builder_accepts(builder)); + } + + #[test] + fn grumpkin_nonexceptional_add_accepts_valid_sum() { + let mut builder = R1csBuilder::::new(); + let p = Grumpkin::generator().scalar_mul(&Fq::from_u64(5)); + let q = Grumpkin::generator().scalar_mul(&Fq::from_u64(7)); + let sum = p + q; + let p = GrumpkinPointVar::alloc(&mut builder, &p).expect("non-identity point"); + let q = GrumpkinPointVar::alloc(&mut builder, &q).expect("non-identity point"); + let sum = GrumpkinPointVar::alloc(&mut builder, &sum).expect("non-identity point"); + + assert_nonexceptional_grumpkin_add(&mut builder, &p, &q, &sum).expect("ordinary addition"); + + assert!(builder_accepts(builder)); + } + + #[test] + fn nonexceptional_add_trait_accepts_valid_grumpkin_sum() { + let mut builder = R1csBuilder::::new(); + let p = Grumpkin::generator().scalar_mul(&Fq::from_u64(5)); + let q = Grumpkin::generator().scalar_mul(&Fq::from_u64(7)); + let sum = p + q; + let p = GrumpkinPointVar::alloc(&mut builder, &p).expect("non-identity point"); + let q = GrumpkinPointVar::alloc(&mut builder, &q).expect("non-identity point"); + let sum = GrumpkinPointVar::alloc(&mut builder, &sum).expect("non-identity point"); + + assert_generic_nonexceptional_add(&mut builder, &p, &q, &sum).expect("ordinary addition"); + + assert!(builder_accepts(builder)); + } + + #[test] + fn grumpkin_double_accepts_valid_double() { + let mut builder = R1csBuilder::::new(); + let point = Grumpkin::generator().scalar_mul(&Fq::from_u64(5)); + let doubled = point.double(); + let point = GrumpkinPointVar::alloc(&mut builder, &point).expect("non-identity point"); + let doubled = GrumpkinPointVar::alloc(&mut builder, &doubled).expect("non-identity point"); + + assert_grumpkin_double(&mut builder, &point, &doubled).expect("ordinary doubling"); + + assert!(builder_accepts(builder)); + } + + #[test] + fn grumpkin_double_trait_accepts_valid_double() { + let mut builder = R1csBuilder::::new(); + let point = Grumpkin::generator().scalar_mul(&Fq::from_u64(5)); + let doubled = point.double(); + let point = GrumpkinPointVar::alloc(&mut builder, &point).expect("non-identity point"); + let doubled = GrumpkinPointVar::alloc(&mut builder, &doubled).expect("non-identity point"); + + assert_generic_double(&mut builder, &point, &doubled).expect("ordinary doubling"); + + assert!(builder_accepts(builder)); + } + + #[test] + fn grumpkin_double_rejects_tampered_output() { + let mut builder = R1csBuilder::::new(); + let point = Grumpkin::generator().scalar_mul(&Fq::from_u64(5)); + let doubled = point.double(); + let point = GrumpkinPointVar::alloc(&mut builder, &point).expect("non-identity point"); + let doubled = GrumpkinPointVar::alloc(&mut builder, &doubled).expect("non-identity point"); + let targets = [ + ("doubled x-coordinate", variable(&doubled.x)), + ("doubled y-coordinate", variable(&doubled.y)), + ]; + + assert_grumpkin_double(&mut builder, &point, &doubled).expect("ordinary doubling"); + + assert_tampering_rejected(builder, targets); + } + + #[test] + fn grumpkin_double_rejects_zero_y_input() { + let mut builder = R1csBuilder::::new(); + let input = GrumpkinPointVar::new( + AssignedScalar::alloc(&mut builder, Fr::from_u64(1)), + AssignedScalar::alloc(&mut builder, Fr::zero()), + ); + let output = GrumpkinPointVar::new( + AssignedScalar::alloc(&mut builder, Fr::from_u64(1)), + AssignedScalar::alloc(&mut builder, Fr::from_u64(1)), + ); + + assert_eq!( + assert_grumpkin_double(&mut builder, &input, &output), + Err(CryptoR1csError::ExceptionalAffineDoubling) + ); + } + + #[test] + fn grumpkin_nonexceptional_add_rejects_tampered_output() { + let mut builder = R1csBuilder::::new(); + let p = Grumpkin::generator().scalar_mul(&Fq::from_u64(5)); + let q = Grumpkin::generator().scalar_mul(&Fq::from_u64(7)); + let sum = p + q; + let p = GrumpkinPointVar::alloc(&mut builder, &p).expect("non-identity point"); + let q = GrumpkinPointVar::alloc(&mut builder, &q).expect("non-identity point"); + let sum = GrumpkinPointVar::alloc(&mut builder, &sum).expect("non-identity point"); + let targets = [ + ("sum x-coordinate", variable(&sum.x)), + ("sum y-coordinate", variable(&sum.y)), + ]; + + assert_nonexceptional_grumpkin_add(&mut builder, &p, &q, &sum).expect("ordinary addition"); + + assert_tampering_rejected(builder, targets); + } + + #[test] + fn grumpkin_nonexceptional_add_rejects_equal_x_inputs() { + let mut builder = R1csBuilder::::new(); + let p = Grumpkin::generator().scalar_mul(&Fq::from_u64(5)); + let doubled = p + p; + let p = GrumpkinPointVar::alloc(&mut builder, &p).expect("non-identity point"); + let doubled = GrumpkinPointVar::alloc(&mut builder, &doubled).expect("non-identity point"); + + assert_eq!( + assert_nonexceptional_grumpkin_add(&mut builder, &p, &p, &doubled), + Err(CryptoR1csError::ExceptionalAffineAddition) + ); + } + + #[test] + fn grumpkin_affine_gadget_rejects_identity() { + let mut builder = R1csBuilder::::new(); + + assert_eq!( + GrumpkinPointVar::alloc(&mut builder, &GrumpkinPoint::identity()), + Err(CryptoR1csError::IdentityPoint) + ); + } + + #[test] + fn grumpkin_identity_aware_point_accepts_identity() { + let mut builder = R1csBuilder::::new(); + let identity = + GrumpkinPointWithIdentityVar::alloc(&mut builder, &GrumpkinPoint::identity()); + + identity.assert_valid(&mut builder); + + assert!(builder_accepts(builder)); + } + + #[test] + fn grumpkin_identity_aware_point_rejects_malformed_identity() { + let mut builder = R1csBuilder::::new(); + let point = GrumpkinPointWithIdentityVar::new( + AssignedScalar::alloc(&mut builder, Fr::from_u64(1)), + AssignedScalar::alloc(&mut builder, Fr::zero()), + AssignedScalar::alloc(&mut builder, Fr::one()), + ); + + point.assert_valid(&mut builder); + + assert!(builder_rejects(builder)); + } + + #[test] + fn grumpkin_add_with_identity_accepts_left_identity() { + let mut builder = R1csBuilder::::new(); + let identity = + GrumpkinPointWithIdentityVar::alloc(&mut builder, &GrumpkinPoint::identity()); + let point = Grumpkin::generator().scalar_mul(&Fq::from_u64(5)); + let point = GrumpkinPointWithIdentityVar::alloc(&mut builder, &point); + let output = point.clone(); + + assert_grumpkin_add_with_identity(&mut builder, &identity, &point, &output) + .expect("identity plus point"); + + assert!(builder_accepts(builder)); + } + + #[test] + fn grumpkin_add_with_identity_accepts_right_identity() { + let mut builder = R1csBuilder::::new(); + let point = Grumpkin::generator().scalar_mul(&Fq::from_u64(5)); + let point = GrumpkinPointWithIdentityVar::alloc(&mut builder, &point); + let identity = + GrumpkinPointWithIdentityVar::alloc(&mut builder, &GrumpkinPoint::identity()); + let output = point.clone(); + + assert_grumpkin_add_with_identity(&mut builder, &point, &identity, &output) + .expect("point plus identity"); + + assert!(builder_accepts(builder)); + } + + #[test] + fn grumpkin_add_with_identity_accepts_both_identities() { + let mut builder = R1csBuilder::::new(); + let lhs = GrumpkinPointWithIdentityVar::alloc(&mut builder, &GrumpkinPoint::identity()); + let rhs = GrumpkinPointWithIdentityVar::alloc(&mut builder, &GrumpkinPoint::identity()); + let output = GrumpkinPointWithIdentityVar::alloc(&mut builder, &GrumpkinPoint::identity()); + + assert_grumpkin_add_with_identity(&mut builder, &lhs, &rhs, &output) + .expect("identity plus identity"); + + assert!(builder_accepts(builder)); + } + + #[test] + fn grumpkin_add_with_identity_accepts_ordinary_sum() { + let mut builder = R1csBuilder::::new(); + let p = Grumpkin::generator().scalar_mul(&Fq::from_u64(5)); + let q = Grumpkin::generator().scalar_mul(&Fq::from_u64(7)); + let sum = p + q; + let p = GrumpkinPointWithIdentityVar::alloc(&mut builder, &p); + let q = GrumpkinPointWithIdentityVar::alloc(&mut builder, &q); + let sum = GrumpkinPointWithIdentityVar::alloc(&mut builder, &sum); + + assert_grumpkin_add_with_identity(&mut builder, &p, &q, &sum).expect("ordinary addition"); + + assert!(builder_accepts(builder)); + } + + #[test] + fn grumpkin_add_with_identity_rejects_tampered_identity_case_output() { + let mut builder = R1csBuilder::::new(); + let identity = + GrumpkinPointWithIdentityVar::alloc(&mut builder, &GrumpkinPoint::identity()); + let point = Grumpkin::generator().scalar_mul(&Fq::from_u64(5)); + let point = GrumpkinPointWithIdentityVar::alloc(&mut builder, &point); + let output = point.clone(); + let targets = [("identity-case output x", variable(&output.x))]; + + assert_grumpkin_add_with_identity(&mut builder, &identity, &point, &output) + .expect("identity plus point"); + + assert_tampering_rejected(builder, targets); + } + + #[test] + fn grumpkin_add_with_identity_rejects_tampered_ordinary_output() { + let mut builder = R1csBuilder::::new(); + let p = Grumpkin::generator().scalar_mul(&Fq::from_u64(5)); + let q = Grumpkin::generator().scalar_mul(&Fq::from_u64(7)); + let sum = p + q; + let p = GrumpkinPointWithIdentityVar::alloc(&mut builder, &p); + let q = GrumpkinPointWithIdentityVar::alloc(&mut builder, &q); + let sum = GrumpkinPointWithIdentityVar::alloc(&mut builder, &sum); + let targets = [ + ("ordinary output x", variable(&sum.x)), + ("ordinary output y", variable(&sum.y)), + ]; + + assert_grumpkin_add_with_identity(&mut builder, &p, &q, &sum).expect("ordinary addition"); + + assert_tampering_rejected(builder, targets); + } + + #[test] + fn grumpkin_add_with_identity_rejects_nonidentity_exceptional_addition() { + let mut builder = R1csBuilder::::new(); + let point = Grumpkin::generator().scalar_mul(&Fq::from_u64(5)); + let doubled = point.double(); + let point = GrumpkinPointWithIdentityVar::alloc(&mut builder, &point); + let doubled = GrumpkinPointWithIdentityVar::alloc(&mut builder, &doubled); + + assert_eq!( + assert_grumpkin_add_with_identity(&mut builder, &point, &point, &doubled), + Err(CryptoR1csError::ExceptionalAffineAddition) + ); + } + + #[test] + fn complete_grumpkin_add_accepts_identity_cases() { + let mut builder = R1csBuilder::::new(); + let identity = + GrumpkinPointWithIdentityVar::alloc(&mut builder, &GrumpkinPoint::identity()); + let point = Grumpkin::generator().scalar_mul(&Fq::from_u64(5)); + let point = GrumpkinPointWithIdentityVar::alloc(&mut builder, &point); + + let left = complete_grumpkin_add(&mut builder, &identity, &point).expect("O + P"); + let right = complete_grumpkin_add(&mut builder, &point, &identity).expect("P + O"); + let both = complete_grumpkin_add(&mut builder, &identity, &identity).expect("O + O"); + + left.assert_equal(&mut builder, &point); + right.assert_equal(&mut builder, &point); + both.assert_equal(&mut builder, &identity); + assert!(builder_accepts(builder)); + } + + #[test] + fn complete_grumpkin_add_accepts_ordinary_sum() { + let mut builder = R1csBuilder::::new(); + let p = Grumpkin::generator().scalar_mul(&Fq::from_u64(5)); + let q = Grumpkin::generator().scalar_mul(&Fq::from_u64(7)); + let expected = GrumpkinPointWithIdentityVar::constant(&(p + q)); + let p = GrumpkinPointWithIdentityVar::alloc(&mut builder, &p); + let q = GrumpkinPointWithIdentityVar::alloc(&mut builder, &q); + + let sum = complete_grumpkin_add(&mut builder, &p, &q).expect("ordinary sum"); + + sum.assert_equal(&mut builder, &expected); + assert!(builder_accepts(builder)); + } + + #[test] + fn complete_grumpkin_add_accepts_doubling() { + let mut builder = R1csBuilder::::new(); + let p = Grumpkin::generator().scalar_mul(&Fq::from_u64(5)); + let expected = GrumpkinPointWithIdentityVar::constant(&p.double()); + let p = GrumpkinPointWithIdentityVar::alloc(&mut builder, &p); + + let doubled = complete_grumpkin_add(&mut builder, &p, &p).expect("doubling"); + + doubled.assert_equal(&mut builder, &expected); + assert!(builder_accepts(builder)); + } + + #[test] + fn complete_grumpkin_add_accepts_inverse_to_identity() { + let mut builder = R1csBuilder::::new(); + let p = Grumpkin::generator().scalar_mul(&Fq::from_u64(5)); + let neg_p = -p; + let p = GrumpkinPointWithIdentityVar::alloc(&mut builder, &p); + let neg_p = GrumpkinPointWithIdentityVar::alloc(&mut builder, &neg_p); + + let sum = complete_grumpkin_add(&mut builder, &p, &neg_p).expect("inverse sum"); + + sum.assert_equal(&mut builder, &GrumpkinPointWithIdentityVar::identity()); + assert!(builder_accepts(builder)); + } + + #[test] + fn complete_grumpkin_add_trait_accepts_valid_sum() { + let mut builder = R1csBuilder::::new(); + let p = Grumpkin::generator().scalar_mul(&Fq::from_u64(5)); + let q = Grumpkin::generator().scalar_mul(&Fq::from_u64(7)); + let expected = GrumpkinPointWithIdentityVar::constant(&(p + q)); + let p = GrumpkinPointWithIdentityVar::alloc(&mut builder, &p); + let q = GrumpkinPointWithIdentityVar::alloc(&mut builder, &q); + + let sum = assert_generic_complete_add::(&mut builder, &p, &q) + .expect("complete add"); + + sum.assert_equal(&mut builder, &expected); + assert!(builder_accepts(builder)); + } + + #[test] + fn complete_grumpkin_add_rejects_tampered_ordinary_output() { + let mut builder = R1csBuilder::::new(); + let p = Grumpkin::generator().scalar_mul(&Fq::from_u64(5)); + let q = Grumpkin::generator().scalar_mul(&Fq::from_u64(7)); + let expected = p + q; + let p = GrumpkinPointWithIdentityVar::alloc(&mut builder, &p); + let q = GrumpkinPointWithIdentityVar::alloc(&mut builder, &q); + let output = GrumpkinPointWithIdentityVar::alloc(&mut builder, &expected); + let targets = [ + ("complete ordinary output x", variable(&output.x)), + ("complete ordinary output y", variable(&output.y)), + ]; + + assert_complete_grumpkin_add(&mut builder, &p, &q, &output).expect("complete add"); + + assert_tampering_rejected(builder, targets); + } + + #[test] + fn complete_grumpkin_add_rejects_tampered_inverse_output() { + let mut builder = R1csBuilder::::new(); + let p = Grumpkin::generator().scalar_mul(&Fq::from_u64(5)); + let neg_p = -p; + let p = GrumpkinPointWithIdentityVar::alloc(&mut builder, &p); + let neg_p = GrumpkinPointWithIdentityVar::alloc(&mut builder, &neg_p); + let output = GrumpkinPointWithIdentityVar::alloc(&mut builder, &GrumpkinPoint::identity()); + let targets = [( + "complete inverse output identity flag", + variable(&output.is_identity), + )]; + + assert_complete_grumpkin_add(&mut builder, &p, &neg_p, &output).expect("complete add"); + + assert_tampering_rejected(builder, targets); + } + + #[test] + fn fixed_base_grumpkin_scalar_mul_accepts_zero_scalar() { + let mut builder = R1csBuilder::::new(); + let base = Grumpkin::generator(); + let scalar = FqVar::alloc(&mut builder, Fq::zero()); + + let result = + fixed_base_grumpkin_scalar_mul(&mut builder, &base, &scalar).expect("scalar mul"); + + result.assert_equal(&mut builder, &GrumpkinPointWithIdentityVar::identity()); + assert!(builder_accepts(builder)); + } + + #[test] + fn fixed_base_grumpkin_scalar_mul_accepts_nonzero_scalar() { + let mut builder = R1csBuilder::::new(); + let base = Grumpkin::generator(); + let scalar_value = Fq::from_u64(13); + let scalar = FqVar::alloc(&mut builder, scalar_value); + + let result = + fixed_base_grumpkin_scalar_mul(&mut builder, &base, &scalar).expect("scalar mul"); + let expected = GrumpkinPointWithIdentityVar::constant(&base.scalar_mul(&scalar_value)); + + result.assert_equal(&mut builder, &expected); + assert!(builder_accepts(builder)); + } + + #[test] + fn fixed_base_scalar_mul_trait_accepts_nonzero_scalar() { + let mut builder = R1csBuilder::::new(); + let base = Grumpkin::generator(); + let scalar_value = Fq::from_u64(19); + let scalar = FqVar::alloc(&mut builder, scalar_value); + + let result = assert_generic_fixed_base_scalar_mul::( + &mut builder, + &base, + &scalar, + ) + .expect("scalar mul"); + let expected = GrumpkinPointWithIdentityVar::constant(&base.scalar_mul(&scalar_value)); + + result.assert_equal(&mut builder, &expected); + assert!(builder_accepts(builder)); + } + + #[test] + fn fixed_base_grumpkin_scalar_mul_rejects_tampered_output() { + let mut builder = R1csBuilder::::new(); + let base = Grumpkin::generator(); + let scalar = FqVar::alloc(&mut builder, Fq::from_u64(13)); + + let result = + fixed_base_grumpkin_scalar_mul(&mut builder, &base, &scalar).expect("scalar mul"); + let targets = [("scalar mul output x", variable(&result.x))]; + + assert_tampering_rejected(builder, targets); + } + + #[test] + fn variable_base_grumpkin_scalar_mul_accepts_zero_scalar() { + let mut builder = R1csBuilder::::new(); + let base = Grumpkin::generator().scalar_mul(&Fq::from_u64(11)); + let base = GrumpkinPointWithIdentityVar::alloc(&mut builder, &base); + let scalar = FqVar::alloc(&mut builder, Fq::zero()); + + let result = + variable_base_grumpkin_scalar_mul(&mut builder, &base, &scalar).expect("scalar mul"); + + result.assert_equal(&mut builder, &GrumpkinPointWithIdentityVar::identity()); + assert!(builder_accepts(builder)); + } + + #[test] + fn variable_base_grumpkin_scalar_mul_accepts_nonzero_scalar() { + let mut builder = R1csBuilder::::new(); + let base_value = Grumpkin::generator().scalar_mul(&Fq::from_u64(11)); + let scalar_value = Fq::from_u64(13); + let base = GrumpkinPointWithIdentityVar::alloc(&mut builder, &base_value); + let scalar = FqVar::alloc(&mut builder, scalar_value); + + let result = + variable_base_grumpkin_scalar_mul(&mut builder, &base, &scalar).expect("scalar mul"); + let expected = + GrumpkinPointWithIdentityVar::constant(&base_value.scalar_mul(&scalar_value)); + + result.assert_equal(&mut builder, &expected); + assert!(builder_accepts(builder)); + } + + #[test] + fn variable_base_scalar_mul_trait_accepts_nonzero_scalar() { + let mut builder = R1csBuilder::::new(); + let base_value = Grumpkin::generator().scalar_mul(&Fq::from_u64(11)); + let scalar_value = Fq::from_u64(19); + let base = GrumpkinPointWithIdentityVar::alloc(&mut builder, &base_value); + let scalar = FqVar::alloc(&mut builder, scalar_value); + + let result = assert_generic_variable_base_scalar_mul::( + &mut builder, + &base, + &scalar, + ) + .expect("scalar mul"); + let expected = + GrumpkinPointWithIdentityVar::constant(&base_value.scalar_mul(&scalar_value)); + + result.assert_equal(&mut builder, &expected); + assert!(builder_accepts(builder)); + } + + #[test] + fn variable_base_grumpkin_scalar_mul_rejects_tampered_base() { + let mut builder = R1csBuilder::::new(); + let base = Grumpkin::generator().scalar_mul(&Fq::from_u64(11)); + let base = GrumpkinPointWithIdentityVar::alloc(&mut builder, &base); + let scalar = FqVar::alloc(&mut builder, Fq::from_u64(13)); + let result = + variable_base_grumpkin_scalar_mul(&mut builder, &base, &scalar).expect("scalar mul"); + result.assert_valid(&mut builder); + let targets = [("variable-base scalar mul base x", variable(&base.x))]; + + assert_tampering_rejected(builder, targets); + } + + #[test] + fn variable_base_grumpkin_scalar_mul_rejects_tampered_output() { + let mut builder = R1csBuilder::::new(); + let base = Grumpkin::generator().scalar_mul(&Fq::from_u64(11)); + let base = GrumpkinPointWithIdentityVar::alloc(&mut builder, &base); + let scalar = FqVar::alloc(&mut builder, Fq::from_u64(13)); + + let result = + variable_base_grumpkin_scalar_mul(&mut builder, &base, &scalar).expect("scalar mul"); + let targets = [("variable-base scalar mul output x", variable(&result.x))]; + + assert_tampering_rejected(builder, targets); + } + + #[test] + fn fixed_base_grumpkin_msm_accepts_nonzero_scalars() { + let mut builder = R1csBuilder::::new(); + let bases = vec![ + Grumpkin::generator().scalar_mul(&Fq::from_u64(11)), + Grumpkin::generator().scalar_mul(&Fq::from_u64(17)), + Grumpkin::generator().scalar_mul(&Fq::from_u64(23)), + ]; + let scalar_values = vec![Fq::from_u64(3), Fq::from_u64(5), Fq::from_u64(7)]; + let scalars = scalar_values + .iter() + .copied() + .map(|scalar| FqVar::alloc(&mut builder, scalar)) + .collect::>(); + + let result = + fixed_base_grumpkin_msm(&mut builder, &bases, &scalars).expect("fixed-base MSM"); + let expected = + GrumpkinPointWithIdentityVar::constant(&GrumpkinPoint::msm(&bases, &scalar_values)); + + result.assert_equal(&mut builder, &expected); + assert!(builder_accepts(builder)); + } + + #[test] + fn fixed_base_grumpkin_msm_rejects_length_mismatch() { + let mut builder = R1csBuilder::::new(); + let bases = vec![Grumpkin::generator()]; + let scalars = vec![ + FqVar::alloc(&mut builder, Fq::from_u64(3)), + FqVar::alloc(&mut builder, Fq::from_u64(5)), + ]; + + assert_eq!( + fixed_base_grumpkin_msm(&mut builder, &bases, &scalars), + Err(CryptoR1csError::FixedBaseMsmLengthMismatch { + bases: 1, + scalars: 2, + }) + ); + } + + #[test] + fn linear_combine_grumpkin_commitments_accepts_valid_combination() { + let mut builder = R1csBuilder::::new(); + let commitment_values = [ + Grumpkin::generator().scalar_mul(&Fq::from_u64(11)), + Grumpkin::generator().scalar_mul(&Fq::from_u64(17)), + ]; + let coefficient_values = [Fq::from_u64(3), Fq::from_u64(5)]; + let commitments = commitment_values + .iter() + .map(|commitment| GrumpkinPointWithIdentityVar::alloc(&mut builder, commitment)) + .collect::>(); + let coefficients = coefficient_values + .iter() + .copied() + .map(|coefficient| FqVar::alloc(&mut builder, coefficient)) + .collect::>(); + + let combined = + linear_combine_grumpkin_commitments(&mut builder, &commitments, &coefficients) + .expect("linear combination"); + let expected = GrumpkinPointWithIdentityVar::constant(&GrumpkinPoint::msm( + &commitment_values, + &coefficient_values, + )); + + combined.assert_equal(&mut builder, &expected); + assert!(builder_accepts(builder)); + } + + #[test] + fn vector_commitment_r1cs_trait_combines_grumpkin_commitments() { + let mut builder = R1csBuilder::::new(); + let commitment_values = [ + Grumpkin::generator().scalar_mul(&Fq::from_u64(11)), + Grumpkin::generator().scalar_mul(&Fq::from_u64(17)), + ]; + let coefficient_values = [Fq::from_u64(3), Fq::from_u64(5)]; + let commitments = commitment_values + .iter() + .map(|commitment| GrumpkinPointWithIdentityVar::alloc(&mut builder, commitment)) + .collect::>(); + let coefficients = coefficient_values + .iter() + .copied() + .map(|coefficient| FqVar::alloc(&mut builder, coefficient)) + .collect::>(); + + let combined = + as VectorCommitmentR1cs>::linear_combine_commitments( + &mut builder, + &commitments, + &coefficients, + ) + .expect("linear combination"); + let expected = GrumpkinPointWithIdentityVar::constant(&GrumpkinPoint::msm( + &commitment_values, + &coefficient_values, + )); + + combined.assert_equal(&mut builder, &expected); + assert!(builder_accepts(builder)); + } + + #[test] + fn linear_combine_grumpkin_commitments_rejects_tampered_coefficient() { + let mut builder = R1csBuilder::::new(); + let commitment_values = [ + Grumpkin::generator().scalar_mul(&Fq::from_u64(11)), + Grumpkin::generator().scalar_mul(&Fq::from_u64(17)), + ]; + let coefficient_values = [Fq::from_u64(3), Fq::from_u64(5)]; + let commitments = commitment_values + .iter() + .map(|commitment| GrumpkinPointWithIdentityVar::alloc(&mut builder, commitment)) + .collect::>(); + let coefficients = coefficient_values + .iter() + .copied() + .map(|coefficient| FqVar::alloc(&mut builder, coefficient)) + .collect::>(); + + let combined = + linear_combine_grumpkin_commitments(&mut builder, &commitments, &coefficients) + .expect("linear combination"); + combined.assert_valid(&mut builder); + let targets = [( + "linear-combine coefficient limb", + variable(&coefficients[0].limbs()[0]), + )]; + + assert_tampering_rejected(builder, targets); + } + + #[test] + fn linear_combine_grumpkin_commitments_rejects_length_mismatch() { + let mut builder = R1csBuilder::::new(); + let commitments = vec![GrumpkinPointWithIdentityVar::alloc( + &mut builder, + &Grumpkin::generator(), + )]; + let coefficients = vec![ + FqVar::alloc(&mut builder, Fq::from_u64(3)), + FqVar::alloc(&mut builder, Fq::from_u64(5)), + ]; + + assert_eq!( + linear_combine_grumpkin_commitments(&mut builder, &commitments, &coefficients), + Err(CryptoR1csError::FixedBaseMsmLengthMismatch { + bases: 1, + scalars: 2, + }) + ); + } + + #[test] + fn grumpkin_pedersen_opening_accepts_valid_opening() { + let mut builder = R1csBuilder::::new(); + let setup = grumpkin_pedersen_setup(); + let value_scalars = vec![Fq::from_u64(3), Fq::from_u64(5)]; + let blinding_scalar = Fq::from_u64(7); + let commitment_value = + Pedersen::::commit(&setup, &value_scalars, &blinding_scalar); + let commitment = GrumpkinPointWithIdentityVar::alloc(&mut builder, &commitment_value); + let opening = grumpkin_pedersen_opening_var(&mut builder, &value_scalars, blinding_scalar); + + verify_grumpkin_pedersen_opening(&mut builder, &setup, &commitment, &opening) + .expect("valid Pedersen opening"); + + assert!(builder_accepts(builder)); + } + + #[test] + fn vector_commitment_r1cs_trait_verifies_grumpkin_pedersen_opening() { + let mut builder = R1csBuilder::::new(); + let setup = grumpkin_pedersen_setup(); + let value_scalars = vec![Fq::from_u64(3), Fq::from_u64(5)]; + let blinding_scalar = Fq::from_u64(7); + let commitment_value = + Pedersen::::commit(&setup, &value_scalars, &blinding_scalar); + let commitment = GrumpkinPointWithIdentityVar::alloc(&mut builder, &commitment_value); + let opening = grumpkin_pedersen_opening_var(&mut builder, &value_scalars, blinding_scalar); + + as VectorCommitmentR1cs>::verify_opening( + &mut builder, + &setup, + &commitment, + &opening, + ) + .expect("valid Pedersen opening"); + + assert_eq!( + as VectorCommitmentR1cs>::capacity(&setup), + setup.message_generators.len() + ); + assert!(builder_accepts(builder)); + } + + #[test] + fn grumpkin_pedersen_opening_rejects_tampered_value() { + let mut builder = R1csBuilder::::new(); + let setup = grumpkin_pedersen_setup(); + let value_scalars = vec![Fq::from_u64(3), Fq::from_u64(5)]; + let blinding_scalar = Fq::from_u64(7); + let commitment_value = + Pedersen::::commit(&setup, &value_scalars, &blinding_scalar); + let commitment = GrumpkinPointWithIdentityVar::alloc(&mut builder, &commitment_value); + let opening = grumpkin_pedersen_opening_var(&mut builder, &value_scalars, blinding_scalar); + let targets = [( + "Pedersen opening value limb", + variable(&opening.values[0].limbs()[0]), + )]; + + verify_grumpkin_pedersen_opening(&mut builder, &setup, &commitment, &opening) + .expect("valid Pedersen opening before tampering"); + + assert_tampering_rejected(builder, targets); + } + + #[test] + fn grumpkin_pedersen_opening_rejects_tampered_blinding() { + let mut builder = R1csBuilder::::new(); + let setup = grumpkin_pedersen_setup(); + let value_scalars = vec![Fq::from_u64(3), Fq::from_u64(5)]; + let blinding_scalar = Fq::from_u64(7); + let commitment_value = + Pedersen::::commit(&setup, &value_scalars, &blinding_scalar); + let commitment = GrumpkinPointWithIdentityVar::alloc(&mut builder, &commitment_value); + let opening = grumpkin_pedersen_opening_var(&mut builder, &value_scalars, blinding_scalar); + let targets = [( + "Pedersen opening blinding limb", + variable(&opening.blinding.limbs()[0]), + )]; + + verify_grumpkin_pedersen_opening(&mut builder, &setup, &commitment, &opening) + .expect("valid Pedersen opening before tampering"); + + assert_tampering_rejected(builder, targets); + } + + #[test] + fn grumpkin_pedersen_opening_rejects_tampered_commitment() { + let mut builder = R1csBuilder::::new(); + let setup = grumpkin_pedersen_setup(); + let value_scalars = vec![Fq::from_u64(3), Fq::from_u64(5)]; + let blinding_scalar = Fq::from_u64(7); + let commitment_value = + Pedersen::::commit(&setup, &value_scalars, &blinding_scalar); + let commitment = GrumpkinPointWithIdentityVar::alloc(&mut builder, &commitment_value); + let opening = grumpkin_pedersen_opening_var(&mut builder, &value_scalars, blinding_scalar); + let targets = [("Pedersen commitment x-coordinate", variable(&commitment.x))]; + + verify_grumpkin_pedersen_opening(&mut builder, &setup, &commitment, &opening) + .expect("valid Pedersen opening before tampering"); + + assert_tampering_rejected(builder, targets); + } + + #[test] + fn grumpkin_pedersen_opening_rejects_capacity_exceeded() { + let mut builder = R1csBuilder::::new(); + let setup = PedersenSetup::new( + vec![Grumpkin::generator().scalar_mul(&Fq::from_u64(11))], + Grumpkin::generator().scalar_mul(&Fq::from_u64(23)), + ); + let opening = grumpkin_pedersen_opening_var( + &mut builder, + &[Fq::from_u64(3), Fq::from_u64(5)], + Fq::from_u64(7), + ); + + assert_eq!( + grumpkin_pedersen_opening_commitment(&mut builder, &setup, &opening), + Err(CryptoR1csError::VectorCommitmentCapacityExceeded { + capacity: 1, + values: 2, + }) + ); + } + + #[test] + fn vector_commitment_opening_var_preserves_values_and_blinding() { + let mut builder = R1csBuilder::::new(); + let values = vec![ + jolt_r1cs::FqVar::alloc(&mut builder, Fq::from_u64(3)), + jolt_r1cs::FqVar::alloc(&mut builder, Fq::from_u64(5)), + ]; + let blinding = jolt_r1cs::FqVar::alloc(&mut builder, Fq::from_u64(7)); + + let opening = VectorCommitmentOpeningVar::new(values.clone(), blinding.clone()); + + assert_eq!(opening.values, values); + assert_eq!(opening.blinding, blinding); + assert!(builder_accepts(builder)); + } + + fn assert_group_var_valid( + builder: &mut R1csBuilder, + point: &G, + ) -> Result<(), CryptoR1csError> + where + G: GroupElementVar, + { + point.assert_valid(builder) + } + + fn assert_generic_nonexceptional_add( + builder: &mut R1csBuilder, + lhs: &G, + rhs: &G, + output: &G, + ) -> Result<(), CryptoR1csError> + where + G: NonExceptionalAddGroupVar, + { + G::assert_nonexceptional_add(builder, lhs, rhs, output) + } + + fn assert_generic_double( + builder: &mut R1csBuilder, + input: &G, + output: &G, + ) -> Result<(), CryptoR1csError> + where + G: DoubleGroupVar, + { + G::assert_double(builder, input, output) + } + + fn assert_generic_complete_add( + builder: &mut R1csBuilder, + lhs: &G, + rhs: &G, + ) -> Result + where + G: CompleteAddGroupVar, + { + G::complete_add(builder, lhs, rhs) + } + + fn assert_generic_variable_base_scalar_mul( + builder: &mut R1csBuilder, + base: &G, + scalar: &FqVar, + ) -> Result + where + G: VariableBaseScalarMulGroupVar< + BuilderField = Fr, + ScalarVar = FqVar, + Error = CryptoR1csError, + >, + { + G::variable_base_scalar_mul(builder, base, scalar) + } + + fn assert_generic_fixed_base_scalar_mul( + builder: &mut R1csBuilder, + base: &GrumpkinPoint, + scalar: &FqVar, + ) -> Result + where + G: FixedBaseScalarMulGroupVar< + BuilderField = Fr, + ScalarVar = FqVar, + Constant = GrumpkinPoint, + Error = CryptoR1csError, + >, + { + G::fixed_base_scalar_mul(builder, base, scalar) + } + + fn grumpkin_pedersen_setup() -> PedersenSetup { + PedersenSetup::new( + vec![ + Grumpkin::generator().scalar_mul(&Fq::from_u64(11)), + Grumpkin::generator().scalar_mul(&Fq::from_u64(17)), + ], + Grumpkin::generator().scalar_mul(&Fq::from_u64(23)), + ) + } + + fn grumpkin_pedersen_opening_var( + builder: &mut R1csBuilder, + values: &[Fq], + blinding: Fq, + ) -> VectorCommitmentOpeningVar { + VectorCommitmentOpeningVar::new( + values + .iter() + .copied() + .map(|value| FqVar::alloc(builder, value)) + .collect(), + FqVar::alloc(builder, blinding), + ) + } + + fn builder_accepts(builder: R1csBuilder) -> bool { + let witness = builder.witness().expect("witness is assigned"); + builder.into_matrices().check_witness(&witness).is_ok() + } + + fn builder_rejects(builder: R1csBuilder) -> bool { + let witness = builder.witness().expect("witness is assigned"); + builder.into_matrices().check_witness(&witness).is_err() + } + + fn assert_tampering_rejected( + builder: R1csBuilder, + targets: impl IntoIterator, + ) { + let witness = builder.witness().expect("witness is assigned"); + let matrices = builder.into_matrices(); + assert!(matrices.check_witness(&witness).is_ok()); + + for (label, variable) in targets { + let mut tampered = witness.clone(); + tampered[variable.index()] += Fr::from_u64(1); + assert!( + matrices.check_witness(&tampered).is_err(), + "{label} accepted after tampering variable {}", + variable.index() + ); + } + } + + fn variable(scalar: &AssignedScalar) -> Variable { + scalar + .lc + .terms + .first() + .copied() + .expect("expected scalar backed by one variable") + .0 + } +} diff --git a/crates/jolt-crypto/tests/pairing.rs b/crates/jolt-crypto/tests/pairing.rs index a50c75dd4d..745e820963 100644 --- a/crates/jolt-crypto/tests/pairing.rs +++ b/crates/jolt-crypto/tests/pairing.rs @@ -1,7 +1,7 @@ //! Pairing bilinearity and consistency tests for BN254. -use jolt_crypto::{Bn254, Bn254G2, Bn254GT, JoltGroup, PairingGroup}; -use jolt_field::{Fr, FromPrimitiveInt, RandomSampling}; +use jolt_crypto::{Bn254, Bn254Fq12, Bn254G1, Bn254G2, Bn254GT, JoltGroup, PairingGroup}; +use jolt_field::{Fq, Fr, FromPrimitiveInt, RandomSampling}; use rand_chacha::ChaCha20Rng; use rand_core::SeedableRng; @@ -195,3 +195,87 @@ fn gt_msm_matches_naive() { fn gt_default_is_identity() { assert_eq!(Bn254GT::default(), Bn254GT::identity()); } + +#[test] +fn gt_fq12_coefficients_expose_identity_layout() { + let coeffs = Bn254GT::identity().fq12_coefficients(); + + assert_eq!(coeffs[0], Fq::from_u64(1)); + assert!(coeffs[1..] + .iter() + .all(|coefficient| *coefficient == Fq::default())); +} + +#[test] +fn gt_fq12_coefficients_are_deterministic() { + let value = Bn254::pairing(&Bn254::g1_generator(), &Bn254::g2_generator()); + + assert_eq!(value.fq12_coefficients(), value.fq12_coefficients()); +} + +#[test] +fn g1_affine_coordinates_roundtrip() { + let value = Bn254::g1_generator().scalar_mul(&Fr::from_u64(42)); + let coordinates = value.affine_coordinates_with_infinity(); + + assert_eq!( + Bn254G1::from_affine_coordinates_with_infinity(coordinates), + Some(value) + ); +} + +#[test] +fn g2_affine_coordinates_roundtrip() { + let value = Bn254::g2_generator().scalar_mul(&Fr::from_u64(42)); + let coordinates = value.affine_coordinates_with_infinity(); + + assert_eq!( + Bn254G2::from_affine_coordinates_with_infinity(coordinates), + Some(value) + ); +} + +#[test] +fn gt_fq12_coefficients_roundtrip() { + let value = Bn254::pairing(&Bn254::g1_generator(), &Bn254::g2_generator()); + let coefficients = value.fq12_coefficients(); + + assert_eq!(Bn254GT::from_fq12_coefficients(coefficients), Some(value)); +} + +#[test] +fn fq12_from_gt_preserves_coefficients() { + let value = Bn254::pairing(&Bn254::g1_generator(), &Bn254::g2_generator()); + let raw = Bn254Fq12::from(value); + + assert_eq!(raw.coefficients(), value.fq12_coefficients()); +} + +#[test] +fn fq12_final_exponentiation_accepts_identity() { + assert_eq!( + Bn254Fq12::default().final_exponentiation(), + Some(Bn254GT::identity()) + ); +} + +#[test] +fn raw_multi_miller_loop_final_exponentiation_matches_pairing() { + let g1 = Bn254::g1_generator(); + let g2 = Bn254::g2_generator(); + let raw = Bn254::multi_miller_loop(&[g1], &[g2]); + + assert_eq!(raw.final_exponentiation(), Some(Bn254::pairing(&g1, &g2))); +} + +#[test] +fn raw_multi_miller_loop_batches_pairings() { + let g1 = Bn254::g1_generator(); + let g2 = Bn254::g2_generator(); + let two = Fr::from_u64(2); + let g1_twice = g1.scalar_mul(&two); + let raw = Bn254::multi_miller_loop(&[g1, g1_twice], &[g2, g2]); + let expected = Bn254::multi_pairing(&[g1, g1_twice], &[g2, g2]); + + assert_eq!(raw.final_exponentiation(), Some(expected)); +} diff --git a/crates/jolt-dory/src/scheme.rs b/crates/jolt-dory/src/scheme.rs index 63526d87a8..93d690dab0 100644 --- a/crates/jolt-dory/src/scheme.rs +++ b/crates/jolt-dory/src/scheme.rs @@ -20,6 +20,7 @@ use jolt_openings::{AdditivelyHomomorphic, CommitmentScheme, OpeningsError, ZkOp use jolt_poly::MultilinearPoly; use jolt_transcript::{AppendToTranscript, Label, LabelWithCount, Transcript}; use rayon::prelude::*; +use serde::{Deserialize, Serialize}; use crate::transcript::JoltToDoryTranscript; use crate::types::{DoryCommitment, DoryHint, DoryProof, DoryProverSetup, DoryVerifierSetup}; @@ -78,7 +79,7 @@ pub(crate) fn ark_to_jolt_g1(ark: ArkG1) -> Bn254G1 { unsafe { std::mem::transmute(ark) } } -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct DoryScheme; impl DoryScheme { diff --git a/crates/jolt-dory/src/streaming.rs b/crates/jolt-dory/src/streaming.rs index a26bd68ea0..7d1f7936e3 100644 --- a/crates/jolt-dory/src/streaming.rs +++ b/crates/jolt-dory/src/streaming.rs @@ -2,7 +2,8 @@ use dory::backends::arkworks::G1Routines; use dory::primitives::arithmetic::DoryRoutines; -use jolt_field::Fr; +use jolt_crypto::{Bn254G1, JoltGroup}; +use jolt_field::{Fr, FromPrimitiveInt}; use jolt_openings::StreamingCommitment; use crate::scheme::{ @@ -32,6 +33,8 @@ impl crate::DoryScheme { impl StreamingCommitment for crate::DoryScheme { type PartialCommitment = DoryPartialCommitment; + type OneHotChunkCommitment = Vec; + type OneHotStreamContext = usize; fn begin(_setup: &Self::ProverSetup) -> Self::PartialCommitment { DoryPartialCommitment { @@ -62,6 +65,44 @@ impl StreamingCommitment for crate::DoryScheme { partial.row_commitments.push(ark_to_jolt_g1(row_commitment)); } + fn begin_one_hot_column_major_stream( + _setup: &Self::ProverSetup, + _row_width: usize, + ) -> Self::OneHotStreamContext { + 0 + } + + #[tracing::instrument(skip_all, name = "DoryScheme::stream_one_hot_chunk")] + fn process_one_hot_chunk( + context: &mut Self::OneHotStreamContext, + setup: &Self::ProverSetup, + one_hot_k: usize, + chunk: &[Option], + ) -> Self::OneHotChunkCommitment { + assert!( + context.saturating_add(chunk.len()) <= setup.0.g1_vec.len(), + "one-hot chunk exceeds Dory SRS size" + ); + + let g1_bases = &setup.0.g1_vec[*context..*context + chunk.len()]; + let mut row_commitments = vec![Bn254G1::identity(); one_hot_k]; + for (row, row_commitment) in row_commitments.iter_mut().enumerate() { + let scalars: Vec = chunk + .iter() + .map(|hot_row| { + if *hot_row == Some(row) { + jolt_fr_to_ark(&Fr::from_u64(1)) + } else { + jolt_fr_to_ark(&Fr::from_u64(0)) + } + }) + .collect(); + *row_commitment = ark_to_jolt_g1(G1Routines::msm(g1_bases, &scalars)); + } + *context += chunk.len(); + row_commitments + } + /// Aggregates row commitments into the final tier-2 commitment, matching /// [`DoryScheme::commit`](crate::DoryScheme::commit). Asserts that the /// streamed row count is a power of two (the layout `DoryScheme::commit` @@ -75,6 +116,36 @@ impl StreamingCommitment for crate::DoryScheme { let (tier_2, _) = commit_rows_tier_2::(&ark_rows, setup); DoryCommitment(ark_to_jolt_gt(&tier_2)) } + + #[tracing::instrument(skip_all, name = "DoryScheme::stream_one_hot_finish")] + fn finish_one_hot_column_major_chunks( + setup: &Self::ProverSetup, + one_hot_k: usize, + chunks: &[Self::OneHotChunkCommitment], + ) -> (Self::Output, Self::OpeningHint) { + assert!( + !chunks.is_empty(), + "one-hot stream must contain at least one chunk" + ); + let mut row_commitments = vec![Bn254G1::identity(); one_hot_k]; + for chunk in chunks { + assert_eq!( + chunk.len(), + one_hot_k, + "one-hot chunk row count must match one_hot_k" + ); + for (row_commitment, chunk_commitment) in row_commitments.iter_mut().zip(chunk) { + *row_commitment += *chunk_commitment; + } + } + validate_row_count(row_commitments.len(), setup); + let ark_rows = jolt_g1_vec_to_ark(row_commitments); + let (tier_2, commit_blind) = commit_rows_tier_2::(&ark_rows, setup); + ( + DoryCommitment(ark_to_jolt_gt(&tier_2)), + DoryHint::new(ark_to_jolt_g1_vec(ark_rows), ark_to_jolt_fr(&commit_blind)), + ) + } } fn validate_row_count(num_rows: usize, setup: &DoryProverSetup) { diff --git a/crates/jolt-dory/src/types.rs b/crates/jolt-dory/src/types.rs index bdce8d8fc5..b82b1ac778 100644 --- a/crates/jolt-dory/src/types.rs +++ b/crates/jolt-dory/src/types.rs @@ -52,14 +52,14 @@ impl AppendToTranscript for DoryCommitment { } } -impl HomomorphicCommitment for DoryCommitment { +impl HomomorphicCommitment for DoryCommitment { #[inline] fn add(c1: &Self, c2: &Self) -> Self { - Self(>::add(&c1.0, &c2.0)) + Self(>::add(&c1.0, &c2.0)) } #[inline] - fn linear_combine(c1: &Self, c2: &Self, scalar: &F) -> Self { + fn linear_combine(c1: &Self, c2: &Self, scalar: &Fr) -> Self { Self(HomomorphicCommitment::linear_combine(&c1.0, &c2.0, scalar)) } } diff --git a/crates/jolt-field/Cargo.toml b/crates/jolt-field/Cargo.toml index fdd2ed8c3c..4cff4757a3 100644 --- a/crates/jolt-field/Cargo.toml +++ b/crates/jolt-field/Cargo.toml @@ -14,7 +14,7 @@ workspace = true [dependencies] ark-ff = { workspace = true, optional = true } ark-serialize = { workspace = true, optional = true } -ark-bn254 = { workspace = true, features = ["scalar_field"], optional = true } +ark-bn254 = { workspace = true, features = ["curve"], optional = true } num-traits = { workspace = true } serde = { workspace = true, features = ["derive"] } allocative = { workspace = true, optional = true } diff --git a/crates/jolt-field/src/arkworks/bn254.rs b/crates/jolt-field/src/arkworks/bn254.rs index d31096ca7f..62088d6baf 100644 --- a/crates/jolt-field/src/arkworks/bn254.rs +++ b/crates/jolt-field/src/arkworks/bn254.rs @@ -6,6 +6,7 @@ use crate::{ AdditiveGroup, CanonicalBitLength, CanonicalBytes, CanonicalU64, Field, FieldCore, FixedByteSize, FixedBytes, FromPrimitiveInt, Invertible, Limbs, MulPrimitiveInt, RandomSampling, ReducingBytes, RingCore, TranscriptChallenge, WithAccumulator, + WithSignedProductAccumulator, WithSmallScalarAccumulator, }; use ark_ff::{prelude::*, PrimeField, UniformRand}; use rand_core::RngCore; @@ -363,7 +364,6 @@ impl TranscriptChallenge for Fr { buf[..len].copy_from_slice(&bytes[..len]); let value = u128::from_le_bytes(buf); let low = value as u64; - // Top 3 bits of high limb are zeroed to ensure value < BN254 modulus. let high = ((value >> 64) as u64) & (u64::MAX >> 3); let Some(inner) = InnerFr::from_bigint_unchecked(ark_ff::BigInt::new([0, 0, low, high])) else { @@ -447,6 +447,14 @@ impl WithAccumulator for Fr { type Accumulator = super::wide_accumulator::WideAccumulator; } +impl WithSmallScalarAccumulator for Fr { + type SmallScalarAccumulator = super::small_scalar_accumulator::FrSmallScalarAccumulator; +} + +impl WithSignedProductAccumulator for Fr { + type SignedProductAccumulator = super::signed_product_accumulator::FrSignedProductAccumulator; +} + impl crate::MulPow2 for Fr {} impl MulPrimitiveInt for Fr { diff --git a/crates/jolt-field/src/arkworks/bn254_fq.rs b/crates/jolt-field/src/arkworks/bn254_fq.rs new file mode 100644 index 0000000000..73d620f695 --- /dev/null +++ b/crates/jolt-field/src/arkworks/bn254_fq.rs @@ -0,0 +1,479 @@ +//! Newtype wrapper around `ark_bn254::Fq` that decouples the public API from arkworks. +//! +//! [`Fq`] is the BN254 base field. In the BN254/Grumpkin cycle it is also the +//! scalar field of Grumpkin. + +use crate::{ + AdditiveGroup, CanonicalBitLength, CanonicalBytes, CanonicalU64, Field, FieldCore, + FixedByteSize, FixedBytes, FromPrimitiveInt, Invertible, Limbs, MulPrimitiveInt, + NaiveAccumulator, NaiveSignedProductAccumulator, NaiveSignedScalarAccumulator, RandomSampling, + ReducingBytes, RingCore, TranscriptChallenge, WithAccumulator, WithSignedProductAccumulator, + WithSmallScalarAccumulator, +}; +use ark_ff::{prelude::*, PrimeField, UniformRand}; +use rand_core::RngCore; + +type InnerFq = ark_bn254::Fq; + +/// BN254 base field element. +/// +/// A `#[repr(transparent)]` newtype over `ark_bn254::Fq`. +#[derive(Clone, Copy, Default, PartialEq, Eq, Hash)] +#[repr(transparent)] +pub struct Fq(pub(crate) InnerFq); + +impl From for Fq { + #[inline(always)] + fn from(v: bool) -> Self { + ::from_bool(v) + } +} + +impl From for Fq { + #[inline(always)] + fn from(v: u8) -> Self { + ::from_u64(v as u64) + } +} + +impl From for Fq { + #[inline(always)] + fn from(v: u16) -> Self { + ::from_u64(v as u64) + } +} + +impl From for Fq { + #[inline(always)] + fn from(v: u32) -> Self { + ::from_u64(v as u64) + } +} + +impl From for Fq { + #[inline(always)] + fn from(v: u64) -> Self { + ::from_u64(v) + } +} + +impl From for Fq { + #[inline(always)] + fn from(v: i64) -> Self { + ::from_i64(v) + } +} + +impl From for Fq { + #[inline(always)] + fn from(v: i128) -> Self { + ::from_i128(v) + } +} + +impl From for Fq { + #[inline(always)] + fn from(v: u128) -> Self { + ::from_u128(v) + } +} + +impl std::fmt::Debug for Fq { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + std::fmt::Debug::fmt(&self.0, f) + } +} + +impl std::fmt::Display for Fq { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + std::fmt::Display::fmt(&self.0, f) + } +} + +macro_rules! delegate_binop { + ($Trait:ident, $method:ident) => { + impl std::ops::$Trait for Fq { + type Output = Fq; + #[inline(always)] + fn $method(self, rhs: Fq) -> Fq { + Fq(std::ops::$Trait::$method(self.0, rhs.0)) + } + } + + impl std::ops::$Trait<&Fq> for Fq { + type Output = Fq; + #[inline(always)] + fn $method(self, rhs: &Fq) -> Fq { + Fq(std::ops::$Trait::$method(self.0, rhs.0)) + } + } + + impl std::ops::$Trait for &Fq { + type Output = Fq; + #[inline(always)] + fn $method(self, rhs: Fq) -> Fq { + Fq(std::ops::$Trait::$method(self.0, rhs.0)) + } + } + + impl<'a, 'b> std::ops::$Trait<&'b Fq> for &'a Fq { + type Output = Fq; + #[inline(always)] + fn $method(self, rhs: &'b Fq) -> Fq { + Fq(std::ops::$Trait::$method(self.0, rhs.0)) + } + } + }; +} + +delegate_binop!(Add, add); +delegate_binop!(Sub, sub); +delegate_binop!(Mul, mul); +delegate_binop!(Div, div); + +impl std::ops::Neg for Fq { + type Output = Fq; + + #[inline(always)] + fn neg(self) -> Fq { + Fq(self.0.neg()) + } +} + +impl std::ops::AddAssign for Fq { + #[inline(always)] + fn add_assign(&mut self, rhs: Fq) { + self.0.add_assign(rhs.0); + } +} + +impl std::ops::SubAssign for Fq { + #[inline(always)] + fn sub_assign(&mut self, rhs: Fq) { + self.0.sub_assign(rhs.0); + } +} + +impl std::ops::MulAssign for Fq { + #[inline(always)] + fn mul_assign(&mut self, rhs: Fq) { + self.0.mul_assign(rhs.0); + } +} + +impl std::iter::Sum for Fq { + fn sum>(iter: I) -> Self { + Fq(iter.map(|f| f.0).sum()) + } +} + +impl<'a> std::iter::Sum<&'a Fq> for Fq { + fn sum>(iter: I) -> Self { + Fq(iter.map(|f| f.0).sum()) + } +} + +impl std::iter::Product for Fq { + fn product>(iter: I) -> Self { + Fq(iter.map(|f| f.0).product()) + } +} + +impl<'a> std::iter::Product<&'a Fq> for Fq { + fn product>(iter: I) -> Self { + Fq(iter.map(|f| f.0).product()) + } +} + +impl num_traits::Zero for Fq { + #[inline(always)] + fn zero() -> Self { + Fq(InnerFq::zero()) + } + + #[inline(always)] + fn is_zero(&self) -> bool { + self.0.is_zero() + } +} + +impl num_traits::One for Fq { + #[inline(always)] + fn one() -> Self { + Fq(InnerFq::one()) + } + + #[inline(always)] + fn is_one(&self) -> bool { + self.0.is_one() + } +} + +impl serde::Serialize for Fq { + fn serialize(&self, serializer: S) -> Result { + use ark_serialize::CanonicalSerialize; + let mut buf = [0u8; 32]; + self.0 + .serialize_compressed(&mut buf[..]) + .map_err(serde::ser::Error::custom)?; + <[u8; 32]>::serialize(&buf, serializer) + } +} + +impl<'de> serde::Deserialize<'de> for Fq { + fn deserialize>(deserializer: D) -> Result { + use ark_serialize::CanonicalDeserialize; + let buf = <[u8; 32]>::deserialize(deserializer)?; + let inner = InnerFq::deserialize_compressed(&buf[..]).map_err(serde::de::Error::custom)?; + Ok(Fq(inner)) + } +} + +impl ark_serialize::CanonicalSerialize for Fq { + fn serialize_with_mode( + &self, + writer: W, + compress: ark_serialize::Compress, + ) -> Result<(), ark_serialize::SerializationError> { + self.0.serialize_with_mode(writer, compress) + } + + fn serialized_size(&self, compress: ark_serialize::Compress) -> usize { + self.0.serialized_size(compress) + } +} + +impl ark_serialize::Valid for Fq { + fn check(&self) -> Result<(), ark_serialize::SerializationError> { + self.0.check() + } +} + +impl ark_serialize::CanonicalDeserialize for Fq { + fn deserialize_with_mode( + reader: R, + compress: ark_serialize::Compress, + validate: ark_serialize::Validate, + ) -> Result { + InnerFq::deserialize_with_mode(reader, compress, validate).map(Fq) + } +} + +impl UniformRand for Fq { + fn rand(rng: &mut R) -> Self { + Fq(::rand(rng)) + } +} + +#[cfg(feature = "allocative")] +impl allocative::Allocative for Fq { + fn visit<'a, 'b: 'a>(&self, visitor: &'a mut allocative::Visitor<'b>) { + visitor.visit_simple_sized::(); + } +} + +impl Fq { + /// Deserializes from little-endian bytes, reducing modulo the field prime. + #[inline] + pub fn from_le_bytes_mod_order(bytes: &[u8]) -> Self { + Fq(InnerFq::from_le_bytes_mod_order(bytes)) + } + + /// Converts a limb array to a field element without checking that it is + /// less than the modulus. + #[inline] + pub fn from_bigint_unchecked(limbs: Limbs<4>) -> Self { + let Some(inner) = InnerFq::from_bigint(ark_ff::BigInt::new(limbs.0)) else { + unreachable!("unchecked BN254 Fq construction received non-canonical limbs") + }; + Fq(inner) + } + + /// Access the internal Montgomery-form limbs. + #[inline(always)] + pub fn inner_limbs(self) -> Limbs<4> { + Limbs((self.0).0 .0) + } +} + +impl AdditiveGroup for Fq {} + +impl RingCore for Fq { + #[inline] + fn square(&self) -> Self { + Fq(::square(&self.0)) + } +} + +impl Invertible for Fq { + #[inline] + fn inverse(&self) -> Option { + ::inverse(&self.0).map(Fq) + } +} + +impl FieldCore for Fq {} + +impl FixedByteSize for Fq { + const NUM_BYTES: usize = 32; +} + +impl CanonicalBytes for Fq { + #[expect(clippy::expect_used)] + #[inline] + fn to_bytes_le(&self, out: &mut [u8]) { + assert_eq!(out.len(), ::NUM_BYTES); + use ark_serialize::CanonicalSerialize; + self.0 + .serialize_compressed(out) + .expect("BN254 Fq always serializes to 32 bytes"); + } +} + +impl ReducingBytes for Fq { + #[inline] + fn from_le_bytes_mod_order(bytes: &[u8]) -> Self { + Fq::from_le_bytes_mod_order(bytes) + } +} + +impl TranscriptChallenge for Fq { + #[inline] + fn from_challenge_bytes(bytes: &[u8]) -> Self { + let mut buf = [0u8; 16]; + let len = bytes.len().min(buf.len()); + buf[..len].copy_from_slice(&bytes[..len]); + let value = u128::from_le_bytes(buf); + let low = value as u64; + let high = ((value >> 64) as u64) & (u64::MAX >> 3); + let Some(inner) = InnerFq::from_bigint(ark_ff::BigInt::new([0, 0, low, high])) else { + unreachable!("masked 125-bit shifted challenge fits in BN254 Fq") + }; + Fq(inner) + } + + #[inline] + fn from_scalar_challenge_bytes(bytes: &[u8]) -> Self { + let mut buf = bytes.to_vec(); + buf.reverse(); + Fq::from_le_bytes_mod_order(&buf) + } +} + +impl FixedBytes<32> for Fq {} + +impl CanonicalU64 for Fq { + #[inline] + fn to_canonical_u64_checked(&self) -> Option { + let bigint = ::into_bigint(self.0); + let limbs: &[u64] = bigint.as_ref(); + let result = limbs[0]; + + if ::from_u64(result) != *self { + None + } else { + Some(result) + } + } +} + +impl CanonicalBitLength for Fq { + #[inline] + fn num_bits(&self) -> u32 { + ::into_bigint(self.0).num_bits() + } +} + +impl RandomSampling for Fq { + #[inline] + fn random(rng: &mut R) -> Self { + Fq(::rand(rng)) + } +} + +impl FromPrimitiveInt for Fq { + #[inline] + fn from_u64(n: u64) -> Self { + Fq(InnerFq::from(n)) + } + + #[inline] + fn from_i64(val: i64) -> Self { + if val.is_negative() { + -Fq(InnerFq::from(val.unsigned_abs())) + } else { + Fq(InnerFq::from(val as u64)) + } + } + + #[inline] + fn from_i128(val: i128) -> Self { + if val.is_negative() { + -Fq(InnerFq::from(val.unsigned_abs())) + } else { + Fq(InnerFq::from(val as u128)) + } + } + + #[inline] + fn from_u128(val: u128) -> Self { + Fq(InnerFq::from(val)) + } +} + +impl WithAccumulator for Fq { + type Accumulator = NaiveAccumulator; +} + +impl WithSmallScalarAccumulator for Fq { + type SmallScalarAccumulator = NaiveSignedScalarAccumulator; +} + +impl WithSignedProductAccumulator for Fq { + type SignedProductAccumulator = NaiveSignedProductAccumulator; +} + +impl crate::MulPow2 for Fq {} + +impl MulPrimitiveInt for Fq {} + +impl Field for Fq {} + +#[cfg(test)] +#[expect(clippy::unwrap_used)] +mod tests { + use super::*; + use crate::{CanonicalU64, FixedBytes}; + + #[test] + fn field_arithmetic_basic() { + let a = Fq::from_u64(7); + let b = Fq::from_u64(6); + assert_eq!(a * b, Fq::from_u64(42)); + assert_eq!(a + b, Fq::from_u64(13)); + assert_eq!(b - a, Fq::from_i64(-1)); + } + + #[test] + fn serialization_roundtrip() { + let val = Fq::from_u64(123_456_789); + let bytes = val.to_bytes_array(); + let recovered = Fq::from_bytes_array(&bytes); + assert_eq!(val, recovered); + } + + #[test] + fn inverse_and_square() { + let a = Fq::from_u64(42); + let inv = a.inverse().unwrap(); + assert_eq!(a * inv, Fq::one()); + assert!(Fq::zero().inverse().is_none()); + assert_eq!(a.square(), a * a); + } + + #[test] + fn to_u64_roundtrip() { + let value = Fq::from_u64(12345); + assert_eq!(value.to_canonical_u64_checked(), Some(12345)); + } +} diff --git a/crates/jolt-field/src/arkworks/bn254_ops.rs b/crates/jolt-field/src/arkworks/bn254_ops.rs index 3a1d061791..a7a00e24d1 100644 --- a/crates/jolt-field/src/arkworks/bn254_ops.rs +++ b/crates/jolt-field/src/arkworks/bn254_ops.rs @@ -2,6 +2,7 @@ //! //! Low-level field arithmetic (Montgomery/Barrett reduction, scalar multiplication, //! precomputed lookup tables). +use crate::Limbs; use ark_bn254::FrConfig; use ark_ff::{BigInt, Fp, MontConfig}; use num_traits::Zero; @@ -331,6 +332,14 @@ fn bigint4_mul_u64(a: &BigInt, b: u64) -> BigInt<5> { res } +#[inline(always)] +pub(crate) fn mul_u64_unreduced(a: Fr, b: u64) -> Limbs<5> { + if b == 0 || Zero::is_zero(&a) { + return Limbs::zero(); + } + bigint4_mul_u64(&a.0, b).into() +} + /// Multiply BigInt<4> by u128, producing BigInt<6>. #[inline(always)] fn bigint4_mul_u128(a: &BigInt, b: u128) -> BigInt<6> { @@ -366,6 +375,11 @@ fn from_unchecked_nplus1(element: BigInt<5>) -> Fr { Fp::new_unchecked(r) } +#[inline(always)] +pub(crate) fn reduce_nplus1(element: Limbs<5>) -> Fr { + from_unchecked_nplus1(element.into()) +} + /// Barrett reduce BigInt<6> → Fr via two rounds #[inline(always)] fn from_unchecked_nplus2(element: BigInt<6>) -> Fr { diff --git a/crates/jolt-field/src/arkworks/mod.rs b/crates/jolt-field/src/arkworks/mod.rs index 841398bb79..7ff1700012 100644 --- a/crates/jolt-field/src/arkworks/mod.rs +++ b/crates/jolt-field/src/arkworks/mod.rs @@ -7,8 +7,11 @@ use crate::Limbs; use ark_ff::BigInt; pub mod bn254; +pub mod bn254_fq; pub(crate) mod bn254_ops; pub mod montgomery_impl; +pub mod signed_product_accumulator; +pub mod small_scalar_accumulator; pub mod wide_accumulator; impl From> for BigInt { diff --git a/crates/jolt-field/src/arkworks/signed_product_accumulator.rs b/crates/jolt-field/src/arkworks/signed_product_accumulator.rs new file mode 100644 index 0000000000..7e287ea029 --- /dev/null +++ b/crates/jolt-field/src/arkworks/signed_product_accumulator.rs @@ -0,0 +1,121 @@ +use crate::{signed::S256, Limbs, SignedProductAccumulator}; +use ark_ff::{BigInt, MontConfig}; +use num_traits::Zero; + +use super::{bn254::Fr, bn254_ops}; + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct FrSignedProductAccumulator { + pos: [u128; 8], + neg: [u128; 8], +} + +impl Default for FrSignedProductAccumulator { + #[inline] + fn default() -> Self { + Self { + pos: [0; 8], + neg: [0; 8], + } + } +} + +impl FrSignedProductAccumulator { + #[inline(always)] + fn fmadd_magnitude(slots: &mut [u128; 8], value: Fr, magnitude: Limbs<4>) { + let value = value.inner_limbs(); + for i in 0..4 { + for j in 0..4 { + let product = (value.0[i] as u128) * (magnitude.0[j] as u128); + slots[i + j] += (product as u64) as u128; + slots[i + j + 1] += ((product >> 64) as u64) as u128; + } + } + } + + #[inline] + fn normalize(slots: [u128; 8]) -> Limbs<9> { + let mut out = [0u64; 9]; + let mut carry = 0u128; + for (index, slot) in slots.into_iter().enumerate() { + let (sum, overflow) = slot.overflowing_add(carry); + out[index] = sum as u64; + carry = (sum >> 64) + ((overflow as u128) << 64); + } + out[8] = carry as u64; + Limbs(out) + } +} + +impl SignedProductAccumulator for FrSignedProductAccumulator { + type Element = Fr; + + #[inline(always)] + fn fmadd_s256(&mut self, value: Fr, scalar: &S256) { + if scalar.is_zero() { + return; + } + if scalar.is_positive { + Self::fmadd_magnitude(&mut self.pos, value, scalar.magnitude); + } else { + Self::fmadd_magnitude(&mut self.neg, value, scalar.magnitude); + } + } + + #[inline] + fn reduce(self) -> Fr { + let pos = Self::normalize(self.pos); + let neg = Self::normalize(self.neg); + let montgomery_r_value = + Fr::from_bigint_unchecked(Limbs(>::R2.0)); + let reduced = if pos >= neg { + Fr::from_inner(bn254_ops::from_montgomery_reduce(BigInt::from( + pos.sub_trunc::<9, 9>(&neg), + ))) + } else { + -Fr::from_inner(bn254_ops::from_montgomery_reduce(BigInt::from( + neg.sub_trunc::<9, 9>(&pos), + ))) + }; + reduced * montgomery_r_value + } +} + +#[cfg(test)] +mod tests { + use crate::FromPrimitiveInt; + + use super::*; + + fn s256_to_fr(value: &S256) -> Fr { + let mut bytes = [0u8; 32]; + for (index, limb) in value.magnitude_limbs().iter().copied().enumerate() { + bytes[index * 8..(index + 1) * 8].copy_from_slice(&limb.to_le_bytes()); + } + let magnitude = Fr::from_le_bytes_mod_order(&bytes); + if value.is_positive { + magnitude + } else { + -magnitude + } + } + + #[test] + fn signed_product_accumulator_reduces_mixed_terms() { + let terms = [ + (Fr::from_u64(3), S256::from_i128(17)), + (Fr::from_u64(11), S256::from_i128(-9)), + (Fr::from_u64(42), S256::new([7, 5, 3, 1], true)), + (Fr::from_u64(6), S256::new([u64::MAX, 19, 0, 0], false)), + ]; + + let mut acc = FrSignedProductAccumulator::default(); + let mut expected = Fr::from_u64(0); + for (field, scalar) in terms { + acc.fmadd_s256(field, &scalar); + expected += field * s256_to_fr(&scalar); + } + + assert_eq!(acc.reduce(), expected); + } +} diff --git a/crates/jolt-field/src/arkworks/small_scalar_accumulator.rs b/crates/jolt-field/src/arkworks/small_scalar_accumulator.rs new file mode 100644 index 0000000000..3523395a4f --- /dev/null +++ b/crates/jolt-field/src/arkworks/small_scalar_accumulator.rs @@ -0,0 +1,120 @@ +use crate::{Limbs, SignedScalarAccumulator}; + +use super::{bn254::Fr, bn254_ops}; + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct FrSmallScalarAccumulator { + pos: Limbs<5>, + neg: Limbs<5>, +} + +impl Default for FrSmallScalarAccumulator { + #[inline(always)] + fn default() -> Self { + Self { + pos: Limbs::zero(), + neg: Limbs::zero(), + } + } +} + +impl FrSmallScalarAccumulator { + #[inline(always)] + fn add_to_pos(&mut self, value: Fr) { + self.pos.add_assign_trunc::<4>(&value.inner_limbs()); + } + + #[inline(always)] + fn add_to_neg(&mut self, value: Fr) { + self.neg.add_assign_trunc::<4>(&value.inner_limbs()); + } + + #[inline(always)] + fn fmadd_magnitude_to_pos(&mut self, value: Fr, scalar: u64) { + if scalar == 0 { + return; + } + if scalar == 1 { + self.add_to_pos(value); + return; + } + self.pos + .add_assign_trunc::<5>(&bn254_ops::mul_u64_unreduced(value.0, scalar)); + } + + #[inline(always)] + fn fmadd_magnitude_to_neg(&mut self, value: Fr, scalar: u64) { + if scalar == 0 { + return; + } + if scalar == 1 { + self.add_to_neg(value); + return; + } + self.neg + .add_assign_trunc::<5>(&bn254_ops::mul_u64_unreduced(value.0, scalar)); + } +} + +impl SignedScalarAccumulator for FrSmallScalarAccumulator { + type Element = Fr; + + #[inline(always)] + fn add(&mut self, value: Fr) { + self.add_to_pos(value); + } + + #[inline(always)] + fn fmadd_u64(&mut self, value: Fr, scalar: u64) { + self.fmadd_magnitude_to_pos(value, scalar); + } + + #[inline(always)] + fn fmadd_i64(&mut self, value: Fr, scalar: i64) { + let magnitude = scalar.unsigned_abs(); + if scalar >= 0 { + self.fmadd_magnitude_to_pos(value, magnitude); + } else { + self.fmadd_magnitude_to_neg(value, magnitude); + } + } + + #[inline(always)] + fn reduce(self) -> Fr { + if self.pos >= self.neg { + Fr::from_inner(bn254_ops::reduce_nplus1( + self.pos.sub_trunc::<5, 5>(&self.neg), + )) + } else { + -Fr::from_inner(bn254_ops::reduce_nplus1( + self.neg.sub_trunc::<5, 5>(&self.pos), + )) + } + } +} + +#[cfg(test)] +mod tests { + use crate::FromPrimitiveInt; + + use super::*; + + #[test] + fn signed_small_scalar_accumulator_reduces_mixed_terms() { + let mut acc = FrSmallScalarAccumulator::default(); + acc.fmadd_u64(Fr::from_u64(3), 16); + acc.fmadd_i64(Fr::from_u64(5), -7); + acc.add(Fr::from_u64(11)); + + assert_eq!(acc.reduce(), Fr::from_u64(24)); + } + + #[test] + fn signed_small_scalar_accumulator_handles_negative_result() { + let mut acc = FrSmallScalarAccumulator::default(); + acc.fmadd_i64(Fr::from_u64(9), -13); + acc.fmadd_u64(Fr::from_u64(2), 7); + + assert_eq!(acc.reduce(), -Fr::from_u64(103)); + } +} diff --git a/crates/jolt-field/src/arkworks/wide_accumulator.rs b/crates/jolt-field/src/arkworks/wide_accumulator.rs index ab76c66ab4..f5c9fba8ad 100644 --- a/crates/jolt-field/src/arkworks/wide_accumulator.rs +++ b/crates/jolt-field/src/arkworks/wide_accumulator.rs @@ -1,37 +1,34 @@ //! Wide-integer accumulator for BN254 Fr deferred reduction. //! -//! Accumulates `sum += a * b` as 9-limb (576-bit) schoolbook products, -//! deferring the Montgomery reduction to a single call at the end. +//! Accumulates `sum += a * b` as folded 4x4 limb products, deferring +//! carry propagation and Montgomery reduction to a single call at the end. //! //! # Capacity //! -//! Each Fr element is 4 limbs (256 bits). The unreduced product of two -//! elements is 8 limbs (512 bits). A 9-limb accumulator (576 bits) can -//! hold up to 2^64 such products without overflow. +//! Each Fr element is 4 limbs (256 bits). The product of two elements is +//! accumulated into eight positional `u128` slots. Carry headroom in each +//! slot lets the hot loop avoid carry propagation until reduction. use crate::accumulator::{AdditiveAccumulator, RingAccumulator}; use crate::arkworks::bn254::Fr; -use crate::Limbs; +use ark_ff::BigInt; use super::bn254_ops; -/// Wide 9-limb accumulator for BN254 Fr deferred reduction. +/// Folded 4x4 product accumulator for BN254 Fr deferred reduction. /// -/// Stores the running sum of Montgomery-form products as a 576-bit integer. -/// Converting to a field element requires a single Montgomery reduction -/// via [`AdditiveAccumulator::reduce`]. +/// Stores the running sum of Montgomery-form products in positional `u128` +/// slots. Converting to a field element requires one carry propagation pass +/// and one Montgomery reduction via [`AdditiveAccumulator::reduce`]. #[derive(Clone, Copy)] pub struct WideAccumulator { - /// 9 limbs = 2×4 (product width) + 1 (addition headroom). - limbs: Limbs<9>, + slots: [u128; 8], } impl Default for WideAccumulator { #[inline] fn default() -> Self { - Self { - limbs: Limbs::zero(), - } + Self { slots: [0; 8] } } } @@ -45,22 +42,45 @@ impl AdditiveAccumulator for WideAccumulator { #[inline(always)] fn merge(&mut self, other: Self) { - self.limbs.add_assign_trunc::<9>(&other.limbs); + for (lhs, rhs) in self.slots.iter_mut().zip(other.slots) { + *lhs += rhs; + } } fn reduce(self) -> Fr { // The accumulator holds Montgomery-form products and/or elements. - // Montgomery reduction divides product terms by R; raw element additions - // are already in Montgomery form and live in the low limbs. - let bigint = self.limbs.into(); - Fr::from_inner(bn254_ops::from_montgomery_reduce(bigint)) + // Montgomery reduction divides product terms by R. + Fr::from_inner(bn254_ops::from_montgomery_reduce(self.normalize())) } } impl RingAccumulator for WideAccumulator { #[inline(always)] fn fmadd(&mut self, a: Fr, b: Fr) { - self.limbs.fmadd::<4, 4>(&a.inner_limbs(), &b.inner_limbs()); + let a = a.inner_limbs(); + let b = b.inner_limbs(); + for i in 0..4 { + for j in 0..4 { + let product = (a.0[i] as u128) * (b.0[j] as u128); + self.slots[i + j] += (product as u64) as u128; + self.slots[i + j + 1] += ((product >> 64) as u64) as u128; + } + } + } +} + +impl WideAccumulator { + #[inline] + fn normalize(self) -> BigInt<9> { + let mut out = [0u64; 9]; + let mut carry = 0u128; + for (index, slot) in self.slots.into_iter().enumerate() { + let (sum, overflow) = slot.overflowing_add(carry); + out[index] = sum as u64; + carry = (sum >> 64) + ((overflow as u128) << 64); + } + out[8] = carry as u64; + BigInt::new(out) } } diff --git a/crates/jolt-field/src/field.rs b/crates/jolt-field/src/field.rs index 80c1da6778..fe46ce719f 100644 --- a/crates/jolt-field/src/field.rs +++ b/crates/jolt-field/src/field.rs @@ -8,7 +8,7 @@ use std::ops::Mul; use crate::{ CanonicalBitLength, CanonicalBytes, CanonicalU64, FieldCore, FixedByteSize, FixedBytes, FromPrimitiveInt, MulPow2, MulPrimitiveInt, RandomSampling, ReducingBytes, RingCore, - TranscriptChallenge, WithAccumulator, + TranscriptChallenge, WithAccumulator, WithSignedProductAccumulator, WithSmallScalarAccumulator, }; /// Prime field element abstraction used throughout Jolt. @@ -41,6 +41,8 @@ pub trait Field: + CanonicalU64 + RandomSampling + WithAccumulator + + WithSmallScalarAccumulator + + WithSignedProductAccumulator + MulPow2 + MulPrimitiveInt + Serialize diff --git a/crates/jolt-field/src/lib.rs b/crates/jolt-field/src/lib.rs index 143c649521..b43f3dbfc6 100644 --- a/crates/jolt-field/src/lib.rs +++ b/crates/jolt-field/src/lib.rs @@ -22,6 +22,7 @@ //! # BN254 types (feature `bn254`) //! //! - [`Fr`] — BN254 scalar field element +//! - [`Fq`] — BN254 base field element //! - [`WideAccumulator`] — 9-limb deferred Montgomery reduction //! //! # Multi-precision arithmetic @@ -46,6 +47,8 @@ mod mul_primitive_int; mod random_sampling; mod reducing_bytes; mod ring_core; +mod signed_product_accumulator; +mod small_scalar_accumulator; mod transcript_challenge; mod with_accumulator; @@ -66,6 +69,12 @@ pub use mul_primitive_int::MulPrimitiveInt; pub use random_sampling::RandomSampling; pub use reducing_bytes::ReducingBytes; pub use ring_core::RingCore; +pub use signed_product_accumulator::{ + NaiveSignedProductAccumulator, SignedProductAccumulator, WithSignedProductAccumulator, +}; +pub use small_scalar_accumulator::{ + NaiveSignedScalarAccumulator, SignedScalarAccumulator, WithSmallScalarAccumulator, +}; pub use transcript_challenge::TranscriptChallenge; pub use with_accumulator::WithAccumulator; @@ -79,4 +88,10 @@ pub mod arkworks; #[cfg(feature = "bn254")] pub use arkworks::bn254::Fr; #[cfg(feature = "bn254")] +pub use arkworks::bn254_fq::Fq; +#[cfg(feature = "bn254")] +pub use arkworks::signed_product_accumulator::FrSignedProductAccumulator; +#[cfg(feature = "bn254")] +pub use arkworks::small_scalar_accumulator::FrSmallScalarAccumulator; +#[cfg(feature = "bn254")] pub use arkworks::wide_accumulator::WideAccumulator; diff --git a/crates/jolt-field/src/signed_product_accumulator.rs b/crates/jolt-field/src/signed_product_accumulator.rs new file mode 100644 index 0000000000..25447e6081 --- /dev/null +++ b/crates/jolt-field/src/signed_product_accumulator.rs @@ -0,0 +1,54 @@ +use crate::{signed::S256, AdditiveGroup, ReducingBytes, RingCore}; +use num_traits::Zero; + +pub trait SignedProductAccumulator: Default + Copy + Send + Sync { + type Element: AdditiveGroup + RingCore + ReducingBytes; + + fn fmadd_s256(&mut self, value: Self::Element, scalar: &S256); + + fn reduce(self) -> Self::Element; +} + +pub trait WithSignedProductAccumulator: AdditiveGroup { + type SignedProductAccumulator: SignedProductAccumulator; +} + +#[derive(Clone, Copy)] +pub struct NaiveSignedProductAccumulator(R); + +impl Default for NaiveSignedProductAccumulator { + #[inline] + fn default() -> Self { + Self(R::zero()) + } +} + +impl SignedProductAccumulator for NaiveSignedProductAccumulator +where + R: AdditiveGroup + RingCore + ReducingBytes, +{ + type Element = R; + + #[inline] + fn fmadd_s256(&mut self, value: R, scalar: &S256) { + if scalar.is_zero() { + return; + } + let mut bytes = [0u8; 32]; + for (index, limb) in scalar.magnitude_limbs().iter().copied().enumerate() { + bytes[index * 8..(index + 1) * 8].copy_from_slice(&limb.to_le_bytes()); + } + let magnitude = R::from_le_bytes_mod_order(&bytes); + let term = if scalar.is_positive { + value * magnitude + } else { + -(value * magnitude) + }; + self.0 += term; + } + + #[inline] + fn reduce(self) -> R { + self.0 + } +} diff --git a/crates/jolt-field/src/small_scalar_accumulator.rs b/crates/jolt-field/src/small_scalar_accumulator.rs new file mode 100644 index 0000000000..9a6c3cc173 --- /dev/null +++ b/crates/jolt-field/src/small_scalar_accumulator.rs @@ -0,0 +1,60 @@ +use crate::{AdditiveGroup, MulPrimitiveInt}; + +pub trait SignedScalarAccumulator: Default + Copy + Send + Sync { + type Element: AdditiveGroup + MulPrimitiveInt; + + fn add(&mut self, value: Self::Element); + + fn fmadd_u64(&mut self, value: Self::Element, scalar: u64); + + fn fmadd_i64(&mut self, value: Self::Element, scalar: i64) { + let magnitude = scalar.unsigned_abs(); + if scalar >= 0 { + self.fmadd_u64(value, magnitude); + } else { + self.add(-value.mul_u64(magnitude)); + } + } + + fn reduce(self) -> Self::Element; +} + +pub trait WithSmallScalarAccumulator: AdditiveGroup { + type SmallScalarAccumulator: SignedScalarAccumulator; +} + +#[derive(Clone, Copy)] +pub struct NaiveSignedScalarAccumulator(R); + +impl Default for NaiveSignedScalarAccumulator { + #[inline] + fn default() -> Self { + Self(R::zero()) + } +} + +impl SignedScalarAccumulator + for NaiveSignedScalarAccumulator +{ + type Element = R; + + #[inline] + fn add(&mut self, value: R) { + self.0 += value; + } + + #[inline] + fn fmadd_u64(&mut self, value: R, scalar: u64) { + self.0 += value.mul_u64(scalar); + } + + #[inline] + fn fmadd_i64(&mut self, value: R, scalar: i64) { + self.0 += value.mul_i64(scalar); + } + + #[inline] + fn reduce(self) -> R { + self.0 + } +} diff --git a/crates/jolt-hyperkzg/src/scheme.rs b/crates/jolt-hyperkzg/src/scheme.rs index 426038f4a5..fd9d8d846f 100644 --- a/crates/jolt-hyperkzg/src/scheme.rs +++ b/crates/jolt-hyperkzg/src/scheme.rs @@ -17,6 +17,7 @@ use jolt_poly::Polynomial; use jolt_transcript::{AppendToTranscript, Label, LabelWithCount, Transcript}; use num_traits::{One, Zero}; use rayon::prelude::*; +use serde::{Deserialize, Serialize}; use crate::error::HyperKZGError; use crate::kzg::{self, kzg_open_batch, kzg_verify_batch}; @@ -26,8 +27,9 @@ use crate::types::{HyperKZGCommitment, HyperKZGProof, HyperKZGProverSetup, Hyper /// /// Generic over `P: PairingGroup`. Implements [`CommitmentScheme`] and /// [`AdditivelyHomomorphic`] from `jolt-openings`. -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct HyperKZGScheme { + #[serde(skip)] _phantom: PhantomData

, } diff --git a/crates/jolt-hyperkzg/src/types.rs b/crates/jolt-hyperkzg/src/types.rs index 55ba4e0848..73798c11b5 100644 --- a/crates/jolt-hyperkzg/src/types.rs +++ b/crates/jolt-hyperkzg/src/types.rs @@ -43,16 +43,16 @@ impl PartialEq for HyperKZGCommitment

{ impl Eq for HyperKZGCommitment

{} -impl HomomorphicCommitment for HyperKZGCommitment

{ +impl HomomorphicCommitment for HyperKZGCommitment

{ #[inline] fn add(c1: &Self, c2: &Self) -> Self { Self { - point: >::add(&c1.point, &c2.point), + point: >::add(&c1.point, &c2.point), } } #[inline] - fn linear_combine(c1: &Self, c2: &Self, scalar: &F) -> Self { + fn linear_combine(c1: &Self, c2: &Self, scalar: &P::ScalarField) -> Self { Self { point: HomomorphicCommitment::linear_combine(&c1.point, &c2.point, scalar), } diff --git a/crates/jolt-openings/Cargo.toml b/crates/jolt-openings/Cargo.toml index 8e23641600..357b41b840 100644 --- a/crates/jolt-openings/Cargo.toml +++ b/crates/jolt-openings/Cargo.toml @@ -12,6 +12,7 @@ workspace = true jolt-crypto.workspace = true jolt-field.workspace = true jolt-poly.workspace = true +jolt-r1cs = { workspace = true, optional = true } jolt-transcript.workspace = true serde.workspace = true tracing.workspace = true @@ -19,6 +20,7 @@ thiserror.workspace = true [dev-dependencies] jolt-crypto = { workspace = true, features = ["bn254"] } +jolt-transcript = { workspace = true, features = ["poseidon-r1cs"] } rand.workspace = true rand_chacha.workspace = true rand_core.workspace = true @@ -32,4 +34,5 @@ harness = false ignored = ["rand", "jolt-crypto"] [features] +r1cs = ["dep:jolt-r1cs"] test-utils = [] diff --git a/crates/jolt-openings/src/lib.rs b/crates/jolt-openings/src/lib.rs index f38aa1bde6..3521fb5fec 100644 --- a/crates/jolt-openings/src/lib.rs +++ b/crates/jolt-openings/src/lib.rs @@ -25,14 +25,16 @@ //! AdditivelyHomomorphic ZkOpeningScheme //! (+ combine) (+ commit_zk/open_zk/verify_zk) //! │ -//! StreamingCommitment -//! (+ begin/feed/finish) +//! StreamingCommitment ── ZkStreamingCommitment +//! (+ begin/feed/finish) (+ finish_zk) //! ``` mod claims; mod error; #[cfg(any(test, feature = "test-utils"))] pub mod mock; +#[cfg(feature = "r1cs")] +pub mod r1cs; mod reduction; mod schemes; @@ -40,4 +42,7 @@ pub use claims::{EvaluationClaim, ProverOpeningClaim, VerifierOpeningClaim}; pub use error::OpeningsError; pub use reduction::{reduce_prover, reduce_verifier, rlc_combine, rlc_combine_scalars}; -pub use schemes::{AdditivelyHomomorphic, CommitmentScheme, StreamingCommitment, ZkOpeningScheme}; +pub use schemes::{ + AdditivelyHomomorphic, CommitmentScheme, StreamingCommitment, ZkOpeningScheme, + ZkStreamingCommitment, +}; diff --git a/crates/jolt-openings/src/mock.rs b/crates/jolt-openings/src/mock.rs index 5d079454f1..155635a67f 100644 --- a/crates/jolt-openings/src/mock.rs +++ b/crates/jolt-openings/src/mock.rs @@ -11,7 +11,10 @@ use serde::{Deserialize, Serialize}; use jolt_crypto::HomomorphicCommitment; use crate::error::OpeningsError; -use crate::schemes::{AdditivelyHomomorphic, CommitmentScheme, ZkOpeningScheme}; +use crate::schemes::{ + AdditivelyHomomorphic, CommitmentScheme, StreamingCommitment, ZkOpeningScheme, + ZkStreamingCommitment, +}; #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] #[serde(bound = "")] @@ -22,16 +25,24 @@ pub struct MockCommitmentScheme(PhantomData); #[serde(bound = "")] pub struct MockCommitment { evaluations: Vec, + zk: bool, } impl Default for MockCommitment { fn default() -> Self { Self { evaluations: Vec::new(), + zk: false, } } } +impl MockCommitment { + pub const fn is_zk(&self) -> bool { + self.zk + } +} + #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] #[serde(bound = "")] pub struct MockProof { @@ -47,6 +58,7 @@ impl AppendToTranscript for MockCommitment { e.to_bytes_le(&mut buf[start..]); } buf.reverse(); + buf.push(u8::from(self.zk)); transcript.append_bytes(&buf); } } @@ -78,7 +90,13 @@ impl CommitmentScheme for MockCommitmentScheme { poly.for_each_row(poly.num_vars(), &mut |_, row| { evaluations.extend_from_slice(row); }); - (MockCommitment { evaluations }, ()) + ( + MockCommitment { + evaluations, + zk: false, + }, + (), + ) } fn open( @@ -141,6 +159,7 @@ impl HomomorphicCommitment for MockCommitment { } MockCommitment { evaluations: result, + zk: c1.zk || c2.zk, } } } @@ -159,11 +178,115 @@ impl AdditivelyHomomorphic for MockCommitmentScheme { MockCommitment { evaluations: result, + zk: commitments.iter().any(|commitment| commitment.zk), } } } -#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +impl StreamingCommitment for MockCommitmentScheme { + type PartialCommitment = Vec; + type OneHotChunkCommitment = Vec>; + type OneHotStreamContext = (); + + fn begin(_setup: &Self::ProverSetup) -> Self::PartialCommitment { + Vec::new() + } + + fn feed( + partial: &mut Self::PartialCommitment, + chunk: &[Self::Field], + _setup: &Self::ProverSetup, + ) { + partial.extend_from_slice(chunk); + } + + fn begin_one_hot_column_major_stream( + _setup: &Self::ProverSetup, + _row_width: usize, + ) -> Self::OneHotStreamContext { + } + + fn process_one_hot_chunk( + _context: &mut Self::OneHotStreamContext, + _setup: &Self::ProverSetup, + one_hot_k: usize, + chunk: &[Option], + ) -> Self::OneHotChunkCommitment { + let mut rows = vec![vec![F::zero(); chunk.len()]; one_hot_k]; + for (column, hot_row) in chunk.iter().copied().enumerate() { + if let Some(hot_row) = hot_row { + rows[hot_row][column] = F::one(); + } + } + rows + } + + fn finish(partial: Self::PartialCommitment, _setup: &Self::ProverSetup) -> Self::Output { + MockCommitment { + evaluations: partial, + zk: false, + } + } + + fn finish_one_hot_column_major_chunks( + _setup: &Self::ProverSetup, + one_hot_k: usize, + chunks: &[Self::OneHotChunkCommitment], + ) -> (Self::Output, Self::OpeningHint) { + ( + MockCommitment { + evaluations: flatten_one_hot_chunks(one_hot_k, chunks), + zk: false, + }, + (), + ) + } +} + +impl ZkStreamingCommitment for MockCommitmentScheme { + fn finish_zk_with_hint( + partial: Self::PartialCommitment, + _setup: &Self::ProverSetup, + ) -> (Self::Output, Self::OpeningHint) { + ( + MockCommitment { + evaluations: partial, + zk: true, + }, + (), + ) + } + + fn finish_zk_one_hot_column_major_chunks( + _setup: &Self::ProverSetup, + one_hot_k: usize, + chunks: &[Self::OneHotChunkCommitment], + ) -> (Self::Output, Self::OpeningHint) { + ( + MockCommitment { + evaluations: flatten_one_hot_chunks(one_hot_k, chunks), + zk: true, + }, + (), + ) + } +} + +fn flatten_one_hot_chunks(one_hot_k: usize, chunks: &[Vec>]) -> Vec { + let chunk_width = chunks + .first() + .and_then(|chunk| chunk.first()) + .map_or(0, Vec::len); + let mut evaluations = Vec::with_capacity(one_hot_k * chunks.len() * chunk_width); + for row in 0..one_hot_k { + for chunk in chunks { + evaluations.extend_from_slice(&chunk[row]); + } + } + evaluations +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)] #[serde(bound = "")] pub struct MockHidingCommitment { pub eval: F, @@ -183,7 +306,9 @@ impl ZkOpeningScheme for MockCommitmentScheme { poly: &P, setup: &Self::ProverSetup, ) -> (Self::Output, Self::OpeningHint) { - Self::commit(poly, setup) + let (mut commitment, hint) = Self::commit(poly, setup); + commitment.zk = true; + (commitment, hint) } fn open_zk( diff --git a/crates/jolt-openings/src/r1cs.rs b/crates/jolt-openings/src/r1cs.rs new file mode 100644 index 0000000000..1610174642 --- /dev/null +++ b/crates/jolt-openings/src/r1cs.rs @@ -0,0 +1,525 @@ +//! R1CS helpers for generic opening-claim preparation. +//! +//! This module only constrains scheme-independent opening algebra. Concrete +//! commitment checks belong to the R1CS module of the selected PCS. + +use thiserror::Error; + +use jolt_r1cs::{R1csBuilder, ScalarGadget}; + +#[derive(Clone, Debug, Error, PartialEq, Eq)] +pub enum OpeningR1csError { + #[error("opening claim reduction requires at least one opening claim")] + EmptyOpeningClaims, + #[error("opening point length mismatch: expected {expected}, got {got}")] + OpeningPointLengthMismatch { expected: usize, got: usize }, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct OpeningClaimVar +where + S: ScalarGadget, +{ + pub commitment: C, + pub point: Vec, + pub opening_claim: S, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct ReducedOpeningClaimScalars +where + S: ScalarGadget, +{ + pub point: Vec, + pub opening_claim: S, +} + +impl OpeningClaimVar +where + S: ScalarGadget, +{ + pub fn new(commitment: C, point: Vec, opening_claim: S) -> Self { + Self { + commitment, + point, + opening_claim, + } + } +} + +pub fn reduce_same_point_opening_claims( + builder: &mut R1csBuilder, + claims: &[OpeningClaimVar], + batching_challenge: &S, +) -> Result, OpeningR1csError> +where + S: ScalarGadget, +{ + let point = assert_same_opening_point(builder, claims)?; + let opening_claim = reduce_opening_claim_scalars( + builder, + claims.iter().map(|claim| &claim.opening_claim), + batching_challenge, + )?; + + Ok(ReducedOpeningClaimScalars { + point, + opening_claim, + }) +} + +pub fn assert_same_opening_point( + builder: &mut R1csBuilder, + claims: &[OpeningClaimVar], +) -> Result, OpeningR1csError> +where + S: ScalarGadget, +{ + let Some((first, rest)) = claims.split_first() else { + return Err(OpeningR1csError::EmptyOpeningClaims); + }; + + for claim in rest { + if claim.point.len() != first.point.len() { + return Err(OpeningR1csError::OpeningPointLengthMismatch { + expected: first.point.len(), + got: claim.point.len(), + }); + } + + for (actual, expected) in claim.point.iter().zip(&first.point) { + actual.assert_equal(builder, expected); + } + } + + Ok(first.point.clone()) +} + +pub fn reduce_opening_claim_scalars<'a, S>( + builder: &mut R1csBuilder, + opening_claims: impl IntoIterator, + batching_challenge: &S, +) -> Result +where + S: ScalarGadget + 'a, +{ + let opening_claims = opening_claims.into_iter().collect::>(); + let Some((last, rest)) = opening_claims.split_last() else { + return Err(OpeningR1csError::EmptyOpeningClaims); + }; + + let mut reduced = (*last).clone(); + for opening_claim in rest.iter().rev() { + reduced = reduced.mul(builder, batching_challenge); + reduced = reduced.add(builder, opening_claim); + } + + Ok(reduced) +} + +#[cfg(test)] +#[expect(clippy::expect_used, reason = "tests may panic on assertion failures")] +mod tests { + use jolt_field::{Fq, Fr, FromPrimitiveInt}; + use jolt_r1cs::{AssignedScalar, FqVar, Variable}; + use jolt_transcript::r1cs::{PoseidonR1csTranscript, R1csJoltTranscript, R1csTranscript}; + + use super::*; + use crate::rlc_combine_scalars; + + #[test] + fn native_reduces_same_point_opening_claims() { + let mut builder = R1csBuilder::::new(); + let claims = native_claims(&mut builder); + let gamma = AssignedScalar::alloc(&mut builder, Fr::from_u64(2)); + + let reduced = reduce_same_point_opening_claims(&mut builder, &claims, &gamma) + .expect("same-point opening claims reduce"); + + reduced + .opening_claim + .assert_equal(&mut builder, &AssignedScalar::constant(Fr::from_u64(170))); + assert_eq!(reduced.point.len(), 2); + assert!(builder_accepts(builder)); + } + + #[test] + fn native_reduction_matches_opening_rlc_helper() { + let mut builder = R1csBuilder::::new(); + let values = [ + Fr::from_u64(10), + Fr::from_u64(20), + Fr::from_u64(30), + Fr::from_u64(40), + ]; + let opening_claims = values + .iter() + .copied() + .map(|value| AssignedScalar::alloc(&mut builder, value)) + .collect::>(); + let gamma = AssignedScalar::alloc(&mut builder, Fr::from_u64(7)); + let expected = rlc_combine_scalars(&values, gamma.value); + + let reduced = reduce_opening_claim_scalars(&mut builder, &opening_claims, &gamma) + .expect("opening claims reduce"); + reduced.assert_equal(&mut builder, &AssignedScalar::constant(expected)); + + assert!(builder_accepts(builder)); + } + + #[test] + fn native_rejects_opening_claim_challenge_and_point_tampering() { + let mut builder = R1csBuilder::::new(); + let a = AssignedScalar::alloc(&mut builder, Fr::from_u64(10)); + let b = AssignedScalar::alloc(&mut builder, Fr::from_u64(20)); + let c = AssignedScalar::alloc(&mut builder, Fr::from_u64(30)); + let gamma = AssignedScalar::alloc(&mut builder, Fr::from_u64(2)); + let point0 = AssignedScalar::alloc(&mut builder, Fr::from_u64(7)); + let point1 = AssignedScalar::alloc(&mut builder, Fr::from_u64(8)); + let point1_copy = AssignedScalar::alloc(&mut builder, Fr::from_u64(8)); + let targets = [ + ("opening claim", variable(&a)), + ("batching challenge", variable(&gamma)), + ("opening point", variable(&point1_copy)), + ]; + let claims = vec![ + OpeningClaimVar::new(0usize, vec![point0.clone(), point1], a), + OpeningClaimVar::new(1usize, vec![point0.clone(), point1_copy], b), + OpeningClaimVar::new( + 2usize, + vec![point0, AssignedScalar::constant(Fr::from_u64(8))], + c, + ), + ]; + + let reduced = reduce_same_point_opening_claims(&mut builder, &claims, &gamma) + .expect("same-point opening claims reduce"); + reduced + .opening_claim + .assert_equal(&mut builder, &AssignedScalar::constant(Fr::from_u64(170))); + + assert_tampering_rejected(builder, targets); + } + + #[test] + fn native_rejects_wrong_opening_point() { + let mut builder = R1csBuilder::::new(); + let gamma = AssignedScalar::alloc(&mut builder, Fr::from_u64(2)); + let claims = vec![ + OpeningClaimVar::new( + 0usize, + vec![AssignedScalar::alloc(&mut builder, Fr::from_u64(7))], + AssignedScalar::alloc(&mut builder, Fr::from_u64(10)), + ), + OpeningClaimVar::new( + 1usize, + vec![AssignedScalar::alloc(&mut builder, Fr::from_u64(8))], + AssignedScalar::alloc(&mut builder, Fr::from_u64(20)), + ), + ]; + + let _ = reduce_same_point_opening_claims(&mut builder, &claims, &gamma) + .expect("point equality constraints are emitted"); + + assert!(builder_rejects(builder)); + } + + #[test] + fn native_accepts_independent_same_point_groups() { + let mut builder = R1csBuilder::::new(); + let gamma = AssignedScalar::alloc(&mut builder, Fr::from_u64(3)); + let first_group = vec![ + OpeningClaimVar::new( + 0usize, + vec![AssignedScalar::alloc(&mut builder, Fr::from_u64(7))], + AssignedScalar::alloc(&mut builder, Fr::from_u64(10)), + ), + OpeningClaimVar::new( + 1usize, + vec![AssignedScalar::alloc(&mut builder, Fr::from_u64(7))], + AssignedScalar::alloc(&mut builder, Fr::from_u64(20)), + ), + ]; + let second_group = vec![ + OpeningClaimVar::new( + 2usize, + vec![AssignedScalar::alloc(&mut builder, Fr::from_u64(99))], + AssignedScalar::alloc(&mut builder, Fr::from_u64(30)), + ), + OpeningClaimVar::new( + 3usize, + vec![AssignedScalar::alloc(&mut builder, Fr::from_u64(99))], + AssignedScalar::alloc(&mut builder, Fr::from_u64(40)), + ), + ]; + + let first = reduce_same_point_opening_claims(&mut builder, &first_group, &gamma) + .expect("first point group reduces"); + let second = reduce_same_point_opening_claims(&mut builder, &second_group, &gamma) + .expect("second point group reduces"); + + first + .opening_claim + .assert_equal(&mut builder, &AssignedScalar::constant(Fr::from_u64(70))); + second + .opening_claim + .assert_equal(&mut builder, &AssignedScalar::constant(Fr::from_u64(150))); + + assert!(builder_accepts(builder)); + } + + #[test] + fn transcript_challenge_composes_with_native_reduction() { + let mut builder = R1csBuilder::::new(); + let mut transcript = PoseidonR1csTranscript::new(&mut builder, b"OpeningR1cs"); + let claims = [ + AssignedScalar::alloc(&mut builder, Fr::from_u64(11)), + AssignedScalar::alloc(&mut builder, Fr::from_u64(13)), + AssignedScalar::alloc(&mut builder, Fr::from_u64(17)), + ]; + transcript.append_scalars(&mut builder, b"opening_claim", &claims); + let gamma = transcript.challenge_scalar(&mut builder); + let expected_values: [Fr; 3] = std::array::from_fn(|index| claims[index].value); + let expected = rlc_combine_scalars(&expected_values, gamma.value); + let targets = [ + ("transcript opening claim", variable(&claims[0])), + ("transcript challenge", variable(&gamma)), + ]; + + let reduced = + reduce_opening_claim_scalars(&mut builder, &claims, &gamma).expect("claims reduce"); + reduced.assert_equal(&mut builder, &AssignedScalar::constant(expected)); + + assert_tampering_rejected(builder, targets); + } + + #[test] + fn nonnative_reduces_same_point_opening_claims() { + let mut builder = R1csBuilder::::new(); + let claims = nonnative_claims(&mut builder); + let gamma = FqVar::alloc(&mut builder, Fq::from_u64(2)); + + let reduced = reduce_same_point_opening_claims(&mut builder, &claims, &gamma) + .expect("same-point opening claims reduce"); + + reduced + .opening_claim + .assert_equal(&mut builder, &FqVar::constant(Fq::from_u64(170))); + assert_eq!(reduced.point.len(), 2); + assert!(builder_accepts(builder)); + } + + #[test] + fn nonnative_reduction_matches_opening_rlc_helper() { + let mut builder = R1csBuilder::::new(); + let values = [ + Fq::from_u64(10), + Fq::from_u64(20), + Fq::from_u64(30), + Fq::from_u64(40), + ]; + let opening_claims = values + .iter() + .copied() + .map(|value| FqVar::alloc(&mut builder, value)) + .collect::>(); + let gamma_value = Fq::from_u64(7); + let gamma = FqVar::alloc(&mut builder, gamma_value); + let expected = rlc_combine_scalars(&values, gamma_value); + + let reduced = reduce_opening_claim_scalars(&mut builder, &opening_claims, &gamma) + .expect("opening claims reduce"); + reduced.assert_equal(&mut builder, &FqVar::constant(expected)); + + assert!(builder_accepts(builder)); + } + + #[test] + fn nonnative_rejects_opening_claim_challenge_and_point_tampering() { + let mut builder = R1csBuilder::::new(); + let a = FqVar::alloc(&mut builder, Fq::from_u64(10)); + let b = FqVar::alloc(&mut builder, Fq::from_u64(20)); + let c = FqVar::alloc(&mut builder, Fq::from_u64(30)); + let gamma = FqVar::alloc(&mut builder, Fq::from_u64(2)); + let point0 = FqVar::alloc(&mut builder, Fq::from_u64(7)); + let point1 = FqVar::alloc(&mut builder, Fq::from_u64(8)); + let point1_copy = FqVar::alloc(&mut builder, Fq::from_u64(8)); + let targets = [ + ("opening claim limb", variable(&a.limbs()[0])), + ("batching challenge limb", variable(&gamma.limbs()[0])), + ("opening point limb", variable(&point1_copy.limbs()[0])), + ]; + let claims = vec![ + OpeningClaimVar::new(0usize, vec![point0.clone(), point1], a), + OpeningClaimVar::new(1usize, vec![point0.clone(), point1_copy], b), + OpeningClaimVar::new(2usize, vec![point0, FqVar::constant(Fq::from_u64(8))], c), + ]; + + let reduced = reduce_same_point_opening_claims(&mut builder, &claims, &gamma) + .expect("same-point opening claims reduce"); + reduced + .opening_claim + .assert_equal(&mut builder, &FqVar::constant(Fq::from_u64(170))); + + assert_tampering_rejected(builder, targets); + } + + #[test] + fn nonnative_rejects_wrong_opening_claim() { + let mut builder = R1csBuilder::::new(); + let claims = nonnative_claims(&mut builder); + let gamma = FqVar::alloc(&mut builder, Fq::from_u64(2)); + + let reduced = reduce_same_point_opening_claims(&mut builder, &claims, &gamma) + .expect("same-point opening claims reduce"); + reduced + .opening_claim + .assert_equal(&mut builder, &FqVar::constant(Fq::from_u64(171))); + + assert!(builder_rejects(builder)); + } + + #[test] + fn empty_opening_claims_are_typed_errors() { + let mut builder = R1csBuilder::::new(); + let gamma = AssignedScalar::constant(Fr::from_u64(2)); + + assert_eq!( + reduce_same_point_opening_claims::, usize>( + &mut builder, + &[], + &gamma + ), + Err(OpeningR1csError::EmptyOpeningClaims) + ); + assert_eq!( + reduce_opening_claim_scalars::>(&mut builder, &[], &gamma), + Err(OpeningR1csError::EmptyOpeningClaims) + ); + } + + #[test] + fn opening_point_length_mismatch_is_a_typed_error() { + let mut builder = R1csBuilder::::new(); + let gamma = AssignedScalar::constant(Fr::from_u64(2)); + let claims = vec![ + OpeningClaimVar::new( + 0usize, + vec![AssignedScalar::constant(Fr::from_u64(7))], + AssignedScalar::constant(Fr::from_u64(10)), + ), + OpeningClaimVar::new( + 1usize, + vec![ + AssignedScalar::constant(Fr::from_u64(7)), + AssignedScalar::constant(Fr::from_u64(8)), + ], + AssignedScalar::constant(Fr::from_u64(20)), + ), + ]; + + assert_eq!( + reduce_same_point_opening_claims(&mut builder, &claims, &gamma), + Err(OpeningR1csError::OpeningPointLengthMismatch { + expected: 1, + got: 2 + }) + ); + } + + fn native_claims( + builder: &mut R1csBuilder, + ) -> Vec, usize>> { + let point0 = AssignedScalar::alloc(builder, Fr::from_u64(7)); + let point1 = AssignedScalar::alloc(builder, Fr::from_u64(8)); + vec![ + OpeningClaimVar::new( + 0, + vec![point0.clone(), point1.clone()], + AssignedScalar::alloc(builder, Fr::from_u64(10)), + ), + OpeningClaimVar::new( + 1, + vec![point0.clone(), point1.clone()], + AssignedScalar::alloc(builder, Fr::from_u64(20)), + ), + OpeningClaimVar::new( + 2, + vec![point0, point1], + AssignedScalar::alloc(builder, Fr::from_u64(30)), + ), + ] + } + + fn nonnative_claims(builder: &mut R1csBuilder) -> Vec> { + let point0 = FqVar::alloc(builder, Fq::from_u64(7)); + let point1 = FqVar::alloc(builder, Fq::from_u64(8)); + vec![ + OpeningClaimVar::new( + 0, + vec![point0.clone(), point1.clone()], + FqVar::alloc(builder, Fq::from_u64(10)), + ), + OpeningClaimVar::new( + 1, + vec![point0.clone(), point1.clone()], + FqVar::alloc(builder, Fq::from_u64(20)), + ), + OpeningClaimVar::new( + 2, + vec![point0, point1], + FqVar::alloc(builder, Fq::from_u64(30)), + ), + ] + } + + fn builder_accepts(builder: R1csBuilder) -> bool + where + F: jolt_field::Field, + { + let witness = builder.witness().expect("witness is assigned"); + builder.into_matrices().check_witness(&witness).is_ok() + } + + fn builder_rejects(builder: R1csBuilder) -> bool + where + F: jolt_field::Field, + { + let witness = builder.witness().expect("witness is assigned"); + builder.into_matrices().check_witness(&witness).is_err() + } + + fn assert_tampering_rejected( + builder: R1csBuilder, + targets: impl IntoIterator, + ) where + F: jolt_field::Field, + { + let witness = builder.witness().expect("witness is assigned"); + let matrices = builder.into_matrices(); + assert!(matrices.check_witness(&witness).is_ok()); + + for (label, variable) in targets { + let mut tampered = witness.clone(); + tampered[variable.index()] += F::one(); + assert!( + matrices.check_witness(&tampered).is_err(), + "{label} accepted after tampering variable {}", + variable.index() + ); + } + } + + fn variable(scalar: &AssignedScalar) -> Variable + where + F: jolt_field::Field, + { + scalar + .lc + .terms + .first() + .copied() + .expect("expected scalar backed by one variable") + .0 + } +} diff --git a/crates/jolt-openings/src/schemes.rs b/crates/jolt-openings/src/schemes.rs index 97e68e59aa..bd3c573f89 100644 --- a/crates/jolt-openings/src/schemes.rs +++ b/crates/jolt-openings/src/schemes.rs @@ -4,11 +4,12 @@ //! - [`AdditivelyHomomorphic`] — linear combination of commitments. //! - [`StreamingCommitment`] — chunked commitment without full materialization. //! - [`ZkOpeningScheme`] — zero-knowledge commitments and opening proofs. +//! - [`ZkStreamingCommitment`] — chunked zero-knowledge commitments. use std::fmt::Debug; use jolt_crypto::{Commitment, HomomorphicCommitment}; -use jolt_field::Field; +use jolt_field::{Field, FromPrimitiveInt}; use jolt_poly::MultilinearPoly; use jolt_transcript::{AppendToTranscript, Transcript}; use serde::{de::DeserializeOwned, Serialize}; @@ -47,6 +48,20 @@ pub trait CommitmentScheme: Commitment { transcript: &mut impl Transcript, ) -> Self::Proof; + fn open_poly>( + poly: &P, + point: &[Self::Field], + eval: Self::Field, + setup: &Self::ProverSetup, + hint: Option, + transcript: &mut impl Transcript, + ) -> Self::Proof { + let mut evals = Vec::with_capacity(1usize << poly.num_vars()); + poly.for_each_row(poly.num_vars(), &mut |_, row| evals.extend_from_slice(row)); + let dense = Self::Polynomial::from(evals); + Self::open(&dense, point, eval, setup, hint, transcript) + } + fn verify( commitment: &Self::Output, point: &[Self::Field], @@ -81,6 +96,8 @@ where /// Incremental commitment without full materialization. pub trait StreamingCommitment: CommitmentScheme { type PartialCommitment: Clone + Send + Sync; + type OneHotChunkCommitment: Clone + Send + Sync; + type OneHotStreamContext: Send + Sync; fn begin(setup: &Self::ProverSetup) -> Self::PartialCommitment; @@ -90,7 +107,80 @@ pub trait StreamingCommitment: CommitmentScheme { setup: &Self::ProverSetup, ); + fn feed_zeros( + partial: &mut Self::PartialCommitment, + row_width: usize, + rows: usize, + setup: &Self::ProverSetup, + ) { + if rows == 0 { + return; + } + let row = vec![Self::Field::from_u64(0); row_width]; + for _ in 0..rows { + Self::feed(partial, &row, setup); + } + } + + fn feed_u64(partial: &mut Self::PartialCommitment, chunk: &[u64], setup: &Self::ProverSetup) { + let values: Vec = chunk + .iter() + .copied() + .map(::from_u64) + .collect(); + Self::feed(partial, &values, setup); + } + + fn feed_i128(partial: &mut Self::PartialCommitment, chunk: &[i128], setup: &Self::ProverSetup) { + let values: Vec = chunk + .iter() + .copied() + .map(::from_i128) + .collect(); + Self::feed(partial, &values, setup); + } + + fn begin_one_hot_column_major_stream( + setup: &Self::ProverSetup, + row_width: usize, + ) -> Self::OneHotStreamContext; + + fn process_one_hot_chunk( + context: &mut Self::OneHotStreamContext, + setup: &Self::ProverSetup, + one_hot_k: usize, + chunk: &[Option], + ) -> Self::OneHotChunkCommitment; + fn finish(partial: Self::PartialCommitment, setup: &Self::ProverSetup) -> Self::Output; + + fn finish_with_hint( + partial: Self::PartialCommitment, + setup: &Self::ProverSetup, + ) -> (Self::Output, Self::OpeningHint) { + (Self::finish(partial, setup), Self::OpeningHint::default()) + } + + fn finish_one_hot_column_major_chunks( + setup: &Self::ProverSetup, + one_hot_k: usize, + chunks: &[Self::OneHotChunkCommitment], + ) -> (Self::Output, Self::OpeningHint); +} + +/// Incremental commitment support for schemes whose hiding/ZK commitment mode +/// is distinct from the transparent streaming path. +pub trait ZkStreamingCommitment: StreamingCommitment + ZkOpeningScheme { + fn finish_zk_with_hint( + partial: Self::PartialCommitment, + setup: &Self::ProverSetup, + ) -> (Self::Output, Self::OpeningHint); + + fn finish_zk_one_hot_column_major_chunks( + setup: &Self::ProverSetup, + one_hot_k: usize, + chunks: &[Self::OneHotChunkCommitment], + ) -> (Self::Output, Self::OpeningHint); } /// Opening proofs that hide the evaluation behind a commitment. @@ -124,6 +214,20 @@ pub trait ZkOpeningScheme: CommitmentScheme { transcript: &mut impl Transcript, ) -> (Self::Proof, Self::HidingCommitment, Self::Blind); + fn open_zk_poly>( + poly: &P, + point: &[Self::Field], + eval: Self::Field, + setup: &Self::ProverSetup, + hint: Self::OpeningHint, + transcript: &mut impl Transcript, + ) -> (Self::Proof, Self::HidingCommitment, Self::Blind) { + let mut evals = Vec::with_capacity(1usize << poly.num_vars()); + poly.for_each_row(poly.num_vars(), &mut |_, row| evals.extend_from_slice(row)); + let dense = Self::Polynomial::from(evals); + Self::open_zk(&dense, point, eval, setup, hint, transcript) + } + /// Verify a ZK opening proof and return the hiding commitment to the /// evaluation that the proof binds internally. fn verify_zk( diff --git a/crates/jolt-poly/Cargo.toml b/crates/jolt-poly/Cargo.toml index 738aae75ee..394a704a3a 100644 --- a/crates/jolt-poly/Cargo.toml +++ b/crates/jolt-poly/Cargo.toml @@ -36,3 +36,4 @@ ignored = ["rand"] [features] default = ["parallel"] parallel = ["rayon"] +r1cs = [] diff --git a/crates/jolt-poly/src/compressed_univariate.rs b/crates/jolt-poly/src/compressed_univariate.rs index e013789f29..964968f937 100644 --- a/crates/jolt-poly/src/compressed_univariate.rs +++ b/crates/jolt-poly/src/compressed_univariate.rs @@ -68,12 +68,17 @@ impl CompressedPoly { /// in O(d) multiplications. #[inline] pub fn evaluate_with_hint(&self, hint: F, point: F) -> F { - let linear_term = self.recover_linear_term(hint); + self.eval_from_hint(&hint, &point) + } + + #[inline] + pub fn eval_from_hint(&self, hint: &F, point: &F) -> F { + let linear_term = self.recover_linear_term(*hint); - let mut x_pow = point; - let mut sum = self.coeffs_except_linear_term[0] + point * linear_term; + let mut x_pow = *point; + let mut sum = self.coeffs_except_linear_term[0] + *point * linear_term; for &c in &self.coeffs_except_linear_term[1..] { - x_pow *= point; + x_pow *= *point; sum += c * x_pow; } sum diff --git a/crates/jolt-poly/src/lagrange.rs b/crates/jolt-poly/src/lagrange.rs index e10cf96fcf..d4430a21e8 100644 --- a/crates/jolt-poly/src/lagrange.rs +++ b/crates/jolt-poly/src/lagrange.rs @@ -4,7 +4,7 @@ //! protocols. All functions are generic over [`Field`] and operate on //! integer-indexed domains (symmetric or arbitrary). -use std::fmt; +use std::{fmt, marker::PhantomData}; use jolt_field::Field; @@ -69,22 +69,46 @@ pub fn centered_lagrange_evals( domain_size: usize, r: F, ) -> Result, CenteredIntegerDomainError> { - Ok(lagrange_evals( - centered_domain_start(domain_size)?, - domain_size, - r, - )) + let _ = centered_domain_start(domain_size)?; + macro_rules! evals { + ($n:literal) => { + Ok(LagrangePolynomial::::evals::<$n>(r).to_vec()) + }; + } + match domain_size { + 1 => evals!(1), + 2 => evals!(2), + 3 => evals!(3), + 4 => evals!(4), + 5 => evals!(5), + 6 => evals!(6), + 7 => evals!(7), + 8 => evals!(8), + 9 => evals!(9), + 10 => evals!(10), + 11 => evals!(11), + 12 => evals!(12), + 13 => evals!(13), + 14 => evals!(14), + 15 => evals!(15), + 16 => evals!(16), + 17 => evals!(17), + 18 => evals!(18), + 19 => evals!(19), + 20 => evals!(20), + _ => Ok(lagrange_evals( + centered_domain_start(domain_size)?, + domain_size, + r, + )), + } } pub fn centered_lagrange_evals_array( r: F, ) -> Result<[F; N], CenteredIntegerDomainError> { - let evals = centered_lagrange_evals(N, r)?; - let mut result = [F::zero(); N]; - for (dst, src) in result.iter_mut().zip(evals) { - *dst = src; - } - Ok(result) + let _ = centered_domain_start(N)?; + Ok(LagrangePolynomial::::evals::(r)) } /// Computes `sum_i L_i(x) * L_i(y)` over the centered consecutive integer @@ -94,13 +118,43 @@ pub fn centered_lagrange_kernel( x: F, y: F, ) -> Result { - let x_evals = centered_lagrange_evals(domain_size, x)?; - let y_evals = centered_lagrange_evals(domain_size, y)?; - Ok(x_evals - .into_iter() - .zip(y_evals) - .map(|(left, right)| left * right) - .sum()) + let _ = centered_domain_start(domain_size)?; + macro_rules! kernel { + ($n:literal) => { + Ok(LagrangePolynomial::::lagrange_kernel::<$n>(x, y)) + }; + } + match domain_size { + 1 => kernel!(1), + 2 => kernel!(2), + 3 => kernel!(3), + 4 => kernel!(4), + 5 => kernel!(5), + 6 => kernel!(6), + 7 => kernel!(7), + 8 => kernel!(8), + 9 => kernel!(9), + 10 => kernel!(10), + 11 => kernel!(11), + 12 => kernel!(12), + 13 => kernel!(13), + 14 => kernel!(14), + 15 => kernel!(15), + 16 => kernel!(16), + 17 => kernel!(17), + 18 => kernel!(18), + 19 => kernel!(19), + 20 => kernel!(20), + _ => { + let x_evals = lagrange_evals(centered_domain_start(domain_size)?, domain_size, x); + let y_evals = lagrange_evals(centered_domain_start(domain_size)?, domain_size, y); + Ok(x_evals + .into_iter() + .zip(y_evals) + .map(|(left, right)| left * right) + .sum()) + } + } } /// Computes power sums $S_k = \sum_{t=-D}^{D} t^k$ for $k = 0, 1, \ldots, \text{num\_powers}-1$ @@ -149,6 +203,319 @@ impl fmt::Display for CenteredIntegerDomainError { impl std::error::Error for CenteredIntegerDomainError {} +pub struct LagrangeHelper; + +impl LagrangeHelper { + #[inline] + pub const fn fact(n: usize) -> u64 { + let mut acc = 1u64; + let mut i = 2usize; + while i <= n { + acc *= i as u64; + i += 1; + } + acc + } + + pub const FACT_U64_0_TO_20: [u64; 21] = { + let mut out = [0u64; 21]; + let mut i = 0usize; + while i <= 20 { + out[i] = Self::fact(i); + i += 1; + } + out + }; + + #[inline] + pub const fn den_row_i64() -> [i64; N] { + let mut out = [0i64; N]; + let mut i = 0usize; + while i < N { + let left = Self::FACT_U64_0_TO_20[i] as i128; + let right = Self::FACT_U64_0_TO_20[N - 1 - i] as i128; + let mut value = left * right; + if ((N - 1 - i) & 1) == 1 { + value = -value; + } + out[i] = value as i64; + i += 1; + } + out + } +} + +pub struct LagrangePolynomial(PhantomData); + +impl LagrangePolynomial { + #[inline] + fn start_i64() -> i64 { + -(((N - 1) / 2) as i64) + } + + #[inline] + fn distances(r: F) -> ([F; N], Option) { + let mut dists = [F::zero(); N]; + let mut base = r - F::from_i64(Self::start_i64::()); + let mut hit = None; + for (i, dist) in dists.iter_mut().enumerate() { + let current = base; + if current.is_zero() { + hit = Some(i); + } + *dist = current; + base -= F::one(); + } + (dists, hit) + } + + #[inline] + #[expect(clippy::expect_used)] + fn inv_denom() -> [F; N] { + let den_i64 = LagrangeHelper::den_row_i64::(); + let mut denom = [F::zero(); N]; + for (dst, &src) in denom.iter_mut().zip(den_i64.iter()) { + *dst = F::from_i64(src); + } + + let mut left = [F::one(); N]; + for i in 1..N { + left[i] = left[i - 1] * denom[i - 1]; + } + let inv_total = (left[N - 1] * denom[N - 1]) + .inverse() + .expect("Lagrange denominator product is invertible"); + + let mut inv_denom = [F::zero(); N]; + let mut right = F::one(); + for i in (0..N).rev() { + inv_denom[i] = left[i] * right * inv_total; + right *= denom[i]; + } + inv_denom + } + + #[inline] + #[expect(clippy::expect_used)] + fn bary_terms_from_dists(dists: &[F; N], inv_denom: &[F; N]) -> ([F; N], F) { + let mut prefix = [F::one(); N]; + for i in 1..N { + prefix[i] = prefix[i - 1] * dists[i - 1]; + } + let inv_prod = (prefix[N - 1] * dists[N - 1]) + .inverse() + .expect("off-domain Lagrange distance product is invertible"); + + let mut suffix = [F::one(); N]; + for i in (0..N.saturating_sub(1)).rev() { + suffix[i] = suffix[i + 1] * dists[i + 1]; + } + + let mut terms = [F::zero(); N]; + let mut sum = F::zero(); + for i in 0..N { + let inv_di = prefix[i] * suffix[i] * inv_prod; + let term = inv_denom[i] * inv_di; + terms[i] = term; + sum += term; + } + (terms, sum) + } + + #[inline] + #[expect(clippy::expect_used)] + pub fn evaluate(values: &[F; N], r: F) -> F { + debug_assert!(N > 0, "N must be positive"); + debug_assert!(N <= 20, "evaluate is intended for small N"); + let (dists, hit) = Self::distances::(r); + if let Some(i) = hit { + return values[i]; + } + let inv_denom = Self::inv_denom::(); + let (terms, sum) = Self::bary_terms_from_dists::(&dists, &inv_denom); + let inv_sum = sum + .inverse() + .expect("off-domain Lagrange term sum is invertible"); + let mut numerator = F::zero(); + for i in 0..N { + numerator += values[i] * terms[i]; + } + numerator * inv_sum + } + + #[inline] + #[expect(clippy::expect_used)] + pub fn evals(r: F) -> [F; N] { + debug_assert!(N > 0, "N must be positive"); + debug_assert!(N <= 20, "evals is intended for small N"); + let (dists, hit) = Self::distances::(r); + if let Some(i) = hit { + let mut out = [F::zero(); N]; + out[i] = F::one(); + return out; + } + let inv_denom = Self::inv_denom::(); + let (terms, sum) = Self::bary_terms_from_dists::(&dists, &inv_denom); + let inv_sum = sum + .inverse() + .expect("off-domain Lagrange term sum is invertible"); + let mut out = [F::zero(); N]; + for i in 0..N { + out[i] = terms[i] * inv_sum; + } + out + } + + #[inline] + #[expect(clippy::expect_used)] + pub fn lagrange_kernel(x: F, y: F) -> F { + debug_assert!(N > 0, "N must be positive"); + debug_assert!(N <= 20, "lagrange_kernel is intended for small N"); + let (dists_x, hit_x) = Self::distances::(x); + let (dists_y, hit_y) = Self::distances::(y); + + if let (Some(ix), Some(jy)) = (hit_x, hit_y) { + return if ix == jy { F::one() } else { F::zero() }; + } + + let inv_denom = Self::inv_denom::(); + if let Some(ix) = hit_x { + let (terms_y, sum_y) = Self::bary_terms_from_dists::(&dists_y, &inv_denom); + return terms_y[ix] + * sum_y + .inverse() + .expect("off-domain Lagrange term sum is invertible"); + } + if let Some(jy) = hit_y { + let (terms_x, sum_x) = Self::bary_terms_from_dists::(&dists_x, &inv_denom); + return terms_x[jy] + * sum_x + .inverse() + .expect("off-domain Lagrange term sum is invertible"); + } + + let (terms_x, sum_x) = Self::bary_terms_from_dists::(&dists_x, &inv_denom); + let (terms_y, sum_y) = Self::bary_terms_from_dists::(&dists_y, &inv_denom); + let mut numerator = F::zero(); + for i in 0..N { + numerator += terms_x[i] * terms_y[i]; + } + numerator + * (sum_x * sum_y) + .inverse() + .expect("off-domain Lagrange kernel denominator is invertible") + } + + pub fn evaluate_many(values: &[F; N], points: &[F]) -> Vec { + if points.is_empty() { + return Vec::new(); + } + + if points.len() > N { + let coeffs = Self::interpolate_coeffs(values); + points + .iter() + .map(|&point| { + let mut result = coeffs[N - 1]; + for i in (0..N - 1).rev() { + result = result * point + coeffs[i]; + } + result + }) + .collect() + } else { + points + .iter() + .map(|&point| Self::evaluate::(values, point)) + .collect() + } + } + + #[inline] + #[expect(clippy::expect_used)] + pub fn interpolate_coeffs(values: &[F; N]) -> [F; N] { + debug_assert!(N > 0, "N must be positive"); + let degree = N - 1; + let start = Self::start_i64::(); + + let mut smalls = [0u64; N]; + let mut prefix = [F::one(); N]; + for m in 1..=degree { + smalls[m] = m as u64; + prefix[m] = prefix[m - 1].mul_u64(smalls[m]); + } + let inv_total = prefix[degree] + .inverse() + .expect("factorial product is invertible"); + let mut right = F::one(); + let mut inverses = [F::zero(); N]; + for idx in (1..=degree).rev() { + inverses[idx] = prefix[idx - 1] * right * inv_total; + right = right.mul_u64(smalls[idx]); + } + + let mut dd = *values; + let mut newton = [F::zero(); N]; + newton[0] = dd[0]; + for order in 1..=degree { + let inv = inverses[order]; + for i in 0..(N - order) { + dd[i] = (dd[i + 1] - dd[i]) * inv; + } + newton[order] = dd[0]; + } + + let mut coeffs = [F::zero(); N]; + let mut basis = [F::zero(); N]; + basis[0] = F::one(); + let mut basis_degree = 0usize; + for (k, &scale) in newton.iter().enumerate() { + for j in 0..=basis_degree { + coeffs[j] += scale * basis[j]; + } + + if k == degree { + break; + } + + let node = start + k as i64; + let last = basis[basis_degree]; + for idx in (1..=basis_degree).rev() { + let old = basis[idx]; + basis[idx] = basis[idx - 1] - old.mul_i64(node); + } + basis[0] = -basis[0].mul_i64(node); + basis_degree += 1; + basis[basis_degree] = last; + } + + coeffs + } +} + +pub fn centered_lagrange_evaluate( + values: &[F; N], + r: F, +) -> Result { + let _ = centered_domain_start(N)?; + Ok(LagrangePolynomial::::evaluate::(values, r)) +} + +pub fn centered_lagrange_evaluate_many( + values: &[F; N], + points: &[F], +) -> Result, CenteredIntegerDomainError> { + let _ = centered_domain_start(N)?; + Ok(LagrangePolynomial::::evaluate_many::(values, points)) +} + +pub fn centered_interpolate_coeffs_array( + values: &[F; N], +) -> Result<[F; N], CenteredIntegerDomainError> { + let _ = centered_domain_start(N)?; + Ok(LagrangePolynomial::::interpolate_coeffs::(values)) +} + /// Start of the centered consecutive-integer domain used by core univariate skip. /// /// The domain has `domain_size` consecutive integer points diff --git a/crates/jolt-poly/src/lib.rs b/crates/jolt-poly/src/lib.rs index 8016ca2083..be480806ae 100644 --- a/crates/jolt-poly/src/lib.rs +++ b/crates/jolt-poly/src/lib.rs @@ -20,6 +20,7 @@ //! - [`EqPolynomial`]: Equality polynomial `eq(x, r)`, materialized via bottom-up doubling //! - [`EqPlusOnePolynomial`]: Successor polynomial `eq+1(x, y)` evaluating to 1 when `y = x + 1` //! - [`EqPlusOnePrefixSuffix`]: Prefix-suffix decomposition of `eq+1` for sqrt-sized sumcheck buffers +//! - [`TensorEqTable`] and [`GruenSplitEqPolynomial`]: Split eq tables for sqrt-memory sumcheck kernels //! - [`LtPolynomial`]: Less-than polynomial `LT(x, r)` with split optimization for sqrt-sized buffers //! - [`IdentityPolynomial`]: Maps hypercube points to their integer index //! - [`UnivariatePoly`]: Coefficient-form univariate with Lagrange interpolation and compression @@ -56,6 +57,9 @@ mod mle; mod multilinear; mod one_hot; mod point; +#[cfg(feature = "r1cs")] +pub mod r1cs; +mod split_eq; pub mod thread; mod univariate; @@ -70,6 +74,7 @@ pub use mle::{ block_selector_mle_msb, range_mask_mle_msb, sparse_mle_msb, sparse_segments_mle_msb, MleError, }; pub use multilinear::{MultilinearBinding, MultilinearEvaluation, MultilinearPoly, RlcSource}; -pub use one_hot::OneHotPolynomial; +pub use one_hot::{OneHotIndexOrder, OneHotPolynomial}; pub use point::{Endianness, Point, HIGH_TO_LOW, LOW_TO_HIGH}; +pub use split_eq::{GruenSplitEqPolynomial, TensorEqTable}; pub use univariate::{UnivariatePoly, UnivariatePolynomial}; diff --git a/crates/jolt-poly/src/lt.rs b/crates/jolt-poly/src/lt.rs index 1b50547463..d30cfc6de2 100644 --- a/crates/jolt-poly/src/lt.rs +++ b/crates/jolt-poly/src/lt.rs @@ -20,9 +20,9 @@ //! LT(j, r) = LT(j_hi, r_hi) + eq(j_hi, r_hi) · LT(j_lo, r_lo) //! ``` //! -//! where `j = (j_hi, j_lo)`. Binding proceeds HighToLow: first all hi vars -//! (shrinking `lt_hi` and `eq_hi`), then all lo vars (shrinking `lt_lo`). -//! Total memory stays at 3 · √N throughout. +//! where `j = (j_hi, j_lo)`. Binding can proceed HighToLow or LowToHigh. +//! HighToLow first shrinks `lt_hi` and `eq_hi`; LowToHigh first shrinks +//! `lt_lo`. Total memory stays at 3 · √N throughout. use jolt_field::Field; @@ -33,7 +33,6 @@ use crate::EqPolynomial; /// Stores three sub-tables of size ≤ √N each, reconstructing full-table /// values on demand via `LT[j] = lt_hi[j_hi] + eq_hi[j_hi] · lt_lo[j_lo]`. /// -/// Supports HighToLow binding only (MSB first). pub struct LtPolynomial { lt_lo: Vec, lt_hi: Vec, @@ -87,17 +86,37 @@ impl LtPolynomial { } /// Returns `(LT[j], LT[j + half])` for HighToLow sumcheck pairing. - /// - /// In the hi-binding phase, the pairing splits across hi-table halves. - /// In the lo-binding phase (all hi vars bound), it splits across lo-table halves. #[inline] pub fn sumcheck_eval_pair(&self, j: usize) -> (F, F) { - let half = self.len() / 2; - (self.get(j), self.get(j + half)) + self.sumcheck_eval_pair_with_order(j, crate::BindingOrder::HighToLow) + } + + /// Returns the `(lo, hi)` pair for the requested sumcheck binding order. + #[inline] + pub fn sumcheck_eval_pair_with_order(&self, j: usize, order: crate::BindingOrder) -> (F, F) { + match order { + crate::BindingOrder::HighToLow => { + let half = self.len() / 2; + (self.get(j), self.get(j + half)) + } + crate::BindingOrder::LowToHigh => (self.get(2 * j), self.get(2 * j + 1)), + } } /// Binds the MSB (HighToLow), halving the effective table size. pub fn bind(&mut self, challenge: F) { + self.bind_with_order(challenge, crate::BindingOrder::HighToLow); + } + + /// Binds the next variable for the requested sumcheck binding order. + pub fn bind_with_order(&mut self, challenge: F, order: crate::BindingOrder) { + match order { + crate::BindingOrder::HighToLow => self.bind_high_to_low(challenge), + crate::BindingOrder::LowToHigh => self.bind_low_to_high(challenge), + } + } + + fn bind_high_to_low(&mut self, challenge: F) { if self.n_hi_vars > 0 { bind_in_place(&mut self.lt_hi, challenge); bind_in_place(&mut self.eq_hi, challenge); @@ -109,6 +128,18 @@ impl LtPolynomial { } } + fn bind_low_to_high(&mut self, challenge: F) { + if self.n_lo_vars > 0 { + bind_in_place_low_to_high(&mut self.lt_lo, challenge); + self.n_lo_vars -= 1; + } else { + assert!(self.n_hi_vars > 0, "no variables left to bind"); + bind_in_place_low_to_high(&mut self.lt_hi, challenge); + bind_in_place_low_to_high(&mut self.eq_hi, challenge); + self.n_hi_vars -= 1; + } + } + /// Materializes the full `2^n` evaluation table `[LT(0, r), ..., LT(2^n - 1, r)]`. /// /// Big-endian index order: `j = 0` corresponds to the all-zeros vertex. @@ -166,6 +197,18 @@ fn bind_in_place(v: &mut Vec, challenge: F) { v.truncate(half); } +/// In-place LowToHigh bind: `v[j] = v[2j] + challenge · (v[2j+1] - v[2j])`. +#[inline] +fn bind_in_place_low_to_high(v: &mut Vec, challenge: F) { + let half = v.len() / 2; + for j in 0..half { + let lo = v[2 * j]; + let hi = v[2 * j + 1]; + v[j] = lo + challenge * (hi - lo); + } + v.truncate(half); +} + #[cfg(test)] mod tests { use super::*; @@ -256,6 +299,23 @@ mod tests { } } + #[test] + fn low_to_high_sumcheck_eval_pair_matches_full_table() { + let mut rng = ChaCha20Rng::seed_from_u64(124); + for n in 2..=7 { + let r: Vec = (0..n).map(|_| Fr::random(&mut rng)).collect(); + let full_table = LtPolynomial::evaluations(&r); + let split = LtPolynomial::new(&r); + + for (j, pair) in full_table.chunks_exact(2).enumerate() { + let (lo, hi) = + split.sumcheck_eval_pair_with_order(j, crate::BindingOrder::LowToHigh); + assert_eq!(lo, pair[0], "lo mismatch at j={j}, n={n}"); + assert_eq!(hi, pair[1], "hi mismatch at j={j}, n={n}"); + } + } + } + #[test] fn sequential_bind_converges() { // Bind all variables → single scalar = evaluate(challenges, r). @@ -325,6 +385,34 @@ mod tests { } } + #[test] + fn low_to_high_multi_round_bind_matches_full_table() { + let mut rng = ChaCha20Rng::seed_from_u64(401); + let n = 6; + let r: Vec = (0..n).map(|_| Fr::random(&mut rng)).collect(); + let challenges: Vec = (0..n).map(|_| Fr::random(&mut rng)).collect(); + + let mut split = LtPolynomial::new(&r); + let mut full = LtPolynomial::evaluations(&r); + + for (round, &c) in challenges.iter().enumerate() { + split.bind_with_order(c, crate::BindingOrder::LowToHigh); + bind_in_place_low_to_high(&mut full, c); + + assert_eq!(split.len(), full.len(), "size mismatch after round {round}"); + for (j, &expected) in full.iter().enumerate() { + assert_eq!( + split.get(j), + expected, + "mismatch at j={j} after round {round}" + ); + } + } + + let point = challenges.iter().rev().copied().collect::>(); + assert_eq!(split.get(0), LtPolynomial::evaluate(&point, &r)); + } + #[test] fn inline_evaluate_matches_table() { let mut rng = ChaCha20Rng::seed_from_u64(500); diff --git a/crates/jolt-poly/src/multilinear.rs b/crates/jolt-poly/src/multilinear.rs index b2b49906eb..133ae378c2 100644 --- a/crates/jolt-poly/src/multilinear.rs +++ b/crates/jolt-poly/src/multilinear.rs @@ -124,6 +124,23 @@ pub trait MultilinearPoly: Send + Sync { false } + /// Returns the one-hot address-space size `K` when this polynomial exposes + /// its sparse row representation. + fn one_hot_k(&self) -> Option { + None + } + + /// Returns one optional hot column per cycle row when this polynomial + /// exposes its sparse row representation. + fn one_hot_indices(&self) -> Option<&[Option]> { + None + } + + /// Returns the coefficient order used by the sparse one-hot representation. + fn one_hot_index_order(&self) -> Option { + None + } + /// Iterates over positions whose value is exactly `F::one()`. /// /// Only implementations that return true from [`is_one_hot`](Self::is_one_hot) diff --git a/crates/jolt-poly/src/one_hot.rs b/crates/jolt-poly/src/one_hot.rs index da9f097a34..7d450ff4f6 100644 --- a/crates/jolt-poly/src/one_hot.rs +++ b/crates/jolt-poly/src/one_hot.rs @@ -25,6 +25,14 @@ pub struct OneHotPolynomial { k: usize, indices: Vec>, num_vars: usize, + index_order: OneHotIndexOrder, +} + +#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)] +pub enum OneHotIndexOrder { + #[default] + RowMajor, + ColumnMajor, } impl OneHotPolynomial { @@ -34,6 +42,23 @@ impl OneHotPolynomial { /// /// Panics if `k * indices.len()` is not a power of two. pub fn new(k: usize, indices: Vec>) -> Self { + Self::new_with_index_order(k, indices, OneHotIndexOrder::RowMajor) + } + + /// Creates a one-hot polynomial with an explicit coefficient order. + /// + /// `RowMajor` stores each cycle contiguously as `cycle * k + column`. + /// `ColumnMajor` stores each column contiguously as `column * T + cycle`, + /// matching the legacy core Dory cycle-major RA commitment layout. + /// + /// # Panics + /// + /// Panics if `k * indices.len()` is not a power of two. + pub fn new_with_index_order( + k: usize, + indices: Vec>, + index_order: OneHotIndexOrder, + ) -> Self { assert!( k <= u8::MAX as usize + 1, "k exceeds u8 index range ({k} > 256)" @@ -48,6 +73,7 @@ impl OneHotPolynomial { k, indices, num_vars, + index_order, } } @@ -74,6 +100,19 @@ impl OneHotPolynomial { pub fn num_vars(&self) -> usize { self.num_vars } + + #[inline] + pub const fn index_order(&self) -> OneHotIndexOrder { + self.index_order + } + + #[inline] + fn flat_index(&self, row: usize, col: u8) -> usize { + match self.index_order { + OneHotIndexOrder::RowMajor => row * self.k + col as usize, + OneHotIndexOrder::ColumnMajor => col as usize * self.indices.len() + row, + } + } } impl MultilinearPoly for OneHotPolynomial { @@ -88,7 +127,7 @@ impl MultilinearPoly for OneHotPolynomial { let mut result = F::zero(); for (row, &opt_col) in self.indices.iter().enumerate() { if let Some(col) = opt_col { - result += eq_evals[row * self.k + col as usize]; + result += eq_evals[self.flat_index(row, col)]; } } result @@ -103,7 +142,7 @@ impl MultilinearPoly for OneHotPolynomial { let mut row_hot_cols: Vec> = vec![Vec::new(); num_rows]; for (cycle, &opt_col) in self.indices.iter().enumerate() { if let Some(col) = opt_col { - let flat = cycle * self.k + col as usize; + let flat = self.flat_index(cycle, col); row_hot_cols[flat / num_cols].push(flat % num_cols); } } @@ -125,7 +164,7 @@ impl MultilinearPoly for OneHotPolynomial { let mut result = crate::thread::unsafe_allocate_zero_vec(num_cols); for (cycle, &opt_col) in self.indices.iter().enumerate() { if let Some(col) = opt_col { - let flat = cycle * self.k + col as usize; + let flat = self.flat_index(cycle, col); result[flat % num_cols] += left[flat / num_cols]; } } @@ -137,10 +176,25 @@ impl MultilinearPoly for OneHotPolynomial { true } + #[inline] + fn one_hot_k(&self) -> Option { + Some(self.k) + } + + #[inline] + fn one_hot_indices(&self) -> Option<&[Option]> { + Some(&self.indices) + } + + #[inline] + fn one_hot_index_order(&self) -> Option { + Some(self.index_order) + } + fn for_each_one(&self, f: &mut dyn FnMut(usize)) { for (cycle, &opt_col) in self.indices.iter().enumerate() { if let Some(col) = opt_col { - f(cycle * self.k + col as usize); + f(self.flat_index(cycle, col)); } } } @@ -164,7 +218,7 @@ mod tests { let mut table = vec![F::zero(); total]; for (row, &opt_col) in oh.indices.iter().enumerate() { if let Some(col) = opt_col { - table[row * oh.k + col as usize] = F::one(); + table[oh.flat_index(row, col)] = F::one(); } } Polynomial::new(table) @@ -244,6 +298,18 @@ mod tests { assert_eq!(entries[2], 3 * 4 + 3); } + #[test] + fn column_major_order_groups_cycles_by_column() { + let k = 4; + let indices = vec![Some(2), None, Some(0), Some(3)]; + let oh = OneHotPolynomial::new_with_index_order(k, indices, OneHotIndexOrder::ColumnMajor); + + let mut entries = Vec::new(); + >::for_each_one(&oh, &mut |idx| entries.push(idx)); + + assert_eq!(entries, vec![2 * 4, 2, 3 * 4 + 3]); + } + #[test] fn is_one_hot_returns_true() { let oh = make_one_hot(4, &[Some(0), Some(1), Some(2), Some(3)]); diff --git a/crates/jolt-poly/src/r1cs.rs b/crates/jolt-poly/src/r1cs.rs new file mode 100644 index 0000000000..d91d25cbdb --- /dev/null +++ b/crates/jolt-poly/src/r1cs.rs @@ -0,0 +1,266 @@ +use jolt_field::Field; +use num_traits::{One, Zero}; +use thiserror::Error; + +#[derive(Clone, Debug, Error, PartialEq, Eq)] +pub enum PolyR1csError { + #[error("eq polynomial arity mismatch: left has {left} coordinates, right has {right}")] + EqArityMismatch { left: usize, right: usize }, + #[error( + "multilinear evaluation table length mismatch: {num_vars} variables require {expected} evaluations, got {got}" + )] + EvaluationTableLengthMismatch { + num_vars: usize, + expected: usize, + got: usize, + }, + #[error("cannot materialize a hypercube of dimension {num_vars}")] + HypercubeTooLarge { num_vars: usize }, +} + +pub trait PolynomialScalarGadget: Clone { + type ConstraintBuilder; + type Scalar: Field; + + fn constant(scalar: Self::Scalar) -> Self; + fn add(&self, builder: &mut Self::ConstraintBuilder, rhs: &Self) -> Self; + fn sub(&self, builder: &mut Self::ConstraintBuilder, rhs: &Self) -> Self; + fn mul(&self, builder: &mut Self::ConstraintBuilder, rhs: &Self) -> Self; +} + +pub fn eq_eval( + builder: &mut S::ConstraintBuilder, + left: &[S], + right: &[S], +) -> Result +where + S: PolynomialScalarGadget, +{ + if left.len() != right.len() { + return Err(PolyR1csError::EqArityMismatch { + left: left.len(), + right: right.len(), + }); + } + + let mut result = S::constant(S::Scalar::one()); + for (left_coordinate, right_coordinate) in left.iter().zip(right) { + let both_one = left_coordinate.mul(builder, right_coordinate); + let one_minus_left = S::constant(S::Scalar::one()).sub(builder, left_coordinate); + let one_minus_right = S::constant(S::Scalar::one()).sub(builder, right_coordinate); + let both_zero = one_minus_left.mul(builder, &one_minus_right); + let coordinate_eq = both_one.add(builder, &both_zero); + result = result.mul(builder, &coordinate_eq); + } + + Ok(result) +} + +pub fn eq_evals(builder: &mut S::ConstraintBuilder, point: &[S]) -> Vec +where + S: PolynomialScalarGadget, +{ + scaled_eq_evals(builder, point, &S::constant(S::Scalar::one())) +} + +pub fn scaled_eq_evals( + builder: &mut S::ConstraintBuilder, + point: &[S], + scaling_factor: &S, +) -> Vec +where + S: PolynomialScalarGadget, +{ + let mut table = vec![scaling_factor.clone()]; + for coordinate in point { + let mut next_table = Vec::with_capacity(table.len() * 2); + for entry in &table { + let selected = entry.mul(builder, coordinate); + next_table.push(entry.sub(builder, &selected)); + next_table.push(selected); + } + table = next_table; + } + table +} + +pub fn inner_product( + builder: &mut S::ConstraintBuilder, + left: &[S], + right: &[S], +) -> Result +where + S: PolynomialScalarGadget, +{ + if left.len() != right.len() { + return Err(PolyR1csError::EqArityMismatch { + left: left.len(), + right: right.len(), + }); + } + + let mut result = S::constant(S::Scalar::zero()); + for (left_scalar, right_scalar) in left.iter().zip(right) { + let term = left_scalar.mul(builder, right_scalar); + result = result.add(builder, &term); + } + Ok(result) +} + +pub fn multilinear_eval( + builder: &mut S::ConstraintBuilder, + evaluations: &[S], + point: &[S], +) -> Result +where + S: PolynomialScalarGadget, +{ + let expected = hypercube_len(point.len())?; + if evaluations.len() != expected { + return Err(PolyR1csError::EvaluationTableLengthMismatch { + num_vars: point.len(), + expected, + got: evaluations.len(), + }); + } + + let weights = eq_evals(builder, point); + inner_product(builder, evaluations, &weights) +} + +fn hypercube_len(num_vars: usize) -> Result { + 1usize + .checked_shl(num_vars.try_into().unwrap_or(u32::MAX)) + .ok_or(PolyR1csError::HypercubeTooLarge { num_vars }) +} + +#[cfg(test)] +#[expect(clippy::expect_used, reason = "tests may panic on assertion failures")] +mod tests { + use jolt_field::{Fr, FromPrimitiveInt}; + + use super::*; + use crate::{EqPolynomial, Polynomial}; + + #[derive(Clone, Debug, PartialEq, Eq)] + struct PlainScalar(F); + + impl PolynomialScalarGadget for PlainScalar { + type ConstraintBuilder = (); + type Scalar = F; + + fn constant(scalar: Self::Scalar) -> Self { + Self(scalar) + } + + fn add(&self, _builder: &mut Self::ConstraintBuilder, rhs: &Self) -> Self { + Self(self.0 + rhs.0) + } + + fn sub(&self, _builder: &mut Self::ConstraintBuilder, rhs: &Self) -> Self { + Self(self.0 - rhs.0) + } + + fn mul(&self, _builder: &mut Self::ConstraintBuilder, rhs: &Self) -> Self { + Self(self.0 * rhs.0) + } + } + + #[test] + fn eq_eval_matches_plain_eq_mle() { + let mut builder = (); + let left_values = [Fr::from_u64(2), Fr::from_u64(3), Fr::from_u64(5)]; + let right_values = [Fr::from_u64(7), Fr::from_u64(11), Fr::from_u64(13)]; + let left = left_values + .iter() + .copied() + .map(PlainScalar) + .collect::>(); + let right = right_values + .iter() + .copied() + .map(PlainScalar) + .collect::>(); + + let result = eq_eval(&mut builder, &left, &right).expect("eq eval succeeds"); + let expected = EqPolynomial::new(left_values.to_vec()).evaluate(&right_values); + + assert_eq!(result, PlainScalar(expected)); + } + + #[test] + fn eq_evals_match_plain_table_order() { + let mut builder = (); + let point_values = [Fr::from_u64(2), Fr::from_u64(3), Fr::from_u64(5)]; + let point = point_values + .iter() + .copied() + .map(PlainScalar) + .collect::>(); + + let evals = eq_evals(&mut builder, &point) + .into_iter() + .map(|scalar| scalar.0) + .collect::>(); + let expected = EqPolynomial::new(point_values.to_vec()).evaluations(); + + assert_eq!(evals, expected); + } + + #[test] + fn multilinear_eval_matches_plain_evaluation() { + let mut builder = (); + let evaluation_values = (0..8) + .map(|index| Fr::from_u64((3 * index + 2) as u64)) + .collect::>(); + let point_values = [Fr::from_u64(2), Fr::from_u64(3), Fr::from_u64(5)]; + let evaluations = evaluation_values + .iter() + .copied() + .map(PlainScalar) + .collect::>(); + let point = point_values + .iter() + .copied() + .map(PlainScalar) + .collect::>(); + + let result = + multilinear_eval(&mut builder, &evaluations, &point).expect("evaluation succeeds"); + let expected = Polynomial::new(evaluation_values).evaluate(&point_values); + + assert_eq!(result, PlainScalar(expected)); + } + + #[test] + fn dimension_errors_are_typed() { + let mut builder = (); + let x = PlainScalar(Fr::from_u64(2)); + let y = PlainScalar(Fr::from_u64(3)); + + assert_eq!( + eq_eval( + &mut builder, + std::slice::from_ref(&x), + &[x.clone(), y.clone()] + ), + Err(PolyR1csError::EqArityMismatch { left: 1, right: 2 }) + ); + assert_eq!( + inner_product(&mut builder, std::slice::from_ref(&x), &[x.clone(), y]), + Err(PolyR1csError::EqArityMismatch { left: 1, right: 2 }) + ); + assert_eq!( + multilinear_eval( + &mut builder, + std::slice::from_ref(&x), + std::slice::from_ref(&x) + ), + Err(PolyR1csError::EvaluationTableLengthMismatch { + num_vars: 1, + expected: 2, + got: 1, + }) + ); + } +} diff --git a/crates/jolt-poly/src/split_eq.rs b/crates/jolt-poly/src/split_eq.rs new file mode 100644 index 0000000000..4011dfd9a5 --- /dev/null +++ b/crates/jolt-poly/src/split_eq.rs @@ -0,0 +1,637 @@ +//! Split equality tables for sqrt-memory sumcheck kernels. + +use jolt_field::Field; +#[cfg(feature = "parallel")] +use rayon::prelude::*; + +use crate::{BindingOrder, EqPolynomial, Polynomial, UnivariatePoly}; + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct TensorEqTable { + e_out: Vec, + e_in: Vec, + in_bits: usize, +} + +impl TensorEqTable { + pub fn new(point: &[F]) -> Self { + let split = point.len() / 2; + let (out_point, in_point) = point.split_at(split); + #[cfg(feature = "parallel")] + let (e_out, e_in) = rayon::join( + || EqPolynomial::::evals(out_point, None), + || EqPolynomial::::evals(in_point, None), + ); + #[cfg(not(feature = "parallel"))] + let (e_out, e_in) = ( + EqPolynomial::::evals(out_point, None), + EqPolynomial::::evals(in_point, None), + ); + Self { + e_out, + e_in, + in_bits: in_point.len(), + } + } + + pub fn len(&self) -> usize { + self.e_out.len() * self.e_in.len() + } + + pub fn is_empty(&self) -> bool { + self.e_out.is_empty() || self.e_in.is_empty() + } + + pub fn e_out(&self) -> &[F] { + &self.e_out + } + + pub fn e_in(&self) -> &[F] { + &self.e_in + } + + pub fn evaluate_index(&self, index: usize) -> F { + let x_out = index >> self.in_bits; + let x_in = index & ((1usize << self.in_bits) - 1); + self.e_out[x_out] * self.e_in[x_in] + } + + pub fn evaluate_slices(&self, values: &[&[F]]) -> Vec { + if values.is_empty() { + return Vec::new(); + } + debug_assert!( + values.iter().all(|values| values.len() == self.len()), + "TensorEqTable::evaluate_slices length mismatch" + ); + + self.par_fold_out_in( + || vec![F::zero(); values.len()], + |inner, row, _x_in, e_in| { + if e_in.is_zero() { + return; + } + for (accumulator, values) in inner.iter_mut().zip(values) { + *accumulator += e_in * values[row]; + } + }, + |_x_out, e_out, mut inner| { + if e_out.is_zero() { + inner.fill(F::zero()); + } else { + for value in &mut inner { + *value *= e_out; + } + } + inner + }, + |mut left, right| { + for (left, right) in left.iter_mut().zip(right) { + *left += right; + } + left + }, + ) + } + + #[inline(always)] + pub fn group_index(&self, x_out: usize, x_in: usize) -> usize { + (x_out << self.in_bits) | x_in + } + + #[inline] + pub fn par_fold_out_in< + OuterAcc: Send, + InnerAcc: Send, + MakeInner: Fn() -> InnerAcc + Sync + Send, + InnerStep: Fn(&mut InnerAcc, usize, usize, F) + Sync + Send, + OuterStep: Fn(usize, F, InnerAcc) -> OuterAcc + Sync + Send, + Merge: Fn(OuterAcc, OuterAcc) -> OuterAcc + Sync + Send, + >( + &self, + make_inner: MakeInner, + inner_step: InnerStep, + outer_step: OuterStep, + merge: Merge, + ) -> OuterAcc { + #[cfg(feature = "parallel")] + { + (0..self.e_out.len()) + .into_par_iter() + .map(|x_out| { + let mut inner_acc = make_inner(); + for (x_in, &e_in) in self.e_in.iter().enumerate() { + let row = self.group_index(x_out, x_in); + inner_step(&mut inner_acc, row, x_in, e_in); + } + outer_step(x_out, self.e_out[x_out], inner_acc) + }) + .reduce_with(merge) + .unwrap_or_else(|| { + let inner_acc = make_inner(); + outer_step(0, F::zero(), inner_acc) + }) + } + #[cfg(not(feature = "parallel"))] + { + let mut acc = None; + for x_out in 0..self.e_out.len() { + let mut inner_acc = make_inner(); + for (x_in, &e_in) in self.e_in.iter().enumerate() { + let row = self.group_index(x_out, x_in); + inner_step(&mut inner_acc, row, x_in, e_in); + } + let value = outer_step(x_out, self.e_out[x_out], inner_acc); + acc = Some(match acc { + Some(acc) => merge(acc, value), + None => value, + }); + } + acc.unwrap_or_else(|| { + let inner_acc = make_inner(); + outer_step(0, F::zero(), inner_acc) + }) + } + } +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct GruenSplitEqPolynomial { + current_index: usize, + current_scalar: F, + point: Vec, + e_in_vec: Vec>, + e_out_vec: Vec>, + binding_order: BindingOrder, +} + +impl GruenSplitEqPolynomial { + pub fn new(point: &[F], binding_order: BindingOrder) -> Self { + Self::new_with_scaling(point, binding_order, None) + } + + pub fn new_with_scaling( + point: &[F], + binding_order: BindingOrder, + scaling_factor: Option, + ) -> Self { + if point.is_empty() { + return Self { + current_index: match binding_order { + BindingOrder::LowToHigh => 0, + BindingOrder::HighToLow => 0, + }, + current_scalar: scaling_factor.unwrap_or(F::one()), + point: Vec::new(), + e_in_vec: vec![vec![F::one()]], + e_out_vec: vec![vec![F::one()]], + binding_order, + }; + } + + match binding_order { + BindingOrder::LowToHigh => { + let split = point.len() / 2; + let head = &point[..point.len() - 1]; + let (out_point, in_point) = head.split_at(split.min(head.len())); + #[cfg(feature = "parallel")] + let (e_out_vec, e_in_vec) = rayon::join( + || EqPolynomial::::evals_cached(out_point, None), + || EqPolynomial::::evals_cached(in_point, None), + ); + #[cfg(not(feature = "parallel"))] + let (e_out_vec, e_in_vec) = ( + EqPolynomial::::evals_cached(out_point, None), + EqPolynomial::::evals_cached(in_point, None), + ); + Self { + current_index: point.len(), + current_scalar: scaling_factor.unwrap_or(F::one()), + point: point.to_vec(), + e_in_vec, + e_out_vec, + binding_order, + } + } + BindingOrder::HighToLow => { + let split = point.len() / 2; + let tail = &point[1..]; + let (in_point, out_point) = tail.split_at(split.min(tail.len())); + #[cfg(feature = "parallel")] + let (e_in_vec, e_out_vec) = rayon::join( + || EqPolynomial::::evals_cached_rev(in_point, None), + || EqPolynomial::::evals_cached_rev(out_point, None), + ); + #[cfg(not(feature = "parallel"))] + let (e_in_vec, e_out_vec) = ( + EqPolynomial::::evals_cached_rev(in_point, None), + EqPolynomial::::evals_cached_rev(out_point, None), + ); + Self { + current_index: 0, + current_scalar: scaling_factor.unwrap_or(F::one()), + point: point.to_vec(), + e_in_vec, + e_out_vec, + binding_order, + } + } + } + } + + pub fn current_scalar(&self) -> F { + self.current_scalar + } + + pub fn current_linear_evals(&self) -> (F, F) { + let point = match self.binding_order { + BindingOrder::LowToHigh => self.point[self.current_index - 1], + BindingOrder::HighToLow => self.point[self.current_index], + }; + let at_one = self.current_scalar * point; + (self.current_scalar - at_one, at_one) + } + + pub fn current_index(&self) -> usize { + self.current_index + } + + pub fn e_in_current(&self) -> &[F] { + &self.e_in_vec[self.e_in_vec.len() - 1] + } + + pub fn e_out_current(&self) -> &[F] { + &self.e_out_vec[self.e_out_vec.len() - 1] + } + + pub fn e_in_current_len(&self) -> usize { + self.e_in_current().len() + } + + pub fn e_out_current_len(&self) -> usize { + self.e_out_current().len() + } + + pub fn e_out_in_for_window(&self, window_size: usize) -> (&[F], &[F]) { + assert!( + matches!(self.binding_order, BindingOrder::LowToHigh), + "streaming split-eq windows are only defined for low-to-high" + ); + + let window_size = core::cmp::min(window_size, self.current_index); + let head_len = self.current_index.saturating_sub(window_size); + let split = self.point.len() / 2; + + let head_out_bits = core::cmp::min(head_len, split); + let head_in_bits = head_len.saturating_sub(head_out_bits); + + debug_assert_eq!(head_out_bits + head_in_bits, head_len); + debug_assert!(head_out_bits < self.e_out_vec.len()); + debug_assert!(head_in_bits < self.e_in_vec.len()); + + (&self.e_out_vec[head_out_bits], &self.e_in_vec[head_in_bits]) + } + + pub fn e_active_for_window(&self, window_size: usize) -> Vec { + assert!( + matches!(self.binding_order, BindingOrder::LowToHigh), + "streaming split-eq windows are only defined for low-to-high" + ); + + if window_size <= 1 { + return vec![F::one()]; + } + + let num_unbound = self.current_index; + if window_size > num_unbound { + return vec![F::one()]; + } + + let remaining_point = &self.point[..num_unbound]; + let window_start = remaining_point.len() - window_size; + let (_, window_point) = remaining_point.split_at(window_start); + let (active_point, _) = window_point.split_at(window_size - 1); + EqPolynomial::::evals(active_point, None) + } + + pub fn bind(&mut self, challenge: F) { + if self.point.is_empty() { + return; + } + + match self.binding_order { + BindingOrder::LowToHigh => { + let point = self.point[self.current_index - 1]; + let product = point * challenge; + self.current_scalar *= F::one() - point - challenge + product + product; + self.current_index -= 1; + if self.point.len() / 2 < self.current_index && self.e_in_vec.len() > 1 { + let _ = self.e_in_vec.pop(); + } else if 0 < self.current_index && self.e_out_vec.len() > 1 { + let _ = self.e_out_vec.pop(); + } + } + BindingOrder::HighToLow => { + let point = self.point[self.current_index]; + let product = point * challenge; + self.current_scalar *= F::one() - point - challenge + product + product; + self.current_index += 1; + if self.current_index <= self.point.len() / 2 && self.e_in_vec.len() > 1 { + let _ = self.e_in_vec.pop(); + } else if self.current_index <= self.point.len() && self.e_out_vec.len() > 1 { + let _ = self.e_out_vec.pop(); + } + } + } + } + + pub fn merge(&self) -> Polynomial { + let evals = match self.binding_order { + BindingOrder::LowToHigh => EqPolynomial::::evals( + &self.point[..self.current_index], + Some(self.current_scalar), + ), + BindingOrder::HighToLow => EqPolynomial::::evals( + &self.point[self.current_index..], + Some(self.current_scalar), + ), + }; + Polynomial::new(evals) + } + + /// Computes `s(X) = l(X) * q(X)` where `l` is the current linear eq + /// factor and `q` is quadratic, represented by its constant and quadratic + /// coefficients plus the sumcheck hint `s(0) + s(1)`. + #[expect(clippy::expect_used)] + pub fn gruen_poly_deg_3( + &self, + q_constant: F, + q_quadratic_coeff: F, + s_0_plus_s_1: F, + ) -> UnivariatePoly { + let eq_eval_1 = self.current_scalar + * match self.binding_order { + BindingOrder::LowToHigh => self.point[self.current_index - 1], + BindingOrder::HighToLow => self.point[self.current_index], + }; + let eq_eval_0 = self.current_scalar - eq_eval_1; + let eq_m = eq_eval_1 - eq_eval_0; + let eq_eval_2 = eq_eval_1 + eq_m; + let eq_eval_3 = eq_eval_2 + eq_m; + + let quadratic_eval_0 = q_constant; + let cubic_eval_0 = eq_eval_0 * quadratic_eval_0; + let cubic_eval_1 = s_0_plus_s_1 - cubic_eval_0; + let quadratic_eval_1 = cubic_eval_1 + * eq_eval_1 + .inverse() + .expect("current eq evaluation at one must be invertible"); + let e_times_2 = q_quadratic_coeff + q_quadratic_coeff; + let quadratic_eval_2 = quadratic_eval_1 + quadratic_eval_1 - quadratic_eval_0 + e_times_2; + let quadratic_eval_3 = + quadratic_eval_2 + quadratic_eval_1 - quadratic_eval_0 + e_times_2 + e_times_2; + + UnivariatePoly::interpolate_over_integers(&[ + cubic_eval_0, + cubic_eval_1, + eq_eval_2 * quadratic_eval_2, + eq_eval_3 * quadratic_eval_3, + ]) + } + + #[expect(clippy::expect_used)] + pub fn gruen_poly_from_evals(&self, q_evals: &[F], s_0_plus_s_1: F) -> UnivariatePoly { + let r_round = match self.binding_order { + BindingOrder::LowToHigh => self.point[self.current_index - 1], + BindingOrder::HighToLow => self.point[self.current_index], + }; + + let l_at_0 = self.current_scalar * (F::one() - r_round); + let l_at_1 = self.current_scalar * r_round; + let q_at_0 = (s_0_plus_s_1 - l_at_1 * q_evals[0]) + * l_at_0 + .inverse() + .expect("current eq evaluation at zero must be invertible"); + + let mut full_q_evals = Vec::with_capacity(q_evals.len() + 1); + full_q_evals.push(q_at_0); + full_q_evals.extend_from_slice(q_evals); + let q_coeffs = UnivariatePoly::from_evals_toom(&full_q_evals).into_coefficients(); + + let l_c0 = l_at_0; + let l_c1 = l_at_1 - l_at_0; + let mut s_coeffs = vec![F::zero(); q_coeffs.len() + 1]; + for (index, q_coeff) in q_coeffs.into_iter().enumerate() { + s_coeffs[index] += q_coeff * l_c0; + s_coeffs[index + 1] += q_coeff * l_c1; + } + + UnivariatePoly::new(s_coeffs) + } + + #[inline(always)] + pub fn group_index(&self, x_out: usize, x_in: usize) -> usize { + let in_bits = self.e_in_current_len().trailing_zeros() as usize; + (x_out << in_bits) | x_in + } + + #[inline] + pub fn par_fold_out_in< + OuterAcc: Send, + InnerAcc: Send, + MakeInner: Fn() -> InnerAcc + Sync + Send, + InnerStep: Fn(&mut InnerAcc, usize, usize, F) + Sync + Send, + OuterStep: Fn(usize, F, InnerAcc) -> OuterAcc + Sync + Send, + Merge: Fn(OuterAcc, OuterAcc) -> OuterAcc + Sync + Send, + >( + &self, + make_inner: MakeInner, + inner_step: InnerStep, + outer_step: OuterStep, + merge: Merge, + ) -> OuterAcc { + let e_out = self.e_out_current(); + let e_in = self.e_in_current(); + #[cfg(feature = "parallel")] + { + (0..e_out.len()) + .into_par_iter() + .map(|x_out| { + let mut inner_acc = make_inner(); + for (x_in, &e_in) in e_in.iter().enumerate() { + let row = self.group_index(x_out, x_in); + inner_step(&mut inner_acc, row, x_in, e_in); + } + outer_step(x_out, e_out[x_out], inner_acc) + }) + .reduce_with(merge) + .unwrap_or_else(|| { + let inner_acc = make_inner(); + outer_step(0, F::zero(), inner_acc) + }) + } + #[cfg(not(feature = "parallel"))] + { + let mut acc = None; + for x_out in 0..e_out.len() { + let mut inner_acc = make_inner(); + for (x_in, &e_in) in e_in.iter().enumerate() { + let row = self.group_index(x_out, x_in); + inner_step(&mut inner_acc, row, x_in, e_in); + } + let value = outer_step(x_out, e_out[x_out], inner_acc); + acc = Some(match acc { + Some(acc) => merge(acc, value), + None => value, + }); + } + acc.unwrap_or_else(|| { + let inner_acc = make_inner(); + outer_step(0, F::zero(), inner_acc) + }) + } + } +} + +#[cfg(test)] +mod tests { + use jolt_field::{Fr, FromPrimitiveInt, RandomSampling}; + use rand_chacha::ChaCha20Rng; + use rand_core::SeedableRng; + + use super::*; + + fn random_point(len: usize, seed: u64) -> Vec { + let mut rng = ChaCha20Rng::seed_from_u64(seed); + (0..len).map(|_| Fr::random(&mut rng)).collect() + } + + #[test] + fn tensor_eq_table_factors_full_eq_table() { + for vars in 0..=10 { + let point = random_point(vars, 100 + vars as u64); + let tensor = TensorEqTable::::new(&point); + let full = EqPolynomial::::evals(&point, None); + assert_eq!(tensor.len(), full.len()); + for x_out in 0..tensor.e_out().len() { + for x_in in 0..tensor.e_in().len() { + let row = tensor.group_index(x_out, x_in); + assert_eq!(tensor.e_out()[x_out] * tensor.e_in()[x_in], full[row]); + } + } + } + } + + #[test] + fn tensor_eq_fold_matches_full_table_dot_product() { + let point = random_point(9, 211); + let values = random_point(1 << point.len(), 307); + let tensor = TensorEqTable::::new(&point); + let folded = tensor.par_fold_out_in( + || Fr::from_u64(0), + |inner, row, _x_in, e_in| { + *inner += e_in * values[row]; + }, + |_x_out, e_out, inner| e_out * inner, + |left, right| left + right, + ); + let full = EqPolynomial::::evals(&point, None) + .into_iter() + .zip(values) + .map(|(eq, value)| eq * value) + .sum::(); + assert_eq!(folded, full); + } + + #[test] + fn tensor_eq_evaluates_slices_in_one_fold() { + let point = random_point(8, 811); + let values = [ + random_point(1 << point.len(), 907), + random_point(1 << point.len(), 1009), + random_point(1 << point.len(), 1103), + ]; + let slices = values.iter().map(Vec::as_slice).collect::>(); + let tensor = TensorEqTable::::new(&point); + let actual = tensor.evaluate_slices(&slices); + let eq = EqPolynomial::::evals(&point, None); + let expected = values + .iter() + .map(|values| { + eq.iter() + .zip(values) + .map(|(&eq, &value)| eq * value) + .sum::() + }) + .collect::>(); + assert_eq!(actual, expected); + } + + #[test] + fn gruen_low_to_high_merge_matches_bound_eq() { + let point = random_point(10, 401); + let mut split = GruenSplitEqPolynomial::::new(&point, BindingOrder::LowToHigh); + let mut dense = Polynomial::new(EqPolynomial::::evals(&point, None)); + assert_eq!(split.merge(), dense); + + let challenges = random_point(point.len(), 509); + for challenge in challenges { + split.bind(challenge); + dense.bind_with_order(challenge, BindingOrder::LowToHigh); + assert_eq!(split.merge(), dense); + } + } + + #[test] + fn gruen_high_to_low_merge_matches_bound_eq() { + let point = random_point(10, 601); + let mut split = GruenSplitEqPolynomial::::new(&point, BindingOrder::HighToLow); + let mut dense = Polynomial::new(EqPolynomial::::evals(&point, None)); + assert_eq!(split.merge(), dense); + + let challenges = random_point(point.len(), 709); + for challenge in challenges { + split.bind(challenge); + dense.bind_with_order(challenge, BindingOrder::HighToLow); + assert_eq!(split.merge(), dense); + } + } + + #[test] + fn gruen_current_linear_factor_matches_merged_sumcheck_pair() { + for order in [BindingOrder::LowToHigh, BindingOrder::HighToLow] { + let point = random_point(8, 811); + let challenges = random_point(4, 919); + let mut split = GruenSplitEqPolynomial::::new(&point, order); + for (round, challenge) in challenges.into_iter().enumerate() { + let merged = split.merge(); + let (linear_0, linear_1) = split.current_linear_evals(); + for x_out in 0..split.e_out_current_len() { + for x_in in 0..split.e_in_current_len() { + let row = split.group_index(x_out, x_in); + let dense_row = match order { + BindingOrder::LowToHigh => row, + BindingOrder::HighToLow => { + let out_bits = split.e_out_current_len().trailing_zeros() as usize; + (x_in << out_bits) | x_out + } + }; + let head = split.e_out_current()[x_out] * split.e_in_current()[x_in]; + let (dense_0, dense_1) = merged.sumcheck_eval_pair(dense_row, order); + assert_eq!( + head * linear_0, + dense_0, + "{order:?} round {round} row {row} eval 0" + ); + assert_eq!( + head * linear_1, + dense_1, + "{order:?} round {round} row {row} eval 1" + ); + } + } + split.bind(challenge); + } + } + } +} diff --git a/crates/jolt-poly/src/univariate.rs b/crates/jolt-poly/src/univariate.rs index 47cc890e80..3d9ab4ac62 100644 --- a/crates/jolt-poly/src/univariate.rs +++ b/crates/jolt-poly/src/univariate.rs @@ -186,7 +186,9 @@ impl UnivariatePoly { self.coefficients.len() >= 2, "cannot compress a polynomial of degree < 1" ); - let coeffs = [&self.coefficients[..1], &self.coefficients[2..]].concat(); + let mut coeffs = Vec::with_capacity(self.coefficients.len() - 1); + coeffs.push(self.coefficients[0]); + coeffs.extend_from_slice(&self.coefficients[2..]); debug_assert_eq!(coeffs.len() + 1, self.coefficients.len()); crate::CompressedPoly::new(coeffs) } @@ -195,8 +197,12 @@ impl UnivariatePoly { /// on the Vandermonde system. Equivalent to `interpolate_over_integers` but uses a /// direct matrix solve instead of the Lagrange formula. pub fn from_evals(evals: &[F]) -> Self { - Self { - coefficients: gaussian_elimination_vandermonde(evals), + match evals.len() { + 3 => Self::from_evals_degree2(evals[0], evals[1], evals[2]), + 4 => Self::from_evals_degree3(evals[0], evals[1], evals[2], evals[3]), + _ => Self { + coefficients: gaussian_elimination_vandermonde(evals), + }, } } @@ -204,10 +210,52 @@ impl UnivariatePoly { /// /// Recovers `p(1) = hint - p(0)` and then interpolates over the full set `{0, 1, ..., n-1}`. pub fn from_evals_and_hint(hint: F, evals: &[F]) -> Self { - let mut full = evals.to_vec(); - let eval_at_1 = hint - full[0]; - full.insert(1, eval_at_1); - Self::from_evals(&full) + match evals.len() { + 2 => { + let e0 = evals[0]; + let e1 = hint - e0; + let e2 = evals[1]; + Self::from_evals_degree2(e0, e1, e2) + } + 3 => { + let e0 = evals[0]; + let e1 = hint - e0; + let e2 = evals[1]; + let e3 = evals[2]; + Self::from_evals_degree3(e0, e1, e2, e3) + } + _ => { + let mut full = Vec::with_capacity(evals.len() + 1); + full.push(evals[0]); + full.push(hint - evals[0]); + full.extend_from_slice(&evals[1..]); + Self::from_evals(&full) + } + } + } + + #[expect(clippy::expect_used)] + fn from_evals_degree2(e0: F, e1: F, e2: F) -> Self { + let two_inv = F::from_u64(2).inverse().expect("2 is invertible"); + let c0 = e0; + let c2 = (e0 - e1 - e1 + e2) * two_inv; + let c1 = e1 - e0 - c2; + Self { + coefficients: vec![c0, c1, c2], + } + } + + #[expect(clippy::expect_used)] + fn from_evals_degree3(e0: F, e1: F, e2: F, e3: F) -> Self { + let two_inv = F::from_u64(2).inverse().expect("2 is invertible"); + let six_inv = F::from_u64(6).inverse().expect("6 is invertible"); + let c0 = e0; + let c3 = (e3 - e0 + (e1 - e2) * F::from_u64(3)) * six_inv; + let c2 = (e0 - e1 - e1 + e2) * two_inv - c3 - c3 - c3; + let c1 = e1 - e0 - c2 - c3; + Self { + coefficients: vec![c0, c1, c2, c3], + } } /// Interpolates from evaluations at `[0, 1, ..., degree-1, ∞]`. diff --git a/crates/jolt-r1cs/Cargo.toml b/crates/jolt-r1cs/Cargo.toml index d1c8ec6080..d39544c3fd 100644 --- a/crates/jolt-r1cs/Cargo.toml +++ b/crates/jolt-r1cs/Cargo.toml @@ -11,7 +11,8 @@ workspace = true [dependencies] jolt-claims.workspace = true jolt-field = { path = "../jolt-field" } -jolt-poly = { path = "../jolt-poly" } +jolt-poly = { path = "../jolt-poly", features = ["r1cs"] } +num-bigint.workspace = true num-traits.workspace = true rayon = { workspace = true, optional = true } serde = { workspace = true, features = ["derive"] } diff --git a/crates/jolt-r1cs/src/builder.rs b/crates/jolt-r1cs/src/builder.rs index 9f41b7e867..33b35a0ca7 100644 --- a/crates/jolt-r1cs/src/builder.rs +++ b/crates/jolt-r1cs/src/builder.rs @@ -39,6 +39,12 @@ pub struct LinearCombination { pub terms: Vec<(Variable, F)>, } +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct AssignedScalar { + pub value: F, + pub lc: LinearCombination, +} + impl LinearCombination { pub fn zero() -> Self { Self { terms: Vec::new() } @@ -76,7 +82,6 @@ impl LinearCombination { self } - #[cfg(test)] pub fn as_constant(&self) -> Option { let mut value = F::zero(); for &(variable, coefficient) in &self.terms { @@ -128,6 +133,29 @@ impl From for LinearCombination { } } +impl AssignedScalar { + pub fn new(value: F, lc: LinearCombination) -> Self { + Self { value, lc } + } + + pub fn constant(value: F) -> Self { + Self::new(value, LinearCombination::constant(value)) + } + + pub fn variable(value: F, variable: Variable) -> Self { + Self::new(value, LinearCombination::variable(variable)) + } + + pub fn alloc(builder: &mut R1csBuilder, value: F) -> Self { + let variable = builder.alloc(value); + Self::variable(value, variable) + } + + pub fn scale(self, scale: F) -> Self { + Self::new(self.value * scale, self.lc.scale(scale)) + } +} + impl Add for LinearCombination { type Output = Self; diff --git a/crates/jolt-r1cs/src/constraint.rs b/crates/jolt-r1cs/src/constraint.rs index 5ba0d2b510..a7c9f5247d 100644 --- a/crates/jolt-r1cs/src/constraint.rs +++ b/crates/jolt-r1cs/src/constraint.rs @@ -1,6 +1,7 @@ //! Sparse per-cycle R1CS constraint matrices. use jolt_field::Field; +use jolt_poly::EqPolynomial; use serde::{Deserialize, Serialize}; use thiserror::Error as ThisError; @@ -9,12 +10,21 @@ pub type SparseRow = Vec<(usize, F)>; #[derive(Clone, Debug, ThisError, PartialEq, Eq)] pub enum ConstraintMatrixEvalError { + #[error("row point length mismatch: expected {expected}, got {actual}")] + RowPointLengthMismatch { expected: usize, actual: usize }, + #[error("column point length mismatch: expected {expected}, got {actual}")] + ColumnPointLengthMismatch { expected: usize, actual: usize }, #[error("row weights length mismatch: expected at least {expected}, got {actual}")] RowWeightsLengthMismatch { expected: usize, actual: usize }, #[error("column weights length mismatch: expected {expected}, got {actual}")] ColumnWeightsLengthMismatch { expected: usize, actual: usize }, #[error("column {column} out of bounds for {num_vars} variables")] ColumnOutOfBounds { column: usize, num_vars: usize }, + #[error("matrix {dimension} dimension {value} cannot be padded to a power of two")] + PaddedDimensionOverflow { + dimension: &'static str, + value: usize, + }, #[error("matrix column range overflow: start {start}, count {count}")] ColumnRangeOverflow { start: usize, count: usize }, } @@ -216,6 +226,55 @@ impl ConstraintMatrices { Ok(weighted) } + pub fn evaluate_matrix_mles( + &self, + row_point: &[F], + column_point: &[F], + ) -> Result, ConstraintMatrixEvalError> { + let expected_row_vars = log_padded_dimension("rows", self.num_constraints)?; + if row_point.len() != expected_row_vars { + return Err(ConstraintMatrixEvalError::RowPointLengthMismatch { + expected: expected_row_vars, + actual: row_point.len(), + }); + } + + let expected_column_vars = log_padded_dimension("columns", self.num_vars)?; + if column_point.len() != expected_column_vars { + return Err(ConstraintMatrixEvalError::ColumnPointLengthMismatch { + expected: expected_column_vars, + actual: column_point.len(), + }); + } + + let row_eq = EqPolynomial::new(row_point.to_vec()).evaluations(); + let column_eq = EqPolynomial::new(column_point.to_vec()).evaluations(); + + Ok(MatrixColumnContributions { + a: matrix_bilinear_eval_columns( + &self.a, + &row_eq, + &column_eq[..self.num_vars], + 0, + self.num_vars, + )?, + b: matrix_bilinear_eval_columns( + &self.b, + &row_eq, + &column_eq[..self.num_vars], + 0, + self.num_vars, + )?, + c: matrix_bilinear_eval_columns( + &self.c, + &row_eq, + &column_eq[..self.num_vars], + 0, + self.num_vars, + )?, + }) + } + pub fn linear_form_bilinear_eval( &self, row_weights: &[F], @@ -249,6 +308,19 @@ impl ConstraintMatrices { } } +fn log_padded_dimension( + dimension: &'static str, + raw: usize, +) -> Result { + let padded = raw.max(1).checked_next_power_of_two().ok_or( + ConstraintMatrixEvalError::PaddedDimensionOverflow { + dimension, + value: raw, + }, + )?; + Ok(padded.trailing_zeros() as usize) +} + #[inline] fn dot(row: &[(usize, F)], witness: &[F]) -> F { let mut acc = F::zero(); @@ -352,6 +424,70 @@ mod tests { assert_eq!(m.check_witness(&w), Err(0)); } + #[test] + fn matrix_mles_evaluate_sparse_entries() { + let m = ConstraintMatrices::new( + 2, + 3, + vec![vec![(1, Fr::from_u64(2))], vec![(2, Fr::from_u64(3))]], + vec![vec![(0, Fr::from_u64(5))], vec![]], + vec![vec![], vec![(1, Fr::from_u64(7))]], + ); + + let row_point = [Fr::from_u64(11)]; + let column_point = [Fr::from_u64(13), Fr::from_u64(17)]; + let evals = m + .evaluate_matrix_mles(&row_point, &column_point) + .expect("matrix MLE evaluation should accept matching point sizes"); + + let row_0 = Fr::from_u64(1) - row_point[0]; + let row_1 = row_point[0]; + let col_0 = (Fr::from_u64(1) - column_point[0]) * (Fr::from_u64(1) - column_point[1]); + let col_1 = (Fr::from_u64(1) - column_point[0]) * column_point[1]; + let col_2 = column_point[0] * (Fr::from_u64(1) - column_point[1]); + + assert_eq!( + evals.a, + row_0 * col_1 * Fr::from_u64(2) + row_1 * col_2 * Fr::from_u64(3) + ); + assert_eq!(evals.b, row_0 * col_0 * Fr::from_u64(5)); + assert_eq!(evals.c, row_1 * col_1 * Fr::from_u64(7)); + } + + #[test] + fn matrix_mles_reject_wrong_point_lengths() { + let m = ConstraintMatrices::new(2, 3, vec![vec![]; 2], vec![vec![]; 2], vec![vec![]; 2]); + + assert_eq!( + m.evaluate_matrix_mles(&[], &[Fr::from_u64(1), Fr::from_u64(2)]), + Err(ConstraintMatrixEvalError::RowPointLengthMismatch { + expected: 1, + actual: 0 + }) + ); + assert_eq!( + m.evaluate_matrix_mles(&[Fr::from_u64(1)], &[Fr::from_u64(2)]), + Err(ConstraintMatrixEvalError::ColumnPointLengthMismatch { + expected: 2, + actual: 1 + }) + ); + } + + #[test] + fn matrix_mles_reject_unpaddable_dimensions() { + let m: ConstraintMatrices = + ConstraintMatrices::new(1, usize::MAX, vec![vec![]], vec![vec![]], vec![vec![]]); + + assert_eq!( + m.evaluate_matrix_mles(&[], &[]), + Err(ConstraintMatrixEvalError::PaddedDimensionOverflow { + dimension: "columns", + value: usize::MAX + }) + ); + } + #[test] #[should_panic(expected = "num_constraints")] fn new_rejects_row_count_mismatch() { diff --git a/crates/jolt-r1cs/src/lib.rs b/crates/jolt-r1cs/src/lib.rs index be006826c2..07c4da6fef 100644 --- a/crates/jolt-r1cs/src/lib.rs +++ b/crates/jolt-r1cs/src/lib.rs @@ -16,9 +16,11 @@ pub mod constraint; pub mod constraints; pub mod key; pub mod lowering; +pub mod nonnative; pub mod provider; +pub mod scalar; -pub use builder::{LinearCombination, R1csBuilder, R1csBuilderError, Variable}; +pub use builder::{AssignedScalar, LinearCombination, R1csBuilder, R1csBuilderError, Variable}; pub use column::R1csColumn; pub use constraint::{ ConstraintMatrices, ConstraintMatrixEvalError, MatrixColumnContributions, SparseRow, @@ -26,7 +28,10 @@ pub use constraint::{ }; pub use key::R1csKey; pub use lowering::{ - assert_claim_expr_eq, lower_claim_expr, ClaimLoweringError, ClaimSourceTable, ClaimSources, - SourceValue, + assert_claim_expr_eq, assert_claim_expr_gadget_eq, lower_claim_expr, lower_claim_expr_gadget, + ClaimLoweringError, ClaimSourceTable, ClaimSources, ScalarClaimSourceTable, ScalarClaimSources, + ScalarSourceValue, SourceValue, }; +pub use nonnative::FqVar; pub use provider::{R1csSource, SpartanChallenges}; +pub use scalar::{scalar_affine_combination, scalar_dot_product, ScalarGadget}; diff --git a/crates/jolt-r1cs/src/lowering.rs b/crates/jolt-r1cs/src/lowering.rs index 0656f39b6c..cbe616bb20 100644 --- a/crates/jolt-r1cs/src/lowering.rs +++ b/crates/jolt-r1cs/src/lowering.rs @@ -1,8 +1,9 @@ use jolt_claims::{Expr, Source}; use jolt_field::Field; +use num_traits::Zero; use thiserror::Error; -use crate::{LinearCombination, R1csBuilder, Variable}; +use crate::{LinearCombination, R1csBuilder, ScalarGadget, Variable}; #[derive(Clone, Debug, Error, PartialEq, Eq)] pub enum ClaimLoweringError { @@ -24,12 +25,50 @@ pub trait ClaimSources { fn public(&mut self, id: &Self::Public) -> Result, ClaimLoweringError>; } +pub trait ScalarClaimSources +where + S: ScalarGadget, +{ + type Opening; + type Challenge; + type Public; + + fn opening(&mut self, id: &Self::Opening) -> Result, ClaimLoweringError>; + fn challenge( + &mut self, + id: &Self::Challenge, + ) -> Result, ClaimLoweringError>; + fn public(&mut self, id: &Self::Public) -> Result, ClaimLoweringError>; +} + #[derive(Clone, Debug, PartialEq, Eq)] pub enum SourceValue { Constant(F), LinearCombination(LinearCombination), } +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum ScalarSourceValue +where + S: ScalarGadget, +{ + Constant(S::Scalar), + Scalar(S), +} + +impl ScalarSourceValue +where + S: ScalarGadget, +{ + pub fn constant(value: S::Scalar) -> Self { + Self::Constant(value) + } + + pub fn scalar(value: S) -> Self { + Self::Scalar(value) + } +} + impl SourceValue { pub fn variable(variable: Variable) -> Self { Self::LinearCombination(LinearCombination::variable(variable)) @@ -54,6 +93,16 @@ pub struct ClaimSourceTable { publics: Vec<(P, SourceValue)>, } +#[derive(Clone, Debug, Default)] +pub struct ScalarClaimSourceTable +where + S: ScalarGadget, +{ + openings: Vec<(O, ScalarSourceValue)>, + challenges: Vec<(C, ScalarSourceValue)>, + publics: Vec<(P, ScalarSourceValue)>, +} + impl ClaimSourceTable { pub fn new() -> Self { Self { @@ -146,6 +195,97 @@ impl ClaimSourceTable { } } +impl ScalarClaimSourceTable +where + S: ScalarGadget, +{ + pub fn new() -> Self { + Self { + openings: Vec::new(), + challenges: Vec::new(), + publics: Vec::new(), + } + } + + pub fn insert_opening_constant(&mut self, id: O, value: S::Scalar) + where + O: PartialEq, + { + self.insert_opening_source(id, ScalarSourceValue::constant(value)); + } + + pub fn insert_opening_scalar(&mut self, id: O, value: S) + where + O: PartialEq, + { + self.insert_opening_source(id, ScalarSourceValue::scalar(value)); + } + + pub fn insert_opening_source(&mut self, id: O, source: ScalarSourceValue) + where + O: PartialEq, + { + assert!( + !self.openings.iter().any(|(candidate, _)| candidate == &id), + "duplicate opening source" + ); + self.openings.push((id, source)); + } + + pub fn insert_challenge_constant(&mut self, id: C, value: S::Scalar) + where + C: PartialEq, + { + self.insert_challenge_source(id, ScalarSourceValue::constant(value)); + } + + pub fn insert_challenge_scalar(&mut self, id: C, value: S) + where + C: PartialEq, + { + self.insert_challenge_source(id, ScalarSourceValue::scalar(value)); + } + + pub fn insert_challenge_source(&mut self, id: C, source: ScalarSourceValue) + where + C: PartialEq, + { + assert!( + !self + .challenges + .iter() + .any(|(candidate, _)| candidate == &id), + "duplicate challenge source" + ); + self.challenges.push((id, source)); + } + + pub fn insert_public_constant(&mut self, id: P, value: S::Scalar) + where + P: PartialEq, + { + self.insert_public_source(id, ScalarSourceValue::constant(value)); + } + + pub fn insert_public_scalar(&mut self, id: P, value: S) + where + P: PartialEq, + { + self.insert_public_source(id, ScalarSourceValue::scalar(value)); + } + + pub fn insert_public_source(&mut self, id: P, source: ScalarSourceValue) + where + P: PartialEq, + { + assert!( + !self.publics.iter().any(|(candidate, _)| candidate == &id), + "duplicate public source" + ); + self.publics.push((id, source)); + } +} + impl ClaimSources for ClaimSourceTable { @@ -175,6 +315,42 @@ impl ClaimSources } } +impl ScalarClaimSources for ScalarClaimSourceTable +where + S: ScalarGadget, + O: PartialEq, + P: PartialEq, + C: PartialEq, +{ + type Opening = O; + type Challenge = C; + type Public = P; + + fn opening(&mut self, id: &Self::Opening) -> Result, ClaimLoweringError> { + self.openings + .iter() + .find_map(|(candidate, source)| (candidate == id).then_some(source.clone())) + .ok_or(ClaimLoweringError::MissingOpening) + } + + fn challenge( + &mut self, + id: &Self::Challenge, + ) -> Result, ClaimLoweringError> { + self.challenges + .iter() + .find_map(|(candidate, source)| (candidate == id).then_some(source.clone())) + .ok_or(ClaimLoweringError::MissingChallenge) + } + + fn public(&mut self, id: &Self::Public) -> Result, ClaimLoweringError> { + self.publics + .iter() + .find_map(|(candidate, source)| (candidate == id).then_some(source.clone())) + .ok_or(ClaimLoweringError::MissingPublic) + } +} + pub fn lower_claim_expr( builder: &mut R1csBuilder, expression: &Expr, @@ -226,6 +402,55 @@ where Ok(()) } +pub fn lower_claim_expr_gadget( + builder: &mut R1csBuilder, + expression: &Expr, + sources: &mut R, +) -> Result +where + S: ScalarGadget, + R: ScalarClaimSources, +{ + let mut result = S::constant(S::Scalar::zero()); + + for term in &expression.terms { + let mut coefficient = term.coefficient; + let mut factors = Vec::new(); + + for source in &term.factors { + let source = match source { + Source::Opening(id) => sources.opening(id)?, + Source::Challenge(id) => sources.challenge(id)?, + Source::Public(id) => sources.public(id)?, + }; + match source { + ScalarSourceValue::Constant(value) => coefficient *= value, + ScalarSourceValue::Scalar(value) => factors.push(value), + } + } + + let term = lower_gadget_product(builder, coefficient, factors); + result = result.add(builder, &term); + } + + Ok(result) +} + +pub fn assert_claim_expr_gadget_eq( + builder: &mut R1csBuilder, + expression: &Expr, + expected: &S, + sources: &mut R, +) -> Result<(), ClaimLoweringError> +where + S: ScalarGadget, + R: ScalarClaimSources, +{ + let actual = lower_claim_expr_gadget(builder, expression, sources)?; + actual.assert_equal(builder, expected); + Ok(()) +} + fn lower_product( builder: &mut R1csBuilder, coefficient: F, @@ -247,12 +472,38 @@ fn lower_product( product.scale(coefficient) } +fn lower_gadget_product( + builder: &mut R1csBuilder, + coefficient: S::Scalar, + factors: Vec, +) -> S +where + S: ScalarGadget, +{ + if coefficient.is_zero() { + return S::constant(S::Scalar::zero()); + } + + let mut factors = factors.into_iter(); + let Some(mut product) = factors.next() else { + return S::constant(coefficient); + }; + + for factor in factors { + product = product.mul(builder, &factor); + } + + product.scale_by_constant(builder, coefficient) +} + #[cfg(test)] #[expect(clippy::expect_used, reason = "tests may panic on assertion failures")] mod tests { use super::*; use jolt_claims::{challenge, constant, opening, public, Expr}; - use jolt_field::{Fr, FromPrimitiveInt}; + use jolt_field::{Fq, Fr, FromPrimitiveInt}; + + use crate::{AssignedScalar, FqVar}; #[derive(Clone, Copy, Debug, PartialEq, Eq)] enum Opening { @@ -485,4 +736,196 @@ mod tests { let witness = builder.witness().expect("witness is assigned"); assert!(builder.into_matrices().check_witness(&witness).is_ok()); } + + #[test] + fn native_gadget_lowering_accepts_formula() { + let mut builder = R1csBuilder::::new(); + let a = AssignedScalar::alloc(&mut builder, Fr::from_u64(3)); + let b = AssignedScalar::alloc(&mut builder, Fr::from_u64(5)); + let gamma = AssignedScalar::alloc(&mut builder, Fr::from_u64(2)); + let mut sources = + ScalarClaimSourceTable::, Opening, Public, Challenge>::new(); + sources.insert_opening_scalar(Opening::A, a); + sources.insert_opening_scalar(Opening::B, b); + sources.insert_challenge_scalar(Challenge::Gamma, gamma); + sources.insert_public_constant(Public::Offset, Fr::from_u64(4)); + + let expected = AssignedScalar::constant(Fr::from_u64(47)); + assert_claim_expr_gadget_eq(&mut builder, &sample_expression(), &expected, &mut sources) + .expect("native scalar-gadget expression lowers"); + + assert!(builder_accepts(builder)); + } + + #[test] + fn native_gadget_lowering_rejects_tampering() { + let mut builder = R1csBuilder::::new(); + let a = AssignedScalar::alloc(&mut builder, Fr::from_u64(3)); + let b = AssignedScalar::alloc(&mut builder, Fr::from_u64(5)); + let gamma = AssignedScalar::alloc(&mut builder, Fr::from_u64(2)); + let expected = AssignedScalar::alloc(&mut builder, Fr::from_u64(47)); + let targets = [ + ("opening", variable(&a)), + ("challenge", variable(&gamma)), + ("expected", variable(&expected)), + ]; + let mut sources = + ScalarClaimSourceTable::, Opening, Public, Challenge>::new(); + sources.insert_opening_scalar(Opening::A, a); + sources.insert_opening_scalar(Opening::B, b); + sources.insert_challenge_scalar(Challenge::Gamma, gamma); + sources.insert_public_constant(Public::Offset, Fr::from_u64(4)); + + assert_claim_expr_gadget_eq(&mut builder, &sample_expression(), &expected, &mut sources) + .expect("native scalar-gadget expression lowers"); + + assert_tampering_rejected(builder, targets); + } + + #[test] + fn nonnative_gadget_lowering_accepts_formula() { + let mut builder = R1csBuilder::::new(); + let a = FqVar::alloc(&mut builder, Fq::from_u64(3)); + let b = FqVar::alloc(&mut builder, Fq::from_u64(5)); + let gamma = FqVar::alloc(&mut builder, Fq::from_u64(2)); + let mut sources = ScalarClaimSourceTable::::new(); + sources.insert_opening_scalar(Opening::A, a); + sources.insert_opening_scalar(Opening::B, b); + sources.insert_challenge_scalar(Challenge::Gamma, gamma); + sources.insert_public_constant(Public::Offset, Fq::from_u64(4)); + + let expected = FqVar::constant(Fq::from_u64(47)); + assert_claim_expr_gadget_eq(&mut builder, &sample_expression(), &expected, &mut sources) + .expect("non-native scalar-gadget expression lowers"); + + assert!(builder_accepts(builder)); + } + + #[test] + fn nonnative_gadget_lowering_rejects_tampering() { + let mut builder = R1csBuilder::::new(); + let a = FqVar::alloc(&mut builder, Fq::from_u64(3)); + let b = FqVar::alloc(&mut builder, Fq::from_u64(5)); + let gamma = FqVar::alloc(&mut builder, Fq::from_u64(2)); + let expected = FqVar::alloc(&mut builder, Fq::from_u64(47)); + let targets = [ + ("opening limb", variable(&a.limbs()[0])), + ("challenge limb", variable(&gamma.limbs()[0])), + ("expected limb", variable(&expected.limbs()[0])), + ]; + let mut sources = ScalarClaimSourceTable::::new(); + sources.insert_opening_scalar(Opening::A, a); + sources.insert_opening_scalar(Opening::B, b); + sources.insert_challenge_scalar(Challenge::Gamma, gamma); + sources.insert_public_constant(Public::Offset, Fq::from_u64(4)); + + assert_claim_expr_gadget_eq(&mut builder, &sample_expression(), &expected, &mut sources) + .expect("non-native scalar-gadget expression lowers"); + + assert_tampering_rejected(builder, targets); + } + + #[test] + fn nonnative_gadget_lowering_rejects_bad_expected_output() { + let mut builder = R1csBuilder::::new(); + let a = FqVar::alloc(&mut builder, Fq::from_u64(3)); + let b = FqVar::alloc(&mut builder, Fq::from_u64(5)); + let gamma = FqVar::alloc(&mut builder, Fq::from_u64(2)); + let mut sources = ScalarClaimSourceTable::::new(); + sources.insert_opening_scalar(Opening::A, a); + sources.insert_opening_scalar(Opening::B, b); + sources.insert_challenge_scalar(Challenge::Gamma, gamma); + sources.insert_public_constant(Public::Offset, Fq::from_u64(4)); + + let expected = FqVar::constant(Fq::from_u64(48)); + assert_claim_expr_gadget_eq(&mut builder, &sample_expression(), &expected, &mut sources) + .expect("non-native scalar-gadget expression lowers"); + + assert!(builder_rejects(builder)); + } + + #[test] + fn gadget_lowering_missing_sources_are_typed_errors() { + let mut builder = R1csBuilder::::new(); + let mut sources = + ScalarClaimSourceTable::, Opening, Public, Challenge>::new(); + + let opening_expr: Expr = opening(Opening::A); + let challenge_expr: Expr = challenge(Challenge::Gamma); + let public_expr: Expr = public(Public::Offset); + + assert_eq!( + lower_claim_expr_gadget(&mut builder, &opening_expr, &mut sources), + Err(ClaimLoweringError::MissingOpening) + ); + assert_eq!( + lower_claim_expr_gadget(&mut builder, &challenge_expr, &mut sources), + Err(ClaimLoweringError::MissingChallenge) + ); + assert_eq!( + lower_claim_expr_gadget(&mut builder, &public_expr, &mut sources), + Err(ClaimLoweringError::MissingPublic) + ); + } + + fn sample_expression() -> Expr + where + F: Field, + { + constant(F::from_u64(2)) * opening(Opening::A) * opening(Opening::B) + + challenge(Challenge::Gamma) * public(Public::Offset) + + constant(F::from_u64(9)) + } + + fn builder_accepts(builder: R1csBuilder) -> bool + where + F: Field, + { + let witness = builder.witness().expect("witness is assigned"); + builder.into_matrices().check_witness(&witness).is_ok() + } + + fn builder_rejects(builder: R1csBuilder) -> bool + where + F: Field, + { + let witness = builder.witness().expect("witness is assigned"); + builder.into_matrices().check_witness(&witness).is_err() + } + + fn assert_tampering_rejected( + builder: R1csBuilder, + targets: impl IntoIterator, + ) where + F: Field, + { + let witness = builder.witness().expect("witness is assigned"); + let matrices = builder.into_matrices(); + assert!(matrices.check_witness(&witness).is_ok()); + + for (label, variable) in targets { + let mut tampered = witness.clone(); + tampered[variable.index()] += F::one(); + assert!( + matrices.check_witness(&tampered).is_err(), + "{label} accepted after tampering variable {}", + variable.index() + ); + } + } + + fn variable(scalar: &AssignedScalar) -> Variable + where + F: Field, + { + assert_eq!(scalar.lc.terms.len(), 1); + let (variable, coefficient) = scalar + .lc + .terms + .first() + .copied() + .expect("linear combination has one term"); + assert_eq!(coefficient, F::one()); + variable + } } diff --git a/crates/jolt-r1cs/src/nonnative.rs b/crates/jolt-r1cs/src/nonnative.rs new file mode 100644 index 0000000000..b32671dfe9 --- /dev/null +++ b/crates/jolt-r1cs/src/nonnative.rs @@ -0,0 +1,918 @@ +//! Non-native field representations for verifier circuits. +//! +//! This module currently targets the wrapper path that constrains BN254 `Fq` +//! values inside an R1CS over BN254 `Fr`. Values are represented as canonical +//! little-endian integer limbs, not as Montgomery-form field internals. + +use jolt_field::{CanonicalBytes, FixedByteSize, Fq, Fr, FromPrimitiveInt, Invertible}; +use num_bigint::BigUint; +use num_traits::Zero; + +use crate::{AssignedScalar, LinearCombination, R1csBuilder}; + +const LIMB_BITS: usize = 64; +const NUM_LIMBS: usize = 4; +const MUL_LIMBS: usize = 2 * NUM_LIMBS; +const RADIX: u128 = 1u128 << LIMB_BITS; +const CARRY_BITS: usize = 68; + +const FQ_MODULUS_LIMBS: [u64; NUM_LIMBS] = [ + 4_332_616_871_279_656_263, + 10_917_124_144_477_883_021, + 13_281_191_951_274_694_749, + 3_486_998_266_802_970_665, +]; + +const FQ_MODULUS_MINUS_ONE_LIMBS: [u64; NUM_LIMBS] = [ + 4_332_616_871_279_656_262, + 10_917_124_144_477_883_021, + 13_281_191_951_274_694_749, + 3_486_998_266_802_970_665, +]; + +const FR_MODULUS_MINUS_ONE_LIMBS: [u64; NUM_LIMBS] = [ + 4_891_460_686_036_598_784, + 2_896_914_383_306_846_353, + 13_281_191_951_274_694_749, + 3_486_998_266_802_970_665, +]; + +/// A canonical BN254 `Fq` integer represented by four 64-bit limbs in an +/// `Fr`-native R1CS. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct FqVar { + limbs: [AssignedScalar; NUM_LIMBS], +} + +impl FqVar { + pub const LIMB_BITS: usize = LIMB_BITS; + pub const NUM_LIMBS: usize = NUM_LIMBS; + + pub fn constant(value: Fq) -> Self { + let limbs = fq_to_u64_limbs(value).map(|limb| AssignedScalar::constant(fr(limb))); + Self { limbs } + } + + pub fn alloc(builder: &mut R1csBuilder, value: Fq) -> Self { + let limbs = fq_to_u64_limbs(value).map(|limb| AssignedScalar::alloc(builder, fr(limb))); + Self::from_checked_limbs(builder, limbs) + } + + pub fn from_checked_limbs( + builder: &mut R1csBuilder, + limbs: [AssignedScalar; NUM_LIMBS], + ) -> Self { + for limb in &limbs { + assert_u64(builder, limb); + } + assert_limbs_less_or_equal(builder, &limbs, FQ_MODULUS_MINUS_ONE_LIMBS); + + Self { limbs } + } + + /// Injects a native `Fr` scalar into `Fq` as an integer. + /// + /// This is the canonical integer injection `Fr -> Fq`: constrain the + /// `Fr` value's little-endian integer limbs to be canonical modulo `q_Fr`, + /// then reuse those same limbs as an `Fq` value. This is valid because + /// BN254 has `q_Fr < q_Fq`; it is not a residue reinterpretation. + pub fn inject_fr(builder: &mut R1csBuilder, value: &AssignedScalar) -> Self { + let limbs = + fr_to_u64_limbs(value.value).map(|limb| AssignedScalar::alloc(builder, fr(limb))); + for limb in &limbs { + assert_u64(builder, limb); + } + assert_limbs_less_or_equal(builder, &limbs, FR_MODULUS_MINUS_ONE_LIMBS); + builder.assert_equal(value.lc.clone(), compose_limbs(&limbs)); + + Self { limbs } + } + + /// Converts a Poseidon-over-`Fr` challenge into an `Fq` challenge by the + /// canonical integer injection `Fr -> Fq`. + /// + /// The transcript remains an `Fr` transcript. Callers should domain-separate + /// in the transcript before squeezing `challenge`, then call this at the + /// protocol boundary where `Fq` arithmetic begins. + pub fn inject_fr_challenge( + builder: &mut R1csBuilder, + challenge: &AssignedScalar, + ) -> Self { + Self::inject_fr(builder, challenge) + } + + pub fn limbs(&self) -> &[AssignedScalar; NUM_LIMBS] { + &self.limbs + } + + pub fn witness_value(&self) -> Fq { + self.value() + } + + /// Returns a constrained little-endian bit decomposition of this `Fq` + /// variable's canonical integer representation. + pub fn bits_le(&self, builder: &mut R1csBuilder) -> Vec> { + self.limbs + .iter() + .flat_map(|limb| assert_unsigned_bits(builder, limb, LIMB_BITS)) + .collect() + } + + pub fn assert_equal(&self, builder: &mut R1csBuilder, rhs: &Self) { + for (lhs_limb, rhs_limb) in self.limbs.iter().zip(&rhs.limbs) { + builder.assert_equal(lhs_limb.lc.clone(), rhs_limb.lc.clone()); + } + } + + pub fn add(&self, builder: &mut R1csBuilder, rhs: &Self) -> Self { + let output = Self::alloc(builder, self.value() + rhs.value()); + assert_add_relation(builder, self, rhs, &output); + output + } + + pub fn sub(&self, builder: &mut R1csBuilder, rhs: &Self) -> Self { + let output = Self::alloc(builder, self.value() - rhs.value()); + assert_sub_relation(builder, self, rhs, &output); + output + } + + pub fn mul(&self, builder: &mut R1csBuilder, rhs: &Self) -> Self { + let output = Self::alloc(builder, self.value() * rhs.value()); + assert_mul_relation(builder, self, rhs, &output); + output + } + + pub fn inverse(&self, builder: &mut R1csBuilder) -> Option { + let output = Self::alloc(builder, self.value().inverse()?); + let one = Self::constant(Fq::from_u64(1)); + assert_mul_relation(builder, self, &output, &one); + Some(output) + } + + pub fn select( + builder: &mut R1csBuilder, + selector: &AssignedScalar, + when_true: &Self, + when_false: &Self, + ) -> Self { + assert_boolean(builder, selector); + + let output = if selector.value == fr(1) { + Self::alloc(builder, when_true.value()) + } else { + Self::alloc(builder, when_false.value()) + }; + + for ((output_limb, true_limb), false_limb) in output + .limbs + .iter() + .zip(&when_true.limbs) + .zip(&when_false.limbs) + { + let selected_delta = builder.multiply( + selector.lc.clone(), + true_limb.lc.clone() - false_limb.lc.clone(), + ); + builder.assert_equal( + output_limb.lc.clone(), + false_limb.lc.clone() + selected_delta, + ); + } + + output + } + + fn value(&self) -> Fq { + fq_from_limbs(self.value_limbs()) + } + + fn value_limbs(&self) -> [u64; NUM_LIMBS] { + std::array::from_fn(|index| scalar_low_u64(self.limbs[index].value)) + } +} + +fn assert_u64(builder: &mut R1csBuilder, value: &AssignedScalar) { + let _ = assert_unsigned_bits(builder, value, LIMB_BITS); +} + +fn assert_unsigned_bits( + builder: &mut R1csBuilder, + value: &AssignedScalar, + bit_len: usize, +) -> Vec> { + let bits = scalar_low_u128(value.value); + let bit_vars = (0..bit_len) + .map(|index| { + let bit = Fr::from_bool(((bits >> index) & 1) == 1); + let assigned = AssignedScalar::alloc(builder, bit); + assert_boolean(builder, &assigned); + assigned + }) + .collect::>(); + + builder.assert_equal(value.lc.clone(), compose_bits(&bit_vars)); + bit_vars +} + +fn assert_add_relation(builder: &mut R1csBuilder, lhs: &FqVar, rhs: &FqVar, output: &FqVar) { + let lhs_limbs = lhs.value_limbs(); + let rhs_limbs = rhs.value_limbs(); + let output_limbs = output.value_limbs(); + let sum = limbs_to_biguint(lhs_limbs) + limbs_to_biguint(rhs_limbs); + let wraps_modulus = sum >= fq_modulus(); + let quotient = AssignedScalar::alloc(builder, Fr::from_bool(wraps_modulus)); + assert_boolean(builder, "ient); + let normalized = alloc_u64_limbs(builder, biguint_to_u64_limbs::(&sum)); + + let lhs_terms = + std::array::from_fn(|index| lhs.limbs[index].lc.clone() + rhs.limbs[index].lc.clone()); + let lhs_raw_terms = std::array::from_fn(|index| { + BigUint::from(lhs_limbs[index]) + BigUint::from(rhs_limbs[index]) + }); + assert_terms_normalize_to(builder, lhs_terms, lhs_raw_terms, &normalized); + + let rhs_terms = std::array::from_fn(|index| { + output.limbs[index].lc.clone() + quotient.lc.clone().scale(fr(FQ_MODULUS_LIMBS[index])) + }); + let quotient_value = u64::from(wraps_modulus); + let rhs_raw_terms = std::array::from_fn(|index| { + BigUint::from(output_limbs[index]) + BigUint::from(FQ_MODULUS_LIMBS[index]) * quotient_value + }); + assert_terms_normalize_to(builder, rhs_terms, rhs_raw_terms, &normalized); +} + +fn assert_sub_relation(builder: &mut R1csBuilder, lhs: &FqVar, rhs: &FqVar, output: &FqVar) { + let lhs_limbs = lhs.value_limbs(); + let rhs_limbs = rhs.value_limbs(); + let output_limbs = output.value_limbs(); + let lhs_value = limbs_to_biguint(lhs_limbs); + let rhs_value = limbs_to_biguint(rhs_limbs); + let borrow = lhs_value < rhs_value; + let adjusted_lhs = if borrow { + lhs_value + fq_modulus() + } else { + lhs_value + }; + let quotient = AssignedScalar::alloc(builder, Fr::from_bool(borrow)); + assert_boolean(builder, "ient); + let normalized = alloc_u64_limbs(builder, biguint_to_u64_limbs::(&adjusted_lhs)); + + let lhs_terms = std::array::from_fn(|index| { + lhs.limbs[index].lc.clone() + quotient.lc.clone().scale(fr(FQ_MODULUS_LIMBS[index])) + }); + let quotient_value = u64::from(borrow); + let lhs_raw_terms = std::array::from_fn(|index| { + BigUint::from(lhs_limbs[index]) + BigUint::from(FQ_MODULUS_LIMBS[index]) * quotient_value + }); + assert_terms_normalize_to(builder, lhs_terms, lhs_raw_terms, &normalized); + + let rhs_terms = + std::array::from_fn(|index| rhs.limbs[index].lc.clone() + output.limbs[index].lc.clone()); + let rhs_raw_terms = std::array::from_fn(|index| { + BigUint::from(rhs_limbs[index]) + BigUint::from(output_limbs[index]) + }); + assert_terms_normalize_to(builder, rhs_terms, rhs_raw_terms, &normalized); +} + +fn assert_mul_relation(builder: &mut R1csBuilder, lhs: &FqVar, rhs: &FqVar, output: &FqVar) { + let lhs_limbs = lhs.value_limbs(); + let rhs_limbs = rhs.value_limbs(); + let output_limbs = output.value_limbs(); + let product = limbs_to_biguint(lhs_limbs) * limbs_to_biguint(rhs_limbs); + let output_value = limbs_to_biguint(output_limbs); + let quotient = if product >= output_value { + (product.clone() - output_value) / fq_modulus() + } else { + BigUint::zero() + }; + let quotient_limbs = alloc_limbs(builder, biguint_to_u64_limbs::("ient)); + let quotient = FqVar::from_checked_limbs(builder, quotient_limbs); + let normalized = alloc_u64_limbs(builder, biguint_to_u64_limbs::(&product)); + + let mut product_terms = std::array::from_fn(|_| LinearCombination::zero()); + for (lhs_index, lhs_limb) in lhs.limbs.iter().enumerate() { + for (rhs_index, rhs_limb) in rhs.limbs.iter().enumerate() { + let product_limb = builder.multiply(lhs_limb.lc.clone(), rhs_limb.lc.clone()); + product_terms[lhs_index + rhs_index] = + product_terms[lhs_index + rhs_index].clone() + product_limb; + } + } + let product_raw_terms = convolution_terms(lhs_limbs, rhs_limbs); + assert_terms_normalize_to(builder, product_terms, product_raw_terms, &normalized); + + let mut reduction_terms = std::array::from_fn(|_| LinearCombination::zero()); + for (modulus_index, modulus_limb) in FQ_MODULUS_LIMBS.into_iter().enumerate() { + for (quotient_index, quotient_limb) in quotient.limbs.iter().enumerate() { + reduction_terms[modulus_index + quotient_index] = + reduction_terms[modulus_index + quotient_index].clone() + + quotient_limb.lc.clone().scale(fr(modulus_limb)); + } + } + for (index, output_limb) in output.limbs.iter().enumerate() { + reduction_terms[index] = reduction_terms[index].clone() + output_limb.lc.clone(); + } + let mut reduction_raw_terms = constant_mul_terms(FQ_MODULUS_LIMBS, quotient.value_limbs()); + for (index, output_limb) in output_limbs.into_iter().enumerate() { + reduction_raw_terms[index] += BigUint::from(output_limb); + } + assert_terms_normalize_to(builder, reduction_terms, reduction_raw_terms, &normalized); +} + +fn alloc_u64_limbs( + builder: &mut R1csBuilder, + limbs: [u64; N], +) -> [AssignedScalar; N] { + let assigned = alloc_limbs(builder, limbs); + for limb in &assigned { + assert_u64(builder, limb); + } + assigned +} + +fn alloc_limbs( + builder: &mut R1csBuilder, + limbs: [u64; N], +) -> [AssignedScalar; N] { + limbs.map(|limb| AssignedScalar::alloc(builder, fr(limb))) +} + +fn assert_terms_normalize_to( + builder: &mut R1csBuilder, + terms: [LinearCombination; N], + raw_terms: [BigUint; N], + normalized: &[AssignedScalar; N], +) { + let mut carry_value = BigUint::zero(); + let mut carry = AssignedScalar::constant(fr(0)); + + for ((term, raw_term), normalized_limb) in terms.into_iter().zip(raw_terms).zip(normalized) { + let total = raw_term + &carry_value; + let limb = BigUint::from(scalar_low_u64(normalized_limb.value)); + let next_carry_value = if total >= limb { + (total - limb) / BigUint::from(RADIX) + } else { + BigUint::zero() + }; + let next_carry = AssignedScalar::alloc(builder, fr_from_biguint(&next_carry_value)); + let _ = assert_unsigned_bits(builder, &next_carry, CARRY_BITS); + + let lhs = term + carry.lc; + let rhs = normalized_limb.lc.clone() + next_carry.lc.clone().scale(radix_fr()); + builder.assert_equal(lhs, rhs); + + carry_value = next_carry_value; + carry = next_carry; + } + + builder.assert_equal(carry.lc, LinearCombination::zero()); +} + +fn assert_boolean(builder: &mut R1csBuilder, value: &AssignedScalar) { + builder.assert_product( + value.lc.clone(), + value.lc.clone() - LinearCombination::one(), + LinearCombination::zero(), + ); +} + +fn assert_limbs_less_or_equal( + builder: &mut R1csBuilder, + limbs: &[AssignedScalar; NUM_LIMBS], + bound: [u64; NUM_LIMBS], +) { + let difference = wrapping_difference_limbs(bound, limbs); + let mut borrow = AssignedScalar::constant(fr(0)); + + for ((limb, bound_limb), difference_limb) in limbs.iter().zip(bound).zip(difference) { + let next_borrow = AssignedScalar::alloc(builder, Fr::from_bool(difference_limb.borrow)); + assert_boolean(builder, &next_borrow); + + let difference = AssignedScalar::alloc(builder, fr(difference_limb.value)); + assert_u64(builder, &difference); + + let lhs = + LinearCombination::constant(fr(bound_limb)) + next_borrow.lc.clone().scale(radix_fr()); + let rhs = limb.lc.clone() + borrow.lc + difference.lc; + builder.assert_equal(lhs, rhs); + + borrow = next_borrow; + } + + builder.assert_equal(borrow.lc, LinearCombination::zero()); +} + +fn compose_bits(bits: &[AssignedScalar]) -> LinearCombination { + bits.iter() + .enumerate() + .fold(LinearCombination::zero(), |acc, (index, bit)| { + acc + bit.lc.clone().scale(Fr::from_u128(1u128 << index)) + }) +} + +fn compose_limbs(limbs: &[AssignedScalar; NUM_LIMBS]) -> LinearCombination { + let mut coefficient = fr(1); + let radix = radix_fr(); + limbs.iter().fold(LinearCombination::zero(), |acc, limb| { + let term = limb.lc.clone().scale(coefficient); + coefficient *= radix; + acc + term + }) +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +struct DifferenceLimb { + value: u64, + borrow: bool, +} + +fn wrapping_difference_limbs( + bound: [u64; NUM_LIMBS], + limbs: &[AssignedScalar; NUM_LIMBS], +) -> [DifferenceLimb; NUM_LIMBS] { + let mut borrow = 0u128; + std::array::from_fn(|index| { + let bound = u128::from(bound[index]); + let limb = u128::from(scalar_low_u64(limbs[index].value)); + let subtrahend = limb + borrow; + let (value, next_borrow) = if bound >= subtrahend { + (bound - subtrahend, 0u128) + } else { + (bound + RADIX - subtrahend, 1u128) + }; + borrow = next_borrow; + + DifferenceLimb { + value: value as u64, + borrow: next_borrow == 1, + } + }) +} + +fn fq_to_u64_limbs(value: Fq) -> [u64; NUM_LIMBS] { + let mut bytes = [0u8; Fq::NUM_BYTES]; + value.to_bytes_le(&mut bytes); + bytes_to_u64_limbs(bytes) +} + +fn fq_from_limbs(limbs: [u64; NUM_LIMBS]) -> Fq { + Fq::from_le_bytes_mod_order(&limbs_to_bytes(limbs)) +} + +fn fr_to_u64_limbs(value: Fr) -> [u64; NUM_LIMBS] { + let mut bytes = [0u8; Fr::NUM_BYTES]; + value.to_bytes_le(&mut bytes); + bytes_to_u64_limbs(bytes) +} + +fn bytes_to_u64_limbs(bytes: [u8; 32]) -> [u64; NUM_LIMBS] { + std::array::from_fn(|index| { + let offset = index * 8; + let mut limb = [0u8; 8]; + limb.copy_from_slice(&bytes[offset..offset + 8]); + u64::from_le_bytes(limb) + }) +} + +fn scalar_low_u64(value: Fr) -> u64 { + let mut bytes = [0u8; Fr::NUM_BYTES]; + value.to_bytes_le(&mut bytes); + let mut limb = [0u8; 8]; + limb.copy_from_slice(&bytes[..8]); + u64::from_le_bytes(limb) +} + +fn scalar_low_u128(value: Fr) -> u128 { + let mut bytes = [0u8; Fr::NUM_BYTES]; + value.to_bytes_le(&mut bytes); + let mut limbs = [0u8; 16]; + limbs.copy_from_slice(&bytes[..16]); + u128::from_le_bytes(limbs) +} + +fn limbs_to_biguint(limbs: [u64; N]) -> BigUint { + BigUint::from_bytes_le(&limbs_to_bytes(limbs)) +} + +fn biguint_to_u64_limbs(value: &BigUint) -> [u64; N] { + let bytes = value.to_bytes_le(); + std::array::from_fn(|index| { + let offset = index * 8; + let mut limb = [0u8; 8]; + if offset < bytes.len() { + let available = (bytes.len() - offset).min(8); + limb[..available].copy_from_slice(&bytes[offset..offset + available]); + } + u64::from_le_bytes(limb) + }) +} + +fn limbs_to_bytes(limbs: [u64; N]) -> Vec { + let mut bytes = Vec::with_capacity(N * 8); + for limb in limbs { + bytes.extend_from_slice(&limb.to_le_bytes()); + } + bytes +} + +fn fq_modulus() -> BigUint { + limbs_to_biguint(FQ_MODULUS_LIMBS) +} + +fn convolution_terms(lhs: [u64; NUM_LIMBS], rhs: [u64; NUM_LIMBS]) -> [BigUint; MUL_LIMBS] { + let mut terms = std::array::from_fn(|_| BigUint::zero()); + for (lhs_index, lhs_limb) in lhs.into_iter().enumerate() { + for (rhs_index, rhs_limb) in rhs.into_iter().enumerate() { + terms[lhs_index + rhs_index] += BigUint::from(lhs_limb) * BigUint::from(rhs_limb); + } + } + terms +} + +fn constant_mul_terms(lhs: [u64; NUM_LIMBS], rhs: [u64; NUM_LIMBS]) -> [BigUint; MUL_LIMBS] { + let mut terms = std::array::from_fn(|_| BigUint::zero()); + for (lhs_index, lhs_limb) in lhs.into_iter().enumerate() { + for (rhs_index, rhs_limb) in rhs.into_iter().enumerate() { + terms[lhs_index + rhs_index] += BigUint::from(lhs_limb) * BigUint::from(rhs_limb); + } + } + terms +} + +fn fr_from_biguint(value: &BigUint) -> Fr { + Fr::from_le_bytes_mod_order(&value.to_bytes_le()) +} + +fn radix_fr() -> Fr { + Fr::from_u128(RADIX) +} + +fn fr(value: u64) -> Fr { + Fr::from_u64(value) +} + +#[cfg(test)] +#[expect(clippy::expect_used, reason = "tests may panic on assertion failures")] +mod tests { + use super::*; + use crate::Variable; + + #[test] + fn accepts_fq_values_at_canonical_edges() { + assert_fq_value_satisfies(Fq::from_u64(0)); + assert_fq_value_satisfies(fq_from_limbs(FQ_MODULUS_MINUS_ONE_LIMBS)); + } + + #[test] + fn rejects_limb_out_of_u64_range() { + let mut builder = R1csBuilder::new(); + let limbs = [ + AssignedScalar::alloc(&mut builder, Fr::from_u128(RADIX)), + AssignedScalar::alloc(&mut builder, fr(0)), + AssignedScalar::alloc(&mut builder, fr(0)), + AssignedScalar::alloc(&mut builder, fr(0)), + ]; + + let _ = FqVar::from_checked_limbs(&mut builder, limbs); + + assert!(builder_rejects(builder)); + } + + #[test] + fn rejects_fq_modulus_as_noncanonical() { + let mut builder = R1csBuilder::new(); + let limbs = [ + AssignedScalar::alloc(&mut builder, fr(4_332_616_871_279_656_263)), + AssignedScalar::alloc(&mut builder, fr(10_917_124_144_477_883_021)), + AssignedScalar::alloc(&mut builder, fr(13_281_191_951_274_694_749)), + AssignedScalar::alloc(&mut builder, fr(3_486_998_266_802_970_665)), + ]; + + let _ = FqVar::from_checked_limbs(&mut builder, limbs); + + assert!(builder_rejects(builder)); + } + + #[test] + fn injects_native_fr_as_canonical_fq_limbs() { + assert_fr_injection_satisfies(fr(0)); + assert_fr_injection_satisfies(Fr::from_u64(17)); + assert_fr_injection_satisfies(fr_from_limbs(FR_MODULUS_MINUS_ONE_LIMBS)); + } + + #[test] + fn fq_bits_le_match_canonical_integer_bits() { + let mut builder = R1csBuilder::new(); + let value = Fq::from_u64(0xdead_beef); + let fq = FqVar::alloc(&mut builder, value); + + let bits = fq.bits_le(&mut builder); + + assert_eq!(bits.len(), FqVar::NUM_LIMBS * FqVar::LIMB_BITS); + for (index, bit) in bits.iter().enumerate().take(32) { + let expected = Fr::from_bool(((0xdead_beefu64 >> index) & 1) == 1); + assert_eq!(bit.value, expected); + } + assert!(bits[32..].iter().all(|bit| bit.value.is_zero())); + assert!(builder_accepts(builder)); + } + + #[test] + fn fq_bits_le_rejects_tampered_bit() { + let mut builder = R1csBuilder::new(); + let fq = FqVar::alloc(&mut builder, Fq::from_u64(0xdead_beef)); + let bits = fq.bits_le(&mut builder); + let bit = variable(&bits[0]); + let mut witness = builder.witness().expect("witness is assigned"); + witness[bit.index()] += fr(1); + + assert!(builder.into_matrices().check_witness(&witness).is_err()); + } + + #[test] + fn rejects_tampered_fr_injection_limb() { + let mut builder = R1csBuilder::new(); + let value = AssignedScalar::alloc(&mut builder, Fr::from_u64(17)); + let fq = FqVar::inject_fr(&mut builder, &value); + let first_limb = variable(&fq.limbs()[0]); + let mut witness = builder.witness().expect("witness is assigned"); + witness[first_limb.index()] += fr(1); + + assert!(builder.into_matrices().check_witness(&witness).is_err()); + } + + #[test] + fn fr_challenge_injection_uses_canonical_integer_embedding() { + assert_fr_challenge_injection_satisfies(fr(0)); + assert_fr_challenge_injection_satisfies(Fr::from_u64(99)); + assert_fr_challenge_injection_satisfies(fr_from_limbs(FR_MODULUS_MINUS_ONE_LIMBS)); + } + + #[test] + fn fr_challenge_injection_rejects_tampered_source_challenge() { + let mut builder = R1csBuilder::new(); + let challenge = AssignedScalar::alloc(&mut builder, Fr::from_u64(99)); + let _ = FqVar::inject_fr_challenge(&mut builder, &challenge); + let challenge_variable = variable(&challenge); + let mut witness = builder.witness().expect("witness is assigned"); + witness[challenge_variable.index()] += fr(1); + + assert!(builder.into_matrices().check_witness(&witness).is_err()); + } + + #[test] + fn fr_challenge_injection_rejects_tampered_fq_limb() { + let mut builder = R1csBuilder::new(); + let challenge = AssignedScalar::alloc(&mut builder, Fr::from_u64(99)); + let fq = FqVar::inject_fr_challenge(&mut builder, &challenge); + let first_limb = variable(&fq.limbs()[0]); + let mut witness = builder.witness().expect("witness is assigned"); + witness[first_limb.index()] += fr(1); + + assert!(builder.into_matrices().check_witness(&witness).is_err()); + } + + #[test] + fn fq_arithmetic_completeness_matches_native_field() { + let values = interesting_fq_values(); + for &lhs in &values { + for &rhs in &values { + let mut builder = R1csBuilder::new(); + let lhs_var = FqVar::alloc(&mut builder, lhs); + let rhs_var = FqVar::alloc(&mut builder, rhs); + + let sum = lhs_var.add(&mut builder, &rhs_var); + let difference = lhs_var.sub(&mut builder, &rhs_var); + let product = lhs_var.mul(&mut builder, &rhs_var); + + assert_eq!(sum.value(), lhs + rhs); + assert_eq!(difference.value(), lhs - rhs); + assert_eq!(product.value(), lhs * rhs); + assert!(builder_accepts(builder)); + } + } + } + + #[test] + fn fq_inverse_completeness_matches_native_field() { + for value in interesting_fq_values() + .into_iter() + .filter(|value| *value != Fq::from_u64(0)) + { + let mut builder = R1csBuilder::new(); + let value_var = FqVar::alloc(&mut builder, value); + let inverse = value_var + .inverse(&mut builder) + .expect("nonzero field element has an inverse"); + + assert_eq!(value * inverse.value(), Fq::from_u64(1)); + assert!(builder_accepts(builder)); + } + } + + #[test] + fn fq_inverse_returns_none_for_zero() { + let mut builder = R1csBuilder::new(); + let zero = FqVar::alloc(&mut builder, Fq::from_u64(0)); + + assert!(zero.inverse(&mut builder).is_none()); + } + + #[test] + fn fq_select_completeness_matches_selector() { + let mut builder = R1csBuilder::new(); + let when_true = FqVar::alloc(&mut builder, Fq::from_u64(11)); + let when_false = FqVar::alloc(&mut builder, fq_from_limbs(FQ_MODULUS_MINUS_ONE_LIMBS)); + let true_selector = AssignedScalar::alloc(&mut builder, fr(1)); + let false_selector = AssignedScalar::alloc(&mut builder, fr(0)); + + let selected_true = FqVar::select(&mut builder, &true_selector, &when_true, &when_false); + let selected_false = FqVar::select(&mut builder, &false_selector, &when_true, &when_false); + + assert_eq!(selected_true.value(), when_true.value()); + assert_eq!(selected_false.value(), when_false.value()); + assert!(builder_accepts(builder)); + } + + #[test] + fn fq_select_rejects_non_boolean_selector() { + let mut builder = R1csBuilder::new(); + let when_true = FqVar::alloc(&mut builder, Fq::from_u64(11)); + let when_false = FqVar::alloc(&mut builder, Fq::from_u64(7)); + let selector = AssignedScalar::alloc(&mut builder, fr(2)); + + let _ = FqVar::select(&mut builder, &selector, &when_true, &when_false); + + assert!(builder_rejects(builder)); + } + + #[test] + fn fq_add_soundness_rejects_single_variable_tampering() { + let mut builder = R1csBuilder::new(); + let lhs = FqVar::alloc(&mut builder, fq_from_limbs(FQ_MODULUS_MINUS_ONE_LIMBS)); + let rhs = FqVar::alloc(&mut builder, Fq::from_u64(9)); + + let _ = lhs.add(&mut builder, &rhs); + + assert_single_variable_tampering_rejected(builder); + } + + #[test] + fn fq_sub_soundness_rejects_single_variable_tampering() { + let mut builder = R1csBuilder::new(); + let lhs = FqVar::alloc(&mut builder, Fq::from_u64(3)); + let rhs = FqVar::alloc(&mut builder, Fq::from_u64(9)); + + let _ = lhs.sub(&mut builder, &rhs); + + assert_single_variable_tampering_rejected(builder); + } + + #[test] + fn fq_mul_soundness_rejects_single_variable_tampering() { + let mut builder = R1csBuilder::new(); + let lhs = FqVar::alloc(&mut builder, fq_from_limbs(FQ_MODULUS_MINUS_ONE_LIMBS)); + let rhs = FqVar::alloc(&mut builder, fq_from_limbs([17, 9_876_543_210, 11, 3])); + + let _ = lhs.mul(&mut builder, &rhs); + + assert_single_variable_tampering_rejected(builder); + } + + #[test] + fn fq_inverse_soundness_rejects_single_variable_tampering() { + let mut builder = R1csBuilder::new(); + let value = FqVar::alloc(&mut builder, fq_from_limbs([17, 9_876_543_210, 11, 3])); + + let _ = value + .inverse(&mut builder) + .expect("nonzero field element has an inverse"); + + assert_single_variable_tampering_rejected(builder); + } + + #[test] + fn fq_select_soundness_rejects_single_variable_tampering() { + let mut builder = R1csBuilder::new(); + let when_true = FqVar::alloc(&mut builder, fq_from_limbs([17, 9_876_543_210, 11, 3])); + let when_false = FqVar::alloc(&mut builder, Fq::from_u64(7)); + let selector = AssignedScalar::alloc(&mut builder, fr(1)); + + let _ = FqVar::select(&mut builder, &selector, &when_true, &when_false); + + assert_single_variable_tampering_rejected(builder); + } + + fn assert_fq_value_satisfies(value: Fq) { + let mut builder = R1csBuilder::new(); + let fq = FqVar::alloc(&mut builder, value); + + assert_eq!( + fq.limbs() + .iter() + .map(|limb| scalar_low_u64(limb.value)) + .collect::>(), + fq_to_u64_limbs(value) + ); + assert!(builder_accepts(builder)); + } + + fn assert_fr_injection_satisfies(value: Fr) { + let mut builder = R1csBuilder::new(); + let scalar = AssignedScalar::alloc(&mut builder, value); + let fq = FqVar::inject_fr(&mut builder, &scalar); + + assert_eq!( + fq.limbs() + .iter() + .map(|limb| scalar_low_u64(limb.value)) + .collect::>(), + fr_to_u64_limbs(value) + ); + assert!(builder_accepts(builder)); + } + + fn assert_fr_challenge_injection_satisfies(value: Fr) { + let mut builder = R1csBuilder::new(); + let challenge = AssignedScalar::alloc(&mut builder, value); + let fq = FqVar::inject_fr_challenge(&mut builder, &challenge); + + assert_eq!( + fq.limbs() + .iter() + .map(|limb| scalar_low_u64(limb.value)) + .collect::>(), + fr_to_u64_limbs(value) + ); + assert!(builder_accepts(builder)); + } + + fn builder_accepts(builder: R1csBuilder) -> bool { + let witness = builder.witness().expect("witness is assigned"); + builder.into_matrices().check_witness(&witness).is_ok() + } + + fn builder_rejects(builder: R1csBuilder) -> bool { + let witness = builder.witness().expect("witness is assigned"); + builder.into_matrices().check_witness(&witness).is_err() + } + + fn assert_single_variable_tampering_rejected(builder: R1csBuilder) { + let witness = builder.witness().expect("witness is assigned"); + let matrices = builder.into_matrices(); + assert!(matrices.check_witness(&witness).is_ok()); + + for index in 1..witness.len() { + let mut tampered = witness.clone(); + tampered[index] += fr(1); + assert!( + matrices.check_witness(&tampered).is_err(), + "variable {index} accepted after single-variable tampering" + ); + } + } + + fn interesting_fq_values() -> Vec { + vec![ + Fq::from_u64(0), + Fq::from_u64(1), + Fq::from_u64(2), + Fq::from_u64(17), + fq_from_limbs([123, 456, 789, 101_112]), + fq_from_limbs([ + FQ_MODULUS_MINUS_ONE_LIMBS[0] - 1, + FQ_MODULUS_MINUS_ONE_LIMBS[1], + FQ_MODULUS_MINUS_ONE_LIMBS[2], + FQ_MODULUS_MINUS_ONE_LIMBS[3], + ]), + fq_from_limbs(FQ_MODULUS_MINUS_ONE_LIMBS), + ] + } + + fn fq_from_limbs(limbs: [u64; NUM_LIMBS]) -> Fq { + Fq::from_le_bytes_mod_order(&limbs_to_bytes(limbs)) + } + + fn fr_from_limbs(limbs: [u64; NUM_LIMBS]) -> Fr { + Fr::from_le_bytes_mod_order(&limbs_to_bytes(limbs)) + } + + fn limbs_to_bytes(limbs: [u64; NUM_LIMBS]) -> [u8; 32] { + let mut bytes = [0u8; 32]; + for (index, limb) in limbs.into_iter().enumerate() { + bytes[index * 8..index * 8 + 8].copy_from_slice(&limb.to_le_bytes()); + } + bytes + } + + fn variable(value: &AssignedScalar) -> Variable { + assert_eq!(value.lc.terms.len(), 1); + let (variable, coefficient) = value + .lc + .terms + .first() + .copied() + .expect("linear combination has one term"); + assert_eq!(coefficient, fr(1)); + variable + } +} diff --git a/crates/jolt-r1cs/src/scalar.rs b/crates/jolt-r1cs/src/scalar.rs new file mode 100644 index 0000000000..f560b70b98 --- /dev/null +++ b/crates/jolt-r1cs/src/scalar.rs @@ -0,0 +1,443 @@ +//! Shared scalar helpers for R1CS verifier equations. +//! +//! These helpers let verifier equations be written once over either native +//! builder-field scalars or non-native scalars represented inside the builder. + +use jolt_field::{Field, Fq, Fr}; +use jolt_poly::r1cs::PolynomialScalarGadget; +use num_traits::{One, Zero}; + +use crate::{AssignedScalar, FqVar, LinearCombination, R1csBuilder}; + +pub trait ScalarGadget: Clone { + type BuilderField: Field; + type Scalar: Field; + + fn constant(scalar: Self::Scalar) -> Self; + fn alloc(builder: &mut R1csBuilder, scalar: Self::Scalar) -> Self; + fn assert_equal(&self, builder: &mut R1csBuilder, rhs: &Self); + fn add(&self, builder: &mut R1csBuilder, rhs: &Self) -> Self; + fn sub(&self, builder: &mut R1csBuilder, rhs: &Self) -> Self; + fn mul(&self, builder: &mut R1csBuilder, rhs: &Self) -> Self; + fn scale_by_constant( + &self, + builder: &mut R1csBuilder, + scalar: Self::Scalar, + ) -> Self; + fn select( + builder: &mut R1csBuilder, + selector: &AssignedScalar, + when_true: &Self, + when_false: &Self, + ) -> Self; +} + +pub fn scalar_affine_combination<'a, S>( + builder: &mut R1csBuilder, + constant: S::Scalar, + terms: impl IntoIterator, +) -> S +where + S: ScalarGadget + 'a, +{ + let mut result = S::constant(constant); + for (coefficient, scalar) in terms { + let term = scalar.scale_by_constant(builder, coefficient); + result = result.add(builder, &term); + } + result +} + +pub fn scalar_dot_product<'a, S>( + builder: &mut R1csBuilder, + terms: impl IntoIterator, +) -> S +where + S: ScalarGadget + 'a, +{ + let mut result = S::constant(S::Scalar::zero()); + for (lhs, rhs) in terms { + let term = lhs.mul(builder, rhs); + result = result.add(builder, &term); + } + result +} + +impl ScalarGadget for AssignedScalar +where + F: Field, +{ + type BuilderField = F; + type Scalar = F; + + fn constant(scalar: Self::Scalar) -> Self { + AssignedScalar::constant(scalar) + } + + fn alloc(builder: &mut R1csBuilder, scalar: Self::Scalar) -> Self { + AssignedScalar::alloc(builder, scalar) + } + + fn assert_equal(&self, builder: &mut R1csBuilder, rhs: &Self) { + builder.assert_equal(self.lc.clone(), rhs.lc.clone()); + } + + fn add(&self, _builder: &mut R1csBuilder, rhs: &Self) -> Self { + Self::new(self.value + rhs.value, self.lc.clone() + rhs.lc.clone()) + } + + fn sub(&self, _builder: &mut R1csBuilder, rhs: &Self) -> Self { + Self::new(self.value - rhs.value, self.lc.clone() - rhs.lc.clone()) + } + + fn mul(&self, builder: &mut R1csBuilder, rhs: &Self) -> Self { + Self::new( + self.value * rhs.value, + builder.multiply(self.lc.clone(), rhs.lc.clone()), + ) + } + + fn scale_by_constant( + &self, + _builder: &mut R1csBuilder, + scalar: Self::Scalar, + ) -> Self { + self.clone().scale(scalar) + } + + fn select( + builder: &mut R1csBuilder, + selector: &AssignedScalar, + when_true: &Self, + when_false: &Self, + ) -> Self { + assert_boolean(builder, selector); + + let selected_delta = builder.multiply( + selector.lc.clone(), + when_true.lc.clone() - when_false.lc.clone(), + ); + Self::new( + when_false.value + selector.value * (when_true.value - when_false.value), + when_false.lc.clone() + selected_delta, + ) + } +} + +impl ScalarGadget for FqVar { + type BuilderField = Fr; + type Scalar = Fq; + + fn constant(scalar: Self::Scalar) -> Self { + FqVar::constant(scalar) + } + + fn alloc(builder: &mut R1csBuilder, scalar: Self::Scalar) -> Self { + FqVar::alloc(builder, scalar) + } + + fn assert_equal(&self, builder: &mut R1csBuilder, rhs: &Self) { + self.assert_equal(builder, rhs); + } + + fn add(&self, builder: &mut R1csBuilder, rhs: &Self) -> Self { + self.add(builder, rhs) + } + + fn sub(&self, builder: &mut R1csBuilder, rhs: &Self) -> Self { + self.sub(builder, rhs) + } + + fn mul(&self, builder: &mut R1csBuilder, rhs: &Self) -> Self { + self.mul(builder, rhs) + } + + fn scale_by_constant( + &self, + builder: &mut R1csBuilder, + scalar: Self::Scalar, + ) -> Self { + if scalar.is_zero() { + Self::constant(Fq::zero()) + } else if scalar == Fq::one() { + self.clone() + } else { + self.mul(builder, &Self::constant(scalar)) + } + } + + fn select( + builder: &mut R1csBuilder, + selector: &AssignedScalar, + when_true: &Self, + when_false: &Self, + ) -> Self { + FqVar::select(builder, selector, when_true, when_false) + } +} + +impl PolynomialScalarGadget for AssignedScalar +where + F: Field, +{ + type ConstraintBuilder = R1csBuilder; + type Scalar = F; + + fn constant(scalar: Self::Scalar) -> Self { + ::constant(scalar) + } + + fn add(&self, builder: &mut Self::ConstraintBuilder, rhs: &Self) -> Self { + ::add(self, builder, rhs) + } + + fn sub(&self, builder: &mut Self::ConstraintBuilder, rhs: &Self) -> Self { + ::sub(self, builder, rhs) + } + + fn mul(&self, builder: &mut Self::ConstraintBuilder, rhs: &Self) -> Self { + ::mul(self, builder, rhs) + } +} + +impl PolynomialScalarGadget for FqVar { + type ConstraintBuilder = R1csBuilder; + type Scalar = Fq; + + fn constant(scalar: Self::Scalar) -> Self { + ::constant(scalar) + } + + fn add(&self, builder: &mut Self::ConstraintBuilder, rhs: &Self) -> Self { + ::add(self, builder, rhs) + } + + fn sub(&self, builder: &mut Self::ConstraintBuilder, rhs: &Self) -> Self { + ::sub(self, builder, rhs) + } + + fn mul(&self, builder: &mut Self::ConstraintBuilder, rhs: &Self) -> Self { + ::mul(self, builder, rhs) + } +} + +fn assert_boolean(builder: &mut R1csBuilder, value: &AssignedScalar) +where + F: Field, +{ + builder.assert_product( + value.lc.clone(), + value.lc.clone() - LinearCombination::one(), + LinearCombination::zero(), + ); +} + +#[cfg(test)] +#[expect(clippy::expect_used, reason = "tests may panic on assertion failures")] +mod tests { + use jolt_field::{Fq, Fr, FromPrimitiveInt}; + + use super::*; + use crate::Variable; + + #[test] + fn native_scalar_relation_accepts_valid_witnesses() { + for selector in [false, true] { + let mut builder = R1csBuilder::::new(); + let x = AssignedScalar::alloc(&mut builder, Fr::from_u64(9)); + let y = AssignedScalar::alloc(&mut builder, Fr::from_u64(12)); + let z = AssignedScalar::alloc(&mut builder, Fr::from_u64(5)); + let w = AssignedScalar::alloc(&mut builder, Fr::from_u64(4)); + let selector = AssignedScalar::alloc(&mut builder, Fr::from_bool(selector)); + + let result = shared_relation(&mut builder, &selector, &x, &y, &z, &w); + result.assert_equal( + &mut builder, + &AssignedScalar::constant(expected_relation( + selector.value == Fr::one(), + x.value, + y.value, + z.value, + w.value, + )), + ); + + assert!(builder_accepts(builder)); + } + } + + #[test] + fn native_scalar_relation_rejects_single_variable_tampering() { + let mut builder = R1csBuilder::::new(); + let x = AssignedScalar::alloc(&mut builder, Fr::from_u64(9)); + let y = AssignedScalar::alloc(&mut builder, Fr::from_u64(12)); + let z = AssignedScalar::alloc(&mut builder, Fr::from_u64(5)); + let w = AssignedScalar::alloc(&mut builder, Fr::from_u64(4)); + let selector = AssignedScalar::alloc(&mut builder, Fr::one()); + + let result = shared_relation(&mut builder, &selector, &x, &y, &z, &w); + result.assert_equal( + &mut builder, + &AssignedScalar::constant(expected_relation(true, x.value, y.value, z.value, w.value)), + ); + + assert_single_variable_tampering_rejected(builder); + } + + #[test] + fn nonnative_scalar_relation_accepts_valid_witnesses() { + for selector in [false, true] { + let mut builder = R1csBuilder::::new(); + let x = FqVar::alloc(&mut builder, Fq::from_u64(9)); + let y = FqVar::alloc(&mut builder, Fq::from_u64(12)); + let z = FqVar::alloc(&mut builder, Fq::from_u64(5)); + let w = FqVar::alloc(&mut builder, Fq::from_u64(4)); + let selector = AssignedScalar::alloc(&mut builder, Fr::from_bool(selector)); + + let result = shared_relation(&mut builder, &selector, &x, &y, &z, &w); + result.assert_equal( + &mut builder, + &FqVar::constant(expected_relation( + selector.value == Fr::one(), + Fq::from_u64(9), + Fq::from_u64(12), + Fq::from_u64(5), + Fq::from_u64(4), + )), + ); + + assert!(builder_accepts(builder)); + } + } + + #[test] + fn nonnative_scalar_relation_rejects_single_variable_tampering() { + let mut builder = R1csBuilder::::new(); + let x = FqVar::alloc(&mut builder, Fq::from_u64(9)); + let y = FqVar::alloc(&mut builder, Fq::from_u64(12)); + let z = FqVar::alloc(&mut builder, Fq::from_u64(5)); + let w = FqVar::alloc(&mut builder, Fq::from_u64(4)); + let selector = AssignedScalar::alloc(&mut builder, Fr::one()); + + let result = shared_relation(&mut builder, &selector, &x, &y, &z, &w); + result.assert_equal( + &mut builder, + &FqVar::constant(expected_relation( + true, + Fq::from_u64(9), + Fq::from_u64(12), + Fq::from_u64(5), + Fq::from_u64(4), + )), + ); + + assert_variable_tampering_rejected( + builder, + [ + variable(&selector), + variable(&x.limbs()[0]), + variable(&result.limbs()[0]), + ], + ); + } + + fn shared_relation( + builder: &mut R1csBuilder, + selector: &AssignedScalar, + x: &S, + y: &S, + z: &S, + w: &S, + ) -> S + where + S: ScalarGadget, + { + let dot_product = scalar_dot_product(builder, [(x, y), (z, w)]); + let true_branch = scalar_affine_combination( + builder, + S::Scalar::from_u64(7), + [ + (S::Scalar::from_u64(3), &dot_product), + (S::Scalar::from_u64(5), x), + ], + ); + + let product = y.mul(builder, z); + let difference = product.sub(builder, w); + let false_branch = scalar_affine_combination( + builder, + -S::Scalar::from_u64(11), + [(S::Scalar::from_u64(2), &difference), (S::Scalar::one(), z)], + ); + + S::select(builder, selector, &true_branch, &false_branch) + } + + fn expected_relation(selector: bool, x: F, y: F, z: F, w: F) -> F + where + F: Field, + { + let dot_product = x * y + z * w; + let true_branch = F::from_u64(7) + F::from_u64(3) * dot_product + F::from_u64(5) * x; + let false_branch = -F::from_u64(11) + F::from_u64(2) * (y * z - w) + z; + if selector { + true_branch + } else { + false_branch + } + } + + fn builder_accepts(builder: R1csBuilder) -> bool + where + F: Field, + { + let witness = builder.witness().expect("witness is assigned"); + builder.into_matrices().check_witness(&witness).is_ok() + } + + fn assert_single_variable_tampering_rejected(builder: R1csBuilder) + where + F: Field, + { + let variables = (1..builder.num_vars()) + .map(Variable::new) + .collect::>(); + assert_variable_tampering_rejected(builder, variables); + } + + fn assert_variable_tampering_rejected( + builder: R1csBuilder, + variables: impl IntoIterator, + ) where + F: Field, + { + let witness = builder.witness().expect("witness is assigned"); + let matrices = builder.into_matrices(); + assert!(matrices.check_witness(&witness).is_ok()); + + for variable in variables { + let mut tampered = witness.clone(); + let index = variable.index(); + tampered[index] += F::one(); + assert!( + matrices.check_witness(&tampered).is_err(), + "variable {index} accepted after single-variable tampering" + ); + } + } + + fn variable(scalar: &AssignedScalar) -> Variable + where + F: Field, + { + assert_eq!(scalar.lc.terms.len(), 1); + let (variable, coefficient) = scalar + .lc + .terms + .first() + .copied() + .expect("linear combination has one term"); + assert_eq!(coefficient, F::one()); + variable + } +} diff --git a/crates/jolt-r1cs/tests/poly_r1cs.rs b/crates/jolt-r1cs/tests/poly_r1cs.rs new file mode 100644 index 0000000000..7ce35bc045 --- /dev/null +++ b/crates/jolt-r1cs/tests/poly_r1cs.rs @@ -0,0 +1,191 @@ +#![expect(clippy::expect_used, reason = "integration tests may fail by panic")] + +use jolt_field::{Fq, Fr, FromPrimitiveInt}; +use jolt_poly::{ + r1cs::{eq_evals, multilinear_eval}, + EqPolynomial, Polynomial, +}; +use jolt_r1cs::{AssignedScalar, FqVar, R1csBuilder, Variable}; + +#[test] +fn native_poly_r1cs_multilinear_eval_accepts_valid_witness() { + let mut builder = R1csBuilder::::new(); + let evaluation_values = (0..8) + .map(|index| Fr::from_u64((3 * index + 2) as u64)) + .collect::>(); + let point_values = [Fr::from_u64(2), Fr::from_u64(3), Fr::from_u64(5)]; + let evaluations = evaluation_values + .iter() + .copied() + .map(|value| AssignedScalar::alloc(&mut builder, value)) + .collect::>(); + let point = point_values + .iter() + .copied() + .map(|value| AssignedScalar::alloc(&mut builder, value)) + .collect::>(); + + let result = multilinear_eval(&mut builder, &evaluations, &point).expect("evaluation succeeds"); + let expected = Polynomial::new(evaluation_values).evaluate(&point_values); + builder.assert_equal(result.lc, AssignedScalar::constant(expected).lc); + + assert!(builder_accepts(builder)); +} + +#[test] +fn native_poly_r1cs_rejects_single_variable_tampering() { + let mut builder = R1csBuilder::::new(); + let evaluation_values = (0..8) + .map(|index| Fr::from_u64((5 * index + 1) as u64)) + .collect::>(); + let point_values = [Fr::from_u64(2), Fr::from_u64(3), Fr::from_u64(5)]; + let evaluations = evaluation_values + .iter() + .copied() + .map(|value| AssignedScalar::alloc(&mut builder, value)) + .collect::>(); + let point = point_values + .iter() + .copied() + .map(|value| AssignedScalar::alloc(&mut builder, value)) + .collect::>(); + + let result = multilinear_eval(&mut builder, &evaluations, &point).expect("evaluation succeeds"); + let expected = Polynomial::new(evaluation_values).evaluate(&point_values); + builder.assert_equal(result.lc, AssignedScalar::constant(expected).lc); + + assert_single_variable_tampering_rejected(builder); +} + +#[test] +fn nonnative_poly_r1cs_multilinear_eval_accepts_valid_witness() { + let mut builder = R1csBuilder::::new(); + let evaluation_values = (0..8) + .map(|index| Fq::from_u64((7 * index + 4) as u64)) + .collect::>(); + let point_values = [Fq::from_u64(2), Fq::from_u64(3), Fq::from_u64(5)]; + let evaluations = evaluation_values + .iter() + .copied() + .map(|value| FqVar::alloc(&mut builder, value)) + .collect::>(); + let point = point_values + .iter() + .copied() + .map(|value| FqVar::alloc(&mut builder, value)) + .collect::>(); + + let result = multilinear_eval(&mut builder, &evaluations, &point).expect("evaluation succeeds"); + let expected = Polynomial::new(evaluation_values).evaluate(&point_values); + result.assert_equal(&mut builder, &FqVar::constant(expected)); + + assert!(builder_accepts(builder)); +} + +#[test] +fn nonnative_poly_r1cs_rejects_single_variable_tampering() { + let mut builder = R1csBuilder::::new(); + let evaluation_values = (0..8) + .map(|index| Fq::from_u64((11 * index + 4) as u64)) + .collect::>(); + let point_values = [Fq::from_u64(2), Fq::from_u64(3), Fq::from_u64(5)]; + let evaluations = evaluation_values + .iter() + .copied() + .map(|value| FqVar::alloc(&mut builder, value)) + .collect::>(); + let point = point_values + .iter() + .copied() + .map(|value| FqVar::alloc(&mut builder, value)) + .collect::>(); + + let result = multilinear_eval(&mut builder, &evaluations, &point).expect("evaluation succeeds"); + let expected = Polynomial::new(evaluation_values).evaluate(&point_values); + result.assert_equal(&mut builder, &FqVar::constant(expected)); + + let tamper_targets = [ + variable(&evaluations[0].limbs()[0]).index(), + variable(&point[0].limbs()[0]).index(), + variable(&result.limbs()[0]).index(), + ]; + assert_selected_variable_tampering_rejected(builder, &tamper_targets); +} + +#[test] +fn native_eq_table_order_matches_jolt_poly_plain_order() { + let mut builder = R1csBuilder::::new(); + let point_values = [Fr::from_u64(2), Fr::from_u64(3), Fr::from_u64(5)]; + let point = point_values + .iter() + .copied() + .map(|value| AssignedScalar::alloc(&mut builder, value)) + .collect::>(); + + let actual = eq_evals(&mut builder, &point); + let expected = EqPolynomial::new(point_values.to_vec()).evaluations(); + for (actual, expected) in actual.into_iter().zip(expected) { + builder.assert_equal(actual.lc, AssignedScalar::constant(expected).lc); + } + + assert!(builder_accepts(builder)); +} + +fn builder_accepts(builder: R1csBuilder) -> bool +where + F: jolt_field::Field, +{ + let witness = builder.witness().expect("witness is assigned"); + builder.into_matrices().check_witness(&witness).is_ok() +} + +fn assert_single_variable_tampering_rejected(builder: R1csBuilder) +where + F: jolt_field::Field, +{ + let witness = builder.witness().expect("witness is assigned"); + let matrices = builder.into_matrices(); + assert!(matrices.check_witness(&witness).is_ok()); + + for variable in 1..witness.len() { + let mut tampered = witness.clone(); + tampered[variable] += F::one(); + assert!( + matrices.check_witness(&tampered).is_err(), + "tampering variable {variable} was accepted" + ); + } +} + +fn assert_selected_variable_tampering_rejected(builder: R1csBuilder, targets: &[usize]) +where + F: jolt_field::Field, +{ + let witness = builder.witness().expect("witness is assigned"); + let matrices = builder.into_matrices(); + assert!(matrices.check_witness(&witness).is_ok()); + + for &variable in targets { + let mut tampered = witness.clone(); + tampered[variable] += F::one(); + assert!( + matrices.check_witness(&tampered).is_err(), + "tampering variable {variable} was accepted" + ); + } +} + +fn variable(scalar: &AssignedScalar) -> Variable +where + F: jolt_field::Field, +{ + assert_eq!(scalar.lc.terms.len(), 1); + let (variable, coefficient) = scalar + .lc + .terms + .first() + .copied() + .expect("assigned scalar should be backed by one variable"); + assert_eq!(coefficient, F::one()); + variable +} diff --git a/crates/jolt-sumcheck/src/lib.rs b/crates/jolt-sumcheck/src/lib.rs index 601ce234f7..ea8a639a37 100644 --- a/crates/jolt-sumcheck/src/lib.rs +++ b/crates/jolt-sumcheck/src/lib.rs @@ -112,8 +112,9 @@ pub use proof::{ClearProof, ClearSumcheckProof, CompressedSumcheckProof, Sumchec #[cfg(feature = "r1cs")] pub use r1cs::{ allocate_sumcheck_r1cs_layout, append_sumcheck_r1cs_constraints, - append_sumcheck_r1cs_constraints_for_domain, SumcheckR1csError, SumcheckR1csLayout, - SumcheckR1csRound, SumcheckR1csRoundLayout, + append_sumcheck_r1cs_constraints_for_domain, append_sumcheck_r1cs_gadget_constraints, + append_sumcheck_r1cs_gadget_constraints_for_domain, SumcheckR1csError, SumcheckR1csGadgetRound, + SumcheckR1csLayout, SumcheckR1csRound, SumcheckR1csRoundLayout, }; pub use round_proof::{ClearRound, CompressedLabeledRoundPoly, LabeledRoundPoly, RoundMessage}; pub use scalar::SumcheckScalar; diff --git a/crates/jolt-sumcheck/src/r1cs.rs b/crates/jolt-sumcheck/src/r1cs.rs index c1214987b9..9b7f57cd23 100644 --- a/crates/jolt-sumcheck/src/r1cs.rs +++ b/crates/jolt-sumcheck/src/r1cs.rs @@ -1,21 +1,23 @@ -use jolt_field::Field; -use jolt_r1cs::{LinearCombination, R1csBuilder, Variable}; +use jolt_field::{Field, FromPrimitiveInt}; +use jolt_r1cs::{ + scalar_affine_combination, LinearCombination, R1csBuilder, ScalarGadget, Variable, +}; use thiserror::Error; use crate::{BooleanHypercube, SumcheckDomain, SumcheckStatement, VerifiedCommittedRound}; -pub trait SumcheckR1csRound { +pub trait SumcheckR1csRound { fn degree(&self) -> usize; - fn challenge(&self) -> F; + fn challenge(&self) -> LinearCombination; } -impl SumcheckR1csRound for VerifiedCommittedRound { +impl SumcheckR1csRound for VerifiedCommittedRound { fn degree(&self) -> usize { self.degree } - fn challenge(&self) -> F { - self.challenge + fn challenge(&self) -> LinearCombination { + LinearCombination::constant(self.challenge) } } @@ -33,6 +35,8 @@ pub enum SumcheckR1csError { LayoutRoundCountMismatch { expected: usize, actual: usize }, #[error("round {round_index} layout has no coefficient variables")] EmptyRoundLayout { round_index: usize }, + #[error("round {round_index} has no coefficient gadgets")] + EmptyRoundCoefficients { round_index: usize }, #[error( "round {round_index} layout has degree {actual} but sumcheck input has degree {expected}" )] @@ -86,6 +90,32 @@ pub struct SumcheckR1csRoundLayout { pub claim_out: Variable, } +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct SumcheckR1csGadgetRound { + pub claim_in: S, + pub coefficients: Vec, + pub challenge: S, + pub claim_out: S, +} + +impl SumcheckR1csGadgetRound +where + S: ScalarGadget, +{ + pub fn new(claim_in: S, coefficients: Vec, challenge: S, claim_out: S) -> Self { + Self { + claim_in, + coefficients, + challenge, + claim_out, + } + } + + pub fn degree(&self) -> usize { + self.coefficients.len().saturating_sub(1) + } +} + impl SumcheckR1csRoundLayout { pub fn degree(&self) -> usize { self.coefficients.len().saturating_sub(1) @@ -183,6 +213,46 @@ where Ok(()) } +pub fn append_sumcheck_r1cs_gadget_constraints( + builder: &mut R1csBuilder, + statement: SumcheckStatement, + rounds: &[SumcheckR1csGadgetRound], +) -> Result<(), SumcheckR1csError> +where + S: ScalarGadget, + BooleanHypercube: SumcheckDomain, +{ + append_sumcheck_r1cs_gadget_constraints_for_domain(builder, statement, rounds, BooleanHypercube) +} + +pub fn append_sumcheck_r1cs_gadget_constraints_for_domain( + builder: &mut R1csBuilder, + statement: SumcheckStatement, + rounds: &[SumcheckR1csGadgetRound], + domain: D, +) -> Result<(), SumcheckR1csError> +where + S: ScalarGadget, + D: SumcheckDomain, +{ + validate_gadget_rounds(statement, rounds)?; + + for (round_index, round) in rounds.iter().enumerate() { + let round_sum_coefficients = + domain + .round_sum_coefficients(round.degree()) + .map_err( + |source| SumcheckR1csError::RoundSumCoefficientsUnavailable { + round_index, + reason: source.to_string(), + }, + )?; + append_gadget_round_constraints(builder, round_index, round, &round_sum_coefficients)?; + } + + Ok(()) +} + fn validate_layout( num_vars: usize, statement: SumcheckStatement, @@ -250,6 +320,7 @@ fn validate_rounds_statement( rounds: &[R], ) -> Result<(), SumcheckR1csError> where + F: Field, R: SumcheckR1csRound, { if statement.num_vars != rounds.len() { @@ -272,6 +343,36 @@ where Ok(()) } +fn validate_gadget_rounds( + statement: SumcheckStatement, + rounds: &[SumcheckR1csGadgetRound], +) -> Result<(), SumcheckR1csError> +where + S: ScalarGadget, +{ + if statement.num_vars != rounds.len() { + return Err(SumcheckR1csError::WrongNumberOfRounds { + expected: statement.num_vars, + actual: rounds.len(), + }); + } + + for (round_index, round) in rounds.iter().enumerate() { + if round.coefficients.is_empty() { + return Err(SumcheckR1csError::EmptyRoundCoefficients { round_index }); + } + if round.degree() > statement.degree { + return Err(SumcheckR1csError::DegreeBoundExceeded { + round_index, + bound: statement.degree, + actual: round.degree(), + }); + } + } + + Ok(()) +} + fn validate_variable(variable: Variable, num_vars: usize) -> Result<(), SumcheckR1csError> { if variable.index() >= num_vars { return Err(SumcheckR1csError::LayoutVariableOutOfBounds { variable, num_vars }); @@ -284,15 +385,46 @@ fn append_round_constraints( builder: &mut R1csBuilder, round_index: usize, round: &SumcheckR1csRoundLayout, - challenge: F, + challenge: LinearCombination, round_sum_coefficients: &[F], ) -> Result<(), SumcheckR1csError> { let round_sum = round_sum_lc(round_index, round, round_sum_coefficients)?; builder.assert_equal(round_sum, round.claim_in); - builder.assert_equal( - polynomial_eval_lc(&round.coefficients, challenge), - round.claim_out, + let round_eval = polynomial_eval_at_challenge(builder, &round.coefficients, challenge); + builder.assert_equal(round_eval, round.claim_out); + Ok(()) +} + +fn append_gadget_round_constraints( + builder: &mut R1csBuilder, + round_index: usize, + round: &SumcheckR1csGadgetRound, + round_sum_coefficients: &[S::Scalar], +) -> Result<(), SumcheckR1csError> +where + S: ScalarGadget, +{ + if round_sum_coefficients.len() != round.coefficients.len() { + return Err(SumcheckR1csError::RoundSumCoefficientCountMismatch { + round_index, + expected: round.coefficients.len(), + actual: round_sum_coefficients.len(), + }); + } + + let round_sum = scalar_affine_combination( + builder, + S::Scalar::from_u64(0), + round_sum_coefficients + .iter() + .copied() + .zip(&round.coefficients), ); + round_sum.assert_equal(builder, &round.claim_in); + + let round_eval = polynomial_eval_gadget(builder, &round.coefficients, &round.challenge); + round_eval.assert_equal(builder, &round.claim_out); + Ok(()) } @@ -329,11 +461,64 @@ fn polynomial_eval_lc(coefficients: &[Variable], point: F) -> LinearCo result } +fn polynomial_eval_at_challenge( + builder: &mut R1csBuilder, + coefficients: &[Variable], + point: LinearCombination, +) -> LinearCombination { + if let Some(point) = point.as_constant() { + return polynomial_eval_lc(coefficients, point); + } + + let Some((&last, rest)) = coefficients.split_last() else { + return LinearCombination::zero(); + }; + + let mut evaluation = LinearCombination::variable(last); + for &coefficient in rest.iter().rev() { + evaluation = + builder.multiply(evaluation, point.clone()) + LinearCombination::variable(coefficient); + } + evaluation +} + +fn polynomial_eval_gadget( + builder: &mut R1csBuilder, + coefficients: &[S], + point: &S, +) -> S +where + S: ScalarGadget, +{ + let Some((last, rest)) = coefficients.split_last() else { + return S::constant(S::Scalar::from_u64(0)); + }; + + let mut evaluation = last.clone(); + for coefficient in rest.iter().rev() { + evaluation = evaluation.mul(builder, point); + evaluation = evaluation.add(builder, coefficient); + } + evaluation +} + #[cfg(test)] #[expect(clippy::expect_used, reason = "tests may panic on assertion failures")] mod tests { use super::*; - use jolt_field::{Fr, FromPrimitiveInt}; + use jolt_field::{Fq, Fr, FromPrimitiveInt}; + use jolt_r1cs::{AssignedScalar, FqVar}; + + type NativeGadgetRoundFixture = ( + R1csBuilder, + Vec<(&'static str, Variable)>, + SumcheckR1csGadgetRound>, + ); + type NonnativeGadgetRoundFixture = ( + R1csBuilder, + Vec<(&'static str, Variable)>, + SumcheckR1csGadgetRound, + ); #[derive(Clone, Copy, Debug, PartialEq, Eq)] struct Round { @@ -346,8 +531,24 @@ mod tests { self.degree } - fn challenge(&self) -> Fr { - self.challenge + fn challenge(&self) -> LinearCombination { + LinearCombination::constant(self.challenge) + } + } + + #[derive(Clone, Debug, PartialEq, Eq)] + struct LinearChallengeRound { + degree: usize, + challenge: LinearCombination, + } + + impl SumcheckR1csRound for LinearChallengeRound { + fn degree(&self) -> usize { + self.degree + } + + fn challenge(&self) -> LinearCombination { + self.challenge.clone() } } @@ -376,6 +577,68 @@ mod tests { assign(builder, round.claim_out, claim_out); } + fn native_gadget_round( + input_claim: u64, + coefficients: &[u64], + challenge: u64, + output_claim: u64, + ) -> NativeGadgetRoundFixture { + let mut builder = R1csBuilder::::new(); + let input_claim = AssignedScalar::alloc(&mut builder, Fr::from_u64(input_claim)); + let coefficients = coefficients + .iter() + .map(|&coefficient| AssignedScalar::alloc(&mut builder, Fr::from_u64(coefficient))) + .collect::>(); + let challenge = AssignedScalar::alloc(&mut builder, Fr::from_u64(challenge)); + let output_claim = AssignedScalar::alloc(&mut builder, Fr::from_u64(output_claim)); + let tamper_targets = std::iter::once(("input claim", variable(&input_claim))) + .chain(coefficients.iter().enumerate().map(|(index, coefficient)| { + let label = match index { + 0 => "coefficient 0", + 1 => "coefficient 1", + _ => "coefficient", + }; + (label, variable(coefficient)) + })) + .chain([ + ("challenge", variable(&challenge)), + ("output claim", variable(&output_claim)), + ]) + .collect(); + let round = + SumcheckR1csGadgetRound::new(input_claim, coefficients, challenge, output_claim); + + (builder, tamper_targets, round) + } + + fn nonnative_gadget_round( + input_claim: u64, + coefficients: &[u64], + challenge: u64, + output_claim: u64, + ) -> NonnativeGadgetRoundFixture { + let mut builder = R1csBuilder::::new(); + let input_claim = FqVar::alloc(&mut builder, Fq::from_u64(input_claim)); + let coefficients = coefficients + .iter() + .map(|&coefficient| FqVar::alloc(&mut builder, Fq::from_u64(coefficient))) + .collect::>(); + let challenge = FqVar::alloc(&mut builder, Fq::from_u64(challenge)); + let output_claim = FqVar::alloc(&mut builder, Fq::from_u64(output_claim)); + let mut tamper_targets = vec![("input claim limb", variable(&input_claim.limbs()[0]))]; + if let Some(coefficient) = coefficients.get(1) { + tamper_targets.push(("coefficient limb", variable(&coefficient.limbs()[0]))); + } + tamper_targets.extend([ + ("challenge limb", variable(&challenge.limbs()[0])), + ("output claim limb", variable(&output_claim.limbs()[0])), + ]); + let round = + SumcheckR1csGadgetRound::new(input_claim, coefficients, challenge, output_claim); + + (builder, tamper_targets, round) + } + #[test] fn emits_satisfied_round_constraints() { let statement = SumcheckStatement::new(2, 1); @@ -395,6 +658,167 @@ mod tests { assert!(builder.into_matrices().check_witness(&witness).is_ok()); } + #[test] + fn supports_variable_challenge_transition() { + let statement = SumcheckStatement::new(1, 2); + let mut builder = R1csBuilder::::new(); + let challenge = builder.alloc(Fr::from_u64(2)); + let rounds = [LinearChallengeRound { + degree: 2, + challenge: LinearCombination::variable(challenge), + }]; + + let layout = allocate_sumcheck_r1cs_layout(&mut builder, statement, &rounds) + .expect("layout should allocate"); + assign(&mut builder, layout.input_claim, 15); + assign_round(&mut builder, &layout.rounds[0], &[3, 4, 5], 31); + append_sumcheck_r1cs_constraints(&mut builder, statement, &rounds, &layout) + .expect("constraints should build"); + + let witness = builder.witness().expect("witness is assigned"); + assert!(builder.into_matrices().check_witness(&witness).is_ok()); + } + + #[test] + fn rejects_bad_variable_challenge_transition() { + let statement = SumcheckStatement::new(1, 2); + let mut builder = R1csBuilder::::new(); + let challenge = builder.alloc(Fr::from_u64(3)); + let rounds = [LinearChallengeRound { + degree: 2, + challenge: LinearCombination::variable(challenge), + }]; + + let layout = allocate_sumcheck_r1cs_layout(&mut builder, statement, &rounds) + .expect("layout should allocate"); + assign(&mut builder, layout.input_claim, 15); + assign_round(&mut builder, &layout.rounds[0], &[3, 4, 5], 31); + append_sumcheck_r1cs_constraints(&mut builder, statement, &rounds, &layout) + .expect("constraints should build"); + + let witness = builder.witness().expect("witness is assigned"); + assert!(builder.into_matrices().check_witness(&witness).is_err()); + } + + #[test] + fn native_gadget_constraints_accept_valid_sumcheck() { + let statement = SumcheckStatement::new(1, 2); + let (mut builder, _, round) = native_gadget_round(15, &[3, 4, 5], 2, 31); + + append_sumcheck_r1cs_gadget_constraints(&mut builder, statement, &[round]) + .expect("constraints should build"); + + assert!(builder_accepts(builder)); + } + + #[test] + fn native_gadget_constraints_reject_tampering() { + let statement = SumcheckStatement::new(1, 2); + let (mut builder, targets, round) = native_gadget_round(15, &[3, 4, 5], 2, 31); + + append_sumcheck_r1cs_gadget_constraints(&mut builder, statement, &[round]) + .expect("constraints should build"); + + assert_tampering_rejected(builder, targets); + } + + #[test] + fn native_gadget_constraints_reject_bad_round_sum() { + let statement = SumcheckStatement::new(1, 2); + let (mut builder, _, round) = native_gadget_round(16, &[3, 4, 5], 2, 31); + + append_sumcheck_r1cs_gadget_constraints(&mut builder, statement, &[round]) + .expect("constraints should build"); + + assert!(builder_rejects(builder)); + } + + #[test] + fn native_gadget_constraints_reject_bad_output_claim() { + let statement = SumcheckStatement::new(1, 2); + let (mut builder, _, round) = native_gadget_round(15, &[3, 4, 5], 2, 32); + + append_sumcheck_r1cs_gadget_constraints(&mut builder, statement, &[round]) + .expect("constraints should build"); + + assert!(builder_rejects(builder)); + } + + #[test] + fn gadget_constraints_reject_empty_coefficients() { + let statement = SumcheckStatement::new(1, 2); + let mut builder = R1csBuilder::::new(); + let round = SumcheckR1csGadgetRound::new( + AssignedScalar::alloc(&mut builder, Fr::from_u64(0)), + Vec::new(), + AssignedScalar::alloc(&mut builder, Fr::from_u64(2)), + AssignedScalar::alloc(&mut builder, Fr::from_u64(0)), + ); + + let error = append_sumcheck_r1cs_gadget_constraints(&mut builder, statement, &[round]) + .expect_err("empty coefficient list should be rejected"); + + assert_eq!( + error, + SumcheckR1csError::EmptyRoundCoefficients { round_index: 0 } + ); + } + + #[test] + fn nonnative_gadget_constraints_accept_valid_sumcheck() { + let statement = SumcheckStatement::new(1, 2); + let (mut builder, _, round) = nonnative_gadget_round(15, &[3, 4, 5], 2, 31); + + append_sumcheck_r1cs_gadget_constraints(&mut builder, statement, &[round]) + .expect("constraints should build"); + + assert!(builder_accepts(builder)); + } + + #[test] + fn nonnative_gadget_constraints_reject_tampering() { + let statement = SumcheckStatement::new(1, 2); + let (mut builder, targets, round) = nonnative_gadget_round(15, &[3, 4, 5], 2, 31); + + append_sumcheck_r1cs_gadget_constraints(&mut builder, statement, &[round]) + .expect("constraints should build"); + + assert_tampering_rejected(builder, targets); + } + + #[test] + fn nonnative_gadget_constraints_reject_bad_round_sum() { + let statement = SumcheckStatement::new(1, 2); + let (mut builder, _, round) = nonnative_gadget_round(16, &[3, 4, 5], 2, 31); + + append_sumcheck_r1cs_gadget_constraints(&mut builder, statement, &[round]) + .expect("constraints should build"); + + assert!(builder_rejects(builder)); + } + + #[test] + fn nonnative_gadget_constraints_reject_bad_challenge_transition() { + let statement = SumcheckStatement::new(1, 2); + let (mut builder, _, round) = nonnative_gadget_round(15, &[3, 4, 5], 3, 31); + + append_sumcheck_r1cs_gadget_constraints(&mut builder, statement, &[round]) + .expect("constraints should build"); + + assert!(builder_rejects(builder)); + } + + #[test] + fn nonnative_gadget_constraints_reject_bad_output_claim() { + let statement = SumcheckStatement::new(1, 2); + let (mut builder, _, round) = nonnative_gadget_round(15, &[3, 4, 5], 2, 32); + + append_sumcheck_r1cs_gadget_constraints(&mut builder, statement, &[round]) + .expect("constraints should build"); + + assert!(builder_rejects(builder)); + } + #[test] fn rejects_bad_round_sum() { let statement = SumcheckStatement::new(1, 1); @@ -555,4 +979,56 @@ mod tests { } ); } + + fn builder_accepts(builder: R1csBuilder) -> bool + where + F: Field, + { + let witness = builder.witness().expect("witness is assigned"); + builder.into_matrices().check_witness(&witness).is_ok() + } + + fn builder_rejects(builder: R1csBuilder) -> bool + where + F: Field, + { + let witness = builder.witness().expect("witness is assigned"); + builder.into_matrices().check_witness(&witness).is_err() + } + + fn assert_tampering_rejected( + builder: R1csBuilder, + targets: impl IntoIterator, + ) where + F: Field, + { + let witness = builder.witness().expect("witness is assigned"); + let matrices = builder.into_matrices(); + assert!(matrices.check_witness(&witness).is_ok()); + + for (label, variable) in targets { + let mut tampered = witness.clone(); + tampered[variable.index()] += F::one(); + assert!( + matrices.check_witness(&tampered).is_err(), + "{label} accepted after tampering variable {}", + variable.index() + ); + } + } + + fn variable(scalar: &AssignedScalar) -> Variable + where + F: Field, + { + assert_eq!(scalar.lc.terms.len(), 1); + let (variable, coefficient) = scalar + .lc + .terms + .first() + .copied() + .expect("linear combination has one term"); + assert_eq!(coefficient, F::one()); + variable + } } diff --git a/crates/jolt-transcript/Cargo.toml b/crates/jolt-transcript/Cargo.toml index dd35eb939c..2f725080d5 100644 --- a/crates/jolt-transcript/Cargo.toml +++ b/crates/jolt-transcript/Cargo.toml @@ -16,6 +16,8 @@ default = ["transcript-blake2b", "transcript-keccak", "transcript-poseidon"] transcript-blake2b = ["spongefish/blake2", "dep:blake2"] transcript-keccak = ["spongefish/keccak"] transcript-poseidon = ["dep:light-poseidon"] +r1cs = ["dep:jolt-r1cs"] +poseidon-r1cs = ["transcript-poseidon", "r1cs"] [dependencies] ark-bn254.workspace = true @@ -25,6 +27,7 @@ spongefish = { workspace = true, features = ["ark-ff"] } blake2 = { workspace = true, optional = true } digest = { workspace = true } jolt-field = { path = "../jolt-field", features = ["bn254"] } +jolt-r1cs = { path = "../jolt-r1cs", optional = true } rand.workspace = true [dev-dependencies] diff --git a/crates/jolt-transcript/src/lib.rs b/crates/jolt-transcript/src/lib.rs index 6983acbd75..c36ddbe0e7 100644 --- a/crates/jolt-transcript/src/lib.rs +++ b/crates/jolt-transcript/src/lib.rs @@ -23,6 +23,8 @@ mod legacy; #[cfg(feature = "transcript-poseidon")] mod poseidon; mod prover; +#[cfg(feature = "r1cs")] +pub mod r1cs; mod setup; mod verifier; diff --git a/crates/jolt-transcript/src/r1cs/mod.rs b/crates/jolt-transcript/src/r1cs/mod.rs new file mode 100644 index 0000000000..9efbd1ec06 --- /dev/null +++ b/crates/jolt-transcript/src/r1cs/mod.rs @@ -0,0 +1,143 @@ +//! In-circuit Fiat-Shamir transcript interfaces. +//! +//! The base trait owns transcript initialization and scalar challenge +//! production. Absorption is split by capability: algebraic transcripts absorb +//! field elements directly, while byte-oriented transcripts can later absorb +//! constrained bytes or bits. + +use jolt_field::Field; +use jolt_r1cs::{AssignedScalar, R1csBuilder}; + +#[cfg(feature = "transcript-poseidon")] +mod poseidon; + +#[cfg(feature = "transcript-poseidon")] +pub use poseidon::PoseidonR1csTranscript; + +/// Transcript operations shared by all in-circuit Fiat-Shamir backends. +pub trait R1csTranscript { + /// The challenge representation returned by this transcript. + type Challenge; + + /// Creates a transcript with the provided protocol label. + fn new(builder: &mut R1csBuilder, label: &'static [u8]) -> Self; + + /// Squeezes a scalar challenge and advances transcript state. + fn challenge_scalar(&mut self, builder: &mut R1csBuilder) -> Self::Challenge; +} + +/// In-circuit transcript backend that absorbs algebraic field elements. +pub trait R1csAlgebraicTranscript: + R1csTranscript> +{ + /// Absorbs an assigned scalar into the transcript. + fn absorb_scalar(&mut self, builder: &mut R1csBuilder, value: AssignedScalar); + + /// Absorbs a constant scalar into the transcript. + fn absorb_constant_scalar(&mut self, builder: &mut R1csBuilder, value: F) { + self.absorb_scalar(builder, AssignedScalar::constant(value)); + } + + /// Absorbs a constant `u64` domain-separation word. + fn absorb_u64(&mut self, builder: &mut R1csBuilder, value: u64); + + /// Absorbs a constant protocol label. + fn absorb_label(&mut self, builder: &mut R1csBuilder, label: &'static [u8]); + + /// Absorbs a packed protocol label and length/count word. + fn absorb_label_with_len( + &mut self, + builder: &mut R1csBuilder, + label: &'static [u8], + len: u64, + ); +} + +/// Jolt proof transcript operations over algebraic in-circuit backends. +pub trait R1csJoltTranscript: R1csAlgebraicTranscript { + /// Appends a domain-separation label. + fn append_label(&mut self, builder: &mut R1csBuilder, label: &'static [u8]) { + self.absorb_label(builder, label); + } + + /// Appends a labeled `u64`. + fn append_u64(&mut self, builder: &mut R1csBuilder, label: &'static [u8], value: u64) { + self.absorb_label(builder, label); + self.absorb_u64(builder, value); + } + + /// Appends one labeled scalar. + fn append_scalar( + &mut self, + builder: &mut R1csBuilder, + label: &'static [u8], + value: AssignedScalar, + ) { + self.absorb_label(builder, label); + self.absorb_scalar(builder, value); + } + + /// Appends a labeled scalar slice. + fn append_scalars( + &mut self, + builder: &mut R1csBuilder, + label: &'static [u8], + values: &[AssignedScalar], + ) { + self.absorb_label_with_len(builder, label, values.len() as u64); + for value in values { + self.absorb_scalar(builder, value.clone()); + } + } +} + +/// In-circuit transcript backend that absorbs byte-oriented values. +pub trait R1csByteTranscript: R1csTranscript { + /// The in-circuit representation of one byte. + type Byte; + + /// Absorbs constrained byte values into the transcript. + fn absorb_bytes(&mut self, builder: &mut R1csBuilder, bytes: &[Self::Byte]); + + /// Absorbs public constant bytes into the transcript. + fn absorb_constant_bytes(&mut self, builder: &mut R1csBuilder, bytes: &'static [u8]); +} + +/// Jolt proof transcript byte operations over in-circuit backends. +pub trait R1csJoltByteTranscript: R1csJoltTranscript + R1csByteTranscript { + /// Appends labeled byte values. + fn append_bytes( + &mut self, + builder: &mut R1csBuilder, + label: &'static [u8], + bytes: &[Self::Byte], + ) { + self.absorb_label_with_len(builder, label, bytes.len() as u64); + self.absorb_bytes(builder, bytes); + } + + /// Appends labeled constant bytes. + fn append_constant_bytes( + &mut self, + builder: &mut R1csBuilder, + label: &'static [u8], + bytes: &'static [u8], + ) { + self.absorb_label_with_len(builder, label, bytes.len() as u64); + self.absorb_constant_bytes(builder, bytes); + } +} + +impl R1csJoltTranscript for T +where + F: Field, + T: R1csAlgebraicTranscript, +{ +} + +impl R1csJoltByteTranscript for T +where + F: Field, + T: R1csJoltTranscript + R1csByteTranscript, +{ +} diff --git a/crates/jolt-transcript/src/r1cs/poseidon.rs b/crates/jolt-transcript/src/r1cs/poseidon.rs new file mode 100644 index 0000000000..57b07f05c5 --- /dev/null +++ b/crates/jolt-transcript/src/r1cs/poseidon.rs @@ -0,0 +1,568 @@ +//! Jolt Poseidon proof transcript gadget for BN254. + +use std::sync::OnceLock; + +use jolt_field::{Fr, FromPrimitiveInt}; +use jolt_r1cs::{AssignedScalar, LinearCombination, R1csBuilder}; +use light_poseidon::parameters::bn254_x5; + +use super::{R1csAlgebraicTranscript, R1csByteTranscript, R1csTranscript}; + +const POSEIDON_INPUTS: usize = 3; +const POSEIDON_WIDTH: usize = POSEIDON_INPUTS + 1; + +#[derive(Clone, Debug)] +struct PoseidonR1csParameters { + ark: Vec, + mds: Vec>, + full_rounds: usize, + partial_rounds: usize, + width: usize, +} + +/// Poseidon Fiat-Shamir transcript encoded as R1CS constraints. +/// +/// This mirrors `jolt-core`'s `transcript-poseidon` proof transcript: every +/// raw absorb and challenge hashes `(state, n_rounds, payload)`, scalar payloads +/// are absorbed as BN254 field elements, and challenges are full field elements. +#[derive(Clone, Debug)] +pub struct PoseidonR1csTranscript { + state: AssignedScalar, + round: u64, +} + +impl PoseidonR1csTranscript { + /// Returns the current assigned transcript state. + pub fn state(&self) -> &AssignedScalar { + &self.state + } +} + +impl R1csTranscript for PoseidonR1csTranscript { + type Challenge = AssignedScalar; + + fn new(builder: &mut R1csBuilder, label: &'static [u8]) -> Self { + let label = label_scalar(label); + let state = poseidon_permutation( + builder, + [ + AssignedScalar::constant(label), + AssignedScalar::constant(zero()), + AssignedScalar::constant(zero()), + ], + ); + Self { state, round: 0 } + } + + fn challenge_scalar(&mut self, builder: &mut R1csBuilder) -> AssignedScalar { + let challenge = poseidon_permutation( + builder, + [ + self.state.clone(), + AssignedScalar::constant(round_tag(self.round)), + AssignedScalar::constant(zero()), + ], + ); + self.state = challenge.clone(); + self.round += 1; + challenge + } +} + +impl R1csAlgebraicTranscript for PoseidonR1csTranscript { + fn absorb_scalar(&mut self, builder: &mut R1csBuilder, value: AssignedScalar) { + self.state = poseidon_permutation( + builder, + [ + self.state.clone(), + AssignedScalar::constant(round_tag(self.round)), + value, + ], + ); + self.round += 1; + } + + fn absorb_u64(&mut self, builder: &mut R1csBuilder, value: u64) { + self.absorb_constant_scalar(builder, Fr::from_u64(value)); + } + + fn absorb_label(&mut self, builder: &mut R1csBuilder, label: &'static [u8]) { + self.absorb_constant_scalar(builder, label_scalar(label)); + } + + fn absorb_label_with_len( + &mut self, + builder: &mut R1csBuilder, + label: &'static [u8], + len: u64, + ) { + self.absorb_constant_scalar(builder, label_with_len_scalar(label, len)); + } +} + +impl R1csByteTranscript for PoseidonR1csTranscript { + type Byte = AssignedScalar; + + fn absorb_bytes(&mut self, builder: &mut R1csBuilder, bytes: &[Self::Byte]) { + let mut chunks = bytes.chunks(32); + let first = chunks + .next() + .map_or_else(|| AssignedScalar::constant(zero()), pack_bytes); + let mut current = poseidon_permutation( + builder, + [ + self.state.clone(), + AssignedScalar::constant(round_tag(self.round)), + first, + ], + ); + + for chunk in chunks { + let chunk = pack_bytes(chunk); + current = + poseidon_permutation(builder, [current, AssignedScalar::constant(zero()), chunk]); + } + + self.state = current; + self.round += 1; + } + + fn absorb_constant_bytes(&mut self, builder: &mut R1csBuilder, bytes: &'static [u8]) { + let mut chunks = bytes.chunks(32); + let first = chunks.next().map_or_else(zero, bytes_scalar); + let mut current = poseidon_permutation( + builder, + [ + self.state.clone(), + AssignedScalar::constant(round_tag(self.round)), + AssignedScalar::constant(first), + ], + ); + + for chunk in chunks { + current = poseidon_permutation( + builder, + [ + current, + AssignedScalar::constant(zero()), + AssignedScalar::constant(bytes_scalar(chunk)), + ], + ); + } + + self.state = current; + self.round += 1; + } +} + +fn label_scalar(label: &[u8]) -> Fr { + assert!( + label.len() <= 32, + "label must fit in one Jolt transcript word" + ); + Fr::from_le_bytes_mod_order(label) +} + +fn bytes_scalar(bytes: &[u8]) -> Fr { + assert!( + bytes.len() <= 32, + "Poseidon byte chunks must fit in one BN254 scalar" + ); + Fr::from_le_bytes_mod_order(bytes) +} + +fn label_with_len_scalar(label: &[u8], len: u64) -> Fr { + assert!( + label.len() <= 24, + "label must leave 8 bytes for the Jolt transcript length word" + ); + let mut packed = [0u8; 32]; + packed[..label.len()].copy_from_slice(label); + packed[24..32].copy_from_slice(&len.to_be_bytes()); + Fr::from_le_bytes_mod_order(&packed) +} + +fn round_tag(round: u64) -> Fr { + Fr::from_u64(round) +} + +fn zero() -> Fr { + Fr::from_u64(0) +} + +fn pack_bytes(bytes: &[AssignedScalar]) -> AssignedScalar { + assert!( + bytes.len() <= 32, + "Poseidon byte chunks must fit in one BN254 scalar" + ); + let mut value = zero(); + let mut lc = LinearCombination::zero(); + let mut coefficient = Fr::from_u64(1); + let radix = Fr::from_u64(256); + for byte in bytes { + value += byte.value * coefficient; + lc = lc + byte.lc.clone().scale(coefficient); + coefficient *= radix; + } + AssignedScalar::new(value, lc) +} + +#[cfg(test)] +fn assigned_bytes(builder: &mut R1csBuilder, bytes: &[u8]) -> Vec> { + bytes + .iter() + .map(|byte| AssignedScalar::alloc(builder, Fr::from_u64(u64::from(*byte)))) + .collect() +} + +fn poseidon_permutation( + builder: &mut R1csBuilder, + inputs: [AssignedScalar; POSEIDON_INPUTS], +) -> AssignedScalar { + let params = poseidon_parameters(); + let mut state = Vec::with_capacity(params.width); + state.push(AssignedScalar::constant(zero())); + state.extend(inputs); + + let all_rounds = params.full_rounds + params.partial_rounds; + let half_rounds = params.full_rounds / 2; + + for round in 0..half_rounds { + apply_ark(&mut state, round, params); + apply_sbox_full(builder, &mut state); + apply_mds(builder, &mut state, params); + } + + for round in half_rounds..half_rounds + params.partial_rounds { + apply_ark(&mut state, round, params); + apply_sbox_partial(builder, &mut state); + apply_mds(builder, &mut state, params); + } + + for round in half_rounds + params.partial_rounds..all_rounds { + apply_ark(&mut state, round, params); + apply_sbox_full(builder, &mut state); + apply_mds(builder, &mut state, params); + } + + state[0].clone() +} + +#[cfg(test)] +fn poseidon_hash(inputs: [Fr; POSEIDON_INPUTS]) -> Fr { + let params = poseidon_parameters(); + let mut state = Vec::with_capacity(params.width); + state.push(zero()); + state.extend(inputs); + + let all_rounds = params.full_rounds + params.partial_rounds; + let half_rounds = params.full_rounds / 2; + + for round in 0..half_rounds { + apply_ark_values(&mut state, round, params); + apply_sbox_full_values(&mut state); + apply_mds_values(&mut state, params); + } + + for round in half_rounds..half_rounds + params.partial_rounds { + apply_ark_values(&mut state, round, params); + state[0] = pow5_value(state[0]); + apply_mds_values(&mut state, params); + } + + for round in half_rounds + params.partial_rounds..all_rounds { + apply_ark_values(&mut state, round, params); + apply_sbox_full_values(&mut state); + apply_mds_values(&mut state, params); + } + + state[0] +} + +fn apply_ark(state: &mut [AssignedScalar], round: usize, params: &PoseidonR1csParameters) { + for (index, assigned) in state.iter_mut().enumerate() { + let constant = params.ark[round * params.width + index]; + assigned.value += constant; + assigned.lc = assigned.lc.clone() + LinearCombination::constant(constant); + } +} + +#[cfg(test)] +fn apply_ark_values(state: &mut [Fr], round: usize, params: &PoseidonR1csParameters) { + for (index, value) in state.iter_mut().enumerate() { + *value += params.ark[round * params.width + index]; + } +} + +fn apply_sbox_full(builder: &mut R1csBuilder, state: &mut [AssignedScalar]) { + for assigned in state.iter_mut() { + *assigned = pow5(builder, assigned.clone()); + } +} + +#[cfg(test)] +fn apply_sbox_full_values(state: &mut [Fr]) { + for value in state { + *value = pow5_value(*value); + } +} + +fn apply_sbox_partial(builder: &mut R1csBuilder, state: &mut [AssignedScalar]) { + state[0] = pow5(builder, state[0].clone()); +} + +fn apply_mds( + builder: &mut R1csBuilder, + state: &mut Vec>, + params: &PoseidonR1csParameters, +) { + let previous = state.clone(); + state.clear(); + for row in 0..params.width { + let mut value = zero(); + let mut lc = LinearCombination::zero(); + for (assigned, &coefficient) in previous.iter().zip(¶ms.mds[row]) { + value += assigned.value * coefficient; + lc = lc + assigned.lc.clone().scale(coefficient); + } + let output = AssignedScalar::alloc(builder, value); + builder.assert_equal(output.lc.clone(), lc); + state.push(output); + } +} + +#[cfg(test)] +fn apply_mds_values(state: &mut Vec, params: &PoseidonR1csParameters) { + let previous = state.clone(); + state.clear(); + for row in 0..params.width { + let mut value = zero(); + for (input, &coefficient) in previous.iter().zip(¶ms.mds[row]) { + value += *input * coefficient; + } + state.push(value); + } +} + +fn pow5(builder: &mut R1csBuilder, value: AssignedScalar) -> AssignedScalar { + let square = multiply(builder, &value, &value); + let fourth = multiply(builder, &square, &square); + multiply(builder, &fourth, &value) +} + +#[cfg(test)] +fn pow5_value(value: Fr) -> Fr { + let square = value * value; + let fourth = square * square; + fourth * value +} + +fn multiply( + builder: &mut R1csBuilder, + lhs: &AssignedScalar, + rhs: &AssignedScalar, +) -> AssignedScalar { + AssignedScalar::new( + lhs.value * rhs.value, + builder.multiply(lhs.lc.clone(), rhs.lc.clone()), + ) +} + +fn poseidon_parameters() -> &'static PoseidonR1csParameters { + static PARAMS: OnceLock = OnceLock::new(); + PARAMS.get_or_init(load_poseidon_parameters) +} + +#[expect( + clippy::expect_used, + reason = "constant width-4 BN254 Poseidon parameters are generated by light-poseidon" +)] +fn load_poseidon_parameters() -> PoseidonR1csParameters { + let params = bn254_x5::get_poseidon_parameters::(POSEIDON_WIDTH as u8) + .expect("valid width-4 BN254 Poseidon parameters"); + PoseidonR1csParameters { + ark: params.ark.into_iter().map(Fr::from).collect(), + mds: params + .mds + .into_iter() + .map(|row| row.into_iter().map(Fr::from).collect()) + .collect(), + full_rounds: params.full_rounds, + partial_rounds: params.partial_rounds, + width: params.width, + } +} + +#[cfg(test)] +#[expect(clippy::expect_used, clippy::panic, reason = "tests may fail loudly")] +mod tests { + use super::*; + use crate::r1cs::{R1csJoltByteTranscript, R1csJoltTranscript}; + use jolt_r1cs::Variable; + use light_poseidon::{Poseidon, PoseidonHasher}; + + #[derive(Clone, Copy, Debug)] + struct NativeTranscript { + state: Fr, + round: u64, + } + + impl NativeTranscript { + fn new(label: &'static [u8]) -> Self { + Self { + state: poseidon_hash([label_scalar(label), zero(), zero()]), + round: 0, + } + } + + fn absorb_scalar(&mut self, value: Fr) { + self.state = poseidon_hash([self.state, round_tag(self.round), value]); + self.round += 1; + } + + fn absorb_bytes(&mut self, bytes: &[u8]) { + let mut chunks = bytes.chunks(32); + let first = chunks.next().map_or_else(zero, bytes_scalar); + let mut current = poseidon_hash([self.state, round_tag(self.round), first]); + for chunk in chunks { + current = poseidon_hash([current, zero(), bytes_scalar(chunk)]); + } + self.state = current; + self.round += 1; + } + + fn absorb_label(&mut self, label: &'static [u8]) { + self.absorb_scalar(label_scalar(label)); + } + + fn absorb_u64(&mut self, value: u64) { + self.absorb_scalar(Fr::from_u64(value)); + } + + fn challenge_scalar(&mut self) -> Fr { + self.state = poseidon_hash([self.state, round_tag(self.round), zero()]); + self.round += 1; + self.state + } + + fn append_scalar(&mut self, label: &'static [u8], value: Fr) { + self.absorb_label(label); + self.absorb_scalar(value); + } + + fn append_scalars(&mut self, label: &'static [u8], values: &[Fr]) { + self.absorb_scalar(label_with_len_scalar(label, values.len() as u64)); + for value in values { + self.absorb_scalar(*value); + } + } + + fn append_bytes(&mut self, label: &'static [u8], bytes: &[u8]) { + self.absorb_scalar(label_with_len_scalar(label, bytes.len() as u64)); + self.absorb_bytes(bytes); + } + } + + #[test] + fn native_permutation_matches_light_poseidon() { + let inputs = [Fr::from_u64(7), Fr::from_u64(11), Fr::from_u64(19)]; + + assert_eq!(poseidon_hash(inputs), light_poseidon_hash(inputs)); + } + + #[test] + fn poseidon_gadget_matches_native_permutation() { + let inputs = [Fr::from_u64(1), Fr::from_u64(2), Fr::from_u64(3)]; + let mut builder = R1csBuilder::new(); + let assigned_inputs = inputs.map(|input| AssignedScalar::alloc(&mut builder, input)); + + let output = poseidon_permutation(&mut builder, assigned_inputs); + + assert_eq!(output.value, poseidon_hash(inputs)); + let witness = builder.witness().expect("witness is assigned"); + let matrices = builder.into_matrices(); + assert!(matrices.check_witness(&witness).is_ok()); + assert!((500..=560).contains(&matrices.num_constraints)); + } + + #[test] + fn jolt_poseidon_transcript_matches_native_sequence() { + let mut builder = R1csBuilder::new(); + let mut r1cs = PoseidonR1csTranscript::new(&mut builder, b"Jolt"); + let mut native = NativeTranscript::new(b"Jolt"); + + r1cs.append_label(&mut builder, b"sumcheck"); + native.absorb_label(b"sumcheck"); + + r1cs.append_u64(&mut builder, b"round", 9); + native.absorb_label(b"round"); + native.absorb_u64(9); + + let scalar = AssignedScalar::alloc(&mut builder, Fr::from_u64(42)); + r1cs.append_scalar(&mut builder, b"sumcheck_claim", scalar); + native.append_scalar(b"sumcheck_claim", Fr::from_u64(42)); + + let scalars = [ + AssignedScalar::alloc(&mut builder, Fr::from_u64(3)), + AssignedScalar::alloc(&mut builder, Fr::from_u64(5)), + AssignedScalar::alloc(&mut builder, Fr::from_u64(8)), + ]; + r1cs.append_scalars(&mut builder, b"sumcheck_poly", &scalars); + native.append_scalars( + b"sumcheck_poly", + &[Fr::from_u64(3), Fr::from_u64(5), Fr::from_u64(8)], + ); + + let byte_payload = [0xabu8; 45]; + let assigned_payload = assigned_bytes(&mut builder, &byte_payload); + r1cs.append_bytes(&mut builder, b"inputs", &assigned_payload); + native.append_bytes(b"inputs", &byte_payload); + + r1cs.append_constant_bytes(&mut builder, b"preprocessing_digest", &[0xcdu8; 32]); + native.append_bytes(b"preprocessing_digest", &[0xcdu8; 32]); + + let first = r1cs.challenge_scalar(&mut builder); + let native_first = native.challenge_scalar(); + assert_eq!(first.value, native_first); + + r1cs.absorb_constant_scalar(&mut builder, Fr::from_u64(100)); + native.absorb_scalar(Fr::from_u64(100)); + + let second = r1cs.challenge_scalar(&mut builder); + let native_second = native.challenge_scalar(); + assert_eq!(second.value, native_second); + assert_eq!(r1cs.state().value, native.state); + + let witness = builder.witness().expect("witness is assigned"); + assert!(builder.into_matrices().check_witness(&witness).is_ok()); + } + + #[test] + fn tampered_challenge_witness_fails_constraints() { + let mut builder = R1csBuilder::new(); + let mut transcript = PoseidonR1csTranscript::new(&mut builder, b"Jolt"); + transcript.absorb_constant_scalar(&mut builder, Fr::from_u64(42)); + let challenge = transcript.challenge_scalar(&mut builder); + + let mut witness = builder.witness().expect("witness is assigned"); + let variable = challenge_variable(&challenge); + witness[variable.index()] += Fr::from_u64(1); + + assert!(builder.into_matrices().check_witness(&witness).is_err()); + } + + fn light_poseidon_hash(inputs: [Fr; POSEIDON_INPUTS]) -> Fr { + let mut poseidon = + Poseidon::::new_circom(POSEIDON_INPUTS).expect("Poseidon init"); + let inputs = inputs.map(ark_bn254::Fr::from); + poseidon.hash(&inputs).expect("Poseidon hash").into() + } + + fn challenge_variable(challenge: &AssignedScalar) -> Variable { + let [(variable, coefficient)] = challenge.lc.terms.as_slice() else { + panic!("challenge should be represented by one allocated variable"); + }; + assert_eq!(*coefficient, Fr::from_u64(1)); + *variable + } +} diff --git a/crates/jolt-verifier/src/verifier.rs b/crates/jolt-verifier/src/verifier.rs index 6285db612c..48a1d45e7c 100644 --- a/crates/jolt-verifier/src/verifier.rs +++ b/crates/jolt-verifier/src/verifier.rs @@ -131,12 +131,15 @@ where .as_ref() .ok_or(VerifierError::MissingVectorCommitmentSetup)?; transcript.append(&Label(b"BlindFold")); - blindfold - .protocol - .verify::(proof.blindfold_proof()?, vc_setup, &mut transcript) - .map_err(|error| VerifierError::BlindFoldVerificationFailed { - reason: error.to_string(), - })?; + jolt_blindfold::verify::( + &blindfold.protocol, + proof.blindfold_proof()?, + vc_setup, + &mut transcript, + ) + .map_err(|error| VerifierError::BlindFoldVerificationFailed { + reason: error.to_string(), + })?; return Ok(()); }