diff --git a/parser/src/parser/ast.rs b/parser/src/parser/ast.rs index 7df1c569..7696ce71 100644 --- a/parser/src/parser/ast.rs +++ b/parser/src/parser/ast.rs @@ -1,5 +1,6 @@ use is_macro::Is; use std::fmt::{self}; +use std::sync::Arc; use miette::{SourceOffset, SourceSpan}; @@ -79,9 +80,9 @@ pub enum Statement { AsyncWithStatement(Box), TryStatement(Box), TryStarStatement(Box), - FunctionDef(Box), + FunctionDef(Arc), AsyncFunctionDef(Box), - ClassDef(Box), + ClassDef(Arc), Match(Box), TypeAlias(Box), } diff --git a/parser/src/parser/parser.rs b/parser/src/parser/parser.rs index e4a7cc5d..e19088d8 100644 --- a/parser/src/parser/parser.rs +++ b/parser/src/parser/parser.rs @@ -4,7 +4,7 @@ use core::panic; /// For example star expressions are defined slightly differently in python grammar and references. /// So there might be duplicates of both. Try to migrate the wrong names to how they are called in: /// https://docs.python.org/3/reference/grammar.html -use std::vec; +use std::{sync::Arc, vec}; use miette::Result; @@ -643,8 +643,8 @@ impl<'a> Parser<'a> { type_params, }))) } else { - Ok(Statement::FunctionDef(Box::new(FunctionDef { - node: self.finish_node_chomped(node), + Ok(Statement::FunctionDef(Arc::new(FunctionDef { + node: self.finish_node(node), name, args, body, @@ -707,7 +707,7 @@ impl<'a> Parser<'a> { self.expect(Kind::Colon)?; let body = self.parse_suite()?; - Ok(Statement::ClassDef(Box::new(ClassDef { + Ok(Statement::ClassDef(Arc::new(ClassDef { node: self.finish_node(node), name, bases, diff --git a/typechecker/src/build.rs b/typechecker/src/build.rs index 96acbfeb..eba24919 100755 --- a/typechecker/src/build.rs +++ b/typechecker/src/build.rs @@ -1,6 +1,7 @@ use std::{ - collections::HashMap, + collections::{HashMap, HashSet}, path::{Path, PathBuf}, + sync::Arc, }; use dashmap::DashMap; @@ -34,8 +35,14 @@ pub struct BuildManager<'a> { impl<'a> BuildManager<'a> { pub fn new(settings: Settings) -> Self { let mut builder = Builder::new(); + + let log_level = match std::env::var("DEBUG") { + Ok(_) => log::LevelFilter::Debug, + _ => log::LevelFilter::Info, + }; + builder - .filter(None, log::LevelFilter::Info) + .filter(None, log_level) .format_timestamp(None) .try_init(); @@ -73,11 +80,9 @@ impl<'a> BuildManager<'a> { let (imports, mut new_modules) = gather_files(vec![builtins], root, &self.import_config, &self.host); log::debug!("Imports resolved"); - for mut module in new_modules.iter_mut() { + for mut module in new_modules { let sym_table = module.populate_symbol_table(&imports); - self.symbol_tables.insert(sym_table.id, sym_table); - } - for module in new_modules { + self.symbol_tables.insert(module.id, sym_table); self.module_ids.insert(module.path.to_path_buf(), module.id); self.modules.insert(module.id, module); } @@ -91,11 +96,9 @@ impl<'a> BuildManager<'a> { let (imports, mut new_modules) = gather_files(vec![enderpy_file], root, &self.import_config, &self.host); log::debug!("Imports resolved"); - for mut module in new_modules.iter_mut() { + for mut module in new_modules { let sym_table = module.populate_symbol_table(&imports); self.symbol_tables.insert(module.id, sym_table); - } - for module in new_modules { self.module_ids.insert(module.path.to_path_buf(), module.id); self.modules.insert(module.id, module); } @@ -106,6 +109,7 @@ impl<'a> BuildManager<'a> { // This step happens after the binding phase pub fn type_check(&self, path: &Path) -> TypeChecker { let mut module_to_check = self.get_state(path); + let mut checker = TypeChecker::new( self.get_symbol_table(path), &self.symbol_tables, @@ -128,7 +132,6 @@ impl<'a> BuildManager<'a> { } pub fn get_hover_information(&self, path: &Path, line: u32, column: u32) -> String { - return "".to_string(); let module = self.get_state(path); let checker = self.type_check(path); let symbol_table = self.get_symbol_table(path); @@ -153,59 +156,101 @@ impl<'a> BuildManager<'a> { } } +#[derive(Debug, Clone)] +pub struct ResolvedImport { + pub resolved_ids: Vec, + result: ImportResult, +} + +pub type ResolvedImports = HashMap>; + fn gather_files<'a>( mut initial_files: Vec>, root: &Path, import_config: &ruff_python_resolver::config::Config, host: &ruff_python_resolver::host::StaticHost, -) -> ( - HashMap, - Vec>, -) { +) -> (ResolvedImports, HashSet>) { let execution_environment = &execution_environment::ExecutionEnvironment { root: root.to_path_buf(), python_version: ruff_python_resolver::python_version::PythonVersion::Py312, python_platform: ruff_python_resolver::python_platform::PythonPlatform::Darwin, extra_paths: vec![], }; - let mut new_modules = Vec::with_capacity(initial_files.len() * 5); + let mut path_to_id: HashMap<&Path, Id> = HashMap::with_capacity(initial_files.len() * 5); + let mut new_modules = HashSet::with_capacity(initial_files.len() * 5); let mut import_results = HashMap::new(); - - let cache: &mut HashMap> = - &mut HashMap::new(); + let mut seen = HashSet::new(); while let Some(module) = initial_files.pop() { - if cache.get(&module.path).is_some() { + if seen.contains(&module.path) { continue; } + seen.insert(module.path.clone()); let resolved_imports = resolve_file_imports( &module, execution_environment, import_config, host, - &import_results, + // &import_results, ); - new_modules.push(module); + new_modules.insert(module); for (import_desc, mut resolved) in resolved_imports { - if resolved.is_import_found { - for resolved_path in resolved.resolved_paths.iter_mut() { - if !initial_files.iter().any(|m| m.path == *resolved_path) { - let e = EnderpyFile::new(resolved_path.clone(), true); - initial_files.push(e); - } + if !resolved.is_import_found { + continue; + } + let mut resolved_ids = Vec::with_capacity(resolved.resolved_paths.len()); + for resolved_path in resolved.resolved_paths.iter_mut() { + if let Some(found) = new_modules.iter().find(|m| *m.path == *resolved_path) { + resolved_ids.push(found.id); + } else if let Some(found) = initial_files.iter().find(|m| *m.path == *resolved_path) + { + resolved_ids.push(found.id); + } else { + let e = EnderpyFile::new(std::mem::take(resolved_path), true); + resolved_ids.push(e.id); + initial_files.push(e); } + } - for (_, implicit_import) in resolved.implicit_imports.iter_mut() { - let e = EnderpyFile::new(implicit_import.path.clone(), true); + // TODO: don't know if the implicit imports should be in the resolved list or not + // For imports like from os import path it points to the path.py file which is in the + // implicit imports so without this we cannot resolved that. + for (_, implicit_import) in resolved.implicit_imports.iter_mut() { + let resolved_path = &mut implicit_import.path; + if let Some(found) = new_modules.iter().find(|m| *m.path == *resolved_path) { + resolved_ids.push(found.id); + } else if let Some(found) = initial_files.iter().find(|m| *m.path == *resolved_path) + { + resolved_ids.push(found.id); + } else { + let e = EnderpyFile::new(std::mem::take(resolved_path), true); + resolved_ids.push(e.id); initial_files.push(e); } } - import_results.insert(import_desc, resolved); + import_results.insert( + import_desc, + Arc::new(ResolvedImport { + resolved_ids, + result: resolved, + }), + ); } } new_modules.extend(initial_files); - (import_results, new_modules) + + for import in import_results.iter() { + for resolved in import.1.resolved_ids.iter() { + if !new_modules.iter().any(|m| m.id == *resolved) { + for module in new_modules.iter() { + println!("{:?} - {:?}", module.path, module.id); + } + panic!("symbol table not found {resolved:?}"); + } + } + } + (import_results, new_modules.into()) } fn resolve_file_imports( @@ -213,7 +258,6 @@ fn resolve_file_imports( execution_environment: &ruff_python_resolver::execution_environment::ExecutionEnvironment, import_config: &ruff_python_resolver::config::Config, host: &ruff_python_resolver::host::StaticHost, - cached_imports: &HashMap, ) -> HashMap { let mut imports = HashMap::new(); debug!("resolving imports for file {:?}", file.path); @@ -230,7 +274,8 @@ fn resolve_file_imports( }; for import_desc in import_descriptions { - let resolved = match cached_imports.contains_key(&import_desc) { + // TODO: Cache non relative imports + let resolved = match false { true => continue, false => resolver::resolve_import( &file.path, @@ -279,6 +324,7 @@ mod tests { r"module_name: .*.typechecker.test_data.inputs.symbol_table..*.py", "module_name: [REDACTED]", ); + settings.add_filter(r"Id\(\d+\)", "Id(REDACTED)"); settings.add_filter(r"\(id: .*\)", "(id: [REDACTED])"); settings.bind(|| { insta::assert_snapshot!(result); diff --git a/typechecker/src/checker.rs b/typechecker/src/checker.rs index ab6df67a..01b56151 100644 --- a/typechecker/src/checker.rs +++ b/typechecker/src/checker.rs @@ -8,10 +8,7 @@ use enderpy_python_parser::ast::{self, *}; use super::{type_evaluator::TypeEvaluator, types::PythonType}; use crate::symbol_table::Id; use crate::types::ModuleRef; -use crate::{ - ast_visitor::TraversalVisitor, diagnostic::CharacterSpan, file::EnderpyFile, - symbol_table::SymbolTable, -}; +use crate::{ast_visitor::TraversalVisitor, diagnostic::CharacterSpan, symbol_table::SymbolTable}; use rust_lapper::{Interval, Lapper}; #[derive(Clone, Debug)] @@ -179,9 +176,7 @@ impl<'a> TraversalVisitor for TypeChecker<'a> { self.types.insert(Interval { start, stop, - val: PythonType::Module(ModuleRef { - module_path: PathBuf::new(), - }), + val: PythonType::Module(ModuleRef { module_id: Id(0) }), }); } diff --git a/typechecker/src/file.rs b/typechecker/src/file.rs index 8b94d327..d6b97128 100755 --- a/typechecker/src/file.rs +++ b/typechecker/src/file.rs @@ -1,24 +1,15 @@ use core::panic; -use std::path::Path; +use std::path::PathBuf; use std::sync::atomic::AtomicUsize; -use std::{collections::HashMap, path::PathBuf}; +use std::sync::Arc; -use dashmap::DashMap; use enderpy_python_parser as parser; use enderpy_python_parser::ast::*; use parser::{ast, Parser}; use std::sync::atomic::Ordering; -use crate::checker::TypeChecker; -use crate::{ - ast_visitor::TraversalVisitor, - diagnostic::Position, - ruff_python_import_resolver::{ - import_result::ImportResult, module_descriptor::ImportModuleDescriptor, - }, - semantic_analyzer::SemanticAnalyzer, - symbol_table::SymbolTable, -}; +use crate::build::ResolvedImports; +use crate::{diagnostic::Position, semantic_analyzer::SemanticAnalyzer, symbol_table::SymbolTable}; use crate::{get_module_name, symbol_table}; #[derive(Clone, Debug)] @@ -35,15 +26,30 @@ pub struct EnderpyFile<'a> { pub module: String, // if this source is found by following an import pub followed: bool, - pub path: PathBuf, + pub path: Arc, pub source: String, pub offset_line_number: Vec, pub tree: ast::Module, dummy: &'a str, } -static COUNTER: AtomicUsize = AtomicUsize::new(1); + +impl<'a> Eq for EnderpyFile<'a> {} + +impl<'a> PartialEq for EnderpyFile<'a> { + fn eq(&self, other: &Self) -> bool { + self.id == other.id && self.path == other.path + } +} + +impl<'a> std::hash::Hash for EnderpyFile<'a> { + fn hash(&self, state: &mut H) { + self.id.hash(state); + self.path.hash(state); + } +} fn get_id() -> u32 { + static COUNTER: AtomicUsize = AtomicUsize::new(1); COUNTER.fetch_add(1, Ordering::SeqCst) as u32 } @@ -68,7 +74,7 @@ impl<'a> EnderpyFile<'a> { followed, module, tree, - path, + path: Arc::new(path), dummy: "sdfsd", } } @@ -125,10 +131,7 @@ impl<'a> EnderpyFile<'a> { } /// entry point to fill up the symbol table from the global definitions - pub fn populate_symbol_table( - &mut self, - imports: &HashMap, - ) -> SymbolTable { + pub fn populate_symbol_table(&mut self, imports: &ResolvedImports) -> SymbolTable { let mut sem_anal = SemanticAnalyzer::new(self, imports); for stmt in &self.tree.body { sem_anal.visit_stmt(stmt) diff --git a/typechecker/src/semantic_analyzer.rs b/typechecker/src/semantic_analyzer.rs index 176c892f..4dce9fc0 100644 --- a/typechecker/src/semantic_analyzer.rs +++ b/typechecker/src/semantic_analyzer.rs @@ -1,12 +1,13 @@ -use std::{collections::HashMap, path::Path}; +use std::{collections::HashMap, sync::Arc}; use enderpy_python_parser as parser; use enderpy_python_parser::ast::Expression; -use parser::ast::{FunctionDef, GetNode, Name, Statement}; +use parser::ast::{GetNode, Name, Statement}; use crate::{ ast_visitor::TraversalVisitor, + build::ResolvedImports, file::EnderpyFile, ruff_python_import_resolver::{ import_result::ImportResult, module_descriptor::ImportModuleDescriptor, @@ -27,15 +28,12 @@ pub struct SemanticAnalyzer<'a> { /// if we have a file with the following imports this is how we use the map /// import os -> imports.get("os") /// from os import path -> imports.get("os") - pub imports: &'a HashMap, + pub imports: &'a ResolvedImports, } #[allow(unused)] impl<'a> SemanticAnalyzer<'a> { - pub fn new( - file: &'a EnderpyFile<'a>, - imports: &'a HashMap, - ) -> Self { + pub fn new(file: &'a EnderpyFile<'a>, imports: &'a ResolvedImports) -> Self { let symbols = SymbolTable::new(&file.path, file.id); SemanticAnalyzer { symbol_table: symbols, @@ -238,18 +236,8 @@ impl<'a> SemanticAnalyzer<'a> { /// Returns true if the current function assigns an attribute to an object /// Functions like __init__ and __new__ are considered to assign attributes fn function_assigns_attribute(&self, symbol_table: &SymbolTable) -> bool { - if let Some(FunctionDef { - node, - name: fname, - args, - body, - decorator_list, - returns, - type_comment, - type_params, - }) = symbol_table.current_scope().kind.as_function() - { - if fname == "__init__" || fname == "__new__" { + if let Some(function_def) = symbol_table.current_scope().kind.as_function() { + if function_def.name == "__init__" || function_def.name == "__new__" { return true; } } @@ -257,8 +245,8 @@ impl<'a> SemanticAnalyzer<'a> { } } -impl<'a> TraversalVisitor for SemanticAnalyzer<'a> { - fn visit_stmt(&mut self, s: &parser::ast::Statement) { +impl<'a> SemanticAnalyzer<'a> { + pub fn visit_stmt(&mut self, s: &parser::ast::Statement) { match s { parser::ast::Statement::ExpressionStatement(e) => self.visit_expr(e), parser::ast::Statement::Import(i) => self.visit_import(i), @@ -325,11 +313,10 @@ impl<'a> TraversalVisitor for SemanticAnalyzer<'a> { fn visit_import(&mut self, i: &parser::ast::Import) { for alias in &i.names { - let import_result = match self.imports.get(&ImportModuleDescriptor::from(alias)) { - Some(result) => result.clone(), - None => ImportResult::not_found(), - }; - // TODO: Report unresolved import if import_result is None + let import_result = self + .imports + .get(&ImportModuleDescriptor::from(alias)) + .cloned(); let declaration_path = DeclarationPath::new( self.symbol_table.id, alias.node, @@ -351,15 +338,14 @@ impl<'a> TraversalVisitor for SemanticAnalyzer<'a> { } fn visit_import_from(&mut self, _i: &parser::ast::ImportFrom) { - let module_import_result = match self.imports.get(&ImportModuleDescriptor::from(_i)) { - Some(result) => result.clone(), - None => ImportResult::not_found(), - }; + let module_import_result = self.imports.get(&ImportModuleDescriptor::from(_i)); for alias in &_i.names { if alias.name == "*" { - self.symbol_table - .star_imports - .push(module_import_result.clone()); + if let Some(module_import_result) = module_import_result { + self.symbol_table + .star_imports + .push(module_import_result.clone()); + } continue; } let declaration_path = DeclarationPath::new( @@ -373,7 +359,7 @@ impl<'a> TraversalVisitor for SemanticAnalyzer<'a> { import_node: None, symbol_name: Some(alias.name.clone()), module_name: None, - import_result: module_import_result.clone(), + import_result: module_import_result.cloned(), }); let flags = SymbolFlags::empty(); @@ -473,7 +459,7 @@ impl<'a> TraversalVisitor for SemanticAnalyzer<'a> { } } - fn visit_function_def(&mut self, f: &parser::ast::FunctionDef) { + fn visit_function_def(&mut self, f: &Arc) { let declaration_path = DeclarationPath::new( self.symbol_table.id, f.node, @@ -500,7 +486,7 @@ impl<'a> TraversalVisitor for SemanticAnalyzer<'a> { // } } self.symbol_table.push_scope(SymbolTableScope::new( - crate::symbol_table::SymbolTableType::Function(f.clone()), + crate::symbol_table::SymbolTableType::Function(Arc::clone(f)), f.name.clone(), f.node.start, self.symbol_table.current_scope_id, @@ -578,7 +564,7 @@ impl<'a> TraversalVisitor for SemanticAnalyzer<'a> { fn visit_async_function_def(&mut self, _f: &parser::ast::AsyncFunctionDef) { self.symbol_table.push_scope(SymbolTableScope::new( - SymbolTableType::Function(_f.to_function_def()), + SymbolTableType::Function(Arc::new(_f.to_function_def())), _f.name.clone(), _f.node.start, self.symbol_table.current_scope_id, @@ -586,7 +572,7 @@ impl<'a> TraversalVisitor for SemanticAnalyzer<'a> { self.symbol_table.exit_scope(); } - fn visit_class_def(&mut self, c: &parser::ast::ClassDef) { + fn visit_class_def(&mut self, c: &Arc) { self.symbol_table.push_scope(SymbolTableScope::new( SymbolTableType::Class(c.clone()), c.name.clone(), @@ -631,7 +617,7 @@ impl<'a> TraversalVisitor for SemanticAnalyzer<'a> { .to_str() .unwrap() .to_string(), - c.clone(), + Arc::clone(c), class_declaration_path, class_body_scope_id, )); @@ -777,17 +763,12 @@ pub fn get_member_access_info( let value_name = &name.id; let current_scope = symbol_table.current_scope(); - let FunctionDef { - args, - decorator_list, - .. - } = current_scope.kind.as_function()?; - + let function_def = current_scope.kind.as_function()?; let parent_scope = symbol_table.parent_scope(symbol_table.current_scope())?; let enclosing_class = parent_scope.kind.as_class()?; - let first_arg = args.args.first()?; + let first_arg = function_def.args.args.first()?; let is_value_equal_to_first_arg = value_name == first_arg.arg.as_str(); @@ -797,7 +778,7 @@ pub fn get_member_access_info( // Check if one of the decorators is a classmethod or staticmethod let mut is_class_member = false; - for decorator in decorator_list { + for decorator in function_def.decorator_list.iter() { if let parser::ast::Expression::Call(call) = decorator { if let Some(name) = call.func.as_name() { if name.id == "classmethod" { diff --git a/typechecker/src/symbol_table.rs b/typechecker/src/symbol_table.rs index 1909e2e4..3c3bcb9c 100644 --- a/typechecker/src/symbol_table.rs +++ b/typechecker/src/symbol_table.rs @@ -3,11 +3,12 @@ use rust_lapper::{Interval, Lapper}; use std::fs; use std::path::Path; +use std::sync::Arc; use std::{collections::HashMap, fmt::Display, path::PathBuf}; use enderpy_python_parser::ast::{self, ClassDef, FunctionDef, Node}; -use crate::ruff_python_import_resolver::import_result::ImportResult; +use crate::build::ResolvedImport; #[derive(Debug, Clone, PartialEq, Eq, Copy, Hash)] pub struct Id(pub u32); @@ -23,7 +24,7 @@ pub struct SymbolTable { pub file_path: PathBuf, pub scope_starts: Lapper, - pub star_imports: Vec, + pub star_imports: Vec>, pub id: Id, } @@ -137,6 +138,11 @@ impl SymbolTable { &self, lookup_request: &LookupSymbolRequest, ) -> Option<&SymbolTableNode> { + log::debug!( + "looking for symbol {:?} in symbol table with scopes: {:?}", + lookup_request, + self.file_path + ); let mut scope = match lookup_request.scope { Some(scope_id) => self.get_scope_by_id(scope_id).expect("no scope found"), None => self.current_scope(), @@ -290,8 +296,8 @@ pub enum SymbolTableType { /// BUILTIN scope is used for builtins like len, print, etc. BUILTIN, Module, - Class(ClassDef), - Function(FunctionDef), + Class(Arc), + Function(Arc), } bitflags! { @@ -377,13 +383,6 @@ impl Declaration { Declaration::TypeAlias(t) => &t.declaration_path, } } - - // pub fn get_symbol_table<'a>(&self, symbol_tables: &'a [SymbolTable]) -> &'a SymbolTable { - // let symbol_table = symbol_tables - // .iter() - // .find(|symbol_table| symbol_table.id == self.declaration_path().symbol_table_id); - // symbol_table.expect("Symbol table not found for this symbol node: {self:?}") - // } } #[derive(Debug, Clone)] @@ -397,7 +396,7 @@ pub struct Variable { #[derive(Debug, Clone)] pub struct Function { pub declaration_path: DeclarationPath, - pub function_node: ast::FunctionDef, + pub function_node: Arc, pub is_method: bool, pub is_generator: bool, /// return statements that are reachable in the top level function body @@ -432,7 +431,7 @@ pub struct Class { // These classes have their behavior defined in PEPs so we need to handle them differently pub special: bool, /// Special classes have a generic class node. So this node is null for special classes - pub class_node: Option, + pub class_node: Option>, pub class_scope_id: u32, pub qual_name: String, } @@ -440,7 +439,7 @@ pub struct Class { impl Class { pub fn new( mut module_name: String, - class_node: ast::ClassDef, + class_node: Arc, declaration_path: DeclarationPath, class_scope_id: u32, ) -> Self { @@ -503,7 +502,7 @@ pub struct Alias { /// e.g. import os.path -> os.path is the module name pub module_name: Option, /// The result of the import - pub import_result: ImportResult, + pub import_result: Option>, } #[derive(Debug, Clone)] @@ -552,7 +551,10 @@ impl SymbolTableNode { impl Display for SymbolTable { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { if !self.star_imports.is_empty() { - writeln!(f, "{:?}", self.star_imports)?; + writeln!(f, "Star imports:")?; + for import in self.star_imports.iter() { + writeln!(f, "{:?}", import.resolved_ids)?; + } } let mut sorted_scopes = self.scopes.iter().collect::>(); sorted_scopes.sort_by(|a, b| a.name.cmp(&b.name)); diff --git a/typechecker/src/type_evaluator.rs b/typechecker/src/type_evaluator.rs index 11475970..6e5baa3c 100755 --- a/typechecker/src/type_evaluator.rs +++ b/typechecker/src/type_evaluator.rs @@ -2,7 +2,7 @@ #![allow(unused_variables)] use core::panic; -use std::{path::PathBuf, str::FromStr}; +use std::path::PathBuf; use dashmap::DashMap; use enderpy_python_parser as parser; @@ -19,7 +19,7 @@ use super::{ use crate::{ semantic_analyzer::get_member_access_info, symbol_table::{Class, Declaration, Id, LookupSymbolRequest, SymbolTable, SymbolTableNode}, - types::{ClassType, TypeVar}, + types::{ClassType, ModuleRef, TypeVar}, }; const LITERAL_TYPE_PARAMETER_MSG: &str = "Type arguments for 'Literal' must be None, a literal value (int, bool, str, or bytes), or an enum value"; @@ -328,8 +328,9 @@ impl<'a> TypeEvaluator<'a> { } } PythonType::Module(module) => { - let id = self.ids.get(&module.module_path).unwrap(); - let module_sym_table = self.imported_symbol_tables.get(&id).unwrap(); + log::debug!("module: {:?}", module); + let module_sym_table = + self.imported_symbol_tables.get(&module.module_id).unwrap(); self.get_name_type(&a.attr, None, &module_sym_table, Some(0)) } _ => Ok(PythonType::Unknown), @@ -359,7 +360,7 @@ impl<'a> TypeEvaluator<'a> { // This function tries to find the python type from an annotation expression // If the annotation is invalid it returns unknown type - pub fn get_type_from_annotation( + pub fn get_annotation_type( &self, type_annotation: &ast::Expression, symbol_table: &SymbolTable, @@ -410,18 +411,15 @@ impl<'a> TypeEvaluator<'a> { self.handle_union_type(union_parameters.to_vec()) } "Optional" => { - let inner_value = self.get_type_from_annotation( - &s.slice, - symbol_table, - scope_id, - ); + let inner_value = + self.get_annotation_type(&s.slice, symbol_table, scope_id); PythonType::Optional(Box::new(inner_value)) } _ => PythonType::Any, }; } let type_parameters = - vec![self.get_type_from_annotation(&s.slice, symbol_table, None)]; + vec![self.get_annotation_type(&s.slice, symbol_table, None)]; PythonType::Class(ClassType { details: class_type.details.clone(), type_parameters, @@ -479,12 +477,10 @@ impl<'a> TypeEvaluator<'a> { // Check if there's any import * and try to find the symbol in those files for star_import in symbol_table.star_imports.iter() { log::debug!("checking star imports {:?}", star_import); - for resolved in star_import.resolved_paths.iter() { - log::debug!("checking path {:?}", resolved); - let id = self.ids.get(resolved).expect("module not found"); - let star_import_sym_table = self.imported_symbol_tables.get(&id); + for id in star_import.resolved_ids.iter() { + let star_import_sym_table = self.imported_symbol_tables.get(id); let Some(sym_table) = star_import_sym_table else { - panic!("symbol table of star import not found at {:?}", resolved); + panic!("symbol table of star import not found at {:?}", id); }; let res = sym_table.lookup_in_scope(&lookup_request); match res { @@ -511,11 +507,8 @@ impl<'a> TypeEvaluator<'a> { let result = match decl { Declaration::Variable(v) => { if let Some(type_annotation) = &v.type_annotation { - let var_type = self.get_type_from_annotation( - type_annotation, - symbol_table, - Some(decl_scope), - ); + let var_type = + self.get_annotation_type(type_annotation, symbol_table, Some(decl_scope)); if type_annotation .as_name() @@ -556,15 +549,14 @@ impl<'a> TypeEvaluator<'a> { } } Declaration::Function(f) => { - let annotated_return_type = if let Some(ref type_annotation) = - f.function_node.returns - { - self.get_type_from_annotation(type_annotation, symbol_table, Some(decl_scope)) - } else { - // TODO: infer return type of function disabled because of recursive types - // self.infer_function_return_type(f) - PythonType::Any - }; + let annotated_return_type = + if let Some(ref type_annotation) = f.function_node.returns { + self.get_annotation_type(type_annotation, symbol_table, Some(decl_scope)) + } else { + // TODO: infer return type of function disabled because of recursive types + // self.infer_function_return_type(f) + PythonType::Any + }; let arguments = f.function_node.args.clone(); let name = f.function_node.name.clone(); @@ -577,11 +569,7 @@ impl<'a> TypeEvaluator<'a> { } Declaration::Parameter(p) => { if let Some(type_annotation) = &p.type_annotation { - Ok(self.get_type_from_annotation( - type_annotation, - symbol_table, - Some(decl_scope), - )) + Ok(self.get_annotation_type(type_annotation, symbol_table, Some(decl_scope))) } else { // TODO: Implement self and cls parameter types Ok(PythonType::Unknown) @@ -595,19 +583,41 @@ impl<'a> TypeEvaluator<'a> { match &a.symbol_name { Some(name) => { log::debug!("finding alias with name {name:?}"); - let import_result = &a.import_result; - for resolved_path in import_result.resolved_paths.iter() { - log::debug!("checking path {:?}", resolved_path); - let symbol_table_with_alias_def = { - let id = self.ids.get(resolved_path).unwrap(); - self.imported_symbol_tables.get(&id).unwrap() + let import_result = + a.import_result.clone().expect("import result not found"); + log::debug!("import result {:?}", import_result); + for id in import_result.resolved_ids.iter() { + log::debug!("checking path {:?}", id); + let Some(symbol_table_with_alias_def) = + self.imported_symbol_tables.get(id) + else { + panic!( + " symbol table id {:?} with not found in import {:?}", + id, import_result + ); + }; + + if let Some(symbol_table_file_name) = + symbol_table_with_alias_def.file_path.file_stem() + { + if symbol_table_file_name + .to_str() + .is_some_and(|s| s == name.as_str()) + { + return Ok(PythonType::Module(ModuleRef { + module_id: symbol_table_with_alias_def.id, + })); + } }; // sys/__init__.pyi imports sys itself don't know why // If the resolved path is same as current symbol file path // then it's cyclic and do not resolve - if symbol_table.file_path.as_path() == resolved_path { - log::debug!("alias resolution skipped"); + if symbol_table.id == *id { + log::debug!( + "alias resolution skipped the import {:?}", + import_result + ); continue; } @@ -619,18 +629,13 @@ impl<'a> TypeEvaluator<'a> { return self.get_symbol_type(current_symbol_lookup); }; - for star_import in symbol_table.star_imports.iter() { + for star_import in symbol_table_with_alias_def.star_imports.iter() { log::debug!("checking star imports {:?}", star_import); - for resolved in star_import.resolved_paths.iter() { - log::debug!("checking path {:?}", resolved); - let id = self.ids.get(resolved).expect("module not found"); - let star_import_sym_table = - self.imported_symbol_tables.get(&id); + for id in star_import.resolved_ids.iter() { + log::debug!("checking path {:?}", id); + let star_import_sym_table = self.imported_symbol_tables.get(id); let Some(sym_table) = star_import_sym_table else { - panic!( - "symbol table of star import not found at {:?}", - resolved - ); + panic!("symbol table of star import not found at {:?}", id); }; let res = sym_table.lookup_in_scope(lookup); match res { @@ -643,90 +648,17 @@ impl<'a> TypeEvaluator<'a> { } } - for (_, implicit_import) in import_result.implicit_imports.iter() { - let resolved_path = &implicit_import.path; - log::debug!("checking path {:?}", resolved_path); - let id = self.ids.get(resolved_path).unwrap(); - let Some(symbol_table_with_alias_def) = - self.imported_symbol_tables.get(&id) - else { - panic!("Symbol table not found for alias: {:?}", resolved_path); - }; - // let Some(symbol_table_with_alias_def) = self.get_symbol_table_of(resolved_path) else { - // panic!("Symbol table not found for alias: {:?}", resolved_path); - // }; - - let lookup_request = LookupSymbolRequest { name, scope: None }; - let find_in_current_symbol_table = - symbol_table_with_alias_def.lookup_in_scope(&lookup_request); - - if let Some(res) = find_in_current_symbol_table { - log::debug!("alias resolved to {:?}", res); - return self.get_symbol_type(res); - }; - - log::debug!( - "did not find symbol {} in symbol table, checking star imports", - lookup_request.name - ); - // Check if there's any import * and try to find the symbol in those files - for star_import in symbol_table.star_imports.iter() { - log::debug!("checking star imports {:?}", star_import); - for resolved in star_import.resolved_paths.iter() { - log::debug!("checking path {:?}", resolved); - let id = self.ids.get(resolved).unwrap(); - let star_import_sym_table = - self.imported_symbol_tables.get(&id); - let Some(sym_table) = star_import_sym_table else { - panic!( - "symbol table of star import not found at {:?}", - resolved - ); - }; - let res = sym_table.lookup_in_scope(&lookup_request); - match res { - Some(res) => { - log::debug!("alias resolved to {:?}", res); - return self.get_symbol_type(res); - } - None => continue, - }; - } - } - } - log::debug!("import not found checking if it's a module"); - let mut res = Ok(PythonType::Unknown); - for (_, implicit_import) in a.import_result.implicit_imports.iter() { - let resolved_path = &implicit_import.path; - let Some(file_stem) = resolved_path.file_stem() else { - continue; - }; - if name == file_stem.to_str().unwrap() { - res = Ok(PythonType::Module(crate::types::ModuleRef { - module_path: resolved_path.to_path_buf(), - })) - } - } - - res + Ok(PythonType::Unknown) } None => { - let mut found_module: Option = None; - for resolved_path in a.import_result.resolved_paths.iter() { - let id = self.ids.get(resolved_path).unwrap(); - let Some(sym_table_alias_pointing_to) = - self.imported_symbol_tables.get(&id) - else { - break; - }; - found_module = Some(PythonType::Module(crate::types::ModuleRef { - module_path: sym_table_alias_pointing_to.file_path.clone(), - })); - } - match found_module { - Some(s) => Ok(s), - None => Ok(PythonType::Unknown), - } + let Some(ref resolved_import) = a.import_result else { + return Ok(PythonType::Unknown); + }; + + let module_id = resolved_import.resolved_ids.first().unwrap(); + return Ok(PythonType::Module(ModuleRef { + module_id: *module_id, + })); } } } @@ -932,11 +864,7 @@ impl<'a> TypeEvaluator<'a> { return_type: f.function_node.returns.clone().map_or( PythonType::Unknown, |type_annotation| { - self.get_type_from_annotation( - &type_annotation, - bulitins_symbol_table, - None, - ) + self.get_annotation_type(&type_annotation, bulitins_symbol_table, None) }, ), })) @@ -988,7 +916,7 @@ impl<'a> TypeEvaluator<'a> { fn handle_union_type(&self, expressions: Vec) -> PythonType { let mut types = vec![]; for expr in expressions { - let t = self.get_type_from_annotation(&expr, &self.symbol_table, None); + let t = self.get_annotation_type(&expr, &self.symbol_table, None); if self.is_valid_union_parameter(&t) { types.push(t); } diff --git a/typechecker/src/types.rs b/typechecker/src/types.rs index 91ef9e68..8797dd33 100644 --- a/typechecker/src/types.rs +++ b/typechecker/src/types.rs @@ -4,7 +4,7 @@ use std::path::PathBuf; use enderpy_python_parser::ast; -use crate::symbol_table; +use crate::symbol_table::{self, Id}; #[derive(Debug, Clone, PartialEq, Eq, Is)] pub enum PythonType { @@ -194,7 +194,7 @@ impl Display for LiteralValue { #[derive(Debug, Clone, Eq, PartialEq)] pub struct ModuleRef { - pub module_path: PathBuf, + pub module_id: Id, } impl Display for PythonType { diff --git a/typechecker/test_data/output/enderpy_python_type_checker__build__tests__symbols_import_star.snap b/typechecker/test_data/output/enderpy_python_type_checker__build__tests__symbols_import_star.snap index e32cc306..2b2c2715 100644 --- a/typechecker/test_data/output/enderpy_python_type_checker__build__tests__symbols_import_star.snap +++ b/typechecker/test_data/output/enderpy_python_type_checker__build__tests__symbols_import_star.snap @@ -1,9 +1,10 @@ --- source: typechecker/src/build.rs -description: "from .b import *\nimport os\n\nprint(in_b)\n\nos.path.dirname(\"\")\n" +description: "from .b import *\nimport os\n\nprint(in_b)\n\nos.path\nos.path.dirname(\"\")\n" expression: result --- -[ImportResult { is_relative: true, is_import_found: true, is_partly_resolved: false, is_namespace_package: false, is_init_file_present: false, is_stub_package: false, import_type: Local, resolved_paths: ["test_data/inputs/import_star_test/b.py"], search_path: Some("test_data/inputs/import_star_test"), is_stub_file: false, is_native_lib: false, is_stdlib_typeshed_file: false, is_third_party_typeshed_file: false, is_local_typings_file: false, implicit_imports: ImplicitImports({}), filtered_implicit_imports: ImplicitImports({}), non_stub_import_result: None, py_typed_info: None, package_directory: None }] +Star imports: +[Id(REDACTED)] Symbols in global os - declaration: Alias - properties: SymbolFlags(0x0) - Declarations: diff --git a/typechecker/test_data/output/enderpy_python_type_checker__build__tests__symbols_imports.snap b/typechecker/test_data/output/enderpy_python_type_checker__build__tests__symbols_imports.snap index 9631f6f9..3cc15c19 100644 --- a/typechecker/test_data/output/enderpy_python_type_checker__build__tests__symbols_imports.snap +++ b/typechecker/test_data/output/enderpy_python_type_checker__build__tests__symbols_imports.snap @@ -3,7 +3,9 @@ source: typechecker/src/build.rs description: "import variables\nimport import_test\n\nfrom variables import a\nfrom variables import *\n\nimport os.path\n\nfrom os import *\n\nfrom os.path import join\n" expression: result --- -[ImportResult { is_relative: false, is_import_found: true, is_partly_resolved: false, is_namespace_package: false, is_init_file_present: false, is_stub_package: false, import_type: Local, resolved_paths: ["test_data[TYPESHED]/stdlib/os") }] +Star imports: +[Id(REDACTED)] +[Id(REDACTED), Id(REDACTED)] Symbols in global a - declaration: Alias - properties: SymbolFlags(0x0) - Declarations: diff --git a/typechecker/test_data/output/enderpy_python_type_checker__checker__tests__basic_generics.snap b/typechecker/test_data/output/enderpy_python_type_checker__checker__tests__basic_generics.snap index ecadc88d..be52b40c 100644 --- a/typechecker/test_data/output/enderpy_python_type_checker__checker__tests__basic_generics.snap +++ b/typechecker/test_data/output/enderpy_python_type_checker__checker__tests__basic_generics.snap @@ -655,15 +655,15 @@ Line 161: from collections.abc import Sized, Container Expr types in the line --->: collections.abc => Module - Sized => Unknown - Container => Unknown + Sized => (class) Sized + Container => (class) typing.Container[TypeVar[_T_co, ]] --- Line 164: class LinkedList(Sized, Generic[T]): ... Expr types in the line --->: class LinkedList => (class) basic_generics.LinkedList[TypeVar[T, ]] - Sized => Unknown + Sized => (class) Sized Generic => (class) Generic Generic[T] => (class) Generic T => TypeVar[T, ] @@ -681,8 +681,8 @@ Expr types in the line --->: K => TypeVar[K, ] K, V] => (class) builtins.tuple[Unknown] V => TypeVar[V, ] - Container => Unknown - Container[tuple[K, V]] => Unknown + Container => (class) typing.Container[TypeVar[_T_co, ]] + Container[tuple[K, V]] => (class) typing.Container[TypeVar[_T_co, ]] tuple => (class) builtins.tuple[TypeVar[_T_co, ]] tuple[K, V] => (class) builtins.tuple[TypeVar[_T_co, ]] K => TypeVar[K, ] diff --git a/typechecker/test_data/output/enderpy_python_type_checker__checker__tests__basic_types.snap b/typechecker/test_data/output/enderpy_python_type_checker__checker__tests__basic_types.snap index ef0629e3..097d9ec1 100644 --- a/typechecker/test_data/output/enderpy_python_type_checker__checker__tests__basic_types.snap +++ b/typechecker/test_data/output/enderpy_python_type_checker__checker__tests__basic_types.snap @@ -8,7 +8,7 @@ Line 1: from typing import Dict, Set, List Expr types in the line --->: typing => Module Dict => (class) builtins.dict[TypeVar[_KT, ], TypeVar[_VT, ]] - Set => (class) set + Set => (class) builtins.set[TypeVar[_T, ]] List => (class) builtins.list[TypeVar[_T, ]] --- @@ -155,8 +155,8 @@ Line 28: def get_attr(self) -> Set[int]: Expr types in the line --->: def get_attr => (function) get_attr self => Unknown - Set => (class) set - Set[int] => (class) set + Set => (class) builtins.set[TypeVar[_T, ]] + Set[int] => (class) builtins.set[TypeVar[_T, ]] int => (class) int ---