
Jigsaw: Training Multi-Billion-Parameter AI Weather Models With Optimized Model ParallelismAuthors: Deifilia Kieckhefen, Markus Götz, Lars H. Heyen, Achim Streit, and Charlotte Debus (Karlsruhe Institute of Technology, Helmholtz AI)
The paper introduces WeatherMixer (WM), a multi-layer perceptron (MLP)-based architecture designed for atmospheric forecasting, which serves as a competitive alternative to Transformer-based models. WM's workload scales linearly with input size, addressing the scaling challenges and quadratic computational complexity associated with the self-attention mechanism in Transformers when dealing with gigabyte-sized atmospheric data.• A novel parallelization scheme called Jigsaw parallelism is proposed, combining both domain parallelism and tensor parallelism to efficiently train multi-billion-parameter models. Jigsaw is optimized for large input data by fully sharding the data, model parameters, and optimizer states across devices, eliminating memory redundancy.
Jigsaw effectively mitigates hardware bottlenecks, particularly I/O-bandwidth limitations frequently encountered in training large scientific AI models. Due to its partitioned data loading (domain parallelism), the scheme achieves superscalar weak scaling in I/O-bandwidth-limited systems.
The method demonstrates excellent scaling behavior on high-performance computing systems, exceeding state-of-the-art performance in strong scaling in computation–communication-limited systems. The training was successfully scaled up to 256 GPUs, reaching peak performances of 9 and 11 PFLOPs.• Beyond hardware efficiency, Jigsaw improves predictive performance: by partitioning the model across more GPUs (model parallelism) instead of relying solely on data parallelism, it naturally enforces smaller global batch sizes, which empirically helps mitigate the problematic large-batch effects observed in AI weather models, leading to lower loss values.