Skip to content

Commit

Permalink
fix(type): resolve builtin impor when user file has no imports
Browse files Browse the repository at this point in the history
  • Loading branch information
Glyphack committed Apr 1, 2024
1 parent 522f2e8 commit 68fe518
Show file tree
Hide file tree
Showing 8 changed files with 87 additions and 91 deletions.
34 changes: 14 additions & 20 deletions typechecker/src/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,19 @@ impl BuildManager {
log::debug!("Initialized build manager");
log::debug!("build sources: {:?}", sources);
log::debug!("options: {:?}", options);
let builtins_file = options
.import_discovery
.typeshed_path
.clone()
.unwrap()
.join("stdlib/builtins.pyi");
let builtins = BuildSource::from_path(builtins_file, true).expect("cannot read builtins");
let mut sources_with_builtins = vec![builtins.clone()];
sources_with_builtins.extend(sources.clone());

BuildManager {
errors: vec![],
build_sources: sources,
build_sources: sources_with_builtins,
modules,
options,
diagnostics: HashMap::new(),
Expand Down Expand Up @@ -79,21 +88,6 @@ impl BuildManager {
let state: EnderpyFile = build_source.into();
self.modules.insert(state.module_name(), state);
}
let builtins_file = self
.options
.import_discovery
.typeshed_path
.clone()
.unwrap()
.join("stdlib/builtins.pyi");
let builtins = BuildSource::from_path(builtins_file, true);
match builtins {
Ok(b) => {
let file: EnderpyFile = b.into();
self.modules.insert(file.module_name(), file)
}
Err(e) => panic!("error loading builtins file: {}", e),
};
let (new_files, imports) = match self.options.follow_imports {
crate::settings::FollowImports::All => {
self.gather_files(self.build_sources.clone(), true)
Expand All @@ -108,7 +102,7 @@ impl BuildManager {
}
for module in self.modules.values_mut() {
info!("file: {:#?}", module.module_name());
module.populate_symbol_table(&imports);
module.populate_symbol_table(imports.clone());
}
}

Expand Down Expand Up @@ -388,15 +382,15 @@ impl BuildManager {
continue;
}
log::debug!(
"resolved import: {} -> {:?}",
"{:?} resolved import: {} -> {:?}",
file.path(),
import_desc.name(),
resolved.resolved_paths
);
imports.insert(import_desc, resolved.clone());
}
}

imports.clone()
imports
}
}

Expand Down
1 change: 0 additions & 1 deletion typechecker/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ mod ast_visitor;
mod ast_visitor_generic;
mod nodes;
mod ruff_python_import_resolver;
mod semanal_utils;
mod symbol_table;
mod type_check;

