Skip to content

Commit

Permalink
fix: added support for more complex statement decleration
Browse files Browse the repository at this point in the history
  • Loading branch information
samthakur587 committed Aug 16, 2024
1 parent 54f164a commit 949b041
Show file tree
Hide file tree
Showing 8 changed files with 814 additions and 300 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,5 @@ __pycache__/
/venv
/compiler/codegen/__pycache__
/compiler/__pycache__
*:Z*
*:Z*
*workspace
15 changes: 7 additions & 8 deletions crosstl/src/translator/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,24 +46,26 @@ def __repr__(self):


class VERTEXShaderNode:
def __init__(self, inputs, outputs, functions):
def __init__(self, inputs, outputs, functions,intermidiate):
self.inputs = inputs
self.outputs = outputs
self.functions = functions
self.intermidiate = intermidiate

def __repr__(self):
return f"VERTEXShaderNode({self.inputs!r}) {self.outputs!r} {self.functions!r}"
return f"VERTEXShaderNode({self.inputs!r}) {self.outputs!r} {self.functions!r} {self.intermidiate!r}"


class FRAGMENTShaderNode:
def __init__(self, inputs, outputs, functions):
def __init__(self, inputs, outputs, functions,intermidiate):
self.inputs = inputs
self.outputs = outputs
self.functions = functions
self.intermidiate = intermidiate

def __repr__(self):
return (
f"FRAGMENTShaderNode({self.inputs!r}) {self.outputs!r} {self.functions!r}"
f"FRAGMENTShaderNode({self.inputs!r}) {self.outputs!r} {self.functions!r} {self.intermidiate!r}"
)


Expand Down Expand Up @@ -159,7 +161,4 @@ def __init__(self, op, operand):
self.operand = operand

def __repr__(self):
return f"UnaryOpNode(operator={self.op}, operand={self.operand})"

def __str__(self):
return f"({self.op}{self.operand})"
return f"UnaryOpNode(op={self.op}, operand={self.operand})"
145 changes: 85 additions & 60 deletions crosstl/src/translator/codegen/directx_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ def __init__(self):
self.vertex_item = None
self.fragment_item = None
self.gl_position = False
self.lhs = None
self.type_mapping = {
"void": "void",
"vec2": "float2",
Expand Down Expand Up @@ -50,6 +49,7 @@ def __init__(self):
}

def generate(self, ast):
print(ast)
if isinstance(ast, ShaderNode):
self.current_shader = ast
return self.generate_shader(ast)
Expand Down Expand Up @@ -81,7 +81,7 @@ def generate_shader(self, node):

# Generate vertex shader section
self.vertex_item = node.vertex_section
if self.vertex_item:
if isinstance(self.vertex_item, VERTEXShaderNode):
shader_type = "vertex"
self.check_gl_position(self.vertex_item)
if self.vertex_item.inputs:
Expand All @@ -100,13 +100,24 @@ def generate_shader(self, node):
self.gl_position = False
code += f" {self.map_type(vtype)} {name} : TEXCOORD{i};\n"
code += "};\n\n"
code += (
f"{self.generate_function(self.vertex_item.functions, shader_type)}\n"
)

if self.vertex_item.functions:
code += (
f"{self.generate_function(self.vertex_item.functions, shader_type)}\n"
)

if self.vertex_item.intermidiate:
code += (
f"{self.generate_intermidiate(self.vertex_item.intermidiate, shader_type)}\n"
)
if self.vertex_item.functions:
for function_node in self.vertex_item.functions:
if function_node.name == "main":
code += f"{self.generate_main(function_node, shader_type)}\n"

# Generate fragment shader section
self.fragment_item = node.fragment_section
if self.fragment_item:
if isinstance(self.fragment_item, FRAGMENTShaderNode):
shader_type = "fragment"
if self.fragment_item.inputs:
code += "struct PSInput {\n"
Expand All @@ -118,10 +129,19 @@ def generate_shader(self, node):
for i, (vtype, name) in enumerate(self.fragment_item.outputs):
code += f" {self.map_type(vtype)} {name} : SV_TARGET{i};\n"
code += "};\n\n"
code += (
f"{self.generate_function(self.fragment_item.functions, shader_type)}\n"
)


