From dc53f04f8223653b272bc8e806a1d59d33667b4e Mon Sep 17 00:00:00 2001 From: Shaygan Hooshyari Date: Thu, 12 Dec 2024 21:51:21 +0100 Subject: [PATCH] Parse binary operations with correct operator precedence (#271) * Parse binary operations with correct operator precedence * Add another test case --- parser/src/parser/parser.rs | 26 +- parser/src/token.rs | 26 +- parser/test_data/inputs/binary_op.py | 12 +- ...ot_test_lexer_and_errors@binary_op.py.snap | 127 ++++-- ...ser__parser__parser__tests__binary_op.snap | 391 +++++++++++++++--- 5 files changed, 470 insertions(+), 112 deletions(-) diff --git a/parser/src/parser/parser.rs b/parser/src/parser/parser.rs index 808ad452..6b500c8d 100644 --- a/parser/src/parser/parser.rs +++ b/parser/src/parser/parser.rs @@ -889,7 +889,7 @@ impl<'a> Parser<'a> { // https://docs.python.org/3/reference/compound_stmts.html#literal-patterns fn parse_literal_pattern(&mut self) -> Result { let node = self.start_node(); - let value = self.parse_binary_arithmetic_operation()?; + let value = self.parse_binary_arithmetic_operation(0)?; Ok(MatchPattern::MatchValue(MatchValue { node: self.finish_node(node), value, @@ -2146,7 +2146,7 @@ impl<'a> Parser<'a> { // https://docs.python.org/3/reference/expressions.html#shifting-operations fn parse_shift_expr(&mut self) -> Result { let node = self.start_node(); - let mut arith_expr = self.parse_binary_arithmetic_operation()?; + let mut arith_expr = self.parse_binary_arithmetic_operation(0)?; if self.at(Kind::LeftShift) || self.at(Kind::RightShift) { let op = if self.eat(Kind::LeftShift) { BinaryOperator::LShift @@ -2154,7 +2154,7 @@ impl<'a> Parser<'a> { self.bump(Kind::RightShift); BinaryOperator::RShift }; - let lhs = self.parse_binary_arithmetic_operation()?; + let lhs = self.parse_binary_arithmetic_operation(0)?; arith_expr = Expression::BinOp(Box::new(BinOp { node: self.finish_node(node), op, @@ -2166,12 +2166,24 @@ impl<'a> Parser<'a> { } // https://docs.python.org/3/reference/expressions.html#binary-arithmetic-operations - fn parse_binary_arithmetic_operation(&mut self) -> Result { + // + fn parse_binary_arithmetic_operation( + &mut self, + min_precedence: u8, + ) -> Result { let node = self.start_node(); let mut lhs = self.parse_unary_arithmetic_operation()?; - while self.cur_kind().is_bin_arithmetic_op() { - let op = self.parse_bin_arithmetic_op()?; - let rhs = self.parse_unary_arithmetic_operation()?; + while let Some((op, precedence, associativity)) = self.cur_kind().bin_op_precedence() { + if precedence < min_precedence { + break; + } + self.bump_any(); + let next_precedence = match associativity { + 0 => precedence + 1, + 1 => precedence, + _ => unreachable!(), + }; + let rhs = self.parse_binary_arithmetic_operation(next_precedence)?; lhs = Expression::BinOp(Box::new(BinOp { node: self.finish_node(node), op, diff --git a/parser/src/token.rs b/parser/src/token.rs index 4970d305..cec86690 100644 --- a/parser/src/token.rs +++ b/parser/src/token.rs @@ -1,5 +1,7 @@ use std::fmt::Display; +use crate::ast::BinaryOperator; + #[derive(Debug, Clone, PartialEq)] pub struct Token { pub kind: Kind, @@ -214,18 +216,18 @@ impl Kind { matches!(self, Kind::Not | Kind::BitNot | Kind::Minus | Kind::Plus) } - pub fn is_bin_arithmetic_op(&self) -> bool { - matches!( - self, - Kind::Plus - | Kind::Minus - | Kind::Mul - | Kind::MatrixMul - | Kind::Div - | Kind::Mod - | Kind::Pow - | Kind::IntDiv - ) + pub fn bin_op_precedence(&self) -> Option<(BinaryOperator, u8, u8)> { + match self { + Kind::Plus => Some((BinaryOperator::Add, 9, 0)), + Kind::Minus => Some((BinaryOperator::Sub, 9, 0)), + Kind::Mul => Some((BinaryOperator::Mult, 10, 0)), + Kind::MatrixMul => Some((BinaryOperator::MatMult, 10, 0)), + Kind::Div => Some((BinaryOperator::Div, 10, 0)), + Kind::Mod => Some((BinaryOperator::Mod, 10, 0)), + Kind::Pow => Some((BinaryOperator::Pow, 10, 0)), + Kind::IntDiv => Some((BinaryOperator::FloorDiv, 10, 0)), + _ => None, + } } pub fn is_comparison_operator(&self) -> bool { diff --git a/parser/test_data/inputs/binary_op.py b/parser/test_data/inputs/binary_op.py index 6fe8a2ac..f223b9e6 100644 --- a/parser/test_data/inputs/binary_op.py +++ b/parser/test_data/inputs/binary_op.py @@ -10,7 +10,7 @@ 1 % 2 -1 ** 2 +1**2 1 << 2 @@ -25,3 +25,13 @@ 1 | 2 | 3 1 @ 2 + +1 + 2 * 3 + +1 * 2 + 3 + +1 ^ 2 + 3 + +3 + (1 + 2) * 3 + +(3 + 1) * 2**3 + 1 diff --git a/parser/test_data/output/enderpy_python_parser__lexer__tests__snapshot_test_lexer_and_errors@binary_op.py.snap b/parser/test_data/output/enderpy_python_parser__lexer__tests__snapshot_test_lexer_and_errors@binary_op.py.snap index d17340e8..4fa3dd9c 100644 --- a/parser/test_data/output/enderpy_python_parser__lexer__tests__snapshot_test_lexer_and_errors@binary_op.py.snap +++ b/parser/test_data/output/enderpy_python_parser__lexer__tests__snapshot_test_lexer_and_errors@binary_op.py.snap @@ -1,6 +1,6 @@ --- source: parser/src/lexer/mod.rs -description: "1 + 2\n\n1 - 2\n\n1 * 2\n\n1 / 2\n\n1 // 2\n\n1 % 2\n\n1 ** 2\n\n1 << 2\n\n1 >> 2\n\n1 & 2\n\n1 ^ 2\n\n1 | 2\n\n1 | 2 | 3\n\n1 @ 2\n" +description: "1 + 2\n\n1 - 2\n\n1 * 2\n\n1 / 2\n\n1 // 2\n\n1 % 2\n\n1**2\n\n1 << 2\n\n1 >> 2\n\n1 & 2\n\n1 ^ 2\n\n1 | 2\n\n1 | 2 | 3\n\n1 @ 2\n\n1 + 2 * 3\n\n1 * 2 + 3\n\n1 ^ 2 + 3\n\n3 + (1 + 2) * 3\n\n(3 + 1) * 2**3 + 1\n" input_file: parser/test_data/inputs/binary_op.py --- 0,1: Integer 1 @@ -34,43 +34,88 @@ input_file: parser/test_data/inputs/binary_op.py 41,42: NewLine 42,43: NL 43,44: Integer 1 -45,47: ** -48,49: Integer 2 -49,50: NewLine -50,51: NL -51,52: Integer 1 -53,55: << -56,57: Integer 2 -57,58: NewLine -58,59: NL -59,60: Integer 1 -61,63: >> -64,65: Integer 2 -65,66: NewLine -66,67: NL -67,68: Integer 1 -69,70: & -71,72: Integer 2 -72,73: NewLine -73,74: NL -74,75: Integer 1 -76,77: ^ -78,79: Integer 2 -79,80: NewLine -80,81: NL -81,82: Integer 1 -83,84: | -85,86: Integer 2 -86,87: NewLine -87,88: NL -88,89: Integer 1 -90,91: | -92,93: Integer 2 -94,95: | -96,97: Integer 3 -97,98: NewLine -98,99: NL -99,100: Integer 1 -101,102: @ -103,104: Integer 2 -104,105: NewLine +44,46: ** +46,47: Integer 2 +47,48: NewLine +48,49: NL +49,50: Integer 1 +51,53: << +54,55: Integer 2 +55,56: NewLine +56,57: NL +57,58: Integer 1 +59,61: >> +62,63: Integer 2 +63,64: NewLine +64,65: NL +65,66: Integer 1 +67,68: & +69,70: Integer 2 +70,71: NewLine +71,72: NL +72,73: Integer 1 +74,75: ^ +76,77: Integer 2 +77,78: NewLine +78,79: NL +79,80: Integer 1 +81,82: | +83,84: Integer 2 +84,85: NewLine +85,86: NL +86,87: Integer 1 +88,89: | +90,91: Integer 2 +92,93: | +94,95: Integer 3 +95,96: NewLine +96,97: NL +97,98: Integer 1 +99,100: @ +101,102: Integer 2 +102,103: NewLine +103,104: NL +104,105: Integer 1 +106,107: + +108,109: Integer 2 +110,111: * +112,113: Integer 3 +113,114: NewLine +114,115: NL +115,116: Integer 1 +117,118: * +119,120: Integer 2 +121,122: + +123,124: Integer 3 +124,125: NewLine +125,126: NL +126,127: Integer 1 +128,129: ^ +130,131: Integer 2 +132,133: + +134,135: Integer 3 +135,136: NewLine +136,137: NL +137,138: Integer 3 +139,140: + +141,142: ( +142,143: Integer 1 +144,145: + +146,147: Integer 2 +147,148: ) +149,150: * +151,152: Integer 3 +152,153: NewLine +153,154: NL +154,155: ( +155,156: Integer 3 +157,158: + +159,160: Integer 1 +160,161: ) +162,163: * +164,165: Integer 2 +165,167: ** +167,168: Integer 3 +169,170: + +171,172: Integer 1 +172,173: NewLine diff --git a/parser/test_data/output/enderpy_python_parser__parser__parser__tests__binary_op.snap b/parser/test_data/output/enderpy_python_parser__parser__parser__tests__binary_op.snap index d43a6a74..20d09191 100644 --- a/parser/test_data/output/enderpy_python_parser__parser__parser__tests__binary_op.snap +++ b/parser/test_data/output/enderpy_python_parser__parser__parser__tests__binary_op.snap @@ -1,11 +1,11 @@ --- source: parser/src/parser/parser.rs -description: "test file: test_data/inputs/binary_op.py\n1 + 2\n\n1 - 2\n\n1 * 2\n\n1 / 2\n\n1 // 2\n\n1 % 2\n\n1 ** 2\n\n1 << 2\n\n1 >> 2\n\n1 & 2\n\n1 ^ 2\n\n1 | 2\n\n1 | 2 | 3\n\n1 @ 2\n" +description: "test file: test_data/inputs/binary_op.py\n1 + 2\n\n1 - 2\n\n1 * 2\n\n1 / 2\n\n1 // 2\n\n1 % 2\n\n1**2\n\n1 << 2\n\n1 >> 2\n\n1 & 2\n\n1 ^ 2\n\n1 | 2\n\n1 | 2 | 3\n\n1 @ 2\n\n1 + 2 * 3\n\n1 * 2 + 3\n\n1 ^ 2 + 3\n\n3 + (1 + 2) * 3\n\n(3 + 1) * 2**3 + 1\n" --- Module { node: Node { start: 0, - end: 105, + end: 173, }, body: [ ExpressionStatement( @@ -187,7 +187,7 @@ Module { BinOp { node: Node { start: 43, - end: 49, + end: 47, }, op: Pow, left: Constant( @@ -202,8 +202,8 @@ Module { right: Constant( Constant { node: Node { - start: 48, - end: 49, + start: 46, + end: 47, }, value: Int, }, @@ -215,15 +215,15 @@ Module { BinOp( BinOp { node: Node { - start: 51, - end: 57, + start: 49, + end: 55, }, op: LShift, left: Constant( Constant { node: Node { - start: 51, - end: 52, + start: 49, + end: 50, }, value: Int, }, @@ -231,8 +231,8 @@ Module { right: Constant( Constant { node: Node { - start: 56, - end: 57, + start: 54, + end: 55, }, value: Int, }, @@ -244,15 +244,15 @@ Module { BinOp( BinOp { node: Node { - start: 59, - end: 65, + start: 57, + end: 63, }, op: RShift, left: Constant( Constant { node: Node { - start: 59, - end: 60, + start: 57, + end: 58, }, value: Int, }, @@ -260,8 +260,8 @@ Module { right: Constant( Constant { node: Node { - start: 64, - end: 65, + start: 62, + end: 63, }, value: Int, }, @@ -273,15 +273,15 @@ Module { BinOp( BinOp { node: Node { - start: 67, - end: 72, + start: 65, + end: 70, }, op: BitAnd, left: Constant( Constant { node: Node { - start: 67, - end: 68, + start: 65, + end: 66, }, value: Int, }, @@ -289,8 +289,8 @@ Module { right: Constant( Constant { node: Node { - start: 71, - end: 72, + start: 69, + end: 70, }, value: Int, }, @@ -302,15 +302,15 @@ Module { BinOp( BinOp { node: Node { - start: 74, - end: 79, + start: 72, + end: 77, }, op: BitXor, left: Constant( Constant { node: Node { - start: 74, - end: 75, + start: 72, + end: 73, }, value: Int, }, @@ -318,8 +318,8 @@ Module { right: Constant( Constant { node: Node { - start: 78, - end: 79, + start: 76, + end: 77, }, value: Int, }, @@ -331,15 +331,15 @@ Module { BinOp( BinOp { node: Node { - start: 81, - end: 86, + start: 79, + end: 84, }, op: BitOr, left: Constant( Constant { node: Node { - start: 81, - end: 82, + start: 79, + end: 80, }, value: Int, }, @@ -347,8 +347,8 @@ Module { right: Constant( Constant { node: Node { - start: 85, - end: 86, + start: 83, + end: 84, }, value: Int, }, @@ -360,22 +360,22 @@ Module { BinOp( BinOp { node: Node { - start: 88, - end: 97, + start: 86, + end: 95, }, op: BitOr, left: BinOp( BinOp { node: Node { - start: 88, - end: 93, + start: 86, + end: 91, }, op: BitOr, left: Constant( Constant { node: Node { - start: 88, - end: 89, + start: 86, + end: 87, }, value: Int, }, @@ -383,8 +383,8 @@ Module { right: Constant( Constant { node: Node { - start: 92, - end: 93, + start: 90, + end: 91, }, value: Int, }, @@ -394,8 +394,8 @@ Module { right: Constant( Constant { node: Node { - start: 96, - end: 97, + start: 94, + end: 95, }, value: Int, }, @@ -407,15 +407,15 @@ Module { BinOp( BinOp { node: Node { - start: 99, - end: 104, + start: 97, + end: 102, }, op: MatMult, left: Constant( Constant { node: Node { - start: 99, - end: 100, + start: 97, + end: 98, }, value: Int, }, @@ -423,8 +423,297 @@ Module { right: Constant( Constant { node: Node { - start: 103, - end: 104, + start: 101, + end: 102, + }, + value: Int, + }, + ), + }, + ), + ), + ExpressionStatement( + BinOp( + BinOp { + node: Node { + start: 104, + end: 113, + }, + op: Add, + left: Constant( + Constant { + node: Node { + start: 104, + end: 105, + }, + value: Int, + }, + ), + right: BinOp( + BinOp { + node: Node { + start: 108, + end: 113, + }, + op: Mult, + left: Constant( + Constant { + node: Node { + start: 108, + end: 109, + }, + value: Int, + }, + ), + right: Constant( + Constant { + node: Node { + start: 112, + end: 113, + }, + value: Int, + }, + ), + }, + ), + }, + ), + ), + ExpressionStatement( + BinOp( + BinOp { + node: Node { + start: 115, + end: 124, + }, + op: Add, + left: BinOp( + BinOp { + node: Node { + start: 115, + end: 120, + }, + op: Mult, + left: Constant( + Constant { + node: Node { + start: 115, + end: 116, + }, + value: Int, + }, + ), + right: Constant( + Constant { + node: Node { + start: 119, + end: 120, + }, + value: Int, + }, + ), + }, + ), + right: Constant( + Constant { + node: Node { + start: 123, + end: 124, + }, + value: Int, + }, + ), + }, + ), + ), + ExpressionStatement( + BinOp( + BinOp { + node: Node { + start: 126, + end: 135, + }, + op: BitXor, + left: Constant( + Constant { + node: Node { + start: 126, + end: 127, + }, + value: Int, + }, + ), + right: BinOp( + BinOp { + node: Node { + start: 130, + end: 135, + }, + op: Add, + left: Constant( + Constant { + node: Node { + start: 130, + end: 131, + }, + value: Int, + }, + ), + right: Constant( + Constant { + node: Node { + start: 134, + end: 135, + }, + value: Int, + }, + ), + }, + ), + }, + ), + ), + ExpressionStatement( + BinOp( + BinOp { + node: Node { + start: 137, + end: 152, + }, + op: Add, + left: Constant( + Constant { + node: Node { + start: 137, + end: 138, + }, + value: Int, + }, + ), + right: BinOp( + BinOp { + node: Node { + start: 141, + end: 152, + }, + op: Mult, + left: BinOp( + BinOp { + node: Node { + start: 142, + end: 147, + }, + op: Add, + left: Constant( + Constant { + node: Node { + start: 142, + end: 143, + }, + value: Int, + }, + ), + right: Constant( + Constant { + node: Node { + start: 146, + end: 147, + }, + value: Int, + }, + ), + }, + ), + right: Constant( + Constant { + node: Node { + start: 151, + end: 152, + }, + value: Int, + }, + ), + }, + ), + }, + ), + ), + ExpressionStatement( + BinOp( + BinOp { + node: Node { + start: 154, + end: 172, + }, + op: Add, + left: BinOp( + BinOp { + node: Node { + start: 154, + end: 168, + }, + op: Mult, + left: BinOp( + BinOp { + node: Node { + start: 155, + end: 160, + }, + op: Add, + left: Constant( + Constant { + node: Node { + start: 155, + end: 156, + }, + value: Int, + }, + ), + right: Constant( + Constant { + node: Node { + start: 159, + end: 160, + }, + value: Int, + }, + ), + }, + ), + right: BinOp( + BinOp { + node: Node { + start: 164, + end: 168, + }, + op: Pow, + left: Constant( + Constant { + node: Node { + start: 164, + end: 165, + }, + value: Int, + }, + ), + right: Constant( + Constant { + node: Node { + start: 167, + end: 168, + }, + value: Int, + }, + ), + }, + ), + }, + ), + right: Constant( + Constant { + node: Node { + start: 171, + end: 172, }, value: Int, },