Skip to content

Commit

Permalink
Reorganize directories, add banner figure
Browse files Browse the repository at this point in the history
  • Loading branch information
tridao committed May 29, 2022
1 parent 7025a09 commit 67c3779
Show file tree
Hide file tree
Showing 12 changed files with 21 additions and 32 deletions.
14 changes: 11 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,12 @@
## FlashAttention - Alpha release (0.1).
# FlashAttention
This repository provides the official implementation of FlashAttention from the
following paper.

**FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness***
Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, Christopher Ré
![FlashAttention](assets/flashattn_banner.pdf)

## Alpha release (0.1).

To compile (requiring NVCC and an A100 GPU):
```
Expand Down Expand Up @@ -40,14 +48,14 @@ Our graphs show sequence lengths between 128 and 4096 (when standard attention r

#### Speedup

![FlashAttention speedup](images/flashattn_speedup.jpg)
![FlashAttention speedup](assets/flashattn_speedup.jpg)

We generally see 2-4X speedup at sequence lengths between 128 and 4K, and we see more speedup when using dropout and masking, since we fuse the kernels.
At sequence lengths that are popular with language models like 512 and 1K, we see speedups up to 4X when using dropout and masking.

#### Memory

![FlashAttention memory](images/flashattn_memory.jpg)
![FlashAttention memory](assets/flashattn_memory.jpg)

We show memory savings in this graph (note that memory footprint is the same no matter if you use dropout or masking).
Memory savings are proportional to sequence length -- since standard attention has memory quadratic in sequence length, whereas FlashAttention has memory linear in sequence length.
Expand Down
Binary file added assets/flashattn_banner.pdf
Binary file not shown.
File renamed without changes
File renamed without changes
4 changes: 2 additions & 2 deletions benchmarks/benchmark_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from einops import rearrange, repeat

from benchmarks.utils import benchmark_all, benchmark_forward, benchmark_backward, benchmark_combined
from bert_padding import unpad_input, pad_input
from flash_attn_interface import flash_attn_func
from src.bert_padding import unpad_input, pad_input
from src.flash_attn_interface import flash_attn_func


def attention_ref(qkv, attn_mask, dropout_p, upcast=False, causal=False):
Expand Down
22 changes: 1 addition & 21 deletions benchmarks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,27 +99,7 @@ def pytorch_profiler(fn, *inputs, repeats=10):
) as p:
# benchmark_forward(repeats, fn, *inputs)
fn(*inputs)
print(p.key_averages().table(
sort_by="self_cuda_time_total", row_limit=-1))


def convert_data(*tensors, device='cuda'):
tensors = tuple(t.to(device) for t in tensors)
for t in tensors:
if t.is_leaf: t.requires_grad = True
t.retain_grad()
return tensors


def log_backward(output, *inputs):
""" Perform backward pass of output and print gradients of input tensors. """

#print(f"{output=}")
output.sum().backward(retain_graph=True)
print("Gradients:")
for t in inputs:
print(t.grad)
t.grad.zero_()
print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1))


def benchmark_memory(fn, *inputs, desc='', verbose=True, **kwinputs):
Expand Down
File renamed without changes.
6 changes: 3 additions & 3 deletions flash_attention.py → src/flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

from einops import rearrange

from rotary import RotaryEmbedding, RotaryEmbedding2D
from flash_attn_interface import flash_attn_func
from bert_padding import unpad_input, pad_input, index_first_axis
from src.rotary import RotaryEmbedding, RotaryEmbedding2D
from src.flash_attn_interface import flash_attn_func
from src.bert_padding import unpad_input, pad_input, index_first_axis


class FlashAttention(nn.Module):
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@

import hydra

from flash_blocksparse_attn_interface import flash_blocksparse_attn_func
from flash_blocksparse_attn_interface import convert_blockmask
from bert_padding import unpad_input, pad_input, index_first_axis
from src.flash_blocksparse_attn_interface import flash_blocksparse_attn_func
from src.flash_blocksparse_attn_interface import convert_blockmask
from src.bert_padding import unpad_input, pad_input, index_first_axis


class FlashBlocksparseAttention(nn.Module):
"""Implement the scaled dot product attention with softmax.
Expand Down
File renamed without changes.
File renamed without changes.

0 comments on commit 67c3779

Please sign in to comment.