SkillAgentSearch skills...

CausalMBRL

Official data and code for our paper Systematic Evaluation of Causal Discovery in Visual Model Based Reinforcement Learning

Install / Use

/learn @dido1998/CausalMBRL
About this skill

Quality Score

0/100

Supported Platforms

Universal

README

Systematic Evaluation of Causal Discovery in Visual Model Based Reinforcement Learning

We introduce a novel suite of RL environments as a platform for investigating inductive biases, causal representations, and learning algorithms. The goal is to disentangle distinct aspects of causal learning by allowing the user to choose and modulate different properties of the ground truth causal graph, such as the structure, the size of the graph and the sparsity of the graph. ). We also provide evaluation criteria for measuring causal induction in MBRL that we argue help measure progress and facilitate further research in these directions. The paper is available at https://arxiv.org/abs/2107.00848.

Table of Contents

Physics Environment

Data Generation

  • Observed Physics Environment
  • bash scrips/gen_observed.sh num_obj Blues
  • Unobserved Physics Environment
  • bash scripts/gen_unobserved.sh num_obj Sets
  • FixedUnobserved Physics Environment
  • bash scripts/gen_unobserved_fixed.sh num_obj Sets In our experiments we use num_obj = {3,5}

Model Based Experiments

Observed Physics Environment

bash scripts/run_observed.sh num_obj model_name encoder batch_size cmap seed loss emb_dim
bash scripts/eval_observed.sh num_obj model_name encoder cmap seed loss mode emb_dim

num_obj = number of objects {3,5}
model_name = AE, VAE, Modular, GNN
encoder = medium
batch_size = 512
cmap = Blues
loss = NLL or Contrastive
emb_dim = 128
mode = test-v0

Unobserved Physics Environment

bash scripts/run_unobserved.sh num_obj model_name encoder batch_size cmap seed loss emb_dim
bash scripts/eval_unobserved.sh num_obj model_name encoder cmap seed loss mode emb_dim

num_obj = number of objects {3,5}
model_name = AE, VAE, Modular, GNN
encoder = medium
batch_size = 512
cmap = Sets
loss = NLL or Contrastive
emb_dim = 128
mode = test

FixedUnobserved Physics Environment

bash scripts/run_fixed_unobserved.sh num_obj model_name encoder batch_size cmap seed loss emb_dim
bash scripts/eval_fixed_unobserved.sh num_obj model_name encoder cmap seed loss mode emb_dim

num_obj = number of objects {3,5}
model_name = AE, VAE, Modular, GNN
encoder = medium
batch_size = 512
cmap = Sets
loss = NLL or Contrastive
emb_dim = 128
mode = test

Reinforcement Learning Experiments

The below scripts run the reinforcement learning experiments for the above trained models.

Observed Physics Environment

# This scripts will automatically load the pre-trained model with above arguments. 
bash scripts/run_reward_observed.sh num_obj model_name encoder batch_size cmap seed loss emb_dim
bash scripts/eval_rl_observed.sh num_obj model_name encoder cmap seed loss mode emb_dim steps

num_obj = {3,5}
model_name = AE, VAE, Modular, GNN
batch_size = 32
cmap = Blues
loss = NLL or Contrastive
emb_dim = 128
mode = test-v0
steps = {1,5,10}

Unobserved Physics Environment

# This scripts will automatically load the pre-trained model with above arguments. 
bash scripts/run_reward_unobserved.sh num_obj model_name encoder batch_size cmap seed loss emb_dim
bash scripts/eval_rl_unobserved.sh num_obj model_name encoder cmap seed loss mode emb_dim steps

num_obj = {3,5}
model_name = AE, VAE, Modular, GNN
batch_size = 32
cmap = Sets
loss = NLL or Contrastive
emb_dim = 128
mode = test
steps = {1,5,10}

FixedUnobserved Physics Environment

# This scripts will automatically load the pre-trained model with above arguments. 
bash scripts/run_reward_fixed_unobserved.sh num_obj model_name encoder batch_size cmap seed loss emb_dim
bash scripts/eval_rl_fixed_unobserved.sh num_obj model_name encoder cmap seed loss mode emb_dim steps

num_obj = {3,5}
model_name = AE, VAE, Modular, GNN
batch_size = 32
cmap = Sets
loss = NLL or Contrastive
emb_dim = 128
mode = test
steps = {1,5,10}

To Reproduce Physics Environment Experiments from the paper

# Generate Date
bash scripts/gen_observed.sh 3 Blues
bash scripts/gen_observed.sh 5 Blues

bash scripts/gen_unobserved.sh 3 Sets
bash scripts/gen_unobserved.sh 5 Sets

bash scripts/gen_unobserved_fixed.sh 3 Sets
bash scripts/gen_unobserved_fixed.sh 5 Sets


# Model Based Experiments
## Observed Physics Environment
### These 8 experiments are run for model_name = AE, VAE, Modular, GNN
bash scripts/run_observed.sh 3 AE medium 512 Blues 0 NLL 128
bash scripts/eval_observed.sh 3 AE medium Blues 0 NLL test-v0 128

bash scripts/run_observed.sh 3 AE medium 512 Blues 0 Contrastive 128
bash scripts/eval_observed.sh 3 AE medium Blues 0 Contrastive test-v0 128


bash scripts/run_observed.sh 5 AE medium 512 Blues 0 NLL 128
bash scripts/eval_observed.sh 5 AE medium Blues 0 NLL test-v0 128

bash scripts/run_observed.sh 5 AE medium 512 Blues 0 Contrastive 128
bash scripts/eval_observed.sh 5 AE medium Blues 0 Contrastive test-v0 128

