Change is Hard: A Closer Look at Subpopulation Shift


Yuzhe Yang*      Haoran Zhang*      Dina Katabi      Marzyeh Ghassemi
MIT CSAIL     



Abstract


Machine learning models often perform poorly on subgroups that are underrepresented in the training data. Yet, little is understood on the variation in mechanisms that cause subpopulation shifts, and how algorithms generalize across such diverse shifts at scale. In this work, we provide a fine-grained analysis of subpopulation shift. We first propose a unified framework that dissects and explains common shifts in subgroups. We then establish a comprehensive benchmark of 20 state-of-the-art algorithms evaluated on 12 real-world datasets in vision, language, and healthcare domains. With results obtained from training over 10,000 models, we reveal intriguing observations for future progress in this space. First, existing algorithms only improve subgroup robustness over certain types of shifts but not others. Moreover, while current algorithms rely on group-annotated validation data for model selection, we find that a simple selection criterion based on worst-class accuracy is surprisingly effective even without any group information. Finally, unlike existing works that solely aim to improve worst-group accuracy (WGA), we demonstrate the fundamental tradeoff between WGA and other important metrics, highlighting the need to carefully choose testing metrics.


Paper


Change is Hard: A Closer Look at Subpopulation Shift
Yuzhe Yang*, Haoran Zhang*, Dina Katabi, and Marzyeh Ghassemi
International Conference on Machine Learning (ICML 2023)
[Paper]  •  [Code]  •  [Poster]  •  [Blog Post]  •  [BibTeX]


Talk




Code, Data, and Models



Code, Data, and Models


SubpopBench Contents


Available Algorithms (~20 algorithms)
  • Empirical Risk Minimization (ERM)
  • Invariant Risk Minimization (IRM)
  • Group Distributionally Robust Optimization (GroupDRO)
  • Conditional Value-at-Risk Distributionally Robust Optimization (CVaRDRO)
  • Mixup (Mixup)
  • Just Train Twice (JTT)
  • Learning from Failure (LfF)
  • Learning Invariant Predictors with Selective Augmentation (LISA)
  • Deep Feature Reweighting (DFR)
  • Maximum Mean Discrepancy (MMD)
  • Deep Correlation Alignment (CORAL)
  • Data Re-Sampling (ReSample)
  • Cost-Sensitive Re-Weighting (ReWeight)
  • Square-Root Re-Weighting (SqrtReWeight)
  • Focal Loss (Focal)
  • Class-Balanced Loss (CBLoss)
  • Label-Distribution-Aware Margin Loss (LDAM)
  • Balanced Softmax (BSoftmax)
  • Classifier Re-Training (CRT)
Model Architectures & Pretraining Methods
  • ResNet-50 on ImageNet-1K using supervised pretraining (resnet_sup_in1k)
  • ResNet-50 on ImageNet-21K using supervised pretraining (resnet_sup_in21k)
  • ResNet-50 on ImageNet-1K using SimCLR (resnet_simclr_in1k)
  • ResNet-50 on ImageNet-1K using Barlow Twins (resnet_barlow_in1k)
  • ResNet-50 on ImageNet-1K using DINO (resnet_dino_in1k)
  • ViT-B on ImageNet-1K using supervised pretraining (vit_sup_in1k)
  • ViT-B on ImageNet-21K using supervised pretraining (vit_sup_in21k)
  • ViT-B from OpenAI CLIP (vit_clip_oai)
  • ViT-B pretrained using CLIP on LAION-2B (vit_clip_laion)
  • ViT-B on SWAG using weakly supervised pretraining (vit_sup_swag)
  • ViT-B on ImageNet-1K using DINO (vit_dino_in1k)
  • BERT-base-uncased (bert-base-uncased)
  • GPT-2 (gpt2)
  • RoBERTa-base-uncased (xlm-roberta-base)
  • SciBERT (allenai/scibert_scivocab_uncased)
  • DistilBERT-uncased (distilbert-base-uncased

Available Datasets (13 datasets)
Subpopulation Shift Scenarios

We characterize four basic types of subpopulation shift using our framework, and categorize each dataset into its most dominant shift type.

  • Spurious Correlations (SC): certain a is spuriously correlated with y in training but not in testing.
  • Attribute Imbalance (AI): certain attributes are sampled with a much smaller probability than others in ptrain, but not in ptest.
  • Class Imbalance (CI): certain (minority) classes are underrepresented in ptrain, but not in ptest.
  • Attribute Generalization (AG): certain attributes can be totally missing in ptrain, but present in ptest.

Evaluation Metrics

We include a variety of metrics aiming for a thorough evaluation from different aspects:

  • Average Accuracy & Worst Accuracy
  • Average Precision & Worst Precision
  • Average F1-score & Worst F1-score
  • Adjusted Accuracy
  • Balanced Accuracy
  • AUROC & AUPRC
  • Expected Calibration Error (ECE)
Model Selection Criteria

We highlight the impact of whether attribute is known in (1) training set and (2) validation set.

We show a few important selection criteria:

  • OracleWorstAcc: Picks the best test-set worst-group accuracy (oracle)
  • ValWorstAccAttributeYes: Picks the best val-set worst-group accuracy (attributes known in validation)
  • ValWorstAccAttributeNo: Picks the best val-set worst-class accuracy (attributes unknown in validation; group degenerates to class)


Highlights


(1) Characterizing Basic Types of Subpopulation Shift


(2) SOTA Algorithms Only Improve Subgroup Robustness on Certain Types of Shift


(3) The Role of Representation and Classifier


(4) Impact of Model Selection & Attribute Availability


(5) Fundamental Tradeoff Between WGA and Other Metrics



Press



Citation


@inproceedings{yang2023change,
  title={Change is Hard: A Closer Look at Subpopulation Shift},
  author={Yang, Yuzhe and Zhang, Haoran and Katabi, Dina and Ghassemi, Marzyeh},
  booktitle={International Conference on Machine Learning},
  year={2023}
}