Skip to content

Commit

Permalink
feat: added a save_shader flag in crosstl
Browse files Browse the repository at this point in the history
  • Loading branch information
samthakur587 committed Aug 18, 2024
1 parent f047863 commit 5adec21
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 113 deletions.
15 changes: 10 additions & 5 deletions crosstl/_crosstl.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from .src.backend.Opengl import *


def translate(file_path: str, backend: str = "crossgl") -> str:
def translate(file_path: str, backend: str = "cgl", save_shader: str = None) -> str:
backend = backend.lower()

with open(file_path, "r") as file:
Expand All @@ -35,13 +35,10 @@ def translate(file_path: str, backend: str = "crossgl") -> str:
if file_path.endswith(".cgl"):
if backend == "metal":
codegen = metal_codegen.MetalCodeGen()
return codegen.generate(ast)
elif backend == "directx":
codegen = directx_codegen.HLSLCodeGen()
return codegen.generate(ast)
elif backend == "opengl":
codegen = opengl_codegen.GLSLCodeGen()
return codegen.generate(ast)
else:
raise ValueError(f"Unsupported backend for CrossGL file: {backend}")
else:
Expand All @@ -54,8 +51,16 @@ def translate(file_path: str, backend: str = "crossgl") -> str:
codegen = GLSLToCrossGLConverter()
else:
raise ValueError(f"Reverse translation not supported for: {file_path}")
return codegen.generate(ast)
else:
raise ValueError(
f"Unsupported translation scenario: {file_path} to {backend}"
)

# Generate the code and write to the file
generated_code = codegen.generate(ast)

if save_shader is not None:
with open(save_shader, "w") as file:
file.write(generated_code)

return generated_code
1 change: 0 additions & 1 deletion crosstl/src/translator/codegen/directx_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ def __init__(self):
}

def generate(self, ast):
print(ast)
if isinstance(ast, ShaderNode):
self.current_shader = ast
return self.generate_shader(ast)
Expand Down
13 changes: 5 additions & 8 deletions crosstl/src/translator/codegen/metal_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,20 +346,17 @@ def generate_expression(self, expr, shader_type=None):
elif isinstance(expr, BinaryOpNode):
return f"{self.generate_expression(expr.left, shader_type)} {self.map_operator(expr.op)} {self.generate_expression(expr.right, shader_type)}"
elif isinstance(expr, FunctionCallNode):
if expr.name in ["vec2", "vec3", "vec4"]:
args = ", ".join(
self.generate_expression(arg, shader_type) for arg in expr.args
)
args = ", ".join(
self.generate_expression(arg, shader_type) for arg in expr.args
)
if expr.name in self.type_mapping.keys():
return f"{self.map_type(expr.name)}({args})"
else:
args = ", ".join(
self.generate_expression(arg, shader_type) for arg in expr.args
)
func_name = self.translate_expression(expr.name, shader_type)
return f"{func_name}({args})"

elif isinstance(expr, UnaryOpNode):
return f"{self.generate_expression(expr.operand, shader_type)}{self.map_operator(expr.op)}"
return f"{self.map_operator(expr.op)}{self.generate_expression(expr.operand, shader_type)}"

elif isinstance(expr, TernaryOpNode):
return f"{self.generate_expression(expr.condition, shader_type)} ? {self.generate_expression(expr.true_expr, shader_type)} : {self.generate_expression(expr.false_expr, shader_type)}"
Expand Down
9 changes: 1 addition & 8 deletions crosstl/src/translator/codegen/opengl_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,13 +199,6 @@ def generate_for(self, node, indent, shader_type=None):

update = self.generate_statement(node.update, 0, shader_type).strip()[:-1]

if isinstance(node.update, AssignmentNode) and isinstance(
node.update.value, UnaryOpNode
):
update = f"{node.update.value.operand.name}++"
else:
update = self.generate_statement(node.update, 0, shader_type).strip()[:-1]

code = f"{indent_str}for ({init}; {condition}; {update}) {{\n"
for stmt in node.body:
code += self.generate_statement(stmt, indent + 1, shader_type)
Expand All @@ -229,7 +222,7 @@ def generate_expression(self, expr, shader_type=None):
func_name = self.translate_expression(expr.name, shader_type)
return f"{func_name}({args})"
elif isinstance(expr, UnaryOpNode):
return f"{self.generate_expression(expr.operand, shader_type)}{self.map_operator(expr.op)}"
return f"{self.map_operator(expr.op)}{self.generate_expression(expr.operand, shader_type)}"

elif isinstance(expr, TernaryOpNode):
return f"{self.generate_expression(expr.condition, shader_type)} ? {self.generate_expression(expr.true_expr, shader_type)} : {self.generate_expression(expr.false_expr, shader_type)}"
Expand Down
Loading

0 comments on commit 5adec21

Please sign in to comment.