Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add known value type #203

Merged
merged 7 commits into from
Nov 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions typechecker/src/symbol_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,16 +126,16 @@ pub struct Function {
impl Function {
pub fn is_abstract(&self) -> bool {
if !self.is_method {
return false
return false;
}
for decorator in self.function_node.decorator_list.iter() {
match &decorator {
ast::Expression::Name(n) => {
if &n.id == "abstractmethod" {
return true
}
return true;
}
}
_ => {},
_ => {}
}
}
false
Expand Down
4 changes: 4 additions & 0 deletions typechecker/src/type_check/test_data/inputs/literal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from typing import Literal


a: Literal["foo"] = "foo"
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
---
source: typechecker/src/type_check/type_evaluator.rs
description: "from typing import Literal\n\n\na: Literal[\"foo\"] = \"foo\"\n"
expression: result
input_file: typechecker/src/type_check/test_data/inputs/literal.py
---
[
(
"(line: 3, character: 20):(line: 3, character: 25)",
Str,
),
(
"(line: 3, character: 3):(line: 3, character: 17)",
KnownValue(
KnownValue {
literal_value: Str(
"foo",
),
},
),
),
]
56 changes: 37 additions & 19 deletions typechecker/src/type_check/type_evaluator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ use crate::{
nodes::EnderpyFile,
state::State,
symbol_table::{Declaration, LookupSymbolRequest, SymbolTable, SymbolTableNode},
type_check::{
type_inference::get_type_from_annotation,
},
};

use super::{
Expand Down Expand Up @@ -59,7 +62,9 @@ impl TypeEvaluator {
log::debug!("get_type: {:?}", expr);
match expr {
ast::Expression::Constant(c) => {
let typ = match c.value {
let typ = match &c.value {
// We should consider constants are not literals unless they are explicitly declared as such
// https://peps.python.org/pep-0586/#type-inference
ast::ConstantValue::Int(_) => PythonType::Int,
ast::ConstantValue::Float(_) => PythonType::Float,
ast::ConstantValue::Str(_) => PythonType::Str,
Expand Down Expand Up @@ -219,13 +224,14 @@ impl TypeEvaluator {
}
}
Declaration::Function(f) => {
let annotated_return_type = if let Some(type_annotation) = f.function_node.returns.clone() {
type_inference::get_type_from_annotation(&type_annotation)
} else {
let inferred_return_type = self.infer_function_return_type(f);
log::debug!("infered_return_type: {:?}", inferred_return_type);
inferred_return_type
};
let annotated_return_type =
if let Some(type_annotation) = f.function_node.returns.clone() {
type_inference::get_type_from_annotation(&type_annotation)
} else {
let inferred_return_type = self.infer_function_return_type(f);
log::debug!("infered_return_type: {:?}", inferred_return_type);
inferred_return_type
};

let arguments = f.function_node.args.clone();
let name = f.function_node.name.clone();
Expand Down Expand Up @@ -305,7 +311,7 @@ impl TypeEvaluator {
}
}
if f.return_statements.is_empty() {
return PythonType::None;
PythonType::None
} else {
let mut return_types = vec![];
for return_statement in &f.return_statements {
Expand All @@ -314,10 +320,10 @@ impl TypeEvaluator {
}
}
if return_types.len() == 1 {
return return_types[0].clone();
return_types[0].clone()
} else {
// TODO: Union type
return PythonType::Unknown;
PythonType::Unknown
}
}
}
Expand Down Expand Up @@ -682,10 +688,16 @@ impl TypeEvalVisitor {

/// This function is called on every expression in the ast
pub fn save_type(&mut self, expr: &ast::Expression) {
let typ = self
.type_eval
.get_type(expr)
.unwrap_or(PythonType::Unknown);
let typ = self.type_eval.get_type(expr).unwrap_or(PythonType::Unknown);
log::debug!("save_type: {:?} => {:?}", expr, typ);
let start_pos = self.enderpy_file().get_position(expr.get_node().start);
let end_pos = self.enderpy_file().get_position(expr.get_node().end);
self.types.insert(format!("{}:{}", start_pos, end_pos), typ);
}

// TODO: move type annotation tests to its own file
pub fn save_type_annotation(&mut self, expr: &ast::Expression) {
let typ = get_type_from_annotation(expr);
log::debug!("save_type: {:?} => {:?}", expr, typ);
let start_pos = self.enderpy_file().get_position(expr.get_node().start);
let end_pos = self.enderpy_file().get_position(expr.get_node().end);
Expand All @@ -706,12 +718,18 @@ impl TraversalVisitor for TypeEvalVisitor {
// map all statements and call visit
match s {
ast::Statement::ExpressionStatement(e) => self.visit_expr(e),
ast::Statement::Import(i) => {},
ast::Statement::ImportFrom(i) => {},
ast::Statement::Import(i) => {}
ast::Statement::ImportFrom(i) => {}
ast::Statement::AssignStatement(a) => {
self.save_type(&a.value);
}
ast::Statement::AnnAssignStatement(a) => self.visit_ann_assign(a),
ast::Statement::AnnAssignStatement(a) => {
match a.value.as_ref() {
Some(v) => self.save_type(v),
None => {}
}
self.save_type_annotation(&a.annotation)
}
ast::Statement::AugAssignStatement(a) => self.visit_aug_assign(a),
ast::Statement::Assert(a) => self.visit_assert(a),
ast::Statement::Pass(p) => self.visit_pass(p),
Expand Down Expand Up @@ -744,7 +762,7 @@ impl TraversalVisitor for TypeEvalVisitor {
for stmt in &f.body {
self.visit_stmt(stmt);
}
},
}
ast::Statement::ClassDef(c) => self.visit_class_def(c),
ast::Statement::Match(m) => self.visit_match(m),
Statement::AsyncForStatement(f) => self.visit_async_for(f),
Expand Down
132 changes: 119 additions & 13 deletions typechecker/src/type_check/type_inference.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
#![allow(dead_code)]
#![allow(unused_variables)]
use core::panic;

use enderpy_python_parser::ast::{self, BinaryOperator, Expression};
/// This module is resonsible for ineferring type from annotations or python expressions.
use enderpy_python_parser::ast::{self, BinaryOperator, Expression, Subscript};

use super::{builtins, types::PythonType};
use super::{
builtins,
types::{LiteralValue, PythonType},
};

const LITERAL_TYPE_PARAMETER_MSG: &str = "Type arguments for 'Literal' must be None, a literal value (int, bool, str, or bytes), or an enum value";

pub fn get_type_from_annotation(type_annotation: &ast::Expression) -> PythonType {
log::debug!("Getting type from annotation: {:?}", type_annotation);
let expr_type = match type_annotation {
ast::Expression::Name(name) => match name.id.as_str() {
"int" => PythonType::Int,
Expand All @@ -15,16 +23,8 @@ pub fn get_type_from_annotation(type_annotation: &ast::Expression) -> PythonType
"None" => PythonType::None,
_ => PythonType::Unknown,
},
Expression::Constant(c) => match c.value.clone() {
ast::ConstantValue::Int(_) => PythonType::Int,
ast::ConstantValue::Float(_) => PythonType::Float,
ast::ConstantValue::Str(_) => PythonType::Str,
ast::ConstantValue::Bool(_) => PythonType::Bool,
ast::ConstantValue::None | ast::ConstantValue::Ellipsis => PythonType::None,
ast::ConstantValue::Bytes(_) => todo!(),
ast::ConstantValue::Tuple(_) => todo!(),
ast::ConstantValue::Complex { real, imaginary } => todo!(),
},
// Illegal type annotation
Expression::Constant(c) => PythonType::Unknown,
Expression::Subscript(s) => {
// This is a generic type
let name = match *s.value.clone() {
Expand All @@ -33,7 +33,13 @@ pub fn get_type_from_annotation(type_annotation: &ast::Expression) -> PythonType
Expression::Tuple(_) => todo!(),
Expression::Dict(_) => todo!(),
Expression::Set(_) => todo!(),
Expression::Name(n) => get_builtin_type(n.id),
Expression::Name(n) => {
// TODO: handle builtins with enum
if is_literal(n.id.clone()) {
return handle_literal_type(s);
}
get_builtin_type(n.id)
}
Expression::BoolOp(_) => todo!(),
Expression::UnaryOp(_) => todo!(),
Expression::BinOp(_) => todo!(),
Expand Down Expand Up @@ -69,6 +75,99 @@ pub fn get_type_from_annotation(type_annotation: &ast::Expression) -> PythonType
expr_type
}

fn handle_literal_type(s: &Subscript) -> PythonType {
// Only simple parameters are allowed for literal type:
// https://peps.python.org/pep-0586/#legal-and-illegal-parameterizations
let value = get_literal_value_from_param(&s.slice.clone());
if value.len() > 1 {
todo!("MultiValue literal type is not supported yet")
}

PythonType::KnownValue(super::types::KnownValue {
literal_value: value.last().unwrap().clone(),
})
}

/// Write a function that takes in an expression which is a parameter to a literal type and returns
/// the LiteralValue of the parameter.
/// Literal values might contain a tuple, that's why the return type is a vector.
pub fn get_literal_value_from_param(expr: &Expression) -> Vec<LiteralValue> {
log::debug!("Getting literal value from param: {:?}", expr);
let val = match expr {
Expression::Constant(c) => {
match c.value.clone() {
ast::ConstantValue::Bool(b) => LiteralValue::Bool(b),
ast::ConstantValue::Int(i) => LiteralValue::Int(i),
ast::ConstantValue::Float(f) => LiteralValue::Float(f),
ast::ConstantValue::Str(s) => LiteralValue::Str(s),
ast::ConstantValue::Bytes(b) => LiteralValue::Bytes(b),
ast::ConstantValue::None => LiteralValue::None,
// Tuple is illegal if it has parantheses, otherwise it's allowed and the output a multiValued type
// Currently even mypy does not supoort this, who am I to do it?
// https://mypy-play.net/?mypy=latest&python=3.10&gist=0df0421d5c85f3b75f65a51cae8616ce
ast::ConstantValue::Tuple(t) => {
if t.len() == 1 {
match t[0].value.clone() {
ast::ConstantValue::Bool(b) => LiteralValue::Bool(b),
ast::ConstantValue::Int(i) => LiteralValue::Int(i),
ast::ConstantValue::Float(f) => LiteralValue::Float(f),
ast::ConstantValue::Str(s) => LiteralValue::Str(s),
ast::ConstantValue::Bytes(b) => LiteralValue::Bytes(b),
ast::ConstantValue::None => LiteralValue::None,
_ => panic!("Tuple type with illegal parameter"),
}
} else {
let literal_values = t
.iter()
.map(|c| match c.value.clone() {
ast::ConstantValue::Bool(b) => LiteralValue::Bool(b),
ast::ConstantValue::Int(i) => LiteralValue::Int(i),
ast::ConstantValue::Float(f) => LiteralValue::Float(f),
ast::ConstantValue::Str(s) => LiteralValue::Str(s),
ast::ConstantValue::Bytes(b) => LiteralValue::Bytes(b),
ast::ConstantValue::None => LiteralValue::None,
_ => panic!("Tuple type with illegal parameter"),
})
.collect();
return literal_values;
}
}
// Illegal parameter
ast::ConstantValue::Ellipsis => {
panic!("Literal type with ellipsis value is not supported")
}
ast::ConstantValue::Complex { real, imaginary } => {
panic!("Literal type with complex value is not supported")
}
}
}
// Only can be enum values
Expression::Attribute(a) => {
let value = match *a.value.clone() {
Expression::Name(n) => n.id,
_ => panic!("Literal type with attribute value can only be a name"),
};
LiteralValue::Str(value)
}
Expression::Subscript(s) => {
match *s.value.clone() {
Expression::Name(n) => {
if !is_literal(n.id.clone()) {
panic!("{}", LITERAL_TYPE_PARAMETER_MSG)
}
// When there is a literal inside a literal we flatten it
return get_literal_value_from_param(&s.slice);
}
_ => panic!("{}", LITERAL_TYPE_PARAMETER_MSG),
};
}
// Illegal parameter
_ => panic!("Literal type with illegal parameter, can only be a constant value or enum"),
};

vec![val]
}

pub fn type_equal(t1: &PythonType, t2: &PythonType) -> bool {
match (t1, t2) {
(PythonType::Int, PythonType::Int) => true,
Expand Down Expand Up @@ -180,6 +279,13 @@ pub fn get_builtin_type(name: String) -> String {
}
}

pub fn is_literal(name: String) -> bool {
match name.as_str() {
"Literal" => true,
_ => false,
}
}

pub fn is_subscriptable(t: &PythonType) -> bool {
match t {
PythonType::Class(c) => match c.name.as_str() {
Expand Down
Loading