SkillAgentSearch skills...

EchoJEPA

EchoJEPA: A Latent Predictive Foundation Model for Echocardiography

Install / Use

/learn @bowang-lab/EchoJEPA
About this skill

Quality Score

0/100

Supported Platforms

Universal

README

<h1 align="center"><b>EchoJEPA</b></h1> <h3 align="center">A Latent Predictive Foundation Model for Echocardiography</h3> <p align="center"> <a href="https://arxiv.org/abs/2602.02603" target="_blank"><img src="https://img.shields.io/badge/arXiv-Paper-B31B1B?style=for-the-badge&logo=arxiv&logoColor=white" alt="arXiv"></a> <a href="https://github.com/bowang-lab/EchoJEPA"><img src="https://img.shields.io/badge/GitHub-Code-4A90E2?style=for-the-badge&logo=github&logoColor=white" alt="GitHub"></a> <a href="https://echojepa.com/"><img src="https://img.shields.io/badge/Website-Online-00B89E?style=for-the-badge&logo=internet-explorer&logoColor=white" alt="Website"></a> </p>

Abstract

Foundation models for echocardiography often struggle to disentangle anatomical signal from the stochastic speckle and acquisition artifacts inherent to ultrasound. We present EchoJEPA, a foundation model trained on 18 million echocardiograms across 300K patients, representing the largest pretraining corpus for this modality to date. By leveraging a latent predictive objective, EchoJEPA learns robust anatomical representations that ignore speckle noise. We validate this using a novel multi-view probing framework with frozen backbones, where EchoJEPA outperforms state-of-the-art baselines by approximately 20% in left ventricular ejection fraction (LVEF) estimation and 17% in right ventricular systolic pressure (RVSP) estimation. The model also exhibits remarkable sample efficiency, reaching 79% view classification accuracy with only 1% of labeled data versus 42% for the best baseline trained on 100%. Crucially, EchoJEPA demonstrates superior generalization, degrading by only 2% under physics-informed acoustic perturbations compared to 17% for competitors. Most remarkably, its zero-shot performance on pediatric patients surpasses fully fine-tuned baselines, establishing latent prediction as a superior paradigm for robust, generalizable medical AI.

<p align="center"> <img src="assets/echo_fig1a.png" width=100%> </p>

EchoJEPA models trained on just 1% of labeled data outperform baselines trained on 100%. This efficiency implies that latent prediction yields dense representations capable of defining the view manifold with minimal supervision, as evidenced by the distinct anatomical clustering in the figure below.

<p align="center"> <img src="assets/umap_views.png" width=100%> </p>

EchoJEPA demonstrates anatomical localization, focusing on the mitral valve leaflets, ventricular walls, and annulus while ignoring sector background. Received attention clusters at Doppler jet edges while given attention localizes on valve structures generating flow. Across the cardiac cycle, focus shifts from valve tips during opening to chamber walls during relaxation, indicating it interprets the echocardiogram as a functional biological system.

<p align="center"> <img src="assets/echo_attention.png" width=100%> </p>

Getting Started

Setup

conda create -n vjepa2-312 python=3.12
conda activate vjepa2-312
pip install .  # or `pip install -e .` for development mode

Pretraining

Pretraining can also be run locally or distributed. Pretraining and cooldown training phases are run with the same command using different configs. These sample commands launch initial training of a ViT-L model on MIMIC-IV-ECHO, a dataset of 525K echocardiograms which can be accessed through PhysioNet.

Local

python -m app.main --fname configs/train/vitl16/pretrain-mimic-224px-16f.yaml \
  --devices cuda:0

Distributed

python -m app.main_distributed \
  --fname configs/train/vitl16/pretrain-mimic-224px-16f.yaml
  --time 6000
  --account my_account --qos=my_qos

Dataset Format

The pretrain dataset file needs to be set under data.datasets and looks something like this:

mimic-echo-224px/files/p10/p10002221/s94106955/94106955_0001.mp4 0
mimic-echo-224px/files/p10/p10002221/s94106955/94106955_0006.mp4 0
mimic-echo-224px/files/p10/p10002221/s94106955/94106955_0007.mp4 0
mimic-echo-224px/files/p10/p10002221/s94106955/94106955_0008.mp4 0
mimic-echo-224px/files/p10/p10002221/s94106955/94106955_0009.mp4 0

Pretrained Checkpoints

Since we are doing self-supervised pre-training, all the video labels are set to zero. You can begin pretraining from any of the pre-trained V-JEPA models below:

<table> <tr> <th colspan="1">Model</th> <th colspan="1">#Parameters</th> <th colspan="1">Resolution</th> <th colspan="1">Download Link</th> <th colspan="1">Pretraining Config</th> </tr> <tr> <td>ViT-L/16</td> <td>300M</td> <td>256</td> <td><a href="https://dl.fbaipublicfiles.com/vjepa2/vitl.pt">checkpoint</a></td> <td><a href="configs/train/vitl16">configs</a></td> </tr> <tr> <td>ViT-H/16</td> <td>600M</td> <td>256</td> <td><a href="https://dl.fbaipublicfiles.com/vjepa2/vith.pt">checkpoint</a></td> <td><a href="configs/train/vith16/">configs</a></td> </tr> <tr> <td>ViT-g/16</td> <td>1B</td> <td>256</td> <td><a href="https://dl.fbaipublicfiles.com/vjepa2/vitg.pt">checkpoint</a></td> <td><a href="configs/train/vitg16">configs</a></td> </tr> <tr> <td>ViT-g/16<sub>384</sub></td> <td>1B</td> <td>384</td> <td><a href="https://dl.fbaipublicfiles.com/vjepa2/vitg-384.pt">checkpoint</a></td> <td><a href="configs/train/vitg16">configs</a></td> </tr> </table>

