The current integration of FlashAttention [1,2] for training in most libraries (including HuggingFace transformers) are based on non-intrusive changes to these libraries and most implementations just replace the Naive Attention module with FlashAttention [1,2]. Though this is quite easy to implement, its suboptimal when using sequences of variable lengths in a batch (i.e when we have padding in the batch). All operations except attention are applied independently to each token position in the transformer. And since FlashAttention [1,2] completely avoids any computation/memory requirements on pad tokens, its possible to drop all reduntant computations and memory needed for padding tokens from the transformer model and essentially create a Padding-Free transformer model when using FlashAttention. This is also done in the original FlashAttention training codebase. It should be noted that this is an exact implementation of the model and has no approximations.Similar optimizations are done in HuggingFace TGI for inference efficiency. It should be noted that this would not be a problem in cases where padding of the batch is not needed i.e if a batch has all examples of equal length or when using dense packing of examples (as is the case for pretraining models).In this blog, we give the theoretical memory consumptions for naive attention, FlashAttention with padded transformer blocks (current implementation in HuggingFace transformers library) and the Padding-Free transformer blocks.Lets assume an input batch of embeddings of shape (b,max{si},h)(b, max{s_i}, h)(b,max{si​},h) as input to a transformer layer where bbb, sis_isi​ and hhh denote the batch size, unpadded sequence length of the ithi^{th}ith example in the batch and the hidden size for the transformer model respectively. For training the model, each transformer layer needs to cache activations of each operation (computed in the forward pass) for the backward pass. We assume 16-bit precision for training (2 bytes per value in a tensor). We also assume Multi-Head Attention [6] with aaa attention heads here for simplicity. Though same idea also applies to Multi-Query Attention [7] and Grouped-Query Attention [8].

