PWM: Policy Learning with Multi-Task World Models
Authors: Ignat Georgiev, Varun Giridhar, Nick Hansen, Animesh Garg
ICLR 2025 | Venue PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | Our empirical evaluations on high-dim. tasks indicate that PWM not only achieves higher reward than baselines but also outperforms methods that use ground-truth dynamics. In a multi-task scenario utilizing a pre-trained 48M parameter world model from TD-MPC2, PWM achieves up to 27% higher reward than TD-MPC2 without relying on online planning. This underscores the efficacy of PWM and supports our broader contributions: 1. Correlation Between World Model Smoothness and Policy Performance: Through pedagogical examples and ablations, we demonstrate that smoother, better-regularized world models significantly enhance policy performance. Notably, this results in an inverse correlation between model accuracy and policy performance. 2. Efficiency of First-Order Gradient (Fo G) Optimization: We show that combining Fo G optimization with well-regularized world models enables more efficient policy learning compared to zeroth-order methods. Furthermore, policies learned from world models asymptotically outperform those trained with ground-truth simulation dynamics, emphasizing the importance of the tight relationship between Fo G optimization and world model design. 3. Scalable Multi-Task Algorithm: Instead of training a single multi-task policy model, we propose PWM, a framework where a multi-task world model is first pre-trained on offline data. Then per-task expert policies are extracted in <10 minutes per task, offering a clear and scalable alternative to existing methods focused on unified multi-task models. |
| Researcher Affiliation | Academia | Ignat Georgiev 1, Varun Giridhar 1, Nicklas Hansen 2, Animesh Garg 1 1 Georgia Institute of Technology 2 UC San Diego |
| Pseudocode | Yes | Algorithm 1: PWM: Policy learning with multi-task World Models Given: Multi-task dataset B Given: γ: discount rate Given: αθ, αψ, αϕ: learning rates Initialize learnable parameters θ, ψ, ϕ Pre-train world model once for N epochs do s1:H, a1:H, r1:H, e B ϕ ϕ αϕ Lwm(ϕ) Eq. 10 end Train policy on task embedding e for M epochs do s1 B z1 = Eϕ(s1, e) for h=[1, ..., H] do Rollout ah πθ( |zh) rh = Rϕ(zh, ah, e) zh+1 = Fϕ(zh, ah, e) end θ θ + αθ Lπ(θ) Eq. 6 ψ ψ αψ LV (ψ) Eq. 7-9 end |
| Open Source Code | Yes | Visualizations and code are available at imgeorgiev.com/pwm. |
| Open Datasets | Yes | We harness the same data and world model architecture as TD-MPC2. The data consists of 120k and 40k trajectories per dm_control and Meta World task, respectively generated by 3 random seeds of TD-MPC2 runs. The world models we use are the 48M parameter models introduced in Hansen et al. (2024) with slight modifications to make them differentiable (Appendix C). Reproducibility statement. Code, training data and checkpoints are made available at imgeorgiev.com/pwm. We rely on dflex, Meta World, DMControl and Mu Jo Co for simulation which are publicly available under MIT and Apache 2.0 licenses. We use multi-task data from TD-MPC2 which is publicly available. Implementation details and full list of hyper-parameters are available in Appendix C. |
| Dataset Splits | No | The paper mentions data used for training and evaluating task performance (e.g., "120k and 40k trajectories per dm_control and Meta World task" and "evaluate task performance for 10 seeds for each task"), and that policies are trained on "offline datasets." However, it does not explicitly provide specific training, validation, or test dataset splits (e.g., percentages, absolute counts, or references to predefined splits) for these datasets within the context of the experimental setup for PWM's policy learning. |
| Hardware Specification | Yes | Then we train a PWM policy on each particular task using the offline datasets for 10k gradient steps, which take 9.3 minutes on an Nvidia RTX6000 GPU. |
| Software Dependencies | No | The paper mentions several software components like "dflex, Meta World, DMControl and Mu Jo Co for simulation" and the use of "MLPs with Layer Norm ... and Mish activation." However, it does not provide specific version numbers for any of these software dependencies. |
| Experiment Setup | Yes | Hyper-parameter Value Policy components Horizon (H) 16 Batch size 64 αθ 5 10 4 αψ 5 10 4 Actor grad norm 1 Critic grad norm 100 Actor hidden layers [400, 200, 100] Critic hidden layers [400, 200] Number of critics 3 λ 0.95 γ 0.99 Critic batch split 4 Critic iterations 8 World model components (48M) Latent state (z) dimension 768 Horizon (H) 16 Batch size 1024 αϕ 3 10 4 World model grad norm 20.0 Sim Norm V 8 Reward bins 101 Encoder Eϕ hidden layers [1792, 1792, 1792] Dynamics Fϕ hidden layers [1792, 1792] Reward Rϕ hidden layers [1792, 1792] Task encoding dimension 96 |