FlashAttention: Memory Efficient attention for Large language Models

Arun Rajendran
3 min readFeb 15, 2024
Photo by Scott Graham on Unsplash

Introduction

FlashAttention is a new attention algorithm that computes exact attention with far fewer memory accesses. The main goal of this algorithm is to avoid reading and writing to memory as much as possible, which is a major bottleneck in the performance of attention mechanisms. The paper proposes an IO-aware exact attention algorithm that uses tiling to reduce the number of memory reads/writes between GPU high bandwidth memory and low bandwidth memory.

Technique

The FlashAttention algorithm is designed to optimize both memory requirements and wall-clock time. It incorporates IO-awareness, which means that it divides operations between faster and slower levels of GPU memory to make the whole process more efficient. The algorithm uses tiling to reduce the number of memory reads/writes between GPU high bandwidth memory and low bandwidth memory.

The tiling approach is based on the observation that the attention matrix is usually sparse, with only a small number of non-zero elements. The algorithm divides the attention matrix into smaller tiles, and only reads/writes the non-zero elements of each tile. This reduces the number of memory accesses required, and makes the whole process much faster and more memory-efficient.

--

--

Arun Rajendran

Lead Machine Learning Engineer focused on NLP. I hope to write articles on Machine learning, travel, personal finance and investment.