The input LayerNorm receives an input of shape (b,max{si},h)(b, max{s_i}, h)(b,max{si​},h) which needs to be cached for backward pass. The mean and variance also need to be cached which are each of shape (b,max{si})(b, max{s_i})(b,max{si​}). Since, max⁡{si}bh≫2b×max{si}\max{s_i}bh \gg 2b \times max{s_i}max{si​}bh≫2b×max{si​}, we can ignore the 2b×max{si}2b \times max{s_i}2b×max{si​} elements for mean and variance. Total activation memory for this operation is 2×max{si}bhbytes. 2 \times max{s_i}bh \text{\hspace{1em}bytes.} 2×max{si​}bhbytes.The input (shared among the Q, K and V projections) to the QKV projection matrices needs to be cached. It also has max{si}bhmax{s_i}bhmax{si​}bh elements taking 2×max{si}bh2 \times max{s_i}bh2×max{si​}bh bytes. The outputs of each of Q, K and V projection also need to be cached, each of which has max{si}bhmax{s_i}bhmax{si​}bh elements taking 2×max{si}bh2 \times max{s_i}bh2×max{si​}bh bytes each. Total activation memory for this operation is 2×max{si}bh+3×2×max{si}bh=8×max{si}bhbytes. 2 \times max{s_i}bh + 3 \times 2 \times max{s_i}bh = 8 \times max{s_i}bh \text{\hspace{1em}bytes.} 2×max{si​}bh+3×2×max{si​}bh=8×max{si​}bhbytes.The output of softmax which has (max{si})2ab(max{s_i})^2ab(max{si​})2ab elements also needs to be cached. Total activation memory is for this operation is 2×(max{si})2abbytes. 2 \times (max{s_i})^2ab \text{\hspace{1em}bytes.} 2×(max{si​})2abbytes.Attention softmax has a dropout which requires saving a mask of (max{si})2ab(max{s_i})^2ab(max{si​})2ab elements. Each element takes a byte since PyTorch doesn’t allow bit tensors. The reason for this is probably an ease of implementation since GPUs are generally byte-addressable. Total memory for this operation is (max{si})2abbytes. (max{s_i})^2ab \text{\hspace{1em}bytes.} (max{si​})2abbytes.The softmax dropout output has (max{si})2ab(max{s_i})^2ab(max{si​})2ab elements which also needs to be cached. Total activation memory for this operation is 2×(max{si})2abbytes. 2 \times (max{s_i})^2ab \text{\hspace{1em}bytes.} 2×(max{si​})2abbytes.We cache the output of the above multiplication which is the input to the projection matrix. It has max{si}bhmax{s_i}bhmax{si​}bh elements. Total activation memory for this operation is 2×max{si}bhbytes. 2 \times max{s_i}bh \text{\hspace{1em}bytes.} 2×max{si​}bhbytes.Only the dropout mask needs to be cached. Total memory for this operation is max{si}bhbytes. max{s_i}bh \text{\hspace{1em}bytes.} max{si​}bhbytes.Same as the previous layernorm. Memory requirement is 2×max{si}bhbytes. 2 \times max{s_i}bh \text{\hspace{1em}bytes.} 2×max{si​}bhbytes.We assume here that the feedforward hidden dimension is f=4hf = 4hf=4h as is typical for a standard transformer. Inputs to each linear layer and the input to GELU activation function needs to be cached. These take 2×max{si}bh2 \times max{s_i}bh2×max{si​}bh, 8×max{si}bh8 \times max{s_i}bh8×max{si​}bh bytes and 8×max{si}bh8 \times max{s_i}bh8×max{si​}bh bytes respectively. The required memory for the MLP block is 18×max{si}bhbytes. 18 \times max{s_i}bh \text{\hspace{1em}bytes.} 18×max{si​}bhbytes.Memory required is same as point (8) above i.e max{si}bhbytes. max{s_i}bh \text{\hspace{1em}bytes.} max{si​}bhbytes.Summing these up, total activation memory per layer is given by: Mnaive=max{si}bh(34+5a×max{si}h) M_{naive} = max{s_i}bh \left(34 + \frac{5a \times max{s_i} }{h} \right) Mnaive​=max{si​}bh(34+h5a×max{si​}​)FlashAttention [1,2] has been integrated into the HuggingFace transformers API. The current implementation at the time of writing this blog does an unpad operation just before FlashAttention kernel is executed. This operation converts the input Q, K, V of shape (b,max{si},h)(b, max{s_i}, h)(b,max{si​},h) to shape (∑si,h)\left(\sum s_i, h\right)(∑si​,h) (where each example in the batch is concatenated one after the other resulting in a 2D tensor) and launches the FlashAttention kernel. Post attention computation, the output is padded again to the shape (b,max{si},h)(b, max{s_i}, h)(b,max{si​},h).FlashAttention [1,2] avoids materializing the QKTQK^TQKT quadratic matrix in memory and uses online softmax [3], thereby dropping the need to cache activations in point (3). Rather we only need to materialize the output matrix which has shape (∑si,a,ha)\left(\sum s_i, a, \frac{h}{a}\right)(∑si​,a,ah​), the 2 softmax statistics both of which have the same shape (∑si,a)\left(\sum s_i, a\right)(∑si​,a) and the random number generator state for the dropout which we ignore here. For the algorithm in detail, refer to FlashAttention [1,2] paper. We also need to cache the attention mask of booleans which is used for padding and unpadding. We ignore it in calculations though since its same for every layer and can be cached once for the entire transformer model and doesn’t need to be cached on every layer. Thus the memory required for attention becomes 2∑sia(ha)+4∑sia=2∑si(h+2a)bytes. 2 \sum s_i a\left(\frac{h}{a}\right) + 4 \sum s_i a = 2 \sum s_i (h + 2a) \text{\hspace{1em}bytes.} 2∑si​a(ah​)+4∑si​a=2∑si​(h+2a)bytes.Thus we have the total activation memory per layer with FlashAttention [1,2] as follows: Mflash=max{si}bh(34+∑simax{si}b[1+2ah]) M_{flash} = max{s_i}bh \left(34 + \frac{\sum s_i}{max{s_i}b} \left[1 + \frac{2a}{h}\right] \right) Mflash​=max{si​}bh(34+max{si​}b∑si​​[1+h2a​])Since all operations (except attention) in the transformer layer are same for each token position, we can avoid the padding and unpadding operation and thus reduce the activation memory required by the transformer layer further, this requires minor changes to the HuggingFace transformers implementation. In this implementation of the transformer, there is no wasted memory for pad token positions at all! In this case, the input to the entire transformer model is of the shape (∑si,h)\left(\sum s_i, h\right)(∑si​,h). The memory in this case is given by Mpadding_free=(∑si)h(35+2ah) M_{padding_free} = \left( \sum s_i \right) h \left(35 + \frac{2a}{h} \right) Mpadding_free​=(∑si​)h(35+h2a​)It should be noted that Mflash=Mpadding_freeM_{flash} = M_{padding_free}Mflash​=Mpadding_free​ when there is no padding i.e when si=max⁡{si}∀i∈{1,2,…,b}s_i = \max{s_i} \forall i \in {1, 2, …, b}si​=max{si​}∀i∈{1,2,…,b}. This optimization is similar to running a transformer model with nested tensors. While there has been significant effort to resolve this problem by taking approches like binning examples by context lengths, these lead to model performance degradation especially during finetuning.Now, we analyze the memory consumptions in the 3 transformer layer implementations. We assume that we have a dataset of sequences of lengths following a discrete uniform distribution i.e Si∼U{1,2,…,N}S_i \sim U{1, 2, …, N}Si​∼U{1,2,…,N}, where SiS_iSi​ is the random variable denoting the sequence length of ithi^{th}ith sample in the batch and NNN is the maximum sequence length for the dataset and the model. We sample batches with bbb examples each, with sequences of lengths (S1,S2,…,Sb)(S_1, S_2, …, S_b)(S1​,S2​,…,Sb​). We compute the expectation E[Mnaive]\mathbb{E}[M_{naive}]E[Mnaive​], E[Mflash]\mathbb{E}[M_{flash}]E[Mflash​] and E[Mpadding_free]\mathbb{E}[M_{padding_free}]E[Mpadding_free​] under the discrete uniform distribution. To do so, we consider another random variable K=max{Si}K = max{S_i}K=max{Si​}. The Cumulative Distribution Function for KKK can be derived as: P(K≤k)=P(max{Si}≤k)=P(S1≤k,S2≤k,…,Sb≤k) P(K \le k) = P(max{S_i} \le k) = P(S_1 \le k, S_2 \le k, …, S_b \le k) P(K≤k)=P(max{Si​}≤k)=P(S1​≤k,S2​≤k,…,Sb​≤k) Now, using the fact that examples in a batch are i.i.d, we have   ⟹  P(K≤k)=[P(Si≤k)]b=(kN)b \implies P(K \le k) = [P(S_i \le k)] ^ b = \left( \frac{k}{N} \right) ^ b ⟹P(K≤k)=[P(Si​≤k)]b=(Nk​)b and thus we have the Probability Mass Function for KKK as: P(K=k)=P(K≤k)−P(K≤k−1)=(kN)b−(k−1N)b P(K = k) = P(K \le k) - P(K \le k - 1) = \left(\frac{k}{N}\right) ^ b - \left(\frac{k - 1}{N}\right) ^ b P(K=k)=P(K≤k)−P(K≤k−1)=(Nk​)b−(Nk−1​)b We can use computational methods or Faulhaber’s formula [9] with the aforementioned derived result to calculate the expectations of the memory usage in the 3 methods. We report the theoretical memory consumption derived using the equations for a 20B parameter model in the following table. We find that using a Padding-Free version of the transformer layer saves ∼43%\sim43\%∼43% activation memory and also saves a lot of redundant FLOPs. We leave the analysis of FLOPs out of this blog but they are easily derivable.Table: Memory usage per transformer layer for different attention implementations at different context lengths for a 20B parameter model with context length (N=8192)(N = 8192)(N=8192), hidden size (h=6144)(h = 6144)(h=6144), FFN hidden size (f=24576)(f = 24576)(f=24576), attention heads (a=48)(a = 48)(a=48).

In this blog, we present a way to completely avoid computations and memory requirements of pad tokens during finetuning of transformer models using FlashAttention. Our changes are easily integrable into the HuggingFace transformers ecosystem for finetuning. We also derive equations for theoretical memory consumption for the same in this blog. The method doesn’t involve writing any low level device code. The only non-native PyTorch code we use is FlashAttention which is already available.· Sign up or log in to comment