Implicit Language Models are RNNs: Balancing Parallelization and Expressivity

Authors: Mark Schöne, Babak Rahmani, Heiner Kremer, Fabian Falck, Hitesh Ballani, Jannes Gladrow

ICML 2025 | Venue PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental Empirically, we find that only approximate fixed-point convergence suffices, enabling the design of a scalable training curriculum that largely retains parallelization, with full convergence required only for a small subset of tokens. Our approach demonstrates superior state-tracking capabilities on regular languages, surpassing transformers and SSMs. We further scale implicit SSMs to natural language reasoning tasks and pretraining of large-scale language models up to 1.3B parameters on 207B tokens representing, to our knowledge, the largest implicit model trained to date. Notably, our implicit models outperform their explicit counterparts on standard benchmarks. Our code is publicly available at github.com/microsoft/ implicit_languagemodels.
Researcher Affiliation Collaboration 1Chair of Highly-Parallel VLSI Systems and Neuro-Microelectronics , TUD Dresden University of Technology, Dresden, Germany 2Microsoft Research, Cambridge, United Kingdom. Correspondence to: Jannes Gladrow <EMAIL>.
Pseudocode No The paper describes methods using mathematical equations and diagrams (e.g., Figure 5) to illustrate fixed-point iteration and gradient computation. However, it does not include a clearly labeled pseudocode or algorithm block with structured steps.
Open Source Code Yes Our code is publicly available at github.com/microsoft/ implicit_languagemodels.
Open Datasets Yes These models are pretrained in an autoregressive manner for next-token prediction across all sizes on the D-PILE (Gao et al., 2020) dataset, which consists of 207B tokens. We use the CATBABI dataset (Schlag et al., 2021), a modified version of the BABI dataset (Weston et al., 2015).
Dataset Splits Yes These models are pretrained in an autoregressive manner for next-token prediction across all sizes on the D-PILE (Gao et al., 2020) dataset, which consists of 207B tokens. For baselines, we use both Mamba2 (Dao & Gu, 2024) and Llama (Beck et al., 2024) models previously trained on a corpus of 300B tokens. Additionally, we reproduce Mamba2 and Llama as baselines trained with the same code and data as our implicit models. We evaluate the pretrained models on the test set of the D-PILE, examine their length extrapolation capabilities, and assess their common sense reasoning performance on downstream tasks. See Appendix D.3 for pretraining details. The test split represents a random selection of 0.1 percent of the entire dataset. This size is in line with the proportion used for the PILE s validation set (Gao et al., 2020). All models were trained with an effective batch size of 1M tokens and a training sequence length of 2048 tokens. We tested models trained with 70, 80, and 90 percent of the bounded phase duration before starting the full fixed-point search in the free phase refer to Fig. 14a.
Hardware Specification Yes All word problem models were trained on sequences of length L = 256, and a batch size of 512 on 32 GB V100s. We trained our suite of models on a cluster with AMD Instinct MI300X GPUs. Each node within the cluster comprises 8 GPUs, and we employed distributed multi-node processing to train our models on up to 32 GPUs simultaneously. The evalution of models on downstream tasks was achieved on one 80GB Nvidia H100 GPU.
Software Dependencies No The paper mentions software components like 'Adam W' and 'RMSNorm' and refers to 'PyTorch' in the context of automatic differentiation, but it does not specify version numbers for any software dependencies required to reproduce the experiments.
Experiment Setup Yes The learning rate is set to 0.001. We disable dropout and weight decay, which appears to harm learning on the word problem. We trained the models using batch sizes of 128 and 256, and learning rates of 0.0001, 0.0005, 0.001, and 0.005. The models were trained for 15,000 steps, with the implicit model specifically trained for 5,000 steps in unrolling mode, utilizing 32 steps with normal gradient checkpointing, followed by 10,000 steps of self-iteration fixed-point search. The self-iteration included a stop threshold of 0.03 and a training and testing maximum number of steps 50 and 200 , respectively, and phantom gradient parameters of 6 steps with (λ = 0.5). In particular, we used a weight decay of 0.1, no bias for the LLM head, Adam W hyperparameters β = (0.9, 0.95), RMSNorm instead of Layer Norm, and a linear warm-up step to the peak learning value, which is chosen as 5 times the value of the GPT-3 model. For the learning rate scheduler, we used a constant learning rate followed by a square root decay to a minimum value of 10-5... All models were trained with an effective batch size of 1M tokens and a training sequence length of 2048 tokens. In the bounded stage, we train with four self-iterations and a single step of phantom gradient, which we refer to as the (4 + 1)-model... The free stage starts from a checkpoint of the (4 + 1)-model and increases the number of self-iterations to 24/32 followed by four steps of phantom gradient. We refer to these models as (24 + 4)/(32 + 4)-models for Mamba2/Llama, respectively. A stopping criterion of ε = 0.05 is implemented during this second phase, allowing models to terminate the fixed point search once this threshold is met.