OmniKV: Dynamic Context Selection for Efficient Long-Context LLMs

Authors: Jitai Hao, Yuke Zhu, Tian Wang, Jun Yu, Xin Xin, Bo Zheng, Zhaochun Ren, Sheng Guo

ICLR 2025 | Venue PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental Extensive experiments demonstrate that Omni KV achieves state-of-the-art performance across multiple benchmarks, with particularly advantages in chainof-thoughts scenarios. Omni KV extends the maximum context length supported by a single A100 for Llama-3-8B from 128K to 450K. Our code is available at https://github.com/antgroup/OmniKV.git
Researcher Affiliation Collaboration Jitai Hao2 , Yuke Zhu1 , Tian Wang1, Jun Yu3, Xin Xin2, Bo Zheng1, Zhaochun Ren4, Sheng Guo 1 1MYbank, Ant Group 2Shandong University 3Harbin Institute of Technology 4Leiden University
Pseudocode Yes Algorithm 1: Attention Forward of Omni KV
Open Source Code Yes Our code is available at https://github.com/antgroup/OmniKV.git
Open Datasets Yes To demonstrate the effectiveness of Omni KV, we conducted extensive experiments on Llama-38B-262K, Yi-9B-200K and Llama-3.1-70B-Instruct mainly using the datasets Infinite Bench (Zhang et al., 2024a) and Long Bench (Bai et al., 2023).
Dataset Splits No To test Omni KV s performance in single-step reasoning, we primarily used two widely applied benchmarks: 1) Infinite Bench (Zhang et al., 2024a) with an average length of 145.1K, covering multiple tasks. We uniformly adopted a 128K context for testing and truncated inputs exceeding 128K at the middle. 2) We tested Long Bench s 18 tasks across multiple categories, with most tasks average length ranging from 5K to 15K (Bai et al., 2023). During testing, all models supported a context length longer than the longest sample, eliminating the need for truncation.
Hardware Specification Yes All performance and latency experiments were conducted on Nvidia A100 GPUs. Llama-3.1-70B utilized 4-bit weight quantization via bitsandbytes (Dettmers et al., 2021), while other models employed float16 formatting. We makes minor modifications based on Huggingface s transformers (Wolf et al., 2020). For exponential and uniform context selectors, we set the local window size to 16. The filter layers L are set respectively {2, 8, 18}, {6, 11, 30}, {4, 19, 41} for Llama-3-8B-262K, Yi-9B-200K and Llama-3.1-70B-Instruct. We evaluated the end-to-end inference latency of Omni KV using a single NVIDIA A100 80GB GPU and 12 cores of an Intel Xeon Platinum 8369B CPU at 2.90GHz.
Software Dependencies No Llama-3.1-70B utilized 4-bit weight quantization via bitsandbytes (Dettmers et al., 2021), while other models employed float16 formatting. We makes minor modifications based on Huggingface s transformers (Wolf et al., 2020). For exponential and uniform context selectors, we set the local window size to 16. The filter layers L are set respectively {2, 8, 18}, {6, 11, 30}, {4, 19, 41} for Llama-3-8B-262K, Yi-9B-200K and Llama-3.1-70B-Instruct.
Experiment Setup Yes For most tasks, we adopted greedy decoding. To prevent repetitive outputs, we employed top-p decoding with p = 0.95, temperature = 0.8 for summarization tasks in Infinite Bench. All performance and latency experiments were conducted on Nvidia A100 GPUs. Llama-3.1-70B utilized 4-bit weight quantization via bitsandbytes (Dettmers et al., 2021), while other models employed float16 formatting. We makes minor modifications based on Huggingface s transformers (Wolf et al., 2020). For exponential and uniform context selectors, we set the local window size to 16. The filter layers L are set respectively {2, 8, 18}, {6, 11, 30}, {4, 19, 41} for Llama-3-8B-262K, Yi-9B-200K and Llama-3.1-70B-Instruct. We evaluated the end-to-end inference latency of Omni KV using a single NVIDIA A100 80GB GPU and 12 cores of an Intel Xeon Platinum 8369B CPU at 2.90GHz. A 32-layer LLa MA-3-8B-262K model was utilized, with filter layers L = {2, 8, 18}, applying flash attention (Dao et al., 2022), and a batch size of 1. For different context length settings, the token budget for sparse attention was set to 2048. During decoding, 50 tokens were generated, and the mean latency per token was calculated over all decoding steps.