Scaling Offline Model-Based RL via Jointly-Optimized World-Action Model Pretraining
Authors: Jie Cheng, Ruixi Qiao, ma yingwei, Binhua Li, Gang Xiong, Qinghai Miao, Yongbin Li, Yisheng Lv
ICLR 2025 | Venue PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | Experimental results indicate that our largest agent, with 150 million parameters, achieves 78.9% human-level performance on pretrained games using only 10% subsampled offline data, outperforming existing state-of-the-art large-scale offline RL baselines by 71.4% on averange. Furthermore, JOWA scales favorably with model capacity and can sample-efficiently transfer to novel games using only 5k offline fine-tuning data (approximately 4 trajectories) per game, demonstrating superior generalization. Our ablation studies highlight the significance of two key design features of JOWA: joint optimization and planning, along with other training choices. |
| Researcher Affiliation | Collaboration | Jie Cheng1,2, Ruixi Qiao1,2, Yingwei Ma3, Binhua Li3, Gang Xiong1,2, Qinghai Miao2, Yongbin Li3 , Yisheng Lv1,2 1State Key Laboratory of Multimodal Artificial Intelligence Systems, Institute of Automation, Chinese Academy of Sciences 2School of Artificial Intelligence, University of Chinese Academy of Sciences 3Alibaba Group |
| Pseudocode | No | The paper describes the planning algorithm in Section 4.2 'PARALLELIZABLE PLANNING AT INFERENCE TIME' using paragraph text, but does not include a clearly labeled pseudocode or algorithm block. |
| Open Source Code | Yes | The code and checkpoints will be released at https://github.com/CJReinforce/JOWA. |
| Open Datasets | Yes | Dataset. We use the Atari dataset from Agarwal et al. (2020), which consists of 50M transitions from each of 5 separate training runs. |
| Dataset Splits | Yes | 15 of those games are used for training, and 5 games are held out for OOD generalization experiments. Following Lee et al. (2022), we use data from 2 out of 5 training runs. To investigate performance in low-data regime, we uniformly draw 10% of transitions at random, as per Agarwal et al. (2020), resulting in 10M transitions per game. For fine-tuning, we train models for 50k gradient steps with 5k transitions. We fine-tune pretrained agents on 5 held-out games using uniformly subsampled 5k expert-level transitions (from last 20% of DQN-Replay (Agarwal et al., 2020)) per game as the benchmark. We additionally fine-tune JOWA-150M using 5k suboptimal and highly-suboptimal transitions. Specifically, the suboptimal and the highly-suboptimal transitions are uniformly sampled from the complete and the initial 20% of the DQN-Replay dataset, respectively. |
| Hardware Specification | Yes | The whole training process took around 12 days on A100 GPUs. |
| Software Dependencies | No | The paper mentions models/libraries like GPT-2 (Brown et al., 2020), min GPT2 with Flash Attention (Dao et al., 2022), IRIS (Micheli et al., 2022), EDT, and a PyTorch version of Scaled-QL, but does not specify version numbers for Python, PyTorch, or other specific software dependencies. |
| Experiment Setup | Yes | The sequence length L is set to 8. We pretrain all JOWA models on A100 GPUs for 1.75M steps. For fine-tuning, we train models for 50k gradient steps with 5k transitions. The hyperparameters of JOWA are listed in Table 13. The training hyperparameters of JOWA are shown in Table 13. The evaluation settings of Atari are shown in Table 16. The planning hyperparameters for each game are shown in Table 17. |