Skip to content

Commit

Permalink
Refactor code generation logic for improved readability and maintaina… (
Browse files Browse the repository at this point in the history
#252)

* Refactor code generation logic for improved readability and maintainability

Replaced repeated shader type checks in `generate()` with a dictionary-based approach, improving clarity and reducing redundancy.
- Simplified the `generate_if()` method to handle `else_if_conditions` and `else_if_bodies` more clearly, enhancing readability.
- Refactored `generate_statement()` by using a dictionary-based mapping of node types to their corresponding handlers, reducing conditional branching and improving code organization.
  
These changes streamline the code generation logic, making the codebase easier to understand and extend in the future.

* Refactor code generation logic for improved readability and maintainability

- Replaced repeated shader type checks in `generate()` with a dictionary-based approach, improving clarity and reducing redundancy.
- Simplified the `generate_if()` method to handle `else_if_conditions` and `else_if_bodies` more clearly, enhancing readability.
- Refactored `generate_statement()` by using a dictionary-based mapping of node types to their corresponding handlers, reducing conditional branching and improving code organization.

These changes streamline the code generation logic, making the codebase easier to understand and extend in the future.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fixed Imports and the errors.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update directx_codegen.py

* Update directx_codegen.py

* Update directx_codegen.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update directx_codegen.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: Nripesh Niketan <[email protected]>
  • Loading branch information
Vruddhi18 and NripeshN authored Jan 3, 2025
1 parent aacfd36 commit 0e42f94
Showing 1 changed file with 35 additions and 48 deletions.
83 changes: 35 additions & 48 deletions crosstl/translator/codegen/directx_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,19 +43,16 @@ def __init__(self):
}

self.semantic_map = {
# Vertex inputs instance
"gl_VertexID": "SV_VertexID",
"gl_InstanceID": "SV_InstanceID",
"gl_IsFrontFace": "FRONT_FACE",
"gl_PrimitiveID": "PRIMITIVE_ID",
"InstanceID": "INSTANCE_ID",
"VertexID": "VERTEX_ID",
# Vertex outputs
"gl_Position": "SV_POSITION",
"gl_PointSize": "SV_POINTSIZE",
"gl_ClipDistance": "SV_ClipDistance",
"gl_CullDistance": "SV_CullDistance",
# Fragment inputs
"gl_FragColor": "SV_TARGET",
"gl_FragColor0": "SV_TARGET0",
"gl_FragColor1": "SV_TARGET1",
Expand All @@ -66,10 +63,6 @@ def __init__(self):
"gl_FragColor6": "SV_TARGET6",
"gl_FragColor7": "SV_TARGET7",
"gl_FragDepth": "SV_DEPTH",
"gl_FragDepth0": "SV_DEPTH0",
"gl_FragDepth1": "SV_DEPTH1",
"gl_FragDepth2": "SV_DEPTH2",
"gl_FragDepth3": "SV_DEPTH3",
}

def generate(self, ast):
Expand All @@ -83,7 +76,6 @@ def generate(self, ast):
code += "}\n"

# Generate global variables

for i, node in enumerate(ast.global_variables):
if node.vtype in ["sampler2D", "samplerCube"]:
code += "// Texture Samplers\n"
Expand All @@ -93,6 +85,7 @@ def generate(self, ast):
code += f"{self.map_type(node.vtype)} {node.name} :register(s{i});\n"
else:
code += f"{self.map_type(node.vtype)} {node.name};\n"

# Generate cbuffers
if ast.cbuffers:
code += "// Constant Buffers\n"
Expand All @@ -106,7 +99,6 @@ def generate(self, ast):
elif func.qualifier == "fragment":
code += "// Fragment Shader\n"
code += self.generate_function(func, shader_type="fragment")

elif func.qualifier == "compute":
code += "// Compute Shader\n"
code += self.generate_function(func, shader_type="compute")
Expand All @@ -132,38 +124,32 @@ def generate_function(self, func, indent=0, shader_type=None):
f"{self.map_type(p.vtype)} {p.name} {self.map_semantic(p.semantic)}"
for p in func.params
)
if shader_type == "vertex":
code += f"{self.map_type(func.return_type)} VSMain({params}) {self.map_semantic(func.semantic)} {{\n"
elif shader_type == "fragment":
code += f"{self.map_type(func.return_type)} PSMain({params}) {self.map_semantic(func.semantic)} {{\n"
elif shader_type == "compute":
code += f"{self.map_type(func.return_type)} CSMain({params}) {self.map_semantic(func.semantic)} {{\n"
shader_map = {"vertex": "VSMain", "fragment": "PSMain", "compute": "CSMain"}