Expand Down
2 changes: 1 addition & 1 deletion typechecker/src/nodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ impl EnderpyFile {
/// entry point to fill up the symbol table from the global definitions
pub fn populate_symbol_table(
&mut self,
imports: &HashMap<ImportModuleDescriptor, ImportResult>,
imports: HashMap<ImportModuleDescriptor, ImportResult>,
) {
let mut sem_anal = SemanticAnalyzer::new(self.clone(), imports.clone());
for stmt in &self.body {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,8 @@ impl From<&Alias> for ImportModuleDescriptor {
leading_dots: 0,
name_parts: alias
.name
.chars()
.skip_while(|c| *c == '.')
.collect::<String>()
.split('.')
.map(|s| s.trim())
.map(std::string::ToString::to_string)
.collect(),
imported_symbols: vec![],
Expand All @@ -44,10 +42,8 @@ impl From<&ImportFrom> for ImportModuleDescriptor {
leading_dots: import_from.level,
name_parts: import_from
.module
.chars()
.skip_while(|c| *c == '.')
.collect::<String>()
.split('.')
.map(|s| s.trim())
.map(std::string::ToString::to_string)
.collect(),
imported_symbols: import_from.names.iter().map(|x| x.name.clone()).collect(),
Expand Down
1 change: 0 additions & 1 deletion typechecker/src/semanal_utils.rs

This file was deleted.

17 changes: 11 additions & 6 deletions typechecker/src/semantic_analyzer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::collections::HashMap;

use enderpy_python_parser as parser;
use enderpy_python_parser::ast::Expression;
use log::info;
use parser::ast::{FunctionDef, GetNode, Name, Statement};

use crate::{
Expand Down Expand Up @@ -37,7 +38,11 @@ pub struct SemanticAnalyzer {
#[allow(unused)]
impl SemanticAnalyzer {
pub fn new(file: EnderpyFile, imports: HashMap<ImportModuleDescriptor, ImportResult>) -> Self {
log::debug!("Creating semantic analyzer for {}", file.module_name());
log::debug!(
"Creating semantic analyzer for {} with import count {}",
file.module_name(),
imports.len()
);
let symbols = SymbolTable::new(file.module_name(), file.path());
let is_pyi = file.path().ends_with(".pyi");
SemanticAnalyzer {
Expand Down Expand Up @@ -394,22 +399,22 @@ impl TraversalVisitor for SemanticAnalyzer {
}

fn visit_import_from(&mut self, _i: &parser::ast::ImportFrom) {
let module_import_result = match self.imports.get(&ImportModuleDescriptor::from(_i)) {
Some(result) => result.clone(),
None => ImportResult::not_found(),
};
for alias in &_i.names {
let declaration_path = DeclarationPath::new(
self.file.path(),
alias.node,
self.symbol_table.current_scope_id,
);
let module_import_result = match self.imports.get(&ImportModuleDescriptor::from(_i)) {
Some(result) => result.clone(),
None => ImportResult::not_found(),
};
let declaration = Declaration::Alias(Alias {
declaration_path,
import_from_node: Some(_i.clone()),
import_node: None,
symbol_name: Some(alias.name()),
import_result: module_import_result,
import_result: module_import_result.clone(),
});

let flags = SymbolFlags::empty();
Expand Down
19 changes: 9 additions & 10 deletions typechecker/src/symbol_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -234,22 +234,17 @@ impl Class {

/// Class node refers to SpecialForm in typeshed
/// TODO: needs improvements mostly set the correct values
pub fn new_special(name: String) -> Self {
pub fn new_special(name: String, declaration_path: DeclarationPath) -> Self {
Class {
name,
declaration_path: DeclarationPath::new(
PathBuf::from("../../typeshed/stdlib/builtins.pyi"),
Node::new(0, 0),
0,
),
declaration_path,
methods: vec![],
special: true,
class_node: None,
}
}

pub fn get_qualname(&self) -> String {
log::debug!("Getting qualname for class: {}", self.name);
let scope = self.declaration_path.module_name.clone();
let mut qualname = scope
.file_stem()
Expand All @@ -259,7 +254,6 @@ impl Class {
.to_string();
qualname.push('.');
qualname.push_str(&self.name);
log::debug!("Qualname: {}", qualname);
qualname
}
}
Expand Down Expand Up @@ -473,7 +467,7 @@ impl SymbolTable {
}

pub fn add_symbol(&mut self, mut symbol_node: SymbolTableNode) {
log::debug!("Adding symbol: {}", symbol_node);
let file = self.file_path.clone();
let scope = if symbol_node.flags.contains(SymbolFlags::CLASS_MEMBER)
|| symbol_node.flags.contains(SymbolFlags::INSTANCE_MEMBER)
{
Expand All @@ -490,7 +484,12 @@ impl SymbolTable {
self.current_scope_mut()
};

log::debug!("Adding symbol {} to scope: {}", symbol_node, scope.name);
log::debug!(
"Adding symbol {:?} to scope: {} in file {:?}",
symbol_node,
scope.name,
file
);
if let Some(existing_symbol) = scope.symbols.get(&symbol_node.name) {
symbol_node
.declarations
Expand Down
96 changes: 50 additions & 46 deletions typechecker/src/type_check/type_evaluator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,8 @@ impl TypeEvaluator {
.as_name()
.is_some_and(|name| name.id == SPECIAL_FORM)
{
let class_symbol = Class::new_special(symbol.name.to_string());
let class_symbol =
Class::new_special(symbol.name.to_string(), v.declaration_path.clone());
Ok(PythonType::Class(ClassType::new(class_symbol, vec![])))
} else {
Ok(var_type)
Expand Down Expand Up @@ -585,49 +586,49 @@ impl TypeEvaluator {
let mut class_def_type_parameters = vec![];
for base in bases {
let base_type = self.get_type(&base, Some(symbol_table), None);
// TODO: stack overflow here
log::debug!("base is {:?} base type: {:?}", base, base_type);
if let Ok(PythonType::Class(c)) = base_type {
log::debug!("qualname: {}", c.details.get_qualname());
if c.details.get_qualname() == "builtins.Generic"
|| c.details.get_qualname() == "typing.Protocol"
{
let Some(type_parameter) = base.as_subscript() else {
bail!("Generic class must have subscript");
};
log::debug!("base class is generic: {:?}", class_def_type_parameters);
match *type_parameter.slice {
ast::Expression::Name(ref type_parameter_name) => {
let type_parameter_type = self
.infer_type_from_symbol_table(
&type_parameter_name.id,
Some(type_parameter_name.node.start),
symbol_table,
Some(decl_scope),
)
.context(
"Getting type of the type parameter {type_parameter_name}",
)?;
class_def_type_parameters.push(type_parameter_type);
}
ast::Expression::Tuple(ref type_parameters) => {
for type_parameter in type_parameters.elements.iter() {
let type_parameter_type = self
.get_type(type_parameter, None, Some(decl_scope))
.context("Getting type of the type parameter")?;
if class_def_type_parameters.contains(&type_parameter_type) {
bail!("Duplicate type parameter");
}
class_def_type_parameters.push(type_parameter_type);
}
}
_ => bail!("Type parameter must be a name"),
};
}

c.type_parameters.iter().for_each(|t| {
class_def_type_parameters.push(t.clone());
});
}
// if let Ok(PythonType::Class(c)) = base_type {
// let Some(type_parameter) = base.as_subscript() else {
// continue;
// };
// log::debug!("base class is generic: {:?}", class_def_type_parameters);
// match type_parameter.slice.as_ref() {
// ast::Expression::Name(ref type_parameter_name) => {
// let type_parameter_type = self
// .infer_type_from_symbol_table(
// &type_parameter_name.id,
// Some(type_parameter_name.node.start),
// symbol_table,
// Some(decl_scope),
// )
// .context(
// "Getting type of the type parameter {type_parameter_name}",
// )?;
// if !class_def_type_parameters.contains(&type_parameter_type) {
// class_def_type_parameters.push(type_parameter_type);
// }
// }
// ast::Expression::Tuple(ref type_parameters) => {
// let mut tuple_type_parameters = vec![];
// for type_parameter in type_parameters.elements.iter() {
// let type_parameter_type = self
// .get_type(type_parameter, None, Some(decl_scope))
// .context("Getting type of the type parameter")?;
// if tuple_type_parameters.contains(&type_parameter_type) {
// bail!("Duplicate type parameter");
// }
// tuple_type_parameters.push(type_parameter_type);
// }
// for type_parameter in tuple_type_parameters {
// if !class_def_type_parameters.contains(&type_parameter) {
// class_def_type_parameters.push(type_parameter);
// }
// }
// }
// _ => bail!("Type parameter must be a name"),
// };
// }
}
Ok(PythonType::Class(ClassType::new(
c.clone(),
Expand Down Expand Up @@ -911,7 +912,7 @@ impl TypeEvaluator {
LiteralValue::Str(value)
}
Expression::Subscript(s) => {
match *s.value.clone() {
match s.value.as_ref() {
Expression::Name(n) => {
if !self.is_literal(n.id.clone()) {
panic!("{}", LITERAL_TYPE_PARAMETER_MSG)
Expand Down Expand Up @@ -1033,7 +1034,10 @@ impl TypeEvaluator {
let class = self
.get_class_declaration(annotation, found_in_symbol_table)?;
if class.name == SPECIAL_FORM {
return Some(Class::new_special(n.id.clone()));
return Some(Class::new_special(
n.id.clone(),
v.declaration_path.clone(),
));
}
Some(class)
}
Expand Down Expand Up @@ -1077,7 +1081,7 @@ impl TypeEvaluator {

let resolved_path = match a.import_result.resolved_paths.last() {
Some(path) => path,
None => panic!("Alias {:?} has no resolved path", a.import_node),
None => panic!("Alias {:?} has no resolved path", a),
};

// TODO: This is a hack to resolve Iterator alias in sys/__init__.pyi
Expand Down

0 comments on commit 68fe518

Please sign in to comment.