From 31a4de6e9b8d81314b5a21e09154c867171810f6 Mon Sep 17 00:00:00 2001 From: Alexey Spiridonov Date: Sat, 17 Aug 2024 11:59:35 -0700 Subject: [PATCH] Reduce duplication between `TaskPromise` specializations Summary: Make the `TaskPromise<>` leaf classes as simple as they can be. This shortens the code by 30 lines, and makes it easier to follow. Reviewed By: yfeldblum Differential Revision: D61249849 fbshipit-source-id: c12ca1d96c2c99fd85d53df9a77146450d525ca6 --- folly/coro/Task.h | 144 ++++++++++++++++++---------------------------- 1 file changed, 57 insertions(+), 87 deletions(-) diff --git a/folly/coro/Task.h b/folly/coro/Task.h index 75785e0b6a1..45905183305 100644 --- a/folly/coro/Task.h +++ b/folly/coro/Task.h @@ -168,10 +168,10 @@ class TaskPromiseBase { } private: - template + template friend class folly::coro::TaskWithExecutor; - template + template friend class folly::coro::Task; friend coroutine_handle tag_invoke( @@ -196,45 +196,19 @@ class TaskPromiseBase { } bypassExceptionThrowing_{BypassExceptionThrowing::INACTIVE}; }; -template -class TaskPromise final : public TaskPromiseBase, - public ExtendedCoroutinePromiseImpl> { +// Separate from `TaskPromiseBase` so the compiler has less to specialize. +template +class TaskPromiseCrtpBase : public TaskPromiseBase, + public ExtendedCoroutinePromiseImpl { public: - static_assert( - !std::is_rvalue_reference_v, - "Task is not supported. " - "Consider using Task or Task> instead."); - friend class TaskPromiseBase; - using StorageType = detail::lift_lvalue_reference_t; - TaskPromise() noexcept = default; - Task get_return_object() noexcept; void unhandled_exception() noexcept { result_.emplaceException(exception_wrapper{current_exception()}); } - template - void return_value(U&& value) { - if constexpr (std::is_same_v, Try>) { - DCHECK(value.hasValue() || (value.hasException() && value.exception())); - result_ = static_cast(value); - } else if constexpr ( - std::is_same_v, Try> && - std::is_same_v, Unit>) { - // special-case to make task -> semifuture -> task preserve void type - DCHECK(value.hasValue() || (value.hasException() && value.exception())); - result_ = static_cast>(static_cast(value)); - } else { - static_assert( - std::is_convertible::value, - "cannot convert return value to type T"); - result_.emplace(static_cast(value)); - } - } - Try& result() { return result_; } auto yield_value(co_error ex) { @@ -253,81 +227,80 @@ class TaskPromise final : public TaskPromiseBase, return do_safe_point(*this); } + protected: + TaskPromiseCrtpBase() noexcept = default; + ~TaskPromiseCrtpBase() = default; + std::pair getErrorHandle( exception_wrapper& ex) override { + auto& me = *static_cast(this); if (bypassExceptionThrowing_ == BypassExceptionThrowing::ACTIVE) { auto finalAwaiter = yield_value(co_error(std::move(ex))); DCHECK(!finalAwaiter.await_ready()); return { finalAwaiter.await_suspend( - coroutine_handle::from_promise(*this)), + coroutine_handle::from_promise(me)), // finalAwaiter.await_suspend pops a frame getAsyncFrame().getParentFrame()}; } - return {coroutine_handle::from_promise(*this), nullptr}; + return {coroutine_handle::from_promise(me), nullptr}; } - private: Try result_; }; -template <> -class TaskPromise final - : public TaskPromiseBase, - public ExtendedCoroutinePromiseImpl> { +template +class TaskPromise final : public TaskPromiseCrtpBase, T> { public: + static_assert( + !std::is_rvalue_reference_v, + "Task is not supported. " + "Consider using Task or Task> instead."); friend class TaskPromiseBase; - using StorageType = void; + using StorageType = + typename TaskPromiseCrtpBase, T>::StorageType; TaskPromise() noexcept = default; - Task get_return_object() noexcept; - - void unhandled_exception() noexcept { - result_.emplaceException(exception_wrapper{current_exception()}); + template + void return_value(U&& value) { + if constexpr (std::is_same_v, Try>) { + DCHECK(value.hasValue() || (value.hasException() && value.exception())); + this->result_ = static_cast(value); + } else if constexpr ( + std::is_same_v, Try> && + std::is_same_v, Unit>) { + // special-case to make task -> semifuture -> task preserve void type + DCHECK(value.hasValue() || (value.hasException() && value.exception())); + this->result_ = static_cast>(static_cast(value)); + } else { + static_assert( + std::is_convertible::value, + "cannot convert return value to type T"); + this->result_.emplace(static_cast(value)); + } } +}; - void return_void() noexcept { result_.emplace(); } - - Try& result() { return result_; } +template <> +class TaskPromise final + : public TaskPromiseCrtpBase, void> { + public: + friend class TaskPromiseBase; - auto yield_value(co_error ex) { - result_.emplaceException(std::move(ex.exception())); - return final_suspend(); - } + using StorageType = void; - auto yield_value(co_result&& result) { - result_ = std::move(result.result()); - return final_suspend(); - } - auto yield_value(co_result&& result) { - result_ = std::move(result.result()); - return final_suspend(); - } + TaskPromise() noexcept = default; - using TaskPromiseBase::await_transform; + void return_void() noexcept { this->result_.emplace(); } - auto await_transform(co_safe_point_t) noexcept { - return do_safe_point(*this); - } + using TaskPromiseCrtpBase, void>::yield_value; - std::pair getErrorHandle( - exception_wrapper& ex) override { - if (bypassExceptionThrowing_ == BypassExceptionThrowing::ACTIVE) { - auto finalAwaiter = yield_value(co_error(std::move(ex))); - DCHECK(!finalAwaiter.await_ready()); - return { - finalAwaiter.await_suspend( - coroutine_handle::from_promise(*this)), - // finalAwaiter.await_suspend pops a frame - getAsyncFrame().getParentFrame()}; - } - return {coroutine_handle::from_promise(*this), nullptr}; + auto yield_value(co_result&& result) { + this->result_ = std::move(result.result()); + return final_suspend(); } - - private: - Try result_; }; } // namespace detail @@ -786,7 +759,7 @@ class FOLLY_NODISCARD Task { private: friend class detail::TaskPromiseBase; - friend class detail::TaskPromise; + friend class detail::TaskPromiseCrtpBase, T>; friend class TaskWithExecutor; class Awaiter { @@ -890,14 +863,11 @@ Task> makeResultTask(Try t) { co_yield co_result(std::move(t)); } -template -Task detail::TaskPromise::get_return_object() noexcept { - return Task{coroutine_handle>::from_promise(*this)}; -} - -inline Task detail::TaskPromise::get_return_object() noexcept { - return Task{ - coroutine_handle>::from_promise(*this)}; +template +inline Task +detail::TaskPromiseCrtpBase::get_return_object() noexcept { + return Task{ + coroutine_handle::from_promise(*static_cast(this))}; } } // namespace coro