CausalTransformer
Code for the paper "Causal Transformer for Estimating Counterfactual Outcomes"
Install / Use
/learn @Valentyn1997/CausalTransformerREADME
CausalTransformer
Causal Transformer for estimating counterfactual outcomes over time.
<img width="1518" alt="Screenshot 2022-06-03 at 16 41 44" src="https://user-images.githubusercontent.com/23198776/171877145-c7cba15e-9787-4594-8f1f-cbb8b337b74a.png">The project is built with following Python libraries:
- Pytorch-Lightning - deep learning models
- Hydra - simplified command line arguments management
- MlFlow - experiments tracking
Installations
First one needs to make the virtual environment and install all the requirements:
pip3 install virtualenv
python3 -m virtualenv -p python3 --always-copy venv
source venv/bin/activate
pip3 install -r requirements.txt
MlFlow Setup / Connection
To start an experiments server, run:
mlflow server --port=5000
To access MlFLow web UI with all the experiments, connect via ssh:
ssh -N -f -L localhost:5000:localhost:5000 <username>@<server-link>
Then, one can go to local browser http://localhost:5000.
Experiments
Main training script is universal for different models and datasets. For details on mandatory arguments - see the main configuration file config/config.yaml and other files in configs/ folder.
Generic script with logging and fixed random seed is following (with training-type enc_dec, gnet, rmsn and multi):
PYTHONPATH=. CUDA_VISIBLE_DEVICES=<devices>
python3 runnables/train_<training-type>.py +dataset=<dataset> +backbone=<backbone> exp.seed=10 exp.logging=True
Backbones (baselines)
One needs to choose a backbone and then fill the specific hyperparameters (they are left blank in the configs):
- Causal Transformer (this paper):
runnables/train_multi.py +backbone=ct - Encoder-Decoder Causal Transformer (this paper):
runnables/train_enc_dec.py +backbone=edct - Marginal Structural Models (MSMs):
runnables/train_msm.py +backbone=msm - Recurrent Marginal Structural Networks (RMSNs):
runnables/train_rmsn.py +backbone=rmsn - Counterfactual Recurrent Network (CRN):
runnables/train_enc_dec.py +backbone=crn - G-Net:
runnables/train_gnet.py +backbone=gnet
Models already have best hyperparameters saved (for each model and dataset), one can access them via: +backbone/<backbone>_hparams/cancer_sim_<balancing_objective>=<coeff_value> or +backbone/<backbone>_hparams/mimic3_real=diastolic_blood_pressure.
For CT, EDCT, and CT, several adversarial balancing objectives are available:
- counterfactual domain confusion loss (this paper):
exp.balancing=domain_confusion - gradient reversal (originally in CRN, but can be used for all the methods):
exp.balancing=grad_reverse
To train a decoder (for CRN and RMSNs), use the flag model.train_decoder=True.
To perform a manual hyperparameter tuning use the flags model.<sub_model>.tune_hparams=True, and then see model.<sub_model>.hparams_grid. Use model.<sub_model>.tune_range to specify the number of trials for random search.
Datasets
One needs to specify a dataset / dataset generator (and some additional parameters, e.g. set gamma for cancer_sim with dataset.coeff=1.0):
- Synthetic Tumor Growth Simulator:
+dataset=cancer_sim - MIMIC III Semi-synthetic Simulator (multiple treatments and outcomes):
+dataset=mimic3_synthetic - MIMIC III Real-world dataset:
+dataset=mimic3_real
Before running MIMIC III experiments, place MIMIC-III-extract dataset (all_hourly_data.h5) to data/processed/
Example of running Causal Transformer on Synthetic Tumor Growth Generator with gamma = [1.0, 2.0, 3.0] and different random seeds (total of 30 subruns), using hyperparameters:
PYTHONPATH=. CUDA_VISIBLE_DEVICES=<devices>
python3 runnables/train_multi.py -m +dataset=cancer_sim +backbone=ct +backbone/ct_hparams/cancer_sim_domain_conf=\'0\',\'1\',\'2\' exp.seed=10,101,1010,10101,101010
Updated results
Self- and cross-attention bug
New results for semi-synthetic and real-world experiments after fixing a bug with self- and cross-attentions (https://github.com/Valentyn1997/CausalTransformer/issues/7). Therein, the bug affected only Tables 1 and 2, and Figure 5 (https://arxiv.org/pdf/2204.07258.pdf). Nevertheless, the performance of the CT with the bug fixed did not change drastically.
Table 1 (updated). Results for semi-synthetic data for $\tau$-step-ahead prediction based on real-world medical data (MIMIC-III). Shown: RMSE as mean ± standard deviation over five runs.
| | $\tau = 1$ | $\tau = 2$ | $\tau = 3$ | $\tau = 4$ | $\tau = 5$ | $\tau = 6$ | $\tau = 7$ | $\tau = 8$ | $\tau = 9$ | $\tau = 10$ | |:-------------------------|:---------------------|:---------------------|:---------------------|:---------------------|:---------------------|:---------------------|:---------------------|:--------------------------------|:------------------------------------|:------------------------------------| | MSMs | 0.37 ± 0.01 | 0.57 ± 0.03 | 0.74 ± 0.06 | 0.88 ± 0.03 | 1.14 ± 0.10 | 1.95 ± 1.48 | 3.44 ± 4.57 | > 10.0 | > 10.0 | > 10.0 | | RMSNs | 0.24 ± 0.01 | 0.47 ± 0.01 | 0.60 ± 0.01 | 0.70 ± 0.02 | 0.78 ± 0.04 | 0.84 ± 0.05 | 0.89 ± 0.06 | 0.94 ± 0.08 | 0.97 ± 0.09 | 1.00 ± 0.11 | | CRN | 0.30 ± 0.01 | 0.48 ± 0.02 | 0.59 ± 0.02 | 0.65 ± 0.02 | 0.68 ± 0.02 | 0.71 ± 0.01 | 0.72 ± 0.01 | 0.74 ± 0.01 | 0.76 ± 0.01 | 0.78 ± 0.02 | | G-Net | 0.34 ± 0.01 | 0.67 ± 0.03 | 0.83 ± 0.04 | 0.94 ± 0.04 | 1.03 ± 0.05 | 1.10 ± 0.05 | 1.16 ± 0.05 | 1.21 ± 0.06 | 1.25 ± 0.06 | 1.29 ± 0.06 | | EDCT (GR; $\lambda = 1$) | 0.29 ± 0.01 | 0.46 ± 0.01 | 0.56 ± 0.01 | 0.62 ± 0.01 | 0.67 ± 0.01 | 0.70 ± 0.01 | 0.72 ± 0.01 | 0.74 ± 0.01 | 0.76 ± 0.01 | 0.78 ± 0.01 | | CT ($\alpha = 0$) (ours, fixed) | 0.20 ± 0.01 | 0.38 ± 0.01 | 0.46 ± 0.01 | 0.50 ± 0.01 | 0.52 ± 0.01 | 0.54 ± 0.01 | 0.56 ± 0.01 | 0.57 ± 0.01 | 0.59 ± 0.01 | 0.60 ± 0.01 | | CT (ours, fixed) | 0.21 ± 0.01 | 0.38 ± 0.01 | 0.46 ± 0.01 | 0.50 ± 0.01 | 0.53 ± 0.01 | 0.54 ± 0.01 | 0.55 ± 0.01 | 0.57 ± 0.01 | 0.58 ± 0.01 | 0.59 ± 0.01 |
Table 2 (updated). Results for experiments with real-world medical data (MIMIC-III). Shown: RMSE as mean ± standard deviation over five runs.
| | $\tau = 1$ | $\tau = 2$ | $\tau = 3$ | $\tau = 4$ | $\tau = 5$ | |:----------|:---------------------|:---------------------|:---------------------|:---------------------|:----------------------| | MSMs | 6.37 ± 0.26 | 9.06 ± 0.41 | 11.89 ± 1.28 | 13.12 ± 1.25 | 14.44 ± 1.12 | | RMSNs | 5.20 ± 0.15 | 9.79 ± 0.31 | 10.52 ± 0.39 | 11.09 ± 0.49 | 11.64 ± 0.62 | | CRN | 4.84 ± 0.08 | 9.15 ± 0.16 | 9.81 ± 0.17 | 10.15 ± 0.19 | 10.40 ± 0.21 | | G-Net | 5.13 ± 0.05 | 11.88 ± 0.20 | 12.91 ± 0.26 | 13.57 ± 0.30 | 14.08 ± 0.31 | | CT (ours, fixed) | 4.60 ± 0.08 | 9.01 ± 0.21 | 9.58 ± 0.19 | 9.89 ± 0.21 | 10.12 ± 0.22 |
Figure 6 (updated). Subnetworks importance scores based on semi-synthetic benchmark (higher values correspond to higher importance of subnetwork connectivity via cross-attentions). Shown: RMSE differences between model with isolated subnetwork and full CT, means ± standard errors.
Last active entry zeroing bug
New results after fixing a bug with the synthetic tumor-growth simulator: outcome corresponding to the last entry for every time series was zeroed.
Table 9 (updated). Normalized RMSE for one-step-ahead prediction. Shown: mean and standard deviation over five runs (lower is better). Parameter $\gamma$ is the the amount of time-varying confounding: higher values mean larger treatment assignment bias.
| | $\gamma = 0$ | $\gamma = 1$ | $\gamma = 2$ | $\gamma = 3$ | $\gamma = 4$ | |:-------------------------|:------------------|:------------------|:------------------|:--------------
Related Skills
node-connect
337.7kDiagnose OpenClaw node connection and pairing failures for Android, iOS, and macOS companion apps
frontend-design
83.3kCreate distinctive, production-grade frontend interfaces with high design quality. Use this skill when the user asks to build web components, pages, or applications. Generates creative, polished code that avoids generic AI aesthetics.
openai-whisper-api
337.7kTranscribe audio via OpenAI Audio Transcriptions API (Whisper).
commit-push-pr
83.3kCommit, push, and open a PR
