Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Call from jit operators #840

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
Expand Up @@ -11205,6 +11205,14 @@ cpp_jit_compile_get_function <- function(cu, name) {
.Call('_torch_cpp_jit_compile_get_function', PACKAGE = 'torchpkg', cu, name)
}

cpp_jit_get_all_operators_names <- function() {
.Call('_torch_cpp_jit_get_all_operators_names', PACKAGE = 'torchpkg')
}

cpp_jit_get_operator_from_name <- function(x) {
.Call('_torch_cpp_jit_get_operator_from_name', PACKAGE = 'torchpkg', x)
}

cpp_lantern_configure <- function(log) {
invisible(.Call('_torch_cpp_lantern_configure', PACKAGE = 'torchpkg', log))
}
Expand Down
30 changes: 30 additions & 0 deletions inst/include/lantern/lantern.h
Original file line number Diff line number Diff line change
Expand Up @@ -2288,6 +2288,33 @@ HOST_API int lantern_string_size (void* self)
return ret;
}

LANTERN_API void* (LANTERN_PTR _lantern_jit_get_all_operators_names) ();
HOST_API void* lantern_jit_get_all_operators_names ()
{
LANTERN_CHECK_LOADED
void* ret = _lantern_jit_get_all_operators_names();
LANTERN_HOST_HANDLER;
return ret;
}

LANTERN_API void* (LANTERN_PTR _lantern_jit_get_operation_schema) (void* name);
HOST_API void* lantern_jit_get_operation_schema (void* name)
{
LANTERN_CHECK_LOADED
void* ret = _lantern_jit_get_operation_schema(name);
LANTERN_HOST_HANDLER;
return ret;
}

LANTERN_API void* (LANTERN_PTR _lantern_jit_FunctionSchema_name) (void* schema);
HOST_API void* lantern_jit_FunctionSchema_name (void* schema)
{
LANTERN_CHECK_LOADED
void* ret = _lantern_jit_FunctionSchema_name(schema);
LANTERN_HOST_HANDLER;
return ret;
}

/* Autogen Headers -- Start */
LANTERN_API void* (LANTERN_PTR _lantern__cast_byte_tensor_bool)(void* self, void* non_blocking);
HOST_API void* lantern__cast_byte_tensor_bool(void* self, void* non_blocking) { LANTERN_CHECK_LOADED void* ret = _lantern__cast_byte_tensor_bool(self, non_blocking); LANTERN_HOST_HANDLER return ret; }
Expand Down Expand Up @@ -8321,6 +8348,9 @@ LOAD_SYMBOL(_lantern_cuda_device_stats);
LOAD_SYMBOL(_lantern_cuda_get_runtime_version);
LOAD_SYMBOL(_set_delete_lambda_fun);
LOAD_SYMBOL(_lantern_string_size);
LOAD_SYMBOL(_lantern_jit_get_all_operators_names);
LOAD_SYMBOL(_lantern_jit_get_operation_schema);
LOAD_SYMBOL(_lantern_jit_FunctionSchema_name);
/* Autogen Symbols -- Start */
LOAD_SYMBOL(_lantern__cast_byte_tensor_bool)
LOAD_SYMBOL(_lantern__cast_char_tensor_bool)
Expand Down
5 changes: 5 additions & 0 deletions inst/include/lantern/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ void* bool_t(const bool& x);
void* double_t(const double& x);
void* Stream(const at::Stream& x);
void* IValue(const torch::IValue& x);
void* FunctionSchema (const c10::FunctionSchema& x);

