Skip to content

Commit

Permalink
Add async functiion type
Browse files Browse the repository at this point in the history
Signed-off-by: Shaygan <[email protected]>
  • Loading branch information
Glyphack committed Aug 21, 2024
1 parent e0aa68e commit 7ef34fb
Show file tree
Hide file tree
Showing 8 changed files with 162 additions and 50 deletions.
2 changes: 1 addition & 1 deletion parser/src/parser/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ pub enum Statement {
TryStatement(Box<Try>),
TryStarStatement(Box<TryStar>),
FunctionDef(Arc<FunctionDef>),
AsyncFunctionDef(Box<AsyncFunctionDef>),
AsyncFunctionDef(Arc<AsyncFunctionDef>),
ClassDef(Arc<ClassDef>),
Match(Box<Match>),
TypeAlias(Box<TypeAlias>),
Expand Down
5 changes: 1 addition & 4 deletions parser/src/parser/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ use crate::{
token::{Kind, Token, TokenValue},
};

#[allow(unused)]
#[derive(Debug, Clone)]
pub struct Parser<'a> {
pub identifiers_start_offset: Vec<(u32, u32, String)>,
Expand All @@ -32,7 +31,6 @@ pub struct Parser<'a> {
// This is incremented when we see an opening bracket and decremented when we
// see a closing bracket.
nested_expression_list: u32,
curr_line_string: String,
path: &'a str,
}

