Skip to content

Commit

Permalink
Update cudawrappers to support COBALT (#240)
Browse files Browse the repository at this point in the history
* Make Function::getAttribute const
* Add Function::name
* Add HostMemory::size
* Add DeviceMemory::size
* Add Module constructor with CUjit_option map
* Update CHANGELOG
  • Loading branch information
csbnw authored Nov 7, 2023
1 parent a1fe550 commit 041d317
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 5 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,15 @@ This project adheres to [Semantic Versioning](http://semver.org/).

### Added
- Added `cu::Context::getDevice()`
- Added `cu::Module` constructor with `CUjit_option` map argument
- Added `DeviceMemory::size`
- Added `HostMemory::size`
- Added `Function::name`

### Changed
- Fixed the `cu::Module(CUmodule&)` constructor
- Added `Function::getAttribute` is now const

### Removed

## [0.6.0] - 2023-10-06
Expand Down
46 changes: 41 additions & 5 deletions include/cudawrappers/cu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <array>
#include <cstddef>
#include <exception>
#include <map>
#include <memory>
#include <stdexcept>
#include <string>
Expand Down Expand Up @@ -250,15 +251,16 @@ class Context : public Wrapper<CUcontext> {

class HostMemory : public Wrapper<void *> {
public:
explicit HostMemory(size_t size, unsigned int flags = 0) {
explicit HostMemory(size_t size, unsigned int flags = 0) : _size(size) {
checkCudaCall(cuMemHostAlloc(&_obj, size, flags));
manager = std::shared_ptr<void *>(new (void *)(_obj), [](void **ptr) {
cuMemFreeHost(*ptr);
delete ptr;
});
}

explicit HostMemory(void *ptr, size_t size, unsigned int flags = 0) {
explicit HostMemory(void *ptr, size_t size, unsigned int flags = 0)
: _size(size) {
_obj = ptr;
checkCudaCall(cuMemHostRegister(&_obj, size, flags));
manager = std::shared_ptr<void *>(
Expand All @@ -269,6 +271,11 @@ class HostMemory : public Wrapper<void *> {
operator T *() {
return static_cast<T *>(_obj);
}

size_t size() const { return _size; }

private:
size_t _size;
};

class Array : public Wrapper<CUarray> {
Expand Down Expand Up @@ -342,6 +349,24 @@ class Module : public Wrapper<CUmodule> {
});
}

typedef std::map<CUjit_option, void *> optionmap_t;
explicit Module(const void *image, Module::optionmap_t &options) {
std::vector<CUjit_option> keys;
std::vector<void *> values;

for (const std::pair<CUjit_option, void *> &i : options) {
keys.push_back(i.first);
values.push_back(i.second);
}

checkCudaCall(cuModuleLoadDataEx(&_obj, image, options.size(), keys.data(),
values.data()));

for (size_t i = 0; i < keys.size(); ++i) {
options[keys[i]] = values[i];
}
}

explicit Module(CUmodule &module) : Wrapper(module) {}

CUdeviceptr getGlobal(const char *name) const {
Expand All @@ -353,13 +378,13 @@ class Module : public Wrapper<CUmodule> {

class Function : public Wrapper<CUfunction> {
public:
Function(const Module &module, const char *name) {
Function(const Module &module, const char *name) : _name(name) {
checkCudaCall(cuModuleGetFunction(&_obj, module, name));
}

explicit Function(CUfunction &function) : Wrapper(function) {}

int getAttribute(CUfunction_attribute attribute) {
int getAttribute(CUfunction_attribute attribute) const {
int value{};
checkCudaCall(cuFuncGetAttribute(&value, attribute, _obj));
return value;
Expand All @@ -368,6 +393,11 @@ class Function : public Wrapper<CUfunction> {
void setCacheConfig(CUfunc_cache config) {
checkCudaCall(cuFuncSetCacheConfig(_obj, config));
}

const char *name() const { return _name; }

private:
const char *_name;
};

class Event : public Wrapper<CUevent> {
Expand Down Expand Up @@ -402,7 +432,8 @@ class Event : public Wrapper<CUevent> {
class DeviceMemory : public Wrapper<CUdeviceptr> {
public:
explicit DeviceMemory(size_t size, CUmemorytype type = CU_MEMORYTYPE_DEVICE,
unsigned int flags = 0) {
unsigned int flags = 0)
: _size(size) {
if (type == CU_MEMORYTYPE_DEVICE and !flags) {
checkCudaCall(cuMemAlloc(&_obj, size));
} else if (type == CU_MEMORYTYPE_UNIFIED) {
Expand Down Expand Up @@ -443,6 +474,11 @@ class DeviceMemory : public Wrapper<CUdeviceptr> {
"Cannot return memory of type CU_MEMORYTYPE_DEVICE as pointer.");
}
}

size_t size() const { return _size; }

private:
size_t _size;
};

class Stream : public Wrapper<CUstream> {
Expand Down

0 comments on commit 041d317

Please sign in to comment.