Near-optimal Sketchy Natural Gradients for Physics-Informed Neural Networks

Authors: Maricela Best Mckay, Avleen Kaur, Chen Greif, Brian Wetton

ICML 2025 | Venue PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental We develop a randomized algorithm for natural gradient descent for PINNs that uses sketching to approximate the natural gradient descent direction. We prove that the change of coordinate Gram matrix used in a natural gradient descent update has rapidly-decaying eigenvalues for a one-layer, one-dimensional neural network and empirically demonstrate that this structure holds for four different example problems. Under this structure, our sketching algorithm is guaranteed to provide a near-optimal lowrank approximation of the Gramian. Our algorithm dramatically speeds up computation time and reduces memory overhead. Additionally, in our experiments, the sketched natural gradient outperforms the original natural gradient in terms of accuracy, often achieving an error that is an order of magnitude smaller.
Researcher Affiliation Academia 1Department of Mathematics, University of British Columbia, BC, Canada 2Department of Computer Science, University of British Columbia, BC, Canada. Correspondence to: Maricela Best Mckay <EMAIL>.
Pseudocode Yes Algorithm 1 Sketchy Natural Gradient Descent (SNGD)
Open Source Code Yes Code to reproduce experiments in this manuscript is available at https://github.com/MaricelaM/ICML25SNGD.git.
Open Datasets No The paper uses standard partial differential equations (PDEs) for its experiments (heat equation, Poisson's equation, non-linear boundary-value problem, transport equation). These are described mathematically, but no external, publicly available datasets are explicitly used or referenced with access information in the typical machine learning sense.
Dataset Splits No The paper deals with solving Physics-Informed Neural Networks (PINNs) for PDEs, which typically use collocation points derived from the problem domain rather than pre-defined dataset splits. It does not mention any explicit training/test/validation dataset splits in the conventional sense for machine learning datasets.
Hardware Specification Yes All experiments were run with Google Colaboratory using an NVIDIA A100 GPU.
Software Dependencies No The paper mentions several software libraries: Python, Jax, Equinox, Optax, and Jaxopt. However, it does not provide specific version numbers for any of these components, which is necessary for a reproducible description of ancillary software.
Experiment Setup Yes The first two architectures are trained for 2,500 iterations for ENGD and SNGD, and 3,000 for BFGS. The last architecture is trained using SNGD for 1,000 iterations. ... ADAM is run for 100,000 iterations for all network sizes. ... We thus choose tol to be 1E-13, which we have found to be performant and accurate enough. ... All experiments were run using double precision.