Expand Down Expand Up @@ -62,7 +60,6 @@ impl<'a> Parser<'a> {
prev_token_end,
prev_nonwhitespace_token_end: prev_token_end,
nested_expression_list,
curr_line_string: String::new(),
path,
identifiers_start_offset: identifiers_offset,
}
Expand Down Expand Up @@ -634,7 +631,7 @@ impl<'a> Parser<'a> {
self.expect(Kind::Colon)?;
let body = self.parse_suite()?;
if is_async {
Ok(Statement::AsyncFunctionDef(Box::new(AsyncFunctionDef {
Ok(Statement::AsyncFunctionDef(Arc::new(AsyncFunctionDef {
node: self.finish_node_chomped(node),
name,
args,
Expand Down
11 changes: 7 additions & 4 deletions typechecker/src/ast_visitor.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
use std::sync::Arc;

use enderpy_python_parser as parser;
use enderpy_python_parser::ast::*;

/// A visitor that traverses the AST and calls the visit method for each node
/// This is useful for visitors that only need to visit a few nodes
/// and don't want to implement all the methods.
/// The overridden methods must make sure to continue the traversal.
#[allow(dead_code)]
pub trait TraversalVisitor {
fn visit_stmt(&mut self, s: &Statement) {
// map all statements and call visit
Expand Down Expand Up @@ -40,6 +41,7 @@ pub trait TraversalVisitor {
Statement::TypeAlias(t) => self.visit_type_alias(t),
}
}

fn visit_expr(&mut self, e: &Expression) {
match e {
Expression::Constant(c) => self.visit_constant(c),
Expand Down Expand Up @@ -71,6 +73,7 @@ pub trait TraversalVisitor {
Expression::FormattedValue(f) => self.visit_formatted_value(f),
}
}

fn visit_import(&mut self, _i: &Import) {
todo!();
}
Expand Down Expand Up @@ -170,19 +173,19 @@ pub trait TraversalVisitor {
}
}

fn visit_function_def(&mut self, f: &parser::ast::FunctionDef) {
fn visit_function_def(&mut self, f: &Arc<parser::ast::FunctionDef>) {
for stmt in &f.body {
self.visit_stmt(stmt);
}
}

fn visit_async_function_def(&mut self, f: &parser::ast::AsyncFunctionDef) {
fn visit_async_function_def(&mut self, f: &Arc<parser::ast::AsyncFunctionDef>) {
for stmt in &f.body {
self.visit_stmt(stmt);
}
}

fn visit_class_def(&mut self, c: &parser::ast::ClassDef) {
fn visit_class_def(&mut self, c: &Arc<parser::ast::ClassDef>) {
for stmt in &c.body {
self.visit_stmt(stmt);
}
Expand Down
7 changes: 4 additions & 3 deletions typechecker/src/checker.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::path::PathBuf;
use std::sync::Arc;

use ast::{Expression, Statement};
use dashmap::DashMap;
Expand Down Expand Up @@ -261,7 +262,7 @@ impl<'a> TraversalVisitor for TypeChecker<'a> {
}
}

fn visit_function_def(&mut self, f: &parser::ast::FunctionDef) {
fn visit_function_def(&mut self, f: &Arc<parser::ast::FunctionDef>) {
self.infer_name_type(
&f.name,
f.node.start + 4,
Expand All @@ -283,7 +284,7 @@ impl<'a> TraversalVisitor for TypeChecker<'a> {
self.type_evaluator.symbol_table.revert_scope();
}

fn visit_async_function_def(&mut self, f: &parser::ast::AsyncFunctionDef) {
fn visit_async_function_def(&mut self, f: &Arc<parser::ast::AsyncFunctionDef>) {
self.infer_name_type(
&f.name,
f.node.start + 9,
Expand All @@ -296,7 +297,7 @@ impl<'a> TraversalVisitor for TypeChecker<'a> {
self.type_evaluator.symbol_table.revert_scope();
}

fn visit_class_def(&mut self, c: &parser::ast::ClassDef) {
fn visit_class_def(&mut self, c: &Arc<parser::ast::ClassDef>) {
self.infer_name_type(
&c.name,
c.node.start + 6,
Expand Down
3 changes: 2 additions & 1 deletion typechecker/src/file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::path::PathBuf;
use std::sync::atomic::AtomicUsize;
use std::sync::Arc;

use crate::ast_visitor::TraversalVisitor;
use enderpy_python_parser as parser;
use enderpy_python_parser::ast::*;
use parser::{ast, get_row_col_position, Parser};
Expand Down Expand Up @@ -113,7 +114,7 @@ impl<'a> EnderpyFile<'a> {
}

pub fn get_position(&self, start: u32, end: u32) -> Position {
let (start_line_num, start_line_column, end_line_num, end_line_column) =
let (start_line_num, start_line_column, _end_line_num, _end_line_column) =
get_row_col_position(start, end, &self.line_starts);
Position {
line: start_line_num,
Expand Down
130 changes: 93 additions & 37 deletions typechecker/src/semantic_analyzer.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
use std::{collections::HashMap, sync::Arc};
use std::sync::Arc;

use enderpy_python_parser as parser;
use enderpy_python_parser::ast::Expression;

use parser::ast::{GetNode, Name, Statement};
use parser::ast::{self, GetNode, Name, Statement};

use crate::{
ast_visitor::TraversalVisitor,
build::ResolvedImports,
file::EnderpyFile,
ruff_python_import_resolver::{
import_result::ImportResult, module_descriptor::ImportModuleDescriptor,
},
ruff_python_import_resolver::module_descriptor::ImportModuleDescriptor,
symbol_table::{
Alias, Class, Declaration, DeclarationPath, Function, Parameter, SymbolFlags, SymbolTable,
SymbolTableNode, SymbolTableScope, SymbolTableType, TypeAlias, Variable,
Alias, AsyncFunction, Class, Declaration, DeclarationPath, Function, Parameter,
SymbolFlags, SymbolTable, SymbolTableNode, SymbolTableScope, SymbolTableType, TypeAlias,
Variable,
},
};

Expand All @@ -29,6 +28,13 @@ pub struct SemanticAnalyzer<'a> {
/// import os -> imports.get("os")
/// from os import path -> imports.get("os")
pub imports: &'a ResolvedImports,
pub function_information: FunctionInformation,
}

#[derive(Debug, Clone)]
pub struct FunctionInformation {
pub return_statements: Vec<ast::Return>,
pub yield_statements: Vec<ast::Yield>,
}

#[allow(unused)]
Expand All @@ -38,6 +44,10 @@ impl<'a> SemanticAnalyzer<'a> {
SemanticAnalyzer {
symbol_table: symbols,
imports,
function_information: FunctionInformation {
return_statements: Vec::new(),
yield_statements: Vec::new(),
},
}
}

Expand Down Expand Up @@ -250,8 +260,8 @@ impl<'a> SemanticAnalyzer<'a> {
}
}

impl<'a> SemanticAnalyzer<'a> {
pub fn visit_stmt(&mut self, s: &parser::ast::Statement) {
impl<'a> TraversalVisitor for SemanticAnalyzer<'a> {
fn visit_stmt(&mut self, s: &parser::ast::Statement) {
match s {
parser::ast::Statement::ExpressionStatement(e) => self.visit_expr(e),
parser::ast::Statement::Import(i) => self.visit_import(i),
Expand Down Expand Up @@ -499,25 +509,17 @@ impl<'a> SemanticAnalyzer<'a> {

self.add_arguments_definitions(&f.args);

let mut return_statements = vec![];
let mut yield_statements = vec![];
let mut raise_statements = vec![];
// TODO: clone
let prev_function_information = self.function_information.clone();

for stmt in f.body.iter() {
self.visit_stmt(stmt);
match stmt {
parser::ast::Statement::Raise(r) => raise_statements.push(*r.clone()),
parser::ast::Statement::Return(r) => return_statements.push(*r.clone()),
parser::ast::Statement::ExpressionStatement(e) => {
let Some(yield_expr) = e.as_yield_expr() else {
continue;
};

yield_statements.push(*yield_expr.clone())
}
_ => (),
}
}

let return_statements = std::mem::take(&mut self.function_information.return_statements);
let yield_statements = std::mem::take(&mut self.function_information.yield_statements);
self.function_information = prev_function_information;

for type_parameter in &f.type_params {
let declaration_path = DeclarationPath::new(
self.symbol_table.id,
Expand All @@ -544,7 +546,65 @@ impl<'a> SemanticAnalyzer<'a> {
is_generator: !yield_statements.is_empty(),
return_statements,
yield_statements,
raise_statements,
raise_statements: vec![],
});
let flags = SymbolFlags::empty();
self.create_symbol(f.name.clone(), function_declaration, flags);
}

fn visit_async_function_def(&mut self, f: &Arc<parser::ast::AsyncFunctionDef>) {
let declaration_path = DeclarationPath::new(
self.symbol_table.id,
f.node,
self.symbol_table.current_scope_id,
);

self.symbol_table.push_scope(SymbolTableScope::new(
SymbolTableType::Function(Arc::new(f.to_function_def())),
f.name.clone(),
f.node.start,
self.symbol_table.current_scope_id,
));

self.add_arguments_definitions(&f.args);

// TODO: clone
let prev_function_information = self.function_information.clone();

for stmt in f.body.iter() {
self.visit_stmt(stmt);
}

let return_statements = std::mem::take(&mut self.function_information.return_statements);
let yield_statements = std::mem::take(&mut self.function_information.yield_statements);
self.function_information = prev_function_information;

for type_parameter in &f.type_params {
let declaration_path = DeclarationPath::new(
self.symbol_table.id,
type_parameter.get_node(),
self.symbol_table.current_scope_id,
);
let flags = SymbolFlags::empty();
self.create_symbol(
type_parameter.get_name(),
Declaration::TypeParameter(crate::symbol_table::TypeParameter {
declaration_path,
type_parameter_node: type_parameter.clone(),
}),
flags,
);
}
self.symbol_table.exit_scope();

let function_declaration = Declaration::AsyncFunction(AsyncFunction {
declaration_path,
function_node: f.clone(),
is_method: self.is_inside_class(),
is_generator: !yield_statements.is_empty(),
return_statements,
yield_statements,
raise_statements: vec![],
});
let flags = SymbolFlags::empty();
self.create_symbol(f.name.clone(), function_declaration, flags);
Expand All @@ -567,16 +627,6 @@ impl<'a> SemanticAnalyzer<'a> {
);
}

fn visit_async_function_def(&mut self, _f: &parser::ast::AsyncFunctionDef) {
self.symbol_table.push_scope(SymbolTableScope::new(
SymbolTableType::Function(Arc::new(_f.to_function_def())),
_f.name.clone(),
_f.node.start,
self.symbol_table.current_scope_id,
));
self.symbol_table.exit_scope();
}

fn visit_class_def(&mut self, c: &Arc<parser::ast::ClassDef>) {
self.symbol_table.push_scope(SymbolTableScope::new(
SymbolTableType::Class(c.clone()),
Expand Down Expand Up @@ -662,7 +712,10 @@ impl<'a> SemanticAnalyzer<'a> {

fn visit_named_expr(&mut self, _n: &parser::ast::NamedExpression) {}

fn visit_yield(&mut self, _y: &parser::ast::Yield) {}
// TODO: clone
fn visit_yield(&mut self, y: &parser::ast::Yield) {
self.function_information.yield_statements.push(y.clone());
}

fn visit_yield_from(&mut self, _y: &parser::ast::YieldFrom) {}

Expand Down Expand Up @@ -739,7 +792,10 @@ impl<'a> SemanticAnalyzer<'a> {

fn visit_delete(&mut self, _d: &parser::ast::Delete) {}

fn visit_return(&mut self, _r: &parser::ast::Return) {}
// TODO: clone
fn visit_return(&mut self, r: &parser::ast::Return) {
self.function_information.return_statements.push(r.clone());
}

fn visit_raise(&mut self, _r: &parser::ast::Raise) {}

Expand Down
Loading

0 comments on commit 7ef34fb

Please sign in to comment.