CBIL
[SIGGRAPH ASIA 2024 TOG] CBIL:Collective Behavior Imitation Learning for Fish from Real Videos
Install / Use
/learn @littlecobber/CBILREADME
CBIL:Collective Behavior Imitation Learning for Fish from Real Videos <img src="https://github.com/user-attachments/assets/3e1647cd-65a8-4b76-8033-db0736208271" height="30px" align="center">
<div align="center">CBIL offical implementation
The project is still under patent review and acquring permission from SoftBank to release, however, we release core part of the simulator to perform basic pipepline training and inference.
You can download the core part of the simulator via Simulator. The repo is still updating.
News
- 2025-03: 🔥Demo code released.
- 2024-12: 🔥The code is coming soon.
- 2024-12: We present our work at SIGGRAPH ASIA 2024 in Tokyo.
- 2024-07: CBIL is accepted as SIGGRAPH ASIA 2024 Journal Track (TOG).
📌 To-Do List
- [x] Release the core code of CBIL (Fish Simulator Unity).
- [x] Release the python server implementation.
- [x] Demo code with example fish asset and texture.
- [ ] Keep updating
- [ ] Add more features
Installation
1.Download 'Simulator' Folder
2.Create a conda environment via conda create -n CBIL python=3.9 and
install dependent pip packages via pip install -r requirements.txt.
Preprocessing
-
For the preprocessing of reference video segmentation, refers to state-of-the-art method, which is: SAM2
-
For the MVAE training, refers to the code in
basic_server.py(3D Conv Based VAE, refers to MVAE folder)
Unity Simulator <img src="https://github.com/littlecobber/CBIL/blob/main/Image/uni.svg" height="30px" align="center">
The Unity version we use is 2020.2.2f1
Reward Function
- Agents' states, action space and also reward functions are all written in
Simulator\Assets\Scripts\DeepFoids\FishAgent\FishAgentMulti.cs - For example, the circling reward,
bool clockwiseto control directions
private float CalculateCirclingReward(bool clockwise)
{
Vector3 centerPoint = new Vector3(2.5f, 2.5f, 2.5f);
Vector3 fishPosition = transform.localPosition;
Vector3 toCenter = centerPoint - fishPosition;
Vector3 desiredDirection;
if (clockwise)
{
desiredDirection = Vector3.Cross(Vector3.up, toCenter).normalized;
goal_direction = desiredDirection;
}
else
{
desiredDirection = Vector3.Cross(toCenter, Vector3.up).normalized;
goal_direction = desiredDirection;
}
float dot_vel = Vector3.Dot(desiredDirection, transform.forward);
float target_velocity = 1.0f;
return high_level_weight * Scale_Reward(10.0f * (dot_vel / transform.forward.magnitude) - 10.0f * (rBody.velocity.magnitude - target_velocity) * (rBody.velocity.magnitude - target_velocity));
}
Observation Space
- For each trained policy, we predefine the dimension of observation space, and should correspond to definition in
FishAgentMulti.csin Simulator
// States needed for different high level task:
// Circling:
// Goal
sensor.AddObservation(goal_direction);
sensor.AddObservation(new Vector3(2.5f,2.5f,2.5f)-transform.localPosition);
sensor.AddObservation(Vector3.Distance(transform.localPosition, new Vector3(2.5f,2.5f,2.5f)));
// Fish States
sensor.AddObservation(rBody.velocity);
sensor.AddObservation(transform.localPosition);
sensor.AddObservation(transform.forward);
Virtual Camera
- For circling, we place a virtual camera named
agentCameraRTat the bottom of fish tank (you can manually set different location), this is used to collect frames during runtime for pretrain MVAE and also policy training with imitation learning.
void SendImageToPython()
{
lock (sendImageLock) // safe
{
VRcount++;
// Capture the images from the virtual camera
RenderTexture renderTexture = agentCamera.targetTexture;
Texture2D texture2D = new Texture2D(renderTexture.width, renderTexture.height, TextureFormat.RGB24, false);
RenderTexture.active = renderTexture;
texture2D.ReadPixels(new Rect(0, 0, renderTexture.width, renderTexture.height), 0, 0);
texture2D.Apply();
// convert images to byte
byte[] imageBytes = texture2D.EncodeToPNG();
// Clean up the Texture2D after use to free memory
UnityEngine.Object.Destroy(texture2D);
string lengthStr = imageBytes.Length.ToString("D10"); // Converts length to a 10-digit string
byte[] lengthBytes = System.Text.Encoding.UTF8.GetBytes(lengthStr);
stream.Write(lengthBytes, 0, lengthBytes.Length); // Send the length of the image first
stream.Write(imageBytes, 0, imageBytes.Length); // Then send the actual image bytes
if (VRcount == 10)
{
VRcount = 0;
// start to receiving rewards from python
byte[] buffer = new byte[4]; // assume that the reward is a 4 byte float
stream.Read(buffer, 0, buffer.Length);
float probability = System.BitConverter.ToSingle(buffer, 0);
// compute style reward
float reward = rewardFunction.ComputeReward(probability);
// add reward
AddReward(style_weight * Scale_Style_Reward(reward));
}
}
ML-Agent
Offical ML-Agent Documentation
We use ML-Agent PPO framework for reinforcement learning
ExampleTrainingConfig
behaviors:
FishAgent:
trainer_type: ppo
hyperparameters:
batch_size: 10 # 10, 1000, 1024
buffer_size: 100000 # 100, 100000, 10240
learning_rate: 0.0001 # 0.0003, 0.0001, 0.0006, 0.001
beta: 0.0005
epsilon: 0.2 # 0.2, 0.1
lambd: 0.99
num_epoch: 3 # 3
learning_rate_schedule: linear
network_settings:
normalize: false
hidden_units: 128
num_layers: 4 # 2
vis_encode_type: simple
reward_signals:
extrinsic:
gamma: 0.99
strength: 1.0
keep_checkpoints: 50
checkpoint_interval: 200000
max_steps: 4000000 # 1000000, 2000000
time_horizon: 64
summary_freq: 10000
Python Server <img src="https://github.com/user-attachments/assets/6c050651-8596-4023-bcc0-7d66e26d007e" height="30px" align="center">
- first make sure
ConnectToServer()inFishAgentMulti.csis NOT commanded out - run
basic_server.py
Training Scripts <img src="https://github.com/user-attachments/assets/91a2d2ac-dbd8-476a-8c06-5099726edd1f" height="30px" align="center">
Before training, you need to
- specify a fish prefab used for training in
TrainingManager - select the fish number for training
- choose behavior type in
Prefab_Ginjake_veryLow_DeepFoidsasdefault
Fish Parameters
For each fish prefab, for example, Prefab_Ginjake_veryLow_DeepFoids,it has multiple components.
use the default setting for a quick start.
Training high level policy
- Simply run
mlagents-learn fishagents.yaml --run-id=circling_demoto begin training
Training policy with imitation learning
- run python server first (make sure
ConnectToServer()inFishAgentMulti.csis NOT commanded out) - then run
mlagents-learn fishagents.yaml --run-id=circling_demoto begin training - the discriminator will also be trained at the same time
Inference
- Load scene
TagajoSmall_Ocean - select fish prefabs
Prefab_Ginjake_veryLow_DeepFoidsinTrainingManager - choose fish number in
TrainingManagerandfish count - Load trained policy and select
inference only - Run the simulation
Optional: inference device: cpu or gpu
Before run the simulation, first make sure DataGeneratorManager has the following preference:
- [x] MDE Mode
- [x] Disable UI
Models
Citation
If our work assists your research, feel free to give us a star ⭐ or cite us using:
@article{10.1145/3687904,
author = {Wu, Yifan and Dou, Zhiyang and Ishiwaka, Yuko and Ogawa, Shun and Lou, Yuke and Wang, Wenping and Liu, Lingjie and Komura, Taku},
title = {CBIL: Collective Behavior Imitation Learning for Fish from Real Videos},
year = {2024},
issue_date = {December 2024},
publisher = {Association for Computing Machinery},
address = {New York, NY, USA},
volume = {43},
number = {6},
issn = {0730-0301},
url = {https://doi.org/10.1145/3687904},
doi = {10.1145/3687904},
abstract = {Reproducing realistic collective behaviors presents a captivating yet formidable challenge. Traditional rule-based methods rely on hand-crafted principles, limiting motion diversity and realism in generated collective behaviors. Recent imitation learning methods learn from data but often require ground-truth motion
Related Skills
qqbot-channel
343.1kQQ 频道管理技能。查询频道列表、子频道、成员、发帖、公告、日程等操作。使用 qqbot_channel_api 工具代理 QQ 开放平台 HTTP 接口,自动处理 Token 鉴权。当用户需要查看频道、管理子频道、查询成员、发布帖子/公告/日程时使用。
docs-writer
99.7k`docs-writer` skill instructions As an expert technical writer and editor for the Gemini CLI project, you produce accurate, clear, and consistent documentation. When asked to write, edit, or revie
model-usage
343.1kUse CodexBar CLI local cost usage to summarize per-model usage for Codex or Claude, including the current (most recent) model or a full model breakdown. Trigger when asked for model-level usage/cost data from codexbar, or when you need a scriptable per-model summary from codexbar cost JSON.
ddd
Guía de Principios DDD para el Proyecto > 📚 Documento Complementario : Este documento define los principios y reglas de DDD. Para ver templates de código, ejemplos detallados y guías paso