if func.qualifier in shader_map:
code += f"// {func.qualifier.capitalize()} Shader\n"
code += f"{self.map_type(func.return_type)} {shader_map[func.qualifier]}({params}) {{\n"
else:
code += f"{self.map_type(func.return_type)} {func.name}({params}) {self.map_semantic(func.semantic)} {{\n"
code += f"{self.map_type(func.return_type)} {func.name}({params}) {{\n"

for stmt in func.body:
code += self.generate_statement(stmt, 1)
code += "}\n\n"

code += self.generate_statement(stmt, indent + 1)
code += " " * indent + "}\n\n"
return code

def generate_statement(self, stmt, indent=0):
indent_str = " " * indent
if isinstance(stmt, VariableNode):
return f"{indent_str}{self.map_type(stmt.vtype)} {stmt.name};\n"
elif isinstance(stmt, AssignmentNode):
return f"{indent_str}{self.generate_assignment(stmt)};\n"
elif isinstance(stmt, IfNode):
return self.generate_if(stmt, indent)
elif isinstance(stmt, ForNode):
return self.generate_for(stmt, indent)
elif isinstance(stmt, ReturnNode):
code = ""
for i, return_stmt in enumerate(stmt.value):
code += f"{self.generate_expression(return_stmt)}"
if i < len(stmt.value) - 1:
code += ", "
return f"{indent_str}return {code};\n"
statement_handlers = {
VariableNode: lambda stmt: f"{indent_str}{self.map_type(stmt.vtype)} {stmt.name};\n",
AssignmentNode: lambda stmt: f"{indent_str}{self.generate_assignment(stmt)};\n",
IfNode: lambda stmt: self.generate_if(stmt, indent),
ForNode: lambda stmt: self.generate_for(stmt, indent),
ReturnNode: lambda stmt: self.generate_return(stmt, indent),
}

handler = statement_handlers.get(type(stmt))
if handler:
return handler(stmt)
else:
return f"{indent_str}{self.generate_expression(stmt)};\n"

Expand Down Expand Up @@ -199,22 +185,26 @@ def generate_if(self, node, indent):
def generate_for(self, node, indent):
indent_str = " " * indent

init = self.generate_statement(node.init, 0).strip()[
:-1
] # Remove trailing semicolon

condition = self.generate_statement(node.condition, 0).strip()[
:-1
] # Remove trailing semicolon

update = self.generate_statement(node.update, 0).strip()[:-1]
# Extract and remove the trailing semicolon from init, condition, and update expressions
init = self.generate_statement(node.init, 0).strip().rstrip(";")
condition = self.generate_statement(node.condition, 0).strip().rstrip(";")
update = self.generate_statement(node.update, 0).strip().rstrip(";")

code = f"{indent_str}for ({init}; {condition}; {update}) {{\n"
for stmt in node.body:
code += self.generate_statement(stmt, indent + 1)
code += f"{indent_str}}}\n"
return code

def generate_return(self, node, indent):
indent_str = " " * indent
code = ""
for i, return_stmt in enumerate(node.value):
code += f"{self.generate_expression(return_stmt)}"
if i < len(node.value) - 1:
code += ", "
return f"{indent_str}return {code};\n"

def generate_expression(self, expr):
if isinstance(expr, str):
return expr
Expand All @@ -225,12 +215,10 @@ def generate_expression(self, expr):
left = self.generate_expression(expr.left)
right = self.generate_expression(expr.right)
return f"{left} {self.map_operator(expr.op)} {right}"

elif isinstance(expr, AssignmentNode):
left = self.generate_expression(expr.left)
right = self.generate_expression(expr.right)
return f"{left} {self.map_operator(expr.operator)} {right}"

elif isinstance(expr, UnaryOpNode):
operand = self.generate_expression(expr.operand)
return f"{self.map_operator(expr.op)}{operand}"
Expand All @@ -240,7 +228,6 @@ def generate_expression(self, expr):
elif isinstance(expr, MemberAccessNode):
obj = self.generate_expression(expr.object)
return f"{obj}.{expr.member}"

elif isinstance(expr, TernaryOpNode):
return f"{self.generate_expression(expr.condition)} ? {self.generate_expression(expr.true_expr)} : {self.generate_expression(expr.false_expr)}"
else:
Expand Down Expand Up @@ -285,7 +272,7 @@ def map_operator(self, op):
return op_map.get(op, op)

def map_semantic(self, semantic):
if semantic is not None:
if semantic:
return f": {self.semantic_map.get(semantic, semantic)}"
else:
return ""
return "" # Handle None by returning an empty string

0 comments on commit 0e42f94

Please sign in to comment.