## Unobserved Physics Environment
### These 8 experiments are run for model_name = AE, VAE, Modular, GNN
bash scripts/run_unobserved.sh 3 AE medium 512 Sets 0 NLL 128
bash scripts/eval_unobserved.sh 3 AE medium Sets 0 NLL test 128

bash scripts/run_unobserved.sh 3 AE medium 512 Sets 0 Contrastive 128
bash scripts/eval_unobserved.sh 3 AE medium Sets 0 Contrastive test 128


bash scripts/run_unobserved.sh 5 AE medium 512 Sets 0 NLL 128
bash scripts/eval_unobserved.sh 5 AE medium Sets 0 NLL test 128

bash scripts/run_unobserved.sh 5 AE medium 512 Sets 0 Contrastive 128
bash scripts/eval_unobserved.sh 5 AE medium Sets 0 Contrastive test 128


## FixedUnobserved Physics Environment
### These 8 experiments are run for model_name = AE, VAE, Modular, GNN
bash scripts/run__fixed_unobserved.sh 3 AE medium 512 Sets 0 NLL 128
bash scripts/eval_fixded_unobserved.sh 3 AE medium Sets 0 NLL test 128

bash scripts/run_fixed_unobserved.sh 3 AE medium 512 Sets 0 Contrastive 128
bash scripts/eval_fixed_unobserved.sh 3 AE medium Sets 0 Contrastive test 128


bash scripts/run_fixed_unobserved.sh 5 AE medium 512 Sets 0 NLL 128
bash scripts/eval_fixed_unobserved.sh 5 AE medium Sets 0 NLL test 128

bash scripts/run_fixed_unobserved.sh 5 AE medium 512 Sets 0 Contrastive 128
bash scripts/eval_fixed_unobserved.sh 5 AE medium Sets 0 Contrastive test 128


# Reinforcement Learning 
## The below experiments can be repeated for model_name = {AE, VAE. Modular, GNN}, loss = {NLL, Contrastive}, num_obj = {3,5}, environments = {Observed, Unobserved, FixedUnobserved}
bash scripts/run_reward_observed.sh 3 AE medium 512 Blues 0 NLL 128
bash scripts/eval_rl_observed.sh 3 AE medium Blues 0 NLL Train 128 1
bash scripts/eval_rl_observed.sh 3 AE medium Blues 0 NLL Train 128 5
bash scripts/eval_rl_observed.sh 3 AE medium Blues 0 NLL Train 128 10

Chemistry Environment

Data Generation

bash scripts/chem_data.sh num_obj num_color graph max_steps movement

num_obj = 5
num_color = 5
graph = chain<num_obj>, full<num_obj>, collider<num_obj>. For example: chain5, full5, collider5
max_steps = 10
movement = Static = The positions are fixed across episodes.
          Dynamic = The positions are varying across episodes. 

Model Based Experiments

bash scripts/run_chem.sh num_obj model_name encoder batch_size num_colors max_steps movement graph seed loss emb_dim
bash scripts/eval_chem.sh num_obj model_name encoder num_colors max_steps movement graph seed loss mode emb_dim


num_obj = 5
model_name = AE, VAE, Modular, GNN
encoder = medium
batch_size = 512
num_colors = 5
max_steps = 10
movement = {Static, Dynamic}
graph = chain<num_obj>, full<num_obj>, collider<num_obj>. For example: chain5, full5, collider5
loss = {NLL, Contrastive}
emb_dim = 128
mode = test

Reinforcement Learning Experiments

bash scripts/run_chem_reward.sh num_obj model_name encoder batch_size num_colors max_steps movement graph seed loss emb_dim
bash scripts/eval_rl_chem.sh num_obj model_name encoder num_colors max_steps movement graph seed loss mode emb_dim steps


num_obj = 5
model_name = AE, VAE, Modular, GNN
encoder = medium
batch_size = 512
num_colors = 5
max_steps = 10
movement = {Static, Dynamic}
graph = chain<num_obj>, full<num_obj>, collider<num_obj>. For example: chain5, full5, collider5
loss = {NLL, Contrastive}
emb_dim = 128
mode = test
steps = {1, 5, 10}

To Reproduce Chemistry Environment Experiments from the Paper

# Generate Data
bash scripts/chem_data.sh 5 5 chain5 10 Static
bash scripts/chem_data.sh 5 5 full5 10 Static
bash scripts/chem_data.sh 5 5 collider5 10 Static

bash scripts/chem_data.sh 5 5 chain5 10 Dynamic
bash scripts/chem_data.sh 5 5 full5 10 Dynamic
bash scripts/chem_data.sh 5 5 collider5 10 Dynamic


# Model Based Experiments
## Repeat the below experiments for model_name = {AE, VAE, Modular, GNN}
bash scripts/run_chem.sh 5 AE medium 512 5 10 Dynamic chain5 0 NLL 128
bash scripts/eval_chem.sh 5 AE medium 5 10 Dynamic chain5 0 NLL test 128

bash scripts/run_chem.sh 5 AE medium 512 5 10 Dynamic full5 0 NLL 128
bash scripts/eval_chem.sh 5 AE medium 5 10 Dynamic full5 0 NLL test 128

bash scripts/run_chem.sh 5 AE medium 512 5 10 Dynamic collider5 0 NLL 128
bash scripts/eval_chem.sh 5 AE medium 5 10 Dynamic collider5 0 NLL test 128


bash sc
View on GitHub
GitHub Stars52
CategoryEducation
Updated5mo ago
Forks15

Languages

Python

Security Score

92/100

Audited on Nov 6, 2025

No findings