namespace vector {
void* string(const std::vector<std::string>& x);
Expand Down Expand Up @@ -147,6 +148,7 @@ LANTERN_FROM_RAW_DECL(bool_t, bool)
LANTERN_FROM_RAW_DECL(double_t, double)
LANTERN_FROM_RAW_DECL(Stream, at::Stream)
LANTERN_FROM_RAW_DECL(IValue, torch::IValue)
LANTERN_FROM_RAW_DECL(FunctionSchema, c10::FunctionSchema)

namespace optional {
LANTERN_FROM_RAW_DECL(DimnameList, c10::optional<torch::DimnameList>)
Expand Down Expand Up @@ -398,6 +400,8 @@ void* double_t(const double& x) { return make_ptr<double>(x); }
void* bool_t(const bool& x) { return make_ptr<bool>(x); }
void* Stream(const at::Stream& x) { return make_ptr<at::Stream>(x); }
void* IValue(const at::IValue& x) { return make_ptr<at::IValue>(x); }
void* FunctionSchema (const c10::FunctionSchema& x) { return make_ptr<c10::FunctionSchema>(x); }


namespace vector {

Expand Down Expand Up @@ -527,6 +531,7 @@ LANTERN_FROM_RAW(bool_t, bool)
LANTERN_FROM_RAW(double_t, double)
LANTERN_FROM_RAW(Stream, at::Stream)
LANTERN_FROM_RAW(IValue, torch::IValue)
LANTERN_FROM_RAW(FunctionSchema, c10::FunctionSchema)

namespace optional {
LANTERN_FROM_RAW_WRAPPED(DimnameList, self_contained::optional::DimnameList,
Expand Down
30 changes: 30 additions & 0 deletions lantern/include/lantern/lantern.h
Original file line number Diff line number Diff line change
Expand Up @@ -2288,6 +2288,33 @@ HOST_API int lantern_string_size (void* self)
return ret;
}

LANTERN_API void* (LANTERN_PTR _lantern_jit_get_all_operators_names) ();
HOST_API void* lantern_jit_get_all_operators_names ()
{
LANTERN_CHECK_LOADED
void* ret = _lantern_jit_get_all_operators_names();
LANTERN_HOST_HANDLER;
return ret;
}

LANTERN_API void* (LANTERN_PTR _lantern_jit_get_operation_schema) (void* name);
HOST_API void* lantern_jit_get_operation_schema (void* name)
{
LANTERN_CHECK_LOADED
void* ret = _lantern_jit_get_operation_schema(name);
LANTERN_HOST_HANDLER;
return ret;
}

LANTERN_API void* (LANTERN_PTR _lantern_jit_FunctionSchema_name) (void* schema);
HOST_API void* lantern_jit_FunctionSchema_name (void* schema)
{
LANTERN_CHECK_LOADED
void* ret = _lantern_jit_FunctionSchema_name(schema);
LANTERN_HOST_HANDLER;
return ret;
}

/* Autogen Headers -- Start */
LANTERN_API void* (LANTERN_PTR _lantern__cast_byte_tensor_bool)(void* self, void* non_blocking);
HOST_API void* lantern__cast_byte_tensor_bool(void* self, void* non_blocking) { LANTERN_CHECK_LOADED void* ret = _lantern__cast_byte_tensor_bool(self, non_blocking); LANTERN_HOST_HANDLER return ret; }
Expand Down Expand Up @@ -8321,6 +8348,9 @@ LOAD_SYMBOL(_lantern_cuda_device_stats);
LOAD_SYMBOL(_lantern_cuda_get_runtime_version);
LOAD_SYMBOL(_set_delete_lambda_fun);
LOAD_SYMBOL(_lantern_string_size);
LOAD_SYMBOL(_lantern_jit_get_all_operators_names);
LOAD_SYMBOL(_lantern_jit_get_operation_schema);
LOAD_SYMBOL(_lantern_jit_FunctionSchema_name);
/* Autogen Symbols -- Start */
LOAD_SYMBOL(_lantern__cast_byte_tensor_bool)
LOAD_SYMBOL(_lantern__cast_char_tensor_bool)
Expand Down
5 changes: 5 additions & 0 deletions lantern/include/lantern/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ void* bool_t(const bool& x);
void* double_t(const double& x);
void* Stream(const at::Stream& x);
void* IValue(const torch::IValue& x);
void* FunctionSchema (const c10::FunctionSchema& x);

namespace vector {
void* string(const std::vector<std::string>& x);
Expand Down Expand Up @@ -147,6 +148,7 @@ LANTERN_FROM_RAW_DECL(bool_t, bool)
LANTERN_FROM_RAW_DECL(double_t, double)
LANTERN_FROM_RAW_DECL(Stream, at::Stream)
LANTERN_FROM_RAW_DECL(IValue, torch::IValue)
LANTERN_FROM_RAW_DECL(FunctionSchema, c10::FunctionSchema)

namespace optional {
LANTERN_FROM_RAW_DECL(DimnameList, c10::optional<torch::DimnameList>)
Expand Down Expand Up @@ -398,6 +400,8 @@ void* double_t(const double& x) { return make_ptr<double>(x); }
void* bool_t(const bool& x) { return make_ptr<bool>(x); }
void* Stream(const at::Stream& x) { return make_ptr<at::Stream>(x); }
void* IValue(const at::IValue& x) { return make_ptr<at::IValue>(x); }
void* FunctionSchema (const c10::FunctionSchema& x) { return make_ptr<c10::FunctionSchema>(x); }


namespace vector {

Expand Down Expand Up @@ -527,6 +531,7 @@ LANTERN_FROM_RAW(bool_t, bool)
LANTERN_FROM_RAW(double_t, double)
LANTERN_FROM_RAW(Stream, at::Stream)
LANTERN_FROM_RAW(IValue, torch::IValue)
LANTERN_FROM_RAW(FunctionSchema, c10::FunctionSchema)

namespace optional {
LANTERN_FROM_RAW_WRAPPED(DimnameList, self_contained::optional::DimnameList,
Expand Down
30 changes: 29 additions & 1 deletion lantern/src/Compile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,32 @@ void* _lantern_jit_compile_get_method(void* cu, void* name) {
auto name_ = from_raw::string(name);
return (void*)from_raw::CompilationUnit(cu).find_function(name_);
LANTERN_FUNCTION_END
}
}

void * _lantern_jit_get_all_operators_names () {
LANTERN_FUNCTION_START
auto ops = torch::jit::getAllOperators();
std::vector<std::string> names;
for (const auto& op : ops) {
names.push_back(op->schema().name());
}
return make_raw::vector::string(names);
LANTERN_FUNCTION_END
}

void* _lantern_jit_get_operation_schema (void* name) {
LANTERN_FUNCTION_START
auto name_ = from_raw::string(name);
auto op_name = c10::Symbol::fromQualString(name_);
auto op = torch::jit::getAllOperatorsFor(op_name);
return make_raw::FunctionSchema(op[0]->schema());
LANTERN_FUNCTION_END
}

void* _lantern_jit_FunctionSchema_name (void* schema) {
auto schema_ = from_raw::FunctionSchema(schema);
return make_raw::string(schema_.name());
}

// https://cs.github.com/pytorch/pytorch/blob/47834679ba2f869e66450a74e2add4c04f0006e9/torch/csrc/jit/python/pybind_utils.h#L874
// https://cs.github.com/pytorch/pytorch/blob/47834679ba2f869e66450a74e2add4c04f0006e9/torch/csrc/jit/python/pybind_utils.h#L1137
23 changes: 23 additions & 0 deletions src/RcppExports.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36539,6 +36539,27 @@ BEGIN_RCPP
return rcpp_result_gen;
END_RCPP
}
// cpp_jit_get_all_operators_names
torch::vector::string cpp_jit_get_all_operators_names();
RcppExport SEXP _torch_cpp_jit_get_all_operators_names() {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
rcpp_result_gen = Rcpp::wrap(cpp_jit_get_all_operators_names());
return rcpp_result_gen;
END_RCPP
}
// cpp_jit_get_operator_from_name
torch::string cpp_jit_get_operator_from_name(torch::string x);
RcppExport SEXP _torch_cpp_jit_get_operator_from_name(SEXP xSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< torch::string >::type x(xSEXP);
rcpp_result_gen = Rcpp::wrap(cpp_jit_get_operator_from_name(x));
return rcpp_result_gen;
END_RCPP
}
// cpp_lantern_configure
void cpp_lantern_configure(int log);
RcppExport SEXP _torch_cpp_lantern_configure(SEXP logSEXP) {
Expand Down Expand Up @@ -40597,6 +40618,8 @@ static const R_CallMethodDef CallEntries[] = {
{"_torch_cpp_jit_compile", (DL_FUNC) &_torch_cpp_jit_compile, 1},
{"_torch_cpp_jit_compile_list_methods", (DL_FUNC) &_torch_cpp_jit_compile_list_methods, 1},
{"_torch_cpp_jit_compile_get_function", (DL_FUNC) &_torch_cpp_jit_compile_get_function, 2},
{"_torch_cpp_jit_get_all_operators_names", (DL_FUNC) &_torch_cpp_jit_get_all_operators_names, 0},
{"_torch_cpp_jit_get_operator_from_name", (DL_FUNC) &_torch_cpp_jit_get_operator_from_name, 1},
{"_torch_cpp_lantern_configure", (DL_FUNC) &_torch_cpp_lantern_configure, 1},
{"_torch_cpp_lantern_version", (DL_FUNC) &_torch_cpp_lantern_version, 0},
{"_torch_cpp_lantern_init", (DL_FUNC) &_torch_cpp_lantern_init, 1},
Expand Down
10 changes: 10 additions & 0 deletions src/jit-compile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,13 @@ SEXP cpp_jit_compile_get_function(SEXP cu, XPtrTorchstring name) {
return R_NilValue;
}
}

// [[Rcpp::export]]
torch::vector::string cpp_jit_get_all_operators_names () {
return lantern_jit_get_all_operators_names();
}

// [[Rcpp::export]]
torch::string cpp_jit_get_operator_from_name (torch::string x) {
return lantern_jit_FunctionSchema_name(lantern_jit_get_operation_schema(x.get()));
}
6 changes: 6 additions & 0 deletions src/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,12 @@ torch::Tensor torch_tensor_cpp(SEXP x, Rcpp::Nullable<torch::Dtype> dtype,
break;
}
}
case NILSXP: {
cdtype = lantern_Dtype_bool();
final_type = dtype.isNull() ? torch::Dtype(lantern_Dtype_bool())
: Rcpp::as<torch::Dtype>(dtype);
break;
}
default: {
Rcpp::stop("R type not handled");
}
Expand Down
12 changes: 12 additions & 0 deletions tests/testthat/test-indexing.R
Original file line number Diff line number Diff line change
Expand Up @@ -250,3 +250,15 @@ test_that("regression test for #695", {
as.array(a)[c(1, 3), , c(1, 3)]
)
})

test_that("NULL tensor", {

x <- torch_tensor(NULL)
expect_true(x$dtype == torch_bool())
expect_equal(x$shape, 0)

# subsetting shouldn't crash
expect_error(x[1], regexp = "out of bounds")
expect_error(torch_tensor(as.integer(NULL))[1], regexp = "out of bounds")

})
4 changes: 2 additions & 2 deletions tools/create-decls.R
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ make_load_symbols <- function(decls) {

decls <- readr::read_lines(
"
void _lantern_autograd_edge_list_delete (void* x)
void _lantern_autograd_edge_delete (void* x)
void* _lantern_jit_get_operation_schema (void* name)
void* _lantern_jit_FunctionSchema_name (void* schema)
"
)

Expand Down