Skip to content

Commit

Permalink
Generator argument optimization (and more) (#175)
Browse files Browse the repository at this point in the history
* Fix ABI incompatibilities

* Fix codon-jit on macOS

* Fix scoping bugs

* Fix .codon detection

* Handle static arguments in magic methods; Update simd; Fix misc. bugs

* Avoid partial calls with generators

* clang-format

* Add generator-argument optimization

* Fix typo

* Fix omp test

* Make sure sum() does not call __iadd__

* Clarify difference in docs

* Fix any/all generator pass

* Fix  InstantiateExpr simplification; Support .py as module extension

* clang-format

* Bump version

Co-authored-by: Ibrahim Numanagić <[email protected]>
  • Loading branch information
arshajii and inumanag authored Jan 17, 2023
1 parent fc70c83 commit bac6ae5
Show file tree
Hide file tree
Showing 41 changed files with 514 additions and 99 deletions.
6 changes: 4 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
cmake_minimum_required(VERSION 3.14)
project(
Codon
VERSION "0.15.3"
VERSION "0.15.4"
HOMEPAGE_URL "https://github.com/exaloop/codon"
DESCRIPTION "high-performance, extensible Python compiler")
set(CODON_JIT_PYTHON_VERSION "0.1.1")
set(CODON_JIT_PYTHON_VERSION "0.1.2")
configure_file("${PROJECT_SOURCE_DIR}/cmake/config.h.in"
"${PROJECT_SOURCE_DIR}/codon/config/config.h")
configure_file("${PROJECT_SOURCE_DIR}/cmake/config.py.in"
Expand Down Expand Up @@ -197,6 +197,7 @@ set(CODON_HPPFILES
codon/cir/transform/parallel/schedule.h
codon/cir/transform/pass.h
codon/cir/transform/pythonic/dict.h
codon/cir/transform/pythonic/generator.h
codon/cir/transform/pythonic/io.h
codon/cir/transform/pythonic/list.h
codon/cir/transform/pythonic/str.h
Expand Down Expand Up @@ -304,6 +305,7 @@ set(CODON_CPPFILES
codon/cir/transform/parallel/schedule.cpp
codon/cir/transform/pass.cpp
codon/cir/transform/pythonic/dict.cpp
codon/cir/transform/pythonic/generator.cpp
codon/cir/transform/pythonic/io.cpp
codon/cir/transform/pythonic/list.cpp
codon/cir/transform/pythonic/str.cpp
Expand Down
2 changes: 1 addition & 1 deletion codon/cir/func.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@

#include <algorithm>

#include "codon/parser/common.h"
#include "codon/cir/module.h"
#include "codon/cir/util/iterators.h"
#include "codon/cir/util/operator.h"
#include "codon/cir/util/visitor.h"
#include "codon/cir/var.h"
#include "codon/parser/common.h"

namespace codon {
namespace ir {
Expand Down
6 changes: 3 additions & 3 deletions codon/cir/llvm/llvisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@
#include <unistd.h>
#include <utility>

#include "codon/cir/dsl/codegen.h"
#include "codon/cir/llvm/optimize.h"
#include "codon/cir/util/irtools.h"
#include "codon/compiler/debug_listener.h"
#include "codon/compiler/memory_manager.h"
#include "codon/parser/common.h"
#include "codon/runtime/lib.h"
#include "codon/cir/dsl/codegen.h"
#include "codon/cir/llvm/optimize.h"
#include "codon/cir/util/irtools.h"
#include "codon/util/common.h"

namespace codon {
Expand Down
4 changes: 2 additions & 2 deletions codon/cir/llvm/llvisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

#pragma once

#include "codon/dsl/plugins.h"
#include "codon/cir/llvm/llvm.h"
#include "codon/cir/cir.h"
#include "codon/cir/llvm/llvm.h"
#include "codon/dsl/plugins.h"
#include "codon/util/common.h"

#include <string>
Expand Down
2 changes: 1 addition & 1 deletion codon/cir/llvm/optimize.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

#include <memory>

#include "codon/dsl/plugins.h"
#include "codon/cir/llvm/llvm.h"
#include "codon/dsl/plugins.h"

namespace codon {
namespace ir {
Expand Down
2 changes: 1 addition & 1 deletion codon/cir/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
#include <algorithm>
#include <memory>

#include "codon/parser/cache.h"
#include "codon/cir/func.h"
#include "codon/parser/cache.h"

namespace codon {
namespace ir {
Expand Down
2 changes: 2 additions & 0 deletions codon/cir/transform/manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "codon/cir/transform/parallel/openmp.h"
#include "codon/cir/transform/pass.h"
#include "codon/cir/transform/pythonic/dict.h"
#include "codon/cir/transform/pythonic/generator.h"
#include "codon/cir/transform/pythonic/io.h"
#include "codon/cir/transform/pythonic/list.h"
#include "codon/cir/transform/pythonic/str.h"
Expand Down Expand Up @@ -162,6 +163,7 @@ void PassManager::registerStandardPasses(PassManager::Init init) {
registerPass(std::make_unique<pythonic::DictArithmeticOptimization>());
registerPass(std::make_unique<pythonic::ListAdditionOptimization>());
registerPass(std::make_unique<pythonic::StrAdditionOptimization>());
registerPass(std::make_unique<pythonic::GeneratorArgumentOptimization>());
registerPass(std::make_unique<pythonic::IOCatOptimization>());

// lowering
Expand Down
235 changes: 235 additions & 0 deletions codon/cir/transform/pythonic/generator.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
// Copyright (C) 2022-2023 Exaloop Inc. <https://exaloop.io>

#include "generator.h"

#include <algorithm>

#include "codon/cir/util/cloning.h"
#include "codon/cir/util/irtools.h"
#include "codon/cir/util/matching.h"

namespace codon {
namespace ir {
namespace transform {
namespace pythonic {
namespace {
bool isSum(Func *f) {
return f && f->getName().rfind("std.internal.builtin.sum:", 0) == 0;
}

bool isAny(Func *f) {
return f && f->getName().rfind("std.internal.builtin.any:", 0) == 0;
}

bool isAll(Func *f) {
return f && f->getName().rfind("std.internal.builtin.all:", 0) == 0;
}

// Replaces yields with updates to the accumulator variable.
struct GeneratorSumTransformer : public util::Operator {
Var *accumulator;
bool valid;

explicit GeneratorSumTransformer(Var *accumulator)
: util::Operator(), accumulator(accumulator), valid(true) {}

void handle(YieldInstr *v) override {
auto *M = v->getModule();
auto *val = v->getValue();
if (!val) {
valid = false;
return;
}

Value *rhs = val;
if (val->getType()->is(M->getBoolType())) {
rhs = M->Nr<TernaryInstr>(rhs, M->getInt(1), M->getInt(0));
}

Value *add = *M->Nr<VarValue>(accumulator) + *rhs;
if (!add || !add->getType()->is(accumulator->getType())) {
valid = false;
return;
}

auto *assign = M->Nr<AssignInstr>(accumulator, add);
v->replaceAll(assign);
}

void handle(ReturnInstr *v) override {
auto *M = v->getModule();
auto *newReturn = M->Nr<ReturnInstr>(M->Nr<VarValue>(accumulator));
see(newReturn);
v->replaceAll(util::series(v->getValue(), newReturn));
}

void handle(YieldInInstr *v) override { valid = false; }
};

// Replaces yields with conditional returns of the any/all answer.
struct GeneratorAnyAllTransformer : public util::Operator {
bool any; // true=any, false=all
bool valid;

explicit GeneratorAnyAllTransformer(bool any)
: util::Operator(), any(any), valid(true) {}

void handle(YieldInstr *v) override {
auto *M = v->getModule();
auto *val = v->getValue();
auto *valBool = val ? (*M->getBoolType())(*val) : nullptr;
if (!valBool) {
valid = false;
return;
} else if (!any) {
valBool = M->Nr<TernaryInstr>(valBool, M->getBool(false), M->getBool(true));
}

auto *newReturn = M->Nr<ReturnInstr>(M->getBool(any));
see(newReturn);
auto *rep = M->Nr<IfFlow>(valBool, util::series(newReturn));
v->replaceAll(rep);
}

void handle(ReturnInstr *v) override {
if (saw(v))
return;
auto *M = v->getModule();
auto *newReturn = M->Nr<ReturnInstr>(M->getBool(!any));
see(newReturn);
v->replaceAll(util::series(v->getValue(), newReturn));
}

void handle(YieldInInstr *v) override { valid = false; }
};

Func *genToSum(BodiedFunc *gen, types::Type *startType, types::Type *outType) {
if (!gen || !gen->isGenerator())
return nullptr;

auto *M = gen->getModule();
auto *fn = M->Nr<BodiedFunc>("__sum_wrapper");
auto *genType = cast<types::FuncType>(gen->getType());
if (!genType)
return nullptr;

std::vector<types::Type *> argTypes(genType->begin(), genType->end());
argTypes.push_back(startType);

std::vector<std::string> names;
for (auto it = gen->arg_begin(); it != gen->arg_end(); ++it) {
names.push_back((*it)->getName());
}
names.push_back("start");

auto *fnType = M->getFuncType(outType, argTypes);
fn->realize(fnType, names);

std::unordered_map<id_t, Var *> argRemap;
for (auto it1 = gen->arg_begin(), it2 = fn->arg_begin();
it1 != gen->arg_end() && it2 != fn->arg_end(); ++it1, ++it2) {
argRemap.emplace((*it1)->getId(), *it2);
}

util::CloneVisitor cv(M);
auto *body = cast<SeriesFlow>(cv.clone(gen->getBody(), fn, argRemap));
fn->setBody(body);

Value *init = M->Nr<VarValue>(fn->arg_back());
if (startType->is(M->getIntType()) && outType->is(M->getFloatType()))
init = (*M->getFloatType())(*init);

if (!init || !init->getType()->is(outType))
return nullptr;

auto *accumulator = util::makeVar(init, body, fn, /*prepend=*/true)->getVar();
GeneratorSumTransformer xgen(accumulator);
fn->accept(xgen);
body->push_back(M->Nr<ReturnInstr>(M->Nr<VarValue>(accumulator)));

if (!xgen.valid)
return nullptr;

return fn;
}

Func *genToAnyAll(BodiedFunc *gen, bool any) {
if (!gen || !gen->isGenerator())
return nullptr;

auto *M = gen->getModule();
auto *fn = M->Nr<BodiedFunc>(any ? "__any_wrapper" : "__all_wrapper");
auto *genType = cast<types::FuncType>(gen->getType());

std::vector<types::Type *> argTypes(genType->begin(), genType->end());
std::vector<std::string> names;
for (auto it = gen->arg_begin(); it != gen->arg_end(); ++it) {
names.push_back((*it)->getName());
}

auto *fnType = M->getFuncType(M->getBoolType(), argTypes);
fn->realize(fnType, names);

std::unordered_map<id_t, Var *> argRemap;
for (auto it1 = gen->arg_begin(), it2 = fn->arg_begin();
it1 != gen->arg_end() && it2 != fn->arg_end(); ++it1, ++it2) {
argRemap.emplace((*it1)->getId(), *it2);
}

util::CloneVisitor cv(M);
auto *body = cast<SeriesFlow>(cv.clone(gen->getBody(), fn, argRemap));
fn->setBody(body);

GeneratorAnyAllTransformer xgen(any);
fn->accept(xgen);
body->push_back(M->Nr<ReturnInstr>(M->getBool(!any)));

if (!xgen.valid)
return nullptr;

return fn;
}
} // namespace

const std::string GeneratorArgumentOptimization::KEY =
"core-pythonic-generator-argument-opt";

void GeneratorArgumentOptimization::handle(CallInstr *v) {
auto *M = v->getModule();
auto *func = util::getFunc(v->getCallee());

if (isSum(func) && v->numArgs() == 2) {
auto *call = cast<CallInstr>(v->front());
if (!call)
return;

auto *gen = util::getFunc(call->getCallee());
auto *start = v->back();

if (auto *fn = genToSum(cast<BodiedFunc>(gen), start->getType(), v->getType())) {
std::vector<Value *> args(call->begin(), call->end());
args.push_back(start);
v->replaceAll(util::call(fn, args));
}
} else {
bool any = isAny(func), all = isAll(func);
if (!(any || all) || v->numArgs() != 1 || !v->getType()->is(M->getBoolType()))
return;

auto *call = cast<CallInstr>(v->front());
if (!call)
return;

auto *gen = util::getFunc(call->getCallee());

if (auto *fn = genToAnyAll(cast<BodiedFunc>(gen), any)) {
std::vector<Value *> args(call->begin(), call->end());
v->replaceAll(util::call(fn, args));
}
}
}

} // namespace pythonic
} // namespace transform
} // namespace ir
} // namespace codon
25 changes: 25 additions & 0 deletions codon/cir/transform/pythonic/generator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Copyright (C) 2022-2023 Exaloop Inc. <https://exaloop.io>

#pragma once

#include "codon/cir/transform/pass.h"

namespace codon {
namespace ir {
namespace transform {
namespace pythonic {

/// Pass to optimize passing a generator to some built-in functions
/// like sum(), any() or all(), which will be converted to regular
/// for-loops.
class GeneratorArgumentOptimization : public OperatorPass {
public:
static const std::string KEY;
std::string getKey() const override { return KEY; }
void handle(CallInstr *v) override;
};

} // namespace pythonic
} // namespace transform
} // namespace ir
} // namespace codon
2 changes: 1 addition & 1 deletion codon/cir/types/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
#include <memory>
#include <utility>

#include "codon/parser/cache.h"
#include "codon/cir/module.h"
#include "codon/cir/util/irtools.h"
#include "codon/cir/util/iterators.h"
#include "codon/cir/util/visitor.h"
#include "codon/cir/value.h"
#include "codon/parser/cache.h"
#include <fmt/format.h>

namespace codon {
Expand Down
2 changes: 1 addition & 1 deletion codon/cir/types/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
#include <utility>
#include <vector>

#include "codon/parser/ast.h"
#include "codon/cir/base.h"
#include "codon/cir/util/packs.h"
#include "codon/cir/util/visitor.h"
#include "codon/parser/ast.h"
#include <fmt/format.h>
#include <fmt/ostream.h>

Expand Down
Loading

0 comments on commit bac6ae5

Please sign in to comment.