Skip to content

Commit

Permalink
fix: fixed variables and assignment defination
Browse files Browse the repository at this point in the history
  • Loading branch information
samthakur587 committed Aug 5, 2024
1 parent 3b204dd commit c59ac8a
Show file tree
Hide file tree
Showing 13 changed files with 409 additions and 409 deletions.
63 changes: 57 additions & 6 deletions crosstl/src/backend/DirectX/DirectxCrossGLCodeGen.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,33 @@ def __init__(self):
self.fragment_inputs = []
self.fragment_outputs = []
self.type_map = {
"float": "float",
"void": "void",
"float2": "vec2",
"float3": "vec3",
"float4": "vec4",
"float2x2": "mat2",
"float3x3": "mat3",
"float4x4": "mat4",
"int": "int",
"int2": "ivec2",
"int3": "ivec3",
"int4": "ivec4",
"uint": "uint",
"uint2": "uvec2",
"uint3": "uvec3",
"uint4": "uvec4",
"bool": "bool",
"bool2": "bvec2",
"bool3": "bvec3",
"bool4": "bvec4",
"float": "float",
"double": "double",
"Texture2D": "sampler2D",
"TextureCube": "samplerCube",
}

def generate(self, ast):
self.process_structs(ast)

code = "shader main {\n"

# Generate custom functions
Expand Down Expand Up @@ -104,13 +121,41 @@ def generate_function_body(self, body, indent=0, is_main=False):
for stmt in body:
code += " " * indent
if isinstance(stmt, VariableNode):
if not is_main:
code += f"{self.map_type(stmt.vtype)} {stmt.name};\n"
code += f"{self.map_type(stmt.vtype)} {stmt.name};\n"
elif isinstance(stmt, AssignmentNode):
code += self.generate_assignment(stmt, is_main) + ";\n"
elif isinstance(stmt, ReturnNode):
if not is_main:
code += f"return {self.generate_expression(stmt.value, is_main)};\n"
elif isinstance(stmt, ForNode):
code += self.generate_for_loop(stmt, indent, is_main)
elif isinstance(stmt, IfNode):
code += self.generate_if_statement(stmt, indent, is_main)
return code

def generate_for_loop(self, node, indent, is_main):
init = self.generate_expression(node.init, is_main)
condition = self.generate_expression(node.condition, is_main)
update = self.generate_expression(node.update, is_main)

code = f"for ({init}; {condition}; {update}) {{\n"
code += self.generate_function_body(node.body, indent + 1, is_main)
code += " " * indent + "}\n"
return code

def generate_if_statement(self, node, indent, is_main):
condition = self.generate_expression(node.condition, is_main)

code = f"if ({condition}) {{\n"
code += self.generate_function_body(node.if_body, indent + 1, is_main)
code += " " * indent + "}"

if node.else_body:
code += " else {\n"
code += self.generate_function_body(node.else_body, indent + 1, is_main)
code += " " * indent + "}"

code += "\n"
return code

def generate_assignment(self, node, is_main):
Expand All @@ -134,10 +179,16 @@ def generate_expression(self, expr, is_main=False):
elif isinstance(expr, BinaryOpNode):
left = self.generate_expression(expr.left, is_main)
right = self.generate_expression(expr.right, is_main)
return f"({left} {expr.op} {right})"
return f"{left} {expr.op} {right}"

elif isinstance(expr, AssignmentNode):
left = self.generate_expression(expr.left, is_main)
right = self.generate_expression(expr.right, is_main)
return f"{left} {expr.operator} {right}"

elif isinstance(expr, UnaryOpNode):
operand = self.generate_expression(expr.operand, is_main)
return f"({expr.op}{operand})"
return f"{expr.op}{operand}"
elif isinstance(expr, FunctionCallNode):
args = ", ".join(
self.generate_expression(arg, is_main) for arg in expr.args
Expand Down
21 changes: 13 additions & 8 deletions crosstl/src/backend/DirectX/DirectxParser.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,26 +152,27 @@ def parse_variable_declaration_or_assignment(self):
"BOOL",
"IDENTIFIER",
]:
# This could be a type name or a variable name
first_token = self.current_token
self.eat(self.current_token[0])

