Sigmazero
Generalizing DeepMind's MuZero algorithm on stochastic environments
Install / Use
/learn @chiamp/SigmazeroREADME
SigmaZero
This is a repo where I generalize DeepMind's MuZero reinforcement learning algorithm on stochastic environments, and create an algorithm called SigmaZero (stochastic MuZero). For more details on the MuZero algorithm, check out the original paper and my project on applying the MuZero algorithm on the cartpole environment.
Table of Contents
- MuZero
- Monte Carlo Tree Search in Stochastic Environments
- SigmaZero
- Environment
- Experiments
- Discussions
- Future Work
- File Descriptions
MuZero
Functions
MuZero contains 3 functions approximated by neural networks, to be learned from the environment:
- A representation function, $h(o_t) \rightarrow s^0$, which given an observation $o_t$ from the environment at time step $t$, outputs the hidden state representation $s^0$ of the observation at hypothetical time step $0$ (this hidden state will be used as the root node in MCTS, so its hypothetical time step is zero)
- The representation function is used in tandem with the dynamics function to represent the environment's state in whatever way the algorithm finds useful in order to make accurate predictions for the reward, value and policy
- A dynamics function, $g(s^k,a^{k+1}) \rightarrow s^{k+1},r^{k+1}$, which given a hidden state representation $s^k$ at hypothetical time step $k$ and action $a^{k+1}$ at hypothetical time step $k+1$, outputs the predicted resulting hidden state representation $s^{k+1}$ and transition reward $r^{k+1}$ at hypothetical time step $k+1$
- The dynamics function is the learned transition model, which allows MuZero to utilize MCTS and plan hypothetical future actions on future board states
- A prediction function, $f(s^k) \rightarrow p^k,v^k$, which given a hidden state representation $s^k$, outputs the predicted policy distribution over actions $p^k$ and value $v^k$ at hypothetical time step $k$
- The prediction function is used to limit the search breadth by using the policy output to prioritize MCTS to search for more promising actions, and limit the search depth by using the value output as a substitute for a Monte Carlo rollout
Algorithm Overview
The MuZero algorithm can be summarized as follows:
- loop for a number of episodes:
- at every time step $t$ of the episode:
- perform Monte Carlo tree search (MCTS)
- pass the current observation of the environment $o_t$ to the representation function, $h(o_t) \rightarrow s^0$, and get the hidden state representation $s^0$ from the output
- pass the hidden state representation $s^0$ into the prediction function, $f(s^0) \rightarrow p^0,v^0$, and get the predicted policy distribution over actions $p^0$ and value $v^0$ from the output
- for a number of simulations:
- select a leaf node based on maximizing UCB score
- expand the node by passing the hidden state representation of its parent node $s^k$ and its corresponding action $a^{k+1}$ into the dynamics function, $g(s^k,a^{k+1}) \rightarrow s^{k+1},r^{k+1}$, and get the predicted resulting hidden state representation $s^{k+1}$ and transition reward $r^{k+1}$ from the output
- pass the resulting hidden state representation $s^{k+1}$ into the prediction function, $f(s^{k+1}) \rightarrow p^{k+1},v^{k+1}$, and get the predicted policy distribution over actions $p^{k+1}$ and value $v^{k+1}$ from the output
- backpropagate the predicted value $v^{k+1}$ up the search path
- sample an action based on the visit count of each child node of the root node
- apply the sampled action to the environment and observe the resulting transition reward
- perform Monte Carlo tree search (MCTS)
- once the episode is over, save the game trajectory (including the MCTS results) into the replay buffer
- sample a number of game trajectories from the replay buffer:
- pass the first observation of the environment $o_0$ from the game trajectory to the representation function, $h(o_0) \rightarrow s^0$ and get the hidden state representation $s^0$ from the output
- pass the hidden state representation $s^0$ into the prediction function, $f(s^0) \rightarrow p^0,v^0$, and get the predicted policy distribution over actions $p^0$ and value $v^0$ from the output
- for every time step $t$ in the game trajectory:
- pass the current hidden state representation $s^t$ and the corresponding action $a^{t+1}$ into the dynamics function, $g(s^t,a^{t+1}) \rightarrow s^{t+1},r^{t+1}$, and get the predicted resulting hidden state representation $s^{t+1}$ and transition reward $r^{t+1}$ from the output
- this predicted transition reward $r^{t+1}$ is matched to the actual transition reward target received from the environment
- pass the resulting hidden state representation $s^{t+1}$ into the prediction function, $f(s^{t+1}) \rightarrow p^{t+1},v^{t+1}$, and get the predicted policy distribution over actions $p^{t+1}$ and value $v^{t+1}$ from the output
- this predicted policy distribution $p^{t+1}$ is matched to the child node visit count distribution outputted by MCTS at that time step in that game trajectory
- this predicted value $v^{t+1}$ is matched to the value outputted by MCTS at that time step in that game trajectory
- update the weights of the representation, dynamics and prediction function based on these three targets
- pass the current hidden state representation $s^t$ and the corresponding action $a^{t+1}$ into the dynamics function, $g(s^t,a^{t+1}) \rightarrow s^{t+1},r^{t+1}$, and get the predicted resulting hidden state representation $s^{t+1}$ and transition reward $r^{t+1}$ from the output
- at every time step $t$ of the episode:
Monte Carlo Tree Search in Stochastic Environments
MCTS requires a model of the environment when expanding leaf nodes during its search. The environment model takes in a state and action and outputs the resulting state and transition reward; this is the functional definition of the dynamics function, $g(s^k,a^{k+1}) \rightarrow s^{k+1},r^{k+1}$, which approximates the true environment model. This works for deterministic environments where there is a single outcome for any action applied to any state.
In stochastic environments, the functional definition of the environment model changes. Given a state and action, the environment model instead outputs a set of possible resulting states, transition rewards and the corresponding probabilities of those outcomes occurring. To approximate this environment model, we can re-define the dynamics function as: $g(s^k,a^{k+1}) \rightarrow [s^{k+1}_1,...,s^{k+1}_b],[r^{k+1}_1,...,r^{k+1}_b],[\pi^{k+1}_1,...,\pi^{k+1}_b]$, where $\pi^{k+1}_i$ is the predicted probability that applying action $a^{k+1}$ to state $s^k$ results in the predicted state $s^{k+1}_i$ with transition reward $r^{k+1}_i$.
Given a current state $s$ and action $a$, a perfect environment model would output a corresponding probability for every possible transition sequence $s,a \rightarrow s^{'},r$, where $s^{'}$ is the resulting state and $r$ is the resulting transition reward. To approximate this with the dynamics function, we would need to define the function to output a number of predicted transitions $(s^{k+1}_i,r^{k+1}_i,\pi^{k+1}_i)$ equal to all possible transitions of the environment. This requires additional knowledge of the environment's state space, reward space and transition dynamics.
Instead we define a stochastic branching factor hyperparameter $b$ which sets and limits the number of predicted transitions the dynamics function can output. MCTS can then use this modified dynamics function to expand nodes and account for stochastic outcomes.
Below is an animation illustrating how MCTS is executed in stochastic environments (you can find the individual pictures in assets/sigmazero_graph_images/ for closer inspection). In this environment, the action space is size 3 and the stochastic branching factor hyperparameter is set to 2.

