diff --git a/piop/benches/multipoint_eval.rs b/piop/benches/multipoint_eval.rs index 491870a8..929ff8ca 100644 --- a/piop/benches/multipoint_eval.rs +++ b/piop/benches/multipoint_eval.rs @@ -84,9 +84,11 @@ fn bench_multipoint_eval(c: &mut Criterion, num_vars: usize, num_cols: usize) { MultipointEval::::prove_as_subprotocol( &mut t, &trace_mles, + &[], &eval_point, &up_evals, &down_evals, + &[], &shifts, &field_cfg, ) @@ -104,9 +106,11 @@ fn bench_multipoint_eval(c: &mut Criterion, num_vars: usize, num_cols: usize) { let (proof, prover_state) = MultipointEval::::prove_as_subprotocol( &mut prover_transcript, &trace_mles, + &[], &eval_point, &up_evals, &down_evals, + &[], &shifts, &field_cfg, ) @@ -131,13 +135,20 @@ fn bench_multipoint_eval(c: &mut Criterion, num_vars: usize, num_cols: usize) { &eval_point, &up_evals, &down_evals, + &[], &shifts, num_vars, &field_cfg, ) .expect("verifier failed"); - MultipointEval::::verify_subclaim(&subclaim, &open_evals, &shifts, &field_cfg) - .expect("subclaim check failed"); + MultipointEval::::verify_subclaim( + &subclaim, + &open_evals, + &[], + &shifts, + &field_cfg, + ) + .expect("subclaim check failed"); }, BatchSize::SmallInput, ); diff --git a/piop/src/multipoint_eval.rs b/piop/src/multipoint_eval.rs index c160bdc2..8dffcb5a 100644 --- a/piop/src/multipoint_eval.rs +++ b/piop/src/multipoint_eval.rs @@ -1,26 +1,32 @@ //! Multi-point evaluation subprotocol. //! -//! Reduces two sets of MLE evaluation claims at a shared point r' - the -//! "up" evaluations `v_j(r')` and the "down" (shifted) evaluations -//! `v_j^{down}(r')` - to a single set of standard MLE evaluation claims -//! `v_j(r_0)` at a new random point `r_0` via one sumcheck. +//! Reduces MLE evaluation claims at a shared point r' - the "up" evaluations +//! `v_j(r')`, the "down" (shifted) evaluations `v_j^{down}(r')`, and optional +//! bit-op virtual evaluations - to a single set of standard MLE evaluation +//! claims at a new random point `r_0` via one sumcheck. //! //! The trace column MLEs are precombined into a single MLE -//! `precombined(b) = \sum_j \gamma_j * v_j(b)` before entering the sumcheck, so -//! the prover works with only 3 MLEs (`eq`, `next`, `precombined`) regardless -//! of the number of columns. The sumcheck proves: +//! `precombined(b) = \sum_j \gamma_j * v_j(b) +//! + \sum_l \gamma_l^bit * bit_op_l(b)` +//! before entering the sumcheck, so the prover works with only 3 MLE groups +//! (`eq`, `next`, `precombined`) regardless of the number of columns. The +//! sumcheck proves: //! ```text -//! \sum_b [eq(b, r') * \sum_j \gamma_j * v_j(b) +//! \sum_b [eq(b, r') * (\sum_j \gamma_j * v_j(b) +//! + \sum_l \gamma_l^bit * bit_op_l(b)) //! + \sum_k \alpha_k * next_{c_k}(r', b) * v_{src_k}(b)] -//! = \sum_j \gamma_j * up_eval_j + \sum_k \alpha_k * down_eval_k +//! = \sum_j \gamma_j * up_eval_j +//! + \sum_l \gamma_l^bit * bit_op_eval_l +//! + \sum_k \alpha_k * down_eval_k //! ``` //! //! where `\alpha_k` batch the per-shift evaluation kernels and `\gamma_j` //! batch across columns. After the sumcheck reduces to point `r_0`, the -//! verifier calls [`MultipointEval::verify_subclaim`] with the `open_evals` -//! (the F_q-valued MLE evaluations at `r_0`, typically derived from -//! polynomial-valued `lifted_evals` via `\psi_a`) to check the final -//! consistency equation. +//! verifier calls [`MultipointEval::verify_subclaim`] with the committed-column +//! `open_evals` and the verifier-derived `bit_op_open_evals` to check the +//! final consistency equation. For bit-op virtuals, those open evaluations are +//! derived from source lifted openings via Lemma 2.3 rather than trusted as +//! independent witness openings. //! //! This corresponds to the T=2 case of Pi_{BMLE} in the paper. Following //! the paper, the prover sends only the polynomial-valued lifted evaluations @@ -89,6 +95,8 @@ pub struct Subclaim { pub gammas: Vec, /// Per-shift batching coefficients \alpha_k sampled during the protocol. pub alphas: Vec, + /// Per-bit-op-virtual batching coefficients sampled during the protocol. + pub bit_op_gammas: Vec, /// `eq(r_0, r')` — the equality selector at the sumcheck output point. pub eq_at_r0: F, /// Per-shift selector values at r_0: @@ -111,22 +119,32 @@ where /// Multi-point evaluation protocol prover. /// /// Runs the combined sumcheck over - /// `eq(b, r') * \sum_j(\gamma_j * v_j(b)) + \sum_k \alpha_k * - /// next_{c_k}(r', b) * v_{src_k}(b)`. Returns only the sumcheck proof - /// and the challenge point `r_0`; the caller is responsible for - /// computing and sending `lifted_evals` at `r_0`. - #[allow(clippy::arithmetic_side_effects)] + /// `eq(b, r') * (\sum_j(\gamma_j * v_j(b)) + /// + \sum_l(\gamma_l^bit * bit_op_l(b))) + /// + \sum_k \alpha_k * next_{c_k}(r', b) * v_{src_k}(b)`. + /// Returns only the sumcheck proof and the challenge point `r_0`; the + /// caller is responsible for computing and sending `lifted_evals` at + /// `r_0`. + #[allow(clippy::arithmetic_side_effects, clippy::too_many_arguments)] pub fn prove_as_subprotocol( transcript: &mut impl Transcript, trace_mles: &[DenseMultilinearExtension], + bit_op_mles: &[DenseMultilinearExtension], eval_point: &[F], up_evals: &[F], down_evals: &[F], + bit_op_evals: &[F], shifts: &[ShiftSpec], field_cfg: &F::Config, ) -> Result<(Proof, ProverState), MultipointEvalError> { let num_cols = trace_mles.len(); let num_down_cols = shifts.len(); + let num_bit_op_cols = bit_op_evals.len(); + assert_eq!( + bit_op_mles.len(), + num_bit_op_cols, + "bit_op_mles count must match bit_op_evals.len()", + ); let num_vars = eval_point.len(); let zero = F::zero_with_cfg(field_cfg); let zero_inner = zero.inner(); @@ -135,6 +153,7 @@ where // batching coefficients \gamma_1,...,\gamma_J. let alphas: Vec = transcript.get_field_challenges(num_down_cols, field_cfg); let gammas: Vec = transcript.get_field_challenges(num_cols, field_cfg); + let bit_op_gammas: Vec = transcript.get_field_challenges(num_bit_op_cols, field_cfg); // Step 2: Build the two selector MLEs: // eq_r(b) = eq(b, r') @@ -151,21 +170,29 @@ where .into_iter() .unzip(); - // Precombine up cols with gammas, precombined[b] = Σ_j γ_j trace[j][b] + // Precombine committed columns and bit-op virtual columns. let precombined = { let evaluations: Vec<_> = cfg_into_iter!(0..1 << num_vars) .map(|b| { - gammas - .iter() - .enumerate() - .fold(zero.clone(), |acc, (i, gamma)| { - let eval_f = F::new_unchecked_with_cfg( - trace_mles[i].evaluations[b].clone(), - field_cfg, - ); - acc + eval_f * gamma - }) - .into_inner() + let mut acc = + gammas + .iter() + .enumerate() + .fold(zero.clone(), |acc, (i, gamma)| { + let eval_f = F::new_unchecked_with_cfg( + trace_mles[i].evaluations[b].clone(), + field_cfg, + ); + acc + eval_f * gamma + }); + for (i, gamma) in bit_op_gammas.iter().enumerate() { + let eval_f = F::new_unchecked_with_cfg( + bit_op_mles[i].evaluations[b].clone(), + field_cfg, + ); + acc += eval_f * gamma; + } + acc.into_inner() }) .collect(); DenseMultilinearExtension::from_evaluations_vec( @@ -209,7 +236,15 @@ where // Sanity check debug_assert_eq!( sumcheck_proof.claimed_sum, - compute_expected_sum(up_evals, down_evals, &gammas, &alphas, zero) + compute_expected_sum( + up_evals, + down_evals, + bit_op_evals, + &gammas, + &alphas, + &bit_op_gammas, + zero, + ) ); Ok(( @@ -235,22 +270,32 @@ where eval_point: &[F], up_evals: &[F], down_evals: &[F], + bit_op_evals: &[F], shifts: &[ShiftSpec], num_vars: usize, field_cfg: &F::Config, ) -> Result, MultipointEvalError> { let num_cols = up_evals.len(); let num_down_cols = shifts.len(); + let num_bit_op_cols = bit_op_evals.len(); let zero = F::zero_with_cfg(field_cfg); let one = F::one_with_cfg(field_cfg); // Step 1: Sample \alpha_k and \gamma_j (must match prover). let alphas: Vec = transcript.get_field_challenges(num_down_cols, field_cfg); let gammas: Vec = transcript.get_field_challenges(num_cols, field_cfg); + let bit_op_gammas: Vec = transcript.get_field_challenges(num_bit_op_cols, field_cfg); // Step 2: Compute expected sum - let expected_sum: F = - compute_expected_sum(up_evals, down_evals, &gammas, &alphas, zero.clone()); + let expected_sum: F = compute_expected_sum( + up_evals, + down_evals, + bit_op_evals, + &gammas, + &alphas, + &bit_op_gammas, + zero.clone(), + ); if proof.sumcheck_proof.claimed_sum != expected_sum { return Err(MultipointEvalError::WrongSumcheckSum { @@ -281,6 +326,7 @@ where sumcheck_subclaim, gammas, alphas, + bit_op_gammas, eq_at_r0, shifts_at_r0, }) @@ -297,10 +343,12 @@ where pub fn verify_subclaim( subclaim: &Subclaim, open_evals: &[F], + bit_op_open_evals: &[F], shifts: &[ShiftSpec], field_cfg: &F::Config, ) -> Result<(), MultipointEvalError> { let num_cols = subclaim.gammas.len(); + let num_bit_op_cols = subclaim.bit_op_gammas.len(); if open_evals.len() != num_cols { return Err(MultipointEvalError::WrongOpenEvalsNumber { @@ -309,6 +357,13 @@ where }); } + if bit_op_open_evals.len() != num_bit_op_cols { + return Err(MultipointEvalError::WrongBitOpOpenEvalsNumber { + got: bit_op_open_evals.len(), + expected: num_bit_op_cols, + }); + } + let zero = F::zero_with_cfg(field_cfg); let batched_up: F = subclaim @@ -318,6 +373,11 @@ where .fold(zero.clone(), |acc, (gamma, eval)| { acc + gamma.clone() * eval }); + let batched_up = subclaim + .bit_op_gammas + .iter() + .zip(bit_op_open_evals.iter()) + .fold(batched_up, |acc, (gamma, eval)| acc + gamma.clone() * eval); // open_evals[j] = trace_col_j(r_0) for all committed (up) columns. // Shifted columns reuse the same opening: the shift is captured by @@ -345,13 +405,17 @@ where } } -/// `expected_sum = \sum_j \gamma_j * up_eval_j + \sum_k \alpha_k * -/// down_eval_k` +/// `expected_sum = \sum_j \gamma_j * up_eval_j +/// + \sum_k \alpha_k * down_eval_k +/// + \sum_l \gamma_l^bit * bit_op_eval_l` +#[allow(clippy::too_many_arguments)] fn compute_expected_sum( up_evals: &[F], down_evals: &[F], + bit_op_evals: &[F], gammas: &[F], alphas: &[F], + bit_op_gammas: &[F], zero: F, ) -> F { let up_sum = gammas @@ -359,10 +423,15 @@ fn compute_expected_sum( .zip(up_evals.iter()) .fold(zero, |acc, (gamma, up)| acc + gamma.clone() * up); - alphas + let up_and_down = alphas .iter() .zip(down_evals.iter()) - .fold(up_sum, |acc, (alpha, down)| acc + alpha.clone() * down) + .fold(up_sum, |acc, (alpha, down)| acc + alpha.clone() * down); + + bit_op_gammas + .iter() + .zip(bit_op_evals.iter()) + .fold(up_and_down, |acc, (gamma, eval)| acc + gamma.clone() * eval) } // @@ -373,6 +442,8 @@ fn compute_expected_sum( pub enum MultipointEvalError { #[error("wrong number of open evaluations: got {got}, expected {expected}")] WrongOpenEvalsNumber { got: usize, expected: usize }, + #[error("wrong number of bit-op open evaluations: got {got}, expected {expected}")] + WrongBitOpOpenEvalsNumber { got: usize, expected: usize }, #[error("wrong sumcheck claimed sum: got {got}, expected {expected}")] WrongSumcheckSum { got: F, expected: F }, #[error("multi-point eval claim mismatch: got {got}, expected {expected}")] @@ -483,9 +554,11 @@ mod tests { let (proof, prover_state) = MultipointEval::::prove_as_subprotocol( &mut transcript, trace_mles, + &[], &public.eval_point, &public.up_evals, &public.down_evals, + &[], &public.shifts, &(), ) @@ -511,12 +584,13 @@ mod tests { &public.eval_point, &public.up_evals, &public.down_evals, + &[], &public.shifts, public.num_vars, &(), )?; - MultipointEval::::verify_subclaim(&subclaim, &msg.open_evals, &public.shifts, &())?; + MultipointEval::::verify_subclaim(&subclaim, &msg.open_evals, &[], &public.shifts, &())?; Ok(subclaim) } @@ -585,6 +659,92 @@ mod tests { run_verifier(&public, &msg).unwrap(); } + #[test] + fn bit_op_virtual_opening_is_bound_in_subclaim() { + let shifts = vec![ShiftSpec::new(0, 1)]; + let (trace_mles, public) = build_trace(3, 2, &shifts); + + let bit_op_mles = vec![DenseMultilinearExtension::from_evaluations_vec( + public.num_vars, + trace_mles[0] + .evaluations + .iter() + .map(|eval| (F::new_unchecked_with_cfg(*eval, &()) + F::from(11_u32)).into_inner()) + .collect(), + F::ZERO.into_inner(), + )]; + let bit_op_evals: Vec = bit_op_mles + .iter() + .map(|mle| { + mle.clone() + .evaluate_with_config(&public.eval_point, &()) + .unwrap() + }) + .collect(); + + let mut prover_transcript = make_transcript(); + let (proof, prover_state) = MultipointEval::::prove_as_subprotocol( + &mut prover_transcript, + &trace_mles, + &bit_op_mles, + &public.eval_point, + &public.up_evals, + &public.down_evals, + &bit_op_evals, + &public.shifts, + &(), + ) + .expect("prover should succeed"); + + let r_0 = &prover_state.eval_point; + let open_evals: Vec = trace_mles + .iter() + .map(|mle| mle.clone().evaluate_with_config(r_0, &()).unwrap()) + .collect(); + let bit_op_open_evals: Vec = bit_op_mles + .iter() + .map(|mle| mle.clone().evaluate_with_config(r_0, &()).unwrap()) + .collect(); + + let mut verifier_transcript = make_transcript(); + let subclaim = MultipointEval::::verify_as_subprotocol( + &mut verifier_transcript, + proof, + &public.eval_point, + &public.up_evals, + &public.down_evals, + &bit_op_evals, + &public.shifts, + public.num_vars, + &(), + ) + .expect("verifier should accept sumcheck"); + + MultipointEval::::verify_subclaim( + &subclaim, + &open_evals, + &bit_op_open_evals, + &public.shifts, + &(), + ) + .expect("correct bit-op opening should satisfy subclaim"); + + let mut bad_bit_op_open_evals = bit_op_open_evals; + bad_bit_op_open_evals[0] += F::ONE; + let err = MultipointEval::::verify_subclaim( + &subclaim, + &open_evals, + &bad_bit_op_open_evals, + &public.shifts, + &(), + ) + .unwrap_err(); + assert!( + matches!(err, MultipointEvalError::ClaimMismatch { .. }), + "expected ClaimMismatch, got {err:?}", + ); + } + // --- Failure: corrupted down_evals with mixed shifts --- #[test] diff --git a/protocol/src/prover.rs b/protocol/src/prover.rs index b2935342..20e689f0 100644 --- a/protocol/src/prover.rs +++ b/protocol/src/prover.rs @@ -743,9 +743,11 @@ impl_with_type_bounds!(ProverSumchecked let (mp_proof, mp_prover_state) = MultipointEval::prove_as_subprotocol( &mut self.base.pcs_transcript.fs_transcript, &trace_mles, + &[], &self.cpr_eval_point, &up_evals, &self.cpr_proof.down_evals, + &[], self.base.uair_signature.shifts(), &self.field_cfg, )?; diff --git a/protocol/src/verifier.rs b/protocol/src/verifier.rs index e619d7e8..1c9596e7 100644 --- a/protocol/src/verifier.rs +++ b/protocol/src/verifier.rs @@ -672,6 +672,7 @@ where &self.cpr_eval_point, &up_evals, &self.cpr_down_evals, + &[], self.base.uair_signature.shifts(), self.base.num_vars, &self.field_cfg, @@ -789,6 +790,7 @@ where MultipointEval::verify_subclaim( &self.mp_subclaim, &open_evals, + &[], self.base.uair_signature.shifts(), &self.field_cfg, )?;