Skip to content

Commit

Permalink
Various bug fixes (#185)
Browse files Browse the repository at this point in the history
* Fix #183

* Fix #162; Fix #135

* Fix #155

* Fix #191

* Fix #187

* Fix #189

* Fix vtable init; Fix failing tests on Linux

* Fix #190

* Fix #156

* Fix union routing

* Format

---------

Co-authored-by: A. R. Shajii <[email protected]>
  • Loading branch information
inumanag and arshajii authored Feb 5, 2023
1 parent 28ebb2e commit 5f13644
Show file tree
Hide file tree
Showing 14 changed files with 260 additions and 54 deletions.
14 changes: 13 additions & 1 deletion codon/parser/visitors/simplify/function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -282,8 +282,20 @@ void SimplifyVisitor::visit(FunctionStmt *stmt) {
// Expression to be used if function binding is modified by captures or decorators
ExprPtr finalExpr = nullptr;
// If there are captures, replace `fn` with `fn(cap1=cap1, cap2=cap2, ...)`
if (!captures.empty())
if (!captures.empty()) {
finalExpr = N<CallExpr>(N<IdExpr>(stmt->name), partialArgs);
// Add updated self reference in case function is recursive!
auto pa = partialArgs;
for (auto &a : pa) {
if (!a.name.empty())
a.value = N<IdExpr>(a.name);
else
a.value = clone(a.value);
}
f->suite = N<SuiteStmt>(
N<AssignStmt>(N<IdExpr>(rootName), N<CallExpr>(N<IdExpr>(rootName), pa)),
suite);
}

// Parse remaining decorators
for (auto i = stmt->decorators.size(); i-- > 0;) {
Expand Down
5 changes: 4 additions & 1 deletion codon/parser/visitors/translate/translate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,10 @@ void TranslateVisitor::visit(StringExpr *expr) {
void TranslateVisitor::visit(IdExpr *expr) {
auto val = ctx->find(expr->value);
seqassert(val, "cannot find '{}'", expr->value);
if (auto *v = val->getVar())
if (expr->value == "__vtable_size__")
result = make<ir::IntConst>(expr, ctx->cache->classRealizationCnt + 2,
getType(expr->getType()));
else if (auto *v = val->getVar())
result = make<ir::VarValue>(expr, v);
else if (auto *f = val->getFunc())
result = make<ir::VarValue>(expr, f);
Expand Down
21 changes: 18 additions & 3 deletions codon/parser/visitors/typecheck/call.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -783,16 +783,31 @@ ExprPtr TypecheckVisitor::transformHasAttr(CallExpr *expr) {
.type->getStatic()
->evaluate()
.getString();
std::vector<TypePtr> args{typ};
std::vector<std::pair<std::string, TypePtr>> args{{"", typ}};
if (expr->expr->isId("hasattr:0")) {
// Case: the first hasattr overload allows passing argument types via *args
auto tup = expr->args[1].value->getTuple();
seqassert(tup, "not a tuple");
for (auto &a : tup->items) {
transformType(a);
transform(a);
if (!a->getType()->getClass())
return nullptr;
args.push_back(a->getType());
args.push_back({"", a->getType()});
}
auto kwtup = expr->args[2].value->origExpr->getCall();
seqassert(expr->args[2].value->origExpr && expr->args[2].value->origExpr->getCall(),
"expected call: {}", expr->args[2].value->origExpr);
auto kw = expr->args[2].value->origExpr->getCall();
auto kwCls =
in(ctx->cache->classes, expr->args[2].value->getType()->getClass()->name);
seqassert(kwCls, "cannot find {}",
expr->args[2].value->getType()->getClass()->name);
for (size_t i = 0; i < kw->args.size(); i++) {
auto &a = kw->args[i].value;
transform(a);
if (!a->getType()->getClass())
return nullptr;
args.push_back({kwCls->fields[i].name, a->getType()});
}
}

Expand Down
80 changes: 53 additions & 27 deletions codon/parser/visitors/typecheck/infer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -382,11 +382,11 @@ StmtPtr TypecheckVisitor::prepareVTables() {
// def class_init_vtables():
// return __internal__.class_make_n_vtables(<NUM_REALIZATIONS> + 1)
auto &initAllVT = ctx->cache->functions[rep];
auto suite = N<SuiteStmt>(
N<ReturnStmt>(N<CallExpr>(N<IdExpr>("__internal__.class_make_n_vtables:0"),
N<IntExpr>(ctx->cache->classRealizationCnt + 1))));
auto suite = N<SuiteStmt>(N<ReturnStmt>(N<CallExpr>(
N<IdExpr>("__internal__.class_make_n_vtables:0"), N<IdExpr>("__vtable_size__"))));
initAllVT.ast->suite = suite;
auto typ = initAllVT.realizations.begin()->second->type;
LOG_REALIZE("[poly] {} : {}", typ, *suite);
typ->ast = initAllVT.ast.get();
auto fx = realizeFunc(typ.get(), true);

Expand All @@ -402,30 +402,36 @@ StmtPtr TypecheckVisitor::prepareVTables() {
suite = N<SuiteStmt>();
for (auto &[_, cls] : ctx->cache->classes) {
for (auto &[r, real] : cls.realizations) {
size_t vtSz = 0;
for (auto &[base, vtable] : real->vtables) {
if (!vtable.ir)
vtSz += vtable.table.size();
}
auto var = initFn.ast->args[0].name;
// p.__setitem__(real.ID) = Ptr[cobj](real.vtables.size() + 2)
suite->stmts.push_back(N<ExprStmt>(N<CallExpr>(
N<DotExpr>(N<IdExpr>(var), "__setitem__"), N<IntExpr>(real->id),
N<CallExpr>(NT<InstantiateExpr>(NT<IdExpr>("Ptr"),
std::vector<ExprPtr>{NT<IdExpr>("cobj")}),
N<IntExpr>(vtSz + 2)))));
// __internal__.class_set_typeinfo(p[real.ID], real.ID)
suite->stmts.push_back(N<ExprStmt>(N<CallExpr>(
N<IdExpr>("__internal__.class_set_typeinfo:0"),
N<IndexExpr>(N<IdExpr>(var), N<IntExpr>(real->id)), N<IntExpr>(real->id))));
vtSz = 0;
for (auto &[base, vtable] : real->vtables) {
if (!vtable.ir) {
auto var = initFn.ast->args[0].name;
// p.__setitem__(real.ID) = Ptr[cobj](real.vtables.size() + 2)
suite->stmts.push_back(N<ExprStmt>(N<CallExpr>(
N<DotExpr>(N<IdExpr>(var), "__setitem__"), N<IntExpr>(real->id),
N<CallExpr>(NT<InstantiateExpr>(NT<IdExpr>("Ptr"),
std::vector<ExprPtr>{NT<IdExpr>("cobj")}),
N<IntExpr>(vtable.table.size() + 2)))));
// __internal__.class_set_typeinfo(p[real.ID], real.ID)
suite->stmts.push_back(N<ExprStmt>(
N<CallExpr>(N<IdExpr>("__internal__.class_set_typeinfo:0"),
N<IndexExpr>(N<IdExpr>(var), N<IntExpr>(real->id)),
N<IntExpr>(real->id))));
for (auto &[k, v] : vtable.table) {
auto &[fn, id] = v;
std::vector<ExprPtr> ids;
for (auto &t : fn->getArgTypes())
ids.push_back(NT<IdExpr>(t->realizedName()));
// p[real.ID].__setitem__(f.ID, Function[<TYPE_F>](f).__raw__())
LOG_REALIZE("[poly] vtable[{}][{}] = {}", real->id, vtSz + id, fn);
suite->stmts.push_back(N<ExprStmt>(N<CallExpr>(
N<DotExpr>(N<IndexExpr>(N<IdExpr>(var), N<IntExpr>(real->id)),
"__setitem__"),
N<IntExpr>(id),
N<IntExpr>(vtSz + id),
N<CallExpr>(N<DotExpr>(
N<CallExpr>(
NT<InstantiateExpr>(
Expand All @@ -438,12 +444,14 @@ StmtPtr TypecheckVisitor::prepareVTables() {
N<IdExpr>(fn->realizedName())),
"__raw__")))));
}
vtSz += vtable.table.size();
}
}
}
}
initFn.ast->suite = suite;
typ = initFn.realizations.begin()->second->type;
LOG_REALIZE("[poly] {} : {}", typ, suite->toString(2));
typ->ast = initFn.ast.get();
realizeFunc(typ.get(), true);

Expand All @@ -469,6 +477,7 @@ StmtPtr TypecheckVisitor::prepareVTables() {
N<DotExpr>(N<IdExpr>(clsTyp->realizedName()), "__vtable_id__"))));
}

LOG_REALIZE("[poly] {} : {}", t, *suite);
initObjFns.ast->suite = suite;
t->ast = initObjFns.ast.get();
realizeFunc(t.get(), true);
Expand Down Expand Up @@ -502,6 +511,7 @@ StmtPtr TypecheckVisitor::prepareVTables() {
N<DotExpr>(NT<InstantiateExpr>(
NT<IdExpr>(format("{}{}", TYPE_TUPLE, types.size())), types),
"__elemsize__"));
LOG_REALIZE("[poly] {} : {}", t, *suite);
initDist.ast->suite = suite;
t->ast = initDist.ast.get();
realizeFunc(t.get(), true);
Expand Down Expand Up @@ -802,8 +812,8 @@ TypecheckVisitor::generateSpecialAst(types::FuncType *type) {
type->getArgTypes()[0]->getHeterogenousTuple()) {
// Special case: do not realize auto-generated heterogenous __getitem__
E(Error::EXPECTED_TYPE, getSrcInfo(), "iterable");
} else if (startswith(ast->name, "Function.__call__")) {
// Special case: Function.__call__
} else if (startswith(ast->name, "Function.__call_internal__")) {
// Special case: Function.__call_internal__
/// TODO: move to IR one day
std::vector<StmtPtr> items;
items.push_back(nullptr);
Expand All @@ -826,6 +836,14 @@ TypecheckVisitor::generateSpecialAst(types::FuncType *type) {
ll.push_back(format("ret {{}} %{}", as.size()));
items[0] = N<ExprStmt>(N<StringExpr>(combine2(ll, "\n")));
ast->suite = N<SuiteStmt>(items);
} else if (startswith(ast->name, "Union.__new__:0")) {
auto unionType = type->funcParent->getUnion();
seqassert(unionType, "expected union, got {}", type->funcParent);

StmtPtr suite = N<ReturnStmt>(N<CallExpr>(
N<IdExpr>("__internal__.new_union:0"), N<IdExpr>(type->ast->args[0].name),
N<IdExpr>(unionType->realizedTypeName())));
ast->suite = suite;
} else if (startswith(ast->name, "__internal__.new_union:0")) {
// Special case: __internal__.new_union
// def __internal__.new_union(value, U[T0, ..., TN]):
Expand Down Expand Up @@ -910,21 +928,29 @@ TypecheckVisitor::generateSpecialAst(types::FuncType *type) {
auto suite = N<SuiteStmt>();
int tag = 0;
for (auto &t : unionTypes) {
auto callee =
N<DotExpr>(N<CallExpr>(N<IdExpr>("__internal__.union_get_data:0"),
N<IdExpr>(selfVar), NT<IdExpr>(t->realizedName())),
fnName);
auto args = N<StarExpr>(N<IdExpr>(ast->args[2].name.substr(1)));
auto kwargs = N<KeywordStarExpr>(N<IdExpr>(ast->args[3].name.substr(2)));
std::vector<CallExpr::Arg> callArgs;
ExprPtr check =
N<CallExpr>(N<IdExpr>("hasattr"), NT<IdExpr>(t->realizedName()),
N<StringExpr>(fnName), args->clone(), kwargs->clone());
suite->stmts.push_back(N<IfStmt>(
N<BinaryExpr>(N<CallExpr>(N<IdExpr>("__internal__.union_get_tag:0"),
N<IdExpr>(selfVar)),
"==", N<IntExpr>(tag)),
N<ReturnStmt>(N<CallExpr>(
N<DotExpr>(N<CallExpr>(N<IdExpr>("__internal__.union_get_data:0"),
N<IdExpr>(selfVar), NT<IdExpr>(t->realizedName())),
fnName),
N<StarExpr>(N<IdExpr>(ast->args[2].name.substr(1))),
N<KeywordStarExpr>(N<IdExpr>(ast->args[3].name.substr(2)))))));
N<BinaryExpr>(
check, "&&",
N<BinaryExpr>(N<CallExpr>(N<IdExpr>("__internal__.union_get_tag:0"),
N<IdExpr>(selfVar)),
"==", N<IntExpr>(tag))),
N<SuiteStmt>(N<ReturnStmt>(N<CallExpr>(callee, args, kwargs)))));
tag++;
}
suite->stmts.push_back(
N<ThrowStmt>(N<CallExpr>(N<IdExpr>("std.internal.types.error.TypeError"),
N<StringExpr>("invalid union call"))));
// suite->stmts.push_back(N<ReturnStmt>(N<NoneExpr>()));
unify(type->getRetType(), ctx->instantiate(ctx->getType("Union")));
ast->suite = suite;
} else if (startswith(ast->name, "__internal__.get_union_first:0")) {
Expand Down
41 changes: 27 additions & 14 deletions codon/parser/visitors/typecheck/op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -657,8 +657,12 @@ ExprPtr TypecheckVisitor::transformBinaryMagic(BinaryExpr *expr) {
if (!lt->is("pyobj") && rt->is("pyobj")) {
// Special case: `obj op pyobj` -> `rhs.__rmagic__(lhs)` on lhs
// Assumes that pyobj implements all left and right magics
return transform(N<CallExpr>(N<DotExpr>(expr->rexpr, format("__{}__", rightMagic)),
expr->lexpr));
auto l = ctx->cache->getTemporaryVar("l"), r = ctx->cache->getTemporaryVar("r");
return transform(
N<StmtExpr>(N<AssignStmt>(N<IdExpr>(l), expr->lexpr),
N<AssignStmt>(N<IdExpr>(r), expr->rexpr),
N<CallExpr>(N<DotExpr>(N<IdExpr>(r), format("__{}__", rightMagic)),
N<IdExpr>(l))));
}
if (lt->getUnion()) {
// Special case: `union op obj` -> `union.__magic__(rhs)`
Expand All @@ -667,19 +671,24 @@ ExprPtr TypecheckVisitor::transformBinaryMagic(BinaryExpr *expr) {
}

// Normal operations: check if `lhs.__magic__(lhs, rhs)` exists
auto method = findBestMethod(lt, format("__{}__", magic), {expr->lexpr, expr->rexpr});

// Right-side magics: check if `rhs.__rmagic__(rhs, lhs)` exists
if (!method && (method = findBestMethod(rt, format("__{}__", rightMagic),
{expr->rexpr, expr->lexpr}))) {
swap(expr->lexpr, expr->rexpr);
}

if (method) {
if (auto method =
findBestMethod(lt, format("__{}__", magic), {expr->lexpr, expr->rexpr})) {
// Normal case: `__magic__(lhs, rhs)`
return transform(
N<CallExpr>(N<IdExpr>(method->ast->name), expr->lexpr, expr->rexpr));
}

// Right-side magics: check if `rhs.__rmagic__(rhs, lhs)` exists
if (auto method = findBestMethod(rt, format("__{}__", rightMagic),
{expr->rexpr, expr->lexpr})) {
auto l = ctx->cache->getTemporaryVar("l"), r = ctx->cache->getTemporaryVar("r");
return transform(N<StmtExpr>(
N<AssignStmt>(N<IdExpr>(l), expr->lexpr),
N<AssignStmt>(N<IdExpr>(r), expr->rexpr),
N<CallExpr>(N<IdExpr>(method->ast->name), N<IdExpr>(r), N<IdExpr>(l))));
}
// 145

return nullptr;
}

Expand Down Expand Up @@ -745,14 +754,18 @@ TypecheckVisitor::transformStaticTupleIndex(const ClassTypePtr &tuple,
sliceAdjustIndices(sz, &start, &stop, step);

// Generate a sub-tuple
auto var = N<IdExpr>(ctx->cache->getTemporaryVar("tup"));
auto ass = N<AssignStmt>(var, expr);
std::vector<ExprPtr> te;
for (auto i = start; (step > 0) ? (i < stop) : (i > stop); i += step) {
if (i < 0 || i >= sz)
E(Error::TUPLE_RANGE_BOUNDS, index, sz - 1, i);
te.push_back(N<DotExpr>(clone(expr), classItem->fields[i].name));
te.push_back(N<DotExpr>(clone(var), classItem->fields[i].name));
}
return {true, transform(N<CallExpr>(
N<DotExpr>(format(TYPE_TUPLE "{}", te.size()), "__new__"), te))};
ExprPtr e = transform(N<StmtExpr>(
std::vector<StmtPtr>{ass},
N<CallExpr>(N<DotExpr>(format(TYPE_TUPLE "{}", te.size()), "__new__"), te)));
return {true, e};
}

return {false, nullptr};
Expand Down
15 changes: 15 additions & 0 deletions codon/parser/visitors/typecheck/typecheck.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,21 @@ types::FuncTypePtr TypecheckVisitor::findBestMethod(const ClassTypePtr &typ,
return m.empty() ? nullptr : m[0];
}

/// Select the best method indicated of an object that matches the given argument
/// types. See @c findMatchingMethods for details.
types::FuncTypePtr TypecheckVisitor::findBestMethod(
const ClassTypePtr &typ, const std::string &member,
const std::vector<std::pair<std::string, types::TypePtr>> &args) {
std::vector<CallExpr::Arg> callArgs;
for (auto &[n, a] : args) {
callArgs.push_back({n, std::make_shared<NoneExpr>()}); // dummy expression
callArgs.back().value->setType(a);
}
auto methods = ctx->findMethod(typ->name, member, false);
auto m = findMatchingMethods(typ, methods, callArgs);
return m.empty() ? nullptr : m[0];
}

/// Select the best method among the provided methods given the list of arguments.
/// See @c reorderNamedArgs for details.
std::vector<types::FuncTypePtr>
Expand Down
3 changes: 3 additions & 0 deletions codon/parser/visitors/typecheck/typecheck.h
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,9 @@ class TypecheckVisitor : public CallbackASTVisitor<ExprPtr, StmtPtr> {
types::FuncTypePtr findBestMethod(const types::ClassTypePtr &typ,
const std::string &member,
const std::vector<ExprPtr> &args);
types::FuncTypePtr
findBestMethod(const types::ClassTypePtr &typ, const std::string &member,
const std::vector<std::pair<std::string, types::TypePtr>> &args);
std::vector<types::FuncTypePtr>
findMatchingMethods(const types::ClassTypePtr &typ,
const std::vector<types::FuncTypePtr> &methods,
Expand Down
2 changes: 1 addition & 1 deletion extra/python/codon/jit.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# distutils: language=c++
# cython: language_level=3
# cython: c_string_type=unicode
# cython: c_string_encoding=ascii
# cython: c_string_encoding=utf8

from libcpp.string cimport string
from libcpp.vector cimport vector
Expand Down
6 changes: 4 additions & 2 deletions stdlib/internal/core.codon
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,9 @@ class Ref[T]:
@__internal__
@tuple
class Union[TU]:
pass
# compiler-generated
def __new__(val):
TU

# dummy
@__internal__
Expand Down Expand Up @@ -153,7 +155,7 @@ def isinstance(obj, what):
def overload():
pass

def hasattr(obj, attr: Static[str], *args):
def hasattr(obj, attr: Static[str], *args, **kwargs):
"""Special handling"""
pass

Expand Down
7 changes: 6 additions & 1 deletion stdlib/internal/internal.codon
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ from C import seq_print(str)
from C import exit(int)
from C import malloc(int) -> cobj as c_malloc

__vtable_size__ = 0

@extend
class __internal__:
@pure
Expand Down Expand Up @@ -438,8 +440,11 @@ class Function:
return __internal__.raw_type_str(self.__raw__(), "function")

@llvm
def __call__(self, *args) -> TR:
def __call_internal__(self: Function[T, TR], args: T) -> TR:
noop # compiler will populate this one

def __call__(self, *args) -> TR:
return Function.__call_internal__(self, args)

__vtables__ = __internal__.class_init_vtables()
def _____(): __vtables__ # make it global!
Loading

0 comments on commit 5f13644

Please sign in to comment.