Instead of a single node representing a state, we have a node set representing a number of possible stochastic states, that are a result of applying a sequence of actions starting from the state at the root node. For example, $[s^2_1,s^2_2,s^2_3,s^2_4]$ are the predicted possible states as a result of applying the action $a_3$ to the state at the root node $s^0_1$, followed by applying action $a_1$ to the resulting stochastic states $[s^1_1,s^1_2]$. And thus, $[r^2_1,r^2_2,r^2_3,r^2_4]$ are the predicted transition rewards, $[v^2_1,v^2_2,v^2_3,v^2_4]$ are the predicted values and $[p^2_1,p^2_2,p^2_3,p^2_4]$ are the predicted policy distributions over actions for each corresponding stochastic state.
Just like nodes used for MCTS in deterministic environments, node sets also contain a predicted reward, value and policy distribution attribute. Whereas in the former case where the predicted reward would be obtained directly from the dynamics function output and the predicted value and policy distribution would be obtained directly from the prediction function output, the predicted reward, value and policy distribution for node sets are calculated via expectation.
More formally, to obtain the predicted reward $r_N$, value $v_N$ and policy distribution $p_N$ for a node set $N$, we calculate the following:
$r_N = \sum_{i \in N} \pi_i r_i$
$v_N = \sum_{i \in N} \pi_i v_i$
$p_N = [ \sum_{i \in N} \pi_i p_{i1} , ... , \sum_{i \in N} \pi_i p_{ia} ]$
where $\pi_i$ is the probability of transitioning to the hidden state of node $i$ in node set $N$ after applying the corresponding actions to the state of the root node, $r_i$ is the corresponding transition reward of node $i$ (obtained from the dynamics function output), $v_i$ is the corresponding value of node $i$ (obtained from the prediction function output), and $p_{ia}$ is the correspo
Related Skills
node-connect
332.3kDiagnose OpenClaw node connection and pairing failures for Android, iOS, and macOS companion apps
frontend-design
81.7kCreate distinctive, production-grade frontend interfaces with high design quality. Use this skill when the user asks to build web components, pages, or applications. Generates creative, polished code that avoids generic AI aesthetics.
openai-whisper-api
332.3kTranscribe audio via OpenAI Audio Transcriptions API (Whisper).
commit-push-pr
81.7kCommit, push, and open a PR
