Metade
MetaDE is a GPU-accelerated evolutionary framework that optimizes Differential Evolution (DE) strategies via meta-level evolution. Supporting both JAX and PyTorch, it dynamically adapts mutation and crossover strategies for efficient large-scale black-box optimization.
Install / Use
/learn @EMI-Group/MetadeREADME
MetaDE is an advanced evolutionary framework that dynamically optimizes the strategies and hyperparameters of Differential Evolution (DE) through meta-level evolution. By leveraging DE to fine-tune its own configurations, MetaDE adapts mutation and crossover strategies to suit varying problem landscapes in real-time. With GPU acceleration, it handles large-scale, complex black-box optimization problems efficiently, delivering faster convergence and superior performance. MetaDE is compatible with the <a href="https://github.com/EMI-Group/evox">EvoX</a> framework.
New in this version:
MetaDE now fully supports both JAX and PyTorch backends.
- PyTorch backend updates:
- Now supports Brax-based RL tasks(requires additional installation of JAX).
- Significant performance improvements when used with EvoX v1.1.1, achieving up to 2x speedup compared to previous versions.
- Strongly recommended to use EvoX 1.1.1 and GPU-enabled PyTorch for optimal performance.
- JAX backend:
- Remains fully supported and recommended if you prefer CUDA-enabled JAX.
- Generally offers approximately 2x the speed compared to the PyTorch backend.
To replicate experiments from the paper exactly, you may still opt for the JAX backend with a CUDA-enabled JAX (and Brax) installation.
Features
- Meta-level Evolution 🌱: Uses DE at a meta-level to evolve hyperparameters and strategies of DE applied at a problem-solving level.
- Parameterized DE (PDE) 🛠️: A customizable variant of DE that offers dynamic mutation and crossover strategies adaptable to different optimization problems.
- Multi-Backend Support 🔥: Provides both JAX and PyTorch implementations for broader hardware/software compatibility.
- GPU-Accelerated 🚀: Integrated with GPU acceleration on both JAX and PyTorch, enabling efficient large-scale optimizations.
- End-to-End Optimization 🔄: MetaDE provides a seamless workflow from hyperparameter tuning to solving optimization problems in a fully automated process.
- Wide Applicability 🤖: Supports various benchmarks, including CEC2022, and real-world tasks like evolutionary reinforcement learning in robotics.
RL Tasks Visualization
Using the MetaDE algorithm to solve RL tasks.
The following animations show the behaviors in Brax environments:
<table width="81%"> <tr> <td width="27%"> <img width="200" height="200" style="display:block; margin:auto;" src="./assets/hopper.gif"> </td> <td width="27%"> <img width="200" height="200" style="display:block; margin:auto;" src="./assets/swimmer.gif"> </td> <td width="27%"> <img width="200" height="200" style="display:block; margin:auto;" src="./assets/reacher.gif"> </td> </tr> <tr> <td align="center"> Hopper </td> <td align="center"> Swimmer </td> <td align="center"> Reacher </td> </tr> </table>- Hopper: Aiming for maximum speed and jumping height.
- Swimmer: Enhancing movement efficiency in fluid environments.
- Reacher: Moving the fingertip to a random target.
Requirements
Depending on which backend you plan to use (JAX or PyTorch), you should install the proper libraries and GPU dependencies:
-
Common:
- evox (version == 1.1.1 for PyTorch support)
brax == 0.10.3(optional, if you want to run Brax RL problems)
-
JAX-based version:
jax >= 0.4.16jaxlib >= 0.3.0
-
PyTorch-based version:
torch(GPU version recommended, e.g.torch>=2.5.0)torchvision,torchaudio(optional, depending on your environment/needs)
Installation
You can install MetaDE with either the JAX or PyTorch backend (or both).
Below are some example installation steps; please adapt versions as needed:
Option A: Install for PyTorch Only
- Install PyTorch (with CUDA, if you want GPU acceleration). For example, if you have CUDA 12.4, you might do:
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124 - Install EvoX == 1.1.1 (for PyTorch support):
pip install git+https://github.com/EMI-Group/evox.git@v1.1.1 - Install MetaDE:
pip install git+https://github.com/EMI-Group/metade.git - Install Brax(Optional, if you want to solve Brax RL problems, also requires JAX installation):):
pip install brax==0.10.3 pip install -U jax[cuda12]
Option B: Install for JAX Only
-
Install JAX. We recommend
jax >= 0.4.16.For cpu version only, you may use:
pip install -U jaxFor nvidia gpus, you may use:
pip install -U jax[cuda12]For details of installing jax, please check https://github.com/google/jax.
-
Install EvoX == 1.1.1 (for PyTorch support):
pip install git+https://github.com/EMI-Group/evox.git@v1.1.1 -
Install MetaDE:
pip install git+https://github.com/EMI-Group/metade.git -
Install Brax(Optional, if you want to solve Brax RL problems):
pip install brax==0.10.3
Components
Evolver
MetaDE employs Differential Evolution (DE) as the evolver to optimize the parameters of its executor.
- Mutation: DE's mutation strategies evolve based on feedback from the problem landscape.
- Crossover: Different crossover strategies (binomial, exponential, arithmetic) can be used and adapted. <img src="./assets/evolver.png" alt="Evolver Image" width="90%">
Executor
The executor is a Parameterized Differential Evolution (PDE), a variant of DE designed to accommodate various mutation and crossover strategies dynamically.
- Parameterization: Flexible mutation strategies like
DE/rand/1/binorDE/best/2/expcan be selected based on problem characteristics. - Parallel Execution: Core operations of PDE are optimized for parallel execution on GPUs(via JAX or PyTorch). <img src="./assets/executor.png" alt="Executor Image" width="90%">
GPU Acceleration
MetaDE integrates with the EvoX framework for distributed, GPU-accelerated evolutionary computation, significantly enhancing performance on large-scale optimization tasks.
Examples
Global Optimization Benchmark Functions
import jax.numpy as jnp
import jax
from tqdm import tqdm
from metade.util import StdSOMonitor, StdWorkflow
from metade.algorithms.jax import create_batch_algorithm, decoder_de, MetaDE, ParamDE, DE
from metade.problems.jax.sphere import Sphere
D = 10
BATCH_SIZE = 100
NUM_RUNS = 1
key_start = 42
STEPS = 50
POP_SIZE = BATCH_SIZE
BASE_ALG_POP_SIZE = 100
BASE_ALG_STEPS = 100
tiny_num = 1e-5
param_lb = jnp.array([0, 0, 0, 0, 1, 0])
param_ub = jnp.array([1, 1, 4 - tiny_num, 4 - tiny_num, 5 - tiny_num, 3 - tiny_num])
evolver = DE(
lb=param_lb,
ub=param_ub,
pop_size=POP_SIZE,
base_vector="rand",
differential_weight=0.5,
cross_probability=0.9
)
BatchDE = create_batch_algorithm(ParamDE, BATCH_SIZE, NUM_RUNS)
batch_de = BatchDE(
lb=jnp.full((D,), -100),
ub=jnp.full((D,), 100),
pop_size=BASE_ALG_POP_SIZE,
)
base_problem = Sphere()
decoder = decoder_de
key = jax.random.PRNGKey(key_start)
monitor = StdSOMonitor(record_fit_history=False)
meta_problem = MetaDE(
batch_de,
base_problem,
batch_size=BATCH_SIZE,
num_runs=NUM_RUNS,
base_alg_steps=BASE_ALG_STEPS
)
workflow = StdWorkflow(
algorithm=evolver,
problem=meta_problem,
pop_transform=decoder,
monitor=monitor,
record_pop=True,
)
key, subkey = jax.random.split(key)
state = workflow.init(subkey)
power_up = 0
last_iter = False
for step in tqdm(range(STEPS)):
state = state.update_child("problem", {"power_up": power_up})
state = workflow.step(state)
if step == STEPS - 1:
power_up = 1
if last_iter:
break
last_iter = True
print(f"Best fitness: {monitor.get_best_fitness()}")
If you want to use the PyTorch backend, please refer to the PyTorch examples under
examples/pytorch/example.pyin this repository.
CEC Benchmark Problems
MetaDE supports several benchmark suites such as CEC2022. Here’s an example (JAX-based) for the CEC2022 test suite:
import jax.numpy as jnp
import jax
from tqdm import tqdm
from metade.util import (
StdSOMonitor,
StdWorkflow
)
from metade.algorithms.jax import create_batch_algorithm, decoder_de, MetaDE, ParamDE, DE
from metade.problems.jax import CEC2022TestSuit
D = 10
FUNC_LIST = jnp.arange(12) + 1
BATCH_SIZE = 100
NUM_RUNS = 1
key_start = 42
STEPS = 50
POP_SIZE = BATCH_SIZE
BASE_ALG_POP_SIZE = 100
BASE_ALG_STEPS = 100
tiny_num = 1e-5
param_lb = jnp.array([0, 0, 0, 0, 1, 0])
param_ub = jnp.array([1, 1, 4 - tiny_num, 4 - tiny_num, 5 - tiny_num, 3 - tiny_num])
evolver = DE(
lb=param_lb,
ub=param_ub,
pop_size=POP_SIZE,
base_vector="rand", differential_weight=0.5, cross_probability=0.9
)
BatchDE = create_batch_algorithm(ParamDE, BATCH_SIZE, NUM_RUNS)
batch_de = BatchDE(
lb=jnp.full((D,), -100),
ub=jnp.full((D,), 100),
pop_size=BASE_ALG_POP_SIZE,
)
for func_num in FUNC_LIST:
base_problem = CEC2022TestSuit.create(int(func_num))
decoder
Related Skills
openhue
336.9kControl Philips Hue lights and scenes via the OpenHue CLI.
sag
336.9kElevenLabs text-to-speech with mac-style say UX.
weather
336.9kGet current weather and forecasts via wttr.in or Open-Meteo
tweakcc
1.4kCustomize Claude Code's system prompts, create custom toolsets, input pattern highlighters, themes/thinking verbs/spinners, customize input box & user message styling, support AGENTS.md, unlock private/unreleased features, and much more. Supports both native/npm installs on all platforms.
