FlashAttention: Memory Efficient attention for Large language Models
Introduction
FlashAttention is a new attention algorithm that computes exact attention with far fewer memory accesses. The main goal of this algorithm is to avoid reading and writing to memory as much as possible, which is a major bottleneck in the performance of attention mechanisms. The paper proposes an IO-aware exact attention algorithm that uses tiling to reduce the number of memory reads/writes between GPU high bandwidth memory and low bandwidth memory.
Technique
The FlashAttention algorithm is designed to optimize both memory requirements and wall-clock time. It incorporates IO-awareness, which means that it divides operations between faster and slower levels of GPU memory to make the whole process more efficient. The algorithm uses tiling to reduce the number of memory reads/writes between GPU high bandwidth memory and low bandwidth memory.
The tiling approach is based on the observation that the attention matrix is usually sparse, with only a small number of non-zero elements. The algorithm divides the attention matrix into smaller tiles, and only reads/writes the non-zero elements of each tile. This reduces the number of memory accesses required, and makes the whole process much faster and more memory-efficient.