diff --git a/include/cudawrappers/cu.hpp b/include/cudawrappers/cu.hpp index 6ee3be7..c601aab 100644 --- a/include/cudawrappers/cu.hpp +++ b/include/cudawrappers/cu.hpp @@ -686,12 +686,18 @@ class GraphKernelNodeParams : public Wrapper { unsigned blockDimZ, unsigned sharedMemBytes, const std::vector ¶ms) { _obj.func = function; +#if defined(__HIP__) + _obj.blockDim = {blockDimX, blockDimY, blockDimZ}; + _obj.gridDim = {gridDimX, gridDimY, gridDimZ}; + +#else _obj.blockDimX = blockDimX; _obj.blockDimY = blockDimY; _obj.blockDimZ = blockDimZ; _obj.gridDimX = gridDimX; _obj.gridDimY = gridDimY; _obj.gridDimZ = gridDimZ; +#endif _obj.sharedMemBytes = sharedMemBytes; _obj.kernelParams = const_cast(params.data()); _obj.extra = nullptr; @@ -763,7 +769,7 @@ class GraphExec : public Wrapper { explicit GraphExec(const Graph &graph, unsigned int flags = CU_GRAPH_DEFAULT) { - checkCudaCall(cuGraphInstantiate(&_obj, graph, flags)); + checkCudaCall(cuGraphInstantiateWithFlags(&_obj, graph, flags)); } }; diff --git a/include/cudawrappers/macros.hpp b/include/cudawrappers/macros.hpp index 9c6c63b..0a1d1f7 100644 --- a/include/cudawrappers/macros.hpp +++ b/include/cudawrappers/macros.hpp @@ -661,6 +661,12 @@ typedef uint32_t cuuint32_t; #define cuGetErrorName hipDrvGetErrorName #define cuGetErrorString hipDrvGetErrorString #define cuGetProcAddress hipGetProcAddress +#define cuGraphAddKernelNode hipGraphAddKernelNode +#define cuGraphCreate hipGraphCreate +#define cuGraphDestroy hipGraphDestroy +#define cuGraphInstantiateWithFlags hipGraphInstantiateWithFlags +#define cuGraphInstantiate hipGraphInstantiate +#define cuGraphLaunch hipGraphLaunch #define cuInit hipInit #define cuIpcCloseMemHandle hipIpcCloseMemHandle #define cuIpcGetEventHandle hipIpcGetEventHandle