FlexPrefill: A Context-Aware Sparse Attention Mechanism for Efficient Long-Sequence Inference
Authors: Xunhao Lai, Jianqiao Lu, Yao Luo, Yiyuan Ma, Xun Zhou
ICLR 2025 | Venue PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | We conduct extensive experiments using state-of-the-art LLMs, including Meta-Llama-3.1-8B-Instruct (Xiong et al., 2024), GLM-4-9B-Chat (Zeng et al., 2024), Yi-9B-200K (Young et al., 2024), and Qwen2-7B-Instruct (Yang et al., 2024), on challenging long-context benchmarks such as RULER (Hsieh et al., 2024) and Infinite Bench (Zhang et al., 2024c). The results demonstrate significant improvements in both speed and accuracy over prior methods, with Flex Prefill consistently preserving or even enhancing model performance across various context lengths and tasks. |
| Researcher Affiliation | Collaboration | Xunhao Lai1 , Jianqiao Lu2 , Yao Luo3, Yiyuan Ma3, Xun Zhou3, 1Peking University 2The University of Hong Kong 3 Byte Dance Inc EMAIL,EMAIL EMAIL |
| Pseudocode | Yes | Algorithm Algorithm 1 presents the overall procedure of our proposed method for efficient sparse attention computation. The algorithm takes the query matrix Q, key matrix K, value matrix V , sparse pattern threshold τ and cumulative attention threshold γ as input. It is divided into the following three parts: (i) Sparse Pattern Determination: Algorithm 2 determines whether to use Query-Aware pattern or fall back to the Vertical-Slash pattern for each attention head. (ii) Sparse Index Selection: Based on the attention patterns obtained in (i) and the given cumulative attention threshold γ, the sparse index set S that needs to be computed for each attention head is obtained by Algorithm 4 (Query-Aware) or Algorithm 3 (Vertical-Slash). |
| Open Source Code | Yes | https://github.com/bytedance/Flex Prefill |
| Open Datasets | Yes | We evaluate the models on two datasets, each offering unique challenges in long-context understanding: (i) RULER (Hsieh et al., 2024): a synthetic benchmark dataset created to evaluate long-context LLMs with customizable sequence lengths and task complexities. It extends the basic needle-in-a-haystack test as well as introduces new task categories such as multi-hop tracing and aggregation. (ii) Infinite Bench (Zhang et al., 2024c): a benchmark dataset designed to test LLMs understanding of long dependencies within extensive contexts, with an average token count of 214k. |
| Dataset Splits | No | The paper evaluates models on RULER and Infinite Bench benchmarks across various context lengths, but does not explicitly specify training, validation, or test splits used for its own experimental setup. |
| Hardware Specification | Yes | Our experiments are conducted in a computing environment equipped with a single NVIDIA A100 GPU with 80GB of memory. |
| Software Dependencies | No | For our experiments, we implement a custom pipeline using Py Torch, building on Flash Attention (Dao, 2024), to ensure efficient attention mechanisms over long-context inputs. Our implementation leverages Triton (Tillet et al., 2019) for optimizing the performance of GPU-accelerated computations and uses a block_size = 128 in all experiments. |
| Experiment Setup | Yes | Our implementation leverages Triton (Tillet et al., 2019) for optimizing the performance of GPU-accelerated computations and uses a block_size = 128 in all experiments... The sparse pattern threshold τ is set to 0.1 for all models. Additionally, by adjusting the parameter γ, we assign a tailored computational budget to each attention head in real-time for each input scenario. We set γ = 0.9 for Yi-9B-200k and Qwen2-7B-Instruct model, and 0.95 for the other models. To ensure that all attention heads work normally, we retain the first and last key blocks of each query block, while also requiring each attention head to compute at least 1024 tokens. All experiments were conducted using greedy decoding to maintain consistency across results. |