From f99d8cbe56d87a0e8ab2f2e93ffe2b31d4fa5676 Mon Sep 17 00:00:00 2001 From: Bhushan Srivastava <59949692+he11owthere@users.noreply.github.com> Date: Thu, 19 Dec 2024 05:05:43 +0530 Subject: [PATCH] Add Vulkan backend (#125) * Vulkan Backend implementation * very small fix * parsing for and if statements * added while and switch * working parser * Added Layout support * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove debug code --------- Co-authored-by: Nripesh Niketan --- crosstl/src/backend/Vulkan/VulkanAst.py | 258 ++++++++ crosstl/src/backend/Vulkan/VulkanLexer.py | 177 ++++++ crosstl/src/backend/Vulkan/VulkanParser.py | 665 +++++++++++++++++++++ 3 files changed, 1100 insertions(+) create mode 100644 crosstl/src/backend/Vulkan/VulkanAst.py create mode 100644 crosstl/src/backend/Vulkan/VulkanLexer.py create mode 100644 crosstl/src/backend/Vulkan/VulkanParser.py diff --git a/crosstl/src/backend/Vulkan/VulkanAst.py b/crosstl/src/backend/Vulkan/VulkanAst.py new file mode 100644 index 00000000..67e9948e --- /dev/null +++ b/crosstl/src/backend/Vulkan/VulkanAst.py @@ -0,0 +1,258 @@ +class ASTNode: + pass + + +class TernaryOpNode: + def __init__(self, condition, true_expr, false_expr): + self.condition = condition + self.true_expr = true_expr + self.false_expr = false_expr + + def __repr__(self): + return f"TernaryOpNode(condition={self.condition}, true_expr={self.true_expr}, false_expr={self.false_expr})" + + +class ShaderNode: + def __init__( + self, + spirv_version, + descriptor_sets, + shader_stages, + functions, + ): + self.spirv_version = spirv_version + self.descriptor_sets = descriptor_sets + self.shader_stages = shader_stages + self.functions = functions + + def __repr__(self): + return f"ShaderNode(spirv_version={self.spirv_version}, descriptor_sets={self.descriptor_sets}, shader_stages={self.shader_stages}, functions={self.functions})" + + +class IfNode(ASTNode): + def __init__( + self, + if_condition, + if_body, + else_if_conditions=[], + else_if_bodies=[], + else_body=None, + ): + self.if_condition = if_condition + self.if_body = if_body + self.else_if_conditions = else_if_conditions + self.else_if_bodies = else_if_bodies + self.else_body = else_body + + def __repr__(self): + return f"IfNode(if_condition={self.if_condition}, if_body={self.if_body}, else_if_conditions={self.else_if_conditions}, else_if_bodies={self.else_if_bodies}, else_body={self.else_body})" + + +class ForNode(ASTNode): + def __init__(self, init, condition, update, body): + self.init = init + self.condition = condition + self.update = update + self.body = body + + def __repr__(self): + return f"ForNode(init={self.init}, condition={self.condition}, update={self.update}, body={self.body})" + + +class ReturnNode(ASTNode): + def __init__(self, value): + self.value = value + + def __repr__(self): + return f"ReturnNode(value={self.value})" + + +class FunctionCallNode(ASTNode): + def __init__(self, name, args): + self.name = name + self.args = args + + def __repr__(self): + return f"FunctionCallNode(name={self.name}, args={self.args})" + + +class BinaryOpNode(ASTNode): + def __init__(self, left, op, right): + self.left = left + self.op = op + self.right = right + + def __repr__(self): + return f"BinaryOpNode(left={self.left}, op={self.op}, right={self.right})" + + +class UnaryOpNode(ASTNode): + def __init__(self, op, operand): + self.op = op + self.operand = operand + + def __repr__(self): + return f"UnaryOpNode(operator={self.op}, operand={self.operand})" + + +class DescriptorSetNode(ASTNode): + def __init__(self, set_number, bindings): + self.set_number = set_number + self.bindings = bindings + + def __repr__(self): + return ( + f"DescriptorSetNode(set_number={self.set_number}, bindings={self.bindings})" + ) + + +class LayoutNode(ASTNode): + def __init__( + self, + bindings, + push_constant, + layout_type, + data_type, + variable_name, + struct_fields, + ): + self.bindings = bindings + self.push_constant = push_constant + self.layout_type = layout_type + self.data_type = data_type + self.variable_name = variable_name + self.struct_fields = struct_fields + + def __repr__(self): + return ( + f"LayoutNode(bindings={self.bindings}, push_constant={self.push_constant}, " + f"layout_type={self.layout_type}, data_type={self.data_type}, " + f"variable_name={self.variable_name}, struct_fields={self.struct_fields})" + ) + + +class UniformNode(ASTNode): + def __init__(self, name, var_type, value=None): + self.name = name + self.var_type = var_type + self.value = value + + def __repr__(self): + return f"UniformNode(name={self.name}, var_type={self.var_type}, value={self.value})" + + +class ShaderStageNode(ASTNode): + def __init__(self, stage, entry_point): + self.stage = stage + self.entry_point = entry_point + + def __repr__(self): + return f"ShaderStageNode(stage={self.stage}, entry_point={self.entry_point})" + + +class PushConstantNode(ASTNode): + def __init__(self, size, values): + self.size = size + self.values = values + + def __repr__(self): + return f"PushConstantNode(size={self.size}, values={self.values})" + + +class StructNode(ASTNode): + def __init__(self, name, members): + self.name = name + self.members = members + + def __repr__(self): + return f"StructNode(name={self.name}, members={self.members})" + + +class FunctionNode(ASTNode): + def __init__(self, name, return_type, parameters, body): + self.name = name + self.return_type = return_type + self.parameters = parameters + self.body = body + + def __repr__(self): + return f"FunctionNode(name={self.name}, return_type={self.return_type}, parameters={self.parameters}, body={self.body})" + + +class MemberAccessNode(ASTNode): + def __init__(self, object, member): + self.object = object + self.member = member + + def __repr__(self): + return f"MemberAccessNode(object={self.object}, member={self.member})" + + +class VariableNode(ASTNode): + def __init__(self, name, var_type): + self.name = name + self.var_type = var_type + + def __repr__(self): + return f"VariableNode(name={self.name}, var_type={self.var_type})" + + +class SwitchNode(ASTNode): + def __init__(self, expression, cases): + self.expression = expression + self.cases = cases + + def __repr__(self): + return f"SwitchNode(expression={self.expression}, cases={self.cases})" + + +class CaseNode(ASTNode): + def __init__(self, value, body): + self.value = value + self.body = body + + def __repr__(self): + return f"CaseNode(value={self.value}, body={self.body})" + + +class DefaultNode(ASTNode): + def __init__(self, statements): + self.statements = statements + + def __repr__(self): + return f"DefaultNode(statements={self.statements})" + + +class WhileNode(ASTNode): + def __init__(self, condition, body): + self.condition = condition + self.body = body + + def __repr__(self): + return f"WhileNode(condition={self.condition}, body={self.body})" + + +class DoWhileNode(ASTNode): + def __init__(self, body, condition): + self.body = body + self.condition = condition + + def __repr__(self): + return f"DoWhileNode(body={self.body}, condition={self.condition})" + + +class AssignmentNode(ASTNode): + def __init__(self, name, value): + self.name = name + self.value = value + + def __repr__(self): + return f"AssignmentNode(name={self.name}, value={self.value})" + + +class BreakNode(ASTNode): + def __init__(self): + pass + + def __repr__(self): + return f"BreakNode()" diff --git a/crosstl/src/backend/Vulkan/VulkanLexer.py b/crosstl/src/backend/Vulkan/VulkanLexer.py new file mode 100644 index 00000000..471ab723 --- /dev/null +++ b/crosstl/src/backend/Vulkan/VulkanLexer.py @@ -0,0 +1,177 @@ +import re + +TOKENS = [ + ("COMMENT_SINGLE", r"//.*"), + ("COMMENT_MULTI", r"/\*[\s\S]*?\*/"), + ("WHITESPACE", r"\s+"), + ("SEMANTIC", r":\w+"), + ("PRE_INCREMENT", r"\+\+(?=\w)"), + ("PRE_DECREMENT", r"--(?=\w)"), + ("POST_INCREMENT", r"(?<=\w)\+\+"), + ("POST_DECREMENT", r"(?<=\w)--"), + ("IDENTIFIER", r"[a-zA-Z_][a-zA-Z0-9_]*"), + ("NUMBER", r"\d+(\.\d*)?|\.\d+"), + ("SEMICOLON", r";"), + ("LBRACE", r"\{"), + ("RBRACE", r"\}"), + ("LPAREN", r"\("), + ("RPAREN", r"\)"), + ("COMMA", r","), + ("DOT", r"\."), + ("EQUAL", r"=="), + ("ASSIGN_AND", r"&="), + ("ASSIGN_OR", r"\|="), + ("ASSIGN_XOR", r"\^="), + ("PLUS_EQUALS", r"\+="), + ("MINUS_EQUALS", r"-="), + ("MULTIPLY_EQUALS", r"\*="), + ("DIVIDE_EQUALS", r"/="), + ("ASSIGN_MOD", r"%="), + ("ASSIGN_SHIFT_LEFT", r"<<="), + ("ASSIGN_SHIFT_RIGHT", r">>="), + ("BITWISE_SHIFT_LEFT", r"<<"), + ("BITWISE_SHIFT_RIGHT", r">>"), + ("EQUALS", r"="), + ("PLUS", r"\+"), + ("MINUS", r"-"), + ("MULTIPLY", r"\*"), + ("DIVIDE", r"/"), + ("MODULUS", r"%"), + ("LESS_EQUAL", r"<="), + ("GREATER_EQUAL", r">="), + ("NOT_EQUAL", r"!="), + ("LESS_THAN", r"<"), + ("GREATER_THAN", r">"), + ("AND", r"&&"), + ("OR", r"\|\|"), + ("BINARY_AND", r"&"), + ("BINARY_OR", r"\|"), + ("BINARY_XOR", r"\^"), + ("BINARY_NOT", r"~"), + ("QUESTION", r"\?"), + ("COLON", r":"), +] + +KEYWORDS = { + "struct": "STRUCT", + "layout": "LAYOUT", + "uniform": "UNIFORM", + "sampler2D": "SAMPLER2D", + "samplerCube": "SAMPLERCUBE", + "vec2": "VEC2", + "vec3": "VEC3", + "vec4": "VEC4", + "ivec2": "IVEC2", + "ivec3": "IVEC3", + "ivec4": "IVEC4", + "uvec2": "UVEC2", + "uvec3": "UVEC3", + "uvec4": "UVEC4", + "bvec2": "BVEC2", + "bvec3": "BVEC3", + "bvec4": "BVEC4", + "int": "INT", + "uint": "UINT", + "bool": "BOOL", + "float": "FLOAT", + "double": "DOUBLE", + "void": "VOID", + "return": "RETURN", + "if": "IF", + "else": "ELSE", + "for": "FOR", + "while": "WHILE", + "do": "DO", + "switch": "SWITCH", + "case": "CASE", + "default": "DEFAULT", + "break": "BREAK", + "continue": "CONTINUE", + "discard": "DISCARD", + "in": "IN", + "out": "OUT", + "inout": "INOUT", + "attribute": "ATTRIBUTE", + "varying": "VARYING", + "const": "CONST", + "precision": "PRECISION", + "highp": "HIGHP", + "mediump": "MEDIUMP", + "lowp": "LOWP", + "subpassInput": "SUBPASSINPUT", + "subpassInputMS": "SUBPASSINPUTMS", + "sampler2DArray": "SAMPLER2DARRAY", + "sampler2DMS": "SAMPLER2DMS", + "sampler2DMSArray": "SAMPLER2DMSARRAY", + "sampler3D": "SAMPLER3D", + "samplerCubeArray": "SAMPLERCUBEARRAY", + "image2D": "IMAGE2D", + "image3D": "IMAGE3D", + "imageCube": "IMAGECUBE", + "imageBuffer": "IMAGEBUFFER", + "image2DArray": "IMAGE2DARRAY", + "imageCubeArray": "IMAGECUBEARRAY", + "image1D": "IMAGE1D", + "image1DArray": "IMAGE1DARRAY", + "image2DMS": "IMAGE2DMS", + "image2DMSArray": "IMAGE2DMSARRAY", + "atomic_uint": "ATOMICUINT", + "mat2": "MAT2", + "mat3": "MAT3", + "mat4": "MAT4", +} + +VALID_DATA_TYPES = [ + "int", + "float", + "double", + "vec2", + "vec3", + "vec4", + "mat2", + "mat3", + "mat4", + "uint", + "bool", + "void", +] + + +class VulkanLexer: + def __init__(self, code): + self.code = code + self.tokens = [] + self.tokenize() + + def tokenize(self): + pos = 0 + while pos < len(self.code): + match = None + for token_type, pattern in TOKENS: + regex = re.compile(pattern) + match = regex.match(self.code, pos) + if match: + text = match.group(0) + if token_type == "IDENTIFIER" and text in KEYWORDS: + token_type = KEYWORDS[text] + if token_type == "VERSION": + self.tokens.append((token_type, text)) + elif token_type == "VERSION_NUMBER": + self.tokens.append((token_type, text)) + elif token_type == "CORE": + self.tokens.append((token_type, text)) + elif token_type != "WHITESPACE": # Ignore whitespace tokens + token = (token_type, text) + self.tokens.append(token) + pos = match.end(0) + break + if not match: + unmatched_char = self.code[pos] + highlighted_code = ( + self.code[:pos] + "[" + self.code[pos] + "]" + self.code[pos + 1 :] + ) + raise SyntaxError( + f"Illegal character '{unmatched_char}' at position {pos}\n{highlighted_code}" + ) + + self.tokens.append(("EOF", None)) diff --git a/crosstl/src/backend/Vulkan/VulkanParser.py b/crosstl/src/backend/Vulkan/VulkanParser.py new file mode 100644 index 00000000..6e88f608 --- /dev/null +++ b/crosstl/src/backend/Vulkan/VulkanParser.py @@ -0,0 +1,665 @@ +from VulkanLexer import * +from VulkanAst import * + + +class VulkanParser: + def __init__(self, tokens): + self.tokens = tokens + self.pos = 0 + self.current_token = self.tokens[self.pos] + self.skip_comments() + + def skip_comments(self): + while self.current_token[0] in ["COMMENT_SINGLE", "COMMENT_MULTI"]: + self.eat(self.current_token[0]) + + def peek(self, offset): + """Look ahead by offset tokens without consuming them.""" + peek_index = self.pos + offset + if peek_index < len(self.tokens): + return self.tokens[peek_index][ + 0 + ] # Return the type of the token at the peeked index + return None + + def eat(self, token_type): + if self.current_token[0] == token_type: + self.pos += 1 + self.current_token = ( + self.tokens[self.pos] if self.pos < len(self.tokens) else ("EOF", None) + ) + self.skip_comments() + else: + raise SyntaxError(f"Expected {token_type}, got {self.current_token[0]}") + + def parse(self): + module = self.parse_module() + self.eat("EOF") + return module + + def parse_module(self): + statements = [] + while self.current_token[0] != "EOF": + if self.current_token[0] == "LAYOUT": + statements.append(self.parse_layout()) + elif self.current_token[0] == "STRUCT": + statements.append(self.parse_struct()) + elif self.current_token[0] == "UNIFORM": + statements.append(self.parse_uniform()) + elif ( + self.current_token[1] in VALID_DATA_TYPES + and self.peek(1) == "IDENTIFIER" + and self.peek(2) == "LPAREN" + ): + statements.append(self.parse_function()) + elif ( + self.current_token[0] == "IDENTIFIER" + or self.current_token[1] in VALID_DATA_TYPES + ): + statements.append(self.parse_variable()) + else: + self.eat(self.current_token[0]) + return ShaderNode(None, None, None, statements) + + def parse_layout(self): + self.eat("LAYOUT") + self.eat("LPAREN") + bindings = [] # Stores pairs like ('location', '0'), ('binding', '1'), etc. + push_constant = False + if self.current_token[0] == "PUSH_CONSTANT": + push_constant = True + self.eat("PUSH_CONSTANT") + if self.current_token[0] == "COMMA": + self.eat("COMMA") + + # Parse layout bindings + while self.current_token[0] != "RPAREN": + binding_name = self.current_token[1] + self.eat("IDENTIFIER") + + # Handle assignment with EQUALS and a number + if self.current_token[0] == "EQUALS": + self.eat("EQUALS") + binding_value = self.current_token[1] + self.eat("NUMBER") + bindings.append((binding_name, binding_value)) + else: + bindings.append((binding_name, None)) + + if self.current_token[0] == "COMMA": + self.eat("COMMA") + + self.eat("RPAREN") + + layout_type = None + if self.current_token[0] in ["IN", "OUT", "UNIFORM", "BUFFER"]: + layout_type = self.current_token[0] + self.eat(layout_type) + if self.current_token[0] == "IDENTIFIER": + self.eat(self.current_token[0]) + + data_type = None + struct_fields = None + if layout_type in ["UNIFORM", "BUFFER"]: + # If a curly brace follows, we have a structured data block + if self.current_token[0] == "LBRACE": + self.eat("LBRACE") + struct_fields = [] + + # Parse structured fields within the uniform/push_constant/buffer block + while self.current_token[0] != "RBRACE": + if self.current_token[1] in VALID_DATA_TYPES: + field_type = self.current_token[1] # Field type (e.g., mat4) + self.eat(self.current_token[0]) + else: + raise SyntaxError( + "Expected some data type before an identifier" + ) + field_name = self.current_token[1] # Field name + self.eat("IDENTIFIER") + self.eat("SEMICOLON") + struct_fields.append((field_type, field_name)) + + self.eat("RBRACE") + data_type = "struct" # Use 'struct' as data_type placeholder for uniform/push_constant/buffer + else: + raise SyntaxError( + "Expected structured data block after 'uniform' or 'buffer'" + ) + else: + # For `in` and `out`, expect a data type and variable name + if self.current_token[1] in VALID_DATA_TYPES: + data_type = self.current_token[1] + self.eat(self.current_token[0]) + else: + raise SyntaxError(f"Unexpected type: {self.current_token[1]}") + + # Parse variable name + variable_name = None + if self.current_token[0] == "IDENTIFIER": + variable_name = self.current_token[1] + self.eat("IDENTIFIER") + + self.eat("SEMICOLON") + return LayoutNode( + bindings, + push_constant, + layout_type, + data_type, + variable_name, + struct_fields, + ) + + def parse_push_constant(self): + self.eat("PUSH_CONSTANT") + self.eat("LBRACE") + members = [] + while self.current_token[0] != "RBRACE": + members.append(self.parse_variable()) + self.eat("RBRACE") + return PushConstantNode(members) + + def parse_descriptor_set(self): + self.eat("DESCRIPTOR_SET") + set_number = self.current_token[1] + self.eat("NUMBER") + self.eat("LBRACE") + bindings = [] + while self.current_token[0] != "RBRACE": + bindings.append(self.parse_variable()) + self.eat("RBRACE") + return DescriptorSetNode(set_number, bindings) + + def parse_struct(self): + self.eat("STRUCT") + name = self.current_token[1] + self.eat("IDENTIFIER") + self.eat("LBRACE") + members = [] + type_name = None + while self.current_token[0] != "RBRACE": + if self.current_token[1] in VALID_DATA_TYPES: + type_name = self.current_token[1] + self.eat(self.current_token[0]) + if self.current_token[0] == "IDENTIFIER": + members.append(self.parse_variable(type_name)) + self.eat("RBRACE") + return StructNode(name, members) + + def parse_function(self): + return_type = self.current_token[1] + if self.current_token[1] in VALID_DATA_TYPES: + self.eat(self.current_token[0]) + else: + raise SyntaxError(f"Unexpected type: {self.current_token[1]}") + func_name = self.current_token[1] + self.eat("IDENTIFIER") + self.eat("LPAREN") + params = self.parse_parameters() + self.eat("RPAREN") + body = self.parse_block() + return FunctionNode(func_name, return_type, params, body) + + def parse_parameters(self): + params = [] + while self.current_token[0] != "RPAREN": + vtype = self.current_token[1] + self.eat(self.current_token[0]) + name = self.current_token[1] + self.eat("IDENTIFIER") + params.append(VariableNode(vtype, name)) + if self.current_token[0] == "COMMA": + self.eat("COMMA") + return params + + def parse_block(self): + self.eat("LBRACE") + statements = [] + while self.current_token[0] != "RBRACE": + statements.append(self.parse_body()) + self.eat("RBRACE") + return statements + + def parse_body(self): + token_type = self.current_token[0] + + if token_type == "IDENTIFIER" or self.current_token[1] in VALID_DATA_TYPES: + return self.parse_assignment_or_function_call() + elif token_type == "IF": + return self.parse_if_statement() + elif token_type == "FOR": + return self.parse_for_statement() + elif token_type == "WHILE": + return self.parse_while_statement() + elif token_type == "DO": + return self.parse_do_while_statement() + elif token_type == "SWITCH": + return self.parse_switch_statement() + elif token_type == "BREAK": + self.eat("BREAK") + self.eat("SEMICOLON") + return BreakNode() + else: + return self.parse_expression_statement() + + def parse_update(self): + if self.current_token[0] == "IDENTIFIER": + name = self.current_token[1] + self.eat("IDENTIFIER") + if self.current_token[0] == "POST_INCREMENT": + self.eat("POST_INCREMENT") + return UnaryOpNode("POST_INCREMENT", VariableNode(name, "")) + elif self.current_token[0] == "POST_DECREMENT": + self.eat("POST_DECREMENT") + return UnaryOpNode("POST_DECREMENT", VariableNode(name, "")) + elif self.current_token[0] in [ + "EQUALS", + "ASSIGN_ADD", + "ASSIGN_SUB", + "ASSIGN_MUL", + "ASSIGN_DIV", + ]: + op = self.current_token[0] + self.eat(op) + value = self.parse_expression() + if op == "EQUALS": + return AssignmentNode(name, value) + elif op == "ASSIGN_ADD": + return AssignmentNode( + name, BinaryOpNode(VariableNode(name, ""), "+", value) + ) + elif op == "ASSIGN_SUB": + return AssignmentNode( + name, BinaryOpNode(VariableNode(name, ""), "-", value) + ) + elif op == "ASSIGN_MUL": + return AssignmentNode( + name, BinaryOpNode(VariableNode(name, ""), "*", value) + ) + elif op == "ASSIGN_DIV": + return AssignmentNode( + name, BinaryOpNode(VariableNode(name, ""), "/", value) + ) + else: + raise SyntaxError( + f"Expected INCREMENT or DECREMENT, got {self.current_token[0]}" + ) + elif self.current_token[0] == "PRE_INCREMENT": + self.eat("PRE_INCREMENT") + if self.current_token[0] == "IDENTIFIER": + name = self.current_token[1] + self.eat("IDENTIFIER") + return UnaryOpNode("PRE_INCREMENT", VariableNode(name, "")) + else: + raise SyntaxError( + f"Expected IDENTIFIER after PRE_INCREMENT, got {self.current_token[0]}" + ) + elif self.current_token[0] == "PRE_DECREMENT": + self.eat("PRE_DECREMENT") + if self.current_token[0] == "IDENTIFIER": + name = self.current_token[1] + self.eat("IDENTIFIER") + return UnaryOpNode("PRE_DECREMENT", VariableNode(name, "")) + else: + raise SyntaxError( + f"Expected IDENTIFIER after PRE_DECREMENT, got {self.current_token[0]}" + ) + else: + raise SyntaxError(f"Unexpected token in update: {self.current_token[0]}") + + def parse_if_statement(self): + self.eat("IF") + self.eat("LPAREN") + if_condition = self.parse_expression() + self.eat("RPAREN") + if_body = self.parse_block() + else_body = None + else_if_condition = [] + else_if_body = [] + while self.current_token[0] == "ELSE" and self.peek(1) == "IF": + self.eat("ELSE") + self.eat("IF") + self.eat("LPAREN") + else_if_condition.append(self.parse_expression()) + self.eat("RPAREN") + self.eat("LBRACE") + else_if_body.append(self.parse_body()) + self.eat("RBRACE") + if self.current_token[0] == "ELSE": + self.eat("ELSE") + else_body = self.parse_block() + return IfNode(if_condition, if_body, else_if_condition, else_if_body, else_body) + + def parse_for_statement(self): + self.eat("FOR") + self.eat("LPAREN") + initialization = self.parse_assignment_or_function_call() + condition = self.parse_expression() + self.eat("SEMICOLON") + increment = self.parse_update() + self.eat("RPAREN") + body = self.parse_block() + return ForNode(initialization, condition, increment, body) + + def parse_variable(self, type_name): + name = self.current_token[1] + self.eat("IDENTIFIER") + + while self.current_token[0] == "DOT": + self.eat("DOT") + member_name = self.current_token[1] + self.eat("IDENTIFIER") + name += "." + member_name + + if self.current_token[0] == "SEMICOLON": + self.eat("SEMICOLON") + return VariableNode(name, type_name) + + elif self.current_token[0] == "EQUALS": + self.eat("EQUALS") + value = self.parse_expression() + + if self.current_token[0] == "SEMICOLON": + self.eat("SEMICOLON") + return AssignmentNode(VariableNode(name, type_name), value) + else: + raise SyntaxError( + f"Expected ';' after variable assignment, found: {self.current_token[0]}" + ) + + elif self.current_token[0] in ( + "EQUALS", + "PLUS_EQUALS", + "MINUS_EQUALS", + "MULTIPLY_EQUALS", + "DIVIDE_EQUALS", + "EQUAL", + "LESS_THAN", + "GREATER_THAN", + "LESS_EQUAL", + "GREATER_EQUAL", + "ASSIGN_AND", + "ASSIGN_OR", + "ASSIGN_XOR", + "ASSIGN_MOD", + "BITWISE_SHIFT_RIGHT", + "BITWISE_SHIFT_LEFT", + "BITWISE_XOR", + "ASSIGN_SHIFT_LEFT", + "ASSIGN_SHIFT_RIGHT", + ): + op = self.current_token[0] + op_name = self.current_token[1] + self.eat(op) + value = self.parse_expression() + if self.current_token[0] == "SEMICOLON": + self.eat("SEMICOLON") + return BinaryOpNode(VariableNode(name, type_name), op_name, value) + else: + raise SyntaxError( + f"Expected ';' after compound assignment, found: {self.current_token[0]}" + ) + else: + raise SyntaxError( + f"Unexpected token in variable declaration: {self.current_token[0]}" + ) + + def parse_member_access(self, object): + self.eat("DOT") + if self.current_token[0] != "IDENTIFIER": + raise SyntaxError( + f"Expected identifier after dot, got {self.current_token[0]}" + ) + member = self.current_token[1] + self.eat("IDENTIFIER") + + if self.current_token[0] == "DOT": + return self.parse_member_access(MemberAccessNode(object, member)) + + return MemberAccessNode(object, member) + + def parse_function_call(self, name): + self.eat("LPAREN") + args = [] + if self.current_token[0] != "RPAREN": + args.append(self.parse_expression()) + while self.current_token[0] == "COMMA": + self.eat("COMMA") + args.append(self.parse_expression()) + self.eat("RPAREN") + return FunctionCallNode(name, args) + + def parse_function_call_or_identifier(self): + func_name = self.current_token[1] + self.eat(self.current_token[0]) + + if self.current_token[0] == "LPAREN": + return self.parse_function_call(func_name) + elif self.current_token[0] == "DOT": + return self.parse_member_access(func_name) + return VariableNode(func_name, "") + + def parse_primary(self): + if self.current_token[0] == "MINUS": + self.eat("MINUS") + value = self.parse_primary() + return UnaryOpNode("-", value) + + if ( + self.current_token[0] == "IDENTIFIER" + or self.current_token[1] in VALID_DATA_TYPES + ): + return self.parse_function_call_or_identifier() + elif self.current_token[0] == "NUMBER": + value = self.current_token[1] + self.eat("NUMBER") + return value + elif self.current_token[0] == "LPAREN": + self.eat("LPAREN") + expr = self.parse_expression() + self.eat("RPAREN") + return expr + else: + raise SyntaxError( + f"Unexpected token in expression: {self.current_token[0]}" + ) + + def parse_multiplicative(self): + left = self.parse_primary() + while self.current_token[0] in ["MULTIPLY", "DIVIDE"]: + op = self.current_token[0] + self.eat(op) + right = self.parse_primary() + left = BinaryOpNode(left, op, right) + return left + + def parse_additive(self): + left = self.parse_multiplicative() + while self.current_token[0] in ["PLUS", "MINUS"]: + op = self.current_token[0] + self.eat(op) + right = self.parse_multiplicative() + left = BinaryOpNode(left, op, right) + return left + + def parse_assignment(self, name): + if self.current_token[0] in [ + "EQUALS", + "PLUS_EQUALS", + "MINUS_EQUALS", + "MULTIPLY_EQUALS", + "DIVIDE_EQUALS", + "LESS_THAN", + "GREATER_THAN", + "LESS_EQUAL", + "GREATER_EQUAL", + "ASSIGN_AND", + "ASSIGN_OR", + "ASSIGN_XOR", + "ASSIGN_MOD", + "BITWISE_SHIFT_RIGHT", + "BITWISE_SHIFT_LEFT", + "BITWISE_XOR", + "ASSIGN_SHIFT_LEFT", + "ASSIGN_SHIFT_RIGHT", + ]: + op = self.current_token[0] + op_name = self.current_token[1] + self.eat(op) + value = self.parse_expression() + if self.current_token[0] == "SEMICOLON": + self.eat("SEMICOLON") + return BinaryOpNode(name, op_name, value) + else: + raise SyntaxError( + f"Expected assignment operator, found: {self.current_token[0]}" + ) + + def parse_assignment_or_function_call(self): + type_name = "" + if self.current_token[0] == "IDENTIFIER" and self.peek(1) in [ + "POST_INCREMENT", + "POST_DECREMENT", + ]: + name = self.current_token[1] + self.eat("IDENTIFIER") + + if self.current_token[0] in [ + "EQUALS", + "PLUS_EQUALS", + "MINUS_EQUALS", + "MULTIPLY_EQUALS", + "DIVIDE_EQUALS", + "LESS_THAN", + "GREATER_THAN", + "LESS_EQUAL", + "GREATER_EQUAL", + "ASSIGN_AND", + "ASSIGN_OR", + "ASSIGN_XOR", + "ASSIGN_MOD", + "BITWISE_SHIFT_RIGHT", + "BITWISE_SHIFT_LEFT", + "BITWISE_XOR", + "ASSIGN_SHIFT_LEFT", + "ASSIGN_SHIFT_RIGHT", + ]: + return self.parse_assignment(name) # todo + elif self.current_token[0] == "POST_INCREMENT": + self.eat("POST_INCREMENT") + self.eat("SEMICOLON") + return AssignmentNode( + name, UnaryOpNode("POST_INCREMENT", VariableNode("", name)) + ) + elif self.current_token[0] == "POST_DECREMENT": + self.eat("POST_DECREMENT") + self.eat("SEMICOLON") + return AssignmentNode( + name, UnaryOpNode("POST_DECREMENT", VariableNode("", name)) + ) + elif self.current_token[0] == "LPAREN": + return self.parse_function_call(name) + else: + raise SyntaxError( + f"Unexpected token after identifier: {self.current_token[0]}" + ) + if self.current_token[1] in VALID_DATA_TYPES: + type_name = self.current_token[1] + self.eat(self.current_token[0]) + if self.current_token[0] == "IDENTIFIER": + return self.parse_variable(type_name) + + def parse_expression(self): + left = self.parse_additive() + while self.current_token[0] in [ + "LESS_THAN", + "GREATER_THAN", + "LESS_EQUAL", + "GREATER_EQUAL", + "EQUAL", + "NOT_EQUAL", + "AND", + "OR", + ]: + op = self.current_token[0] + self.eat(op) + right = self.parse_additive() + left = BinaryOpNode(left, op, right) + + if self.current_token[0] == "QUESTION": + self.eat("QUESTION") + true_expr = self.parse_expression() + self.eat("COLON") + false_expr = self.parse_expression() + left = TernaryOpNode(left, true_expr, false_expr) + + return left + + def parse_expression_statement(self): + expr = self.parse_expression() + # self.eat("SEMICOLON") + return expr + + def parse_while_statement(self): + self.eat("WHILE") + self.eat("LPAREN") + condition = self.parse_expression() + self.eat("RPAREN") + body = self.parse_block() + return WhileNode(condition, body) + + def parse_do_while_statement(self): + self.eat("DO") + body = self.parse_block() + self.eat("WHILE") + self.eat("LPAREN") + condition = self.parse_expression() + self.eat("RPAREN") + self.eat("SEMICOLON") + return DoWhileNode(condition, body) + + def parse_switch_statement(self): + self.eat("SWITCH") + self.eat("LPAREN") + expr = self.parse_expression() + self.eat("RPAREN") + self.eat("LBRACE") + cases = [] + while self.current_token[0] != "RBRACE": + cases.append(self.parse_case_statement()) + self.eat("RBRACE") + return SwitchNode(expr, cases) + + def parse_case_statement(self): + if self.current_token[0] == "CASE": + self.eat("CASE") + value = self.parse_expression() + self.eat("COLON") + elif self.current_token[0] == "DEFAULT": + self.eat("DEFAULT") + value = None + self.eat("COLON") + statements = [] + while self.current_token[0] not in ["CASE", "DEFAULT", "RBRACE"]: + statements.append(self.parse_body()) + return CaseNode(value, statements) + + def parse_default_statement(self): + self.eat("DEFAULT") + self.eat("COLON") + statements = [] + while self.current_token[0] not in ["CASE", "RBRACE"]: + statements.append(self.parse_body()) + return DefaultNode(statements) + + def parse_uniform(self): + self.eat("UNIFORM") + var_type = self.current_token[1] + if self.current_token[1] in VALID_DATA_TYPES: + self.eat(self.current_token[0]) + else: + raise SyntaxError(f"Unexpected type: {self.current_token[1]}") + name = self.current_token[1] + self.eat("IDENTIFIER") + self.eat("SEMICOLON") + return UniformNode(name, var_type)