-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Refactor] Apply latest version of flux and merge all branches from ai-compiler-study/flux #1
base: main
Are you sure you want to change the base?
Conversation
Co-authored-by: Neil Movva <[email protected]>
* Remove unused import * Remove extraneous `f` prefix --------- Co-authored-by: Emil Sadek <[email protected]>
* apply ruff * rename * specify ruff version for CI * also check imports * check formatting
if not CHECK_NSFW or nsfw_score < NSFW_THRESHOLD: | ||
buffer = BytesIO() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have switched to using CHECK_NSFW instead of the previously commented-out code.
class CudaTimer: | ||
""" | ||
A static context manager class for measuring execution time of PyTorch code | ||
using CUDA events. It synchronizes GPU operations to ensure accurate time measurements. | ||
""" | ||
|
||
def __init__(self, name="", precision=5, display=False): | ||
self.name = name | ||
self.precision = precision | ||
self.display = display | ||
|
||
def __enter__(self): | ||
torch.cuda.synchronize() | ||
self.start_event = torch.cuda.Event(enable_timing=True) | ||
self.end_event = torch.cuda.Event(enable_timing=True) | ||
self.start_event.record() | ||
return self | ||
|
||
def __exit__(self, *exc): | ||
self.end_event.record() | ||
torch.cuda.synchronize() | ||
# Convert from ms to s | ||
self.elapsed_time = self.start_event.elapsed_time(self.end_event) * 1e-3 | ||
|
||
if self.display: | ||
print(f"{self.name}: {self.elapsed_time:.{self.precision}f} s") | ||
|
||
def get_elapsed_time(self): | ||
"""Returns the elapsed time in microseconds.""" | ||
return self.elapsed_time | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed benchmarking functionality from cli.py since similar benchmarking capabilities are available in benchmark/benchmark_flux.py. This change aligns with the primary purpose of cli.py, which is not intended for benchmarking operations.
if xformers_flash3: | ||
if torch_sdpa or triton_attention: | ||
print( | ||
"Warning: xformers_flash3 is enabled, but torch_sdpa or triton_attention is also enabled. " | ||
"Please remain only one of them." | ||
) | ||
|
||
q = q.permute(0, 2, 1, 3) # B, H, S, D | ||
k = k.permute(0, 2, 1, 3) # B, H, S, D | ||
v = v.permute(0, 2, 1, 3) # B, H, S, D | ||
|
||
x = compiled_xformers_flash_hopper(q, k, v).permute(0,2,1,3) | ||
if torch_sdpa: | ||
x = _compiled_xformers_flash_hopper(q, k, v).permute(0,2,1,3) | ||
elif torch_sdpa: | ||
if triton_attention: | ||
print( | ||
"Warning: torch_sdpa is enabled, but triton_attention is also enabled. " | ||
"Please remain only one of them." | ||
) | ||
|
||
x = scaled_dot_product_attention(q, k, v) | ||
if triton_attention: | ||
elif triton_attention: | ||
from triton.ops import attention as attention_triton | ||
|
||
softmax_scale = q.size(-1) ** -0.5 | ||
x = attention_triton(q, k, v, True, softmax_scale) | ||
else: | ||
x = torch.nn.functional.scaled_dot_product_attention(q, k, v) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Refactored attention mechanism implementation to improve code maintainability and provide clearer error handling when multiple attention methods are enabled simultaneously. Major changes include:
- Lazy importing of optional dependencies (xformers, triton)
- Added warning messages for conflicting attention method selections
- Made default attention method explicit (torch.scaled_dot_product_attention)
try: | ||
import triton_kernels | ||
from triton_kernels import SingleStreamBlock, DoubleStreamBlock | ||
except ImportError: | ||
print("Triton kernels not found, using flux native implementation.") | ||
from flux.modules.layers import SingleStreamBlock, DoubleStreamBlock | ||
except ModuleNotFoundError: | ||
print("Triton kernels not found, using flux native implementation.") | ||
from flux.modules.layers import SingleStreamBlock, DoubleStreamBlock | ||
except Exception as e: | ||
print(f"Error: {e}") | ||
from flux.modules.layers import SingleStreamBlock, DoubleStreamBlock |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added graceful fallback mechanism for triton kernel imports - if triton kernels are unavailable or fail to load, the code automatically falls back to native FLUX implementation while providing appropriate error messages. This ensures smoother execution across different environments and clearer debugging information.
Features