diff --git a/src/blas/backends/cublas/cublas_handle.hpp b/src/blas/backends/cublas/cublas_handle.hpp index 83a76c927..8b77282df 100644 --- a/src/blas/backends/cublas/cublas_handle.hpp +++ b/src/blas/backends/cublas/cublas_handle.hpp @@ -18,36 +18,30 @@ **************************************************************************/ #ifndef CUBLAS_HANDLE_HPP #define CUBLAS_HANDLE_HPP -#include #include +#include "cublas_helper.hpp" namespace oneapi { namespace mkl { namespace blas { namespace cublas { -template struct cublas_handle { - using handle_container_t = std::unordered_map*>; + using handle_container_t = std::unordered_map; handle_container_t cublas_handle_mapper_{}; ~cublas_handle() noexcept(false) { + CUresult err; + CUcontext original; + CUDA_ERROR_FUNC(cuCtxGetCurrent, err, &original); for (auto& handle_pair : cublas_handle_mapper_) { - cublasStatus_t err; - if (handle_pair.second != nullptr) { - auto handle = handle_pair.second->exchange(nullptr); - if (handle != nullptr) { - CUBLAS_ERROR_FUNC(cublasDestroy, err, handle); - handle = nullptr; - } - else { - // if the handle is nullptr it means the handle was already - // destroyed by the ContextCallback and we're free to delete the - // atomic object. - delete handle_pair.second; - } - - handle_pair.second = nullptr; + CUcontext desired; + CUDA_ERROR_FUNC(cuDevicePrimaryCtxRetain, err, &desired, handle_pair.first); + if (original != desired) { + // Sets the desired context as the active one for the thread in order to destroy its corresponding cublasHandle_t. + CUDA_ERROR_FUNC(cuCtxSetCurrent, err, desired); } + cublasStatus_t err; + CUBLAS_ERROR_FUNC(cublasDestroy, err, handle_pair.second); } cublas_handle_mapper_.clear(); } diff --git a/src/blas/backends/cublas/cublas_scope_handle.cpp b/src/blas/backends/cublas/cublas_scope_handle.cpp index 8bb1145fa..812d89d31 100644 --- a/src/blas/backends/cublas/cublas_scope_handle.cpp +++ b/src/blas/backends/cublas/cublas_scope_handle.cpp @@ -17,11 +17,6 @@ * **************************************************************************/ #include "cublas_scope_handle.hpp" -#if __has_include() -#include -#else -#include -#endif namespace oneapi { namespace mkl { @@ -35,108 +30,34 @@ namespace cublas { * takes place if no other element in the container has a key equivalent to * the one being emplaced (keys in a map container are unique). */ -#ifdef ONEMKL_PI_INTERFACE_REMOVED -thread_local cublas_handle CublasScopedContextHandler::handle_helper = - cublas_handle{}; -#else -thread_local cublas_handle CublasScopedContextHandler::handle_helper = - cublas_handle{}; -#endif +thread_local cublas_handle CublasScopedContextHandler::handle_helper = cublas_handle{}; -CublasScopedContextHandler::CublasScopedContextHandler(sycl::queue queue, sycl::interop_handle& ih) - : ih(ih), - needToRecover_(false) { - placedContext_ = new sycl::context(queue.get_context()); - auto cudaDevice = ih.get_native_device(); - CUresult err; - CUcontext desired; - CUDA_ERROR_FUNC(cuCtxGetCurrent, err, &original_); - CUDA_ERROR_FUNC(cuDevicePrimaryCtxRetain, err, &desired, cudaDevice); - if (original_ != desired) { - // Sets the desired context as the active one for the thread - CUDA_ERROR_FUNC(cuCtxSetCurrent, err, desired); - // No context is installed and the suggested context is primary - // This is the most common case. We can activate the context in the - // thread and leave it there until all the PI context referring to the - // same underlying CUDA primary context are destroyed. This emulates - // the behaviour of the CUDA runtime api, and avoids costly context - // switches. No action is required on this side of the if. - needToRecover_ = !(original_ == nullptr); - } -} - -CublasScopedContextHandler::~CublasScopedContextHandler() noexcept(false) { - if (needToRecover_) { - CUresult err; - CUDA_ERROR_FUNC(cuCtxSetCurrent, err, original_); - } - delete placedContext_; -} - -void ContextCallback(void* userData) { - auto* ptr = static_cast*>(userData); - if (!ptr) { - return; - } - auto handle = ptr->exchange(nullptr); - if (handle != nullptr) { - cublasStatus_t err1; - CUBLAS_ERROR_FUNC(cublasDestroy, err1, handle); - handle = nullptr; - } - else { - // if the handle is nullptr it means the handle was already destroyed by - // the cublas_handle destructor and we're free to delete the atomic - // object. - delete ptr; - } -} +CublasScopedContextHandler::CublasScopedContextHandler(sycl::interop_handle& ih) : ih(ih) {} cublasHandle_t CublasScopedContextHandler::get_handle(const sycl::queue& queue) { - auto cudaDevice = ih.get_native_device(); - CUresult cuErr; - CUcontext desired; - CUDA_ERROR_FUNC(cuDevicePrimaryCtxRetain, cuErr, &desired, cudaDevice); -#ifdef ONEMKL_PI_INTERFACE_REMOVED - auto piPlacedContext_ = reinterpret_cast(desired); -#else - auto piPlacedContext_ = reinterpret_cast(desired); -#endif + CUdevice device = ih.get_native_device(); CUstream streamId = get_stream(queue); cublasStatus_t err; - auto it = handle_helper.cublas_handle_mapper_.find(piPlacedContext_); + + auto it = handle_helper.cublas_handle_mapper_.find(device); if (it != handle_helper.cublas_handle_mapper_.end()) { - if (it->second == nullptr) { - handle_helper.cublas_handle_mapper_.erase(it); - } - else { - auto handle = it->second->load(); - if (handle != nullptr) { - cudaStream_t currentStreamId; - CUBLAS_ERROR_FUNC(cublasGetStream, err, handle, ¤tStreamId); - if (currentStreamId != streamId) { - CUBLAS_ERROR_FUNC(cublasSetStream, err, handle, streamId); - } - return handle; - } - else { - handle_helper.cublas_handle_mapper_.erase(it); - } + cublasHandle_t nativeHandle = it->second; + cudaStream_t currentStreamId; + CUBLAS_ERROR_FUNC(cublasGetStream, err, nativeHandle, ¤tStreamId); + if (currentStreamId != streamId) { + CUBLAS_ERROR_FUNC(cublasSetStream, err, nativeHandle, streamId); } + return nativeHandle; } - cublasHandle_t handle; - - CUBLAS_ERROR_FUNC(cublasCreate, err, &handle); - CUBLAS_ERROR_FUNC(cublasSetStream, err, handle, streamId); - - auto insert_iter = handle_helper.cublas_handle_mapper_.insert( - std::make_pair(piPlacedContext_, new std::atomic(handle))); + cublasHandle_t nativeHandle; + CUBLAS_ERROR_FUNC(cublasCreate, err, &nativeHandle); + CUBLAS_ERROR_FUNC(cublasSetStream, err, nativeHandle, streamId); - sycl::detail::pi::contextSetExtendedDeleter(*placedContext_, ContextCallback, - insert_iter.first->second); + auto insert_iter = + handle_helper.cublas_handle_mapper_.insert(std::make_pair(device, nativeHandle)); - return handle; + return nativeHandle; } CUstream CublasScopedContextHandler::get_stream(const sycl::queue& queue) { diff --git a/src/blas/backends/cublas/cublas_scope_handle.hpp b/src/blas/backends/cublas/cublas_scope_handle.hpp index d17909cfb..2f6027478 100644 --- a/src/blas/backends/cublas/cublas_scope_handle.hpp +++ b/src/blas/backends/cublas/cublas_scope_handle.hpp @@ -23,32 +23,9 @@ #else #include #endif -#if __has_include() -#if __SYCL_COMPILER_VERSION <= 20220930 -#include -#endif -#include -#else -#include -#include -#endif -// After Plugin Interface removal in DPC++ ur.hpp is the new include -#if __has_include() -#include -#ifndef ONEMKL_PI_INTERFACE_REMOVED -#define ONEMKL_PI_INTERFACE_REMOVED -#endif -#elif __has_include() -#include -#else -#include -#endif - -#include #include #include -#include #include "cublas_helper.hpp" #include "cublas_handle.hpp" @@ -84,22 +61,14 @@ the handle must be destroyed when the context goes out of scope. This will bind **/ class CublasScopedContextHandler { - CUcontext original_; - sycl::context* placedContext_; - bool needToRecover_; sycl::interop_handle& ih; -#ifdef ONEMKL_PI_INTERFACE_REMOVED - static thread_local cublas_handle handle_helper; -#else - static thread_local cublas_handle handle_helper; -#endif + static thread_local cublas_handle handle_helper; CUstream get_stream(const sycl::queue& queue); sycl::context get_context(const sycl::queue& queue); public: - CublasScopedContextHandler(sycl::queue queue, sycl::interop_handle& ih); + CublasScopedContextHandler(sycl::interop_handle& ih); - ~CublasScopedContextHandler() noexcept(false); /** * @brief get_handle: creates the handle by implicitly impose the advice * given by nvidia for creating a cublas_handle. (e.g. one cuStream per device diff --git a/src/blas/backends/cublas/cublas_scope_handle_hipsycl.cpp b/src/blas/backends/cublas/cublas_scope_handle_hipsycl.cpp index 03c282aed..8822151dd 100644 --- a/src/blas/backends/cublas/cublas_scope_handle_hipsycl.cpp +++ b/src/blas/backends/cublas/cublas_scope_handle_hipsycl.cpp @@ -24,43 +24,33 @@ namespace mkl { namespace blas { namespace cublas { -thread_local cublas_handle CublasScopedContextHandler::handle_helper = cublas_handle{}; +thread_local cublas_handle CublasScopedContextHandler::handle_helper = cublas_handle{}; CublasScopedContextHandler::CublasScopedContextHandler(sycl::queue queue, sycl::interop_handle& ih) : interop_h(ih) {} cublasHandle_t CublasScopedContextHandler::get_handle(const sycl::queue& queue) { sycl::device device = queue.get_device(); - int current_device = interop_h.get_native_device(); + CUdevice current_device = interop_h.get_native_device(); CUstream streamId = get_stream(queue); cublasStatus_t err; auto it = handle_helper.cublas_handle_mapper_.find(current_device); if (it != handle_helper.cublas_handle_mapper_.end()) { - if (it->second == nullptr) { - handle_helper.cublas_handle_mapper_.erase(it); - } - else { - auto handle = it->second->load(); - if (handle != nullptr) { - cudaStream_t currentStreamId; - CUBLAS_ERROR_FUNC(cublasGetStream, err, handle, ¤tStreamId); - if (currentStreamId != streamId) { - CUBLAS_ERROR_FUNC(cublasSetStream, err, handle, streamId); - } - return handle; - } - else { - handle_helper.cublas_handle_mapper_.erase(it); - } + cublasHandle_t handle = it->second; + cudaStream_t currentStreamId; + CUBLAS_ERROR_FUNC(cublasGetStream, err, handle, ¤tStreamId); + if (currentStreamId != streamId) { + CUBLAS_ERROR_FUNC(cublasSetStream, err, handle, streamId); } + return handle; } cublasHandle_t handle; CUBLAS_ERROR_FUNC(cublasCreate, err, &handle); CUBLAS_ERROR_FUNC(cublasSetStream, err, handle, streamId); - auto insert_iter = handle_helper.cublas_handle_mapper_.insert( - std::make_pair(current_device, new std::atomic(handle))); + auto insert_iter = + handle_helper.cublas_handle_mapper_.insert(std::make_pair(current_device, handle)); return handle; } @@ -71,4 +61,4 @@ CUstream CublasScopedContextHandler::get_stream(const sycl::queue& queue) { } // namespace cublas } // namespace blas } // namespace mkl -} // namespace oneapi \ No newline at end of file +} // namespace oneapi diff --git a/src/blas/backends/cublas/cublas_scope_handle_hipsycl.hpp b/src/blas/backends/cublas/cublas_scope_handle_hipsycl.hpp index 9e1eb89e5..84b28e0fd 100644 --- a/src/blas/backends/cublas/cublas_scope_handle_hipsycl.hpp +++ b/src/blas/backends/cublas/cublas_scope_handle_hipsycl.hpp @@ -25,7 +25,6 @@ #endif #include #include -#include #include "cublas_helper.hpp" #include "cublas_handle.hpp" namespace oneapi { @@ -60,7 +59,7 @@ the handle must be destroyed when the context goes out of scope. This will bind class CublasScopedContextHandler { sycl::interop_handle interop_h; - static thread_local cublas_handle handle_helper; + static thread_local cublas_handle handle_helper; sycl::context get_context(const sycl::queue& queue); CUstream get_stream(const sycl::queue& queue); diff --git a/src/blas/backends/cublas/cublas_task.hpp b/src/blas/backends/cublas/cublas_task.hpp index 08d5cf70e..ae95e6eb1 100644 --- a/src/blas/backends/cublas/cublas_task.hpp +++ b/src/blas/backends/cublas/cublas_task.hpp @@ -35,18 +35,6 @@ #else #include "cublas_scope_handle_hipsycl.hpp" -// After Plugin Interface removal in DPC++ ur.hpp is the new include -#if __has_include() -#include -#ifndef ONEMKL_PI_INTERFACE_REMOVED -#define ONEMKL_PI_INTERFACE_REMOVED -#endif -#elif __has_include() -#include -#else -#include -#endif - namespace sycl { using interop_handler = sycl::interop_handle; } @@ -72,7 +60,7 @@ static inline void host_task_internal(H& cgh, sycl::queue queue, F f) { #else cgh.host_task([f, queue](sycl::interop_handle ih) { #endif - auto sc = CublasScopedContextHandler(queue, ih); + auto sc = CublasScopedContextHandler(ih); f(sc); }); }