Maximum Likelihood With Bias-Corrected Calibration is Hard-To-Beat at Label Shift Adaptation 08 Jul 2020

Accepted to ICML 2020
Authors: Amr Alexandari*, Anshul Kundaje†, Avanti Shrikumar*†
(*co-first authors, †co-corresponding authors)

Introduction

Imagine we train a classifier to predict whether or not a person has a disease based on observed symptoms, and the classifier predicts reliably when deployed in the clinic. Suppose that there is a sudden surge in cases of the disease. During such an outbreak, the probability of persons having the disease given that they show symptoms rises, but the symptoms generated by the disease do not change. How can we adapt the classifier to cope with the difference in the baseline prevalence of the disease?

The problem of distribution shift in its most general form is intractable. However, there are reasonable assumptions that can help us solve this problem. If the disease generates similar symptoms in patients regardless of its spread in the community, then this is known as the label shift or prior probability shift (Amos, 2008), and it corresponds to anti-causal learning (i.e. predicting the cause \(y\) from its effects \(\boldsymbol{x}\)). Anti-causal learning is appropriate for diagnosing diseases given observations of symptoms because diseases cause symptoms. In this example, a pandemic would induce a shift in \(p(y)\); there would be a surge in cases. However, the process generating \(\boldsymbol{x}\) given \(y\) is fixed, i.e. \(p_s(\boldsymbol{x} \mid y) = p_t(\boldsymbol{x} \mid y)\); the symptoms generated by the disease do not change. An automated diagnostic system should cope with the difference in the baseline prevalence of the disease and adjust its predictions accordingly. In our paper, we propose using a maximum likelihood procedure with appropriate calibration for doing that, and we show that our approach is hard to beat at adapting to label shifts.

Bias Corrected Temperature Scaling

Calibration has a long history in the machine learning literature (DeGroot and Fienberg, 1983; Platt, 1999; Zadrozny and Elkan; 2002; Niculescu-Mizil and Caruana, 2005; Kuleshov and Liang, 2015; Naeini et al., 2015; Kuleshov and Ermon, 2016). In the context of modern neural networks, Guo et al. (2017) showed that Temperature Scaling (TS), a single-parameter variant of Platt Scaling (Platt, 1999), was effective at reducing miscalibration. Let \(z(\boldsymbol{x_k})\) be a function that returns the original logit vector. With temperature scaling, we have:

\[p(y_i\mid \boldsymbol{x_k})=\frac{e^{z(\boldsymbol{x_k})_i / T}}{\sum_j e^{z(\boldsymbol{x_k})_j /T}}\]

The parameter \(T\) is optimized with respect to the Negative Log Likelihood on a held-out portion of the training set, such as the validation set.

Guo et al. (2017) compared TS to an approach defined as Vector Scaling (VS), where a different scaling parameter was used for each class along with class-specfic bias parameters, namely:

\[p(y_i\mid \boldsymbol{x_k})=\frac{e^{z(\boldsymbol{x_k})_i W_i+b_i}}{\sum_j e^{z(\boldsymbol{x_k})_j W_j+b_j}}\]

The authors found that vector scaling had a tendency to perform slightly worse than TS as measured by a metric known as the Expected Calibration Error (Naeini et al., 2015).

Temperature Scaling exhibits systematic bias. On CIFAR10 data, systematic bias was quantified by the JS divergence between the true class label proportions and the average class predictions on a held-out test set drawn from the same distribution as the dataset used for calibration. TS: Temperature Scaling, NBVS: No-Bias Vector Scaling, BCTS: Bias-Corrected Temperature Scaling, VS: Vector Scaling. BCTS and VS had significantly lower systematic bias compared to TS and NBVS.

As shown in the figure, we often found that TS alone resulted in systematically biased estimates of \(p(y_i \mid \boldsymbol{x_k} )\), while VS, a generalization of TS that contains both class-specific bias terms and class-specific scaling terms, did not exhibit as much systematic bias. Intrigued by this observation, we investigated the performance of two intermediaries between Temperature Scaling and Vector Scaling. The first, which we refer to as No Bias Vector Scaling (NBVS), is equivalent to vector scaling but with all the class-specific bias parameters fixed at zero. The second, which we refer to as Bias-Corrected Temperature Scaling (BCTS), is equivalent TS Scaling but with the addition of the class-specific bias terms from VS. BCTS is defined as follows:

\[p(y_i\mid \boldsymbol{x_k})=\frac{e^{z(\boldsymbol{x_k})_i / T +b_i}}{\sum_j e^{z(\boldsymbol{x_k})_j / T +b_j}}\]

