diff --git a/Cargo.toml b/Cargo.toml index 1b877138..225558b8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,12 +19,15 @@ crypto-bigint = { version = "= 0.7.0-rc.9", features = ["zeroize"] } criterion = "0.7.0" derive_more = { version = "2.1.1", features = ["full"] } itertools = "0.14" +num-integer = "0.1" num-traits = "0.2" proptest = "1.9.0" rand = "0.9" rand_core = "0.9" rayon = { version = "1.10" } thiserror = "2.0" +tracing = "0.1" +tracing-subscriber = "0.3" zinc-primality = { path = "primality/" } zinc-test-uair = { path = "test-uair/" } zinc-transcript = { path = "transcript/" } @@ -52,3 +55,9 @@ unwrap_used = "deny" [profile.release] lto = true # Enable Link Time Optimization codegen-units = 1 # Slower compilation but potentially better optimization + +[profile.bench] +inherits = "release" +debug = 2 +strip = false +split-debuginfo = "unpacked" diff --git a/documentation/sha-uair-doc/uair-object-model.md b/documentation/sha-uair-doc/uair-object-model.md new file mode 100644 index 00000000..9d3c5d0d --- /dev/null +++ b/documentation/sha-uair-doc/uair-object-model.md @@ -0,0 +1,665 @@ +# UAIR Object Model for Production SHA-256 + +This note records the naming and lifecycle decisions for the generic UAIR +objects we want to use before specializing them to production SHA-256. + +## Core Principle + +Keep these roles separate: + +- `Shape`: static relation metadata and layout. +- `Witness`: the prover's full pre-projection assignment. +- `Instance`: verifier-visible public data plus commitments. +- `FoldedWitness`: prover-private folded state after folding. +- `FoldedInstance`: verifier-visible folded accumulator. + +Do not put projected/evaluated data in the initial instance object. Projection +and evaluation happen later after transcript challenges are known. + +## Existing e2e.rs Lifecycle + +The generic path in `protocol/benches/e2e.rs` works like this: + +```rust +let trace = U::generate_random_trace(num_vars, &mut rng); +let sig = U::signature(); +let public_trace = trace.public(&sig); +``` + +The prover receives the full `UairTrace`, splits it with `UairSignature`, commits +only the witness columns, and absorbs the unprojected public columns into the +transcript: + +```rust +let public_trace = trace.public(&uair_signature); +let witness_trace = trace.witness(&uair_signature); + +commit(witness_trace); +absorb(public_trace); +``` + +The verifier receives the proof and the unprojected `public_trace`. It absorbs +that same public trace before later projection/evaluation steps. + +So the instance's public data is an unprojected public UAIR trace, not a +`ProjectedPublicTrace`. + +## UairSignature + +`UairSignature` is the layout contract for a UAIR. It defines: + +- total column counts for `binary_poly`, `arbitrary_poly`, and `int` +- public column prefix counts +- witness column suffix counts +- shifted columns +- virtual columns +- lookup and booleanity metadata + +`UairTrace::public(sig)` and `UairTrace::witness(sig)` use this signature to +split a full trace into public and witness subtraces. + +## Generic Objects + +`UairShape` is useful as a value-level handle for the static UAIR relation plus +the trace length. + +```rust +pub struct UairShape { + pub num_vars: usize, + pub signature: UairSignature, + _marker: PhantomData, +} +``` + +`num_vars` should stay on the shape or protocol input. It is the log trace +length, so the row domain has size `1 << num_vars`. The prover and verifier use +it for MLE sizes, sumcheck rounds, public-structure checks, and PCS parameters. + +Do not add `shape_digest` in the first pass. A shape digest may be useful later +for serialization or cached transcript binding, but it should be derived from +the shape rather than treated as fundamental state. + +The witness is the full pre-projection prover assignment. It includes public +columns and private columns because `UairTrace` itself stores both; public +columns are the prefix determined by `UairSignature`. + +```rust +pub struct UairWitness<'a, PolyCoeff: Clone, Int: Clone, const D: usize> { + pub trace: UairTrace<'a, PolyCoeff, Int, D>, +} +``` + +Document this clearly: `UairWitness` means full prover assignment, not +private-only data. + +The fresh verifier-visible instance contains the unprojected public trace and +commitments to the witness columns. + +```rust +pub struct UairInstance<'a, PolyCoeff: Clone, Int: Clone, Commitments, const D: usize> { + pub public_trace: UairTrace<'a, PolyCoeff, Int, D>, + pub commitments: Commitments, +} +``` + +The folded objects should mirror the fresh split: + +```rust +pub struct FoldedUairWitness { + pub trace: FoldedUairTrace, + pub opening_witness: OpeningWitness, +} + +pub struct FoldedUairTrace { + pub binary_poly: Vec>>, + pub arbitrary_poly: Vec>>, + pub int: Vec>, +} + +pub struct FoldedUairInstance { + pub commitments: Commitments, + pub public: Public, + pub u: F, +} +``` + +`FoldedUairTrace` intentionally keeps the same top-level column families as +`UairTrace`, but its cell types are proof-field objects after projection and +instance folding. Polynomial-valued sources become MLEs whose row values are +univariate proof-field polynomials. Scalar or integer sources become scalar MLEs +over the proof field. + +`opening_witness` is prover-only data needed by the PCS to open the folded +commitments, such as commitment randomness or backend-specific prover state. It +is not part of the proof object. Residual and ideal-polynomial caches should +remain prover working state unless a later phase genuinely needs to carry them. + +The exact field types can be specialized by the folding protocol. For example, +the SHA-256 production path may use a SHA-specific projected trace with +`bit_slices`, `scalarized_words`, `int_columns`, and `public_columns`, while the +generic UAIR object should preserve the `binary_poly`, `arbitrary_poly`, and +`int` families. The boundary remains the same: witness is prover-private, +instance is verifier-visible. + +## LinearIdealFold Proof Objects + +`LinearIdealFold` is the generic folding layer for UAIRs whose projected +residue constraints live in linear ideals. The proof object should contain only +verifier messages and claimed evaluations. It should not contain prover-side +caches such as PCS prover data, and it should not contain NeutronNova-specific +objects such as `comm_E` for a committed power-vector witness. + +`ProjectionFold Concise` is the source of truth for the production protocol: +the verifier algorithm, Fiat-Shamir ordering, concrete SHA-256 ideal families, +degree bounds, and equations. This file records the UAIR object boundaries and +Rust-facing shape of the proof. If a protocol equation is duplicated here, it is +included only to make the object model unambiguous. + +The implementation should reuse the generic proof objects exercised by +`protocol/benches/e2e.rs`. The baseline proof shape in `protocol/src/lib.rs` is: + +```rust +pub struct Proof { + pub commitments: Commitments, + pub zip: Vec, + pub ideal_check: IdealCheckProof, + pub resolver: CombinedPolyResolverProof, + pub combined_sumcheck: MultiDegreeSumcheckProof, + pub multipoint_eval: MultipointEvalProof, + pub witness_lifted_evals: Vec>, + pub lookup_proof: Option>, +} +``` + +Here `IdealCheckProof`, `CombinedPolyResolverProof`, +`MultiDegreeSumcheckProof`, `MultipointEvalProof`, `DynamicPolynomialF`, and +`BatchedLookupProof` are existing protocol/PIOP types exposed by the baseline +e2e proof shape. The production object model should reuse the active proof +components directly. In particular, do not add a separate family-tag proof layer +such as `IdealFamilyId`, `IdealPolySlot`, `IdealFamilyPolys`, or +`IdealFamilyPoly` just to carry the batched ideal polynomials. + +`lookup_proof` is currently a forward-compatible stub in the e2e proof shape. +The prover sets it to `None`, serialization skips it, and the verifier only +carries it through. Production SHA currently has empty `lookup_specs`, so the +production SHA wrapper omits this field. + +The current e2e Rust type calls the serialized PCS opening transcript `zip` +because the first backend was Zip+. Semantically this field is the PCS opening +proof. The production object model should not expose this as Zip-specific state. + +Production folding may need a thin wrapper around this shape because it also has +an instance-axis SumFold/NIFS proof and multiple fresh commitments. Those extra +fields should reuse existing types: + +```rust +pub struct ProductionLinearIdealFoldProof +where + Zt: ZincTypes, + F: PrimeField, + P: ZincPCSTypes, +{ + pub instance_commitments: Vec>, + pub ideal_check: IdealCheckProof, + pub sumfold_proof: MultiDegreeSumcheckProof, + pub resolver: CombinedPolyResolverProof, + pub combined_sumcheck: MultiDegreeSumcheckProof, + pub multipoint_eval: MultipointEvalProof, + pub witness_lifted_evals: Vec>, + pub opening_proof: PCSOpeningProof, +} +``` + +This is intentionally a field-level reuse of the e2e proof object, not a new +PIOP object model. The production wrapper belongs in `protocol`, if needed; it +should not introduce new generic proof structs in `piop/src/neutron_nova`. + +The PCS backend remains generic through `ZincPCSTypes` and the component `PCS` +implementations in `zip-plus/src/pcs/generic.rs`. Do not introduce a separate +SHA-specific PCS trait for the proof object. The proof object only needs the +associated commitment and opening-proof types. Production code that actually +folds commitments should put any homomorphic-folding requirements directly on +the component PCS types at the prover/verifier function boundary. + +```rust +pub struct PCSOpeningProof +where + Zt: ZincTypes, + F: PrimeField, + P: ZincPCSTypes, +{ + pub binary: <

>::BinaryPCS + as PCS, D>>::OpeningProof, + pub arbitrary: <

>::ArbitraryPCS + as PCS, D>>::OpeningProof, + pub int: <

>::IntPCS + as PCS>::OpeningProof, +} +``` + +That requires the PCS trait to expose the opening proof as an associated type: + +```rust +pub trait PCS: Clone + Debug + Send + Sync +where + F: PrimeField, + Eval: Clone + Debug + Send + Sync, +{ + type CommitmentKey: Clone + Debug + Send + Sync; + type VerifierKey: Clone + Debug + Send + Sync; + type Commitment: Clone + Debug + Send + Sync; + type ProverData: Clone + Debug + Send + Sync; + type OpeningProof: Clone + Debug + Send + Sync + Default; + + fn prove_open( + transcript: &mut PcsProverTranscript, + ck: &Self::CommitmentKey, + polys: &[DenseMultilinearExtension], + point: &[F], + prover_data: &Self::ProverData, + field_cfg: &F::Config, + ) -> Result; + + fn verify_open( + transcript: &mut PcsVerifierTranscript, + vk: &Self::VerifierKey, + commitment: &Self::Commitment, + point: &[F], + lifted_evals: &[DynamicPolynomialF], + opening_proof: &Self::OpeningProof, + field_cfg: &F::Config, + ) -> Result<(), ZipError>; +} +``` + +The current trait writes and reads opening data through PCS transcripts and +returns `Result<(), ZipError>`. That should be treated as the current adapter +shape, not the production proof-object shape. Zip+ can set +`type OpeningProof = Vec` while Hyrax or any future PCS can use its native +typed proof. + +The aggregate ideal component from `ProjectionFold Concise` should be carried by +the existing ideal-check proof: + +```rust +pub struct IdealCheckProof { + pub combined_mle_values: Vec>, +} +``` + +The family/order information is setup data, not a new proof object. For +production SHA-256, `verify_setup` fixes the canonical mapping from entries of +`combined_mle_values` to the nonzero ideal families: + + ℱ_≠0 = {R₀, R₁, R₄, R₅, R₆, R₉, R₁₀} + +The compact production interpretation is: + + ideal_check.combined_mle_values[f] = Ē_f^β(X) + for f ∈ ℱ_≠0 in setup-defined order + +If the generic e2e verifier path is used unchanged, the vector length/order must +match `U::verify_as_subprotocol` and `count_constraints::()`. If the +production verifier uses the seven-family compact form, that compact mapping is +part of `verify_setup`; the carrier is still `IdealCheckProof`. + +The honest aggregate polynomial is: + + Ē_f^β(X) + = ∑_{b ∈ {0,1}^ℓ} eq(β,b) + ∑_{z ∈ H_row} eq(r_ic,z) · C_f(z,X;w_b,y_b) + +In the production transcript, r_ic and β are sampled after binding VS and the +fresh instances, and before E_agg is read. `ProjectionFold Concise` owns the +full Fiat-Shamir sequence. + +Thus the submitted ideal component is already batched over both verifier-visible +axes: + + instance axis b via eq(β,b) + row axis z via eq(r_ic,z) + +The verifier still does not trust these aggregate polynomials blindly. It checks +the shape-level degree bound and ideal membership for each family: + + deg_X Ē_f^β(X) < δ_f + Ē_f^β(X) ∈ I_f + +After accepting and absorbing the aggregate polynomials, the verifier samples +the scalarization and family-batching challenges and computes the initial +NIFS/SumFold claim: + + C₀ = ∑_{f ∈ ℱ_≠0} λ^f · Ē_f^β(a) + +The later NIFS, row sumcheck, terminal reconstruction, multipoint reduction, and +PCS opening bind this same scalar to the folded commitments and public trace. + +The folded verifier-visible instance is the result of folding fresh instances. +It contains the folded target claim, folded commitments, and any folded public +values needed by the later checks: + +```rust +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct FoldedLinearIdealInstance { + pub target: F, + pub commitments: Commitments, + pub public: Public, +} +``` + +The folded witness is prover-private. Its concrete representation can be an +owned folded trace, folded source MLEs, or another protocol-specific witness +bundle. In the generic UAIR model, this is normally a +`FoldedUairWitness`: + +```rust +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct FoldedLinearIdealWitness { + pub witness: Witness, +} +``` + +Here `OpeningWitness` is the PCS/backend-specific prover-only state needed to +open the folded commitments. + +The ideal-check proof is the production `E_agg` component from +`ProjectionFold Concise`, represented with `IdealCheckProof`. It contains the +seven SHA-256 aggregate ideal polynomials for the nonzero families, or the +shape-defined analogue for another UAIR. These polynomials are verifier-visible +and must be absorbed before sampling a, λ, ρ, ξ. + +`sumfold_proof` is the instance-axis `MultiDegreeSumcheckProof`. It proves the +SumFold transition from the verifier-computed C₀ to the folded target. The +verifier derives r_b, folding weights θ_b, and T′: + + θ_b = eq(r_b,b) + + T′ = c_SF / eq(β,r_b) + +`resolver` and `combined_sumcheck` are the same terminal-reconstruction objects +used by e2e step 4. `combined_sumcheck` reduces the folded row claim to r⋆, and +`resolver` carries the terminal evaluations needed to close the combined +polynomial resolver: + +```rust +pub struct CombinedPolyResolverProof { + pub up_evals: Vec, + pub down_evals: Vec, + pub bit_slice_evals: Vec, + pub bit_op_down_evals: Vec, + pub shifted_bit_slice_evals: Vec, +} +``` + +Together they prove that the folded target is the row-domain sum of the folded +residue expression: + + T′ = ∑_{x ∈ {0,1}^d} eq(r_ic,x) · Φ_folded(x) + +This sumcheck reduces the folded claim to a terminal point r⋆. + +The verifier uses `resolver` to check: + + terminal = eq(r_ic,r⋆) · Φ_folded(r⋆) + +`multipoint_eval` is the existing e2e multipoint proof. It reduces all terminal +evaluation claims at r⋆ and shifted points into one batched opening claim at a +verifier-derived point r₀: + + { p_i(s_i(r⋆)) = v_i }_i ⇒ P(r₀) = v₀ + +`witness_lifted_evals` are the existing e2e opening-evaluation carrier. They +are witness-only lifted MLE evaluations at r₀ in F_q[X], ordered as +`[wit_bin..., wit_arb..., wit_int...]`. The verifier recomputes public lifted +evals from the public trace, interleaves public and witness lifted evals, +derives scalar `open_evals` by ψ_a, derives bit-op virtual opens locally, and +checks the `multipoint_eval` subclaim. The serialized PCS opening proof is +`opening_proof`. + +The proof chain is: + +```text +ideal_check + → check ideal membership and derive C₀ +sumfold_proof + → fold C₀ into T′ +resolver + combined_sumcheck + → reduce T′ to terminal point r⋆ + → reconstruct the terminal folded expression +multipoint_eval + → reduce many endpoint claims to one opening point r₀ +witness_lifted_evals + opening_proof + → prove consistency with folded commitments +``` + +## SHA-256 Domain Objects + +For one SHA-256 compression, the semantic relation is: + +```rust +H_{i+1} = compress(H_i, M_i) +``` + +where: + +```rust +H_i: [u32; 8] +M_i: [u32; 16] +H_{i+1}: [u32; 8] +``` + +For a chain of `N` compressions: + +```rust +pub struct Sha256ChainPublicInput { + pub initial_state: [u32; 8], + pub message_blocks: [[u32; 16]; N], + pub final_state: [u32; 8], +} +``` + +For standard SHA-256 hashing from the fixed IV, use a wrapper: + +```rust +pub struct Sha256HashPublicInput { + pub message_blocks: [[u32; 16]; N], + pub digest: [u32; 8], +} +``` + +This wrapper expands into `Sha256ChainPublicInput` by setting +`initial_state = SHA256_IV`. + +The current e2e SHA UAIR packs multiple compressions into a single trace. It +does not pass intermediate states as separate public inputs. Witness generation +computes: + +```rust +H_1 = compress(H_0, M_0) +H_2 = compress(H_1, M_1) +... +H_N = compress(H_{N-1}, M_{N-1}) +``` + +and writes the relevant public values into public UAIR columns. + +## SHA-256 to UAIR Mapping + +The SHA domain input is not the same thing as the UAIR public trace. + +The SHA public input: + +```rust +Sha256ChainPublicInput { + initial_state, + message_blocks, + final_state, +} +``` + +is used to build the UAIR public trace. For the current SHA UAIR this includes +columns such as: + +- `PA_M`: message block words +- `PA_A` / `PA_E`: chaining states and final output prefix +- `PA_K`: SHA-256 round constants +- selector columns +- implementation-specific public helper columns + +The prover also builds the full `UairWitness` trace containing the public +columns plus private/witness columns such as: + +- message schedule `W` +- round state columns +- sigma/Sigma columns +- Ch/Maj auxiliary columns +- carry columns +- compensator columns + +The verifier should build or receive only the public trace, then verify the +proof against commitments to the private columns. + +## Recommended Production Flow + +Witness generation is outside the prover. It consumes semantic public input and +produces a full pre-projection UAIR witness: + +```rust +pub fn build_uair_witness( + shape: &UairShape, + public: &Input, +) -> Result, UairWitnessError> +where + U: Uair, + PolyCoeff: Clone, + Int: Clone; +``` + +The SHA-256 instantiation can be written more concretely as: + +```rust +pub fn build_sha256_witness( + shape: &UairShape>, + public: &Sha256ChainPublicInput, +) -> Result, ShaWitnessError> +where + Zt: ZincTypes; +``` + +The prover receives witnesses and commits to the witness columns internally. It +returns fresh verifier-visible instances, the folded accumulator pair, and the +proof: + +```rust +pub struct LinearIdealFoldProveOutput { + pub fresh_instances: Vec, + pub folded_instance: FoldedInstance, + pub folded_witness: FoldedWitness, + pub proof: Proof, +} + +pub fn prove_linear_ideal_fold( + pp: &LinearIdealFoldProverParams, + shape: &UairShape, + witnesses: &[UairWitness<'_, Zt::Int, Zt::Int, D>], + transcript: &mut impl Transcript, +) -> Result< + LinearIdealFoldProveOutput< + UairInstance<'static, Zt::Int, Zt::Int, PCSCommitments, D>, + FoldedLinearIdealInstance, ProjectedPublic>, + FoldedLinearIdealWitness>, + ProductionLinearIdealFoldProof, + >, + LinearIdealFoldError, +> +where + U: Uair, + Zt: ZincTypes, + F: PrimeField, + P: ZincPCSTypes; +``` + +The SHA production folded witness keeps the structured folded SHA trace and the +folded PCS opening witness: + +```rust +pub struct ProductionShaFoldedWitness +where + Zt: ZincTypes, + F: PrimeField, + P: ZincPCSTypes, +{ + pub trace: ProjectedTrace, + pub opening_witness: PCSProverData, +} +``` + +The folded verifier-visible public value is `ProjectedPublic`, not a flat +field vector, because the verifier needs structured SHA public columns for +terminal reconstruction and multipoint checks. + +The production verifier interface in `ProjectionFold Concise` is the acceptance +predicate: + + verify(VS, {Inst_b}_{b ∈ {0,1}^ℓ}, π) → {true, false} + +When VS is fixed by context, the shorthand is: + + verify({Inst_b}_{b ∈ {0,1}^ℓ}, π) → {true, false} + +The Rust-facing API uses the same two-step shape. Setup verification checks and +stores static material: + +```rust +pub fn setup_verify_linear_ideal_fold( + params: LinearIdealFoldVerifierParams, + shape: UairShape, +) -> Result, LinearIdealFoldError> +where + U: Uair + ProductionShaProjectionAdapter, + Zt: ZincTypes, + F: PrimeField, + P: ZincPCSTypes; +``` + +The verifier then receives `VS`, fresh instances, and the proof. It derives the +same folded instance, but it never receives the folded witness: + +```rust +pub fn verify_linear_ideal_fold( + vs: &VerifiedLinearIdealFoldSetup, + instances: &[UairInstance<'_, Zt::Int, Zt::Int, PCSCommitments, D>], + proof: &ProductionLinearIdealFoldProof, + transcript: &mut impl Transcript, +) -> Result< + FoldedLinearIdealInstance, ProjectedPublic>, + LinearIdealFoldError, +> +where + U: Uair + ProductionShaProjectionAdapter, + Zt: ZincTypes, + F: PrimeField, + P: ZincPCSTypes; +``` + +Returning `Ok(folded_instance)` means the production verifier accepts. Returning +`Err(_)` means rejection. + +```rust +Sha256ChainPublicInput + -> build public UairTrace + +Sha256ChainPublicInput + witness generation + -> full UairWitness + +UairWitness + UairShape + -> commit witness columns + -> UairInstance { public_trace, commitments } + -> prove + +UairInstance + proof + -> verify +``` + +Projection to the proof field and evaluation at verifier challenges are internal +protocol phases. They should not be part of the initial public instance type. diff --git a/ecc-qx-commitment.md b/ecc-qx-commitment.md new file mode 100644 index 00000000..14ee89dd --- /dev/null +++ b/ecc-qx-commitment.md @@ -0,0 +1,19 @@ +# ECC Commitment for a ℚ[X]-Valued Oracle + +So you can have a ℚ[X]-valued polynomial represented as the following + +``` +f_b(X) = Σ_{j; + +fn bench_config() -> ::Config +where + MillerRabin: PrimalityTest<::Modulus>, +{ + let mut transcript = Blake3Transcript::new(); + transcript.get_random_field_cfg::::Modulus, MillerRabin>() +} + +fn f(value: u64, cfg: &::Config) -> F { + F::from_with_cfg(value, cfg) +} + +fn mle_table_from_columns(columns: Vec>) -> MleTable { + columns + .into_iter() + .map(|evaluations| DenseMultilinearExtension { + evaluations, + num_vars: SHA_ROW_VARS, + }) + .collect() +} + +fn flatten_bits(bits: Vec>>) -> MleTable { + let mut flattened = (0..bits.len() * SHA_WORD_BITS) + .map(|_| Vec::new()) + .collect::>(); + for (col_idx, rows) in bits.into_iter().enumerate() { + for row in rows { + for (bit_idx, value) in row.into_iter().enumerate() { + flattened[bit_slice_index(col_idx, bit_idx, SHA_WORD_BITS)].push(value); + } + } + } + mle_table_from_columns(flattened) +} + +fn synthetic_boolean_trace( + instance_idx: u64, + a: &F, + cfg: &::Config, +) -> ProjectedTrace { + let zero = F::zero_with_cfg(cfg); + let mut bits = vec![vec![vec![zero.clone(); SHA_WORD_BITS]; SHA_ROW_COUNT]; ShaWordCol::COUNT]; + for (col_idx, col) in bits.iter_mut().enumerate() { + for (row_idx, row) in col.iter_mut().enumerate() { + for (bit_idx, bit) in row.iter_mut().enumerate() { + let selector = instance_idx + + u64::try_from(col_idx * 17 + row_idx * 3 + bit_idx) + .expect("bench selector fits u64"); + if selector % 2 == 1 { + *bit = f(1, cfg); + } + } + } + } + let bit_slices = flatten_bits(bits); + let scalarized = scalarize_bit_slices(&bit_slices, a, cfg).unwrap(); + ProjectedTrace { + bit_slices, + scalarized, + int_columns: mle_table_from_columns(vec![ + vec![zero.clone(); SHA_ROW_COUNT]; + ShaIntCol::COUNT + ]), + public_columns: mle_table_from_columns(vec![ + vec![zero; SHA_ROW_COUNT]; + ShaPublicCol::COUNT + ]), + } +} + +fn zero_public(cfg: &::Config) -> ProjectedPublic { + ProjectedPublic { + columns: mle_table_from_columns(vec![ + vec![F::zero_with_cfg(cfg); SHA_ROW_COUNT]; + ShaPublicCol::COUNT + ]), + bit_slices: None, + } +} + +fn booleanity_sources() -> Vec { + vec![ + ShaBooleanitySource::WordBit { + col: ShaWordCol::A, + bit: 0, + }, + ShaBooleanitySource::WordBit { + col: ShaWordCol::A, + bit: 7, + }, + ShaBooleanitySource::WordBit { + col: ShaWordCol::E, + bit: 1, + }, + ShaBooleanitySource::WordBit { + col: ShaWordCol::E, + bit: 9, + }, + ShaBooleanitySource::WordBit { + col: ShaWordCol::W, + bit: 2, + }, + ShaBooleanitySource::WordBit { + col: ShaWordCol::W, + bit: 13, + }, + ShaBooleanitySource::WordBit { + col: ShaWordCol::Sigma0, + bit: 3, + }, + ShaBooleanitySource::WordBit { + col: ShaWordCol::Sigma1, + bit: 5, + }, + ] +} + +#[allow(clippy::too_many_lines)] +fn neutron_nova_sumfold_benches(c: &mut Criterion) { + let cfg = bench_config(); + let ell = 7usize; + let prefix_vars = 2usize; + let a = f(3, &cfg); + let traces = (0..(1usize << ell)) + .map(|idx| synthetic_boolean_trace(u64::try_from(idx).unwrap(), &a, &cfg)) + .collect::>(); + let publics = vec![zero_public(&cfg); traces.len()]; + let beta = vec![ + f(5, &cfg), + f(7, &cfg), + f(11, &cfg), + f(13, &cfg), + f(17, &cfg), + f(19, &cfg), + f(37, &cfg), + ]; + let r_ic = [ + f(2, &cfg), + f(3, &cfg), + f(5, &cfg), + f(7, &cfg), + f(11, &cfg), + f(13, &cfg), + f(17, &cfg), + ]; + let lambda = f(23, &cfg); + let rho = f(29, &cfg); + let xi = f(31, &cfg); + let sources = booleanity_sources(); + + let mut group = c.benchmark_group("NeutronNova SHA SumFold"); + group.sample_size(10); + + group.bench_function(BenchmarkId::new("dense_build_and_prove", ell), |bench| { + bench.iter_batched( + Blake3Transcript::new, + |mut transcript| { + let group = build_dense_sha_sumfold_group( + &traces, &publics, &beta, &r_ic, &a, &lambda, &rho, &xi, &sources, &cfg, + ) + .unwrap(); + let (proof, _) = MultiDegreeSumcheck::prove_as_subprotocol( + &mut transcript, + vec![group], + ell, + &cfg, + ); + black_box(proof.claimed_sums()[0].clone()) + }, + BatchSize::SmallInput, + ); + }); + + group.bench_function( + BenchmarkId::new("production_prefix_tail_build_and_prove", prefix_vars), + |bench| { + bench.iter_batched( + || (traces.clone().into_boxed_slice(), Blake3Transcript::new()), + |(owned_traces, mut transcript)| { + let group = build_production_sha_sumfold_group_owned( + owned_traces, + &publics, + &beta, + &r_ic, + &a, + &lambda, + &rho, + &xi, + &sources, + prefix_vars, + &cfg, + ) + .unwrap(); + let (proof, _) = MultiDegreeSumcheck::prove_as_subprotocol( + &mut transcript, + vec![group], + ell, + &cfg, + ); + black_box(proof.claimed_sums()[0].clone()) + }, + BatchSize::SmallInput, + ); + }, + ); + + group.finish(); +} + +criterion_group! { + name = benches; + config = Criterion::default().sample_size(10); + targets = neutron_nova_sumfold_benches +} +criterion_main!(benches); diff --git a/piop/src/combined_poly_resolver.rs b/piop/src/combined_poly_resolver.rs index c57bb2d0..13bc89ea 100644 --- a/piop/src/combined_poly_resolver.rs +++ b/piop/src/combined_poly_resolver.rs @@ -5,6 +5,7 @@ mod structs; pub use structs::*; +use crate::projections::ScalarMap; use crate::{ CombFn, combined_poly_resolver::{ @@ -18,12 +19,11 @@ use crate::{ prover::ProverState as SumcheckProverState, }, }; -use crypto_primitives::{FromPrimitiveWithConfig, PrimeField}; +use crypto_primitives::{FromPrimitiveWithConfig, PrimeField, crypto_bigint_uint::Uint}; use itertools::Itertools; use num_traits::Zero; #[cfg(feature = "parallel")] use rayon::prelude::*; -use crate::projections::ScalarMap; use std::{cell::RefCell, marker::PhantomData, slice}; use thiserror::Error; use zinc_poly::{ @@ -33,13 +33,20 @@ use zinc_poly::{ binary::BinaryPoly, dynamic::over_field::{DynamicPolyFInnerProduct, DynamicPolynomialF}, }, - utils::{ArithErrors, build_eq_x_r_inner, eq_eval}, + utils::{ArithErrors, build_eq_x_r_inner, build_eq_x_r_vec, eq_eval}, }; use zinc_transcript::traits::{ConstTranscribable, Transcript}; use zinc_uair::{BitOp, TraceRow, Uair, ideal::ImpossibleIdeal}; use zinc_utils::{ - UNCHECKED, add, cfg_iter, from_ref::FromRef, inner_product::InnerProduct, - inner_transparent_field::InnerTransparentField, powers, + UNCHECKED, add, cfg_iter, + delayed_reduction::{ + BarrettDelayedReduction, DelayedFieldProductSum, DelayedModularReductionAlgorithm, + MontgomeryLimbs, MontgomeryProductSum4, + }, + from_ref::FromRef, + inner_product::{FieldFieldInnerProduct, InnerProduct}, + inner_transparent_field::InnerTransparentField, + powers, }; /// Materialize the bit-op virtual MLEs given by `bit_op_specs`. @@ -125,7 +132,7 @@ where continue; } if let Some(w) = &weights[i] { - acc += w.clone(); + acc += w; } } acc.into_inner() @@ -141,9 +148,97 @@ where .collect() } +/// Evaluate bit-op virtual columns directly at `point`, without materializing +/// the full virtual MLEs. +/// +/// This is the point-only counterpart of [`build_bit_op_mles`]. It keeps the +/// hot scan addition-oriented: each set input bit adds the cached `eq_r(b)` +/// into one destination bucket with sum-only DMR, then the 32 buckets are +/// projected by `alpha` with the field product-sum backend. +#[allow(clippy::arithmetic_side_effects)] +pub fn compute_bit_op_evals_streaming( + trace_bin_poly: &[DenseMultilinearExtension>], + bit_op_specs: &[zinc_uair::BitOpSpec], + num_total_bin: usize, + projecting_element_f: &F, + point: &[F], + field_cfg: &F::Config, +) -> Result, ArithErrors> +where + F: InnerTransparentField + MontgomeryLimbs + DelayedFieldProductSum + Send + Sync, + F::Config: Sync, +{ + if bit_op_specs.is_empty() { + return Ok(Vec::new()); + } + + assert!( + D == 32, + "BitOpSpec virtual columns require D == 32, got D = {D}", + ); + + let zero = F::zero_with_cfg(field_cfg); + let one = F::one_with_cfg(field_cfg); + let alpha_powers: Vec = powers(projecting_element_f.clone(), one, 32); + let eq_table = build_eq_x_r_vec(point, field_cfg)?; + let reducer = BarrettDelayedReduction::::new(field_cfg); + let product_sum = MontgomeryProductSum4::::new(field_cfg); + + let evals = cfg_iter!(bit_op_specs) + .map(|spec| { + assert!( + spec.source_col() < num_total_bin, + "BitOpSpec source_col {} must reference a binary_poly column \ + (num binary cols = {num_total_bin})", + spec.source_col(), + ); + + let col = &trace_bin_poly[spec.source_col()]; + let mut buckets: Vec> = vec![Uint::zero(); 32]; + for (b, cell) in col.iter().enumerate() { + let Some(eq_b) = eq_table.get(b) else { + break; + }; + for (src_bit, coeff) in cell.iter().enumerate().take(32) { + if !coeff.into_inner() { + continue; + } + if let Some(dst_bit) = bit_op_destination(spec.op(), src_bit) { + reducer.add(&mut buckets[dst_bit], eq_b); + } + } + } + + let bucket_evals: Vec = buckets.into_iter().map(|acc| reducer.reduce(acc)).collect(); + + FieldFieldInnerProduct::inner_product_with_algorithm::( + &product_sum, + &bucket_evals, + &alpha_powers, + zero.clone(), + ) + .expect("bit-op bucket and alpha-power lengths match") + }) + .collect(); + + Ok(evals) +} + +fn bit_op_destination(op: BitOp, src_bit: usize) -> Option { + match op { + BitOp::Rot(c) => Some((src_bit + c as usize) % 32), + BitOp::ShiftR(c) => { + let c = c as usize; + src_bit.checked_sub(c) + } + } +} + pub struct CombinedPolyResolver(PhantomData); -impl CombinedPolyResolver { +impl + CombinedPolyResolver +{ /// Build the CPR sumcheck group for use in the multi-degree sumcheck. /// /// Pre-sumcheck half of the CPR prover. Samples the folding challenge `α`, @@ -283,8 +378,7 @@ impl CombinedP // `scalar_proj_cache` for details. Lazily initialized — UAIRs // that never invoke `from_ref`/`mbs` pay only the Option's // discriminant write per call. - let cache: RefCell>> = - RefCell::new(None); + let cache: RefCell>> = RefCell::new(None); let project = |scalar: &U::Scalar| -> F { if let Some(v) = cache.borrow().as_ref().and_then(|c| c.get(scalar)) { return v; @@ -313,7 +407,7 @@ impl CombinedP ImpossibleIdeal::from_ref, ); - folder.folded_constraints * (one.clone() - selector) * eq_r + folder.finish_folded() * (one.clone() - selector) * eq_r }); Ok(( @@ -463,13 +557,11 @@ impl CombinedP let folding_challenge_powers: Vec = powers(folding_challenge, one.clone(), num_constraints); - // TODO(Alex): investigate if parallelising this is beneficial. // Compute v_0 + \alpha * v_1 + ... + \alpha ^ k * v_k. - let expected_sum = ic_check_subclaim + let expected_values: Vec = ic_check_subclaim .values .iter() - .zip(&folding_challenge_powers) - .map(|(claimed_value, random_coeff)| { + .map(|claimed_value| { let deg = claimed_value.degree().map_or(0, |d| add!(d, 1)); DynamicPolyFInnerProduct::inner_product::( &claimed_value.coeffs[..deg], @@ -477,9 +569,14 @@ impl CombinedP zero.clone(), ) .expect("inner product cannot fail here") - * random_coeff }) - .fold(zero.clone(), |acc, term| acc + term); + .collect(); + let expected_sum = FieldFieldInnerProduct::inner_product::( + &expected_values, + &folding_challenge_powers[..expected_values.len()], + zero.clone(), + ) + .expect("claimed values and folding powers have matching lengths"); if claimed_sum != expected_sum { return Err(CombinedPolyResolverError::WrongSumcheckSum { @@ -573,17 +670,13 @@ impl CombinedP &proof.up_evals, uair_sig.total_cols().as_column_layout(), ), - TraceRow::from_slice_with_layout_and_bit_op( - &down_combined, - down_layout, - bit_op_count, - ), + TraceRow::from_slice_with_layout_and_bit_op(&down_combined, down_layout, bit_op_count), project, |x, y| Some(project(y) * x), ImpossibleIdeal::from_ref, ); - let expected_claim_value = eq_r_value * (one - selector_value) * folder.folded_constraints; + let expected_claim_value = eq_r_value * (one - selector_value) * folder.finish_folded(); if expected_claim_value != expected_evaluation { return Err(CombinedPolyResolverError::ClaimValueDoesNotMatch { @@ -667,7 +760,9 @@ mod tests { sumcheck::multi_degree::MultiDegreeSumcheck, test_utils::{LIMBS, run_ideal_check_prover_combined, test_config}, }; - use crypto_primitives::{crypto_bigint_int::Int, crypto_bigint_monty::MontyField}; + use crypto_primitives::{ + FromWithConfig, crypto_bigint_int::Int, crypto_bigint_monty::MontyField, + }; use rand::rng; use zinc_poly::univariate::dense::DensePolynomial; use zinc_test_uair::{ @@ -675,6 +770,7 @@ mod tests { }; use zinc_transcript::Blake3Transcript; use zinc_uair::{ + BitOpSpec, constraint_counter::count_constraints, degree_counter::count_max_degree, ideal::{DegreeOneIdeal, Ideal, IdealCheck}, @@ -685,6 +781,69 @@ mod tests { // Once we have time we need to create a comprehensive test suite // akin to the one we have for the PCS or the sumcheck. + fn binary_col_from_u32s(patterns: &[u32]) -> DenseMultilinearExtension> { + DenseMultilinearExtension::from_evaluations_vec( + patterns.len().next_power_of_two().trailing_zeros() as usize, + patterns + .iter() + .copied() + .map(BinaryPoly::<32>::from) + .collect(), + BinaryPoly::<32>::zero(), + ) + } + + fn assert_bit_op_streaming_matches_materialized(spec: BitOpSpec) { + let cfg = test_config(); + let trace_bin_poly = vec![binary_col_from_u32s(&[ + 0x0000_0001, + 0x8000_0001, + 0x0f0f_00f0, + 0xf000_00ff, + ])]; + let specs = vec![spec]; + let projecting_element = MontyField::<4>::from_with_cfg(7u64, &cfg); + let point = vec![ + MontyField::<4>::from_with_cfg(3u64, &cfg), + MontyField::<4>::from_with_cfg(5u64, &cfg), + ]; + + let materialized = build_bit_op_mles::, 32>( + &trace_bin_poly, + &specs, + 1, + &projecting_element, + point.len(), + &cfg, + ) + .into_iter() + .map(|mle| mle.evaluate_with_config(&point, &cfg)) + .collect::, _>>() + .unwrap(); + + let streaming = compute_bit_op_evals_streaming::, 32>( + &trace_bin_poly, + &specs, + 1, + &projecting_element, + &point, + &cfg, + ) + .unwrap(); + + assert_eq!(streaming, materialized); + } + + #[test] + fn bit_op_streaming_rot_matches_materialized_mle() { + assert_bit_op_streaming_matches_materialized(BitOpSpec::new(0, BitOp::Rot(7))); + } + + #[test] + fn bit_op_streaming_shift_r_matches_materialized_mle() { + assert_bit_op_streaming_matches_materialized(BitOpSpec::new(0, BitOp::ShiftR(5))); + } + fn test_successful_verification_generic< U, IdealOverF, @@ -735,22 +894,23 @@ mod tests { project_scalars_to_field(projected_scalars, &projecting_element).unwrap(); // Prover: prepare → MultiDegreeSumcheck → finalize - let (cpr_group, cpr_ancillary) = CombinedPolyResolver::prepare_sumcheck_group::( - &mut prover_transcript, - evaluate_trace_to_column_mles( - &ProjectedTrace::RowMajor(projected_trace), + let (cpr_group, cpr_ancillary) = + CombinedPolyResolver::prepare_sumcheck_group::( + &mut prover_transcript, + evaluate_trace_to_column_mles( + &ProjectedTrace::RowMajor(projected_trace), + &projecting_element, + ), + &ic_prover_state.evaluation_point, + &projected_scalars, + num_constraints, + num_vars, + max_degree, + &test_config(), + &trace.binary_poly, &projecting_element, - ), - &ic_prover_state.evaluation_point, - &projected_scalars, - num_constraints, - num_vars, - max_degree, - &test_config(), - &trace.binary_poly, - &projecting_element, - ) - .expect("CPR prepare failed"); + ) + .expect("CPR prepare failed"); let (md_proof, states) = MultiDegreeSumcheck::prove_as_subprotocol( &mut prover_transcript, diff --git a/piop/src/combined_poly_resolver/folder.rs b/piop/src/combined_poly_resolver/folder.rs index 4a9afb91..92d2d5e5 100644 --- a/piop/src/combined_poly_resolver/folder.rs +++ b/piop/src/combined_poly_resolver/folder.rs @@ -1,5 +1,10 @@ use crypto_primitives::PrimeField; use zinc_uair::{ConstraintBuilder, ideal::ImpossibleIdeal}; +use zinc_utils::{ + UNCHECKED, + delayed_reduction::DelayedFieldProductSum, + inner_product::{FieldFieldInnerProduct, InnerProduct}, +}; /// There are several situations where we need to /// compute an RLC `u_0 + \alpha * u_1 + ... + \alpha ^ k * u_k`, @@ -18,31 +23,46 @@ use zinc_uair::{ConstraintBuilder, ideal::ImpossibleIdeal}; /// /// This constraint builder handles those situations. /// It's `Expr` associated type is the field `F`, so once -/// an `assert_*` method is called it adds it to the RLC -/// with the next power of the challenge `\alpha`. +/// an `assert_*` method is called it records the residual in order. +/// Call [`ConstraintFolder::finish_folded`] to compute the RLC with the +/// DMR-aware field-field product-sum backend. pub struct ConstraintFolder<'a, F: PrimeField> { /// A reference to precomputed powers of the challenge. challenge_powers: &'a [F], - /// Index of the current constraint, - /// and therefore the current power of the challenge. - current_constraint: usize, - /// The RLC computed so far. - pub folded_constraints: F, + /// Residuals in the exact order constraints were visited. + residuals: Vec, + /// Additive identity used as the product-sum seed. + zero: F, } impl<'a, F: PrimeField> ConstraintFolder<'a, F> { pub fn new(challenge_powers: &'a [F], zero: &F) -> Self { Self { challenge_powers, - current_constraint: 0, - folded_constraints: zero.clone(), + residuals: Vec::with_capacity(challenge_powers.len()), + zero: zero.clone(), } } #[allow(clippy::arithmetic_side_effects)] fn fold_constraint(&mut self, expr: F) { - self.folded_constraints += expr * &self.challenge_powers[self.current_constraint]; - self.current_constraint += 1; + debug_assert!( + self.residuals.len() < self.challenge_powers.len(), + "more constraint residuals than challenge powers" + ); + self.residuals.push(expr); + } + + pub fn finish_folded(self) -> F + where + F: DelayedFieldProductSum, + { + FieldFieldInnerProduct::inner_product::( + &self.residuals, + &self.challenge_powers[..self.residuals.len()], + self.zero, + ) + .expect("constraint residuals and challenge powers have matching lengths") } } @@ -72,3 +92,71 @@ impl<'a, F: PrimeField> ConstraintBuilder for ConstraintFolder<'a, F> { self.fold_constraint(expr); } } + +#[cfg(test)] +mod tests { + use super::*; + use crypto_primitives::{FromWithConfig, crypto_bigint_monty::MontyField}; + use zinc_uair::ConstraintBuilder; + + type F = MontyField<4>; + + fn cfg() -> ::Config { + crate::test_utils::test_config() + } + + fn f(value: u64) -> F { + F::from_with_cfg(value, &cfg()) + } + + fn naive_fold(residuals: &[F], powers: &[F], zero: F) -> F { + residuals + .iter() + .zip(powers) + .fold(zero, |acc, (residual, power)| { + acc + residual.clone() * power + }) + } + + #[test] + fn finish_folded_matches_naive_empty() { + let cfg = cfg(); + let zero = F::zero_with_cfg(&cfg); + let powers = vec![f(1), f(7), f(49)]; + let folder = ConstraintFolder::new(&powers, &zero); + + assert_eq!(folder.finish_folded(), zero); + } + + #[test] + fn finish_folded_matches_naive_single_constraint() { + let cfg = cfg(); + let zero = F::zero_with_cfg(&cfg); + let powers = vec![f(1), f(7), f(49)]; + let residuals = vec![f(11)]; + let mut folder = ConstraintFolder::new(&powers, &zero); + folder.assert_zero(residuals[0].clone()); + + assert_eq!( + folder.finish_folded(), + naive_fold(&residuals, &powers, zero) + ); + } + + #[test] + fn finish_folded_matches_naive_multiple_constraints() { + let cfg = cfg(); + let zero = F::zero_with_cfg(&cfg); + let powers = vec![f(1), f(7), f(49), f(343)]; + let residuals = vec![f(3), f(5), f(8), f(13)]; + let mut folder = ConstraintFolder::new(&powers, &zero); + for residual in &residuals { + folder.assert_zero(residual.clone()); + } + + assert_eq!( + folder.finish_folded(), + naive_fold(&residuals, &powers, zero) + ); + } +} diff --git a/piop/src/combined_poly_resolver/structs.rs b/piop/src/combined_poly_resolver/structs.rs index 0c1eab26..9b6d80b1 100644 --- a/piop/src/combined_poly_resolver/structs.rs +++ b/piop/src/combined_poly_resolver/structs.rs @@ -58,7 +58,9 @@ where let buf = self.down_evals.write_transcription_bytes_subset(buf); let buf = self.bit_slice_evals.write_transcription_bytes_subset(buf); let buf = self.bit_op_down_evals.write_transcription_bytes_subset(buf); - let buf = self.shifted_bit_slice_evals.write_transcription_bytes_subset(buf); + let buf = self + .shifted_bit_slice_evals + .write_transcription_bytes_subset(buf); assert!(buf.is_empty(), "Entire buffer should be used"); } } diff --git a/piop/src/ideal_check.rs b/piop/src/ideal_check.rs index 9c8d51fc..4f1ba7d4 100644 --- a/piop/src/ideal_check.rs +++ b/piop/src/ideal_check.rs @@ -3,13 +3,13 @@ mod batched_ideal_check; mod combined_poly_builder; mod structs; +pub use batched_ideal_check::{BatchedIdealCheckError, batched_ideal_check}; pub use structs::*; #[cfg(feature = "parallel")] use rayon::prelude::*; use crate::projections::{ColumnMajorTrace, RowMajorTrace, ScalarMap}; -use batched_ideal_check::*; use crypto_primitives::PrimeField; use num_traits::ConstZero; use thiserror::Error; @@ -406,6 +406,12 @@ where let mut transcription_buf: Vec = vec![0; F::Inner::NUM_BYTES]; let combined_mle_values = proof.combined_mle_values; + if combined_mle_values.len() != num_constraints { + return Err(IdealCheckError::ProofValueCount { + got: combined_mle_values.len(), + expected: num_constraints, + }); + } let evaluation_point = transcript.get_field_challenges(num_vars, field_cfg); @@ -445,6 +451,8 @@ pub enum IdealCheckError { IdealCollectorError(#[from] BatchedIdealCheckError, I>), #[error("`eq` polynomial construction failure: {0}")] EqPolyConstructionError(#[from] PolyArithErrors), + #[error("ideal-check proof value count mismatch: got {got}, expected {expected}")] + ProofValueCount { got: usize, expected: usize }, } #[cfg(test)] @@ -578,4 +586,36 @@ mod tests { |_ideal_over_ring| IdealOrZero::>::zero(), ); } + + #[test] + fn verifier_rejects_truncated_proof_values() { + let field_cfg = test_config(); + let num_vars = 2; + let mut rng = rng(); + let transcript = Blake3Transcript::new(); + type U = TestUairNoMultiplication>; + + let (mut proof, ..) = run_ideal_check_prover_linear::( + num_vars, + &U::generate_random_trace(num_vars, &mut rng), + &mut transcript.clone(), + ); + let num_constraints = count_constraints::(); + proof.combined_mle_values.pop(); + + let result = U::verify_as_subprotocol( + &mut transcript.clone(), + proof, + num_constraints, + num_vars, + |ideal_over_ring| ideal_over_ring.map(|i| DegreeOneIdeal::from_with_cfg(i, &field_cfg)), + &field_cfg, + ); + + assert!(matches!( + result, + Err(IdealCheckError::ProofValueCount { got, expected }) + if got + 1 == expected && expected == num_constraints + )); + } } diff --git a/piop/src/ideal_check/combined_poly_builder.rs b/piop/src/ideal_check/combined_poly_builder.rs index 955121eb..183aaac4 100644 --- a/piop/src/ideal_check/combined_poly_builder.rs +++ b/piop/src/ideal_check/combined_poly_builder.rs @@ -6,7 +6,7 @@ use crate::{ scalar_proj_cache::ScalarProjCache, }; use crypto_primitives::PrimeField; -use num_traits::{ConstZero, Zero}; +use num_traits::Zero; use std::cell::RefCell; use zinc_poly::{ EvaluationError, diff --git a/piop/src/lib.rs b/piop/src/lib.rs index b1475cb9..8872f821 100644 --- a/piop/src/lib.rs +++ b/piop/src/lib.rs @@ -2,6 +2,7 @@ pub mod combined_poly_resolver; pub mod ideal_check; pub mod lookup; pub mod multipoint_eval; +pub mod neutron_nova; pub mod projections; pub mod random_field_sumcheck; pub mod scalar_proj_cache; diff --git a/piop/src/lookup/booleanity.rs b/piop/src/lookup/booleanity.rs index 0de0f1fb..617874c2 100644 --- a/piop/src/lookup/booleanity.rs +++ b/piop/src/lookup/booleanity.rs @@ -13,7 +13,9 @@ //! `max_degree + 3`. For SHA-style UAIRs (max_degree ≥ 6) with hundreds //! of bit-slice MLEs, this is a 2–2.5× saving on step 4 alone. -use crypto_primitives::{FromPrimitiveWithConfig, PrimeField, semiring::boolean::Boolean}; +use crypto_primitives::{ + FromPrimitiveWithConfig, PrimeField, crypto_bigint_uint::Uint, semiring::boolean::Boolean, +}; use num_traits::Zero; #[cfg(feature = "parallel")] use rayon::prelude::*; @@ -31,7 +33,14 @@ use zinc_uair::{ VirtualBoolSpec, }; use zinc_utils::{ - cfg_into_iter, cfg_iter, inner_transparent_field::InnerTransparentField, powers, + UNCHECKED, cfg_into_iter, cfg_iter, + delayed_reduction::{ + BarrettDelayedReduction, DelayedFieldProductSum, DelayedModularReductionAlgorithm, + MontgomeryLimbs, + }, + inner_product::{FieldFieldInnerProduct, InnerProduct}, + inner_transparent_field::InnerTransparentField, + powers, }; /// Build the F::Inner-valued shifted bit-slice MLEs for each @@ -88,15 +97,11 @@ where /// sumcheck point `r*`. Equivalent to /// `build_shifted_bit_slice_mles(...).iter().map(evaluate_at(r*))`, /// but skips materializing the `num_shifted_specs · D` F::Inner-valued -/// MLE buffers (~`num_shifted · D · n` F::Inner allocations). Builds -/// the size-`n` `eq(r*, ·)` table once and accumulates per-bit sums -/// directly from the `BinaryPoly` trace columns in a single pass -/// per spec (t outer, bits inner — avoids `iter().nth(bit_idx)`'s -/// linear cost on custom binary-poly iterators). +/// MLE buffers (~`num_shifted · D · n` F::Inner allocations). /// -/// Used when the prover doesn't otherwise need the materialized bit- -/// slice MLEs (i.e. when no `VirtualBoolSpec` is registered — the -/// `VirtualBinaryPolySpec` path reads source binary_polys directly). +/// The hot per-bit sums are accumulated with delayed modular reduction: +/// each bit accumulator is a small `Uint<5>` integer and is reduced +/// once at the end of the spec scan. #[allow(clippy::arithmetic_side_effects)] pub fn compute_shifted_bit_slice_evals_streaming( trace_witness_binary_poly: &[DenseMultilinearExtension>], @@ -105,39 +110,36 @@ pub fn compute_shifted_bit_slice_evals_streaming( field_cfg: &F::Config, ) -> Result, ArithErrors> where - F: PrimeField + Send + Sync, + F: PrimeField + MontgomeryLimbs + Send + Sync, F::Config: Sync, { if shifted_specs.is_empty() { return Ok(Vec::new()); } - // Single shared eq table — one O(n) pass + O(n) memory across all - // (spec, bit) sums. let eq_table = build_eq_x_r_vec(point, field_cfg)?; + let reducer = BarrettDelayedReduction::::new(field_cfg); - let zero = F::zero_with_cfg(field_cfg); let out: Vec = cfg_iter!(shifted_specs) .flat_map(|spec| { let col = &trace_witness_binary_poly[spec.witness_col_idx]; let shift = spec.shift_amount; let n = col.evaluations.len(); - // Per-bit accumulators, one F-element each. - let mut accs: Vec = vec![zero.clone(); D]; + let mut accs: Vec> = vec![Uint::zero(); D]; for t in 0..n { let src_t = t.checked_add(shift).filter(|&v| v < n); if let Some(s) = src_t { let bp = &col.evaluations[s]; let eq_t = &eq_table[t]; - // Walk bits in their stored order; the iterator - // visits each coefficient once in O(D). for (bit_idx, coeff) in bp.iter().enumerate() { if coeff.into_inner() { - accs[bit_idx] = accs[bit_idx].clone() + eq_t.clone(); + reducer.add(&mut accs[bit_idx], eq_t); } } } } - accs + accs.into_iter() + .map(|acc| reducer.reduce(acc)) + .collect::>() }) .collect(); Ok(out) @@ -218,19 +220,15 @@ where VirtualBoolSource::SelfBitSlice { witness_col_idx, bit_idx, - } => &self_bit_slices[*witness_col_idx * D + *bit_idx] - .evaluations, + } => &self_bit_slices[*witness_col_idx * D + *bit_idx].evaluations, VirtualBoolSource::ShiftedBitSlice { shifted_spec_idx, bit_idx, - } => &shifted_bit_slice_mles - [*shifted_spec_idx * D + *bit_idx] - .evaluations, + } => &shifted_bit_slice_mles[*shifted_spec_idx * D + *bit_idx].evaluations, VirtualBoolSource::PublicBitSlice { public_col_idx, bit_idx, - } => &public_bit_slices[*public_col_idx * D + *bit_idx] - .evaluations, + } => &public_bit_slices[*public_col_idx * D + *bit_idx].evaluations, VirtualBoolSource::IntCol { witness_col_idx } => { &int_witness_cols[*witness_col_idx].evaluations } @@ -242,43 +240,38 @@ where match *coeff { 1 => { for t in 0..n { - evals[t] = - F::add_inner(&evals[t], &src[t], field_cfg); + evals[t] = F::add_inner(&evals[t], &src[t], field_cfg); } } -1 => { for t in 0..n { - evals[t] = - F::sub_inner(&evals[t], &src[t], field_cfg); + evals[t] = F::sub_inner(&evals[t], &src[t], field_cfg); } } 2 => { for t in 0..n { - evals[t] = - F::add_inner(&evals[t], &src[t], field_cfg); - evals[t] = - F::add_inner(&evals[t], &src[t], field_cfg); + evals[t] = F::add_inner(&evals[t], &src[t], field_cfg); + evals[t] = F::add_inner(&evals[t], &src[t], field_cfg); } } -2 => { for t in 0..n { - evals[t] = - F::sub_inner(&evals[t], &src[t], field_cfg); - evals[t] = - F::sub_inner(&evals[t], &src[t], field_cfg); + evals[t] = F::sub_inner(&evals[t], &src[t], field_cfg); + evals[t] = F::sub_inner(&evals[t], &src[t], field_cfg); } } c => { for t in 0..n { - let term = - apply_coeff_inner::(c, &src[t], field_cfg); - evals[t] = - F::add_inner(&evals[t], &term, field_cfg); + let term = apply_coeff_inner::(c, &src[t], field_cfg); + evals[t] = F::add_inner(&evals[t], &term, field_cfg); } } } } - DenseMultilinearExtension { evaluations: evals, num_vars } + DenseMultilinearExtension { + evaluations: evals, + num_vars, + } }) .collect() } @@ -314,13 +307,11 @@ where VirtualBoolSource::ShiftedBitSlice { shifted_spec_idx, bit_idx, - } => &shifted_bit_slice_evals - [*shifted_spec_idx * D + *bit_idx], + } => &shifted_bit_slice_evals[*shifted_spec_idx * D + *bit_idx], VirtualBoolSource::PublicBitSlice { public_col_idx, bit_idx, - } => &public_bit_slice_evals - [*public_col_idx * D + *bit_idx], + } => &public_bit_slice_evals[*public_col_idx * D + *bit_idx], VirtualBoolSource::IntCol { witness_col_idx } => { &int_witness_up_evals[*witness_col_idx] } @@ -456,7 +447,10 @@ where } evaluations.push(BinaryPoly::::new(coeffs.as_slice())); } - DenseMultilinearExtension { evaluations, num_vars } + DenseMultilinearExtension { + evaluations, + num_vars, + } }) .collect() } @@ -729,9 +723,8 @@ where let r1_inner = r1.inner().clone(); let one_minus_r1_inner = one_minus_r1.inner().clone(); - let mut mles: Vec> = Vec::with_capacity( - 1 + self.binary_cols.len() * D + self.extra_bit_cols.len(), - ); + let mut mles: Vec> = + Vec::with_capacity(1 + self.binary_cols.len() * D + self.extra_bit_cols.len()); mles.push(eq_folded); // BinaryPoly does not impl Index on every backend @@ -867,7 +860,12 @@ pub fn prepare_booleanity_group( field_cfg: &F::Config, ) -> Result, BooleanityProverAncillary)>, BooleanityError> where - F: InnerTransparentField + FromPrimitiveWithConfig + Send + Sync + 'static, + F: InnerTransparentField + + DelayedFieldProductSum + + FromPrimitiveWithConfig + + Send + + Sync + + 'static, F::Inner: ConstTranscribable + Send + Sync + Zero + Default + Clone, F::Modulus: ConstTranscribable, { @@ -880,8 +878,7 @@ where let one = F::one_with_cfg(field_cfg); let folding_challenge: F = transcript.get_field_challenge(field_cfg); - let folding_challenge_powers: Vec = - powers(folding_challenge, one.clone(), num_bit_slices); + let folding_challenge_powers: Vec = powers(folding_challenge, one.clone(), num_bit_slices); // Pre-build E_other = eq(b', ic_evaluation_point[1..]) for the // round-1 fast path. The full-size eq_r is only needed for rounds @@ -914,11 +911,13 @@ where // Σ_k α^k · v_k · (v_k - 1) computed as Σ_k α^k · (v_k² - v_k) to // avoid a per-iteration `(v - one)` clone. - let mut acc = zero.clone(); - for (v, coeff) in bits.iter().zip(folding_challenge_powers.iter()) { - let v_sq = v.clone() * v.clone(); - acc = acc + coeff.clone() * (v_sq - v.clone()); - } + let residuals = booleanity_residuals(bits); + let acc = FieldFieldInnerProduct::inner_product::( + &folding_challenge_powers, + &residuals, + zero.clone(), + ) + .expect("booleanity residuals and powers have matching lengths"); acc * eq_r.clone() }); @@ -1035,7 +1034,7 @@ pub fn finalize_booleanity_verifier( field_cfg: &F::Config, ) -> Result<(), BooleanityError> where - F: InnerTransparentField, + F: InnerTransparentField + DelayedFieldProductSum, F::Inner: ConstTranscribable, F::Modulus: ConstTranscribable, { @@ -1058,14 +1057,18 @@ where let eq_r_value = eq_eval(shared_point, &ancillary.ic_evaluation_point, one.clone())?; let n_proof = bit_slice_evals.len() - closing_overrides_tail.len(); - let bool_folded = bit_slice_evals[..n_proof] + let values: Vec = bit_slice_evals[..n_proof] .iter() .chain(closing_overrides_tail.iter()) - .zip(ancillary.folding_challenge_powers.iter()) - .fold(zero, |acc, (v, coeff)| { - let v_sq = v.clone() * v.clone(); - acc + coeff.clone() * (v_sq - v.clone()) - }); + .cloned() + .collect(); + let residuals = booleanity_residuals(&values); + let bool_folded = FieldFieldInnerProduct::inner_product::( + &ancillary.folding_challenge_powers, + &residuals, + zero, + ) + .expect("booleanity residuals and powers have matching lengths"); let recomputed = bool_folded * eq_r_value; @@ -1095,7 +1098,7 @@ where /// each bit-slice eval against the true bit-decomposition of the /// committed parent column. #[allow(clippy::arithmetic_side_effects)] -pub fn verify_bit_decomposition_consistency( +pub fn verify_bit_decomposition_consistency( parent_evals_per_col: &[F], bit_slice_evals: &[F], projecting_element: &F, @@ -1114,24 +1117,15 @@ pub fn verify_bit_decomposition_consistency( let zero = F::zero_with_cfg(projecting_element.cfg()); let one = F::one_with_cfg(projecting_element.cfg()); - - // Powers [1, a, a^2, ..., a^{bits_per_col - 1}]. - let mut a_powers: Vec = Vec::with_capacity(bits_per_col); - let mut acc = one; - for _ in 0..bits_per_col { - a_powers.push(acc.clone()); - acc *= projecting_element; - } + let a_powers: Vec = powers(projecting_element.clone(), one, bits_per_col); for (col_idx, parent_eval) in parent_evals_per_col.iter().enumerate() { let base = col_idx * bits_per_col; - let recombined = - bit_slice_evals[base..base + bits_per_col] - .iter() - .zip(&a_powers) - .fold(zero.clone(), |acc, (bit_eval, a_pow)| { - acc + bit_eval.clone() * a_pow - }); + let recombined = project_bit_slice_chunk( + &bit_slice_evals[base..base + bits_per_col], + &a_powers, + zero.clone(), + ); if &recombined != parent_eval { return Err(BooleanityError::ConsistencyMismatch { @@ -1145,11 +1139,28 @@ pub fn verify_bit_decomposition_consistency( Ok(()) } +fn booleanity_residuals(values: &[F]) -> Vec { + values + .iter() + .map(|v| { + let v_sq = v.clone() * v.clone(); + v_sq - v.clone() + }) + .collect() +} + +fn project_bit_slice_chunk( + bit_slice_evals: &[F], + powers: &[F], + zero: F, +) -> F { + FieldFieldInnerProduct::inner_product::(bit_slice_evals, powers, zero) + .expect("bit-slice chunk and projection powers have matching lengths") +} + #[derive(Debug, Error)] pub enum BooleanityError { - #[error( - "wrong bit-slice evaluation count: got {got}, expected {expected}" - )] + #[error("wrong bit-slice evaluation count: got {got}, expected {expected}")] WrongBitSliceEvalCount { got: usize, expected: usize }, #[error( "bit-decomposition consistency mismatch on binary_poly column {col_idx}: got Σ a^i·bᵢ = {got:?}, expected parent eval {expected:?}" @@ -1168,9 +1179,7 @@ pub enum BooleanityError { #[cfg(test)] mod tests { use super::*; - use crypto_primitives::{ - FromWithConfig, boolean::Boolean, crypto_bigint_monty::MontyField, - }; + use crypto_primitives::{FromWithConfig, boolean::Boolean, crypto_bigint_monty::MontyField}; type F = MontyField<4>; @@ -1183,8 +1192,7 @@ mod tests { let evaluations: Vec> = patterns .iter() .map(|&p| { - let coeffs: [Boolean; 8] = - array::from_fn(|i| Boolean::new((p >> i) & 1 != 0)); + let coeffs: [Boolean; 8] = array::from_fn(|i| Boolean::new((p >> i) & 1 != 0)); BinaryPoly::<8>::new(coeffs) }) .collect(); @@ -1195,6 +1203,33 @@ mod tests { } } + #[test] + fn shifted_bit_slice_streaming_matches_materialized_mles() { + let cfg = test_cfg(); + let col = col_from_u8s(&[0b0000_0001, 0b0000_1010, 0b1010_0000, 0b1111_0000]); + let specs = [ + ShiftedBitSliceSpec::new(0, 1), + ShiftedBitSliceSpec::new(0, 2), + ]; + let point = vec![F::from_with_cfg(3u64, &cfg), F::from_with_cfg(5u64, &cfg)]; + + let materialized = + build_shifted_bit_slice_mles::(std::slice::from_ref(&col), &specs, &cfg) + .into_iter() + .map(|mle| mle.evaluate_with_config(&point, &cfg)) + .collect::, _>>() + .unwrap(); + let streaming = compute_shifted_bit_slice_evals_streaming::( + std::slice::from_ref(&col), + &specs, + &point, + &cfg, + ) + .unwrap(); + + assert_eq!(streaming, materialized); + } + #[test] fn bit_slices_round_trip_recovers_original_bits() { let cfg = test_cfg(); @@ -1214,7 +1249,10 @@ mod tests { } else { zero.clone() }; - assert_eq!(bit_slices[bit].evaluations[row], want, "row {row} bit {bit}"); + assert_eq!( + bit_slices[bit].evaluations[row], want, + "row {row} bit {bit}" + ); } } } @@ -1239,13 +1277,8 @@ mod tests { a_pow = a_pow * a.clone(); } - verify_bit_decomposition_consistency( - std::slice::from_ref(&parent_eval), - &bit_evals, - &a, - 8, - ) - .expect("honest decomposition should satisfy consistency check"); + verify_bit_decomposition_consistency(std::slice::from_ref(&parent_eval), &bit_evals, &a, 8) + .expect("honest decomposition should satisfy consistency check"); } #[test] @@ -1277,7 +1310,10 @@ mod tests { &a, 4, ); - assert!(matches!(res, Err(BooleanityError::ConsistencyMismatch { .. }))); + assert!(matches!( + res, + Err(BooleanityError::ConsistencyMismatch { .. }) + )); } #[test] @@ -1289,6 +1325,32 @@ mod tests { verify_bit_decomposition_consistency(&parent_evals, &bit_evals, &one, 8).unwrap(); } + #[test] + fn bit_slice_projection_helper_matches_naive_fold() { + let cfg = test_cfg(); + let zero = F::zero_with_cfg(&cfg); + let one = F::one_with_cfg(&cfg); + let a = F::from_with_cfg(9u64, &cfg); + let bit_evals = vec![ + F::from_with_cfg(1u64, &cfg), + F::from_with_cfg(0u64, &cfg), + F::from_with_cfg(1u64, &cfg), + F::from_with_cfg(1u64, &cfg), + F::from_with_cfg(0u64, &cfg), + ]; + let powers = powers(a, one, bit_evals.len()); + + let got = project_bit_slice_chunk(&bit_evals, &powers, zero.clone()); + let want = bit_evals + .iter() + .zip(&powers) + .fold(zero, |acc, (bit_eval, power)| { + acc + bit_eval.clone() * power + }); + + assert_eq!(got, want); + } + /// Cross-validate the round-1 fast path against a faithful standard /// run of `ProverState::prove_round`. Both must produce the same /// tail evaluations and the same asserted sum (zero, since this is @@ -1306,8 +1368,14 @@ mod tests { // Mix of fully-zero, fully-one, and varying patterns to exercise all // four (A, B) cases for the XOR fold structure. let binary_cols = vec![ - col_from_u8s(&[0b00000000, 0b00010001, 0b00100010, 0b00110011, 0b01000100, 0b01010101, 0b01100110, 0b01110111]), - col_from_u8s(&[0b11110000, 0b11100001, 0b11010010, 0b11000011, 0b10110100, 0b10100101, 0b10010110, 0b10000111]), + col_from_u8s(&[ + 0b00000000, 0b00010001, 0b00100010, 0b00110011, 0b01000100, 0b01010101, 0b01100110, + 0b01110111, + ]), + col_from_u8s(&[ + 0b11110000, 0b11100001, 0b11010010, 0b11000011, 0b10110100, 0b10100101, 0b10010110, + 0b10000111, + ]), ]; let num_vars = 3; const D: usize = 8; @@ -1430,7 +1498,8 @@ mod tests { let mut std_eq_r = eq_r_full; std_eq_r.fix_variables_with_config(slice::from_ref(&r_1), &cfg); assert_eq!( - fast_mles[0].num_vars, num_vars - 1, + fast_mles[0].num_vars, + num_vars - 1, "fast-path eq_r_folded must have num_vars - 1 variables" ); assert_eq!( @@ -1441,7 +1510,8 @@ mod tests { for (idx, mut bit_mle) in bit_slices_full.into_iter().enumerate() { bit_mle.fix_variables_with_config(slice::from_ref(&r_1), &cfg); assert_eq!( - fast_mles[1 + idx].evaluations, bit_mle.evaluations, + fast_mles[1 + idx].evaluations, + bit_mle.evaluations, "fast-path bit-slice {idx} folded value must match standard fix_variables" ); } @@ -1454,6 +1524,9 @@ mod tests { let parent_evals = vec![one.clone()]; let bit_evals: Vec = vec![one.clone(), one.clone()]; let res = verify_bit_decomposition_consistency(&parent_evals, &bit_evals, &one, 8); - assert!(matches!(res, Err(BooleanityError::WrongBitSliceEvalCount { .. }))); + assert!(matches!( + res, + Err(BooleanityError::WrongBitSliceEvalCount { .. }) + )); } } diff --git a/piop/src/multipoint_eval.rs b/piop/src/multipoint_eval.rs index 610224d1..de6addff 100644 --- a/piop/src/multipoint_eval.rs +++ b/piop/src/multipoint_eval.rs @@ -44,10 +44,15 @@ use zinc_poly::{ }; use zinc_transcript::{ delegate_transcribable, - traits::{ConstTranscribable, Transcript}, + traits::{ConstTranscribable, Transcribable, Transcript}, }; use zinc_uair::ShiftSpec; -use zinc_utils::{cfg_into_iter, inner_transparent_field::InnerTransparentField}; +use zinc_utils::{ + UNCHECKED, cfg_into_iter, + delayed_reduction::DelayedFieldProductSum, + inner_product::{FieldFieldInnerProduct, InnerProduct}, + inner_transparent_field::InnerTransparentField, +}; // // Data structures @@ -104,9 +109,14 @@ pub struct MultipointEval(PhantomData); impl MultipointEval where - F: InnerTransparentField + FromPrimitiveWithConfig + Send + Sync + 'static, - F::Inner: ConstTranscribable + Zero + Default + Send + Sync, - F::Modulus: ConstTranscribable, + F: InnerTransparentField + + DelayedFieldProductSum + + FromPrimitiveWithConfig + + Send + + Sync + + 'static, + F::Inner: Transcribable + Zero + Default + Send + Sync, + F::Modulus: Transcribable, { /// Multi-point evaluation protocol prover. /// @@ -133,23 +143,30 @@ where // Step 1: Sample multi-point batching coefficient \alpha and column // batching coefficients \gamma_1,...,\gamma_J. - let alphas: Vec = transcript.get_field_challenges(num_down_cols, field_cfg); - let gammas: Vec = transcript.get_field_challenges(num_cols, field_cfg); + let alphas: Vec = + transcript.get_transcribable_field_challenges(num_down_cols, field_cfg); + let gammas: Vec = transcript.get_transcribable_field_challenges(num_cols, field_cfg); // Step 2: Build the two selector MLEs: // eq_r(b) = eq(b, r') // next_c_r_mle(b) = next_c_mle(r', b) let eq_r = build_eq_x_r_inner(eval_point, field_cfg)?; - let (next_mles, down_cols): (Vec<_>, Vec<_>) = shifts + let mut shift_groups: Vec<(usize, Vec<(usize, usize)>)> = Vec::new(); + for (alpha_idx, spec) in shifts.iter().enumerate() { + let amount = spec.shift_amount(); + if let Some((_, group)) = shift_groups + .iter_mut() + .find(|(candidate, _)| *candidate == amount) + { + group.push((alpha_idx, spec.source_col())); + } else { + shift_groups.push((amount, vec![(alpha_idx, spec.source_col())])); + } + } + let next_mles = shift_groups .iter() - .map(|spec| { - let next = build_next_c_r_mle(eval_point, spec.shift_amount(), field_cfg)?; - let col = trace_mles[spec.source_col()].clone(); - Ok((next, col)) - }) - .collect::, ArithErrors>>()? - .into_iter() - .unzip(); + .map(|(amount, _)| build_next_c_r_mle(eval_point, *amount, field_cfg)) + .collect::, ArithErrors>>()?; // Precombine up cols with gammas, precombined[b] = Σ_j γ_j trace[j][b]. // Multiplying eval_f by &gamma uses `Mul<&Self, Output=Self>` from @@ -179,12 +196,38 @@ where ) }; + let grouped_down_cols = shift_groups + .iter() + .map(|(_, group)| { + let evaluations: Vec<_> = cfg_into_iter!(0..1 << num_vars) + .map(|b| { + group + .iter() + .fold(zero.clone(), |acc, (alpha_idx, source_col)| { + let eval_f = F::new_unchecked_with_cfg( + trace_mles[*source_col].evaluations[b].clone(), + field_cfg, + ); + acc + eval_f * &alphas[*alpha_idx] + }) + .into_inner() + }) + .collect(); + DenseMultilinearExtension::from_evaluations_vec( + num_vars, + evaluations, + zero_inner.clone(), + ) + }) + .collect::>(); + // Step 3: Pack MLEs: [eq_r, next_mles[..], precombined, down_cols[..]] - let mut mles = Vec::with_capacity(2 + 2 * num_down_cols); + let grouped_down_cols_len = grouped_down_cols.len(); + let mut mles = Vec::with_capacity(2 + 2 * grouped_down_cols_len); mles.push(eq_r); mles.extend(next_mles); mles.push(precombined); - mles.extend(down_cols); + mles.extend(grouped_down_cols); // Step 4: Run sumcheck with degree=2. @@ -197,15 +240,12 @@ where 2, |mle_values: &[F]| { let eq_val = &mle_values[0]; - let precombined = &mle_values[num_down_cols + 1]; - alphas - .iter() - .enumerate() - .fold(eq_val.clone() * precombined, |acc, (i, alpha)| { - let next = &mle_values[1 + i]; - let down_col = &mle_values[num_down_cols + 2 + i]; - acc + alpha.clone() * next * down_col - }) + let precombined = &mle_values[grouped_down_cols_len + 1]; + (0..grouped_down_cols_len).fold(eq_val.clone() * precombined, |acc, i| { + let next = &mle_values[1 + i]; + let down_col = &mle_values[grouped_down_cols_len + 2 + i]; + acc + next.clone() * down_col + }) }, field_cfg, ); @@ -249,8 +289,9 @@ where let one = F::one_with_cfg(field_cfg); // Step 1: Sample \alpha_k and \gamma_j (must match prover). - let alphas: Vec = transcript.get_field_challenges(num_down_cols, field_cfg); - let gammas: Vec = transcript.get_field_challenges(num_cols, field_cfg); + let alphas: Vec = + transcript.get_transcribable_field_challenges(num_down_cols, field_cfg); + let gammas: Vec = transcript.get_transcribable_field_challenges(num_cols, field_cfg); // Step 2: Compute expected sum let expected_sum: F = @@ -315,13 +356,12 @@ where let zero = F::zero_with_cfg(field_cfg); - let batched_up: F = subclaim - .gammas - .iter() - .zip(open_evals.iter()) - .fold(zero.clone(), |acc, (gamma, eval)| { - acc + gamma.clone() * eval - }); + let batched_up: F = FieldFieldInnerProduct::inner_product::( + &subclaim.gammas, + open_evals, + zero.clone(), + ) + .expect("inner product cannot fail here"); // open_evals[j] = trace_col_j(r_0) for all committed (up) columns. // Shifted columns reuse the same opening: the shift is captured by @@ -351,22 +391,20 @@ where /// `expected_sum = \sum_j \gamma_j * up_eval_j + \sum_k \alpha_k * /// down_eval_k` -fn compute_expected_sum( +fn compute_expected_sum( up_evals: &[F], down_evals: &[F], gammas: &[F], alphas: &[F], zero: F, ) -> F { - let up_sum = gammas - .iter() - .zip(up_evals.iter()) - .fold(zero, |acc, (gamma, up)| acc + gamma.clone() * up); - - alphas - .iter() - .zip(down_evals.iter()) - .fold(up_sum, |acc, (alpha, down)| acc + alpha.clone() * down) + let up_sum = FieldFieldInnerProduct::inner_product::(gammas, up_evals, zero.clone()) + .expect("inner product cannot fail here"); + + let down_sum = FieldFieldInnerProduct::inner_product::(alphas, down_evals, zero) + .expect("inner product cannot fail here"); + + up_sum + down_sum } // diff --git a/piop/src/neutron_nova/accumulator.rs b/piop/src/neutron_nova/accumulator.rs new file mode 100644 index 00000000..a31432c5 --- /dev/null +++ b/piop/src/neutron_nova/accumulator.rs @@ -0,0 +1,372 @@ +use crypto_primitives::{PrimeField, crypto_bigint_uint::Uint}; +use num_traits::Zero; +use std::array; +use thiserror::Error; +use zinc_poly::{ + mle::DenseMultilinearExtension, + univariate::binary::BinaryPoly, + utils::{ArithErrors, build_eq_x_r_vec}, +}; +use zinc_utils::{ + UNCHECKED, + delayed_reduction::{ + BarrettDelayedReduction, DelayedFieldProductSum, DelayedModularReductionAlgorithm, + MontgomeryLimbs, MontgomeryProductSum4, + }, + inner_product::FieldFieldInnerProduct, + powers, +}; + +/// Errors produced by NeutronNova row-space accumulation helpers. +#[derive(Clone, Debug, Error)] +pub enum AccumulatorError { + #[error("row weight length mismatch: weights={weights}, rows={rows}")] + RowWeightLengthMismatch { weights: usize, rows: usize }, + #[error("bit index {bit_idx} is out of range for degree bound {degree}")] + BitIndexOutOfRange { bit_idx: usize, degree: usize }, + #[error("projection powers length mismatch: got {got}, expected at least {expected}")] + ProjectionPowersLengthMismatch { got: usize, expected: usize }, + #[error("row-weight construction failed: {0}")] + RowWeights(#[from] ArithErrors), +} + +/// Equality weights over the Boolean row space. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct RowWeights { + weights: Vec, +} + +impl RowWeights { + /// Build weights `eq(r, z)` for all Boolean rows `z`. + /// + /// The zero-variable row space is supported as a single row of weight 1. + pub fn new(row_point: &[F], field_cfg: &F::Config) -> Result { + let weights = if row_point.is_empty() { + vec![F::one_with_cfg(field_cfg)] + } else { + build_eq_x_r_vec(row_point, field_cfg)? + }; + Ok(Self { weights }) + } + + /// Build row weights and zero the final row to match current CPR parity. + pub fn new_with_last_row_zero( + row_point: &[F], + field_cfg: &F::Config, + ) -> Result { + let mut weights = Self::new(row_point, field_cfg)?; + weights.zero_last_row(field_cfg); + Ok(weights) + } + + /// Set the final row weight to zero in-place. + pub fn zero_last_row(&mut self, field_cfg: &F::Config) { + if let Some(last) = self.weights.last_mut() { + *last = F::zero_with_cfg(field_cfg); + } + } + + pub fn as_slice(&self) -> &[F] { + &self.weights + } + + pub fn len(&self) -> usize { + self.weights.len() + } + + pub fn is_empty(&self) -> bool { + self.weights.is_empty() + } +} + +/// DMR-backed bit buckets for one small-value binary-polynomial column. +#[derive(Clone, Debug)] +pub struct SmallValueBitAccumulator<'a, F: MontgomeryLimbs, const D: usize> { + buckets: [Uint<5>; D], + lane_accs: [F; D], + pending_adds: usize, + field_cfg: &'a F::Config, + reducer: BarrettDelayedReduction<'a, F>, +} + +impl<'a, F, const D: usize> SmallValueBitAccumulator<'a, F, D> +where + F: MontgomeryLimbs + Send + Sync, +{ + pub fn new(field_cfg: &'a F::Config) -> Self { + let reducer = BarrettDelayedReduction::::new(field_cfg); + let zero = F::zero_with_cfg(field_cfg); + Self { + buckets: [Uint::zero(); D], + lane_accs: array::from_fn(|_| zero.clone()), + pending_adds: 0, + field_cfg, + reducer, + } + } + + /// Pending unreduced DMR buckets. + /// + /// Flushed contributions live in the reduced lane accumulators, so this is + /// only useful for low-level tests and diagnostics. + pub fn pending_buckets(&self) -> &[Uint<5>] { + &self.buckets + } + + pub fn add_bit_weight(&mut self, bit_idx: usize, weight: &F) -> Result<(), AccumulatorError> { + let Some(bucket) = self.buckets.get_mut(bit_idx) else { + return Err(AccumulatorError::BitIndexOutOfRange { bit_idx, degree: D }); + }; + self.reducer.add(bucket, weight); + self.pending_adds = self.pending_adds.saturating_add(1); + if self.pending_adds >= self.reducer.flush_adds() { + self.flush_buckets(); + } + Ok(()) + } + + #[allow(clippy::arithmetic_side_effects)] + pub fn add_binary_poly( + &mut self, + poly: &BinaryPoly, + weight: &F, + ) -> Result<(), AccumulatorError> { + if D <= 64 { + let mut bits = 0u64; + for (bit_idx, coeff) in poly.iter().enumerate().take(D) { + if coeff.into_inner() { + bits |= 1u64 << bit_idx; + } + } + + while bits != 0 { + let bit_idx = + usize::try_from(bits.trailing_zeros()).expect("trailing_zeros fits usize"); + self.add_bit_weight(bit_idx, weight)?; + bits &= bits - 1; + } + } else { + for (bit_idx, coeff) in poly.iter().enumerate().take(D) { + if coeff.into_inner() { + self.add_bit_weight(bit_idx, weight)?; + } + } + } + Ok(()) + } + + pub fn reduce_buckets(mut self) -> Vec { + self.flush_buckets(); + self.lane_accs.into_iter().collect() + } + + fn flush_buckets(&mut self) { + for (bucket, acc) in self.buckets.iter_mut().zip(self.lane_accs.iter_mut()) { + if bucket.is_zero() { + continue; + } + let pending = std::mem::replace(bucket, Uint::zero()); + *acc += self.reducer.reduce(pending); + } + self.pending_adds = 0; + } +} + +impl SmallValueBitAccumulator<'_, F, D> +where + F: MontgomeryLimbs + DelayedFieldProductSum + Send + Sync, +{ + pub fn project(mut self, projection_powers: &[F]) -> Result { + if projection_powers.len() < D { + return Err(AccumulatorError::ProjectionPowersLengthMismatch { + got: projection_powers.len(), + expected: D, + }); + } + + self.flush_buckets(); + let zero = F::zero_with_cfg(self.field_cfg); + let product_sum = MontgomeryProductSum4::::new(self.field_cfg); + Ok( + FieldFieldInnerProduct::inner_product_with_algorithm::( + &product_sum, + &self.lane_accs, + &projection_powers[..D], + zero, + ) + .expect("bucket and projection-power lengths match"), + ) + } +} + +/// Accumulate one binary-polynomial column projected by `projecting_element`. +/// +/// This computes the bucket-first form: +/// first `S_j = sum_z bit_j(col[z]) * row_weight[z]`, then +/// `sum_j S_j * projecting_element^j`. +pub fn accumulate_binary_column_projected( + column: &DenseMultilinearExtension>, + row_weights: &RowWeights, + projecting_element: &F, + field_cfg: &F::Config, +) -> Result +where + F: MontgomeryLimbs + DelayedFieldProductSum + Send + Sync, +{ + if column.evaluations.len() != row_weights.len() { + return Err(AccumulatorError::RowWeightLengthMismatch { + weights: row_weights.len(), + rows: column.evaluations.len(), + }); + } + + let one = F::one_with_cfg(field_cfg); + let projection_powers: Vec = powers(projecting_element.clone(), one, D); + let mut accumulator = SmallValueBitAccumulator::::new(field_cfg); + + for (poly, weight) in column.iter().zip(row_weights.as_slice()) { + accumulator.add_binary_poly(poly, weight)?; + } + + accumulator.project(&projection_powers) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::test_utils::test_config; + use crypto_bigint::{Odd, modular::MontyParams}; + use crypto_primitives::{FromWithConfig, crypto_bigint_monty::MontyField}; + use zinc_utils::powers; + + type F = MontyField<4>; + + fn field_cfg() -> MontyParams<4> { + let modulus = crypto_bigint::Uint::<4>::from_words([ + 0xFFFF_FFFE_FFFF_FC2F, + 0xFFFF_FFFF_FFFF_FFFF, + 0xFFFF_FFFF_FFFF_FFFF, + 0xFFFF_FFFF_FFFF_FFFF, + ]); + MontyParams::new(Odd::new(modulus).expect("secp256k1 modulus is odd")) + } + + fn f(value: u64, cfg: &MontyParams<4>) -> F { + F::from_with_cfg(value, cfg) + } + + fn binary_col(patterns: &[u32]) -> DenseMultilinearExtension> { + DenseMultilinearExtension::from_evaluations_vec( + usize::try_from(patterns.len().next_power_of_two().trailing_zeros()) + .expect("trailing_zeros fits usize"), + patterns + .iter() + .copied() + .map(BinaryPoly::<32>::from) + .collect(), + BinaryPoly::<32>::zero(), + ) + } + + #[allow(clippy::arithmetic_side_effects)] + fn naive_projected_sum( + column: &DenseMultilinearExtension>, + row_weights: &RowWeights, + projecting_element: &F, + cfg: &MontyParams<4>, + ) -> F { + let zero = F::zero_with_cfg(&cfg); + let one = F::one_with_cfg(&cfg); + let powers = powers(projecting_element.clone(), one, 32); + + column + .iter() + .zip(row_weights.as_slice()) + .fold(zero, |mut acc, (poly, row_weight)| { + for (bit_idx, coeff) in poly.iter().enumerate().take(32) { + if coeff.into_inner() { + acc += row_weight.clone() * &powers[bit_idx]; + } + } + acc + }) + } + + #[test] + fn projected_binary_column_matches_naive_row_space_sum() { + let cfg = field_cfg(); + let point = vec![f(3, &cfg), f(5, &cfg), f(7, &cfg)]; + let row_weights = RowWeights::new(&point, &cfg).unwrap(); + let column = binary_col(&[ + 0x0000_0001, + 0x8000_0001, + 0x0f0f_00f0, + 0xf000_00ff, + 0x0101_0101, + 0x1111_2222, + 0xdead_beef, + 0xffff_0000, + ]); + let projecting_element = f(11, &cfg); + + let got = + accumulate_binary_column_projected(&column, &row_weights, &projecting_element, &cfg) + .unwrap(); + let expected = naive_projected_sum(&column, &row_weights, &projecting_element, &cfg); + + assert_eq!(got, expected); + } + + #[test] + fn small_value_bit_accumulator_flushes_for_small_modulus() { + let cfg = test_config(); + let max = -F::from_with_cfg(1u64, &cfg); + let mut accumulator = SmallValueBitAccumulator::::new(&cfg); + let mut expected = F::zero_with_cfg(&cfg); + + for _ in 0..2048 { + accumulator.add_bit_weight(7, &max).unwrap(); + expected += &max; + } + + let lanes = accumulator.reduce_buckets(); + assert_eq!(lanes[7], expected); + assert!( + lanes + .iter() + .enumerate() + .all(|(lane, value)| lane == 7 || F::is_zero(value)) + ); + } + + #[test] + fn last_row_zero_helper_matches_manual_zeroing() { + let cfg = field_cfg(); + let point = vec![f(3, &cfg), f(5, &cfg)]; + + let mut manual = RowWeights::new(&point, &cfg).unwrap(); + manual.zero_last_row(&cfg); + let helper = RowWeights::new_with_last_row_zero(&point, &cfg).unwrap(); + + assert_eq!(helper, manual); + assert_eq!(helper.as_slice().last().unwrap(), &F::zero_with_cfg(&cfg)); + } + + #[test] + fn projected_binary_column_rejects_row_weight_mismatch() { + let cfg = field_cfg(); + let row_weights = RowWeights::new(&[f(3, &cfg), f(5, &cfg)], &cfg).unwrap(); + let column = binary_col(&[1, 2, 3, 4, 5, 6, 7, 8]); + + let err = accumulate_binary_column_projected(&column, &row_weights, &f(11, &cfg), &cfg) + .expect_err("mismatched row weights should be rejected"); + + assert!(matches!( + err, + AccumulatorError::RowWeightLengthMismatch { + weights: 4, + rows: 8 + } + )); + } +} diff --git a/piop/src/neutron_nova/booleanity.rs b/piop/src/neutron_nova/booleanity.rs new file mode 100644 index 00000000..e3af721e --- /dev/null +++ b/piop/src/neutron_nova/booleanity.rs @@ -0,0 +1,1267 @@ +use crate::neutron_nova::RowWeights; +use crypto_primitives::{FromPrimitiveWithConfig, PrimeField, crypto_bigint_uint::Uint}; +use num_traits::Zero; +use std::array; +use thiserror::Error; +use zinc_poly::univariate::binary::BinaryPoly; +use zinc_uair::UairTrace; +use zinc_utils::{ + UNCHECKED, + delayed_reduction::{ + BarrettDelayedReduction, DelayedFieldProductSum, DelayedModularReductionAlgorithm, + MontgomeryLimbs, MontgomeryProductSum4, + }, + inner_product::FieldFieldInnerProduct, +}; + +const MAX_DMR_BUCKET_ARRAYS: usize = 256; +const MAX_BOOLEANITY_PREFIX_VARS: usize = 10; + +/// Precomputed equality weights used by the booleanity accumulator. +#[derive(Debug)] +pub struct BooleanityWeights<'a, F: PrimeField> { + pub row_weights: &'a RowWeights, + pub tail_eq_weights: &'a [F], +} + +impl Clone for BooleanityWeights<'_, F> { + fn clone(&self) -> Self { + *self + } +} + +impl Copy for BooleanityWeights<'_, F> {} + +/// Precomputed scalarization weights for booleanity lanes. +#[derive(Debug)] +pub struct BooleanityScalarWeights<'a, F: PrimeField> { + /// Indexed as `col_idx * D + bit_idx`. + pub rho_powers: &'a [F], +} + +impl Clone for BooleanityScalarWeights<'_, F> { + fn clone(&self) -> Self { + *self + } +} + +impl Copy for BooleanityScalarWeights<'_, F> {} + +/// One point in `{0, 1, infinity}^prefix_vars`, represented as `(S, a)`. +/// +/// `support_mask` marks coordinates set to infinity. `finite_bits` stores +/// Boolean assignments in original coordinate positions. The two masks are +/// canonical and must not overlap. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct ExtendedPrefixPoint { + support_mask: usize, + finite_bits: usize, +} + +impl ExtendedPrefixPoint { + pub fn new( + support_mask: usize, + finite_bits: usize, + ) -> Result { + if support_mask & finite_bits != 0 { + return Err(BooleanityAccumulatorError::ExtendedPointNotCanonical); + } + Ok(Self { + support_mask, + finite_bits, + }) + } + + pub fn support_mask(self) -> usize { + self.support_mask + } + + pub fn finite_bits(self) -> usize { + self.finite_bits + } + + pub fn support_size(self) -> usize { + usize::try_from(self.support_mask.count_ones()).expect("count_ones fits usize") + } + + pub fn is_finite_only(self) -> bool { + self.support_mask == 0 + } +} + +/// Dense table over `{0, 1, infinity}^prefix_vars`. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct BooleanityPrefixTable { + values: Vec, + ell: usize, + prefix_vars: usize, + num_binary_cols: usize, +} + +impl BooleanityPrefixTable { + pub fn values(&self) -> &[F] { + &self.values + } + + pub fn ell(&self) -> usize { + self.ell + } + + pub fn prefix_vars(&self) -> usize { + self.prefix_vars + } + + pub fn num_binary_cols(&self) -> usize { + self.num_binary_cols + } + + pub fn len(&self) -> usize { + self.values.len() + } + + pub fn is_empty(&self) -> bool { + self.values.is_empty() + } + + pub fn value_at_point( + &self, + point: ExtendedPrefixPoint, + ) -> Result<&F, BooleanityAccumulatorError> { + let index = extended_point_index(point, self.prefix_vars)?; + self.values + .get(index) + .ok_or(BooleanityAccumulatorError::ExtendedPointIndexOutOfRange { + index, + domain_size: self.values.len(), + }) + } +} + +#[derive(Clone, Debug, Error)] +pub enum BooleanityAccumulatorError { + #[error("booleanity accumulator needs at least one trace")] + EmptyTraces, + #[error("trace count must be a power of two, got {len}")] + TraceCountNotPowerOfTwo { len: usize }, + #[error("prefix_vars={prefix_vars} must be at most ell={ell}")] + PrefixVarsTooLarge { prefix_vars: usize, ell: usize }, + #[error("booleanity prefix_vars={prefix_vars} exceeds supported maximum {max}")] + PrefixVarsExceedsSupported { prefix_vars: usize, max: usize }, + #[error("domain size is too large for {vars} variables")] + DomainTooLarge { vars: usize }, + #[error("{label} length mismatch: got {got}, expected {expected}")] + LengthMismatch { + label: &'static str, + got: usize, + expected: usize, + }, + #[error("trace {trace_idx} has {got} binary columns, expected {expected}")] + BinaryColumnCountMismatch { + trace_idx: usize, + got: usize, + expected: usize, + }, + #[error("trace {trace_idx} binary column {col_idx} has {got} rows, expected {expected}")] + BinaryColumnRowMismatch { + trace_idx: usize, + col_idx: usize, + got: usize, + expected: usize, + }, + #[error("extended point index {index} is out of range for domain size {domain_size}")] + ExtendedPointIndexOutOfRange { index: usize, domain_size: usize }, + #[error("extended point uses bits outside prefix_vars={prefix_vars}")] + ExtendedPointOutOfRange { prefix_vars: usize }, + #[error("extended point has finite bits set inside the infinity support")] + ExtendedPointNotCanonical, + #[error("accumulator bucket count overflow for {entries} entries and stride {stride}")] + BucketCountOverflow { entries: usize, stride: usize }, +} + +#[derive(Clone, Copy, Debug)] +struct ExtendedTableEntry { + table_index: usize, + point: ExtendedPrefixPoint, +} + +/// Build the optimized booleanity prefix table from small-value binary traces. +#[allow(clippy::arithmetic_side_effects, clippy::too_many_arguments)] +pub fn build_booleanity_prefix_table( + traces: &[UairTrace<'_, PolyCoeff, Int, D>], + prefix_vars: usize, + weights: BooleanityWeights<'_, F>, + scalar_weights: BooleanityScalarWeights<'_, F>, + field_cfg: &F::Config, +) -> Result, BooleanityAccumulatorError> +where + F: MontgomeryLimbs + DelayedFieldProductSum + FromPrimitiveWithConfig + Send + Sync + 'static, + PolyCoeff: Clone, + Int: Clone, +{ + let ell = validate_trace_count(traces.len())?; + if prefix_vars > ell { + return Err(BooleanityAccumulatorError::PrefixVarsTooLarge { prefix_vars, ell }); + } + if prefix_vars > MAX_BOOLEANITY_PREFIX_VARS { + return Err(BooleanityAccumulatorError::PrefixVarsExceedsSupported { + prefix_vars, + max: MAX_BOOLEANITY_PREFIX_VARS, + }); + } + + let prefix_len = binary_domain_size(prefix_vars)?; + let tail_len = binary_domain_size(ell - prefix_vars)?; + if weights.tail_eq_weights.len() != tail_len { + return Err(BooleanityAccumulatorError::LengthMismatch { + label: "tail_eq_weights", + got: weights.tail_eq_weights.len(), + expected: tail_len, + }); + } + + let num_binary_cols = validate_traces(traces, weights.row_weights.len())?; + let expected_rho_powers = num_binary_cols + .checked_mul(D) + .ok_or(BooleanityAccumulatorError::DomainTooLarge { vars: prefix_vars })?; + if scalar_weights.rho_powers.len() < expected_rho_powers { + return Err(BooleanityAccumulatorError::LengthMismatch { + label: "rho_powers", + got: scalar_weights.rho_powers.len(), + expected: expected_rho_powers, + }); + } + + let domain_len = ternary_domain_size(prefix_vars)?; + let mut table_values = vec![F::zero_with_cfg(field_cfg); domain_len]; + let entries_by_support_size = extended_entries_by_support_size(prefix_vars)?; + let row_count = weights.row_weights.len(); + let omega = precompute_row_tail_weights(weights, field_cfg)?; + let reducer = BarrettDelayedReduction::::new(field_cfg); + let word_count = bit_word_count(D); + let prefix_word_len = prefix_len + .checked_mul(word_count) + .ok_or(BooleanityAccumulatorError::DomainTooLarge { vars: prefix_vars })?; + let mut prefix_words = vec![0u64; prefix_word_len]; + + for support_size in 1..=prefix_vars { + let entries = &entries_by_support_size[support_size]; + if entries.is_empty() { + continue; + } + + let max_magnitude = max_delta_magnitude(prefix_vars, support_size)?; + let _ = max_magnitude + .checked_mul(max_magnitude) + .ok_or(BooleanityAccumulatorError::DomainTooLarge { vars: prefix_vars })?; + let tile_len = adaptive_tile_len(max_magnitude).min(entries.len()); + for tile in entries.chunks(tile_len) { + for col_idx in 0..num_binary_cols { + accumulate_column_tile::( + traces, + col_idx, + tile, + prefix_vars, + prefix_len, + tail_len, + row_count, + word_count, + max_magnitude, + &omega, + &mut prefix_words, + &scalar_weights.rho_powers[col_idx * D..col_idx * D + D], + field_cfg, + &reducer, + &mut table_values, + )?; + } + } + } + + Ok(BooleanityPrefixTable { + values: table_values, + ell, + prefix_vars, + num_binary_cols, + }) +} + +#[allow(clippy::arithmetic_side_effects, clippy::too_many_arguments)] +fn accumulate_column_tile( + traces: &[UairTrace<'_, PolyCoeff, Int, D>], + col_idx: usize, + tile: &[ExtendedTableEntry], + prefix_vars: usize, + prefix_len: usize, + tail_len: usize, + row_count: usize, + word_count: usize, + max_magnitude: usize, + omega: &[F], + prefix_words: &mut [u64], + rho_powers: &[F], + field_cfg: &F::Config, + reducer: &BarrettDelayedReduction<'_, F>, + table_values: &mut [F], +) -> Result<(), BooleanityAccumulatorError> +where + F: MontgomeryLimbs + DelayedFieldProductSum + FromPrimitiveWithConfig + Send + Sync + 'static, + PolyCoeff: Clone, + Int: Clone, +{ + let bucket_stride = max_magnitude + 1; + let bucket_count = tile.len().checked_mul(bucket_stride).ok_or( + BooleanityAccumulatorError::BucketCountOverflow { + entries: tile.len(), + stride: bucket_stride, + }, + )?; + let mut buckets: Vec<[Uint<5>; D]> = vec![[Uint::zero(); D]; bucket_count]; + let zero = F::zero_with_cfg(field_cfg); + let mut lane_accs: Vec<[F; D]> = (0..bucket_count) + .map(|_| array::from_fn(|_| zero.clone())) + .collect(); + let product_sum = MontgomeryProductSum4::::new(field_cfg); + let mut pending_adds = 0usize; + let flush_adds = reducer.flush_adds(); + + for tail in 0..tail_len { + let omega_offset = tail * row_count; + for row in 0..row_count { + gather_prefix_words::( + traces, + col_idx, + tail, + row, + prefix_vars, + prefix_len, + word_count, + prefix_words, + ); + let weight = &omega[omega_offset + row]; + + for (entry_offset, entry) in tile.iter().enumerate() { + pending_adds = pending_adds.saturating_add(accumulate_entry_deltas::( + entry.point, + entry_offset, + bucket_stride, + word_count, + prefix_words, + weight, + reducer, + &mut buckets, + )); + + if pending_adds >= flush_adds { + flush_buckets_into_lanes(&mut buckets, &mut lane_accs, reducer); + pending_adds = 0; + } + } + } + } + + flush_buckets_into_lanes(&mut buckets, &mut lane_accs, reducer); + + for (entry_offset, entry) in tile.iter().enumerate() { + for magnitude in 1..=max_magnitude { + let bucket_idx = entry_offset * bucket_stride + magnitude; + let projected = FieldFieldInnerProduct::inner_product_with_algorithm::( + &product_sum, + &lane_accs[bucket_idx], + rho_powers, + zero.clone(), + ) + .expect("lane accumulator and rho powers have matching lengths"); + let magnitude_square = magnitude + .checked_mul(magnitude) + .ok_or(BooleanityAccumulatorError::DomainTooLarge { vars: prefix_vars })?; + let scale = F::from_with_cfg( + u64::try_from(magnitude_square).map_err(|_| { + BooleanityAccumulatorError::DomainTooLarge { vars: prefix_vars } + })?, + field_cfg, + ); + table_values[entry.table_index] += scale * projected; + } + } + Ok(()) +} + +#[allow(clippy::arithmetic_side_effects)] +fn gather_prefix_words( + traces: &[UairTrace<'_, PolyCoeff, Int, D>], + col_idx: usize, + tail: usize, + row: usize, + prefix_vars: usize, + prefix_len: usize, + word_count: usize, + out: &mut [u64], +) where + PolyCoeff: Clone, + Int: Clone, +{ + for prefix in 0..prefix_len { + let instance_idx = prefix + (tail << prefix_vars); + let poly = &traces[instance_idx].binary_poly[col_idx].evaluations[row]; + write_poly_words( + poly, + &mut out[prefix * word_count..(prefix + 1) * word_count], + ); + } +} + +#[allow(clippy::arithmetic_side_effects)] +fn accumulate_entry_deltas( + point: ExtendedPrefixPoint, + entry_offset: usize, + bucket_stride: usize, + word_count: usize, + prefix_words: &[u64], + weight: &F, + reducer: &BarrettDelayedReduction<'_, F>, + buckets: &mut [[Uint<5>; D]], +) -> usize +where + F: MontgomeryLimbs + Send + Sync, +{ + match point.support_size() { + 1 => accumulate_support_one::( + point, + entry_offset, + bucket_stride, + word_count, + prefix_words, + weight, + reducer, + buckets, + ), + 2 => accumulate_support_two::( + point, + entry_offset, + bucket_stride, + word_count, + prefix_words, + weight, + reducer, + buckets, + ), + _ => accumulate_support_general::( + point, + entry_offset, + bucket_stride, + word_count, + prefix_words, + weight, + reducer, + buckets, + ), + } +} + +#[allow(clippy::arithmetic_side_effects)] +fn accumulate_support_one( + point: ExtendedPrefixPoint, + entry_offset: usize, + bucket_stride: usize, + word_count: usize, + prefix_words: &[u64], + weight: &F, + reducer: &BarrettDelayedReduction<'_, F>, + buckets: &mut [[Uint<5>; D]], +) -> usize +where + F: MontgomeryLimbs + Send + Sync, +{ + let support_bit = point.support_mask; + let base = point.finite_bits & !support_bit; + let idx0 = base; + let idx1 = base | support_bit; + let bucket_idx = entry_offset * bucket_stride + 1; + let mut adds = 0usize; + for word_idx in 0..word_count { + let mask = word_at(prefix_words, idx0, word_count, word_idx) + ^ word_at(prefix_words, idx1, word_count, word_idx); + adds += add_mask_word_to_bucket(mask, word_idx, weight, reducer, &mut buckets[bucket_idx]); + } + adds +} + +#[allow(clippy::arithmetic_side_effects)] +fn accumulate_support_two( + point: ExtendedPrefixPoint, + entry_offset: usize, + bucket_stride: usize, + word_count: usize, + prefix_words: &[u64], + weight: &F, + reducer: &BarrettDelayedReduction<'_, F>, + buckets: &mut [[Uint<5>; D]], +) -> usize +where + F: MontgomeryLimbs + Send + Sync, +{ + let first = point.support_mask & point.support_mask.wrapping_neg(); + let second = point.support_mask ^ first; + let base = point.finite_bits & !point.support_mask; + let idx00 = base; + let idx10 = base | first; + let idx01 = base | second; + let idx11 = base | first | second; + let bucket_1 = entry_offset * bucket_stride + 1; + let bucket_2 = entry_offset * bucket_stride + 2; + let mut adds = 0usize; + + for word_idx in 0..word_count { + let valid_mask = valid_word_mask::(word_idx); + let d00 = word_at(prefix_words, idx00, word_count, word_idx) & valid_mask; + let d10 = word_at(prefix_words, idx10, word_count, word_idx) & valid_mask; + let d01 = word_at(prefix_words, idx01, word_count, word_idx) & valid_mask; + let d11 = word_at(prefix_words, idx11, word_count, word_idx) & valid_mask; + let mask_1 = ((d11 ^ d00) ^ (d10 ^ d01)) & valid_mask; + let mask_2 = ((d11 & d00 & !d10 & !d01) | (!d11 & !d00 & d10 & d01)) & valid_mask; + adds += add_mask_word_to_bucket(mask_1, word_idx, weight, reducer, &mut buckets[bucket_1]); + adds += add_mask_word_to_bucket(mask_2, word_idx, weight, reducer, &mut buckets[bucket_2]); + } + + adds +} + +#[allow(clippy::arithmetic_side_effects)] +fn accumulate_support_general( + point: ExtendedPrefixPoint, + entry_offset: usize, + bucket_stride: usize, + word_count: usize, + prefix_words: &[u64], + weight: &F, + reducer: &BarrettDelayedReduction<'_, F>, + buckets: &mut [[Uint<5>; D]], +) -> usize +where + F: MontgomeryLimbs + Send + Sync, +{ + let mut support_bits = [0usize; usize::BITS as usize]; + let support_size = support_bit_masks_into(point.support_mask, &mut support_bits); + let base = point.finite_bits & !point.support_mask; + let mut deltas = [0i64; D]; + for vertex in 0..(1usize << support_size) { + let mut prefix = base; + for (pos, bit) in support_bits[..support_size].iter().enumerate() { + if ((vertex >> pos) & 1) == 1 { + prefix |= *bit; + } + } + let sign = if (support_size - vertex.count_ones() as usize) % 2 == 0 { + 1i64 + } else { + -1i64 + }; + for word_idx in 0..word_count { + let mut word = word_at(prefix_words, prefix, word_count, word_idx); + while word != 0 { + let bit = + usize::try_from(word.trailing_zeros()).expect("trailing_zeros fits usize"); + let lane = word_idx * 64 + bit; + if lane < D { + deltas[lane] += sign; + } + word &= word - 1; + } + } + } + + let mut adds = 0usize; + for (lane, delta) in deltas.iter().enumerate() { + let magnitude = usize::try_from(delta.unsigned_abs()).expect("delta magnitude fits usize"); + if magnitude == 0 { + continue; + } + let bucket_idx = entry_offset * bucket_stride + magnitude; + reducer.add(&mut buckets[bucket_idx][lane], weight); + adds += 1; + } + adds +} + +#[allow(clippy::arithmetic_side_effects)] +fn add_mask_word_to_bucket( + mut mask: u64, + word_idx: usize, + weight: &F, + reducer: &BarrettDelayedReduction<'_, F>, + bucket: &mut [Uint<5>; D], +) -> usize +where + F: MontgomeryLimbs + Send + Sync, +{ + let mut adds = 0usize; + while mask != 0 { + let bit = usize::try_from(mask.trailing_zeros()).expect("trailing_zeros fits usize"); + let lane = word_idx * 64 + bit; + if lane < D { + reducer.add(&mut bucket[lane], weight); + adds += 1; + } + mask &= mask - 1; + } + adds +} + +fn flush_buckets_into_lanes( + buckets: &mut [[Uint<5>; D]], + lane_accs: &mut [[F; D]], + reducer: &BarrettDelayedReduction<'_, F>, +) where + F: MontgomeryLimbs + Send + Sync, +{ + for (bucket_lanes, acc_lanes) in buckets.iter_mut().zip(lane_accs.iter_mut()) { + for (bucket, acc) in bucket_lanes.iter_mut().zip(acc_lanes.iter_mut()) { + if bucket.is_zero() { + continue; + } + let pending = std::mem::replace(bucket, Uint::zero()); + *acc += reducer.reduce(pending); + } + } +} + +fn validate_trace_count(len: usize) -> Result { + if len == 0 { + return Err(BooleanityAccumulatorError::EmptyTraces); + } + if !len.is_power_of_two() { + return Err(BooleanityAccumulatorError::TraceCountNotPowerOfTwo { len }); + } + Ok(usize::try_from(len.trailing_zeros()).expect("trailing_zeros fits usize")) +} + +fn validate_traces( + traces: &[UairTrace<'_, PolyCoeff, Int, D>], + expected_rows: usize, +) -> Result +where + PolyCoeff: Clone, + Int: Clone, +{ + let num_binary_cols = traces + .first() + .expect("trace count was already validated as non-empty") + .binary_poly + .len(); + for (trace_idx, trace) in traces.iter().enumerate() { + if trace.binary_poly.len() != num_binary_cols { + return Err(BooleanityAccumulatorError::BinaryColumnCountMismatch { + trace_idx, + got: trace.binary_poly.len(), + expected: num_binary_cols, + }); + } + for (col_idx, column) in trace.binary_poly.iter().enumerate() { + if column.evaluations.len() != expected_rows { + return Err(BooleanityAccumulatorError::BinaryColumnRowMismatch { + trace_idx, + col_idx, + got: column.evaluations.len(), + expected: expected_rows, + }); + } + } + } + Ok(num_binary_cols) +} + +fn precompute_row_tail_weights( + weights: BooleanityWeights<'_, F>, + field_cfg: &F::Config, +) -> Result, BooleanityAccumulatorError> +where + F: PrimeField, +{ + let row_count = weights.row_weights.len(); + let total = weights.tail_eq_weights.len().checked_mul(row_count).ok_or( + BooleanityAccumulatorError::DomainTooLarge { + vars: weights.tail_eq_weights.len(), + }, + )?; + let mut omega = Vec::with_capacity(total); + for tail_weight in weights.tail_eq_weights { + for row_weight in weights.row_weights.as_slice() { + omega.push(tail_weight.clone() * row_weight); + } + } + if omega.is_empty() { + omega.push(F::zero_with_cfg(field_cfg)); + } + Ok(omega) +} + +fn adaptive_tile_len(max_magnitude: usize) -> usize { + (MAX_DMR_BUCKET_ARRAYS / (max_magnitude + 1)).max(1) +} + +fn max_delta_magnitude( + prefix_vars: usize, + support_size: usize, +) -> Result { + let shift = support_size + .checked_sub(1) + .ok_or(BooleanityAccumulatorError::DomainTooLarge { vars: prefix_vars })?; + let shift = u32::try_from(shift) + .map_err(|_| BooleanityAccumulatorError::DomainTooLarge { vars: prefix_vars })?; + 1usize + .checked_shl(shift) + .ok_or(BooleanityAccumulatorError::DomainTooLarge { vars: prefix_vars }) +} + +fn binary_domain_size(vars: usize) -> Result { + let shift = + u32::try_from(vars).map_err(|_| BooleanityAccumulatorError::DomainTooLarge { vars })?; + 1usize + .checked_shl(shift) + .ok_or(BooleanityAccumulatorError::DomainTooLarge { vars }) +} + +pub fn ternary_domain_size(vars: usize) -> Result { + let mut size = 1usize; + for _ in 0..vars { + size = size + .checked_mul(3) + .ok_or(BooleanityAccumulatorError::DomainTooLarge { vars })?; + } + Ok(size) +} + +#[allow(clippy::arithmetic_side_effects)] +pub fn extended_point_from_index( + mut index: usize, + prefix_vars: usize, +) -> Result { + let domain_size = ternary_domain_size(prefix_vars)?; + if index >= domain_size { + return Err(BooleanityAccumulatorError::ExtendedPointIndexOutOfRange { + index, + domain_size, + }); + } + + let mut support_mask = 0usize; + let mut finite_bits = 0usize; + for var in 0..prefix_vars { + let digit = index % 3; + index /= 3; + match digit { + 0 => {} + 1 => finite_bits |= 1usize << var, + 2 => support_mask |= 1usize << var, + _ => unreachable!("ternary digit must be 0, 1, or 2"), + } + } + ExtendedPrefixPoint::new(support_mask, finite_bits) +} + +#[allow(clippy::arithmetic_side_effects)] +pub fn extended_point_index( + point: ExtendedPrefixPoint, + prefix_vars: usize, +) -> Result { + let _ = ternary_domain_size(prefix_vars)?; + let allowed_bits = binary_domain_size(prefix_vars)?.saturating_sub(1); + if point.support_mask & !allowed_bits != 0 || point.finite_bits & !allowed_bits != 0 { + return Err(BooleanityAccumulatorError::ExtendedPointOutOfRange { prefix_vars }); + } + if point.support_mask & point.finite_bits != 0 { + return Err(BooleanityAccumulatorError::ExtendedPointNotCanonical); + } + + let mut index = 0usize; + let mut scale = 1usize; + for var in 0..prefix_vars { + let bit = 1usize << var; + let digit = if point.support_mask & bit != 0 { + 2 + } else if point.finite_bits & bit != 0 { + 1 + } else { + 0 + }; + index = index + .checked_add(digit * scale) + .ok_or(BooleanityAccumulatorError::DomainTooLarge { vars: prefix_vars })?; + scale = scale + .checked_mul(3) + .ok_or(BooleanityAccumulatorError::DomainTooLarge { vars: prefix_vars })?; + } + Ok(index) +} + +fn extended_entries_by_support_size( + prefix_vars: usize, +) -> Result>, BooleanityAccumulatorError> { + let domain_len = ternary_domain_size(prefix_vars)?; + let mut entries = vec![Vec::new(); prefix_vars + 1]; + for table_index in 0..domain_len { + let point = extended_point_from_index(table_index, prefix_vars)?; + let support_size = point.support_size(); + if support_size == 0 { + continue; + } + entries[support_size].push(ExtendedTableEntry { table_index, point }); + } + Ok(entries) +} + +fn support_bit_masks_into(mut support_mask: usize, out: &mut [usize]) -> usize { + let mut len = 0usize; + while support_mask != 0 { + let bit = support_mask & support_mask.wrapping_neg(); + out[len] = bit; + len += 1; + support_mask ^= bit; + } + len +} + +#[cfg(test)] +fn support_bit_masks(support_mask: usize) -> Vec { + let mut bits = + vec![0usize; usize::try_from(support_mask.count_ones()).expect("count_ones fits usize")]; + let len = support_bit_masks_into(support_mask, &mut bits); + debug_assert_eq!(len, bits.len()); + bits +} + +fn bit_word_count(degree: usize) -> usize { + degree.div_ceil(64) +} + +#[allow(clippy::arithmetic_side_effects)] +fn write_poly_words(poly: &BinaryPoly, out: &mut [u64]) { + out.fill(0); + for (bit_idx, coeff) in poly.iter().enumerate().take(D) { + if coeff.into_inner() { + out[bit_idx / 64] |= 1u64 << (bit_idx % 64); + } + } +} + +fn word_at(prefix_words: &[u64], prefix: usize, word_count: usize, word_idx: usize) -> u64 { + prefix_words[prefix * word_count + word_idx] +} + +#[allow(clippy::arithmetic_side_effects)] +fn valid_word_mask(word_idx: usize) -> u64 { + let remaining = D.saturating_sub(word_idx * 64); + match remaining { + 0 => 0, + 1..=63 => (1u64 << remaining) - 1, + _ => u64::MAX, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{neutron_nova::build_sumfold_eq_weights, test_utils::test_config}; + use crypto_primitives::{FromWithConfig, crypto_bigint_monty::MontyField}; + use std::borrow::Cow; + use zinc_poly::mle::DenseMultilinearExtension; + use zinc_utils::powers; + + type F = MontyField<4>; + type Trace = UairTrace<'static, F, F, 32>; + + fn f(value: u64) -> F { + F::from_with_cfg(value, &test_config()) + } + + fn binary_column(patterns: &[u32]) -> DenseMultilinearExtension> { + assert!(patterns.len().is_power_of_two()); + DenseMultilinearExtension::from_evaluations_vec( + usize::try_from(patterns.len().trailing_zeros()).expect("trailing_zeros fits usize"), + patterns + .iter() + .copied() + .map(BinaryPoly::<32>::from) + .collect(), + BinaryPoly::<32>::zero(), + ) + } + + fn trace_from_columns(col0: &[u32], col1: &[u32]) -> Trace { + UairTrace { + binary_poly: Cow::Owned(vec![binary_column(col0), binary_column(col1)]), + arbitrary_poly: Cow::Owned(Vec::new()), + int: Cow::Owned(Vec::new()), + } + } + + fn sample_traces_ell3() -> Vec { + (0..8u32) + .map(|i| { + trace_from_columns( + &[ + 0x0000_0001 ^ i, + 0x0000_00f0 ^ (i << 4), + 0x0000_3333 ^ (i * 0x1111), + 0x8000_0001 ^ (i << 8), + ], + &[ + 0x0000_0005 ^ (i << 1), + 0x0000_0a0a ^ (i << 5), + 0x0000_f00f ^ (i * 3), + 0x0001_0010 ^ (i << 9), + ], + ) + }) + .collect() + } + + #[allow(clippy::arithmetic_side_effects)] + fn naive_table( + traces: &[Trace], + prefix_vars: usize, + weights: BooleanityWeights<'_, F>, + scalar_weights: BooleanityScalarWeights<'_, F>, + ) -> Vec { + let cfg = test_config(); + let ell = + usize::try_from(traces.len().trailing_zeros()).expect("trailing_zeros fits usize"); + let tail_len = 1usize << (ell - prefix_vars); + let domain_len = ternary_domain_size(prefix_vars).unwrap(); + let mut out = vec![F::zero_with_cfg(&cfg); domain_len]; + + for (index, value) in out.iter_mut().enumerate() { + let point = extended_point_from_index(index, prefix_vars).unwrap(); + if point.is_finite_only() { + continue; + } + for tail in 0..tail_len { + for row in 0..weights.row_weights.len() { + for col_idx in 0..traces[0].binary_poly.len() { + for bit_idx in 0..32 { + let delta = naive_delta_bit( + traces, + col_idx, + bit_idx, + tail, + row, + prefix_vars, + point, + ); + if delta == 0 { + continue; + } + let mut contribution = weights.tail_eq_weights[tail].clone(); + contribution *= &weights.row_weights.as_slice()[row]; + contribution *= &scalar_weights.rho_powers[col_idx * 32 + bit_idx]; + let delta_square = + u64::try_from(delta * delta).expect("delta square fits u64"); + contribution *= F::from_with_cfg(delta_square, &cfg); + *value += contribution; + } + } + } + } + } + out + } + + #[allow(clippy::arithmetic_side_effects)] + fn naive_delta_bit( + traces: &[Trace], + col_idx: usize, + bit_idx: usize, + tail: usize, + row: usize, + prefix_vars: usize, + point: ExtendedPrefixPoint, + ) -> i64 { + let support_bits = support_bit_masks(point.support_mask); + let support_size = support_bits.len(); + let base = point.finite_bits & !point.support_mask; + let mut delta = 0i64; + for vertex in 0..(1usize << support_size) { + let mut prefix = base; + for (pos, bit) in support_bits.iter().enumerate() { + if ((vertex >> pos) & 1) == 1 { + prefix |= *bit; + } + } + let sign = if (support_size - vertex.count_ones() as usize) % 2 == 0 { + 1i64 + } else { + -1i64 + }; + let instance_idx = prefix + (tail << prefix_vars); + let bit = traces[instance_idx].binary_poly[col_idx].evaluations[row] + .iter() + .nth(bit_idx) + .expect("bit index in range") + .into_inner(); + if bit { + delta += sign; + } + } + delta + } + + fn rho_powers() -> Vec { + powers(f(7), F::one_with_cfg(&test_config()), 64) + } + + #[test] + fn extended_point_index_round_trips() { + for prefix_vars in 0..=4 { + let domain_len = ternary_domain_size(prefix_vars).unwrap(); + for index in 0..domain_len { + let point = extended_point_from_index(index, prefix_vars).unwrap(); + assert_eq!(extended_point_index(point, prefix_vars).unwrap(), index); + } + } + } + + #[test] + fn extended_point_rejects_noncanonical_overlap() { + assert!(matches!( + ExtendedPrefixPoint::new(0b01, 0b01), + Err(BooleanityAccumulatorError::ExtendedPointNotCanonical) + )); + + let point = ExtendedPrefixPoint { + support_mask: 0b01, + finite_bits: 0b01, + }; + assert!(matches!( + extended_point_index(point, 1), + Err(BooleanityAccumulatorError::ExtendedPointNotCanonical) + )); + } + + #[test] + fn optimized_booleanity_table_matches_naive_for_support_one_and_two() { + let cfg = test_config(); + let traces = sample_traces_ell3(); + let beta = vec![f(3), f(5), f(11)]; + let eq_weights = build_sumfold_eq_weights(&beta, 2, &cfg).unwrap(); + let row_weights = RowWeights::new(&[f(13), f(17)], &cfg).unwrap(); + let rho = rho_powers(); + let weights = BooleanityWeights { + row_weights: &row_weights, + tail_eq_weights: &eq_weights.tail_eq_weights, + }; + let scalar_weights = BooleanityScalarWeights { rho_powers: &rho }; + + let table = build_booleanity_prefix_table(&traces, 2, weights, scalar_weights, &cfg) + .expect("booleanity table should build"); + let expected = naive_table(&traces, 2, weights, scalar_weights); + assert_eq!(table.values(), expected.as_slice()); + } + + #[test] + fn optimized_booleanity_table_matches_naive_for_general_support() { + let cfg = test_config(); + let traces = sample_traces_ell3(); + let eq_weights = build_sumfold_eq_weights(&[f(3), f(5), f(11)], 3, &cfg).unwrap(); + let row_weights = RowWeights::new(&[f(13), f(17)], &cfg).unwrap(); + let rho = rho_powers(); + let weights = BooleanityWeights { + row_weights: &row_weights, + tail_eq_weights: &eq_weights.tail_eq_weights, + }; + let scalar_weights = BooleanityScalarWeights { rho_powers: &rho }; + + let table = build_booleanity_prefix_table(&traces, 3, weights, scalar_weights, &cfg) + .expect("booleanity table should build"); + let expected = naive_table(&traces, 3, weights, scalar_weights); + assert_eq!(table.values(), expected.as_slice()); + } + + #[test] + fn finite_only_entries_are_zero_initially() { + let cfg = test_config(); + let traces = sample_traces_ell3(); + let eq_weights = build_sumfold_eq_weights(&[f(3), f(5), f(11)], 2, &cfg).unwrap(); + let row_weights = RowWeights::new(&[f(13), f(17)], &cfg).unwrap(); + let rho = rho_powers(); + let table = build_booleanity_prefix_table( + &traces, + 2, + BooleanityWeights { + row_weights: &row_weights, + tail_eq_weights: &eq_weights.tail_eq_weights, + }, + BooleanityScalarWeights { rho_powers: &rho }, + &cfg, + ) + .unwrap(); + + for index in 0..table.len() { + let point = extended_point_from_index(index, table.prefix_vars()).unwrap(); + if point.is_finite_only() { + assert_eq!(table.values()[index], F::zero_with_cfg(&cfg)); + } + } + } + + #[test] + fn booleanity_validation_errors_are_reported() { + let cfg = test_config(); + let traces = sample_traces_ell3(); + let row_weights = RowWeights::new(&[f(13), f(17)], &cfg).unwrap(); + let rho = rho_powers(); + let empty_traces: &[Trace] = &[]; + + let err = build_booleanity_prefix_table( + empty_traces, + 0, + BooleanityWeights { + row_weights: &row_weights, + tail_eq_weights: &[F::one_with_cfg(&cfg)], + }, + BooleanityScalarWeights { rho_powers: &rho }, + &cfg, + ) + .expect_err("empty traces should be rejected"); + assert!(matches!(err, BooleanityAccumulatorError::EmptyTraces)); + + let err = build_booleanity_prefix_table( + &traces[..3], + 1, + BooleanityWeights { + row_weights: &row_weights, + tail_eq_weights: &[F::one_with_cfg(&cfg), F::one_with_cfg(&cfg)], + }, + BooleanityScalarWeights { rho_powers: &rho }, + &cfg, + ) + .expect_err("non-power-of-two trace count should be rejected"); + assert!(matches!( + err, + BooleanityAccumulatorError::TraceCountNotPowerOfTwo { len: 3 } + )); + + let err = build_booleanity_prefix_table( + &traces, + 4, + BooleanityWeights { + row_weights: &row_weights, + tail_eq_weights: &[F::one_with_cfg(&cfg)], + }, + BooleanityScalarWeights { rho_powers: &rho }, + &cfg, + ) + .expect_err("too many prefix vars should be rejected"); + assert!(matches!( + err, + BooleanityAccumulatorError::PrefixVarsTooLarge { + prefix_vars: 4, + ell: 3 + } + )); + + let oversized_prefix_traces = + vec![traces[0].clone(); 1usize << (MAX_BOOLEANITY_PREFIX_VARS + 1)]; + let err = build_booleanity_prefix_table( + &oversized_prefix_traces, + MAX_BOOLEANITY_PREFIX_VARS + 1, + BooleanityWeights { + row_weights: &row_weights, + tail_eq_weights: &[F::one_with_cfg(&cfg)], + }, + BooleanityScalarWeights { rho_powers: &rho }, + &cfg, + ) + .expect_err("unsupported prefix var count should be rejected before allocation"); + assert!(matches!( + err, + BooleanityAccumulatorError::PrefixVarsExceedsSupported { + prefix_vars, + max + } if prefix_vars == MAX_BOOLEANITY_PREFIX_VARS + 1 + && max == MAX_BOOLEANITY_PREFIX_VARS + )); + + let err = build_booleanity_prefix_table( + &traces, + 2, + BooleanityWeights { + row_weights: &row_weights, + tail_eq_weights: &[], + }, + BooleanityScalarWeights { rho_powers: &rho }, + &cfg, + ) + .expect_err("wrong tail weight length should be rejected"); + assert!(matches!( + err, + BooleanityAccumulatorError::LengthMismatch { + label: "tail_eq_weights", + got: 0, + expected: 2 + } + )); + + let short_rho = vec![F::one_with_cfg(&cfg); 3]; + let eq_weights = build_sumfold_eq_weights(&[f(3), f(5), f(11)], 2, &cfg).unwrap(); + let err = build_booleanity_prefix_table( + &traces, + 2, + BooleanityWeights { + row_weights: &row_weights, + tail_eq_weights: &eq_weights.tail_eq_weights, + }, + BooleanityScalarWeights { + rho_powers: &short_rho, + }, + &cfg, + ) + .expect_err("short rho powers should be rejected"); + assert!(matches!( + err, + BooleanityAccumulatorError::LengthMismatch { + label: "rho_powers", + got: 3, + expected: 64 + } + )); + + let mut mismatched_cols = traces.clone(); + mismatched_cols[1].binary_poly.to_mut().pop(); + let err = build_booleanity_prefix_table( + &mismatched_cols, + 2, + BooleanityWeights { + row_weights: &row_weights, + tail_eq_weights: &eq_weights.tail_eq_weights, + }, + BooleanityScalarWeights { rho_powers: &rho }, + &cfg, + ) + .expect_err("wrong binary column count should be rejected"); + assert!(matches!( + err, + BooleanityAccumulatorError::BinaryColumnCountMismatch { + trace_idx: 1, + got: 1, + expected: 2 + } + )); + + let bad_row_weights = RowWeights::new(&[f(13)], &cfg).unwrap(); + let err = build_booleanity_prefix_table( + &traces, + 2, + BooleanityWeights { + row_weights: &bad_row_weights, + tail_eq_weights: &eq_weights.tail_eq_weights, + }, + BooleanityScalarWeights { rho_powers: &rho }, + &cfg, + ) + .expect_err("wrong row count should be rejected"); + assert!(matches!( + err, + BooleanityAccumulatorError::BinaryColumnRowMismatch { .. } + )); + } +} diff --git a/piop/src/neutron_nova/linear_cpr.rs b/piop/src/neutron_nova/linear_cpr.rs new file mode 100644 index 00000000..272dd0d1 --- /dev/null +++ b/piop/src/neutron_nova/linear_cpr.rs @@ -0,0 +1,1491 @@ +use crate::neutron_nova::{RowWeights, sumfold::checked_domain_size}; +use crate::sumcheck::{ + multi_degree::{MultiDegreeSumcheckGroup, PrefixFastPath, PrefixRoundOutput}, + prover::ProverState as SumcheckProverState, +}; +use crypto_primitives::{FromPrimitiveWithConfig, PrimeField, crypto_bigint_uint::Uint}; +use num_traits::Zero; +use std::array; +use thiserror::Error; +use zinc_poly::{ + mle::DenseMultilinearExtension, + univariate::binary::BinaryPoly, + utils::{build_eq_x_r_inner, build_eq_x_r_vec, eq_eval}, +}; +use zinc_uair::UairTrace; +use zinc_utils::{ + UNCHECKED, + delayed_reduction::{ + BarrettDelayedReduction, DelayedFieldProductSum, DelayedModularReductionAlgorithm, + MontgomeryLimbs, MontgomeryProductSum4, + }, + inner_product::FieldFieldInnerProduct, + inner_transparent_field::InnerTransparentField, +}; + +use super::{LinearPrefixTable, SumFoldError}; + +const PREFIX_TILE_SIZE: usize = 8; +/// Precomputed equality weights for the instance-axis SumFold split. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct SumFoldEqWeights { + pub prefix_eq_weights: Vec, + pub tail_eq_weights: Vec, +} + +/// Precomputed multiplication weights used by the linear CPR accumulator. +#[derive(Debug)] +pub struct LinearCprWeights<'a, F: PrimeField> { + pub row_weights: &'a RowWeights, + pub tail_eq_weights: &'a [F], +} + +impl Clone for LinearCprWeights<'_, F> { + fn clone(&self) -> Self { + *self + } +} + +impl Copy for LinearCprWeights<'_, F> {} + +/// Precomputed scalar weights applied after row/tail DMR reduction. +#[derive(Debug)] +pub struct LinearCprScalarWeights<'a, F: PrimeField> { + pub family_weights: &'a [F], + pub scalarization_powers: &'a [F], +} + +impl Clone for LinearCprScalarWeights<'_, F> { + fn clone(&self) -> Self { + *self + } +} + +impl Copy for LinearCprScalarWeights<'_, F> {} + +/// One linear CPR family described as small-coefficient binary source terms. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct LinearFamilySpec { + pub family_idx: usize, + pub active_rows: Vec, + pub terms: Vec>, +} + +/// One binary source term in a linear CPR family. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct LinearTermSpec { + pub source: LinearBinarySource, + pub coeffs_by_active_row: Vec>, +} + +/// Binary source read by a linear term. +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum LinearBinarySource { + Column { col_idx: usize }, + ShiftedColumn { col_idx: usize, shift: usize }, +} + +impl LinearBinarySource { + fn col_idx(&self) -> usize { + match self { + Self::Column { col_idx } | Self::ShiftedColumn { col_idx, .. } => *col_idx, + } + } +} + +/// Coefficient class for post-DMR bucket weighting. +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum CoeffClass { + Zero, + Small(i64), + Large(F), +} + +impl CoeffClass +where + F: FromPrimitiveWithConfig, +{ + fn is_zero(&self) -> bool { + matches!(self, Self::Zero | Self::Small(0)) + } + + fn to_field(&self, field_cfg: &F::Config) -> F { + match self { + Self::Zero => F::zero_with_cfg(field_cfg), + Self::Small(value) => F::from_with_cfg(*value, field_cfg), + Self::Large(value) => value.clone(), + } + } +} + +#[derive(Clone, Debug, Error)] +pub enum LinearCprAccumulatorError { + #[error("linear CPR accumulator needs at least one trace")] + EmptyTraces, + #[error("trace count must be a power of two, got {len}")] + TraceCountNotPowerOfTwo { len: usize }, + #[error("prefix_vars={prefix_vars} must be at most ell={ell}")] + PrefixVarsTooLarge { prefix_vars: usize, ell: usize }, + #[error("domain size is too large for {vars} variables")] + DomainTooLarge { vars: usize }, + #[error("{label} length mismatch: got {got}, expected {expected}")] + LengthMismatch { + label: &'static str, + got: usize, + expected: usize, + }, + #[error("trace {trace_idx} has {got} binary columns, expected {expected}")] + BinaryColumnCountMismatch { + trace_idx: usize, + got: usize, + expected: usize, + }, + #[error("trace {trace_idx} binary column {col_idx} has {got} rows, expected {expected}")] + BinaryColumnRowMismatch { + trace_idx: usize, + col_idx: usize, + got: usize, + expected: usize, + }, + #[error( + "family {family_idx} references family weight {weight_idx}, but only {len} weights exist" + )] + FamilyWeightOutOfRange { + family_idx: usize, + weight_idx: usize, + len: usize, + }, + #[error("family {family_idx} active row {row} is out of range for {rows} rows")] + ActiveRowOutOfRange { + family_idx: usize, + row: usize, + rows: usize, + }, + #[error("family {family_idx} term {term_idx} has {got} coefficients, expected {expected}")] + TermCoeffLengthMismatch { + family_idx: usize, + term_idx: usize, + got: usize, + expected: usize, + }, + #[error( + "family {family_idx} term {term_idx} references binary column {col_idx}, but only {cols} exist" + )] + SourceColumnOutOfRange { + family_idx: usize, + term_idx: usize, + col_idx: usize, + cols: usize, + }, + #[error("sumfold helper failed: {0}")] + SumFold(#[from] SumFoldError), + #[error("equality table construction failed: {0}")] + EqTable(#[from] zinc_poly::utils::ArithErrors), +} + +#[derive(Clone, Debug)] +struct PreparedFamily { + family_idx: usize, + active_rows: Vec, + coeff_values: Vec, + terms: Vec, +} + +#[derive(Clone, Debug)] +struct PreparedTerm { + source: LinearBinarySource, + coeff_indices_by_active_row: Vec>, +} + +/// Build prefix/tail equality weights for a SumFold instance-axis split. +pub fn build_sumfold_eq_weights( + beta: &[F], + prefix_vars: usize, + field_cfg: &F::Config, +) -> Result, LinearCprAccumulatorError> +where + F: PrimeField, +{ + let ell = beta.len(); + if prefix_vars > ell { + return Err(LinearCprAccumulatorError::PrefixVarsTooLarge { prefix_vars, ell }); + } + + let prefix_eq_weights = if prefix_vars == 0 { + vec![F::one_with_cfg(field_cfg)] + } else { + zinc_poly::utils::build_eq_x_r_vec(&beta[..prefix_vars], field_cfg)? + }; + let tail_vars = ell - prefix_vars; + let tail_eq_weights = if tail_vars == 0 { + vec![F::one_with_cfg(field_cfg)] + } else { + zinc_poly::utils::build_eq_x_r_vec(&beta[prefix_vars..], field_cfg)? + }; + + Ok(SumFoldEqWeights { + prefix_eq_weights, + tail_eq_weights, + }) +} + +/// Build the optimized linear CPR prefix table from small-value binary traces. +#[allow(clippy::arithmetic_side_effects, clippy::too_many_arguments)] +pub fn build_linear_cpr_prefix_table( + traces: &[UairTrace<'_, PolyCoeff, Int, D>], + prefix_vars: usize, + families: &[LinearFamilySpec], + weights: LinearCprWeights<'_, F>, + scalar_weights: LinearCprScalarWeights<'_, F>, + field_cfg: &F::Config, +) -> Result, LinearCprAccumulatorError> +where + F: MontgomeryLimbs + DelayedFieldProductSum + FromPrimitiveWithConfig + Send + Sync + 'static, + PolyCoeff: Clone, + Int: Clone, +{ + let ell = validate_trace_count(traces.len())?; + if prefix_vars > ell { + return Err(LinearCprAccumulatorError::PrefixVarsTooLarge { prefix_vars, ell }); + } + + let prefix_len = checked_domain_size(prefix_vars)?; + let tail_len = checked_domain_size(ell - prefix_vars)?; + if weights.tail_eq_weights.len() != tail_len { + return Err(LinearCprAccumulatorError::LengthMismatch { + label: "tail_eq_weights", + got: weights.tail_eq_weights.len(), + expected: tail_len, + }); + } + if scalar_weights.scalarization_powers.len() < D { + return Err(LinearCprAccumulatorError::LengthMismatch { + label: "scalarization_powers", + got: scalar_weights.scalarization_powers.len(), + expected: D, + }); + } + + let num_binary_cols = validate_traces(traces, weights.row_weights.len())?; + let prepared = prepare_families( + families, + scalar_weights.family_weights.len(), + num_binary_cols, + weights.row_weights.len(), + field_cfg, + )?; + + let mut table_values = vec![F::zero_with_cfg(field_cfg); prefix_len]; + let reducer = BarrettDelayedReduction::::new(field_cfg); + + for family in &prepared { + let family_weight = &scalar_weights.family_weights[family.family_idx]; + if family.coeff_values.is_empty() { + continue; + } + + let mut tile_start = 0usize; + while tile_start < prefix_len { + let tile_len = PREFIX_TILE_SIZE.min(prefix_len - tile_start); + accumulate_family_tile::( + traces, + family, + tile_start, + tile_len, + prefix_vars, + tail_len, + &weights, + scalar_weights.scalarization_powers, + family_weight, + field_cfg, + &reducer, + &mut table_values, + )?; + tile_start += tile_len; + } + } + + LinearPrefixTable::from_values_for_prefix_vars(table_values, ell, prefix_vars) + .map_err(LinearCprAccumulatorError::from) +} + +/// Build the post-prefix CPR claim table `claim(r_prefix, tail)`. +/// +/// Unlike [`build_linear_cpr_prefix_table`], this does not apply tail equality +/// weights. The resumed ordinary sumcheck keeps +/// `eq(beta_prefix, r_prefix) * eq(beta_tail, tail)` as its separate equality +/// MLE, so this table only binds the CPR claim MLE along the prefix variables. +#[allow(clippy::arithmetic_side_effects, clippy::too_many_arguments)] +pub fn build_linear_cpr_prefix_bound_tail_table( + traces: &[UairTrace<'_, PolyCoeff, Int, D>], + prefix_vars: usize, + prefix_point: &[F], + families: &[LinearFamilySpec], + row_weights: &RowWeights, + scalar_weights: LinearCprScalarWeights<'_, F>, + field_cfg: &F::Config, +) -> Result, LinearCprAccumulatorError> +where + F: InnerTransparentField + FromPrimitiveWithConfig, + F::Inner: Zero, + PolyCoeff: Clone, + Int: Clone, +{ + let ell = validate_trace_count(traces.len())?; + if prefix_vars > ell { + return Err(LinearCprAccumulatorError::PrefixVarsTooLarge { prefix_vars, ell }); + } + if prefix_point.len() != prefix_vars { + return Err(LinearCprAccumulatorError::LengthMismatch { + label: "prefix_point", + got: prefix_point.len(), + expected: prefix_vars, + }); + } + if scalar_weights.scalarization_powers.len() < D { + return Err(LinearCprAccumulatorError::LengthMismatch { + label: "scalarization_powers", + got: scalar_weights.scalarization_powers.len(), + expected: D, + }); + } + + let prefix_len = checked_domain_size(prefix_vars)?; + let tail_vars = ell - prefix_vars; + let tail_len = checked_domain_size(tail_vars)?; + let prefix_weights = if prefix_vars == 0 { + vec![F::one_with_cfg(field_cfg)] + } else { + zinc_poly::utils::build_eq_x_r_vec(prefix_point, field_cfg)? + }; + + let num_binary_cols = validate_traces(traces, row_weights.len())?; + let prepared = prepare_families( + families, + scalar_weights.family_weights.len(), + num_binary_cols, + row_weights.len(), + field_cfg, + )?; + + let zero = F::zero_with_cfg(field_cfg); + let mut tail_values = vec![zero.clone(); tail_len]; + for (tail, tail_value) in tail_values.iter_mut().enumerate() { + for (prefix, prefix_weight) in prefix_weights.iter().enumerate().take(prefix_len) { + let instance_idx = prefix + (tail << prefix_vars); + let trace = &traces[instance_idx]; + + for family in &prepared { + let family_weight = &scalar_weights.family_weights[family.family_idx]; + let mut family_value = zero.clone(); + + for (active_pos, &row) in family.active_rows.iter().enumerate() { + let row_weight = &row_weights.as_slice()[row]; + for term in &family.terms { + let Some(coeff_idx) = term.coeff_indices_by_active_row[active_pos] else { + continue; + }; + let Some(poly) = source_poly::(trace, &term.source, row) + else { + continue; + }; + let coeff_value = &family.coeff_values[coeff_idx]; + + for (bit_idx, bit) in poly.iter().enumerate().take(D) { + if bit.into_inner() { + let mut contribution = row_weight.clone(); + contribution *= coeff_value; + contribution *= &scalar_weights.scalarization_powers[bit_idx]; + family_value += contribution; + } + } + } + } + + *tail_value += prefix_weight.clone() * family_weight * family_value; + } + } + } + + let zero_inner = zero.inner().clone(); + Ok(DenseMultilinearExtension::from_evaluations_vec( + tail_vars, + tail_values + .iter() + .map(|value| value.inner().clone()) + .collect(), + zero_inner, + )) +} + +struct LinearCprPrefixFastPath< + F: PrimeField, + PolyCoeff: Clone + 'static, + Int: Clone + 'static, + const D: usize, +> { + traces: Vec>, + families: Vec>, + row_weights: RowWeights, + family_weights: Vec, + scalarization_powers: Vec, + beta: Vec, + prefix_vars: usize, + prefix_state: SumcheckProverState, +} + +impl LinearCprPrefixFastPath +where + F: MontgomeryLimbs + + InnerTransparentField + + DelayedFieldProductSum + + FromPrimitiveWithConfig + + Send + + Sync + + 'static, + F::Inner: Zero, + PolyCoeff: Clone + Send + Sync + 'static, + Int: Clone + Send + Sync + 'static, +{ + #[allow(clippy::too_many_arguments)] + fn new( + traces: Vec>, + prefix_vars: usize, + beta: Vec, + families: Vec>, + row_weights: RowWeights, + family_weights: Vec, + scalarization_powers: Vec, + field_cfg: &F::Config, + ) -> Result { + let eq_weights = build_sumfold_eq_weights(&beta, prefix_vars, field_cfg)?; + let table = build_linear_cpr_prefix_table( + &traces, + prefix_vars, + &families, + LinearCprWeights { + row_weights: &row_weights, + tail_eq_weights: &eq_weights.tail_eq_weights, + }, + LinearCprScalarWeights { + family_weights: &family_weights, + scalarization_powers: &scalarization_powers, + }, + field_cfg, + )?; + let eq_prefix = build_eq_x_r_inner(&beta[..prefix_vars], field_cfg)?; + let prefix_state = + SumcheckProverState::new(vec![eq_prefix, table.to_mle(field_cfg)], prefix_vars, 2); + + Ok(Self { + traces, + families, + row_weights, + family_weights, + scalarization_powers, + beta, + prefix_vars, + prefix_state, + }) + } + + #[allow(clippy::arithmetic_side_effects)] + fn finish_tail_mles( + self, + prefix_challenges: &[F], + field_cfg: &F::Config, + ) -> Vec> { + debug_assert_eq!(prefix_challenges.len(), self.prefix_vars); + let tail_vars = self.beta.len() - self.prefix_vars; + let beta_tail_weights = build_eq_x_r_vec(&self.beta[self.prefix_vars..], field_cfg) + .expect("tail beta equality table should build"); + let eq_prefix_at_r = eq_eval( + prefix_challenges, + &self.beta[..self.prefix_vars], + F::one_with_cfg(field_cfg), + ) + .expect("prefix challenge and beta prefix lengths match"); + + let tail_claims = build_linear_cpr_prefix_bound_tail_table( + &self.traces, + self.prefix_vars, + prefix_challenges, + &self.families, + &self.row_weights, + LinearCprScalarWeights { + family_weights: &self.family_weights, + scalarization_powers: &self.scalarization_powers, + }, + field_cfg, + ) + .expect("validated CPR tail table should build"); + + let zero_inner = F::zero_with_cfg(field_cfg).inner().clone(); + let scaled_eq_tail = DenseMultilinearExtension::from_evaluations_vec( + tail_vars, + beta_tail_weights + .iter() + .map(|weight| (eq_prefix_at_r.clone() * weight).inner().clone()) + .collect(), + zero_inner, + ); + + vec![scaled_eq_tail, tail_claims] + } +} + +impl PrefixFastPath + for LinearCprPrefixFastPath +where + F: MontgomeryLimbs + + InnerTransparentField + + DelayedFieldProductSum + + FromPrimitiveWithConfig + + Send + + Sync + + 'static, + F::Inner: Zero, + PolyCoeff: Clone + Send + Sync + 'static, + Int: Clone + Send + Sync + 'static, +{ + fn prefix_len(&self) -> usize { + self.prefix_vars + } + + fn prove_prefix_round( + &mut self, + verifier_msg: &Option, + config: &F::Config, + ) -> PrefixRoundOutput { + let msg = self.prefix_state.prove_round( + verifier_msg, + |values: &[F]| values[0].clone() * &values[1], + config, + ); + let asserted_sum = if self.prefix_state.round == 1 { + self.prefix_state.asserted_sum.clone() + } else { + None + }; + + PrefixRoundOutput { + asserted_sum, + tail_evaluations: msg.0.tail_evaluations, + } + } + + fn finish_prefix( + self: Box, + prefix_challenges: &[F], + config: &F::Config, + ) -> Vec> { + self.finish_tail_mles(prefix_challenges, config) + } +} + +/// Build a hybrid SumFold group for linear CPR over owned traces. +/// +/// `prefix_vars = 0` falls back to an ordinary full sumcheck group over dense +/// CPR claims. Positive prefixes emit CPR prefix-table sumcheck messages, bind +/// the prefix at the sampled challenges, and resume the ordinary degree-2 +/// sumcheck over the tail variables. +#[allow(clippy::too_many_arguments)] +pub fn build_linear_cpr_hybrid_sumcheck_group( + traces: Vec>, + prefix_vars: usize, + beta: &[F], + families: Vec>, + row_weights: RowWeights, + family_weights: Vec, + scalarization_powers: Vec, + field_cfg: &F::Config, +) -> Result, LinearCprAccumulatorError> +where + F: MontgomeryLimbs + + InnerTransparentField + + DelayedFieldProductSum + + FromPrimitiveWithConfig + + Send + + Sync + + 'static, + F::Inner: Zero, + PolyCoeff: Clone + Send + Sync + 'static, + Int: Clone + Send + Sync + 'static, +{ + let ell = validate_trace_count(traces.len())?; + if beta.len() != ell { + return Err(LinearCprAccumulatorError::LengthMismatch { + label: "beta", + got: beta.len(), + expected: ell, + }); + } + if prefix_vars > ell { + return Err(LinearCprAccumulatorError::PrefixVarsTooLarge { prefix_vars, ell }); + } + if prefix_vars == 0 { + let claims = build_linear_cpr_prefix_bound_tail_table( + &traces, + 0, + &[], + &families, + &row_weights, + LinearCprScalarWeights { + family_weights: &family_weights, + scalarization_powers: &scalarization_powers, + }, + field_cfg, + )?; + let eq_beta = build_eq_x_r_inner(beta, field_cfg)?; + return Ok(MultiDegreeSumcheckGroup::new( + 2, + vec![eq_beta, claims], + Box::new(|values: &[F]| values[0].clone() * &values[1]), + )); + } + if prefix_vars >= ell { + return Err(LinearCprAccumulatorError::SumFold( + SumFoldError::HybridPrefixNeedsTail { + ell0: prefix_vars, + ell, + }, + )); + } + + let fast_path = LinearCprPrefixFastPath::new( + traces, + prefix_vars, + beta.to_vec(), + families, + row_weights, + family_weights, + scalarization_powers, + field_cfg, + )?; + Ok(MultiDegreeSumcheckGroup::with_prefix_fast( + 2, + Vec::new(), + Box::new(|values: &[F]| values[0].clone() * &values[1]), + Box::new(fast_path), + )) +} + +#[allow(clippy::arithmetic_side_effects, clippy::too_many_arguments)] +fn accumulate_family_tile( + traces: &[UairTrace<'_, PolyCoeff, Int, D>], + family: &PreparedFamily, + tile_start: usize, + tile_len: usize, + prefix_vars: usize, + tail_len: usize, + weights: &LinearCprWeights<'_, F>, + scalarization_powers: &[F], + family_weight: &F, + field_cfg: &F::Config, + reducer: &BarrettDelayedReduction<'_, F>, + table_values: &mut [F], +) -> Result<(), LinearCprAccumulatorError> +where + F: MontgomeryLimbs + DelayedFieldProductSum + FromPrimitiveWithConfig + Send + Sync + 'static, + PolyCoeff: Clone, + Int: Clone, +{ + let bucket_count = tile_len + .checked_mul(family.coeff_values.len()) + .ok_or(LinearCprAccumulatorError::DomainTooLarge { vars: prefix_vars })?; + let mut buckets: Vec<[Uint<5>; D]> = vec![[Uint::zero(); D]; bucket_count]; + let zero = F::zero_with_cfg(field_cfg); + let mut lane_accs: Vec<[F; D]> = (0..bucket_count) + .map(|_| array::from_fn(|_| zero.clone())) + .collect(); + let product_sum = MontgomeryProductSum4::::new(field_cfg); + let mut pending_adds = 0usize; + let flush_adds = reducer.flush_adds(); + + for tail in 0..tail_len { + let tail_weight = &weights.tail_eq_weights[tail]; + for (active_pos, &row) in family.active_rows.iter().enumerate() { + let omega = tail_weight.clone() * &weights.row_weights.as_slice()[row]; + + for term in &family.terms { + let Some(coeff_idx) = term.coeff_indices_by_active_row[active_pos] else { + continue; + }; + + for prefix_offset in 0..tile_len { + let prefix = tile_start + prefix_offset; + let instance_idx = prefix + (tail << prefix_vars); + let Some(poly) = + source_poly::(&traces[instance_idx], &term.source, row) + else { + continue; + }; + let bucket_idx = prefix_offset * family.coeff_values.len() + coeff_idx; + pending_adds = pending_adds.saturating_add(add_poly_bits_to_bucket( + poly, + &omega, + reducer, + &mut buckets[bucket_idx], + )); + + if pending_adds >= flush_adds { + flush_buckets_into_lanes(&mut buckets, &mut lane_accs, reducer); + pending_adds = 0; + } + } + } + } + } + + flush_buckets_into_lanes(&mut buckets, &mut lane_accs, reducer); + + for prefix_offset in 0..tile_len { + let prefix = tile_start + prefix_offset; + let mut family_value = zero.clone(); + for (coeff_idx, coeff_value) in family.coeff_values.iter().enumerate() { + let bucket_idx = prefix_offset * family.coeff_values.len() + coeff_idx; + let projected = FieldFieldInnerProduct::inner_product_with_algorithm::( + &product_sum, + &lane_accs[bucket_idx], + &scalarization_powers[..D], + zero.clone(), + ) + .expect("lane accumulator and scalarization powers have matching lengths"); + family_value += coeff_value.clone() * projected; + } + table_values[prefix] += family_weight.clone() * family_value; + } + + Ok(()) +} + +fn validate_trace_count(len: usize) -> Result { + if len == 0 { + return Err(LinearCprAccumulatorError::EmptyTraces); + } + if !len.is_power_of_two() { + return Err(LinearCprAccumulatorError::TraceCountNotPowerOfTwo { len }); + } + Ok(usize::try_from(len.trailing_zeros()).expect("trailing_zeros fits usize")) +} + +fn validate_traces( + traces: &[UairTrace<'_, PolyCoeff, Int, D>], + expected_rows: usize, +) -> Result +where + PolyCoeff: Clone, + Int: Clone, +{ + let num_binary_cols = traces + .first() + .expect("trace count was already validated as non-empty") + .binary_poly + .len(); + for (trace_idx, trace) in traces.iter().enumerate() { + if trace.binary_poly.len() != num_binary_cols { + return Err(LinearCprAccumulatorError::BinaryColumnCountMismatch { + trace_idx, + got: trace.binary_poly.len(), + expected: num_binary_cols, + }); + } + for (col_idx, column) in trace.binary_poly.iter().enumerate() { + if column.evaluations.len() != expected_rows { + return Err(LinearCprAccumulatorError::BinaryColumnRowMismatch { + trace_idx, + col_idx, + got: column.evaluations.len(), + expected: expected_rows, + }); + } + } + } + Ok(num_binary_cols) +} + +fn prepare_families( + families: &[LinearFamilySpec], + family_weight_len: usize, + num_binary_cols: usize, + row_count: usize, + field_cfg: &F::Config, +) -> Result>, LinearCprAccumulatorError> +where + F: FromPrimitiveWithConfig, +{ + let mut prepared = Vec::with_capacity(families.len()); + for family in families { + if family.family_idx >= family_weight_len { + return Err(LinearCprAccumulatorError::FamilyWeightOutOfRange { + family_idx: family.family_idx, + weight_idx: family.family_idx, + len: family_weight_len, + }); + } + for &row in &family.active_rows { + if row >= row_count { + return Err(LinearCprAccumulatorError::ActiveRowOutOfRange { + family_idx: family.family_idx, + row, + rows: row_count, + }); + } + } + + let mut coeff_classes = Vec::>::new(); + let mut terms = Vec::with_capacity(family.terms.len()); + for (term_idx, term) in family.terms.iter().enumerate() { + let col_idx = term.source.col_idx(); + if col_idx >= num_binary_cols { + return Err(LinearCprAccumulatorError::SourceColumnOutOfRange { + family_idx: family.family_idx, + term_idx, + col_idx, + cols: num_binary_cols, + }); + } + if term.coeffs_by_active_row.len() != family.active_rows.len() { + return Err(LinearCprAccumulatorError::TermCoeffLengthMismatch { + family_idx: family.family_idx, + term_idx, + got: term.coeffs_by_active_row.len(), + expected: family.active_rows.len(), + }); + } + + let mut coeff_indices = Vec::with_capacity(term.coeffs_by_active_row.len()); + for coeff in &term.coeffs_by_active_row { + if coeff.is_zero() { + coeff_indices.push(None); + continue; + } + let idx = match coeff_classes.iter().position(|existing| existing == coeff) { + Some(idx) => idx, + None => { + coeff_classes.push(coeff.clone()); + coeff_classes.len() - 1 + } + }; + coeff_indices.push(Some(idx)); + } + terms.push(PreparedTerm { + source: term.source.clone(), + coeff_indices_by_active_row: coeff_indices, + }); + } + + let coeff_values = coeff_classes + .iter() + .map(|coeff| coeff.to_field(field_cfg)) + .collect(); + + prepared.push(PreparedFamily { + family_idx: family.family_idx, + active_rows: family.active_rows.clone(), + coeff_values, + terms, + }); + } + Ok(prepared) +} + +fn source_poly<'a, PolyCoeff, Int, const D: usize>( + trace: &'a UairTrace<'_, PolyCoeff, Int, D>, + source: &LinearBinarySource, + row: usize, +) -> Option<&'a BinaryPoly> +where + PolyCoeff: Clone, + Int: Clone, +{ + match source { + LinearBinarySource::Column { col_idx } => trace.binary_poly[*col_idx].evaluations.get(row), + LinearBinarySource::ShiftedColumn { col_idx, shift } => row + .checked_add(*shift) + .and_then(|shifted_row| trace.binary_poly[*col_idx].evaluations.get(shifted_row)), + } +} + +#[allow(clippy::arithmetic_side_effects)] +fn add_poly_bits_to_bucket( + poly: &BinaryPoly, + weight: &F, + reducer: &BarrettDelayedReduction<'_, F>, + bucket: &mut [Uint<5>; D], +) -> usize +where + F: MontgomeryLimbs + Send + Sync, +{ + if D <= 64 { + let mut bits = 0u64; + for (bit_idx, coeff) in poly.iter().enumerate().take(D) { + if coeff.into_inner() { + bits |= 1u64 << bit_idx; + } + } + + let mut adds = 0usize; + while bits != 0 { + let bit_idx = + usize::try_from(bits.trailing_zeros()).expect("trailing_zeros fits usize"); + reducer.add(&mut bucket[bit_idx], weight); + bits &= bits - 1; + adds += 1; + } + adds + } else { + let mut adds = 0usize; + for (bit_idx, coeff) in poly.iter().enumerate().take(D) { + if coeff.into_inner() { + reducer.add(&mut bucket[bit_idx], weight); + adds += 1; + } + } + adds + } +} + +fn flush_buckets_into_lanes( + buckets: &mut [[Uint<5>; D]], + lane_accs: &mut [[F; D]], + reducer: &BarrettDelayedReduction<'_, F>, +) where + F: MontgomeryLimbs + Send + Sync, +{ + for (bucket_lanes, acc_lanes) in buckets.iter_mut().zip(lane_accs.iter_mut()) { + for (bucket, acc) in bucket_lanes.iter_mut().zip(acc_lanes.iter_mut()) { + if bucket.is_zero() { + continue; + } + let pending = std::mem::replace(bucket, Uint::zero()); + *acc += reducer.reduce(pending); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + neutron_nova::LinearInstanceClaims, sumcheck::multi_degree::MultiDegreeSumcheck, + test_utils::test_config, + }; + use crypto_primitives::{FromWithConfig, crypto_bigint_monty::MontyField}; + use std::borrow::Cow; + use zinc_poly::{ + mle::{DenseMultilinearExtension, MultilinearExtensionWithConfig}, + utils::{build_eq_x_r_vec, eq_eval}, + }; + use zinc_transcript::Blake3Transcript; + use zinc_utils::powers; + + type F = MontyField<4>; + type Trace = UairTrace<'static, F, F, 32>; + + fn f(value: u64) -> F { + F::from_with_cfg(value, &test_config()) + } + + fn binary_column(patterns: &[u32]) -> DenseMultilinearExtension> { + assert!(patterns.len().is_power_of_two()); + DenseMultilinearExtension::from_evaluations_vec( + usize::try_from(patterns.len().trailing_zeros()).expect("trailing_zeros fits usize"), + patterns + .iter() + .copied() + .map(BinaryPoly::<32>::from) + .collect(), + BinaryPoly::<32>::zero(), + ) + } + + fn trace_from_columns(col0: &[u32], col1: &[u32]) -> Trace { + UairTrace { + binary_poly: Cow::Owned(vec![binary_column(col0), binary_column(col1)]), + arbitrary_poly: Cow::Owned(Vec::new()), + int: Cow::Owned(Vec::new()), + } + } + + fn sample_traces() -> Vec { + vec![ + trace_from_columns( + &[0x0000_0001, 0x0000_0002, 0x8000_0001, 0x0001_0010], + &[0x0000_0005, 0x0000_000a, 0x0000_0101, 0x0000_1000], + ), + trace_from_columns( + &[0x0000_0003, 0x0000_0010, 0x0000_00f0, 0x0000_f000], + &[0x0000_0006, 0x0000_0009, 0x0000_0f00, 0x0000_00ff], + ), + trace_from_columns( + &[0x0000_0100, 0x0000_0201, 0x0000_0402, 0x0000_0804], + &[0x0000_0011, 0x0000_0022, 0x0000_0044, 0x0000_0088], + ), + trace_from_columns( + &[0x0000_aaaa, 0x0000_5555, 0x0000_3333, 0x0000_cccc], + &[0x0000_1234, 0x0000_4321, 0x0000_00f1, 0x0000_0f10], + ), + ] + } + + fn sample_families() -> Vec> { + vec![ + LinearFamilySpec { + family_idx: 0, + active_rows: vec![0, 1, 2], + terms: vec![ + LinearTermSpec { + source: LinearBinarySource::Column { col_idx: 0 }, + coeffs_by_active_row: vec![ + CoeffClass::Small(1), + CoeffClass::Small(-1), + CoeffClass::Zero, + ], + }, + LinearTermSpec { + source: LinearBinarySource::Column { col_idx: 1 }, + coeffs_by_active_row: vec![ + CoeffClass::Small(2), + CoeffClass::Small(1), + CoeffClass::Small(-2), + ], + }, + ], + }, + LinearFamilySpec { + family_idx: 1, + active_rows: vec![1, 3], + terms: vec![ + LinearTermSpec { + source: LinearBinarySource::Column { col_idx: 0 }, + coeffs_by_active_row: vec![CoeffClass::Large(f(7)), CoeffClass::Small(-1)], + }, + LinearTermSpec { + source: LinearBinarySource::ShiftedColumn { + col_idx: 1, + shift: 1, + }, + coeffs_by_active_row: vec![CoeffClass::Small(1), CoeffClass::Small(3)], + }, + ], + }, + ] + } + + fn scalar_weights() -> Vec { + powers(f(3), F::one_with_cfg(&test_config()), 32) + } + + #[allow(clippy::arithmetic_side_effects)] + fn naive_linear_cpr_claims( + traces: &[Trace], + families: &[LinearFamilySpec], + row_weights: &RowWeights, + scalar_weights: LinearCprScalarWeights<'_, F>, + ) -> Vec { + let cfg = test_config(); + let mut claims = vec![F::zero_with_cfg(&cfg); traces.len()]; + + for (trace, claim) in traces.iter().zip(claims.iter_mut()) { + for family in families { + let mut family_value = F::zero_with_cfg(&cfg); + for (active_pos, &row) in family.active_rows.iter().enumerate() { + for term in &family.terms { + let coeff = &term.coeffs_by_active_row[active_pos]; + if coeff.is_zero() { + continue; + } + let coeff_value = coeff.to_field(&cfg); + let Some(poly) = source_poly::(trace, &term.source, row) else { + continue; + }; + + for (bit_idx, bit) in poly.iter().enumerate().take(32) { + if bit.into_inner() { + let mut contribution = row_weights.as_slice()[row].clone(); + contribution *= &coeff_value; + contribution *= &scalar_weights.scalarization_powers[bit_idx]; + family_value += contribution; + } + } + } + } + *claim += scalar_weights.family_weights[family.family_idx].clone() * family_value; + } + } + + claims + } + + #[allow(clippy::arithmetic_side_effects)] + fn naive_linear_cpr_table( + traces: &[Trace], + prefix_vars: usize, + families: &[LinearFamilySpec], + weights: LinearCprWeights<'_, F>, + scalar_weights: LinearCprScalarWeights<'_, F>, + ) -> Vec { + let cfg = test_config(); + let ell = + usize::try_from(traces.len().trailing_zeros()).expect("trailing_zeros fits usize"); + let prefix_len = 1usize << prefix_vars; + let tail_len = 1usize << (ell - prefix_vars); + let claims = naive_linear_cpr_claims(traces, families, weights.row_weights, scalar_weights); + let mut table = vec![F::zero_with_cfg(&cfg); prefix_len]; + + for (prefix, value) in table.iter_mut().enumerate() { + for tail in 0..tail_len { + let instance_idx = prefix + (tail << prefix_vars); + *value += weights.tail_eq_weights[tail].clone() * &claims[instance_idx]; + } + } + + table + } + + fn build_table_for_prefix_vars( + prefix_vars: usize, + ) -> (LinearPrefixTable, SumFoldEqWeights) { + let cfg = test_config(); + let traces = sample_traces(); + let families = sample_families(); + let beta = vec![f(19), f(23)]; + let eq_weights = build_sumfold_eq_weights(&beta, prefix_vars, &cfg).unwrap(); + let row_weights = RowWeights::new(&[f(11), f(13)], &cfg).unwrap(); + let family_weights = vec![f(5), f(17)]; + let scalarization_powers = scalar_weights(); + + let table = build_linear_cpr_prefix_table( + &traces, + prefix_vars, + &families, + LinearCprWeights { + row_weights: &row_weights, + tail_eq_weights: &eq_weights.tail_eq_weights, + }, + LinearCprScalarWeights { + family_weights: &family_weights, + scalarization_powers: &scalarization_powers, + }, + &cfg, + ) + .unwrap(); + + let expected = naive_linear_cpr_table( + &traces, + prefix_vars, + &families, + LinearCprWeights { + row_weights: &row_weights, + tail_eq_weights: &eq_weights.tail_eq_weights, + }, + LinearCprScalarWeights { + family_weights: &family_weights, + scalarization_powers: &scalarization_powers, + }, + ); + assert_eq!(table.values(), expected.as_slice()); + + (table, eq_weights) + } + + #[test] + fn optimized_linear_cpr_table_matches_expanded_formula_for_live_prefix() { + let (table, eq_weights) = build_table_for_prefix_vars(2); + + assert_eq!(table.ell(), 2); + assert_eq!(table.ell0(), 2); + assert_eq!( + eq_weights.tail_eq_weights, + vec![F::one_with_cfg(&test_config())] + ); + } + + #[test] + fn optimized_linear_cpr_table_matches_expanded_formula_for_windowed_prefix() { + let (table, eq_weights) = build_table_for_prefix_vars(1); + + assert_eq!(table.ell(), 2); + assert_eq!(table.ell0(), 1); + assert_eq!(table.len(), 2); + assert_eq!(eq_weights.tail_eq_weights.len(), 2); + } + + #[test] + fn optimized_linear_cpr_prefix_bound_tail_table_matches_naive_claim_binding() { + let cfg = test_config(); + let traces = sample_traces(); + let families = sample_families(); + let prefix_vars = 1; + let prefix_point = vec![f(29)]; + let row_weights = RowWeights::new(&[f(11), f(13)], &cfg).unwrap(); + let family_weights = vec![f(5), f(17)]; + let scalarization_powers = scalar_weights(); + let scalar_weights = LinearCprScalarWeights { + family_weights: &family_weights, + scalarization_powers: &scalarization_powers, + }; + + let tail_mle = build_linear_cpr_prefix_bound_tail_table( + &traces, + prefix_vars, + &prefix_point, + &families, + &row_weights, + scalar_weights, + &cfg, + ) + .unwrap(); + let dense_claims = + naive_linear_cpr_claims(&traces, &families, &row_weights, scalar_weights); + let prefix_weights = build_eq_x_r_vec(&prefix_point, &cfg).unwrap(); + + let tail_len = 1usize << (2 - prefix_vars); + let mut expected = vec![F::zero_with_cfg(&cfg); tail_len]; + for (tail, value) in expected.iter_mut().enumerate() { + for (prefix, weight) in prefix_weights.iter().enumerate() { + let instance_idx = prefix + (tail << prefix_vars); + *value += weight.clone() * &dense_claims[instance_idx]; + } + } + + assert_eq!(tail_mle.num_vars, 1); + assert_eq!( + tail_mle.evaluations, + expected + .iter() + .map(|value| value.inner().clone()) + .collect::>() + ); + } + + #[test] + fn optimized_linear_cpr_hybrid_sumfold_matches_full_ordinary_sumcheck() { + let cfg = test_config(); + let traces = sample_traces(); + let families = sample_families(); + let beta = vec![f(19), f(23)]; + let row_weights = RowWeights::new(&[f(11), f(13)], &cfg).unwrap(); + let family_weights = vec![f(5), f(17)]; + let scalarization_powers = scalar_weights(); + let scalar_weights = LinearCprScalarWeights { + family_weights: &family_weights, + scalarization_powers: &scalarization_powers, + }; + let dense_claims = + naive_linear_cpr_claims(&traces, &families, &row_weights, scalar_weights); + let claims = LinearInstanceClaims::from_claims_for_ell(dense_claims, 2).unwrap(); + + let full_group = claims.build_full_sumcheck_group(&beta, &cfg).unwrap(); + let optimized_group = build_linear_cpr_hybrid_sumcheck_group( + traces.clone(), + 1, + &beta, + families.clone(), + row_weights.clone(), + family_weights.clone(), + scalarization_powers.clone(), + &cfg, + ) + .unwrap(); + let fallback_group = build_linear_cpr_hybrid_sumcheck_group( + traces, + 0, + &beta, + families, + row_weights, + family_weights, + scalarization_powers, + &cfg, + ) + .unwrap(); + + let mut prover_transcript = Blake3Transcript::new(); + let (full_proof, _states) = MultiDegreeSumcheck::prove_as_subprotocol( + &mut prover_transcript, + vec![full_group], + 2, + &cfg, + ); + let mut prover_transcript = Blake3Transcript::new(); + let (optimized_proof, _states) = MultiDegreeSumcheck::prove_as_subprotocol( + &mut prover_transcript, + vec![optimized_group], + 2, + &cfg, + ); + assert_eq!(optimized_proof, full_proof); + let mut prover_transcript = Blake3Transcript::new(); + let (fallback_proof, _states) = MultiDegreeSumcheck::prove_as_subprotocol( + &mut prover_transcript, + vec![fallback_group], + 2, + &cfg, + ); + assert_eq!(fallback_proof, full_proof); + + let mut verifier_transcript = Blake3Transcript::new(); + let full_subclaims = MultiDegreeSumcheck::verify_as_subprotocol( + &mut verifier_transcript, + 2, + &full_proof, + &cfg, + ) + .expect("full linear CPR sumcheck should verify"); + let mut verifier_transcript = Blake3Transcript::new(); + let optimized_subclaims = MultiDegreeSumcheck::verify_as_subprotocol( + &mut verifier_transcript, + 2, + &optimized_proof, + &cfg, + ) + .expect("optimized linear CPR sumcheck should verify"); + assert_eq!(optimized_subclaims.point(), full_subclaims.point()); + assert_eq!( + optimized_subclaims.expected_evaluations(), + full_subclaims.expected_evaluations() + ); + + let point = full_subclaims.point(); + let eq_at_point = + eq_eval(point, &beta, F::one_with_cfg(&cfg)).expect("same number of variables"); + let zero_inner = F::zero_with_cfg(&cfg).inner().clone(); + let claims_mle = DenseMultilinearExtension::from_evaluations_vec( + 2, + claims + .claims() + .iter() + .map(|value| value.inner().clone()) + .collect(), + zero_inner, + ); + let claim_eval = claims_mle.evaluate_with_config(point, &cfg).unwrap(); + assert_eq!( + full_subclaims.expected_evaluations()[0], + eq_at_point * claim_eval + ); + } + + #[test] + fn optimized_linear_cpr_flushes_dmr_for_small_modulus() { + let cfg = test_config(); + let trace_count = 2048usize; + let traces: Vec<_> = (0..trace_count) + .map(|_| trace_from_columns(&[u32::MAX], &[0])) + .collect(); + let families = vec![LinearFamilySpec { + family_idx: 0, + active_rows: vec![0], + terms: vec![LinearTermSpec { + source: LinearBinarySource::Column { col_idx: 0 }, + coeffs_by_active_row: vec![CoeffClass::Small(1)], + }], + }]; + let row_weights = RowWeights::new(&[], &cfg).unwrap(); + let max = -F::from_with_cfg(1u64, &cfg); + let tail_eq_weights = vec![max; trace_count]; + let family_weights = vec![F::one_with_cfg(&cfg)]; + let scalarization_powers = scalar_weights(); + let weights = LinearCprWeights { + row_weights: &row_weights, + tail_eq_weights: &tail_eq_weights, + }; + let scalar_weights = LinearCprScalarWeights { + family_weights: &family_weights, + scalarization_powers: &scalarization_powers, + }; + + let table = + build_linear_cpr_prefix_table(&traces, 0, &families, weights, scalar_weights, &cfg) + .unwrap(); + let expected = naive_linear_cpr_table(&traces, 0, &families, weights, scalar_weights); + + assert_eq!(table.values(), expected.as_slice()); + } + + #[test] + fn optimized_linear_cpr_validation_errors_are_reported() { + let cfg = test_config(); + let traces = sample_traces(); + let families = sample_families(); + let row_weights = RowWeights::new(&[f(11), f(13)], &cfg).unwrap(); + let family_weights = vec![f(5), f(17)]; + let scalarization_powers = scalar_weights(); + + let err = build_linear_cpr_prefix_table( + &traces, + 3, + &families, + LinearCprWeights { + row_weights: &row_weights, + tail_eq_weights: &[F::one_with_cfg(&cfg)], + }, + LinearCprScalarWeights { + family_weights: &family_weights, + scalarization_powers: &scalarization_powers, + }, + &cfg, + ) + .expect_err("too many prefix variables should be rejected"); + assert!(matches!( + err, + LinearCprAccumulatorError::PrefixVarsTooLarge { + prefix_vars: 3, + ell: 2 + } + )); + + let err = build_linear_cpr_prefix_table( + &traces, + 2, + &families, + LinearCprWeights { + row_weights: &row_weights, + tail_eq_weights: &[], + }, + LinearCprScalarWeights { + family_weights: &family_weights, + scalarization_powers: &scalarization_powers, + }, + &cfg, + ) + .expect_err("wrong tail weight length should be rejected"); + assert!(matches!( + err, + LinearCprAccumulatorError::LengthMismatch { + label: "tail_eq_weights", + got: 0, + expected: 1 + } + )); + + let mut bad_families = sample_families(); + bad_families[0].terms[0].coeffs_by_active_row.pop(); + let eq_weights = build_sumfold_eq_weights(&[f(19), f(23)], 2, &cfg).unwrap(); + let err = build_linear_cpr_prefix_table( + &traces, + 2, + &bad_families, + LinearCprWeights { + row_weights: &row_weights, + tail_eq_weights: &eq_weights.tail_eq_weights, + }, + LinearCprScalarWeights { + family_weights: &family_weights, + scalarization_powers: &scalarization_powers, + }, + &cfg, + ) + .expect_err("wrong coefficient vector length should be rejected"); + assert!(matches!( + err, + LinearCprAccumulatorError::TermCoeffLengthMismatch { + family_idx: 0, + term_idx: 0, + got: 2, + expected: 3 + } + )); + } +} diff --git a/piop/src/neutron_nova/mod.rs b/piop/src/neutron_nova/mod.rs new file mode 100644 index 00000000..15f6615a --- /dev/null +++ b/piop/src/neutron_nova/mod.rs @@ -0,0 +1,62 @@ +//! NeutronNova small-value accumulation helpers. +//! +//! This module is intentionally standalone for now: it provides the linear +//! row-space accumulator and SumFold prefix-table primitives without changing +//! protocol proof objects or verifier flow. + +pub mod accumulator; +pub mod booleanity; +pub mod linear_cpr; +pub mod projection_sha; +pub mod sumfold; + +pub use accumulator::{ + AccumulatorError, RowWeights, SmallValueBitAccumulator, accumulate_binary_column_projected, +}; +pub use booleanity::{ + BooleanityAccumulatorError, BooleanityPrefixTable, BooleanityScalarWeights, BooleanityWeights, + ExtendedPrefixPoint, build_booleanity_prefix_table, extended_point_from_index, + extended_point_index, ternary_domain_size, +}; +pub use linear_cpr::{ + CoeffClass, LinearBinarySource, LinearCprAccumulatorError, LinearCprScalarWeights, + LinearCprWeights, LinearFamilySpec, LinearTermSpec, SumFoldEqWeights, + build_linear_cpr_hybrid_sumcheck_group, build_linear_cpr_prefix_bound_tail_table, + build_linear_cpr_prefix_table, build_sumfold_eq_weights, +}; +pub use projection_sha::{ + FoldedCommitments, FreshIdealEvaluationCache, InstanceFoldClaim, LinearResidualCoeffTable, + MleColumn, MleTable, NUM_NONZERO_SHA_FAMILIES, NUM_SHA_RESIDUAL_FAMILIES, ProjectedPublic, + ProjectedTrace, ProjectionFoldAccumulator, ProjectionFoldWitness, SHA_ROW_COUNT, SHA_ROW_VARS, + SHA_WORD_BITS, ShaBinaryFoldField, ShaBooleanitySource, ShaIntCol, ShaProductionIdeal, + ShaProjectionError, ShaPublicCol, ShaPublicWordCol, ShaResidualFamily, ShaWordCol, + VirtualChMajValues, beta_aggregate_nonzero_ideal_polys, + beta_aggregate_nonzero_ideal_polys_with_weights, bit_slice_index, build_booleanity_weights, + build_dense_sha_sumfold_group, build_dense_sha_sumfold_group_with_weights, + build_expression_folded_row_sumcheck_group, + build_expression_folded_row_sumcheck_group_with_row_weights, build_folded_row_sumcheck_group, + build_fresh_sha_ideal_cache, build_linear_residual_coeff_tables, + build_linear_residual_coeff_tables_with_row_weights, build_production_sha_sumfold_group, + build_production_sha_sumfold_group_from_prefix_accumulators, + build_production_sha_sumfold_group_owned, build_production_sha_sumfold_group_with_linear_cache, + build_production_sha_sumfold_group_with_linear_cache_and_weights, + build_sha_ideal_values_at_point, build_sha_lambda_powers, build_sha_residual_eval_powers, + build_sha_sumfold_linear_accumulator, build_sha_sumfold_quadratic_prefix_accumulator, + check_fresh_sha_ideal_cache, check_sha_ideal_values, derive_instance_fold_claim, + evaluate_fresh_sha_targets, expression_folded_row_sum, + expression_folded_row_sum_with_row_weights, expression_folded_row_sum_with_vectors, + fold_projected_traces, folded_row_integrand_sum, folded_row_integrand_values, + folded_row_integrand_values_with_row_weights, folded_row_integrand_values_with_vectors, + production_sha_booleanity_sources, production_sha_nonzero_families, + production_sha_nonzero_ideals, reconstruct_virtual_ch_maj_at_row, scalarize_bit_slices, + sha_int_at_point, sha_int_at_point_with_weights, sha_int_at_point_with_weights_unchecked, + sha_linear_residual_row_value, sha_linear_residual_sum, sha_public_at_point, + sha_public_at_point_with_weights, sha_scalarized_word_at_point, + sha_scalarized_word_at_point_with_weights, sha_word_bits_at_point, + sha_word_bits_at_point_with_weights, sha_word_bits_at_point_with_weights_unchecked, + validate_fresh_sha_ideal_polys_canonical, validate_projected_trace, + verify_folded_row_sumcheck_claim, verify_folded_scalarization_links, + verify_folded_scalarization_links_at_point, verify_folded_shifted_scalarization_link_at_point, + verify_fresh_sha_ideal_polys, +}; +pub use sumfold::{LinearInstanceClaims, LinearPrefixTable, SumFoldError}; diff --git a/piop/src/neutron_nova/projection_sha.rs b/piop/src/neutron_nova/projection_sha.rs new file mode 100644 index 00000000..0e8420ae --- /dev/null +++ b/piop/src/neutron_nova/projection_sha.rs @@ -0,0 +1,6493 @@ +//! Production SHA-256 ProjectionFold helpers. +//! +//! This module implements the SHA-specific data model and reference +//! computations used by the production ProjectionFold flow: +//! +//! fresh ideal checks -> SumFold over instances -> post-SumFold folding -> +//! folded row check over the 128-row SHA domain. + +use crate::ideal_check::batched_ideal_check; +use crate::neutron_nova::SumFoldError; +use crate::{ + CombFn, + sumcheck::multi_degree::{MultiDegreeSumcheckGroup, PrefixFastPath, PrefixRoundOutput}, +}; +use crypto_primitives::{ + PrimeField, crypto_bigint_boxed_monty::BoxedMontyField, crypto_bigint_monty::MontyField, + crypto_bigint_uint::Uint, +}; +use num_traits::{ConstZero, Zero}; +#[cfg(feature = "parallel")] +use rayon::prelude::*; +use thiserror::Error; +use zinc_poly::{ + mle::DenseMultilinearExtension, + univariate::dynamic::over_field::{DynamicPolyFInnerProduct, DynamicPolynomialF}, + utils::{ArithErrors, build_eq_x_r_vec, eq_eval}, +}; +use zinc_uair::{ + ideal::{Ideal, IdealCheck, IdealCheckError, rotation::RotationIdeal}, + ideal_collector::IdealOrZero, +}; +use zinc_utils::{ + UNCHECKED, cfg_chunks, cfg_into_iter, cfg_iter, + delayed_reduction::{ + BarrettDelayedReduction, DelayedFieldProductSum, DelayedModularReductionAlgorithm, + MontgomeryLimbs, + }, + from_ref::FromRef, + inner_product::{FieldFieldInnerProduct, InnerProduct}, + inner_transparent_field::InnerTransparentField, + powers, +}; + +pub const SHA_ROW_VARS: usize = 7; +pub const SHA_ROW_COUNT: usize = 128; +pub const SHA_WORD_BITS: usize = 32; +pub const NUM_SHA_RESIDUAL_FAMILIES: usize = 18; +pub const NUM_NONZERO_SHA_FAMILIES: usize = 7; +const SHA_RESIDUAL_EVAL_POWER_COUNT: usize = 62; + +pub type MleColumn = DenseMultilinearExtension; +pub type MleTable = Vec>; + +pub trait ShaBinaryFoldField: PrimeField + Send + Sync + Sized { + fn fold_binary_mle_tables( + kind: &'static str, + tables: &[&MleTable], + theta: &[Self], + field_cfg: &Self::Config, + ) -> Result, ShaProjectionError>; +} + +impl ShaBinaryFoldField for MontyField<4> { + fn fold_binary_mle_tables( + kind: &'static str, + tables: &[&MleTable], + theta: &[Self], + field_cfg: &Self::Config, + ) -> Result, ShaProjectionError> { + fold_binary_mle_tables_montgomery(kind, tables, theta, field_cfg) + } +} + +impl ShaBinaryFoldField for BoxedMontyField { + fn fold_binary_mle_tables( + kind: &'static str, + tables: &[&MleTable], + theta: &[Self], + field_cfg: &Self::Config, + ) -> Result, ShaProjectionError> { + fold_binary_mle_tables_generic(kind, tables, theta, field_cfg) + } +} + +const NONZERO_SHA_FAMILIES: [ShaResidualFamily; NUM_NONZERO_SHA_FAMILIES] = [ + ShaResidualFamily::R0BigSigmaA, + ShaResidualFamily::R1BigSigmaE, + ShaResidualFamily::R4Schedule, + ShaResidualFamily::R5UpdateA, + ShaResidualFamily::R6UpdateE, + ShaResidualFamily::R9FeedForwardA, + ShaResidualFamily::R10FeedForwardE, +]; + +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub enum ShaResidualFamily { + R0BigSigmaA, + R1BigSigmaE, + R2SmallSigma0, + R3SmallSigma1, + R4Schedule, + R5UpdateA, + R6UpdateE, + R7PinA, + R8PinE, + R9FeedForwardA, + R10FeedForwardE, + R11MessagePin, + R12CompSchedule, + R13CompUpdateA, + R14CompUpdateE, + R15CompFeedForwardA, + R16CompFeedForwardE, + R17CarryHighBits, +} + +impl ShaResidualFamily { + pub const ALL: [Self; NUM_SHA_RESIDUAL_FAMILIES] = [ + Self::R0BigSigmaA, + Self::R1BigSigmaE, + Self::R2SmallSigma0, + Self::R3SmallSigma1, + Self::R4Schedule, + Self::R5UpdateA, + Self::R6UpdateE, + Self::R7PinA, + Self::R8PinE, + Self::R9FeedForwardA, + Self::R10FeedForwardE, + Self::R11MessagePin, + Self::R12CompSchedule, + Self::R13CompUpdateA, + Self::R14CompUpdateE, + Self::R15CompFeedForwardA, + Self::R16CompFeedForwardE, + Self::R17CarryHighBits, + ]; + + pub fn index(self) -> usize { + match self { + Self::R0BigSigmaA => 0, + Self::R1BigSigmaE => 1, + Self::R2SmallSigma0 => 2, + Self::R3SmallSigma1 => 3, + Self::R4Schedule => 4, + Self::R5UpdateA => 5, + Self::R6UpdateE => 6, + Self::R7PinA => 7, + Self::R8PinE => 8, + Self::R9FeedForwardA => 9, + Self::R10FeedForwardE => 10, + Self::R11MessagePin => 11, + Self::R12CompSchedule => 12, + Self::R13CompUpdateA => 13, + Self::R14CompUpdateE => 14, + Self::R15CompFeedForwardA => 15, + Self::R16CompFeedForwardE => 16, + Self::R17CarryHighBits => 17, + } + } + + pub fn is_nonzero_ideal(self) -> bool { + matches!( + self, + Self::R0BigSigmaA + | Self::R1BigSigmaE + | Self::R4Schedule + | Self::R5UpdateA + | Self::R6UpdateE + | Self::R9FeedForwardA + | Self::R10FeedForwardE + ) + } +} + +#[repr(usize)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub enum ShaWordCol { + A = 0, + E = 1, + Sigma0 = 2, + Sigma1 = 3, + W = 4, + SmallSigma0 = 5, + SmallSigma1 = 6, + Uef = 7, + UNegEg = 8, + Maj = 9, + MuPacked = 10, + OvSigma0 = 11, + OvSigma1 = 12, + OvSmallSigma0 = 13, + OvSmallSigma1 = 14, + Ch2Comp = 15, + MajComp = 16, +} + +impl ShaWordCol { + pub const ALL: [Self; 17] = [ + Self::A, + Self::E, + Self::Sigma0, + Self::Sigma1, + Self::W, + Self::SmallSigma0, + Self::SmallSigma1, + Self::Uef, + Self::UNegEg, + Self::Maj, + Self::MuPacked, + Self::OvSigma0, + Self::OvSigma1, + Self::OvSmallSigma0, + Self::OvSmallSigma1, + Self::Ch2Comp, + Self::MajComp, + ]; + + pub const COUNT: usize = 17; + + pub fn index(self) -> usize { + self as usize + } +} + +#[repr(usize)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub enum ShaIntCol { + CompSchedule = 0, + CompUpdateA = 1, + CompUpdateE = 2, + CompFeedForwardA = 3, + CompFeedForwardE = 4, +} + +impl ShaIntCol { + pub const ALL: [Self; 5] = [ + Self::CompSchedule, + Self::CompUpdateA, + Self::CompUpdateE, + Self::CompFeedForwardA, + Self::CompFeedForwardE, + ]; + + pub const COUNT: usize = 5; + + pub fn index(self) -> usize { + self as usize + } +} + +#[repr(usize)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub enum ShaPublicCol { + K = 0, + PAIn = 1, + PEIn = 2, + PAOut = 3, + PEOut = 4, + Message = 5, + SInit = 6, + SMsg = 7, + SSched = 8, + SUpd = 9, + SFf = 10, + SOut = 11, +} + +impl ShaPublicCol { + pub const ALL: [Self; 12] = [ + Self::K, + Self::PAIn, + Self::PEIn, + Self::PAOut, + Self::PEOut, + Self::Message, + Self::SInit, + Self::SMsg, + Self::SSched, + Self::SUpd, + Self::SFf, + Self::SOut, + ]; + + pub const COUNT: usize = 12; + + pub fn index(self) -> usize { + self as usize + } +} + +#[repr(usize)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub enum ShaPublicWordCol { + PAIn = 0, + PEIn = 1, + PAOut = 2, + PEOut = 3, + Message = 4, +} + +impl ShaPublicWordCol { + pub const ALL: [Self; 5] = [ + Self::PAIn, + Self::PEIn, + Self::PAOut, + Self::PEOut, + Self::Message, + ]; + + pub const COUNT: usize = 5; + + pub fn index(self) -> usize { + self as usize + } +} + +impl ShaPublicCol { + fn public_word_col(self) -> Option { + match self { + Self::PAIn => Some(ShaPublicWordCol::PAIn), + Self::PEIn => Some(ShaPublicWordCol::PEIn), + Self::PAOut => Some(ShaPublicWordCol::PAOut), + Self::PEOut => Some(ShaPublicWordCol::PEOut), + Self::Message => Some(ShaPublicWordCol::Message), + _ => None, + } + } +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct ProjectedTrace { + /// Flattened as `[word_col * SHA_WORD_BITS + bit][row]`. + pub bit_slices: MleTable, + /// Indexed as `[word_col][row]`. + pub scalarized: MleTable, + /// Indexed as `[int_col][row]`. + pub int_columns: MleTable, + /// Indexed as `[public_col][row]`. + pub public_columns: MleTable, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct ProjectedPublic { + /// Indexed as `[public_col][row]`. + pub columns: MleTable, + /// Flattened as `[public_word_col * SHA_WORD_BITS + bit][row]`. + pub bit_slices: Option>, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct FreshIdealEvaluationCache { + pub r_ic: [F; SHA_ROW_VARS], + pub ideal_polys: Vec<[DynamicPolynomialF; NUM_NONZERO_SHA_FAMILIES]>, + pub taus_at_a: Vec<[F; NUM_NONZERO_SHA_FAMILIES]>, + pub fresh_targets: Vec, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct LinearResidualCoeffTable { + /// Indexed by residual family. + pub coeffs: Vec>, +} + +impl LinearResidualCoeffTable +where + F: PrimeField, +{ + pub fn coeffs_for_family(&self, family: ShaResidualFamily) -> Option<&DynamicPolynomialF> { + self.coeffs.get(family.index()) + } +} + +pub fn beta_aggregate_nonzero_ideal_polys( + tables: &[LinearResidualCoeffTable], + beta: &[F], + field_cfg: &F::Config, +) -> Result<[DynamicPolynomialF; NUM_NONZERO_SHA_FAMILIES], ShaProjectionError> +where + F: PrimeField, +{ + let weights = build_eq_x_r_vec(beta, field_cfg)?; + beta_aggregate_nonzero_ideal_polys_with_weights(tables, &weights) +} + +pub fn beta_aggregate_nonzero_ideal_polys_with_weights( + tables: &[LinearResidualCoeffTable], + beta_eq_weights: &[F], +) -> Result<[DynamicPolynomialF; NUM_NONZERO_SHA_FAMILIES], ShaProjectionError> +where + F: PrimeField, +{ + if beta_eq_weights.len() != tables.len() { + return Err(ShaProjectionError::InstanceCountMismatch { + got: beta_eq_weights.len(), + expected: tables.len(), + }); + } + + let mut aggregate: [DynamicPolynomialF; NUM_NONZERO_SHA_FAMILIES] = + std::array::from_fn(|_| DynamicPolynomialF::ZERO); + for (weight, table) in beta_eq_weights.iter().zip(tables) { + for (slot, family) in NONZERO_SHA_FAMILIES.iter().enumerate() { + let residual = + table + .coeffs + .get(family.index()) + .ok_or(ShaProjectionError::MissingColumn { + kind: "linear_residual_coeffs", + col: family.index(), + })?; + add_scaled_poly_assign(&mut aggregate[slot], residual, weight); + } + } + aggregate.iter_mut().for_each(DynamicPolynomialF::trim); + Ok(aggregate) +} + +pub fn build_sha_residual_eval_powers(a: &F, field_cfg: &F::Config) -> Vec +where + F: PrimeField, +{ + powers( + a.clone(), + F::one_with_cfg(field_cfg), + SHA_RESIDUAL_EVAL_POWER_COUNT, + ) +} + +pub fn build_sha_lambda_powers(lambda: &F, field_cfg: &F::Config) -> Vec +where + F: PrimeField, +{ + powers( + lambda.clone(), + F::one_with_cfg(field_cfg), + NUM_SHA_RESIDUAL_FAMILIES, + ) +} + +pub fn build_booleanity_weights( + rho: &F, + xi: &F, + source_count: usize, + field_cfg: &F::Config, +) -> Vec +where + F: PrimeField, +{ + powers(rho.clone(), F::one_with_cfg(field_cfg), source_count) + .into_iter() + .map(|rho_power| xi.clone() * rho_power) + .collect() +} + +fn selected_nonzero_sha_lambda_powers( + lambda_powers: &[F], +) -> Result<[F; NUM_NONZERO_SHA_FAMILIES], ShaProjectionError> +where + F: PrimeField, +{ + if lambda_powers.len() != NUM_SHA_RESIDUAL_FAMILIES { + return Err(ShaProjectionError::MissingColumn { + kind: "lambda_powers", + col: lambda_powers.len(), + }); + } + Ok(std::array::from_fn(|slot| { + lambda_powers[NONZERO_SHA_FAMILIES[slot].index()].clone() + })) +} + +pub fn build_sha_sumfold_linear_accumulator( + tables: &[LinearResidualCoeffTable], + a_powers: &[F], + lambda_powers: &[F], + field_cfg: &F::Config, +) -> Result, ShaProjectionError> +where + F: DelayedFieldProductSum, +{ + if lambda_powers.len() != NUM_SHA_RESIDUAL_FAMILIES { + return Err(ShaProjectionError::MissingColumn { + kind: "lambda_powers", + col: lambda_powers.len(), + }); + } + if a_powers.len() < SHA_RESIDUAL_EVAL_POWER_COUNT { + return Err(ShaProjectionError::MissingColumn { + kind: "a_powers", + col: a_powers.len(), + }); + } + cfg_iter!(tables) + .map(|table| { + if table.coeffs.len() != NUM_SHA_RESIDUAL_FAMILIES { + return Err(ShaProjectionError::MissingColumn { + kind: "linear_residual_coeffs", + col: table.coeffs.len(), + }); + } + let mut values: [F; NUM_SHA_RESIDUAL_FAMILIES] = + std::array::from_fn(|_| F::zero_with_cfg(field_cfg)); + for (family_idx, residual) in table.coeffs.iter().enumerate() { + values[family_idx] = evaluate_poly_at_powers_dmr(residual, &a_powers, field_cfg)?; + } + FieldFieldInnerProduct::inner_product::( + &values, + lambda_powers, + F::zero_with_cfg(field_cfg), + ) + .map_err(ShaProjectionError::from) + }) + .collect() +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct InstanceFoldClaim { + pub r_b: Vec, + pub c_sf: F, + pub final_round_sumcheck_claim: F, + pub eq_instance_weights: Vec, +} + +impl InstanceFoldClaim { + pub fn r_b(&self) -> &[F] { + &self.r_b + } + + pub fn c_sf(&self) -> &F { + &self.c_sf + } + + pub fn final_round_sumcheck_claim(&self) -> &F { + &self.final_round_sumcheck_claim + } + + pub fn eq_instance_weights(&self) -> &[F] { + &self.eq_instance_weights + } +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct FoldedCommitments { + pub commitments: Vec, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct ProjectionFoldAccumulator { + pub instance_fold_claim: InstanceFoldClaim, + pub commitments: FoldedCommitments, + pub public: ProjectedPublic, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct ProjectionFoldWitness { + pub trace: ProjectedTrace, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum ShaBooleanitySource { + WordBit { col: ShaWordCol, bit: usize }, + VirtualCh1 { bit: usize }, + VirtualCh2 { bit: usize }, + VirtualMaj { bit: usize }, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct VirtualChMajValues { + pub ch1: [F; SHA_WORD_BITS], + pub ch2: [F; SHA_WORD_BITS], + pub maj: [F; SHA_WORD_BITS], +} + +#[derive(Clone, Debug)] +pub enum ShaProductionIdeal { + RotX2(RotationIdeal), + RotXw1, +} + +impl FromRef> for ShaProductionIdeal { + fn from_ref(value: &ShaProductionIdeal) -> Self { + value.clone() + } +} + +impl Ideal for ShaProductionIdeal {} + +impl IdealCheck> for ShaProductionIdeal { + fn contains(&self, value: &DynamicPolynomialF) -> Result { + match self { + ShaProductionIdeal::RotX2(ideal) => IdealOrZero::NonZero(ideal.clone()).contains(value), + ShaProductionIdeal::RotXw1 => { + if value.coeffs.is_empty() { + return Ok(true); + } + let one = F::one_with_cfg(value.coeffs[0].cfg()); + IdealOrZero::NonZero(RotationIdeal::::new(one)).contains(value) + } + } + } +} + +#[derive(Clone, Debug, Error)] +pub enum ShaProjectionError { + #[error("expected {expected} rows, got {got}")] + RowCount { expected: usize, got: usize }, + #[error("row index out of range: {row}")] + RowIndexOutOfRange { row: usize }, + #[error("{kind} column {col} is missing")] + MissingColumn { kind: &'static str, col: usize }, + #[error("{kind} column {col} row length mismatch: got {got}, expected {expected}")] + ColumnRowCount { + kind: &'static str, + col: usize, + got: usize, + expected: usize, + }, + #[error("word column {col} row {row} bit length mismatch: got {got}, expected {expected}")] + BitCount { + col: usize, + row: usize, + got: usize, + expected: usize, + }, + #[error("row-batching point length mismatch: got {got}, expected 7")] + RowPointLength { got: usize }, + #[error("instance count must be a power of two, got {got}")] + InstanceCountNotPowerOfTwo { got: usize }, + #[error("instance count mismatch: got {got}, expected {expected}")] + InstanceCountMismatch { got: usize, expected: usize }, + #[error("public word column presence mismatch across folded instances")] + PublicWordColumnPresenceMismatch, + #[error("folding weight count mismatch: got {got}, expected {expected}")] + FoldingWeightCount { got: usize, expected: usize }, + #[error("SumFold denominator eq(beta, r_b) is zero")] + ZeroSumFoldDenominator, + #[error("scalarization mismatch for word column {col}")] + ScalarizationMismatch { col: usize }, + #[error("folded row sumcheck claim does not match SumFold target")] + FoldedRowClaimMismatch, + #[error("booleanity bit index out of range: {bit}")] + BitIndexOutOfRange { bit: usize }, + #[error("non-canonical proof object: {0}")] + NonCanonicalProofObject(&'static str), + #[error("ideal membership check failed")] + IdealMembership, + #[error("polynomial evaluation failed: {0}")] + PolynomialEvaluation(#[from] zinc_poly::EvaluationError), + #[error("inner product failed: {0}")] + InnerProduct(#[from] zinc_utils::inner_product::InnerProductError), + #[error("equality table construction failed: {0}")] + EqTable(#[from] ArithErrors), + #[error("sumfold helper failed: {0}")] + SumFold(#[from] SumFoldError), +} + +pub fn production_sha_nonzero_families() -> &'static [ShaResidualFamily] { + &NONZERO_SHA_FAMILIES +} + +pub fn production_sha_nonzero_ideals( + field_cfg: &F::Config, +) -> [ShaProductionIdeal; NUM_NONZERO_SHA_FAMILIES] { + let two = F::one_with_cfg(field_cfg) + F::one_with_cfg(field_cfg); + [ + ShaProductionIdeal::RotXw1, + ShaProductionIdeal::RotXw1, + ShaProductionIdeal::RotX2(RotationIdeal::new(two.clone())), + ShaProductionIdeal::RotX2(RotationIdeal::new(two.clone())), + ShaProductionIdeal::RotX2(RotationIdeal::new(two.clone())), + ShaProductionIdeal::RotX2(RotationIdeal::new(two.clone())), + ShaProductionIdeal::RotX2(RotationIdeal::new(two)), + ] +} + +fn production_sha_ideal_max_degree(family: ShaResidualFamily) -> Result { + match family { + ShaResidualFamily::R0BigSigmaA | ShaResidualFamily::R1BigSigmaE => Ok(61), + ShaResidualFamily::R4Schedule + | ShaResidualFamily::R5UpdateA + | ShaResidualFamily::R6UpdateE + | ShaResidualFamily::R9FeedForwardA + | ShaResidualFamily::R10FeedForwardE => Ok(31), + _ => Err(ShaProjectionError::NonCanonicalProofObject( + "unexpected nonzero SHA ideal family", + )), + } +} + +pub fn validate_fresh_sha_ideal_polys_canonical( + ideal_polys: &[[DynamicPolynomialF; NUM_NONZERO_SHA_FAMILIES]], +) -> Result<(), ShaProjectionError> +where + F: PrimeField, +{ + for instance in ideal_polys { + for (slot, poly) in instance.iter().enumerate() { + if poly.coeffs.last().is_some_and(F::is_zero) { + return Err(ShaProjectionError::NonCanonicalProofObject( + "fresh ideal polynomial has trailing zero coefficients", + )); + } + + let family = production_sha_nonzero_families()[slot]; + let max_degree = production_sha_ideal_max_degree(family)?; + if poly.coeffs.len() > max_degree + 1 { + return Err(ShaProjectionError::NonCanonicalProofObject( + "fresh ideal polynomial exceeds production degree cap", + )); + } + } + } + Ok(()) +} + +pub fn verify_fresh_sha_ideal_polys( + ideal_polys: &[[DynamicPolynomialF; NUM_NONZERO_SHA_FAMILIES]], + field_cfg: &F::Config, +) -> Result<(), ShaProjectionError> +where + F: PrimeField, +{ + validate_fresh_sha_ideal_polys_canonical(ideal_polys)?; + + let ideals = production_sha_nonzero_ideals(field_cfg); + for values in ideal_polys { + batched_ideal_check(&ideals, values).map_err(|_err| ShaProjectionError::IdealMembership)?; + } + Ok(()) +} + +pub fn build_sha_ideal_values_at_point( + trace: &ProjectedTrace, + public: &ProjectedPublic, + r_ic: &[F; SHA_ROW_VARS], + field_cfg: &F::Config, +) -> Result<[DynamicPolynomialF; NUM_NONZERO_SHA_FAMILIES], ShaProjectionError> +where + F: PrimeField, +{ + validate_trace(trace)?; + validate_public(public)?; + + let row_weights = build_eq_x_r_vec(r_ic, field_cfg)?; + let mut out: [DynamicPolynomialF; NUM_NONZERO_SHA_FAMILIES] = + std::array::from_fn(|_| DynamicPolynomialF::ZERO); + + for (row, row_weight) in row_weights.iter().enumerate().take(SHA_ROW_COUNT) { + let residuals = residual_polys_at_row(trace, public, row, field_cfg)?; + for (slot, family) in NONZERO_SHA_FAMILIES.iter().enumerate() { + add_scaled_poly_assign(&mut out[slot], &residuals[family.index()], row_weight); + } + } + out.iter_mut().for_each(DynamicPolynomialF::trim); + Ok(out) +} + +pub fn check_sha_ideal_values( + values: &[DynamicPolynomialF; NUM_NONZERO_SHA_FAMILIES], + field_cfg: &F::Config, +) -> Result<(), ShaProjectionError> +where + F: PrimeField, +{ + let ideals = production_sha_nonzero_ideals(field_cfg); + batched_ideal_check(&ideals, values).map_err(|_err| ShaProjectionError::IdealMembership) +} + +#[allow(clippy::arithmetic_side_effects)] +pub fn bit_slice_index(col: usize, bit: usize, bits_per_col: usize) -> usize { + col * bits_per_col + bit +} + +fn mle_table_from_columns(columns: Vec>, num_vars: usize) -> MleTable { + columns + .into_iter() + .map(|evaluations| DenseMultilinearExtension { + evaluations, + num_vars, + }) + .collect() +} + +#[cfg(test)] +fn flatten_bit_columns( + columns: Vec>>, + bits_per_col: usize, + num_vars: usize, + kind: &'static str, +) -> Result, ShaProjectionError> { + let mut flattened = (0..columns.len() * bits_per_col) + .map(|_| Vec::new()) + .collect::>>(); + for (col_idx, rows) in columns.into_iter().enumerate() { + if rows.len() != SHA_ROW_COUNT { + return Err(ShaProjectionError::ColumnRowCount { + kind, + col: col_idx, + got: rows.len(), + expected: SHA_ROW_COUNT, + }); + } + for (row, bits) in rows.into_iter().enumerate() { + if bits.len() != bits_per_col { + return Err(ShaProjectionError::BitCount { + col: col_idx, + row, + got: bits.len(), + expected: bits_per_col, + }); + } + for (bit, value) in bits.into_iter().enumerate() { + flattened[bit_slice_index(col_idx, bit, bits_per_col)].push(value); + } + } + } + Ok(mle_table_from_columns(flattened, num_vars)) +} + +pub fn build_fresh_sha_ideal_cache( + traces: &[ProjectedTrace], + publics: &[ProjectedPublic], + r_ic: [F; SHA_ROW_VARS], + field_cfg: &F::Config, +) -> Result, ShaProjectionError> +where + F: PrimeField, +{ + if traces.len() != publics.len() { + return Err(ShaProjectionError::InstanceCountMismatch { + got: publics.len(), + expected: traces.len(), + }); + } + let ideal_polys = traces + .iter() + .zip(publics) + .map(|(trace, public)| build_sha_ideal_values_at_point(trace, public, &r_ic, field_cfg)) + .collect::, _>>()?; + + Ok(FreshIdealEvaluationCache { + r_ic, + ideal_polys, + taus_at_a: Vec::new(), + fresh_targets: Vec::new(), + }) +} + +pub fn build_linear_residual_coeff_tables( + traces: &[ProjectedTrace], + publics: &[ProjectedPublic], + r_ic: &[F; SHA_ROW_VARS], + field_cfg: &F::Config, +) -> Result>, ShaProjectionError> +where + F: PrimeField, +{ + let row_weights = build_eq_x_r_vec(r_ic, field_cfg)?; + build_linear_residual_coeff_tables_with_row_weights(traces, publics, &row_weights, field_cfg) +} + +pub fn build_linear_residual_coeff_tables_with_row_weights( + traces: &[ProjectedTrace], + publics: &[ProjectedPublic], + row_weights: &[F], + field_cfg: &F::Config, +) -> Result>, ShaProjectionError> +where + F: PrimeField, +{ + if traces.len() != publics.len() { + return Err(ShaProjectionError::InstanceCountMismatch { + got: publics.len(), + expected: traces.len(), + }); + } + if row_weights.len() != SHA_ROW_COUNT { + return Err(ShaProjectionError::ColumnRowCount { + kind: "row_weights", + col: 0, + got: row_weights.len(), + expected: SHA_ROW_COUNT, + }); + } + cfg_iter!(traces) + .zip(cfg_iter!(publics)) + .map(|(trace, public)| { + #[cfg(debug_assertions)] + { + validate_trace(trace)?; + validate_public(public)?; + } + let constants = ShaResidualPolyConstants::new(field_cfg); + let partials = cfg_chunks!(row_weights, 64) + .enumerate() + .map(|(chunk_idx, row_weight_chunk)| { + let mut partial = FixedResidualCoeffAccumulator::new( + NUM_SHA_RESIDUAL_FAMILIES, + SHA_RESIDUAL_EVAL_POWER_COUNT, + field_cfg, + ); + let row_offset = chunk_idx * 64; + for (row_in_chunk, row_weight) in row_weight_chunk.iter().enumerate() { + let row = row_offset + row_in_chunk; + accumulate_residual_row_fixed( + &mut partial, + trace, + public, + row, + row_weight, + &constants, + field_cfg, + )?; + } + Ok(partial) + }) + .collect::, ShaProjectionError>>()?; + let mut coeffs = FixedResidualCoeffAccumulator::new( + NUM_SHA_RESIDUAL_FAMILIES, + SHA_RESIDUAL_EVAL_POWER_COUNT, + field_cfg, + ); + for partial in partials { + coeffs.add_assign(partial); + } + Ok(coeffs.into_table()) + }) + .collect::, _>>() +} + +pub fn check_fresh_sha_ideal_cache( + cache: &FreshIdealEvaluationCache, + field_cfg: &F::Config, +) -> Result<(), ShaProjectionError> +where + F: PrimeField, +{ + verify_fresh_sha_ideal_polys(&cache.ideal_polys, field_cfg) +} + +pub fn evaluate_fresh_sha_targets( + cache: &mut FreshIdealEvaluationCache, + a: &F, + lambda: &F, + field_cfg: &F::Config, +) -> Result<(), ShaProjectionError> +where + F: DelayedFieldProductSum, +{ + let one = F::one_with_cfg(field_cfg); + let zero = F::zero_with_cfg(field_cfg); + let lambda_powers = powers(lambda.clone(), one, NUM_SHA_RESIDUAL_FAMILIES); + let a_powers = powers( + a.clone(), + F::one_with_cfg(field_cfg), + SHA_RESIDUAL_EVAL_POWER_COUNT, + ); + let nonzero_lambda_powers = selected_nonzero_sha_lambda_powers(&lambda_powers)?; + + cache.taus_at_a.clear(); + cache.fresh_targets.clear(); + + for ideal_polys in &cache.ideal_polys { + let mut tau_values = Vec::with_capacity(NUM_NONZERO_SHA_FAMILIES); + for poly in ideal_polys { + tau_values.push(evaluate_poly_at_powers_dmr(poly, &a_powers, field_cfg)?); + } + let taus: [F; NUM_NONZERO_SHA_FAMILIES] = tau_values + .try_into() + .unwrap_or_else(|_| unreachable!("exactly seven SHA ideal values were evaluated")); + let target = FieldFieldInnerProduct::inner_product::( + &nonzero_lambda_powers, + &taus, + zero.clone(), + ) + .map_err(ShaProjectionError::from)?; + cache.taus_at_a.push(taus); + cache.fresh_targets.push(target); + } + Ok(()) +} + +pub fn derive_instance_fold_claim( + beta: &[F], + r_b: Vec, + c_sf: F, + instance_count: usize, + field_cfg: &F::Config, +) -> Result, ShaProjectionError> +where + F: PrimeField, +{ + if !instance_count.is_power_of_two() { + return Err(ShaProjectionError::InstanceCountNotPowerOfTwo { + got: instance_count, + }); + } + let ell = usize::try_from(instance_count.trailing_zeros()).expect("ell fits usize"); + if beta.len() != ell { + return Err(ShaProjectionError::InstanceCountMismatch { + got: beta.len(), + expected: ell, + }); + } + if r_b.len() != ell { + return Err(ShaProjectionError::InstanceCountMismatch { + got: r_b.len(), + expected: ell, + }); + } + + let one = F::one_with_cfg(field_cfg); + let d = eq_eval(beta, &r_b, one)?; + if F::is_zero(&d) { + return Err(ShaProjectionError::ZeroSumFoldDenominator); + } + + let eq_instance_weights = build_eq_x_r_vec(&r_b, field_cfg)?; + debug_assert_eq!(eq_instance_weights.len(), instance_count); + let final_round_sumcheck_claim = c_sf.clone() / d; + Ok(InstanceFoldClaim { + r_b, + c_sf, + final_round_sumcheck_claim, + eq_instance_weights, + }) +} + +pub fn fold_projected_traces( + traces: &[ProjectedTrace], + publics: &[ProjectedPublic], + sumfold: &InstanceFoldClaim, + field_cfg: &F::Config, +) -> Result<(ProjectionFoldWitness, ProjectedPublic), ShaProjectionError> +where + F: ShaBinaryFoldField, +{ + if traces.len() != publics.len() { + return Err(ShaProjectionError::InstanceCountMismatch { + got: publics.len(), + expected: traces.len(), + }); + } + if sumfold.eq_instance_weights.len() != traces.len() { + return Err(ShaProjectionError::FoldingWeightCount { + got: sumfold.eq_instance_weights.len(), + expected: traces.len(), + }); + } + #[cfg(debug_assertions)] + { + for trace in traces { + validate_trace(trace)?; + } + for public in publics { + validate_public(public)?; + } + } + + let folded_public_columns = fold_mle_tables( + "public.columns", + publics.iter().map(|public| &public.columns), + &sumfold.eq_instance_weights, + field_cfg, + )?; + let folded_trace = ProjectedTrace { + bit_slices: fold_binary_mle_tables( + "bit_slices", + traces.iter().map(|trace| &trace.bit_slices), + &sumfold.eq_instance_weights, + field_cfg, + )?, + scalarized: fold_mle_tables( + "scalarized", + traces.iter().map(|trace| &trace.scalarized), + &sumfold.eq_instance_weights, + field_cfg, + )?, + int_columns: fold_mle_tables( + "int_columns", + traces.iter().map(|trace| &trace.int_columns), + &sumfold.eq_instance_weights, + field_cfg, + )?, + public_columns: folded_public_columns.clone(), + }; + let folded_public = ProjectedPublic { + columns: folded_public_columns, + bit_slices: fold_optional_binary_mle_tables( + "public.bit_slices", + publics.iter().map(|public| public.bit_slices.as_ref()), + &sumfold.eq_instance_weights, + field_cfg, + )?, + }; + + Ok(( + ProjectionFoldWitness { + trace: folded_trace, + }, + folded_public, + )) +} + +pub fn scalarize_bit_slices( + bit_slices: &MleTable, + a: &F, + field_cfg: &F::Config, +) -> Result, ShaProjectionError> +where + F: PrimeField + MontgomeryLimbs + DelayedFieldProductSum + Send + Sync, +{ + let powers = powers(a.clone(), F::one_with_cfg(field_cfg), SHA_WORD_BITS); + if bit_slices.len() % SHA_WORD_BITS != 0 { + return Err(ShaProjectionError::MissingColumn { + kind: "bit_slices", + col: bit_slices.len(), + }); + } + let word_col_count = bit_slices.len() / SHA_WORD_BITS; + let reducer = BarrettDelayedReduction::::new(field_cfg); + let mut words = Vec::with_capacity(word_col_count); + for col_idx in 0..word_col_count { + let mut out_col = Vec::with_capacity(SHA_ROW_COUNT); + for row in 0..SHA_ROW_COUNT { + let mut bits = Vec::with_capacity(SHA_WORD_BITS); + for bit in 0..SHA_WORD_BITS { + bits.push(scalar_from_table( + "bit_slices", + bit_slices, + bit_slice_index(col_idx, bit, SHA_WORD_BITS), + row, + field_cfg, + )?); + } + out_col.push(project_binary_bits_conditional_add_dmr( + &bits, &powers, field_cfg, &reducer, + )?); + } + words.push(out_col); + } + Ok(mle_table_from_columns(words, SHA_ROW_VARS)) +} + +pub fn verify_folded_scalarization_links( + trace: &ProjectedTrace, + a: &F, + word_cols: &[ShaWordCol], + field_cfg: &F::Config, +) -> Result<(), ShaProjectionError> +where + F: PrimeField + DelayedFieldProductSum, +{ + validate_trace(trace)?; + let powers = powers(a.clone(), F::one_with_cfg(field_cfg), SHA_WORD_BITS); + for col in word_cols { + let col_idx = col.index(); + for row in 0..SHA_ROW_COUNT { + let mut bits = Vec::with_capacity(SHA_WORD_BITS); + for bit in 0..SHA_WORD_BITS { + bits.push(scalar_from_table( + "bit_slices", + &trace.bit_slices, + bit_slice_index(col_idx, bit, SHA_WORD_BITS), + row, + field_cfg, + )?); + } + let recombined = project_bits_dmr(&bits, &powers, field_cfg)?; + let scalar = + scalar_from_table("scalarized", &trace.scalarized, col_idx, row, field_cfg)?; + if recombined != scalar { + return Err(ShaProjectionError::ScalarizationMismatch { col: col_idx }); + } + } + } + Ok(()) +} + +pub fn verify_folded_scalarization_links_at_point( + trace: &ProjectedTrace, + a: &F, + r_star: &[F; SHA_ROW_VARS], + word_cols: &[ShaWordCol], + field_cfg: &F::Config, +) -> Result<(), ShaProjectionError> +where + F: PrimeField + DelayedFieldProductSum, +{ + for col in word_cols { + verify_folded_shifted_scalarization_link_at_point(trace, a, r_star, *col, 0, field_cfg)?; + } + Ok(()) +} + +pub fn verify_folded_shifted_scalarization_link_at_point( + trace: &ProjectedTrace, + a: &F, + r_star: &[F; SHA_ROW_VARS], + col: ShaWordCol, + shift: usize, + field_cfg: &F::Config, +) -> Result<(), ShaProjectionError> +where + F: PrimeField + DelayedFieldProductSum, +{ + validate_trace(trace)?; + let row_weights = build_eq_x_r_vec(r_star, field_cfg)?; + let powers = powers(a.clone(), F::one_with_cfg(field_cfg), SHA_WORD_BITS); + let mut word_eval = F::zero_with_cfg(field_cfg); + let mut bit_rows = Vec::with_capacity(SHA_ROW_COUNT); + + for (row, row_weight) in row_weights.iter().enumerate() { + word_eval += row_weight.clone() + * scalarized_word_at_shifted_or_zero(trace, col, row, shift, field_cfg)?; + let mut bits = Vec::with_capacity(SHA_WORD_BITS); + for bit in 0..SHA_WORD_BITS { + bits.push(bit_at_shifted_or_zero( + trace, col, row, shift, bit, field_cfg, + )?); + } + bit_rows.push(project_bits_dmr(&bits, &powers, field_cfg)?); + } + let bit_eval = FieldFieldInnerProduct::inner_product::( + &row_weights, + &bit_rows, + F::zero_with_cfg(field_cfg), + )?; + + if word_eval != bit_eval { + return Err(ShaProjectionError::ScalarizationMismatch { col: col.index() }); + } + Ok(()) +} + +pub fn reconstruct_virtual_ch_maj_at_row( + trace: &ProjectedTrace, + row: usize, + field_cfg: &F::Config, +) -> Result, ShaProjectionError> +where + F: PrimeField, +{ + validate_trace(trace)?; + reconstruct_virtual_ch_maj_at_row_unchecked(trace, row, field_cfg) +} + +fn reconstruct_virtual_ch_maj_at_row_unchecked( + trace: &ProjectedTrace, + row: usize, + field_cfg: &F::Config, +) -> Result, ShaProjectionError> +where + F: PrimeField, +{ + if row >= SHA_ROW_COUNT { + return Err(ShaProjectionError::RowIndexOutOfRange { row }); + } + let two = F::one_with_cfg(field_cfg) + F::one_with_cfg(field_cfg); + let ch1 = build_virtual_bit_array(|bit| { + Ok( + bit_at_shifted_or_zero_fast(trace, ShaWordCol::E, row, 2, bit, field_cfg) + + bit_at_shifted_or_zero_fast(trace, ShaWordCol::E, row, 1, bit, field_cfg) + - two.clone() + * bit_at_shifted_or_zero_fast(trace, ShaWordCol::Uef, row, 2, bit, field_cfg), + ) + }); + let ch2 = build_virtual_bit_array(|bit| { + Ok( + bit_at_shifted_or_zero_fast(trace, ShaWordCol::E, row, 2, bit, field_cfg) + - bit_at_shifted_or_zero_fast(trace, ShaWordCol::E, row, 0, bit, field_cfg) + + two.clone() + * bit_at_shifted_or_zero_fast( + trace, + ShaWordCol::UNegEg, + row, + 2, + bit, + field_cfg, + ) + + two.clone() + * bit_at_shifted_or_zero_fast( + trace, + ShaWordCol::Ch2Comp, + row, + 0, + bit, + field_cfg, + ), + ) + }); + let maj = build_virtual_bit_array(|bit| { + Ok( + bit_at_shifted_or_zero_fast(trace, ShaWordCol::A, row, 0, bit, field_cfg) + + bit_at_shifted_or_zero_fast(trace, ShaWordCol::A, row, 1, bit, field_cfg) + + bit_at_shifted_or_zero_fast(trace, ShaWordCol::A, row, 2, bit, field_cfg) + - two.clone() + * bit_at_shifted_or_zero_fast(trace, ShaWordCol::Maj, row, 2, bit, field_cfg) + - two.clone() + * bit_at_shifted_or_zero_fast( + trace, + ShaWordCol::MajComp, + row, + 0, + bit, + field_cfg, + ), + ) + }); + + Ok(VirtualChMajValues { + ch1: ch1?, + ch2: ch2?, + maj: maj?, + }) +} + +pub fn folded_row_integrand_values( + trace: &ProjectedTrace, + public: &ProjectedPublic, + r_ic: &[F; SHA_ROW_VARS], + a: &F, + lambda: &F, + rho: &F, + xi: &F, + booleanity_sources: &[ShaBooleanitySource], + field_cfg: &F::Config, +) -> Result, ShaProjectionError> +where + F: InnerTransparentField + DelayedFieldProductSum, + F::Inner: Zero, +{ + validate_trace(trace)?; + validate_public(public)?; + let row_weights = build_eq_x_r_vec(r_ic, field_cfg)?; + folded_row_integrand_values_with_row_weights( + trace, + public, + &row_weights, + a, + lambda, + rho, + xi, + booleanity_sources, + field_cfg, + ) +} + +#[allow(clippy::too_many_arguments)] +pub fn folded_row_integrand_values_with_row_weights( + trace: &ProjectedTrace, + public: &ProjectedPublic, + row_weights: &[F], + a: &F, + lambda: &F, + rho: &F, + xi: &F, + booleanity_sources: &[ShaBooleanitySource], + field_cfg: &F::Config, +) -> Result, ShaProjectionError> +where + F: InnerTransparentField + DelayedFieldProductSum, + F::Inner: Zero, +{ + let lambda_powers = build_sha_lambda_powers(lambda, field_cfg); + let a_powers = build_sha_residual_eval_powers(a, field_cfg); + let booleanity_weights = build_booleanity_weights(rho, xi, booleanity_sources.len(), field_cfg); + folded_row_integrand_values_with_vectors( + trace, + public, + row_weights, + &a_powers, + &lambda_powers, + &booleanity_weights, + booleanity_sources, + field_cfg, + ) +} + +#[allow(clippy::too_many_arguments)] +pub fn folded_row_integrand_values_with_vectors( + trace: &ProjectedTrace, + public: &ProjectedPublic, + row_weights: &[F], + a_powers: &[F], + lambda_powers: &[F], + booleanity_weights: &[F], + booleanity_sources: &[ShaBooleanitySource], + field_cfg: &F::Config, +) -> Result, ShaProjectionError> +where + F: InnerTransparentField + DelayedFieldProductSum, + F::Inner: Zero, +{ + validate_trace(trace)?; + validate_public(public)?; + if row_weights.len() != SHA_ROW_COUNT { + return Err(ShaProjectionError::ColumnRowCount { + kind: "row_weights", + col: 0, + got: row_weights.len(), + expected: SHA_ROW_COUNT, + }); + } + if lambda_powers.len() != NUM_SHA_RESIDUAL_FAMILIES { + return Err(ShaProjectionError::MissingColumn { + kind: "lambda_powers", + col: lambda_powers.len(), + }); + } + if a_powers.len() < SHA_RESIDUAL_EVAL_POWER_COUNT { + return Err(ShaProjectionError::MissingColumn { + kind: "a_powers", + col: a_powers.len(), + }); + } + if booleanity_weights.len() != booleanity_sources.len() { + return Err(ShaProjectionError::ColumnRowCount { + kind: "booleanity_weights", + col: 0, + got: booleanity_weights.len(), + expected: booleanity_sources.len(), + }); + } + let needs_virtuals = sources_need_virtuals(booleanity_sources); + + let mut out = Vec::with_capacity(SHA_ROW_COUNT); + for row in 0..SHA_ROW_COUNT { + let linear = sha_linear_residual_row_value_with_powers( + trace, + public, + row, + &a_powers, + &lambda_powers, + field_cfg, + )?; + + let mut bool_terms = Vec::with_capacity(booleanity_sources.len()); + let virtuals = if needs_virtuals { + Some(reconstruct_virtual_ch_maj_at_row_unchecked( + trace, row, field_cfg, + )?) + } else { + None + }; + for source in booleanity_sources { + let d = booleanity_source_value_at_row_with_virtuals( + trace, + row, + source, + virtuals.as_ref(), + field_cfg, + )?; + let term = d.clone() * (d - F::one_with_cfg(field_cfg)); + bool_terms.push(term); + } + let bool_sum = FieldFieldInnerProduct::inner_product::( + booleanity_weights, + &bool_terms, + F::zero_with_cfg(field_cfg), + ) + .map_err(ShaProjectionError::from)?; + out.push(row_weights[row].clone() * (linear + bool_sum)); + } + Ok(out) +} + +pub fn build_folded_row_sumcheck_group( + row_integrand_values: &[F], + field_cfg: &F::Config, +) -> Result, ShaProjectionError> +where + F: InnerTransparentField + DelayedFieldProductSum + Send + Sync + 'static, + F::Inner: Zero, +{ + if row_integrand_values.len() != SHA_ROW_COUNT { + return Err(ShaProjectionError::ColumnRowCount { + kind: "row_integrand", + col: 0, + got: row_integrand_values.len(), + expected: SHA_ROW_COUNT, + }); + } + let zero_inner = F::zero_with_cfg(field_cfg).inner().clone(); + let integrand = DenseMultilinearExtension::from_evaluations_vec( + SHA_ROW_VARS, + row_integrand_values + .iter() + .map(|value| value.inner().clone()) + .collect(), + zero_inner, + ); + Ok(MultiDegreeSumcheckGroup::new( + 1, + vec![integrand], + Box::new(|values: &[F]| values[0].clone()), + )) +} + +pub fn folded_row_integrand_sum( + row_integrand_values: &[F], + field_cfg: &F::Config, +) -> Result +where + F: PrimeField, +{ + if row_integrand_values.len() != SHA_ROW_COUNT { + return Err(ShaProjectionError::ColumnRowCount { + kind: "row_integrand", + col: 0, + got: row_integrand_values.len(), + expected: SHA_ROW_COUNT, + }); + } + Ok(row_integrand_values + .iter() + .fold(F::zero_with_cfg(field_cfg), |acc, value| acc + value)) +} + +/// Canonical booleanity sources for the production SHA ProjectionFold flow. +/// +/// This includes every committed binary-polynomial SHA source bit and the +/// three virtual Ch/Maj residual families. The virtual values are reconstructed +/// from source bit slices; they are never independent witness columns. +pub fn production_sha_booleanity_sources() -> Vec { + let mut sources = Vec::with_capacity(ShaWordCol::COUNT * SHA_WORD_BITS + 3 * SHA_WORD_BITS); + for col_idx in 0..ShaWordCol::COUNT { + let col = ShaWordCol::ALL[col_idx]; + for bit in 0..SHA_WORD_BITS { + sources.push(ShaBooleanitySource::WordBit { col, bit }); + } + } + for bit in 0..SHA_WORD_BITS { + sources.push(ShaBooleanitySource::VirtualCh1 { bit }); + sources.push(ShaBooleanitySource::VirtualCh2 { bit }); + sources.push(ShaBooleanitySource::VirtualMaj { bit }); + } + sources +} + +/// Evaluate the linear SHA residual scalarization at one row. +pub fn sha_linear_residual_row_value( + trace: &ProjectedTrace, + public: &ProjectedPublic, + row: usize, + a: &F, + lambda: &F, + field_cfg: &F::Config, +) -> Result +where + F: DelayedFieldProductSum, +{ + let lambda_powers = powers( + lambda.clone(), + F::one_with_cfg(field_cfg), + NUM_SHA_RESIDUAL_FAMILIES, + ); + let a_powers = powers( + a.clone(), + F::one_with_cfg(field_cfg), + SHA_RESIDUAL_EVAL_POWER_COUNT, + ); + sha_linear_residual_row_value_with_powers( + trace, + public, + row, + &a_powers, + &lambda_powers, + field_cfg, + ) +} + +/// Evaluate the row-weighted linear SHA residual scalarization for one +/// instance. +pub fn sha_linear_residual_sum( + trace: &ProjectedTrace, + public: &ProjectedPublic, + r_ic: &[F; SHA_ROW_VARS], + a: &F, + lambda: &F, + field_cfg: &F::Config, +) -> Result +where + F: DelayedFieldProductSum, +{ + validate_trace(trace)?; + validate_public(public)?; + let row_weights = build_eq_x_r_vec(r_ic, field_cfg)?; + let lambda_powers = powers( + lambda.clone(), + F::one_with_cfg(field_cfg), + NUM_SHA_RESIDUAL_FAMILIES, + ); + let a_powers = powers( + a.clone(), + F::one_with_cfg(field_cfg), + SHA_RESIDUAL_EVAL_POWER_COUNT, + ); + sha_linear_residual_sum_with_weights( + trace, + public, + &row_weights, + &a_powers, + &lambda_powers, + field_cfg, + ) +} + +fn sha_linear_residual_sum_with_weights( + trace: &ProjectedTrace, + public: &ProjectedPublic, + row_weights: &[F], + a_powers: &[F], + lambda_powers: &[F], + field_cfg: &F::Config, +) -> Result +where + F: DelayedFieldProductSum, +{ + let mut values = Vec::with_capacity(row_weights.len()); + for row in 0..row_weights.len() { + values.push(sha_linear_residual_row_value_with_powers( + trace, + public, + row, + a_powers, + lambda_powers, + field_cfg, + )?); + } + FieldFieldInnerProduct::inner_product::( + row_weights, + &values, + F::zero_with_cfg(field_cfg), + ) + .map_err(ShaProjectionError::from) +} + +fn sha_linear_residual_row_value_with_powers( + trace: &ProjectedTrace, + public: &ProjectedPublic, + row: usize, + a_powers: &[F], + lambda_powers: &[F], + field_cfg: &F::Config, +) -> Result +where + F: DelayedFieldProductSum, +{ + let residuals = residual_values_at_row_with_powers(trace, public, row, a_powers, field_cfg)?; + FieldFieldInnerProduct::inner_product::( + &residuals, + lambda_powers, + F::zero_with_cfg(field_cfg), + ) + .map_err(ShaProjectionError::from) +} + +#[derive(Clone, Debug)] +struct BinaryPrefixTailTable { + values: Vec, + prefix_vars: usize, + tail_len: usize, +} + +impl BinaryPrefixTailTable +where + F: PrimeField, +{ + fn new(values: Vec, prefix_vars: usize, tail_len: usize) -> Self { + debug_assert_eq!(values.len(), binary_len(prefix_vars) * tail_len); + Self { + values, + prefix_vars, + tail_len, + } + } + + #[allow(clippy::arithmetic_side_effects)] + fn bind_first_axis(&mut self, r: &F, field_cfg: &F::Config) { + debug_assert!(self.prefix_vars > 0); + let rest_len = binary_len(self.prefix_vars - 1); + let old_prefix_len = binary_len(self.prefix_vars); + let one = F::one_with_cfg(field_cfg); + let one_minus_r = one - r; + let mut next = vec![F::zero_with_cfg(field_cfg); rest_len * self.tail_len]; + + for tail in 0..self.tail_len { + let old_tail_offset = tail * old_prefix_len; + let new_tail_offset = tail * rest_len; + for rest in 0..rest_len { + let base = old_tail_offset + (rest << 1); + next[new_tail_offset + rest] = + self.values[base].clone() * &one_minus_r + self.values[base + 1].clone() * r; + } + } + + self.values = next; + self.prefix_vars -= 1; + } + + #[allow(clippy::arithmetic_side_effects)] + fn value_with_first_axis(&self, rest: usize, tail: usize, x: &F, field_cfg: &F::Config) -> F { + debug_assert!(self.prefix_vars > 0); + let prefix_len = binary_len(self.prefix_vars); + let base = tail * prefix_len + (rest << 1); + let one = F::one_with_cfg(field_cfg); + self.values[base].clone() * (one - x) + self.values[base + 1].clone() * x + } +} + +#[derive(Clone, Debug)] +struct TernaryPrefixTailTable { + values: Vec, + prefix_vars: usize, + tail_len: usize, +} + +impl TernaryPrefixTailTable +where + F: PrimeField, +{ + fn new( + values: Vec, + prefix_vars: usize, + tail_len: usize, + ) -> Result { + debug_assert_eq!(values.len(), checked_ternary_len(prefix_vars)? * tail_len); + Ok(Self { + values, + prefix_vars, + tail_len, + }) + } + + #[allow(clippy::arithmetic_side_effects)] + fn bind_first_axis(&mut self, r: &F, field_cfg: &F::Config) -> Result<(), ShaProjectionError> { + debug_assert!(self.prefix_vars > 0); + let rest_len = checked_ternary_len(self.prefix_vars - 1)?; + let old_prefix_len = checked_ternary_len(self.prefix_vars)?; + let one = F::one_with_cfg(field_cfg); + let one_minus_r = one - r; + let quadratic = r.clone() * (r.clone() - F::one_with_cfg(field_cfg)); + let mut next = vec![F::zero_with_cfg(field_cfg); rest_len * self.tail_len]; + + for tail in 0..self.tail_len { + let old_tail_offset = tail * old_prefix_len; + let new_tail_offset = tail * rest_len; + for rest in 0..rest_len { + let base = old_tail_offset + rest * 3; + next[new_tail_offset + rest] = self.values[base].clone() * &one_minus_r + + self.values[base + 1].clone() * r + + self.values[base + 2].clone() * &quadratic; + } + } + + self.values = next; + self.prefix_vars -= 1; + Ok(()) + } + + #[allow(clippy::arithmetic_side_effects)] + fn value_with_first_axis( + &self, + rest: usize, + tail: usize, + x: &F, + field_cfg: &F::Config, + ) -> Result { + debug_assert!(self.prefix_vars > 0); + let prefix_len = checked_ternary_len(self.prefix_vars)?; + let base = tail * prefix_len + rest * 3; + let one = F::one_with_cfg(field_cfg); + Ok(self.values[base].clone() * (one - x) + + self.values[base + 1].clone() * x + + self.values[base + 2].clone() + * (x.clone() * (x.clone() - F::one_with_cfg(field_cfg)))) + } +} + +struct RelationSumFoldPrefixFastPath { + tail_traces: Option]>>, + beta: Vec, + booleanity_sources: Vec, + linear: BinaryPrefixTailTable, + booleanity: TernaryPrefixTailTable, + tail_eq_weights: Vec, + prefix_suffix_eq_weights: Vec>, + total_prefix_vars: usize, + round: usize, + eq_bound: F, +} + +#[derive(Clone, Debug)] +struct TernaryCoeffPlan { + support_mask: usize, + finite_bits: usize, + vertices: Vec<(usize, bool)>, +} + +impl RelationSumFoldPrefixFastPath +where + F: InnerTransparentField + DelayedFieldProductSum + Send + Sync + 'static, + F::Inner: Zero, +{ + #[allow(clippy::too_many_arguments)] + fn new( + traces: &[ProjectedTrace], + publics: &[ProjectedPublic], + beta: &[F], + r_ic: &[F; SHA_ROW_VARS], + a: &F, + lambda: &F, + rho: &F, + xi: &F, + booleanity_sources: &[ShaBooleanitySource], + prefix_vars: usize, + field_cfg: &F::Config, + ) -> Result { + let coeff_tables = build_linear_residual_coeff_tables(traces, publics, r_ic, field_cfg)?; + let row_weights = build_eq_x_r_vec(r_ic, field_cfg)?; + Self::new_with_linear_cache( + traces, + publics, + beta, + r_ic, + &row_weights, + a, + lambda, + rho, + xi, + booleanity_sources, + prefix_vars, + &coeff_tables, + field_cfg, + ) + } + + #[allow(clippy::too_many_arguments)] + fn new_owned( + traces: Box<[ProjectedTrace]>, + publics: &[ProjectedPublic], + beta: &[F], + r_ic: &[F; SHA_ROW_VARS], + a: &F, + lambda: &F, + rho: &F, + xi: &F, + booleanity_sources: &[ShaBooleanitySource], + prefix_vars: usize, + field_cfg: &F::Config, + ) -> Result { + let coeff_tables = build_linear_residual_coeff_tables(&traces, publics, r_ic, field_cfg)?; + let row_weights = build_eq_x_r_vec(r_ic, field_cfg)?; + Self::new_owned_with_linear_cache( + traces, + publics, + beta, + r_ic, + &row_weights, + a, + lambda, + rho, + xi, + booleanity_sources, + prefix_vars, + &coeff_tables, + field_cfg, + ) + } + + #[allow(clippy::too_many_arguments)] + fn new_with_linear_cache( + traces: &[ProjectedTrace], + publics: &[ProjectedPublic], + beta: &[F], + _r_ic: &[F; SHA_ROW_VARS], + row_weights: &[F], + a: &F, + lambda: &F, + rho: &F, + xi: &F, + booleanity_sources: &[ShaBooleanitySource], + prefix_vars: usize, + coeff_tables: &[LinearResidualCoeffTable], + field_cfg: &F::Config, + ) -> Result { + let ell = validate_sha_sumfold_inputs(traces, publics, beta)?; + if prefix_vars == 0 || prefix_vars > ell { + return Err(SumFoldError::Ell0TooLarge { + ell0: prefix_vars, + ell, + } + .into()); + } + if coeff_tables.len() != traces.len() { + return Err(ShaProjectionError::InstanceCountMismatch { + got: coeff_tables.len(), + expected: traces.len(), + }); + } + if row_weights.len() != SHA_ROW_COUNT { + return Err(ShaProjectionError::ColumnRowCount { + kind: "row_weights", + col: 0, + got: row_weights.len(), + expected: SHA_ROW_COUNT, + }); + } + + let tail_vars = ell - prefix_vars; + let tail_len = binary_len(tail_vars); + let a_powers = build_sha_residual_eval_powers(a, field_cfg); + let lambda_powers = build_sha_lambda_powers(lambda, field_cfg); + let booleanity_weights = + build_booleanity_weights(rho, xi, booleanity_sources.len(), field_cfg); + let linear_values = build_sha_sumfold_linear_accumulator( + coeff_tables, + &a_powers, + &lambda_powers, + field_cfg, + )?; + let quadratic_values = build_sha_booleanity_prefix_tail_table( + traces, + booleanity_sources, + prefix_vars, + tail_len, + row_weights, + &booleanity_weights, + field_cfg, + ); + Self::new_with_accumulators( + traces, + beta, + &linear_values, + &quadratic_values?, + booleanity_sources, + prefix_vars, + field_cfg, + ) + } + + #[allow(clippy::too_many_arguments)] + fn new_owned_with_linear_cache( + traces: Box<[ProjectedTrace]>, + publics: &[ProjectedPublic], + beta: &[F], + _r_ic: &[F; SHA_ROW_VARS], + row_weights: &[F], + a: &F, + lambda: &F, + rho: &F, + xi: &F, + booleanity_sources: &[ShaBooleanitySource], + prefix_vars: usize, + coeff_tables: &[LinearResidualCoeffTable], + field_cfg: &F::Config, + ) -> Result { + let ell = validate_sha_sumfold_inputs(&traces, publics, beta)?; + if prefix_vars == 0 || prefix_vars > ell { + return Err(SumFoldError::Ell0TooLarge { + ell0: prefix_vars, + ell, + } + .into()); + } + if coeff_tables.len() != traces.len() { + return Err(ShaProjectionError::InstanceCountMismatch { + got: coeff_tables.len(), + expected: traces.len(), + }); + } + if row_weights.len() != SHA_ROW_COUNT { + return Err(ShaProjectionError::ColumnRowCount { + kind: "row_weights", + col: 0, + got: row_weights.len(), + expected: SHA_ROW_COUNT, + }); + } + + let tail_vars = ell - prefix_vars; + let tail_len = binary_len(tail_vars); + let a_powers = build_sha_residual_eval_powers(a, field_cfg); + let lambda_powers = build_sha_lambda_powers(lambda, field_cfg); + let booleanity_weights = + build_booleanity_weights(rho, xi, booleanity_sources.len(), field_cfg); + let linear_values = build_sha_sumfold_linear_accumulator( + coeff_tables, + &a_powers, + &lambda_powers, + field_cfg, + )?; + let quadratic_values = build_sha_booleanity_prefix_tail_table( + &traces, + booleanity_sources, + prefix_vars, + tail_len, + row_weights, + &booleanity_weights, + field_cfg, + ); + Self::new_owned_with_accumulators( + traces, + beta, + &linear_values, + &quadratic_values?, + booleanity_sources, + prefix_vars, + field_cfg, + ) + } + + fn new_with_accumulators( + traces: &[ProjectedTrace], + beta: &[F], + linear_values: &[F], + quadratic_values: &[F], + booleanity_sources: &[ShaBooleanitySource], + prefix_vars: usize, + field_cfg: &F::Config, + ) -> Result { + let ell = validate_sha_sumfold_traces(traces, beta)?; + if prefix_vars == 0 || prefix_vars > ell { + return Err(SumFoldError::Ell0TooLarge { + ell0: prefix_vars, + ell, + } + .into()); + } + + let tail_vars = ell - prefix_vars; + let tail_len = binary_len(tail_vars); + let linear_len = binary_len(prefix_vars) * tail_len; + if linear_values.len() != linear_len { + return Err(ShaProjectionError::InstanceCountMismatch { + got: linear_values.len(), + expected: linear_len, + }); + } + let quadratic_len = checked_ternary_len(prefix_vars)? * tail_len; + if quadratic_values.len() != quadratic_len { + return Err(ShaProjectionError::InstanceCountMismatch { + got: quadratic_values.len(), + expected: quadratic_len, + }); + } + + let linear = BinaryPrefixTailTable::new(linear_values.to_vec(), prefix_vars, tail_len); + let booleanity = + TernaryPrefixTailTable::new(quadratic_values.to_vec(), prefix_vars, tail_len)?; + + let tail_eq_weights = eq_weights_or_one(&beta[prefix_vars..], field_cfg)?; + let mut prefix_suffix_eq_weights = Vec::with_capacity(prefix_vars); + for round in 0..prefix_vars { + prefix_suffix_eq_weights + .push(eq_weights_or_one(&beta[round + 1..prefix_vars], field_cfg)?); + } + + let tail_traces = if tail_vars == 0 { + None + } else { + Some(traces.to_vec().into_boxed_slice()) + }; + + Ok(Self { + tail_traces, + beta: beta.to_vec(), + booleanity_sources: booleanity_sources.to_vec(), + linear, + booleanity, + tail_eq_weights, + prefix_suffix_eq_weights, + total_prefix_vars: prefix_vars, + round: 0, + eq_bound: F::one_with_cfg(field_cfg), + }) + } + + fn new_owned_with_accumulators( + traces: Box<[ProjectedTrace]>, + beta: &[F], + linear_values: &[F], + quadratic_values: &[F], + booleanity_sources: &[ShaBooleanitySource], + prefix_vars: usize, + field_cfg: &F::Config, + ) -> Result { + let mut fast_path = Self::new_with_accumulators( + &traces, + beta, + linear_values, + quadratic_values, + booleanity_sources, + prefix_vars, + field_cfg, + )?; + if fast_path.beta.len() > fast_path.total_prefix_vars { + fast_path.tail_traces = Some(traces); + } + Ok(fast_path) + } + + #[allow(clippy::arithmetic_side_effects)] + fn bind_previous_round( + &mut self, + r: &F, + field_cfg: &F::Config, + ) -> Result<(), ShaProjectionError> { + let beta_idx = self.round - 1; + self.eq_bound *= eq_one_var(&self.beta[beta_idx], r, field_cfg); + self.linear.bind_first_axis(r, field_cfg); + self.booleanity.bind_first_axis(r, field_cfg) + } + + fn round_value_at(&self, x: &F, field_cfg: &F::Config) -> Result { + debug_assert!(self.round < self.total_prefix_vars); + let suffix_weights = &self.prefix_suffix_eq_weights[self.round]; + let rest_len = suffix_weights.len(); + let mut acc = F::zero_with_cfg(field_cfg); + + for tail in 0..self.tail_eq_weights.len() { + for (rest, suffix_weight) in suffix_weights.iter().enumerate().take(rest_len) { + let linear = self.linear.value_with_first_axis(rest, tail, x, field_cfg); + let ternary_rest = binary_bits_to_ternary_index(rest, self.linear.prefix_vars - 1); + let booleanity = + self.booleanity + .value_with_first_axis(ternary_rest, tail, x, field_cfg)?; + acc += self.tail_eq_weights[tail].clone() * suffix_weight * (linear + booleanity); + } + } + + Ok(self.eq_bound.clone() * eq_one_var(&self.beta[self.round], x, field_cfg) * acc) + } + + #[allow(clippy::arithmetic_side_effects)] + fn finish_tail_mles( + mut self, + prefix_challenges: &[F], + field_cfg: &F::Config, + ) -> Result>, ShaProjectionError> { + debug_assert_eq!(prefix_challenges.len(), self.total_prefix_vars); + let tail_vars = self.beta.len() - self.total_prefix_vars; + if tail_vars == 0 { + return Ok(Vec::new()); + } + + while self.linear.prefix_vars > 0 { + let next_axis = self.total_prefix_vars - self.linear.prefix_vars; + let r = &prefix_challenges[next_axis]; + self.linear.bind_first_axis(r, field_cfg); + } + + let tail_len = binary_len(tail_vars); + debug_assert_eq!(self.linear.values.len(), tail_len); + + let prefix_weights = eq_weights_or_one(prefix_challenges, field_cfg)?; + let eq_prefix_at_r = eq_eval( + prefix_challenges, + &self.beta[..self.total_prefix_vars], + F::one_with_cfg(field_cfg), + )?; + let zero_inner = F::zero_with_cfg(field_cfg).inner().clone(); + + let mut mles = Vec::with_capacity(2 + self.booleanity_sources.len() * SHA_ROW_COUNT); + mles.push(DenseMultilinearExtension::from_evaluations_vec( + tail_vars, + self.tail_eq_weights + .iter() + .map(|tail_weight| (eq_prefix_at_r.clone() * tail_weight).inner().clone()) + .collect(), + zero_inner.clone(), + )); + mles.push(DenseMultilinearExtension::from_evaluations_vec( + tail_vars, + self.linear + .values + .iter() + .map(|value| value.inner().clone()) + .collect(), + zero_inner.clone(), + )); + + let traces = self + .tail_traces + .as_deref() + .expect("tail traces must be present when tail variables remain"); + let source_tail_values = bind_sha_booleanity_sources_to_prefix( + traces, + &self.booleanity_sources, + self.total_prefix_vars, + tail_len, + &prefix_weights, + field_cfg, + )?; + + for values in source_tail_values { + mles.push(DenseMultilinearExtension::from_evaluations_vec( + tail_vars, + values.iter().map(|value| value.inner().clone()).collect(), + zero_inner.clone(), + )); + } + + Ok(mles) + } +} + +impl PrefixFastPath for RelationSumFoldPrefixFastPath +where + F: InnerTransparentField + DelayedFieldProductSum + Send + Sync + 'static, + F::Inner: Zero, +{ + fn prefix_len(&self) -> usize { + self.total_prefix_vars + } + + fn prove_prefix_round( + &mut self, + verifier_msg: &Option, + config: &F::Config, + ) -> PrefixRoundOutput { + if let Some(r) = verifier_msg { + self.bind_previous_round(r, config) + .expect("validated SHA prefix table should bind"); + } + + let zero = F::zero_with_cfg(config); + let one = F::one_with_cfg(config); + let two = one.clone() + &one; + let three = two.clone() + &one; + + let p0 = self + .round_value_at(&zero, config) + .expect("validated SHA prefix table should evaluate at 0"); + let p1 = self + .round_value_at(&one, config) + .expect("validated SHA prefix table should evaluate at 1"); + let p2 = self + .round_value_at(&two, config) + .expect("validated SHA prefix table should evaluate at 2"); + let p3 = self + .round_value_at(&three, config) + .expect("validated SHA prefix table should evaluate at 3"); + + let asserted_sum = if self.round == 0 { + Some(p0 + &p1) + } else { + None + }; + self.round += 1; + + PrefixRoundOutput { + asserted_sum, + tail_evaluations: vec![p1, p2, p3], + } + } + + fn finish_prefix( + self: Box, + prefix_challenges: &[F], + config: &F::Config, + ) -> Vec> { + self.finish_tail_mles(prefix_challenges, config) + .expect("validated SHA prefix fast path should finish") + } +} + +/// Build the production SHA SumFold group over the instance axis. +/// +/// The group proves the expression +/// +/// `eq(beta, b) * (L(b) + xi * B(b))` +/// +/// where `L` is the row-weighted linear SHA residual scalarization and `B` +/// is built from source booleanity MLEs. Unlike a table of fresh targets, the +/// booleanity part is evaluated from source MLEs, so the terminal at `r_b` +/// is the folded booleanity expression. +#[allow(clippy::too_many_arguments)] +pub fn build_dense_sha_sumfold_group( + traces: &[ProjectedTrace], + publics: &[ProjectedPublic], + beta: &[F], + r_ic: &[F; SHA_ROW_VARS], + a: &F, + lambda: &F, + rho: &F, + xi: &F, + booleanity_sources: &[ShaBooleanitySource], + field_cfg: &F::Config, +) -> Result, ShaProjectionError> +where + F: InnerTransparentField + DelayedFieldProductSum + Send + Sync + 'static, + F::Inner: Zero, +{ + let row_weights = build_eq_x_r_vec(r_ic, field_cfg)?; + let beta_eq_weights = build_eq_x_r_vec(beta, field_cfg)?; + build_dense_sha_sumfold_group_with_weights( + traces, + publics, + beta, + &beta_eq_weights, + &row_weights, + a, + lambda, + rho, + xi, + booleanity_sources, + field_cfg, + ) +} + +#[allow(clippy::too_many_arguments)] +pub fn build_dense_sha_sumfold_group_with_weights( + traces: &[ProjectedTrace], + publics: &[ProjectedPublic], + beta: &[F], + beta_eq_weights: &[F], + row_weights: &[F], + a: &F, + lambda: &F, + rho: &F, + xi: &F, + booleanity_sources: &[ShaBooleanitySource], + field_cfg: &F::Config, +) -> Result, ShaProjectionError> +where + F: InnerTransparentField + DelayedFieldProductSum + Send + Sync + 'static, + F::Inner: Zero, +{ + let _ell = validate_sha_sumfold_inputs(traces, publics, beta)?; + if beta_eq_weights.len() != traces.len() { + return Err(ShaProjectionError::InstanceCountMismatch { + got: beta_eq_weights.len(), + expected: traces.len(), + }); + } + if row_weights.len() != SHA_ROW_COUNT { + return Err(ShaProjectionError::ColumnRowCount { + kind: "row_weights", + col: 0, + got: row_weights.len(), + expected: SHA_ROW_COUNT, + }); + } + let lambda_powers = build_sha_lambda_powers(lambda, field_cfg); + let a_powers = build_sha_residual_eval_powers(a, field_cfg); + let booleanity_weights = build_booleanity_weights(rho, xi, booleanity_sources.len(), field_cfg); + let linear_values = traces + .iter() + .zip(publics.iter()) + .map(|(trace, public)| { + sha_linear_residual_sum_with_weights( + trace, + public, + &row_weights, + &a_powers, + &lambda_powers, + field_cfg, + ) + }) + .collect::, _>>()?; + build_dense_sha_sumfold_group_from_accumulators( + traces, + beta, + beta_eq_weights, + row_weights, + &linear_values, + &booleanity_weights, + booleanity_sources, + field_cfg, + ) +} + +#[allow(clippy::too_many_arguments)] +fn build_dense_sha_sumfold_group_from_accumulators( + traces: &[ProjectedTrace], + beta: &[F], + beta_eq_weights: &[F], + row_weights: &[F], + linear_accumulator: &[F], + booleanity_weights: &[F], + booleanity_sources: &[ShaBooleanitySource], + field_cfg: &F::Config, +) -> Result, ShaProjectionError> +where + F: InnerTransparentField + DelayedFieldProductSum + Send + Sync + 'static, + F::Inner: Zero, +{ + let ell = validate_sha_sumfold_traces(traces, beta)?; + if beta_eq_weights.len() != traces.len() { + return Err(ShaProjectionError::InstanceCountMismatch { + got: beta_eq_weights.len(), + expected: traces.len(), + }); + } + if row_weights.len() != SHA_ROW_COUNT { + return Err(ShaProjectionError::ColumnRowCount { + kind: "row_weights", + col: 0, + got: row_weights.len(), + expected: SHA_ROW_COUNT, + }); + } + if linear_accumulator.len() != traces.len() { + return Err(ShaProjectionError::InstanceCountMismatch { + got: linear_accumulator.len(), + expected: traces.len(), + }); + } + if booleanity_weights.len() != booleanity_sources.len() { + return Err(ShaProjectionError::ColumnRowCount { + kind: "booleanity_weights", + col: 0, + got: booleanity_weights.len(), + expected: booleanity_sources.len(), + }); + } + + let zero_inner = F::zero_with_cfg(field_cfg).inner().clone(); + let mut mles = Vec::with_capacity(2 + booleanity_sources.len() * SHA_ROW_COUNT); + + mles.push(DenseMultilinearExtension::from_evaluations_vec( + ell, + beta_eq_weights + .iter() + .map(|value| value.inner().clone()) + .collect(), + zero_inner.clone(), + )); + mles.push(DenseMultilinearExtension::from_evaluations_vec( + ell, + linear_accumulator + .iter() + .map(|value| value.inner().clone()) + .collect(), + zero_inner.clone(), + )); + + for source in booleanity_sources { + for row in 0..SHA_ROW_COUNT { + let values = traces + .iter() + .map(|trace| booleanity_source_value_at_row(trace, row, source, field_cfg)) + .collect::, _>>()?; + mles.push(DenseMultilinearExtension::from_evaluations_vec( + ell, + values.iter().map(|value| value.inner().clone()).collect(), + zero_inner.clone(), + )); + } + } + + Ok(MultiDegreeSumcheckGroup::new( + 3, + mles, + sha_weighted_sumfold_comb_fn(row_weights.to_vec(), booleanity_weights.to_vec(), field_cfg), + )) +} + +#[allow(clippy::too_many_arguments)] +pub fn build_production_sha_sumfold_group_from_prefix_accumulators( + traces: &[ProjectedTrace], + beta: &[F], + beta_eq_weights: &[F], + row_weights: &[F], + linear_accumulator: &[F], + quadratic_prefix_accumulator: &[F], + booleanity_weights: &[F], + booleanity_sources: &[ShaBooleanitySource], + prefix_vars: usize, + field_cfg: &F::Config, +) -> Result, ShaProjectionError> +where + F: InnerTransparentField + DelayedFieldProductSum + Send + Sync + 'static, + F::Inner: Zero, +{ + let ell = validate_sha_sumfold_traces(traces, beta)?; + if beta_eq_weights.len() != traces.len() { + return Err(ShaProjectionError::InstanceCountMismatch { + got: beta_eq_weights.len(), + expected: traces.len(), + }); + } + if row_weights.len() != SHA_ROW_COUNT { + return Err(ShaProjectionError::ColumnRowCount { + kind: "row_weights", + col: 0, + got: row_weights.len(), + expected: SHA_ROW_COUNT, + }); + } + if linear_accumulator.len() != traces.len() { + return Err(ShaProjectionError::InstanceCountMismatch { + got: linear_accumulator.len(), + expected: traces.len(), + }); + } + if booleanity_weights.len() != booleanity_sources.len() { + return Err(ShaProjectionError::ColumnRowCount { + kind: "booleanity_weights", + col: 0, + got: booleanity_weights.len(), + expected: booleanity_sources.len(), + }); + } + if prefix_vars > ell { + return Err(SumFoldError::Ell0TooLarge { + ell0: prefix_vars, + ell, + } + .into()); + } + if prefix_vars == 0 { + if !quadratic_prefix_accumulator.is_empty() { + return Err(ShaProjectionError::InstanceCountMismatch { + got: quadratic_prefix_accumulator.len(), + expected: 0, + }); + } + return build_dense_sha_sumfold_group_from_accumulators( + traces, + beta, + beta_eq_weights, + row_weights, + linear_accumulator, + booleanity_weights, + booleanity_sources, + field_cfg, + ); + } + + let fast_path = RelationSumFoldPrefixFastPath::new_with_accumulators( + traces, + beta, + linear_accumulator, + quadratic_prefix_accumulator, + booleanity_sources, + prefix_vars, + field_cfg, + ); + + Ok(MultiDegreeSumcheckGroup::with_prefix_fast( + 3, + Vec::new(), + sha_weighted_sumfold_comb_fn(row_weights.to_vec(), booleanity_weights.to_vec(), field_cfg), + Box::new(fast_path?), + )) +} + +#[allow(clippy::too_many_arguments)] +pub fn build_production_sha_sumfold_group( + traces: &[ProjectedTrace], + publics: &[ProjectedPublic], + beta: &[F], + r_ic: &[F; SHA_ROW_VARS], + a: &F, + lambda: &F, + rho: &F, + xi: &F, + booleanity_sources: &[ShaBooleanitySource], + prefix_vars: usize, + field_cfg: &F::Config, +) -> Result, ShaProjectionError> +where + F: InnerTransparentField + DelayedFieldProductSum + Send + Sync + 'static, + F::Inner: Zero, +{ + let ell = validate_sha_sumfold_inputs(traces, publics, beta)?; + if prefix_vars > ell { + return Err(SumFoldError::Ell0TooLarge { + ell0: prefix_vars, + ell, + } + .into()); + } + if prefix_vars == 0 { + return build_dense_sha_sumfold_group( + traces, + publics, + beta, + r_ic, + a, + lambda, + rho, + xi, + booleanity_sources, + field_cfg, + ); + } + + let fast_path = RelationSumFoldPrefixFastPath::new( + traces, + publics, + beta, + r_ic, + a, + lambda, + rho, + xi, + booleanity_sources, + prefix_vars, + field_cfg, + ); + + Ok(MultiDegreeSumcheckGroup::with_prefix_fast( + 3, + Vec::new(), + sha_sumfold_comb_fn( + build_eq_x_r_vec(r_ic, field_cfg)?, + powers( + rho.clone(), + F::one_with_cfg(field_cfg), + booleanity_sources.len(), + ), + xi.clone(), + field_cfg, + ), + Box::new(fast_path?), + )) +} + +#[allow(clippy::too_many_arguments)] +pub fn build_production_sha_sumfold_group_with_linear_cache( + traces: &[ProjectedTrace], + publics: &[ProjectedPublic], + linear_cache: &[LinearResidualCoeffTable], + beta: &[F], + r_ic: &[F; SHA_ROW_VARS], + a: &F, + lambda: &F, + rho: &F, + xi: &F, + booleanity_sources: &[ShaBooleanitySource], + prefix_vars: usize, + field_cfg: &F::Config, +) -> Result, ShaProjectionError> +where + F: InnerTransparentField + DelayedFieldProductSum + Send + Sync + 'static, + F::Inner: Zero, +{ + let row_weights = build_eq_x_r_vec(r_ic, field_cfg)?; + let beta_eq_weights = build_eq_x_r_vec(beta, field_cfg)?; + build_production_sha_sumfold_group_with_linear_cache_and_weights( + traces, + publics, + linear_cache, + beta, + &beta_eq_weights, + r_ic, + &row_weights, + a, + lambda, + rho, + xi, + booleanity_sources, + prefix_vars, + field_cfg, + ) +} + +#[allow(clippy::too_many_arguments)] +pub fn build_production_sha_sumfold_group_with_linear_cache_and_weights( + traces: &[ProjectedTrace], + publics: &[ProjectedPublic], + linear_cache: &[LinearResidualCoeffTable], + beta: &[F], + beta_eq_weights: &[F], + r_ic: &[F; SHA_ROW_VARS], + row_weights: &[F], + a: &F, + lambda: &F, + rho: &F, + xi: &F, + booleanity_sources: &[ShaBooleanitySource], + prefix_vars: usize, + field_cfg: &F::Config, +) -> Result, ShaProjectionError> +where + F: InnerTransparentField + DelayedFieldProductSum + Send + Sync + 'static, + F::Inner: Zero, +{ + let ell = validate_sha_sumfold_inputs(traces, publics, beta)?; + if linear_cache.len() != traces.len() { + return Err(ShaProjectionError::InstanceCountMismatch { + got: linear_cache.len(), + expected: traces.len(), + }); + } + if beta_eq_weights.len() != traces.len() { + return Err(ShaProjectionError::InstanceCountMismatch { + got: beta_eq_weights.len(), + expected: traces.len(), + }); + } + if row_weights.len() != SHA_ROW_COUNT { + return Err(ShaProjectionError::ColumnRowCount { + kind: "row_weights", + col: 0, + got: row_weights.len(), + expected: SHA_ROW_COUNT, + }); + } + if prefix_vars > ell { + return Err(SumFoldError::Ell0TooLarge { + ell0: prefix_vars, + ell, + } + .into()); + } + if prefix_vars == 0 { + return build_dense_sha_sumfold_group_with_linear_cache_and_weights( + traces, + linear_cache, + beta, + beta_eq_weights, + row_weights, + a, + lambda, + rho, + xi, + booleanity_sources, + field_cfg, + ); + } + + let fast_path = RelationSumFoldPrefixFastPath::new_with_linear_cache( + traces, + publics, + beta, + r_ic, + row_weights, + a, + lambda, + rho, + xi, + booleanity_sources, + prefix_vars, + linear_cache, + field_cfg, + ); + + Ok(MultiDegreeSumcheckGroup::with_prefix_fast( + 3, + Vec::new(), + sha_sumfold_comb_fn( + row_weights.to_vec(), + powers( + rho.clone(), + F::one_with_cfg(field_cfg), + booleanity_sources.len(), + ), + xi.clone(), + field_cfg, + ), + Box::new(fast_path?), + )) +} + +#[allow(clippy::too_many_arguments)] +fn build_dense_sha_sumfold_group_with_linear_cache_and_weights( + traces: &[ProjectedTrace], + linear_cache: &[LinearResidualCoeffTable], + beta: &[F], + beta_eq_weights: &[F], + row_weights: &[F], + a: &F, + lambda: &F, + rho: &F, + xi: &F, + booleanity_sources: &[ShaBooleanitySource], + field_cfg: &F::Config, +) -> Result, ShaProjectionError> +where + F: InnerTransparentField + DelayedFieldProductSum + Send + Sync + 'static, + F::Inner: Zero, +{ + if beta_eq_weights.len() != traces.len() { + return Err(ShaProjectionError::InstanceCountMismatch { + got: beta_eq_weights.len(), + expected: traces.len(), + }); + } + if row_weights.len() != SHA_ROW_COUNT { + return Err(ShaProjectionError::ColumnRowCount { + kind: "row_weights", + col: 0, + got: row_weights.len(), + expected: SHA_ROW_COUNT, + }); + } + let a_powers = build_sha_residual_eval_powers(a, field_cfg); + let lambda_powers = build_sha_lambda_powers(lambda, field_cfg); + let booleanity_weights = build_booleanity_weights(rho, xi, booleanity_sources.len(), field_cfg); + let linear_values = + build_sha_sumfold_linear_accumulator(linear_cache, &a_powers, &lambda_powers, field_cfg)?; + build_dense_sha_sumfold_group_from_accumulators( + traces, + beta, + beta_eq_weights, + row_weights, + &linear_values, + &booleanity_weights, + booleanity_sources, + field_cfg, + ) +} + +#[allow(clippy::too_many_arguments)] +pub fn build_production_sha_sumfold_group_owned( + traces: Box<[ProjectedTrace]>, + publics: &[ProjectedPublic], + beta: &[F], + r_ic: &[F; SHA_ROW_VARS], + a: &F, + lambda: &F, + rho: &F, + xi: &F, + booleanity_sources: &[ShaBooleanitySource], + prefix_vars: usize, + field_cfg: &F::Config, +) -> Result, ShaProjectionError> +where + F: InnerTransparentField + DelayedFieldProductSum + Send + Sync + 'static, + F::Inner: Zero, +{ + let ell = validate_sha_sumfold_inputs(&traces, publics, beta)?; + if prefix_vars > ell { + return Err(SumFoldError::Ell0TooLarge { + ell0: prefix_vars, + ell, + } + .into()); + } + if prefix_vars == 0 { + return build_dense_sha_sumfold_group( + &traces, + publics, + beta, + r_ic, + a, + lambda, + rho, + xi, + booleanity_sources, + field_cfg, + ); + } + + let fast_path = RelationSumFoldPrefixFastPath::new_owned( + traces, + publics, + beta, + r_ic, + a, + lambda, + rho, + xi, + booleanity_sources, + prefix_vars, + field_cfg, + ); + + Ok(MultiDegreeSumcheckGroup::with_prefix_fast( + 3, + Vec::new(), + sha_sumfold_comb_fn( + build_eq_x_r_vec(r_ic, field_cfg)?, + powers( + rho.clone(), + F::one_with_cfg(field_cfg), + booleanity_sources.len(), + ), + xi.clone(), + field_cfg, + ), + Box::new(fast_path?), + )) +} + +fn sha_sumfold_comb_fn( + row_weights: Vec, + rho_powers: Vec, + xi: F, + field_cfg: &F::Config, +) -> CombFn +where + F: PrimeField + Send + Sync + 'static, +{ + let booleanity_weights = rho_powers + .into_iter() + .map(|rho_power| xi.clone() * rho_power) + .collect(); + sha_weighted_sumfold_comb_fn(row_weights, booleanity_weights, field_cfg) +} + +fn sha_weighted_sumfold_comb_fn( + row_weights: Vec, + booleanity_weights: Vec, + field_cfg: &F::Config, +) -> CombFn +where + F: PrimeField + Send + Sync + 'static, +{ + let zero_for_comb = F::zero_with_cfg(field_cfg); + let one_for_comb = F::one_with_cfg(field_cfg); + Box::new(move |values: &[F]| { + let eq_beta = values[0].clone(); + let linear = values[1].clone(); + let mut bool_sum = zero_for_comb.clone(); + let mut cursor = 2usize; + for booleanity_weight in &booleanity_weights { + for row_weight in &row_weights { + let d = values[cursor].clone(); + cursor += 1; + let term = d.clone() * (d - one_for_comb.clone()); + bool_sum += row_weight.clone() * booleanity_weight * term; + } + } + eq_beta * (linear + bool_sum) + }) +} + +fn validate_sha_sumfold_inputs( + traces: &[ProjectedTrace], + publics: &[ProjectedPublic], + beta: &[F], +) -> Result { + if traces.is_empty() { + return Err(ShaProjectionError::InstanceCountNotPowerOfTwo { got: 0 }); + } + if traces.len() != publics.len() { + return Err(ShaProjectionError::InstanceCountMismatch { + got: publics.len(), + expected: traces.len(), + }); + } + if !traces.len().is_power_of_two() { + return Err(ShaProjectionError::InstanceCountNotPowerOfTwo { got: traces.len() }); + } + let ell = usize::try_from(traces.len().trailing_zeros()).expect("trailing_zeros fits usize"); + if beta.len() != ell { + return Err(ShaProjectionError::InstanceCountMismatch { + got: beta.len(), + expected: ell, + }); + } + for trace in traces { + validate_trace(trace)?; + } + for public in publics { + validate_public(public)?; + } + Ok(ell) +} + +fn binary_len(vars: usize) -> usize { + 1usize + .checked_shl(u32::try_from(vars).expect("vars fits u32")) + .expect("binary domain size fits usize") +} + +fn checked_ternary_len(vars: usize) -> Result { + let mut size = 1usize; + for _ in 0..vars { + size = size + .checked_mul(3) + .ok_or(SumFoldError::DomainTooLarge { ell: vars })?; + } + Ok(size) +} + +fn eq_weights_or_one(point: &[F], field_cfg: &F::Config) -> Result, ShaProjectionError> +where + F: PrimeField, +{ + if point.is_empty() { + Ok(vec![F::one_with_cfg(field_cfg)]) + } else { + Ok(build_eq_x_r_vec(point, field_cfg)?) + } +} + +fn eq_one_var(beta: &F, x: &F, field_cfg: &F::Config) -> F +where + F: PrimeField, +{ + let one = F::one_with_cfg(field_cfg); + x.clone() * beta + (one.clone() - x) * (one - beta) +} + +fn validate_sha_sumfold_traces( + traces: &[ProjectedTrace], + beta: &[F], +) -> Result { + if traces.is_empty() { + return Err(ShaProjectionError::InstanceCountNotPowerOfTwo { got: 0 }); + } + if !traces.len().is_power_of_two() { + return Err(ShaProjectionError::InstanceCountNotPowerOfTwo { got: traces.len() }); + } + let ell = usize::try_from(traces.len().trailing_zeros()).expect("trailing_zeros fits usize"); + if beta.len() != ell { + return Err(ShaProjectionError::InstanceCountMismatch { + got: beta.len(), + expected: ell, + }); + } + #[cfg(debug_assertions)] + { + for trace in traces { + validate_trace(trace)?; + } + } + Ok(ell) +} + +#[allow(clippy::too_many_arguments)] +pub fn build_sha_sumfold_quadratic_prefix_accumulator( + traces: &[ProjectedTrace], + booleanity_sources: &[ShaBooleanitySource], + prefix_vars: usize, + row_weights: &[F], + booleanity_weights: &[F], + field_cfg: &F::Config, +) -> Result, ShaProjectionError> +where + F: PrimeField, +{ + if traces.is_empty() { + return Err(ShaProjectionError::InstanceCountNotPowerOfTwo { got: 0 }); + } + if !traces.len().is_power_of_two() { + return Err(ShaProjectionError::InstanceCountNotPowerOfTwo { got: traces.len() }); + } + let ell = usize::try_from(traces.len().trailing_zeros()).expect("trailing_zeros fits usize"); + if prefix_vars > ell { + return Err(SumFoldError::Ell0TooLarge { + ell0: prefix_vars, + ell, + } + .into()); + } + if row_weights.len() != SHA_ROW_COUNT { + return Err(ShaProjectionError::ColumnRowCount { + kind: "row_weights", + col: 0, + got: row_weights.len(), + expected: SHA_ROW_COUNT, + }); + } + if booleanity_weights.len() != booleanity_sources.len() { + return Err(ShaProjectionError::ColumnRowCount { + kind: "booleanity_weights", + col: 0, + got: booleanity_weights.len(), + expected: booleanity_sources.len(), + }); + } + #[cfg(debug_assertions)] + { + for trace in traces { + validate_trace(trace)?; + } + } + if prefix_vars == 0 { + return Ok(Vec::new()); + } + + let tail_len = binary_len(ell - prefix_vars); + build_sha_booleanity_prefix_tail_table( + traces, + booleanity_sources, + prefix_vars, + tail_len, + row_weights, + booleanity_weights, + field_cfg, + ) +} + +#[allow(clippy::arithmetic_side_effects)] +fn build_sha_booleanity_prefix_tail_table( + traces: &[ProjectedTrace], + booleanity_sources: &[ShaBooleanitySource], + prefix_vars: usize, + tail_len: usize, + row_weights: &[F], + booleanity_weights: &[F], + field_cfg: &F::Config, +) -> Result, ShaProjectionError> +where + F: PrimeField, +{ + let prefix_len = binary_len(prefix_vars); + let ternary_len = checked_ternary_len(prefix_vars)?; + let mut table = vec![F::zero_with_cfg(field_cfg); ternary_len * tail_len]; + if booleanity_sources.is_empty() { + return Ok(table); + } + + let coeff_plans = ternary_coeff_plans(prefix_vars)?; + let word_bit_source_count = booleanity_sources + .iter() + .take_while(|source| matches!(source, ShaBooleanitySource::WordBit { .. })) + .count(); + let suffix_sources = &booleanity_sources[word_bit_source_count..]; + let suffix_count = suffix_sources.len(); + let suffix_needs_virtuals = sources_need_virtuals(suffix_sources); + let small_square_fields: Vec = small_square_field_table(field_cfg); + let mask_count = 1usize << prefix_len; + let mut mask_coeff_table = Vec::with_capacity(mask_count * ternary_len); + for mask in 0..mask_count { + let source_mask = u8::try_from(mask).map_err(|_| { + ShaProjectionError::NonCanonicalProofObject( + "booleanity prefix masks require at most eight prefix entries", + ) + })?; + for plan in &coeff_plans { + mask_coeff_table.push(booleanity_word_bit_mask_degree_two_coeff( + source_mask, + plan, + &small_square_fields, + field_cfg, + )); + } + } + let one = F::one_with_cfg(field_cfg); + let partials = cfg_chunks!(row_weights, 8) + .enumerate() + .map(|(chunk_idx, row_weight_chunk)| { + let row_offset = chunk_idx * 8; + let mut partial = vec![F::zero_with_cfg(field_cfg); ternary_len * tail_len]; + let mut suffix_values = vec![F::zero_with_cfg(field_cfg); prefix_len * suffix_count]; + let mut mask_weights = vec![F::zero_with_cfg(field_cfg); mask_count]; + let mut touched_masks = Vec::new(); + for tail in 0..tail_len { + for (row_in_chunk, row_weight) in row_weight_chunk.iter().enumerate() { + let row = row_offset + row_in_chunk; + for (source_idx, source) in booleanity_sources[..word_bit_source_count] + .iter() + .enumerate() + { + let ShaBooleanitySource::WordBit { col, bit } = source else { + unreachable!("word-bit prefix only contains word-bit sources"); + }; + let mask = booleanity_word_bit_prefix_mask( + traces, + *col, + *bit, + prefix_vars, + tail, + row, + field_cfg, + ); + let mask_idx = usize::from(mask); + if F::is_zero(&mask_weights[mask_idx]) { + touched_masks.push(mask_idx); + } + mask_weights[mask_idx] += booleanity_weights[source_idx].clone(); + } + + for &mask_idx in &touched_masks { + let source_weight = row_weight.clone() * &mask_weights[mask_idx]; + let coeff_offset = mask_idx * ternary_len; + for ternary_idx in 0..ternary_len { + let coeff = &mask_coeff_table[coeff_offset + ternary_idx]; + if F::is_zero(&coeff) { + continue; + } + partial[tail * ternary_len + ternary_idx] += + source_weight.clone() * coeff; + } + mask_weights[mask_idx] = F::zero_with_cfg(field_cfg); + } + touched_masks.clear(); + + if suffix_count != 0 { + fill_booleanity_source_prefix_values( + traces, + suffix_sources, + prefix_vars, + tail, + row, + suffix_needs_virtuals, + &mut suffix_values, + field_cfg, + )?; + } + + let mut generic_suffixes = Vec::new(); + for suffix_idx in 0..suffix_count { + let source_idx = word_bit_source_count + suffix_idx; + let booleanity_weight = &booleanity_weights[source_idx]; + if let Some(mask_idx) = booleanity_prefix_values_binary_mask( + &suffix_values, + suffix_count, + suffix_idx, + &one, + ) { + if F::is_zero(&mask_weights[mask_idx]) { + touched_masks.push(mask_idx); + } + mask_weights[mask_idx] += booleanity_weight.clone(); + } else { + generic_suffixes.push(suffix_idx); + } + } + + for &mask_idx in &touched_masks { + let source_weight = row_weight.clone() * &mask_weights[mask_idx]; + let coeff_offset = mask_idx * ternary_len; + for ternary_idx in 0..ternary_len { + let coeff = &mask_coeff_table[coeff_offset + ternary_idx]; + if F::is_zero(&coeff) { + continue; + } + partial[tail * ternary_len + ternary_idx] += + source_weight.clone() * coeff; + } + mask_weights[mask_idx] = F::zero_with_cfg(field_cfg); + } + touched_masks.clear(); + + for suffix_idx in generic_suffixes { + let source_idx = word_bit_source_count + suffix_idx; + let booleanity_weight = &booleanity_weights[source_idx]; + let source_weight = row_weight.clone() * booleanity_weight; + for (ternary_idx, plan) in coeff_plans.iter().enumerate() { + let coeff = booleanity_degree_two_coeff( + &suffix_values, + suffix_count, + suffix_idx, + plan, + field_cfg, + ); + if F::is_zero(&coeff) { + continue; + } + partial[tail * ternary_len + ternary_idx] += + source_weight.clone() * coeff; + } + } + } + } + Ok(partial) + }) + .collect::, ShaProjectionError>>()?; + for partial in partials { + for (acc, value) in table.iter_mut().zip(partial) { + *acc += value; + } + } + Ok(table) +} + +#[allow(clippy::arithmetic_side_effects)] +fn booleanity_word_bit_prefix_mask( + traces: &[ProjectedTrace], + col: ShaWordCol, + bit: usize, + prefix_vars: usize, + tail: usize, + row: usize, + field_cfg: &F::Config, +) -> u8 +where + F: PrimeField, +{ + let prefix_len = binary_len(prefix_vars); + let mut value_mask = 0u8; + for prefix in 0..prefix_len { + let instance_idx = prefix + (tail << prefix_vars); + let trace = &traces[instance_idx]; + if !F::is_zero(&bit_at_shifted_or_zero_fast( + trace, col, row, 0, bit, field_cfg, + )) { + value_mask |= 1u8 << prefix; + } + } + value_mask +} + +#[allow(clippy::arithmetic_side_effects)] +fn bind_sha_booleanity_sources_to_prefix( + traces: &[ProjectedTrace], + booleanity_sources: &[ShaBooleanitySource], + prefix_vars: usize, + tail_len: usize, + prefix_weights: &[F], + field_cfg: &F::Config, +) -> Result>, ShaProjectionError> +where + F: DelayedFieldProductSum, +{ + let prefix_len = binary_len(prefix_vars); + let source_count = booleanity_sources.len(); + let needs_virtuals = sources_need_virtuals(booleanity_sources); + let mut source_values = vec![F::zero_with_cfg(field_cfg); prefix_len * source_count]; + let mut out = vec![vec![F::zero_with_cfg(field_cfg); tail_len]; source_count * SHA_ROW_COUNT]; + let mut source_column_values = Vec::with_capacity(prefix_len); + + for tail in 0..tail_len { + for row in 0..SHA_ROW_COUNT { + fill_booleanity_source_prefix_values( + traces, + booleanity_sources, + prefix_vars, + tail, + row, + needs_virtuals, + &mut source_values, + field_cfg, + )?; + for source_idx in 0..source_count { + source_column_values.clear(); + for prefix in 0..prefix_len { + source_column_values + .push(source_values[prefix * source_count + source_idx].clone()); + } + let acc = FieldFieldInnerProduct::inner_product::( + prefix_weights, + &source_column_values, + F::zero_with_cfg(field_cfg), + ) + .map_err(ShaProjectionError::from)?; + out[source_idx * SHA_ROW_COUNT + row][tail] = acc; + } + } + } + + Ok(out) +} + +#[allow(clippy::arithmetic_side_effects)] +fn fill_booleanity_source_prefix_values( + traces: &[ProjectedTrace], + booleanity_sources: &[ShaBooleanitySource], + prefix_vars: usize, + tail: usize, + row: usize, + needs_virtuals: bool, + out: &mut [F], + field_cfg: &F::Config, +) -> Result<(), ShaProjectionError> +where + F: PrimeField, +{ + let prefix_len = binary_len(prefix_vars); + let source_count = booleanity_sources.len(); + debug_assert_eq!(out.len(), prefix_len * source_count); + + for prefix in 0..prefix_len { + let instance_idx = prefix + (tail << prefix_vars); + let trace = &traces[instance_idx]; + let virtuals = if needs_virtuals { + Some(reconstruct_virtual_ch_maj_at_row_unchecked( + trace, row, field_cfg, + )?) + } else { + None + }; + + for (source_idx, source) in booleanity_sources.iter().enumerate() { + let value = match source { + ShaBooleanitySource::WordBit { col, bit } => { + bit_at_shifted_or_zero(trace, *col, row, 0, *bit, field_cfg)? + } + ShaBooleanitySource::VirtualCh1 { bit } => { + virtual_bit_at(&virtuals.as_ref().expect("virtuals computed").ch1, *bit)? + } + ShaBooleanitySource::VirtualCh2 { bit } => { + virtual_bit_at(&virtuals.as_ref().expect("virtuals computed").ch2, *bit)? + } + ShaBooleanitySource::VirtualMaj { bit } => { + virtual_bit_at(&virtuals.as_ref().expect("virtuals computed").maj, *bit)? + } + }; + out[prefix * source_count + source_idx] = value; + } + } + + Ok(()) +} + +#[allow(clippy::arithmetic_side_effects)] +fn booleanity_degree_two_coeff( + source_values: &[F], + source_count: usize, + source_idx: usize, + plan: &TernaryCoeffPlan, + field_cfg: &F::Config, +) -> F +where + F: PrimeField, +{ + let value_at = + |prefix: usize| -> F { source_values[prefix * source_count + source_idx].clone() }; + if plan.support_mask == 0 { + let d = value_at(plan.finite_bits); + return d.clone() * (d - F::one_with_cfg(field_cfg)); + } + + let mut delta = F::zero_with_cfg(field_cfg); + for (prefix, positive) in &plan.vertices { + if *positive { + delta += value_at(*prefix); + } else { + delta -= value_at(*prefix); + } + } + + delta.clone() * delta +} + +fn booleanity_prefix_values_binary_mask( + source_values: &[F], + source_count: usize, + source_idx: usize, + one: &F, +) -> Option +where + F: PrimeField, +{ + if source_count == 0 { + return Some(0); + } + let prefix_len = source_values.len() / source_count; + if prefix_len > usize::BITS as usize { + return None; + } + let mut mask = 0usize; + for prefix in 0..prefix_len { + let value = &source_values[prefix * source_count + source_idx]; + if F::is_zero(value) { + continue; + } + if value == one { + mask |= 1usize << prefix; + } else { + return None; + } + } + Some(mask) +} + +fn booleanity_word_bit_mask_degree_two_coeff( + source_mask: u8, + plan: &TernaryCoeffPlan, + small_square_fields: &[F], + field_cfg: &F::Config, +) -> F +where + F: PrimeField, +{ + if plan.support_mask == 0 { + return F::zero_with_cfg(field_cfg); + } + let mut delta = 0i32; + for (prefix, positive) in &plan.vertices { + if ((source_mask >> prefix) & 1) == 0 { + continue; + } + if *positive { + delta += 1; + } else { + delta -= 1; + } + } + let square = usize::try_from(delta * delta).expect("square is non-negative"); + small_square_fields + .get(square) + .cloned() + .unwrap_or_else(|| small_usize_to_field(square, field_cfg)) +} + +fn small_square_field_table(field_cfg: &F::Config) -> Vec +where + F: PrimeField, +{ + (0..=64) + .map(|value| small_usize_to_field(value, field_cfg)) + .collect() +} + +fn small_usize_to_field(value: usize, field_cfg: &F::Config) -> F +where + F: PrimeField, +{ + let one = F::one_with_cfg(field_cfg); + let mut out = F::zero_with_cfg(field_cfg); + for _ in 0..value { + out += &one; + } + out +} + +#[allow(clippy::arithmetic_side_effects)] +fn ternary_point_parts(mut index: usize, prefix_vars: usize) -> (usize, usize) { + let mut support_mask = 0usize; + let mut finite_bits = 0usize; + for var in 0..prefix_vars { + let digit = index % 3; + index /= 3; + match digit { + 0 => {} + 1 => finite_bits |= 1usize << var, + 2 => support_mask |= 1usize << var, + _ => unreachable!("ternary digit is always 0, 1, or 2"), + } + } + (support_mask, finite_bits) +} + +fn sources_need_virtuals(booleanity_sources: &[ShaBooleanitySource]) -> bool { + booleanity_sources.iter().any(|source| { + matches!( + source, + ShaBooleanitySource::VirtualCh1 { .. } + | ShaBooleanitySource::VirtualCh2 { .. } + | ShaBooleanitySource::VirtualMaj { .. } + ) + }) +} + +#[allow(clippy::arithmetic_side_effects)] +fn ternary_coeff_plans(prefix_vars: usize) -> Result, ShaProjectionError> { + let ternary_len = checked_ternary_len(prefix_vars)?; + let mut plans = Vec::with_capacity(ternary_len); + for ternary_idx in 0..ternary_len { + let (support_mask, finite_bits) = ternary_point_parts(ternary_idx, prefix_vars); + let mut vertices = Vec::new(); + if support_mask != 0 { + let mut support_bits = [0usize; usize::BITS as usize]; + let mut mask = support_mask; + let mut support_size = 0usize; + while mask != 0 { + let bit = mask & mask.wrapping_neg(); + support_bits[support_size] = bit; + support_size += 1; + mask ^= bit; + } + vertices.reserve(1usize << support_size); + for vertex in 0..(1usize << support_size) { + let mut prefix = finite_bits; + for (pos, bit) in support_bits[..support_size].iter().enumerate() { + if ((vertex >> pos) & 1) == 1 { + prefix |= *bit; + } + } + let positive = (support_size + - usize::try_from(vertex.count_ones()).expect("count_ones fits usize")) + % 2 + == 0; + vertices.push((prefix, positive)); + } + } + plans.push(TernaryCoeffPlan { + support_mask, + finite_bits, + vertices, + }); + } + Ok(plans) +} + +#[allow(clippy::arithmetic_side_effects)] +fn binary_bits_to_ternary_index(mut bits: usize, vars: usize) -> usize { + let mut index = 0usize; + let mut scale = 1usize; + for _ in 0..vars { + if bits & 1 == 1 { + index += scale; + } + bits >>= 1; + scale *= 3; + } + index +} + +/// Build the expression-backed folded row sumcheck group. +/// +/// The terminal at the verifier challenge is tied to source MLE endpoint +/// values, including booleanity sources, rather than to an opaque MLE of +/// precomputed row-integrand values. +#[allow(clippy::too_many_arguments)] +pub fn build_expression_folded_row_sumcheck_group( + trace: &ProjectedTrace, + public: &ProjectedPublic, + r_ic: &[F; SHA_ROW_VARS], + a: &F, + lambda: &F, + rho: &F, + xi: &F, + booleanity_sources: &[ShaBooleanitySource], + field_cfg: &F::Config, +) -> Result, ShaProjectionError> +where + F: InnerTransparentField + DelayedFieldProductSum + Send + Sync + 'static, + F::Inner: Zero, +{ + let row_weights = build_eq_x_r_vec(r_ic, field_cfg)?; + build_expression_folded_row_sumcheck_group_with_row_weights( + trace, + public, + &row_weights, + a, + lambda, + rho, + xi, + booleanity_sources, + field_cfg, + ) +} + +#[allow(clippy::too_many_arguments)] +pub fn build_expression_folded_row_sumcheck_group_with_row_weights( + trace: &ProjectedTrace, + public: &ProjectedPublic, + row_weights: &[F], + a: &F, + lambda: &F, + rho: &F, + xi: &F, + booleanity_sources: &[ShaBooleanitySource], + field_cfg: &F::Config, +) -> Result, ShaProjectionError> +where + F: InnerTransparentField + DelayedFieldProductSum + Send + Sync + 'static, + F::Inner: Zero, +{ + validate_trace(trace)?; + validate_public(public)?; + if row_weights.len() != SHA_ROW_COUNT { + return Err(ShaProjectionError::ColumnRowCount { + kind: "row_weights", + col: 0, + got: row_weights.len(), + expected: SHA_ROW_COUNT, + }); + } + + let zero = F::zero_with_cfg(field_cfg); + let zero_inner = zero.inner().clone(); + let mut mles = Vec::with_capacity(2 + booleanity_sources.len()); + + mles.push(DenseMultilinearExtension::from_evaluations_vec( + SHA_ROW_VARS, + row_weights + .iter() + .map(|value| value.inner().clone()) + .collect(), + zero_inner.clone(), + )); + + let linear_values = (0..SHA_ROW_COUNT) + .map(|row| sha_linear_residual_row_value(trace, public, row, a, lambda, field_cfg)) + .collect::, _>>()?; + mles.push(DenseMultilinearExtension::from_evaluations_vec( + SHA_ROW_VARS, + linear_values + .iter() + .map(|value| value.inner().clone()) + .collect(), + zero_inner.clone(), + )); + + let needs_virtuals = sources_need_virtuals(booleanity_sources); + let mut source_values = (0..booleanity_sources.len()) + .map(|_| Vec::with_capacity(SHA_ROW_COUNT)) + .collect::>(); + for row in 0..SHA_ROW_COUNT { + let virtuals = if needs_virtuals { + Some(reconstruct_virtual_ch_maj_at_row_unchecked( + trace, row, field_cfg, + )?) + } else { + None + }; + for (source_idx, source) in booleanity_sources.iter().enumerate() { + source_values[source_idx].push(booleanity_source_value_at_row_with_virtuals( + trace, + row, + source, + virtuals.as_ref(), + field_cfg, + )?); + } + } + for values in source_values { + mles.push(DenseMultilinearExtension::from_evaluations_vec( + SHA_ROW_VARS, + values.iter().map(|value| value.inner().clone()).collect(), + zero_inner.clone(), + )); + } + + let rho_powers = powers( + rho.clone(), + F::one_with_cfg(field_cfg), + booleanity_sources.len(), + ); + let xi = xi.clone(); + let zero_for_comb = F::zero_with_cfg(field_cfg); + let one_for_comb = F::one_with_cfg(field_cfg); + Ok(MultiDegreeSumcheckGroup::new( + 3, + mles, + Box::new(move |values: &[F]| { + let row_weight = values[0].clone(); + let linear = values[1].clone(); + let mut bool_sum = zero_for_comb.clone(); + for (d, rho_power) in values[2..].iter().zip(rho_powers.iter()) { + let term = d.clone() * (d.clone() - one_for_comb.clone()); + bool_sum += rho_power.clone() * term; + } + row_weight * (linear + xi.clone() * bool_sum) + }), + )) +} + +/// Claimed sum for the expression-backed folded row check. +#[allow(clippy::too_many_arguments)] +pub fn expression_folded_row_sum( + trace: &ProjectedTrace, + public: &ProjectedPublic, + r_ic: &[F; SHA_ROW_VARS], + a: &F, + lambda: &F, + rho: &F, + xi: &F, + booleanity_sources: &[ShaBooleanitySource], + field_cfg: &F::Config, +) -> Result +where + F: InnerTransparentField + DelayedFieldProductSum, + F::Inner: Zero, +{ + let values = folded_row_integrand_values( + trace, + public, + r_ic, + a, + lambda, + rho, + xi, + booleanity_sources, + field_cfg, + )?; + folded_row_integrand_sum(&values, field_cfg) +} + +/// Claimed sum for the expression-backed folded row check with precomputed +/// `eq(r_ic, row)` weights. +#[allow(clippy::too_many_arguments)] +pub fn expression_folded_row_sum_with_row_weights( + trace: &ProjectedTrace, + public: &ProjectedPublic, + row_weights: &[F], + a: &F, + lambda: &F, + rho: &F, + xi: &F, + booleanity_sources: &[ShaBooleanitySource], + field_cfg: &F::Config, +) -> Result +where + F: InnerTransparentField + DelayedFieldProductSum, + F::Inner: Zero, +{ + let values = folded_row_integrand_values_with_row_weights( + trace, + public, + row_weights, + a, + lambda, + rho, + xi, + booleanity_sources, + field_cfg, + )?; + folded_row_integrand_sum(&values, field_cfg) +} + +#[allow(clippy::too_many_arguments)] +pub fn expression_folded_row_sum_with_vectors( + trace: &ProjectedTrace, + public: &ProjectedPublic, + row_weights: &[F], + a_powers: &[F], + lambda_powers: &[F], + booleanity_weights: &[F], + booleanity_sources: &[ShaBooleanitySource], + field_cfg: &F::Config, +) -> Result +where + F: InnerTransparentField + DelayedFieldProductSum, + F::Inner: Zero, +{ + let values = folded_row_integrand_values_with_vectors( + trace, + public, + row_weights, + a_powers, + lambda_powers, + booleanity_weights, + booleanity_sources, + field_cfg, + )?; + folded_row_integrand_sum(&values, field_cfg) +} + +pub fn sha_word_bits_at_point( + trace: &ProjectedTrace, + col: ShaWordCol, + shift: usize, + point: &[F], + field_cfg: &F::Config, +) -> Result<[F; SHA_WORD_BITS], ShaProjectionError> +where + F: PrimeField, +{ + if point.len() != SHA_ROW_VARS { + return Err(ShaProjectionError::RowPointLength { got: point.len() }); + } + validate_trace(trace)?; + let row_weights = build_eq_x_r_vec(point, field_cfg)?; + sha_word_bits_at_point_with_weights(trace, col, shift, &row_weights, field_cfg) +} + +pub fn sha_word_bits_at_point_with_weights( + trace: &ProjectedTrace, + col: ShaWordCol, + shift: usize, + row_weights: &[F], + field_cfg: &F::Config, +) -> Result<[F; SHA_WORD_BITS], ShaProjectionError> +where + F: PrimeField, +{ + validate_trace(trace)?; + if row_weights.len() != SHA_ROW_COUNT { + return Err(ShaProjectionError::ColumnRowCount { + kind: "row_weights", + col: 0, + got: row_weights.len(), + expected: SHA_ROW_COUNT, + }); + } + sha_word_bits_at_point_with_weights_unchecked(trace, col, shift, row_weights, field_cfg) +} + +pub fn sha_word_bits_at_point_with_weights_unchecked( + trace: &ProjectedTrace, + col: ShaWordCol, + shift: usize, + row_weights: &[F], + field_cfg: &F::Config, +) -> Result<[F; SHA_WORD_BITS], ShaProjectionError> +where + F: PrimeField, +{ + let mut bits: [F; SHA_WORD_BITS] = std::array::from_fn(|_| F::zero_with_cfg(field_cfg)); + for (row, row_weight) in row_weights.iter().enumerate() { + for (bit, out) in bits.iter_mut().enumerate() { + *out += row_weight.clone() + * bit_at_shifted_or_zero(trace, col, row, shift, bit, field_cfg)?; + } + } + Ok(bits) +} + +pub fn sha_scalarized_word_at_point( + trace: &ProjectedTrace, + col: ShaWordCol, + shift: usize, + point: &[F], + field_cfg: &F::Config, +) -> Result +where + F: PrimeField, +{ + if point.len() != SHA_ROW_VARS { + return Err(ShaProjectionError::RowPointLength { got: point.len() }); + } + validate_trace(trace)?; + let row_weights = build_eq_x_r_vec(point, field_cfg)?; + sha_scalarized_word_at_point_with_weights(trace, col, shift, &row_weights, field_cfg) +} + +pub fn sha_scalarized_word_at_point_with_weights( + trace: &ProjectedTrace, + col: ShaWordCol, + shift: usize, + row_weights: &[F], + field_cfg: &F::Config, +) -> Result +where + F: PrimeField, +{ + validate_trace(trace)?; + if row_weights.len() != SHA_ROW_COUNT { + return Err(ShaProjectionError::ColumnRowCount { + kind: "row_weights", + col: 0, + got: row_weights.len(), + expected: SHA_ROW_COUNT, + }); + } + let mut value = F::zero_with_cfg(field_cfg); + for (row, row_weight) in row_weights.iter().enumerate() { + value += row_weight.clone() + * scalarized_word_at_shifted_or_zero(trace, col, row, shift, field_cfg)?; + } + Ok(value) +} + +pub fn sha_int_at_point( + trace: &ProjectedTrace, + col: ShaIntCol, + point: &[F], + field_cfg: &F::Config, +) -> Result +where + F: PrimeField, +{ + if point.len() != SHA_ROW_VARS { + return Err(ShaProjectionError::RowPointLength { got: point.len() }); + } + validate_trace(trace)?; + let row_weights = build_eq_x_r_vec(point, field_cfg)?; + sha_int_at_point_with_weights(trace, col, &row_weights, field_cfg) +} + +pub fn sha_int_at_point_with_weights( + trace: &ProjectedTrace, + col: ShaIntCol, + row_weights: &[F], + field_cfg: &F::Config, +) -> Result +where + F: PrimeField, +{ + validate_trace(trace)?; + if row_weights.len() != SHA_ROW_COUNT { + return Err(ShaProjectionError::ColumnRowCount { + kind: "row_weights", + col: 0, + got: row_weights.len(), + expected: SHA_ROW_COUNT, + }); + } + sha_int_at_point_with_weights_unchecked(trace, col, row_weights, field_cfg) +} + +pub fn sha_int_at_point_with_weights_unchecked( + trace: &ProjectedTrace, + col: ShaIntCol, + row_weights: &[F], + field_cfg: &F::Config, +) -> Result +where + F: PrimeField, +{ + let mut value = F::zero_with_cfg(field_cfg); + for (row, row_weight) in row_weights.iter().enumerate() { + value += row_weight.clone() * int_scalar(trace, col, row, field_cfg)?; + } + Ok(value) +} + +pub fn sha_public_at_point( + public: &ProjectedPublic, + col: ShaPublicCol, + shift: usize, + point: &[F], + field_cfg: &F::Config, +) -> Result +where + F: PrimeField, +{ + if point.len() != SHA_ROW_VARS { + return Err(ShaProjectionError::RowPointLength { got: point.len() }); + } + validate_public(public)?; + let row_weights = build_eq_x_r_vec(point, field_cfg)?; + sha_public_at_point_with_weights(public, col, shift, &row_weights, field_cfg) +} + +pub fn sha_public_at_point_with_weights( + public: &ProjectedPublic, + col: ShaPublicCol, + shift: usize, + row_weights: &[F], + field_cfg: &F::Config, +) -> Result +where + F: PrimeField, +{ + validate_public(public)?; + if row_weights.len() != SHA_ROW_COUNT { + return Err(ShaProjectionError::ColumnRowCount { + kind: "row_weights", + col: 0, + got: row_weights.len(), + expected: SHA_ROW_COUNT, + }); + } + let mut value = F::zero_with_cfg(field_cfg); + for (row, row_weight) in row_weights.iter().enumerate() { + let shifted = row.checked_add(shift).unwrap_or(SHA_ROW_COUNT); + value += row_weight.clone() * public_scalar(public, col, shifted, field_cfg)?; + } + Ok(value) +} + +pub fn verify_folded_row_sumcheck_claim( + claimed_sum: &F, + final_round_sumcheck_claim: &F, +) -> Result<(), ShaProjectionError> +where + F: PrimeField, +{ + if claimed_sum != final_round_sumcheck_claim { + return Err(ShaProjectionError::FoldedRowClaimMismatch); + } + Ok(()) +} + +pub fn residual_polys_at_row( + trace: &ProjectedTrace, + public: &ProjectedPublic, + row: usize, + field_cfg: &F::Config, +) -> Result<[DynamicPolynomialF; NUM_SHA_RESIDUAL_FAMILIES], ShaProjectionError> +where + F: PrimeField, +{ + let rho_sig0 = sparse_poly::(&[10, 19, 30], field_cfg); + let rho_sig1 = sparse_poly::(&[7, 21, 26], field_cfg); + residual_polys_at_row_with_rotation_polys(trace, public, row, &rho_sig0, &rho_sig1, field_cfg) +} + +fn residual_polys_at_row_with_rotation_polys( + trace: &ProjectedTrace, + public: &ProjectedPublic, + row: usize, + rho_sig0: &DynamicPolynomialF, + rho_sig1: &DynamicPolynomialF, + field_cfg: &F::Config, +) -> Result<[DynamicPolynomialF; NUM_SHA_RESIDUAL_FAMILIES], ShaProjectionError> +where + F: PrimeField, +{ + let constants = ShaResidualPolyConstants { + rho_sig0: rho_sig0.clone(), + rho_sig1: rho_sig1.clone(), + two: F::one_with_cfg(field_cfg) + F::one_with_cfg(field_cfg), + low_mu_coeff: pow_two(32, field_cfg), + high_mu_w_coeff: pow_two(34, field_cfg), + high_mu_3_bit_coeff: pow_two(35, field_cfg), + high_mu_1_bit_coeff: pow_two(33, field_cfg), + }; + residual_polys_at_row_with_constants(trace, public, row, &constants, field_cfg) +} + +struct ShaResidualPolyConstants { + rho_sig0: DynamicPolynomialF, + rho_sig1: DynamicPolynomialF, + two: F, + low_mu_coeff: F, + high_mu_w_coeff: F, + high_mu_3_bit_coeff: F, + high_mu_1_bit_coeff: F, +} + +impl ShaResidualPolyConstants { + fn new(field_cfg: &F::Config) -> Self { + Self { + rho_sig0: sparse_poly::(&[10, 19, 30], field_cfg), + rho_sig1: sparse_poly::(&[7, 21, 26], field_cfg), + two: F::one_with_cfg(field_cfg) + F::one_with_cfg(field_cfg), + low_mu_coeff: pow_two(32, field_cfg), + high_mu_w_coeff: pow_two(34, field_cfg), + high_mu_3_bit_coeff: pow_two(35, field_cfg), + high_mu_1_bit_coeff: pow_two(33, field_cfg), + } + } +} + +#[derive(Clone, Debug)] +struct FixedResidualCoeffAccumulator { + coeffs: Vec>, +} + +impl FixedResidualCoeffAccumulator +where + F: PrimeField, +{ + fn new(family_count: usize, coeff_count: usize, field_cfg: &F::Config) -> Self { + Self { + coeffs: (0..family_count) + .map(|_| vec![F::zero_with_cfg(field_cfg); coeff_count]) + .collect(), + } + } + + fn add_assign(&mut self, rhs: Self) { + for (dst_family, rhs_family) in self.coeffs.iter_mut().zip(rhs.coeffs) { + for (dst, rhs) in dst_family.iter_mut().zip(rhs_family) { + *dst += rhs; + } + } + } + + fn into_table(mut self) -> LinearResidualCoeffTable { + let coeffs = self + .coeffs + .drain(..) + .map(|mut coeffs| { + while coeffs.last().is_some_and(F::is_zero) { + coeffs.pop(); + } + DynamicPolynomialF { coeffs } + }) + .collect(); + LinearResidualCoeffTable { coeffs } + } + + #[inline(always)] + fn add_scaled_to_family_idx( + &mut self, + family_idx: usize, + coeff_idx: usize, + value: &F, + scale: &F, + ) { + if F::is_zero(value) || F::is_zero(scale) { + return; + } + debug_assert!(family_idx < self.coeffs.len()); + debug_assert!(coeff_idx < self.coeffs[family_idx].len()); + self.coeffs[family_idx][coeff_idx] += value.clone() * scale; + } + + #[inline(always)] + fn add_scaled(&mut self, family: ShaResidualFamily, coeff_idx: usize, value: &F, scale: &F) { + self.add_scaled_to_family_idx(family.index(), coeff_idx, value, scale); + } + + #[inline(always)] + fn add_const_scaled(&mut self, family: ShaResidualFamily, value: &F, scale: &F) { + self.add_scaled(family, 0, value, scale); + } + + fn add_trace_word_scaled( + &mut self, + family: ShaResidualFamily, + trace: &ProjectedTrace, + col: ShaWordCol, + row: usize, + row_shift: usize, + scale: &F, + field_cfg: &F::Config, + ) -> Result<(), ShaProjectionError> { + if F::is_zero(scale) { + return Ok(()); + } + for bit in 0..SHA_WORD_BITS { + let value = bit_at_shifted_or_zero(trace, col, row, row_shift, bit, field_cfg)?; + self.add_scaled(family, bit, &value, scale); + } + Ok(()) + } + + fn add_trace_word_rot_scaled( + &mut self, + family: ShaResidualFamily, + trace: &ProjectedTrace, + col: ShaWordCol, + row: usize, + rot: usize, + scale: &F, + field_cfg: &F::Config, + ) -> Result<(), ShaProjectionError> { + debug_assert!(rot < SHA_WORD_BITS); + if F::is_zero(scale) { + return Ok(()); + } + for bit in 0..SHA_WORD_BITS { + let value = bit_at_shifted_or_zero(trace, col, row, 0, bit, field_cfg)?; + self.add_scaled(family, (bit + rot) % SHA_WORD_BITS, &value, scale); + } + Ok(()) + } + + fn add_trace_word_shift_r_scaled( + &mut self, + family: ShaResidualFamily, + trace: &ProjectedTrace, + col: ShaWordCol, + row: usize, + shift: usize, + scale: &F, + field_cfg: &F::Config, + ) -> Result<(), ShaProjectionError> { + debug_assert!(shift < SHA_WORD_BITS); + if F::is_zero(scale) { + return Ok(()); + } + for out_bit in 0..(SHA_WORD_BITS - shift) { + let value = bit_at_shifted_or_zero(trace, col, row, 0, out_bit + shift, field_cfg)?; + self.add_scaled(family, out_bit, &value, scale); + } + Ok(()) + } + + fn add_trace_word_sparse_product_scaled( + &mut self, + family: ShaResidualFamily, + trace: &ProjectedTrace, + col: ShaWordCol, + row: usize, + shifts: &[usize], + scale: &F, + field_cfg: &F::Config, + ) -> Result<(), ShaProjectionError> { + if F::is_zero(scale) { + return Ok(()); + } + for bit in 0..SHA_WORD_BITS { + let value = bit_at_shifted_or_zero(trace, col, row, 0, bit, field_cfg)?; + for &shift in shifts { + self.add_scaled(family, bit + shift, &value, scale); + } + } + Ok(()) + } + + fn add_trace_int_const_scaled( + &mut self, + family: ShaResidualFamily, + trace: &ProjectedTrace, + col: ShaIntCol, + row: usize, + scale: &F, + field_cfg: &F::Config, + ) -> Result<(), ShaProjectionError> { + let value = int_scalar(trace, col, row, field_cfg)?; + self.add_const_scaled(family, &value, scale); + Ok(()) + } + + fn add_public_scalar_const_scaled( + &mut self, + family: ShaResidualFamily, + public: &ProjectedPublic, + col: ShaPublicCol, + row: usize, + scale: &F, + field_cfg: &F::Config, + ) -> Result<(), ShaProjectionError> { + let value = public_scalar(public, col, row, field_cfg)?; + self.add_const_scaled(family, &value, scale); + Ok(()) + } + + fn add_public_word_or_const_scaled( + &mut self, + family: ShaResidualFamily, + public: &ProjectedPublic, + col: ShaPublicCol, + row: usize, + scale: &F, + field_cfg: &F::Config, + ) -> Result<(), ShaProjectionError> { + if F::is_zero(scale) { + return Ok(()); + } + let Some(word_col) = col.public_word_col() else { + return self.add_public_scalar_const_scaled(family, public, col, row, scale, field_cfg); + }; + let Some(bit_slices) = &public.bit_slices else { + return self.add_public_scalar_const_scaled(family, public, col, row, scale, field_cfg); + }; + if row >= SHA_ROW_COUNT { + return Ok(()); + } + let col_idx = word_col.index(); + for bit in 0..SHA_WORD_BITS { + let value = scalar_from_table( + "public.bit_slices", + bit_slices, + bit_slice_index(col_idx, bit, SHA_WORD_BITS), + row, + field_cfg, + )?; + self.add_scaled(family, bit, &value, scale); + } + Ok(()) + } +} + +#[inline(always)] +fn neg(value: &F) -> F { + -value.clone() +} + +#[inline(always)] +fn scaled(lhs: &F, rhs: &F) -> F { + lhs.clone() * rhs +} + +fn add_mu_contribution( + acc: &mut FixedResidualCoeffAccumulator, + family: ShaResidualFamily, + trace: &ProjectedTrace, + row: usize, + low_shift: usize, + high_shift: usize, + high_coeff: &F, + row_weight: &F, + constants: &ShaResidualPolyConstants, + field_cfg: &F::Config, +) -> Result<(), ShaProjectionError> +where + F: PrimeField, +{ + let low_scale = scaled(row_weight, &constants.low_mu_coeff); + let high_scale = neg(&scaled(row_weight, high_coeff)); + acc.add_trace_word_shift_r_scaled( + family, + trace, + ShaWordCol::MuPacked, + row, + low_shift, + &low_scale, + field_cfg, + )?; + acc.add_trace_word_shift_r_scaled( + family, + trace, + ShaWordCol::MuPacked, + row, + high_shift, + &high_scale, + field_cfg, + ) +} + +#[allow(clippy::arithmetic_side_effects)] +fn accumulate_residual_row_fixed( + acc: &mut FixedResidualCoeffAccumulator, + trace: &ProjectedTrace, + public: &ProjectedPublic, + row: usize, + row_weight: &F, + constants: &ShaResidualPolyConstants, + field_cfg: &F::Config, +) -> Result<(), ShaProjectionError> +where + F: PrimeField, +{ + if F::is_zero(row_weight) { + return Ok(()); + } + + let minus_row = neg(row_weight); + let minus_two_row = neg(&scaled(row_weight, &constants.two)); + + // R0/R1: big-sigma residuals. Multiplication by the sparse rotation + // polynomial is just three coefficient shifts. + acc.add_trace_word_sparse_product_scaled( + ShaResidualFamily::R0BigSigmaA, + trace, + ShaWordCol::A, + row, + &[10, 19, 30], + row_weight, + field_cfg, + )?; + acc.add_trace_word_scaled( + ShaResidualFamily::R0BigSigmaA, + trace, + ShaWordCol::Sigma0, + row, + 0, + &minus_row, + field_cfg, + )?; + acc.add_trace_word_scaled( + ShaResidualFamily::R0BigSigmaA, + trace, + ShaWordCol::OvSigma0, + row, + 0, + &minus_two_row, + field_cfg, + )?; + + acc.add_trace_word_sparse_product_scaled( + ShaResidualFamily::R1BigSigmaE, + trace, + ShaWordCol::E, + row, + &[7, 21, 26], + row_weight, + field_cfg, + )?; + acc.add_trace_word_scaled( + ShaResidualFamily::R1BigSigmaE, + trace, + ShaWordCol::Sigma1, + row, + 0, + &minus_row, + field_cfg, + )?; + acc.add_trace_word_scaled( + ShaResidualFamily::R1BigSigmaE, + trace, + ShaWordCol::OvSigma1, + row, + 0, + &minus_two_row, + field_cfg, + )?; + + // R2/R3: small-sigma residuals over the message schedule word. + for rot in [25usize, 14] { + acc.add_trace_word_rot_scaled( + ShaResidualFamily::R2SmallSigma0, + trace, + ShaWordCol::W, + row, + rot, + row_weight, + field_cfg, + )?; + } + acc.add_trace_word_shift_r_scaled( + ShaResidualFamily::R2SmallSigma0, + trace, + ShaWordCol::W, + row, + 3, + row_weight, + field_cfg, + )?; + acc.add_trace_word_scaled( + ShaResidualFamily::R2SmallSigma0, + trace, + ShaWordCol::SmallSigma0, + row, + 0, + &minus_row, + field_cfg, + )?; + acc.add_trace_word_scaled( + ShaResidualFamily::R2SmallSigma0, + trace, + ShaWordCol::OvSmallSigma0, + row, + 0, + &minus_two_row, + field_cfg, + )?; + + for rot in [15usize, 13] { + acc.add_trace_word_rot_scaled( + ShaResidualFamily::R3SmallSigma1, + trace, + ShaWordCol::W, + row, + rot, + row_weight, + field_cfg, + )?; + } + acc.add_trace_word_shift_r_scaled( + ShaResidualFamily::R3SmallSigma1, + trace, + ShaWordCol::W, + row, + 10, + row_weight, + field_cfg, + )?; + acc.add_trace_word_scaled( + ShaResidualFamily::R3SmallSigma1, + trace, + ShaWordCol::SmallSigma1, + row, + 0, + &minus_row, + field_cfg, + )?; + acc.add_trace_word_scaled( + ShaResidualFamily::R3SmallSigma1, + trace, + ShaWordCol::OvSmallSigma1, + row, + 0, + &minus_two_row, + field_cfg, + )?; + + // R4: schedule transition. + acc.add_trace_word_scaled( + ShaResidualFamily::R4Schedule, + trace, + ShaWordCol::W, + row, + 16, + row_weight, + field_cfg, + )?; + for (col, shift) in [ + (ShaWordCol::W, 0usize), + (ShaWordCol::SmallSigma0, 1), + (ShaWordCol::W, 9), + (ShaWordCol::SmallSigma1, 14), + ] { + acc.add_trace_word_scaled( + ShaResidualFamily::R4Schedule, + trace, + col, + row, + shift, + &minus_row, + field_cfg, + )?; + } + add_mu_contribution( + acc, + ShaResidualFamily::R4Schedule, + trace, + row, + 0, + 2, + &constants.high_mu_w_coeff, + row_weight, + constants, + field_cfg, + )?; + acc.add_trace_int_const_scaled( + ShaResidualFamily::R4Schedule, + trace, + ShaIntCol::CompSchedule, + row, + row_weight, + field_cfg, + )?; + + // R5/R6: compression round updates. + acc.add_trace_word_scaled( + ShaResidualFamily::R5UpdateA, + trace, + ShaWordCol::A, + row, + 4, + row_weight, + field_cfg, + )?; + for (col, shift) in [ + (ShaWordCol::E, 0usize), + (ShaWordCol::Sigma1, 3), + (ShaWordCol::Uef, 3), + (ShaWordCol::UNegEg, 3), + (ShaWordCol::W, 0), + (ShaWordCol::Sigma0, 3), + (ShaWordCol::Maj, 3), + ] { + acc.add_trace_word_scaled( + ShaResidualFamily::R5UpdateA, + trace, + col, + row, + shift, + &minus_row, + field_cfg, + )?; + } + acc.add_public_scalar_const_scaled( + ShaResidualFamily::R5UpdateA, + public, + ShaPublicCol::K, + row + 3, + &minus_row, + field_cfg, + )?; + add_mu_contribution( + acc, + ShaResidualFamily::R5UpdateA, + trace, + row, + 2, + 5, + &constants.high_mu_3_bit_coeff, + row_weight, + constants, + field_cfg, + )?; + acc.add_trace_int_const_scaled( + ShaResidualFamily::R5UpdateA, + trace, + ShaIntCol::CompUpdateA, + row, + row_weight, + field_cfg, + )?; + + acc.add_trace_word_scaled( + ShaResidualFamily::R6UpdateE, + trace, + ShaWordCol::E, + row, + 4, + row_weight, + field_cfg, + )?; + for (col, shift) in [ + (ShaWordCol::A, 0usize), + (ShaWordCol::E, 0), + (ShaWordCol::Sigma1, 3), + (ShaWordCol::Uef, 3), + (ShaWordCol::UNegEg, 3), + (ShaWordCol::W, 0), + ] { + acc.add_trace_word_scaled( + ShaResidualFamily::R6UpdateE, + trace, + col, + row, + shift, + &minus_row, + field_cfg, + )?; + } + acc.add_public_scalar_const_scaled( + ShaResidualFamily::R6UpdateE, + public, + ShaPublicCol::K, + row + 3, + &minus_row, + field_cfg, + )?; + add_mu_contribution( + acc, + ShaResidualFamily::R6UpdateE, + trace, + row, + 5, + 8, + &constants.high_mu_3_bit_coeff, + row_weight, + constants, + field_cfg, + )?; + acc.add_trace_int_const_scaled( + ShaResidualFamily::R6UpdateE, + trace, + ShaIntCol::CompUpdateE, + row, + row_weight, + field_cfg, + )?; + + // R7/R8: pin input/output public words at active selector rows. + let s_init = public_scalar(public, ShaPublicCol::SInit, row, field_cfg)?; + let s_out = public_scalar(public, ShaPublicCol::SOut, row, field_cfg)?; + let init_scale = scaled(row_weight, &s_init); + let out_scale = scaled(row_weight, &s_out); + let neg_init_scale = neg(&init_scale); + let neg_out_scale = neg(&out_scale); + acc.add_trace_word_scaled( + ShaResidualFamily::R7PinA, + trace, + ShaWordCol::A, + row, + 0, + &init_scale, + field_cfg, + )?; + acc.add_public_word_or_const_scaled( + ShaResidualFamily::R7PinA, + public, + ShaPublicCol::PAIn, + row, + &neg_init_scale, + field_cfg, + )?; + acc.add_trace_word_scaled( + ShaResidualFamily::R7PinA, + trace, + ShaWordCol::A, + row, + 0, + &out_scale, + field_cfg, + )?; + acc.add_public_word_or_const_scaled( + ShaResidualFamily::R7PinA, + public, + ShaPublicCol::PAOut, + row, + &neg_out_scale, + field_cfg, + )?; + + acc.add_trace_word_scaled( + ShaResidualFamily::R8PinE, + trace, + ShaWordCol::E, + row, + 0, + &init_scale, + field_cfg, + )?; + acc.add_public_word_or_const_scaled( + ShaResidualFamily::R8PinE, + public, + ShaPublicCol::PEIn, + row, + &neg_init_scale, + field_cfg, + )?; + acc.add_trace_word_scaled( + ShaResidualFamily::R8PinE, + trace, + ShaWordCol::E, + row, + 0, + &out_scale, + field_cfg, + )?; + acc.add_public_word_or_const_scaled( + ShaResidualFamily::R8PinE, + public, + ShaPublicCol::PEOut, + row, + &neg_out_scale, + field_cfg, + )?; + + // R9/R10: feed-forward rows. + acc.add_trace_word_scaled( + ShaResidualFamily::R9FeedForwardA, + trace, + ShaWordCol::A, + row, + 4, + row_weight, + field_cfg, + )?; + acc.add_trace_word_scaled( + ShaResidualFamily::R9FeedForwardA, + trace, + ShaWordCol::A, + row, + 0, + &minus_row, + field_cfg, + )?; + acc.add_public_scalar_const_scaled( + ShaResidualFamily::R9FeedForwardA, + public, + ShaPublicCol::PAIn, + row, + &minus_row, + field_cfg, + )?; + add_mu_contribution( + acc, + ShaResidualFamily::R9FeedForwardA, + trace, + row, + 8, + 9, + &constants.high_mu_1_bit_coeff, + row_weight, + constants, + field_cfg, + )?; + acc.add_trace_int_const_scaled( + ShaResidualFamily::R9FeedForwardA, + trace, + ShaIntCol::CompFeedForwardA, + row, + row_weight, + field_cfg, + )?; + + acc.add_trace_word_scaled( + ShaResidualFamily::R10FeedForwardE, + trace, + ShaWordCol::E, + row, + 4, + row_weight, + field_cfg, + )?; + acc.add_trace_word_scaled( + ShaResidualFamily::R10FeedForwardE, + trace, + ShaWordCol::E, + row, + 0, + &minus_row, + field_cfg, + )?; + acc.add_public_scalar_const_scaled( + ShaResidualFamily::R10FeedForwardE, + public, + ShaPublicCol::PEIn, + row, + &minus_row, + field_cfg, + )?; + add_mu_contribution( + acc, + ShaResidualFamily::R10FeedForwardE, + trace, + row, + 9, + 10, + &constants.high_mu_1_bit_coeff, + row_weight, + constants, + field_cfg, + )?; + acc.add_trace_int_const_scaled( + ShaResidualFamily::R10FeedForwardE, + trace, + ShaIntCol::CompFeedForwardE, + row, + row_weight, + field_cfg, + )?; + + // R11-R17: selector and high-bit/carry residuals. + let s_msg = public_scalar(public, ShaPublicCol::SMsg, row, field_cfg)?; + let msg_scale = scaled(row_weight, &s_msg); + let neg_msg_scale = neg(&msg_scale); + acc.add_trace_word_scaled( + ShaResidualFamily::R11MessagePin, + trace, + ShaWordCol::W, + row, + 0, + &msg_scale, + field_cfg, + )?; + acc.add_public_word_or_const_scaled( + ShaResidualFamily::R11MessagePin, + public, + ShaPublicCol::Message, + row, + &neg_msg_scale, + field_cfg, + )?; + + let s_sched = public_scalar(public, ShaPublicCol::SSched, row, field_cfg)?; + let s_upd = public_scalar(public, ShaPublicCol::SUpd, row, field_cfg)?; + let s_ff = public_scalar(public, ShaPublicCol::SFf, row, field_cfg)?; + acc.add_trace_int_const_scaled( + ShaResidualFamily::R12CompSchedule, + trace, + ShaIntCol::CompSchedule, + row, + &scaled(row_weight, &s_sched), + field_cfg, + )?; + acc.add_trace_int_const_scaled( + ShaResidualFamily::R13CompUpdateA, + trace, + ShaIntCol::CompUpdateA, + row, + &scaled(row_weight, &s_upd), + field_cfg, + )?; + acc.add_trace_int_const_scaled( + ShaResidualFamily::R14CompUpdateE, + trace, + ShaIntCol::CompUpdateE, + row, + &scaled(row_weight, &s_upd), + field_cfg, + )?; + acc.add_trace_int_const_scaled( + ShaResidualFamily::R15CompFeedForwardA, + trace, + ShaIntCol::CompFeedForwardA, + row, + &scaled(row_weight, &s_ff), + field_cfg, + )?; + acc.add_trace_int_const_scaled( + ShaResidualFamily::R16CompFeedForwardE, + trace, + ShaIntCol::CompFeedForwardE, + row, + &scaled(row_weight, &s_ff), + field_cfg, + )?; + acc.add_trace_word_shift_r_scaled( + ShaResidualFamily::R17CarryHighBits, + trace, + ShaWordCol::MuPacked, + row, + 10, + row_weight, + field_cfg, + )?; + + Ok(()) +} + +fn residual_polys_at_row_with_constants( + trace: &ProjectedTrace, + public: &ProjectedPublic, + row: usize, + constants: &ShaResidualPolyConstants, + field_cfg: &F::Config, +) -> Result<[DynamicPolynomialF; NUM_SHA_RESIDUAL_FAMILIES], ShaProjectionError> +where + F: PrimeField, +{ + let a = word_poly(trace, ShaWordCol::A, row, field_cfg)?; + let e = word_poly(trace, ShaWordCol::E, row, field_cfg)?; + let sigma0 = word_poly(trace, ShaWordCol::Sigma0, row, field_cfg)?; + let sigma1 = word_poly(trace, ShaWordCol::Sigma1, row, field_cfg)?; + let w = word_poly(trace, ShaWordCol::W, row, field_cfg)?; + let small_sigma0 = word_poly(trace, ShaWordCol::SmallSigma0, row, field_cfg)?; + let small_sigma1 = word_poly(trace, ShaWordCol::SmallSigma1, row, field_cfg)?; + let ov_sigma0 = word_poly(trace, ShaWordCol::OvSigma0, row, field_cfg)?; + let ov_sigma1 = word_poly(trace, ShaWordCol::OvSigma1, row, field_cfg)?; + let ov_small_sigma0 = word_poly(trace, ShaWordCol::OvSmallSigma0, row, field_cfg)?; + let ov_small_sigma1 = word_poly(trace, ShaWordCol::OvSmallSigma1, row, field_cfg)?; + let w_rot25 = w.rot_c(25); + let w_rot14 = w.rot_c(14); + let w_rot15 = w.rot_c(15); + let w_rot13 = w.rot_c(13); + let w_shift3 = w.shift_r_c(3); + let w_shift9 = word_poly_shifted(trace, ShaWordCol::W, row, 9, field_cfg)?; + let w_shift16 = word_poly_shifted(trace, ShaWordCol::W, row, 16, field_cfg)?; + let small_sigma0_shift1 = word_poly_shifted(trace, ShaWordCol::SmallSigma0, row, 1, field_cfg)?; + let small_sigma1_shift14 = + word_poly_shifted(trace, ShaWordCol::SmallSigma1, row, 14, field_cfg)?; + let a_shift4 = word_poly_shifted(trace, ShaWordCol::A, row, 4, field_cfg)?; + let e_shift4 = word_poly_shifted(trace, ShaWordCol::E, row, 4, field_cfg)?; + let sigma0_shift3 = word_poly_shifted(trace, ShaWordCol::Sigma0, row, 3, field_cfg)?; + let sigma1_shift3 = word_poly_shifted(trace, ShaWordCol::Sigma1, row, 3, field_cfg)?; + let uef_shift3 = word_poly_shifted(trace, ShaWordCol::Uef, row, 3, field_cfg)?; + let uneg_eg_shift3 = word_poly_shifted(trace, ShaWordCol::UNegEg, row, 3, field_cfg)?; + let maj_shift3 = word_poly_shifted(trace, ShaWordCol::Maj, row, 3, field_cfg)?; + let public_k_shift3 = public_const_poly(public, ShaPublicCol::K, row + 3, field_cfg)?; + let comp_schedule = int_const_poly(trace, ShaIntCol::CompSchedule, row, field_cfg)?; + let comp_update_a = int_const_poly(trace, ShaIntCol::CompUpdateA, row, field_cfg)?; + let comp_update_e = int_const_poly(trace, ShaIntCol::CompUpdateE, row, field_cfg)?; + let comp_ff_a = int_const_poly(trace, ShaIntCol::CompFeedForwardA, row, field_cfg)?; + let comp_ff_e = int_const_poly(trace, ShaIntCol::CompFeedForwardE, row, field_cfg)?; + + let r0 = (&a * &constants.rho_sig0) - &sigma0 - &scale_poly(&ov_sigma0, &constants.two); + let r1 = (&e * &constants.rho_sig1) - &sigma1 - &scale_poly(&ov_sigma1, &constants.two); + let r2 = w_rot25 + &w_rot14 + &w_shift3 + - &small_sigma0 + - &scale_poly(&ov_small_sigma0, &constants.two); + let r3 = w_rot15 + &w_rot13 + &w.shift_r_c(10) + - &small_sigma1 + - &scale_poly(&ov_small_sigma1, &constants.two); + + let mu_packed = word_poly(trace, ShaWordCol::MuPacked, row, field_cfg)?; + let mu_shift2 = mu_packed.shift_r_c(2); + let mu_shift5 = mu_packed.shift_r_c(5); + let mu_shift8 = mu_packed.shift_r_c(8); + let mu_shift9 = mu_packed.shift_r_c(9); + let mu_shift10 = mu_packed.shift_r_c(10); + let mu = |low: &DynamicPolynomialF, high: &DynamicPolynomialF, high_coeff: &F| { + scale_poly(low, &constants.low_mu_coeff) - &scale_poly(high, high_coeff) + }; + let mu_w = mu(&mu_packed, &mu_shift2, &constants.high_mu_w_coeff); + let mu_a = mu(&mu_shift2, &mu_shift5, &constants.high_mu_3_bit_coeff); + let mu_e = mu(&mu_shift5, &mu_shift8, &constants.high_mu_3_bit_coeff); + let mu_ff_a = mu(&mu_shift8, &mu_shift9, &constants.high_mu_1_bit_coeff); + let mu_ff_e = mu(&mu_shift9, &mu_shift10, &constants.high_mu_1_bit_coeff); + + let r4 = w_shift16 - &w - &small_sigma0_shift1 - &w_shift9 - &small_sigma1_shift14 + + &mu_w + + &comp_schedule; + + let r5 = a_shift4.clone() + - &e + - &sigma1_shift3 + - &uef_shift3 + - &uneg_eg_shift3 + - &public_k_shift3 + - &w + - &sigma0_shift3 + - &maj_shift3 + + &mu_a + + &comp_update_a; + + let r6 = e_shift4.clone() + - &a + - &e + - &sigma1_shift3 + - &uef_shift3 + - &uneg_eg_shift3 + - &public_k_shift3 + - &w + + &mu_e + + &comp_update_e; + + let s_init = public_scalar(public, ShaPublicCol::SInit, row, field_cfg)?; + let s_msg = public_scalar(public, ShaPublicCol::SMsg, row, field_cfg)?; + let s_sched = public_scalar(public, ShaPublicCol::SSched, row, field_cfg)?; + let s_upd = public_scalar(public, ShaPublicCol::SUpd, row, field_cfg)?; + let s_ff = public_scalar(public, ShaPublicCol::SFf, row, field_cfg)?; + let s_out = public_scalar(public, ShaPublicCol::SOut, row, field_cfg)?; + + let r7 = scale_poly( + &(a.clone() - &public_word_or_const_poly(public, ShaPublicCol::PAIn, row, field_cfg)?), + &s_init, + ) + &scale_poly( + &(a.clone() - &public_word_or_const_poly(public, ShaPublicCol::PAOut, row, field_cfg)?), + &s_out, + ); + let r8 = scale_poly( + &(e.clone() - &public_word_or_const_poly(public, ShaPublicCol::PEIn, row, field_cfg)?), + &s_init, + ) + &scale_poly( + &(e.clone() - &public_word_or_const_poly(public, ShaPublicCol::PEOut, row, field_cfg)?), + &s_out, + ); + + let r9 = a_shift4 - &a - &public_const_poly(public, ShaPublicCol::PAIn, row, field_cfg)? + + &mu_ff_a + + &comp_ff_a; + let r10 = e_shift4 - &e - &public_const_poly(public, ShaPublicCol::PEIn, row, field_cfg)? + + &mu_ff_e + + &comp_ff_e; + let r11 = scale_poly( + &(w - &public_word_or_const_poly(public, ShaPublicCol::Message, row, field_cfg)?), + &s_msg, + ); + + let r12 = scale_poly(&comp_schedule, &s_sched); + let r13 = scale_poly(&comp_update_a, &s_upd); + let r14 = scale_poly(&comp_update_e, &s_upd); + let r15 = scale_poly(&comp_ff_a, &s_ff); + let r16 = scale_poly(&comp_ff_e, &s_ff); + let r17 = mu_shift10; + + let mut residuals = [ + r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, r10, r11, r12, r13, r14, r15, r16, r17, + ]; + residuals.iter_mut().for_each(DynamicPolynomialF::trim); + debug_assert_eq!(residuals.len(), NUM_SHA_RESIDUAL_FAMILIES); + Ok(residuals) +} + +fn residual_values_at_row_with_powers( + trace: &ProjectedTrace, + public: &ProjectedPublic, + row: usize, + a_powers: &[F], + field_cfg: &F::Config, +) -> Result<[F; NUM_SHA_RESIDUAL_FAMILIES], ShaProjectionError> +where + F: DelayedFieldProductSum, +{ + let polies = residual_polys_at_row(trace, public, row, field_cfg)?; + let mut out: [F; NUM_SHA_RESIDUAL_FAMILIES] = + std::array::from_fn(|_| F::zero_with_cfg(field_cfg)); + for (idx, poly) in polies.iter().enumerate() { + out[idx] = evaluate_poly_at_powers_dmr(poly, a_powers, field_cfg)?; + } + Ok(out) +} + +fn validate_trace(trace: &ProjectedTrace) -> Result<(), ShaProjectionError> { + validate_table( + "bit_slices", + &trace.bit_slices, + ShaWordCol::COUNT * SHA_WORD_BITS, + )?; + validate_table("scalarized", &trace.scalarized, ShaWordCol::COUNT)?; + validate_table("int_columns", &trace.int_columns, ShaIntCol::COUNT)?; + validate_table("public_columns", &trace.public_columns, ShaPublicCol::COUNT) +} + +pub fn validate_projected_trace(trace: &ProjectedTrace) -> Result<(), ShaProjectionError> { + validate_trace(trace) +} + +fn validate_public(public: &ProjectedPublic) -> Result<(), ShaProjectionError> { + validate_table("public.columns", &public.columns, ShaPublicCol::COUNT)?; + if let Some(bit_slices) = &public.bit_slices { + validate_table( + "public.bit_slices", + bit_slices, + ShaPublicWordCol::COUNT * SHA_WORD_BITS, + )?; + } + Ok(()) +} + +fn validate_table( + kind: &'static str, + columns: &MleTable, + expected_cols: usize, +) -> Result<(), ShaProjectionError> { + if columns.len() != expected_cols { + return Err(ShaProjectionError::MissingColumn { + kind, + col: columns.len(), + }); + } + for (col, values) in columns.iter().enumerate() { + if values.num_vars != SHA_ROW_VARS || values.evaluations.len() != SHA_ROW_COUNT { + return Err(ShaProjectionError::ColumnRowCount { + kind, + col, + got: values.evaluations.len(), + expected: SHA_ROW_COUNT, + }); + } + } + Ok(()) +} + +fn word_poly( + trace: &ProjectedTrace, + col: ShaWordCol, + row: usize, + field_cfg: &F::Config, +) -> Result, ShaProjectionError> +where + F: PrimeField, +{ + if row >= SHA_ROW_COUNT { + return Ok(DynamicPolynomialF::ZERO); + } + let col_idx = col.index(); + let mut coeffs = Vec::with_capacity(SHA_WORD_BITS); + for bit in 0..SHA_WORD_BITS { + coeffs.push(scalar_from_table( + "bit_slices", + &trace.bit_slices, + bit_slice_index(col_idx, bit, SHA_WORD_BITS), + row, + field_cfg, + )?); + } + coeffs.resize(SHA_WORD_BITS, F::zero_with_cfg(field_cfg)); + Ok(DynamicPolynomialF::new_trimmed(coeffs)) +} + +fn word_poly_shifted( + trace: &ProjectedTrace, + col: ShaWordCol, + row: usize, + shift: usize, + field_cfg: &F::Config, +) -> Result, ShaProjectionError> +where + F: PrimeField, +{ + match row.checked_add(shift) { + Some(shifted) if shifted < SHA_ROW_COUNT => word_poly(trace, col, shifted, field_cfg), + _ => Ok(DynamicPolynomialF::ZERO), + } +} + +fn bit_at_shifted_or_zero( + trace: &ProjectedTrace, + col: ShaWordCol, + row: usize, + shift: usize, + bit: usize, + field_cfg: &F::Config, +) -> Result +where + F: PrimeField, +{ + if bit >= SHA_WORD_BITS { + return Err(ShaProjectionError::BitIndexOutOfRange { bit }); + } + let Some(shifted) = row.checked_add(shift) else { + return Ok(F::zero_with_cfg(field_cfg)); + }; + if shifted >= SHA_ROW_COUNT { + return Ok(F::zero_with_cfg(field_cfg)); + } + let col_idx = col.index(); + scalar_from_table( + "bit_slices", + &trace.bit_slices, + bit_slice_index(col_idx, bit, SHA_WORD_BITS), + shifted, + field_cfg, + ) +} + +fn bit_at_shifted_or_zero_fast( + trace: &ProjectedTrace, + col: ShaWordCol, + row: usize, + shift: usize, + bit: usize, + field_cfg: &F::Config, +) -> F +where + F: PrimeField, +{ + debug_assert!(bit < SHA_WORD_BITS); + let Some(shifted) = row.checked_add(shift) else { + return F::zero_with_cfg(field_cfg); + }; + if shifted >= SHA_ROW_COUNT { + return F::zero_with_cfg(field_cfg); + } + let table_idx = bit_slice_index(col.index(), bit, SHA_WORD_BITS); + debug_assert!(table_idx < trace.bit_slices.len()); + debug_assert!(shifted < trace.bit_slices[table_idx].evaluations.len()); + trace.bit_slices[table_idx].evaluations[shifted].clone() +} + +fn int_const_poly( + trace: &ProjectedTrace, + col: ShaIntCol, + row: usize, + field_cfg: &F::Config, +) -> Result, ShaProjectionError> +where + F: PrimeField, +{ + Ok(const_poly( + int_scalar(trace, col, row, field_cfg)?, + field_cfg, + )) +} + +fn public_const_poly( + public: &ProjectedPublic, + col: ShaPublicCol, + row: usize, + field_cfg: &F::Config, +) -> Result, ShaProjectionError> +where + F: PrimeField, +{ + Ok(const_poly( + public_scalar(public, col, row, field_cfg)?, + field_cfg, + )) +} + +fn public_word_or_const_poly( + public: &ProjectedPublic, + col: ShaPublicCol, + row: usize, + field_cfg: &F::Config, +) -> Result, ShaProjectionError> +where + F: PrimeField, +{ + let Some(word_col) = col.public_word_col() else { + return public_const_poly(public, col, row, field_cfg); + }; + let Some(bit_slices) = &public.bit_slices else { + return public_const_poly(public, col, row, field_cfg); + }; + if row >= SHA_ROW_COUNT { + return Ok(DynamicPolynomialF::ZERO); + } + let col_idx = word_col.index(); + let mut bits = Vec::with_capacity(SHA_WORD_BITS); + for bit in 0..SHA_WORD_BITS { + bits.push(scalar_from_table( + "public.bit_slices", + bit_slices, + bit_slice_index(col_idx, bit, SHA_WORD_BITS), + row, + field_cfg, + )?); + } + Ok(DynamicPolynomialF::new_trimmed(bits)) +} + +fn int_scalar( + trace: &ProjectedTrace, + col: ShaIntCol, + row: usize, + field_cfg: &F::Config, +) -> Result +where + F: PrimeField, +{ + if row >= SHA_ROW_COUNT { + return Ok(F::zero_with_cfg(field_cfg)); + } + scalar_from_table( + "int_columns", + &trace.int_columns, + col.index(), + row, + field_cfg, + ) +} + +fn public_scalar( + public: &ProjectedPublic, + col: ShaPublicCol, + row: usize, + field_cfg: &F::Config, +) -> Result +where + F: PrimeField, +{ + if row >= SHA_ROW_COUNT { + return Ok(F::zero_with_cfg(field_cfg)); + } + scalar_from_table( + "public.columns", + &public.columns, + col.index(), + row, + field_cfg, + ) +} + +fn scalar_from_table( + kind: &'static str, + columns: &MleTable, + col: usize, + row: usize, + field_cfg: &F::Config, +) -> Result +where + F: PrimeField, +{ + let values = columns + .get(col) + .ok_or(ShaProjectionError::MissingColumn { kind, col })?; + if values.num_vars != SHA_ROW_VARS || values.evaluations.len() != SHA_ROW_COUNT { + return Err(ShaProjectionError::ColumnRowCount { + kind, + col, + got: values.evaluations.len(), + expected: SHA_ROW_COUNT, + }); + } + Ok(values + .evaluations + .get(row) + .cloned() + .unwrap_or_else(|| F::zero_with_cfg(field_cfg))) +} + +fn const_poly(value: F, _field_cfg: &F::Config) -> DynamicPolynomialF { + DynamicPolynomialF::new_trimmed([value]) +} + +fn sparse_poly(indices: &[usize], field_cfg: &F::Config) -> DynamicPolynomialF { + let mut coeffs = vec![F::zero_with_cfg(field_cfg); SHA_WORD_BITS]; + for &idx in indices { + coeffs[idx] = F::one_with_cfg(field_cfg); + } + DynamicPolynomialF::new_trimmed(coeffs) +} + +fn scale_poly(poly: &DynamicPolynomialF, scalar: &F) -> DynamicPolynomialF { + if poly.is_zero() || F::is_zero(scalar) { + return DynamicPolynomialF::ZERO; + } + DynamicPolynomialF::new_trimmed( + poly.coeffs + .iter() + .map(|coeff| coeff.clone() * scalar) + .collect::>(), + ) +} + +fn add_scaled_poly_assign( + acc: &mut DynamicPolynomialF, + poly: &DynamicPolynomialF, + scalar: &F, +) { + if poly.is_zero() || F::is_zero(scalar) { + return; + } + if acc.coeffs.len() < poly.coeffs.len() { + acc.coeffs + .resize_with(poly.coeffs.len(), || F::zero_with_cfg(scalar.cfg())); + } + for (dst, coeff) in acc.coeffs.iter_mut().zip(&poly.coeffs) { + *dst += coeff.clone() * scalar; + } +} + +fn pow_two(exp: usize, field_cfg: &F::Config) -> F { + let two = F::one_with_cfg(field_cfg) + F::one_with_cfg(field_cfg); + let mut out = F::one_with_cfg(field_cfg); + for _ in 0..exp { + out *= &two; + } + out +} + +fn evaluate_poly_at_powers_dmr( + poly: &DynamicPolynomialF, + powers: &[F], + field_cfg: &F::Config, +) -> Result +where + F: DelayedFieldProductSum, +{ + if poly.coeffs.is_empty() { + return Ok(F::zero_with_cfg(field_cfg)); + } + if poly.coeffs.len() > powers.len() { + return Err(ShaProjectionError::NonCanonicalProofObject( + "SHA polynomial exceeds precomputed scalarization power bound", + )); + } + DynamicPolyFInnerProduct::inner_product::( + &poly.coeffs, + &powers[..poly.coeffs.len()], + F::zero_with_cfg(field_cfg), + ) + .map_err(ShaProjectionError::from) +} + +fn project_bits_dmr( + bits: &[F], + powers: &[F], + field_cfg: &F::Config, +) -> Result +where + F: DelayedFieldProductSum, +{ + if bits.len() > powers.len() { + return Err(ShaProjectionError::NonCanonicalProofObject( + "SHA bit projection exceeds precomputed scalarization power bound", + )); + } + FieldFieldInnerProduct::inner_product::( + bits, + &powers[..bits.len()], + F::zero_with_cfg(field_cfg), + ) + .map_err(ShaProjectionError::from) +} + +fn project_binary_bits_conditional_add_dmr( + bits: &[F], + powers: &[F], + field_cfg: &F::Config, + reducer: &BarrettDelayedReduction<'_, F>, +) -> Result +where + F: MontgomeryLimbs + DelayedFieldProductSum + Send + Sync, +{ + if bits.len() > powers.len() { + return Err(ShaProjectionError::NonCanonicalProofObject( + "SHA binary bit projection exceeds precomputed scalarization power bound", + )); + } + let one = F::one_with_cfg(field_cfg); + let mut bucket = Uint::<5>::zero(); + let mut pending_adds = 0usize; + let mut acc = F::zero_with_cfg(field_cfg); + + for (bit, power) in bits.iter().zip(powers.iter()) { + if F::is_zero(bit) { + continue; + } + if bit != &one { + return project_bits_dmr(bits, powers, field_cfg); + } + + reducer.add(&mut bucket, power); + pending_adds = pending_adds.saturating_add(1); + if pending_adds >= reducer.flush_adds() { + let pending = std::mem::replace(&mut bucket, Uint::zero()); + acc += reducer.reduce(pending); + pending_adds = 0; + } + } + + if !bucket.is_zero() { + acc += reducer.reduce(bucket); + } + Ok(acc) +} + +fn build_virtual_bit_array(mut f: G) -> Result<[F; SHA_WORD_BITS], ShaProjectionError> +where + G: FnMut(usize) -> Result, +{ + let mut values = Vec::with_capacity(SHA_WORD_BITS); + for bit in 0..SHA_WORD_BITS { + values.push(f(bit)?); + } + Ok(values + .try_into() + .unwrap_or_else(|_| unreachable!("exactly 32 virtual bits were built"))) +} + +fn scalarized_word_at_shifted_or_zero( + trace: &ProjectedTrace, + col: ShaWordCol, + row: usize, + shift: usize, + field_cfg: &F::Config, +) -> Result +where + F: PrimeField, +{ + let Some(shifted) = row.checked_add(shift) else { + return Ok(F::zero_with_cfg(field_cfg)); + }; + if shifted >= SHA_ROW_COUNT { + return Ok(F::zero_with_cfg(field_cfg)); + } + let col_idx = col.index(); + scalar_from_table("scalarized", &trace.scalarized, col_idx, shifted, field_cfg) +} + +fn booleanity_source_value_at_row( + trace: &ProjectedTrace, + row: usize, + source: &ShaBooleanitySource, + field_cfg: &F::Config, +) -> Result +where + F: PrimeField, +{ + let virtuals = if matches!( + source, + ShaBooleanitySource::VirtualCh1 { .. } + | ShaBooleanitySource::VirtualCh2 { .. } + | ShaBooleanitySource::VirtualMaj { .. } + ) { + Some(reconstruct_virtual_ch_maj_at_row(trace, row, field_cfg)?) + } else { + None + }; + booleanity_source_value_at_row_with_virtuals(trace, row, source, virtuals.as_ref(), field_cfg) +} + +fn booleanity_source_value_at_row_with_virtuals( + trace: &ProjectedTrace, + row: usize, + source: &ShaBooleanitySource, + virtuals: Option<&VirtualChMajValues>, + field_cfg: &F::Config, +) -> Result +where + F: PrimeField, +{ + match source { + ShaBooleanitySource::WordBit { col, bit } => { + bit_at_shifted_or_zero(trace, *col, row, 0, *bit, field_cfg) + } + ShaBooleanitySource::VirtualCh1 { bit } => { + virtual_bit_at(&virtuals.expect("virtual source needs row cache").ch1, *bit) + } + ShaBooleanitySource::VirtualCh2 { bit } => { + virtual_bit_at(&virtuals.expect("virtual source needs row cache").ch2, *bit) + } + ShaBooleanitySource::VirtualMaj { bit } => { + virtual_bit_at(&virtuals.expect("virtual source needs row cache").maj, *bit) + } + } +} + +fn virtual_bit_at( + bits: &[F; SHA_WORD_BITS], + bit: usize, +) -> Result { + bits.get(bit) + .cloned() + .ok_or(ShaProjectionError::BitIndexOutOfRange { bit }) +} + +fn fold_binary_mle_tables<'a, F, I>( + kind: &'static str, + tables: I, + theta: &[F], + field_cfg: &F::Config, +) -> Result, ShaProjectionError> +where + F: ShaBinaryFoldField + 'a, + I: IntoIterator>, +{ + let tables = tables.into_iter().collect::>(); + F::fold_binary_mle_tables(kind, &tables, theta, field_cfg) +} + +fn fold_binary_mle_tables_generic( + kind: &'static str, + tables: &[&MleTable], + theta: &[F], + field_cfg: &F::Config, +) -> Result, ShaProjectionError> +where + F: PrimeField, +{ + fold_mle_tables(kind, tables.iter().copied(), theta, field_cfg) +} + +fn fold_binary_mle_tables_montgomery( + kind: &'static str, + tables: &[&MleTable], + theta: &[F], + field_cfg: &F::Config, +) -> Result, ShaProjectionError> +where + F: PrimeField + MontgomeryLimbs + Send + Sync, +{ + if tables.len() != theta.len() { + return Err(ShaProjectionError::FoldingWeightCount { + got: theta.len(), + expected: tables.len(), + }); + } + let Some(&first) = tables.first() else { + return Ok(Vec::new()); + }; + for table in tables { + if table.len() != first.len() { + return Err(ShaProjectionError::InstanceCountMismatch { + got: table.len(), + expected: first.len(), + }); + } + } + + let one = F::one_with_cfg(field_cfg); + let reducer = BarrettDelayedReduction::::new(field_cfg); + cfg_into_iter!(0..first.len()) + .map(|col_idx| { + let template = &first[col_idx]; + let mut evaluations = vec![F::zero_with_cfg(field_cfg); template.evaluations.len()]; + for table in tables { + let col = &table[col_idx]; + if col.num_vars != template.num_vars + || col.evaluations.len() != template.evaluations.len() + { + return Err(ShaProjectionError::ColumnRowCount { + kind, + col: col_idx, + got: col.evaluations.len(), + expected: template.evaluations.len(), + }); + } + } + for (row, out) in evaluations.iter_mut().enumerate() { + *out = fold_binary_row_values_montgomery_dmr( + tables, theta, col_idx, row, &one, field_cfg, &reducer, + ); + } + Ok(DenseMultilinearExtension { + evaluations, + num_vars: template.num_vars, + }) + }) + .collect::, ShaProjectionError>>() +} + +fn fold_binary_row_values_montgomery_dmr( + tables: &[&MleTable], + theta: &[F], + col_idx: usize, + row: usize, + one: &F, + field_cfg: &F::Config, + reducer: &BarrettDelayedReduction<'_, F>, +) -> F +where + F: PrimeField + MontgomeryLimbs + Send + Sync, +{ + let mut bucket = Uint::<5>::zero(); + let mut pending_adds = 0usize; + let mut acc = F::zero_with_cfg(field_cfg); + + for (table, weight) in tables.iter().zip(theta) { + let value = &table[col_idx].evaluations[row]; + if F::is_zero(value) { + continue; + } + if value != one { + return fold_row_values_naive(tables, theta, col_idx, row, field_cfg); + } + reducer.add(&mut bucket, weight); + pending_adds = pending_adds.saturating_add(1); + if pending_adds >= reducer.flush_adds() { + let pending = std::mem::replace(&mut bucket, Uint::zero()); + acc += reducer.reduce(pending); + pending_adds = 0; + } + } + + if !bucket.is_zero() { + acc += reducer.reduce(bucket); + } + acc +} + +fn fold_row_values_naive( + tables: &[&MleTable], + theta: &[F], + col_idx: usize, + row: usize, + field_cfg: &F::Config, +) -> F +where + F: PrimeField, +{ + let mut acc = F::zero_with_cfg(field_cfg); + for (table, weight) in tables.iter().zip(theta) { + acc += weight.clone() * &table[col_idx].evaluations[row]; + } + acc +} + +fn fold_optional_binary_mle_tables<'a, F, I>( + kind: &'static str, + tables: I, + theta: &[F], + field_cfg: &F::Config, +) -> Result>, ShaProjectionError> +where + F: ShaBinaryFoldField + 'a, + I: IntoIterator>>, +{ + let tables = tables.into_iter().collect::>(); + if tables.iter().all(Option::is_none) { + return Ok(None); + } + let mut present = Vec::with_capacity(tables.len()); + for table in tables { + let Some(table) = table else { + return Err(ShaProjectionError::PublicWordColumnPresenceMismatch); + }; + present.push(table); + } + fold_binary_mle_tables(kind, present, theta, field_cfg).map(Some) +} + +fn fold_mle_tables<'a, F, I>( + kind: &'static str, + tables: I, + theta: &[F], + field_cfg: &F::Config, +) -> Result, ShaProjectionError> +where + F: PrimeField + 'a, + I: IntoIterator>, +{ + let tables = tables.into_iter().collect::>(); + if tables.len() != theta.len() { + return Err(ShaProjectionError::FoldingWeightCount { + got: theta.len(), + expected: tables.len(), + }); + } + let Some(first) = tables.first() else { + return Ok(Vec::new()); + }; + for table in &tables { + if table.len() != first.len() { + return Err(ShaProjectionError::InstanceCountMismatch { + got: table.len(), + expected: first.len(), + }); + } + } + cfg_into_iter!(0..first.len()) + .map(|col_idx| { + let template = &first[col_idx]; + let mut evaluations = vec![F::zero_with_cfg(field_cfg); template.evaluations.len()]; + for (table, weight) in tables.iter().zip(theta) { + let col = &table[col_idx]; + if col.num_vars != template.num_vars + || col.evaluations.len() != template.evaluations.len() + { + return Err(ShaProjectionError::ColumnRowCount { + kind, + col: col_idx, + got: col.evaluations.len(), + expected: template.evaluations.len(), + }); + } + for (out, value) in evaluations.iter_mut().zip(&col.evaluations) { + *out += weight.clone() * value; + } + } + Ok(DenseMultilinearExtension { + evaluations, + num_vars: template.num_vars, + }) + }) + .collect::, ShaProjectionError>>() +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::sumcheck::multi_degree::MultiDegreeSumcheck; + use crate::test_utils::test_config; + use crypto_primitives::{FromWithConfig, crypto_bigint_monty::MontyField}; + use zinc_poly::EvaluatablePolynomial; + use zinc_transcript::Blake3Transcript; + + type F = MontyField<4>; + + fn f(value: u64) -> F { + F::from_with_cfg(value, &test_config()) + } + + fn zero_table(cols: usize) -> MleTable { + let cfg = test_config(); + mle_table_from_columns( + vec![vec![F::zero_with_cfg(&cfg); SHA_ROW_COUNT]; cols], + SHA_ROW_VARS, + ) + } + + fn set_word_bit( + trace: &mut ProjectedTrace, + col: ShaWordCol, + row: usize, + bit: usize, + value: F, + ) { + let idx = bit_slice_index(col.index(), bit, SHA_WORD_BITS); + trace.bit_slices[idx].evaluations[row] = value; + } + + fn word_bit(trace: &ProjectedTrace, col: ShaWordCol, row: usize, bit: usize) -> &F { + &trace.bit_slices[bit_slice_index(col.index(), bit, SHA_WORD_BITS)].evaluations[row] + } + + fn zero_trace() -> ProjectedTrace { + let cfg = test_config(); + let zero = F::zero_with_cfg(&cfg); + let bits = vec![vec![vec![zero.clone(); SHA_WORD_BITS]; SHA_ROW_COUNT]; ShaWordCol::COUNT]; + let bit_slices = + flatten_bit_columns(bits, SHA_WORD_BITS, SHA_ROW_VARS, "bit_slices").unwrap(); + let scalarized = scalarize_bit_slices(&bit_slices, &f(5), &cfg).unwrap(); + ProjectedTrace { + bit_slices, + scalarized, + int_columns: zero_table(ShaIntCol::COUNT), + public_columns: zero_table(ShaPublicCol::COUNT), + } + } + + fn zero_public() -> ProjectedPublic { + ProjectedPublic { + columns: zero_table(ShaPublicCol::COUNT), + bit_slices: None, + } + } + + fn synthetic_boolean_trace(instance_idx: u64, a: &F) -> ProjectedTrace { + let cfg = test_config(); + let zero = F::zero_with_cfg(&cfg); + let mut bits = + vec![vec![vec![zero.clone(); SHA_WORD_BITS]; SHA_ROW_COUNT]; ShaWordCol::COUNT]; + for (col_idx, col) in bits.iter_mut().enumerate() { + for (row_idx, row) in col.iter_mut().enumerate() { + for (bit_idx, bit) in row.iter_mut().enumerate() { + let selector = instance_idx + + u64::try_from(col_idx * 17 + row_idx * 3 + bit_idx) + .expect("test selector fits u64"); + if selector % 2 == 1 { + *bit = f(1); + } + } + } + } + let bit_slices = + flatten_bit_columns(bits, SHA_WORD_BITS, SHA_ROW_VARS, "bit_slices").unwrap(); + let scalarized = scalarize_bit_slices(&bit_slices, a, &cfg).unwrap(); + ProjectedTrace { + bit_slices, + scalarized, + int_columns: zero_table(ShaIntCol::COUNT), + public_columns: zero_table(ShaPublicCol::COUNT), + } + } + + fn prove_and_verify_sumfold( + group: MultiDegreeSumcheckGroup, + ell: usize, + ) -> ( + crate::sumcheck::multi_degree::MultiDegreeSumcheckProof, + Vec, + Vec, + ) { + let cfg = test_config(); + let mut prover_transcript = Blake3Transcript::new(); + let (proof, _) = MultiDegreeSumcheck::prove_as_subprotocol( + &mut prover_transcript, + vec![group], + ell, + &cfg, + ); + + let mut verifier_transcript = Blake3Transcript::new(); + let subclaims = + MultiDegreeSumcheck::verify_as_subprotocol(&mut verifier_transcript, ell, &proof, &cfg) + .expect("sumcheck proof should verify"); + + ( + proof, + subclaims.point().to_vec(), + subclaims.expected_evaluations().to_vec(), + ) + } + + fn naive_project_bits(bits: &[F], powers: &[F]) -> F { + bits.iter() + .zip(powers.iter()) + .fold(F::zero_with_cfg(&test_config()), |acc, (bit, power)| { + acc + bit.clone() * power + }) + } + + #[test] + fn dmr_bit_projection_matches_naive_for_binary_and_field_bits() { + let cfg = test_config(); + let a = f(7); + let powers = powers(a, F::one_with_cfg(&cfg), SHA_WORD_BITS); + let zero = F::zero_with_cfg(&cfg); + let mut binary_bits = vec![zero.clone(); SHA_WORD_BITS]; + binary_bits[0] = f(1); + binary_bits[5] = f(1); + binary_bits[31] = f(1); + + let binary_expected = naive_project_bits(&binary_bits, &powers); + let reducer = BarrettDelayedReduction::::new(&cfg); + assert_eq!( + project_binary_bits_conditional_add_dmr(&binary_bits, &powers, &cfg, &reducer).unwrap(), + binary_expected + ); + assert_eq!( + project_bits_dmr(&binary_bits, &powers, &cfg).unwrap(), + binary_expected + ); + + let mut field_bits = vec![zero; SHA_WORD_BITS]; + field_bits[3] = f(2); + field_bits[9] = f(11); + let field_expected = naive_project_bits(&field_bits, &powers); + assert_eq!( + project_binary_bits_conditional_add_dmr(&field_bits, &powers, &cfg, &reducer).unwrap(), + field_expected + ); + assert_eq!( + project_bits_dmr(&field_bits, &powers, &cfg).unwrap(), + field_expected + ); + } + + #[test] + fn dmr_residual_evaluation_matches_polynomial_evaluation() { + let cfg = test_config(); + let a = f(5); + let trace = synthetic_boolean_trace(3, &a); + let public = zero_public(); + let row = 17usize; + let a_powers = powers( + a.clone(), + F::one_with_cfg(&cfg), + SHA_RESIDUAL_EVAL_POWER_COUNT, + ); + let residuals = + residual_values_at_row_with_powers(&trace, &public, row, &a_powers, &cfg).unwrap(); + let polies = residual_polys_at_row(&trace, &public, row, &cfg).unwrap(); + + for (value, poly) in residuals.iter().zip(polies.iter()) { + assert_eq!(value, &poly.evaluate_at_point(&a).unwrap()); + } + } + + #[test] + fn fixed_residual_coeff_table_matches_dynamic_reference() { + let cfg = test_config(); + let a = f(5); + let mut trace = synthetic_boolean_trace(11, &a); + for (col_idx, column) in trace.int_columns.iter_mut().enumerate() { + for (row_idx, value) in column.evaluations.iter_mut().enumerate() { + *value = f(u64::try_from((col_idx + 3) * (row_idx % 17 + 1)).unwrap()); + } + } + + let mut public = zero_public(); + for (col_idx, column) in public.columns.iter_mut().enumerate() { + for (row_idx, value) in column.evaluations.iter_mut().enumerate() { + *value = f(u64::try_from((col_idx + 5) * (row_idx % 19 + 1)).unwrap()); + } + } + + let zero = F::zero_with_cfg(&cfg); + let mut public_bits = + vec![vec![vec![zero; SHA_WORD_BITS]; SHA_ROW_COUNT]; ShaPublicWordCol::COUNT]; + for (col_idx, col) in public_bits.iter_mut().enumerate() { + for (row_idx, row) in col.iter_mut().enumerate() { + for (bit_idx, bit) in row.iter_mut().enumerate() { + if (col_idx + row_idx + bit_idx) % 3 == 0 { + *bit = f(1); + } + } + } + } + public.bit_slices = + Some(flatten_bit_columns(public_bits, SHA_WORD_BITS, SHA_ROW_VARS, "public").unwrap()); + + let row_weights = (0..SHA_ROW_COUNT) + .map(|row| f(u64::try_from(row % 23 + 1).unwrap())) + .collect::>(); + let fixed = build_linear_residual_coeff_tables_with_row_weights( + &[trace.clone()], + &[public.clone()], + &row_weights, + &cfg, + ) + .unwrap(); + + let constants = ShaResidualPolyConstants::new(&cfg); + let mut expected = vec![DynamicPolynomialF::ZERO; NUM_SHA_RESIDUAL_FAMILIES]; + for (row, row_weight) in row_weights.iter().enumerate() { + let residuals = + residual_polys_at_row_with_constants(&trace, &public, row, &constants, &cfg) + .unwrap(); + for (family_idx, residual) in residuals.iter().enumerate() { + add_scaled_poly_assign(&mut expected[family_idx], residual, row_weight); + } + } + expected.iter_mut().for_each(DynamicPolynomialF::trim); + + assert_eq!(fixed[0].coeffs, expected); + } + + #[test] + fn dmr_fresh_sha_targets_match_reference_evaluation() { + let cfg = test_config(); + let a = f(13); + let lambda = f(17); + let mut cache = FreshIdealEvaluationCache { + r_ic: std::array::from_fn(|_| F::zero_with_cfg(&cfg)), + ideal_polys: vec![std::array::from_fn(|slot| { + DynamicPolynomialF::new_trimmed([ + f(u64::try_from(slot + 1).unwrap()), + f(u64::try_from(slot + 2).unwrap()), + f(u64::try_from(slot + 3).unwrap()), + ]) + })], + taus_at_a: Vec::new(), + fresh_targets: Vec::new(), + }; + + evaluate_fresh_sha_targets(&mut cache, &a, &lambda, &cfg).unwrap(); + + let lambda_powers = powers(lambda, F::one_with_cfg(&cfg), NUM_SHA_RESIDUAL_FAMILIES); + let mut expected_target = F::zero_with_cfg(&cfg); + for (slot, family) in NONZERO_SHA_FAMILIES.iter().enumerate() { + let expected_tau = cache.ideal_polys[0][slot].evaluate_at_point(&a).unwrap(); + assert_eq!(cache.taus_at_a[0][slot], expected_tau); + expected_target += lambda_powers[family.index()].clone() * expected_tau; + } + assert_eq!(cache.fresh_targets[0], expected_target); + } + + #[test] + fn zero_trace_ideal_cache_checks_and_targets_are_zero() { + let cfg = test_config(); + let trace = zero_trace(); + let public = zero_public(); + let mut r_ic = std::array::from_fn(|_| F::zero_with_cfg(&cfg)); + r_ic[0] = f(3); + r_ic[1] = f(7); + + let mut cache = + build_fresh_sha_ideal_cache(&[trace], &[public], r_ic, &cfg).expect("cache builds"); + check_fresh_sha_ideal_cache(&cache, &cfg).expect("zero ideals pass"); + evaluate_fresh_sha_targets(&mut cache, &f(11), &f(13), &cfg).unwrap(); + + assert_eq!(cache.ideal_polys.len(), 1); + for poly in &cache.ideal_polys[0] { + assert!(poly.is_zero()); + } + for tau in &cache.taus_at_a[0] { + assert_eq!(tau, &F::zero_with_cfg(&cfg)); + } + assert_eq!(cache.fresh_targets[0], F::zero_with_cfg(&cfg)); + } + + #[test] + fn beta_aggregate_with_weights_matches_wrapper() { + let cfg = test_config(); + let table = |offset: u64| LinearResidualCoeffTable { + coeffs: (0..NUM_SHA_RESIDUAL_FAMILIES) + .map(|idx| { + DynamicPolynomialF::new_trimmed([ + f(offset + idx as u64 + 1), + f(offset + idx as u64 + 101), + ]) + }) + .collect(), + }; + let tables = vec![table(0), table(1_000)]; + let beta = [f(17)]; + let beta_eq_weights = zinc_poly::utils::build_eq_x_r_vec(&beta, &cfg).unwrap(); + + let wrapped = beta_aggregate_nonzero_ideal_polys(&tables, &beta, &cfg).unwrap(); + let cached = + beta_aggregate_nonzero_ideal_polys_with_weights(&tables, &beta_eq_weights).unwrap(); + + assert_eq!(cached, wrapped); + } + + #[test] + fn tampered_ideal_cache_fails_membership() { + let cfg = test_config(); + let trace = zero_trace(); + let public = zero_public(); + let r_ic = std::array::from_fn(|_| F::zero_with_cfg(&cfg)); + let mut values = build_sha_ideal_values_at_point(&trace, &public, &r_ic, &cfg).unwrap(); + values[0] = DynamicPolynomialF::new_trimmed([f(1)]); + + assert!(matches!( + check_sha_ideal_values(&values, &cfg), + Err(ShaProjectionError::IdealMembership) + )); + } + + #[test] + fn fresh_sha_ideal_polys_are_verified_by_reusable_helper() { + let cfg = test_config(); + let valid_zero = vec![std::array::from_fn(|_| { + DynamicPolynomialF::new(Vec::::new()) + })]; + verify_fresh_sha_ideal_polys(&valid_zero, &cfg).expect("zero ideal set passes"); + + let mut tampered_x_minus_two = valid_zero.clone(); + tampered_x_minus_two[0][2] = DynamicPolynomialF::new_trimmed([f(1)]); + assert!(matches!( + verify_fresh_sha_ideal_polys(&tampered_x_minus_two, &cfg), + Err(ShaProjectionError::IdealMembership) + )); + + let mut trailing_zero = valid_zero.clone(); + trailing_zero[0][0] = DynamicPolynomialF::new(vec![f(1), F::zero_with_cfg(&cfg)]); + assert!(matches!( + verify_fresh_sha_ideal_polys(&trailing_zero, &cfg), + Err(ShaProjectionError::NonCanonicalProofObject(_)) + )); + + let mut high_degree = valid_zero; + high_degree[0][2] = DynamicPolynomialF::new(vec![f(1); 33]); + assert!(matches!( + verify_fresh_sha_ideal_polys(&high_degree, &cfg), + Err(ShaProjectionError::NonCanonicalProofObject(_)) + )); + } + + #[test] + fn scalarization_links_check_folded_words() { + let cfg = test_config(); + let mut trace = zero_trace(); + set_word_bit(&mut trace, ShaWordCol::A, 0, 0, f(1)); + set_word_bit(&mut trace, ShaWordCol::A, 0, 3, f(1)); + trace.scalarized = scalarize_bit_slices(&trace.bit_slices, &f(5), &cfg).unwrap(); + + verify_folded_scalarization_links(&trace, &f(5), &[ShaWordCol::A], &cfg) + .expect("scalarization should pass"); + + trace.scalarized[ShaWordCol::A.index()].evaluations[0] += f(1); + assert!(matches!( + verify_folded_scalarization_links(&trace, &f(5), &[ShaWordCol::A], &cfg), + Err(ShaProjectionError::ScalarizationMismatch { .. }) + )); + } + + #[test] + fn scalarization_links_check_endpoint_and_shifted_sources() { + let cfg = test_config(); + let mut trace = zero_trace(); + set_word_bit(&mut trace, ShaWordCol::A, 0, 1, f(1)); + set_word_bit(&mut trace, ShaWordCol::A, 1, 0, f(1)); + set_word_bit(&mut trace, ShaWordCol::A, 1, 2, f(1)); + trace.scalarized = scalarize_bit_slices(&trace.bit_slices, &f(3), &cfg).unwrap(); + let r_star = std::array::from_fn(|_| F::zero_with_cfg(&cfg)); + + verify_folded_scalarization_links_at_point(&trace, &f(3), &r_star, &[ShaWordCol::A], &cfg) + .expect("endpoint scalarization should pass"); + verify_folded_shifted_scalarization_link_at_point( + &trace, + &f(3), + &r_star, + ShaWordCol::A, + 1, + &cfg, + ) + .expect("shifted endpoint scalarization should pass"); + + trace.scalarized[ShaWordCol::A.index()].evaluations[0] += f(1); + assert!(matches!( + verify_folded_scalarization_links_at_point( + &trace, + &f(3), + &r_star, + &[ShaWordCol::A], + &cfg + ), + Err(ShaProjectionError::ScalarizationMismatch { .. }) + )); + } + + #[test] + fn instance_fold_claim_derives_weights_after_endpoint() { + let cfg = test_config(); + let beta = vec![f(2), f(3)]; + let r_b = vec![f(5), f(7)]; + let c_sf = f(11); + let out = derive_instance_fold_claim(&beta, r_b.clone(), c_sf.clone(), 4, &cfg).unwrap(); + let d = eq_eval(&beta, &r_b, F::one_with_cfg(&cfg)).unwrap(); + + assert_eq!(out.final_round_sumcheck_claim(), &(c_sf / d)); + assert_eq!( + out.eq_instance_weights(), + build_eq_x_r_vec(&r_b, &cfg).unwrap() + ); + } + + #[test] + fn folding_uses_eq_instance_weights() { + let cfg = test_config(); + let beta = vec![f(2)]; + let r_b = vec![f(3)]; + let out = derive_instance_fold_claim(&beta, r_b, f(9), 2, &cfg).unwrap(); + + let mut left = zero_trace(); + let mut right = zero_trace(); + set_word_bit(&mut left, ShaWordCol::A, 0, 0, f(1)); + set_word_bit(&mut right, ShaWordCol::A, 0, 0, f(2)); + left.scalarized = scalarize_bit_slices(&left.bit_slices, &f(5), &cfg).unwrap(); + right.scalarized = scalarize_bit_slices(&right.bit_slices, &f(5), &cfg).unwrap(); + + let (folded, _public) = fold_projected_traces( + &[left.clone(), right.clone()], + &[zero_public(), zero_public()], + &out, + &cfg, + ) + .unwrap(); + let expected = out.eq_instance_weights()[0].clone() * word_bit(&left, ShaWordCol::A, 0, 0) + + out.eq_instance_weights()[1].clone() * word_bit(&right, ShaWordCol::A, 0, 0); + assert_eq!(*word_bit(&folded.trace, ShaWordCol::A, 0, 0), expected); + } + + #[test] + fn virtual_ch_maj_reconstructs_from_source_bits() { + let cfg = test_config(); + let mut trace = zero_trace(); + set_word_bit(&mut trace, ShaWordCol::E, 2, 0, f(1)); + set_word_bit(&mut trace, ShaWordCol::E, 1, 0, f(1)); + set_word_bit(&mut trace, ShaWordCol::Uef, 2, 0, f(1)); + set_word_bit(&mut trace, ShaWordCol::A, 0, 1, f(1)); + set_word_bit(&mut trace, ShaWordCol::A, 1, 1, f(1)); + set_word_bit(&mut trace, ShaWordCol::A, 2, 1, f(1)); + set_word_bit(&mut trace, ShaWordCol::Maj, 2, 1, f(1)); + + let virtuals = reconstruct_virtual_ch_maj_at_row(&trace, 0, &cfg).unwrap(); + + assert_eq!(virtuals.ch1[0], F::zero_with_cfg(&cfg)); + assert_eq!(virtuals.maj[1], f(1)); + } + + #[test] + fn malformed_virtual_sources_return_errors() { + let cfg = test_config(); + let trace = zero_trace(); + assert!(matches!( + reconstruct_virtual_ch_maj_at_row(&trace, SHA_ROW_COUNT, &cfg), + Err(ShaProjectionError::RowIndexOutOfRange { .. }) + )); + + let public = zero_public(); + let r_ic = std::array::from_fn(|_| F::zero_with_cfg(&cfg)); + assert!(matches!( + folded_row_integrand_values( + &trace, + &public, + &r_ic, + &f(3), + &f(5), + &f(7), + &f(11), + &[ShaBooleanitySource::VirtualMaj { bit: SHA_WORD_BITS }], + &cfg, + ), + Err(ShaProjectionError::BitIndexOutOfRange { .. }) + )); + } + + #[test] + fn production_sha_sumfold_prefix_tail_matches_dense_sumcheck() { + let cfg = test_config(); + let ell = 3usize; + let a = f(3); + let traces = (0..(1usize << ell)) + .map(|idx| synthetic_boolean_trace(u64::try_from(idx).unwrap(), &a)) + .collect::>(); + let publics = vec![zero_public(); traces.len()]; + let beta = vec![f(5), f(7), f(11)]; + let r_ic = [f(2), f(3), f(5), f(7), f(11), f(13), f(17)]; + let lambda = f(19); + let rho = f(23); + let xi = f(29); + let sources = vec![ + ShaBooleanitySource::WordBit { + col: ShaWordCol::A, + bit: 0, + }, + ShaBooleanitySource::WordBit { + col: ShaWordCol::E, + bit: 1, + }, + ShaBooleanitySource::WordBit { + col: ShaWordCol::W, + bit: 2, + }, + ]; + + for prefix_vars in [1usize, 2, 3] { + let dense = build_dense_sha_sumfold_group( + &traces, &publics, &beta, &r_ic, &a, &lambda, &rho, &xi, &sources, &cfg, + ) + .unwrap(); + let optimized = build_production_sha_sumfold_group( + &traces, + &publics, + &beta, + &r_ic, + &a, + &lambda, + &rho, + &xi, + &sources, + prefix_vars, + &cfg, + ) + .unwrap(); + + let (dense_proof, dense_point, dense_expected) = prove_and_verify_sumfold(dense, ell); + let (optimized_proof, optimized_point, optimized_expected) = + prove_and_verify_sumfold(optimized, ell); + + assert_eq!(optimized_proof, dense_proof); + assert_eq!(optimized_point, dense_point); + assert_eq!(optimized_expected, dense_expected); + } + } + + #[test] + fn production_sha_sumfold_feeds_folded_row_sumcheck() { + let cfg = test_config(); + let ell = 2usize; + let a = f(3); + let traces = (0..(1usize << ell)) + .map(|idx| synthetic_boolean_trace(u64::try_from(idx).unwrap(), &a)) + .collect::>(); + let publics = vec![zero_public(); traces.len()]; + let beta = vec![f(5), f(7)]; + let r_ic = [f(2), f(3), f(5), f(7), f(11), f(13), f(17)]; + let lambda = f(19); + let rho = f(23); + let xi = f(29); + let sources = vec![ + ShaBooleanitySource::WordBit { + col: ShaWordCol::A, + bit: 0, + }, + ShaBooleanitySource::WordBit { + col: ShaWordCol::E, + bit: 1, + }, + ]; + + let sumfold_group = build_production_sha_sumfold_group( + &traces, &publics, &beta, &r_ic, &a, &lambda, &rho, &xi, &sources, 1, &cfg, + ) + .unwrap(); + let (_proof, r_b, expected) = prove_and_verify_sumfold(sumfold_group, ell); + let sumfold = + derive_instance_fold_claim(&beta, r_b, expected[0].clone(), traces.len(), &cfg) + .unwrap(); + let (folded_witness, folded_public) = + fold_projected_traces(&traces, &publics, &sumfold, &cfg).unwrap(); + + let folded_claim = expression_folded_row_sum( + &folded_witness.trace, + &folded_public, + &r_ic, + &a, + &lambda, + &rho, + &xi, + &sources, + &cfg, + ) + .unwrap(); + assert_eq!(&folded_claim, sumfold.final_round_sumcheck_claim()); + + let row_group = build_expression_folded_row_sumcheck_group( + &folded_witness.trace, + &folded_public, + &r_ic, + &a, + &lambda, + &rho, + &xi, + &sources, + &cfg, + ) + .unwrap(); + let mut row_prover_transcript = Blake3Transcript::new(); + let (row_proof, _) = MultiDegreeSumcheck::prove_as_subprotocol( + &mut row_prover_transcript, + vec![row_group], + SHA_ROW_VARS, + &cfg, + ); + let mut row_verifier_transcript = Blake3Transcript::new(); + MultiDegreeSumcheck::verify_as_subprotocol( + &mut row_verifier_transcript, + SHA_ROW_VARS, + &row_proof, + &cfg, + ) + .expect("folded row sumcheck proof should verify"); + verify_folded_row_sumcheck_claim( + &row_proof.claimed_sums()[0], + sumfold.final_round_sumcheck_claim(), + ) + .expect("folded row claim matches T'"); + } + + #[test] + fn folded_row_group_claims_row_integrand_sum() { + let cfg = test_config(); + let trace = zero_trace(); + let public = zero_public(); + let r_ic = std::array::from_fn(|_| F::zero_with_cfg(&cfg)); + let values = folded_row_integrand_values( + &trace, + &public, + &r_ic, + &f(3), + &f(5), + &f(7), + &f(11), + &[], + &cfg, + ) + .unwrap(); + let group = build_folded_row_sumcheck_group(&values, &cfg).unwrap(); + let mut transcript = Blake3Transcript::new(); + let (proof, _) = + MultiDegreeSumcheck::prove_as_subprotocol(&mut transcript, vec![group], 7, &cfg); + + assert_eq!(proof.claimed_sums()[0], F::zero_with_cfg(&cfg)); + verify_folded_row_sumcheck_claim(&proof.claimed_sums()[0], &F::zero_with_cfg(&cfg)) + .expect("row claim matches T'"); + assert!(matches!( + verify_folded_row_sumcheck_claim(&proof.claimed_sums()[0], &f(1)), + Err(ShaProjectionError::FoldedRowClaimMismatch) + )); + assert_eq!( + folded_row_integrand_sum(&values, &cfg).unwrap(), + F::zero_with_cfg(&cfg) + ); + } +} diff --git a/piop/src/neutron_nova/sumfold.rs b/piop/src/neutron_nova/sumfold.rs new file mode 100644 index 00000000..1112637b --- /dev/null +++ b/piop/src/neutron_nova/sumfold.rs @@ -0,0 +1,668 @@ +use crypto_primitives::PrimeField; +use thiserror::Error; +use zinc_poly::{ + mle::DenseMultilinearExtension, + utils::{ArithErrors, build_eq_x_r_inner, build_eq_x_r_vec, eq_eval}, +}; +use zinc_utils::{ + delayed_reduction::DelayedFieldProductSum, inner_transparent_field::InnerTransparentField, +}; + +use crate::sumcheck::{ + multi_degree::{MultiDegreeSumcheckGroup, PrefixFastPath, PrefixRoundOutput}, + prover::ProverState as SumcheckProverState, +}; + +/// Errors produced by linear SumFold prefix-table helpers. +#[derive(Clone, Debug, Error)] +pub enum SumFoldError { + #[error("linear instance claims cannot be empty")] + EmptyClaims, + #[error("instance count must be a power of two, got {len}")] + InstanceCountNotPowerOfTwo { len: usize }, + #[error("instance count mismatch for ell={ell}: got {got}, expected {expected}")] + InstanceCountMismatch { + ell: usize, + got: usize, + expected: usize, + }, + #[error("domain size is too large for ell={ell}")] + DomainTooLarge { ell: usize }, + #[error("ell0={ell0} must be at most ell={ell}")] + Ell0TooLarge { ell0: usize, ell: usize }, + #[error("beta length mismatch: got {got}, expected {expected}")] + BetaLengthMismatch { got: usize, expected: usize }, + #[error("hybrid sumfold requires ell0 < ell, got ell0={ell0}, ell={ell}")] + HybridPrefixNeedsTail { ell0: usize, ell: usize }, + #[error("equality table construction failed: {0}")] + EqTable(#[from] ArithErrors), +} + +/// Dense per-instance linear claims. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct LinearInstanceClaims { + claims: Vec, + ell: usize, +} + +impl LinearInstanceClaims { + pub fn new(claims: Vec) -> Result { + if claims.is_empty() { + return Err(SumFoldError::EmptyClaims); + } + if !claims.len().is_power_of_two() { + return Err(SumFoldError::InstanceCountNotPowerOfTwo { len: claims.len() }); + } + + let ell = + usize::try_from(claims.len().trailing_zeros()).expect("trailing_zeros fits usize"); + Ok(Self { claims, ell }) + } + + pub fn from_claims_for_ell(claims: Vec, ell: usize) -> Result { + let expected = checked_domain_size(ell)?; + if claims.len() != expected { + return Err(SumFoldError::InstanceCountMismatch { + ell, + got: claims.len(), + expected, + }); + } + Self::new(claims) + } + + pub fn claims(&self) -> &[F] { + &self.claims + } + + pub fn ell(&self) -> usize { + self.ell + } + + pub fn len(&self) -> usize { + self.claims.len() + } + + pub fn is_empty(&self) -> bool { + self.claims.is_empty() + } +} + +/// Prefix table over the first `ell0` instance variables. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct LinearPrefixTable { + values: Vec, + ell: usize, + ell0: usize, +} + +impl LinearPrefixTable { + pub(crate) fn from_values_for_prefix_vars( + values: Vec, + ell: usize, + prefix_vars: usize, + ) -> Result { + if prefix_vars > ell { + return Err(SumFoldError::Ell0TooLarge { + ell0: prefix_vars, + ell, + }); + } + let expected = checked_domain_size(prefix_vars)?; + if values.len() != expected { + return Err(SumFoldError::InstanceCountMismatch { + ell: prefix_vars, + got: values.len(), + expected, + }); + } + Ok(Self { + values, + ell, + ell0: prefix_vars, + }) + } + + #[allow(clippy::arithmetic_side_effects)] + pub fn build( + instance_claims: &LinearInstanceClaims, + beta: &[F], + ell0: usize, + field_cfg: &F::Config, + ) -> Result { + let ell = instance_claims.ell(); + if ell0 > ell { + return Err(SumFoldError::Ell0TooLarge { ell0, ell }); + } + if beta.len() != ell { + return Err(SumFoldError::BetaLengthMismatch { + got: beta.len(), + expected: ell, + }); + } + + let prefix_len = checked_domain_size(ell0)?; + let tail_vars = ell - ell0; + let tail_len = checked_domain_size(tail_vars)?; + let tail_weights = if tail_vars == 0 { + vec![F::one_with_cfg(field_cfg)] + } else { + build_eq_x_r_vec(&beta[ell0..], field_cfg)? + }; + + debug_assert_eq!(tail_weights.len(), tail_len); + let mut values = vec![F::zero_with_cfg(field_cfg); prefix_len]; + for (tail_weight, claims_chunk) in tail_weights + .iter() + .zip(instance_claims.claims().chunks_exact(prefix_len)) + { + for (value, claim) in values.iter_mut().zip(claims_chunk) { + *value += tail_weight.clone() * claim; + } + } + + Ok(Self { values, ell, ell0 }) + } + + pub fn values(&self) -> &[F] { + &self.values + } + + pub fn ell(&self) -> usize { + self.ell + } + + pub fn ell0(&self) -> usize { + self.ell0 + } + + pub fn len(&self) -> usize { + self.values.len() + } + + pub fn is_empty(&self) -> bool { + self.values.is_empty() + } + + pub fn to_mle(&self, field_cfg: &F::Config) -> DenseMultilinearExtension { + let zero_inner = F::zero_with_cfg(field_cfg).inner().clone(); + DenseMultilinearExtension::from_evaluations_vec( + self.ell0, + self.values + .iter() + .map(|value| value.inner().clone()) + .collect(), + zero_inner, + ) + } +} + +struct LinearSumFoldPrefixFastPath { + instance_claims: LinearInstanceClaims, + beta: Vec, + ell0: usize, + prefix_state: SumcheckProverState, +} + +impl LinearSumFoldPrefixFastPath +where + F: InnerTransparentField + DelayedFieldProductSum + Send + Sync + 'static, + F::Inner: num_traits::Zero, +{ + fn new( + instance_claims: LinearInstanceClaims, + beta: Vec, + ell0: usize, + field_cfg: &F::Config, + ) -> Result { + let table = LinearPrefixTable::build(&instance_claims, &beta, ell0, field_cfg)?; + let eq_prefix = build_eq_x_r_inner(&beta[..ell0], field_cfg)?; + let table_mle = table.to_mle(field_cfg); + let prefix_state = SumcheckProverState::new(vec![eq_prefix, table_mle], ell0, 2); + + Ok(Self { + instance_claims, + beta, + ell0, + prefix_state, + }) + } + + #[allow(clippy::arithmetic_side_effects)] + fn finish_tail_mles( + self, + prefix_challenges: &[F], + field_cfg: &F::Config, + ) -> Vec> { + debug_assert_eq!(prefix_challenges.len(), self.ell0); + let ell = self.instance_claims.ell(); + let tail_vars = ell - self.ell0; + let prefix_len = checked_domain_size(self.ell0).expect("validated ell0 fits usize"); + let tail_len = checked_domain_size(tail_vars).expect("validated tail vars fit usize"); + let prefix_weights = build_eq_x_r_vec(prefix_challenges, field_cfg) + .expect("prefix challenge equality table should build"); + let beta_tail_weights = build_eq_x_r_vec(&self.beta[self.ell0..], field_cfg) + .expect("tail beta equality table should build"); + let eq_prefix_at_r = eq_eval( + prefix_challenges, + &self.beta[..self.ell0], + F::one_with_cfg(field_cfg), + ) + .expect("prefix challenge and beta prefix lengths match"); + + let mut bound_claims = vec![F::zero_with_cfg(field_cfg); tail_len]; + for (tail, value) in bound_claims.iter_mut().enumerate() { + for (prefix, weight) in prefix_weights.iter().enumerate().take(prefix_len) { + let idx = prefix + (tail << self.ell0); + *value += weight.clone() * &self.instance_claims.claims()[idx]; + } + } + + let zero_inner = F::zero_with_cfg(field_cfg).inner().clone(); + let scaled_eq_tail = DenseMultilinearExtension::from_evaluations_vec( + tail_vars, + beta_tail_weights + .iter() + .map(|weight| (eq_prefix_at_r.clone() * weight).inner().clone()) + .collect(), + zero_inner.clone(), + ); + let bound_claims = DenseMultilinearExtension::from_evaluations_vec( + tail_vars, + bound_claims + .iter() + .map(|value| value.inner().clone()) + .collect(), + zero_inner, + ); + + vec![scaled_eq_tail, bound_claims] + } +} + +impl PrefixFastPath for LinearSumFoldPrefixFastPath +where + F: InnerTransparentField + DelayedFieldProductSum + Send + Sync + 'static, + F::Inner: num_traits::Zero, +{ + fn prefix_len(&self) -> usize { + self.ell0 + } + + fn prove_prefix_round( + &mut self, + verifier_msg: &Option, + config: &F::Config, + ) -> PrefixRoundOutput { + let msg = self.prefix_state.prove_round( + verifier_msg, + |values: &[F]| values[0].clone() * &values[1], + config, + ); + let asserted_sum = if self.prefix_state.round == 1 { + self.prefix_state.asserted_sum.clone() + } else { + None + }; + + PrefixRoundOutput { + asserted_sum, + tail_evaluations: msg.0.tail_evaluations, + } + } + + fn finish_prefix( + self: Box, + prefix_challenges: &[F], + config: &F::Config, + ) -> Vec> { + self.finish_tail_mles(prefix_challenges, config) + } +} + +impl LinearInstanceClaims +where + F: InnerTransparentField + DelayedFieldProductSum + Send + Sync + 'static, + F::Inner: num_traits::Zero, +{ + pub fn build_full_sumcheck_group( + &self, + beta: &[F], + field_cfg: &F::Config, + ) -> Result, SumFoldError> { + if beta.len() != self.ell { + return Err(SumFoldError::BetaLengthMismatch { + got: beta.len(), + expected: self.ell, + }); + } + + let zero_inner = F::zero_with_cfg(field_cfg).inner().clone(); + let eq_beta = build_eq_x_r_inner(beta, field_cfg)?; + let claims = DenseMultilinearExtension::from_evaluations_vec( + self.ell, + self.claims + .iter() + .map(|value| value.inner().clone()) + .collect(), + zero_inner, + ); + + Ok(MultiDegreeSumcheckGroup::new( + 2, + vec![eq_beta, claims], + Box::new(|values: &[F]| values[0].clone() * &values[1]), + )) + } + + pub fn build_hybrid_sumcheck_group( + &self, + beta: &[F], + ell0: usize, + field_cfg: &F::Config, + ) -> Result, SumFoldError> { + if beta.len() != self.ell { + return Err(SumFoldError::BetaLengthMismatch { + got: beta.len(), + expected: self.ell, + }); + } + if ell0 == 0 { + return self.build_full_sumcheck_group(beta, field_cfg); + } + if ell0 >= self.ell { + return Err(SumFoldError::HybridPrefixNeedsTail { + ell0, + ell: self.ell, + }); + } + + let fast_path = + LinearSumFoldPrefixFastPath::new(self.clone(), beta.to_vec(), ell0, field_cfg)?; + + Ok(MultiDegreeSumcheckGroup::with_prefix_fast( + 2, + Vec::new(), + Box::new(|values: &[F]| values[0].clone() * &values[1]), + Box::new(fast_path), + )) + } +} + +pub(crate) fn checked_domain_size(ell: usize) -> Result { + let shift = u32::try_from(ell).map_err(|_| SumFoldError::DomainTooLarge { ell })?; + 1usize + .checked_shl(shift) + .ok_or(SumFoldError::DomainTooLarge { ell }) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + sumcheck::multi_degree::{MultiDegreeSumcheck, MultiDegreeSumcheckProof}, + test_utils::test_config, + }; + use crypto_primitives::{FromWithConfig, crypto_bigint_monty::MontyField}; + use zinc_poly::{ + mle::{DenseMultilinearExtension, MultilinearExtensionWithConfig}, + utils::{build_eq_x_r_vec, eq_eval}, + }; + use zinc_transcript::Blake3Transcript; + + type F = MontyField<4>; + + fn f(value: u64) -> F { + F::from_with_cfg(value, &test_config()) + } + + fn claims_for_ell(ell: usize) -> LinearInstanceClaims { + let claims = (0..(1usize << ell)) + .map(|idx| { + let idx = u64::try_from(idx).expect("test index fits u64"); + f(idx + 2) + }) + .collect(); + LinearInstanceClaims::from_claims_for_ell(claims, ell).unwrap() + } + + #[allow(clippy::arithmetic_side_effects)] + fn direct_prefix_value( + claims: &LinearInstanceClaims, + beta: &[F], + ell0: usize, + prefix: usize, + ) -> F { + let cfg = test_config(); + let ell = claims.ell(); + let tail_vars = ell - ell0; + let tail_weights = if tail_vars == 0 { + vec![F::one_with_cfg(&cfg)] + } else { + build_eq_x_r_vec(&beta[ell0..], &cfg).unwrap() + }; + + let mut acc = F::zero_with_cfg(&cfg); + for (tail, weight) in tail_weights.iter().enumerate() { + let idx = prefix + (tail << ell0); + acc += weight.clone() * &claims.claims()[idx]; + } + acc + } + + #[allow(clippy::arithmetic_side_effects)] + fn direct_full_beta_sum(claims: &LinearInstanceClaims, beta: &[F]) -> F { + let cfg = test_config(); + let weights = build_eq_x_r_vec(beta, &cfg).unwrap(); + let mut acc = F::zero_with_cfg(&cfg); + for (weight, claim) in weights.iter().zip(claims.claims()) { + acc += weight.clone() * claim; + } + acc + } + + fn claims_mle( + claims: &LinearInstanceClaims, + ) -> DenseMultilinearExtension<::Inner> { + let cfg = test_config(); + let zero_inner = F::zero_with_cfg(&cfg).inner().clone(); + DenseMultilinearExtension::from_evaluations_vec( + claims.ell(), + claims + .claims() + .iter() + .map(|claim| claim.inner().clone()) + .collect(), + zero_inner, + ) + } + + fn prove_and_verify( + group: MultiDegreeSumcheckGroup, + ell: usize, + ) -> (MultiDegreeSumcheckProof, Vec, Vec) { + let cfg = test_config(); + let mut prover_transcript = Blake3Transcript::new(); + let (proof, _states) = MultiDegreeSumcheck::prove_as_subprotocol( + &mut prover_transcript, + vec![group], + ell, + &cfg, + ); + + let mut verifier_transcript = Blake3Transcript::new(); + let subclaims = + MultiDegreeSumcheck::verify_as_subprotocol(&mut verifier_transcript, ell, &proof, &cfg) + .expect("sumcheck proof should verify"); + + ( + proof, + subclaims.point().to_vec(), + subclaims.expected_evaluations().to_vec(), + ) + } + + fn proof_satisfies_dense_claim( + proof: &MultiDegreeSumcheckProof, + ell: usize, + beta: &[F], + claims: &LinearInstanceClaims, + ) -> bool { + let cfg = test_config(); + let mut verifier_transcript = Blake3Transcript::new(); + let Ok(subclaims) = + MultiDegreeSumcheck::verify_as_subprotocol(&mut verifier_transcript, ell, proof, &cfg) + else { + return false; + }; + + let eq_at_point = + eq_eval(subclaims.point(), beta, F::one_with_cfg(&cfg)).expect("same length"); + let claim_at_point = claims_mle(claims) + .evaluate_with_config(subclaims.point(), &cfg) + .unwrap(); + subclaims.expected_evaluations()[0] == eq_at_point * claim_at_point + } + + #[allow(clippy::arithmetic_side_effects)] + fn beta_for_ell(ell: usize) -> Vec { + (0..ell) + .map(|idx| f(3 + 2 * u64::try_from(idx).expect("test index fits u64"))) + .collect() + } + + fn assert_hybrid_matches_full_sumcheck(ell: usize, ell0: usize) { + let cfg = test_config(); + let claims = claims_for_ell(ell); + let beta = beta_for_ell(ell); + + let full_group = claims.build_full_sumcheck_group(&beta, &cfg).unwrap(); + let optimized_group = claims + .build_hybrid_sumcheck_group(&beta, ell0, &cfg) + .unwrap(); + + let (full_proof, full_point, full_expected) = prove_and_verify(full_group, ell); + let (optimized_proof, optimized_point, optimized_expected) = + prove_and_verify(optimized_group, ell); + + assert_eq!(optimized_proof, full_proof); + assert_eq!(optimized_point, full_point); + assert_eq!(optimized_expected, full_expected); + assert_eq!( + full_proof.claimed_sums()[0], + direct_full_beta_sum(&claims, &beta) + ); + + let eq_at_point = eq_eval(&full_point, &beta, F::one_with_cfg(&cfg)).expect("same length"); + let claim_at_point = claims_mle(&claims) + .evaluate_with_config(&full_point, &cfg) + .unwrap(); + assert_eq!(full_expected[0], eq_at_point * claim_at_point); + } + + #[test] + fn prefix_table_matches_direct_tail_fold_for_all_ell0_cases() { + let cfg = test_config(); + let claims = claims_for_ell(3); + let beta = vec![f(3), f(5), f(7)]; + + for ell0 in 0..=3 { + let table = LinearPrefixTable::build(&claims, &beta, ell0, &cfg).unwrap(); + assert_eq!(table.ell(), 3); + assert_eq!(table.ell0(), ell0); + assert_eq!(table.len(), 1usize << ell0); + + for prefix in 0..table.len() { + assert_eq!( + table.values()[prefix], + direct_prefix_value(&claims, &beta, ell0, prefix) + ); + } + } + } + + #[test] + fn hybrid_sumfold_proof_matches_full_ordinary_sumcheck() { + for (ell, ell0) in [(3, 1), (4, 2), (5, 3), (4, 0), (5, 4)] { + assert_hybrid_matches_full_sumcheck(ell, ell0); + } + } + + #[test] + fn hybrid_sumfold_rejects_tampered_prefix_and_tail_messages() { + let cfg = test_config(); + let ell = 4; + let ell0 = 2; + let claims = claims_for_ell(ell); + let beta = beta_for_ell(ell); + + let group = claims + .build_hybrid_sumcheck_group(&beta, ell0, &cfg) + .unwrap(); + let (proof, _point, _expected) = prove_and_verify(group, ell); + + let mut prefix_tampered = proof.clone(); + prefix_tampered.group_messages_mut_for_testing()[0][0] + .0 + .tail_evaluations[0] += f(1); + assert!(!proof_satisfies_dense_claim( + &prefix_tampered, + ell, + &beta, + &claims + )); + + let mut tail_tampered = proof; + tail_tampered.group_messages_mut_for_testing()[0][ell0] + .0 + .tail_evaluations[0] += f(1); + assert!(!proof_satisfies_dense_claim( + &tail_tampered, + ell, + &beta, + &claims + )); + } + + #[test] + fn linear_sumfold_validation_errors_are_reported() { + let cfg = test_config(); + + assert!(matches!( + LinearInstanceClaims::::new(Vec::new()), + Err(SumFoldError::EmptyClaims) + )); + assert!(matches!( + LinearInstanceClaims::new(vec![f(1), f(2), f(3)]), + Err(SumFoldError::InstanceCountNotPowerOfTwo { len: 3 }) + )); + assert!(matches!( + LinearInstanceClaims::from_claims_for_ell(vec![f(1), f(2), f(3)], 2), + Err(SumFoldError::InstanceCountMismatch { + ell: 2, + got: 3, + expected: 4 + }) + )); + + let claims = claims_for_ell(2); + assert!(matches!( + LinearPrefixTable::build(&claims, &[f(3), f(5)], 3, &cfg), + Err(SumFoldError::Ell0TooLarge { ell0: 3, ell: 2 }) + )); + assert!(matches!( + LinearPrefixTable::build(&claims, &[f(3)], 1, &cfg), + Err(SumFoldError::BetaLengthMismatch { + got: 1, + expected: 2 + }) + )); + + assert!(matches!( + claims + .build_hybrid_sumcheck_group(&[f(3), f(5)], 2, &cfg) + .err(), + Some(SumFoldError::HybridPrefixNeedsTail { ell0: 2, ell: 2 }) + )); + } +} diff --git a/piop/src/projections.rs b/piop/src/projections.rs index be5af276..97e6ce40 100644 --- a/piop/src/projections.rs +++ b/piop/src/projections.rs @@ -15,8 +15,9 @@ use zinc_poly::{ }; use zinc_uair::{Uair, UairTrace, collect_scalars::collect_scalars}; use zinc_utils::{ - UNCHECKED, cfg_extend, cfg_into_iter, cfg_iter, cfg_iter_mut, from_ref::FromRef, - inner_product::InnerProduct, powers, projectable_to_field::ProjectableToField, + UNCHECKED, cfg_extend, cfg_into_iter, cfg_iter, cfg_iter_mut, + delayed_reduction::DelayedFieldProductSum, from_ref::FromRef, inner_product::InnerProduct, + powers, projectable_to_field::ProjectableToField, }; /// HashMap specialization used for every `projected_scalars` lookup in the @@ -240,7 +241,7 @@ where /// MLEs (`Vec>`) for sumcheck /// compatibility. Dispatches on the trace layout internally. #[allow(clippy::arithmetic_side_effects)] -pub fn evaluate_trace_to_column_mles( +pub fn evaluate_trace_to_column_mles( trace: &ProjectedTrace, projecting_element: &F, ) -> Vec> { @@ -337,9 +338,8 @@ where { let zero_inner = F::Inner::default(); - let mut result = Vec::with_capacity( - trace.binary_poly.len() + trace.arbitrary_poly.len() + trace.int.len(), - ); + let mut result = + Vec::with_capacity(trace.binary_poly.len() + trace.arbitrary_poly.len() + trace.int.len()); let bin_proj = BinaryPoly::::prepare_projection(projecting_element); cfg_extend!( @@ -408,7 +408,7 @@ pub fn project_scalars( /// Project scalars of a UAIR along F[X] -> F. #[allow(clippy::arithmetic_side_effects)] -pub fn project_scalars_to_field( +pub fn project_scalars_to_field( scalars: ScalarMap>, projecting_element: &F, ) -> Result, (R, F, EvaluationError)> { diff --git a/piop/src/sumcheck.rs b/piop/src/sumcheck.rs index c16759ae..f339c7b8 100644 --- a/piop/src/sumcheck.rs +++ b/piop/src/sumcheck.rs @@ -170,14 +170,14 @@ impl MLSumcheck { ) -> (SumcheckProof, ProverState) where F: InnerTransparentField, - F::Inner: ConstTranscribable + Zero, - F::Modulus: ConstTranscribable, + F::Inner: Transcribable + Zero, + F::Modulus: Transcribable, { if nvars == 0 { panic!("Attempt to prove a constant") } - let mut buf = vec![0; F::Inner::NUM_BYTES]; + let mut buf = vec![0; F::zero_with_cfg(config).inner().get_num_bytes()]; let nvars_field = F::from_with_cfg(nvars as u64, config); let degree_field = F::from_with_cfg(degree as u64, config); @@ -192,7 +192,7 @@ impl MLSumcheck { let prover_msg = prover_state.prove_round(&verifier_msg, &comb_fn, config); transcript.absorb_random_field_slice(&prover_msg.0.tail_evaluations, &mut buf); prover_msgs.push(prover_msg); - let next_verifier_msg = transcript.get_field_challenge(config); + let next_verifier_msg = transcript.get_transcribable_field_challenge(config); transcript.absorb_random_field(&next_verifier_msg, &mut buf); verifier_msg = Some(next_verifier_msg); @@ -273,14 +273,14 @@ impl MLSumcheck { config: &F::Config, ) -> Result, SumCheckError> where - F::Inner: ConstTranscribable, - F::Modulus: ConstTranscribable, + F::Inner: Transcribable, + F::Modulus: Transcribable, { if num_vars == 0 { panic!("Attempt to verify a sumcheck claim for 0 variables") } - let mut buf = vec![0; F::Inner::NUM_BYTES]; + let mut buf = vec![0; F::zero_with_cfg(config).inner().get_num_bytes()]; let (nvars_field, degree_field): (F, F) = { ( diff --git a/piop/src/sumcheck/multi_degree.rs b/piop/src/sumcheck/multi_degree.rs index 98bcedfb..2c190c64 100644 --- a/piop/src/sumcheck/multi_degree.rs +++ b/piop/src/sumcheck/multi_degree.rs @@ -15,14 +15,13 @@ use crypto_primitives::{FromPrimitiveWithConfig, PrimeField}; use num_traits::Zero; -#[cfg(feature = "parallel")] -use rayon::prelude::*; use std::marker::PhantomData; -use zinc_poly::mle::DenseMultilinearExtension; -use zinc_transcript::traits::{ConstTranscribable, GenTranscribable, Transcribable, Transcript}; -use zinc_utils::{ - add, cfg_iter, cfg_iter_mut, inner_transparent_field::InnerTransparentField, mul, +use zinc_poly::{ + EvaluatablePolynomial, mle::DenseMultilinearExtension, + univariate::nat_evaluation::NatEvaluatedPoly, }; +use zinc_transcript::traits::{ConstTranscribable, GenTranscribable, Transcribable, Transcript}; +use zinc_utils::{add, inner_transparent_field::InnerTransparentField, mul}; use crate::CombFn; @@ -50,6 +49,14 @@ pub struct Round1Output { pub tail_evaluations: Vec, } +/// Output of one prefix fast-path round. +pub struct PrefixRoundOutput { + /// Set only on round 1, matching [`Round1Output::asserted_sum`]. + pub asserted_sum: Option, + /// `[p_i(1), p_i(2), ..., p_i(degree)]` for this round. + pub tail_evaluations: Vec, +} + /// Optional per-group hook that lets a degree group bypass the standard /// round-1 sumcheck loop. Used when the round-1 polynomial has a closed /// form (e.g. booleanity zerocheck on bit-slice MLEs that are 0/1 @@ -74,13 +81,68 @@ pub trait Round1FastPath: Send + Sync { ) -> Vec>; } +/// Optional per-group hook that emits a prefix of ordinary sumcheck messages. +/// +/// A prefix fast path must be transcript-equivalent to running +/// [`SumcheckProverState::prove_round`] for rounds `1..=prefix_len`. Once the +/// last prefix challenge is sampled, it returns the MLEs already bound at the +/// full prefix challenge point, so the standard prover can continue with the +/// remaining variables. +pub trait PrefixFastPath: Send + Sync { + fn prefix_len(&self) -> usize; + + fn prove_prefix_round( + &mut self, + verifier_msg: &Option, + config: &F::Config, + ) -> PrefixRoundOutput; + + fn finish_prefix( + self: Box, + prefix_challenges: &[F], + config: &F::Config, + ) -> Vec>; +} + +struct Round1PrefixAdapter { + inner: Box>, +} + +impl PrefixFastPath for Round1PrefixAdapter { + fn prefix_len(&self) -> usize { + 1 + } + + fn prove_prefix_round( + &mut self, + verifier_msg: &Option, + config: &F::Config, + ) -> PrefixRoundOutput { + debug_assert!(verifier_msg.is_none()); + let out = self.inner.round_1_message(config); + PrefixRoundOutput { + asserted_sum: Some(out.asserted_sum), + tail_evaluations: out.tail_evaluations, + } + } + + fn finish_prefix( + self: Box, + prefix_challenges: &[F], + config: &F::Config, + ) -> Vec> { + debug_assert_eq!(prefix_challenges.len(), 1); + self.inner.fold_with_r1(&prefix_challenges[0], config) + } +} + /// A single degree group for the multi-degree sumcheck: (degree, mles, /// comb_fn). pub struct MultiDegreeSumcheckGroup { degree: usize, poly: Vec>, comb_fn: CombFn, - round_1_fast: Option>>, + prefix_fast: Option>>, } impl MultiDegreeSumcheckGroup { @@ -93,7 +155,7 @@ impl MultiDegreeSumcheckGroup { degree, poly, comb_fn, - round_1_fast: None, + prefix_fast: None, } } @@ -105,12 +167,34 @@ impl MultiDegreeSumcheckGroup { poly: Vec>, comb_fn: CombFn, round_1_fast: Box>, + ) -> Self + where + F: 'static, + { + Self { + degree, + poly, + comb_fn, + prefix_fast: Some(Box::new(Round1PrefixAdapter { + inner: round_1_fast, + })), + } + } + + /// Construct a group whose initial rounds are produced by a custom + /// [`PrefixFastPath`]. `poly` may be empty here — the fast path supplies + /// post-prefix MLEs via [`PrefixFastPath::finish_prefix`]. + pub fn with_prefix_fast( + degree: usize, + poly: Vec>, + comb_fn: CombFn, + prefix_fast: Box>, ) -> Self { Self { degree, poly, comb_fn, - round_1_fast: Some(round_1_fast), + prefix_fast: Some(prefix_fast), } } } @@ -135,6 +219,19 @@ impl MultiDegreeSumcheckProof { pub fn claimed_sums(&self) -> &[F] { &self.claimed_sums } + + pub fn degrees(&self) -> &[usize] { + &self.degrees + } + + pub fn group_messages(&self) -> &[Vec>] { + &self.group_messages + } + + #[cfg(test)] + pub(crate) fn group_messages_mut_for_testing(&mut self) -> &mut [Vec>] { + &mut self.group_messages + } } impl GenTranscribable for MultiDegreeSumcheckProof @@ -318,8 +415,32 @@ impl MultiDegreeSumcheck { ) -> (MultiDegreeSumcheckProof, Vec>) where F: InnerTransparentField + Send + Sync, - F::Inner: ConstTranscribable + Zero, - F::Modulus: ConstTranscribable, + F::Inner: Transcribable + Zero, + F::Modulus: Transcribable, + { + let (proof, states, _) = Self::prove_as_subprotocol_with_expected_evaluations( + transcript, groups, num_vars, config, + ); + (proof, states) + } + + /// Prove a multi-degree sumcheck and also return each group's expected + /// terminal evaluation at the sampled point. + #[allow(clippy::type_complexity)] + pub fn prove_as_subprotocol_with_expected_evaluations( + transcript: &mut impl Transcript, + groups: Vec>, + num_vars: usize, + config: &F::Config, + ) -> ( + MultiDegreeSumcheckProof, + Vec>, + Vec, + ) + where + F: InnerTransparentField + Send + Sync, + F::Inner: Transcribable + Zero, + F::Modulus: Transcribable, { assert!( num_vars > 0, @@ -328,7 +449,7 @@ impl MultiDegreeSumcheck { assert!(!groups.is_empty(), "need at least one degree group"); let num_groups = groups.len(); - let mut buf = vec![0; F::Inner::NUM_BYTES]; + let mut buf = vec![0; F::zero_with_cfg(config).inner().get_num_bytes()]; let nvars_field = F::from_with_cfg(num_vars as u64, config); let ngroups_field = F::from_with_cfg(num_groups as u64, config); transcript.absorb_random_field(&nvars_field, &mut buf); @@ -341,79 +462,89 @@ impl MultiDegreeSumcheck { let mut prover_states: Vec> = Vec::with_capacity(num_groups); let mut comb_fns: Vec> = Vec::with_capacity(num_groups); - let mut fast_paths: Vec>>> = + let mut fast_paths: Vec>>> = Vec::with_capacity(num_groups); for group in groups { let degree_field = F::from_with_cfg(group.degree as u64, config); transcript.absorb_random_field(°ree_field, &mut buf); + if let Some(ref fp) = group.prefix_fast { + assert!( + fp.prefix_len() > 0 && fp.prefix_len() <= num_vars, + "prefix fast path length must be in 1..=num_vars" + ); + } prover_states.push(SumcheckProverState::new(group.poly, num_vars, group.degree)); comb_fns.push(group.comb_fn); - fast_paths.push(group.round_1_fast); + fast_paths.push(group.prefix_fast); } - // ---- Round 1 --------------------------------------------------- - let mut round_1_msgs: Vec> = Vec::with_capacity(num_groups); - for ((state, comb_fn), fp_slot) in prover_states - .iter_mut() - .zip(comb_fns.iter()) - .zip(fast_paths.iter()) - { - let msg = if let Some(fp) = fp_slot { - let out = fp.round_1_message(config); - debug_assert_eq!( - out.tail_evaluations.len(), - state.max_degree, - "fast-path round-1 tail must have length equal to group's degree" - ); - state.asserted_sum = Some(out.asserted_sum); - state.round = 1; - SumcheckProverMsg(NatEvaluatedPolyWithoutConstant::new(out.tail_evaluations)) - } else { - state.prove_round(&None, comb_fn, config) - }; - round_1_msgs.push(msg); - } - for msg in &round_1_msgs { - transcript.absorb_random_field_slice(&msg.0.tail_evaluations, &mut buf); - } - for (j, msg) in round_1_msgs.into_iter().enumerate() { - group_messages[j].push(msg); - } - let r_1: F = transcript.get_field_challenge(config); - transcript.absorb_random_field(&r_1, &mut buf); - let mut verifier_msg = Some(r_1.clone()); - - // For fast-path groups, materialize the round-1-folded MLEs and - // mark the next standard fold to be skipped. - for (state, fp_slot) in prover_states.iter_mut().zip(fast_paths.iter_mut()) { - if let Some(fp) = fp_slot.take() { - let folded = fp.fold_with_r1(&r_1, config); - state.mles = folded; - state.skip_next_fold = true; + let mut verifier_msg = None; + let mut challenges = Vec::with_capacity(num_vars); + for round_idx in 0..num_vars { + let mut round_msgs: Vec> = Vec::with_capacity(num_groups); + for group_idx in 0..num_groups { + let use_fast = fast_paths[group_idx] + .as_ref() + .is_some_and(|fp| round_idx < fp.prefix_len()); + let msg = if use_fast { + let fp = fast_paths[group_idx] + .as_mut() + .expect("fast path must exist when use_fast is true"); + let out = fp.prove_prefix_round(&verifier_msg, config); + debug_assert_eq!( + out.tail_evaluations.len(), + prover_states[group_idx].max_degree, + "prefix fast-path tail must match the group's degree" + ); + if round_idx == 0 { + prover_states[group_idx].asserted_sum = Some( + out.asserted_sum + .expect("prefix fast path must provide the first asserted sum"), + ); + } else { + debug_assert!(out.asserted_sum.is_none()); + } + prover_states[group_idx].round = round_idx + 1; + SumcheckProverMsg(NatEvaluatedPolyWithoutConstant::new(out.tail_evaluations)) + } else { + prover_states[group_idx].prove_round( + &verifier_msg, + &comb_fns[group_idx], + config, + ) + }; + round_msgs.push(msg); } - } - // ---- Rounds 2..num_vars --------------------------------------- - for _ in 1..num_vars { - // Parallel: each group computes its round polynomial independently - let round_msgs: Vec> = cfg_iter_mut!(prover_states) - .zip(cfg_iter!(comb_fns)) - .map(|(state, comb_fn)| state.prove_round(&verifier_msg, comb_fn, config)) - .collect(); - - // Sequential: absorb in deterministic order, sample one shared challenge for msg in &round_msgs { transcript.absorb_random_field_slice(&msg.0.tail_evaluations, &mut buf); } - for (j, msg) in round_msgs.into_iter().enumerate() { group_messages[j].push(msg); } - let next_verifier_msg = transcript.get_field_challenge(config); - transcript.absorb_random_field(&next_verifier_msg, &mut buf); + let challenge: F = transcript.get_transcribable_field_challenge(config); + transcript.absorb_random_field(&challenge, &mut buf); + + for group_idx in 0..num_groups { + let should_finish = fast_paths[group_idx] + .as_ref() + .is_some_and(|fp| fp.prefix_len() == round_idx + 1); + if should_finish { + let fp = fast_paths[group_idx] + .take() + .expect("fast path must exist when should_finish is true"); + let mut prefix_challenges = challenges.clone(); + prefix_challenges.push(challenge.clone()); + prover_states[group_idx].mles = fp.finish_prefix(&prefix_challenges, config); + prover_states[group_idx].randomness = challenges.clone(); + prover_states[group_idx].round = round_idx + 1; + prover_states[group_idx].skip_next_fold = round_idx + 1 < num_vars; + } + } - verifier_msg = Some(next_verifier_msg); + verifier_msg = Some(challenge.clone()); + challenges.push(challenge); } prover_states.iter_mut().for_each(|state| { @@ -428,7 +559,18 @@ impl MultiDegreeSumcheck { } }); - let degrees = prover_states.iter().map(|s| s.max_degree).collect(); + let degrees = prover_states + .iter() + .map(|s| s.max_degree) + .collect::>(); + let expected_evaluations = group_messages + .iter() + .zip(claimed_sums.iter()) + .zip(degrees.iter()) + .map(|((messages, claimed_sum), degree)| { + Self::expected_evaluation_from_messages(claimed_sum, messages, *degree, &challenges) + }) + .collect(); ( MultiDegreeSumcheckProof { @@ -437,9 +579,44 @@ impl MultiDegreeSumcheck { degrees, }, prover_states, + expected_evaluations, ) } + fn expected_evaluation_from_messages( + claimed_sum: &F, + messages: &[SumcheckProverMsg], + degree: usize, + challenges: &[F], + ) -> F { + assert_eq!( + messages.len(), + challenges.len(), + "generated sumcheck proof should have one message per challenge" + ); + let mut expected = claimed_sum.clone(); + for (message, challenge) in messages.iter().zip(challenges.iter()) { + let tail = &message.0.tail_evaluations; + assert_eq!( + tail.len(), + degree, + "generated sumcheck message should match group degree" + ); + let constant_term = if degree == 0 { + expected.clone() + } else { + expected.clone() - tail[0].clone() + }; + let mut reconstructed_evaluations = Vec::with_capacity(tail.len() + 1); + reconstructed_evaluations.push(constant_term); + reconstructed_evaluations.extend_from_slice(tail); + expected = NatEvaluatedPoly::new(reconstructed_evaluations) + .evaluate_at_point(challenge) + .expect("generated sumcheck message should interpolate"); + } + expected + } + /// Multi-degree sumcheck verifier. /// /// Runs the verifier side of the sumcheck protocol for G degree groups @@ -488,8 +665,8 @@ impl MultiDegreeSumcheck { ) -> Result, SumCheckError> where F: InnerTransparentField, - F::Inner: ConstTranscribable, - F::Modulus: ConstTranscribable, + F::Inner: Transcribable, + F::Modulus: Transcribable, { assert!( num_vars > 0, @@ -498,7 +675,7 @@ impl MultiDegreeSumcheck { let num_groups = proof.degrees.len(); assert!(num_groups != 0, "need at least one degree group"); - let mut buf = vec![0; F::Inner::NUM_BYTES]; + let mut buf = vec![0; F::zero_with_cfg(config).inner().get_num_bytes()]; let nvars_field = F::from_with_cfg(num_vars as u64, config); let ngroups_field = F::from_with_cfg(num_groups as u64, config); transcript.absorb_random_field(&nvars_field, &mut buf); @@ -536,7 +713,7 @@ impl MultiDegreeSumcheck { transcript.absorb_random_field_slice(&msg[i].0.tail_evaluations, &mut buf) }); - let shared_challenge: F = transcript.get_field_challenge(config); + let shared_challenge: F = transcript.get_transcribable_field_challenge(config); transcript.absorb_random_field(&shared_challenge, &mut buf); verifier_states diff --git a/piop/src/sumcheck/prover.rs b/piop/src/sumcheck/prover.rs index 79f3fa96..e6407c33 100644 --- a/piop/src/sumcheck/prover.rs +++ b/piop/src/sumcheck/prover.rs @@ -146,20 +146,17 @@ where evals: Vec, steps: Vec, vals0: Vec, - vals1: Vec, vals: Vec, levals: Vec, } let zero = F::zero_with_cfg(config); let zero_vec_deg = vec![zero.clone(); degree + 1]; - let zero_vec_poly = vec![zero.clone(); polys.len()]; let scratch = || Scratch { evals: zero_vec_deg.clone(), - steps: zero_vec_poly.clone(), - vals0: zero_vec_poly.clone(), - vals1: zero_vec_poly.clone(), - vals: zero_vec_poly.clone(), - levals: zero_vec_deg.clone(), + steps: Vec::with_capacity(polys.len()), + vals0: Vec::with_capacity(polys.len()), + vals: Vec::with_capacity(polys.len()), + levals: Vec::with_capacity(degree + 1), }; #[cfg(not(feature = "parallel"))] @@ -182,29 +179,35 @@ where // My bet is that it won't affect running time, but better safe than // sorry. - s.vals0 - .iter_mut() - .zip(polys.iter()) - .for_each(|(v0, poly)| *v0.inner_mut() = poly[index].clone()); - s.levals[0] = comb_fn(&s.vals0); + s.vals0.clear(); + s.vals0.extend( + polys + .iter() + .map(|poly| F::new_unchecked_with_cfg(poly[index].clone(), config)), + ); + s.levals.clear(); + s.levals.push(comb_fn(&s.vals0)); if degree > 0 { - s.vals1 - .iter_mut() - .zip(polys.iter()) - .for_each(|(v1, poly)| *v1.inner_mut() = poly[index + 1].clone()); - s.levals[1] = comb_fn(&s.vals1); - - for (i, (v1, v0)) in s.vals1.iter().zip(s.vals0.iter()).enumerate() { - s.steps[i] = v1.clone() - v0.clone(); - s.vals[i] = v1.clone(); - } - - for eval_point in s.levals.iter_mut().take(degree + 1).skip(2) { - for poly_i in 0..polys.len() { - s.vals[poly_i] += &s.steps[poly_i]; + s.vals.clear(); + s.vals.extend( + polys + .iter() + .map(|poly| F::new_unchecked_with_cfg(poly[index + 1].clone(), config)), + ); + s.levals.push(comb_fn(&s.vals)); + + s.steps.clear(); + s.steps + .extend(s.vals.iter().zip(s.vals0.iter()).map(|(v1, v0)| { + v1.clone() - v0.clone() + })); + + for _ in 2..=degree { + for (value, step) in s.vals.iter_mut().zip(s.steps.iter()) { + *value += step; } - *eval_point = comb_fn(&s.vals); + s.levals.push(comb_fn(&s.vals)); } } diff --git a/piop/src/sumcheck/verifier.rs b/piop/src/sumcheck/verifier.rs index f9c857c4..e580e8d8 100644 --- a/piop/src/sumcheck/verifier.rs +++ b/piop/src/sumcheck/verifier.rs @@ -2,7 +2,7 @@ use crypto_primitives::{FromPrimitiveWithConfig, PrimeField}; use zinc_poly::{EvaluatablePolynomial, univariate::nat_evaluation::NatEvaluatedPoly}; -use zinc_transcript::traits::{ConstTranscribable, Transcript}; +use zinc_transcript::traits::{Transcribable, Transcript}; use zinc_utils::add; use crate::sumcheck::prover::{NatEvaluatedPolyWithoutConstant, ProverMsg}; @@ -65,9 +65,9 @@ impl VerifierState { /// [`Self::verify_round_with_challenge`]. Returns the sampled challenge. pub fn verify_round(&mut self, prover_msg: &ProverMsg, transcript: &mut impl Transcript) -> F where - F::Inner: ConstTranscribable, + F::Inner: Transcribable, { - let challenge: F = transcript.get_field_challenge(&self.config); + let challenge: F = transcript.get_transcribable_field_challenge(&self.config); self.verify_round_with_challenge(prover_msg, challenge.clone()); challenge } diff --git a/poly/src/mle/dense.rs b/poly/src/mle/dense.rs index 84735905..389aabff 100644 --- a/poly/src/mle/dense.rs +++ b/poly/src/mle/dense.rs @@ -257,15 +257,15 @@ where let nv = self.num_vars; let dim = partial_point.len(); - let mut r = partial_point[0].clone(); for i in 1..dim + 1 { + let r_base = &partial_point[i - 1]; for b in 0..1 << (nv - i) { - *r.inner_mut() = partial_point[i - 1].inner().clone(); if self[2 * b + 1] != self[2 * b] { // a = f(1) - f(0) let a = F::sub_inner(&self[2 * b + 1], &self[2 * b], config); // self[b] = f(0) + r * a + let mut r = r_base.clone(); r.mul_assign_by_inner(&a); self[b] = F::add_inner(&self[2 * b], r.inner(), config); } else { diff --git a/poly/src/univariate/binary_ref.rs b/poly/src/univariate/binary_ref.rs index 1370a5cc..d72d22b4 100644 --- a/poly/src/univariate/binary_ref.rs +++ b/poly/src/univariate/binary_ref.rs @@ -51,6 +51,12 @@ impl BinaryRefPoly { pub const fn inner(&self) -> &DensePolynomial { &self.0 } + + #[inline(always)] + pub fn coeff(&self, idx: usize) -> bool { + assert!(idx < DEGREE_PLUS_ONE); + self.0.coeffs[idx].inner() + } } impl From> diff --git a/poly/src/univariate/binary_u64.rs b/poly/src/univariate/binary_u64.rs index 8d5f031a..b36cf46a 100644 --- a/poly/src/univariate/binary_u64.rs +++ b/poly/src/univariate/binary_u64.rs @@ -32,6 +32,13 @@ impl BinaryU64Poly { pub const fn inner(&self) -> &u64 { &self.0 } + + #[inline(always)] + #[allow(clippy::arithmetic_side_effects)] + pub fn coeff(&self, idx: usize) -> bool { + assert!(idx < DEGREE_PLUS_ONE && idx < u64::BITS as usize); + !(self.0 & (1 << idx)).is_zero() + } } impl From> for u64 { diff --git a/poly/src/univariate/dynamic/over_field.rs b/poly/src/univariate/dynamic/over_field.rs index 4d21816b..c60a62cc 100644 --- a/poly/src/univariate/dynamic/over_field.rs +++ b/poly/src/univariate/dynamic/over_field.rs @@ -8,7 +8,8 @@ use std::{ use zinc_transcript::traits::{ConstTranscribable, GenTranscribable, Transcribable}; use zinc_utils::{ UNCHECKED, add, - inner_product::{InnerProduct, InnerProductError}, + delayed_reduction::DelayedFieldProductSum, + inner_product::{FieldFieldInnerProduct, InnerProduct, InnerProductError}, mul, }; @@ -128,17 +129,13 @@ impl DynamicPolynomialF { /// Inner product for dynamic polynomials over a prime field. pub struct DynamicPolyFInnerProduct; -impl InnerProduct<[F], F, F> for DynamicPolyFInnerProduct { - #[allow(clippy::arithmetic_side_effects)] +impl InnerProduct<[F], F, F> for DynamicPolyFInnerProduct { fn inner_product( lhs: &[F], rhs: &[F], zero: F, ) -> Result { - Ok(lhs - .iter() - .zip(rhs) - .fold(zero, |acc, (coeff, power)| acc + coeff.clone() * power)) + FieldFieldInnerProduct::inner_product::(lhs, rhs, zero) } } diff --git a/protocol/Cargo.toml b/protocol/Cargo.toml index 8083869d..013985b4 100644 --- a/protocol/Cargo.toml +++ b/protocol/Cargo.toml @@ -12,6 +12,7 @@ rayon = { workspace = true, optional = true } itertools = { workspace = true } num-traits = { workspace = true } thiserror = { workspace = true } +tracing = { workspace = true } zip-plus = { workspace = true } zinc-piop = { workspace = true } zinc-poly = { workspace = true } @@ -19,14 +20,19 @@ zinc-primality = { workspace = true } zinc-transcript = { workspace = true } zinc-uair = { workspace = true } zinc-utils = { workspace = true } +ark-ec = { version = "0.5.0", default-features = false } +ark-ff = { version = "0.5.0", default-features = false } [lib] bench = false [dev-dependencies] +ark-bn254 = { version = "0.5.0", default-features = false, features = ["curve"] } +ark-secp256k1 = { version = "0.5.0", default-features = false } criterion = { workspace = true } crypto-bigint = { workspace = true } rand = { workspace = true } +tracing-subscriber = { workspace = true } zinc-test-uair = { workspace = true } zstd = "0.13" libc = "0.2" @@ -36,7 +42,7 @@ workspace = true [features] -parallel = ["dep:rayon", "zinc-piop/parallel", "zinc-poly/parallel", "zinc-uair/parallel", "zinc-utils/parallel"] +parallel = ["dep:rayon", "zip-plus/parallel", "zinc-piop/parallel", "zinc-poly/parallel", "zinc-uair/parallel", "zinc-utils/parallel", "zinc-test-uair/parallel"] simd = ["zinc-poly/simd", "zinc-piop/simd", "zip-plus/simd"] unchecked = [] # Switch the IPRS bench code from inverse-rate 4 (rate 1/4) to inverse-rate 8 diff --git a/protocol/benches/e2e.rs b/protocol/benches/e2e.rs index b41b0a3f..d1098eed 100644 --- a/protocol/benches/e2e.rs +++ b/protocol/benches/e2e.rs @@ -1,16 +1,26 @@ #![allow(clippy::arithmetic_side_effects)] +use ark_ec::{AffineRepr, CurveGroup, PrimeGroup}; use criterion::{ BatchSize, BenchmarkGroup, BenchmarkId, Criterion, criterion_group, criterion_main, measurement::WallTime, }; use crypto_bigint::U64; use crypto_primitives::{ - ConstIntRing, ConstIntSemiring, Field, FixedSemiring, FromWithConfig, PrimeField, - crypto_bigint_int::Int, crypto_bigint_monty::MontyField, crypto_bigint_uint::Uint, + ConstIntRing, ConstIntSemiring, ConstSemiring, Field, FixedSemiring, FromWithConfig, + PrimeField, crypto_bigint_int::Int, crypto_bigint_monty::MontyField, crypto_bigint_uint::Uint, }; +use num_traits::Zero; use rand::rng; -use std::{fmt::Debug, hint::black_box, marker::PhantomData, ops::Neg}; +#[cfg(feature = "parallel")] +use rayon::prelude::*; +use std::{borrow::Cow, fmt::Debug, hint::black_box, marker::PhantomData, ops::Neg}; +use zinc_piop::neutron_nova::{ + MleTable, ProjectedPublic, ProjectedTrace, SHA_ROW_COUNT, SHA_ROW_VARS, SHA_WORD_BITS, + ShaPublicCol, bit_slice_index, +}; +use zinc_poly::mle::DenseMultilinearExtension; +use zinc_poly::univariate::dynamic::over_field::DynamicPolyVecF; use zinc_poly::{ ConstCoeffBitWidth, Polynomial, univariate::{ @@ -22,21 +32,38 @@ use zinc_poly::{ use zinc_primality::{MillerRabin, PrimalityTest}; use zinc_protocol::{ FoldedZincTypes, IntFoldedZincTypes4x, Proof, ZincPlusPiop, ZincTypes, + fixed_prime::field_cfg_from_curve_scalar, + pcs::{ + AllHyraxPCSTypes, AllZipPCSTypes, BinaryIntHyraxZipArbitrary, PCSCommitments, PCSParams, + PCSVerifierParams, ZincPCSTypes, + }, + production_sha::{ + LinearIdealFoldProverParams, LinearIdealFoldVerifierParams, ProductionShaError, + ProductionShaProjectionAdapter, ProductionShaWitnessPolys, UairShape, + prepare_linear_ideal_fold_witnesses, prove_prepared_linear_ideal_fold, + setup_verify_linear_ideal_fold, verify_linear_ideal_fold, + }, }; use zinc_test_uair::{ BigLinearUair, BigLinearUairWithPublicInput, BinaryDecompositionUair, EC_FP_INT_LIMBS, - EcdsaUair, GenerateRandomTrace, Sha256CompressionSliceUair, Sha256Ideal, ShaEcdsaUair, - ShaProxy, TestUairNoMultiplication, + EcdsaUair, GenerateRandomTrace, SHA256_INITIAL_STATE, Sha256CompressionSliceUair, Sha256Ideal, + Sha256MessageBlock, ShaEcdsaUair, ShaProxy, TestUairNoMultiplication, + sha256::{K_CANONICAL, cols as sha256_cols}, + sha256_padded_message_blocks, synthesize_sha256_chain_trace, synthesize_sha256_chain_witnesses, +}; +use zinc_transcript::{ + Blake3Transcript, + traits::{ConstTranscribable, Transcribable}, }; -use zinc_poly::univariate::dynamic::over_field::DynamicPolyVecF; -use zinc_transcript::traits::{ConstTranscribable, Transcribable}; use zinc_uair::{ - Uair, UairTrace, + ConstraintBuilder, PublicStructureError, TraceRow, Uair, UairSignature, UairTrace, degree_counter::count_effective_max_degree, ideal::{DegreeOneIdeal, Ideal, IdealCheck, ImpossibleIdeal, rotation::RotationIdeal}, ideal_collector::IdealOrZero, }; use zinc_utils::{ + cfg_into_iter, cfg_iter, + delayed_reduction::{BarrettDelayedReduction, DelayedModularReductionAlgorithm}, from_ref::FromRef, inner_product::{InnerProduct, MBSInnerProduct, ScalarProduct}, mul_by_scalar::MulByScalar, @@ -44,9 +71,14 @@ use zinc_utils::{ projectable_to_field::ProjectableToField, }; use zip_plus::{ - code::iprs::{IprsCode, PnttConfigF65537}, + code::{ + LinearCode, + iprs::{IprsCode, PnttConfigF65537}, + }, + pcs::generic::{PCS, ZipPlusPCS}, + pcs::hyrax::{BinaryLanes, DensePolyScalarLanes, HyraxBlindingMode, HyraxPCS, IntScalarLane}, pcs::structs::{ZipPlus, ZipPlusCommitment, ZipPlusParams, ZipTypes}, - utils::{eprint_bytes_size_breakdown, eprint_proof_size}, + utils::{eprint_bytes_size, eprint_bytes_size_breakdown, eprint_proof_size}, }; // @@ -186,20 +218,8 @@ struct GenericBenchZincTypes< )>, ); -impl - ZincTypes - for GenericBenchZincTypes< - Int, - CwR, - Chal, - Pt, - BinaryCombR, - CombR, - IntCombR, - Fmod, - PrimeTest, - D, - > +impl ZincTypes + for GenericBenchZincTypes where Int: ConstIntSemiring + for<'a> MulByScalar<&'a i64, CwR> @@ -317,8 +337,7 @@ where const DEGREE_PLUS_ONE: usize = 32; const INT_LIMBS: usize = U64::LIMBS; -// `fixed-prime` branch: 256-bit field modulus (4 × u64 limbs) so that the -// fixed secp256k1 base prime fits in `Fmod = Uint`. +// 256-bit field modulus (4 × u64 limbs). const FIELD_LIMBS: usize = U64::LIMBS * 4; type F = MontyField; @@ -430,6 +449,618 @@ fn sha256_real_project_ideal( } } +const REAL_SHA256_CHAIN_BLOCKS: usize = 8; +const REAL_SHA256_CHAIN_NUM_VARS: usize = 10; + +fn real_sha256_chain_message() -> String { + vec!["hello world"; 40].join(" ") +} + +#[allow(clippy::unwrap_used)] +fn real_sha256_chain_blocks() -> [Sha256MessageBlock; REAL_SHA256_CHAIN_BLOCKS] { + let message = real_sha256_chain_message(); + sha256_padded_message_blocks::(message.as_bytes()) + .expect("real SHA-256 benchmark fixture should canonically pad to eight blocks") +} + +#[allow(clippy::unwrap_used)] +fn real_sha256_chain_trace( + num_vars: usize, +) -> UairTrace<'static, RealEcdsaInt, RealEcdsaInt, DEGREE_PLUS_ONE> { + let (trace, _final_state) = synthesize_sha256_chain_trace::< + RealEcdsaInt, + REAL_SHA256_CHAIN_BLOCKS, + >(num_vars, SHA256_INITIAL_STATE, real_sha256_chain_blocks()) + .expect("real SHA-256 monolithic chain trace synthesis should succeed"); + trace +} + +#[derive(Clone, Debug)] +struct ProjectionShaBenchUair(PhantomData); + +impl Uair for ProjectionShaBenchUair +where + R: ConstSemiring + 'static, +{ + type Ideal = Sha256Ideal; + type Scalar = DensePolynomial; + + fn signature() -> UairSignature { + Sha256CompressionSliceUair::::signature() + } + + fn constrain_general( + b: &mut B, + up: TraceRow, + down: TraceRow, + from_ref: FromR, + mbs: MulByScalarFn, + ideal_from_ref: IFromR, + ) where + B: ConstraintBuilder, + FromR: Fn(&Self::Scalar) -> B::Expr, + MulByScalarFn: Fn(&B::Expr, &Self::Scalar) -> Option, + IFromR: Fn(&Self::Ideal) -> B::Ideal, + { + Sha256CompressionSliceUair::::constrain_general( + b, + up, + down, + from_ref, + mbs, + ideal_from_ref, + ); + } + + fn verify_public_structure( + public_trace: &UairTrace<'_, RT, IntT, D>, + num_vars: usize, + ) -> Result<(), PublicStructureError> + where + RT: Clone, + IntT: Clone + num_traits::Zero, + { + Sha256CompressionSliceUair::::verify_public_structure(public_trace, num_vars) + } +} + +fn projection_sha_binary_col<'a>( + public_trace: &'a UairTrace<'_, RealEcdsaInt, RealEcdsaInt, DEGREE_PLUS_ONE>, + witness_trace: &'a UairTrace<'_, RealEcdsaInt, RealEcdsaInt, DEGREE_PLUS_ONE>, + flat_col: usize, +) -> Result<&'a DenseMultilinearExtension>, ProductionShaError> { + if flat_col < sha256_cols::NUM_BIN_PUB { + public_trace + .binary_poly + .get(flat_col) + .ok_or(ProductionShaError::LengthMismatch { + label: "SHA binary public source columns", + got: public_trace.binary_poly.len(), + expected: flat_col + 1, + }) + } else { + let witness_col = flat_col - sha256_cols::NUM_BIN_PUB; + witness_trace + .binary_poly + .get(witness_col) + .ok_or(ProductionShaError::LengthMismatch { + label: "SHA binary witness source columns", + got: witness_trace.binary_poly.len(), + expected: witness_col + 1, + }) + } +} + +fn projection_sha_int_col<'a>( + public_trace: &'a UairTrace<'_, RealEcdsaInt, RealEcdsaInt, DEGREE_PLUS_ONE>, + witness_trace: &'a UairTrace<'_, RealEcdsaInt, RealEcdsaInt, DEGREE_PLUS_ONE>, + flat_col: usize, +) -> Result<&'a DenseMultilinearExtension, ProductionShaError> { + if flat_col < sha256_cols::NUM_INT_PUB { + public_trace + .int + .get(flat_col) + .ok_or(ProductionShaError::LengthMismatch { + label: "SHA int public source columns", + got: public_trace.int.len(), + expected: flat_col + 1, + }) + } else { + let witness_col = flat_col - sha256_cols::NUM_INT_PUB; + witness_trace + .int + .get(witness_col) + .ok_or(ProductionShaError::LengthMismatch { + label: "SHA int witness source columns", + got: witness_trace.int.len(), + expected: witness_col + 1, + }) + } +} + +fn projection_sha_project_binary_source( + col: &DenseMultilinearExtension>, + field_cfg: &::Config, +) -> Result>, ProductionShaError> { + if col.evaluations.len() < SHA_ROW_COUNT { + return Err(ProductionShaError::LengthMismatch { + label: "SHA binary source rows", + got: col.evaluations.len(), + expected: SHA_ROW_COUNT, + }); + } + let rows = &col.evaluations[..SHA_ROW_COUNT]; + Ok(cfg_iter!(rows) + .map(|poly| { + poly.iter() + .take(SHA_WORD_BITS) + .map(|bit| { + if bit.into_inner() { + F::one_with_cfg(field_cfg) + } else { + F::zero_with_cfg(field_cfg) + } + }) + .collect() + }) + .collect()) +} + +fn projection_sha_project_int_source( + col: &DenseMultilinearExtension, + field_cfg: &::Config, +) -> Result, ProductionShaError> { + if col.evaluations.len() < SHA_ROW_COUNT { + return Err(ProductionShaError::LengthMismatch { + label: "SHA int source rows", + got: col.evaluations.len(), + expected: SHA_ROW_COUNT, + }); + } + let rows = &col.evaluations[..SHA_ROW_COUNT]; + Ok(cfg_iter!(rows) + .map(|value| F::from_with_cfg(value, field_cfg)) + .collect()) +} + +fn projection_sha_truncate_row_domain( + col: &DenseMultilinearExtension, + label: &'static str, +) -> Result, ProductionShaError> { + if col.evaluations.len() < SHA_ROW_COUNT { + return Err(ProductionShaError::LengthMismatch { + label, + got: col.evaluations.len(), + expected: SHA_ROW_COUNT, + }); + } + Ok(DenseMultilinearExtension { + evaluations: cfg_iter!(&col.evaluations[..SHA_ROW_COUNT]) + .cloned() + .collect(), + num_vars: SHA_ROW_VARS, + }) +} + +fn projection_sha_word_scalar_at_two(bits: &[F], field_cfg: &::Config) -> F { + let two = F::one_with_cfg(field_cfg) + F::one_with_cfg(field_cfg); + let mut power = F::one_with_cfg(field_cfg); + let mut value = F::zero_with_cfg(field_cfg); + for bit in bits { + value += bit.clone() * &power; + power *= &two; + } + value +} + +fn projection_sha_mle_table_from_columns(columns: Vec>) -> MleTable { + columns + .into_iter() + .map(|evaluations| DenseMultilinearExtension { + evaluations, + num_vars: SHA_ROW_VARS, + }) + .collect() +} + +fn projection_sha_flatten_bit_columns( + columns: Vec>>, +) -> MleTable { + let flattened = cfg_into_iter!(0..columns.len() * SHA_WORD_BITS) + .map(|flat_idx| { + let col_idx = flat_idx / SHA_WORD_BITS; + let bit = flat_idx % SHA_WORD_BITS; + columns[col_idx] + .iter() + .map(|row_bits| row_bits[bit].clone()) + .collect::>() + }) + .collect::>(); + projection_sha_mle_table_from_columns(flattened) +} + +fn projection_sha_flatten_bit_column_refs( + columns: &[&[Vec]], +) -> MleTable { + let flattened = cfg_into_iter!(0..columns.len() * SHA_WORD_BITS) + .map(|flat_idx| { + let col_idx = flat_idx / SHA_WORD_BITS; + let bit = flat_idx % SHA_WORD_BITS; + columns[col_idx] + .iter() + .map(|row_bits| row_bits[bit].clone()) + .collect::>() + }) + .collect::>(); + projection_sha_mle_table_from_columns(flattened) +} + +fn projection_sha_scalarize_bit_slices( + bit_slices: &MleTable, + a: &F, + field_cfg: &::Config, +) -> Result, ProductionShaError> { + let powers = zinc_utils::powers(a.clone(), F::one_with_cfg(field_cfg), SHA_WORD_BITS); + let word_count = bit_slices.len() / SHA_WORD_BITS; + let one = F::one_with_cfg(field_cfg); + let reducer = BarrettDelayedReduction::::new(field_cfg); + let words = cfg_into_iter!(0..word_count) + .map(|col_idx| { + let bit_cols = (0..SHA_WORD_BITS) + .map(|bit| { + let bit_col = &bit_slices[bit_slice_index(col_idx, bit, SHA_WORD_BITS)]; + if bit_col.num_vars != SHA_ROW_VARS + || bit_col.evaluations.len() != SHA_ROW_COUNT + { + return Err(ProductionShaError::LengthMismatch { + label: "SHA scalarized bit-slice rows", + got: bit_col.evaluations.len(), + expected: SHA_ROW_COUNT, + }); + } + Ok(bit_col) + }) + .collect::, ProductionShaError>>()?; + let mut out_col = Vec::with_capacity(SHA_ROW_COUNT); + for row in 0..SHA_ROW_COUNT { + out_col.push(projection_sha_scalarize_binary_row_dmr( + &bit_cols, row, &powers, &one, field_cfg, &reducer, + )); + } + Ok(out_col) + }) + .collect::, ProductionShaError>>()?; + Ok(projection_sha_mle_table_from_columns(words)) +} + +fn projection_sha_scalarize_binary_row_dmr( + bit_cols: &[&DenseMultilinearExtension], + row: usize, + powers: &[F], + one: &F, + field_cfg: &::Config, + reducer: &BarrettDelayedReduction<'_, F>, +) -> F { + let mut bucket = Uint::<5>::zero(); + let mut pending_adds = 0usize; + let mut acc = F::zero_with_cfg(field_cfg); + + for (bit_col, power) in bit_cols.iter().zip(powers) { + let bit = &bit_col.evaluations[row]; + if F::is_zero(bit) { + continue; + } + if bit != one { + return projection_sha_scalarize_row_naive(bit_cols, row, powers, field_cfg); + } + reducer.add(&mut bucket, power); + pending_adds = pending_adds.saturating_add(1); + if pending_adds >= reducer.flush_adds() { + let pending = std::mem::replace(&mut bucket, Uint::zero()); + acc += reducer.reduce(pending); + pending_adds = 0; + } + } + + if !bucket.is_zero() { + acc += reducer.reduce(bucket); + } + acc +} + +fn projection_sha_scalarize_row_naive( + bit_cols: &[&DenseMultilinearExtension], + row: usize, + powers: &[F], + field_cfg: &::Config, +) -> F { + let mut value = F::zero_with_cfg(field_cfg); + for (bit_col, power) in bit_cols.iter().zip(powers) { + value += bit_col.evaluations[row].clone() * power; + } + value +} + +fn projection_sha_selector_expected( + selector: ShaPublicCol, + row: usize, + field_cfg: &::Config, +) -> F { + let active = match selector { + ShaPublicCol::SInit => row < 4, + ShaPublicCol::SMsg => row < 16, + ShaPublicCol::SSched => row < 48, + ShaPublicCol::SUpd => row < 64, + ShaPublicCol::SFf => (64..68).contains(&row), + ShaPublicCol::SOut => (68..72).contains(&row), + _ => false, + }; + if active { + F::one_with_cfg(field_cfg) + } else { + F::zero_with_cfg(field_cfg) + } +} + +fn projection_sha_k_expected(row: usize, field_cfg: &::Config) -> F { + if (3..67).contains(&row) { + F::from_with_cfg(K_CANONICAL[row - 3] as u64, field_cfg) + } else { + F::zero_with_cfg(field_cfg) + } +} + +fn projection_sha_projected_public_from_sources( + pa_a: &[Vec], + pa_e: &[Vec], + message: &[Vec], + field_cfg: &::Config, +) -> MleTable { + let mut columns = vec![vec![F::zero_with_cfg(field_cfg); SHA_ROW_COUNT]; ShaPublicCol::COUNT]; + for row in 0..SHA_ROW_COUNT { + columns[ShaPublicCol::K.index()][row] = projection_sha_k_expected(row, field_cfg); + columns[ShaPublicCol::PAIn.index()][row] = + projection_sha_word_scalar_at_two(&pa_a[row], field_cfg); + columns[ShaPublicCol::PEIn.index()][row] = + projection_sha_word_scalar_at_two(&pa_e[row], field_cfg); + columns[ShaPublicCol::PAOut.index()][row] = + projection_sha_word_scalar_at_two(&pa_a[row], field_cfg); + columns[ShaPublicCol::PEOut.index()][row] = + projection_sha_word_scalar_at_two(&pa_e[row], field_cfg); + columns[ShaPublicCol::Message.index()][row] = + projection_sha_word_scalar_at_two(&message[row], field_cfg); + } + for selector in [ + ShaPublicCol::SInit, + ShaPublicCol::SMsg, + ShaPublicCol::SSched, + ShaPublicCol::SUpd, + ShaPublicCol::SFf, + ShaPublicCol::SOut, + ] { + for row in 0..SHA_ROW_COUNT { + columns[selector.index()][row] = + projection_sha_selector_expected(selector, row, field_cfg); + } + } + projection_sha_mle_table_from_columns(columns) +} + +impl ProductionShaProjectionAdapter + for ProjectionShaBenchUair +{ + fn project_production_sha_public( + _shape: &UairShape, + public_trace: &UairTrace<'_, RealEcdsaInt, RealEcdsaInt, DEGREE_PLUS_ONE>, + field_cfg: &::Config, + ) -> Result, ProductionShaError> { + let empty_witness = UairTrace { + binary_poly: Cow::Borrowed(&[]), + arbitrary_poly: Cow::Borrowed(&[]), + int: Cow::Borrowed(&[]), + }; + let pa_a = projection_sha_project_binary_source( + projection_sha_binary_col(public_trace, &empty_witness, sha256_cols::PA_A)?, + field_cfg, + )?; + let pa_e = projection_sha_project_binary_source( + projection_sha_binary_col(public_trace, &empty_witness, sha256_cols::PA_E)?, + field_cfg, + )?; + let message = projection_sha_project_binary_source( + projection_sha_binary_col(public_trace, &empty_witness, sha256_cols::PA_M)?, + field_cfg, + )?; + let public_columns = + projection_sha_projected_public_from_sources(&pa_a, &pa_e, &message, field_cfg); + let public_bit_columns = [ + pa_a.as_slice(), + pa_e.as_slice(), + pa_a.as_slice(), + pa_e.as_slice(), + message.as_slice(), + ]; + Ok(ProjectedPublic { + columns: public_columns, + bit_slices: Some(projection_sha_flatten_bit_column_refs(&public_bit_columns)), + }) + } + + fn project_production_sha_witness( + _shape: &UairShape, + public_trace: &UairTrace<'_, RealEcdsaInt, RealEcdsaInt, DEGREE_PLUS_ONE>, + witness_trace: &UairTrace<'_, RealEcdsaInt, RealEcdsaInt, DEGREE_PLUS_ONE>, + field_cfg: &::Config, + ) -> Result< + ( + ProjectedTrace, + ProjectedPublic, + ProductionShaWitnessPolys, + ), + ProductionShaError, + > { + let word_sources = [ + sha256_cols::W_A, + sha256_cols::W_E, + sha256_cols::W_SIG0, + sha256_cols::W_SIG1, + sha256_cols::W_W, + sha256_cols::W_LSIG0, + sha256_cols::W_LSIG1, + sha256_cols::W_U_EF, + sha256_cols::W_U_NEG_E_G, + sha256_cols::W_MAJ, + sha256_cols::W_MU_PACKED, + sha256_cols::PA_OV_SIG0, + sha256_cols::PA_OV_SIG1, + sha256_cols::PA_OV_LSIG0, + sha256_cols::PA_OV_LSIG1, + sha256_cols::PA_R_CH2_COMP, + sha256_cols::PA_R_MAJ_COMP, + ]; + let int_sources = [ + sha256_cols::PA_C_C7, + sha256_cols::PA_C_C8, + sha256_cols::PA_C_C9, + sha256_cols::PA_C_FF_A, + sha256_cols::PA_C_FF_E, + ]; + + let word_cols = cfg_iter!(&word_sources) + .map(|&col| projection_sha_binary_col(public_trace, witness_trace, col)) + .collect::, _>>()?; + let int_cols = cfg_iter!(&int_sources) + .map(|&col| projection_sha_int_col(public_trace, witness_trace, col)) + .collect::, _>>()?; + + let bit_columns = cfg_iter!(&word_cols) + .map(|&col| projection_sha_project_binary_source(col, field_cfg)) + .collect::, _>>()?; + let bit_slices = projection_sha_flatten_bit_columns(bit_columns); + let scalarized = projection_sha_scalarize_bit_slices( + &bit_slices, + &F::from_with_cfg(2u64, field_cfg), + field_cfg, + )?; + let pa_a = projection_sha_project_binary_source( + projection_sha_binary_col(public_trace, witness_trace, sha256_cols::PA_A)?, + field_cfg, + )?; + let pa_e = projection_sha_project_binary_source( + projection_sha_binary_col(public_trace, witness_trace, sha256_cols::PA_E)?, + field_cfg, + )?; + let message = projection_sha_project_binary_source( + projection_sha_binary_col(public_trace, witness_trace, sha256_cols::PA_M)?, + field_cfg, + )?; + let public_columns = + projection_sha_projected_public_from_sources(&pa_a, &pa_e, &message, field_cfg); + let int_columns = cfg_iter!(&int_cols) + .map(|&col| projection_sha_project_int_source(col, field_cfg)) + .collect::, _>>()?; + let public_bit_columns = [ + pa_a.as_slice(), + pa_e.as_slice(), + pa_a.as_slice(), + pa_e.as_slice(), + message.as_slice(), + ]; + + let trace = ProjectedTrace { + bit_slices, + scalarized, + int_columns: projection_sha_mle_table_from_columns(int_columns), + public_columns: public_columns.clone(), + }; + let public = ProjectedPublic { + columns: public_columns, + bit_slices: Some(projection_sha_flatten_bit_column_refs(&public_bit_columns)), + }; + Ok(( + trace, + public, + ProductionShaWitnessPolys { + binary: cfg_iter!(&word_cols) + .map(|&col| { + projection_sha_truncate_row_domain( + col, + "SHA binary witness row-domain projection", + ) + }) + .collect::, _>>()?, + arbitrary: Vec::new(), + int: cfg_iter!(&int_cols) + .map(|&col| { + projection_sha_truncate_row_domain( + col, + "SHA int witness row-domain projection", + ) + }) + .collect::, _>>()?, + }, + )) + } +} + +fn projection_sha_hyrax_key_pair( + width: usize, + offset: u64, +) -> ( + zip_plus::pcs::hyrax::HyraxCommitmentKey, + zip_plus::pcs::hyrax::HyraxVerifierKey, +) +where + C: AffineRepr, + Lanes: Clone + Debug + Send + Sync, +{ + let generator = C::Group::generator(); + let bases = (0..width) + .map(|idx| { + let scalar = C::ScalarField::from( + offset + u64::try_from(idx).expect("Hyrax basis index fits u64") + 1, + ); + (generator * scalar).into_affine() + }) + .collect::>(); + let h = generator + * C::ScalarField::from(offset + u64::try_from(width).expect("Hyrax width fits u64") + 1); + HyraxPCS::::setup_from_bases_with_blinding( + width, + bases, + h, + HyraxBlindingMode::Unblinded, + ) + .expect("Hyrax benchmark setup must be valid") +} + +fn projection_sha_hyrax_pcs_params() -> ( + PCSParams, RealEcdsaBenchZincTypes, F, DEGREE_PLUS_ONE>, + PCSVerifierParams, RealEcdsaBenchZincTypes, F, DEGREE_PLUS_ONE>, +) +where + C: AffineRepr, +{ + let width = SHA_ROW_COUNT; + let (binary_ck, binary_vk) = projection_sha_hyrax_key_pair::(width, 0); + let (arbitrary_ck, arbitrary_vk) = + projection_sha_hyrax_key_pair::(width, 1_000); + let (int_ck, int_vk) = projection_sha_hyrax_key_pair::(width, 2_000); + + ( + PCSParams::, RealEcdsaBenchZincTypes, F, DEGREE_PLUS_ONE> { + binary: binary_ck, + arbitrary: arbitrary_ck, + int: int_ck, + }, + PCSVerifierParams::, RealEcdsaBenchZincTypes, F, DEGREE_PLUS_ONE> { + binary: binary_vk, + arbitrary: arbitrary_vk, + int: int_vk, + }, + ) +} + // // End-to-end benchmarks (total prove/verify time) // @@ -744,11 +1375,295 @@ fn do_bench_steps( ); step_bench!( - "Verify" / "6: Lifted evals", - setup = || v_mp_evaled.clone(), - run = |s| s.step6_lifted_evals::(), + "Verify" / "6: Lifted evals", + setup = || v_mp_evaled.clone(), + run = |s| s.step6_lifted_evals::(), + ); + + step_bench!( + "Verify" / "7: PCS verify", + setup = || v_lifted.clone(), + run = |s| s.step7_pcs_verify::(), + ); +} + +fn append_transcribable_bytes(out: &mut Vec, value: &T) { + let offset = out.len(); + out.resize(offset + T::LENGTH_NUM_BYTES + value.get_num_bytes(), 0); + let rest = value.write_transcription_bytes_subset(&mut out[offset..]); + assert!(rest.is_empty(), "transcription buffer should be exact"); +} + +fn generic_pcs_proof_raw_bytes( + proof: &Proof>, +) -> Vec +where + Zt: ZincTypes, + P: ZincPCSTypes, +{ + let mut out = Vec::new(); + <

>::BinaryPCS as PCS< + F, + BinaryPoly, + DEGREE_PLUS_ONE, + >>::write_commitment_bytes(&proof.commitments.binary, &mut out); + <

>::ArbitraryPCS as PCS< + F, + DensePolynomial, + DEGREE_PLUS_ONE, + >>::write_commitment_bytes(&proof.commitments.arbitrary, &mut out); + <

>::IntPCS as PCS< + F, + Zt::Int, + DEGREE_PLUS_ONE, + >>::write_commitment_bytes(&proof.commitments.int, &mut out); + + let zip_len = u32::try_from(proof.zip.len()).expect("zip length must fit into u32"); + out.extend_from_slice(&zip_len.to_le_bytes()); + out.extend_from_slice(&proof.zip); + append_transcribable_bytes(&mut out, &proof.ideal_check); + append_transcribable_bytes(&mut out, &proof.resolver); + append_transcribable_bytes(&mut out, &proof.combined_sumcheck); + append_transcribable_bytes(&mut out, &proof.multipoint_eval); + append_transcribable_bytes( + &mut out, + DynamicPolyVecF::reinterpret(&proof.witness_lifted_evals), + ); + out +} + +fn eprint_generic_pcs_proof_size( + label: &str, + proof: &Proof>, +) where + Zt: ZincTypes, + P: ZincPCSTypes, +{ + let raw = generic_pcs_proof_raw_bytes::(proof); + eprint_bytes_size(label, &raw); +} + +#[allow(clippy::too_many_arguments)] +fn do_bench_pcs_e2e( + group: &mut BenchmarkGroup, + label: &str, + num_vars: usize, + pp: &PCSParams, + vp: &PCSVerifierParams, + trace: &UairTrace<'static, Zt::Int, Zt::Int, DEGREE_PLUS_ONE>, + field_cfg: ::Config, + project_scalar: impl Fn(&U::Scalar, &::Config) -> DynamicPolynomialF + + Copy + + Sync, + project_ideal: impl Fn(&IdealOrZero, &::Config) -> IdealOverF + Copy, +) where + Zt: ZincTypes, + Zt::Int: ProjectableToField + num_traits::Zero, + ::Cw: ProjectableToField, + ::Eval: ProjectableToField, + ::Cw: ProjectableToField, + ::Cw: ProjectableToField, + F: FromWithConfig + + for<'a> FromWithConfig<&'a ::CombR> + + for<'a> FromWithConfig<&'a ::CombR> + + for<'a> FromWithConfig<&'a ::CombR> + + for<'a> FromWithConfig<&'a Zt::Chal> + + for<'a> FromWithConfig<&'a Zt::Pt> + + for<'a> MulByScalar<&'a F> + + FromRef + + Send + + Sync + + 'static, + F: for<'a> FromWithConfig<&'a Zt::Int>, + ::Modulus: ConstTranscribable + FromRef, + U: Uair + 'static, + IdealOverF: Ideal + IdealCheck>, + P: ZincPCSTypes, +{ + let params = format!("{label}/nvars={num_vars}"); + + macro_rules! zinc_plus { + () => { + ZincPlusPiop:: + }; + } + + macro_rules! bench_prove { + ($label:literal, $mle_first:expr) => { + group.bench_function(BenchmarkId::new($label, ¶ms), |bench| { + bench.iter(|| { + black_box(::prove_with_pcs_and_field_cfg::< + P, + { $mle_first }, + PERFORM_CHECKS, + >( + pp, trace, num_vars, project_scalar, field_cfg.clone() + )) + .expect("Prover failed"); + }); + }); + }; + } + + bench_prove!("Prove (Combined)", false); + if count_effective_max_degree::() <= 1 { + bench_prove!("Prove (MLE-first)", true); + } + + let proof = ::prove_with_pcs_and_field_cfg::( + pp, + trace, + num_vars, + project_scalar, + field_cfg.clone(), + ) + .expect("proof generation for verifier bench"); + + let sig = U::signature(); + let public_trace = trace.public(&sig); + + group.bench_function(BenchmarkId::new("Verify", ¶ms), |bench| { + bench.iter_batched( + || proof.clone(), + |proof| { + black_box(::verify_with_pcs_and_field_cfg::< + P, + IdealOverF, + PERFORM_CHECKS, + >( + vp, + proof, + &public_trace, + num_vars, + project_scalar, + project_ideal, + field_cfg.clone(), + )) + .expect("Verifier failed"); + }, + BatchSize::SmallInput, + ); + }); + + eprint_generic_pcs_proof_size::(¶ms, &proof); +} + +#[allow(clippy::too_many_arguments, clippy::unwrap_used)] +fn do_bench_pcs_steps( + group: &mut BenchmarkGroup, + label: &str, + num_vars: usize, + pp: &PCSParams, + vp: &PCSVerifierParams, + trace: &UairTrace<'static, Zt::Int, Zt::Int, DEGREE_PLUS_ONE>, + field_cfg: ::Config, + project_scalar: fn(&U::Scalar, &::Config) -> DynamicPolynomialF, + project_ideal: impl Fn(&IdealOrZero, &::Config) -> IdealOverF + Copy, +) where + Zt: ZincTypes, + Zt::Int: ProjectableToField + num_traits::Zero, + ::Cw: ProjectableToField, + ::Eval: ProjectableToField, + ::Cw: ProjectableToField, + ::Cw: ProjectableToField, + F: FromWithConfig + + for<'a> FromWithConfig<&'a ::CombR> + + for<'a> FromWithConfig<&'a ::CombR> + + for<'a> FromWithConfig<&'a ::CombR> + + for<'a> FromWithConfig<&'a Zt::Chal> + + for<'a> FromWithConfig<&'a Zt::Pt> + + for<'a> MulByScalar<&'a F> + + FromRef + + Send + + Sync + + 'static, + F: for<'a> FromWithConfig<&'a Zt::Int>, + ::Modulus: ConstTranscribable + FromRef, + U: Uair + 'static, + IdealOverF: Ideal + IdealCheck>, + P: ZincPCSTypes, +{ + let params = format!("{label}/nvars={num_vars}"); + + macro_rules! step_bench { + ($side:literal / $step_name:literal, setup = || $setup:expr, run = |$s:ident| $run:expr $(,)?) => { + group.bench_function( + BenchmarkId::new(format!("{}/{}", $side, $step_name), ¶ms), + |b| { + b.iter_batched( + || $setup, + |$s| { + black_box($run).expect("step failed"); + }, + BatchSize::SmallInput, + ); + }, + ); + }; + } + + macro_rules! piop { + () => { + ZincPlusPiop:: + }; + } + + let p_committed = ::step0_commit_with_pcs::

(pp, trace, num_vars).unwrap(); + let p_projected = p_committed + .clone() + .step1_combined_with_field_cfg(project_scalar, field_cfg.clone()) + .unwrap(); + let p_ideal_checked = p_projected.clone().step2_ideal_check().unwrap(); + let p_eval_projected = p_ideal_checked.clone().step3_eval_projection().unwrap(); + let p_sumchecked = p_eval_projected.clone().step4_sumcheck().unwrap(); + let p_mp_evaled = p_sumchecked.clone().step5_multipoint_eval().unwrap(); + let p_lifted = p_mp_evaled.clone().step6_lift_and_project().unwrap(); + + step_bench!( + "Prove" / "0: Commit", + setup = || {}, + run = |_s| ::step0_commit_with_pcs::

(pp, trace, num_vars), + ); + + step_bench!( + "Prove" / "7: PCS open", + setup = || p_lifted.clone(), + run = |s| s.step7_pcs_open::(), ); + let proof = ::prove_with_pcs_and_field_cfg::( + pp, + trace, + num_vars, + project_scalar, + field_cfg.clone(), + ) + .expect("proof generation for verifier bench"); + let sig = U::signature(); + let public_trace = trace.public(&sig); + let v_transcript = ::step0_reconstruct_transcript_with_pcs::( + vp, + proof, + &public_trace, + num_vars, + ) + .unwrap(); + let v_prime_projected = v_transcript + .clone() + .step1_prime_projection_with_field_cfg(field_cfg.clone()) + .unwrap(); + let v_ideal_checked = v_prime_projected + .clone() + .step2_ideal_check(project_ideal) + .unwrap(); + let v_eval_projected = v_ideal_checked + .clone() + .step3_eval_projection(project_scalar) + .unwrap(); + let v_sumchecked = v_eval_projected.clone().step4_sumcheck_verify().unwrap(); + let v_mp_evaled = v_sumchecked.clone().step5_multipoint_eval::().unwrap(); + let v_lifted = v_mp_evaled.clone().step6_lifted_evals::().unwrap(); + step_bench!( "Verify" / "7: PCS verify", setup = || v_lifted.clone(), @@ -839,11 +1754,10 @@ fn bench_real_ecdsa_e2e(group: &mut BenchmarkGroup, num_vars: usize) { let trace = U::generate_random_trace(num_vars, &mut rng); let pp = setup_pp_real_ecdsa(num_vars); - let proj_ideal = |_: &IdealOrZero<::Ideal>, - _: &::Config| - -> ImpossibleIdeal { - unreachable!("EcdsaUair has only assert_zero constraints") - }; + let proj_ideal = + |_: &IdealOrZero<::Ideal>, _: &::Config| -> ImpossibleIdeal { + unreachable!("EcdsaUair has only assert_zero constraints") + }; do_bench_e2e::( group, @@ -863,11 +1777,10 @@ fn bench_real_ecdsa_steps(group: &mut BenchmarkGroup, num_vars: usize) let trace = U::generate_random_trace(num_vars, &mut rng); let pp = setup_pp_real_ecdsa(num_vars); - let proj_ideal = |_: &IdealOrZero<::Ideal>, - _: &::Config| - -> ImpossibleIdeal { - unreachable!("EcdsaUair has only assert_zero constraints") - }; + let proj_ideal = + |_: &IdealOrZero<::Ideal>, _: &::Config| -> ImpossibleIdeal { + unreachable!("EcdsaUair has only assert_zero constraints") + }; do_bench_steps::( group, @@ -883,13 +1796,12 @@ fn bench_real_ecdsa_steps(group: &mut BenchmarkGroup, num_vars: usize) fn bench_real_sha256_e2e(group: &mut BenchmarkGroup, num_vars: usize) { type U = Sha256CompressionSliceUair; - let mut rng = rng(); - let trace = U::generate_random_trace(num_vars, &mut rng); + let trace = real_sha256_chain_trace(num_vars); let pp = setup_pp_real_ecdsa(num_vars); do_bench_e2e::( group, - "RealSha256", + "RealSha256Chain8", num_vars, &pp, &trace, @@ -901,13 +1813,12 @@ fn bench_real_sha256_e2e(group: &mut BenchmarkGroup, num_vars: usize) fn bench_real_sha256_steps(group: &mut BenchmarkGroup, num_vars: usize) { type U = Sha256CompressionSliceUair; - let mut rng = rng(); - let trace = U::generate_random_trace(num_vars, &mut rng); + let trace = real_sha256_chain_trace(num_vars); let pp = setup_pp_real_ecdsa(num_vars); do_bench_steps::( group, - "RealSha256", + "RealSha256Chain8", num_vars, &pp, &trace, @@ -916,6 +1827,404 @@ fn bench_real_sha256_steps(group: &mut BenchmarkGroup, num_vars: usize ); } +fn zip_pcs_params( + num_vars: usize, +) -> ( + PCSParams, + PCSVerifierParams, +) { + let pp = setup_pp_real_ecdsa(num_vars); + ( + PCSParams:: { + binary: pp.0.clone(), + arbitrary: pp.1.clone(), + int: pp.2.clone(), + }, + PCSVerifierParams:: { + binary: pp.0, + arbitrary: pp.1, + int: pp.2, + }, + ) +} + +fn hyrax_pcs_params( + num_vars: usize, +) -> ( + PCSParams, RealEcdsaBenchZincTypes, F, DEGREE_PLUS_ONE>, + PCSVerifierParams, RealEcdsaBenchZincTypes, F, DEGREE_PLUS_ONE>, +) +where + BinaryIntHyraxZipArbitrary: ZincPCSTypes< + RealEcdsaBenchZincTypes, + F, + DEGREE_PLUS_ONE, + BinaryPCS = HyraxPCS, + ArbitraryPCS = ZipPlusPCS< + >::ArbitraryZt, + >::ArbitraryLc, + >, + IntPCS = HyraxPCS, + >, +{ + let pp = setup_pp_real_ecdsa(num_vars); + let binary_width = pp.0.linear_code.row_len(); + let int_width = pp.2.linear_code.row_len(); + let (binary, binary_vk) = HyraxPCS::::setup( + binary_width, + b"zinc-plus-bench-real-sha256-hyrax-binary", + HyraxBlindingMode::Unblinded, + ) + .expect("Hyrax binary benchmark setup must be valid"); + let (int, int_vk) = HyraxPCS::::setup( + int_width, + b"zinc-plus-bench-real-sha256-hyrax-int", + HyraxBlindingMode::Unblinded, + ) + .expect("Hyrax int benchmark setup must be valid"); + ( + PCSParams::, RealEcdsaBenchZincTypes, F, DEGREE_PLUS_ONE> { + binary, + arbitrary: pp.1.clone(), + int, + }, + PCSVerifierParams::< + BinaryIntHyraxZipArbitrary, + RealEcdsaBenchZincTypes, + F, + DEGREE_PLUS_ONE, + > { + binary: binary_vk, + arbitrary: pp.1, + int: int_vk, + }, + ) +} + +fn bench_real_sha256_pcs_curve_e2e( + group: &mut BenchmarkGroup, + num_vars: usize, + zip_label: &str, + hyrax_label: &str, +) where + BinaryIntHyraxZipArbitrary: ZincPCSTypes< + RealEcdsaBenchZincTypes, + F, + DEGREE_PLUS_ONE, + BinaryPCS = HyraxPCS, + ArbitraryPCS = ZipPlusPCS< + >::ArbitraryZt, + >::ArbitraryLc, + >, + IntPCS = HyraxPCS, + >, +{ + type U = Sha256CompressionSliceUair; + + let trace = real_sha256_chain_trace(num_vars); + let field_cfg = field_cfg_from_curve_scalar::< + F, + >::Fmod, + C, + >(); + + let (zip_pp, zip_vp) = zip_pcs_params(num_vars); + do_bench_pcs_e2e::( + group, + zip_label, + num_vars, + &zip_pp, + &zip_vp, + &trace, + field_cfg.clone(), + zinc_protocol::project_scalar_fn, + sha256_real_project_ideal, + ); + + let (hyrax_pp, hyrax_vp) = hyrax_pcs_params::(num_vars); + do_bench_pcs_e2e::>( + group, + hyrax_label, + num_vars, + &hyrax_pp, + &hyrax_vp, + &trace, + field_cfg, + zinc_protocol::project_scalar_fn, + sha256_real_project_ideal, + ); +} + +fn bench_real_sha256_pcs_curve_steps( + group: &mut BenchmarkGroup, + num_vars: usize, + zip_label: &str, + hyrax_label: &str, +) where + BinaryIntHyraxZipArbitrary: ZincPCSTypes< + RealEcdsaBenchZincTypes, + F, + DEGREE_PLUS_ONE, + BinaryPCS = HyraxPCS, + ArbitraryPCS = ZipPlusPCS< + >::ArbitraryZt, + >::ArbitraryLc, + >, + IntPCS = HyraxPCS, + >, +{ + type U = Sha256CompressionSliceUair; + + let trace = real_sha256_chain_trace(num_vars); + let field_cfg = field_cfg_from_curve_scalar::< + F, + >::Fmod, + C, + >(); + + let (zip_pp, zip_vp) = zip_pcs_params(num_vars); + do_bench_pcs_steps::( + group, + zip_label, + num_vars, + &zip_pp, + &zip_vp, + &trace, + field_cfg.clone(), + zinc_protocol::project_scalar_fn, + sha256_real_project_ideal, + ); + + let (hyrax_pp, hyrax_vp) = hyrax_pcs_params::(num_vars); + do_bench_pcs_steps::>( + group, + hyrax_label, + num_vars, + &hyrax_pp, + &hyrax_vp, + &trace, + field_cfg, + zinc_protocol::project_scalar_fn, + sha256_real_project_ideal, + ); +} + +fn bench_real_sha256_pcs_e2e(group: &mut BenchmarkGroup, num_vars: usize) { + bench_real_sha256_pcs_curve_e2e::( + group, + num_vars, + "RealSha256Chain8PCS/ZipBn254Fr", + "RealSha256Chain8PCS/HyraxBn254Unblinded", + ); + bench_real_sha256_pcs_curve_e2e::( + group, + num_vars, + "RealSha256Chain8PCS/ZipSecp256k1Fr", + "RealSha256Chain8PCS/HyraxSecp256k1Unblinded", + ); +} + +fn bench_real_sha256_pcs_steps(group: &mut BenchmarkGroup, num_vars: usize) { + bench_real_sha256_pcs_curve_steps::( + group, + num_vars, + "RealSha256Chain8PCS/ZipBn254Fr", + "RealSha256Chain8PCS/HyraxBn254Unblinded", + ); + bench_real_sha256_pcs_curve_steps::( + group, + num_vars, + "RealSha256Chain8PCS/ZipSecp256k1Fr", + "RealSha256Chain8PCS/HyraxSecp256k1Unblinded", + ); +} + +fn bench_og_sha256_zip_compare(group: &mut BenchmarkGroup, num_vars: usize) { + type U = Sha256CompressionSliceUair; + + let trace = real_sha256_chain_trace(num_vars); + let field_cfg = field_cfg_from_curve_scalar::< + F, + >::Fmod, + ark_bn254::G1Affine, + >(); + let (zip_pp, zip_vp) = zip_pcs_params(num_vars); + + do_bench_pcs_e2e::( + group, + "OG-ZincPlus-ZipBn254/SHA256Chain8", + num_vars, + &zip_pp, + &zip_vp, + &trace, + field_cfg, + zinc_protocol::project_scalar_fn, + sha256_real_project_ideal, + ); +} + +fn bench_projectionfold_sha256_concise_hyrax(group: &mut BenchmarkGroup, label: &str) +where + C: AffineRepr + Send + Sync + 'static, + F: zip_plus::pcs::hyrax::HyraxFieldBridge, +{ + type P = AllHyraxPCSTypes; + type U = ProjectionShaBenchUair; + + let message_blocks = real_sha256_chain_blocks(); + let (_mono_trace, mono_final_state) = + synthesize_sha256_chain_trace::( + REAL_SHA256_CHAIN_NUM_VARS, + SHA256_INITIAL_STATE, + message_blocks, + ) + .expect("monolithic N=8 SHA trace synthesis should succeed"); + let (witnesses, projection_final_state) = synthesize_sha256_chain_witnesses::< + RealEcdsaInt, + REAL_SHA256_CHAIN_BLOCKS, + >(SHA256_INITIAL_STATE, message_blocks) + .expect("ProjectionFold SHA witness synthesis should succeed"); + assert_eq!(mono_final_state, projection_final_state); + + let shape = UairShape::::new(SHA_ROW_VARS); + let field_cfg = field_cfg_from_curve_scalar::< + F, + >::Fmod, + C, + >(); + let (pcs_params, pcs_verifier_params) = projection_sha_hyrax_pcs_params::(); + let pp = + LinearIdealFoldProverParams::, U, RealEcdsaBenchZincTypes, F, DEGREE_PLUS_ONE>::new( + pcs_params, + field_cfg.clone(), + 3, + ); + let vs = + setup_verify_linear_ideal_fold::, U, RealEcdsaBenchZincTypes, F, DEGREE_PLUS_ONE>( + LinearIdealFoldVerifierParams::new(pcs_verifier_params, field_cfg), + shape.clone(), + ) + .expect("ProjectionFold SHA verifier setup succeeds"); + + let params = format!("{label}/SHA256Chain8/row-vars={SHA_ROW_VARS}"); + let prepared_instances = prepare_linear_ideal_fold_witnesses::< + U, + RealEcdsaBenchZincTypes, + F, + DEGREE_PLUS_ONE, + >(&shape, &witnesses, &pp.field_cfg) + .expect("ProjectionFold SHA witness preparation should succeed"); + + group.bench_function(BenchmarkId::new("Prove", ¶ms), |bench| { + bench.iter(|| { + let mut transcript = Blake3Transcript::new(); + black_box(prove_prepared_linear_ideal_fold::< + P, + U, + RealEcdsaBenchZincTypes, + F, + DEGREE_PLUS_ONE, + >(&pp, &shape, &prepared_instances, &mut transcript)) + .expect("ProjectionFold Concise prover failed"); + }); + }); + + let mut prover_transcript = Blake3Transcript::new(); + let output = prove_prepared_linear_ideal_fold::< + P, + U, + RealEcdsaBenchZincTypes, + F, + DEGREE_PLUS_ONE, + >( + &pp, + &shape, + &prepared_instances, + &mut prover_transcript, + ) + .expect("proof generation for ProjectionFold verifier bench"); + + let mut verifier_transcript = Blake3Transcript::new(); + let verified = + verify_linear_ideal_fold::, U, RealEcdsaBenchZincTypes, F, DEGREE_PLUS_ONE>( + &vs, + &output.fresh_instances, + &output.proof, + &mut verifier_transcript, + ) + .expect("ProjectionFold verifier preflight failed"); + assert_eq!(verified.target, output.folded_instance.target); + assert_eq!(verified.public, output.folded_instance.public); + + eprintln!(" ProjectionFold Concise tracing ({params}):"); + let subscriber = tracing_subscriber::fmt() + .with_writer(std::io::stderr) + .with_target(true) + .with_span_events(tracing_subscriber::fmt::format::FmtSpan::CLOSE) + .finish(); + tracing::subscriber::with_default(subscriber, || { + let mut prover_transcript = Blake3Transcript::new(); + let traced_output = prove_prepared_linear_ideal_fold::< + P, + U, + RealEcdsaBenchZincTypes, + F, + DEGREE_PLUS_ONE, + >( + &pp, + &shape, + &prepared_instances, + &mut prover_transcript, + ) + .expect("ProjectionFold traced prover failed"); + + let mut verifier_transcript = Blake3Transcript::new(); + let traced_verified = + verify_linear_ideal_fold::, U, RealEcdsaBenchZincTypes, F, DEGREE_PLUS_ONE>( + &vs, + &traced_output.fresh_instances, + &traced_output.proof, + &mut verifier_transcript, + ) + .expect("ProjectionFold traced verifier failed"); + assert_eq!(traced_verified.target, traced_output.folded_instance.target); + assert_eq!(traced_verified.public, traced_output.folded_instance.public); + }); + + group.bench_function(BenchmarkId::new("Verify", ¶ms), |bench| { + bench.iter(|| { + let mut transcript = Blake3Transcript::new(); + black_box(verify_linear_ideal_fold::< + P, + U, + RealEcdsaBenchZincTypes, + F, + DEGREE_PLUS_ONE, + >( + &vs, + &output.fresh_instances, + &output.proof, + &mut transcript, + )) + .expect("ProjectionFold Concise verifier failed"); + }); + }); +} + +fn bench_projectionfold_sha256_concise_hyrax_bn254(group: &mut BenchmarkGroup) { + bench_projectionfold_sha256_concise_hyrax::( + group, + "ProjectionFoldConcise-HyraxBn254", + ); +} + +fn bench_projectionfold_sha256_concise_hyrax_secp256k1(group: &mut BenchmarkGroup) { + bench_projectionfold_sha256_concise_hyrax::( + group, + "ProjectionFoldConcise-HyraxSecp256k1", + ); +} + fn bench_real_sha_ecdsa_e2e(group: &mut BenchmarkGroup, num_vars: usize) { type U = ShaEcdsaUair; @@ -994,19 +2303,19 @@ fn e2e_benches(c: &mut Criterion) { // bench_no_mult_e2e(&mut group, 8); // bench_no_mult_e2e(&mut group, 10); // bench_no_mult_e2e(&mut group, 12); -// + // // bench_binary_decomposition_e2e(&mut group, 8); // bench_binary_decomposition_e2e(&mut group, 10); // bench_binary_decomposition_e2e(&mut group, 12); -// + // // bench_big_linear_e2e(&mut group, 8); // bench_big_linear_e2e(&mut group, 10); // bench_big_linear_e2e(&mut group, 12); -// + // // bench_big_linear_public_input_e2e(&mut group, 8); // bench_big_linear_public_input_e2e(&mut group, 10); // bench_big_linear_public_input_e2e(&mut group, 12); -// + // // bench_sha_proxy_e2e(&mut group, 8); // bench_sha_proxy_e2e(&mut group, 10); // bench_sha_proxy_e2e(&mut group, 12); @@ -1014,7 +2323,8 @@ fn e2e_benches(c: &mut Criterion) { // Real UAIRs ported from main-gamma. Trace size for ECDSA needs >= 256 // rows (Shamir loop), so num_vars=9 is the smallest meaningful size. // bench_real_ecdsa_e2e(&mut group, 9); - bench_real_sha256_e2e(&mut group, 9); + bench_real_sha256_e2e(&mut group, REAL_SHA256_CHAIN_NUM_VARS); + bench_real_sha256_pcs_e2e(&mut group, REAL_SHA256_CHAIN_NUM_VARS); bench_real_sha_ecdsa_e2e(&mut group, 9); group.finish(); @@ -1026,19 +2336,19 @@ fn e2e_steps_benches(c: &mut Criterion) { // bench_no_mult_steps(&mut group, 8); // bench_no_mult_steps(&mut group, 10); // bench_no_mult_steps(&mut group, 12); -// + // // bench_binary_decomposition_steps(&mut group, 8); // bench_binary_decomposition_steps(&mut group, 10); // bench_binary_decomposition_steps(&mut group, 12); -// + // // bench_big_linear_steps(&mut group, 8); // bench_big_linear_steps(&mut group, 10); // bench_big_linear_steps(&mut group, 12); -// + // // bench_big_linear_public_input_steps(&mut group, 8); // bench_big_linear_public_input_steps(&mut group, 10); // bench_big_linear_public_input_steps(&mut group, 12); -// + // // bench_sha_proxy_steps(&mut group, 8); // bench_sha_proxy_steps(&mut group, 10); // bench_sha_proxy_steps(&mut group, 12); @@ -1046,12 +2356,23 @@ fn e2e_steps_benches(c: &mut Criterion) { // Real UAIRs ported from main-gamma. See `e2e_benches` for the // num_vars=9 lower-bound rationale. bench_real_ecdsa_steps(&mut group, 9); - bench_real_sha256_steps(&mut group, 9); + bench_real_sha256_steps(&mut group, REAL_SHA256_CHAIN_NUM_VARS); + bench_real_sha256_pcs_steps(&mut group, REAL_SHA256_CHAIN_NUM_VARS); bench_real_sha_ecdsa_steps(&mut group, 9); group.finish(); } +fn sha256_proving_system_compare_benches(c: &mut Criterion) { + let mut group = c.benchmark_group("SHA-256 Proving System Comparison"); + + bench_og_sha256_zip_compare(&mut group, REAL_SHA256_CHAIN_NUM_VARS); + bench_projectionfold_sha256_concise_hyrax_bn254(&mut group); + bench_projectionfold_sha256_concise_hyrax_secp256k1(&mut group); + + group.finish(); +} + // // Folded Zip+ (1× fold) — total prove/verify benchmark. // @@ -1221,7 +2542,6 @@ type FoldedPp4x = ( >, ); - /// 4× folded e2e bench: routes binary AND int through `MultiZip3` for /// shared-Merkle collapse, then opens at the doubly-extended point /// `(r_0 ‖ γ₁ ‖ γ₂)`. Calls [`prove_folded_4x`] / [`verify_folded_4x`]. @@ -1314,39 +2634,57 @@ fn do_bench_e2e_folded_4x( let sig = U::signature(); let public_trace = trace.public(&sig); - group.bench_function( - BenchmarkId::new("Verify (folded 4×)", ¶ms), - |bench| { - bench.iter_batched( - || proof.clone(), - |proof| { - black_box( - zinc_protocol::verifier::verify_folded_4x::< - ZtF, - U, - F, - IdealOverF, - DEGREE_PLUS_ONE, - HALF_DEGREE_PLUS_ONE, - QUARTER_DEGREE_PLUS_ONE, - EC_FP_INT_LIMBS, - INT_QUARTER_LIMBS_BENCH, - PERFORM_CHECKS, - >( - pp, - proof, - &public_trace, - num_vars, - project_scalar, - project_ideal, - ), - ) - .expect("Folded 4× verifier failed"); - }, - BatchSize::SmallInput, - ); - }, - ); + eprintln!(" Folded 4× tracing ({params}):"); + let subscriber = tracing_subscriber::fmt() + .with_writer(std::io::stderr) + .with_target(true) + .with_span_events(tracing_subscriber::fmt::format::FmtSpan::CLOSE) + .finish(); + tracing::subscriber::with_default(subscriber, || { + let (_traced_proof, _traced_timings) = + zinc_protocol::prover::prove_folded_4x_with_timings::< + ZtF, + U, + F, + DEGREE_PLUS_ONE, + HALF_DEGREE_PLUS_ONE, + QUARTER_DEGREE_PLUS_ONE, + EC_FP_INT_LIMBS, + INT_QUARTER_LIMBS_BENCH, + false, + PERFORM_CHECKS, + >(pp, trace, num_vars, project_scalar) + .expect("Folded 4× traced prover failed"); + }); + + group.bench_function(BenchmarkId::new("Verify (folded 4×)", ¶ms), |bench| { + bench.iter_batched( + || proof.clone(), + |proof| { + black_box(zinc_protocol::verifier::verify_folded_4x::< + ZtF, + U, + F, + IdealOverF, + DEGREE_PLUS_ONE, + HALF_DEGREE_PLUS_ONE, + QUARTER_DEGREE_PLUS_ONE, + EC_FP_INT_LIMBS, + INT_QUARTER_LIMBS_BENCH, + PERFORM_CHECKS, + >( + pp, + proof, + &public_trace, + num_vars, + project_scalar, + project_ideal, + )) + .expect("Folded 4× verifier failed"); + }, + BatchSize::SmallInput, + ); + }); let label_full = format!("Folded 4×/{params}"); eprint_proof_size(&label_full, &proof); @@ -1436,7 +2774,7 @@ fn eprint_folded_4x_per_region_prove_timings( > + 'static, S: Fn(&U::Scalar, &::Config) -> DynamicPolynomialF + Copy + Sync, { - use zinc_protocol::prover::{prove_folded_4x_with_timings, FoldedProveTimings}; + use zinc_protocol::prover::{FoldedProveTimings, prove_folded_4x_with_timings}; const N: u32 = 100; @@ -1580,9 +2918,7 @@ fn eprint_folded_4x_per_region_verify_timings( S: Fn(&U::Scalar, &::Config) -> DynamicPolynomialF + Copy + Sync, I: Fn(&IdealOrZero, &::Config) -> IdealOverF + Copy, { - use zinc_protocol::verifier::{ - verify_folded_4x_with_timings, FoldedVerifyTimings, - }; + use zinc_protocol::verifier::{FoldedVerifyTimings, verify_folded_4x_with_timings}; const N: u32 = 100; @@ -1685,7 +3021,6 @@ fn eprint_folded_4x_per_region_verify_timings( ); } - /// Serialize each `Proof` component into its own byte buffer and report /// per-part raw + zstd-compressed sizes, so we can see how much each part /// of the proof contributes to the total size. Sizes match the per-field @@ -1705,8 +3040,9 @@ where } // 3 commitments concatenated (each ConstTranscribable, no length prefix). - let mut commits = - Vec::with_capacity(3_usize.saturating_mul(::NUM_BYTES)); + let mut commits = Vec::with_capacity( + 3_usize.saturating_mul(::NUM_BYTES), + ); commits.extend_from_slice(&to_bytes(&proof.commitments.0)); commits.extend_from_slice(&to_bytes(&proof.commitments.1)); commits.extend_from_slice(&to_bytes(&proof.commitments.2)); @@ -1879,8 +3215,6 @@ fn eprint_folded_4x_zip_substep_breakdown( ); } - - // // Real-UAIR folded benches (1× and 4×). These reuse the generic // `do_bench_e2e_folded` / `do_bench_e2e_folded_4x` helpers above with @@ -1908,13 +3242,7 @@ impl FoldedZincTypes for BenchFoldedRealE Int<5>, DensePolynomial, HALF_DEGREE_PLUS_ONE>, BinaryPolyInnerProduct, - DensePolyInnerProduct< - Int<5>, - Self::Chal, - Int<5>, - MBSInnerProduct, - HALF_DEGREE_PLUS_ONE, - >, + DensePolyInnerProduct, Self::Chal, Int<5>, MBSInnerProduct, HALF_DEGREE_PLUS_ONE>, MBSInnerProduct, >; @@ -1926,7 +3254,6 @@ impl FoldedZincTypes for BenchFoldedRealE type IntLc = >::IntLc; } - // // 4× int-fold variant of the bench Zinc-types. Implements // `IntFoldedZincTypes4x` so that `prove_folded_4x` / @@ -1994,13 +3321,7 @@ impl Int<5>, DensePolynomial, QUARTER_DEGREE_PLUS_ONE>, BinaryPolyInnerProduct, - DensePolyInnerProduct< - Int<5>, - Self::Chal, - Int<5>, - MBSInnerProduct, - QUARTER_DEGREE_PLUS_ONE, - >, + DensePolyInnerProduct, Self::Chal, Int<5>, MBSInnerProduct, QUARTER_DEGREE_PLUS_ONE>, MBSInnerProduct, >; type ArbitraryZt = >::ArbitraryZt; @@ -2085,11 +3406,10 @@ fn bench_real_ecdsa_e2e_folded(group: &mut BenchmarkGroup, num_vars: u let trace = U::generate_random_trace(num_vars, &mut rng); let pp = setup_folded_pp_real_ecdsa(num_vars); - let proj_ideal = |_: &IdealOrZero<::Ideal>, - _: &::Config| - -> ImpossibleIdeal { - unreachable!("EcdsaUair has only assert_zero constraints") - }; + let proj_ideal = + |_: &IdealOrZero<::Ideal>, _: &::Config| -> ImpossibleIdeal { + unreachable!("EcdsaUair has only assert_zero constraints") + }; do_bench_e2e_folded::( group, @@ -2105,13 +3425,12 @@ fn bench_real_ecdsa_e2e_folded(group: &mut BenchmarkGroup, num_vars: u fn bench_real_sha256_e2e_folded(group: &mut BenchmarkGroup, num_vars: usize) { type U = Sha256CompressionSliceUair; - let mut rng = rng(); - let trace = U::generate_random_trace(num_vars, &mut rng); + let trace = real_sha256_chain_trace(num_vars); let pp = setup_folded_pp_real_ecdsa(num_vars); do_bench_e2e_folded::( group, - "RealSha256", + "RealSha256Chain8", num_vars, &pp, &trace, @@ -2138,14 +3457,10 @@ fn bench_real_sha_ecdsa_e2e_folded(group: &mut BenchmarkGroup, num_var ); } - /// ShaEcdsa 4× folded: binary AND int both quartered /// (BinaryPoly<8> / Int<2>) and committed under one Merkle tree /// via `MultiZip3`. One Merkle path per opening instead of three. -fn bench_real_sha_ecdsa_e2e_folded_4x( - group: &mut BenchmarkGroup, - num_vars: usize, -) { +fn bench_real_sha_ecdsa_e2e_folded_4x(group: &mut BenchmarkGroup, num_vars: usize) { type U = ShaEcdsaUair; let mut rng = rng(); @@ -2181,7 +3496,7 @@ fn e2e_folded_benches(c: &mut Criterion) { // bench_sha_proxy_e2e_folded(&mut group, 12); bench_real_ecdsa_e2e_folded(&mut group, 9); - bench_real_sha256_e2e_folded(&mut group, 9); + bench_real_sha256_e2e_folded(&mut group, REAL_SHA256_CHAIN_NUM_VARS); bench_real_sha_ecdsa_e2e_folded(&mut group, 9); group.finish(); @@ -2223,7 +3538,6 @@ fn print_peak_rss(label: &str) { eprintln!("[{label}] peak RSS: {bytes} B ({mib:.2} MiB / {gib:.3} GiB)"); } - criterion_group! { name = e2e; config = Criterion::default().sample_size(500); @@ -2234,6 +3548,11 @@ criterion_group! { config = Criterion::default().sample_size(100); targets = e2e_steps_benches } +criterion_group! { + name = sha256_compare; + config = Criterion::default().sample_size(20); + targets = sha256_proving_system_compare_benches +} criterion_group! { name = e2e_folded; config = Criterion::default().sample_size(500); @@ -2244,4 +3563,4 @@ criterion_group! { config = Criterion::default().sample_size(500); targets = e2e_folded_4x_benches } -criterion_main!(e2e, e2e_steps, e2e_folded, e2e_folded_4x); +criterion_main!(e2e, e2e_steps, sha256_compare, e2e_folded, e2e_folded_4x); diff --git a/protocol/src/fixed_prime.rs b/protocol/src/fixed_prime.rs index eb11da95..7bca1dd2 100644 --- a/protocol/src/fixed_prime.rs +++ b/protocol/src/fixed_prime.rs @@ -14,6 +14,8 @@ //! are honest mod `p`). Do not reuse this branch for other applications //! without re-doing the soundness analysis. +use ark_ec::AffineRepr; +use ark_ff::{BigInteger, PrimeField as ArkPrimeField}; use crypto_primitives::PrimeField; use zinc_transcript::traits::ConstTranscribable; use zinc_utils::from_ref::FromRef; @@ -24,10 +26,8 @@ use zinc_utils::from_ref::FromRef; /// as little-endian limb chunks (see `transcript::traits` impl), so we /// store the prime in that same order. pub const SECP256K1_P_LE_BYTES: [u8; 32] = [ - 0x2F, 0xFC, 0xFF, 0xFF, 0xFE, 0xFF, 0xFF, 0xFF, - 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, - 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, - 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0x2F, 0xFC, 0xFF, 0xFF, 0xFE, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, ]; /// Build `F::Config` from the secp256k1 base prime, replacing the @@ -50,8 +50,39 @@ where "Fmod must be exactly 256 bits to hold the secp256k1 base prime", ); let prime = FMod::read_transcription_bytes_exact(&SECP256K1_P_LE_BYTES); + F::make_cfg(&F::Modulus::from_ref(&prime)).expect("secp256k1 base field prime is prime") +} + +/// Build `F::Config` from the scalar field of an arkworks curve. +/// +/// This is the field configuration Hyrax must use: PCS scalar operations +/// happen in `C::ScalarField`, so the PIOP field modulus must match it. +pub fn field_cfg_from_curve_scalar() -> F::Config +where + F: PrimeField, + FMod: ConstTranscribable, + F::Modulus: FromRef, + C: AffineRepr, +{ + let prime = fmod_from_curve_scalar::(); F::make_cfg(&F::Modulus::from_ref(&prime)) - .expect("secp256k1 base field prime is prime") + .expect("curve scalar modulus must define a valid prime field") +} + +fn fmod_from_curve_scalar() -> FMod +where + FMod: ConstTranscribable, + C: AffineRepr, +{ + let modulus_bytes = ::MODULUS.to_bytes_le(); + assert!( + modulus_bytes.len() <= FMod::NUM_BYTES, + "curve scalar modulus does not fit in the protocol modulus type", + ); + + let mut bytes = vec![0u8; FMod::NUM_BYTES]; + bytes[..modulus_bytes.len()].copy_from_slice(&modulus_bytes); + FMod::read_transcription_bytes_exact(&bytes) } #[cfg(test)] @@ -83,4 +114,23 @@ mod tests { fn secp256k1_field_cfg_constructs() { let _cfg = secp256k1_field_cfg::, Uint<4>>(); } + + #[test] + fn curve_scalar_field_cfg_constructs() { + let bn_cfg = field_cfg_from_curve_scalar::, Uint<4>, ark_bn254::G1Affine>(); + let secp_cfg = + field_cfg_from_curve_scalar::, Uint<4>, ark_secp256k1::Affine>(); + + let bn_modulus = MontyField::<4>::one_with_cfg(&bn_cfg).modulus(); + let secp_modulus = MontyField::<4>::one_with_cfg(&secp_cfg).modulus(); + + assert_eq!( + bn_modulus, + fmod_from_curve_scalar::, ark_bn254::G1Affine>(), + ); + assert_eq!( + secp_modulus, + fmod_from_curve_scalar::, ark_secp256k1::Affine>(), + ); + } } diff --git a/protocol/src/lib.rs b/protocol/src/lib.rs index 22dd910e..0ae6c9be 100644 --- a/protocol/src/lib.rs +++ b/protocol/src/lib.rs @@ -18,13 +18,18 @@ //! - Step 7: Zip+ PCS open/verify at r_0 pub mod fixed_prime; +mod multipoint_reduction; +pub mod pcs; +pub mod production_sha; pub mod prover; pub mod verifier; #[cfg(feature = "parallel")] use rayon::prelude::*; -use crypto_primitives::{ConstIntRing, ConstIntSemiring, FromWithConfig, PrimeField, Semiring}; +use crypto_primitives::{ + ConstIntRing, ConstIntSemiring, FromWithConfig, PrimeField, Semiring, crypto_bigint_uint::Uint, +}; use std::{fmt::Debug, marker::PhantomData}; use thiserror::Error; use zinc_piop::{ @@ -47,7 +52,13 @@ use zinc_poly::{ use zinc_primality::PrimalityTest; use zinc_transcript::traits::{ConstTranscribable, GenTranscribable, Transcribable, Transcript}; use zinc_uair::{Uair, ideal::Ideal}; -use zinc_utils::{cfg_extend, cfg_into_iter, cfg_iter, named::Named}; +use zinc_utils::{ + cfg_extend, cfg_into_iter, cfg_iter, + delayed_reduction::{ + BarrettDelayedReduction, DelayedModularReductionAlgorithm, MontgomeryLimbs, + }, + named::Named, +}; use zip_plus::{ ZipError, code::LinearCode, @@ -60,9 +71,9 @@ use zip_plus::{ /// Full proof produced by the Zinc+ PIOP for UCS. #[derive(Clone, Debug, PartialEq, Eq)] -pub struct Proof { +pub struct Proof { /// Zip+ commitments to the witness columns. - pub commitments: (ZipPlusCommitment, ZipPlusCommitment, ZipPlusCommitment), + pub commitments: Commitments, /// Serialized PCS proof data (Zip+ proving transcripts). pub zip: Vec, /// Randomized ideal check proof. @@ -472,6 +483,8 @@ pub enum ProtocolError { Pcs(#[from] ZipError), #[error("PCS verification failed at column {0}: {1}")] PcsVerification(usize, ZipError), + #[error("PCS proof has trailing bytes: consumed {consumed} of {total}")] + PcsProofTrailingBytes { consumed: usize, total: usize }, } // @@ -505,12 +518,16 @@ fn absorb_public_columns( /// Binary columns exploit the 0/1 structure for conditional additions only. /// The `eq(point, *)` table is built once and reused across all columns. #[allow(clippy::arithmetic_side_effects)] -fn compute_lifted_evals( +fn compute_lifted_evals( point: &[F], trace_bin_poly: &[DenseMultilinearExtension>], projected_trace: &ProjectedTrace, field_cfg: &F::Config, -) -> Vec> { +) -> Vec> +where + F: PrimeField + MontgomeryLimbs + Send + Sync, + F::Config: Sync, +{ compute_lifted_evals_capped::(point, trace_bin_poly, projected_trace, field_cfg, None) } @@ -524,46 +541,55 @@ fn compute_lifted_evals( /// section here is wasted work. Pass `Some(num_total_arb_cols)` to stop /// the non-binary iter right after arbitrary cols. #[allow(clippy::arithmetic_side_effects)] -pub fn compute_lifted_evals_capped( +pub fn compute_lifted_evals_capped( point: &[F], trace_bin_poly: &[DenseMultilinearExtension>], projected_trace: &ProjectedTrace, field_cfg: &F::Config, non_binary_cap: Option, -) -> Vec> { +) -> Vec> +where + F: PrimeField + MontgomeryLimbs + Send + Sync, + F::Config: Sync, +{ let eq_table = zinc_poly::utils::build_eq_x_r_vec(point, field_cfg) .expect("compute_lifted_evals: eq table build failed"); + let reducer = BarrettDelayedReduction::::new(field_cfg); let n_bin = trace_bin_poly.len(); let zero = F::zero_with_cfg(field_cfg); // Binary columns: exploit 0/1 structure for conditional additions. - // Pack each entry's up-to-64 boolean coefficients into a u64 so we - // can (a) skip entries that are identically zero, and (b) walk only - // the SET bits via `trailing_zeros` + Brian Kernighan's clear-lowest - // instead of branching on every slot. - debug_assert!(D <= 64, "compute_lifted_evals: bitmask packing assumes D <= 64"); + // For D <= 64, pack each row into a bitmask so zero rows are skipped + // and set bits are walked directly. Each set bit adds cached eq[b] into + // a DMR accumulator, reducing once per output coefficient. let mut result: Vec> = cfg_iter!(trace_bin_poly) .map(|col| { - let mut coeffs = vec![zero.clone(); D]; + let mut coeffs = vec![Uint::<5>::default(); D]; for (b, entry) in col.iter().enumerate() { - let mut bits: u64 = 0; - for (l, coeff) in entry.iter().enumerate().take(D) { - if coeff.into_inner() { - bits |= 1u64 << l; - } - } - if bits == 0 { - continue; - } let eq_b = &eq_table[b]; - let mut remaining = bits; - while remaining != 0 { - let l = remaining.trailing_zeros() as usize; - coeffs[l] += eq_b; - remaining &= remaining - 1; + if D <= 64 { + let mut bits: u64 = 0; + for (l, coeff) in entry.iter().enumerate().take(D) { + if coeff.into_inner() { + bits |= 1u64 << l; + } + } + let mut remaining = bits; + while remaining != 0 { + let l = remaining.trailing_zeros() as usize; + reducer.add(&mut coeffs[l], eq_b); + remaining &= remaining - 1; + } + } else { + for (l, coeff) in entry.iter().enumerate().take(D) { + if coeff.into_inner() { + reducer.add(&mut coeffs[l], eq_b); + } + } } } + let coeffs: Vec = coeffs.into_iter().map(|acc| reducer.reduce(acc)).collect(); DynamicPolynomialF::new_trimmed(coeffs) }) .collect(); @@ -636,8 +662,7 @@ pub fn compute_int_fold_lifted_evals( field_cfg: &F::Config, ) -> Vec> where - F: PrimeField - + for<'a> FromWithConfig<&'a crypto_primitives::crypto_bigint_int::Int>, + F: PrimeField + for<'a> FromWithConfig<&'a crypto_primitives::crypto_bigint_int::Int>, { use crypto_primitives::crypto_bigint_int::Int; assert!(HALF_H >= 2); @@ -777,9 +802,15 @@ where #[cfg(test)] mod tests { use super::*; + use crate::{ + fixed_prime::field_cfg_from_curve_scalar, + pcs::{AllZipPCSTypes, BinaryHyraxZipRest, PCSParams, PCSVerifierParams, ZincPCSTypes}, + }; + use ark_ec::AffineRepr; use crypto_bigint::U64; use crypto_primitives::{ - Field, crypto_bigint_int::Int, crypto_bigint_monty::MontyField, crypto_bigint_uint::Uint, + Field, FromWithConfig, boolean::Boolean, crypto_bigint_int::Int, + crypto_bigint_monty::MontyField, crypto_bigint_uint::Uint, }; use rand::rng; use zinc_piop::{ @@ -789,9 +820,10 @@ mod tests { use zinc_primality::MillerRabin; use zinc_test_uair::{ BigLinearUair, BigLinearUairWithPublicInput, BinaryDecompositionUair, BitOpRotUair, - EC_FP_INT_LIMBS, GenerateRandomTrace, Sha256CompressionSliceUair, Sha256Ideal, - ShaEcdsaUair, TestUairMixedDegrees, TestUairMixedShifts, TestUairNoMultiplication, - TestUairSimpleMultiplication, + EC_FP_INT_LIMBS, GenerateRandomTrace, SHA256_INITIAL_STATE, Sha256CompressionSliceUair, + Sha256Ideal, Sha256MessageBlock, Sha256State, ShaEcdsaUair, TestUairMixedDegrees, + TestUairMixedShifts, TestUairNoMultiplication, TestUairSimpleMultiplication, + sha256_compress_native, sha256_padded_message_blocks, synthesize_sha256_chain_witnesses, }; use zinc_uair::{ ideal::{DegreeOneIdeal, rotation::RotationIdeal}, @@ -805,16 +837,20 @@ mod tests { }; use zip_plus::{ code::{ + LinearCode, iprs::{IprsCode, PnttConfigF65537}, raa::{RaaCode, RaaConfig}, }, - pcs::structs::{ZipPlus, ZipPlusParams}, + pcs::{ + generic::ZipPlusPCS, + hyrax::{BinaryLanes, HyraxBlindingMode, HyraxPCS}, + structs::{ZipPlus, ZipPlusParams}, + }, pcs_transcript::PcsProverTranscript, }; const INT_LIMBS: usize = U64::LIMBS; - // `fixed-prime` branch: 256-bit field modulus (4 × u64 limbs) so the - // hardcoded secp256k1 base prime fits in `Fmod = Uint`. + // 256-bit field modulus (4 × u64 limbs). const FIELD_LIMBS: usize = U64::LIMBS * 4; const DEGREE_PLUS_ONE: usize = 32; @@ -851,6 +887,67 @@ mod tests { type F = MontyField; + #[test] + fn lifted_binary_eval_matches_definition() { + let field_cfg = fixed_prime::secp256k1_field_cfg::>(); + let point = vec![ + F::from_with_cfg(7u64, &field_cfg), + F::from_with_cfg(11u64, &field_cfg), + ]; + let trace_bin_poly: Vec>> = [ + [0b0000_0001, 0b0000_1010, 0b1010_0000, 0b1111_0000], + [0b1111_1111, 0b0101_0101, 0b0011_0011, 0b0000_0000], + ] + .into_iter() + .map(|patterns| { + let evaluations: Vec> = patterns + .into_iter() + .map(|p| { + let coeffs: [Boolean; 8] = + std::array::from_fn(|i| Boolean::new((p >> i) & 1 != 0)); + BinaryPoly::<8>::new(coeffs) + }) + .collect(); + DenseMultilinearExtension { + num_vars: evaluations.len().next_power_of_two().trailing_zeros() as usize, + evaluations, + } + }) + .collect(); + let projected_trace = ProjectedTrace::RowMajor(vec![ + vec![ + DynamicPolynomialF::new_trimmed(vec![F::zero_with_cfg(&field_cfg)]), + DynamicPolynomialF::new_trimmed(vec![F::zero_with_cfg(&field_cfg)]), + ]; + 4 + ]); + + let eq_table = zinc_poly::utils::build_eq_x_r_vec(&point, &field_cfg).unwrap(); + let expected: Vec<_> = trace_bin_poly + .iter() + .map(|col| { + let mut coeffs = vec![F::zero_with_cfg(&field_cfg); 8]; + for (row_idx, entry) in col.iter().enumerate() { + for (bit_idx, coeff) in entry.iter().enumerate() { + if coeff.into_inner() { + coeffs[bit_idx] += &eq_table[row_idx]; + } + } + } + DynamicPolynomialF::new_trimmed(coeffs) + }) + .collect(); + let lifted = compute_lifted_evals_capped::( + &point, + &trace_bin_poly, + &projected_trace, + &field_cfg, + None, + ); + + assert_eq!(lifted, expected); + } + #[derive(Debug, Clone)] pub struct BinPolyZipTypes {} impl ZipTypes for BinPolyZipTypes { @@ -1034,7 +1131,7 @@ mod tests { check_verification: impl Fn(Result<(), ProtocolError>>>), ) where Zt: ZincTypes, - Zt::Int: num_traits::Zero, + Zt::Int: ProjectableToField + num_traits::Zero, ::Cw: ProjectableToField, ::Eval: ProjectableToField, ::Cw: ProjectableToField, @@ -1349,9 +1446,7 @@ mod tests { |res| { assert!(matches!( res.unwrap_err(), - ProtocolError::Resolver( - CombinedPolyResolverError::WrongSumcheckSum { .. } - ) + ProtocolError::Resolver(CombinedPolyResolverError::WrongSumcheckSum { .. }) )); }, ); @@ -1800,10 +1895,7 @@ mod tests { tamper(&mut proof); - let verification_result = ZincPlusPiop::::verify::< - _, - CHECKED, - >( + let verification_result = ZincPlusPiop::::verify::<_, CHECKED>( &pp, proof, &public_trace, @@ -1814,6 +1906,285 @@ mod tests { check_verification(verification_result); } + #[allow(clippy::type_complexity)] + fn sha256_zip_pcs_params( + num_vars: usize, + ) -> ( + PCSParams, + PCSVerifierParams, + ) { + let pp = setup_pp::( + num_vars, + ( + make_iprs(num_vars), + make_iprs(num_vars), + make_iprs(num_vars), + ), + ); + ( + PCSParams:: { + binary: pp.0.clone(), + arbitrary: pp.1.clone(), + int: pp.2.clone(), + }, + PCSVerifierParams:: { + binary: pp.0, + arbitrary: pp.1, + int: pp.2, + }, + ) + } + + #[allow(clippy::type_complexity)] + fn sha256_hyrax_pcs_params( + num_vars: usize, + ) -> ( + PCSParams, TestShaEcdsaZincTypes, F, DEGREE_PLUS_ONE>, + PCSVerifierParams, TestShaEcdsaZincTypes, F, DEGREE_PLUS_ONE>, + ) + where + BinaryHyraxZipRest: ZincPCSTypes< + TestShaEcdsaZincTypes, + F, + DEGREE_PLUS_ONE, + BinaryPCS = HyraxPCS, + ArbitraryPCS = ZipPlusPCS< + >::ArbitraryZt, + >::ArbitraryLc, + >, + IntPCS = ZipPlusPCS< + >::IntZt, + >::IntLc, + >, + >, + { + let pp = setup_pp::( + num_vars, + ( + make_iprs(num_vars), + make_iprs(num_vars), + make_iprs(num_vars), + ), + ); + let width = pp.0.linear_code.row_len(); + let (binary, binary_vk) = HyraxPCS::::setup( + width, + b"zinc-plus-test-sha256-hyrax", + HyraxBlindingMode::Unblinded, + ) + .expect("Hyrax setup must be valid"); + ( + PCSParams::, TestShaEcdsaZincTypes, F, DEGREE_PLUS_ONE> { + binary, + arbitrary: pp.1.clone(), + int: pp.2.clone(), + }, + PCSVerifierParams::, TestShaEcdsaZincTypes, F, DEGREE_PLUS_ONE> { + binary: binary_vk, + arbitrary: pp.1, + int: pp.2, + }, + ) + } + + fn run_sha256_pcs_round_trip

( + pp: &PCSParams, + vp: &PCSVerifierParams, + field_cfg: ::Config, + ) where + P: ZincPCSTypes, + { + type U = Sha256CompressionSliceUair; + + const NUM_VARS: usize = 9; + + let mut rng = rng(); + let trace = U::generate_random_trace(NUM_VARS, &mut rng); + let public_trace = trace.public(&U::signature()); + + let proof = + ZincPlusPiop::::prove_with_pcs_and_field_cfg::< + P, + false, + CHECKED, + >(pp, &trace, NUM_VARS, project_scalar_fn, field_cfg.clone()) + .expect("SHA PCS prover failed"); + + ZincPlusPiop::::verify_with_pcs_and_field_cfg::< + P, + Sha256Ideal, + CHECKED, + >( + vp, + proof, + &public_trace, + NUM_VARS, + project_scalar_fn, + sha256_test_project_ideal, + field_cfg, + ) + .expect("SHA PCS verifier rejected an honest proof"); + } + + #[test] + fn test_synthesized_sha256_single_witness_round_trip() { + type U = Sha256CompressionSliceUair; + + const NUM_VARS: usize = 9; + let initial_state: Sha256State = SHA256_INITIAL_STATE; + let message_blocks: [Sha256MessageBlock; 1] = + sha256_padded_message_blocks(b"zinc-plus synthesized SHA witness") + .expect("test message should canonically pad to one SHA-256 block"); + + let (witnesses, final_state) = + synthesize_sha256_chain_witnesses::(initial_state, message_blocks) + .expect("synthesized SHA witness generation should succeed"); + assert_eq!( + final_state, + sha256_compress_native(initial_state, message_blocks[0]) + ); + + let (pp, vp) = sha256_zip_pcs_params(NUM_VARS); + let field_cfg = field_cfg_from_curve_scalar::< + F, + >::Fmod, + ark_bn254::G1Affine, + >(); + let public_trace = witnesses[0].trace.public(&U::signature()); + let proof = ZincPlusPiop::::prove_with_pcs_and_field_cfg::< + AllZipPCSTypes, + false, + CHECKED, + >( + &pp, + &witnesses[0].trace, + NUM_VARS, + project_scalar_fn, + field_cfg.clone(), + ) + .expect("synthesized SHA PCS prover failed"); + + ZincPlusPiop::::verify_with_pcs_and_field_cfg::< + AllZipPCSTypes, + Sha256Ideal, + CHECKED, + >( + &vp, + proof, + &public_trace, + NUM_VARS, + project_scalar_fn, + sha256_test_project_ideal, + field_cfg, + ) + .expect("SHA PCS verifier rejected a synthesized witness proof"); + } + + #[test] + fn test_real_sha256_pcs_zip_bn_round_trip() { + const NUM_VARS: usize = 9; + + let bn_field_cfg = field_cfg_from_curve_scalar::< + F, + >::Fmod, + ark_bn254::G1Affine, + >(); + let (zip_bn_pp, zip_bn_vp) = sha256_zip_pcs_params(NUM_VARS); + run_sha256_pcs_round_trip::(&zip_bn_pp, &zip_bn_vp, bn_field_cfg); + } + + #[test] + fn test_real_sha256_pcs_zip_secp256k1_round_trip() { + const NUM_VARS: usize = 9; + + let secp_field_cfg = field_cfg_from_curve_scalar::< + F, + >::Fmod, + ark_secp256k1::Affine, + >(); + let (zip_secp_pp, zip_secp_vp) = sha256_zip_pcs_params(NUM_VARS); + run_sha256_pcs_round_trip::(&zip_secp_pp, &zip_secp_vp, secp_field_cfg); + } + + #[test] + fn test_real_sha256_pcs_hyrax_bn_round_trip() { + const NUM_VARS: usize = 9; + + let bn_field_cfg = field_cfg_from_curve_scalar::< + F, + >::Fmod, + ark_bn254::G1Affine, + >(); + let (hyrax_bn_pp, hyrax_bn_vp) = sha256_hyrax_pcs_params::(NUM_VARS); + run_sha256_pcs_round_trip::>( + &hyrax_bn_pp, + &hyrax_bn_vp, + bn_field_cfg, + ); + } + + #[test] + fn test_real_sha256_pcs_hyrax_secp256k1_round_trip() { + const NUM_VARS: usize = 9; + + let secp_field_cfg = field_cfg_from_curve_scalar::< + F, + >::Fmod, + ark_secp256k1::Affine, + >(); + let (hyrax_secp_pp, hyrax_secp_vp) = + sha256_hyrax_pcs_params::(NUM_VARS); + run_sha256_pcs_round_trip::>( + &hyrax_secp_pp, + &hyrax_secp_vp, + secp_field_cfg, + ); + } + + #[test] + fn test_real_sha256_rejects_trailing_pcs_bytes() { + type U = Sha256CompressionSliceUair; + const NUM_VARS: usize = 9; + + let field_cfg = field_cfg_from_curve_scalar::< + F, + >::Fmod, + ark_bn254::G1Affine, + >(); + let (pp, vp) = sha256_zip_pcs_params(NUM_VARS); + let mut rng = rng(); + let trace = U::generate_random_trace(NUM_VARS, &mut rng); + let public_trace = trace.public(&U::signature()); + + let mut proof = + ZincPlusPiop::::prove_with_pcs_and_field_cfg::< + AllZipPCSTypes, + false, + CHECKED, + >(&pp, &trace, NUM_VARS, project_scalar_fn, field_cfg.clone()) + .expect("SHA PCS prover failed"); + proof.zip.extend_from_slice(b"trailing pcs bytes"); + + let result = + ZincPlusPiop::::verify_with_pcs_and_field_cfg::< + AllZipPCSTypes, + Sha256Ideal, + CHECKED, + >( + &vp, + proof, + &public_trace, + NUM_VARS, + project_scalar_fn, + sha256_test_project_ideal, + field_cfg, + ); + assert!(matches!( + result, + Err(ProtocolError::PcsProofTrailingBytes { .. }) + )); + } + /// `num_vars` for SHA-ECDSA tests. ECDSA's Shamir scalar /// multiplication needs `n_rows > 256`, so `num_vars >= 9`. const SHA_ECDSA_NUM_VARS: usize = 9; @@ -1894,9 +2265,7 @@ mod tests { serialized_len.div_ceil(1024), ); let mut transcript = transcript.into_verification_transcript(); - let proof_2 = transcript - .read() - .expect("Failed to deserialize proof"); + let proof_2 = transcript.read().expect("Failed to deserialize proof"); assert_eq!(proof, proof_2); verify_folded_4x::< @@ -2116,7 +2485,7 @@ mod tests { HALF_DEGREE_PLUS_ONE, >>::IntLc, >, - ) { + ){ let split_size = 1 << (num_vars + 1); let normal_size = 1 << num_vars; ( @@ -2213,5 +2582,4 @@ mod tests { >; type ArrCombRDotChal = MBSInnerProduct; } - } diff --git a/protocol/src/multipoint_reduction.rs b/protocol/src/multipoint_reduction.rs new file mode 100644 index 00000000..89b9fb26 --- /dev/null +++ b/protocol/src/multipoint_reduction.rs @@ -0,0 +1,62 @@ +use crypto_primitives::FromPrimitiveWithConfig; +use num_traits::Zero; +use zinc_piop::multipoint_eval::{ + MultipointEval, MultipointEvalError, Proof as MultipointEvalProof, + Subclaim as MultipointSubclaim, +}; +use zinc_poly::mle::DenseMultilinearExtension; +use zinc_transcript::traits::{Transcribable, Transcript}; +use zinc_uair::ShiftSpec; +use zinc_utils::{ + delayed_reduction::DelayedFieldProductSum, inner_transparent_field::InnerTransparentField, +}; + +pub(crate) fn prove_multipoint_reduction( + transcript: &mut impl Transcript, + trace_mles: &[DenseMultilinearExtension], + eval_point: &[F], + up_evals: &[F], + down_evals: &[F], + shifts: &[ShiftSpec], + field_cfg: &F::Config, +) -> Result<(MultipointEvalProof, Vec), MultipointEvalError> +where + F: InnerTransparentField + + DelayedFieldProductSum + + FromPrimitiveWithConfig + + Send + + Sync + + 'static, + F::Inner: Transcribable + Zero + Default + Send + Sync, + F::Modulus: Transcribable, +{ + let (proof, state) = MultipointEval::prove_as_subprotocol( + transcript, trace_mles, eval_point, up_evals, down_evals, shifts, field_cfg, + )?; + Ok((proof, state.eval_point)) +} + +pub(crate) fn verify_multipoint_reduction( + transcript: &mut impl Transcript, + proof: MultipointEvalProof, + eval_point: &[F], + up_evals: &[F], + down_evals: &[F], + shifts: &[ShiftSpec], + num_vars: usize, + field_cfg: &F::Config, +) -> Result, MultipointEvalError> +where + F: InnerTransparentField + + DelayedFieldProductSum + + FromPrimitiveWithConfig + + Send + + Sync + + 'static, + F::Inner: Transcribable + Zero + Default + Send + Sync, + F::Modulus: Transcribable, +{ + MultipointEval::verify_as_subprotocol( + transcript, proof, eval_point, up_evals, down_evals, shifts, num_vars, field_cfg, + ) +} diff --git a/protocol/src/pcs.rs b/protocol/src/pcs.rs new file mode 100644 index 00000000..40455261 --- /dev/null +++ b/protocol/src/pcs.rs @@ -0,0 +1,192 @@ +use std::{fmt::Debug, marker::PhantomData}; + +use ark_ec::AffineRepr; +use crypto_primitives::PrimeField; +use zinc_poly::univariate::{binary::BinaryPoly, dense::DensePolynomial}; +use zip_plus::pcs::{ + generic::{PCS, ZipPlusPCS}, + hyrax::{BinaryLanes, DensePolyScalarLanes, HyraxPCS, IntScalarLane}, + structs::ZipPlusCommitment, +}; + +use crate::ZincTypes; + +pub type ZipPCSCommitments = (ZipPlusCommitment, ZipPlusCommitment, ZipPlusCommitment); + +pub trait ZincPCSTypes: Clone + Debug + Send + Sync +where + Zt: ZincTypes, + F: PrimeField, +{ + type BinaryPCS: PCS, D>; + type ArbitraryPCS: PCS, D>; + type IntPCS: PCS; +} + +#[derive(Clone, Debug)] +pub struct AllZipPCSTypes; + +impl ZincPCSTypes for AllZipPCSTypes +where + Zt: ZincTypes, + F: PrimeField, + ZipPlusPCS: PCS, D>, + ZipPlusPCS: PCS, D>, + ZipPlusPCS: PCS, +{ + type BinaryPCS = ZipPlusPCS; + type ArbitraryPCS = ZipPlusPCS; + type IntPCS = ZipPlusPCS; +} + +#[derive(Clone, Debug)] +pub struct BinaryHyraxZipRest(PhantomData); + +impl ZincPCSTypes for BinaryHyraxZipRest +where + Zt: ZincTypes, + F: PrimeField, + C: AffineRepr, + HyraxPCS: PCS, D>, + ZipPlusPCS: PCS, D>, + ZipPlusPCS: PCS, +{ + type BinaryPCS = HyraxPCS; + type ArbitraryPCS = ZipPlusPCS; + type IntPCS = ZipPlusPCS; +} + +/// Homomorphic PCS bundle for production ProjectionFold paths. +/// +/// Every witness domain uses Hyrax/MSM commitments so verifier-derived +/// instance-axis folds can be opened against folded prover data. +#[derive(Clone, Debug)] +pub struct AllHyraxPCSTypes(PhantomData); + +impl ZincPCSTypes for AllHyraxPCSTypes +where + Zt: ZincTypes, + F: PrimeField, + C: AffineRepr, + HyraxPCS: PCS, D>, + HyraxPCS: PCS, D>, + HyraxPCS: PCS, +{ + type BinaryPCS = HyraxPCS; + type ArbitraryPCS = HyraxPCS; + type IntPCS = HyraxPCS; +} + +#[derive(Clone, Debug)] +pub struct BinaryIntHyraxZipArbitrary(PhantomData); + +impl ZincPCSTypes for BinaryIntHyraxZipArbitrary +where + Zt: ZincTypes, + F: PrimeField, + C: AffineRepr, + HyraxPCS: PCS, D>, + ZipPlusPCS: PCS, D>, + HyraxPCS: PCS, +{ + type BinaryPCS = HyraxPCS; + type ArbitraryPCS = ZipPlusPCS; + type IntPCS = HyraxPCS; +} + +#[derive(Clone, Debug)] +pub struct PCSParams +where + Zt: ZincTypes, + F: PrimeField, + P: ZincPCSTypes, +{ + pub binary: + <

>::BinaryPCS as PCS, D>>::CommitmentKey, + pub arbitrary: <

>::ArbitraryPCS as PCS< + F, + DensePolynomial, + D, + >>::CommitmentKey, + pub int: <

>::IntPCS as PCS>::CommitmentKey, +} + +#[derive(Clone, Debug)] +pub struct PCSVerifierParams +where + Zt: ZincTypes, + F: PrimeField, + P: ZincPCSTypes, +{ + pub binary: <

>::BinaryPCS as PCS, D>>::VerifierKey, + pub arbitrary: <

>::ArbitraryPCS as PCS< + F, + DensePolynomial, + D, + >>::VerifierKey, + pub int: <

>::IntPCS as PCS>::VerifierKey, +} + +#[derive(Clone, Debug)] +pub struct PCSCommitments +where + Zt: ZincTypes, + F: PrimeField, + P: ZincPCSTypes, +{ + pub binary: <

>::BinaryPCS as PCS, D>>::Commitment, + pub arbitrary: <

>::ArbitraryPCS as PCS< + F, + DensePolynomial, + D, + >>::Commitment, + pub int: <

>::IntPCS as PCS>::Commitment, +} + +#[derive(Clone, Debug)] +pub struct PCSProverData +where + Zt: ZincTypes, + F: PrimeField, + P: ZincPCSTypes, +{ + pub binary: <

>::BinaryPCS as PCS, D>>::ProverData, + pub arbitrary: <

>::ArbitraryPCS as PCS< + F, + DensePolynomial, + D, + >>::ProverData, + pub int: <

>::IntPCS as PCS>::ProverData, +} + +#[derive(Clone, Debug)] +pub struct PCSOpeningProof +where + Zt: ZincTypes, + F: PrimeField, + P: ZincPCSTypes, +{ + pub binary: + <

>::BinaryPCS as PCS, D>>::OpeningProof, + pub arbitrary: <

>::ArbitraryPCS as PCS< + F, + DensePolynomial, + D, + >>::OpeningProof, + pub int: <

>::IntPCS as PCS>::OpeningProof, +} + +impl Default for PCSOpeningProof +where + Zt: ZincTypes, + F: PrimeField, + P: ZincPCSTypes, +{ + fn default() -> Self { + Self { + binary: Default::default(), + arbitrary: Default::default(), + int: Default::default(), + } + } +} diff --git a/protocol/src/production_sha.rs b/protocol/src/production_sha.rs new file mode 100644 index 00000000..32a707fa --- /dev/null +++ b/protocol/src/production_sha.rs @@ -0,0 +1,9021 @@ +//! Production SHA ProjectionFold protocol helpers. +//! +//! This module is intentionally separate from the existing single-instance +//! `Proof`: production ProjectionFold has a different transcript order and +//! derives folded commitments only after SumFold fixes the instance-axis point. + +use std::{borrow::Cow, io::Cursor, marker::PhantomData}; + +use crate::{ + ZincTypes, + multipoint_reduction::{prove_multipoint_reduction, verify_multipoint_reduction}, + pcs::{ + AllHyraxPCSTypes, PCSCommitments, PCSOpeningProof, PCSParams, PCSProverData, + PCSVerifierParams, ZincPCSTypes, + }, +}; +use ark_ec::AffineRepr; +use crypto_primitives::{FromPrimitiveWithConfig, PrimeField}; +use num_traits::{ConstZero, Zero}; +#[cfg(feature = "parallel")] +use rayon::prelude::*; +use thiserror::Error; +#[cfg(debug_assertions)] +use zinc_piop::neutron_nova::expression_folded_row_sum_with_vectors; +#[cfg(debug_assertions)] +use zinc_piop::neutron_nova::validate_projected_trace; +use zinc_piop::{ + combined_poly_resolver::Proof as CombinedPolyResolverProof, + ideal_check::Proof as IdealCheckProof, + multipoint_eval::{ + MultipointEval, MultipointEvalError, Proof as MultipointEvalProof, + Subclaim as MultipointSubclaim, + }, + neutron_nova::SumFoldError, + neutron_nova::{ + InstanceFoldClaim, LinearResidualCoeffTable, MleTable, NUM_NONZERO_SHA_FAMILIES, + NUM_SHA_RESIDUAL_FAMILIES, ProjectedPublic, ProjectedTrace, ProjectionFoldWitness, + SHA_ROW_COUNT, SHA_ROW_VARS, SHA_WORD_BITS, ShaBinaryFoldField, ShaBooleanitySource, + ShaIntCol, ShaProjectionError, ShaPublicCol, ShaPublicWordCol, ShaResidualFamily, + ShaWordCol, beta_aggregate_nonzero_ideal_polys_with_weights, bit_slice_index, + build_booleanity_weights, build_dense_sha_sumfold_group, build_folded_row_sumcheck_group, + build_linear_residual_coeff_tables_with_row_weights, + build_production_sha_sumfold_group_from_prefix_accumulators, build_sha_lambda_powers, + build_sha_residual_eval_powers, build_sha_sumfold_linear_accumulator, + build_sha_sumfold_quadratic_prefix_accumulator, derive_instance_fold_claim, + expression_folded_row_sum_with_row_weights, fold_projected_traces, + folded_row_integrand_sum, production_sha_booleanity_sources, + production_sha_nonzero_families, sha_int_at_point_with_weights_unchecked, + sha_public_at_point, sha_public_at_point_with_weights, + sha_word_bits_at_point_with_weights_unchecked, verify_folded_row_sumcheck_claim, + verify_fresh_sha_ideal_polys, + }, + sumcheck::{ + SumCheckError, + multi_degree::{MultiDegreeSumcheck, MultiDegreeSumcheckGroup, MultiDegreeSumcheckProof}, + }, +}; +use zinc_poly::{ + EvaluatablePolynomial, EvaluationError, + mle::DenseMultilinearExtension, + univariate::{ + binary::BinaryPoly, + dense::DensePolynomial, + dynamic::over_field::{DynamicPolyFInnerProduct, DynamicPolynomialF}, + nat_evaluation::NatEvaluatedPoly, + }, + utils::{ArithErrors, build_eq_x_r_vec, eq_eval}, +}; +use zinc_transcript::Blake3Transcript; +use zinc_transcript::traits::{GenTranscribable, Transcribable, Transcript}; +use zinc_uair::{ShiftSpec, Uair, UairSignature, UairTrace, UairWitness}; +use zinc_utils::{ + UNCHECKED, cfg_into_iter, cfg_iter, delayed_reduction::DelayedFieldProductSum, + inner_product::FieldFieldInnerProduct, inner_product::InnerProduct, + inner_transparent_field::InnerTransparentField, +}; +use zip_plus::{ + ZipError, + pcs::{ + generic::{FoldablePCS, PCS}, + hyrax::{BinaryLanes, DensePolyScalarLanes, HyraxFieldBridge, HyraxPCS, IntScalarLane}, + }, + pcs_transcript::{PcsProverTranscript, PcsVerifierTranscript}, +}; + +/// Serialized production ProjectionFold proof object. +/// +/// This object carries verifier messages and claimed evaluations only. Folding +/// weights, batching powers, folded accumulator values, and prover working +/// caches are derived from the transcript/setup or kept as prover-local state. +#[derive(Clone, Debug)] +pub struct ProductionLinearIdealFoldProof +where + Zt: ZincTypes, + F: PrimeField, + P: ZincPCSTypes, +{ + pub instance_commitments: Vec>, + pub ideal_check: IdealCheckProof, + pub sumfold_proof: MultiDegreeSumcheckProof, + pub resolver: CombinedPolyResolverProof, + pub combined_sumcheck: MultiDegreeSumcheckProof, + pub multipoint_eval: MultipointEvalProof, + pub witness_lifted_evals: Vec>, + pub opening_proof: PCSOpeningProof, +} + +#[derive(Clone, Debug)] +pub struct ProductionShaWitnessPolys +where + Zt: ZincTypes, +{ + pub binary: MleTable>, + pub arbitrary: MleTable>, + pub int: MleTable, +} + +#[derive(Clone, Debug)] +pub struct ProductionShaProverInstance +where + Zt: ZincTypes, + F: PrimeField, +{ + pub trace: ProjectedTrace, + pub public: ProjectedPublic, + pub witness_polys: ProductionShaWitnessPolys, +} + +#[derive(Clone, Debug)] +pub struct PreparedProductionShaProverInstance +where + Zt: ZincTypes, + F: PrimeField, +{ + pub public_trace: UairTrace<'static, Zt::Int, Zt::Int, D>, + pub instance: ProductionShaProverInstance, +} + +pub trait ProductionShaProjectionAdapter: Uair +where + Zt: ZincTypes, + F: PrimeField, +{ + fn production_sha_pcs_batch_sizes() -> (usize, usize, usize) { + (ShaWordCol::COUNT, 0, ShaIntCol::COUNT) + } + + fn project_production_sha_public( + shape: &UairShape, + public_trace: &UairTrace<'_, Zt::Int, Zt::Int, D>, + field_cfg: &F::Config, + ) -> Result, ProductionShaError> + where + Self: Sized; + + fn project_production_sha_witness( + shape: &UairShape, + public_trace: &UairTrace<'_, Zt::Int, Zt::Int, D>, + witness_trace: &UairTrace<'_, Zt::Int, Zt::Int, D>, + field_cfg: &F::Config, + ) -> Result< + ( + ProjectedTrace, + ProjectedPublic, + ProductionShaWitnessPolys, + ), + ProductionShaError, + > + where + Self: Sized; +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct ShaEndpointEvals { + pub sources: Vec>, + pub int_sources: Vec>, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct ShaSourceEndpointEval { + pub col: ShaWordCol, + pub shift: usize, + pub scalarized: F, + pub bits: [F; 32], +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct ShaIntEndpointEval { + pub col: ShaIntCol, + pub scalar: F, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub enum ShaMpSource { + WordBit { col: ShaWordCol, bit: usize }, + Int { col: ShaIntCol }, + Public { col: ShaPublicCol }, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub enum ShaMpShiftSource { + WordBit { + col: ShaWordCol, + bit: usize, + shift: usize, + }, + Public { + col: ShaPublicCol, + shift: usize, + }, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct ShaMultipointLayout { + pub sources: Vec, + pub shifts: Vec, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct VirtualChMajEndpoint { + pub ch1: [F; 32], + pub ch2: [F; 32], + pub maj: [F; 32], +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct ProductionShaChallenges { + pub r_ic: [F; SHA_ROW_VARS], + pub a: F, + pub lambda: F, + pub rho: F, + pub xi: F, + pub beta: Vec, +} + +const SHA256_ROUND_CONSTANTS: [u32; 64] = [ + 0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5, + 0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174, + 0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da, + 0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967, + 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85, + 0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070, + 0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3, + 0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2, +]; +const SHA_IDEAL_EVAL_POWER_COUNT: usize = 62; + +const PRODUCTION_SHA_FRESH_BATCH_DOMAIN: &[u8] = b"PF_CONCISE_SHA256_FRESH_BATCH_V1"; + +#[derive(Clone, Debug)] +pub struct UairShape { + pub num_vars: usize, + pub signature: UairSignature, + _marker: PhantomData, +} + +impl UairShape { + pub fn new(num_vars: usize) -> Self { + Self { + num_vars, + signature: U::signature(), + _marker: PhantomData, + } + } +} + +#[derive(Clone, Debug)] +pub struct UairInstance<'a, PolyCoeff: Clone, Int: Clone, Commitments, const D: usize> { + pub public_trace: UairTrace<'a, PolyCoeff, Int, D>, + pub commitments: Commitments, +} + +#[derive(Clone, Debug)] +pub struct LinearIdealFoldProverParams +where + U: Uair, + Zt: ZincTypes, + F: PrimeField, + P: ZincPCSTypes, +{ + pub pcs_params: PCSParams, + pub field_cfg: F::Config, + pub prefix_vars: usize, + _marker: PhantomData, +} + +impl LinearIdealFoldProverParams +where + U: Uair, + Zt: ZincTypes, + F: PrimeField, + P: ZincPCSTypes, +{ + pub fn new( + pcs_params: PCSParams, + field_cfg: F::Config, + prefix_vars: usize, + ) -> Self { + Self { + pcs_params, + field_cfg, + prefix_vars, + _marker: PhantomData, + } + } +} + +#[derive(Clone, Debug)] +pub struct LinearIdealFoldVerifierParams +where + U: Uair, + Zt: ZincTypes, + F: PrimeField, + P: ZincPCSTypes, +{ + pub pcs_params: PCSVerifierParams, + pub field_cfg: F::Config, + _marker: PhantomData, +} + +impl LinearIdealFoldVerifierParams +where + U: Uair, + Zt: ZincTypes, + F: PrimeField, + P: ZincPCSTypes, +{ + pub fn new(pcs_params: PCSVerifierParams, field_cfg: F::Config) -> Self { + Self { + pcs_params, + field_cfg, + _marker: PhantomData, + } + } +} + +#[derive(Clone, Debug)] +pub struct VerifiedLinearIdealFoldSetup +where + U: Uair, + Zt: ZincTypes, + F: PrimeField, + P: ZincPCSTypes, +{ + pub pcs_params: PCSVerifierParams, + pub shape: UairShape, + pub field_cfg: F::Config, +} + +#[derive(Clone, Debug)] +pub struct LinearIdealFoldProveOutput { + pub fresh_instances: Vec, + pub folded_instance: FoldedInstance, + pub folded_witness: FoldedWitness, + pub proof: Proof, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct FoldedLinearIdealInstance { + pub target: F, + pub commitments: Commitments, + pub public: Public, +} + +#[derive(Clone, Debug)] +pub struct FoldedLinearIdealWitness { + pub witness: Witness, +} + +type ProductionShaFreshArtifacts = ( + Vec< + UairInstance< + 'static, + >::Int, + >::Int, + PCSCommitments, + D, + >, + >, + Vec>, + Vec>, + Vec>, + Vec>, +); + +type ProductionShaSumfoldAccumulators = (Vec, Vec, MultiDegreeSumcheckGroup); + +type ProductionShaFoldAfterSumfold = ( + ProjectionFoldWitness, + ProjectedPublic, + F, + InstanceFoldClaim, + PCSProverData, +); + +type ProductionShaEndpointMultipoint = ( + CombinedPolyResolverProof, + ShaEndpointEvals, + MultipointEvalProof, + Vec, +); + +type ProductionShaPcsOpening = + (Vec>, PCSOpeningProof); + +type ProductionShaVerifierAggregate = ( + [DynamicPolynomialF; NUM_NONZERO_SHA_FAMILIES], + F, + F, + F, + F, + F, +); + +#[derive(Clone, Debug)] +pub struct ProductionShaFoldedWitness +where + Zt: ZincTypes, + F: PrimeField, + P: ZincPCSTypes, +{ + pub trace: ProjectedTrace, + pub opening_witness: PCSProverData, +} + +pub trait ProductionShaFoldedPcsOpen: ZincPCSTypes +where + Zt: ZincTypes, + F: PrimeField, +{ + #[allow(clippy::too_many_arguments)] + fn prove_folded_pcs_opening( + pcs_params: &PCSParams, + instance_commitments: &[PCSCommitments], + fold_weights: &[F], + folded_trace: &ProjectedTrace, + folded_prover_data: &PCSProverData, + r_0: &[F], + folded_lifted_evals: &[DynamicPolynomialF], + field_cfg: &F::Config, + ) -> Result, ProductionShaError> + where + Self: Sized; +} + +impl ProductionShaFoldedPcsOpen for AllHyraxPCSTypes +where + Zt: ZincTypes, + F: HyraxFieldBridge + DelayedFieldProductSum, + F::Inner: Transcribable, + F::Modulus: Transcribable, + C: AffineRepr, + HyraxPCS: PCS< + F, + BinaryPoly, + D, + CommitmentKey = zip_plus::pcs::hyrax::HyraxCommitmentKey, + ProverData = zip_plus::pcs::hyrax::HyraxProverData, + OpeningProof = Vec, + >, + HyraxPCS: PCS< + F, + DensePolynomial, + D, + CommitmentKey = zip_plus::pcs::hyrax::HyraxCommitmentKey, + ProverData = zip_plus::pcs::hyrax::HyraxProverData, + OpeningProof = Vec, + >, + HyraxPCS: PCS< + F, + Zt::Int, + D, + CommitmentKey = zip_plus::pcs::hyrax::HyraxCommitmentKey, + ProverData = zip_plus::pcs::hyrax::HyraxProverData, + OpeningProof = Vec, + >, +{ + fn prove_folded_pcs_opening( + pcs_params: &PCSParams, + instance_commitments: &[PCSCommitments], + fold_weights: &[F], + folded_trace: &ProjectedTrace, + folded_prover_data: &PCSProverData, + r_0: &[F], + folded_lifted_evals: &[DynamicPolynomialF], + field_cfg: &F::Config, + ) -> Result, ProductionShaError> { + prove_production_sha_hyrax_pcs_opening::( + pcs_params, + instance_commitments, + fold_weights, + folded_trace, + folded_prover_data, + r_0, + folded_lifted_evals, + field_cfg, + ) + } +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct VerifiedShaSumFold { + pub r_b: Vec, + pub c_sf: F, +} + +pub type LinearIdealFoldError = ProductionShaError; + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct FoldedRowSumcheckOutput { + pub r_star: Vec, + pub r_star_eq_weights: Vec, + pub terminal_value: F, + pub endpoint_evals: Option>, +} + +#[derive(Debug, Error)] +pub enum ProductionShaError { + #[error("production SHA requires at least two fresh instances, got {0}")] + InstanceCountTooSmall(usize), + #[error("instance count must be a power of two, got {0}")] + InstanceCountNotPowerOfTwo(usize), + #[error("length mismatch for {label}: got {got}, expected {expected}")] + LengthMismatch { + label: &'static str, + got: usize, + expected: usize, + }, + #[error("non-canonical proof object: {0}")] + NonCanonicalProofObject(&'static str), + #[error("production SHA public selector column {col:?} is not boolean at row {row}")] + NonBooleanPublicSelector { col: ShaPublicCol, row: usize }, + #[error("production SHA public selector column {col:?} is all zero")] + EmptyPublicSelector { col: ShaPublicCol }, + #[error( + "production SHA public selector column {col:?} does not match the fixed row layout at row {row}" + )] + InvalidPublicSelector { col: ShaPublicCol, row: usize }, + #[error("production SHA public K column does not match SHA-256 constants at row {row}")] + InvalidRoundConstant { row: usize }, + #[error("production SHA requires {expected}-bit word polynomials, got D={got}")] + UnsupportedProductionShaWordDegree { got: usize, expected: usize }, + #[error("unsupported production SHA PCS shape: {0}")] + UnsupportedProductionShaPcsShape(&'static str), + #[error("production SHA prover not implemented: {0}")] + ProverNotImplemented(&'static str), + #[error("PCS opening transcript has trailing bytes")] + TrailingPcsOpeningBytes, + #[error("{label} expected exactly one sumcheck group, got {got}")] + UnexpectedSumcheckGroupCount { label: &'static str, got: usize }, + #[error("SumFold proof has degree {degree}, expected at most 3")] + SumFoldDegreeTooHigh { degree: usize }, + #[error("SumFold terminal evaluation mismatch")] + SumFoldTerminalMismatch, + #[error("row sumcheck proof has degree {degree}, expected at most 3")] + RowSumcheckDegreeTooHigh { degree: usize }, + #[error("row sumcheck terminal evaluation mismatch")] + RowSumcheckTerminalMismatch, + #[error("endpoint scalarization mismatch for {col:?} shift {shift}")] + EndpointScalarizationMismatch { col: ShaWordCol, shift: usize }, + #[error("missing endpoint eval for {col:?} shift {shift}")] + MissingEndpointEval { col: ShaWordCol, shift: usize }, + #[error("ideal membership failed")] + IdealMembership, + #[error("PCS error: {0}")] + Pcs(#[from] ZipError), + #[error("sumcheck error: {0}")] + Sumcheck(#[from] SumCheckError), + #[error("multipoint evaluation error: {0}")] + Multipoint(#[from] MultipointEvalError), + #[error("SumFold error: {0}")] + SumFold(#[from] SumFoldError), + #[error("SHA projection error: {0}")] + ShaProjection(#[from] ShaProjectionError), + #[error("equality polynomial error: {0}")] + Eq(#[from] ArithErrors), + #[error("polynomial evaluation error: {0}")] + PolyEval(#[from] EvaluationError), +} + +pub fn absorb_projected_sha_publics( + transcript: &mut impl Transcript, + publics: &[zinc_piop::neutron_nova::ProjectedPublic], + field_cfg: &F::Config, +) where + F: PrimeField, + F::Inner: Transcribable, + F::Modulus: Transcribable, +{ + let mut field_buf = runtime_field_transcript_buf::(field_cfg); + let mut encoded = Vec::with_capacity( + publics.len() + * ShaPublicWordCol::COUNT + * SHA_ROW_COUNT + * F::Inner::get_num_bytes(F::zero_with_cfg(field_cfg).inner()), + ); + let zero = F::zero_with_cfg(field_cfg); + + fn push_u64(buf: &mut Vec, value: usize) { + buf.extend_from_slice(&(value as u64).to_le_bytes()); + } + + fn push_field_inners(buf: &mut Vec, values: &[F], scratch: &mut [u8]) + where + F: PrimeField, + F::Inner: Transcribable, + { + for value in values { + value.inner().write_transcription_bytes_exact(scratch); + buf.extend_from_slice(scratch); + } + } + + transcript.absorb_slice(b"production_sha_publics_begin"); + encoded.extend_from_slice(b"compact_v1"); + zero.modulus() + .write_transcription_bytes_exact(&mut field_buf); + encoded.extend_from_slice(&field_buf); + push_u64(&mut encoded, publics.len()); + push_u64(&mut encoded, ShaPublicCol::COUNT); + push_u64(&mut encoded, ShaPublicWordCol::COUNT); + push_u64(&mut encoded, SHA_ROW_COUNT); + for (instance_idx, public) in publics.iter().enumerate() { + push_u64(&mut encoded, instance_idx); + push_u64(&mut encoded, public.columns.len()); + match &public.bit_slices { + Some(bit_slices) => { + encoded.push(1); + push_u64(&mut encoded, bit_slices.len()); + } + None => encoded.push(0), + } + for public_col in production_sha_public_word_column_map() { + let col_idx = public_col.index(); + let col = &public.columns[col_idx]; + push_u64(&mut encoded, col_idx); + push_u64(&mut encoded, col.evaluations.len()); + push_field_inners::(&mut encoded, &col.evaluations, &mut field_buf); + } + } + transcript.absorb_slice(&encoded); + transcript.absorb_slice(b"production_sha_publics_end"); +} + +fn absorb_uair_shape_metadata(transcript: &mut impl Transcript, shape: &UairShape) { + let sig = &shape.signature; + + transcript.absorb_slice(b"uair_shape_metadata_begin"); + transcript.absorb_slice(&(shape.num_vars as u64).to_le_bytes()); + absorb_uair_column_counts( + transcript, + sig.total_cols().num_binary_poly_cols(), + sig.total_cols().num_arbitrary_poly_cols(), + sig.total_cols().num_int_cols(), + ); + absorb_uair_column_counts( + transcript, + sig.public_cols().num_binary_poly_cols(), + sig.public_cols().num_arbitrary_poly_cols(), + sig.public_cols().num_int_cols(), + ); + absorb_uair_column_counts( + transcript, + sig.witness_cols().num_binary_poly_cols(), + sig.witness_cols().num_arbitrary_poly_cols(), + sig.witness_cols().num_int_cols(), + ); + transcript.absorb_slice(&(sig.shifts().len() as u64).to_le_bytes()); + for shift in sig.shifts() { + transcript.absorb_slice(&(shift.source_col() as u64).to_le_bytes()); + transcript.absorb_slice(&(shift.shift_amount() as u64).to_le_bytes()); + } + transcript.absorb_slice(b"uair_shape_metadata_end"); +} + +fn absorb_uair_column_counts( + transcript: &mut impl Transcript, + binary: usize, + arbitrary: usize, + int: usize, +) { + transcript.absorb_slice(&(binary as u64).to_le_bytes()); + transcript.absorb_slice(&(arbitrary as u64).to_le_bytes()); + transcript.absorb_slice(&(int as u64).to_le_bytes()); +} + +fn runtime_field_transcript_buf(field_cfg: &F::Config) -> Vec +where + F: PrimeField, + F::Inner: Transcribable, +{ + vec![0u8; F::zero_with_cfg(field_cfg).inner().get_num_bytes()] +} + +fn absorb_sha_resolver_proof( + transcript: &mut impl Transcript, + proof: &CombinedPolyResolverProof, + field_cfg: &F::Config, +) where + F: PrimeField, + F::Inner: Transcribable, + F::Modulus: Transcribable, +{ + let mut field_buf = runtime_field_transcript_buf::(field_cfg); + fn absorb_vec( + transcript: &mut impl Transcript, + label: &'static [u8], + values: &[F], + field_buf: &mut [u8], + ) where + F: PrimeField, + F::Inner: Transcribable, + F::Modulus: Transcribable, + { + transcript.absorb_slice(label); + transcript.absorb_slice(&(values.len() as u64).to_le_bytes()); + transcript.absorb_random_field_slice(values, field_buf); + } + + transcript.absorb_slice(b"production_sha_resolver_begin"); + absorb_vec(transcript, b"up", &proof.up_evals, &mut field_buf); + absorb_vec(transcript, b"down", &proof.down_evals, &mut field_buf); + absorb_vec( + transcript, + b"bit_slice", + &proof.bit_slice_evals, + &mut field_buf, + ); + absorb_vec( + transcript, + b"bit_op_down", + &proof.bit_op_down_evals, + &mut field_buf, + ); + absorb_vec( + transcript, + b"shifted_bit_slice", + &proof.shifted_bit_slice_evals, + &mut field_buf, + ); + transcript.absorb_slice(b"production_sha_resolver_end"); +} + +fn absorb_public_uair_trace( + transcript: &mut impl Transcript, + instance_idx: usize, + trace: &UairTrace<'_, Zt::Int, Zt::Int, D>, +) where + Zt: ZincTypes, +{ + fn push_u64(buf: &mut Vec, value: usize) { + buf.extend_from_slice(&(value as u64).to_le_bytes()); + } + + fn push_transcribable(buf: &mut Vec, value: &T, scratch: &mut Vec) { + let len = value.get_num_bytes() + T::LENGTH_NUM_BYTES; + scratch.resize(len, 0); + value.write_transcription_bytes_subset(scratch); + buf.extend_from_slice(scratch); + } + + let mut encoded = Vec::new(); + let mut scratch = Vec::new(); + encoded.extend_from_slice(b"compact_v1"); + push_u64(&mut encoded, instance_idx); + push_u64(&mut encoded, trace.binary_poly.len()); + for (col_idx, col) in trace.binary_poly.iter().enumerate() { + push_u64(&mut encoded, col_idx); + push_u64(&mut encoded, col.num_vars); + push_u64(&mut encoded, col.evaluations.len()); + for value in &col.evaluations { + push_transcribable(&mut encoded, value, &mut scratch); + } + } + push_u64(&mut encoded, trace.arbitrary_poly.len()); + for (col_idx, col) in trace.arbitrary_poly.iter().enumerate() { + push_u64(&mut encoded, col_idx); + push_u64(&mut encoded, col.num_vars); + push_u64(&mut encoded, col.evaluations.len()); + for poly in &col.evaluations { + for coeff in poly.iter() { + push_transcribable(&mut encoded, coeff, &mut scratch); + } + } + } + push_u64(&mut encoded, trace.int.len()); + for (col_idx, col) in trace.int.iter().enumerate() { + push_u64(&mut encoded, col_idx); + push_u64(&mut encoded, col.num_vars); + push_u64(&mut encoded, col.evaluations.len()); + for value in &col.evaluations { + push_transcribable(&mut encoded, value, &mut scratch); + } + } + + transcript.absorb_slice(b"uair_public_trace_begin"); + transcript.absorb_slice(&encoded); + transcript.absorb_slice(b"uair_public_trace_end"); +} + +fn absorb_production_sha_statement_metadata(transcript: &mut impl Transcript) { + transcript.absorb_slice(PRODUCTION_SHA_FRESH_BATCH_DOMAIN); + transcript.absorb_slice(b"production_sha_statement_metadata_begin"); + + transcript.absorb_slice(b"row_layout"); + transcript.absorb_slice(&(SHA_ROW_VARS as u64).to_le_bytes()); + transcript.absorb_slice(&(SHA_ROW_COUNT as u64).to_le_bytes()); + for (start, end) in [(0u64, 3u64), (0, 15), (0, 47), (0, 63), (64, 67), (68, 71)] { + transcript.absorb_slice(&start.to_le_bytes()); + transcript.absorb_slice(&end.to_le_bytes()); + } + + transcript.absorb_slice(b"sha_word_column_order"); + for col in ShaWordCol::ALL { + transcript.absorb_slice(&(col.index() as u64).to_le_bytes()); + } + transcript.absorb_slice(b"sha_int_column_order"); + for col in ShaIntCol::ALL { + transcript.absorb_slice(&(col.index() as u64).to_le_bytes()); + } + transcript.absorb_slice(b"sha_public_column_order"); + for col in ShaPublicCol::ALL { + transcript.absorb_slice(&(col.index() as u64).to_le_bytes()); + } + + transcript.absorb_slice(b"sha_residual_family_order"); + for family in ShaResidualFamily::ALL { + transcript.absorb_slice(&(family.index() as u64).to_le_bytes()); + } + transcript.absorb_slice(b"sha_nonzero_ideal_ids"); + for family in production_sha_nonzero_families() { + transcript.absorb_slice(&(family.index() as u64).to_le_bytes()); + let ideal_id: &[u8] = match family { + ShaResidualFamily::R0BigSigmaA | ShaResidualFamily::R1BigSigmaE => b"X32_MINUS_1", + ShaResidualFamily::R4Schedule + | ShaResidualFamily::R5UpdateA + | ShaResidualFamily::R6UpdateE + | ShaResidualFamily::R9FeedForwardA + | ShaResidualFamily::R10FeedForwardE => b"X_MINUS_2", + _ => b"UNEXPECTED_NONZERO_IDEAL", + }; + transcript.absorb_slice(ideal_id); + } + + transcript.absorb_slice(b"sha256_k_constants"); + for constant in SHA256_ROUND_CONSTANTS { + transcript.absorb_slice(&(constant as u64).to_le_bytes()); + } + + transcript.absorb_slice(b"production_sha_statement_metadata_end"); +} + +pub fn absorb_production_sha_commitments( + transcript: &mut impl Transcript, + label: &'static [u8], + commitments: &[PCSCommitments], +) where + Zt: ZincTypes, + F: PrimeField, + P: ZincPCSTypes, +{ + transcript.absorb_slice(label); + transcript.absorb_slice(&(commitments.len() as u64).to_le_bytes()); + for (instance_idx, commitment) in commitments.iter().enumerate() { + transcript.absorb_slice(&(instance_idx as u64).to_le_bytes()); + P::BinaryPCS::absorb_commitment(transcript, &commitment.binary); + P::ArbitraryPCS::absorb_commitment(transcript, &commitment.arbitrary); + P::IntPCS::absorb_commitment(transcript, &commitment.int); + } +} + +pub fn absorb_derived_production_sha_commitments( + transcript: &mut impl Transcript, + label: &'static [u8], + commitments: &[PCSCommitments], + weights: &[F], + field_cfg: &F::Config, +) where + Zt: ZincTypes, + F: PrimeField, + F::Inner: Transcribable, + F::Modulus: Transcribable, + P: ZincPCSTypes, +{ + let mut field_buf = runtime_field_transcript_buf::(field_cfg); + transcript.absorb_slice(label); + transcript.absorb_slice(b"derived_from_fresh_v1"); + transcript.absorb_slice(&(commitments.len() as u64).to_le_bytes()); + transcript.absorb_slice(&(weights.len() as u64).to_le_bytes()); + transcript.absorb_random_field_slice(weights, &mut field_buf); + absorb_production_sha_commitments::( + transcript, + b"production_sha_derived_folded_commitment_sources", + commitments, + ); +} + +pub fn absorb_fresh_sha_ideal_polys( + transcript: &mut impl Transcript, + ideal_polys: &[[DynamicPolynomialF; NUM_NONZERO_SHA_FAMILIES]], + field_cfg: &F::Config, +) where + F: PrimeField, + F::Inner: Transcribable, + F::Modulus: Transcribable, +{ + let mut field_buf = runtime_field_transcript_buf::(field_cfg); + transcript.absorb_slice(b"production_sha_fresh_ideals_begin"); + transcript.absorb_slice(&(ideal_polys.len() as u64).to_le_bytes()); + for (instance_idx, instance) in ideal_polys.iter().enumerate() { + transcript.absorb_slice(&(instance_idx as u64).to_le_bytes()); + transcript.absorb_slice(&(instance.len() as u64).to_le_bytes()); + for (family_idx, poly) in instance.iter().enumerate() { + transcript.absorb_slice(&(family_idx as u64).to_le_bytes()); + transcript.absorb_slice(&(poly.coeffs.len() as u64).to_le_bytes()); + transcript.absorb_random_field_slice(&poly.coeffs, &mut field_buf); + } + } + transcript.absorb_slice(b"production_sha_fresh_ideals_end"); +} + +pub fn absorb_aggregate_sha_ideal_polys( + transcript: &mut impl Transcript, + ideal_polys: &[DynamicPolynomialF; NUM_NONZERO_SHA_FAMILIES], + field_cfg: &F::Config, +) where + F: PrimeField, + F::Inner: Transcribable, + F::Modulus: Transcribable, +{ + let mut field_buf = runtime_field_transcript_buf::(field_cfg); + transcript.absorb_slice(b"production_sha_aggregate_ideals_begin"); + transcript.absorb_slice(&(ideal_polys.len() as u64).to_le_bytes()); + for (family_idx, poly) in ideal_polys.iter().enumerate() { + transcript.absorb_slice(&(family_idx as u64).to_le_bytes()); + transcript.absorb_slice(&(poly.coeffs.len() as u64).to_le_bytes()); + transcript.absorb_random_field_slice(&poly.coeffs, &mut field_buf); + } + transcript.absorb_slice(b"production_sha_aggregate_ideals_end"); +} + +pub fn absorb_sha_endpoint_evals( + transcript: &mut impl Transcript, + endpoint_evals: &ShaEndpointEvals, + field_cfg: &F::Config, +) where + F: PrimeField, + F::Inner: Transcribable, + F::Modulus: Transcribable, +{ + let mut field_buf = runtime_field_transcript_buf::(field_cfg); + transcript.absorb_slice(b"production_sha_endpoint_evals_begin"); + transcript.absorb_slice(&(endpoint_evals.sources.len() as u64).to_le_bytes()); + for source in &endpoint_evals.sources { + transcript.absorb_slice(&(source.col.index() as u64).to_le_bytes()); + transcript.absorb_slice(&(source.shift as u64).to_le_bytes()); + transcript.absorb_random_field(&source.scalarized, &mut field_buf); + transcript.absorb_random_field_slice(&source.bits, &mut field_buf); + } + transcript.absorb_slice(&(endpoint_evals.int_sources.len() as u64).to_le_bytes()); + for source in &endpoint_evals.int_sources { + transcript.absorb_slice(&(source.col.index() as u64).to_le_bytes()); + transcript.absorb_random_field(&source.scalar, &mut field_buf); + } + transcript.absorb_slice(b"production_sha_endpoint_evals_end"); +} + +pub fn absorb_folded_lifted_evals( + transcript: &mut impl Transcript, + lifted_evals: &[DynamicPolynomialF], + field_cfg: &F::Config, +) where + F: PrimeField, + F::Inner: Transcribable, + F::Modulus: Transcribable, +{ + let mut field_buf = runtime_field_transcript_buf::(field_cfg); + transcript.absorb_slice(b"production_sha_folded_lifted_evals_begin"); + transcript.absorb_slice(&(lifted_evals.len() as u64).to_le_bytes()); + for (idx, lifted_eval) in lifted_evals.iter().enumerate() { + transcript.absorb_slice(&(idx as u64).to_le_bytes()); + transcript.absorb_slice(&(lifted_eval.coeffs.len() as u64).to_le_bytes()); + transcript.absorb_random_field_slice(&lifted_eval.coeffs, &mut field_buf); + } + transcript.absorb_slice(b"production_sha_folded_lifted_evals_end"); +} + +pub fn sample_pre_ideal_challenge( + transcript: &mut impl Transcript, + field_cfg: &F::Config, +) -> [F; SHA_ROW_VARS] +where + F: DelayedFieldProductSum, + F::Inner: Transcribable, +{ + std::array::from_fn(|_| transcript.get_transcribable_field_challenge(field_cfg)) +} + +pub fn sample_instance_batch_challenge( + transcript: &mut impl Transcript, + instance_count: usize, + field_cfg: &F::Config, +) -> Result, ProductionShaError> +where + F: PrimeField, + F::Inner: Transcribable, +{ + if !instance_count.is_power_of_two() { + return Err(ProductionShaError::InstanceCountNotPowerOfTwo( + instance_count, + )); + } + let ell = usize::try_from(instance_count.trailing_zeros()).expect("ell fits usize"); + Ok(transcript.get_transcribable_field_challenges(ell, field_cfg)) +} + +pub fn sample_post_aggregate_ideal_challenges( + transcript: &mut impl Transcript, + field_cfg: &F::Config, +) -> (F, F, F, F) +where + F: PrimeField, + F::Inner: Transcribable, +{ + ( + transcript.get_transcribable_field_challenge(field_cfg), + transcript.get_transcribable_field_challenge(field_cfg), + transcript.get_transcribable_field_challenge(field_cfg), + transcript.get_transcribable_field_challenge(field_cfg), + ) +} + +pub fn sample_post_ideal_challenges( + transcript: &mut impl Transcript, + instance_count: usize, + field_cfg: &F::Config, +) -> Result<(F, F, F, F, Vec), ProductionShaError> +where + F: PrimeField, + F::Inner: Transcribable, +{ + let (a, lambda, rho, xi) = sample_post_aggregate_ideal_challenges(transcript, field_cfg); + Ok(( + a, + lambda, + rho, + xi, + sample_instance_batch_challenge(transcript, instance_count, field_cfg)?, + )) +} + +pub fn check_fresh_sha_ideal_membership( + ideal_polys: &[[DynamicPolynomialF; NUM_NONZERO_SHA_FAMILIES]], + field_cfg: &F::Config, +) -> Result<(), ProductionShaError> +where + F: DelayedFieldProductSum, +{ + verify_fresh_sha_ideal_polys(ideal_polys, field_cfg)?; + Ok(()) +} + +pub fn check_aggregate_sha_ideal_membership( + ideal_polys: &[DynamicPolynomialF; NUM_NONZERO_SHA_FAMILIES], + field_cfg: &F::Config, +) -> Result<(), ProductionShaError> +where + F: PrimeField, + F::Inner: Transcribable, + F::Modulus: Transcribable, +{ + verify_fresh_sha_ideal_polys(std::slice::from_ref(ideal_polys), field_cfg)?; + Ok(()) +} + +fn ensure_production_sha_word_degree() -> Result<(), ProductionShaError> +where + F: PrimeField, +{ + if D != SHA_WORD_BITS { + return Err(ProductionShaError::UnsupportedProductionShaWordDegree { + got: D, + expected: SHA_WORD_BITS, + }); + } + Ok(()) +} + +fn validate_production_sha_batch_sizes( + binary: usize, + arbitrary: usize, + int: usize, +) -> Result<(), ProductionShaError> +where + F: PrimeField, +{ + if binary != ShaWordCol::COUNT { + return Err(ProductionShaError::UnsupportedProductionShaPcsShape( + "production SHA expects one binary commitment batch per SHA word column", + )); + } + if arbitrary != 0 { + return Err(ProductionShaError::UnsupportedProductionShaPcsShape( + "production SHA expects no arbitrary witness columns", + )); + } + if int != ShaIntCol::COUNT { + return Err(ProductionShaError::UnsupportedProductionShaPcsShape( + "production SHA expects one int commitment batch per SHA int column", + )); + } + Ok(()) +} + +pub fn commit_production_sha_instance( + pcs_params: &PCSParams, + witness_polys: &ProductionShaWitnessPolys, +) -> Result<(PCSProverData, PCSCommitments), ProductionShaError> +where + Zt: ZincTypes, + F: PrimeField, + P: ZincPCSTypes, + P::BinaryPCS: FoldablePCS, D>, + P::ArbitraryPCS: FoldablePCS, D>, + P::IntPCS: FoldablePCS, +{ + ensure_production_sha_word_degree::()?; + validate_production_sha_batch_sizes::( + witness_polys.binary.len(), + witness_polys.arbitrary.len(), + witness_polys.int.len(), + )?; + let (binary_data, binary_commitment) = + P::BinaryPCS::commit(&pcs_params.binary, &witness_polys.binary)?; + let (arbitrary_data, arbitrary_commitment) = + P::ArbitraryPCS::commit(&pcs_params.arbitrary, &witness_polys.arbitrary)?; + let (int_data, int_commitment) = P::IntPCS::commit(&pcs_params.int, &witness_polys.int)?; + Ok(( + PCSProverData { + binary: binary_data, + arbitrary: arbitrary_data, + int: int_data, + }, + PCSCommitments { + binary: binary_commitment, + arbitrary: arbitrary_commitment, + int: int_commitment, + }, + )) +} + +#[allow(clippy::too_many_arguments)] +pub fn prove_linear_ideal_fold( + pp: &LinearIdealFoldProverParams, + shape: &UairShape, + witnesses: &[UairWitness<'_, Zt::Int, Zt::Int, D>], + transcript: &mut impl Transcript, +) -> Result< + LinearIdealFoldProveOutput< + UairInstance<'static, Zt::Int, Zt::Int, PCSCommitments, D>, + FoldedLinearIdealInstance>, + FoldedLinearIdealWitness>, + ProductionLinearIdealFoldProof, + >, + LinearIdealFoldError, +> +where + U: Uair + ProductionShaProjectionAdapter + Sync, + Zt: ZincTypes, + F: InnerTransparentField + + DelayedFieldProductSum + + ShaBinaryFoldField + + FromPrimitiveWithConfig + + Send + + Sync + + 'static, + F::Inner: Transcribable + Zero + Default + Send + Sync, + F::Modulus: Transcribable, + P: ZincPCSTypes, + P: ProductionShaFoldedPcsOpen, + P::BinaryPCS: FoldablePCS, D>, + P::ArbitraryPCS: FoldablePCS, D>, + P::IntPCS: FoldablePCS, +{ + let field_cfg = &pp.field_cfg; + ensure_production_sha_word_degree::()?; + let prepared = + prepare_linear_ideal_fold_witnesses::(shape, witnesses, field_cfg)?; + prove_prepared_linear_ideal_fold::(pp, shape, &prepared, transcript) +} + +pub fn prepare_linear_ideal_fold_witnesses( + shape: &UairShape, + witnesses: &[UairWitness<'_, Zt::Int, Zt::Int, D>], + field_cfg: &F::Config, +) -> Result>, LinearIdealFoldError> +where + U: Uair + ProductionShaProjectionAdapter + Sync, + Zt: ZincTypes, + F: PrimeField + Send + Sync, +{ + if witnesses.len() < 2 { + return Err(ProductionShaError::InstanceCountTooSmall(witnesses.len())); + } + if !witnesses.len().is_power_of_two() { + return Err(ProductionShaError::InstanceCountNotPowerOfTwo( + witnesses.len(), + )); + } + + cfg_iter!(witnesses) + .map(|witness| { + let public_trace = public_uair_trace_view(&witness.trace, &shape.signature)?; + let witness_trace = witness_uair_trace_view(&witness.trace, &shape.signature)?; + let (trace, public, witness_polys) = + U::project_production_sha_witness(shape, &public_trace, &witness_trace, field_cfg)?; + Ok(PreparedProductionShaProverInstance { + public_trace: own_uair_trace(&public_trace), + instance: ProductionShaProverInstance { + trace, + public, + witness_polys, + }, + }) + }) + .collect() +} + +#[allow(clippy::too_many_arguments)] +pub fn prove_prepared_linear_ideal_fold( + pp: &LinearIdealFoldProverParams, + shape: &UairShape, + prepared_instances: &[PreparedProductionShaProverInstance], + transcript: &mut impl Transcript, +) -> Result< + LinearIdealFoldProveOutput< + UairInstance<'static, Zt::Int, Zt::Int, PCSCommitments, D>, + FoldedLinearIdealInstance>, + FoldedLinearIdealWitness>, + ProductionLinearIdealFoldProof, + >, + LinearIdealFoldError, +> +where + U: Uair, + Zt: ZincTypes, + F: InnerTransparentField + + DelayedFieldProductSum + + ShaBinaryFoldField + + FromPrimitiveWithConfig + + Send + + Sync + + 'static, + F::Inner: Transcribable + Zero + Default + Send + Sync, + F::Modulus: Transcribable, + P: ZincPCSTypes, + P: ProductionShaFoldedPcsOpen, + P::BinaryPCS: FoldablePCS, D>, + P::ArbitraryPCS: FoldablePCS, D>, + P::IntPCS: FoldablePCS, +{ + let field_cfg = &pp.field_cfg; + ensure_production_sha_word_degree::()?; + let instance_count = prepared_instances.len(); + if instance_count < 2 { + return Err(ProductionShaError::InstanceCountTooSmall(instance_count)); + } + if !instance_count.is_power_of_two() { + return Err(ProductionShaError::InstanceCountNotPowerOfTwo( + instance_count, + )); + } + + let booleanity_sources = production_sha_booleanity_sources(); + absorb_production_sha_statement_metadata(transcript); + absorb_uair_shape_metadata(transcript, shape); + + let (fresh_instances, instance_commitments, instance_prover_data, traces, publics) = + prove_prepared_fresh_instances_phase::( + &pp.pcs_params, + prepared_instances, + transcript, + field_cfg, + )?; + validate_production_sha_publics(&publics, field_cfg)?; + + tracing::info_span!( + target: "zinc_protocol::production_sha", + "absorb_fresh_commitments", + side = "prove", + phase = "absorb_fresh_commitments", + ) + .in_scope(|| { + absorb_production_sha_commitments::( + transcript, + b"production_sha_fresh_commitments", + &instance_commitments, + ) + }); + tracing::info_span!( + target: "zinc_protocol::production_sha", + "absorb_projected_publics", + side = "prove", + phase = "absorb_projected_publics", + ) + .in_scope(|| absorb_projected_sha_publics(transcript, &publics, field_cfg)); + + let r_ic = sample_pre_ideal_challenge(transcript, field_cfg); + let r_ic_eq_weights = build_eq_x_r_vec(&r_ic, field_cfg)?; + let coeff_tables = + build_residual_coeff_tables_phase(&traces, &publics, &r_ic_eq_weights, field_cfg)?; + + let beta = sample_instance_batch_challenge(transcript, instance_count, field_cfg)?; + let beta_eq_weights = build_eq_x_r_vec(&beta, field_cfg)?; + let (ideal_check, aggregate_ideal_polys) = + prove_aggregate_ideal_phase(&coeff_tables, &beta_eq_weights, transcript, field_cfg)?; + + let (a, lambda, rho, xi) = sample_post_aggregate_ideal_challenges(transcript, field_cfg); + let a_powers = build_sha_residual_eval_powers(&a, field_cfg); + let lambda_powers = build_sha_lambda_powers(&lambda, field_cfg); + let booleanity_weights = + build_booleanity_weights(&rho, &xi, booleanity_sources.len(), field_cfg); + let initial_claim = evaluate_aggregate_sha_ideal_claim_with_powers( + &aggregate_ideal_polys, + &a_powers, + &lambda_powers, + field_cfg, + )?; + + let (_linear_accumulator, _quadratic_prefix_accumulator, sumfold_group) = + build_sumfold_accumulators_phase( + &traces, + &beta, + &beta_eq_weights, + &r_ic_eq_weights, + &coeff_tables, + &a_powers, + &lambda_powers, + &booleanity_weights, + &booleanity_sources, + pp.prefix_vars, + field_cfg, + )?; + + let (sumfold_proof, sumfold_r_b, sumfold_c_sf) = prove_sumfold_phase( + transcript, + sumfold_group, + &initial_claim, + beta.len(), + field_cfg, + )?; + + let sumfold_output = derive_instance_fold_claim( + &beta, + sumfold_r_b.clone(), + sumfold_c_sf, + instance_count, + field_cfg, + )?; + + let (folded, folded_public, row_claim, sumfold_output, folded_prover_data) = + prove_fold_after_sumfold_phase::( + &traces, + &publics, + sumfold_output, + &r_ic_eq_weights, + &a_powers, + &lambda_powers, + &booleanity_weights, + &booleanity_sources, + &instance_prover_data, + field_cfg, + )?; + absorb_derived_production_sha_commitments::( + transcript, + b"production_sha_derived_folded_commitments", + &instance_commitments, + sumfold_output.eq_instance_weights(), + field_cfg, + ); + + verify_folded_row_sumcheck_claim(&row_claim, sumfold_output.final_round_sumcheck_claim())?; + let (combined_sumcheck, row_output) = prove_row_sumcheck_phase( + transcript, + &folded.trace, + &folded_public, + &r_ic, + &r_ic_eq_weights, + &a_powers, + &lambda_powers, + &booleanity_weights, + &booleanity_sources, + &row_claim, + field_cfg, + )?; + + let (resolver, _resolver_endpoint_evals, multipoint_eval, r_0) = + prove_endpoint_multipoint_phase( + transcript, + &folded.trace, + &folded_public, + &row_output, + &r_ic, + &a, + &a_powers, + &lambda_powers, + &booleanity_weights, + &booleanity_sources, + field_cfg, + )?; + + let r_0_eq_weights = build_eq_x_r_vec(&r_0, field_cfg)?; + let (witness_lifted_evals, opening_proof) = prove_pcs_opening_phase::( + transcript, + &folded.trace, + &instance_commitments, + sumfold_output.eq_instance_weights(), + &folded_prover_data, + &r_0, + &r_0_eq_weights, + &pp.pcs_params, + field_cfg, + )?; + + Ok(LinearIdealFoldProveOutput { + fresh_instances, + folded_instance: FoldedLinearIdealInstance { + target: sumfold_output.final_round_sumcheck_claim().clone(), + commitments: (), + public: folded_public, + }, + folded_witness: FoldedLinearIdealWitness { + witness: ProductionShaFoldedWitness { + trace: folded.trace, + opening_witness: folded_prover_data, + }, + }, + proof: ProductionLinearIdealFoldProof { + instance_commitments, + ideal_check, + sumfold_proof, + resolver, + combined_sumcheck, + multipoint_eval, + witness_lifted_evals, + opening_proof, + }, + }) +} + +#[tracing::instrument( + target = "zinc_protocol::production_sha", + level = "info", + skip_all, + fields(side = "prove", phase = "fresh_instances", instances = prepared_instances.len()) +)] +#[allow(clippy::too_many_arguments)] +fn prove_prepared_fresh_instances_phase( + pcs_params: &PCSParams, + prepared_instances: &[PreparedProductionShaProverInstance], + transcript: &mut impl Transcript, + _field_cfg: &F::Config, +) -> Result, ProductionShaError> +where + Zt: ZincTypes, + F: PrimeField, + P: ZincPCSTypes, + P::BinaryPCS: FoldablePCS, D>, + P::ArbitraryPCS: FoldablePCS, D>, + P::IntPCS: FoldablePCS, +{ + for (instance_idx, prepared) in prepared_instances.iter().enumerate() { + absorb_public_uair_trace::(transcript, instance_idx, &prepared.public_trace); + } + + let artifacts = cfg_iter!(prepared_instances) + .map(|prepared| { + let (data, commitment) = tracing::info_span!( + target: "zinc_protocol::production_sha", + "fresh_commit_instance", + side = "prove", + phase = "fresh_commit_instance", + ) + .in_scope(|| { + commit_production_sha_instance::( + pcs_params, + &prepared.instance.witness_polys, + ) + })?; + + Ok(( + UairInstance { + public_trace: prepared.public_trace.clone(), + commitments: commitment.clone(), + }, + commitment, + data, + prepared.instance.trace.clone(), + prepared.instance.public.clone(), + )) + }) + .collect::, ProductionShaError>>()?; + + let mut fresh_instances = Vec::with_capacity(artifacts.len()); + let mut instance_commitments = Vec::with_capacity(artifacts.len()); + let mut instance_prover_data = Vec::with_capacity(artifacts.len()); + let mut traces = Vec::with_capacity(artifacts.len()); + let mut publics = Vec::with_capacity(artifacts.len()); + + for (fresh_instance, commitment, data, trace, public) in artifacts { + fresh_instances.push(fresh_instance); + instance_commitments.push(commitment); + instance_prover_data.push(data); + traces.push(trace); + publics.push(public); + } + + Ok(( + fresh_instances, + instance_commitments, + instance_prover_data, + traces, + publics, + )) +} + +#[tracing::instrument( + target = "zinc_protocol::production_sha", + level = "info", + skip_all, + fields(side = "prove", phase = "residual_coeff_tables", instances = traces.len()) +)] +fn build_residual_coeff_tables_phase( + traces: &[ProjectedTrace], + publics: &[ProjectedPublic], + r_ic_eq_weights: &[F], + field_cfg: &F::Config, +) -> Result>, ProductionShaError> +where + F: PrimeField, + F::Inner: Transcribable, + F::Modulus: Transcribable, +{ + build_linear_residual_coeff_tables_with_row_weights(traces, publics, r_ic_eq_weights, field_cfg) + .map_err(ProductionShaError::from) +} + +#[tracing::instrument( + target = "zinc_protocol::production_sha", + level = "info", + skip_all, + fields(side = "prove", phase = "aggregate_ideal", instances = coeff_tables.len()) +)] +fn prove_aggregate_ideal_phase( + coeff_tables: &[LinearResidualCoeffTable], + beta_eq_weights: &[F], + transcript: &mut impl Transcript, + field_cfg: &F::Config, +) -> Result< + ( + IdealCheckProof, + [DynamicPolynomialF; NUM_NONZERO_SHA_FAMILIES], + ), + ProductionShaError, +> +where + F: PrimeField, + F::Inner: Transcribable, + F::Modulus: Transcribable, +{ + let aggregate_ideal_polys = + beta_aggregate_nonzero_ideal_polys_with_weights(coeff_tables, beta_eq_weights)?; + let ideal_check = IdealCheckProof { + combined_mle_values: aggregate_ideal_polys.iter().cloned().collect(), + }; + #[cfg(debug_assertions)] + check_aggregate_sha_ideal_membership(&aggregate_ideal_polys, field_cfg)?; + absorb_aggregate_sha_ideal_polys(transcript, &aggregate_ideal_polys, field_cfg); + Ok((ideal_check, aggregate_ideal_polys)) +} + +#[tracing::instrument( + target = "zinc_protocol::production_sha", + level = "info", + skip_all, + fields( + side = "prove", + phase = "sumfold_accumulators", + instances = traces.len(), + prefix_vars, + ) +)] +#[allow(clippy::too_many_arguments)] +fn build_sumfold_accumulators_phase( + traces: &[ProjectedTrace], + beta: &[F], + beta_eq_weights: &[F], + r_ic_eq_weights: &[F], + coeff_tables: &[LinearResidualCoeffTable], + a_powers: &[F], + lambda_powers: &[F], + booleanity_weights: &[F], + booleanity_sources: &[ShaBooleanitySource], + prefix_vars: usize, + field_cfg: &F::Config, +) -> Result, ProductionShaError> +where + F: InnerTransparentField + DelayedFieldProductSum + Send + Sync + 'static, + F::Inner: Zero, +{ + let linear_accumulator = tracing::info_span!( + target: "zinc_protocol::production_sha", + "sumfold_linear_accumulator", + side = "prove", + phase = "sumfold_linear_accumulator", + ) + .in_scope(|| { + build_sha_sumfold_linear_accumulator(coeff_tables, a_powers, lambda_powers, field_cfg) + })?; + let quadratic_prefix_accumulator = tracing::info_span!( + target: "zinc_protocol::production_sha", + "sumfold_quadratic_prefix_accumulator", + side = "prove", + phase = "sumfold_quadratic_prefix_accumulator", + ) + .in_scope(|| { + build_sha_sumfold_quadratic_prefix_accumulator( + traces, + booleanity_sources, + prefix_vars, + r_ic_eq_weights, + booleanity_weights, + field_cfg, + ) + })?; + let sumfold_group = tracing::info_span!( + target: "zinc_protocol::production_sha", + "sumfold_group", + side = "prove", + phase = "sumfold_group", + ) + .in_scope(|| { + build_production_sha_sumfold_group_from_prefix_accumulators( + traces, + beta, + beta_eq_weights, + r_ic_eq_weights, + &linear_accumulator, + &quadratic_prefix_accumulator, + booleanity_weights, + booleanity_sources, + prefix_vars, + field_cfg, + ) + })?; + Ok(( + linear_accumulator, + quadratic_prefix_accumulator, + sumfold_group, + )) +} + +#[tracing::instrument( + target = "zinc_protocol::production_sha", + level = "info", + skip_all, + fields(side = "prove", phase = "sumfold_prove", instance_vars) +)] +fn prove_sumfold_phase( + transcript: &mut impl Transcript, + sumfold_group: MultiDegreeSumcheckGroup, + initial_claim: &F, + instance_vars: usize, + field_cfg: &F::Config, +) -> Result<(MultiDegreeSumcheckProof, Vec, F), ProductionShaError> +where + F: InnerTransparentField + + DelayedFieldProductSum + + FromPrimitiveWithConfig + + Send + + Sync + + 'static, + F::Inner: Transcribable + Zero + Send + Sync, + F::Modulus: Transcribable, +{ + prove_optimized_sha_sumfold_with_weights( + transcript, + sumfold_group, + initial_claim, + instance_vars, + field_cfg, + ) +} + +#[tracing::instrument( + target = "zinc_protocol::production_sha", + level = "info", + skip_all, + fields(side = "prove", phase = "fold_after_sumfold", instances = traces.len()) +)] +#[allow(clippy::too_many_arguments)] +fn prove_fold_after_sumfold_phase( + traces: &[ProjectedTrace], + publics: &[ProjectedPublic], + sumfold_output: InstanceFoldClaim, + _r_ic_eq_weights: &[F], + _a_powers: &[F], + _lambda_powers: &[F], + _booleanity_weights: &[F], + _booleanity_sources: &[ShaBooleanitySource], + instance_prover_data: &[PCSProverData], + field_cfg: &F::Config, +) -> Result, ProductionShaError> +where + Zt: ZincTypes, + F: InnerTransparentField + DelayedFieldProductSum + ShaBinaryFoldField + Send + Sync + 'static, + F::Inner: Zero, + P: ZincPCSTypes, + P::BinaryPCS: FoldablePCS, D>, + P::ArbitraryPCS: FoldablePCS, D>, + P::IntPCS: FoldablePCS, +{ + let (folded, folded_public) = tracing::info_span!( + target: "zinc_protocol::production_sha", + "fold_projected_traces", + side = "prove", + phase = "fold_projected_traces", + ) + .in_scope(|| fold_projected_traces(traces, publics, &sumfold_output, field_cfg))?; + let row_claim = tracing::info_span!( + target: "zinc_protocol::production_sha", + "fold_row_claim", + side = "prove", + phase = "fold_row_claim", + ) + .in_scope(|| -> Result> { + let row_claim = sumfold_output.final_round_sumcheck_claim().clone(); + #[cfg(debug_assertions)] + { + let recomputed = production_sha_folded_row_sum_fast( + &folded.trace, + &folded_public, + _r_ic_eq_weights, + _a_powers, + _lambda_powers, + _booleanity_weights, + _booleanity_sources, + field_cfg, + )?; + if recomputed != row_claim { + return Err(ShaProjectionError::FoldedRowClaimMismatch.into()); + } + } + Ok(row_claim) + })?; + let folded_prover_data = tracing::info_span!( + target: "zinc_protocol::production_sha", + "fold_prover_data", + side = "prove", + phase = "fold_prover_data", + ) + .in_scope(|| { + fold_pcs_prover_data::( + instance_prover_data, + sumfold_output.eq_instance_weights(), + field_cfg, + ) + })?; + + Ok(( + folded, + folded_public, + row_claim, + sumfold_output, + folded_prover_data, + )) +} + +#[allow(clippy::too_many_arguments)] +#[cfg(debug_assertions)] +fn production_sha_folded_row_sum_fast( + trace: &ProjectedTrace, + public: &ProjectedPublic, + row_weights: &[F], + a_powers: &[F], + lambda_powers: &[F], + booleanity_weights: &[F], + booleanity_sources: &[ShaBooleanitySource], + field_cfg: &F::Config, +) -> Result> +where + F: InnerTransparentField + DelayedFieldProductSum + Send + Sync + 'static, + F::Inner: Zero, +{ + #[cfg(debug_assertions)] + validate_projected_trace(trace)?; + if row_weights.len() != SHA_ROW_COUNT { + return Err(ProductionShaError::LengthMismatch { + label: "row weights", + got: row_weights.len(), + expected: SHA_ROW_COUNT, + }); + } + if a_powers.len() < SHA_IDEAL_EVAL_POWER_COUNT { + return Err(ProductionShaError::LengthMismatch { + label: "a powers", + got: a_powers.len(), + expected: SHA_IDEAL_EVAL_POWER_COUNT, + }); + } + if lambda_powers.len() != NUM_SHA_RESIDUAL_FAMILIES { + return Err(ProductionShaError::LengthMismatch { + label: "lambda powers", + got: lambda_powers.len(), + expected: NUM_SHA_RESIDUAL_FAMILIES, + }); + } + if booleanity_weights.len() != booleanity_sources.len() { + return Err(ProductionShaError::LengthMismatch { + label: "booleanity weights", + got: booleanity_weights.len(), + expected: booleanity_sources.len(), + }); + } + + let weight_vec = |shift: usize| { + (0..SHA_WORD_BITS) + .map(|bit| { + if bit >= shift { + a_powers[bit - shift].clone() + } else { + F::zero_with_cfg(field_cfg) + } + }) + .collect::>() + }; + let rot_vec = |shift: usize| { + (0..SHA_WORD_BITS) + .map(|bit| a_powers[(bit + shift) % SHA_WORD_BITS].clone()) + .collect::>() + }; + + let word_weights = a_powers[..SHA_WORD_BITS].to_vec(); + let rot25_weights = rot_vec(25); + let rot14_weights = rot_vec(14); + let rot15_weights = rot_vec(15); + let rot13_weights = rot_vec(13); + let shift0_weights = weight_vec(0); + let shift2_weights = weight_vec(2); + let shift3_weights = weight_vec(3); + let shift5_weights = weight_vec(5); + let shift8_weights = weight_vec(8); + let shift9_weights = weight_vec(9); + let shift10_weights = weight_vec(10); + let rho_sig0 = a_powers[10].clone() + &a_powers[19] + &a_powers[30]; + let rho_sig1 = a_powers[7].clone() + &a_powers[21] + &a_powers[26]; + let low_mu_coeff = production_sha_pow_two(32, field_cfg); + let high_mu_w_coeff = production_sha_pow_two(34, field_cfg); + let high_mu_3_bit_coeff = production_sha_pow_two(35, field_cfg); + let high_mu_1_bit_coeff = production_sha_pow_two(33, field_cfg); + let one = F::one_with_cfg(field_cfg); + let two = one.clone() + &one; + + let values = cfg_iter!(row_weights) + .enumerate() + .map(|(row, row_weight)| { + let word_eval_with = |col: ShaWordCol, shift: usize, weights: &[F]| { + trace_word_eval_at_row_with_weights(trace, col, row, shift, weights, field_cfg) + }; + let word_eval = + |col: ShaWordCol, shift: usize| word_eval_with(col, shift, &word_weights); + let public_word_eval = |col: ShaPublicCol| { + public_word_or_const_eval_at_row(public, col, row, &word_weights, field_cfg) + }; + + let a_word = word_eval(ShaWordCol::A, 0)?; + let e_word = word_eval(ShaWordCol::E, 0)?; + let sigma0 = word_eval(ShaWordCol::Sigma0, 0)?; + let sigma1 = word_eval(ShaWordCol::Sigma1, 0)?; + let w = word_eval(ShaWordCol::W, 0)?; + let small_sigma0 = word_eval(ShaWordCol::SmallSigma0, 0)?; + let small_sigma1 = word_eval(ShaWordCol::SmallSigma1, 0)?; + let ov_sigma0 = word_eval(ShaWordCol::OvSigma0, 0)?; + let ov_sigma1 = word_eval(ShaWordCol::OvSigma1, 0)?; + let ov_small_sigma0 = word_eval(ShaWordCol::OvSmallSigma0, 0)?; + let ov_small_sigma1 = word_eval(ShaWordCol::OvSmallSigma1, 0)?; + + let mu = |low_weights: &[F], high_weights: &[F], high_coeff: &F| { + Ok::>( + word_eval_with(ShaWordCol::MuPacked, 0, low_weights)? * &low_mu_coeff + - word_eval_with(ShaWordCol::MuPacked, 0, high_weights)? * high_coeff, + ) + }; + let mu_w = mu(&shift0_weights, &shift2_weights, &high_mu_w_coeff)?; + let mu_a = mu(&shift2_weights, &shift5_weights, &high_mu_3_bit_coeff)?; + let mu_e = mu(&shift5_weights, &shift8_weights, &high_mu_3_bit_coeff)?; + let mu_ff_a = mu(&shift8_weights, &shift9_weights, &high_mu_1_bit_coeff)?; + let mu_ff_e = mu(&shift9_weights, &shift10_weights, &high_mu_1_bit_coeff)?; + + let r0 = a_word.clone() * &rho_sig0 - &sigma0 - two.clone() * &ov_sigma0; + let r1 = e_word.clone() * &rho_sig1 - &sigma1 - two.clone() * &ov_sigma1; + let r2 = word_eval_with(ShaWordCol::W, 0, &rot25_weights)? + + word_eval_with(ShaWordCol::W, 0, &rot14_weights)? + + word_eval_with(ShaWordCol::W, 0, &shift3_weights)? + - &small_sigma0 + - two.clone() * &ov_small_sigma0; + let r3 = word_eval_with(ShaWordCol::W, 0, &rot15_weights)? + + word_eval_with(ShaWordCol::W, 0, &rot13_weights)? + + word_eval_with(ShaWordCol::W, 0, &shift10_weights)? + - &small_sigma1 + - two.clone() * &ov_small_sigma1; + let r4 = word_eval(ShaWordCol::W, 16)? + - &w + - word_eval(ShaWordCol::SmallSigma0, 1)? + - word_eval(ShaWordCol::W, 9)? + - word_eval(ShaWordCol::SmallSigma1, 14)? + + &mu_w + + trace_int_at_row(trace, ShaIntCol::CompSchedule, row, field_cfg)?; + let r5 = word_eval(ShaWordCol::A, 4)? + - &e_word + - word_eval(ShaWordCol::Sigma1, 3)? + - word_eval(ShaWordCol::Uef, 3)? + - word_eval(ShaWordCol::UNegEg, 3)? + - public_scalar_at_row(public, ShaPublicCol::K, row, 3, field_cfg)? + - &w + - word_eval(ShaWordCol::Sigma0, 3)? + - word_eval(ShaWordCol::Maj, 3)? + + &mu_a + + trace_int_at_row(trace, ShaIntCol::CompUpdateA, row, field_cfg)?; + let r6 = word_eval(ShaWordCol::E, 4)? + - &a_word + - &e_word + - word_eval(ShaWordCol::Sigma1, 3)? + - word_eval(ShaWordCol::Uef, 3)? + - word_eval(ShaWordCol::UNegEg, 3)? + - public_scalar_at_row(public, ShaPublicCol::K, row, 3, field_cfg)? + - &w + + &mu_e + + trace_int_at_row(trace, ShaIntCol::CompUpdateE, row, field_cfg)?; + + let s_init = public_scalar_at_row(public, ShaPublicCol::SInit, row, 0, field_cfg)?; + let s_msg = public_scalar_at_row(public, ShaPublicCol::SMsg, row, 0, field_cfg)?; + let s_sched = public_scalar_at_row(public, ShaPublicCol::SSched, row, 0, field_cfg)?; + let s_upd = public_scalar_at_row(public, ShaPublicCol::SUpd, row, 0, field_cfg)?; + let s_ff = public_scalar_at_row(public, ShaPublicCol::SFf, row, 0, field_cfg)?; + let s_out = public_scalar_at_row(public, ShaPublicCol::SOut, row, 0, field_cfg)?; + + let r7 = (a_word.clone() - public_word_eval(ShaPublicCol::PAIn)?) * &s_init + + (a_word.clone() - public_word_eval(ShaPublicCol::PAOut)?) * &s_out; + let r8 = (e_word.clone() - public_word_eval(ShaPublicCol::PEIn)?) * &s_init + + (e_word.clone() - public_word_eval(ShaPublicCol::PEOut)?) * &s_out; + let r9 = word_eval(ShaWordCol::A, 4)? + - &a_word + - public_scalar_at_row(public, ShaPublicCol::PAIn, row, 0, field_cfg)? + + &mu_ff_a + + trace_int_at_row(trace, ShaIntCol::CompFeedForwardA, row, field_cfg)?; + let r10 = word_eval(ShaWordCol::E, 4)? + - &e_word + - public_scalar_at_row(public, ShaPublicCol::PEIn, row, 0, field_cfg)? + + &mu_ff_e + + trace_int_at_row(trace, ShaIntCol::CompFeedForwardE, row, field_cfg)?; + let r11 = (w - public_word_eval(ShaPublicCol::Message)?) * &s_msg; + let r12 = trace_int_at_row(trace, ShaIntCol::CompSchedule, row, field_cfg)? * &s_sched; + let r13 = trace_int_at_row(trace, ShaIntCol::CompUpdateA, row, field_cfg)? * &s_upd; + let r14 = trace_int_at_row(trace, ShaIntCol::CompUpdateE, row, field_cfg)? * &s_upd; + let r15 = trace_int_at_row(trace, ShaIntCol::CompFeedForwardA, row, field_cfg)? * &s_ff; + let r16 = trace_int_at_row(trace, ShaIntCol::CompFeedForwardE, row, field_cfg)? * &s_ff; + let r17 = word_eval_with(ShaWordCol::MuPacked, 0, &shift10_weights)?; + + let residuals = [ + r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, r10, r11, r12, r13, r14, r15, r16, r17, + ]; + let linear = FieldFieldInnerProduct::inner_product::( + &residuals, + lambda_powers, + F::zero_with_cfg(field_cfg), + ) + .map_err(|_| { + ProductionShaError::NonCanonicalProofObject( + "production SHA row residual dot product failed", + ) + })?; + + let mut bool_sum = F::zero_with_cfg(field_cfg); + for (source, weight) in booleanity_sources.iter().zip(booleanity_weights.iter()) { + let d = booleanity_source_value_at_fast(trace, row, source, field_cfg)?; + bool_sum += weight.clone() * (d.clone() * (d - one.clone())); + } + + Ok(row_weight.clone() * (linear + bool_sum)) + }) + .collect::, ProductionShaError>>()?; + + folded_row_integrand_sum(&values, field_cfg).map_err(ProductionShaError::from) +} + +#[cfg(debug_assertions)] +fn trace_word_eval_at_row_with_weights( + trace: &ProjectedTrace, + col: ShaWordCol, + row: usize, + shift: usize, + weights: &[F], + field_cfg: &F::Config, +) -> Result> +where + F: PrimeField, +{ + let mut acc = F::zero_with_cfg(field_cfg); + for (bit, weight) in weights.iter().enumerate().take(SHA_WORD_BITS) { + acc += trace_word_bit_at_row(trace, col, row, shift, bit, field_cfg)? * weight; + } + Ok(acc) +} + +#[cfg(debug_assertions)] +fn public_word_or_const_eval_at_row( + public: &ProjectedPublic, + col: ShaPublicCol, + row: usize, + weights: &[F], + field_cfg: &F::Config, +) -> Result> +where + F: PrimeField, +{ + let Some(_col_idx) = public_word_col_index(col) else { + return public_scalar_at_row(public, col, row, 0, field_cfg); + }; + if public.bit_slices.is_none() { + return public_scalar_at_row(public, col, row, 0, field_cfg); + } + let mut acc = F::zero_with_cfg(field_cfg); + for (bit, weight) in weights.iter().enumerate().take(SHA_WORD_BITS) { + acc += public_word_bit_at_row(public, col, row, bit, field_cfg)? * weight; + } + Ok(acc) +} + +#[cfg(debug_assertions)] +fn booleanity_source_value_at_fast( + trace: &ProjectedTrace, + row: usize, + source: &ShaBooleanitySource, + field_cfg: &F::Config, +) -> Result> +where + F: PrimeField, +{ + let one = F::one_with_cfg(field_cfg); + let two = one.clone() + &one; + match *source { + ShaBooleanitySource::WordBit { col, bit } => { + trace_word_bit_at_row(trace, col, row, 0, bit, field_cfg) + } + ShaBooleanitySource::VirtualCh1 { bit } => { + Ok( + trace_word_bit_at_row(trace, ShaWordCol::E, row, 2, bit, field_cfg)? + + &trace_word_bit_at_row(trace, ShaWordCol::E, row, 1, bit, field_cfg)? + - two.clone() + * trace_word_bit_at_row(trace, ShaWordCol::Uef, row, 2, bit, field_cfg)?, + ) + } + ShaBooleanitySource::VirtualCh2 { bit } => { + Ok( + trace_word_bit_at_row(trace, ShaWordCol::E, row, 2, bit, field_cfg)? + - &trace_word_bit_at_row(trace, ShaWordCol::E, row, 0, bit, field_cfg)? + + two.clone() + * trace_word_bit_at_row(trace, ShaWordCol::UNegEg, row, 2, bit, field_cfg)? + + two.clone() + * trace_word_bit_at_row( + trace, + ShaWordCol::Ch2Comp, + row, + 0, + bit, + field_cfg, + )?, + ) + } + ShaBooleanitySource::VirtualMaj { bit } => { + Ok( + trace_word_bit_at_row(trace, ShaWordCol::A, row, 0, bit, field_cfg)? + + &trace_word_bit_at_row(trace, ShaWordCol::A, row, 1, bit, field_cfg)? + + &trace_word_bit_at_row(trace, ShaWordCol::A, row, 2, bit, field_cfg)? + - two.clone() + * trace_word_bit_at_row(trace, ShaWordCol::Maj, row, 2, bit, field_cfg)? + - two.clone() + * trace_word_bit_at_row( + trace, + ShaWordCol::MajComp, + row, + 0, + bit, + field_cfg, + )?, + ) + } + } +} + +#[tracing::instrument( + target = "zinc_protocol::production_sha", + level = "info", + skip_all, + fields(side = "prove", phase = "row_sumcheck") +)] +#[allow(clippy::too_many_arguments)] +fn prove_row_sumcheck_phase( + transcript: &mut impl Transcript, + trace: &ProjectedTrace, + public: &ProjectedPublic, + r_ic: &[F; SHA_ROW_VARS], + r_ic_eq_weights: &[F], + a_powers: &[F], + lambda_powers: &[F], + booleanity_weights: &[F], + booleanity_sources: &[ShaBooleanitySource], + row_claim: &F, + field_cfg: &F::Config, +) -> Result<(MultiDegreeSumcheckProof, FoldedRowSumcheckOutput), ProductionShaError> +where + F: InnerTransparentField + + DelayedFieldProductSum + + FromPrimitiveWithConfig + + Send + + Sync + + 'static, + F::Inner: Transcribable + Zero, + F::Modulus: Transcribable, +{ + let (combined_sumcheck, row_output) = + prove_expression_folded_row_sumcheck_with_output_and_vectors( + transcript, + trace, + public, + r_ic, + r_ic_eq_weights, + a_powers, + lambda_powers, + booleanity_weights, + booleanity_sources, + field_cfg, + )?; + verify_folded_row_sumcheck_claim(&combined_sumcheck.claimed_sums()[0], row_claim)?; + Ok((combined_sumcheck, row_output)) +} + +#[tracing::instrument( + target = "zinc_protocol::production_sha", + level = "info", + skip_all, + fields(side = "prove", phase = "endpoint_multipoint") +)] +#[allow(clippy::too_many_arguments)] +fn prove_endpoint_multipoint_phase( + transcript: &mut impl Transcript, + trace: &ProjectedTrace, + folded_public: &ProjectedPublic, + row_output: &FoldedRowSumcheckOutput, + r_ic: &[F; SHA_ROW_VARS], + a: &F, + a_powers: &[F], + lambda_powers: &[F], + booleanity_weights: &[F], + booleanity_sources: &[ShaBooleanitySource], + field_cfg: &F::Config, +) -> Result, ProductionShaError> +where + F: InnerTransparentField + + DelayedFieldProductSum + + FromPrimitiveWithConfig + + Send + + Sync + + 'static, + F::Inner: Transcribable + Zero + Default + Send + Sync, + F::Modulus: Transcribable, +{ + #[cfg(not(debug_assertions))] + let _ = ( + r_ic, + a, + a_powers, + lambda_powers, + booleanity_weights, + booleanity_sources, + ); + + let endpoint_evals = tracing::info_span!( + target: "zinc_protocol::production_sha", + "endpoint_build_evals", + side = "prove", + phase = "endpoint_build_evals", + ) + .in_scope(|| { + row_output.endpoint_evals.clone().map_or_else( + || { + build_sha_endpoint_evals_from_trace_with_row_weights( + trace, + &row_output.r_star_eq_weights, + a, + field_cfg, + ) + }, + Ok, + ) + })?; + let resolver = tracing::info_span!( + target: "zinc_protocol::production_sha", + "endpoint_resolver", + side = "prove", + phase = "endpoint_resolver", + ) + .in_scope(|| { + let resolver = sha_resolver_from_endpoint_evals(&endpoint_evals)?; + absorb_sha_resolver_proof(transcript, &resolver, field_cfg); + Ok::<_, ProductionShaError>(resolver) + })?; + #[cfg(debug_assertions)] + { + let resolver_endpoint_evals = sha_endpoint_evals_from_resolver(&resolver, a, field_cfg)?; + let terminal = tracing::info_span!( + target: "zinc_protocol::production_sha", + "endpoint_terminal", + side = "prove", + phase = "endpoint_terminal", + ) + .in_scope(|| { + reconstruct_folded_row_terminal_from_endpoints_with_vectors( + &resolver_endpoint_evals, + folded_public, + r_ic, + &row_output.r_star, + &row_output.r_star_eq_weights, + a_powers, + lambda_powers, + booleanity_weights, + booleanity_sources, + field_cfg, + ) + })?; + verify_folded_row_terminal_value(row_output, &terminal)?; + } + + let (multipoint_eval, r_0) = tracing::info_span!( + target: "zinc_protocol::production_sha", + "endpoint_reduce", + side = "prove", + phase = "endpoint_reduce", + ) + .in_scope(|| { + prove_sha_endpoint_multipoint_with_row_weights( + transcript, + trace, + folded_public, + &endpoint_evals, + &row_output.r_star, + &row_output.r_star_eq_weights, + field_cfg, + ) + })?; + Ok((resolver, endpoint_evals, multipoint_eval, r_0)) +} + +#[tracing::instrument( + target = "zinc_protocol::production_sha", + level = "info", + skip_all, + fields(side = "prove", phase = "pcs_opening") +)] +#[allow(clippy::too_many_arguments)] +fn prove_pcs_opening_phase( + transcript: &mut impl Transcript, + folded_trace: &ProjectedTrace, + instance_commitments: &[PCSCommitments], + fold_weights: &[F], + folded_prover_data: &PCSProverData, + r_0: &[F], + r_0_eq_weights: &[F], + pcs_params: &PCSParams, + field_cfg: &F::Config, +) -> Result, ProductionShaError> +where + Zt: ZincTypes, + F: PrimeField + DelayedFieldProductSum, + F::Inner: Transcribable, + F::Modulus: Transcribable, + P: ZincPCSTypes, + P: ProductionShaFoldedPcsOpen, + P::BinaryPCS: FoldablePCS, D>, + P::ArbitraryPCS: FoldablePCS, D>, + P::IntPCS: FoldablePCS, +{ + let witness_lifted_evals = tracing::info_span!( + target: "zinc_protocol::production_sha", + "pcs_lifted_evals", + side = "prove", + phase = "pcs_lifted_evals", + ) + .in_scope(|| { + build_folded_sha_pcs_lifted_evals_with_row_weights(folded_trace, r_0_eq_weights, field_cfg) + })?; + tracing::info_span!( + target: "zinc_protocol::production_sha", + "pcs_absorb_lifted_evals", + side = "prove", + phase = "pcs_absorb_lifted_evals", + ) + .in_scope(|| absorb_folded_lifted_evals(transcript, &witness_lifted_evals, field_cfg)); + let opening_proof = tracing::info_span!( + target: "zinc_protocol::production_sha", + "pcs_open_core", + side = "prove", + phase = "pcs_open_core", + ) + .in_scope(|| { + P::prove_folded_pcs_opening( + pcs_params, + instance_commitments, + fold_weights, + folded_trace, + folded_prover_data, + r_0, + &witness_lifted_evals, + field_cfg, + ) + })?; + Ok((witness_lifted_evals, opening_proof)) +} + +fn public_uair_trace_view<'a, PolyCoeff, Int, F, const D: usize>( + trace: &'a UairTrace<'_, PolyCoeff, Int, D>, + sig: &UairSignature, +) -> Result, ProductionShaError> +where + PolyCoeff: Clone, + Int: Clone, + F: PrimeField, +{ + let public = sig.public_cols(); + validate_uair_trace_shape(trace, sig)?; + Ok(UairTrace { + binary_poly: Cow::Borrowed( + trace + .binary_poly + .get(..public.num_binary_poly_cols()) + .ok_or(ProductionShaError::LengthMismatch { + label: "UAIR public binary columns", + got: trace.binary_poly.len(), + expected: public.num_binary_poly_cols(), + })?, + ), + arbitrary_poly: Cow::Borrowed( + trace + .arbitrary_poly + .get(..public.num_arbitrary_poly_cols()) + .ok_or(ProductionShaError::LengthMismatch { + label: "UAIR public arbitrary columns", + got: trace.arbitrary_poly.len(), + expected: public.num_arbitrary_poly_cols(), + })?, + ), + int: Cow::Borrowed(trace.int.get(..public.num_int_cols()).ok_or( + ProductionShaError::LengthMismatch { + label: "UAIR public int columns", + got: trace.int.len(), + expected: public.num_int_cols(), + }, + )?), + }) +} + +fn witness_uair_trace_view<'a, PolyCoeff, Int, F, const D: usize>( + trace: &'a UairTrace<'_, PolyCoeff, Int, D>, + sig: &UairSignature, +) -> Result, ProductionShaError> +where + PolyCoeff: Clone, + Int: Clone, + F: PrimeField, +{ + let public = sig.public_cols(); + let total = sig.total_cols(); + validate_uair_trace_shape(trace, sig)?; + Ok(UairTrace { + binary_poly: Cow::Borrowed( + trace + .binary_poly + .get(public.num_binary_poly_cols()..total.num_binary_poly_cols()) + .ok_or(ProductionShaError::LengthMismatch { + label: "UAIR witness binary columns", + got: trace.binary_poly.len(), + expected: total.num_binary_poly_cols(), + })?, + ), + arbitrary_poly: Cow::Borrowed( + trace + .arbitrary_poly + .get(public.num_arbitrary_poly_cols()..total.num_arbitrary_poly_cols()) + .ok_or(ProductionShaError::LengthMismatch { + label: "UAIR witness arbitrary columns", + got: trace.arbitrary_poly.len(), + expected: total.num_arbitrary_poly_cols(), + })?, + ), + int: Cow::Borrowed( + trace + .int + .get(public.num_int_cols()..total.num_int_cols()) + .ok_or(ProductionShaError::LengthMismatch { + label: "UAIR witness int columns", + got: trace.int.len(), + expected: total.num_int_cols(), + })?, + ), + }) +} + +fn validate_uair_trace_shape( + trace: &UairTrace<'_, PolyCoeff, Int, D>, + sig: &UairSignature, +) -> Result<(), ProductionShaError> +where + PolyCoeff: Clone, + Int: Clone, + F: PrimeField, +{ + let total = sig.total_cols(); + if trace.binary_poly.len() != total.num_binary_poly_cols() { + return Err(ProductionShaError::LengthMismatch { + label: "UAIR binary columns", + got: trace.binary_poly.len(), + expected: total.num_binary_poly_cols(), + }); + } + if trace.arbitrary_poly.len() != total.num_arbitrary_poly_cols() { + return Err(ProductionShaError::LengthMismatch { + label: "UAIR arbitrary columns", + got: trace.arbitrary_poly.len(), + expected: total.num_arbitrary_poly_cols(), + }); + } + if trace.int.len() != total.num_int_cols() { + return Err(ProductionShaError::LengthMismatch { + label: "UAIR int columns", + got: trace.int.len(), + expected: total.num_int_cols(), + }); + } + Ok(()) +} + +fn own_uair_trace( + trace: &UairTrace<'_, PolyCoeff, Int, D>, +) -> UairTrace<'static, PolyCoeff, Int, D> +where + PolyCoeff: Clone, + Int: Clone, +{ + UairTrace { + binary_poly: Cow::Owned(trace.binary_poly.iter().cloned().collect()), + arbitrary_poly: Cow::Owned(trace.arbitrary_poly.iter().cloned().collect()), + int: Cow::Owned(trace.int.iter().cloned().collect()), + } +} + +pub fn setup_verify_linear_ideal_fold( + params: LinearIdealFoldVerifierParams, + shape: UairShape, +) -> Result, LinearIdealFoldError> +where + U: Uair + ProductionShaProjectionAdapter, + Zt: ZincTypes, + F: PrimeField + FromPrimitiveWithConfig, + F::Inner: Transcribable, + F::Modulus: Transcribable, + P: ZincPCSTypes, + P::BinaryPCS: FoldablePCS, D>, + P::ArbitraryPCS: FoldablePCS, D>, + P::IntPCS: FoldablePCS, +{ + ensure_production_sha_word_degree::()?; + if shape.num_vars != SHA_ROW_VARS { + return Err(ProductionShaError::LengthMismatch { + label: "production SHA row variables", + got: shape.num_vars, + expected: SHA_ROW_VARS, + }); + } + + let (binary, arbitrary, int) = U::production_sha_pcs_batch_sizes(); + validate_production_sha_batch_sizes::(binary, arbitrary, int)?; + + Ok(VerifiedLinearIdealFoldSetup { + pcs_params: params.pcs_params, + shape, + field_cfg: params.field_cfg, + }) +} + +pub fn verify_linear_ideal_fold( + vs: &VerifiedLinearIdealFoldSetup, + instances: &[UairInstance<'_, Zt::Int, Zt::Int, PCSCommitments, D>], + proof: &ProductionLinearIdealFoldProof, + transcript: &mut impl Transcript, +) -> Result< + FoldedLinearIdealInstance, ProjectedPublic>, + LinearIdealFoldError, +> +where + U: Uair + ProductionShaProjectionAdapter, + Zt: ZincTypes, + F: InnerTransparentField + + DelayedFieldProductSum + + FromPrimitiveWithConfig + + Send + + Sync + + 'static, + F::Inner: Transcribable + Zero + Default + Send + Sync, + F::Modulus: Transcribable, + P: ZincPCSTypes, + P::BinaryPCS: FoldablePCS, D>, + P::ArbitraryPCS: FoldablePCS, D>, + P::IntPCS: FoldablePCS, +{ + let field_cfg = &vs.field_cfg; + ensure_production_sha_word_degree::()?; + if instances.len() < 2 { + return Err(ProductionShaError::InstanceCountTooSmall(instances.len())); + } + if !instances.len().is_power_of_two() { + return Err(ProductionShaError::InstanceCountNotPowerOfTwo( + instances.len(), + )); + } + if proof.instance_commitments.len() != instances.len() { + return Err(ProductionShaError::LengthMismatch { + label: "proof commitments/instances", + got: proof.instance_commitments.len(), + expected: instances.len(), + }); + } + + absorb_production_sha_statement_metadata(transcript); + absorb_uair_shape_metadata(transcript, &vs.shape); + + let publics = + verify_public_projection_phase::(vs, instances, proof, transcript)?; + + let booleanity_sources = production_sha_booleanity_sources(); + + let r_ic = sample_pre_ideal_challenge(transcript, field_cfg); + let beta = sample_instance_batch_challenge(transcript, instances.len(), field_cfg)?; + + let (_aggregate_ideal_polys, a, lambda, rho, xi, initial_claim) = + verify_aggregate_ideal_phase(&proof.ideal_check, transcript, field_cfg)?; + + let sumfold_output = verify_sumfold_phase( + transcript, + &proof.sumfold_proof, + &initial_claim, + &beta, + beta.len(), + instances.len(), + field_cfg, + )?; + + absorb_derived_production_sha_commitments::( + transcript, + b"production_sha_derived_folded_commitments", + &proof.instance_commitments, + sumfold_output.eq_instance_weights(), + field_cfg, + ); + + let row_output = verify_row_sumcheck_phase( + transcript, + &proof.combined_sumcheck, + sumfold_output.final_round_sumcheck_claim(), + field_cfg, + )?; + + let folded_public = + verify_fold_publics_phase(&publics, sumfold_output.eq_instance_weights(), field_cfg)?; + + let subclaim = verify_endpoint_multipoint_phase( + transcript, + proof, + &folded_public, + &row_output, + &r_ic, + &a, + &lambda, + &rho, + &xi, + &booleanity_sources, + field_cfg, + )?; + + let folded_commitments = verify_fold_commitments_phase::( + &proof.instance_commitments, + sumfold_output.eq_instance_weights(), + field_cfg, + )?; + verify_pcs_phase::( + transcript, + &vs.pcs_params, + &proof.instance_commitments, + sumfold_output.eq_instance_weights(), + &folded_commitments, + &subclaim.sumcheck_subclaim.point, + &proof.witness_lifted_evals, + &proof.opening_proof, + field_cfg, + )?; + + Ok(FoldedLinearIdealInstance { + target: sumfold_output.final_round_sumcheck_claim().clone(), + commitments: folded_commitments, + public: folded_public, + }) +} + +#[tracing::instrument( + target = "zinc_protocol::production_sha", + level = "info", + skip_all, + fields(side = "verify", phase = "public_projection", instances = instances.len()) +)] +fn verify_public_projection_phase( + vs: &VerifiedLinearIdealFoldSetup, + instances: &[UairInstance<'_, Zt::Int, Zt::Int, PCSCommitments, D>], + proof: &ProductionLinearIdealFoldProof, + transcript: &mut impl Transcript, +) -> Result>, ProductionShaError> +where + U: Uair + ProductionShaProjectionAdapter, + Zt: ZincTypes, + F: PrimeField + FromPrimitiveWithConfig, + F::Inner: Transcribable, + F::Modulus: Transcribable, + P: ZincPCSTypes, +{ + let field_cfg = &vs.field_cfg; + let mut publics = Vec::with_capacity(instances.len()); + for (instance_idx, instance) in instances.iter().enumerate() { + validate_public_uair_trace_shape::( + &instance.public_trace, + &vs.shape.signature, + )?; + if !pcs_commitments_match::( + &instance.commitments, + &proof.instance_commitments[instance_idx], + ) { + return Err(ProductionShaError::NonCanonicalProofObject( + "instance commitments do not match proof commitments", + )); + } + absorb_public_uair_trace::(transcript, instance_idx, &instance.public_trace); + publics.push(U::project_production_sha_public( + &vs.shape, + &instance.public_trace, + field_cfg, + )?); + } + + validate_production_sha_publics(&publics, field_cfg)?; + absorb_production_sha_commitments::( + transcript, + b"production_sha_fresh_commitments", + &proof.instance_commitments, + ); + absorb_projected_sha_publics(transcript, &publics, field_cfg); + Ok(publics) +} + +#[tracing::instrument( + target = "zinc_protocol::production_sha", + level = "info", + skip_all, + fields(side = "verify", phase = "aggregate_ideal_verify") +)] +fn verify_aggregate_ideal_phase( + proof: &IdealCheckProof, + transcript: &mut impl Transcript, + field_cfg: &F::Config, +) -> Result, ProductionShaError> +where + F: PrimeField + DelayedFieldProductSum, + F::Inner: Transcribable, + F::Modulus: Transcribable, +{ + let aggregate_ideal_polys = aggregate_sha_ideal_polys_from_proof(proof)?; + check_aggregate_sha_ideal_membership(&aggregate_ideal_polys, field_cfg)?; + absorb_aggregate_sha_ideal_polys(transcript, &aggregate_ideal_polys, field_cfg); + + let (a, lambda, rho, xi) = sample_post_aggregate_ideal_challenges(transcript, field_cfg); + let initial_claim = + evaluate_aggregate_sha_ideal_claim(&aggregate_ideal_polys, &a, &lambda, field_cfg)?; + Ok((aggregate_ideal_polys, a, lambda, rho, xi, initial_claim)) +} + +#[tracing::instrument( + target = "zinc_protocol::production_sha", + level = "info", + skip_all, + fields(side = "verify", phase = "sumfold_verify", instance_vars, instances) +)] +fn verify_sumfold_phase( + transcript: &mut impl Transcript, + proof: &MultiDegreeSumcheckProof, + initial_claim: &F, + beta: &[F], + instance_vars: usize, + instances: usize, + field_cfg: &F::Config, +) -> Result, ProductionShaError> +where + F: InnerTransparentField + + DelayedFieldProductSum + + FromPrimitiveWithConfig + + Send + + Sync + + 'static, + F::Inner: Transcribable + Zero + Default + Send + Sync, + F::Modulus: Transcribable, +{ + let verified_sumfold = + verify_full_sha_sumfold(transcript, proof, initial_claim, instance_vars, field_cfg)?; + Ok(derive_instance_fold_claim( + beta, + verified_sumfold.r_b, + verified_sumfold.c_sf, + instances, + field_cfg, + )?) +} + +#[tracing::instrument( + target = "zinc_protocol::production_sha", + level = "info", + skip_all, + fields(side = "verify", phase = "fold_after_sumfold", instances = commitments.len()) +)] +fn verify_fold_commitments_phase( + commitments: &[PCSCommitments], + weights: &[F], + field_cfg: &F::Config, +) -> Result, ProductionShaError> +where + Zt: ZincTypes, + F: PrimeField, + F::Inner: Transcribable, + F::Modulus: Transcribable, + P: ZincPCSTypes, + P::BinaryPCS: FoldablePCS, D>, + P::ArbitraryPCS: FoldablePCS, D>, + P::IntPCS: FoldablePCS, +{ + fold_pcs_commitments::(commitments, weights, field_cfg) +} + +#[tracing::instrument( + target = "zinc_protocol::production_sha", + level = "info", + skip_all, + fields(side = "verify", phase = "row_sumcheck_verify") +)] +fn verify_row_sumcheck_phase( + transcript: &mut impl Transcript, + proof: &MultiDegreeSumcheckProof, + final_round_sumcheck_claim: &F, + field_cfg: &F::Config, +) -> Result, ProductionShaError> +where + F: InnerTransparentField + + DelayedFieldProductSum + + FromPrimitiveWithConfig + + Send + + Sync + + 'static, + F::Inner: Transcribable + Zero + Default + Send + Sync, + F::Modulus: Transcribable, +{ + verify_folded_row_sumcheck(transcript, proof, final_round_sumcheck_claim, field_cfg) +} + +#[tracing::instrument( + target = "zinc_protocol::production_sha", + level = "info", + skip_all, + fields(side = "verify", phase = "fold_after_sumfold", instances = publics.len()) +)] +fn verify_fold_publics_phase( + publics: &[ProjectedPublic], + weights: &[F], + field_cfg: &F::Config, +) -> Result, ProductionShaError> +where + F: PrimeField, +{ + fold_projected_publics(publics, weights, field_cfg) +} + +#[tracing::instrument( + target = "zinc_protocol::production_sha", + level = "info", + skip_all, + fields(side = "verify", phase = "endpoint_multipoint_verify") +)] +#[allow(clippy::too_many_arguments)] +fn verify_endpoint_multipoint_phase( + transcript: &mut impl Transcript, + proof: &ProductionLinearIdealFoldProof, + folded_public: &ProjectedPublic, + row_output: &FoldedRowSumcheckOutput, + r_ic: &[F; SHA_ROW_VARS], + a: &F, + lambda: &F, + rho: &F, + xi: &F, + booleanity_sources: &[ShaBooleanitySource], + field_cfg: &F::Config, +) -> Result, ProductionShaError> +where + Zt: ZincTypes, + F: InnerTransparentField + + DelayedFieldProductSum + + FromPrimitiveWithConfig + + Send + + Sync + + 'static, + F::Inner: Transcribable + Zero + Default + Send + Sync, + F::Modulus: Transcribable, + P: ZincPCSTypes, +{ + absorb_sha_resolver_proof(transcript, &proof.resolver, field_cfg); + let endpoint_evals = sha_endpoint_evals_from_resolver(&proof.resolver, a, field_cfg)?; + let terminal = reconstruct_folded_row_terminal_from_endpoints( + &endpoint_evals, + folded_public, + r_ic, + &row_output.r_star, + a, + lambda, + rho, + xi, + booleanity_sources, + field_cfg, + )?; + verify_folded_row_terminal_value(row_output, &terminal)?; + + let (subclaim, shift_specs) = verify_sha_endpoint_multipoint( + transcript, + &proof.multipoint_eval, + &endpoint_evals, + folded_public, + &row_output.r_star, + field_cfg, + )?; + let open_evals = multipoint_open_evals_from_pcs_lifted( + &proof.witness_lifted_evals, + &production_sha_multipoint_layout(), + folded_public, + &subclaim.sumcheck_subclaim.point, + field_cfg, + )?; + verify_sha_endpoint_multipoint_open_evals(&subclaim, &open_evals, &shift_specs, field_cfg)?; + Ok(subclaim) +} + +#[tracing::instrument( + target = "zinc_protocol::production_sha", + level = "info", + skip_all, + fields(side = "verify", phase = "pcs_verify") +)] +fn verify_pcs_phase( + transcript: &mut impl Transcript, + pcs_params: &PCSVerifierParams, + instance_commitments: &[PCSCommitments], + fold_weights: &[F], + folded_commitments: &PCSCommitments, + point: &[F], + witness_lifted_evals: &[DynamicPolynomialF], + opening_proof: &PCSOpeningProof, + field_cfg: &F::Config, +) -> Result<(), ProductionShaError> +where + Zt: ZincTypes, + F: PrimeField, + F::Inner: Transcribable, + F::Modulus: Transcribable, + P: ZincPCSTypes, + P::BinaryPCS: FoldablePCS, D>, + P::ArbitraryPCS: FoldablePCS, D>, + P::IntPCS: FoldablePCS, +{ + absorb_folded_lifted_evals(transcript, witness_lifted_evals, field_cfg); + verify_production_sha_pcs_opening::( + pcs_params, + instance_commitments, + fold_weights, + folded_commitments, + point, + witness_lifted_evals, + opening_proof, + field_cfg, + ) +} + +fn validate_public_uair_trace_shape( + trace: &UairTrace<'_, PolyCoeff, Int, D>, + sig: &UairSignature, +) -> Result<(), ProductionShaError> +where + PolyCoeff: Clone, + Int: Clone, + F: PrimeField, +{ + let public = sig.public_cols(); + if trace.binary_poly.len() != public.num_binary_poly_cols() { + return Err(ProductionShaError::LengthMismatch { + label: "UAIR public binary columns", + got: trace.binary_poly.len(), + expected: public.num_binary_poly_cols(), + }); + } + if trace.arbitrary_poly.len() != public.num_arbitrary_poly_cols() { + return Err(ProductionShaError::LengthMismatch { + label: "UAIR public arbitrary columns", + got: trace.arbitrary_poly.len(), + expected: public.num_arbitrary_poly_cols(), + }); + } + if trace.int.len() != public.num_int_cols() { + return Err(ProductionShaError::LengthMismatch { + label: "UAIR public int columns", + got: trace.int.len(), + expected: public.num_int_cols(), + }); + } + Ok(()) +} + +fn pcs_commitments_match( + lhs: &PCSCommitments, + rhs: &PCSCommitments, +) -> bool +where + Zt: ZincTypes, + F: PrimeField, + P: ZincPCSTypes, +{ + fn commitment_bytes(commitment: &C, write: W) -> Vec + where + W: FnOnce(&C, &mut Vec), + { + let mut bytes = Vec::new(); + write(commitment, &mut bytes); + bytes + } + + commitment_bytes(&lhs.binary, P::BinaryPCS::write_commitment_bytes) + == commitment_bytes(&rhs.binary, P::BinaryPCS::write_commitment_bytes) + && commitment_bytes(&lhs.arbitrary, P::ArbitraryPCS::write_commitment_bytes) + == commitment_bytes(&rhs.arbitrary, P::ArbitraryPCS::write_commitment_bytes) + && commitment_bytes(&lhs.int, P::IntPCS::write_commitment_bytes) + == commitment_bytes(&rhs.int, P::IntPCS::write_commitment_bytes) +} + +fn aggregate_sha_ideal_polys_from_proof( + proof: &IdealCheckProof, +) -> Result<[DynamicPolynomialF; NUM_NONZERO_SHA_FAMILIES], ProductionShaError> +where + F: PrimeField, +{ + let got = proof.combined_mle_values.len(); + proof + .combined_mle_values + .clone() + .try_into() + .map_err(|_| ProductionShaError::LengthMismatch { + label: "aggregate SHA ideal polynomial count", + got, + expected: NUM_NONZERO_SHA_FAMILIES, + }) +} + +fn scalarize_sha_endpoint_bits(bits: &[F; SHA_WORD_BITS], a: &F, field_cfg: &F::Config) -> F +where + F: PrimeField, +{ + let powers = zinc_utils::powers(a.clone(), F::one_with_cfg(field_cfg), SHA_WORD_BITS); + bits.iter() + .zip(powers.iter()) + .fold(F::zero_with_cfg(field_cfg), |acc, (bit, power)| { + acc + bit.clone() * power + }) +} + +fn sha_endpoint_evals_from_resolver( + resolver: &CombinedPolyResolverProof, + a: &F, + field_cfg: &F::Config, +) -> Result, ProductionShaError> +where + F: DelayedFieldProductSum, +{ + if !resolver.down_evals.is_empty() { + return Err(ProductionShaError::LengthMismatch { + label: "production SHA resolver down evals", + got: resolver.down_evals.len(), + expected: 0, + }); + } + if !resolver.bit_op_down_evals.is_empty() { + return Err(ProductionShaError::LengthMismatch { + label: "production SHA resolver bit-op down evals", + got: resolver.bit_op_down_evals.len(), + expected: 0, + }); + } + + let word_sources = production_sha_endpoint_word_sources(); + let unshifted_words = word_sources.iter().filter(|(_, shift)| *shift == 0).count(); + let shifted_words = word_sources.len() - unshifted_words; + let expected_unshifted_bits = unshifted_words * SHA_WORD_BITS; + let expected_shifted_bits = shifted_words * SHA_WORD_BITS; + if resolver.bit_slice_evals.len() != expected_unshifted_bits { + return Err(ProductionShaError::LengthMismatch { + label: "production SHA resolver unshifted bit slices", + got: resolver.bit_slice_evals.len(), + expected: expected_unshifted_bits, + }); + } + if resolver.shifted_bit_slice_evals.len() != expected_shifted_bits { + return Err(ProductionShaError::LengthMismatch { + label: "production SHA resolver shifted bit slices", + got: resolver.shifted_bit_slice_evals.len(), + expected: expected_shifted_bits, + }); + } + + let mut unshifted_idx = 0usize; + let mut shifted_idx = 0usize; + let mut sources = Vec::with_capacity(word_sources.len()); + for (col, shift) in word_sources { + let bit_slice = if shift == 0 { + let start = unshifted_idx * SHA_WORD_BITS; + unshifted_idx += 1; + &resolver.bit_slice_evals[start..start + SHA_WORD_BITS] + } else { + let start = shifted_idx * SHA_WORD_BITS; + shifted_idx += 1; + &resolver.shifted_bit_slice_evals[start..start + SHA_WORD_BITS] + }; + let bits: [F; SHA_WORD_BITS] = std::array::from_fn(|idx| bit_slice[idx].clone()); + let scalarized = scalarize_sha_endpoint_bits(&bits, a, field_cfg); + sources.push(ShaSourceEndpointEval { + col, + shift, + scalarized, + bits, + }); + } + + let int_sources = production_sha_endpoint_int_sources(); + if resolver.up_evals.len() != int_sources.len() { + return Err(ProductionShaError::LengthMismatch { + label: "production SHA resolver int evals", + got: resolver.up_evals.len(), + expected: int_sources.len(), + }); + } + let int_sources = int_sources + .into_iter() + .zip(resolver.up_evals.iter()) + .map(|(col, scalar)| ShaIntEndpointEval { + col, + scalar: scalar.clone(), + }) + .collect(); + + Ok(ShaEndpointEvals { + sources, + int_sources, + }) +} + +fn sha_resolver_from_endpoint_evals( + endpoint_evals: &ShaEndpointEvals, +) -> Result, ProductionShaError> +where + F: PrimeField, +{ + validate_sha_endpoint_layout(endpoint_evals)?; + + let mut bit_slice_evals = Vec::new(); + let mut shifted_bit_slice_evals = Vec::new(); + for source in &endpoint_evals.sources { + if source.shift == 0 { + bit_slice_evals.extend(source.bits.iter().cloned()); + } else { + shifted_bit_slice_evals.extend(source.bits.iter().cloned()); + } + } + + Ok(CombinedPolyResolverProof { + up_evals: endpoint_evals + .int_sources + .iter() + .map(|source| source.scalar.clone()) + .collect(), + down_evals: Vec::new(), + bit_slice_evals, + bit_op_down_evals: Vec::new(), + shifted_bit_slice_evals, + }) +} + +fn absorb_derived_pcs_commitment( + transcript: &mut impl Transcript, + label: &'static [u8], + commitments: &[&Pcs::Commitment], + weights: &[F], + field_cfg: &F::Config, +) where + Pcs: PCS, + F: PrimeField, + F::Inner: Transcribable, + F::Modulus: Transcribable, + Eval: Clone + std::fmt::Debug + Send + Sync, +{ + let mut field_buf = runtime_field_transcript_buf::(field_cfg); + transcript.absorb_slice(label); + transcript.absorb_slice(b"derived_pcs_commitment_v1"); + transcript.absorb_slice(&(commitments.len() as u64).to_le_bytes()); + transcript.absorb_slice(&(weights.len() as u64).to_le_bytes()); + transcript.absorb_random_field_slice(weights, &mut field_buf); + for (idx, commitment) in commitments.iter().enumerate() { + transcript.absorb_slice(&(idx as u64).to_le_bytes()); + Pcs::absorb_commitment(transcript, commitment); + } + transcript.absorb_slice(b"derived_pcs_commitment_end"); +} + +fn verify_production_sha_pcs_opening( + pcs_params: &PCSVerifierParams, + instance_commitments: &[PCSCommitments], + fold_weights: &[F], + folded_commitments: &PCSCommitments, + r_0: &[F], + folded_lifted_evals: &[DynamicPolynomialF], + opening_proof: &PCSOpeningProof, + field_cfg: &F::Config, +) -> Result<(), ProductionShaError> +where + Zt: ZincTypes, + F: PrimeField, + F::Inner: Transcribable, + F::Modulus: Transcribable, + P: ZincPCSTypes, +{ + ensure_production_sha_word_degree::()?; + validate_production_sha_batch_sizes::( + P::BinaryPCS::batch_size(&folded_commitments.binary), + P::ArbitraryPCS::batch_size(&folded_commitments.arbitrary), + P::IntPCS::batch_size(&folded_commitments.int), + )?; + let (binary_lifted, int_lifted) = split_folded_sha_pcs_lifted_evals(folded_lifted_evals)?; + let arbitrary_lifted: &[DynamicPolynomialF] = &[]; + + let mut transcript = PcsVerifierTranscript { + fs_transcript: Blake3Transcript::default(), + stream: Cursor::default(), + }; + let mut transcription_buf = vec![0u8; F::zero_with_cfg(field_cfg).inner().get_num_bytes()]; + + let binary_commitments = instance_commitments + .iter() + .map(|commitment| &commitment.binary) + .collect::>(); + absorb_derived_pcs_commitment::, D>( + &mut transcript.fs_transcript, + b"production_sha_pcs_binary", + &binary_commitments, + fold_weights, + field_cfg, + ); + absorb_pcs_lifted_evals( + &mut transcript.fs_transcript, + binary_lifted, + &mut transcription_buf, + ); + P::BinaryPCS::verify_open::( + &mut transcript, + &pcs_params.binary, + &folded_commitments.binary, + r_0, + binary_lifted, + &opening_proof.binary, + field_cfg, + )?; + + let arbitrary_commitments = instance_commitments + .iter() + .map(|commitment| &commitment.arbitrary) + .collect::>(); + absorb_derived_pcs_commitment::, D>( + &mut transcript.fs_transcript, + b"production_sha_pcs_arbitrary", + &arbitrary_commitments, + fold_weights, + field_cfg, + ); + absorb_pcs_lifted_evals( + &mut transcript.fs_transcript, + arbitrary_lifted, + &mut transcription_buf, + ); + P::ArbitraryPCS::verify_open::( + &mut transcript, + &pcs_params.arbitrary, + &folded_commitments.arbitrary, + r_0, + arbitrary_lifted, + &opening_proof.arbitrary, + field_cfg, + )?; + + let int_commitments = instance_commitments + .iter() + .map(|commitment| &commitment.int) + .collect::>(); + absorb_derived_pcs_commitment::( + &mut transcript.fs_transcript, + b"production_sha_pcs_int", + &int_commitments, + fold_weights, + field_cfg, + ); + absorb_pcs_lifted_evals( + &mut transcript.fs_transcript, + int_lifted, + &mut transcription_buf, + ); + P::IntPCS::verify_open::( + &mut transcript, + &pcs_params.int, + &folded_commitments.int, + r_0, + int_lifted, + &opening_proof.int, + field_cfg, + )?; + + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +fn prove_production_sha_hyrax_pcs_opening( + pcs_params: &PCSParams, Zt, F, D>, + instance_commitments: &[PCSCommitments, Zt, F, D>], + fold_weights: &[F], + folded_trace: &ProjectedTrace, + folded_prover_data: &PCSProverData, Zt, F, D>, + r_0: &[F], + folded_lifted_evals: &[DynamicPolynomialF], + field_cfg: &F::Config, +) -> Result, Zt, F, D>, ProductionShaError> +where + Zt: ZincTypes, + F: HyraxFieldBridge + DelayedFieldProductSum, + F::Inner: Transcribable, + F::Modulus: Transcribable, + C: AffineRepr, + HyraxPCS: PCS< + F, + BinaryPoly, + D, + CommitmentKey = zip_plus::pcs::hyrax::HyraxCommitmentKey, + ProverData = zip_plus::pcs::hyrax::HyraxProverData, + OpeningProof = Vec, + >, + HyraxPCS: PCS< + F, + DensePolynomial, + D, + CommitmentKey = zip_plus::pcs::hyrax::HyraxCommitmentKey, + ProverData = zip_plus::pcs::hyrax::HyraxProverData, + OpeningProof = Vec, + >, + HyraxPCS: PCS< + F, + Zt::Int, + D, + CommitmentKey = zip_plus::pcs::hyrax::HyraxCommitmentKey, + ProverData = zip_plus::pcs::hyrax::HyraxProverData, + OpeningProof = Vec, + >, +{ + ensure_production_sha_word_degree::()?; + validate_production_sha_batch_sizes::(ShaWordCol::COUNT, 0, ShaIntCol::COUNT)?; + let (binary_lifted, int_lifted) = split_folded_sha_pcs_lifted_evals(folded_lifted_evals)?; + let arbitrary_lifted: &[DynamicPolynomialF] = &[]; + + let arbitrary_scalar_lanes: Vec>> = Vec::new(); + let binary_field_lanes = folded_sha_binary_field_lanes(folded_trace); + let int_field_lanes = folded_sha_int_field_lanes(folded_trace); + + let mut transcript = PcsProverTranscript { + fs_transcript: Blake3Transcript::default(), + stream: Cursor::default(), + }; + let mut transcription_buf = vec![0u8; F::zero_with_cfg(field_cfg).inner().get_num_bytes()]; + + let binary_commitments = instance_commitments + .iter() + .map(|commitment| &commitment.binary) + .collect::>(); + absorb_derived_pcs_commitment::, F, BinaryPoly, D>( + &mut transcript.fs_transcript, + b"production_sha_pcs_binary", + &binary_commitments, + fold_weights, + field_cfg, + ); + absorb_pcs_lifted_evals( + &mut transcript.fs_transcript, + binary_lifted, + &mut transcription_buf, + ); + let binary_start = transcript.stream.position() as usize; + HyraxPCS::::prove_open_field_lanes_single_row::( + &mut transcript, + &pcs_params.binary, + &binary_field_lanes, + r_0, + &folded_prover_data.binary, + field_cfg, + )?; + let binary_end = transcript.stream.position() as usize; + let binary = transcript.stream.get_ref()[binary_start..binary_end].to_vec(); + + let arbitrary_commitments = instance_commitments + .iter() + .map(|commitment| &commitment.arbitrary) + .collect::>(); + absorb_derived_pcs_commitment::< + HyraxPCS, + F, + DensePolynomial, + D, + >( + &mut transcript.fs_transcript, + b"production_sha_pcs_arbitrary", + &arbitrary_commitments, + fold_weights, + field_cfg, + ); + absorb_pcs_lifted_evals( + &mut transcript.fs_transcript, + arbitrary_lifted, + &mut transcription_buf, + ); + let arbitrary_start = transcript.stream.position() as usize; + HyraxPCS::::prove_open_scalar_lanes::( + &mut transcript, + &pcs_params.arbitrary, + &arbitrary_scalar_lanes, + r_0, + &folded_prover_data.arbitrary, + field_cfg, + )?; + let arbitrary_end = transcript.stream.position() as usize; + let arbitrary = transcript.stream.get_ref()[arbitrary_start..arbitrary_end].to_vec(); + + let int_commitments = instance_commitments + .iter() + .map(|commitment| &commitment.int) + .collect::>(); + absorb_derived_pcs_commitment::, F, Zt::Int, D>( + &mut transcript.fs_transcript, + b"production_sha_pcs_int", + &int_commitments, + fold_weights, + field_cfg, + ); + absorb_pcs_lifted_evals( + &mut transcript.fs_transcript, + int_lifted, + &mut transcription_buf, + ); + let int_start = transcript.stream.position() as usize; + HyraxPCS::::prove_open_field_lanes_single_row::( + &mut transcript, + &pcs_params.int, + &int_field_lanes, + r_0, + &folded_prover_data.int, + field_cfg, + )?; + let int_end = transcript.stream.position() as usize; + let int = transcript.stream.get_ref()[int_start..int_end].to_vec(); + + Ok(PCSOpeningProof { + binary, + arbitrary, + int, + }) +} + +#[cfg(test)] +fn evaluate_fresh_targets_from_ideal_polys( + ideal_polys: &[[DynamicPolynomialF; NUM_NONZERO_SHA_FAMILIES]], + a: &F, + lambda: &F, + field_cfg: &F::Config, +) -> Result, ProductionShaError> +where + F: DelayedFieldProductSum, +{ + let lambda_powers = zinc_utils::powers( + lambda.clone(), + F::one_with_cfg(field_cfg), + NUM_SHA_RESIDUAL_FAMILIES, + ); + let a_powers = zinc_utils::powers( + a.clone(), + F::one_with_cfg(field_cfg), + SHA_IDEAL_EVAL_POWER_COUNT, + ); + let nonzero_lambda_powers = selected_nonzero_sha_lambda_powers(&lambda_powers)?; + ideal_polys + .iter() + .map(|instance| { + let mut values: [F; NUM_NONZERO_SHA_FAMILIES] = + std::array::from_fn(|_| F::zero_with_cfg(field_cfg)); + for (slot, poly) in instance.iter().enumerate() { + values[slot] = evaluate_production_sha_poly_at_powers(poly, &a_powers, field_cfg)?; + } + FieldFieldInnerProduct::inner_product::( + &values, + &nonzero_lambda_powers, + F::zero_with_cfg(field_cfg), + ) + .map_err(|_| { + ProductionShaError::NonCanonicalProofObject( + "production SHA nonzero-family dot product failed", + ) + }) + }) + .collect() +} + +#[allow(dead_code)] +fn beta_aggregate_sha_ideal_polys( + ideal_polys: &[[DynamicPolynomialF; NUM_NONZERO_SHA_FAMILIES]], + beta: &[F], + field_cfg: &F::Config, +) -> Result<[DynamicPolynomialF; NUM_NONZERO_SHA_FAMILIES], ProductionShaError> +where + F: PrimeField, +{ + let weights = build_eq_x_r_vec(beta, field_cfg)?; + if weights.len() != ideal_polys.len() { + return Err(ProductionShaError::LengthMismatch { + label: "beta weights/fresh ideal polys", + got: weights.len(), + expected: ideal_polys.len(), + }); + } + + let mut aggregate: [DynamicPolynomialF; NUM_NONZERO_SHA_FAMILIES] = + std::array::from_fn(|_| DynamicPolynomialF::ZERO); + for (weight, instance) in weights.iter().zip(ideal_polys) { + for (slot, poly) in instance.iter().enumerate() { + let weighted = scale_production_sha_poly(poly, weight); + aggregate[slot] += &weighted; + } + } + aggregate.iter_mut().for_each(DynamicPolynomialF::trim); + Ok(aggregate) +} + +#[allow(dead_code)] +fn scale_production_sha_poly(poly: &DynamicPolynomialF, scalar: &F) -> DynamicPolynomialF +where + F: PrimeField, +{ + DynamicPolynomialF::new_trimmed( + poly.coeffs + .iter() + .map(|coeff| coeff.clone() * scalar) + .collect::>(), + ) +} + +fn evaluate_aggregate_sha_ideal_claim( + ideal_polys: &[DynamicPolynomialF; NUM_NONZERO_SHA_FAMILIES], + a: &F, + lambda: &F, + field_cfg: &F::Config, +) -> Result> +where + F: DelayedFieldProductSum, +{ + let lambda_powers = zinc_utils::powers( + lambda.clone(), + F::one_with_cfg(field_cfg), + NUM_SHA_RESIDUAL_FAMILIES, + ); + let a_powers = zinc_utils::powers( + a.clone(), + F::one_with_cfg(field_cfg), + SHA_IDEAL_EVAL_POWER_COUNT, + ); + evaluate_aggregate_sha_ideal_claim_with_powers( + ideal_polys, + &a_powers, + &lambda_powers, + field_cfg, + ) +} + +fn evaluate_aggregate_sha_ideal_claim_with_powers( + ideal_polys: &[DynamicPolynomialF; NUM_NONZERO_SHA_FAMILIES], + a_powers: &[F], + lambda_powers: &[F], + field_cfg: &F::Config, +) -> Result> +where + F: DelayedFieldProductSum, +{ + if a_powers.len() < SHA_IDEAL_EVAL_POWER_COUNT { + return Err(ProductionShaError::LengthMismatch { + label: "aggregate SHA ideal a powers", + got: a_powers.len(), + expected: SHA_IDEAL_EVAL_POWER_COUNT, + }); + } + let mut values: [F; NUM_NONZERO_SHA_FAMILIES] = + std::array::from_fn(|_| F::zero_with_cfg(field_cfg)); + for (slot, poly) in ideal_polys.iter().enumerate() { + values[slot] = evaluate_production_sha_poly_at_powers(poly, a_powers, field_cfg)?; + } + lambda_weighted_nonzero_sha_values(&values, lambda_powers, field_cfg) +} + +fn selected_nonzero_sha_lambda_powers( + lambda_powers: &[F], +) -> Result<[F; NUM_NONZERO_SHA_FAMILIES], ProductionShaError> +where + F: PrimeField, +{ + if lambda_powers.len() != NUM_SHA_RESIDUAL_FAMILIES { + return Err(ProductionShaError::LengthMismatch { + label: "lambda powers", + got: lambda_powers.len(), + expected: NUM_SHA_RESIDUAL_FAMILIES, + }); + } + Ok(std::array::from_fn(|slot| { + lambda_powers[production_sha_nonzero_families()[slot].index()].clone() + })) +} + +fn lambda_weighted_nonzero_sha_values( + values: &[F; NUM_NONZERO_SHA_FAMILIES], + lambda_powers: &[F], + field_cfg: &F::Config, +) -> Result> +where + F: DelayedFieldProductSum, +{ + let weights = selected_nonzero_sha_lambda_powers(lambda_powers)?; + FieldFieldInnerProduct::inner_product::( + values, + &weights, + F::zero_with_cfg(field_cfg), + ) + .map_err(|_| { + ProductionShaError::NonCanonicalProofObject( + "production SHA nonzero-family dot product failed", + ) + }) +} + +fn lambda_weighted_sha_residual_polys_at_powers( + residuals: &[DynamicPolynomialF], + a_powers: &[F], + lambda_powers: &[F], + field_cfg: &F::Config, +) -> Result> +where + F: DelayedFieldProductSum, +{ + if residuals.len() != NUM_SHA_RESIDUAL_FAMILIES { + return Err(ProductionShaError::LengthMismatch { + label: "SHA residual families", + got: residuals.len(), + expected: NUM_SHA_RESIDUAL_FAMILIES, + }); + } + if lambda_powers.len() != NUM_SHA_RESIDUAL_FAMILIES { + return Err(ProductionShaError::LengthMismatch { + label: "lambda powers", + got: lambda_powers.len(), + expected: NUM_SHA_RESIDUAL_FAMILIES, + }); + } + let mut values: [F; NUM_SHA_RESIDUAL_FAMILIES] = + std::array::from_fn(|_| F::zero_with_cfg(field_cfg)); + for (idx, residual) in residuals.iter().enumerate() { + values[idx] = evaluate_production_sha_poly_at_powers(residual, a_powers, field_cfg)?; + } + FieldFieldInnerProduct::inner_product::( + &values, + lambda_powers, + F::zero_with_cfg(field_cfg), + ) + .map_err(|_| { + ProductionShaError::NonCanonicalProofObject( + "production SHA residual-family dot product failed", + ) + }) +} + +fn evaluate_production_sha_poly_at_powers( + poly: &DynamicPolynomialF, + powers: &[F], + field_cfg: &F::Config, +) -> Result> +where + F: DelayedFieldProductSum, +{ + if poly.coeffs.is_empty() { + return Ok(F::zero_with_cfg(field_cfg)); + } + if poly.coeffs.len() > powers.len() { + return Err(ProductionShaError::NonCanonicalProofObject( + "production SHA polynomial exceeds scalarization power bound", + )); + } + DynamicPolyFInnerProduct::inner_product::( + &poly.coeffs, + &powers[..poly.coeffs.len()], + F::zero_with_cfg(field_cfg), + ) + .map_err(|_| { + ProductionShaError::NonCanonicalProofObject("production SHA polynomial dot product failed") + }) +} + +#[allow(dead_code)] +fn eq_weighted_sum( + point: &[F], + values: &[F], + field_cfg: &F::Config, +) -> Result> +where + F: DelayedFieldProductSum, +{ + let expected = 1usize + .checked_shl(u32::try_from(point.len()).map_err(|_| { + ProductionShaError::LengthMismatch { + label: "eq point", + got: point.len(), + expected: usize::BITS as usize, + } + })?) + .ok_or(ProductionShaError::LengthMismatch { + label: "eq point", + got: point.len(), + expected: usize::BITS as usize, + })?; + if values.len() != expected { + return Err(ProductionShaError::LengthMismatch { + label: "eq-weighted values", + got: values.len(), + expected, + }); + } + let weights = build_eq_x_r_vec(point, field_cfg)?; + FieldFieldInnerProduct::inner_product::( + &weights, + values, + F::zero_with_cfg(field_cfg), + ) + .map_err(|_| { + ProductionShaError::NonCanonicalProofObject("eq-weighted value dot product failed") + }) +} + +fn fold_mle_tables( + kind: &'static str, + tables: &[&MleTable], + theta: &[F], + field_cfg: &F::Config, +) -> Result, ProductionShaError> +where + F: PrimeField, +{ + if tables.len() != theta.len() { + return Err(ProductionShaError::LengthMismatch { + label: kind, + got: tables.len(), + expected: theta.len(), + }); + } + let first = tables.first().ok_or(ProductionShaError::LengthMismatch { + label: kind, + got: 0, + expected: 1, + })?; + let col_count = first.len(); + let first_col = first.first().ok_or(ProductionShaError::LengthMismatch { + label: kind, + got: 0, + expected: 1, + })?; + let row_count = first_col.evaluations.len(); + let num_vars = first_col.num_vars; + let mut folded = vec![vec![F::zero_with_cfg(field_cfg); row_count]; col_count]; + for (table, weight) in tables.iter().zip(theta.iter()) { + if table.len() != col_count { + return Err(ProductionShaError::LengthMismatch { + label: kind, + got: table.len(), + expected: col_count, + }); + } + for (col_idx, column) in table.iter().enumerate() { + if column.num_vars != num_vars { + return Err(ProductionShaError::LengthMismatch { + label: kind, + got: column.num_vars, + expected: num_vars, + }); + } + if column.evaluations.len() != row_count { + return Err(ProductionShaError::LengthMismatch { + label: kind, + got: column.evaluations.len(), + expected: row_count, + }); + } + for (out, value) in folded[col_idx].iter_mut().zip(column.evaluations.iter()) { + *out += weight.clone() * value; + } + } + } + Ok(folded + .into_iter() + .map(|evaluations| DenseMultilinearExtension { + evaluations, + num_vars, + }) + .collect()) +} + +fn fold_optional_mle_tables( + kind: &'static str, + tables: &[Option<&MleTable>], + theta: &[F], + field_cfg: &F::Config, +) -> Result>, ProductionShaError> +where + F: PrimeField, +{ + let has_table = tables + .first() + .ok_or(ProductionShaError::LengthMismatch { + label: kind, + got: 0, + expected: 1, + })? + .is_some(); + if !has_table { + if tables.iter().any(Option::is_some) { + return Err(ProductionShaError::LengthMismatch { + label: kind, + got: 1, + expected: 0, + }); + } + return Ok(None); + } + let present = tables + .iter() + .map(|table| { + table.ok_or(ProductionShaError::LengthMismatch { + label: kind, + got: 0, + expected: 1, + }) + }) + .collect::, _>>()?; + fold_mle_tables(kind, &present, theta, field_cfg).map(Some) +} + +fn fold_projected_publics( + publics: &[ProjectedPublic], + theta: &[F], + field_cfg: &F::Config, +) -> Result, ProductionShaError> +where + F: PrimeField, +{ + if publics.len() != theta.len() { + return Err(ProductionShaError::LengthMismatch { + label: "publics/theta", + got: publics.len(), + expected: theta.len(), + }); + } + let first = publics.first().ok_or(ProductionShaError::LengthMismatch { + label: "publics", + got: 0, + expected: 1, + })?; + let columns = fold_mle_tables( + "public columns", + &publics + .iter() + .map(|public| &public.columns) + .collect::>(), + theta, + field_cfg, + )?; + let bit_slices = fold_optional_mle_tables( + "public bit slices", + &publics + .iter() + .map(|public| public.bit_slices.as_ref()) + .collect::>(), + theta, + field_cfg, + )?; + if first.bit_slices.is_none() != bit_slices.is_none() { + return Err(ProductionShaError::LengthMismatch { + label: "public bit slice presence", + got: usize::from(bit_slices.is_some()), + expected: usize::from(first.bit_slices.is_some()), + }); + } + Ok(ProjectedPublic { + columns, + bit_slices, + }) +} + +fn validate_production_sha_publics( + publics: &[ProjectedPublic], + field_cfg: &F::Config, +) -> Result<(), ProductionShaError> +where + F: PrimeField + FromPrimitiveWithConfig, +{ + for public in publics { + if public.columns.len() != ShaPublicCol::COUNT { + return Err(ProductionShaError::LengthMismatch { + label: "SHA public column count", + got: public.columns.len(), + expected: ShaPublicCol::COUNT, + }); + } + for col in &public.columns { + if col.num_vars != SHA_ROW_VARS || col.evaluations.len() != SHA_ROW_COUNT { + return Err(ProductionShaError::LengthMismatch { + label: "SHA public row count", + got: col.evaluations.len(), + expected: SHA_ROW_COUNT, + }); + } + } + for selector in [ + ShaPublicCol::SInit, + ShaPublicCol::SMsg, + ShaPublicCol::SSched, + ShaPublicCol::SUpd, + ShaPublicCol::SFf, + ShaPublicCol::SOut, + ] { + let col = &public.columns[selector.index()]; + for (row, value) in col.evaluations.iter().enumerate() { + let expected = production_sha_selector_expected(selector, row, field_cfg); + if value != &expected { + if *value != F::zero_with_cfg(field_cfg) && *value != F::one_with_cfg(field_cfg) + { + return Err(ProductionShaError::NonBooleanPublicSelector { + col: selector, + row, + }); + } + return Err(ProductionShaError::InvalidPublicSelector { col: selector, row }); + } + } + } + + let k_col = &public.columns[ShaPublicCol::K.index()]; + for (row, value) in k_col.evaluations.iter().enumerate() { + let expected = production_sha_k_expected(row, field_cfg); + if value != &expected { + return Err(ProductionShaError::InvalidRoundConstant { row }); + } + } + + validate_production_sha_public_word_columns(public, field_cfg)?; + } + Ok(()) +} + +fn validate_production_sha_public_word_columns( + public: &ProjectedPublic, + field_cfg: &F::Config, +) -> Result<(), ProductionShaError> +where + F: PrimeField, +{ + let bit_slices = + public + .bit_slices + .as_ref() + .ok_or(ProductionShaError::NonCanonicalProofObject( + "production SHA public word columns are required", + ))?; + if bit_slices.len() != ShaPublicWordCol::COUNT * SHA_WORD_BITS { + return Err(ProductionShaError::LengthMismatch { + label: "SHA public word column count", + got: bit_slices.len(), + expected: ShaPublicWordCol::COUNT * SHA_WORD_BITS, + }); + } + + for (word_idx, public_col) in production_sha_public_word_column_map().iter().enumerate() { + let scalar_col = &public.columns[public_col.index()]; + for row in 0..SHA_ROW_COUNT { + let mut bits = Vec::with_capacity(SHA_WORD_BITS); + for bit in 0..SHA_WORD_BITS { + let table_idx = bit_slice_index(word_idx, bit, SHA_WORD_BITS); + let bit_col = + bit_slices + .get(table_idx) + .ok_or(ProductionShaError::LengthMismatch { + label: "SHA public word bit column", + got: table_idx, + expected: bit_slices.len(), + })?; + if bit_col.num_vars != SHA_ROW_VARS || bit_col.evaluations.len() != SHA_ROW_COUNT { + return Err(ProductionShaError::LengthMismatch { + label: "SHA public word row count", + got: bit_col.evaluations.len(), + expected: SHA_ROW_COUNT, + }); + } + let bit = bit_col.evaluations[row].clone(); + if bit != F::zero_with_cfg(field_cfg) && bit != F::one_with_cfg(field_cfg) { + return Err(ProductionShaError::NonCanonicalProofObject( + "production SHA public word bit is not boolean", + )); + } + bits.push(bit); + } + if bits.len() != SHA_WORD_BITS { + return Err(ProductionShaError::LengthMismatch { + label: "SHA public word bit count", + got: bits.len(), + expected: SHA_WORD_BITS, + }); + } + let scalarized = scalarize_sha_public_word_bits_at_two(&bits, field_cfg); + if scalarized != scalar_col.evaluations[row] { + return Err(ProductionShaError::NonCanonicalProofObject( + "production SHA public word bits do not match scalar public column", + )); + } + } + debug_assert_eq!(Some(word_idx), public_word_col_index(*public_col)); + } + Ok(()) +} + +fn production_sha_public_word_column_map() -> [ShaPublicCol; ShaPublicWordCol::COUNT] { + [ + ShaPublicCol::PAIn, + ShaPublicCol::PEIn, + ShaPublicCol::PAOut, + ShaPublicCol::PEOut, + ShaPublicCol::Message, + ] +} + +fn scalarize_sha_public_word_bits_at_two(bits: &[F], field_cfg: &F::Config) -> F +where + F: PrimeField, +{ + let two = F::one_with_cfg(field_cfg) + F::one_with_cfg(field_cfg); + let mut power = F::one_with_cfg(field_cfg); + let mut out = F::zero_with_cfg(field_cfg); + for bit in bits { + out += bit.clone() * &power; + power *= &two; + } + out +} + +fn production_sha_selector_expected( + selector: ShaPublicCol, + row: usize, + field_cfg: &F::Config, +) -> F +where + F: PrimeField, +{ + let active = match selector { + ShaPublicCol::SInit => row < 4, + ShaPublicCol::SMsg => row < 16, + ShaPublicCol::SSched => row < 48, + ShaPublicCol::SUpd => row < 64, + ShaPublicCol::SFf => (64..68).contains(&row), + ShaPublicCol::SOut => (68..72).contains(&row), + _ => false, + }; + if active { + F::one_with_cfg(field_cfg) + } else { + F::zero_with_cfg(field_cfg) + } +} + +fn production_sha_k_expected(row: usize, field_cfg: &F::Config) -> F +where + F: PrimeField + FromPrimitiveWithConfig, +{ + if (3..67).contains(&row) { + F::from_with_cfg(SHA256_ROUND_CONSTANTS[row - 3] as u64, field_cfg) + } else { + F::zero_with_cfg(field_cfg) + } +} + +#[allow(dead_code)] +fn build_folded_sha_pcs_lifted_evals( + folded_trace: &ProjectedTrace, + r_0: &[F], + field_cfg: &F::Config, +) -> Result>, ProductionShaError> +where + F: PrimeField + DelayedFieldProductSum, +{ + let row_weights = build_eq_x_r_vec(r_0, field_cfg)?; + build_folded_sha_pcs_lifted_evals_with_row_weights(folded_trace, &row_weights, field_cfg) +} + +fn build_folded_sha_pcs_lifted_evals_with_row_weights( + folded_trace: &ProjectedTrace, + row_weights: &[F], + field_cfg: &F::Config, +) -> Result>, ProductionShaError> +where + F: PrimeField + DelayedFieldProductSum, +{ + if row_weights.len() != SHA_ROW_COUNT { + return Err(ProductionShaError::LengthMismatch { + label: "row weights", + got: row_weights.len(), + expected: SHA_ROW_COUNT, + }); + } + #[cfg(debug_assertions)] + validate_projected_trace(folded_trace)?; + let word_lifted = cfg_iter!(&ShaWordCol::ALL) + .map(|&col| { + let coeffs = sha_word_bits_at_point_with_weights_unchecked( + folded_trace, + col, + 0, + row_weights, + field_cfg, + )? + .to_vec(); + Ok(DynamicPolynomialF::new_trimmed(coeffs)) + }) + .collect::, ProductionShaError>>()?; + let int_lifted = cfg_iter!(&ShaIntCol::ALL) + .map(|&col| { + Ok(DynamicPolynomialF::new_trimmed([ + sha_int_at_point_with_weights_unchecked(folded_trace, col, row_weights, field_cfg)?, + ])) + }) + .collect::, ProductionShaError>>()?; + Ok(word_lifted.into_iter().chain(int_lifted).collect()) +} + +fn split_folded_sha_pcs_lifted_evals( + lifted_evals: &[DynamicPolynomialF], +) -> Result<(&[DynamicPolynomialF], &[DynamicPolynomialF]), ProductionShaError> +where + F: PrimeField, +{ + let expected = ShaWordCol::COUNT + ShaIntCol::COUNT; + if lifted_evals.len() != expected { + return Err(ProductionShaError::LengthMismatch { + label: "folded SHA PCS lifted evals", + got: lifted_evals.len(), + expected, + }); + } + validate_folded_sha_pcs_lifted_evals_canonical(lifted_evals)?; + Ok(lifted_evals.split_at(ShaWordCol::COUNT)) +} + +fn validate_folded_sha_pcs_lifted_evals_canonical( + lifted_evals: &[DynamicPolynomialF], +) -> Result<(), ProductionShaError> +where + F: PrimeField, +{ + for (idx, lifted_eval) in lifted_evals.iter().enumerate() { + let max_len = if idx < ShaWordCol::COUNT { + SHA_WORD_BITS + } else { + 1 + }; + if lifted_eval.coeffs.len() > max_len { + return Err(ProductionShaError::NonCanonicalProofObject( + "folded SHA lifted eval has too many coefficients", + )); + } + if lifted_eval.coeffs.last().is_some_and(F::is_zero) { + return Err(ProductionShaError::NonCanonicalProofObject( + "folded SHA lifted eval has trailing zero coefficients", + )); + } + } + Ok(()) +} + +#[allow(dead_code)] +fn folded_sha_binary_scalar_lanes( + folded_trace: &ProjectedTrace, +) -> Result>>, ZipError> +where + C: AffineRepr, + F: HyraxFieldBridge, +{ + let lanes = cfg_into_iter!(0..ShaWordCol::COUNT * SHA_WORD_BITS) + .map(|flat_idx| { + let col_idx = flat_idx / SHA_WORD_BITS; + let bit = flat_idx % SHA_WORD_BITS; + let column = &folded_trace.bit_slices[bit_slice_index(col_idx, bit, SHA_WORD_BITS)]; + column + .evaluations + .iter() + .map(F::field_to_scalar) + .collect::, _>>() + }) + .collect::, _>>()?; + + let mut out = Vec::with_capacity(ShaWordCol::COUNT); + let mut lanes = lanes.into_iter(); + for _ in 0..ShaWordCol::COUNT { + let mut col_lanes = Vec::with_capacity(SHA_WORD_BITS); + for _ in 0..SHA_WORD_BITS { + col_lanes.push( + lanes + .next() + .expect("flat binary scalar lane count is exact"), + ); + } + out.push(col_lanes); + } + Ok(out) +} + +#[allow(dead_code)] +fn folded_sha_int_scalar_lanes( + folded_trace: &ProjectedTrace, +) -> Result>>, ZipError> +where + C: AffineRepr, + F: HyraxFieldBridge, +{ + cfg_iter!(&ShaIntCol::ALL) + .map(|col| { + let column = &folded_trace.int_columns[col.index()]; + let lane = column + .evaluations + .iter() + .map(F::field_to_scalar) + .collect::, _>>()?; + Ok(vec![lane]) + }) + .collect() +} + +fn folded_sha_binary_field_lanes(folded_trace: &ProjectedTrace) -> Vec> +where + F: PrimeField, +{ + ShaWordCol::ALL + .iter() + .map(|col| { + (0..SHA_WORD_BITS) + .map(|bit| { + folded_trace.bit_slices[bit_slice_index(col.index(), bit, SHA_WORD_BITS)] + .evaluations + .as_slice() + }) + .collect::>() + }) + .collect() +} + +fn folded_sha_int_field_lanes(folded_trace: &ProjectedTrace) -> Vec> +where + F: PrimeField, +{ + ShaIntCol::ALL + .iter() + .map(|col| vec![folded_trace.int_columns[col.index()].evaluations.as_slice()]) + .collect() +} + +fn absorb_pcs_lifted_evals( + transcript: &mut impl Transcript, + lifted_evals: &[DynamicPolynomialF], + transcription_buf: &mut Vec, +) where + F: PrimeField, + F::Inner: Transcribable, + F::Modulus: Transcribable, +{ + for lifted_eval in lifted_evals { + transcript.absorb_random_field_slice(&lifted_eval.coeffs, transcription_buf); + } +} + +fn multipoint_open_evals_from_pcs_lifted( + lifted_evals: &[DynamicPolynomialF], + layout: &ShaMultipointLayout, + folded_public: &ProjectedPublic, + r_0: &[F], + field_cfg: &F::Config, +) -> Result, ProductionShaError> +where + F: PrimeField, +{ + split_folded_sha_pcs_lifted_evals(lifted_evals)?; + layout + .sources + .iter() + .map(|source| match *source { + ShaMpSource::Public { col } => { + sha_public_at_point(folded_public, col, 0, r_0, field_cfg) + .map_err(ProductionShaError::from) + } + ShaMpSource::WordBit { col, bit } => Ok(lifted_evals[col.index()] + .coeffs + .get(bit) + .cloned() + .unwrap_or_else(|| F::zero_with_cfg(field_cfg))), + ShaMpSource::Int { col } => Ok(lifted_evals[ShaWordCol::COUNT + col.index()] + .coeffs + .first() + .cloned() + .unwrap_or_else(|| F::zero_with_cfg(field_cfg))), + }) + .collect() +} + +pub fn prove_sha_sumfold_targets( + transcript: &mut impl Transcript, + fresh_targets: &[F], + beta: &[F], + prefix_vars: usize, + field_cfg: &F::Config, +) -> Result<(MultiDegreeSumcheckProof, InstanceFoldClaim), ProductionShaError> +where + F: InnerTransparentField + + DelayedFieldProductSum + + FromPrimitiveWithConfig + + Send + + Sync + + 'static, + F::Inner: Transcribable + Zero, + F::Modulus: Transcribable, +{ + let claims = zinc_piop::neutron_nova::LinearInstanceClaims::new(fresh_targets.to_vec())?; + let group = claims.build_hybrid_sumcheck_group(beta, prefix_vars, field_cfg)?; + let (proof, states) = + MultiDegreeSumcheck::prove_as_subprotocol(transcript, vec![group], claims.ell(), field_cfg); + let r_b = states + .first() + .ok_or(ProductionShaError::LengthMismatch { + label: "sumfold states", + got: 0, + expected: 1, + })? + .randomness + .clone(); + let c_sf = sumfold_expected_eval(beta, fresh_targets, &r_b, field_cfg)?; + let output = derive_instance_fold_claim(beta, r_b, c_sf, fresh_targets.len(), field_cfg)?; + Ok((proof, output)) +} + +pub fn verify_sha_sumfold_targets( + transcript: &mut impl Transcript, + proof: &MultiDegreeSumcheckProof, + fresh_targets: &[F], + beta: &[F], + field_cfg: &F::Config, +) -> Result, ProductionShaError> +where + F: InnerTransparentField + + DelayedFieldProductSum + + FromPrimitiveWithConfig + + Send + + Sync + + 'static, + F::Inner: Transcribable + Zero, + F::Modulus: Transcribable, +{ + require_single_sumcheck_group(proof, "SHA SumFold")?; + for °ree in proof.degrees() { + if degree > 3 { + return Err(ProductionShaError::SumFoldDegreeTooHigh { degree }); + } + } + let claims = zinc_piop::neutron_nova::LinearInstanceClaims::new(fresh_targets.to_vec())?; + let subclaims = + MultiDegreeSumcheck::verify_as_subprotocol(transcript, claims.ell(), proof, field_cfg)?; + let r_b = subclaims.point().to_vec(); + let c_sf = subclaims.expected_evaluations()[0].clone(); + if c_sf != sumfold_expected_eval(beta, fresh_targets, &r_b, field_cfg)? { + return Err(ProductionShaError::SumFoldTerminalMismatch); + } + Ok(derive_instance_fold_claim( + beta, + r_b, + c_sf, + fresh_targets.len(), + field_cfg, + )?) +} + +#[allow(clippy::too_many_arguments)] +pub fn prove_full_sha_sumfold( + transcript: &mut impl Transcript, + traces: &[ProjectedTrace], + publics: &[ProjectedPublic], + initial_claim: &F, + beta: &[F], + r_ic: &[F; SHA_ROW_VARS], + a: &F, + lambda: &F, + rho: &F, + xi: &F, + booleanity_sources: &[ShaBooleanitySource], + field_cfg: &F::Config, +) -> Result<(MultiDegreeSumcheckProof, InstanceFoldClaim), ProductionShaError> +where + F: InnerTransparentField + + DelayedFieldProductSum + + ShaBinaryFoldField + + FromPrimitiveWithConfig + + Send + + Sync + + 'static, + F::Inner: Transcribable + Zero, + F::Modulus: Transcribable, +{ + let group = build_dense_sha_sumfold_group( + traces, + publics, + beta, + r_ic, + a, + lambda, + rho, + xi, + booleanity_sources, + field_cfg, + )?; + let ell = beta.len(); + let (proof, states) = + MultiDegreeSumcheck::prove_as_subprotocol(transcript, vec![group], ell, field_cfg); + require_single_sumcheck_group(&proof, "SHA SumFold")?; + if proof.claimed_sums()[0] != *initial_claim { + return Err(ProductionShaError::SumFoldTerminalMismatch); + } + + let r_b = states + .first() + .ok_or(ProductionShaError::LengthMismatch { + label: "sumfold states", + got: 0, + expected: 1, + })? + .randomness + .clone(); + let provisional = derive_instance_fold_claim( + beta, + r_b.clone(), + F::one_with_cfg(field_cfg), + traces.len(), + field_cfg, + )?; + let (folded, folded_public) = + zinc_piop::neutron_nova::fold_projected_traces(traces, publics, &provisional, field_cfg)?; + let post_sumfold_claim = zinc_piop::neutron_nova::expression_folded_row_sum( + &folded.trace, + &folded_public, + r_ic, + a, + lambda, + rho, + xi, + booleanity_sources, + field_cfg, + )?; + let d = eq_eval(beta, &r_b, F::one_with_cfg(field_cfg))?; + let c_sf = d * post_sumfold_claim; + Ok(( + proof, + derive_instance_fold_claim(beta, r_b, c_sf, traces.len(), field_cfg)?, + )) +} + +#[allow(clippy::too_many_arguments)] +pub fn prove_optimized_sha_sumfold( + transcript: &mut impl Transcript, + traces: &[ProjectedTrace], + _publics: &[ProjectedPublic], + initial_claim: &F, + beta: &[F], + r_ic: &[F; SHA_ROW_VARS], + a: &F, + lambda: &F, + rho: &F, + xi: &F, + booleanity_sources: &[ShaBooleanitySource], + coeff_tables: &[LinearResidualCoeffTable], + prefix_vars: usize, + field_cfg: &F::Config, +) -> Result<(MultiDegreeSumcheckProof, InstanceFoldClaim), ProductionShaError> +where + F: InnerTransparentField + + DelayedFieldProductSum + + ShaBinaryFoldField + + FromPrimitiveWithConfig + + Send + + Sync + + 'static, + F::Inner: Transcribable + Zero, + F::Modulus: Transcribable, +{ + let beta_eq_weights = build_eq_x_r_vec(beta, field_cfg)?; + let r_ic_eq_weights = build_eq_x_r_vec(r_ic, field_cfg)?; + let a_powers = build_sha_residual_eval_powers(a, field_cfg); + let lambda_powers = build_sha_lambda_powers(lambda, field_cfg); + let booleanity_weights = build_booleanity_weights(rho, xi, booleanity_sources.len(), field_cfg); + let linear_accumulator = + build_sha_sumfold_linear_accumulator(coeff_tables, &a_powers, &lambda_powers, field_cfg)?; + let quadratic_prefix_accumulator = build_sha_sumfold_quadratic_prefix_accumulator( + traces, + booleanity_sources, + prefix_vars, + &r_ic_eq_weights, + &booleanity_weights, + field_cfg, + )?; + let group = build_production_sha_sumfold_group_from_prefix_accumulators( + traces, + beta, + &beta_eq_weights, + &r_ic_eq_weights, + &linear_accumulator, + &quadratic_prefix_accumulator, + &booleanity_weights, + booleanity_sources, + prefix_vars, + field_cfg, + )?; + let (proof, r_b, c_sf) = prove_optimized_sha_sumfold_with_weights( + transcript, + group, + initial_claim, + beta.len(), + field_cfg, + )?; + #[cfg(debug_assertions)] + { + let provisional = + derive_instance_fold_claim(beta, r_b.clone(), c_sf.clone(), traces.len(), field_cfg)?; + let (folded, folded_public) = zinc_piop::neutron_nova::fold_projected_traces( + traces, + _publics, + &provisional, + field_cfg, + )?; + let row_claim = expression_folded_row_sum_with_vectors( + &folded.trace, + &folded_public, + &r_ic_eq_weights, + &a_powers, + &lambda_powers, + &booleanity_weights, + booleanity_sources, + field_cfg, + )?; + verify_folded_row_sumcheck_claim(provisional.final_round_sumcheck_claim(), &row_claim)?; + } + Ok(( + proof, + derive_instance_fold_claim(beta, r_b, c_sf, traces.len(), field_cfg)?, + )) +} + +pub fn prove_optimized_sha_sumfold_with_weights( + transcript: &mut impl Transcript, + group: MultiDegreeSumcheckGroup, + initial_claim: &F, + instance_vars: usize, + field_cfg: &F::Config, +) -> Result<(MultiDegreeSumcheckProof, Vec, F), ProductionShaError> +where + F: InnerTransparentField + + DelayedFieldProductSum + + FromPrimitiveWithConfig + + Send + + Sync + + 'static, + F::Inner: Transcribable + Zero, + F::Modulus: Transcribable, +{ + let (proof, states, expected_evaluations) = + MultiDegreeSumcheck::prove_as_subprotocol_with_expected_evaluations( + transcript, + vec![group], + instance_vars, + field_cfg, + ); + require_single_sumcheck_group(&proof, "SHA SumFold")?; + if proof.claimed_sums()[0] != *initial_claim { + return Err(ProductionShaError::SumFoldTerminalMismatch); + } + if expected_evaluations.len() != 1 { + return Err(ProductionShaError::LengthMismatch { + label: "sumfold expected evaluations", + got: expected_evaluations.len(), + expected: 1, + }); + } + + Ok(( + proof, + states + .first() + .ok_or(ProductionShaError::LengthMismatch { + label: "sumfold states", + got: 0, + expected: 1, + })? + .randomness + .clone(), + expected_evaluations[0].clone(), + )) +} + +pub fn derive_instance_fold_claim_from_row_claim( + beta: &[F], + r_b: Vec, + row_claim: &F, + instance_count: usize, + field_cfg: &F::Config, +) -> Result, ProductionShaError> +where + F: PrimeField, +{ + let d = eq_eval(beta, &r_b, F::one_with_cfg(field_cfg))?; + let c_sf = d * row_claim; + derive_instance_fold_claim(beta, r_b, c_sf, instance_count, field_cfg) + .map_err(ProductionShaError::from) +} + +pub fn verify_full_sha_sumfold( + transcript: &mut impl Transcript, + proof: &MultiDegreeSumcheckProof, + initial_claim: &F, + instance_vars: usize, + field_cfg: &F::Config, +) -> Result, ProductionShaError> +where + F: InnerTransparentField + + DelayedFieldProductSum + + FromPrimitiveWithConfig + + Send + + Sync + + 'static, + F::Inner: Transcribable + Zero, + F::Modulus: Transcribable, +{ + require_single_sumcheck_group(proof, "SHA SumFold")?; + for °ree in proof.degrees() { + if degree > 3 { + return Err(ProductionShaError::SumFoldDegreeTooHigh { degree }); + } + } + let Some(claimed_sum) = proof.claimed_sums().first() else { + return Err(ProductionShaError::LengthMismatch { + label: "SHA SumFold claimed sums", + got: 0, + expected: 1, + }); + }; + if claimed_sum != initial_claim { + return Err(ProductionShaError::SumFoldTerminalMismatch); + } + + let subclaims = + MultiDegreeSumcheck::verify_as_subprotocol(transcript, instance_vars, proof, field_cfg)?; + let r_b = subclaims.point().to_vec(); + let c_sf = subclaims.expected_evaluations()[0].clone(); + Ok(VerifiedShaSumFold { r_b, c_sf }) +} + +pub fn fold_pcs_commitments( + commitments: &[PCSCommitments], + theta: &[F], + field_cfg: &F::Config, +) -> Result, ProductionShaError> +where + Zt: ZincTypes, + F: PrimeField, + P: ZincPCSTypes, + P::BinaryPCS: FoldablePCS, D>, + P::ArbitraryPCS: FoldablePCS, D>, + P::IntPCS: FoldablePCS, +{ + if commitments.len() != theta.len() { + return Err(ProductionShaError::LengthMismatch { + label: "commitments/theta", + got: commitments.len(), + expected: theta.len(), + }); + } + let binary = commitments + .iter() + .map(|commitment| &commitment.binary) + .collect::>(); + let arbitrary = commitments + .iter() + .map(|commitment| &commitment.arbitrary) + .collect::>(); + let int = commitments + .iter() + .map(|commitment| &commitment.int) + .collect::>(); + Ok(PCSCommitments { + binary: P::BinaryPCS::fold_commitment_refs(&binary, theta, field_cfg)?, + arbitrary: P::ArbitraryPCS::fold_commitment_refs(&arbitrary, theta, field_cfg)?, + int: P::IntPCS::fold_commitment_refs(&int, theta, field_cfg)?, + }) +} + +pub fn fold_pcs_prover_data( + prover_data: &[PCSProverData], + theta: &[F], + field_cfg: &F::Config, +) -> Result, ProductionShaError> +where + Zt: ZincTypes, + F: PrimeField, + P: ZincPCSTypes, + P::BinaryPCS: FoldablePCS, D>, + P::ArbitraryPCS: FoldablePCS, D>, + P::IntPCS: FoldablePCS, +{ + if prover_data.len() != theta.len() { + return Err(ProductionShaError::LengthMismatch { + label: "prover_data/theta", + got: prover_data.len(), + expected: theta.len(), + }); + } + let binary = prover_data + .iter() + .map(|data| data.binary.clone()) + .collect::>(); + let arbitrary = prover_data + .iter() + .map(|data| data.arbitrary.clone()) + .collect::>(); + let int = prover_data + .iter() + .map(|data| data.int.clone()) + .collect::>(); + Ok(PCSProverData { + binary: P::BinaryPCS::fold_prover_data(&binary, theta, field_cfg)?, + arbitrary: P::ArbitraryPCS::fold_prover_data(&arbitrary, theta, field_cfg)?, + int: P::IntPCS::fold_prover_data(&int, theta, field_cfg)?, + }) +} + +pub fn prove_folded_row_sumcheck( + transcript: &mut impl Transcript, + row_integrand_values: &[F], + post_sumfold_claim: &F, + field_cfg: &F::Config, +) -> Result, ProductionShaError> +where + F: InnerTransparentField + + DelayedFieldProductSum + + FromPrimitiveWithConfig + + Send + + Sync + + 'static, + F::Inner: Transcribable + Zero, + F::Modulus: Transcribable, +{ + let claimed = folded_row_integrand_sum(row_integrand_values, field_cfg)?; + verify_folded_row_sumcheck_claim(&claimed, post_sumfold_claim)?; + let group = build_folded_row_sumcheck_group(row_integrand_values, field_cfg)?; + let (proof, _) = + MultiDegreeSumcheck::prove_as_subprotocol(transcript, vec![group], SHA_ROW_VARS, field_cfg); + Ok(proof) +} + +#[derive(Clone)] +struct RowExpressionOffsets { + word: [[usize; ROW_EXPR_WORD_SHIFT_SLOTS]; ShaWordCol::COUNT], + int: [usize; ShaIntCol::COUNT], + public_scalar: [[usize; ROW_EXPR_PUBLIC_SHIFT_SLOTS]; ShaPublicCol::COUNT], + public_word: [usize; ShaPublicCol::COUNT], +} + +const ROW_EXPR_MISSING_SOURCE: usize = usize::MAX; +const ROW_EXPR_WORD_SHIFT_SLOTS: usize = 17; +const ROW_EXPR_PUBLIC_SHIFT_SLOTS: usize = 4; + +#[derive(Clone)] +struct RowExpressionLayout { + word_sources: Vec<(ShaWordCol, usize)>, + int_sources: Vec, + public_scalar_sources: Vec<(ShaPublicCol, usize)>, + public_word_sources: Vec, + word_offset: usize, + int_offset: usize, + public_scalar_offset: usize, + public_word_offset: usize, +} + +impl RowExpressionLayout { + fn new() -> Self { + let word_sources = production_sha_endpoint_word_sources(); + let int_sources = production_sha_endpoint_int_sources(); + let mut public_scalar_sources = ShaPublicCol::ALL + .iter() + .copied() + .map(|col| (col, 0)) + .collect::>(); + public_scalar_sources.push((ShaPublicCol::K, 3)); + let public_word_sources = vec![ + ShaPublicCol::PAIn, + ShaPublicCol::PEIn, + ShaPublicCol::PAOut, + ShaPublicCol::PEOut, + ShaPublicCol::Message, + ]; + let word_offset = 1; + let int_offset = word_offset + word_sources.len() * SHA_WORD_BITS; + let public_scalar_offset = int_offset + int_sources.len(); + let public_word_offset = public_scalar_offset + public_scalar_sources.len(); + Self { + word_sources, + int_sources, + public_scalar_sources, + public_word_sources, + word_offset, + int_offset, + public_scalar_offset, + public_word_offset, + } + } + + fn offsets(&self) -> RowExpressionOffsets { + let mut word = [[ROW_EXPR_MISSING_SOURCE; ROW_EXPR_WORD_SHIFT_SLOTS]; ShaWordCol::COUNT]; + for (idx, &(col, shift)) in self.word_sources.iter().enumerate() { + word[col.index()][shift] = idx; + } + + let mut int = [ROW_EXPR_MISSING_SOURCE; ShaIntCol::COUNT]; + for (idx, &col) in self.int_sources.iter().enumerate() { + int[col.index()] = idx; + } + + let mut public_scalar = + [[ROW_EXPR_MISSING_SOURCE; ROW_EXPR_PUBLIC_SHIFT_SLOTS]; ShaPublicCol::COUNT]; + for (idx, &(col, shift)) in self.public_scalar_sources.iter().enumerate() { + public_scalar[col.index()][shift] = idx; + } + + let mut public_word = [ROW_EXPR_MISSING_SOURCE; ShaPublicCol::COUNT]; + for (idx, &col) in self.public_word_sources.iter().enumerate() { + public_word[col.index()] = idx; + } + + RowExpressionOffsets { + word, + int, + public_scalar, + public_word, + } + } +} + +#[cfg(debug_assertions)] +fn trace_word_bit_at_row( + trace: &ProjectedTrace, + col: ShaWordCol, + row: usize, + shift: usize, + bit: usize, + field_cfg: &F::Config, +) -> Result> +where + F: PrimeField, +{ + if bit >= SHA_WORD_BITS { + return Err(ProductionShaError::LengthMismatch { + label: "SHA word bit index", + got: bit, + expected: SHA_WORD_BITS, + }); + } + let Some(shifted) = row.checked_add(shift) else { + return Ok(F::zero_with_cfg(field_cfg)); + }; + if shifted >= SHA_ROW_COUNT { + return Ok(F::zero_with_cfg(field_cfg)); + } + trace + .bit_slices + .get(bit_slice_index(col.index(), bit, SHA_WORD_BITS)) + .and_then(|column| column.evaluations.get(shifted)) + .cloned() + .ok_or(ProductionShaError::LengthMismatch { + label: "SHA trace word bit", + got: shifted, + expected: SHA_ROW_COUNT, + }) +} + +#[cfg(debug_assertions)] +fn trace_int_at_row( + trace: &ProjectedTrace, + col: ShaIntCol, + row: usize, + field_cfg: &F::Config, +) -> Result> +where + F: PrimeField, +{ + if row >= SHA_ROW_COUNT { + return Ok(F::zero_with_cfg(field_cfg)); + } + trace + .int_columns + .get(col.index()) + .and_then(|column| column.evaluations.get(row)) + .cloned() + .ok_or(ProductionShaError::LengthMismatch { + label: "SHA trace int row", + got: row, + expected: SHA_ROW_COUNT, + }) +} + +#[cfg(debug_assertions)] +fn public_scalar_at_row( + public: &ProjectedPublic, + col: ShaPublicCol, + row: usize, + shift: usize, + field_cfg: &F::Config, +) -> Result> +where + F: PrimeField, +{ + let Some(shifted) = row.checked_add(shift) else { + return Ok(F::zero_with_cfg(field_cfg)); + }; + if shifted >= SHA_ROW_COUNT { + return Ok(F::zero_with_cfg(field_cfg)); + } + public + .columns + .get(col.index()) + .and_then(|column| column.evaluations.get(shifted)) + .cloned() + .ok_or(ProductionShaError::LengthMismatch { + label: "SHA public scalar row", + got: shifted, + expected: SHA_ROW_COUNT, + }) +} + +fn public_word_col_index(col: ShaPublicCol) -> Option { + match col { + ShaPublicCol::PAIn => Some(0), + ShaPublicCol::PEIn => Some(1), + ShaPublicCol::PAOut => Some(2), + ShaPublicCol::PEOut => Some(3), + ShaPublicCol::Message => Some(4), + _ => None, + } +} + +#[cfg(debug_assertions)] +fn public_word_bit_at_row( + public: &ProjectedPublic, + col: ShaPublicCol, + row: usize, + bit: usize, + field_cfg: &F::Config, +) -> Result> +where + F: PrimeField, +{ + if bit >= SHA_WORD_BITS { + return Err(ProductionShaError::LengthMismatch { + label: "SHA public word bit index", + got: bit, + expected: SHA_WORD_BITS, + }); + } + if row >= SHA_ROW_COUNT { + return Ok(F::zero_with_cfg(field_cfg)); + } + let col_idx = public_word_col_index(col).ok_or(ProductionShaError::NonCanonicalProofObject( + "SHA public column is not a public word", + ))?; + let bit_slices = + public + .bit_slices + .as_ref() + .ok_or(ProductionShaError::NonCanonicalProofObject( + "production SHA public word columns are required", + ))?; + let table_idx = bit_slice_index(col_idx, bit, SHA_WORD_BITS); + bit_slices + .get(table_idx) + .and_then(|column| column.evaluations.get(row)) + .cloned() + .ok_or(ProductionShaError::LengthMismatch { + label: "SHA public word bit row", + got: row, + expected: SHA_ROW_COUNT, + }) +} + +fn production_sha_pow_two(exp: usize, field_cfg: &F::Config) -> F +where + F: PrimeField, +{ + let two = F::one_with_cfg(field_cfg) + F::one_with_cfg(field_cfg); + let mut out = F::one_with_cfg(field_cfg); + for _ in 0..exp { + out *= &two; + } + out +} + +fn row_expr_mle_from_table_shift( + label: &'static str, + table: &MleTable, + col_idx: usize, + shift: usize, + zero: &F, + zero_inner: &F::Inner, +) -> Result, ProductionShaError> +where + F: InnerTransparentField, +{ + let column = table + .get(col_idx) + .ok_or(ProductionShaError::LengthMismatch { + label, + got: col_idx, + expected: table.len(), + })?; + if column.num_vars != SHA_ROW_VARS || column.evaluations.len() != SHA_ROW_COUNT { + return Err(ProductionShaError::LengthMismatch { + label, + got: column.evaluations.len(), + expected: SHA_ROW_COUNT, + }); + } + + let mut evaluations = Vec::with_capacity(SHA_ROW_COUNT); + if shift < SHA_ROW_COUNT { + evaluations.extend( + column.evaluations[shift..] + .iter() + .map(|value| value.inner().clone()), + ); + } + evaluations.resize(SHA_ROW_COUNT, zero_inner.clone()); + debug_assert_eq!(evaluations.len(), SHA_ROW_COUNT); + let _ = zero; + Ok(DenseMultilinearExtension::from_evaluations_vec( + SHA_ROW_VARS, + evaluations, + zero_inner.clone(), + )) +} + +#[allow(clippy::too_many_arguments)] +fn build_production_sha_row_expression_sumcheck_group( + trace: &ProjectedTrace, + public: &ProjectedPublic, + r_ic: &[F; SHA_ROW_VARS], + a: &F, + lambda: &F, + rho: &F, + xi: &F, + booleanity_sources: &[ShaBooleanitySource], + field_cfg: &F::Config, +) -> Result, ProductionShaError> +where + F: InnerTransparentField + DelayedFieldProductSum + Send + Sync + 'static, + F::Inner: Zero, +{ + let row_weights = build_eq_x_r_vec(r_ic, field_cfg)?; + build_production_sha_row_expression_sumcheck_group_with_row_weights( + trace, + public, + &row_weights, + a, + lambda, + rho, + xi, + booleanity_sources, + field_cfg, + ) +} + +#[allow(clippy::too_many_arguments)] +fn build_production_sha_row_expression_sumcheck_group_with_row_weights( + trace: &ProjectedTrace, + public: &ProjectedPublic, + row_weights: &[F], + a: &F, + lambda: &F, + rho: &F, + xi: &F, + booleanity_sources: &[ShaBooleanitySource], + field_cfg: &F::Config, +) -> Result, ProductionShaError> +where + F: InnerTransparentField + DelayedFieldProductSum + Send + Sync + 'static, + F::Inner: Zero, +{ + let a_powers = build_sha_residual_eval_powers(a, field_cfg); + let lambda_powers = build_sha_lambda_powers(lambda, field_cfg); + let booleanity_weights = build_booleanity_weights(rho, xi, booleanity_sources.len(), field_cfg); + build_production_sha_row_expression_sumcheck_group_with_vectors( + trace, + public, + row_weights, + &a_powers, + &lambda_powers, + &booleanity_weights, + booleanity_sources, + field_cfg, + ) +} + +#[allow(clippy::too_many_arguments)] +fn build_production_sha_row_expression_sumcheck_group_with_vectors( + trace: &ProjectedTrace, + public: &ProjectedPublic, + row_weights: &[F], + a_powers: &[F], + lambda_powers: &[F], + booleanity_weights: &[F], + booleanity_sources: &[ShaBooleanitySource], + field_cfg: &F::Config, +) -> Result, ProductionShaError> +where + F: InnerTransparentField + DelayedFieldProductSum + Send + Sync + 'static, + F::Inner: Zero, +{ + if row_weights.len() != SHA_ROW_COUNT { + return Err(ProductionShaError::LengthMismatch { + label: "row weights", + got: row_weights.len(), + expected: SHA_ROW_COUNT, + }); + } + if a_powers.len() < SHA_IDEAL_EVAL_POWER_COUNT { + return Err(ProductionShaError::LengthMismatch { + label: "a powers", + got: a_powers.len(), + expected: SHA_IDEAL_EVAL_POWER_COUNT, + }); + } + if lambda_powers.len() != NUM_SHA_RESIDUAL_FAMILIES { + return Err(ProductionShaError::LengthMismatch { + label: "lambda powers", + got: lambda_powers.len(), + expected: NUM_SHA_RESIDUAL_FAMILIES, + }); + } + if booleanity_weights.len() != booleanity_sources.len() { + return Err(ProductionShaError::LengthMismatch { + label: "booleanity weights", + got: booleanity_weights.len(), + expected: booleanity_sources.len(), + }); + } + let zero = F::zero_with_cfg(field_cfg); + let zero_inner = zero.inner().clone(); + let layout = RowExpressionLayout::new(); + let mut mles = Vec::with_capacity( + 1 + layout.word_sources.len() * SHA_WORD_BITS + + layout.int_sources.len() + + layout.public_scalar_sources.len() + + layout.public_word_sources.len() * SHA_WORD_BITS, + ); + + mles.push(DenseMultilinearExtension::from_evaluations_vec( + SHA_ROW_VARS, + row_weights + .iter() + .map(|value| value.inner().clone()) + .collect(), + zero_inner.clone(), + )); + + for &(col, shift) in &layout.word_sources { + for bit in 0..SHA_WORD_BITS { + mles.push(row_expr_mle_from_table_shift( + "SHA trace word bit", + &trace.bit_slices, + bit_slice_index(col.index(), bit, SHA_WORD_BITS), + shift, + &zero, + &zero_inner, + )?); + } + } + + for &col in &layout.int_sources { + mles.push(row_expr_mle_from_table_shift( + "SHA trace int", + &trace.int_columns, + col.index(), + 0, + &zero, + &zero_inner, + )?); + } + + for &(col, shift) in &layout.public_scalar_sources { + mles.push(row_expr_mle_from_table_shift( + "SHA public scalar", + &public.columns, + col.index(), + shift, + &zero, + &zero_inner, + )?); + } + + for &col in &layout.public_word_sources { + let bit_slices = + public + .bit_slices + .as_ref() + .ok_or(ProductionShaError::NonCanonicalProofObject( + "production SHA public word columns are required", + ))?; + let col_idx = public_word_col_index(col).ok_or( + ProductionShaError::NonCanonicalProofObject("SHA public column is not a public word"), + )?; + for bit in 0..SHA_WORD_BITS { + mles.push(row_expr_mle_from_table_shift( + "SHA public word bit", + bit_slices, + bit_slice_index(col_idx, bit, SHA_WORD_BITS), + 0, + &zero, + &zero_inner, + )?); + } + } + + let offsets = layout.offsets(); + let word_weights = a_powers[..SHA_WORD_BITS].to_vec(); + let rot_weights = |shift: usize| { + (0..SHA_WORD_BITS) + .map(|bit| a_powers[(bit + shift) % SHA_WORD_BITS].clone()) + .collect::>() + }; + let shift_weights = |shift: usize| { + (0..SHA_WORD_BITS) + .map(|bit| { + if bit >= shift { + a_powers[bit - shift].clone() + } else { + F::zero_with_cfg(field_cfg) + } + }) + .collect::>() + }; + let rot25_weights = rot_weights(25); + let rot14_weights = rot_weights(14); + let rot15_weights = rot_weights(15); + let rot13_weights = rot_weights(13); + let shift3_weights = shift_weights(3); + let shift10_weights = shift_weights(10); + let shift0_weights = shift_weights(0); + let shift2_weights = shift_weights(2); + let shift5_weights = shift_weights(5); + let shift8_weights = shift_weights(8); + let shift9_weights = shift_weights(9); + let lambda_powers = lambda_powers.to_vec(); + let booleanity_weights = booleanity_weights.to_vec(); + let booleanity_sources = booleanity_sources.to_vec(); + let one = F::one_with_cfg(field_cfg); + let two = one.clone() + &one; + let rho_sig0 = a_powers[10].clone() + &a_powers[19] + &a_powers[30]; + let rho_sig1 = a_powers[7].clone() + &a_powers[21] + &a_powers[26]; + let low_mu_coeff = production_sha_pow_two(32, field_cfg); + let high_mu_w_coeff = production_sha_pow_two(34, field_cfg); + let high_mu_3_bit_coeff = production_sha_pow_two(35, field_cfg); + let high_mu_1_bit_coeff = production_sha_pow_two(33, field_cfg); + + Ok(MultiDegreeSumcheckGroup::new( + 3, + mles, + Box::new(move |values: &[F]| { + let zero = zero.clone(); + let dot = |lhs: &[F], rhs: &[F]| { + debug_assert_eq!(lhs.len(), rhs.len()); + lhs.iter() + .zip(rhs.iter()) + .fold(zero.clone(), |acc, (left, right)| { + acc + left.clone() * right + }) + }; + let word_source_idx = |col: ShaWordCol, shift: usize| { + let idx = offsets.word[col.index()][shift]; + debug_assert_ne!(idx, ROW_EXPR_MISSING_SOURCE); + idx + }; + let word_bits = |col: ShaWordCol, shift: usize| { + let source_idx = word_source_idx(col, shift); + let base = layout.word_offset + source_idx * SHA_WORD_BITS; + &values[base..base + SHA_WORD_BITS] + }; + let word_eval = + |col: ShaWordCol, shift: usize| dot(word_bits(col, shift), &word_weights); + let word_eval_with = + |col: ShaWordCol, shift: usize, weights: &[F]| dot(word_bits(col, shift), weights); + let word_bit = + |col: ShaWordCol, shift: usize, bit: usize| word_bits(col, shift)[bit].clone(); + let int_value = |col: ShaIntCol| { + let idx = offsets.int[col.index()]; + debug_assert_ne!(idx, ROW_EXPR_MISSING_SOURCE); + values[layout.int_offset + idx].clone() + }; + let public_scalar = |col: ShaPublicCol, shift: usize| { + let idx = offsets.public_scalar[col.index()][shift]; + debug_assert_ne!(idx, ROW_EXPR_MISSING_SOURCE); + values[layout.public_scalar_offset + idx].clone() + }; + let public_word_or_const_eval = |col: ShaPublicCol| { + let idx = offsets.public_word[col.index()]; + if idx == ROW_EXPR_MISSING_SOURCE { + public_scalar(col, 0) + } else { + let base = layout.public_word_offset + idx * SHA_WORD_BITS; + dot(&values[base..base + SHA_WORD_BITS], &word_weights) + } + }; + + let a_word = word_eval(ShaWordCol::A, 0); + let e_word = word_eval(ShaWordCol::E, 0); + let sigma0 = word_eval(ShaWordCol::Sigma0, 0); + let sigma1 = word_eval(ShaWordCol::Sigma1, 0); + let w = word_eval(ShaWordCol::W, 0); + let small_sigma0 = word_eval(ShaWordCol::SmallSigma0, 0); + let small_sigma1 = word_eval(ShaWordCol::SmallSigma1, 0); + let ov_sigma0 = word_eval(ShaWordCol::OvSigma0, 0); + let ov_sigma1 = word_eval(ShaWordCol::OvSigma1, 0); + let ov_small_sigma0 = word_eval(ShaWordCol::OvSmallSigma0, 0); + let ov_small_sigma1 = word_eval(ShaWordCol::OvSmallSigma1, 0); + + let mu = |low_weights: &[F], high_weights: &[F], high_coeff: &F| { + word_eval_with(ShaWordCol::MuPacked, 0, low_weights) * &low_mu_coeff + - word_eval_with(ShaWordCol::MuPacked, 0, high_weights) * high_coeff + }; + let mu_w = mu(&shift0_weights, &shift2_weights, &high_mu_w_coeff); + let mu_a = mu(&shift2_weights, &shift5_weights, &high_mu_3_bit_coeff); + let mu_e = mu(&shift5_weights, &shift8_weights, &high_mu_3_bit_coeff); + let mu_ff_a = mu(&shift8_weights, &shift9_weights, &high_mu_1_bit_coeff); + let mu_ff_e = mu(&shift9_weights, &shift10_weights, &high_mu_1_bit_coeff); + + let w_rot25 = word_eval_with(ShaWordCol::W, 0, &rot25_weights); + let w_rot14 = word_eval_with(ShaWordCol::W, 0, &rot14_weights); + let w_shift3 = word_eval_with(ShaWordCol::W, 0, &shift3_weights); + let w_rot15 = word_eval_with(ShaWordCol::W, 0, &rot15_weights); + let w_rot13 = word_eval_with(ShaWordCol::W, 0, &rot13_weights); + let w_shift10 = word_eval_with(ShaWordCol::W, 0, &shift10_weights); + let w_shift16 = word_eval(ShaWordCol::W, 16); + let w_shift9 = word_eval(ShaWordCol::W, 9); + let small_sigma0_shift1 = word_eval(ShaWordCol::SmallSigma0, 1); + let small_sigma1_shift14 = word_eval(ShaWordCol::SmallSigma1, 14); + let a_shift4 = word_eval(ShaWordCol::A, 4); + let e_shift4 = word_eval(ShaWordCol::E, 4); + let sigma0_shift3 = word_eval(ShaWordCol::Sigma0, 3); + let sigma1_shift3 = word_eval(ShaWordCol::Sigma1, 3); + let uef_shift3 = word_eval(ShaWordCol::Uef, 3); + let uneg_eg_shift3 = word_eval(ShaWordCol::UNegEg, 3); + let maj_shift3 = word_eval(ShaWordCol::Maj, 3); + let public_k_shift3 = public_scalar(ShaPublicCol::K, 3); + + let r0 = a_word.clone() * &rho_sig0 - &sigma0 - two.clone() * &ov_sigma0; + let r1 = e_word.clone() * &rho_sig1 - &sigma1 - two.clone() * &ov_sigma1; + let r2 = w_rot25 + w_rot14 + w_shift3 - &small_sigma0 + - two.clone() * &ov_small_sigma0; + let r3 = w_rot15 + w_rot13 + w_shift10 - &small_sigma1 + - two.clone() * &ov_small_sigma1; + let r4 = w_shift16 - &w - small_sigma0_shift1 - w_shift9 - small_sigma1_shift14 + + &mu_w + + int_value(ShaIntCol::CompSchedule); + let r5 = a_shift4.clone() + - &e_word + - &sigma1_shift3 + - &uef_shift3 + - &uneg_eg_shift3 + - &public_k_shift3 + - &w + - &sigma0_shift3 + - &maj_shift3 + + &mu_a + + int_value(ShaIntCol::CompUpdateA); + let r6 = e_shift4.clone() + - &a_word + - &e_word + - &sigma1_shift3 + - &uef_shift3 + - &uneg_eg_shift3 + - &public_k_shift3 + - &w + + &mu_e + + int_value(ShaIntCol::CompUpdateE); + + let s_init = public_scalar(ShaPublicCol::SInit, 0); + let s_msg = public_scalar(ShaPublicCol::SMsg, 0); + let s_sched = public_scalar(ShaPublicCol::SSched, 0); + let s_upd = public_scalar(ShaPublicCol::SUpd, 0); + let s_ff = public_scalar(ShaPublicCol::SFf, 0); + let s_out = public_scalar(ShaPublicCol::SOut, 0); + + let r7 = (a_word.clone() - public_word_or_const_eval(ShaPublicCol::PAIn)) * &s_init + + (a_word.clone() - public_word_or_const_eval(ShaPublicCol::PAOut)) * &s_out; + let r8 = (e_word.clone() - public_word_or_const_eval(ShaPublicCol::PEIn)) * &s_init + + (e_word.clone() - public_word_or_const_eval(ShaPublicCol::PEOut)) * &s_out; + let r9 = a_shift4 - &a_word - public_scalar(ShaPublicCol::PAIn, 0) + + &mu_ff_a + + int_value(ShaIntCol::CompFeedForwardA); + let r10 = e_shift4 - &e_word - public_scalar(ShaPublicCol::PEIn, 0) + + &mu_ff_e + + int_value(ShaIntCol::CompFeedForwardE); + let r11 = (w - public_word_or_const_eval(ShaPublicCol::Message)) * &s_msg; + let r12 = int_value(ShaIntCol::CompSchedule) * &s_sched; + let r13 = int_value(ShaIntCol::CompUpdateA) * &s_upd; + let r14 = int_value(ShaIntCol::CompUpdateE) * &s_upd; + let r15 = int_value(ShaIntCol::CompFeedForwardA) * &s_ff; + let r16 = int_value(ShaIntCol::CompFeedForwardE) * &s_ff; + let r17 = word_eval_with(ShaWordCol::MuPacked, 0, &shift10_weights); + let residuals = [ + r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, r10, r11, r12, r13, r14, r15, r16, r17, + ]; + let linear = FieldFieldInnerProduct::inner_product::( + &residuals, + &lambda_powers, + zero.clone(), + ) + .expect("row expression residual dot product lengths match"); + + let mut bool_sum = zero.clone(); + for (source, weight) in booleanity_sources.iter().zip(booleanity_weights.iter()) { + let d = match *source { + ShaBooleanitySource::WordBit { col, bit } => word_bit(col, 0, bit), + ShaBooleanitySource::VirtualCh1 { bit: bit_idx } => { + word_bit(ShaWordCol::E, 2, bit_idx) + &word_bit(ShaWordCol::E, 1, bit_idx) + - two.clone() * word_bit(ShaWordCol::Uef, 2, bit_idx) + } + ShaBooleanitySource::VirtualCh2 { bit: bit_idx } => { + word_bit(ShaWordCol::E, 2, bit_idx) - &word_bit(ShaWordCol::E, 0, bit_idx) + + two.clone() * word_bit(ShaWordCol::UNegEg, 2, bit_idx) + + two.clone() * word_bit(ShaWordCol::Ch2Comp, 0, bit_idx) + } + ShaBooleanitySource::VirtualMaj { bit: bit_idx } => { + word_bit(ShaWordCol::A, 0, bit_idx) + + &word_bit(ShaWordCol::A, 1, bit_idx) + + &word_bit(ShaWordCol::A, 2, bit_idx) + - two.clone() * word_bit(ShaWordCol::Maj, 2, bit_idx) + - two.clone() * word_bit(ShaWordCol::MajComp, 0, bit_idx) + } + }; + bool_sum += weight.clone() * (d.clone() * (d - one.clone())); + } + + values[0].clone() * (linear + bool_sum) + }), + )) +} + +#[allow(clippy::too_many_arguments)] +pub fn prove_expression_folded_row_sumcheck( + transcript: &mut impl Transcript, + trace: &ProjectedTrace, + public: &ProjectedPublic, + r_ic: &[F; SHA_ROW_VARS], + a: &F, + lambda: &F, + rho: &F, + xi: &F, + booleanity_sources: &[ShaBooleanitySource], + post_sumfold_claim: &F, + field_cfg: &F::Config, +) -> Result, ProductionShaError> +where + F: InnerTransparentField + + DelayedFieldProductSum + + FromPrimitiveWithConfig + + Send + + Sync + + 'static, + F::Inner: Transcribable + Zero, + F::Modulus: Transcribable, +{ + let claimed = zinc_piop::neutron_nova::expression_folded_row_sum( + trace, + public, + r_ic, + a, + lambda, + rho, + xi, + booleanity_sources, + field_cfg, + )?; + verify_folded_row_sumcheck_claim(&claimed, post_sumfold_claim)?; + let group = build_production_sha_row_expression_sumcheck_group( + trace, + public, + r_ic, + a, + lambda, + rho, + xi, + booleanity_sources, + field_cfg, + )?; + let (proof, _) = + MultiDegreeSumcheck::prove_as_subprotocol(transcript, vec![group], SHA_ROW_VARS, field_cfg); + Ok(proof) +} + +#[allow(clippy::too_many_arguments)] +pub fn prove_expression_folded_row_sumcheck_with_output( + transcript: &mut impl Transcript, + trace: &ProjectedTrace, + public: &ProjectedPublic, + r_ic: &[F; SHA_ROW_VARS], + a: &F, + lambda: &F, + rho: &F, + xi: &F, + booleanity_sources: &[ShaBooleanitySource], + post_sumfold_claim: &F, + field_cfg: &F::Config, +) -> Result<(MultiDegreeSumcheckProof, FoldedRowSumcheckOutput), ProductionShaError> +where + F: InnerTransparentField + + DelayedFieldProductSum + + FromPrimitiveWithConfig + + Send + + Sync + + 'static, + F::Inner: Transcribable + Zero, + F::Modulus: Transcribable, +{ + let r_ic_eq_weights = build_eq_x_r_vec(r_ic, field_cfg)?; + prove_expression_folded_row_sumcheck_with_output_and_weights( + transcript, + trace, + public, + r_ic, + &r_ic_eq_weights, + a, + lambda, + rho, + xi, + booleanity_sources, + post_sumfold_claim, + field_cfg, + ) +} + +#[allow(clippy::too_many_arguments)] +pub fn prove_expression_folded_row_sumcheck_with_output_and_weights( + transcript: &mut impl Transcript, + trace: &ProjectedTrace, + public: &ProjectedPublic, + r_ic: &[F; SHA_ROW_VARS], + r_ic_eq_weights: &[F], + a: &F, + lambda: &F, + rho: &F, + xi: &F, + booleanity_sources: &[ShaBooleanitySource], + post_sumfold_claim: &F, + field_cfg: &F::Config, +) -> Result<(MultiDegreeSumcheckProof, FoldedRowSumcheckOutput), ProductionShaError> +where + F: InnerTransparentField + + DelayedFieldProductSum + + FromPrimitiveWithConfig + + Send + + Sync + + 'static, + F::Inner: Transcribable + Zero, + F::Modulus: Transcribable, +{ + let a_powers = build_sha_residual_eval_powers(a, field_cfg); + let lambda_powers = build_sha_lambda_powers(lambda, field_cfg); + let booleanity_weights = build_booleanity_weights(rho, xi, booleanity_sources.len(), field_cfg); + let claimed = expression_folded_row_sum_with_row_weights( + trace, + public, + r_ic_eq_weights, + a, + lambda, + rho, + xi, + booleanity_sources, + field_cfg, + )?; + verify_folded_row_sumcheck_claim(&claimed, post_sumfold_claim)?; + prove_expression_folded_row_sumcheck_with_output_and_vectors( + transcript, + trace, + public, + r_ic, + r_ic_eq_weights, + &a_powers, + &lambda_powers, + &booleanity_weights, + booleanity_sources, + field_cfg, + ) +} + +fn row_sumcheck_terminal_from_proof( + proof: &MultiDegreeSumcheckProof, + challenges: &[F], +) -> Result> +where + F: FromPrimitiveWithConfig, +{ + if proof.group_messages().len() != 1 { + return Err(ProductionShaError::UnexpectedSumcheckGroupCount { + label: "row sumcheck terminal", + got: proof.group_messages().len(), + }); + } + if proof.claimed_sums().len() != 1 { + return Err(ProductionShaError::UnexpectedSumcheckGroupCount { + label: "row sumcheck terminal claimed sums", + got: proof.claimed_sums().len(), + }); + } + let degree = proof + .degrees() + .first() + .ok_or(ProductionShaError::LengthMismatch { + label: "row sumcheck terminal degrees", + got: 0, + expected: 1, + })?; + let messages = proof + .group_messages() + .first() + .expect("checked row sumcheck group count"); + if messages.len() != challenges.len() { + return Err(ProductionShaError::LengthMismatch { + label: "row sumcheck terminal rounds", + got: messages.len(), + expected: challenges.len(), + }); + } + + let mut expected = proof.claimed_sums()[0].clone(); + for (message, challenge) in messages.iter().zip(challenges) { + let tail = &message.0.tail_evaluations; + if tail.len() != *degree { + return Err(ProductionShaError::LengthMismatch { + label: "row sumcheck terminal degree", + got: tail.len(), + expected: *degree, + }); + } + let constant = match tail.first() { + Some(p1) => expected.clone() - p1, + None => expected.clone(), + }; + let mut evaluations = Vec::with_capacity(tail.len() + 1); + evaluations.push(constant); + evaluations.extend_from_slice(tail); + expected = NatEvaluatedPoly::new(evaluations).evaluate_at_point(challenge)?; + } + + Ok(expected) +} + +#[allow(clippy::too_many_arguments)] +pub fn prove_expression_folded_row_sumcheck_with_output_and_vectors( + transcript: &mut impl Transcript, + trace: &ProjectedTrace, + public: &ProjectedPublic, + r_ic: &[F; SHA_ROW_VARS], + r_ic_eq_weights: &[F], + a_powers: &[F], + lambda_powers: &[F], + booleanity_weights: &[F], + booleanity_sources: &[ShaBooleanitySource], + field_cfg: &F::Config, +) -> Result<(MultiDegreeSumcheckProof, FoldedRowSumcheckOutput), ProductionShaError> +where + F: InnerTransparentField + + DelayedFieldProductSum + + FromPrimitiveWithConfig + + Send + + Sync + + 'static, + F::Inner: Transcribable + Zero, + F::Modulus: Transcribable, +{ + #[cfg(not(debug_assertions))] + let _ = r_ic; + + let group = tracing::info_span!( + target: "zinc_protocol::production_sha", + "row_sumcheck_build_group", + side = "prove", + phase = "row_sumcheck_build_group", + ) + .in_scope(|| { + build_production_sha_row_expression_sumcheck_group_with_vectors( + trace, + public, + r_ic_eq_weights, + a_powers, + lambda_powers, + booleanity_weights, + booleanity_sources, + field_cfg, + ) + })?; + let (proof, states) = tracing::info_span!( + target: "zinc_protocol::production_sha", + "row_sumcheck_prove_core", + side = "prove", + phase = "row_sumcheck_prove_core", + ) + .in_scope(|| { + MultiDegreeSumcheck::prove_as_subprotocol(transcript, vec![group], SHA_ROW_VARS, field_cfg) + }); + let r_star = states + .first() + .ok_or(ProductionShaError::LengthMismatch { + label: "folded row states", + got: 0, + expected: 1, + })? + .randomness + .clone(); + let r_star_eq_weights = build_eq_x_r_vec(&r_star, field_cfg)?; + let a = a_powers.get(1).ok_or(ProductionShaError::LengthMismatch { + label: "a powers", + got: a_powers.len(), + expected: 2, + })?; + let endpoint_evals = tracing::info_span!( + target: "zinc_protocol::production_sha", + "row_sumcheck_endpoint_evals", + side = "prove", + phase = "row_sumcheck_endpoint_evals", + ) + .in_scope(|| { + build_sha_endpoint_evals_from_trace_with_row_weights( + trace, + &r_star_eq_weights, + a, + field_cfg, + ) + })?; + let terminal_value = tracing::info_span!( + target: "zinc_protocol::production_sha", + "row_sumcheck_terminal", + side = "prove", + phase = "row_sumcheck_terminal", + ) + .in_scope(|| row_sumcheck_terminal_from_proof(&proof, &r_star))?; + #[cfg(debug_assertions)] + { + let reconstructed_terminal = tracing::info_span!( + target: "zinc_protocol::production_sha", + "row_sumcheck_terminal_debug", + side = "prove", + phase = "row_sumcheck_terminal_debug", + ) + .in_scope(|| { + reconstruct_folded_row_terminal_from_endpoints_with_vectors( + &endpoint_evals, + public, + r_ic, + &r_star, + &r_star_eq_weights, + a_powers, + lambda_powers, + booleanity_weights, + booleanity_sources, + field_cfg, + ) + })?; + if terminal_value != reconstructed_terminal { + return Err(ProductionShaError::RowSumcheckTerminalMismatch); + } + } + Ok(( + proof, + FoldedRowSumcheckOutput { + r_star, + r_star_eq_weights, + terminal_value, + endpoint_evals: Some(endpoint_evals), + }, + )) +} + +pub fn verify_folded_row_sumcheck( + transcript: &mut impl Transcript, + proof: &MultiDegreeSumcheckProof, + post_sumfold_claim: &F, + field_cfg: &F::Config, +) -> Result, ProductionShaError> +where + F: InnerTransparentField + + DelayedFieldProductSum + + FromPrimitiveWithConfig + + Send + + Sync + + 'static, + F::Inner: Transcribable + Zero, + F::Modulus: Transcribable, +{ + require_single_sumcheck_group(proof, "folded row sumcheck")?; + for °ree in proof.degrees() { + if degree > 3 { + return Err(ProductionShaError::RowSumcheckDegreeTooHigh { degree }); + } + } + let Some(claimed_sum) = proof.claimed_sums().first() else { + return Err(ProductionShaError::LengthMismatch { + label: "folded row claimed sums", + got: 0, + expected: 1, + }); + }; + verify_folded_row_sumcheck_claim(claimed_sum, post_sumfold_claim)?; + let subclaims = + MultiDegreeSumcheck::verify_as_subprotocol(transcript, SHA_ROW_VARS, proof, field_cfg)?; + let r_star = subclaims.point().to_vec(); + let r_star_eq_weights = build_eq_x_r_vec(&r_star, field_cfg)?; + Ok(FoldedRowSumcheckOutput { + r_star, + r_star_eq_weights, + terminal_value: subclaims.expected_evaluations()[0].clone(), + endpoint_evals: None, + }) +} + +pub fn verify_folded_row_terminal_value( + output: &FoldedRowSumcheckOutput, + reconstructed_terminal: &F, +) -> Result<(), ProductionShaError> +where + F: PrimeField, +{ + if &output.terminal_value != reconstructed_terminal { + return Err(ProductionShaError::RowSumcheckTerminalMismatch); + } + Ok(()) +} + +pub fn production_sha_endpoint_word_sources() -> Vec<(ShaWordCol, usize)> { + let mut sources = Vec::new(); + let mut push = |col, shift| { + if !sources.contains(&(col, shift)) { + sources.push((col, shift)); + } + }; + + for col in ShaWordCol::ALL { + push(col, 0); + } + for (col, shifts) in [ + (ShaWordCol::A, &[1usize, 2, 4][..]), + (ShaWordCol::E, &[1usize, 2, 4][..]), + (ShaWordCol::Sigma0, &[3usize][..]), + (ShaWordCol::Sigma1, &[3usize][..]), + (ShaWordCol::W, &[9usize, 16][..]), + (ShaWordCol::SmallSigma0, &[1usize][..]), + (ShaWordCol::SmallSigma1, &[14usize][..]), + (ShaWordCol::Uef, &[2usize, 3][..]), + (ShaWordCol::UNegEg, &[2usize, 3][..]), + (ShaWordCol::Maj, &[2usize, 3][..]), + ] { + for &shift in shifts { + push(col, shift); + } + } + sources +} + +pub fn production_sha_endpoint_int_sources() -> Vec { + ShaIntCol::ALL.to_vec() +} + +pub fn production_sha_multipoint_layout() -> ShaMultipointLayout { + let mut sources = Vec::new(); + let mut push_source = |source| { + if !sources.contains(&source) { + sources.push(source); + } + }; + + for (col, _) in production_sha_endpoint_word_sources() { + for bit in 0..32 { + push_source(ShaMpSource::WordBit { col, bit }); + } + } + for col in production_sha_endpoint_int_sources() { + push_source(ShaMpSource::Int { col }); + } + + let mut shifts = Vec::new(); + let mut push_shift = |shift| { + if !shifts.contains(&shift) { + shifts.push(shift); + } + }; + for (col, shift) in production_sha_endpoint_word_sources() { + if shift == 0 { + continue; + } + for bit in 0..32 { + push_shift(ShaMpShiftSource::WordBit { col, bit, shift }); + } + } + + ShaMultipointLayout { sources, shifts } +} + +pub fn build_sha_endpoint_evals_from_trace( + trace: &ProjectedTrace, + r_star: &[F], + a: &F, + field_cfg: &F::Config, +) -> Result, ProductionShaError> +where + F: DelayedFieldProductSum, +{ + let row_weights = build_eq_x_r_vec(r_star, field_cfg)?; + build_sha_endpoint_evals_from_trace_with_row_weights(trace, &row_weights, a, field_cfg) +} + +pub fn build_sha_endpoint_evals_from_trace_with_row_weights( + trace: &ProjectedTrace, + row_weights: &[F], + a: &F, + field_cfg: &F::Config, +) -> Result, ProductionShaError> +where + F: DelayedFieldProductSum, +{ + if row_weights.len() != SHA_ROW_COUNT { + return Err(ProductionShaError::LengthMismatch { + label: "row weights", + got: row_weights.len(), + expected: SHA_ROW_COUNT, + }); + } + #[cfg(debug_assertions)] + validate_projected_trace(trace)?; + let mut sources = Vec::new(); + for (col, shift) in production_sha_endpoint_word_sources() { + let bits = + sha_endpoint_word_bits_with_row_weights(trace, col, shift, row_weights, field_cfg)?; + sources.push(ShaSourceEndpointEval { + col, + shift, + scalarized: scalarize_sha_endpoint_bits(&bits, a, field_cfg), + bits, + }); + } + let mut int_sources = Vec::new(); + for col in production_sha_endpoint_int_sources() { + int_sources.push(ShaIntEndpointEval { + col, + scalar: sha_endpoint_int_with_row_weights(trace, col, row_weights, field_cfg)?, + }); + } + Ok(ShaEndpointEvals { + sources, + int_sources, + }) +} + +fn sha_endpoint_word_bits_with_row_weights( + trace: &ProjectedTrace, + col: ShaWordCol, + shift: usize, + row_weights: &[F], + field_cfg: &F::Config, +) -> Result<[F; SHA_WORD_BITS], ProductionShaError> +where + F: DelayedFieldProductSum, +{ + let active_len = SHA_ROW_COUNT.saturating_sub(shift); + if active_len == 0 { + return Ok(std::array::from_fn(|_| F::zero_with_cfg(field_cfg))); + } + let weights = &row_weights[..active_len]; + let values_start = shift; + let values_end = shift + active_len; + let bits = cfg_into_iter!(0..SHA_WORD_BITS) + .map(|bit| { + let table_idx = bit_slice_index(col.index(), bit, SHA_WORD_BITS); + let column = + trace + .bit_slices + .get(table_idx) + .ok_or(ProductionShaError::LengthMismatch { + label: "SHA endpoint bit column", + got: table_idx, + expected: trace.bit_slices.len(), + })?; + FieldFieldInnerProduct::inner_product::( + weights, + &column.evaluations[values_start..values_end], + F::zero_with_cfg(field_cfg), + ) + .map_err(|_| { + ProductionShaError::NonCanonicalProofObject( + "SHA endpoint bit row-weight dot product failed", + ) + }) + }) + .collect::, _>>()?; + bits.try_into() + .map_err(|bits: Vec| ProductionShaError::LengthMismatch { + label: "SHA endpoint word bits", + got: bits.len(), + expected: SHA_WORD_BITS, + }) +} + +fn sha_endpoint_int_with_row_weights( + trace: &ProjectedTrace, + col: ShaIntCol, + row_weights: &[F], + field_cfg: &F::Config, +) -> Result> +where + F: DelayedFieldProductSum, +{ + let column = trace + .int_columns + .get(col.index()) + .ok_or(ProductionShaError::LengthMismatch { + label: "SHA endpoint int column", + got: col.index(), + expected: trace.int_columns.len(), + })?; + FieldFieldInnerProduct::inner_product::( + row_weights, + &column.evaluations, + F::zero_with_cfg(field_cfg), + ) + .map_err(|_| { + ProductionShaError::NonCanonicalProofObject( + "SHA endpoint int row-weight dot product failed", + ) + }) +} + +pub fn validate_sha_endpoint_layout( + endpoint_evals: &ShaEndpointEvals, +) -> Result<(), ProductionShaError> +where + F: PrimeField, +{ + let word_sources = production_sha_endpoint_word_sources(); + if endpoint_evals.sources.len() != word_sources.len() { + return Err(ProductionShaError::LengthMismatch { + label: "SHA endpoint word-source count", + got: endpoint_evals.sources.len(), + expected: word_sources.len(), + }); + } + for (got, expected) in endpoint_evals.sources.iter().zip(word_sources.iter()) { + if (got.col, got.shift) != *expected { + return Err(ProductionShaError::NonCanonicalProofObject( + "SHA endpoint word sources are not in canonical order", + )); + } + } + + let int_sources = production_sha_endpoint_int_sources(); + if endpoint_evals.int_sources.len() != int_sources.len() { + return Err(ProductionShaError::LengthMismatch { + label: "SHA endpoint int-source count", + got: endpoint_evals.int_sources.len(), + expected: int_sources.len(), + }); + } + for (got, expected) in endpoint_evals.int_sources.iter().zip(int_sources.iter()) { + if got.col != *expected { + return Err(ProductionShaError::NonCanonicalProofObject( + "SHA endpoint int sources are not in canonical order", + )); + } + } + Ok(()) +} + +pub fn prove_sha_endpoint_multipoint( + transcript: &mut impl Transcript, + folded_trace: &ProjectedTrace, + folded_public: &ProjectedPublic, + endpoint_evals: &ShaEndpointEvals, + r_star: &[F], + field_cfg: &F::Config, +) -> Result<(MultipointEvalProof, Vec), ProductionShaError> +where + F: InnerTransparentField + + DelayedFieldProductSum + + FromPrimitiveWithConfig + + Send + + Sync + + 'static, + F::Inner: Transcribable + Zero + Default + Send + Sync, + F::Modulus: Transcribable, +{ + let row_weights = build_eq_x_r_vec(r_star, field_cfg)?; + prove_sha_endpoint_multipoint_with_row_weights( + transcript, + folded_trace, + folded_public, + endpoint_evals, + r_star, + &row_weights, + field_cfg, + ) +} + +pub fn prove_sha_endpoint_multipoint_with_row_weights( + transcript: &mut impl Transcript, + folded_trace: &ProjectedTrace, + folded_public: &ProjectedPublic, + endpoint_evals: &ShaEndpointEvals, + r_star: &[F], + row_weights: &[F], + field_cfg: &F::Config, +) -> Result<(MultipointEvalProof, Vec), ProductionShaError> +where + F: InnerTransparentField + + DelayedFieldProductSum + + FromPrimitiveWithConfig + + Send + + Sync + + 'static, + F::Inner: Transcribable + Zero + Default + Send + Sync, + F::Modulus: Transcribable, +{ + if row_weights.len() != SHA_ROW_COUNT { + return Err(ProductionShaError::LengthMismatch { + label: "row weights", + got: row_weights.len(), + expected: SHA_ROW_COUNT, + }); + } + validate_sha_endpoint_layout(endpoint_evals)?; + let layout = tracing::info_span!( + target: "zinc_protocol::production_sha", + "multipoint_layout", + side = "prove", + phase = "multipoint_layout", + ) + .in_scope(production_sha_multipoint_layout); + let trace_mles = tracing::info_span!( + target: "zinc_protocol::production_sha", + "multipoint_trace_mles", + side = "prove", + phase = "multipoint_trace_mles", + sources = layout.sources.len(), + ) + .in_scope(|| sha_multipoint_trace_mles(folded_trace, folded_public, &layout, field_cfg))?; + let up_evals = tracing::info_span!( + target: "zinc_protocol::production_sha", + "multipoint_up_evals", + side = "prove", + phase = "multipoint_up_evals", + sources = layout.sources.len(), + ) + .in_scope(|| { + sha_multipoint_up_evals_with_row_weights( + endpoint_evals, + folded_public, + row_weights, + &layout, + field_cfg, + ) + })?; + let (shift_specs, down_evals) = tracing::info_span!( + target: "zinc_protocol::production_sha", + "multipoint_down_evals", + side = "prove", + phase = "multipoint_down_evals", + shifts = layout.shifts.len(), + ) + .in_scope(|| { + sha_multipoint_shift_specs_and_down_evals_with_row_weights( + endpoint_evals, + folded_public, + row_weights, + &layout, + field_cfg, + ) + })?; + tracing::info_span!( + target: "zinc_protocol::production_sha", + "multipoint_sumcheck", + side = "prove", + phase = "multipoint_sumcheck", + sources = trace_mles.len(), + shifts = shift_specs.len(), + ) + .in_scope(|| { + prove_multipoint_reduction( + transcript, + &trace_mles, + r_star, + &up_evals, + &down_evals, + &shift_specs, + field_cfg, + ) + .map_err(ProductionShaError::from) + }) +} + +pub fn verify_sha_endpoint_multipoint( + transcript: &mut impl Transcript, + proof: &MultipointEvalProof, + endpoint_evals: &ShaEndpointEvals, + folded_public: &ProjectedPublic, + r_star: &[F], + field_cfg: &F::Config, +) -> Result<(MultipointSubclaim, Vec), ProductionShaError> +where + F: InnerTransparentField + + DelayedFieldProductSum + + FromPrimitiveWithConfig + + Send + + Sync + + 'static, + F::Inner: Transcribable + Zero + Default + Send + Sync, + F::Modulus: Transcribable, +{ + validate_sha_endpoint_layout(endpoint_evals)?; + let layout = production_sha_multipoint_layout(); + let up_evals = + sha_multipoint_up_evals(endpoint_evals, folded_public, r_star, &layout, field_cfg)?; + let (shift_specs, down_evals) = sha_multipoint_shift_specs_and_down_evals( + endpoint_evals, + folded_public, + r_star, + &layout, + field_cfg, + )?; + let subclaim = verify_multipoint_reduction( + transcript, + proof.clone(), + r_star, + &up_evals, + &down_evals, + &shift_specs, + SHA_ROW_VARS, + field_cfg, + )?; + Ok((subclaim, shift_specs)) +} + +pub fn verify_sha_endpoint_multipoint_open_evals( + subclaim: &MultipointSubclaim, + open_evals: &[F], + shift_specs: &[ShiftSpec], + field_cfg: &F::Config, +) -> Result<(), ProductionShaError> +where + F: InnerTransparentField + + DelayedFieldProductSum + + FromPrimitiveWithConfig + + Send + + Sync + + 'static, + F::Inner: Transcribable + Zero + Default + Send + Sync, + F::Modulus: Transcribable, +{ + Ok(MultipointEval::verify_subclaim( + subclaim, + open_evals, + shift_specs, + field_cfg, + )?) +} + +#[allow(clippy::too_many_arguments)] +pub fn reconstruct_folded_row_terminal_from_endpoints( + endpoint_evals: &ShaEndpointEvals, + folded_public: &ProjectedPublic, + r_ic: &[F; SHA_ROW_VARS], + r_star: &[F], + a: &F, + lambda: &F, + rho: &F, + xi: &F, + booleanity_sources: &[ShaBooleanitySource], + field_cfg: &F::Config, +) -> Result> +where + F: DelayedFieldProductSum, +{ + if r_star.len() != SHA_ROW_VARS { + return Err(ProductionShaError::LengthMismatch { + label: "r_star", + got: r_star.len(), + expected: SHA_ROW_VARS, + }); + } + let row_weights = build_eq_x_r_vec(r_star, field_cfg)?; + reconstruct_folded_row_terminal_from_endpoints_with_row_weights( + endpoint_evals, + folded_public, + r_ic, + r_star, + &row_weights, + a, + lambda, + rho, + xi, + booleanity_sources, + field_cfg, + ) +} + +#[allow(clippy::too_many_arguments)] +pub fn reconstruct_folded_row_terminal_from_endpoints_with_row_weights( + endpoint_evals: &ShaEndpointEvals, + folded_public: &ProjectedPublic, + r_ic: &[F; SHA_ROW_VARS], + r_star: &[F], + row_weights: &[F], + a: &F, + lambda: &F, + rho: &F, + xi: &F, + booleanity_sources: &[ShaBooleanitySource], + field_cfg: &F::Config, +) -> Result> +where + F: DelayedFieldProductSum, +{ + if r_star.len() != SHA_ROW_VARS { + return Err(ProductionShaError::LengthMismatch { + label: "r_star", + got: r_star.len(), + expected: SHA_ROW_VARS, + }); + } + if row_weights.len() != SHA_ROW_COUNT { + return Err(ProductionShaError::LengthMismatch { + label: "row weights", + got: row_weights.len(), + expected: SHA_ROW_COUNT, + }); + } + let a_powers = build_sha_residual_eval_powers(a, field_cfg); + let lambda_powers = build_sha_lambda_powers(lambda, field_cfg); + let booleanity_weights = build_booleanity_weights(rho, xi, booleanity_sources.len(), field_cfg); + reconstruct_folded_row_terminal_from_endpoints_with_vectors( + endpoint_evals, + folded_public, + r_ic, + r_star, + row_weights, + &a_powers, + &lambda_powers, + &booleanity_weights, + booleanity_sources, + field_cfg, + ) +} + +#[allow(clippy::too_many_arguments)] +pub fn reconstruct_folded_row_terminal_from_endpoints_with_vectors( + endpoint_evals: &ShaEndpointEvals, + folded_public: &ProjectedPublic, + r_ic: &[F; SHA_ROW_VARS], + r_star: &[F], + row_weights: &[F], + a_powers: &[F], + lambda_powers: &[F], + booleanity_weights: &[F], + booleanity_sources: &[ShaBooleanitySource], + field_cfg: &F::Config, +) -> Result> +where + F: DelayedFieldProductSum, +{ + if r_star.len() != SHA_ROW_VARS { + return Err(ProductionShaError::LengthMismatch { + label: "r_star", + got: r_star.len(), + expected: SHA_ROW_VARS, + }); + } + if row_weights.len() != SHA_ROW_COUNT { + return Err(ProductionShaError::LengthMismatch { + label: "row weights", + got: row_weights.len(), + expected: SHA_ROW_COUNT, + }); + } + if a_powers.len() < SHA_IDEAL_EVAL_POWER_COUNT { + return Err(ProductionShaError::LengthMismatch { + label: "a powers", + got: a_powers.len(), + expected: SHA_IDEAL_EVAL_POWER_COUNT, + }); + } + if lambda_powers.len() != NUM_SHA_RESIDUAL_FAMILIES { + return Err(ProductionShaError::LengthMismatch { + label: "lambda powers", + got: lambda_powers.len(), + expected: NUM_SHA_RESIDUAL_FAMILIES, + }); + } + if booleanity_weights.len() != booleanity_sources.len() { + return Err(ProductionShaError::LengthMismatch { + label: "booleanity weights", + got: booleanity_weights.len(), + expected: booleanity_sources.len(), + }); + } + validate_sha_endpoint_layout(endpoint_evals)?; + verify_endpoint_scalarization_with_powers(endpoint_evals, a_powers, field_cfg)?; + + let residuals = residual_polys_from_endpoints_with_row_weights( + endpoint_evals, + folded_public, + row_weights, + field_cfg, + )?; + let linear = lambda_weighted_sha_residual_polys_at_powers( + &residuals, + a_powers, + lambda_powers, + field_cfg, + )?; + + let mut bool_terms = Vec::with_capacity(booleanity_sources.len()); + for source in booleanity_sources { + let d = booleanity_endpoint_value(endpoint_evals, source, field_cfg)?; + bool_terms.push(d.clone() * (d - F::one_with_cfg(field_cfg))); + } + let bool_sum = FieldFieldInnerProduct::inner_product::( + booleanity_weights, + &bool_terms, + F::zero_with_cfg(field_cfg), + ) + .map_err(|_| { + ProductionShaError::NonCanonicalProofObject("endpoint booleanity dot product failed") + })?; + + let row_weight = eq_eval(r_ic, r_star, F::one_with_cfg(field_cfg))?; + Ok(row_weight * (linear + bool_sum)) +} + +pub fn verify_endpoint_scalarization( + endpoint_evals: &ShaEndpointEvals, + a: &F, + field_cfg: &F::Config, +) -> Result<(), ProductionShaError> +where + F: DelayedFieldProductSum, +{ + let powers = zinc_utils::powers(a.clone(), F::one_with_cfg(field_cfg), 32); + verify_endpoint_scalarization_with_powers(endpoint_evals, &powers, field_cfg) +} + +pub fn verify_endpoint_scalarization_with_powers( + endpoint_evals: &ShaEndpointEvals, + a_powers: &[F], + field_cfg: &F::Config, +) -> Result<(), ProductionShaError> +where + F: DelayedFieldProductSum, +{ + if a_powers.len() < SHA_WORD_BITS { + return Err(ProductionShaError::LengthMismatch { + label: "endpoint scalarization powers", + got: a_powers.len(), + expected: SHA_WORD_BITS, + }); + } + for source in &endpoint_evals.sources { + let recombined = zinc_utils::inner_product::FieldFieldInnerProduct::inner_product::< + UNCHECKED, + >( + &source.bits, + &a_powers[..SHA_WORD_BITS], + F::zero_with_cfg(field_cfg), + ) + .map_err(|_| { + ProductionShaError::NonCanonicalProofObject("endpoint scalarization dot product failed") + })?; + if recombined != source.scalarized { + return Err(ProductionShaError::EndpointScalarizationMismatch { + col: source.col, + shift: source.shift, + }); + } + } + Ok(()) +} + +pub fn reconstruct_virtual_ch_maj_endpoint( + endpoint_evals: &ShaEndpointEvals, + field_cfg: &F::Config, +) -> Result, ProductionShaError> +where + F: PrimeField, +{ + let two = F::one_with_cfg(field_cfg) + F::one_with_cfg(field_cfg); + let bits = |col, shift| source_bits(endpoint_evals, col, shift); + let e0 = bits(ShaWordCol::E, 0)?; + let e1 = bits(ShaWordCol::E, 1)?; + let e2 = bits(ShaWordCol::E, 2)?; + let a0 = bits(ShaWordCol::A, 0)?; + let a1 = bits(ShaWordCol::A, 1)?; + let a2 = bits(ShaWordCol::A, 2)?; + let uef2 = bits(ShaWordCol::Uef, 2)?; + let uneg_eg2 = bits(ShaWordCol::UNegEg, 2)?; + let ch2_comp0 = bits(ShaWordCol::Ch2Comp, 0)?; + let maj2 = bits(ShaWordCol::Maj, 2)?; + let maj_comp0 = bits(ShaWordCol::MajComp, 0)?; + + Ok(VirtualChMajEndpoint { + ch1: std::array::from_fn(|idx| e2[idx].clone() + &e1[idx] - two.clone() * &uef2[idx]), + ch2: std::array::from_fn(|idx| { + e2[idx].clone() - &e0[idx] + + two.clone() * &uneg_eg2[idx] + + two.clone() * &ch2_comp0[idx] + }), + maj: std::array::from_fn(|idx| { + a0[idx].clone() + &a1[idx] + &a2[idx] + - two.clone() * &maj2[idx] + - two.clone() * &maj_comp0[idx] + }), + }) +} + +pub fn booleanity_endpoint_value( + endpoint_evals: &ShaEndpointEvals, + source: &ShaBooleanitySource, + field_cfg: &F::Config, +) -> Result> +where + F: PrimeField, +{ + match source { + ShaBooleanitySource::WordBit { col, bit } => Ok(source_bits(endpoint_evals, *col, 0)? + .get(*bit) + .cloned() + .ok_or(ProductionShaError::LengthMismatch { + label: "endpoint bit", + got: *bit, + expected: 32, + })?), + ShaBooleanitySource::VirtualCh1 { bit } => Ok(reconstruct_virtual_ch_maj_endpoint( + endpoint_evals, + field_cfg, + )? + .ch1 + .get(*bit) + .cloned() + .ok_or(ProductionShaError::LengthMismatch { + label: "virtual Ch1 bit", + got: *bit, + expected: 32, + })?), + ShaBooleanitySource::VirtualCh2 { bit } => Ok(reconstruct_virtual_ch_maj_endpoint( + endpoint_evals, + field_cfg, + )? + .ch2 + .get(*bit) + .cloned() + .ok_or(ProductionShaError::LengthMismatch { + label: "virtual Ch2 bit", + got: *bit, + expected: 32, + })?), + ShaBooleanitySource::VirtualMaj { bit } => Ok(reconstruct_virtual_ch_maj_endpoint( + endpoint_evals, + field_cfg, + )? + .maj + .get(*bit) + .cloned() + .ok_or(ProductionShaError::LengthMismatch { + label: "virtual Maj bit", + got: *bit, + expected: 32, + })?), + } +} + +fn sha_multipoint_trace_mles( + folded_trace: &ProjectedTrace, + folded_public: &ProjectedPublic, + layout: &ShaMultipointLayout, + field_cfg: &F::Config, +) -> Result>, ProductionShaError> +where + F: InnerTransparentField + Sync, + F::Inner: Send + Sync, +{ + let zero_inner = F::zero_with_cfg(field_cfg).inner().clone(); + cfg_iter!(layout.sources) + .map(|source| { + Ok(DenseMultilinearExtension::from_evaluations_vec( + SHA_ROW_VARS, + sha_mp_source_column_values(folded_trace, folded_public, *source)? + .iter() + .map(|value| value.inner().clone()) + .collect(), + zero_inner.clone(), + )) + }) + .collect() +} + +fn sha_multipoint_up_evals( + endpoint_evals: &ShaEndpointEvals, + folded_public: &ProjectedPublic, + r_star: &[F], + layout: &ShaMultipointLayout, + field_cfg: &F::Config, +) -> Result, ProductionShaError> +where + F: PrimeField, +{ + let row_weights = build_eq_x_r_vec(r_star, field_cfg)?; + sha_multipoint_up_evals_with_row_weights( + endpoint_evals, + folded_public, + &row_weights, + layout, + field_cfg, + ) +} + +fn sha_multipoint_up_evals_with_row_weights( + endpoint_evals: &ShaEndpointEvals, + folded_public: &ProjectedPublic, + row_weights: &[F], + layout: &ShaMultipointLayout, + field_cfg: &F::Config, +) -> Result, ProductionShaError> +where + F: PrimeField, +{ + layout + .sources + .iter() + .map(|source| { + sha_mp_source_endpoint_value_with_row_weights( + endpoint_evals, + folded_public, + row_weights, + *source, + field_cfg, + ) + }) + .collect() +} + +fn sha_multipoint_shift_specs_and_down_evals( + endpoint_evals: &ShaEndpointEvals, + folded_public: &ProjectedPublic, + r_star: &[F], + layout: &ShaMultipointLayout, + field_cfg: &F::Config, +) -> Result<(Vec, Vec), ProductionShaError> +where + F: PrimeField, +{ + let row_weights = build_eq_x_r_vec(r_star, field_cfg)?; + sha_multipoint_shift_specs_and_down_evals_with_row_weights( + endpoint_evals, + folded_public, + &row_weights, + layout, + field_cfg, + ) +} + +fn sha_multipoint_shift_specs_and_down_evals_with_row_weights( + endpoint_evals: &ShaEndpointEvals, + folded_public: &ProjectedPublic, + row_weights: &[F], + layout: &ShaMultipointLayout, + field_cfg: &F::Config, +) -> Result<(Vec, Vec), ProductionShaError> +where + F: PrimeField, +{ + let mut specs = Vec::with_capacity(layout.shifts.len()); + let mut evals = Vec::with_capacity(layout.shifts.len()); + for shift in &layout.shifts { + let (source, amount, value) = match *shift { + ShaMpShiftSource::WordBit { col, bit, shift } => ( + ShaMpSource::WordBit { col, bit }, + shift, + source_bits(endpoint_evals, col, shift)? + .get(bit) + .cloned() + .ok_or(ProductionShaError::LengthMismatch { + label: "shifted word bit", + got: bit, + expected: 32, + })?, + ), + ShaMpShiftSource::Public { col, shift } => ( + ShaMpSource::Public { col }, + shift, + sha_public_at_point_with_weights( + folded_public, + col, + shift, + row_weights, + field_cfg, + )?, + ), + }; + let source_idx = layout + .sources + .iter() + .position(|candidate| *candidate == source) + .ok_or(ProductionShaError::LengthMismatch { + label: "multipoint shift source", + got: 0, + expected: 1, + })?; + specs.push(ShiftSpec::new(source_idx, amount)); + evals.push(value); + } + Ok((specs, evals)) +} + +fn sha_mp_source_column_values<'a, F>( + folded_trace: &'a ProjectedTrace, + folded_public: &'a ProjectedPublic, + source: ShaMpSource, +) -> Result<&'a [F], ProductionShaError> +where + F: PrimeField, +{ + let values = + match source { + ShaMpSource::WordBit { col, bit } => folded_trace + .bit_slices + .get(bit_slice_index(col.index(), bit, SHA_WORD_BITS)) + .ok_or(ProductionShaError::LengthMismatch { + label: "multipoint word bit source", + got: bit_slice_index(col.index(), bit, SHA_WORD_BITS), + expected: folded_trace.bit_slices.len(), + })?, + ShaMpSource::Int { col } => folded_trace.int_columns.get(col.index()).ok_or( + ProductionShaError::LengthMismatch { + label: "multipoint int source", + got: col.index(), + expected: folded_trace.int_columns.len(), + }, + )?, + ShaMpSource::Public { col } => folded_public.columns.get(col.index()).ok_or( + ProductionShaError::LengthMismatch { + label: "multipoint public source", + got: col.index(), + expected: folded_public.columns.len(), + }, + )?, + }; + if values.num_vars != SHA_ROW_VARS || values.evaluations.len() != SHA_ROW_COUNT { + return Err(ProductionShaError::LengthMismatch { + label: "multipoint source rows", + got: values.evaluations.len(), + expected: SHA_ROW_COUNT, + }); + } + Ok(&values.evaluations) +} + +fn sha_mp_source_endpoint_value_with_row_weights( + endpoint_evals: &ShaEndpointEvals, + folded_public: &ProjectedPublic, + row_weights: &[F], + source: ShaMpSource, + field_cfg: &F::Config, +) -> Result> +where + F: PrimeField, +{ + match source { + ShaMpSource::WordBit { col, bit } => source_bits(endpoint_evals, col, 0)? + .get(bit) + .cloned() + .ok_or(ProductionShaError::LengthMismatch { + label: "endpoint word bit", + got: bit, + expected: 32, + }), + ShaMpSource::Int { col } => endpoint_evals + .int_sources + .iter() + .find(|source| source.col == col) + .map(|source| source.scalar.clone()) + .ok_or(ProductionShaError::LengthMismatch { + label: "endpoint int source", + got: endpoint_evals.int_sources.len(), + expected: ShaIntCol::COUNT, + }), + ShaMpSource::Public { col } => Ok(sha_public_at_point_with_weights( + folded_public, + col, + 0, + row_weights, + field_cfg, + )?), + } +} + +fn residual_polys_from_endpoints_with_row_weights( + endpoint_evals: &ShaEndpointEvals, + folded_public: &ProjectedPublic, + row_weights: &[F], + field_cfg: &F::Config, +) -> Result>, ProductionShaError> +where + F: DelayedFieldProductSum, +{ + if row_weights.len() != SHA_ROW_COUNT { + return Err(ProductionShaError::LengthMismatch { + label: "row weights", + got: row_weights.len(), + expected: SHA_ROW_COUNT, + }); + } + let one = F::one_with_cfg(field_cfg); + let two = one.clone() + &one; + let rho_sig0 = sparse_endpoint_poly::(&[10, 19, 30], field_cfg); + let rho_sig1 = sparse_endpoint_poly::(&[7, 21, 26], field_cfg); + + let a = endpoint_word_poly(endpoint_evals, ShaWordCol::A, 0, field_cfg)?; + let e = endpoint_word_poly(endpoint_evals, ShaWordCol::E, 0, field_cfg)?; + let sigma0 = endpoint_word_poly(endpoint_evals, ShaWordCol::Sigma0, 0, field_cfg)?; + let sigma1 = endpoint_word_poly(endpoint_evals, ShaWordCol::Sigma1, 0, field_cfg)?; + let w = endpoint_word_poly(endpoint_evals, ShaWordCol::W, 0, field_cfg)?; + let small_sigma0 = endpoint_word_poly(endpoint_evals, ShaWordCol::SmallSigma0, 0, field_cfg)?; + let small_sigma1 = endpoint_word_poly(endpoint_evals, ShaWordCol::SmallSigma1, 0, field_cfg)?; + let ov_sigma0 = endpoint_word_poly(endpoint_evals, ShaWordCol::OvSigma0, 0, field_cfg)?; + let ov_sigma1 = endpoint_word_poly(endpoint_evals, ShaWordCol::OvSigma1, 0, field_cfg)?; + let ov_small_sigma0 = + endpoint_word_poly(endpoint_evals, ShaWordCol::OvSmallSigma0, 0, field_cfg)?; + let ov_small_sigma1 = + endpoint_word_poly(endpoint_evals, ShaWordCol::OvSmallSigma1, 0, field_cfg)?; + + let r0 = (&a * &rho_sig0) - &sigma0 - &scale_endpoint_poly(&ov_sigma0, &two); + let r1 = (&e * &rho_sig1) - &sigma1 - &scale_endpoint_poly(&ov_sigma1, &two); + let r2 = endpoint_word_poly(endpoint_evals, ShaWordCol::W, 0, field_cfg)?.rot_c(25) + + &endpoint_word_poly(endpoint_evals, ShaWordCol::W, 0, field_cfg)?.rot_c(14) + + &endpoint_word_poly(endpoint_evals, ShaWordCol::W, 0, field_cfg)?.shift_r_c(3) + - &small_sigma0 + - &scale_endpoint_poly(&ov_small_sigma0, &two); + let r3 = endpoint_word_poly(endpoint_evals, ShaWordCol::W, 0, field_cfg)?.rot_c(15) + + &endpoint_word_poly(endpoint_evals, ShaWordCol::W, 0, field_cfg)?.rot_c(13) + + &endpoint_word_poly(endpoint_evals, ShaWordCol::W, 0, field_cfg)?.shift_r_c(10) + - &small_sigma1 + - &scale_endpoint_poly(&ov_small_sigma1, &two); + + let mu_w = endpoint_mu_contribution(endpoint_evals, 0, 2, field_cfg)?; + let mu_a = endpoint_mu_contribution(endpoint_evals, 2, 5, field_cfg)?; + let mu_e = endpoint_mu_contribution(endpoint_evals, 5, 8, field_cfg)?; + let mu_ff_a = endpoint_mu_contribution(endpoint_evals, 8, 9, field_cfg)?; + let mu_ff_e = endpoint_mu_contribution(endpoint_evals, 9, 10, field_cfg)?; + + let r4 = endpoint_word_poly(endpoint_evals, ShaWordCol::W, 16, field_cfg)? + - &w + - &endpoint_word_poly(endpoint_evals, ShaWordCol::SmallSigma0, 1, field_cfg)? + - &endpoint_word_poly(endpoint_evals, ShaWordCol::W, 9, field_cfg)? + - &endpoint_word_poly(endpoint_evals, ShaWordCol::SmallSigma1, 14, field_cfg)? + + &mu_w + + &endpoint_int_const_poly(endpoint_evals, ShaIntCol::CompSchedule, field_cfg)?; + + let r5 = endpoint_word_poly(endpoint_evals, ShaWordCol::A, 4, field_cfg)? + - &e + - &endpoint_word_poly(endpoint_evals, ShaWordCol::Sigma1, 3, field_cfg)? + - &endpoint_word_poly(endpoint_evals, ShaWordCol::Uef, 3, field_cfg)? + - &endpoint_word_poly(endpoint_evals, ShaWordCol::UNegEg, 3, field_cfg)? + - &endpoint_public_const_poly_with_row_weights( + folded_public, + ShaPublicCol::K, + 3, + row_weights, + field_cfg, + )? + - &w + - &endpoint_word_poly(endpoint_evals, ShaWordCol::Sigma0, 3, field_cfg)? + - &endpoint_word_poly(endpoint_evals, ShaWordCol::Maj, 3, field_cfg)? + + &mu_a + + &endpoint_int_const_poly(endpoint_evals, ShaIntCol::CompUpdateA, field_cfg)?; + + let r6 = endpoint_word_poly(endpoint_evals, ShaWordCol::E, 4, field_cfg)? + - &a + - &e + - &endpoint_word_poly(endpoint_evals, ShaWordCol::Sigma1, 3, field_cfg)? + - &endpoint_word_poly(endpoint_evals, ShaWordCol::Uef, 3, field_cfg)? + - &endpoint_word_poly(endpoint_evals, ShaWordCol::UNegEg, 3, field_cfg)? + - &endpoint_public_const_poly_with_row_weights( + folded_public, + ShaPublicCol::K, + 3, + row_weights, + field_cfg, + )? + - &w + + &mu_e + + &endpoint_int_const_poly(endpoint_evals, ShaIntCol::CompUpdateE, field_cfg)?; + + let s_init = sha_public_at_point_with_weights( + folded_public, + ShaPublicCol::SInit, + 0, + row_weights, + field_cfg, + )?; + let s_msg = sha_public_at_point_with_weights( + folded_public, + ShaPublicCol::SMsg, + 0, + row_weights, + field_cfg, + )?; + let s_sched = sha_public_at_point_with_weights( + folded_public, + ShaPublicCol::SSched, + 0, + row_weights, + field_cfg, + )?; + let s_upd = sha_public_at_point_with_weights( + folded_public, + ShaPublicCol::SUpd, + 0, + row_weights, + field_cfg, + )?; + let s_ff = sha_public_at_point_with_weights( + folded_public, + ShaPublicCol::SFf, + 0, + row_weights, + field_cfg, + )?; + let s_out = sha_public_at_point_with_weights( + folded_public, + ShaPublicCol::SOut, + 0, + row_weights, + field_cfg, + )?; + + let r7 = scale_endpoint_poly( + &(a.clone() + - &endpoint_public_word_or_const_poly_with_row_weights( + folded_public, + ShaPublicCol::PAIn, + row_weights, + field_cfg, + )?), + &s_init, + ) + &scale_endpoint_poly( + &(a.clone() + - &endpoint_public_word_or_const_poly_with_row_weights( + folded_public, + ShaPublicCol::PAOut, + row_weights, + field_cfg, + )?), + &s_out, + ); + let r8 = scale_endpoint_poly( + &(e.clone() + - &endpoint_public_word_or_const_poly_with_row_weights( + folded_public, + ShaPublicCol::PEIn, + row_weights, + field_cfg, + )?), + &s_init, + ) + &scale_endpoint_poly( + &(e.clone() + - &endpoint_public_word_or_const_poly_with_row_weights( + folded_public, + ShaPublicCol::PEOut, + row_weights, + field_cfg, + )?), + &s_out, + ); + + let r9 = endpoint_word_poly(endpoint_evals, ShaWordCol::A, 4, field_cfg)? + - &a + - &endpoint_public_const_poly_with_row_weights( + folded_public, + ShaPublicCol::PAIn, + 0, + row_weights, + field_cfg, + )? + + &mu_ff_a + + &endpoint_int_const_poly(endpoint_evals, ShaIntCol::CompFeedForwardA, field_cfg)?; + let r10 = endpoint_word_poly(endpoint_evals, ShaWordCol::E, 4, field_cfg)? + - &e + - &endpoint_public_const_poly_with_row_weights( + folded_public, + ShaPublicCol::PEIn, + 0, + row_weights, + field_cfg, + )? + + &mu_ff_e + + &endpoint_int_const_poly(endpoint_evals, ShaIntCol::CompFeedForwardE, field_cfg)?; + let r11 = scale_endpoint_poly( + &(w - &endpoint_public_word_or_const_poly_with_row_weights( + folded_public, + ShaPublicCol::Message, + row_weights, + field_cfg, + )?), + &s_msg, + ); + + let comp_schedule = + endpoint_int_const_poly(endpoint_evals, ShaIntCol::CompSchedule, field_cfg)?; + let comp_update_a = endpoint_int_const_poly(endpoint_evals, ShaIntCol::CompUpdateA, field_cfg)?; + let comp_update_e = endpoint_int_const_poly(endpoint_evals, ShaIntCol::CompUpdateE, field_cfg)?; + let comp_ff_a = + endpoint_int_const_poly(endpoint_evals, ShaIntCol::CompFeedForwardA, field_cfg)?; + let comp_ff_e = + endpoint_int_const_poly(endpoint_evals, ShaIntCol::CompFeedForwardE, field_cfg)?; + + let r12 = scale_endpoint_poly(&comp_schedule, &s_sched); + let r13 = scale_endpoint_poly(&comp_update_a, &s_upd); + let r14 = scale_endpoint_poly(&comp_update_e, &s_upd); + let r15 = scale_endpoint_poly(&comp_ff_a, &s_ff); + let r16 = scale_endpoint_poly(&comp_ff_e, &s_ff); + let r17 = endpoint_word_poly(endpoint_evals, ShaWordCol::MuPacked, 0, field_cfg)?.shift_r_c(10); + + let mut residuals = vec![ + r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, r10, r11, r12, r13, r14, r15, r16, r17, + ]; + residuals.iter_mut().for_each(DynamicPolynomialF::trim); + Ok(residuals) +} + +fn endpoint_word_poly( + endpoint_evals: &ShaEndpointEvals, + col: ShaWordCol, + shift: usize, + field_cfg: &F::Config, +) -> Result, ProductionShaError> +where + F: PrimeField, +{ + let mut coeffs = source_bits(endpoint_evals, col, shift)?.to_vec(); + coeffs.resize(32, F::zero_with_cfg(field_cfg)); + Ok(DynamicPolynomialF::new_trimmed(coeffs)) +} + +fn endpoint_int_const_poly( + endpoint_evals: &ShaEndpointEvals, + col: ShaIntCol, + field_cfg: &F::Config, +) -> Result, ProductionShaError> +where + F: PrimeField, +{ + let value = endpoint_evals + .int_sources + .iter() + .find(|source| source.col == col) + .map(|source| source.scalar.clone()) + .ok_or(ProductionShaError::LengthMismatch { + label: "endpoint int source", + got: endpoint_evals.int_sources.len(), + expected: ShaIntCol::COUNT, + })?; + Ok(endpoint_const_poly(value, field_cfg)) +} + +fn endpoint_public_const_poly_with_row_weights( + folded_public: &ProjectedPublic, + col: ShaPublicCol, + shift: usize, + row_weights: &[F], + field_cfg: &F::Config, +) -> Result, ProductionShaError> +where + F: PrimeField, +{ + Ok(endpoint_const_poly( + sha_public_at_point_with_weights(folded_public, col, shift, row_weights, field_cfg)?, + field_cfg, + )) +} + +fn endpoint_public_word_or_const_poly_with_row_weights( + folded_public: &ProjectedPublic, + col: ShaPublicCol, + row_weights: &[F], + field_cfg: &F::Config, +) -> Result, ProductionShaError> +where + F: DelayedFieldProductSum, +{ + let Some(col_idx) = public_word_col_index(col) else { + return endpoint_public_const_poly_with_row_weights( + folded_public, + col, + 0, + row_weights, + field_cfg, + ); + }; + let bit_slices = + folded_public + .bit_slices + .as_ref() + .ok_or(ProductionShaError::NonCanonicalProofObject( + "production SHA public word columns are required", + ))?; + let mut coeffs = Vec::with_capacity(SHA_WORD_BITS); + for bit in 0..SHA_WORD_BITS { + let table_idx = bit_slice_index(col_idx, bit, SHA_WORD_BITS); + let bit_column = bit_slices + .get(table_idx) + .ok_or(ProductionShaError::LengthMismatch { + label: "SHA public word bit column", + got: table_idx, + expected: bit_slices.len(), + })?; + if bit_column.num_vars != SHA_ROW_VARS || bit_column.evaluations.len() != SHA_ROW_COUNT { + return Err(ProductionShaError::LengthMismatch { + label: "SHA public word row count", + got: bit_column.evaluations.len(), + expected: SHA_ROW_COUNT, + }); + } + coeffs.push( + FieldFieldInnerProduct::inner_product::( + row_weights, + &bit_column.evaluations, + F::zero_with_cfg(field_cfg), + ) + .map_err(|_| { + ProductionShaError::NonCanonicalProofObject( + "SHA public word row-weight dot product failed", + ) + })?, + ); + } + Ok(DynamicPolynomialF::new_trimmed(coeffs)) +} + +fn endpoint_mu_contribution( + endpoint_evals: &ShaEndpointEvals, + low: usize, + high: usize, + field_cfg: &F::Config, +) -> Result, ProductionShaError> +where + F: PrimeField, +{ + let packed = endpoint_word_poly(endpoint_evals, ShaWordCol::MuPacked, 0, field_cfg)? + .shift_r_c(low as u32); + let tail = endpoint_word_poly(endpoint_evals, ShaWordCol::MuPacked, 0, field_cfg)? + .shift_r_c(high as u32); + let low_coeff = endpoint_pow_two(32, field_cfg); + let high_coeff = endpoint_pow_two(32 + high - low, field_cfg); + Ok(scale_endpoint_poly(&packed, &low_coeff) - &scale_endpoint_poly(&tail, &high_coeff)) +} + +fn sparse_endpoint_poly(positions: &[usize], field_cfg: &F::Config) -> DynamicPolynomialF +where + F: PrimeField, +{ + let mut coeffs = vec![F::zero_with_cfg(field_cfg); 32]; + for &pos in positions { + coeffs[pos] = F::one_with_cfg(field_cfg); + } + DynamicPolynomialF::new_trimmed(coeffs) +} + +fn scale_endpoint_poly(poly: &DynamicPolynomialF, scalar: &F) -> DynamicPolynomialF +where + F: PrimeField, +{ + if poly.is_zero() || F::is_zero(scalar) { + return DynamicPolynomialF::ZERO; + } + DynamicPolynomialF::new_trimmed( + poly.coeffs + .iter() + .map(|coeff| coeff.clone() * scalar) + .collect::>(), + ) +} + +fn endpoint_const_poly(value: F, field_cfg: &F::Config) -> DynamicPolynomialF +where + F: PrimeField, +{ + if F::is_zero(&value) { + DynamicPolynomialF::ZERO + } else { + let _ = field_cfg; + DynamicPolynomialF::constant_poly(value) + } +} + +fn endpoint_pow_two(exp: usize, field_cfg: &F::Config) -> F +where + F: PrimeField, +{ + let two = F::one_with_cfg(field_cfg) + F::one_with_cfg(field_cfg); + let mut out = F::one_with_cfg(field_cfg); + for _ in 0..exp { + out *= &two; + } + out +} + +fn source_bits( + endpoint_evals: &ShaEndpointEvals, + col: ShaWordCol, + shift: usize, +) -> Result<&[F; 32], ProductionShaError> +where + F: PrimeField, +{ + endpoint_evals + .sources + .iter() + .find(|source| source.col == col && source.shift == shift) + .map(|source| &source.bits) + .ok_or(ProductionShaError::MissingEndpointEval { col, shift }) +} + +fn sumfold_expected_eval( + beta: &[F], + fresh_targets: &[F], + r_b: &[F], + field_cfg: &F::Config, +) -> Result> +where + F: DelayedFieldProductSum, +{ + let d = eq_eval(beta, r_b, F::one_with_cfg(field_cfg))?; + let weights = build_eq_x_r_vec(r_b, field_cfg)?; + if weights.len() != fresh_targets.len() { + return Err(ProductionShaError::LengthMismatch { + label: "sumfold target weights", + got: weights.len(), + expected: fresh_targets.len(), + }); + } + let claim_at_r = FieldFieldInnerProduct::inner_product::( + &weights, + fresh_targets, + F::zero_with_cfg(field_cfg), + ) + .map_err(|_| { + ProductionShaError::NonCanonicalProofObject("SumFold expected-value dot product failed") + })?; + Ok(d * claim_at_r) +} + +fn require_single_sumcheck_group( + proof: &MultiDegreeSumcheckProof, + label: &'static str, +) -> Result<(), ProductionShaError> +where + F: PrimeField, +{ + let group_count = proof.degrees().len(); + if group_count != 1 { + return Err(ProductionShaError::UnexpectedSumcheckGroupCount { + label, + got: group_count, + }); + } + let claimed_count = proof.claimed_sums().len(); + if claimed_count != 1 { + return Err(ProductionShaError::LengthMismatch { + label: "sumcheck claimed sums", + got: claimed_count, + expected: 1, + }); + } + Ok(()) +} + +#[allow(dead_code)] +fn family_weight_index(family: ShaResidualFamily) -> usize { + family.index() +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::{fixed_prime, pcs::AllHyraxPCSTypes}; + use ark_ec::{CurveGroup, PrimeGroup}; + use core::fmt::Debug; + use crypto_primitives::{ + FromWithConfig, crypto_bigint_boxed_monty::BoxedMontyField, crypto_bigint_int::Int, + crypto_bigint_uint::Uint, + }; + use zinc_piop::neutron_nova::{ + SHA_ROW_COUNT, SHA_WORD_BITS, expression_folded_row_sum, fold_projected_traces, + }; + use zinc_poly::mle::MultilinearExtensionWithConfig; + use zinc_poly::univariate::{binary::BinaryPolyInnerProduct, dense::DensePolyInnerProduct}; + use zinc_primality::MillerRabin; + use zinc_test_uair::{ + EC_FP_INT_LIMBS, SHA256_INITIAL_STATE, Sha256CompressionSliceUair, + sha256::cols as sha256_cols, sha256_padded_message_blocks, + synthesize_sha256_chain_witnesses, + }; + use zinc_transcript::{Blake3Transcript, traits::Transcript}; + use zinc_utils::inner_product::{MBSInnerProduct, ScalarProduct}; + use zip_plus::{ + code::iprs::{IprsCode, PnttConfigF65537}, + pcs::{ + hyrax::{HyraxBlindingMode, HyraxPCS}, + structs::ZipTypes, + }, + }; + + type F = BoxedMontyField; + type ShaInt = Int; + const TEST_DEGREE_PLUS_ONE: usize = 32; + const TEST_REP: usize = 4; + const TEST_CHECKED: bool = false; + const TEST_FIELD_LIMBS: usize = 4; + const TEST_NUM_COLUMN_OPENINGS: usize = 150; + + fn cfg() -> ::Config { + fixed_prime::secp256k1_field_cfg::>() + } + + fn f(value: u64) -> F { + F::from_with_cfg(value, &cfg()) + } + + #[derive(Debug, Clone)] + struct TestShaBinaryZipTypes; + + impl ZipTypes for TestShaBinaryZipTypes { + const NUM_COLUMN_OPENINGS: usize = TEST_NUM_COLUMN_OPENINGS; + type Eval = BinaryPoly; + type Cw = DensePolynomial; + type Fmod = Uint; + type PrimeTest = MillerRabin; + type Chal = i128; + type Pt = i128; + type CombR = Int<{ EC_FP_INT_LIMBS * 4 }>; + type Comb = DensePolynomial; + type EvalDotChal = BinaryPolyInnerProduct; + type CombDotChal = DensePolyInnerProduct< + Self::CombR, + Self::Chal, + Self::CombR, + MBSInnerProduct, + TEST_DEGREE_PLUS_ONE, + >; + type ArrCombRDotChal = MBSInnerProduct; + } + + #[derive(Debug, Clone)] + struct TestShaArbitraryZipTypes; + + impl ZipTypes for TestShaArbitraryZipTypes { + const NUM_COLUMN_OPENINGS: usize = TEST_NUM_COLUMN_OPENINGS; + type Eval = DensePolynomial; + type Cw = DensePolynomial, TEST_DEGREE_PLUS_ONE>; + type Fmod = Uint; + type PrimeTest = MillerRabin; + type Chal = i128; + type Pt = i128; + type CombR = Int<{ EC_FP_INT_LIMBS * 4 }>; + type Comb = DensePolynomial; + type EvalDotChal = DensePolyInnerProduct< + ShaInt, + Self::Chal, + Self::CombR, + MBSInnerProduct, + TEST_DEGREE_PLUS_ONE, + >; + type CombDotChal = DensePolyInnerProduct< + Self::CombR, + Self::Chal, + Self::CombR, + MBSInnerProduct, + TEST_DEGREE_PLUS_ONE, + >; + type ArrCombRDotChal = MBSInnerProduct; + } + + #[derive(Debug, Clone)] + struct TestShaIntZipTypes; + + impl ZipTypes for TestShaIntZipTypes { + const NUM_COLUMN_OPENINGS: usize = TEST_NUM_COLUMN_OPENINGS; + type Eval = ShaInt; + type Cw = Int<6>; + type Fmod = Uint; + type PrimeTest = MillerRabin; + type Chal = i128; + type Pt = i128; + type CombR = Int<{ EC_FP_INT_LIMBS * 4 }>; + type Comb = Self::CombR; + type EvalDotChal = ScalarProduct; + type CombDotChal = ScalarProduct; + type ArrCombRDotChal = MBSInnerProduct; + } + + #[derive(Clone, Debug)] + struct TestShaZincTypes; + + impl ZincTypes for TestShaZincTypes { + type Int = ShaInt; + type Chal = i128; + type Pt = i128; + type Fmod = Uint; + type PrimeTest = MillerRabin; + + type BinaryZt = TestShaBinaryZipTypes; + type ArbitraryZt = TestShaArbitraryZipTypes; + type IntZt = TestShaIntZipTypes; + + type BinaryLc = IprsCode; + type ArbitraryLc = IprsCode; + type IntLc = IprsCode; + } + + fn sha_binary_col<'a>( + public_trace: &'a UairTrace<'_, ShaInt, ShaInt, TEST_DEGREE_PLUS_ONE>, + witness_trace: &'a UairTrace<'_, ShaInt, ShaInt, TEST_DEGREE_PLUS_ONE>, + flat_col: usize, + ) -> Result< + &'a DenseMultilinearExtension>, + ProductionShaError, + > { + if flat_col < sha256_cols::NUM_BIN_PUB { + public_trace + .binary_poly + .get(flat_col) + .ok_or(ProductionShaError::LengthMismatch { + label: "SHA binary public source columns", + got: public_trace.binary_poly.len(), + expected: flat_col + 1, + }) + } else { + let witness_col = flat_col - sha256_cols::NUM_BIN_PUB; + witness_trace + .binary_poly + .get(witness_col) + .ok_or(ProductionShaError::LengthMismatch { + label: "SHA binary witness source columns", + got: witness_trace.binary_poly.len(), + expected: witness_col + 1, + }) + } + } + + fn sha_int_col<'a>( + public_trace: &'a UairTrace<'_, ShaInt, ShaInt, TEST_DEGREE_PLUS_ONE>, + witness_trace: &'a UairTrace<'_, ShaInt, ShaInt, TEST_DEGREE_PLUS_ONE>, + flat_col: usize, + ) -> Result<&'a DenseMultilinearExtension, ProductionShaError> { + if flat_col < sha256_cols::NUM_INT_PUB { + public_trace + .int + .get(flat_col) + .ok_or(ProductionShaError::LengthMismatch { + label: "SHA int public source columns", + got: public_trace.int.len(), + expected: flat_col + 1, + }) + } else { + let witness_col = flat_col - sha256_cols::NUM_INT_PUB; + witness_trace + .int + .get(witness_col) + .ok_or(ProductionShaError::LengthMismatch { + label: "SHA int witness source columns", + got: witness_trace.int.len(), + expected: witness_col + 1, + }) + } + } + + fn project_binary_source( + col: &DenseMultilinearExtension>, + field_cfg: &::Config, + ) -> Result>, ProductionShaError> { + if col.evaluations.len() < SHA_ROW_COUNT { + return Err(ProductionShaError::LengthMismatch { + label: "SHA binary source rows", + got: col.evaluations.len(), + expected: SHA_ROW_COUNT, + }); + } + Ok(col + .evaluations + .iter() + .take(SHA_ROW_COUNT) + .map(|poly| { + poly.iter() + .take(SHA_WORD_BITS) + .map(|bit| { + if bit.into_inner() { + F::one_with_cfg(field_cfg) + } else { + F::zero_with_cfg(field_cfg) + } + }) + .collect() + }) + .collect()) + } + + fn project_int_source( + col: &DenseMultilinearExtension, + field_cfg: &::Config, + ) -> Result, ProductionShaError> { + if col.evaluations.len() < SHA_ROW_COUNT { + return Err(ProductionShaError::LengthMismatch { + label: "SHA int source rows", + got: col.evaluations.len(), + expected: SHA_ROW_COUNT, + }); + } + Ok(col + .evaluations + .iter() + .take(SHA_ROW_COUNT) + .map(|value| F::from_with_cfg(value, field_cfg)) + .collect()) + } + + fn truncate_sha_row_domain( + col: &DenseMultilinearExtension, + label: &'static str, + ) -> Result, ProductionShaError> { + if col.evaluations.len() < SHA_ROW_COUNT { + return Err(ProductionShaError::LengthMismatch { + label, + got: col.evaluations.len(), + expected: SHA_ROW_COUNT, + }); + } + Ok(DenseMultilinearExtension { + evaluations: col.evaluations[..SHA_ROW_COUNT].to_vec(), + num_vars: SHA_ROW_VARS, + }) + } + + fn word_scalar_at_two(bits: &[F], field_cfg: &::Config) -> F { + let two = F::one_with_cfg(field_cfg) + F::one_with_cfg(field_cfg); + let mut power = F::one_with_cfg(field_cfg); + let mut value = F::zero_with_cfg(field_cfg); + for bit in bits { + value += bit.clone() * &power; + power *= &two; + } + value + } + + fn mle_table_from_columns(columns: Vec>) -> MleTable { + columns + .into_iter() + .map(|evaluations| DenseMultilinearExtension { + evaluations, + num_vars: SHA_ROW_VARS, + }) + .collect() + } + + fn flatten_bit_columns(columns: Vec>>) -> MleTable { + let mut flattened = (0..columns.len() * SHA_WORD_BITS) + .map(|_| Vec::with_capacity(SHA_ROW_COUNT)) + .collect::>(); + for (col_idx, rows) in columns.into_iter().enumerate() { + for bits in rows { + for (bit, value) in bits.into_iter().enumerate() { + flattened[bit_slice_index(col_idx, bit, SHA_WORD_BITS)].push(value); + } + } + } + mle_table_from_columns(flattened) + } + + fn scalarize_bit_slices_plain( + bit_slices: &MleTable, + a: &F, + field_cfg: &::Config, + ) -> Result, ProductionShaError> { + let powers = zinc_utils::powers(a.clone(), F::one_with_cfg(field_cfg), SHA_WORD_BITS); + let word_count = bit_slices.len() / SHA_WORD_BITS; + let mut words = Vec::with_capacity(word_count); + for col_idx in 0..word_count { + let mut out_col = Vec::with_capacity(SHA_ROW_COUNT); + for row in 0..SHA_ROW_COUNT { + let mut value = F::zero_with_cfg(field_cfg); + for (bit, power) in powers.iter().enumerate() { + let bit_col = &bit_slices[bit_slice_index(col_idx, bit, SHA_WORD_BITS)]; + if bit_col.num_vars != SHA_ROW_VARS + || bit_col.evaluations.len() != SHA_ROW_COUNT + { + return Err(ProductionShaError::LengthMismatch { + label: "SHA scalarized bit-slice rows", + got: bit_col.evaluations.len(), + expected: SHA_ROW_COUNT, + }); + } + value += bit_col.evaluations[row].clone() * power; + } + out_col.push(value); + } + words.push(out_col); + } + Ok(mle_table_from_columns(words)) + } + + fn projected_public_from_sources( + pa_a: &[Vec], + pa_e: &[Vec], + message: &[Vec], + field_cfg: &::Config, + ) -> MleTable { + let mut columns = + vec![vec![F::zero_with_cfg(field_cfg); SHA_ROW_COUNT]; ShaPublicCol::COUNT]; + for row in 0..SHA_ROW_COUNT { + columns[ShaPublicCol::K.index()][row] = production_sha_k_expected(row, field_cfg); + columns[ShaPublicCol::PAIn.index()][row] = word_scalar_at_two(&pa_a[row], field_cfg); + columns[ShaPublicCol::PEIn.index()][row] = word_scalar_at_two(&pa_e[row], field_cfg); + columns[ShaPublicCol::PAOut.index()][row] = word_scalar_at_two(&pa_a[row], field_cfg); + columns[ShaPublicCol::PEOut.index()][row] = word_scalar_at_two(&pa_e[row], field_cfg); + columns[ShaPublicCol::Message.index()][row] = + word_scalar_at_two(&message[row], field_cfg); + } + for selector in [ + ShaPublicCol::SInit, + ShaPublicCol::SMsg, + ShaPublicCol::SSched, + ShaPublicCol::SUpd, + ShaPublicCol::SFf, + ShaPublicCol::SOut, + ] { + for row in 0..SHA_ROW_COUNT { + columns[selector.index()][row] = + production_sha_selector_expected(selector, row, field_cfg); + } + } + mle_table_from_columns(columns) + } + + impl ProductionShaProjectionAdapter + for Sha256CompressionSliceUair + { + fn project_production_sha_public( + _shape: &UairShape, + public_trace: &UairTrace<'_, ShaInt, ShaInt, TEST_DEGREE_PLUS_ONE>, + field_cfg: &::Config, + ) -> Result, ProductionShaError> { + let empty_witness = UairTrace { + binary_poly: Cow::Borrowed(&[]), + arbitrary_poly: Cow::Borrowed(&[]), + int: Cow::Borrowed(&[]), + }; + let pa_a = project_binary_source( + sha_binary_col(public_trace, &empty_witness, sha256_cols::PA_A)?, + field_cfg, + )?; + let pa_e = project_binary_source( + sha_binary_col(public_trace, &empty_witness, sha256_cols::PA_E)?, + field_cfg, + )?; + let message = project_binary_source( + sha_binary_col(public_trace, &empty_witness, sha256_cols::PA_M)?, + field_cfg, + )?; + let public_columns = projected_public_from_sources(&pa_a, &pa_e, &message, field_cfg); + Ok(ProjectedPublic { + columns: public_columns, + bit_slices: Some(flatten_bit_columns(vec![ + pa_a.clone(), + pa_e.clone(), + pa_a, + pa_e, + message, + ])), + }) + } + + fn project_production_sha_witness( + _shape: &UairShape, + public_trace: &UairTrace<'_, ShaInt, ShaInt, TEST_DEGREE_PLUS_ONE>, + witness_trace: &UairTrace<'_, ShaInt, ShaInt, TEST_DEGREE_PLUS_ONE>, + field_cfg: &::Config, + ) -> Result< + ( + ProjectedTrace, + ProjectedPublic, + ProductionShaWitnessPolys, + ), + ProductionShaError, + > { + let word_sources = [ + sha256_cols::W_A, + sha256_cols::W_E, + sha256_cols::W_SIG0, + sha256_cols::W_SIG1, + sha256_cols::W_W, + sha256_cols::W_LSIG0, + sha256_cols::W_LSIG1, + sha256_cols::W_U_EF, + sha256_cols::W_U_NEG_E_G, + sha256_cols::W_MAJ, + sha256_cols::W_MU_PACKED, + sha256_cols::PA_OV_SIG0, + sha256_cols::PA_OV_SIG1, + sha256_cols::PA_OV_LSIG0, + sha256_cols::PA_OV_LSIG1, + sha256_cols::PA_R_CH2_COMP, + sha256_cols::PA_R_MAJ_COMP, + ]; + let int_sources = [ + sha256_cols::PA_C_C7, + sha256_cols::PA_C_C8, + sha256_cols::PA_C_C9, + sha256_cols::PA_C_FF_A, + sha256_cols::PA_C_FF_E, + ]; + + let bit_columns = word_sources + .iter() + .map(|&col| { + project_binary_source( + sha_binary_col(public_trace, witness_trace, col)?, + field_cfg, + ) + }) + .collect::, _>>()?; + let bit_slices = flatten_bit_columns(bit_columns.clone()); + let scalarized = scalarize_bit_slices_plain( + &bit_slices, + &F::from_with_cfg(2u64, field_cfg), + field_cfg, + )?; + let pa_a = project_binary_source( + sha_binary_col(public_trace, witness_trace, sha256_cols::PA_A)?, + field_cfg, + )?; + let pa_e = project_binary_source( + sha_binary_col(public_trace, witness_trace, sha256_cols::PA_E)?, + field_cfg, + )?; + let message = project_binary_source( + sha_binary_col(public_trace, witness_trace, sha256_cols::PA_M)?, + field_cfg, + )?; + let public_columns = projected_public_from_sources(&pa_a, &pa_e, &message, field_cfg); + let int_columns = int_sources + .iter() + .map(|&col| { + project_int_source(sha_int_col(public_trace, witness_trace, col)?, field_cfg) + }) + .collect::, _>>()?; + + let trace = ProjectedTrace { + bit_slices, + scalarized, + int_columns: mle_table_from_columns(int_columns.clone()), + public_columns: public_columns.clone(), + }; + let public = ProjectedPublic { + columns: public_columns, + bit_slices: Some(flatten_bit_columns(vec![ + pa_a.clone(), + pa_e.clone(), + pa_a, + pa_e, + message, + ])), + }; + Ok(( + trace, + public, + ProductionShaWitnessPolys { + binary: word_sources + .iter() + .map(|&col| { + truncate_sha_row_domain( + sha_binary_col(public_trace, witness_trace, col)?, + "SHA binary witness row-domain projection", + ) + }) + .collect::, _>>()?, + arbitrary: Vec::new(), + int: int_sources + .iter() + .map(|&col| { + truncate_sha_row_domain( + sha_int_col(public_trace, witness_trace, col)?, + "SHA int witness row-domain projection", + ) + }) + .collect::, _>>()?, + }, + )) + } + } + + fn zero_trace_with_scalar_challenge(a: &F) -> ProjectedTrace { + let field_cfg = cfg(); + let zero = F::zero_with_cfg(&field_cfg); + let bit_slices = + flatten_bit_columns(vec![ + vec![vec![zero.clone(); SHA_WORD_BITS]; SHA_ROW_COUNT]; + ShaWordCol::COUNT + ]); + let scalarized = scalarize_bit_slices_plain(&bit_slices, a, &field_cfg).unwrap(); + ProjectedTrace { + bit_slices, + scalarized, + int_columns: mle_table_from_columns(vec![ + vec![zero.clone(); SHA_ROW_COUNT]; + ShaIntCol::COUNT + ]), + public_columns: mle_table_from_columns(vec![ + vec![zero; SHA_ROW_COUNT]; + ShaPublicCol::COUNT + ]), + } + } + + fn zero_public() -> ProjectedPublic { + let field_cfg = cfg(); + ProjectedPublic { + columns: mle_table_from_columns(vec![ + vec![F::zero_with_cfg(&field_cfg); SHA_ROW_COUNT]; + ShaPublicCol::COUNT + ]), + bit_slices: Some(flatten_bit_columns(vec![ + vec![ + vec![ + F::zero_with_cfg( + &field_cfg + ); + SHA_WORD_BITS + ]; + SHA_ROW_COUNT + ]; + ShaPublicWordCol::COUNT + ])), + } + } + + fn fixed_layout_public() -> ProjectedPublic { + let field_cfg = cfg(); + let mut public = zero_public(); + for selector in [ + ShaPublicCol::SInit, + ShaPublicCol::SMsg, + ShaPublicCol::SSched, + ShaPublicCol::SUpd, + ShaPublicCol::SFf, + ShaPublicCol::SOut, + ] { + for row in 0..SHA_ROW_COUNT { + public.columns[selector.index()].evaluations[row] = + production_sha_selector_expected(selector, row, &field_cfg); + } + } + for row in 0..SHA_ROW_COUNT { + public.columns[ShaPublicCol::K.index()].evaluations[row] = + production_sha_k_expected(row, &field_cfg); + } + public + } + + fn sparse_r_ic() -> [F; SHA_ROW_VARS] { + std::array::from_fn(|idx| f(idx as u64 + 2)) + } + + fn rescalarize_endpoint_source(source: &mut ShaSourceEndpointEval, a: &F) { + let field_cfg = cfg(); + let powers = zinc_utils::powers(a.clone(), F::one_with_cfg(&field_cfg), 32); + source.scalarized = source + .bits + .iter() + .zip(powers.iter()) + .fold(F::zero_with_cfg(&field_cfg), |acc, (bit, power)| { + acc + bit.clone() * power + }); + } + + fn endpoint_source(col: ShaWordCol, shift: usize, seed: u64) -> ShaSourceEndpointEval { + let field_cfg = cfg(); + let bits = std::array::from_fn(|idx| f(seed + idx as u64 + 1)); + let powers = zinc_utils::powers(f(7), F::one_with_cfg(&field_cfg), 32); + let scalarized = bits + .iter() + .zip(powers.iter()) + .fold(F::zero_with_cfg(&field_cfg), |acc, (bit, power)| { + acc + bit.clone() * power + }); + ShaSourceEndpointEval { + col, + shift, + scalarized, + bits, + } + } + + fn endpoint_evals_for_virtuals() -> ShaEndpointEvals { + ShaEndpointEvals { + sources: vec![ + endpoint_source(ShaWordCol::E, 0, 10), + endpoint_source(ShaWordCol::E, 1, 20), + endpoint_source(ShaWordCol::E, 2, 30), + endpoint_source(ShaWordCol::A, 0, 40), + endpoint_source(ShaWordCol::A, 1, 50), + endpoint_source(ShaWordCol::A, 2, 60), + endpoint_source(ShaWordCol::Uef, 2, 70), + endpoint_source(ShaWordCol::UNegEg, 2, 80), + endpoint_source(ShaWordCol::Ch2Comp, 0, 90), + endpoint_source(ShaWordCol::Maj, 2, 100), + endpoint_source(ShaWordCol::MajComp, 0, 110), + ], + int_sources: Vec::new(), + } + } + + fn hyrax_key_pair( + width: usize, + offset: u64, + ) -> ( + zip_plus::pcs::hyrax::HyraxCommitmentKey, + zip_plus::pcs::hyrax::HyraxVerifierKey, + ) + where + C: AffineRepr, + Lanes: Clone + Debug + Send + Sync, + { + let generator = C::Group::generator(); + let bases = (0..width) + .map(|idx| { + let scalar = C::ScalarField::from( + offset + u64::try_from(idx).expect("Hyrax basis index fits u64") + 1, + ); + (generator * scalar).into_affine() + }) + .collect::>(); + let h = generator + * C::ScalarField::from( + offset + u64::try_from(width).expect("Hyrax width fits u64") + 1, + ); + HyraxPCS::::setup_from_bases_with_blinding( + width, + bases, + h, + HyraxBlindingMode::Unblinded, + ) + .expect("Hyrax test setup must be valid") + } + + fn all_hyrax_test_pcs_params() -> ( + PCSParams, TestShaZincTypes, F, TEST_DEGREE_PLUS_ONE>, + PCSVerifierParams, TestShaZincTypes, F, TEST_DEGREE_PLUS_ONE>, + ) + where + C: AffineRepr, + AllHyraxPCSTypes: ZincPCSTypes< + TestShaZincTypes, + F, + TEST_DEGREE_PLUS_ONE, + BinaryPCS = HyraxPCS, + ArbitraryPCS = HyraxPCS, + IntPCS = HyraxPCS, + >, + { + let width = SHA_ROW_COUNT; + let (binary_ck, binary_vk) = hyrax_key_pair::(width, 0); + let (arbitrary_ck, arbitrary_vk) = hyrax_key_pair::(width, 1_000); + let (int_ck, int_vk) = hyrax_key_pair::(width, 2_000); + + ( + PCSParams::, TestShaZincTypes, F, TEST_DEGREE_PLUS_ONE> { + binary: binary_ck, + arbitrary: arbitrary_ck, + int: int_ck, + }, + PCSVerifierParams::, TestShaZincTypes, F, TEST_DEGREE_PLUS_ONE> { + binary: binary_vk, + arbitrary: arbitrary_vk, + int: int_vk, + }, + ) + } + + #[test] + fn linear_ideal_fold_proves_and_verifies_eight_sha_instances_with_hyrax() { + type C = ark_bn254::G1Affine; + type P = AllHyraxPCSTypes; + type U = Sha256CompressionSliceUair; + + let field_cfg = fixed_prime::field_cfg_from_curve_scalar::, C>(); + let initial_state = SHA256_INITIAL_STATE; + let message = vec!["hello world"; 40].join(" "); + let message_blocks = sha256_padded_message_blocks::<8>(message.as_bytes()) + .expect("test message should canonically pad to 8 SHA-256 blocks"); + let (witnesses, _final_state) = + synthesize_sha256_chain_witnesses::(initial_state, message_blocks) + .expect("SHA-256 UAIR witnesses synthesize"); + let shape = UairShape::::new(SHA_ROW_VARS); + let (pcs_params, pcs_verifier_params) = all_hyrax_test_pcs_params::(); + let pp = + LinearIdealFoldProverParams::::new( + pcs_params, + field_cfg.clone(), + 3, + ); + let vs = setup_verify_linear_ideal_fold::( + LinearIdealFoldVerifierParams::new(pcs_verifier_params, field_cfg), + shape.clone(), + ) + .expect("production SHA verifier setup succeeds"); + + let mut prover_transcript = Blake3Transcript::new(); + let output = prove_linear_ideal_fold::( + &pp, + &shape, + &witnesses, + &mut prover_transcript, + ) + .expect("production SHA ProjectionFold proof succeeds"); + + let mut verifier_transcript = Blake3Transcript::new(); + let verified = verify_linear_ideal_fold::( + &vs, + &output.fresh_instances, + &output.proof, + &mut verifier_transcript, + ) + .expect("production SHA ProjectionFold proof verifies"); + + assert_eq!(verified.target, output.folded_instance.target); + assert_eq!(verified.public, output.folded_instance.public); + } + + #[test] + fn optimized_sumfold_claim_feeds_folded_row_sumcheck_with_tail_for_eight_sha_instances() { + type U = Sha256CompressionSliceUair; + + let field_cfg = cfg(); + let initial_state = SHA256_INITIAL_STATE; + let message = vec!["hello world"; 40].join(" "); + let message_blocks = sha256_padded_message_blocks::<8>(message.as_bytes()) + .expect("test message should canonically pad to eight SHA-256 blocks"); + let (witnesses, _final_state) = + synthesize_sha256_chain_witnesses::(initial_state, message_blocks) + .expect("SHA-256 UAIR witnesses synthesize"); + let shape = UairShape::::new(SHA_ROW_VARS); + + let (traces, publics): (Vec<_>, Vec<_>) = witnesses + .iter() + .map(|witness| { + let public_trace = + public_uair_trace_view::( + &witness.trace, + &shape.signature, + ) + .unwrap(); + let witness_trace = + witness_uair_trace_view::( + &witness.trace, + &shape.signature, + ) + .unwrap(); + let (trace, public, _witness_polys) = U::project_production_sha_witness( + &shape, + &public_trace, + &witness_trace, + &field_cfg, + ) + .unwrap(); + (trace, public) + }) + .unzip(); + validate_production_sha_publics(&publics, &field_cfg).unwrap(); + + let r_ic = sparse_r_ic(); + let r_ic_eq_weights = build_eq_x_r_vec(&r_ic, &field_cfg).unwrap(); + let coeff_tables = build_linear_residual_coeff_tables_with_row_weights( + &traces, + &publics, + &r_ic_eq_weights, + &field_cfg, + ) + .unwrap(); + let beta = vec![f(13), f(17), f(19)]; + let beta_eq_weights = build_eq_x_r_vec(&beta, &field_cfg).unwrap(); + let aggregate_ideal_polys = + beta_aggregate_nonzero_ideal_polys_with_weights(&coeff_tables, &beta_eq_weights) + .unwrap(); + let ideal_check = IdealCheckProof { + combined_mle_values: aggregate_ideal_polys.iter().cloned().collect(), + }; + let aggregate_ideal_polys = aggregate_sha_ideal_polys_from_proof(&ideal_check).unwrap(); + check_aggregate_sha_ideal_membership(&aggregate_ideal_polys, &field_cfg).unwrap(); + + let a = f(5); + let lambda = f(7); + let rho = f(11); + let xi = f(13); + let booleanity_sources = production_sha_booleanity_sources(); + let a_powers = build_sha_residual_eval_powers(&a, &field_cfg); + let lambda_powers = build_sha_lambda_powers(&lambda, &field_cfg); + let booleanity_weights = + build_booleanity_weights(&rho, &xi, booleanity_sources.len(), &field_cfg); + let initial_claim = + evaluate_aggregate_sha_ideal_claim(&aggregate_ideal_polys, &a, &lambda, &field_cfg) + .unwrap(); + let linear_accumulator = build_sha_sumfold_linear_accumulator( + &coeff_tables, + &a_powers, + &lambda_powers, + &field_cfg, + ) + .unwrap(); + let prefix_vars = 2; + let quadratic_prefix_accumulator = build_sha_sumfold_quadratic_prefix_accumulator( + &traces, + &booleanity_sources, + prefix_vars, + &r_ic_eq_weights, + &booleanity_weights, + &field_cfg, + ) + .unwrap(); + assert_eq!(linear_accumulator.len(), traces.len()); + assert_eq!(quadratic_prefix_accumulator.len(), 18); + + let group = build_production_sha_sumfold_group_from_prefix_accumulators( + &traces, + &beta, + &beta_eq_weights, + &r_ic_eq_weights, + &linear_accumulator, + &quadratic_prefix_accumulator, + &booleanity_weights, + &booleanity_sources, + prefix_vars, + &field_cfg, + ) + .unwrap(); + let mut sumfold_prover_transcript = Blake3Transcript::new(); + sumfold_prover_transcript.absorb_slice(b"sha-sumfold-row-bridge"); + let (sumfold_proof, r_b, c_sf) = prove_optimized_sha_sumfold_with_weights( + &mut sumfold_prover_transcript, + group, + &initial_claim, + beta.len(), + &field_cfg, + ) + .unwrap(); + + let mut sumfold_verifier_transcript = Blake3Transcript::new(); + sumfold_verifier_transcript.absorb_slice(b"sha-sumfold-row-bridge"); + let verified_sumfold = verify_full_sha_sumfold( + &mut sumfold_verifier_transcript, + &sumfold_proof, + &initial_claim, + beta.len(), + &field_cfg, + ) + .unwrap(); + assert_eq!(verified_sumfold.r_b, r_b); + + let provisional = + derive_instance_fold_claim(&beta, r_b.clone(), c_sf.clone(), traces.len(), &field_cfg) + .unwrap(); + let (folded, folded_public) = + fold_projected_traces(&traces, &publics, &provisional, &field_cfg).unwrap(); + let row_claim = expression_folded_row_sum_with_vectors( + &folded.trace, + &folded_public, + &r_ic_eq_weights, + &a_powers, + &lambda_powers, + &booleanity_weights, + &booleanity_sources, + &field_cfg, + ) + .unwrap(); + let sumfold_output = + derive_instance_fold_claim(&beta, r_b, c_sf, traces.len(), &field_cfg).unwrap(); + assert_eq!(verified_sumfold.c_sf, *sumfold_output.c_sf()); + assert_eq!(sumfold_output.final_round_sumcheck_claim(), &row_claim); + + let mut row_prover_transcript = Blake3Transcript::new(); + row_prover_transcript.absorb_slice(b"sha-folded-row-bridge"); + let (row_proof, row_output) = prove_expression_folded_row_sumcheck_with_output_and_vectors( + &mut row_prover_transcript, + &folded.trace, + &folded_public, + &r_ic, + &r_ic_eq_weights, + &a_powers, + &lambda_powers, + &booleanity_weights, + &booleanity_sources, + &field_cfg, + ) + .unwrap(); + assert_eq!(row_proof.claimed_sums(), &[row_claim.clone()]); + + let mut row_verifier_transcript = Blake3Transcript::new(); + row_verifier_transcript.absorb_slice(b"sha-folded-row-bridge"); + let verified_row = verify_folded_row_sumcheck( + &mut row_verifier_transcript, + &row_proof, + &row_claim, + &field_cfg, + ) + .unwrap(); + verify_folded_row_terminal_value(&verified_row, &row_output.terminal_value).unwrap(); + } + + #[test] + fn fresh_ideal_coefficients_are_bound_before_a() { + let field_cfg = cfg(); + let ideals = vec![std::array::from_fn(|idx| { + DynamicPolynomialF::new_trimmed(vec![f(idx as u64 + 1), f(99)]) + })]; + let mut tampered = ideals.clone(); + tampered[0][0].coeffs[0] += f(1); + + let sample_a = |values: &[[DynamicPolynomialF; NUM_NONZERO_SHA_FAMILIES]]| { + let mut transcript = Blake3Transcript::new(); + transcript.absorb_slice(b"fresh-commitments-and-public-inputs"); + let _r_ic = sample_pre_ideal_challenge::(&mut transcript, &field_cfg); + absorb_fresh_sha_ideal_polys(&mut transcript, values, &field_cfg); + let (a, _, _, _, _) = + sample_post_ideal_challenges::(&mut transcript, 1, &field_cfg).unwrap(); + a + }; + + assert_ne!(sample_a(&ideals), sample_a(&tampered)); + } + + #[test] + fn fresh_ideal_absorption_binds_polynomial_slot_structure() { + let field_cfg = cfg(); + let mut packed = vec![std::array::from_fn(|_| { + DynamicPolynomialF::new(Vec::::new()) + })]; + packed[0][0] = DynamicPolynomialF::new(vec![f(1), f(2)]); + + let mut split = vec![std::array::from_fn(|_| { + DynamicPolynomialF::new(Vec::::new()) + })]; + split[0][0] = DynamicPolynomialF::new(vec![f(1)]); + split[0][1] = DynamicPolynomialF::new(vec![f(2)]); + + let sample_a = |values: &[[DynamicPolynomialF; NUM_NONZERO_SHA_FAMILIES]]| { + let mut transcript = Blake3Transcript::new(); + transcript.absorb_slice(b"fresh-commitments-and-public-inputs"); + let _r_ic = sample_pre_ideal_challenge::(&mut transcript, &field_cfg); + absorb_fresh_sha_ideal_polys(&mut transcript, values, &field_cfg); + let (a, _, _, _, _) = + sample_post_ideal_challenges::(&mut transcript, 1, &field_cfg).unwrap(); + a + }; + + assert_ne!(sample_a(&packed), sample_a(&split)); + } + + #[test] + fn aggregate_ideal_claim_matches_old_per_instance_targets() { + let field_cfg = cfg(); + let mut ideal_polys = Vec::new(); + for instance_idx in 0..4 { + ideal_polys.push(std::array::from_fn(|slot| { + let family = production_sha_nonzero_families()[slot]; + match family { + ShaResidualFamily::R0BigSigmaA | ShaResidualFamily::R1BigSigmaE => { + let c = f((instance_idx * 10 + slot + 1) as u64); + let mut coeffs = vec![F::zero_with_cfg(&field_cfg); 33]; + coeffs[0] = -c.clone(); + coeffs[32] = c; + DynamicPolynomialF::new_trimmed(coeffs) + } + _ => { + let c = f((instance_idx * 10 + slot + 1) as u64); + DynamicPolynomialF::new_trimmed(vec![-f(2) * &c, c]) + } + } + })); + } + let beta = vec![f(3), f(5)]; + let a = f(7); + let lambda = f(11); + + let aggregate = beta_aggregate_sha_ideal_polys(&ideal_polys, &beta, &field_cfg).unwrap(); + let aggregate_claim = + evaluate_aggregate_sha_ideal_claim(&aggregate, &a, &lambda, &field_cfg).unwrap(); + let fresh_targets = + evaluate_fresh_targets_from_ideal_polys(&ideal_polys, &a, &lambda, &field_cfg).unwrap(); + let old_claim = eq_weighted_sum(&beta, &fresh_targets, &field_cfg).unwrap(); + + assert_eq!(aggregate_claim, old_claim); + } + + #[test] + fn aggregate_ideal_membership_rejects_wrong_family_polynomial() { + let field_cfg = cfg(); + let mut aggregate = std::array::from_fn(|slot| { + let family = production_sha_nonzero_families()[slot]; + match family { + ShaResidualFamily::R0BigSigmaA | ShaResidualFamily::R1BigSigmaE => { + let mut coeffs = vec![F::zero_with_cfg(&field_cfg); 33]; + coeffs[0] = -f(3); + coeffs[32] = f(3); + DynamicPolynomialF::new_trimmed(coeffs) + } + _ => DynamicPolynomialF::new_trimmed(vec![-f(10), f(5)]), + } + }); + check_aggregate_sha_ideal_membership(&aggregate, &field_cfg).unwrap(); + + aggregate[2] = DynamicPolynomialF::new_trimmed(vec![f(1)]); + assert!(matches!( + check_aggregate_sha_ideal_membership(&aggregate, &field_cfg), + Err(ProductionShaError::ShaProjection( + ShaProjectionError::IdealMembership + )) + )); + } + + #[test] + fn aggregate_ideal_absorption_precedes_scalarization_challenges() { + let field_cfg = cfg(); + let aggregate = std::array::from_fn(|slot| { + DynamicPolynomialF::new_trimmed(vec![f(slot as u64 + 1), f(slot as u64 + 2)]) + }); + let mut tampered = aggregate.clone(); + tampered[0].coeffs[0] += f(1); + + let sample_a = |values: &[DynamicPolynomialF; NUM_NONZERO_SHA_FAMILIES]| { + let mut transcript = Blake3Transcript::new(); + transcript.absorb_slice(b"fresh-commitments-and-public-inputs"); + let _r_ic = sample_pre_ideal_challenge::(&mut transcript, &field_cfg); + let _beta = + sample_instance_batch_challenge::(&mut transcript, 4, &field_cfg).unwrap(); + absorb_aggregate_sha_ideal_polys(&mut transcript, values, &field_cfg); + let (a, _, _, _) = + sample_post_aggregate_ideal_challenges::(&mut transcript, &field_cfg); + a + }; + + assert_ne!(sample_a(&aggregate), sample_a(&tampered)); + } + + #[test] + fn production_public_validation_requires_fixed_selectors_and_k() { + let field_cfg = cfg(); + let valid = fixed_layout_public(); + validate_production_sha_publics(std::slice::from_ref(&valid), &field_cfg).unwrap(); + + let mut bad_selector = valid.clone(); + bad_selector.columns[ShaPublicCol::SOut.index()].evaluations[0] = f(1); + assert!(matches!( + validate_production_sha_publics(&[bad_selector], &field_cfg), + Err(ProductionShaError::InvalidPublicSelector { + col: ShaPublicCol::SOut, + row: 0 + }) + )); + + let mut non_boolean_selector = valid.clone(); + non_boolean_selector.columns[ShaPublicCol::SInit.index()].evaluations[0] = f(2); + assert!(matches!( + validate_production_sha_publics(&[non_boolean_selector], &field_cfg), + Err(ProductionShaError::NonBooleanPublicSelector { + col: ShaPublicCol::SInit, + row: 0 + }) + )); + + let mut bad_k = valid; + bad_k.columns[ShaPublicCol::K.index()].evaluations[3] += f(1); + assert!(matches!( + validate_production_sha_publics(&[bad_k], &field_cfg), + Err(ProductionShaError::InvalidRoundConstant { row: 3 }) + )); + } + + #[test] + fn sumfold_outputs_instance_fold_point_before_weights() { + let field_cfg = cfg(); + let fresh_targets = vec![f(2), f(5), f(7), f(11)]; + let beta = vec![f(13), f(17)]; + + let mut prover_transcript = Blake3Transcript::new(); + prover_transcript.absorb_slice(b"bound-before-sumfold"); + let (proof, prover_output) = + prove_sha_sumfold_targets(&mut prover_transcript, &fresh_targets, &beta, 1, &field_cfg) + .unwrap(); + + let mut verifier_transcript = Blake3Transcript::new(); + verifier_transcript.absorb_slice(b"bound-before-sumfold"); + let verifier_output = verify_sha_sumfold_targets( + &mut verifier_transcript, + &proof, + &fresh_targets, + &beta, + &field_cfg, + ) + .unwrap(); + + assert_eq!(verifier_output, prover_output); + assert_eq!( + prover_output.eq_instance_weights(), + build_eq_x_r_vec(prover_output.r_b(), &field_cfg).unwrap() + ); + assert_eq!( + prover_output.eq_instance_weights().len(), + fresh_targets.len() + ); + + let d = eq_eval(&beta, prover_output.r_b(), F::one_with_cfg(&field_cfg)).unwrap(); + assert_eq!( + prover_output.c_sf(), + &(d * prover_output.final_round_sumcheck_claim()) + ); + + let mut bad_targets = fresh_targets; + bad_targets[0] += f(1); + let mut bad_transcript = Blake3Transcript::new(); + bad_transcript.absorb_slice(b"bound-before-sumfold"); + assert!( + verify_sha_sumfold_targets( + &mut bad_transcript, + &proof, + &bad_targets, + &beta, + &field_cfg + ) + .is_err() + ); + } + + #[test] + fn sumfold_verifier_rejects_extra_groups() { + let field_cfg = cfg(); + let fresh_targets = vec![f(2), f(5), f(7), f(11)]; + let beta = vec![f(13), f(17)]; + let claims = + zinc_piop::neutron_nova::LinearInstanceClaims::new(fresh_targets.clone()).unwrap(); + let group_0 = claims + .build_hybrid_sumcheck_group(&beta, 1, &field_cfg) + .unwrap(); + let group_1 = claims + .build_hybrid_sumcheck_group(&beta, 1, &field_cfg) + .unwrap(); + + let mut prover_transcript = Blake3Transcript::new(); + prover_transcript.absorb_slice(b"bound-before-sumfold"); + let (proof, _) = MultiDegreeSumcheck::prove_as_subprotocol( + &mut prover_transcript, + vec![group_0, group_1], + claims.ell(), + &field_cfg, + ); + + let mut verifier_transcript = Blake3Transcript::new(); + verifier_transcript.absorb_slice(b"bound-before-sumfold"); + assert!(matches!( + verify_sha_sumfold_targets( + &mut verifier_transcript, + &proof, + &fresh_targets, + &beta, + &field_cfg, + ), + Err(ProductionShaError::UnexpectedSumcheckGroupCount { + label: "SHA SumFold", + got: 2 + }) + )); + } + + #[test] + fn full_sha_sumfold_derives_fold_weights_after_instance_sumcheck() { + let field_cfg = cfg(); + let a = f(5); + let traces = vec![ + zero_trace_with_scalar_challenge(&a), + zero_trace_with_scalar_challenge(&a), + ]; + let publics = vec![zero_public(), zero_public()]; + let beta = vec![f(13)]; + let r_ic = sparse_r_ic(); + let lambda = f(17); + let rho = f(19); + let xi = f(23); + let booleanity_sources = vec![ + ShaBooleanitySource::WordBit { + col: ShaWordCol::A, + bit: 0, + }, + ShaBooleanitySource::VirtualMaj { bit: 0 }, + ]; + let initial_claim = F::zero_with_cfg(&field_cfg); + + let mut prover_transcript = Blake3Transcript::new(); + prover_transcript.absorb_slice(b"full-sha-sumfold-context"); + let (proof, prover_output) = prove_full_sha_sumfold( + &mut prover_transcript, + &traces, + &publics, + &initial_claim, + &beta, + &r_ic, + &a, + &lambda, + &rho, + &xi, + &booleanity_sources, + &field_cfg, + ) + .unwrap(); + + let mut verifier_transcript = Blake3Transcript::new(); + verifier_transcript.absorb_slice(b"full-sha-sumfold-context"); + let verified_sumfold = verify_full_sha_sumfold( + &mut verifier_transcript, + &proof, + &initial_claim, + beta.len(), + &field_cfg, + ) + .unwrap(); + let verifier_output = derive_instance_fold_claim( + &beta, + verified_sumfold.r_b, + verified_sumfold.c_sf, + traces.len(), + &field_cfg, + ) + .unwrap(); + + assert_eq!(verifier_output, prover_output); + assert_eq!( + prover_output.eq_instance_weights(), + build_eq_x_r_vec(prover_output.r_b(), &field_cfg).unwrap() + ); + assert_eq!(prover_output.eq_instance_weights().len(), traces.len()); + + let (folded, folded_public) = + fold_projected_traces(&traces, &publics, &prover_output, &field_cfg).unwrap(); + let folded_sum = expression_folded_row_sum( + &folded.trace, + &folded_public, + &r_ic, + &a, + &lambda, + &rho, + &xi, + &booleanity_sources, + &field_cfg, + ) + .unwrap(); + assert_eq!(prover_output.final_round_sumcheck_claim(), &folded_sum); + + let mut bad_transcript = Blake3Transcript::new(); + bad_transcript.absorb_slice(b"full-sha-sumfold-context"); + assert!( + verify_full_sha_sumfold(&mut bad_transcript, &proof, &f(1), beta.len(), &field_cfg) + .is_err() + ); + } + + #[test] + fn folded_row_sumcheck_claim_matches_folded_integrand_sum() { + let field_cfg = cfg(); + let row_integrand_values = (0..(1usize << SHA_ROW_VARS)) + .map(|idx| f((idx as u64).wrapping_mul(3) + 1)) + .collect::>(); + let final_round_sumcheck_claim = + folded_row_integrand_sum(&row_integrand_values, &field_cfg).unwrap(); + + let mut prover_transcript = Blake3Transcript::new(); + prover_transcript.absorb_slice(b"folded-row-context"); + let proof = prove_folded_row_sumcheck( + &mut prover_transcript, + &row_integrand_values, + &final_round_sumcheck_claim, + &field_cfg, + ) + .unwrap(); + assert_eq!(proof.claimed_sums(), &[final_round_sumcheck_claim.clone()]); + + let mut verifier_transcript = Blake3Transcript::new(); + verifier_transcript.absorb_slice(b"folded-row-context"); + let output = verify_folded_row_sumcheck( + &mut verifier_transcript, + &proof, + &final_round_sumcheck_claim, + &field_cfg, + ) + .unwrap(); + assert_eq!(output.r_star.len(), SHA_ROW_VARS); + + let row_weights = build_eq_x_r_vec(&output.r_star, &field_cfg).unwrap(); + let terminal = row_weights + .iter() + .zip(row_integrand_values.iter()) + .fold(F::zero_with_cfg(&field_cfg), |acc, (weight, value)| { + acc + weight.clone() * value + }); + verify_folded_row_terminal_value(&output, &terminal).unwrap(); + + let mut bad_terminal = terminal; + bad_terminal += f(1); + assert!(verify_folded_row_terminal_value(&output, &bad_terminal).is_err()); + + let mut bad_claim = final_round_sumcheck_claim; + bad_claim += f(1); + let mut bad_transcript = Blake3Transcript::new(); + bad_transcript.absorb_slice(b"folded-row-context"); + assert!( + verify_folded_row_sumcheck(&mut bad_transcript, &proof, &bad_claim, &field_cfg) + .is_err() + ); + } + + #[test] + fn folded_row_verifier_rejects_extra_groups() { + let field_cfg = cfg(); + let row_integrand_values = (0..(1usize << SHA_ROW_VARS)) + .map(|idx| f((idx as u64).wrapping_mul(5) + 9)) + .collect::>(); + let post_sumfold_claim = + folded_row_integrand_sum(&row_integrand_values, &field_cfg).unwrap(); + let group_0 = build_folded_row_sumcheck_group(&row_integrand_values, &field_cfg).unwrap(); + let group_1 = build_folded_row_sumcheck_group(&row_integrand_values, &field_cfg).unwrap(); + + let mut prover_transcript = Blake3Transcript::new(); + prover_transcript.absorb_slice(b"folded-row-context"); + let (proof, _) = MultiDegreeSumcheck::prove_as_subprotocol( + &mut prover_transcript, + vec![group_0, group_1], + SHA_ROW_VARS, + &field_cfg, + ); + + let mut verifier_transcript = Blake3Transcript::new(); + verifier_transcript.absorb_slice(b"folded-row-context"); + assert!(matches!( + verify_folded_row_sumcheck( + &mut verifier_transcript, + &proof, + &post_sumfold_claim, + &field_cfg + ), + Err(ProductionShaError::UnexpectedSumcheckGroupCount { + label: "folded row sumcheck", + got: 2 + }) + )); + } + + #[test] + fn expression_folded_row_terminal_is_reconstructed_from_endpoints() { + let field_cfg = cfg(); + let a = f(5); + let trace = zero_trace_with_scalar_challenge(&a); + let public = zero_public(); + let r_ic = sparse_r_ic(); + let lambda = f(7); + let rho = f(11); + let xi = f(13); + let booleanity_sources = vec![ + ShaBooleanitySource::WordBit { + col: ShaWordCol::A, + bit: 0, + }, + ShaBooleanitySource::VirtualCh1 { bit: 0 }, + ]; + let post_sumfold_claim = expression_folded_row_sum( + &trace, + &public, + &r_ic, + &a, + &lambda, + &rho, + &xi, + &booleanity_sources, + &field_cfg, + ) + .unwrap(); + + let mut prover_transcript = Blake3Transcript::new(); + prover_transcript.absorb_slice(b"expression-row-context"); + let proof = prove_expression_folded_row_sumcheck( + &mut prover_transcript, + &trace, + &public, + &r_ic, + &a, + &lambda, + &rho, + &xi, + &booleanity_sources, + &post_sumfold_claim, + &field_cfg, + ) + .unwrap(); + + let mut verifier_transcript = Blake3Transcript::new(); + verifier_transcript.absorb_slice(b"expression-row-context"); + let output = verify_folded_row_sumcheck( + &mut verifier_transcript, + &proof, + &post_sumfold_claim, + &field_cfg, + ) + .unwrap(); + let endpoint_evals = + build_sha_endpoint_evals_from_trace(&trace, &output.r_star, &a, &field_cfg).unwrap(); + let terminal = reconstruct_folded_row_terminal_from_endpoints( + &endpoint_evals, + &public, + &r_ic, + &output.r_star, + &a, + &lambda, + &rho, + &xi, + &booleanity_sources, + &field_cfg, + ) + .unwrap(); + verify_folded_row_terminal_value(&output, &terminal).unwrap(); + + let mut bad_terminal = terminal; + bad_terminal += f(1); + assert!(verify_folded_row_terminal_value(&output, &bad_terminal).is_err()); + + let mut bad_endpoints = endpoint_evals; + bad_endpoints.sources[0].bits[0] += f(1); + assert!( + reconstruct_folded_row_terminal_from_endpoints( + &bad_endpoints, + &public, + &r_ic, + &output.r_star, + &a, + &lambda, + &rho, + &xi, + &booleanity_sources, + &field_cfg + ) + .is_err() + ); + } + + #[test] + fn endpoint_multipoint_reduces_all_sources_and_rejects_bad_openings() { + let field_cfg = cfg(); + let a = f(5); + let trace = zero_trace_with_scalar_challenge(&a); + let public = zero_public(); + let r_star = vec![f(2), f(3), f(5), f(7), f(11), f(13), f(17)]; + let endpoint_evals = + build_sha_endpoint_evals_from_trace(&trace, &r_star, &a, &field_cfg).unwrap(); + + let mut prover_transcript = Blake3Transcript::new(); + prover_transcript.absorb_slice(b"endpoint-multipoint-context"); + let (proof, r_0) = prove_sha_endpoint_multipoint( + &mut prover_transcript, + &trace, + &public, + &endpoint_evals, + &r_star, + &field_cfg, + ) + .unwrap(); + + let mut verifier_transcript = Blake3Transcript::new(); + verifier_transcript.absorb_slice(b"endpoint-multipoint-context"); + let (subclaim, shift_specs) = verify_sha_endpoint_multipoint( + &mut verifier_transcript, + &proof, + &endpoint_evals, + &public, + &r_star, + &field_cfg, + ) + .unwrap(); + assert_eq!(subclaim.sumcheck_subclaim.point, r_0); + + let layout = production_sha_multipoint_layout(); + let trace_mles = sha_multipoint_trace_mles(&trace, &public, &layout, &field_cfg).unwrap(); + let open_evals = trace_mles + .iter() + .map(|mle| mle.clone().evaluate_with_config(&r_0, &field_cfg).unwrap()) + .collect::>(); + verify_sha_endpoint_multipoint_open_evals(&subclaim, &open_evals, &shift_specs, &field_cfg) + .unwrap(); + + let mut bad_open_evals = open_evals.clone(); + bad_open_evals[0] += f(1); + assert!( + verify_sha_endpoint_multipoint_open_evals( + &subclaim, + &bad_open_evals, + &shift_specs, + &field_cfg + ) + .is_err() + ); + + let mut bad_endpoint_evals = endpoint_evals; + bad_endpoint_evals.sources[0].bits[0] += f(1); + rescalarize_endpoint_source(&mut bad_endpoint_evals.sources[0], &a); + let mut bad_verifier_transcript = Blake3Transcript::new(); + bad_verifier_transcript.absorb_slice(b"endpoint-multipoint-context"); + assert!( + verify_sha_endpoint_multipoint( + &mut bad_verifier_transcript, + &proof, + &bad_endpoint_evals, + &public, + &r_star, + &field_cfg + ) + .is_err() + ); + } + + #[test] + fn fresh_ideal_objects_must_be_trimmed_and_degree_capped() { + let field_cfg = cfg(); + let mut ideals = vec![std::array::from_fn(|_| { + DynamicPolynomialF::new(Vec::::new()) + })]; + ideals[0][0] = DynamicPolynomialF::new(vec![f(1), F::zero_with_cfg(&field_cfg)]); + assert!(matches!( + check_fresh_sha_ideal_membership(&ideals, &field_cfg), + Err(ProductionShaError::ShaProjection( + ShaProjectionError::NonCanonicalProofObject(_) + )) + )); + + let mut high_degree = vec![std::array::from_fn(|_| { + DynamicPolynomialF::new(Vec::::new()) + })]; + high_degree[0][2] = DynamicPolynomialF::new(vec![f(1); 33]); + assert!(matches!( + check_fresh_sha_ideal_membership(&high_degree, &field_cfg), + Err(ProductionShaError::ShaProjection( + ShaProjectionError::NonCanonicalProofObject(_) + )) + )); + } + + #[test] + fn endpoint_layout_must_be_exact_and_canonical() { + let field_cfg = cfg(); + let a = f(5); + let trace = zero_trace_with_scalar_challenge(&a); + let r_star = vec![f(2), f(3), f(5), f(7), f(11), f(13), f(17)]; + let endpoint_evals = + build_sha_endpoint_evals_from_trace(&trace, &r_star, &a, &field_cfg).unwrap(); + validate_sha_endpoint_layout(&endpoint_evals).unwrap(); + + let mut missing = endpoint_evals.clone(); + missing.sources.pop(); + assert!(validate_sha_endpoint_layout(&missing).is_err()); + + let mut reordered = endpoint_evals; + reordered.sources.swap(0, 1); + assert!(matches!( + validate_sha_endpoint_layout(&reordered), + Err(ProductionShaError::NonCanonicalProofObject(_)) + )); + } + + #[test] + fn pcs_lifted_evals_drive_multipoint_sources_and_recompute_publics() { + let field_cfg = cfg(); + let mut public = zero_public(); + public.columns[ShaPublicCol::K.index()].evaluations[0] = f(99); + let r_0 = vec![F::zero_with_cfg(&field_cfg); SHA_ROW_VARS]; + let layout = production_sha_multipoint_layout(); + + let mut lifted = vec![ + DynamicPolynomialF::new_trimmed(vec![F::zero_with_cfg(&field_cfg)]); + ShaWordCol::COUNT + ShaIntCol::COUNT + ]; + lifted[ShaWordCol::A.index()] = DynamicPolynomialF::new_trimmed(vec![f(3)]); + lifted[ShaWordCol::COUNT + ShaIntCol::CompSchedule.index()] = + DynamicPolynomialF::new_trimmed(vec![f(7)]); + + let open_evals = + multipoint_open_evals_from_pcs_lifted(&lifted, &layout, &public, &r_0, &field_cfg) + .unwrap(); + let a0_idx = layout + .sources + .iter() + .position(|source| { + *source + == ShaMpSource::WordBit { + col: ShaWordCol::A, + bit: 0, + } + }) + .unwrap(); + let int_idx = layout + .sources + .iter() + .position(|source| { + *source + == ShaMpSource::Int { + col: ShaIntCol::CompSchedule, + } + }) + .unwrap(); + assert!(!layout.sources.iter().any(|source| matches!( + source, + ShaMpSource::Public { + col: ShaPublicCol::K + } + ))); + + assert_eq!(open_evals[a0_idx], f(3)); + assert_eq!(open_evals[int_idx], f(7)); + assert_eq!( + sha_public_at_point(&public, ShaPublicCol::K, 0, &r_0, &field_cfg).unwrap(), + f(99) + ); + } + + #[test] + fn folded_lifted_evals_must_be_canonical_and_32_bit() { + let field_cfg = cfg(); + let mut lifted = vec![DynamicPolynomialF::ZERO; ShaWordCol::COUNT + ShaIntCol::COUNT]; + split_folded_sha_pcs_lifted_evals(&lifted).unwrap(); + ensure_production_sha_word_degree::().unwrap(); + assert!(matches!( + ensure_production_sha_word_degree::(), + Err(ProductionShaError::UnsupportedProductionShaWordDegree { + got: 8, + expected: 32 + }) + )); + + lifted[ShaWordCol::A.index()] = + DynamicPolynomialF::new(vec![F::zero_with_cfg(&field_cfg); SHA_WORD_BITS + 1]); + assert!(matches!( + split_folded_sha_pcs_lifted_evals(&lifted), + Err(ProductionShaError::NonCanonicalProofObject(_)) + )); + + lifted[ShaWordCol::A.index()] = + DynamicPolynomialF::new(vec![f(1), F::zero_with_cfg(&field_cfg)]); + assert!(matches!( + split_folded_sha_pcs_lifted_evals(&lifted), + Err(ProductionShaError::NonCanonicalProofObject(_)) + )); + + lifted[ShaWordCol::A.index()] = DynamicPolynomialF::ZERO; + lifted[ShaWordCol::COUNT + ShaIntCol::CompSchedule.index()] = + DynamicPolynomialF::new(vec![f(1), f(2)]); + assert!(matches!( + split_folded_sha_pcs_lifted_evals(&lifted), + Err(ProductionShaError::NonCanonicalProofObject(_)) + )); + } + + #[test] + fn production_sha_requires_exact_commitment_batch_sizes() { + validate_production_sha_batch_sizes::(ShaWordCol::COUNT, 0, ShaIntCol::COUNT).unwrap(); + + assert!(matches!( + validate_production_sha_batch_sizes::(0, 0, ShaIntCol::COUNT), + Err(ProductionShaError::UnsupportedProductionShaPcsShape(_)) + )); + assert!(matches!( + validate_production_sha_batch_sizes::(ShaWordCol::COUNT, 1, ShaIntCol::COUNT), + Err(ProductionShaError::UnsupportedProductionShaPcsShape(_)) + )); + assert!(matches!( + validate_production_sha_batch_sizes::(ShaWordCol::COUNT, 0, 0), + Err(ProductionShaError::UnsupportedProductionShaPcsShape(_)) + )); + } + + #[test] + fn scalarization_and_virtual_endpoints_use_source_bits_only() { + let field_cfg = cfg(); + let mut endpoint_evals = endpoint_evals_for_virtuals(); + verify_endpoint_scalarization(&endpoint_evals, &f(7), &field_cfg).unwrap(); + + let virtuals = reconstruct_virtual_ch_maj_endpoint(&endpoint_evals, &field_cfg).unwrap(); + let two = f(2); + for bit in 0..SHA_WORD_BITS { + let e0 = source_bits(&endpoint_evals, ShaWordCol::E, 0).unwrap()[bit].clone(); + let e1 = source_bits(&endpoint_evals, ShaWordCol::E, 1).unwrap()[bit].clone(); + let e2 = source_bits(&endpoint_evals, ShaWordCol::E, 2).unwrap()[bit].clone(); + let a0 = source_bits(&endpoint_evals, ShaWordCol::A, 0).unwrap()[bit].clone(); + let a1 = source_bits(&endpoint_evals, ShaWordCol::A, 1).unwrap()[bit].clone(); + let a2 = source_bits(&endpoint_evals, ShaWordCol::A, 2).unwrap()[bit].clone(); + let uef2 = source_bits(&endpoint_evals, ShaWordCol::Uef, 2).unwrap()[bit].clone(); + let uneg_eg2 = + source_bits(&endpoint_evals, ShaWordCol::UNegEg, 2).unwrap()[bit].clone(); + let ch2_comp0 = + source_bits(&endpoint_evals, ShaWordCol::Ch2Comp, 0).unwrap()[bit].clone(); + let maj2 = source_bits(&endpoint_evals, ShaWordCol::Maj, 2).unwrap()[bit].clone(); + let maj_comp0 = + source_bits(&endpoint_evals, ShaWordCol::MajComp, 0).unwrap()[bit].clone(); + + assert_eq!(virtuals.ch1[bit], e2.clone() + e1 - two.clone() * uef2); + assert_eq!( + virtuals.ch2[bit], + e2 - e0 + two.clone() * uneg_eg2 + two.clone() * ch2_comp0 + ); + assert_eq!( + virtuals.maj[bit], + a0 + a1 + a2 - two.clone() * maj2 - two.clone() * maj_comp0 + ); + } + + endpoint_evals.sources[0].scalarized += f(1); + assert!(verify_endpoint_scalarization(&endpoint_evals, &f(7), &field_cfg).is_err()); + } +} diff --git a/protocol/src/prover.rs b/protocol/src/prover.rs index 44b265d8..a2be0524 100644 --- a/protocol/src/prover.rs +++ b/protocol/src/prover.rs @@ -3,7 +3,7 @@ use crypto_primitives::{ ConstIntSemiring, FromPrimitiveWithConfig, FromWithConfig, crypto_bigint_int::Int, }; use num_traits::Zero; -use std::fmt::Debug; +use std::{fmt::Debug, io::Cursor}; use zinc_piop::{ combined_poly_resolver::CombinedPolyResolver, ideal_check::IdealCheckProtocol, @@ -13,7 +13,7 @@ use zinc_piop::{ compute_shifted_bit_slice_evals_streaming, finalize_booleanity_prover, prepare_booleanity_group, }, - multipoint_eval::{MultipointEval, Proof as MultipointEvalProof}, + multipoint_eval::Proof as MultipointEvalProof, projections::{ ColumnMajorTrace, ProjectedTrace, RowMajorTrace, ScalarMap, evaluate_trace_to_column_mles_fast, project_scalars, project_scalars_to_field, @@ -21,28 +21,37 @@ use zinc_piop::{ }, sumcheck::multi_degree::MultiDegreeSumcheck, }; +use zinc_poly::{mle::DenseMultilinearExtension, univariate::binary::BinaryPoly}; use zinc_poly::{ - mle::MultilinearExtensionWithConfig, - univariate::dynamic::over_field::DynamicPolynomialF, + mle::MultilinearExtensionWithConfig, univariate::dynamic::over_field::DynamicPolynomialF, +}; +use zinc_transcript::{ + Blake3Transcript, + traits::{ConstTranscribable, Transcript}, }; -use zinc_transcript::traits::{ConstTranscribable, Transcript}; use zinc_uair::{ Uair, UairSignature, UairTrace, constraint_counter::count_constraints, degree_counter::count_max_degree, }; use zinc_utils::{ - add, cfg_join, from_ref::FromRef, inner_transparent_field::InnerTransparentField, - mul_by_scalar::MulByScalar, projectable_to_field::ProjectableToField, + add, cfg_join, delayed_reduction::DelayedFieldProductSum, from_ref::FromRef, + inner_transparent_field::InnerTransparentField, mul_by_scalar::MulByScalar, + projectable_to_field::ProjectableToField, }; use zip_plus::{ pcs::{ ZipPlusProveByteBreakdown, + generic::PCS, multi_zip::MultiZip3, structs::{ZipPlus, ZipPlusHint, ZipPlusParams, ZipTypes}, }, pcs_transcript::PcsProverTranscript, }; -use zinc_poly::{mle::DenseMultilinearExtension, univariate::binary::BinaryPoly}; + +use crate::{ + multipoint_reduction::prove_multipoint_reduction, + pcs::{AllZipPCSTypes, PCSCommitments, PCSParams, PCSProverData, ZincPCSTypes}, +}; /// Drop the witness binary_poly columns the UAIR opted out of (sorted, /// dedup'd `skip_indices` relative to `witness_cols`) and return the @@ -73,24 +82,25 @@ fn filter_booleanity_witness( /// Fiat-Shamir transcript, PCS parameters/hints/commitments, and trace /// reference. #[derive(Clone, Debug)] -pub struct ProverBase<'a, Zt: ZincTypes, U: Uair, F: PrimeField, const D: usize> { +pub struct ProverBase< + 'a, + Zt: ZincTypes, + U: Uair, + F: PrimeField, + const D: usize, + P: ZincPCSTypes = AllZipPCSTypes, +> { num_vars: usize, uair_signature: UairSignature, pcs_transcript: PcsProverTranscript, trace: &'a UairTrace<'static, Zt::Int, Zt::Int, D>, // Commitment info - pp_bin: &'a ZipPlusParams, - pp_arb: &'a ZipPlusParams, - pp_int: &'a ZipPlusParams, - hint_bin: Option::Cw>>, - hint_arb: Option::Cw>>, - hint_int: Option::Cw>>, - commitment_bin: ZipPlusCommitment, - commitment_arb: ZipPlusCommitment, - commitment_int: ZipPlusCommitment, - - _phantom: PhantomData<(U, F)>, + pcs_params: PCSParams, + pcs_data: PCSProverData, + pcs_commitments: PCSCommitments, + + _phantom: PhantomData<(U, F, P)>, } // @@ -100,8 +110,15 @@ pub struct ProverBase<'a, Zt: ZincTypes, U: Uair, F: PrimeField, const D: usi /// After step 1 via [`step1_combined`](ProverCommitted::step1_combined) /// (row-major / "combined" projection). `project_scalar` has been consumed. #[derive(Clone, Debug)] -pub struct ProverProjectedCombined<'a, Zt: ZincTypes, U: Uair, F: PrimeField, const D: usize> { - base: ProverBase<'a, Zt, U, F, D>, +pub struct ProverProjectedCombined< + 'a, + Zt: ZincTypes, + U: Uair, + F: PrimeField, + const D: usize, + P: ZincPCSTypes = AllZipPCSTypes, +> { + base: ProverBase<'a, Zt, U, F, D, P>, field_cfg: F::Config, projected_trace: RowMajorTrace, projected_scalars_fx: ScalarMap>, @@ -110,8 +127,15 @@ pub struct ProverProjectedCombined<'a, Zt: ZincTypes, U: Uair, F: PrimeField, /// After step 1 via [`step1_mle_first`](ProverCommitted::step1_mle_first) /// (column-major / MLE-first projection). `project_scalar` has been consumed. #[derive(Clone, Debug)] -pub struct ProverProjectedMleFirst<'a, Zt: ZincTypes, U: Uair, F: PrimeField, const D: usize> { - base: ProverBase<'a, Zt, U, F, D>, +pub struct ProverProjectedMleFirst< + 'a, + Zt: ZincTypes, + U: Uair, + F: PrimeField, + const D: usize, + P: ZincPCSTypes = AllZipPCSTypes, +> { + base: ProverBase<'a, Zt, U, F, D, P>, field_cfg: F::Config, projected_trace: ColumnMajorTrace, projected_scalars_fx: ScalarMap>, @@ -123,8 +147,15 @@ pub struct ProverProjectedMleFirst<'a, Zt: ZincTypes, U: Uair, F: PrimeField, /// through the combined-poly lane (row-major). `project_scalar` has been /// consumed. #[derive(Clone, Debug)] -pub struct ProverProjectedHybrid<'a, Zt: ZincTypes, U: Uair, F: PrimeField, const D: usize> { - base: ProverBase<'a, Zt, U, F, D>, +pub struct ProverProjectedHybrid< + 'a, + Zt: ZincTypes, + U: Uair, + F: PrimeField, + const D: usize, + P: ZincPCSTypes = AllZipPCSTypes, +> { + base: ProverBase<'a, Zt, U, F, D, P>, field_cfg: F::Config, row_major_trace: RowMajorTrace, column_major_trace: ColumnMajorTrace, @@ -133,8 +164,15 @@ pub struct ProverProjectedHybrid<'a, Zt: ZincTypes, U: Uair, F: PrimeField, c /// After step 2 (ideal check). #[derive(Clone, Debug)] -pub struct ProverIdealChecked<'a, Zt: ZincTypes, U: Uair, F: PrimeField, const D: usize> { - base: ProverBase<'a, Zt, U, F, D>, +pub struct ProverIdealChecked< + 'a, + Zt: ZincTypes, + U: Uair, + F: PrimeField, + const D: usize, + P: ZincPCSTypes = AllZipPCSTypes, +> { + base: ProverBase<'a, Zt, U, F, D, P>, field_cfg: F::Config, projected_trace: ProjectedTrace, projected_scalars_fx: ScalarMap>, @@ -146,8 +184,15 @@ pub struct ProverIdealChecked<'a, Zt: ZincTypes, U: Uair, F: PrimeField, cons /// After step 3 (eval projection). `projected_scalars_fx` has been consumed. #[derive(Clone, Debug)] -pub struct ProverEvalProjected<'a, Zt: ZincTypes, U: Uair, F: PrimeField, const D: usize> { - base: ProverBase<'a, Zt, U, F, D>, +pub struct ProverEvalProjected< + 'a, + Zt: ZincTypes, + U: Uair, + F: PrimeField, + const D: usize, + P: ZincPCSTypes = AllZipPCSTypes, +> { + base: ProverBase<'a, Zt, U, F, D, P>, field_cfg: F::Config, projected_trace: ProjectedTrace, ic_proof: IdealCheckProof, @@ -164,8 +209,15 @@ pub struct ProverEvalProjected<'a, Zt: ZincTypes, U: Uair, F: PrimeField, con /// After step 4 (sumcheck). #[allow(clippy::type_complexity)] #[derive(Clone, Debug)] -pub struct ProverSumchecked<'a, Zt: ZincTypes, U: Uair, F: PrimeField, const D: usize> { - base: ProverBase<'a, Zt, U, F, D>, +pub struct ProverSumchecked< + 'a, + Zt: ZincTypes, + U: Uair, + F: PrimeField, + const D: usize, + P: ZincPCSTypes = AllZipPCSTypes, +> { + base: ProverBase<'a, Zt, U, F, D, P>, field_cfg: F::Config, projected_trace: ProjectedTrace, ic_proof: IdealCheckProof, @@ -185,8 +237,15 @@ pub struct ProverSumchecked<'a, Zt: ZincTypes, U: Uair, F: PrimeField, const /// After step 5 (multipoint eval). #[derive(Clone, Debug)] -pub struct ProverMultipointEvaled<'a, Zt: ZincTypes, U: Uair, F: PrimeField, const D: usize> { - base: ProverBase<'a, Zt, U, F, D>, +pub struct ProverMultipointEvaled< + 'a, + Zt: ZincTypes, + U: Uair, + F: PrimeField, + const D: usize, + P: ZincPCSTypes = AllZipPCSTypes, +> { + base: ProverBase<'a, Zt, U, F, D, P>, field_cfg: F::Config, projected_trace: ProjectedTrace, ic_proof: IdealCheckProof, @@ -201,8 +260,15 @@ pub struct ProverMultipointEvaled<'a, Zt: ZincTypes, U: Uair, F: PrimeField, /// After step 6 (lift-and-project). #[derive(Clone, Debug)] -pub struct ProverLifted<'a, Zt: ZincTypes, U: Uair, F: PrimeField, const D: usize> { - base: ProverBase<'a, Zt, U, F, D>, +pub struct ProverLifted< + 'a, + Zt: ZincTypes, + U: Uair, + F: PrimeField, + const D: usize, + P: ZincPCSTypes = AllZipPCSTypes, +> { + base: ProverBase<'a, Zt, U, F, D, P>, field_cfg: F::Config, ic_proof: IdealCheckProof, cpr_proof: CombinedPolyResolverProof, @@ -215,8 +281,8 @@ pub struct ProverLifted<'a, Zt: ZincTypes, U: Uair, F: PrimeField, const D: u lifted_evals: Vec>, } -impl<'a, Zt: ZincTypes, U: Uair, F: PrimeField, const D: usize> - ProverLifted<'a, Zt, U, F, D> +impl<'a, Zt: ZincTypes, U: Uair, F: PrimeField, const D: usize, P: ZincPCSTypes> + ProverLifted<'a, Zt, U, F, D, P> { /// PIOP evaluation point `r_0` produced by step 5. Used by external /// per-step bench harnesses (folded paths) to seed step 7 setups. @@ -230,8 +296,15 @@ impl<'a, Zt: ZincTypes, U: Uair, F: PrimeField, const D: usize> /// Ready for generating the final proof object in /// [`finish`](ProverPcsOpened::finish). #[derive(Clone, Debug)] -pub struct ProverPcsOpened<'a, Zt: ZincTypes, U: Uair, F: PrimeField, const D: usize> { - base: ProverBase<'a, Zt, U, F, D>, +pub struct ProverPcsOpened< + 'a, + Zt: ZincTypes, + U: Uair, + F: PrimeField, + const D: usize, + P: ZincPCSTypes = AllZipPCSTypes, +> { + base: ProverBase<'a, Zt, U, F, D, P>, ic_proof: IdealCheckProof, cpr_proof: CombinedPolyResolverProof, combined_sumcheck: MultiDegreeSumcheckProof, @@ -248,13 +321,16 @@ pub struct ProverPcsOpened<'a, Zt: ZincTypes, U: Uair, F: PrimeField, const D /// define them macro_rules! impl_with_type_bounds { ($type_name:ident { $($code:tt)* }) => { - impl<'a, Zt, U, F, const D: usize> $type_name<'a, Zt, U, F, D> + impl<'a, Zt, U, F, const D: usize, P> $type_name<'a, Zt, U, F, D, P> where Zt: ZincTypes, + P: ZincPCSTypes, Zt::Int: ProjectableToField, ::Eval: ProjectableToField, U: Uair + 'static, F: InnerTransparentField + + MontgomeryLimbs + + DelayedFieldProductSum + FromPrimitiveWithConfig + for<'b> FromWithConfig<&'b Zt::Int> + for<'b> FromWithConfig<&'b ::CombR> @@ -278,39 +354,84 @@ macro_rules! impl_with_type_bounds { impl ZincPlusPiop where Zt: ZincTypes, - U: Uair, - F: PrimeField, - F::Inner: ConstTranscribable, + Zt::Int: ProjectableToField, + ::Eval: ProjectableToField, + U: Uair + 'static, + F: InnerTransparentField + + DelayedFieldProductSum + + FromPrimitiveWithConfig + + for<'b> FromWithConfig<&'b Zt::Int> + + for<'b> FromWithConfig<&'b ::CombR> + + for<'b> FromWithConfig<&'b ::CombR> + + for<'b> FromWithConfig<&'b ::CombR> + + for<'b> FromWithConfig<&'b Zt::Chal> + + for<'b> MulByScalar<&'b F> + + FromRef + + Send + + Sync + + 'static, + F::Inner: + ConstIntSemiring + ConstTranscribable + FromRef + Send + Sync + Zero + Default, + F::Modulus: ConstTranscribable + FromRef, { - /// Step 0: Prover entry point. - /// Commit *witness* columns via Zip+ PCS, absorb roots and public - /// data into the Fiat-Shamir transcript. + /// Step 0: Prover entry point using the default all-Zip PCS bundle. #[allow(clippy::type_complexity)] pub fn step0_commit<'a>( - (pp_bin, pp_arb, pp_int): &'a ( + pp: &'a ( ZipPlusParams, ZipPlusParams, ZipPlusParams, ), trace: &'a UairTrace<'static, Zt::Int, Zt::Int, D>, num_vars: usize, - ) -> Result, ProtocolError> { + ) -> Result, ProtocolError> + where + AllZipPCSTypes: ZincPCSTypes< + Zt, + F, + D, + BinaryPCS = zip_plus::pcs::generic::ZipPlusPCS, + ArbitraryPCS = zip_plus::pcs::generic::ZipPlusPCS, + IntPCS = zip_plus::pcs::generic::ZipPlusPCS, + >, + { + let pcs_params = PCSParams:: { + binary: pp.0.clone(), + arbitrary: pp.1.clone(), + int: pp.2.clone(), + }; + Self::step0_commit_with_pcs::(&pcs_params, trace, num_vars) + } + + /// Step 0 with an explicit PCS bundle. + pub fn step0_commit_with_pcs<'a, P>( + pcs_params: &PCSParams, + trace: &'a UairTrace<'static, Zt::Int, Zt::Int, D>, + num_vars: usize, + ) -> Result, ProtocolError> + where + P: ZincPCSTypes, + { let uair_signature = U::signature(); let public_trace = trace.public(&uair_signature); let witness_trace = trace.witness(&uair_signature); let (res_bin, (res_arb, res_int)) = cfg_join!( - commit_optionally(pp_bin, &witness_trace.binary_poly), - commit_optionally(pp_arb, &witness_trace.arbitrary_poly), - commit_optionally(pp_int, &witness_trace.int), + P::BinaryPCS::commit(&pcs_params.binary, &witness_trace.binary_poly), + P::ArbitraryPCS::commit(&pcs_params.arbitrary, &witness_trace.arbitrary_poly), + P::IntPCS::commit(&pcs_params.int, &witness_trace.int), ); - let (hint_bin, commitment_bin) = res_bin?; - let (hint_arb, commitment_arb) = res_arb?; - let (hint_int, commitment_int) = res_int?; + let (data_bin, commitment_bin) = res_bin?; + let (data_arb, commitment_arb) = res_arb?; + let (data_int, commitment_int) = res_int?; - let mut pcs_transcript = PcsProverTranscript::new_from_commitments( - [&commitment_bin, &commitment_arb, &commitment_int].into_iter(), - ); + let mut pcs_transcript = PcsProverTranscript { + fs_transcript: Blake3Transcript::default(), + stream: Cursor::default(), + }; + P::BinaryPCS::absorb_commitment(&mut pcs_transcript.fs_transcript, &commitment_bin); + P::ArbitraryPCS::absorb_commitment(&mut pcs_transcript.fs_transcript, &commitment_arb); + P::IntPCS::absorb_commitment(&mut pcs_transcript.fs_transcript, &commitment_int); absorb_public_columns(&mut pcs_transcript.fs_transcript, &public_trace.binary_poly); absorb_public_columns( @@ -324,15 +445,17 @@ where uair_signature, pcs_transcript, trace, - pp_bin, - pp_arb, - pp_int, - hint_bin, - hint_arb, - hint_int, - commitment_bin, - commitment_arb, - commitment_int, + pcs_params: pcs_params.clone(), + pcs_data: PCSProverData { + binary: data_bin, + arbitrary: data_arb, + int: data_int, + }, + pcs_commitments: PCSCommitments { + binary: commitment_bin, + arbitrary: commitment_arb, + int: commitment_int, + }, _phantom: PhantomData, }) } @@ -351,6 +474,17 @@ impl_with_type_bounds!(ProverBase // See `crate::fixed_prime` for the soundness caveat. let field_cfg = crate::fixed_prime::secp256k1_field_cfg::(); + self.project_common_with_field_cfg(project_scalar, field_cfg) + } + + fn project_common_with_field_cfg< + S: Fn(&U::Scalar, &F::Config) -> DynamicPolynomialF + Sync, + >( + &mut self, + project_scalar: S, + field_cfg: F::Config, + ) -> Result<(F::Config, ScalarMap>), ProtocolError> + { let projected_scalars_fx = project_scalars::(|s| project_scalar(s, &field_cfg)); Ok((field_cfg, projected_scalars_fx)) } @@ -362,7 +496,7 @@ impl_with_type_bounds!(ProverBase pub fn step1_combined DynamicPolynomialF + Sync>( mut self, project_scalar: S, - ) -> Result, ProtocolError> { + ) -> Result, ProtocolError> { let (field_cfg, projected_scalars_fx) = self.project_common(project_scalar)?; let projected_trace = project_trace_coeffs_row_major(self.trace, &field_cfg); @@ -374,6 +508,25 @@ impl_with_type_bounds!(ProverBase }) } + pub fn step1_combined_with_field_cfg< + S: Fn(&U::Scalar, &F::Config) -> DynamicPolynomialF + Sync, + >( + mut self, + project_scalar: S, + field_cfg: F::Config, + ) -> Result, ProtocolError> { + let (field_cfg, projected_scalars_fx) = + self.project_common_with_field_cfg(project_scalar, field_cfg)?; + + let projected_trace = project_trace_coeffs_row_major(self.trace, &field_cfg); + Ok(ProverProjectedCombined { + base: self, + field_cfg, + projected_trace, + projected_scalars_fx, + }) + } + /// Step 1 (MLE-first / column-major): Prime projection /// (`\phi_q`: `Z[X] -> F_q[X]`). Samples a random prime, projects the /// full trace and scalars using the column-major layout. @@ -381,7 +534,7 @@ impl_with_type_bounds!(ProverBase pub fn step1_mle_first DynamicPolynomialF + Sync>( mut self, project_scalar: S, - ) -> Result, ProtocolError> { + ) -> Result, ProtocolError> { let (field_cfg, projected_scalars_fx) = self.project_common(project_scalar)?; let projected_trace = project_trace_coeffs_column_major(self.trace, &field_cfg); @@ -393,6 +546,25 @@ impl_with_type_bounds!(ProverBase }) } + pub fn step1_mle_first_with_field_cfg< + S: Fn(&U::Scalar, &F::Config) -> DynamicPolynomialF + Sync, + >( + mut self, + project_scalar: S, + field_cfg: F::Config, + ) -> Result, ProtocolError> { + let (field_cfg, projected_scalars_fx) = + self.project_common_with_field_cfg(project_scalar, field_cfg)?; + + let projected_trace = project_trace_coeffs_column_major(self.trace, &field_cfg); + Ok(ProverProjectedMleFirst { + base: self, + field_cfg, + projected_trace, + projected_scalars_fx, + }) + } + /// Step 1 (hybrid): Prime projection that produces **both** layouts. /// Used when the UAIR has a mix of linear and non-linear constraints, /// so the ideal-check can route them through their respective fast/slow @@ -401,7 +573,7 @@ impl_with_type_bounds!(ProverBase pub fn step1_hybrid DynamicPolynomialF + Sync>( mut self, project_scalar: S, - ) -> Result, ProtocolError> { + ) -> Result, ProtocolError> { let (field_cfg, projected_scalars_fx) = self.project_common(project_scalar)?; let row_major_trace = project_trace_coeffs_row_major(self.trace, &field_cfg); @@ -414,6 +586,27 @@ impl_with_type_bounds!(ProverBase projected_scalars_fx, }) } + + pub fn step1_hybrid_with_field_cfg< + S: Fn(&U::Scalar, &F::Config) -> DynamicPolynomialF + Sync, + >( + mut self, + project_scalar: S, + field_cfg: F::Config, + ) -> Result, ProtocolError> { + let (field_cfg, projected_scalars_fx) = + self.project_common_with_field_cfg(project_scalar, field_cfg)?; + + let row_major_trace = project_trace_coeffs_row_major(self.trace, &field_cfg); + let column_major_trace = project_trace_coeffs_column_major(self.trace, &field_cfg); + Ok(ProverProjectedHybrid { + base: self, + field_cfg, + row_major_trace, + column_major_trace, + projected_scalars_fx, + }) + } }); impl_with_type_bounds!(ProverProjectedCombined @@ -422,7 +615,7 @@ impl_with_type_bounds!(ProverProjectedCombined /// trace. Works for both linear and non-linear constraints. pub fn step2_ideal_check( mut self, - ) -> Result, ProtocolError> { + ) -> Result, ProtocolError> { let num_constraints = count_constraints::(); let (ic_proof, ic_prover_state) = U::prove_combined( @@ -451,7 +644,7 @@ impl_with_type_bounds!(ProverProjectedMleFirst /// trace. Only suitable for linear constraints. pub fn step2_ideal_check( mut self, - ) -> Result, ProtocolError> { + ) -> Result, ProtocolError> { let num_constraints = count_constraints::(); let (ic_proof, ic_prover_state) = U::prove_linear( @@ -483,7 +676,7 @@ impl_with_type_bounds!(ProverProjectedHybrid /// linear and non-linear constraints. pub fn step2_ideal_check( mut self, - ) -> Result, ProtocolError> { + ) -> Result, ProtocolError> { let num_constraints = count_constraints::(); let (ic_proof, ic_prover_state) = U::prove_hybrid( @@ -516,7 +709,7 @@ impl_with_type_bounds!(ProverIdealChecked /// `a in F_q`, evaluates polynomials at `X = a`. pub fn step3_eval_projection( mut self, - ) -> Result, ProtocolError> { + ) -> Result, ProtocolError> { let projecting_element: Zt::Chal = self.base.pcs_transcript.fs_transcript.get_challenge(); let projecting_element_f: F = F::from_with_cfg(&projecting_element, &self.field_cfg); @@ -551,7 +744,7 @@ impl_with_type_bounds!(ProverEvalProjected /// Produces `up_evals` and `down_evals` (CPR) and lookup auxiliary witnesses at `r*`. pub fn step4_sumcheck( mut self, - ) -> Result, ProtocolError> { + ) -> Result, ProtocolError> { let num_constraints = count_constraints::(); // Sumcheck protocol degree must accommodate the actual fold // polynomial's per-variable degree, including `assert_zero` @@ -778,7 +971,7 @@ impl_with_type_bounds!(ProverSumchecked /// to the source's `lifted_eval` (free arithmetic in F_q[X]). pub fn step5_multipoint_eval( mut self, - ) -> Result, ProtocolError> { + ) -> Result, ProtocolError> { let sig = &self.base.uair_signature; // Materialize the bit-op virtual MLEs (under ψ_α) — same shape @@ -804,7 +997,7 @@ impl_with_type_bounds!(ProverSumchecked let mut up_evals = self.cpr_proof.up_evals.clone(); up_evals.extend(self.cpr_proof.bit_op_down_evals.iter().cloned()); - let (mp_proof, mp_prover_state) = MultipointEval::prove_as_subprotocol( + let (mp_proof, r_0) = prove_multipoint_reduction( &mut self.base.pcs_transcript.fs_transcript, &sources, &self.cpr_eval_point, @@ -823,7 +1016,7 @@ impl_with_type_bounds!(ProverSumchecked combined_sumcheck: self.combined_sumcheck, lookup_proof: self.lookup_proof, mp_proof, - r_0: mp_prover_state.eval_point, + r_0, }) } }); @@ -834,7 +1027,7 @@ impl_with_type_bounds!(ProverMultipointEvaled /// evaluations at `r_0` in `F_q[X]` and absorbs them into the transcript. pub fn step6_lift_and_project( mut self, - ) -> Result, ProtocolError> { + ) -> Result, ProtocolError> { // Compute per-column polynomial MLE evaluations at r_0 in F_q[X] // (after \phi_q but before \psi_a). The verifier derives the scalar // open_evals via \psi_a for the sumcheck consistency check, and @@ -873,39 +1066,33 @@ impl_with_type_bounds!(ProverLifted /// Step 7: PCS open at `r_0` (witness columns only). pub fn step7_pcs_open( mut self, - ) -> Result, ProtocolError> { + ) -> Result, ProtocolError> { let witness_trace = self.base.trace.witness(&self.base.uair_signature); - if let Some(hint_bin) = &self.base.hint_bin { - let _ = ZipPlus::::prove_f::<_, CHECK_FOR_OVERFLOW>( - &mut self.base.pcs_transcript, - self.base.pp_bin, - &witness_trace.binary_poly, - &self.r_0, - hint_bin, - &self.field_cfg, - )?; - } - if let Some(hint_arb) = &self.base.hint_arb { - let _ = ZipPlus::::prove_f::<_, CHECK_FOR_OVERFLOW>( - &mut self.base.pcs_transcript, - self.base.pp_arb, - &witness_trace.arbitrary_poly, - &self.r_0, - hint_arb, - &self.field_cfg, - )?; - } - if let Some(hint_int) = &self.base.hint_int { - let _ = ZipPlus::::prove_f::<_, CHECK_FOR_OVERFLOW>( - &mut self.base.pcs_transcript, - self.base.pp_int, - &witness_trace.int, - &self.r_0, - hint_int, - &self.field_cfg, - )?; - } + P::BinaryPCS::prove_open::( + &mut self.base.pcs_transcript, + &self.base.pcs_params.binary, + &witness_trace.binary_poly, + &self.r_0, + &self.base.pcs_data.binary, + &self.field_cfg, + )?; + P::ArbitraryPCS::prove_open::( + &mut self.base.pcs_transcript, + &self.base.pcs_params.arbitrary, + &witness_trace.arbitrary_poly, + &self.r_0, + &self.base.pcs_data.arbitrary, + &self.field_cfg, + )?; + P::IntPCS::prove_open::( + &mut self.base.pcs_transcript, + &self.base.pcs_params.int, + &witness_trace.int, + &self.r_0, + &self.base.pcs_data.int, + &self.field_cfg, + )?; Ok(ProverPcsOpened { base: self.base, @@ -922,14 +1109,10 @@ impl_with_type_bounds!(ProverLifted impl_with_type_bounds!(ProverPcsOpened { /// Assemble the final proof from accumulated state. - pub fn finish(self) -> Result, ProtocolError> { + pub fn finish(self) -> Result>, ProtocolError> { let sig = self.base.uair_signature; let zip_proof = self.base.pcs_transcript.stream.into_inner(); - let commitments = ( - self.base.commitment_bin, - self.base.commitment_arb, - self.base.commitment_int, - ); + let commitments = self.base.pcs_commitments; let lifted_evals = self.lifted_evals; @@ -975,6 +1158,8 @@ where Zt::Int: ProjectableToField, ::Eval: ProjectableToField, F: InnerTransparentField + + MontgomeryLimbs + + DelayedFieldProductSum + FromPrimitiveWithConfig + for<'a> FromWithConfig<&'a Zt::Int> + for<'a> FromWithConfig<&'a ::CombR> @@ -1024,7 +1209,60 @@ where num_vars: usize, project_scalar: impl Fn(&U::Scalar, &F::Config) -> DynamicPolynomialF + Sync, ) -> Result, ProtocolError> { - let committed = Self::step0_commit(pp, trace, num_vars)?; + let pcs_params = PCSParams:: { + binary: pp.0.clone(), + arbitrary: pp.1.clone(), + int: pp.2.clone(), + }; + let proof = Self::prove_with_pcs::( + &pcs_params, + trace, + num_vars, + project_scalar, + )?; + let commitments = proof.commitments; + Ok(Proof { + commitments: (commitments.binary, commitments.arbitrary, commitments.int), + zip: proof.zip, + ideal_check: proof.ideal_check, + resolver: proof.resolver, + combined_sumcheck: proof.combined_sumcheck, + multipoint_eval: proof.multipoint_eval, + witness_lifted_evals: proof.witness_lifted_evals, + lookup_proof: proof.lookup_proof, + }) + } + + pub fn prove_with_pcs( + pp: &PCSParams, + trace: &UairTrace<'static, Zt::Int, Zt::Int, D>, + num_vars: usize, + project_scalar: impl Fn(&U::Scalar, &F::Config) -> DynamicPolynomialF + Sync, + ) -> Result>, ProtocolError> + where + P: ZincPCSTypes, + { + let field_cfg = crate::fixed_prime::secp256k1_field_cfg::(); + Self::prove_with_pcs_and_field_cfg::( + pp, + trace, + num_vars, + project_scalar, + field_cfg, + ) + } + + pub fn prove_with_pcs_and_field_cfg( + pp: &PCSParams, + trace: &UairTrace<'static, Zt::Int, Zt::Int, D>, + num_vars: usize, + project_scalar: impl Fn(&U::Scalar, &F::Config) -> DynamicPolynomialF + Sync, + field_cfg: F::Config, + ) -> Result>, ProtocolError> + where + P: ZincPCSTypes, + { + let committed = Self::step0_commit_with_pcs::

(pp, trace, num_vars)?; let ideal_checked = if MLE_FIRST { // Classify constraints by degree, ignoring zero-ideal (their @@ -1039,22 +1277,26 @@ where if i.is_zero_ideal() { continue; } - if *m { any_linear = true } else { any_nonlinear = true } + if *m { + any_linear = true + } else { + any_nonlinear = true + } } match (any_linear, any_nonlinear) { (true, false) => committed - .step1_mle_first(project_scalar)? + .step1_mle_first_with_field_cfg(project_scalar, field_cfg)? .step2_ideal_check()?, (false, _) => committed - .step1_combined(project_scalar)? + .step1_combined_with_field_cfg(project_scalar, field_cfg)? .step2_ideal_check()?, (true, true) => committed - .step1_hybrid(project_scalar)? + .step1_hybrid_with_field_cfg(project_scalar, field_cfg)? .step2_ideal_check()?, } } else { committed - .step1_combined(project_scalar)? + .step1_combined_with_field_cfg(project_scalar, field_cfg)? .step2_ideal_check()? }; @@ -1127,6 +1369,8 @@ where BinaryPoly: ProjectableToField, U: Uair + 'static, F: InnerTransparentField + + MontgomeryLimbs + + DelayedFieldProductSum + FromPrimitiveWithConfig + for<'a> FromWithConfig<&'a ZtF::Int> + for<'a> FromWithConfig<&'a ::CombR> @@ -1266,9 +1510,8 @@ where let projected_trace_f = evaluate_trace_to_column_mles_fast(trace, &projecting_element_f, &field_cfg); - let projected_scalars_f = - project_scalars_to_field(projected_scalars_fx, &projecting_element_f) - .map_err(|(_s, _f, e)| ProtocolError::ScalarProjection(e))?; + let projected_scalars_f = project_scalars_to_field(projected_scalars_fx, &projecting_element_f) + .map_err(|(_s, _f, e)| ProtocolError::ScalarProjection(e))?; // ── Step 4: CPR + booleanity multi-degree sumcheck ────────────────── let max_degree = count_max_degree::(); @@ -1312,14 +1555,10 @@ where let virtual_mles = if virtual_specs.is_empty() { Vec::new() } else { - let self_bit_slices = compute_bit_slices_flat::( - &trace.binary_poly[num_pub_bin..], - &field_cfg, - ); - let public_bit_slices = compute_bit_slices_flat::( - &trace.binary_poly[..num_pub_bin], - &field_cfg, - ); + let self_bit_slices = + compute_bit_slices_flat::(&trace.binary_poly[num_pub_bin..], &field_cfg); + let public_bit_slices = + compute_bit_slices_flat::(&trace.binary_poly[..num_pub_bin], &field_cfg); let int_witness_cols: Vec<_> = (0..num_wit_int) .map(|i| projected_trace_f[int_offset + num_pub_int + i].clone()) .collect(); @@ -1422,7 +1661,7 @@ where sources.extend(bit_op_mles); let mut up_evals_with_bit_op = cpr_proof.up_evals.clone(); up_evals_with_bit_op.extend(cpr_proof.bit_op_down_evals.iter().cloned()); - let (mp_proof, mp_prover_state) = MultipointEval::prove_as_subprotocol( + let (mp_proof, r_0) = prove_multipoint_reduction( &mut pcs_transcript.fs_transcript, &sources, &cpr_eval_point, @@ -1431,7 +1670,6 @@ where uair_signature.shifts(), &field_cfg, )?; - let r_0 = mp_prover_state.eval_point; // ── Step 6: Lift-and-project + sample γ for folding ───────────────── let lifted_evals = @@ -1464,15 +1702,14 @@ where )?; } if let Some(hint_arb) = &hint_arb { - let _ = - ZipPlus::::prove_f::<_, CHECK_FOR_OVERFLOW>( - &mut pcs_transcript, - pp_arb, - &witness_trace.arbitrary_poly, - &r_0, - hint_arb, - &field_cfg, - )?; + let _ = ZipPlus::::prove_f::<_, CHECK_FOR_OVERFLOW>( + &mut pcs_transcript, + pp_arb, + &witness_trace.arbitrary_poly, + &r_0, + hint_arb, + &field_cfg, + )?; } if let Some(hint_int) = &hint_int { let _ = ZipPlus::::prove_f::<_, CHECK_FOR_OVERFLOW>( @@ -1533,8 +1770,6 @@ where // from `witness_lifted_evals` coefficient eighths (see `verify_folded_4x`). // - - /// Per-domain (binary / arbitrary / integer) byte breakdown of the /// PCS bytes written during step 7 of [`prove_folded_4x`]. Each domain /// holds its own [`ZipPlusProveByteBreakdown`] (sums of the four @@ -1548,7 +1783,6 @@ pub struct FoldedProveZipBreakdown { pub int: ZipPlusProveByteBreakdown, } - /// Per-region wall-time breakdown of a single [`prove_folded_4x`] run, /// populated by [`prove_folded_4x_with_timings`]. Useful as a /// criterion-bypassing diagnostic — each step's `Duration` is measured @@ -1613,6 +1847,29 @@ impl FoldedProveTimings { } } +#[tracing::instrument( + target = "zinc_protocol::prover", + level = "info", + skip_all, + fields( + side = "prove", + prover = "folded_4x", + phase = _phase, + num_vars = _num_vars, + mle_first = _mle_first, + check_for_overflow = _check_for_overflow + ) +)] +fn trace_folded_4x_prover_phase( + _phase: &'static str, + _num_vars: usize, + _mle_first: bool, + _check_for_overflow: bool, + run: impl FnOnce() -> T, +) -> T { + run() +} + pub fn prove_folded_4x< ZtF, U, @@ -1643,6 +1900,8 @@ where BinaryPoly: ProjectableToField, U: Uair, D>> + 'static, F: InnerTransparentField + + MontgomeryLimbs + + DelayedFieldProductSum + FromPrimitiveWithConfig + for<'a> FromWithConfig<&'a Int> + for<'a> FromWithConfig<&'a Int> @@ -1708,6 +1967,8 @@ where BinaryPoly: ProjectableToField, U: Uair, D>> + 'static, F: InnerTransparentField + + MontgomeryLimbs + + DelayedFieldProductSum + FromPrimitiveWithConfig + for<'a> FromWithConfig<&'a Int> + for<'a> FromWithConfig<&'a Int> @@ -1737,9 +1998,19 @@ where INT_QUARTER_LIMBS, MLE_FIRST, CHECK_FOR_OVERFLOW, - >(pp, trace, num_vars, project_scalar, Some(&mut timings), None)?; + >( + pp, + trace, + num_vars, + project_scalar, + Some(&mut timings), + None, + )?; - let (_compressed, dt) = zip_plus::utils::serialize_and_compress(&proof); + let (_compressed, dt) = + trace_folded_4x_prover_phase("compress", num_vars, MLE_FIRST, CHECK_FOR_OVERFLOW, || { + zip_plus::utils::serialize_and_compress(&proof) + }); timings.step8_compress = dt; Ok((proof, timings)) @@ -1779,6 +2050,8 @@ where BinaryPoly: ProjectableToField, U: Uair, D>> + 'static, F: InnerTransparentField + + MontgomeryLimbs + + DelayedFieldProductSum + FromPrimitiveWithConfig + for<'a> FromWithConfig<&'a Int> + for<'a> FromWithConfig<&'a Int> @@ -1808,7 +2081,14 @@ where INT_QUARTER_LIMBS, MLE_FIRST, CHECK_FOR_OVERFLOW, - >(pp, trace, num_vars, project_scalar, None, Some(&mut breakdown))?; + >( + pp, + trace, + num_vars, + project_scalar, + None, + Some(&mut breakdown), + )?; Ok((proof, breakdown)) } @@ -1846,6 +2126,8 @@ where BinaryPoly: ProjectableToField, U: Uair, D>> + 'static, F: InnerTransparentField + + MontgomeryLimbs + + DelayedFieldProductSum + FromPrimitiveWithConfig + for<'a> FromWithConfig<&'a Int> + for<'a> FromWithConfig<&'a Int> @@ -1870,27 +2152,9 @@ where // ── Step 0: Commit (twice-split binary, arb, quartered int) ───────── let _t_step0 = std::time::Instant::now(); - let split1: Vec>> = - zip_plus::pcs::folding::split_columns::(&witness_trace.binary_poly); - let split_binary_witness: Vec>> = - zip_plus::pcs::folding::split_columns::(&split1); - drop(split1); - let split_int_witness: Vec>> = - zip_plus::pcs::folding::split_int_columns_4x::( - &witness_trace.int, - ); - - // Shared-Merkle dispatch (same criterion as the 1× int-fold path): - // ≥2 non-empty batches AND arb is empty / has matching codeword. - let arb_compatible = witness_trace.arbitrary_poly.is_empty() - || pp_arb.linear_code.codeword_len() == pp_bin_split2.linear_code.codeword_len(); - let bin_nonempty = !split_binary_witness.is_empty(); - let int_nonempty = !split_int_witness.is_empty(); - let arb_nonempty = !witness_trace.arbitrary_poly.is_empty(); - let nonempty_count = (bin_nonempty as u8) + (arb_nonempty as u8) + (int_nonempty as u8); - let use_multi = nonempty_count >= 2 && arb_compatible; - let ( + split_binary_witness, + split_int_witness, hint_bin_split, hint_arb, hint_int_split, @@ -1898,104 +2162,179 @@ where commitment_bin, commitment_arb, commitment_int, - ) = if use_multi { - let (multi, comm_bin, comm_arb, comm_int) = MultiZip3::< - ZtF::BinaryZt, - ZtF::ArbitraryZt, - ZtF::IntZt, - ZtF::BinaryLc, - ZtF::ArbitraryLc, - ZtF::IntLc, - >::commit( - pp_bin_split2, - pp_arb, - pp_int_split4, - &split_binary_witness, - &witness_trace.arbitrary_poly, - &split_int_witness, - )?; - (None, None, None, Some(multi), comm_bin, comm_arb, comm_int) - } else { - let (res_bin, (res_arb, res_int)) = cfg_join!( - commit_optionally(pp_bin_split2, &split_binary_witness), - commit_optionally(pp_arb, &witness_trace.arbitrary_poly), - commit_optionally(pp_int_split4, &split_int_witness), - ); - let (hb, cb) = res_bin?; - let (ha, ca) = res_arb?; - let (hi, ci) = res_int?; - (hb, ha, hi, None, cb, ca, ci) - }; + mut pcs_transcript, + ) = trace_folded_4x_prover_phase("commit", num_vars, MLE_FIRST, CHECK_FOR_OVERFLOW, || { + let split1: Vec>> = + zip_plus::pcs::folding::split_columns::(&witness_trace.binary_poly); + let split_binary_witness: Vec>> = + zip_plus::pcs::folding::split_columns::(&split1); + drop(split1); + let split_int_witness: Vec>> = + zip_plus::pcs::folding::split_int_columns_4x::( + &witness_trace.int, + ); - let mut pcs_transcript = PcsProverTranscript::new_from_commitments( - [&commitment_bin, &commitment_arb, &commitment_int].into_iter(), - ); + // Shared-Merkle dispatch (same criterion as the 1× int-fold path): + // ≥2 non-empty batches AND arb is empty / has matching codeword. + let arb_compatible = witness_trace.arbitrary_poly.is_empty() + || pp_arb.linear_code.codeword_len() == pp_bin_split2.linear_code.codeword_len(); + let bin_nonempty = !split_binary_witness.is_empty(); + let int_nonempty = !split_int_witness.is_empty(); + let arb_nonempty = !witness_trace.arbitrary_poly.is_empty(); + let nonempty_count = (bin_nonempty as u8) + (arb_nonempty as u8) + (int_nonempty as u8); + let use_multi = nonempty_count >= 2 && arb_compatible; + + let ( + hint_bin_split, + hint_arb, + hint_int_split, + multi_hint, + commitment_bin, + commitment_arb, + commitment_int, + ) = if use_multi { + let (multi, comm_bin, comm_arb, comm_int) = MultiZip3::< + ZtF::BinaryZt, + ZtF::ArbitraryZt, + ZtF::IntZt, + ZtF::BinaryLc, + ZtF::ArbitraryLc, + ZtF::IntLc, + >::commit( + pp_bin_split2, + pp_arb, + pp_int_split4, + &split_binary_witness, + &witness_trace.arbitrary_poly, + &split_int_witness, + )?; + (None, None, None, Some(multi), comm_bin, comm_arb, comm_int) + } else { + let (res_bin, (res_arb, res_int)) = cfg_join!( + commit_optionally(pp_bin_split2, &split_binary_witness), + commit_optionally(pp_arb, &witness_trace.arbitrary_poly), + commit_optionally(pp_int_split4, &split_int_witness), + ); + let (hb, cb) = res_bin?; + let (ha, ca) = res_arb?; + let (hi, ci) = res_int?; + (hb, ha, hi, None, cb, ca, ci) + }; - absorb_public_columns(&mut pcs_transcript.fs_transcript, &public_trace.binary_poly); - absorb_public_columns( - &mut pcs_transcript.fs_transcript, - &public_trace.arbitrary_poly, - ); - absorb_public_columns(&mut pcs_transcript.fs_transcript, &public_trace.int); - if let Some(t) = timings.as_mut() { - t.step0_commit = _t_step0.elapsed(); - } + let mut pcs_transcript = PcsProverTranscript::new_from_commitments( + [&commitment_bin, &commitment_arb, &commitment_int].into_iter(), + ); - // ── Step 1: Prime projection ──────────────────────────────────────── - let _t_step1 = std::time::Instant::now(); - let field_cfg = crate::fixed_prime::secp256k1_field_cfg::(); - let projected_scalars_fx = project_scalars::(|s| project_scalar(s, &field_cfg)); - if let Some(t) = timings.as_mut() { - t.step1_prime_projection = _t_step1.elapsed(); - } + absorb_public_columns(&mut pcs_transcript.fs_transcript, &public_trace.binary_poly); + absorb_public_columns( + &mut pcs_transcript.fs_transcript, + &public_trace.arbitrary_poly, + ); + absorb_public_columns(&mut pcs_transcript.fs_transcript, &public_trace.int); + + Ok::<_, ProtocolError>(( + split_binary_witness, + split_int_witness, + hint_bin_split, + hint_arb, + hint_int_split, + multi_hint, + commitment_bin, + commitment_arb, + commitment_int, + pcs_transcript, + )) + })?; + if let Some(t) = timings.as_mut() { + t.step0_commit = _t_step0.elapsed(); + } + + // ── Step 1: Prime projection ──────────────────────────────────────── + let _t_step1 = std::time::Instant::now(); + let (field_cfg, projected_scalars_fx) = trace_folded_4x_prover_phase( + "prime_projection", + num_vars, + MLE_FIRST, + CHECK_FOR_OVERFLOW, + || { + let field_cfg = crate::fixed_prime::secp256k1_field_cfg::(); + let projected_scalars_fx = project_scalars::(|s| project_scalar(s, &field_cfg)); + (field_cfg, projected_scalars_fx) + }, + ); + if let Some(t) = timings.as_mut() { + t.step1_prime_projection = _t_step1.elapsed(); + } // ── Step 2: Ideal check ───────────────────────────────────────────── let _t_step2 = std::time::Instant::now(); let num_constraints = count_constraints::(); - let (ic_proof, ic_prover_state, projected_trace) = if MLE_FIRST { - let mask = zinc_uair::degree_counter::linear_constraint_mask::(); - let ideals = zinc_uair::ideal_collector::collect_ideals::(num_constraints).ideals; - let (mut any_linear, mut any_nonlinear) = (false, false); - for (m, i) in mask.iter().zip(ideals.iter()) { - if i.is_zero_ideal() { - continue; - } - if *m { - any_linear = true + let (ic_proof, ic_prover_state, projected_trace) = trace_folded_4x_prover_phase( + "ideal_check", + num_vars, + MLE_FIRST, + CHECK_FOR_OVERFLOW, + || { + let out = if MLE_FIRST { + let mask = zinc_uair::degree_counter::linear_constraint_mask::(); + let ideals = + zinc_uair::ideal_collector::collect_ideals::(num_constraints).ideals; + let (mut any_linear, mut any_nonlinear) = (false, false); + for (m, i) in mask.iter().zip(ideals.iter()) { + if i.is_zero_ideal() { + continue; + } + if *m { + any_linear = true + } else { + any_nonlinear = true + } + } + match (any_linear, any_nonlinear) { + (true, false) => { + let projected_trace_cm = + project_trace_coeffs_column_major(trace, &field_cfg); + let (p, s) = U::prove_linear( + &mut pcs_transcript.fs_transcript, + &projected_trace_cm, + &projected_scalars_fx, + num_constraints, + num_vars, + &field_cfg, + )?; + (p, s, ProjectedTrace::ColumnMajor(projected_trace_cm)) + } + (true, true) => { + let (rm, cm) = cfg_join!( + project_trace_coeffs_row_major::(trace, &field_cfg), + project_trace_coeffs_column_major(trace, &field_cfg), + ); + let (p, s) = U::prove_hybrid( + &mut pcs_transcript.fs_transcript, + &rm, + &cm, + &projected_scalars_fx, + num_constraints, + num_vars, + &field_cfg, + )?; + (p, s, ProjectedTrace::RowMajor(rm)) + } + (false, _) => { + let projected_trace_rm = + project_trace_coeffs_row_major::(trace, &field_cfg); + let (p, s) = U::prove_combined( + &mut pcs_transcript.fs_transcript, + &projected_trace_rm, + &projected_scalars_fx, + num_constraints, + num_vars, + &field_cfg, + )?; + (p, s, ProjectedTrace::RowMajor(projected_trace_rm)) + } + } } else { - any_nonlinear = true - } - } - match (any_linear, any_nonlinear) { - (true, false) => { - let projected_trace_cm = project_trace_coeffs_column_major(trace, &field_cfg); - let (p, s) = U::prove_linear( - &mut pcs_transcript.fs_transcript, - &projected_trace_cm, - &projected_scalars_fx, - num_constraints, - num_vars, - &field_cfg, - )?; - (p, s, ProjectedTrace::ColumnMajor(projected_trace_cm)) - } - (true, true) => { - let (rm, cm) = cfg_join!( - project_trace_coeffs_row_major::(trace, &field_cfg), - project_trace_coeffs_column_major(trace, &field_cfg), - ); - let (p, s) = U::prove_hybrid( - &mut pcs_transcript.fs_transcript, - &rm, - &cm, - &projected_scalars_fx, - num_constraints, - num_vars, - &field_cfg, - )?; - (p, s, ProjectedTrace::RowMajor(rm)) - } - (false, _) => { let projected_trace_rm = project_trace_coeffs_row_major::(trace, &field_cfg); let (p, s) = U::prove_combined( @@ -2007,20 +2346,10 @@ where &field_cfg, )?; (p, s, ProjectedTrace::RowMajor(projected_trace_rm)) - } - } - } else { - let projected_trace_rm = project_trace_coeffs_row_major::(trace, &field_cfg); - let (p, s) = U::prove_combined( - &mut pcs_transcript.fs_transcript, - &projected_trace_rm, - &projected_scalars_fx, - num_constraints, - num_vars, - &field_cfg, - )?; - (p, s, ProjectedTrace::RowMajor(projected_trace_rm)) - }; + }; + Ok::<_, ProtocolError>(out) + }, + )?; let ic_eval_point = ic_prover_state.evaluation_point; if let Some(t) = timings.as_mut() { t.step2_ideal_check = _t_step2.elapsed(); @@ -2028,166 +2357,211 @@ where // ── Step 3: Eval projection (ψ_a) ─────────────────────────────────── let _t_step3 = std::time::Instant::now(); - let projecting_element: ZtF::Chal = pcs_transcript.fs_transcript.get_challenge(); - let projecting_element_f: F = F::from_with_cfg(&projecting_element, &field_cfg); - - let projected_trace_f = - evaluate_trace_to_column_mles_fast(trace, &projecting_element_f, &field_cfg); - let projected_scalars_f = - project_scalars_to_field(projected_scalars_fx, &projecting_element_f) - .map_err(|(_s, _f, e)| ProtocolError::ScalarProjection(e))?; + let (projecting_element_f, projected_trace_f, projected_scalars_f) = + trace_folded_4x_prover_phase( + "eval_projection", + num_vars, + MLE_FIRST, + CHECK_FOR_OVERFLOW, + || { + let projecting_element: ZtF::Chal = pcs_transcript.fs_transcript.get_challenge(); + let projecting_element_f: F = F::from_with_cfg(&projecting_element, &field_cfg); + + let projected_trace_f = + evaluate_trace_to_column_mles_fast(trace, &projecting_element_f, &field_cfg); + let projected_scalars_f = + project_scalars_to_field(projected_scalars_fx, &projecting_element_f) + .map_err(|(_s, _f, e)| ProtocolError::ScalarProjection(e))?; + + Ok::<_, ProtocolError>(( + projecting_element_f, + projected_trace_f, + projected_scalars_f, + )) + }, + )?; if let Some(t) = timings.as_mut() { t.step3_eval_projection = _t_step3.elapsed(); } // ── Step 4: CPR + booleanity multi-degree sumcheck ────────────────── let _t_step4 = std::time::Instant::now(); - let max_degree = count_max_degree::(); - let (cpr_group, cpr_ancillary) = CombinedPolyResolver::prepare_sumcheck_group::( - &mut pcs_transcript.fs_transcript, - projected_trace_f.clone(), - &ic_eval_point, - &projected_scalars_f, - num_constraints, - num_vars, - max_degree, - &field_cfg, - &trace.binary_poly, - &projecting_element_f, - )?; + let (cpr_proof, cpr_eval_point, combined_sumcheck, lookup_proof) = + trace_folded_4x_prover_phase("sumcheck", num_vars, MLE_FIRST, CHECK_FOR_OVERFLOW, || { + let max_degree = count_max_degree::(); + let (cpr_group, cpr_ancillary) = CombinedPolyResolver::prepare_sumcheck_group::( + &mut pcs_transcript.fs_transcript, + projected_trace_f.clone(), + &ic_eval_point, + &projected_scalars_f, + num_constraints, + num_vars, + max_degree, + &field_cfg, + &trace.binary_poly, + &projecting_element_f, + )?; - let num_pub_bin = uair_signature.public_cols().num_binary_poly_cols(); - let num_pub_int = uair_signature.public_cols().num_int_cols(); - let num_wit_int = uair_signature.witness_cols().num_int_cols(); - let int_offset = trace.binary_poly.len() + trace.arbitrary_poly.len(); - let int_bit_cols: Vec<_> = uair_signature - .int_witness_bit_cols() - .iter() - .map(|&idx| projected_trace_f[int_offset + idx].clone()) - .collect(); - let shifted_bit_slice_mles = build_shifted_bit_slice_mles::( - &trace.binary_poly[num_pub_bin..], - uair_signature.shifted_bit_slice_specs(), - &field_cfg, - ); - let virtual_specs = uair_signature.virtual_booleanity_cols(); - let virtual_mles = if virtual_specs.is_empty() { - Vec::new() - } else { - let self_bit_slices = compute_bit_slices_flat::( - &trace.binary_poly[num_pub_bin..], - &field_cfg, - ); - let public_bit_slices = compute_bit_slices_flat::( - &trace.binary_poly[..num_pub_bin], - &field_cfg, - ); - let int_witness_cols: Vec<_> = (0..num_wit_int) - .map(|i| projected_trace_f[int_offset + num_pub_int + i].clone()) - .collect(); - build_virtual_booleanity_mles::( - &self_bit_slices, - &shifted_bit_slice_mles, - &public_bit_slices, - &int_witness_cols, - virtual_specs, - &field_cfg, - ) - }; - let mut extra_bit_cols = int_bit_cols; - extra_bit_cols.extend(virtual_mles); - let virtual_bp_specs = uair_signature.virtual_binary_poly_cols(); - let virtual_binary_mles = build_virtual_binary_poly_mles::( - &trace.binary_poly[num_pub_bin..], - &trace.binary_poly[..num_pub_bin], - uair_signature.shifted_bit_slice_specs(), - virtual_bp_specs, - ); - let kept_witness_binary = filter_booleanity_witness( - &trace.binary_poly[num_pub_bin..], - uair_signature.booleanity_skip_indices(), - ); - let booleanity_binary_cols: Vec<_> = kept_witness_binary - .iter() - .chain(virtual_binary_mles.iter()) - .cloned() - .collect(); - let bool_prep = prepare_booleanity_group::( - &mut pcs_transcript.fs_transcript, - &booleanity_binary_cols, - &extra_bit_cols, - &ic_eval_point, - &field_cfg, - ) - .map_err(ProtocolError::Booleanity)?; + let num_pub_bin = uair_signature.public_cols().num_binary_poly_cols(); + let num_pub_int = uair_signature.public_cols().num_int_cols(); + let num_wit_int = uair_signature.witness_cols().num_int_cols(); + let int_offset = trace.binary_poly.len() + trace.arbitrary_poly.len(); + let int_bit_cols: Vec<_> = uair_signature + .int_witness_bit_cols() + .iter() + .map(|&idx| projected_trace_f[int_offset + idx].clone()) + .collect(); + let virtual_specs = uair_signature.virtual_booleanity_cols(); + let shifted_bit_slice_mles = if virtual_specs.is_empty() { + Vec::new() + } else { + build_shifted_bit_slice_mles::( + &trace.binary_poly[num_pub_bin..], + uair_signature.shifted_bit_slice_specs(), + &field_cfg, + ) + }; + let virtual_mles = if virtual_specs.is_empty() { + Vec::new() + } else { + let self_bit_slices = + compute_bit_slices_flat::(&trace.binary_poly[num_pub_bin..], &field_cfg); + let public_bit_slices = + compute_bit_slices_flat::(&trace.binary_poly[..num_pub_bin], &field_cfg); + let int_witness_cols: Vec<_> = (0..num_wit_int) + .map(|i| projected_trace_f[int_offset + num_pub_int + i].clone()) + .collect(); + build_virtual_booleanity_mles::( + &self_bit_slices, + &shifted_bit_slice_mles, + &public_bit_slices, + &int_witness_cols, + virtual_specs, + &field_cfg, + ) + }; + let mut extra_bit_cols = int_bit_cols; + extra_bit_cols.extend(virtual_mles); + let virtual_bp_specs = uair_signature.virtual_binary_poly_cols(); + let virtual_binary_mles = build_virtual_binary_poly_mles::( + &trace.binary_poly[num_pub_bin..], + &trace.binary_poly[..num_pub_bin], + uair_signature.shifted_bit_slice_specs(), + virtual_bp_specs, + ); + let kept_witness_binary = filter_booleanity_witness( + &trace.binary_poly[num_pub_bin..], + uair_signature.booleanity_skip_indices(), + ); + let booleanity_binary_cols: Vec<_> = kept_witness_binary + .iter() + .chain(virtual_binary_mles.iter()) + .cloned() + .collect(); + let bool_prep = prepare_booleanity_group::( + &mut pcs_transcript.fs_transcript, + &booleanity_binary_cols, + &extra_bit_cols, + &ic_eval_point, + &field_cfg, + ) + .map_err(ProtocolError::Booleanity)?; - let mut groups = vec![cpr_group]; - let mut bool_ancillary_opt = None; - if let Some((bg, ba)) = bool_prep { - groups.push(bg); - bool_ancillary_opt = Some(ba); - } + let mut groups = vec![cpr_group]; + let mut bool_ancillary_opt = None; + if let Some((bg, ba)) = bool_prep { + groups.push(bg); + bool_ancillary_opt = Some(ba); + } - let (combined_sumcheck, mut md_states) = MultiDegreeSumcheck::prove_as_subprotocol( - &mut pcs_transcript.fs_transcript, - groups, - num_vars, - &field_cfg, - ); - let cpr_state = md_states.remove(0); - let (mut cpr_proof, cpr_prover_state) = CombinedPolyResolver::finalize_prover( - &mut pcs_transcript.fs_transcript, - cpr_state, - cpr_ancillary, - &field_cfg, - )?; - let shifted_bit_slice_evals: Vec = shifted_bit_slice_mles - .into_iter() - .map(|mle| mle.evaluate_with_config(&cpr_prover_state.evaluation_point, &field_cfg)) - .collect::, _>>() - .map_err(ProtocolError::ShiftedBitSliceEval)?; - cpr_proof.shifted_bit_slice_evals = shifted_bit_slice_evals; - if let Some(ba) = bool_ancillary_opt { - let bool_state = md_states.remove(0); - let bit_slice_evals = finalize_booleanity_prover( - &mut pcs_transcript.fs_transcript, - bool_state, - ba, - &field_cfg, - ) - .map_err(ProtocolError::Booleanity)?; - cpr_proof.bit_slice_evals = bit_slice_evals; - } - let lookup_proof: Option> = None; + let (combined_sumcheck, mut md_states) = MultiDegreeSumcheck::prove_as_subprotocol( + &mut pcs_transcript.fs_transcript, + groups, + num_vars, + &field_cfg, + ); + let cpr_state = md_states.remove(0); + let (mut cpr_proof, cpr_prover_state) = CombinedPolyResolver::finalize_prover( + &mut pcs_transcript.fs_transcript, + cpr_state, + cpr_ancillary, + &field_cfg, + )?; + let shifted_bit_slice_evals: Vec = if shifted_bit_slice_mles.is_empty() { + compute_shifted_bit_slice_evals_streaming::( + &trace.binary_poly[num_pub_bin..], + uair_signature.shifted_bit_slice_specs(), + &cpr_prover_state.evaluation_point, + &field_cfg, + ) + .map_err(|e| ProtocolError::Booleanity(e.into()))? + } else { + shifted_bit_slice_mles + .into_iter() + .map(|mle| { + mle.evaluate_with_config(&cpr_prover_state.evaluation_point, &field_cfg) + }) + .collect::, _>>() + .map_err(ProtocolError::ShiftedBitSliceEval)? + }; + cpr_proof.shifted_bit_slice_evals = shifted_bit_slice_evals; + if let Some(ba) = bool_ancillary_opt { + let bool_state = md_states.remove(0); + let bit_slice_evals = finalize_booleanity_prover( + &mut pcs_transcript.fs_transcript, + bool_state, + ba, + &field_cfg, + ) + .map_err(ProtocolError::Booleanity)?; + cpr_proof.bit_slice_evals = bit_slice_evals; + } + let lookup_proof: Option> = None; + + Ok::<_, ProtocolError>(( + cpr_proof, + cpr_prover_state.evaluation_point, + combined_sumcheck, + lookup_proof, + )) + })?; if let Some(t) = timings.as_mut() { t.step4_sumcheck = _t_step4.elapsed(); } // ── Step 5: Multi-point evaluation sumcheck ───────────────────────── let _t_step5 = std::time::Instant::now(); - let cpr_eval_point = cpr_prover_state.evaluation_point.clone(); - let bit_op_mles = zinc_piop::combined_poly_resolver::build_bit_op_mles::( - &trace.binary_poly, - uair_signature.bit_op_specs(), - uair_signature.total_cols().num_binary_poly_cols(), - &projecting_element_f, + let (mp_proof, r_0) = trace_folded_4x_prover_phase( + "multipoint_eval", num_vars, - &field_cfg, - ); - let mut sources = projected_trace_f.clone(); - sources.extend(bit_op_mles); - let mut up_evals_with_bit_op = cpr_proof.up_evals.clone(); - up_evals_with_bit_op.extend(cpr_proof.bit_op_down_evals.iter().cloned()); - let (mp_proof, mp_prover_state) = MultipointEval::prove_as_subprotocol( - &mut pcs_transcript.fs_transcript, - &sources, - &cpr_eval_point, - &up_evals_with_bit_op, - &cpr_proof.down_evals, - uair_signature.shifts(), - &field_cfg, + MLE_FIRST, + CHECK_FOR_OVERFLOW, + || { + let bit_op_mles = zinc_piop::combined_poly_resolver::build_bit_op_mles::( + &trace.binary_poly, + uair_signature.bit_op_specs(), + uair_signature.total_cols().num_binary_poly_cols(), + &projecting_element_f, + num_vars, + &field_cfg, + ); + let mut sources = projected_trace_f.clone(); + sources.extend(bit_op_mles); + let mut up_evals_with_bit_op = cpr_proof.up_evals.clone(); + up_evals_with_bit_op.extend(cpr_proof.bit_op_down_evals.iter().cloned()); + let (mp_proof, r_0) = prove_multipoint_reduction( + &mut pcs_transcript.fs_transcript, + &sources, + &cpr_eval_point, + &up_evals_with_bit_op, + &cpr_proof.down_evals, + uair_signature.shifts(), + &field_cfg, + )?; + + Ok::<_, ProtocolError>((mp_proof, r_0)) + }, )?; - let r_0 = mp_prover_state.eval_point; if let Some(t) = timings.as_mut() { t.step5_multipoint_eval = _t_step5.elapsed(); } @@ -2197,209 +2571,226 @@ where // the int section is replaced below with 4-coeff bar_us, so the // standard 1-coeff int compute would be wasted work. let _t_step6 = std::time::Instant::now(); - let total_cols = uair_signature.total_cols(); - let num_total_bin = total_cols.num_binary_poly_cols(); - let num_total_arb = total_cols.num_arbitrary_poly_cols(); - let mut lifted_evals = crate::compute_lifted_evals_capped::( - &r_0, - &trace.binary_poly, - &projected_trace, - &field_cfg, - Some(num_total_arb), + let (num_total_bin, num_total_arb, lifted_evals, gamma1, gamma2) = trace_folded_4x_prover_phase( + "lift_and_project", + num_vars, + MLE_FIRST, + CHECK_FOR_OVERFLOW, + || { + let total_cols = uair_signature.total_cols(); + let num_total_bin = total_cols.num_binary_poly_cols(); + let num_total_arb = total_cols.num_arbitrary_poly_cols(); + let mut lifted_evals = crate::compute_lifted_evals_capped::( + &r_0, + &trace.binary_poly, + &projected_trace, + &field_cfg, + Some(num_total_arb), + ); + let int_lifted_evals_4coeff: Vec> = + crate::compute_int_fold_4x_lifted_evals::( + &r_0, &trace.int, &field_cfg, + ); + // Append the 4-coeff int section. + lifted_evals.extend(int_lifted_evals_4coeff); + let _int_section_offset = num_total_bin + num_total_arb; + + let mut transcription_buf: Vec = vec![0; F::Inner::NUM_BYTES]; + for bar_u in &lifted_evals { + pcs_transcript + .fs_transcript + .absorb_random_field_slice(&bar_u.coeffs, &mut transcription_buf); + } + let gamma1: F = { + let g_chal: ZtF::Chal = pcs_transcript.fs_transcript.get_challenge(); + F::from_with_cfg(&g_chal, &field_cfg) + }; + let gamma2: F = { + let g_chal: ZtF::Chal = pcs_transcript.fs_transcript.get_challenge(); + F::from_with_cfg(&g_chal, &field_cfg) + }; + + (num_total_bin, num_total_arb, lifted_evals, gamma1, gamma2) + }, ); - let int_lifted_evals_4coeff: Vec> = - crate::compute_int_fold_4x_lifted_evals::( - &r_0, - &trace.int, - &field_cfg, - ); - // Append the 4-coeff int section. - lifted_evals.extend(int_lifted_evals_4coeff); - let _int_section_offset = num_total_bin + num_total_arb; - - let mut transcription_buf: Vec = vec![0; F::Inner::NUM_BYTES]; - for bar_u in &lifted_evals { - pcs_transcript - .fs_transcript - .absorb_random_field_slice(&bar_u.coeffs, &mut transcription_buf); - } - let gamma1: F = { - let g_chal: ZtF::Chal = pcs_transcript.fs_transcript.get_challenge(); - F::from_with_cfg(&g_chal, &field_cfg) - }; - let gamma2: F = { - let g_chal: ZtF::Chal = pcs_transcript.fs_transcript.get_challenge(); - F::from_with_cfg(&g_chal, &field_cfg) - }; if let Some(t) = timings.as_mut() { t.step6_lift_and_project = _t_step6.elapsed(); } // ── Step 7: PCS open ──────────────────────────────────────────────── let _t_step7 = std::time::Instant::now(); - let mut r0_ext = r_0.clone(); - r0_ext.push(gamma1); - r0_ext.push(gamma2); + trace_folded_4x_prover_phase("pcs_open", num_vars, MLE_FIRST, CHECK_FOR_OVERFLOW, || { + let mut r0_ext = r_0.clone(); + r0_ext.push(gamma1); + r0_ext.push(gamma2); - if let Some(multi) = &multi_hint { - if let Some(bd) = zip_breakdown.as_deref_mut() { - let _ = MultiZip3::< - ZtF::BinaryZt, - ZtF::ArbitraryZt, - ZtF::IntZt, - ZtF::BinaryLc, - ZtF::ArbitraryLc, - ZtF::IntLc, - >::prove_f_with_byte_breakdown::( - &mut pcs_transcript, - pp_bin_split2, - pp_arb, - pp_int_split4, - &split_binary_witness, - &witness_trace.arbitrary_poly, - &split_int_witness, - &r0_ext, - multi, - &field_cfg, - &mut bd.bin, - &mut bd.arb, - &mut bd.int, - )?; - } else { - let _ = MultiZip3::< - ZtF::BinaryZt, - ZtF::ArbitraryZt, - ZtF::IntZt, - ZtF::BinaryLc, - ZtF::ArbitraryLc, - ZtF::IntLc, - >::prove_f::( - &mut pcs_transcript, - pp_bin_split2, - pp_arb, - pp_int_split4, - &split_binary_witness, - &witness_trace.arbitrary_poly, - &split_int_witness, - &r0_ext, - multi, - &field_cfg, - )?; - } - } else { - if let Some(hint_bin) = &hint_bin_split { + if let Some(multi) = &multi_hint { if let Some(bd) = zip_breakdown.as_deref_mut() { - let _ = ZipPlus::::prove_f_with_byte_breakdown::< - _, - CHECK_FOR_OVERFLOW, - >( - &mut pcs_transcript, - pp_bin_split2, - &split_binary_witness, - &r0_ext, - hint_bin, - &field_cfg, - &mut bd.bin, - )?; - } else { - let _ = ZipPlus::::prove_f::<_, CHECK_FOR_OVERFLOW>( + let _ = MultiZip3::< + ZtF::BinaryZt, + ZtF::ArbitraryZt, + ZtF::IntZt, + ZtF::BinaryLc, + ZtF::ArbitraryLc, + ZtF::IntLc, + >::prove_f_with_byte_breakdown::( &mut pcs_transcript, pp_bin_split2, - &split_binary_witness, - &r0_ext, - hint_bin, - &field_cfg, - )?; - } - } - if let Some(hint_arb) = &hint_arb { - if let Some(bd) = zip_breakdown.as_deref_mut() { - let _ = ZipPlus::::prove_f_with_byte_breakdown::< - _, - CHECK_FOR_OVERFLOW, - >( - &mut pcs_transcript, - pp_arb, - &witness_trace.arbitrary_poly, - &r_0, - hint_arb, - &field_cfg, - &mut bd.arb, - )?; - } else { - let _ = ZipPlus::::prove_f::< - _, - CHECK_FOR_OVERFLOW, - >( - &mut pcs_transcript, pp_arb, - &witness_trace.arbitrary_poly, - &r_0, - hint_arb, - &field_cfg, - )?; - } - } - if let Some(hint_int) = &hint_int_split { - if let Some(bd) = zip_breakdown.as_deref_mut() { - let _ = ZipPlus::::prove_f_with_byte_breakdown::< - _, - CHECK_FOR_OVERFLOW, - >( - &mut pcs_transcript, pp_int_split4, + &split_binary_witness, + &witness_trace.arbitrary_poly, &split_int_witness, &r0_ext, - hint_int, + multi, &field_cfg, + &mut bd.bin, + &mut bd.arb, &mut bd.int, )?; } else { - let _ = ZipPlus::::prove_f::<_, CHECK_FOR_OVERFLOW>( + let _ = MultiZip3::< + ZtF::BinaryZt, + ZtF::ArbitraryZt, + ZtF::IntZt, + ZtF::BinaryLc, + ZtF::ArbitraryLc, + ZtF::IntLc, + >::prove_f::( &mut pcs_transcript, + pp_bin_split2, + pp_arb, pp_int_split4, + &split_binary_witness, + &witness_trace.arbitrary_poly, &split_int_witness, &r0_ext, - hint_int, + multi, &field_cfg, )?; } + } else { + if let Some(hint_bin) = &hint_bin_split { + if let Some(bd) = zip_breakdown.as_deref_mut() { + let _ = ZipPlus::::prove_f_with_byte_breakdown::< + _, + CHECK_FOR_OVERFLOW, + >( + &mut pcs_transcript, + pp_bin_split2, + &split_binary_witness, + &r0_ext, + hint_bin, + &field_cfg, + &mut bd.bin, + )?; + } else { + let _ = + ZipPlus::::prove_f::<_, CHECK_FOR_OVERFLOW>( + &mut pcs_transcript, + pp_bin_split2, + &split_binary_witness, + &r0_ext, + hint_bin, + &field_cfg, + )?; + } + } + if let Some(hint_arb) = &hint_arb { + if let Some(bd) = zip_breakdown.as_deref_mut() { + let _ = + ZipPlus::::prove_f_with_byte_breakdown::< + _, + CHECK_FOR_OVERFLOW, + >( + &mut pcs_transcript, + pp_arb, + &witness_trace.arbitrary_poly, + &r_0, + hint_arb, + &field_cfg, + &mut bd.arb, + )?; + } else { + let _ = ZipPlus::::prove_f::< + _, + CHECK_FOR_OVERFLOW, + >( + &mut pcs_transcript, + pp_arb, + &witness_trace.arbitrary_poly, + &r_0, + hint_arb, + &field_cfg, + )?; + } + } + if let Some(hint_int) = &hint_int_split { + if let Some(bd) = zip_breakdown.as_deref_mut() { + let _ = ZipPlus::::prove_f_with_byte_breakdown::< + _, + CHECK_FOR_OVERFLOW, + >( + &mut pcs_transcript, + pp_int_split4, + &split_int_witness, + &r0_ext, + hint_int, + &field_cfg, + &mut bd.int, + )?; + } else { + let _ = ZipPlus::::prove_f::<_, CHECK_FOR_OVERFLOW>( + &mut pcs_transcript, + pp_int_split4, + &split_int_witness, + &r0_ext, + hint_int, + &field_cfg, + )?; + } + } } - } + + Ok::<_, ProtocolError>(()) + })?; if let Some(t) = timings.as_mut() { t.step7_pcs_open = _t_step7.elapsed(); } // ── Assemble the proof ────────────────────────────────────────────── let _t_assembly = std::time::Instant::now(); - let zip_proof = pcs_transcript.stream.into_inner(); - let commitments = (commitment_bin, commitment_arb, commitment_int); - - let pub_cols = uair_signature.public_cols(); - let num_pub_bin = pub_cols.num_binary_poly_cols(); - let num_pub_arb = pub_cols.num_arbitrary_poly_cols(); - let num_pub_int = pub_cols.num_int_cols(); - let witness = uair_signature.witness_cols(); - let witness_arb_offset = add!(num_total_bin, num_pub_arb); - let witness_arb_end = add!(witness_arb_offset, witness.num_arbitrary_poly_cols()); - let witness_int_offset = add!(add!(num_total_bin, num_total_arb), num_pub_int); - let witness_lifted_evals: Vec<_> = lifted_evals[num_pub_bin..num_total_bin] - .iter() - .chain(&lifted_evals[witness_arb_offset..witness_arb_end]) - .chain(&lifted_evals[witness_int_offset..]) - .cloned() - .collect(); + let proof = + trace_folded_4x_prover_phase("assembly", num_vars, MLE_FIRST, CHECK_FOR_OVERFLOW, || { + let zip_proof = pcs_transcript.stream.into_inner(); + let commitments = (commitment_bin, commitment_arb, commitment_int); + + let pub_cols = uair_signature.public_cols(); + let num_pub_bin = pub_cols.num_binary_poly_cols(); + let num_pub_arb = pub_cols.num_arbitrary_poly_cols(); + let num_pub_int = pub_cols.num_int_cols(); + let witness = uair_signature.witness_cols(); + let witness_arb_offset = add!(num_total_bin, num_pub_arb); + let witness_arb_end = add!(witness_arb_offset, witness.num_arbitrary_poly_cols()); + let witness_int_offset = add!(add!(num_total_bin, num_total_arb), num_pub_int); + let witness_lifted_evals: Vec<_> = lifted_evals[num_pub_bin..num_total_bin] + .iter() + .chain(&lifted_evals[witness_arb_offset..witness_arb_end]) + .chain(&lifted_evals[witness_int_offset..]) + .cloned() + .collect(); - let proof = Proof { - commitments, - ideal_check: ic_proof, - resolver: cpr_proof, - combined_sumcheck, - multipoint_eval: mp_proof, - zip: zip_proof, - witness_lifted_evals, - lookup_proof, - }; + Proof { + commitments, + ideal_check: ic_proof, + resolver: cpr_proof, + combined_sumcheck, + multipoint_eval: mp_proof, + zip: zip_proof, + witness_lifted_evals, + lookup_proof, + } + }); if let Some(t) = timings.as_mut() { t.assembly = _t_assembly.elapsed(); } diff --git a/protocol/src/verifier.rs b/protocol/src/verifier.rs index 855d682a..91453de9 100644 --- a/protocol/src/verifier.rs +++ b/protocol/src/verifier.rs @@ -34,14 +34,26 @@ use zinc_uair::{ ideal_collector::IdealOrZero, }; use zinc_utils::{ - add, cfg_join, from_ref::FromRef, inner_transparent_field::InnerTransparentField, - mul_by_scalar::MulByScalar, projectable_to_field::ProjectableToField, + add, cfg_join, + delayed_reduction::{DelayedFieldProductSum, MontgomeryLimbs}, + from_ref::FromRef, + inner_transparent_field::InnerTransparentField, + mul_by_scalar::MulByScalar, + projectable_to_field::ProjectableToField, }; use zip_plus::{ - pcs::structs::{ZipPlus, ZipPlusParams, ZipTypes}, + pcs::{ + generic::PCS, + structs::{ZipPlus, ZipPlusParams, ZipTypes}, + }, pcs_transcript::PcsVerifierTranscript, }; +use crate::{ + multipoint_reduction::verify_multipoint_reduction, + pcs::{AllZipPCSTypes, PCSCommitments, PCSVerifierParams, ZincPCSTypes}, +}; + /// Drop the witness binary_poly column evals the UAIR opted out of /// (sorted, dedup'd `skip_indices` relative to the witness slice). The /// surviving evals line up positionally with the bit-slice blocks @@ -62,22 +74,39 @@ fn filter_skipped_parent_evals( .collect() } +fn ensure_pcs_stream_consumed( + pcs_transcript: &PcsVerifierTranscript, +) -> Result<(), ProtocolError> { + let consumed = usize::try_from(pcs_transcript.stream.position()).unwrap_or(usize::MAX); + let total = pcs_transcript.stream.get_ref().len(); + if consumed != total { + return Err(ProtocolError::PcsProofTrailingBytes { consumed, total }); + } + Ok(()) +} + // // Shared base // /// Persistent verifier infrastructure carried across every step. #[derive(Clone, Debug)] -pub struct VerifierBase<'a, Zt: ZincTypes, const D: usize> { +pub struct VerifierBase< + 'a, + Zt: ZincTypes, + F: PrimeField, + const D: usize, + P: ZincPCSTypes = AllZipPCSTypes, +> { num_vars: usize, uair_signature: UairSignature, pcs_transcript: PcsVerifierTranscript, public_trace: &'a UairTrace<'a, Zt::Int, Zt::Int, D>, // Commitment info - vp_bin: &'a ZipPlusParams, - vp_arb: &'a ZipPlusParams, - vp_int: &'a ZipPlusParams, + vp: PCSVerifierParams, + + _phantom: PhantomData<(F, P)>, } // @@ -93,11 +122,12 @@ pub struct VerifierTranscriptReconstructed< F: PrimeField, IdealOverF, const D: usize, + P: ZincPCSTypes = AllZipPCSTypes, > { - base: VerifierBase<'a, Zt, D>, + base: VerifierBase<'a, Zt, F, D, P>, // Proof leftovers - proof_commitments: (ZipPlusCommitment, ZipPlusCommitment, ZipPlusCommitment), + proof_commitments: PCSCommitments, proof_ideal_check: IdealCheckProof, proof_resolver: CombinedPolyResolverProof, proof_combined_sumcheck: MultiDegreeSumcheckProof, @@ -116,12 +146,13 @@ pub struct VerifierPrimeProjected< F: PrimeField, IdealOverF, const D: usize, + P: ZincPCSTypes = AllZipPCSTypes, > { - base: VerifierBase<'a, Zt, D>, + base: VerifierBase<'a, Zt, F, D, P>, field_cfg: F::Config, // Proof leftovers - proof_commitments: (ZipPlusCommitment, ZipPlusCommitment, ZipPlusCommitment), + proof_commitments: PCSCommitments, proof_ideal_check: IdealCheckProof, proof_resolver: CombinedPolyResolverProof, proof_combined_sumcheck: MultiDegreeSumcheckProof, @@ -140,13 +171,14 @@ pub struct VerifierIdealChecked< F: PrimeField, IdealOverF, const D: usize, + P: ZincPCSTypes = AllZipPCSTypes, > { - base: VerifierBase<'a, Zt, D>, + base: VerifierBase<'a, Zt, F, D, P>, field_cfg: F::Config, ic_subclaim: ideal_check::VerifierSubclaim, // Proof leftovers - proof_commitments: (ZipPlusCommitment, ZipPlusCommitment, ZipPlusCommitment), + proof_commitments: PCSCommitments, proof_resolver: CombinedPolyResolverProof, proof_combined_sumcheck: MultiDegreeSumcheckProof, proof_multipoint_eval: MultipointEvalProof, @@ -164,15 +196,16 @@ pub struct VerifierEvalProjected< F: PrimeField, IdealOverF, const D: usize, + P: ZincPCSTypes = AllZipPCSTypes, > { - base: VerifierBase<'a, Zt, D>, + base: VerifierBase<'a, Zt, F, D, P>, field_cfg: F::Config, ic_subclaim: ideal_check::VerifierSubclaim, projecting_element_f: F, projected_scalars_f: ScalarMap, // Proof leftovers - proof_commitments: (ZipPlusCommitment, ZipPlusCommitment, ZipPlusCommitment), + proof_commitments: PCSCommitments, proof_resolver: CombinedPolyResolverProof, proof_combined_sumcheck: MultiDegreeSumcheckProof, proof_multipoint_eval: MultipointEvalProof, @@ -183,14 +216,21 @@ pub struct VerifierEvalProjected< /// After step 4 (sumcheck verify). #[derive(Clone, Debug)] -pub struct VerifierSumchecked<'a, Zt: ZincTypes, F: PrimeField, IdealOverF, const D: usize> { - base: VerifierBase<'a, Zt, D>, +pub struct VerifierSumchecked< + 'a, + Zt: ZincTypes, + F: PrimeField, + IdealOverF, + const D: usize, + P: ZincPCSTypes = AllZipPCSTypes, +> { + base: VerifierBase<'a, Zt, F, D, P>, field_cfg: F::Config, projecting_element_f: F, cpr_subclaim: combined_poly_resolver::VerifierSubclaim, // Proof leftovers - proof_commitments: (ZipPlusCommitment, ZipPlusCommitment, ZipPlusCommitment), + proof_commitments: PCSCommitments, proof_multipoint_eval: MultipointEvalProof, proof_witness_lifted_evals: Vec>, proof_lookup_proof: Option>, @@ -199,15 +239,21 @@ pub struct VerifierSumchecked<'a, Zt: ZincTypes, F: PrimeField, IdealOverF, c /// After step 5 (multi-point eval). #[derive(Clone, Debug)] -pub struct VerifierMultipointEvaled<'a, Zt: ZincTypes, F: PrimeField, IdealOverF, const D: usize> -{ - base: VerifierBase<'a, Zt, D>, +pub struct VerifierMultipointEvaled< + 'a, + Zt: ZincTypes, + F: PrimeField, + IdealOverF, + const D: usize, + P: ZincPCSTypes = AllZipPCSTypes, +> { + base: VerifierBase<'a, Zt, F, D, P>, field_cfg: F::Config, projecting_element_f: F, mp_subclaim: multipoint_eval::Subclaim, // Proof leftovers - proof_commitments: (ZipPlusCommitment, ZipPlusCommitment, ZipPlusCommitment), + proof_commitments: PCSCommitments, proof_witness_lifted_evals: Vec>, proof_lookup_proof: Option>, _phantom: PhantomData, @@ -222,14 +268,15 @@ pub struct VerifierLiftedEvalsChecked< F: PrimeField, IdealOverF, const D: usize, + P: ZincPCSTypes = AllZipPCSTypes, > { - base: VerifierBase<'a, Zt, D>, + base: VerifierBase<'a, Zt, F, D, P>, field_cfg: F::Config, mp_subclaim: multipoint_eval::Subclaim, all_lifted_evals: Vec>, // Proof leftovers - proof_commitments: (ZipPlusCommitment, ZipPlusCommitment, ZipPlusCommitment), + proof_commitments: PCSCommitments, proof_lookup_proof: Option>, _phantom: PhantomData, } @@ -249,19 +296,28 @@ impl ZincPlusPiop where Zt: ZincTypes, U: Uair, - F: PrimeField, + F: PrimeField + + DelayedFieldProductSum + + FromPrimitiveWithConfig + + for<'b> FromWithConfig<&'b ::CombR> + + for<'b> FromWithConfig<&'b ::CombR> + + for<'b> FromWithConfig<&'b ::CombR> + + for<'b> FromWithConfig<&'b Zt::Chal> + + for<'b> MulByScalar<&'b F> + + FromRef, F::Inner: ConstTranscribable, + F::Modulus: FromRef, { /// Step 0: Verifier entry point. /// Reconstruct Fiat-Shamir transcript from commitments and public data. #[allow(clippy::type_complexity)] pub fn step0_reconstruct_transcript<'a, IdealOverF>( - (vp_bin, vp_arb, vp_int): &'a ( + (vp_bin, vp_arb, vp_int): &( ZipPlusParams, ZipPlusParams, ZipPlusParams, ), - mut proof: Proof, + proof: Proof, public_trace: &'a UairTrace<'a, Zt::Int, Zt::Int, D>, num_vars: usize, ) -> Result< @@ -270,6 +326,57 @@ where > where IdealOverF: Ideal, + AllZipPCSTypes: ZincPCSTypes< + Zt, + F, + D, + BinaryPCS = zip_plus::pcs::generic::ZipPlusPCS, + ArbitraryPCS = zip_plus::pcs::generic::ZipPlusPCS, + IntPCS = zip_plus::pcs::generic::ZipPlusPCS, + >, + { + let pcs_vp = PCSVerifierParams:: { + binary: vp_bin.clone(), + arbitrary: vp_arb.clone(), + int: vp_int.clone(), + }; + let commitments = proof.commitments; + let proof = Proof { + commitments: PCSCommitments:: { + binary: commitments.0, + arbitrary: commitments.1, + int: commitments.2, + }, + zip: proof.zip, + ideal_check: proof.ideal_check, + resolver: proof.resolver, + combined_sumcheck: proof.combined_sumcheck, + multipoint_eval: proof.multipoint_eval, + witness_lifted_evals: proof.witness_lifted_evals, + lookup_proof: proof.lookup_proof, + }; + Self::step0_reconstruct_transcript_with_pcs::( + &pcs_vp, + proof, + public_trace, + num_vars, + ) + } + + /// Step 0 with an explicit PCS bundle. + #[allow(clippy::type_complexity)] + pub fn step0_reconstruct_transcript_with_pcs<'a, IdealOverF, P>( + vp: &PCSVerifierParams, + mut proof: Proof>, + public_trace: &'a UairTrace<'a, Zt::Int, Zt::Int, D>, + num_vars: usize, + ) -> Result< + VerifierTranscriptReconstructed<'a, Zt, U, F, IdealOverF, D, P>, + ProtocolError, + > + where + P: ZincPCSTypes, + IdealOverF: Ideal, { let zip_proof = std::mem::take(&mut proof.zip); let mut base = VerifierBase { @@ -280,18 +387,22 @@ where fs_transcript: Blake3Transcript::default(), stream: Cursor::new(zip_proof), }, - vp_bin, - vp_arb, - vp_int, + vp: vp.clone(), + _phantom: PhantomData, }; - for comm in [ - &proof.commitments.0, - &proof.commitments.1, - &proof.commitments.2, - ] { - base.pcs_transcript.fs_transcript.absorb_slice(&comm.root); - } + P::BinaryPCS::absorb_commitment( + &mut base.pcs_transcript.fs_transcript, + &proof.commitments.binary, + ); + P::ArbitraryPCS::absorb_commitment( + &mut base.pcs_transcript.fs_transcript, + &proof.commitments.arbitrary, + ); + P::IntPCS::absorb_commitment( + &mut base.pcs_transcript.fs_transcript, + &proof.commitments.int, + ); absorb_public_columns( &mut base.pcs_transcript.fs_transcript, @@ -320,10 +431,11 @@ where } } -impl<'a, Zt, U, F, IdealOverF, const D: usize> - VerifierTranscriptReconstructed<'a, Zt, U, F, IdealOverF, D> +impl<'a, Zt, U, F, IdealOverF, const D: usize, P> + VerifierTranscriptReconstructed<'a, Zt, U, F, IdealOverF, D, P> where Zt: ZincTypes, + P: ZincPCSTypes, F: InnerTransparentField + FromPrimitiveWithConfig + FromRef + Send + Sync + 'static, F::Inner: ConstIntSemiring + ConstTranscribable + Send + Sync + Zero + Default, F::Modulus: ConstTranscribable + FromRef, @@ -335,13 +447,21 @@ where #[allow(clippy::type_complexity)] pub fn step1_prime_projection( self, - ) -> Result, ProtocolError> + ) -> Result, ProtocolError> { // `fixed-prime` branch: use the secp256k1 base field prime as the // projecting prime instead of drawing one from the transcript. // See `crate::fixed_prime` for the soundness caveat. let field_cfg = crate::fixed_prime::secp256k1_field_cfg::(); + self.step1_prime_projection_with_field_cfg(field_cfg) + } + + pub fn step1_prime_projection_with_field_cfg( + self, + field_cfg: F::Config, + ) -> Result, ProtocolError> + { Ok(VerifierPrimeProjected { base: self.base, field_cfg, @@ -357,9 +477,11 @@ where } } -impl<'a, Zt, U, F, IdealOverF, const D: usize> VerifierPrimeProjected<'a, Zt, U, F, IdealOverF, D> +impl<'a, Zt, U, F, IdealOverF, const D: usize, P> + VerifierPrimeProjected<'a, Zt, U, F, IdealOverF, D, P> where Zt: ZincTypes, + P: ZincPCSTypes, Zt::Int: ProjectableToField, ::Eval: ProjectableToField, F: InnerTransparentField @@ -384,7 +506,7 @@ where pub fn step2_ideal_check( mut self, project_ideal: impl Fn(&IdealOrZero, &F::Config) -> IdealOverF, - ) -> Result, ProtocolError> + ) -> Result, ProtocolError> { let num_constraints = count_constraints::(); @@ -412,10 +534,13 @@ where } } -impl<'a, Zt, U, F, IdealOverF, const D: usize> VerifierIdealChecked<'a, Zt, U, F, IdealOverF, D> +impl<'a, Zt, U, F, IdealOverF, const D: usize, P> + VerifierIdealChecked<'a, Zt, U, F, IdealOverF, D, P> where Zt: ZincTypes, + P: ZincPCSTypes, F: InnerTransparentField + + DelayedFieldProductSum + for<'b> FromWithConfig<&'b Zt::Chal> + FromRef + Send @@ -430,7 +555,7 @@ where pub fn step3_eval_projection( mut self, project_scalar: impl Fn(&U::Scalar, &F::Config) -> DynamicPolynomialF + Sync, - ) -> Result, ProtocolError> + ) -> Result, ProtocolError> { let projecting_element: Zt::Chal = self.base.pcs_transcript.fs_transcript.get_challenge(); let projecting_element_f: F = F::from_with_cfg(&projecting_element, &self.field_cfg); @@ -457,12 +582,15 @@ where } } -impl<'a, Zt, U, F, IdealOverF, const D: usize> VerifierEvalProjected<'a, Zt, U, F, IdealOverF, D> +impl<'a, Zt, U, F, IdealOverF, const D: usize, P> + VerifierEvalProjected<'a, Zt, U, F, IdealOverF, D, P> where Zt: ZincTypes, + P: ZincPCSTypes, Zt::Int: ProjectableToField, ::Eval: ProjectableToField, F: InnerTransparentField + + DelayedFieldProductSum + FromPrimitiveWithConfig + for<'b> FromWithConfig<&'b Zt::Int> + for<'b> FromWithConfig<&'b ::CombR> @@ -482,31 +610,26 @@ where /// Step 4: Sumcheck verification (CPR + algebraic booleanity). pub fn step4_sumcheck_verify( mut self, - ) -> Result, ProtocolError> { + ) -> Result, ProtocolError> { let num_constraints = count_constraints::(); let num_pub_bin = self .base .uair_signature .public_cols() .num_binary_poly_cols(); - let num_total_bin = - self.base.uair_signature.total_cols().num_binary_poly_cols(); + let num_total_bin = self.base.uair_signature.total_cols().num_binary_poly_cols(); let bool_skip = self.base.uair_signature.booleanity_skip_indices(); // Booleanity covers: witness binary_poly cols (minus // `booleanity_skip_indices`), packed virtual binary_poly cols, // declared int bit cols, and virtual booleanity linear-combo cols. - let num_int_bit_cols = - self.base.uair_signature.int_witness_bit_cols().len(); - let num_virtual_cols = - self.base.uair_signature.virtual_booleanity_cols().len(); - let num_virtual_bp_cols = - self.base.uair_signature.virtual_binary_poly_cols().len(); + let num_int_bit_cols = self.base.uair_signature.int_witness_bit_cols().len(); + let num_virtual_cols = self.base.uair_signature.virtual_booleanity_cols().len(); + let num_virtual_bp_cols = self.base.uair_signature.virtual_binary_poly_cols().len(); let num_bit_slices = ((num_total_bin - num_pub_bin) - bool_skip.len()) * D + num_virtual_bp_cols * D + num_int_bit_cols + num_virtual_cols; - let num_shifted_bit_slices = - self.base.uair_signature.shifted_bit_slice_specs().len() * D; + let num_shifted_bit_slices = self.base.uair_signature.shifted_bit_slice_specs().len() * D; let cpr_verifier_ancillary = CombinedPolyResolver::prepare_verifier::( &mut self.base.pcs_transcript.fs_transcript, @@ -524,8 +647,7 @@ where // 4b: Booleanity verifier prep — samples α_b, validates that the // booleanity group's claimed sum is zero (zerocheck). let bool_verifier_ancillary_opt = if num_bit_slices > 0 { - let bool_claimed_sum = - self.proof_combined_sumcheck.claimed_sums()[1].clone(); + let bool_claimed_sum = self.proof_combined_sumcheck.claimed_sums()[1].clone(); prepare_booleanity_verifier::( &mut self.base.pcs_transcript.fs_transcript, bool_claimed_sum, @@ -564,7 +686,11 @@ where // booleanity-bound MLE to the committed sources without a // separate equality check. let int_offset = self.base.uair_signature.total_cols().num_binary_poly_cols() - + self.base.uair_signature.total_cols().num_arbitrary_poly_cols(); + + self + .base + .uair_signature + .total_cols() + .num_arbitrary_poly_cols(); let num_pub_int = self.base.uair_signature.public_cols().num_int_cols(); let num_wit_int = self.base.uair_signature.witness_cols().num_int_cols(); let num_binary_bit_slices_for_overrides = (num_total_bin - num_pub_bin) * D; @@ -574,21 +700,20 @@ where // Public binary_poly bit slice evals at the shared sumcheck // point — verifier computes locally from public_trace; reused // by both virtual-bool and virtual-binary-poly overrides. - let public_bit_slice_evals: Vec = if !virtual_bp_specs.is_empty() - || !virtual_specs.is_empty() - { - let public_bit_slice_mles = compute_bit_slices_flat::( - &self.base.public_trace.binary_poly, - &self.field_cfg, - ); - public_bit_slice_mles - .into_iter() - .map(|mle| mle.evaluate_with_config(md_subclaims.point(), &self.field_cfg)) - .collect::, _>>() - .map_err(ProtocolError::ShiftedBitSliceEval)? - } else { - Vec::new() - }; + let public_bit_slice_evals: Vec = + if !virtual_bp_specs.is_empty() || !virtual_specs.is_empty() { + let public_bit_slice_mles = compute_bit_slices_flat::( + &self.base.public_trace.binary_poly, + &self.field_cfg, + ); + public_bit_slice_mles + .into_iter() + .map(|mle| mle.evaluate_with_config(md_subclaims.point(), &self.field_cfg)) + .collect::, _>>() + .map_err(ProtocolError::ShiftedBitSliceEval)? + } else { + Vec::new() + }; // closing_overrides_tail layout (in trailing-position order): // [virtual_binary_poly_per_bit (V_b * D), @@ -617,9 +742,7 @@ where ); if !virtual_specs.is_empty() { let int_witness_up_evals: Vec = (0..num_wit_int) - .map(|i| { - cpr_subclaim.up_evals[int_offset + num_pub_int + i].clone() - }) + .map(|i| cpr_subclaim.up_evals[int_offset + num_pub_int + i].clone()) .collect(); let virtual_overrides = compute_virtual_closing_overrides::( virtual_specs, @@ -670,8 +793,7 @@ where // Shifted bit-slice consistency: tie each spec's emitted bit // slices to the corresponding `down_eval` (= parent col at // shifted point) via the same projection-element trick. - let shifted_down_indices = - self.base.uair_signature.shifted_bit_slice_down_indices(); + let shifted_down_indices = self.base.uair_signature.shifted_bit_slice_down_indices(); let shifted_parent_evals: Vec = shifted_down_indices .iter() .map(|&i| cpr_subclaim.down_evals[i].clone()) @@ -700,10 +822,17 @@ where } } -impl<'a, Zt, F, IdealOverF, const D: usize> VerifierSumchecked<'a, Zt, F, IdealOverF, D> +impl<'a, Zt, F, IdealOverF, const D: usize, P> VerifierSumchecked<'a, Zt, F, IdealOverF, D, P> where Zt: ZincTypes, - F: InnerTransparentField + FromPrimitiveWithConfig + FromRef + Send + Sync + 'static, + P: ZincPCSTypes, + F: InnerTransparentField + + DelayedFieldProductSum + + FromPrimitiveWithConfig + + FromRef + + Send + + Sync + + 'static, F::Inner: ConstIntSemiring + ConstTranscribable + Send + Sync + Zero + Default, F::Modulus: ConstTranscribable + FromRef, IdealOverF: Ideal, @@ -719,14 +848,14 @@ where /// lifted eval (free arithmetic in F_q[X]). pub fn step5_multipoint_eval( mut self, - ) -> Result, ProtocolError> + ) -> Result, ProtocolError> { let cpr_eval_point = self.cpr_subclaim.evaluation_point.clone(); let mut up_evals_with_bit_op = self.cpr_subclaim.up_evals.clone(); up_evals_with_bit_op.extend(self.cpr_subclaim.bit_op_down_evals.iter().cloned()); - let mp_subclaim = MultipointEval::verify_as_subprotocol( + let mp_subclaim = verify_multipoint_reduction( &mut self.base.pcs_transcript.fs_transcript, self.proof_multipoint_eval, &cpr_eval_point, @@ -750,12 +879,15 @@ where } } -impl<'a, Zt, F, IdealOverF, const D: usize> VerifierMultipointEvaled<'a, Zt, F, IdealOverF, D> +impl<'a, Zt, F, IdealOverF, const D: usize, P> VerifierMultipointEvaled<'a, Zt, F, IdealOverF, D, P> where Zt: ZincTypes, + P: ZincPCSTypes, Zt::Int: ProjectableToField, ::Eval: ProjectableToField, F: InnerTransparentField + + MontgomeryLimbs + + DelayedFieldProductSum + FromPrimitiveWithConfig + for<'b> FromWithConfig<&'b Zt::Int> + for<'b> FromWithConfig<&'b Zt::Chal> @@ -781,7 +913,7 @@ where /// `MultipointEval(ClaimMismatch)` from `verify_subclaim`. pub fn step6_lifted_evals( mut self, - ) -> Result, ProtocolError> + ) -> Result, ProtocolError> { let r_0 = &self.mp_subclaim.sumcheck_subclaim.point; @@ -869,15 +1001,18 @@ where } } -impl<'a, Zt, F, IdealOverF, const D: usize> VerifierLiftedEvalsChecked<'a, Zt, F, IdealOverF, D> +impl<'a, Zt, F, IdealOverF, const D: usize, P> + VerifierLiftedEvalsChecked<'a, Zt, F, IdealOverF, D, P> where Zt: ZincTypes, + P: ZincPCSTypes, Zt::Int: ProjectableToField, ::Cw: ProjectableToField, ::Eval: ProjectableToField, ::Cw: ProjectableToField, ::Cw: ProjectableToField, F: InnerTransparentField + + DelayedFieldProductSum + FromPrimitiveWithConfig + for<'b> FromWithConfig<&'b Zt::Int> + for<'b> FromWithConfig<&'b ::CombR> @@ -913,60 +1048,38 @@ where let field_cfg = &self.field_cfg; let all_lifted_evals = &self.all_lifted_evals; - macro_rules! verify_pcs_batch { - ($Zt:ty, $Lc:ty, $vp:expr, $idx:tt, [$evals_range:expr]) => {{ - let comm = &commitments.$idx; - if comm.batch_size > 0 { - let per_poly_alphas = ZipPlus::<$Zt, $Lc>::sample_alphas( - &mut pcs_transcript.fs_transcript, - comm.batch_size, - ); - let mut eval_f = F::zero_with_cfg(field_cfg); - for (bar_u, alphas) in all_lifted_evals[$evals_range] - .iter() - .zip(per_poly_alphas.iter()) - { - for (coeff, alpha) in bar_u.coeffs.iter().zip(alphas.iter()) { - let mut term = F::from_with_cfg(alpha, field_cfg); - term *= coeff; - eval_f += &term; - } - } - ZipPlus::<$Zt, $Lc>::verify_with_alphas::( - pcs_transcript, - $vp, - comm, - field_cfg, - r_0, - &eval_f, - &per_poly_alphas, - ) - .map_err(|e| ProtocolError::PcsVerification($idx, e))?; - } - }}; - } + P::BinaryPCS::verify_open::( + pcs_transcript, + &self.base.vp.binary, + &commitments.binary, + r_0, + &all_lifted_evals[num_pub_bin..num_total_bin], + &Default::default(), + field_cfg, + ) + .map_err(|e| ProtocolError::PcsVerification(0, e))?; + P::ArbitraryPCS::verify_open::( + pcs_transcript, + &self.base.vp.arbitrary, + &commitments.arbitrary, + r_0, + &all_lifted_evals[add!(num_total_bin, num_pub_arb)..add!(num_total_bin, num_total_arb)], + &Default::default(), + field_cfg, + ) + .map_err(|e| ProtocolError::PcsVerification(1, e))?; + P::IntPCS::verify_open::( + pcs_transcript, + &self.base.vp.int, + &commitments.int, + r_0, + &all_lifted_evals[add!(add!(num_total_bin, num_total_arb), num_pub_int)..], + &Default::default(), + field_cfg, + ) + .map_err(|e| ProtocolError::PcsVerification(2, e))?; - verify_pcs_batch!( - Zt::BinaryZt, - Zt::BinaryLc, - self.base.vp_bin, - 0, - [num_pub_bin..num_total_bin] - ); - verify_pcs_batch!( - Zt::ArbitraryZt, - Zt::ArbitraryLc, - self.base.vp_arb, - 1, - [add!(num_total_bin, num_pub_arb)..add!(num_total_bin, num_total_arb)] - ); - verify_pcs_batch!( - Zt::IntZt, - Zt::IntLc, - self.base.vp_int, - 2, - [add!(add!(num_total_bin, num_total_arb), num_pub_int)..] - ); + ensure_pcs_stream_consumed::(pcs_transcript)?; Ok(VerifierPcsVerified { _phantom: PhantomData, @@ -992,6 +1105,8 @@ where ::Cw: ProjectableToField, ::Cw: ProjectableToField, F: InnerTransparentField + + MontgomeryLimbs + + DelayedFieldProductSum + FromPrimitiveWithConfig + for<'a> FromWithConfig<&'a Zt::Int> + for<'a> FromWithConfig<&'a ::CombR> @@ -1028,24 +1143,95 @@ where ) -> Result<(), ProtocolError> where IdealOverF: Ideal + IdealCheck>, + AllZipPCSTypes: ZincPCSTypes< + Zt, + F, + D, + BinaryPCS = zip_plus::pcs::generic::ZipPlusPCS, + ArbitraryPCS = zip_plus::pcs::generic::ZipPlusPCS, + IntPCS = zip_plus::pcs::generic::ZipPlusPCS, + >, + { + let pcs_vp = PCSVerifierParams:: { + binary: vp.0.clone(), + arbitrary: vp.1.clone(), + int: vp.2.clone(), + }; + let commitments = proof.commitments; + let proof = Proof { + commitments: PCSCommitments:: { + binary: commitments.0, + arbitrary: commitments.1, + int: commitments.2, + }, + zip: proof.zip, + ideal_check: proof.ideal_check, + resolver: proof.resolver, + combined_sumcheck: proof.combined_sumcheck, + multipoint_eval: proof.multipoint_eval, + witness_lifted_evals: proof.witness_lifted_evals, + lookup_proof: proof.lookup_proof, + }; + + Self::verify_with_pcs::( + &pcs_vp, + proof, + public_trace, + num_vars, + project_scalar, + project_ideal, + ) + } + + #[allow(clippy::too_many_arguments, clippy::type_complexity)] + pub fn verify_with_pcs( + vp: &PCSVerifierParams, + proof: Proof>, + public_trace: &UairTrace, + num_vars: usize, + project_scalar: impl Fn(&U::Scalar, &F::Config) -> DynamicPolynomialF + Sync, + project_ideal: impl Fn(&IdealOrZero, &F::Config) -> IdealOverF, + ) -> Result<(), ProtocolError> + where + P: ZincPCSTypes, + IdealOverF: Ideal + IdealCheck>, + { + let field_cfg = crate::fixed_prime::secp256k1_field_cfg::(); + Self::verify_with_pcs_and_field_cfg::( + vp, + proof, + public_trace, + num_vars, + project_scalar, + project_ideal, + field_cfg, + ) + } + + #[allow(clippy::too_many_arguments, clippy::type_complexity)] + pub fn verify_with_pcs_and_field_cfg( + vp: &PCSVerifierParams, + proof: Proof>, + public_trace: &UairTrace, + num_vars: usize, + project_scalar: impl Fn(&U::Scalar, &F::Config) -> DynamicPolynomialF + Sync, + project_ideal: impl Fn(&IdealOrZero, &F::Config) -> IdealOverF, + field_cfg: F::Config, + ) -> Result<(), ProtocolError> + where + P: ZincPCSTypes, + IdealOverF: Ideal + IdealCheck>, { - // Verifier-side public-column structural checks. UAIRs that - // need to enforce structural properties of public columns - // (compensator-zero on active rows, tail-corrector-zero on - // inner rows, etc.) discharge them here, by direct row-wise - // inspection of public_trace, before any algebraic check - // begins. Default impl is a no-op for UAIRs that don't need - // such checks. U::verify_public_structure(public_trace, num_vars) .map_err(ProtocolError::PublicStructure)?; - ZincPlusPiop::::step0_reconstruct_transcript::( + ZincPlusPiop::::step0_reconstruct_transcript_with_pcs::( vp, proof, public_trace, num_vars, )? - .step1_prime_projection()? + .step1_prime_projection_with_field_cfg(field_cfg)? .step2_ideal_check(project_ideal)? .step3_eval_projection(project_scalar)? .step4_sumcheck_verify()? @@ -1098,6 +1284,8 @@ where ::Cw: ProjectableToField, U: Uair + 'static, F: InnerTransparentField + + MontgomeryLimbs + + DelayedFieldProductSum + FromPrimitiveWithConfig + for<'b> FromWithConfig<&'b ZtF::Int> + for<'b> FromWithConfig<&'b ::CombR> @@ -1116,8 +1304,7 @@ where // Verifier-side public-column structural checks (compensator/ // corrector zero-pinning, etc.). UAIRs that don't need extra // structural checks fall through this with a no-op default impl. - U::verify_public_structure(public_trace, num_vars) - .map_err(ProtocolError::PublicStructure)?; + U::verify_public_structure(public_trace, num_vars).map_err(ProtocolError::PublicStructure)?; // ── Step 0: Reconstruct transcript ────────────────────────────────── let zip_proof = std::mem::take(&mut proof.zip); @@ -1164,9 +1351,8 @@ where let projecting_element_f: F = F::from_with_cfg(&projecting_element, &field_cfg); let projected_scalars_fx = project_scalars::(|s| project_scalar(s, &field_cfg)); - let projected_scalars_f = - project_scalars_to_field(projected_scalars_fx, &projecting_element_f) - .map_err(|(_s, _f, e)| ProtocolError::ScalarProjection(e))?; + let projected_scalars_f = project_scalars_to_field(projected_scalars_fx, &projecting_element_f) + .map_err(|(_s, _f, e)| ProtocolError::ScalarProjection(e))?; // ── Step 4: Sumcheck verify (CPR + algebraic booleanity) ──────────── let num_pub_bin = uair_signature.public_cols().num_binary_poly_cols(); @@ -1179,8 +1365,7 @@ where + num_virtual_bp_cols * D + num_int_bit_cols + num_virtual_cols; - let num_shifted_bit_slices = - uair_signature.shifted_bit_slice_specs().len() * D; + let num_shifted_bit_slices = uair_signature.shifted_bit_slice_specs().len() * D; let cpr_verifier_ancillary = CombinedPolyResolver::prepare_verifier::( &mut pcs_transcript.fs_transcript, &proof.resolver, @@ -1234,21 +1419,18 @@ where let num_binary_bit_slices = (num_total_bin - num_pub_bin) * D; let virtual_bp_specs = uair_signature.virtual_binary_poly_cols(); let virtual_specs = uair_signature.virtual_booleanity_cols(); - let public_bit_slice_evals: Vec = if !virtual_bp_specs.is_empty() - || !virtual_specs.is_empty() - { - let public_bit_slice_mles = compute_bit_slices_flat::( - &public_trace.binary_poly, - &field_cfg, - ); - public_bit_slice_mles - .into_iter() - .map(|mle| mle.evaluate_with_config(md_subclaims.point(), &field_cfg)) - .collect::, _>>() - .map_err(ProtocolError::ShiftedBitSliceEval)? - } else { - Vec::new() - }; + let public_bit_slice_evals: Vec = + if !virtual_bp_specs.is_empty() || !virtual_specs.is_empty() { + let public_bit_slice_mles = + compute_bit_slices_flat::(&public_trace.binary_poly, &field_cfg); + public_bit_slice_mles + .into_iter() + .map(|mle| mle.evaluate_with_config(md_subclaims.point(), &field_cfg)) + .collect::, _>>() + .map_err(ProtocolError::ShiftedBitSliceEval)? + } else { + Vec::new() + }; let mut closing_overrides_tail: Vec = Vec::new(); if !virtual_bp_specs.is_empty() { let virtual_bp_overrides = compute_virtual_binary_poly_closing_overrides::( @@ -1324,7 +1506,7 @@ where let cpr_eval_point = cpr_subclaim.evaluation_point.clone(); let mut up_evals_with_bit_op = cpr_subclaim.up_evals.clone(); up_evals_with_bit_op.extend(cpr_subclaim.bit_op_down_evals.iter().cloned()); - let mp_subclaim = MultipointEval::verify_as_subprotocol( + let mp_subclaim = verify_multipoint_reduction( &mut pcs_transcript.fs_transcript, proof.multipoint_eval, &cpr_eval_point, @@ -1428,11 +1610,10 @@ where { let comm = &proof.commitments.0; if comm.batch_size > 0 { - let per_poly_alphas = - ZipPlus::::sample_alphas( - &mut pcs_transcript.fs_transcript, - comm.batch_size, - ); + let per_poly_alphas = ZipPlus::::sample_alphas( + &mut pcs_transcript.fs_transcript, + comm.batch_size, + ); let one = F::one_with_cfg(&field_cfg); let one_minus_gamma = one - gamma.clone(); @@ -1485,11 +1666,10 @@ where { let comm = &proof.commitments.1; if comm.batch_size > 0 { - let per_poly_alphas = - ZipPlus::::sample_alphas( - &mut pcs_transcript.fs_transcript, - comm.batch_size, - ); + let per_poly_alphas = ZipPlus::::sample_alphas( + &mut pcs_transcript.fs_transcript, + comm.batch_size, + ); let mut eval_f = F::zero_with_cfg(&field_cfg); for (bar_u, alphas) in all_lifted_evals [add!(num_total_bin, num_pub_arb)..add!(num_total_bin, num_total_arb)] @@ -1551,6 +1731,7 @@ where } } + ensure_pcs_stream_consumed::(&pcs_transcript)?; Ok(()) } @@ -1586,7 +1767,6 @@ where // witness columns. // - /// Per-region wall-time breakdown of a single [`verify_folded_4x`] run, /// populated by [`verify_folded_4x_with_timings`]. Mirrors /// [`crate::prover::FoldedProveTimings`] step-for-step; summing the @@ -1672,6 +1852,8 @@ where ::Cw: ProjectableToField, U: Uair, D>> + 'static, F: InnerTransparentField + + MontgomeryLimbs + + DelayedFieldProductSum + FromPrimitiveWithConfig + for<'b> FromWithConfig<&'b Int> + for<'b> FromWithConfig<&'b Int> @@ -1746,6 +1928,8 @@ where ::Cw: ProjectableToField, U: Uair, D>> + 'static, F: InnerTransparentField + + MontgomeryLimbs + + DelayedFieldProductSum + FromPrimitiveWithConfig + for<'b> FromWithConfig<&'b Int> + for<'b> FromWithConfig<&'b Int> @@ -1821,6 +2005,8 @@ where ::Cw: ProjectableToField, U: Uair, D>> + 'static, F: InnerTransparentField + + MontgomeryLimbs + + DelayedFieldProductSum + FromPrimitiveWithConfig + for<'b> FromWithConfig<&'b Int> + for<'b> FromWithConfig<&'b Int> @@ -1894,9 +2080,8 @@ where let projecting_element_f: F = F::from_with_cfg(&projecting_element, &field_cfg); let projected_scalars_fx = project_scalars::(|s| project_scalar(s, &field_cfg)); - let projected_scalars_f = - project_scalars_to_field(projected_scalars_fx, &projecting_element_f) - .map_err(|(_s, _f, e)| ProtocolError::ScalarProjection(e))?; + let projected_scalars_f = project_scalars_to_field(projected_scalars_fx, &projecting_element_f) + .map_err(|(_s, _f, e)| ProtocolError::ScalarProjection(e))?; if let Some(t) = timings.as_mut() { t.step3_eval_projection = _t_step3.elapsed(); } @@ -2055,7 +2240,7 @@ where let cpr_eval_point = cpr_subclaim.evaluation_point.clone(); let mut up_evals_with_bit_op = cpr_subclaim.up_evals.clone(); up_evals_with_bit_op.extend(cpr_subclaim.bit_op_down_evals.iter().cloned()); - let mp_subclaim = MultipointEval::verify_as_subprotocol( + let mp_subclaim = verify_multipoint_reduction( &mut pcs_transcript.fs_transcript, proof.multipoint_eval, &cpr_eval_point, @@ -2081,11 +2266,10 @@ where let num_wit_arb = wit_cols.num_arbitrary_poly_cols(); let public_lifted = if add!(add!(num_pub_bin, num_pub_arb), num_pub_int) > 0 { - let projected_public = - project_trace_coeffs_row_major::, Int, D>( - public_trace, - &field_cfg, - ); + let projected_public = project_trace_coeffs_row_major::, Int, D>( + public_trace, + &field_cfg, + ); let mut lifted = crate::compute_lifted_evals::( &r_0, &public_trace.binary_poly, @@ -2093,12 +2277,11 @@ where &field_cfg, ); if num_pub_int > 0 { - let int_4coeff = - crate::compute_int_fold_4x_lifted_evals::( - &r_0, - &public_trace.int, - &field_cfg, - ); + let int_4coeff = crate::compute_int_fold_4x_lifted_evals::< + F, + INT_LIMBS, + INT_QUARTER_LIMBS, + >(&r_0, &public_trace.int, &field_cfg); let int_off = num_pub_bin + num_pub_arb; for (i, bar_u) in int_4coeff.into_iter().enumerate() { lifted[int_off + i] = bar_u; @@ -2233,7 +2416,10 @@ where // Closure: compute eval_f for the binary path's 4× fold. let bin_eval_f = |alphas: &[Vec<::Chal>]| -> F { let mut eval_f = zero.clone(); - for (bar_u, a) in all_lifted_evals[bin_range.clone()].iter().zip(alphas.iter()) { + for (bar_u, a) in all_lifted_evals[bin_range.clone()] + .iter() + .zip(alphas.iter()) + { debug_assert_eq!(a.len(), QUARTER_D); let mut c00 = zero.clone(); let mut c10 = zero.clone(); @@ -2283,7 +2469,10 @@ where // c00=α·c[0], c10=α·c[2], c01=α·c[1], c11=α·c[3]. let int_eval_f = |alphas: &[Vec<::Chal>]| -> F { let mut eval_f = zero.clone(); - for (bar_u, a) in all_lifted_evals[int_range.clone()].iter().zip(alphas.iter()) { + for (bar_u, a) in all_lifted_evals[int_range.clone()] + .iter() + .zip(alphas.iter()) + { debug_assert_eq!(a.len(), 1); let a_0: F = F::from_with_cfg(&a[0], &field_cfg); let z = || F::zero_with_cfg(&field_cfg); @@ -2326,7 +2515,10 @@ where // Closure: arb's eval_f (standard ). let arb_eval_f = |alphas: &[Vec<::Chal>]| -> F { let mut eval_f = F::zero_with_cfg(&field_cfg); - for (bar_u, a) in all_lifted_evals[arb_range.clone()].iter().zip(alphas.iter()) { + for (bar_u, a) in all_lifted_evals[arb_range.clone()] + .iter() + .zip(alphas.iter()) + { for (coeff, alpha) in bar_u.coeffs.iter().zip(a.iter()) { let mut term = F::from_with_cfg(alpha, &field_cfg); term *= coeff; @@ -2423,7 +2615,9 @@ where ZipPlus::::verify_pre_open_finalize::< F, CHECK_FOR_OVERFLOW, - >(vp_bin_split2, &field_cfg, &r0_ext, &eval_f, reads) + >( + vp_bin_split2, &field_cfg, &r0_ext, &eval_f, reads + ) .map_err(|e| ProtocolError::PcsVerification(0, e))?, )), _ => Ok(None), @@ -2448,7 +2642,9 @@ where ZipPlus::::verify_pre_open_finalize::< F, CHECK_FOR_OVERFLOW, - >(vp_int_split4, &field_cfg, &r0_ext, &eval_f, reads) + >( + vp_int_split4, &field_cfg, &r0_ext, &eval_f, reads + ) .map_err(|e| ProtocolError::PcsVerification(2, e))?, )), _ => Ok(None), @@ -2486,11 +2682,10 @@ where } else { // Per-instance fallback. if proof.commitments.0.batch_size > 0 { - let alphas = - ZipPlus::::sample_alphas( - &mut pcs_transcript.fs_transcript, - proof.commitments.0.batch_size, - ); + let alphas = ZipPlus::::sample_alphas( + &mut pcs_transcript.fs_transcript, + proof.commitments.0.batch_size, + ); let eval_f = bin_eval_f(&alphas); ZipPlus::::verify_with_alphas::( &mut pcs_transcript, @@ -2504,11 +2699,10 @@ where .map_err(|e| ProtocolError::PcsVerification(0, e))?; } if proof.commitments.1.batch_size > 0 { - let alphas = - ZipPlus::::sample_alphas( - &mut pcs_transcript.fs_transcript, - proof.commitments.1.batch_size, - ); + let alphas = ZipPlus::::sample_alphas( + &mut pcs_transcript.fs_transcript, + proof.commitments.1.batch_size, + ); let eval_f = arb_eval_f(&alphas); ZipPlus::::verify_with_alphas::< F, @@ -2546,5 +2740,6 @@ where t.step7_pcs_verify = _t_step7.elapsed(); } + ensure_pcs_stream_consumed::(&pcs_transcript)?; Ok(()) } diff --git a/test-uair/Cargo.toml b/test-uair/Cargo.toml index 3e89df0e..e9cf7b91 100644 --- a/test-uair/Cargo.toml +++ b/test-uair/Cargo.toml @@ -16,6 +16,7 @@ crypto-primitives = { workspace = true } itertools = { workspace = true } num-traits = { workspace = true } rand = { workspace = true } +rayon = { workspace = true, optional = true } zinc-poly = { workspace = true } zinc-uair = { workspace = true } zinc-utils = { workspace = true } @@ -23,4 +24,6 @@ zinc-utils = { workspace = true } [lints] workspace = true +[features] +parallel = ["dep:rayon", "zinc-utils/parallel"] diff --git a/test-uair/src/ecdsa.rs b/test-uair/src/ecdsa.rs index 1ad1b121..3f0aab83 100644 --- a/test-uair/src/ecdsa.rs +++ b/test-uair/src/ecdsa.rs @@ -64,8 +64,7 @@ use rand::RngCore; use zinc_poly::{mle::DenseMultilinearExtension, univariate::dense::DensePolynomial}; use zinc_uair::{ ConstraintBuilder, PublicColumnLayout, ShiftSpec, TotalColumnLayout, TraceRow, Uair, - UairSignature, UairTrace, - ideal::ImpossibleIdeal, + UairSignature, UairTrace, ideal::ImpossibleIdeal, }; use crate::GenerateRandomTrace; @@ -315,15 +314,12 @@ where // C-D4: Y_pa − 12·X³·Y² + 3·X²·X_pa + 8·Y⁴ = 0 let x3_y_sq = x_sq.clone() * &x_y_sq; - let twelve_x3_y_sq = - mbs(&x3_y_sq, &twelve_scalar).expect("12·X³·Y² overflow"); + let twelve_x3_y_sq = mbs(&x3_y_sq, &twelve_scalar).expect("12·X³·Y² overflow"); let x_sq_x_pa = x_sq.clone() * x_pa; - let three_x2_xpa = - mbs(&x_sq_x_pa, &three_scalar).expect("3·X²·X_pa overflow"); + let three_x2_xpa = mbs(&x_sq_x_pa, &three_scalar).expect("3·X²·X_pa overflow"); let y_pow4 = y_sq.clone() * &y_sq; let eight_y_pow4 = mbs(&y_pow4, &eight_scalar).expect("8·Y⁴ overflow"); - let d4_inner = - y_pa.clone() - &twelve_x3_y_sq + &three_x2_xpa + &eight_y_pow4; + let d4_inner = y_pa.clone() - &twelve_x3_y_sq + &three_x2_xpa + &eight_y_pow4; b.assert_zero(s_active.clone() * &d4_inner); // =================================================================== @@ -376,12 +372,10 @@ where // C-O2 (Y): down.Y − Y_pa − S_ADD·(3·D·X_pa·C² + D·C³ − D³ − Y_pa·C³ − Y_pa) = 0 let d_cube = d.clone() * &d_sq; let d_x_pa_c_sq = d.clone() * &x_pa_c_sq; - let three_d_x_pa_c_sq = - mbs(&d_x_pa_c_sq, &three_scalar).expect("3·D·X_pa·C² overflow"); + let three_d_x_pa_c_sq = mbs(&d_x_pa_c_sq, &three_scalar).expect("3·D·X_pa·C² overflow"); let d_c_cube = d.clone() * &c_cube; let y_pa_c_cube = y_pa.clone() * &c_cube; - let y_add_minus_y_pa = - three_d_x_pa_c_sq + &d_c_cube - &d_cube - &y_pa_c_cube - y_pa; + let y_add_minus_y_pa = three_d_x_pa_c_sq + &d_c_cube - &d_cube - &y_pa_c_cube - y_pa; let s_add_y = s_add.clone() * &y_add_minus_y_pa; let o2_inner = down_y.clone() - y_pa - &s_add_y; b.assert_zero(s_active.clone() * &o2_inner); @@ -494,13 +488,11 @@ fn rand_nonzero_fp(rng: &mut Rng) -> CbUint) -> CbUint { let p_odd = Odd::new(SECP256K1_P_UINT).expect("p is odd"); - a.invert_odd_mod(&p_odd).expect("a has no inverse mod p (a == 0?)") + a.invert_odd_mod(&p_odd) + .expect("a has no inverse mod p (a == 0?)") } -fn mul_mod_p( - a: &CbUint, - b: &CbUint, -) -> CbUint { +fn mul_mod_p(a: &CbUint, b: &CbUint) -> CbUint { let wide: CbUint<{ EC_FP_INT_LIMBS * 2 }> = a.widening_mul(b).into(); let p_wide: CbUint<{ EC_FP_INT_LIMBS * 2 }> = SECP256K1_P_UINT.resize(); let p_wide_nz = NonZero::new(p_wide).expect("p is nonzero"); @@ -526,10 +518,7 @@ fn p_geq(a: &CbUint) -> bool { a.checked_sub(&SECP256K1_P_UINT).is_some().into() } -fn sub_mod_p( - a: &CbUint, - b: &CbUint, -) -> CbUint { +fn sub_mod_p(a: &CbUint, b: &CbUint) -> CbUint { use crypto_bigint::CheckedSub; let p_nz = NonZero::new(SECP256K1_P_UINT).expect("p is nonzero"); if a.checked_sub(b).is_some().into() { @@ -848,8 +837,15 @@ mod tests { assert_eq!(count_max_degree::(), 7); let degrees = count_constraint_degrees::(); // Spot-checks: at least one deg-7 (Y addend constraint); 3 init deg-2. - assert!(degrees.iter().any(|&d| d == 7), "expected at least one deg-7"); - assert_eq!(degrees.iter().filter(|&&d| d == 2).count(), 3, "init = 3 deg-2"); + assert!( + degrees.iter().any(|&d| d == 7), + "expected at least one deg-7" + ); + assert_eq!( + degrees.iter().filter(|&&d| d == 2).count(), + 3, + "init = 3 deg-2" + ); } /// Witness gen produces a trace where every constraint vanishes @@ -858,8 +854,10 @@ mod tests { fn witness_satisfies_constraints_mod_p() { let num_vars = 9; let mut r = rng(); - let trace = > as GenerateRandomTrace<32>>:: - generate_random_trace(num_vars, &mut r); + let trace = + > as GenerateRandomTrace<32>>::generate_random_trace( + num_vars, &mut r, + ); let n_rows = 1 << num_vars; assert_eq!(trace.int.len(), cols::NUM_INT); @@ -909,17 +907,41 @@ mod tests { // Output: down.X = next R = expected.next_x. if t + 1 < n_rows { - assert_eq!(read_uint(cols::W_X, t + 1), expected.next_x, "next X at {t}"); - assert_eq!(read_uint(cols::W_Y, t + 1), expected.next_y, "next Y at {t}"); - assert_eq!(read_uint(cols::W_Z, t + 1), expected.next_z, "next Z at {t}"); + assert_eq!( + read_uint(cols::W_X, t + 1), + expected.next_x, + "next X at {t}" + ); + assert_eq!( + read_uint(cols::W_Y, t + 1), + expected.next_y, + "next Y at {t}" + ); + assert_eq!( + read_uint(cols::W_Z, t + 1), + expected.next_z, + "next Z at {t}" + ); } } // Init boundary: row 0's R = PA_R_INIT. if init { - assert_eq!(read_uint(cols::W_X, t), read_uint(cols::PA_R_INIT_X, t), "init X at {t}"); - assert_eq!(read_uint(cols::W_Y, t), read_uint(cols::PA_R_INIT_Y, t), "init Y at {t}"); - assert_eq!(read_uint(cols::W_Z, t), read_uint(cols::PA_R_INIT_Z, t), "init Z at {t}"); + assert_eq!( + read_uint(cols::W_X, t), + read_uint(cols::PA_R_INIT_X, t), + "init X at {t}" + ); + assert_eq!( + read_uint(cols::W_Y, t), + read_uint(cols::PA_R_INIT_Y, t), + "init Y at {t}" + ); + assert_eq!( + read_uint(cols::W_Z, t), + read_uint(cols::PA_R_INIT_Z, t), + "init Z at {t}" + ); } // Final-row check: no in-circuit constraint after dropping diff --git a/test-uair/src/ecdsa_addition.rs b/test-uair/src/ecdsa_addition.rs index e0ab3c7e..4a930fa0 100644 --- a/test-uair/src/ecdsa_addition.rs +++ b/test-uair/src/ecdsa_addition.rs @@ -54,8 +54,7 @@ use rand::RngCore; use zinc_poly::{mle::DenseMultilinearExtension, univariate::dense::DensePolynomial}; use zinc_uair::{ ConstraintBuilder, PublicColumnLayout, TotalColumnLayout, TraceRow, Uair, UairSignature, - UairTrace, - ideal::ImpossibleIdeal, + UairTrace, ideal::ImpossibleIdeal, }; use crate::GenerateRandomTrace; @@ -311,10 +310,7 @@ fn rand_fp(rng: &mut Rng) -> CbUint { raw.rem_vartime(&p_nz) } -fn mul_mod_p( - a: &CbUint, - b: &CbUint, -) -> CbUint { +fn mul_mod_p(a: &CbUint, b: &CbUint) -> CbUint { let wide: CbUint<{ EC_FP_INT_LIMBS * 2 }> = a.widening_mul(b).into(); let p_wide: CbUint<{ EC_FP_INT_LIMBS * 2 }> = SECP256K1_P_UINT.resize(); let p_wide_nz = NonZero::new(p_wide).expect("p is nonzero"); @@ -339,10 +335,7 @@ fn p_geq(a: &CbUint) -> bool { a.checked_sub(&SECP256K1_P_UINT).is_some().into() } -fn sub_mod_p( - a: &CbUint, - b: &CbUint, -) -> CbUint { +fn sub_mod_p(a: &CbUint, b: &CbUint) -> CbUint { let p_nz = NonZero::new(SECP256K1_P_UINT).expect("p is nonzero"); if a.checked_sub(b).is_some().into() { a.wrapping_sub(b).rem_vartime(&p_nz) diff --git a/test-uair/src/ecdsa_affine.rs b/test-uair/src/ecdsa_affine.rs index 745fa4c9..a8d93f1d 100644 --- a/test-uair/src/ecdsa_affine.rs +++ b/test-uair/src/ecdsa_affine.rs @@ -46,8 +46,7 @@ use rand::RngCore; use zinc_poly::{mle::DenseMultilinearExtension, univariate::dense::DensePolynomial}; use zinc_uair::{ ConstraintBuilder, PublicColumnLayout, TotalColumnLayout, TraceRow, Uair, UairSignature, - UairTrace, - ideal::ImpossibleIdeal, + UairTrace, ideal::ImpossibleIdeal, }; use crate::GenerateRandomTrace; @@ -247,10 +246,7 @@ fn rand_nonzero_fp(rng: &mut Rng) -> CbUint, - b: &CbUint, -) -> CbUint { +fn mul_mod_p(a: &CbUint, b: &CbUint) -> CbUint { let wide: CbUint<{ EC_FP_INT_LIMBS * 2 }> = a.widening_mul(b).into(); let p_wide: CbUint<{ EC_FP_INT_LIMBS * 2 }> = SECP256K1_P_UINT.resize(); let p_wide_nz = NonZero::new(p_wide).expect("p is nonzero"); @@ -368,10 +364,18 @@ mod tests { // C1: Z · Z_inv = 1 let one_uint: CbUint = CbUint::ONE; - assert_eq!(mul_mod_p(&z, &z_inv), one_uint, "C1 (Z·Z_inv=1) at row {row}"); + assert_eq!( + mul_mod_p(&z, &z_inv), + one_uint, + "C1 (Z·Z_inv=1) at row {row}" + ); // C2: Z_inv_sq = Z_inv² - assert_eq!(z_inv_sq, mul_mod_p(&z_inv, &z_inv), "C2 (Z_inv²) at row {row}"); + assert_eq!( + z_inv_sq, + mul_mod_p(&z_inv, &z_inv), + "C2 (Z_inv²) at row {row}" + ); // C3: Z_inv_cube = Z_inv · Z_inv_sq assert_eq!( diff --git a/test-uair/src/ecdsa_doubling.rs b/test-uair/src/ecdsa_doubling.rs index b5c94966..5f36e545 100644 --- a/test-uair/src/ecdsa_doubling.rs +++ b/test-uair/src/ecdsa_doubling.rs @@ -66,8 +66,7 @@ use rand::RngCore; use zinc_poly::{mle::DenseMultilinearExtension, univariate::dense::DensePolynomial}; use zinc_uair::{ ConstraintBuilder, PublicColumnLayout, ShiftSpec, TotalColumnLayout, TraceRow, Uair, - UairSignature, UairTrace, - ideal::ImpossibleIdeal, + UairSignature, UairTrace, ideal::ImpossibleIdeal, }; use crate::GenerateRandomTrace; @@ -215,8 +214,7 @@ where let x_sq_x_s = x_sq.clone() * &xs; // X²·X·S = X³·S let twelve_x3s = mbs(&x_sq_x_s, &twelve_scalar).expect("12·X³·S overflow"); let x_sq_xmid = x_sq.clone() * x_mid; - let three_xsq_xmid = - mbs(&x_sq_xmid, &three_scalar).expect("3·X²·X_mid overflow"); + let three_xsq_xmid = mbs(&x_sq_xmid, &three_scalar).expect("3·X²·X_mid overflow"); let s_sq = s_w.clone() * s_w; let eight_s_sq = mbs(&s_sq, &eight_scalar).expect("8·S² overflow"); let c4_inner = y_mid.clone() - &twelve_x3s + &three_xsq_xmid + &eight_s_sq; @@ -326,10 +324,7 @@ fn rand_fp(rng: &mut Rng) -> CbUint { } /// `(a · b) mod p`. -fn mul_mod_p( - a: &CbUint, - b: &CbUint, -) -> CbUint { +fn mul_mod_p(a: &CbUint, b: &CbUint) -> CbUint { let wide: CbUint<{ EC_FP_INT_LIMBS * 2 }> = a.widening_mul(b).into(); let p_wide: CbUint<{ EC_FP_INT_LIMBS * 2 }> = SECP256K1_P_UINT.resize(); let p_wide_nz = NonZero::new(p_wide).expect("p is nonzero"); @@ -356,10 +351,7 @@ fn p_geq(a: &CbUint) -> bool { } /// `(a − b) mod p`, allowing `a < b`. -fn sub_mod_p( - a: &CbUint, - b: &CbUint, -) -> CbUint { +fn sub_mod_p(a: &CbUint, b: &CbUint) -> CbUint { let p_nz = NonZero::new(SECP256K1_P_UINT).expect("p is nonzero"); if a.checked_sub(b).is_some().into() { a.wrapping_sub(b).rem_vartime(&p_nz) diff --git a/test-uair/src/lib.rs b/test-uair/src/lib.rs index a6951178..47de863c 100644 --- a/test-uair/src/lib.rs +++ b/test-uair/src/lib.rs @@ -12,8 +12,13 @@ pub use ecdsa_addition::JacobianAdditionUair; pub use ecdsa_affine::AffineConversionUair; pub use ecdsa_doubling::{EC_FP_INT_LIMBS, EcdsaFpRing, JacobianDoublingUair}; pub use generate_trace::*; -pub use sha256::{Sha256CompressionSliceUair, Sha256Ideal}; pub use sha_ecdsa::ShaEcdsaUair; +pub use sha256::{ + SHA256_INITIAL_STATE, Sha256CompressionSliceUair, Sha256Ideal, Sha256MessageBlock, Sha256State, + Sha256WitnessError, sha256_compress_native, sha256_padded_message_blocks, + synthesize_one_sha256_compression_trace, synthesize_sha256_chain_trace, + synthesize_sha256_chain_witnesses, +}; use crypto_primitives::{ConstSemiring, FixedSemiring, Semiring, boolean::Boolean}; use num_traits::Zero; @@ -1008,10 +1013,7 @@ where { let two_ideal = ideal_from_ref(&DegreeOneIdeal::new(R::from(2))); // V[i] − Rot(7)(W[i]) ≡ 0 mod (X − 2) - b.assert_in_ideal( - up.binary_poly[1].clone() - &down.bit_op[0], - &two_ideal, - ); + b.assert_in_ideal(up.binary_poly[1].clone() - &down.bit_op[0], &two_ideal); } } diff --git a/test-uair/src/sha256.rs b/test-uair/src/sha256.rs index 8c9a5bc6..de6c6e80 100644 --- a/test-uair/src/sha256.rs +++ b/test-uair/src/sha256.rs @@ -163,24 +163,25 @@ //! these would let a verifier pin the initial compression state. The //! init boundary currently only constrains `a[0] = pa_a[0]`. -use core::marker::PhantomData; +use core::{fmt, marker::PhantomData}; use crypto_primitives::{ConstSemiring, PrimeField, Semiring}; use rand::RngCore; +#[cfg(feature = "parallel")] +use rayon::prelude::*; use zinc_poly::{ mle::DenseMultilinearExtension, univariate::{ - binary::BinaryPoly, dense::DensePolynomial, - dynamic::over_field::DynamicPolynomialF, + binary::BinaryPoly, dense::DensePolynomial, dynamic::over_field::DynamicPolynomialF, }, }; use zinc_uair::{ BitOp, BitOpSpec, ConstraintBuilder, LookupColumnSpec, PublicColumnLayout, PublicStructureError, ShiftSpec, ShiftedBitSliceSpec, TotalColumnLayout, TraceRow, Uair, - UairSignature, UairTrace, VirtualBinaryPolySource, VirtualBinaryPolySpec, + UairSignature, UairTrace, UairWitness, VirtualBinaryPolySource, VirtualBinaryPolySpec, ideal::{Ideal, IdealCheck, IdealCheckError, rotation::RotationIdeal}, }; -use zinc_utils::from_ref::FromRef; +use zinc_utils::{cfg_into_iter, from_ref::FromRef}; use crate::GenerateRandomTrace; @@ -248,6 +249,81 @@ where #[derive(Clone, Debug)] pub struct Sha256CompressionSliceUair(PhantomData); +/// Canonical SHA-256 compression state `[a, b, c, d, e, f, g, h]`. +pub type Sha256State = [u32; 8]; + +/// Canonical SHA-256 initial hash state H_0 (FIPS 180-4 section 5.3.3). +pub const SHA256_INITIAL_STATE: Sha256State = [ + 0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19, +]; + +/// One 512-bit SHA-256 message block as sixteen big-endian words. +pub type Sha256MessageBlock = [u32; 16]; + +/// Errors surfaced by deterministic SHA-256 witness synthesis. +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum Sha256WitnessError { + /// A byte message is too long to encode its bit length in SHA-256's u64 + /// length suffix. + MessageBitLengthOverflow { bytes: usize }, + /// The canonical SHA-256 padding of a byte message did not produce the + /// caller-requested fixed number of blocks. + PaddedBlockCountMismatch { expected: usize, got: usize }, + /// The requested packed trace would not fit in the requested MLE size. + TraceTooSmall { + num_compressions: usize, + active_rows: usize, + num_vars: usize, + }, + /// A synthesized compression trace did not return the expected terminal state. + FinalStateMismatch { + index: usize, + expected: Sha256State, + got: Sha256State, + }, + /// Internal conversion from `Vec` to fixed-size array failed. + InternalLengthMismatch { expected: usize, got: usize }, +} + +impl fmt::Display for Sha256WitnessError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::MessageBitLengthOverflow { bytes } => { + write!( + f, + "message of {bytes} byte(s) is too long for SHA-256 padding" + ) + } + Self::PaddedBlockCountMismatch { expected, got } => write!( + f, + "expected canonical SHA-256 padding to produce {expected} block(s), got {got}" + ), + Self::TraceTooSmall { + num_compressions, + active_rows, + num_vars, + } => write!( + f, + "trace too small for {num_compressions} SHA-256 compression(s): \ + {active_rows} active rows do not fit in 2^{num_vars} rows" + ), + Self::FinalStateMismatch { + index, + expected, + got, + } => write!( + f, + "SHA-256 witness {index} ended at {got:08x?}, expected {expected:08x?}" + ), + Self::InternalLengthMismatch { expected, got } => { + write!(f, "expected {expected} synthesized witness(es), got {got}") + } + } + } +} + +impl std::error::Error for Sha256WitnessError {} + /// Column indices within the flat trace (binary || arbitrary || int). /// /// All polynomial columns are bit-polynomials (stored as `binary_poly`); @@ -367,7 +443,7 @@ pub mod cols { // witness and the verifier no longer needs an out-of-band // public-structure check on them. pub const S_ACTIVE_SCHED: usize = 4; // public: 1 on C7's active range [start, start+48) per compression - pub const S_ACTIVE_UPD: usize = 5; // public: 1 on C8/C9's active range [start, start+64) per compression + pub const S_ACTIVE_UPD: usize = 5; // public: 1 on C8/C9's active range [start, start+64) per compression // (For C12/C13 the active range is the junction window // [start+64, start+68), which is exactly where `S_FEEDFORWARD` is // 1 — no separate selector needed.) @@ -387,11 +463,11 @@ pub mod cols { // - SCHED (C7): k ∈ [start, start + 48) // - UPD (C8/C9): k ∈ [start, start + 64) // - JUNCTION (C12/C13): k ∈ [start + 64, start + 68) - pub const PA_C_C7: usize = 6; // witness: compensator for C7 (sched_anch) - pub const PA_C_C8: usize = 7; // witness: compensator for C8 (upd_anch a) - pub const PA_C_C9: usize = 8; // witness: compensator for C9 (upd_anch e) - pub const PA_C_FF_A: usize = 9; // witness: compensator for C12 (feed-forward a-half) - pub const PA_C_FF_E: usize = 10; // witness: compensator for C13 (feed-forward e-half) + pub const PA_C_C7: usize = 6; // witness: compensator for C7 (sched_anch) + pub const PA_C_C8: usize = 7; // witness: compensator for C8 (upd_anch a) + pub const PA_C_C9: usize = 8; // witness: compensator for C9 (upd_anch e) + pub const PA_C_FF_A: usize = 9; // witness: compensator for C12 (feed-forward a-half) + pub const PA_C_FF_E: usize = 10; // witness: compensator for C13 (feed-forward e-half) /// Total number of int columns. /// @@ -464,7 +540,8 @@ where ShiftSpec::new(cols::FLAT_W_E, 4), // w_sig1: Sigma_1(e[t]) at anchor t-3. ShiftSpec::new(cols::FLAT_W_SIG1, 3), - // w_W: message-schedule 9, 16 AND register-update 3. + // w_W: message-schedule shifts 9 and 16. Shift 3 is retained + // for signature-slot stability; round updates consume up.w_W. ShiftSpec::new(cols::FLAT_W_W, 3), ShiftSpec::new(cols::FLAT_W_W, 9), ShiftSpec::new(cols::FLAT_W_W, 16), @@ -509,11 +586,11 @@ where // `down.bit_op` and the (X^32 − 1) modular lift goes away: // ROT/SHIFTR virtual columns are mod X^32 by construction. let bit_op_specs: Vec = vec![ - BitOpSpec::new(cols::FLAT_W_W, BitOp::Rot(25)), // σ_0: ROTR^7 - BitOpSpec::new(cols::FLAT_W_W, BitOp::Rot(14)), // σ_0: ROTR^18 - BitOpSpec::new(cols::FLAT_W_W, BitOp::ShiftR(3)), // σ_0: SHR^3 - BitOpSpec::new(cols::FLAT_W_W, BitOp::Rot(15)), // σ_1: ROTR^17 - BitOpSpec::new(cols::FLAT_W_W, BitOp::Rot(13)), // σ_1: ROTR^19 + BitOpSpec::new(cols::FLAT_W_W, BitOp::Rot(25)), // σ_0: ROTR^7 + BitOpSpec::new(cols::FLAT_W_W, BitOp::Rot(14)), // σ_0: ROTR^18 + BitOpSpec::new(cols::FLAT_W_W, BitOp::ShiftR(3)), // σ_0: SHR^3 + BitOpSpec::new(cols::FLAT_W_W, BitOp::Rot(15)), // σ_1: ROTR^17 + BitOpSpec::new(cols::FLAT_W_W, BitOp::Rot(13)), // σ_1: ROTR^19 BitOpSpec::new(cols::FLAT_W_W, BitOp::ShiftR(10)), // σ_1: SHR^10 // Bit-op virtuals over W_MU_PACKED for extracting each // carry from its bit slice. The C7/C8/C9/C12/C13 @@ -521,10 +598,10 @@ where // `2^32 · ShiftR(k_low) − 2^{32+w} · ShiftR(k_low+w)`. // ShiftR(10) is also assert_zero'd by C22 to pin // positions 10..31 to 0. - BitOpSpec::new(cols::FLAT_W_MU_PACKED, BitOp::ShiftR(2)), // skips mu_W - BitOpSpec::new(cols::FLAT_W_MU_PACKED, BitOp::ShiftR(5)), // skips mu_W + mu_a - BitOpSpec::new(cols::FLAT_W_MU_PACKED, BitOp::ShiftR(8)), // skips through mu_e - BitOpSpec::new(cols::FLAT_W_MU_PACKED, BitOp::ShiftR(9)), // keeps mu_ff_e + (high) + BitOpSpec::new(cols::FLAT_W_MU_PACKED, BitOp::ShiftR(2)), // skips mu_W + BitOpSpec::new(cols::FLAT_W_MU_PACKED, BitOp::ShiftR(5)), // skips mu_W + mu_a + BitOpSpec::new(cols::FLAT_W_MU_PACKED, BitOp::ShiftR(8)), // skips through mu_e + BitOpSpec::new(cols::FLAT_W_MU_PACKED, BitOp::ShiftR(9)), // keeps mu_ff_e + (high) BitOpSpec::new(cols::FLAT_W_MU_PACKED, BitOp::ShiftR(10)), // (high) — must be 0 ]; // Witness-relative col indices (post-public) for virtual specs. @@ -722,7 +799,7 @@ where let _down_w_e_sh2 = &down.binary_poly[5]; let down_w_e_sh4 = &down.binary_poly[6]; let down_w_sig1_sh3 = &down.binary_poly[7]; - let down_w_w_sh3 = &down.binary_poly[8]; + let _down_w_w_sh3 = &down.binary_poly[8]; let down_w_w_sh9 = &down.binary_poly[9]; let down_w_w_sh16 = &down.binary_poly[10]; let down_w_lsig0_sh1 = &down.binary_poly[11]; @@ -748,10 +825,10 @@ where let down_w_shr3 = &down.bit_op[4]; // σ_0: SHR^3 let down_w_shr10 = &down.bit_op[5]; // σ_1: SHR^10 // Bit-extraction shifts on W_MU_PACKED. - let down_w_mu_packed_shr2 = &down.bit_op[6]; // skips mu_W - let down_w_mu_packed_shr5 = &down.bit_op[7]; // skips mu_W + mu_a - let down_w_mu_packed_shr8 = &down.bit_op[8]; // skips through mu_e - let down_w_mu_packed_shr9 = &down.bit_op[9]; // keeps only mu_ff_e + (high) + let down_w_mu_packed_shr2 = &down.bit_op[6]; // skips mu_W + let down_w_mu_packed_shr5 = &down.bit_op[7]; // skips mu_W + mu_a + let down_w_mu_packed_shr8 = &down.bit_op[8]; // skips through mu_e + let down_w_mu_packed_shr9 = &down.bit_op[9]; // keeps only mu_ff_e + (high) let down_w_mu_packed_shr10 = &down.bit_op[10]; // (high) — must be 0 // Ideals. @@ -777,8 +854,7 @@ where let const_2_to_35 = const_scalar::(pow_two::(35)); // mu_X contributions (each evaluates to `2^32 · mu_X` at X=2). - let mu_w_contrib = mbs(w_mu_packed, &const_2_to_32) - .expect("2^32 · w_mu_packed overflow") + let mu_w_contrib = mbs(w_mu_packed, &const_2_to_32).expect("2^32 · w_mu_packed overflow") - &mbs(down_w_mu_packed_shr2, &const_2_to_34) .expect("2^34 · ShiftR(2)(w_mu_packed) overflow"); let mu_a_contrib = mbs(down_w_mu_packed_shr2, &const_2_to_32) @@ -807,7 +883,8 @@ where // Constraint 1: Sigma_0 rotation, Q[X]-lifted. // (a_hat · rho_sig0 − sig0_hat − 2 · ov_sig0) ∈ (X^32 − 1) b.assert_in_ideal( - mbs(w_a, &rho_sig0).expect("a · rho_sig0 overflow") - w_sig0 + mbs(w_a, &rho_sig0).expect("a · rho_sig0 overflow") + - w_sig0 - &mbs(pa_ov_sig0, &two_scalar).expect("2 · ov_sig0 overflow"), &ideal_rot_xw1, ); @@ -815,7 +892,8 @@ where // Constraint 2: Sigma_1 rotation, Q[X]-lifted. // (e_hat · rho_sig1 − sig1_hat − 2 · ov_sig1) ∈ (X^32 − 1) b.assert_in_ideal( - mbs(w_e, &rho_sig1).expect("e · rho_sig1 overflow") - w_sig1 + mbs(w_e, &rho_sig1).expect("e · rho_sig1 overflow") + - w_sig1 - &mbs(pa_ov_sig1, &two_scalar).expect("2 · ov_sig1 overflow"), &ideal_rot_xw1, ); @@ -828,14 +906,16 @@ where // coefficient sum {0..3} → bit XOR. // ROT^25(W) + ROT^14(W) + SHIFTR^3(W) − lsig0 − 2 · pa_ov_lsig0 == 0 b.assert_zero( - down_w_rot25.clone() + down_w_rot14 + down_w_shr3 - w_lsig0 + down_w_rot25.clone() + down_w_rot14 + down_w_shr3 + - w_lsig0 - &mbs(pa_ov_lsig0, &two_scalar).expect("2 · ov_lsig0 overflow"), ); // Constraint 6 (was σ_1 (X^32 − 1) ideal-lift): σ_1 analogue of C4. // ROT^15(W) + ROT^13(W) + SHIFTR^10(W) − lsig1 − 2 · pa_ov_lsig1 == 0 b.assert_zero( - down_w_rot15.clone() + down_w_rot13 + down_w_shr10 - w_lsig1 + down_w_rot15.clone() + down_w_rot13 + down_w_shr10 + - w_lsig1 - &mbs(pa_ov_lsig1, &two_scalar).expect("2 · ov_lsig1 overflow"), ); @@ -857,12 +937,9 @@ where // mu_W is now read from the up row (chained-comp re-anchoring // stores each carry at its constraint's anchor row, not at // spec-row t). `mu_w_contrib` evaluates to `2^32 · mu_W` at X=2. - let sched_inner = down_w_w_sh16.clone() - - w_big_w - - down_w_lsig0_sh1 - - down_w_w_sh9 - - down_w_lsig1_sh14 - + &mu_w_contrib; + let sched_inner = + down_w_w_sh16.clone() - w_big_w - down_w_lsig0_sh1 - down_w_w_sh9 - down_w_lsig1_sh14 + + &mu_w_contrib; b.assert_in_ideal(sched_inner + pa_c_c7, &ideal_rot_x2); // Constraint 8: Register-update for `a`, anchored at k = t − 3. @@ -878,7 +955,7 @@ where // a[t] = down.w_a^↓3 Sigma_1(e[t]) = down.w_sig1^↓3 // Sigma_0(a[t]) = down.w_sig0^↓3 u_ef[t] = down.w_u_ef^↓3 // u_{¬e,g}[t] = down.w_u_neg_e_g^↓3 - // Maj[t] = down.w_maj^↓3 W[t] = down.w_W^↓3 + // Maj[t] = down.w_maj^↓3 W[t] = up.w_W // K[t] = down.pa_K^↓3 mu_a[t] = down.w_mu_a^↓3 // pa_c_c8 is the witness compensator (see C7 note); zero-on-active // pinned in-circuit by C19: `pa_c_c8 · S_ACTIVE_UPD = 0`. @@ -888,10 +965,10 @@ where - down_w_u_ef_sh3 // Ch[t] = u_ef + u_{¬e,g} - down_w_u_neg_e_g_sh3 - down_pa_k_sh3 // K[t] - - down_w_w_sh3 // W[t] + - w_big_w // W[t] - down_w_sig0_sh3 // Sigma_0(a[t]) - down_w_maj_sh3 // Maj[t] - + &mu_a_contrib; // = 2^32 · mu_a (bits 2-4 of W_MU_PACKED) + + &mu_a_contrib; // = 2^32 · mu_a (bits 2-4 of W_MU_PACKED) b.assert_in_ideal(a_update_inner + pa_c_c8, &ideal_rot_x2); // Constraint 9: Register-update for `e`, anchored at k = t − 3. @@ -907,8 +984,8 @@ where - down_w_u_ef_sh3 // Ch[t] = u_ef + u_{¬e,g} - down_w_u_neg_e_g_sh3 - down_pa_k_sh3 - - down_w_w_sh3 - + &mu_e_contrib; // = 2^32 · mu_e (bits 5-7 of W_MU_PACKED) + - w_big_w + + &mu_e_contrib; // = 2^32 · mu_e (bits 5-7 of W_MU_PACKED) b.assert_in_ideal(e_update_inner + pa_c_c9, &ideal_rot_x2); // C13–C15 (B_1/B_2/B_3 materialization identities) are gone: @@ -958,18 +1035,12 @@ where // Keeps C12 at degree 1 in the trace MLEs (preserving MLE-first // eligibility) and avoids a multiplicative selector that would // push the effective max degree to 2. - let ff_a_inner = down_w_a_sh4.clone() - - w_a - - pa_a - + &mu_ff_a_contrib; // = 2^32 · mu_ff_a (bit 8 of W_MU_PACKED) + let ff_a_inner = down_w_a_sh4.clone() - w_a - pa_a + &mu_ff_a_contrib; // = 2^32 · mu_ff_a (bit 8 of W_MU_PACKED) b.assert_in_ideal(ff_a_inner + pa_c_ff_a, &ideal_rot_x2); // Constraint 13 (feed-forward, e-family). Mirrors C12 on the // e-half via `pa_c_ff_e`. mu_ff_e from bit 9 of W_MU_PACKED. - let ff_e_inner = down_w_e_sh4.clone() - - w_e - - pa_e - + &mu_ff_e_contrib; + let ff_e_inner = down_w_e_sh4.clone() - w_e - pa_e + &mu_ff_e_contrib; b.assert_in_ideal(ff_e_inner + pa_c_ff_e, &ideal_rot_x2); // Constraint 16: message init (Table 9 row 77). For each @@ -1104,17 +1175,14 @@ fn pow_two(k: u32) -> R { /// so the C8/C9 read at anchor `k = start + j` (which references /// `down.pa_K^↓3 = pa_K[k+3]`) lands on the right round constant. pub const K_CANONICAL: [u32; 64] = [ - 0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, - 0x923f82a4, 0xab1c5ed5, 0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, - 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174, 0xe49b69c1, 0xefbe4786, - 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da, - 0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, - 0x06ca6351, 0x14292967, 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, - 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85, 0xa2bfe8a1, 0xa81a664b, - 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070, - 0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, - 0x5b9cca4f, 0x682e6ff3, 0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, - 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2, + 0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5, + 0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174, + 0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da, + 0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967, + 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85, + 0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070, + 0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3, + 0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2, ]; #[inline] @@ -1168,12 +1236,7 @@ fn maj(x: u32, y: u32, z: u32) -> u32 { /// Per the module doc, for the rho patterns we use (3 or 2 nonzero terms) /// each per-position quotient fits in `{0, 1}`; the returned word is /// therefore a valid 32-bit bit-polynomial. -fn rotation_overflow( - input_bits: u32, - rho_positions: &[usize], - s0_bits: u32, - out_bits: u32, -) -> u32 { +fn rotation_overflow(input_bits: u32, rho_positions: &[usize], s0_bits: u32, out_bits: u32) -> u32 { // Compute the Z[X] product coefficients of `input · rho`. let mut prod = [0u32; 64]; for i in 0..32 { @@ -1246,585 +1309,799 @@ fn lsig1_overflow(w_val: u32, lsig1_val: u32) -> u32 { rotation_overflow(w_val, &[13, 15], shr10, lsig1_val) } -// --------------------------------------------------------------------------- -// GenerateRandomTrace for the slice. -// --------------------------------------------------------------------------- +fn state_to_trace_halves(state: Sha256State) -> ([u32; 4], [u32; 4]) { + ( + [state[3], state[2], state[1], state[0]], + [state[7], state[6], state[5], state[4]], + ) +} -impl GenerateRandomTrace<32> for Sha256CompressionSliceUair +fn trace_halves_to_state(h_a: [u32; 4], h_e: [u32; 4]) -> Sha256State { + [ + h_a[3], h_a[2], h_a[1], h_a[0], h_e[3], h_e[2], h_e[1], h_e[0], + ] +} + +/// Native SHA-256 compression of one 512-bit message block. +/// +/// The returned state is `initial_state + round_state` componentwise, in +/// canonical `[a, b, c, d, e, f, g, h]` order. +pub fn sha256_compress_native( + initial_state: Sha256State, + message_block: Sha256MessageBlock, +) -> Sha256State { + let mut w = [0u32; cols::ROUNDS_PER_COMP]; + w[..16].copy_from_slice(&message_block); + for t in 16..cols::ROUNDS_PER_COMP { + w[t] = w[t - 16] + .wrapping_add(small_sigma0(w[t - 15])) + .wrapping_add(w[t - 7]) + .wrapping_add(small_sigma1(w[t - 2])); + } + + let [mut a, mut b, mut c, mut d, mut e, mut f, mut g, mut h] = initial_state; + for t in 0..cols::ROUNDS_PER_COMP { + let t1 = h + .wrapping_add(big_sigma1(e)) + .wrapping_add(ch(e, f, g)) + .wrapping_add(K_CANONICAL[t]) + .wrapping_add(w[t]); + let t2 = big_sigma0(a).wrapping_add(maj(a, b, c)); + h = g; + g = f; + f = e; + e = d.wrapping_add(t1); + d = c; + c = b; + b = a; + a = t1.wrapping_add(t2); + } + + [ + initial_state[0].wrapping_add(a), + initial_state[1].wrapping_add(b), + initial_state[2].wrapping_add(c), + initial_state[3].wrapping_add(d), + initial_state[4].wrapping_add(e), + initial_state[5].wrapping_add(f), + initial_state[6].wrapping_add(g), + initial_state[7].wrapping_add(h), + ] +} + +/// Canonically pad `message` as SHA-256 input and return exactly `N` 512-bit +/// blocks as big-endian words. +/// +/// This intentionally rejects messages whose canonical padding would require a +/// different number of blocks; callers that ask for `[Sha256MessageBlock; 8]` +/// should pass a message whose SHA-256 padding really spans eight blocks. +pub fn sha256_padded_message_blocks( + message: &[u8], +) -> Result<[Sha256MessageBlock; N], Sha256WitnessError> { + const BLOCK_BYTES: usize = 64; + const LENGTH_BYTES: usize = 8; + const WORD_BYTES: usize = 4; + + let message_len_u64 = + u64::try_from(message.len()).map_err(|_| Sha256WitnessError::MessageBitLengthOverflow { + bytes: message.len(), + })?; + let bit_len = + message_len_u64 + .checked_mul(8) + .ok_or(Sha256WitnessError::MessageBitLengthOverflow { + bytes: message.len(), + })?; + let padded_len = message + .len() + .checked_add(1) + .and_then(|len| len.checked_add(LENGTH_BYTES)) + .ok_or(Sha256WitnessError::MessageBitLengthOverflow { + bytes: message.len(), + })?; + let required_blocks = padded_len.div_ceil(BLOCK_BYTES); + if required_blocks != N { + return Err(Sha256WitnessError::PaddedBlockCountMismatch { + expected: N, + got: required_blocks, + }); + } + + let requested_len = + N.checked_mul(BLOCK_BYTES) + .ok_or(Sha256WitnessError::MessageBitLengthOverflow { + bytes: message.len(), + })?; + let mut padded = vec![0u8; requested_len]; + padded[..message.len()].copy_from_slice(message); + padded[message.len()] = 0x80; + padded[requested_len - LENGTH_BYTES..].copy_from_slice(&bit_len.to_be_bytes()); + + let mut blocks = [[0u32; 16]; N]; + for (block_idx, block) in blocks.iter_mut().enumerate() { + let block_start = block_idx * BLOCK_BYTES; + for (word_idx, word) in block.iter_mut().enumerate() { + let word_start = block_start + word_idx * WORD_BYTES; + *word = u32::from_be_bytes([ + padded[word_start], + padded[word_start + 1], + padded[word_start + 2], + padded[word_start + 3], + ]); + } + } + Ok(blocks) +} + +/// Synthesize one fresh SHA-256 compression UAIR trace. +pub fn synthesize_one_sha256_compression_trace( + initial_state: Sha256State, + message_block: Sha256MessageBlock, +) -> Result<(UairTrace<'static, R, R, 32>, Sha256State), Sha256WitnessError> where - R: ConstSemiring + From + 'static, + R: ConstSemiring + From + Clone + Send + Sync + 'static, { - type PolyCoeff = R; - type Int = R; - - fn generate_random_trace( - num_vars: usize, - rng: &mut Rng, - ) -> UairTrace<'static, R, R, 32> { - let n = 1usize << num_vars; - assert!( - num_vars >= cols::MIN_NUM_VARS, - "trace too small for {} chained compressions: need num_vars ≥ {}, got {num_vars}", - cols::NUM_COMPRESSIONS, - cols::MIN_NUM_VARS, - ); + synthesize_sha256_compression_chain_trace::( + cols::MIN_NUM_VARS, + initial_state, + &[message_block], + ) +} - // ===== Chained-compression layout ===== - // - // Run NUM_COMPRESSIONS independent SHA-256 compressions chained - // via the spec's feed-forward addition `H_{i+1} = compress(H_i, - // M_i) + H_i mod 2^32` componentwise. Compression i ∈ [0, N) uses - // rows [i·RPC, (i+1)·RPC) where RPC = ROWS_PER_COMP = 68: - // - rows [start, start+4): init prefix (= H_i, pinned to pa_a/pa_e - // by S_INIT_PREFIX). Under the shift- - // register convention, w_a[start+j] holds - // H_i's (d, c, b, a) for j=0..3, w_e[start+j] - // holds H_i's (h, g, f, e). - // - rows [start+4, start+68): 64 round-update outputs. - // - rows [start+64, start+68): "junction window" — w_a/w_e hold - // internal_final_i; pa_a/pa_e hold a SECOND - // copy of H_i so the feed-forward constraint - // can read the prior init via `up.pa_a`. - // After the last compression, rows [N·RPC, N·RPC+4) hold the H_N output - // prefix, pinned by S_INIT_PREFIX in the same way. - // - // Slack rows [N·RPC + 4, n) are zero-padded; all SHA constraints - // are inactive there (compensators absorb C7/C8/C9; selectors - // gate off C13–C15 and the boundary/junction families). - let big_n = cols::NUM_COMPRESSIONS; - let rpc = cols::ROWS_PER_COMP; - let rounds = cols::ROUNDS_PER_COMP; - - // Trace-row buffers, all length n, zero-initialized. - let mut a_vals = vec![0u32; n]; - let mut e_vals = vec![0u32; n]; - let mut w_vals = vec![0u32; n]; - let mut k_vals = vec![0u32; n]; - let mut mu_w_vals = vec![0u32; n]; - let mut mu_a_vals = vec![0u32; n]; - let mut mu_e_vals = vec![0u32; n]; - let mut mu_junction_a_vals = vec![0u32; n]; - let mut mu_junction_e_vals = vec![0u32; n]; - - // pa_a / pa_e: H_i values at init-prefix rows (gated by - // S_INIT_PREFIX) AND at junction rows (read by the feed-forward - // constraint). Both copies hold the same H_i values; they live - // at different rows for different constraint uses. - let mut pa_a_vals = vec![0u32; n]; - let mut pa_e_vals = vec![0u32; n]; - // pa_m: per-compression message-block words. Holds M_i[0..16] - // at rows [start, start+16) for compression i; zero elsewhere. - // Pinned to w_W at those rows by C16 (s_msg_init). - let mut pa_m_vals = vec![0u32; n]; - - // H_0: random initial state for testing. Stored as two 4-arrays - // (d, c, b, a) for the a-half and (h, g, f, e) for the e-half, in - // the order they appear at the init prefix rows (so index j → row - // `start + j` directly). - let mut h_a: [u32; 4] = [ - rng.next_u32(), - rng.next_u32(), - rng.next_u32(), - rng.next_u32(), - ]; - let mut h_e: [u32; 4] = [ - rng.next_u32(), - rng.next_u32(), - rng.next_u32(), - rng.next_u32(), - ]; +/// Synthesize one monolithic SHA-256 compression-chain trace. +/// +/// This packs all `N` chained compressions into a single UAIR trace, using +/// `num_vars` to choose the MLE row domain. It is the deterministic, +/// message-driven counterpart to [`GenerateRandomTrace`] for benchmark cases +/// that need to compare a monolithic proof against a batched/folded proof over +/// the same SHA-256 chain. +pub fn synthesize_sha256_chain_trace( + num_vars: usize, + initial_state: Sha256State, + message_blocks: [Sha256MessageBlock; N], +) -> Result<(UairTrace<'static, R, R, 32>, Sha256State), Sha256WitnessError> +where + R: ConstSemiring + From + Clone + 'static, +{ + synthesize_sha256_compression_chain_trace::(num_vars, initial_state, &message_blocks) +} - for i in 0..big_n { - let start = i * rpc; +/// Synthesize `N` ordered fresh SHA-256 compression witnesses for a chain. +/// +/// State computation is sequential (`H_{i+1} = compress(H_i, M_i)`); once all +/// states are known, trace synthesis is parallelized when the crate's +/// `parallel` feature is enabled. +pub fn synthesize_sha256_chain_witnesses( + initial_state: Sha256State, + message_blocks: [Sha256MessageBlock; N], +) -> Result<([UairWitness<'static, R, R, 32>; N], Sha256State), Sha256WitnessError> +where + R: ConstSemiring + From + Clone + Send + Sync + 'static, +{ + let mut states = Vec::with_capacity(N + 1); + states.push(initial_state); + let mut state = initial_state; + for message_block in &message_blocks { + state = sha256_compress_native(state, *message_block); + states.push(state); + } - // 1) Init prefix [start, start+4): pin to H_i. - for j in 0..4 { - a_vals[start + j] = h_a[j]; - e_vals[start + j] = h_e[j]; - pa_a_vals[start + j] = h_a[j]; - pa_e_vals[start + j] = h_e[j]; + let witnesses_vec = cfg_into_iter!(0..N) + .map(|index| { + let (trace, got) = + synthesize_one_sha256_compression_trace::(states[index], message_blocks[index])?; + let expected = states[index + 1]; + if got != expected { + return Err(Sha256WitnessError::FinalStateMismatch { + index, + expected, + got, + }); } + Ok(UairWitness { trace }) + }) + .collect::, Sha256WitnessError>>()?; + + let got = witnesses_vec.len(); + let witnesses = witnesses_vec + .try_into() + .map_err(|_| Sha256WitnessError::InternalLengthMismatch { expected: N, got })?; + Ok((witnesses, states[N])) +} - // 2) Per-compression message block. 16 random seeds (which - // also populate the public pa_m column so C16 pins them), - // then 48 derived via the SHA-256 message-schedule - // recurrence — contained entirely within compression i's - // window. - for j in 0..16 { - let m_word = rng.next_u32(); - w_vals[start + j] = m_word; - pa_m_vals[start + j] = m_word; - } - for j in 16..rpc { - let t = start + j; - let sum_u64: u64 = (w_vals[t - 16] as u64) - + (small_sigma0(w_vals[t - 15]) as u64) - + (w_vals[t - 7] as u64) - + (small_sigma1(w_vals[t - 2]) as u64); - w_vals[t] = sum_u64 as u32; - let carry = (sum_u64 >> 32) as u32; - debug_assert!(carry <= 3, "message-schedule carry out of [0,3]: {carry}"); - // Store mu_W at the C7 anchor row k = t − 16 (not at - // spec row t) so C7 reads it via `up.w_mu_packed` bits - // 0-1 with no shift. - mu_w_vals[t - 16] = carry; - } +fn synthesize_sha256_compression_chain_trace( + num_vars: usize, + initial_state: Sha256State, + message_blocks: &[Sha256MessageBlock], +) -> Result<(UairTrace<'static, R, R, 32>, Sha256State), Sha256WitnessError> +where + R: ConstSemiring + From + Clone + 'static, +{ + let n = 1usize << num_vars; + let big_n = message_blocks.len(); + let rpc = cols::ROWS_PER_COMP; + let rounds = cols::ROUNDS_PER_COMP; + let active_rows = big_n * rpc + 4; + if active_rows > n { + return Err(Sha256WitnessError::TraceTooSmall { + num_compressions: big_n, + active_rows, + num_vars, + }); + } - // 3) Per-compression round constants. Cycle the canonical - // SHA-256 K table per compression at rows - // `[start + 3, start + 67)` so that C8/C9 at active - // anchors `k ∈ [start, start + 64)` (which read - // `down.pa_K^↓3 = pa_K[k+3]`) see `K_CANONICAL[k - start]`. - // Rows `start..start+3` and `start+67` are not read by - // any active anchor of compression i, so they're left - // as zero. (The compensator pa_c_c8/c9 absorbs whatever - // those rows contain.) - for j in 0..cols::ROUNDS_PER_COMP { - k_vals[start + 3 + j] = K_CANONICAL[j]; - } + // ===== Chained-compression layout ===== + // + // Run NUM_COMPRESSIONS independent SHA-256 compressions chained + // via the spec's feed-forward addition `H_{i+1} = compress(H_i, + // M_i) + H_i mod 2^32` componentwise. Compression i ∈ [0, N) uses + // rows [i·RPC, (i+1)·RPC) where RPC = ROWS_PER_COMP = 68: + // - rows [start, start+4): init prefix (= H_i, pinned to pa_a/pa_e + // by S_INIT_PREFIX). Under the shift- + // register convention, w_a[start+j] holds + // H_i's (d, c, b, a) for j=0..3, w_e[start+j] + // holds H_i's (h, g, f, e). + // - rows [start+4, start+68): 64 round-update outputs. + // - rows [start+64, start+68): "junction window" — w_a/w_e hold + // internal_final_i; pa_a/pa_e hold a SECOND + // copy of H_i so the feed-forward constraint + // can read the prior init via `up.pa_a`. + // After the last compression, rows [N·RPC, N·RPC+4) hold the H_N output + // prefix, pinned by S_INIT_PREFIX in the same way. + // + // Slack rows [N·RPC + 4, n) are zero-padded; all SHA constraints + // are inactive there (compensators absorb C7/C8/C9; selectors + // gate off C13–C15 and the boundary/junction families). + // Trace-row buffers, all length n, zero-initialized. + let mut a_vals = vec![0u32; n]; + let mut e_vals = vec![0u32; n]; + let mut w_vals = vec![0u32; n]; + let mut k_vals = vec![0u32; n]; + let mut mu_w_vals = vec![0u32; n]; + let mut mu_a_vals = vec![0u32; n]; + let mut mu_e_vals = vec![0u32; n]; + let mut mu_junction_a_vals = vec![0u32; n]; + let mut mu_junction_e_vals = vec![0u32; n]; + + // pa_a / pa_e: H_i values at init-prefix rows (gated by + // S_INIT_PREFIX) AND at junction rows (read by the feed-forward + // constraint). Both copies hold the same H_i values; they live + // at different rows for different constraint uses. + let mut pa_a_vals = vec![0u32; n]; + let mut pa_e_vals = vec![0u32; n]; + // pa_m: per-compression message-block words. Holds M_i[0..16] + // at rows [start, start+16) for compression i; zero elsewhere. + // Pinned to w_W at those rows by C16 (s_msg_init). + let mut pa_m_vals = vec![0u32; n]; + + // H_0: caller-supplied initial state. Stored as two 4-arrays + // (d, c, b, a) for the a-half and (h, g, f, e) for the e-half, in + // the order they appear at the init prefix rows (so index j → row + // `start + j` directly). + let (mut h_a, mut h_e) = state_to_trace_halves(initial_state); + + for i in 0..big_n { + let start = i * rpc; + + // 1) Init prefix [start, start+4): pin to H_i. + for j in 0..4 { + a_vals[start + j] = h_a[j]; + e_vals[start + j] = h_e[j]; + pa_a_vals[start + j] = h_a[j]; + pa_e_vals[start + j] = h_e[j]; + } + + // 2) Per-compression message block. 16 caller-supplied seeds + // (which also populate the public pa_m column so C16 pins them), + // then 48 derived via the SHA-256 message-schedule + // recurrence — contained entirely within compression i's + // window. + for j in 0..16 { + let m_word = message_blocks[i][j]; + w_vals[start + j] = m_word; + pa_m_vals[start + j] = m_word; + } + for j in 16..rpc { + let t = start + j; + let sum_u64: u64 = (w_vals[t - 16] as u64) + + (small_sigma0(w_vals[t - 15]) as u64) + + (w_vals[t - 7] as u64) + + (small_sigma1(w_vals[t - 2]) as u64); + w_vals[t] = sum_u64 as u32; + let carry = (sum_u64 >> 32) as u32; + debug_assert!(carry <= 3, "message-schedule carry out of [0,3]: {carry}"); + // Store mu_W at the C7 anchor row k = t − 16 (not at + // spec row t) so C7 reads it via `up.w_mu_packed` bits + // 0-1 with no shift. + mu_w_vals[t - 16] = carry; + } + + // 3) Per-compression round constants. Cycle the canonical + // SHA-256 K table per compression at rows + // `[start + 3, start + 67)` so that C8/C9 at active + // anchors `k ∈ [start, start + 64)` (which read + // `down.pa_K^↓3 = pa_K[k+3]`) see `K_CANONICAL[k - start]`. + // Rows `start..start+3` and `start+67` are not read by + // any active anchor of compression i, so they're left + // as zero. (The compensator pa_c_c8/c9 absorbs whatever + // those rows contain.) + for j in 0..cols::ROUNDS_PER_COMP { + k_vals[start + 3 + j] = K_CANONICAL[j]; + } - // 4) Round-update: 64 rounds, anchor k = start+0..=start+63 - // produces a[k+4]/e[k+4] from the 4-row window a[k..=k+3] - // / e[k..=k+3]. All back-references stay within - // compression i (the first round's reads land on the init - // prefix at [start, start+4); the last round's reads land - // on rows [start+60, start+64)). - // - // Bounds: T1 = h + Σ_1(e) + Ch + K + W (5 terms of <2^32). - // T2 = Σ_0(a) + Maj (2 terms). - // a_sum = T1 + T2 (7 terms ⇒ mu_a ∈ {0..=6}). - // e_sum = d + T1 (6 terms ⇒ mu_e ∈ {0..=5}). - for j in 0..rounds { - let k = start + j; - let t = k + 3; // spec round number under the t = k+3 anchor convention - - let a_t = a_vals[k + 3]; // a[t] - let a_t1 = a_vals[k + 2]; // a[t-1] = b - let a_t2 = a_vals[k + 1]; // a[t-2] = c - let e_t = e_vals[k + 3]; // e[t] - let e_t1 = e_vals[k + 2]; // e[t-1] = f - let e_t2 = e_vals[k + 1]; // e[t-2] = g - - let sig0_a_t = big_sigma0(a_t); - let sig1_e_t = big_sigma1(e_t); - let ch_t = ch(e_t, e_t1, e_t2); - let maj_t = maj(a_t, a_t1, a_t2); - - let t1: u64 = (e_vals[k] as u64) // h = e[t-3] + // 4) Round-update: 64 rounds, anchor k = start+0..=start+63 + // produces a[k+4]/e[k+4] from the 4-row window a[k..=k+3] + // / e[k..=k+3]. All back-references stay within + // compression i (the first round's reads land on the init + // prefix at [start, start+4); the last round's reads land + // on rows [start+60, start+64)). + // + // Bounds: T1 = h + Σ_1(e) + Ch + K + W (5 terms of <2^32). + // T2 = Σ_0(a) + Maj (2 terms). + // a_sum = T1 + T2 (7 terms ⇒ mu_a ∈ {0..=6}). + // e_sum = d + T1 (6 terms ⇒ mu_e ∈ {0..=5}). + for j in 0..rounds { + let k = start + j; + let t = k + 3; // register/K row under the t = k+3 anchor convention + + let a_t = a_vals[k + 3]; // a[t] + let a_t1 = a_vals[k + 2]; // a[t-1] = b + let a_t2 = a_vals[k + 1]; // a[t-2] = c + let e_t = e_vals[k + 3]; // e[t] + let e_t1 = e_vals[k + 2]; // e[t-1] = f + let e_t2 = e_vals[k + 1]; // e[t-2] = g + + let sig0_a_t = big_sigma0(a_t); + let sig1_e_t = big_sigma1(e_t); + let ch_t = ch(e_t, e_t1, e_t2); + let maj_t = maj(a_t, a_t1, a_t2); + + let t1: u64 = (e_vals[k] as u64) // h = e[t-3] + (sig1_e_t as u64) + (ch_t as u64) + (k_vals[t] as u64) - + (w_vals[t] as u64); - let t2: u64 = (sig0_a_t as u64) + (maj_t as u64); - let a_sum: u64 = t1 + t2; - let e_sum: u64 = (a_vals[k] as u64) + t1; // d + T1, d = a[t-3] - - a_vals[k + 4] = a_sum as u32; - e_vals[k + 4] = e_sum as u32; - let mu_a_t = (a_sum >> 32) as u32; - let mu_e_t = (e_sum >> 32) as u32; - debug_assert!(mu_a_t <= 6, "mu_a out of [0,6]: {mu_a_t}"); - debug_assert!(mu_e_t <= 5, "mu_e out of [0,5]: {mu_e_t}"); - // Store mu_a/mu_e at the C8/C9 anchor row k (not at - // spec row t = k+3) so C8/C9 read via `up.w_mu_packed` - // bits 2-4 / 5-7 with no shift. - mu_a_vals[k] = mu_a_t; - mu_e_vals[k] = mu_e_t; - } - - // 5) Feed-forward: H_{i+1} = internal_final_i + H_i mod 2^32 - // componentwise. internal_final_i lives at rows - // [start+64, start+68); we place a second copy of H_i in - // pa_a/pa_e at the same rows (so the feed-forward - // constraint can read the prior init via `up.pa_a`), and - // record the per-component carry in w_mu_junction_{a,e}. - // Each carry is in {0, 1} since both summands are < 2^32. - let mut h_a_next: [u32; 4] = [0; 4]; - let mut h_e_next: [u32; 4] = [0; 4]; - for j in 0..4 { - let internal_a = a_vals[start + 64 + j]; - let internal_e = e_vals[start + 64 + j]; - let prior_a = h_a[j]; - let prior_e = h_e[j]; - let sum_a: u64 = (internal_a as u64) + (prior_a as u64); - let sum_e: u64 = (internal_e as u64) + (prior_e as u64); - h_a_next[j] = sum_a as u32; - h_e_next[j] = sum_e as u32; - let carry_a = (sum_a >> 32) as u32; - let carry_e = (sum_e >> 32) as u32; - debug_assert!(carry_a <= 1, "feed-forward a-carry out of {{0,1}}: {carry_a}"); - debug_assert!(carry_e <= 1, "feed-forward e-carry out of {{0,1}}: {carry_e}"); - - pa_a_vals[start + 64 + j] = prior_a; - pa_e_vals[start + 64 + j] = prior_e; - mu_junction_a_vals[start + 64 + j] = carry_a; - mu_junction_e_vals[start + 64 + j] = carry_e; - } - h_a = h_a_next; - h_e = h_e_next; + + (w_vals[k] as u64); + let t2: u64 = (sig0_a_t as u64) + (maj_t as u64); + let a_sum: u64 = t1 + t2; + let e_sum: u64 = (a_vals[k] as u64) + t1; // d + T1, d = a[t-3] + + a_vals[k + 4] = a_sum as u32; + e_vals[k + 4] = e_sum as u32; + let mu_a_t = (a_sum >> 32) as u32; + let mu_e_t = (e_sum >> 32) as u32; + debug_assert!(mu_a_t <= 6, "mu_a out of [0,6]: {mu_a_t}"); + debug_assert!(mu_e_t <= 5, "mu_e out of [0,5]: {mu_e_t}"); + // Store mu_a/mu_e at the C8/C9 anchor row k (not at + // spec row t = k+3) so C8/C9 read via `up.w_mu_packed` + // bits 2-4 / 5-7 with no shift. + mu_a_vals[k] = mu_a_t; + mu_e_vals[k] = mu_e_t; } - // 6) H_N output prefix at rows [big_n·rpc, big_n·rpc + 4): pin - // to H_N (the final compression's output) so the verifier can - // read the digest from the public columns. - let h_out_start = big_n * rpc; + // 5) Feed-forward: H_{i+1} = internal_final_i + H_i mod 2^32 + // componentwise. internal_final_i lives at rows + // [start+64, start+68); we place a second copy of H_i in + // pa_a/pa_e at the same rows (so the feed-forward + // constraint can read the prior init via `up.pa_a`), and + // record the per-component carry in w_mu_junction_{a,e}. + // Each carry is in {0, 1} since both summands are < 2^32. + let mut h_a_next: [u32; 4] = [0; 4]; + let mut h_e_next: [u32; 4] = [0; 4]; for j in 0..4 { - a_vals[h_out_start + j] = h_a[j]; - e_vals[h_out_start + j] = h_e[j]; - pa_a_vals[h_out_start + j] = h_a[j]; - pa_e_vals[h_out_start + j] = h_e[j]; + let internal_a = a_vals[start + 64 + j]; + let internal_e = e_vals[start + 64 + j]; + let prior_a = h_a[j]; + let prior_e = h_e[j]; + let sum_a: u64 = (internal_a as u64) + (prior_a as u64); + let sum_e: u64 = (internal_e as u64) + (prior_e as u64); + h_a_next[j] = sum_a as u32; + h_e_next[j] = sum_e as u32; + let carry_a = (sum_a >> 32) as u32; + let carry_e = (sum_e >> 32) as u32; + debug_assert!( + carry_a <= 1, + "feed-forward a-carry out of {{0,1}}: {carry_a}" + ); + debug_assert!( + carry_e <= 1, + "feed-forward e-carry out of {{0,1}}: {carry_e}" + ); + + pa_a_vals[start + 64 + j] = prior_a; + pa_e_vals[start + 64 + j] = prior_e; + mu_junction_a_vals[start + 64 + j] = carry_a; + mu_junction_e_vals[start + 64 + j] = carry_e; } + h_a = h_a_next; + h_e = h_e_next; + } - // ===== Per-row Ch / Maj operand witnesses ===== - // - // Computed honestly on every row from a_vals / e_vals contents. - // The truth-table values must hold on every row (not only - // SHA-active ones) to keep the Ch/Maj virtual residuals - // (`r_ch1` / `r_ch2` / `r_maj`, declared in `signature()`'s - // `with_virtual_binary_poly_cols`) bit-valid per coefficient - // across compression-junction boundaries: the booleanity - // sumcheck checks every row, including ones the spec doesn't - // care about. - let u_ef_vals: Vec = (0..n) - .map(|t| if t >= 1 { e_vals[t] & e_vals[t - 1] } else { 0 }) - .collect(); - let u_neg_e_g_vals: Vec = (0..n) - .map(|t| if t >= 2 { (!e_vals[t]) & e_vals[t - 2] } else { 0 }) - .collect(); - let maj_vals: Vec = (0..n) - .map(|t| if t >= 2 { maj(a_vals[t], a_vals[t - 1], a_vals[t - 2]) } else { 0 }) - .collect(); - - // ===== Tail compensators for the Ch (63) / Maj (64) virtual residuals ===== - // - // Zero on every row except `k ∈ {n−2, n−1}` where the length-2 - // forward shifts in r_ch2 / r_maj read into off-trace zero- - // padding and the residual would slip outside `{0,1}` per - // coefficient. Match the compensator logic in option-a-virtual- - // residuals (8787cbd): - // r_ch2 (alt complement form) at boundary k = n-2 / n-1: - // u_{¬e,g}[k+2] = 0 (off-trace), e[k+2] = 0, e[k] real. - // residual = -e[k] + 2·comp_ch2 ∈ {0,1} ⇒ comp_ch2[k] = e[k]. - // r_maj at boundary k = n-2: - // a[k+2] = Maj[k+2] = 0 (off-trace), a[k+1] real. - // residual = a[k] + a[k+1] − 2·comp_maj ∈ {0,1} - // ⇒ comp_maj[k] = AND(a[k], a[k+1]). - // r_maj at k = n-1: a[k+1] = a[k+2] = 0, residual = a[k] ∈ - // {0,1} already; comp_maj = 0. - let mut pa_r_ch2_comp_vals: Vec = vec![0; n]; - let mut pa_r_maj_comp_vals: Vec = vec![0; n]; - for k in 0..n { - let off_kp1 = k + 1 >= n; - let off_kp2 = k + 2 >= n; - if off_kp2 { - pa_r_ch2_comp_vals[k] = e_vals[k]; + // 6) H_N output prefix at rows [big_n·rpc, big_n·rpc + 4): pin + // to H_N (the final compression's output) so the verifier can + // read the digest from the public columns. + let h_out_start = big_n * rpc; + for j in 0..4 { + a_vals[h_out_start + j] = h_a[j]; + e_vals[h_out_start + j] = h_e[j]; + pa_a_vals[h_out_start + j] = h_a[j]; + pa_e_vals[h_out_start + j] = h_e[j]; + } + + // ===== Per-row Ch / Maj operand witnesses ===== + // + // Computed honestly on every row from a_vals / e_vals contents. + // The truth-table values must hold on every row (not only + // SHA-active ones) to keep the Ch/Maj virtual residuals + // (`r_ch1` / `r_ch2` / `r_maj`, declared in `signature()`'s + // `with_virtual_binary_poly_cols`) bit-valid per coefficient + // across compression-junction boundaries: the booleanity + // sumcheck checks every row, including ones the spec doesn't + // care about. + let u_ef_vals: Vec = cfg_into_iter!(0..n) + .map(|t| if t >= 1 { e_vals[t] & e_vals[t - 1] } else { 0 }) + .collect(); + let u_neg_e_g_vals: Vec = cfg_into_iter!(0..n) + .map(|t| { + if t >= 2 { + (!e_vals[t]) & e_vals[t - 2] + } else { + 0 } - if off_kp2 && !off_kp1 { - pa_r_maj_comp_vals[k] = a_vals[k] & a_vals[k + 1]; + }) + .collect(); + let maj_vals: Vec = cfg_into_iter!(0..n) + .map(|t| { + if t >= 2 { + maj(a_vals[t], a_vals[t - 1], a_vals[t - 2]) + } else { + 0 } - } - - // Derived values. - let sig0_vals: Vec = a_vals.iter().copied().map(big_sigma0).collect(); - let sig1_vals: Vec = e_vals.iter().copied().map(big_sigma1).collect(); - let lsig0_vals: Vec = w_vals.iter().copied().map(small_sigma0).collect(); - let lsig1_vals: Vec = w_vals.iter().copied().map(small_sigma1).collect(); - - let ov_sig0_vals: Vec = a_vals - .iter() - .zip(&sig0_vals) - .map(|(&a, &s)| sigma0_overflow(a, s)) - .collect(); - let ov_sig1_vals: Vec = e_vals - .iter() - .zip(&sig1_vals) - .map(|(&e, &s)| sigma1_overflow(e, s)) - .collect(); - let ov_lsig0_vals: Vec = w_vals - .iter() - .zip(&lsig0_vals) - .map(|(&w, &l)| lsig0_overflow(w, l)) - .collect(); - let ov_lsig1_vals: Vec = w_vals - .iter() - .zip(&lsig1_vals) - .map(|(&w, &l)| lsig1_overflow(w, l)) - .collect(); - - // The σ_0/σ_1 right-shift decomposition columns S_i / T_i are - // gone — their role (carrying SHR(W, k) for the F_2[X] sum) is - // taken over by the `BitOp::ShiftR(k)` virtual columns over W. - // `lsig0_overflow` / `lsig1_overflow` already compute the - // matching `pa_ov_lsig{0,1}` per-bit values for the new - // constraint (the algebraic identity is unchanged). - - // Pack all 5 carries per row into the W_MU_PACKED binary_poly - // column. Each carry was stored at its constraint's anchor row - // (mu_W at C7-anchor k = t-16, mu_a/mu_e at C8/C9-anchor k = - // t-3, mu_ff_a/e at junction-anchor k). Bit layout: - // bits 0-1: mu_W, 2-4: mu_a, 5-7: mu_e, 8: mu_ff_a, 9: mu_ff_e. - // Positions 10..31 stay 0 (pinned by C22's high-bits-zero - // assert_zero on ShiftR(10)(W_MU_PACKED)). - let w_mu_packed_vals: Vec = (0..n) - .map(|k| { - (mu_w_vals[k] & 0b11) - | ((mu_a_vals[k] & 0b111) << 2) - | ((mu_e_vals[k] & 0b111) << 5) - | ((mu_junction_a_vals[k] & 0b1) << 8) - | ((mu_junction_e_vals[k] & 0b1) << 9) - }) - .collect(); - - let to_bits = |v: &[u32]| -> Vec> { - v.iter().copied().map(BinaryPoly::<32>::from).collect() - }; - - let to_bin_mle = |col: Vec>| -> DenseMultilinearExtension< - BinaryPoly<32>, - > { col.into_iter().collect() }; - - // Layout: 8 public bin_poly cols (PA_A, PA_E, PA_OV_SIG0, - // PA_OV_SIG1, PA_OV_LSIG0, PA_OV_LSIG1, PA_R_CH2_COMP, - // PA_R_MAJ_COMP) + 10 witness cols. pa_a / pa_e were populated - // above with H_i values at init-prefix rows (for compression i - // and the H_N output block) AND at junction rows (where the - // feed-forward constraint reads the prior H_i). The two - // PA_R_*_COMP columns are zero except on the trace tail. - let binary_poly = vec![ - to_bin_mle(to_bits(&pa_a_vals)), - to_bin_mle(to_bits(&pa_e_vals)), - to_bin_mle(to_bits(&ov_sig0_vals)), - to_bin_mle(to_bits(&ov_sig1_vals)), - to_bin_mle(to_bits(&ov_lsig0_vals)), - to_bin_mle(to_bits(&ov_lsig1_vals)), - to_bin_mle(to_bits(&pa_r_ch2_comp_vals)), - to_bin_mle(to_bits(&pa_r_maj_comp_vals)), - to_bin_mle(to_bits(&pa_m_vals)), - to_bin_mle(to_bits(&a_vals)), - to_bin_mle(to_bits(&sig0_vals)), - to_bin_mle(to_bits(&e_vals)), - to_bin_mle(to_bits(&sig1_vals)), - to_bin_mle(to_bits(&w_vals)), - to_bin_mle(to_bits(&lsig0_vals)), - to_bin_mle(to_bits(&lsig1_vals)), - to_bin_mle(to_bits(&u_ef_vals)), - to_bin_mle(to_bits(&u_neg_e_g_vals)), - to_bin_mle(to_bits(&maj_vals)), - to_bin_mle(to_bits(&w_mu_packed_vals)), - ]; + }) + .collect(); - // ===== Selectors ===== - // - // s_init_prefix: 1 on the init-prefix windows for every compression - // (4 rows × NUM_COMPRESSIONS) plus the H_N output - // block (4 more rows). Pins w_a / w_e to pa_a / pa_e. - // s_feedforward: 1 on the junction windows [start+64, start+68) for - // every compression. Gates the SHA-256 inter- - // compression addition constraint. - let mut s_init_prefix_col: Vec = (0..n).map(|_| R::ZERO).collect(); - for i in 0..=big_n { - // i = big_n: the H_N output block. - for j in 0..4 { - s_init_prefix_col[i * rpc + j] = R::ONE; + // ===== Tail compensators for the Ch (63) / Maj (64) virtual residuals ===== + // + // Zero on every row except `k ∈ {n−2, n−1}` where the length-2 + // forward shifts in r_ch2 / r_maj read into off-trace zero- + // padding and the residual would slip outside `{0,1}` per + // coefficient. Match the compensator logic in option-a-virtual- + // residuals (8787cbd): + // r_ch2 (alt complement form) at boundary k = n-2 / n-1: + // u_{¬e,g}[k+2] = 0 (off-trace), e[k+2] = 0, e[k] real. + // residual = -e[k] + 2·comp_ch2 ∈ {0,1} ⇒ comp_ch2[k] = e[k]. + // r_maj at boundary k = n-2: + // a[k+2] = Maj[k+2] = 0 (off-trace), a[k+1] real. + // residual = a[k] + a[k+1] − 2·comp_maj ∈ {0,1} + // ⇒ comp_maj[k] = AND(a[k], a[k+1]). + // r_maj at k = n-1: a[k+1] = a[k+2] = 0, residual = a[k] ∈ + // {0,1} already; comp_maj = 0. + let pa_r_ch2_comp_vals: Vec = cfg_into_iter!(0..n) + .map(|k| if k + 2 >= n { e_vals[k] } else { 0 }) + .collect(); + let pa_r_maj_comp_vals: Vec = cfg_into_iter!(0..n) + .map(|k| { + if k + 2 >= n && k + 1 < n { + a_vals[k] & a_vals[k + 1] + } else { + 0 } + }) + .collect(); + + // Derived values. + let sig0_vals: Vec = cfg_into_iter!(0..n) + .map(|t| big_sigma0(a_vals[t])) + .collect(); + let sig1_vals: Vec = cfg_into_iter!(0..n) + .map(|t| big_sigma1(e_vals[t])) + .collect(); + let lsig0_vals: Vec = cfg_into_iter!(0..n) + .map(|t| small_sigma0(w_vals[t])) + .collect(); + let lsig1_vals: Vec = cfg_into_iter!(0..n) + .map(|t| small_sigma1(w_vals[t])) + .collect(); + + let ov_sig0_vals: Vec = cfg_into_iter!(0..n) + .map(|t| sigma0_overflow(a_vals[t], sig0_vals[t])) + .collect(); + let ov_sig1_vals: Vec = cfg_into_iter!(0..n) + .map(|t| sigma1_overflow(e_vals[t], sig1_vals[t])) + .collect(); + let ov_lsig0_vals: Vec = cfg_into_iter!(0..n) + .map(|t| lsig0_overflow(w_vals[t], lsig0_vals[t])) + .collect(); + let ov_lsig1_vals: Vec = cfg_into_iter!(0..n) + .map(|t| lsig1_overflow(w_vals[t], lsig1_vals[t])) + .collect(); + + // The σ_0/σ_1 right-shift decomposition columns S_i / T_i are + // gone — their role (carrying SHR(W, k) for the F_2[X] sum) is + // taken over by the `BitOp::ShiftR(k)` virtual columns over W. + // `lsig0_overflow` / `lsig1_overflow` already compute the + // matching `pa_ov_lsig{0,1}` per-bit values for the new + // constraint (the algebraic identity is unchanged). + + // Pack all 5 carries per row into the W_MU_PACKED binary_poly + // column. Each carry was stored at its constraint's anchor row + // (mu_W at C7-anchor k = t-16, mu_a/mu_e at C8/C9-anchor k = + // t-3, mu_ff_a/e at junction-anchor k). Bit layout: + // bits 0-1: mu_W, 2-4: mu_a, 5-7: mu_e, 8: mu_ff_a, 9: mu_ff_e. + // Positions 10..31 stay 0 (pinned by C22's high-bits-zero + // assert_zero on ShiftR(10)(W_MU_PACKED)). + let w_mu_packed_vals: Vec = cfg_into_iter!(0..n) + .map(|k| { + (mu_w_vals[k] & 0b11) + | ((mu_a_vals[k] & 0b111) << 2) + | ((mu_e_vals[k] & 0b111) << 5) + | ((mu_junction_a_vals[k] & 0b1) << 8) + | ((mu_junction_e_vals[k] & 0b1) << 9) + }) + .collect(); + + let to_bits = |v: &[u32]| -> Vec> { + cfg_into_iter!(0..v.len()) + .map(|idx| BinaryPoly::<32>::from(v[idx])) + .collect() + }; + + let to_bin_mle = |col: Vec>| -> DenseMultilinearExtension> { + col.into_iter().collect() + }; + + // Layout: 8 public bin_poly cols (PA_A, PA_E, PA_OV_SIG0, + // PA_OV_SIG1, PA_OV_LSIG0, PA_OV_LSIG1, PA_R_CH2_COMP, + // PA_R_MAJ_COMP) + 10 witness cols. pa_a / pa_e were populated + // above with H_i values at init-prefix rows (for compression i + // and the H_N output block) AND at junction rows (where the + // feed-forward constraint reads the prior H_i). The two + // PA_R_*_COMP columns are zero except on the trace tail. + let binary_poly = vec![ + to_bin_mle(to_bits(&pa_a_vals)), + to_bin_mle(to_bits(&pa_e_vals)), + to_bin_mle(to_bits(&ov_sig0_vals)), + to_bin_mle(to_bits(&ov_sig1_vals)), + to_bin_mle(to_bits(&ov_lsig0_vals)), + to_bin_mle(to_bits(&ov_lsig1_vals)), + to_bin_mle(to_bits(&pa_r_ch2_comp_vals)), + to_bin_mle(to_bits(&pa_r_maj_comp_vals)), + to_bin_mle(to_bits(&pa_m_vals)), + to_bin_mle(to_bits(&a_vals)), + to_bin_mle(to_bits(&sig0_vals)), + to_bin_mle(to_bits(&e_vals)), + to_bin_mle(to_bits(&sig1_vals)), + to_bin_mle(to_bits(&w_vals)), + to_bin_mle(to_bits(&lsig0_vals)), + to_bin_mle(to_bits(&lsig1_vals)), + to_bin_mle(to_bits(&u_ef_vals)), + to_bin_mle(to_bits(&u_neg_e_g_vals)), + to_bin_mle(to_bits(&maj_vals)), + to_bin_mle(to_bits(&w_mu_packed_vals)), + ]; + + // ===== Selectors ===== + // + // s_init_prefix: 1 on the init-prefix windows for every compression + // (4 rows × NUM_COMPRESSIONS) plus the H_N output + // block (4 more rows). Pins w_a / w_e to pa_a / pa_e. + // s_feedforward: 1 on the junction windows [start+64, start+68) for + // every compression. Gates the SHA-256 inter- + // compression addition constraint. + let mut s_init_prefix_col: Vec = (0..n).map(|_| R::ZERO).collect(); + for i in 0..=big_n { + // i = big_n: the H_N output block. + for j in 0..4 { + s_init_prefix_col[i * rpc + j] = R::ONE; } - let mut s_feedforward_col: Vec = (0..n).map(|_| R::ZERO).collect(); - for i in 0..big_n { - for j in 0..4 { - s_feedforward_col[i * rpc + 64 + j] = R::ONE; - } + } + let mut s_feedforward_col: Vec = (0..n).map(|_| R::ZERO).collect(); + for i in 0..big_n { + for j in 0..4 { + s_feedforward_col[i * rpc + 64 + j] = R::ONE; } - // s_msg_init: 1 on the 16 message-block-seed rows of every - // compression, 0 elsewhere. Gates C16 (`w_W − pa_m == 0`). - let mut s_msg_init_col: Vec = (0..n).map(|_| R::ZERO).collect(); - for i in 0..big_n { - for j in 0..16 { - s_msg_init_col[i * rpc + j] = R::ONE; - } + } + // s_msg_init: 1 on the 16 message-block-seed rows of every + // compression, 0 elsewhere. Gates C16 (`w_W − pa_m == 0`). + let mut s_msg_init_col: Vec = (0..n).map(|_| R::ZERO).collect(); + for i in 0..big_n { + for j in 0..16 { + s_msg_init_col[i * rpc + j] = R::ONE; } - // s_active_sched / s_active_upd: pin each compensator to 0 on - // its constraint's honest active range. Read by the - // `pa_c_* · s_active_* == 0` zero-ideal constraints in - // `constrain_general`. - // - // s_active_sched: 1 on C7's 48 anchors per compression - // [start, start + ROUNDS_PER_COMP - 16), 0 elsewhere. - // s_active_upd: 1 on C8/C9's 64 anchors per compression - // [start, start + ROUNDS_PER_COMP), 0 elsewhere. - let mut s_active_sched_col: Vec = (0..n).map(|_| R::ZERO).collect(); - let mut s_active_upd_col: Vec = (0..n).map(|_| R::ZERO).collect(); - for i in 0..big_n { - let start = i * rpc; - for j in 0..(rounds - 16) { - s_active_sched_col[start + j] = R::ONE; - } - for j in 0..rounds { - s_active_upd_col[start + j] = R::ONE; - } + } + // s_active_sched / s_active_upd: pin each compensator to 0 on + // its constraint's honest active range. Read by the + // `pa_c_* · s_active_* == 0` zero-ideal constraints in + // `constrain_general`. + // + // s_active_sched: 1 on C7's 48 anchors per compression + // [start, start + ROUNDS_PER_COMP - 16), 0 elsewhere. + // s_active_upd: 1 on C8/C9's 64 anchors per compression + // [start, start + ROUNDS_PER_COMP), 0 elsewhere. + let mut s_active_sched_col: Vec = (0..n).map(|_| R::ZERO).collect(); + let mut s_active_upd_col: Vec = (0..n).map(|_| R::ZERO).collect(); + for i in 0..big_n { + let start = i * rpc; + for j in 0..(rounds - 16) { + s_active_sched_col[start + j] = R::ONE; } - // (PA_C_FF_{A,E} reuse `s_feedforward_col` as their - // compensator-zero selector — it is already 1 exactly on the - // junction window where the feed-forward addition holds - // honestly.) - - let k_col: Vec = k_vals.iter().copied().map(R::from).collect(); - // mu_w_vals / mu_a_vals / mu_e_vals / mu_junction_{a,e}_vals - // are no longer materialized as separate int columns — they're - // packed into the W_MU_PACKED binary_poly column above. - - // ----- Compensator columns (replace s_sched_anch / s_upd_anch). ----- - // - // For each constraint Cᵢ ∈ {C7, C8, C9}, we publish a public column - // `pa_c_cᵢ[k]` with the property that (innerᵢ + pa_c_cᵢ) ∈ (X − 2) - // on every row k. Concretely we pick `pa_c_cᵢ[k] = −innerᵢ(2)` mod - // R's modulus; the protocol projects R into the random field, so - // the negation lands as `−innerᵢ(2) mod p` — exactly the value - // needed for the constraint to lie in (X − 2). - // - // On the corresponding active range (where the original selector - // was 1), the SHA recurrence makes `innerᵢ(2) = 0` for an honest - // prover, so `pa_c_cᵢ[k] = 0` automatically. On inactive rows the - // compensator absorbs whatever `innerᵢ(2)` happens to be. - let two_to_32: R = R::from(0x10000u32) * &R::from(0x10000u32); - let load = |arr: &[u32], idx: usize| -> R { - if idx < n { R::from(arr[idx]) } else { R::ZERO } - }; - - // C7: inner(2) = w_W[k+16] − w_W[k] − lsig0[k+1] − w_W[k+9] - // − lsig1[k+14] + 2^32 · mu_W[k+16] - let pa_c_c7_col: Vec = (0..n) - .map(|k| { - let w_k16 = load(&w_vals, k + 16); - let w_k = load(&w_vals, k); - let lsig0_k1 = load(&lsig0_vals, k + 1); - let w_k9 = load(&w_vals, k + 9); - let lsig1_k14 = load(&lsig1_vals, k + 14); - // mu_W stored at C7-anchor row k (= round t = k+16 was - // formerly stored at k+16; with chained-comp re-anchoring - // it's now at row k). - let mu_k = load(&mu_w_vals, k); - let two32_mu = two_to_32.clone() * &mu_k; - // comp = −inner(2) = w_k + lsig0_k1 + w_k9 + lsig1_k14 - // − 2^32·mu_k16 − w_k16 - w_k + &lsig0_k1 + &w_k9 + &lsig1_k14 - &two32_mu - &w_k16 - }) - .collect(); - - // C8: inner(2) = w_a[k+4] − w_e[k] − sig1[k+3] − Ch[k+3] − K[k+3] - // − W[k+3] − sig0[k+3] − maj[k+3] + 2^32 · mu_a[k+3] - // with Ch[k+3] = u_ef[k+3] + u_{¬e,g}[k+3]. - let pa_c_c8_col: Vec = (0..n) - .map(|k| { - let w_a_k4 = load(&a_vals, k + 4); - let w_e_k = load(&e_vals, k); - let sig1_k3 = load(&sig1_vals, k + 3); - let u_ef_k3 = load(&u_ef_vals, k + 3); - let u_neg_e_g_k3 = load(&u_neg_e_g_vals, k + 3); - let k_k3 = load(&k_vals, k + 3); - let w_k3 = load(&w_vals, k + 3); - let sig0_k3 = load(&sig0_vals, k + 3); - let maj_k3 = load(&maj_vals, k + 3); - // mu_a stored at C8-anchor row k (= round t = k+3 was - // formerly stored at k+3; now at row k). - let mu_a_k = load(&mu_a_vals, k); - let two32_mu = two_to_32.clone() * &mu_a_k; - w_e_k - + &sig1_k3 - + &u_ef_k3 - + &u_neg_e_g_k3 - + &k_k3 - + &w_k3 - + &sig0_k3 - + &maj_k3 - - &two32_mu - - &w_a_k4 - }) - .collect(); - - // C9: inner(2) = w_e[k+4] − w_a[k] − w_e[k] − sig1[k+3] − Ch[k+3] - // − K[k+3] − W[k+3] + 2^32 · mu_e[k+3] - // with Ch[k+3] = u_ef[k+3] + u_{¬e,g}[k+3]. - let pa_c_c9_col: Vec = (0..n) - .map(|k| { - let w_e_k4 = load(&e_vals, k + 4); - let w_a_k = load(&a_vals, k); - let w_e_k = load(&e_vals, k); - let sig1_k3 = load(&sig1_vals, k + 3); - let u_ef_k3 = load(&u_ef_vals, k + 3); - let u_neg_e_g_k3 = load(&u_neg_e_g_vals, k + 3); - let k_k3 = load(&k_vals, k + 3); - let w_k3 = load(&w_vals, k + 3); - // mu_e stored at C9-anchor row k (analogous to mu_a). - let mu_e_k = load(&mu_e_vals, k); - let two32_mu = two_to_32.clone() * &mu_e_k; - w_a_k - + &w_e_k - + &sig1_k3 - + &u_ef_k3 - + &u_neg_e_g_k3 - + &k_k3 - + &w_k3 - - &two32_mu - - &w_e_k4 - }) - .collect(); - - // C12/C13 feed-forward compensators. inner_a(2) at row k = - // w_a[k+4] − w_a[k] − pa_a[k] + 2^32 · mu_junction_a[k] - // (e-half symmetric). On junction rows the SHA-256 feed-forward - // makes inner = 0 honestly, so the compensator is 0. Off- - // junction (init prefix straddle, round-update windows, slack) - // it absorbs whatever inner happens to be so that - // `(inner + pa_c_ff) ∈ (X − 2)` everywhere. - let pa_c_ff_a_col: Vec = (0..n) - .map(|k| { - let w_a_k4 = load(&a_vals, k + 4); - let w_a_k = load(&a_vals, k); - let pa_a_k = load(&pa_a_vals, k); - let mu_ff_k = load(&mu_junction_a_vals, k); - let two32_mu = two_to_32.clone() * &mu_ff_k; - // comp = −inner(2) = w_a_k + pa_a_k − 2^32·mu_ff_k − w_a_k4 - w_a_k + &pa_a_k - &two32_mu - &w_a_k4 - }) - .collect(); - let pa_c_ff_e_col: Vec = (0..n) - .map(|k| { - let w_e_k4 = load(&e_vals, k + 4); - let w_e_k = load(&e_vals, k); - let pa_e_k = load(&pa_e_vals, k); - let mu_ff_k = load(&mu_junction_e_vals, k); - let two32_mu = two_to_32.clone() * &mu_ff_k; - w_e_k + &pa_e_k - &two32_mu - &w_e_k4 - }) - .collect(); - - let to_int_mle = |col: Vec| -> DenseMultilinearExtension { - col.into_iter().collect() - }; - // Layout: public int prefix (selectors + K + active-range - // selectors) followed by witness int suffix (the five linear- - // constraint compensators). Order matches cols::S_INIT_PREFIX.. - // PA_C_FF_E. The 5 prior int carry columns (mu_W/a/e/ - // junction_a/e) are gone — packed into W_MU_PACKED above. - let int = vec![ - to_int_mle(s_init_prefix_col), - to_int_mle(s_feedforward_col), - to_int_mle(s_msg_init_col), - to_int_mle(k_col), - to_int_mle(s_active_sched_col), - to_int_mle(s_active_upd_col), - to_int_mle(pa_c_c7_col), - to_int_mle(pa_c_c8_col), - to_int_mle(pa_c_c9_col), - to_int_mle(pa_c_ff_a_col), - to_int_mle(pa_c_ff_e_col), - ]; - + for j in 0..rounds { + s_active_upd_col[start + j] = R::ONE; + } + } + // (PA_C_FF_{A,E} reuse `s_feedforward_col` as their + // compensator-zero selector — it is already 1 exactly on the + // junction window where the feed-forward addition holds + // honestly.) + + let k_col: Vec = cfg_into_iter!(0..n) + .map(|idx| R::from(k_vals[idx])) + .collect(); + // mu_w_vals / mu_a_vals / mu_e_vals / mu_junction_{a,e}_vals + // are no longer materialized as separate int columns — they're + // packed into the W_MU_PACKED binary_poly column above. + + // ----- Compensator columns (replace s_sched_anch / s_upd_anch). ----- + // + // For each constraint Cᵢ ∈ {C7, C8, C9}, we publish a public column + // `pa_c_cᵢ[k]` with the property that (innerᵢ + pa_c_cᵢ) ∈ (X − 2) + // on every row k. Concretely we pick `pa_c_cᵢ[k] = −innerᵢ(2)` mod + // R's modulus; the protocol projects R into the random field, so + // the negation lands as `−innerᵢ(2) mod p` — exactly the value + // needed for the constraint to lie in (X − 2). + // + // On the corresponding active range (where the original selector + // was 1), the SHA recurrence makes `innerᵢ(2) = 0` for an honest + // prover, so `pa_c_cᵢ[k] = 0` automatically. On inactive rows the + // compensator absorbs whatever `innerᵢ(2)` happens to be. + let two_to_32: R = R::from(0x10000u32) * &R::from(0x10000u32); + let load = |arr: &[u32], idx: usize| -> R { if idx < n { R::from(arr[idx]) } else { R::ZERO } }; + + // C7: inner(2) = w_W[k+16] − w_W[k] − lsig0[k+1] − w_W[k+9] + // − lsig1[k+14] + 2^32 · mu_W[k+16] + let pa_c_c7_col: Vec = cfg_into_iter!(0..n) + .map(|k| { + let w_k16 = load(&w_vals, k + 16); + let w_k = load(&w_vals, k); + let lsig0_k1 = load(&lsig0_vals, k + 1); + let w_k9 = load(&w_vals, k + 9); + let lsig1_k14 = load(&lsig1_vals, k + 14); + // mu_W stored at C7-anchor row k (= round t = k+16 was + // formerly stored at k+16; with chained-comp re-anchoring + // it's now at row k). + let mu_k = load(&mu_w_vals, k); + let two32_mu = two_to_32.clone() * &mu_k; + // comp = −inner(2) = w_k + lsig0_k1 + w_k9 + lsig1_k14 + // − 2^32·mu_k16 − w_k16 + w_k + &lsig0_k1 + &w_k9 + &lsig1_k14 - &two32_mu - &w_k16 + }) + .collect(); + + // C8: inner(2) = w_a[k+4] − w_e[k] − sig1[k+3] − Ch[k+3] − K[k+3] + // − W[k] − sig0[k+3] − maj[k+3] + 2^32 · mu_a[k] + // with Ch[k+3] = u_ef[k+3] + u_{¬e,g}[k+3]. + let pa_c_c8_col: Vec = cfg_into_iter!(0..n) + .map(|k| { + let w_a_k4 = load(&a_vals, k + 4); + let w_e_k = load(&e_vals, k); + let sig1_k3 = load(&sig1_vals, k + 3); + let u_ef_k3 = load(&u_ef_vals, k + 3); + let u_neg_e_g_k3 = load(&u_neg_e_g_vals, k + 3); + let k_k3 = load(&k_vals, k + 3); + let w_k = load(&w_vals, k); + let sig0_k3 = load(&sig0_vals, k + 3); + let maj_k3 = load(&maj_vals, k + 3); + // mu_a stored at C8-anchor row k (= round t = k+3 was + // formerly stored at k+3; now at row k). + let mu_a_k = load(&mu_a_vals, k); + let two32_mu = two_to_32.clone() * &mu_a_k; + w_e_k + &sig1_k3 + &u_ef_k3 + &u_neg_e_g_k3 + &k_k3 + &w_k + &sig0_k3 + &maj_k3 + - &two32_mu + - &w_a_k4 + }) + .collect(); + + // C9: inner(2) = w_e[k+4] − w_a[k] − w_e[k] − sig1[k+3] − Ch[k+3] + // − K[k+3] − W[k] + 2^32 · mu_e[k] + // with Ch[k+3] = u_ef[k+3] + u_{¬e,g}[k+3]. + let pa_c_c9_col: Vec = cfg_into_iter!(0..n) + .map(|k| { + let w_e_k4 = load(&e_vals, k + 4); + let w_a_k = load(&a_vals, k); + let w_e_k = load(&e_vals, k); + let sig1_k3 = load(&sig1_vals, k + 3); + let u_ef_k3 = load(&u_ef_vals, k + 3); + let u_neg_e_g_k3 = load(&u_neg_e_g_vals, k + 3); + let k_k3 = load(&k_vals, k + 3); + let w_k = load(&w_vals, k); + // mu_e stored at C9-anchor row k (analogous to mu_a). + let mu_e_k = load(&mu_e_vals, k); + let two32_mu = two_to_32.clone() * &mu_e_k; + w_a_k + &w_e_k + &sig1_k3 + &u_ef_k3 + &u_neg_e_g_k3 + &k_k3 + &w_k + - &two32_mu + - &w_e_k4 + }) + .collect(); + + // C12/C13 feed-forward compensators. inner_a(2) at row k = + // w_a[k+4] − w_a[k] − pa_a[k] + 2^32 · mu_junction_a[k] + // (e-half symmetric). On junction rows the SHA-256 feed-forward + // makes inner = 0 honestly, so the compensator is 0. Off- + // junction (init prefix straddle, round-update windows, slack) + // it absorbs whatever inner happens to be so that + // `(inner + pa_c_ff) ∈ (X − 2)` everywhere. + let pa_c_ff_a_col: Vec = cfg_into_iter!(0..n) + .map(|k| { + let w_a_k4 = load(&a_vals, k + 4); + let w_a_k = load(&a_vals, k); + let pa_a_k = load(&pa_a_vals, k); + let mu_ff_k = load(&mu_junction_a_vals, k); + let two32_mu = two_to_32.clone() * &mu_ff_k; + // comp = −inner(2) = w_a_k + pa_a_k − 2^32·mu_ff_k − w_a_k4 + w_a_k + &pa_a_k - &two32_mu - &w_a_k4 + }) + .collect(); + let pa_c_ff_e_col: Vec = cfg_into_iter!(0..n) + .map(|k| { + let w_e_k4 = load(&e_vals, k + 4); + let w_e_k = load(&e_vals, k); + let pa_e_k = load(&pa_e_vals, k); + let mu_ff_k = load(&mu_junction_e_vals, k); + let two32_mu = two_to_32.clone() * &mu_ff_k; + w_e_k + &pa_e_k - &two32_mu - &w_e_k4 + }) + .collect(); + + let to_int_mle = |col: Vec| -> DenseMultilinearExtension { col.into_iter().collect() }; + // Layout: public int prefix (selectors + K + active-range + // selectors) followed by witness int suffix (the five linear- + // constraint compensators). Order matches cols::S_INIT_PREFIX.. + // PA_C_FF_E. The 5 prior int carry columns (mu_W/a/e/ + // junction_a/e) are gone — packed into W_MU_PACKED above. + let int = vec![ + to_int_mle(s_init_prefix_col), + to_int_mle(s_feedforward_col), + to_int_mle(s_msg_init_col), + to_int_mle(k_col), + to_int_mle(s_active_sched_col), + to_int_mle(s_active_upd_col), + to_int_mle(pa_c_c7_col), + to_int_mle(pa_c_c8_col), + to_int_mle(pa_c_c9_col), + to_int_mle(pa_c_ff_a_col), + to_int_mle(pa_c_ff_e_col), + ]; + + Ok(( UairTrace { binary_poly: binary_poly.into(), int: int.into(), ..Default::default() - } + }, + trace_halves_to_state(h_a, h_e), + )) +} + +// --------------------------------------------------------------------------- +// GenerateRandomTrace for the slice. +// --------------------------------------------------------------------------- + +impl GenerateRandomTrace<32> for Sha256CompressionSliceUair +where + R: ConstSemiring + From + Clone + 'static, +{ + type PolyCoeff = R; + type Int = R; + + fn generate_random_trace( + num_vars: usize, + rng: &mut Rng, + ) -> UairTrace<'static, R, R, 32> { + let initial_state: Sha256State = std::array::from_fn(|_| rng.next_u32()); + let message_blocks: [Sha256MessageBlock; cols::NUM_COMPRESSIONS] = + std::array::from_fn(|_| std::array::from_fn(|_| rng.next_u32())); + synthesize_sha256_compression_chain_trace::(num_vars, initial_state, &message_blocks) + .map(|(trace, _)| trace) + .expect("random SHA-256 trace synthesis failed") } } @@ -1834,8 +2111,116 @@ where #[cfg(test)] mod tests { use super::*; - use crypto_primitives::crypto_bigint_int::Int; - use zinc_uair::degree_counter::{count_effective_max_degree, count_max_degree}; + use crypto_primitives::{crypto_bigint_int::Int, semiring::boolean::Boolean}; + use zinc_uair::{ + Uair, + degree_counter::{count_effective_max_degree, count_max_degree}, + }; + + fn test_initial_state() -> Sha256State { + SHA256_INITIAL_STATE + } + + fn test_message_blocks() -> [Sha256MessageBlock; N] { + std::array::from_fn(|i| { + std::array::from_fn(|j| { + ((i as u32) << 24) ^ ((j as u32) << 16) ^ 0xa5a5_0000u32.wrapping_add(j as u32) + }) + }) + } + + fn native_chain( + initial_state: Sha256State, + message_blocks: &[Sha256MessageBlock; N], + ) -> Sha256State { + message_blocks.iter().fold(initial_state, |state, block| { + sha256_compress_native(state, *block) + }) + } + + #[test] + fn padded_message_blocks_encode_short_message() { + let blocks = sha256_padded_message_blocks::<1>(b"abc") + .expect("abc should canonically pad to one SHA-256 block"); + + assert_eq!(blocks[0][0], 0x6162_6380); + assert_eq!(blocks[0][14], 0); + assert_eq!(blocks[0][15], 24); + } + + #[test] + fn padded_message_blocks_reject_wrong_fixed_count() { + let err = sha256_padded_message_blocks::<8>(b"abc") + .expect_err("abc should not canonically pad to eight SHA-256 blocks"); + + assert!(matches!( + err, + Sha256WitnessError::PaddedBlockCountMismatch { + expected: 8, + got: 1 + } + )); + } + + #[test] + fn padded_message_blocks_can_make_eight_block_fixture_from_string() { + let message = vec!["hello world"; 40].join(" "); + let blocks = sha256_padded_message_blocks::<8>(message.as_bytes()) + .expect("test message should canonically pad to eight SHA-256 blocks"); + + assert_eq!(blocks[0][0], u32::from_be_bytes(*b"hell")); + assert_eq!(blocks[7][14], 0); + assert_eq!(blocks[7][15], u32::try_from(message.len() * 8).unwrap()); + } + + fn binary_word_to_u32(word: &BinaryPoly<32>) -> u32 { + let dense: DensePolynomial = word.clone().into(); + dense + .coeffs + .iter() + .enumerate() + .fold(0u32, |acc, (bit, coeff)| { + if coeff.inner() { + let shift = u32::try_from(bit).expect("SHA-256 bit index fits u32"); + acc | 1u32 + .checked_shl(shift) + .expect("SHA-256 bit index is below word width") + } else { + acc + } + }) + } + + fn public_sha_state_at( + trace: &UairTrace<'_, P, I, 32>, + row_start: usize, + ) -> Sha256State { + let h_a = std::array::from_fn(|j| { + binary_word_to_u32(&trace.binary_poly[cols::PA_A].evaluations[row_start + j]) + }); + let h_e = std::array::from_fn(|j| { + binary_word_to_u32(&trace.binary_poly[cols::PA_E].evaluations[row_start + j]) + }); + trace_halves_to_state(h_a, h_e) + } + + fn assert_sha_trace_splits_cleanly(trace: &UairTrace<'static, Int<5>, Int<5>, 32>) { + let sig = > as Uair>::signature(); + let public = trace.public(&sig); + let witness = trace.witness(&sig); + + assert_eq!(public.binary_poly.len(), cols::NUM_BIN_PUB); + assert_eq!(public.arbitrary_poly.len(), 0); + assert_eq!(public.int.len(), cols::NUM_INT_PUB); + assert_eq!(witness.binary_poly.len(), cols::NUM_BIN - cols::NUM_BIN_PUB); + assert_eq!(witness.arbitrary_poly.len(), 0); + assert_eq!(witness.int.len(), cols::NUM_INT - cols::NUM_INT_PUB); + > as Uair>::verify_public_structure( + &public, + cols::MIN_NUM_VARS, + ) + .expect("generated SHA public trace should satisfy public structure checks"); + } /// All non-zero-ideal SHA constraints (C1, C2, C4, C6, C7, C8, C9, /// and the new feed-forward C12/C13) must stay degree-1 in the @@ -1852,6 +2237,116 @@ mod tests { assert!(count_max_degree::() >= 2); } + #[test] + fn synthesize_sha256_chain_witnesses_n1_matches_native() { + let initial_state = test_initial_state(); + let message_blocks = test_message_blocks::<1>(); + let expected = sha256_compress_native(initial_state, message_blocks[0]); + + let (witnesses, final_state) = + synthesize_sha256_chain_witnesses::, 1>(initial_state, message_blocks) + .expect("N=1 SHA witness synthesis should succeed"); + + assert_eq!(final_state, expected); + assert_sha_trace_splits_cleanly(&witnesses[0].trace); + } + + #[test] + fn synthesize_sha256_chain_witnesses_n8_matches_native() { + let initial_state = test_initial_state(); + let message_blocks = test_message_blocks::<8>(); + let expected = native_chain(initial_state, &message_blocks); + + let (witnesses, final_state) = + synthesize_sha256_chain_witnesses::, 8>(initial_state, message_blocks) + .expect("N=8 SHA witness synthesis should succeed"); + + assert_eq!(final_state, expected); + for witness in &witnesses { + assert_sha_trace_splits_cleanly(&witness.trace); + } + } + + #[test] + fn synthesize_sha256_chain_trace_n8_matches_native() { + let initial_state = test_initial_state(); + let message_blocks = test_message_blocks::<8>(); + let expected = native_chain(initial_state, &message_blocks); + + let (trace, final_state) = + synthesize_sha256_chain_trace::, 8>(10, initial_state, message_blocks) + .expect("monolithic N=8 SHA trace synthesis should succeed"); + let sig = > as Uair>::signature(); + let public = trace.public(&sig); + + assert_eq!(final_state, expected); + assert_eq!( + public_sha_state_at(&trace, 8 * cols::ROWS_PER_COMP), + final_state + ); + > as Uair>::verify_public_structure(&public, 10) + .expect("generated monolithic SHA public trace should satisfy public structure checks"); + } + + #[test] + fn sha256_chain_trace_and_projection_witnesses_expose_same_h8() { + let initial_state = SHA256_INITIAL_STATE; + let message = vec!["hello world"; 40].join(" "); + let message_blocks = sha256_padded_message_blocks::<8>(message.as_bytes()) + .expect("fixture should canonically pad to eight SHA-256 blocks"); + let expected = native_chain(initial_state, &message_blocks); + + let (mono_trace, mono_final_state) = + synthesize_sha256_chain_trace::, 8>(10, initial_state, message_blocks) + .expect("monolithic N=8 SHA trace synthesis should succeed"); + let (pf_witnesses, pf_final_state) = + synthesize_sha256_chain_witnesses::, 8>(initial_state, message_blocks) + .expect("N=8 SHA witness synthesis should succeed"); + + let mono_public_final = public_sha_state_at(&mono_trace, 8 * cols::ROWS_PER_COMP); + let pf_public_final = public_sha_state_at( + &pf_witnesses[pf_witnesses.len() - 1].trace, + cols::ROWS_PER_COMP, + ); + + assert_eq!(mono_final_state, expected); + assert_eq!(pf_final_state, expected); + assert_eq!(mono_final_state, pf_final_state); + assert_eq!(mono_public_final, pf_public_final); + assert_eq!(mono_public_final, mono_final_state); + assert_eq!(pf_public_final, pf_final_state); + + let digest_hex = mono_final_state + .iter() + .map(|word| format!("{word:08x}")) + .collect::(); + println!("final chained SHA-256 H_8 = {digest_hex}"); + } + + #[test] + fn synthesize_sha256_chain_witnesses_links_adjacent_states() { + let initial_state = test_initial_state(); + let message_blocks = test_message_blocks::<8>(); + + let (witnesses, final_state) = + synthesize_sha256_chain_witnesses::, 8>(initial_state, message_blocks) + .expect("N=8 SHA witness synthesis should succeed"); + + assert_eq!(public_sha_state_at(&witnesses[0].trace, 0), initial_state); + for (index, pair) in witnesses.windows(2).enumerate() { + let left_output = public_sha_state_at(&pair[0].trace, cols::ROWS_PER_COMP); + let right_input = public_sha_state_at(&pair[1].trace, 0); + assert_eq!( + left_output, right_input, + "SHA witness {index} output must feed the next witness input" + ); + } + assert_eq!( + public_sha_state_at(&witnesses[witnesses.len() - 1].trace, cols::ROWS_PER_COMP), + final_state + ); + } + /// Cross-check the K_CANONICAL table against the canonical SHA-256 /// initial hash values H_0 — running one full compression of the /// empty-padding block (with H_0 as input) must produce the @@ -1861,11 +2356,7 @@ mod tests { /// logic itself). #[test] fn k_canonical_matches_sha256_empty_string_digest() { - // SHA-256 H_0 (FIPS 180-4 §5.3.3). - let h_in: [u32; 8] = [ - 0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, - 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19, - ]; + let h_in = SHA256_INITIAL_STATE; // Empty-string padded block: single 0x80 byte then 63 zero bytes. let mut m = [0u32; 16]; m[0] = 0x80000000; @@ -1890,9 +2381,13 @@ mod tests { .wrapping_add(K_CANONICAL[t]) .wrapping_add(w[t]); let t2 = big_sigma0(a).wrapping_add(maj(a, b, c)); - h = g; g = f; f = e; + h = g; + g = f; + f = e; e = d.wrapping_add(t1); - d = c; c = b; b = a; + d = c; + c = b; + b = a; a = t1.wrapping_add(t2); } @@ -1909,9 +2404,12 @@ mod tests { ]; let expected: [u32; 8] = [ - 0xe3b0c442, 0x98fc1c14, 0x9afbf4c8, 0x996fb924, - 0x27ae41e4, 0x649b934c, 0xa495991b, 0x7852b855, + 0xe3b0c442, 0x98fc1c14, 0x9afbf4c8, 0x996fb924, 0x27ae41e4, 0x649b934c, 0xa495991b, + 0x7852b855, ]; - assert_eq!(h_out, expected, "SHA-256(\"\") digest mismatch — K table or round logic drift"); + assert_eq!( + h_out, expected, + "SHA-256(\"\") digest mismatch — K table or round logic drift" + ); } } diff --git a/test-uair/src/sha_ecdsa.rs b/test-uair/src/sha_ecdsa.rs index a0a33d8c..2eb3e11e 100644 --- a/test-uair/src/sha_ecdsa.rs +++ b/test-uair/src/sha_ecdsa.rs @@ -67,10 +67,7 @@ use core::marker::PhantomData; use crypto_primitives::ConstSemiring; use rand::RngCore; -use zinc_poly::{ - mle::DenseMultilinearExtension, - univariate::dense::DensePolynomial, -}; +use zinc_poly::{mle::DenseMultilinearExtension, univariate::dense::DensePolynomial}; use zinc_uair::{ BitOp, BitOpSpec, ConstraintBuilder, LookupColumnSpec, PublicColumnLayout, PublicStructureError, ShiftSpec, ShiftedBitSliceSpec, TotalColumnLayout, TraceRow, Uair, @@ -78,9 +75,10 @@ use zinc_uair::{ ideal::rotation::RotationIdeal, }; +#[cfg(test)] +use crate::ecdsa::FINAL_ROW as ECDSA_FINAL_ROW; use crate::{ - GenerateRandomTrace, - ecdsa::{self, FINAL_ROW as ECDSA_FINAL_ROW, NUM_SHAMIR_ROUNDS}, + GenerateRandomTrace, ecdsa, ecdsa_doubling::{EC_FP_INT_LIMBS, EcdsaFpRing}, sha256::{self, Sha256CompressionSliceUair, Sha256Ideal}, }; @@ -88,7 +86,7 @@ use crate::{ use crypto_primitives::crypto_bigint_int::Int; // Re-export for convenience. -pub use crate::ecdsa::FINAL_ROW; +pub use crate::ecdsa::{FINAL_ROW, NUM_SHAMIR_ROUNDS}; // --------------------------------------------------------------------------- // Column layout for the merged trace. @@ -281,11 +279,11 @@ where // for the full mapping (`Rot(c)` ≡ `ROTR^{32-c}` ≡ multiplication // by `X^c mod (X^32 − 1)`). All six specs target FLAT_W_W. let bit_op_specs: Vec = vec![ - BitOpSpec::new(cols::FLAT_W_W, BitOp::Rot(25)), // σ_0: ROTR^7 - BitOpSpec::new(cols::FLAT_W_W, BitOp::Rot(14)), // σ_0: ROTR^18 - BitOpSpec::new(cols::FLAT_W_W, BitOp::ShiftR(3)), // σ_0: SHR^3 - BitOpSpec::new(cols::FLAT_W_W, BitOp::Rot(15)), // σ_1: ROTR^17 - BitOpSpec::new(cols::FLAT_W_W, BitOp::Rot(13)), // σ_1: ROTR^19 + BitOpSpec::new(cols::FLAT_W_W, BitOp::Rot(25)), // σ_0: ROTR^7 + BitOpSpec::new(cols::FLAT_W_W, BitOp::Rot(14)), // σ_0: ROTR^18 + BitOpSpec::new(cols::FLAT_W_W, BitOp::ShiftR(3)), // σ_0: SHR^3 + BitOpSpec::new(cols::FLAT_W_W, BitOp::Rot(15)), // σ_1: ROTR^17 + BitOpSpec::new(cols::FLAT_W_W, BitOp::Rot(13)), // σ_1: ROTR^19 BitOpSpec::new(cols::FLAT_W_W, BitOp::ShiftR(10)), // σ_1: SHR^10 // Bit-op virtuals over W_MU_PACKED for extracting the 5 // chained-comp carries. See sha256.rs cols doc. @@ -483,7 +481,7 @@ where let _down_w_e_sh2 = &down.binary_poly[5]; let down_w_e_sh4 = &down.binary_poly[6]; let down_w_sig1_sh3 = &down.binary_poly[7]; - let down_w_w_sh3 = &down.binary_poly[8]; + let _down_w_w_sh3 = &down.binary_poly[8]; let down_w_w_sh9 = &down.binary_poly[9]; let down_w_w_sh16 = &down.binary_poly[10]; let down_w_lsig0_sh1 = &down.binary_poly[11]; @@ -534,8 +532,7 @@ where let const_2_to_34 = const_scalar::(pow_two::(34)); let const_2_to_35 = const_scalar::(pow_two::(35)); - let mu_w_contrib = mbs(w_mu_packed, &const_2_to_32) - .expect("2^32 · w_mu_packed overflow") + let mu_w_contrib = mbs(w_mu_packed, &const_2_to_32).expect("2^32 · w_mu_packed overflow") - &mbs(down_w_mu_packed_shr2, &const_2_to_34) .expect("2^34 · ShiftR(2)(w_mu_packed) overflow"); let mu_a_contrib = mbs(down_w_mu_packed_shr2, &const_2_to_32) @@ -557,14 +554,16 @@ where // C1: Sigma_0 rotation b.assert_in_ideal( - mbs(w_a, &rho_sig0).expect("a · rho_sig0 overflow") - w_sig0 + mbs(w_a, &rho_sig0).expect("a · rho_sig0 overflow") + - w_sig0 - &mbs(pa_ov_sig0, &two_scalar_sha).expect("2 · ov_sig0 overflow"), &ideal_rot_xw1, ); // C2: Sigma_1 rotation b.assert_in_ideal( - mbs(w_e, &rho_sig1).expect("e · rho_sig1 overflow") - w_sig1 + mbs(w_e, &rho_sig1).expect("e · rho_sig1 overflow") + - w_sig1 - &mbs(pa_ov_sig1, &two_scalar_sha).expect("2 · ov_sig1 overflow"), &ideal_rot_xw1, ); @@ -575,26 +574,25 @@ where // and the C3/C5 right-shift decompositions go away. // ROT^25(W) + ROT^14(W) + SHIFTR^3(W) − lsig0 − 2 · pa_ov_lsig0 == 0 b.assert_zero( - down_w_rot25.clone() + down_w_rot14 + down_w_shr3 - w_lsig0 + down_w_rot25.clone() + down_w_rot14 + down_w_shr3 + - w_lsig0 - &mbs(pa_ov_lsig0, &two_scalar_sha).expect("2 · ov_lsig0 overflow"), ); // C6 (was σ_1 (X^32 − 1) ideal-lift): σ_1 analogue of C4. // ROT^15(W) + ROT^13(W) + SHIFTR^10(W) − lsig1 − 2 · pa_ov_lsig1 == 0 b.assert_zero( - down_w_rot15.clone() + down_w_rot13 + down_w_shr10 - w_lsig1 + down_w_rot15.clone() + down_w_rot13 + down_w_shr10 + - w_lsig1 - &mbs(pa_ov_lsig1, &two_scalar_sha).expect("2 · ov_lsig1 overflow"), ); // C7: Message-schedule modular sum. mu_W from up.w_mu_packed // bits 0-1 via mu_w_contrib (chained-comp re-anchoring stores // each carry at its constraint's anchor row). - let sched_inner = down_w_w_sh16.clone() - - w_big_w - - down_w_lsig0_sh1 - - down_w_w_sh9 - - down_w_lsig1_sh14 - + &mu_w_contrib; + let sched_inner = + down_w_w_sh16.clone() - w_big_w - down_w_lsig0_sh1 - down_w_w_sh9 - down_w_lsig1_sh14 + + &mu_w_contrib; b.assert_in_ideal(sched_inner + pa_c_c7, &ideal_rot_x2); // C8: Register-update for `a`. mu_a from bits 2-4 of W_MU_PACKED. @@ -604,7 +602,7 @@ where - down_w_u_ef_sh3 - down_w_u_neg_e_g_sh3 - down_pa_k_sh3 - - down_w_w_sh3 + - w_big_w - down_w_sig0_sh3 - down_w_maj_sh3 + &mu_a_contrib; @@ -618,7 +616,7 @@ where - down_w_u_ef_sh3 - down_w_u_neg_e_g_sh3 - down_pa_k_sh3 - - down_w_w_sh3 + - w_big_w + &mu_e_contrib; b.assert_in_ideal(e_update_inner + pa_c_c9, &ideal_rot_x2); @@ -646,16 +644,10 @@ where // down.w_a^↓4 = w_a[k+4] = H_{i+1}, j-th component (pinned by C10) // up.sha_w_mu_junction_a = carry ∈ {0, 1} // mu_ff_a / mu_ff_e from bits 8 / 9 of W_MU_PACKED. - let ff_a_inner = down_w_a_sh4.clone() - - w_a - - pa_a - + &mu_ff_a_contrib; + let ff_a_inner = down_w_a_sh4.clone() - w_a - pa_a + &mu_ff_a_contrib; b.assert_in_ideal(ff_a_inner + pa_c_ff_a, &ideal_rot_x2); - let ff_e_inner = down_w_e_sh4.clone() - - w_e - - pa_e - + &mu_ff_e_contrib; + let ff_e_inner = down_w_e_sh4.clone() - w_e - pa_e + &mu_ff_e_contrib; b.assert_in_ideal(ff_e_inner + pa_c_ff_e, &ideal_rot_x2); // C16: message init (Table 9 row 77). Pin w_W to public message @@ -739,15 +731,12 @@ where b.assert_zero(e_s_active.clone() * &d3_inner); let x3_y_sq = x_sq.clone() * &x_y_sq; - let twelve_x3_y_sq = - mbs(&x3_y_sq, &twelve_scalar).expect("12·X³·Y² overflow"); + let twelve_x3_y_sq = mbs(&x3_y_sq, &twelve_scalar).expect("12·X³·Y² overflow"); let x_sq_x_pa = x_sq.clone() * e_x_pa; - let three_x2_xpa = - mbs(&x_sq_x_pa, &three_scalar).expect("3·X²·X_pa overflow"); + let three_x2_xpa = mbs(&x_sq_x_pa, &three_scalar).expect("3·X²·X_pa overflow"); let y_pow4 = e_y_sq.clone() * &e_y_sq; let eight_y_pow4 = mbs(&y_pow4, &eight_scalar).expect("8·Y⁴ overflow"); - let d4_inner = - e_y_pa.clone() - &twelve_x3_y_sq + &three_x2_xpa + &eight_y_pow4; + let d4_inner = e_y_pa.clone() - &twelve_x3_y_sq + &three_x2_xpa + &eight_y_pow4; b.assert_zero(e_s_active.clone() * &d4_inner); // === In-circuit affine addend selection === @@ -797,12 +786,10 @@ where // Y: down.Y − Y_pa − S_ADD·(3·D·X_pa·C² + D·C³ − D³ − Y_pa·C³ − Y_pa) = 0 let d_cube = e_d.clone() * &d_sq; let d_x_pa_c_sq = e_d.clone() * &e_x_pa_c_sq; - let three_d_x_pa_c_sq = - mbs(&d_x_pa_c_sq, &three_scalar).expect("3·D·X_pa·C² overflow"); + let three_d_x_pa_c_sq = mbs(&d_x_pa_c_sq, &three_scalar).expect("3·D·X_pa·C² overflow"); let d_c_cube = e_d.clone() * &e_c_cube; let y_pa_c_cube = e_y_pa.clone() * &e_c_cube; - let y_add_minus_y_pa = - three_d_x_pa_c_sq + &d_c_cube - &d_cube - &y_pa_c_cube - e_y_pa; + let y_add_minus_y_pa = three_d_x_pa_c_sq + &d_c_cube - &d_cube - &y_pa_c_cube - e_y_pa; let s_add_y = e_s_add.clone() * &y_add_minus_y_pa; let o2_inner = down_ecdsa_y_sh1.clone() - e_y_pa - &s_add_y; b.assert_zero(e_s_active.clone() * &o2_inner); @@ -933,10 +920,12 @@ where "ShaEcdsa UAIR needs > {FINAL_ROW} rows; got {n_rows}", ); - let sha_trace = as GenerateRandomTrace<32>>:: - generate_random_trace(num_vars, rng); - let ecdsa_trace = as GenerateRandomTrace<32>>:: - generate_random_trace(num_vars, rng); + let sha_trace = + as GenerateRandomTrace<32>>::generate_random_trace( + num_vars, rng, + ); + let ecdsa_trace = + as GenerateRandomTrace<32>>::generate_random_trace(num_vars, rng); // Sanity: column counts match the standalone UAIRs. assert_eq!(sha_trace.binary_poly.len(), sha256::cols::NUM_BIN); @@ -944,8 +933,7 @@ where assert_eq!(ecdsa_trace.int.len(), ecdsa::cols::NUM_INT); // Binary_poly: copy SHA's directly (ECDSA contributes nothing). - let binary_poly: Vec> = - sha_trace.binary_poly.into_owned(); + let binary_poly: Vec> = sha_trace.binary_poly.into_owned(); // Int section: merge per the layout in `cols`. // @@ -1016,8 +1004,14 @@ mod tests { // some deg-2 (boundaries + chaining + the compensator-zero pins), // some deg-1 (SHA C1, C2, C4, C6 — including the new row-local // σ_0/σ_1 equalities). - assert!(degrees.iter().any(|&d| d == 7), "expected deg-7 from ECDSA C-A2"); - assert!(degrees.iter().filter(|&&d| d == 2).count() >= 3, "expected ≥3 deg-2"); + assert!( + degrees.iter().any(|&d| d == 7), + "expected deg-7 from ECDSA C-A2" + ); + assert!( + degrees.iter().filter(|&&d| d == 2).count() >= 3, + "expected ≥3 deg-2" + ); } /// The merged trace builder produces a trace with the right column @@ -1027,8 +1021,10 @@ mod tests { fn merged_trace_shape() { let num_vars = 9; let mut r = rng(); - let trace = > as GenerateRandomTrace<32>>:: - generate_random_trace(num_vars, &mut r); + let trace = + > as GenerateRandomTrace<32>>::generate_random_trace( + num_vars, &mut r, + ); assert_eq!(trace.binary_poly.len(), cols::NUM_BIN); assert_eq!(trace.int.len(), cols::NUM_INT); diff --git a/transcript/src/traits.rs b/transcript/src/traits.rs index 8c29dc1f..cc3f1460 100644 --- a/transcript/src/traits.rs +++ b/transcript/src/traits.rs @@ -340,6 +340,30 @@ pub trait Transcript { F::new_with_cfg(random_inner, cfg) } + fn get_variable_field_challenge( + &mut self, + cfg: &F::Config, + num_bytes: usize, + ) -> F + where + F::Inner: GenTranscribable, + { + let mut buf = vec![0u8; num_bytes]; + for chunk in buf.chunks_mut(u64::NUM_BYTES) { + let word = self.get_challenge::(); + chunk.copy_from_slice(&word.to_le_bytes()[..chunk.len()]); + } + F::new_with_cfg(F::Inner::read_transcription_bytes_exact(&buf), cfg) + } + + fn get_transcribable_field_challenge(&mut self, cfg: &F::Config) -> F + where + F::Inner: Transcribable, + { + let zero = F::zero_with_cfg(cfg); + self.get_variable_field_challenge(cfg, zero.inner().get_num_bytes()) + } + /// Generates a pseudorandom transcribable values as challenges based on the /// current transcript state, updating it. // TODO(Alex): `get_field_challenge` is not efficient @@ -353,6 +377,19 @@ pub trait Transcript { (0..n).map(|_| self.get_field_challenge(cfg)).collect() } + fn get_transcribable_field_challenges( + &mut self, + n: usize, + cfg: &F::Config, + ) -> Vec + where + F::Inner: Transcribable, + { + (0..n) + .map(|_| self.get_transcribable_field_challenge(cfg)) + .collect() + } + /// Generates a pseudorandom transcribable values as challenges based on the /// current transcript state, updating it. fn get_challenges(&mut self, n: usize) -> Vec { @@ -412,6 +449,16 @@ pub trait Transcript { self.absorb_inner(&[0x3]) } + fn absorb_random_field_owned(&mut self, v: &F) + where + F: PrimeField, + F::Inner: Transcribable, + F::Modulus: Transcribable, + { + let mut buf = vec![0u8; v.inner().get_num_bytes()]; + self.absorb_random_field(v, &mut buf); + } + /// Absorbs a slice of field element into the transcript. /// Delegates to the field element's implementation of /// absorb_into_transcript. @@ -423,6 +470,15 @@ pub trait Transcript { { v.iter().for_each(|x| self.absorb_random_field(x, buf)); } + + fn absorb_random_field_slice_owned(&mut self, v: &[F]) + where + F: PrimeField, + F::Inner: Transcribable, + F::Modulus: Transcribable, + { + v.iter().for_each(|x| self.absorb_random_field_owned(x)); + } } // diff --git a/uair/src/lib.rs b/uair/src/lib.rs index bdf6e6fb..91e89bcb 100644 --- a/uair/src/lib.rs +++ b/uair/src/lib.rs @@ -526,9 +526,10 @@ impl UairSignature { spec.witness_col_idx, ); let flat_col = spec.witness_col_idx + num_pub_bin; - let matched = self.shifts.iter().any(|s| { - s.source_col() == flat_col && s.shift_amount() == spec.shift_amount - }); + let matched = self + .shifts + .iter() + .any(|s| s.source_col() == flat_col && s.shift_amount() == spec.shift_amount); assert!( matched, "ShiftedBitSliceSpec(col {}, shift {}) has no matching ShiftSpec", @@ -556,9 +557,7 @@ impl UairSignature { let mut bin_down_idx = 0usize; for s in &self.shifts { if s.source_col() < num_total_bin { - if s.source_col() == flat_col - && s.shift_amount() == spec.shift_amount - { + if s.source_col() == flat_col && s.shift_amount() == spec.shift_amount { return bin_down_idx; } bin_down_idx += 1; @@ -661,10 +660,7 @@ impl UairSignature { /// binary_poly cols. Per-bit closing overrides on the verifier side /// bind each bit's MLE eval to the spec residual. #[must_use] - pub fn with_virtual_binary_poly_cols( - mut self, - cols: Vec, - ) -> Self { + pub fn with_virtual_binary_poly_cols(mut self, cols: Vec) -> Self { let num_wit_bin = self.witness_cols.num_binary_poly_cols(); let num_pub_bin = self.public_cols.num_binary_poly_cols(); let num_shifted = self.shifted_bit_slice_specs.len(); @@ -786,6 +782,16 @@ pub struct UairTrace<'a, PolyCoeff: Clone, Int: Clone, const D: usize> { pub int: Cow<'a, [DenseMultilinearExtension]>, } +/// Prover-private UAIR witness data. +/// +/// The wrapped trace may still contain public-prefix columns; callers can split +/// it with [`UairTrace::public`] and [`UairTrace::witness`] using the UAIR +/// signature. +#[derive(Debug, Clone)] +pub struct UairWitness<'a, PolyCoeff: Clone, Int: Clone, const D: usize> { + pub trace: UairTrace<'a, PolyCoeff, Int, D>, +} + impl UairTrace<'static, PolyCoeff, Int, D> { /// Returns a sub-trace containing only public columns. /// Returned trace is borrowed from the full trace. @@ -1002,8 +1008,5 @@ pub enum PublicStructureError { /// expected closed-form value (e.g. the tail-corrector boundary /// formula). #[error("public column '{column}' at row {row} has wrong value")] - WrongValue { - column: &'static str, - row: usize, - }, + WrongValue { column: &'static str, row: usize }, } diff --git a/utils/src/delayed_reduction.rs b/utils/src/delayed_reduction.rs new file mode 100644 index 00000000..a7a3e831 --- /dev/null +++ b/utils/src/delayed_reduction.rs @@ -0,0 +1,921 @@ +//! Delayed modular reduction helpers for fixed 4-limb Montgomery fields. +//! +//! This module is intentionally narrow: it supports summing Montgomery-form +//! field elements into a 5-limb accumulator, then reducing once with Barrett +//! reduction. The limb routines are adapted from Spartan2's MIT-licensed +//! `big_num` helpers. + +use crypto_bigint::modular::{ConstMontyForm, ConstMontyParams, MontyForm}; +use crypto_primitives::{ + PrimeField, crypto_bigint_const_monty::ConstMontyField, crypto_bigint_monty::MontyField, + crypto_bigint_uint::Uint, +}; +use num_traits::Zero; +use std::marker::PhantomData; + +const DEFAULT_DMR_FLUSH_ADDS: usize = 1 << 20; + +/// Barrett reduction parameters modulo a 4-limb prime. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct BarrettReductionParams { + /// The 4-limb prime modulus in little-endian limb order. + pub modulus: [u64; 4], + + /// `floor(2^512 / MODULUS)`, stored in little-endian limb order. + pub mu: [u64; 5], +} + +impl BarrettReductionParams { + #[inline(always)] + pub const fn new(modulus: [u64; 4]) -> Self { + Self { + modulus, + mu: compute_barrett_mu(modulus), + } + } +} + +/// Field types that expose reduced Montgomery-form limbs. +pub trait MontgomeryLimbs: PrimeField> + Sized { + /// Construct a field element from reduced Montgomery-form limbs. + fn from_montgomery_limbs(limbs: [u64; 4], cfg: &Self::Config) -> Self; + + /// Borrow the field element's Montgomery-form limbs. + fn montgomery_limbs(&self) -> &[u64; 4]; + + /// Return Barrett reduction parameters for this field configuration. + fn barrett_reduction_params(cfg: &Self::Config) -> BarrettReductionParams; +} + +/// Algorithm object for delayed modular reduction. +pub trait DelayedModularReductionAlgorithm { + type Value; + type Accumulator; + + fn zero_accumulator(&self) -> Self::Accumulator; + fn add(&self, acc: &mut Self::Accumulator, value: &Self::Value); + fn reduce(&self, acc: Self::Accumulator) -> Self::Value; +} + +/// Algorithm object for delayed field product sums. +pub trait DelayedFieldProductSumAlgorithm { + type Value; + type Accumulator; + + fn zero_accumulator(&self) -> Self::Accumulator; + fn add_product(&self, acc: &mut Self::Accumulator, lhs: &Self::Value, rhs: &Self::Value); + fn reduce_products(&self, acc: Self::Accumulator) -> Self::Value; + fn sum_of_products(&self, lhs: &[Self::Value], rhs: &[Self::Value]) -> Self::Value; + fn sum_of_products_with_seed( + &self, + lhs: &[Self::Value], + rhs: &[Self::Value], + seed: Self::Value, + ) -> Self::Value; +} + +/// Accumulator trait for delayed modular reduction. +pub trait DelayedModularReduction: Zero + Clone + Send + Sync +where + F: PrimeField, +{ + fn add(&mut self, value: &F); + fn reduce(self, cfg: &F::Config, params: &BarrettReductionParams) -> F; +} + +/// Field product-sum backend for delayed modular reduction-aware dot products. +pub trait DelayedFieldProductSum: PrimeField + Sized { + /// Compute `zero + sum_i lhs[i] * rhs[i]`. + /// + /// The caller is responsible for enforcing equal slice lengths. + fn delayed_sum_of_products(lhs: &[Self], rhs: &[Self], zero: Self) -> Self; +} + +#[derive(Clone, Debug)] +pub struct BarrettDelayedReduction<'cfg, F> +where + F: MontgomeryLimbs, +{ + cfg: &'cfg F::Config, + params: BarrettReductionParams, + flush_adds: usize, + _field: PhantomData, +} + +impl<'cfg, F> BarrettDelayedReduction<'cfg, F> +where + F: MontgomeryLimbs, +{ + pub fn new(cfg: &'cfg F::Config) -> Self { + let params = F::barrett_reduction_params(cfg); + let flush_adds = if params.modulus[3] == 0 { + 1 + } else { + DEFAULT_DMR_FLUSH_ADDS + }; + Self { + cfg, + params, + flush_adds, + _field: PhantomData, + } + } + + pub fn flush_adds(&self) -> usize { + self.flush_adds + } + + pub fn params(&self) -> &BarrettReductionParams { + &self.params + } +} + +impl DelayedModularReductionAlgorithm for BarrettDelayedReduction<'_, F> +where + F: MontgomeryLimbs + Send + Sync, +{ + type Value = F; + type Accumulator = Uint<5>; + + fn zero_accumulator(&self) -> Self::Accumulator { + Uint::zero() + } + + #[inline(always)] + fn add(&self, acc: &mut Self::Accumulator, value: &Self::Value) { + add_montgomery_limbs_5(acc, value.montgomery_limbs()); + } + + #[inline(always)] + fn reduce(&self, acc: Self::Accumulator) -> Self::Value { + F::from_montgomery_limbs(barrett_reduce_5(acc.as_words(), &self.params), self.cfg) + } +} + +/// Raw accumulator for a delayed sum of 4-limb Montgomery products. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct ProductAccumulator4 { + limbs: Uint<9>, + pending_products: usize, +} + +impl ProductAccumulator4 { + pub fn pending_products(&self) -> usize { + self.pending_products + } + + pub fn limbs(&self) -> &Uint<9> { + &self.limbs + } +} + +#[derive(Clone, Debug)] +pub struct MontgomeryProductSum4<'cfg, F> +where + F: MontgomeryLimbs, +{ + cfg: &'cfg F::Config, + reduction_params: BarrettReductionParams, + mod_neg_inv: u64, + flush_products: usize, + _field: PhantomData, +} + +impl<'cfg, F> MontgomeryProductSum4<'cfg, F> +where + F: MontgomeryLimbs, +{ + pub fn new(cfg: &'cfg F::Config) -> Self { + let reduction_params = F::barrett_reduction_params(cfg); + let leading_zeros = clz::<4>(&reduction_params.modulus); + let flush_products = if leading_zeros == 0 { + 1 + } else { + usize::try_from(leading_zeros) + .ok() + .and_then(|shift| 1usize.checked_shl(shift as u32)) + .unwrap_or(usize::MAX) + }; + Self::new_with_flush_products(cfg, flush_products) + } + + pub fn new_with_flush_products(cfg: &'cfg F::Config, flush_products: usize) -> Self { + let reduction_params = F::barrett_reduction_params(cfg); + Self { + cfg, + reduction_params, + mod_neg_inv: mod_neg_inv_u64(reduction_params.modulus[0]), + flush_products: flush_products.max(1), + _field: PhantomData, + } + } + + pub fn flush_products(&self) -> usize { + self.flush_products + } +} + +impl DelayedFieldProductSumAlgorithm for MontgomeryProductSum4<'_, F> +where + F: MontgomeryLimbs + Send + Sync, +{ + type Value = F; + type Accumulator = ProductAccumulator4; + + fn zero_accumulator(&self) -> Self::Accumulator { + ProductAccumulator4 { + limbs: Uint::zero(), + pending_products: 0, + } + } + + #[inline(always)] + fn add_product(&self, acc: &mut Self::Accumulator, lhs: &Self::Value, rhs: &Self::Value) { + debug_assert!( + acc.pending_products < self.flush_products, + "ProductAccumulator4 must be reduced before exceeding its flush threshold" + ); + add_montgomery_product_4x4( + &mut acc.limbs, + lhs.montgomery_limbs(), + rhs.montgomery_limbs(), + ); + acc.pending_products = acc.pending_products.saturating_add(1); + } + + fn reduce_products(&self, acc: Self::Accumulator) -> Self::Value { + if acc.pending_products == 0 { + return F::zero_with_cfg(self.cfg); + } + let reduced = montgomery_reduce_9_to_4( + acc.limbs.as_words(), + &self.reduction_params, + self.mod_neg_inv, + ); + F::from_montgomery_limbs(reduced, self.cfg) + } + + fn sum_of_products(&self, lhs: &[Self::Value], rhs: &[Self::Value]) -> Self::Value { + let mut total = F::zero_with_cfg(self.cfg); + let mut acc = self.zero_accumulator(); + for (left, right) in lhs.iter().zip(rhs) { + self.add_product(&mut acc, left, right); + if acc.pending_products >= self.flush_products { + let pending = acc; + total += self.reduce_products(pending); + acc = self.zero_accumulator(); + } + } + if acc.pending_products != 0 { + total += self.reduce_products(acc); + } + total + } + + fn sum_of_products_with_seed( + &self, + lhs: &[Self::Value], + rhs: &[Self::Value], + seed: Self::Value, + ) -> Self::Value { + seed + self.sum_of_products(lhs, rhs) + } +} + +impl DelayedModularReduction for Uint<5> +where + F: MontgomeryLimbs + Send + Sync, +{ + #[inline(always)] + fn add(&mut self, value: &F) { + add_montgomery_limbs_5(self, value.montgomery_limbs()); + } + + #[inline(always)] + fn reduce(self, cfg: &F::Config, params: &BarrettReductionParams) -> F { + let acc = self.as_words(); + F::from_montgomery_limbs(barrett_reduce_5(acc, params), cfg) + } +} + +#[inline(always)] +fn add_montgomery_limbs_5(acc: &mut Uint<5>, rhs: &[u64; 4]) { + let acc = acc.as_mut_words(); + let mut carry = 0u64; + let mut i = 0; + while i < 4 { + let (sum, c0) = acc[i].overflowing_add(rhs[i]); + let (sum, c1) = sum.overflowing_add(carry); + acc[i] = sum; + carry = (c0 as u64) + (c1 as u64); + i += 1; + } + + let old_hi = acc[4]; + acc[4] = acc[4].wrapping_add(carry); + debug_assert!( + acc[4] >= old_hi, + "Uint<5> delayed accumulator overflowed high limb" + ); +} + +#[inline(always)] +fn mod_neg_inv_u64(modulus_limb: u64) -> u64 { + debug_assert!(modulus_limb & 1 == 1, "Montgomery modulus must be odd"); + let mut inv = 1u64; + let mut i = 0; + while i < 6 { + inv = inv.wrapping_mul(2u64.wrapping_sub(modulus_limb.wrapping_mul(inv))); + i += 1; + } + inv.wrapping_neg() +} + +#[inline(always)] +fn add_montgomery_product_4x4(acc: &mut Uint<9>, lhs: &[u64; 4], rhs: &[u64; 4]) { + let product = mul_4x4_to_8(lhs, rhs); + let acc = acc.as_mut_words(); + let mut carry = 0u64; + let mut i = 0; + while i < 8 { + let (sum, c0) = acc[i].overflowing_add(product[i]); + let (sum, c1) = sum.overflowing_add(carry); + acc[i] = sum; + carry = (c0 as u64) + (c1 as u64); + i += 1; + } + + let old_hi = acc[8]; + acc[8] = acc[8].wrapping_add(carry); + debug_assert!(acc[8] >= old_hi, "ProductAccumulator4 overflowed high limb"); +} + +#[inline(always)] +fn mul_4x4_to_8(lhs: &[u64; 4], rhs: &[u64; 4]) -> [u64; 8] { + let mut result = [0u64; 8]; + let mut i = 0; + while i < 4 { + let mut carry = 0u128; + let mut j = 0; + while j < 4 { + let idx = i + j; + let prod = (lhs[i] as u128) * (rhs[j] as u128) + (result[idx] as u128) + carry; + result[idx] = prod as u64; + carry = prod >> 64; + j += 1; + } + + let mut idx = i + 4; + let mut carry_u64 = carry as u64; + while carry_u64 != 0 && idx < 8 { + let (sum, overflow) = result[idx].overflowing_add(carry_u64); + result[idx] = sum; + carry_u64 = overflow as u64; + idx += 1; + } + debug_assert!(carry_u64 == 0, "4x4 product exceeded eight limbs"); + i += 1; + } + result +} + +#[inline(always)] +fn montgomery_reduce_9_to_4( + acc: &[u64; 9], + params: &BarrettReductionParams, + mod_neg_inv: u64, +) -> [u64; 4] { + let mut t = *acc; + let mut i = 0; + while i < 4 { + let q = t[i].wrapping_mul(mod_neg_inv); + let mut carry = 0u128; + let mut j = 0; + while j < 4 { + let idx = i + j; + let sum = (q as u128) * (params.modulus[j] as u128) + (t[idx] as u128) + carry; + t[idx] = sum as u64; + carry = sum >> 64; + j += 1; + } + + let mut idx = i + 4; + let mut carry_u64 = carry as u64; + while carry_u64 != 0 { + debug_assert!(idx < 9, "Montgomery reduction carry exceeded accumulator"); + let (sum, overflow) = t[idx].overflowing_add(carry_u64); + t[idx] = sum; + carry_u64 = overflow as u64; + idx += 1; + } + debug_assert!(t[i] == 0, "Montgomery reduction did not clear low limb"); + i += 1; + } + + let reduced = [t[4], t[5], t[6], t[7], t[8]]; + barrett_reduce_5(&reduced, params) +} + +impl DelayedFieldProductSum for MontyField { + fn delayed_sum_of_products(lhs: &[Self], rhs: &[Self], zero: Self) -> Self { + if lhs.is_empty() { + return zero; + } + + let leading_zeros = zero.cfg().modulus().as_ref().leading_zeros(); + if !lincomb_has_product_sum_headroom(leading_zeros, lhs.len()) { + return naive_sum_of_products(lhs, rhs, zero); + } + + let lhs_forms: Vec> = + lhs.iter().cloned().map(|value| value.into()).collect(); + let rhs_forms: Vec> = + rhs.iter().cloned().map(|value| value.into()).collect(); + let products: Vec<(&MontyForm, &MontyForm)> = + lhs_forms.iter().zip(&rhs_forms).collect(); + + MontyField::new(MontyForm::lincomb_vartime(&products)) + zero + } +} + +impl DelayedFieldProductSum for ConstMontyField +where + Mod: ConstMontyParams, +{ + fn delayed_sum_of_products(lhs: &[Self], rhs: &[Self], zero: Self) -> Self { + if lhs.is_empty() { + return zero; + } + + let leading_zeros = Mod::PARAMS.modulus().as_ref().leading_zeros(); + if !lincomb_has_product_sum_headroom(leading_zeros, lhs.len()) { + return naive_sum_of_products(lhs, rhs, zero); + } + + let products: Vec<(ConstMontyForm, ConstMontyForm)> = lhs + .iter() + .cloned() + .zip(rhs.iter().cloned()) + .map(|(left, right)| (left.into(), right.into())) + .collect(); + + ConstMontyField::from(ConstMontyForm::lincomb(&products)) + zero + } +} + +#[inline(always)] +fn lincomb_has_product_sum_headroom(leading_zeros: u32, len: usize) -> bool { + len > 1 && leading_zeros > 0 +} + +#[allow(clippy::arithmetic_side_effects)] +fn naive_sum_of_products(lhs: &[F], rhs: &[F], zero: F) -> F { + lhs.iter() + .zip(rhs) + .fold(zero, |acc, (left, right)| acc + left.clone() * right) +} + +impl MontgomeryLimbs for ConstMontyField +where + Mod: ConstMontyParams<4>, +{ + #[inline(always)] + fn from_montgomery_limbs(limbs: [u64; 4], _cfg: &Self::Config) -> Self { + Self::new_unchecked(Uint::<4>::from_words(limbs)) + } + + #[inline(always)] + fn montgomery_limbs(&self) -> &[u64; 4] { + self.inner().as_words() + } + + #[inline(always)] + fn barrett_reduction_params(_cfg: &Self::Config) -> BarrettReductionParams { + BarrettReductionParams::new(Uint::<4>::new(*Mod::PARAMS.modulus().as_ref()).to_words()) + } +} + +impl MontgomeryLimbs for MontyField<4> { + #[inline(always)] + fn from_montgomery_limbs(limbs: [u64; 4], cfg: &Self::Config) -> Self { + Self::new_unchecked_with_cfg(Uint::<4>::from_words(limbs), cfg) + } + + #[inline(always)] + fn montgomery_limbs(&self) -> &[u64; 4] { + self.inner().as_words() + } + + #[inline(always)] + fn barrett_reduction_params(cfg: &Self::Config) -> BarrettReductionParams { + BarrettReductionParams::new(Uint::<4>::new(cfg.modulus().get()).to_words()) + } +} + +/// Barrett reduction for a 5-limb value modulo a 4-limb modulus. +/// +/// This uses the 5-limb remainder path, which is required for moduli near +/// `2^256` such as the secp256k1 base prime. +#[inline(always)] +pub fn barrett_reduce_5(c: &[u64; 5], params: &BarrettReductionParams) -> [u64; 4] { + let q1 = [c[3], c[4]]; + let q2 = mul_2x5_to_7(&q1, ¶ms.mu); + let q3 = [q2[5], q2[6]]; + + let r1 = *c; + let r2 = mul_2x4_lo5(&q3, ¶ms.modulus); + let mut r = sub::<5>(&r1, &r2); + + if r[4] != 0 || gte::<4>(&[r[0], r[1], r[2], r[3]], ¶ms.modulus) { + r = sub_5_4(&r, ¶ms.modulus); + } + + debug_assert!( + r[4] == 0 && !gte::<4>(&[r[0], r[1], r[2], r[3]], ¶ms.modulus), + "Barrett reduction produced non-canonical result" + ); + + [r[0], r[1], r[2], r[3]] +} + +#[inline(always)] +fn mul_2x5_to_7(a: &[u64; 2], b: &[u64; 5]) -> [u64; 7] { + let mut result = [0u64; 7]; + for i in 0..2 { + let mut carry = 0u128; + for j in 0..5 { + let prod = (a[i] as u128) * (b[j] as u128) + (result[i + j] as u128) + carry; + result[i + j] = prod as u64; + carry = prod >> 64; + } + result[i + 5] = carry as u64; + } + result +} + +#[inline(always)] +fn mul_2x4_lo5(a: &[u64; 2], b: &[u64; 4]) -> [u64; 5] { + let mut result = [0u64; 5]; + + let mut carry = 0u128; + for j in 0..4 { + let prod = (a[0] as u128) * (b[j] as u128) + carry; + result[j] = prod as u64; + carry = prod >> 64; + } + result[4] = carry as u64; + + carry = 0; + for j in 0..4 { + let prod = (a[1] as u128) * (b[j] as u128) + (result[1 + j] as u128) + carry; + result[1 + j] = prod as u64; + carry = prod >> 64; + } + + result +} + +#[inline(always)] +const fn gte(a: &[u64; N], b: &[u64; N]) -> bool { + let mut i = N; + while i > 0 { + i -= 1; + if a[i] > b[i] { + return true; + } + if a[i] < b[i] { + return false; + } + } + true +} + +#[inline(always)] +const fn sub(a: &[u64; N], b: &[u64; N]) -> [u64; N] { + let mut result = [0u64; N]; + let mut borrow = 0u64; + let mut i = 0; + while i < N { + let (diff, b1) = a[i].overflowing_sub(b[i]); + let (diff2, b2) = diff.overflowing_sub(borrow); + result[i] = diff2; + borrow = (b1 as u64) + (b2 as u64); + i += 1; + } + result +} + +#[inline(always)] +const fn sub_5_4(a: &[u64; 5], b: &[u64; 4]) -> [u64; 5] { + let mut result = [0u64; 5]; + let mut borrow = 0u64; + let mut i = 0; + while i < 4 { + let (diff, b1) = a[i].overflowing_sub(b[i]); + let (diff2, b2) = diff.overflowing_sub(borrow); + result[i] = diff2; + borrow = (b1 as u64) + (b2 as u64); + i += 1; + } + let (diff, _) = a[4].overflowing_sub(borrow); + result[4] = diff; + result +} + +#[inline(always)] +const fn shl(a: &[u64; N]) -> [u64; N] { + let mut result = [0u64; N]; + let mut carry = 0u64; + let mut i = 0; + while i < N { + let new_carry = a[i] >> 63; + result[i] = (a[i] << 1) | carry; + carry = new_carry; + i += 1; + } + result +} + +#[inline(always)] +const fn shr(a: &[u64; N]) -> [u64; N] { + let mut result = [0u64; N]; + let mut carry = 0u64; + let mut i = N; + while i > 0 { + i -= 1; + let new_carry = a[i] << 63; + result[i] = (a[i] >> 1) | carry; + carry = new_carry; + } + result +} + +#[inline(always)] +const fn clz(a: &[u64; N]) -> u32 { + let mut i = N; + let mut count = 0u32; + while i > 0 { + i -= 1; + if a[i] != 0 { + return count + a[i].leading_zeros(); + } + count += 64; + } + count +} + +pub const fn compute_barrett_mu(p: [u64; 4]) -> [u64; 5] { + let mut dividend: [u64; 9] = [0, 0, 0, 0, 0, 0, 0, 0, 1]; + let divisor: [u64; 9] = [p[0], p[1], p[2], p[3], 0, 0, 0, 0, 0]; + let mut quotient: [u64; 5] = [0; 5]; + + let dividend_clz = clz::<9>(÷nd); + let divisor_clz = clz::<9>(&divisor); + if divisor_clz <= dividend_clz { + return quotient; + } + + let shift_bits = divisor_clz - dividend_clz; + let mut shifted_divisor = divisor; + let whole_limbs = (shift_bits / 64) as usize; + let rem_bits = shift_bits % 64; + + if whole_limbs > 0 { + let mut i = 8; + while i >= whole_limbs { + shifted_divisor[i] = shifted_divisor[i - whole_limbs]; + if i == whole_limbs { + break; + } + i -= 1; + } + let mut j = 0; + while j < whole_limbs { + shifted_divisor[j] = 0; + j += 1; + } + } + + let mut i = 0; + while i < rem_bits { + shifted_divisor = shl::<9>(&shifted_divisor); + i += 1; + } + + let mut bit_pos = shift_bits; + loop { + if gte::<9>(÷nd, &shifted_divisor) { + dividend = sub::<9>(÷nd, &shifted_divisor); + let limb_idx = (bit_pos / 64) as usize; + let bit_idx = bit_pos % 64; + if limb_idx < 5 { + quotient[limb_idx] |= 1u64 << bit_idx; + } + } + + if bit_pos == 0 { + break; + } + bit_pos -= 1; + shifted_divisor = shr::<9>(&shifted_divisor); + } + + quotient +} + +#[cfg(test)] +mod tests { + use super::*; + use crypto_primitives::{FromWithConfig, crypto_bigint_monty::MontyField}; + + type F = MontyField<4>; + + fn secp256k1_cfg() -> ::Config { + let modulus = Uint::<4>::from_words([ + 0xFFFF_FFFE_FFFF_FC2F, + 0xFFFF_FFFF_FFFF_FFFF, + 0xFFFF_FFFF_FFFF_FFFF, + 0xFFFF_FFFF_FFFF_FFFF, + ]); + F::make_cfg(&modulus).expect("secp256k1 base field prime is valid") + } + + fn batched_product_cfg() -> ::Config { + let modulus = Uint::new( + crypto_bigint::Uint::<4>::from_str_radix_vartime( + "00dca94d8a1ecce3b6e8755d8999787d0524d8ca1ea755e7af84fb646fa31f27", + 16, + ) + .expect("valid modulus"), + ); + F::make_cfg(&modulus).expect("valid field config") + } + + #[test] + fn secp256k1_barrett_params_match_expected_modulus() { + let cfg = secp256k1_cfg(); + assert_eq!( + F::barrett_reduction_params(&cfg).modulus, + [ + 0xFFFF_FFFE_FFFF_FC2F, + 0xFFFF_FFFF_FFFF_FFFF, + 0xFFFF_FFFF_FFFF_FFFF, + 0xFFFF_FFFF_FFFF_FFFF, + ], + ); + } + + #[test] + fn delayed_sum_matches_field_addition() { + let cfg = secp256k1_cfg(); + let reducer = BarrettDelayedReduction::::new(&cfg); + let values: Vec = (0..512) + .map(|i| F::from_with_cfg(i as u64 + 1, &cfg)) + .collect(); + + let mut expected = F::zero_with_cfg(&cfg); + for value in &values { + expected += value; + } + + let mut acc = reducer.zero_accumulator(); + for value in &values { + reducer.add(&mut acc, value); + } + + assert_eq!(reducer.reduce(acc), expected); + } + + #[test] + fn barrett_reduce_5_matches_uint_remainder_for_bounded_sum() { + let cfg = secp256k1_cfg(); + let reduction_params = F::barrett_reduction_params(&cfg); + let reducer = BarrettDelayedReduction::::new(&cfg); + let mut acc = Uint::<5>::zero(); + let max = -F::from_with_cfg(1u64, &cfg); + for _ in 0..512 { + reducer.add(&mut acc, &max); + } + + let wide = acc; + let modulus = Uint::<5>::from_words([ + reduction_params.modulus[0], + reduction_params.modulus[1], + reduction_params.modulus[2], + reduction_params.modulus[3], + 0, + ]); + let expected = (wide % &modulus) + .checked_resize::<4>() + .expect("remainder fits in four limbs"); + + let acc = acc.as_words(); + let reduced = barrett_reduce_5(acc, &reduction_params); + assert_eq!(Uint::<4>::from_words(reduced), expected); + } + + #[test] + fn product_accumulator_single_product_matches_field_multiplication() { + let cfg = secp256k1_cfg(); + let reducer = MontgomeryProductSum4::::new(&cfg); + let lhs = F::from_with_cfg(17u64, &cfg); + let rhs = F::from_with_cfg(23u64, &cfg); + + let mut acc = reducer.zero_accumulator(); + reducer.add_product(&mut acc, &lhs, &rhs); + + assert_eq!(reducer.reduce_products(acc), lhs * &rhs); + } + + #[test] + fn product_accumulator_multi_product_matches_naive_sum() { + let cfg = secp256k1_cfg(); + let reducer = MontgomeryProductSum4::::new(&cfg); + let lhs: Vec = (0..32).map(|idx| F::from_with_cfg(idx + 3, &cfg)).collect(); + let rhs: Vec = (0..32) + .map(|idx| F::from_with_cfg(257 - idx, &cfg)) + .collect(); + + let expected = lhs + .iter() + .zip(&rhs) + .fold(F::zero_with_cfg(&cfg), |acc, (left, right)| { + acc + left.clone() * right + }); + + assert_eq!(reducer.sum_of_products(&lhs, &rhs), expected); + } + + #[test] + fn product_accumulator_batches_near_modulus_terms_before_reduction() { + let cfg = batched_product_cfg(); + let reducer = MontgomeryProductSum4::::new(&cfg); + assert!(reducer.flush_products() > 64); + + let lhs: Vec = (0..64) + .map(|idx| -F::from_with_cfg(idx * 17 + 5, &cfg)) + .collect(); + let rhs: Vec = (0..64) + .map(|idx| -F::from_with_cfg(idx * 19 + 7, &cfg)) + .collect(); + + let mut acc = reducer.zero_accumulator(); + for (left, right) in lhs.iter().zip(&rhs) { + reducer.add_product(&mut acc, left, right); + } + assert_eq!(acc.pending_products(), lhs.len()); + + let expected = lhs + .iter() + .zip(&rhs) + .fold(F::zero_with_cfg(&cfg), |sum, (left, right)| { + sum + left.clone() * right + }); + + assert_eq!(reducer.reduce_products(acc), expected); + } + + #[test] + fn product_accumulator_seeded_sum_matches_naive_sum() { + let cfg = secp256k1_cfg(); + let reducer = MontgomeryProductSum4::::new(&cfg); + let seed = F::from_with_cfg(99u64, &cfg); + let lhs: Vec = (0..16).map(|idx| F::from_with_cfg(idx + 5, &cfg)).collect(); + let rhs: Vec = (0..16) + .map(|idx| F::from_with_cfg(131 - idx, &cfg)) + .collect(); + + let expected = lhs + .iter() + .zip(&rhs) + .fold(seed.clone(), |acc, (left, right)| { + acc + left.clone() * right + }); + + assert_eq!( + reducer.sum_of_products_with_seed(&lhs, &rhs, seed), + expected + ); + } + + #[test] + fn product_accumulator_forced_flush_matches_naive_sum() { + let cfg = secp256k1_cfg(); + let reducer = MontgomeryProductSum4::::new_with_flush_products(&cfg, 1); + let lhs: Vec = (0..32) + .map(|idx| F::from_with_cfg(idx + 11, &cfg)) + .collect(); + let rhs: Vec = (0..32) + .map(|idx| F::from_with_cfg(409 - idx, &cfg)) + .collect(); + + let expected = lhs + .iter() + .zip(&rhs) + .fold(F::zero_with_cfg(&cfg), |acc, (left, right)| { + acc + left.clone() * right + }); + + assert_eq!(reducer.sum_of_products(&lhs, &rhs), expected); + } +} diff --git a/utils/src/field/boxed_monty.rs b/utils/src/field/boxed_monty.rs index a1a7296d..eb63dd5b 100644 --- a/utils/src/field/boxed_monty.rs +++ b/utils/src/field/boxed_monty.rs @@ -1,11 +1,13 @@ -use crypto_bigint::BoxedUint; +use crypto_bigint::{BoxedUint, modular::BoxedMontyForm}; use crypto_primitives::{ FromWithConfig, IntoWithConfig, PrimeField, crypto_bigint_boxed_monty::BoxedMontyField, crypto_bigint_uint::Uint, }; use crate::{ - from_ref::FromRef, mul_by_scalar::MulByScalar, projectable_to_field::ProjectableToField, + delayed_reduction::DelayedFieldProductSum, from_ref::FromRef, + inner_transparent_field::InnerTransparentField, mul_by_scalar::MulByScalar, + projectable_to_field::ProjectableToField, }; impl MulByScalar<&Self> for BoxedMontyField { @@ -22,6 +24,33 @@ impl FromRef for BoxedMontyField { } } +impl DelayedFieldProductSum for BoxedMontyField { + #[allow(clippy::arithmetic_side_effects)] + fn delayed_sum_of_products(lhs: &[Self], rhs: &[Self], zero: Self) -> Self { + if lhs.is_empty() { + return zero; + } + + let leading_zeros = zero.cfg().modulus().as_ref().leading_zeros(); + if lhs.len() == 1 || leading_zeros == 0 { + return lhs + .iter() + .zip(rhs) + .fold(zero, |acc, (left, right)| acc + left.clone() * right); + } + + let forms: Vec<(BoxedMontyForm, BoxedMontyForm)> = lhs + .iter() + .zip(rhs) + .map(|(left, right)| (left.clone().into(), right.clone().into())) + .collect(); + let products: Vec<(&BoxedMontyForm, &BoxedMontyForm)> = + forms.iter().map(|(left, right)| (left, right)).collect(); + + Self::from(BoxedMontyForm::lincomb_vartime(&products)) + zero + } +} + impl FromRef> for BoxedUint { #[inline] fn from_ref(value: &Uint) -> Self { @@ -41,6 +70,25 @@ where } } +impl InnerTransparentField for BoxedMontyField { + fn add_inner(lhs: &Self::Inner, rhs: &Self::Inner, config: &Self::Config) -> Self::Inner { + let lhs = BoxedMontyForm::from_montgomery(lhs.clone(), config.clone()); + let rhs = BoxedMontyForm::from_montgomery(rhs.clone(), config.clone()); + (lhs + rhs).to_montgomery() + } + + fn sub_inner(lhs: &Self::Inner, rhs: &Self::Inner, config: &Self::Config) -> Self::Inner { + let lhs = BoxedMontyForm::from_montgomery(lhs.clone(), config.clone()); + let rhs = BoxedMontyForm::from_montgomery(rhs.clone(), config.clone()); + (lhs - rhs).to_montgomery() + } + + fn mul_assign_by_inner(&mut self, rhs: &Self::Inner) { + let rhs = Self::new_unchecked_with_cfg(rhs.clone(), self.cfg()); + *self *= rhs; + } +} + #[cfg(test)] #[allow( clippy::arithmetic_side_effects, @@ -48,6 +96,7 @@ where clippy::cast_possible_wrap )] mod prop_tests { + use crate::delayed_reduction::DelayedFieldProductSum; use crypto_bigint::{BoxedUint, U256}; use crypto_primitives::{ FromWithConfig, IntoWithConfig, PrimeField, crypto_bigint_boxed_monty::BoxedMontyField, @@ -73,6 +122,39 @@ mod prop_tests { any::() } + #[test] + fn delayed_sum_of_products_matches_naive() { + let cfg = get_dyn_config(MODULUS); + let seed = F::from_with_cfg(99u64, &cfg); + let empty = ::delayed_sum_of_products(&[], &[], seed.clone()); + assert_eq!(empty, seed); + + let single_lhs = [F::from_with_cfg(17u64, &cfg)]; + let single_rhs = [F::from_with_cfg(23u64, &cfg)]; + let got = ::delayed_sum_of_products( + &single_lhs, + &single_rhs, + seed.clone(), + ); + assert_eq!(got, seed.clone() + single_lhs[0].clone() * &single_rhs[0]); + + let lhs: Vec = (0..128) + .map(|idx| F::from_with_cfg(idx + 3, &cfg)) + .collect(); + let rhs: Vec = (0..128) + .map(|idx| F::from_with_cfg(257 - idx, &cfg)) + .collect(); + let expected = lhs + .iter() + .zip(&rhs) + .fold(seed.clone(), |acc, (left, right)| { + acc + left.clone() * right + }); + let got = ::delayed_sum_of_products(&lhs, &rhs, seed); + + assert_eq!(got, expected); + } + proptest! { #[test] fn prop_from_unsigned_matches_sum_of_bits(x in any_u128()) { diff --git a/utils/src/field/const_monty.rs b/utils/src/field/const_monty.rs index 3c529850..3c785485 100644 --- a/utils/src/field/const_monty.rs +++ b/utils/src/field/const_monty.rs @@ -31,7 +31,7 @@ macro_rules! impl_from_primitive_ref { )* }; } -impl_from_primitive_ref!(u8, u16, u32, u64, u128); +impl_from_primitive_ref!(u8, u16, u32, u64, u128, i8, i16, i32, i64, i128); impl, const LIMBS: usize> FromRef> for ConstMontyForm @@ -49,14 +49,23 @@ impl, const LIMBS: usize> FromRef } } -impl, const LIMBS: usize, const LIMBS2: usize> - ProjectableToField> for Int +impl, const LIMBS: usize, const LIMBS2: usize> FromRef> + for ConstMontyField +{ + fn from_ref(value: &Int) -> Self { + value.into() + } +} + +impl, const LIMBS: usize> + ProjectableToField> for T +where + ConstMontyField: FromRef, { fn prepare_projection( _sampled_value: &ConstMontyField, ) -> impl Fn(&Self) -> ConstMontyField + Send + Sync + 'static { - // No need to read anything - |value: &Int| value.into() + |value: &T| ConstMontyField::::from_ref(value) } } diff --git a/utils/src/inner_product.rs b/utils/src/inner_product.rs index 53effe36..d9a2655b 100644 --- a/utils/src/inner_product.rs +++ b/utils/src/inner_product.rs @@ -1,4 +1,8 @@ -use crate::{from_ref::FromRef, mul_by_scalar::MulByScalar}; +use crate::{ + delayed_reduction::{DelayedFieldProductSum, DelayedFieldProductSumAlgorithm}, + from_ref::FromRef, + mul_by_scalar::MulByScalar, +}; use crypto_primitives::{FromWithConfig, PrimeField, boolean::Boolean}; use num_traits::CheckedAdd; use thiserror::Error; @@ -86,6 +90,51 @@ impl MBSInnerProduct { } } +/// Field-field inner product backed by a delayed product-sum implementation. +#[derive(Clone, Debug)] +pub struct FieldFieldInnerProduct; + +impl FieldFieldInnerProduct { + pub fn inner_product_with_algorithm( + algorithm: &A, + lhs: &[A::Value], + rhs: &[A::Value], + seed: A::Value, + ) -> Result + where + A: DelayedFieldProductSumAlgorithm, + { + if lhs.len() != rhs.len() { + return Err(InnerProductError::LengthMismatch { + lhs: lhs.len(), + rhs: rhs.len(), + }); + } + + Ok(algorithm.sum_of_products_with_seed(lhs, rhs, seed)) + } +} + +impl InnerProduct<[F], F, F> for FieldFieldInnerProduct +where + F: DelayedFieldProductSum, +{ + fn inner_product( + lhs: &[F], + rhs: &[F], + zero: F, + ) -> Result { + if lhs.len() != rhs.len() { + return Err(InnerProductError::LengthMismatch { + lhs: lhs.len(), + rhs: rhs.len(), + }); + } + + Ok(F::delayed_sum_of_products(lhs, rhs, zero)) + } +} + /// The inner product for vectors of length 1 (a.k.a. scalars). /// Uses `mul_by_scalar` to multiply the only components of vectors /// to get the result. @@ -153,9 +202,12 @@ impl + CheckedAdd> InnerProduct<[Boolean], Rhs, Ou #[cfg(test)] mod test { - use crate::{CHECKED, UNCHECKED}; - use crypto_bigint::{U64, const_monty_params}; - use crypto_primitives::crypto_bigint_const_monty::ConstMontyField; + use crate::{CHECKED, UNCHECKED, delayed_reduction::MontgomeryProductSum4}; + use crypto_bigint::{U64, U256, const_monty_params}; + use crypto_primitives::{ + FromWithConfig, PrimeField, crypto_bigint_const_monty::ConstMontyField, + crypto_bigint_monty::MontyField, crypto_bigint_uint::Uint, + }; use num_traits::ConstZero; use super::*; @@ -198,6 +250,29 @@ mod test { } const_monty_params!(Params, U64, "0000000000000007"); + const_monty_params!( + Params256, + U256, + "00dca94d8a1ecce3b6e8755d8999787d0524d8ca1ea755e7af84fb646fa31f27" + ); + + fn dyn_field_cfg() -> as PrimeField>::Config { + let modulus = Uint::new( + crypto_bigint::Uint::<4>::from_str_radix_vartime( + "00dca94d8a1ecce3b6e8755d8999787d0524d8ca1ea755e7af84fb646fa31f27", + 16, + ) + .expect("valid modulus"), + ); + MontyField::make_cfg(&modulus).expect("valid field config") + } + + #[allow(clippy::arithmetic_side_effects)] + fn naive_field_inner_product(lhs: &[F], rhs: &[F], zero: F) -> F { + lhs.iter() + .zip(rhs) + .fold(zero, |acc, (left, right)| acc + left.clone() * right) + } #[test] fn boolean_unchecked_eq_boolean_checked() { @@ -219,4 +294,115 @@ mod test { BooleanInnerProductAdd::inner_product::(&lhs, &rhs, ConstMontyField::ZERO) ); } + + #[test] + fn field_field_inner_product_monty_matches_naive() { + type F = MontyField<4>; + let cfg = dyn_field_cfg(); + let lhs = [ + F::from_with_cfg(3u64, &cfg), + F::from_with_cfg(5u64, &cfg), + F::from_with_cfg(8u64, &cfg), + F::from_with_cfg(13u64, &cfg), + ]; + let rhs = [ + F::from_with_cfg(21u64, &cfg), + F::from_with_cfg(34u64, &cfg), + F::from_with_cfg(55u64, &cfg), + F::from_with_cfg(89u64, &cfg), + ]; + let zero = F::zero_with_cfg(&cfg); + + let got = + FieldFieldInnerProduct::inner_product::(&lhs, &rhs, zero.clone()).unwrap(); + let expected = naive_field_inner_product(&lhs, &rhs, zero); + + assert_eq!(got, expected); + } + + #[test] + fn field_field_inner_product_with_algorithm_matches_naive() { + type F = MontyField<4>; + let cfg = dyn_field_cfg(); + let algorithm = MontgomeryProductSum4::::new(&cfg); + let lhs: Vec = (0..24).map(|idx| F::from_with_cfg(idx + 3, &cfg)).collect(); + let rhs: Vec = (0..24) + .map(|idx| F::from_with_cfg(89 - idx, &cfg)) + .collect(); + let seed = F::from_with_cfg(99u64, &cfg); + + let got = FieldFieldInnerProduct::inner_product_with_algorithm::( + &algorithm, + &lhs, + &rhs, + seed.clone(), + ) + .unwrap(); + let expected = naive_field_inner_product(&lhs, &rhs, seed); + + assert_eq!(got, expected); + } + + #[test] + fn field_field_inner_product_const_monty_matches_naive() { + type F = ConstMontyField; + let lhs = [F::from(2u64), F::from(7u64), F::from(19u64), F::from(31u64)]; + let rhs = [ + F::from(43u64), + F::from(59u64), + F::from(61u64), + F::from(71u64), + ]; + + let got = FieldFieldInnerProduct::inner_product::(&lhs, &rhs, F::ZERO).unwrap(); + let expected = naive_field_inner_product(&lhs, &rhs, F::ZERO); + + assert_eq!(got, expected); + } + + #[test] + fn field_field_inner_product_empty_returns_zero() { + type F = ConstMontyField; + let zero = F::from(99u64); + + let got = FieldFieldInnerProduct::inner_product::(&[], &[], zero).unwrap(); + + assert_eq!(got, zero); + } + + #[test] + fn field_field_inner_product_single_term_matches_naive() { + type F = ConstMontyField; + let lhs = [F::from(144u64)]; + let rhs = [F::from(233u64)]; + + let got = FieldFieldInnerProduct::inner_product::(&lhs, &rhs, F::ZERO).unwrap(); + let expected = naive_field_inner_product(&lhs, &rhs, F::ZERO); + + assert_eq!(got, expected); + } + + #[test] + fn field_field_inner_product_nonzero_seed_matches_naive() { + type F = ConstMontyField; + let lhs = [F::from(5u64), F::from(8u64), F::from(13u64)]; + let rhs = [F::from(21u64), F::from(34u64), F::from(55u64)]; + let seed = F::from(99u64); + + let got = FieldFieldInnerProduct::inner_product::(&lhs, &rhs, seed).unwrap(); + let expected = naive_field_inner_product(&lhs, &rhs, seed); + + assert_eq!(got, expected); + } + + #[test] + fn field_field_inner_product_length_mismatch() { + type F = ConstMontyField; + let lhs = [F::from(1u64)]; + + assert_eq!( + FieldFieldInnerProduct::inner_product::(&lhs, &[], F::ZERO), + Err(InnerProductError::LengthMismatch { lhs: 1, rhs: 0 }) + ); + } } diff --git a/utils/src/lib.rs b/utils/src/lib.rs index 17ccba30..2d04445b 100644 --- a/utils/src/lib.rs +++ b/utils/src/lib.rs @@ -1,3 +1,4 @@ +pub mod delayed_reduction; pub mod field; pub mod from_ref; pub mod inner_product; diff --git a/zinc+.pdf b/zinc+.pdf new file mode 100644 index 00000000..b1c6ba02 Binary files /dev/null and b/zinc+.pdf differ diff --git a/zip-plus/Cargo.toml b/zip-plus/Cargo.toml index 9d63f86c..817aa31a 100644 --- a/zip-plus/Cargo.toml +++ b/zip-plus/Cargo.toml @@ -17,11 +17,15 @@ zinc-transcript = { workspace = true } zinc-utils = { workspace = true } ark-ff = { version = "0.5.0", default-features = false } +ark-ec = { version = "0.5.0", default-features = false } ark-poly = { version = "0.5.0", default-features = false } +ark-serialize = { version = "0.5.0", default-features = false } +ark-std = { version = "0.5.0", default-features = false, features = ["std", "getrandom"] } itertools = { workspace = true } thiserror = { workspace = true } crypto-bigint = { workspace = true } +num-integer = { workspace = true } num-traits = { workspace = true } rand = { workspace = true } rand_core = { workspace = true } @@ -33,6 +37,8 @@ uninit = "0.6.2" zstd = "0.13" [dev-dependencies] +ark-bn254 = { version = "0.5.0", default-features = false, features = ["curve"] } +ark-secp256k1 = { version = "0.5.0", default-features = false } criterion = { workspace = true } proptest = { workspace = true } @@ -40,7 +46,7 @@ proptest = { workspace = true } workspace = true [features] -parallel = ["dep:rayon", "zinc-utils/parallel", "zinc-poly/parallel", "ark-ff/parallel", "ark-poly/parallel"] +parallel = ["dep:rayon", "zinc-utils/parallel", "zinc-poly/parallel", "ark-ff/parallel", "ark-ec/parallel", "ark-poly/parallel"] simd = ["zinc-poly/simd"] unchecked = [] @@ -51,3 +57,11 @@ harness = false [[bench]] name = "zip_plus_benches" harness = false + +[[bench]] +name = "msm_commitment_benches" +harness = false + +[[bench]] +name = "hyrax_commit_breakdown" +harness = false diff --git a/zip-plus/benches/hyrax_commit_breakdown.rs b/zip-plus/benches/hyrax_commit_breakdown.rs new file mode 100644 index 00000000..3382c001 --- /dev/null +++ b/zip-plus/benches/hyrax_commit_breakdown.rs @@ -0,0 +1,253 @@ +use ark_bn254::G1Affine; +use ark_ec::{AffineRepr, CurveGroup, PrimeGroup}; +use ark_ff::Zero as ArkZero; +use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; +use crypto_primitives::crypto_bigint_monty::MontyField; +use std::hint::black_box; +use zinc_poly::{mle::DenseMultilinearExtension, univariate::binary::BinaryPoly}; +use zip_plus::pcs::{ + generic::PCS, + hyrax::{BinaryLanes, HyraxBlindingMode, HyraxCommitmentKey, HyraxPCS}, + msm_commitment::{BoolSubsetMsm, MsmCommitmentEngine, MsmCommitmentKey, RowMsmStrategy}, +}; + +type F = MontyField<4>; + +fn scalar(value: usize) -> C::ScalarField { + C::ScalarField::from(u64::try_from(value).expect("benchmark value must fit into u64")) +} + +fn bases_and_h(width: usize) -> (Vec, C::Group) { + let generator = C::Group::generator(); + let bases = (1..=width) + .map(|idx| (generator * scalar::(idx)).into_affine()) + .collect(); + let h = generator * scalar::(width + 1); + (bases, h) +} + +fn msm_ck(width: usize) -> MsmCommitmentKey { + let (bases, h) = bases_and_h::(width); + MsmCommitmentEngine::::setup_from_bases(width, bases, h) + .expect("benchmark setup must be valid") + .0 +} + +fn hyrax_ck(width: usize) -> HyraxCommitmentKey { + HyraxPCS::::setup( + width, + b"zinc-plus-hyrax-commit-breakdown-bench", + HyraxBlindingMode::Unblinded, + ) + .expect("benchmark setup must be valid") + .0 +} + +fn bool_row(width: usize) -> Vec { + (0..width) + .map(|idx| idx % 3 == 0 || idx % 11 == 1) + .collect() +} + +fn bool_values(num_lanes: usize, width: usize) -> Vec> { + (0..num_lanes) + .map(|lane| { + (0..width) + .map(|idx| (idx + lane) % 3 == 0 || (idx * 7 + lane) % 19 == 2) + .collect() + }) + .collect() +} + +fn bit_mask(bits: &[bool]) -> usize { + bits.iter().enumerate().fold( + 0usize, + |mask, (idx, bit)| { + if *bit { mask | (1usize << idx) } else { mask } + }, + ) +} + +fn subset_tables(bases: &[C], window_bits: usize) -> Vec> { + bases + .chunks(window_bits) + .map(|window| { + let table_len = 1usize << window.len(); + let mut table = vec![C::Group::zero(); table_len]; + for mask in 1..table_len { + let bit = mask.trailing_zeros() as usize; + let previous = mask & !(1usize << bit); + table[mask] = table[previous] + window[bit]; + } + table + }) + .collect() +} + +fn precomputed_bool_row( + tables: &[Vec], + values: &[bool], + window_bits: usize, +) -> C::Group { + let mut acc = C::Group::zero(); + for (window_idx, bits) in values.chunks(window_bits).enumerate() { + acc += tables[window_idx][bit_mask(bits)]; + } + acc +} + +fn binary_polys( + batch_size: usize, + num_vars: usize, +) -> Vec>> { + let n = 1usize << num_vars; + (0..batch_size) + .map(|poly_idx| { + let evals = (0..n) + .map(|row_idx| { + let mut value = (row_idx as u32).wrapping_mul(0x9e37_79b9); + value ^= (poly_idx as u32).wrapping_mul(0x85eb_ca6b); + value = value.rotate_left( + u32::try_from((row_idx + poly_idx) % 32).expect("rotation must fit"), + ); + value ^= value >> 16; + BinaryPoly::<32>::from(value) + }) + .collect(); + DenseMultilinearExtension::from_evaluations_vec(num_vars, evals, BinaryPoly::zero()) + }) + .collect() +} + +fn bench_curve( + c: &mut Criterion, + curve_name: &str, + batch_size: usize, + width: usize, + num_vars: usize, +) { + let mut group = c.benchmark_group("hyrax_commit_breakdown"); + let lanes = batch_size * 32; + let row = bool_row(width); + let lane_rows = bool_values(lanes, width); + let (bases, h) = bases_and_h::(width); + let tables = subset_tables::(&bases, 6); + let precomputed_blinds = (0..lanes) + .map(|idx| scalar::(idx + 17)) + .collect::>(); + let msm_ck = msm_ck::(width); + let blind_one = MsmCommitmentEngine::::blind(&msm_ck, width); + let hyrax_ck = hyrax_ck::(width); + let polys = binary_polys(batch_size, num_vars); + + group.bench_with_input( + BenchmarkId::new(format!("{curve_name}/bool_row_msm"), width), + &width, + |b, _| { + b.iter(|| { + as RowMsmStrategy>::msm_row( + black_box(&msm_ck), + black_box(&row), + ) + .expect("row MSM must succeed") + }); + }, + ); + + group.bench_with_input( + BenchmarkId::new(format!("{curve_name}/precomputed_bool_row_msm"), width), + &width, + |b, _| { + b.iter(|| precomputed_bool_row::(black_box(&tables), black_box(&row), 6)); + }, + ); + + group.bench_with_input( + BenchmarkId::new(format!("{curve_name}/commit_one_bool_lane"), width), + &width, + |b, _| { + b.iter(|| { + MsmCommitmentEngine::::commit_with::>( + black_box(&msm_ck), + black_box(&row), + black_box(&blind_one), + ) + .expect("one-lane commitment must succeed") + }); + }, + ); + + group.bench_with_input( + BenchmarkId::new( + format!("{curve_name}/precomputed_commit_352_bool_lanes"), + width, + ), + &width, + |b, _| { + b.iter(|| { + let mut acc = Vec::with_capacity(lane_rows.len()); + for (values, blind) in lane_rows.iter().zip(precomputed_blinds.iter()) { + let mut commitment = + precomputed_bool_row::(black_box(&tables), black_box(values), 6); + commitment += h * blind; + acc.push(commitment); + } + acc + }); + }, + ); + + group.bench_with_input( + BenchmarkId::new(format!("{curve_name}/commit_352_bool_lanes"), width), + &width, + |b, _| { + b.iter(|| { + let mut acc = Vec::with_capacity(lane_rows.len()); + for values in &lane_rows { + let blind = MsmCommitmentEngine::::blind(&msm_ck, values.len()); + let commitment = + MsmCommitmentEngine::::commit_with::>( + black_box(&msm_ck), + black_box(values), + black_box(&blind), + ) + .expect("lane commitment must succeed"); + acc.push(commitment); + } + acc + }); + }, + ); + + group.bench_with_input( + BenchmarkId::new(format!("{curve_name}/hyrax_binary_commit_batch11"), width), + &width, + |b, _| { + b.iter(|| { + as PCS, 32>>::commit( + black_box(&hyrax_ck), + black_box(&polys), + ) + .expect("Hyrax binary commit must succeed") + }); + }, + ); + + group.finish(); +} + +fn hyrax_commit_breakdown(c: &mut Criterion) { + let width = 512; + let num_vars = 9; + let batch_size = 11; + + bench_curve::(c, "bn254", batch_size, width, num_vars); + bench_curve::(c, "secp256k1", batch_size, width, num_vars); +} + +criterion_group! { + name = benches; + config = Criterion::default().sample_size(10); + targets = hyrax_commit_breakdown +} +criterion_main!(benches); diff --git a/zip-plus/benches/msm_commitment_benches.rs b/zip-plus/benches/msm_commitment_benches.rs new file mode 100644 index 00000000..a980f60e --- /dev/null +++ b/zip-plus/benches/msm_commitment_benches.rs @@ -0,0 +1,127 @@ +use ark_bn254::G1Affine; +use ark_ec::{AffineRepr, CurveGroup, PrimeGroup}; +use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; +use std::hint::black_box; +use zip_plus::pcs::msm_commitment::{ + BoolSubsetMsm, MsmCommitmentEngine, MsmCommitmentKey, ScalarPippengerMsm, U8BucketMsm, +}; + +fn scalar(value: usize) -> C::ScalarField { + C::ScalarField::from(u64::try_from(value).expect("benchmark value must fit into u64")) +} + +fn setup(width: usize, n: usize) -> MsmCommitmentKey { + let generator = C::Group::generator(); + let bases = (1..=width) + .map(|idx| (generator * scalar::(idx)).into_affine()) + .collect(); + let h = generator * scalar::(width + 1); + let (ck, _) = MsmCommitmentEngine::::setup_from_bases(width, bases, h) + .expect("benchmark setup must be valid"); + let _blind = MsmCommitmentEngine::::blind(&ck, n); + ck +} + +fn bool_values(n: usize) -> Vec { + (0..n).map(|idx| idx % 3 == 0 || idx % 11 == 1).collect() +} + +fn u8_values(n: usize, modulus: u8) -> Vec { + (0..n) + .map(|idx| { + let value = (idx * 17 + 5) % usize::from(modulus); + u8::try_from(value).expect("benchmark u8 value must fit") + }) + .collect() +} + +fn scalar_values(values: &[u8]) -> Vec { + values + .iter() + .map(|value| C::ScalarField::from(u64::from(*value))) + .collect() +} + +fn bench_curve( + group: &mut criterion::BenchmarkGroup, + curve_name: &str, + width: usize, + n: usize, +) { + let ck = setup::(width, n); + let blind = MsmCommitmentEngine::::blind(&ck, n); + let bools = bool_values(n); + let u8_small = u8_values(n, 32); + let u8_full = u8_values(n, 255); + let scalars = scalar_values::(&u8_full); + + group.bench_with_input( + BenchmarkId::new(format!("{curve_name}/bool_subset"), width), + &width, + |b, _| { + b.iter(|| { + MsmCommitmentEngine::::commit_with::>( + black_box(&ck), + black_box(&bools), + black_box(&blind), + ) + .expect("bool benchmark commit must succeed") + }); + }, + ); + group.bench_with_input( + BenchmarkId::new(format!("{curve_name}/u8_0_32"), width), + &width, + |b, _| { + b.iter(|| { + MsmCommitmentEngine::::commit_with::( + black_box(&ck), + black_box(&u8_small), + black_box(&blind), + ) + .expect("u8 small benchmark commit must succeed") + }); + }, + ); + group.bench_with_input( + BenchmarkId::new(format!("{curve_name}/u8_0_255"), width), + &width, + |b, _| { + b.iter(|| { + MsmCommitmentEngine::::commit_with::( + black_box(&ck), + black_box(&u8_full), + black_box(&blind), + ) + .expect("u8 full benchmark commit must succeed") + }); + }, + ); + group.bench_with_input( + BenchmarkId::new(format!("{curve_name}/scalar_pippenger"), width), + &width, + |b, _| { + b.iter(|| { + MsmCommitmentEngine::::commit_with::( + black_box(&ck), + black_box(&scalars), + black_box(&blind), + ) + .expect("scalar benchmark commit must succeed") + }); + }, + ); +} + +fn msm_commitment_benches(c: &mut Criterion) { + let width = 64; + let n = width * 1024; + + let mut group = c.benchmark_group("msm_commitment"); + bench_curve::(&mut group, "bn254", width, n); + bench_curve::(&mut group, "secp256k1", width, n); + group.finish(); +} + +criterion_group!(benches, msm_commitment_benches); +criterion_main!(benches); diff --git a/zip-plus/src/code/iprs.rs b/zip-plus/src/code/iprs.rs index 0ed13ab4..1b2b9107 100644 --- a/zip-plus/src/code/iprs.rs +++ b/zip-plus/src/code/iprs.rs @@ -51,8 +51,7 @@ where let target_base_len = 1 << MAX_BASE_COLS_LOG2; // We want depth to be at least 1. - let base_depth = - 1.max(((1.max(row_len / target_base_len)).ilog2() as usize).div_ceil(3)); + let base_depth = 1.max(((1.max(row_len / target_base_len)).ilog2() as usize).div_ceil(3)); let extra = if REP >= 16 { 2 } else if REP >= 8 { diff --git a/zip-plus/src/merkle.rs b/zip-plus/src/merkle.rs index aee86c6c..13ed7448 100644 --- a/zip-plus/src/merkle.rs +++ b/zip-plus/src/merkle.rs @@ -98,11 +98,7 @@ impl MerkleTree { /// from all three groups in fixed order (0, 1, 2). Used by /// [`crate::pcs::multi_zip::MultiZip3`] to commit two or three /// heterogeneous Zip+ instances under a single tree. - pub fn new_combined_3( - rows0: &[&[S0]], - rows1: &[&[S1]], - rows2: &[&[S2]], - ) -> Self + pub fn new_combined_3(rows0: &[&[S0]], rows1: &[&[S1]], rows2: &[&[S2]]) -> Self where S0: ConstTranscribable + Send + Sync, S1: ConstTranscribable + Send + Sync, diff --git a/zip-plus/src/pcs.rs b/zip-plus/src/pcs.rs index da82e61f..24c85f0a 100644 --- a/zip-plus/src/pcs.rs +++ b/zip-plus/src/pcs.rs @@ -5,6 +5,9 @@ mod phase_verify; pub use phase_prove::ZipPlusProveByteBreakdown; pub use phase_verify::{VerifyPreOpen, VerifyPreOpenReads}; pub mod folding; +pub mod generic; +pub mod hyrax; +pub mod msm_commitment; pub mod multi_zip; pub mod structs; #[cfg(test)] diff --git a/zip-plus/src/pcs/folding.rs b/zip-plus/src/pcs/folding.rs index 211a1365..620a1e02 100644 --- a/zip-plus/src/pcs/folding.rs +++ b/zip-plus/src/pcs/folding.rs @@ -338,11 +338,8 @@ mod tests { fn split_preserves_reconstruction() { // v[i](X=2) = u[i](2) + 2^16 * w[i](2) let val: u32 = 0xABCD_1234; - let col = DenseMultilinearExtension::from_evaluations_vec( - 0, - vec![bp32(val)], - BinaryPoly::zero(), - ); + let col = + DenseMultilinearExtension::from_evaluations_vec(0, vec![bp32(val)], BinaryPoly::zero()); let split = split_column::<32, 16>(&col); assert_eq!(split.evaluations.len(), 2); diff --git a/zip-plus/src/pcs/generic.rs b/zip-plus/src/pcs/generic.rs new file mode 100644 index 00000000..f677b795 --- /dev/null +++ b/zip-plus/src/pcs/generic.rs @@ -0,0 +1,276 @@ +use std::{fmt::Debug, io::Cursor, marker::PhantomData}; + +use crypto_primitives::{FromPrimitiveWithConfig, FromWithConfig, PrimeField}; +use zinc_poly::{ + mle::DenseMultilinearExtension, univariate::dynamic::over_field::DynamicPolynomialF, +}; +use zinc_transcript::traits::{GenTranscribable, Transcribable, Transcript}; +use zinc_utils::{ + delayed_reduction::DelayedFieldProductSum, from_ref::FromRef, mul_by_scalar::MulByScalar, +}; + +use crate::{ + ZipError, + code::LinearCode, + pcs::structs::{ZipPlus, ZipPlusCommitment, ZipPlusHint, ZipPlusParams, ZipTypes}, + pcs_transcript::{PcsProverTranscript, PcsVerifierTranscript}, +}; + +/// Polynomial commitment scheme interface used by the Zinc+ protocol. +/// +/// `Eval` is the unprojected witness cell type committed by the backend. +pub trait PCS: Clone + Debug + Send + Sync +where + F: PrimeField, + Eval: Clone + Debug + Send + Sync, +{ + type CommitmentKey: Clone + Debug + Send + Sync; + type VerifierKey: Clone + Debug + Send + Sync; + type Commitment: Clone + Debug + Send + Sync; + type ProverData: Clone + Debug + Send + Sync; + type OpeningProof: Clone + Debug + Send + Sync + Default; + + fn precompute_ck(_ck: &Self::CommitmentKey) {} + + fn commit( + ck: &Self::CommitmentKey, + polys: &[DenseMultilinearExtension], + ) -> Result<(Self::ProverData, Self::Commitment), ZipError>; + + fn absorb_commitment(transcript: &mut T, commitment: &Self::Commitment); + + fn commitment_num_bytes(commitment: &Self::Commitment) -> usize; + + fn write_commitment_bytes(commitment: &Self::Commitment, buf: &mut Vec); + + fn batch_size(commitment: &Self::Commitment) -> usize; + + fn prove_open( + transcript: &mut PcsProverTranscript, + ck: &Self::CommitmentKey, + polys: &[DenseMultilinearExtension], + point: &[F], + prover_data: &Self::ProverData, + field_cfg: &F::Config, + ) -> Result + where + F::Inner: Transcribable, + F::Modulus: Transcribable; + + fn verify_open( + transcript: &mut PcsVerifierTranscript, + vk: &Self::VerifierKey, + commitment: &Self::Commitment, + point: &[F], + lifted_evals: &[DynamicPolynomialF], + opening_proof: &Self::OpeningProof, + field_cfg: &F::Config, + ) -> Result<(), ZipError> + where + F::Inner: Transcribable, + F::Modulus: Transcribable; +} + +/// Homomorphic extension of [`PCS`] used by instance-axis folding protocols. +/// +/// Implementations must satisfy: +/// +/// ```text +/// fold_commitments([Com(w_i; eta_i)], theta) +/// = Com(sum_i theta_i w_i; sum_i theta_i eta_i) +/// ``` +/// +/// Non-homomorphic commitments, such as Merkle roots, must not implement this +/// trait. +pub trait FoldablePCS: PCS +where + F: PrimeField, + Eval: Clone + Debug + Send + Sync, +{ + fn fold_commitments( + commitments: &[Self::Commitment], + theta: &[F], + field_cfg: &F::Config, + ) -> Result; + + fn fold_commitment_refs( + commitments: &[&Self::Commitment], + theta: &[F], + field_cfg: &F::Config, + ) -> Result { + let owned = commitments + .iter() + .map(|commitment| (*commitment).clone()) + .collect::>(); + Self::fold_commitments(&owned, theta, field_cfg) + } + + fn fold_prover_data( + prover_data: &[Self::ProverData], + theta: &[F], + field_cfg: &F::Config, + ) -> Result; +} + +#[derive(Clone, Debug)] +pub struct ZipPlusPCS>(PhantomData<(Zt, Lc)>); + +impl PCS for ZipPlusPCS +where + F: PrimeField + + DelayedFieldProductSum + + FromPrimitiveWithConfig + + for<'a> FromWithConfig<&'a Zt::CombR> + + for<'a> FromWithConfig<&'a Zt::Chal> + + for<'a> MulByScalar<&'a F> + + FromRef, + Zt: ZipTypes, + Zt::Eval: Clone + Debug + Send + Sync, + Lc: LinearCode, + F::Modulus: zinc_utils::from_ref::FromRef, +{ + type CommitmentKey = ZipPlusParams; + type VerifierKey = ZipPlusParams; + type Commitment = ZipPlusCommitment; + type ProverData = Option>; + type OpeningProof = Vec; + + fn commit( + ck: &Self::CommitmentKey, + polys: &[DenseMultilinearExtension], + ) -> Result<(Self::ProverData, Self::Commitment), ZipError> { + if polys.is_empty() { + return Ok((None, ZipPlusCommitment::default())); + } + let (hint, commitment) = ZipPlus::::commit(ck, polys)?; + Ok((Some(hint), commitment)) + } + + fn absorb_commitment(transcript: &mut T, commitment: &Self::Commitment) { + transcript.absorb_slice(&commitment.root); + transcript.absorb_slice(&(commitment.batch_size as u64).to_le_bytes()); + } + + fn commitment_num_bytes(commitment: &Self::Commitment) -> usize { + commitment.get_num_bytes() + } + + fn write_commitment_bytes(commitment: &Self::Commitment, buf: &mut Vec) { + let offset = buf.len(); + buf.resize(offset + commitment.get_num_bytes(), 0); + commitment.write_transcription_bytes_exact(&mut buf[offset..]); + } + + fn batch_size(commitment: &Self::Commitment) -> usize { + commitment.batch_size + } + + fn prove_open( + transcript: &mut PcsProverTranscript, + ck: &Self::CommitmentKey, + polys: &[DenseMultilinearExtension], + point: &[F], + prover_data: &Self::ProverData, + field_cfg: &F::Config, + ) -> Result + where + F::Inner: Transcribable, + F::Modulus: Transcribable, + { + let start = transcript.stream.position() as usize; + match (polys.is_empty(), prover_data) { + (true, None) => {} + (true, Some(_)) => { + return Err(ZipError::InvalidPcsParam( + "Zip+ prover data must be empty for an empty batch".to_string(), + )); + } + (false, None) => { + return Err(ZipError::InvalidPcsParam( + "Zip+ prover data missing for non-empty batch".to_string(), + )); + } + (false, Some(hint)) => { + let _ = ZipPlus::::prove_f::<_, CHECK_FOR_OVERFLOW>( + transcript, ck, polys, point, hint, field_cfg, + )?; + } + } + let end = transcript.stream.position() as usize; + Ok(transcript.stream.get_ref()[start..end].to_vec()) + } + + fn verify_open( + transcript: &mut PcsVerifierTranscript, + vk: &Self::VerifierKey, + commitment: &Self::Commitment, + point: &[F], + lifted_evals: &[DynamicPolynomialF], + opening_proof: &Self::OpeningProof, + field_cfg: &F::Config, + ) -> Result<(), ZipError> + where + F::Inner: Transcribable, + F::Modulus: Transcribable, + { + if !opening_proof.is_empty() { + let original_stream = + std::mem::replace(&mut transcript.stream, Cursor::new(opening_proof.clone())); + let result = >::verify_open::( + transcript, + vk, + commitment, + point, + lifted_evals, + &Vec::new(), + field_cfg, + ); + let consumed = transcript.stream.position() == opening_proof.len() as u64; + transcript.stream = original_stream; + result?; + if !consumed { + return Err(ZipError::InvalidPcsOpen( + "PCS opening proof has trailing bytes".to_string(), + )); + } + return Ok(()); + } + + if lifted_evals.len() != commitment.batch_size { + return Err(ZipError::InvalidPcsParam(format!( + "Zip+ verifier expected {} lifted evals, got {}", + commitment.batch_size, + lifted_evals.len() + ))); + } + if commitment.batch_size == 0 { + if commitment.root != Default::default() { + return Err(ZipError::InvalidPcsParam( + "Zip+ empty batch must use the canonical empty commitment".to_string(), + )); + } + return Ok(()); + } + + let per_poly_alphas = + ZipPlus::::sample_alphas(&mut transcript.fs_transcript, commitment.batch_size); + let mut eval_f = F::zero_with_cfg(field_cfg); + for (bar_u, alphas) in lifted_evals.iter().zip(per_poly_alphas.iter()) { + for (coeff, alpha) in bar_u.coeffs.iter().zip(alphas.iter()) { + let mut term = F::from_with_cfg(alpha, field_cfg); + term *= coeff; + eval_f += &term; + } + } + + ZipPlus::::verify_with_alphas::( + transcript, + vk, + commitment, + field_cfg, + point, + &eval_f, + &per_poly_alphas, + ) + } +} diff --git a/zip-plus/src/pcs/hyrax.rs b/zip-plus/src/pcs/hyrax.rs new file mode 100644 index 00000000..695c056b --- /dev/null +++ b/zip-plus/src/pcs/hyrax.rs @@ -0,0 +1,3133 @@ +#![allow(clippy::arithmetic_side_effects)] + +use std::{ + collections::HashSet, + fmt::Debug, + io::{Cursor, Read, Write}, + marker::PhantomData, +}; + +use ark_ec::{AffineRepr, CurveGroup, VariableBaseMSM}; +use ark_ff::{AdditiveGroup, BigInteger, PrimeField as ArkPrimeField, UniformRand, Zero}; +use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, Compress}; +use crypto_bigint::{BoxedUint, modular::BoxedMontyForm}; +use crypto_primitives::{ + FromWithConfig, IntRing, PrimeField, crypto_bigint_boxed_monty::BoxedMontyField, + crypto_bigint_int::Int, crypto_bigint_monty::MontyField, crypto_bigint_uint::Uint, +}; +use num_integer::Integer; +use zinc_poly::{ + mle::DenseMultilinearExtension, + univariate::{ + binary::BinaryPoly, dense::DensePolynomial, dynamic::over_field::DynamicPolynomialF, + }, +}; +use zinc_transcript::traits::{ConstTranscribable, GenTranscribable, Transcribable, Transcript}; +use zinc_utils::{cfg_into_iter, cfg_iter, delayed_reduction::DelayedFieldProductSum}; + +#[cfg(feature = "parallel")] +use rayon::prelude::*; + +use crate::{ + ZipError, + pcs::{ + generic::{FoldablePCS, PCS}, + msm_commitment::{ + BoolSubsetMsm, MsmCommitmentEngine, MsmCommitmentKey, MsmError, RowMsmStrategy, + ScalarPippengerMsm, SignedIntPippengerMsm, + }, + }, + pcs_transcript::{PcsProverTranscript, PcsVerifierTranscript}, +}; + +#[derive(Clone, Debug)] +pub struct HyraxPCS(PhantomData<(C, Lanes)>); + +impl HyraxPCS +where + C: AffineRepr, +{ + /// Open a folded Hyrax commitment whose lane values are already scalar + /// field elements. + /// + /// This is needed for instance-axis folds of binary commitments: after + /// folding by transcript-derived weights, each bit lane is a scalar field + /// linear combination of bits, not a `bool`. + #[allow(clippy::arithmetic_side_effects)] + pub fn prove_open_scalar_lanes( + transcript: &mut PcsProverTranscript, + ck: &HyraxCommitmentKey, + scalar_lanes: &[Vec>], + point: &[F], + prover_data: &HyraxProverData, + field_cfg: &F::Config, + ) -> Result<(), ZipError> + where + F: HyraxFieldBridge, + F::Inner: Transcribable, + F::Modulus: Transcribable, + Lanes: Clone + Debug + Send + Sync, + { + let _ = CHECK_FOR_OVERFLOW; + if scalar_lanes.is_empty() { + return Ok(()); + } + validate_scalar_lanes::(ck, scalar_lanes, point.len(), prover_data)?; + + let n = scalar_lanes[0][0].len(); + let point_scalar = point + .iter() + .map(F::field_to_scalar) + .collect::, _>>()?; + let row_vars = prover_data.num_rows.ilog2() as usize; + let q0_f = eq_tensor_f::(&point[..row_vars], field_cfg); + let q1_scalar = eq_tensor_scalar::(&point_scalar[row_vars..]); + let alphas = sample_scalars::( + &mut transcript.fs_transcript, + scalar_lanes.len() * prover_data.num_lanes, + ); + + let mut b_scalar = vec![C::ScalarField::zero(); prover_data.num_rows]; + for (poly_idx, lanes) in scalar_lanes.iter().enumerate() { + for (lane, values) in lanes.iter().enumerate() { + let alpha = alphas[alpha_index_dynamic(prover_data.num_lanes, poly_idx, lane)]; + for (row_idx, row) in values.chunks(ck.num_cols).enumerate() { + let mut row_eval = C::ScalarField::zero(); + for (col_idx, value) in row.iter().enumerate() { + if let Some(weight) = q1_scalar.get(col_idx) { + row_eval += *value * weight; + } + } + b_scalar[row_idx] += alpha * row_eval; + } + } + } + + let b_f = b_scalar + .iter() + .map(|value| F::scalar_to_field(value, field_cfg)) + .collect::, _>>()?; + transcript.write_field_elements(&b_f)?; + + let row_coeffs = if prover_data.num_rows == 1 { + vec![C::ScalarField::from(1u64)] + } else { + sample_scalars::(&mut transcript.fs_transcript, prover_data.num_rows) + }; + + let mut combined_row = vec![C::ScalarField::zero(); ck.num_cols]; + let mut rho_star = C::ScalarField::zero(); + for (poly_idx, lanes) in scalar_lanes.iter().enumerate() { + for (lane, values) in lanes.iter().enumerate() { + let alpha = alphas[alpha_index_dynamic(prover_data.num_lanes, poly_idx, lane)]; + for (row_idx, row) in values.chunks(ck.num_cols).enumerate() { + let coeff = alpha * row_coeffs[row_idx]; + if ck.blinding_mode.is_blinded() { + let blind_idx = commitment_index_dynamic( + prover_data.num_lanes, + poly_idx, + lane, + row_idx, + prover_data.num_rows, + ); + rho_star += coeff * prover_data.blinds[blind_idx]; + } + for (col_idx, value) in row.iter().enumerate() { + combined_row[col_idx] += coeff * value; + } + } + } + } + + write_scalars::(transcript, &combined_row)?; + if ck.blinding_mode.is_blinded() { + write_scalar::(transcript, &rho_star)?; + } + + if q0_f.len() != b_f.len() || n != (1usize << point.len()) { + return Err(ZipError::InvalidPcsOpen( + "Hyrax folded scalar-lane opening shape mismatch".to_string(), + )); + } + + Ok(()) + } + + /// Open a folded single-row Hyrax commitment from protocol-field lanes. + /// + /// ProjectionFold folds binary witnesses by protocol-field challenges, so + /// folded bit lanes are already field elements rather than booleans. The + /// generic scalar-lane path first converts every lane entry into the curve + /// scalar field and then scans the matrix twice. For the SHA benchmark the + /// Hyrax width is the whole row domain (`num_rows == 1`), so we can compute + /// the transcript's combined row directly in the protocol field, derive the + /// single `b` value from it, and convert only the final row entries that are + /// written to the proof stream. + #[allow(clippy::arithmetic_side_effects)] + pub fn prove_open_field_lanes_single_row( + transcript: &mut PcsProverTranscript, + ck: &HyraxCommitmentKey, + field_lanes: &[Vec<&[F]>], + point: &[F], + prover_data: &HyraxProverData, + field_cfg: &F::Config, + ) -> Result<(), ZipError> + where + F: HyraxFieldBridge + DelayedFieldProductSum, + F::Inner: Transcribable, + F::Modulus: Transcribable, + Lanes: Clone + Debug + Send + Sync, + { + let _ = CHECK_FOR_OVERFLOW; + if field_lanes.is_empty() { + return Ok(()); + } + validate_field_lanes::(ck, field_lanes, point.len(), prover_data)?; + if prover_data.num_rows != 1 { + return Err(ZipError::InvalidPcsParam( + "Hyrax field-lane fast opening requires a single row".to_string(), + )); + } + + let q1 = eq_tensor_f::(point, field_cfg); + let alphas = sample_scalars::( + &mut transcript.fs_transcript, + field_lanes.len() * prover_data.num_lanes, + ); + let alpha_fields = alphas + .iter() + .map(|alpha| F::scalar_to_field(alpha, field_cfg)) + .collect::, _>>()?; + + let mut combined_row = vec![F::zero_with_cfg(field_cfg); ck.num_cols]; + let mut rho_star = C::ScalarField::zero(); + for (poly_idx, lanes) in field_lanes.iter().enumerate() { + for (lane, values) in lanes.iter().enumerate() { + let alpha_idx = alpha_index_dynamic(prover_data.num_lanes, poly_idx, lane); + let alpha = &alpha_fields[alpha_idx]; + for (acc, value) in combined_row.iter_mut().zip(values.iter()) { + *acc += value.clone() * alpha.clone(); + } + if ck.blinding_mode.is_blinded() { + let blind_idx = commitment_index_dynamic( + prover_data.num_lanes, + poly_idx, + lane, + 0, + prover_data.num_rows, + ); + rho_star += alphas[alpha_idx] * prover_data.blinds[blind_idx]; + } + } + } + + let b = F::delayed_sum_of_products(&combined_row, &q1, F::zero_with_cfg(field_cfg)); + transcript.write_field_elements(&[b])?; + + let combined_scalars = combined_row + .iter() + .map(F::field_to_scalar) + .collect::, _>>()?; + write_scalars::(transcript, &combined_scalars)?; + if ck.blinding_mode.is_blinded() { + write_scalar::(transcript, &rho_star)?; + } + + Ok(()) + } + + /// Open two folded Hyrax commitments that share the same row bases as one + /// mixed single-row proof. + #[allow(clippy::arithmetic_side_effects)] + #[allow(clippy::too_many_arguments)] + pub fn prove_open_two_field_lane_groups_single_row( + transcript: &mut PcsProverTranscript, + ck_a: &HyraxCommitmentKey, + field_lanes_a: &[Vec<&[F]>], + prover_data_a: &HyraxProverData, + ck_b: &HyraxCommitmentKey, + field_lanes_b: &[Vec<&[F]>], + prover_data_b: &HyraxProverData, + point: &[F], + field_cfg: &F::Config, + ) -> Result<(), ZipError> + where + F: HyraxFieldBridge + DelayedFieldProductSum, + F::Inner: Transcribable, + F::Modulus: Transcribable, + { + let _ = CHECK_FOR_OVERFLOW; + if field_lanes_a.is_empty() || field_lanes_b.is_empty() { + return Err(ZipError::InvalidPcsParam( + "Hyrax mixed field-lane opening expects two non-empty groups".to_string(), + )); + } + validate_field_lanes::(ck_a, field_lanes_a, point.len(), prover_data_a)?; + validate_field_lanes::(ck_b, field_lanes_b, point.len(), prover_data_b)?; + validate_shared_commitment_keys(ck_a, ck_b)?; + if prover_data_a.num_rows != 1 || prover_data_b.num_rows != 1 { + return Err(ZipError::InvalidPcsParam( + "Hyrax mixed field-lane opening requires a single row".to_string(), + )); + } + + let q1 = eq_tensor_f::(point, field_cfg); + let alpha_count_a = field_lanes_a.len() * prover_data_a.num_lanes; + let alpha_count_b = field_lanes_b.len() * prover_data_b.num_lanes; + let alphas = + sample_scalars::(&mut transcript.fs_transcript, alpha_count_a + alpha_count_b); + let alpha_fields = alphas + .iter() + .map(|alpha| F::scalar_to_field(alpha, field_cfg)) + .collect::, _>>()?; + + let mut combined_row = vec![F::zero_with_cfg(field_cfg); ck_a.num_cols]; + let mut rho_star = C::ScalarField::zero(); + for (poly_idx, lanes) in field_lanes_a.iter().enumerate() { + for (lane, values) in lanes.iter().enumerate() { + let alpha_idx = alpha_index_dynamic(prover_data_a.num_lanes, poly_idx, lane); + let alpha = &alpha_fields[alpha_idx]; + for (acc, value) in combined_row.iter_mut().zip(values.iter()) { + *acc += value.clone() * alpha.clone(); + } + if ck_a.blinding_mode.is_blinded() { + let blind_idx = commitment_index_dynamic( + prover_data_a.num_lanes, + poly_idx, + lane, + 0, + prover_data_a.num_rows, + ); + rho_star += alphas[alpha_idx] * prover_data_a.blinds[blind_idx]; + } + } + } + for (poly_idx, lanes) in field_lanes_b.iter().enumerate() { + for (lane, values) in lanes.iter().enumerate() { + let local_alpha_idx = alpha_index_dynamic(prover_data_b.num_lanes, poly_idx, lane); + let alpha_idx = alpha_count_a + local_alpha_idx; + let alpha = &alpha_fields[alpha_idx]; + for (acc, value) in combined_row.iter_mut().zip(values.iter()) { + *acc += value.clone() * alpha.clone(); + } + if ck_b.blinding_mode.is_blinded() { + let blind_idx = commitment_index_dynamic( + prover_data_b.num_lanes, + poly_idx, + lane, + 0, + prover_data_b.num_rows, + ); + rho_star += alphas[alpha_idx] * prover_data_b.blinds[blind_idx]; + } + } + } + + let b = F::delayed_sum_of_products(&combined_row, &q1, F::zero_with_cfg(field_cfg)); + transcript.write_field_elements(&[b])?; + + let combined_scalars = combined_row + .iter() + .map(F::field_to_scalar) + .collect::, _>>()?; + write_scalars::(transcript, &combined_scalars)?; + if ck_a.blinding_mode.is_blinded() { + write_scalar::(transcript, &rho_star)?; + } + + Ok(()) + } + + #[allow(clippy::arithmetic_side_effects)] + #[allow(clippy::too_many_arguments)] + pub fn verify_open_two_field_lane_groups_single_row< + F, + EvalA, + LanesA, + EvalB, + LanesB, + const CHECK_FOR_OVERFLOW: bool, + const D: usize, + >( + transcript: &mut PcsVerifierTranscript, + vk_a: &HyraxVerifierKey, + commitment_a: &HyraxCommitment, + lifted_evals_a: &[DynamicPolynomialF], + vk_b: &HyraxVerifierKey, + commitment_b: &HyraxCommitment, + lifted_evals_b: &[DynamicPolynomialF], + point: &[F], + opening_proof: &[u8], + field_cfg: &F::Config, + ) -> Result<(), ZipError> + where + F: HyraxFieldBridge, + F::Inner: Transcribable, + F::Modulus: Transcribable, + EvalA: Clone + Debug + Send + Sync, + EvalB: Clone + Debug + Send + Sync, + LanesA: HyraxLanes, + LanesB: HyraxLanes, + { + let _ = CHECK_FOR_OVERFLOW; + let original_stream = + std::mem::replace(&mut transcript.stream, Cursor::new(opening_proof.to_vec())); + let result = (|| { + if commitment_a.blinding_mode != vk_a.blinding_mode + || commitment_b.blinding_mode != vk_b.blinding_mode + { + return Err(ZipError::InvalidPcsParam( + "Hyrax commitment blinding mode mismatch".to_string(), + )); + } + validate_commitment_shape::(commitment_a)?; + validate_commitment_shape::(commitment_b)?; + validate_shared_verifier_keys(vk_a, vk_b)?; + if lifted_evals_a.len() != commitment_a.batch_size { + return Err(ZipError::InvalidPcsParam(format!( + "Hyrax verifier expected {} left lifted evals, got {}", + commitment_a.batch_size, + lifted_evals_a.len() + ))); + } + if lifted_evals_b.len() != commitment_b.batch_size { + return Err(ZipError::InvalidPcsParam(format!( + "Hyrax verifier expected {} right lifted evals, got {}", + commitment_b.batch_size, + lifted_evals_b.len() + ))); + } + if commitment_a.batch_size == 0 || commitment_b.batch_size == 0 { + return Err(ZipError::InvalidPcsParam( + "Hyrax mixed opening expects two non-empty commitment groups".to_string(), + )); + } + + let n = 1usize << point.len(); + let expected_rows = num_rows(n, vk_a.num_cols)?; + if expected_rows != 1 || commitment_a.num_rows != 1 || commitment_b.num_rows != 1 { + return Err(ZipError::InvalidPcsParam( + "Hyrax mixed opening verifier requires a single row".to_string(), + )); + } + + let point_scalar = point + .iter() + .map(F::field_to_scalar) + .collect::, _>>()?; + let q1_scalar = eq_tensor_scalar::(&point_scalar); + let alpha_count_a = commitment_a.batch_size * commitment_a.num_lanes; + let alpha_count_b = commitment_b.batch_size * commitment_b.num_lanes; + let alphas = + sample_scalars::(&mut transcript.fs_transcript, alpha_count_a + alpha_count_b); + + let b_f = transcript.read_field_elements::(1)?; + let mut expected_eval = F::zero_with_cfg(field_cfg); + for (poly_idx, lifted_eval) in lifted_evals_a.iter().enumerate() { + for lane in 0..commitment_a.num_lanes { + let alpha_idx = alpha_index_dynamic(commitment_a.num_lanes, poly_idx, lane); + let alpha = F::scalar_to_field(&alphas[alpha_idx], field_cfg)?; + let mut term = LanesA::lifted_eval::(lifted_eval, lane, field_cfg)?; + term *= α + expected_eval += &term; + } + } + for (poly_idx, lifted_eval) in lifted_evals_b.iter().enumerate() { + for lane in 0..commitment_b.num_lanes { + let local_alpha_idx = + alpha_index_dynamic(commitment_b.num_lanes, poly_idx, lane); + let alpha_idx = alpha_count_a + local_alpha_idx; + let alpha = F::scalar_to_field(&alphas[alpha_idx], field_cfg)?; + let mut term = LanesB::lifted_eval::(lifted_eval, lane, field_cfg)?; + term *= α + expected_eval += &term; + } + } + if b_f[0] != expected_eval { + return Err(ZipError::InvalidPcsOpen( + "Hyrax mixed evaluation consistency failure".to_string(), + )); + } + + let b_scalar = F::field_to_scalar(&b_f[0])?; + let combined_row = read_scalars::(transcript, vk_a.num_cols)?; + let rho_star = if vk_a.blinding_mode.is_blinded() { + Some(read_scalar::(transcript)?) + } else { + None + }; + + let mut lhs = C::ScalarField::zero(); + for (value, weight) in combined_row.iter().zip(q1_scalar.iter()) { + lhs += *value * weight; + } + if lhs != b_scalar { + return Err(ZipError::InvalidPcsOpen( + "Hyrax mixed row coherence failure".to_string(), + )); + } + + let mut comm_bases = + Vec::with_capacity(commitment_a.comm_affine.len() + commitment_b.comm_affine.len()); + comm_bases.extend_from_slice(&commitment_a.comm_affine); + comm_bases.extend_from_slice(&commitment_b.comm_affine); + let comm_lc = msm_unchecked::(&comm_bases, &alphas)?; + + let mut expected = + msm_unchecked::(&vk_a.bases[..combined_row.len()], &combined_row)?; + if let Some(rho_star) = rho_star { + expected += vk_a.h * rho_star; + } + + if comm_lc != expected { + return Err(ZipError::InvalidPcsOpen( + "Hyrax mixed commitment opening failure".to_string(), + )); + } + + Ok(()) + })(); + let consumed = transcript.stream.position() == opening_proof.len() as u64; + transcript.stream = original_stream; + result?; + if !consumed { + return Err(ZipError::InvalidPcsOpen( + "PCS mixed opening proof has trailing bytes".to_string(), + )); + } + Ok(()) + } +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum HyraxBlindingMode { + Blinded, + Unblinded, +} + +impl Default for HyraxBlindingMode { + fn default() -> Self { + Self::Unblinded + } +} + +impl HyraxBlindingMode { + fn as_u8(self) -> u8 { + match self { + Self::Blinded => 1, + Self::Unblinded => 0, + } + } + + fn is_blinded(self) -> bool { + matches!(self, Self::Blinded) + } +} + +#[derive(Clone, Debug)] +pub struct HyraxCommitmentKey { + pub(crate) num_cols: usize, + pub(crate) blinding_mode: HyraxBlindingMode, + pub(crate) msm_ck: MsmCommitmentKey, +} + +#[derive(Clone, Debug)] +pub struct HyraxVerifierKey { + pub(crate) num_cols: usize, + pub(crate) bases: Vec, + pub(crate) h: C::Group, + pub(crate) blinding_mode: HyraxBlindingMode, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct HyraxCommitment { + pub(crate) batch_size: usize, + pub(crate) num_lanes: usize, + pub(crate) num_rows: usize, + pub(crate) blinding_mode: HyraxBlindingMode, + pub(crate) comm: Vec, + pub(crate) comm_affine: Vec, + pub(crate) comm_bytes: Vec, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct HyraxProverData { + pub(crate) batch_size: usize, + pub(crate) num_lanes: usize, + pub(crate) num_rows: usize, + pub(crate) blinding_mode: HyraxBlindingMode, + pub(crate) blinds: Vec, +} + +pub trait HyraxFieldBridge: PrimeField { + fn field_to_scalar(value: &Self) -> Result; + fn scalar_to_field(value: &C::ScalarField, cfg: &Self::Config) -> Result; +} + +impl HyraxFieldBridge for MontyField +where + C: AffineRepr, +{ + fn field_to_scalar(value: &Self) -> Result { + validate_curve_scalar_modulus::(&value.modulus())?; + + let canonical = value.retrieve(); + let mut bytes = vec![0u8; as ConstTranscribable>::NUM_BYTES]; + canonical.write_transcription_bytes_exact(&mut bytes); + Ok(C::ScalarField::from_le_bytes_mod_order(&bytes)) + } + + fn scalar_to_field(value: &C::ScalarField, cfg: &Self::Config) -> Result { + let actual_modulus = Uint::::new(cfg.modulus().get()); + validate_curve_scalar_modulus::(&actual_modulus)?; + + let scalar_bigint: ::BigInt = value.clone().into(); + let scalar_uint = uint_from_le_bytes::(&scalar_bigint.to_bytes_le()); + Ok(MontyField::::from_with_cfg(&scalar_uint, cfg)) + } +} + +impl HyraxFieldBridge for BoxedMontyField +where + C: AffineRepr, +{ + fn field_to_scalar(value: &Self) -> Result { + validate_curve_scalar_modulus_boxed::(&value.modulus())?; + + let canonical = BoxedMontyForm::from(value.clone()).retrieve(); + Ok(C::ScalarField::from_le_bytes_mod_order( + &canonical.to_le_bytes(), + )) + } + + fn scalar_to_field(value: &C::ScalarField, cfg: &Self::Config) -> Result { + let actual_modulus = cfg.modulus().clone().get(); + validate_curve_scalar_modulus_boxed::(&actual_modulus)?; + + let scalar_bigint: ::BigInt = value.clone().into(); + let scalar_uint = BoxedUint::from_le_slice( + &scalar_bigint.to_bytes_le(), + actual_modulus.bits_precision(), + ) + .expect("curve scalar must fit protocol field precision"); + Ok(BoxedMontyField::from_with_cfg(&scalar_uint, cfg)) + } +} + +fn validate_fold_inputs(values: &[T], theta_len: usize, label: &str) -> Result<(), ZipError> { + if values.is_empty() { + return Err(ZipError::InvalidPcsParam(format!( + "Hyrax cannot fold empty {label}" + ))); + } + if values.len() != theta_len { + return Err(ZipError::InvalidPcsParam(format!( + "Hyrax fold {label} count mismatch: got {}, expected {theta_len}", + values.len() + ))); + } + Ok(()) +} + +pub trait HyraxLanes: Clone + Debug + Send + Sync +where + C: AffineRepr, + Eval: Clone + Debug + Send + Sync, +{ + type LaneValue: Copy + Send + Sync; + type Strategy: RowMsmStrategy; + + const NUM_LANES: usize; + + fn lane_value(eval: &Eval, lane: usize) -> Result; + + fn lane_to_scalar(value: Self::LaneValue) -> C::ScalarField; + + fn commit_poly( + _ck: &HyraxCommitmentKey, + _poly: &DenseMultilinearExtension, + _num_rows: usize, + ) -> Option, Vec), ZipError>> { + None + } + + fn accumulate_b( + row: &[Eval], + lane: usize, + q1_scalar: &[C::ScalarField], + ) -> Result { + let mut row_eval = C::ScalarField::zero(); + for (col_idx, eval) in row.iter().enumerate() { + let value = Self::lane_to_scalar(Self::lane_value(eval, lane)?); + if let Some(weight) = q1_scalar.get(col_idx) { + row_eval += value * weight; + } + } + Ok(row_eval) + } + + fn accumulate_combined_row( + row: &[Eval], + lane: usize, + coeff: C::ScalarField, + combined_row: &mut [C::ScalarField], + ) -> Result<(), ZipError> { + for (col_idx, eval) in row.iter().enumerate() { + let value = Self::lane_to_scalar(Self::lane_value(eval, lane)?); + combined_row[col_idx] += coeff * value; + } + Ok(()) + } + + fn accumulate_single_row_opening( + row: &[Eval], + lane: usize, + alpha: C::ScalarField, + q1_scalar: &[C::ScalarField], + b_scalar: &mut C::ScalarField, + combined_row: &mut [C::ScalarField], + ) -> Result<(), ZipError> { + for (col_idx, eval) in row.iter().enumerate() { + let value = Self::lane_to_scalar(Self::lane_value(eval, lane)?); + let scaled = alpha * value; + if let Some(weight) = q1_scalar.get(col_idx) { + *b_scalar += scaled * weight; + } + combined_row[col_idx] += scaled; + } + Ok(()) + } + + fn lifted_eval( + lifted_eval: &DynamicPolynomialF, + lane: usize, + field_cfg: &F::Config, + ) -> Result + where + F: PrimeField; +} + +#[derive(Clone, Debug)] +pub struct BinaryLanes; + +#[derive(Clone, Debug)] +pub struct IntScalarLane; + +#[derive(Clone, Debug)] +pub struct DensePolyScalarLanes; + +impl HyraxLanes, D> for BinaryLanes { + type LaneValue = bool; + type Strategy = BoolSubsetMsm<6>; + + const NUM_LANES: usize = D; + + fn lane_value(eval: &BinaryPoly, lane: usize) -> Result { + if lane >= D { + return Err(ZipError::InvalidPcsParam(format!( + "binary lane {lane} out of range" + ))); + } + Ok(eval.coeff(lane)) + } + + fn lane_to_scalar(value: Self::LaneValue) -> C::ScalarField { + if value { + C::ScalarField::from(1u64) + } else { + C::ScalarField::zero() + } + } + + fn commit_poly( + ck: &HyraxCommitmentKey, + poly: &DenseMultilinearExtension>, + num_rows: usize, + ) -> Option, Vec), ZipError>> { + let expected_comm = , D>>::NUM_LANES * num_rows; + , D>>::Strategy::precompute_ck(&ck.msm_ck); + let blinds = if ck.blinding_mode.is_blinded() { + random_scalars::(expected_comm) + } else { + Vec::new() + }; + + Some((|| { + let use_inner_parallelism = use_inner_bool_parallelism(expected_comm); + let per_row = cfg_into_iter!(0..num_rows) + .map(|row_idx| { + let lower = row_idx * ck.num_cols; + let upper = (lower + ck.num_cols).min(poly.evaluations.len()); + let row_len = upper - lower; + let mut row_comms = + BoolSubsetMsm::<6>::msm_bool_rows_from_window_masks::( + &ck.msm_ck, + row_len, + use_inner_parallelism, + |offset, len| { + let mut masks = [0usize; D]; + for bit_idx in 0..len { + let eval = &poly.evaluations[lower + offset + bit_idx]; + for (lane, mask) in masks.iter_mut().enumerate() { + if eval.coeff(lane) { + *mask |= 1usize << bit_idx; + } + } + } + masks + }, + ) + .map_err(msm_err)?; + + if ck.blinding_mode.is_blinded() { + for (lane, row_comm) in row_comms.iter_mut().enumerate() { + let blind_idx = lane * num_rows + row_idx; + *row_comm += ck.msm_ck.h * blinds[blind_idx]; + } + } + Ok::<[C::Group; D], ZipError>(row_comms) + }) + .collect::, _>>()?; + + let mut comm = Vec::with_capacity(expected_comm); + for lane in 0..D { + for row_comms in &per_row { + comm.push(row_comms[lane]); + } + } + Ok((comm, blinds)) + })()) + } + + fn accumulate_b( + row: &[BinaryPoly], + lane: usize, + q1_scalar: &[C::ScalarField], + ) -> Result { + let mut row_eval = C::ScalarField::zero(); + for (col_idx, eval) in row.iter().enumerate() { + if , D>>::lane_value(eval, lane)? { + if let Some(weight) = q1_scalar.get(col_idx) { + row_eval += weight; + } + } + } + Ok(row_eval) + } + + fn accumulate_combined_row( + row: &[BinaryPoly], + lane: usize, + coeff: C::ScalarField, + combined_row: &mut [C::ScalarField], + ) -> Result<(), ZipError> { + for (col_idx, eval) in row.iter().enumerate() { + if , D>>::lane_value(eval, lane)? { + combined_row[col_idx] += coeff; + } + } + Ok(()) + } + + fn accumulate_single_row_opening( + row: &[BinaryPoly], + lane: usize, + alpha: C::ScalarField, + q1_scalar: &[C::ScalarField], + b_scalar: &mut C::ScalarField, + combined_row: &mut [C::ScalarField], + ) -> Result<(), ZipError> { + let mut row_eval = C::ScalarField::zero(); + for (col_idx, eval) in row.iter().enumerate() { + if , D>>::lane_value(eval, lane)? { + if let Some(weight) = q1_scalar.get(col_idx) { + row_eval += weight; + } + combined_row[col_idx] += alpha; + } + } + *b_scalar += alpha * row_eval; + Ok(()) + } + + fn lifted_eval( + lifted_eval: &DynamicPolynomialF, + lane: usize, + field_cfg: &F::Config, + ) -> Result + where + F: PrimeField, + { + Ok(lifted_eval + .coeffs + .get(lane) + .cloned() + .unwrap_or_else(|| F::zero_with_cfg(field_cfg))) + } +} + +impl HyraxLanes, D> + for IntScalarLane +{ + type LaneValue = Int; + type Strategy = SignedIntPippengerMsm; + + const NUM_LANES: usize = 1; + + fn lane_value(eval: &Int, lane: usize) -> Result { + if lane != 0 { + return Err(ZipError::InvalidPcsParam(format!( + "int lane {lane} out of range" + ))); + } + Ok(*eval) + } + + fn lane_to_scalar(value: Self::LaneValue) -> C::ScalarField { + int_to_scalar::(&value).expect("int lane value must convert to scalar") + } + + fn lifted_eval( + lifted_eval: &DynamicPolynomialF, + lane: usize, + field_cfg: &F::Config, + ) -> Result + where + F: PrimeField, + { + if lane != 0 { + return Err(ZipError::InvalidPcsParam(format!( + "lifted int lane {lane} out of range" + ))); + } + Ok(lifted_eval + .coeffs + .first() + .cloned() + .unwrap_or_else(|| F::zero_with_cfg(field_cfg))) + } +} + +impl + HyraxLanes, D>, D> for DensePolyScalarLanes +{ + type LaneValue = C::ScalarField; + type Strategy = ScalarPippengerMsm; + + const NUM_LANES: usize = D; + + fn lane_value( + eval: &DensePolynomial, D>, + lane: usize, + ) -> Result { + eval.coeffs + .get(lane) + .ok_or_else(|| ZipError::InvalidPcsParam(format!("dense lane {lane} out of range"))) + .and_then(int_to_scalar::) + } + + fn lane_to_scalar(value: Self::LaneValue) -> C::ScalarField { + value + } + + fn lifted_eval( + lifted_eval: &DynamicPolynomialF, + lane: usize, + field_cfg: &F::Config, + ) -> Result + where + F: PrimeField, + { + Ok(lifted_eval + .coeffs + .get(lane) + .cloned() + .unwrap_or_else(|| F::zero_with_cfg(field_cfg))) + } +} + +impl HyraxPCS { + pub fn setup( + width: usize, + domain: impl AsRef<[u8]>, + blinding_mode: HyraxBlindingMode, + ) -> Result<(HyraxCommitmentKey, HyraxVerifierKey), ZipError> { + let domain = domain.as_ref(); + let bases = (0..width) + .map(|idx| hash_to_curve::(domain, b"basis", idx)) + .collect::, _>>()?; + let h = hash_to_curve::(domain, b"blinding", 0)?.into_group(); + Self::setup_from_trusted_bases(width, bases, h, blinding_mode) + } + + pub fn setup_from_bases( + width: usize, + bases: Vec, + h: C::Group, + ) -> Result<(HyraxCommitmentKey, HyraxVerifierKey), ZipError> { + Self::setup_from_bases_with_blinding(width, bases, h, HyraxBlindingMode::Blinded) + } + + pub fn setup_from_bases_with_blinding( + width: usize, + bases: Vec, + h: C::Group, + blinding_mode: HyraxBlindingMode, + ) -> Result<(HyraxCommitmentKey, HyraxVerifierKey), ZipError> { + Self::setup_from_trusted_bases(width, bases, h, blinding_mode) + } + + pub fn setup_from_trusted_bases( + width: usize, + bases: Vec, + h: C::Group, + blinding_mode: HyraxBlindingMode, + ) -> Result<(HyraxCommitmentKey, HyraxVerifierKey), ZipError> { + validate_trusted_bases(width, &bases, &h)?; + let msm_ck = msm_key(width, bases.clone(), h)?; + Ok(( + HyraxCommitmentKey { + num_cols: width, + blinding_mode, + msm_ck, + }, + HyraxVerifierKey { + num_cols: width, + bases, + h, + blinding_mode, + }, + )) + } +} + +impl PCS for HyraxPCS +where + F: HyraxFieldBridge, + C: AffineRepr, + Eval: Clone + Debug + Send + Sync, + Lanes: HyraxLanes, +{ + type CommitmentKey = HyraxCommitmentKey; + type VerifierKey = HyraxVerifierKey; + type Commitment = HyraxCommitment; + type ProverData = HyraxProverData; + type OpeningProof = Vec; + + fn precompute_ck(ck: &Self::CommitmentKey) { + Lanes::Strategy::precompute_ck(&ck.msm_ck); + } + + fn commit( + ck: &Self::CommitmentKey, + polys: &[DenseMultilinearExtension], + ) -> Result<(Self::ProverData, Self::Commitment), ZipError> { + if polys.is_empty() { + return Ok(( + HyraxProverData { + batch_size: 0, + num_lanes: Lanes::NUM_LANES, + num_rows: 0, + blinding_mode: ck.blinding_mode, + blinds: Vec::new(), + }, + HyraxCommitment { + batch_size: 0, + num_lanes: Lanes::NUM_LANES, + num_rows: 0, + blinding_mode: ck.blinding_mode, + comm: Vec::new(), + comm_affine: Vec::new(), + comm_bytes: Vec::new(), + }, + )); + } + + validate_polys(polys)?; + let n = polys[0].evaluations.len(); + let num_rows = num_rows(n, ck.num_cols)?; + Lanes::Strategy::precompute_ck(&ck.msm_ck); + + let per_poly = cfg_iter!(polys) + .map(|poly| commit_hyrax_poly::(ck, poly, num_rows)) + .collect::, _>>()?; + + let expected_comm = polys.len() * Lanes::NUM_LANES * num_rows; + let expected_blinds = if ck.blinding_mode.is_blinded() { + expected_comm + } else { + 0 + }; + let mut all_comm = Vec::with_capacity(expected_comm); + let mut all_blinds = Vec::with_capacity(expected_blinds); + for (comm, blinds) in per_poly { + all_comm.extend(comm); + all_blinds.extend(blinds); + } + + let all_affine = C::Group::normalize_batch(&all_comm); + let all_bytes = affine_points_bytes::(&all_affine)?; + + Ok(( + HyraxProverData { + batch_size: polys.len(), + num_lanes: Lanes::NUM_LANES, + num_rows, + blinding_mode: ck.blinding_mode, + blinds: all_blinds, + }, + HyraxCommitment { + batch_size: polys.len(), + num_lanes: Lanes::NUM_LANES, + num_rows, + blinding_mode: ck.blinding_mode, + comm: all_comm, + comm_affine: all_affine, + comm_bytes: all_bytes, + }, + )) + } + + fn absorb_commitment(transcript: &mut T, commitment: &Self::Commitment) { + transcript.absorb_slice(b"hyrax_commitment_begin"); + transcript.absorb_slice(&(commitment.batch_size as u64).to_le_bytes()); + transcript.absorb_slice(&(commitment.num_lanes as u64).to_le_bytes()); + transcript.absorb_slice(&(commitment.num_rows as u64).to_le_bytes()); + transcript.absorb_slice(&[commitment.blinding_mode.as_u8()]); + transcript.absorb_slice(&commitment.comm_bytes); + transcript.absorb_slice(b"hyrax_commitment_end"); + } + + fn commitment_num_bytes(commitment: &Self::Commitment) -> usize { + let group_size = C::zero().serialized_size(Compress::Yes); + 3 * core::mem::size_of::() + 1 + commitment.comm.len() * group_size + } + + fn write_commitment_bytes(commitment: &Self::Commitment, buf: &mut Vec) { + buf.extend_from_slice(&(commitment.batch_size as u64).to_le_bytes()); + buf.extend_from_slice(&(commitment.num_lanes as u64).to_le_bytes()); + buf.extend_from_slice(&(commitment.num_rows as u64).to_le_bytes()); + buf.push(commitment.blinding_mode.as_u8()); + buf.extend_from_slice(&commitment.comm_bytes); + } + + fn batch_size(commitment: &Self::Commitment) -> usize { + commitment.batch_size + } + + fn prove_open( + transcript: &mut PcsProverTranscript, + ck: &Self::CommitmentKey, + polys: &[DenseMultilinearExtension], + point: &[F], + prover_data: &Self::ProverData, + field_cfg: &F::Config, + ) -> Result + where + F::Inner: Transcribable, + F::Modulus: Transcribable, + { + let _ = CHECK_FOR_OVERFLOW; + let start = transcript.stream.position() as usize; + if polys.is_empty() { + if prover_data.batch_size != 0 + || prover_data.num_lanes != Lanes::NUM_LANES + || prover_data.num_rows != 0 + || prover_data.blinding_mode != ck.blinding_mode + || !prover_data.blinds.is_empty() + { + return Err(ZipError::InvalidPcsParam( + "Hyrax prover data must be canonical for an empty batch".to_string(), + )); + } + let end = transcript.stream.position() as usize; + return Ok(transcript.stream.get_ref()[start..end].to_vec()); + } + validate_polys(polys)?; + validate_hyrax_shape::( + ck.num_cols, + ck.blinding_mode, + polys, + prover_data, + )?; + + let n = polys[0].evaluations.len(); + if n != (1usize << point.len()) { + return Err(ZipError::InvalidPcsParam(format!( + "Hyrax open expected point for {n} evals, got {} variables", + point.len() + ))); + } + + let point_scalar = point + .iter() + .map(F::field_to_scalar) + .collect::, _>>()?; + let row_vars = prover_data.num_rows.ilog2() as usize; + let q0_f = eq_tensor_f::(&point[..row_vars], field_cfg); + let q1_scalar = eq_tensor_scalar::(&point_scalar[row_vars..]); + let alphas = sample_scalars::( + &mut transcript.fs_transcript, + polys.len() * Lanes::NUM_LANES, + ); + + let mut combined_row = vec![C::ScalarField::zero(); ck.num_cols]; + let mut rho_star = C::ScalarField::zero(); + + let mut b_scalar = vec![C::ScalarField::zero(); prover_data.num_rows]; + if prover_data.num_rows == 1 { + for (poly_idx, poly) in polys.iter().enumerate() { + for lane in 0..Lanes::NUM_LANES { + let alpha = alphas[alpha_index_dynamic(Lanes::NUM_LANES, poly_idx, lane)]; + if ck.blinding_mode.is_blinded() { + let blind_idx = commitment_index_dynamic( + Lanes::NUM_LANES, + poly_idx, + lane, + 0, + prover_data.num_rows, + ); + rho_star += alpha * prover_data.blinds[blind_idx]; + } + Lanes::accumulate_single_row_opening( + &poly.evaluations, + lane, + alpha, + &q1_scalar, + &mut b_scalar[0], + &mut combined_row, + )?; + } + } + } else { + b_scalar = cfg_into_iter!(0..prover_data.num_rows) + .map(|row_idx| { + let lower = row_idx * ck.num_cols; + let mut acc = C::ScalarField::zero(); + for (poly_idx, poly) in polys.iter().enumerate() { + let upper = (lower + ck.num_cols).min(poly.evaluations.len()); + let row = &poly.evaluations[lower..upper]; + for lane in 0..Lanes::NUM_LANES { + let alpha = + alphas[alpha_index_dynamic(Lanes::NUM_LANES, poly_idx, lane)]; + let row_eval = Lanes::accumulate_b(row, lane, &q1_scalar)?; + acc += alpha * row_eval; + } + } + Ok::(acc) + }) + .collect::, _>>()?; + + let b_f = b_scalar + .iter() + .map(|value| F::scalar_to_field(value, field_cfg)) + .collect::, _>>()?; + transcript.write_field_elements(&b_f)?; + + let row_coeffs = + sample_scalars::(&mut transcript.fs_transcript, prover_data.num_rows); + + combined_row = cfg_into_iter!(0..ck.num_cols) + .map(|col_idx| { + let mut acc = C::ScalarField::zero(); + for (poly_idx, poly) in polys.iter().enumerate() { + for lane in 0..Lanes::NUM_LANES { + let alpha = + alphas[alpha_index_dynamic(Lanes::NUM_LANES, poly_idx, lane)]; + for (row_idx, row_coeff) in row_coeffs.iter().copied().enumerate() { + let eval_idx = row_idx * ck.num_cols + col_idx; + if let Some(eval) = poly.evaluations.get(eval_idx) { + let value = + Lanes::lane_to_scalar(Lanes::lane_value(eval, lane)?); + acc += alpha * row_coeff * value; + } + } + } + } + Ok::(acc) + }) + .collect::, _>>()?; + + if ck.blinding_mode.is_blinded() { + let total_jobs = polys.len() * Lanes::NUM_LANES * prover_data.num_rows; + let rho_terms = cfg_into_iter!(0..total_jobs) + .map(|job_idx| { + let poly_stride = Lanes::NUM_LANES * prover_data.num_rows; + let poly_idx = job_idx / poly_stride; + let lane_row_idx = job_idx % poly_stride; + let lane = lane_row_idx / prover_data.num_rows; + let row_idx = lane_row_idx % prover_data.num_rows; + let alpha = alphas[alpha_index_dynamic(Lanes::NUM_LANES, poly_idx, lane)]; + alpha * row_coeffs[row_idx] * prover_data.blinds[job_idx] + }) + .collect::>(); + rho_star = rho_terms + .into_iter() + .fold(C::ScalarField::zero(), |acc, term| acc + term); + } + } + + if prover_data.num_rows == 1 { + let b_f = b_scalar + .iter() + .map(|value| F::scalar_to_field(value, field_cfg)) + .collect::, _>>()?; + transcript.write_field_elements(&b_f)?; + } + + write_scalars::(transcript, &combined_row)?; + if ck.blinding_mode.is_blinded() { + write_scalar::(transcript, &rho_star)?; + } + + if q0_f.len() != b_scalar.len() { + return Err(ZipError::InvalidPcsOpen( + "Hyrax b vector shape mismatch".to_string(), + )); + } + + let end = transcript.stream.position() as usize; + Ok(transcript.stream.get_ref()[start..end].to_vec()) + } + + fn verify_open( + transcript: &mut PcsVerifierTranscript, + vk: &Self::VerifierKey, + commitment: &Self::Commitment, + point: &[F], + lifted_evals: &[DynamicPolynomialF], + opening_proof: &Self::OpeningProof, + field_cfg: &F::Config, + ) -> Result<(), ZipError> + where + F::Inner: Transcribable, + F::Modulus: Transcribable, + { + let _ = CHECK_FOR_OVERFLOW; + if !opening_proof.is_empty() { + let original_stream = + std::mem::replace(&mut transcript.stream, Cursor::new(opening_proof.clone())); + let result = >::verify_open::( + transcript, + vk, + commitment, + point, + lifted_evals, + &Vec::new(), + field_cfg, + ); + let consumed = transcript.stream.position() == opening_proof.len() as u64; + transcript.stream = original_stream; + result?; + if !consumed { + return Err(ZipError::InvalidPcsOpen( + "PCS opening proof has trailing bytes".to_string(), + )); + } + return Ok(()); + } + + if commitment.blinding_mode != vk.blinding_mode { + return Err(ZipError::InvalidPcsParam( + "Hyrax commitment blinding mode mismatch".to_string(), + )); + } + validate_commitment_shape::(commitment)?; + if lifted_evals.len() != commitment.batch_size { + return Err(ZipError::InvalidPcsParam(format!( + "Hyrax verifier expected {} lifted evals, got {}", + commitment.batch_size, + lifted_evals.len() + ))); + } + if commitment.batch_size == 0 { + if commitment.num_rows != 0 { + return Err(ZipError::InvalidPcsParam( + "Hyrax empty batch must use the canonical empty commitment".to_string(), + )); + } + return Ok(()); + } + + let n = 1usize << point.len(); + let expected_rows = num_rows(n, vk.num_cols)?; + if expected_rows != commitment.num_rows { + return Err(ZipError::InvalidPcsParam(format!( + "Hyrax verifier expected {expected_rows} rows from point, commitment has {}", + commitment.num_rows + ))); + } + + let row_vars = commitment.num_rows.ilog2() as usize; + let q0_f = eq_tensor_f::(&point[..row_vars], field_cfg); + let point_scalar = point + .iter() + .map(F::field_to_scalar) + .collect::, _>>()?; + let q1_scalar = eq_tensor_scalar::(&point_scalar[row_vars..]); + let alphas = sample_scalars::( + &mut transcript.fs_transcript, + commitment.batch_size * commitment.num_lanes, + ); + + let b_f = transcript.read_field_elements::(commitment.num_rows)?; + if b_f.len() != q0_f.len() { + return Err(ZipError::InvalidPcsOpen( + "Hyrax b vector shape mismatch".to_string(), + )); + } + + let mut expected_eval = F::zero_with_cfg(field_cfg); + for (poly_idx, lifted_eval) in lifted_evals.iter().enumerate() { + for lane in 0..commitment.num_lanes { + let alpha = F::scalar_to_field( + &alphas[alpha_index_dynamic(commitment.num_lanes, poly_idx, lane)], + field_cfg, + )?; + let mut term = Lanes::lifted_eval::(lifted_eval, lane, field_cfg)?; + term *= α + expected_eval += &term; + } + } + + let mut b_eval = F::zero_with_cfg(field_cfg); + for (weight, b) in q0_f.iter().zip(b_f.iter()) { + let mut term = weight.clone(); + term *= b; + b_eval += &term; + } + if b_eval != expected_eval { + return Err(ZipError::InvalidPcsOpen( + "Hyrax evaluation consistency failure".to_string(), + )); + } + + let b_scalar = b_f + .iter() + .map(F::field_to_scalar) + .collect::, _>>()?; + let row_coeffs = if commitment.num_rows == 1 { + vec![C::ScalarField::from(1u64)] + } else { + sample_scalars::(&mut transcript.fs_transcript, commitment.num_rows) + }; + + let combined_row = read_scalars::(transcript, vk.num_cols)?; + let rho_star = if vk.blinding_mode.is_blinded() { + Some(read_scalar::(transcript)?) + } else { + None + }; + + let mut lhs = C::ScalarField::zero(); + for (value, weight) in combined_row.iter().zip(q1_scalar.iter()) { + lhs += *value * weight; + } + let mut rhs = C::ScalarField::zero(); + for (coeff, b) in row_coeffs.iter().zip(b_scalar.iter()) { + rhs += *coeff * b; + } + if lhs != rhs { + return Err(ZipError::InvalidPcsOpen( + "Hyrax row coherence failure".to_string(), + )); + } + + let comm_lc = if commitment.num_rows == 1 { + msm_unchecked::(&commitment.comm_affine, &alphas)? + } else { + let mut comm_lc_scalars = Vec::with_capacity(commitment.comm.len()); + for poly_idx in 0..commitment.batch_size { + for lane in 0..commitment.num_lanes { + let alpha = alphas[alpha_index_dynamic(commitment.num_lanes, poly_idx, lane)]; + comm_lc_scalars.extend(row_coeffs.iter().map(|row_coeff| alpha * row_coeff)); + } + } + msm_unchecked::(&commitment.comm_affine, &comm_lc_scalars)? + }; + + let mut expected = msm_unchecked::(&vk.bases[..combined_row.len()], &combined_row)?; + if let Some(rho_star) = rho_star { + expected += vk.h * rho_star; + } + + if comm_lc != expected { + return Err(ZipError::InvalidPcsOpen( + "Hyrax commitment opening failure".to_string(), + )); + } + + Ok(()) + } +} + +impl FoldablePCS for HyraxPCS +where + F: HyraxFieldBridge, + C: AffineRepr, + Eval: Clone + Debug + Send + Sync, + Lanes: HyraxLanes, +{ + fn fold_commitments( + commitments: &[Self::Commitment], + theta: &[F], + field_cfg: &F::Config, + ) -> Result { + let refs = commitments.iter().collect::>(); + Self::fold_commitment_refs(&refs, theta, field_cfg) + } + + fn fold_commitment_refs( + commitments: &[&Self::Commitment], + theta: &[F], + field_cfg: &F::Config, + ) -> Result { + let _ = field_cfg; + validate_fold_inputs(commitments, theta.len(), "commitments")?; + let first = commitments[0]; + validate_commitment_shape::(first)?; + + let scalars = theta + .iter() + .map(F::field_to_scalar) + .collect::, _>>()?; + for &commitment in commitments { + validate_commitment_shape::(commitment)?; + if !same_commitment_shape(first, commitment) { + return Err(ZipError::InvalidPcsParam( + "Hyrax commitment fold shape mismatch".to_string(), + )); + } + } + let folded = msm_shared_weight_commitments_unchecked::(&scalars, commitments)?; + let folded_affine = C::Group::normalize_batch(&folded); + let folded_bytes = affine_points_bytes::(&folded_affine)?; + + Ok(HyraxCommitment { + batch_size: first.batch_size, + num_lanes: first.num_lanes, + num_rows: first.num_rows, + blinding_mode: first.blinding_mode, + comm: folded, + comm_affine: folded_affine, + comm_bytes: folded_bytes, + }) + } + + fn fold_prover_data( + prover_data: &[Self::ProverData], + theta: &[F], + field_cfg: &F::Config, + ) -> Result { + let _ = field_cfg; + validate_fold_inputs(prover_data, theta.len(), "prover data")?; + let first = &prover_data[0]; + + let scalars = theta + .iter() + .map(F::field_to_scalar) + .collect::, _>>()?; + for data in prover_data { + if !same_prover_data_shape(first, data) { + return Err(ZipError::InvalidPcsParam( + "Hyrax prover-data fold shape mismatch".to_string(), + )); + } + } + let folded_blinds = cfg_into_iter!(0..first.blinds.len()) + .map(|idx| { + let mut acc = C::ScalarField::zero(); + for (data, scalar) in prover_data.iter().zip(&scalars) { + acc += data.blinds[idx] * scalar; + } + acc + }) + .collect(); + + Ok(HyraxProverData { + batch_size: first.batch_size, + num_lanes: first.num_lanes, + num_rows: first.num_rows, + blinding_mode: first.blinding_mode, + blinds: folded_blinds, + }) + } +} + +fn validate_polys(polys: &[DenseMultilinearExtension]) -> Result<(), ZipError> { + if let Some(first) = polys.first() { + for poly in polys { + if poly.num_vars != first.num_vars || poly.evaluations.len() != first.evaluations.len() + { + return Err(ZipError::InvalidPcsParam( + "Hyrax batch polynomial shape mismatch".to_string(), + )); + } + } + } + Ok(()) +} + +fn validate_scalar_lanes( + ck: &HyraxCommitmentKey, + scalar_lanes: &[Vec>], + point_len: usize, + prover_data: &HyraxProverData, +) -> Result<(), ZipError> +where + C: AffineRepr, +{ + let expected_n = 1usize + .checked_shl(u32::try_from(point_len).map_err(|_| { + ZipError::InvalidPcsParam(format!("Hyrax point length {point_len} is too large")) + })?) + .ok_or_else(|| { + ZipError::InvalidPcsParam(format!("Hyrax point length {point_len} is too large")) + })?; + let expected_rows = num_rows(expected_n, ck.num_cols)?; + if prover_data.batch_size != scalar_lanes.len() + || prover_data.num_rows != expected_rows + || prover_data.blinding_mode != ck.blinding_mode + { + return Err(ZipError::InvalidPcsParam( + "Hyrax scalar-lane prover data shape mismatch".to_string(), + )); + } + let expected_blinds = if ck.blinding_mode.is_blinded() { + prover_data.batch_size * prover_data.num_lanes * prover_data.num_rows + } else { + 0 + }; + if prover_data.blinds.len() != expected_blinds { + return Err(ZipError::InvalidPcsParam( + "Hyrax scalar-lane blind count mismatch".to_string(), + )); + } + for lanes in scalar_lanes { + if lanes.len() != prover_data.num_lanes { + return Err(ZipError::InvalidPcsParam( + "Hyrax scalar-lane count mismatch".to_string(), + )); + } + for values in lanes { + if values.len() != expected_n { + return Err(ZipError::InvalidPcsParam(format!( + "Hyrax scalar-lane length mismatch: got {}, expected {expected_n}", + values.len() + ))); + } + } + } + Ok(()) +} + +fn validate_field_lanes<'a, C, F>( + ck: &HyraxCommitmentKey, + field_lanes: &[Vec<&'a [F]>], + point_len: usize, + prover_data: &HyraxProverData, +) -> Result<(), ZipError> +where + C: AffineRepr, + F: PrimeField + 'a, +{ + let expected_n = 1usize + .checked_shl(u32::try_from(point_len).map_err(|_| { + ZipError::InvalidPcsParam(format!("Hyrax point length {point_len} is too large")) + })?) + .ok_or_else(|| { + ZipError::InvalidPcsParam(format!("Hyrax point length {point_len} is too large")) + })?; + let expected_rows = num_rows(expected_n, ck.num_cols)?; + if prover_data.batch_size != field_lanes.len() + || prover_data.num_rows != expected_rows + || prover_data.blinding_mode != ck.blinding_mode + { + return Err(ZipError::InvalidPcsParam( + "Hyrax field-lane prover data shape mismatch".to_string(), + )); + } + let expected_blinds = if ck.blinding_mode.is_blinded() { + prover_data.batch_size * prover_data.num_lanes * prover_data.num_rows + } else { + 0 + }; + if prover_data.blinds.len() != expected_blinds { + return Err(ZipError::InvalidPcsParam( + "Hyrax field-lane blind count mismatch".to_string(), + )); + } + for lanes in field_lanes { + if lanes.len() != prover_data.num_lanes { + return Err(ZipError::InvalidPcsParam( + "Hyrax field-lane count mismatch".to_string(), + )); + } + for values in lanes { + if values.len() != expected_n { + return Err(ZipError::InvalidPcsParam(format!( + "Hyrax field-lane length mismatch: got {}, expected {expected_n}", + values.len() + ))); + } + } + } + Ok(()) +} + +fn same_commitment_shape( + lhs: &HyraxCommitment, + rhs: &HyraxCommitment, +) -> bool { + lhs.batch_size == rhs.batch_size + && lhs.num_lanes == rhs.num_lanes + && lhs.num_rows == rhs.num_rows + && lhs.blinding_mode == rhs.blinding_mode + && lhs.comm.len() == rhs.comm.len() + && lhs.comm_affine.len() == rhs.comm_affine.len() + && lhs.comm_bytes.len() == rhs.comm_bytes.len() +} + +fn same_prover_data_shape( + lhs: &HyraxProverData, + rhs: &HyraxProverData, +) -> bool { + lhs.batch_size == rhs.batch_size + && lhs.num_lanes == rhs.num_lanes + && lhs.num_rows == rhs.num_rows + && lhs.blinding_mode == rhs.blinding_mode + && lhs.blinds.len() == rhs.blinds.len() +} + +fn validate_shared_commitment_keys( + lhs: &HyraxCommitmentKey, + rhs: &HyraxCommitmentKey, +) -> Result<(), ZipError> { + if lhs.num_cols != rhs.num_cols + || lhs.blinding_mode != rhs.blinding_mode + || lhs.msm_ck.num_cols != rhs.msm_ck.num_cols + || lhs.msm_ck.bases != rhs.msm_ck.bases + || lhs.msm_ck.h != rhs.msm_ck.h + { + return Err(ZipError::InvalidPcsParam( + "Hyrax mixed opening requires shared commitment bases".to_string(), + )); + } + Ok(()) +} + +fn validate_shared_verifier_keys( + lhs: &HyraxVerifierKey, + rhs: &HyraxVerifierKey, +) -> Result<(), ZipError> { + if lhs.num_cols != rhs.num_cols + || lhs.blinding_mode != rhs.blinding_mode + || lhs.bases != rhs.bases + || lhs.h != rhs.h + { + return Err(ZipError::InvalidPcsParam( + "Hyrax mixed opening requires shared verifier bases".to_string(), + )); + } + Ok(()) +} + +fn validate_trusted_bases( + width: usize, + bases: &[C], + h: &C::Group, +) -> Result<(), ZipError> { + if !width.is_power_of_two() { + return Err(ZipError::InvalidPcsParam(format!( + "Hyrax row width must be a power of two, got {width}" + ))); + } + if bases.len() != width { + return Err(ZipError::InvalidPcsParam(format!( + "Hyrax expected {width} bases, got {}", + bases.len() + ))); + } + + let mut seen = HashSet::with_capacity(bases.len()); + for (idx, base) in bases.iter().copied().enumerate() { + if base.is_zero() { + return Err(ZipError::InvalidPcsParam(format!( + "Hyrax base {idx} is the identity" + ))); + } + if !seen.insert(base) { + return Err(ZipError::InvalidPcsParam(format!( + "Hyrax base {idx} duplicates an earlier base" + ))); + } + } + + let h_affine = h.clone().into_affine(); + if h_affine.is_zero() { + return Err(ZipError::InvalidPcsParam( + "Hyrax blinding base is the identity".to_string(), + )); + } + if seen.contains(&h_affine) { + return Err(ZipError::InvalidPcsParam( + "Hyrax blinding base duplicates a witness base".to_string(), + )); + } + + Ok(()) +} + +fn validate_hyrax_shape( + width: usize, + blinding_mode: HyraxBlindingMode, + polys: &[DenseMultilinearExtension], + prover_data: &HyraxProverData, +) -> Result<(), ZipError> +where + C: AffineRepr, + Lanes: HyraxLanes, + Eval: Clone + Debug + Send + Sync, +{ + let n = polys[0].evaluations.len(); + let num_rows = num_rows(n, width)?; + let expected_blinds = if blinding_mode.is_blinded() { + polys.len() * Lanes::NUM_LANES * num_rows + } else { + 0 + }; + if prover_data.batch_size != polys.len() + || prover_data.num_lanes != Lanes::NUM_LANES + || prover_data.num_rows != num_rows + || prover_data.blinding_mode != blinding_mode + || prover_data.blinds.len() != expected_blinds + { + return Err(ZipError::InvalidPcsParam( + "Hyrax prover data shape mismatch".to_string(), + )); + } + Ok(()) +} + +fn validate_commitment_shape( + commitment: &HyraxCommitment, +) -> Result<(), ZipError> +where + C: AffineRepr, + Lanes: HyraxLanes, + Eval: Clone + Debug + Send + Sync, +{ + if commitment.num_lanes != Lanes::NUM_LANES { + return Err(ZipError::InvalidPcsParam(format!( + "Hyrax commitment lane mismatch: expected {}, got {}", + Lanes::NUM_LANES, + commitment.num_lanes + ))); + } + let expected = commitment.batch_size * commitment.num_lanes * commitment.num_rows; + if commitment.comm.len() != expected { + return Err(ZipError::InvalidPcsParam(format!( + "Hyrax commitment expected {expected} row commitments, got {}", + commitment.comm.len() + ))); + } + if commitment.comm_affine.len() != expected { + return Err(ZipError::InvalidPcsParam(format!( + "Hyrax commitment expected {expected} affine row commitments, got {}", + commitment.comm_affine.len() + ))); + } + let expected_bytes = expected * C::zero().serialized_size(Compress::Yes); + if commitment.comm_bytes.len() != expected_bytes { + return Err(ZipError::InvalidPcsParam(format!( + "Hyrax commitment expected {expected_bytes} commitment bytes, got {}", + commitment.comm_bytes.len() + ))); + } + Ok(()) +} + +fn commit_hyrax_poly( + ck: &HyraxCommitmentKey, + poly: &DenseMultilinearExtension, + num_rows: usize, +) -> Result<(Vec, Vec), ZipError> +where + C: AffineRepr, + Lanes: HyraxLanes, + Eval: Clone + Debug + Send + Sync, +{ + if let Some(result) = Lanes::commit_poly(ck, poly, num_rows) { + return result; + } + + let per_lane = cfg_into_iter!(0..Lanes::NUM_LANES) + .map(|lane| { + let values = lane_values::(poly, lane)?; + if ck.blinding_mode.is_blinded() { + let blind = MsmCommitmentEngine::::blind(&ck.msm_ck, values.len()); + let commitment = MsmCommitmentEngine::::commit_with::<_, Lanes::Strategy>( + &ck.msm_ck, &values, &blind, + ) + .map_err(msm_err)?; + Ok::<(Vec, Vec), ZipError>((commitment.comm, blind.blind)) + } else { + let commitment = MsmCommitmentEngine::::commit_unblinded_with::< + _, + Lanes::Strategy, + >(&ck.msm_ck, &values) + .map_err(msm_err)?; + Ok::<(Vec, Vec), ZipError>((commitment.comm, Vec::new())) + } + }) + .collect::, _>>()?; + + let mut comm = Vec::with_capacity(Lanes::NUM_LANES * num_rows); + let mut blinds = if ck.blinding_mode.is_blinded() { + Vec::with_capacity(Lanes::NUM_LANES * num_rows) + } else { + Vec::new() + }; + for (lane_comm, lane_blinds) in per_lane { + comm.extend(lane_comm); + blinds.extend(lane_blinds); + } + Ok((comm, blinds)) +} + +fn lane_values( + poly: &DenseMultilinearExtension, + lane: usize, +) -> Result, ZipError> +where + C: AffineRepr, + Lanes: HyraxLanes, + Eval: Clone + Debug + Send + Sync, +{ + poly.evaluations + .iter() + .map(|eval| Lanes::lane_value(eval, lane)) + .collect() +} + +fn random_scalars(n: usize) -> Vec { + let mut rng = ark_std::rand::thread_rng(); + (0..n).map(|_| C::ScalarField::rand(&mut rng)).collect() +} + +fn use_inner_bool_parallelism(outer_jobs: usize) -> bool { + #[cfg(feature = "parallel")] + { + outer_jobs < rayon::current_num_threads() + } + + #[cfg(not(feature = "parallel"))] + { + let _ = outer_jobs; + false + } +} + +fn hash_to_curve(domain: &[u8], label: &[u8], index: usize) -> Result { + let point_bytes = C::zero().serialized_size(Compress::Yes); + let mut counter = 0u64; + loop { + let mut hasher = blake3::Hasher::new(); + absorb_hash_part(&mut hasher, b"zinc-plus-hyrax-setup-v1")?; + absorb_hash_part(&mut hasher, domain)?; + absorb_hash_part(&mut hasher, label)?; + hasher.update( + &u64::try_from(index) + .map_err(|_| { + ZipError::InvalidPcsParam("Hyrax setup index does not fit u64".to_string()) + })? + .to_le_bytes(), + ); + hasher.update(&counter.to_le_bytes()); + + let mut bytes = vec![0u8; point_bytes]; + hasher.finalize_xof().fill(&mut bytes); + if let Some(point) = C::from_random_bytes(&bytes).map(|point| point.clear_cofactor()) { + if !point.is_zero() { + return Ok(point); + } + } + + counter = counter.checked_add(1).ok_or_else(|| { + ZipError::InvalidPcsParam("Hyrax hash-to-curve setup exhausted counters".to_string()) + })?; + } +} + +fn absorb_hash_part(hasher: &mut blake3::Hasher, part: &[u8]) -> Result<(), ZipError> { + hasher.update( + &u64::try_from(part.len()) + .map_err(|_| { + ZipError::InvalidPcsParam( + "Hyrax setup domain component length does not fit u64".to_string(), + ) + })? + .to_le_bytes(), + ); + hasher.update(part); + Ok(()) +} + +fn int_to_scalar( + value: &Int, +) -> Result { + let (abs, is_negative) = if value.is_negative() { + ( + value.checked_abs().ok_or_else(|| { + ZipError::InvalidPcsParam("cannot convert minimum Int to scalar".to_string()) + })?, + true, + ) + } else { + (*value, false) + }; + let mut scalar = unsigned_int_to_scalar::(&abs); + if is_negative && !scalar.is_zero() { + scalar = -scalar; + } + Ok(scalar) +} + +fn unsigned_int_to_scalar(value: &Int) -> C::ScalarField { + let mut bytes = Vec::with_capacity(LIMBS * core::mem::size_of::()); + for word in value.as_uint().as_words() { + bytes.extend_from_slice(&word.to_le_bytes()); + } + C::ScalarField::from_le_bytes_mod_order(&bytes) +} + +fn validate_curve_scalar_modulus( + actual: &Uint, +) -> Result<(), ZipError> +where + C: AffineRepr, +{ + let expected = + uint_from_le_bytes::(&::MODULUS.to_bytes_le()); + if actual != &expected { + return Err(ZipError::InvalidPcsParam( + "Hyrax field mismatch: protocol field modulus must equal curve scalar modulus" + .to_string(), + )); + } + Ok(()) +} + +fn validate_curve_scalar_modulus_boxed(actual: &BoxedUint) -> Result<(), ZipError> +where + C: AffineRepr, +{ + let expected = BoxedUint::from_le_slice( + &::MODULUS.to_bytes_le(), + actual.bits_precision(), + ) + .expect("curve scalar modulus must fit protocol field precision"); + if actual != &expected { + return Err(ZipError::InvalidPcsParam( + "Hyrax field mismatch: protocol field modulus must equal curve scalar modulus" + .to_string(), + )); + } + Ok(()) +} + +fn uint_from_le_bytes(bytes: &[u8]) -> Uint { + let num_bytes = as ConstTranscribable>::NUM_BYTES; + assert!( + bytes.len() <= num_bytes, + "integer encoding does not fit in target Uint", + ); + let mut padded = vec![0u8; num_bytes]; + padded[..bytes.len()].copy_from_slice(bytes); + Uint::::read_transcription_bytes_exact(&padded) +} + +fn msm_key( + width: usize, + bases: Vec, + h: C::Group, +) -> Result, ZipError> { + MsmCommitmentEngine::::setup_from_bases(width, bases, h) + .map(|(ck, _)| ck) + .map_err(msm_err) +} + +fn msm_unchecked( + bases: &[C], + scalars: &[C::ScalarField], +) -> Result { + if bases.len() != scalars.len() { + return Err(ZipError::InvalidPcsParam(format!( + "Hyrax MSM expected {} bases, got {}", + scalars.len(), + bases.len() + ))); + } + if !scalars.iter().any(|scalar| scalar.is_zero()) { + return Ok(::msm_unchecked(bases, scalars)); + } + + let non_zero = scalars + .iter() + .enumerate() + .filter(|(_, scalar)| !scalar.is_zero()); + let mut filtered_bases = Vec::new(); + let mut filtered_scalars = Vec::new(); + for (idx, scalar) in non_zero { + filtered_bases.push(bases[idx]); + filtered_scalars.push(*scalar); + } + if filtered_scalars.is_empty() { + return Ok(C::Group::zero()); + } + + Ok(::msm_unchecked( + &filtered_bases, + &filtered_scalars, + )) +} + +fn msm_shared_weight_commitments_unchecked( + scalars: &[C::ScalarField], + commitments: &[&HyraxCommitment], +) -> Result, ZipError> { + if commitments.is_empty() { + return Ok(Vec::new()); + } + let row_count = commitments[0].comm_affine.len(); + debug_assert!( + commitments + .iter() + .all(|commitment| commitment.comm_affine.len() == row_count) + ); + + msm_shared_weights_indexed_unchecked::(scalars, row_count, |row_idx, scalar_idx| { + commitments[scalar_idx].comm_affine[row_idx] + }) +} + +fn msm_shared_weights_indexed_unchecked( + scalars: &[C::ScalarField], + row_count: usize, + base_at: BaseAt, +) -> Result, ZipError> +where + C: AffineRepr, + BaseAt: Fn(usize, usize) -> C + Sync, +{ + if row_count == 0 { + return Ok(Vec::new()); + } + + let one = C::ScalarField::from(1u64); + let mut unit_indices = Vec::new(); + let mut general_indices = Vec::new(); + let mut general_scalars = Vec::new(); + for (idx, scalar) in scalars.iter().enumerate() { + if scalar.is_zero() { + continue; + } + if *scalar == one { + unit_indices.push(idx); + } else { + general_indices.push(idx); + general_scalars.push(scalar.into_bigint()); + } + } + + if general_indices.is_empty() { + return Ok(cfg_into_iter!(0..row_count) + .map(|row_idx| { + let mut acc = C::Group::zero(); + for &idx in &unit_indices { + acc += base_at(row_idx, idx); + } + acc + }) + .collect()); + } + + let window_bits = shared_weight_window_bits(scalars.len()); + let half_window = 1usize << (window_bits - 1); + let full_window = 1usize << window_bits; + let bucket_len = half_window; + let segments = + ::div_ceil(&(C::ScalarField::MODULUS_BIT_SIZE as usize), &window_bits) + + 1; + let mut carries = vec![0u8; general_scalars.len()]; + let mut signed_windows = Vec::with_capacity(segments); + for segment in 0..segments { + let offset = segment * window_bits; + let mut digits = Vec::with_capacity(general_scalars.len()); + for (idx, scalar) in general_scalars.iter().enumerate() { + let raw = window_value_from_limbs(scalar.as_ref(), offset, window_bits) + + usize::from(carries[idx]); + if raw >= half_window { + digits.push(-((full_window - raw) as i16)); + carries[idx] = 1; + } else { + digits.push(raw as i16); + carries[idx] = 0; + } + } + signed_windows.push(digits); + } + + if bucket_len == 4 { + return Ok(cfg_into_iter!(0..row_count) + .map(|row_idx| { + let mut unit_sum = C::Group::zero(); + for &idx in &unit_indices { + unit_sum += base_at(row_idx, idx); + } + + let mut buckets: [C::Group; 4] = std::array::from_fn(|_| C::Group::zero()); + let mut acc = C::Group::zero(); + for digits in signed_windows.iter().rev() { + for _ in 0..window_bits { + acc.double_in_place(); + } + for bucket in &mut buckets { + *bucket = C::Group::zero(); + } + for (general_idx, digit) in digits.iter().enumerate() { + if *digit > 0 { + buckets[*digit as usize - 1] += + base_at(row_idx, general_indices[general_idx]); + } else if *digit < 0 { + buckets[(-*digit) as usize - 1] -= + base_at(row_idx, general_indices[general_idx]); + } + } + acc += bucket_running_sum(&buckets); + } + + unit_sum + acc + }) + .collect()); + } + + Ok(cfg_into_iter!(0..row_count) + .map(|row_idx| { + let mut unit_sum = C::Group::zero(); + for &idx in &unit_indices { + unit_sum += base_at(row_idx, idx); + } + + let mut buckets = vec![C::Group::zero(); bucket_len]; + let mut acc = C::Group::zero(); + for digits in signed_windows.iter().rev() { + for _ in 0..window_bits { + acc.double_in_place(); + } + for bucket in &mut buckets { + *bucket = C::Group::zero(); + } + for (general_idx, digit) in digits.iter().enumerate() { + if *digit > 0 { + buckets[*digit as usize - 1] += + base_at(row_idx, general_indices[general_idx]); + } else if *digit < 0 { + buckets[(-*digit) as usize - 1] -= + base_at(row_idx, general_indices[general_idx]); + } + } + acc += bucket_running_sum(&buckets); + } + + unit_sum + acc + }) + .collect()) +} + +fn shared_weight_window_bits(n: usize) -> usize { + if n < 32 { + 3 + } else { + (usize::BITS - n.leading_zeros()) as usize + } +} + +fn bucket_running_sum(buckets: &[G]) -> G { + let mut acc = G::zero(); + let mut running_sum = G::zero(); + for bucket in buckets.iter().rev() { + running_sum += bucket; + acc += running_sum; + } + acc +} + +fn window_value_from_limbs(limbs: &[u64], start: usize, width: usize) -> usize { + (0..width).fold(0usize, |value, bit_idx| { + let absolute_bit = start + bit_idx; + let limb_idx = absolute_bit / u64::BITS as usize; + let limb_bit = absolute_bit % u64::BITS as usize; + if limbs + .get(limb_idx) + .map(|limb| ((limb >> limb_bit) & 1) == 1) + .unwrap_or(false) + { + value | (1usize << bit_idx) + } else { + value + } + }) +} + +fn num_rows(n: usize, width: usize) -> Result { + if width == 0 { + return Err(ZipError::InvalidPcsParam( + "Hyrax row width must be non-zero".to_string(), + )); + } + Ok(::div_ceil(&n, &width)) +} + +fn alpha_index_dynamic(num_lanes: usize, poly_idx: usize, lane: usize) -> usize { + poly_idx * num_lanes + lane +} + +fn commitment_index_dynamic( + num_lanes: usize, + poly_idx: usize, + lane: usize, + row_idx: usize, + num_rows: usize, +) -> usize { + ((poly_idx * num_lanes + lane) * num_rows) + row_idx +} + +fn eq_tensor_f(point: &[F], cfg: &F::Config) -> Vec { + let mut tensor = vec![F::one_with_cfg(cfg)]; + for r in point { + let one_minus = { + let mut value = F::one_with_cfg(cfg); + value -= r; + value + }; + let current = tensor.clone(); + tensor.clear(); + for value in ¤t { + let mut lo = value.clone(); + lo *= &one_minus; + tensor.push(lo); + } + for value in current { + let mut hi = value; + hi *= r; + tensor.push(hi); + } + } + tensor +} + +fn eq_tensor_scalar(point: &[C::ScalarField]) -> Vec { + let mut tensor = vec![C::ScalarField::from(1u64)]; + for r in point { + let one_minus = C::ScalarField::from(1u64) - r; + let current = tensor.clone(); + tensor.clear(); + for value in ¤t { + tensor.push(*value * one_minus); + } + for value in current { + tensor.push(value * r); + } + } + tensor +} + +fn sample_scalars( + transcript: &mut impl Transcript, + n: usize, +) -> Vec { + (0..n) + .map(|_| { + let mut bytes = Vec::with_capacity(64); + for _ in 0..8 { + let word = transcript.get_challenge::(); + bytes.extend_from_slice(&word.to_le_bytes()); + } + C::ScalarField::from_le_bytes_mod_order(&bytes) + }) + .collect() +} + +fn write_scalars( + transcript: &mut PcsProverTranscript, + scalars: &[C::ScalarField], +) -> Result<(), ZipError> { + for scalar in scalars { + write_scalar::(transcript, scalar)?; + } + Ok(()) +} + +fn write_scalar( + transcript: &mut PcsProverTranscript, + scalar: &C::ScalarField, +) -> Result<(), ZipError> { + let bytes = scalar_bytes::(scalar)?; + transcript.fs_transcript.absorb_slice(&bytes); + transcript.stream.write_all(&bytes)?; + Ok(()) +} + +fn read_scalars( + transcript: &mut PcsVerifierTranscript, + n: usize, +) -> Result, ZipError> { + (0..n).map(|_| read_scalar::(transcript)).collect() +} + +fn read_scalar( + transcript: &mut PcsVerifierTranscript, +) -> Result { + let size = C::ScalarField::zero().serialized_size(Compress::Yes); + let mut bytes = vec![0u8; size]; + transcript.stream.read_exact(&mut bytes)?; + transcript.fs_transcript.absorb_slice(&bytes); + C::ScalarField::deserialize_compressed(bytes.as_slice()).map_err(ark_err) +} + +fn scalar_bytes(scalar: &C::ScalarField) -> Result, ZipError> { + let mut bytes = Vec::with_capacity(scalar.serialized_size(Compress::Yes)); + scalar.serialize_compressed(&mut bytes).map_err(ark_err)?; + Ok(bytes) +} + +fn affine_bytes_into(affine: &C, bytes: &mut Vec) -> Result<(), ZipError> { + affine.serialize_compressed(bytes).map_err(ark_err) +} + +fn affine_points_bytes(points: &[C]) -> Result, ZipError> { + let point_size = C::zero().serialized_size(Compress::Yes); + let mut bytes = Vec::with_capacity(points.len() * point_size); + for point in points { + affine_bytes_into::(point, &mut bytes)?; + } + Ok(bytes) +} + +fn msm_err(err: MsmError) -> ZipError { + ZipError::InvalidPcsParam(err.to_string()) +} + +fn ark_err(err: ark_serialize::SerializationError) -> ZipError { + ZipError::Serialization(format!("ark serialization error: {err}")) +} + +#[cfg(test)] +mod tests { + use super::*; + + use ark_ec::PrimeGroup; + use ark_ff::Field as ArkField; + use crypto_primitives::FromWithConfig; + + fn cfg_from_curve() -> as PrimeField>::Config { + let modulus = + uint_from_le_bytes::<4>(&::MODULUS.to_bytes_le()); + as PrimeField>::make_cfg(&modulus) + .expect("curve scalar modulus must be prime") + } + + fn assert_bridge_round_trip() -> Result<(), ZipError> { + let cfg = cfg_from_curve::(); + for value in [0u64, 1, 2, 17, 123, 1 << 20] { + let field = MontyField::<4>::from_with_cfg(value, &cfg); + let scalar = as HyraxFieldBridge>::field_to_scalar(&field)?; + assert_eq!(scalar, C::ScalarField::from(value)); + + let field_again = + as HyraxFieldBridge>::scalar_to_field(&scalar, &cfg)?; + assert_eq!(field_again, field); + } + + let large_values = [ + C::ScalarField::from(2u64).inverse().unwrap(), + -C::ScalarField::from(1u64), + C::ScalarField::from_le_bytes_mod_order(&[0xA5; 64]), + ]; + for scalar in large_values { + let field = as HyraxFieldBridge>::scalar_to_field(&scalar, &cfg)?; + let scalar_again = as HyraxFieldBridge>::field_to_scalar(&field)?; + assert_eq!(scalar_again, scalar); + } + Ok(()) + } + + #[test] + fn bridge_round_trips_bn254_scalar_field() { + assert_bridge_round_trip::().unwrap(); + } + + #[test] + fn bridge_round_trips_secp256k1_scalar_field() { + assert_bridge_round_trip::().unwrap(); + } + + #[test] + fn bridge_rejects_mismatched_field_config() { + let bn_cfg = cfg_from_curve::(); + let bn_field = MontyField::<4>::from_with_cfg(1u64, &bn_cfg); + let result = + as HyraxFieldBridge>::field_to_scalar(&bn_field); + assert!(matches!(result, Err(ZipError::InvalidPcsParam(_)))); + } + + #[test] + fn setup_derives_distinct_deterministic_bases() { + type C = ark_bn254::G1Affine; + let width = 32; + let (ck_0, vk_0) = HyraxPCS::::setup( + width, + b"zinc-plus-hyrax-setup-test", + HyraxBlindingMode::Unblinded, + ) + .unwrap(); + let (ck_1, vk_1) = HyraxPCS::::setup( + width, + b"zinc-plus-hyrax-setup-test", + HyraxBlindingMode::Unblinded, + ) + .unwrap(); + + assert_eq!(ck_0.msm_ck.bases, ck_1.msm_ck.bases); + assert_eq!(vk_0.bases, vk_1.bases); + assert_eq!(ck_0.msm_ck.h, ck_1.msm_ck.h); + assert_eq!(vk_0.h, vk_1.h); + assert_eq!(ck_0.blinding_mode, HyraxBlindingMode::Unblinded); + assert_eq!(vk_0.blinding_mode, HyraxBlindingMode::Unblinded); + assert!(ck_0.msm_ck.bases.iter().all(|base| !base.is_zero())); + assert!(!ck_0.msm_ck.h.is_zero()); + + let seen = ck_0.msm_ck.bases.iter().copied().collect::>(); + assert_eq!(seen.len(), width); + assert!(!seen.contains(&ck_0.msm_ck.h.into_affine())); + } + + #[test] + fn trusted_setup_rejects_bad_bases() { + type C = ark_bn254::G1Affine; + let width = 8; + let generator = ::Group::generator(); + let bases = (1..=width) + .map(|idx| (generator * ::ScalarField::from(idx as u64)).into_affine()) + .collect::>(); + let h = generator * ::ScalarField::from((width + 1) as u64); + + assert!(matches!( + HyraxPCS::::setup_from_trusted_bases( + 0, + Vec::new(), + ::Group::zero(), + HyraxBlindingMode::Unblinded, + ), + Err(ZipError::InvalidPcsParam(_)) + )); + + assert!(matches!( + HyraxPCS::::setup_from_trusted_bases( + width, + bases[..width - 1].to_vec(), + h, + HyraxBlindingMode::Unblinded, + ), + Err(ZipError::InvalidPcsParam(_)) + )); + + let mut identity_bases = bases.clone(); + identity_bases[0] = C::zero(); + assert!(matches!( + HyraxPCS::::setup_from_trusted_bases( + width, + identity_bases, + h, + HyraxBlindingMode::Unblinded, + ), + Err(ZipError::InvalidPcsParam(_)) + )); + + let mut duplicate_bases = bases.clone(); + duplicate_bases[1] = duplicate_bases[0]; + assert!(matches!( + HyraxPCS::::setup_from_trusted_bases( + width, + duplicate_bases, + h, + HyraxBlindingMode::Unblinded, + ), + Err(ZipError::InvalidPcsParam(_)) + )); + + assert!(matches!( + HyraxPCS::::setup_from_trusted_bases( + width, + bases.clone(), + ::Group::zero(), + HyraxBlindingMode::Unblinded, + ), + Err(ZipError::InvalidPcsParam(_)) + )); + + assert!(matches!( + HyraxPCS::::setup_from_trusted_bases( + width, + bases.clone(), + bases[0].into_group(), + HyraxBlindingMode::Unblinded, + ), + Err(ZipError::InvalidPcsParam(_)) + )); + } + + fn binary_hyrax_open_verify_round_trip_with_modes( + commit_mode: HyraxBlindingMode, + verify_mode: HyraxBlindingMode, + ) -> Result<(), ZipError> { + type C = ark_bn254::G1Affine; + type F = MontyField<4>; + const D: usize = 32; + + fn bp(bits: u32) -> BinaryPoly { + BinaryPoly::::from(bits) + } + + let cfg = cfg_from_curve::(); + let width = 512; + let (ck, _) = HyraxPCS::::setup( + width, + b"zinc-plus-hyrax-round-trip-test", + commit_mode, + )?; + let (_, vk) = HyraxPCS::::setup( + width, + b"zinc-plus-hyrax-round-trip-test", + verify_mode, + )?; + + let evals0 = (0..width) + .map(|idx| bp((idx as u32).wrapping_mul(0x9E37_79B1))) + .collect::>(); + let evals1 = (0..width) + .map(|idx| bp(!((idx as u32).wrapping_mul(0x85EB_CA6B)))) + .collect::>(); + let polys = vec![ + DenseMultilinearExtension::from_evaluations_vec(9, evals0, bp(0)), + DenseMultilinearExtension::from_evaluations_vec(9, evals1, bp(0)), + ]; + let (prover_data, commitment) = + as PCS, D>>::commit(&ck, &polys)?; + + let point = [ + [0x11u8; 64], + [0x22u8; 64], + [0x33u8; 64], + [0x44u8; 64], + [0x55u8; 64], + [0x66u8; 64], + [0x77u8; 64], + [0x88u8; 64], + [0xA5u8; 64], + ] + .iter() + .map(|bytes| { + let scalar = ::ScalarField::from_le_bytes_mod_order(bytes); + >::scalar_to_field(&scalar, &cfg) + }) + .collect::, _>>()?; + let eq = eq_tensor_f::(&point, &cfg); + let lifted_evals = polys + .iter() + .map(|poly| { + let mut coeffs = vec![F::zero_with_cfg(&cfg); D]; + for (weight, eval) in eq.iter().zip(poly.evaluations.iter()) { + for (lane, bit) in eval.iter().enumerate() { + if bit.inner() { + coeffs[lane] += weight; + } + } + } + DynamicPolynomialF::new_trimmed(coeffs) + }) + .collect::>(); + + let mut prover_transcript = PcsProverTranscript { + fs_transcript: Default::default(), + stream: Default::default(), + }; + as PCS, D>>::absorb_commitment( + &mut prover_transcript.fs_transcript, + &commitment, + ); + let mut transcription_buf = vec![0u8; ::Inner::NUM_BYTES]; + for lifted_eval in &lifted_evals { + prover_transcript + .fs_transcript + .absorb_random_field_slice(&lifted_eval.coeffs, &mut transcription_buf); + } + as PCS, D>>::prove_open::( + &mut prover_transcript, + &ck, + &polys, + &point, + &prover_data, + &cfg, + )?; + + let mut verifier_transcript = prover_transcript.into_verification_transcript(); + as PCS, D>>::absorb_commitment( + &mut verifier_transcript.fs_transcript, + &commitment, + ); + let mut transcription_buf = vec![0u8; ::Inner::NUM_BYTES]; + for lifted_eval in &lifted_evals { + verifier_transcript + .fs_transcript + .absorb_random_field_slice(&lifted_eval.coeffs, &mut transcription_buf); + } + as PCS, D>>::verify_open::( + &mut verifier_transcript, + &vk, + &commitment, + &point, + &lifted_evals, + &Vec::new(), + &cfg, + ) + } + + #[test] + fn binary_hyrax_open_verify_round_trip() { + binary_hyrax_open_verify_round_trip_with_modes( + HyraxBlindingMode::Blinded, + HyraxBlindingMode::Blinded, + ) + .unwrap(); + } + + #[test] + fn unblinded_binary_hyrax_open_verify_round_trip() { + binary_hyrax_open_verify_round_trip_with_modes( + HyraxBlindingMode::Unblinded, + HyraxBlindingMode::Unblinded, + ) + .unwrap(); + } + + #[test] + fn binary_hyrax_commitment_order_is_poly_lane_row() { + type C = ark_bn254::G1Affine; + type F = MontyField<4>; + const D: usize = 32; + + fn bp(bits: u32) -> BinaryPoly { + BinaryPoly::::from(bits) + } + + let width = 8; + let (ck, _) = HyraxPCS::::setup( + width, + b"zinc-plus-hyrax-order-test", + HyraxBlindingMode::Unblinded, + ) + .unwrap(); + let polys = vec![ + DenseMultilinearExtension::from_evaluations_vec( + 4, + (0..16).map(|idx| bp((idx * 13 + 7) as u32)).collect(), + bp(0), + ), + DenseMultilinearExtension::from_evaluations_vec( + 4, + (0..16) + .map(|idx| bp(((idx * 29 + 3) as u32).reverse_bits())) + .collect(), + bp(0), + ), + ]; + + let (prover_data, commitment) = + as PCS, D>>::commit(&ck, &polys).unwrap(); + + let mut expected = Vec::new(); + BoolSubsetMsm::<6>::precompute_ck(&ck.msm_ck); + for poly in &polys { + for lane in 0..D { + for row in poly.evaluations.chunks(width) { + let values = row.iter().map(|eval| eval.coeff(lane)).collect::>(); + let row_comm = if values.iter().copied().any(|bit| bit) { + BoolSubsetMsm::<6>::msm_bool_row(&ck.msm_ck, &values, false).unwrap() + } else { + ::Group::zero() + }; + expected.push(row_comm); + } + } + } + + assert_eq!(prover_data.blinds.len(), 0); + assert_eq!(commitment.num_rows, 2); + assert_eq!(commitment.comm, expected); + } + + #[test] + fn binary_hyrax_commitment_supports_partial_single_row() { + type C = ark_bn254::G1Affine; + type F = MontyField<4>; + const D: usize = 32; + + fn bp(bits: u32) -> BinaryPoly { + BinaryPoly::::from(bits) + } + + let width = 32; + let (ck, _) = HyraxPCS::::setup( + width, + b"zinc-plus-hyrax-partial-row-test", + HyraxBlindingMode::Unblinded, + ) + .unwrap(); + let polys = vec![DenseMultilinearExtension::from_evaluations_vec( + 4, + (0..16).map(|idx| bp((idx * 17 + 11) as u32)).collect(), + bp(0), + )]; + + let (prover_data, commitment) = + as PCS, D>>::commit(&ck, &polys).unwrap(); + + assert_eq!(prover_data.num_rows, 1); + assert_eq!(commitment.num_rows, 1); + assert_eq!(commitment.comm.len(), D); + for (lane, comm) in commitment.comm.iter().enumerate() { + let values = polys[0] + .evaluations + .iter() + .map(|eval| eval.coeff(lane)) + .collect::>(); + let expected = if values.iter().copied().any(|bit| bit) { + BoolSubsetMsm::<6>::msm_bool_row(&ck.msm_ck, &values, false).unwrap() + } else { + ::Group::zero() + }; + assert_eq!(*comm, expected); + } + } + + #[test] + fn hyrax_rejects_blinding_mode_mismatch() { + let result = binary_hyrax_open_verify_round_trip_with_modes( + HyraxBlindingMode::Unblinded, + HyraxBlindingMode::Blinded, + ); + assert!(result.is_err()); + } + + #[test] + fn folded_binary_hyrax_commitment_opens_from_scalar_lanes() { + type C = ark_bn254::G1Affine; + type F = MontyField<4>; + const D: usize = 32; + + fn bp(bits: u32) -> BinaryPoly { + BinaryPoly::::from(bits) + } + + let cfg = cfg_from_curve::(); + let n = 8; + let width = n; + let generator = ::Group::generator(); + let bases = (1..=width) + .map(|idx| (generator * ::ScalarField::from(idx as u64)).into_affine()) + .collect::>(); + let h = generator * ::ScalarField::from((width + 1) as u64); + let (ck, vk) = HyraxPCS::::setup_from_bases_with_blinding( + width, + bases, + h, + HyraxBlindingMode::Blinded, + ) + .unwrap(); + + let instance_polys = [ + vec![ + DenseMultilinearExtension::from_evaluations_vec( + 3, + (0..n).map(|idx| bp((idx as u32) * 17 + 3)).collect(), + bp(0), + ), + DenseMultilinearExtension::from_evaluations_vec( + 3, + (0..n).map(|idx| bp(!((idx as u32) * 11))).collect(), + bp(0), + ), + ], + vec![ + DenseMultilinearExtension::from_evaluations_vec( + 3, + (0..n).map(|idx| bp((idx as u32) * 23 + 9)).collect(), + bp(0), + ), + DenseMultilinearExtension::from_evaluations_vec( + 3, + (0..n).map(|idx| bp(!((idx as u32) * 5 + 7))).collect(), + bp(0), + ), + ], + ]; + + let mut prover_data = Vec::new(); + let mut commitments = Vec::new(); + for polys in &instance_polys { + let (data, commitment) = + as PCS, D>>::commit(&ck, polys).unwrap(); + prover_data.push(data); + commitments.push(commitment); + } + + let theta = [F::from_with_cfg(3u64, &cfg), F::from_with_cfg(5u64, &cfg)]; + let folded_commitment = + as FoldablePCS, D>>::fold_commitments( + &commitments, + &theta, + &cfg, + ) + .unwrap(); + let folded_data = + as FoldablePCS, D>>::fold_prover_data( + &prover_data, + &theta, + &cfg, + ) + .unwrap(); + + let theta_scalar = theta + .iter() + .map(|theta| >::field_to_scalar(theta).unwrap()) + .collect::>(); + let mut scalar_lanes = + vec![vec![vec![::ScalarField::zero(); n]; D]; instance_polys[0].len()]; + for (instance_idx, polys) in instance_polys.iter().enumerate() { + for (poly_idx, poly) in polys.iter().enumerate() { + for (eval_idx, eval) in poly.evaluations.iter().enumerate() { + for (lane, bit) in eval.iter().enumerate() { + if bit.inner() { + scalar_lanes[poly_idx][lane][eval_idx] += theta_scalar[instance_idx]; + } + } + } + } + } + + let point = [[0x11u8; 64], [0x22u8; 64], [0x33u8; 64]] + .iter() + .map(|bytes| { + let scalar = ::ScalarField::from_le_bytes_mod_order(bytes); + >::scalar_to_field(&scalar, &cfg).unwrap() + }) + .collect::>(); + let eq = eq_tensor_f::(&point, &cfg); + let folded_lifted_evals = scalar_lanes + .iter() + .map(|lanes| { + let coeffs = lanes + .iter() + .map(|values| { + values.iter().zip(eq.iter()).fold( + F::zero_with_cfg(&cfg), + |mut acc, (value, weight)| { + acc += >::scalar_to_field(value, &cfg) + .unwrap() + * weight; + acc + }, + ) + }) + .collect::>(); + DynamicPolynomialF::new_trimmed(coeffs) + }) + .collect::>(); + + let mut prover_transcript = PcsProverTranscript { + fs_transcript: Default::default(), + stream: Default::default(), + }; + as PCS, D>>::absorb_commitment( + &mut prover_transcript.fs_transcript, + &folded_commitment, + ); + let mut transcription_buf = vec![0u8; ::Inner::NUM_BYTES]; + for lifted_eval in &folded_lifted_evals { + prover_transcript + .fs_transcript + .absorb_random_field_slice(&lifted_eval.coeffs, &mut transcription_buf); + } + HyraxPCS::::prove_open_scalar_lanes::( + &mut prover_transcript, + &ck, + &scalar_lanes, + &point, + &folded_data, + &cfg, + ) + .unwrap(); + + let mut verifier_transcript = prover_transcript.into_verification_transcript(); + as PCS, D>>::absorb_commitment( + &mut verifier_transcript.fs_transcript, + &folded_commitment, + ); + let mut transcription_buf = vec![0u8; ::Inner::NUM_BYTES]; + for lifted_eval in &folded_lifted_evals { + verifier_transcript + .fs_transcript + .absorb_random_field_slice(&lifted_eval.coeffs, &mut transcription_buf); + } + as PCS, D>>::verify_open::( + &mut verifier_transcript, + &vk, + &folded_commitment, + &point, + &folded_lifted_evals, + &Vec::new(), + &cfg, + ) + .unwrap(); + } + + #[test] + fn hyrax_fold_rejects_commitment_shape_mismatch() { + type C = ark_bn254::G1Affine; + type F = MontyField<4>; + const D: usize = 32; + + let cfg = cfg_from_curve::(); + let generator = ::Group::generator(); + let bases = (1..=4) + .map(|idx| (generator * ::ScalarField::from(idx as u64)).into_affine()) + .collect::>(); + let h = generator * ::ScalarField::from(5u64); + let (ck, _) = HyraxPCS::::setup_from_bases_with_blinding( + 4, + bases, + h, + HyraxBlindingMode::Unblinded, + ) + .unwrap(); + let polys_one = vec![DenseMultilinearExtension::from_evaluations_vec( + 2, + vec![BinaryPoly::::from(1u32); 4], + BinaryPoly::::from(0u32), + )]; + let polys_two = vec![DenseMultilinearExtension::from_evaluations_vec( + 3, + vec![BinaryPoly::::from(2u32); 8], + BinaryPoly::::from(0u32), + )]; + let (_, c0) = + as PCS, D>>::commit(&ck, &polys_one) + .unwrap(); + let (_, c1) = + as PCS, D>>::commit(&ck, &polys_two) + .unwrap(); + + let theta = [F::from_with_cfg(1u64, &cfg), F::from_with_cfg(2u64, &cfg)]; + assert!(matches!( + as FoldablePCS, D>>::fold_commitments( + &[c0, c1], + &theta, + &cfg, + ), + Err(ZipError::InvalidPcsParam(_)) + )); + } + + #[test] + fn hyrax_rejects_empty_commitment_with_nonempty_lifted_evals() { + type C = ark_bn254::G1Affine; + type F = MontyField<4>; + const D: usize = 32; + + let width = 8; + let (_, vk) = HyraxPCS::::setup( + width, + b"zinc-plus-hyrax-empty-reject-test", + HyraxBlindingMode::Unblinded, + ) + .unwrap(); + + let commitment = HyraxCommitment:: { + batch_size: 0, + num_lanes: D, + num_rows: 0, + blinding_mode: HyraxBlindingMode::Unblinded, + comm: Vec::new(), + }; + let cfg = cfg_from_curve::(); + let lifted_evals = vec![DynamicPolynomialF::new_trimmed(vec![F::zero_with_cfg( + &cfg, + )])]; + let mut verifier_transcript = PcsVerifierTranscript { + fs_transcript: Default::default(), + stream: Default::default(), + }; + + let result = as PCS, D>>::verify_open::( + &mut verifier_transcript, + &vk, + &commitment, + &[], + &lifted_evals, + &Vec::new(), + &cfg, + ); + assert!(matches!(result, Err(ZipError::InvalidPcsParam(_)))); + } + + #[test] + fn hyrax_rejects_noncanonical_empty_commitment() { + type C = ark_bn254::G1Affine; + type F = MontyField<4>; + const D: usize = 32; + + let width = 8; + let (_, vk) = HyraxPCS::::setup( + width, + b"zinc-plus-hyrax-empty-reject-test-2", + HyraxBlindingMode::Unblinded, + ) + .unwrap(); + + let commitment = HyraxCommitment:: { + batch_size: 0, + num_lanes: D, + num_rows: 1, + blinding_mode: HyraxBlindingMode::Unblinded, + comm: Vec::new(), + }; + let cfg = cfg_from_curve::(); + let mut verifier_transcript = PcsVerifierTranscript { + fs_transcript: Default::default(), + stream: Default::default(), + }; + + let result = as PCS, D>>::verify_open::( + &mut verifier_transcript, + &vk, + &commitment, + &[], + &[], + &Vec::new(), + &cfg, + ); + assert!(matches!(result, Err(ZipError::InvalidPcsParam(_)))); + } +} diff --git a/zip-plus/src/pcs/msm_commitment.rs b/zip-plus/src/pcs/msm_commitment.rs new file mode 100644 index 00000000..69e9d732 --- /dev/null +++ b/zip-plus/src/pcs/msm_commitment.rs @@ -0,0 +1,1122 @@ +#![allow(clippy::arithmetic_side_effects)] +#![allow(clippy::cast_possible_truncation)] +#![allow(clippy::cast_possible_wrap)] +#![allow(clippy::cast_sign_loss)] + +use std::{ + marker::PhantomData, + sync::{Arc, OnceLock}, +}; + +use ark_ec::{AffineRepr, CurveGroup}; +use ark_ff::{AdditiveGroup, One, PrimeField, UniformRand, Zero}; +use crypto_primitives::{IntRing, crypto_bigint_int::Int}; +use num_integer::Integer; +use num_traits::Zero as NumZero; +use thiserror::Error; +use zinc_utils::{cfg_chunks, cfg_iter}; + +#[cfg(feature = "parallel")] +use rayon::prelude::*; + +const DEFAULT_BOOL_WINDOW_BITS: usize = 6; + +#[derive(Clone, Debug)] +pub struct MsmCommitmentKey { + pub(crate) num_cols: usize, + pub(crate) bases: Vec, + pub(crate) h: C::Group, + bool_tables_6: Arc>>, +} + +#[derive(Clone, Debug)] +#[allow(dead_code)] +pub struct MsmVerifierKey { + pub(crate) num_cols: usize, + pub(crate) bases: Vec, + pub(crate) h: C::Group, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct MsmCommitment { + pub(crate) comm: Vec, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct MsmBlind { + pub(crate) blind: Vec, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct MsmCommitmentEngine(PhantomData); + +#[derive(Clone, Debug, PartialEq, Eq, Error)] +pub enum MsmError { + #[error("MSM commitment width must be non-zero")] + InvalidWidth, + #[error("MSM commitment expected {expected} bases, got {actual}")] + BaseCountMismatch { expected: usize, actual: usize }, + #[error("MSM commitment expected {expected} blinds, got {actual}")] + BlindCountMismatch { expected: usize, actual: usize }, + #[error("MSM row length must be at most {max}, got {actual}")] + RowLengthMismatch { max: usize, actual: usize }, + #[error("MSM commitment expected {expected} row commitments, got {actual}")] + CommitmentShapeMismatch { expected: usize, actual: usize }, + #[error("MSM window size must be in 1..usize::BITS, got {0}")] + InvalidWindowBits(usize), + #[error("cannot commit minimum signed integer value")] + SignedIntegerMinimum, +} + +pub trait RowMsmStrategy +where + C: AffineRepr, + V: Copy + Send + Sync, +{ + fn precompute_ck(ck: &MsmCommitmentKey); + + fn msm_row(ck: &MsmCommitmentKey, values: &[V]) -> Result; + + fn is_zero(value: V) -> bool; + + fn to_scalar(value: V) -> C::ScalarField; +} + +pub struct BoolSubsetMsm; +pub struct U8BucketMsm; +pub struct ScalarPippengerMsm; +pub struct SignedIntPippengerMsm; + +#[derive(Clone, Debug)] +struct BoolWindowTable { + tables: Vec>, + lens: Vec, +} + +impl BoolWindowTable { + fn new(bases: &[C], window_bits: usize) -> Self { + let built = cfg_chunks!(bases, window_bits) + .map(|window| { + let len = window.len(); + let table_len = 1usize << len; + let mut table = vec![C::Group::zero(); table_len]; + for mask in 1..table_len { + let bit = mask.trailing_zeros() as usize; + let previous = mask & !(1usize << bit); + table[mask] = table[previous] + window[bit]; + } + (table, len) + }) + .collect::>(); + + let (tables, lens) = built.into_iter().unzip(); + Self { tables, lens } + } + + fn msm_row( + &self, + values: &[bool], + window_bits: usize, + _use_parallelism_internally: bool, + ) -> C::Group { + #[cfg(feature = "parallel")] + if _use_parallelism_internally && self.lens.len() > 1 { + return self + .lens + .par_iter() + .copied() + .enumerate() + .map(|(window_idx, len)| { + let offset = window_idx * window_bits; + if offset >= values.len() { + return C::Group::zero(); + } + let end = (offset + len).min(values.len()); + self.tables[window_idx][bit_mask(&values[offset..end])] + }) + .reduce(C::Group::zero, |acc, point| acc + point); + } + + let mut acc = C::Group::zero(); + for (window_idx, len) in self.lens.iter().copied().enumerate() { + let offset = window_idx * window_bits; + if offset >= values.len() { + break; + } + let end = (offset + len).min(values.len()); + acc += self.tables[window_idx][bit_mask(&values[offset..end])]; + } + acc + } + + fn msm_rows_from_window_masks( + &self, + value_len: usize, + window_bits: usize, + _use_parallelism_internally: bool, + mask_at: M, + ) -> [C::Group; LANES] + where + M: Fn(usize, usize) -> [usize; LANES] + Sync, + { + #[cfg(feature = "parallel")] + if _use_parallelism_internally && self.lens.len() > 1 { + return self + .lens + .par_iter() + .copied() + .enumerate() + .map(|(window_idx, len)| { + let mut partial = std::array::from_fn(|_| C::Group::zero()); + let offset = window_idx * window_bits; + if offset >= value_len { + return partial; + } + let end = (offset + len).min(value_len); + let masks = mask_at(offset, end - offset); + for lane in 0..LANES { + partial[lane] += self.tables[window_idx][masks[lane]]; + } + partial + }) + .reduce( + || std::array::from_fn(|_| C::Group::zero()), + |mut acc, partial| { + for lane in 0..LANES { + acc[lane] += partial[lane]; + } + acc + }, + ); + } + + let mut acc = std::array::from_fn(|_| C::Group::zero()); + for (window_idx, len) in self.lens.iter().copied().enumerate() { + let offset = window_idx * window_bits; + if offset >= value_len { + break; + } + let end = (offset + len).min(value_len); + let masks = mask_at(offset, end - offset); + for lane in 0..LANES { + acc[lane] += self.tables[window_idx][masks[lane]]; + } + } + acc + } +} + +impl MsmCommitmentEngine { + pub fn setup_from_bases( + width: usize, + bases: Vec, + h: C::Group, + ) -> Result<(MsmCommitmentKey, MsmVerifierKey), MsmError> { + if width == 0 { + return Err(MsmError::InvalidWidth); + } + if bases.len() != width { + return Err(MsmError::BaseCountMismatch { + expected: width, + actual: bases.len(), + }); + } + + let vk = MsmVerifierKey { + num_cols: width, + bases: bases.clone(), + h, + }; + let ck = MsmCommitmentKey { + num_cols: width, + bases, + h, + bool_tables_6: Arc::new(OnceLock::new()), + }; + + Ok((ck, vk)) + } + + pub fn precompute_ck(ck: &MsmCommitmentKey) { + as RowMsmStrategy>::precompute_ck(ck); + ScalarPippengerMsm::precompute_ck(ck); + } + + pub fn blind(ck: &MsmCommitmentKey, n: usize) -> MsmBlind { + let num_rows = num_rows(n, ck.num_cols).unwrap_or(0); + let mut rng = ark_std::rand::thread_rng(); + let blind = (0..num_rows) + .map(|_| C::ScalarField::rand(&mut rng)) + .collect(); + MsmBlind { blind } + } + + pub fn commit_with( + ck: &MsmCommitmentKey, + values: &[V], + blind: &MsmBlind, + ) -> Result, MsmError> + where + V: Copy + Send + Sync, + S: RowMsmStrategy, + { + let expected_rows = num_rows(values.len(), ck.num_cols)?; + if blind.blind.len() != expected_rows { + return Err(MsmError::BlindCountMismatch { + expected: expected_rows, + actual: blind.blind.len(), + }); + } + + S::precompute_ck(ck); + let comm = cfg_chunks!(values, ck.num_cols) + .enumerate() + .map(|(row_idx, row)| { + let mut row_comm = commit_row::(ck, row)?; + row_comm += ck.h * blind.blind[row_idx]; + Ok(row_comm) + }) + .collect::, _>>()?; + + Ok(MsmCommitment { comm }) + } + + pub fn commit_unblinded_with( + ck: &MsmCommitmentKey, + values: &[V], + ) -> Result, MsmError> + where + V: Copy + Send + Sync, + S: RowMsmStrategy, + { + let expected_rows = num_rows(values.len(), ck.num_cols)?; + S::precompute_ck(ck); + let comm = cfg_chunks!(values, ck.num_cols) + .map(|row| commit_row::(ck, row)) + .collect::, _>>()?; + debug_assert_eq!(comm.len(), expected_rows); + + Ok(MsmCommitment { comm }) + } + + pub fn commit_unblinded( + ck: &MsmCommitmentKey, + values: &[C::ScalarField], + ) -> Result, MsmError> { + Self::commit_unblinded_with::(ck, values) + } + + pub fn commit( + ck: &MsmCommitmentKey, + values: &[C::ScalarField], + blind: &MsmBlind, + ) -> Result, MsmError> { + Self::commit_with::(ck, values, blind) + } + + pub fn commit_zeros( + ck: &MsmCommitmentKey, + n: usize, + blind: &MsmBlind, + ) -> Result, MsmError> { + let expected_rows = num_rows(n, ck.num_cols)?; + if blind.blind.len() != expected_rows { + return Err(MsmError::BlindCountMismatch { + expected: expected_rows, + actual: blind.blind.len(), + }); + } + + let comm = cfg_iter!(blind.blind).map(|r| ck.h * r).collect(); + Ok(MsmCommitment { comm }) + } + + pub fn check_commitment( + comm: &MsmCommitment, + n: usize, + width: usize, + ) -> Result<(), MsmError> { + let expected_rows = num_rows(n, width)?; + if comm.comm.len() != expected_rows { + return Err(MsmError::CommitmentShapeMismatch { + expected: expected_rows, + actual: comm.comm.len(), + }); + } + Ok(()) + } +} + +impl RowMsmStrategy + for BoolSubsetMsm +{ + fn precompute_ck(ck: &MsmCommitmentKey) { + if WINDOW_BITS == DEFAULT_BOOL_WINDOW_BITS { + ck.bool_tables_6 + .get_or_init(|| BoolWindowTable::new(&ck.bases, DEFAULT_BOOL_WINDOW_BITS)); + } + } + + fn msm_row(ck: &MsmCommitmentKey, values: &[bool]) -> Result { + Self::msm_bool_row(ck, values, false) + } + + fn is_zero(value: bool) -> bool { + !value + } + + fn to_scalar(value: bool) -> C::ScalarField { + if value { + C::ScalarField::one() + } else { + C::ScalarField::zero() + } + } +} + +impl BoolSubsetMsm { + pub(crate) fn msm_bool_row( + ck: &MsmCommitmentKey, + values: &[bool], + use_parallelism_internally: bool, + ) -> Result { + validate_row_len(ck, values.len())?; + validate_window_bits(WINDOW_BITS)?; + + if WINDOW_BITS == DEFAULT_BOOL_WINDOW_BITS { + return Ok(ck + .bool_tables_6 + .get_or_init(|| BoolWindowTable::new(&ck.bases, DEFAULT_BOOL_WINDOW_BITS)) + .msm_row(values, DEFAULT_BOOL_WINDOW_BITS, use_parallelism_internally)); + } + + let mut acc = C::Group::zero(); + for (window_idx, bits) in values.chunks(WINDOW_BITS).enumerate() { + let start = window_idx * WINDOW_BITS; + let end = start + bits.len(); + let table = subset_table::(&ck.bases[start..end])?; + acc += table[bit_mask(bits)]; + } + Ok(acc) + } + + pub(crate) fn msm_bool_rows_from_window_masks( + ck: &MsmCommitmentKey, + value_len: usize, + use_parallelism_internally: bool, + mask_at: M, + ) -> Result<[C::Group; LANES], MsmError> + where + C: AffineRepr, + M: Fn(usize, usize) -> [usize; LANES] + Sync, + { + validate_row_len(ck, value_len)?; + validate_window_bits(WINDOW_BITS)?; + + if WINDOW_BITS == DEFAULT_BOOL_WINDOW_BITS { + return Ok(ck + .bool_tables_6 + .get_or_init(|| BoolWindowTable::new(&ck.bases, DEFAULT_BOOL_WINDOW_BITS)) + .msm_rows_from_window_masks( + value_len, + DEFAULT_BOOL_WINDOW_BITS, + use_parallelism_internally, + mask_at, + )); + } + + let mut acc = std::array::from_fn(|_| C::Group::zero()); + for (window_idx, window) in ck.bases[..value_len].chunks(WINDOW_BITS).enumerate() { + let offset = window_idx * WINDOW_BITS; + let masks = mask_at(offset, window.len()); + let table = subset_table::(window)?; + for lane in 0..LANES { + acc[lane] += table[masks[lane]]; + } + } + Ok(acc) + } +} + +impl RowMsmStrategy for U8BucketMsm { + fn precompute_ck(_ck: &MsmCommitmentKey) {} + + fn msm_row(ck: &MsmCommitmentKey, values: &[u8]) -> Result { + validate_row_len(ck, values.len())?; + + let max_value = values.iter().copied().max().unwrap_or(0); + if max_value == 0 { + return Ok(C::Group::zero()); + } + + let mut buckets = vec![C::Group::zero(); usize::from(max_value)]; + for (&value, base) in values.iter().zip(ck.bases.iter()) { + if value != 0 { + buckets[usize::from(value) - 1] += base; + } + } + + Ok(bucket_running_sum(&buckets)) + } + + fn is_zero(value: u8) -> bool { + value == 0 + } + + fn to_scalar(value: u8) -> C::ScalarField { + C::ScalarField::from(u64::from(value)) + } +} + +impl RowMsmStrategy for ScalarPippengerMsm { + fn precompute_ck(_ck: &MsmCommitmentKey) {} + + fn msm_row(ck: &MsmCommitmentKey, values: &[C::ScalarField]) -> Result { + validate_row_len(ck, values.len())?; + signed_window_pippenger::(values, &ck.bases[..values.len()]) + } + + fn is_zero(value: C::ScalarField) -> bool { + value.is_zero() + } + + fn to_scalar(value: C::ScalarField) -> C::ScalarField { + value + } +} + +impl RowMsmStrategy> for SignedIntPippengerMsm { + fn precompute_ck(_ck: &MsmCommitmentKey) {} + + fn msm_row(ck: &MsmCommitmentKey, values: &[Int]) -> Result { + validate_row_len(ck, values.len())?; + signed_int_window_pippenger::(values, &ck.bases[..values.len()]) + } + + fn is_zero(value: Int) -> bool { + NumZero::is_zero(&value) + } + + fn to_scalar(value: Int) -> C::ScalarField { + signed_int_to_scalar::(&value) + .expect("signed integer lane value must fit scalar conversion") + } +} + +fn num_rows(n: usize, width: usize) -> Result { + if width == 0 { + return Err(MsmError::InvalidWidth); + } + Ok(::div_ceil(&n, &width)) +} + +fn validate_row_len( + ck: &MsmCommitmentKey, + actual: usize, +) -> Result<(), MsmError> { + if actual > ck.num_cols { + return Err(MsmError::RowLengthMismatch { + max: ck.num_cols, + actual, + }); + } + Ok(()) +} + +fn validate_window_bits(window_bits: usize) -> Result<(), MsmError> { + if window_bits == 0 || window_bits >= usize::BITS as usize { + return Err(MsmError::InvalidWindowBits(window_bits)); + } + Ok(()) +} + +fn commit_row(ck: &MsmCommitmentKey, row: &[V]) -> Result +where + C: AffineRepr, + V: Copy + Send + Sync, + S: RowMsmStrategy, +{ + let effective_len = row + .iter() + .rposition(|value| !S::is_zero(*value)) + .map_or(0, |pos| pos + 1); + if effective_len == 0 { + Ok(C::Group::zero()) + } else { + S::msm_row(ck, &row[..effective_len]) + } +} + +fn bit_mask(bits: &[bool]) -> usize { + bits.iter().enumerate().fold( + 0usize, + |mask, (idx, bit)| { + if *bit { mask | (1usize << idx) } else { mask } + }, + ) +} + +fn subset_table(bases: &[C]) -> Result, MsmError> { + validate_window_bits(bases.len())?; + let table_len = 1usize << bases.len(); + let mut table = vec![C::Group::zero(); table_len]; + + for mask in 1..table_len { + let bit = mask.trailing_zeros() as usize; + let previous = mask & !(1usize << bit); + table[mask] = table[previous] + bases[bit]; + } + + Ok(table) +} + +fn bucket_running_sum(buckets: &[G]) -> G { + let mut acc = G::zero(); + let mut running_sum = G::zero(); + for bucket in buckets.iter().rev() { + running_sum += bucket; + acc += running_sum; + } + acc +} + +fn signed_window_pippenger( + scalars: &[C::ScalarField], + bases: &[C], +) -> Result { + if scalars.len() != bases.len() { + return Err(MsmError::BaseCountMismatch { + expected: scalars.len(), + actual: bases.len(), + }); + } + if scalars.is_empty() { + return Ok(C::Group::zero()); + } + + let window_bits = scalar_window_bits(scalars.len()); + validate_window_bits(window_bits)?; + + let num_bits = C::ScalarField::MODULUS_BIT_SIZE as usize; + let segments = ::div_ceil(&num_bits, &window_bits); + let bucket_len = (1usize << window_bits) - 1; + let bigints = scalars + .iter() + .map(|scalar| scalar.into_bigint()) + .collect::>(); + let mut buckets = vec![C::Group::zero(); bucket_len]; + + let mut acc = C::Group::zero(); + for segment in (0..segments).rev() { + for _ in 0..window_bits { + acc.double_in_place(); + } + + let offset = segment * window_bits; + for bucket in &mut buckets { + *bucket = C::Group::zero(); + } + for (j, scalar) in bigints.iter().enumerate() { + let digit = window_value_from_limbs(scalar.as_ref(), offset, window_bits); + if digit != 0 { + buckets[digit - 1] += bases[j]; + } + } + + acc += bucket_running_sum(&buckets); + } + + Ok(acc) +} + +fn signed_int_window_pippenger( + values: &[Int], + bases: &[C], +) -> Result { + if values.len() != bases.len() { + return Err(MsmError::BaseCountMismatch { + expected: values.len(), + actual: bases.len(), + }); + } + if values.is_empty() { + return Ok(C::Group::zero()); + } + + let mut max_bits = 0usize; + for value in values { + let (abs, _) = signed_int_abs(value)?; + max_bits = max_bits.max(bit_len_from_words(abs.as_uint().as_words())); + } + if max_bits == 0 { + return Ok(C::Group::zero()); + } + + let window_bits = scalar_window_bits(values.len()).min(max_bits).max(1); + validate_window_bits(window_bits)?; + + let segments = ::div_ceil(&max_bits, &window_bits); + let bucket_len = (1usize << window_bits) - 1; + let mut positive_buckets = vec![C::Group::zero(); bucket_len]; + let mut negative_buckets = vec![C::Group::zero(); bucket_len]; + + let mut acc = C::Group::zero(); + for segment in (0..segments).rev() { + for _ in 0..window_bits { + acc.double_in_place(); + } + + for bucket in &mut positive_buckets { + *bucket = C::Group::zero(); + } + for bucket in &mut negative_buckets { + *bucket = C::Group::zero(); + } + + let offset = segment * window_bits; + for (value, base) in values.iter().zip(bases.iter()) { + let (abs, is_negative) = signed_int_abs(value)?; + let digit = window_value_from_words(abs.as_uint().as_words(), offset, window_bits); + if digit != 0 { + if is_negative { + negative_buckets[digit - 1] += base; + } else { + positive_buckets[digit - 1] += base; + } + } + } + + acc += bucket_running_sum(&positive_buckets); + acc -= bucket_running_sum(&negative_buckets); + } + + Ok(acc) +} + +fn signed_int_abs(value: &Int) -> Result<(Int, bool), MsmError> { + if value.is_negative() { + let abs = value.checked_abs().ok_or(MsmError::SignedIntegerMinimum)?; + Ok((abs, true)) + } else { + Ok((*value, false)) + } +} + +fn signed_int_to_scalar( + value: &Int, +) -> Result { + let (abs, is_negative) = signed_int_abs(value)?; + let mut bytes = Vec::with_capacity(LIMBS * core::mem::size_of::()); + for word in abs.as_uint().as_words() { + bytes.extend_from_slice(&word.to_le_bytes()); + } + let mut scalar = C::ScalarField::from_le_bytes_mod_order(&bytes); + if is_negative && !scalar.is_zero() { + scalar = -scalar; + } + Ok(scalar) +} + +fn scalar_window_bits(n: usize) -> usize { + if n < 4 { + 1 + } else if n < 32 { + 3 + } else { + (usize::BITS - n.leading_zeros()) as usize + } +} + +fn bit_len_from_words(words: &[crypto_bigint::Word]) -> usize { + let word_bits = core::mem::size_of::() * 8; + for (idx, word) in words.iter().copied().enumerate().rev() { + if word != 0 { + return idx * word_bits + word_bits - word.leading_zeros() as usize; + } + } + 0 +} + +fn window_value_from_limbs(limbs: &[u64], start: usize, width: usize) -> usize { + (0..width).fold(0usize, |value, bit_idx| { + let absolute_bit = start + bit_idx; + let limb_idx = absolute_bit / u64::BITS as usize; + let limb_bit = absolute_bit % u64::BITS as usize; + if limbs + .get(limb_idx) + .map(|limb| ((limb >> limb_bit) & 1) == 1) + .unwrap_or(false) + { + value | (1usize << bit_idx) + } else { + value + } + }) +} + +fn window_value_from_words(words: &[crypto_bigint::Word], start: usize, width: usize) -> usize { + let word_bits = core::mem::size_of::() * 8; + (0..width).fold(0usize, |value, bit_idx| { + let absolute_bit = start + bit_idx; + let word_idx = absolute_bit / word_bits; + let word_bit = absolute_bit % word_bits; + if words + .get(word_idx) + .map(|word| ((word >> word_bit) & 1) == 1) + .unwrap_or(false) + { + value | (1usize << bit_idx) + } else { + value + } + }) +} + +#[cfg(test)] +mod tests { + use super::*; + + use ark_bn254::{Fr, G1Affine, G1Projective}; + use ark_ec::PrimeGroup; + use ark_ff::UniformRand; + + type TestCurve = G1Affine; + + fn fr(value: usize) -> Fr { + Fr::from(u64::try_from(value).expect("test value must fit into u64")) + } + + fn setup(width: usize) -> (MsmCommitmentKey, MsmVerifierKey) { + let generator = G1Projective::generator(); + let bases = (1..=width) + .map(|idx| (generator * fr(idx)).into_affine()) + .collect(); + let h = generator * fr(width + 1); + MsmCommitmentEngine::::setup_from_bases(width, bases, h) + .expect("valid test setup") + } + + fn blind(width: usize, n: usize) -> MsmBlind { + let rows = ::div_ceil(&n, &width); + MsmBlind { + blind: (0..rows).map(|idx| fr(idx + 11)).collect(), + } + } + + fn bool_values(n: usize) -> Vec { + (0..n).map(|idx| idx % 3 == 0 || idx % 7 == 1).collect() + } + + fn u8_values(n: usize, modulus: u8) -> Vec { + (0..n) + .map(|idx| { + let value = (idx * 17 + 5) % usize::from(modulus); + u8::try_from(value).expect("test u8 value must fit") + }) + .collect() + } + + fn scalars_from_bool(values: &[bool]) -> Vec { + values + .iter() + .map(|value| if *value { Fr::one() } else { Fr::zero() }) + .collect() + } + + fn scalars_from_u8(values: &[u8]) -> Vec { + values + .iter() + .map(|value| Fr::from(u64::from(*value))) + .collect() + } + + fn scalars_from_int(values: &[Int]) -> Vec { + values + .iter() + .map(|value| signed_int_to_scalar::(value).expect("valid test int")) + .collect() + } + + fn naive_scalar_commit( + ck: &MsmCommitmentKey, + values: &[Fr], + blind: &MsmBlind, + ) -> MsmCommitment { + let comm = values + .chunks(ck.num_cols) + .enumerate() + .map(|(row_idx, row)| { + let mut acc = G1Projective::zero(); + for (scalar, base) in row.iter().zip(ck.bases.iter()) { + acc += *base * scalar; + } + acc += ck.h * blind.blind[row_idx]; + acc + }) + .collect(); + MsmCommitment { comm } + } + + fn naive_scalar_commit_unblinded( + ck: &MsmCommitmentKey, + values: &[Fr], + ) -> MsmCommitment { + let comm = values + .chunks(ck.num_cols) + .map(|row| { + let mut acc = G1Projective::zero(); + for (scalar, base) in row.iter().zip(ck.bases.iter()) { + acc += *base * scalar; + } + acc + }) + .collect(); + MsmCommitment { comm } + } + + #[test] + fn bool_commit_matches_scalar_commit_for_configured_widths() { + for width in [8, 32, 64] { + let (ck, _) = setup(width); + let n = width * 3 + 5; + let values = bool_values(n); + let scalars = scalars_from_bool(&values); + let blind = blind(width, n); + + let bool_comm = + MsmCommitmentEngine::::commit_with::>( + &ck, &values, &blind, + ) + .expect("bool commit must succeed"); + let scalar_comm = MsmCommitmentEngine::::commit(&ck, &scalars, &blind) + .expect("scalar commit must succeed"); + + assert_eq!(bool_comm, scalar_comm); + } + } + + #[test] + fn precomputed_bool_commit_matches_scalar_commit_for_wide_rows() { + for width in [8, 32, 64, 512] { + let (ck, _) = setup(width); + let n = width + 5; + let values = bool_values(n); + let scalars = scalars_from_bool(&values); + let blind = blind(width, n); + + let before_precompute = MsmCommitmentEngine::::commit_with::< + bool, + BoolSubsetMsm<6>, + >(&ck, &values, &blind) + .expect("bool commit before precompute must succeed"); + MsmCommitmentEngine::::precompute_ck(&ck); + let after_precompute = MsmCommitmentEngine::::commit_with::< + bool, + BoolSubsetMsm<6>, + >(&ck, &values, &blind) + .expect("bool commit after precompute must succeed"); + let cloned_ck = ck.clone(); + let after_clone = + MsmCommitmentEngine::::commit_with::>( + &cloned_ck, &values, &blind, + ) + .expect("bool commit through cloned ck must succeed"); + let scalar_comm = MsmCommitmentEngine::::commit(&ck, &scalars, &blind) + .expect("scalar commit must succeed"); + + assert_eq!(before_precompute, scalar_comm); + assert_eq!(after_precompute, scalar_comm); + assert_eq!(after_clone, scalar_comm); + } + } + + #[test] + fn unblinded_bool_commit_matches_scalar_commit_for_wide_rows() { + for width in [8, 32, 64, 512] { + let (ck, _) = setup(width); + let n = width + 7; + let values = bool_values(n); + let scalars = scalars_from_bool(&values); + + MsmCommitmentEngine::::precompute_ck(&ck); + let bool_comm = MsmCommitmentEngine::::commit_unblinded_with::< + bool, + BoolSubsetMsm<6>, + >(&ck, &values) + .expect("unblinded bool commit must succeed"); + let scalar_comm = MsmCommitmentEngine::::commit_unblinded(&ck, &scalars) + .expect("unblinded scalar commit must succeed"); + let naive_comm = naive_scalar_commit_unblinded(&ck, &scalars); + + assert_eq!(bool_comm, scalar_comm); + assert_eq!(bool_comm, naive_comm); + } + } + + #[test] + fn u8_commit_matches_scalar_commit_for_configured_widths() { + for width in [8, 32, 64] { + let (ck, _) = setup(width); + let n = width * 2 + width / 2 + 1; + let cases = [vec![0; n], vec![1; n], u8_values(n, 32), u8_values(n, 255)]; + + for values in cases { + let scalars = scalars_from_u8(&values); + let blind = blind(width, n); + + let u8_comm = MsmCommitmentEngine::::commit_with::( + &ck, &values, &blind, + ) + .expect("u8 commit must succeed"); + let scalar_comm = MsmCommitmentEngine::::commit(&ck, &scalars, &blind) + .expect("scalar commit must succeed"); + + assert_eq!(u8_comm, scalar_comm); + } + } + } + + #[test] + fn scalar_commit_matches_naive_full_field_commit_for_configured_widths() { + let mut rng = ark_std::test_rng(); + for width in [8, 32, 64] { + let (ck, _) = setup(width); + let n = width * 2 + 3; + let values = (0..n).map(|_| Fr::rand(&mut rng)).collect::>(); + let blind = blind(width, n); + + let scalar_comm = MsmCommitmentEngine::::commit(&ck, &values, &blind) + .expect("scalar commit must succeed"); + let naive_comm = naive_scalar_commit(&ck, &values, &blind); + + assert_eq!(scalar_comm, naive_comm); + } + } + + #[test] + fn signed_int_commit_matches_scalar_commit_for_small_values() { + for width in [8, 32, 64] { + let (ck, _) = setup(width); + let n = width * 2 + 5; + let values = (0..n) + .map(|idx| Int::<1>::from((idx as i64 % 31) - 15)) + .collect::>(); + let scalars = scalars_from_int(&values); + let blind = blind(width, n); + + let int_comm = MsmCommitmentEngine::::commit_with::< + Int<1>, + SignedIntPippengerMsm, + >(&ck, &values, &blind) + .expect("signed int commit must succeed"); + let scalar_comm = MsmCommitmentEngine::::commit(&ck, &scalars, &blind) + .expect("scalar commit must succeed"); + + assert_eq!(int_comm, scalar_comm); + } + } + + #[test] + fn commit_zeros_matches_strategy_zero_paths() { + let width = 32; + let n = width * 2 + 9; + let (ck, _) = setup(width); + let blind = blind(width, n); + let zeros_bool = vec![false; n]; + let zeros_u8 = vec![0u8; n]; + let zeros_scalar = vec![Fr::zero(); n]; + + let zero_comm = MsmCommitmentEngine::::commit_zeros(&ck, n, &blind) + .expect("zero commit must succeed"); + let bool_comm = MsmCommitmentEngine::::commit_with::>( + &ck, + &zeros_bool, + &blind, + ) + .expect("bool zero commit must succeed"); + let u8_comm = MsmCommitmentEngine::::commit_with::( + &ck, &zeros_u8, &blind, + ) + .expect("u8 zero commit must succeed"); + let scalar_comm = MsmCommitmentEngine::::commit(&ck, &zeros_scalar, &blind) + .expect("scalar zero commit must succeed"); + + assert_eq!(zero_comm, bool_comm); + assert_eq!(zero_comm, u8_comm); + assert_eq!(zero_comm, scalar_comm); + } + + #[test] + fn changing_one_blind_changes_only_that_row_by_delta_h() { + let width = 16; + let n = width * 3; + let (ck, _) = setup(width); + let values = u8_values(n, 32); + let mut blind_a = blind(width, n); + let mut blind_b = blind_a.clone(); + let delta = fr(5); + blind_b.blind[1] += delta; + + let comm_a = MsmCommitmentEngine::::commit_with::( + &ck, &values, &blind_a, + ) + .expect("first commit must succeed"); + let comm_b = MsmCommitmentEngine::::commit_with::( + &ck, &values, &blind_b, + ) + .expect("second commit must succeed"); + + for row_idx in 0..comm_a.comm.len() { + let actual_delta = comm_b.comm[row_idx] - comm_a.comm[row_idx]; + let expected_delta = if row_idx == 1 { + ck.h * delta + } else { + G1Projective::zero() + }; + assert_eq!(actual_delta, expected_delta); + } + + blind_a.blind[1] += delta; + assert_eq!(blind_a, blind_b); + } + + #[test] + fn rejects_invalid_shapes() { + let width = 8; + let (ck, _) = setup(width); + let n = 17; + let values = vec![Fr::one(); n]; + let blind = blind(width, n); + + assert!(matches!( + MsmCommitmentEngine::::setup_from_bases(0, Vec::new(), G1Projective::zero()), + Err(MsmError::InvalidWidth) + )); + assert!(matches!( + MsmCommitmentEngine::::setup_from_bases( + width, + vec![G1Affine::generator(); width - 1], + G1Projective::generator(), + ), + Err(MsmError::BaseCountMismatch { .. }) + )); + + let short_blind = MsmBlind { + blind: blind.blind[..1].to_vec(), + }; + assert!(matches!( + MsmCommitmentEngine::::commit(&ck, &values, &short_blind), + Err(MsmError::BlindCountMismatch { .. }) + )); + + let comm = MsmCommitment { comm: Vec::new() }; + assert!(matches!( + MsmCommitmentEngine::::check_commitment(&comm, n, width), + Err(MsmError::CommitmentShapeMismatch { .. }) + )); + assert!(matches!( + MsmCommitmentEngine::::check_commitment(&comm, n, 0), + Err(MsmError::InvalidWidth) + )); + } +} diff --git a/zip-plus/src/pcs/multi_zip.rs b/zip-plus/src/pcs/multi_zip.rs index 18d39b34..0ce220d8 100644 --- a/zip-plus/src/pcs/multi_zip.rs +++ b/zip-plus/src/pcs/multi_zip.rs @@ -35,7 +35,8 @@ use std::marker::PhantomData; use zinc_poly::mle::DenseMultilinearExtension; use zinc_transcript::traits::Transcribable; use zinc_utils::{ - cfg_into_iter, cfg_iter, cfg_join, from_ref::FromRef, mul_by_scalar::MulByScalar, + cfg_into_iter, cfg_iter, cfg_join, delayed_reduction::DelayedFieldProductSum, + from_ref::FromRef, mul_by_scalar::MulByScalar, }; /// Full prover-side data for a [`MultiZip3`] commitment: per-instance @@ -50,9 +51,7 @@ pub struct MultiZipHint3 { /// Three-instance Zip+ wrapper sharing a single Merkle tree across three /// independent Zip+ commitments. -pub struct MultiZip3( - PhantomData<(Zt0, Zt1, Zt2, Lc0, Lc1, Lc2)>, -) +pub struct MultiZip3(PhantomData<(Zt0, Zt1, Zt2, Lc0, Lc1, Lc2)>) where Zt0: ZipTypes, Zt1: ZipTypes, @@ -94,8 +93,8 @@ where ), ZipError, > { - let nonempty = (!polys0.is_empty()) as u8 + (!polys1.is_empty()) as u8 - + (!polys2.is_empty()) as u8; + let nonempty = + (!polys0.is_empty()) as u8 + (!polys1.is_empty()) as u8 + (!polys2.is_empty()) as u8; assert!( nonempty >= 2, "MultiZip3::commit requires at least two non-empty batches \ @@ -208,6 +207,7 @@ where ) -> Result<(Option, Option, Option), ZipError> where F: PrimeField + + DelayedFieldProductSum + for<'a> FromWithConfig<&'a Zt0::CombR> + for<'a> FromWithConfig<&'a Zt1::CombR> + for<'a> FromWithConfig<&'a Zt2::CombR> @@ -239,23 +239,26 @@ where let eval0 = if polys0.is_empty() { None } else { - Some(ZipPlus::::prove_pre_open_f::( - transcript, pp0, polys0, point, field_cfg, - )?) + Some(ZipPlus::::prove_pre_open_f::< + F, + CHECK_FOR_OVERFLOW, + >(transcript, pp0, polys0, point, field_cfg)?) }; let eval1 = if polys1.is_empty() { None } else { - Some(ZipPlus::::prove_pre_open_f::( - transcript, pp1, polys1, point, field_cfg, - )?) + Some(ZipPlus::::prove_pre_open_f::< + F, + CHECK_FOR_OVERFLOW, + >(transcript, pp1, polys1, point, field_cfg)?) }; let eval2 = if polys2.is_empty() { None } else { - Some(ZipPlus::::prove_pre_open_f::( - transcript, pp2, polys2, point, field_cfg, - )?) + Some(ZipPlus::::prove_pre_open_f::< + F, + CHECK_FOR_OVERFLOW, + >(transcript, pp2, polys2, point, field_cfg)?) }; for _ in 0..Zt0::NUM_COLUMN_OPENINGS { @@ -310,6 +313,7 @@ where ) -> Result<(Option, Option, Option), ZipError> where F: PrimeField + + DelayedFieldProductSum + for<'a> FromWithConfig<&'a Zt0::CombR> + for<'a> FromWithConfig<&'a Zt1::CombR> + for<'a> FromWithConfig<&'a Zt2::CombR> @@ -353,9 +357,10 @@ where let eval0 = if polys0.is_empty() { None } else { - Some(ZipPlus::::prove_pre_open_f::( - transcript, pp0, polys0, point, field_cfg, - )?) + Some(ZipPlus::::prove_pre_open_f::< + F, + CHECK_FOR_OVERFLOW, + >(transcript, pp0, polys0, point, field_cfg)?) }; let p1 = pos(transcript); bd0.combined_row.extend(snapshot(transcript, p0, p1)); @@ -363,9 +368,10 @@ where let eval1 = if polys1.is_empty() { None } else { - Some(ZipPlus::::prove_pre_open_f::( - transcript, pp1, polys1, point, field_cfg, - )?) + Some(ZipPlus::::prove_pre_open_f::< + F, + CHECK_FOR_OVERFLOW, + >(transcript, pp1, polys1, point, field_cfg)?) }; let p2 = pos(transcript); bd1.combined_row.extend(snapshot(transcript, p1, p2)); @@ -373,9 +379,10 @@ where let eval2 = if polys2.is_empty() { None } else { - Some(ZipPlus::::prove_pre_open_f::( - transcript, pp2, polys2, point, field_cfg, - )?) + Some(ZipPlus::::prove_pre_open_f::< + F, + CHECK_FOR_OVERFLOW, + >(transcript, pp2, polys2, point, field_cfg)?) }; let p3 = pos(transcript); bd2.combined_row.extend(snapshot(transcript, p2, p3)); diff --git a/zip-plus/src/pcs/phase_prove.rs b/zip-plus/src/pcs/phase_prove.rs index 30f5232c..1bff475f 100644 --- a/zip-plus/src/pcs/phase_prove.rs +++ b/zip-plus/src/pcs/phase_prove.rs @@ -51,8 +51,9 @@ use zinc_poly::{Polynomial, mle::DenseMultilinearExtension}; use zinc_transcript::traits::{Transcribable, Transcript}; use zinc_utils::{ UNCHECKED, cfg_chunks, cfg_iter, cfg_iter_mut, + delayed_reduction::DelayedFieldProductSum, from_ref::FromRef, - inner_product::{InnerProduct, MBSInnerProduct}, + inner_product::{FieldFieldInnerProduct, InnerProduct, MBSInnerProduct}, mul_by_scalar::MulByScalar, }; @@ -127,6 +128,7 @@ impl> ZipPlus { ) -> Result where F: PrimeField + + DelayedFieldProductSum + for<'a> FromWithConfig<&'a Zt::CombR> + for<'a> FromWithConfig<&'a Zt::Pt> + for<'a> MulByScalar<&'a F> @@ -161,6 +163,7 @@ impl> ZipPlus { ) -> Result where F: PrimeField + + DelayedFieldProductSum + for<'a> FromWithConfig<&'a Zt::CombR> + for<'a> MulByScalar<&'a F> + FromRef, @@ -194,6 +197,7 @@ impl> ZipPlus { ) -> Result where F: PrimeField + + DelayedFieldProductSum + for<'a> FromWithConfig<&'a Zt::CombR> + for<'a> MulByScalar<&'a F> + FromRef, @@ -223,6 +227,7 @@ impl> ZipPlus { ) -> Result where F: PrimeField + + DelayedFieldProductSum + for<'a> FromWithConfig<&'a Zt::CombR> + for<'a> MulByScalar<&'a F> + FromRef, @@ -297,7 +302,7 @@ impl> ZipPlus { } // Compute eval = (inner product in field), in paper // It is safe to use inner_product_unchecked because we're in a field. - let eval = MBSInnerProduct::inner_product::(&q_0, &b, zero_f.clone())?; + let eval = FieldFieldInnerProduct::inner_product::(&q_0, &b, zero_f.clone())?; // Matrix-vector product over the flat poly_comb_r layout: // Each poly is a row-major (num_rows x row_len) matrix, and coeffs is the @@ -377,6 +382,7 @@ impl> ZipPlus { ) -> Result where F: PrimeField + + DelayedFieldProductSum + for<'a> FromWithConfig<&'a Zt::CombR> + for<'a> FromWithConfig<&'a Zt::Chal> + for<'a> FromWithConfig<&'a Zt::Pt> @@ -413,6 +419,7 @@ impl> ZipPlus { ) -> Result where F: PrimeField + + DelayedFieldProductSum + for<'a> FromWithConfig<&'a Zt::CombR> + for<'a> MulByScalar<&'a F> + FromRef, @@ -473,12 +480,14 @@ impl> ZipPlus { }; transcript.write_field_elements(&b)?; - let eval = MBSInnerProduct::inner_product::(&q_0, &b, zero_f.clone())?; + let eval = FieldFieldInnerProduct::inner_product::(&q_0, &b, zero_f.clone())?; let coeffs = if pp.num_rows == 1 { vec![Zt::Chal::ONE] } else { - transcript.fs_transcript.get_challenges::(num_rows) + transcript + .fs_transcript + .get_challenges::(num_rows) }; let combined_row: Vec = { diff --git a/zip-plus/src/pcs/phase_verify.rs b/zip-plus/src/pcs/phase_verify.rs index 73ca3db9..c3cf9ede 100644 --- a/zip-plus/src/pcs/phase_verify.rs +++ b/zip-plus/src/pcs/phase_verify.rs @@ -19,8 +19,9 @@ use zinc_transcript::{ }; use zinc_utils::{ UNCHECKED, add, cfg_into_iter, + delayed_reduction::DelayedFieldProductSum, from_ref::FromRef, - inner_product::{InnerProduct, MBSInnerProduct}, + inner_product::{FieldFieldInnerProduct, InnerProduct, MBSInnerProduct}, mul_by_scalar::MulByScalar, }; @@ -126,6 +127,7 @@ impl> ZipPlus { ) -> Result<(), ZipError> where F: FromPrimitiveWithConfig + + DelayedFieldProductSum + FromRef + for<'a> FromWithConfig<&'a Zt::CombR> + for<'a> FromWithConfig<&'a Zt::Chal> @@ -164,6 +166,7 @@ impl> ZipPlus { ) -> Result<(), ZipError> where F: FromPrimitiveWithConfig + + DelayedFieldProductSum + FromRef + for<'a> FromWithConfig<&'a Zt::CombR> + for<'a> FromWithConfig<&'a Zt::Chal> @@ -192,7 +195,8 @@ impl> ZipPlus { let b: Vec = transcript.read_field_elements(num_rows)?; // Check 1: == eval_f - if MBSInnerProduct::inner_product::(&q_0, &b, zero_f.clone())? != *eval_f { + if FieldFieldInnerProduct::inner_product::(&q_0, &b, zero_f.clone())? != *eval_f + { return Err(ZipError::InvalidPcsOpen( "Evaluation consistency failure".into(), )); @@ -273,6 +277,7 @@ impl> ZipPlus { ) -> Result, ZipError> where F: FromPrimitiveWithConfig + + DelayedFieldProductSum + FromRef + for<'a> FromWithConfig<&'a Zt::CombR> + for<'a> FromWithConfig<&'a Zt::Chal> @@ -348,6 +353,7 @@ impl> ZipPlus { ) -> Result, ZipError> where F: FromPrimitiveWithConfig + + DelayedFieldProductSum + FromRef + for<'a> FromWithConfig<&'a Zt::CombR> + for<'a> FromWithConfig<&'a Zt::Chal> @@ -374,7 +380,8 @@ impl> ZipPlus { let (q_0, q_1) = point_to_tensor(vp.num_rows, point_f, field_cfg)?; let zero_f = F::zero_with_cfg(field_cfg); - if MBSInnerProduct::inner_product::(&q_0, &b, zero_f.clone())? != *eval_f { + if FieldFieldInnerProduct::inner_product::(&q_0, &b, zero_f.clone())? != *eval_f + { return Err(ZipError::InvalidPcsOpen( "Evaluation consistency failure".into(), )); diff --git a/zip-plus/src/pcs/test_utils.rs b/zip-plus/src/pcs/test_utils.rs index ef27b714..012de393 100644 --- a/zip-plus/src/pcs/test_utils.rs +++ b/zip-plus/src/pcs/test_utils.rs @@ -31,6 +31,7 @@ use zinc_primality::MillerRabin; use zinc_transcript::traits::{Transcribable, Transcript}; use zinc_utils::{ CHECKED, + delayed_reduction::DelayedFieldProductSum, from_ref::FromRef, inner_product::{MBSInnerProduct, ScalarProduct}, mul_by_scalar::MulByScalar, @@ -168,6 +169,7 @@ where + for<'a> FromWithConfig<&'a as ZipTypes>::Chal> + for<'a> FromWithConfig<&'a as ZipTypes>::CombR> + for<'a> MulByScalar<&'a F> + + DelayedFieldProductSum + FromRef, F::Inner: Transcribable, F::Modulus: FromRef< as ZipTypes>::Fmod> + Transcribable, @@ -201,6 +203,7 @@ where + for<'a> FromWithConfig<&'a as ZipTypes>::Chal> + for<'a> FromWithConfig<&'a as ZipTypes>::CombR> + for<'a> MulByScalar<&'a F> + + DelayedFieldProductSum + FromRef + 'static, F::Inner: Transcribable, @@ -231,6 +234,7 @@ where + for<'a> FromWithConfig<&'a Zt::Chal> + for<'a> FromWithConfig<&'a Zt::Pt> + for<'a> MulByScalar<&'a F> + + DelayedFieldProductSum + FromRef, F::Inner: Transcribable, F::Modulus: FromRef + Transcribable, diff --git a/zip-plus/src/utils.rs b/zip-plus/src/utils.rs index cef1268a..3ae73c5e 100644 --- a/zip-plus/src/utils.rs +++ b/zip-plus/src/utils.rs @@ -35,9 +35,7 @@ pub const ZSTD_LEVEL: i32 = 3; /// compression step (excluding serialization). Useful for callers /// that want to attribute the compression cost to a step in a /// timings breakdown. -pub fn serialize_and_compress( - value: &T, -) -> (Vec, std::time::Duration) { +pub fn serialize_and_compress(value: &T) -> (Vec, std::time::Duration) { let mut buf = vec![0_u8; value.get_num_bytes()]; value.write_transcription_bytes_exact(&mut buf); let t0 = std::time::Instant::now(); @@ -75,7 +73,11 @@ pub fn eprint_bytes_size(label: impl std::fmt::Display, raw: &[u8]) { print_size!(format_args!("zstd-{ZSTD_LEVEL}"), compressed.len()); let decompressed = zstd::decode_all(&compressed[..]).expect("zstd decompression failed"); - assert_eq!(decompressed.len(), raw.len(), "zstd round-trip size mismatch"); + assert_eq!( + decompressed.len(), + raw.len(), + "zstd round-trip size mismatch" + ); } /// Prints a per-part proof size breakdown (raw + zstd-compressed) to stderr.