Skip to content

Commit

Permalink
Initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
matmanc committed Jan 9, 2025
1 parent ad13d8f commit 7dea131
Showing 1 changed file with 84 additions and 0 deletions.
84 changes: 84 additions & 0 deletions include/cudawrappers/cu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -668,6 +668,81 @@ class DeviceMemory : public Wrapper<CUdeviceptr> {
size_t _size;
};

constexpr unsigned int CU_GRAPH_DEFAULT = 0;

class GraphNode : public Wrapper<CUgraphNode> {
public:
GraphNode() = default;
GraphNode(CUgraphNode &node) : Wrapper(node) {};

CUgraphNode *getNode() { return &_obj; };
};

class GraphKernelNodeParams : public Wrapper<CUDA_KERNEL_NODE_PARAMS> {
public:
GraphKernelNodeParams(const Function &function, unsigned gridDimX,
unsigned gridDimY, unsigned gridDimZ,
unsigned blockDimX, unsigned blockDimY,
unsigned blockDimZ, unsigned sharedMemBytes,
const std::vector<const void *> &params) {
_obj.func = function;
_obj.blockDimX = blockDimX;
_obj.blockDimY = blockDimY;
_obj.blockDimZ = blockDimZ;
_obj.gridDimX = gridDimX;
_obj.gridDimY = gridDimY;
_obj.gridDimZ = gridDimZ;
_obj.sharedMemBytes = sharedMemBytes;
_obj.kernelParams = const_cast<void **>(params.data());
_obj.extra = nullptr;
}
};

class Graph : public Wrapper<CUgraph> {
public:
explicit Graph(unsigned int flags = CU_GRAPH_DEFAULT) {
checkCudaCall(cuGraphCreate(&_obj, flags));
manager = std::shared_ptr<CUgraph>(new CUgraph(_obj), [](CUgraph *ptr) {
checkCudaCall(cuGraphDestroy(*ptr));
delete ptr;
});
}

void addKernelNode(GraphNode &node,
const std::vector<CUgraphNode> &dependencies,
const GraphKernelNodeParams &params) {
checkCudaCall(cuGraphAddKernelNode(node.getNode(), _obj,
dependencies.data(), dependencies.size(),
(CUDA_KERNEL_NODE_PARAMS *)(&params)));
}

CUgraphExec Instantiate(unsigned int flags = CU_GRAPH_DEFAULT) {
CUgraphExec graph_instance;
cu::checkCudaCall(
cuGraphInstantiateWithFlags(&graph_instance, _obj, flags));
return graph_instance;
}
};

class WhileNode : public Wrapper<CUgraphNodeParams> {
public:
WhileNode() {
_obj.conditional = _obj.type =
CUgraphNodeType::CU_GRAPH_NODE_TYPE_CONDITIONAL;
}
};

class GraphExec : public Wrapper<CUgraphExec> {
public:
explicit GraphExec(CUgraphExec &graph_exec) : Wrapper(graph_exec) {}
explicit GraphExec(GraphExec &graph_exec) = default;

explicit GraphExec(const Graph &graph,
unsigned int flags = CU_GRAPH_DEFAULT) {
checkCudaCall(cuGraphInstantiate(&_obj, graph, flags));
}
};

class Stream : public Wrapper<CUstream> {
friend class Event;

Expand Down Expand Up @@ -886,6 +961,14 @@ class Stream : public Wrapper<CUstream> {
}
#endif

void lunchGraph(CUgraphExec &graph) {
checkCudaCall(cuGraphLaunch(graph, _obj));
}

void lunchGraph(GraphExec &graph) {
checkCudaCall(cuGraphLaunch(graph, _obj));
}

void query() {
checkCudaCall(cuStreamQuery(_obj)); // unsuccessful result throws cu::Error
}
Expand Down Expand Up @@ -920,6 +1003,7 @@ class Stream : public Wrapper<CUstream> {
inline void Event::record(Stream &stream) {
checkCudaCall(cuEventRecord(_obj, stream._obj));
}

} // namespace cu

#endif

0 comments on commit 7dea131

Please sign in to comment.