%matplotlib inline
%config InlineBackend.figure_format = 'retina'
from matplotlib import pyplot as plt
from lifelines import CoxPHFitter
import numpy as np
import pandas as pd
from lifelines.datasets import load_rossi
plt.style.use('bmh')
This tutorial is on some common use cases of the (many) residuals of the Cox model. We can use resdiuals to diagnose a model's poor fit to a dataset, and improve an existing model's fit.
df = load_rossi()
df['age_strata'] = pd.cut(df['age'], np.arange(0, 80, 5))
df = df.drop('age', axis=1)
cph = CoxPHFitter()
cph.fit(df, 'week', 'arrest', strata=['age_strata', 'wexp'])
<lifelines.CoxPHFitter: fitted with 432 observations, 318 censored>
cph.print_summary()
cph.plot();
<lifelines.CoxPHFitter: fitted with 432 observations, 318 censored> duration col = 'week' event col = 'arrest' strata = ['age_strata', 'wexp'] number of subjects = 432 number of events = 114 log-likelihood = -434.50 time fit was run = 2019-02-19 17:39:13 UTC --- coef exp(coef) se(coef) z p -log2(p) lower 0.95 upper 0.95 fin -0.41 0.67 0.19 -2.10 0.04 4.82 -0.79 -0.03 race 0.29 1.33 0.31 0.93 0.35 1.50 -0.32 0.90 mar -0.34 0.71 0.39 -0.87 0.38 1.38 -1.10 0.42 paro -0.10 0.91 0.20 -0.50 0.62 0.70 -0.48 0.29 prio 0.08 1.08 0.03 2.83 <0.005 7.73 0.02 0.14 --- Concordance = 0.57 Log-likelihood ratio test = 481.75 on 5 df, -log2(p)=336.05
Defined as:
$$ \delta_i - \Lambda(T_i) \\ = \delta_i - \beta_0(T_i)\exp(\beta^T x_i)$$where $T_i$ is the total observation time of subject $i$ and $\delta_i$ denotes whether they died under observation of not (event_observed
in lifelines).
From [1]:
Martingale residuals take a value between $[1,−\inf]$ for uncensored observations and $[0,−\inf]$ for censored observations. Martingale residuals can be used to assess the true functional form of a particular covariate (Thernau et al. (1990)). It is often useful to overlay a LOESS curve over this plot as they can be noisy in plots with lots of observations. Martingale residuals can also be used to assess outliers in the data set whereby the survivor function predicts an event either too early or too late, however, it's often better to use the deviance residual for this.
From [2]:
Positive values mean that the patient died sooner than
expected (according to the model); negative values mean that the patient lived longer than expected (or were censored).
r = cph.compute_residuals(df, 'martingale')
r.head()
week | arrest | martingale | |
---|---|---|---|
313 | 1.0 | True | 0.989383 |
79 | 5.0 | True | 0.972812 |
60 | 6.0 | True | 0.947727 |
225 | 7.0 | True | 0.976976 |
138 | 8.0 | True | 0.920273 |
r.plot.scatter(
x='week', y='martingale', c=np.where(r['arrest'], '#008fd5', '#fc4f30'),
alpha=0.75
)
<matplotlib.axes._subplots.AxesSubplot at 0x11ea370f0>
One problem with martingale residuals is that they are not symetric around 0. Deviance residuals are a transform of martingale residuals them symetric.
r = cph.compute_residuals(df, 'deviance')
r.head()
week | arrest | deviance | |
---|---|---|---|
313 | 1.0 | True | 2.666807 |
79 | 5.0 | True | 2.294411 |
60 | 6.0 | True | 2.001769 |
225 | 7.0 | True | 2.363998 |
138 | 8.0 | True | 1.793808 |
r.plot.scatter(
x='week', y='deviance', c=np.where(r['arrest'], '#008fd5', '#fc4f30'),
alpha=0.75
)
<matplotlib.axes._subplots.AxesSubplot at 0x11eaaca20>
r = r.join(df.drop(['week', 'arrest'], axis=1))
plt.scatter(r['prio'], r['deviance'], color=np.where(r['arrest'], '#008fd5', '#fc4f30'))
<matplotlib.collections.PathCollection at 0x11ec82208>
r = cph.compute_residuals(df, 'delta_beta')
r.head()
r = r.join(df[['week', 'arrest']])
r.head()
fin | race | mar | paro | prio | week | arrest | |
---|---|---|---|---|---|---|---|
313 | -0.005650 | -0.011593 | 0.012142 | -0.027450 | -0.020486 | 1 | 1 |
79 | -0.005761 | -0.005810 | 0.007687 | -0.020926 | -0.013372 | 5 | 1 |
60 | -0.005783 | -0.000146 | 0.003277 | -0.014325 | -0.006315 | 6 | 1 |
225 | 0.014998 | -0.041568 | 0.004855 | -0.002254 | -0.015725 | 7 | 1 |
138 | 0.011572 | 0.005331 | -0.004241 | 0.013036 | 0.004405 | 8 | 1 |
plt.scatter(r['week'], r['prio'], color=np.where(r['arrest'], '#008fd5', '#fc4f30'))
<matplotlib.collections.PathCollection at 0x11f016748>