AdaMAE: Adaptive Masking for Efficient Spatiotemporal Learning with Masked Autoencoders

LinkGitHub

In this paper, we propose an adaptive masking technique for MAE. Our adaptive masking first estimates a categorical distribution over the input tokens in a given video by analyzing their spatiotemporal information and then samples visible tokens for MAE based on that distribution. Since the sampling process is non-differentiable, we introduce an auxiliary loss term to optimize the sampling network in which we maximize the expected reconstruction loss motivated by the REINFORCE algorithm in reinforcement learning. We empirically show that our adaptive masking results in sampling more tokens from the foreground than the background. We conduct ablation experiments on the SSv2 dataset to demonstrate the superiority of our adaptive sampling over existing techniques. Finally, we compare the results of AdaMAE with SOTA results on action classification benchmarks where we achieve 70.0% and 81.7% in top-1 on SSv2 and K400 with ViT-Base backbone, masking ratio of 95%, & 800 pre-training epochs.

Comparision with existing masking techniques

Comparison of our adaptive masking with existing random, cube, and frame masking for masking ratio of 80%.Our adaptive sampling selects more tokens from the regions with high spatiotemporal information while a small number of tokens from the background as the visible tokens to the MAE.

Adaptive mask Visualizations

  • Here, we show the predicted categorical distribution and sampled mask from our adaptive sampling network for different checkpoints of pre-training.

  • In each visualization, from left to write, we show the selected video, reconstructed video, reconstruction error (MSE), predicted categorical distribution, and the sampled mask.

@5th Epoch-SSv2, From left to right: Input video, Predicted video, Reconstruction error, Predicted categorical distribution, and Sampled mask.

@75th Epoch-SSv2, , From left to right: Input video, Predicted video, Reconstruction error, Predicted categorical distribution, and Sampled mask.

@150th Epoch-SSv2, , From left to right: Input video, Predicted video, Reconstruction error, Predicted categorical distribution, and Sampled mask.

@300th Epoch-SSv2, , From left to right: Input video, Predicted video, Reconstruction error, Predicted categorical distribution, and Sampled mask.

Ablation Experiments

Ablation experiments on SSv2 datase. We use ViT-Base as the backbone for all experiments. MHA (D=2, d=384) denotes our adaptive mask sampling network with a depth of two and embedding dimension of 384. All the pre-trained models are evaluated based on the evaluation protocol detailed in Sec. 4 of the paper. The default choice of our AdaMAE is highlighted in gray color. The GPU memory consumption is reported for batch size of 16 on a single GPU.

Main Analysis

SSv2 dataset: Comparison of our AdaMAE with SOTA methods on SSv2. We report the results for ViT-B [44] architecture. Our model is pre-trained for the default setting in Table 1. The ✓ in extra labels tab denotes supervised data used for pre-training while ✗ denotes only unlabeled data is used for the pre-training. The N/A denotes these numbers are not available/reported in the paper.

K400 dataset: Comparison of our AdaMAE with SOTA methods on K400. We report the results for ViT-B [44] architecture. Our model is pre-trained for the default setting in Table 1. The ✓ in extra labels tab denotes supervised data used for pre-training while ✗ denotes only unlabeled data is used for the pre-training. The N/A denotes these numbers are not available/reported in the paper.

Conclusions

In this paper, we proposed an adaptive masking technique for spatiotemporal representation learning with MAEs. The proposed masking technique is entirely different from the existing masking methods such as random, cube, and tube masking, in which we sample the visible token indices based on a categorical distribution. The distribution itself is estimated by an auxiliary network optimized by minimizing the expected reconstruction error. We empirically show that minimizing this auxiliary loss results in sampling more tokens from the foreground compared to the background, which provides better transferability of extracted features with respect to different downstream applications. The ablation experiments conducted on SSv2 dataset demonstrates the effectiveness of our adaptive sampling over existing mask sampling techniques. Finally, we compare our AdaMAE with previous SOTA methods, and achieve better results on SSv2 and Kinetics-400 datasets for the ViT-Base architecture.

Citation