FlashRNN: I/O-Aware Optimization of Traditional RNNs on modern hardware
Authors: Korbinian Pöppel, Maximilian Beck, Sepp Hochreiter
ICLR 2025 | Venue PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | We show that our kernels can achieve 50x speed-ups over a vanilla Py Torch implementation and allow 40x larger hidden sizes compared to our Triton implementation. 6 EXPERIMENTS In Section 6.1, we benchmark the runtime of our Flash RNN kernels and compare against the LSTM and Attention implementations provided in Py Torch. In Section 6.2, we measure training time with Flash RNN kernels on language modeling. Finally, in Section 6.3 we confirm that traditional RNNs like LSTM and s LSTM implemented in Flash RNN can solve state tracking problems. |
| Researcher Affiliation | Collaboration | Korbinian P oppel & Maximilian Beck & Sepp Hochreiter Johannes Kepler University, NXAI Lab and NXAI Gmb H Altenberger Str. 69 A-4040 Linz, Austria EMAIL |
| Pseudocode | Yes | Algorithm 1 Flash RNN-fused forward pass ... Algorithm 2 Flash RNN-fused forward pass ... Algorithm 3 Constr INT Resolution ... Algorithm 4 Constr INT Global ARC-Reduce ... Algorithm 5 Triton Flash RNN Forward Pass |
| Open Source Code | Yes | We have open-sourced our kernels and the optimization library to boost research in the direction of state-tracking enabled RNNs and sequence modeling here: https://github.com/NX-AI/flashrnn. ... The code is released here: https://github.com/NX-AI/flashrnn, with the Constr INT library in flashrnn/autotune/constrint.py as single-file implementation and optional caching. |
| Open Datasets | Yes | We use an open dataset (Slim Pajama) that uses publicly crawled internet data for Language Model training. ... For Language Modeling this setup description is provided in Appendix Section J and uses the open Slim Pajama dataset |
| Dataset Splits | Yes | For the parity task we train on the parity task with varying training sequence lengths up to 40. For the reported validation, we evaluate on sequence lengths between 40 and 256. Sequence lengths are uniformly sampled in the respective ranges. |
| Hardware Specification | Yes | For every runtime measurement we do 25 warmup iterations and then report the average across 1000 iterations on NVIDIA H100 GPUs. ... For the A100 experiments, we use one node of eight A100-SXM (80GB) GPUs and a local batch size of 64. For H100-SXM we reduce the local batch size to 32 and use 2 gradient accumulation steps due to Out Of Memory errors, even though they have the same HBM size (80 GB). |
| Software Dependencies | Yes | We use Py Torch 2.4 and with CUDA version 12.4 for our experiments. ... We use Py Torch in version 2.4.0 and CUDA 12.1 for A100 and 12.4 for H100s. |
| Experiment Setup | Yes | We fix the embedding dimension d = NH DH to 768 and vary the head dimension from 16 to 768. We use batch size 16 and sequence length 1024. ... We train with context length 1024 and a global batch size of 512, resulting in roughly 30 k training steps for 15 B tokens of the Slim Pajama dataset. We use the GPT-2 tokenizer and learning rate 1e-3 with linear warmup over 750 steps and cosine decay to 1e-4 over 30k steps. |