if self.fragment_item.functions:
code += (
f"{self.generate_function(self.fragment_item.functions, shader_type)}\n"
)
if self.fragment_item.intermidiate:
code += (
f"{self.generate_intermidiate(self.fragment_item.intermidiate, shader_type)}\n"
)
if self.fragment_item.functions:
for function_node in self.fragment_item.functions:
if function_node.name == "main":
code += f"{self.generate_main(function_node, shader_type)}\n"
return code

def check_gl_position(self, node):
Expand All @@ -131,50 +151,42 @@ def check_gl_position(self, node):
vb_left = vb_name.split("=")[0].strip()
if vb_left == "output.position":
self.gl_position = True


def generate_intermidiate(self, node, shader_type):
code = ""
for stmt in node:
code += self.generate_statement(stmt, 0, shader_type=shader_type)
return code

def generate_function(self, node, shader_type):
code = ""
if shader_type == "vertex":
for function_node in node:
if function_node.name == "main":
params = "VSInput input"
return_type = "VSOutput"
code += f"{return_type} VSMain({params}) {{\n"
else:
if function_node.name != "main":
params = ", ".join(
f"{self.map_type(param[0])} {param[1]}"
for param in function_node.params
)
return_type = self.map_type(function_node.return_type)
code += f"{return_type} {function_node.name}({params}) {{\n"

if function_node.name == "main":
code += " VSOutput output;\n"
for stmt in function_node.body:
code += self.generate_statement(stmt, 1, shader_type=shader_type)
if function_node.name == "main":
code += " return output;\n"
code += "}\n"
for stmt in function_node.body:
code += self.generate_statement(stmt, 1, shader_type=shader_type)

code += "}\n"
elif shader_type == "fragment":
for function_node in node:
if function_node.name == "main":
params = "PSInput input"
return_type = "PSOutput"
code += f"{return_type} PSMain({params}) {{\n"
else:
if function_node.name != "main":
params = ", ".join(
f"{self.map_type(param[0])} {param[1]}"
for param in function_node.params
)
return_type = self.map_type(function_node.return_type)
code += f"{return_type} {function_node.name}({params}) {{\n"
if function_node.name == "main":
code += " PSOutput output;\n"
for stmt in function_node.body:
code += self.generate_statement(stmt, 1, shader_type=shader_type)
if function_node.name == "main":
code += " return output;\n"
code += "}\n"

for stmt in function_node.body:
code += self.generate_statement(stmt, 1, shader_type=shader_type)
code += "}\n"
elif shader_type == "global":
if node.name == "main":
params = "Global_INPUT input"
Expand All @@ -195,18 +207,39 @@ def generate_function(self, node, shader_type):
code += "}\n"
return code


def generate_main(self, node, shader_type):
if shader_type == "vertex":
code = "VSOutput VSMain(VSInput input) {\n"
code += " VSOutput output;\n"
elif shader_type == "fragment":
code = "PSOutput PSMain(PSInput input) {\n"
code += " PSOutput output;\n"
for stmt in node.body:
code += self.generate_statement(stmt, 1, shader_type)

code += " return output;\n"
code += "}\n"
return code


def generate_statement(self, stmt, indent=0, shader_type=None):
indent_str = " " * indent
if isinstance(stmt, VariableNode):
return f"{indent_str}{self.map_type(stmt.vtype)} {stmt.name};\n"
elif isinstance(stmt, AssignmentNode):
return f"{indent_str}{self.generate_assignment(stmt, shader_type)};\n"
elif isinstance(stmt, IfNode):
return self.generate_if(stmt, indent, shader_type)
return self.generate_if(stmt, indent, shader_type)
elif isinstance(stmt, ForNode):
return self.generate_for(stmt, indent, shader_type)
elif isinstance(stmt, ReturnNode):
return f"{indent_str}return {self.generate_expression(stmt.value, shader_type)};\n"
code = ""
for i,return_stmt in enumerate(stmt.value):
code += f"{self.generate_expression(return_stmt, shader_type)}"
if i < len(stmt.value) - 1:
code += ", "
return f"{indent_str}return {code};\n"
else:
return f"{indent_str}{self.generate_expression(stmt, shader_type)};\n"

