From f23bd0d87ea5e93714c862ce562e71427507f8b9 Mon Sep 17 00:00:00 2001 From: Mattia Mancini Date: Tue, 14 Jan 2025 14:00:17 +0100 Subject: [PATCH] Fix hip usage --- include/cudawrappers/cu.hpp | 8 +++++++- include/cudawrappers/macros.hpp | 6 ++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/include/cudawrappers/cu.hpp b/include/cudawrappers/cu.hpp index 6ee3be7..c601aab 100644 --- a/include/cudawrappers/cu.hpp +++ b/include/cudawrappers/cu.hpp @@ -686,12 +686,18 @@ class GraphKernelNodeParams : public Wrapper { unsigned blockDimZ, unsigned sharedMemBytes, const std::vector ¶ms) { _obj.func = function; +#if defined(__HIP__) + _obj.blockDim = {blockDimX, blockDimY, blockDimZ}; + _obj.gridDim = {gridDimX, gridDimY, gridDimZ}; + +#else _obj.blockDimX = blockDimX; _obj.blockDimY = blockDimY; _obj.blockDimZ = blockDimZ; _obj.gridDimX = gridDimX; _obj.gridDimY = gridDimY; _obj.gridDimZ = gridDimZ; +#endif _obj.sharedMemBytes = sharedMemBytes; _obj.kernelParams = const_cast(params.data()); _obj.extra = nullptr; @@ -763,7 +769,7 @@ class GraphExec : public Wrapper { explicit GraphExec(const Graph &graph, unsigned int flags = CU_GRAPH_DEFAULT) { - checkCudaCall(cuGraphInstantiate(&_obj, graph, flags)); + checkCudaCall(cuGraphInstantiateWithFlags(&_obj, graph, flags)); } }; diff --git a/include/cudawrappers/macros.hpp b/include/cudawrappers/macros.hpp index 9c6c63b..0a1d1f7 100644 --- a/include/cudawrappers/macros.hpp +++ b/include/cudawrappers/macros.hpp @@ -661,6 +661,12 @@ typedef uint32_t cuuint32_t; #define cuGetErrorName hipDrvGetErrorName #define cuGetErrorString hipDrvGetErrorString #define cuGetProcAddress hipGetProcAddress +#define cuGraphAddKernelNode hipGraphAddKernelNode +#define cuGraphCreate hipGraphCreate +#define cuGraphDestroy hipGraphDestroy +#define cuGraphInstantiateWithFlags hipGraphInstantiateWithFlags +#define cuGraphInstantiate hipGraphInstantiate +#define cuGraphLaunch hipGraphLaunch #define cuInit hipInit #define cuIpcCloseMemHandle hipIpcCloseMemHandle #define cuIpcGetEventHandle hipIpcGetEventHandle