Skip to content

Commit

Permalink
Implement limited type inference for generic params in fn calls
Browse files Browse the repository at this point in the history
  • Loading branch information
bfbachmann committed Dec 30, 2024
1 parent 5494654 commit d0cac1b
Show file tree
Hide file tree
Showing 26 changed files with 769 additions and 300 deletions.
8 changes: 4 additions & 4 deletions src/analyzer/ast/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ impl AArrayType {

// At this point we know both arrays are non-empty and have some element type key, so
// make sure the type keys match.
let elem_type1 = ctx.must_get_type(self.maybe_element_type_key.unwrap());
let elem_type2 = ctx.must_get_type(other.maybe_element_type_key.unwrap());
let elem_type1 = ctx.get_type(self.maybe_element_type_key.unwrap());
let elem_type2 = ctx.get_type(other.maybe_element_type_key.unwrap());
elem_type1.is_same_as(ctx, elem_type2, false)
}
}
Expand Down Expand Up @@ -168,14 +168,14 @@ impl AArrayInit {
// Make sure all the values are of the same type.
let maybe_element_type_key = if !contained_values.is_empty() {
let expected_type_key = contained_values.first().unwrap().type_key;
let expected_type = ctx.must_get_type(expected_type_key);
let expected_type = ctx.get_type(expected_type_key);

for value in &contained_values {
if value.type_key == expected_type_key {
continue;
}

let value_type = ctx.must_get_type(value.type_key);
let value_type = ctx.get_type(value.type_key);
if !value_type.is_same_as(ctx, expected_type, false) {
ctx.insert_err(AnalyzeError::new(
ErrorKind::MismatchedTypes,
Expand Down
2 changes: 1 addition & 1 deletion src/analyzer/ast/const.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ impl AConst {
));

// Just return a dummy value if the expression already failed analysis.
if ctx.must_get_type(value.type_key).is_unknown() {
if ctx.get_type(value.type_key).is_unknown() {
return AConst::new_zero_value(ctx, const_decl.name.as_str());
}

Expand Down
2 changes: 1 addition & 1 deletion src/analyzer/ast/enum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ impl AEnumVariantInit {
pub fn from(ctx: &mut ProgramContext, enum_init: &EnumVariantInit) -> Self {
// Make sure the enum type exists.
let enum_type_key = ctx.resolve_type(&Type::Unresolved(enum_init.typ.clone()));
let enum_type = match ctx.must_get_type(enum_type_key) {
let enum_type = match ctx.get_type(enum_type_key) {
AType::Unknown(_) => {
// The enum type has already failed semantic analysis, so we should avoid
// analyzing its initialization and just return some zero-value placeholder instead.
Expand Down
85 changes: 51 additions & 34 deletions src/analyzer/ast/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use crate::parser::ast::array::ArrayInit;
use crate::parser::ast::expr::Expression;
use crate::parser::ast::from::From;
use crate::parser::ast::func::Function;
use crate::parser::ast::func_call::FuncCall;
use crate::parser::ast::func_call::FnCall;
use crate::parser::ast::op::Operator;
use crate::parser::ast::r#type::Type;
use crate::parser::ast::symbol::Symbol;
Expand Down Expand Up @@ -386,6 +386,7 @@ impl AExpr {
expr,
maybe_expected_type_key,
allow_type,
false,
ignore_mutability,
false,
)
Expand All @@ -399,6 +400,7 @@ impl AExpr {
expr: Expression,
maybe_expected_type_key: Option<TypeKey>,
allow_type: bool,
allow_polymorph: bool,
ignore_mutability: bool,
prefer_method: bool,
) -> AExpr {
Expand All @@ -408,6 +410,7 @@ impl AExpr {
maybe_expected_type_key,
allow_type,
prefer_method,
allow_polymorph,
Span {
start_pos: expr.start_pos().clone(),
end_pos: expr.end_pos().clone(),
Expand Down Expand Up @@ -446,8 +449,8 @@ impl AExpr {

// Skip the type check if either type is unknown, as this implies that semantic analysis
// has already failed somewhere else in this expression or wherever it's being used.
let actual_type = ctx.must_get_type(self.type_key);
let expected_type = ctx.must_get_type(expected_tk);
let actual_type = ctx.get_type(self.type_key);
let expected_type = ctx.get_type(expected_tk);
if actual_type.is_unknown() || expected_type.is_unknown() {
return self;
}
Expand Down Expand Up @@ -495,7 +498,7 @@ impl AExpr {
.as_str(),
),
);
} else if !scoped_symbol.is_mut && !ctx.must_get_type(scoped_symbol.type_key).is_mut_ptr() {
} else if !scoped_symbol.is_mut && !ctx.get_type(scoped_symbol.type_key).is_mut_ptr() {
ctx.insert_err(
AnalyzeError::new(
ErrorKind::InvalidMutRef,
Expand All @@ -510,15 +513,15 @@ impl AExpr {
/// Tries to coerce this expression to the target type. If coercion is successful, returns
/// the coerced expression, otherwise just returns the expression as-is.
pub fn try_coerce_to(mut self, ctx: &mut ProgramContext, target_type_key: TypeKey) -> Self {
let target_type = ctx.must_get_type(target_type_key);
let target_type = ctx.get_type(target_type_key);
if target_type.is_unknown() {
return self;
}

// If both types are pointers, make sure their pointee types are the same
// and immutability is not violated. If so, we can allow coercion.
if let (AType::Pointer(self_ptr_type), AType::Pointer(other_ptr_type)) =
(ctx.must_get_type(self.type_key), target_type)
(ctx.get_type(self.type_key), target_type)
{
let pointee_type_match =
self_ptr_type.pointee_type_key == other_ptr_type.pointee_type_key;
Expand Down Expand Up @@ -650,7 +653,7 @@ impl AExpr {
return self;
}

match ctx.must_get_type(symbol.type_key) {
match ctx.get_type(symbol.type_key) {
AType::Function(sig) if sig.is_parameterized() => {
todo!()
}
Expand Down Expand Up @@ -752,7 +755,7 @@ impl AExpr {

if !expr.kind.is_const()
|| !ctx
.must_get_type(expr.type_key)
.get_type(expr.type_key)
.is_same_as(ctx, &AType::Uint, true)
{
return None;
Expand Down Expand Up @@ -821,8 +824,8 @@ pub fn check_operand_types(
op: &Operator,
right_expr: &AExpr,
) -> Result<Option<TypeKey>, Vec<AnalyzeError>> {
let left_type = ctx.must_get_type(left_expr.type_key);
let right_type = ctx.must_get_type(right_expr.type_key);
let left_type = ctx.get_type(left_expr.type_key);
let right_type = ctx.get_type(right_expr.type_key);

let mut left_type_key = None;
let mut right_type_key = None;
Expand Down Expand Up @@ -878,7 +881,7 @@ pub fn check_operand_types(
(Some(ltk), Some(rtk)) => {
// In the case of pointer arithmetic, if either of the pointers is
// mutable, we'll make the result mutable as well.
if ctx.must_get_type(rtk).is_mut_ptr() {
if ctx.get_type(rtk).is_mut_ptr() {
Some(rtk)
} else {
Some(ltk)
Expand Down Expand Up @@ -1060,8 +1063,8 @@ fn analyze_type_cast(
) -> AExpr {
let left_expr = AExpr::from(ctx, expr, None, false, false);
let target_type_key = ctx.resolve_type(&target_type);
let left_type = ctx.must_get_type(left_expr.type_key);
let a_target_type = ctx.must_get_type(target_type_key);
let left_type = ctx.get_type(left_expr.type_key);
let a_target_type = ctx.get_type(target_type_key);

// Skip the check if the left expression already failed analysis.
// Also make sure the type keys are actually different. If not, this is a
Expand Down Expand Up @@ -1112,7 +1115,7 @@ fn analyze_tuple_init(
) -> AExpr {
let maybe_expected_field_type_keys = match maybe_expected_type_key {
Some(tk) => {
if let AType::Tuple(tuple_type) = ctx.must_get_type(tk) {
if let AType::Tuple(tuple_type) = ctx.get_type(tk) {
let mut field_type_keys = Vec::with_capacity(tuple_type.fields.len());
for i in 0..tuple_type.fields.len() {
field_type_keys.insert(i, tuple_type.get_field_type_key(i).unwrap());
Expand Down Expand Up @@ -1143,7 +1146,7 @@ fn analyze_array_init(
) -> AExpr {
let maybe_element_type_key = match maybe_expected_type_key {
Some(tk) => {
if let AType::Array(array_type) = ctx.must_get_type(tk) {
if let AType::Array(array_type) = ctx.get_type(tk) {
array_type.maybe_element_type_key
} else {
None
Expand All @@ -1161,8 +1164,13 @@ fn analyze_array_init(
}
}

fn analyze_fn_call(ctx: &mut ProgramContext, fn_call: FuncCall, span: Span) -> AExpr {
let a_call = AFnCall::from(ctx, &fn_call);
fn analyze_fn_call(
ctx: &mut ProgramContext,
fn_call: FnCall,
maybe_expected_ret_tk: Option<TypeKey>,
span: Span,
) -> AExpr {
let a_call = AFnCall::from(ctx, &fn_call, maybe_expected_ret_tk);
match a_call.maybe_ret_type_key.clone() {
Some(type_key) => AExpr {
kind: AExprKind::FunctionCall(Box::new(a_call)),
Expand Down Expand Up @@ -1238,7 +1246,7 @@ fn analyze_unary_op(
);

// Make sure the expression has type bool.
let typ = ctx.must_get_type(a_expr.type_key);
let typ = ctx.get_type(a_expr.type_key);
if !typ.is_unknown() && !typ.is_bool() {
ctx.insert_err(AnalyzeError::new(
ErrorKind::MismatchedTypes,
Expand Down Expand Up @@ -1291,7 +1299,7 @@ fn analyze_unary_op(

Operator::Defererence => {
let operand_expr = AExpr::from(ctx, right_expr.clone(), None, false, false);
let operand_expr_type = ctx.must_get_type(operand_expr.type_key);
let operand_expr_type = ctx.get_type(operand_expr.type_key);

// Make sure the operand expression is of a pointer type.
match operand_expr_type {
Expand Down Expand Up @@ -1323,7 +1331,7 @@ fn analyze_unary_op(

Operator::Subtract => {
let operand_expr = AExpr::from(ctx, right_expr.clone(), None, false, false);
let operand_expr_type = ctx.must_get_type(operand_expr.type_key);
let operand_expr_type = ctx.get_type(operand_expr.type_key);

// Make sure the operand expression is of a signed numeric type since we'll
// have to flip its sign.
Expand Down Expand Up @@ -1354,7 +1362,7 @@ fn analyze_unary_op(

Operator::BitwiseNot => {
let operand_expr = AExpr::from(ctx, right_expr.clone(), None, false, false);
let operand_expr_type = ctx.must_get_type(operand_expr.type_key);
let operand_expr_type = ctx.get_type(operand_expr.type_key);

// Make sure the operand is of an integer type.
if operand_expr_type.is_integer() {
Expand Down Expand Up @@ -1411,8 +1419,8 @@ fn analyze_binary_op(

// If we couldn't resolve both of the operand types, we'll skip any further
// type checks by returning early.
let left_type = ctx.must_get_type(a_left.type_key);
let right_type = ctx.must_get_type(a_right.type_key);
let left_type = ctx.get_type(a_left.type_key);
let right_type = ctx.get_type(a_right.type_key);
if left_type.is_unknown() || right_type.is_unknown() {
return AExpr {
kind: AExprKind::BinaryOperation(
Expand Down Expand Up @@ -1445,8 +1453,14 @@ fn analyze_binary_op(
}
}

fn analyze_symbol(ctx: &mut ProgramContext, symbol: Symbol, allow_type: bool, span: Span) -> AExpr {
let a_symbol = ASymbol::from(ctx, &symbol, true, allow_type, false);
fn analyze_symbol(
ctx: &mut ProgramContext,
symbol: Symbol,
allow_type: bool,
allow_polymorph: bool,
span: Span,
) -> AExpr {
let a_symbol = ASymbol::from(ctx, &symbol, true, allow_type, false, allow_polymorph);
AExpr {
type_key: a_symbol.type_key,
kind: AExprKind::Symbol(a_symbol),
Expand Down Expand Up @@ -1532,13 +1546,16 @@ fn analyze_expr_with_pref(
expr: Expression,
maybe_expected_type_key: Option<TypeKey>,
allow_type: bool,
allow_polymorph: bool,
prefer_method: bool,
span: Span,
) -> AExpr {
match expr {
Expression::TypeCast(expr, target_type) => analyze_type_cast(ctx, *expr, target_type, span),

Expression::Symbol(symbol) => analyze_symbol(ctx, symbol, allow_type, span),
Expression::Symbol(symbol) => {
analyze_symbol(ctx, symbol, allow_type, allow_polymorph, span)
}

Expression::BoolLiteral(b) => AExpr {
kind: AExprKind::BoolLiteral(b.value),
Expand Down Expand Up @@ -1669,7 +1686,7 @@ fn analyze_expr_with_pref(

Expression::FunctionCall(fn_call) => {
// Analyze the function call and ensure it has a return type.
analyze_fn_call(ctx, *fn_call, span)
analyze_fn_call(ctx, *fn_call, maybe_expected_type_key, span)
}

Expression::Index(index) => {
Expand All @@ -1685,11 +1702,11 @@ fn analyze_expr_with_pref(
// Prefer methods if the expected type is a function.
let prefer_method = prefer_method
|| match maybe_expected_type_key {
Some(tk) => ctx.must_get_type(tk).is_fn(),
Some(tk) => ctx.get_type(tk).is_fn(),
None => false,
};

let access = AMemberAccess::from(ctx, &member_access, prefer_method);
let access = AMemberAccess::from(ctx, &member_access, prefer_method, allow_polymorph);
AExpr {
type_key: access.member_type_key,
kind: AExprKind::MemberAccess(Box::new(access)),
Expand Down Expand Up @@ -1735,7 +1752,7 @@ mod tests {
use crate::parser::ast::arg::Argument;
use crate::parser::ast::bool_lit::BoolLit;
use crate::parser::ast::expr::Expression;
use crate::parser::ast::func_call::FuncCall;
use crate::parser::ast::func_call::FnCall;
use crate::parser::ast::func_sig::FunctionSignature;
use crate::parser::ast::i64_lit::I64Lit;
use crate::parser::ast::op::Operator;
Expand Down Expand Up @@ -1885,7 +1902,7 @@ mod tests {
ctx.insert_type(AType::from_fn_sig(a_fn.signature.clone()));

// Analyze the function call expression.
let fn_call = FuncCall::new_with_default_pos(
let fn_call = FnCall::new_with_default_pos(
Expression::Symbol(Symbol::new_with_default_pos("do_thing")),
vec![Expression::BoolLiteral(BoolLit::new_with_default_pos(true))],
);
Expand Down Expand Up @@ -1948,7 +1965,7 @@ mod tests {
Box::new(Expression::I64Literal(I64Lit::new_with_default_pos(1))),
Operator::Add,
Box::new(Expression::FunctionCall(Box::new(
FuncCall::new_with_default_pos(
FnCall::new_with_default_pos(
Expression::Symbol(Symbol::new_with_default_pos("do_thing")),
vec![],
),
Expand Down Expand Up @@ -2043,7 +2060,7 @@ mod tests {
// Analyze the function call expression.
let result = AExpr::from(
&mut ctx,
Expression::FunctionCall(Box::new(FuncCall::new_with_default_pos(
Expression::FunctionCall(Box::new(FnCall::new_with_default_pos(
Expression::Symbol(Symbol::new_with_default_pos("do_thing")),
vec![Expression::BoolLiteral(BoolLit::new_with_default_pos(true))],
))),
Expand Down Expand Up @@ -2140,7 +2157,7 @@ mod tests {
// Analyze the function call expression.
let result = AExpr::from(
&mut ctx,
Expression::FunctionCall(Box::new(FuncCall::new_with_default_pos(
Expression::FunctionCall(Box::new(FnCall::new_with_default_pos(
Expression::Symbol(Symbol::new_with_default_pos("do_thing")),
vec![Expression::I64Literal(I64Lit::new_with_default_pos(1))],
))),
Expand Down
Loading

0 comments on commit d0cac1b

Please sign in to comment.