diff --git a/.github/workflows/backend-tests.yml b/.github/workflows/backend-tests.yml index 48c805925..48065fd8a 100644 --- a/.github/workflows/backend-tests.yml +++ b/.github/workflows/backend-tests.yml @@ -79,7 +79,7 @@ jobs: - backend: zkcrypto support_wasm: true support_ckzg: true - clippy-flag: --all-features + clippy-flag: --features=default,std,rand,parallel - backend: arkworks4 support_wasm: true support_ckzg: true diff --git a/Cargo.lock b/Cargo.lock index fa82b8a07..fd2e14183 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1874,6 +1874,7 @@ dependencies = [ "once_cell", "rand 0.8.5", "rayon", + "rust-kzg-blst", "smallvec", ] diff --git a/arkworks4/Cargo.toml b/arkworks4/Cargo.toml index 6ba3cc082..14adfd7b4 100644 --- a/arkworks4/Cargo.toml +++ b/arkworks4/Cargo.toml @@ -51,6 +51,9 @@ wbits = [ arkmsm = [ "kzg/arkmsm" ] +strauss = [ + "kzg/strauss" +] c_bindings = [] diskcache = [ "kzg/diskcache" diff --git a/arkworks5/Cargo.toml b/arkworks5/Cargo.toml index 9dc2d1058..4134a388c 100644 --- a/arkworks5/Cargo.toml +++ b/arkworks5/Cargo.toml @@ -51,6 +51,9 @@ arkmsm = [ wbits = [ "kzg/wbits" ] +strauss = [ + "kzg/strauss" +] c_bindings = [] diskcache = [ "kzg/diskcache" diff --git a/blst/Cargo.toml b/blst/Cargo.toml index b41b6ba20..04b3dd61c 100644 --- a/blst/Cargo.toml +++ b/blst/Cargo.toml @@ -54,6 +54,9 @@ sppark = [ wbits = [ "kzg/wbits" ] +strauss = [ + "kzg/strauss" +] c_bindings = [] diskcache = [ "kzg/diskcache" diff --git a/constantine/Cargo.toml b/constantine/Cargo.toml index 32c74339d..97202b989 100644 --- a/constantine/Cargo.toml +++ b/constantine/Cargo.toml @@ -21,6 +21,7 @@ arbitrary = "1.4.2" criterion = "0.5.1" kzg-bench = { path = "../kzg-bench" } rand = "0.8.5" +rust-kzg-blst = { path = "../blst", default-features = false, features = ["std", "rand"] } [features] default = [ @@ -52,7 +53,13 @@ arkmsm = [ wbits = [ "kzg/wbits" ] +strauss = [ + "kzg/strauss" +] c_bindings = [] +diskcache = [ + "kzg/diskcache" +] [[bench]] name = "das" diff --git a/constantine/src/types/g1.rs b/constantine/src/types/g1.rs index d31fde2a9..b4fd0f9f2 100644 --- a/constantine/src/types/g1.rs +++ b/constantine/src/types/g1.rs @@ -458,11 +458,91 @@ impl G1Affine for CtG1Affine { } fn to_bytes_uncompressed(&self) -> [u8; 96] { - todo!() + let mut out = [0u8; 96]; + + // Check if point is infinity + if self.is_infinity() { + // Set infinity flag (bit 6) in first byte + out[0] = 0x40; + return out; + } + + // Serialize: 48 bytes x (big-endian) || 48 bytes y (big-endian) + // limbs are stored in little-endian, so limbs[5] is most significant + for i in 0..6 { + let bytes = self.0.x.limbs[5 - i].to_be_bytes(); + out[i * 8..(i + 1) * 8].copy_from_slice(&bytes); + } + for i in 0..6 { + let bytes = self.0.y.limbs[5 - i].to_be_bytes(); + out[48 + i * 8..48 + (i + 1) * 8].copy_from_slice(&bytes); + } + + out } - fn from_bytes_uncompressed(_bytes: [u8; 96]) -> Result { - todo!() + fn from_bytes_uncompressed(bytes: [u8; 96]) -> Result { + // Check flags in first byte + let compression_flag = bytes[0] & 0x80; // most-significant bit + let infinity_flag = bytes[0] & 0x40; // second most-significant bit + let sort_flag = bytes[0] & 0x20; // third most-significant bit + + // For uncompressed, compression bit must be 0 + if compression_flag != 0 { + return Err("Compression flag set for uncompressed encoding".to_string()); + } + + // Sort flag must be 0 for uncompressed + if sort_flag != 0 { + return Err("Sort flag must be 0 for uncompressed encoding".to_string()); + } + + // Handle infinity point + if infinity_flag != 0 { + // All other bits (except flags) must be zero for infinity + if bytes[0] & 0x1f != 0 || bytes[1..].iter().any(|&b| b != 0) { + return Err("Invalid infinity encoding".to_string()); + } + return Ok(Self::zero()); + } + + let mut x_limbs: [usize; 6] = [0; 6]; + let mut y_limbs: [usize; 6] = [0; 6]; + + // Deserialize: bytes come in big-endian + // We need to store them in little-endian limbs array + // First limb needs to have flag bits cleared + for i in 0..6 { + let mut limb_bytes = [0u8; 8]; + limb_bytes.copy_from_slice(&bytes[i * 8..(i + 1) * 8]); + let mut limb_value = usize::from_be_bytes(limb_bytes); + // Clear top 3 flag bits from the first limb (most significant) + if i == 0 { + limb_value &= 0x1fffffffffffffff; // Clear bits 63, 62, 61 + } + x_limbs[5 - i] = limb_value; + } + for i in 0..6 { + let mut limb_bytes = [0u8; 8]; + limb_bytes.copy_from_slice(&bytes[48 + i * 8..48 + (i + 1) * 8]); + y_limbs[5 - i] = usize::from_be_bytes(limb_bytes); + } + + let tmp = bls12_381_g1_aff { + x: bls12_381_fp { limbs: x_limbs }, + y: bls12_381_fp { limbs: y_limbs }, + }; + + // Validate point is on curve + unsafe { + match constantine::ctt_bls12_381_validate_g1(&tmp) { + ctt_codec_ecc_status::cttCodecEcc_Success => Ok(CtG1Affine(tmp)), + ctt_codec_ecc_status::cttCodecEcc_PointAtInfinity => { + Err("Point at infinity should have infinity flag set".to_string()) + } + _ => Err("Point is not on the curve".to_string()), + } + } } } diff --git a/constantine/tests/mod.rs b/constantine/tests/mod.rs index d1b7273bd..9a100c52f 100644 --- a/constantine/tests/mod.rs +++ b/constantine/tests/mod.rs @@ -1 +1,3 @@ pub mod local_tests; + +mod serialization; diff --git a/constantine/tests/serialization.rs b/constantine/tests/serialization.rs new file mode 100644 index 000000000..6ca6b0660 --- /dev/null +++ b/constantine/tests/serialization.rs @@ -0,0 +1,270 @@ +#[cfg(test)] +mod tests { + use kzg::{Fr, G1Affine, G1Mul, G1}; + use rust_kzg_constantine::types::fr::CtFr; + use rust_kzg_constantine::types::g1::{CtG1, CtG1Affine}; + + #[test] + fn test_uncompressed_serialization_roundtrip() { + // Test with generator + let point = CtG1::generator(); + let affine = CtG1Affine::into_affine(&point); + + let bytes = affine.to_bytes_uncompressed(); + let recovered = CtG1Affine::from_bytes_uncompressed(bytes).expect("Failed to deserialize"); + + assert_eq!(affine, recovered, "Generator roundtrip failed"); + } + + #[test] + fn test_uncompressed_serialization_infinity() { + let point = CtG1Affine::zero(); + + let bytes = point.to_bytes_uncompressed(); + + // Check that infinity flag is set (bit 6 of first byte) + assert_eq!(bytes[0], 0x40, "Infinity flag not set correctly"); + + // All other bytes should be zero + for &byte in &bytes[1..] { + assert_eq!(byte, 0, "Non-flag bytes should be zero for infinity"); + } + + let recovered = + CtG1Affine::from_bytes_uncompressed(bytes).expect("Failed to deserialize infinity"); + assert!( + recovered.is_infinity(), + "Deserialized point should be infinity" + ); + } + + #[test] + fn test_uncompressed_serialization_random_points() { + // Test multiple random points + for i in 1..10 { + let scalar = CtFr::from_u64(i * 12345 + 67890); + let point = CtG1::generator().mul(&scalar); + let affine = CtG1Affine::into_affine(&point); + + let bytes = affine.to_bytes_uncompressed(); + let recovered = CtG1Affine::from_bytes_uncompressed(bytes) + .unwrap_or_else(|e| panic!("Failed to deserialize point {}: {}", i, e)); + + assert_eq!(affine, recovered, "Roundtrip failed for point {}", i); + } + } + + #[test] + fn test_uncompressed_with_high_bit_coordinates() { + // Create points where coordinates might have high bits set + // Use large scalar multipliers to get varied coordinate values + let large_scalar = CtFr::from_u64(u64::MAX - 1); + let point = CtG1::generator().mul(&large_scalar); + let affine = CtG1Affine::into_affine(&point); + + let bytes = affine.to_bytes_uncompressed(); + + // Verify that flag bits are NOT set in serialized form for non-infinity point + // (compression bit should be 0, infinity bit should be 0) + assert_eq!( + bytes[0] & 0xC0, + 0, + "Unexpected flag bits set for regular point" + ); + + let recovered = CtG1Affine::from_bytes_uncompressed(bytes) + .expect("Failed to deserialize point with high bit coordinates"); + + assert_eq!( + affine, recovered, + "Roundtrip failed for point with high bit coordinates" + ); + } + + #[test] + fn test_known_generator_bytes() { + // Get the actual serialization from constantine for the generator + // This serves as a regression test - the bytes should remain consistent + let generator = CtG1::generator(); + let affine = CtG1Affine::into_affine(&generator); + let bytes = affine.to_bytes_uncompressed(); + + // Verify it deserializes correctly + let recovered = CtG1Affine::from_bytes_uncompressed(bytes) + .expect("Failed to deserialize generator bytes"); + assert_eq!(affine, recovered, "Generator round-trip failed"); + + // Verify no flags are set (should be 0 for uncompressed, non-infinity point) + assert_eq!( + bytes[0] & 0xE0, + 0, + "Unexpected flags set in generator serialization" + ); + + // Print bytes for cross-backend verification (useful for manual testing) + #[cfg(feature = "std")] + { + println!("Generator bytes (hex): {}", hex::encode(&bytes[..])); + println!("First 8 bytes: {:02x?}", &bytes[0..8]); + println!("Bytes 48-56: {:02x?}", &bytes[48..56]); + } + } + + #[test] + fn test_known_infinity_bytes() { + // Infinity point should serialize to all zeros except the infinity flag + let expected_bytes = { + let mut bytes = [0u8; 96]; + bytes[0] = 0x40; // Only infinity flag set + bytes + }; + + let infinity = CtG1Affine::zero(); + let bytes = infinity.to_bytes_uncompressed(); + + assert_eq!( + bytes, expected_bytes, + "Infinity serialization doesn't match expected bytes" + ); + + // Verify round-trip + let recovered = CtG1Affine::from_bytes_uncompressed(expected_bytes) + .expect("Failed to deserialize known infinity bytes"); + assert!( + recovered.is_infinity(), + "Recovered point should be infinity" + ); + } + + // // generate points to replace hex::decode in test_uncompressed_known_points() + // #[test] + // fn print_known_points() { + // for k in [5u64, 7, 1235, 9999] { + // let point = CtG1::generator().mul(&CtFr::from_u64(k)); + // let affine = CtG1Affine::into_affine(&point); + // let bytes = affine.to_bytes_uncompressed(); + // println!("g * {} = {}", k, hex::encode(bytes)); + // } + // } + + #[test] + fn test_uncompressed_known_points() { + // g * 5 + let point = CtG1::generator().mul(&CtFr::from_u64(5)); + let affine = CtG1Affine::into_affine(&point); + let bytes = affine.to_bytes_uncompressed(); + let expected = hex::decode("0befb962052d5be4fa0cd24153cb8d710593bb36cbf5d7d289f3afc44b4fd7f2b031cca2d46d20db6c3aea956edc0d65149a008e9c0217f87906a800fc343bbb83af773c28f5fce6f978a25d58dd410239ff5eb5d2a4f6b9afc43ba27a23863c").unwrap(); + assert_eq!(bytes.to_vec(), expected); + assert!(affine.eq(&CtG1Affine::from_bytes_uncompressed(bytes).unwrap())); + + //g * 7 + let point = CtG1::generator().mul(&CtFr::from_u64(7)); + let affine = CtG1Affine::into_affine(&point); + let bytes = affine.to_bytes_uncompressed(); + let expected = hex::decode("0d3056fc0db4365ff29c6756cc2dd13a5508d390e218ff49a8588f9235e2e40d018298254a48192dbf6f80fad9849c75092ab1a757dc00ed54b15b24176ee9d24869c5dda756bdfea158f28b5eff937fed6b86ac4a437cb894ceeaaf25173e97").unwrap(); + assert_eq!(bytes.to_vec(), expected); + assert!(affine.eq(&CtG1Affine::from_bytes_uncompressed(bytes).unwrap())); + + // g * 1235 + let point = CtG1::generator().mul(&CtFr::from_u64(1235)); + let affine = CtG1Affine::into_affine(&point); + let bytes = affine.to_bytes_uncompressed(); + let expected = hex::decode("18b5fee2a1cdf7726124be509790e12d6169df13c2fdf7bd127a0872fe117575426f9fd0464c15b53f766c1861d510f312c51fdeaff5ec6ae20025033bcb4d6ed6ed064abfbbad523187ec727c6f858cfda248ff03da74e292523519a1a4d56e").unwrap(); + assert_eq!(bytes.to_vec(), expected); + assert!(affine.eq(&CtG1Affine::from_bytes_uncompressed(bytes).unwrap())); + + // g * 9999 + let point = CtG1::generator().mul(&CtFr::from_u64(9999)); + let affine = CtG1Affine::into_affine(&point); + let bytes = affine.to_bytes_uncompressed(); + let expected = hex::decode("0da245189a3242ca54c7397d77285fc0252722b8a53366efd1080f5752128206194d7d882b2a3becd6a52cf1faebf0e8198ee528c3360a55cfe7ca3e92c75546dd5800bb46f204e8ddc8491dfefb4b37e57c024a45e58c75aba1797c439ae799").unwrap(); + assert_eq!(bytes.to_vec(), expected); + assert!(affine.eq(&CtG1Affine::from_bytes_uncompressed(bytes).unwrap())); + } + + #[test] + fn test_msb_handling() { + // BLS12-381 field prime is 381 bits, so top 3 bits of 384-bit representation are unused + // The field modulus is: 0x1a0111ea397fe69a4b1ba7b6434bacd764774b84f38512bf6730d2a0f6b0f6241eabfffeb153ffffb9feffffffffaaab + // This means valid field elements can have bits set up to bit 380 + + use kzg::{Fr, G1Mul}; + + // Create points with large scalar multipliers to get varied coordinates + for scalar_val in [u64::MAX, u64::MAX - 1, u64::MAX / 2, 1u64 << 63] { + let scalar = CtFr::from_u64(scalar_val); + let point = CtG1::generator().mul(&scalar); + let affine = CtG1Affine::into_affine(&point); + + let bytes = affine.to_bytes_uncompressed(); + + // Verify top 3 bits are clear (flag bits) + assert_eq!( + bytes[0] & 0xE0, + 0, + "Top 3 bits should be clear for scalar {}", + scalar_val + ); + + // Verify round-trip + let recovered = CtG1Affine::from_bytes_uncompressed(bytes).unwrap_or_else(|e| { + panic!( + "Failed to deserialize point with scalar {}: {}", + scalar_val, e + ) + }); + + assert_eq!( + affine, recovered, + "Round-trip failed for scalar {}", + scalar_val + ); + + // Verify the point is still valid by checking compressed form matches + let compressed1 = point.to_bytes(); + let compressed2 = recovered.to_proj().to_bytes(); + assert_eq!( + compressed1, compressed2, + "Compressed forms don't match for scalar {}", + scalar_val + ); + } + } + + #[test] + fn test_invalid_compression_flag() { + let mut bytes = [0u8; 96]; + bytes[0] = 0x80; // Set compression flag + + let result = CtG1Affine::from_bytes_uncompressed(bytes); + assert!( + result.is_err(), + "Should reject compressed flag in uncompressed format" + ); + } + + #[test] + fn test_invalid_sort_flag() { + let mut bytes = [0u8; 96]; + bytes[0] = 0x20; // Set sort flag + + let result = CtG1Affine::from_bytes_uncompressed(bytes); + assert!( + result.is_err(), + "Should reject sort flag in uncompressed format" + ); + } + + #[test] + fn test_invalid_infinity_encoding() { + let mut bytes = [0u8; 96]; + bytes[0] = 0x40; // Set infinity flag + bytes[1] = 0x01; // But has non-zero data + + let result = CtG1Affine::from_bytes_uncompressed(bytes); + assert!( + result.is_err(), + "Should reject infinity point with non-zero data" + ); + } +} diff --git a/fuzz/Cargo.toml b/fuzz/Cargo.toml index f43ebce5c..73a3b310e 100644 --- a/fuzz/Cargo.toml +++ b/fuzz/Cargo.toml @@ -62,6 +62,15 @@ parallel = [ "rust-kzg-mcl/parallel", ] +strauss = [ + "rust-kzg-blst/strauss", + "rust-kzg-arkworks4/strauss", + "rust-kzg-arkworks5/strauss", + "rust-kzg-constantine/strauss", + "rust-kzg-mcl/strauss", + "rust-kzg-zkcrypto/strauss", +] + # backends arkworks3=["dep:rust-kzg-arkworks3"] arkworks4=["dep:rust-kzg-arkworks4"] diff --git a/kzg/Cargo.toml b/kzg/Cargo.toml index 1879429a9..4ebbea422 100644 --- a/kzg/Cargo.toml +++ b/kzg/Cargo.toml @@ -34,6 +34,7 @@ arkmsm = [] bgmw = [] sppark = [] wbits = [] +strauss = [] diskcache = [ "std", "dep:dirs" diff --git a/kzg/src/msm/diskcache.rs b/kzg/src/msm/diskcache.rs index accb11434..115d31d6a 100644 --- a/kzg/src/msm/diskcache.rs +++ b/kzg/src/msm/diskcache.rs @@ -31,7 +31,7 @@ fn compute_content_hash>( } for row in matrix { for point in row { - let affine = TG1Affine::into_affine(&point); + let affine = TG1Affine::into_affine(point); hasher .write_all(&affine.to_bytes_uncompressed()) .map_err(|e| format!("{e:?}"))?; @@ -85,7 +85,7 @@ impl> DiskCache> DiskCache(points, matrix))?; writer @@ -217,7 +217,7 @@ impl> DiskCache = pub type PrecomputationTable = super::wbits::WbitsTable; -#[cfg(all(not(feature = "bgmw"), not(feature = "sppark"), not(feature = "wbits")))] +#[cfg(feature = "strauss")] +pub type PrecomputationTable = + super::strauss::StraussTable; + +#[cfg(all( + not(feature = "bgmw"), + not(feature = "sppark"), + not(feature = "wbits"), + not(feature = "strauss") +))] #[derive(Debug, Clone)] pub struct EmptyTable where @@ -45,7 +58,12 @@ where g1_affine_add_marker: core::marker::PhantomData, } -#[cfg(all(not(feature = "bgmw"), not(feature = "sppark"), not(feature = "wbits")))] +#[cfg(all( + not(feature = "bgmw"), + not(feature = "sppark"), + not(feature = "wbits"), + not(feature = "strauss") +))] impl EmptyTable where @@ -73,7 +91,12 @@ where } } -#[cfg(all(not(feature = "bgmw"), not(feature = "sppark"), not(feature = "wbits")))] +#[cfg(all( + not(feature = "bgmw"), + not(feature = "sppark"), + not(feature = "wbits"), + not(feature = "strauss") +))] pub type PrecomputationTable = EmptyTable; diff --git a/kzg/src/msm/strauss.rs b/kzg/src/msm/strauss.rs new file mode 100644 index 000000000..3e4e9a4bc --- /dev/null +++ b/kzg/src/msm/strauss.rs @@ -0,0 +1,403 @@ +use crate::msm::pippenger_utils::get_wval_limb; +use crate::{Fr, G1Affine, G1Fp, G1GetFp, G1Mul, G1ProjAddAffine, G1}; +use alloc::vec::Vec; +use core::marker::PhantomData; + +#[cfg(feature = "diskcache")] +use crate::msm::diskcache::DiskCache; + +// Strauss chunk size: process this many points at a time. +// Table size = 2^CHUNK_SIZE. For CHUNK_SIZE=7: table has 128 entries. + +fn get_window_size() -> usize { + option_env!("WINDOW_SIZE") + .map(|v| { + v.parse() + .expect("WINDOW_SIZE environment variable must be valid number") + }) + .unwrap_or(8) +} + +#[derive(Debug, Clone)] +pub struct StraussTable +where + TFr: Fr, + TG1: G1 + G1Mul + G1GetFp, + TG1Fp: G1Fp, + TG1Affine: G1Affine, + TG1ProjAddAffine: G1ProjAddAffine, +{ + chunk_tables: Vec>, + numpoints: usize, + + batch_numpoints: usize, + batch_chunk_tables: Vec>>, // precomputed tables per row + + g1_marker: PhantomData, + g1_fp_marker: PhantomData, + fr_marker: PhantomData, + g1_affine_marker: PhantomData, + g1_affine_add_marker: PhantomData, +} + +impl + StraussTable +where + TFr: Fr, + TG1: G1 + G1Mul + G1GetFp + Clone, + TG1Fp: G1Fp, + TG1Affine: G1Affine + Clone, + TG1ProjAddAffine: G1ProjAddAffine, +{ + fn try_read_cache(points: &[TG1], matrix: &[Vec]) -> Result> { + #[cfg(feature = "diskcache")] + { + DiskCache::::load("strauss", get_window_size(), points, matrix) + .map_err(|(err, contenthash)| { + println!("Failed to load cache: {err}"); + contenthash + }) + .map(|cache| { + // Reconstruct chunk_tables from cache + let chunk_size = get_window_size(); + let n = cache.numpoints; + let num_chunks = n.div_ceil(chunk_size); + + let mut chunk_tables = Vec::new(); + let mut offset = 0; + + for chunk_idx in 0..num_chunks { + let start = chunk_idx * chunk_size; + let end = core::cmp::min(start + chunk_size, n); + let chunk_len = end - start; + let table_size = (1usize << chunk_len) - 1; + + // Store directly as affine + let chunk: Vec = + cache.table[offset..offset + table_size].to_vec(); + chunk_tables.push(chunk); + offset += table_size; + } + + // Rebuild batch_chunk_tables from cache + // DiskCache stores batch_table as Vec> - each row is already flattened + let batch_chunk_tables: Vec>> = cache + .batch_table + .iter() + .map(|flat_row| { + Self::unflatten_row_chunk_tables( + flat_row, + cache.batch_numpoints, + chunk_size, + ) + }) + .collect(); + + Self { + chunk_tables, + numpoints: cache.numpoints, + batch_numpoints: cache.batch_numpoints, + batch_chunk_tables, + g1_marker: PhantomData, + g1_fp_marker: PhantomData, + fr_marker: PhantomData, + g1_affine_marker: PhantomData, + g1_affine_add_marker: PhantomData, + } + }) + } + + #[cfg(not(feature = "diskcache"))] + Err(None) + } + + fn try_write_cache( + points: &[TG1], + matrix: &[Vec], + chunk_tables: &[Vec], + numpoints: usize, + batch_chunk_tables: &[Vec>], + batch_numpoints: usize, + contenthash: Option<[u8; 32]>, + ) -> Result<(), String> { + #[cfg(feature = "diskcache")] + { + // Flatten chunk_tables + let table_affine: Vec = chunk_tables + .iter() + .flat_map(|chunk| chunk.iter()) + .cloned() + .collect(); + + // Flatten each row's chunk_tables for DiskCache's 2D structure + let batch_table_affine: Vec> = batch_chunk_tables + .iter() + .map(|row_chunks| { + row_chunks + .iter() + .flat_map(|chunk| chunk.iter()) + .cloned() + .collect() + }) + .collect(); + + DiskCache::::save( + "strauss", + get_window_size(), + points, + matrix, + &table_affine, + numpoints, + &batch_table_affine, + batch_numpoints, + contenthash, + ) + .inspect_err(|err| println!("Failed to save cache: {err}")) + } + + #[cfg(not(feature = "diskcache"))] + Ok(()) + } + + fn unflatten_row_chunk_tables( + flat_row: &[TG1Affine], + batch_numpoints: usize, + chunk_size: usize, + ) -> Vec> { + if flat_row.is_empty() { + return Vec::new(); + } + + let num_chunks = batch_numpoints.div_ceil(chunk_size); + let mut row_chunk_tables = Vec::new(); + let mut offset = 0; + + for chunk_idx in 0..num_chunks { + let start = chunk_idx * chunk_size; + let end = core::cmp::min(start + chunk_size, batch_numpoints); + let chunk_len = end - start; + let table_size = (1usize << chunk_len) - 1; + + let chunk: Vec = flat_row[offset..offset + table_size].to_vec(); + row_chunk_tables.push(chunk); + offset += table_size; + } + + row_chunk_tables + } + + /// Build a StraussTable that precomputes chunk tables for all chunks. + /// This mirrors the style of other algorithms that precompute and store tables. + pub fn new(points: &[TG1], matrix: &[Vec]) -> Result, String> { + let contenthash = match Self::try_read_cache(points, matrix) { + Ok(v) => return Ok(Some(v)), + Err(e) => e, + }; + + let strauss_chunk_size: usize = get_window_size(); + + // If matrix is empty, build single-point-set precomputation + if matrix.is_empty() { + let n = points.len(); + if n == 0 { + let table = StraussTable { + chunk_tables: Vec::new(), + numpoints: 0, + batch_numpoints: 0, + batch_chunk_tables: Vec::new(), + g1_marker: PhantomData, + g1_fp_marker: PhantomData, + fr_marker: PhantomData, + g1_affine_marker: PhantomData, + g1_affine_add_marker: PhantomData, + }; + return Ok(Some(table)); + } + + // Build chunk tables as affine + let chunk_tables = Self::build_chunk_tables(points, strauss_chunk_size); + + Self::try_write_cache(points, matrix, &chunk_tables, n, &[], 0, contenthash)?; + + let table = StraussTable { + chunk_tables, + numpoints: n, + batch_numpoints: 0, + batch_chunk_tables: Vec::new(), + g1_marker: PhantomData, + g1_fp_marker: PhantomData, + fr_marker: PhantomData, + g1_affine_marker: PhantomData, + g1_affine_add_marker: PhantomData, + }; + return Ok(Some(table)); + } + + let batch_numpoints = matrix[0].len(); + + // Build chunk tables for each row + let batch_chunk_tables: Vec>> = matrix + .iter() + .map(|point_row| Self::build_chunk_tables(point_row, strauss_chunk_size)) + .collect(); + + // Build main chunk_tables if needed + let n = points.len(); + let chunk_tables = if n > 0 { + Self::build_chunk_tables(points, strauss_chunk_size) + } else { + Vec::new() + }; + + Self::try_write_cache( + points, + matrix, + &chunk_tables, + n, + &batch_chunk_tables, + batch_numpoints, + contenthash, + )?; + + let table = StraussTable { + chunk_tables, + numpoints: n, + batch_numpoints, + batch_chunk_tables, + g1_marker: PhantomData, + g1_fp_marker: PhantomData, + fr_marker: PhantomData, + g1_affine_marker: PhantomData, + g1_affine_add_marker: PhantomData, + }; + Ok(Some(table)) + } + + /// Build chunk tables - returns AFFINE for storage efficiency + fn build_chunk_tables(points: &[TG1], chunk_size: usize) -> Vec> { + let n = points.len(); + let mut chunk_tables: Vec> = Vec::new(); + + let num_chunks = n.div_ceil(chunk_size); + + for chunk_idx in 0..num_chunks { + let start = chunk_idx * chunk_size; + let end = core::cmp::min(start + chunk_size, n); + let chunk_len = end - start; + + // size of table for this chunk: 2^chunk_len entries, but we skip index 0 (identity) + let table_size = (1usize << chunk_len) - 1; + + // Build incremental table in projective space using the lowest-bit trick. + // faster additions in projective + let mut table_proj: Vec = Vec::with_capacity(table_size); + + for mask in 1..=table_size { + let lb = mask.trailing_zeros() as usize; + let prev = mask ^ (1 << lb); + if prev == 0 { + table_proj.push(points[start + lb].clone()); + } else { + let mut new_val = table_proj[prev - 1].clone(); + new_val.add_or_dbl_assign(&points[start + lb]); + table_proj.push(new_val); + } + } + + // Convert to affine once for storage + let table_affine: Vec = table_proj + .iter() + .map(|proj| TG1Affine::into_affine(proj)) + .collect(); + + chunk_tables.push(table_affine); + } + + chunk_tables + } + + /// Multiply using the precomputed chunk tables (sequential) + pub fn multiply_sequential(&self, scalars: &[TFr]) -> TG1 { + Self::multiply_with_tables(scalars, &self.chunk_tables) + } + + /// Core multiplication logic using provided tables + fn multiply_with_tables(scalars: &[TFr], chunk_tables: &[Vec]) -> TG1 { + let n = scalars.len(); + if n == 0 || chunk_tables.is_empty() { + return TG1::zero(); + } + + // Convert scalars to scalar limbs for bit access + let scalar_values = scalars.iter().map(TFr::to_scalar).collect::>(); + + // Single accumulator processing all chunks together + let mut accumulator = TG1::zero(); + + // Process all 255 bits (BLS12-381 scalar bit length) + for bit in (0..255).rev() { + // Double accumulator unconditionally + accumulator.dbl_assign(); + + // Process each chunk at this bit position + let mut pt_idx = 0usize; + for table in chunk_tables.iter() { + let table_size = table.len(); + // Derive chunk_len from table size: table_size = 2^chunk_len - 1 + // So 2^chunk_len = table_size + 1, thus log2(table_size + 1) + let chunk_len = + (usize::BITS - 1) as usize - (table_size + 1).leading_zeros() as usize; + + // Only process this chunk if we have scalars for it + // This handles the case where tables were built for more points than we're using + if pt_idx >= scalar_values.len() { + break; + } + + // Build table_index for this bit across chunk scalars + let mut table_index = 0usize; + let actual_chunk_len = core::cmp::min(chunk_len, scalar_values.len() - pt_idx); + + for i in 0..actual_chunk_len { + let scalar_idx = pt_idx + i; + + let s = &scalar_values[scalar_idx]; + // Extract single bit at position 'bit' from scalar + if (get_wval_limb(s, bit, 1) & 1) != 0 { + table_index |= 1 << i; + } + } + + if table_index != 0 { + // Mixed addition - Projective + Affine (should be faster than Proj + Proj) + let affine_pt = &table[table_index - 1]; + TG1ProjAddAffine::add_or_double_assign_affine(&mut accumulator, affine_pt); + } + + pt_idx += chunk_len; + } + } + + accumulator + } + + pub fn multiply_batch(&self, scalars: &[Vec]) -> Vec { + // Use precomputed batch_chunk_tables + assert!( + scalars.len() == self.batch_chunk_tables.len(), + "Scalars length {} != batch_chunk_tables length {}", + scalars.len(), + self.batch_chunk_tables.len() + ); + + scalars + .iter() + .zip(self.batch_chunk_tables.iter()) + .map(|(scalar_row, chunk_tables)| Self::multiply_with_tables(scalar_row, chunk_tables)) + .collect() + } + + pub fn multiply_parallel(&self, scalars: &[TFr]) -> TG1 { + self.multiply_sequential(scalars) + } +} diff --git a/mcl/Cargo.toml b/mcl/Cargo.toml index 583f4622a..a16aeb350 100644 --- a/mcl/Cargo.toml +++ b/mcl/Cargo.toml @@ -47,6 +47,9 @@ bgmw = [ arkmsm = [ "kzg/arkmsm" ] +strauss = [ + "kzg/strauss" +] c_bindings = [] diskcache = [ "kzg/diskcache" diff --git a/mcl/src/types/g1.rs b/mcl/src/types/g1.rs index 003b9810e..5725422b1 100644 --- a/mcl/src/types/g1.rs +++ b/mcl/src/types/g1.rs @@ -491,6 +491,9 @@ impl G1ProjAddAffine for MclG1ProjAddAffine { } fn add_or_double_assign_affine(_proj: &mut MclG1, _aff: &MclG1Affine) { - todo!() + try_init_mcl(); + + let tmp = _aff.to_proj(); + _proj.0 = _proj.0.add(&tmp.0); } } diff --git a/msm-benches/Cargo.toml b/msm-benches/Cargo.toml index bcc281186..5178e2f9b 100644 --- a/msm-benches/Cargo.toml +++ b/msm-benches/Cargo.toml @@ -52,6 +52,23 @@ parallel = [ "rust-kzg-zkcrypto/parallel", "rust-kzg-mcl/parallel", ] +strauss = [ + "rust-kzg-blst/strauss", + "rust-kzg-arkworks4/strauss", + "rust-kzg-arkworks5/strauss", + "rust-kzg-constantine/strauss", + "rust-kzg-zkcrypto/strauss", +] + +# Re-export diskcache for backends so benches can enable it with a single flag +# Enables the `diskcache` feature in `kzg` and the common backend crates +diskcache = [ + "rust-kzg-blst/diskcache", + "rust-kzg-arkworks3/diskcache", + "rust-kzg-arkworks4/diskcache", + "rust-kzg-arkworks5/diskcache", + "rust-kzg-constantine/diskcache", +] [[bench]] name = "g1_fixed_base_msm" diff --git a/zkcrypto/Cargo.toml b/zkcrypto/Cargo.toml index 7c16e69e8..b609c9afc 100644 --- a/zkcrypto/Cargo.toml +++ b/zkcrypto/Cargo.toml @@ -25,7 +25,6 @@ default = [ "std", "rand", "diskcache", - "bgmw", ] std = [ "kzg/std", @@ -41,6 +40,9 @@ rand = [ bgmw = [ "kzg/bgmw" ] +strauss = [ + "kzg/strauss" +] c_bindings = [] diskcache = [ "kzg/diskcache"