if self.current_token[0] == "IDENTIFIER":
# This is a variable declaration
name = self.current_token[1]
self.eat("IDENTIFIER")
if self.current_token[0] == "SEMICOLON":
# Variable declaration without initialization
self.eat("SEMICOLON")
return VariableNode(first_token[1], name)
elif self.current_token[0] == "EQUALS":
# Variable declaration with initialization
self.eat("EQUALS")
value = self.parse_expression()
self.eat("SEMICOLON")
return AssignmentNode(VariableNode(first_token[1], name), value)
else:
# This is an assignment or a more complex expression
elif self.current_token[0] == "EQUALS":
# This handles cases like "test = float3(1.0, 1.0, 1.0);"
self.eat("EQUALS")
value = self.parse_expression()
self.eat("SEMICOLON")
return AssignmentNode(VariableNode("", first_token[1]), value)
elif self.current_token[0] == "DOT":
left = self.parse_member_access(first_token[1])
if self.current_token[0] == "EQUALS":
self.eat("EQUALS")
Expand All @@ -181,8 +182,11 @@ def parse_variable_declaration_or_assignment(self):
else:
self.eat("SEMICOLON")
return left
else:
expr = self.parse_expression()
self.eat("SEMICOLON")
return expr
else:
# This is an expression statement
expr = self.parse_expression()
self.eat("SEMICOLON")
return expr
Expand Down Expand Up @@ -211,7 +215,8 @@ def parse_for_statement(self):
self.eat("IDENTIFIER")
self.eat("EQUALS")
init_value = self.parse_expression()
init = VariableNode(type_name, var_name, init_value)
init = VariableNode(type_name, var_name)
init = AssignmentNode(init, init_value)
else:
init = self.parse_expression()
self.eat("SEMICOLON")
Expand Down
66 changes: 62 additions & 4 deletions crosstl/src/backend/Metal/MetalCrossGLCodeGen.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,37 @@ def __init__(self):
self.fragment_inputs = []
self.fragment_outputs = []
self.type_map = {
# Scalar Types
"void": "void",
"int": "short",
"uint": "unsigned short",
"int64_t": "long",
"uint64_t": "unsigned long",
"float": "float",
"half": "half",
"bool": "bool",
# Vector Types
"float2": "vec2",
"float3": "vec3",
"float4": "vec4",
"int": "int",
"int2": "short2",
"int3": "short3",
"int4": "short4",
"uint2": "ushort2",
"uint3": "ushort3",
"uint4": "ushort4",
"bool2": "bvec2",
"bool3": "bvec3",
"bool4": "bvec4",
"Texture2D": "sampler2D",
"TextureCube": "samplerCube",
# Matrix Types
"float2x2": "mat2",
"float3x3": "mat3",
"float4x4": "mat4",
"half2x2": "half2x2",
"half3x3": "half3x3",
"half4x4": "half4x4",
}

def generate(self, ast):
Expand Down Expand Up @@ -111,17 +137,43 @@ def generate_main_function(self, func):
def generate_function_body(self, body, indent=0, is_main=False):
code = ""
for stmt in body:
if isinstance(stmt, VariableNode) and stmt.name in ["output", "input"]:
continue
code += " " * indent
if isinstance(stmt, VariableNode):

code += f"{self.map_type(stmt.vtype)} {stmt.name};\n"
elif isinstance(stmt, AssignmentNode):
code += self.generate_assignment(stmt, is_main) + ";\n"
elif isinstance(stmt, ReturnNode):
if not is_main:
code += f"return {self.generate_expression(stmt.value, is_main)};\n"
elif isinstance(stmt, ForNode):
code += self.generate_for_loop(stmt, indent, is_main)
elif isinstance(stmt, IfNode):
code += self.generate_if_statement(stmt, indent, is_main)
return code

def generate_for_loop(self, node, indent, is_main):
init = self.generate_expression(node.init, is_main)
condition = self.generate_expression(node.condition, is_main)
update = self.generate_expression(node.update, is_main)

code = f"for ({init}; {condition}; {update}) {{\n"
code += self.generate_function_body(node.body, indent + 1, is_main)
code += " " * indent + "}\n"
return code

def generate_if_statement(self, node, indent, is_main):
condition = self.generate_expression(node.condition, is_main)

code = f"if ({condition}) {{\n"
code += self.generate_function_body(node.if_body, indent + 1, is_main)
code += " " * indent + "}"

if node.else_body:
code += " else {\n"
code += self.generate_function_body(node.else_body, indent + 1, is_main)
code += " " * indent + "}"

code += "\n"
return code

def generate_assignment(self, node, is_main):
Expand Down Expand Up @@ -163,6 +215,12 @@ def generate_expression(self, expr, is_main=False):
if obj == "output" or obj == "input":
return expr.member
return f"{obj}.{expr.member}"

elif isinstance(expr, AssignmentNode):
left = self.generate_expression(expr.left, is_main)
right = self.generate_expression(expr.right, is_main)
return f"({left} {expr.operator} {right})"

