SkillAgentSearch skills...

CausalTransformer

Code for the paper "Causal Transformer for Estimating Counterfactual Outcomes"

Install / Use

/learn @Valentyn1997/CausalTransformer
About this skill

Quality Score

0/100

Supported Platforms

Universal

README

CausalTransformer

Conference arXiv Python application

Causal Transformer for estimating counterfactual outcomes over time.

<img width="1518" alt="Screenshot 2022-06-03 at 16 41 44" src="https://user-images.githubusercontent.com/23198776/171877145-c7cba15e-9787-4594-8f1f-cbb8b337b74a.png">

The project is built with following Python libraries:

  1. Pytorch-Lightning - deep learning models
  2. Hydra - simplified command line arguments management
  3. MlFlow - experiments tracking

Installations

First one needs to make the virtual environment and install all the requirements:

pip3 install virtualenv
python3 -m virtualenv -p python3 --always-copy venv
source venv/bin/activate
pip3 install -r requirements.txt

MlFlow Setup / Connection

To start an experiments server, run:

mlflow server --port=5000

To access MlFLow web UI with all the experiments, connect via ssh:

ssh -N -f -L localhost:5000:localhost:5000 <username>@<server-link>

Then, one can go to local browser http://localhost:5000.

Experiments

Main training script is universal for different models and datasets. For details on mandatory arguments - see the main configuration file config/config.yaml and other files in configs/ folder.

Generic script with logging and fixed random seed is following (with training-type enc_dec, gnet, rmsn and multi):

PYTHONPATH=. CUDA_VISIBLE_DEVICES=<devices> 
python3 runnables/train_<training-type>.py +dataset=<dataset> +backbone=<backbone> exp.seed=10 exp.logging=True

Backbones (baselines)

One needs to choose a backbone and then fill the specific hyperparameters (they are left blank in the configs):

Models already have best hyperparameters saved (for each model and dataset), one can access them via: +backbone/<backbone>_hparams/cancer_sim_<balancing_objective>=<coeff_value> or +backbone/<backbone>_hparams/mimic3_real=diastolic_blood_pressure.

For CT, EDCT, and CT, several adversarial balancing objectives are available:

  • counterfactual domain confusion loss (this paper): exp.balancing=domain_confusion
  • gradient reversal (originally in CRN, but can be used for all the methods): exp.balancing=grad_reverse

To train a decoder (for CRN and RMSNs), use the flag model.train_decoder=True.

To perform a manual hyperparameter tuning use the flags model.<sub_model>.tune_hparams=True, and then see model.<sub_model>.hparams_grid. Use model.<sub_model>.tune_range to specify the number of trials for random search.

Datasets

One needs to specify a dataset / dataset generator (and some additional parameters, e.g. set gamma for cancer_sim with dataset.coeff=1.0):

  • Synthetic Tumor Growth Simulator: +dataset=cancer_sim
  • MIMIC III Semi-synthetic Simulator (multiple treatments and outcomes): +dataset=mimic3_synthetic
  • MIMIC III Real-world dataset: +dataset=mimic3_real

Before running MIMIC III experiments, place MIMIC-III-extract dataset (all_hourly_data.h5) to data/processed/

Example of running Causal Transformer on Synthetic Tumor Growth Generator with gamma = [1.0, 2.0, 3.0] and different random seeds (total of 30 subruns), using hyperparameters:

PYTHONPATH=. CUDA_VISIBLE_DEVICES=<devices> 
python3 runnables/train_multi.py -m +dataset=cancer_sim +backbone=ct +backbone/ct_hparams/cancer_sim_domain_conf=\'0\',\'1\',\'2\' exp.seed=10,101,1010,10101,101010

Updated results

Self- and cross-attention bug

