UniMASK
Codebase for "Uni[MASK]: Unified Inference in Sequential Decision Problems"
Install / Use
/learn @micahcarroll/UniMASKREADME
Introduction
uniMASK is a generalization of BERT models with flexible abstractions for performing inference on subportions of sequences. Masking and prediction can occur both on the token level (as in traditional transformer), or even on subportions of tokens.
You can find the full paper here
Getting Started
To install uniMASK, run:
conda create -n uniMASK python=3.7
conda activate uniMASK
pip install -e .
uniMASK requires D4RL. You may install as detailed in the D4RL repo, e.g., by running:
pip install git+https://github.com/Farama-Foundation/d4rl@master#egg=d4rl
For CUDA support, you may need to reinstall pytorch in CUDA mode, for example:
pip install torch --extra-index-url https://download.pytorch.org/whl/cu116
To verify that the installation was successful, run pytest.
Reproducing results from the paper
Minigrid heatmap (figure 7)
Note: Reproducing all runs can a long time. We recommend parallelizing runs. In each script, the first line (comment) contains an example of how to use GNU Parallel towards this end.
- Run the commands found in
minigrid_repro.sh. - Fine-tune the pre-trained models generated in the previous step by running the commands in
minigrid_ft_repro.sh. - Generate the heatmaps from these runs by running
minigrid_heatmap.sh(no parallelization here). - You may then find the heatmap at uniMASK/scripts.
Maze2D results
To reproduce the Maze2D table in the paper:
- Run
wandbsweeps:medium_maze_sweep_all.yamlandmedium_maze_sweep_DT.yaml - Run finetuning runs:
maze_ft_final.sh - Parse results with
Parse wandb Maze Experiments.ipynb
File structure
scripts/train.py: the main script from running uniMASK -- start here.data/: where rollouts (datasets) and trained models (transformer_runs) are stored.envs/: data-handling and evaluation for each supported environment. Currentlyscripts/: reproducing results from the paper, and running uniMASK in general.batches.py: has all data pipeline processing classes (FactorSeq, TokenSeq, FullTokenSeq, Batch, SubBatch)sequences.py:trainer.py: the Trainer class handles the training loop for all models.transformer.py: contains the transformer model class itself.transformer_train.py: interface and config setting for training a transformer, throughTrainerclass.utils.py: misc utilities, namely math functions, gpu handling, profiling, etc.transformer_eval.py: interface for getting predictions from transformer (currently empty).
Related Skills
proje
Interactive vocabulary learning platform with smart flashcards and spaced repetition for effective language acquisition.
groundhog
398Groundhog's primary purpose is to teach people how Cursor and all these other coding agents work under the hood. If you understand how these coding assistants work from first principles, then you can drive these tools harder (or perhaps make your own!).
last30days-skill
17.5kAI agent skill that researches any topic across Reddit, X, YouTube, HN, Polymarket, and the web - then synthesizes a grounded summary
sec-edgar-agentkit
10AI agent toolkit for accessing and analyzing SEC EDGAR filing data. Build intelligent agents with LangChain, MCP-use, Gradio, Dify, and smolagents to analyze financial statements, insider trading, and company filings.