As with TS and VS, the parameters are optimized to minimize the Negative Log Likelihood (NLL) on the validation set. Note that in the case of binary classification, the parameterization of BCTS reduces to Platt Scaling (Platt, 1999). Thus, BCTS can be viewed as a multi-class generalization of Platt scaling.

Another Look at Maximum Likelihood

Saerens et al. 2002 proposed an Expectation Maximization (EM) algorithm that estimates \(p_t(y)\) but assumes access to a classifier that outputs the true source distribution conditional probabilities \(p_s(y \mid \boldsymbol{x})\). The algorithm has the following update steps:

\[p_t(y_i)^{(0)}=p_s(y_i)\] \[p_t(y_i \mid \boldsymbol{x_k} )^{(r)}=\frac{\frac{p_t(y_i)^{(r)}}{p_s(y_i)}p_s(y_i \mid \boldsymbol{x_k} )}{\sum_{j=1}^{n}\frac{p_t(y_j)^{(r)}}{p_s(y_j)}p_s(y_j \mid \boldsymbol{x_k} )}\] \[p_t(y_i)^{(r+1)}=\frac{1}{N}\sum_{k=1}^{N} p_t(y_i \mid \boldsymbol{x_k} )^{(r)}\]

where \(p_s(y_i)\) is our estimate of the prior probability of observing class \(i\) on the training set, \(p_t(y_i)^{(r)}\) is the estimate in EM step \(r\) of the prior probability of observing class \(i\) on the testing set, \(p_s(y_i \mid \boldsymbol{x_k} )\) is the conditional probability of observing class \(i\) given features \(\boldsymbol{x_k}\) on the training set, \(p_t(y_i \mid \boldsymbol{x_k} )^{(r)}\) is the conditional probability in EM step \(r\) of observing class \(i\) given features \(\boldsymbol{x_k}\) on the testing set, and \(N\) is the number of examples in the testing set.

Since there is no need to estimate \(p(\boldsymbol{x} \mid y)\) in any step of the EM procedure, the algorithm can scale to high-dimensional datasets. Unfortunately, estimates of \(p(y \mid \boldsymbol{x})\) derived from modern neural networks are often poorly calibrated (Guo et al., 2017), and the lack of calibration can decrease the effectiveness of EM. For this reason, comparisons against the EM algorithm have been absent in the label shift adaptation literature.

Recently, Black Box Shift Estimation (BBSE) (Lipton et al., 2018) and a variant called Regularized Learning Label Shift (RLLS) (Azizzadenesheli et al., 2019): leverage (possibly uncalibrated) predictions off-the-shelf classifiers to estimate the shift. Both of these moment-matching estimators require model retraining with importance weights which can be challenging at large scales.

In our paper, we revisit maximum likelihood. We show that in combination with good calibration, a maximum likelihood procedure outperforms all other methods empirically and achieves state of the art results. Let’s examine the maximum likelihood objective. Let \(\omega_i\) denote membership in class \(i\). We seek target-domain priors \(p_t(\omega_i)\) that maximize the log-likelihood:

\[l(\boldsymbol{X}; p_t(\omega_i)) = 􏰀\sum_{k} log \sum_{i} p_t(\boldsymbol{x_k}\mid\omega_i) q(\omega_i)\]

which we show is equivalent to

\[l(\boldsymbol{X}; p_t(\omega_i)) = 􏰀\sum_{k} log(p_s(\boldsymbol{x_k})) + log \sum_{i}\frac{p_s(\omega_i\mid\boldsymbol{x_k})}{p_s(\omega_i)}p_t{\omega_i}\]

The paper contains proof that the maximum likelihood objective above is concave, so EM converges to the global maximum. Furthermore, any optimization algorithm could be used, so we are not restricted to EM.

Our approach is as follows:

In our experiments, the proposed approach results in state of the art performance.

BCTS calibrated maximum likelihood is effective. On CIFAR10 data, performance was quantified by the mean squared error in ratio of target domain and source domain probabilities q(y)/p(y). Dirichlet shift (shift=0.1) simulated over 10 trials for each of 10 different trained models (100 trials in total). N=2000 samples were used in validation & test sets (results are qualitatively similar for different shifts and N as well).

This work demonstrates that the maximum likelihood with appropriate calibration is a formidable and efficient baseline for label shift adaptation. Follow up work by Garg et al. that studies why well-calibrated maximum likelihood is effective indepedently verified our findings: “Across all shifts, MLLS (with BCTS-calibrated classifiers) uniformly dominates BBSE, RLLS, …”

Summary & Additional Resources

In summary:

A number of useful resources: