diff --git a/include/cudawrappers/cu.hpp b/include/cudawrappers/cu.hpp index 06cc69a..ee68420 100644 --- a/include/cudawrappers/cu.hpp +++ b/include/cudawrappers/cu.hpp @@ -668,6 +668,81 @@ class DeviceMemory : public Wrapper { size_t _size; }; +constexpr unsigned int CU_GRAPH_DEFAULT = 0; + +class GraphNode : public Wrapper { + public: + GraphNode() = default; + GraphNode(CUgraphNode &node) : Wrapper(node) {}; + + CUgraphNode *getNode() { return &_obj; }; +}; + +class GraphKernelNodeParams : public Wrapper { + public: + GraphKernelNodeParams(const Function &function, unsigned gridDimX, + unsigned gridDimY, unsigned gridDimZ, + unsigned blockDimX, unsigned blockDimY, + unsigned blockDimZ, unsigned sharedMemBytes, + const std::vector ¶ms) { + _obj.func = function; + _obj.blockDimX = blockDimX; + _obj.blockDimY = blockDimY; + _obj.blockDimZ = blockDimZ; + _obj.gridDimX = gridDimX; + _obj.gridDimY = gridDimY; + _obj.gridDimZ = gridDimZ; + _obj.sharedMemBytes = sharedMemBytes; + _obj.kernelParams = const_cast(params.data()); + _obj.extra = nullptr; + } +}; + +class Graph : public Wrapper { + public: + explicit Graph(unsigned int flags = CU_GRAPH_DEFAULT) { + checkCudaCall(cuGraphCreate(&_obj, flags)); + manager = std::shared_ptr(new CUgraph(_obj), [](CUgraph *ptr) { + checkCudaCall(cuGraphDestroy(*ptr)); + delete ptr; + }); + } + + void addKernelNode(GraphNode &node, + const std::vector &dependencies, + const GraphKernelNodeParams ¶ms) { + checkCudaCall(cuGraphAddKernelNode(node.getNode(), _obj, + dependencies.data(), dependencies.size(), + (CUDA_KERNEL_NODE_PARAMS *)(¶ms))); + } + + CUgraphExec Instantiate(unsigned int flags = CU_GRAPH_DEFAULT) { + CUgraphExec graph_instance; + cu::checkCudaCall( + cuGraphInstantiateWithFlags(&graph_instance, _obj, flags)); + return graph_instance; + } +}; + +class WhileNode : public Wrapper { + public: + WhileNode() { + _obj.conditional = _obj.type = + CUgraphNodeType::CU_GRAPH_NODE_TYPE_CONDITIONAL; + } +}; + +class GraphExec : public Wrapper { + public: + explicit GraphExec(CUgraphExec &graph_exec) : Wrapper(graph_exec) {} + explicit GraphExec(GraphExec &graph_exec) = default; + + explicit GraphExec(const Graph &graph, + unsigned int flags = CU_GRAPH_DEFAULT) { + checkCudaCall(cuGraphInstantiate(&_obj, graph, flags)); + } +}; + class Stream : public Wrapper { friend class Event; @@ -886,6 +961,14 @@ class Stream : public Wrapper { } #endif + void lunchGraph(CUgraphExec &graph) { + checkCudaCall(cuGraphLaunch(graph, _obj)); + } + + void lunchGraph(GraphExec &graph) { + checkCudaCall(cuGraphLaunch(graph, _obj)); + } + void query() { checkCudaCall(cuStreamQuery(_obj)); // unsuccessful result throws cu::Error } @@ -920,6 +1003,7 @@ class Stream : public Wrapper { inline void Event::record(Stream &stream) { checkCudaCall(cuEventRecord(_obj, stream._obj)); } + } // namespace cu #endif \ No newline at end of file