Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions ec-gpu-gen/src/multiexp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@ use ff::PrimeField;
use group::{prime::PrimeCurveAffine, Group};
use log::{error, info};
use rust_gpu_tools::{program_closures, Device, Program};
#[cfg(not(target_arch = "wasm32"))]
use yastl::Scope;
#[cfg(target_arch = "wasm32")]
use crate::threadpool::WasmScope as Scope;

use crate::{
error::{EcError, EcResult},
Expand Down
327 changes: 218 additions & 109 deletions ec-gpu-gen/src/threadpool.rs
Original file line number Diff line number Diff line change
@@ -1,137 +1,245 @@
//! An interface for dealing with the kinds of parallel computations involved.
use std::env;

use crossbeam_channel::{bounded, Receiver, SendError};
use log::trace;
use once_cell::sync::Lazy;
use yastl::Pool;

/// The number of threads the thread pool should use.
///
/// By default it's equal to the number of CPUs, but it can be changed with the
/// `EC_GPU_NUM_THREADS` environment variable.
static NUM_THREADS: Lazy<usize> = Lazy::new(read_num_threads);

/// The thread pool that is used for the computations.
///
/// By default, it's size is equal to the number of CPUs. It can be set to a different value with
/// the `EC_GPU_NUM_THREADS` environment variable.
pub static THREAD_POOL: Lazy<Pool> = Lazy::new(|| Pool::new(*NUM_THREADS));

/// Returns the number of threads.
///
/// The number can be set with the `EC_GPU_NUM_THREADS` environment variable. If it isn't set, it
/// defaults to the number of CPUs the system has.
fn read_num_threads() -> usize {
env::var("EC_GPU_NUM_THREADS")
.ok()
.and_then(|num| num.parse::<usize>().ok())
.unwrap_or_else(num_cpus::get)
}
//!
//! On native targets, uses a thread pool (yastl) for parallel computation.
//! On `wasm32` targets, executes everything on the main thread since WASM
//! does not support spawning OS threads.

/// A worker operates on a pool of threads.
#[derive(Clone, Default)]
pub struct Worker {}
// ════════════════════════════════════════════════════════════════
// Native implementation (multi-threaded via yastl)
// ════════════════════════════════════════════════════════════════

impl Worker {
/// Returns a new worker.
pub fn new() -> Worker {
Worker {}
}
#[cfg(not(target_arch = "wasm32"))]
mod native {
use std::env;

/// Returns binary logarithm (floored) of the number of threads.
///
/// This means, the number of threads is `2^log_num_threads()`.
pub fn log_num_threads(&self) -> u32 {
log2_floor(*NUM_THREADS)
use crossbeam_channel::{bounded, Receiver, SendError};
use log::trace;
use once_cell::sync::Lazy;
use yastl::Pool;

static NUM_THREADS: Lazy<usize> = Lazy::new(read_num_threads);

pub static THREAD_POOL: Lazy<Pool> = Lazy::new(|| Pool::new(*NUM_THREADS));

fn read_num_threads() -> usize {
env::var("EC_GPU_NUM_THREADS")
.ok()
.and_then(|num| num.parse::<usize>().ok())
.unwrap_or_else(num_cpus::get)
}

/// Executes a function in a thread and returns a [`Waiter`] immediately.
pub fn compute<F, R>(&self, f: F) -> Waiter<R>
where
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
{
let (sender, receiver) = bounded(1);

THREAD_POOL.spawn(move || {
let res = f();
// Best effort. We run it in a separate thread, so the receiver might not exist
// anymore, but that's OK. It only means that we are not interested in the result.
// A message is logged though, as concurrency issues are hard to debug and this might
// help in such cases.
if let Err(SendError(_)) = sender.send(res) {
trace!("Cannot send result");
}
});
#[derive(Clone, Default)]
pub struct Worker {}

impl Worker {
pub fn new() -> Worker {
Worker {}
}

pub fn log_num_threads(&self) -> u32 {
log2_floor(*NUM_THREADS)
}

pub fn compute<F, R>(&self, f: F) -> Waiter<R>
where
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
{
let (sender, receiver) = bounded(1);

Waiter { receiver }
THREAD_POOL.spawn(move || {
let res = f();
if let Err(SendError(_)) = sender.send(res) {
trace!("Cannot send result");
}
});

Waiter { receiver }
}

pub fn scope<'a, F, R>(&self, elements: usize, f: F) -> R
where
F: FnOnce(&yastl::Scope<'a>, usize) -> R,
{
let chunk_size = if elements < *NUM_THREADS {
1
} else {
elements / *NUM_THREADS
};

THREAD_POOL.scoped(|scope| f(scope, chunk_size))
}

pub fn scoped<'a, F, R>(&self, f: F) -> R
where
F: FnOnce(&yastl::Scope<'a>) -> R,
{
let (sender, receiver) = bounded(1);
THREAD_POOL.scoped(|s| {
let res = f(s);
sender.send(res).unwrap();
});

receiver.recv().unwrap()
}
}

/// Executes a function and returns the result once it is finished.
///
/// The function gets the [`yastl::Scope`] as well as the `chunk_size` as parameters. THe
/// `chunk_size` is number of elements per thread.
pub fn scope<'a, F, R>(&self, elements: usize, f: F) -> R
where
F: FnOnce(&yastl::Scope<'a>, usize) -> R,
{
let chunk_size = if elements < *NUM_THREADS {
1
} else {
elements / *NUM_THREADS
};

THREAD_POOL.scoped(|scope| f(scope, chunk_size))
pub struct Waiter<T> {
receiver: Receiver<T>,
}

/// Executes the passed in function, and returns the result once it is finished.
pub fn scoped<'a, F, R>(&self, f: F) -> R
where
F: FnOnce(&yastl::Scope<'a>) -> R,
{
let (sender, receiver) = bounded(1);
THREAD_POOL.scoped(|s| {
let res = f(s);
sender.send(res).unwrap();
});
impl<T> Waiter<T> {
pub fn wait(&self) -> T {
self.receiver.recv().unwrap()
}

receiver.recv().unwrap()
pub fn done(val: T) -> Self {
let (sender, receiver) = bounded(1);
sender.send(val).unwrap();
Waiter { receiver }
}
}
}

/// A future that is waiting for a result.
pub struct Waiter<T> {
receiver: Receiver<T>,
pub(crate) fn log2_floor(num: usize) -> u32 {
assert!(num > 0);
let mut pow = 0;
while (1 << (pow + 1)) <= num {
pow += 1;
}
pow
}
}

impl<T> Waiter<T> {
/// Wait for the result.
pub fn wait(&self) -> T {
self.receiver.recv().unwrap()
// ════════════════════════════════════════════════════════════════
// WASM implementation (single-threaded fallback)
// ════════════════════════════════════════════════════════════════

#[cfg(target_arch = "wasm32")]
mod wasm {
/// A worker that executes everything on the main thread.
///
/// WASM does not support spawning OS threads (std::thread::spawn panics).
/// This fallback runs all computations sequentially on the calling thread.
/// Performance is lower than native multi-threaded execution but correctness
/// is preserved — all bellperson operations are safe to run single-threaded.
#[derive(Clone, Default)]
pub struct Worker {}

impl Worker {
pub fn new() -> Worker {
Worker {}
}

/// Returns 0 (single-threaded = 2^0 = 1 thread).
pub fn log_num_threads(&self) -> u32 {
0
}

/// Executes the function immediately on the current thread.
pub fn compute<F, R>(&self, f: F) -> Waiter<R>
where
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
{
let result = f();
Waiter { value: std::cell::RefCell::new(Some(result)) }
}

/// Executes with chunk_size = elements (single thread processes everything).
pub fn scope<'a, F, R>(&self, elements: usize, f: F) -> R
where
F: FnOnce(&WasmScope<'a>, usize) -> R,
{
let scope = WasmScope::new();
f(&scope, elements)
}

/// Executes the function immediately.
pub fn scoped<'a, F, R>(&self, f: F) -> R
where
F: FnOnce(&WasmScope<'a>) -> R,
{
let scope = WasmScope::new();
f(&scope)
}
}

/// One off sending.
pub fn done(val: T) -> Self {
let (sender, receiver) = bounded(1);
sender.send(val).unwrap();
/// A minimal scope that executes spawned closures immediately (sequentially).
///
/// Mimics the `yastl::Scope<'a>` API but runs everything on the current thread.
/// The lifetime parameter matches yastl::Scope's signature for API compatibility.
pub struct WasmScope<'a> {
_marker: std::marker::PhantomData<&'a ()>,
}

Waiter { receiver }
impl<'a> WasmScope<'a> {
fn new() -> Self {
WasmScope { _marker: std::marker::PhantomData }
}

/// "Spawns" a closure by executing it immediately on the current thread.
pub fn execute<F>(&self, f: F)
where
F: FnOnce() + Send,
{
f();
}
}
}

fn log2_floor(num: usize) -> u32 {
assert!(num > 0);
/// A future that already has its value (computed synchronously).
pub struct Waiter<T> {
value: std::cell::RefCell<Option<T>>,
}

let mut pow = 0;
impl<T> Waiter<T> {
/// Takes the value. Panics if called more than once.
pub fn wait(&self) -> T {
self.value.borrow_mut().take()
.expect("Waiter::wait called on already-consumed waiter")
}

while (1 << (pow + 1)) <= num {
pow += 1;
pub fn done(val: T) -> Self {
Waiter { value: std::cell::RefCell::new(Some(val)) }
}
}

pow
/// A single-threaded "pool" for WASM — executes closures inline.
pub struct WasmPool;

impl WasmPool {
pub fn scoped<'a, F, R>(&self, f: F) -> R
where
F: FnOnce(&WasmScope<'a>) -> R,
{
let scope = WasmScope::new();
f(&scope)
}
}

/// Global "thread pool" for WASM — single-threaded.
pub static THREAD_POOL: WasmPool = WasmPool;

pub(crate) fn log2_floor(num: usize) -> u32 {
assert!(num > 0);
let mut pow = 0;
while (1 << (pow + 1)) <= num {
pow += 1;
}
pow
}
}

// ════════════════════════════════════════════════════════════════
// Re-exports — callers use Worker, Waiter, THREAD_POOL without
// knowing which implementation is active.
// ════════════════════════════════════════════════════════════════

#[cfg(not(target_arch = "wasm32"))]
pub use native::*;

#[cfg(target_arch = "wasm32")]
pub use wasm::*;

#[cfg(test)]
mod tests {
use super::*;
Expand All @@ -147,20 +255,21 @@ mod tests {
assert_eq!(log2_floor(8), 3);
}

#[cfg(not(target_arch = "wasm32"))]
#[test]
fn test_read_num_threads() {
let num_cpus = num_cpus::get();
temp_env::with_var("EC_GPU_NUM_THREADS", None::<&str>, || {
assert_eq!(
read_num_threads(),
native::read_num_threads(),
num_cpus,
"By default the number of threads matches the number of CPUs."
);
});

temp_env::with_var("EC_GPU_NUM_THREADS", Some("1234"), || {
assert_eq!(
read_num_threads(),
native::read_num_threads(),
1234,
"Number of threads matches the environment variable."
);
Expand Down