Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FP8 GEMM Kernels #1391

Open
xiaoxiao26 opened this issue Jan 6, 2025 · 0 comments
Open

FP8 GEMM Kernels #1391

xiaoxiao26 opened this issue Jan 6, 2025 · 0 comments

Comments

@xiaoxiao26
Copy link

xiaoxiao26 commented Jan 6, 2025

After leveraging Transformer Engine's FP8 features for PyTorch on H100, my linear layers in forward pass output GEMM kernels like sm90_xmma_gemm_e4m3bf16_e4m3f32_f32_tn_n_tilesize128x128x128_warpgroupsize1x1x1_execute_segment_k_off_kernel__5x_cublas instead of sm90_xmma_gemm_bf16bf16_bf16f32_f32_tn_n_tilesize128x128x64_warpgroupsize1x1x1_execute_segment_k_off_kernel__5x_cublas

The main difference seems to be bf16bf16_bf16f32_f32 -> e4m3bf16_e4m3f32_f32

I'm curious how do I interpret this? I thought the pattern was [input_types]_[accumulator_type]_[output_type]. But that would imply that either one of the weights or activations is in bf16 rather than fp8. My understanding is that both are cast to fp8. Would appreciate if anyone can help correct my understanding here. Thank you!

Note I am also using AMP autocast with bf16 so maybe that is affecting things.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant