%load_ext autoreload
%autoreload 2
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
import seaborn as sns
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.linear_model import LogisticRegressionCV, LogisticRegression
from xgboost import XGBRegressor
from lightgbm import LGBMRegressor
from sklearn.metrics import mean_absolute_error
from sklearn.metrics import mean_squared_error as mse
from scipy.stats import entropy
import warnings
from causalml.inference.meta import LRSRegressor
from causalml.inference.meta import XGBTRegressor, MLPTRegressor
from causalml.inference.meta import BaseXRegressor, BaseRRegressor, BaseSRegressor, BaseTRegressor
from causalml.inference.nn import DragonNet
from causalml.match import NearestNeighborMatch, MatchOptimizer, create_table_one
from causalml.propensity import ElasticNetPropensityModel
from causalml.dataset.regression import *
from causalml.metrics import *
import os, sys
%matplotlib inline
warnings.filterwarnings('ignore')
plt.style.use('fivethirtyeight')
sns.set_palette('Paired')
plt.rcParams['figure.figsize'] = (12,8)
/Users/jeong/.conda/envs/py36/lib/python3.6/site-packages/sklearn/utils/deprecation.py:144: FutureWarning: The sklearn.utils.testing module is deprecated in version 0.22 and will be removed in version 0.24. The corresponding classes / functions should instead be imported from sklearn.utils. Anything that cannot be imported from sklearn.utils is now part of the private API. warnings.warn(message, FutureWarning) Using TensorFlow backend.
Hill introduced a semi-synthetic dataset constructed from the Infant Health and Development Program (IHDP). This dataset is based on a randomized experiment investigating the effect of home visits by specialists on future cognitive scores. The data has 747 observations (rows). The IHDP simulation is considered the de-facto standard benchmark for neural network treatment effect estimation methods.
The original paper uses 1000 realizations from the NCPI package, but for illustration purposes, we use 1 dataset (realization) as an example below.
df = pd.read_csv(f'data/ihdp_npci_3.csv', header=None)
cols = ["treatment", "y_factual", "y_cfactual", "mu0", "mu1"] + [f'x{i}' for i in range(1,26)]
df.columns = cols
df.shape
(747, 30)
df.head()
treatment | y_factual | y_cfactual | mu0 | mu1 | x1 | x2 | x3 | x4 | x5 | ... | x16 | x17 | x18 | x19 | x20 | x21 | x22 | x23 | x24 | x25 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 1 | 5.931652 | 3.500591 | 2.253801 | 7.136441 | -0.528603 | -0.343455 | 1.128554 | 0.161703 | -0.316603 | ... | 1 | 1 | 1 | 1 | 0 | 0 | 0 | 0 | 0 | 0 |
1 | 0 | 2.175966 | 5.952101 | 1.257592 | 6.553022 | -1.736945 | -1.802002 | 0.383828 | 2.244320 | -0.629189 | ... | 1 | 1 | 1 | 1 | 0 | 0 | 0 | 0 | 0 | 0 |
2 | 0 | 2.180294 | 7.175734 | 2.384100 | 7.192645 | -0.807451 | -0.202946 | -0.360898 | -0.879606 | 0.808706 | ... | 1 | 0 | 1 | 1 | 0 | 0 | 0 | 0 | 0 | 0 |
3 | 0 | 3.587662 | 7.787537 | 4.009365 | 7.712456 | 0.390083 | 0.596582 | -1.850350 | -0.879606 | -0.004017 | ... | 1 | 0 | 1 | 1 | 0 | 0 | 0 | 0 | 0 | 0 |
4 | 0 | 2.372618 | 5.461871 | 2.481631 | 7.232739 | -1.045229 | -0.602710 | 0.011465 | 0.161703 | 0.683672 | ... | 1 | 1 | 1 | 1 | 0 | 0 | 0 | 0 | 0 | 0 |
5 rows × 30 columns
pd.Series(df['treatment']).value_counts(normalize=True)
0 0.813922 1 0.186078 Name: treatment, dtype: float64
X = df.loc[:,'x1':]
treatment = df['treatment']
y = df['y_factual']
tau = df.apply(lambda d: d['y_factual'] - d['y_cfactual'] if d['treatment']==1
else d['y_cfactual'] - d['y_factual'],
axis=1)
# p_model = LogisticRegressionCV(penalty='elasticnet', solver='saga', l1_ratios=np.linspace(0,1,5),
# cv=StratifiedKFold(n_splits=4, shuffle=True))
# p_model.fit(X, treatment)
# p = p_model.predict_proba(X)[:, 1]
p_model = ElasticNetPropensityModel()
p = p_model.fit_predict(X, treatment)
s_learner = BaseSRegressor(LGBMRegressor())
s_ate = s_learner.estimate_ate(X, treatment, y)[0]
s_ite = s_learner.fit_predict(X, treatment, y)
t_learner = BaseTRegressor(LGBMRegressor())
t_ate = t_learner.estimate_ate(X, treatment, y)[0][0]
t_ite = t_learner.fit_predict(X, treatment, y)
x_learner = BaseXRegressor(LGBMRegressor())
x_ate = x_learner.estimate_ate(X, treatment, y, p)[0][0]
x_ite = x_learner.fit_predict(X, treatment, y, p)
r_learner = BaseRRegressor(LGBMRegressor())
r_ate = r_learner.estimate_ate(X, treatment, y, p)[0][0]
r_ite = r_learner.fit_predict(X, treatment, y, p)
dragon = DragonNet(neurons_per_layer=200, targeted_reg=True)
dragon_ite = dragon.fit_predict(X, treatment, y, return_components=False)
dragon_ate = dragon_ite.mean()
Train on 597 samples, validate on 150 samples Epoch 1/30 597/597 [==============================] - 1s 1ms/step - loss: 1153.1169 - regression_loss: 526.5245 - binary_classification_loss: 34.2278 - treatment_accuracy: 0.7999 - track_epsilon: 0.0516 - val_loss: 356.0019 - val_regression_loss: 126.8068 - val_binary_classification_loss: 34.7623 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0513 Epoch 2/30 597/597 [==============================] - 0s 67us/step - loss: 343.8514 - regression_loss: 142.3123 - binary_classification_loss: 28.2888 - treatment_accuracy: 0.8434 - track_epsilon: 0.0513 - val_loss: 230.0812 - val_regression_loss: 81.2849 - val_binary_classification_loss: 34.9740 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0496 Epoch 3/30 597/597 [==============================] - 0s 73us/step - loss: 255.0366 - regression_loss: 108.9301 - binary_classification_loss: 26.8012 - treatment_accuracy: 0.8465 - track_epsilon: 0.0490 - val_loss: 235.1863 - val_regression_loss: 82.9400 - val_binary_classification_loss: 35.9143 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0493 Epoch 4/30 597/597 [==============================] - 0s 83us/step - loss: 214.3295 - regression_loss: 84.8636 - binary_classification_loss: 26.3836 - treatment_accuracy: 0.8561 - track_epsilon: 0.0496 - val_loss: 206.7090 - val_regression_loss: 66.4528 - val_binary_classification_loss: 36.8853 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0489 Epoch 5/30 597/597 [==============================] - 0s 63us/step - loss: 193.1023 - regression_loss: 77.6289 - binary_classification_loss: 25.8865 - treatment_accuracy: 0.8497 - track_epsilon: 0.0478 - val_loss: 204.8226 - val_regression_loss: 71.0998 - val_binary_classification_loss: 35.7694 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0470 Epoch 6/30 597/597 [==============================] - 0s 62us/step - loss: 181.7809 - regression_loss: 71.4368 - binary_classification_loss: 25.3941 - treatment_accuracy: 0.8593 - track_epsilon: 0.0469 - val_loss: 209.0668 - val_regression_loss: 68.9204 - val_binary_classification_loss: 36.2566 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0464 Epoch 7/30 597/597 [==============================] - 0s 61us/step - loss: 176.5884 - regression_loss: 70.6928 - binary_classification_loss: 25.1117 - treatment_accuracy: 0.8561 - track_epsilon: 0.0455 - val_loss: 203.3805 - val_regression_loss: 69.1391 - val_binary_classification_loss: 35.8173 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0446 Epoch 8/30 597/597 [==============================] - 0s 65us/step - loss: 170.7210 - regression_loss: 66.4146 - binary_classification_loss: 24.8363 - treatment_accuracy: 0.8401 - track_epsilon: 0.0441 - val_loss: 192.5185 - val_regression_loss: 62.5455 - val_binary_classification_loss: 36.8282 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0433 Epoch 9/30 597/597 [==============================] - 0s 66us/step - loss: 160.6429 - regression_loss: 61.6206 - binary_classification_loss: 24.6174 - treatment_accuracy: 0.8497 - track_epsilon: 0.0426 - val_loss: 194.9871 - val_regression_loss: 64.1374 - val_binary_classification_loss: 36.2175 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0418 Epoch 10/30 597/597 [==============================] - 0s 65us/step - loss: 160.4497 - regression_loss: 61.8506 - binary_classification_loss: 24.4592 - treatment_accuracy: 0.8497 - track_epsilon: 0.0412 - val_loss: 188.0958 - val_regression_loss: 60.7865 - val_binary_classification_loss: 36.4476 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0403 Epoch 11/30 597/597 [==============================] - 0s 62us/step - loss: 159.8468 - regression_loss: 62.8502 - binary_classification_loss: 24.3127 - treatment_accuracy: 0.8529 - track_epsilon: 0.0395 - val_loss: 197.3698 - val_regression_loss: 63.0735 - val_binary_classification_loss: 36.6958 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0390 Epoch 12/30 597/597 [==============================] - 0s 57us/step - loss: 159.8472 - regression_loss: 61.2275 - binary_classification_loss: 24.2195 - treatment_accuracy: 0.8497 - track_epsilon: 0.0383 - val_loss: 190.8406 - val_regression_loss: 64.3669 - val_binary_classification_loss: 35.1488 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0372 Train on 597 samples, validate on 150 samples Epoch 1/300 597/597 [==============================] - 1s 1ms/step - loss: 151.0525 - regression_loss: 58.8814 - binary_classification_loss: 24.1191 - treatment_accuracy: 0.8529 - track_epsilon: 0.0377 - val_loss: 184.5767 - val_regression_loss: 59.0096 - val_binary_classification_loss: 35.9360 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0390 Epoch 2/300 597/597 [==============================] - 0s 67us/step - loss: 150.0762 - regression_loss: 56.8447 - binary_classification_loss: 24.1029 - treatment_accuracy: 0.8497 - track_epsilon: 0.0326 - val_loss: 184.1211 - val_regression_loss: 59.6037 - val_binary_classification_loss: 35.7386 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0227 Epoch 3/300 597/597 [==============================] - 0s 67us/step - loss: 149.7746 - regression_loss: 57.0947 - binary_classification_loss: 24.0785 - treatment_accuracy: 0.8561 - track_epsilon: 0.0181 - val_loss: 181.9517 - val_regression_loss: 59.1016 - val_binary_classification_loss: 35.8648 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0142 Epoch 4/300 597/597 [==============================] - 0s 68us/step - loss: 148.7084 - regression_loss: 57.0758 - binary_classification_loss: 24.0558 - treatment_accuracy: 0.8561 - track_epsilon: 0.0127 - val_loss: 182.8566 - val_regression_loss: 59.1128 - val_binary_classification_loss: 35.9316 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0076 Epoch 5/300 597/597 [==============================] - 0s 69us/step - loss: 150.4725 - regression_loss: 56.8933 - binary_classification_loss: 24.0455 - treatment_accuracy: 0.8529 - track_epsilon: 0.0040 - val_loss: 182.9057 - val_regression_loss: 59.1165 - val_binary_classification_loss: 36.0808 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0022 Epoch 6/300 597/597 [==============================] - 0s 68us/step - loss: 147.7774 - regression_loss: 56.4606 - binary_classification_loss: 24.0391 - treatment_accuracy: 0.8593 - track_epsilon: 0.0013 - val_loss: 183.9675 - val_regression_loss: 59.4084 - val_binary_classification_loss: 36.1876 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0024 Epoch 7/300 597/597 [==============================] - 0s 67us/step - loss: 149.8826 - regression_loss: 57.2671 - binary_classification_loss: 24.0319 - treatment_accuracy: 0.8529 - track_epsilon: 0.0028 - val_loss: 186.5590 - val_regression_loss: 60.4098 - val_binary_classification_loss: 36.1753 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 4.0377e-04 Epoch 8/300 597/597 [==============================] - 0s 70us/step - loss: 148.1314 - regression_loss: 56.3730 - binary_classification_loss: 24.0128 - treatment_accuracy: 0.8561 - track_epsilon: 0.0021 - val_loss: 183.1079 - val_regression_loss: 59.5408 - val_binary_classification_loss: 36.1076 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0031 Epoch 9/300 597/597 [==============================] - 0s 70us/step - loss: 148.6218 - regression_loss: 56.6761 - binary_classification_loss: 23.9945 - treatment_accuracy: 0.8561 - track_epsilon: 0.0017 - val_loss: 183.6684 - val_regression_loss: 59.4958 - val_binary_classification_loss: 36.1848 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0041 Epoch 10/300 597/597 [==============================] - 0s 80us/step - loss: 147.2199 - regression_loss: 55.6598 - binary_classification_loss: 23.9914 - treatment_accuracy: 0.8561 - track_epsilon: 0.0037 - val_loss: 187.5044 - val_regression_loss: 60.6762 - val_binary_classification_loss: 36.0621 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0018 Epoch 11/300 597/597 [==============================] - ETA: 0s - loss: 184.3614 - regression_loss: 74.6064 - binary_classification_loss: 30.5212 - treatment_accuracy: 0.8281 - track_epsilon: 0.001 - 0s 67us/step - loss: 149.2038 - regression_loss: 56.6065 - binary_classification_loss: 23.9720 - treatment_accuracy: 0.8465 - track_epsilon: 0.0028 - val_loss: 185.2099 - val_regression_loss: 59.7681 - val_binary_classification_loss: 36.0292 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 1.1392e-04 Epoch 12/300 597/597 [==============================] - 0s 65us/step - loss: 144.7243 - regression_loss: 56.0806 - binary_classification_loss: 23.9684 - treatment_accuracy: 0.8401 - track_epsilon: 0.0012 - val_loss: 182.7289 - val_regression_loss: 59.1949 - val_binary_classification_loss: 36.1361 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0018 Epoch 13/300 597/597 [==============================] - 0s 82us/step - loss: 146.8869 - regression_loss: 56.1869 - binary_classification_loss: 23.9454 - treatment_accuracy: 0.8593 - track_epsilon: 0.0012 - val_loss: 181.4378 - val_regression_loss: 58.5800 - val_binary_classification_loss: 36.0047 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0020 Epoch 14/300 597/597 [==============================] - 0s 64us/step - loss: 145.5166 - regression_loss: 55.1947 - binary_classification_loss: 23.9264 - treatment_accuracy: 0.8497 - track_epsilon: 0.0028 - val_loss: 183.9117 - val_regression_loss: 59.8171 - val_binary_classification_loss: 36.0495 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0023 Epoch 15/300 597/597 [==============================] - 0s 67us/step - loss: 147.9824 - regression_loss: 55.9960 - binary_classification_loss: 23.9193 - treatment_accuracy: 0.8561 - track_epsilon: 0.0013 - val_loss: 184.8934 - val_regression_loss: 60.1771 - val_binary_classification_loss: 36.0228 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0027 Epoch 16/300 597/597 [==============================] - 0s 67us/step - loss: 146.7458 - regression_loss: 55.8055 - binary_classification_loss: 23.8981 - treatment_accuracy: 0.8561 - track_epsilon: 0.0022 - val_loss: 184.1797 - val_regression_loss: 59.5255 - val_binary_classification_loss: 35.9737 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 9.6159e-04 Epoch 17/300 597/597 [==============================] - 0s 76us/step - loss: 145.5521 - regression_loss: 55.3490 - binary_classification_loss: 23.8978 - treatment_accuracy: 0.8529 - track_epsilon: 0.0014 - val_loss: 183.2418 - val_regression_loss: 59.2208 - val_binary_classification_loss: 35.7738 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0039 Epoch 00017: ReduceLROnPlateau reducing learning rate to 4.999999873689376e-06. Epoch 18/300 597/597 [==============================] - 0s 74us/step - loss: 144.7616 - regression_loss: 54.9449 - binary_classification_loss: 23.8797 - treatment_accuracy: 0.8561 - track_epsilon: 0.0050 - val_loss: 183.1350 - val_regression_loss: 59.2228 - val_binary_classification_loss: 35.7351 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0039 Epoch 19/300 597/597 [==============================] - 0s 67us/step - loss: 141.8471 - regression_loss: 54.6760 - binary_classification_loss: 23.8693 - treatment_accuracy: 0.8561 - track_epsilon: 0.0020 - val_loss: 182.4961 - val_regression_loss: 59.0138 - val_binary_classification_loss: 35.8385 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 2.2382e-04 Epoch 20/300 597/597 [==============================] - 0s 75us/step - loss: 143.4988 - regression_loss: 54.6465 - binary_classification_loss: 23.8661 - treatment_accuracy: 0.8593 - track_epsilon: 9.6414e-04 - val_loss: 183.4780 - val_regression_loss: 59.2525 - val_binary_classification_loss: 35.8081 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 6.3370e-04 Epoch 21/300 597/597 [==============================] - 0s 69us/step - loss: 143.2713 - regression_loss: 54.8240 - binary_classification_loss: 23.8655 - treatment_accuracy: 0.8529 - track_epsilon: 5.8381e-04 - val_loss: 182.7529 - val_regression_loss: 59.1405 - val_binary_classification_loss: 35.8905 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0014 Epoch 22/300 597/597 [==============================] - 0s 73us/step - loss: 144.5639 - regression_loss: 54.9520 - binary_classification_loss: 23.8562 - treatment_accuracy: 0.8497 - track_epsilon: 0.0011 - val_loss: 182.2272 - val_regression_loss: 58.9541 - val_binary_classification_loss: 35.8026 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0020 Epoch 23/300 597/597 [==============================] - 0s 88us/step - loss: 144.3322 - regression_loss: 54.4709 - binary_classification_loss: 23.8485 - treatment_accuracy: 0.8465 - track_epsilon: 0.0033 - val_loss: 183.0935 - val_regression_loss: 59.1250 - val_binary_classification_loss: 35.7517 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0026 Epoch 24/300 597/597 [==============================] - 0s 65us/step - loss: 143.6903 - regression_loss: 54.4800 - binary_classification_loss: 23.8423 - treatment_accuracy: 0.8561 - track_epsilon: 0.0013 - val_loss: 182.7994 - val_regression_loss: 59.0775 - val_binary_classification_loss: 35.7825 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0011 Epoch 00024: ReduceLROnPlateau reducing learning rate to 2.499999936844688e-06. Epoch 25/300 597/597 [==============================] - 0s 69us/step - loss: 142.5934 - regression_loss: 54.3459 - binary_classification_loss: 23.8378 - treatment_accuracy: 0.8529 - track_epsilon: 0.0012 - val_loss: 182.6808 - val_regression_loss: 59.0681 - val_binary_classification_loss: 35.7840 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0015 Epoch 26/300 597/597 [==============================] - 0s 69us/step - loss: 144.1265 - regression_loss: 54.4636 - binary_classification_loss: 23.8337 - treatment_accuracy: 0.8593 - track_epsilon: 0.0011 - val_loss: 183.0977 - val_regression_loss: 59.1001 - val_binary_classification_loss: 35.7414 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0011 Epoch 27/300 597/597 [==============================] - 0s 67us/step - loss: 143.5707 - regression_loss: 54.1999 - binary_classification_loss: 23.8293 - treatment_accuracy: 0.8497 - track_epsilon: 0.0016 - val_loss: 182.1685 - val_regression_loss: 58.8281 - val_binary_classification_loss: 35.7402 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0019 Epoch 28/300 597/597 [==============================] - 0s 70us/step - loss: 144.1436 - regression_loss: 54.1982 - binary_classification_loss: 23.8266 - treatment_accuracy: 0.8561 - track_epsilon: 0.0018 - val_loss: 182.2616 - val_regression_loss: 58.8418 - val_binary_classification_loss: 35.7468 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0017 Epoch 29/300 597/597 [==============================] - 0s 69us/step - loss: 143.1436 - regression_loss: 54.2246 - binary_classification_loss: 23.8253 - treatment_accuracy: 0.8497 - track_epsilon: 0.0017 - val_loss: 182.5233 - val_regression_loss: 58.9060 - val_binary_classification_loss: 35.7543 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0017 Epoch 00029: ReduceLROnPlateau reducing learning rate to 1.249999968422344e-06. Epoch 30/300 597/597 [==============================] - 0s 69us/step - loss: 142.9970 - regression_loss: 54.0639 - binary_classification_loss: 23.8208 - treatment_accuracy: 0.8625 - track_epsilon: 0.0016 - val_loss: 182.8976 - val_regression_loss: 59.0591 - val_binary_classification_loss: 35.7240 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0017 Epoch 31/300 597/597 [==============================] - 0s 67us/step - loss: 143.8003 - regression_loss: 54.1442 - binary_classification_loss: 23.8190 - treatment_accuracy: 0.8529 - track_epsilon: 0.0021 - val_loss: 182.6798 - val_regression_loss: 59.0072 - val_binary_classification_loss: 35.7270 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0022 Epoch 32/300 597/597 [==============================] - 0s 66us/step - loss: 143.4029 - regression_loss: 54.1355 - binary_classification_loss: 23.8157 - treatment_accuracy: 0.8561 - track_epsilon: 0.0023 - val_loss: 182.5541 - val_regression_loss: 58.9682 - val_binary_classification_loss: 35.7180 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0020 Epoch 33/300 597/597 [==============================] - 0s 73us/step - loss: 142.1901 - regression_loss: 54.0516 - binary_classification_loss: 23.8148 - treatment_accuracy: 0.8529 - track_epsilon: 0.0018 - val_loss: 183.0714 - val_regression_loss: 59.1151 - val_binary_classification_loss: 35.7216 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0015 Epoch 34/300 597/597 [==============================] - 0s 65us/step - loss: 140.2360 - regression_loss: 54.0345 - binary_classification_loss: 23.8139 - treatment_accuracy: 0.8497 - track_epsilon: 0.0016 - val_loss: 182.7475 - val_regression_loss: 59.0084 - val_binary_classification_loss: 35.7426 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0015 Epoch 35/300 597/597 [==============================] - 0s 81us/step - loss: 142.8741 - regression_loss: 54.0038 - binary_classification_loss: 23.8122 - treatment_accuracy: 0.8433 - track_epsilon: 0.0013 - val_loss: 182.6587 - val_regression_loss: 58.9828 - val_binary_classification_loss: 35.7345 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0014 Epoch 36/300 597/597 [==============================] - 0s 73us/step - loss: 143.2542 - regression_loss: 54.0470 - binary_classification_loss: 23.8112 - treatment_accuracy: 0.8497 - track_epsilon: 0.0015 - val_loss: 182.7340 - val_regression_loss: 59.0171 - val_binary_classification_loss: 35.7291 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0016 Epoch 37/300 597/597 [==============================] - 0s 63us/step - loss: 143.1216 - regression_loss: 53.9242 - binary_classification_loss: 23.8101 - treatment_accuracy: 0.8497 - track_epsilon: 0.0018 - val_loss: 182.6380 - val_regression_loss: 58.9966 - val_binary_classification_loss: 35.7090 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0019 Epoch 38/300 597/597 [==============================] - 0s 74us/step - loss: 142.9598 - regression_loss: 53.9560 - binary_classification_loss: 23.8082 - treatment_accuracy: 0.8497 - track_epsilon: 0.0019 - val_loss: 182.5107 - val_regression_loss: 58.9566 - val_binary_classification_loss: 35.7025 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0018 Epoch 39/300 597/597 [==============================] - 0s 70us/step - loss: 142.1619 - regression_loss: 53.9813 - binary_classification_loss: 23.8070 - treatment_accuracy: 0.8497 - track_epsilon: 0.0015 - val_loss: 182.6606 - val_regression_loss: 58.9962 - val_binary_classification_loss: 35.7107 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0015 Epoch 00039: ReduceLROnPlateau reducing learning rate to 6.24999984211172e-07. Epoch 40/300 597/597 [==============================] - 0s 65us/step - loss: 143.1522 - regression_loss: 53.9099 - binary_classification_loss: 23.8051 - treatment_accuracy: 0.8561 - track_epsilon: 0.0017 - val_loss: 182.5675 - val_regression_loss: 58.9788 - val_binary_classification_loss: 35.6982 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0018 Epoch 41/300 597/597 [==============================] - 0s 70us/step - loss: 142.9669 - regression_loss: 54.0113 - binary_classification_loss: 23.8046 - treatment_accuracy: 0.8465 - track_epsilon: 0.0017 - val_loss: 182.7173 - val_regression_loss: 59.0139 - val_binary_classification_loss: 35.6968 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0018 Epoch 42/300 597/597 [==============================] - 0s 71us/step - loss: 142.9812 - regression_loss: 53.9480 - binary_classification_loss: 23.8039 - treatment_accuracy: 0.8625 - track_epsilon: 0.0019 - val_loss: 182.5140 - val_regression_loss: 58.9547 - val_binary_classification_loss: 35.7111 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0020 Epoch 43/300 597/597 [==============================] - 0s 73us/step - loss: 143.3660 - regression_loss: 53.9199 - binary_classification_loss: 23.8027 - treatment_accuracy: 0.8529 - track_epsilon: 0.0018 - val_loss: 182.6215 - val_regression_loss: 58.9790 - val_binary_classification_loss: 35.7084 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0018 Epoch 44/300 597/597 [==============================] - 0s 80us/step - loss: 142.2327 - regression_loss: 53.8626 - binary_classification_loss: 23.8023 - treatment_accuracy: 0.8561 - track_epsilon: 0.0019 - val_loss: 182.6031 - val_regression_loss: 58.9846 - val_binary_classification_loss: 35.7026 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0019 Epoch 00044: ReduceLROnPlateau reducing learning rate to 3.12499992105586e-07. Epoch 45/300 597/597 [==============================] - 0s 70us/step - loss: 141.4800 - regression_loss: 53.8688 - binary_classification_loss: 23.8012 - treatment_accuracy: 0.8497 - track_epsilon: 0.0018 - val_loss: 182.5443 - val_regression_loss: 58.9709 - val_binary_classification_loss: 35.7019 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0019 Epoch 46/300 597/597 [==============================] - 0s 75us/step - loss: 141.3155 - regression_loss: 53.8269 - binary_classification_loss: 23.8007 - treatment_accuracy: 0.8497 - track_epsilon: 0.0018 - val_loss: 182.6176 - val_regression_loss: 58.9883 - val_binary_classification_loss: 35.7016 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0018 Epoch 47/300 597/597 [==============================] - 0s 73us/step - loss: 143.1740 - regression_loss: 53.8460 - binary_classification_loss: 23.8005 - treatment_accuracy: 0.8497 - track_epsilon: 0.0018 - val_loss: 182.6459 - val_regression_loss: 59.0014 - val_binary_classification_loss: 35.6936 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0018 Epoch 48/300 597/597 [==============================] - 0s 85us/step - loss: 142.8012 - regression_loss: 53.8343 - binary_classification_loss: 23.7998 - treatment_accuracy: 0.8593 - track_epsilon: 0.0019 - val_loss: 182.6606 - val_regression_loss: 59.0031 - val_binary_classification_loss: 35.6939 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0019 Epoch 49/300 597/597 [==============================] - 0s 71us/step - loss: 142.5543 - regression_loss: 53.8559 - binary_classification_loss: 23.7995 - treatment_accuracy: 0.8497 - track_epsilon: 0.0019 - val_loss: 182.6408 - val_regression_loss: 58.9984 - val_binary_classification_loss: 35.6932 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0019 Epoch 00049: ReduceLROnPlateau reducing learning rate to 1.56249996052793e-07. Epoch 50/300 597/597 [==============================] - 0s 74us/step - loss: 142.2061 - regression_loss: 53.8305 - binary_classification_loss: 23.7990 - treatment_accuracy: 0.8465 - track_epsilon: 0.0019 - val_loss: 182.6622 - val_regression_loss: 59.0048 - val_binary_classification_loss: 35.6935 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0019 Epoch 51/300 597/597 [==============================] - 0s 71us/step - loss: 144.1153 - regression_loss: 53.8202 - binary_classification_loss: 23.7987 - treatment_accuracy: 0.8561 - track_epsilon: 0.0019 - val_loss: 182.6417 - val_regression_loss: 58.9983 - val_binary_classification_loss: 35.6931 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0019 Epoch 52/300 597/597 [==============================] - 0s 75us/step - loss: 142.2625 - regression_loss: 53.8170 - binary_classification_loss: 23.7987 - treatment_accuracy: 0.8497 - track_epsilon: 0.0018 - val_loss: 182.6301 - val_regression_loss: 58.9968 - val_binary_classification_loss: 35.6929 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0018 Epoch 53/300 597/597 [==============================] - 0s 77us/step - loss: 142.2695 - regression_loss: 53.8087 - binary_classification_loss: 23.7985 - treatment_accuracy: 0.8593 - track_epsilon: 0.0018 - val_loss: 182.6330 - val_regression_loss: 58.9989 - val_binary_classification_loss: 35.6917 - val_treatment_accuracy: 0.7244 - val_track_epsilon: 0.0018
df_preds = pd.DataFrame([s_ite.ravel(),
t_ite.ravel(),
x_ite.ravel(),
r_ite.ravel(),
dragon_ite.ravel(),
tau.ravel(),
treatment.ravel(),
y.ravel()],
index=['S','T','X','R','dragonnet','tau','w','y']).T
df_cumgain = get_cumgain(df_preds)
df_result = pd.DataFrame([s_ate, t_ate, x_ate, r_ate, dragon_ate, tau.mean()],
index=['S','T','X','R','dragonnet','actual'], columns=['ATE'])
df_result['MAE'] = [mean_absolute_error(t,p) for t,p in zip([s_ite, t_ite, x_ite, r_ite, dragon_ite],
[tau.values.reshape(-1,1)]*5 )
] + [None]
df_result['AUUC'] = auuc_score(df_preds)
df_result
ATE | MAE | AUUC | |
---|---|---|---|
S | 4.054511 | 1.027666 | 0.575822 |
T | 4.100199 | 0.980788 | 0.580929 |
X | 4.020589 | 1.115693 | 0.564634 |
R | 3.867016 | 2.033445 | 0.557536 |
dragonnet | 4.003578 | 1.182555 | 0.553948 |
actual | 4.098887 | NaN | NaN |
plot_gain(df_preds)
causalml
Synthetic Data Generation Method¶y, X, w, tau, b, e = simulate_nuisance_and_easy_treatment(n=1000)
X_train, X_val, y_train, y_val, w_train, w_val, tau_train, tau_val, b_train, b_val, e_train, e_val = \
train_test_split(X, y, w, tau, b, e, test_size=0.2, random_state=123, shuffle=True)
preds_dict_train = {}
preds_dict_valid = {}
preds_dict_train['Actuals'] = tau_train
preds_dict_valid['Actuals'] = tau_val
preds_dict_train['generated_data'] = {
'y': y_train,
'X': X_train,
'w': w_train,
'tau': tau_train,
'b': b_train,
'e': e_train}
preds_dict_valid['generated_data'] = {
'y': y_val,
'X': X_val,
'w': w_val,
'tau': tau_val,
'b': b_val,
'e': e_val}
# Predict p_hat because e would not be directly observed in real-life
p_model = ElasticNetPropensityModel()
p_hat_train = p_model.fit_predict(X_train, w_train)
p_hat_val = p_model.fit_predict(X_val, w_val)
for base_learner, label_l in zip([BaseSRegressor, BaseTRegressor, BaseXRegressor, BaseRRegressor],
['S', 'T', 'X', 'R']):
for model, label_m in zip([LinearRegression, XGBRegressor], ['LR', 'XGB']):
# RLearner will need to fit on the p_hat
if label_l != 'R':
learner = base_learner(model())
# fit the model on training data only
learner.fit(X=X_train, treatment=w_train, y=y_train)
try:
preds_dict_train['{} Learner ({})'.format(
label_l, label_m)] = learner.predict(X=X_train, p=p_hat_train).flatten()
preds_dict_valid['{} Learner ({})'.format(
label_l, label_m)] = learner.predict(X=X_val, p=p_hat_val).flatten()
except TypeError:
preds_dict_train['{} Learner ({})'.format(
label_l, label_m)] = learner.predict(X=X_train, treatment=w_train, y=y_train).flatten()
preds_dict_valid['{} Learner ({})'.format(
label_l, label_m)] = learner.predict(X=X_val, treatment=w_val, y=y_val).flatten()
else:
learner = base_learner(model())
learner.fit(X=X_train, p=p_hat_train, treatment=w_train, y=y_train)
preds_dict_train['{} Learner ({})'.format(
label_l, label_m)] = learner.predict(X=X_train).flatten()
preds_dict_valid['{} Learner ({})'.format(
label_l, label_m)] = learner.predict(X=X_val).flatten()
learner = DragonNet(verbose=False)
learner.fit(X_train, treatment=w_train, y=y_train)
preds_dict_train['DragonNet'] = learner.predict_tau(X=X_train).flatten()
preds_dict_valid['DragonNet'] = learner.predict_tau(X=X_val).flatten()
actuals_train = preds_dict_train['Actuals']
actuals_validation = preds_dict_valid['Actuals']
synthetic_summary_train = pd.DataFrame({label: [preds.mean(), mse(preds, actuals_train)] for label, preds
in preds_dict_train.items() if 'generated' not in label.lower()},
index=['ATE', 'MSE']).T
synthetic_summary_train['Abs % Error of ATE'] = np.abs(
(synthetic_summary_train['ATE']/synthetic_summary_train.loc['Actuals', 'ATE']) - 1)
synthetic_summary_validation = pd.DataFrame({label: [preds.mean(), mse(preds, actuals_validation)]
for label, preds in preds_dict_valid.items()
if 'generated' not in label.lower()},
index=['ATE', 'MSE']).T
synthetic_summary_validation['Abs % Error of ATE'] = np.abs(
(synthetic_summary_validation['ATE']/synthetic_summary_validation.loc['Actuals', 'ATE']) - 1)
# calculate kl divergence for training
for label in synthetic_summary_train.index:
stacked_values = np.hstack((preds_dict_train[label], actuals_train))
stacked_low = np.percentile(stacked_values, 0.1)
stacked_high = np.percentile(stacked_values, 99.9)
bins = np.linspace(stacked_low, stacked_high, 100)
distr = np.histogram(preds_dict_train[label], bins=bins)[0]
distr = np.clip(distr/distr.sum(), 0.001, 0.999)
true_distr = np.histogram(actuals_train, bins=bins)[0]
true_distr = np.clip(true_distr/true_distr.sum(), 0.001, 0.999)
kl = entropy(distr, true_distr)
synthetic_summary_train.loc[label, 'KL Divergence'] = kl
# calculate kl divergence for validation
for label in synthetic_summary_validation.index:
stacked_values = np.hstack((preds_dict_valid[label], actuals_validation))
stacked_low = np.percentile(stacked_values, 0.1)
stacked_high = np.percentile(stacked_values, 99.9)
bins = np.linspace(stacked_low, stacked_high, 100)
distr = np.histogram(preds_dict_valid[label], bins=bins)[0]
distr = np.clip(distr/distr.sum(), 0.001, 0.999)
true_distr = np.histogram(actuals_validation, bins=bins)[0]
true_distr = np.clip(true_distr/true_distr.sum(), 0.001, 0.999)
kl = entropy(distr, true_distr)
synthetic_summary_validation.loc[label, 'KL Divergence'] = kl
df_preds_train = pd.DataFrame([preds_dict_train['S Learner (LR)'].ravel(),
preds_dict_train['S Learner (XGB)'].ravel(),
preds_dict_train['T Learner (LR)'].ravel(),
preds_dict_train['T Learner (XGB)'].ravel(),
preds_dict_train['X Learner (LR)'].ravel(),
preds_dict_train['X Learner (XGB)'].ravel(),
preds_dict_train['R Learner (LR)'].ravel(),
preds_dict_train['R Learner (XGB)'].ravel(),
preds_dict_train['DragonNet'].ravel(),
preds_dict_train['generated_data']['tau'].ravel(),
preds_dict_train['generated_data']['w'].ravel(),
preds_dict_train['generated_data']['y'].ravel()],
index=['S Learner (LR)','S Learner (XGB)',
'T Learner (LR)','T Learner (XGB)',
'X Learner (LR)','X Learner (XGB)',
'R Learner (LR)','R Learner (XGB)',
'DragonNet','tau','w','y']).T
synthetic_summary_train['AUUC'] = auuc_score(df_preds_train).iloc[:-1]
df_preds_validation = pd.DataFrame([preds_dict_valid['S Learner (LR)'].ravel(),
preds_dict_valid['S Learner (XGB)'].ravel(),
preds_dict_valid['T Learner (LR)'].ravel(),
preds_dict_valid['T Learner (XGB)'].ravel(),
preds_dict_valid['X Learner (LR)'].ravel(),
preds_dict_valid['X Learner (XGB)'].ravel(),
preds_dict_valid['R Learner (LR)'].ravel(),
preds_dict_valid['R Learner (XGB)'].ravel(),
preds_dict_valid['DragonNet'].ravel(),
preds_dict_valid['generated_data']['tau'].ravel(),
preds_dict_valid['generated_data']['w'].ravel(),
preds_dict_valid['generated_data']['y'].ravel()],
index=['S Learner (LR)','S Learner (XGB)',
'T Learner (LR)','T Learner (XGB)',
'X Learner (LR)','X Learner (XGB)',
'R Learner (LR)','R Learner (XGB)',
'DragonNet','tau','w','y']).T
synthetic_summary_validation['AUUC'] = auuc_score(df_preds_validation).iloc[:-1]
synthetic_summary_train
ATE | MSE | Abs % Error of ATE | KL Divergence | AUUC | |
---|---|---|---|---|---|
Actuals | 0.484486 | 0.000000 | 0.000000 | 0.000000 | NaN |
S Learner (LR) | 0.528743 | 0.044194 | 0.091349 | 3.473087 | 0.492660 |
S Learner (XGB) | 0.358208 | 0.310652 | 0.260643 | 0.817620 | 0.544115 |
T Learner (LR) | 0.493815 | 0.022688 | 0.019255 | 0.289978 | 0.610855 |
T Learner (XGB) | 0.397053 | 1.350928 | 0.180465 | 1.452143 | 0.521719 |
X Learner (LR) | 0.493815 | 0.022688 | 0.019255 | 0.289978 | 0.610855 |
X Learner (XGB) | 0.341013 | 0.620823 | 0.296134 | 1.098308 | 0.534908 |
R Learner (LR) | 0.471610 | 0.030968 | 0.026577 | 0.378494 | 0.614607 |
R Learner (XGB) | 0.413902 | 4.850255 | 0.145688 | 1.950556 | 0.510872 |
DragonNet | 0.415214 | 0.038613 | 0.142980 | 0.405291 | 0.612157 |
synthetic_summary_validation
ATE | MSE | Abs % Error of ATE | KL Divergence | AUUC | |
---|---|---|---|---|---|
Actuals | 0.511242 | 0.000000 | 0.000000 | 0.000000 | NaN |
S Learner (LR) | 0.528743 | 0.042236 | 0.034233 | 4.574498 | 0.494022 |
S Learner (XGB) | 0.434208 | 0.260496 | 0.150680 | 0.854890 | 0.544212 |
T Learner (LR) | 0.541503 | 0.025840 | 0.059191 | 0.686602 | 0.604712 |
T Learner (XGB) | 0.483404 | 0.679398 | 0.054451 | 1.215394 | 0.526918 |
X Learner (LR) | 0.541503 | 0.025840 | 0.059191 | 0.686602 | 0.604712 |
X Learner (XGB) | 0.328046 | 0.352812 | 0.358335 | 1.310631 | 0.535895 |
R Learner (LR) | 0.526797 | 0.034872 | 0.030426 | 0.732823 | 0.608290 |
R Learner (XGB) | 0.377533 | 2.174835 | 0.261537 | 1.734253 | 0.512412 |
DragonNet | 0.464221 | 0.037349 | 0.091973 | 0.695660 | 0.606139 |
plot_gain(df_preds_train)
plot_gain(df_preds_validation)