Skip to content

Commit

Permalink
Implemented parallel policies to for statement (#68 from wesuRage/main)
Browse files Browse the repository at this point in the history
Implemented parallel policies to for statement
  • Loading branch information
wesuRage authored Dec 22, 2024
2 parents 38ec504 + d15953c commit 7fa3595
Show file tree
Hide file tree
Showing 10 changed files with 367 additions and 65 deletions.
16 changes: 10 additions & 6 deletions examples/a.glx
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
extern int writeln( string) ;
extern string strrep( string, int) ;
extern int writeln( string) ;

def main( ) -> int:
writeln( "oi" * 20 ) ;
def main( ) -> int:

return 0;
end;
int N := 100000;

for parallel static ( int i := 0; i < N; ++ i) -> 8:
writeln( "hello") ;
end;

return 0;
end;
16 changes: 16 additions & 0 deletions include/backend/generator/statements/generate_outlined_for.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#ifndef GENERATE_OUTLINED_FOR_H
#define GENERATE_OUTLINED_FOR_H

extern "C" {
#include "frontend/ast/definitions.h"
}

#include <llvm/IR/Function.h>
#include <llvm/IR/IRBuilder.h>
#include <llvm/IR/Module.h>
#include <llvm/IR/Constants.h>
#include <llvm/IR/DerivedTypes.h>

llvm::Function* generate_outlined_for(ForNode *node, llvm::LLVMContext &Context, llvm::Module &Module, llvm::GlobalVariable *ompIdent, char *schedule_policy);

#endif // GENERATE_OUTLINED_FOR_H
12 changes: 12 additions & 0 deletions include/backend/generator/utils/generate_stop.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#ifndef GENERATE_STOP_H
#define GENERATE_STOP_H

extern "C" {
#include "frontend/ast/definitions.h"
}
#include <llvm/IR/LLVMContext.h>
#include <llvm/IR/IRBuilder.h>

llvm::Value *generate_stop(AstNode *node, llvm::LLVMContext &Context, llvm::IRBuilder<> &Builder, llvm::Module &Module);

#endif // GENERATE_STOP_H
11 changes: 1 addition & 10 deletions include/frontend/ast/definitions.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,19 +95,10 @@ typedef struct {
AstNode *updater;
AstNode *iterator;
bool is_parallel;
char *schedule_policy; // "static", "dynamic"
char *schedule_policy; // "static", "dynamic", "guided"
AstNode *num_threads;
} ForNode;

typedef struct {
char *type; // "barrier", "atomic"
} SyncNode;

typedef struct {
char *name;
char *memory_type; // "private", "shared", "pgas"
} MemoryNode;

typedef struct {
AstNode *condition;
AstNode **consequent;
Expand Down
1 change: 1 addition & 0 deletions include/frontend/lexer/core.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ typedef enum {
TOKEN_STATIC,
TOKEN_DYNAMIC,
TOKEN_PARALLEL,
TOKEN_GUIDED,
TOKEN_UNKNOWN,
} TokenType;

Expand Down
195 changes: 147 additions & 48 deletions src/backend/generator/statements/generate_for_stmt.cpp
Original file line number Diff line number Diff line change
@@ -1,74 +1,173 @@
#include "backend/generator/utils/generate_stop.hpp"
#include "backend/generator/expressions/generate_expr.hpp"
#include "backend/generator/statements/generate_stmt.hpp"
#include "backend/generator/statements/generate_for_stmt.hpp"
#include "backend/generator/statements/generate_variable_declaration_stmt.hpp"
#include "backend/generator/statements/generate_outlined_for.hpp"
#include "backend/generator/symbols/identifier_symbol_table.hpp"

llvm::Value* generate_for_stmt(ForNode *node, llvm::LLVMContext &Context, llvm::IRBuilder<> &Builder, llvm::Module &Module) {
llvm::Function *currentFunction = Builder.GetInsertBlock()->getParent();

// Blocos básicos do loop
llvm::BasicBlock *preLoopBB = llvm::BasicBlock::Create(Context, "preloop", currentFunction);
llvm::BasicBlock *condBB = llvm::BasicBlock::Create(Context, "cond", currentFunction);
llvm::BasicBlock *bodyBB = llvm::BasicBlock::Create(Context, "body", currentFunction);
llvm::BasicBlock *updateBB = llvm::BasicBlock::Create(Context, "update", currentFunction);
llvm::BasicBlock *endBB = llvm::BasicBlock::Create(Context, "endloop", currentFunction);

// Cria o salto para o bloco de pré-loop
Builder.CreateBr(preLoopBB);
Builder.SetInsertPoint(preLoopBB);

// Declaração da variável do loop (%i) e inicialização

// Declarar e inicializar a variável do loop
VariableNode iteratorVar;
iteratorVar.name = node->variable;
iteratorVar.varType = node->var_type;
iteratorVar.isPtr = node->var_isPtr;
iteratorVar.isConst = false;
iteratorVar.value = node->start;

generate_variable_declaration_stmt(&iteratorVar, Context, Builder, Module);

if (node->iterator){
generate_variable_declaration_stmt(&iteratorVar, Context, Builder, Module);
}

llvm::Value *startVal = generate_expr(node->start, Context, Builder, Module);

llvm::AllocaInst *loopVar = Builder.CreateAlloca(startVal->getType(), nullptr, node->variable);
Builder.CreateStore(startVal, loopVar); // Inicializa %i

Builder.CreateBr(condBB);
Builder.CreateStore(startVal, loopVar); // Inicializar a variável do loop

// Condição do loop
Builder.SetInsertPoint(condBB);

llvm::Value *loopVarVal = Builder.CreateLoad(loopVar->getAllocatedType(), loopVar, node->variable);
llvm::Value *stopVal = generate_stop(node->stop, Context, Builder, Module);

BinaryExprNode *binaryNode = (BinaryExprNode *)node->stop->data;
llvm::Value *rightOperand = generate_expr(binaryNode->right, Context, Builder, Module);
llvm::Value *stopVal = rightOperand;

if (stopVal->getType()->isPointerTy()) {
const SymbolInfo *id = find_identifier(stopVal->getName().str());
stopVal = id->value;
// Certifique-se de que ambos são do tipo i32
if (loopVarVal->getType() != llvm::Type::getInt32Ty(Context)) {
loopVarVal = Builder.CreateSExt(loopVarVal, llvm::Type::getInt32Ty(Context));
}

llvm::Value *cond = Builder.CreateICmpSLT(loopVarVal, stopVal, "loopcond");

Builder.CreateCondBr(cond, bodyBB, endBB); // Se a condição for verdadeira, vai para o corpo; senão, para o final.

// Corpo do loop
Builder.SetInsertPoint(bodyBB);
for (size_t i = 0; i < node->body_count; ++i) {
generate_stmt(node->body[i], Context, Module, Builder); // Gerar as declarações no corpo do loop
if (stopVal->getType() != llvm::Type::getInt32Ty(Context)) {
stopVal = Builder.CreateSExt(stopVal, llvm::Type::getInt32Ty(Context));
}

Builder.CreateBr(updateBB); // Salta para o bloco de atualização

// Atualiza o valor de %i
Builder.SetInsertPoint(updateBB);
llvm::Value *loopVarValUpdate = Builder.CreateLoad(loopVar->getAllocatedType(), loopVar, node->variable);
llvm::Value *inc = Builder.CreateAdd(loopVarValUpdate, llvm::ConstantInt::get(llvm::Type::getInt32Ty(Context), 1), "inc");
Builder.CreateStore(inc, loopVar); // Atualiza %i

Builder.CreateBr(condBB); // Salta de volta para a condição

// Finaliza o loop
Builder.SetInsertPoint(endBB);
return Builder.CreateLoad(loopVar->getAllocatedType(), loopVar, node->variable);
if (node->is_parallel) {
// Definir funções OpenMP
llvm::Constant *ompString = llvm::ConstantDataArray::getString(Context, ";unknown;unknown;0;0;;", true);

llvm::GlobalVariable *ompGlobalString = new llvm::GlobalVariable(
Module,
ompString->getType(),
true,
llvm::GlobalValue::PrivateLinkage,
ompString,
"omp_global_string"
);

llvm::StructType *identTy = llvm::StructType::get(
Context,
{
llvm::Type::getInt32Ty(Context), // Reserved
llvm::Type::getInt32Ty(Context), // Flags
llvm::Type::getInt32Ty(Context), // Reserved
llvm::Type::getInt32Ty(Context), // Source Info
llvm::PointerType::get(llvm::Type::getInt8Ty(Context), 0) // String Pointer
}
);


auto create_omp_ident = [&](int flags, llvm::StringRef name) {
llvm::Constant *ident = llvm::ConstantStruct::get(
identTy,
{
llvm::ConstantInt::get(llvm::Type::getInt32Ty(Context), 0),
llvm::ConstantInt::get(llvm::Type::getInt32Ty(Context), flags),
llvm::ConstantInt::get(llvm::Type::getInt32Ty(Context), 0),
llvm::ConstantInt::get(llvm::Type::getInt32Ty(Context), 22),
llvm::ConstantExpr::getBitCast(ompGlobalString, llvm::PointerType::get(llvm::Type::getInt8Ty(Context), 0))
}
);
return new llvm::GlobalVariable(
Module,
identTy,
true,
llvm::GlobalValue::PrivateLinkage,
ident,
name
);
};

llvm::FunctionCallee kmpcGlobalThreadNum = Module.getOrInsertFunction(
"__kmpc_global_thread_num",
llvm::FunctionType::get(
llvm::Type::getInt32Ty(Context),
{ llvm::PointerType::get(identTy, 0) },
false
)
);

llvm::FunctionCallee kmpcPushNumThreads = Module.getOrInsertFunction(
"__kmpc_push_num_threads",
llvm::FunctionType::get(
llvm::Type::getVoidTy(Context),
{
llvm::PointerType::get(identTy, 0),
llvm::Type::getInt32Ty(Context),
llvm::Type::getInt32Ty(Context)
},
false
)
);

llvm::FunctionCallee kmpcForkCall = Module.getOrInsertFunction(
"__kmpc_fork_call",
llvm::FunctionType::get(
llvm::Type::getVoidTy(Context),
{
llvm::PointerType::get(identTy, 0),
llvm::Type::getInt32Ty(Context),
llvm::PointerType::get(llvm::Type::getVoidTy(Context), 0)
},
true
)
);


llvm::GlobalVariable *ompIdent0 = create_omp_ident(514, "omp_ident0");
llvm::GlobalVariable *ompIdent1 = create_omp_ident(2, "omp_ident1");

llvm::GlobalVariable *ompIdent = create_omp_ident(514, "omp_ident");

llvm::Value *threadNum = Builder.CreateCall(kmpcGlobalThreadNum, { ompIdent1 });

Builder.CreateCall(kmpcPushNumThreads, { ompIdent1, threadNum, llvm::ConstantInt::get(llvm::Type::getInt32Ty(Context), 1) });

llvm::Function* ompOutlined = generate_outlined_for(node, Context, Module, ompIdent, node->schedule_policy);

Builder.CreateCall(kmpcForkCall, { ompIdent1, llvm::ConstantInt::get(llvm::Type::getInt32Ty(Context), 0), ompOutlined });
} else {
llvm::BasicBlock *preLoopBB = llvm::BasicBlock::Create(Context, "preloop", currentFunction);
llvm::BasicBlock *condBB = llvm::BasicBlock::Create(Context, "cond", currentFunction);
llvm::BasicBlock *bodyBB = llvm::BasicBlock::Create(Context, "body", currentFunction);
llvm::BasicBlock *updateBB = llvm::BasicBlock::Create(Context, "update", currentFunction);
llvm::BasicBlock *endBB = llvm::BasicBlock::Create(Context, "endloop", currentFunction);

Builder.CreateBr(preLoopBB);
Builder.SetInsertPoint(preLoopBB);

Builder.CreateBr(condBB);

Builder.SetInsertPoint(condBB);
llvm::Value *loopVarValLoad = Builder.CreateLoad(loopVarVal->getType(), loopVar, node->variable);
llvm::Value *cond = Builder.CreateICmpSLT(loopVarValLoad, stopVal, "loopcond");

Builder.CreateCondBr(cond, bodyBB, endBB);

Builder.SetInsertPoint(bodyBB);

for (size_t i = 0; i < node->body_count; i++) {
generate_stmt(node->body[i], Context, Module, Builder);
}

Builder.CreateBr(updateBB); // Jump to the update block

// Update loop variable
Builder.SetInsertPoint(updateBB);
llvm::Value *inc = Builder.CreateAdd(loopVarValLoad, llvm::ConstantInt::get(llvm::Type::getInt32Ty(Context), 1), "inc");
Builder.CreateStore(inc, loopVar); // Update %i

Builder.CreateBr(condBB); // Jump back to the condition block

// Finalize the loop
Builder.SetInsertPoint(endBB);
}
return nullptr;
}
Loading

0 comments on commit 7fa3595

Please sign in to comment.