We keep the pretraining configuration mostly the same as in V-JEPA 2, but adjust some of the sampling and augmentation parameters for echocardiography:

app: vjepa
nodes: 1
tasks_per_node: 8
cpus_per_task: 16
mem_per_gpu: 220G
folder: checkpoints/pretrain/mimic/vjepa2_vitl_224px_16f
data:
  dataset_type: VideoDataset
  datasets:
  - /home/sagemaker-user/user-default-efs/vjepa2/data/csv/mimic_annotations_s3.csv # 525k echocardiogram video clips (224px)
  datasets_weights:
  - 1.0
  batch_size: 128
  crop_size: 224              # <--- thanks to RoPE scaling, this crop size is flexible, but we keep 224 to match with other models
  patch_size: 16
  dataset_fpcs:
  - 16                        # <--- frames per clip, 16 works well in practice
  fps: 8                      # <--- set this lower for greater temporal coverage, higher for greater fidelity
  tubelet_size: 2
  num_workers: 8
  persistent_workers: true
  pin_mem: true
data_aug:
  auto_augment: false
  motion_shift: false
  random_resize_aspect_ratio: # <--- We narrow this range from [0.75, 1.35]
  - 0.9 
  - 1.1 
  random_resize_scale:        # <--- We narrow this range from [0.3, 1.0]
  - 0.5
  - 1.0

If you are not training from scratch, set optimization.checkpoint to your downloaded checkpoint path. Make sure to scale your learning rates!

Probe-based evaluation

Probe-based evaluation consists in training an attentive probe on top of frozen V-JEPA 2 features. We provide training scripts for training your own probes, and checkpoints to run inference directly.

<p align="center"> <img src="assets/echo_fig2.png" width=100%> </p>

Classification Dataset Format

For classification, prepare a two-column CSV. It should be space-delimited, with the first column being the path to the MP4, and the second being your integer label.

data/echo_views_22k/19068955.mp4 5
data/echo_views_22k/19076133.mp4 7
data/echo_views_22k/19083831.mp4 2
data/echo_views_22k/19086809.mp4 2
data/echo_views_22k/19089161.mp4 5

Regression Dataset Format

For regression, we perform standard scaling (Z-score normalization) by fitting distribution parameters (mean and standard deviation) solely on the training data to prevent data leakage, then transform all data splits to a centered distribution with zero mean and unit variance to stabilize model optimization. You can use the following code:

import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler
import pickle

# 1. Initialize the Scaler
# We use StandardScaler to center the data (Mean=0, Std=1)
scaler = StandardScaler()

# 2. Fit ONLY on the Training Set
# This prevents information leakage from the future/test set
train_values = train_clean['Value'].values.reshape(-1, 1)
scaler.fit(train_values)

print(f"Scaler fitted. Mean: {scaler.mean_[0]:.4f}, Std: {scaler.scale_[0]:.4f}")

# 3. Transform All Splits
# We create a new column 'norm_value' which the model will try to predict
train_clean['norm_value'] = scaler.transform(train_clean['Value'].values.reshape(-1, 1))
val_clean['norm_value']   = scaler.transform(val_clean['Value'].values.reshape(-1, 1))
test_clean['norm_value']  = scaler.transform(test_clean['Value'].values.reshape(-1, 1))

# 4. Save the Scaler
# CRITICAL: You need this file to convert predictions back to real LVEF % later
with open('lvef_scaler.pkl', 'wb') as f:
    pickle.dump(scaler, f)

# --- Verification ---
print("\nNormalization Check (Train Set):")
print(train_clean['norm_value'].describe())
# Expected: Mean ~ 0.0, Std ~ 1.0

print("\nExample Data:")
print(train_clean[['Value', 'norm_value']].head(3))

The resulting CSV for train, val, and test should look something like this:

data/echo_a4c_lvef/2230801.mp4 -3.4486802913030026
data/echo_a4c_lvef/3260170.mp4 -0.16931876118450664
data/echo_a4c_lvef/2758271.mp4 0.7278549632852218
data/echo_a4c_lvef/4291596.mp4 0.7278549632852218
data/echo_a4c_lvef/2350500.mp4 -0.9335615677497962
data/echo_a4c_lvef/2351242.mp4 -0.9335615677497962
data/echo_a4c_lvef/2761632.mp4 -0.9335615677497962
data/echo_a4c_lvef/2351284.mp4 -0.9335615677497962
data/echo_a4c_lvef/3257799.mp4 -0.9335615677497962
data/echo_a4c_lvef/2759135.mp4 0.8186436513424096

Make sure to make a note of the mean and standard deviation of your dataset (or save the scaling pickle), as we will need these for inference. If you are preparing a multi-view dataset, the dataset will have additional columns for your videos. The last column is

Related Skills

View on GitHub
GitHub Stars274
CategoryDevelopment
Updated12h ago
Forks42

Languages

Python

Security Score

95/100

Audited on Apr 3, 2026

No findings