Gsrl
This is the official codebase for the paper "Sensor-Invariant Tactile Representation" (ICLR 2025).
Install / Use
/learn @hgupt3/GsrlREADME
Sensor-Invariant Tactile Representation (SITR)
This is the official codebase for the paper "Sensor-Invariant Tactile Representation" (ICLR 2025). For more details, visit our project website or read the arXiv paper.
For instructions on how to generate your own simulation dataset, you can visit Generation Codebase.
Table of Contents
- Overview
- Installation
- Quick Start
- Datasets and Weights
- Usage
- Data Structure and Loading
- Network Architectures
- Project Structure
- Citation
- License
Overview
This project implements SITR (Sensor-Invariant Tactile Representations), a novel method for extracting sensor-invariant representations that enables zero-shot transfer across optical tactile sensors. Our approach utilizes a transformer-based architecture trained on a diverse dataset of simulated sensor designs, allowing it to generalize to new sensors in the real world with minimal calibration.
The codebase provides implementations for:
- Base representation learning using a transformer architecture with supervised contrastive learning
- Object classification across multiple sensor designs (GelSight Mini [4 variations], DIGIT, Hex, Wedge)
- Pose estimation across multiple sensor designs (GelSight Mini, DIGIT, Hex, Wedge)
Installation
-
Clone this repository:
git clone https://github.com/hgupt3/gsrl.git cd gsrl -
Install the required dependencies:
pip install -r requirements.txt
Quick Start
To quickly get started with SITR:
- Download the pre-trained weights and dataset from our HuggingFace dataset. Instructions are present there.
- Run the classification demo:
python plot_classifier.py --load your_path/checkpoints/classification/ --val-path your_path/classification_dataset/ - Run the pose estimation demo:
python plot_pose_estimater.py --load checkpoints/pose_estimation/ --val-path your_path/pose_dataset/
Usage
Classification
Evaluate and visualize classification results:
python plot_classifier.py
Arguments:
--base_model: Model architecture (default: 'SITR_base')--load: Path to load model weights from (default: 'checkpoints/classification/')--batch-size: Batch size for evaluation (default: 16)--amp: Enable mixed precision training (default: True)--calibration-config: Number of calibration images to use (0, 4, 8, 9, or 18) (default: 18)--device: Device to run evaluation on (default: 'cuda:0')--val-path: Path to validation dataset (default: 'classification_dataset/val_set')
Pose Estimation
Evaluate and visualize pose estimation results:
python plot_pose_estimater.py
Arguments:
--base_model: Model architecture (default: 'SITR_base')--load: Path to load model weights from (default: 'checkpoints/pose_estimation/')--batch-size: Batch size for evaluation (default: 32)--amp: Enable mixed precision training (default: True)--calibration-config: Number of calibration images to use (0, 4, 8, 9, or 18) (default: 18)--device: Device to run evaluation on (default: 'cuda:0')--val-path: Path to validation dataset (default: 'pose_dataset/val_set')
Data Structure and Loading
The codebase provides data loader classes for different tasks. Each data loader expects a specific data structure and supports various configuration options.
Directory Structures
1. Base Representation Learning (sim_dataset)
data_root/
├── sensor_0000/
│ ├── calibration/ # Calibration images
│ │ ├── 0000.png # Background image
│ │ ├── 0001.png # Calibration image 1
│ │ └── ... # More calibration images
│ ├── samples/ # Sample images
│ │ ├── 0000.png
│ │ ├── 0001.png
│ │ └── ...
│ ├── dmaps/ # Depth maps (optional)
│ │ ├── 0000.npy
│ │ └── ...
│ └── norms/ # Surface normals (optional)
│ ├── 0000.npy
│ └── ...
├── sensor_0001/
└── ...
2. Classification (classification_dataset)
data_root/
├── sensor_0000/
│ ├── calibration/ # Same as above
│ ├── samples/ # Organized by class
│ │ ├── class_0000/ # Class 0 samples
│ │ │ ├── 0000.png
│ │ │ └── ...
│ │ ├── class_0001/ # Class 1 samples
│ │ │ ├── 0000.png
│ │ │ └── ...
│ │ └── ...
│ └── ...
├── sensor_0001/
└── ...
3. Pose Estimation (pose_dataset)
data_root/
├── sensor_0000/
│ ├── calibration/ # Same as above
│ ├── samples/ # Sample images
│ │ ├── 0000.png
│ │ ├── 0001.png
│ │ └── ...
│ └── locations/ # Location data for pose estimation
│ ├── 0000.npy
│ ├── 0001.npy
│ └── ...
├── sensor_0001/
└── ...
Data Preprocessing
The data loaders automatically handle:
- Background subtraction using calibration images
- Image normalization using pre-computed statistics
Data Loader Classes
-
sim_dataset: Base dataset for representation learning
- Supports contrastive learning with paired samples
- Handles calibration images and background subtraction
- Configuration options:
calibration_config: Number of calibration images (0, 4, 8, 9, or 18)augment: Enable data augmentationsendTwo: Enable contrastive learning with paired samplesnum_samples: Number of samples to use (optional)num_sensors: Number of sensors to use (optional)
-
classification_dataset: For object classification tasks
- Supports multiple sensor configurations
- Handles class labels and sensor IDs
- Configuration options:
sensor_list: List of sensor IDs to useclass_list: List of class IDs to includecalibration_config: Number of calibration imagesaugment: Enable data augmentation
-
pose_dataset: For pose estimation tasks
- Supports sequential and random sample pairs
- Handles location data for pose estimation
- Configuration options:
sensor_idx: Specific sensor to userandom_final: Use random or sequential final statesaugment: Enable data augmentation
Using Your Own Data
To use the data loaders with your own data:
-
Organize your data following the appropriate directory structure above
-
Ensure all images (calibration and samples) are:
- Resolution: 224x224 pixels
- Ideally cropped around the center of the sensor for best results
-
Ensure calibration images are properly named and organized as follows:
Calibration Image Layout
Background: 0000.png (sensor without any object) Sphere Calibration (4mm diameter sphere): ┌────────────┬────────────┬────────────┐ │ 0001.png │ 0002.png │ 0003.png │ │ Top Left │ Top Middle │ Top Right │ ├────────────┼────────────┼────────────┤ │ 0004.png │ 0005.png │ 0006.png │ │ Mid Left │ Center │ Mid Right │ ├────────────┼────────────┼────────────┤ │ 0007.png │ 0008.png │ 0009.png │ │ Bot Left │ Bot Middle │ Bot Right │ └────────────┴────────────┴────────────┘ Cube Corner Calibration (any reasonable cube): ┌────────────┬────────────┬────────────┐ │ 0010.png │ 0011.png │ 0012.png │ │ Top Left │ Top Middle │ Top Right │ ├────────────┼────────────┼────────────┤ │ 0013.png │ 0014.png │ 0015.png │ │ Mid Left │ Center │ Mid Right │ ├────────────┼────────────┼────────────┤ │ 0016.png │ 0017.png │ 0018.png │ │ Bot Left │ Bot Middle │ Bot Right │ └────────────┴────────────┴────────────┘For each calibration image, place the object (sphere or cube) roughly indented in the corresponding region of the sensor. For example, in
0001.png, the sphere should be indented in the top-left region. -
Configure the appropriate data loader with your sensor IDs or class labels
Examples:
# Base representation learning
from dataloaders import sim_dataset
dataset = sim_dataset(
path='your_data_path',
augment=True,
calibration_config=18,
sendTwo=True # for contrastive learning
)
# Classification
from dataloaders import classification_dataset
dataset = classification_dataset(
path='your_data_path',
sensor_list=[0, 1],
class_list=[0, 1, 2],
augment=True,
calibration_config=18
)
# Pose estimation
from dataloaders import pose_dataset
dataset = pose_dataset(
path='your_data_path',
sensor_idx=0,
random_final=False,
augment=True
)
Network Architectures
The codebase implements several neural network architectures for different tasks:
Base Model (SITR)
The core architecture is the Sensor-Invariant Tactile Representation (SITR) model, which consists of:
- Vision Transformer backbone with patch embedding
- Calibration-specific components for processing calibration images
- Contrastive learning head for sensor-invariant feature extraction
- Reconstruction decoder for self-supervised learning
Task-Specific Networks
- Classification Network (
classification_net) - Pose Estimation Network (
pose_estimation_net) - MLP Network (
MLP_net)
Example usage:
from models.networks import SITR_base, classification_net, pose_estimation_net, MLP_net
# Initialize base model
base = SITR_base(num_calibration=18)
# Cr
