Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
805 changes: 406 additions & 399 deletions build.zig

Large diffs are not rendered by default.

1 change: 0 additions & 1 deletion packages/zolt-arith/src/curves/montgomery_field.zig
Original file line number Diff line number Diff line change
Expand Up @@ -935,7 +935,6 @@ pub fn MontgomeryField(
};
}


// =========================================================================
// Tests
// =========================================================================
Expand Down
17 changes: 12 additions & 5 deletions packages/zolt-arith/src/msm/mod.zig
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
//! We use the short Weierstrass curve y^2 = x^3 + b (a = 0 for BN254)

const std = @import("std");
const builtin = @import("builtin");
const Allocator = std.mem.Allocator;
const ThreadPool = @import("zolt_pool").ThreadPool;
const is_wasm = @import("zolt_pool").is_wasm;

pub const glv = @import("glv.zig");

Expand Down Expand Up @@ -672,7 +674,7 @@ pub fn MSM(comptime F: type, comptime G: type) type {
max_scalar_bits_hint: usize,
) Affine {
const c = optimalWindowSize(bases.len);
const num_buckets = (@as(usize, 1) << @as(u6, @intCast(c))) / 2; // 2^(c-1) buckets for wNAF
const num_buckets = (@as(usize, 1) << @as(std.math.Log2Int(usize), @intCast(c))) / 2; // 2^(c-1) buckets for wNAF

// Pre-convert all scalars and compute wNAF digits.
// Also detect effective max bit-width to skip unnecessary windows.
Expand Down Expand Up @@ -1250,7 +1252,7 @@ pub fn MSM(comptime F: type, comptime G: type) type {
const effective_bits: usize = if (max_abs == 0) 1 else @as(usize, std.math.log2_int(u128, max_abs)) + 1;
const num_scalar_windows = @min((effective_bits + c - 1) / c, (SCALAR_BITS_I128 + c - 1) / c);
const num_windows = num_scalar_windows + 1;
const num_buckets = (@as(usize, 1) << @as(u6, @intCast(c))) / 2;
const num_buckets = (@as(usize, 1) << @as(std.math.Log2Int(usize), @intCast(c))) / 2;

// Pre-compute wNAF digits and sign flags
const stack_threshold = 256;
Expand Down Expand Up @@ -1531,12 +1533,14 @@ pub fn ParallelMSM(comptime F: type, comptime G: type) type {

// For small inputs, single-threaded is faster due to thread overhead
const min_points_per_thread: usize = 1024;
const actual_threads = @min(num_threads, @max(1, bases.len / min_points_per_thread));
const actual_threads = if (comptime is_wasm) 1 else @min(num_threads, @max(1, bases.len / min_points_per_thread));

if (actual_threads <= 1) {
return SingleMSM.compute(bases, scalars);
}

if (comptime is_wasm) unreachable; // WASM always takes the sequential path above

// Divide work among threads
const chunk_size = (bases.len + actual_threads - 1) / actual_threads;

Expand Down Expand Up @@ -1603,6 +1607,7 @@ pub fn ParallelMSM(comptime F: type, comptime G: type) type {

/// Detect optimal number of threads
pub fn detectOptimalThreads() usize {
if (comptime is_wasm) return 1;
// Try to get CPU count, default to 4
const cpu_count = std.Thread.getCpuCount() catch return 4;
// Use at most 8 threads to avoid diminishing returns
Expand Down Expand Up @@ -1639,14 +1644,16 @@ pub fn ParallelBatchMSM(comptime F: type, comptime G: type) type {
const results = try allocator.alloc(Affine, scalar_batches.len);
errdefer allocator.free(results);

// For small batch counts, run sequentially
if (scalar_batches.len <= 2) {
// For small batch counts or WASM, run sequentially
if (scalar_batches.len <= 2 or comptime is_wasm) {
for (scalar_batches, 0..) |scalars, i| {
results[i] = SingleMSM.compute(bases, scalars);
}
return results;
}

if (comptime is_wasm) unreachable;

// Allocate thread infrastructure
const contexts = try allocator.alloc(BatchThreadContext, scalar_batches.len);
defer allocator.free(contexts);
Expand Down
116 changes: 116 additions & 0 deletions packages/zolt-arith/src/poly/commitment/dory.zig
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ const field = @import("../../field/mod.zig");
const msm = @import("../../msm/mod.zig");
const glv = msm.glv;
const ThreadPool = @import("zolt_pool").ThreadPool;
const is_wasm = @import("zolt_pool").is_wasm;

const gpu_mod = @import("../../gpu/mod.zig");
const GpuMsmOps = gpu_mod.GpuMsmOps;
Expand Down Expand Up @@ -1086,6 +1087,113 @@ pub const DorySRS = struct {
};
}

/// Load pre-serialized SRS from a byte slice (e.g. passed from JS into WASM memory).
/// Same format as saveToCache/loadFromCache. Returns null if data is invalid.
pub fn loadFromBytes(allocator: Allocator, data: []const u8) ?DorySRS {
if (data.len < 20) return null; // minimum: magic(4) + version(4) + n(8) + sigma(4)

var pos: usize = 0;

// Validate header
if (!std.mem.eql(u8, data[pos..][0..4], CACHE_MAGIC)) return null;
pos += 4;
if (std.mem.readInt(u32, data[pos..][0..4], .little) != CACHE_VERSION) return null;
pos += 4;
const n: usize = @intCast(std.mem.readInt(u64, data[pos..][0..8], .little));
pos += 8;
const sigma: u32 = std.mem.readInt(u32, data[pos..][0..4], .little);
pos += 4;
const nu: u32 = std.mem.readInt(u32, data[pos..][0..4], .little);
pos += 4;

// G1 points
const g1_byte_len = n * @sizeOf(G1Point);
if (pos + g1_byte_len > data.len) return null;
const g1_vec = allocator.alloc(G1Point, n) catch return null;
@memcpy(std.mem.sliceAsBytes(g1_vec), data[pos..][0..g1_byte_len]);
pos += g1_byte_len;

// G2 points
const g2_byte_len = n * @sizeOf(G2Point);
if (pos + g2_byte_len > data.len) {
allocator.free(g1_vec);
return null;
}
const g2_vec = allocator.alloc(G2Point, n) catch {
allocator.free(g1_vec);
return null;
};
@memcpy(std.mem.sliceAsBytes(g2_vec), data[pos..][0..g2_byte_len]);
pos += g2_byte_len;

// G2Prepared (optional)
var g2_prepared: ?[]G2Prepared = null;
if (pos >= data.len) {
allocator.free(g1_vec);
allocator.free(g2_vec);
return null;
}
if (data[pos] == 1) {
pos += 1;
const prep_byte_len = n * @sizeOf(G2Prepared);
if (pos + prep_byte_len > data.len) {
allocator.free(g1_vec);
allocator.free(g2_vec);
return null;
}
const prep = allocator.alloc(G2Prepared, n) catch {
allocator.free(g1_vec);
allocator.free(g2_vec);
return null;
};
@memcpy(std.mem.sliceAsBytes(prep), data[pos..][0..prep_byte_len]);
pos += prep_byte_len;
g2_prepared = prep;
} else {
pos += 1;
}

// G2PreparedAffine (optional)
var g2_prepared_affine: ?[]G2PreparedAffine = null;
if (pos < data.len and data[pos] == 1) {
pos += 1;
const affine_byte_len = n * @sizeOf(G2PreparedAffine);
if (pos + affine_byte_len > data.len) {
allocator.free(g1_vec);
allocator.free(g2_vec);
if (g2_prepared) |p| allocator.free(p);
return null;
}
const affine = allocator.alloc(G2PreparedAffine, n) catch {
allocator.free(g1_vec);
allocator.free(g2_vec);
if (g2_prepared) |p| allocator.free(p);
return null;
};
@memcpy(std.mem.sliceAsBytes(affine), data[pos..][0..affine_byte_len]);
g2_prepared_affine = affine;
} else if (pos < data.len) {
// flag == 0, skip
}

const num_columns: usize = @as(usize, 1) << @intCast(sigma);
const num_rows: usize = @as(usize, 1) << @intCast(nu);

return DorySRS{
.g1_vec = g1_vec,
.g2_vec = g2_vec,
.g2_prepared = g2_prepared,
.g2_prepared_affine = g2_prepared_affine,
.num_columns = num_columns,
.num_rows = num_rows,
.sigma = sigma,
.nu = nu,
.h1 = G1Point.generator(),
.h2 = G2Point.generator(),
.allocator = allocator,
};
}

pub fn deinit(self: *DorySRS) void {
if (self.g1_vec.len > 0) {
self.allocator.free(self.g1_vec);
Expand Down Expand Up @@ -1372,7 +1480,15 @@ pub fn DoryCommitmentScheme(comptime F: type) type {

/// Setup with disk caching. Tries to load from ~/.cache/zolt/srs_v1_{log_size}.bin.
/// On cache miss, generates SRS + prepared caches and writes to disk.
/// On WASM, skips caching and generates SRS from scratch.
pub fn setupCached(allocator: Allocator, max_num_vars: usize, tp: ?*ThreadPool) !SetupParams {
if (comptime is_wasm) {
// No filesystem on WASM — generate from scratch
var srs = try setup(allocator, max_num_vars);
srs.initPreparedCache(tp);
return srs;
}

// Build cache path: ~/.cache/zolt/srs_v1_<log_size>.bin
var path_buf: [256]u8 = undefined;
const home = std.posix.getenv("HOME") orelse "/tmp";
Expand Down
4 changes: 2 additions & 2 deletions packages/zolt-arith/src/poly/commitment/g2_msm.zig
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ fn pippengerMsmG2(comptime F: type, bases: []const G2Point, scalars: []const F)
const c = g2OptimalWindowSize(bases.len);
const num_scalar_windows = (SCALAR_BITS + c - 1) / c;
const num_windows = num_scalar_windows + 1; // +1 for wNAF carry
const num_buckets = (@as(usize, 1) << @as(u6, @intCast(c))) / 2; // wNAF: 2^(c-1) buckets
const num_buckets = (@as(usize, 1) << @as(std.math.Log2Int(usize), @intCast(c))) / 2; // wNAF: 2^(c-1) buckets

// Compute wNAF digits for all scalars
const stack_threshold = 128;
Expand Down Expand Up @@ -150,7 +150,7 @@ fn pippengerMsmG2Parallel(comptime F: type, bases: []const G2Point, scalars: []c
const c = g2OptimalWindowSize(bases.len);
const num_scalar_windows = (SCALAR_BITS + c - 1) / c;
const num_windows = num_scalar_windows + 1;
const num_buckets = (@as(usize, 1) << @as(u6, @intCast(c))) / 2;
const num_buckets = (@as(usize, 1) << @as(std.math.Log2Int(usize), @intCast(c))) / 2;

// Compute wNAF digits
const heap_digits = std.heap.page_allocator.alloc([MAX_DIGITS]i32, scalars.len) catch
Expand Down
26 changes: 13 additions & 13 deletions packages/zolt-arith/src/poly/mod.zig
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ pub fn DensePolynomial(comptime F: type) type {
for (self.evaluations, 0..) |eval, i| {
var term = eval;
for (0..self.num_vars) |j| {
const shift_amount: u6 = @intCast(j);
const shift_amount: std.math.Log2Int(usize) = @intCast(j);
const bit = (i >> shift_amount) & 1;
if (bit == 1) {
term = term.mul(point[j]);
Expand Down Expand Up @@ -307,7 +307,7 @@ pub fn EqPolynomial(comptime F: type) type {
/// - For each i in active region: result[i+size] = result[i] * r[j], result[i] -= result[i+size]
pub fn evalsSliceWithScaling(comptime FieldType: type, allocator: Allocator, r: []const FieldType, scaling_factor: ?FieldType) ![]FieldType {
const n = r.len;
const final_size = @as(usize, 1) << @as(u6, @intCast(n));
const final_size = @as(usize, 1) << @as(std.math.Log2Int(usize), @intCast(n));
const result = try allocator.alloc(FieldType, final_size);
buildEqTableInPlace(r, result, scaling_factor);
return result;
Expand All @@ -323,7 +323,7 @@ pub fn EqPolynomial(comptime F: type) type {
return result;
}

const final_size = @as(usize, 1) << @as(u6, @intCast(n));
const final_size = @as(usize, 1) << @as(std.math.Log2Int(usize), @intCast(n));
const result = try allocator.alloc(FieldType, final_size);

// No memset needed: every entry is written before read by the expansion loop below.
Expand Down Expand Up @@ -1009,7 +1009,7 @@ test "EqPolynomial buildEqTableInPlace matches per-point mle" {
defer allocator.free(j_bits);
for (0..size) |j| {
for (0..n) |k| {
const bit_pos: u6 = @intCast(n - 1 - k);
const bit_pos: std.math.Log2Int(usize) = @intCast(n - 1 - k);
j_bits[k] = if ((j >> bit_pos) & 1 == 1) F.one() else F.zero();
}
const expected = EqPolynomial(F).mle(r, j_bits);
Expand All @@ -1024,7 +1024,7 @@ test "EqPolynomial buildEqTableInPlace matches per-point mle" {

for (0..size) |j| {
for (0..n) |k| {
const bit_pos: u6 = @intCast(n - 1 - k);
const bit_pos: std.math.Log2Int(usize) = @intCast(n - 1 - k);
j_bits[k] = if ((j >> bit_pos) & 1 == 1) F.one() else F.zero();
}
const expected_scaled = scale.mul(EqPolynomial(F).mle(r, j_bits));
Expand Down Expand Up @@ -1063,7 +1063,7 @@ test "EqPolynomial buildEqPlusOneTableInPlace matches per-point mle" {
defer allocator.free(j_bits);
for (0..size) |j| {
for (0..n) |k| {
const bit_pos: u6 = @intCast(n - 1 - k);
const bit_pos: std.math.Log2Int(usize) = @intCast(n - 1 - k);
j_bits[k] = if ((j >> bit_pos) & 1 == 1) F.one() else F.zero();
}
const expected = EqPlusOnePolynomial(F).mle(r, j_bits);
Expand Down Expand Up @@ -1102,7 +1102,7 @@ test "EqPolynomial buildEqAndEqPlusOneInPlace matches per-point mle" {
defer allocator.free(j_bits);
for (0..size) |j| {
for (0..n) |k| {
const bit_pos: u6 = @intCast(n - 1 - k);
const bit_pos: std.math.Log2Int(usize) = @intCast(n - 1 - k);
j_bits[k] = if ((j >> bit_pos) & 1 == 1) F.one() else F.zero();
}
// eq table matches EqPolynomial.mle
Expand Down Expand Up @@ -1143,7 +1143,7 @@ test "EqPlusOnePrefixSuffixPoly batch construction matches per-point mle" {
defer allocator.free(j_bits);
for (0..size_lo) |j| {
for (0..r_lo.len) |k| {
const bit_pos: u6 = @intCast(r_lo.len - 1 - k);
const bit_pos: std.math.Log2Int(usize) = @intCast(r_lo.len - 1 - k);
j_bits[k] = if ((j >> bit_pos) & 1 == 1) F.one() else F.zero();
}
const expected = EqPlusOnePolynomial(F).mle(r_lo, j_bits[0..r_lo.len]);
Expand All @@ -1164,7 +1164,7 @@ test "EqPlusOnePrefixSuffixPoly batch construction matches per-point mle" {
// Verify suffix_0 = eq(r_hi, j) for all j
for (0..size_hi) |j| {
for (0..r_hi.len) |k| {
const bit_pos: u6 = @intCast(r_hi.len - 1 - k);
const bit_pos: std.math.Log2Int(usize) = @intCast(r_hi.len - 1 - k);
j_bits[k] = if ((j >> bit_pos) & 1 == 1) F.one() else F.zero();
}
const expected = EqPolynomial(F).mle(r_hi, j_bits[0..r_hi.len]);
Expand All @@ -1174,7 +1174,7 @@ test "EqPlusOnePrefixSuffixPoly batch construction matches per-point mle" {
// Verify suffix_1 = eq+1(r_hi, j) for all j
for (0..size_hi) |j| {
for (0..r_hi.len) |k| {
const bit_pos: u6 = @intCast(r_hi.len - 1 - k);
const bit_pos: std.math.Log2Int(usize) = @intCast(r_hi.len - 1 - k);
j_bits[k] = if ((j >> bit_pos) & 1 == 1) F.one() else F.zero();
}
const expected = EqPlusOnePolynomial(F).mle(r_hi, j_bits[0..r_hi.len]);
Expand All @@ -1192,7 +1192,7 @@ test "EqPlusOnePrefixSuffixPoly batch construction matches per-point mle" {
const actual = poly.prefix_0[j_lo].mul(poly.suffix_0[j_hi])
.add(poly.prefix_1[j_lo].mul(poly.suffix_1[j_hi]));
for (0..n) |k| {
const bit_pos: u6 = @intCast(n - 1 - k);
const bit_pos: std.math.Log2Int(usize) = @intCast(n - 1 - k);
full_j_bits[k] = if ((j >> bit_pos) & 1 == 1) F.one() else F.zero();
}
const expected_full = EqPlusOnePolynomial(F).mle(r, full_j_bits);
Expand Down Expand Up @@ -1227,7 +1227,7 @@ test "EqPolynomial batch builders with large field elements" {
EqPolynomial(F).buildEqTableInPlace(&large_r, eq_out, null);
for (0..size) |j| {
for (0..n) |k| {
const bit_pos: u6 = @intCast(n - 1 - k);
const bit_pos: std.math.Log2Int(usize) = @intCast(n - 1 - k);
j_bits[k] = if ((j >> bit_pos) & 1 == 1) F.one() else F.zero();
}
try std.testing.expect(eq_out[j].eql(EqPolynomial(F).mle(&large_r, j_bits)));
Expand All @@ -1248,7 +1248,7 @@ test "EqPolynomial batch builders with large field elements" {
EqPolynomial(F).buildEqPlusOneTableInPlace(&large_r, eqp1_out);
for (0..size) |j| {
for (0..n) |k| {
const bit_pos: u6 = @intCast(n - 1 - k);
const bit_pos: std.math.Log2Int(usize) = @intCast(n - 1 - k);
j_bits[k] = if ((j >> bit_pos) & 1 == 1) F.one() else F.zero();
}
try std.testing.expect(eqp1_out[j].eql(EqPlusOnePolynomial(F).mle(&large_r, j_bits)));
Expand Down
1 change: 1 addition & 0 deletions packages/zolt-pool/src/root.zig
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

pub const thread_pool = @import("thread_pool.zig");
pub const ThreadPool = thread_pool.ThreadPool;
pub const is_wasm = thread_pool.is_wasm;

pub const parallel_sort = @import("parallel_sort.zig");
pub const parallelSort = parallel_sort.parallelSort;
Expand Down
Loading
Loading