Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] Apply latest version of flux and merge all branches from ai-compiler-study/flux #1

Draft
wants to merge 17 commits into
base: main
Choose a base branch
from

Conversation

cmpark0126
Copy link

@cmpark0126 cmpark0126 commented Nov 30, 2024

Features

  • Apply the latest version of black-forest-labs/flux
  • Merge ai-compiler-study:main and ai-compiler-study:triton
  • Refactor codes to solve the problems below:
    • Refactored the codebase to make custom kernel implementations optional while maintaining backward compatibility with the original FLUX implementation, allowing users to choose between different kernel versions.
      • e.g., triton kernel implemented by @sjjeong94 at src/flux/model.py
      • e.g., XFORMERS_FLASH3 at src/flux/math.py
    • Remain cli as cli program not for benchmarking

Comment on lines +246 to +247
if not CHECK_NSFW or nsfw_score < NSFW_THRESHOLD:
buffer = BytesIO()
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have switched to using CHECK_NSFW instead of the previously commented-out code.

Comment on lines -28 to -58
class CudaTimer:
"""
A static context manager class for measuring execution time of PyTorch code
using CUDA events. It synchronizes GPU operations to ensure accurate time measurements.
"""

def __init__(self, name="", precision=5, display=False):
self.name = name
self.precision = precision
self.display = display

def __enter__(self):
torch.cuda.synchronize()
self.start_event = torch.cuda.Event(enable_timing=True)
self.end_event = torch.cuda.Event(enable_timing=True)
self.start_event.record()
return self

def __exit__(self, *exc):
self.end_event.record()
torch.cuda.synchronize()
# Convert from ms to s
self.elapsed_time = self.start_event.elapsed_time(self.end_event) * 1e-3

if self.display:
print(f"{self.name}: {self.elapsed_time:.{self.precision}f} s")

def get_elapsed_time(self):
"""Returns the elapsed time in microseconds."""
return self.elapsed_time

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed benchmarking functionality from cli.py since similar benchmarking capabilities are available in benchmark/benchmark_flux.py. This change aligns with the primary purpose of cli.py, which is not intended for benchmarking operations.

Comment on lines 39 to +65
if xformers_flash3:
if torch_sdpa or triton_attention:
print(
"Warning: xformers_flash3 is enabled, but torch_sdpa or triton_attention is also enabled. "
"Please remain only one of them."
)

q = q.permute(0, 2, 1, 3) # B, H, S, D
k = k.permute(0, 2, 1, 3) # B, H, S, D
v = v.permute(0, 2, 1, 3) # B, H, S, D

x = compiled_xformers_flash_hopper(q, k, v).permute(0,2,1,3)
if torch_sdpa:
x = _compiled_xformers_flash_hopper(q, k, v).permute(0,2,1,3)
elif torch_sdpa:
if triton_attention:
print(
"Warning: torch_sdpa is enabled, but triton_attention is also enabled. "
"Please remain only one of them."
)

x = scaled_dot_product_attention(q, k, v)
if triton_attention:
elif triton_attention:
from triton.ops import attention as attention_triton

softmax_scale = q.size(-1) ** -0.5
x = attention_triton(q, k, v, True, softmax_scale)
else:
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactored attention mechanism implementation to improve code maintainability and provide clearer error handling when multiple attention methods are enabled simultaneously. Major changes include:

  • Lazy importing of optional dependencies (xformers, triton)
  • Added warning messages for conflicting attention method selections
  • Made default attention method explicit (torch.scaled_dot_product_attention)

Comment on lines +14 to +25
try:
import triton_kernels
from triton_kernels import SingleStreamBlock, DoubleStreamBlock
except ImportError:
print("Triton kernels not found, using flux native implementation.")
from flux.modules.layers import SingleStreamBlock, DoubleStreamBlock
except ModuleNotFoundError:
print("Triton kernels not found, using flux native implementation.")
from flux.modules.layers import SingleStreamBlock, DoubleStreamBlock
except Exception as e:
print(f"Error: {e}")
from flux.modules.layers import SingleStreamBlock, DoubleStreamBlock
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added graceful fallback mechanism for triton kernel imports - if triton kernels are unavailable or fail to load, the code automatically falls back to native FLUX implementation while providing appropriate error messages. This ensures smoother execution across different environments and clearer debugging information.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants