Skip to content

Commit

Permalink
Parse binary operations with correct operator precedence (#271)
Browse files Browse the repository at this point in the history
* Parse binary operations with correct operator precedence

* Add another test case
  • Loading branch information
Glyphack authored Dec 12, 2024
1 parent 822399e commit dc53f04
Show file tree
Hide file tree
Showing 5 changed files with 470 additions and 112 deletions.
26 changes: 19 additions & 7 deletions parser/src/parser/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -889,7 +889,7 @@ impl<'a> Parser<'a> {
// https://docs.python.org/3/reference/compound_stmts.html#literal-patterns
fn parse_literal_pattern(&mut self) -> Result<MatchPattern, ParsingError> {
let node = self.start_node();
let value = self.parse_binary_arithmetic_operation()?;
let value = self.parse_binary_arithmetic_operation(0)?;
Ok(MatchPattern::MatchValue(MatchValue {
node: self.finish_node(node),
value,
Expand Down Expand Up @@ -2146,15 +2146,15 @@ impl<'a> Parser<'a> {
// https://docs.python.org/3/reference/expressions.html#shifting-operations
fn parse_shift_expr(&mut self) -> Result<Expression, ParsingError> {
let node = self.start_node();
let mut arith_expr = self.parse_binary_arithmetic_operation()?;
let mut arith_expr = self.parse_binary_arithmetic_operation(0)?;
if self.at(Kind::LeftShift) || self.at(Kind::RightShift) {
let op = if self.eat(Kind::LeftShift) {
BinaryOperator::LShift
} else {
self.bump(Kind::RightShift);
BinaryOperator::RShift
};
let lhs = self.parse_binary_arithmetic_operation()?;
let lhs = self.parse_binary_arithmetic_operation(0)?;
arith_expr = Expression::BinOp(Box::new(BinOp {
node: self.finish_node(node),
op,
Expand All @@ -2166,12 +2166,24 @@ impl<'a> Parser<'a> {
}

// https://docs.python.org/3/reference/expressions.html#binary-arithmetic-operations
fn parse_binary_arithmetic_operation(&mut self) -> Result<Expression, ParsingError> {
//
fn parse_binary_arithmetic_operation(
&mut self,
min_precedence: u8,
) -> Result<Expression, ParsingError> {
let node = self.start_node();
let mut lhs = self.parse_unary_arithmetic_operation()?;
while self.cur_kind().is_bin_arithmetic_op() {
let op = self.parse_bin_arithmetic_op()?;
let rhs = self.parse_unary_arithmetic_operation()?;
while let Some((op, precedence, associativity)) = self.cur_kind().bin_op_precedence() {
if precedence < min_precedence {
break;
}
self.bump_any();
let next_precedence = match associativity {
0 => precedence + 1,
1 => precedence,
_ => unreachable!(),
};
let rhs = self.parse_binary_arithmetic_operation(next_precedence)?;
lhs = Expression::BinOp(Box::new(BinOp {
node: self.finish_node(node),
op,
Expand Down
26 changes: 14 additions & 12 deletions parser/src/token.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use std::fmt::Display;

use crate::ast::BinaryOperator;

#[derive(Debug, Clone, PartialEq)]
pub struct Token {
pub kind: Kind,
Expand Down Expand Up @@ -214,18 +216,18 @@ impl Kind {
matches!(self, Kind::Not | Kind::BitNot | Kind::Minus | Kind::Plus)
}

pub fn is_bin_arithmetic_op(&self) -> bool {
matches!(
self,
Kind::Plus
| Kind::Minus
| Kind::Mul
| Kind::MatrixMul
| Kind::Div
| Kind::Mod
| Kind::Pow
| Kind::IntDiv
)
pub fn bin_op_precedence(&self) -> Option<(BinaryOperator, u8, u8)> {
match self {
Kind::Plus => Some((BinaryOperator::Add, 9, 0)),
Kind::Minus => Some((BinaryOperator::Sub, 9, 0)),
Kind::Mul => Some((BinaryOperator::Mult, 10, 0)),
Kind::MatrixMul => Some((BinaryOperator::MatMult, 10, 0)),
Kind::Div => Some((BinaryOperator::Div, 10, 0)),
Kind::Mod => Some((BinaryOperator::Mod, 10, 0)),
Kind::Pow => Some((BinaryOperator::Pow, 10, 0)),
Kind::IntDiv => Some((BinaryOperator::FloorDiv, 10, 0)),
_ => None,
}
}

pub fn is_comparison_operator(&self) -> bool {
Expand Down
12 changes: 11 additions & 1 deletion parser/test_data/inputs/binary_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

1 % 2

1 ** 2
1**2

1 << 2

Expand All @@ -25,3 +25,13 @@
1 | 2 | 3

1 @ 2

1 + 2 * 3

1 * 2 + 3

1 ^ 2 + 3

3 + (1 + 2) * 3

(3 + 1) * 2**3 + 1
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
---
source: parser/src/lexer/mod.rs
description: "1 + 2\n\n1 - 2\n\n1 * 2\n\n1 / 2\n\n1 // 2\n\n1 % 2\n\n1 ** 2\n\n1 << 2\n\n1 >> 2\n\n1 & 2\n\n1 ^ 2\n\n1 | 2\n\n1 | 2 | 3\n\n1 @ 2\n"
description: "1 + 2\n\n1 - 2\n\n1 * 2\n\n1 / 2\n\n1 // 2\n\n1 % 2\n\n1**2\n\n1 << 2\n\n1 >> 2\n\n1 & 2\n\n1 ^ 2\n\n1 | 2\n\n1 | 2 | 3\n\n1 @ 2\n\n1 + 2 * 3\n\n1 * 2 + 3\n\n1 ^ 2 + 3\n\n3 + (1 + 2) * 3\n\n(3 + 1) * 2**3 + 1\n"
input_file: parser/test_data/inputs/binary_op.py
---
0,1: Integer 1
Expand Down Expand Up @@ -34,43 +34,88 @@ input_file: parser/test_data/inputs/binary_op.py
41,42: NewLine
42,43: NL
43,44: Integer 1
45,47: **
48,49: Integer 2
49,50: NewLine
50,51: NL
51,52: Integer 1
53,55: <<
56,57: Integer 2
57,58: NewLine
58,59: NL
59,60: Integer 1
61,63: >>
64,65: Integer 2
65,66: NewLine
66,67: NL
67,68: Integer 1
69,70: &
71,72: Integer 2
72,73: NewLine
73,74: NL
74,75: Integer 1
76,77: ^
78,79: Integer 2
79,80: NewLine
80,81: NL
81,82: Integer 1
83,84: |
85,86: Integer 2
86,87: NewLine
87,88: NL
88,89: Integer 1
90,91: |
92,93: Integer 2
94,95: |
96,97: Integer 3
97,98: NewLine
98,99: NL
99,100: Integer 1
101,102: @
103,104: Integer 2
104,105: NewLine
44,46: **
46,47: Integer 2
47,48: NewLine
48,49: NL
49,50: Integer 1
51,53: <<
54,55: Integer 2
55,56: NewLine
56,57: NL
57,58: Integer 1
59,61: >>
62,63: Integer 2
63,64: NewLine
64,65: NL
65,66: Integer 1
67,68: &
69,70: Integer 2
70,71: NewLine
71,72: NL
72,73: Integer 1
74,75: ^
76,77: Integer 2
77,78: NewLine
78,79: NL
79,80: Integer 1
81,82: |
83,84: Integer 2
84,85: NewLine
85,86: NL
86,87: Integer 1
88,89: |
90,91: Integer 2
92,93: |
94,95: Integer 3
95,96: NewLine
96,97: NL
97,98: Integer 1
99,100: @
101,102: Integer 2
102,103: NewLine
103,104: NL
104,105: Integer 1
106,107: +
108,109: Integer 2
110,111: *
112,113: Integer 3
113,114: NewLine
114,115: NL
115,116: Integer 1
117,118: *
119,120: Integer 2
121,122: +
123,124: Integer 3
124,125: NewLine
125,126: NL
126,127: Integer 1
128,129: ^
130,131: Integer 2
132,133: +
134,135: Integer 3
135,136: NewLine
136,137: NL
137,138: Integer 3
139,140: +
141,142: (
142,143: Integer 1
144,145: +
146,147: Integer 2
147,148: )
149,150: *
151,152: Integer 3
152,153: NewLine
153,154: NL
154,155: (
155,156: Integer 3
157,158: +
159,160: Integer 1
160,161: )
162,163: *
164,165: Integer 2
165,167: **
167,168: Integer 3
169,170: +
171,172: Integer 1
172,173: NewLine
Loading

0 comments on commit dc53f04

Please sign in to comment.