Pbt
Population Based Training, Figure 2
Install / Use
/learn @bkj/PbtREADME
pbt: Population Based Training
Code to replicate figure 2 of Population Based Training of Neural Networks, Jaderberg et al
from __future__ import print_function
import numpy as np
import torch
from torch.optim import SGD
from torch.autograd import Variable
from torch.nn import Parameter
from matplotlib import pyplot as plt
plt.rcParams['figure.figsize'] = (10, 10)
_ = np.random.seed(123)
_ = torch.manual_seed(123)
# Define PBT worker
class Worker():
def __init__(self, theta, h, objective, surrogate_objective, id):
self.theta = theta
self.h = h
self.objective = objective
self.surrogate_objective = surrogate_objective
self.id = id
self._opt = SGD([theta], lr=0.01)
self._history = {"theta" : [], "h" : [], "score" : []}
@property
def history(self):
return {
"theta" : np.vstack(self._history['theta']),
"h" : np.vstack(self._history['h']),
"score" : np.array(self._history['score']),
}
def _log(self):
self._history['theta'].append(self.theta.data.numpy().copy())
self._history['h'].append(self.h.data.numpy().copy())
self._history['score'].append(self.eval())
def step(self):
""" Take an optimization step, given current hyperparemeters and surrogate objective """
self._log()
self._opt.zero_grad()
surrogate_loss = -1 * self.surrogate_objective(self.theta, self.h)
surrogate_loss.backward()
self._opt.step()
def eval(self):
""" Evalute actual objective -- eg measure accuracy on the hold-out set """
return self.objective(self.theta).data[0]
def exploit(self, population):
""" Copy theta from best member of the population """
current_scores = [{
"id": worker.id,
"score": worker.eval()
} for worker in population]
best_worker = sorted(current_scores, key=lambda x: x['score'])[-1]
if best_worker['id'] != self.id:
self.theta.data.set_(population[best_worker['id']].theta.data.clone())
def explore(self, sd=0.1):
""" Add normal noise to hyperparameter vector """
self.h.add_(Variable(torch.randn(2) * sd))
def run_experiment(do_explore=False, do_exploit=False, interval=5, n_steps=200):
# Define objective functions
objective = lambda theta: 1.2 - (theta ** 2).sum()
surrogate_objective = lambda theta, h: 1.2 - ((h * theta) ** 2).sum()
# Create population
population = [
Worker(
theta=Parameter(torch.FloatTensor([0.9, 0.9])),
h=Variable(torch.FloatTensor([1.0, 0.0])),
objective=objective,
surrogate_objective=surrogate_objective,
id=0,
),
Worker(
theta=Parameter(torch.FloatTensor([0.9, 0.9])),
h=Variable(torch.FloatTensor([0.0, 1.0])),
objective=objective,
surrogate_objective=surrogate_objective,
id=1,
),
]
# Train
for step in range(n_steps):
for worker in population:
if not (step + 1) % interval:
if do_exploit:
worker.exploit(population)
if do_explore:
worker.explore()
worker.step()
return population
# Run experiments w/ various PBT settings
pbt = run_experiment(do_explore=True, do_exploit=True) # Explore and exploit
explore = run_experiment(do_explore=True, do_exploit=False) # Explore only
exploit = run_experiment(do_explore=False, do_exploit=True) # Exploit only
grid = run_experiment(do_explore=False, do_exploit=False) # Independent training runs -- eg, regular grid search
def plot_score(ax, workers, run_name):
""" Plot performance """
for worker in workers:
history = worker.history
_ = ax.plot(history['score'], label="%s worker %d" % (run_name, worker.id), alpha=0.5)
_ = ax.set_title(run_name)
_ = ax.set_ylim(-1, 1.3)
_ = ax.axhline(1.2, c='lightgrey')
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(nrows=2, ncols=2)
plot_score(ax1, pbt, 'pbt')
plot_score(ax2, explore, 'explore')
plot_score(ax3, exploit, 'exploit')
plot_score(ax4, grid, 'grid')
_ = plt.tight_layout(pad=1)
plt.show()

def plot_theta(ax, workers, run_name):
""" Plot values of theta """
for worker in workers:
history = worker.history
_ = ax.scatter(history['theta'][:,0], history['theta'][:,1],
s=2, alpha=0.5, label="%s worker %d" % (run_name, worker.id))
_ = ax.set_title(run_name)
_ = ax.set_xlim(0, 1)
_ = ax.set_ylim(0, 1)
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(nrows=2, ncols=2)
plot_theta(ax1, pbt, 'pbt')
plot_theta(ax2, explore, 'explore')
plot_theta(ax3, exploit, 'exploit')
plot_theta(ax4, grid, 'grid')
_ = plt.tight_layout(pad=1)
plt.show()

Related Skills
YC-Killer
2.7kA library of enterprise-grade AI agents designed to democratize artificial intelligence and provide free, open-source alternatives to overvalued Y Combinator startups. If you are excited about democratizing AI access & AI agents, please star ⭐️ this repository and use the link in the readme to join our open source AI research team.
groundhog
399Groundhog's primary purpose is to teach people how Cursor and all these other coding agents work under the hood. If you understand how these coding assistants work from first principles, then you can drive these tools harder (or perhaps make your own!).
last30days-skill
18.8kAI agent skill that researches any topic across Reddit, X, YouTube, HN, Polymarket, and the web - then synthesizes a grounded summary
sec-edgar-agentkit
10AI agent toolkit for accessing and analyzing SEC EDGAR filing data. Build intelligent agents with LangChain, MCP-use, Gradio, Dify, and smolagents to analyze financial statements, insider trading, and company filings.
