GOAT
The official implementation of Global Occlusion-Aware Transformer for Robust Stereo Matching.(WACV2024)
Install / Use
/learn @Magicboomliu/GOATREADME
GOAT: Global Occlusion-Aware Transformer for Robust Stereo Matching
<div align="center"> <img src="docs/assets/architecture.png" width="800px"/> </div>Global Occlusion-Aware Transformer for Robust Stereo Matching
WACV 2024
📋 Table of Contents
- Overview
- Features
- Installation
- Project Structure
- Dataset Preparation
- Training
- Citation
- Acknowledgements
- License
🎯 Overview
GOAT (Global Occlusion-Aware Transformer) is a robust stereo matching network that achieves state-of-the-art performance on multiple benchmarks. The network features:
- Global Context Modeling: Transformer-based architecture for capturing long-range dependencies
- Occlusion Awareness: Explicit occlusion detection and handling mechanism
- Multi-Scale Processing: Pyramid cost volume construction for robust matching
- Model Architecture: GOAT-T (Tiny) optimized for accuracy and efficiency
✨ Features
- ✅ Global attention mechanism for robust feature matching
- ✅ Occlusion-aware design for handling challenging scenarios
- ✅ Support for multiple datasets: SceneFlow, KITTI, MiddleBurry, ETH3D, FAT
- ✅ Distributed training (DDP) support
- ✅ TensorBoard integration for training visualization
- ✅ Flexible loss configuration
- ✅ Comprehensive evaluation metrics (EPE, P1-Error, mIOU)
🔧 Installation
Prerequisites
- Python >= 3.7
- PyTorch >= 1.7.0
- CUDA >= 10.2 (for GPU support)
- GCC >= 5.4 (for building deformable convolution)
Quick Setup
Step 1: Clone the repository
git clone git@github.com:Magicboomliu/GOAT.git
cd GOAT
Step 2: Create and activate conda environment
conda create -n goat python=3.8
conda activate goat
Step 3: Install PyTorch
# For CUDA 11.0
pip install torch==1.7.0+cu110 torchvision==0.8.0+cu110 -f https://download.pytorch.org/whl/torch_stable.html
# Or for CUDA 10.2
pip install torch==1.7.0+cu102 torchvision==0.8.0+cu102 -f https://download.pytorch.org/whl/torch_stable.html
Step 4: Install dependencies
pip install -r requirements.txt
Step 5: Install GOAT package
pip install -e .
Step 6: Build Deformable Convolution (Optional)
cd third_party/deform
bash make.sh
cd ../..
Note: Deformable convolution is optional. The model works without it, but may achieve slightly better performance with it enabled via
--use_deformflag.
Verify Installation
# Test imports
python -c "import goat; from goat.models.networks.Methods.GOAT_T import GOAT_T; print('Installation successful!')"
# Check GPU availability
python -c "import torch; print(f'CUDA available: {torch.cuda.is_available()}, GPU count: {torch.cuda.device_count()}')"
📂 Project Structure
The repository is organized as follows:
GOAT/
├── goat/ # Main source package (models, losses, utilities)
├── data/ # Dataloaders and dataset file lists
├── scripts/ # Training and evaluation scripts
├── configs/ # Configuration files
├── third_party/ # External dependencies (deformable convolution)
├── docs/ # Documentation and assets
└── tests/ # Unit tests (future)
Key directories:
goat/models/: Network architectures (GOAT-T, GOAT-L, attention modules, etc.)goat/losses/: Loss functions for traininggoat/utils/: Utility functions and metricsdata/dataloaders/: Dataset loaders for KITTI, SceneFlow, MiddleBurry, ETH3D, FATdata/filenames/: Dataset file lists organized by dataset typescripts/: Executable training scripts
For a detailed structure guide, see STRUCTURE.md.
📋 Quick Links:
- 🚀 Quick Start Guide - Get started in 5 minutes
- ✅ GPU Ready Checklist - Verify code is ready to run
- 🔍 Verification Report - Complete verification details
- 🔄 Migration Guide - Update existing code
📊 Dataset Preparation
SceneFlow Dataset
- Download the SceneFlow dataset
- Organize the data structure as:
/path/to/sceneflow/
├── frames_cleanpass/
├── frames_finalpass/
└── disparity/
- Update the dataset path in training script:
--datapath /path/to/sceneflow
KITTI Dataset
- Download KITTI 2012 and/or KITTI 2015
- Organize as:
/path/to/kitti/
├── 2012/
│ ├── training/
│ └── testing/
└── 2015/
├── training/
└── testing/
Other Datasets
See data/filenames/ directory for supported datasets:
- MiddleBurry
- ETH3D
- FAT (Flying Automotive Things)
🚀 Training
Quick Start
Prepare your dataset first (see Dataset Preparation), then run:
# Create necessary directories
mkdir -p models_saved logs experiments_logdir
# Single GPU training (recommended to test first)
python scripts/train.py \
--cuda \
--model GOAT_T_Origin \
--loss configs/loss_config_disp.json \
--lr 1e-3 \
--batch_size 4 \
--dataset sceneflow \
--trainlist data/filenames/SceneFlow/SceneFlow_With_Occ.list \
--vallist data/filenames/SceneFlow/FlyingThings3D_Test_With_Occ.list \
--datapath /path/to/sceneflow \
--outf models_saved/goat_experiment \
--logFile logs/train.log \
--save_logdir experiments_logdir/goat_experiment \
--devices 0 \
--local_rank 0 \
--datathread 4 \
--manualSeed 1024
Multi-GPU Training (Recommended)
For faster training, use distributed data parallel (DDP) with multiple GPUs:
# 2 GPUs
CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch \
--nproc_per_node=2 \
scripts/train.py \
--cuda \
--model GOAT_T_Origin \
--loss configs/loss_config_disp.json \
--lr 1e-3 \
--batch_size 2 \
--dataset sceneflow \
--trainlist data/filenames/SceneFlow/SceneFlow_With_Occ.list \
--vallist data/filenames/SceneFlow/FlyingThings3D_Test_With_Occ.list \
--datapath /path/to/sceneflow \
--outf models_saved/goat_experiment \
--logFile logs/train.log \
--save_logdir experiments_logdir/goat_experiment \
--devices 0,1 \
--datathread 4 \
--manualSeed 1024
# 4 GPUs (adjust batch_size and nproc_per_node accordingly)
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch \
--nproc_per_node=4 \
scripts/train.py \
--cuda \
--model GOAT_T_Origin \
--loss configs/loss_config_disp.json \
--lr 1e-3 \
--batch_size 1 \
--dataset sceneflow \
--trainlist data/filenames/SceneFlow/SceneFlow_With_Occ.list \
--vallist data/filenames/SceneFlow/FlyingThings3D_Test_With_Occ.list \
--datapath /path/to/sceneflow \
--outf models_saved/goat_experiment \
--logFile logs/train.log \
--save_logdir experiments_logdir/goat_experiment \
--devices 0,1,2,3 \
--datathread 4 \
--manualSeed 1024
Using the Training Script
Modify and use the provided shell script:
# 1. Edit scripts/train.sh and update these variables:
# - datapath: path to your dataset
# - pretrain_name: experiment name
# - Other parameters as needed
# 2. Run the script
bash scripts/train.sh
Training Arguments
Required Arguments:
| Argument | Description | Example |
|----------|-------------|---------|
| --cuda | Enable CUDA | (flag, no value) |
| --model | Model architecture | GOAT_T_Origin |
| --loss | Loss configuration file | configs/loss_config_disp.json |
| --dataset | Dataset name | sceneflow |
| --trainlist | Training file list | data/filenames/SceneFlow/SceneFlow_With_Occ.list |
| --vallist | Validation file list | data/filenames/SceneFlow/FlyingThings3D_Test_With_Occ.list |
| --datapath | Dataset root path | /path/to/sceneflow |
| --outf | Output directory for models | models_saved/experiment_name |
| --logFile | Log file path | logs/train.log |
| --save_logdir | TensorBoard log directory | experiments_logdir/experiment_name |
| --devices | GPU device IDs | 0 or 0,1,2,3 |
| --local_rank | Local rank (DDP) | Auto-set by launcher |
| --datathread | Number of data loading workers | 4 |
Optional Arguments:
| Argument | Description | Default |
|----------|-------------|---------|
| --net | Legacy network name | simplenet |
| --lr | Learning rate | 0.0002 |
| --batch_size | Batch size per GPU | 8 |
| --test_batch | Test batch size | 4 |
| --maxdisp | Maximum disparity | -1 (auto) |
| --pretrain | Path to pretrained model | none |
| --initial_pretrain | Partial weight loading | none |
| --use_deform | Use deformable convolution | False |
| --startRound | Start training round | 0 |
| --startEpoch | Start epoch | 0 |
| --manualSeed | Random seed | Random |
| --workers | Number of workers | 8 |
| --momentum | SGD momentum | 0.9 |
| --beta | Adam beta | 0.999 |
With Deformable Convolution:
# First build deformable convolution (see Installation)
python scripts/train.py \
--use_deform \
... other arguments ...
Tracked Metrics:
- Total loss (combined dispari
Related Skills
node-connect
349.9kDiagnose OpenClaw node connection and pairing failures for Android, iOS, and macOS companion apps
frontend-design
109.8kCreate distinctive, production-grade frontend interfaces with high design quality. Use this skill when the user asks to build web components, pages, or applications. Generates creative, polished code that avoids generic AI aesthetics.
openai-whisper-api
349.9kTranscribe audio via OpenAI Audio Transcriptions API (Whisper).
qqbot-media
349.9kQQBot 富媒体收发能力。使用 <qqmedia> 标签,系统根据文件扩展名自动识别类型(图片/语音/视频/文件)。
