Improving Transformer World Models for Data-Efficient RL

Authors: Antoine Dedieu, Joseph Ortiz, Xinghua Lou, Carter Wendelken, J Swaroop Guntupalli, Wolfgang Lehrach, Miguel Lazaro-Gredilla, Kevin Patrick Murphy

ICML 2025 | Venue PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental We present an approach to model-based RL that achieves a new state of the art performance on the challenging Craftax-classic benchmark, an openworld 2D survival game that requires agents to exhibit a wide range of general abilities such as strong generalization, deep exploration, and long-term reasoning. With a series of careful design choices aimed at improving sample efficiency, our MBRL algorithm achieves a reward of 69.66% after only 1M environment steps, significantly outperforming Dreamer V3, which achieves 53.2%, and, for the first time, exceeds human performance of 65.0%. Our method starts by constructing a SOTA model-free baseline, using a novel policy architecture that combines CNNs and RNNs. We then add three improvements to the standard MBRL setup: (a) Dyna with warmup, which trains the policy on real and imaginary data, (b) nearest neighbor tokenizer on image patches, which improves the scheme to create the transformer world model (TWM) inputs, and (c) block teacher forcing, which allows the TWM to reason jointly about the future tokens of the next timestep. In this section, we report our experimental results on the Craftax-classic benchmark. Each experiment is run on 8 H100 GPUs.
Researcher Affiliation Industry 1Google Deep Mind. Correspondence to: Antoine Dedieu <EMAIL>, Joseph Ortiz <EMAIL>.
Pseudocode Yes Algorithm 2, Appendix A.1, presents a pseudocode for our MFRL agent. Algorithm 1 presents the pseudocode for our MBRL approach. Algorithm 3 details the PPO-update-policy, which is called in Steps 1 and 4 in our main Algorithm 1 to update the PPO parameters on a batch of trajectories. Algorithm 4 presents the rollout method, which we call in Steps 1 and 4 of Algorithm 1. Step 3 of Algorithm 1 is implemented as in Algorithm 5.
Open Source Code No The paper does not provide specific access to source code for the methodology described in this paper. It refers to third-party code used for baselines or components, such as 'purejaxrl library (Lu et al., 2022)' and 'IRIS VQ-VAE (Micheli et al., 2022)', but does not state that its own implementation code is released.
Open Datasets Yes To evaluate sample-efficient RL algorithms, it is common to use the Atari-100k benchmark (Kaiser et al., 2019). Specifically, we use the Craftax-classic environment (Matthews et al., 2024), a fast, near-replica of Crafter, implemented in JAX (Bradbury et al., 2018). We conduct additional experiments on Min Atar (Young & Tian, 2019), another grid world environment.
Dataset Splits Yes All methods are compared after interacting with the environment for Ttotal = 1M steps. All the methods collect trajectories of length Tenv = 96 in Nenv = 48 environment (in parallel). For MBRL methods, the imaginary rollouts are of length TWM = 20, and we start generating these (for policy training) after TBP = 200k environment steps. We update the TWM N iters WM = 500 times and the policy N iters AC = 150 times. For evaluation, we leverage an appealing property of Craftax-classic: each observation Ot comes with an array of ground truth symbols St = (S1:R t ), with R = 145. Given 100k pairs (Ot, St), we train a CNN fµ, to predict the symbols from the observation; fµ achieves a 99% validation accuracy. Table 7 summarizes the main parameters used in our MBRL training pipeline.
Hardware Specification Yes Each experiment is run on 8 H100 GPUs. Our MFRL agent only takes 15 minutes to train for 1M environment steps on one A100 GPU.
Software Dependencies Yes JAX: composable transformations of Python+Num Py programs, 2018. Note that for implementing PPO, we start from the code available in the purejaxrl library (Lu et al., 2022). We use flashbax (Toledo et al., 2023) to implement our replay buffer in JAX.
Experiment Setup Yes Each experiment is run on 8 H100 GPUs. All methods are compared after interacting with the environment for Ttotal = 1M steps. All the methods collect trajectories of length Tenv = 96 in Nenv = 48 environment (in parallel). For MBRL methods, the imaginary rollouts are of length TWM = 20, and we start generating these (for policy training) after TBP = 200k environment steps. We update the TWM N iters WM = 500 times and the policy N iters AC = 150 times. Table 5 displays the PPO hyperparameters used for training our SOTA MFRL agent. Table 6 details the different hyperparameters for the transformer world model. Table 7 summarizes the main parameters used in our MBRL training pipeline.