New results for semi-synthetic and real-world experiments after fixing a bug with self- and cross-attentions (https://github.com/Valentyn1997/CausalTransformer/issues/7). Therein, the bug affected only Tables 1 and 2, and Figure 5 (https://arxiv.org/pdf/2204.07258.pdf). Nevertheless, the performance of the CT with the bug fixed did not change drastically.

Table 1 (updated). Results for semi-synthetic data for $\tau$-step-ahead prediction based on real-world medical data (MIMIC-III). Shown: RMSE as mean ± standard deviation over five runs.

| | $\tau = 1$ | $\tau = 2$ | $\tau = 3$ | $\tau = 4$ | $\tau = 5$ | $\tau = 6$ | $\tau = 7$ | $\tau = 8$ | $\tau = 9$ | $\tau = 10$ | |:-------------------------|:---------------------|:---------------------|:---------------------|:---------------------|:---------------------|:---------------------|:---------------------|:--------------------------------|:------------------------------------|:------------------------------------| | MSMs | 0.37 ± 0.01 | 0.57 ± 0.03 | 0.74 ± 0.06 | 0.88 ± 0.03 | 1.14 ± 0.10 | 1.95 ± 1.48 | 3.44 ± 4.57 | > 10.0 | > 10.0 | > 10.0 | | RMSNs | 0.24 ± 0.01 | 0.47 ± 0.01 | 0.60 ± 0.01 | 0.70 ± 0.02 | 0.78 ± 0.04 | 0.84 ± 0.05 | 0.89 ± 0.06 | 0.94 ± 0.08 | 0.97 ± 0.09 | 1.00 ± 0.11 | | CRN | 0.30 ± 0.01 | 0.48 ± 0.02 | 0.59 ± 0.02 | 0.65 ± 0.02 | 0.68 ± 0.02 | 0.71 ± 0.01 | 0.72 ± 0.01 | 0.74 ± 0.01 | 0.76 ± 0.01 | 0.78 ± 0.02 | | G-Net | 0.34 ± 0.01 | 0.67 ± 0.03 | 0.83 ± 0.04 | 0.94 ± 0.04 | 1.03 ± 0.05 | 1.10 ± 0.05 | 1.16 ± 0.05 | 1.21 ± 0.06 | 1.25 ± 0.06 | 1.29 ± 0.06 | | EDCT (GR; $\lambda = 1$) | 0.29 ± 0.01 | 0.46 ± 0.01 | 0.56 ± 0.01 | 0.62 ± 0.01 | 0.67 ± 0.01 | 0.70 ± 0.01 | 0.72 ± 0.01 | 0.74 ± 0.01 | 0.76 ± 0.01 | 0.78 ± 0.01 | | CT ($\alpha = 0$) (ours, fixed) | 0.20 ± 0.01 | 0.38 ± 0.01 | 0.46 ± 0.01 | 0.50 ± 0.01 | 0.52 ± 0.01 | 0.54 ± 0.01 | 0.56 ± 0.01 | 0.57 ± 0.01 | 0.59 ± 0.01 | 0.60 ± 0.01 | | CT (ours, fixed) | 0.21 ± 0.01 | 0.38 ± 0.01 | 0.46 ± 0.01 | 0.50 ± 0.01 | 0.53 ± 0.01 | 0.54 ± 0.01 | 0.55 ± 0.01 | 0.57 ± 0.01 | 0.58 ± 0.01 | 0.59 ± 0.01 |

Table 2 (updated). Results for experiments with real-world medical data (MIMIC-III). Shown: RMSE as mean ± standard deviation over five runs.

| | $\tau = 1$ | $\tau = 2$ | $\tau = 3$ | $\tau = 4$ | $\tau = 5$ | |:----------|:---------------------|:---------------------|:---------------------|:---------------------|:----------------------| | MSMs | 6.37 ± 0.26 | 9.06 ± 0.41 | 11.89 ± 1.28 | 13.12 ± 1.25 | 14.44 ± 1.12 | | RMSNs | 5.20 ± 0.15 | 9.79 ± 0.31 | 10.52 ± 0.39 | 11.09 ± 0.49 | 11.64 ± 0.62 | | CRN | 4.84 ± 0.08 | 9.15 ± 0.16 | 9.81 ± 0.17 | 10.15 ± 0.19 | 10.40 ± 0.21 | | G-Net | 5.13 ± 0.05 | 11.88 ± 0.20 | 12.91 ± 0.26 | 13.57 ± 0.30 | 14.08 ± 0.31 | | CT (ours, fixed) | 4.60 ± 0.08 | 9.01 ± 0.21 | 9.58 ± 0.19 | 9.89 ± 0.21 | 10.12 ± 0.22 |

Figure 6 (updated). Subnetworks importance scores based on semi-synthetic benchmark (higher values correspond to higher importance of subnetwork connectivity via cross-attentions). Shown: RMSE differences between model with isolated subnetwork and full CT, means ± standard errors.

subnet-isolation

Last active entry zeroing bug

New results after fixing a bug with the synthetic tumor-growth simulator: outcome corresponding to the last entry for every time series was zeroed.

Table 9 (updated). Normalized RMSE for one-step-ahead prediction. Shown: mean and standard deviation over five runs (lower is better). Parameter $\gamma$ is the the amount of time-varying confounding: higher values mean larger treatment assignment bias.

| | $\gamma = 0$ | $\gamma = 1$ | $\gamma = 2$ | $\gamma = 3$ | $\gamma = 4$ | |:-------------------------|:------------------|:------------------|:------------------|:--------------

Related Skills

View on GitHub
GitHub Stars183
CategoryDevelopment
Updated13d ago
Forks41

Languages

Python

Security Score

95/100

Audited on Mar 13, 2026

No findings