diff --git a/typechecker/src/semantic_analyzer.rs b/typechecker/src/semantic_analyzer.rs index 511252de..43f2550b 100644 --- a/typechecker/src/semantic_analyzer.rs +++ b/typechecker/src/semantic_analyzer.rs @@ -503,9 +503,37 @@ impl TraversalVisitor for SemanticAnalyzer { ); } let mut methods = vec![]; + let mut attributes = HashMap::new(); for stmt in &c.body { match stmt { parser::ast::Statement::FunctionDef(f) => { + if f.name == "__init__" { + for stmt in &f.body { + match stmt { + parser::ast::Statement::AssignStatement(assign) => { + for target in &assign.targets { + match target { + parser::ast::Expression::Attribute(attr) => { + match *attr.value { + parser::ast::Expression::Name(ref n) => { + if n.id == "self" { + attributes.insert( + attr.attr.clone(), + assign.value.clone(), + ); + } + } + _ => (), + } + } + _ => (), + } + } + } + _ => (), + } + } + } methods.push(f.name.clone()); } _ => (), @@ -517,6 +545,7 @@ impl TraversalVisitor for SemanticAnalyzer { let class_declaration = Declaration::Class(Class { declaration_path, + attributes, methods, }); self.create_symbol(c.name.clone(), class_declaration); diff --git a/typechecker/src/symbol_table.rs b/typechecker/src/symbol_table.rs index 643cc0b5..71f2cdcc 100644 --- a/typechecker/src/symbol_table.rs +++ b/typechecker/src/symbol_table.rs @@ -115,6 +115,10 @@ pub struct Class { // Method names, can be used to look up the function in the symbol table // of the class pub methods: Vec, + // instance attibutes that are defined in the __init__ method + // if the attribute is referencing another symbol we need to look up that symbol in the + // __init__ method + pub attributes: HashMap, } #[derive(Debug, Clone)] diff --git a/typechecker/test_data/inputs/symbol_table/class_definition.py b/typechecker/test_data/inputs/symbol_table/class_definition.py index 7b12b386..e50e5d29 100644 --- a/typechecker/test_data/inputs/symbol_table/class_definition.py +++ b/typechecker/test_data/inputs/symbol_table/class_definition.py @@ -1,3 +1,5 @@ class c: def __init__(self): a = 1 + b = a + self.c = b diff --git a/typechecker/testdata/output/enderpy_python_type_checker__build__tests__symbol_table@class_definition.py.snap b/typechecker/testdata/output/enderpy_python_type_checker__build__tests__symbol_table@class_definition.py.snap index cf0957f3..b646eb17 100644 --- a/typechecker/testdata/output/enderpy_python_type_checker__build__tests__symbol_table@class_definition.py.snap +++ b/typechecker/testdata/output/enderpy_python_type_checker__build__tests__symbol_table@class_definition.py.snap @@ -1,6 +1,6 @@ --- source: typechecker/src/build.rs -description: "class c:\n def __init__(self):\n a = 1\n" +description: "class c:\n def __init__(self):\n a = 1\n b = a\n self.c = b\n" expression: result input_file: typechecker/test_data/inputs/symbol_table/class_definition.py --- @@ -14,12 +14,23 @@ c module_name: [REDACTED]", node: Node { start: 0, - end: 47, + end: 80, }, }, methods: [ "__init__", ], + attributes: { + "c": Name( + Name { + node: Node { + start: 78, + end: 79, + }, + id: "b", + }, + ), + }, } all scopes: @@ -49,6 +60,31 @@ a ), is_constant: false, } +b +- Declarations: +--: Variable { + declaration_path: DeclarationPath { + module_name: [REDACTED]", + node: Node { + start: 55, + end: 60, + }, + }, + scope: Global, + type_annotation: None, + inferred_type_source: Some( + Name( + Name { + node: Node { + start: 59, + end: 60, + }, + id: "a", + }, + ), + ), + is_constant: false, +} self - Declarations: --: Paramter { @@ -79,13 +115,13 @@ __init__ module_name: [REDACTED]", node: Node { start: 13, - end: 47, + end: 80, }, }, function_node: FunctionDef { node: Node { start: 13, - end: 47, + end: 80, }, name: "__init__", args: Arguments { @@ -139,6 +175,71 @@ __init__ ), }, ), + AssignStatement( + Assign { + node: Node { + start: 55, + end: 60, + }, + targets: [ + Name( + Name { + node: Node { + start: 55, + end: 56, + }, + id: "b", + }, + ), + ], + value: Name( + Name { + node: Node { + start: 59, + end: 60, + }, + id: "a", + }, + ), + }, + ), + AssignStatement( + Assign { + node: Node { + start: 69, + end: 79, + }, + targets: [ + Attribute( + Attribute { + node: Node { + start: 69, + end: 75, + }, + value: Name( + Name { + node: Node { + start: 69, + end: 73, + }, + id: "self", + }, + ), + attr: "c", + }, + ), + ], + value: Name( + Name { + node: Node { + start: 78, + end: 79, + }, + id: "b", + }, + ), + }, + ), ], decorator_list: [], returns: None,