Solaris
The first multiplayer video world model in Minecraft
Install / Use
/learn @solaris-wm/SolarisREADME
![]()
Overview
This repository contains the JAX implementation of the Solaris multiplayer world model for Minecraft. It supports GCP TPU training and inference, and GPU inference. It also contains the source code for the VLM-as-a-judge multiplayer self-consistency metric.
Inference (GPU)
Set up Python env
conda env create -f environment.yml
conda activate solaris
pip install -r requirements_gpu.txt
pip install -e .
Download pretrained model weights
hf download nyu-visionx/solaris --local-dir ./pretrained
See the nyu-visionx/solaris HF model repo for all available model weights.
Download eval datasets
hf download nyu-visionx/solaris-eval-datasets --local-dir ./datasets --repo-type dataset
See the nyu-visionx/solaris-eval-datasets for all available evaluation datasets.
Simple inference
For the simplest scenario, run this:
CUDA_VISIBLE_DEVICES=0 python src/inference.py experiment_name=solaris device.eval_num_samples=1
It assumes the datasets are in ./datasets and uses the pretrained model weights at ./pretrained/solaris.pt. It will generate 1 video per eval dataset and write generated videos to ./output/. If you want to run on multiple GPUs, adjust the CUDA_VISIBLE_DEVICES env variable, making sure device.eval_num_samples is divisible by it. Inference always uses a per-device batch size of 1, which requires the GPU device to have at least 48GB memory. Refer to the sharding section for details.
You might see the following GPU log messages:
2026-02-25 08:28:29.343101: E external/xla/xla/stream_executor/cuda/cuda_timer.cc:86] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.
2026-02-25 08:28:29.472418: E external/xla/xla/stream_executor/cuda/cuda_timer.cc:86] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.
2026-02-25 08:28:34.231109: W external/xla/xla/tsl/framework/bfc_allocator.cc:310] Allocator (GPU_0_bfc) ran out of memory trying to allocate 36.68GiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available.
These are warnings and you can disregard them.
</details>Evaluation
VLM metric
The code for the VLM evaluation metric lives under vlm_eval/. Refer to vlm_eval/README.md for how to run it and the implementation details.
FID
To get the FID number, check the inference script log file. It outputs FID numbers as log messages by default.
Training
Only TPU training is supported and requires a device with at least 95GB of memory (v5p) with a per-device batch size of 1. Refer to the sharding section for details.
Set up Python env
conda env create -f environment.yml
conda activate solaris
pip install -r requirements_tpu.txt
pip install -e .
In a multi-host TPU setting, you will need your conda environment on all hosts, which can be achieved by wrapping your installation instruction with gcloud alpha compute tpus tpu-vm ssh --command {COMMAND}.
TPU Storage Setup
There are many ways to store data on GCP TPUs, such as Persistent Disks or GCS buckets. Refer to the official guide for how to set it up. Note that your storage option will need to support writing as well to save training checkpoints and generated outputs.
Download pretrained model weights
hf download nyu-visionx/solaris --local-dir YOUR_STORAGE_PATH/pretrained
See the nyu-visionx/solaris HF model repo for all available model weights.
Download eval datasets
hf download nyu-visionx/solaris-eval-datasets --local-dir YOUR_STORAGE_PATH/datasets --repo-type dataset
Download training datasets
Multiplayer Duet dataset
hf download nyu-visionx/solaris-training-dataset --local-dir YOUR_STORAGE_PATH/datasets
The multiplayer Duet dataset is stored in a sharded form on HuggingFace. The above command will download it into YOUR_STORAGE_PATH/datasets/duet_sharded. Run the below command to unshard it to the original format that this codebase can work with:
python unshard_dataset.py --shards YOUR_STORAGE_PATH/datasets/duet_sharded --out YOUR_STORAGE_PATH/datasets/duet
Single player VPT dataset
The full training pipeline requires training on the single-player VPT dataset. Refer to vpt_datasets/README.md for instructions on how to set it up.
Training stages
The training pipeline consists of four stages, each backed by a dedicated runner:
- Stage 1 — Single-player bidirectional pretraining
- Stage 2 — Multiplayer bidirectional training
- Stage 3 — Multiplayer causal training
- Stage 4 — Multiplayer self-forcing training
Below are the four example commands to run each training stage. Edit the folder paths to where you set them up and run the command as part of gcloud alpha compute tpus tpu-vm ssh --command {COMMAND} in a multi-host setting.
Note that running training automatically runs inference on the test split of the datasets. The training step and inference are JIT compiled functions which can time when running for the first time so the script might appear hanging at the beginning of the training and at the first evaluation.
Stage 1 – Single-player bidirectional pretraining
This stage pretrains the initial Matrix Game 2.0 weights (available as matrix-game-init) on the VPT dataset, extending the action space.
python src/train.py \
runner=trainer_sp \
model=single_player \
dataset=vpt \
+dataset@eval_datasets.vpt=vpt \
~dataset@eval_datasets.duet \
experiment_name=sp_bidirectional_pretrain \
wandb_entity="YOUR_WANDB_ENTITY" \
device.batch_size=64 \
device.eval_num_samples=64 \
device.data_dir="YOUR_DATASETS_DIR" \
device.pretrained_model_dir="YOUR_PRETRAINED_MODEL_DIR" \
device.output_dir="YOUR_OUTPUT_DIR" \
device.checkpoint_dir="YOUR_CHECKPOINT_DIR" \
device.jax_cache_dir="YOUR_JAX_CACHE_DIR
It will train for 120K steps. The final model weights are the initialization for Stage 2. It will save them to YOUR_PRETRAINED_MODEL_DIR/sp_bidirectional_pretrain_120000.pt.
Stage 2 – Multiplayer bidirectional
This stage trains the multiplayer bidirectional model on the Duet datasets obtained from SolarisEngine, starting from the pretrained single player model.
python src/train.py \
runner=trainer_mp_bidirectional \
experiment_name=mp_bidirectional \
wandb_entity="YOUR_WANDB_ENTITY" \
device.data_dir="YOUR_DATASETS_DIR" \
device.pretrained_model_dir="YOUR_PRETRAINED_MODEL_DIR" \
device.output_dir="YOUR_OUTPUT_DIR" \
device.checkpoint_dir="YOUR_CHECKPOINT_DIR" \
device.jax_cache_dir="YOUR_JAX_CACHE_DIR
It starts from the model weights at YOUR_PRETRAINED_MODEL_DIR/sp_bidirectional_pretrain_120000.pt and trans for 120k steps. Its final model weights are initialization for Stage 3 and the teacher and critic in Stage 4.
It will save them to YOUR_PRETRAINED_MODEL_DIR/mp_bidirectional_120000.pt.
Stage 3 – Multiplayer causal
This stage converts the multiplayer bidirectional model to causal using the Diffusion Forcing objective and a causal attention mask, training on the same Duet dataset.
python src/train.py \
runner=trainer_mp_causal \
experiment_name=mp_causal \
wandb_entity="YOUR_WANDB_ENTITY" \
device.data_dir="YOUR_DATASETS_DIR" \
device.pretrained_model_dir="YOUR_PRETRAINED_MODEL_DIR" \
device.output_dir="YOUR_OUTPUT_DIR" \
device.checkpoint_dir="YOUR_CHECKPOINT_DIR" \
device.jax_cache_dir="YOUR_JAX_CACHE_DIR
It starts from the model weights at YOUR_PRETRAINED_MODEL_DIR/mp_bidirectional_120000.pt and trans for 60k steps. Its final model weights are initialization for the student in Stage 4.
It will save them to YOUR_PRETRAINED_MODEL_DIR/mp_causal_60000.pt.
Stage 4 – Multiplayer self-forcing
This stage finetunes the multiplayer causal model (student) on its own rollouts, distilling from the multiplayer bidirectional model (teacher). This stage removes the test time distribution mismatch and makes the final multiplayer causal model a few-step diffusion model.
python src/train.py \
runner=trainer_mp_sf \
experiment_name=mp_sf \
wandb_entity="YOUR_WANDB_ENTITY" \
device.data_dir="YOUR_DATASETS_DIR" \
device.pretrained_model_dir="YOUR_PRETRAINED_MODEL_DIR" \
device.output_dir="YOUR_OUTPUT_DIR" \
device.checkpoint_dir="YOUR_CHECKPOINT_DIR" \
device.jax_cache_dir="YOUR_JAX_CACHE_DIR" \
save_model_state_to="YOUR_PRETRAINED_MODEL_DIR/solaris.pt"
It initializes the student from YOUR_PRETRAINED_MODEL_DIR/mp_causal_60000.pt, and the teacher and critic from YOUR_PRETRAINED_MODEL_DIR/mp_bidirectional_120000.pt, and trains for 1.2K steps. It will save the final model weights to YOUR_PRETRAINED_MODEL_DIR/solaris.pt which can be used for inference and evaluation.
TPU inference
TPU Inference requires t