elif isinstance(expr, UnaryOpNode):
operand = self.generate_expression(expr.operand, is_main)
return f"({expr.op}{operand})"
Expand Down
18 changes: 14 additions & 4 deletions crosstl/src/backend/Metal/MetalParser.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,11 +245,10 @@ def parse_statement(self):
def parse_variable_declaration_or_assignment(self):
if self.current_token[0] in [
"FLOAT",
"HALF",
"FVECTOR",
"INT",
"UINT",
"BOOL",
"VECTOR",
"IDENTIFIER",
]:
first_token = self.current_token
Expand All @@ -266,7 +265,13 @@ def parse_variable_declaration_or_assignment(self):
value = self.parse_expression()
self.eat("SEMICOLON")
return AssignmentNode(VariableNode(first_token[1], name), value)
else:
elif self.current_token[0] == "EQUALS":
# This handles cases like "test = float3(1.0, 1.0, 1.0);"
self.eat("EQUALS")
value = self.parse_expression()
self.eat("SEMICOLON")
return AssignmentNode(VariableNode("", first_token[1]), value)
elif self.current_token[0] == "DOT":
left = self.parse_member_access(first_token[1])
if self.current_token[0] == "EQUALS":
self.eat("EQUALS")
Expand All @@ -276,6 +281,10 @@ def parse_variable_declaration_or_assignment(self):
else:
self.eat("SEMICOLON")
return left
else:
expr = self.parse_expression()
self.eat("SEMICOLON")
return expr
else:
expr = self.parse_expression()
self.eat("SEMICOLON")
Expand Down Expand Up @@ -304,7 +313,8 @@ def parse_for_statement(self):
self.eat("IDENTIFIER")
self.eat("EQUALS")
init_value = self.parse_expression()
init = VariableNode(type_name, var_name, init_value)
init = VariableNode(type_name, var_name)
init = AssignmentNode(init, init_value)
else:
init = self.parse_expression()
self.eat("SEMICOLON")
Expand Down
27 changes: 15 additions & 12 deletions crosstl/src/translator/codegen/directx_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def __init__(self):
self.vertex_item = None
self.fragment_item = None
self.gl_position = False
self.lhs = None
self.type_mapping = {
"void": "void",
"vec2": "float2",
Expand Down Expand Up @@ -58,7 +59,6 @@ def generate_shader(self, node):
self.shader_inputs = node.global_inputs
self.shader_outputs = node.global_outputs
code = "\n"

# Generate global inputs and outputs
if self.shader_inputs:
code += "struct VSINPUT {\n"
Expand Down Expand Up @@ -219,7 +219,9 @@ def generate_assignment(self, node, shader_type=None):
if isinstance(node.name, VariableNode) and node.name.vtype:
return f"{self.map_type(node.name.vtype)} {node.name.name} = {self.generate_expression(node.value, shader_type)}"
else:
self.lhs = True
lhs = self.generate_expression(node.name, shader_type)
self.lhs = False
if lhs == "gl_Position" or lhs == "gl_position":
return f"output.position = {self.generate_expression(node.value, shader_type)}"
return f"{lhs} = {self.generate_expression(node.value, shader_type)}"
Expand Down Expand Up @@ -269,29 +271,30 @@ def generate_expression(self, expr, shader_type=None):
if isinstance(expr, str):
return self.translate_expression(expr, shader_type)
elif isinstance(expr, VariableNode):
return self.translate_expression(expr.name, shader_type)
if isinstance(expr.name, str):
return self.translate_expression(expr.name, shader_type)
else:
return self.generate_expression(expr.name, shader_type)
elif isinstance(expr, BinaryOpNode):
return f"({self.generate_expression(expr.left, shader_type)} {self.map_operator(expr.op)} {self.generate_expression(expr.right, shader_type)})"
return f"{self.generate_expression(expr.left, shader_type)} {self.map_operator(expr.op)} {self.generate_expression(expr.right, shader_type)}"
elif isinstance(expr, FunctionCallNode):
args = ", ".join(
self.generate_expression(arg, shader_type) for arg in expr.args
)
if expr.name in self.type_mapping.keys():
args = ", ".join(
self.generate_expression(arg, shader_type) for arg in expr.args
)
return f"{self.map_type(expr.name)}({args})"
else:
args = ", ".join(
self.generate_expression(arg, shader_type) for arg in expr.args
)
func_name = self.translate_expression(expr.name, shader_type)
return f"{func_name}({args})"
elif isinstance(expr, UnaryOpNode):
return f"{self.map_operator(expr.op)}{self.generate_expression(expr.operand, shader_type)}"

elif isinstance(expr, TernaryOpNode):
return f"{self.generate_expression(expr.condition, shader_type)} ? {self.generate_expression(expr.true_expr, shader_type)} : {self.generate_expression(expr.false_expr, shader_type)}"

elif isinstance(expr, MemberAccessNode):
return f"{self.generate_expression(expr.object, shader_type)}.{expr.member}"
if self.lhs:
return f"{expr.member}.{self.generate_expression(expr.object, shader_type)}"
else:
return f"{self.generate_expression(expr.object, shader_type)}.{expr.member}"
else:
return str(expr)

Expand Down
Loading

0 comments on commit c59ac8a

Please sign in to comment.