MaxxVit (CoAtNet, MaxVit, and related experimental weights)
timm trained weightsWeights were created reproducing the paper architectures and exploring timm sepcific additions such as ConvNeXt blocks, parallel partitioning, and other experiments.
Weights were trained on a mix of TPU and GPU systems. Bulk of weights were trained on TPU via the TRC program (https://sites.research.google/trc/about/).
CoAtNet variants run particularly well on TPU, it's a great combination. MaxVit is better suited to GPU due to the window partitioning, although there are some optimizations that can be made to improve TPU padding/utilization incl using 256x256 image size (8, 8) windo/grid size, and keeping format in NCHW for partition attention when using PyTorch XLA.
Glossary:
coatnet - CoAtNet (MBConv + transformer blocks)coatnext - CoAtNet w/ ConvNeXt conv blocksmaxvit - MaxViT (MBConv + block (ala swin) and grid partioning transformer blocks)maxxvit - MaxViT w/ ConvNeXt conv blocksrmlp - relative position embedding w/ MLP (can be resized) -- if this isn't in model name, it's using relative position bias (ala swin)rw - my variations on the model, slight differences in sizing / pooling / etc from Google paper specResults:
maxvit_rmlp_pico_rw_256 - 80.5 @ 256, 81.3 @ 320 (T)coatnet_nano_rw_224 - 81.7 @ 224 (T)coatnext_nano_rw_224 - 82.0 @ 224 (G) -- (uses convnext block, no BatchNorm)coatnet_rmlp_nano_rw_224 - 82.0 @ 224, 82.8 @ 320 (T)coatnet_0_rw_224 - 82.4 (T) -- NOTE timm '0' coatnets have 2 more 3rd stage blockscoatnet_bn_0_rw_224 - 82.4 (T) -- all BatchNorm, no LayerNormmaxvit_nano_rw_256 - 82.9 @ 256 (T)maxvit_rmlp_nano_rw_256 - 83.0 @ 256, 83.6 @ 320 (T)maxxvit_rmlp_nano_rw_256 - 83.0 @ 256, 83.7 @ 320 (G) (uses convnext conv block, no BatchNorm)coatnet_rmlp_1_rw_224 - 83.4 @ 224, 84 @ 320 (T)maxvit_tiny_rw_224 - 83.5 @ 224 (G)coatnet_1_rw_224 - 83.6 @ 224 (G)maxvit_rmlp_tiny_rw_256 - 84.2 @ 256, 84.8 @ 320 (T)maxvit_rmlp_small_rw_224 - 84.5 @ 224, 85.1 @ 320 (G)maxxvit_rmlp_small_rw_256 - 84.6 @ 256, 84.9 @ 288 (G) -- could be trained better, hparms need tuning (uses convnext conv block, no BN)coatnet_rmlp_2_rw_224 - 84.6 @ 224, 85 @ 320 (T)(T) = TPU trained with bits_and_tpu branch training code, (G) = GPU trained
Fetched April 7, 2026