diff --git a/enderpy/Cargo.toml b/enderpy/Cargo.toml index a4090018..effdc67a 100644 --- a/enderpy/Cargo.toml +++ b/enderpy/Cargo.toml @@ -10,5 +10,6 @@ edition = "2021" [dependencies] enderpy_python_parser = { path = "../parser" , version = "0.1.0" } enderpy_python_type_checker = { path = "../typechecker" , version = "0.1.0" } +corepy_python_translator = { path = "../translator", version = "0.1.0" } clap = { version = "4.5.17", features = ["derive"] } miette.workspace = true diff --git a/enderpy/src/cli.rs b/enderpy/src/cli.rs index 02267048..55aa53be 100644 --- a/enderpy/src/cli.rs +++ b/enderpy/src/cli.rs @@ -23,6 +23,8 @@ pub enum Commands { }, /// Type check Check { path: PathBuf }, + /// Translate to C++ + Translate { path: PathBuf }, /// Symbol table Symbols { path: PathBuf }, diff --git a/enderpy/src/main.rs b/enderpy/src/main.rs index 11fe7b3c..f2d9c59a 100644 --- a/enderpy/src/main.rs +++ b/enderpy/src/main.rs @@ -2,12 +2,14 @@ use std::{ fs::{self, File}, io::{self, Read}, path::{Path, PathBuf}, + sync::Arc, }; use clap::Parser as ClapParser; use cli::{Cli, Commands}; use enderpy_python_parser::{get_row_col_position, parser::parser::Parser, Lexer}; use enderpy_python_type_checker::{build::BuildManager, find_project_root, settings::Settings}; +use corepy_python_translator::translator::CppTranslator; use miette::{bail, IntoDiagnostic, Result}; mod cli; @@ -18,6 +20,7 @@ fn main() -> Result<()> { Commands::Tokenize {} => tokenize(), Commands::Parse { file } => parse(file), Commands::Check { path } => check(path), + Commands::Translate { path } => translate(path), Commands::Watch => watch(), Commands::Symbols { path } => symbols(path), } @@ -134,6 +137,33 @@ fn check(path: &Path) -> Result<()> { Ok(()) } +fn translate(path: &Path) -> Result<()> { + if path.is_dir() { + bail!("Path must be a file"); + } + let root = find_project_root(path); + let python_executable = Some(get_python_executable()?); + let typeshed_path = get_typeshed_path()?; + let settings = Settings { + typeshed_path, + python_executable, + }; + let build_manager = BuildManager::new(settings); + build_manager.build(root); + build_manager.build_one(root, path); + let id = build_manager.paths.get(path).unwrap(); + let file = build_manager.files.get(&id).unwrap(); + let checker = Arc::new(build_manager.type_check(path, &file)); + let mut translator = CppTranslator::new(checker.clone(), &file); + translator.translate(); + println!("{:?}", file.tree); + println!("===="); + println!("{}", translator.output); + println!("===="); + print!("{}", checker.clone().dump_types()); + Ok(()) +} + fn watch() -> Result<()> { todo!() } diff --git a/parser/src/ast.rs b/parser/src/ast.rs index 9feb1129..1ce0ef0a 100644 --- a/parser/src/ast.rs +++ b/parser/src/ast.rs @@ -425,7 +425,13 @@ impl Constant { } else { Cow::Borrowed("false") } - } + }, + ConstantValue::Int => Cow::Borrowed( + &source[self.node.start as usize..self.node.end as usize], + ), + ConstantValue::Float => Cow::Borrowed( + &source[self.node.start as usize..self.node.end as usize], + ), _ => todo!("Call the parser and get the value"), } } diff --git a/translator/Cargo.toml b/translator/Cargo.toml new file mode 100644 index 00000000..8f211b0e --- /dev/null +++ b/translator/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "corepy_python_translator" +version = "0.1.0" +edition = "2021" + +[dependencies] +enderpy_python_parser = { path = "../parser", version = "0.1.0" } +enderpy_python_type_checker = {path = "../typechecker", version = "0.1.0" } +log = { version = "0.4.17" } diff --git a/translator/src/lib.rs b/translator/src/lib.rs new file mode 100644 index 00000000..a73a1679 --- /dev/null +++ b/translator/src/lib.rs @@ -0,0 +1 @@ +pub mod translator; diff --git a/translator/src/translator.rs b/translator/src/translator.rs new file mode 100644 index 00000000..87efe458 --- /dev/null +++ b/translator/src/translator.rs @@ -0,0 +1,478 @@ +use std::sync::Arc; +use std::collections::HashMap; + +use enderpy_python_parser::ast::{self, *}; +use enderpy_python_parser::parser::parser::intern_lookup; + +use enderpy_python_type_checker::{types, ast_visitor::TraversalVisitor, file::EnderpyFile, checker::TypeChecker, types::PythonType}; +use enderpy_python_type_checker::{get_module_name, symbol_table}; + +#[derive(Clone, Debug)] +pub struct CppTranslator<'a> { + pub output: String, + indent_level: usize, + checker: Arc>, + file: &'a EnderpyFile, + current_scope: u32, + prev_scope: u32, + // Member variables of the current class, map from name to type + class_members: HashMap, + // Whether we are currently inside a __init__ method + // and therefore need to record member variables + in_constructor: bool, +} + +impl<'a> CppTranslator<'a> { + pub fn new(checker: Arc>, file: &'a EnderpyFile) -> Self { + CppTranslator { + output: "".to_string(), + indent_level: 0, + checker: checker, + file: file, + current_scope: 0, + prev_scope: 0, + class_members: HashMap::new(), + in_constructor: false, + } + } + + pub fn translate(&mut self) { + for stmt in self.file.tree.body.iter() { + self.visit_stmt(stmt); + } + } + + fn emit>(&mut self, s: S) { + self.output += s.as_ref(); + } + + fn emit_indent(&mut self) { + self.emit(" ".repeat(self.indent_level)); + } + + fn emit_type(&mut self, node: &ast::Node) { + self.emit(self.get_cpp_type(node)); + } + + fn get_cpp_type(&self, node: &ast::Node) -> String { + let typ = self.checker.get_type(node); + return self.python_type_to_cpp(&typ); + } + + fn check_type(&self, node: &Node, typ: &PythonType) { + assert!( + self.checker.get_type(node) == *typ, + "type error at {}, expected {} got {}", + self.file.get_position(node.start, node.end), + typ, self.checker.get_type(node) + ); + } + + fn enter_scope(&mut self, pos: u32) { + let symbol_table = self.checker.get_symbol_table(None); + self.prev_scope = self.current_scope; + self.current_scope = symbol_table.get_scope(pos); + } + + fn leave_scope(&mut self) { + self.current_scope = self.prev_scope; + } + + fn python_type_to_cpp(&self, python_type: &PythonType) -> String { + let details; + match python_type { + PythonType::Class(c) => { + details = &c.details; + }, + PythonType::Instance(i) => { + details = &i.class_type.details; + }, + _ => { + return String::from(format!("", python_type)); + } + }; + // If the current symbol table already contains details.name, + // we do not need to qualify it, otherwise qualify it with the module name, + // unless it is a builtin type in which case we do not qualify it. + let symbol_table = self.checker.get_symbol_table(None); + match symbol_table.lookup_in_scope(&details.name, self.current_scope) { + Some(_) => details.name.to_string(), + None => { + let symbol_table = self.checker.get_symbol_table(Some(details.declaration_path.symbol_table_id)); + if symbol_table.file_path.as_path().ends_with("builtins.pyi") { + details.name.to_string() + } else { + format!("{}::{}", get_module_name(symbol_table.file_path.as_path()), &details.name) + } + } + } + } +} + +impl<'a> TraversalVisitor for CppTranslator<'a> { + fn visit_stmt(&mut self, s: &ast::Statement) { + self.emit_indent(); + match s { + Statement::ExpressionStatement(e) => self.visit_expr(e), + Statement::Import(i) => self.visit_import(i), + Statement::ImportFrom(i) => self.visit_import_from(i), + Statement::AssignStatement(a) => { + self.visit_assign(a); + self.emit(";\n"); + }, + Statement::AnnAssignStatement(a) => self.visit_ann_assign(a), + Statement::AugAssignStatement(a) => self.visit_aug_assign(a), + Statement::Assert(a) => self.visit_assert(a), + Statement::Pass(p) => self.visit_pass(p), + Statement::Delete(d) => self.visit_delete(d), + Statement::ReturnStmt(r) => { + self.visit_return(r); + self.emit(";\n"); + }, + Statement::Raise(r) => self.visit_raise(r), + Statement::BreakStmt(b) => self.visit_break(b), + Statement::ContinueStmt(c) => self.visit_continue(c), + Statement::Global(g) => self.visit_global(g), + Statement::Nonlocal(n) => self.visit_nonlocal(n), + Statement::IfStatement(i) => self.visit_if(i), + Statement::WhileStatement(w) => self.visit_while(w), + Statement::ForStatement(f) => self.visit_for(f), + Statement::WithStatement(w) => self.visit_with(w), + Statement::TryStatement(t) => self.visit_try(t), + Statement::TryStarStatement(t) => self.visit_try_star(t), + Statement::FunctionDef(f) => self.visit_function_def(f), + Statement::ClassDef(c) => self.visit_class_def(c), + Statement::MatchStmt(m) => self.visit_match(m), + Statement::AsyncForStatement(f) => self.visit_async_for(f), + Statement::AsyncWithStatement(w) => self.visit_async_with(w), + Statement::AsyncFunctionDef(f) => self.visit_async_function_def(f), + Statement::TypeAlias(a) => self.visit_type_alias(a), + } + } + + fn visit_expr(&mut self, e: &Expression) { + match e { + Expression::Constant(c) => self.visit_constant(c), + Expression::List(l) => self.visit_list(l), + Expression::Tuple(t) => self.visit_tuple(t), + Expression::Dict(d) => self.visit_dict(d), + Expression::Set(s) => self.visit_set(s), + Expression::Name(n) => self.visit_name(n), + Expression::BoolOp(b) => self.visit_bool_op(b), + Expression::UnaryOp(u) => self.visit_unary_op(u), + Expression::BinOp(b) => self.visit_bin_op(b), + Expression::NamedExpr(n) => self.visit_named_expr(n), + Expression::Yield(y) => self.visit_yield(y), + Expression::YieldFrom(y) => self.visit_yield_from(y), + Expression::Starred(s) => self.visit_starred(s), + Expression::Generator(g) => self.visit_generator(g), + Expression::ListComp(l) => self.visit_list_comp(l), + Expression::SetComp(s) => self.visit_set_comp(s), + Expression::DictComp(d) => self.visit_dict_comp(d), + Expression::Attribute(a) => self.visit_attribute(a), + Expression::Subscript(s) => self.visit_subscript(s), + Expression::Slice(s) => self.visit_slice(s), + Expression::Call(c) => self.visit_call(c), + Expression::Await(a) => self.visit_await(a), + Expression::Compare(c) => self.visit_compare(c), + Expression::Lambda(l) => self.visit_lambda(l), + Expression::IfExp(i) => self.visit_if_exp(i), + Expression::JoinedStr(j) => self.visit_joined_str(j), + Expression::FormattedValue(f) => self.visit_formatted_value(f), + } + } + + fn visit_constant(&mut self, constant: &Constant) { + match constant.value { + ConstantValue::None => self.emit("None"), + ConstantValue::Ellipsis => self.emit("..."), + ConstantValue::Bool(_) => self.emit("bool"), + ConstantValue::Str(_) => self.emit(constant.get_value(&self.file.source)), + ConstantValue::Bytes => self.emit("bytes"), + ConstantValue::Tuple => self.emit("tuple"), + ConstantValue::Int => self.emit(constant.get_value(&self.file.source)), + ConstantValue::Float => self.emit(constant.get_value(&self.file.source)), + ConstantValue::Complex => self.emit("complex"), + /* + Constant::Tuple(elements) => { + let tuple_elements: Vec = elements + .iter() + .map(|elem| self.translate_constant(elem)) + .collect::, _>>()?; + Ok(format!("({})", tuple_elements.join(", "))) + }, + */ + }; + } + + fn visit_import(&mut self, import: &Import) { + for name in import.names.iter() { + if name.name == "torch" { + self.emit("#include \n"); + } + } + } + + fn visit_assign(&mut self, a: &Assign) { + let symbol_table = self.checker.get_symbol_table(None); + for target in &a.targets { + match target { + Expression::Name(n) => { + let node = symbol_table.lookup_in_scope(&n.id, self.current_scope); + match node { + Some(node) => { + let path = node.declarations[0].declaration_path(); + // If this is the place where the name was defined, also emit its type + if path.node == n.node { + self.emit_type(&n.node); + } + }, + None => {}, + }; + self.visit_name(n); + }, + Expression::Attribute(attr) => { + if let Expression::Name(n) = &attr.value { + if n.id == "self" { + let cpp_type = self.get_cpp_type(&a.value.get_node()); + self.class_members.insert(attr.attr.clone(), cpp_type); + } + } + self.visit_expr(target); + } + _ => { + self.visit_expr(target); + } + } + } + self.emit(" = "); + self.visit_expr(&a.value); + } + + fn visit_name(&mut self, name: &Name) { + self.emit(name.id.clone()); + } + + fn visit_bin_op(&mut self, b: &BinOp) { + self.visit_expr(&b.left); + self.emit(b.op.to_string()); + self.visit_expr(&b.right); + } + + fn visit_call(&mut self, c: &Call) { + let mut typ = self.checker.get_type(&c.func.get_node()); + self.visit_expr(&c.func); + self.emit("("); + // In case c.func is a class instance, we need to use the __call__ method + // of that instance instead -- we fix this here. + if let PythonType::Instance(i) = &typ { + let symbol_table = self.checker.get_symbol_table(None); + typ = self.checker.type_evaluator.lookup_on_class(&symbol_table, &i.class_type, "__call__").expect("instance type not callable").clone(); + let PythonType::Callable(old_callable) = typ else { + panic!("XXX"); + }; + let callable_type = types::CallableType::new( + old_callable.name, + old_callable.signature[1..].to_vec(), + old_callable.return_type, + old_callable.is_async, + ); + typ = PythonType::Callable(Box::new(callable_type)); + } + // In case c.func is a class, we need to use the type signature of the + // __init__ method. + if let PythonType::Class(c) = &typ { + let symbol_table = self.checker.get_symbol_table(None); + typ = self.checker.type_evaluator.lookup_on_class(&symbol_table, &c, "__init__").expect("class currently needs an __init__ method").clone(); + let PythonType::Callable(old_callable) = typ else { + panic!("XXX"); + }; + let callable_type = types::CallableType::new( + old_callable.name, + old_callable.signature[1..].to_vec(), + old_callable.return_type, + old_callable.is_async, + ); + typ = PythonType::Callable(Box::new(callable_type)); + } + match typ { + PythonType::Callable(callable) => { + let mut num_pos_args = 0; + // First check all the positional args + for (i, arg) in callable.signature.iter().enumerate() { + match arg { + types::CallableArgs::Positional(t) => { + self.check_type(&c.args[i].get_node(), t); + if i != 0 { + self.emit(", "); + } + self.visit_expr(&c.args[i]); + num_pos_args = num_pos_args + 1; + }, + _ => { + break; + } + } + } + // Then check all the star args if there are any + if c.args.len() > num_pos_args { + self.emit("{"); + for (i, arg) in c.args[num_pos_args..].iter().enumerate() { + self.check_type(&arg.get_node(), callable.signature[num_pos_args].get_type()); + if i != 0 { + self.emit(", "); + } + self.visit_expr(arg); + } + self.emit("}"); + } + }, + _ => { + println!("Shouldn't hit this code path"); + } + } + // for keyword in &c.keywords { + // self.visit_expr(&keyword.value); + // } + self.emit(")"); + } + + fn visit_attribute(&mut self, attribute: &Attribute) { + self.visit_expr(&attribute.value); + match &attribute.value { + Expression::Name(n) => { + let symbol_table = self.checker.get_symbol_table(None); + match symbol_table.lookup_in_scope(&n.id, self.current_scope) { + Some(entry) => { + match entry.last_declaration() { + symbol_table::Declaration::Alias(_a) => { + self.emit(format!("::{}", &attribute.attr)); + return + }, + _ => {} + } + }, + None => {}, + } + }, + _ => {} + } + self.emit(format!(".{}", attribute.attr)); + } + + fn visit_return(&mut self, r: &Return) { + self.emit("return "); + if let Some(value) = &r.value { + self.visit_expr(value); + } + } + + fn visit_function_def(&mut self, f: &Arc) { + self.enter_scope(f.node.start); + let mut name = intern_lookup(f.name).to_string(); + if name == "__init__" { + // In this case, the function is a constructor and in + // C++ needs to be named the same as the class. We achieve + // this by naming it after the type of the "self" argument + // of __init__. + name = self.get_cpp_type(&f.args.args[0].node); + self.class_members = HashMap::new(); + self.in_constructor = true; + } + if let Some(ret) = &f.returns { + let return_type = self.get_cpp_type(&ret.get_node()); + self.emit(format!("{} {}(", return_type, name)); + } else { + if self.in_constructor { + self.emit(format!("{}(", name)); + } else { + self.emit(format!("void {}(", name)); + } + } + // Filter out "self" arg (first arg of a Python method), + // since in C++ the "this" arg is implicit. + // TODO: This will also filter out random args called "self" -- + // instead we should check if we are in a class definition and then + // only filter the first argument called "self". + let args = f.args.args.iter().filter(|arg| arg.arg != "self"); + for (i, arg) in args.enumerate() { + if i != 0 { + self.emit(", "); + } + self.emit_type(&arg.node); + self.emit(format!(" {}", arg.arg)); + } + self.emit(") {\n"); + self.indent_level += 1; + // If this is an instance method, introduce "self" + self.emit_indent(); + self.emit("auto& self = *this;\n"); + for stmt in &f.body { + self.visit_stmt(stmt); + } + self.indent_level -= 1; + self.emit_indent(); + self.emit("}\n"); + self.in_constructor = false; + self.leave_scope(); + } + + fn visit_class_def(&mut self, c: &Arc) { + let name = intern_lookup(c.name); + self.emit(format!("class {} {{\n", name)); + self.emit_indent(); + self.emit("public:\n"); + self.enter_scope(c.node.start); + self.indent_level += 1; + for stmt in &c.body { + self.visit_stmt(stmt); + } + self.indent_level -= 1; + // print class member variables + self.emit_indent(); + self.emit("private:\n"); + // TODO: Want to move this out, not clone it + for (key, value) in self.class_members.clone() { + self.emit_indent(); + self.emit(format!(" {} {};\n", value, key)); + } + self.class_members = HashMap::new(); + self.emit_indent(); + self.emit("};\n"); + self.leave_scope(); + } + + fn visit_for(&mut self, f: &For) { + let mut bound = None; + match &f.iter { + Expression::Call(c) => { + match &c.func { + Expression::Name(n) => { + if n.id == "range" { + bound = Some(c.args[0].clone()); + } + } + _ => {} + } + }, + _ => {} + } + self.emit("for(int "); + self.visit_expr(&f.target); + self.emit(" = 0; "); + self.visit_expr(&f.target); + self.emit(" < "); + self.visit_expr(&bound.unwrap()); + self.emit("; ++"); + self.visit_expr(&f.target); + self.emit(") {\n"); + self.indent_level += 1; + for stmt in &f.body { + self.visit_stmt(stmt); + } + self.indent_level -= 1; + self.emit_indent(); + self.emit("}\n"); + } +} + diff --git a/typechecker/src/checker.rs b/typechecker/src/checker.rs index c58c8a84..e17e861f 100644 --- a/typechecker/src/checker.rs +++ b/typechecker/src/checker.rs @@ -7,7 +7,7 @@ use enderpy_python_parser::parser::parser::intern_lookup; use super::{type_evaluator::TypeEvaluator, types::PythonType}; use crate::build::BuildManager; -use crate::symbol_table::Id; +use crate::symbol_table::{Id, SymbolTable}; use crate::types::ModuleRef; use crate::{ast_visitor::TraversalVisitor, diagnostic::CharacterSpan}; use rust_lapper::{Interval, Lapper}; @@ -16,7 +16,7 @@ use rust_lapper::{Interval, Lapper}; pub struct TypeChecker<'a> { pub types: Lapper, id: Id, - type_evaluator: TypeEvaluator<'a>, + pub type_evaluator: TypeEvaluator<'a>, build_manager: &'a BuildManager, current_scope: u32, prev_scope: u32, @@ -136,6 +136,20 @@ impl<'a> TypeChecker<'a> { str } + + pub fn get_type(&self, node: &ast::Node) -> PythonType { + for r in self.types.find(node.start, node.end) { + if r.start == node.start && r.stop == node.end { + return r.val.clone(); + } + } + return PythonType::Unknown; + } + + pub fn get_symbol_table(&self, id: Option) -> Arc { + let id = id.unwrap_or(self.id); + return self.build_manager.get_symbol_table_by_id(&id); + } } #[allow(unused)] impl<'a> TraversalVisitor for TypeChecker<'a> { @@ -456,8 +470,10 @@ impl<'a> TraversalVisitor for TypeChecker<'a> { } fn visit_bin_op(&mut self, b: &BinOp) { - let l_type = self.infer_expr_type(&b.left); - let r_type = self.infer_expr_type(&b.right); + // let l_type = self.infer_expr_type(&b.left); + self.visit_expr(&b.left); + // let r_type = self.infer_expr_type(&b.right); + self.visit_expr(&b.right); } fn visit_named_expr(&mut self, _n: &NamedExpression) { diff --git a/typechecker/src/lib.rs b/typechecker/src/lib.rs index c6396fef..79ecc988 100644 --- a/typechecker/src/lib.rs +++ b/typechecker/src/lib.rs @@ -1,9 +1,9 @@ use std::path::Path; -mod ast_visitor; -mod file; +pub mod ast_visitor; +pub mod file; mod ruff_python_import_resolver; -mod symbol_table; +pub mod symbol_table; pub mod build; pub mod checker; @@ -11,7 +11,7 @@ pub mod diagnostic; pub mod semantic_analyzer; pub mod settings; pub mod type_evaluator; -mod types; +pub mod types; pub(crate) mod builtins { pub const LIST_TYPE: &str = "list"; @@ -40,5 +40,16 @@ pub fn find_project_root(path: &Path) -> &Path { } pub fn get_module_name(path: &Path) -> String { - path.to_str().unwrap().replace(['/', '\\'], ".") + // First we strip .pyi and / or __init__.pyi from the end + let mut s = path.to_str().unwrap(); + s = match s.strip_suffix("/__init__.pyi") { + Some(new) => new, + None => s + }; + s = match s.strip_suffix(".pyi") { + Some(new) => new, + None => s + }; + // And then we replace the slashes with . + s.replace(['/', '\\'], ".") } diff --git a/typechecker/src/type_evaluator.rs b/typechecker/src/type_evaluator.rs index 3adda5bf..1a5a0123 100755 --- a/typechecker/src/type_evaluator.rs +++ b/typechecker/src/type_evaluator.rs @@ -22,6 +22,7 @@ use super::{ }, }; use crate::{ + get_module_name, build::BuildManager, semantic_analyzer::get_member_access_info, symbol_table::{self, Class, Declaration, DeclarationPath, Id, SymbolTable, SymbolTableNode}, @@ -59,6 +60,13 @@ bitflags::bitflags! { } } +fn class_type_to_instance_type(class_type: PythonType) -> PythonType { + let PythonType::Class(c) = class_type else { + return PythonType::Unknown; + }; + PythonType::Instance(types::InstanceType::new(c.clone(), [].to_vec())) +} + /// Struct for evaluating the type of an expression impl<'a> TypeEvaluator<'a> { pub fn new(build_manager: &'a BuildManager) -> Self { @@ -84,12 +92,12 @@ impl<'a> TypeEvaluator<'a> { let typ = match &c.value { // Constants are not literals unless they are explicitly // typing.readthedocs.io/en/latest/spec/literal.html#backwards-compatibility - ast::ConstantValue::Int => self.get_builtin_type("int"), - ast::ConstantValue::Float => self.get_builtin_type("float"), - ast::ConstantValue::Str(_) => self.get_builtin_type("str"), - ast::ConstantValue::Bool(_) => self.get_builtin_type("bool"), + ast::ConstantValue::Int => self.get_builtin_type("int").map(class_type_to_instance_type), + ast::ConstantValue::Float => self.get_builtin_type("float").map(class_type_to_instance_type), + ast::ConstantValue::Str(_) => self.get_builtin_type("str").map(class_type_to_instance_type), + ast::ConstantValue::Bool(_) => self.get_builtin_type("bool").map(class_type_to_instance_type), ast::ConstantValue::None => Some(PythonType::None), - ast::ConstantValue::Bytes => self.get_builtin_type("bytes"), + ast::ConstantValue::Bytes => self.get_builtin_type("bytes").map(class_type_to_instance_type), ast::ConstantValue::Ellipsis => Some(PythonType::Any), // TODO: implement ast::ConstantValue::Tuple => Some(PythonType::Unknown), @@ -122,8 +130,21 @@ impl<'a> TypeEvaluator<'a> { scope_id, ); Ok(return_type) + } else if let PythonType::Instance(i) = &called_type { + // This executes the __call__ method of the instance + let Some(PythonType::Callable(c)) = self.lookup_on_class(symbol_table, &i.class_type, "__call__") else { + bail!("If you call an instance, it must have a __call__ method"); + }; + let return_type = self.get_return_type_of_callable( + &c, + &call.args, + symbol_table, + scope_id, + ); + Ok(return_type) } else if let PythonType::Class(c) = &called_type { - Ok(called_type) + // This instantiates the class + Ok(PythonType::Instance(types::InstanceType::new(c.clone(), [].to_vec()))) } else if let PythonType::TypeVar(t) = &called_type { let Some(first_arg) = call.args.first() else { bail!("TypeVar must be called with a name"); @@ -551,7 +572,7 @@ impl<'a> TypeEvaluator<'a> { let expr_type = match type_annotation { Expression::Name(name) => { // TODO: Reject this type if the name refers to a variable. - self.get_name_type(&name.id, Some(name.node.start), symbol_table, scope_id) + return class_type_to_instance_type(self.get_name_type(&name.id, Some(name.node.start), symbol_table, scope_id)); } Expression::Constant(ref c) => match c.value { ast::ConstantValue::None => PythonType::None, @@ -676,7 +697,18 @@ impl<'a> TypeEvaluator<'a> { // TODO: check if other binary operators are allowed _ => todo!(), } - } + }, + Expression::Attribute(a) => { + match &a.value { + Expression::Name(n) => { + let Some(typ) = self.lookup_on_module(symbol_table, scope_id, &n.id, &a.attr) else { + return PythonType::Unknown; + }; + return class_type_to_instance_type(typ); + }, + _ => todo!(), + }; + }, _ => PythonType::Unknown, }; @@ -869,9 +901,39 @@ impl<'a> TypeEvaluator<'a> { &iter_method_type.type_parameters, &iter_method_type.specialized, ) - } + }, + PythonType::Class(class_type) => { + let iter_method = match self.lookup_on_class( + &symbol_table, + &class_type, + "__iter__", + ) { + Some(PythonType::Callable(c)) => c, + Some(other) => panic!("iter method was not callable: {}", other), + None => panic!("next method not found"), + }; + let Some(iter_method_type) = &iter_method.return_type.class() + else { + panic!("iter method return type is not class"); + }; + let next_method = match self.lookup_on_class( + &symbol_table, + &iter_method_type, + "__next__", + ) { + Some(PythonType::Callable(c)) => c, + Some(other) => panic!("next method was not callable: {}", other), + None => panic!("next method not found"), + }; + self.resolve_generics( + &next_method.return_type, + &iter_method_type.type_parameters, + &iter_method_type.specialized, + ) + // PythonType::Unknown + }, _ => { - error!("iterating over a {} is not defined", iter_type); + error!("iterating over a {:?} is not defined", iter_type); PythonType::Unknown } } @@ -1004,15 +1066,20 @@ impl<'a> TypeEvaluator<'a> { PythonType::Unknown } None => { - let Some(ref resolved_import) = a.import_result else { - trace!("import result not found"); - return PythonType::Unknown; - }; - - let module_id = resolved_import.resolved_ids.first().unwrap(); - return PythonType::Module(ModuleRef { - module_id: *module_id, - }); + match &a.import_node { + Some(i) => { + let module_name = &i.names[0].name; + let Some(module_symbol_table) = self.get_symbol_table_for_module(&a, module_name) else { + return PythonType::Unknown; + }; + return PythonType::Module(ModuleRef { + module_id: module_symbol_table.id, + }); + }, + None => { + return PythonType::Unknown; + } + } } } } @@ -1485,7 +1552,7 @@ impl<'a> TypeEvaluator<'a> { ret_type } - fn lookup_on_class( + pub fn lookup_on_class( &self, symbol_table: &SymbolTable, c: &ClassType, @@ -1517,6 +1584,39 @@ impl<'a> TypeEvaluator<'a> { symbol.map(|node| self.get_symbol_type(node, symbol_table, None)) } + /// Find a type inside a Python module + fn lookup_on_module( + &self, + symbol_table: &SymbolTable, + scope_id: u32, + module_name: &str, + attr: &str, + ) -> Option { + // See if the module is in the symbol table + let symbol_table_entry = symbol_table.lookup_in_scope(module_name, scope_id)?; + match symbol_table_entry.last_declaration() { + Declaration::Alias(a) => { + let module_symbol_table = self.get_symbol_table_for_module(&a, module_name)?; + return Some(self.get_name_type(attr, None, &module_symbol_table, 0)); + } + _ => {} + }; + None + } + + fn get_symbol_table_for_module(&self, alias: &symbol_table::Alias, module_name: &str) -> Option> { + let Some(ref resolved_import) = alias.import_result else { + return None; + }; + for id in resolved_import.resolved_ids.iter() { + let module_symbol_table = self.get_symbol_table(id); + if module_name == get_module_name(module_symbol_table.file_path.as_path()) { + return Some(module_symbol_table); + } + } + return None; + } + fn get_function_signature( &self, arguments: &ast::Arguments,