Expand All @@ -219,9 +252,7 @@ def generate_assignment(self, node, shader_type=None):
if isinstance(node.name, VariableNode) and node.name.vtype:
return f"{self.map_type(node.name.vtype)} {node.name.name} = {self.generate_expression(node.value, shader_type)}"
else:
self.lhs = True
lhs = self.generate_expression(node.name, shader_type)
self.lhs = False
if lhs == "gl_Position" or lhs == "gl_position":
return f"output.position = {self.generate_expression(node.value, shader_type)}"
return f"{lhs} = {self.generate_expression(node.value, shader_type)}"
Expand All @@ -243,23 +274,16 @@ def generate_if(self, node, indent, shader_type=None):
def generate_for(self, node, indent, shader_type=None):
indent_str = " " * indent

if isinstance(node.init, AssignmentNode) and isinstance(
node.init.name, VariableNode
):
init = f"{self.map_type(node.init.name.vtype)} {node.init.name.name} = {self.generate_expression(node.init.value, shader_type)}"
else:
init = self.generate_statement(node.init, 0, shader_type).strip()[

init = self.generate_statement(node.init, 0, shader_type).strip()[
:-1
] # Remove trailing semicolon

condition = self.generate_expression(node.condition, shader_type)
condition = self.generate_statement(node.condition,0, shader_type).strip()[
:-1
] # Remove trailing semicolon

if isinstance(node.update, AssignmentNode) and isinstance(
node.update.value, UnaryOpNode
):
update = f"{node.update.value.operand.name}++"
else:
update = self.generate_statement(node.update, 0, shader_type).strip()[:-1]
update = self.generate_statement(node.update, 0, shader_type).strip()[:-1]

code = f"{indent_str}for ({init}; {condition}; {update}) {{\n"
for stmt in node.body:
Expand All @@ -272,9 +296,9 @@ def generate_expression(self, expr, shader_type=None):
return self.translate_expression(expr, shader_type)
elif isinstance(expr, VariableNode):
if isinstance(expr.name, str):
return self.translate_expression(expr.name, shader_type)
return f"{self.map_type(expr.vtype)} {self.translate_expression(expr.name, shader_type)}"
else:
return self.generate_expression(expr.name, shader_type)
return f"{self.map_type(expr.vtype)} {self.generate_expression(expr.name, shader_type)}"
elif isinstance(expr, BinaryOpNode):
return f"{self.generate_expression(expr.left, shader_type)} {self.map_operator(expr.op)} {self.generate_expression(expr.right, shader_type)}"
elif isinstance(expr, FunctionCallNode):
Expand All @@ -287,14 +311,11 @@ def generate_expression(self, expr, shader_type=None):
func_name = self.translate_expression(expr.name, shader_type)
return f"{func_name}({args})"
elif isinstance(expr, UnaryOpNode):
return f"{self.map_operator(expr.op)}{self.generate_expression(expr.operand, shader_type)}"
return f"{self.map_operator(expr.op)} {self.generate_expression(expr.operand, shader_type)}"
elif isinstance(expr, TernaryOpNode):
return f"{self.generate_expression(expr.condition, shader_type)} ? {self.generate_expression(expr.true_expr, shader_type)} : {self.generate_expression(expr.false_expr, shader_type)}"
elif isinstance(expr, MemberAccessNode):
if self.lhs:
return f"{expr.member}.{self.generate_expression(expr.object, shader_type)}"
else:
return f"{self.generate_expression(expr.object, shader_type)}.{expr.member}"
return f"{self.generate_expression(expr.object, shader_type)}.{expr.member}"
else:
return str(expr)

Expand All @@ -321,7 +342,10 @@ def translate_expression(self, expr, shader_type):
return self.type_mapping.get(expr, expr)

def map_type(self, vtype):
return self.type_mapping.get(vtype, vtype)
if vtype == "":
return ""
else:
return self.type_mapping.get(vtype, vtype)

def map_operator(self, op):
op_map = {
Expand All @@ -341,5 +365,6 @@ def map_operator(self, op):
"NOT_EQUAL": "!=",
"AND": "&&",
"OR": "||",
"EQUALS": "=",
}
return op_map.get(op, op)
Loading

0 comments on commit 949b041

Please sign in to comment.