diff --git a/.gitmodules b/.gitmodules index a8e8349e1..8d501cb19 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,3 @@ -[submodule "csrc/flash_attn/cutlass"] - path = csrc/flash_attn/cutlass +[submodule "csrc/cutlass"] + path = csrc/cutlass url = https://github.com/NVIDIA/cutlass.git diff --git a/AUTHORS b/AUTHORS index bb78ee50a..e35a78166 100644 --- a/AUTHORS +++ b/AUTHORS @@ -1,2 +1 @@ -Tri Dao, trid@stanford.edu -Dan Fu, danfu@cs.stanford.edu \ No newline at end of file +Tri Dao, trid@cs.stanford.edu \ No newline at end of file diff --git a/MANIFEST.in b/MANIFEST.in index 885bb8b9a..021b4d0f7 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -2,8 +2,10 @@ recursive-include csrc *.cu recursive-include csrc *.h recursive-include csrc *.cuh recursive-include csrc *.cpp +recursive-include csrc *.hpp recursive-include flash_attn *.cu recursive-include flash_attn *.h recursive-include flash_attn *.cuh recursive-include flash_attn *.cpp +recursive-include flash_attn *.hpp diff --git a/README.md b/README.md index 31fc62a6e..829230cd0 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,7 @@ # FlashAttention -This repository provides the official implementation of FlashAttention from the -following paper. +This repository provides the official implementation of FlashAttention and +FlashAttention-2 from the +following papers. **FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness** Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, Christopher RĂ© @@ -8,39 +9,22 @@ Paper: https://arxiv.org/abs/2205.14135 IEEE Spectrum [article](https://spectrum.ieee.org/mlperf-rankings-2022) about our submission to the MLPerf 2.0 benchmark using FlashAttention. ![FlashAttention](assets/flashattn_banner.jpg) -## Usage +**FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning** +Tri Dao -We've been very happy to see FlashAttention being widely adopted in such a short -time after its release. This [page](https://github.com/HazyResearch/flash-attention/blob/main/usage.md) -contains a partial list of places where FlashAttention is being used. +Paper: https://tridao.me/publications/flash2/flash2.pdf -## Full model code and training script +![FlashAttention-2](assets/flashattention_logo.png) -We have released the full GPT model -[implementation](https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/models/gpt.py). -We also provide optimized implementations of other layers (e.g., MLP, LayerNorm, -cross-entropy loss, rotary embedding). Overall this speeds up training by 3-5x -compared to the baseline implementation from Huggingface, reaching up to 189 -TFLOPs/sec per A100, equivalent to 60.6\% model FLOPs utilization (we don't need -any activation checkpointing). -We also include a training -[script](https://github.com/HazyResearch/flash-attention/tree/main/training) to -train GPT2 on Openwebtext and GPT3 on The Pile. - -## Triton implementation of FlashAttention - -Phil Tillet (OpenAI) has an experimental implementation of FlashAttention in Triton: -https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py - -As Triton is a higher-level language than CUDA, it might be easier to understand -and experiment with. The notations in the Triton implementation are also closer -to what's used in our paper. +## Usage -We also have an experimental implementation in Triton that support attention -bias (e.g. ALiBi): -https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attn_triton.py +We've been very happy to see FlashAttention being widely adopted in such a short +time after its release. This [page](https://github.com/Dao-AILab/flash-attention/blob/main/usage.md) +contains a partial list of places where FlashAttention is being used. +FlashAttention and FlashAttention-2 are free to use and modify (see LICENSE). +Please cite and credit FlashAttention if you use it. ## Installation and features @@ -53,125 +37,116 @@ We recommend the container from Nvidia, which has all the required tools to install FlashAttention. To install: +1. Make sure that PyTorch is installed. +2. Make sure that `packaging` is installed (`pip install packaging`) +3. Make sure that `ninja` is installed and that it works correctly (e.g. `ninja +--version` then `echo $?` should return exit code 0). If not (sometimes `ninja +--version` then `echo $?` returns a nonzero exit code), uninstall then reinstall +`ninja` (`pip uninstall -y ninja && pip install ninja`). Without `ninja` +compiling can take a very long time (2h) since it does not use multiple CPU +cores. With `ninja` compiling takes 3-5 minutes on a 64-core machine. +4. Then: ```sh -pip install flash-attn +pip install flash-attn --no-build-isolation ``` - Alternatively you can compile from source: ``` python setup.py install ``` -Interface: `src/flash_attention.py` - -To run the benchmark against PyTorch standard attention: -``` -PYTHONPATH=$PWD python benchmarks/benchmark_flash_attention.py -``` +Interface: `src/flash_attention_interface.py` -FlashAttention currently supports: -1. Turing, Ampere, Ada, or Hopper GPUs (e.g., H100, A100, RTX 3090, T4, RTX 2080). -2. fp16 and bf16 (bf16 requires Ampere, Ada, or Hopper GPUs). -3. Head dimensions that are multiples of 8, up to 128 (e.g., 8, 16, 24, ..., - 128). Head dim > 64 backward requires A100 or H100. - -Our tentative roadmap: -1. ~~[Jun 2022] Make package pip-installable~~[Done, thanks to lucidrains]. -2. ~~[Jun 2022] Support SM86 GPUs (e.g., RTX 3080, 3090)~~[Done]. -3. ~~[Jun 2022] Support SM75 GPUs (e.g. T4)~~[Done]. -4. ~~[Jun 2022] Support bf16~~[Done]. -5. ~~[Jul 2022] Implement cross-attention~~[Done]. -6. ~~[Jul 2022] Support head dimension 128~~[Done]. -7. ~~[Aug 2022] Fuse rotary embedding~~[Done]. -8. ~~[Mar 2023] Support SM90 GPUs (H100)~~[Done]. +FlashAttention-2 currently supports: +1. Ampere, Ada, or Hopper GPUs (e.g., A100, RTX 3090, RTX 4090, H100). Support for Turing + GPUs (T4, RTX 2080) is coming soon, please use FlashAttention 1.x for Turing + GPUs for now. +2. Datatype fp16 and bf16 (bf16 requires Ampere, Ada, or Hopper GPUs). +3. All head dimensions up to 256. Head dim > 192 backward requires A100/A800 or H100/H800. ## How to use FlashAttention -Here's a simple example: -```python -import torch -from flash_attn.flash_attention import FlashMHA - -# Replace this with your correct GPU device -device = "cuda:0" - -# Create attention layer. This is similar to torch.nn.MultiheadAttention, -# and it includes the input and output linear layers -flash_mha = FlashMHA( - embed_dim=128, # total channels (= num_heads * head_dim) - num_heads=8, # number of heads - device=device, - dtype=torch.float16, -) - -# Run forward pass with dummy data -x = torch.randn( - (64, 256, 128), # (batch, seqlen, embed_dim) - device=device, - dtype=torch.float16 -) - -output = flash_mha(x)[0] +The main functions implement scaled dot product attention (softmax(Q @ K^T * +softmax_scale) @ V): +``` +from flash_attn import flash_attn_qkvpacked_func, flash_attn_func ``` -Alternatively, you can import the inner attention layer only (so that the input -and output linear layers are not included): -```python -from flash_attn.flash_attention import FlashAttention +``` +flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False): +"""dropout_p should be set to 0.0 during evaluation +If Q, K, V are already stacked into 1 tensor, this function will be faster than +calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation +of the gradients of Q, K, V. +Arguments: + qkv: (batch_size, seqlen, 3, nheads, headdim) + dropout_p: float. Dropout probability. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). +Return: + out: (batch_size, seqlen, nheads, headdim). +``` -# Create the nn.Module -flash_attention = FlashAttention() +``` +flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False): +"""dropout_p should be set to 0.0 during evaluation +Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads +than Q. Note that the number of heads in KV must be divisible by the number of heads in Q. +For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head +0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. + +Arguments: + q: (batch_size, seqlen, nheads, headdim) + k: (batch_size, seqlen, nheads_k, headdim) + v: (batch_size, seqlen, nheads_k, headdim) + dropout_p: float. Dropout probability. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). +Return: + out: (batch_size, seqlen, nheads, headdim). ``` -Or, if you need more fine-grained control, you can import one of the lower-level -functions (this is more similar to the `torch.nn.functional` style): -```python -from flash_attn.flash_attn_interface import flash_attn_unpadded_func +To see how these functions are used in a multi-head attention layer (which +includes QKV projection, output projection), see the MHA [implementation](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py). -# or +## Upgrading from FlashAttention (1.x) to FlashAttention-2 -from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_split_func +These functions have been renamed: +- `flash_attn_unpadded_func` -> `flash_attn_varlen_func` +- `flash_attn_unpadded_qkvpacked_func` -> `flash_attn_varlen_qkvpacked_func` +- `flash_attn_unpadded_kvpacked_func` -> `flash_attn_varlen_kvpacked_func` -# etc. +If the inputs have the same sequence lengths in the same batch, it is simpler +and faster to use these functions: +``` +flash_attn_qkvpacked_func(qkv, dropout_p, softmax_scale=None, causal=False) ``` - -There are also separate Python files with various FlashAttention extensions: -```python -# Import the triton implementation (torch.nn.functional version only) -from flash_attn.flash_attn_triton import flash_attn_func - -# Import block sparse attention (nn.Module version) -from flash_attn.flash_blocksparse_attention import FlashBlocksparseMHA, FlashBlocksparseAttention - -# Import block sparse attention (torch.nn.functional version) -from flash_attn.flash_blocksparse_attn_interface import flash_blocksparse_attn_func +``` +flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False) ``` -## Speedup and Memory Savings +## Performance We present expected speedup (combined forward + backward pass) and memory savings from using FlashAttention against PyTorch standard attention, depending on sequence length, on different GPUs (speedup depends on memory bandwidth - we see more speedup on slower GPU memory). We currently have benchmarks for these GPUs: * [A100](#a100) -* [RTX 3090](#rtx-3090) -* [T4](#t4) +* [H100](#h100) + + ### A100 -We display FlashAttention speedup using these parameters (similar to BERT-base): -* Batch size 8 -* Head dimension 64 -* 12 attention heads - -Our graphs show sequence lengths between 128 and 4096 (when standard attention runs out of memory on an A100), but FlashAttention can scale up to sequence length 64K. +We display FlashAttention speedup using these parameters: +* Head dimension 64 or 128, hidden dimension 2048 (i.e. either 32 or 16 heads). +* Sequence length 512, 1k, 2k, 4k, 8k, 16k. +* Batch size set to 16k / seqlen. #### Speedup -![FlashAttention speedup](assets/flashattn_speedup.jpg) - -We generally see 2-4X speedup at sequence lengths between 128 and 4K, and we see more speedup when using dropout and masking, since we fuse the kernels. -At sequence lengths that are popular with language models like 512 and 1K, we see speedups up to 4X when using dropout and masking. +![FlashAttention speedup on A100 80GB SXM5 with FP16/BF16](assets/flash2_a100_fwd_bwd_benchmark.png) #### Memory @@ -182,38 +157,37 @@ Memory savings are proportional to sequence length -- since standard attention h We see 10X memory savings at sequence length 2K, and 20X at 4K. As a result, FlashAttention can scale to much longer sequence lengths. -#### Head Dimension 128 - -![FlashAttention speedup, head dimension 128](assets/flashattn_speedup_a100_d128.jpg) - -We show speedup with head dimension 128. -Here we show batch size 16 with 12 heads. -Speedup is less than with the smaller head sizes, since we have to make the block size smaller in the tiling. -But speedup is still significant, especially with a causal mask. - -### RTX 3090 +### H100 -For the RTX 3090, we use batch size 12 with 12 attention heads. -Memory savings are the same as on an A100, so we'll only show speedup here. +![FlashAttention speedup on H100 SXM5 with FP16/BF16](assets/flash2_h100_fwd_bwd_benchmark.png) -![FlashAttention speedup GTX 3090](assets/flashattn_speedup_3090.jpg) - -We see slightly higher speedups (between 2.5-4.5x) on the GTX 3090, since memory bandwidth on the GDDR6X is lower than A100 HBM (~900 GB/s vs. ~1.5 TB/s). +## Full model code and training script -### T4 +We have released the full GPT model +[implementation](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/models/gpt.py). +We also provide optimized implementations of other layers (e.g., MLP, LayerNorm, +cross-entropy loss, rotary embedding). Overall this speeds up training by 3-5x +compared to the baseline implementation from Huggingface, reaching up to 225 +TFLOPs/sec per A100, equivalent to 72% model FLOPs utilization (we don't need +any activation checkpointing). -We again use batch size 12 with 12 attention heads. +We also include a training +[script](https://github.com/Dao-AILab/flash-attention/tree/main/training) to +train GPT2 on Openwebtext and GPT3 on The Pile. -![Flashattention speedup T4](assets/flashattn_speedup_t4.jpg) +## Triton implementation of FlashAttention -T4 SRAM is smaller than the newer GPUs (64 KB), so we see less speedup (we need to make the block sizes smaller, so we end up doing more R/W). -This matches the IO complexity analysis from section 3.2 of [our paper](https://arxiv.org/abs/2205.14135). +Phil Tillet (OpenAI) has an experimental implementation of FlashAttention in Triton: +https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py -T4 GPUs are commonly used for inference, so we also measure speedup on the forward pass only (note that these are not directly comparable to the graphs above): +As Triton is a higher-level language than CUDA, it might be easier to understand +and experiment with. The notations in the Triton implementation are also closer +to what's used in our paper. -![FlashAttention speedup T4 fwd](assets/flashattn_speedup_t4_fwd.jpg) +We also have an experimental implementation in Triton that support attention +bias (e.g. ALiBi): +https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_triton.py -We see speedups between 2.5x-4.5x on the forward pass. ## Tests We test that FlashAttention produces the same output and gradient as a reference @@ -228,21 +202,10 @@ pytest -q -s tests/test_flash_attn.py ``` ## When you encounter issues -This alpha release of FlashAttention contains code written for a research -project to validate ideas on speeding up attention. -We have tested it on several models (BERT, GPT2, ViT). -However, there might still be bugs in the implementation that we hope to iron -out in the next few months. +This new release of FlashAttention-2 have been tested on several GPT-style +models, mostly on A100 GPUs. -If you encounter any of these bugs, please open a respective GitHub Issue! - -## Acknowledgments -Our implementation uses Apex's -[FMHA](https://github.com/NVIDIA/apex/tree/master/apex/contrib/csrc/fmha) code -as a starting point. - -We thank [Young-Jun Ko](https://yjk21.github.io/) for the in-depth explanation of his FMHA implementation -and for his thoughtful answers to our questions about CUDA. +If you encounter any of bugs, please open a respective GitHub Issue! ## Citation If you use this codebase, or otherwise found our work valuable, please cite: @@ -253,4 +216,9 @@ If you use this codebase, or otherwise found our work valuable, please cite: booktitle={Advances in Neural Information Processing Systems}, year={2022} } +@article{dao2023flashattention2, + title={Flash{A}ttention-2: Faster Attention with Better Parallelism and Work Partitioning, + author={Dao, Tri}, + year={2023} +} ``` diff --git a/assets/flash2_a100_fwd_bwd_benchmark.png b/assets/flash2_a100_fwd_bwd_benchmark.png new file mode 100644 index 000000000..f529197be Binary files /dev/null and b/assets/flash2_a100_fwd_bwd_benchmark.png differ diff --git a/assets/flash2_h100_fwd_bwd_benchmark.png b/assets/flash2_h100_fwd_bwd_benchmark.png new file mode 100644 index 000000000..41779e0af Binary files /dev/null and b/assets/flash2_h100_fwd_bwd_benchmark.png differ diff --git a/assets/flashattention_logo.png b/assets/flashattention_logo.png new file mode 100644 index 000000000..85e7e5744 Binary files /dev/null and b/assets/flashattention_logo.png differ diff --git a/benchmarks/benchmark_causal.py b/benchmarks/benchmark_causal.py index 8226c883a..26f16e3d3 100644 --- a/benchmarks/benchmark_causal.py +++ b/benchmarks/benchmark_causal.py @@ -6,11 +6,21 @@ from einops import rearrange, repeat -from flash_attn.utils.benchmark import benchmark_forward, benchmark_all, pytorch_profiler -from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func -# from flash_attn.triton.fused_attention import attention as attention -from flash_attn.flash_attn_triton import flash_attn_qkvpacked_func -from flash_attn.flash_attn_triton_og import attention as attention_og +# from flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler +from src.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler +from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func +# # from flash_attn.triton.fused_attention import attention as attention +# from flash_attn.flash_attn_triton import flash_attn_qkvpacked_func +# from flash_attn.flash_attn_triton_og import attention as attention_og + +# from triton.ops.flash_attention import attention as attention_triton + +try: + from fav2 import flash_attn_qkvpacked_func as fav2_qkvpacked_func + from fav2 import flash_attn_kvpacked_func as fav2_kvpacked_func +except ImportError: + fav2_qkvpacked_func = None + fav2_kvpacked_func = None try: from flash_attn.fused_softmax import scaled_upper_triang_masked_softmax @@ -71,16 +81,18 @@ def attention_megatron(qkv): torch.manual_seed(0) repeats = 30 batch_size = 2 -seqlen = 4096 +seqlen = 8192 nheads = 12 headdim = 128 +# nheads = 24 +# headdim = 64 # batch_size = 64 # seqlen = 512 # nheads = 8 # headdim = 128 -dropout_p = 0.0 -causal = True -dtype = torch.bfloat16 +dropout_p = 0.1 +causal = False +dtype = torch.float16 device = 'cuda' qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype, @@ -88,18 +100,130 @@ def attention_megatron(qkv): cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32, device=qkv.device) -benchmark_all(flash_attn_unpadded_qkvpacked_func, rearrange(qkv, 'b s ... -> (b s) ...'), - cu_seqlens, seqlen, dropout_p, causal=causal, repeats=repeats, desc='FlashAttention') -benchmark_all(attention_pytorch, qkv, dropout_p, causal=causal, - repeats=repeats, desc='PyTorch Attention') +# qkv_unpad = rearrange(qkv, 'b s ... -> (b s) ...').detach().requires_grad_(True) +# benchmark_all(flash_attn_varlen_qkvpacked_func, qkv_unpad, +# cu_seqlens, seqlen, dropout_p, causal=causal, repeats=repeats, desc='FlashAttention') +# pytorch_profiler(flash_attn_varlen_qkvpacked_func, qkv_unpad, +# cu_seqlens, seqlen, dropout_p, causal=causal, backward=True) +# if fav2_qkvpacked_func is not None: + # benchmark_all(fav2_qkvpacked_func, qkv, dropout_p, causal=causal, repeats=repeats, desc='Fav2') + # pytorch_profiler(fav2_qkvpacked_func, qkv, dropout_p, causal=causal, backward=True) + +# for dropout_p in [0.1, 0.0]: +# for causal in [False, True]: +# print(f"### {dropout_p = }, {causal = } ###") +# pytorch_profiler(fav2_qkvpacked_func, qkv, dropout_p, causal=causal, backward=True) + +# nheads_k = 2 +# q = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, requires_grad=True) +# kv = torch.randn(batch_size, seqlen, 2, nheads_k, headdim, device=device, dtype=dtype, +# requires_grad=True) +# if fav2_kvpacked_func is not None: +# benchmark_all(fav2_kvpacked_func, q, kv, dropout_p, causal=causal, repeats=repeats, desc='Fav2') +# pytorch_profiler(fav2_kvpacked_func, q, kv, dropout_p, causal=causal, backward=True) + +# dropout_p = 0.0 +# causal = False +# benchmark_all(attention_pytorch, qkv, dropout_p, causal=causal, +# repeats=repeats, desc='PyTorch Attention') + +# benchmark_all(flash_attn_qkvpacked_func, qkv, None, causal, repeats=repeats, desc='FlashAttention Triton') +# pytorch_profiler(flash_attn_qkvpacked_func, qkv, None, causal, backward=True) + +# q, k, v = [torch.randn(batch_size, nheads, seqlen, headdim, device=device, dtype=dtype, +# requires_grad=True) for _ in range(3)] +# benchmark_all(attention_og, q, k, v, 1.0, repeats=repeats, desc='FlashAttention Triton OG') +# # pytorch_profiler(attention, q, k, v, 1.0, backward=True) + +# if scaled_upper_triang_masked_softmax is not None: +# benchmark_all(attention_megatron, qkv, repeats=repeats, desc='Megatron Attention') + +# from src.ops.fftconv import fftconv_func + +# dim = nheads * headdim +# u = torch.randn(batch_size, dim, seqlen, device=device, dtype=dtype, requires_grad=True) +# k = torch.randn(dim, seqlen, device=device, requires_grad=True) +# D = torch.randn(dim, device=device, requires_grad=True) +# benchmark_all(fftconv_func, u, k, D, repeats=repeats, desc='FFTConv') +# pytorch_profiler(fftconv_func, u, k, D, backward=True) +# pytorch_profiler(torch.fft.rfft, u.float()) + +flops = 4 * batch_size * seqlen ** 2 * nheads * headdim +ideal_a100_time = flops / 312 / 1e9 +print(f"Ideal A100 fwd time: {ideal_a100_time:.3f}ms, bwd time: {ideal_a100_time * 2.5:.3f}ms") + + +def time_fwd_bwd(func, *args, **kwargs): + time_f, time_b = benchmark_fwd_bwd(func, *args, **kwargs) + return time_f[1].mean, time_b[1].mean + +bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 16384)] +causal_vals = [False, True] +headdim_vals = [64, 128] +dim = 2048 +dropout_p = 0.0 + +time_f = {} +time_b = {} +for causal in causal_vals: + for headdim in headdim_vals: + for batch_size, seqlen in bs_seqlen_vals: + nheads = dim // headdim + qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype, + requires_grad=True) + cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32, + device=qkv.device) + qkv_unpad = rearrange(qkv, 'b s ... -> (b s) ...').detach().requires_grad_(True) + f, b = time_fwd_bwd( + flash_attn_varlen_qkvpacked_func, qkv_unpad, cu_seqlens, seqlen, dropout_p, + causal=causal, repeats=repeats, verbose=False + ) + time_f[(causal, headdim, batch_size, seqlen), "Flash"] = f + time_b[(causal, headdim, batch_size, seqlen), "Flash"] = b + + qkv = qkv.detach().requires_grad_(True) + f, b = time_fwd_bwd( + fav2_qkvpacked_func, qkv, dropout_p, causal=causal, repeats=repeats, verbose=False + ) + time_f[(causal, headdim, batch_size, seqlen), "Flash2"] = f + time_b[(causal, headdim, batch_size, seqlen), "Flash2"] = b + + # q, k, v = [torch.randn(batch_size, nheads, seqlen, headdim, device=device, dtype=dtype, + # requires_grad=True) for _ in range(3)] + # # Try both values of sequence_parallel and pick the faster one + # f, b = time_fwd_bwd( + # attention_triton, q, k, v, causal, headdim**(-0.5), + # False, repeats=repeats, verbose=False + # ) + # _, b0 = time_fwd_bwd( + # attention_triton, q, k, v, causal, headdim**(-0.5), + # True, repeats=repeats, verbose=False + # ) + # time_f[(causal, headdim, batch_size, seqlen), "Triton"] = f + # time_b[(causal, headdim, batch_size, seqlen), "Triton"] = min(b, b0) + + if seqlen <= 8 * 1024: + qkv = qkv.detach().requires_grad_(True) + f, b = time_fwd_bwd( + attention_pytorch, qkv, dropout_p, causal=causal, repeats=repeats, verbose=False + ) + else: + f, b = float('nan'), float('nan') + time_f[(causal, headdim, batch_size, seqlen), "Pytorch"] = f + time_b[(causal, headdim, batch_size, seqlen), "Pytorch"] = b -benchmark_all(flash_attn_qkvpacked_func, qkv, None, causal, repeats=repeats, desc='FlashAttention Triton') -pytorch_profiler(flash_attn_qkvpacked_func, qkv, None, causal, backward=True) + # q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, + # requires_grad=True) for _ in range(3)] + # import xformers.ops as xops + # f, b = time_fwd_bwd( + # xops.memory_efficient_attention, q, k, v, + # attn_bias=xops.LowerTriangularMask() if causal else None, + # op=(xops.fmha.cutlass.FwOp, xops.fmha.cutlass.BwOp) + # ) + # time_f[(causal, headdim, batch_size, seqlen), "xformers"] = f + # time_b[(causal, headdim, batch_size, seqlen), "xformers"] = b -q, k, v = [torch.randn(batch_size, nheads, seqlen, headdim, device=device, dtype=dtype, - requires_grad=True) for _ in range(3)] -benchmark_all(attention_og, q, k, v, 1.0, repeats=repeats, desc='FlashAttention Triton OG') -# pytorch_profiler(attention, q, k, v, 1.0, backward=True) -if scaled_upper_triang_masked_softmax is not None: - benchmark_all(attention_megatron, qkv, repeats=repeats, desc='Megatron Attention') +import pickle +with open('flash2_attn_time_h100.plk', 'wb') as fp: + pickle.dump((time_f, time_b), fp, protocol=pickle.HIGHEST_PROTOCOL) diff --git a/benchmarks/benchmark_flash_attention.py b/benchmarks/benchmark_flash_attention.py index 03ae29d1f..e0b63881d 100644 --- a/benchmarks/benchmark_flash_attention.py +++ b/benchmarks/benchmark_flash_attention.py @@ -8,7 +8,7 @@ from flash_attn.utils.benchmark import benchmark_all, benchmark_forward, benchmark_backward, benchmark_combined from flash_attn.bert_padding import unpad_input, pad_input -from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func +from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func def attention_ref(qkv, attn_mask, dropout_p, upcast=False, causal=False): @@ -62,7 +62,7 @@ def attention_ref(qkv, attn_mask, dropout_p, upcast=False, causal=False): h=nheads).detach().requires_grad_() qkv = rearrange(Wqkv(x), 'b s (t h d) -> b s t h d', t=3, h=nheads).detach().requires_grad_() -fn = lambda qkv_unpad: flash_attn_unpadded_qkvpacked_func( +fn = lambda qkv_unpad: flash_attn_varlen_qkvpacked_func( qkv_unpad, cu_seqlens, max_seqlen_in_batch, dropout_p, causal=causal ) benchmark_all(fn, qkv_unpad, repeats=repeats, desc='FlashAttention') diff --git a/csrc/cutlass b/csrc/cutlass new file mode 160000 index 000000000..c4f6b8c6b --- /dev/null +++ b/csrc/cutlass @@ -0,0 +1 @@ +Subproject commit c4f6b8c6bc94ff69048492fb34df0dfaf1983933 diff --git a/csrc/flash_attn/cutlass b/csrc/flash_attn/cutlass deleted file mode 160000 index 319a389f4..000000000 --- a/csrc/flash_attn/cutlass +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 319a389f42b776fae5701afcb943fc03be5b5c25 diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp new file mode 100644 index 000000000..cc2b54160 --- /dev/null +++ b/csrc/flash_attn/flash_api.cpp @@ -0,0 +1,912 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#include +#include +#include + +#include + +#include "flash.h" +#include "static_switch.h" + +#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") + + +void set_params_fprop(Flash_fwd_params ¶ms, + // sizes + const size_t b, + const size_t seqlen_q, + const size_t seqlen_k, + const size_t seqlen_q_rounded, + const size_t seqlen_k_rounded, + const size_t h, + const size_t h_k, + const size_t d, + const size_t d_rounded, + // device pointers + const at::Tensor q, + const at::Tensor k, + const at::Tensor v, + at::Tensor out, + void *cu_seqlens_q_d, + void *cu_seqlens_k_d, + void *p_d, + void *softmax_lse_d, + float p_dropout, + float softmax_scale, + bool is_causal) { + + // Reset the parameters + memset(¶ms, 0, sizeof(params)); + + params.is_bf16 = q.dtype() == torch::kBFloat16; + + // Set the pointers and strides. + params.q_ptr = q.data_ptr(); + params.k_ptr = k.data_ptr(); + params.v_ptr = v.data_ptr(); + // All stride are in elements, not bytes. + params.q_row_stride = q.stride(-3); + params.k_row_stride = k.stride(-3); + params.v_row_stride = v.stride(-3); + params.q_head_stride = q.stride(-2); + params.k_head_stride = k.stride(-2); + params.v_head_stride = v.stride(-2); + params.o_ptr = out.data_ptr(); + params.o_row_stride = out.stride(-3); + params.o_head_stride = out.stride(-2); + + if (cu_seqlens_q_d == nullptr) { + params.q_batch_stride = q.stride(0); + params.k_batch_stride = k.stride(0); + params.v_batch_stride = v.stride(0); + params.o_batch_stride = out.stride(0); + } + + params.cu_seqlens_q = static_cast(cu_seqlens_q_d); + params.cu_seqlens_k = static_cast(cu_seqlens_k_d); + + // P = softmax(QK^T) + params.p_ptr = p_d; + + // Softmax sum + params.softmax_lse_ptr = softmax_lse_d; + + // Set the dimensions. + params.b = b; + params.h = h; + params.h_k = h_k; + params.h_h_k_ratio = h / h_k; + params.seqlen_q = seqlen_q; + params.seqlen_k = seqlen_k; + params.seqlen_q_rounded = seqlen_q_rounded; + params.seqlen_k_rounded = seqlen_k_rounded; + params.d = d; + params.d_rounded = d_rounded; + + // Set the different scale values. + params.scale_softmax = softmax_scale; + params.scale_softmax_log2 = softmax_scale * M_LOG2E; + + // Set this to probability of keeping an element to simplify things. + params.p_dropout = 1.f - p_dropout; + // Convert p from float to int so we don't have to convert the random uint to float to compare. + // [Minor] We want to round down since when we do the comparison we use <= instead of < + // params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0)); + // params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0)); + params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0)); + params.rp_dropout = 1.f / params.p_dropout; + params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax; + TORCH_CHECK(p_dropout < 1.f); + + params.is_causal = is_causal; +} + +void set_params_dgrad(Flash_bwd_params ¶ms, + // sizes + const size_t b, + const size_t seqlen_q, + const size_t seqlen_k, + const size_t seqlen_q_rounded, + const size_t seqlen_k_rounded, + const size_t h, + const size_t h_k, + const size_t d, + const size_t d_rounded, + // device pointers + const at::Tensor q, + const at::Tensor k, + const at::Tensor v, + const at::Tensor out, + const at::Tensor dout, + at::Tensor dq, + at::Tensor dk, + at::Tensor dv, + void *cu_seqlens_q_d, + void *cu_seqlens_k_d, + void *dq_accum_d, + void *dk_accum_d, + void *dv_accum_d, + void *softmax_lse_d, + void *dsoftmax_sum_d, + float p_dropout, + float softmax_scale, + bool is_causal) { + + set_params_fprop(params, + b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded, + q, k, v, out, + cu_seqlens_q_d, + cu_seqlens_k_d, + nullptr, + softmax_lse_d, + p_dropout, + softmax_scale, + is_causal); + + // Set the pointers and strides. + params.do_ptr = dout.data_ptr(); + params.do_row_stride = dout.stride(-3); + params.do_head_stride = dout.stride(-2); + params.dq_ptr = dq.data_ptr(); + params.dk_ptr = dk.data_ptr(); + params.dv_ptr = dv.data_ptr(); + params.dq_row_stride = dq.stride(-3); + params.dk_row_stride = dk.stride(-3); + params.dv_row_stride = dv.stride(-3); + params.dq_head_stride = dq.stride(-2); + params.dk_head_stride = dk.stride(-2); + params.dv_head_stride = dv.stride(-2); + + if (cu_seqlens_q_d == nullptr) { + params.do_batch_stride = dout.stride(0); + params.dq_batch_stride = dq.stride(0); + params.dk_batch_stride = dk.stride(0); + params.dv_batch_stride = dv.stride(0); + } + + params.dq_accum_ptr = dq_accum_d; + params.dk_accum_ptr = dk_accum_d; + params.dv_accum_ptr = dv_accum_d; + + // Softmax sum + params.dsoftmax_sum = dsoftmax_sum_d; +} + +void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { + FP16_SWITCH(!params.is_bf16, [&] { + FWD_HEADDIM_SWITCH(params.d, [&] { + run_mha_fwd_(params, stream); + }); + }); +} + +std::vector +mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size + c10::optional &out_, // batch_size x seqlen_q x num_heads x head_size + const float p_dropout, + const float softmax_scale, + const bool is_causal, + const bool return_softmax, + c10::optional gen_) { + + auto dprops = at::cuda::getCurrentDeviceProperties(); + // bool is_sm75 = dprops->major == 7 && dprops->minor == 5; + bool is_sm8x = dprops->major == 8 && dprops->minor >= 0; + bool is_sm90 = dprops->major == 9 && dprops->minor == 0; + TORCH_CHECK(is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer."); + // We will support Turing in the near future + // TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer."); + + auto q_dtype = q.dtype(); + TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, + "FlashAttention only support fp16 and bf16 data type"); + if (q_dtype == torch::kBFloat16) { + TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer"); + } + TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); + TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); + + TORCH_CHECK(q.is_cuda(), "Input tensor must be on CUDA device"); + TORCH_CHECK(k.is_cuda(), "Input tensor must be on CUDA device"); + TORCH_CHECK(v.is_cuda(), "Input tensor must be on CUDA device"); + + TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + + const auto sizes = q.sizes(); + + const int batch_size = sizes[0]; + const int seqlen_q = sizes[1]; + const int num_heads = sizes[2]; + const int head_size_og = sizes[3]; + const int seqlen_k = k.size(1); + const int num_heads_k = k.size(2); + TORCH_CHECK(batch_size > 0, "batch size must be postive"); + TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256"); + TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + + CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og); + CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size_og); + CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_og); + + at::Tensor q_padded, k_padded, v_padded; + if (head_size_og % 8 != 0) { + q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); + k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); + v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); + } else { + q_padded = q; + k_padded = k; + v_padded = v; + } + + at::Tensor out; + if (out_.has_value()) { + out = out_.value(); + TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); + TORCH_CHECK(out.is_cuda(), "Output tensor must be on CUDA device"); + TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); + CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og); + if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); } + } else { + out = torch::empty_like(q_padded); + } + + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size = round_multiple(head_size_og, 8); + const int head_size_rounded = round_multiple(head_size, 32); + const int seqlen_q_rounded = round_multiple(seqlen_q, 128); + const int seqlen_k_rounded = round_multiple(seqlen_k, 128); + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)q.get_device()}; + + auto opts = q.options(); + + auto softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); + at::Tensor p; + // Only return softmax if there's dropout to reduce compilation time + if (return_softmax) { + TORCH_CHECK(p_dropout > 0.0f, "return_softmax is only supported when p_dropout > 0.0"); + p = torch::empty({ batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded }, opts); + } + + Flash_fwd_params params; + set_params_fprop(params, + batch_size, + seqlen_q, seqlen_k, + seqlen_q_rounded, seqlen_k_rounded, + num_heads, num_heads_k, + head_size, head_size_rounded, + q_padded, k_padded, v_padded, out, + /*cu_seqlens_q_d=*/nullptr, + /*cu_seqlens_k_d=*/nullptr, + return_softmax ? p.data_ptr() : nullptr, + softmax_lse.data_ptr(), + p_dropout, + softmax_scale, + is_causal); + + if (p_dropout > 0.0) { + // number of times random will be generated per thread, to offset philox counter in thc random + // state + // We use a custom RNG that increases the offset by batch_size * nheads * 32. + int64_t counter_offset = params.b * params.h * 32; + auto gen = at::get_generator_or_default( + gen_, at::cuda::detail::getDefaultCUDAGenerator()); + // See Note [Acquire lock when using random generators] + std::lock_guard lock(gen->mutex_); + params.philox_args = gen->philox_cuda_state(counter_offset); + } + + auto stream = at::cuda::getCurrentCUDAStream().stream(); + run_mha_fwd(params, stream); + + at::Tensor out_padded = out; + if (head_size_og % 8 != 0) { + out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}); + if (out_.has_value()) { out_.value().copy_(out); } + } + + return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p}; +} + +std::vector +mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + c10::optional &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &cu_seqlens_q, // b+1 + const at::Tensor &cu_seqlens_k, // b+1 + const int max_seqlen_q, + const int max_seqlen_k, + const float p_dropout, + const float softmax_scale, + const bool zero_tensors, + const bool is_causal, + const bool return_softmax, + c10::optional gen_) { + + auto dprops = at::cuda::getCurrentDeviceProperties(); + // bool is_sm75 = dprops->major == 7 && dprops->minor == 5; + bool is_sm8x = dprops->major == 8 && dprops->minor >= 0; + bool is_sm90 = dprops->major == 9 && dprops->minor == 0; + TORCH_CHECK(is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer."); + // We will support Turing in the near future + // TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer."); + + auto q_dtype = q.dtype(); + TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, + "FlashAttention only support fp16 and bf16 data type"); + if (q_dtype == torch::kBFloat16) { + TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer"); + } + TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); + TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); + TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32"); + TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32"); + + TORCH_CHECK(q.is_cuda(), "Input tensor must be on CUDA device"); + TORCH_CHECK(k.is_cuda(), "Input tensor must be on CUDA device"); + TORCH_CHECK(v.is_cuda(), "Input tensor must be on CUDA device"); + TORCH_CHECK(cu_seqlens_q.is_cuda(), "cu_seqlens_q must be on CUDA device"); + TORCH_CHECK(cu_seqlens_k.is_cuda(), "cu_seqlens_k must be on CUDA device"); + + TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(cu_seqlens_q.is_contiguous(), "cu_seqlens_q must be contiguous"); + TORCH_CHECK(cu_seqlens_k.is_contiguous(), "cu_seqlens_k must be contiguous"); + + const auto sizes = q.sizes(); + + const int total_q = sizes[0]; + const int batch_size = cu_seqlens_q.numel() - 1; + const int num_heads = sizes[1]; + const int head_size_og = sizes[2]; + const int total_k = k.size(0); + const int num_heads_k = k.size(1); + TORCH_CHECK(batch_size > 0, "batch size must be positive"); + TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256"); + TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + + CHECK_SHAPE(q, total_q, num_heads, head_size_og); + CHECK_SHAPE(k, total_k, num_heads_k, head_size_og); + CHECK_SHAPE(v, total_k, num_heads_k, head_size_og); + CHECK_SHAPE(cu_seqlens_q, batch_size + 1); + CHECK_SHAPE(cu_seqlens_k, batch_size + 1); + + at::Tensor q_padded, k_padded, v_padded; + if (head_size_og % 8 != 0) { + q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); + k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); + v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); + } else { + q_padded = q; + k_padded = k; + v_padded = v; + } + + at::Tensor out; + if (out_.has_value()) { + out = out_.value(); + TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); + TORCH_CHECK(out.is_cuda(), "Output tensor must be on CUDA device"); + TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); + CHECK_SHAPE(out, total_q, num_heads, head_size_og); + if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); } + } else { + out = torch::empty_like(q_padded); + } + + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size = round_multiple(head_size_og, 8); + const int head_size_rounded = round_multiple(head_size, 32); + const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128); + const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128); + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)q.get_device()}; + + auto opts = q.options(); + + auto softmax_lse = torch::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat)); + at::Tensor p; + // Only return softmax if there's dropout to reduce compilation time + if (return_softmax) { + TORCH_CHECK(p_dropout > 0.0f, "return_softmax is only supported when p_dropout > 0.0"); + p = torch::empty({ batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded }, opts); + } + + if (zero_tensors) { + out.zero_(); + softmax_lse.fill_(-std::numeric_limits::infinity()); + if (return_softmax) {p.zero_();} + } + + Flash_fwd_params params; + set_params_fprop(params, + batch_size, + max_seqlen_q, max_seqlen_k, + seqlen_q_rounded, seqlen_k_rounded, + num_heads, num_heads_k, + head_size, head_size_rounded, + q_padded, k_padded, v_padded, out, + cu_seqlens_q.data_ptr(), + cu_seqlens_k.data_ptr(), + return_softmax ? p.data_ptr() : nullptr, + softmax_lse.data_ptr(), + p_dropout, + softmax_scale, + is_causal); + + if (p_dropout > 0.0) { + // number of times random will be generated per thread, to offset philox counter in thc random + // state + // We use a custom RNG that increases the offset by batch_size * nheads * 32. + int64_t counter_offset = params.b * params.h * 32; + auto gen = at::get_generator_or_default( + gen_, at::cuda::detail::getDefaultCUDAGenerator()); + // See Note [Acquire lock when using random generators] + std::lock_guard lock(gen->mutex_); + params.philox_args = gen->philox_cuda_state(counter_offset); + } + + auto stream = at::cuda::getCurrentCUDAStream().stream(); + run_mha_fwd(params, stream); + + at::Tensor out_padded = out; + if (head_size_og % 8 != 0) { + out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}); + if (out_.has_value()) { out_.value().copy_(out); } + } + + return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p}; +} + +void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { + FP16_SWITCH(!params.is_bf16, [&] { + if (params.d <= 32) { + run_mha_bwd_(params, stream, configure); + } else if (params.d <= 64) { + run_mha_bwd_(params, stream, configure); + } else if (params.d <= 96) { + run_mha_bwd_(params, stream, configure); + } else if (params.d <= 128) { + run_mha_bwd_(params, stream, configure); + } else if (params.d <= 160) { + run_mha_bwd_(params, stream, configure); + } else if (params.d <= 192) { + run_mha_bwd_(params, stream, configure); + } else if (params.d <= 224) { + run_mha_bwd_(params, stream, configure); + } else if (params.d <= 256) { + run_mha_bwd_(params, stream, configure); + } + }); +} + +std::vector +mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_size_og + const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &softmax_lse, // b x h x seqlen_q + c10::optional &dq_, // batch_size x seqlen_q x num_heads x head_size + c10::optional &dk_, // batch_size x seqlen_k x num_heads_k x head_size + c10::optional &dv_, // batch_size x seqlen_k x num_heads_k x head_size + const float p_dropout, // probability to drop + const float softmax_scale, + const bool is_causal, + c10::optional gen_) { + auto dprops = at::cuda::getCurrentDeviceProperties(); + // bool is_sm75 = dprops->major == 7 && dprops->minor == 5; + bool is_sm8x = dprops->major == 8 && dprops->minor >= 0; + bool is_sm80 = dprops->major == 8 && dprops->minor == 0; + bool is_sm90 = dprops->major == 9 && dprops->minor == 0; + TORCH_CHECK(is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer."); + // We will support Turing in the near future + // TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer."); + + bool is_dropout = p_dropout > 0.0; + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + auto q_dtype = q.dtype(); + TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, + "FlashAttention only support fp16 and bf16 data type"); + if (q_dtype == torch::kBFloat16) { + TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer"); + } + TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); + TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); + TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype"); + TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype"); + + TORCH_CHECK(q.is_cuda(), "Input tensor must be on CUDA device"); + TORCH_CHECK(k.is_cuda(), "Input tensor must be on CUDA device"); + TORCH_CHECK(v.is_cuda(), "Input tensor must be on CUDA device"); + TORCH_CHECK(out.is_cuda(), "out tensor must be on CUDA device"); + TORCH_CHECK(dout.is_cuda(), "dout tensor must be on CUDA device"); + TORCH_CHECK(softmax_lse.is_cuda(), "softmax_lse tensor must be on CUDA device"); + + TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension"); + TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension"); + + const auto sizes = q.sizes(); + + const int batch_size = sizes[0]; + const int seqlen_q = sizes[1]; + const int num_heads = sizes[2]; + const int head_size_og = dout.size(3); + const int head_size = sizes[3]; + const int seqlen_k = k.size(1); + const int num_heads_k = k.size(2); + TORCH_CHECK(batch_size > 0, "batch size must be positive"); + TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8"); + TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256"); + if (head_size > 192) { + TORCH_CHECK(is_sm80 || is_sm90, "FlashAttention backward for head dim > 192 requires A100/A800 or H100/H800"); + } + TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size_rounded = round_multiple(head_size, 32); + const int seqlen_q_rounded = round_multiple(seqlen_q, 128); + const int seqlen_k_rounded = round_multiple(seqlen_k, 128); + + TORCH_CHECK(head_size == round_multiple(head_size_og, 8), "head_size must be head_size_og rounded to a multiple of 8"); + + CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size); + CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size); + CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size); + CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size); + CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size_og); + + at::Tensor dq, dk, dv; + if (dq_.has_value()) { + dq = dq_.value(); + TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q"); + TORCH_CHECK(dq.is_cuda(), "dq must be on CUDA device"); + TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension"); + CHECK_SHAPE(dq, batch_size, seqlen_q, num_heads, head_size); + } else { + dq = torch::empty_like(q); + } + if (dk_.has_value()) { + dk = dk_.value(); + TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q"); + TORCH_CHECK(dk.is_cuda(), "dk must be on CUDA device"); + TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension"); + CHECK_SHAPE(dk, batch_size, seqlen_k, num_heads_k, head_size); + } else { + dk = torch::empty_like(k); + } + if (dv_.has_value()) { + dv = dv_.value(); + TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q"); + TORCH_CHECK(dv.is_cuda(), "dv must be on CUDA device"); + TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension"); + CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size); + } else { + dv = torch::empty_like(k); + } + + at::Tensor dout_padded; + if (head_size_og % 8 != 0) { + dout_padded = torch::nn::functional::pad(dout, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); + } else { + dout_padded = dout; + } + + // bool loop = seqlen_k > blocksize_c; + // TODO: change later, for now set to true for simplicity + bool loop = true; + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)q.get_device()}; + + auto opts = q.options(); + auto softmax_d = torch::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat)); + at::Tensor dq_accum; + at::Tensor dk_accum, dv_accum; + if (loop) { + dq_accum = torch::empty({batch_size, num_heads, seqlen_q_rounded, head_size_rounded}, opts.dtype(at::kFloat)); + // dk_accum = torch::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat)); + // dv_accum = torch::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat)); + } + + at::Tensor dk_expanded, dv_expanded; + if (num_heads_k != num_heads) { // MQA / GQA + dk_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts); + dv_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts); + } else { + dk_expanded = dk; + dv_expanded = dv; + } + + Flash_bwd_params params; + + set_params_dgrad(params, + batch_size, + seqlen_q, seqlen_k, + seqlen_q_rounded, seqlen_k_rounded, + num_heads, num_heads_k, + head_size, head_size_rounded, + q, k, v, out, + dout_padded, dq, dk_expanded, dv_expanded, + nullptr, + nullptr, + loop ? dq_accum.data_ptr() : nullptr, + // loop ? dk_accum.data_ptr() : nullptr, + // loop ? dv_accum.data_ptr() : nullptr, + nullptr, + nullptr, + softmax_lse.data_ptr(), + softmax_d.data_ptr(), + p_dropout, + softmax_scale, + is_causal); + + auto launch = &run_mha_bwd; + // launch(params, stream, /*configure=*/true); + + auto gen = at::get_generator_or_default( + gen_, at::cuda::detail::getDefaultCUDAGenerator()); + + // We use a custom RNG that increases the offset by batch_size * nheads * 32. + int64_t counter_offset = params.b * params.h * 32; + + if (is_dropout) { + // See Note [Acquire lock when using random generators] + std::lock_guard lock(gen->mutex_); + params.philox_args = gen->philox_cuda_state(counter_offset); + } + + launch(params, stream, /*configure=*/false); + + // For MQA/GQA we need to sum dK and dV across the groups + if (num_heads_k != num_heads) { + at::sum_out(dk, at::reshape(dk_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3}); + at::sum_out(dv, at::reshape(dv_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3}); + } + if (head_size_og % 8 != 0) { + dq = dq.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}); + dk = dk.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}); + dv = dv.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}); + } + + return { dq, dk, dv, softmax_d }; +} + +std::vector +mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size + const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &out, // total_q x num_heads x head_size + const at::Tensor &softmax_lse, // b x h x s softmax logsumexp + c10::optional &dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + c10::optional &dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + c10::optional &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &cu_seqlens_q, // b+1 + const at::Tensor &cu_seqlens_k, // b+1 + const int max_seqlen_q, + const int max_seqlen_k, // max sequence length to choose the kernel + const float p_dropout, // probability to drop + const float softmax_scale, + const bool zero_tensors, + const bool is_causal, + c10::optional gen_ +) { + auto dprops = at::cuda::getCurrentDeviceProperties(); + // bool is_sm75 = dprops->major == 7 && dprops->minor == 5; + bool is_sm8x = dprops->major == 8 && dprops->minor >= 0; + bool is_sm80 = dprops->major == 8 && dprops->minor == 0; + bool is_sm90 = dprops->major == 9 && dprops->minor == 0; + TORCH_CHECK(is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer."); + // We will support Turing in the near future + // TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer."); + bool is_dropout = p_dropout > 0.0; + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + auto q_dtype = q.dtype(); + TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, + "FlashAttention only support fp16 and bf16 data type"); + if (q_dtype == torch::kBFloat16) { + TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer"); + } + TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); + TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); + TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype"); + TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype"); + TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32"); + TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32"); + + TORCH_CHECK(q.is_cuda(), "Input tensor must be on CUDA device"); + TORCH_CHECK(k.is_cuda(), "Input tensor must be on CUDA device"); + TORCH_CHECK(v.is_cuda(), "Input tensor must be on CUDA device"); + TORCH_CHECK(out.is_cuda(), "out tensor must be on CUDA device"); + TORCH_CHECK(dout.is_cuda(), "dout tensor must be on CUDA device"); + TORCH_CHECK(softmax_lse.is_cuda(), "softmax_lse tensor must be on CUDA device"); + TORCH_CHECK(cu_seqlens_q.is_cuda(), "cu_seqlens_q must be on CUDA device"); + TORCH_CHECK(cu_seqlens_k.is_cuda(), "cu_seqlens_k must be on CUDA device"); + + TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension"); + TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension"); + TORCH_CHECK(cu_seqlens_q.is_contiguous(), "cu_seqlens_q must be contiguous"); + TORCH_CHECK(cu_seqlens_k.is_contiguous(), "cu_seqlens_k must be contiguous"); + + const auto sizes = q.sizes(); + + const int total_q = sizes[0]; + const int batch_size = cu_seqlens_q.numel() - 1; + const int num_heads = sizes[1]; + const int head_size_og = dout.size(2); + const int head_size = sizes[2]; + const int total_k = k.size(0); + const int num_heads_k = k.size(1); + TORCH_CHECK(batch_size > 0, "batch size must be positive"); + TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8"); + TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256"); + if (head_size > 192) { + TORCH_CHECK(is_sm80 || is_sm90, "FlashAttention backward for head dim > 192 requires A100/A800 or H100/H800"); + } + TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size_rounded = round_multiple(head_size, 32); + const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128); + const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128); + + TORCH_CHECK(head_size == round_multiple(head_size_og, 8), "head_size must be head_size_og rounded to a multiple of 8"); + + CHECK_SHAPE(q, total_q, num_heads, head_size); + CHECK_SHAPE(k, total_k, num_heads_k, head_size); + CHECK_SHAPE(v, total_k, num_heads_k, head_size); + CHECK_SHAPE(out, total_q, num_heads, head_size); + CHECK_SHAPE(dout, total_q, num_heads, head_size_og); + CHECK_SHAPE(cu_seqlens_q, batch_size + 1); + CHECK_SHAPE(cu_seqlens_k, batch_size + 1); + + at::Tensor dq, dk, dv; + if (dq_.has_value()) { + dq = dq_.value(); + TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q"); + TORCH_CHECK(dq.is_cuda(), "dq must be on CUDA device"); + TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension"); + CHECK_SHAPE(dq, total_q, num_heads, head_size); + } else { + dq = torch::empty_like(q); + } + if (dk_.has_value()) { + dk = dk_.value(); + TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q"); + TORCH_CHECK(dk.is_cuda(), "dk must be on CUDA device"); + TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension"); + CHECK_SHAPE(dk, total_k, num_heads_k, head_size); + } else { + dk = torch::empty_like(k); + } + if (dv_.has_value()) { + dv = dv_.value(); + TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q"); + TORCH_CHECK(dv.is_cuda(), "dv must be on CUDA device"); + TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension"); + CHECK_SHAPE(dv, total_k, num_heads_k, head_size); + } else { + dv = torch::empty_like(k); + } + + at::Tensor dout_padded; + if (head_size_og % 8 != 0) { + dout_padded = torch::nn::functional::pad(dout, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); + } else { + dout_padded = dout; + } + + // bool loop = max_seqlen_k > blocksize_c; + // TODO: change later, for now set to true for simplicity + bool loop = true; + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)q.get_device()}; + + auto opts = q.options(); + auto softmax_d = torch::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat)); + at::Tensor dq_accum; + if (loop) { + dq_accum = torch::empty({batch_size, num_heads, seqlen_q_rounded, head_size_rounded}, opts.dtype(at::kFloat)); + } + + at::Tensor dk_expanded, dv_expanded; + if (num_heads_k != num_heads) { // MQA / GQA + dk_expanded = torch::empty({total_k, num_heads, head_size}, opts); + dv_expanded = torch::empty({total_k, num_heads, head_size}, opts); + } else { + dk_expanded = dk; + dv_expanded = dv; + } + + if( zero_tensors ) { + dq.zero_(); + dk_expanded.zero_(); + dv_expanded.zero_(); + softmax_d.zero_(); + } + + Flash_bwd_params params; + + set_params_dgrad(params, + batch_size, + max_seqlen_q, max_seqlen_k, + seqlen_q_rounded, seqlen_k_rounded, + num_heads, num_heads_k, + head_size, head_size_rounded, + q, k, v, out, + dout_padded, dq, dk_expanded, dv_expanded, + cu_seqlens_q.data_ptr(), + cu_seqlens_k.data_ptr(), + loop ? dq_accum.data_ptr() : nullptr, + nullptr, + nullptr, + softmax_lse.data_ptr(), + softmax_d.data_ptr(), + p_dropout, + softmax_scale, + is_causal); + + auto launch = &run_mha_bwd; + // launch(params, stream, /*configure=*/true); + + auto gen = at::get_generator_or_default( + gen_, at::cuda::detail::getDefaultCUDAGenerator()); + + // We use a custom RNG that increases the offset by batch_size * nheads * 32. + int64_t counter_offset = params.b * params.h * 32; + + if (is_dropout) { + // See Note [Acquire lock when using random generators] + std::lock_guard lock(gen->mutex_); + params.philox_args = gen->philox_cuda_state(counter_offset); + } + + launch(params, stream, /*configure=*/false); + + // For MQA/GQA we need to sum dK and dV across the groups + if (num_heads_k != num_heads) { + at::sum_out(dk, at::reshape(dk_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size}), {2}); + at::sum_out(dv, at::reshape(dv_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size}), {2}); + } + if (head_size_og % 8 != 0) { + dq = dq.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}); + dk = dk.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}); + dv = dv.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}); + } + + return { dq, dk, dv, softmax_d }; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.doc() = "FlashAttention"; + m.def("fwd", &mha_fwd, "Forward pass"); + m.def("varlen_fwd", &mha_varlen_fwd, "Forward pass (variable length)"); + m.def("bwd", &mha_bwd, "Backward pass"); + m.def("varlen_bwd", &mha_varlen_bwd, "Backward pass (variable length)"); +} diff --git a/csrc/flash_attn/fmha_api.cpp b/csrc/flash_attn/fmha_api.cpp deleted file mode 100644 index 73d9cdb58..000000000 --- a/csrc/flash_attn/fmha_api.cpp +++ /dev/null @@ -1,796 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2022, Tri Dao. - * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the - * names of its contributors may be used to endorse or promote products - * derived from this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND - * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED - * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY - * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES - * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; - * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND - * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - ******************************************************************************/ - -#include -#include -#include - -#include "fmha.h" - -#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") - - -void set_params_fprop(FMHA_fprop_params ¶ms, - // sizes - const size_t b, - const size_t seqlen_q, - const size_t seqlen_k, - const size_t h, - const size_t d, - // device pointers - const at::Tensor q, - const at::Tensor k, - const at::Tensor v, - at::Tensor out, - void *cu_seqlens_q_d, - void *cu_seqlens_k_d, - void *o_tmp_d, - void *s_d, - void *softmax_lse_d, - float p_dropout, - float softmax_scale, - bool is_causal, - int num_splits) { - - Data_type acc_type = DATA_TYPE_FP32; - Data_type data_type = !(q.dtype() == torch::kBFloat16) ? DATA_TYPE_FP16 : DATA_TYPE_BF16; - - // Reset the parameters - memset(¶ms, 0, sizeof(params)); - - params.is_bf16 = q.dtype() == torch::kBFloat16; - - // Set the pointers and strides. - params.q_ptr = q.data_ptr(); - params.k_ptr = k.data_ptr(); - params.v_ptr = v.data_ptr(); - params.q_row_stride_in_elts = q.stride(0); - params.k_row_stride_in_elts = k.stride(0); - params.v_row_stride_in_elts = v.stride(0); - params.q_head_stride_in_elts = q.stride(1); - params.k_head_stride_in_elts = k.stride(1); - params.v_head_stride_in_elts = v.stride(1); - params.o_ptr = out.data_ptr(); - params.o_row_stride_in_elts = out.stride(0); - params.o_head_stride_in_elts = out.stride(1); - params.o_tmp_ptr = o_tmp_d; - params.o_tmp_row_stride_in_elts = h * d; - params.o_tmp_head_stride_in_elts = d; - - params.cu_seqlens_q = static_cast(cu_seqlens_q_d); - params.cu_seqlens_k = static_cast(cu_seqlens_k_d); - - // S = softmax(P) - params.s_ptr = s_d; - params.s_stride_in_bytes = get_size_in_bytes(b * h * seqlen_k, data_type); - - // Softmax sum - params.softmax_lse_ptr = softmax_lse_d; - - // Set the dimensions. - params.b = b; - params.h = h; - params.seqlen_q = seqlen_q; - params.seqlen_k = seqlen_k; - params.d = d; - - // Set the different scale values. - // const float scale_bmm1 = 1.f / sqrtf(d); - const float scale_bmm1 = softmax_scale; - - params.scale_bmm1f = scale_bmm1; - set_alpha(params.scale_bmm1, scale_bmm1, data_type); - - // Set this to probability of keeping an element to simplify things. - params.p_dropout = 1.f - p_dropout; - // Convert p from float to int so we don't have to convert the random uint to float to compare. - // [Minor] We want to round down since when we do the comparison we use <= instead of < - params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0)); - params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0)); - params.rp_dropout = 1.f / params.p_dropout; - params.scale_bmm1_rp_dropout = params.rp_dropout * params.scale_bmm1f; - TORCH_CHECK(p_dropout < 1.f); - set_alpha(params.scale_dropout, params.rp_dropout, data_type); - - params.is_causal = is_causal; - params.num_splits = num_splits; -} - -void set_params_dgrad(FMHA_dgrad_params ¶ms, - // sizes - const size_t b, - const size_t seqlen_q, - const size_t seqlen_k, - const size_t h, - const size_t d, - // device pointers - const at::Tensor q, - const at::Tensor k, - const at::Tensor v, - const at::Tensor out, - at::Tensor dq, - at::Tensor dk, - at::Tensor dv, - void *cu_seqlens_q_d, - void *cu_seqlens_k_d, - void *dq_tmp_d, - void *do_packed_d, - void *softmax_lse_d, - void *dsoftmax_sum_d, - float p_dropout, - float softmax_scale, - bool is_causal, - int num_splits) { - - set_params_fprop(params, - b, seqlen_q, seqlen_k, h, d, - q, k, v, out, - cu_seqlens_q_d, - cu_seqlens_k_d, - dq_tmp_d, // Reusing the o_tmp_ptr variable to store dq_tmp - nullptr, - softmax_lse_d, - p_dropout, - softmax_scale, - is_causal, - num_splits); - - // Set the pointers and strides. - params.dq_ptr = dq.data_ptr(); - params.dk_ptr = dk.data_ptr(); - params.dv_ptr = dv.data_ptr(); - params.dq_row_stride_in_elts = dq.stride(0); - params.dk_row_stride_in_elts = dk.stride(0); - params.dv_row_stride_in_elts = dv.stride(0); - params.dq_head_stride_in_elts = dq.stride(1); - params.dk_head_stride_in_elts = dk.stride(1); - params.dv_head_stride_in_elts = dv.stride(1); - params.do_ptr = do_packed_d; - - // Softmax sum - params.dsoftmax_sum = dsoftmax_sum_d; -} - -void run_fmha_fwd(Launch_params &launch_params) { - if (launch_params.params.d <= 32) { - run_fmha_fwd_hdim32(launch_params); - } else if (launch_params.params.d <= 64) { - run_fmha_fwd_hdim64(launch_params); - } else if (launch_params.params.d <= 128) { - run_fmha_fwd_hdim128(launch_params); - } -} - -std::vector -mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i - const at::Tensor &k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i - const at::Tensor &v, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i - at::Tensor &out, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i - const at::Tensor &cu_seqlens_q, // b+1 - const at::Tensor &cu_seqlens_k, // b+1 - const int max_seqlen_q_, - const int max_seqlen_k_, - const float p_dropout, - const float softmax_scale, - const bool zero_tensors, - const bool is_causal, - const bool return_softmax, - const int num_splits, - c10::optional gen_) { - - auto dprops = at::cuda::getCurrentDeviceProperties(); - bool is_sm75 = dprops->major == 7 && dprops->minor == 5; - bool is_sm80 = dprops->major == 8 && dprops->minor == 0; - bool is_sm8x = dprops->major == 8 && dprops->minor >= 0; - bool is_sm90 = dprops->major == 9 && dprops->minor == 0; - TORCH_CHECK(is_sm90 || is_sm8x || is_sm75); - auto stream = at::cuda::getCurrentCUDAStream().stream(); - bool is_dropout = p_dropout > 0.0; - Launch_params launch_params(dprops, stream, is_dropout, return_softmax); - - auto q_dtype = q.dtype(); - TORCH_CHECK(q_dtype == torch::kFloat16 || ((is_sm8x || is_sm90) && q_dtype == torch::kBFloat16)); - TORCH_CHECK(k.dtype() == q_dtype); - TORCH_CHECK(v.dtype() == q_dtype); - TORCH_CHECK(out.dtype() == q_dtype); - TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32); - TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32); - - TORCH_CHECK(q.is_cuda()); - TORCH_CHECK(k.is_cuda()); - TORCH_CHECK(v.is_cuda()); - TORCH_CHECK(out.is_cuda()); - TORCH_CHECK(cu_seqlens_q.is_cuda()); - TORCH_CHECK(cu_seqlens_k.is_cuda()); - - TORCH_CHECK(q.stride(-1) == 1); - TORCH_CHECK(k.stride(-1) == 1); - TORCH_CHECK(v.stride(-1) == 1); - TORCH_CHECK(out.stride(-1) == 1); - TORCH_CHECK(cu_seqlens_q.is_contiguous()); - TORCH_CHECK(cu_seqlens_k.is_contiguous()); - - const auto sizes = q.sizes(); - - const int batch_size = cu_seqlens_q.numel() - 1; - const int total_q = sizes[TOTAL_DIM]; - const int num_heads = sizes[H_DIM]; - const int head_size = sizes[D_DIM]; - const int total_k = k.size(TOTAL_DIM); - TORCH_CHECK(batch_size > 0); - TORCH_CHECK((head_size % 8 == 0) && (head_size <= 128)); - - CHECK_SHAPE(q, total_q, num_heads, head_size); - CHECK_SHAPE(k, total_k, num_heads, head_size); - CHECK_SHAPE(v, total_k, num_heads, head_size); - CHECK_SHAPE(out, total_q, num_heads, head_size); - CHECK_SHAPE(cu_seqlens_q, batch_size + 1); - CHECK_SHAPE(cu_seqlens_k, batch_size + 1); - - int blocksize_c = head_size > 64 ? 128 : 256; - // Need to round max_seqlen_k to multiples of blocksize_c - int max_seqlen_k = ((max_seqlen_k_ + blocksize_c - 1) / blocksize_c) * blocksize_c; - if( max_seqlen_k_ <= 128 ) { - max_seqlen_k = 128; - } else if( max_seqlen_k_ <= 256 ) { - max_seqlen_k = 256; - } - int max_seqlen_q = ((max_seqlen_q_ + 16 - 1) / 16) * 16; - bool loop = max_seqlen_k > blocksize_c; - - // Otherwise the kernel will be launched from cuda:0 device - // Cast to char to avoid compiler warning about narrowing - at::cuda::CUDAGuard device_guard{(char)q.get_device()}; - - auto opts = q.options(); - - // auto o = torch::empty({ total_q, num_heads, head_size }, opts); - - at::Tensor o_tmp; - if (loop) { o_tmp = torch::empty({total_q, num_heads, head_size}, opts.dtype(at::kFloat)); } - - auto softmax_lse = torch::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat)); - // auto softmax_lse = torch::full({batch_size, num_heads, max_seqlen_k}, -std::numeric_limits::infinity(), opts.dtype(at::kFloat)); - - at::Tensor s; - if (return_softmax) { s = torch::empty({ batch_size, num_heads, max_seqlen_q, max_seqlen_k }, opts); } - - if( zero_tensors ) { - out.zero_(); - softmax_lse.fill_(-std::numeric_limits::infinity()); - if (return_softmax) {s.zero_();} - } - - auto gen = at::get_generator_or_default( - gen_, at::cuda::detail::getDefaultCUDAGenerator()); - - set_params_fprop(launch_params.params, - batch_size, - max_seqlen_q, - max_seqlen_k, - num_heads, - head_size, - q, k, v, out, - cu_seqlens_q.data_ptr(), - cu_seqlens_k.data_ptr(), - loop ? o_tmp.data_ptr() : nullptr, - return_softmax ? s.data_ptr() : nullptr, - softmax_lse.data_ptr(), - p_dropout, - softmax_scale, - is_causal, - num_splits); - - // number of times random will be generated per thread, to offset philox counter in thc random - // state - // We use a custom RNG that increases the offset by batch_size * nheads * 32. - int64_t counter_offset = launch_params.params.b * launch_params.params.h * 32; - auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); - auto rng_state = torch::empty({2}, options.dtype(torch::kInt64)); - // Forward kernel will populate memory with the seed and offset. - launch_params.params.rng_state = reinterpret_cast(rng_state.data_ptr()); - - if( is_dropout ) { - // See Note [Acquire lock when using random generators] - std::lock_guard lock(gen->mutex_); - launch_params.params.philox_args = gen->philox_cuda_state(counter_offset); - } - - run_fmha_fwd(launch_params); - - std::vector result = {softmax_lse}; - result.push_back(rng_state); - if (return_softmax) {result.push_back(s);} - return result; -} - -void run_fmha_bwd(FMHA_dgrad_params ¶ms, cudaStream_t stream, const bool configure) { - if (params.d <= 32) { - run_fmha_bwd_hdim32(params, stream, configure); - } else if (params.d <= 64) { - run_fmha_bwd_hdim64(params, stream, configure); - } else if (params.d <= 128) { - run_fmha_bwd_hdim128(params, stream, configure); - } -} - -std::vector -mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size - const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i - const at::Tensor &k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i - const at::Tensor &v, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i - const at::Tensor &out, // total_q x num_heads x head_size - const at::Tensor &softmax_lse_, // b x h x s softmax logsumexp - at::Tensor &dq, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i - at::Tensor &dk, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i - at::Tensor &dv, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i - const at::Tensor &cu_seqlens_q, // b+1 - const at::Tensor &cu_seqlens_k, // b+1 - const int max_seqlen_q_, - const int max_seqlen_k_, // max sequence length to choose the kernel - const float p_dropout, // probability to drop - const float softmax_scale, - const bool zero_tensors, - const bool is_causal, - const int num_splits, - c10::optional gen_, - c10::optional &rng_state -) { - auto dprops = at::cuda::getCurrentDeviceProperties(); - bool is_sm75 = dprops->major == 7 && dprops->minor == 5; - bool is_sm80 = dprops->major == 8 && dprops->minor == 0; - bool is_sm8x = dprops->major == 8 && dprops->minor >= 0; - bool is_sm90 = dprops->major == 9 && dprops->minor == 0; - TORCH_CHECK(is_sm90 || is_sm8x || is_sm75); - auto launch = &run_fmha_bwd; - - bool is_dropout = p_dropout > 0.0; - auto stream = at::cuda::getCurrentCUDAStream().stream(); - - auto q_dtype = q.dtype(); - TORCH_CHECK(q_dtype == torch::kFloat16 || ((is_sm8x || is_sm90) && q_dtype == torch::kBFloat16)); - TORCH_CHECK(k.dtype() == q_dtype); - TORCH_CHECK(v.dtype() == q_dtype); - TORCH_CHECK(out.dtype() == q_dtype); - TORCH_CHECK(dout.dtype() == q_dtype); - TORCH_CHECK(dq.dtype() == q_dtype); - TORCH_CHECK(dk.dtype() == q_dtype); - TORCH_CHECK(dv.dtype() == q_dtype); - TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32); - TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32); - - TORCH_CHECK(q.is_cuda()); - TORCH_CHECK(k.is_cuda()); - TORCH_CHECK(v.is_cuda()); - TORCH_CHECK(out.is_cuda()); - TORCH_CHECK(dout.is_cuda()); - TORCH_CHECK(softmax_lse_.is_cuda()); - TORCH_CHECK(cu_seqlens_q.is_cuda()); - TORCH_CHECK(cu_seqlens_k.is_cuda()); - - TORCH_CHECK(q.stride(-1) == 1); - TORCH_CHECK(k.stride(-1) == 1); - TORCH_CHECK(v.stride(-1) == 1); - TORCH_CHECK(out.is_contiguous()); - TORCH_CHECK(dout.is_contiguous()); - TORCH_CHECK(dq.stride(-1) == 1); - TORCH_CHECK(dk.stride(-1) == 1); - TORCH_CHECK(dv.stride(-1) == 1); - TORCH_CHECK(cu_seqlens_q.is_contiguous()); - TORCH_CHECK(cu_seqlens_k.is_contiguous()); - - const auto sizes = q.sizes(); - - const int batch_size = cu_seqlens_q.numel() - 1; - const int total_q = sizes[TOTAL_DIM]; - const int num_heads = sizes[H_DIM]; - const int head_size = sizes[D_DIM]; - const int total_k = k.size(TOTAL_DIM); - TORCH_CHECK(batch_size > 0); - TORCH_CHECK((head_size % 8 == 0) && (head_size <= 128)); - if (head_size > 64) { - TORCH_CHECK(is_sm80 || is_sm90, "FlashAttention backward for head dim > 64 requires A100 or H100 GPUs as the implementation needs a large amount of shared memory."); - } - - CHECK_SHAPE(q, total_q, num_heads, head_size); - CHECK_SHAPE(k, total_k, num_heads, head_size); - CHECK_SHAPE(v, total_k, num_heads, head_size); - CHECK_SHAPE(out, total_q, num_heads, head_size); - CHECK_SHAPE(dout, total_q, num_heads, head_size); - CHECK_SHAPE(dq, total_q, num_heads, head_size); - CHECK_SHAPE(dk, total_k, num_heads, head_size); - CHECK_SHAPE(dv, total_k, num_heads, head_size); - CHECK_SHAPE(cu_seqlens_q, batch_size + 1); - CHECK_SHAPE(cu_seqlens_k, batch_size + 1); - - int blocksize_c = (head_size > 64 || (is_sm75 && head_size > 32)) ? 128 : 256; - int max_seqlen_k = ((max_seqlen_k_ + blocksize_c - 1) / blocksize_c) * blocksize_c; - if( max_seqlen_k_ <= 128 ) { - max_seqlen_k = 128; - } else if( max_seqlen_k_ <= 256 ) { - max_seqlen_k = 256; - } - int max_seqlen_q = ((max_seqlen_q_ + 16 - 1) / 16) * 16; - bool loop = max_seqlen_k > blocksize_c; - - // Otherwise the kernel will be launched from cuda:0 device - // Cast to char to avoid compiler warning about narrowing - at::cuda::CUDAGuard device_guard{(char)q.get_device()}; - - // It's possible the softmax_lse_ from the fwd has a different length since blocksize_c could be different. - auto softmax_lse = softmax_lse_.index({torch::indexing::Slice(), torch::indexing::Slice(), torch::indexing::Slice(torch::indexing::None, max_seqlen_q)}).contiguous(); - - auto opts = q.options(); - auto softmax_d = torch::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat)); - at::Tensor dq_tmp; - if (loop) { dq_tmp = torch::empty({total_q, num_heads, head_size}, opts.dtype(at::kFloat)); } - - if( zero_tensors ) { - dq.zero_(); - dk.zero_(); - dv.zero_(); - softmax_d.zero_(); - } - - FMHA_dgrad_params params; - - set_params_dgrad(params, - batch_size, - max_seqlen_q, - max_seqlen_k, - num_heads, - head_size, - q, k, v, out, - dq, dk, dv, - cu_seqlens_q.data_ptr(), - cu_seqlens_k.data_ptr(), - loop ? dq_tmp.data_ptr() : nullptr, - dout.data_ptr(), - softmax_lse.data_ptr(), - softmax_d.data_ptr(), - p_dropout, - softmax_scale, - is_causal, - num_splits); - - launch(params, stream, /*configure=*/true); - - if (params.num_splits > 1) { - if (!dq_tmp.defined()) { - dq_tmp = torch::zeros({total_q, num_heads, head_size}, opts.dtype(at::kFloat)); - params.o_tmp_ptr = dq_tmp.data_ptr(); // o_tmp stores dq_tmp in the backward pass - } else { - dq_tmp.zero_(); - } - } - - auto gen = at::get_generator_or_default( - gen_, at::cuda::detail::getDefaultCUDAGenerator()); - - // We use a custom RNG that increases the offset by batch_size * nheads * 32. - int64_t counter_offset = params.b * params.h * 32; - if ( rng_state.has_value() ) { - params.rng_state = reinterpret_cast(rng_state.value().data_ptr()); - } else if( is_dropout ) { - // See Note [Acquire lock when using random generators] - std::lock_guard lock(gen->mutex_); - params.philox_args = gen->philox_cuda_state(counter_offset); - auto seeds = at::cuda::philox::unpack(params.philox_args); - params.rng_state[0] = std::get<0>(seeds); - params.rng_state[1] = std::get<1>(seeds); - } - - launch(params, stream, /*configure=*/false); - - if (params.num_splits > 1) { - dq.copy_(dq_tmp); - } - - return { dq, dk, dv, softmax_d }; -} - -std::vector -mha_fwd_block(const at::Tensor &q, // total_q x num_heads x head_size, total := \sum_{i=0}^{b} s_i - const at::Tensor &k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i - const at::Tensor &v, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i - const at::Tensor &cu_seqlens_q, // b+1 - const at::Tensor &cu_seqlens_k, // b+1 - const at::Tensor &blockmask, // (seqlen / 256, seqlen / 16) - const int max_seqlen_q_, - const int max_seqlen_k_, - const float p_dropout, - const float softmax_scale, - const bool is_causal, - const bool return_softmax, - c10::optional gen_) { - - auto dprops = at::cuda::getCurrentDeviceProperties(); - bool is_sm80 = dprops->major == 8 && dprops->minor == 0; - bool is_sm8x = dprops->major == 8 && dprops->minor >= 0; - bool is_sm90 = dprops->major == 9 && dprops->minor == 0; - TORCH_CHECK(is_sm8x || is_sm90); - auto stream = at::cuda::getCurrentCUDAStream().stream(); - bool is_dropout = p_dropout > 0.0; - Launch_params launch_params(dprops, stream, is_dropout, return_softmax); - - TORCH_CHECK(q.dtype() == torch::kFloat16); - TORCH_CHECK(k.dtype() == torch::kFloat16); - TORCH_CHECK(v.dtype() == torch::kFloat16); - TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32); - TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32); - TORCH_CHECK(blockmask.dtype() == torch::kInt32); - - TORCH_CHECK(q.is_cuda()); - TORCH_CHECK(k.is_cuda()); - TORCH_CHECK(v.is_cuda()); - TORCH_CHECK(cu_seqlens_q.is_cuda()); - TORCH_CHECK(cu_seqlens_k.is_cuda()); - TORCH_CHECK(blockmask.is_cuda()) - - TORCH_CHECK(q.stride(-1) == 1); - TORCH_CHECK(k.stride(-1) == 1); - TORCH_CHECK(v.stride(-1) == 1); - TORCH_CHECK(cu_seqlens_k.is_contiguous()); - TORCH_CHECK(cu_seqlens_k.is_contiguous()); - TORCH_CHECK(blockmask.is_contiguous()) - - const auto sizes = q.sizes(); - - const int batch_size = cu_seqlens_q.numel() - 1; - const int total_q = sizes[TOTAL_DIM]; - const int num_heads = sizes[H_DIM]; - const int head_size = sizes[D_DIM]; - const int total_k = k.size(TOTAL_DIM); - TORCH_CHECK(batch_size > 0); - TORCH_CHECK(head_size == 16 || head_size == 32 || head_size == 64 || head_size == 128); - - CHECK_SHAPE(q, total_q, num_heads, head_size); - CHECK_SHAPE(k, total_k, num_heads, head_size); - CHECK_SHAPE(v, total_k, num_heads, head_size); - CHECK_SHAPE(cu_seqlens_q, batch_size + 1); - CHECK_SHAPE(cu_seqlens_k, batch_size + 1); - - int max_seqlen_k = ((max_seqlen_k_ + 256 - 1) / 256) * 256; - if( max_seqlen_k <= 256 ) { - max_seqlen_k = 256; - } - int max_seqlen_q = ((max_seqlen_q_ + 16 - 1) / 16) * 16; - bool loop = max_seqlen_k > 256; - CHECK_SHAPE(blockmask, max_seqlen_k / 256, max_seqlen_q / 16); - - auto opts = q.options(); - - auto o = torch::zeros({ total_q, num_heads, head_size }, opts); - - at::Tensor o_tmp; - if (loop) { - // o_tmp = torch::zeros({total, num_heads, head_size}, opts.dtype(at::kFloat)); - o_tmp = torch::empty({total_q, num_heads, head_size}, opts.dtype(at::kFloat)); - } - - // auto softmax_lse = torch::full({batch_size, num_heads, max_seqlen_k}, -std::numeric_limits::infinity(), opts.dtype(at::kFloat)); - auto softmax_lse = torch::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat)); - - at::Tensor s; - if (return_softmax) { - s = torch::zeros({ batch_size, num_heads, max_seqlen_q, max_seqlen_k }, opts); - } - - auto gen = at::get_generator_or_default( - gen_, at::cuda::detail::getDefaultCUDAGenerator()); - - set_params_fprop(launch_params.params, - batch_size, - max_seqlen_q, - max_seqlen_k, - num_heads, - head_size, - q, k, v, o, - cu_seqlens_q.data_ptr(), - cu_seqlens_k.data_ptr(), - loop ? o_tmp.data_ptr() : nullptr, - return_softmax ? s.data_ptr() : nullptr, - softmax_lse.data_ptr(), - p_dropout, - softmax_scale, - is_causal, - /*num_splits=*/1); - launch_params.params.blockmask = static_cast(blockmask.data_ptr()); - - run_fmha_block_fp16_sm80(launch_params, /*configure=*/ true); - // number of times random will be generated per thread, to offset philox counter in thc random - // state - int64_t counter_offset = launch_params.elts_per_thread; - - if( is_dropout ) { - // See Note [Acquire lock when using random generators] - std::lock_guard lock(gen->mutex_); - launch_params.params.philox_args = gen->philox_cuda_state(counter_offset); - } - - run_fmha_block_fp16_sm80(launch_params, /*configure=*/false); - - std::vector result = {o, softmax_lse}; - if (return_softmax) {result.push_back(s);} - return result; -} - -std::vector -mha_bwd_block(const at::Tensor &dout, // total x num_heads, x head_size - const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i - const at::Tensor &k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i - const at::Tensor &v, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i - const at::Tensor &out, // total_q x num_heads x head_size - const at::Tensor &softmax_lse_, // b x h x s softmax logsumexp - at::Tensor &dq, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i - at::Tensor &dk, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i - at::Tensor &dv, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i - const at::Tensor &cu_seqlens_q, // b+1 - const at::Tensor &cu_seqlens_k, // b+1 - const at::Tensor &blockmask, // (seqlen / 256, seqlen / 16) - const int max_seqlen_q_, - const int max_seqlen_k_, // max sequence length to choose the kernel - const float p_dropout, // probability to drop - const float softmax_scale, - const bool is_causal, - c10::optional gen_ -) { - auto dprops = at::cuda::getCurrentDeviceProperties(); - bool is_sm80 = dprops->major == 8 && dprops->minor == 0; - bool is_sm8x = dprops->major == 8 && dprops->minor >= 0; - bool is_sm90 = dprops->major == 9 && dprops->minor == 0; - TORCH_CHECK(is_sm8x || is_sm90); - auto launch = &run_fmha_block_dgrad_fp16_sm80; - - bool is_dropout = p_dropout > 0.0; - auto stream = at::cuda::getCurrentCUDAStream().stream(); - - TORCH_CHECK(q.dtype() == torch::kFloat16); - TORCH_CHECK(k.dtype() == torch::kFloat16); - TORCH_CHECK(v.dtype() == torch::kFloat16); - TORCH_CHECK(out.dtype() == torch::kFloat16); - TORCH_CHECK(dout.dtype() == torch::kFloat16); - TORCH_CHECK(dq.dtype() == torch::kFloat16); - TORCH_CHECK(dk.dtype() == torch::kFloat16); - TORCH_CHECK(dv.dtype() == torch::kFloat16); - TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32); - TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32); - TORCH_CHECK(blockmask.dtype() == torch::kInt32); - - TORCH_CHECK(q.is_cuda()); - TORCH_CHECK(k.is_cuda()); - TORCH_CHECK(v.is_cuda()); - TORCH_CHECK(out.is_cuda()); - TORCH_CHECK(dout.is_cuda()); - TORCH_CHECK(softmax_lse_.is_cuda()); - TORCH_CHECK(cu_seqlens_q.is_cuda()); - TORCH_CHECK(cu_seqlens_k.is_cuda()); - TORCH_CHECK(blockmask.is_cuda()); - - TORCH_CHECK(q.stride(-1) == 1); - TORCH_CHECK(k.stride(-1) == 1); - TORCH_CHECK(v.stride(-1) == 1); - TORCH_CHECK(out.is_contiguous()); - TORCH_CHECK(dout.is_contiguous()); - TORCH_CHECK(dq.stride(-1) == 1); - TORCH_CHECK(dk.stride(-1) == 1); - TORCH_CHECK(dv.stride(-1) == 1); - TORCH_CHECK(cu_seqlens_q.is_contiguous()); - TORCH_CHECK(cu_seqlens_k.is_contiguous()); - TORCH_CHECK(blockmask.is_contiguous()); - - const auto sizes = q.sizes(); - - const int batch_size = cu_seqlens_q.numel() - 1; - const int total_q = sizes[TOTAL_DIM]; - const int num_heads = sizes[H_DIM]; - const int head_size = sizes[D_DIM]; - const int total_k = k.size(TOTAL_DIM); - TORCH_CHECK(batch_size > 0); - TORCH_CHECK(head_size == 16 || head_size == 32 || head_size == 64 || head_size == 128); - if (head_size == 128) { // TODO: eventually we should support SM86 and SM70 with d=128 as well - TORCH_CHECK(is_sm80 || is_sm90); - } - - CHECK_SHAPE(q, total_q, num_heads, head_size); - CHECK_SHAPE(k, total_k, num_heads, head_size); - CHECK_SHAPE(v, total_k, num_heads, head_size); - CHECK_SHAPE(out, total_q, num_heads, head_size); - CHECK_SHAPE(dout, total_q, num_heads, head_size); - CHECK_SHAPE(dq, total_q, num_heads, head_size); - CHECK_SHAPE(dk, total_k, num_heads, head_size); - CHECK_SHAPE(dv, total_k, num_heads, head_size); - CHECK_SHAPE(cu_seqlens_q, batch_size + 1); - CHECK_SHAPE(cu_seqlens_k, batch_size + 1); - - int max_seqlen_k = ((max_seqlen_k_ + 256 - 1) / 256) * 256; - if( max_seqlen_k <= 256 ) { - max_seqlen_k = 256; - } - int max_seqlen_q = ((max_seqlen_q_ + 16 - 1) / 16) * 16; - bool loop = max_seqlen_k > 256; - CHECK_SHAPE(blockmask, max_seqlen_k / 256, max_seqlen_q / 16); - - // It's possible the softmax_lse_ from the fwd has a different length since blocksize_c could be different. - auto softmax_lse = softmax_lse_.index({torch::indexing::Slice(), torch::indexing::Slice(), torch::indexing::Slice(torch::indexing::None, max_seqlen_q)}).contiguous(); - - auto opts = q.options(); - auto softmax_d = torch::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat)); - at::Tensor dq_tmp; - if (loop) { - // dq_tmp = torch::zeros({total, num_heads, head_size}, opts.dtype(at::kFloat)); - dq_tmp = torch::empty({total_q, num_heads, head_size}, opts.dtype(at::kFloat)); - } - - FMHA_dgrad_params params; - - set_params_dgrad(params, - batch_size, - max_seqlen_q, - max_seqlen_k, - num_heads, - head_size, - q, k, v, out, - dq, dk, dv, - cu_seqlens_q.data_ptr(), - cu_seqlens_k.data_ptr(), - loop ? dq_tmp.data_ptr() : nullptr, - dout.data_ptr(), - softmax_lse.data_ptr(), - softmax_d.data_ptr(), - p_dropout, - softmax_scale, - is_causal, - /*num_splits=*/1); - params.blockmask = static_cast(blockmask.data_ptr()); - - auto gen = at::get_generator_or_default( - gen_, at::cuda::detail::getDefaultCUDAGenerator()); - - // We're gonna reset the rng state in Python after this kernel, so the counter offset - // here doesn't matter at all. We just choose an arbitrary number; - int64_t counter_offset = 4; - - if( is_dropout ) { - // See Note [Acquire lock when using random generators] - std::lock_guard lock(gen->mutex_); - params.philox_args = gen->philox_cuda_state(counter_offset); - } - - launch(params, stream); - return { dq, dk, dv, softmax_d }; -} - - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.doc() = "Fused Multi-head Self-attention"; - m.def("fwd", &mha_fwd, "Forward pass"); - m.def("bwd", &mha_bwd, "Backward pass"); - m.def("fwd_block", &mha_fwd_block, "Forward pass (blocksparse)"); - m.def("bwd_block", &mha_bwd_block, "Backward pass (blocksparse)"); -} diff --git a/csrc/flash_attn/src/block_info.h b/csrc/flash_attn/src/block_info.h new file mode 100644 index 000000000..94251a41e --- /dev/null +++ b/csrc/flash_attn/src/block_info.h @@ -0,0 +1,41 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +namespace flash { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct BlockInfo { + + template + __device__ BlockInfo(const Params ¶ms, const int bidb) + : sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb]) + , sum_s_k(!Varlen || params.cu_seqlens_k == nullptr ? -1 : params.cu_seqlens_k[bidb]) + , actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q) + , actual_seqlen_k(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : params.cu_seqlens_k[bidb + 1] - sum_s_k) + { + } + + template + inline __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { + return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride; + } + + template + inline __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { + return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride; + } + + const int sum_s_q; + const int sum_s_k; + const uint32_t actual_seqlen_q; + const uint32_t actual_seqlen_k; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace flash diff --git a/csrc/flash_attn/src/flash.h b/csrc/flash_attn/src/flash.h new file mode 100644 index 000000000..cb0a57dff --- /dev/null +++ b/csrc/flash_attn/src/flash.h @@ -0,0 +1,141 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include + +#ifdef OLD_GENERATOR_PATH +#include +#else +#include +#endif + +#include + + +constexpr int TOTAL_DIM = 0; +constexpr int H_DIM = 1; +constexpr int D_DIM = 2; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Qkv_params { + using index_t = uint32_t; + // The QKV matrices. + void *__restrict__ q_ptr; + void *__restrict__ k_ptr; + void *__restrict__ v_ptr; + + // The stride between rows of the Q, K and V matrices. + index_t q_batch_stride; + index_t k_batch_stride; + index_t v_batch_stride; + index_t q_row_stride; + index_t k_row_stride; + index_t v_row_stride; + index_t q_head_stride; + index_t k_head_stride; + index_t v_head_stride; + + // The number of heads. + int h, h_k; + // In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be + // different from nheads (query). + int h_h_k_ratio; // precompute h / h_k, +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Flash_fwd_params : public Qkv_params { + + // The O matrix (output). + void * __restrict__ o_ptr; + + // The stride between rows of O. + index_t o_batch_stride; + index_t o_row_stride; + index_t o_head_stride; + + // The pointer to the P matrix. + void * __restrict__ p_ptr; + + // The pointer to the softmax sum. + void * __restrict__ softmax_lse_ptr; + + // The dimensions. + int b, seqlen_q, seqlen_k, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded; + + // The scaling factors for the kernel. + float scale_softmax; + float scale_softmax_log2; + + // array of length b+1 holding starting offset of each sequence. + int * __restrict__ cu_seqlens_q; + int * __restrict__ cu_seqlens_k; + + int *__restrict__ blockmask; + + // The dropout probability (probability of keeping an activation). + float p_dropout; + // uint32_t p_dropout_in_uint; + // uint16_t p_dropout_in_uint16_t; + uint8_t p_dropout_in_uint8_t; + + // Scale factor of 1 / (1 - p_dropout). + float rp_dropout; + float scale_softmax_rp_dropout; + + // Random state. + at::PhiloxCudaState philox_args; + + bool is_bf16; + bool is_causal; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Flash_bwd_params : public Flash_fwd_params { + + // The dO and dQKV matrices. + void *__restrict__ do_ptr; + void *__restrict__ dq_ptr; + void *__restrict__ dk_ptr; + void *__restrict__ dv_ptr; + + // To accumulate dQ + void *__restrict__ dq_accum_ptr; + void *__restrict__ dk_accum_ptr; + void *__restrict__ dv_accum_ptr; + + // // To accumulate dK and dV in case we're splitting the bwd along seqlen_q + // dimension void *__restrict__ dk_accum_ptr; void *__restrict__ + // dv_accum_ptr; + + // The stride between rows of the dO, dQ, dK and dV matrices. + // TD [2022-04-16]: We're using 32-bit indexing to save registers. + // The code probably won't work for arrays larger than 2GB. + index_t do_batch_stride; + index_t do_row_stride; + index_t do_head_stride; + index_t dq_batch_stride; + index_t dk_batch_stride; + index_t dv_batch_stride; + index_t dq_row_stride; + index_t dk_row_stride; + index_t dv_row_stride; + index_t dq_head_stride; + index_t dk_head_stride; + index_t dv_head_stride; + + // The pointer to the softmax d sum. + void *__restrict__ dsoftmax_sum; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream); + +template void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure); diff --git a/csrc/flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu new file mode 100644 index 000000000..bec30a8d5 --- /dev/null +++ b/csrc/flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu @@ -0,0 +1,20 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_bwd_launch_template.h" + +// template<> +// void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { +// using elem_type = cutlass::bfloat16_t; +// if (params.h == params.h_k) { +// run_flash_bwd>(params, stream, configure); +// } else { +// run_flash_bwd_seqq_parallel>(params, stream, configure); +// } +// } + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { + run_mha_bwd_hdim128(params, stream, configure); +} \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu new file mode 100644 index 000000000..1de5b16a0 --- /dev/null +++ b/csrc/flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu @@ -0,0 +1,26 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_bwd_launch_template.h" + +// template<> +// void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { +// using elem_type = cutlass::half_t; +// if (params.h == params.h_k) { +// // run_flash_bwd>(params, stream, configure); +// // This is faster, in the case of sequence-parallel bwd (where we need fewer registers). +// // Out of these three, the 2nd one is slightly faster (2% faster than the first). Idk why. +// // run_flash_bwd>(params, stream, configure); +// run_flash_bwd>(params, stream, configure); +// // run_flash_bwd>(params, stream, configure); +// // run_flash_bwd>(params, stream, configure); +// } else { +// run_flash_bwd_seqq_parallel>(params, stream, configure); +// } +// } + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { + run_mha_bwd_hdim128(params, stream, configure); +} \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu new file mode 100644 index 000000000..de9007ded --- /dev/null +++ b/csrc/flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_bwd_launch_template.h" + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { + run_mha_bwd_hdim160(params, stream, configure); +} \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu new file mode 100644 index 000000000..9adf32b61 --- /dev/null +++ b/csrc/flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_bwd_launch_template.h" + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { + run_mha_bwd_hdim160(params, stream, configure); +} \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu new file mode 100644 index 000000000..d9859500c --- /dev/null +++ b/csrc/flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_bwd_launch_template.h" + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { + run_mha_bwd_hdim192(params, stream, configure); +} \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu new file mode 100644 index 000000000..8179d75f1 --- /dev/null +++ b/csrc/flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_bwd_launch_template.h" + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { + run_mha_bwd_hdim192(params, stream, configure); +} \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_bwd_hdim224_bf16_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim224_bf16_sm80.cu new file mode 100644 index 000000000..0ce28e3f4 --- /dev/null +++ b/csrc/flash_attn/src/flash_bwd_hdim224_bf16_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_bwd_launch_template.h" + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { + run_mha_bwd_hdim224(params, stream, configure); +} \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_bwd_hdim224_fp16_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim224_fp16_sm80.cu new file mode 100644 index 000000000..ab1eb6c96 --- /dev/null +++ b/csrc/flash_attn/src/flash_bwd_hdim224_fp16_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_bwd_launch_template.h" + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { + run_mha_bwd_hdim224(params, stream, configure); +} \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu new file mode 100644 index 000000000..2bc48001a --- /dev/null +++ b/csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_bwd_launch_template.h" + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { + run_mha_bwd_hdim256(params, stream, configure); +} \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu new file mode 100644 index 000000000..9f76c58bc --- /dev/null +++ b/csrc/flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_bwd_launch_template.h" + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { + run_mha_bwd_hdim256(params, stream, configure); +} \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu new file mode 100644 index 000000000..4b6ecb40a --- /dev/null +++ b/csrc/flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu @@ -0,0 +1,16 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_bwd_launch_template.h" + +// template<> +// void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { +// using elem_type = cutlass::bfloat16_t; +// run_flash_bwd>(params, stream, configure); +// } + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { + run_mha_bwd_hdim32(params, stream, configure); +} \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu new file mode 100644 index 000000000..f4ac6c582 --- /dev/null +++ b/csrc/flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu @@ -0,0 +1,16 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_bwd_launch_template.h" + +// template<> +// void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { +// using elem_type = cutlass::half_t; +// run_flash_bwd>(params, stream, configure); +// } + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { + run_mha_bwd_hdim32(params, stream, configure); +} \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu new file mode 100644 index 000000000..7307344eb --- /dev/null +++ b/csrc/flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu @@ -0,0 +1,16 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_bwd_launch_template.h" + +// template<> +// void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { +// using elem_type = cutlass::bfloat16_t; +// run_flash_bwd>(params, stream, configure); +// } + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { + run_mha_bwd_hdim64(params, stream, configure); +} \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu new file mode 100644 index 000000000..7a2f8ecc3 --- /dev/null +++ b/csrc/flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu @@ -0,0 +1,35 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_bwd_launch_template.h" + +// template<> +// void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { +// using elem_type = cutlass::half_t; +// // Changing AtomLayoutMdQ from 2 to 4 takes the same time +// // run_flash_bwd>(params, stream, configure); +// // run_flash_bwd>(params, stream, configure); +// // run_flash_bwd>(params, stream, configure); +// // This is slightly faster. We want to split M more so we need fewer registers to store LSE. +// run_flash_bwd>(params, stream, configure); +// // run_flash_bwd>(params, stream, configure); +// // run_flash_bwd>(params, stream, configure); +// // run_flash_bwd>(params, stream, configure); +// // run_flash_bwd>(params, stream, configure); +// // M=128, N=64 is quite slow, I think because we need to read/write dQaccum twice as many times +// // run_flash_bwd>(params, stream, configure); +// // run_flash_bwd>(params, stream, configure); +// // run_flash_bwd>(params, stream, configure); +// // run_flash_bwd>(params, stream, configure); +// // run_flash_bwd>(params, stream, configure); + +// // run_flash_bwd>(params, stream, configure); +// // run_flash_bwd>(params, stream, configure); +// // run_flash_bwd>(params, stream, configure); +// } + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { + run_mha_bwd_hdim64(params, stream, configure); +} \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu new file mode 100644 index 000000000..2f5d9aaeb --- /dev/null +++ b/csrc/flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu @@ -0,0 +1,20 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_bwd_launch_template.h" + +// template<> +// void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { +// using elem_type = cutlass::bfloat16_t; +// if (params.h == params.h_k) { +// run_flash_bwd>(params, stream, configure); +// } else { +// run_flash_bwd_seqq_parallel>(params, stream, configure); +// } +// } + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { + run_mha_bwd_hdim96(params, stream, configure); +} \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu new file mode 100644 index 000000000..9a1d88a68 --- /dev/null +++ b/csrc/flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu @@ -0,0 +1,22 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_bwd_launch_template.h" + +// template<> +// void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { +// using elem_type = cutlass::half_t; +// if (params.h == params.h_k) { +// // run_flash_bwd>(params, stream, configure); +// // This is very slightly faster +// run_flash_bwd>(params, stream, configure); +// } else { +// run_flash_bwd_seqq_parallel>(params, stream, configure); +// } +// } + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { + run_mha_bwd_hdim96(params, stream, configure); +} \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_bwd_kernel.h b/csrc/flash_attn/src/flash_bwd_kernel.h new file mode 100644 index 000000000..74f25ba45 --- /dev/null +++ b/csrc/flash_attn/src/flash_bwd_kernel.h @@ -0,0 +1,1519 @@ +/*************************************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include + +#include +#include +#include +#include + +#include "block_info.h" +#include "kernel_traits.h" +#include "utils.h" +#include "softmax.h" +#include "philox.cuh" + +namespace flash { + +using namespace cute; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTE_HOST_DEVICE +auto +make_tiled_copy_B_warpcontiguousN(Copy_Atom const& copy_atom, + TiledMMA const& tiled_mma) { + using TileShape_MNK = typename TiledMMA::TiledShape_MNK; + using AtomShape_MNK = typename TiledMMA::AtomShape_MNK; + constexpr int AtomShape_N = decltype(size<1>(AtomShape_MNK{}))::value; + // Divide by 2 because right now we always use 2 for the ValLayout + constexpr int kNWarpsN = decltype(size<1>(TileShape_MNK{}))::value / AtomShape_N / 2; + constexpr int MMAStride_N = MMA_N * AtomShape_N * 2; + // This gives the correct layout, idk why. + // auto t = make_tile(Layout, _2>, + // Stride, _8> >{}, + // auto t = make_tile(Layout, + // Stride<_1, _64, _8> >{}, + auto t = make_tile(Layout, Int, _2>, // (8, 2, 2) or (8, 4, 2) + Stride<_1, Int, _8> >{}, // (1, 64, 8) or (1, 32, 8) + make_layout(size<2>(TileShape_MNK{}))); + // if (cute::thread0()) {printf("make_tiled_copy_B_warpcontiguousN "); print(t); printf("\n"); } + return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutB_TV(), t); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTE_HOST_DEVICE +auto +make_tiled_copy_C_warpcontiguousN(Copy_Atom const& copy_atom, + TiledMMA const& tiled_mma) { + using TileShape_MNK = typename TiledMMA::TiledShape_MNK; + using AtomShape_MNK = typename TiledMMA::AtomShape_MNK; + constexpr int AtomShape_N = decltype(size<1>(AtomShape_MNK{}))::value; + // Divide by 2 because right now we always use 2 for the ValLayout + constexpr int kNWarpsN = decltype(size<1>(TileShape_MNK{}))::value / AtomShape_N / 2; + constexpr int MMAStride_N = MMA_N * AtomShape_N * 2; + auto t = make_tile(make_layout(size<0>(TileShape_MNK{})), + Layout, Int, _2>, // (8, 2, 2) or (8, 4, 2) + Stride<_1, Int, _8> >{}); // (1, 64, 8) or (1, 32, 8) + // if (cute::thread0()) {printf("make_tiled_copy_C_warpcontiguousN "); print(t); printf("\n"); } + return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutC_TV(), t); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void dot_do_o(Tensor const &do_, Tensor const &o, + Tensor &dP_sum, Tensor &sdPsum, + const int gdP_col_stride, const float scale) { + static_assert(Layout0::rank == 3, "Only support 3D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(do_.layout() == o.layout()); + // Reshape do_ and o from (8, kBlockM / 32, kHeadDim / 64) to (kBlockM / 32, 8 * kHeadDim / 64) + // The last coordinate is the "page". + Tensor do_reshaped = make_tensor(do_.data(), make_layout(get<1>(do_.layout()), + make_layout(get<0>(do_.layout()), + get<2>(do_.layout())))); + Tensor o_reshaped = make_tensor(o.data(), do_reshaped.layout()); + Tensor do_fp32 = flash::convert_type(do_reshaped); + Tensor o_fp32 = flash::convert_type(o_reshaped); + #pragma unroll + for (int mi = 0; mi < size<0>(do_reshaped); ++mi) { + float dP_sum_cur = do_fp32(mi, 0) * o_fp32(mi, 0); + #pragma unroll + for (int ni = 1; ni < size<1>(do_reshaped); ni++) { + dP_sum_cur += do_fp32(mi, ni) * o_fp32(mi, ni); + } + flash::SumOp sum_op; + dP_sum_cur = flash::Allreduce::run(dP_sum_cur, sum_op) * scale; + if (threadIdx.x % THREADS_PER_ROW == 0) { + dP_sum(mi * gdP_col_stride + threadIdx.x / THREADS_PER_ROW) = dP_sum_cur; + // recast(sdPsum)(mi * gdP_col_stride + threadIdx.x / THREADS_PER_ROW) = dP_sum; + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Just compute dot(do, o) and write the result (softmax_d) to global memory as a separate kernel. +// This is used in the case where we want to parallelize the backward across seqlen_k. +template +inline __device__ void compute_dot_do_o(const Params ¶ms) { + using Element = typename Kernel_traits::Element; + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + + const int m_block = blockIdx.x; + // The block index for the batch. + const int bidb = blockIdx.y; + // The block index for the head. + const int bidh = blockIdx.z; + // The thread index. + const int tidx = threadIdx.x; + + constexpr int kBlockM = Kernel_traits::kBlockM; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + + const BlockInfo binfo(params, bidb); + if (m_block * kBlockM >= binfo.actual_seqlen_q) return; + + const index_t row_offset_do = binfo.q_offset(params.do_batch_stride, params.do_row_stride, bidb) + + m_block * kBlockM * params.do_row_stride + bidh * params.do_head_stride; + const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; + const index_t row_offset_dq_accum = ((bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM) * params.d_rounded; + const index_t row_offset_dpsum = (bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM; + + Tensor gdO = make_tensor(make_gmem_ptr(reinterpret_cast(params.do_ptr) + row_offset_do), + Shape, Int>{}, + make_stride(params.do_row_stride, _1{})); + Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + row_offset_o), + Shape, Int>{}, + make_stride(params.do_row_stride, _1{})); + Tensor gdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dq_accum_ptr) + row_offset_dq_accum), + Shape, Int>{}, Stride, _1>{}); + Tensor dP_sum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dsoftmax_sum) + row_offset_dpsum), + Shape>{}, Stride<_1>{}); + + auto gmem_thr_copy_dO = typename Kernel_traits::GmemTiledCopydO{}.get_thread_slice(tidx); + // TODO: careful, we're zeroing out dQaccum with type float4, but when + // we do atomicAdds, we use type float. The layouts are different. Check this. + auto gmem_thr_copy_dQ_accum = typename Kernel_traits::GmemTiledCopydQaccum{}.get_thread_slice(tidx); + + Tensor tdOgdO = gmem_thr_copy_dO.partition_S(gdO); + Tensor tdOgO = gmem_thr_copy_dO.partition_S(gO); + Tensor tdQgdQaccum = gmem_thr_copy_dQ_accum.partition_D(gdQaccum); + + Tensor cdO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor tdOcdO = gmem_thr_copy_dO.partition_S(cdO); + + // Allocate predicate tensors for k + Tensor tdOpdO = make_tensor(make_shape(size<2>(tdOgdO))); + // Set predicates for k bounds + #pragma unroll + for (int k = 0; k < size(tdOpdO); ++k) {tdOpdO(k) = get<1>(tdOcdO(0, 0, k)) < params.d;} + + Tensor tdOrdO = make_fragment_like(tdOgdO); + Tensor tdOrO = make_fragment_like(tdOgO); + flash::copy( + gmem_thr_copy_dO, tdOgdO, tdOrdO, tdOcdO, tdOpdO, binfo.actual_seqlen_q - m_block * kBlockM + ); + flash::copy( + gmem_thr_copy_dO, tdOgO, tdOrO, tdOcdO, tdOpdO, binfo.actual_seqlen_q - m_block * kBlockM + ); + // By right we need to scale dP up by 1/p_dropout, but instead we don't and only scale the final + // results (dQ and dK) by 1/p_dropout. So we need to keep dP_sum scaled down by p_dropout here, + // so that (dP - dP_sum) is on the same scale. + dot_do_o(tdOrdO, tdOrO, dP_sum, dP_sum, + Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout); + if (Clear_dQaccum) { + Tensor zero = make_fragment_like(tdQgdQaccum); + clear(zero); + copy(gmem_thr_copy_dQ_accum, zero, tdQgdQaccum); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void clear_dKVaccum(const Params ¶ms) { + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + + const int n_block = blockIdx.x; + // The block index for the batch. + const int bidb = blockIdx.y; + // The block index for the head. + const int bidh = blockIdx.z; + // The thread index. + const int tidx = threadIdx.x; + + constexpr int kBlockN = Kernel_traits::kBlockN; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + + const BlockInfo binfo(params, bidb); + if (n_block * kBlockN >= binfo.actual_seqlen_k) return; + + const index_t row_offset_dkv_accum = ((bidb * params.h_k + bidh) * params.seqlen_k_rounded + n_block * kBlockN) * params.d_rounded; + + Tensor gdKaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dk_accum_ptr) + row_offset_dkv_accum), + Shape, Int>{}, Stride, _1>{}); + Tensor gdVaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dv_accum_ptr) + row_offset_dkv_accum), + Shape, Int>{}, Stride, _1>{}); + + auto gmem_thr_copy_dKV_accum = typename Kernel_traits::GmemTiledCopydQaccum{}.get_thread_slice(tidx); + Tensor tdKgdKaccum = gmem_thr_copy_dKV_accum.partition_D(gdKaccum); + Tensor tdVgdVaccum = gmem_thr_copy_dKV_accum.partition_D(gdVaccum); + Tensor zero = make_fragment_like(tdKgdKaccum); + clear(zero); + copy(gmem_thr_copy_dKV_accum, zero, tdKgdKaccum); + copy(gmem_thr_copy_dKV_accum, zero, tdVgdVaccum); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Convert dQ from dQaccum (in float) to fp16/bf16. +// This is used in the case where we want to parallelize the backward across seqlen_k. +template +inline __device__ void convert_dQ(const Params ¶ms) { + using Element = typename Kernel_traits::Element; + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + + // Shared memory. + extern __shared__ char smem_[]; + + const int m_block = blockIdx.x; + // The block index for the batch. + const int bidb = blockIdx.y; + // The block index for the head. + const int bidh = blockIdx.z; + // The thread index. + const int tidx = threadIdx.x; + + constexpr int kBlockM = Kernel_traits::kBlockM; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + + const BlockInfo binfo(params, bidb); + if (m_block * kBlockM >= binfo.actual_seqlen_q) return; + + const index_t row_offset_dq = binfo.q_offset(params.dq_batch_stride, params.dq_row_stride, bidb) + + m_block * kBlockM * params.dq_row_stride + bidh * params.dq_head_stride; + const index_t row_offset_dq_accum = ((bidb * params.h + bidh) * params.seqlen_q_rounded + + m_block * kBlockM) * params.d_rounded; + + Tensor gdQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.dq_ptr) + row_offset_dq), + Shape, Int>{}, + make_stride(params.dq_row_stride, _1{})); + Tensor gdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dq_accum_ptr) + row_offset_dq_accum), + Shape, Int>{}, + Stride, _1>{}); + + Tensor sdQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), + typename Kernel_traits::SmemLayoutdQ{}); + + auto gmem_thr_copy_dQ = typename Kernel_traits::GmemTiledCopydQ{}.get_thread_slice(tidx); + auto gmem_thr_copy_dQ_accum = typename Kernel_traits::GmemTiledCopydQaccumAtomicAdd{}.get_thread_slice(tidx); + + typename Kernel_traits::TiledMmadQ tiled_mma_dq; + auto smem_thr_copy_dQ = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdQ{}, tiled_mma_dq).get_thread_slice(tidx); + Tensor taccdQsdQ = smem_thr_copy_dQ.partition_D(sdQ); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + Tensor tdQsdQ = gmem_thr_copy_dQ.partition_S(sdQ); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tdQgdQ = gmem_thr_copy_dQ.partition_D(gdQ); + Tensor tdQgdQaccum = gmem_thr_copy_dQ_accum.partition_S(gdQaccum); + + Tensor acc_dq = partition_fragment_C(tiled_mma_dq, Shape, Int>{}); // MMA, MMA_N, MMA_K + CUTE_STATIC_ASSERT_V(size(acc_dq) == size(tdQgdQaccum)); + + Tensor tdQrdQaccum = make_fragment_like(tdQgdQaccum); + copy(gmem_thr_copy_dQ_accum, tdQgdQaccum, tdQrdQaccum); + #pragma unroll + for (int i = 0; i < size(acc_dq); ++i) { + acc_dq(i) = tdQrdQaccum(i) * params.scale_softmax_rp_dropout; + } + // Convert acc_dq from fp32 to fp16 + Tensor rdQ = flash::convert_type(acc_dq); + Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ); // ((Atom,AtomNum), MMA_N, MMA_N) + copy(smem_thr_copy_dQ, taccdQrdQ, taccdQsdQ); + __syncthreads(); + Tensor tdQrdQ = make_tensor(shape(tdQgdQ)); + copy(gmem_thr_copy_dQ, tdQsdQ, tdQrdQ); + + Tensor cdQ = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor tdQcdQ = gmem_thr_copy_dQ.partition_D(cdQ); + Tensor tdQpdQ = make_tensor(make_shape(size<2>(tdQgdQ))); + #pragma unroll + for (int k = 0; k < size(tdQpdQ); ++k) { tdQpdQ(k) = get<1>(tdQcdQ(0, 0, k)) < params.d; } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_thr_copy_dQ, tdQrdQ, tdQgdQ, tdQcdQ, tdQpdQ, binfo.actual_seqlen_q - m_block * kBlockM + ); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Convert dK and dV from dKaccum and dVaccum (in float) to fp16/bf16. +// This is used in the case where we want to parallelize the backward across seqlen_q. +template +inline __device__ void convert_dKV(const Params ¶ms) { + using Element = typename Kernel_traits::Element; + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + + // Shared memory. + extern __shared__ char smem_[]; + + const int n_block = blockIdx.x; + // The block index for the batch. + const int bidb = blockIdx.y; + // The block index for the head. + const int bidh = blockIdx.z; + // The thread index. + const int tidx = threadIdx.x; + + constexpr int kBlockN = Kernel_traits::kBlockN; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + + const BlockInfo binfo(params, bidb); + if (n_block * kBlockN >= binfo.actual_seqlen_k) return; + + const index_t row_offset_dk = binfo.k_offset(params.dk_batch_stride, params.dk_row_stride, bidb) + + n_block * kBlockN * params.dk_row_stride + bidh * params.dk_head_stride; + const index_t row_offset_dv = binfo.k_offset(params.dv_batch_stride, params.dv_row_stride, bidb) + + n_block * kBlockN * params.dv_row_stride + bidh * params.dv_head_stride; + const index_t row_offset_dkv_accum = ((bidb * params.h_k + bidh) * params.seqlen_k_rounded + + n_block * kBlockN) * params.d_rounded; + + Tensor gdK = make_tensor(make_gmem_ptr(reinterpret_cast(params.dk_ptr) + row_offset_dk), + Shape, Int>{}, + make_stride(params.dk_row_stride, _1{})); + Tensor gdV = make_tensor(make_gmem_ptr(reinterpret_cast(params.dv_ptr) + row_offset_dv), + Shape, Int>{}, + make_stride(params.dv_row_stride, _1{})); + Tensor gdKaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dk_accum_ptr) + row_offset_dkv_accum), + Shape, Int>{}, + Stride, _1>{}); + Tensor gdVaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dv_accum_ptr) + row_offset_dkv_accum), + Shape, Int>{}, + Stride, _1>{}); + + Tensor sdK = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), + typename Kernel_traits::SmemLayoutdKV{}); + Tensor sdV = make_tensor(sdK.data() + size(sdK), typename Kernel_traits::SmemLayoutdKV{}); // (SMEM_N, SMEM_K) + + auto gmem_thr_copy_dKV = typename Kernel_traits::GmemTiledCopydQ{}.get_thread_slice(tidx); + auto gmem_thr_copy_dKV_accum = typename Kernel_traits::GmemTiledCopydQaccumAtomicAdd{}.get_thread_slice(tidx); + + typename Kernel_traits::TiledMmadKV tiled_mma_dkv; + auto smem_thr_copy_dKV = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdKV{}, tiled_mma_dkv).get_thread_slice(tidx); + Tensor taccdKsdK = smem_thr_copy_dKV.partition_D(sdK); // ((Atom,AtomNum),PIPE_M,PIPE_N) + Tensor taccdVsdV = smem_thr_copy_dKV.partition_D(sdV); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + Tensor tdKsdK = gmem_thr_copy_dKV.partition_S(sdK); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tdKgdK = gmem_thr_copy_dKV.partition_D(gdK); + Tensor tdVsdV = gmem_thr_copy_dKV.partition_S(sdV); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tdVgdV = gmem_thr_copy_dKV.partition_D(gdV); + Tensor tdKgdKaccum = gmem_thr_copy_dKV_accum.partition_S(gdKaccum); + Tensor tdVgdVaccum = gmem_thr_copy_dKV_accum.partition_S(gdVaccum); + + Tensor acc_dk = partition_fragment_C(tiled_mma_dkv, Shape, Int>{}); // MMA, MMA_N, MMA_K + Tensor acc_dv = partition_fragment_C(tiled_mma_dkv, Shape, Int>{}); // MMA, MMA_N, MMA_K + CUTE_STATIC_ASSERT_V(size(acc_dk) == size(tdKgdKaccum)); + CUTE_STATIC_ASSERT_V(size(acc_dv) == size(tdVgdVaccum)); + + Tensor tdKrdKaccum = make_fragment_like(tdKgdKaccum); + Tensor tdVrdVaccum = make_fragment_like(tdVgdVaccum); + copy(gmem_thr_copy_dKV_accum, tdKgdKaccum, tdKrdKaccum); + copy(gmem_thr_copy_dKV_accum, tdVgdVaccum, tdVrdVaccum); + #pragma unroll + for (int i = 0; i < size(acc_dk); ++i) { + acc_dk(i) = tdKrdKaccum(i) * params.scale_softmax_rp_dropout; + } + #pragma unroll + for (int i = 0; i < size(acc_dv); ++i) { + acc_dv(i) = tdVrdVaccum(i) * params.rp_dropout; + } + // Convert acc_dk from fp32 to fp16 + Tensor rdK = flash::convert_type(acc_dk); + Tensor rdV = flash::convert_type(acc_dv); + Tensor taccdKrdK = smem_thr_copy_dKV.retile_S(rdK); // ((Atom,AtomNum), MMA_N, MMA_N) + Tensor taccdVrdV = smem_thr_copy_dKV.retile_S(rdV); // ((Atom,AtomNum), MMA_N, MMA_N) + copy(smem_thr_copy_dKV, taccdKrdK, taccdKsdK); + copy(smem_thr_copy_dKV, taccdVrdV, taccdVsdV); + __syncthreads(); + Tensor tdKrdK = make_tensor(shape(tdKgdK)); + Tensor tdVrdV = make_tensor(shape(tdVgdV)); + copy(gmem_thr_copy_dKV, tdKsdK, tdKrdK); + copy(gmem_thr_copy_dKV, tdVsdV, tdVrdV); + + Tensor cdKV = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV); + Tensor tdKVpdKV = make_tensor(make_shape(size<2>(tdKgdK))); + #pragma unroll + for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(0, 0, k)) < params.d; } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_thr_copy_dKV, tdKrdK, tdKgdK, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN + ); + flash::copy( + gmem_thr_copy_dKV, tdVrdV, tdVgdV, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN + ); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const int bidb, const int bidh, const int n_block) { + + using Element = typename Kernel_traits::Element; + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + + // Shared memory. + extern __shared__ char smem_[]; + + // The thread index. + const int tidx = threadIdx.x; + + constexpr int kBlockM = Kernel_traits::kBlockM; + constexpr int kBlockN = Kernel_traits::kBlockN; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + // constexpr int kNWarps = Kernel_traits::kNWarps; + constexpr int MMA_N_SdP = kBlockN / decltype(size<1>(typename Kernel_traits::TiledMmaSdP::TiledShape_MNK{}))::value; + constexpr int AtomLayoutMS = Kernel_traits::AtomLayoutMSdP; + constexpr bool Double_buffer = !Kernel_traits::No_double_buffer; + + const BlockInfo binfo(params, bidb); + if (n_block * kBlockN >= binfo.actual_seqlen_k || binfo.actual_seqlen_q == 0) return; + + int m_block_max = cute::ceil_div(binfo.actual_seqlen_q, kBlockM); + + const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb) + + (m_block_max - 1) * kBlockM * params.q_row_stride + bidh * params.q_head_stride; + const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb) + + n_block * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; + const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb) + + n_block * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; + const index_t row_offset_do = binfo.q_offset(params.do_batch_stride, params.do_row_stride, bidb) + + (m_block_max - 1) * kBlockM * params.do_row_stride + bidh * params.do_head_stride; + const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + + (m_block_max - 1) * kBlockM * params.o_row_stride + bidh * params.o_head_stride; + const index_t row_offset_dq = binfo.q_offset(params.dq_batch_stride, params.dq_row_stride, bidb) + + (m_block_max - 1) * kBlockM * params.dq_row_stride + bidh * params.dq_head_stride; + const index_t row_offset_dq_accum = ((bidb * params.h + bidh) * params.seqlen_q_rounded + + (m_block_max - 1) * kBlockM) * params.d_rounded; + const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + + (m_block_max - 1) * kBlockM; + const index_t row_offset_dpsum = (bidb * params.h + bidh) * params.seqlen_q_rounded + + (m_block_max - 1) * kBlockM; + + Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), + Shape, Int>{}, + make_stride(params.q_row_stride, _1{})); + Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), + Shape, Int>{}, + make_stride(params.k_row_stride, _1{})); + Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast(params.v_ptr) + row_offset_v), + Shape, Int>{}, + make_stride(params.v_row_stride, _1{})); + Tensor gdO = make_tensor(make_gmem_ptr(reinterpret_cast(params.do_ptr) + row_offset_do), + Shape, Int>{}, + make_stride(params.do_row_stride, _1{})); + Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + row_offset_o), + Shape, Int>{}, + make_stride(params.do_row_stride, _1{})); + Tensor gdQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.dq_ptr) + row_offset_dq), + Shape, Int>{}, + make_stride(params.dq_row_stride, _1{})); + Tensor gdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dq_accum_ptr) + row_offset_dq_accum), + Shape, Int>{}, + Stride, _1>{}); + Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), + Shape>{}, Stride<_1>{}); + Tensor gdPsum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dsoftmax_sum) + row_offset_dpsum), + Shape>{}, Stride<_1>{}); + + Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), + typename Kernel_traits::SmemLayoutQdO{}); + Tensor sQt = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutQdOtransposed{}); + Tensor sQtNoSwizzle = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutQdOtransposedNoSwizzle{}); + // Double buffer for sQ + Tensor sdO = make_tensor(sQ.data() + (Double_buffer ? 2 : 1) * size(sQ), typename Kernel_traits::SmemLayoutQdO{}); + Tensor sdOt = make_tensor(sdO.data(), typename Kernel_traits::SmemLayoutQdOtransposed{}); + Tensor sdOtransposedNoSwizzle = make_tensor(sdO.data(), + typename Kernel_traits::SmemLayoutQdOtransposedNoSwizzle{}); + Tensor sK = make_tensor(sdO.data() + size(sdO), typename Kernel_traits::SmemLayoutKV{}); + Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{}); + Tensor sKt = make_tensor(sK.data(), typename Kernel_traits::SmemLayoutKtransposed{}); + Tensor sKtNoSwizzle = make_tensor(sK.data(), typename Kernel_traits::SmemLayoutKtransposedNoSwizzle{}); + Tensor sdS = make_tensor(!Kernel_traits::Is_V_in_regs ? sV.data() + size(sV) : sK.data() + size(sK), + typename Kernel_traits::SmemLayoutPdS{}); + Tensor sdSt = make_tensor(sdS.data(), typename Kernel_traits::SmemLayoutPdStransposed{}); + Tensor sdStNoSwizzle = make_tensor(sdS.data(), typename Kernel_traits::SmemLayoutPdStransposedNoSwizzle{}); + Tensor sP = make_tensor(sdS.data() + size(sdS), typename Kernel_traits::SmemLayoutPdS{}); + Tensor sPt = make_tensor(sP.data(), typename Kernel_traits::SmemLayoutPdStransposed{}); + Tensor sPtNoSwizzle = make_tensor(sP.data(), typename Kernel_traits::SmemLayoutPdStransposedNoSwizzle{}); + // sP and sdQ share the same memory so be careful + Tensor sdQ = make_tensor(sP.data(), typename Kernel_traits::SmemLayoutdQ{}); + Tensor sdPsum = make_tensor(make_smem_ptr(reinterpret_cast((sP.data() + cute::max(size(sP), size(sdQ))).get())), + Shape>{}); + + auto gmem_thr_copy_QKV = typename Kernel_traits::GmemTiledCopyQKV{}.get_thread_slice(tidx); + using GmemTiledCopydO = std::conditional_t< + Is_first, + typename Kernel_traits::GmemTiledCopydO, + typename Kernel_traits::GmemTiledCopyQKV + >; + auto gmem_thr_copy_dO = GmemTiledCopydO{}.get_thread_slice(tidx); + auto gmem_thr_copy_dQ = typename Kernel_traits::GmemTiledCopydQ{}.get_thread_slice(tidx); + using GmemLayoutAtomdQaccum = std::conditional_t< + !Seq_parallel, + typename Kernel_traits::GmemTiledCopydQaccum, + typename Kernel_traits::GmemTiledCopydQaccumAtomicAdd + >; + auto gmem_thr_copy_dQ_accum = GmemLayoutAtomdQaccum{}.get_thread_slice(tidx); + + Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); + Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); + Tensor tdOgdO = gmem_thr_copy_dO.partition_S(gdO); + Tensor tdOsdO = gmem_thr_copy_dO.partition_D(sdO); + Tensor tdOgO = gmem_thr_copy_dO.partition_S(gO); + Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K) + Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); + Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) + Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); + Tensor tdQsdQ = gmem_thr_copy_dQ.partition_S(sdQ); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tdQgdQ = gmem_thr_copy_dQ.partition_D(gdQ); + Tensor tdQgdQaccum = gmem_thr_copy_dQ_accum.partition_D(gdQaccum); + // if (cute::thread0()) { print(tdQgdQaccum.layout()); printf("\n"); } + // __syncthreads(); + // if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && tidx < 64) { + // printf("tidx = %d, tdQgdQaccum = 0x%p\n", tidx, tdQgdQaccum.data()); + // } + + typename Kernel_traits::TiledMmaSdP tiled_mma_sdp; + auto thr_mma_sdp = tiled_mma_sdp.get_thread_slice(tidx); + Tensor tSrQ = thr_mma_sdp.partition_fragment_A(sQ); // (MMA,MMA_N,MMA_K) + Tensor tSrK = thr_mma_sdp.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) + Tensor tdPrdO = thr_mma_sdp.partition_fragment_A(sdO); // (MMA,MMA_N,MMA_K) + Tensor tdPrV = thr_mma_sdp.partition_fragment_B(sV); // (MMA,MMA_N,MMA_K) + + typename Kernel_traits::TiledMmadKV tiled_mma_dkv; + auto thr_mma_dkv = tiled_mma_dkv.get_thread_slice(tidx); + Tensor tdKrdSt = thr_mma_dkv.partition_fragment_A(sdStNoSwizzle); // (MMA, MMA_N, MMA_N) + Tensor tdKrQt = thr_mma_dkv.partition_fragment_B(sQtNoSwizzle); // (MMA, MMA_K, MMA_N) + Tensor tdVrPt = thr_mma_dkv.partition_fragment_A(sPtNoSwizzle); // (MMA, MMA_N, MMA_N) + Tensor tdVrdO = thr_mma_dkv.partition_fragment_B(sdOtransposedNoSwizzle); // (MMA, MMA_K, MMA_N) + + typename Kernel_traits::TiledMmadQ tiled_mma_dq; + auto thr_mma_dq = tiled_mma_dq.get_thread_slice(tidx); + Tensor tdQrdS = thr_mma_dq.partition_fragment_A(sdS); // (MMA, MMA_N, MMA_N) + Tensor tdQrKt = thr_mma_dq.partition_fragment_B(sKtNoSwizzle); // (MMA, MMA_K, MMA_N) + + Tensor acc_dk = partition_fragment_C(tiled_mma_dkv, Shape, Int>{}); // MMA, MMA_N, MMA_K + Tensor acc_dv = partition_fragment_C(tiled_mma_dkv, Shape, Int>{}); // MMA, MMA_N, MMA_K + + // + // Copy Atom retiling + // + + auto smem_thr_copy_QdO = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma_sdp).get_thread_slice(tidx); + Tensor tSsQ = smem_thr_copy_QdO.partition_S(sQ); + Tensor tdPsdO = smem_thr_copy_QdO.partition_S(sdO); + + // auto smem_thr_copy_KV = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma_sdp).get_thread_slice(tidx); + auto smem_thr_copy_KV = make_tiled_copy_B_warpcontiguousN(typename Kernel_traits::SmemCopyAtom{}, tiled_mma_sdp).get_thread_slice(tidx); + Tensor tSsK = smem_thr_copy_KV.partition_S(sK); + // if (cute::thread(0, 0) && n_block == 0) { printf("sK layout: "); print(sK.layout()); printf("\n"); } + // if (cute::thread(0, 0) && n_block == 0) { print(tSsK.layout()); printf("\n"); } + Tensor tdPsV = smem_thr_copy_KV.partition_S(sV); + + // Partition sP and sdS to match the accumulator partitioning + // This has to be tiled_mma_sdp, not tiled_mma_dkv + // auto smem_thr_copy_PdS = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomPdS{}, tiled_mma_sdp).get_thread_slice(tidx); + auto smem_thr_copy_PdS = make_tiled_copy_C_warpcontiguousN(typename Kernel_traits::SmemCopyAtomPdS{}, tiled_mma_sdp).get_thread_slice(tidx); + Tensor tPsP = smem_thr_copy_PdS.partition_D(sP); // ((Atom,AtomNum),PIPE_M,PIPE_N) + // if (cute::thread(0, 0) && n_block == 0) { printf("sP layout: "); print(sP.layout()); printf("\n"); } + // if (cute::thread(0, 0) && n_block == 0) { print(tPsP.layout()); printf("\n"); } + // if (n_block == 0 && blockIdx.x == 0 && blockIdx.y == 0 && tidx < 64) { + // printf("tidx=%d, tPsP = 0x%p\n", tidx, tPsP.data()); + // } + Tensor tdSsdS = smem_thr_copy_PdS.partition_D(sdS); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + auto smem_thr_copy_PdSt = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma_dkv).get_thread_slice(tidx); + Tensor tdVsPt = smem_thr_copy_PdSt.partition_S(sPt); + Tensor tdKsdSt = smem_thr_copy_PdSt.partition_S(sdSt); + + auto smem_thr_copy_QdOt = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma_dkv).get_thread_slice(tidx); + Tensor tdVsdOt = smem_thr_copy_QdOt.partition_S(sdOt); + Tensor tdKsQt = smem_thr_copy_QdOt.partition_S(sQt); + + auto smem_thr_copy_dS = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma_dq).get_thread_slice(tidx); + Tensor tdQsdS = smem_thr_copy_dS.partition_S(sdS); + + auto smem_thr_copy_Kt = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma_dq).get_thread_slice(tidx); + Tensor tdQsKt = smem_thr_copy_Kt.partition_S(sKt); + + auto smem_thr_copy_dQ = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdQ{}, tiled_mma_dq).get_thread_slice(tidx); + Tensor taccdQsdQ = smem_thr_copy_dQ.partition_D(sdQ); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + // + // PREDICATES + // + + Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) + Tensor tQcQ = gmem_thr_copy_QKV.partition_D(cQ); + Tensor tKVcKV = gmem_thr_copy_QKV.partition_D(cKV); + + // Allocate predicate tensors for k + Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); + Tensor tKVpKV = make_tensor(make_shape(size<2>(tKsK))); + + // Set predicates for k bounds + if (!Is_even_K) { + #pragma unroll + for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; } + #pragma unroll + for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; } + } + + // Prologue + + // We'll advance gdQ and gdQaccum before the 1st read/write. + tdQgdQ.data() = tdQgdQ.data() + kBlockM * params.dq_row_stride; + tdQgdQaccum.data() = tdQgdQaccum.data() + kBlockM * params.d_rounded; + + int m_block = m_block_max - 1; + int m_block_min = !Is_causal ? 0 : (n_block * kBlockN) / kBlockM; + + // We might need to exit early and write 0 to dK and dV. + // Otherwise we get wrong result for the case where we don't enter the for loop. + // And we might read OOB elements from gQ and gdO. + // TODO: what if we're not parallelizing, do we need to compute dot_do_o? + if (Is_causal && m_block < m_block_min) { + const index_t row_offset_dk = binfo.k_offset(params.dk_batch_stride, params.dk_row_stride, bidb) + + n_block * kBlockN * params.dk_row_stride + bidh * params.dk_head_stride; + const index_t row_offset_dv = binfo.k_offset(params.dv_batch_stride, params.dv_row_stride, bidb) + + n_block * kBlockN * params.dv_row_stride + bidh * params.dv_head_stride; + Tensor gdK = make_tensor(make_gmem_ptr(reinterpret_cast(params.dk_ptr) + row_offset_dk), + Shape, Int>{}, + make_stride(params.dk_row_stride, _1{})); + Tensor gdV = make_tensor(make_gmem_ptr(reinterpret_cast(params.dv_ptr) + row_offset_dv), + Shape, Int>{}, + make_stride(params.dv_row_stride, _1{})); + auto gmem_thr_copy_dKV = typename Kernel_traits::GmemTiledCopydKV{}.get_thread_slice(tidx); + Tensor tdKgdK = gmem_thr_copy_dKV.partition_D(gdK); + Tensor tdVgdV = gmem_thr_copy_dKV.partition_D(gdV); + Tensor tdKrdK = make_tensor(shape(tdKgdK)); + Tensor tdVrdV = make_tensor(shape(tdVgdV)); + clear(tdKrdK); + clear(tdVrdV); + Tensor cdKV = make_identity_tensor(make_shape(size<0>(gdK), size<1>(gdK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) + Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV); + Tensor tdKVpdKV = make_tensor(make_shape(size<2>(tdKgdK))); + #pragma unroll + for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(0, 0, k)) < params.d; } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_thr_copy_dKV, tdKrdK, tdKgdK, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN + ); + flash::copy( + gmem_thr_copy_dKV, tdVrdV, tdVgdV, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN + ); + return; + } + + if (Double_buffer && m_block % 2 == 1) { // Double buffer for sQ + tQsQ.data() = tQsQ.data() + size(sQ); + tSsQ.data() = tSsQ.data() + size(sQ); + tdKsQt.data() = tdKsQt.data() + size(sQ); + } + + if (!Is_first && !Seq_parallel) { __syncthreads(); } + + if (Kernel_traits::Is_V_in_regs) { + // Clear the smem tiles to account for predicated off loads + flash::copy( + gmem_thr_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN + ); + flash::cp_async_fence(); + } + + Tensor tdOrdO = make_fragment_like(tdOgdO); + Tensor tdOrO = make_fragment_like(tdOgO); + if (!Is_first) { + // Clear the smem tiles to account for predicated off loads + flash::copy( + gmem_thr_copy_dO, tdOgdO, tdOsdO, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM + ); + } else { + flash::copy( + gmem_thr_copy_dO, tdOgdO, tdOrdO, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM + ); + flash::copy( + gmem_thr_copy_dO, tdOgO, tdOrO, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM + ); + } + flash::copy( + gmem_thr_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM + ); + + Tensor caccS = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_N) -> (blk_m,blk_n) + Tensor taccScS = thr_mma_sdp.partition_C(caccS); // (MMA,MMA_N,MMA_N) + static_assert(decltype(size<0>(taccScS))::value == 4); + // Convert to ((2, 2), MMA_N, MMA_N) then take only the row indices. + Tensor taccScS_row = logical_divide(taccScS, Shape<_2>{})(make_coord(0, _), _, 0); + Tensor lse = make_tensor(Shape>{}); + #pragma unroll + for (int mi = 0; mi < size(lse); ++mi) { + // Using uint32_t row makes it 10us slower on d=128, not sure why. + const int row = get<0>(taccScS_row(mi)); + lse(mi) = Is_even_M || row < binfo.actual_seqlen_q - m_block * kBlockM ? gLSE(row) : 0; + } + + // Tensor tKrK = make_fragment_like(tKsK); + // // copy(gmem_thr_copy_QKV, tKgK(_, _, _, 0), tKrK); + // copy(gmem_thr_copy_QKV, tKgK, tKrK); + // // if (cute::thread(1, 0)) { print(tKrK); } + + flash::copy( + gmem_thr_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN + ); + if (!Kernel_traits::Is_V_in_regs) { + flash::copy( + gmem_thr_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN + ); + } + flash::cp_async_fence(); + + // if (cute::thread0()) { print(tdOgdO.layout()); printf("\n"); print(tdOrdO); print(tdOrO); } + if (Is_first) { + copy(tdOrdO, tdOsdO); + dot_do_o(tdOrdO, tdOrO, gdPsum, sdPsum, + Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout); + } + + if (Kernel_traits::Is_V_in_regs) { + cute::cp_async_wait<1>(); + __syncthreads(); + Tensor tdPrV_copy_view = smem_thr_copy_KV.retile_D(tdPrV); + CUTE_STATIC_ASSERT_V(size<1>(tdPsV) == size<1>(tdPrV_copy_view)); // M + copy(smem_thr_copy_KV, tdPsV, tdPrV_copy_view); + } + + auto seeds = at::cuda::philox::unpack(params.philox_args); + unsigned long long seed = std::get<0>(seeds); + unsigned long long offset = std::get<1>(seeds) + (bidb * params.h + bidh) * 32 + tidx % 32; + + clear(acc_dv); + clear(acc_dk); + + for (; m_block >= m_block_min; --m_block) { + Tensor acc_s = partition_fragment_C(tiled_mma_sdp, Shape, Int>{}); // (MMA=4, MMA_N, MMA_N) + clear(acc_s); + cute::cp_async_wait<0>(); + __syncthreads(); + + Tensor dP_sum = make_fragment_like(lse); + #pragma unroll + for (int mi = 0; mi < size(lse); ++mi) { dP_sum(mi) = gdPsum(get<0>(taccScS_row(mi))); } + + // if (cute::thread0()) { print(sK); } + // Tensor tSrK_copy_view = smem_thr_copy_KV.retile_D(tSrK); + // #pragma unroll + // for (int k = 0; k < size<2>(tSrK_copy_view); ++k) { + // copy(smem_thr_copy_KV, tSsK(_, _, k), tSrK_copy_view(_, _, k)); + // } + // if (cute::thread0()) { print(tSrK); } + flash::gemm(acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma_sdp, smem_thr_copy_QdO, smem_thr_copy_KV); + + // Reshape acc_s from (MMA=4, MMA_N, MMA_N) to (col=(2, MMA_N), row=(2, MMA_N)) + Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); + // if (cute::thread(32, 0)) { print(scores); } + // We don't need to mask out the elements beyond actual_seqlen_k, because acc_s would + // be some finite value for those indices. In the end when we multiply with K to get dQ, + // the corresponding values of K would be 0, so the result would still be correct. + // Putting this causal masking right after acc_s is *much* slower for some reason. + if (Is_causal && m_block * kBlockM < (n_block + 1) * kBlockN) { + flash::apply_mask_causal(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16, + binfo.actual_seqlen_k, m_block * kBlockM + get<0>(taccScS_row(0)), + // binfo.actual_seqlen_k, m_block * kBlockM + (tidx / 32) % AtomLayoutMS * 16 + (tidx % 32) / 4, + AtomLayoutMS * 16); + } + // if (cute::thread(32, 0)) { print(scores); } + // Compute the exponential value. + flash::scale_apply_exp2(scores, lse, params.scale_softmax_log2); + if (Is_dropout) { + uint32_t warp_id = tidx / 32; + uint32_t block_row_idx = m_block * (kBlockM / 16) + warp_id % AtomLayoutMS; + // Need col to be multiples of 32, since we're doing dropout with block of 16 x 32 + static_assert(MMA_N_SdP % 2 == 0); + uint32_t block_col_idx = n_block * (kBlockN / 32) + (warp_id / AtomLayoutMS) * (MMA_N_SdP / 2); + Tensor scores_dropped = make_tensor(scores.data(), flash::convert_layout_rowcol_Aregs(scores.layout())); + flash::apply_dropout( + scores_dropped, params.p_dropout_in_uint8_t, seed, offset, + block_row_idx, block_col_idx, AtomLayoutMS + ); + } + // Convert scores from fp32 to fp16/bf16 + Tensor rP = !Is_dropout + ? flash::convert_type(scores) + : flash::convert_type_relu(scores); + // Reshape rP from (nrow=(2, MMA_N), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_N, MMA_N / 2) + // if using m16n8k16 or ((2, 2, 1), MMA_N, MMA_N) if using m16n8k8. + Tensor tPrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs(rP.layout())); + Tensor tPaP = smem_thr_copy_PdS.retile_S(tPrP); // ((Atom,AtomNum), MMA_N, MMA_N) + copy(smem_thr_copy_PdS, tPaP, tPsP); + // if (cute::thread0()) { print(tPaP); } + // __syncthreads(); + // if (cute::thread0()) { print(sP); } + + Tensor acc_dp = partition_fragment_C(tiled_mma_sdp, Shape, Int>{}); // (MMA=4, MMA_N, MMA_N) + CUTE_STATIC_ASSERT_V(size<0>(acc_dp) == size<0>(acc_s)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(acc_dp) == size<1>(acc_s)); // MMA + CUTE_STATIC_ASSERT_V(size<2>(acc_dp) == size<2>(acc_s)); // MMA + + clear(acc_dp); + // Tensor acc_dp_reshaped = make_tensor(acc_dp.data(), flash::convert_layout_acc_rowcol(acc_dp.layout())); + // #pragma unroll + // for (int mi = 0; mi < size<0>(acc_dp_reshaped); ++mi) { + // #pragma unroll + // for (int ni = 0; ni < size<1>(acc_dp_reshaped); ++ni) { + // acc_dp_reshaped(mi, ni) = -dP_sum(mi); + // } + // } + + // if (cute::thread0()) { print(dP_sum); } + + flash::gemm( + acc_dp, tdPrdO, tdPrV, tdPsdO, tdPsV, tiled_mma_sdp, smem_thr_copy_QdO, smem_thr_copy_KV + ); + + // Reshape acc_dp from (MMA=4, MMA_N, MMA_N) to (col=(2, MMA_N), row=(2, MMA_N)) + Tensor dS = make_tensor(acc_dp.data(), scores.layout()); + auto pointwise_mult = [](float p, float dp, float d) { + return p * (!Is_dropout || p >= 0 ? dp - d : d); + }; + #pragma unroll + for (int mi = 0; mi < size<0>(dS); ++mi) { + #pragma unroll + for (int ni = 0; ni < size<1>(dS); ++ni) { + dS(mi, ni) = pointwise_mult(scores(mi, ni), dS(mi, ni), dP_sum(mi)); + } + } + // if (cute::thread0()) { print(dS); } + + Tensor acc_dq = partition_fragment_C(tiled_mma_dq, Shape, Int>{}); // MMA, MMA_N, MMA_K + tdQgdQaccum.data() = tdQgdQaccum.data() + (-int(kBlockM * params.d_rounded)); + if (Is_first || Seq_parallel) { + clear(acc_dq); + } else { + // Reshape acc_dq from (4, 1, 2) to (4, 2, 1) to write to gdQaccum + Tensor acc_dq_reshaped = make_tensor(acc_dq.data(), + make_layout(get<0>(acc_dq.layout()), + get<2>(acc_dq.layout()), + get<1>(acc_dq.layout()))); + copy(gmem_thr_copy_dQ_accum, tdQgdQaccum, acc_dq_reshaped); + } + + if (Double_buffer && m_block > m_block_min) { + // Double buffer for sQ + const int sQ_offset = m_block % 2 == 0 ? size(sQ) : -size(sQ); + tQsQ.data() = tQsQ.data() + sQ_offset; + tSsQ.data() = tSsQ.data() + sQ_offset; + // Advance gQ + tQgQ.data() = tQgQ.data() + (-int(kBlockM * params.q_row_stride)); + flash::copy(gmem_thr_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ); + flash::cp_async_fence(); + } + + Tensor dS_reshaped = make_tensor(dS.data(), acc_dp.layout()); + // Convert dS from fp32 to fp16 + Tensor tdSrdS = flash::convert_type(dS_reshaped); + // if (cute::thread0()) { print(tPrP); } + Tensor tdSadS = smem_thr_copy_PdS.retile_S(tdSrdS); // ((Atom,AtomNum), MMA_N, MMA_N) + copy(smem_thr_copy_PdS, tdSadS, tdSsdS); + __syncthreads(); + + // Layout p_l = tPrP.layout(); + // Tensor tdVrPt = make_tensor(tPrP.data(), make_layout(get<0>(p_l), get<2>(p_l), get<1>(p_l))); + // flash::gemm_A_in_regs(acc_dv, tdVrPt, tdVrdO, tdVsdOt, tiled_mma_dkv, smem_thr_copy_QdOt); + // Tensor tdKrdSt = make_tensor(tdSrdS.data(), tdVrPt.layout()); + // flash::gemm_A_in_regs(acc_dk, tdKrdSt, tdKrQt, tdKsQt, tiled_mma_dkv, smem_thr_copy_QdOt); + flash::gemm(acc_dv, tdVrPt, tdVrdO, tdVsPt, tdVsdOt, tiled_mma_dkv, smem_thr_copy_PdSt, smem_thr_copy_QdOt); + // if (cute::thread0() && n_block == 0 && m_block == 0) { print(tdVrPt); } + // if (cute::thread0()) { print(acc_dv); } + + __syncthreads(); // Need syncthreads since we're writing to the same sdO location + + if (m_block > m_block_min) { + // Advance gdO + tdOgdO.data() = tdOgdO.data() + (-int(kBlockM * params.do_row_stride)); + if (Is_first) { + tdOgO.data() = tdOgO.data() + (-int(kBlockM * params.o_row_stride)); + flash::copy(gmem_thr_copy_dO, tdOgdO, tdOrdO, tQcQ, tQpQ); + flash::copy(gmem_thr_copy_dO, tdOgO, tdOrO, tQcQ, tQpQ); + } else { + flash::copy(gmem_thr_copy_dO, tdOgdO, tdOsdO, tQcQ, tQpQ); + flash::cp_async_fence(); + } + } + + flash::gemm(acc_dq, tdQrdS, tdQrKt, tdQsdS, tdQsKt, tiled_mma_dq, smem_thr_copy_dS, smem_thr_copy_Kt); + // if (cute::thread0()) { print(acc_dq); } + + if (m_block > m_block_min) { + gLSE.data() = gLSE.data() + (-int(kBlockM)); + #pragma unroll + for (int mi = 0; mi < size(lse); ++mi) { lse(mi) = gLSE(get<0>(taccScS_row(mi))); } + gdPsum.data() = gdPsum.data() + (-int(kBlockM)); + // if (!Is_first && tidx < kBlockM / 2) { + // sdPsum(tidx) = recast(gdPsum)(tidx); + // if (!Is_first && tidx < kBlockM) { + // recast(sdPsum)(tidx) = gdPsum(tidx); + // } + } + + if (!Is_last) { + // Reshape acc_dq from (4, 1, 2) to (4, 2, 1) to write to gdQaccum + Tensor acc_dq_reshaped = make_tensor(acc_dq.data(), + make_layout(get<0>(acc_dq.layout()), + get<2>(acc_dq.layout()), + get<1>(acc_dq.layout()))); + if (!Seq_parallel) { + copy(gmem_thr_copy_dQ_accum, acc_dq_reshaped, tdQgdQaccum); + } else { + // if (cute::thread0()) { print(acc_dq.layout()); printf("\n"); print(acc_dq_reshaped.layout()); printf("\n"); print(tdQgdQaccum.layout()); printf("\n"); } + CUTE_STATIC_ASSERT_V(size(acc_dq) == size(tdQgdQaccum)); + #pragma unroll + for (int i = 0; i < size(acc_dq); ++i) { atomicAdd(&tdQgdQaccum(i), acc_dq(i)); } + } + } else { + #pragma unroll + for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) *= params.scale_softmax_rp_dropout; } + // Convert acc_dq from fp32 to fp16 + Tensor rdQ = flash::convert_type(acc_dq); + Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ); // ((Atom,AtomNum), MMA_N, MMA_N) + copy(smem_thr_copy_dQ, taccdQrdQ, taccdQsdQ); + } + + flash::gemm(acc_dk, tdKrdSt, tdKrQt, tdKsdSt, tdKsQt, tiled_mma_dkv, smem_thr_copy_PdSt, smem_thr_copy_QdOt); + // if (cute::thread0()) { print(acc_dk); } + if (Double_buffer) { // Double buffer for sQ + tdKsQt.data() = tdKsQt.data() + (m_block % 2 == 0 ? size(sQ) : -size(sQ)); + } + if (!Double_buffer && m_block > m_block_min) { + __syncthreads(); + // Advance gQ + tQgQ.data() = tQgQ.data() + (-int(kBlockM * params.q_row_stride)); + flash::copy(gmem_thr_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ); + flash::cp_async_fence(); + } + + if (Is_first && m_block > m_block_min) { + copy(tdOrdO, tdOsdO); + dot_do_o(tdOrdO, tdOrO, gdPsum, sdPsum, + Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout); + } + + if (Is_last) { + __syncthreads(); + Tensor tdQrdQ = make_tensor(shape(tdQgdQ)); + copy(gmem_thr_copy_dQ, tdQsdQ, tdQrdQ); + tdQgdQ.data() = tdQgdQ.data() + (-int(kBlockM * params.dq_row_stride)); + Tensor cdQ = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor tdQcdQ = gmem_thr_copy_dQ.partition_D(cdQ); + #pragma unroll + for (int m = 0; m < size<1>(tdQgdQ); ++m) { + if (Is_even_M || get<0>(tdQcdQ(0, m, 0)) < binfo.actual_seqlen_q - m_block * kBlockM) { + copy(gmem_thr_copy_dQ, tdQrdQ(_, m, _), tdQgdQ(_, m, _)); + } + } + } + + } + + // Epilogue + + if (Is_dropout) { + #pragma unroll + for (int i = 0; i < size(acc_dv); ++i) { acc_dv(i) *= params.rp_dropout; } + } + #pragma unroll + for (int i = 0; i < size(acc_dk); ++i) { acc_dk(i) *= params.scale_softmax_rp_dropout; } + + // Convert acc_dv from fp32 to fp16 + Tensor rdK = flash::convert_type(acc_dk); + Tensor rdV = flash::convert_type(acc_dv); + + Tensor sdK = make_tensor(sK.data(), typename Kernel_traits::SmemLayoutdKV{}); // (SMEM_N, SMEM_K) + Tensor sdV = make_tensor(sdK.data() + size(sdK), typename Kernel_traits::SmemLayoutdKV{}); // (SMEM_N, SMEM_K) + + // Partition sdV and sdK to match the accumulator partitioning + auto smem_thr_copy_dKV = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdKV{}, tiled_mma_dkv).get_thread_slice(tidx); + Tensor taccdKrdK = smem_thr_copy_dKV.retile_S(rdK); // ((Atom,AtomNum), MMA_N, MMA_N) + Tensor taccdKsdK = smem_thr_copy_dKV.partition_D(sdK); // ((Atom,AtomNum),PIPE_M,PIPE_N) + Tensor taccdVrdV = smem_thr_copy_dKV.retile_S(rdV); // ((Atom,AtomNum), MMA_N, MMA_N) + Tensor taccdVsdV = smem_thr_copy_dKV.partition_D(sdV); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + // If we don't need syncthreads here since we're writing to the same location as sK and sV. + // Unless Is_V_in_regs. If Is_last, there's already a __syncthreads() at the end of the loop. + if (Kernel_traits::Is_V_in_regs && !Is_last) { __syncthreads(); } + + copy(smem_thr_copy_dKV, taccdKrdK, taccdKsdK); + copy(smem_thr_copy_dKV, taccdVrdV, taccdVsdV); + + const index_t row_offset_dk = binfo.k_offset(params.dk_batch_stride, params.dk_row_stride, bidb) + + n_block * kBlockN * params.dk_row_stride + bidh * params.dk_head_stride; + const index_t row_offset_dv = binfo.k_offset(params.dv_batch_stride, params.dv_row_stride, bidb) + + n_block * kBlockN * params.dv_row_stride + bidh * params.dv_head_stride; + Tensor gdK = make_tensor(make_gmem_ptr(reinterpret_cast(params.dk_ptr) + row_offset_dk), + Shape, Int>{}, + make_stride(params.dk_row_stride, _1{})); + Tensor gdV = make_tensor(make_gmem_ptr(reinterpret_cast(params.dv_ptr) + row_offset_dv), + Shape, Int>{}, + make_stride(params.dv_row_stride, _1{})); + + auto gmem_thr_copy_dKV = typename Kernel_traits::GmemTiledCopydKV{}.get_thread_slice(tidx); + Tensor tdKsdK = gmem_thr_copy_dKV.partition_S(sdK); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tdKgdK = gmem_thr_copy_dKV.partition_D(gdK); + Tensor tdVsdV = gmem_thr_copy_dKV.partition_S(sdV); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tdVgdV = gmem_thr_copy_dKV.partition_D(gdV); + + __syncthreads(); + Tensor tdKrdK = make_tensor(shape(tdKgdK)); + copy(gmem_thr_copy_dKV, tdKsdK, tdKrdK); + Tensor tdVrdV = make_tensor(shape(tdVgdV)); + copy(gmem_thr_copy_dKV, tdVsdV, tdVrdV); + Tensor cdKV = make_identity_tensor(make_shape(size<0>(sdK), size<1>(sdK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) + Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV); + Tensor tdKVpdKV = make_tensor(make_shape(size<2>(tdKgdK))); + #pragma unroll + for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(0, 0, k)) < params.d; } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_thr_copy_dKV, tdKrdK, tdKgdK, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN + ); + flash::copy( + gmem_thr_copy_dKV, tdVrdV, tdVgdV, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN + ); + +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void compute_dq_dk_dv_1rowblock(const Params ¶ms, const int bidb, const int bidh, const int m_block) { + + using Element = typename Kernel_traits::Element; + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + + // Shared memory. + extern __shared__ char smem_[]; + + // The thread index. + const int tidx = threadIdx.x; + + constexpr int kBlockM = Kernel_traits::kBlockM; + constexpr int kBlockN = Kernel_traits::kBlockN; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + // constexpr int kNWarps = Kernel_traits::kNWarps; + constexpr int MMA_N_SdP = kBlockN / decltype(size<1>(typename Kernel_traits::TiledMmaSdP::TiledShape_MNK{}))::value; + constexpr int AtomLayoutMS = Kernel_traits::AtomLayoutMSdP; + + const BlockInfo binfo(params, bidb); + if (m_block * kBlockM >= binfo.actual_seqlen_q || binfo.actual_seqlen_k == 0) return; + + int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN); + if (Is_causal) { + n_block_max = std::min(n_block_max, cute::ceil_div((m_block + 1) * kBlockM, kBlockN)); + } + + // We iterate over the blocks in reverse order. This is because the last block is the only one + // that needs masking when we read K and V from global memory. Moreover, iterating in reverse + // might save us 1 register (we just need n_block instead of both n_block and n_block_max). + + const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb) + + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride; + // We move K and V to the last block. + const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb) + + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; + const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb) + + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; + const index_t row_offset_do = binfo.q_offset(params.do_batch_stride, params.do_row_stride, bidb) + + m_block * kBlockM * params.do_row_stride + bidh * params.do_head_stride; + const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + + m_block * kBlockM * params.do_row_stride + bidh * params.o_head_stride; + // We'll advance gdKaccum and gdVaccum before the first write. + const index_t row_offset_dkv_accum = ((bidb * params.h_k + (bidh / params.h_h_k_ratio)) * params.seqlen_k_rounded + + n_block_max * kBlockN) * params.d_rounded; + const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM; + + // We assume that params.d == kHeadDim for now + Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), + Shape, Int>{}, + make_stride(params.q_row_stride, _1{})); + Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), + Shape, Int>{}, + make_stride(params.k_row_stride, _1{})); + Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast(params.v_ptr) + row_offset_v), + Shape, Int>{}, + make_stride(params.v_row_stride, _1{})); + Tensor gdO = make_tensor(make_gmem_ptr(reinterpret_cast(params.do_ptr) + row_offset_do), + Shape, Int>{}, + make_stride(params.do_row_stride, _1{})); + Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + row_offset_o), + Shape, Int>{}, + make_stride(params.do_row_stride, _1{})); + Tensor gdKaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dk_accum_ptr) + row_offset_dkv_accum), + Shape, Int>{}, + Stride, _1>{}); + Tensor gdVaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dv_accum_ptr) + row_offset_dkv_accum), + Shape, Int>{}, + Stride, _1>{}); + Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), + Shape>{}, Stride<_1>{}); + + Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), + typename Kernel_traits::SmemLayoutQdO{}); + Tensor sQt = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutQdOtransposed{}); + Tensor sQtNoSwizzle = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutQdOtransposedNoSwizzle{}); + Tensor sdO = make_tensor(sQ.data() + size(sQ), typename Kernel_traits::SmemLayoutQdO{}); + Tensor sdOt = make_tensor(sdO.data(), typename Kernel_traits::SmemLayoutQdOtransposed{}); + Tensor sdOtransposedNoSwizzle = make_tensor(sdO.data(), + typename Kernel_traits::SmemLayoutQdOtransposedNoSwizzle{}); + Tensor sK = make_tensor(sdO.data() + size(sdO), typename Kernel_traits::SmemLayoutKV{}); + // Double buffer for sK + Tensor sV = make_tensor(sK.data() + 2 * size(sK), typename Kernel_traits::SmemLayoutKV{}); + Tensor sKt = make_tensor(sK.data(), typename Kernel_traits::SmemLayoutKtransposed{}); + Tensor sKtNoSwizzle = make_tensor(sK.data(), typename Kernel_traits::SmemLayoutKtransposedNoSwizzle{}); + Tensor sdS = make_tensor(sV.data() + size(sV), typename Kernel_traits::SmemLayoutPdS{}); + Tensor sdSt = make_tensor(sdS.data(), typename Kernel_traits::SmemLayoutPdStransposed{}); + Tensor sdStNoSwizzle = make_tensor(sdS.data(), typename Kernel_traits::SmemLayoutPdStransposedNoSwizzle{}); + Tensor sP = make_tensor(sdS.data() + size(sdS), typename Kernel_traits::SmemLayoutPdS{}); + Tensor sPt = make_tensor(sP.data(), typename Kernel_traits::SmemLayoutPdStransposed{}); + Tensor sPtNoSwizzle = make_tensor(sP.data(), typename Kernel_traits::SmemLayoutPdStransposedNoSwizzle{}); + Tensor sdPsum = make_tensor(make_smem_ptr(reinterpret_cast(sdS.data().get())), + Shape>{}); + + auto gmem_thr_copy_QKV = typename Kernel_traits::GmemTiledCopyQKV{}.get_thread_slice(tidx); + auto gmem_thr_copy_dO = typename Kernel_traits::GmemTiledCopydO{}.get_thread_slice(tidx); + auto gmem_thr_copy_dKV_accum = typename Kernel_traits::GmemTiledCopydQaccumAtomicAdd{}.get_thread_slice(tidx); + + Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); + Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); + Tensor tdOgdO = gmem_thr_copy_dO.partition_S(gdO); + Tensor tdOsdO = gmem_thr_copy_dO.partition_D(sdO); + Tensor tdOgO = gmem_thr_copy_dO.partition_S(gO); + Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K) + Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); + Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) + Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); + Tensor tdKgdKaccum = gmem_thr_copy_dKV_accum.partition_D(gdKaccum); + Tensor tdVgdVaccum = gmem_thr_copy_dKV_accum.partition_D(gdVaccum); + + typename Kernel_traits::TiledMmaSdP tiled_mma_sdp; + auto thr_mma_sdp = tiled_mma_sdp.get_thread_slice(tidx); + Tensor tSrQ = thr_mma_sdp.partition_fragment_A(sQ); // (MMA,MMA_N,MMA_K) + Tensor tSrK = thr_mma_sdp.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) + Tensor tdPrdO = thr_mma_sdp.partition_fragment_A(sdO); // (MMA,MMA_N,MMA_K) + Tensor tdPrV = thr_mma_sdp.partition_fragment_B(sV); // (MMA,MMA_N,MMA_K) + + typename Kernel_traits::TiledMmadKV tiled_mma_dkv; + auto thr_mma_dkv = tiled_mma_dkv.get_thread_slice(tidx); + Tensor tdKrdSt = thr_mma_dkv.partition_fragment_A(sdStNoSwizzle); // (MMA, MMA_N, MMA_N) + Tensor tdKrQt = thr_mma_dkv.partition_fragment_B(sQtNoSwizzle); // (MMA, MMA_K, MMA_N) + Tensor tdVrPt = thr_mma_dkv.partition_fragment_A(sPtNoSwizzle); // (MMA, MMA_N, MMA_N) + Tensor tdVrdO = thr_mma_dkv.partition_fragment_B(sdOtransposedNoSwizzle); // (MMA, MMA_K, MMA_N) + + typename Kernel_traits::TiledMmadQ tiled_mma_dq; + auto thr_mma_dq = tiled_mma_dq.get_thread_slice(tidx); + Tensor tdQrdS = thr_mma_dq.partition_fragment_A(sdS); // (MMA, MMA_N, MMA_N) + Tensor tdQrKt = thr_mma_dq.partition_fragment_B(sKtNoSwizzle); // (MMA, MMA_K, MMA_N) + + Tensor acc_dq = partition_fragment_C(tiled_mma_dq, Shape, Int>{}); // MMA, MMA_M_SdP, MMA_K + + // + // Copy Atom retiling + // + + auto smem_thr_copy_QdO = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma_sdp).get_thread_slice(tidx); + Tensor tSsQ = smem_thr_copy_QdO.partition_S(sQ); + Tensor tdPsdO = smem_thr_copy_QdO.partition_S(sdO); + + auto smem_thr_copy_KV = make_tiled_copy_B_warpcontiguousN(typename Kernel_traits::SmemCopyAtom{}, tiled_mma_sdp).get_thread_slice(tidx); + Tensor tSsK = smem_thr_copy_KV.partition_S(sK); + Tensor tdPsV = smem_thr_copy_KV.partition_S(sV); + + // Partition sP and sdS to match the accumulator partitioning + // This has to be tiled_mma_sdp, not tiled_mma_dkv + auto smem_thr_copy_PdS = make_tiled_copy_C_warpcontiguousN(typename Kernel_traits::SmemCopyAtomPdS{}, tiled_mma_sdp).get_thread_slice(tidx); + Tensor tPsP = smem_thr_copy_PdS.partition_D(sP); // ((Atom,AtomNum),PIPE_M,PIPE_N) + Tensor tdSsdS = smem_thr_copy_PdS.partition_D(sdS); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + auto smem_thr_copy_PdSt = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma_dkv).get_thread_slice(tidx); + Tensor tdVsPt = smem_thr_copy_PdSt.partition_S(sPt); + Tensor tdKsdSt = smem_thr_copy_PdSt.partition_S(sdSt); + + auto smem_thr_copy_QdOt = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma_dkv).get_thread_slice(tidx); + Tensor tdVsdOt = smem_thr_copy_QdOt.partition_S(sdOt); + Tensor tdKsQt = smem_thr_copy_QdOt.partition_S(sQt); + + auto smem_thr_copy_dS = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma_dq).get_thread_slice(tidx); + Tensor tdQsdS = smem_thr_copy_dS.partition_S(sdS); + + auto smem_thr_copy_Kt = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma_dq).get_thread_slice(tidx); + Tensor tdQsKt = smem_thr_copy_Kt.partition_S(sKt); + + // + // PREDICATES + // + + // Construct identity layout for sQ and sK + Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) + // Repeat the partitioning with identity layouts + Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) + + // Allocate predicate tensors for k + Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); + Tensor tKVpKV = make_tensor(make_shape(size<2>(tKsK))); + + // Set predicates for k bounds + if (!Is_even_K) { + #pragma unroll + for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; } + #pragma unroll + for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; } + } + + // Prologue + + Tensor tdOrdO = make_fragment_like(tdOgdO); + Tensor tdOrO = make_fragment_like(tdOgO); + + // TODO: Might need to exit early and write 0 to gdQ. + + flash::copy( + gmem_thr_copy_dO, tdOgdO, tdOrdO, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM + ); + flash::copy( + gmem_thr_copy_dO, tdOgO, tdOrO, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM + ); + + Tensor tQrQ = make_fragment_like(tQgQ); + flash::copy( + gmem_thr_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM + ); + + int n_block = n_block_max - 1; + if (n_block % 2 == 1) { + tKsK.data() = tKsK.data() + size(sK); + tSsK.data() = tSsK.data() + size(sK); + tdQsKt.data() = tdQsKt.data() + size(sK); + } + + flash::copy( + gmem_thr_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN + ); + flash::copy( + gmem_thr_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN + ); + + Tensor caccS = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_N) -> (blk_m,blk_n) + Tensor taccScS = thr_mma_sdp.partition_C(caccS); // (MMA,MMA_N,MMA_N) + static_assert(decltype(size<0>(taccScS))::value == 4); + // Convert to ((2, 2), MMA_N, MMA_N) then take only the row indices. + Tensor taccScS_row = logical_divide(taccScS, Shape<_2>{})(make_coord(0, _), _, 0); + Tensor lse = make_tensor(Shape>{}); + #pragma unroll + for (int mi = 0; mi < size(lse); ++mi) { + // Using uint32_t row makes it 10us slower on d=128, not sure why. + const int row = get<0>(taccScS_row(mi)); + lse(mi) = row < binfo.actual_seqlen_q - m_block * kBlockM ? gLSE(row) : 0; + } + + cute::cp_async_fence(); + + Tensor dP_sum = make_fragment_like(lse); + copy(tdOrdO, tdOsdO); + dot_do_o( + tdOrdO, tdOrO, sdPsum, sdPsum, + Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout + ); + __syncthreads(); + #pragma unroll + for (int mi = 0; mi < size(dP_sum); ++mi) { dP_sum(mi) = sdPsum(get<0>(taccScS_row(mi))); } + + auto seeds = at::cuda::philox::unpack(params.philox_args); + unsigned long long seed = std::get<0>(seeds); + unsigned long long offset = std::get<1>(seeds) + (bidb * params.h + bidh) * 32 + tidx % 32; + + clear(acc_dq); + + for (; n_block >= 0; --n_block) { + Tensor acc_s = partition_fragment_C(tiled_mma_sdp, Shape, Int>{}); // (MMA=4, MMA_M_SdP, MMA_N) + clear(acc_s); + flash::cp_async_wait<0>(); + __syncthreads(); + + flash::gemm(acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma_sdp, smem_thr_copy_QdO, smem_thr_copy_KV); + + // Reshape acc_s from (MMA=4, MMA_N, MMA_N) to (col=(2, MMA_N), row=(2, MMA_N)) + Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); + // We don't need to mask out the elements beyond actual_seqlen_k, because acc_s would + // be some finite value for those indices. In the end when we multiply with K to get dQ, + // the corresponding values of K would be 0, so the result would still be correct. + if (Is_causal && m_block * kBlockM < (n_block + 1) * kBlockN) { + flash::apply_mask_causal(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16, + binfo.actual_seqlen_k, m_block * kBlockM + get<0>(taccScS_row(0)), + // binfo.actual_seqlen_k, m_block * kBlockM + (tidx / 32) % AtomLayoutMS * 16 + (tidx % 32) / 4, + AtomLayoutMS * 16); + } + // Compute the exponential value. + flash::scale_apply_exp2(scores, lse, params.scale_softmax_log2); + if (Is_dropout) { + uint32_t warp_id = tidx / 32; + uint32_t block_row_idx = m_block * (kBlockM / 16) + warp_id % AtomLayoutMS; + // Need col to be multiples of 32, since we're doing dropout with block of 16 x 32 + static_assert(MMA_N_SdP % 2 == 0); + uint32_t block_col_idx = n_block * (kBlockN / 32) + (warp_id / AtomLayoutMS) * (MMA_N_SdP / 2); + Tensor scores_dropped = make_tensor(scores.data(), flash::convert_layout_rowcol_Aregs(scores.layout())); + flash::apply_dropout( + scores_dropped, params.p_dropout_in_uint8_t, seed, offset, + block_row_idx, block_col_idx, AtomLayoutMS + ); + } + // Convert scores from fp32 to fp16/bf16 + Tensor rP = !Is_dropout + ? flash::convert_type(scores) + : flash::convert_type_relu(scores); + // Reshape rP from (nrow=(2, MMA_N), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_N, MMA_N / 2) + // if using m16n8k16 or ((2, 2, 1), MMA_N, MMA_N) if using m16n8k8. + Tensor tPrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs(rP.layout())); + Tensor tPaP = smem_thr_copy_PdS.retile_S(tPrP); // ((Atom,AtomNum), MMA_N, MMA_N) + copy(smem_thr_copy_PdS, tPaP, tPsP); + + Tensor acc_dp = partition_fragment_C(tiled_mma_sdp, Shape, Int>{}); // (MMA=4, MMA_N, MMA_N) + CUTE_STATIC_ASSERT_V(size<0>(acc_dp) == size<0>(acc_s)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(acc_dp) == size<1>(acc_s)); // MMA + CUTE_STATIC_ASSERT_V(size<2>(acc_dp) == size<2>(acc_s)); // MMA + + clear(acc_dp); + flash::gemm(acc_dp, tdPrdO, tdPrV, tdPsdO, tdPsV, tiled_mma_sdp, smem_thr_copy_QdO, smem_thr_copy_KV); + + // Reshape acc_dp from (MMA=4, MMA_N, MMA_N) to (col=(2, MMA_N), row=(2, MMA_N)) + Tensor dS = make_tensor(acc_dp.data(), scores.layout()); + auto pointwise_mult = [](float p, float dp, float d) { + return p * (!Is_dropout || p >= 0 ? dp - d : d); + }; + #pragma unroll + for (int mi = 0; mi < size<0>(dS); ++mi) { + #pragma unroll + for (int ni = 0; ni < size<1>(dS); ++ni) { + dS(mi, ni) = pointwise_mult(scores(mi, ni), dS(mi, ni), dP_sum(mi)); + } + } + + Tensor dS_reshaped = make_tensor(dS.data(), acc_dp.layout()); + // Convert dS from fp32 to fp16 + Tensor tdSrdS = flash::convert_type(dS_reshaped); + Tensor tdSadS = smem_thr_copy_PdS.retile_S(tdSrdS); // ((Atom,AtomNum), MMA_N, MMA_N) + copy(smem_thr_copy_PdS, tdSadS, tdSsdS); + __syncthreads(); + + if (n_block > 0) { + // Double buffer for sK + const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK); + tKsK.data() = tKsK.data() + sK_offset; + tSsK.data() = tSsK.data() + sK_offset; + // Advance gK, gV + tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); + flash::copy(gmem_thr_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + flash::copy(gmem_thr_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); + // This cp_async_fence needs to be in the if block, otherwise the synchronization + // isn't right and we get race conditions. + cute::cp_async_fence(); + } + + Tensor acc_dv = partition_fragment_C(tiled_mma_dkv, Shape, Int>{}); // MMA, MMA_N, MMA_K + clear(acc_dv); + flash::gemm(acc_dv, tdVrPt, tdVrdO, tdVsPt, tdVsdOt, tiled_mma_dkv, smem_thr_copy_PdSt, smem_thr_copy_QdOt); + // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(acc_dv); } + tdVgdVaccum.data() = tdVgdVaccum.data() + (-int(kBlockN * params.d_rounded)); + #pragma unroll + for (int i = 0; i < size(acc_dv); ++i) { atomicAdd(&tdVgdVaccum(i), acc_dv(i)); } + + __syncthreads(); + Tensor acc_dk = partition_fragment_C(tiled_mma_dkv, Shape, Int>{}); // MMA, MMA_N, MMA_K + clear(acc_dk); + flash::gemm(acc_dk, tdKrdSt, tdKrQt, tdKsdSt, tdKsQt, tiled_mma_dkv, smem_thr_copy_PdSt, smem_thr_copy_QdOt); + tdKgdKaccum.data() = tdKgdKaccum.data() + (-int(kBlockN * params.d_rounded)); + #pragma unroll + for (int i = 0; i < size(acc_dk); ++i) { atomicAdd(&tdKgdKaccum(i), acc_dk(i)); } + + flash::gemm(acc_dq, tdQrdS, tdQrKt, tdQsdS, tdQsKt, tiled_mma_dq, smem_thr_copy_dS, smem_thr_copy_Kt); + // Double buffer for sK + tdQsKt.data() = tdQsKt.data() + (n_block % 2 == 0 ? size(sK) : -size(sK)); + + } + + // Epilogue + + #pragma unroll + for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) *= params.scale_softmax_rp_dropout; } + // Convert acc_dq from fp32 to fp16 + Tensor rdQ = flash::convert_type(acc_dq); + + Tensor sdQ = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutdQ{}); + + // Partition sdV and sdK to match the accumulator partitioning + auto smem_thr_copy_dQ = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdQ{}, tiled_mma_dq).get_thread_slice(tidx); + Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ); // ((Atom,AtomNum), MMA_N, MMA_N) + Tensor taccdQsdQ = smem_thr_copy_dQ.partition_D(sdQ); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + __syncthreads(); + copy(smem_thr_copy_dQ, taccdQrdQ, taccdQsdQ); + + const index_t row_offset_dq = binfo.q_offset(params.dq_batch_stride, params.dq_row_stride, bidb) + + m_block * kBlockM * params.dq_row_stride + bidh * params.dq_head_stride; + Tensor gdQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.dq_ptr) + row_offset_dq), + Shape, Int>{}, + make_stride(params.dq_row_stride, _1{})); + + auto gmem_thr_copy_dQ = typename Kernel_traits::GmemTiledCopydQ{}.get_thread_slice(tidx); + Tensor tdQsdQ = gmem_thr_copy_dQ.partition_S(sdQ); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tdQgdQ = gmem_thr_copy_dQ.partition_D(gdQ); + + __syncthreads(); + + Tensor tdQrdQ = make_tensor(shape(tdQgdQ)); + copy(gmem_thr_copy_dQ, tdQsdQ, tdQrdQ); + + Tensor cdQ = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor tdQcdQ = gmem_thr_copy_dQ.partition_D(cdQ); + Tensor tdQpdQ = make_tensor(make_shape(size<2>(tdQgdQ))); + if (!Is_even_K) { + #pragma unroll + for (int k = 0; k < size(tdQpdQ); ++k) { tdQpdQ(k) = get<1>(tdQcdQ(0, 0, k)) < params.d; } + } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_thr_copy_dQ, tdQrdQ, tdQgdQ, tdQcdQ, tdQpdQ, binfo.actual_seqlen_q - m_block * kBlockM + ); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void compute_dq_dk_dv(const Params ¶ms) { + + // The block index for the batch. + const int bidb = blockIdx.x; + // const int bidb = blockIdx.y; + // The block index for the head. + const int bidh = blockIdx.y; + // const int bidh = blockIdx.z; + // The thread index. + const int tidx = threadIdx.x; + + const int n_block_max = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN; + if (n_block_max == 1) { + compute_dq_dk_dv_1colblock(params, bidb, bidh, 0); + } else { + // Iterating backward from n_block_max - 1 to 0 might save 1 register + compute_dq_dk_dv_1colblock(params, bidb, bidh, n_block_max - 1); + for (int n_block = n_block_max - 2; n_block > 0; n_block--) { + compute_dq_dk_dv_1colblock(params, bidb, bidh, n_block); + } + compute_dq_dk_dv_1colblock(params, bidb, bidh, 0); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params ¶ms) { + + const int n_block = blockIdx.x; + // The block index for the batch. + const int bidb = blockIdx.y; + // The block index for the head. + const int bidh = blockIdx.z; + + compute_dq_dk_dv_1colblock(params, bidb, bidh, n_block); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void compute_dq_dk_dv_seqq_parallel(const Params ¶ms) { + + const int m_block = blockIdx.x; + // The block index for the batch. + const int bidb = blockIdx.y; + // The block index for the head. + const int bidh = blockIdx.z; + + compute_dq_dk_dv_1rowblock(params, bidb, bidh, m_block); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +} // namespace flash diff --git a/csrc/flash_attn/src/flash_bwd_launch_template.h b/csrc/flash_attn/src/flash_bwd_launch_template.h new file mode 100644 index 000000000..bf70ac19e --- /dev/null +++ b/csrc/flash_attn/src/flash_bwd_launch_template.h @@ -0,0 +1,355 @@ +// Copyright (c) 2023, Tri Dao. + +#pragma once + +#include + +#include "static_switch.h" +#include "flash.h" +#include "flash_bwd_kernel.h" + +template +__global__ void flash_bwd_dot_do_o_kernel(Flash_bwd_params params) { + flash::compute_dot_do_o(params); +} + +template +__global__ void flash_bwd_clear_dkvaccum_kernel(Flash_bwd_params params) { + flash::clear_dKVaccum(params); +} + +template +__global__ void flash_bwd_dq_dk_dv_loop_kernel(Flash_bwd_params params) { + flash::compute_dq_dk_dv(params); +} + +template +__global__ void flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel(Flash_bwd_params params) { + flash::compute_dq_dk_dv_seqk_parallel(params); +} + +template +__global__ void flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel(Flash_bwd_params params) { + flash::compute_dq_dk_dv_seqq_parallel(params); +} + +template +__global__ void flash_bwd_convert_dq_kernel(Flash_bwd_params params) { + flash::convert_dQ(params); +} + +template +__global__ void flash_bwd_convert_dkv_kernel(Flash_bwd_params params) { + flash::convert_dKV(params); +} + +template +void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { + const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; + dim3 grid_m(num_m_block, params.b, params.h); + const int num_n_block = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN; + dim3 grid_n(num_n_block, params.b, params.h); + + flash_bwd_dot_do_o_kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + // We also use is_even_M to set Unpadded in the BlockInfo constructor, so we need to check + // for cu_seqlens_q as well. + const bool is_even_M = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_q % Kernel_traits::kBlockM == 0; + const bool is_even_K = params.d == Kernel_traits::kHeadDim; + constexpr int smem_size_dq_dk_dv = Kernel_traits::kSmemSize1colblock; + // printf("smem_size_dq_dk_dv = %d\n", smem_size_dq_dk_dv); + BOOL_SWITCH(params.is_causal, IsCausalConst, [&] { + BOOL_SWITCH(is_even_M, IsEvenMConst, [&] { + BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { + auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; + // auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; + if (smem_size_dq_dk_dv >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); + } + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + }); + }); + + auto kernel_dq = &flash_bwd_convert_dq_kernel; + if (Kernel_traits::kSmemdQSize >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel_dq, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemdQSize)); + } + kernel_dq<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +template +void run_flash_bwd_seqq_parallel(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { + const int num_n_block = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN; + dim3 grid_n(num_n_block, params.b, params.h_k); + flash_bwd_clear_dkvaccum_kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; + dim3 grid_m(num_m_block, params.b, params.h); + // We also use is_even_N to set Unpadded in the BlockInfo constructor, so we need to check + // for cu_seqlens_k as well. + const bool is_even_N = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0; + const bool is_even_K = params.d == Kernel_traits::kHeadDim; + constexpr int smem_size_dq_dk_dv = Kernel_traits::kSmemSize1rowblock; + // printf("smem_size_dq_dk_dv = %d\n", smem_size_dq_dk_dv); + BOOL_SWITCH(params.is_causal, IsCausalConst, [&] { + BOOL_SWITCH(is_even_N, IsEvenNConst, [&] { + BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { + auto kernel = &flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel; + // auto kernel = &flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel; + if (smem_size_dq_dk_dv >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); + } + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + }); + }); + + auto kernel_dkv = &flash_bwd_convert_dkv_kernel; + if (Kernel_traits::kSmemKVSize >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel_dkv, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemKVSize)); + } + kernel_dkv<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} +// + +template +void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { + if (configure) return; + // dim3 grid(params.b, params.h); + // const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; + // dim3 grid_m(num_m_block, params.b, params.h); + + // if (params.h == params.h_k) { // No multi-query or grouped-query attention (MQA/GQA) + run_flash_bwd_seqk_parallel(params, stream, configure); + // } else { + // run_flash_bwd_seqq_parallel(params, stream, configure); + // } + + // // We also use is_even_M to set Unpadded in the BlockInfo constructor, so we need to check + // // for cu_seqlens_q as well. + // const bool is_even_M = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_q % Kernel_traits::kBlockM == 0; + // const bool is_even_K = params.d == Kernel_traits::kHeadDim; + // constexpr int smem_size_dq_dk_dv = Kernel_traits::kSmemSize; + // BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + // BOOL_SWITCH(params.is_causal, IsCausalConst, [&] { + // BOOL_SWITCH(is_even_M, IsEvenMConst, [&] { + // BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { + // // auto kernel = &flash_bwd_dq_dk_dv_loop_kernel; + // auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; + // if (smem_size_dq_dk_dv >= 48 * 1024) { + // C10_CUDA_CHECK(cudaFuncSetAttribute( + // kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); + // } + // kernel<<>>(params); + // C10_CUDA_KERNEL_LAUNCH_CHECK(); + // }); + // }); + // }); + // }); + + // auto kernel_dq = &flash_bwd_convert_dq_kernel; + // if (Kernel_traits::kSmemdQSize >= 48 * 1024) { + // C10_CUDA_CHECK(cudaFuncSetAttribute( + // kernel_dq, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemdQSize)); + // } + // kernel_dq<<>>(params); + // C10_CUDA_KERNEL_LAUNCH_CHECK(); +} +// + +template +void run_mha_bwd_hdim32(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { + constexpr int Headdim = 32; + int device; + cudaGetDevice(&device); + int max_smem_per_block; + cudaError status_ = cudaDeviceGetAttribute( + &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); + BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + if (max_smem_per_block >= 2 * ((3 * 128 + 2 * 128) * Headdim + 2 * 128 * 128)) { // 104 KB + if constexpr(!Is_dropout) { // We can afford more registers to keep V in registers + run_flash_bwd, Is_dropout>(params, stream, configure); + } else { + run_flash_bwd, Is_dropout>(params, stream, configure); + } + } else { // 96 KB + run_flash_bwd, Is_dropout>(params, stream, configure); + } + }); +} + +template +void run_mha_bwd_hdim64(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { + constexpr int Headdim = 64; + int device; + cudaGetDevice(&device); + int max_smem_per_block; + cudaError status_ = cudaDeviceGetAttribute( + &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); + // printf("max_smem_per_block = %d\n", max_smem_per_block); + BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + // Changing AtomLayoutMdQ from 2 to 4 takes the same time + // run_flash_bwd>(params, stream, configure); + // run_flash_bwd>(params, stream, configure); + // run_flash_bwd>(params, stream, configure); + // run_flash_bwd, Is_dropout>(params, stream, configure); + // This is slightly faster. We want to split M more so we need fewer registers to store LSE. + if (max_smem_per_block >= 144 * 1024) { + run_flash_bwd, Is_dropout>(params, stream, configure); + // This has a lot of register spilling + // run_flash_bwd, Is_dropout>(params, stream, configure); + } else { + // if (params.h == params.h_k) { + // run_flash_bwd, Is_dropout>(params, stream, configure); + run_flash_bwd, Is_dropout>(params, stream, configure); + // run_flash_bwd, Is_dropout>(params, stream, configure); + // run_flash_bwd, Is_dropout>(params, stream, configure); + // } else { + // run_flash_bwd_seqq_parallel, Is_dropout>(params, stream, configure); + // } + } + }); + // run_flash_bwd>(params, stream, configure); + // run_flash_bwd>(params, stream, configure); + // run_flash_bwd>(params, stream, configure); + // run_flash_bwd>(params, stream, configure); + // M=128, N=64 is quite slow, I think because we need to read/write dQaccum twice as many times + // run_flash_bwd>(params, stream, configure); + // run_flash_bwd>(params, stream, configure); + // run_flash_bwd>(params, stream, configure); + + // run_flash_bwd>(params, stream, configure); +} + +template +void run_mha_bwd_hdim96(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { + constexpr int Headdim = 96; + int device; + cudaGetDevice(&device); + int max_smem_per_block; + cudaError status_ = cudaDeviceGetAttribute( + &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); + // printf("max_smem_per_block = %d\n", max_smem_per_block); + BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + // if (params.h == params.h_k) { + if (max_smem_per_block >= 116 * 1024) { + if constexpr(!Is_dropout) { // 92KB + run_flash_bwd, Is_dropout>(params, stream, configure); + } else { // 116 KB + // This is faster for dropout since we don't have many registers to spare + run_flash_bwd, Is_dropout>(params, stream, configure); + } + } else { + run_flash_bwd, Is_dropout>(params, stream, configure); + } + // } else { + // run_flash_bwd_seqq_parallel>(params, stream, configure); + // } + }); +} + +template +void run_mha_bwd_hdim128(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { + constexpr int Headdim = 128; + int device; + cudaGetDevice(&device); + int max_smem_per_block; + cudaError status_ = cudaDeviceGetAttribute( + &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); + // printf("max_smem_per_block = %d\n", max_smem_per_block); + BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + // if (params.h == params.h_k) { + // run_flash_bwd>(params, stream, configure); + // This is faster, in the case of sequence-parallel bwd (where we need fewer registers). + // Out of these three, the 2nd one is slightly faster (2% faster than the first). Idk why. + // run_flash_bwd>(params, stream, configure); + if (max_smem_per_block >= 144 * 1024) { + run_flash_bwd, Is_dropout>(params, stream, configure); + // run_flash_bwd_seqk_parallel, Is_dropout>(params, stream, configure); + // run_flash_bwd_seqk_parallel, Is_dropout>(params, stream, configure); + // run_flash_bwd_seqq_parallel, Is_dropout>(params, stream, configure); + // run_flash_bwd, Is_dropout>(params, stream, configure); + // run_flash_bwd, Is_dropout>(params, stream, configure); + // run_flash_bwd, Is_dropout>(params, stream, configure); + } else { + // run_flash_bwd, Is_dropout>(params, stream, configure); + run_flash_bwd, Is_dropout>(params, stream, configure); + } + // run_flash_bwd>(params, stream, configure); + + // run_flash_bwd>(params, stream, configure); + // } else { + // run_flash_bwd_seqq_parallel>(params, stream, configure); + // } + }); +} + +template +void run_mha_bwd_hdim160(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { + constexpr int Headdim = 160; + int device; + cudaGetDevice(&device); + int max_smem_per_block; + cudaError status_ = cudaDeviceGetAttribute( + &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); + BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + if (max_smem_per_block >= 116 * 1024) { + run_flash_bwd, Is_dropout>(params, stream, configure); + } else { + run_flash_bwd, Is_dropout>(params, stream, configure); + } + }); +} + +template +void run_mha_bwd_hdim192(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { + constexpr int Headdim = 192; + int device; + cudaGetDevice(&device); + int max_smem_per_block; + cudaError status_ = cudaDeviceGetAttribute( + &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); + BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + if (max_smem_per_block >= 136 * 1024) { + run_flash_bwd, Is_dropout>(params, stream, configure); + } else { + run_flash_bwd, Is_dropout>(params, stream, configure); + } + }); +} + +template +void run_mha_bwd_hdim224(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { + constexpr int Headdim = 224; + BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + run_flash_bwd, Is_dropout>(params, stream, configure); + }); +} + +template +void run_mha_bwd_hdim256(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { + constexpr int Headdim = 256; + int device; + cudaGetDevice(&device); + int max_smem_per_block; + cudaError status_ = cudaDeviceGetAttribute( + &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); + BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + if (max_smem_per_block >= 176 * 1024) { // H100 + run_flash_bwd, Is_dropout>(params, stream, configure); + } else { // A100, we don't do double buffering to save smem + run_flash_bwd, Is_dropout>(params, stream, configure); + } + }); +} diff --git a/csrc/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu new file mode 100644 index 000000000..654400a74 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu @@ -0,0 +1,19 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +// template<> +// void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { +// using elem_type = cutlass::bfloat16_t; +// if (params.p_dropout == 1.f) { +// run_flash_fwd, false>(params, stream); +// } else { +// run_flash_fwd, true>(params, stream); +// } +// } +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128(params, stream); +} \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu new file mode 100644 index 000000000..5b7254a91 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu @@ -0,0 +1,32 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +// template<> +// void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { +// using elem_type = cutlass::half_t; +// if (params.p_dropout == 1.f) { +// // Using 8 warps (128 x 128 and 256 x 64) is 28% slower for seqlen=2k +// run_flash_fwd, false>(params, stream); +// // run_flash_fwd, false>(params, stream); +// // run_flash_fwd, false>(params, stream); +// // run_flash_fwd, false>(params, stream); +// run_flash_fwd, false>(params, stream); +// run_flash_fwd, false>(params, stream); +// run_flash_fwd, false>(params, stream); +// // 1st ones are good for H100, A100 +// // 2nd one is good for A6000 bc we get slightly better occupancy +// } else { +// run_flash_fwd, true>(params, stream); +// run_flash_fwd, true>(params, stream); +// run_flash_fwd, true>(params, stream); +// // 1st one is good for H100, A100, A6000 +// } +// } + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128(params, stream); +} \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu new file mode 100644 index 000000000..6a9d60c39 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu @@ -0,0 +1,17 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +// template<> +// void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { +// using elem_type = cutlass::bfloat16_t; +// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { +// run_flash_fwd, Is_dropout>(params, stream); +// }); +// } +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim160(params, stream); +} \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu new file mode 100644 index 000000000..6c40a164d --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu @@ -0,0 +1,27 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +// template<> +// void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { +// using elem_type = cutlass::half_t; +// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { +// run_flash_fwd, Is_dropout>(params, stream); +// run_flash_fwd, Is_dropout>(params, stream); +// run_flash_fwd, Is_dropout>(params, stream); +// run_flash_fwd, Is_dropout>(params, stream); +// // run_flash_fwd>(params, stream); +// // run_flash_fwd>(params, stream); +// // run_flash_fwd>(params, stream); +// // run_flash_fwd>(params, stream); +// // run_flash_fwd>(params, stream); +// // For A6000, no-causal, 1st is fastest. causal, 4th is fastest. +// // For A100, H100, 1st is fastest. +// }); +// } +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim160(params, stream); +} \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu new file mode 100644 index 000000000..d2f4cba71 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu @@ -0,0 +1,16 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +// template<> +// void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { +// using elem_type = cutlass::bfloat16_t; +// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { +// run_flash_fwd, Is_dropout>(params, stream); +// }); +// } +template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim192(params, stream); +} diff --git a/csrc/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu new file mode 100644 index 000000000..2875c9266 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu @@ -0,0 +1,27 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +// template<> +// void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { +// using elem_type = cutlass::half_t; +// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { +// run_flash_fwd, Is_dropout>(params, stream); +// run_flash_fwd, Is_dropout>(params, stream); +// run_flash_fwd, Is_dropout>(params, stream); +// // This one is slightly faster for causal? +// // run_flash_fwd>(params, stream); +// // run_flash_fwd>(params, stream); +// // run_flash_fwd>(params, stream); +// // run_flash_fwd>(params, stream); +// // run_flash_fwd>(params, stream); +// }); +// // For A100 H100, 1st is faster with dropout, 3rd is faster without dropout +// // For A6000, 1st is faster when causal, 3rd is faster when not causal +// } +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim192(params, stream); +} \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu new file mode 100644 index 000000000..982fe7ead --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim224(params, stream); +} diff --git a/csrc/flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu new file mode 100644 index 000000000..4c083f7b6 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim224(params, stream); +} diff --git a/csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu new file mode 100644 index 000000000..cb074a95e --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256(params, stream); +} diff --git a/csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu new file mode 100644 index 000000000..ddf5e1322 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256(params, stream); +} diff --git a/csrc/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu new file mode 100644 index 000000000..81e359e16 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim32(params, stream); +} \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu new file mode 100644 index 000000000..91e6331e9 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu @@ -0,0 +1,23 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +// template<> +// void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { +// using elem_type = cutlass::half_t; +// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { +// run_flash_fwd, Is_dropout>(params, stream); +// // For dropout there might be a lot of register spilling? +// // These two are very slow due to register spilling +// // run_flash_fwd>(params, stream); +// // run_flash_fwd>(params, stream); +// // This one is slightly slower +// // run_flash_fwd>(params, stream); +// }); +// } +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim32(params, stream); +} \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu new file mode 100644 index 000000000..fffcbebb5 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu @@ -0,0 +1,19 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +// template<> +// void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { +// using elem_type = cutlass::bfloat16_t; +// if (params.p_dropout == 1.f) { +// run_flash_fwd, false>(params, stream); +// } else { +// run_flash_fwd, true>(params, stream); +// } +// } +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64(params, stream); +} \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu new file mode 100644 index 000000000..01bd17167 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu @@ -0,0 +1,26 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +// template<> +// void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { +// using elem_type = cutlass::half_t; +// if (params.p_dropout == 1.f) { +// // Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower +// // Using block size (64 x 256) is 27% slower for seqlen=2k +// // Using block size (256 x 64) is 85% slower for seqlen=2k, because of register spilling +// run_flash_fwd, false>(params, stream); +// run_flash_fwd, false>(params, stream); +// run_flash_fwd, false>(params, stream); +// } else { +// run_flash_fwd, true>(params, stream); +// run_flash_fwd, true>(params, stream); +// run_flash_fwd, true>(params, stream); +// } +// } +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64(params, stream); +} \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu new file mode 100644 index 000000000..b0b27db59 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu @@ -0,0 +1,17 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +// template<> +// void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { +// using elem_type = cutlass::bfloat16_t; +// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { +// run_flash_fwd, Is_dropout>(params, stream); +// }); +// } +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim96(params, stream); +} \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu new file mode 100644 index 000000000..820b63cbb --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu @@ -0,0 +1,23 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +// template<> +// void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { +// using elem_type = cutlass::half_t; +// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { +// run_flash_fwd, Is_dropout>(params, stream); +// run_flash_fwd, Is_dropout>(params, stream); +// // This 3rd one is good for H100, and A100, A6000 +// run_flash_fwd, Is_dropout>(params, stream); +// run_flash_fwd, Is_dropout>(params, stream); +// // These two are always slower +// // run_flash_fwd>(params, stream); +// // run_flash_fwd>(params, stream); +// }); +// } +template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim96(params, stream); +} \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h new file mode 100644 index 000000000..2eba4ef12 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_kernel.h @@ -0,0 +1,576 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include +#include + +#include +#include +#include +#include + +#include "block_info.h" +#include "kernel_traits.h" +#include "utils.h" +#include "softmax.h" +#include "philox.cuh" + +namespace flash { + +using namespace cute; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTE_HOST_DEVICE +auto +make_tiled_copy_A_warpcontiguousM(Copy_Atom const& copy_atom, + TiledMMA const& tiled_mma) { + using TileShape_MNK = typename TiledMMA::TiledShape_MNK; + using AtomShape_MNK = typename TiledMMA::AtomShape_MNK; + constexpr int AtomShape_M = decltype(size<0>(AtomShape_MNK{}))::value; + constexpr int kNWarps = decltype(size<0>(TileShape_MNK{}))::value / AtomShape_M; + constexpr int MMAStride_M = MMA_M * AtomShape_M; + auto t = make_tile(Layout, Int>, + Stride<_1, Int> >{}, + make_layout(size<2>(TileShape_MNK{}))); + // if (cute::thread0()) {printf("make_tiled_copy_A_warpcontiguousM "); print(t); printf("\n"); } + return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutA_TV(), t); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTE_HOST_DEVICE +auto +make_tiled_copy_C_warpcontiguousM(Copy_Atom const& copy_atom, + TiledMMA const& tiled_mma) { + using TileShape_MNK = typename TiledMMA::TiledShape_MNK; + using AtomShape_MNK = typename TiledMMA::AtomShape_MNK; + constexpr int AtomShape_M = decltype(size<0>(AtomShape_MNK{}))::value; + constexpr int kNWarps = decltype(size<0>(TileShape_MNK{}))::value / AtomShape_M; + constexpr int MMAStride_M = MMA_M * AtomShape_M; + auto t = make_tile(Layout, Int>, + Stride<_1, Int> >{}, + // TODO: Shouldn't this be size<1>? + make_layout(size<2>(TileShape_MNK{}))); + // if (cute::thread0()) {printf("make_tiled_copy_C_warpcontiguousM "); print(t); printf("\n"); } + return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutC_TV(), t); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &scores_max, Tensor1 &scores_sum, + Tensor2 &acc_o, float softmax_scale_log2) { + if (Is_first) { + flash::template reduce_max(scores, scores_max); + flash::scale_apply_exp2(scores, scores_max, softmax_scale_log2); + flash::reduce_sum(scores, scores_sum); + } else { + Tensor scores_max_prev = make_fragment_like(scores_max); + copy(scores_max, scores_max_prev); + flash::template reduce_max(scores, scores_max); + // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) + Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); + #pragma unroll + for (int mi = 0; mi < size(scores_max); ++mi) { + float scores_max_cur = !Check_inf + ? scores_max(mi) + : (scores_max(mi) == -INFINITY ? 0.0f : scores_max(mi)); + float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2); + scores_sum(mi) *= scores_scale; + #pragma unroll + for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; } + } + flash::scale_apply_exp2(scores, scores_max, softmax_scale_log2); + Tensor scores_sum_cur = make_fragment_like(scores_sum); + flash::reduce_sum(scores, scores_sum_cur); + #pragma unroll + for (int mi = 0; mi < size(scores_sum); ++mi) { scores_sum(mi) += scores_sum_cur(mi); } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void write_softmax_to_gmem( + Tensor const &tOrP, Tensor &tPgP, TiledCopy gmem_thr_copy_P +) { + // Reshape tOrP from (8, MMA_M, MMA_N) to (8, MMA_M * MMA_N) + Layout l = tOrP.layout(); + Tensor tPrP = make_tensor(tOrP.data(), make_layout(get<0>(l), make_layout(get<1>(l), get<2>(l)))); + CUTE_STATIC_ASSERT_V(size<2>(tPgP) == _1{}); + CUTE_STATIC_ASSERT_V(size<1>(tPrP) == size<1>(tPgP)); + #pragma unroll + for (int mi = 0; mi < size<1>(tPrP); ++mi) { + copy(gmem_thr_copy_P, tPrP(_, mi), tPgP(_, mi, 0)); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bidb, const int bidh, const int m_block) { + + using Element = typename Kernel_traits::Element; + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + + // Shared memory. + extern __shared__ char smem_[]; + + // The thread index. + const int tidx = threadIdx.x; + + constexpr int kBlockM = Kernel_traits::kBlockM; + constexpr int kBlockN = Kernel_traits::kBlockN; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + constexpr int kNWarps = Kernel_traits::kNWarps; + constexpr int MMA_M = kBlockM / decltype(size<0>(typename Kernel_traits::TiledMma::TiledShape_MNK{}))::value; + + const BlockInfo binfo(params, bidb); + if (m_block * kBlockM >= binfo.actual_seqlen_q || binfo.actual_seqlen_k == 0) return; + + int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN); + if (Is_causal) { + n_block_max = std::min(n_block_max, cute::ceil_div((m_block + 1) * kBlockM, kBlockN)); + // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { + // printf("m_block = %d, n_block_max = %d\n", m_block, n_block_max); + // } + } + + // We iterate over the blocks in reverse order. This is because the last block is the only one + // that needs masking when we read K and V from global memory. Moreover, iterating in reverse + // might save us 1 register (we just need n_block instead of both n_block and n_block_max). + + const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb) + + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride; + // We move K and V to the last block. + const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb) + + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; + const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb) + + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; + const index_t row_offset_p = ((bidb * params.h + bidh) * params.seqlen_q_rounded + + m_block * kBlockM) * params.seqlen_k_rounded + (n_block_max - 1) * kBlockN; + + Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), + Shape, Int>{}, + make_stride(params.q_row_stride, _1{})); + Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), + Shape, Int>{}, + make_stride(params.k_row_stride, _1{})); + Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast(params.v_ptr) + row_offset_v), + Shape, Int>{}, + make_stride(params.v_row_stride, _1{})); + Tensor gP = make_tensor(make_gmem_ptr(reinterpret_cast(params.p_ptr) + row_offset_p), + Shape, Int>{}, + make_stride(params.seqlen_k_rounded, _1{})); + + Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), + typename Kernel_traits::SmemLayoutQ{}); + // Careful we're using the same smem for sQ and sK | sV if Share_Q_K_smem; + Tensor sK = make_tensor(sQ.data() + (Kernel_traits::Share_Q_K_smem ? 0 : size(sQ)), + typename Kernel_traits::SmemLayoutKV{}); + Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{}); + Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); + Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); + + auto gmem_thr_copy_QKV = typename Kernel_traits::GmemTiledCopyQKV{}.get_thread_slice(tidx); + auto gmem_thr_copy_P = typename Kernel_traits::GmemTiledCopyP{}.get_thread_slice(tidx); + + Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); + Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); + Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K) + Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); + Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) + Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); + Tensor tPgP = gmem_thr_copy_P.partition_D(gP); + + typename Kernel_traits::TiledMma tiled_mma; + auto thr_mma = tiled_mma.get_thread_slice(tidx); + Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) + Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) + Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N) + + Tensor acc_o = partition_fragment_C(tiled_mma, Shape, Int>{}); // MMA, MMA_M, MMA_K + + // + // Copy Atom retiling + // + + auto smem_thr_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma).get_thread_slice(tidx); + // auto smem_thr_copy_Q = make_tiled_copy_A_warpcontiguousM(typename Kernel_traits::SmemCopyAtom{}, tiled_mma).get_thread_slice(tidx); + // if (cute::thread0()) {smem_thr_copy_Q.print_all();} + Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); + // if (cute::thread0()) {print(tSsQ.layout()); printf("\n");} + + auto smem_thr_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma).get_thread_slice(tidx); + Tensor tSsK = smem_thr_copy_K.partition_S(sK); + + auto smem_thr_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma).get_thread_slice(tidx); + Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); + + // TODO: this might need to change if we change the mma instruction in SM70 + Tensor scores_max = make_tensor(Shape(acc_o)>>{}); + Tensor scores_sum = make_fragment_like(scores_max); + + // + // PREDICATES + // + + // // Allocate predicate tensors for m and n + // Tensor tQpQ = make_tensor(make_shape(size<1>(tQsQ), size<2>(tQsQ)), Stride<_1,_0>{}); + // Tensor tKVpKV = make_tensor(make_shape(size<1>(tKsK), size<2>(tKsK)), Stride<_1,_0>{}); + + // Construct identity layout for sQ and sK + Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) + // Tensor tScQ = thr_mma.partition_A(cQ); // (MMA,MMA_M,MMA_K) + // if (cute::thread0()) { + // print(tScQ.layout()); printf("\n"); + // for (int i = 0; i < size(tScQ); ++i) { + // printf("%d ", get<0>(tScQ(i))); + // } + // printf("\n"); + // for (int i = 0; i < size(tScQ); ++i) { + // printf("%d ", get<1>(tScQ(i))); + // } + // printf("\n"); + // } + + // Repeat the partitioning with identity layouts + Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) + + // Allocate predicate tensors for k + Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); + Tensor tKVpKV = make_tensor(make_shape(size<2>(tKsK))); + + // Set predicates for k bounds + if (!Is_even_K) { + #pragma unroll + for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; } + #pragma unroll + for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; } + } + + // Prologue + + Tensor tQrQ = make_fragment_like(tQgQ); + // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs + flash::copy(gmem_thr_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, + binfo.actual_seqlen_q - m_block * kBlockM); + if (Kernel_traits::Is_Q_in_regs) { cute::cp_async_fence(); } + + // // Copy rmem to smem + // // copy(tQrQ, tQsQ); + // flash::cp_async_wait<0>(); + // __syncthreads(); + // // if (cute::thread(1, 0)) { print(tQsQ); } + // // Tensor sQNoSwizzle = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutQNoSwizzle{}); + // // if (cute::thread0()) { print(sQNoSwizzle); } + + if (Kernel_traits::Share_Q_K_smem) { + flash::cp_async_wait<0>(); + __syncthreads(); + Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); + CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M + copy(smem_thr_copy_Q, tSsQ, tSrQ_copy_view); + __syncthreads(); + } + + int n_block = n_block_max - 1; + // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. + flash::copy(gmem_thr_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, + binfo.actual_seqlen_k - n_block * kBlockN); + cute::cp_async_fence(); + // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z < 2) { print(tKgK); } + // __syncthreads(); + + if (Kernel_traits::Is_Q_in_regs && !Kernel_traits::Share_Q_K_smem) { + flash::cp_async_wait<1>(); + __syncthreads(); + Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); + CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M + copy(smem_thr_copy_Q, tSsQ, tSrQ_copy_view); + } + + auto seeds = at::cuda::philox::unpack(params.philox_args); + unsigned long long seed = std::get<0>(seeds); + unsigned long long offset = std::get<1>(seeds) + (bidb * params.h + bidh) * 32 + tidx % 32; + + clear(acc_o); + + // For performance reason, we separate out two kinds of iterations: + // those that need masking on S, and those that don't. + // We need masking on S for the very last block when K and V has length not multiple of kBlockN. + // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks. + // We will have at least 1 "masking" iteration. + + constexpr int n_masking_steps = Is_causal ? cute::ceil_div(kBlockM, kBlockN) : 1; + #pragma unroll + for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) { + Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) + clear(acc_s); + flash::cp_async_wait<0>(); + __syncthreads(); + + // Advance gV + if (masking_step > 0) { + tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); + flash::copy(gmem_thr_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); + } else { + // Clear the smem tiles to account for predicated off loads + flash::copy( + gmem_thr_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN + ); + } + cute::cp_async_fence(); + + flash::gemm( + acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_thr_copy_Q, smem_thr_copy_K + ); + // if (cute::thread0()) { print(acc_s); } + + // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) + Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); + // if (cute::thread0()) { print(scores); } + // We don't put the masking before the matmul S = Q K^T because we don't clear sK + // for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul + // can produce Inf / NaN. + if (!Is_causal) { + if (!Is_even_N) { flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); } + } else { + // Tensor caccS = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_N) -> (blk_m,blk_n) + // Tensor taccScS = thr_mma.partition_C(caccS); // (MMA,MMA_M,MMA_N) + // static_assert(decltype(size<0>(taccScS))::value == 4); + // // Convert to ((2, 2), MMA_M, MMA_N) then take only the row indices. + // Tensor idx_row = logical_divide(taccScS, Shape<_2>{})(make_coord(0, _), _, 0); + // Tensor idx_rowcol = make_tensor(taccScS.data(), flash::convert_layout_acc_rowcol(taccScS.layout())); + // flash::apply_mask_causal_w_idx(scores, idx_rowcol, n_block * kBlockN, binfo.actual_seqlen_k, + // m_block * kBlockM); + // Idk why it's get<1> and not get<0> of the stride. + // if (cute::thread0()) { print(idx_row.layout()); print(stride<1>(idx_row)); printf("stride = %d \n", get<1>(stride<1>(idx_row))); } + // I can't get the stride from idx_row + flash::apply_mask_causal(scores, n_block * kBlockN, binfo.actual_seqlen_k, + // m_block * kBlockM + get<0>(idx_row(0)), + m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, + kNWarps * 16); + // m_block * kBlockM + (tidx / 32) * 16, kNWarps * 16); + // m_block * kBlockM + (tidx / 32) * (kBlockM / kNWarps), 16); + } + + flash::cp_async_wait<0>(); + __syncthreads(); + if (n_block > 0) { + // Advance gK + tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + flash::copy(gmem_thr_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + // This cp_async_fence needs to be in the if block, otherwise the synchronization + // isn't right and we get race conditions. + cute::cp_async_fence(); + } + + // TODO: when we have key_padding_mask we'll need to Check_inf + masking_step == 0 + ? softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2) + : softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); + + // Convert scores from fp32 to fp16/bf16 + Tensor rP = flash::convert_type(scores); + // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. + Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs(rP.layout())); + uint32_t block_row_idx = m_block * (kBlockM / 16) + tidx / 32; + uint32_t block_col_idx = n_block * (kBlockN / 32); + if (Return_softmax) { + Tensor tOrP_copy = make_fragment_like(tOrP); + copy(tOrP, tOrP_copy); + flash::apply_dropout( + tOrP_copy, params.p_dropout_in_uint8_t, seed, offset, + block_row_idx, block_col_idx, kNWarps + ); + flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_thr_copy_P); + tPgP.data() = tPgP.data() + (-kBlockN); + } + if (Is_dropout) { + flash::apply_dropout(tOrP, params.p_dropout_in_uint8_t, seed, offset, + block_row_idx, block_col_idx, kNWarps); + } + // if (cute::thread0()) { print(tOrP); } + + flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_thr_copy_V); + // if (cute::thread0()) { print(scores); } + + // This check is at the end of the loop since we always have at least 1 iteration + if (n_masking_steps > 1 && n_block <= 0) { + --n_block; + break; + } + } + + // These are the iterations where we don't need masking on S + for (; n_block >= 0; --n_block) { + Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) + clear(acc_s); + flash::cp_async_wait<0>(); + __syncthreads(); + // Advance gV + tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); + flash::copy(gmem_thr_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); + cute::cp_async_fence(); + + flash::gemm( + acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_thr_copy_Q, smem_thr_copy_K + ); + + flash::cp_async_wait<0>(); + __syncthreads(); + if (n_block > 0) { + // Advance gK + tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + flash::copy(gmem_thr_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + // This cp_async_fence needs to be in the if block, otherwise the synchronization + // isn't right and we get race conditions. + cute::cp_async_fence(); + } + + // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) + Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); + softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); + + Tensor rP = flash::convert_type(scores); + // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. + Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs(rP.layout())); + uint32_t block_row_idx = m_block * (kBlockM / 16) + tidx / 32; + uint32_t block_col_idx = n_block * (kBlockN / 32); + if (Return_softmax) { + Tensor tOrP_copy = make_fragment_like(tOrP); + copy(tOrP, tOrP_copy); + flash::apply_dropout( + tOrP_copy, params.p_dropout_in_uint8_t, seed, offset, + block_row_idx, block_col_idx, kNWarps + ); + flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_thr_copy_P); + tPgP.data() = tPgP.data() + (-kBlockN); + } + if (Is_dropout) { + flash::apply_dropout(tOrP, params.p_dropout_in_uint8_t, seed, offset, + block_row_idx, block_col_idx, kNWarps); + } + + flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_thr_copy_V); + } + + // Epilogue + + // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) + Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); + Tensor lse = make_fragment_like(scores_sum); + #pragma unroll + for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) { + float sum = scores_sum(mi); + float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; + lse(mi) = (sum == 0.f || sum != sum) ? INFINITY : scores_max(mi) * params.scale_softmax + __logf(sum); + float scale = !Is_dropout ? inv_sum : inv_sum * params.rp_dropout; + #pragma unroll + for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; } + } + + // if (cute::thread0()) { print(acc_o_rowcol); } + + // Convert acc_o from fp32 to fp16/bf16 + Tensor rO = flash::convert_type(acc_o); + Tensor sO = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) + // Partition sO to match the accumulator partitioning + auto smem_thr_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma).get_thread_slice(tidx); + // auto smem_thr_copy_O = make_tiled_copy_C_warpcontiguousM(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma).get_thread_slice(tidx); + Tensor taccOrO = smem_thr_copy_O.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) + Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + // sO has the same size as sQ, so we don't need to sync here. + if (Kernel_traits::Share_Q_K_smem) { __syncthreads(); } + + copy(smem_thr_copy_O, taccOrO, taccOsO); + + const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; + const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM; + Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + row_offset_o), + Shape, Int>{}, + make_stride(params.o_row_stride, _1{})); + Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), + Shape>{}, Stride<_1>{}); + + auto gmem_thr_copy_O = typename Kernel_traits::GmemTiledCopyO{}.get_thread_slice(tidx); + Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tOgO = gmem_thr_copy_O.partition_D(gO); + + __syncthreads(); + + Tensor tOrO = make_tensor(shape(tOgO)); + copy(gmem_thr_copy_O, tOsO, tOrO); + + Tensor caccO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) + static_assert(decltype(size<0>(taccOcO))::value == 4); + // Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices. + Tensor taccOcO_row = logical_divide(taccOcO, Shape<_2>{})(make_coord(0, _), _, 0); + CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M + if (get<1>(taccOcO_row(0)) == 0) { + #pragma unroll + for (int mi = 0; mi < size(lse); ++mi) { + const int row = get<0>(taccOcO_row(mi)); + if (row < binfo.actual_seqlen_q - m_block * kBlockM) { gLSE(row) = lse(mi); } + } + } + + // Construct identity layout for sO + Tensor cO = make_identity_tensor(make_shape(size<0>(sO), size<1>(sO))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_O.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tOpO = make_tensor(make_shape(size<2>(tOgO))); + if (!Is_even_K) { + #pragma unroll + for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; } + } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_thr_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM + ); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void compute_attn(const Params ¶ms) { + const int m_block = blockIdx.x; + // The block index for the batch. + const int bidb = blockIdx.y; + // The block index for the head. + const int bidh = blockIdx.z; + + // We want the fwd and bwd to generate the same dropout pattern (RNG), without restricting + // them to have the same number of threads or have to traverse the attention matrix + // in the same order. + // In the Philox RNG, we use the offset to store the batch, head, and the lane id + // (within a warp). We use the subsequence to store the location of the 16 x 32 blocks within + // the attention matrix. This way, as long as we have the batch, head, and the location of + // the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern. + + flash::compute_attn_1rowblock(params, bidb, bidh, m_block); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace flash diff --git a/csrc/flash_attn/src/flash_fwd_launch_template.h b/csrc/flash_attn/src/flash_fwd_launch_template.h new file mode 100644 index 000000000..f48186aeb --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_launch_template.h @@ -0,0 +1,251 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include + +#include "static_switch.h" +#include "flash.h" +#include "flash_fwd_kernel.h" + +template +__global__ void flash_fwd_kernel(Flash_fwd_params params) { + flash::compute_attn(params); +} + +template +void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr size_t smem_size = Kernel_traits::kSmemSize; + // printf("smem_size = %d\n", smem_size); + + // Work-around for gcc 7. It doesn't like nested BOOL_SWITCH. + // https://github.com/kokkos/kokkos-kernels/issues/349 + // https://github.com/HazyResearch/flash-attention/issues/21 + + const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; + dim3 grid(num_m_block, params.b, params.h); + // We also use is_even_N to set Unpadded in the BlockInfo constructor, so we need to check + // for cu_seqlens_q as well. + const bool is_even_N = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0; + const bool is_even_K = params.d == Kernel_traits::kHeadDim; + const bool return_softmax = params.p_ptr != nullptr; + BOOL_SWITCH(is_even_N, IsEvenNConst, [&] { + BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { + BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] { + // Will only return softmax if dropout, to reduce compilation time. + auto kernel = &flash_fwd_kernel; + // auto kernel = &flash_fwd_kernel; + if (smem_size >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + int ctas_per_sm; + cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); + // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + }); + }); +} + +template +void run_mha_fwd_hdim32(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr int Headdim = 32; + BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + }); + }); +} + +template +void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr int Headdim = 64; + BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + if constexpr(!Is_dropout) { + // Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower + // Using block size (64 x 256) is 27% slower for seqlen=2k + // Using block size (256 x 64) is 85% slower for seqlen=2k, because of register spilling + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } + }); + }); +} + +template +void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr int Headdim = 96; + auto dprops = at::cuda::getCurrentDeviceProperties(); + bool is_sm8x = dprops->major == 8 && dprops->minor > 0; + BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), + if (is_sm8x) { + if constexpr(!Is_causal) { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } + } else { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // These two are always slower + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + }); + }); +} + +template +void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr int Headdim = 128; + auto dprops = at::cuda::getCurrentDeviceProperties(); + bool is_sm8x = dprops->major == 8 && dprops->minor > 0; + BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + if constexpr(!Is_dropout) { + // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), + // and 128 x 32 (48 KB smem) is the fastest for non-causal since we get 2 CTAs per SM. + if (is_sm8x) { + if constexpr(!Is_causal) { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } + } else { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // Using 8 warps (128 x 128 and 256 x 64) is 28% slower for seqlen=2k + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // 1st ones are good for H100, A100 + // 2nd one is good for A6000 bc we get slightly better occupancy + } else { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } + }); + }); +} + +template +void run_mha_fwd_hdim160(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr int Headdim = 160; + auto dprops = at::cuda::getCurrentDeviceProperties(); + bool is_sm8x = dprops->major == 8 && dprops->minor > 0; + BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + // For A100, H100, 128 x 32 is the fastest. + // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), + // and 128 x 64 with 8 warps is the fastest for non-causal. + if (is_sm8x) { + if constexpr(!Is_causal) { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } + } else { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + }); + }); +} + +template +void run_mha_fwd_hdim192(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr int Headdim = 192; + BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + if constexpr(!Is_dropout) { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + }); + }); +} + +template +void run_mha_fwd_hdim224(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr int Headdim = 224; + int device; + cudaGetDevice(&device); + int max_smem_per_block; + cudaError status_ = cudaDeviceGetAttribute( + &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); + // printf("max_smem_per_block = %d\n", max_smem_per_block); + BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64)) { // 112 KB + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // We can't do 128 x 32 with 8 warps because with headdim 224, kBlockKSmem = 32. + // If we have N = 32, there are only 1024 elements to load at once, where each load + // is 8 elements. This means we can only use 128 threads and not 256 threads. + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + }); + }); +} + +template +void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr int Headdim = 256; + int device; + cudaGetDevice(&device); + int max_smem_per_sm, max_smem_per_block; + cudaError status_ = cudaDeviceGetAttribute( + &max_smem_per_sm, cudaDevAttrMaxSharedMemoryPerMultiprocessor, device); + status_ = cudaDeviceGetAttribute( + &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); + // printf("max_smem_per_sm = %d, max_smem_per_block = %d\n", max_smem_per_sm, max_smem_per_block); + BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + // For A100, we want to run with 128 x 64 (128KB smem). + // For H100 we want to run with 64 x 64 (96KB smem) since then we can get 2 CTAs per SM. + if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64) && max_smem_per_sm < 4 * Headdim * (64 + 2 * 64)) { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } + // 64 KB + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // 96 KB + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + }); + }); +} diff --git a/csrc/flash_attn/src/fmha.h b/csrc/flash_attn/src/fmha.h deleted file mode 100644 index 2905e6dce..000000000 --- a/csrc/flash_attn/src/fmha.h +++ /dev/null @@ -1,211 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the - * names of its contributors may be used to endorse or promote products - * derived from this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND - * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED - * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY - * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES - * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; - * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND - * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - ******************************************************************************/ - -#pragma once - -#include -#include - -#ifdef OLD_GENERATOR_PATH -#include -#else -#include -#endif - -#include -#include - -#include - - -constexpr int TOTAL_DIM = 0; -constexpr int H_DIM = 1; -constexpr int D_DIM = 2; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct Qkv_params { - // The QKV matrices. - void *__restrict__ q_ptr; - void *__restrict__ k_ptr; - void *__restrict__ v_ptr; - - // The stride between rows of the Q, K and V matrices. - // size_t qkv_stride_in_elts; - // size_t qkv_stride_in_bytes; - // TD [2022-04-16]: We're using 32-bit indexing to save registers. - // The code probably won't work for arrays larger than 2GB. - uint32_t q_row_stride_in_elts; - uint32_t k_row_stride_in_elts; - uint32_t v_row_stride_in_elts; - uint32_t q_head_stride_in_elts; - uint32_t k_head_stride_in_elts; - uint32_t v_head_stride_in_elts; - - // The number of heads. - int h; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct FMHA_fprop_params : public Qkv_params { - - // The O matrix (output). - void * __restrict__ o_ptr; - - // The stride between rows of O. - // size_t o_stride_in_elts; - // size_t o_stride_in_bytes; - uint32_t o_row_stride_in_elts; - uint32_t o_head_stride_in_elts; - uint32_t o_tmp_row_stride_in_elts; - uint32_t o_tmp_head_stride_in_elts; - - // The pointer to the O_tmp matrix, which holds O intermediate value during - // the loop; - void *__restrict__ o_tmp_ptr; - - // The pointer to the S matrix. - void * __restrict__ s_ptr; - // The stride between rows of the S matrix. - // int64_t s_stride_in_bytes; - uint32_t s_stride_in_bytes; - - // The pointer to the softmax sum. - void * __restrict__ softmax_lse_ptr; - - // The dimensions. - int b, seqlen_q, seqlen_k, d; - - // The scaling factors for the kernel. - float scale_bmm1f; - uint32_t scale_bmm1; - - // array of length b+1 holding starting offset of each sequence. - int * __restrict__ cu_seqlens_q; - int * __restrict__ cu_seqlens_k; - - int *__restrict__ blockmask; - - // The dropout probability (probability of keeping an activation). - float p_dropout; - uint32_t p_dropout_in_uint; - uint16_t p_dropout_in_uint16_t; - - // Scale factor of 1 / (1 - p_dropout). - float rp_dropout; - float scale_bmm1_rp_dropout; - - // Scale factor of 1 / (1 - p_dropout), in half2. - uint32_t scale_dropout; - - // Random state. - at::PhiloxCudaState philox_args; - // Pointer to the RNG seed (idx 0) and offset (idx 1). - uint64_t * rng_state; - - bool is_bf16; - bool is_causal; - - int num_splits; // How many SMs per attention matrix. -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct FMHA_dgrad_params : public FMHA_fprop_params { - - // The dQKV matrices. - void *__restrict__ dq_ptr; - void *__restrict__ dk_ptr; - void *__restrict__ dv_ptr; - - // // To accumulate dK and dV in case we're splitting the bwd along seqlen_q dimension - // void *__restrict__ dk_accum_ptr; - // void *__restrict__ dv_accum_ptr; - - // The stride between rows of the dQ, dK and dV matrices. - // TD [2022-04-16]: We're using 32-bit indexing to save registers. - // The code probably won't work for arrays larger than 2GB. - uint32_t dq_row_stride_in_elts; - uint32_t dk_row_stride_in_elts; - uint32_t dv_row_stride_in_elts; - uint32_t dq_head_stride_in_elts; - uint32_t dk_head_stride_in_elts; - uint32_t dv_head_stride_in_elts; - - // The dO matrix. We assume it is contiguous. - void * __restrict__ do_ptr; - - // The pointer to the softmax d sum. - void * __restrict__ dsoftmax_sum; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Launch_params{ - Launch_params(cudaDeviceProp * props_, - cudaStream_t stream_, - bool is_dropout_, - bool return_softmax_) - : elts_per_thread(0) - , props(props_) - , stream(stream_) - , is_dropout(is_dropout_) - , return_softmax(return_softmax_) { - } - - size_t elts_per_thread; - - cudaDeviceProp * props; - - cudaStream_t stream; - - bool is_dropout; - bool return_softmax; - - Kernel_params params; - int num_full_heads; - int num_main_groups; - int heads_last_wave; - int main_steps; - int rest_steps; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void run_fmha_fwd_hdim32(Launch_params &launch_params); -void run_fmha_fwd_hdim64(Launch_params &launch_params); -void run_fmha_fwd_hdim128(Launch_params &launch_params); - -void run_fmha_bwd_hdim32(FMHA_dgrad_params ¶ms, cudaStream_t stream, const bool configure); -void run_fmha_bwd_hdim64(FMHA_dgrad_params ¶ms, cudaStream_t stream, const bool configure); -void run_fmha_bwd_hdim128(FMHA_dgrad_params ¶ms, cudaStream_t stream, const bool configure); - -void run_fmha_block_fp16_sm80(Launch_params &launch_params, const bool configure); - -void run_fmha_block_dgrad_fp16_sm80(const FMHA_dgrad_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/fmha/gemm.h b/csrc/flash_attn/src/fmha/gemm.h deleted file mode 100644 index a142f0bf2..000000000 --- a/csrc/flash_attn/src/fmha/gemm.h +++ /dev/null @@ -1,451 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the - * names of its contributors may be used to endorse or promote products - * derived from this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND - * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED - * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY - * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES - * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; - * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND - * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - ******************************************************************************/ - -#pragma once - -#include - -#include "cutlass/cutlass.h" -#include "cutlass/gemm/warp/default_mma_tensor_op.h" -#include "cutlass/layout/layout.h" -#include -#include -#include - -namespace fmha { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< typename Data_type_, int NUM_ELTS_, int BITS_PER_ELT_, int ALIGNMENT_ > -struct Fragment_base_ { - - // The data type. - using Data_type = Data_type_; - // default input type - using Input_type_ = Data_type_; - // Does it store the array of elements. - static constexpr bool HAS_ELTS = BITS_PER_ELT_ >= 8; - // The number of elements. - static constexpr int NUM_ELTS = NUM_ELTS_; - // The size of element in bits. - static constexpr int BITS_PER_ELT = BITS_PER_ELT_; - // The size of byte of a single register. - static constexpr int BYTES_PER_REG = 4; - // The size in bits. - static constexpr int BITS_PER_REG = BYTES_PER_REG * 8; - // The number of registers needed to store the fragment. - static constexpr int NUM_REGS = DivUpConstexpr(NUM_ELTS * BITS_PER_ELT, BITS_PER_REG); - // The size in bytes (as returned by sizeof(Fragment_base<>). - static constexpr int SIZE_IN_BYTES = NUM_REGS * BYTES_PER_REG; - // The alignment. - static constexpr int ALIGNMENT = ALIGNMENT_ > 0 ? ALIGNMENT_ : MinConstexpr(NUM_REGS * BYTES_PER_REG, 16); -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< - // The type of the elements. - typename Data_type_, - // The number of elements. - int NUM_ELTS_, - // The alignment if you want to force a value -- use 0 otherwise. - int ALIGNMENT_ = 0, - // The base class. - typename Base_ = Fragment_base_ -> -struct alignas(static_cast(Base_::ALIGNMENT)) Fragment : public Base_ { - - // The size of a load/store. - static constexpr int BYTES_PER_LOAD_STORE = Base_::NUM_REGS * sizeof(uint32_t); - - // Clear the fragment. Using PTX in that code seems to produce better SASS... - inline __device__ void clear() { - #pragma unroll - for( int ii = 0; ii < Base_::NUM_REGS; ++ii ) { - asm volatile("mov.u32 %0, 0; \n" : "=r"(this->reg(ii)) : ); - } - } - - // Immutable access to a register. - inline __device__ const uint32_t& reg(int ii) const { - return this->regs_[ii]; - } - - // Mutable access to a register. - inline __device__ uint32_t& reg(int ii) { - return this->regs_[ii]; - } - - uint32_t regs_[Base_::NUM_REGS]; - - // Immutable access to the elements. - inline __device__ const Data_type_& elt(int ii) const { - return reinterpret_cast(&this->regs_[0])[ii]; - } - - // Mutable access to the elements. - inline __device__ Data_type_& elt(int ii) { - return reinterpret_cast(&this->regs_[0])[ii]; - } - - // Immutable access to the elements with a cast. - template< typename Cast_type > - inline __device__ const Cast_type& elt_as(int ii) const { - return reinterpret_cast(&this->regs_[0])[ii]; - } - - // Mutable access to the elements. - template< typename Cast_type > - inline __device__ Cast_type& elt_as(int ii) { - return reinterpret_cast(&this->regs_[0])[ii]; - } - - // Add another fragment. - inline __device__ void add(const Fragment &other) { - // TODO (TD 2022-04-09): Shouldn't this be NUM_REGS instead of NUM_ELTS? - // Also are we doing int addition or __half2 addition? - #pragma unroll - for( int ii = 0; ii < NUM_ELTS_; ++ii ) { - this->elt(ii) += other.elt(ii); - } - } - - // Multiply by another fragment. - inline __device__ void hmul(const Fragment &other) { - #pragma unroll - for( int ii = 0; ii < Base_::NUM_REGS; ++ii ) { - this->reg(ii) = fmha::hmul2(this->reg(ii), other.reg(ii)); - } - } - - template - inline __device__ void hrelu_() { - #pragma unroll - for( int ii = 0; ii < Base_::NUM_REGS; ++ii ) { - this->reg(ii) = fmha::hrelu2(this->reg(ii)); - } - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< typename Layout > -struct Fragment_a : public Fragment { -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< typename Layout > -struct Fragment_b : public Fragment { -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct Fragment_accumulator : public Fragment { - - // The base class. - using Base = Fragment; - - // Add two fragments. - template< typename Other_fragment_ > - inline __device__ void add(const Other_fragment_ &other) { - for( int ii = 0; ii < Base::NUM_ELTS; ++ii ) { - this->elt(ii) = this->elt(ii) + other.elt(ii); - } - } - - inline __device__ void mul_(const float other) { - for( int ii = 0; ii < Base::NUM_ELTS; ++ii ) { - this->elt(ii) *= other; - } - } - - // Do the HMMA. - template< typename Layout_a, typename Layout_b > - inline __device__ void mma(const Fragment_a &a, - const Fragment_b &b) { - asm volatile( \ - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 \n" \ - " {%0, %1, %2, %3}, \n" \ - " {%4, %5, %6, %7}, \n" \ - " {%8, %9}, \n" \ - " {%0, %1, %2, %3}; \n" \ - : "+f"( elt(0)), "+f"( elt(1)), "+f"( elt(2)), "+f"( elt(3)) - : "r"(a.reg(0)), "r"(a.reg(1)), "r"(a.reg(2)), "r"(a.reg(3)) - , "r"(b.reg(0)), "r"(b.reg(1))); - asm volatile( \ - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 \n" \ - " {%0, %1, %2, %3}, \n" \ - " {%4, %5, %6, %7}, \n" \ - " {%8, %9}, \n" \ - " {%0, %1, %2, %3}; \n" \ - : "+f"( elt(4)), "+f"( elt(5)), "+f"( elt(6)), "+f"( elt(7)) - : "r"(a.reg(0)), "r"(a.reg(1)), "r"(a.reg(2)), "r"(a.reg(3)) - , "r"(b.reg(2)), "r"(b.reg(3))); - } - -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< typename Fragment, int M, int N > -inline __device__ void clear(Fragment (&frag)[M][N]) { - #pragma unroll - for( int mi = 0; mi < M; ++mi ) { - #pragma unroll - for( int ni = 0; ni < N; ++ni ) { - frag[mi][ni].clear(); - } - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< typename Accumulator_type, int WARPS_K > -struct Clear_accumulator { -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< int WARPS_K > -struct Clear_accumulator { - template< typename Acc, int M, int N > - static inline __device__ void apply(Acc (&acc)[M][N], bool = false) { - fmha::clear(acc); - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ void gemm(Acc (&acc)[M][N], const A (&a)[M], const B (&b)[N]) { - - #pragma unroll - for( int mi = 0; mi < M; ++mi ) { - #pragma unroll - for( int ni = 0; ni < N; ++ni ) { - acc[mi][ni].mma(a[mi], b[ni]); - } - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////// -/// Statically maps half types => cutlass data types -///////////////////////////////////////////////////////////////////////////////////////////////// -template -struct HalfTypeToCutlassType { using Type = Type_; }; - -/// Statically maps __half => cutlass::half_t -template <> struct HalfTypeToCutlassType<__half> { - using Type = cutlass::half_t; -}; - -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) -template <> struct HalfTypeToCutlassType<__nv_bfloat16> { - using Type = cutlass::bfloat16_t; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ void gemm_cl(Acc (&acc)[M][N], const A (&a)[M], const B (&b)[N]) { - using Shape = cutlass::gemm::GemmShape<16 * M, 16 * N, 16>; -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -#elif defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 - using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -#else - using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; - // TD [2022-06-02] We don't support Volta (SM70) yet. - assert(0); -#endif - using Element = typename HalfTypeToCutlassType::Type; - using ElementC = float; - using LayoutA = cutlass::layout::RowMajor; - using LayoutB = cutlass::layout::ColumnMajor; - - using WarpMma = typename cutlass::gemm::warp::DefaultMmaTensorOp< - Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, - cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd, 1, true>::Type; - - constexpr int kIters = Shape::kK / InstructionShape::kK; - // using FragmentA = typename WarpMma::FragmentA; - // using FragmentB = typename WarpMma::FragmentB; - using FragmentA = typename WarpMma::ArchMmaOperator::FragmentA; - using FragmentB = typename WarpMma::ArchMmaOperator::FragmentB; - using FragmentC = typename WarpMma::FragmentC; - - // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y) == 0) { - // printf("FragmentA::kStorageElements = %d\n", FragmentA::kStorageElements); - // printf("Archmma::FragmentA::kStorageElements = %d\n", WarpMma::ArchMmaOperator::FragmentA::kStorageElements); - // printf("FragmentB::kStorageElements = %d\n", FragmentB::kStorageElements); - // printf("Archmma::FragmentB::kStorageElements = %d\n", WarpMma::ArchMmaOperator::FragmentB::kStorageElements); - // printf("FragmentC::kStorageElements = %d\n", FragmentC::kStorageElements); - // printf("Archmma::FragmentC::kStorageElements = %d\n", WarpMma::ArchMmaOperator::FragmentC::kStorageElements); - // } - - // static_assert(FragmentA::kStorageElements == M * a[0].NUM_REGS); - // static_assert(FragmentB::kStorageElements == N * b[0].NUM_REGS); - static_assert(FragmentA::kStorageElements * kIters == a[0].NUM_REGS); - static_assert(FragmentB::kStorageElements * kIters * 16 / InstructionShape::kN == b[0].NUM_REGS); - static_assert(FragmentC::kStorageElements == M * N * acc[0][0].NUM_REGS); - // const FragmentA a_cl = reinterpret_cast(a); - // const FragmentB b_cl = reinterpret_cast(b); - FragmentC c_cl = reinterpret_cast(acc); - FragmentA a_cl[kIters][M]; - FragmentA b_cl[kIters][N]; - constexpr int kRegs = InstructionShape::kK == 16 ? 4 : 2; - #pragma unroll - for (int iter = 0; iter < kIters; iter++) { - #pragma unroll - for (int mi = 0; mi < M; mi++) { - uint32_t *a_ptr = a_cl[iter][mi].raw_data(); - #pragma unroll - for (int ki = 0; ki < kRegs; ki++) { - a_ptr[ki] = a[mi].regs_[iter * kRegs + ki]; - } - } - } - #pragma unroll - for (int iter = 0; iter < kIters; iter++) { - #pragma unroll - for (int ni = 0; ni < N; ni++) { - uint32_t *b_ptr = b_cl[iter][ni].raw_data(); - #pragma unroll - for (int ki = 0; ki < kRegs; ki++) { - // b_ptr[ki] = b[ni].regs_[iter * kRegs + ki]; - // TD [2022-06-02] For some reason the order for frag_b is different. - b_ptr[ki] = b[ni].regs_[InstructionShape::kK == 16 ? iter * kRegs + ki : ki * kRegs + iter]; - } - } - } - - WarpMma mma_op; - // mma_op(c_cl, a_cl, b_cl, c_cl); - #pragma unroll - for (int iter = 0; iter < kIters; iter++) { - mma_op(c_cl, reinterpret_cast(a_cl[iter]), - reinterpret_cast(b_cl[iter]), c_cl); - } - - // The modified c_cl is not copied back into acc, idk why - #pragma unroll - for (int mi = 0; mi < M; mi++) { - #pragma unroll - for (int ni = 0; ni < N; ni++) { - #pragma unroll - for (int i =0; i < 8; i++) { - acc[mi][ni].elt(i) = c_cl[mi * N * 8 + ni * 8 + i]; - } - } - } - -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< - // The number of rows in the CTA tile. - int M_, - // The number of cols in the CTA tile. - int N_, - // The number of elements in the the K dimension of the GEMM loop. - int K_, - // The number of rows of warps. - int WARPS_M_, - // The number of cols of warps. - int WARPS_N_, - // The number of warps in the K dimension of the GEMM loop. - int WARPS_K_> -struct Cta_tile_ { - - static constexpr int M = M_, N = N_, K = K_; - // The number of warps. - static constexpr int WARPS_M = WARPS_M_, WARPS_N = WARPS_N_, WARPS_K = WARPS_K_; - // The number of warps per CTA. - static constexpr int WARPS_PER_CTA = WARPS_M * WARPS_N * WARPS_K; - // The number of threads per warp. - static constexpr int THREADS_PER_WARP = 32; - // The number of threads per CTA. - static constexpr int THREADS_PER_CTA = WARPS_PER_CTA * THREADS_PER_WARP; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Hmma_tile { - // The number of elements computed with a single warp-MMA. - static constexpr int M_PER_MMA = 16, N_PER_MMA = 16, K_PER_MMA = 16; - - // The number of elements computed with a single CTA-MMA. - static constexpr int M_PER_MMA_PER_CTA = M_PER_MMA * Cta_tile::WARPS_M, - N_PER_MMA_PER_CTA = N_PER_MMA * Cta_tile::WARPS_N, - K_PER_MMA_PER_CTA = K_PER_MMA * Cta_tile::WARPS_K; - - // The number of MMAs needed to compute the GEMM. - static constexpr int MMAS_M = DivUpConstexpr(Cta_tile::M, M_PER_MMA_PER_CTA), - MMAS_N = DivUpConstexpr(Cta_tile::N, N_PER_MMA_PER_CTA), - MMAS_K = DivUpConstexpr(Cta_tile::K, K_PER_MMA_PER_CTA); - - // // The number of elements computed per warp. - // static constexpr int M_PER_WARP = MMAS_M * M_PER_MMA, - // N_PER_WARP = MMAS_N * N_PER_MMA, - // K_PER_WARP = MMAS_K * K_PER_MMA; - -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -using A_type = uint16_t; -using B_type = uint16_t; -using C_type = uint16_t; -using Accumulator_type = float; -using Epilogue_type = float; - -constexpr int BITS_PER_ELEMENT_A = sizeof(A_type) * 8; -constexpr int BITS_PER_ELEMENT_B = sizeof(B_type) * 8; -constexpr int BITS_PER_ELEMENT_C = sizeof(C_type) * 8; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -using Cta_tile_extd = Cta_tile_; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -using Cta_tile_with_k_with_padding = Cta_tile_extd::VALUE, - Cta_tile_::WARPS_M, - Cta_tile_::WARPS_N, - Cta_tile_::WARPS_K>; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace fmha diff --git a/csrc/flash_attn/src/fmha/gmem_tile.h b/csrc/flash_attn/src/fmha/gmem_tile.h deleted file mode 100644 index e0bd24c3c..000000000 --- a/csrc/flash_attn/src/fmha/gmem_tile.h +++ /dev/null @@ -1,555 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the - * names of its contributors may be used to endorse or promote products - * derived from this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND - * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED - * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY - * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES - * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; - * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND - * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - ******************************************************************************/ - -#pragma once - -#include -#include - -#include - -namespace fmha { - -template< - // The dimensions of the tile computed by the CTA. - typename Cta_tile_, - // The number of bits per element. - int BITS_PER_ELEMENT, - // The number of rows of Q, K or V loaded by this tile. - int ROWS_, - // The number of columns. - int COLS, - int BYTES_PER_LDGS_ = 16 -> -struct Gmem_tile_qkv { - - using Cta_tile = Cta_tile_; - - static constexpr int BYTES_PER_ELEMENT = BITS_PER_ELEMENT / 8; - // The size of each LDG. - static constexpr int BYTES_PER_LDG = BYTES_PER_LDGS_; - // The size of a row in bytes. - static constexpr int BYTES_PER_ROW = COLS * BITS_PER_ELEMENT / 8; - - // The number of threads to load a "row" of the matrix. - static constexpr int THREADS_PER_ROW = BYTES_PER_ROW / BYTES_PER_LDG; - - static constexpr int ROWS = ROWS_; - // The number of "rows" loaded per LDG. - static constexpr int ROWS_PER_LDG = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW; - // The number of LDGs needed to load a chunk of the Q matrix. - static constexpr int LDGS = DivUpConstexpr(ROWS, ROWS_PER_LDG); - - // Ctor. - template< typename BInfo > - inline __device__ Gmem_tile_qkv(void *ptr_, const uint32_t row_stride_in_elts, - const uint32_t head_stride_in_elts, const int headdim, - const BInfo &binfo, const int tidx, bool use_seqlen_q) - : row_stride_in_bytes(row_stride_in_elts * BYTES_PER_ELEMENT) - , actual_seqlen(use_seqlen_q ? binfo.actual_seqlen_q : binfo.actual_seqlen_k) - , ptr(reinterpret_cast(ptr_)) - , tidx_(tidx) - , col_predicate((tidx % THREADS_PER_ROW) * (BYTES_PER_LDG / BYTES_PER_ELEMENT) < headdim) { - - // Compute the position in the sequence (within the CTA for the moment). - int row = tidx / THREADS_PER_ROW; - // Compute the position of the thread in the row. - int col = tidx % THREADS_PER_ROW; - - // Store the row as we need it to disable the loads. - // TD [2022-04-16]: To minimize registers, we'll recompute row_ instead of storing it - // row_ = row; - - // The row offset in the batched GEMM. For each seq element, we store QKV in that order. - // int64_t row_offset = (int64_t)row * params.qkv_stride_in_bytes; - uint32_t row_offset = (uint32_t)(((use_seqlen_q ? binfo.sum_s_q : binfo.sum_s_k) + row) * row_stride_in_bytes); - // Add the block index. - // row_offset += (int64_t)((binfo.sum_s * NUM_MATS + qkv_offset) * binfo.h + binfo.bidh) * BYTES_PER_ROW; - row_offset += (uint32_t)(binfo.bidh * head_stride_in_elts * BYTES_PER_ELEMENT); - - // Assemble the final pointer. - ptr += row_offset + col * BYTES_PER_LDG; - } - - // Store data to shared memory. - template< typename Smem_tile > - inline __device__ void commit(Smem_tile &smem_tile) { - smem_tile.store(fetch_); - } - - inline __device__ void load() { - int row_ = tidx_ / THREADS_PER_ROW; - const void *ptrs[LDGS]; - uint32_t preds[LDGS]; - #pragma unroll - for( int ii = 0; ii < LDGS; ++ii ) { - // ptrs[ii] = ptr + (int64_t)ii * ROWS_PER_LDG * row_stride_in_bytes; - ptrs[ii] = ptr + (uint32_t)ii * ROWS_PER_LDG * row_stride_in_bytes; - preds[ii] = col_predicate && ((row_ + ii * ROWS_PER_LDG) < min(ROWS, actual_seqlen)); - fetch_[ii] = make_uint4(0, 0, 0, 0); - } - - // not packing predicates removes restrictions (e.g. FP16 384, 4 warps) - Ldg_functor fct(fetch_, ptrs); - #pragma unroll - for( int ii = 0; ii < LDGS; ++ii ) { - fct.load(ii, preds[ii]); - } - } - - // Store data to memory. - inline __device__ void store(const uint4 (&data)[LDGS]) { - int row_ = tidx_ / THREADS_PER_ROW; - #pragma unroll - for( int ii = 0; ii < LDGS; ++ii ) { - // char *ptr_ = ptr + (int64_t)ii * ROWS_PER_LDG * row_stride_in_bytes; - char *ptr_ = ptr + (uint32_t)ii * ROWS_PER_LDG * row_stride_in_bytes; - if (col_predicate && (row_ + ii * ROWS_PER_LDG) < min(ROWS, actual_seqlen)) { - fmha::stg(ptr_, data[ii]); - } - } - } - - inline __device__ void move(const int steps = 1) { - // ptr += (int64_t)ROWS * row_stride_in_bytes * steps; - ptr += (uint32_t)ROWS * row_stride_in_bytes * steps; - actual_seqlen -= ROWS * steps; - } - - // The stride between rows for the QKV matrice. - // int64_t row_stride_in_bytes; - const uint32_t row_stride_in_bytes; - // The pointer. - char *ptr; - // The fetch registers. - uint4 fetch_[LDGS]; - // Keep track of the row the thread is processing as we move the tile. - // int row_; - const int tidx_; - // The length of the sequence loaded by that memory tile. - int actual_seqlen; - const bool col_predicate; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< - typename Cta_tile, - int BYTES_PER_ELEMENT = 2 -> -struct Gmem_tile_o { - - static_assert(BYTES_PER_ELEMENT == 2 || BYTES_PER_ELEMENT == 4); - - // The mma tile. - using Mma_tile = fmha::Hmma_tile; - - // The size of each element. - // static constexpr int BYTES_PER_ELEMENT = 2; - // The size of each STG. - static constexpr int BYTES_PER_STG = BYTES_PER_ELEMENT * 4; - static constexpr int COLS = Cta_tile::N; - // The size of a row in bytes. - static constexpr int BYTES_PER_ROW = COLS * BYTES_PER_ELEMENT; - - // The number of threads to store a "row" of the matrix. - static constexpr int THREADS_PER_ROW = BYTES_PER_ROW / BYTES_PER_STG; - // The number of "rows" stored per iteration of the loop. The output of 1 MMA. - static constexpr int ROWS = Cta_tile::M; - // The number of "rows" stored per iteration of the loop. The output of 1 MMA. - static constexpr int ROWS_PER_LOOP = ROWS <= 64 ? ROWS : (int)Mma_tile::M_PER_MMA_PER_CTA; - // The number of outter loop for the stores. - static constexpr int LOOPS = ROWS / ROWS_PER_LOOP; - - // The number of "rows" stored per STG. - static constexpr int ROWS_PER_STG = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW; - // Do we have to guard against partial writes/reads. - static constexpr bool HAS_INCOMPLETE_STG = Cta_tile::M % ROWS_PER_STG != 0; - // The number of STGs needed to store a chunk of the Q matrix. - static constexpr int STGS_PER_LOOP = DivUpConstexpr(ROWS_PER_LOOP, ROWS_PER_STG); - // The number of STGs needed to store a chunk of the Q matrix in total. - static constexpr int STGS = STGS_PER_LOOP * LOOPS; - - // Ctor. - template - // inline __device__ Gmem_tile_o(void *ptr, const size_t row_stride_in_elts, const BInfo &binfo, const int tidx) - inline __device__ Gmem_tile_o(void *ptr, const uint32_t row_stride_in_elts, - const uint32_t head_stride_in_elts, const int headdim, - const BInfo &binfo, const int tidx) - : row_stride_in_bytes(row_stride_in_elts * BYTES_PER_ELEMENT) - , actual_seqlen_q(binfo.actual_seqlen_q) - , ptr_(reinterpret_cast(ptr)) - , tidx_(tidx) - , col_predicate((tidx % THREADS_PER_ROW) * (BYTES_PER_STG / BYTES_PER_ELEMENT) < headdim) { - - // Compute the position in the sequence (within the CTA for the moment). - int row = tidx / THREADS_PER_ROW; - // Compute the position of the thread in the row. - int col = tidx % THREADS_PER_ROW; - - // Store the row as we need it to disable loads. - // row_ = row; - - // The row offset in the batched GEMM. - // int64_t row_offset = (int64_t)row * row_stride_in_bytes + binfo.bidx * BYTES_PER_ROW; - uint32_t row_offset = (uint32_t)((binfo.sum_s_q + row) * row_stride_in_bytes); - row_offset += (uint32_t)(binfo.bidh * head_stride_in_elts * BYTES_PER_ELEMENT); - // Assemble the final pointer. - ptr_ += row_offset + col * BYTES_PER_STG; - - // Is that thread active on the last STG? - if( HAS_INCOMPLETE_STG ) { - is_active_for_last_stg_ = row + (STGS - 1) * ROWS_PER_STG < Cta_tile::M; - } - } - - // Store data to global memory. - template - inline __device__ void store(const uint4 (&src)[STGS_PER_LOOP], int mi) { - int row_ = tidx_ / THREADS_PER_ROW; - #pragma unroll - for( int ii = 0; ii < STGS_PER_LOOP; ++ii ) { - int jj = mi * STGS_PER_LOOP + ii; - if ((!col_predicate) || (row_ + jj * ROWS_PER_STG >= this->actual_seqlen_q)) { - break; - } - - if (BYTES_PER_ELEMENT == 4) { - if( !HAS_INCOMPLETE_STG || (jj < STGS - 1 || this->is_active_for_last_stg_) ) { - fmha::stg(this->ptr_ + jj * ROWS_PER_STG * this->row_stride_in_bytes, src[ii]); - } - } else if (BYTES_PER_ELEMENT == 2) { - float x = reinterpret_cast(src[ii].x); - float y = reinterpret_cast(src[ii].y); - float z = reinterpret_cast(src[ii].z); - float w = reinterpret_cast(src[ii].w); - uint2 out = fmha::float4_pack(x, y, z, w); - if( !HAS_INCOMPLETE_STG || (jj < STGS - 1 || this->is_active_for_last_stg_) ) { - fmha::stg(this->ptr_ + jj * ROWS_PER_STG * this->row_stride_in_bytes, out); - } - } - } - } - - // Store data to global memory with atomicAdd. - inline __device__ void atomic_add(const uint4 (&src)[STGS_PER_LOOP], int mi) { - static_assert(BYTES_PER_ELEMENT == 4); // Only do atomic add on floats - int row_ = tidx_ / THREADS_PER_ROW; - #pragma unroll - for( int ii = 0; ii < STGS_PER_LOOP; ++ii ) { - int jj = mi * STGS_PER_LOOP + ii; - if ((!col_predicate) || (row_ + jj * ROWS_PER_STG >= this->actual_seqlen_q)) { - break; - } - - if( !HAS_INCOMPLETE_STG || (jj < STGS - 1 || this->is_active_for_last_stg_) ) { - float *ptr_ = reinterpret_cast(this->ptr_ + jj * ROWS_PER_STG * this->row_stride_in_bytes); - #pragma unroll - for (int jj = 0; jj < 4; ++jj) { - atomicAdd(ptr_ + jj, reinterpret_cast(src[ii])[jj]); - } - } - } - } - - // Load data from global memory. - inline __device__ void load(uint4 (&dst)[STGS_PER_LOOP], int mi) { - static_assert(BYTES_PER_ELEMENT == 4); - int row_ = tidx_ / THREADS_PER_ROW; - #pragma unroll - for( int ii = 0; ii < STGS_PER_LOOP; ++ii ) { - int jj = mi * STGS_PER_LOOP + ii; - if ((!col_predicate) || (row_ + jj * ROWS_PER_STG >= this->actual_seqlen_q)) { - break; - } - - if( !HAS_INCOMPLETE_STG || (jj < STGS - 1 || this->is_active_for_last_stg_) ) { - fmha::ldg(dst[ii], this->ptr_ + jj * ROWS_PER_STG * this->row_stride_in_bytes); - } - } - } - - inline __device__ void move(const int steps = 1) { - // row_ += ROWS * steps; - // ptr_ += (int64_t)ROWS * row_stride_in_bytes * steps; - ptr_ += (uint32_t)ROWS * row_stride_in_bytes * steps; - actual_seqlen_q -= ROWS * steps; - } - - // The stride between rows for the QKV matrice. - // int64_t row_stride_in_bytes; - const uint32_t row_stride_in_bytes; - // The pointer. - char *ptr_; - // Is the thread active for the last STG? - int is_active_for_last_stg_; - // The length of the sequence loaded by that memory tile. - int actual_seqlen_q; - const int tidx_; - const bool col_predicate; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< typename Cta_tile, int BYTES_PER_ELEMENT > -struct Gmem_tile_mma_sd { - - // The mma tile. - using Mma_tile = fmha::Hmma_tile; - - // Each STG stores 8 elements. - static constexpr int BYTES_PER_STG = BYTES_PER_ELEMENT * 8; - // The number of MMAs in the M dimension. - static constexpr int MMAS_M = Mma_tile::MMAS_M; - // The number of MMAs in the N dimension. - static constexpr int MMAS_N = Mma_tile::MMAS_N; - // The number of rows computed per MMA per thread block. - static constexpr int M_PER_MMA_PER_CTA = Mma_tile::M_PER_MMA_PER_CTA; - // The number of cols computed per MMA per thread block. - static constexpr int N_PER_MMA_PER_CTA = Mma_tile::N_PER_MMA_PER_CTA; - // The number of threads per block. - static constexpr int THREADS_PER_CTA = Cta_tile::THREADS_PER_CTA; - // The size of each row in bytes. I.e. how many bytes are stored per STG. - static constexpr int BYTES_PER_ROW = THREADS_PER_CTA * BYTES_PER_STG; - // The distance between elements stored per loop (in bytes). - static constexpr int LOOP_STRIDE_BYTES = MMAS_M * MMAS_N * BYTES_PER_ROW; - - // The type of elements stored per STG. - using Type = typename fmha::Uint_from_size_in_bytes::Type; - - // Ctor. - template - inline __device__ Gmem_tile_mma_sd(void *ptr, const Params ¶ms, const int bidb, const int bidh, const int tidx) - : ptr_(static_cast(ptr)) { - - // The block index. - // size_t bidx = bidb * params.h + bidh; - uint32_t bidx = bidb * params.h + bidh; - - // The distance between two blocks (in bytes). - // const size_t block_stride_bytes = params.seqlen_q * params.seqlen_k * BYTES_PER_ELEMENT; - const uint32_t block_stride_bytes = params.seqlen_q * params.seqlen_k * BYTES_PER_ELEMENT; - // Set store location for each thread at the beginning of the loop - ptr_ += bidx * block_stride_bytes + tidx * BYTES_PER_STG; - } - - // Store to global memory. - inline __device__ void store(const Type &data, const int mi, const int ni) { - // size_t offset = (mi * MMAS_N + ni) * BYTES_PER_ROW; - uint32_t offset = (mi * MMAS_N + ni) * BYTES_PER_ROW; - fmha::stg(ptr_ + offset, data); - } - - // Load from global memory. - inline __device__ void load(Type &data, const int mi, const int ni) { - // size_t offset = (mi * MMAS_N + ni) * BYTES_PER_ROW; - uint32_t offset = (mi * MMAS_N + ni) * BYTES_PER_ROW; - fmha::ldg(data, ptr_ + offset); - } - - // Move to the next tile. - inline __device__ void move(const int steps = 1) { - ptr_ += LOOP_STRIDE_BYTES * steps; - } - - // The pointer in global memory. - char *ptr_; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< typename Cta_tile, typename Base = Gmem_tile_mma_sd > -struct Gmem_tile_mma_s : public Base { - - // The number of mmas in the vertical dimension. - static constexpr int M = Base::MMAS_M; - // The number of mmas in the horizontal dimension. - static constexpr int N = Base::MMAS_N; - // The type of the vectors stored by each STG. - using Type = typename Base::Type; - - // Ctor. - template< typename Params, typename Block_info > - inline __device__ Gmem_tile_mma_s(const Params ¶ms, const Block_info& binfo, const int tidx) - : Base(params.s_ptr, params, binfo.bidb, binfo.bidh, tidx) { - } - - // Store to global memory. - template - inline __device__ void store(const Fragment (&frag)[N][M], const Mask& mask){ - #pragma unroll - for( int mi = 0; mi < M; mi++ ) { - #pragma unroll - for( int ni = 0; ni < N; ni++ ) { - uint4 dst; - dst.x = frag[ni][mi].reg(0); - dst.y = frag[ni][mi].reg(2); - dst.z = frag[ni][mi].reg(1); - dst.w = frag[ni][mi].reg(3); - if( mask.any_valid(mi, ni) ) { - Base::store(dst, mi, ni); - } - } - } - } - - // Load from global memory. - template - inline __device__ void load(uint4 (®s)[M][N], const Mask &mask) { - #pragma unroll - for( int mi = 0; mi < M; mi++ ) { - #pragma unroll - for( int ni = 0; ni < N; ni++ ) { - regs[mi][ni] = make_uint4(0, 0, 0, 0); - if( mask.any_valid(mi, ni) ) { - Base::load(regs[mi][ni], mi, ni); - } - } - } - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< - // The dimensions of the tile computed by the CTA. - typename Cta_tile -> -struct Gmem_summary_stats { - - // The Mma tile. - using Mma_tile = fmha::Hmma_tile; - - // The number of MMAs in M/N dimensions. - static constexpr int MMAS_M = Mma_tile::MMAS_M; - - // The size of each element. - static constexpr int BYTES_PER_ELEMENT = 4; - static constexpr int BYTES_PER_MMA = (Cta_tile::THREADS_PER_WARP / 4) * 2 * BYTES_PER_ELEMENT; - static constexpr int ROWS = Cta_tile::M; - - // Ctor. - template - inline __device__ Gmem_summary_stats(void *ptr, const Params ¶ms, const int tidx) - : ptr_(reinterpret_cast(ptr)), tidx_(tidx) { - - // The block index for the batch. - const int bidb = blockIdx.x; - // The block index for the head. - const int bidh = blockIdx.y; - // The block index. - // size_t bidx = bidb * params.h + bidh; - uint32_t bidx = bidb * params.h + bidh; - - // Extract the position in the warp. - int warp = tidx / Cta_tile::THREADS_PER_WARP; - int lane = tidx % Cta_tile::THREADS_PER_WARP; - - // The distance between two blocks (in bytes). - // size_t block_stride_bytes = params.seqlen_q * BYTES_PER_ELEMENT; - uint32_t block_stride_bytes = params.seqlen_q * BYTES_PER_ELEMENT; - - // Set store location for each thread at the beginning of the loop - ptr_row_ = ptr_ + bidx * block_stride_bytes; - ptr_ += bidx * block_stride_bytes + (lane / 4) * BYTES_PER_ELEMENT; - } - - // Store data to global memory. - inline __device__ void store(const uint32_t (&data)[MMAS_M * 2]) { - int warp = tidx_ / Cta_tile::THREADS_PER_WARP; - int lane = tidx_ % Cta_tile::THREADS_PER_WARP; - if ((warp == 0) && (lane % 4 == 0)) { - #pragma unroll - for (int mi = 0; mi < MMAS_M; ++mi) { - // TODO: Not sure if it's right for MMAS_M > 1 - fmha::stg(ptr_ + mi * BYTES_PER_MMA + 0 * BYTES_PER_ELEMENT, data[mi * 2 + 0]); - fmha::stg(ptr_ + mi * BYTES_PER_MMA + 8 * BYTES_PER_ELEMENT, data[mi * 2 + 1]); - } - } - } - - // Store data to global memory. - inline __device__ void store_row(const uint32_t (&data)[MMAS_M], const int row) { - #pragma unroll - for (int mi = 0; mi < MMAS_M; ++mi) { - // TODO: Not sure if it's right for MMAS_M > 1 - fmha::stg(ptr_row_ + mi * BYTES_PER_MMA + row * BYTES_PER_ELEMENT, data[mi]); - } - } - - // Load from global memory. - inline __device__ void load(uint32_t (&data)[MMAS_M * 2]) { - #pragma unroll - for (int mi = 0; mi < MMAS_M; ++mi) { - // TODO: Not sure if it's right for MMAS_M > 1 - fmha::ldg(data[mi * 2 + 0], ptr_ + mi * BYTES_PER_MMA + 0 * BYTES_PER_ELEMENT); - fmha::ldg(data[mi * 2 + 1], ptr_ + mi * BYTES_PER_MMA + 8 * BYTES_PER_ELEMENT); - } - } - - // Load from global memory. - inline __device__ void load_next(uint32_t (&data)[MMAS_M * 2], int move_steps=1) { - char *ptr_next = ptr_ + move_steps * ROWS * BYTES_PER_ELEMENT; - #pragma unroll - for (int mi = 0; mi < MMAS_M; ++mi) { - // TODO: Not sure if it's right for MMAS_M > 1 - fmha::ldg(data[mi * 2 + 0], ptr_next + mi * BYTES_PER_MMA + 0 * BYTES_PER_ELEMENT); - fmha::ldg(data[mi * 2 + 1], ptr_next + mi * BYTES_PER_MMA + 8 * BYTES_PER_ELEMENT); - } - } - - // Store data to global memory. - template - inline __device__ void load_row(uint32_t (&data)[N], const int row[N]) { - #pragma unroll - for (int ni = 0; ni < N; ++ni) { - fmha::ldg(data[ni], ptr_row_ + row[ni] * BYTES_PER_ELEMENT); - } - } - - // Move the pointer to the next location. - inline __device__ void move() { - ptr_ += ROWS * BYTES_PER_ELEMENT; - ptr_row_ += ROWS * BYTES_PER_ELEMENT; - } - - // Move the pointer to the next location. - inline __device__ void move(const int steps) { - ptr_ += ROWS * BYTES_PER_ELEMENT * steps; - ptr_row_ += ROWS * BYTES_PER_ELEMENT * steps; - } - - // The pointer. - char *ptr_; - char *ptr_row_; - const int tidx_; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace fmha - diff --git a/csrc/flash_attn/src/fmha/kernel_traits.h b/csrc/flash_attn/src/fmha/kernel_traits.h deleted file mode 100644 index 63f07aee8..000000000 --- a/csrc/flash_attn/src/fmha/kernel_traits.h +++ /dev/null @@ -1,116 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the - * names of its contributors may be used to endorse or promote products - * derived from this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND - * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED - * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY - * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES - * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; - * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND - * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - ******************************************************************************/ - -#include - -#pragma once - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct FMHA_kernel_traits { - - // The CTA description for the 1st GEMM. - using Cta_tile_p = fmha::Cta_tile_extd; - // The CTA description for the 2nd GEMM. - using Cta_tile_o = fmha::Cta_tile_extd; - - // Do we use one buffer for K and V. - static constexpr bool SHARE_SMEM_FOR_K_AND_V = (FLAGS & 0x08u) != 0u; - // Do we keep K in registers. - static constexpr bool K_IN_REGS = (FLAGS & 0x10u) == 0u; - // Do we keep V in registers. - static constexpr bool V_IN_REGS = (FLAGS & 0x100u) == 0u; - - // The global memory tile to load Q. - using Gmem_tile_q = fmha::Gmem_tile_qkv; - - // The shared memory tile to swizzle Q. - // using Smem_tile_q = fmha::Smem_tile_a; - using Smem_tile_q = fmha::Smem_tile_a; - - // The global memory tile to load K. - using Gmem_tile_k = fmha::Gmem_tile_qkv; - // The shared memory tile to swizzle K. - using Smem_tile_k = fmha::Smem_tile_b; - - // The global memory tile to load V. - using Gmem_tile_v = fmha::Gmem_tile_qkv; - // The shared memory tile to swizzle V. - using Smem_tile_v = fmha::Smem_tile_v; - - // The global memory tile to store O. - using Gmem_tile_o = fmha::Gmem_tile_o; - // The shared memory tile for O. - using Smem_tile_o = fmha::Smem_tile_o;; - - // The global memory tile to load/store S. - using Gmem_tile_s = fmha::Gmem_tile_mma_s; - - // The shared memory tile to transpose S. - using Smem_tile_st = fmha::Smem_tile_mma_transposed; - - using Gmem_tile_do = fmha::Gmem_tile_qkv; - - // // The global memory tile to store the accumulated dK and dV - // // Hack: we set BYTES_PER_LDGS=32 to emulate the access pattern of dK and dV - // // where there are 16 bits per lements and 16 bytes per load. In reality we won't - // // be issue any load or store of size 32 bytes. - // using Gmem_tile_dkv_accum = fmha::Gmem_tile_qkv; - - // The global memory tile to store the softmax sum. - using Gmem_softmax_sum = fmha::Gmem_summary_stats; - - // The shared memory tile to store dp sum. - using Smem_dp_sum = fmha::Smem_tile_dp_sum; - - using elem_type = elem_type_; - - // Make sure the number of threads match. - static_assert((int)Gmem_tile_o::THREADS_PER_ROW == (int)Smem_tile_o::THREADS_PER_ROW, ""); - - // The number of threads. - static constexpr int THREADS = Cta_tile_p::THREADS_PER_CTA; - // Make sure the number of threads matches both CTAs. - static_assert(THREADS == Cta_tile_o::THREADS_PER_CTA, ""); - - // The amount of shared memory needed to load Q and K. - static constexpr int BYTES_PER_SMEM_QK = Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE; - // The extra amount of shared memory needed to load V. - static constexpr int BYTES_PER_SMEM_V = SHARE_SMEM_FOR_K_AND_V ? 0u : Smem_tile_v::BYTES_PER_TILE; - // The amount of shared memory needed for Q, K and V.. - static constexpr int BYTES_PER_SMEM_QKV = BYTES_PER_SMEM_QK + BYTES_PER_SMEM_V; - // The amount of shared memory needed to load Q and store O. - static constexpr int BYTES_PER_SMEM_QO = Smem_tile_q::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE; - - // The amount of shared memory needed for Q, K, V and O. - static constexpr int BYTES_PER_SMEM = fmha::MaxConstexpr(BYTES_PER_SMEM_QKV, BYTES_PER_SMEM_QO); - // Make sure we have enough shared memory. - static_assert(Smem_tile_q::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE <= BYTES_PER_SMEM, ""); -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/flash_attn/src/fmha/mask.h b/csrc/flash_attn/src/fmha/mask.h deleted file mode 100644 index 6c8092983..000000000 --- a/csrc/flash_attn/src/fmha/mask.h +++ /dev/null @@ -1,90 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the - * names of its contributors may be used to endorse or promote products - * derived from this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND - * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED - * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY - * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES - * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; - * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND - * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - ******************************************************************************/ - -#pragma once - -namespace fmha { - - -template -struct Mask { - using Mma_tile = fmha::Hmma_tile; - - template - __device__ Mask(const BInfo &binfo, int tidx, const int loop_step_idx_ = 0) - : actual_seqlen_k(binfo.actual_seqlen_k - loop_step_idx_ * Cta_tile::N) - , loop_step_idx(loop_step_idx_) { - - const int warp = tidx / Cta_tile::THREADS_PER_WARP; - const int lane = tidx % Cta_tile::THREADS_PER_WARP; - - static_assert(Cta_tile::WARPS_K == 1, ""); - - // find the warp in the Cta tile - const int warp_n = (warp / Cta_tile::WARPS_M); - const int warp_m = (warp % Cta_tile::WARPS_M); - // decompose warp into 8x4 tile - const int quad = lane / 4; - const int tid = (lane % 4) * 2; - row = warp_m * 16 + quad; - col = warp_n * 16 + tid; - } - - inline __device__ bool is_valid(const int mi, const int ni, const int ii, const int jj) const { - - // ii and jj iterate over the 2x4 fragment - // const int current_col = (Is_causal ? loop_step_idx * Cta_tile::N : 0) + ni * Mma_tile::N_PER_MMA_PER_CTA + col + (jj & 2) * 4 + (jj & 1); - const int current_col = ni * Mma_tile::N_PER_MMA_PER_CTA + col + (jj & 2) * 4 + (jj & 1); - const int current_row = row_offset + ii * 8; - const bool col_valid = current_col < actual_seqlen_k; - // const bool col_valid = (ni * Mma_tile::N_PER_MMA_PER_CTA + col + (jj & 2) * 4 + (jj & 1)) < actual_seqlen_k; - //&& (row + mi * Mma_tile::M_PER_MMA_PER_CTA + ii * 8) < actual_seqlen_k; - // bool all_valid = Is_causal ? col_valid && (current_col + loop_step_idx * Cta_tile::N <= current_row) : col_valid; - // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (blockIdx.z == 1)) { - // printf("current_col=%d, current_row=%d, actual_seqlen_k=%d, col_valid=%d, all_valid=%d\n", current_col, current_row, actual_seqlen_k, col_valid, all_valid); - // } - return Is_causal ? col_valid && (current_col + loop_step_idx * Cta_tile::N <= current_row) : col_valid; - // return row_valid && col_valid; - } - - //BERT Mask: if upper left is invalid, none are valid - inline __device__ bool any_valid(const int mi, const int ni) const { - return is_valid(mi, ni, 0, 0) || is_valid(mi, ni, 1, 0); - } - - inline __device__ void load(const int it) { - row_offset = it * Cta_tile::M + row; - } - int row_offset; - - int row; - int col; - const int loop_step_idx; - const int actual_seqlen_k; -}; - -} // namespace fmha diff --git a/csrc/flash_attn/src/fmha/smem_tile.h b/csrc/flash_attn/src/fmha/smem_tile.h deleted file mode 100644 index 491253bb9..000000000 --- a/csrc/flash_attn/src/fmha/smem_tile.h +++ /dev/null @@ -1,1703 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the - * names of its contributors may be used to endorse or promote products - * derived from this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND - * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED - * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY - * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES - * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; - * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND - * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - ******************************************************************************/ - -#pragma once - -#include "utils.h" -#include -#include - -namespace fmha { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< - // The description of the tile computed by this CTA. - typename Cta_tile, - // The number of rows in the 2D shared memory buffer. - int M_, - // The number of cols. - int N_, - // The size in bits of each element. - int BITS_PER_ELEMENT_, - // The number of bytes per STS. - int BYTES_PER_STS_ = 16, - // The number of buffers. (Used in multistage and double buffer cases.) - int BUFFERS_PER_TILE_ = 1, - // Do we enable the fast path for LDS.128 and friends. - int ENABLE_LDS_FAST_PATH_ = 0, - // The number of rows that are used for the XOR swizzling to allow fast STS/LDS. - int ROWS_PER_XOR_PATTERN_ = 8, - // The number of cols that are used for the XOR swizzling to allow fast STS/LDS. - int COLS_PER_XOR_PATTERN_ = 1, - // Use or not predicates - bool USE_PREDICATES_ = true -> -struct Smem_tile_without_skews { - - // The size in bits of each element. - enum { BITS_PER_ELEMENT = BITS_PER_ELEMENT_ }; - // The size in bytes of a single STS. - enum { BYTES_PER_STS = BYTES_PER_STS_ }; - // The number of elements per STS. - enum { ELEMENTS_PER_STS = BYTES_PER_STS * 8 / BITS_PER_ELEMENT }; - // To support arbitrary N, we pad some values to a power-of-2. - enum { N_WITH_PADDING = Next_power_of_two::VALUE }; - // The number of bytes per row without packing of rows. - enum { BYTES_PER_ROW_BEFORE_PACKING = N_WITH_PADDING * BITS_PER_ELEMENT / 8 }; - // The number of bytes per row -- we want at least 128B per row. - enum { BYTES_PER_ROW = Max::VALUE }; - // The number of rows in shared memory (two rows may be packed into a single one). - enum { ROWS = M_ * BYTES_PER_ROW_BEFORE_PACKING / BYTES_PER_ROW }; - - // The number of threads per row. - enum { THREADS_PER_ROW_UNBOUNDED = BYTES_PER_ROW / BYTES_PER_STS }; - // The number of threads per row. - enum { THREADS_PER_ROW = Min::VALUE }; - - // The number of STS per row. - enum { STS_PER_ROW = BYTES_PER_ROW / THREADS_PER_ROW / BYTES_PER_STS }; - // It must be at least one. - static_assert(STS_PER_ROW >= 1, ""); - // The number of rows written with a single STS. - enum { ROWS_PER_STS = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW }; - // Make sure we write to at least one row per STS. Thanks Dr. Obvious ;) - static_assert(ROWS_PER_STS >= 1, ""); - // The number of STS needed to store all rows. - enum { STS_PER_COL = Div_up::VALUE }; - // The number of STS in total. - enum { STS = STS_PER_COL * STS_PER_ROW }; - - // TD [2022-06-02] In the case of Q (16 x 64) in the backward pass with 256 threads, - // we only need to store 16 * 64 * 2 = 2KB instead of 4KB. - static constexpr bool PARTIAL_STORE = ROWS_PER_STS > ROWS; - static constexpr int STORING_THREADS = PARTIAL_STORE ? ROWS * THREADS_PER_ROW : Cta_tile::THREADS_PER_CTA; - - // The size of one buffer in bytes in shared memory. - // enum { BYTES_PER_BUFFER = STS * BYTES_PER_STS * Cta_tile::THREADS_PER_CTA }; - enum { BYTES_PER_BUFFER = STS * BYTES_PER_STS * STORING_THREADS }; - // The number of buffers. - enum { BUFFERS_PER_TILE = BUFFERS_PER_TILE_ }; - // The size in bytes of total buffers. - enum { BYTES_PER_TILE = BYTES_PER_BUFFER * BUFFERS_PER_TILE }; - // The boundary for smem_read_offset and smem_write_offset increment. - enum { BYTES_PER_TILE_INC_BOUNDARY = BYTES_PER_TILE - BYTES_PER_BUFFER }; - - // Do we enable the LDS.128 fast path? - enum { ENABLE_LDS_FAST_PATH = ENABLE_LDS_FAST_PATH_ }; - static_assert(ENABLE_LDS_FAST_PATH == 0); - // The number of rows that are used for the XOR swizzling to allow fast STS/LDS. - enum { ROWS_PER_XOR_PATTERN = ROWS_PER_XOR_PATTERN_ }; - // The number of cols that are used for the XOR swizzling to allow fast STS/LDS. - enum { COLS_PER_XOR_PATTERN = COLS_PER_XOR_PATTERN_ * 16 / BYTES_PER_STS }; - // Use or not predicates - enum { USE_PREDICATES = USE_PREDICATES_ }; - - // The type of elements that are stored in shared memory by each thread. - using Store_type = typename Uint_from_size_in_bytes::Type; - - // Ctor. - inline __device__ Smem_tile_without_skews(void *smem, int tidx) - : smem_(__nvvm_get_smem_pointer(smem)), tidx_(tidx) { - - // The row written by a thread. See doc/mma_smem_layout.xlsx. - int smem_write_row = tidx / THREADS_PER_ROW; - - // The XOR pattern. - int smem_write_xor = smem_write_row % ROWS_PER_XOR_PATTERN * COLS_PER_XOR_PATTERN; - // Compute the column and apply the XOR pattern. - int smem_write_col = (tidx % THREADS_PER_ROW) ^ smem_write_xor; - - // The offset. - this->smem_write_offset_ = smem_write_row*BYTES_PER_ROW + smem_write_col*BYTES_PER_STS; - - // TODO: Why not merge it with the read offset? - // this->smem_read_buffer_ = __shfl_sync(0xffffffff, 0, 0); - // this->smem_write_buffer_ = __shfl_sync(0xffffffff, 0, 0); - } - - // Compute the store pointers. - template< int N > - inline __device__ void compute_store_pointers(uint32_t (&ptrs)[N]) { - #pragma unroll - for( int ii = 0; ii < N; ++ii ) { - // Decompose the STS into row/col. - int row = ii / STS_PER_ROW; - int col = ii % STS_PER_ROW; - - // Assemble the offset. - int offset = smem_write_offset_ + row*ROWS_PER_STS*BYTES_PER_ROW; - - // Take the column into account. - if( STS_PER_ROW > 1 ) { - offset += col*THREADS_PER_ROW*BYTES_PER_STS; - } - - // Apply the XOR pattern if needed. - if( ROWS_PER_STS < ROWS_PER_XOR_PATTERN ) { - const int m = row * ROWS_PER_STS % ROWS_PER_XOR_PATTERN; - offset ^= m * COLS_PER_XOR_PATTERN * BYTES_PER_STS; - } - - // Assemble the final pointer :) - // ptrs[ii] = smem_ + offset + smem_write_buffer_; - // smem_write_buffer_ is already merged with smem_write_offset_ - ptrs[ii] = smem_ + offset; - } - } - - inline __device__ void debug_reset() { - for( int buffer = 0; buffer < BYTES_PER_TILE; buffer += BYTES_PER_BUFFER) { - for( int row = 0; row < ROWS; ++row ) { - for( int col = 0; col < BYTES_PER_ROW; col += 4 ) { - if( threadIdx.x == 0 ) { - uint32_t val = 0x0; - sts(val, smem_ + row*BYTES_PER_ROW + col + buffer); - } - } - } - } - } - - // Print the content of the tile (only for debug ;)). - inline __device__ void debug_print() const { - for( int buffer = 0; buffer < BYTES_PER_TILE; buffer += BYTES_PER_BUFFER) { - for( int row = 0; row < ROWS; ++row ) { - for( int col = 0; col < BYTES_PER_ROW; col += 4 ) { - if( threadIdx.x == 0 ) { - uint32_t val; - lds(val, smem_ + row*BYTES_PER_ROW + col + buffer); - printf("block=(x=%2d, y=%2d, z=%2d) (smem_=%2d, buffer=%2d, row=%2d, byte=%4d)=0x%08x\n", - blockIdx.x, - blockIdx.y, - blockIdx.z, - smem_, - buffer, - row, - col, - val); - } - } - } - } - } - - // Move the read offset to next buffer. - inline __device__ void move_to_next_read_buffer() { - // if( BUFFERS_PER_TILE > 1 && smem_read_buffer_ >= BYTES_PER_TILE_INC_BOUNDARY ) { - // this->smem_read_buffer_ -= BYTES_PER_TILE_INC_BOUNDARY; - // } else if( BUFFERS_PER_TILE > 1 ) { - // this->smem_read_buffer_ += BYTES_PER_BUFFER; - // } - if( BUFFERS_PER_TILE > 1 && smem_read_offset_ >= BYTES_PER_TILE_INC_BOUNDARY ) { - this->smem_read_offset_ -= BYTES_PER_TILE_INC_BOUNDARY; - } else if( BUFFERS_PER_TILE > 1 ) { - this->smem_read_offset_ += BYTES_PER_BUFFER; - } - } - - // Move the read offset to next buffer. TODO: Remove this member function!!! - inline __device__ void move_next_read_buffer() { - this->move_to_next_read_buffer(); - } - - // Move the read offset to next N buffer (circular-buffer). - inline __device__ void move_to_next_read_buffer(int N) { - if( BUFFERS_PER_TILE > 1 ) { - // this->smem_read_buffer_ += N * BYTES_PER_BUFFER; - // this->smem_read_buffer_ -= smem_read_buffer_ >= BYTES_PER_TILE ? BYTES_PER_TILE : 0; - this->smem_read_offset_ += N * BYTES_PER_BUFFER; - this->smem_read_offset_ -= smem_read_offset_ >= BYTES_PER_TILE ? BYTES_PER_TILE : 0; - } - } - - // Move the read offset to next N buffer (circular-buffer). TODO: Remove this member function!!! - inline __device__ void move_next_read_buffer(int N) { - this->move_to_next_read_buffer(N); - } - - // Move the write offset to next buffer. - inline __device__ void move_to_next_write_buffer() { - // if( BUFFERS_PER_TILE > 1 && smem_write_buffer_ >= BYTES_PER_TILE_INC_BOUNDARY ) { - // this->smem_write_buffer_ -= BYTES_PER_TILE_INC_BOUNDARY; - // } else if( BUFFERS_PER_TILE > 1 ) { - // this->smem_write_buffer_ += BYTES_PER_BUFFER; - // } - if( BUFFERS_PER_TILE > 1 && smem_write_offset_ >= BYTES_PER_TILE_INC_BOUNDARY ) { - this->smem_write_offset_ -= BYTES_PER_TILE_INC_BOUNDARY; - } else if( BUFFERS_PER_TILE > 1 ) { - this->smem_write_offset_ += BYTES_PER_BUFFER; - } - } - - // Move the write offset to next buffer. TODO: Remove that member function! - inline __device__ void move_next_write_buffer() { - this->move_to_next_write_buffer(); - } - - // Move the read offset. - inline __device__ void move_read_offset(int delta) { - this->smem_read_offset_ += delta; - } - - // Move the write offset. - inline __device__ void move_write_offset(int delta) { - this->smem_write_offset_ += delta; - } - - // Store to the tile in shared memory. - template< int N > - inline __device__ void store(const Store_type (&data)[N], uint64_t = 0) { - uint32_t smem_ptrs[N]; - this->compute_store_pointers(smem_ptrs); - // Trying to reduce the shared mem for Q from 4KB per buffer to 2KB per buffer. - if (!PARTIAL_STORE || (tidx_ / THREADS_PER_ROW < ROWS)) { - sts(smem_ptrs, data); - } - } - - // Store to the tile in shared memory. - template< int N, int M > - inline __device__ void store(const Store_type (&data)[N], uint32_t (&preds)[M], uint64_t = 0) { - uint32_t smem_ptrs[N]; - this->compute_store_pointers(smem_ptrs); - sts(smem_ptrs, data, preds); - } - - // Store to the tile in shared memory. - template< int N > - inline __device__ void store(const Store_type (&data)[N], uint32_t preds, uint64_t = 0) { - this->store(data, preds); - } - - // Store to the tile in shared memory. - template< int N > - inline __device__ void store(const void* (&gmem_ptrs)[N], uint32_t preds, uint64_t = 0) { - uint32_t tmp[1] = { preds }; - this->store(gmem_ptrs, tmp); - } - - // The shared memory pointer. - const uint32_t smem_; - // The read offset. Reserve 4 offsets if needed. - int smem_read_offset_; - // The write offset. - int smem_write_offset_; - // The buffer base offset for read. - // int smem_read_buffer_; - // The buffer base offset for write. - // int smem_write_buffer_; - const int tidx_; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< - // The dimensions of the tile computed by the CTA. - typename Cta_tile, - // The layout of the tile. - typename Layout, - // The size of the STS. - int BYTES_PER_STS = 16, - // The number of buffers per tile. - int BUFFERS_PER_TILE = 1, - // Use or not predicates - bool USE_PREDICATES = true -> -struct Smem_tile_a { -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< int MMAS_K, int MMAS_K_WITH_PADDING > -struct Compute_reset_mask { - // The potential mask. - enum { HALF = MMAS_K_WITH_PADDING / 2 }; - // The remainder. - enum { MOD = MMAS_K % HALF }; - // The final value. - enum { VALUE = (MMAS_K == MOD ? 0 : HALF) | Compute_reset_mask::VALUE }; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< int MMAS_K_WITH_PADDING > -struct Compute_reset_mask<0, MMAS_K_WITH_PADDING> { - enum { VALUE = 0 }; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< int MMAS_K > -struct Compute_reset_mask { - enum { VALUE = MMAS_K - 1 }; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< int N > -struct Rows_per_xor_pattern_a { - // The size in bits. - enum { N_IN_BITS = N * fmha::BITS_PER_ELEMENT_A }; - // The number of rows. - enum { VALUE = N_IN_BITS <= 256 ? 2 : (N_IN_BITS <= 512 ? 4 : 8) }; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< int N > -struct Rows_per_xor_pattern_row_a : public Rows_per_xor_pattern_a { -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< - // The dimensions of the tile computed by the CTA. - typename Cta_tile, - // The size of the STS. - int BYTES_PER_STS, - // The number of buffers per tile. - int BUFFERS_PER_TILE, - // How many rows to use for the XOR pattern to avoid bank conflicts? - int ROWS_PER_XOR_PATTERN_ = Rows_per_xor_pattern_row_a::VALUE -> -struct Smem_tile_row_a : public Smem_tile_without_skews { - // The MMA tile. - using Mma_tile = fmha::Hmma_tile; - // The base class. - using Base = Smem_tile_without_skews; - // The fragment. - using Fragment = Fragment_a; - - // When we use padding to reach a power of two, special care has to be taken. - using Cta_tile_with_padding = Cta_tile_with_k_with_padding; - // The number of MMAs. - using Mma_tile_with_padding = fmha::Hmma_tile; - - // The size of a single LDS in bytes. - enum { BYTES_PER_LDS = 16 }; - - // Ctor. - inline __device__ Smem_tile_row_a(void *smem, int tidx) : Base(smem, tidx) { - - // For documentation on the layout, see doc/mma_smem_layout.xlsx. - - // The number of warps. - const int WARPS_M = Cta_tile::WARPS_M; - const int WARPS_N = Cta_tile::WARPS_N; - const int WARPS_K = Cta_tile::WARPS_K; - - static_assert(WARPS_M == 1); - static_assert(WARPS_N == 4 || WARPS_N == 8); - static_assert(WARPS_K == 1); - static_assert(Base::ROWS_PER_XOR_PATTERN == 2 || Base::ROWS_PER_XOR_PATTERN == 4 || Base::ROWS_PER_XOR_PATTERN == 8); - - // The row and column read by the thread. - int smem_read_row = (tidx & 0x0f); - constexpr int ROWS_PER_PACKING = Base::BYTES_PER_ROW / Base::BYTES_PER_ROW_BEFORE_PACKING; - int smem_read_col = ((smem_read_row / ROWS_PER_PACKING) % Base::ROWS_PER_XOR_PATTERN) * Base::COLS_PER_XOR_PATTERN; - smem_read_col ^= (tidx & 0x10) / 16; - - // The shared memory offset. - this->smem_read_offset_ = smem_read_row*Base::BYTES_PER_ROW_BEFORE_PACKING + smem_read_col*BYTES_PER_LDS; - } - - // Rewind smem_read_offset for last LDS phase in main loop. - inline __device__ void reverse_smem_read_offset(int ki = 0) { - // Undo the pointer increment for the next ni. - // Should match the load function below for ki = 0. - if( Mma_tile_with_padding::MMAS_K >= 2 ) { - this->smem_read_offset_ ^= BYTES_PER_LDS * 2; - } - } - - // Load from shared memory. - inline __device__ void load(Fragment (&a)[Mma_tile::MMAS_M], int ki) { - #pragma unroll - for( int mi = 0; mi < Mma_tile::MMAS_M; ++mi ) { - // Jump by as many matrix rows as needed (a row in smem may pack multiple matrix rows). - int offset = mi * Mma_tile::M_PER_MMA_PER_CTA * Base::BYTES_PER_ROW_BEFORE_PACKING; - - // Load using LDSM.M88.4. - uint4 tmp; - // ldsm(tmp, this->smem_ + this->smem_read_offset_ + this->smem_read_buffer_ + offset); - ldsm(tmp, this->smem_ + this->smem_read_offset_ + offset); - - // Store the value into the fragment. - a[mi].reg(0) = tmp.x; - a[mi].reg(1) = tmp.y; - a[mi].reg(2) = tmp.z; - a[mi].reg(3) = tmp.w; - } - - // Move the offset to the next possition. See doc/mma_smem_layout.xlsx. - static_assert(Mma_tile_with_padding::MMAS_K < 64, "Not implemented"); - if( Mma_tile_with_padding::MMAS_K >= 32 && ki % 16 == 15 ) { - this->smem_read_offset_ ^= 31 * BYTES_PER_LDS * 2; - } else if( Mma_tile_with_padding::MMAS_K >= 16 && ki % 8 == 7 ) { - this->smem_read_offset_ ^= 15 * BYTES_PER_LDS * 2; - } else if( Mma_tile_with_padding::MMAS_K >= 8 && ki % 4 == 3 ) { - this->smem_read_offset_ ^= 7 * BYTES_PER_LDS * 2; - } else if( Mma_tile_with_padding::MMAS_K >= 4 && ki % 2 == 1 ) { - this->smem_read_offset_ ^= 3 * BYTES_PER_LDS * 2; - } else if( Mma_tile_with_padding::MMAS_K >= 2 ) { - this->smem_read_offset_ ^= 1 * BYTES_PER_LDS * 2; - } - } - - // Reset the read offset. - inline __device__ void reset_read_offset() { - // The number of MMAs in the K dimension. - enum { MMAS_K = Mma_tile::MMAS_K }; - // The number of MMAs in the K dimension when we include padding. - enum { MMAS_K_WITH_PADDING = Mma_tile_with_padding::MMAS_K }; - // Assemble the mask. - enum { MASK = Compute_reset_mask::VALUE }; - - // Reset the read offset. - this->smem_read_offset_ ^= MASK * BYTES_PER_LDS * 2; - } - -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< - // The dimensions of the tile computed by the CTA. - typename Cta_tile, - // The size of the STS. - int BYTES_PER_STS, - // The number of buffers per tile. - int BUFFERS_PER_TILE -> -struct Smem_tile_a - : public Smem_tile_row_a { - // The base class. - using Base = Smem_tile_row_a; - - // Ctor. - inline __device__ Smem_tile_a(void *smem, int tidx) : Base(smem, tidx) { - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< - // The dimensions of the tile computed by the CTA. - typename Cta_tile, - // The layout of the tile. - typename Layout, - // The size of the STS. - int BYTES_PER_STS = 16, - // The number of buffers per tile. - int BUFFERS_PER_TILE = 1, - // Use or not predicates - bool USE_PREDICATES = true -> -struct Smem_tile_b { -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< int N > -struct Rows_per_xor_pattern_b { - // The size in bits. - enum { N_IN_BITS = N * fmha::BITS_PER_ELEMENT_B }; - // The number of rows. - enum { VALUE = N_IN_BITS <= 256 ? 2 : (N_IN_BITS <= 512 ? 4 : 8) }; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< int N > -struct Rows_per_xor_pattern_col_b : public Rows_per_xor_pattern_b { -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< - // The dimensions of the tile computed by the CTA. - typename Cta_tile, - // The size of the STS. - int BYTES_PER_STS, - // The number of buffers per tile. - int BUFFERS_PER_TILE, - // How many rows to use for the XOR pattern to avoid bank conflicts? - int ROWS_PER_XOR_PATTERN_ = Rows_per_xor_pattern_col_b::VALUE -> -struct Smem_tile_col_b : public Smem_tile_without_skews { - // The MMA tile. - using Mma_tile = fmha::Hmma_tile; - // The base class. - using Base = Smem_tile_without_skews; - // The fragment. - using Fragment = Fragment_b< Col>; - - // When we use padding to reach a power of two, special care has to be taken. - using Cta_tile_with_padding = Cta_tile_with_k_with_padding< Cta_tile>; - // The number of MMAs. - using Mma_tile_with_padding = fmha::Hmma_tile; - - // The size of a single LDS in bytes. - enum { BYTES_PER_LDS = 16 }; - - // The number of STS per thread - enum { STS_PER_THREAD_ = Base::ROWS * Base::THREADS_PER_ROW / Cta_tile::THREADS_PER_CTA }; - // The number of STS per thread must be at least 1. - enum { STS_PER_THREAD = Max<1, STS_PER_THREAD_>::VALUE }; - - // Ctor. - inline __device__ Smem_tile_col_b(void *smem, int tidx) : Base(smem, tidx) { - - // For documentation on the layout, see doc/mma_smem_layout.xlsx. - - // The number of warps. - const int WARPS_M = Cta_tile::WARPS_M; - const int WARPS_N = Cta_tile::WARPS_N; - const int WARPS_K = Cta_tile::WARPS_K; - static_assert(Base::ROWS_PER_XOR_PATTERN == 2 || Base::ROWS_PER_XOR_PATTERN == 4 || Base::ROWS_PER_XOR_PATTERN == 8); - static_assert(WARPS_M == 1); - static_assert(WARPS_N == 4 || WARPS_N == 8); - static_assert(WARPS_K == 1); - - // The masks to select the warps. - const int WARP_MASK_N = Warp_masks::N; - - // The divisor for the warps. - const int WARP_DIV_N = WARPS_M * 1 * Cta_tile::THREADS_PER_WARP; - - // The row and column read by the thread. - int smem_read_row = (tidx & WARP_MASK_N) / WARP_DIV_N * Mma_tile::N_PER_MMA + - (tidx & 0x07) + - (tidx & 0x10) / 2; - constexpr int ROWS_PER_PACKING = Base::BYTES_PER_ROW / Base::BYTES_PER_ROW_BEFORE_PACKING; - int smem_read_col = ((smem_read_row / ROWS_PER_PACKING) % Base::ROWS_PER_XOR_PATTERN) * Base::COLS_PER_XOR_PATTERN; - smem_read_col ^= (tidx & 0x08) / 8; - // The shared memory offset. - this->smem_read_offset_ = smem_read_row*Base::BYTES_PER_ROW_BEFORE_PACKING + smem_read_col*BYTES_PER_LDS; - } - - // Rewind smem_read_offset for last LDS phase in main loop. - inline __device__ void reverse_smem_read_offset(int ki = 0) { - // Undo the pointer increment for the next ni. - // Should match the load function below for ki = 0. - if( Mma_tile_with_padding::MMAS_K >= 2 ) { - this->smem_read_offset_ ^= BYTES_PER_LDS * 2; - } - } - - // Load from shared memory. - inline __device__ void load(Fragment (&b)[Mma_tile::MMAS_N], int ki) { - #pragma unroll - for( int ni = 0; ni < Mma_tile::MMAS_N; ++ni ) { - // Jump by as many matrix rows as needed (a row in smem may pack multiple matrix rows). - int offset = ni * Mma_tile::N_PER_MMA_PER_CTA * Base::BYTES_PER_ROW_BEFORE_PACKING; - - // Load using LDSM.M88.4. - uint4 tmp; - // ldsm(tmp, this->smem_ + this->smem_read_offset_ + this->smem_read_buffer_ + offset); - ldsm(tmp, this->smem_ + this->smem_read_offset_ + offset); - - // Store the value into the fragment. - b[ni].reg(0) = tmp.x; - b[ni].reg(1) = tmp.y; - b[ni].reg(2) = tmp.z; - b[ni].reg(3) = tmp.w; - } - - // Move the offset to the next possition. See doc/mma_smem_layout.xlsx. - static_assert(Mma_tile_with_padding::MMAS_K < 64, "Not implemented"); - if( Mma_tile_with_padding::MMAS_K >= 32 && ki % 16 == 15 ) { - this->smem_read_offset_ ^= 31 * BYTES_PER_LDS * 2; - } else if( Mma_tile_with_padding::MMAS_K >= 16 && ki % 8 == 7 ) { - this->smem_read_offset_ ^= 15 * BYTES_PER_LDS * 2; - } else if( Mma_tile_with_padding::MMAS_K >= 8 && ki % 4 == 3 ) { - this->smem_read_offset_ ^= 7 * BYTES_PER_LDS * 2; - } else if( Mma_tile_with_padding::MMAS_K >= 4 && ki % 2 == 1 ) { - this->smem_read_offset_ ^= 3 * BYTES_PER_LDS * 2; - } else if( Mma_tile_with_padding::MMAS_K >= 2 ) { - this->smem_read_offset_ ^= 1 * BYTES_PER_LDS * 2; - } - } - - // Reset the read offset. - inline __device__ void reset_read_offset() { - // The number of MMAs in the K dimension. - enum { MMAS_K = Mma_tile::MMAS_K }; - // The number of MMAs in the K dimension when we include padding. - enum { MMAS_K_WITH_PADDING = Mma_tile_with_padding::MMAS_K }; - // Assemble the mask. - enum { MASK = Compute_reset_mask::VALUE }; - - // Reset the read offset. - this->smem_read_offset_ ^= MASK * BYTES_PER_LDS * 2; - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< - // The dimensions of the tile computed by the CTA. - typename Cta_tile, - // The size of the STS. - int BYTES_PER_STS, - // The number of buffers per tile. - int BUFFERS_PER_TILE -> -struct Smem_tile_b< Cta_tile, Col, BYTES_PER_STS, BUFFERS_PER_TILE > - : public Smem_tile_col_b { - - // The base class. - using Base = Smem_tile_col_b< Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>; - - // Ctor. - inline __device__ Smem_tile_b(void *smem, int tidx) : Base(smem, tidx) { - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< int N > -struct Rows_per_xor_pattern_row_b : public Rows_per_xor_pattern_b< N> { -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - -template< - // The dimensions of the tile computed by the CTA. - typename Cta_tile, - // The size of the STS. - int BYTES_PER_STS, - // The number of buffers per tile. - int BUFFERS_PER_TILE, - // How many rows to use for the XOR pattern to avoid bank conflicts? - int ROWS_PER_XOR_PATTERN_ = Rows_per_xor_pattern_row_b::VALUE, - // How many cols to use for the XOR pattern to avoid bank conflicts? - int COLS_PER_XOR_PATTERN_ = 1 -> -struct Smem_tile_row_b : public Smem_tile_without_skews { - - // The MMA tile. - using Mma_tile = fmha::Hmma_tile; - // The base class. - using Base = Smem_tile_without_skews; - // The fragment. - using Fragment = Fragment_b; - - // Can we use LDSM? No if the data type is 32-bit large. - enum { USE_LDSMT = fmha::BITS_PER_ELEMENT_B == 16 }; - // The size of a single LDS in bytes. - enum { BYTES_PER_LDS = USE_LDSMT ? 16 : 4 }; - // The number of elements per LDS. - enum { ELEMENTS_PER_LDS = BYTES_PER_LDS * 8 / fmha::BITS_PER_ELEMENT_B }; - - // The number of STS per thread - enum { STS_PER_THREAD_ = Base::ROWS * Base::THREADS_PER_ROW / Cta_tile::THREADS_PER_CTA }; - // The number of STS per thread must be at least 1. - enum { STS_PER_THREAD = Max<1, STS_PER_THREAD_>::VALUE }; - - // Ctor. - inline __device__ Smem_tile_row_b(void *smem, int tidx) : Base(smem, tidx) { - - // The number of warps. - const int WARPS_M = Cta_tile::WARPS_M; - const int WARPS_N = Cta_tile::WARPS_N; - const int WARPS_K = Cta_tile::WARPS_K; - static_assert(WARPS_K == 1); - static_assert(WARPS_M == 4 || WARPS_M == 8); - static_assert(WARPS_N == 1); - - // The masks to select the warps. - const int WARP_MASK_N = Warp_masks::N; - const int WARP_MASK_K = Warp_masks::K; - - // The divisor for the warps. - const int WARP_DIV_N = WARPS_M * 1 * Cta_tile::THREADS_PER_WARP; - const int WARP_DIV_K = WARPS_M * WARPS_N * Cta_tile::THREADS_PER_WARP; - - - static_assert(USE_LDSMT); - static_assert(Base::ROWS_PER_XOR_PATTERN == 2 || Base::ROWS_PER_XOR_PATTERN == 4 || Base::ROWS_PER_XOR_PATTERN == 8); - - // The row/col read by the thread. - int smem_read_row = (tidx & WARP_MASK_K) / WARP_DIV_K * Mma_tile::MMAS_K * 16 + - (tidx & 0x07) + (tidx & 0x08); - constexpr int ROWS_PER_PACKING = Base::BYTES_PER_ROW / Base::BYTES_PER_ROW_BEFORE_PACKING; - int smem_read_col = ((smem_read_row / ROWS_PER_PACKING) % Base::ROWS_PER_XOR_PATTERN) * Base::COLS_PER_XOR_PATTERN; - smem_read_col ^= (tidx & WARP_MASK_N) / WARP_DIV_N * 2 + (tidx & 0x10) / 16; - - // The shared memory offset. - this->smem_read_offset_ = smem_read_row*Base::BYTES_PER_ROW_BEFORE_PACKING + smem_read_col*BYTES_PER_LDS; - - // Fill zeroes for group conv - } - - // Rewind smem_read_offset for last LDS phase in main loop. - inline __device__ void reverse_smem_read_offset(int ki = 0) { - // The size of each element in bits. - const int BITS_PER_ELT = fmha::BITS_PER_ELEMENT_B; - // The size in bytes of the data needed to compute an MMA per CTA. - const int BYTES_PER_MMA_PER_CTA = Mma_tile::N_PER_MMA_PER_CTA * BITS_PER_ELT / 8; - - #pragma unroll - for( int ni = 0; ni < Mma_tile::MMAS_N; ++ni ) { - // Undo the pointer increment for the next ni. - // Should match the load function below for ki = 0. - if( BYTES_PER_MMA_PER_CTA >= 128 ) { - // Nothing to do! - } else if( BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1 ) { - this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA; - } else if( BYTES_PER_MMA_PER_CTA == 64 ) { - // Nothing to do! - } else if( BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 4 ) { - this->smem_read_offset_ ^= BYTES_PER_LDS * (ni % 2 == 0 ? 2 : 6); - } else if( BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 2 ) { - this->smem_read_offset_ ^= BYTES_PER_LDS * 2; - } - } - - // Reset smem_read_offset for odd MMAS_N > 1 (npo2 kernels) - if( BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1 && - Mma_tile::MMAS_N % 2 == 1 ) { - this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA; - } - } - - // Load from shared memory. - inline __device__ void load(Fragment (&b)[Mma_tile::MMAS_N], int ki) { - // The size of each element in bits. - const int BITS_PER_ELT = fmha::BITS_PER_ELEMENT_B; - // The size in bytes of the data needed to compute an MMA per CTA. - const int BYTES_PER_MMA_PER_CTA = Mma_tile::N_PER_MMA_PER_CTA * BITS_PER_ELT / 8; - - // uint32_t smem_read_og = this->smem_ + this->smem_read_offset_; - #pragma unroll - for( int ni = 0; ni < Mma_tile::MMAS_N; ++ni ) { - // Prepare the offset. - int offset = ki * Base::ROWS_PER_XOR_PATTERN * 2 * Base::BYTES_PER_ROW_BEFORE_PACKING; - if ( BYTES_PER_MMA_PER_CTA == 32 ) { - offset += this->smem_read_offset_; - } else if ( BYTES_PER_MMA_PER_CTA == 64 ) { - offset += this->smem_read_offset_ + (ni/2) * BYTES_PER_MMA_PER_CTA * 2; - } else { - offset += this->smem_read_offset_ + (ni ) * BYTES_PER_MMA_PER_CTA; - } - - // Load the data using LDSM.MT88.2. - // uint32_t ptr = this->smem_ + this->smem_read_buffer_ + offset; - uint32_t ptr = this->smem_ + offset; - uint4 tmp; - if( USE_LDSMT ) { - ldsmt(tmp, ptr); - } else { - lds(tmp.x, (ptr ) + 0*Base::BYTES_PER_ROW_BEFORE_PACKING); - lds(tmp.y, (ptr ) + 4*Base::BYTES_PER_ROW_BEFORE_PACKING); - lds(tmp.z, (ptr ^ 32) + 0*Base::BYTES_PER_ROW_BEFORE_PACKING); - lds(tmp.w, (ptr ^ 32) + 4*Base::BYTES_PER_ROW_BEFORE_PACKING); - } - - // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - // printf("BYTES_PER_MMA_PER_CTA=%d, ni = %d, smem_read diff = %d\n", BYTES_PER_MMA_PER_CTA, ni, ptr - smem_read_og); - // } - // Store those values in the fragment. - b[ni].reg(0) = tmp.x; - b[ni].reg(1) = tmp.y; - b[ni].reg(2) = tmp.z; - b[ni].reg(3) = tmp.w; - - // Move the pointer for the next ni. I expect the compiler to not recompute those. - if( BYTES_PER_MMA_PER_CTA >= 128 ) { - // Nothing to do! - } else if( BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1 ) { - this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA; - } else if( BYTES_PER_MMA_PER_CTA == 64 ) { - // Nothing to do! - } else if( BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 8 ) { - this->smem_read_offset_ ^= BYTES_PER_LDS * (ni % 4 == 3 ? 14 : (ni % 2 == 1 ? 6 : 2)); - } else if( BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 4 ) { - this->smem_read_offset_ ^= BYTES_PER_LDS * (ni % 2 == 0 ? 2 : 6); - } else if( BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 2 ) { - this->smem_read_offset_ ^= BYTES_PER_LDS * 2; - } - } - - // Reset smem_read_offset for odd MMAS_N > 1 (npo2 kernels) - if( BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1 && - Mma_tile::MMAS_N % 2 == 1 ) { - this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA; - } - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< - // The dimensions of the tile computed by the CTA. - typename Cta_tile, - // The size of the STS. - int BYTES_PER_STS, - // The number of buffers per tile. - int BUFFERS_PER_TILE -> -struct Smem_tile_b - : public Smem_tile_row_b { - - // The base class. - using Base = Smem_tile_row_b; - - // Ctor. - inline __device__ Smem_tile_b(void *smem, int tidx) : Base(smem, tidx) { - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Smem_tile_v : public fmha::Smem_tile_without_skews::VALUE, 1> { - - // The base class. - using Base = Smem_tile_without_skews::VALUE, 1>; - // The MMA tile. - using Mma_tile = fmha::Hmma_tile; - // The fragment. - using Fragment = Fragment_b< fmha::Col>; - - // The size of a single LDS in bytes. - enum { BYTES_PER_LDS = 16 }; - - // Ctor. - inline __device__ Smem_tile_v(void *smem, int tidx) : Base(smem, tidx) { - - // The row/col read by the thread. - int read_row, read_col; - - static_assert(Cta_tile::WARPS_M == 1 && Cta_tile::WARPS_N == 1 && (Cta_tile::WARPS_K == 4 || Cta_tile::WARPS_K == 8)); - - read_row = (tidx & 0xe0) / 2 + (tidx & 0x0f); - constexpr int ROWS_PER_PACKING = Base::BYTES_PER_ROW / Base::BYTES_PER_ROW_BEFORE_PACKING; - read_col = ((read_row / ROWS_PER_PACKING) % Base::ROWS_PER_XOR_PATTERN) * Base::COLS_PER_XOR_PATTERN; - read_col ^= (tidx & 0x10) / 16; - - // The shared memory offset. - this->smem_read_offset_ = read_row * Base::BYTES_PER_ROW_BEFORE_PACKING + read_col * BYTES_PER_LDS; - } - - // Load from shared memory. - inline __device__ void load(Fragment (&b)[Mma_tile::MMAS_N], int ki) { -#pragma unroll - for( int ni = 0; ni < Mma_tile::MMAS_N; ++ni ) { - // Jump by 16 * #warps row. - int row = ki * 16 * Cta_tile::WARPS_K; - - // Load the data using LDSM.MT88.2. - uint4 tmp; - fmha::ldsmt(tmp, this->smem_ + this->smem_read_offset_ + row * Base::BYTES_PER_ROW_BEFORE_PACKING); - b[ni].reg(0) = tmp.x; - b[ni].reg(1) = tmp.y; - b[ni].reg(2) = tmp.z; - b[ni].reg(3) = tmp.w; - - // Move the pointer for the next ni. I expect the compiler to not recompute those. - if( Mma_tile::MMAS_N == 1 ) { - // noop - } else if( Mma_tile::MMAS_N == 2 ) { - this->smem_read_offset_ ^= BYTES_PER_LDS * 2; - } else if( Mma_tile::MMAS_N == 4 ) { - this->smem_read_offset_ ^= BYTES_PER_LDS * (ni % 2 == 0 ? 2 : 6); - } else if (Mma_tile::MMAS_N == 8) { - this->smem_read_offset_ ^= BYTES_PER_LDS * (ni % 4 == 3 ? 14 : (ni % 2 == 1 ? 6 : 2)); - } else { - assert(false); // Not implemented! - } - } - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Smem_tile_o { - - // The MMA tile. - using Mma_tile = fmha::Hmma_tile; - // The accumulators. - using Accumulator = fmha::Fragment_accumulator; - // The accumulators. - using Data_type = typename Accumulator::Data_type; - - // The size of each element. - static constexpr int BYTES_PER_ELEMENT = sizeof(Data_type); - // The size of each STS. - static constexpr int BYTES_PER_STS = 8; - // The size of each row in shared memory. - static constexpr int BYTES_PER_ROW = Cta_tile::N * Cta_tile::WARPS_K * BYTES_PER_ELEMENT; - - // The size of each LDS. - static constexpr int BYTES_PER_LDS = 16; - static constexpr int THREADS_PER_ROW = Cta_tile::N * BYTES_PER_ELEMENT / BYTES_PER_LDS; - - // The number of rows. - static constexpr int ROWS = Cta_tile::M; - // The number of "rows" to process per loop iteration (in the "epilogue"). - static constexpr int ROWS_PER_LOOP = ROWS <= 64 ? ROWS : (int)Mma_tile::M_PER_MMA_PER_CTA; - // The number of outer loops. - static constexpr int LOOPS = ROWS / ROWS_PER_LOOP; - // Make sure it matches our expectations. - static_assert(LOOPS == 1 || LOOPS == (int)Mma_tile::MMAS_M, ""); - - // The number of rows loaded per LDS. - static constexpr int ROWS_PER_LDS = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW; - // Do we have to guard against partial writes/reads. - static constexpr bool HAS_INCOMPLETE_LDS = ROWS_PER_LOOP % ROWS_PER_LDS != 0; - // The total number of LDS per loop. - static constexpr int LDS_PER_LOOP = fmha::DivUpConstexpr(ROWS_PER_LOOP, ROWS_PER_LDS); - - // The amount of shared memory. - static constexpr int BYTES_PER_TILE = ROWS_PER_LOOP * BYTES_PER_ROW; - - // The write pointer. - uint32_t smem_write_, smem_read_; - // Is the thread active for the last LDS of the series? - int is_active_for_last_lds_; - - // static_assert(BYTES_PER_ROW == 64 * 4 * Cta_tile::WARPS_K); - static_assert(LOOPS == 1 || LOOPS == (int)Mma_tile::MMAS_M, ""); - - // Ctor. - inline __device__ Smem_tile_o(void *smem, int tidx) { - - // Get a 32-bit value for the shared memory address. - uint32_t smem_ = __nvvm_get_smem_pointer(smem); - - static_assert(Cta_tile::WARPS_M == 1 && Cta_tile::WARPS_N == 1 && (Cta_tile::WARPS_K == 4 || Cta_tile::WARPS_K == 8)); - static_assert(Cta_tile::N == 16 || Cta_tile::N == 32 || Cta_tile::N == 64 || Cta_tile::N == 128); - - int write_row = (tidx & 0x1c) / 4; - - const int lane = tidx % 32; - const int warp = tidx / 32; - - constexpr int ELEMENTS_PER_STS = BYTES_PER_STS / BYTES_PER_ELEMENT; - constexpr int STS_PER_WARP = 16 * Mma_tile::MMAS_N / ELEMENTS_PER_STS; - int write_col = warp * STS_PER_WARP + lane % STS_PER_WARP; - - // if ((threadIdx.x == 16) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - // printf("write_row = %d, write_col = %d\n", write_row, write_col); - // } - - // if ((blockIdx.x == 0) && (blockIdx.y == 0) && (write_row == 0) && (write_col == 0)) { - // printf("threadIdx.x = %d\n", threadIdx.x); - // } - - // Assemble the write pointer. - smem_write_ = smem_ + write_row * BYTES_PER_ROW + write_col * BYTES_PER_STS; - - // The element read by each thread. - int read_row = tidx / THREADS_PER_ROW; - int read_col = tidx % THREADS_PER_ROW; - - // Take the XOR pattern into account for the column. - read_col ^= 2 * (read_row % (Cta_tile::N == 16 ? 2 : (Cta_tile::N == 32 ? 4 : 8))); - // read_col ^= 2 * (read_row % (Cta_tile::N == 16 ? 2 : (Cta_tile::N == 32 ? 4 : (Cta_tile::N == 128 ? 16 : 8)))); - - // if ((threadIdx.x == 8) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - // printf("read_row = %d, read_col = %d\n", read_row, read_col); - // } - // if ((blockIdx.x == 0) && (blockIdx.y == 0) && (read_row == 0) && (read_col == 0)) { - // printf("threadIdx.x = %d\n", threadIdx.x); - // } - // Assemble the read pointer. - this->smem_read_ = smem_ + read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS; - - // Is that thread active on the last LDS? - if( HAS_INCOMPLETE_LDS ) { - this->is_active_for_last_lds_ = read_row + (LDS_PER_LOOP - 1) * ROWS_PER_LDS < Cta_tile::M; - } - } - - // Load the output fragments. - template - inline __device__ void load(uint4 (&out)[LDS_PER_LOOP]) const { - #pragma unroll - for( int ii = 0; ii < LDS_PER_LOOP; ++ii ) { - - // Load the elements before the reduction (split-K). - uint4 tmp[Cta_tile::WARPS_K]; - #pragma unroll - for( int jj = 0; jj < Cta_tile::WARPS_K; ++jj ) { - int imm = ii * ROWS_PER_LDS * BYTES_PER_ROW + jj * Cta_tile::N * BYTES_PER_ELEMENT; - uint32_t smem_read = this->smem_read_ + imm; - // TD [2022-06-05] Ugly fix for d=128 in the forward pass, maybe there's a better way. - if ((Cta_tile::N == 128) && (ROWS_PER_LDS == 4) && (ii % 2 == 1)) { - smem_read ^= 8 * BYTES_PER_LDS; - } - // if ((threadIdx.x == 8) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - // printf("imm diff = %d\n", smem_read - this->smem_read_); - // } - if( !HAS_INCOMPLETE_LDS || (ii < LDS_PER_LOOP - 1 || this->is_active_for_last_lds_) ) { - // fmha::lds(tmp[jj], this->smem_read_ + imm); - fmha::lds(tmp[jj], smem_read); - } - } - - // Perform the reduction. - out[ii] = zero_init ? tmp[0] : fmha::fadd4(out[ii], tmp[0]); - // if ((threadIdx.x == 8) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - // printf("out reduction: out = %.6f\n", reinterpret_cast(out[ii])[0]); - // } - #pragma unroll - for( int jj = 1; jj < Cta_tile::WARPS_K; ++jj ) { - out[ii] = fmha::fadd4(out[ii], tmp[jj]); - // if ((threadIdx.x == 8) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - // printf("out reduction tmp = %.6f, out = %.6f\n", reinterpret_cast(tmp[jj])[0], reinterpret_cast(out[ii])[0]); - // } - } - } - } - - // Store the accumulators. - template - inline __device__ void store(const Accumulator (&acc)[M][N], int mi) { - // uint32_t smem_write_og = this->smem_write_; - static constexpr int M_PER_MMA = Mma_tile::M_PER_MMA_PER_CTA; - #pragma unroll - for( int ni = 0; ni < Mma_tile::MMAS_N; ++ni ) { - - // The number of MMAs that are stored per loop iteration. - static constexpr int MMAS_M_PER_LOOP = Mma_tile::MMAS_M / LOOPS; - - // Store 1st column of the different MMAs. - #pragma unroll - for( int mj = 0; mj < MMAS_M_PER_LOOP; ++mj ) { - // Precompute the immediates to jump between rows. - int row_0 = (mj * M_PER_MMA + 0) * BYTES_PER_ROW; - int row_1 = (mj * M_PER_MMA + 8) * BYTES_PER_ROW; - uint2 tmp0, tmp1; - tmp0.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(0); - tmp0.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(1); - - tmp1.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(2); - tmp1.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(3); - - // Store. - fmha::sts(this->smem_write_ + row_0, tmp0); - fmha::sts(this->smem_write_ + row_1, tmp1); - } - // if ((threadIdx.x == 16) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - // printf("smem_write diff = %d\n", this->smem_write_ - smem_write_og); - // } - - // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - // uint4 read_tmp; - // fmha::lds(read_tmp, this->smem_read_); - // printf("smem_o = %.6f\n", reinterpret_cast(read_tmp)[0]); - // } - // Swizzle the write pointer using a XOR of 16B. - this->smem_write_ ^= 32; - - // Store 2nd column of the different MMAs. - #pragma unroll - for( int mj = 0; mj < MMAS_M_PER_LOOP; ++mj ) { - // Precompute the immediates to jump between rows. - int row_0 = (mj * M_PER_MMA + 0) * BYTES_PER_ROW; - int row_1 = (mj * M_PER_MMA + 8) * BYTES_PER_ROW; - - uint2 tmp0, tmp1; - tmp0.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(4); - tmp0.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(5); - - tmp1.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(6); - tmp1.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(7); - // Store. - fmha::sts(this->smem_write_ + row_0, tmp0); - fmha::sts(this->smem_write_ + row_1, tmp1); - } - - // if ((threadIdx.x == 16) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - // printf("smem_write diff = %d\n", this->smem_write_ - smem_write_og); - // } - - // Cancel the previous XOR of 1 + swizzle the write pointer using a XOR of 32B or 64B. - static_assert(Mma_tile::MMAS_N <= 8, "Not implemented"); - if( Mma_tile::MMAS_N >= 8 && ni % 4 == 3 ) { - this->smem_write_ ^= 15 * 32; - } else if( Mma_tile::MMAS_N >= 4 && ni % 2 == 1 ) { - this->smem_write_ ^= 7 * 32; - } else if( Mma_tile::MMAS_N >= 2 ) { - this->smem_write_ ^= 3 * 32; - } else { - this->smem_write_ ^= 3 * 32; - } - // this->smem_write_ ^= (ni & 1) ? 7 * 32 : 3 * 32; - // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - // uint4 read_tmp; - // fmha::lds(read_tmp, this->smem_read_); - // printf("smem_o = %.6f\n", reinterpret_cast(read_tmp)[0]); - // } - } - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Smem_tile_mma { - - using Mma_tile = fmha::Hmma_tile; - using Fragment = fmha::Fragment_a; - - enum { COLS = Cta_tile::N }; - enum { BYTES_PER_ELT = 2 }; - enum { BYTES_PER_STS = 4 }; - enum { BYTES_PER_ROW = COLS * BYTES_PER_ELT }; // TODO - enum { BYTES_PER_TILE = Cta_tile::M * BYTES_PER_ROW }; - - enum { WARPS_M = Cta_tile::WARPS_M }; - enum { WARPS_N = Cta_tile::WARPS_N }; - enum { WARPS_K = Cta_tile::WARPS_K }; - - static_assert(WARPS_K == 1); - inline __device__ Smem_tile_mma(char *smem, int tidx) { - uint32_t smem_ = __nvvm_get_smem_pointer(smem); - - int write_col, write_row; - static_assert(WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8) || (WARPS_M == 4 || WARPS_M == 8) || WARPS_N == 1); - if( WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8) ) { - write_row = (tidx & 0x1c) / 4; - write_col = (tidx & 0xe0) / 4 + (tidx & 0x03); - write_col ^= (write_row & 0x07) * 4; - } else { - write_row = (tidx & 0xe0) / 2 + (tidx & 0x1c) / 4; - write_col = (tidx & 0x03); - // write_col ^= (write_row & (BYTES_PER_ROW == 32 ? 0x01 : (BYTES_PER_ROW == 64 ? 0x03 : (BYTES_PER_ROW == 128 ? 0x07 : 0x0f)))) * 4; - write_col ^= (write_row & (BYTES_PER_ROW == 32 ? 0x01 : (BYTES_PER_ROW == 64 ? 0x03 : (BYTES_PER_ROW == 128 ? 0x07 : 0x07)))) * 4; - } - - // write_offset_ = write_row * BYTES_PER_ROW + write_col * BYTES_PER_STS; - smem_write_ = smem_ + write_row * BYTES_PER_ROW + write_col * BYTES_PER_STS; - } - - template - inline __device__ void store(const uint4 (®s)[M][N]) { - static_assert(COLS == Cta_tile::N); - #pragma unroll - for( int mi = 0; mi < M; mi++ ) { - #pragma unroll - for( int ni = 0; ni < N; ni++ ) { - // size_t offset = write_offset_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT; - // fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].x); - // fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].z); - // offset ^= 4 * BYTES_PER_STS; - // fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].y); - // fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].w); - // size_t offset = smem_write_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT; - uint32_t offset = smem_write_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT; - fmha::sts(offset + 0 * BYTES_PER_ROW, regs[mi][ni].x); - fmha::sts(offset + 8 * BYTES_PER_ROW, regs[mi][ni].z); - offset ^= 4 * BYTES_PER_STS; - fmha::sts(offset + 0 * BYTES_PER_ROW, regs[mi][ni].y); - fmha::sts(offset + 8 * BYTES_PER_ROW, regs[mi][ni].w); - } - } - } - - template - inline __device__ void store(const Fragment (&frag)[N][M]) { - static_assert(COLS == Cta_tile::N); - uint4 regs[M][N]; - #pragma unroll - for( int mi = 0; mi < M; mi++ ) { - #pragma unroll - for( int ni = 0; ni < N; ni++ ) { - // Need to transpose ref(1) and reg(2) here since when we load it we transpose again. - regs[mi][ni] = make_uint4(frag[ni][mi].reg(0), frag[ni][mi].reg(2), - frag[ni][mi].reg(1), frag[ni][mi].reg(3)); - } - } - this->store(regs); - } - - // uint32_t smem_; - // uint32_t write_offset_; - uint32_t smem_write_; -}; - -template< typename Cta_tile, typename Base = Smem_tile_mma< Cta_tile>> -struct Smem_tile_mma_transposed : public Base { - enum { BYTES_PER_LDS = 16 }; - enum { BYTES_PER_ROW = Base::BYTES_PER_ROW }; - enum { BYTES_PER_ELT = Base::BYTES_PER_ELT }; - enum { WARPS_M = Base::WARPS_M }; - enum { WARPS_N = Base::WARPS_N }; - static_assert(WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8)); - using Fragment = typename Base::Fragment; - inline __device__ Smem_tile_mma_transposed(char *smem, int tidx) : Base(smem, tidx) { - - uint32_t smem_ = __nvvm_get_smem_pointer(smem); - static_assert(WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8)); - int read_row, read_col; - read_row = (tidx & 0x0f); - read_col = (tidx & 0xe0) / 16 + (tidx & 0x1c) / 16; - - // read_col ^= (read_row & (Base::BYTES_PER_ROW == 32 ? 0x01 : (Base::BYTES_PER_ROW == 64 ? 0x03 : (Base::BYTES_PER_ROW == 128 ? 0x07 : 0x0f)))); - read_col ^= (read_row & 0x07); - // read_offset_ = read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS; - smem_read_ = smem_ + read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS; - } - - template - inline __device__ void load(Fragment (&frag)[M][N]) { - static_assert(Base::COLS == Cta_tile::N); - for( int mi = 0; mi < M; mi++ ) { - for( int ni = 0; ni < N; ni++ ) { - // size_t offset = read_offset_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT; - uint4 dst; - // fmha::ldsmt(dst, this->smem_ + offset); - // size_t offset = smem_read_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT; - uint32_t offset = smem_read_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT; - fmha::ldsmt(dst, offset); - frag[mi][ni].reg(0) = dst.x; - frag[mi][ni].reg(1) = dst.z; // Fragment A regs col major! - frag[mi][ni].reg(2) = dst.y; - frag[mi][ni].reg(3) = dst.w; - } - } - } - - // uint32_t read_offset_; - uint32_t smem_read_; -}; - -template< typename Cta_tile, typename Base = Smem_tile_mma< Cta_tile>> -struct Smem_tile_mma_epilogue : public Base { - enum { BYTES_PER_LDS = 16 }; - enum { BYTES_PER_ROW = Base::BYTES_PER_ROW }; - enum { BYTES_PER_ELT = Base::BYTES_PER_ELT }; - enum { THREADS_PER_ROW = BYTES_PER_ROW / BYTES_PER_LDS }; - static_assert(THREADS_PER_ROW * BYTES_PER_LDS == BYTES_PER_ROW); - enum { ROWS_PER_LDS = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW }; - enum { NUM_LDS = Cta_tile::M / ROWS_PER_LDS }; - static_assert(NUM_LDS * ROWS_PER_LDS == Cta_tile::M); - enum { WARPS_M = Base::WARPS_M }; - enum { WARPS_N = Base::WARPS_N }; - static_assert((WARPS_M == 4 || WARPS_N == 8) || WARPS_N == 1); - - using Acc = fmha::Fragment_accumulator; - - inline __device__ Smem_tile_mma_epilogue(char *smem, int tidx) : Base(smem, tidx) { - uint32_t smem_ = __nvvm_get_smem_pointer(smem); - const int read_row = tidx / THREADS_PER_ROW; - int read_col = tidx % THREADS_PER_ROW; - // read_col ^= (read_row & (Base::BYTES_PER_ROW == 32 ? 0x01 : (Base::BYTES_PER_ROW == 64 ? 0x03 : 0x07))); - static_assert(Base::BYTES_PER_ROW == 32 || Base::BYTES_PER_ROW == 64 || Base::BYTES_PER_ROW == 128 || Base::BYTES_PER_ROW == 256); - read_col ^= (read_row & (Base::BYTES_PER_ROW == 32 ? 0x01 : (Base::BYTES_PER_ROW == 64 ? 0x03 : (Base::BYTES_PER_ROW == 128 ? 0x07 : 0x07)))); - // read_offset_ = read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS; - smem_read_ = smem_ + read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS; - } - - inline __device__ void load(uint4 (&data)[NUM_LDS]) { - for( int ii = 0; ii < NUM_LDS; ii++ ) { - // size_t offset = read_offset_ + ii * ROWS_PER_LDS * BYTES_PER_ROW; - // fmha::lds(data[ii], this->smem_ + offset); - // size_t offset = smem_read_ + ii * ROWS_PER_LDS * BYTES_PER_ROW; - uint32_t offset = smem_read_ + ii * ROWS_PER_LDS * BYTES_PER_ROW; - fmha::lds(data[ii], offset); - } - } - - template - inline __device__ void store(const Acc (&acc)[M][N]){ - #pragma unroll - for( int mi = 0; mi < M; mi++ ) { - #pragma unroll - for( int ni = 0; ni < N; ni++ ) { - // 1st row - 4 elements per row. - float tmp00 = acc[mi][ni].elt(0); - float tmp01 = acc[mi][ni].elt(1); - float tmp02 = acc[mi][ni].elt(4); - float tmp03 = acc[mi][ni].elt(5); - // 2nd row - 4 elements per row. - float tmp10 = acc[mi][ni].elt(2); - float tmp11 = acc[mi][ni].elt(3); - float tmp12 = acc[mi][ni].elt(6); - float tmp13 = acc[mi][ni].elt(7); - - uint32_t x = fmha::float2_pack(tmp00, tmp01); - uint32_t y = fmha::float2_pack(tmp02, tmp03); - uint32_t z = fmha::float2_pack(tmp10, tmp11); - uint32_t w = fmha::float2_pack(tmp12, tmp13); - - // size_t offset = (this->write_offset_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW; - // fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, x); - // fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, z); - // offset ^= 4 * Base::BYTES_PER_STS; - // fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, y); - // fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, w); - // size_t offset = (this->smem_write_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW; - uint32_t offset = (this->smem_write_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW; - // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - // printf("mi = %d, ni = %d, offset - smem_write_ = %d\n", mi, ni, offset - this->smem_write_); - // } - fmha::sts(offset + 0 * BYTES_PER_ROW, x); - fmha::sts(offset + 8 * BYTES_PER_ROW, z); - offset ^= 4 * Base::BYTES_PER_STS; - fmha::sts(offset + 0 * BYTES_PER_ROW, y); - fmha::sts(offset + 8 * BYTES_PER_ROW, w); - } - } - } - - template - inline __device__ void store(const uint4 (®s)[M][N]) { - for( int mi = 0; mi < M; mi++ ) { - for( int ni = 0; ni < N; ni++ ) { - // size_t offset = (this->write_offset_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW; - uint32_t offset = (this->write_offset_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW; - fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].x); - fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].z); - offset ^= 4 * Base::BYTES_PER_STS; - fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].y); - fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].w); - } - } - } - - // uint32_t read_offset_; - uint32_t smem_read_; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Smem_tile_transpose { - - using Mma_tile = fmha::Hmma_tile; - using Fragment_write = fmha::Fragment_b; - using Fragment_read = fmha::Fragment_b; - - enum { COLS = Cta_tile::N }; - enum { BYTES_PER_ELT = 2 }; - enum { BYTES_PER_STS = 4 }; - enum { BYTES_PER_ROW = COLS * BYTES_PER_ELT }; // TODO - enum { BYTES_PER_TILE = Cta_tile::M * BYTES_PER_ROW }; - - enum { BYTES_PER_LDS = 16 }; - - enum { WARPS_M = Cta_tile::WARPS_M }; - enum { WARPS_N = Cta_tile::WARPS_N }; - enum { WARPS_K = Cta_tile::WARPS_K }; - - static_assert(WARPS_K == 1); - static_assert(WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8)); - - inline __device__ Smem_tile_transpose(char *smem, int tidx) { - smem_ = __nvvm_get_smem_pointer(smem); - // uint32_t smem_ = __nvvm_get_smem_pointer(smem); - - int write_col, write_row; - static_assert(WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8) || (WARPS_M == 4 || WARPS_N == 8) || WARPS_N == 1); - if( WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8) ) { - write_row = (tidx & 0x1c) / 4; - write_col = (tidx & 0xe0) / 4 + (tidx & 0x03); - } else { - write_row = (tidx & 0xe0) / 2 + (tidx & 0x1c) / 4; - write_col = (tidx & 0x03); - } - write_col ^= (write_row & 0x07) * 4; - - write_offset_ = write_row * BYTES_PER_ROW + write_col * BYTES_PER_STS; - // smem_write_ = smem_ + write_row * BYTES_PER_ROW + write_col * BYTES_PER_STS; - - int read_row, read_col; - read_row = (tidx & 0x0f); - read_col = (tidx & 0xe0) / 16 + (tidx & 0x1c) / 16; - - read_col ^= (read_row & 0x07); - read_offset_ = read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS; - // smem_read_ = smem_ + read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS; - } - - template - inline __device__ void store(const Fragment_write (&frag_w)[M][N], int mi) { - #pragma unroll - for( int ni = 0; ni < N; ni++ ) { - // size_t offset = write_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT; - uint32_t offset = write_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT; - fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, frag_w[ni][mi].reg(0)); - fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, frag_w[ni][mi].reg(2)); - offset ^= 4 * BYTES_PER_STS; - fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, frag_w[ni][mi].reg(1)); - fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, frag_w[ni][mi].reg(3)); - } - } - - template - inline __device__ void load(Fragment_read (&frag_r)[N]) { - #pragma unroll - for( int ni = 0; ni < N; ni++ ) { - // size_t offset = read_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT; - uint32_t offset = read_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT; - uint4 dst; - fmha::ldsmt(dst, this->smem_ + offset); - frag_r[ni].reg(0) = dst.x; - frag_r[ni].reg(1) = dst.y; // Fragment B regs col major! - frag_r[ni].reg(2) = dst.z; - frag_r[ni].reg(3) = dst.w; - } - } - - template - inline __device__ void transpose(const Fragment_write (&frag_w)[M][N], Fragment_read (&frag_r)[M], int mi) { - static_assert(COLS == Cta_tile::N); - #pragma unroll - for( int ni = 0; ni < N; ni++ ) { - // size_t offset = write_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT; - uint32_t offset = write_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT; - fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, frag_w[ni][mi].reg(0)); - fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, frag_w[ni][mi].reg(2)); - offset ^= 4 * BYTES_PER_STS; - fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, frag_w[ni][mi].reg(1)); - fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, frag_w[ni][mi].reg(3)); - } - #pragma unroll - for( int ni = 0; ni < N; ni++ ) { - // size_t offset = read_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT; - // size_t offset = read_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT; - uint32_t offset = read_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT; - uint4 dst; - fmha::ldsmt(dst, this->smem_ + offset); - frag_r[ni].reg(0) = dst.x; - frag_r[ni].reg(1) = dst.y; // Fragment B regs col major! - frag_r[ni].reg(2) = dst.z; - frag_r[ni].reg(3) = dst.w; - } - } - - uint32_t smem_; - uint32_t write_offset_; - uint32_t read_offset_; - // uint32_t smem_write_; - // uint32_t smem_read_; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< - typename Gmem_tile, - // The number of buffers. (Used in multistage and double buffer cases.) - int BUFFERS_PER_TILE_ = 1 -> -struct Smem_tile_dp_sum { - - using Cta_tile = typename Gmem_tile::Cta_tile; - using Mma_tile = fmha::Hmma_tile; - - // The size of each element. - static constexpr int BYTES_PER_ELEMENT = 4; - static constexpr int ROWS = Gmem_tile::ROWS; - static constexpr int THREADS_PER_ROW = Gmem_tile::THREADS_PER_ROW; - static constexpr int MMAS_M = Mma_tile::MMAS_M; - - static constexpr int ROWS_PER_LDG = Gmem_tile::ROWS_PER_LDG; - static constexpr int LDGS = Gmem_tile::LDGS; - - static constexpr int ROWS_PER_MMA = Mma_tile::M_PER_MMA; - - // The size of one buffer in bytes in shared memory. - static constexpr int BYTES_PER_BUFFER = ROWS * BYTES_PER_ELEMENT; - // The number of buffers. - static constexpr int BUFFERS_PER_TILE = BUFFERS_PER_TILE_; - // The size in bytes of total buffers. - static constexpr int BYTES_PER_TILE = BYTES_PER_BUFFER * BUFFERS_PER_TILE; - // The boundary for smem_read_offset and smem_write_offset increment. - static constexpr int ROWS_PER_TILE_INC_BOUNDARY = ROWS * BUFFERS_PER_TILE - ROWS; - - inline __device__ Smem_tile_dp_sum(float *smem, const int tidx) - : smem_(smem), smem_read_buffer_(smem), smem_write_buffer_(smem), tidx_(tidx) { - } - - // Move the read offset to next buffer. - inline __device__ void move_to_next_read_buffer() { - if( BUFFERS_PER_TILE > 1 && (smem_read_buffer_ - smem_) >= ROWS_PER_TILE_INC_BOUNDARY ) { - this->smem_read_buffer_ -= ROWS_PER_TILE_INC_BOUNDARY; - } else if( BUFFERS_PER_TILE > 1 ) { - this->smem_read_buffer_ += ROWS; - } - } - - // Move the write offset to next buffer. - inline __device__ void move_to_next_write_buffer() { - if( BUFFERS_PER_TILE > 1 && (smem_write_buffer_ - smem_) >= ROWS_PER_TILE_INC_BOUNDARY ) { - this->smem_write_buffer_ -= ROWS_PER_TILE_INC_BOUNDARY; - } else if( BUFFERS_PER_TILE > 1 ) { - this->smem_write_buffer_ += ROWS; - } - } - - inline __device__ void store(const float (&sum)[LDGS]) { - if (tidx_ % THREADS_PER_ROW == 0) { - int row = tidx_ / THREADS_PER_ROW; - #pragma unroll - for (int i = 0; i < LDGS; ++i) { - if (row + i * ROWS_PER_LDG < ROWS) { - smem_write_buffer_[row + i * ROWS_PER_LDG] = sum[i]; - } - } - } - } - - inline __device__ void store(const float sum, const int buffer_idx) { - float *smem_write = smem_ + buffer_idx * ROWS; - int row = tidx_ / THREADS_PER_ROW; - if ((row < ROWS) && (tidx_ % THREADS_PER_ROW == 0)) { - smem_write[row] = sum; - } - } - - inline __device__ void store(const float (&sum)[LDGS], const int buffer_idx) { - float *smem_write = smem_ + buffer_idx * ROWS; - if (tidx_ % THREADS_PER_ROW == 0) { - int row = tidx_ / THREADS_PER_ROW; - #pragma unroll - for (int i = 0; i < LDGS; ++i) { - if (row + i * ROWS_PER_LDG < ROWS) { - smem_write[row + i * ROWS_PER_LDG] = sum[i]; - } - } - } - } - - inline __device__ void store_pair(const float (&sum)[MMAS_M * 2]) { - float *smem_write = smem_; - // Extract the position in the warp. - int warp = tidx_ / Cta_tile::THREADS_PER_WARP; - int lane = tidx_ % Cta_tile::THREADS_PER_WARP; - int row = lane / 4; - #pragma unroll - for (int mi = 0; mi < MMAS_M; ++mi) { - smem_write[mi * ROWS_PER_MMA + row + 0] = sum[mi * 2 + 0]; - smem_write[mi * ROWS_PER_MMA + row + 8] = sum[mi * 2 + 1]; - } - } - - inline __device__ void store_pair(const float (&sum)[MMAS_M * 2], const int buffer_idx) { - float *smem_write = smem_ + buffer_idx * ROWS; - // Extract the position in the warp. - int warp = tidx_ / Cta_tile::THREADS_PER_WARP; - int lane = tidx_ % Cta_tile::THREADS_PER_WARP; - int row = lane / 4; - #pragma unroll - for (int mi = 0; mi < MMAS_M; ++mi) { - smem_write[mi * ROWS_PER_MMA + row + 0] = sum[mi * 2 + 0]; - smem_write[mi * ROWS_PER_MMA + row + 8] = sum[mi * 2 + 1]; - } - } - - template - inline __device__ void load(float (&sum)[N], const int (&row)[N]) { - #pragma unroll - for( int ni = 0; ni < N; ni++ ) { - sum[ni] = smem_read_buffer_[row[ni]]; - } - } - - template - inline __device__ void load(float (&sum)[N], const int (&row)[N], const int buffer_idx) { - float *smem_read = smem_ + buffer_idx * ROWS; - #pragma unroll - for( int ni = 0; ni < N; ni++ ) { - sum[ni] = smem_read[row[ni]]; - } - } - - static inline __device__ float reduce_warp(float sum) { - fmha::SumOp sum_op; - return fmha::Allreduce::run(sum, sum_op); - } - - const int tidx_; - float * const smem_; - float *smem_read_buffer_; - float *smem_write_buffer_; -}; - -} // namespace fmha diff --git a/csrc/flash_attn/src/fmha/softmax.h b/csrc/flash_attn/src/fmha/softmax.h deleted file mode 100644 index bd874375e..000000000 --- a/csrc/flash_attn/src/fmha/softmax.h +++ /dev/null @@ -1,607 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the - * names of its contributors may be used to endorse or promote products - * derived from this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND - * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED - * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY - * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES - * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; - * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND - * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - ******************************************************************************/ - -#pragma once - -#include -#include - -namespace fmha { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float apply_exp_(float x, float max) { - return __expf(x - max); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float apply_exp2_(float x, float max) { - return exp2f(x - max); - // With fast-math, this produces the same PTX instruction as the assembly below - // float diff = x - max; - // float res; - // asm ("ex2.approx.ftz.f32 %0, %1;\n\t" : "=f"(res) : "f"(diff)); - // return res; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template struct ReadType {}; -template<> struct ReadType<4> { using T = float;}; -template<> struct ReadType<8> { using T = float2;}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Smem_tile_reduce { - // Helper class to distribute MMA tiles reduced over rows per warp over quads. - - // The Mma tile. - using Mma_tile = fmha::Hmma_tile; - - // The number of MMAs in M/N dimensions. - static constexpr int MMAS_M = Mma_tile::MMAS_M; - static constexpr int MMAS_N = Mma_tile::MMAS_N; - - static constexpr int WARPS_M = Cta_tile::WARPS_M; - static constexpr int WARPS_N = Cta_tile::WARPS_N; - - - static constexpr int ROWS = WARPS_M * MMAS_M * 16; - static constexpr int COLS = WARPS_N; - static_assert(COLS == 4 || COLS == 8); - static constexpr int ROWS_PER_XOR_PATTERN = (COLS == 8) ? 4 : 8; - static constexpr int BYTES_PER_TILE = ROWS * COLS * sizeof(float); - static constexpr int ELTS_PER_TILE = ROWS * COLS; - - static constexpr int THREADS_PER_GROUP = Kernel_traits::Gmem_tile_o::THREADS_PER_ROW; - // TD [2022-05-02]: No longer true if head_dim != 64 - // static_assert(THREADS_PER_GROUP == 16); // DEBUG - static constexpr int ROWS_PER_WARP = 32 / THREADS_PER_GROUP; - static constexpr int LOOPS = Kernel_traits::Gmem_tile_o::LOOPS; - static_assert(LOOPS == 1); - - using read_t = typename ReadType::T; - - __device__ inline Smem_tile_reduce(float *smem_, const int tidx) { - - int lane = tidx % 32; - int warp = tidx / 32; - - int warp_m = warp % WARPS_M; - int warp_n = warp / WARPS_M; - - qid_ = lane % 4; - int qp = lane / 4; - - // Swizzle the column to avoid 2-fold bank conflicts when we have 8 warps. - // This won't affect reading as we assume commutative reduction ops. - const int col = warp_n ^ (qp / ROWS_PER_XOR_PATTERN); - smem_write_ = &smem_[warp_m * 16 * MMAS_M * WARPS_N + qp * WARPS_N + col]; - smem_read_ = &reinterpret_cast(smem_)[warp_m * 16 * MMAS_M * 4 + qp * 4 + qid_]; - smem_read_row_ = &reinterpret_cast(smem_)[warp_m * 16 * MMAS_M * 4 + qid_]; - - } - - __device__ inline void store(float (&frag)[2 * MMAS_M]) { - if( qid_ == 0 ) { - #pragma unroll - for( int mi = 0; mi < MMAS_M; mi++ ) { - int offset = mi * 16 * WARPS_N; - smem_write_[offset + 0 * 8 * WARPS_N] = frag[mi * 2 + 0]; - smem_write_[offset + 1 * 8 * WARPS_N] = frag[mi * 2 + 1]; - } - } - } - - __device__ inline void load(read_t (&frag)[2 * MMAS_M]) { - #pragma unroll - for( int mi = 0; mi < MMAS_M; mi++ ) { - int offset = mi * 16 * 4; - frag[mi * 2 + 0] = smem_read_[offset + 0 * 8 * 4]; - frag[mi * 2 + 1] = smem_read_[offset + 1 * 8 * 4]; - } - } - - __device__ inline void load_row(read_t (&frag)[MMAS_M], int row) { - #pragma unroll - for( int mi = 0; mi < MMAS_M; mi++ ) { - int offset = mi * 16 * 4; - frag[mi] = smem_read_row_[offset + 0 * 8 * 4 + row * 4]; - } - } - - int qid_; - float *smem_write_; - read_t *smem_read_; - read_t *smem_read_row_; - -}; - - -template -struct Softmax_base { - - // The Mma tile. - using Mma_tile = fmha::Hmma_tile; - - // The number of MMAs in M/N dimensions. - static constexpr int MMAS_M = Mma_tile::MMAS_M; - static constexpr int MMAS_N = Mma_tile::MMAS_N; - - // The number of groups of warp such that we have at most 4 warps writing consecutive elements. - static constexpr int GROUPS = fmha::DivUpConstexpr(Cta_tile::WARPS_N, 4); - // The number of elements that we are going to store per row. - static constexpr int ELEMENTS_PER_ROW = Cta_tile::WARPS_N / GROUPS; - // The number of rows. - static constexpr int ROWS = Cta_tile::M * GROUPS; - // The total number of elements. - static constexpr int ELEMENTS = ROWS * ELEMENTS_PER_ROW; - - // Ctor. - template - inline __device__ Softmax_base(const Params ¶ms, void *smem, int tidx) - : // packed_mask_ptr_(reinterpret_cast(params.packed_mask_ptr)), - smem_(reinterpret_cast(smem)), tidx_(tidx) { - - // Move to the 1st mask loaded by the thread+ tidx; - // packed_mask_ptr_ += bidb * params.packed_mask_stride_in_bytes + tidx * sizeof(uint32_t); - - // Extract the position in the warp. - int warp = tidx / Cta_tile::THREADS_PER_WARP; - int lane = tidx % Cta_tile::THREADS_PER_WARP; - - // Decompose the warp index into M and N. - int warp_m = warp % Cta_tile::WARPS_M; - int warp_n = warp / Cta_tile::WARPS_M; - - // Decompose the warp-n index into group/position-inside-the-group. - int warp_g = warp_n / ELEMENTS_PER_ROW; - int warp_i = warp_n % ELEMENTS_PER_ROW; - - // The location written by the threads. - int write_row = warp_g * (ROWS / GROUPS) + warp_m * Mma_tile::M_PER_MMA + lane / 4; - int write_col = warp_i; - - // Assemble the write pointer. - smem_write_ = &smem_[write_row * ELEMENTS_PER_ROW + write_col]; - - // Assemble the read pointer. - smem_read_ = &smem_[warp_m * Mma_tile::M_PER_MMA + lane / 4]; - } - - template - inline __device__ void apply_mask(const Mask &mask) { - #pragma unroll - for( int mi = 0; mi < MMAS_M; ++mi ) { - #pragma unroll - for( int ii = 0; ii < 2; ++ii ) { - #pragma unroll - for( int ni = 0; ni < MMAS_N; ++ni ) { - #pragma unroll - for( int jj = 0; jj < 4; ++jj ) { - if( !mask.is_valid(mi, ni, ii, jj) ) { - elt_[2 * mi + ii][4 * ni + jj] = zero ? 0.f : -INFINITY; - } - } - } - } - } - } - - // Apply the exp to all the elements. - template - inline __device__ void apply_exp(const float (&max)[MMAS_M * 2]) { - #pragma unroll - for( int mi = 0; mi < MMAS_M * 2; ++mi ) { - // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - - // max * log_2(e)) This allows the compiler to use the ffma - // instruction instead of fadd and fmul separately. - constexpr float kLog2e = M_LOG2E; - const float max_base2 = max_in_base2 ? max[mi] : max[mi] * kLog2e; - #pragma unroll - for( int ni = 0; ni < MMAS_N * 4; ++ni ) { - // elt_[mi][ni] = apply_exp_(elt_[mi][ni], max[mi]); - elt_[mi][ni] = apply_exp2_(elt_in_base2 ? elt_[mi][ni] : elt_[mi][ni] * kLog2e, - max_base2); - } - } - } - - // Apply the exp to all the elements. - template - inline __device__ void scale_apply_exp(const float (&max)[MMAS_M * 2], const float scale_) { - const float max_scale = scale_max ? scale_ * M_LOG2E : M_LOG2E; - const float scale = scale_ * M_LOG2E; - #pragma unroll - for( int mi = 0; mi < MMAS_M * 2; ++mi ) { - // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - - // max * log_2(e)) This allows the compiler to use the ffma - // instruction instead of fadd and fmul separately. - const float max_scaled = max[mi] * max_scale; - #pragma unroll - for( int ni = 0; ni < MMAS_N * 4; ++ni ) { - elt_[mi][ni] = apply_exp2_(elt_[mi][ni] * scale, max_scaled); - } - } - } - - // Apply the exp to all the elements. - template - inline __device__ void apply_exp_col(const float (&max)[MMAS_N * 4]) { - #pragma unroll - for( int ni = 0; ni < MMAS_N * 4; ++ni ) { - constexpr float kLog2e = M_LOG2E; - const float max_base2 = max_in_base2 ? max[ni] : max[ni] * kLog2e; - #pragma unroll - for( int mi = 0; mi < MMAS_M * 2; ++mi ) { - elt_[mi][ni] = apply_exp2_(elt_[mi][ni] * kLog2e, max_base2); - } - } - } - // inline __device__ void apply_exp_col(const float (&max)[MMAS_N]) { - // constexpr float kLog2e = M_LOG2E; - // #pragma unroll - // for( int ni = 0; ni < MMAS_N * 4; ++ni ) { - // float max_base2 = max_in_base2 ? max[ni / 4] : max[ni / 4] * kLog2e; - // max_base2 = __shfl_sync(0xffffffff, max_base2, (ni % 4) * 8 + threadIdx.x % 8); - // #pragma unroll - // for( int mi = 0; mi < MMAS_M * 2; ++mi ) { - // elt_[mi][ni] = apply_exp2_(elt_[mi][ni] * kLog2e, max_base2); - // } - // } - // } - - template - inline __device__ void apply_dropout_16bits(Philox &ph, uint16_t p_dropout_in_uint16_t) { - // We encode the dropout pattern in the sign bit of the non-negative - // softmax to distinguish from pre-existing zeros - auto encode_dropout = [](bool keep, float val) { - return keep ? val : (encode_dropout_in_sign_bit ? -val : float(0)); - }; - #pragma unroll - for( int mi = 0; mi < MMAS_M; mi++ ) { - #pragma unroll - for( int ni = 0; ni < MMAS_N; ni++ ) { - uint16_t tmp[8]; - // fmha::uint4_to_ushort8(ph(), tmp); - uint4 tmp_32 = ph(); - fmha::uint4_to_ushort8(tmp_32, tmp); - // if ((threadIdx.x % 32 == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - // printf("tidx = %d, ni = %d, ph Philox: %u, %u, %u, %u\n", threadIdx.x, ni, tmp_32.x, tmp_32.y, tmp_32.z, tmp_32.w); - // } - #pragma unroll - for (int ii = 0; ii < 2; ++ii) { - #pragma unroll - for (int jj = 0; jj < 4; ++jj) { - elt_[mi * 2 + ii][4 * ni + jj] = - encode_dropout(tmp[ii * 4 + jj] <= p_dropout_in_uint16_t, elt_[mi * 2 + ii][4 * ni + jj]); - } - } - } - } - } - - template - inline __device__ void apply_dropout_16bits(Philox &ph, uint16_t p_dropout_in_uint16_t, - unsigned long long philox_subsequence) { - // We encode the dropout pattern in the sign bit of the non-negative - // softmax to distinguish from pre-existing zeros - auto encode_dropout = [](bool keep, float val) { - return keep ? val : (encode_dropout_in_sign_bit ? -val : float(0)); - }; - static_assert(MMAS_M == 1); // We're assuming 16x16 blocks. - #pragma unroll - for( int mi = 0; mi < MMAS_M; mi++ ) { - #pragma unroll - for( int ni = 0; ni < MMAS_N; ni++ ) { - uint16_t tmp[8]; - // fmha::uint4_to_ushort8(ph(), tmp); - fmha::uint4_to_ushort8(ph(philox_subsequence + ni * Cta_tile::WARPS_N), tmp); - // uint4 tmp_32 = ph(philox_subsequence + ni * Cta_tile::WARPS_N); - // fmha::uint4_to_ushort8(tmp_32, tmp); - // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - // printf("ni = %d, ph Philox: %u, %u, %u, %u\n", ni, tmp_32.x, tmp_32.y, tmp_32.z, tmp_32.w); - // } - #pragma unroll - for (int ii = 0; ii < 2; ++ii) { - #pragma unroll - for (int jj = 0; jj < 4; ++jj) { - elt_[mi * 2 + ii][4 * ni + jj] = - encode_dropout(tmp[ii * 4 + jj] <= p_dropout_in_uint16_t, elt_[mi * 2 + ii][4 * ni + jj]); - } - } - } - } - } - - template - inline __device__ void apply_dropout_16bits(Philox &ph0, Philox &ph1, uint16_t p_dropout_in_uint16_t) { - // We encode the dropout pattern in the sign bit of the non-negative - // softmax to distinguish from pre-existing zeros - auto encode_dropout = [](bool keep, float val) { - return keep ? val : (encode_dropout_in_sign_bit ? -val : float(0)); - }; - #pragma unroll - for( int mi = 0; mi < MMAS_M; mi++ ) { - static_assert(MMAS_N % 2 == 0); - #pragma unroll - for( int ni = 0; ni < MMAS_N; ni += 2 ) { - uint16_t tmp[8]; - fmha::uint4_to_ushort8(ph0(), tmp); - // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - // printf("ni = %d, ph Philox: %u, %u, %u, %u\n", ni, tmp.x, tmp.y, tmp.z, tmp.w); - // } - #pragma unroll - for (int ii = 0; ii < 2; ++ii) { - #pragma unroll - for (int jj = 0; jj < 4; ++jj) { - elt_[mi * 2 + ii][4 * ni + jj] = - encode_dropout(tmp[ii * 4 + jj] <= p_dropout_in_uint16_t, elt_[mi * 2 + ii][4 * ni + jj]); - } - } - fmha::uint4_to_ushort8(ph1(), tmp); - // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - // printf("ni = %d, ph Philox: %u, %u, %u, %u\n", ni, tmp.x, tmp.y, tmp.z, tmp.w); - // } - #pragma unroll - for (int ii = 0; ii < 2; ++ii) { - #pragma unroll - for (int jj = 0; jj < 4; ++jj) { - elt_[mi * 2 + ii][4 * (ni + 1) + jj] = - encode_dropout(tmp[ii * 4 + jj] <= p_dropout_in_uint16_t, elt_[mi * 2 + ii][4 * (ni + 1) + jj]); - } - } - } - } - } - - // Scale all the elements. - inline __device__ void scale(const float (&sum)[MMAS_M * 2]) { - // Precompute the inverse sum to normalize. Without -use_fast_math, it makes a huge deal. - float inv_sum[MMAS_M * 2]; - #pragma unroll - for( int mi = 0; mi < MMAS_M * 2; ++mi ) { - inv_sum[mi] = (sum[mi] == 0.f || sum[mi] != sum[mi]) ? 1.f : 1.f / sum[mi]; - } - - // Update the values. - #pragma unroll - for( int mi = 0; mi < MMAS_M * 2; ++mi ) { - #pragma unroll - for( int ni = 0; ni < MMAS_N * 4; ++ni ) { - elt_[mi][ni] *= inv_sum[mi]; - } - } - } - - // Subtract all elements by dp_sum - inline __device__ void subtract_dp_sum(const float (&dp_sum)[MMAS_M * 2]) { - #pragma unroll - for( int mi = 0; mi < MMAS_M * 2; ++mi ) { - #pragma unroll - for( int ni = 0; ni < MMAS_N * 4; ++ni ) { - elt_[mi][ni] -= dp_sum[mi]; - } - } - } - - // The pointer to the mask. - const char *packed_mask_ptr_; - // Shared memory for the CTA-wide reduction. - float *smem_, *smem_write_, *smem_read_; - // The current thread index. - int tidx_; - // The elements. - float elt_[MMAS_M * 2][MMAS_N * 4]; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Softmax : public Softmax_base { - - // The base class. - using Base = Softmax_base; - // The fragment. - using Fragment_a = fmha::Fragment_a; - - static_assert(Fragment_a::NUM_REGS == 4); - - static constexpr int WARPS_M = Cta_tile::WARPS_M; - static constexpr int WARPS_N = Cta_tile::WARPS_N; - // The MMAs. - static constexpr int MMAS_M = Base::MMAS_M; - static constexpr int MMAS_N = Base::MMAS_N; - - // The accumulators. - using Accumulator = fmha::Fragment_accumulator; - using Accumulator_out = Fragment; - static_assert(Accumulator_out::NUM_REGS == 4); - - static_assert(std::is_same::value); - - using Smem_tile_red = Smem_tile_reduce; - static_assert(Smem_tile_red::ELTS_PER_TILE == Cta_tile::M * WARPS_N); - // Ctor. - template - inline __device__ Softmax(const Params ¶ms, void *smem, int tidx) - : Base(params, smem, tidx) - , params_scale_bmm1_(params.scale_bmm1) - , smem_sum_(static_cast(smem), tidx) - , smem_max_(static_cast(smem) + Smem_tile_red::ELTS_PER_TILE, tidx) { - } - - // Pack the data to a fragment for the next GEMM. - template - inline __device__ void pack(Fragment_a (&dst)[K][M]) const { - #pragma unroll - for( int mi = 0; mi < M; ++mi ) { - #pragma unroll - for( int ki = 0; ki < K; ++ki ) { - - // 1st row - 4 elements per row. - float tmp_00 = this->elt_[2 * mi + 0][4 * ki + 0]; - float tmp_01 = this->elt_[2 * mi + 0][4 * ki + 1]; - float tmp_02 = this->elt_[2 * mi + 0][4 * ki + 2]; - float tmp_03 = this->elt_[2 * mi + 0][4 * ki + 3]; - - // 2nd row - 4 elements per row. - float tmp_10 = this->elt_[2 * mi + 1][4 * ki + 0]; - float tmp_11 = this->elt_[2 * mi + 1][4 * ki + 1]; - float tmp_12 = this->elt_[2 * mi + 1][4 * ki + 2]; - float tmp_13 = this->elt_[2 * mi + 1][4 * ki + 3]; - - // Pack to 4 registers. - dst[ki][mi].reg(0) = fmha::float2_pack(tmp_00, tmp_01); - dst[ki][mi].reg(1) = fmha::float2_pack(tmp_10, tmp_11); - dst[ki][mi].reg(2) = fmha::float2_pack(tmp_02, tmp_03); - dst[ki][mi].reg(3) = fmha::float2_pack(tmp_12, tmp_13); - } - } - } - - // Scale FP32 fragments - inline __device__ void unpack(const Accumulator (&acc)[MMAS_M][MMAS_N]) { - const float scalef = reinterpret_cast(this->params_scale_bmm1_); - - #pragma unroll - for( int mi = 0; mi < MMAS_M; ++mi ) { - #pragma unroll - for( int ni = 0; ni < MMAS_N; ++ni ) { - // 1st row - 4 elements per row. - this->elt_[2 * mi + 0][4 * ni + 0] = acc[mi][ni].elt(0) * scalef; - this->elt_[2 * mi + 0][4 * ni + 1] = acc[mi][ni].elt(1) * scalef; - this->elt_[2 * mi + 0][4 * ni + 2] = acc[mi][ni].elt(4) * scalef; - this->elt_[2 * mi + 0][4 * ni + 3] = acc[mi][ni].elt(5) * scalef; - // 2nd row - 4 elements per row. - this->elt_[2 * mi + 1][4 * ni + 0] = acc[mi][ni].elt(2) * scalef; - this->elt_[2 * mi + 1][4 * ni + 1] = acc[mi][ni].elt(3) * scalef; - this->elt_[2 * mi + 1][4 * ni + 2] = acc[mi][ni].elt(6) * scalef; - this->elt_[2 * mi + 1][4 * ni + 3] = acc[mi][ni].elt(7) * scalef; - } - } - } - - // Scale FP32 fragments - inline __device__ void unpack_noscale(const Accumulator (&acc)[MMAS_M][MMAS_N]) { - - #pragma unroll - for( int mi = 0; mi < MMAS_M; ++mi ) { - #pragma unroll - for( int ni = 0; ni < MMAS_N; ++ni ) { - // 1st row - 4 elements per row. - this->elt_[2 * mi + 0][4 * ni + 0] = acc[mi][ni].elt(0); - this->elt_[2 * mi + 0][4 * ni + 1] = acc[mi][ni].elt(1); - this->elt_[2 * mi + 0][4 * ni + 2] = acc[mi][ni].elt(4); - this->elt_[2 * mi + 0][4 * ni + 3] = acc[mi][ni].elt(5); - // 2nd row - 4 elements per row. - this->elt_[2 * mi + 1][4 * ni + 0] = acc[mi][ni].elt(2); - this->elt_[2 * mi + 1][4 * ni + 1] = acc[mi][ni].elt(3); - this->elt_[2 * mi + 1][4 * ni + 2] = acc[mi][ni].elt(6); - this->elt_[2 * mi + 1][4 * ni + 3] = acc[mi][ni].elt(7); - } - } - } - - template - __device__ inline void thread_reduce_(float (&frag)[2 * MMAS_M], Operator &op) { - #pragma unroll - for( int mi = 0; mi < 2 * MMAS_M; mi++ ) { - frag[mi] = zero_init ? this->elt_[mi][0] : op(frag[mi], this->elt_[mi][0]); - #pragma unroll - for( int ni = 1; ni < 4 * MMAS_N; ni++ ) { - frag[mi] = op(frag[mi], this->elt_[mi][ni]); - } - } - } - - template - __device__ inline void reduce_(float (&frag)[2 * MMAS_M], Operator &op, Smem_tile_red & smem_red) { - thread_reduce_(frag, op); - quad_reduce(frag, frag, op); - smem_red.store(frag); - __syncthreads(); - typename Smem_tile_red::read_t tmp[2 * MMAS_M]; - smem_red.load(tmp); - quad_allreduce(frag, tmp, op); - } - - template - __device__ inline void reduce_max(float (&frag)[2 * MMAS_M]){ - MaxOp max; - reduce_(frag, max, smem_max_); - } - - __device__ inline void reduce_sum(float (&frag)[2 * MMAS_M]){ - SumOp sum; - reduce_(frag, sum, smem_sum_); - } - - template - __device__ inline void reduce_sum_before_sync_(float (&frag)[2 * MMAS_M]){ - SumOp sum; - thread_reduce_(frag, sum); - quad_reduce(frag, frag, sum); - smem_sum_.store(frag); - } - - template - __device__ inline void reduce_after_sync_(float (&frag)[NROWS][MMAS_M], - const int (&rows)[NROWS], - Operator &op, Smem_tile_red & smem_red) { - #pragma unroll - for (int ii = 0; ii < NROWS; ii++) { - typename Smem_tile_red::read_t tmp[MMAS_M]; - smem_red.load_row(tmp, rows[ii]); - quad_allreduce(frag[ii], tmp, op); - } - } - - template - __device__ inline void reduce_sum_after_sync_(float (&frag)[NROWS][MMAS_M], - const int (&rows)[NROWS]){ - SumOp sum; - reduce_after_sync_(frag, rows, sum, smem_sum_); - } - - template - __device__ inline void reduce_max_after_sync_(float (&frag)[NROWS][MMAS_M], - const int (&rows)[NROWS]){ - MaxOp max; - reduce_after_sync_(frag, rows, max, smem_max_); - } - - const uint32_t params_scale_bmm1_; - Smem_tile_red smem_max_; - Smem_tile_red smem_sum_; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace fmha diff --git a/csrc/flash_attn/src/fmha/utils.h b/csrc/flash_attn/src/fmha/utils.h deleted file mode 100644 index ecb8aef7f..000000000 --- a/csrc/flash_attn/src/fmha/utils.h +++ /dev/null @@ -1,1215 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the - * names of its contributors may be used to endorse or promote products - * derived from this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND - * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED - * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY - * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES - * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; - * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND - * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - ******************************************************************************/ - -#pragma once - -#include -#include -#include - -#include - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 -#include -#endif - -extern "C" __device__ uint32_t __nvvm_get_smem_pointer(void *ptr); - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace fmha { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct Row {}; -struct Col {}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< int M, bool = (M & (M-1)) == 0 > -struct Next_power_of_two { -}; - -template< int M > -struct Next_power_of_two< M, true > { enum { VALUE = M }; }; -template<> -struct Next_power_of_two< 3, false> { enum { VALUE = 4 }; }; -template<> -struct Next_power_of_two< 5, false> { enum { VALUE = 8 }; }; -template<> -struct Next_power_of_two< 6, false> { enum { VALUE = 8 }; }; -template<> -struct Next_power_of_two< 7, false> { enum { VALUE = 8 }; }; -template<> -struct Next_power_of_two< 9, false> { enum { VALUE = 16 }; }; -template<> -struct Next_power_of_two< 10, false> { enum { VALUE = 16 }; }; -template<> -struct Next_power_of_two< 11, false> { enum { VALUE = 16 }; }; -template<> -struct Next_power_of_two< 12, false> { enum { VALUE = 16 }; }; -template<> -struct Next_power_of_two< 13, false> { enum { VALUE = 16 }; }; -template<> -struct Next_power_of_two< 14, false> { enum { VALUE = 16 }; }; -template<> -struct Next_power_of_two< 15, false> { enum { VALUE = 16 }; }; -template<> -struct Next_power_of_two< 24, false> { enum { VALUE = 32 }; }; -template<> -struct Next_power_of_two< 48, false> { enum { VALUE = 64 }; }; -template<> -struct Next_power_of_two< 80, false> { enum { VALUE = 128 }; }; -template<> -struct Next_power_of_two< 96, false> { enum { VALUE = 128 }; }; -template<> -struct Next_power_of_two<112, false> { enum { VALUE = 128 }; }; -template<> -struct Next_power_of_two<144, false> { enum { VALUE = 256 }; }; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< int N, bool = (N & (N-1)) == 0 > -struct Prev_power_of_two { -}; - -template< int N > -struct Prev_power_of_two< N, true > { enum { VALUE = N }; }; -template<> -struct Prev_power_of_two< 3, false> { enum { VALUE = 2 }; }; -template<> -struct Prev_power_of_two< 5, false> { enum { VALUE = 4 }; }; -template<> -struct Prev_power_of_two< 6, false> { enum { VALUE = 4 }; }; -template<> -struct Prev_power_of_two< 7, false> { enum { VALUE = 4 }; }; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< int M, int N > -struct Div_up { - enum { VALUE = (M + N-1) / N }; -}; - -constexpr int DivUpConstexpr(int M, int N) { return (M + N - 1) / N; } - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< int A, int B > -struct Max { - enum { VALUE = A >= B ? A : B }; -}; - -constexpr int MaxConstexpr(int A, int B) { return A >= B ? A : B; } - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< int A, int B, int C > -struct Max_3 { - enum { VALUE = Max::VALUE, C>::VALUE }; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< int A, int B > -struct Min { - enum { VALUE = A <= B ? A : B }; -}; - -constexpr int MinConstexpr(int A, int B) { return A <= B ? A : B; } - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< int SIZE_IN_BYTES > -struct Uint_from_size_in_bytes { -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -struct Uint_from_size_in_bytes<1> { - using Type = uint8_t; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -struct Uint_from_size_in_bytes<2> { - using Type = uint16_t; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -struct Uint_from_size_in_bytes<4> { - using Type = uint32_t; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -struct Uint_from_size_in_bytes<8> { - using Type = uint2; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -struct Uint_from_size_in_bytes<16> { - using Type = uint4; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< int WARPS_M, int WARPS_N, int WARPS_K > -struct Warp_masks { -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -struct Warp_masks<8, 1, 1> { enum { M = 0xe0, N = 0x00, K = 0x00 }; }; -template<> -struct Warp_masks<4, 2, 1> { enum { M = 0x60, N = 0x80, K = 0x00 }; }; -template<> -struct Warp_masks<4, 1, 2> { enum { M = 0x60, N = 0x00, K = 0x80 }; }; -template<> -struct Warp_masks<4, 1, 1> { enum { M = 0x60, N = 0x00, K = 0x00 }; }; -template<> -struct Warp_masks<2, 4, 1> { enum { M = 0x20, N = 0xc0, K = 0x00 }; }; -template<> -struct Warp_masks<2, 2, 2> { enum { M = 0x20, N = 0x40, K = 0x80 }; }; -template<> -struct Warp_masks<2, 2, 1> { enum { M = 0x20, N = 0x40, K = 0x00 }; }; -template<> -struct Warp_masks<2, 1, 2> { enum { M = 0x20, N = 0x00, K = 0x40 }; }; -template<> -struct Warp_masks<2, 1, 1> { enum { M = 0x20, N = 0x00, K = 0x00 }; }; -template<> -struct Warp_masks<1, 8, 1> { enum { M = 0x00, N = 0xe0, K = 0x00 }; }; -template<> -struct Warp_masks<1, 4, 2> { enum { M = 0x00, N = 0x60, K = 0x80 }; }; -template<> -struct Warp_masks<1, 4, 1> { enum { M = 0x00, N = 0x60, K = 0x00 }; }; -template<> -struct Warp_masks<1, 2, 2> { enum { M = 0x00, N = 0x20, K = 0x40 }; }; -template<> -struct Warp_masks<1, 2, 1> { enum { M = 0x00, N = 0x20, K = 0x00 }; }; -template<> -struct Warp_masks<1, 1, 4> { enum { M = 0x00, N = 0x00, K = 0x60 }; }; -template<> -struct Warp_masks<1, 1, 2> { enum { M = 0x00, N = 0x00, K = 0x20 }; }; -template<> -struct Warp_masks<1, 1, 1> { enum { M = 0x00, N = 0x00, K = 0x00 }; }; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< typename T > -inline __device__ __host__ T div_up(T m, T n) { - return (m + n-1) / n; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline int clz(int x) { - for( int i = 31; i >= 0; --i ) { - if( (1 << i) & x ) { - return 31 - i; - } - } - return 32; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline int find_log_2(int x, bool round_up = false) { - int a = 31 - clz(x); - if( round_up ) { - a += (x & (x-1)) ? 1 : 0; - } - return a; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static inline __device__ uint32_t hadd2(uint32_t a, uint32_t b) { - uint32_t c; - asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static inline __device__ uint32_t hmin2(uint32_t a, uint32_t b) { - uint32_t c; - asm volatile("min.f16x2 %0, %1, %2;" : "=r"(c) : "r"(a), "r"(b)); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static inline __device__ uint32_t hmul2(const uint32_t a, const uint32_t b) { - // uint32_t c; - // asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); - // return c; - __half2 result = __hmul2(reinterpret_cast(a), - reinterpret_cast(b)); - return reinterpret_cast(result); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static inline __device__ uint2 hmul4(uint2 a, uint2 b) { - uint2 c; - c.x = hmul2(a.x, b.x); - c.y = hmul2(a.y, b.y); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static inline __device__ uint4 hmul8(uint4 a, uint4 b) { - uint4 c; - c.x = hmul2(a.x, b.x); - c.y = hmul2(a.y, b.y); - c.z = hmul2(a.z, b.z); - c.w = hmul2(a.w, b.w); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static inline __device__ uint4 hmul8(uint32_t a, uint4 b) { - uint4 c; - c.x = hmul2(a, b.x); - c.y = hmul2(a, b.y); - c.z = hmul2(a, b.z); - c.w = hmul2(a, b.w); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ uint32_t hrelu2(uint32_t x); - -template<> -inline __device__ uint32_t hrelu2<__half>(uint32_t x) { - uint32_t res; - const uint32_t zero = 0u; -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - asm volatile( "max.f16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero)); -#else - asm volatile( \ - "{\n" \ - "\t .reg .f16x2 sela;\n" \ - "\t set.gtu.u32.f16x2 sela, %1, %2;\n" \ - "\t and.b32 %0, sela, %1;\n" - "}\n" : "=r"(res) : "r"(x), "r"(zero)); -#endif - return res; -} - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 -template<> -inline __device__ uint32_t hrelu2<__nv_bfloat16>(uint32_t x) { - uint32_t res; - const uint32_t zero = 0u; - asm volatile( "max.bf16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero)); - return res; -} -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static inline __device__ uint32_t habs2(uint32_t x) { - uint32_t res; - asm volatile( "abs.f16x2 %0, %1;\n" : "=r"(res) : "r"(x)); - return res; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< typename T > -static inline __device__ T clamp(T x, T lb, T ub) { - return x < lb ? lb : (x > ub ? ub : x); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static inline __device__ uint16_t clamp_to_zero(uint16_t x) { - uint16_t mask; - asm volatile("set.gtu %0, %1, 0;" : "=h"(mask) : "h"(x)); - return mask & x; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static inline __device__ uint16_t float_to_half(float f) { - uint16_t h; - asm volatile("cvt.rn.f16.f32 %0, %1;" : "=h"(h) : "f"(f)); - return h; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static inline __device__ uint32_t float2_to_half2(float a, float b) { - uint32_t c; -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(c) : "f"(b), "f"(a)); -#else - uint16_t lo = float_to_half(a); - uint16_t hi = float_to_half(b); - asm volatile("mov.b32 %0, {%1, %2};\n" : "=r"(c) : "h"(lo), "h"(hi)); -#endif - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ uint32_t float2_pack(float a, float b); - -template <> -inline __device__ uint32_t float2_pack<__half>(float a, float b) { - __half2 result = __floats2half2_rn(a, b); - return reinterpret_cast(result); -} - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 -template <> -inline __device__ uint32_t float2_pack<__nv_bfloat16>(float a, float b) { - __nv_bfloat162 result = __floats2bfloat162_rn(a, b); - return reinterpret_cast(result); -} -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static inline __device__ uint32_t float_to_half2(float a) { - return float2_to_half2(a,a); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static inline __device__ uint32_t float2_to_half2(const float2 &f) { - return float2_to_half2(f.x, f.y); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static inline __device__ uint2 float4_to_half4(float x, float y, float z, float w) { - uint2 d; - d.x = float2_to_half2(x, y); - d.y = float2_to_half2(z, w); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ uint2 float4_pack(float x, float y, float z, float w) { - uint2 d; - d.x = float2_pack(x, y); - d.y = float2_pack(z, w); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static inline __device__ uint32_t hfma2(uint32_t a, uint32_t b, uint32_t c) { - uint32_t d; - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c)); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static inline __device__ uint32_t hfma2_relu(uint32_t a, uint32_t b, uint32_t c) { - uint32_t d; -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - asm volatile("fma.rn.f16x2.relu %0, %1, %2, %3;" : "=r"(d) : "r"(a), "r"(b), "r"(c)); -#else - d = hrelu2<__half>(hfma2(a, b, c)); -#endif - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static inline __device__ uint32_t h0_h0(uint32_t x) { - uint32_t y; - asm volatile("{.reg .f16 lo, hi; mov.b32 {lo, hi}, %1; mov.b32 %0, {lo, lo};}\n" - : "=r"(y) : "r"(x)); - return y; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static inline __device__ float h0_to_float(uint32_t h2) { - float f; - asm volatile("{\n" \ - ".reg .f16 lo, hi;\n" \ - "mov.b32 {lo, hi}, %1;\n" \ - "cvt.f32.f16 %0, lo;\n" \ - "}\n" : "=f"(f) : "r"(h2)); - return f; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static inline __device__ uint32_t h1_h1(uint32_t x) { - uint32_t y; - asm volatile("{.reg .f16 lo, hi; mov.b32 {lo, hi}, %1; mov.b32 %0, {hi, hi};}\n" - : "=r"(y) : "r"(x)); - return y; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static inline __device__ uint16_t hadd(uint16_t a, uint16_t b) { - uint16_t d; - asm volatile("add.f16 %0, %1, %2;" : "=h"(d) : "h"(a), "h"(b)); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static inline __device__ uint32_t hadd(uint32_t a, uint32_t b) { - return hadd2(a, b); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static inline __device__ uint2 hadd4(uint2 a, uint2 b) { - uint2 c; - c.x = hadd2(a.x, b.x); - c.y = hadd2(a.y, b.y); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static inline __device__ uint2 hadd(uint2 a, uint2 b) { - return hadd4(a, b); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static inline __device__ uint4 hadd8(uint4 a, uint4 b) { - uint4 c; - c.x = hadd2(a.x, b.x); - c.y = hadd2(a.y, b.y); - c.z = hadd2(a.z, b.z); - c.w = hadd2(a.w, b.w); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ float2 half2_unpack(uint32_t a); - -template <> -inline __device__ float2 half2_unpack<__half>(uint32_t a) { - return __half22float2(reinterpret_cast<__half2 (&)>(a)); -} - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 -template <> -inline __device__ float2 half2_unpack<__nv_bfloat16>(uint32_t a) { - return __bfloat1622float2(reinterpret_cast<__nv_bfloat162 (&)>(a)); -} -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// Converted two half2's or bf162's into float, then take their dot product. -template -inline __device__ float hfma2_to_float(const uint32_t a, const uint32_t b) { - float2 af = fmha::half2_unpack(a); - float2 bf = fmha::half2_unpack(b); - return af.x * bf.x + af.y * bf.y; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// Converted two vectors of 8 half's or bf16's into float, then take their dot product. -template -inline __device__ float hmulsum8(const uint4 a, const uint4 b) { - float sum; - sum = fmha::hfma2_to_float(a.x, b.x); - sum += fmha::hfma2_to_float(a.y, b.y); - sum += fmha::hfma2_to_float(a.z, b.z); - sum += fmha::hfma2_to_float(a.w, b.w); - return sum; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static inline __device__ uint4 fadd4(uint4 a, uint4 b) { - float4 c; - c.x = reinterpret_cast(a.x) + reinterpret_cast(b.x); - c.y = reinterpret_cast(a.y) + reinterpret_cast(b.y); - c.z = reinterpret_cast(a.z) + reinterpret_cast(b.z); - c.w = reinterpret_cast(a.w) + reinterpret_cast(b.w); - return reinterpret_cast(c); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static inline __device__ uint4 fmul4(uint4 a, float b) { - float4 c; - c.x = reinterpret_cast(a.x) * b; - c.y = reinterpret_cast(a.y) * b; - c.z = reinterpret_cast(a.z) * b; - c.w = reinterpret_cast(a.w) * b; - return reinterpret_cast(c); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static inline __device__ uint4 hadd(uint4 a, uint4 b) { - return hadd8(a, b); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static inline __device__ float half_to_float(uint16_t h) { - float f; - asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h)); - return f; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static inline __device__ float2 half2_to_float2(uint32_t x) { - uint16_t lo, hi; - asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(x)); - return make_float2(half_to_float(lo), half_to_float(hi)); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static inline __device__ void half2_to_float2(float &x, float &y, uint32_t h) { - float2 tmp = half2_to_float2(h); - x = tmp.x; - y = tmp.y; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static inline __device__ uint16_t hfma(uint16_t a, uint16_t b, uint16_t c) { - uint16_t d; - asm volatile("fma.rn.f16 %0, %1, %2, %3;" : "=h"(d) : "h"(a), "h"(b), "h"(c)); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static inline __device__ uint16_t hmul(uint16_t a, uint16_t b) { - uint16_t d; - asm volatile("mul.f16 %0, %1, %2;" : "=h"(d) : "h"(a), "h"(b)); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static inline __device__ void uint4_to_ushort8(const uint4 a, uint16_t (&b)[8]) { - uint32_t *b_tmp = reinterpret_cast(&b[0]); - b_tmp[0] = a.x; - b_tmp[1] = a.y; - b_tmp[2] = a.z; - b_tmp[3] = a.w; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static inline __device__ float sigmoid(float x) { - return 1.f / (1.f + expf(-x)); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void clear(uint16_t &dst) { - dst = uint16_t(0); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void clear(uint32_t &dst) { - dst = 0u; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void clear(uint2 &dst) { - dst = make_uint2(0u, 0u); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void clear(uint4 &dst) { - dst = make_uint4(0u, 0u, 0u, 0u); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// -// -// P R E D I C A T E P A C K I N G -// -//////////////////////////////////////////////////////////////////////////////////////////////////// -enum { BYTES_PER_REG = 4, PREDS_PER_BYTE = 4, PREDS_PER_REG = BYTES_PER_REG * PREDS_PER_BYTE }; - - -//////////////////////////////////////////////////////////////////////////////////////////////////// -// -// G E N E R I C P R E D I C A T E D L D G S T S -// -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< int N, int M, typename Functor > -inline __device__ void load_(Functor &fct, const uint32_t (&preds)[M]) { - - // The number of complete bytes (where we use all the predicates in a byte). - enum { COMPLETE = N / PREDS_PER_BYTE }; - // Make sure we did allocate enough predicates. - static_assert(Div_up::VALUE <= M, ""); - // The remainder. - enum { REMAINDER = N - COMPLETE * PREDS_PER_BYTE }; - // Make sure we got the math right and the remainder is between 0 and 3. - static_assert(REMAINDER >= 0 && REMAINDER <= 3, ""); - // The mask to extract the predicates. - enum { COMPLETE_MASK = (1 << PREDS_PER_BYTE) - 1 }; - - // Clear the fetch registers. - #pragma unroll - for( int ii = 0; ii < N; ++ii ) { - fct.clear(ii); - } - - // Run complete steps. - bool p[PREDS_PER_BYTE]; - #pragma unroll - for( int ii = 0; ii < COMPLETE; ++ii ) { - - // The predicate. - uint32_t reg = preds[ii / BYTES_PER_REG]; - - // Extract the predicates. - #pragma unroll - for( int jj = 0; jj < PREDS_PER_BYTE; ++jj ) { - uint32_t mask = 1u << (ii % BYTES_PER_REG * 8 + jj); - p[jj] = (reg & mask) != 0u; - } - - // Issue the loads. - #pragma unroll - for( int jj = 0; jj < PREDS_PER_BYTE; ++jj ) { - fct.load(ii * PREDS_PER_BYTE + jj, p[jj]); - } - } - - // Skip the rest of the code if we do not have a remainder. - if( REMAINDER > 0 ) { - - // The mask to extract the predicates. - enum { REMAINDER_MASK = (1 << REMAINDER) - 1 }; - - // The predicate register. - uint32_t reg = preds[COMPLETE / BYTES_PER_REG]; - - // Extract the predicates. - #pragma unroll - for( int jj = 0; jj < PREDS_PER_BYTE; ++jj ) { - uint32_t mask = 1u << (COMPLETE % BYTES_PER_REG * 8 + jj); - p[jj] = (reg & mask) != 0u; - } - - // Issue the loads. - #pragma unroll - for( int ii = 0; ii < REMAINDER; ++ii ) { - fct.load(COMPLETE * PREDS_PER_BYTE + ii, p[ii]); - } - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< int M, typename Functor > -inline __device__ void load_(Functor &fct, uint32_t preds) { - uint32_t tmp[1] = { preds }; - load_(fct, tmp); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// -// -// L D G -// -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void ldg(uint8_t &dst, const void *ptr) { - dst = *reinterpret_cast(ptr); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void ldg(uint16_t &dst, const void *ptr) { - dst = *reinterpret_cast(ptr); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void ldg(uint32_t &dst, const void *ptr) { - dst = *reinterpret_cast(ptr); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void ldg(uint2 &dst, const void *ptr) { - dst = *reinterpret_cast(ptr); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void ldg(uint4 &dst, const void *ptr) { - dst = *reinterpret_cast(ptr); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< typename Data_type, int N > -struct Ldg_functor { - // Ctor. - inline __device__ Ldg_functor(Data_type (&fetch)[N], const void* (&ptrs)[N]) - : fetch_(fetch), ptrs_(ptrs) { - } - - // Clear the element. - inline __device__ void clear(int ii) { - fmha::clear(fetch_[ii]); - } - - // Trigger the loads. - inline __device__ void load(int ii, bool p) { - if( p ) { - ldg(fetch_[ii], ptrs_[ii]); - } - } - - // The fetch registers. - Data_type (&fetch_)[N]; - // The pointers. - const void* (&ptrs_)[N]; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< typename Data_type, int N, int M > -inline __device__ void ldg_(Data_type (&fetch)[N], const void* (&ptrs)[N], uint32_t (&preds)[M]) { - Ldg_functor fct(fetch, ptrs); - load_(fct, preds); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< int N, int M > -inline __device__ void ldg(uint8_t (&fetch)[N], const void* (&ptrs)[N], uint32_t (&preds)[M]) { - ldg_(fetch, ptrs, preds); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< int N, int M > -inline __device__ void ldg(uint16_t (&fetch)[N], const void* (&ptrs)[N], uint32_t (&preds)[M]) { - ldg_(fetch, ptrs, preds); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< int N, int M > -inline __device__ void ldg(uint32_t (&fetch)[N], const void* (&ptrs)[N], uint32_t (&preds)[M]) { - ldg_(fetch, ptrs, preds); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< int N, int M > -inline __device__ void ldg(uint2 (&fetch)[N], const void* (&ptrs)[N], uint32_t (&preds)[M]) { - ldg_(fetch, ptrs, preds); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< int N, int M > -inline __device__ void ldg(uint4 (&fetch)[N], const void* (&ptrs)[N], uint32_t (&preds)[M]) { - ldg_(fetch, ptrs, preds); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// -// -// L D S -// -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void lds(uint16_t &dst, uint32_t ptr) { - asm volatile("ld.shared.b16 %0, [%1];\n" : "=h"(dst) : "r"(ptr)); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void lds(uint32_t &dst, uint32_t ptr) { - asm volatile("ld.shared.b32 %0, [%1];\n" : "=r"(dst) : "r"(ptr)); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void lds(uint2 &dst, uint32_t ptr) { - asm volatile("ld.shared.v2.b32 {%0, %1}, [%2];\n" : "=r"(dst.x), "=r"(dst.y) : "r"(ptr)); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void lds(uint4 &dst, uint32_t ptr) { - asm volatile("ld.shared.v4.b32 {%0, %1, %2, %3}, [%4];\n" - : "=r"(dst.x) - , "=r"(dst.y) - , "=r"(dst.z) - , "=r"(dst.w) - : "r"(ptr)); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// -// -// L D S M -// -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void ldsm(uint32_t &dst, uint32_t ptr) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730 - asm volatile("ldmatrix.sync.aligned.m8n8.x1.shared.b16 {%0}, [%1];\n" - : "=r"(dst) : "r"(ptr)); -#endif -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void ldsmt(uint32_t &dst, uint32_t ptr) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730 - asm volatile("ldmatrix.sync.aligned.m8n8.x1.trans.shared.b16 {%0}, [%1];\n" - : "=r"(dst) : "r"(ptr)); -#endif -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void ldsm(uint2 &dst, uint32_t ptr) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730 - asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0, %1}, [%2];\n" - : "=r"(dst.x), "=r"(dst.y) : "r"(ptr)); -#endif -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void ldsmt(uint2 &dst, uint32_t ptr) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730 - asm volatile("ldmatrix.sync.aligned.m8n8.x2.trans.shared.b16 {%0, %1}, [%2];\n" - : "=r"(dst.x), "=r"(dst.y) : "r"(ptr)); -#endif -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void ldsm(uint4 &dst, uint32_t ptr) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730 - asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];\n" - : "=r"(dst.x), "=r"(dst.y), "=r"(dst.z), "=r"(dst.w) : "r"(ptr)); -#endif -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void ldsmt(uint4 &dst, uint32_t ptr) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730 - asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0, %1, %2, %3}, [%4];\n" - : "=r"(dst.x), "=r"(dst.y), "=r"(dst.z), "=r"(dst.w) : "r"(ptr)); -#endif -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// -// -// S T G -// -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void stg(void *ptr, uint8_t val) { - *reinterpret_cast(ptr) = val; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void stg(void *ptr, uint16_t val) { - *reinterpret_cast(ptr) = val; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void stg(void *ptr, uint32_t val) { - *reinterpret_cast(ptr) = val; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void stg(void *ptr, uint2 val) { - *reinterpret_cast(ptr) = val; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void stg(void *ptr, uint4 val) { - *reinterpret_cast(ptr) = val; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// -// -// S T S -// -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void sts(uint32_t ptr, uint16_t val) { - asm volatile("st.shared.b16 [%0], %1;\n" : : "r"(ptr), "h"(val)); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void sts(uint32_t ptr, uint32_t val) { - asm volatile("st.shared.b32 [%0], %1;\n" : : "r"(ptr), "r"(val)); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void sts(uint32_t ptr, uint2 val) { - asm volatile("st.shared.v2.b32 [%0], {%1, %2};\n" - : - : "r"(ptr) - , "r"(val.x) - , "r"(val.y)); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void sts(uint32_t ptr, uint4 val) { - asm volatile("st.shared.v4.b32 [%0], {%1, %2, %3, %4};\n" - : - : "r"(ptr) - , "r"(val.x) - , "r"(val.y) - , "r"(val.z) - , "r"(val.w)); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< typename Data_type, int N > -inline __device__ void sts_(uint32_t (&ptrs)[N], const Data_type (&data)[N]) { - #pragma unroll - for( int ii = 0; ii < N; ++ii ) { - sts(ptrs[ii], data[ii]); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< int N > -inline __device__ void sts(uint32_t (&ptrs)[N], const uint16_t (&data)[N]) { - sts_(ptrs, data); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< int N > -inline __device__ void sts(uint32_t (&ptrs)[N], const uint32_t (&data)[N]) { - sts_(ptrs, data); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< int N > -inline __device__ void sts(uint32_t (&ptrs)[N], const uint2 (&data)[N]) { - sts_(ptrs, data); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< int N > -inline __device__ void sts(uint32_t (&ptrs)[N], const uint4 (&data)[N]) { - sts_(ptrs, data); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MaxOp { -__device__ inline T operator()(T const & x, T const & y) { return x > y ? x : y; } -}; - -template <> -struct MaxOp { -// This is slightly faster -__device__ inline float operator()(float const &x, float const &y) { return max(x, y); } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct SumOp { -__device__ inline T operator()(T const & x, T const & y) { return x + y; } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Allreduce { - static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); - template - static __device__ inline T run(T x, Operator &op) { - constexpr int OFFSET = THREADS / 2; - x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); - return Allreduce::run(x, op); - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -struct Allreduce<2> { -template -static __device__ inline T run(T x, Operator &op) { - x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); - return x; -} -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -__device__ inline void quad_reduce(float (&dst)[M], float (&src)[M], Operator &op) { - #pragma unroll - for(int mi=0; mi < M; mi++){ - dst[mi] = src[mi]; - dst[mi] = op(dst[mi], __shfl_down_sync(uint32_t(-1), dst[mi], 2)); - dst[mi] = op(dst[mi], __shfl_down_sync(uint32_t(-1), dst[mi], 1)); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -__device__ inline void quad_reduce(__half2 (&dst)[M], __half2 (&src)[M], Operator &op) { - #pragma unroll - for(int mi=0; mi < M; mi++){ - dst[mi] = src[mi]; - dst[mi] = op(dst[mi], __shfl_down_sync(uint32_t(-1), dst[mi], 2)); - dst[mi] = op(dst[mi], __shfl_down_sync(uint32_t(-1), dst[mi], 1)); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -__device__ inline void quad_reduce(float (&dst)[M], float2 (&src)[M], Operator &op) { - float tmp[M]; - #pragma unroll - for(int mi=0; mi < M; mi++){ - tmp[mi] = op(src[mi].x, src[mi].y); - } - quad_reduce(dst, tmp, op); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -__device__ inline void quad_reduce(__half2 (&dst)[M], float2 (&src)[M], Operator &op) { - __half2 tmp[M]; - #pragma unroll - for(int mi=0; mi < M; mi++){ - tmp[mi] = op(reinterpret_cast(src[mi].x), - reinterpret_cast(src[mi].y)); - } - quad_reduce(dst, tmp, op); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -__device__ inline void quad_allreduce(float (&dst)[M], float (&src)[M], Operator &op) { - #pragma unroll - for(int mi=0; mi < M; mi++){ - dst[mi] = src[mi]; - dst[mi] = Allreduce<4>::run(dst[mi], op); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -__device__ inline void quad_allreduce(__half2 (&dst)[M], __half2 (&src)[M], Operator &op) { - #pragma unroll - for(int mi=0; mi < M; mi++){ - dst[mi] = src[mi]; - dst[mi] = Allreduce<4>::run(dst[mi], op); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -__device__ inline void quad_allreduce(float (&dst)[M], float2 (&src)[M], Operator &op) { - float tmp[M]; - #pragma unroll - for(int mi=0; mi < M; mi++){ - tmp[mi] = op(src[mi].x, src[mi].y); - } - quad_allreduce(dst, tmp, op); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -__device__ inline void quad_allreduce(__half2 (&dst)[M], float2 (&src)[M], Operator &op) { - __half2 tmp[M]; - #pragma unroll - for(int mi=0; mi < M; mi++){ - tmp[mi] = op(reinterpret_cast(src[mi].x), - reinterpret_cast(src[mi].y)); - } - quad_allreduce(dst, tmp, op); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace fmha diff --git a/csrc/flash_attn/src/fmha_block_dgrad_fp16_kernel_loop.sm80.cu b/csrc/flash_attn/src/fmha_block_dgrad_fp16_kernel_loop.sm80.cu deleted file mode 100644 index bfafa20ea..000000000 --- a/csrc/flash_attn/src/fmha_block_dgrad_fp16_kernel_loop.sm80.cu +++ /dev/null @@ -1,64 +0,0 @@ -/* Copyright (c) 2022, Tri Dao. - */ - -#include "fmha.h" -#include "fmha_block_dgrad_kernel_1xN_loop.h" - -template -__global__ void fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel(FMHA_dgrad_params params) { - fmha::compute_block_dq_dk_dv_1xN(params); -} - -template -void run_fmha_block_dgrad_fp16_sm80_loop_(const FMHA_dgrad_params ¶ms, cudaStream_t stream) { - constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float); - constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE; - constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE; - constexpr int smem_size_dq = Kernel_traits::Smem_tile_o::BYTES_PER_TILE; - constexpr int smem_size_dp_sum = Kernel_traits::Smem_dp_sum::BYTES_PER_TILE; - - using Smem_tile_s = fmha::Smem_tile_mma_transposed; - constexpr int smem_size_s = Smem_tile_s::BYTES_PER_TILE; - static_assert(smem_size_s == 16 * Kernel_traits::Cta_tile_p::N * 2); - static_assert(smem_size_dq == 16 * Kernel_traits::Cta_tile_p::K * 4 * Kernel_traits::Cta_tile_p::WARPS_N); - static_assert(smem_size_dp_sum == 16 * 4 * 2); - - constexpr int smem_size_dq_dk_dv = smem_size_q * 2 + smem_size_v * (Kernel_traits::V_IN_REGS ? 1 : 2) + smem_size_dq + smem_size_s * 2 + smem_size_dp_sum; - - bool is_dropout = params.p_dropout < 1.f; // params.p_dropout is the probability of "keeping" - bool is_causal = params.is_causal; - auto kernel = is_dropout - ? (is_causal ? &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel : &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel) - : (is_causal ? &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel : &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel); - constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N; - if (params.seqlen_k == blocksize_c) { - kernel = is_dropout - ? (is_causal ? &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel : &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel) - : (is_causal ? &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel : &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel); - } else if (params.seqlen_k == blocksize_c * 2) { - kernel = is_dropout - ? (is_causal ? &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel : &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel) - : (is_causal ? &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel : &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel); - } - - if( smem_size_dq_dk_dv >= 48 * 1024 ) { - FMHA_CHECK_CUDA(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); - } - dim3 grid(params.b, params.h); - kernel<<>>(params); - FMHA_CHECK_CUDA(cudaPeekAtLastError()); -} - -void run_fmha_block_dgrad_fp16_sm80(const FMHA_dgrad_params ¶ms, cudaStream_t stream) { - if (params.d == 16) { - using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 8, 0x08u>; - run_fmha_block_dgrad_fp16_sm80_loop_(params, stream); - } else if (params.d == 32) { - using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 8, 0x08u>; - run_fmha_block_dgrad_fp16_sm80_loop_(params, stream); - } else if (params.d == 64) { - using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x100u>; - run_fmha_block_dgrad_fp16_sm80_loop_(params, stream); - } -} \ No newline at end of file diff --git a/csrc/flash_attn/src/fmha_block_dgrad_kernel_1xN_loop.h b/csrc/flash_attn/src/fmha_block_dgrad_kernel_1xN_loop.h deleted file mode 100644 index ce5410fc8..000000000 --- a/csrc/flash_attn/src/fmha_block_dgrad_kernel_1xN_loop.h +++ /dev/null @@ -1,772 +0,0 @@ -/* Copyright (c) 2022, Tri Dao. - */ - -#pragma once - -#include "fmha_fprop_kernel_1xN.h" -#include "fmha_kernel.h" -#include "fmha_blockmask.h" -#include -#include - -namespace fmha { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ void dot_do_o(float (&sum)[M], const uint4 (&do_)[M], const uint4 (&o)[M], - Smem_dp_sum smem, const int buffer_idx) { - #pragma unroll - for (int mi = 0; mi < M; ++mi) { - sum[mi] = smem.reduce_warp(fmha::hmulsum8<__half>(do_[mi], o[mi])); - } - static_assert(M == 1); - smem.store(sum[0], buffer_idx); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng &ph, - const int loop_step_idx) { - - // The description of the CTA tile for the 1st batched GEMM. - using Cta_tile_p = typename Kernel_traits::Cta_tile_p; - // The description of the CTA tile for the 2nd batched GEMM. - using Cta_tile_dq = typename Kernel_traits::Cta_tile_o; - // The description of the CTA tile for the 3rd batched GEMM. - using Cta_tile_dkv = - fmha::Cta_tile_extd; - - static_assert(Cta_tile_dkv::M == 512 || Cta_tile_dkv::M == 256 || Cta_tile_dkv::M == 128); - static_assert(Cta_tile_dkv::N == 16 || Cta_tile_dkv::N == 32 || Cta_tile_dkv::N == 64); - static_assert(Cta_tile_dkv::K == 16); - - // The MMA tile for the 1st GEMM. - using Mma_tile_p = fmha::Hmma_tile; - // The MMA tile for the 2nd GEMM. - using Mma_tile_dq = fmha::Hmma_tile; - // The MMA tile for the 3rd GEMM. - using Mma_tile_dkv = fmha::Hmma_tile; - - // The global memory tile to load Q. - using Gmem_tile_q = typename Kernel_traits::Gmem_tile_q; - // The shared memory tile to reload Q transposed. - using Smem_tile_qt = fmha::Smem_tile_b; - - // The global memory tile to load K. - using Gmem_tile_k = typename Kernel_traits::Gmem_tile_k; - // The shared memory tile to swizzle K^T. Treat K^T as V - using Smem_tile_kt = typename Kernel_traits::Smem_tile_v; - - // Treating V as K. We need to use Kernel_traits::Smem_tile_k otherwise loading will be wrong - // The global memory tile to load V. - using Gmem_tile_v = typename Kernel_traits::Gmem_tile_k; - // The shared memory tile to swizzle V. - using Smem_tile_v = typename Kernel_traits::Smem_tile_k; - - // The global memory tile to load dO. - using Gmem_tile_do = typename Kernel_traits::Gmem_tile_do; - // The shared memory tile to load dO. - // Treating dO as Q. - using Smem_tile_do = typename Kernel_traits::Smem_tile_q; - // The shared memory tile to reload dO transposed. - using Smem_tile_dot = fmha::Smem_tile_b; - - // The global memory tile to load O.Loading O here is similar to loading dO. - using Gmem_tile_o = Gmem_tile_do; - - // The global memory tile to store dQ. - using Gmem_tile_dq = typename Kernel_traits::Gmem_tile_o; - using Gmem_tile_dq_tmp = fmha::Gmem_tile_o; - // The shared memory tile to swizzle dQ. - using Smem_tile_dq = typename Kernel_traits::Smem_tile_o; - - // The global memory tile to store dV. - using Gmem_tile_dv = typename Kernel_traits::Gmem_tile_v; - // The shared memory tile to swizzle dV. - using Smem_tile_dv = fmha::Smem_tile_mma_epilogue; - - // The global memory tile to store dK. - using Gmem_tile_dk = typename Kernel_traits::Gmem_tile_v; - // The shared memory tile to swizzle dK. - using Smem_tile_dk = fmha::Smem_tile_mma_epilogue; - static_assert(Smem_tile_dk::NUM_LDS == Gmem_tile_dk::LDGS); - static_assert(Smem_tile_dk::THREADS_PER_ROW == Gmem_tile_dk::THREADS_PER_ROW); - - using Gmem_tile_s = typename Kernel_traits::Gmem_tile_s; - - using Smem_tile_st = typename Kernel_traits::Smem_tile_st; - - using Gmem_softmax_sum = typename Kernel_traits::Gmem_softmax_sum; - - using Smem_dp_sum = typename Kernel_traits::Smem_dp_sum; - - // using Gemm1 = Gemm_Q_K; - using Gemm1 = Gemm_Q_K; - - using Softmax = fmha::Softmax; - - // Shared memory. - extern __shared__ char smem_[]; - // Shared memory layout if we keep V in registers: - // dO | Q | K / V | dQ | S | dP | dP_sum - // dV | dK - // Shared memory layout if we keep V shared memory: - // dO | Q | K | V | dQ | S | dP | dP_sum - // dV | dK - - - // The block index for the batch. - const int bidb = blockIdx.x; - // The block index for the head. - const int bidh = blockIdx.y; - // The thread index. - const int tidx = threadIdx.x; - - const BlockInfoPadded binfo(params, bidb, bidh, tidx); - // if( binfo.stop_early() ) return; - if( binfo.stop_early(loop_step_idx * Cta_tile_p::N) ) return; - - Blockmask blockmask(params, loop_step_idx); - int block_row_idx = 0; - int mask_val = blockmask.mask_val(0); - if (mask_val == -1) return; - // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - // printf("mask_val = %d.\n", mask_val); - // } - - Gemm1 gemm_q_k(&smem_[Smem_tile_do::BYTES_PER_TILE], tidx); - // Allocate the global memory tile loader for Q. - Gmem_tile_q gmem_q(params.q_ptr, params.q_row_stride_in_elts, params.q_head_stride_in_elts, - params.d, binfo, tidx, true); - // Allocate the global memory tile loader for dQ. - Gmem_tile_dq gmem_dq(params.dq_ptr, params.dq_row_stride_in_elts, params.dq_head_stride_in_elts, - params.d, binfo, tidx); - Gmem_tile_dq_tmp gmem_dq_tmp(params.o_tmp_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, - params.d, binfo, tidx); - // Allocate the global memory tile loader for S. - Gmem_tile_s gmem_s(params, binfo, tidx); - - fmha::Mask mask(binfo, tidx, loop_step_idx); - - // Allocate the global memory tile loader for K. - Gmem_tile_k gmem_k(params.k_ptr, params.k_row_stride_in_elts, params.k_head_stride_in_elts, - params.d, binfo, tidx, false); - // Allocate the global memory tile loader for V. - Gmem_tile_v gmem_v(params.v_ptr, params.v_row_stride_in_elts, params.v_head_stride_in_elts, - params.d, binfo, tidx, false); - // The base pointer of smem_v; - char *smem_v_ = &smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_V]; - - // Allocate the shared memory tile loader for V. We use the same as K so be careful!!! - Smem_tile_v smem_v(smem_v_, tidx); - // Allocate the shared memory tile loader for K^T. We use the same as K so be careful!!! - Smem_tile_kt smem_kt(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::Smem_tile_q::BYTES_PER_TILE], tidx); - - // Allocate the global memory tile loader for dO. - Gmem_tile_do gmem_do(params.do_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, - params.d, binfo, tidx, true); - // Allocate the shared memory tile loader for dO. - Smem_tile_do smem_do(&smem_[0], tidx); - Smem_tile_dot smem_dot(&smem_[0], tidx); - // Allocate the shared memory tile loader for Q^T. - // TODO: assert that this points to the same memory as gemm_q_k.smem_q - Smem_tile_qt smem_qt(&smem_[Smem_tile_do::BYTES_PER_TILE], tidx); - - Smem_tile_st smem_s(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_O + Smem_tile_dq::BYTES_PER_TILE], tidx); - Smem_tile_st smem_dp(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_O + Smem_tile_dq::BYTES_PER_TILE + Smem_tile_st::BYTES_PER_TILE], tidx); - - // Allocate the global memory tile loader for O. - Gmem_tile_o gmem_o(params.o_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, - params.d, binfo, tidx, true); - - // Allocate the shared memory tile loader for O. We use the same as K so be careful!!! - Smem_tile_dq smem_dq(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_O], tidx); - - Gmem_softmax_sum gmem_softmax_lse(params.softmax_lse_ptr, params, tidx); - Gmem_softmax_sum gmem_softmax_d(params.dsoftmax_sum, params, tidx); - - static_assert(Cta_tile_p::N % Cta_tile_p::M == 0); - const int steps = (params.seqlen_q + Cta_tile_p::M - 1) / Cta_tile_p::M; - - // Wind gmem tiles to the correct position. - int block_row_idx_next = mask_val / 4; - int block_row_idx_to_move = block_row_idx_next - block_row_idx; - block_row_idx = block_row_idx_next; - gmem_q.move(block_row_idx_to_move); - gmem_do.move(block_row_idx_to_move); - gmem_o.move(block_row_idx_to_move); - gmem_dq.move(block_row_idx_to_move); - gmem_dq_tmp.move(block_row_idx_to_move); - // TODO: need to move gmem_s if we want the intermediate result for debugging - gmem_softmax_lse.move(block_row_idx_to_move); - gmem_softmax_d.move(block_row_idx_to_move); - block_row_idx = block_row_idx_next; - - if (!Is_first) { - gmem_k.move(loop_step_idx); - gmem_v.move(loop_step_idx); - } - - // Trigger the loads for K. - gmem_k.load(); - // Trigger the loads for Q. - gmem_q.load(); - // Trigger the loads for V. - gmem_v.load(); - // Trigger the loads for dO. - gmem_do.load(); - // Trigger the loads for O. - // if (Is_first) { gmem_o.load(); } - // if (true) { gmem_o.load(); } - if (Is_first || mask_val % 2 == 1) { gmem_o.load(); } - - float p_lse[Mma_tile_p::MMAS_M * 2]; - gmem_softmax_lse.load(reinterpret_cast(p_lse)); - - float dp_sum[Mma_tile_p::MMAS_M * 2]; - // if (!Is_first) { - // if (false) { - if (!(Is_first || mask_val % 2 == 1)) { - gmem_softmax_d.load(reinterpret_cast(dp_sum)); - } - - float dp_sum_regs[Gmem_tile_do::LDGS]; - Smem_dp_sum smem_dp_sum(reinterpret_cast(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_O + Smem_tile_dq::BYTES_PER_TILE + Smem_tile_st::BYTES_PER_TILE * 2]), tidx); - - if (!Is_first) { __syncthreads(); } - // Commit the data for Q, dO, and V to shared memory. - gmem_q.commit(gemm_q_k.smem_q); - gmem_do.commit(smem_do); - // if (Is_first) { - // if (true) { - if (Is_first || mask_val % 2 == 1) { - dot_do_o(dp_sum_regs, gmem_do.fetch_, gmem_o.fetch_, smem_dp_sum, 0); - const int dp_sum_row = tidx / Smem_dp_sum::THREADS_PER_ROW; - if ((dp_sum_row < Smem_dp_sum::ROWS) && (tidx % Smem_dp_sum::THREADS_PER_ROW == 0)) { - gmem_softmax_d.store_row(reinterpret_cast(dp_sum_regs), dp_sum_row); - } - } - - // Instead of scaling dP by rp_dropout, we scale V instead - if (Is_dropout) { - const uint32_t scale_dropout = params.scale_dropout; - #pragma unroll - for(int it=0; it < Gmem_tile_v::LDGS; it++){ - gmem_v.fetch_[it] = fmha::hmul8(scale_dropout, gmem_v.fetch_[it]); - } - } - - gmem_v.commit(smem_v); - - // const uint32_t scale_bmm1 = reinterpret_cast(params.scale_bmm1); - // #pragma unroll - // for(int it=0; it < Gmem_tile_k::LDGS; it++){ - // gmem_k.fetch_[it] = fmha::hmul8(scale_bmm1, gmem_k.fetch_[it]); - // } - - // Commit the data for K to shared memory. - if( !Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) { - gmem_k.commit(gemm_q_k.smem_k); - } - - __syncthreads(); - - // Load the fragments for Q. - gemm_q_k.load_q(); - - // Load the fragments for V. We keep the data in registers during the entire kernel. - typename Smem_tile_v::Fragment frag_v[Kernel_traits::V_IN_REGS ? Mma_tile_p::MMAS_K : 2][Mma_tile_p::MMAS_N]; - if (Kernel_traits::V_IN_REGS) { - #pragma unroll - for( int ki = 0; ki < Mma_tile_p::MMAS_K; ++ki ) { - smem_v.load(frag_v[ki], ki); - } - } - - // Commit the data for V to shared memory if it has not been done already. - if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) { - // Make sure we are done loading the fragments for K. - __syncthreads(); - - // Commit the data to shared memory for V. - gmem_k.commit(gemm_q_k.smem_k); - - // Make sure the data is in shared memory. - __syncthreads(); - } - - // Load the fragments for K. - gemm_q_k.load_k(); - // Load the fragments for K^T. - // typename Smem_tile_kt::Fragment frag_kt[2][Mma_tile_dq::MMAS_N]; - // smem_kt.load(frag_kt[0], 0); - // typename Smem_tile_kt::Fragment frag_kt[Mma_tile_dq::MMAS_K][Mma_tile_dq::MMAS_N]; - // #pragma unroll - // for( int ki = 0; ki < Mma_tile_dq::MMAS_K; ++ki ) { - // smem_kt.load(frag_kt[ki], ki); - // } - - // Create the object to do the softmax. - // We won't be using the shared memory for this softmax at all - Softmax softmax(params, smem_, tidx); - - // Declare the accumulators for the 3rd gemm. - fmha::Fragment_accumulator acc_dv[Mma_tile_dkv::MMAS_M][Mma_tile_dkv::MMAS_N]; - fmha::Clear_accumulator::apply(acc_dv); - fmha::Fragment_accumulator acc_dk[Mma_tile_dkv::MMAS_M][Mma_tile_dkv::MMAS_N]; - fmha::Clear_accumulator::apply(acc_dk); - - // Load over the entire sequence length. - for( int l = 0; l < steps; l++ ) { - // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - // printf("block_row_idx = %d\n", block_row_idx); - // } - if (block_row_idx * Cta_tile_p::M >= binfo.actual_seqlen_q) break; - - int mask_val_next = l < steps - 1 ? blockmask.mask_val(l + 1) : -1; - // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - // printf("mask_val = %d, mask_val_next = %d\n", mask_val, mask_val_next); - // } - - // Load the fragments for V. - // typename Smem_tile_v::Fragment frag_v[2][Mma_tile_p::MMAS_N]; - if (!Kernel_traits::V_IN_REGS) { smem_v.load(frag_v[0], 0); } - - // Load the fragments for dO. - typename Smem_tile_do::Fragment frag_do[2][Mma_tile_p::MMAS_M]; - smem_do.load(frag_do[0], 0); - - // Declare the accumulators for the 1st gemm. - fmha::Fragment_accumulator acc_p[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N]; - fmha::Clear_accumulator::apply(acc_p); - - // Do this part of P^T = (Q * K^T)^T. - gemm_q_k(acc_p); - - // Load the mask for that iteration. - mask.load(block_row_idx); - - // Convert from the accumulator type to FP32 for Softmax. - softmax.unpack_noscale(acc_p); - // Apply the mask. - softmax.apply_mask(mask); - // Scale by log-sum-exp of the softmax - // softmax.apply_exp(p_lse); - softmax.template scale_apply_exp(p_lse, params.scale_bmm1f); - if (Is_dropout) { - // softmax.apply_dropout(ph, params.p_dropout_in_uint); - // softmax.template apply_dropout(ph, params.p_dropout_in_uint); - softmax.template apply_dropout_16bits(ph, params.p_dropout_in_uint16_t); - } - - using Frag_p = fmha::Fragment_a; - Frag_p frag_p[Mma_tile_dq::MMAS_K][Mma_tile_dq::MMAS_M]; - static_assert(Mma_tile_dq::MMAS_M == Mma_tile_p::MMAS_M); - static_assert(Mma_tile_dq::MMAS_K == Mma_tile_p::MMAS_N); - softmax.template pack<__half>(frag_p); - - // Store s * dmask to smem for transpose - smem_s.store(frag_p); - - // Trigger the load for the next Q values. - bool not_last_iter = (l < steps - 1) && (mask_val_next != -1); - block_row_idx_next = mask_val_next / 4; - int block_row_idx_to_move = block_row_idx_next - block_row_idx; - if (not_last_iter) { - gemm_q_k.smem_q.move_to_next_write_buffer(); - gmem_q.move(block_row_idx_to_move); - gmem_q.load(); - } - - // if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V && l == 0 ) { - // // if we share K and V, it could be that V was not fully read yet but we write into smem for reduction - // __syncthreads(); - // } - - bool is_first_read = Is_first || mask_val % 2 == 1; - // TD [2022-04-24]: if Is_first, then it's faster to set acc_dp to zero then subtract by - // dp_sum later. If !Is_first, then it's faster to set acc_dp to -dp_sum and don't subtract - // later. This is because loading dp_sum earlier uses more registers. - fmha::Fragment_accumulator acc_dp[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N]; - // if (Is_first) { - // if (true) { - if (is_first_read) { - fmha::Clear_accumulator::apply(acc_dp); - } else { - #pragma unroll - for (int mi = 0; mi < Mma_tile_p::MMAS_M; ++mi) { - #pragma unroll - for (int ni = 0; ni < Mma_tile_p::MMAS_N; ++ni) { - #pragma unroll - for (int ii = 0; ii < 8; ++ii) { - acc_dp[mi][ni].elt(ii) = -dp_sum[mi * 2 + ((ii / 2) % 2)]; - } - } - } - } - - // Do this part of dP^T = (dO * V^T)^T. - #pragma unroll - for( int ki = 1; ki < Mma_tile_p::MMAS_K; ++ki ) { - // Trigger the load from shared memory for the next series of dO values. - smem_do.load(frag_do[ki & 1], ki); - if (!Kernel_traits::V_IN_REGS) { - smem_v.load(frag_v[ki & 1], ki); - fmha::gemm_cl<__half>(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1) & 1]); - } else { - fmha::gemm_cl<__half>(acc_dp, frag_do[(ki - 1) & 1], frag_v[ki - 1]); - } - // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l < 4)) { - // float2 tmp = __half22float2(reinterpret_cast<__half2 &>(frag_do[(ki - 1) & 1])); - // printf("frag_do=%.6f, %.6f\n", tmp.x, tmp.y); - // tmp = __half22float2(reinterpret_cast<__half2 &>(frag_v[(ki - 1) & 1])); - // printf("frag_v=%.6f, %.6f\n", tmp.x, tmp.y); - // } - } - - // Do the final stage of math. - { - int ki = Mma_tile_p::MMAS_K; - if (!Kernel_traits::V_IN_REGS) { - fmha::gemm_cl<__half>(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1) & 1]); - } else { - fmha::gemm_cl<__half>(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1)]); - } - } - - // Load the fragments for K^T. - typename Smem_tile_kt::Fragment frag_kt[2][Mma_tile_dq::MMAS_N]; - smem_kt.load(frag_kt[0], 0); - - // if (Is_first) { - // if (true) { - if (is_first_read) { - const int quad = (tidx % Cta_tile_p::THREADS_PER_WARP) / 4; - const int row[2] = {quad, quad + 8}; - smem_dp_sum.load(dp_sum, row, l % 2); - } - - // Trigger the load for the next dO values. - if (not_last_iter) { - smem_do.move_to_next_write_buffer(); - gmem_do.move(block_row_idx_to_move); - gmem_do.load(); - gmem_o.move(block_row_idx_to_move); - // if (Is_first) { - // if (true) { - if (Is_first || mask_val_next % 2 == 1) { - gmem_o.load(); - } - } - - softmax.unpack_noscale(acc_dp); - // // TD [2022-04-01]: Don't need to apply mask since the corresponding value in softmax - // // will be zero. - // for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) { dp_sum[mi] *= params.p_dropout; } - // if (Is_first) { softmax.subtract_dp_sum(dp_sum); } - // if (true) { softmax.subtract_dp_sum(dp_sum); } - if (is_first_read) { softmax.subtract_dp_sum(dp_sum); } - - Frag_p frag_dp[Mma_tile_dq::MMAS_K][Mma_tile_dq::MMAS_M]; - softmax.template pack<__half>(frag_dp); - - if (!Is_dropout) { - #pragma unroll - for( int mi = 0; mi < Mma_tile_p::MMAS_M; mi++ ) { - #pragma unroll - for( int ni = 0; ni < Mma_tile_p::MMAS_N; ni++ ) { - frag_p[mi][ni].hmul(frag_dp[mi][ni]); - } - } - } else { - __half2 dp_sum_half[Mma_tile_p::MMAS_M * 2]; - for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) { - dp_sum_half[mi] = __float2half2_rn(dp_sum[mi]); - } - const __half zero_h = __half(0.f); - #pragma unroll - for( int mi = 0; mi < Mma_tile_p::MMAS_M; mi++ ) { - #pragma unroll - for( int ni = 0; ni < Mma_tile_p::MMAS_N; ni++ ) { - #pragma unroll - for (int ii = 0; ii < 4; ++ii) { - const __half2 p = frag_p[mi][ni].template elt_as<__half2>(ii); - const __half2 pdp = __hmul2(p, frag_dp[mi][ni].template elt_as<__half2>(ii)); - // If this element is dropped, then frag_p stores -p instead of p. - // So pd holds -p * dp_sum in that case. - const __half2 pd = __hmul2(p, dp_sum_half[mi * 2 + (ii % 2)]); - const __half low = __low2half(p) >= zero_h ? __low2half(pdp) : __low2half(pd); - const __half high = __high2half(p) >= zero_h ? __high2half(pdp) : __high2half(pd); - frag_p[mi][ni].template elt_as<__half2>(ii) = __halves2half2(low, high); - } - } - } - } - - // Store dp to smem for transpose - smem_dp.store(frag_p); - - // gmem_s.store(frag_p, mask); - // gmem_s.move(); - - // Declare the accumulators for the 2nd gemm. - fmha::Fragment_accumulator acc_dq[Mma_tile_dq::MMAS_M][Mma_tile_dq::MMAS_N]; - fmha::Clear_accumulator::apply(acc_dq); - - // Do this part of O = P^T * V^T. - #pragma unroll - for( int ki = 1; ki < Mma_tile_dq::MMAS_K; ++ki ) { - // Trigger the load from shared memory for the next series of Q values. - smem_kt.load(frag_kt[ki & 1], ki); - // Do the math for the values already in registers. - fmha::gemm_cl<__half>(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1) & 1]); - // fmha::gemm_cl<__half>(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]); - } - // Do the final stage of math. - { - int ki = Mma_tile_dq::MMAS_K; - fmha::gemm_cl<__half>(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1) & 1]); - // fmha::gemm_cl<__half>(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]); - } - - static_assert(Gmem_tile_dq::LOOPS == 1); - - // Swizzle the elements and do the final reduction. - smem_dq.store(acc_dq, 0); - - typename Smem_tile_dot::Fragment frag_dot[2][Mma_tile_dkv::MMAS_N]; - static_assert(Smem_tile_dot::Fragment::NUM_REGS == 4); - static_assert(Mma_tile_dkv::MMAS_K == 1); - smem_dot.load(frag_dot[0], 0); - - // Threads in a warp is communicating via shared memory (smem_s and smem_dp) - __syncwarp(); - typename Smem_tile_st::Fragment frag_s[Mma_tile_dkv::MMAS_K][Mma_tile_dkv::MMAS_M]; - smem_s.load(frag_s); - - if (Is_dropout) { - #pragma unroll - for( int ki = 0; ki < Mma_tile_dkv::MMAS_K; ki++ ) { - #pragma unroll - for( int mi = 0; mi < Mma_tile_dkv::MMAS_M; mi++ ) { - frag_s[ki][mi].template hrelu_<__half>(); - } - } - } - - #pragma unroll - for( int ki = 1; ki < Mma_tile_dkv::MMAS_K; ++ki ) { - // Trigger the load from shared memory for the next series of Q values. - smem_dot.load(frag_dot[ki & 1], ki); - // Do the math for the values already in registers. - fmha::gemm_cl<__half>(acc_dv, frag_s[(ki - 1)], frag_dot[(ki - 1) & 1]); - } - - // Do the final stage of math. - { - int ki = Mma_tile_dkv::MMAS_K; - fmha::gemm_cl<__half>(acc_dv, frag_s[(ki - 1)], frag_dot[(ki - 1) & 1]); - } - - // __syncthreads(); - // Commit the values for Q and dO into shared memory. - if (not_last_iter) { - gmem_q.commit(gemm_q_k.smem_q); - } - - uint4 dq_out[Gmem_tile_dq::STGS_PER_LOOP]; - // if (!Is_first) { gmem_dq_tmp.load(dq_out, 0); } - if (!is_first_read) { gmem_dq_tmp.load(dq_out, 0); } - - // __syncthreads(); - // Commit the values for Q and dO into shared memory. - if (not_last_iter) { - gmem_do.commit(smem_do); - // if (Is_first) { - // if (true) { - gmem_softmax_d.move(block_row_idx_to_move); - if (Is_first || mask_val_next % 2 == 1) { - // dot_do_o(dp_sum_regs, gmem_do.fetch_, gmem_o.fetch_, smem_dp_sum); - // smem_dp_sum.move_to_next_write_buffer(); - dot_do_o(dp_sum_regs, gmem_do.fetch_, gmem_o.fetch_, smem_dp_sum, (l + 1) % 2); - const int dp_sum_row_1 = tidx / Smem_dp_sum::THREADS_PER_ROW; - if ((dp_sum_row_1 < Smem_dp_sum::ROWS) && (tidx % Smem_dp_sum::THREADS_PER_ROW == 0)) { - gmem_softmax_d.store_row(reinterpret_cast(dp_sum_regs), dp_sum_row_1); - } - } - gmem_softmax_lse.move(block_row_idx_to_move); - gmem_softmax_lse.load(reinterpret_cast(p_lse)); - // if (!Is_first) { - if (!(Is_first || mask_val_next % 2 == 1)) { - gmem_softmax_d.load(reinterpret_cast(dp_sum)); - } - } - - typename Smem_tile_st::Fragment frag_dpt[Mma_tile_dkv::MMAS_K][Mma_tile_dkv::MMAS_M]; - smem_dp.load(frag_dpt); - - gemm_q_k.reload_k(); - - typename Smem_tile_qt::Fragment frag_qt[2][Mma_tile_dkv::MMAS_N]; - static_assert(Smem_tile_qt::Fragment::NUM_REGS == 4); - static_assert(Mma_tile_dkv::MMAS_K == 1); - smem_qt.load(frag_qt[0], 0); - - #pragma unroll - for( int ki = 1; ki < Mma_tile_dkv::MMAS_K; ++ki ) { - // Trigger the load from shared memory for the next series of Q values. - smem_qt.load(frag_qt[ki & 1], ki); - // Do the math for the values already in registers. - fmha::gemm_cl<__half>(acc_dk, frag_dpt[(ki - 1)], frag_qt[(ki - 1) & 1]); - } - - // Do the final stage of math. - { - int ki = Mma_tile_dkv::MMAS_K; - fmha::gemm_cl<__half>(acc_dk, frag_dpt[(ki - 1)], frag_qt[(ki - 1) & 1]); - } - - // Make sure dQ is in shared memory. - __syncthreads(); - - // Load from shared memory. - is_first_read ? smem_dq.template load(dq_out) : smem_dq.template load(dq_out); - - const bool is_final_write = - Is_last - || ((loop_step_idx + 1) * Cta_tile_p::N >= binfo.actual_seqlen_k) - || ((mask_val & 0x2) != 0) - || ((Is_causal) && (block_row_idx * Cta_tile_p::M < (loop_step_idx + 1) * Cta_tile_p::N)); - if (is_final_write) { - // if (Is_dropout) { - // dq_out[0] = fmha::fmul4(dq_out[0], params.rp_dropout); - // } - dq_out[0] = fmha::fmul4(dq_out[0], params.scale_bmm1f); - // Output the values. - gmem_dq.template store<__half>(dq_out, 0); - } else { - // Output the values. - gmem_dq_tmp.store(dq_out, 0); - } - - // Move to the next part of the output. - gmem_dq.move(block_row_idx_to_move); - if (!(Is_first && Is_last)) { gmem_dq_tmp.move(block_row_idx_to_move); } - - // // Make sure the data is in shared memory. - // __syncthreads(); - - // Commit the values for Q and dO into shared memory. - if (not_last_iter) { - gemm_q_k.smem_q.move_to_next_read_buffer(); - gemm_q_k.reload_q(); - smem_qt.move_to_next_read_buffer(); - // smem_qt.load(frag_qt[0], 0); - smem_do.move_to_next_read_buffer(); - smem_dot.move_to_next_read_buffer(); - // smem_dot.load(frag_dot[0], 0); - } - - if (mask_val_next == -1) break; - mask_val = mask_val_next; - block_row_idx += block_row_idx_to_move; - - } // Outer loop over the sequence length. - - if (Is_dropout) { - for( int mi = 0; mi < Mma_tile_dkv::MMAS_M; mi++ ) { - for( int ni = 0; ni < Mma_tile_dkv::MMAS_N; ni++ ) { - acc_dv[mi][ni].mul_(params.rp_dropout); - } - } - } - // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - // printf("l final, acc_dk=%.6f, %.6f\n", acc_dk[0][0].elt(0), acc_dk[0][0].elt(1)); - // } - for( int mi = 0; mi < Mma_tile_dkv::MMAS_M; mi++ ) { - for( int ni = 0; ni < Mma_tile_dkv::MMAS_N; ni++ ) { - // acc_dk[mi][ni].mul_(Is_dropout ? params.rp_dropout * params.scale_bmm1f : params.scale_bmm1f); - acc_dk[mi][ni].mul_(params.scale_bmm1f); - } - } - // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - // printf("l final, acc_dk=%.6f, %.6f\n", acc_dk[0][0].elt(0), acc_dk[0][0].elt(1)); - // } - - __syncthreads(); - // TODO [TD - 2022-05-04]: Are there cases where the shared mem for dV and dK are larger than - // the total amount of shared mem? - // Epilogue swizzle for dV - Smem_tile_dv smem_dv(&smem_[0], tidx); - smem_dv.template store<__half>(acc_dv); - - // Epilogue swizzle for dK - Smem_tile_dk smem_dk(&smem_[Smem_tile_dv::BYTES_PER_TILE], tidx); - smem_dk.template store<__half>(acc_dk); - - __syncthreads(); - uint4 dv_out[Smem_tile_dv::NUM_LDS]; - smem_dv.load(dv_out); - Gmem_tile_dv gmem_dv(params.dv_ptr, params.dv_row_stride_in_elts, params.dv_head_stride_in_elts, - params.d, binfo, tidx, false); - if (!Is_first) { - gmem_dv.move(loop_step_idx); - } - gmem_dv.store(dv_out); - - uint4 dk_out[Smem_tile_dk::NUM_LDS]; - smem_dk.load(dk_out); - // for (int ii = 0; ii < Smem_tile_dk::NUM_LDS; ++ii) { - // dk_out[ii] = fmha::fmul4(dk_out[ii], params.scale_bmm1f); - // } - Gmem_tile_dk gmem_dk(params.dk_ptr, params.dk_row_stride_in_elts, params.dk_head_stride_in_elts, - params.d, binfo, tidx, false); - if (!Is_first) { - gmem_dk.move(loop_step_idx); - } - gmem_dk.store(dk_out); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// loop_steps = -1 means the number of steps will be params.seqlen_k / Kernel_traits::Cta_tile_p::N. -// This template parameter is there so we can specialize with loop_steps == 1 and loop_steps == 2. -template -inline __device__ void compute_block_dq_dk_dv_1xN(const Params ¶ms) { - constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N; - - // The block index for the batch. - const int bidb = blockIdx.x; - // The block index for the head. - const int bidh = blockIdx.y; - // The thread index. - const int tidx = threadIdx.x; - - const int tidx_global = (bidb * params.h + bidh) * blockDim.x + tidx; - auto seeds = at::cuda::philox::unpack(params.philox_args); - Philox ph(std::get<0>(seeds), tidx_global, std::get<1>(seeds)); - - if (loop_steps == 1) { - compute_block_dq_dk_dv_1xN_one_iter(params, ph, 0); - } else if (loop_steps == 2) { - compute_block_dq_dk_dv_1xN_one_iter(params, ph, 0); - compute_block_dq_dk_dv_1xN_one_iter(params, ph, 1); - } else { - if (params.seqlen_k == blocksize_c) { - compute_block_dq_dk_dv_1xN_one_iter(params, ph, 0); - } else { - const int max_loop_steps = (params.seqlen_k + blocksize_c - 1) / blocksize_c; - compute_block_dq_dk_dv_1xN_one_iter(params, ph, 0); - for (int loop_step_idx = 1; loop_step_idx < max_loop_steps - 1; loop_step_idx++) { - compute_block_dq_dk_dv_1xN_one_iter(params, ph, loop_step_idx); - } - compute_block_dq_dk_dv_1xN_one_iter(params, ph, max_loop_steps - 1); - } - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace fmha diff --git a/csrc/flash_attn/src/fmha_block_fprop_fp16_kernel.sm80.cu b/csrc/flash_attn/src/fmha_block_fprop_fp16_kernel.sm80.cu deleted file mode 100644 index d1a90633e..000000000 --- a/csrc/flash_attn/src/fmha_block_fprop_fp16_kernel.sm80.cu +++ /dev/null @@ -1,90 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the - * names of its contributors may be used to endorse or promote products - * derived from this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND - * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED - * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY - * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES - * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; - * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND - * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - ******************************************************************************/ - -#include "fmha.h" -#include "fmha_block_fprop_kernel_1xN.h" - -template -__global__ void fmha_block_fprop_fp16_sm80_loop_kernel(FMHA_fprop_params params) { - fmha::device_block_1xN_loop(params); -} - -template -void run_fmha_block_fp16_sm80_loop_(Launch_params &launch_params, - const bool configure) { - bool is_causal = launch_params.params.is_causal; - // TD [2022-04-27]: This case work is pretty ugly, maybe there's a better way? - auto kernel = launch_params.is_dropout - ? (is_causal - ? (launch_params.return_softmax ? &fmha_block_fprop_fp16_sm80_loop_kernel : &fmha_block_fprop_fp16_sm80_loop_kernel) - : (launch_params.return_softmax ? &fmha_block_fprop_fp16_sm80_loop_kernel : &fmha_block_fprop_fp16_sm80_loop_kernel)) - : (is_causal - ? (launch_params.return_softmax ? &fmha_block_fprop_fp16_sm80_loop_kernel : &fmha_block_fprop_fp16_sm80_loop_kernel) - : (launch_params.return_softmax ? &fmha_block_fprop_fp16_sm80_loop_kernel : &fmha_block_fprop_fp16_sm80_loop_kernel)); - - constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N; - const int loop_steps = (launch_params.params.seqlen_k + blocksize_c - 1) / blocksize_c; - constexpr int smem_size_softmax_lse = Kernel_traits::Smem_dp_sum::BYTES_PER_TILE; - // Don't need smem_size_softmax_lse if we're not looping - const int smem_size = fmha::get_dynamic_smem_size() - + (loop_steps > 1 ? smem_size_softmax_lse : 0); - - if( smem_size >= 48 * 1024 ) { - FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - } - - if (configure) { - using Mma_tile_p = fmha::Hmma_tile; - constexpr int M = Kernel_traits::Cta_tile_p::M; - size_t STEPS = (launch_params.params.seqlen_q + M - 1) / M; - constexpr size_t MMAS_M = Mma_tile_p::MMAS_M; - constexpr size_t MMAS_N = Mma_tile_p::MMAS_N; - size_t elts_per_head = STEPS * MMAS_M * MMAS_N * 8 * loop_steps; - launch_params.elts_per_thread = elts_per_head; - return; - } - - dim3 grid(launch_params.params.b, launch_params.params.h); - kernel<<>>( - launch_params.params); - - FMHA_CHECK_CUDA(cudaPeekAtLastError()); -} - -void run_fmha_block_fp16_sm80(Launch_params &launch_params, - const bool configure) { - if (launch_params.params.d == 16) { - using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 4, 0x08u>; - run_fmha_block_fp16_sm80_loop_(launch_params, configure); - } else if (launch_params.params.d == 32) { - using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 4, 0x08u>; - run_fmha_block_fp16_sm80_loop_(launch_params, configure); - } else if (launch_params.params.d == 64) { - using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u>; - run_fmha_block_fp16_sm80_loop_(launch_params, configure); - } -} \ No newline at end of file diff --git a/csrc/flash_attn/src/fmha_block_fprop_kernel_1xN.h b/csrc/flash_attn/src/fmha_block_fprop_kernel_1xN.h deleted file mode 100644 index 15f865ecf..000000000 --- a/csrc/flash_attn/src/fmha_block_fprop_kernel_1xN.h +++ /dev/null @@ -1,533 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2022, Tri Dao. - * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the - * names of its contributors may be used to endorse or promote products - * derived from this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND - * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED - * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY - * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES - * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; - * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND - * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - ******************************************************************************/ - -#pragma once - -#include "fmha_fprop_kernel_1xN.h" -#include "fmha_kernel.h" -#include "fmha_blockmask.h" -#include -#include - -namespace fmha { - -template -inline __device__ void device_block_1xN_(const Params ¶ms, const int bidb, const int bidh, int steps, Prng &ph0, Prng &ph1, const int loop_step_idx) { - - - // The description of the CTA tile for the 1st batched GEMM. - using Cta_tile_p = typename Kernel_traits::Cta_tile_p; - // The description of the CTA tile for the 2nd batched GEMM. - using Cta_tile_o = typename Kernel_traits::Cta_tile_o; - - // The MMA tile for the 1st GEMM. - using Mma_tile_p = fmha::Hmma_tile; - // The MMA tile for the 2nd GEMM. - using Mma_tile_o = fmha::Hmma_tile; - - // The global memory tile to load Q. - using Gmem_tile_q = typename Kernel_traits::Gmem_tile_q; - - // The global memory tile to load K. - using Gmem_tile_k = typename Kernel_traits::Gmem_tile_k; - - // The global memory tile to load V. - using Gmem_tile_v = typename Kernel_traits::Gmem_tile_v; - // The shared memory tile to swizzle V. - using Smem_tile_v = typename Kernel_traits::Smem_tile_v; - - // The global memory tile to store O. - using Gmem_tile_o = typename Kernel_traits::Gmem_tile_o; - using Gmem_tile_o_tmp = fmha::Gmem_tile_o; - // The shared memory tile to swizzle O. - using Smem_tile_o = typename Kernel_traits::Smem_tile_o; - - using Gmem_tile_s = typename Kernel_traits::Gmem_tile_s; - - using Gmem_softmax_sum = typename Kernel_traits::Gmem_softmax_sum; - - using Smem_softmax_sum = typename Kernel_traits::Smem_dp_sum; - - using Gemm1 = Gemm_Q_K; - - using Softmax = fmha::Softmax; - - // Shared memory. - extern __shared__ char smem_[]; - - // The thread index. - const int tidx = threadIdx.x; - - const BlockInfoPadded binfo(params, bidb, bidh, tidx); - // if( binfo.stop_early() ) return; - if( binfo.stop_early(loop_step_idx * Cta_tile_p::N) ) return; - - Blockmask blockmask(params, loop_step_idx); - int block_row_idx = 0; - int mask_val = blockmask.mask_val(0); - if (mask_val == -1) return; - // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - // printf("mask_val = %d.\n", mask_val); - // } - - Gemm1 gemm_q_k(smem_, tidx); - // Allocate the global memory tile loader for Q. - Gmem_tile_q gmem_q(params.q_ptr, params.q_row_stride_in_elts, params.q_head_stride_in_elts, - params.d, binfo, tidx, true); - // Allocate the global memory tile loader for O. - Gmem_tile_o gmem_o(params.o_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, - params.d, binfo, tidx); - Gmem_tile_o_tmp gmem_o_tmp(params.o_tmp_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, - params.d, binfo, tidx); - // Allocate the global memory tile loader for S. - Gmem_tile_s gmem_s(params, binfo, tidx); - Gmem_softmax_sum gmem_softmax_lse(params.softmax_lse_ptr, params, tidx); - - // Wind gmem tiles to the correct position. - static_assert(Cta_tile_p::N % Cta_tile_p::M == 0); - int block_row_idx_next = mask_val / 4; - int block_row_idx_to_move = block_row_idx_next - block_row_idx; - gmem_q.move(block_row_idx_to_move); - gmem_o.move(block_row_idx_to_move); - gmem_o_tmp.move(block_row_idx_to_move); - if (Return_softmax) { gmem_s.move(block_row_idx_to_move); } - gmem_softmax_lse.move(block_row_idx_to_move); - block_row_idx = block_row_idx_next; - // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - // printf("begin = %d, steps = %d\n", begin, steps); - // } - - fmha::Mask mask(binfo, tidx, loop_step_idx); - - // Allocate the global memory tile loader for K. - Gmem_tile_k gmem_k(params.k_ptr, params.k_row_stride_in_elts, params.k_head_stride_in_elts, - params.d, binfo, tidx, false); - // Allocate the global memory tile loader for V. - Gmem_tile_v gmem_v(params.v_ptr, params.v_row_stride_in_elts, params.v_head_stride_in_elts, - params.d, binfo, tidx, false); - // The base pointer of smem_v; - char *smem_v_ = &smem_[Gemm1::SMEM_OFFSET_V]; - - // Allocate the shared memory tile loader for V. We use the same as K so be careful!!! - Smem_tile_v smem_v(smem_v_, tidx); - - // Allocate the shared memory tile loader for O. We use the same as K so be careful!!! - Smem_tile_o smem_o(&smem_[Gemm1::SMEM_OFFSET_O], tidx); - - if (!Is_first) { - gmem_k.move(loop_step_idx); - gmem_v.move(loop_step_idx); - if (Return_softmax) { gmem_s.move(loop_step_idx * steps); } - } - - // Trigger the loads for K. - gmem_k.load(); - // Trigger the loads for Q. - gmem_q.load(); - // Trigger the loads for V. - gmem_v.load(); - - if (!Is_first) { __syncthreads(); } - - float p_prev_lse[Mma_tile_p::MMAS_M * 2]; - if (!(Is_first || mask_val % 2 == 1)) { - gmem_softmax_lse.load(reinterpret_cast(p_prev_lse)); - } - - // Commit the data for Q and V to shared memory. - gmem_q.commit(gemm_q_k.smem_q); - gmem_v.commit(smem_v); - - // const uint32_t scale_bmm1 = reinterpret_cast(params.scale_bmm1); - // #pragma unroll - // for(int it=0;it < Gmem_tile_k::LDGS;it++){ - // gmem_k.fetch_[it] = fmha::hmul8(scale_bmm1, gmem_k.fetch_[it]); - // } - - // Commit the data for K to shared memory. - if( !Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) { - gmem_k.commit(gemm_q_k.smem_k); - } - - __syncthreads(); - - // Load the fragments for Q. - gemm_q_k.load_q(); - - // Load the fragments for V. We keep the data in registers during the entire kernel. - typename Smem_tile_v::Fragment frag_v[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_N]; - #pragma unroll - for( int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki ) { - smem_v.load(frag_v[ki], ki); - } - - // Commit the data for V to shared memory if it has not been done already. - if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) { - // Make sure we are done loading the fragments for K. - __syncthreads(); - - // Commit the data to shared memory for V. - gmem_k.commit(gemm_q_k.smem_k); - - // Make sure the data is in shared memory. - __syncthreads(); - } - - // Load the fragments for K. - gemm_q_k.load_k(); - - // Create the object to do the softmax. - Softmax softmax(params, &smem_[Gemm1::SMEM_OFFSET_SOFTMAX], tidx); - - Smem_softmax_sum smem_softmax_lse(reinterpret_cast(&smem_[Gemm1::SMEM_BYTES]), tidx); - - // Load over the entire sequence length. - for( int l = 0; l < steps; l++ ) { - // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - // printf("block_row_idx = %d\n", block_row_idx); - // } - if (block_row_idx * Cta_tile_p::M >= binfo.actual_seqlen_q) break; - - int mask_val_next = l < steps - 1 ? blockmask.mask_val(l + 1) : -1; - // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - // printf("mask_val = %d, mask_val_next = %d\n", mask_val, mask_val_next); - // } - - // Declare the accumulators for the 1st gemm. - fmha::Fragment_accumulator acc_p[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N]; - fmha::Clear_accumulator::apply(acc_p); - - // Do this part of P = Q * K^T. - gemm_q_k(acc_p); - - uint4 out[Gmem_tile_o::STGS_PER_LOOP]; - bool is_first_read = Is_first || mask_val % 2 == 1; - // if (!Is_first) { gmem_o_tmp.load(out, 0); } - if (!is_first_read) { gmem_o_tmp.load(out, 0); } - - // Trigger the load for the next Q values. - bool not_last_iter = (l < steps - 1) && (mask_val_next != -1); - block_row_idx_next = mask_val_next / 4; - int block_row_idx_to_move = block_row_idx_next - block_row_idx; - if (not_last_iter) { - gemm_q_k.smem_q.move_to_next_write_buffer(); - gmem_q.move(block_row_idx_to_move); - gmem_q.load(); - } - - // Load the mask for that iteration. - mask.load(block_row_idx); - - // Convert from the accumulator type to FP32 for Softmax. - softmax.unpack_noscale(acc_p); - - // Apply the mask. - softmax.apply_mask(mask); - - // softmax.unpack_noscale_half_and_apply_mask(acc_p, mask); - - if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V && l == 0 ) { - // if we share K and V, it could be that V was not fully read yet but we write into smem for reduction - __syncthreads(); - } - // if (!Is_first) { - // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) { - // printf("p_prev_lse=%.6f, %.6f\n", p_prev_lse[0], p_prev_lse[1]); - // } - // } - // Compute the max. - float p_max[Mma_tile_p::MMAS_M * 2]; - // if (!Is_first) { - if (!is_first_read) { - smem_softmax_lse.store_pair(p_prev_lse, l % 2); - // for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) { p_max[mi] = p_prev_lse[mi]; } - for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) { p_max[mi] = p_prev_lse[mi] / params.scale_bmm1f; } - } - - // Trigger the load for the next LSE values. - if (not_last_iter) { - // if (!Is_first) { - if (!(Is_first || mask_val_next % 2 == 1)) { - gmem_softmax_lse.load_next(reinterpret_cast(p_prev_lse), - block_row_idx_to_move); - } - } - - // __half2 p_max[Mma_tile_p::MMAS_M]; - // softmax.template reduce_max(p_max); - is_first_read ? softmax.template reduce_max(p_max) : softmax.template reduce_max(p_max); - - // if ((threadIdx.x == 0) && (l == 38)) { - // printf("loop_step_idx %d, p_max = %.6f, %.6f., p_prev_lse = %.6f, %.6f\n", loop_step_idx, p_max[0], p_max[1], Is_first ? -10000.f : p_prev_lse[0], Is_first ? -10000.f : p_prev_lse[1]); - // } - - // if (!Is_first) { - // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) { - // printf("after reduce_max=%.6f, %.6f\n", softmax.elt_[0][0], softmax.elt_[0][1]); - // } - // } - - // Compute the exponential value. - // softmax.apply_exp(p_max); - softmax.scale_apply_exp(p_max, params.scale_bmm1f); - - // if (!Is_first) { - // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) { - // printf("after apply_exp=%.6f, %.6f\n", softmax.elt_[0][0], softmax.elt_[0][1]); - // } - // } - - // Compute the sum. - float p_sum[Mma_tile_p::MMAS_M * 2]; - // if (!Is_first) { - // int warp = tidx / Cta_tile_p::THREADS_PER_WARP; - // int lane = tidx % Cta_tile_p::THREADS_PER_WARP; - // for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) { - // p_sum[mi] = ((warp == 0) && (lane % 4 == 0)) ? expf(p_prev_lse[mi] - p_max[mi]) : 0; - // } - // } - // softmax.reduce_sum(p_sum); - softmax.reduce_sum_before_sync_(p_sum); - // softmax.template reduce_sum_before_sync_(p_sum); - - // float p_sum_log[Mma_tile_p::MMAS_M * 2]; - // for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; ++mi) { - // float sum = p_sum[mi]; - // // p_sum_log[mi] = (sum == 0.f || sum != sum) ? INFINITY : p_max[mi] + __logf(sum); - // constexpr float kLog2e = M_LOG2E; - // p_sum_log[mi] = (sum == 0.f || sum != sum) ? INFINITY : p_max[mi] * kLog2e + __log2f(sum); - // } - // // gmem_softmax_lse.store(reinterpret_cast(p_sum)); - // gmem_softmax_lse.store(reinterpret_cast(p_sum_log)); - // gmem_softmax_lse.move(); - - // // Finalize softmax on the accumulators of P^T. - // softmax.scale(p_sum); - - constexpr bool encode_dropout_in_sign_bit = Return_softmax; - if (Is_dropout) { - // softmax.template apply_dropout(ph0, params.p_dropout_in_uint); - // softmax.template apply_dropout(ph0, ph1, params.p_dropout_in_uint); - softmax.template apply_dropout_16bits(ph0, ph1, params.p_dropout_in_uint16_t); - } - - using Frag_p = fmha::Fragment_a; - Frag_p frag_p[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M]; - static_assert(Mma_tile_o::MMAS_M == Mma_tile_p::MMAS_M); - static_assert(Mma_tile_o::MMAS_K == Mma_tile_p::MMAS_N); - softmax.template pack<__half>(frag_p); - if (Return_softmax) { - gmem_s.store(frag_p, mask); - if (not_last_iter) { - gmem_s.move(block_row_idx_to_move); - } - } - - // Commit the values for Q into shared memory. - if (not_last_iter) { - gmem_q.commit(gemm_q_k.smem_q); - } - - if (Is_dropout && encode_dropout_in_sign_bit) { - #pragma unroll - for( int ki = 0; ki < Mma_tile_o::MMAS_K; ki++ ) { - #pragma unroll - for( int mi = 0; mi < Mma_tile_o::MMAS_M; mi++ ) { - frag_p[ki][mi].template hrelu_<__half>(); - } - } - } - - // Declare the accumulators for the 2nd gemm. - fmha::Fragment_accumulator acc_o[Mma_tile_o::MMAS_M][Mma_tile_o::MMAS_N]; - fmha::Clear_accumulator::apply(acc_o); - - // Do this part of O = P^T * V^T. - #pragma unroll - for( int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki ) { - fmha::gemm_cl<__half>(acc_o, frag_p[ki], frag_v[ki]); - } - - // The mapping from tidx to rows changes between the softmax and the O-reduction. - // So we recalculate the max. - float p_max_o[Gmem_tile_o::STGS_PER_LOOP][Mma_tile_o::MMAS_M]; - int rows[Gmem_tile_o::STGS_PER_LOOP]; - for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) { - rows[jj] = tidx / Gmem_tile_o::THREADS_PER_ROW + jj * Gmem_tile_o::ROWS_PER_STG; - } - softmax.reduce_max_after_sync_(p_max_o, rows); - static_assert(Mma_tile_o::MMAS_M == 1); - for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) { - p_max_o[jj][0] *= params.scale_bmm1f; - } - float p_prev_scale_o[Gmem_tile_o::STGS_PER_LOOP]; - // if (!Is_first) { smem_softmax_lse.load(p_prev_scale_o, rows, l % 2); } - if (!is_first_read) { smem_softmax_lse.load(p_prev_scale_o, rows, l % 2); } - // if (!Is_first) { - // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) { - // printf("p_prev_scale_o=%.6f\n", p_prev_scale_o[0]); - // } - // } - - static_assert(Gmem_tile_o::LOOPS == 1); - - // Swizzle the elements and do the final reduction. - smem_o.store(acc_o, 0); - - // Make sure the data is in shared memory. - __syncthreads(); - - static_assert(Mma_tile_o::MMAS_M == 1); - float p_sum_o[Gmem_tile_o::STGS_PER_LOOP][Mma_tile_o::MMAS_M]; - softmax.reduce_sum_after_sync_(p_sum_o, rows); - // if (!Is_first) { - if (!is_first_read) { - for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) { - p_prev_scale_o[jj] = expf(p_prev_scale_o[jj] - p_max_o[jj][0]); - p_sum_o[jj][0] += p_prev_scale_o[jj]; - } - } - - float p_sum_log[Gmem_tile_o::STGS_PER_LOOP][Mma_tile_o::MMAS_M]; - #pragma unroll - for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) { - float sum = p_sum_o[jj][0]; - p_sum_log[jj][0] = (sum == 0.f || sum != sum) ? -INFINITY : p_max_o[jj][0] + __logf(sum); - // if (sum == 0.f || sum != sum) { - // printf("loop_step_idx = %d, l = %d, tidx = %d, sum = %.6f, p_max_o = %.6f\n", loop_step_idx, l, tidx, sum, p_max_o[jj][0]); - // } - // if (Is_first) { - // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) { - // printf("p_sum_log=%.6f\n", p_sum_log[jj][0]); - // } - // } - if ((tidx % Gmem_tile_o::THREADS_PER_ROW == 0) && (tidx / Gmem_tile_o::THREADS_PER_ROW < Gmem_tile_o::ROWS)) { - gmem_softmax_lse.store_row( - reinterpret_cast(p_sum_log[jj]), rows[jj]); - } - } - if (not_last_iter) { - gmem_softmax_lse.move(block_row_idx_to_move); - } - - // Load from shared memory. - // if (!Is_first) { - if (!is_first_read) { - for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) { - out[jj] = fmha::fmul4(out[jj], p_prev_scale_o[jj]); - } - } - // smem_o.template load(out); - is_first_read ? smem_o.template load(out) : smem_o.template load(out); - - const bool is_final_write = - Is_last - || ((loop_step_idx + 1) * Cta_tile_p::N >= binfo.actual_seqlen_k) - || ((mask_val & 0x2) != 0) - || ((Is_causal) && (block_row_idx * Cta_tile_p::M < (loop_step_idx + 1) * Cta_tile_p::N)); - // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - // printf("is_final_write = %d\n", is_final_write); - // } - #pragma unroll - for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) { - float sum = p_sum_o[jj][0]; - float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; - if (Is_dropout && is_final_write) { - inv_sum *= params.rp_dropout; - } - out[jj] = fmha::fmul4(out[jj], inv_sum); - } - - // if (Is_dropout && Is_last) { - // for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) { - // out[jj] = fmha::fmul4(out[jj], params.rp_dropout); - // } - // } - - // Output the values. - if (is_final_write) { - gmem_o.template store<__half>(out, 0); - } else { - gmem_o_tmp.store(out, 0); - } - - // Move to the next part of the output. - gmem_o.move(block_row_idx_to_move); - if (!(Is_first && Is_last)) { gmem_o_tmp.move(block_row_idx_to_move); } - gemm_q_k.reload_k(); - - // Make sure we are reading from the correct buffer. - gemm_q_k.smem_q.move_to_next_read_buffer(); - // Trigger the load from shared memory for the next series of Q values. - if (not_last_iter) { - gemm_q_k.reload_q(); - } - - if (mask_val_next == -1) break; - mask_val = mask_val_next; - block_row_idx += block_row_idx_to_move; - - } // Outer loop over the sequence length. -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ void device_block_1xN_loop(const Params ¶ms) { - - // The block index for the batch. - const int bidb = blockIdx.x; - // The block index for the head. - const int bidh = blockIdx.y; - // The thread index. - const int tidx = threadIdx.x; - - const int tidx_global = (bidb * params.h + bidh) * blockDim.x * 2 + tidx; - auto seeds = at::cuda::philox::unpack(params.philox_args); - Philox ph0(std::get<0>(seeds), tidx_global, std::get<1>(seeds)); - Philox ph1(std::get<0>(seeds), tidx_global + blockDim.x, std::get<1>(seeds)); - constexpr int M = Kernel_traits::Cta_tile_p::M; - const int STEPS = (params.seqlen_q + M - 1) / M; - - constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N; - if (params.seqlen_k == blocksize_c) { - fmha::device_block_1xN_(params, bidb, bidh, STEPS, ph0, ph1, 0); - } else { - const int max_loop_steps = (params.seqlen_k + blocksize_c - 1) / blocksize_c; - fmha::device_block_1xN_(params, bidb, bidh, STEPS, ph0, ph1, 0); - for (int loop_step_idx = 1; loop_step_idx < max_loop_steps - 1; loop_step_idx++) { - fmha::device_block_1xN_(params, bidb, bidh, STEPS, ph0, ph1, loop_step_idx); - } - fmha::device_block_1xN_(params, bidb, bidh, STEPS, ph0, ph1, max_loop_steps - 1); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace fmha - diff --git a/csrc/flash_attn/src/fmha_blockmask.h b/csrc/flash_attn/src/fmha_blockmask.h deleted file mode 100644 index bbd33d62a..000000000 --- a/csrc/flash_attn/src/fmha_blockmask.h +++ /dev/null @@ -1,57 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the - * names of its contributors may be used to endorse or promote products - * derived from this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND - * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED - * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY - * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES - * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; - * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND - * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - ******************************************************************************/ - -#pragma once - -#include -#include -#include -#include -#include -#include - -namespace fmha { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct Blockmask { - - template - __device__ Blockmask(const Params ¶ms, int loop_step_idx) : - blockmask_ptr(params.blockmask + loop_step_idx * params.seqlen_q / 16) { - } - - __device__ int mask_val(int block_row_idx) const { - return blockmask_ptr[block_row_idx]; - } - - const int *blockmask_ptr; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace fmha diff --git a/csrc/flash_attn/src/fmha_bwd_hdim128.cu b/csrc/flash_attn/src/fmha_bwd_hdim128.cu deleted file mode 100644 index 138dcaafa..000000000 --- a/csrc/flash_attn/src/fmha_bwd_hdim128.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2022, Tri Dao. - -// Splitting the different head dimensions to different files to speed up compilation. - -#include "fmha_bwd_launch_template.h" - -void run_fmha_bwd_hdim128(FMHA_dgrad_params ¶ms, cudaStream_t stream, const bool configure) { - FP16_SWITCH(params.is_bf16, ([&] { - using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 8, 0x100u, elem_type>; - run_fmha_bwd_loop(params, stream, configure); - })); -} \ No newline at end of file diff --git a/csrc/flash_attn/src/fmha_bwd_hdim32.cu b/csrc/flash_attn/src/fmha_bwd_hdim32.cu deleted file mode 100644 index a09ebac2b..000000000 --- a/csrc/flash_attn/src/fmha_bwd_hdim32.cu +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright (c) 2022, Tri Dao. - -// Splitting the different head dimensions to different files to speed up compilation. - -#include "fmha_bwd_launch_template.h" - -void run_fmha_bwd_hdim32(FMHA_dgrad_params ¶ms, cudaStream_t stream, const bool configure) { - FP16_SWITCH(params.is_bf16, ([&] { - if (params.seqlen_k == 128) { - using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 8, 0x08u, elem_type>; - run_fmha_bwd_loop(params, stream, configure); - } else if (params.seqlen_k >= 256) { - using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 8, 0x08u, elem_type>; - run_fmha_bwd_loop(params, stream, configure); - } - })); -} \ No newline at end of file diff --git a/csrc/flash_attn/src/fmha_bwd_hdim64.cu b/csrc/flash_attn/src/fmha_bwd_hdim64.cu deleted file mode 100644 index 3091605ba..000000000 --- a/csrc/flash_attn/src/fmha_bwd_hdim64.cu +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright (c) 2022, Tri Dao. - -// Splitting the different head dimensions to different files to speed up compilation. - -#include "fmha_bwd_launch_template.h" - -void run_fmha_bwd_hdim64(FMHA_dgrad_params ¶ms, cudaStream_t stream, const bool configure) { - FP16_SWITCH(params.is_bf16, ([&] { - auto dprops = at::cuda::getCurrentDeviceProperties(); - if (params.seqlen_k == 128) { - using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u, elem_type>; - run_fmha_bwd_loop(params, stream, configure); - } else if (params.seqlen_k >= 256) { - if ((dprops->major == 8 && dprops->minor == 0) || (dprops->major == 9 && dprops->minor == 0)) { - // Don't share smem for K & V, and don't keep V in registers - // This speeds things up by 2-3% by avoiding register spills, but it - // uses more shared memory, which is fine on A100 and H100 but not other GPUs. - // For other GPUs, we keep V in registers. - using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x100u, elem_type>; - run_fmha_bwd_loop(params, stream, configure); - } else if (dprops->major == 8 && dprops->minor > 0) { - using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x08u, elem_type>; - run_fmha_bwd_loop(params, stream, configure); - } else if (dprops->major == 7 && dprops->minor == 5) { - using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u, elem_type>; - run_fmha_bwd_loop(params, stream, configure); - } - } - })); -} \ No newline at end of file diff --git a/csrc/flash_attn/src/fmha_bwd_launch_template.h b/csrc/flash_attn/src/fmha_bwd_launch_template.h deleted file mode 100644 index 032c4a11d..000000000 --- a/csrc/flash_attn/src/fmha_bwd_launch_template.h +++ /dev/null @@ -1,114 +0,0 @@ -// Copyright (c) 2022, Tri Dao. - -#pragma once - -#include "static_switch.h" -#include "fmha.h" -#include "fmha_dgrad_kernel_1xN_loop.h" - -// Pick whether we should parallelize across seqlen_k (num_splits > 1) or not (num_splits=1). -// Parallelizing will have better occupancy, but has some overhead due to having to zero out -// dq_tmp and having to copy dq_tmp to dq. -inline int num_splits_heuristic_bwd(int batch_nheads, int num_SMs, int ctas_per_sm, int seqlen, - int blocksize, bool is_causal) { - float n_waves_1 = float(batch_nheads) / (num_SMs * ctas_per_sm); - float eff_1 = n_waves_1 / ceil(n_waves_1); - int num_splits_parallel = seqlen / blocksize; - float n_waves_parallel = float(batch_nheads * num_splits_parallel) / (num_SMs * ctas_per_sm); - float eff_parallel_raw = n_waves_parallel / ceil(n_waves_parallel); - float discount_factor; - if (!is_causal) { - discount_factor = 1.f + float(blocksize) / seqlen; - } else { // For causal, parallelizing seems to help with load-balancing as well - // For example, if headdim=128, seqlen >= 1280 always prefers parallel - if (seqlen / blocksize >= 10) return num_splits_parallel; - discount_factor = 1.f + 0.5 * float(blocksize) / seqlen; - } - float eff_parallel = eff_parallel_raw / discount_factor; - return eff_1 >= eff_parallel ? 1 : num_splits_parallel; -} - -template -__global__ void fmha_bwd_dot_do_o_kernel(FMHA_dgrad_params params) { - fmha::compute_dot_do_o(params); -} - -template -__global__ void fmha_bwd_dq_dk_dv_loop_kernel(FMHA_dgrad_params params) { - fmha::compute_dq_dk_dv_1xN(params); -} - -template -__global__ void fmha_bwd_q_dk_dv_loop_seqparallel_kernel(FMHA_dgrad_params params) { - fmha::compute_dq_dk_dv_seqparallel(params); -} - -template -void run_fmha_bwd_loop(FMHA_dgrad_params ¶ms, cudaStream_t stream, const bool configure) { - constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float); - constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE; - constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE; - constexpr int smem_size_dq = Kernel_traits::Smem_tile_o::BYTES_PER_TILE; - - using Smem_tile_s = fmha::Smem_tile_mma_transposed; - constexpr int smem_size_s = Smem_tile_s::BYTES_PER_TILE; - static_assert(smem_size_s == 16 * Kernel_traits::Cta_tile_p::N * 2); - static_assert(smem_size_dq == 16 * Kernel_traits::Cta_tile_p::K * 4 * Kernel_traits::Cta_tile_p::WARPS_N); - - constexpr int smem_size_dq_dk_dv = smem_size_q * 2 + smem_size_v * (Kernel_traits::V_IN_REGS ? 1 : 2) + smem_size_dq + smem_size_s * 2; - constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N; - // printf("blocksize_c = %d, WARPS_N = %d, Smem size = %d\n", blocksize_c, Kernel_traits::Cta_tile_p::WARPS_N, smem_size_dq_dk_dv); - - bool is_dropout = params.p_dropout < 1.f; // params.p_dropout is the probability of "keeping" - // Work-around for gcc 7. It doesn't like nested BOOL_SWITCH. - BOOL_SWITCH(is_dropout, IsDropoutConst, ([&] { - auto kernel = params.is_causal - ? &fmha_bwd_dq_dk_dv_loop_kernel - : &fmha_bwd_dq_dk_dv_loop_kernel; - if (params.seqlen_k == blocksize_c) { - kernel = params.is_causal - ? &fmha_bwd_dq_dk_dv_loop_kernel - : &fmha_bwd_dq_dk_dv_loop_kernel; - } else if (params.seqlen_k == blocksize_c * 2) { - kernel = params.is_causal - ? &fmha_bwd_dq_dk_dv_loop_kernel - : &fmha_bwd_dq_dk_dv_loop_kernel; - } - auto kernel_seqparallel = params.is_causal - ? &fmha_bwd_q_dk_dv_loop_seqparallel_kernel - : &fmha_bwd_q_dk_dv_loop_seqparallel_kernel; - if( smem_size_dq_dk_dv >= 48 * 1024 ) { - FMHA_CHECK_CUDA(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); - FMHA_CHECK_CUDA(cudaFuncSetAttribute( - kernel_seqparallel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); - } - // Automatically set num_splits to maximize occupancy - if (params.num_splits <= 0) { - int ctas_per_sm; - cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &ctas_per_sm, kernel, Kernel_traits::THREADS, smem_size_dq_dk_dv); - auto dprops = at::cuda::getCurrentDeviceProperties(); - // printf("CTAS_PER_SM = %d, nSMs = %d\n", ctas_per_sm, dprops->multiProcessorCount); - constexpr int M = Kernel_traits::Cta_tile_p::M; - // We don't want more than 10 splits due to numerical error. - // Numerical error on dk/dv scales as sqrt(num_splits). - params.num_splits = num_splits_heuristic_bwd( - params.b * params.h, dprops->multiProcessorCount, - ctas_per_sm, params.seqlen_k, blocksize_c, params.is_causal - ); - } - if (configure) return; - if (params.num_splits == 1) { - dim3 grid(params.b, params.h, params.num_splits); - kernel<<>>(params); - } else { - dim3 grid_dot(params.b, params.h, (params.seqlen_q + 128 - 1) / 128); - fmha_bwd_dot_do_o_kernel<<>>(params); - int num_splits = params.seqlen_k / blocksize_c; // seqlen_k is divisible by blocksize_c - dim3 grid(params.b, params.h, num_splits); - kernel_seqparallel<<>>(params); - } - FMHA_CHECK_CUDA(cudaPeekAtLastError()); - })); -} diff --git a/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h b/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h deleted file mode 100644 index d5ac579a3..000000000 --- a/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h +++ /dev/null @@ -1,841 +0,0 @@ -/* Copyright (c) 2022, Tri Dao. - */ - -#pragma once - -#include "fmha_fprop_kernel_1xN.h" -#include "fmha_kernel.h" -#include -#include - -namespace fmha { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ void dot_do_o(const uint4 (&do_)[M], const uint4 (&o)[M], const float scale, - Gmem_softmax_sum gmem_softmax_d, int tidx) { - float sum[M]; - fmha::SumOp sum_op; - #pragma unroll - for (int mi = 0; mi < M; ++mi) { - sum[mi] = fmha::Allreduce::run( - fmha::hmulsum8(do_[mi], o[mi]), sum_op - ) * scale; - } - const int dp_sum_row = tidx / THREADS_PER_ROW; - if ((dp_sum_row < ROWS) && (tidx % THREADS_PER_ROW == 0)) { - gmem_softmax_d.store_row(reinterpret_cast(sum), dp_sum_row); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// Just compute dot(do, o) and write the result (softmax_d) to global memory as a separate kernel. -// This is used in the case where we want to parallelize the backward across seqlen_k. -template -inline __device__ void compute_dot_do_o(const Params ¶ms) { - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - using elem_type = typename Kernel_traits::elem_type; -#else - constexpr bool is_fp16_type = std::is_same::value; - assert(is_fp16_type); - using elem_type = __half; -#endif - - // The description of the CTA tile for the 1st batched GEMM. - using Cta_tile_p = typename Kernel_traits::Cta_tile_p; - // The description of the CTA tile for the 3rd batched GEMM. - using Cta_tile_dkv = - fmha::Cta_tile_extd; - - static_assert(Cta_tile_dkv::N == 16 || Cta_tile_dkv::N == 32 || Cta_tile_dkv::N == 64 || Cta_tile_dkv::N == 128); - static_assert(Cta_tile_dkv::K == 16); - - // The global memory tile to load dO. - using Gmem_tile_do = typename Kernel_traits::Gmem_tile_do; - - // The global memory tile to load O.Loading O here is similar to loading dO. - using Gmem_tile_o = Gmem_tile_do; - - using Gmem_softmax_sum = typename Kernel_traits::Gmem_softmax_sum; - - // The block index for the batch. - const int bidb = blockIdx.x; - // The block index for the head. - const int bidh = blockIdx.y; - // The thread index. - const int tidx = threadIdx.x; - - // How many steps to jump per iteration. - const int step_stride = gridDim.z; - - const BlockInfoPadded binfo(params, bidb, bidh, tidx); - if( binfo.stop_early() ) return; - - // Allocate the global memory tile loader for dO. - Gmem_tile_do gmem_do(params.do_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, - params.d, binfo, tidx, true); - - // Allocate the global memory tile loader for O. - Gmem_tile_o gmem_o(params.o_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, - params.d, binfo, tidx, true); - - Gmem_softmax_sum gmem_softmax_d(params.dsoftmax_sum, params, tidx); - - static_assert(Cta_tile_p::N % Cta_tile_p::M == 0); - const int steps = (params.seqlen_q + Cta_tile_p::M - 1) / Cta_tile_p::M; - // Wind gmem tiles to the correct position. - gmem_do.move(blockIdx.z); - gmem_o.move(blockIdx.z); - gmem_softmax_d.move(blockIdx.z); - - // Load over the entire sequence length. - for (int l = blockIdx.z; l < steps; l += step_stride) { - if (l * Cta_tile_p::M >= binfo.actual_seqlen_q) - break; - - gmem_do.load(); - gmem_do.move(step_stride); - gmem_o.load(); - gmem_o.move(step_stride); - - dot_do_o( - gmem_do.fetch_, gmem_o.fetch_, params.p_dropout, gmem_softmax_d, tidx - ); - gmem_softmax_d.move(step_stride); - } // Outer loop over the sequence length. -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng &ph, - const int loop_step_idx) { - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - using elem_type = typename Kernel_traits::elem_type; -#else - constexpr bool is_fp16_type = std::is_same::value; - assert(is_fp16_type); - using elem_type = __half; -#endif - - // The description of the CTA tile for the 1st batched GEMM. - using Cta_tile_p = typename Kernel_traits::Cta_tile_p; - // The description of the CTA tile for the 2nd batched GEMM. - using Cta_tile_dq = typename Kernel_traits::Cta_tile_o; - // The description of the CTA tile for the 3rd batched GEMM. - using Cta_tile_dkv = - fmha::Cta_tile_extd; - - static_assert(Cta_tile_dkv::M == 512 || Cta_tile_dkv::M == 256 || Cta_tile_dkv::M == 128); - static_assert(Cta_tile_dkv::N == 16 || Cta_tile_dkv::N == 32 || Cta_tile_dkv::N == 64 || Cta_tile_dkv::N == 128); - static_assert(Cta_tile_dkv::K == 16); - - // The MMA tile for the 1st GEMM. - using Mma_tile_p = fmha::Hmma_tile; - // The MMA tile for the 2nd GEMM. - using Mma_tile_dq = fmha::Hmma_tile; - // The MMA tile for the 3rd GEMM. - using Mma_tile_dkv = fmha::Hmma_tile; - - // The global memory tile to load Q. - using Gmem_tile_q = typename Kernel_traits::Gmem_tile_q; - // The shared memory tile to reload Q transposed. - using Smem_tile_qt = fmha::Smem_tile_b; - - // The global memory tile to load K. - using Gmem_tile_k = typename Kernel_traits::Gmem_tile_k; - // The shared memory tile to swizzle K^T. Treat K^T as V - using Smem_tile_kt = typename Kernel_traits::Smem_tile_v; - - // Treating V as K. We need to use Kernel_traits::Smem_tile_k otherwise loading will be wrong - // The global memory tile to load V. - using Gmem_tile_v = typename Kernel_traits::Gmem_tile_k; - // The shared memory tile to swizzle V. - using Smem_tile_v = typename Kernel_traits::Smem_tile_k; - - // The global memory tile to load dO. - using Gmem_tile_do = typename Kernel_traits::Gmem_tile_do; - // The shared memory tile to load dO. - // Treating dO as Q. - using Smem_tile_do = typename Kernel_traits::Smem_tile_q; - // The shared memory tile to reload dO transposed. - using Smem_tile_dot = fmha::Smem_tile_b; - - // The global memory tile to load O.Loading O here is similar to loading dO. - using Gmem_tile_o = Gmem_tile_do; - - // The global memory tile to store dQ. - using Gmem_tile_dq = typename Kernel_traits::Gmem_tile_o; - using Gmem_tile_dq_tmp = fmha::Gmem_tile_o; - // The shared memory tile to swizzle dQ. - using Smem_tile_dq = typename Kernel_traits::Smem_tile_o; - - // The global memory tile to store dV. - using Gmem_tile_dv = typename Kernel_traits::Gmem_tile_v; - // The shared memory tile to swizzle dV. - using Smem_tile_dv = fmha::Smem_tile_mma_epilogue; - - // The global memory tile to store dK. - using Gmem_tile_dk = typename Kernel_traits::Gmem_tile_v; - // The shared memory tile to swizzle dK. - using Smem_tile_dk = fmha::Smem_tile_mma_epilogue; - static_assert(Smem_tile_dk::NUM_LDS == Gmem_tile_dk::LDGS); - static_assert(Smem_tile_dk::THREADS_PER_ROW == Gmem_tile_dk::THREADS_PER_ROW); - - using Gmem_tile_s = typename Kernel_traits::Gmem_tile_s; - - using Smem_tile_st = typename Kernel_traits::Smem_tile_st; - - using Gmem_softmax_sum = typename Kernel_traits::Gmem_softmax_sum; - - // using Gemm1 = Gemm_Q_K; - using Gemm1 = Gemm_Q_K; - - using Softmax = fmha::Softmax; - - // Shared memory. - extern __shared__ char smem_[]; - // Shared memory layout if we keep V in registers: - // dO | Q | K / V | dQ | S | dP | dP_sum - // dV | dK - // Shared memory layout if we keep V shared memory: - // dO | Q | K | V | dQ | S | dP | dP_sum - // dV | dK - - - // The block index for the batch. - const int bidb = blockIdx.x; - // The block index for the head. - const int bidh = blockIdx.y; - // The thread index. - const int tidx = threadIdx.x; - - const BlockInfoPadded binfo(params, bidb, bidh, tidx); - // if( binfo.stop_early() ) return; - if( binfo.stop_early(loop_step_idx * Cta_tile_p::N) ) return; - - Gemm1 gemm_q_k(&smem_[Smem_tile_do::BYTES_PER_TILE], tidx); - // Allocate the global memory tile loader for Q. - Gmem_tile_q gmem_q(params.q_ptr, params.q_row_stride_in_elts, params.q_head_stride_in_elts, - params.d, binfo, tidx, true); - // Allocate the global memory tile loader for dQ. - Gmem_tile_dq gmem_dq(params.dq_ptr, params.dq_row_stride_in_elts, params.dq_head_stride_in_elts, - params.d, binfo, tidx); - Gmem_tile_dq_tmp gmem_dq_tmp(params.o_tmp_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, - params.d, binfo, tidx); - // Allocate the global memory tile loader for S. - Gmem_tile_s gmem_s(params, binfo, tidx); - - fmha::Mask mask(binfo, tidx, loop_step_idx); - - // Allocate the global memory tile loader for K. - Gmem_tile_k gmem_k(params.k_ptr, params.k_row_stride_in_elts, params.k_head_stride_in_elts, - params.d, binfo, tidx, false); - // Allocate the global memory tile loader for V. - Gmem_tile_v gmem_v(params.v_ptr, params.v_row_stride_in_elts, params.v_head_stride_in_elts, - params.d, binfo, tidx, false); - // The base pointer of smem_v; - char *smem_v_ = &smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_V]; - - // Allocate the shared memory tile loader for V. We use the same as K so be careful!!! - Smem_tile_v smem_v(smem_v_, tidx); - // Allocate the shared memory tile loader for K^T. We use the same as K so be careful!!! - Smem_tile_kt smem_kt(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::Smem_tile_q::BYTES_PER_TILE], tidx); - - // Allocate the global memory tile loader for dO. - Gmem_tile_do gmem_do(params.do_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, - params.d, binfo, tidx, true); - // Allocate the shared memory tile loader for dO. - Smem_tile_do smem_do(&smem_[0], tidx); - Smem_tile_dot smem_dot(&smem_[0], tidx); - // Allocate the shared memory tile loader for Q^T. - // TODO: assert that this points to the same memory as gemm_q_k.smem_q - Smem_tile_qt smem_qt(&smem_[Smem_tile_do::BYTES_PER_TILE], tidx); - - Smem_tile_st smem_s(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_O + Smem_tile_dq::BYTES_PER_TILE], tidx); - Smem_tile_st smem_dp(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_O + Smem_tile_dq::BYTES_PER_TILE + Smem_tile_st::BYTES_PER_TILE], tidx); - - // Allocate the global memory tile loader for O. - Gmem_tile_o gmem_o(params.o_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, - params.d, binfo, tidx, true); - - // Allocate the shared memory tile loader for O. We use the same as K so be careful!!! - Smem_tile_dq smem_dq(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_O], tidx); - - Gmem_softmax_sum gmem_softmax_lse(params.softmax_lse_ptr, params, tidx); - Gmem_softmax_sum gmem_softmax_d(params.dsoftmax_sum, params, tidx); - - static_assert(Cta_tile_p::N % Cta_tile_p::M == 0); - int begin = Is_causal ? loop_step_idx * Cta_tile_p::N / Cta_tile_p::M : 0; - // Otherwise we'd be reading out-of-bound memory before the loop - if (begin * Cta_tile_p::M >= binfo.actual_seqlen_q) { - // Still need to zero out dk and dv before returning - static_assert(Smem_tile_dk::NUM_LDS == Smem_tile_dv::NUM_LDS); - uint4 dkv_out[Smem_tile_dk::NUM_LDS]; - #pragma unroll - for (int i = 0; i < Smem_tile_dk::NUM_LDS; ++i) { dkv_out[i] = make_uint4(0u, 0u, 0u, 0u); } - Gmem_tile_dk gmem_dk(params.dk_ptr, params.dk_row_stride_in_elts, params.dk_head_stride_in_elts, - params.d, binfo, tidx, false); - if (!Is_first) { gmem_dk.move(loop_step_idx); } - gmem_dk.store(dkv_out); - Gmem_tile_dv gmem_dv(params.dv_ptr, params.dv_row_stride_in_elts, params.dv_head_stride_in_elts, - params.d, binfo, tidx, false); - if (!Is_first) { gmem_dv.move(loop_step_idx); } - gmem_dv.store(dkv_out); - return; - } - - const int steps = (params.seqlen_q + Cta_tile_p::M - 1) / Cta_tile_p::M - begin; - // Wind gmem tiles to the correct position. - gmem_q.move(begin); - gmem_do.move(begin); - gmem_o.move(begin); - if (!Seq_parallel) { gmem_dq.move(begin); } // If Seq_parallel, we're not using gmem_dq at all - gmem_dq_tmp.move(begin); - // TODO: need to move gmem_s if we want the intermediate result for debugging - gmem_softmax_lse.move(begin); - gmem_softmax_d.move(begin); - - if (!Is_first) { - gmem_k.move(loop_step_idx); - gmem_v.move(loop_step_idx); - } - - // Trigger the loads for K. - gmem_k.load(); - // Trigger the loads for Q. - gmem_q.load(); - // Trigger the loads for V. - gmem_v.load(); - // Trigger the loads for dO. - gmem_do.load(); - // Trigger the loads for O. - if (Is_first) { gmem_o.load(); } - - float p_lse[Mma_tile_p::MMAS_M * 2]; - gmem_softmax_lse.load(reinterpret_cast(p_lse)); - - if (!Is_first) { __syncthreads(); } - // Commit the data for Q, dO, and V to shared memory. - gmem_q.commit(gemm_q_k.smem_q); - gmem_do.commit(smem_do); - if (Is_first) { - dot_do_o( - gmem_do.fetch_, gmem_o.fetch_, params.p_dropout, gmem_softmax_d, tidx - ); - } - - // // Instead of scaling dP by rp_dropout, we scale V instead - // if (Is_dropout) { - // const uint32_t scale_dropout = params.scale_dropout; - // #pragma unroll - // for(int it=0; it < Gmem_tile_v::LDGS; it++){ - // gmem_v.fetch_[it] = fmha::hmul8(scale_dropout, gmem_v.fetch_[it]); - // } - // } - - gmem_v.commit(smem_v); - - // const uint32_t scale_bmm1 = reinterpret_cast(params.scale_bmm1); - // #pragma unroll - // for(int it=0; it < Gmem_tile_k::LDGS; it++){ - // gmem_k.fetch_[it] = fmha::hmul8(scale_bmm1, gmem_k.fetch_[it]); - // } - - // Commit the data for K to shared memory. - if( !Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) { - gmem_k.commit(gemm_q_k.smem_k); - } - - __syncthreads(); - - // Load the fragments for Q. - gemm_q_k.load_q(); - - // Load the fragments for V. We keep the data in registers during the entire kernel. - typename Smem_tile_v::Fragment frag_v[Kernel_traits::V_IN_REGS ? Mma_tile_p::MMAS_K : 2][Mma_tile_p::MMAS_N]; - if (Kernel_traits::V_IN_REGS) { - #pragma unroll - for( int ki = 0; ki < Mma_tile_p::MMAS_K; ++ki ) { - smem_v.load(frag_v[ki], ki); - } - } - - float dp_sum[Mma_tile_p::MMAS_M * 2]; - gmem_softmax_d.load(reinterpret_cast(dp_sum)); - - // Commit the data for V to shared memory if it has not been done already. - if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) { - // Make sure we are done loading the fragments for K. - __syncthreads(); - - // Commit the data to shared memory for V. - gmem_k.commit(gemm_q_k.smem_k); - - // Make sure the data is in shared memory. - __syncthreads(); - } - - // Load the fragments for K. - gemm_q_k.load_k(); - // Load the fragments for K^T. - // typename Smem_tile_kt::Fragment frag_kt[2][Mma_tile_dq::MMAS_N]; - // smem_kt.load(frag_kt[0], 0); - // typename Smem_tile_kt::Fragment frag_kt[Mma_tile_dq::MMAS_K][Mma_tile_dq::MMAS_N]; - // #pragma unroll - // for( int ki = 0; ki < Mma_tile_dq::MMAS_K; ++ki ) { - // smem_kt.load(frag_kt[ki], ki); - // } - - // Create the object to do the softmax. - // We won't be using the shared memory for this softmax at all - Softmax softmax(params, smem_, tidx); - - // Declare the accumulators for the 3rd gemm. - fmha::Fragment_accumulator acc_dv[Mma_tile_dkv::MMAS_M][Mma_tile_dkv::MMAS_N]; - fmha::Clear_accumulator::apply(acc_dv); - fmha::Fragment_accumulator acc_dk[Mma_tile_dkv::MMAS_M][Mma_tile_dkv::MMAS_N]; - fmha::Clear_accumulator::apply(acc_dk); - - // Load over the entire sequence length. - for (int l = 0; l < steps; l++) { - if ((begin + l) * Cta_tile_p::M >= binfo.actual_seqlen_q) - break; - - // Load the fragments for V. - // typename Smem_tile_v::Fragment frag_v[2][Mma_tile_p::MMAS_N]; - if (!Kernel_traits::V_IN_REGS) { smem_v.load(frag_v[0], 0); } - - // Load the fragments for dO. - typename Smem_tile_do::Fragment frag_do[2][Mma_tile_p::MMAS_M]; - smem_do.load(frag_do[0], 0); - - // Declare the accumulators for the 1st gemm. - fmha::Fragment_accumulator acc_p[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N]; - fmha::Clear_accumulator::apply(acc_p); - - // Do this part of P^T = (Q * K^T)^T. - gemm_q_k(acc_p); - - // Load the mask for that iteration. - mask.load(begin + l); - - // Convert from the accumulator type to FP32 for Softmax. - softmax.unpack_noscale(acc_p); - // Apply the mask. - softmax.apply_mask(mask); - // Scale by log-sum-exp of the softmax - // softmax.apply_exp(p_lse); - softmax.template scale_apply_exp(p_lse, params.scale_bmm1f); - if (Is_dropout) { - // softmax.apply_dropout(ph, params.p_dropout_in_uint); - // softmax.template apply_dropout(ph, params.p_dropout_in_uint); - // softmax.template apply_dropout_16bits(ph, params.p_dropout_in_uint16_t); - unsigned int warp_idx = threadIdx.x / 32; - // TODO: this should change after we rearrange the warps (e.g. cutlass branch) - unsigned int block_col_idx = loop_step_idx * Cta_tile_p::N / 16 + warp_idx; - unsigned long long philox_subsequence = (begin + l) * (binfo.actual_seqlen_k / 16) + block_col_idx; - softmax.template apply_dropout_16bits(ph, params.p_dropout_in_uint16_t, philox_subsequence); - } - - using Frag_p = fmha::Fragment_a; - Frag_p frag_p[Mma_tile_dq::MMAS_K][Mma_tile_dq::MMAS_M]; - static_assert(Mma_tile_dq::MMAS_M == Mma_tile_p::MMAS_M); - static_assert(Mma_tile_dq::MMAS_K == Mma_tile_p::MMAS_N); - softmax.template pack(frag_p); - - // Store s * dmask to smem for transpose - smem_s.store(frag_p); - - // Trigger the load for the next Q values. - if (l + 1 < steps) { - gemm_q_k.smem_q.move_to_next_write_buffer(); - gmem_q.move(); - gmem_q.load(); - } - - // if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V && l == 0 ) { - // // if we share K and V, it could be that V was not fully read yet but we write into smem for reduction - // __syncthreads(); - // } - - fmha::Fragment_accumulator acc_dp[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N]; - #pragma unroll - for (int mi = 0; mi < Mma_tile_p::MMAS_M; ++mi) { - #pragma unroll - for (int ni = 0; ni < Mma_tile_p::MMAS_N; ++ni) { - #pragma unroll - for (int ii = 0; ii < 8; ++ii) { - acc_dp[mi][ni].elt(ii) = -dp_sum[mi * 2 + ((ii / 2) % 2)]; - } - } - } - - // Do this part of dP^T = (dO * V^T)^T. - #pragma unroll - for( int ki = 1; ki < Mma_tile_p::MMAS_K; ++ki ) { - // Trigger the load from shared memory for the next series of dO values. - smem_do.load(frag_do[ki & 1], ki); - if (!Kernel_traits::V_IN_REGS) { - smem_v.load(frag_v[ki & 1], ki); - fmha::gemm_cl(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1) & 1]); - } else { - fmha::gemm_cl(acc_dp, frag_do[(ki - 1) & 1], frag_v[ki - 1]); - } - // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l < 4)) { - // float2 tmp = __half22float2(reinterpret_cast<__half2 &>(frag_do[(ki - 1) & 1])); - // printf("frag_do=%.6f, %.6f\n", tmp.x, tmp.y); - // tmp = __half22float2(reinterpret_cast<__half2 &>(frag_v[(ki - 1) & 1])); - // printf("frag_v=%.6f, %.6f\n", tmp.x, tmp.y); - // } - } - - // Do the final stage of math. - { - int ki = Mma_tile_p::MMAS_K; - if (!Kernel_traits::V_IN_REGS) { - fmha::gemm_cl(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1) & 1]); - } else { - fmha::gemm_cl(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1)]); - } - } - - auto pointwise_mult = [](float p, float dp, float d) { - return p * ((!Is_dropout) || p >= 0.f ? dp : d); - }; - #pragma unroll - for (int mi = 0; mi < Mma_tile_p::MMAS_M; mi++) { - #pragma unroll - for (int ni = 0; ni < Mma_tile_p::MMAS_N; ni++) { - softmax.elt_[2 * mi + 0][4 * ni + 0] = pointwise_mult(softmax.elt_[2 * mi + 0][4 * ni + 0], acc_dp[mi][ni].elt(0), dp_sum[2 * mi + 0]); - softmax.elt_[2 * mi + 0][4 * ni + 1] = pointwise_mult(softmax.elt_[2 * mi + 0][4 * ni + 1], acc_dp[mi][ni].elt(1), dp_sum[2 * mi + 0]); - softmax.elt_[2 * mi + 0][4 * ni + 2] = pointwise_mult(softmax.elt_[2 * mi + 0][4 * ni + 2], acc_dp[mi][ni].elt(4), dp_sum[2 * mi + 0]); - softmax.elt_[2 * mi + 0][4 * ni + 3] = pointwise_mult(softmax.elt_[2 * mi + 0][4 * ni + 3], acc_dp[mi][ni].elt(5), dp_sum[2 * mi + 0]); - softmax.elt_[2 * mi + 1][4 * ni + 0] = pointwise_mult(softmax.elt_[2 * mi + 1][4 * ni + 0], acc_dp[mi][ni].elt(2), dp_sum[2 * mi + 1]); - softmax.elt_[2 * mi + 1][4 * ni + 1] = pointwise_mult(softmax.elt_[2 * mi + 1][4 * ni + 1], acc_dp[mi][ni].elt(3), dp_sum[2 * mi + 1]); - softmax.elt_[2 * mi + 1][4 * ni + 2] = pointwise_mult(softmax.elt_[2 * mi + 1][4 * ni + 2], acc_dp[mi][ni].elt(6), dp_sum[2 * mi + 1]); - softmax.elt_[2 * mi + 1][4 * ni + 3] = pointwise_mult(softmax.elt_[2 * mi + 1][4 * ni + 3], acc_dp[mi][ni].elt(7), dp_sum[2 * mi + 1]); - } - } - - // Load the fragments for K^T. - typename Smem_tile_kt::Fragment frag_kt[2][Mma_tile_dq::MMAS_N]; - smem_kt.load(frag_kt[0], 0); - - // Trigger the load for the next dO values. - if (l + 1 < steps) { - smem_do.move_to_next_write_buffer(); - gmem_do.move(); - gmem_do.load(); - if (Is_first) { - gmem_o.move(); - gmem_o.load(); - } - } - - softmax.template pack(frag_p); - - // Store dp to smem for transpose - smem_dp.store(frag_p); - - // gmem_s.store(frag_p, mask); - // gmem_s.move(); - - // Declare the accumulators for the 2nd gemm. - fmha::Fragment_accumulator acc_dq[Mma_tile_dq::MMAS_M][Mma_tile_dq::MMAS_N]; - fmha::Clear_accumulator::apply(acc_dq); - - // Do this part of O = P^T * V^T. - #pragma unroll - for( int ki = 1; ki < Mma_tile_dq::MMAS_K; ++ki ) { - // Trigger the load from shared memory for the next series of Q values. - smem_kt.load(frag_kt[ki & 1], ki); - // Do the math for the values already in registers. - fmha::gemm_cl(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1) & 1]); - // fmha::gemm_cl(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]); - } - // Do the final stage of math. - { - int ki = Mma_tile_dq::MMAS_K; - fmha::gemm_cl(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1) & 1]); - // fmha::gemm_cl(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]); - } - - static_assert(Gmem_tile_dq::LOOPS == 1); - - // Swizzle the elements and do the final reduction. - // Need to syncthreads here, otherwise the smem_dq reads from the previous iteration - // might happen after the smem_dq writes in this iteration. - __syncthreads(); - smem_dq.store(acc_dq, 0); - - typename Smem_tile_dot::Fragment frag_dot[2][Mma_tile_dkv::MMAS_N]; - static_assert(Smem_tile_dot::Fragment::NUM_REGS == 4); - static_assert(Mma_tile_dkv::MMAS_K == 1); - smem_dot.load(frag_dot[0], 0); - - // Threads in a warp is communicating via shared memory (smem_s and smem_dp) - __syncwarp(); - typename Smem_tile_st::Fragment frag_s[Mma_tile_dkv::MMAS_K][Mma_tile_dkv::MMAS_M]; - smem_s.load(frag_s); - - if (Is_dropout) { - #pragma unroll - for( int ki = 0; ki < Mma_tile_dkv::MMAS_K; ki++ ) { - #pragma unroll - for( int mi = 0; mi < Mma_tile_dkv::MMAS_M; mi++ ) { - frag_s[ki][mi].template hrelu_(); - } - } - } - - #pragma unroll - for( int ki = 1; ki < Mma_tile_dkv::MMAS_K; ++ki ) { - // Trigger the load from shared memory for the next series of Q values. - smem_dot.load(frag_dot[ki & 1], ki); - // Do the math for the values already in registers. - fmha::gemm_cl(acc_dv, frag_s[(ki - 1)], frag_dot[(ki - 1) & 1]); - } - - // Do the final stage of math. - { - int ki = Mma_tile_dkv::MMAS_K; - fmha::gemm_cl(acc_dv, frag_s[(ki - 1)], frag_dot[(ki - 1) & 1]); - } - - // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - // float2 tmp0 = __half22float2(reinterpret_cast<__half2 &>(frag_dot[0][0])); - // printf("frag_dot[0][0]=%.6f, %.6f\n", tmp0.x, tmp0.y); - // float2 tmp1 = __half22float2(reinterpret_cast<__half2 &>(frag_dot[0][1])); - // printf("frag_dot[0][1]=%.6f, %.6f\n", tmp1.x, tmp1.y); - // } - - // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - // printf("l = %d, acc_dv[0][0]=%.6f, %.6f\n", l, acc_dv[0][0].elt(2), acc_dv[0][0].elt(3)); - // printf("l = %d, acc_dv[0][1]=%.6f, %.6f\n", l, acc_dv[0][1].elt(2), acc_dv[0][1].elt(3)); - // } - // __syncthreads(); - // Commit the values for Q and dO into shared memory. - if (l + 1 < steps) { - gmem_q.commit(gemm_q_k.smem_q); - } - - uint4 dq_out[Gmem_tile_dq::STGS_PER_LOOP]; - if (!Is_first && !Seq_parallel) { gmem_dq_tmp.load(dq_out, 0); } - - // __syncthreads(); - // Commit the values for Q and dO into shared memory. - if (l + 1 < steps) { - gmem_do.commit(smem_do); - gmem_softmax_d.move(); - if (Is_first) { - dot_do_o( - gmem_do.fetch_, gmem_o.fetch_, params.p_dropout, gmem_softmax_d, tidx - ); - } - gmem_softmax_lse.move(); - gmem_softmax_lse.load(reinterpret_cast(p_lse)); - } - - typename Smem_tile_st::Fragment frag_dpt[Mma_tile_dkv::MMAS_K][Mma_tile_dkv::MMAS_M]; - smem_dp.load(frag_dpt); - - gemm_q_k.reload_k(); - - typename Smem_tile_qt::Fragment frag_qt[2][Mma_tile_dkv::MMAS_N]; - static_assert(Smem_tile_qt::Fragment::NUM_REGS == 4); - static_assert(Mma_tile_dkv::MMAS_K == 1); - smem_qt.load(frag_qt[0], 0); - - #pragma unroll - for( int ki = 1; ki < Mma_tile_dkv::MMAS_K; ++ki ) { - // Trigger the load from shared memory for the next series of Q values. - smem_qt.load(frag_qt[ki & 1], ki); - // Do the math for the values already in registers. - fmha::gemm_cl(acc_dk, frag_dpt[(ki - 1)], frag_qt[(ki - 1) & 1]); - } - - // Do the final stage of math. - { - int ki = Mma_tile_dkv::MMAS_K; - fmha::gemm_cl(acc_dk, frag_dpt[(ki - 1)], frag_qt[(ki - 1) & 1]); - } - - // Make sure dQ is in shared memory. - __syncthreads(); - - if (l + 1 < steps) { - gmem_softmax_d.load(reinterpret_cast(dp_sum)); - } - - // Load from shared memory. - smem_dq.template load(dq_out); - - if (!Seq_parallel) { - const bool is_final_write = - Is_last - || ((loop_step_idx + 1) * Cta_tile_p::N >= binfo.actual_seqlen_k) - || ((Is_causal) && ((begin + l) * Cta_tile_p::M < (loop_step_idx + 1) * Cta_tile_p::N)); - if (is_final_write) { - // if (Is_dropout) { - // dq_out[0] = fmha::fmul4(dq_out[0], params.rp_dropout); - // } - for (int jj = 0; jj < Gmem_tile_dq::STGS_PER_LOOP; ++jj) { - // dq_out[jj] = fmha::fmul4(dq_out[jj], params.scale_bmm1f); - dq_out[jj] = fmha::fmul4(dq_out[jj], params.scale_bmm1_rp_dropout); - } - // Output the values. - gmem_dq.template store(dq_out, 0); - // Move to the next part of the output. - gmem_dq.move(); - // TODO: for parallel, need to deal with the dropout scaling - } else { - // Output the values. - gmem_dq_tmp.store(dq_out, 0); - } - } else { - // We always scale dq_out before writing in this case, since we don't want to - // have to scale at the end when copying from dq_tmp to dq. - for (int jj = 0; jj < Gmem_tile_dq::STGS_PER_LOOP; ++jj) { - // dq_out[jj] = fmha::fmul4(dq_out[jj], params.scale_bmm1f); - dq_out[jj] = fmha::fmul4(dq_out[jj], params.scale_bmm1_rp_dropout); - } - gmem_dq_tmp.atomic_add(dq_out, 0); - } - - // Move to the next part of the output. - if (!(Is_first && Is_last)) { gmem_dq_tmp.move(); } - - // // Make sure the data is in shared memory. - // __syncthreads(); - - // Commit the values for Q and dO into shared memory. - if (l + 1 < steps) { - gemm_q_k.smem_q.move_to_next_read_buffer(); - gemm_q_k.reload_q(); - smem_qt.move_to_next_read_buffer(); - // smem_qt.load(frag_qt[0], 0); - smem_do.move_to_next_read_buffer(); - smem_dot.move_to_next_read_buffer(); - // smem_dot.load(frag_dot[0], 0); - } - - } // Outer loop over the sequence length. - - if (Is_dropout) { - for( int mi = 0; mi < Mma_tile_dkv::MMAS_M; mi++ ) { - for( int ni = 0; ni < Mma_tile_dkv::MMAS_N; ni++ ) { - acc_dv[mi][ni].mul_(params.rp_dropout); - } - } - } - // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - // printf("l final, acc_dv[0][0]=%.6f, %.6f\n", acc_dv[0][0].elt(2), acc_dv[0][0].elt(3)); - // printf("l final, acc_dv[0][1]=%.6f, %.6f\n", acc_dv[0][1].elt(2), acc_dv[0][1].elt(3)); - // } - for( int mi = 0; mi < Mma_tile_dkv::MMAS_M; mi++ ) { - for( int ni = 0; ni < Mma_tile_dkv::MMAS_N; ni++ ) { - // acc_dk[mi][ni].mul_(Is_dropout ? params.rp_dropout * params.scale_bmm1f : params.scale_bmm1f); - // acc_dk[mi][ni].mul_(params.scale_bmm1f); - acc_dk[mi][ni].mul_(params.scale_bmm1_rp_dropout); - } - } - // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - // printf("l final, acc_dk=%.6f, %.6f\n", acc_dk[0][0].elt(0), acc_dk[0][0].elt(1)); - // } - - __syncthreads(); - // TODO [TD - 2022-05-04]: Are there cases where the shared mem for dV and dK are larger than - // the total amount of shared mem? - // Epilogue swizzle for dV - Smem_tile_dv smem_dv(&smem_[0], tidx); - smem_dv.template store(acc_dv); - - // Epilogue swizzle for dK - Smem_tile_dk smem_dk(&smem_[Smem_tile_dv::BYTES_PER_TILE], tidx); - smem_dk.template store(acc_dk); - - __syncthreads(); - uint4 dv_out[Smem_tile_dv::NUM_LDS]; - smem_dv.load(dv_out); - Gmem_tile_dv gmem_dv(params.dv_ptr, params.dv_row_stride_in_elts, params.dv_head_stride_in_elts, - params.d, binfo, tidx, false); - if (!Is_first) { - gmem_dv.move(loop_step_idx); - } - gmem_dv.store(dv_out); - - uint4 dk_out[Smem_tile_dk::NUM_LDS]; - smem_dk.load(dk_out); - Gmem_tile_dk gmem_dk(params.dk_ptr, params.dk_row_stride_in_elts, params.dk_head_stride_in_elts, - params.d, binfo, tidx, false); - if (!Is_first) { - gmem_dk.move(loop_step_idx); - } - gmem_dk.store(dk_out); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// loop_steps = -1 means the number of steps will be params.seqlen_k / Kernel_traits::Cta_tile_p::N. -// This template parameter is there so we can specialize with loop_steps == 1 and loop_steps == 2. -template -inline __device__ void compute_dq_dk_dv_1xN(const Params ¶ms) { - constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N; - - // The block index for the batch. - const int bidb = blockIdx.x; - // The block index for the head. - const int bidh = blockIdx.y; - // The thread index. - const int tidx = threadIdx.x; - - auto seed = params.rng_state[0]; - auto offset = params.rng_state[1]; - Philox ph(seed, 0, offset + (bidb * params.h + bidh) * 32 + tidx % 32); - - if (loop_steps == 1) { - compute_dq_dk_dv_1xN_one_iter(params, ph, 0); - } else if (loop_steps == 2) { - compute_dq_dk_dv_1xN_one_iter(params, ph, 0); - compute_dq_dk_dv_1xN_one_iter(params, ph, 1); - } else { - if (params.seqlen_k == blocksize_c) { - compute_dq_dk_dv_1xN_one_iter(params, ph, 0); - } else { - const int max_loop_steps = (params.seqlen_k + blocksize_c - 1) / blocksize_c; - compute_dq_dk_dv_1xN_one_iter(params, ph, 0); - for (int loop_step_idx = 1; loop_step_idx < max_loop_steps - 1; loop_step_idx++) { - compute_dq_dk_dv_1xN_one_iter(params, ph, loop_step_idx); - } - compute_dq_dk_dv_1xN_one_iter(params, ph, max_loop_steps - 1); - } - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ void compute_dq_dk_dv_seqparallel(const Params ¶ms) { - // The block index for the batch. - const int bidb = blockIdx.x; - // The block index for the head. - const int bidh = blockIdx.y; - // The thread index. - const int tidx = threadIdx.x; - - auto seed = params.rng_state[0]; - auto offset = params.rng_state[1]; - Philox ph(seed, 0, offset + (bidb * params.h + bidh) * 32 + tidx % 32); - - int loop_step_idx = blockIdx.z; - compute_dq_dk_dv_1xN_one_iter(params, ph, loop_step_idx); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace fmha diff --git a/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h b/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h deleted file mode 100644 index ee5d68dcc..000000000 --- a/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h +++ /dev/null @@ -1,707 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2022, Tri Dao. - * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the - * names of its contributors may be used to endorse or promote products - * derived from this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND - * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED - * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY - * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES - * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; - * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND - * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - ******************************************************************************/ - -#pragma once - -#include "fmha_kernel.h" -#include -#include -#include - -namespace fmha { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Gemm_Q_K_base { - using Smem_tile_o = typename Kernel_traits::Smem_tile_o; - using Smem_tile_q = typename Kernel_traits::Smem_tile_q; - using Smem_tile_k = typename Kernel_traits::Smem_tile_k; - using Fragment_q = typename Smem_tile_q::Fragment; - using Fragment_k = typename Smem_tile_k::Fragment; - - // The description of the CTA tile for the 1st batched GEMM. - using Cta_tile_p = typename Kernel_traits::Cta_tile_p; - - // The MMA tile for the 1st GEMM. - using Mma_tile_p = fmha::Hmma_tile; - - static constexpr int SMEM_BYTES_SOFTMAX = Cta_tile_p::M * Cta_tile_p::WARPS_N * sizeof(float) * 2; - - __device__ inline Gemm_Q_K_base(char * smem_ptr_q, char * smem_ptr_k, const int tidx) - : smem_q(smem_ptr_q, tidx) - , smem_k(smem_ptr_k, tidx) { - - } - - __device__ inline void load_q() { - smem_q.load(frag_q[0], 0); - } - - __device__ inline void reload_q() { - smem_q.load(frag_q[0], 0); - } - - Fragment_q frag_q[2][Mma_tile_p::MMAS_M]; - Smem_tile_q smem_q; - Smem_tile_k smem_k; -}; - -template -struct Gemm_Q_K : public Gemm_Q_K_base { - - using Base = Gemm_Q_K_base; - using Smem_tile_o = typename Base::Smem_tile_o; - using Smem_tile_q = typename Base::Smem_tile_q; - using Smem_tile_k = typename Base::Smem_tile_k; - using Fragment_k = typename Base::Fragment_k; - using Mma_tile_p = typename Base::Mma_tile_p; - using elem_type = elem_type_; - - static constexpr bool SHARE_SMEM_FOR_K_AND_V = Kernel_traits::SHARE_SMEM_FOR_K_AND_V; - // If V is stored in shared memory, we can't load K using the same shared memory. - static_assert(Kernel_traits::V_IN_REGS); - - static constexpr int SMEM_OFFSET_O = Smem_tile_q::BYTES_PER_TILE; - static constexpr int SMEM_OFFSET_SOFTMAX = SMEM_OFFSET_O + Smem_tile_o::BYTES_PER_TILE; - static constexpr int SMEM_OFFSET_V = Smem_tile_q::BYTES_PER_TILE + (SHARE_SMEM_FOR_K_AND_V ? 0 : Smem_tile_k::BYTES_PER_TILE); - - // Q | K / V - // | O | SOFTMAX - static constexpr int SMEM_BYTES = Smem_tile_q::BYTES_PER_TILE - + std::max((SHARE_SMEM_FOR_K_AND_V ? 1 : 2) * Smem_tile_k::BYTES_PER_TILE, - Smem_tile_o::BYTES_PER_TILE + Base::SMEM_BYTES_SOFTMAX); - - __device__ inline Gemm_Q_K(char * smem_, const int tidx) - : Base(smem_, smem_ + Smem_tile_q::BYTES_PER_TILE, tidx) { - } - - __device__ inline void load_k(){ - #pragma unroll - for( int ki = 0; ki < Mma_tile_p::MMAS_K; ++ki ) { - Base::smem_k.load(frag_k[ki], ki); - } - } - - template - __device__ inline void operator()(Acc (&acc_p)[M][N]){ - // Do this part of P^T = (Q * K^T)^T. - #pragma unroll - for( int ki = 1; ki < Mma_tile_p::MMAS_K; ++ki ) { - // Trigger the load from shared memory for the next series of Q values. - Base::smem_q.load(Base::frag_q[ki & 1], ki); - // Do the math for the values already in registers. - fmha::gemm_cl(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1)]); - } - // Do the final stage of math. - { - int ki = Mma_tile_p::MMAS_K; - fmha::gemm_cl(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1)]); - } - } - - __device__ inline void reload_k(){ - // Noop. - } - - Fragment_k frag_k[Mma_tile_p::MMAS_K][Mma_tile_p::MMAS_N]; -}; - - -template -struct Gemm_Q_K : public Gemm_Q_K_base { - using Base = Gemm_Q_K_base; - using Smem_tile_o = typename Base::Smem_tile_o; - using Smem_tile_q = typename Base::Smem_tile_q; - using Smem_tile_k = typename Base::Smem_tile_k; - using Smem_tile_v = typename Kernel_traits::Smem_tile_v; - using Fragment_k = typename Base::Fragment_k; - using Mma_tile_p = typename Base::Mma_tile_p; - using elem_type = elem_type_; - Fragment_k frag_k[2][Mma_tile_p::MMAS_N]; - - static constexpr bool SHARE_SMEM_FOR_K_AND_V = Kernel_traits::SHARE_SMEM_FOR_K_AND_V; - static constexpr bool V_IN_REGS = Kernel_traits::V_IN_REGS; - static_assert(V_IN_REGS || !SHARE_SMEM_FOR_K_AND_V); - - static constexpr int SMEM_OFFSET_V = Smem_tile_q::BYTES_PER_TILE + (SHARE_SMEM_FOR_K_AND_V ? 0 : Smem_tile_k::BYTES_PER_TILE); - static_assert(Smem_tile_v::BYTES_PER_TILE == (int) Smem_tile_k::BYTES_PER_TILE); - static constexpr int SMEM_OFFSET_O = SMEM_OFFSET_V + Smem_tile_v::BYTES_PER_TILE; - static constexpr int SMEM_OFFSET_SOFTMAX = SMEM_OFFSET_O + Smem_tile_o::BYTES_PER_TILE; - - // If V_IN_REGS and SHARE_SMEM_FOR_K_AND_V: Q | K/V | O | SOFTMAX - // If !V_IN_REGS (then !SHARE_SMEM_FOR_K_AND_V): Q | K | V | O | SOFTMAX - static constexpr int SMEM_BYTES = Smem_tile_q::BYTES_PER_TILE - + (SHARE_SMEM_FOR_K_AND_V ? 1 : 2) * Smem_tile_k::BYTES_PER_TILE - + Smem_tile_o::BYTES_PER_TILE + Base::SMEM_BYTES_SOFTMAX; - - __device__ inline Gemm_Q_K(char * smem_, const int tidx) - : Base(smem_, smem_ + Smem_tile_q::BYTES_PER_TILE, tidx) { - } - - __device__ inline void load_k(){ - Base::smem_k.load(frag_k[0], 0); - } - - template - __device__ inline void operator()(Acc (&acc_p)[M][N]){ - // Do this part of P^T = (Q * K^T)^T. - #pragma unroll - for( int ki = 1; ki < Mma_tile_p::MMAS_K; ++ki ) { - // Trigger the load from shared memory for the next series of Q values. - Base::smem_q.load(Base::frag_q[ki & 1], ki); - Base::smem_k.load(frag_k[ki & 1], ki); - // Do the math for the values already in registers. - fmha::gemm_cl(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]); - } - // Do the final stage of math. - { - int ki = Mma_tile_p::MMAS_K; - fmha::gemm_cl(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]); - } - } - - __device__ inline void reload_k(){ - Base::smem_k.load(frag_k[0], 0); - } -}; - -template -constexpr size_t get_dynamic_smem_size(){ - return Gemm_Q_K::SMEM_BYTES; -} - -template -inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const int bidh, int steps, Prng &ph, const int loop_step_idx) { - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - using elem_type = typename Kernel_traits::elem_type; -#else - constexpr bool is_fp16_type = std::is_same::value; - assert(is_fp16_type); - using elem_type = __half; -#endif - - // The description of the CTA tile for the 1st batched GEMM. - using Cta_tile_p = typename Kernel_traits::Cta_tile_p; - // The description of the CTA tile for the 2nd batched GEMM. - using Cta_tile_o = typename Kernel_traits::Cta_tile_o; - - // The MMA tile for the 1st GEMM. - using Mma_tile_p = fmha::Hmma_tile; - // The MMA tile for the 2nd GEMM. - using Mma_tile_o = fmha::Hmma_tile; - - // The global memory tile to load Q. - using Gmem_tile_q = typename Kernel_traits::Gmem_tile_q; - - // The global memory tile to load K. - using Gmem_tile_k = typename Kernel_traits::Gmem_tile_k; - - // The global memory tile to load V. - using Gmem_tile_v = typename Kernel_traits::Gmem_tile_v; - // The shared memory tile to swizzle V. - using Smem_tile_v = typename Kernel_traits::Smem_tile_v; - - // The global memory tile to store O. - using Gmem_tile_o = typename Kernel_traits::Gmem_tile_o; - using Gmem_tile_o_tmp = fmha::Gmem_tile_o; - // The shared memory tile to swizzle O. - using Smem_tile_o = typename Kernel_traits::Smem_tile_o; - - using Gmem_tile_s = typename Kernel_traits::Gmem_tile_s; - - using Gmem_softmax_sum = typename Kernel_traits::Gmem_softmax_sum; - - using Smem_softmax_sum = typename Kernel_traits::Smem_dp_sum; - - using Gemm1 = Gemm_Q_K; - - using Softmax = fmha::Softmax; - - // Shared memory. - extern __shared__ char smem_[]; - - // The thread index. - const int tidx = threadIdx.x; - - // How many steps to jump per iteration, which is the same as params.num_splits. - const int step_stride = gridDim.z; - - const BlockInfoPadded binfo(params, bidb, bidh, tidx); - // if( binfo.stop_early() ) return; - if( binfo.stop_early(loop_step_idx * Cta_tile_p::N) ) return; - - Gemm1 gemm_q_k(smem_, tidx); - // Allocate the global memory tile loader for Q. - Gmem_tile_q gmem_q(params.q_ptr, params.q_row_stride_in_elts, params.q_head_stride_in_elts, - params.d, binfo, tidx, true); - // Allocate the global memory tile loader for O. - Gmem_tile_o gmem_o(params.o_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, - params.d, binfo, tidx); - Gmem_tile_o_tmp gmem_o_tmp(params.o_tmp_ptr, params.o_tmp_row_stride_in_elts, - params.o_tmp_head_stride_in_elts, params.d, binfo, tidx); - // Allocate the global memory tile loader for S. - Gmem_tile_s gmem_s(params, binfo, tidx); - Gmem_softmax_sum gmem_softmax_lse(params.softmax_lse_ptr, params, tidx); - - // Wind gmem tiles to the correct position. - static_assert(Cta_tile_p::N % Cta_tile_p::M == 0); - int begin = Is_causal ? loop_step_idx * Cta_tile_p::N / Cta_tile_p::M : 0; - // We want begin to be a multiple of gridDim.z - // This is because the row indices processed by each threadblock must align between the - // loop steps, otherwise we have a dependency between the blocks. - // For example, threadblock with blockIdx.z == 1 must process row indices that are - // k * gridDim.z + 1 for integer k. - const int begin_mod_z = begin % gridDim.z; - begin = begin_mod_z <= blockIdx.z ? begin - begin_mod_z : begin + gridDim.z - begin_mod_z; - // Otherwise we'd be reading out-of-bound memory before the loop - if ((begin + blockIdx.z) * Cta_tile_p::M >= binfo.actual_seqlen_q) return; - const int steps_og = steps; - steps -= begin; - gmem_q.move(begin + blockIdx.z); - gmem_o.move(begin + blockIdx.z); - gmem_o_tmp.move(begin + blockIdx.z); - if (Return_softmax) { - gmem_s.move(begin + blockIdx.z); - } - gmem_softmax_lse.move(begin + blockIdx.z); - // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - // printf("begin = %d, steps = %d\n", begin, steps); - // } - - fmha::Mask mask(binfo, tidx, loop_step_idx); - - // Allocate the global memory tile loader for K. - Gmem_tile_k gmem_k(params.k_ptr, params.k_row_stride_in_elts, params.k_head_stride_in_elts, - params.d, binfo, tidx, false); - // Allocate the global memory tile loader for V. - Gmem_tile_v gmem_v(params.v_ptr, params.v_row_stride_in_elts, params.v_head_stride_in_elts, - params.d, binfo, tidx, false); - // The base pointer of smem_v; - char *smem_v_ = &smem_[Gemm1::SMEM_OFFSET_V]; - - // Allocate the shared memory tile loader for V. We use the same as K so be careful!!! - Smem_tile_v smem_v(smem_v_, tidx); - - // Allocate the shared memory tile loader for O. We use the same as K so be careful!!! - Smem_tile_o smem_o(&smem_[Gemm1::SMEM_OFFSET_O], tidx); - - if (!Is_first) { - gmem_k.move(loop_step_idx); - gmem_v.move(loop_step_idx); - if (Return_softmax) { gmem_s.move(loop_step_idx * steps_og); } - } - - // Trigger the loads for K. - gmem_k.load(); - // Trigger the loads for Q. - gmem_q.load(); - // Trigger the loads for V. - gmem_v.load(); - - if (!Is_first) { __syncthreads(); } - - float p_prev_lse[Mma_tile_p::MMAS_M * 2]; - if (!Is_first) { - gmem_softmax_lse.load(reinterpret_cast(p_prev_lse)); - } - - // Commit the data for Q and V to shared memory. - gmem_q.commit(gemm_q_k.smem_q); - gmem_v.commit(smem_v); - - // const uint32_t scale_bmm1 = reinterpret_cast(params.scale_bmm1); - // #pragma unroll - // for(int it=0;it < Gmem_tile_k::LDGS;it++){ - // gmem_k.fetch_[it] = fmha::hmul8(scale_bmm1, gmem_k.fetch_[it]); - // } - - // Commit the data for K to shared memory. - if( !Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) { - gmem_k.commit(gemm_q_k.smem_k); - } - - __syncthreads(); - - // Load the fragments for Q. - gemm_q_k.load_q(); - - // Load the fragments for V. We keep the data in registers during the entire kernel. - typename Smem_tile_v::Fragment frag_v[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_N]; - #pragma unroll - for( int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki ) { - smem_v.load(frag_v[ki], ki); - } - - // Commit the data for V to shared memory if it has not been done already. - if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) { - // Make sure we are done loading the fragments for K. - __syncthreads(); - - // Commit the data to shared memory for V. - gmem_k.commit(gemm_q_k.smem_k); - - // Make sure the data is in shared memory. - __syncthreads(); - } - - // Load the fragments for K. - gemm_q_k.load_k(); - - // Create the object to do the softmax. - Softmax softmax(params, &smem_[Gemm1::SMEM_OFFSET_SOFTMAX], tidx); - - Smem_softmax_sum smem_softmax_lse(reinterpret_cast(&smem_[Gemm1::SMEM_BYTES]), tidx); - - // Load over the entire sequence length. - for (int l = blockIdx.z; l < steps; l += step_stride) { - // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (blockIdx.z <= 1)) { - // printf("l = %d\n", l); - // } - if ((begin + l) * Cta_tile_p::M >= binfo.actual_seqlen_q) break; - - // Declare the accumulators for the 1st gemm. - fmha::Fragment_accumulator acc_p[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N]; - fmha::Clear_accumulator::apply(acc_p); - - // Do this part of P = Q * K^T. - gemm_q_k(acc_p); - - // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) { - // printf("acc_p=%.6f, %.6f\n", acc_p[0][0].elt(0), acc_p[0][0].elt(1)); - // } - - uint4 out[Gmem_tile_o::STGS_PER_LOOP]; - if (!Is_first) { gmem_o_tmp.load(out, 0); } - - // Trigger the load for the next Q values. - if (l + step_stride < steps) { - gemm_q_k.smem_q.move_to_next_write_buffer(); - gmem_q.move(step_stride); - gmem_q.load(); - } - - // Load the mask for that iteration. - mask.load(begin + l); - - // Convert from the accumulator type to FP32 for Softmax. - softmax.unpack_noscale(acc_p); - - // Apply the mask. - softmax.apply_mask(mask); - - if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V && l < step_stride ) { - // if we share K and V, it could be that V was not fully read yet but we write into smem for reduction - __syncthreads(); - } - // if (!Is_first) { - // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l >= 0)) { - // printf("p_prev_lse=%.6f, %.6f\n", p_prev_lse[0], p_prev_lse[1]); - // } - // } - // Compute the max. - float p_max[Mma_tile_p::MMAS_M * 2]; - if (!Is_first) { - smem_softmax_lse.store_pair(p_prev_lse); - // for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) { p_max[mi] = p_prev_lse[mi]; } - for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) { p_max[mi] = p_prev_lse[mi] / params.scale_bmm1f; } - } - - // Trigger the load for the next LSE values. - if (l + step_stride < steps) { - if (!Is_first) { - gmem_softmax_lse.load_next(reinterpret_cast(p_prev_lse), - step_stride); - } - } - - softmax.template reduce_max(p_max); - - // if ((threadIdx.x == 0) && (l == 38)) { - // printf("loop_step_idx %d, p_max = %.6f, %.6f., p_prev_lse = %.6f, %.6f\n", loop_step_idx, p_max[0], p_max[1], Is_first ? -10000.f : p_prev_lse[0], Is_first ? -10000.f : p_prev_lse[1]); - // } - - // if (!Is_first) { - // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) { - // printf("after reduce_max=%.6f, %.6f\n", softmax.elt_[0][0], softmax.elt_[0][1]); - // } - // } - - // Compute the exponential value. - // softmax.apply_exp(p_max); - softmax.scale_apply_exp(p_max, params.scale_bmm1f); - - // if (!Is_first) { - // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) { - // printf("after apply_exp=%.6f, %.6f\n", softmax.elt_[0][0], softmax.elt_[0][1]); - // } - // } - - // Compute the sum. - float p_sum[Mma_tile_p::MMAS_M * 2]; - // if (!Is_first) { - // int warp = tidx / Cta_tile_p::THREADS_PER_WARP; - // int lane = tidx % Cta_tile_p::THREADS_PER_WARP; - // for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) { - // p_sum[mi] = ((warp == 0) && (lane % 4 == 0)) ? expf(p_prev_lse[mi] - p_max[mi]) : 0; - // } - // } - // softmax.reduce_sum(p_sum); - softmax.reduce_sum_before_sync_(p_sum); - // softmax.template reduce_sum_before_sync_(p_sum); - - // float p_sum_log[Mma_tile_p::MMAS_M * 2]; - // for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; ++mi) { - // float sum = p_sum[mi]; - // // p_sum_log[mi] = (sum == 0.f || sum != sum) ? INFINITY : p_max[mi] + __logf(sum); - // constexpr float kLog2e = M_LOG2E; - // p_sum_log[mi] = (sum == 0.f || sum != sum) ? INFINITY : p_max[mi] * kLog2e + __log2f(sum); - // } - // // gmem_softmax_lse.store(reinterpret_cast(p_sum)); - // gmem_softmax_lse.store(reinterpret_cast(p_sum_log)); - // gmem_softmax_lse.move(); - - // // Finalize softmax on the accumulators of P^T. - // softmax.scale(p_sum); - - constexpr bool encode_dropout_in_sign_bit = Return_softmax; - if (Is_dropout) { - // softmax.template apply_dropout(ph, params.p_dropout_in_uint); - // softmax.template apply_dropout(ph, ph1, params.p_dropout_in_uint); - // softmax.template apply_dropout_16bits(ph, ph1, params.p_dropout_in_uint16_t); - unsigned int warp_idx = threadIdx.x / 32; - // TODO: this should change after we rearrange the warps (e.g. cutlass branch) - unsigned int block_col_idx = loop_step_idx * Cta_tile_p::N / 16 + warp_idx; - // We want to use actual_seqlen_k, not seqlen_k, since seqlen_k could be rounded - // differently in the fwd and bwd pass. E.g., for d=128 on A100, fwd rounds seqlen_k - // to multiples of 256 while bwd rounds seqlen_k to multiples of 128. - unsigned long long philox_subsequence = (begin + l) * (binfo.actual_seqlen_k / 16) + block_col_idx; - softmax.template apply_dropout_16bits(ph, params.p_dropout_in_uint16_t, philox_subsequence); - } - - using Frag_p = fmha::Fragment_a; - Frag_p frag_p[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M]; - static_assert(Mma_tile_o::MMAS_M == Mma_tile_p::MMAS_M); - static_assert(Mma_tile_o::MMAS_K == Mma_tile_p::MMAS_N); - softmax.template pack(frag_p); - if (Return_softmax) { - gmem_s.store(frag_p, mask); - gmem_s.move(step_stride); - } - - // Commit the values for Q into shared memory. - if (l + step_stride < steps) { - gmem_q.commit(gemm_q_k.smem_q); - } - - if (Is_dropout && encode_dropout_in_sign_bit) { - #pragma unroll - for( int ki = 0; ki < Mma_tile_o::MMAS_K; ki++ ) { - #pragma unroll - for( int mi = 0; mi < Mma_tile_o::MMAS_M; mi++ ) { - frag_p[ki][mi].template hrelu_(); - } - } - } - - // Declare the accumulators for the 2nd gemm. - fmha::Fragment_accumulator acc_o[Mma_tile_o::MMAS_M][Mma_tile_o::MMAS_N]; - fmha::Clear_accumulator::apply(acc_o); - - // Do this part of O = P^T * V^T. - #pragma unroll - for( int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki ) { - fmha::gemm_cl(acc_o, frag_p[ki], frag_v[ki]); - // if ((threadIdx.x == 4) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) { - // float2 tmp_p = __half22float2(reinterpret_cast<__half2 &>(frag_p[ki])); - // float2 tmp_v = __half22float2(reinterpret_cast<__half2 &>(frag_v[ki])); - // printf("Per warp, threadIdx.x = %d, frag_p = %.6f, %.6f, frag_v = %.6f, %.6f, acc_o=%.6f\n", threadIdx.x, tmp_p.x, tmp_p.y, tmp_v.x, tmp_v.y, acc_o[0][0].elt(0)); - // } - } - - // if ((threadIdx.x % 32 == 16) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) { - // printf("Per warp, threadIdx.x = %d, acc_o=%.6f\n", threadIdx.x, acc_o[0][2].elt(0)); - // } - - // The mapping from tidx to rows changes between the softmax and the - // O-reduction. So we recalculate the max. - float p_max_o[Gmem_tile_o::STGS_PER_LOOP][Mma_tile_o::MMAS_M]; - int rows[Gmem_tile_o::STGS_PER_LOOP]; - for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) { - rows[jj] = tidx / Gmem_tile_o::THREADS_PER_ROW + jj * Gmem_tile_o::ROWS_PER_STG; - } - softmax.reduce_max_after_sync_(p_max_o, rows); - static_assert(Mma_tile_o::MMAS_M == 1); - for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) { - p_max_o[jj][0] *= params.scale_bmm1f; - } - float p_prev_scale_o[Gmem_tile_o::STGS_PER_LOOP]; - if (!Is_first) { - smem_softmax_lse.load(p_prev_scale_o, rows); - } - // if (!Is_first) { - // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) { - // printf("p_prev_scale_o=%.6f\n", p_prev_scale_o[0]); - // } - // } - - static_assert(Gmem_tile_o::LOOPS == 1); - - // Swizzle the elements and do the final reduction. - smem_o.store(acc_o, 0); - - // Make sure the data is in shared memory. - __syncthreads(); - - static_assert(Mma_tile_o::MMAS_M == 1); - float p_sum_o[Gmem_tile_o::STGS_PER_LOOP][Mma_tile_o::MMAS_M]; - softmax.reduce_sum_after_sync_(p_sum_o, rows); - if (!Is_first) { - for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) { - p_prev_scale_o[jj] = expf(p_prev_scale_o[jj] - p_max_o[jj][0]); - p_sum_o[jj][0] += p_prev_scale_o[jj]; - } - } - - float p_sum_log[Gmem_tile_o::STGS_PER_LOOP][Mma_tile_o::MMAS_M]; - #pragma unroll - for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) { - float sum = p_sum_o[jj][0]; - p_sum_log[jj][0] = (sum == 0.f || sum != sum) ? -INFINITY : p_max_o[jj][0] + __logf(sum); - // if (sum == 0.f || sum != sum) { - // printf("loop_step_idx = %d, l = %d, tidx = %d, sum = %.6f, p_max_o = %.6f\n", loop_step_idx, l, tidx, sum, p_max_o[jj][0]); - // } - // if (Is_first) { - // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) { - // printf("p_sum_log=%.6f\n", p_sum_log[jj][0]); - // } - // } - if (tidx % Gmem_tile_o::THREADS_PER_ROW == 0) { - gmem_softmax_lse.store_row( - reinterpret_cast(p_sum_log[jj]), rows[jj]); - } - } - gmem_softmax_lse.move(step_stride); - - // Load from shared memory. - if (!Is_first) { - for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) { - out[jj] = fmha::fmul4(out[jj], p_prev_scale_o[jj]); - } - } - smem_o.template load(out); - - const bool is_final_write = - Is_last - || ((loop_step_idx + 1) * Cta_tile_p::N >= binfo.actual_seqlen_k) - || ((Is_causal) && ((begin + l) * Cta_tile_p::M < (loop_step_idx + 1) * Cta_tile_p::N)); - #pragma unroll - for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) { - float sum = p_sum_o[jj][0]; - float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; - if (Is_dropout && is_final_write) { - inv_sum *= params.rp_dropout; - } - out[jj] = fmha::fmul4(out[jj], inv_sum); - } - - // if (Is_dropout && Is_last) { - // for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) { - // out[jj] = fmha::fmul4(out[jj], params.rp_dropout); - // } - // } - - // Output the values. - if (is_final_write) { - gmem_o.template store(out, 0); - gmem_o.move(step_stride); - } else { - gmem_o_tmp.store(out, 0); - } - - // Move to the next part of the output. - if (!(Is_first && Is_last)) { gmem_o_tmp.move(step_stride); } - gemm_q_k.reload_k(); - - // Make sure we are reading from the correct buffer. - gemm_q_k.smem_q.move_to_next_read_buffer(); - // Trigger the load from shared memory for the next series of Q values. - if (l + step_stride < steps) { - gemm_q_k.reload_q(); - } - } // Outer loop over the sequence length. -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ void device_1xN_loop(const Params ¶ms) { - - // The block index for the batch. - const int bidb = blockIdx.x; - // The block index for the head. - const int bidh = blockIdx.y; - // The block index. - const int bidx = gridDim.x * bidh + bidb; - // The thread index. - const int tidx = threadIdx.x; - - // We want the fwd and bwd to generate the same dropout pattern (RNG), without restricting - // them to have the same number of threads or have to traverse the attention matrix - // in the same order. - // In the Philox RNG, we use the offset to store the batch, head, and the lane id - // (within a warp). We use the subsequence to store the location of the 16 x 16 blocks within - // the attention matrix. This way, as long as we have the batch, head, and the location of - // the 16 x 16 block within the attention matrix, we can generate the exact same dropout pattern. - auto seeds = at::cuda::philox::unpack(params.philox_args); - if (bidx == 0 && tidx == 0) { - params.rng_state[0] = std::get<0>(seeds); - params.rng_state[1] = std::get<1>(seeds); - } - Philox ph(std::get<0>(seeds), 0, std::get<1>(seeds) + (bidb * params.h + bidh) * 32 + tidx % 32); - constexpr int M = Kernel_traits::Cta_tile_p::M; - const int STEPS = (params.seqlen_q + M - 1) / M; - - constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N; - if (params.seqlen_k == blocksize_c) { - fmha::device_1xN_(params, bidb, bidh, STEPS, ph, 0); - } else { - const int max_loop_steps = (params.seqlen_k + blocksize_c - 1) / blocksize_c; - fmha::device_1xN_(params, bidb, bidh, STEPS, ph, 0); - for (int loop_step_idx = 1; loop_step_idx < max_loop_steps - 1; loop_step_idx++) { - fmha::device_1xN_(params, bidb, bidh, STEPS, ph, loop_step_idx); - } - fmha::device_1xN_(params, bidb, bidh, STEPS, ph, max_loop_steps - 1); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace fmha - diff --git a/csrc/flash_attn/src/fmha_fwd_hdim128.cu b/csrc/flash_attn/src/fmha_fwd_hdim128.cu deleted file mode 100644 index 66532e651..000000000 --- a/csrc/flash_attn/src/fmha_fwd_hdim128.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2022, Tri Dao. - -// Splitting the different head dimensions to different files to speed up compilation. - -#include "fmha_fwd_launch_template.h" - -void run_fmha_fwd_hdim128(Launch_params &launch_params) { - FP16_SWITCH(launch_params.params.is_bf16, ([&] { - using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>; - run_fmha_fwd_loop(launch_params); - })); -} \ No newline at end of file diff --git a/csrc/flash_attn/src/fmha_fwd_hdim32.cu b/csrc/flash_attn/src/fmha_fwd_hdim32.cu deleted file mode 100644 index f569ca5f6..000000000 --- a/csrc/flash_attn/src/fmha_fwd_hdim32.cu +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright (c) 2022, Tri Dao. - -// Splitting the different head dimensions to different files to speed up compilation. - -#include "fmha_fwd_launch_template.h" - -void run_fmha_fwd_hdim32(Launch_params &launch_params) { - FP16_SWITCH(launch_params.params.is_bf16, ([&] { - if (launch_params.params.seqlen_k == 128) { - using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 4, 0x08u, elem_type>; - run_fmha_fwd_loop(launch_params); - } else if (launch_params.params.seqlen_k >= 256) { - using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 4, 0x08u, elem_type>; - run_fmha_fwd_loop(launch_params); - } - })); -} \ No newline at end of file diff --git a/csrc/flash_attn/src/fmha_fwd_hdim64.cu b/csrc/flash_attn/src/fmha_fwd_hdim64.cu deleted file mode 100644 index 134efa63b..000000000 --- a/csrc/flash_attn/src/fmha_fwd_hdim64.cu +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright (c) 2022, Tri Dao. - -// Splitting the different head dimensions to different files to speed up compilation. - -#include "fmha_fwd_launch_template.h" - -void run_fmha_fwd_hdim64(Launch_params &launch_params) { - FP16_SWITCH(launch_params.params.is_bf16, ([&] { - if (launch_params.params.seqlen_k == 128) { - using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>; - run_fmha_fwd_loop(launch_params); - } else if (launch_params.params.seqlen_k >= 256) { - using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>; - run_fmha_fwd_loop(launch_params); - } - })); -} diff --git a/csrc/flash_attn/src/fmha_fwd_launch_template.h b/csrc/flash_attn/src/fmha_fwd_launch_template.h deleted file mode 100644 index ec1d3df0a..000000000 --- a/csrc/flash_attn/src/fmha_fwd_launch_template.h +++ /dev/null @@ -1,91 +0,0 @@ -// Copyright (c) 2022, Tri Dao. - -#pragma once - -#include - -#include -#include - -#include "static_switch.h" -#include "fmha.h" -#include "fmha_fprop_kernel_1xN.h" - -// Find the number of splits that maximizes the occupancy. For example, if we have -// batch * n_heads = 48 and we have 108 SMs, having 2 splits (efficiency = 0.89) is -// better than having 3 splits (efficiency = 0.67). However, we also don't want too many -// splits as that would incur more HBM reads/writes. -// So we find the best efficiency, then find the smallest number of splits that gets 95% -// of the best efficiency. -// [2022-11-25] TD: Mark this as "inline" otherwise we get "multiple definition" error. -inline int num_splits_heuristic_fwd(int batch_nheads, int num_SMs, int ctas_per_sm, int max_splits) { - float max_efficiency = 0.f; - std::vector efficiency; - efficiency.reserve(max_splits); - for (int num_splits = 1; num_splits <= max_splits; num_splits++) { - float n_waves = float(batch_nheads * num_splits) / (num_SMs * ctas_per_sm); - float eff = n_waves / ceil(n_waves); - // printf("num_splits = %d, eff = %f\n", num_splits, eff); - if (eff > max_efficiency) { max_efficiency = eff; } - efficiency.push_back(eff); - } - for (int num_splits = 1; num_splits <= max_splits; num_splits++) { - if (efficiency[num_splits - 1] > 0.95 * max_efficiency) { - // printf("num_splits chosen = %d\n", num_splits); - return num_splits; - } - } - return 1; -} - -template -__global__ void fmha_fwd_loop_kernel(FMHA_fprop_params params) { - fmha::device_1xN_loop(params); -} - -template -void run_fmha_fwd_loop(Launch_params &launch_params) { - constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N; - const int loop_steps = (launch_params.params.seqlen_k + blocksize_c - 1) / blocksize_c; - - constexpr int smem_size_softmax_lse = Kernel_traits::Smem_dp_sum::BYTES_PER_TILE; - // Don't need smem_size_softmax_lse if we're not looping - const int smem_size = fmha::get_dynamic_smem_size() - + (loop_steps > 1 ? smem_size_softmax_lse : 0); - - // Work-around for gcc 7. It doesn't like nested BOOL_SWITCH. - // https://github.com/kokkos/kokkos-kernels/issues/349 - // https://github.com/HazyResearch/flash-attention/issues/21 - BOOL_SWITCH(launch_params.is_dropout, IsDropoutConst, ([&] { - auto kernel = launch_params.params.is_causal - ? (launch_params.return_softmax - ? &fmha_fwd_loop_kernel - : &fmha_fwd_loop_kernel) - : (launch_params.return_softmax - ? &fmha_fwd_loop_kernel - : &fmha_fwd_loop_kernel); - if( smem_size >= 48 * 1024 ) { - FMHA_CHECK_CUDA(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - } - // Automatically set num_splits to maximize occupancy - if (launch_params.params.num_splits <= 0) { - int ctas_per_sm; - cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &ctas_per_sm, kernel, Kernel_traits::THREADS, smem_size); - auto dprops = at::cuda::getCurrentDeviceProperties(); - // printf("CTAS_PER_SM = %d, nSMs = %d\n", ctas_per_sm, dprops->multiProcessorCount); - constexpr int M = Kernel_traits::Cta_tile_p::M; - launch_params.params.num_splits = num_splits_heuristic_fwd( - launch_params.params.b * launch_params.params.h, dprops->multiProcessorCount, - ctas_per_sm, - /*max_splits=*/std::min(30, (launch_params.params.seqlen_q + M - 1 / M)) - ); - } - // printf("smem_size = %d\n", smem_size); - dim3 grid(launch_params.params.b, launch_params.params.h, launch_params.params.num_splits); - kernel<<>>( - launch_params.params); - FMHA_CHECK_CUDA(cudaPeekAtLastError()); - })); -} diff --git a/csrc/flash_attn/src/fmha_kernel.h b/csrc/flash_attn/src/fmha_kernel.h deleted file mode 100644 index 62879769a..000000000 --- a/csrc/flash_attn/src/fmha_kernel.h +++ /dev/null @@ -1,78 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the - * names of its contributors may be used to endorse or promote products - * derived from this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND - * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED - * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY - * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES - * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; - * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND - * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - ******************************************************************************/ - -#pragma once - -#include - -#include -#include -#include -#include -#include -#include - -namespace fmha { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct BlockInfoPadded { - - template - __device__ BlockInfoPadded(const Params ¶ms, - const int bidb, - const int bidh, - const int tidx) - : bidb(bidb), bidh(bidh), h(params.h) { - - // The block index. - sum_s_k = params.cu_seqlens_k[bidb]; - actual_seqlen_k = params.cu_seqlens_k[bidb + 1] - sum_s_k; - sum_s_q = params.cu_seqlens_q[bidb]; - actual_seqlen_q = params.cu_seqlens_q[bidb + 1] - sum_s_q; - - tidx_global = (bidb * params.h + bidh) * THREADS_PER_CTA + tidx; - } - - __device__ bool stop_early(const int start_col = 0) const { - return actual_seqlen_k <= start_col; - } - - int actual_seqlen_q; - int actual_seqlen_k; - int sum_s_q; - int sum_s_k; - int bidh; - int bidb; - int tidx_global; - int h; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace fmha diff --git a/csrc/flash_attn/src/fmha_utils.h b/csrc/flash_attn/src/fmha_utils.h deleted file mode 100644 index 865ddc0b7..000000000 --- a/csrc/flash_attn/src/fmha_utils.h +++ /dev/null @@ -1,100 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the - * names of its contributors may be used to endorse or promote products - * derived from this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND - * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED - * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY - * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES - * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; - * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND - * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - ******************************************************************************/ - -#pragma once - -#include -#include -#include -#include -#include -#include - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#define FMHA_CHECK_CUDA( call ) \ - do { \ - cudaError_t status_ = call; \ - if( status_ != cudaSuccess ) { \ - fprintf( stderr, \ - "CUDA error (%s:%d): %s\n", \ - __FILE__, \ - __LINE__, \ - cudaGetErrorString( status_ ) ); \ - exit( 1 ); \ - } \ - } while( 0 ) - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -enum Data_type { DATA_TYPE_FP16, DATA_TYPE_BF16, DATA_TYPE_FP32, DATA_TYPE_INT32, DATA_TYPE_INT8 }; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static inline void set_alpha( uint32_t &alpha, float norm, Data_type dtype ) { - if( dtype == DATA_TYPE_FP16 ) { - half x = __float2half_rn( norm ); - uint16_t h = reinterpret_cast( x ); - ushort2 h2 = { h, h }; - alpha = reinterpret_cast( h2 ); - } else if( dtype == DATA_TYPE_BF16 ) { - __nv_bfloat16 x = __float2bfloat16( norm ); - uint16_t h = reinterpret_cast( x ); - ushort2 h2 = { h, h }; - alpha = reinterpret_cast( h2 ); - } else if( dtype == DATA_TYPE_FP32 ) { - alpha = reinterpret_cast( norm ); - } else if( dtype == DATA_TYPE_INT32 ) { - int32_t inorm = static_cast( norm ); - alpha = reinterpret_cast( inorm ); - } else { - assert( false ); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static inline size_t get_size_in_bytes( size_t n, Data_type dtype ) { - switch( dtype ) { - case DATA_TYPE_FP32: - return n * 4; - case DATA_TYPE_FP16: - return n * 2; - case DATA_TYPE_BF16: - return n * 2; - case DATA_TYPE_INT32: - return n * 4; - case DATA_TYPE_INT8: - return n; - default: - assert( false ); - return 0; - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - diff --git a/csrc/flash_attn/src/kernel_traits.h b/csrc/flash_attn/src/kernel_traits.h new file mode 100644 index 000000000..3468e4bff --- /dev/null +++ b/csrc/flash_attn/src/kernel_traits.h @@ -0,0 +1,366 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include "cute/algorithm/copy.hpp" + +#include "cutlass/cutlass.h" +#include "cutlass/layout/layout.h" +#include + +using namespace cute; + +template +struct Flash_kernel_traits { + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + using Element = elem_type; + static constexpr bool Has_cp_async = true; +#else + using Element = cutlass::half_t; + static constexpr bool Has_cp_async = false; +#endif + + using ElementAccum = float; + using index_t = uint32_t; + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + using MMA_Atom_Arch = std::conditional_t< + std::is_same_v, + MMA_Atom, + MMA_Atom + >; + using ValLayoutMNK = Layout>; +#else + using MMA_Atom_Arch = MMA_Atom; + using ValLayoutMNK = Layout>; +#endif + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 + using SmemCopyAtom = Copy_Atom; + using SmemCopyAtomTransposed = Copy_Atom; +#else + using SmemCopyAtom = Copy_Atom; + using SmemCopyAtomTransposed = Copy_Atom; +#endif +}; + +// If Share_Q_K_smem is true, that forces Is_Q_in_regs to be true +template > +struct Flash_fwd_kernel_traits : public Base { + using Element = typename Base::Element; + using ElementAccum = typename Base::ElementAccum; + using index_t = typename Base::index_t; + static constexpr bool Has_cp_async = Base::Has_cp_async; + using SmemCopyAtom = typename Base::SmemCopyAtom; + using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; + + static constexpr bool Share_Q_K_smem = Share_Q_K_smem_; + static constexpr bool Is_Q_in_regs = Is_Q_in_regs_ || Share_Q_K_smem; + + // The number of threads. + static constexpr int kNWarps = kNWarps_; + static constexpr int kNThreads = kNWarps * 32; + + static constexpr int kBlockM = kBlockM_; + static constexpr int kBlockN = kBlockN_; + static constexpr int kHeadDim = kHeadDim_; + static_assert(kHeadDim % 32 == 0); + static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; + static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); + static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; + + using TiledMma = TiledMMA< + typename Base::MMA_Atom_Arch, + Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group + typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM + + using SmemLayoutAtomQ = decltype( + composition(Swizzle{}, + // This has to be kBlockKSmem, using kHeadDim gives wrong results for d=128 + Layout>, + Stride, _1>>{})); + using SmemLayoutQ = decltype(tile_to_shape( + SmemLayoutAtomQ{}, + Shape, Int>{})); + + using SmemLayoutKV = decltype(tile_to_shape( + SmemLayoutAtomQ{}, + Shape, Int>{})); + + using SmemLayoutAtomVtransposed = decltype( + composition(Swizzle{}, + // This has to be kBlockN and not 8, otherwise we get wrong results for d=128 + Layout, Int>, + Stride<_1, Int>>{})); + using SmemLayoutVtransposed = decltype(tile_to_shape( + SmemLayoutAtomVtransposed{}, + Shape, Int>{})); + // Maybe the VtransposeNoSwizzle just needs to have the right shape + // And the strides don't matter? + using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn()); + + using SmemLayoutAtomO = decltype( + composition(Swizzle{}, + Layout, Int>, + Stride, _1>>{})); + using SmemLayoutO = decltype(tile_to_shape( + SmemLayoutAtomO{}, + Shape, Int>{})); + using SmemCopyAtomO = Copy_Atom; + + static constexpr int kSmemQCount = size(SmemLayoutQ{}); + static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2; + static constexpr int kSmemQSize = kSmemQCount * sizeof(Element); + static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element); + static constexpr int kSmemSize = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize; + + static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); + static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); + // Using kBlockKSmem here is 6-10% faster than kBlockKGmem for d=128 because of bank conflicts. + // For example, for d=128, smem is split into 2 "pages", each page takes care of columns + // 0-63 and 64-127. If we have 16 threads per row for gmem read, when we write to smem, + // thread 0 - 7 will write to the first page and thread 8 - 15 will write to the second page, + // to the same banks. + static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; + static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); + using GmemLayoutAtom = Layout, Int>, + Stride, _1>>; + + // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading + // from the same address by the same threadblock. This is slightly faster. + using Gmem_copy_struct = std::conditional_t< + Has_cp_async, + SM80_CP_ASYNC_CACHEGLOBAL, + DefaultCopy + >; + using GmemTiledCopyQKV = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per read + using GmemTiledCopyO = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per store + static constexpr int kGmemThreadsPerRowP = kBlockN / kGmemElemsPerLoad; + static_assert(kNThreads % kGmemThreadsPerRowP == 0, "kNThreads must be a multiple of kGmemThreadsPerRowP"); + using GmemLayoutAtomP = Layout, Int>, + Stride, _1>>; + + using GmemTiledCopyP = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomP{}, + Layout>{})); // Val layout, 8 vals per store + +}; + +// Is_V_in_regs is an option to reduce smem usage, but will increase register pressue. +// No_double_buffer is another option to reduce smem usage, but will slow things down. +template > +struct Flash_bwd_kernel_traits : public Base { + using Element = typename Base::Element; + using ElementAccum = typename Base::ElementAccum; + using index_t = typename Base::index_t; + static constexpr bool Has_cp_async = Base::Has_cp_async; + using SmemCopyAtom = typename Base::SmemCopyAtom; + using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; + + static constexpr bool Is_V_in_regs = Is_V_in_regs_; + static constexpr bool No_double_buffer = No_double_buffer_; + + // The number of threads. + static constexpr int kNWarps = kNWarps_; + static constexpr int kNThreads = kNWarps * 32; + + static constexpr int kBlockM = kBlockM_; + static constexpr int kBlockN = kBlockN_; + static constexpr int kHeadDim = kHeadDim_; + static_assert(kHeadDim % 32 == 0); + static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; + static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); + static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; + + static constexpr int AtomLayoutMSdP = AtomLayoutMSdP_; + static_assert(kNWarps % AtomLayoutMSdP == 0); + static_assert(kNWarps % AtomLayoutNdKV == 0); + static_assert(kNWarps % AtomLayoutMdQ == 0); + + using TiledMmaSdP = TiledMMA< + typename Base::MMA_Atom_Arch, + Layout, Int, _1>>, + typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM + + using TiledMmadKV = TiledMMA< + typename Base::MMA_Atom_Arch, + Layout, Int, _1>>, + typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM + + using TiledMmadQ = TiledMMA< + typename Base::MMA_Atom_Arch, + Layout, Int, _1>>, // 2x4x1 or 4x2x1 thread group + typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM + + using SmemLayoutAtomQdO = decltype( + composition(Swizzle{}, + Layout>, + Stride, _1>>{})); + using SmemLayoutQdO = decltype(tile_to_shape( + SmemLayoutAtomQdO{}, + make_shape(Int{}, Int{}))); + + using SmemLayoutAtomKV = decltype( + composition(Swizzle{}, + Layout, Int>, + Stride, _1>>{})); + using SmemLayoutKV = decltype(tile_to_shape( + // SmemLayoutAtomQdO{}, + SmemLayoutAtomKV{}, + make_shape(Int{}, Int{}))); + + using SmemLayoutAtomKtransposed = decltype( + composition(Swizzle{}, + Layout, Int>, + Stride<_1, Int>>{})); + using SmemLayoutKtransposed = decltype(tile_to_shape( + SmemLayoutAtomKtransposed{}, + make_shape(Int{}, Int{}))); + // Maybe the KtransposeNoSwizzle just needs to have the right shape + // And the strides don't matter? + using SmemLayoutKtransposedNoSwizzle = decltype(SmemLayoutKtransposed{}.layout_fn()); + + // TODO: generalize to other values of kBlockN + // TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2 + // static constexpr int kPBlockN = kBlockN; + static_assert(kBlockN >= 64); + // TD [2023-03-19]: Idk why kPBlockN = 16 and kSwizzlePdS=3 is the fastest. + static constexpr int kPBlockN = 64; + static_assert(kPBlockN == 16 || kPBlockN == 32 || kPBlockN == 64); + // static constexpr int kSwizzlePdS = kPBlockN == 16 ? 1 : (kPBlockN == 32 ? 2 : 3); + static constexpr int kSwizzlePdS = 3; + using SmemLayoutAtomPdS = decltype( + composition(Swizzle{}, + Layout, Int>, + Stride, _1>>{})); + using SmemLayoutPdS = decltype(tile_to_shape( + SmemLayoutAtomPdS{}, + make_shape(Int{}, Int{}))); + using SmemLayoutAtomPdStransposed = decltype( + composition(Swizzle{}, + Layout, Int>, + Stride<_1, Int>>{})); + using SmemLayoutPdStransposed = decltype(tile_to_shape( + SmemLayoutAtomPdStransposed{}, + make_shape(Int{}, Int{}))); + using SmemLayoutPdStransposedNoSwizzle = decltype(SmemLayoutPdStransposed{}.layout_fn()); + using SmemCopyAtomPdS = Copy_Atom; + + using SmemLayoutAtomQdOtransposed = decltype( + composition(Swizzle{}, + Layout, Int>, + Stride<_1, Int>>{})); + using SmemLayoutQdOtransposed = decltype(tile_to_shape( + SmemLayoutAtomQdOtransposed{}, + make_shape(Int{}, Int{}))); + using SmemLayoutQdOtransposedNoSwizzle = decltype(SmemLayoutQdOtransposed{}.layout_fn()); + + using SmemLayoutAtomdKV = decltype( + composition(Swizzle{}, + Layout>, + Stride, _1>>{})); + using SmemLayoutdKV = decltype(tile_to_shape( + SmemLayoutAtomdKV{}, + make_shape(Int{}, Int{}))); + using SmemCopyAtomdKV = Copy_Atom; + + using SmemLayoutAtomdQ = decltype( + composition(Swizzle{}, + Layout>, + Stride, _1>>{})); + using SmemLayoutdQ = decltype(tile_to_shape( + SmemLayoutAtomdQ{}, + make_shape(Int{}, Int{}))); + using SmemCopyAtomdQ = Copy_Atom; + + static constexpr int kSmemQdOCount = size(SmemLayoutQdO{}) * (No_double_buffer ? 2 : 3); // Double buffer for sQ + static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2; + static constexpr int kSmemdSCount = size(SmemLayoutPdS{}); + static constexpr int kSmemPCount = size(SmemLayoutPdS{}); + static constexpr int kSmemdQCount = size(SmemLayoutdQ{}); + static constexpr int kSmemdPsumCount = kBlockM; + static constexpr int kSmemQdOSize = kSmemQdOCount * sizeof(Element); + static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element); + static constexpr int kSmemdSSize = kSmemdSCount * sizeof(Element); + static constexpr int kSmemPSize = kSmemPCount * sizeof(Element); + static constexpr int kSmemdQSize = kSmemdQCount * sizeof(Element); + static constexpr int kSmemdPsumSize = kSmemdPsumCount * sizeof(ElementAccum); + static constexpr int kSmemSize = kSmemQdOSize + + (!Is_V_in_regs + ? kSmemKVSize + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize) + : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize))); + static constexpr int kSmemSize1colblock = kSmemQdOSize + + (!Is_V_in_regs + ? kSmemKVSize + kSmemdSSize + kSmemPSize + : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + kSmemPSize)); + static constexpr int kSmemSize1rowblock = kSmemQdOSize / 3 * 2 + kSmemKVSize / 2 * 3 + + kSmemdSSize + kSmemPSize; + + + static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); + static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); + // Using kBlockKSmem instead of kHeadDim here to avoid bank conflicts, but doesn't seem + // to affect speed in practice. + static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; + static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); + using GmemLayoutAtom = Layout, Int>, + Stride, _1>>; + + // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading + // from the same address by the same threadblock. This is slightly faster. + using Gmem_copy_struct = std::conditional_t< + Has_cp_async, + SM80_CP_ASYNC_CACHEGLOBAL, + DefaultCopy + >; + using GmemTiledCopyQKV = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per read + using GmemTiledCopydO = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per store + using GmemTiledCopydKV = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per store + using GmemTiledCopydQ = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per store + using GmemLayoutAtomdQaccum = std::conditional_t< + kBlockKSmem == 32, + Layout, // Thread layout, 8 threads per row + Stride< _8, _1>>, + Layout, // Thread layout, 16 threads per row + Stride< _16, _1>> + >; + using GmemTiledCopydQaccum = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomdQaccum{}, + Layout>{})); // Val layout, 4 vals per store + + using GmemTiledCopydQaccumAtomicAdd = decltype( + make_tiled_copy(Copy_Atom{}, + Layout, // Thread layout, 8 threads per row + Stride<_32, _1>>{}, + Layout>{})); // Val layout, 1 val per store + +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/flash_attn/src/kernel_traits_sm90.h b/csrc/flash_attn/src/kernel_traits_sm90.h new file mode 100644 index 000000000..e07f38390 --- /dev/null +++ b/csrc/flash_attn/src/kernel_traits_sm90.h @@ -0,0 +1,159 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include "cute/algorithm/copy.hpp" + +#include "cutlass/cutlass.h" +#include "cutlass/layout/layout.h" +#include + +using namespace cute; + +template +struct Flash_kernel_traits_sm90 { + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + using Element = elem_type; + static constexpr bool Has_cp_async = true; +#else + using Element = cutlass::half_t; + static constexpr bool Has_cp_async = false; +#endif + + using ElementAccum = float; + using index_t = uint32_t; + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + using MMA_Atom_Arch = std::conditional_t< + std::is_same_v, + MMA_Atom, + MMA_Atom + >; + using ValLayoutMNK = Layout>; +#else + using MMA_Atom_Arch = MMA_Atom; + using ValLayoutMNK = Layout>; +#endif + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 + using SmemCopyAtom = Copy_Atom; + using SmemCopyAtomTransposed = Copy_Atom; +#else + using SmemCopyAtom = Copy_Atom; + using SmemCopyAtomTransposed = Copy_Atom; +#endif +}; + +template > +struct Flash_fwd_kernel_traits : public Base { + using Element = typename Base::Element; + using ElementAccum = typename Base::ElementAccum; + using index_t = typename Base::index_t; + static constexpr bool Has_cp_async = Base::Has_cp_async; + using SmemCopyAtom = typename Base::SmemCopyAtom; + using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; + + static constexpr bool Share_Q_K_smem = Share_Q_K_smem_; + static constexpr bool Is_Q_in_regs = Is_Q_in_regs_ || Share_Q_K_smem; + + // The number of threads. + static constexpr int kNWarps = kNWarps_; + static constexpr int kNThreads = kNWarps * 32; + + static constexpr int kBlockM = kBlockM_; + static constexpr int kBlockN = kBlockN_; + static constexpr int kHeadDim = kHeadDim_; + static_assert(kHeadDim % 32 == 0); + static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; + static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); + static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; + + using TiledMma = TiledMMA< + typename Base::MMA_Atom_Arch, + Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group + typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM + + using SmemLayoutAtomQ = decltype( + composition(Swizzle{}, + // This has to be kBlockKSmem, using kHeadDim gives wrong results for d=128 + Layout>, + Stride, _1>>{})); + using SmemLayoutQ = decltype(tile_to_shape( + SmemLayoutAtomQ{}, + Shape, Int>{})); + + using SmemLayoutKV = decltype(tile_to_shape( + SmemLayoutAtomQ{}, + Shape, Int>{})); + + using SmemLayoutAtomVtransposed = decltype( + composition(Swizzle{}, + // This has to be kBlockN and not 8, otherwise we get wrong results for d=128 + Layout, Int>, + Stride<_1, Int>>{})); + using SmemLayoutVtransposed = decltype(tile_to_shape( + SmemLayoutAtomVtransposed{}, + Shape, Int>{})); + // Maybe the VtransposeNoSwizzle just needs to have the right shape + // And the strides don't matter? + using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn()); + + using SmemLayoutAtomO = decltype( + composition(Swizzle{}, + Layout, Int>, + Stride, _1>>{})); + using SmemLayoutO = decltype(tile_to_shape( + SmemLayoutAtomO{}, + Shape, Int>{})); + using SmemCopyAtomO = Copy_Atom; + + static constexpr int kSmemQCount = size(SmemLayoutQ{}); + static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2; + static constexpr int kSmemQSize = kSmemQCount * sizeof(Element); + static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element); + static constexpr int kSmemSize = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize; + + static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); + static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); + // Using kBlockKSmem here is 6-10% faster than kBlockKGmem for d=128 because of bank conflicts. + // For example, for d=128, smem is split into 2 "pages", each page takes care of columns + // 0-63 and 64-127. If we have 16 threads per row for gmem read, when we write to smem, + // thread 0 - 7 will write to the first page and thread 8 - 15 will write to the second page, + // to the same banks. + static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; + static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); + using GmemLayoutAtom = Layout, Int>, + Stride, _1>>; + + // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading + // from the same address by the same threadblock. This is slightly faster. + using Gmem_copy_struct = std::conditional_t< + Has_cp_async, + SM80_CP_ASYNC_CACHEGLOBAL, + DefaultCopy + >; + using GmemTiledCopyQKV = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per read + using GmemTiledCopyO = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per store + static constexpr int kGmemThreadsPerRowP = kBlockN / kGmemElemsPerLoad; + static_assert(kNThreads % kGmemThreadsPerRowP == 0, "kNThreads must be a multiple of kGmemThreadsPerRowP"); + using GmemLayoutAtomP = Layout, Int>, + Stride, _1>>; + + using GmemTiledCopyP = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomP{}, + Layout>{})); // Val layout, 8 vals per store + +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/flash_attn/src/philox.cuh b/csrc/flash_attn/src/philox.cuh index a1e4c641d..6ce1440f2 100644 --- a/csrc/flash_attn/src/philox.cuh +++ b/csrc/flash_attn/src/philox.cuh @@ -1,8 +1,55 @@ -// Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/csrc/multihead_attn/philox.cuh -// Pytorch also has an implementation of Philox RNG: https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/codegen/cuda/runtime/random_numbers.cu +// Pytorch also has an implementation of Philox RNG: https://github.com/pytorch/pytorch/blob/8ca3c881db3e3510fcb7725389f6a0633c9b992c/torch/csrc/jit/tensorexpr/cuda_random.h #pragma once // Philox CUDA. +namespace flash { + +struct ull2 { + unsigned long long x; + unsigned long long y; +}; + +inline __device__ uint2 mulhilo32(const unsigned int a, const unsigned int b) { + uint2 *res; + unsigned long long tmp; + asm ("mul.wide.u32 %0, %1, %2;\n\t" + : "=l"(tmp) + : "r"(a), "r"(b)); + res = (uint2*)(&tmp); + return *res; +} + +inline __device__ uint4 philox_single_round(const uint4 ctr, const uint2 key) { + constexpr unsigned long kPhiloxSA = 0xD2511F53; + constexpr unsigned long kPhiloxSB = 0xCD9E8D57; + uint2 res0 = mulhilo32(kPhiloxSA, ctr.x); + uint2 res1 = mulhilo32(kPhiloxSB, ctr.z); + uint4 ret = {res1.y ^ ctr.y ^ key.x, res1.x, res0.y ^ ctr.w ^ key.y, res0.x}; + return ret; +} + +inline __device__ uint4 philox(unsigned long long seed, + unsigned long long subsequence, + unsigned long long offset) { + constexpr unsigned long kPhilox10A = 0x9E3779B9; + constexpr unsigned long kPhilox10B = 0xBB67AE85; + uint2 key = reinterpret_cast(seed); + uint4 counter; + ull2 *tmp = reinterpret_cast(&counter); + tmp->x = offset; + tmp->y = subsequence; + #pragma unroll + for (int i = 0; i < 6; i++) { + counter = philox_single_round(counter, key); + key.x += (kPhilox10A); + key.y += (kPhilox10B); + } + uint4 output = philox_single_round(counter, key); + return output; +} + +} // namespace flash + namespace { class Philox { @@ -10,7 +57,10 @@ public: __device__ inline Philox(unsigned long long seed, unsigned long long subsequence, unsigned long long offset) - : key(reinterpret_cast(seed)) { + : STATE(0) + , seed_(seed) + , offset_(offset) + , key(reinterpret_cast(seed)) { //key.x = (unsigned int)seed; //key.y = (unsigned int)(seed >> 32); //counter = make_uint4(0, 0, 0, 0); @@ -19,6 +69,7 @@ public: //STATE = 0; //incr_n(offset / 4); + // key = reinterpret_cast(seed); ull2 * tmp = reinterpret_cast(&counter); tmp->x = offset / 4; tmp->y = subsequence; @@ -26,72 +77,64 @@ public: // printf("Philox counter: %d, %d, %d, %d\n", counter.x, counter.y, counter.z, counter.w); // } } - __device__ inline uint4 operator()() { - uint4 counter_ = counter; - uint2 key_ = key; - // 7-round philox - #pragma unroll - for (int i = 0; i < 6; i++) { - counter_ = single_round(counter_, key_); - key_.x += (kPhilox10A); - key_.y += (kPhilox10B); - } - uint4 output = single_round(counter_, key_); - // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - // printf("Philox counter: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w); - // printf("Philox output: %u, %u, %u, %u\n", output.x, output.y, output.z, output.w); - // } - incr(); - return output; - } - - __device__ inline uint4 operator()(const unsigned long long subsequence) { - uint4 counter_ = counter; - ull2 * tmp = reinterpret_cast(&counter_); - tmp->y = subsequence; - // if ((threadIdx.x % 32 == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - // printf("tidx = %d, counter_: %u, %u, %u, %u\n", threadIdx.x, counter_.x, counter_.y, counter_.z, counter_.w); - // } - uint2 key_ = key; - // 7-round philox - #pragma unroll - for (int i = 0; i < 6; i++) { - counter_ = single_round(counter_, key_); - key_.x += (kPhilox10A); - key_.y += (kPhilox10B); - } - uint4 output = single_round(counter_, key_); - // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - // printf("Philox counter: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w); - // printf("Philox output: %u, %u, %u, %u\n", output.x, output.y, output.z, output.w); - // } - return output; + // // if (STATE == 0) { + // uint4 counter_ = counter; + // uint2 key_ = key; + // // 7-round philox + // #pragma unroll + // for (int i = 0; i < 6; i++) { + // counter_ = flash::philox_single_round(counter_, key_); + // key_.x += (kPhilox10A); + // key_.y += (kPhilox10B); + // } + // // output = philox_single_round(counter_, key_); + // uint4 output = flash::philox_single_round(counter_, key_); + // // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // // printf("Philox counter: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w); + // // printf("Philox output: %u, %u, %u, %u\n", output.x, output.y, output.z, output.w); + // // } + // incr(); + // // } + // // return a float4 directly + // // unsigned long ret; + // // switch(STATE) { + // // case 0: ret = output.x; break; + // // case 1: ret = output.y; break; + // // case 2: ret = output.z; break; + // // case 3: ret = output.w; break; + // //} + // // STATE = (STATE + 1) % 4; + // return output; + return flash::philox(seed_, offset_, offset_); } private: + unsigned long long offset_, seed_; struct ull2 { uint64_t x; uint64_t y; }; uint4 counter; + // uint4 output; const uint2 key; + unsigned int STATE; + __device__ inline void incr_n(unsigned long long n) { + unsigned int nlo = (unsigned int)(n); + unsigned int nhi = (unsigned int)(n >> 32); + counter.x += nlo; + if (counter.x < nlo) + nhi++; + counter.y += nhi; + if (nhi <= counter.y) + return; + if (++counter.z) + return; + ++counter.w; + } - // __device__ inline void incr_n(unsigned long long n) { - // unsigned int nlo = (unsigned int)(n); - // unsigned int nhi = (unsigned int)(n >> 32); - // counter.x += nlo; - // if (counter.x < nlo) - // nhi++; - // counter.y += nhi; - // if (nhi <= counter.y) - // return; - // if (++counter.z) - // return; - // ++counter.w; - // } - - __device__ uint4 incr(uint4 ctr) { + __device__ uint4 incr128 (uint4 ctr) + { uint4 res; asm ("add.cc.u32 %0, %4, %8;\n\t" "addc.cc.u32 %1, %5, %9;\n\t" @@ -107,51 +150,16 @@ private: // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { // printf("Counter before: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w); // } - counter = incr(counter); + counter = incr128(counter); // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { // printf("Counter after: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w); // } } - // __device__ unsigned int mulhilo32(unsigned int a, unsigned int b, - // unsigned int *result_high) { - // *result_high = __umulhi(a, b); - // return a * b; - // } - - __device__ uint2 mulhilo32(const unsigned int a, const unsigned int b) { - uint2 *res; - unsigned long long tmp; - asm ("mul.wide.u32 %0, %1, %2;\n\t" - : "=l"(tmp) - : "r"(a), "r"(b)); - res = (uint2*)(&tmp); - return *res; - } - - __device__ inline uint4 single_round(const uint4 ctr, const uint2 key) { - //unsigned int hi0; - //unsigned int hi1; - //unsigned int lo0 = mulhilo32(kPhiloxSA, ctr.x, &hi0); - //unsigned int lo1 = mulhilo32(kPhiloxSB, ctr.z, &hi1); - //uint4 ret = {hi1 ^ ctr.y ^ key.x, lo1, hi0 ^ ctr.w ^ key.y, lo0}; - uint2 res0 = mulhilo32(kPhiloxSA, ctr.x); - uint2 res1 = mulhilo32(kPhiloxSB, ctr.z); - uint4 ret = {res1.y ^ ctr.y ^ key.x, res1.x, res0.y ^ ctr.w ^ key.y, res0.x}; - return ret; - } - static const unsigned long kPhilox10A = 0x9E3779B9; static const unsigned long kPhilox10B = 0xBB67AE85; - static const unsigned long kPhiloxSA = 0xD2511F53; - static const unsigned long kPhiloxSB = 0xCD9E8D57; + // static const unsigned long kPhiloxSA = 0xD2511F53; + // static const unsigned long kPhiloxSB = 0xCD9E8D57; }; -// Inverse of 2^32. -constexpr float M_RAN_INVM32 = 2.3283064e-10f; -__device__ __inline__ float4 uniform4(const uint4 x) { - return make_float4(x.x * M_RAN_INVM32, x.y * M_RAN_INVM32, x.z * M_RAN_INVM32, - x.w * M_RAN_INVM32); -} - } // namespace diff --git a/csrc/flash_attn/src/softmax.h b/csrc/flash_attn/src/softmax.h new file mode 100644 index 000000000..3e9a7b459 --- /dev/null +++ b/csrc/flash_attn/src/softmax.h @@ -0,0 +1,272 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include + +#include + +#include +#include + +#include "philox.cuh" +#include "utils.h" + +namespace flash { + +using namespace cute; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__device__ inline void thread_reduce_(Tensor const &tensor, Tensor &summary, Operator &op) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor)); + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); mi++) { + summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0)); + #pragma unroll + for (int ni = 1; ni < size<1>(tensor); ni++) { + summary(mi) = op(summary(mi), tensor(mi, ni)); + } + } +} + +template +__device__ inline void quad_allreduce_(Tensor &dst, Tensor &src, Operator &op) { + CUTE_STATIC_ASSERT_V(size(dst) == size(src)); + #pragma unroll + for (int i = 0; i < size(dst); i++){ + dst(i) = Allreduce<4>::run(src(i), op); + } +} + +template +__device__ inline void reduce_(Tensor const& tensor, Tensor &summary, Operator &op) { + thread_reduce_(tensor, summary, op); + quad_allreduce_(summary, summary, op); +} + +template +__device__ inline void reduce_max(Tensor const& tensor, Tensor &max){ + MaxOp max_op; + reduce_(tensor, max, max_op); +} + +template +__device__ inline void reduce_sum(Tensor const& tensor, Tensor &sum){ + SumOp sum_op; + reduce_(tensor, sum, sum_op); +} + +// Apply the exp to all the elements. +template +inline __device__ void scale_apply_exp2(Tensor &tensor, Tensor const &max, const float scale) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + // If max is -inf, then all elements must have been -inf (possibly due to masking). + // We don't want (-inf - (-inf)) since that would give NaN. + // If we don't have float around M_LOG2E the multiplication is done in fp64. + const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * (Scale_max ? scale : float(M_LOG2E)); + #pragma unroll + for (int ni = 0; ni < size<1>(tensor); ++ni) { + // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + // max * log_2(e)) This allows the compiler to use the ffma + // instruction instead of fadd and fmul separately. + tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); + } + } +} + +// Apply the exp to all the elements. +template +inline __device__ void max_scale_exp2_sum(Tensor &tensor, Tensor &max, Tensor &sum, const float scale) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + MaxOp max_op; + max(mi) = zero_init ? tensor(mi, 0) : max_op(max(mi), tensor(mi, 0)); + #pragma unroll + for (int ni = 1; ni < size<1>(tensor); ni++) { + max(mi) = max_op(max(mi), tensor(mi, ni)); + } + max(mi) = Allreduce<4>::run(max(mi), max_op); + // If max is -inf, then all elements must have been -inf (possibly due to masking). + // We don't want (-inf - (-inf)) since that would give NaN. + const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * scale; + sum(mi) = 0; + #pragma unroll + for (int ni = 0; ni < size<1>(tensor); ++ni) { + // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + // max * log_2(e)) This allows the compiler to use the ffma + // instruction instead of fadd and fmul separately. + tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); + sum(mi) += tensor(mi, ni); + } + SumOp sum_op; + sum(mi) = Allreduce<4>::run(sum(mi), sum_op); + } +} + +template +inline __device__ void apply_mask(Tensor &tensor, const uint32_t max_seqlen_k) { + // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) + static_assert(Layout::rank == 2, "Only support 2D Tensor"); + const uint32_t lane_id = threadIdx.x % 32; + #pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + #pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const uint32_t col_idx = nj * 8 + j + (lane_id % 4) * 2; + if (col_idx >= max_seqlen_k) { + // Without the "make_coord" we get wrong results + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + tensor(mi, make_coord(j, nj)) = -INFINITY; + } + } + } + } +} + +template +inline __device__ void apply_mask_causal(Tensor &tensor, const uint32_t col_idx_offset_, + const uint32_t max_seqlen_k, const uint32_t row_idx_offset_, + const uint32_t warp_row_stride) { + // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) + static_assert(Layout::rank == 2, "Only support 2D Tensor"); + const uint32_t lane_id = threadIdx.x % 32; + // const uint32_t row_idx_offset = row_idx_offset_ + lane_id / 4; + const uint32_t row_idx_offset = row_idx_offset_; + const uint32_t col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; + #pragma unroll + for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { + const uint32_t row_idx_base = row_idx_offset + mi * warp_row_stride; + #pragma unroll + for (int i = 0; i < size<0, 0>(tensor); ++i) { + const uint32_t row_idx = row_idx_base + i * 8; + const uint32_t col_idx_limit = std::min(max_seqlen_k, row_idx + 1); + #pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const uint32_t col_idx_base = col_idx_offset + nj * 8; + #pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const uint32_t col_idx = col_idx_base + j; + if (col_idx >= col_idx_limit) { + tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; + } + } + } + // if (cute::thread0()) { + // printf("mi = %d, i = %d, row_idx = %d, max_seqlen_k = %d\n", mi, i, row_idx, max_seqlen_k); + // print(tensor(make_coord(i, mi), _)); + // // print(tensor(_, j + nj * size<1, 0>(tensor))); + // } + } + } +} + +template +inline __device__ void apply_mask_causal_w_idx( + Tensor &tensor, Tensor const &idx_rowcol, + const uint32_t col_idx_offset_, const uint32_t max_seqlen_k, const uint32_t row_idx_offset_) +{ + // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 2, "Only support 2D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(tensor) == size<0>(idx_rowcol)); + CUTE_STATIC_ASSERT_V(size<1>(tensor) == size<1>(idx_rowcol)); + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + const uint32_t col_idx_limit = std::min(max_seqlen_k, 1 + row_idx_offset_ + get<0>(idx_rowcol(mi, 0))); + #pragma unroll + for (int ni = 0; ni < size<1, 1>(tensor); ++ni) { + if (col_idx_offset_ + get<1>(idx_rowcol(0, ni)) >= col_idx_limit) { + tensor(mi, ni) = -INFINITY; + } + } + // if (cute::thread0()) { + // printf("ni = %d, j = %d, col_idx = %d, max_seqlen_k = %d\n", ni, j, col_idx, max_seqlen_k); + // print(tensor(_, make_coord(j, ni))); + // // print(tensor(_, j + ni * size<1, 0>(tensor))); + // } + } +} + +template +inline __device__ void apply_dropout(Tensor &tensor, uint8_t p_dropout_in_uint8_t, + unsigned long long seed, unsigned long long offset, + uint32_t block_row_start, uint32_t block_col_start, + uint32_t block_row_stride) { + // tensor has shape (8, MMA_M, MMA_N / 2) + using T = typename Engine::value_type; + auto encode_dropout = [](bool keep, T val) { + return keep ? val : (encode_dropout_in_sign_bit ? -val : T(0)); + }; + static_assert(decltype(size<2>(tensor))::value % 2 == 0); + const uint16_t p_dropout_8bit_in_uint16_t = uint16_t(p_dropout_in_uint8_t); + const uint32_t p_dropout_8bit_in_uint32_t = (uint32_t(p_dropout_8bit_in_uint16_t) << 16) | uint32_t(p_dropout_8bit_in_uint16_t); + // if (cute::thread0()) { printf("threshold2 = 0x%x\n", p_dropout_8bit_in_uint32_t); } + #pragma unroll + for (int m = 0; m < size<1>(tensor); ++m, block_row_start += block_row_stride) { + uint2 rowcol = make_uint2(block_row_start, block_col_start); + #pragma unroll + for (int n = 0; n < size<2>(tensor) / 2; ++n, ++rowcol.y) { + // if (cute::thread(32, 0)) { printf("m = %d, n = %d, row = %d, col = %d\n", m, n, int(rowcol.x), int(rowcol.y));} + uint4 random_uint4 = flash::philox(seed, reinterpret_cast(rowcol), offset); + // if (cute::thread0()) { printf("philox = %u, %d, %d, %d\n", random_uint4.x, random_uint4.y, random_uint4.z, random_uint4.w);} + uint8_t (&rnd_8)[16] = reinterpret_cast(random_uint4); + // Special implementation for 16-bit types: we duplicate the threshold to the + // low and high 16 bits of a 32-bit value, then use the f16x2 comparison instruction + // to get a mask. The low 16 bits of the mask will be either 0xffff or 0x0000, + // and the high 16 bits will be either 0xffff or 0x0000, depending on whether + // the random value is less than the threshold. + // We then do a bit-wise AND between the mask and the original value (in 32-bit). + // We're exploiting the fact that floating point comparison is equivalent to integer + // comparison, since we're comparing unsigned integers whose top 8-bits are zero. + if (!encode_dropout_in_sign_bit + && (std::is_same::value || std::is_same::value)) { + uint16_t rnd_16[16]; + #pragma unroll + for (int i = 0; i < 16; i++) { rnd_16[i] = uint16_t(rnd_8[i]); } + uint32_t (&rnd_32)[8] = reinterpret_cast(rnd_16); + #pragma unroll + for (int j = 0; j < 2; j++) { + Tensor tensor_uint32 = recast(tensor(_, m, n * 2 + j)); + // if (cute::thread0()) { printf("random = 0x%x, 0x%x, 0x%x, 0x%x\n", rnd_32[j * 4 + 0], rnd_32[j * 4 + 1], rnd_32[j * 4 + 2], rnd_32[j * 4 + 3]); } + // if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); } + #pragma unroll + for (int i = 0; i < 4; i++) { + uint32_t mask; + asm volatile("set.le.u32.f16x2 %0, %1, %2;\n" : "=r"(mask) : "r"(rnd_32[j * 4 + i]), "r"(p_dropout_8bit_in_uint32_t)); + tensor_uint32(i) &= mask; + } + // if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); } + } + } else { + #pragma unroll + for (int j = 0; j < 2; j++) { + #pragma unroll + for (int i = 0; i < 8; i++) { + tensor(i, m, n * 2 + j) = encode_dropout(rnd_8[j * 8 + i] <= p_dropout_in_uint8_t, tensor(i, m, n * 2 + j)); + } + Tensor tensor_uint32 = recast(tensor(_, m, n * 2 + j)); + // if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); } + } + } + // // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // // printf("n = %d, ph Philox: %u, %u, %u, %u\n", n, rnd_8.x, rnd_8.y, rnd_8.z, rnd_8.w); + // // } + } + } +} + +} // namespace flash diff --git a/csrc/flash_attn/src/static_switch.h b/csrc/flash_attn/src/static_switch.h index 53bcf35d6..b4a4b488d 100644 --- a/csrc/flash_attn/src/static_switch.h +++ b/csrc/flash_attn/src/static_switch.h @@ -1,6 +1,5 @@ // Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h // and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h -// and https://github.com/facebookresearch/xformers/blob/main/xformers/csrc/attention/cuda/fmha/gemm_kernel_utils.h#L8 #pragma once @@ -10,31 +9,57 @@ /// /// Usage: /// ``` -/// BOOL_SWITCH(flag, BoolConst, ([&] { +/// BOOL_SWITCH(flag, BoolConst, [&] { /// some_function(...); -/// })); +/// }); /// ``` -/// We need "({" and "})" to make sure that the code is a single argument being passed to the macro. -#define BOOL_SWITCH(COND, CONST_NAME, F) \ - { \ - if (COND) { \ - constexpr bool CONST_NAME = true; \ - F(); \ - } else { \ - constexpr bool CONST_NAME = false; \ - F(); \ - } \ - } +#define BOOL_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + if (COND) { \ + constexpr bool CONST_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + } \ + }() -// modified from BOOL_SWITCH -// because MSVC cannot handle std::conditional with constexpr variable -#define FP16_SWITCH(COND, F) \ - { \ - if (COND) { \ - using elem_type = __nv_bfloat16; \ - F(); \ - } else { \ - using elem_type = __half; \ - F(); \ - } \ - } +#define FP16_SWITCH(COND, ...) \ + [&] { \ + if (COND) { \ + using elem_type = cutlass::half_t; \ + return __VA_ARGS__(); \ + } else { \ + using elem_type = cutlass::bfloat16_t; \ + return __VA_ARGS__(); \ + } \ + }() + +#define FWD_HEADDIM_SWITCH(HEADDIM, ...) \ + [&] { \ + if (HEADDIM <= 32) { \ + constexpr int kHeadDim = 32; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 64) { \ + constexpr int kHeadDim = 64; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 96) { \ + constexpr int kHeadDim = 96; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 128) { \ + constexpr int kHeadDim = 128; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 160) { \ + constexpr int kHeadDim = 160; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 192) { \ + constexpr int kHeadDim = 192; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 224) { \ + constexpr int kHeadDim = 224; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 256) { \ + constexpr int kHeadDim = 256; \ + return __VA_ARGS__(); \ + } \ + }() diff --git a/csrc/flash_attn/src/utils.h b/csrc/flash_attn/src/utils.h new file mode 100644 index 000000000..2221a2faf --- /dev/null +++ b/csrc/flash_attn/src/utils.h @@ -0,0 +1,388 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include +#include + +#include + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#include +#endif + +#include +#include + +#include +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace flash { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ uint32_t relu2(const uint32_t x); + +template<> +inline __device__ uint32_t relu2(const uint32_t x) { + uint32_t res; + const uint32_t zero = 0u; +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + asm volatile("max.f16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero)); +#else + asm volatile( \ + "{\n" \ + "\t .reg .f16x2 sela;\n" \ + "\t set.gtu.u32.f16x2 sela, %1, %2;\n" \ + "\t and.b32 %0, sela, %1;\n" + "}\n" : "=r"(res) : "r"(x), "r"(zero)); +#endif + return res; +} + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +template<> +inline __device__ uint32_t relu2(const uint32_t x) { + uint32_t res; + const uint32_t zero = 0u; + asm volatile("max.bf16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero)); + return res; +} +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + +template +inline __device__ uint32_t convert_relu2(const float2 x); + +template<> +inline __device__ uint32_t convert_relu2(const float2 x) { + uint32_t res; + const uint32_t a = reinterpret_cast(x.x); + const uint32_t b = reinterpret_cast(x.y); + asm volatile("cvt.rn.relu.f16x2.f32 %0, %1, %2;\n" : "=r"(res) : "r"(b), "r"(a)); + return res; +} + +template<> +inline __device__ uint32_t convert_relu2(const float2 x) { + uint32_t res; + const uint32_t a = reinterpret_cast(x.x); + const uint32_t b = reinterpret_cast(x.y); + asm volatile("cvt.rn.relu.bf16x2.f32 %0, %1, %2;\n" : "=r"(res) : "r"(b), "r"(a)); + return res; +} + +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ float2 half2_unpack(uint32_t a); + +template <> +inline __device__ float2 half2_unpack<__half>(uint32_t a) { + return __half22float2(reinterpret_cast<__half2 (&)>(a)); +} + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +template <> +inline __device__ float2 half2_unpack<__nv_bfloat16>(uint32_t a) { + return __bfloat1622float2(reinterpret_cast<__nv_bfloat162 (&)>(a)); +} +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Convert two half2's or bf162's into float, then take their dot product. +template +inline __device__ float hfma2_to_float(const uint32_t a, const uint32_t b) { + float2 af = flash::half2_unpack(a); + float2 bf = flash::half2_unpack(b); + return af.x * bf.x + af.y * bf.y; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Converted two vectors of 8 half's or bf16's into float, then take their dot product. +template +inline __device__ float hmulsum8(const uint4 a, const uint4 b) { + float sum; + sum = flash::hfma2_to_float(a.x, b.x); + sum += flash::hfma2_to_float(a.y, b.y); + sum += flash::hfma2_to_float(a.z, b.z); + sum += flash::hfma2_to_float(a.w, b.w); + return sum; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MaxOp { +__device__ inline T operator()(T const & x, T const & y) { return x > y ? x : y; } +}; + +template <> +struct MaxOp { +// This is slightly faster +__device__ inline float operator()(float const &x, float const &y) { return max(x, y); } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct SumOp { +__device__ inline T operator()(T const & x, T const & y) { return x + y; } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Allreduce { + static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); + template + static __device__ inline T run(T x, Operator &op) { + constexpr int OFFSET = THREADS / 2; + x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); + return Allreduce::run(x, op); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +struct Allreduce<2> { +template +static __device__ inline T run(T x, Operator &op) { + x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); + return x; +} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA, + Tensor4 const& tCsB, TiledMma tiled_mma, + TiledCopy0 smem_thr_copy_A, TiledCopy1 smem_thr_copy_B) { + CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N + CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K + Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA); + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M + Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N + if (!A_in_regs) { copy(smem_thr_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); } + if (!B_in_regs) { copy(smem_thr_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); } + #pragma unroll + for (int i = 0; i < size<2>(tCrA); ++i) { + if (i < size<2>(tCrA) - 1) { + if (!A_in_regs) { copy(smem_thr_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1)); } + if (!B_in_regs) { copy(smem_thr_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); } + } + cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void gemm_A_in_regs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB, + TiledMma tiled_mma, TiledCopy smem_thr_copy_B) { + CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N + CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K + Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N + copy(smem_thr_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); + #pragma unroll + for (int i = 0; i < size<2>(tCrA); ++i) { + if (i < size<2>(tCrA) - 1) { + copy(smem_thr_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); + } + cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) +template +inline __device__ auto convert_layout_acc_rowcol(Layout acc_layout) { + static_assert(decltype(size<0>(acc_layout))::value == 4); + static_assert(decltype(rank(acc_layout))::value == 3); + auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N) + return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l))); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Convert rowcol_layout from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) +// if using m16n8k16, or to ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. +template +inline __device__ auto convert_layout_rowcol_Aregs(Layout rowcol_layout) { + using X = Underscore; + static_assert(decltype(size<0, 0>(rowcol_layout))::value == 2); + static_assert(decltype(size<1, 0>(rowcol_layout))::value == 2); + constexpr int mma_shape_K = get<2>(typename MMA_traits::Shape_MNK{}); + static_assert(mma_shape_K == 8 || mma_shape_K == 16); + constexpr int MMA_N_divisor = mma_shape_K == 8 ? 1 : 2; + auto l = logical_divide(rowcol_layout, Shape>>{}); // ((2, MMA_M), (2, (2, MMA_N / 2))) + return make_layout(make_layout(get<1, 0>(l), get<0, 0>(l), get<1, 1, 0>(l)), + get<0, 1>(l), + get<1, 1, 1>(l)); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ auto convert_type(Tensor const &tensor) { + using From_type = typename Engine::value_type; + constexpr int numel = decltype(size(tensor))::value; + cutlass::NumericArrayConverter convert_op; + // HACK: this requires tensor to be "contiguous" + auto frag = convert_op(*reinterpret_cast *>(tensor.data())); + return make_tensor(make_rmem_ptr(&frag), tensor.layout()); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void relu_(Tensor &tensor) { + constexpr int numel = decltype(size(tensor))::value; + static_assert(numel % 2 == 0); + using value_t = typename Engine::value_type; + // HACK: this requires tensor to be "contiguous" + Tensor tensor_uint32 = recast(tensor); + #pragma unroll + for (int i = 0; i < size(tensor_uint32); ++i) { + tensor_uint32(i) = relu2(tensor_uint32(i)); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// On SM80 and above, we can fuse fp32 -> fp16/bf16 conversion and relu into 1 instruction +template +inline __device__ auto convert_type_relu(Tensor const &tensor) { + using From_type = typename Engine::value_type; + static_assert(std::is_same_v || std::is_same_v); + static_assert(std::is_same_v); + constexpr int numel = decltype(size(tensor))::value; + static_assert(numel % 2 == 0); +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + // HACK: this requires tensor to be "contiguous" + Tensor tensor_float2 = recast(tensor); + Tensor out_uint32 = make_tensor(tensor_float2.layout()); + #pragma unroll + for (int i = 0; i < size(out_uint32); ++i) { + out_uint32(i) = convert_relu2(tensor_float2(i)); + } + Tensor out = make_tensor(make_rmem_ptr(out_uint32.data()), tensor.layout()); +#else + Tensor out = flash::convert_type(tensor); + flash::relu_(out); +#endif + return out; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Blocks until all but N previous cp.async.commit_group operations have committed. +// This differs from cute::cp_async_wait in that when N = 0 we don't call cp.async.wait_all +// (which is equivalent to commit_group then wait_group 0). +// Instead we just call cp.async.wait_group 0, which is slightly faster. +// https://github.com/NVIDIA/cutlass/blob/master/include/cute/arch/copy_sm80.hpp#L113 +template +CUTE_HOST_DEVICE +void cp_async_wait() { +#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) + asm volatile("cp.async.wait_group %0;\n" :: "n"(N)); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void copy(TiledCopy thr_copy, Tensor const &S, + Tensor &D, Tensor const &identity_MN, + Tensor const &predicate_K, int max_MN=0) { + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); + CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K + // There's no case where !Clear_OOB_K && Clear_OOB_MN + static_assert(!(Clear_OOB_MN && !Clear_OOB_K)); + #pragma unroll + for (int m = 0; m < size<1>(S); ++m) { + if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { + #pragma unroll + for (int k = 0; k < size<2>(S); ++k) { + if (Is_even_K || predicate_K(k)) { + copy(thr_copy, S(_, m, k), D(_, m, k)); + } else if (Clear_OOB_K) { + clear(D(_, m, k)); + } + } + } else if (Clear_OOB_MN) { + clear(D(_, m, _)); + } + } + // TD [2023-04-13]: Strange that the code below can cause race condition. + // I think it's because the copies are under an if statement. + // if (Is_even_K) { + // #pragma unroll + // for (int m = 0; m < size<1>(S); ++m) { + // if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { + // copy(thr_copy, S(_, m, _), D(_, m, _)); + // } else if (Clear_OOB_MN) { + // clear(D(_, m, _)); + // } + // } + // } else { // It's slightly faster in this case if iterate over K first + // #pragma unroll + // for (int k = 0; k < size<2>(S); ++k) { + // if (predicate_K(k)) { + // #pragma unroll + // for (int m = 0; m < size<1>(S); ++m) { + // if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { + // copy(thr_copy, S(_, m, k), D(_, m, k)); + // } else if (Clear_OOB_MN) { + // clear(D(_, m, k)); + // } + // } + // } else if (Clear_OOB_K) { // There's no case where !Clear_OOB_K && Clear_OOB_MN + // if (Clear_OOB_MN || Is_even_MN) { + // clear(D(_, _, k)); + // } else { + // #pragma unroll + // for (int m = 0; m < size<1>(S); ++m) { + // if (!(Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN)) { + // clear(D(_, m, k)); + // } + // } + // } + // } + // } + // } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace flash diff --git a/flash_attn/__init__.py b/flash_attn/__init__.py index 39e0411d5..7498a2d80 100644 --- a/flash_attn/__init__.py +++ b/flash_attn/__init__.py @@ -1 +1,8 @@ -__version__ = "1.0.9" +__version__ = "2.0.0.post1" + +from flash_attn.flash_attn_interface import flash_attn_func +from flash_attn.flash_attn_interface import flash_attn_kvpacked_func +from flash_attn.flash_attn_interface import flash_attn_qkvpacked_func +from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func +from flash_attn.flash_attn_interface import flash_attn_varlen_kvpacked_func +from flash_attn.flash_attn_interface import flash_attn_varlen_func diff --git a/flash_attn/flash_attention.py b/flash_attn/flash_attention.py deleted file mode 100644 index 0719d0a0d..000000000 --- a/flash_attn/flash_attention.py +++ /dev/null @@ -1,101 +0,0 @@ -import math -import torch -import torch.nn as nn - -from einops import rearrange - -from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func -from flash_attn.bert_padding import unpad_input, pad_input - - -class FlashAttention(nn.Module): - """Implement the scaled dot product attention with softmax. - Arguments - --------- - softmax_scale: The temperature to use for the softmax attention. - (default: 1/sqrt(d_keys) where d_keys is computed at - runtime) - attention_dropout: The dropout rate to apply to the attention - (default: 0.0) - """ - def __init__(self, softmax_scale=None, attention_dropout=0.0): - super().__init__() - self.softmax_scale = softmax_scale - self.dropout_p = attention_dropout - - def forward(self, qkv, key_padding_mask=None, causal=False, cu_seqlens=None, - max_s=None, need_weights=False): - """Implements the multihead softmax attention. - Arguments - --------- - qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None - if unpadded: (nnz, 3, h, d) - key_padding_mask: a bool tensor of shape (B, S) - """ - assert not need_weights - assert qkv.dtype in [torch.float16, torch.bfloat16] - assert qkv.is_cuda - - if cu_seqlens is None: - batch_size = qkv.shape[0] - seqlen = qkv.shape[1] - if key_padding_mask is None: - qkv = rearrange(qkv, 'b s ... -> (b s) ...') - max_s = seqlen - cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32, - device=qkv.device) - output = flash_attn_unpadded_qkvpacked_func( - qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0, - softmax_scale=self.softmax_scale, causal=causal - ) - output = rearrange(output, '(b s) ... -> b s ...', b=batch_size) - else: - nheads = qkv.shape[-2] - x = rearrange(qkv, 'b s three h d -> b s (three h d)') - x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask) - x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads) - output_unpad = flash_attn_unpadded_qkvpacked_func( - x_unpad, cu_seqlens, max_s, self.dropout_p if self.training else 0.0, - softmax_scale=self.softmax_scale, causal=causal - ) - output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), - indices, batch_size, seqlen), - 'b s (h d) -> b s h d', h=nheads) - else: - assert max_s is not None - output = flash_attn_unpadded_qkvpacked_func( - qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0, - softmax_scale=self.softmax_scale, causal=causal - ) - - return output, None - - -class FlashMHA(nn.Module): - - def __init__(self, embed_dim, num_heads, bias=True, batch_first=True, attention_dropout=0.0, - causal=False, device=None, dtype=None) -> None: - assert batch_first - factory_kwargs = {'device': device, 'dtype': dtype} - super().__init__() - self.embed_dim = embed_dim - self.causal = causal - - self.num_heads = num_heads - assert self.embed_dim % num_heads == 0, "self.kdim must be divisible by num_heads" - self.head_dim = self.embed_dim // num_heads - assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8" - - self.Wqkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs) - self.inner_attn = FlashAttention(attention_dropout=attention_dropout) - self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs) - - def forward(self, x, key_padding_mask=None, need_weights=False): - """x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) - key_padding_mask: bool tensor of shape (batch, seqlen) - """ - qkv = self.Wqkv(x) - qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.num_heads) - context, attn_weights = self.inner_attn(qkv, key_padding_mask=key_padding_mask, - need_weights=need_weights, causal=self.causal) - return self.out_proj(rearrange(context, 'b s h d -> b s (h d)')), attn_weights diff --git a/flash_attn/flash_attn_interface.py b/flash_attn/flash_attn_interface.py index 07e97ce96..c6e43408e 100644 --- a/flash_attn/flash_attn_interface.py +++ b/flash_attn/flash_attn_interface.py @@ -1,48 +1,86 @@ import torch import torch.nn as nn -import torch.nn.functional as F -import flash_attn_cuda - - -def _get_block_size(device, head_dim, is_dropout): - assert head_dim % 8 == 0 and head_dim <= 128 - return 256 if head_dim <= 64 else 128 - - -def _flash_attn_forward(q, k, v, out, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, - dropout_p, softmax_scale, causal, return_softmax, num_splits=0, - generator=None): - """ - num_splits: how much to parallelize over the seqlen_q dimension. num_splits=0 means - it will be set by an internal heuristic. We're exposing num_splits mostly for benchmarking. - Don't change it unless you know what you're doing. - """ - softmax_lse, rng_state, *rest = flash_attn_cuda.fwd( - q, k, v, out, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, - softmax_scale, False, causal, return_softmax, num_splits, generator +import flash_attn_2_cuda as flash_attn_cuda +from einops import rearrange + + +def _get_block_size(device, head_dim, is_dropout, is_causal): + # This should match the block sizes in the CUDA kernel + assert head_dim <= 256 + major, minor = torch.cuda.get_device_capability(device) + is_sm8x = major == 8 and minor > 0 # Only include sm86 and sm89, exclude sm80 (A100) + is_sm80 = major == 8 and minor == 0 + is_sm90 = major == 9 and minor == 0 + if head_dim <= 32: + return 128, 128 + if head_dim <= 64: + return (128, 128) if not is_dropout else (128, 64) + elif head_dim <= 96: + return (64, 64) if (is_sm8x and is_causal) else (128, 64) + elif head_dim <= 128: + if is_sm8x: + return (64, 64) if (not is_dropout and is_causal) else (128, 32) + else: + return 128, (64 if not is_dropout else 32) + elif head_dim <= 160: + if is_sm8x: + return (128, 64) if not is_causal else (64, 64) + else: + return 128, 32 + elif head_dim <= 192: + return (128, 64) if not is_dropout else (64, 64) + elif head_dim <= 224: + return (128, 64) if (is_sm80 or is_sm90) else (64, 64) + elif head_dim <= 256: + return (128, 64) if is_sm80 else (64, 64) + + +def _flash_attn_forward(q, k, v, dropout_p, softmax_scale, causal, return_softmax): + if q.stride(-1) != 1: + q = q.contiguous() + if k.stride(-1) != 1: + k = k.contiguous() + if v.stride(-1) != 1: + v = v.contiguous() + out, q, k, v, out_padded, softmax_lse, S_dmask = flash_attn_cuda.fwd( + q, k, v, None, dropout_p, softmax_scale, causal, return_softmax, None + ) + return out, q, k, v, out_padded, softmax_lse, S_dmask + + +def _flash_attn_varlen_forward(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, + dropout_p, softmax_scale, causal, return_softmax): + if q.stride(-1) != 1: + q = q.contiguous() + if k.stride(-1) != 1: + k = k.contiguous() + if v.stride(-1) != 1: + v = v.contiguous() + out, q, k, v, out_padded, softmax_lse, S_dmask = flash_attn_cuda.varlen_fwd( + q, k, v, None, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, + softmax_scale, False, causal, return_softmax, None ) # if out.isnan().any() or softmax_lse.isnan().any(): # breakpoint() - S_dmask = rest[0] if return_softmax else None - return out, softmax_lse, rng_state, S_dmask + return out, q, k, v, out_padded, softmax_lse, S_dmask -def _flash_attn_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k, - max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal, - rng_state=None, num_splits=0, generator=None): - """ - num_splits: whether to parallelize over the seqlen_k dimension (num_splits > 1) or - not (num_splits = 1). num_splits=0 means it will be set by an internal heuristic. - Any value above 1 will call the same kernel (i.e. num_splits=2 would call the same kernel - as num_splits=3), so effectively the choices are 0, 1, and 2. - This hyperparameter can be tuned for performance, but default value (heuristic) should work fine. - """ - dout = dout.contiguous() # CUDA code assumes that dout is contiguous - _, _, _, softmax_d = flash_attn_cuda.bwd( +def _flash_attn_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, + dropout_p, softmax_scale, causal): + dq, dk, dv, softmax_d, = flash_attn_cuda.bwd( + dout, q, k, v, out, softmax_lse, dq, dk, dv, dropout_p, softmax_scale, causal, None + ) + return dq, dk, dv, softmax_d + + +def _flash_attn_varlen_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, + cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, + dropout_p, softmax_scale, causal): + dq, dk, dv, softmax_d, = flash_attn_cuda.varlen_bwd( dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k, - max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, False, causal, - num_splits, generator, rng_state) + max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, False, causal, None + ) # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any(): # breakpoint() return dq, dk, dv, softmax_d @@ -51,191 +89,249 @@ def _flash_attn_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens class FlashAttnQKVPackedFunc(torch.autograd.Function): @staticmethod - def forward(ctx, qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, - return_softmax, deterministic): + def forward(ctx, qkv, dropout_p, softmax_scale, causal, return_softmax): + # Save rng_state because the backward pass will regenerate the dropout mask + rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None if softmax_scale is None: softmax_scale = qkv.shape[-1] ** (-0.5) - out, softmax_lse, rng_state, S_dmask = _flash_attn_forward( - qkv[:, 0], qkv[:, 1], qkv[:, 2], torch.empty_like(qkv[:, 0]), cu_seqlens, cu_seqlens, - max_seqlen, max_seqlen, dropout_p, softmax_scale, causal=causal, - return_softmax=return_softmax + out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_forward( + qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], dropout_p, softmax_scale, + causal=causal, return_softmax=return_softmax and dropout_p > 0 ) - ctx.save_for_backward(qkv, out, softmax_lse, cu_seqlens, rng_state) + ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state) ctx.dropout_p = dropout_p - ctx.max_seqlen = max_seqlen ctx.softmax_scale = softmax_scale ctx.causal = causal - ctx.deterministic = deterministic return out if not return_softmax else (out, softmax_lse, S_dmask) @staticmethod def backward(ctx, dout, *args): - qkv, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors - dqkv = torch.empty_like(qkv) + q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors + if rng_state is not None: + cur_rng_state = torch.cuda.get_rng_state() + torch.cuda.set_rng_state(rng_state) + qkv_shape = q.shape[:-2] + (3, *q.shape[-2:]) + dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device) _flash_attn_backward( - dout, qkv[:, 0], qkv[:, 1], qkv[:, 2], out, softmax_lse, - dqkv[:, 0], dqkv[:, 1], dqkv[:, 2], cu_seqlens, cu_seqlens, - ctx.max_seqlen, ctx.max_seqlen, ctx.dropout_p, ctx.softmax_scale, ctx.causal, - rng_state=rng_state, num_splits=1 if ctx.deterministic else 0, + dout, q, k, v, out, softmax_lse, dqkv[:, :, 0], dqkv[:, :, 1], dqkv[:, :, 2], + ctx.dropout_p, ctx.softmax_scale, ctx.causal + ) + dqkv = dqkv[..., :dout.shape[-1]] # We could have padded the head dimension + if rng_state is not None: + torch.cuda.set_rng_state(cur_rng_state) + return dqkv, None, None, None, None + + +class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function): + + @staticmethod + def forward(ctx, qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, return_softmax): + # Save rng_state because the backward pass will regenerate the dropout mask + rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None + if softmax_scale is None: + softmax_scale = qkv.shape[-1] ** (-0.5) + out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_varlen_forward( + qkv[:, 0], qkv[:, 1], qkv[:, 2], cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, + dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax and dropout_p > 0 ) - return dqkv, None, None, None, None, None, None, None + ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state) + ctx.dropout_p = dropout_p + ctx.max_seqlen = max_seqlen + ctx.softmax_scale = softmax_scale + ctx.causal = causal + return out if not return_softmax else (out, softmax_lse, S_dmask) + + @staticmethod + def backward(ctx, dout, *args): + q, k, v, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors + if rng_state is not None: + cur_rng_state = torch.cuda.get_rng_state() + torch.cuda.set_rng_state(rng_state) + qkv_shape = q.shape[:-2] + (3, *q.shape[-2:]) + dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device) + _flash_attn_varlen_backward( + dout, q, k, v, out, softmax_lse, dqkv[:, 0], dqkv[:, 1], dqkv[:, 2], + cu_seqlens, cu_seqlens, ctx.max_seqlen, ctx.max_seqlen, + ctx.dropout_p, ctx.softmax_scale, ctx.causal + ) + dqkv = dqkv[..., :dout.shape[-1]] # We could have padded the head dimension + if rng_state is not None: + torch.cuda.set_rng_state(cur_rng_state) + return dqkv, None, None, None, None, None, None class FlashAttnKVPackedFunc(torch.autograd.Function): @staticmethod - def forward(ctx, q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, - softmax_scale, causal, return_softmax, deterministic): + def forward(ctx, q, kv, dropout_p, softmax_scale, causal, return_softmax): + # Save rng_state because the backward pass will regenerate the dropout mask + rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) - out, softmax_lse, rng_state, S_dmask = _flash_attn_forward( - q, kv[:, 0], kv[:, 1], torch.empty_like(q), cu_seqlens_q, cu_seqlens_k, max_seqlen_q, - max_seqlen_k, dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax + out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_forward( + q, kv[:, :, 0], kv[:, :, 1], dropout_p, softmax_scale, causal=causal, + return_softmax=return_softmax and dropout_p > 0 ) - ctx.save_for_backward(q, kv, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state) + ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state) ctx.dropout_p = dropout_p - ctx.max_seqlen_q = max_seqlen_q - ctx.max_seqlen_k = max_seqlen_k ctx.softmax_scale = softmax_scale ctx.causal = causal - ctx.deterministic = deterministic return out if not return_softmax else (out, softmax_lse, S_dmask) @staticmethod def backward(ctx, dout, *args): - q, kv, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors + q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors + if rng_state is not None: + cur_rng_state = torch.cuda.get_rng_state() + torch.cuda.set_rng_state(rng_state) dq = torch.empty_like(q) - dkv = torch.empty_like(kv) + kv_shape = k.shape[:-2] + (2, *k.shape[-2:]) + dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device) _flash_attn_backward( - dout, q, kv[:, 0], kv[:, 1], out, softmax_lse, - dq, dkv[:, 0], dkv[:, 1], cu_seqlens_q, cu_seqlens_k, - ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.dropout_p, ctx.softmax_scale, ctx.causal, - rng_state=rng_state, num_splits=1 if ctx.deterministic else 0, + dout, q, k, v, out, softmax_lse, + dq, dkv[:, :, 0], dkv[:, :, 1], ctx.dropout_p, ctx.softmax_scale, ctx.causal ) - return dq, dkv, None, None, None, None, None, None, None, None, None + dq = dq[..., :dout.shape[-1]] # We could have padded the head dimension + dkv = dkv[..., :dout.shape[-1]] + if rng_state is not None: + torch.cuda.set_rng_state(cur_rng_state) + return dq, dkv, None, None, None, None -class FlashAttnFunc(torch.autograd.Function): +class FlashAttnVarlenKVPackedFunc(torch.autograd.Function): @staticmethod - def forward(ctx, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, - softmax_scale, causal, return_softmax, deterministic): + def forward(ctx, q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, + softmax_scale, causal, return_softmax): + # Save rng_state because the backward pass will regenerate the dropout mask + rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) - out, softmax_lse, rng_state, S_dmask = _flash_attn_forward( - q, k, v, torch.empty_like(q), cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, - dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax + out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_varlen_forward( + q, kv[:, 0], kv[:, 1], cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, + dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax and dropout_p > 0 ) - ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state) + ctx.save_for_backward(q, k, v, out_padded, softmax_lse, + cu_seqlens_q, cu_seqlens_k, rng_state) ctx.dropout_p = dropout_p ctx.max_seqlen_q = max_seqlen_q ctx.max_seqlen_k = max_seqlen_k ctx.softmax_scale = softmax_scale ctx.causal = causal - ctx.deterministic = deterministic return out if not return_softmax else (out, softmax_lse, S_dmask) @staticmethod def backward(ctx, dout, *args): q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors + if rng_state is not None: + cur_rng_state = torch.cuda.get_rng_state() + torch.cuda.set_rng_state(rng_state) + dq = torch.empty_like(q) + kv_shape = k.shape[:-2] + (2, *k.shape[-2:]) + dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device) + _flash_attn_varlen_backward( + dout, q, k, v, out, softmax_lse, dq, dkv[:, 0], dkv[:, 1], + cu_seqlens_q, cu_seqlens_k, ctx.max_seqlen_q, ctx.max_seqlen_k, + ctx.dropout_p, ctx.softmax_scale, ctx.causal + ) + dq = dq[..., :dout.shape[-1]] # We could have padded the head dimension + dkv = dkv[..., :dout.shape[-1]] + if rng_state is not None: + torch.cuda.set_rng_state(cur_rng_state) + return dq, dkv, None, None, None, None, None, None, None, None + + +class FlashAttnFunc(torch.autograd.Function): + + @staticmethod + def forward(ctx, q, k, v, dropout_p, softmax_scale, causal, return_softmax): + # Save rng_state because the backward pass will regenerate the dropout mask + rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_forward( + q, k, v, dropout_p, softmax_scale, causal=causal, + return_softmax=return_softmax and dropout_p > 0 + ) + ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state) + ctx.dropout_p = dropout_p + ctx.softmax_scale = softmax_scale + ctx.causal = causal + return out if not return_softmax else (out, softmax_lse, S_dmask) + + @staticmethod + def backward(ctx, dout, *args): + q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors + if rng_state is not None: + cur_rng_state = torch.cuda.get_rng_state() + torch.cuda.set_rng_state(rng_state) dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) _flash_attn_backward( - dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k, - ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.dropout_p, ctx.softmax_scale, ctx.causal, - rng_state=rng_state, num_splits=1 if ctx.deterministic else 0, + dout, q, k, v, out, softmax_lse, + dq, dk, dv, ctx.dropout_p, ctx.softmax_scale, ctx.causal ) - return dq, dk, dv, None, None, None, None, None, None, None, None, None + dq = dq[..., :dout.shape[-1]] # We could have padded the head dimension + dk = dk[..., :dout.shape[-1]] + dv = dv[..., :dout.shape[-1]] + if rng_state is not None: + torch.cuda.set_rng_state(cur_rng_state) + return dq, dk, dv, None, None, None, None, None, None, None, None -class FlashAttnQKVPackedSplitFunc(torch.autograd.Function): +class FlashAttnVarlenFunc(torch.autograd.Function): @staticmethod - def forward(ctx, qkv, cu_seqlens, max_seqlen0, max_seqlen1, batch_size0, dropout_p, - softmax_scale, causal, return_softmax, deterministic): + def forward(ctx, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, + softmax_scale, causal, return_softmax): # Save rng_state because the backward pass will regenerate the dropout mask - if dropout_p > 0: - rng_state0 = torch.cuda.get_rng_state() - generator1 = torch.Generator(device='cuda') - rng_state1 = generator1.get_state() - else: - rng_state0, generator1, rng_state1 = None, None, None + rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None if softmax_scale is None: - softmax_scale = qkv.shape[-1] ** (-0.5) - out = torch.empty_like(qkv[:, 0]) - _, softmax_lse0, S_dmask0 = _flash_attn_forward( - qkv[:, 0], qkv[:, 1], qkv[:, 2], out, cu_seqlens[:batch_size0 + 1], - cu_seqlens[:batch_size0 + 1], max_seqlen0, max_seqlen0, dropout_p, softmax_scale, - causal=causal, return_softmax=return_softmax + softmax_scale = q.shape[-1] ** (-0.5) + out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_varlen_forward( + q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, + dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax and dropout_p > 0 ) - s = torch.cuda.Stream() - with torch.cuda.stream(s): - _, softmax_lse1, S_dmask1 = _flash_attn_forward( - qkv[:, 0], qkv[:, 1], qkv[:, 2], out, cu_seqlens[batch_size0:], - cu_seqlens[batch_size0:], max_seqlen1, max_seqlen1, dropout_p, softmax_scale, - causal=causal, return_softmax=return_softmax, generator=generator1 - ) - torch.cuda.current_stream().wait_stream(s) - ctx.save_for_backward(qkv, out, softmax_lse0, softmax_lse1, cu_seqlens, - rng_state0, rng_state1) + ctx.save_for_backward(q, k, v, out_padded, softmax_lse, + cu_seqlens_q, cu_seqlens_k, rng_state) ctx.dropout_p = dropout_p - ctx.max_seqlen0 = max_seqlen0 - ctx.max_seqlen1 = max_seqlen1 - ctx.batch_size0 = batch_size0 + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_k = max_seqlen_k ctx.softmax_scale = softmax_scale ctx.causal = causal - ctx.deterministic = deterministic - if not return_softmax: - return out - else: - max_seqlen_q = max(softmax_lse0.shape[2], softmax_lse1.shape[2]) - max_seqlen_k = max(S_dmask0.shape[3], S_dmask1.shape[3]) - softmax_lse = torch.cat([F.pad(softmax_lse0, (0, max_seqlen_q - softmax_lse0.shape[2])), - F.pad(softmax_lse1, (0, max_seqlen_q - softmax_lse1.shape[2]))], - dim=0) - return out, softmax_lse, S_dmask0, S_dmask1 + return out if not return_softmax else (out, softmax_lse, S_dmask) @staticmethod def backward(ctx, dout, *args): - qkv, out, softmax_lse0, softmax_lse1, cu_seqlens, rng_state0, rng_state1 = ctx.saved_tensors - batch_size0 = ctx.batch_size0 - if rng_state0 is not None: + q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors + if rng_state is not None: cur_rng_state = torch.cuda.get_rng_state() - torch.cuda.set_rng_state(rng_state0) - if rng_state1 is not None: - generator1 = torch.Generator(device='cuda') - generator1.set_state(rng_state1) - else: - generator1 = None - dqkv = torch.empty_like(qkv) - _flash_attn_backward( - dout, qkv[:, 0], qkv[:, 1], qkv[:, 2], out, softmax_lse0, - dqkv[:, 0], dqkv[:, 1], dqkv[:, 2], cu_seqlens[:batch_size0 + 1], - cu_seqlens[:batch_size0 + 1], ctx.max_seqlen0, ctx.max_seqlen0, ctx.dropout_p, - ctx.softmax_scale, ctx.causal, num_splits=1 if ctx.deterministic else 0, + torch.cuda.set_rng_state(rng_state) + dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) + _flash_attn_varlen_backward( + dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k, + ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.dropout_p, ctx.softmax_scale, ctx.causal ) - s = torch.cuda.Stream() - with torch.cuda.stream(s): - _flash_attn_backward( - dout, qkv[:, 0], qkv[:, 1], qkv[:, 2], out, softmax_lse1, - dqkv[:, 0], dqkv[:, 1], dqkv[:, 2], cu_seqlens[batch_size0:], - cu_seqlens[batch_size0:], ctx.max_seqlen1, ctx.max_seqlen1, ctx.dropout_p, - ctx.softmax_scale, ctx.causal, generator=generator1, - num_splits=1 if ctx.deterministic else 0, - ) - torch.cuda.current_stream().wait_stream(s) - if rng_state0 is not None: + dq = dq[..., :dout.shape[-1]] # We could have padded the head dimension + dk = dk[..., :dout.shape[-1]] + dv = dv[..., :dout.shape[-1]] + if rng_state is not None: torch.cuda.set_rng_state(cur_rng_state) - return dqkv, None, None, None, None, None, None, None, None, None + return dq, dk, dv, None, None, None, None, None, None, None, None -def flash_attn_unpadded_qkvpacked_func(qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale=None, - causal=False, return_attn_probs=False, deterministic=False): +def flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False, + return_attn_probs=False): """dropout_p should be set to 0.0 during evaluation + If Q, K, V are already stacked into 1 tensor, this function will be faster than + calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation + of the gradients of Q, K, V. + Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads + than Q. Note that the number of heads in KV must be divisible by the number of heads in Q. + For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head + 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. + Arguments: - qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch. - cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths - of the sequences in the batch, used to index into qkv. - max_seqlen: int. Maximum sequence length in the batch. + qkv: (batch_size, seqlen, 3, nheads, headdim) dropout_p: float. Dropout probability. softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(headdim). @@ -243,9 +339,8 @@ def flash_attn_unpadded_qkvpacked_func(qkv, cu_seqlens, max_seqlen, dropout_p, s return_attn_probs: bool. Whether to return the attention probabilities. This option is for testing only. The returned probabilities are not guaranteed to be correct (they might not have the right scaling). - deterministic: bool. Whether or not to ensure deterministic execution. Return: - out: (total, nheads, headdim). + out: (batch_size, seqlen, nheads, headdim). softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). @@ -253,23 +348,87 @@ def flash_attn_unpadded_qkvpacked_func(qkv, cu_seqlens, max_seqlen, dropout_p, s The output of softmax (possibly with different scaling). It also encodes the dropout pattern (negative means that location was dropped, nonnegative means it was kept). """ - return FlashAttnQKVPackedFunc.apply(qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, - causal, return_attn_probs, deterministic) + return FlashAttnQKVPackedFunc.apply(qkv, dropout_p, softmax_scale, causal, return_attn_probs) -def flash_attn_unpadded_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, - dropout_p, softmax_scale=None, causal=False, - return_attn_probs=False, deterministic=False): +def flash_attn_kvpacked_func(q, kv, dropout_p=0.0, softmax_scale=None, causal=False, + return_attn_probs=False): """dropout_p should be set to 0.0 during evaluation + If K, V are already stacked into 1 tensor, this function will be faster than + calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation + of the gradients of K, V. + Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads + than Q. Note that the number of heads in KV must be divisible by the number of heads in Q. + For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head + 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. + Arguments: - q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. - kv: (total_k, 2, nheads, headdim), where total_k = total number of key tokens in the batch. - cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths - of the sequences in the batch, used to index into q. - cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths - of the sequences in the batch, used to index into kv. - max_seqlen_q: int. Maximum query sequence length in the batch. - max_seqlen_k: int. Maximum key sequence length in the batch. + q: (batch_size, seqlen, nheads, headdim) + kv: (batch_size, seqlen, 2, nheads_k, headdim) + dropout_p: float. Dropout probability. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + return_attn_probs: bool. Whether to return the attention probabilities. This option is for + testing only. The returned probabilities are not guaranteed to be correct + (they might not have the right scaling). + Return: + out: (batch_size, seqlen, nheads, headdim). + softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The + logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax + normalization factor). + S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). + The output of softmax (possibly with different scaling). It also encodes the dropout + pattern (negative means that location was dropped, nonnegative means it was kept). + """ + return FlashAttnKVPackedFunc.apply(q, kv, dropout_p, softmax_scale, causal, return_attn_probs) + + +def flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False, + return_attn_probs=False): + """dropout_p should be set to 0.0 during evaluation + Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads + than Q. Note that the number of heads in KV must be divisible by the number of heads in Q. + For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head + 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. + + Arguments: + q: (batch_size, seqlen, nheads, headdim) + k: (batch_size, seqlen, nheads_k, headdim) + v: (batch_size, seqlen, nheads_k, headdim) + dropout_p: float. Dropout probability. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + return_attn_probs: bool. Whether to return the attention probabilities. This option is for + testing only. The returned probabilities are not guaranteed to be correct + (they might not have the right scaling). + Return: + out: (batch_size, seqlen, nheads, headdim). + softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The + logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax + normalization factor). + S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). + The output of softmax (possibly with different scaling). It also encodes the dropout + pattern (negative means that location was dropped, nonnegative means it was kept). + """ + return FlashAttnFunc.apply(q, k, v, dropout_p, softmax_scale, causal, return_attn_probs) + + +def flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens, max_seqlen, dropout_p=0.0, softmax_scale=None, + causal=False, return_attn_probs=False): + """dropout_p should be set to 0.0 during evaluation + If Q, K, V are already stacked into 1 tensor, this function will be faster than + calling flash_attn_varlen_func on Q, K, V since the backward pass avoids explicit concatenation + of the gradients of Q, K, V. + For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head + 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. + + Arguments: + qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch. + cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths + of the sequences in the batch, used to index into qkv. + max_seqlen: int. Maximum sequence length in the batch. dropout_p: float. Dropout probability. softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(headdim). @@ -277,9 +436,8 @@ def flash_attn_unpadded_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, max_seq return_attn_probs: bool. Whether to return the attention probabilities. This option is for testing only. The returned probabilities are not guaranteed to be correct (they might not have the right scaling). - deterministic: bool. Whether or not to ensure deterministic execution. Return: - out: (total_q, nheads, headdim). + out: (total, nheads, headdim). softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). @@ -287,19 +445,26 @@ def flash_attn_unpadded_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, max_seq The output of softmax (possibly with different scaling). It also encodes the dropout pattern (negative means that location was dropped, nonnegative means it was kept). """ - return FlashAttnKVPackedFunc.apply(q, kv, cu_seqlens_q, cu_seqlens_k, - max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal, - return_attn_probs, deterministic) + return FlashAttnVarlenQKVPackedFunc.apply( + qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, return_attn_probs + ) -def flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, - dropout_p, softmax_scale=None, causal=False, return_attn_probs=False, - deterministic=False): +def flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, + dropout_p=0.0, softmax_scale=None, causal=False, + return_attn_probs=False): """dropout_p should be set to 0.0 during evaluation + If K, V are already stacked into 1 tensor, this function will be faster than + calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation + of the gradients of K, V. + Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads + than Q. Note that the number of heads in KV must be divisible by the number of heads in Q. + For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head + 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. + Arguments: q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. - k: (total_k, nheads, headdim), where total_k = total number of key tokens in the batch. - v: (total_k, nheads, headdim), where total_k = total number of key tokens in the batch. + kv: (total_k, 2, nheads_k, headdim), where total_k = total number of key tokens in the batch. cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, used to index into q. cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths @@ -313,9 +478,8 @@ def flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, return_attn_probs: bool. Whether to return the attention probabilities. This option is for testing only. The returned probabilities are not guaranteed to be correct (they might not have the right scaling). - deterministic: bool. Whether or not to ensure deterministic execution. Return: - out: (total_q, nheads, headdim). + out: (total, nheads, headdim). softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). @@ -323,27 +487,31 @@ def flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, The output of softmax (possibly with different scaling). It also encodes the dropout pattern (negative means that location was dropped, nonnegative means it was kept). """ - return FlashAttnFunc.apply(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, - dropout_p, softmax_scale, causal, return_attn_probs, deterministic) - + return FlashAttnVarlenKVPackedFunc.apply( + q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, + dropout_p, softmax_scale, causal, return_attn_probs + ) -def flash_attn_unpadded_qkvpacked_split_func( - qkv, cu_seqlens, max_seqlen0, max_seqlen1, batch_size0, dropout_p, softmax_scale=None, - causal=False, return_attn_probs=False, deterministic=False): - """ - Split attention into 2 kernels running on 2 separate streams for performance reason: - e.g., if the batch has some sequences of length <= 128 and some > 128, it might be faster to - have one kernel dealing with seqlen <= 128 and one kernel for seqlen > 128. - dropout_p should be set to 0.0 during evaluation. +def flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, + dropout_p=0.0, softmax_scale=None, causal=False, + return_attn_probs=False): + """dropout_p should be set to 0.0 during evaluation + Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads + than Q. Note that the number of heads in K, V must be divisible by the number of heads in Q. + For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head + 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. Arguments: - qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch. - cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths - of the sequences in the batch, used to index into qkv. - max_seqlen0: int. Maximum sequence length in 1st part of the batch. - max_seqlen1: int. Maximum sequence length in 2nd part of the batch. - batch_size0: int. Number of sequences in the 1st part of the batch. + q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. + k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. + v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. + cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths + of the sequences in the batch, used to index into q. + cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths + of the sequences in the batch, used to index into kv. + max_seqlen_q: int. Maximum query sequence length in the batch. + max_seqlen_k: int. Maximum key sequence length in the batch. dropout_p: float. Dropout probability. softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(headdim). @@ -351,7 +519,6 @@ def flash_attn_unpadded_qkvpacked_split_func( return_attn_probs: bool. Whether to return the attention probabilities. This option is for testing only. The returned probabilities are not guaranteed to be correct (they might not have the right scaling). - deterministic: bool. Whether or not to ensure deterministic execution. Return: out: (total, nheads, headdim). softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The @@ -361,15 +528,7 @@ def flash_attn_unpadded_qkvpacked_split_func( The output of softmax (possibly with different scaling). It also encodes the dropout pattern (negative means that location was dropped, nonnegative means it was kept). """ - return FlashAttnQKVPackedSplitFunc.apply(qkv, cu_seqlens, max_seqlen0, max_seqlen1, batch_size0, - dropout_p, softmax_scale, causal, return_attn_probs, - deterministic) - - -def flash_attn_func(qkv, cu_seqlens, dropout_p, max_s, softmax_scale=None, causal=False, - return_attn_probs=False): - """For backward-compatibility only, will remove soon. - dropout_p should be set to 0.0 during evaluation - """ - return flash_attn_unpadded_qkvpacked_func(qkv, cu_seqlens, max_s, dropout_p, softmax_scale, - causal, return_attn_probs) + return FlashAttnVarlenFunc.apply( + q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, + dropout_p, softmax_scale, causal, return_attn_probs + ) diff --git a/flash_attn/modules/block.py b/flash_attn/modules/block.py index a4ff5a260..5a25d79a9 100644 --- a/flash_attn/modules/block.py +++ b/flash_attn/modules/block.py @@ -8,7 +8,7 @@ import torch.nn.functional as F from torch import Tensor -from torchvision.ops import StochasticDepth +# from torchvision.ops import StochasticDepth from flash_attn.modules.mha import MHA from flash_attn.modules.mlp import Mlp @@ -70,12 +70,12 @@ def __init__(self, dim, mixer_cls=None, mlp_cls=None, norm_cls=nn.LayerNorm, mlp_cls = partial(Mlp, hidden_features=4 * dim) self.mixer = mixer_cls(dim) self.dropout1 = dropout_cls(resid_dropout1) - self.drop_path1 = StochasticDepth(drop_path1, mode='row') + # self.drop_path1 = StochasticDepth(drop_path1, mode='row') self.norm1 = norm_cls(dim) self.mlp = mlp_cls(dim) if not isinstance(self.mlp, nn.Identity): self.dropout2 = dropout_cls(resid_dropout2) - self.drop_path2 = StochasticDepth(drop_path2, mode='row') + # self.drop_path2 = StochasticDepth(drop_path2, mode='row') self.norm2 = norm_cls(dim) if self.fused_dropout_add_ln: @@ -129,13 +129,14 @@ def forward(self, hidden_states: Tensor, residual: Optional[Tensor] = None, if self.residual_in_fp32: residual = residual.to(torch.float32) else: - if self.drop_path1.p == 0 or not self.training: - rowscale1 = None - else: - rowscale1 = self.drop_path1(torch.ones( - hidden_states.shape[:-1], device=hidden_states.device, - dtype=hidden_states.dtype) - ) + rowscale1 = None + # if self.drop_path1.p == 0 or not self.training: + # rowscale1 = None + # else: + # rowscale1 = self.drop_path1(torch.ones( + # hidden_states.shape[:-1], device=hidden_states.device, + # dtype=hidden_states.dtype) + # ) hidden_states, residual = fused_add_norm_fn( hidden_states, residual, self.norm1.weight, self.norm1.bias, self.dropout1.p if self.training else 0.0, self.norm1.eps, @@ -156,13 +157,14 @@ def forward(self, hidden_states: Tensor, residual: Optional[Tensor] = None, if self.residual_in_fp32: residual = residual.to(torch.float32) else: - if self.drop_path2.p == 0 or not self.training: - rowscale2 = None - else: - rowscale2 = self.drop_path2(torch.ones( - hidden_states.shape[:-1], device=hidden_states.device, - dtype=hidden_states.dtype) - ) + # if self.drop_path2.p == 0 or not self.training: + # rowscale2 = None + # else: + # rowscale2 = self.drop_path2(torch.ones( + # hidden_states.shape[:-1], device=hidden_states.device, + # dtype=hidden_states.dtype) + # ) + rowscale2 = None hidden_states, residual = fused_add_norm_fn( hidden_states, residual, self.norm2.weight, self.norm2.bias, self.dropout2.p if self.training else 0.0, self.norm2.eps, diff --git a/flash_attn/modules/mha.py b/flash_attn/modules/mha.py index 684935dac..b70a58928 100644 --- a/flash_attn/modules/mha.py +++ b/flash_attn/modules/mha.py @@ -10,14 +10,10 @@ from einops import rearrange try: - from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func - from flash_attn.flash_attn_interface import flash_attn_unpadded_kvpacked_func -except ImportError: - flash_attn_unpadded_qkvpacked_func, flash_attn_unpadded_kvpacked_func = None, None - -try: - from flash_attn.ops.flash_attn_triton import flash_attn_qkvpacked_func, flash_attn_kvpacked_func + from flash_attn import flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func + from flash_attn import flash_attn_qkvpacked_func, flash_attn_kvpacked_func except ImportError: + flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func = None, None flash_attn_qkvpacked_func, flash_attn_kvpacked_func = None, None try: @@ -46,17 +42,13 @@ class FlashSelfAttention(nn.Module): attention_dropout: The dropout rate to apply to the attention (default: 0.0) """ - def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0, - triton=False): + def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0): super().__init__() - if attention_dropout != 0.0 or not triton: - assert flash_attn_unpadded_qkvpacked_func is not None, 'FlashAttention is not installed' - if attention_dropout == 0.0 and triton: - assert flash_attn_qkvpacked_func is not None, 'FlashAttention Triton is not installed' + assert flash_attn_varlen_qkvpacked_func is not None, 'FlashAttention is not installed' + assert flash_attn_qkvpacked_func is not None, 'FlashAttention is not installed' self.causal = causal self.softmax_scale = softmax_scale self.drop = nn.Dropout(attention_dropout) - self.triton = triton def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None): """Implements the multihead softmax attention. @@ -83,26 +75,13 @@ def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None): assert cu_seqlens.dtype == torch.int32 assert max_seqlen is not None assert isinstance(max_seqlen, int) - return flash_attn_unpadded_qkvpacked_func( + return flash_attn_varlen_qkvpacked_func( qkv, cu_seqlens, max_seqlen, self.drop.p if self.training else 0.0, softmax_scale=self.softmax_scale, causal=causal ) else: - batch_size, seqlen = qkv.shape[0], qkv.shape[1] - # Triton version doesn't support dropout - if self.triton and (self.drop.p == 0 or not self.training): - output = flash_attn_qkvpacked_func(qkv, None, causal, self.softmax_scale) - else: - qkv = rearrange(qkv, 'b s ... -> (b s) ...') - max_seqlen = seqlen - cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32, - device=qkv.device) - output = flash_attn_unpadded_qkvpacked_func( - qkv, cu_seqlens, max_seqlen, self.drop.p if self.training else 0.0, - softmax_scale=self.softmax_scale, causal=causal - ) - output = rearrange(output, '(b s) ... -> b s ...', b=batch_size) - return output + return flash_attn_qkvpacked_func(qkv, self.drop.p if self.training else 0.0, + softmax_scale=self.softmax_scale, causal=causal) class FlashCrossAttention(nn.Module): @@ -115,17 +94,13 @@ class FlashCrossAttention(nn.Module): attention_dropout: The dropout rate to apply to the attention (default: 0.0) """ - def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0, - triton=False): + def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0): super().__init__() - if attention_dropout != 0.0 or not triton: - assert flash_attn_unpadded_kvpacked_func is not None, 'FlashAttention is not installed' - if attention_dropout == 0.0 and triton: - assert flash_attn_kvpacked_func is not None, 'FlashAttention Triton is not installed' + assert flash_attn_varlen_kvpacked_func is not None, 'FlashAttention is not installed' + assert flash_attn_kvpacked_func is not None, 'FlashAttention is not installed' self.causal = causal self.softmax_scale = softmax_scale self.drop = nn.Dropout(attention_dropout) - self.triton = triton def forward(self, q, kv, causal=None, cu_seqlens=None, max_seqlen=None, cu_seqlens_k=None, max_seqlen_k=None): @@ -133,7 +108,7 @@ def forward(self, q, kv, causal=None, cu_seqlens=None, max_seqlen=None, Arguments --------- q: The tensor containing the query. (B, Sq, H, D) - kv: The tensor containing the key and value. (B, Sk, 2, H, D) + kv: The tensor containing the key and value. (B, Sk, 2, H_k, D) causal: if passed, will override self.causal cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, used to index into q. @@ -154,7 +129,7 @@ def forward(self, q, kv, causal=None, cu_seqlens=None, max_seqlen=None, assert cu_seqlens_k.dtype == torch.int32 assert max_seqlen_k is not None assert isinstance(max_seqlen, int) - return flash_attn_unpadded_kvpacked_func( + return flash_attn_varlen_kvpacked_func( q, kv, cu_seqlens, cu_seqlens_k, max_seqlen, max_seqlen_k, self.drop.p if self.training else 0.0, softmax_scale=self.softmax_scale, causal=causal @@ -162,23 +137,9 @@ def forward(self, q, kv, causal=None, cu_seqlens=None, max_seqlen=None, else: batch_size, seqlen_q = q.shape[0], q.shape[1] seqlen_k = kv.shape[1] - assert kv.shape[0] == batch_size and kv.shape[3] == q.shape[2] and kv.shape[4] == q.shape[3] - if self.triton and (self.drop.p == 0.0 or not self.training): # Triton version doesn't support dropout - output = flash_attn_kvpacked_func(q, kv, None, causal, self.softmax_scale) - else: - q = rearrange(q, 'b s ... -> (b s) ...') - kv = rearrange(kv, 'b s ... -> (b s) ...') - cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q, - dtype=torch.int32, device=q.device) - cu_seqlens_k = torch.arange(0, (batch_size + 1) * seqlen_k, step=seqlen_k, - dtype=torch.int32, device=kv.device) - output = flash_attn_unpadded_kvpacked_func( - q, kv, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen_k, - self.drop.p if self.training else 0.0, - softmax_scale=self.softmax_scale, causal=causal - ) - output = rearrange(output, '(b s) ... -> b s ...', b=batch_size) - return output + assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3] + return flash_attn_kvpacked_func(q, kv, self.drop.p if self.training else 0.0, + causal=causal, softmax_scale=self.softmax_scale) class SelfAttention(nn.Module): diff --git a/setup.py b/setup.py index 7597ea318..88353f8c2 100644 --- a/setup.py +++ b/setup.py @@ -111,28 +111,52 @@ def append_nvcc_threads(nvcc_extra_args): _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) if bare_metal_version < Version("11.0"): raise RuntimeError("FlashAttention is only supported on CUDA 11 and above") -cc_flag.append("-gencode") -cc_flag.append("arch=compute_75,code=sm_75") +# cc_flag.append("-gencode") +# cc_flag.append("arch=compute_75,code=sm_75") cc_flag.append("-gencode") cc_flag.append("arch=compute_80,code=sm_80") if bare_metal_version >= Version("11.8"): cc_flag.append("-gencode") cc_flag.append("arch=compute_90,code=sm_90") -subprocess.run(["git", "submodule", "update", "--init", "csrc/flash_attn/cutlass"]) +subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"]) ext_modules.append( CUDAExtension( - name="flash_attn_cuda", + name="flash_attn_2_cuda", sources=[ - "csrc/flash_attn/fmha_api.cpp", - "csrc/flash_attn/src/fmha_fwd_hdim32.cu", - "csrc/flash_attn/src/fmha_fwd_hdim64.cu", - "csrc/flash_attn/src/fmha_fwd_hdim128.cu", - "csrc/flash_attn/src/fmha_bwd_hdim32.cu", - "csrc/flash_attn/src/fmha_bwd_hdim64.cu", - "csrc/flash_attn/src/fmha_bwd_hdim128.cu", - "csrc/flash_attn/src/fmha_block_fprop_fp16_kernel.sm80.cu", - "csrc/flash_attn/src/fmha_block_dgrad_fp16_kernel_loop.sm80.cu", + "csrc/flash_attn/flash_api.cpp", + "csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim224_fp16_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim224_bf16_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu", ], extra_compile_args={ "cxx": ["-O3", "-std=c++17"] + generator_flag, @@ -157,11 +181,12 @@ def append_nvcc_threads(nvcc_extra_args): include_dirs=[ Path(this_dir) / 'csrc' / 'flash_attn', Path(this_dir) / 'csrc' / 'flash_attn' / 'src', - Path(this_dir) / 'csrc' / 'flash_attn' / 'cutlass' / 'include', + Path(this_dir) / 'csrc' / 'cutlass' / 'include', ], ) ) + def get_package_version(): with open(Path(this_dir) / "flash_attn" / "__init__.py", "r") as f: version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE) @@ -172,6 +197,7 @@ def get_package_version(): else: return str(public_version) + setup( name="flash_attn", version=get_package_version(), @@ -179,11 +205,9 @@ def get_package_version(): exclude=("build", "csrc", "include", "tests", "dist", "docs", "benchmarks", "flash_attn.egg-info",) ), author="Tri Dao", - author_email="trid@stanford.edu", + author_email="trid@cs.stanford.edu", description="Flash Attention: Fast and Memory-Efficient Exact Attention", - long_description=long_description, - long_description_content_type="text/markdown", - url="https://github.com/HazyResearch/flash-attention", + url="https://github.com/Dao-AILab/flash-attention", classifiers=[ "Programming Language :: Python :: 3", "License :: OSI Approved :: BSD License", diff --git a/tests/test_flash_attn.py b/tests/test_flash_attn.py index 3486f9b06..27223ebfb 100644 --- a/tests/test_flash_attn.py +++ b/tests/test_flash_attn.py @@ -1,5 +1,4 @@ import math -from functools import partial import torch import torch.nn.functional as F @@ -8,100 +7,87 @@ from einops import rearrange, repeat -from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_unpadded_qkvpacked_func, _get_block_size, flash_attn_unpadded_kvpacked_func, flash_attn_unpadded_func -from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_split_func +from flash_attn import flash_attn_func, flash_attn_kvpacked_func, flash_attn_qkvpacked_func +from flash_attn import flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func +from flash_attn import flash_attn_varlen_func +from flash_attn.flash_attn_interface import _get_block_size from flash_attn.bert_padding import unpad_input, pad_input, index_first_axis -try: - from flash_attn.flash_attn_triton import flash_attn_func -except (ImportError, AttributeError): # Older version of Triton doesn't have tl.constexpr - flash_attn_func = None + +MAX_HEADDIM_SM8x = 192 is_sm75 = torch.cuda.get_device_capability('cuda') == (7, 5) +is_sm8x = torch.cuda.get_device_capability('cuda')[0] == 8 is_sm80 = torch.cuda.get_device_capability('cuda') == (8, 0) +is_sm90 = torch.cuda.get_device_capability('cuda') == (9, 0) def generate_random_padding_mask(max_seqlen, batch_size, device, mode='random'): - assert mode in ['full', 'random', 'third', 'split'] + assert mode in ['full', 'random', 'third'] if mode == 'full': lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32) elif mode == 'random': - lengths = torch.randint(max(1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device) + lengths = torch.randint(max(1, max_seqlen - 20), max_seqlen, (batch_size, 1), device=device) elif mode == 'third': - lengths = torch.randint(max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device) - elif mode == 'split': - lengths0 = torch.randint(min(128, max_seqlen), max_seqlen + 1, - (batch_size // 4 * 3, 1), device=device) - lengths1 = torch.randint(min(max(1, max_seqlen - 20), 128), min(max_seqlen, 128) + 1, - (batch_size - batch_size // 4 * 3, 1), device=device) - lengths = torch.cat([lengths0, lengths1], dim=0) + lengths = torch.randint(max_seqlen // 3, max_seqlen, (batch_size, 1), device=device) padding_mask = repeat(torch.arange(max_seqlen, device=device), 's -> b s', b=batch_size) < lengths return padding_mask -def generate_qkv(x, Wqkv, nheads, query_padding_mask=None, key_padding_mask=None, +def generate_qkv(q, k, v, query_padding_mask=None, key_padding_mask=None, kvpacked=False, qkvpacked=False): """ Arguments: - x: (batch_size, seqlen, nheads * d) - Wqkv: nn.Linear(nheads * d, 3 * nheads * d) + q: (batch_size, seqlen_q, nheads, d) + k: (batch_size, seqlen_k, nheads_k, d) + v: (batch_size, seqlen_k, nheads_k, d) query_padding_mask: (batch_size, seqlen), bool key_padding_mask: (batch_size, seqlen), bool """ assert not (kvpacked and qkvpacked) - batch_size, seqlen, dim = x.shape - q, k, v = Wqkv(x).chunk(3, dim=-1) + batch_size, seqlen_q, nheads, d = q.shape + _, seqlen_k, nheads_k, _ = k.shape + assert k.shape == (batch_size, seqlen_k, nheads_k, d) + assert v.shape == (batch_size, seqlen_k, nheads_k, d) if query_padding_mask is not None: q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, query_padding_mask) - q_unpad = rearrange(q_unpad, 'nnz (h d) -> nnz h d', h=nheads) - output_pad_fn = lambda output_unpad: rearrange( - pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices_q, batch_size, seqlen), - 'b s (h d) -> b s h d', h=nheads - ) + output_pad_fn = lambda output_unpad: pad_input(output_unpad, indices_q, batch_size, seqlen_q) else: - q_unpad = rearrange(q, 'b s (h d) -> (b s) h d', h=nheads) - cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32, + q_unpad = rearrange(q, 'b s h d -> (b s) h d') + cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q_unpad.device) - max_seqlen_q = seqlen + max_seqlen_q = seqlen_q output_pad_fn = lambda output_unpad: rearrange(output_unpad, '(b s) h d -> b s h d', b=batch_size) if key_padding_mask is not None: k_unpad, indices_k, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask) - k_unpad = rearrange(k_unpad, 'nnz (h d) -> nnz h d', h=nheads) v_unpad, _, _, _ = unpad_input(v, key_padding_mask) - v_unpad = rearrange(v_unpad, 'nnz (h d) -> nnz h d', h=nheads) else: - k_unpad = rearrange(k, 'b s (h d) -> (b s) h d', h=nheads) - v_unpad = rearrange(v, 'b s (h d) -> (b s) h d', h=nheads) - cu_seqlens_k = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32, - device=q_unpad.device) - max_seqlen_k = seqlen + k_unpad = rearrange(k, 'b s h d -> (b s) h d') + v_unpad = rearrange(v, 'b s h d -> (b s) h d') + cu_seqlens_k = torch.arange(0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, + device=k_unpad.device) + max_seqlen_k = seqlen_k if qkvpacked: assert (query_padding_mask == key_padding_mask).all() + assert nheads == nheads_k qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1) - qkv = rearrange(torch.stack([q, k, v], dim=2), 'b s t (h d) -> b s t h d', h=nheads) + qkv = torch.stack([q, k, v], dim=2) if query_padding_mask is not None: - dqkv_pad_fn = lambda dqkv_unpad: rearrange( - pad_input(rearrange(dqkv_unpad, 'nnz t h d -> nnz (t h d)'), indices_q, batch_size, seqlen), - 'b s (t h d) -> b s t h d', t=3, h=nheads - ) + dqkv_pad_fn = lambda dqkv_unpad: pad_input(dqkv_unpad, indices_q, batch_size, seqlen_q) else: dqkv_pad_fn = lambda dqkv_unpad: rearrange(dqkv_unpad, '(b s) t h d -> b s t h d', b=batch_size) return (qkv_unpad.detach().requires_grad_(), cu_seqlens_q, max_seqlen_q, qkv.detach().requires_grad_(), output_pad_fn, dqkv_pad_fn) elif kvpacked: kv_unpad = torch.stack([k_unpad, v_unpad], dim=1) - q = rearrange(q, 'b s (h d) -> b s h d', h=nheads) - kv = rearrange(torch.stack([k, v], dim=2), 'b s t (h d) -> b s t h d', h=nheads) + kv = torch.stack([k, v], dim=2) dq_pad_fn = output_pad_fn if key_padding_mask is not None: - dkv_pad_fn = lambda dkv_unpad: rearrange( - pad_input(rearrange(dkv_unpad, 'nnz t h d -> nnz (t h d)'), indices_k, batch_size, seqlen), - 'b s (t h d) -> b s t h d', t=2, h=nheads - ) + dkv_pad_fn = lambda dkv_unpad: pad_input(dkv_unpad, indices_k, batch_size, seqlen_k) else: dkv_pad_fn = lambda dkv_unpad: rearrange(dkv_unpad, '(b s) t h d -> b s t h d', b=batch_size) return (q_unpad.detach().requires_grad_(), kv_unpad.detach().requires_grad_(), @@ -109,35 +95,30 @@ def generate_qkv(x, Wqkv, nheads, query_padding_mask=None, key_padding_mask=None q.detach().requires_grad_(), kv.detach().requires_grad_(), output_pad_fn, dq_pad_fn, dkv_pad_fn) else: - q, k, v = [rearrange(z, 'b s (h d) -> b s h d', h=nheads).detach().requires_grad_() - for z in [q, k, v]] dq_pad_fn = output_pad_fn if key_padding_mask is not None: - dk_pad_fn = lambda dk_unpad: rearrange( - pad_input(rearrange(dk_unpad, 'nnz h d -> nnz (h d)'), indices_k, batch_size, seqlen), - 'b s (h d) -> b s h d', h=nheads - ) + dk_pad_fn = lambda dk_unpad: pad_input(dk_unpad, indices_k, batch_size, seqlen_k) else: dk_pad_fn = lambda dk_unpad: rearrange(dk_unpad, '(b s) h d -> b s h d', b=batch_size) return (q_unpad.detach().requires_grad_(), k_unpad.detach().requires_grad_(), v_unpad.detach().requires_grad_(), cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, - q, k, v, + q.detach().requires_grad_(), k.detach().requires_grad_(), + v.detach().requires_grad_(), output_pad_fn, dq_pad_fn, dk_pad_fn) def attention_ref(q, k, v, query_padding_mask=None, key_padding_mask=None, dropout_p=0.0, - dropout_mask=None, causal=False, bias=None, upcast=True, reorder_ops=False): + dropout_mask=None, causal=False, upcast=True, reorder_ops=False): """ Arguments: q: (batch_size, seqlen_q, nheads, head_dim) - k: (batch_size, seqlen_k, nheads, head_dim) - v: (batch_size, seqlen_k, nheads, head_dim) + k: (batch_size, seqlen_k, nheads_k, head_dim) + v: (batch_size, seqlen_k, nheads_k, head_dim) query_padding_mask: (batch_size, seqlen_q) key_padding_mask: (batch_size, seqlen_k) dropout_p: float dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k) - bias: (batch_size, nheads, seqlen_q, seqlen_k) upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast output back to fp16/bf16. reorder_ops: whether to change the order of operations (scaling k instead of scaling k, etc.) @@ -151,13 +132,13 @@ def attention_ref(q, k, v, query_padding_mask=None, key_padding_mask=None, dropo if upcast: q, k, v = q.float(), k.float(), v.float() seqlen_q, seqlen_k = q.shape[1], k.shape[1] + k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2]) + v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2]) d = q.shape[-1] if not reorder_ops: scores = torch.einsum('bthd,bshd->bhts', q / math.sqrt(d), k) else: scores = torch.einsum('bthd,bshd->bhts', q, k / math.sqrt(d)) - if bias is not None: - scores = (scores + bias).to(dtype=scores.dtype) if key_padding_mask is not None: scores.masked_fill_(rearrange(~key_padding_mask, 'b s -> b 1 1 s'), float('-inf')) if causal: @@ -238,37 +219,40 @@ def convert_flash_attn_S_to_softmax(S, query_padding_mask, key_padding_mask, hea causal=False): """FlashAttention stores the S matrix in a different way. Arguments: - S: (batch_size, nheads, seqlen_q, seqlen_k) + S: (batch_size, nheads, seqlen_q_rounded, seqlen_k_rounded) query_padding_mask: (batch_size, seqlen_q) key_padding_mask: (batch_size, seqlen_k) """ - S_flat = rearrange(S, 'b h t s -> b h (t s)') seqlen_q, seqlen_k = S.shape[-2:] - block_size = _get_block_size(S.device, head_dim, is_dropout) - loop_steps = (seqlen_k + block_size - 1) // block_size warps_n = 4 - mmas_n = (seqlen_k // warps_n // 16) if seqlen_k <= block_size else (block_size // warps_n // 16) - S_converted = rearrange(S_flat, 'b h (loop nsteps mmas_n warps_n eight t r c0 c1) -> b h (nsteps r eight) (loop mmas_n warps_n c0 t c1)', - loop=loop_steps, nsteps=seqlen_q // 16, mmas_n=mmas_n, warps_n=warps_n, eight=8, t=4, - r=2, c0=2, c1=2) - - # Need to zero out things not in attention_mask in case S was initialized with random values - # and some of those values aren't overwritten. - seqlen_q_og = query_padding_mask.shape[-1] - if seqlen_q_og < seqlen_q: - query_padding_mask = F.pad(query_padding_mask, (0, seqlen_q - seqlen_q_og)) - else: - query_padding_mask = query_padding_mask[:, :seqlen_q] - S_converted = S_converted.masked_fill(rearrange(~query_padding_mask, 'b s -> b 1 s 1'), 0.0) - seqlen_k_og = key_padding_mask.shape[-1] - if seqlen_k_og < seqlen_k: - key_padding_mask = F.pad(key_padding_mask, (0, seqlen_k - seqlen_k_og)) - else: - key_padding_mask = key_padding_mask[:, :seqlen_k] - S_converted = S_converted.masked_fill(rearrange(~key_padding_mask, 'b s -> b 1 1 s'), 0.0) + blocksize_m, blocksize_n = _get_block_size(S.device, head_dim, is_dropout, causal) + nblocks_n = (seqlen_k + blocksize_n - 1) // blocksize_n + nblocks_m = (seqlen_q + blocksize_m - 1) // blocksize_m + mmas_n = (blocksize_n + 16 - 1) // 16 + S_flat = rearrange(S, 'b h (nblocks_m blocksize_m) (nblocks_n blocksize_n) -> b h nblocks_m nblocks_n (blocksize_m blocksize_n)', + blocksize_m=blocksize_m, blocksize_n=blocksize_n) + S_converted = rearrange(S_flat, 'b h nblocks_m nblocks_n (mmas_n mmas_m warps_n eight four c2 c1 c0) -> b h (nblocks_m mmas_m warps_n c1 eight) (nblocks_n mmas_n c2 four c0)', + mmas_n=mmas_n, warps_n=warps_n, eight=8, c0=2, c1=2, c2=2, four=4) if causal: causal_mask = torch.triu(torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=S.device), 1) S_converted.masked_fill_(causal_mask, 0.0) + + # Need to zero out things not in attention_mask in case S was initialized with random values + # and some of those values aren't overwritten. + seqlen_q_og = query_padding_mask.shape[-1] if query_padding_mask is not None else seqlen_q + if query_padding_mask is not None: + if seqlen_q_og < seqlen_q: + query_padding_mask = F.pad(query_padding_mask, (0, seqlen_q - seqlen_q_og)) + else: + query_padding_mask = query_padding_mask[:, :seqlen_q] + S_converted = S_converted.masked_fill(rearrange(~query_padding_mask, 'b s -> b 1 s 1'), 0.0) + seqlen_k_og = key_padding_mask.shape[-1] if key_padding_mask is not None else seqlen_k + if key_padding_mask is not None: + if seqlen_k_og < seqlen_k: + key_padding_mask = F.pad(key_padding_mask, (0, seqlen_k - seqlen_k_og)) + else: + key_padding_mask = key_padding_mask[:, :seqlen_k] + S_converted = S_converted.masked_fill(rearrange(~key_padding_mask, 'b s -> b 1 1 s'), 0.0) if seqlen_q_og < seqlen_q: S_converted = S_converted[:, :, :seqlen_q_og, :] else: @@ -300,16 +284,15 @@ def normalize_flash_attn_S(attn_unnorm, q, k, v, query_padding_mask=None, key_pa if causal: causal_mask = torch.triu(torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=q.device), 1) scores.masked_fill_(causal_mask, float('-inf')) - block_size = _get_block_size(scores.device, head_dim, is_dropout) - scores_block = scores.split(block_size, dim=-1) + _, block_size_n = _get_block_size(scores.device, head_dim, is_dropout, causal) + scores_block = scores.split(block_size_n, dim=-1) lse_block = torch.stack([torch.logsumexp(s, dim=-1) for s in scores_block], dim=-1) - lcse_block = torch.logcumsumexp(lse_block, dim=-1).unbind(dim=-1) - scores_max_block = ([torch.amax(scores_block[0], dim=-1)] - + [torch.maximum(torch.amax(s, dim=-1), lcse) - for s, lcse in zip(scores_block[1:], lcse_block[:-1])]) - attn_unnorm_block = attn_unnorm.split(block_size, dim=-1) - attn_norm = torch.cat([a / rearrange(torch.exp(lcse_block[-1] - m), 'b h s -> b h s 1') - for a, m in zip(attn_unnorm_block, scores_max_block)], dim=-1) + lse = torch.logsumexp(lse_block, dim=-1) + scores_max_block = torch.stack([torch.amax(s, dim=-1) for s in scores_block], dim=-1) + cummax_block = torch.cummax(scores_max_block.flip(-1), dim=-1).values.flip(-1).unbind(dim=-1) + attn_unnorm_block = attn_unnorm.split(block_size_n, dim=-1) + attn_norm = torch.cat([a / rearrange(torch.exp(lse - m), 'b h s -> b h s 1') + for a, m in zip(attn_unnorm_block, cummax_block)], dim=-1) if query_padding_mask is not None: attn_norm.masked_fill_(rearrange(~query_padding_mask, 'b s -> b 1 s 1'), 0.0) return attn_norm.to(dtype=attn_unnorm.dtype) @@ -350,68 +333,79 @@ def get_dropout_fraction(dropout_mask, query_padding_mask=None, key_padding_mask @pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize('dtype', [torch.float16]) @pytest.mark.parametrize('causal', [False, True]) -# @pytest.mark.parametrize('causal', [False]) -@pytest.mark.parametrize('d', [128, 64, 80, 40, 32, 16]) +# @pytest.mark.parametrize('causal', [True]) +@pytest.mark.parametrize('d', [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize('d', [32, 64, 96, 128]) # @pytest.mark.parametrize('d', [64]) +# @pytest.mark.parametrize('seqlen', [128, 256, 384, 512, 768, 1024, 2048]) @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048]) -# @pytest.mark.parametrize('seqlen', [128]) +# @pytest.mark.parametrize('seqlen', [97]) @pytest.mark.parametrize('dropout_p', [0.0, 0.17]) -# @pytest.mark.parametrize('dropout_p', [0.0]) -def test_flash_attn_unpadded_qkvpacked(seqlen, d, dropout_p, causal, dtype): +# @pytest.mark.parametrize('dropout_p', [0.17]) +def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, dtype): if seqlen >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30: pytest.skip() # Reference implementation OOM device = 'cuda' - # if dtype == torch.float16: - # rtol, atol = (1e-3, 3e-4) if not causal else (1e-3, 1e-3) - # else: # torch.bfloat16 - # rtol, atol = (3e-3, 3e-3) if not causal else (1e-3, 1e-3) # set seed torch.random.manual_seed(0) - # Set smaller batch size so it would trigger num_splits > 1 - batch_size = 8 - nheads = 4 - x = torch.randn(batch_size, seqlen, nheads * d, device=device, dtype=dtype, requires_grad=True) - Wqkv = torch.nn.Linear(nheads * d, 3 * nheads * d, device=device, dtype=dtype) - - key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='random') - # key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='full') - - qkv_unpad, cu_seqlens, max_seqlen, qkv, output_pad_fn, dqkv_pad_fn = generate_qkv( - x, Wqkv, nheads, key_padding_mask, key_padding_mask, qkvpacked=True - ) - - output_unpad, sm_lse, S_dmask = flash_attn_unpadded_qkvpacked_func( - qkv_unpad, cu_seqlens, max_seqlen, dropout_p, return_attn_probs=True, causal=causal + batch_size = 16 + nheads = 9 + qkv = torch.randn(batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, + requires_grad=True) + out, lse, S_dmask = flash_attn_qkvpacked_func( + qkv, dropout_p, return_attn_probs=True, causal=causal ) - output = output_pad_fn(output_unpad) - S_dmask_converted = convert_flash_attn_S_to_softmax( - S_dmask, key_padding_mask, key_padding_mask, d, dropout_p > 0.0, causal=causal - ) - dropout_mask = S_dmask_converted >= 0 - attn_unnorm = S_dmask_converted.abs() - attn = normalize_flash_attn_S(attn_unnorm, qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], - key_padding_mask, key_padding_mask, dropout_p > 0.0, causal=causal) - dropout_fraction = get_dropout_fraction(dropout_mask, key_padding_mask, key_padding_mask, - causal=causal).item() - - output_ref, attn_ref = attention_qkvpacked_ref(qkv, key_padding_mask, dropout_p, dropout_mask, - causal=causal) - output_pt, attn_pt = attention_qkvpacked_ref(qkv, key_padding_mask, dropout_p, dropout_mask, - causal=causal, upcast=False, reorder_ops=True) - print(f'Actual dropout fraction: {dropout_fraction}') - print(f'Output max diff: {(output - output_ref).abs().max().item()}') - print(f'Output mean diff: {(output - output_ref).abs().mean().item()}') - print(f'Pytorch max diff: {(output_pt - output_ref).abs().max().item()}') - print(f'Pytorch mean diff: {(output_pt - output_ref).abs().mean().item()}') - print(f'Attention max diff: {(attn - attn_ref).abs().max().item()}') - print(f'Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}') - - if is_sm80 or d <= 64: # Only run backward for d=128 on A100 - g = torch.randn_like(output) - dqkv_unpad, = torch.autograd.grad(output, qkv_unpad, g) - dqkv = dqkv_pad_fn(dqkv_unpad) - dqkv_ref, = torch.autograd.grad(output_ref, qkv, g) - dqkv_pt, = torch.autograd.grad(output_pt, qkv, g) + if dropout_p > 0.0: + S_dmask_converted = convert_flash_attn_S_to_softmax( + S_dmask, None, None, d, dropout_p > 0.0, causal=causal + )[:, :, :seqlen, :seqlen] + dropout_mask = S_dmask_converted >= 0 + attn_unnorm = S_dmask_converted.abs() + attn = normalize_flash_attn_S(attn_unnorm, qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], + None, None, dropout_p > 0.0, causal=causal) + dropout_fraction = get_dropout_fraction(dropout_mask, None, None, causal=causal).item() + print(f'Actual dropout fraction: {dropout_fraction}') + else: + dropout_mask = None + + out_ref, attn_ref = attention_qkvpacked_ref(qkv, None, dropout_p, dropout_mask, causal=causal) + out_pt, attn_pt = attention_qkvpacked_ref(qkv, None, dropout_p, dropout_mask, causal=causal, + upcast=False, reorder_ops=True) + # v = qkv[:, :, 2].float() + # qk = torch.einsum('bshd,bthd->bhst', qkv[:, :, 0], qkv[:, :, 1]).float() + # if causal: + # causal_mask = torch.triu(torch.ones(seqlen, seqlen, dtype=torch.bool, device=qkv.device), 1) + # qk.masked_fill_(causal_mask, float('-inf')) + # m = qk.amax(-1, keepdim=True) + # s_tmp = torch.exp((qk - m) / math.sqrt(d)) + # p_tmp = torch.softmax(qk / math.sqrt(d), -1) + # p_dropped = p_tmp if dropout_mask is None else p_tmp.masked_fill(~dropout_mask, 0) + # lse_ref = torch.logsumexp(qk / math.sqrt(d), -1) + # qk_max1 = torch.max(qk[:, :, 128:, 192:], -1, keepdim=True).values + # qk_max2 = torch.max(qk[:, :, 128:, 128:], -1, keepdim=True).values + # qk_max3 = torch.max(qk[:, :, 128:, 64:], -1, keepdim=True).values + # qk_max4 = torch.max(qk[:, :, 128:, :], -1, keepdim=True).values + # o1 = torch.einsum('bhst,bthd->bshd', torch.exp((qk[:, :, 128:, 192:] - qk_max1) / math.sqrt(d)), v[:, 192:]) + # o2 = torch.einsum('bhst,bthd->bshd', torch.exp((qk[:, :, 128:, 128:] - qk_max2) / math.sqrt(d)), v[:, 128:]) + # o3 = torch.einsum('bhst,bthd->bshd', torch.exp((qk[:, :, 128:, 64:] - qk_max3) / math.sqrt(d)), v[:, 64:]) + # o4 = torch.einsum('bhst,bthd->bshd', torch.exp((qk[:, :, 128:, :] - qk_max4) / math.sqrt(d)), v[:, :]) + print(f'Output max diff: {(out - out_ref).abs().max().item()}') + print(f'Output mean diff: {(out - out_ref).abs().mean().item()}') + print(f'Pytorch max diff: {(out_pt - out_ref).abs().max().item()}') + print(f'Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}') + if dropout_p > 0.0: + print(f'Attention max diff: {(attn - attn_ref).abs().max().item()}') + print(f'Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}') + + g = torch.randn_like(out) + # do_o = (g.float() * out.float()).sum(-1) + # dv_tmp = torch.einsum('bhts,bthd->bshd', attn_pt[:, :, :64], g[:, :64]) + # dv_tmp1 = torch.einsum('bhts,bthd->bshd', attn_pt[:, :, 64:], g[:, 64:]) + if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90): + dqkv, = torch.autograd.grad(out, qkv, g) + dqkv_ref, = torch.autograd.grad(out_ref, qkv, g) + dqkv_pt, = torch.autograd.grad(out_pt, qkv, g) print(f'dQ max diff: {(dqkv[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}') print(f'dK max diff: {(dqkv[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}') print(f'dV max diff: {(dqkv[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}') @@ -423,584 +417,411 @@ def test_flash_attn_unpadded_qkvpacked(seqlen, d, dropout_p, causal, dtype): # Check that FlashAttention's numerical error is at most twice the numerical error # of a Pytorch implementation. - assert (output - output_ref).abs().max().item() <= 2 * (output_pt - output_ref).abs().max().item() - # assert torch.allclose(output, output_ref, rtol=rtol, atol=atol) - assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item() - # assert torch.allclose(attn, attn_ref, rtol=rtol, atol=atol) - if dropout_p == 0.0: - assert dropout_mask.all() - else: - assert 0.98 <= dropout_fraction / dropout_p <= 1.02 + assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + + if dropout_p > 0.0: + assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item() + assert abs(dropout_fraction - dropout_p) <= 0.01 + + if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90): + assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item() - if is_sm80 or d <= 64: # Only run backward for d=128 on A100 - # Error for dK and dV could be a bit higher if we're splitting along seqlen_q dimension - assert (dqkv - dqkv_ref).abs().max().item() <= 4 * (dqkv_pt - dqkv_ref).abs().max().item() - # assert torch.allclose(dqkv, dqkv_ref, rtol=rtol, atol=atol) @pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize('dtype', [torch.float16]) @pytest.mark.parametrize('causal', [False, True]) -@pytest.mark.parametrize('d', [128, 64, 80, 40, 32, 16]) +# @pytest.mark.parametrize('causal', [False]) +@pytest.mark.parametrize('d', [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [64]) @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048]) # @pytest.mark.parametrize('seqlen', [128]) @pytest.mark.parametrize('dropout_p', [0.0, 0.17]) # @pytest.mark.parametrize('dropout_p', [0.0]) -def test_flash_attn_unpadded_kvpacked(seqlen, d, dropout_p, causal, dtype): +def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, dtype): if seqlen >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30: pytest.skip() # Reference implementation OOM device = 'cuda' - # if dtype == torch.float16: - # rtol, atol = (1e-3, 3e-4) if not causal else (1e-3, 1e-3) - # else: # torch.bfloat16 - # rtol, atol = (3e-3, 3e-3) if not causal else (1e-3, 1e-3) # set seed torch.random.manual_seed(0) - batch_size = 32 - nheads = 4 - x = torch.randn(batch_size, seqlen, nheads * d, device=device, dtype=dtype, requires_grad=True) - Wqkv = torch.nn.Linear(nheads * d, 3 * nheads * d, device=device, dtype=dtype) + batch_size = 5 + nheads = 6 + qkv = torch.randn(batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, + requires_grad=True) - query_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='random') key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='random') + # key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='full') - (q_unpad, kv_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, q, kv, - output_pad_fn, dq_pad_fn, dkv_pad_fn) = generate_qkv( - x, Wqkv, nheads, query_padding_mask, key_padding_mask, kvpacked=True - ) - - output_unpad, sm_lse, S_dmask = flash_attn_unpadded_kvpacked_func( - q_unpad, kv_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, - dropout_p, return_attn_probs=True, causal=causal + qkv_unpad, cu_seqlens, max_seqlen, qkv, output_pad_fn, dqkv_pad_fn = generate_qkv( + *qkv.unbind(dim=2), key_padding_mask, key_padding_mask, qkvpacked=True ) - output = output_pad_fn(output_unpad) - S_dmask_converted = convert_flash_attn_S_to_softmax( - S_dmask, query_padding_mask, key_padding_mask, d, dropout_p > 0.0, causal=causal + + out_unpad, sm_lse, S_dmask = flash_attn_varlen_qkvpacked_func( + qkv_unpad, cu_seqlens, max_seqlen, dropout_p, return_attn_probs=True, causal=causal ) - dropout_mask = S_dmask_converted >= 0 - attn_unnorm = S_dmask_converted.abs() - attn = normalize_flash_attn_S(attn_unnorm, q, kv[:, :, 0], kv[:, :, 1], - query_padding_mask, key_padding_mask, dropout_p > 0.0, causal=causal) - dropout_fraction = get_dropout_fraction(dropout_mask, query_padding_mask, key_padding_mask, - causal=causal) - - output_ref, attn_ref = attention_kvpacked_ref(q, kv, query_padding_mask, key_padding_mask, - dropout_p, dropout_mask, causal=causal) - output_pt, attn_pt = attention_kvpacked_ref(q, kv, query_padding_mask, key_padding_mask, - dropout_p, dropout_mask, causal=causal, - upcast=False, reorder_ops=True) - print(f'Actual dropout fraction: {dropout_fraction}') - print(f'Output max diff: {(output - output_ref).abs().max().item()}') - print(f'Output mean diff: {(output - output_ref).abs().mean().item()}') - print(f'Pytorch max diff: {(output_pt - output_ref).abs().max().item()}') - print(f'Pytorch mean diff: {(output_pt - output_ref).abs().mean().item()}') - print(f'Attention max diff: {(attn - attn_ref).abs().max().item()}') - print(f'Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}') - - if is_sm80 or d <= 64: # Only run backward for d=128 on A100 - g = torch.randn_like(output) - dq_unpad, dkv_unpad, = torch.autograd.grad(output, (q_unpad, kv_unpad), g) - dq = dq_pad_fn(dq_unpad) - dkv = dkv_pad_fn(dkv_unpad) - dq_ref, dkv_ref, = torch.autograd.grad(output_ref, (q, kv), g) - dq_pt, dkv_pt = torch.autograd.grad(output_pt, (q, kv), g) - print(f'dQ max diff: {(dq - dq_ref).abs().max().item()}') - print(f'dK max diff: {(dkv[:, :, 0] - dkv_ref[:, :, 0]).abs().max().item()}') - print(f'dV max diff: {(dkv[:, :, 1] - dkv_ref[:, :, 1]).abs().max().item()}') - print(f'dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}') - print(f'dK Pytorch max diff: {(dkv_pt[:, :, 0] - dkv_ref[:, :, 0]).abs().max().item()}') - print(f'dV Pytorch max diff: {(dkv_pt[:, :, 1] - dkv_ref[:, :, 1]).abs().max().item()}') + out = output_pad_fn(out_unpad) + if dropout_p > 0.0: + S_dmask_converted = convert_flash_attn_S_to_softmax( + S_dmask, key_padding_mask, key_padding_mask, d, dropout_p > 0.0, causal=causal + )[:, :, :seqlen, :seqlen] + dropout_mask = S_dmask_converted >= 0 + attn_unnorm = S_dmask_converted.abs() + attn = normalize_flash_attn_S(attn_unnorm, qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], + key_padding_mask, key_padding_mask, dropout_p > 0.0, + causal=causal) + dropout_fraction = get_dropout_fraction(dropout_mask, key_padding_mask, key_padding_mask, + causal=causal).item() + print(f'Actual dropout fraction: {dropout_fraction}') + else: + dropout_mask = None + + out_ref, attn_ref = attention_qkvpacked_ref(qkv, key_padding_mask, dropout_p, dropout_mask, + causal=causal) + out_pt, attn_pt = attention_qkvpacked_ref(qkv, key_padding_mask, dropout_p, dropout_mask, + causal=causal, upcast=False, reorder_ops=True) + print(f'Output max diff: {(out - out_ref).abs().max().item()}') + print(f'Output mean diff: {(out - out_ref).abs().mean().item()}') + print(f'Pytorch max diff: {(out_pt - out_ref).abs().max().item()}') + print(f'Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}') + if dropout_p > 0.0: + print(f'Attention max diff: {(attn - attn_ref).abs().max().item()}') + print(f'Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}') + + g = torch.randn_like(out) + if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90): + dqkv_unpad, = torch.autograd.grad(out, qkv_unpad, g) + dqkv = dqkv_pad_fn(dqkv_unpad) + dqkv_ref, = torch.autograd.grad(out_ref, qkv, g) + dqkv_pt, = torch.autograd.grad(out_pt, qkv, g) + print(f'dQ max diff: {(dqkv[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}') + print(f'dK max diff: {(dqkv[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}') + print(f'dV max diff: {(dqkv[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}') + print(f'dQKV mean diff: {(dqkv - dqkv_ref).abs().mean().item()}') + print(f'dQ Pytorch max diff: {(dqkv_pt[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}') + print(f'dK Pytorch max diff: {(dqkv_pt[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}') + print(f'dV Pytorch max diff: {(dqkv_pt[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}') + print(f'dQKV Pytorch mean diff: {(dqkv_pt - dqkv_ref).abs().mean().item()}') # Check that FlashAttention's numerical error is at most twice the numerical error # of a Pytorch implementation. - assert (output - output_ref).abs().max().item() <= 2 * (output_pt - output_ref).abs().max().item() - # assert torch.allclose(output, output_ref, rtol=rtol, atol=atol) - assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item() - # assert torch.allclose(attn, attn_ref, rtol=rtol, atol=atol) - if dropout_p == 0.0: - assert dropout_mask.all() - else: - assert 0.99 <= dropout_fraction / dropout_p <= 1.01 + assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() - if is_sm80 or d <= 64: # Only run backward for d=128 on A100 - assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() - assert (dkv - dkv_ref).abs().max().item() <= 2 * (dkv_pt - dkv_ref).abs().max().item() - # assert torch.allclose(dq, dq_ref, rtol=rtol, atol=atol) - # assert torch.allclose(dkv, dkv_ref, rtol=rtol, atol=atol) + if dropout_p > 0.0: + assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item() + assert abs(dropout_fraction - dropout_p) <= 0.01 + + if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90): + assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item() +@pytest.mark.parametrize('kvpacked', [True, False]) +# @pytest.mark.parametrize('kvpacked', [False]) @pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) -# @pytest.mark.parametrize('dtype', [torch.float16]) +# @pytest.mark.parametrize('dtype', [torch.bfloat16]) +@pytest.mark.parametrize('mha_type', ["mha", "mqa", "gqa"]) +# @pytest.mark.parametrize('mha_type', ["mha"]) @pytest.mark.parametrize('causal', [False, True]) -@pytest.mark.parametrize('d', [128, 64, 80, 40, 32, 16]) +# @pytest.mark.parametrize('causal', [False]) +@pytest.mark.parametrize('d', [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [56, 80]) # @pytest.mark.parametrize('d', [64]) -@pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048]) -# @pytest.mark.parametrize('seqlen', [128]) +@pytest.mark.parametrize('seqlen_q,seqlen_k', [(113, 203), (128, 217), (113, 211), (108, 256), (256, 512), (512, 256), (1024, 1024), (1023, 1024), (1024, 1023), (2048, 2048)]) +# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)]) @pytest.mark.parametrize('dropout_p', [0.0, 0.17]) # @pytest.mark.parametrize('dropout_p', [0.0]) -def test_flash_attn_unpadded(seqlen, d, dropout_p, causal, dtype): - if seqlen >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30: +def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, dtype, kvpacked): + if max(seqlen_q, seqlen_k) >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30: pytest.skip() # Reference implementation OOM device = 'cuda' - # if dtype == torch.float16: - # rtol, atol = (1e-3, 3e-4) if not causal else (1e-3, 1e-3) - # else: # torch.bfloat16 - # rtol, atol = (3e-3, 3e-3) if not causal else (1e-3, 1e-3) # set seed torch.random.manual_seed(0) - batch_size = 32 - nheads = 4 - x = torch.randn(batch_size, seqlen, nheads * d, device=device, dtype=dtype, requires_grad=True) - Wqkv = torch.nn.Linear(nheads * d, 3 * nheads * d, device=device, dtype=dtype) - - query_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='random') - key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='random') - - (q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, q, k, v, - output_pad_fn, dq_pad_fn, dk_pad_fn) = generate_qkv( - x, Wqkv, nheads, query_padding_mask, key_padding_mask - ) + batch_size = 16 + nheads = 9 + nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) + assert nheads % nheads_k == 0 + q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) + if kvpacked: + kv = torch.randn(batch_size, seqlen_k, 2, nheads_k, d, device=device, dtype=dtype, + requires_grad=True) + else: + k = torch.randn(batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, + requires_grad=True) + v = torch.randn(batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, + requires_grad=True) + + if kvpacked: + out, lse, S_dmask = flash_attn_kvpacked_func( + q, kv, dropout_p, return_attn_probs=True, causal=causal + ) + else: + out, lse, S_dmask = flash_attn_func( + q, k, v, dropout_p, return_attn_probs=True, causal=causal + ) + if dropout_p > 0.0: + S_dmask_converted = convert_flash_attn_S_to_softmax( + S_dmask, None, None, d, dropout_p > 0.0, causal=causal + )[:, :, :seqlen_q, :seqlen_k] + dropout_mask = S_dmask_converted >= 0 + attn_unnorm = S_dmask_converted.abs() + if kvpacked: + kv_rep = repeat(kv, "b s two h d -> b s two (h g) d", g=nheads // nheads_k) + k_rep, v_rep = kv_rep.unbind(dim=2) + else: + k_rep = repeat(k, "b s h d -> b s (h g) d", g=nheads // nheads_k) + v_rep = repeat(v, "b s h d -> b s (h g) d", g=nheads // nheads_k) + attn = normalize_flash_attn_S(attn_unnorm, q, k_rep, v_rep, + None, None, dropout_p > 0.0, causal=causal) + dropout_fraction = get_dropout_fraction(dropout_mask, None, None, causal=causal).item() + print(f'Actual dropout fraction: {dropout_fraction}') + else: + dropout_mask = None - output_unpad, sm_lse, S_dmask = flash_attn_unpadded_func( - q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, - dropout_p, return_attn_probs=True, causal=causal - ) - output = output_pad_fn(output_unpad) - S_dmask_converted = convert_flash_attn_S_to_softmax( - S_dmask, query_padding_mask, key_padding_mask, d, dropout_p > 0.0, causal=causal - ) - dropout_mask = S_dmask_converted >= 0 - attn_unnorm = S_dmask_converted.abs() - attn = normalize_flash_attn_S(attn_unnorm, q, k, v, query_padding_mask, key_padding_mask, - dropout_p > 0.0, causal=causal) - dropout_fraction = get_dropout_fraction(dropout_mask, query_padding_mask, key_padding_mask, - causal=causal) - - output_ref, attn_ref = attention_ref(q, k, v, query_padding_mask, key_padding_mask, - dropout_p, dropout_mask, causal=causal) - output_pt, attn_pt = attention_ref(q, k, v, query_padding_mask, key_padding_mask, - dropout_p, dropout_mask, causal=causal, - upcast=False, reorder_ops=True) - print(f'Actual dropout fraction: {dropout_fraction}') - print(f'Output max diff: {(output - output_ref).abs().max().item()}') - print(f'Output mean diff: {(output - output_ref).abs().mean().item()}') - print(f'Pytorch max diff: {(output_pt - output_ref).abs().max().item()}') - print(f'Pytorch mean diff: {(output_pt - output_ref).abs().mean().item()}') - print(f'Attention max diff: {(attn - attn_ref).abs().max().item()}') - print(f'Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}') - - if is_sm80 or d <= 64: # Only run backward for d=128 on A100 - g = torch.randn_like(output) - dq_unpad, dk_unpad, dv_unpad, = torch.autograd.grad(output, (q_unpad, k_unpad, v_unpad), g) - dq = dq_pad_fn(dq_unpad) - dk = dk_pad_fn(dk_unpad) - dv = dk_pad_fn(dv_unpad) - dq_ref, dk_ref, dv_ref, = torch.autograd.grad(output_ref, (q, k, v), g) - dq_pt, dk_pt, dv_pt, = torch.autograd.grad(output_pt, (q, k, v), g) + if kvpacked: + out_ref, attn_ref = attention_kvpacked_ref(q, kv, None, None, dropout_p, dropout_mask, + causal=causal) + out_pt, attn_pt = attention_kvpacked_ref(q, kv, None, None, dropout_p, dropout_mask, + causal=causal, upcast=False, reorder_ops=True) + else: + out_ref, attn_ref = attention_ref(q, k, v, None, None, dropout_p, dropout_mask, + causal=causal) + out_pt, attn_pt = attention_ref(q, k, v, None, None, dropout_p, dropout_mask, + causal=causal, upcast=False, reorder_ops=True) + + print(f'Output max diff: {(out - out_ref).abs().max().item()}') + print(f'Output mean diff: {(out - out_ref).abs().mean().item()}') + print(f'Pytorch max diff: {(out_pt - out_ref).abs().max().item()}') + print(f'Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}') + if dropout_p > 0.0: + print(f'Attention max diff: {(attn - attn_ref).abs().max().item()}') + print(f'Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}') + + g = torch.randn_like(out) + do_o = (g.float() * out.float()).sum(-1) + if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90): + if kvpacked: + dq, dkv, = torch.autograd.grad(out, (q, kv), g) + dk, dv = dkv.unbind(2) + dq_ref, dkv_ref, = torch.autograd.grad(out_ref, (q, kv), g) + dk_ref, dv_ref = dkv_ref.unbind(2) + dq_pt, dkv_pt, = torch.autograd.grad(out_pt, (q, kv), g) + dk_pt, dv_pt = dkv_pt.unbind(2) + else: + dq, dk, dv, = torch.autograd.grad(out, (q, k, v), g) + dq_ref, dk_ref, dv_ref, = torch.autograd.grad(out_ref, (q, k, v), g) + dq_pt, dk_pt, dv_pt, = torch.autograd.grad(out_pt, (q, k, v), g) print(f'dQ max diff: {(dq - dq_ref).abs().max().item()}') print(f'dK max diff: {(dk - dk_ref).abs().max().item()}') print(f'dV max diff: {(dv - dv_ref).abs().max().item()}') + print(f'dQ mean diff: {(dq - dq_ref).abs().mean().item()}') + print(f'dK mean diff: {(dk - dk_ref).abs().mean().item()}') + print(f'dV mean diff: {(dv - dv_ref).abs().mean().item()}') print(f'dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}') print(f'dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}') print(f'dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}') + print(f'dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}') + print(f'dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}') + print(f'dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}') # Check that FlashAttention's numerical error is at most twice the numerical error # of a Pytorch implementation. - assert (output - output_ref).abs().max().item() <= 2 * (output_pt - output_ref).abs().max().item() - # assert torch.allclose(output, output_ref, rtol=rtol, atol=atol) - assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item() - # assert torch.allclose(attn, attn_ref, rtol=rtol, atol=atol) - if dropout_p == 0.0: - assert dropout_mask.all() - else: - assert 0.99 <= dropout_fraction / dropout_p <= 1.01 + assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + + if dropout_p > 0.0: + assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item() + assert abs(dropout_fraction - dropout_p) <= 0.01 - if is_sm80 or d <= 64: # Only run backward for d=128 on A100 + if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90): assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() - # assert torch.allclose(dq, dq_ref, rtol=rtol, atol=atol) - # assert torch.allclose(dk, dk_ref, rtol=rtol, atol=atol) - # assert torch.allclose(dv, dv_ref, rtol=rtol, atol=atol) -@pytest.mark.skipif(True, reason='Experimental, not being used') +@pytest.mark.parametrize('kvpacked', [True, False]) +# @pytest.mark.parametrize('kvpacked', [False]) @pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize('dtype', [torch.float16]) +@pytest.mark.parametrize('mha_type', ["mha", "mqa", "gqa"]) +# @pytest.mark.parametrize('mha_type', ["mqa"]) @pytest.mark.parametrize('causal', [False, True]) -# @pytest.mark.parametrize('causal', [False]) -@pytest.mark.parametrize('d', [128, 64, 80, 40, 32, 16]) +# @pytest.mark.parametrize('causal', [True]) +@pytest.mark.parametrize('d', [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [64]) -@pytest.mark.parametrize('seqlen', [512]) +@pytest.mark.parametrize('seqlen_q,seqlen_k', [(113, 203), (128, 217), (113, 211), (108, 256), (256, 512), (512, 256), (1024, 1024), (1023, 1024), (1024, 1023), (2048, 2048)]) +# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)]) @pytest.mark.parametrize('dropout_p', [0.0, 0.17]) # @pytest.mark.parametrize('dropout_p', [0.0]) -def test_flash_attn_split(seqlen, d, dropout_p, causal, dtype): - if seqlen >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30: +def test_flash_attn_varlen_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, dtype, + kvpacked): + if max(seqlen_q, seqlen_k) >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30: pytest.skip() # Reference implementation OOM device = 'cuda' - # if dtype == torch.float16: - # rtol, atol = (1e-3, 3e-4) if not causal else (1e-3, 1e-3) - # else: # torch.bfloat16 - # rtol, atol = (3e-3, 3e-3) if not causal else (1e-3, 1e-3) # set seed torch.random.manual_seed(0) - batch_size = 32 - nheads = 4 - x = torch.randn(batch_size, seqlen, nheads * d, device=device, dtype=dtype, requires_grad=True) - Wqkv = torch.nn.Linear(nheads * d, 3 * nheads * d, device=device, dtype=dtype) - - key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='split') - batch_size0 = batch_size // 4 * 3 # this must match what's in generate_random_padding_mask - # key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='full') - - qkv_unpad, cu_seqlens, max_seqlen0, qkv, output_pad_fn, dqkv_pad_fn = generate_qkv( - x, Wqkv, nheads, key_padding_mask, key_padding_mask, qkvpacked=True - ) - max_seqlen1 = 128 - - output_unpad, sm_lse, S_dmask0, S_dmask1 = flash_attn_unpadded_qkvpacked_split_func( - qkv_unpad, cu_seqlens, max_seqlen0, max_seqlen1, batch_size0, dropout_p, - return_attn_probs=True, causal=causal - ) - output = output_pad_fn(output_unpad) - S_dmask0_converted = convert_flash_attn_S_to_softmax( - S_dmask0, key_padding_mask[:batch_size0], key_padding_mask[:batch_size0], d, dropout_p > 0.0, causal=causal - ) - S_dmask1_converted = convert_flash_attn_S_to_softmax( - S_dmask1, key_padding_mask[batch_size0:, :max_seqlen1], key_padding_mask[batch_size0:, :max_seqlen1], d, dropout_p > 0.0, causal=causal - ) - padding = (S_dmask0_converted.shape[-1] - S_dmask1_converted.shape[-1], - S_dmask0_converted.shape[-2] - S_dmask1_converted.shape[-2]) - S_dmask_converted = torch.cat([S_dmask0_converted, - F.pad(S_dmask1_converted, (0, padding[0], 0, padding[1]))], dim=0) - dropout_mask = S_dmask_converted >= 0 - attn_unnorm = S_dmask_converted.abs() - attn = normalize_flash_attn_S(attn_unnorm, qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], - key_padding_mask, key_padding_mask, dropout_p > 0.0, causal=causal) - dropout_fraction = get_dropout_fraction(dropout_mask, key_padding_mask, key_padding_mask, - causal=causal).item() - - output_ref, attn_ref = attention_qkvpacked_ref(qkv, key_padding_mask, dropout_p, dropout_mask, - causal=causal) - output_pt, attn_pt = attention_qkvpacked_ref(qkv, key_padding_mask, dropout_p, dropout_mask, - causal=causal, upcast=False, reorder_ops=True) - print(f'Actual dropout fraction: {dropout_fraction}') - print(f'Output max diff: {(output - output_ref).abs().max().item()}') - print(f'Output mean diff: {(output - output_ref).abs().mean().item()}') - print(f'Pytorch max diff: {(output_pt - output_ref).abs().max().item()}') - print(f'Pytorch mean diff: {(output_pt - output_ref).abs().mean().item()}') - print(f'Attention max diff: {(attn - attn_ref).abs().max().item()}') - print(f'Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}') - - if is_sm80 or d <= 64: # Only run backward for d=128 on A100 - g = torch.randn_like(output) - dqkv_unpad, = torch.autograd.grad(output, qkv_unpad, g) - dqkv = dqkv_pad_fn(dqkv_unpad) - dqkv_ref, = torch.autograd.grad(output_ref, qkv, g) - dqkv_pt, = torch.autograd.grad(output_pt, qkv, g) - print(f'dQ max diff: {(dqkv[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}') - print(f'dK max diff: {(dqkv[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}') - print(f'dV max diff: {(dqkv[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}') - print(f'dQKV mean diff: {(dqkv - dqkv_ref).abs().mean().item()}') - print(f'dQ Pytorch max diff: {(dqkv_pt[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}') - print(f'dK Pytorch max diff: {(dqkv_pt[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}') - print(f'dV Pytorch max diff: {(dqkv_pt[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}') - print(f'dQKV Pytorch mean diff: {(dqkv_pt - dqkv_ref).abs().mean().item()}') - - # Check that FlashAttention's numerical error is at most twice the numerical error - # of a Pytorch implementation. - assert (output - output_ref).abs().max().item() <= 2 * (output_pt - output_ref).abs().max().item() - # assert torch.allclose(output, output_ref, rtol=rtol, atol=atol) - assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item() - # assert torch.allclose(attn, attn_ref, rtol=rtol, atol=atol) - if dropout_p == 0.0: - assert dropout_mask.all() + batch_size = 16 + nheads = 9 + nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) + assert nheads % nheads_k == 0 + q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) + if kvpacked: + kv = torch.randn(batch_size, seqlen_k, 2, nheads_k, d, device=device, dtype=dtype, + requires_grad=True) else: - assert 0.99 <= dropout_fraction / dropout_p <= 1.01 - - if is_sm80 or d <= 64: # Only run backward for d=128 on A100 - assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item() - # assert torch.allclose(dqkv, dqkv_ref, rtol=rtol, atol=atol) - - -@pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) -# @pytest.mark.parametrize('dtype', [torch.float16]) -@pytest.mark.parametrize('causal', [False, True]) -@pytest.mark.parametrize('d', [128, 64, 80, 40, 32, 16]) -# @pytest.mark.parametrize('d', [64]) -@pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048]) -# @pytest.mark.parametrize('seqlen', [128]) -@pytest.mark.parametrize('dropout_p', [0.0, 0.17]) -# @pytest.mark.parametrize('dropout_p', [0.0]) -def test_flash_attn_race_condition(seqlen, d, dropout_p, causal, dtype): - if seqlen >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30: - pytest.skip() # Reference implementation OOM - device = 'cuda' - # set seed - torch.random.manual_seed(0) - batch_size = 32 - nheads = 4 - x = torch.randn(batch_size, seqlen, nheads * d, device=device, dtype=dtype, requires_grad=True) - Wqkv = torch.nn.Linear(nheads * d, 3 * nheads * d, device=device, dtype=dtype) - - query_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='random') - key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='random') - - (q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, q, k, v, - output_pad_fn, dq_pad_fn, dk_pad_fn) = generate_qkv( - x, Wqkv, nheads, query_padding_mask, key_padding_mask - ) - - torch.random.manual_seed(0) - output_unpad_0, sm_lse_0, S_dmask_0 = flash_attn_unpadded_func( - q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, - dropout_p, return_attn_probs=True, causal=causal - ) - S_dmask_converted_0 = convert_flash_attn_S_to_softmax( - S_dmask_0, query_padding_mask, key_padding_mask, d, dropout_p > 0.0, causal=causal - ) - - if is_sm80 or d <= 64: # Only run backward for d=128 on A100 - g = torch.randn_like(output_unpad_0) - dq_unpad_0, dk_unpad_0, dv_unpad_0, = torch.autograd.grad(output_unpad_0, - (q_unpad, k_unpad, v_unpad), g) - # Parallelizing over seqlen_k makes dq non-deterministic - deterministic_dq = False - # Numerical error if we just do any arithmetic on dq - dq_atol = ((dq_unpad_0 + 0.3 - 0.3) - dq_unpad_0).abs().max().item() - equal_fn = torch.equal if deterministic_dq else partial(torch.allclose, atol=dq_atol) - - for _ in range(10): - torch.random.manual_seed(0) - output_unpad, sm_lse, S_dmask = flash_attn_unpadded_func( + k = torch.randn(batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, + requires_grad=True) + v = torch.randn(batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, + requires_grad=True) + + query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode='random') + key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode='random') + # key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode='full') + + if kvpacked: + (q_unpad, kv_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, q, kv, + output_pad_fn, dq_pad_fn, dkv_pad_fn) = generate_qkv( + q, *kv.unbind(dim=2), query_padding_mask, key_padding_mask, kvpacked=True + ) + out_unpad, sm_lse, S_dmask = flash_attn_varlen_kvpacked_func( + q_unpad, kv_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, + dropout_p, return_attn_probs=True, causal=causal + ) + else: + (q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, q, k, v, + output_pad_fn, dq_pad_fn, dk_pad_fn) = generate_qkv( + q, k, v, query_padding_mask, key_padding_mask, kvpacked=False + ) + out_unpad, sm_lse, S_dmask = flash_attn_varlen_func( q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, return_attn_probs=True, causal=causal ) + out = output_pad_fn(out_unpad) + if dropout_p > 0.0: S_dmask_converted = convert_flash_attn_S_to_softmax( S_dmask, query_padding_mask, key_padding_mask, d, dropout_p > 0.0, causal=causal - ) - assert torch.equal(output_unpad, output_unpad_0) - # sm_lse has some parts that are uninitialized from torch.empty - # assert torch.equal(sm_lse, sm_lse_0) - assert torch.equal(S_dmask_converted, S_dmask_converted_0) - - if is_sm80 or d <= 64: # Only run backward for d=128 on A100 - dq_unpad, dk_unpad, dv_unpad, = torch.autograd.grad(output_unpad, - (q_unpad, k_unpad, v_unpad), g) - assert equal_fn(dq_unpad, dq_unpad_0) - assert torch.equal(dk_unpad, dk_unpad_0) - assert torch.equal(dv_unpad, dv_unpad_0) - - -@pytest.mark.skipif(torch.cuda.device_count() < 2, reason='requires multiple GPUs') -def test_flash_attn_multigpu(): - seqlen = 256 - d = 64 - dropout_p = 0.0 - causal = False - dtype = torch.float16 - device = 'cuda:1' - torch.random.manual_seed(0) - batch_size = 32 - nheads = 4 - x = torch.randn(batch_size, seqlen, nheads * d, device=device, dtype=dtype, requires_grad=True) - Wqkv = torch.nn.Linear(nheads * d, 3 * nheads * d, device=device, dtype=dtype) - - key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='random') - # key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='full') - - qkv_unpad, cu_seqlens, max_seqlen, qkv, output_pad_fn, dqkv_pad_fn = generate_qkv( - x, Wqkv, nheads, key_padding_mask, key_padding_mask, qkvpacked=True - ) - - output_unpad, sm_lse, S_dmask = flash_attn_unpadded_qkvpacked_func( - qkv_unpad, cu_seqlens, max_seqlen, dropout_p, return_attn_probs=True, causal=causal - ) - output = output_pad_fn(output_unpad) - S_dmask_converted = convert_flash_attn_S_to_softmax( - S_dmask, key_padding_mask, key_padding_mask, d, dropout_p > 0.0, causal=causal - ) - dropout_mask = S_dmask_converted >= 0 - attn_unnorm = S_dmask_converted.abs() - attn = normalize_flash_attn_S(attn_unnorm, qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], - key_padding_mask, key_padding_mask, dropout_p > 0.0, causal=causal) - dropout_fraction = get_dropout_fraction(dropout_mask, key_padding_mask, key_padding_mask, - causal=causal).item() - - output_ref, attn_ref = attention_qkvpacked_ref(qkv, key_padding_mask, dropout_p, dropout_mask, - causal=causal) - output_pt, attn_pt = attention_qkvpacked_ref(qkv, key_padding_mask, dropout_p, dropout_mask, - causal=causal, upcast=False, reorder_ops=True) - print(f'Actual dropout fraction: {dropout_fraction}') - print(f'Output max diff: {(output - output_ref).abs().max().item()}') - print(f'Output mean diff: {(output - output_ref).abs().mean().item()}') - print(f'Pytorch max diff: {(output_pt - output_ref).abs().max().item()}') - print(f'Pytorch mean diff: {(output_pt - output_ref).abs().mean().item()}') - print(f'Attention max diff: {(attn - attn_ref).abs().max().item()}') - print(f'Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}') - - g = torch.randn_like(output) - dqkv_unpad, = torch.autograd.grad(output, qkv_unpad, g) - dqkv = dqkv_pad_fn(dqkv_unpad) - dqkv_ref, = torch.autograd.grad(output_ref, qkv, g) - dqkv_pt, = torch.autograd.grad(output_pt, qkv, g) - print(f'dQ max diff: {(dqkv[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}') - print(f'dK max diff: {(dqkv[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}') - print(f'dV max diff: {(dqkv[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}') - print(f'dQKV mean diff: {(dqkv - dqkv_ref).abs().mean().item()}') - print(f'dQ Pytorch max diff: {(dqkv_pt[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}') - print(f'dK Pytorch max diff: {(dqkv_pt[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}') - print(f'dV Pytorch max diff: {(dqkv_pt[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}') - print(f'dQKV Pytorch mean diff: {(dqkv_pt - dqkv_ref).abs().mean().item()}') + )[:, :, :seqlen_q, :seqlen_k] + dropout_mask = S_dmask_converted >= 0 + attn_unnorm = S_dmask_converted.abs() + if kvpacked: + kv_rep = repeat(kv, "b s two h d -> b s two (h g) d", g=nheads // nheads_k) + k_rep, v_rep = kv_rep.unbind(dim=2) + else: + k_rep = repeat(k, "b s h d -> b s (h g) d", g=nheads // nheads_k) + v_rep = repeat(v, "b s h d -> b s (h g) d", g=nheads // nheads_k) + attn = normalize_flash_attn_S(attn_unnorm, q, k_rep, v_rep, + query_padding_mask, key_padding_mask, + dropout_p > 0.0, causal=causal) + dropout_fraction = get_dropout_fraction(dropout_mask, query_padding_mask, + key_padding_mask, causal=causal).item() + print(f'Actual dropout fraction: {dropout_fraction}') + else: + dropout_mask = None + + if kvpacked: + out_ref, attn_ref = attention_kvpacked_ref(q, kv, query_padding_mask, key_padding_mask, + dropout_p, dropout_mask, causal=causal) + out_pt, attn_pt = attention_kvpacked_ref(q, kv, query_padding_mask, key_padding_mask, + dropout_p, dropout_mask, + causal=causal, upcast=False, reorder_ops=True) + else: + out_ref, attn_ref = attention_ref(q, k, v, query_padding_mask, key_padding_mask, + dropout_p, dropout_mask, causal=causal) + out_pt, attn_pt = attention_ref(q, k, v, query_padding_mask, key_padding_mask, + dropout_p, dropout_mask, + causal=causal, upcast=False, reorder_ops=True) + + print(f'Output max diff: {(out - out_ref).abs().max().item()}') + print(f'Output mean diff: {(out - out_ref).abs().mean().item()}') + print(f'Pytorch max diff: {(out_pt - out_ref).abs().max().item()}') + print(f'Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}') + if dropout_p > 0.0: + print(f'Attention max diff: {(attn - attn_ref).abs().max().item()}') + print(f'Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}') + + g = torch.randn_like(out) + if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90): + if kvpacked: + dq_unpad, dkv_unpad, = torch.autograd.grad(out, (q_unpad, kv_unpad), g) + dk, dv = dkv_pad_fn(dkv_unpad).unbind(2) + dq_ref, dkv_ref, = torch.autograd.grad(out_ref, (q, kv), g) + dk_ref, dv_ref = dkv_ref.unbind(2) + dq_pt, dkv_pt, = torch.autograd.grad(out_pt, (q, kv), g) + dk_pt, dv_pt = dkv_pt.unbind(2) + else: + dq_unpad, dk_unpad, dv_unpad, = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g) + dk = dk_pad_fn(dk_unpad) + dv = dk_pad_fn(dv_unpad) + dq_ref, dk_ref, dv_ref, = torch.autograd.grad(out_ref, (q, k, v), g) + dq_pt, dk_pt, dv_pt, = torch.autograd.grad(out_pt, (q, k, v), g) + dq = dq_pad_fn(dq_unpad) + print(f'dQ max diff: {(dq - dq_ref).abs().max().item()}') + print(f'dK max diff: {(dk - dk_ref).abs().max().item()}') + print(f'dV max diff: {(dv - dv_ref).abs().max().item()}') + print(f'dQ mean diff: {(dq - dq_ref).abs().mean().item()}') + print(f'dK mean diff: {(dk - dk_ref).abs().mean().item()}') + print(f'dV mean diff: {(dv - dv_ref).abs().mean().item()}') + print(f'dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}') + print(f'dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}') + print(f'dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}') + print(f'dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}') + print(f'dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}') + print(f'dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}') # Check that FlashAttention's numerical error is at most twice the numerical error # of a Pytorch implementation. - assert (output - output_ref).abs().max().item() <= 2 * (output_pt - output_ref).abs().max().item() - # assert torch.allclose(output, output_ref, rtol=rtol, atol=atol) - assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item() - # assert torch.allclose(attn, attn_ref, rtol=rtol, atol=atol) - if dropout_p == 0.0: - assert dropout_mask.all() - else: - assert 0.99 <= dropout_fraction / dropout_p <= 1.01 + assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() - assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item() + if dropout_p > 0.0: + assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item() + assert abs(dropout_fraction - dropout_p) <= 0.01 + if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90): + assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() -@pytest.mark.skipif(flash_attn_func is None, reason='Triton is not installed or is too old') -@pytest.mark.skipif(not is_sm80, reason='Triton version is only tested on A100') -@pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) -# @pytest.mark.parametrize('dtype', [torch.bfloat16]) +# @pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) +@pytest.mark.parametrize('dtype', [torch.float16]) @pytest.mark.parametrize('causal', [False, True]) # @pytest.mark.parametrize('causal', [True]) -@pytest.mark.parametrize('d', [40, 48, 64, 128, 80, 88, 96]) -# @pytest.mark.parametrize('d', [48]) -@pytest.mark.parametrize('seqlen_q,seqlen_k', [(113, 203), (128, 217), (113, 211), (108, 256), (256, 512), (512, 256), (1024, 1024), (1023, 1024), (1024, 1023), (2048, 2048)]) -# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(1024, 1023)]) -@pytest.mark.parametrize('bias_shape', ([None, '1h1k', '1hqk', 'b11k', 'b1qk'])) -# @pytest.mark.parametrize('bias_shape', (['1hqk'])) -def test_flash_attn_triton_output(seqlen_q, seqlen_k, d, causal, dtype, bias_shape): - if seqlen_q >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30: +# @pytest.mark.parametrize('d', [32, 56, 64, 80, 96, 128]) +# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) +@pytest.mark.parametrize('d', [64]) +# @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048]) +@pytest.mark.parametrize('seqlen', [128, 256, 384, 512, 768, 1024, 2048]) +# @pytest.mark.parametrize('seqlen', [193]) +# @pytest.mark.parametrize('dropout_p', [0.0, 0.17]) +@pytest.mark.parametrize('dropout_p', [0.0]) +def test_flash_attn_race_condition(seqlen, d, dropout_p, causal, dtype): + if seqlen >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30: pytest.skip() # Reference implementation OOM device = 'cuda' # set seed torch.random.manual_seed(0) batch_size = 32 nheads = 4 - q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype) - k, v = torch.randn(batch_size, seqlen_k, 2, nheads, d, device=device, dtype=dtype).unbind(dim=2) - if bias_shape == '1h1k': - bias = torch.randn(1, nheads, 1, seqlen_k, dtype=torch.float, device=device) - elif bias_shape == '1hqk': - bias = torch.randn(1, nheads, seqlen_q, seqlen_k, dtype=torch.float, device=device) - elif bias_shape == 'b11k': - bias = torch.randn(batch_size, 1, 1, seqlen_k, dtype=torch.float, device=device) - elif bias_shape == 'b1qk': - bias = torch.randn(batch_size, 1, seqlen_q, seqlen_k, dtype=torch.float, device=device) - else: - bias = None - - q, k, v = [x.detach().requires_grad_() for x in [q, k, v]] - output = flash_attn_func(q, k, v, bias, causal) - - output_ref, attn_ref = attention_ref(q, k, v, bias=bias, causal=causal) - output_pt, attn_pt = attention_ref(q, k, v, bias=bias, causal=causal, upcast=False, - reorder_ops=True) - print(f'Output max diff: {(output - output_ref).abs().max().item()}') - print(f'Output mean diff: {(output - output_ref).abs().mean().item()}') - print(f'Pytorch max diff: {(output_pt - output_ref).abs().max().item()}') - print(f'Pytorch mean diff: {(output_pt - output_ref).abs().mean().item()}') - - g = torch.randn_like(output) - dq, dk, dv = torch.autograd.grad(output, (q, k, v), g) - dq_ref, dk_ref, dv_ref, = torch.autograd.grad(output_ref, (q, k, v), g) - dq_pt, dk_pt, dv_pt, = torch.autograd.grad(output_pt, (q, k, v), g) - print(f'dQ max diff: {(dq - dq_ref).abs().max().item()}') - print(f'dK max diff: {(dk - dk_ref).abs().max().item()}') - print(f'dV max diff: {(dv - dv_ref).abs().max().item()}') - print(f'dQ mean diff: {(dq - dq_ref).abs().mean().item()}') - print(f'dK mean diff: {(dk - dk_ref).abs().mean().item()}') - print(f'dV mean diff: {(dv - dv_ref).abs().mean().item()}') - print(f'dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}') - print(f'dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}') - print(f'dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}') - print(f'dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}') - print(f'dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}') - print(f'dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}') - - # Check that FlashAttention's numerical error is at most twice the numerical error - # of a Pytorch implementation. - assert (output - output_ref).abs().max().item() <= 2 * (output_pt - output_ref).abs().max().item() - # assert torch.allclose(output, output_ref, rtol=rtol, atol=atol) - - assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() - assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() - assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + qkv = torch.randn(batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True) + out0, lse0, _ = flash_attn_qkvpacked_func( + qkv, dropout_p, return_attn_probs=True, causal=causal + ) + g = torch.randn_like(out0) + dqkv0, = torch.autograd.grad(out0, qkv, g) + for _ in range(200): + torch.random.manual_seed(0) + out, lse, S_dmask = flash_attn_qkvpacked_func( + qkv, dropout_p, return_attn_probs=True, causal=causal + ) + assert torch.equal(out, out0) + assert torch.equal(lse, lse0) + # sm_lse has some parts that are uninitialized from torch.empty + # assert torch.equal(sm_lse, sm_lse_0) -@pytest.mark.skipif(flash_attn_func is None, reason='Triton is not installed or is too old') -@pytest.mark.skipif(not is_sm80, reason='Triton version is only tested on A100') -@pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) -# @pytest.mark.parametrize('dtype', [torch.bfloat16]) -@pytest.mark.parametrize('causal', [False, True]) -# @pytest.mark.parametrize('causal', [True]) -@pytest.mark.parametrize('d', [40, 48, 64, 128, 80, 88, 96]) -# @pytest.mark.parametrize('d', [64]) -@pytest.mark.parametrize('seqlen_q,seqlen_k', [(113, 203), (128, 217), (91, 211), (108, 256), (256, 512), (512, 256), (1024, 1024), (1023, 1024), (1024, 1023), (2048, 2048)]) -# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(113, 203)]) -@pytest.mark.parametrize('bias_shape', ([None, '1h1k', '1hqk', 'b11k', 'b1qk'])) -# @pytest.mark.parametrize('bias_shape', (['b1qk'])) -def test_flash_attn_triton_race_condition(seqlen_q, seqlen_k, d, causal, dtype, bias_shape): - if seqlen_q >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30: - pytest.skip() # Reference implementation OOM - device = 'cuda' - # set seed - torch.random.manual_seed(0) - batch_size = 32 - nheads = 4 - q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype) - k, v = torch.randn(batch_size, seqlen_k, 2, nheads, d, device=device, dtype=dtype).unbind(dim=2) - if bias_shape == '1h1k': - bias = torch.randn(1, nheads, 1, seqlen_k, dtype=torch.float, device=device) - elif bias_shape == '1hqk': - bias = torch.randn(1, nheads, seqlen_q, seqlen_k, dtype=torch.float, device=device) - elif bias_shape == 'b11k': - bias = torch.randn(batch_size, 1, 1, seqlen_k, dtype=torch.float, device=device) - elif bias_shape == 'b1qk': - bias = torch.randn(batch_size, 1, seqlen_q, seqlen_k, dtype=torch.float, device=device) - else: - bias = None - - q, k, v = [x.detach().requires_grad_() for x in [q, k, v]] - output_0 = flash_attn_func(q, k, v, bias, causal) - - g = torch.randn_like(output_0) - dq_0, dk_0, dv_0 = torch.autograd.grad(output_0, (q, k, v), g) - - # The SEQUENCE_PARALLEL option for the bwd to makes dq non-deterministic - deterministic_dq = False - # Numerical error if we just do any arithmetic on dq - dq_atol = ((dq_0 + 0.3 - 0.3) - dq_0).abs().max().item() - equal_fn = torch.equal if deterministic_dq else partial(torch.allclose, atol=dq_atol) - # Run 10000 times and check that the results don't change - for i in range(10000): - output = flash_attn_func(q, k, v, bias, causal) - output_equal = torch.equal(output, output_0) - if not output_equal: # Printing / computing diff sometimes makes the race condition disappear - print(f'{dtype = }, {causal = }, {d = }, {seqlen_q = }, {seqlen_k = }, {bias_shape = }, {i = }') - print(f'Output max diff: {(output - output_0).abs().max().item()}') - assert torch.equal(output, output_0) - dq, dk, dv = torch.autograd.grad(output, (q, k, v), g) - dq_equal = equal_fn(dq, dq_0) - dk_equal = torch.equal(dk, dk_0) - dv_equal = torch.equal(dv, dv_0) - if not (dq_equal and dk_equal and dv_equal): - print(f'{dtype = }, {causal = }, {d = }, {seqlen_q = }, {seqlen_k = }, {bias_shape = }, {i = }') - print(f'dQ max diff: {(dq - dq_0).abs().max().item()}') - print(f'dK max diff: {(dk - dk_0).abs().max().item()}') - print(f'dV max diff: {(dv - dv_0).abs().max().item()}') - assert equal_fn(dq, dq_0) - assert torch.equal(dk, dk_0) - assert torch.equal(dv, dv_0) + if not (is_sm75 and d == 128): + dqkv, = torch.autograd.grad(out, qkv, g) + assert torch.equal(dqkv[:, :, 0], dqkv0[:, :, 0]) + assert torch.equal(dqkv[:, :, 1], dqkv0[:, :, 1]) + assert torch.equal(dqkv[:, :, 2], dqkv0[:, :, 2]) diff --git a/training/Dockerfile b/training/Dockerfile index cef847f3d..de535e7ec 100644 --- a/training/Dockerfile +++ b/training/Dockerfile @@ -85,11 +85,11 @@ RUN pip install transformers==4.25.1 datasets==2.8.0 pytorch-lightning==1.8.6 tr RUN pip install git+https://github.com/mlcommons/logging.git@2.1.0 # Install FlashAttention -RUN pip install flash-attn==1.0.9 +RUN pip install flash-attn==2.0.0.post1 # Install CUDA extensions for cross-entropy, fused dense, layer norm RUN git clone https://github.com/HazyResearch/flash-attention \ - && cd flash-attention && git checkout v1.0.9 \ + && cd flash-attention && git checkout v2.0.0.post1 \ && cd csrc/fused_softmax && pip install . && cd ../../ \ && cd csrc/rotary && pip install . && cd ../../ \ && cd csrc/xentropy && pip install . && cd ../../ \