use std::fmt::{self, Debug, Formatter};
use std::future::{Future, Ready};
use std::marker::PhantomData;
use std::ops::Deref;
use std::pin::Pin;
use std::task::{ready, Context, Poll};
use std::{any, mem, thread};
use pin_project::{pin_project, pinned_drop};
use super::r#impl::JoinHandle;
use super::{r#impl, Builder, Thread};
#[cfg(feature = "message")]
use crate::web::message::MessageSend;
#[track_caller]
pub fn scope<'env, F, T>(#[allow(clippy::min_ident_chars)] f: F) -> T
where
F: for<'scope> FnOnce(&'scope Scope<'scope, 'env>) -> T,
{
let scope = Scope {
this: r#impl::Scope::new(),
_scope: PhantomData,
_env: PhantomData,
};
let result = f(&scope);
scope.this.finish();
result
}
pub(crate) fn scope_async<'scope, 'env: 'scope, F1, F2, T>(
task: F1,
) -> ScopeFuture<'scope, 'env, F2, T>
where
F1: FnOnce(&'scope Scope<'scope, 'env>) -> F2,
F2: Future<Output = T>,
{
let scope = Box::pin(Scope {
this: r#impl::Scope::new(),
_scope: PhantomData,
_env: PhantomData,
});
let task = task(unsafe { mem::transmute::<&Scope<'_, '_>, &Scope<'_, '_>>(scope.deref()) });
ScopeFuture::new(task, scope)
}
#[derive(Debug)]
pub struct Scope<'scope, 'env: 'scope> {
pub(super) this: r#impl::Scope,
#[allow(clippy::struct_field_names)]
pub(super) _scope: PhantomData<&'scope mut &'scope ()>,
pub(super) _env: PhantomData<&'env mut &'env ()>,
}
impl<'scope, #[allow(single_use_lifetimes)] 'env> Scope<'scope, 'env> {
pub fn spawn<F, T>(
&'scope self,
#[allow(clippy::min_ident_chars)] f: F,
) -> ScopedJoinHandle<'scope, T>
where
F: FnOnce() -> T + Send + 'scope,
T: Send + 'scope,
{
Builder::new()
.spawn_scoped(self, f)
.expect("failed to spawn thread")
}
pub(crate) fn spawn_async_internal<F1, F2, T>(
&'scope self,
task: F1,
) -> ScopedJoinHandle<'scope, T>
where
F1: 'scope + FnOnce() -> F2 + Send,
F2: 'scope + Future<Output = T>,
T: 'scope + Send,
{
Builder::new()
.spawn_scoped_async_internal(self, task)
.expect("failed to spawn thread")
}
#[cfg(feature = "message")]
pub(crate) fn spawn_with_message_internal<F1, F2, T, M>(
&'scope self,
task: F1,
message: M,
) -> ScopedJoinHandle<'scope, T>
where
F1: 'scope + FnOnce(M) -> F2 + Send,
F2: 'scope + Future<Output = T>,
T: 'scope + Send,
M: 'scope + MessageSend,
{
Builder::new()
.spawn_scoped_with_message_internal(self, task, message)
.expect("failed to spawn thread")
}
}
pub struct ScopedJoinHandle<'scope, T> {
handle: JoinHandle<T>,
_scope: PhantomData<&'scope ()>,
}
impl<T> Debug for ScopedJoinHandle<'_, T> {
fn fmt(&self, formatter: &mut Formatter<'_>) -> fmt::Result {
formatter
.debug_struct("ScopedJoinHandle")
.field("handle", &self.handle)
.field("_scope", &self._scope)
.finish()
}
}
impl<#[allow(single_use_lifetimes)] 'scope, T> ScopedJoinHandle<'scope, T> {
#[cfg(target_feature = "atomics")]
pub(super) const fn new(handle: JoinHandle<T>) -> Self {
Self {
handle,
_scope: PhantomData,
}
}
#[must_use]
pub fn thread(&self) -> &Thread {
self.handle.thread()
}
#[allow(clippy::missing_errors_doc)]
pub fn join(self) -> thread::Result<T> {
self.handle.join()
}
#[allow(clippy::must_use_candidate)]
pub fn is_finished(&self) -> bool {
self.handle.is_finished()
}
pub(crate) fn poll(&mut self, cx: &mut Context<'_>) -> Poll<thread::Result<T>> {
Pin::new(&mut self.handle).poll(cx)
}
}
#[pin_project(PinnedDrop)]
pub(crate) struct ScopeFuture<'scope, 'env, F, T>(#[pin] State<'scope, 'env, F, T>);
#[pin_project(project = ScopeFutureProj, project_replace = ScopeFutureReplace)]
enum State<'scope, 'env, F, T> {
Task {
#[pin]
task: F,
scope: Pin<Box<Scope<'scope, 'env>>>,
},
Wait {
result: T,
scope: Pin<Box<Scope<'scope, 'env>>>,
},
None,
}
impl<F, T> Debug for ScopeFuture<'_, '_, F, T> {
fn fmt(&self, formatter: &mut Formatter<'_>) -> fmt::Result {
formatter.debug_tuple("ScopeFuture").field(&self.0).finish()
}
}
impl<F, T> Debug for State<'_, '_, F, T> {
fn fmt(&self, formatter: &mut Formatter<'_>) -> fmt::Result {
match self {
Self::Task { scope, .. } => formatter
.debug_struct("Task")
.field("task", &any::type_name::<F>())
.field("scope", &scope)
.finish(),
Self::Wait { scope, .. } => formatter
.debug_struct("Wait")
.field("result", &any::type_name::<T>())
.field("scope", &scope)
.finish(),
Self::None => formatter.write_str("None"),
}
}
}
#[pinned_drop]
impl<F, T> PinnedDrop for ScopeFuture<'_, '_, F, T> {
fn drop(self: Pin<&mut Self>) {
let this = self.project();
if let ScopeFutureReplace::Task { scope, .. } | ScopeFutureReplace::Wait { scope, .. } =
this.0.project_replace(State::None)
{
scope.this.finish();
}
}
}
impl<F, T> Future for ScopeFuture<'_, '_, F, T>
where
F: Future<Output = T>,
{
type Output = T;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut this = self.project();
loop {
match this.0.as_mut().project() {
ScopeFutureProj::Task { task, .. } => {
let result = ready!(task.poll(cx));
let ScopeFutureReplace::Task { scope, .. } =
this.0.as_mut().project_replace(State::None)
else {
unreachable!("found wrong state")
};
this.0
.as_mut()
.project_replace(State::Wait { result, scope });
}
ScopeFutureProj::Wait { scope, .. } => {
ready!(scope.this.finish_async(cx));
let ScopeFutureReplace::Wait { result, .. } =
this.0.project_replace(State::None)
else {
unreachable!("found wrong state")
};
return Poll::Ready(result);
}
ScopeFutureProj::None => panic!("`ScopeFuture` polled after completion"),
}
}
}
}
impl<'scope, 'env, F, T> ScopeFuture<'scope, 'env, F, T>
where
F: Future<Output = T>,
{
pub(super) const fn new(task: F, scope: Pin<Box<Scope<'scope, 'env>>>) -> Self {
Self(State::Task { task, scope })
}
pub(crate) fn poll_into_wait(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<ScopeFuture<'scope, 'env, Ready<T>, T>> {
let mut this = self.project();
match this.0.as_mut().project() {
ScopeFutureProj::Task { task, .. } => {
let result = ready!(task.poll(cx));
let ScopeFutureReplace::Task { scope, .. } = this.0.project_replace(State::None)
else {
unreachable!("found wrong state")
};
Poll::Ready(ScopeFuture(State::Wait { result, scope }))
}
ScopeFutureProj::Wait { .. } => {
let ScopeFutureReplace::Wait { result, scope } =
this.0.project_replace(State::None)
else {
unreachable!("found wrong state")
};
return Poll::Ready(ScopeFuture(State::Wait { result, scope }));
}
ScopeFutureProj::None => panic!("`ScopeFuture` polled after completion"),
}
}
pub(crate) fn is_finished(&self) -> bool {
match &self.0 {
State::Task { .. } => false,
State::Wait { scope, .. } => scope.this.thread_count() == 0,
State::None => true,
}
}
pub(crate) fn join_all(mut self) -> T {
match mem::replace(&mut self.0, State::None) {
State::Wait { result, scope } => {
assert!(
super::has_block_support(),
"current thread type cannot be blocked"
);
scope.this.finish();
result
}
State::None => {
panic!("called after `ScopeJoinFuture` was polled to completion")
}
State::Task { .. } => {
unreachable!("should only be called from `ScopeJoinFuture`")
}
}
}
}