use rayon::ThreadPool;
use std::future::Future;
use std::panic::{catch_unwind, resume_unwind, AssertUnwindSafe};
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::thread;
use tokio::sync::oneshot;
pub fn spawn_compute<F, R>(func: F) -> AsyncRayonHandle<R>
where
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
{
let (tx, rx) = oneshot::channel();
rayon::spawn(move || {
let ret = catch_unwind(AssertUnwindSafe(func));
let _res = tx.send(ret);
});
AsyncRayonHandle { rx }
}
#[must_use]
#[derive(Debug)]
pub struct AsyncRayonHandle<T> {
pub(crate) rx: oneshot::Receiver<thread::Result<T>>,
}
impl<T> Future for AsyncRayonHandle<T> {
type Output = T;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let rx = Pin::new(&mut self.rx);
rx.poll(cx).map(|result| {
result
.expect("Unreachable error: Tokio channel closed")
.unwrap_or_else(|err| resume_unwind(err))
})
}
}
pub trait AsyncThreadPool: private::Sealed {
fn spawn_compute<F, R>(&self, func: F) -> AsyncRayonHandle<R>
where
F: FnOnce() -> R + Send + 'static,
R: Send + 'static;
fn spawn_install_compute<F, R>(self: Arc<Self>, func: F) -> AsyncRayonHandle<R>
where
F: FnOnce() -> R + Send + 'static,
R: Send + 'static;
}
impl AsyncThreadPool for ThreadPool {
fn spawn_compute<F, R>(&self, func: F) -> AsyncRayonHandle<R>
where
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
{
let (tx, rx) = oneshot::channel();
self.spawn(move || {
let ret = catch_unwind(AssertUnwindSafe(func));
let _res = tx.send(ret);
});
AsyncRayonHandle { rx }
}
fn spawn_install_compute<F, R>(self: Arc<Self>, func: F) -> AsyncRayonHandle<R>
where
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
{
let this = Arc::clone(&self);
let (tx, rx) = oneshot::channel();
self.spawn(move || {
this.install(move || {
let ret = catch_unwind(AssertUnwindSafe(func));
let _res = tx.send(ret);
});
});
AsyncRayonHandle { rx }
}
}
mod private {
use rayon::ThreadPool;
pub trait Sealed {}
impl Sealed for ThreadPool {}
}
#[cfg(test)]
mod tests {
use super::*;
use rayon::ThreadPoolBuilder;
fn build_thread_pool() -> ThreadPool {
ThreadPoolBuilder::new().num_threads(1).build().unwrap()
}
#[tokio::test]
#[should_panic(expected = "Task failed successfully")]
async fn test_poll_propagates_panic() {
let panic_err = catch_unwind(|| {
panic!("Task failed successfully");
})
.unwrap_err();
let (tx, rx) = oneshot::channel::<thread::Result<()>>();
let handle = AsyncRayonHandle { rx };
tx.send(Err(panic_err)).unwrap();
handle.await;
}
#[tokio::test]
#[should_panic(expected = "Unreachable error: Tokio channel closed")]
async fn test_unreachable_channel_closed() {
let (_, rx) = oneshot::channel::<thread::Result<()>>();
let handle = AsyncRayonHandle { rx };
handle.await;
}
#[tokio::test]
async fn test_spawn_compute_works() {
let pool = build_thread_pool();
let result = pool
.spawn_compute(|| {
let thread_index = rayon::current_thread_index();
assert_eq!(thread_index, Some(0));
1337_usize
})
.await;
assert_eq!(result, 1337);
let thread_index = rayon::current_thread_index();
assert_eq!(thread_index, None);
}
#[tokio::test]
#[should_panic(expected = "Task failed successfully")]
async fn test_spawn_compute_propagates_panic() {
let pool = build_thread_pool();
let handle = pool.spawn_compute(|| {
panic!("Task failed successfully");
});
handle.await;
}
}