rl4burn
Reinforcement learning algorithms for the Burn ML framework.
rl4burn lets you write RL algorithms once with B: AutodiffBackend and run them on any Burn backend — WGPU, CUDA, NdArray, or LibTorch. No per-backend reimplementation.
What’s included
Algorithms
| Algorithm | Type | Status |
|---|---|---|
| PPO | On-policy, actor-critic | Solves CartPole in <1s |
| Dual-Clip PPO | PPO for distributed training | JueWu/HoK-style |
| DQN | Off-policy, value-based | Solves CartPole in ~9s |
| Behavioral Cloning | Supervised imitation | Cross-entropy on demonstrations |
| Policy Distillation | Teacher-student transfer | Temperature-scaled KL |
Neural Network Modules
LSTM, GRU, and block-diagonal GRU cells. Transformer encoder blocks. Multi-head attention, target attention, attention pooling, and pointer networks. FiLM conditioning. Auto-regressive action distributions.
World Models (DreamerV3)
RSSM world model with imagination rollouts. Symlog/twohot distributional encoding. KL balancing with free bits. Sequence replay buffer. Percentile return normalization.
Game AI Infrastructure
Self-play with opponent pools. League training with agent roles (AlphaStar-style). PFSP matchmaking. Multi-agent shared-weight training. Privileged critic. Goal-conditioned RL. Agent branching. MCTS for drafting. Beta-VAE opponent modeling. Curriculum self-play learning (CSPL).
Building Blocks
GAE, V-trace, UPGO, replay buffers, multi-head value decomposition, intrinsic rewards, polyak updates, loss functions, orthogonal initialization, global gradient clipping, and logging.
Workspace architecture
rl4burn is organized as a Cargo workspace of five focused crates (rl4burn-core, rl4burn-nn, rl4burn-collect, rl4burn-algo, rl4burn-envs) plus an umbrella rl4burn crate that re-exports the full API. Users depend only on rl4burn — no need to manage individual crate dependencies. See the Architecture chapter for details.
Cookbook
The repository includes 15 runnable examples organized into five tiers:
- Fundamentals — quickstart, annotated PPO, config-driven training
- Environment Variations — custom environments, continuous actions, multi-discrete actions
- Techniques — action masking, reward shaping, LSTM policies
- Multi-Agent & Game AI — self-play, multi-agent, curriculum learning
- Production — diagnostics, hyperparameter tuning, policy deployment
Run any example with cargo run -p <name> --release. See the Cookbook for the full list and a decision guide for choosing the right algorithm.
Why Burn?
Burn’s Backend trait lets you write generic code:
fn train<B: AutodiffBackend>(model: MyModel<B>, device: &B::Device) {
// This works on WGPU, CUDA, NdArray, LibTorch...
}
For RL, this means:
- Train on GPU with
Autodiff<Wgpu>orAutodiff<LibTorch> - Deploy on edge with
NdArray(no GPU needed,no_stdcapable) - Run in the browser with WASM via the WGPU backend
No other Rust RL library achieves this level of backend portability.
Design philosophy
- You own the training loop.
ppo_collectandppo_updateare functions, not a framework. Compose them however you want. - Minimal API surface. Each algorithm is ~200 lines. Read the source — it’s meant to be understood.
- Correctness first. Integration tests train both PPO and DQN on CartPole to convergence. Contract annotations enforce preconditions on core functions.
- Match reference implementations. PPO defaults match CleanRL’s ppo.py. When Burn’s behavior differs from PyTorch (gradient clipping, parameter initialization), we provide compatible alternatives.
Architecture
rl4burn is organized as a Cargo workspace with five focused crates and one umbrella crate that re-exports everything.
Workspace layout
crates/
rl4burn-core — Env trait, spaces, SyncVecEnv, wrappers, Logger
rl4burn-nn — Neural network utilities (LSTM, GRU, attention, FiLM, policy traits, init)
rl4burn-collect — GAE, V-trace, UPGO, replay buffers, collection patterns
rl4burn-algo — PPO, DQN, AC, imitation, multi-agent, planning, losses
rl4burn-envs — CartPole, Pendulum, GridWorld
rl4burn/ — Umbrella crate re-exporting everything
examples/ — 15 runnable cookbook examples (see Cookbook)
Dependency DAG
The crates form a clean dependency hierarchy:
rl4burn-core (no internal deps)
|
+--- rl4burn-nn (depends on core)
| |
| +--- rl4burn-collect (depends on core, nn)
| | |
| | +--- rl4burn-algo (depends on core, nn, collect)
| |
+--- rl4burn-envs (depends on core)
Each crate has a single responsibility:
- rl4burn-core defines the foundational abstractions: the
Envtrait, observation/action spaces, vectorized environments (SyncVecEnv), environment wrappers, and the logging system. - rl4burn-nn provides neural network building blocks: RNN cells (LSTM, GRU, block-diagonal GRU), transformer encoders, attention mechanisms, FiLM conditioning, policy traits (
DiscreteActorCritic,MaskedActorCritic,QNetwork), orthogonal initialization, gradient clipping, and polyak updates. - rl4burn-collect handles data collection: GAE, V-trace, UPGO, replay buffers, sequence replay, intrinsic rewards, percentile normalization, and distributed collection patterns (actor-learner, centralized inference, trajectory queues).
- rl4burn-algo contains the algorithms: PPO, DQN, actor-critic with V-trace, behavioral cloning, policy distillation, multi-agent infrastructure (self-play, league training, PFSP), planning (MCTS, imagination rollouts), and loss functions.
- rl4burn-envs provides built-in environments for testing and examples: CartPole, Pendulum, and GridWorld.
The umbrella crate
Most users should depend only on rl4burn in their Cargo.toml. The umbrella crate re-exports the full public API at the top level:
// All of these work — no intermediate module paths needed:
use rl4burn::SyncVecEnv;
use rl4burn::PpoConfig;
use rl4burn::{ppo_collect, ppo_update};
use rl4burn::{DiscreteActorCritic, DiscreteAcOutput};
use rl4burn::{Logger, PrintLogger};
use rl4burn::ReplayBuffer;
If you need access to the sub-crate modules directly, they are also available:
use rl4burn::core; // rl4burn_core
use rl4burn::nn; // rl4burn_nn
use rl4burn::collect; // rl4burn_collect
use rl4burn::algo; // rl4burn_algo
use rl4burn::envs; // rl4burn_envs
When to depend on individual crates
For most projects, the umbrella rl4burn crate is the right choice. You might depend on individual crates if:
- You only need environments (
rl4burn-core+rl4burn-envs) and want minimal compile times. - You are building a custom algorithm and only need the collection primitives (
rl4burn-collect). - You are writing a library that extends rl4burn and want to minimize your dependency footprint.
Installation
Add rl4burn and Burn to your Cargo.toml:
[dependencies]
rl4burn = { git = "https://github.com/RPP1011/rl4burn" }
burn = { version = "0.20", features = ["std", "ndarray", "autodiff"] }
rand = "0.10"
rl4burn is a workspace of focused crates (rl4burn-core, rl4burn-nn, rl4burn-collect, rl4burn-algo, rl4burn-envs), but the umbrella rl4burn crate re-exports everything so you only need one dependency.
The ndarray feature gives you a CPU backend for development and testing. For GPU training, add wgpu or tch (LibTorch) instead.
Verify the install
Create a src/main.rs:
use rl4burn::envs::CartPole;
use rl4burn::env::Env;
use rand::SeedableRng;
fn main() {
let mut env = CartPole::new(rand::rngs::SmallRng::seed_from_u64(42));
let obs = env.reset();
println!("CartPole observation: {:?}", obs);
let step = env.step(1); // push right
println!("Reward: {}, Done: {}", step.reward, step.done());
}
cargo run
You should see a 4-element observation vector and a reward of 1.0.
Your First Agent: PPO on CartPole
This walkthrough trains a PPO agent to balance a pole on a cart. By the end, you’ll understand the three pieces every rl4burn training script needs: a model, environments, and a training loop.
The model
PPO needs an actor-critic model: given an observation, produce action logits and a value estimate. Define a Burn module and implement DiscreteActorCritic:
use burn::module::Module;
use burn::nn::{Linear, LinearConfig};
use burn::prelude::*;
use rl4burn::{DiscreteAcOutput, DiscreteActorCritic};
#[derive(Module, Debug)]
struct ActorCritic<B: Backend> {
// Separate actor and critic networks (no shared layers)
actor_fc1: Linear<B>,
actor_fc2: Linear<B>,
actor_out: Linear<B>,
critic_fc1: Linear<B>,
critic_fc2: Linear<B>,
critic_out: Linear<B>,
}
impl<B: Backend> ActorCritic<B> {
fn new(device: &B::Device) -> Self {
Self {
actor_fc1: LinearConfig::new(4, 64).init(device),
actor_fc2: LinearConfig::new(64, 64).init(device),
actor_out: LinearConfig::new(64, 2).init(device),
critic_fc1: LinearConfig::new(4, 64).init(device),
critic_fc2: LinearConfig::new(64, 64).init(device),
critic_out: LinearConfig::new(64, 1).init(device),
}
}
}
impl<B: Backend> DiscreteActorCritic<B> for ActorCritic<B> {
fn forward(&self, obs: Tensor<B, 2>) -> DiscreteAcOutput<B> {
let a = self.actor_fc1.forward(obs.clone()).tanh();
let a = self.actor_fc2.forward(a).tanh();
let logits = self.actor_out.forward(a);
let c = self.critic_fc1.forward(obs).tanh();
let c = self.critic_fc2.forward(c).tanh();
let values = self.critic_out.forward(c).squeeze_dim::<1>(1);
DiscreteAcOutput { logits, values }
}
}
Key points:
#[derive(Module)]gives you parameter management, serialization, and device transfer for free.DiscreteAcOutputholdslogits: Tensor<B, 2>(shape[batch, n_actions]) andvalues: Tensor<B, 1>(shape[batch]).- The model is generic over
B: Backend. The same struct works on any Burn backend.
The environments
CartPole is built in. Wrap it in SyncVecEnv to run multiple copies in parallel:
use rl4burn::envs::CartPole;
use rl4burn::SyncVecEnv;
use rand::SeedableRng;
let n_envs = 4;
let envs: Vec<CartPole<rand::rngs::SmallRng>> = (0..n_envs)
.map(|i| CartPole::new(rand::rngs::SmallRng::seed_from_u64(i as u64)))
.collect();
let mut vec_env = SyncVecEnv::new(envs);
SyncVecEnv steps all environments in lockstep and auto-resets when episodes end.
The training loop
PPO training alternates between two phases: collect a rollout of experience, then update the model on that experience.
use burn::backend::{Autodiff, NdArray};
use burn::module::AutodiffModule;
use burn::optim::AdamConfig;
use rl4burn::{ppo_collect, ppo_update, PpoConfig};
use rl4burn::{Loggable, Logger, PrintLogger};
type AB = Autodiff<NdArray>;
let device = burn::backend::ndarray::NdArrayDevice::Cpu;
let mut rng = rand::rngs::SmallRng::seed_from_u64(42);
let mut model: ActorCritic<AB> = ActorCritic::new(&device);
let mut optim = AdamConfig::new().with_epsilon(1e-5).init();
let config = PpoConfig::default();
let mut logger = PrintLogger::new(0);
// Episode return accumulator — persists across rollouts
let mut ep_acc = vec![0.0f32; n_envs];
// Current observations — persists across rollouts
let mut current_obs = vec_env.reset();
for iter in 0..100 {
// Collect: use the non-autodiff model for inference
let rollout = ppo_collect::<NdArray, _, _>(
&model.valid(),
&mut vec_env,
&config,
&device,
&mut rng,
&mut current_obs,
&mut ep_acc,
);
// Update: train on the collected data
let (new_model, stats) = ppo_update(
model, &mut optim, &rollout, &config,
config.lr, // or use LR annealing
&device, &mut rng,
);
model = new_model;
// Log training stats
let step = (iter + 1) as u64 * (config.n_steps * n_envs) as u64;
stats.log(&mut logger, step);
if !rollout.episode_returns.is_empty() {
let avg = rollout.episode_returns.iter().sum::<f32>()
/ rollout.episode_returns.len() as f32;
logger.log_scalar("rollout/avg_return", avg as f64, step);
}
}
logger.flush();
Key points:
model.valid()strips the autodiff layer for efficient inference during collection.current_obsholds the latest observations from the environments, persisting across rollout boundaries so the next collection starts from where the last one left off.ep_acctracks per-env cumulative reward across rollout boundaries. Without this, episodes longer thann_stepswould have their returns split.ppo_updatereturns the updated model (Burn modules are moved through optimizers, not mutated in place).stats.log(...)uses theLoggabletrait to log all PPO metrics. See the Logging chapter for details on logger setup.
Run it
cargo run -p quickstart --release
You should see episode returns climb from ~20 (random policy) to 500 (solved) within seconds.
Cookbook
rl4burn ships with 15 runnable examples in the examples/ directory, organized into five tiers of increasing complexity. Each example is a standalone Cargo package that you can run with cargo run -p <name> --release.
Tier 1: Fundamentals
| Example | Command | Description |
|---|---|---|
| quickstart | cargo run -p quickstart --release | Minimal PPO on CartPole — the “hello world” of RL |
| ppo-annotated | cargo run -p ppo-annotated --release | Same as quickstart but with detailed comments explaining every line |
| config-driven | cargo run -p config-driven --release | Load hyperparameters from a TOML file instead of hardcoding them |
Tier 2: Environment Variations
| Example | Command | Description |
|---|---|---|
| custom-env | cargo run -p custom-env --release | Implement the Env trait for your own environment |
| ppo-continuous | cargo run -p ppo-continuous --release | PPO with continuous actions on Pendulum |
| ppo-multi-discrete | cargo run -p ppo-multi-discrete --release | PPO with multi-discrete action spaces |
Tier 3: Techniques
| Example | Command | Description |
|---|---|---|
| action-masking | cargo run -p action-masking --release | Invalid action masking with the masked PPO pipeline |
| reward-shaping | cargo run -p reward-shaping --release | Intrinsic rewards and reward shaping wrappers |
| lstm-policy | cargo run -p lstm-policy --release | Recurrent policy for partially observable environments |
Tier 4: Multi-Agent & Game AI
| Example | Command | Description |
|---|---|---|
| self-play | cargo run -p self-play --release | Self-play training with an opponent pool |
| multi-agent | cargo run -p multi-agent --release | Shared-weight multi-agent training |
| curriculum | cargo run -p curriculum --release | Curriculum self-play learning (CSPL) |
Tier 5: Production
| Example | Command | Description |
|---|---|---|
| diagnostics | cargo run -p diagnostics --release | TensorBoard logging, video recording, and training diagnostics |
| hyperparameter-tuning | cargo run -p hyperparameter-tuning --release | Systematic hyperparameter sweeps |
| deploy-policy | cargo run -p deploy-policy --release | Export a trained policy for inference on a different backend |
Which algorithm should I use?
Use this decision guide to pick the right starting point:
| Scenario | Recommended algorithm | Start from example |
|---|---|---|
| Discrete actions (e.g., CartPole, Atari) | PPO or DQN | quickstart |
| Continuous actions (e.g., Pendulum, MuJoCo) | PPO with Gaussian policy | ppo-continuous |
| Multi-discrete actions (e.g., RTS games) | PPO with multi-head | ppo-multi-discrete |
| Invalid actions vary per step | Masked PPO | action-masking |
| Competitive game (1v1 or teams) | Self-play PPO | self-play |
| Partial observability | LSTM policy + PPO | lstm-policy |
| Multiple cooperating agents | Shared-weight PPO | multi-agent |
| Large observation space / model-based | DreamerV3 (future) | — |
When in doubt, start with PPO (quickstart). It is the most versatile algorithm and works well across a wide range of problems. Switch to DQN only if you need off-policy learning or have a small discrete action space where sample efficiency matters.
Environments
The Env trait defines how an RL agent interacts with the world. It follows modern Gymnasium conventions.
The Env trait
pub trait Env {
type Observation: Clone;
type Action: Clone;
fn reset(&mut self) -> Self::Observation;
fn step(&mut self, action: Self::Action) -> Step<Self::Observation>;
fn observation_space(&self) -> Space;
fn action_space(&self) -> Space;
}
step returns a Step struct with separate terminated and truncated flags:
pub struct Step<O> {
pub observation: O,
pub reward: f32,
pub terminated: bool, // episode ended due to environment dynamics
pub truncated: bool, // episode ended due to time limit
}
The done() method returns terminated || truncated.
Implementing a custom environment
use rl4burn::env::{Env, Step};
use rl4burn::space::Space;
struct MyEnv {
state: f32,
step_count: usize,
}
impl Env for MyEnv {
type Observation = Vec<f32>;
type Action = usize;
fn reset(&mut self) -> Vec<f32> {
self.state = 0.0;
self.step_count = 0;
vec![self.state]
}
fn step(&mut self, action: usize) -> Step<Vec<f32>> {
self.state += if action == 0 { -0.1 } else { 0.1 };
self.step_count += 1;
Step {
observation: vec![self.state],
reward: -self.state.abs(), // reward for staying near 0
terminated: self.state.abs() > 1.0,
truncated: self.step_count >= 200,
}
}
fn observation_space(&self) -> Space {
Space::Box { low: vec![-2.0], high: vec![2.0] }
}
fn action_space(&self) -> Space {
Space::Discrete(2)
}
}
Built-in environments
| Environment | Obs dim | Actions | Max steps |
|---|---|---|---|
CartPole | 4 | 2 (left/right) | 500 |
use rl4burn::envs::CartPole;
use rand::SeedableRng;
let mut env = CartPole::new(rand::rngs::SmallRng::seed_from_u64(42));
CartPole is generic over R: Rng, so you control the random number generator for reproducibility.
Spaces
Spaces describe the shape and bounds of observations and actions. They’re used for constructing networks (knowing input/output dimensions) and validating data.
The Space enum
pub enum Space {
Discrete(usize), // {0, 1, ..., n-1}
Box { low: Vec<f32>, high: Vec<f32> }, // continuous, per-dimension bounds
MultiDiscrete(Vec<usize>), // multiple independent discrete spaces
}
Methods
flat_dim()— total dimension (one-hot width for Discrete, number of dims for Box)shape()— shape as a Vec
Usage
Spaces are returned by Env::observation_space() and Env::action_space(). Use them to size your network layers:
let obs_dim = env.observation_space().flat_dim(); // e.g., 4 for CartPole
let n_actions = match env.action_space() {
Space::Discrete(n) => n,
_ => panic!("expected discrete actions"),
};
let fc1 = LinearConfig::new(obs_dim, 64).init(&device);
let out = LinearConfig::new(64, n_actions).init(&device);
Vectorized Environments
SyncVecEnv runs N copies of an environment in lockstep, collecting N transitions per step. This is essential for PPO, which needs batched data from parallel environments.
Usage
use rl4burn::vec_env::SyncVecEnv;
use rl4burn::envs::CartPole;
use rand::SeedableRng;
let envs: Vec<CartPole<_>> = (0..8)
.map(|i| CartPole::new(rand::rngs::SmallRng::seed_from_u64(i as u64)))
.collect();
let mut vec_env = SyncVecEnv::new(envs);
// Reset all environments
let observations = vec_env.reset(); // Vec of 8 observations
// Step all environments with one action each
let actions = vec![0, 1, 0, 1, 1, 0, 1, 0];
let steps = vec_env.step(actions); // Vec of 8 Step results
Auto-reset
When an environment reaches a terminal or truncated state, SyncVecEnv automatically resets it. The returned observation is the initial observation of the new episode, not the terminal observation. This matches Gymnasium’s SyncVectorEnv behavior.
The reward and done flags in the returned Step are from the terminal step — only the observation is replaced.
When to use SyncVecEnv
| Algorithm | Vectorized? | Why |
|---|---|---|
| PPO | Yes (required) | Needs batched rollouts from parallel envs |
| DQN | No (typically) | Single-env stepping with replay buffer |
Environment Wrappers
Wrappers transform an environment’s observations, rewards, or tracking without modifying the environment itself. They implement Env and wrap an inner Env.
EpisodeStats
Tracks cumulative episode reward and length. Updated when episodes complete.
use rl4burn::wrapper::EpisodeStats;
let mut env = EpisodeStats::new(CartPole::new(rng));
env.reset();
loop {
let step = env.step(action);
if step.done() {
println!("Episode return: {}", env.last_episode_reward.unwrap());
println!("Episode length: {}", env.last_episode_length.unwrap());
break;
}
}
RewardClip
Clips rewards to [-limit, limit]. Useful for environments with large or unbounded rewards.
use rl4burn::wrapper::RewardClip;
let env = RewardClip::new(my_env, 1.0); // rewards clipped to [-1, 1]
NormalizeObservation
Normalizes observations to zero mean, unit variance using Welford’s online algorithm. Observations are also clipped to [-clip, clip].
use rl4burn::wrapper::NormalizeObservation;
let env = NormalizeObservation::new(my_env, 10.0).unwrap(); // clip normalized obs to [-10, 10]
Requires the environment to have Observation = Vec<f32> and a Box observation space.
Composing wrappers
Wrappers compose naturally:
let env = EpisodeStats::new(
RewardClip::new(
NormalizeObservation::new(my_env, 10.0).unwrap(),
1.0
)
);
PPO (Proximal Policy Optimization)
PPO is an on-policy actor-critic algorithm. It collects a batch of experience using the current policy, computes advantages, then performs multiple epochs of minibatch gradient descent with a clipped surrogate objective.
Our implementation matches CleanRL’s ppo.py.
API
PPO is split into two functions:
ppo_collect— Run the policy on vectorized environments, collect transitions, compute GAE advantages.ppo_update— Perform clipped PPO gradient steps on the collected data.
You compose them in your own training loop.
The DiscreteActorCritic trait
pub trait DiscreteActorCritic<B: Backend> {
fn forward(&self, obs: Tensor<B, 2>) -> DiscreteAcOutput<B>;
}
pub struct DiscreteAcOutput<B: Backend> {
pub logits: Tensor<B, 2>, // [batch, n_actions]
pub values: Tensor<B, 1>, // [batch]
}
Implement this on any #[derive(Module)] struct. The model must output both action logits (for the policy) and value estimates (for the critic) in a single forward pass.
Configuration
PpoConfig defaults match CleanRL:
| Parameter | Default | Description |
|---|---|---|
lr | 2.5e-4 | Learning rate |
gamma | 0.99 | Discount factor |
gae_lambda | 0.95 | GAE smoothing parameter |
clip_eps | 0.2 | Surrogate clipping range |
vf_coef | 0.5 | Value loss coefficient |
ent_coef | 0.01 | Entropy bonus coefficient |
update_epochs | 4 | Optimization epochs per rollout |
minibatch_size | 128 | Minibatch size |
n_steps | 128 | Rollout length per env |
clip_vloss | true | Whether to clip value loss |
max_grad_norm | 0.5 | Global gradient norm clipping (0 = disabled) |
LR annealing
ppo_update takes a current_lr parameter. For linear annealing:
let frac = 1.0 - iter as f64 / n_iterations as f64;
let current_lr = config.lr * frac;
For constant LR, just pass config.lr.
Episode return tracking
ppo_collect accepts an &mut Vec<f32> accumulator for per-env episode returns. This handles episodes that span multiple rollouts correctly. Completed episode returns are in PpoRollout::episode_returns.
let mut current_obs = vec_env.reset(); // create once before the loop
let mut ep_acc = vec![0.0f32; n_envs];
let rollout = ppo_collect(..., &mut current_obs, &mut ep_acc);
for &ret in &rollout.episode_returns {
println!("completed episode return: {ret}");
}
Multi-discrete actions and action masking
For complex action spaces (multiple discrete dimensions, per-step validity masks), use masked_ppo_collect and masked_ppo_update with an ActionDist:
use rl4burn::{ActionDist, MaskedActorCritic, masked_ppo_collect, masked_ppo_update};
// Action space: [action_type(5), target(10)]
let action_dist = ActionDist::MultiDiscrete(vec![5, 10]);
The MaskedActorCritic trait
pub trait MaskedActorCritic<B: Backend> {
fn forward(&self, obs: Tensor<B, 2>) -> (Tensor<B, 2>, Tensor<B, 1>);
fn log_std(&self) -> Option<Tensor<B, 1>> { None } // continuous only
}
If you already have a DiscreteActorCritic model, the delegation is trivial:
impl<B: Backend> MaskedActorCritic<B> for MyModel<B> {
fn forward(&self, obs: Tensor<B, 2>) -> (Tensor<B, 2>, Tensor<B, 1>) {
let out = DiscreteActorCritic::forward(self, obs);
(out.logits, out.values)
}
}
Action masking
Environments provide per-step masks via Env::action_mask():
fn action_mask(&self) -> Option<Vec<f32>> {
let mut mask = vec![0.0; 15]; // 5 + 10
for valid_type in &self.valid_action_types { mask[*valid_type] = 1.0; }
for valid_target in &self.valid_targets { mask[5 + *valid_target] = 1.0; }
Some(mask)
}
Masked actions are never sampled and receive zero probability during training.
Env action type
The masked pipeline expects Env<Action = Vec<f32>>. For existing discrete envs (Action = usize), use DiscreteEnvAdapter:
use rl4burn::DiscreteEnvAdapter;
let envs: Vec<DiscreteEnvAdapter<CartPole<_>>> = (0..4)
.map(|i| DiscreteEnvAdapter(CartPole::new(rng)))
.collect();
Continuous action spaces
For continuous control (e.g. Pendulum, MuJoCo), use ActionDist::Continuous. The model outputs means (and optionally log standard deviations) for a diagonal Gaussian distribution.
ModelOutput mode
The model outputs [batch, 2 * action_dim] — first half is means, second half is log_stds:
let action_dist = ActionDist::Continuous {
action_dim: 1,
log_std_mode: LogStdMode::ModelOutput,
};
// Model outputs [batch, 2]: [mean, log_std]
impl<B: Backend> MaskedActorCritic<B> for MyModel<B> {
fn forward(&self, obs: Tensor<B, 2>) -> (Tensor<B, 2>, Tensor<B, 1>) {
let h = self.encoder.forward(obs);
let logits = self.policy_head.forward(h.clone()); // [batch, 2]
let values = self.value_head.forward(h).squeeze_dim::<1>(1);
(logits, values)
}
}
Separate mode
For state-independent log_std (CleanRL’s default), the model outputs only means and provides log_std via a separate learnable parameter:
let action_dist = ActionDist::Continuous {
action_dim: 1,
log_std_mode: LogStdMode::Separate,
};
impl<B: Backend> MaskedActorCritic<B> for MyModel<B> {
fn forward(&self, obs: Tensor<B, 2>) -> (Tensor<B, 2>, Tensor<B, 1>) {
// logits = [batch, action_dim] (means only)
...
}
fn log_std(&self) -> Option<Tensor<B, 1>> {
Some(self.log_std_param.val())
}
}
log_std clamping
ActionDist::Continuous automatically clamps log_std to [-5, 2] in all operations (sampling, log-prob, entropy). This prevents numerical instability from excessively large or small standard deviations — a common source of training divergence in continuous RL.
Continuous PPO tips
- Set
ent_coef: 0.0— entropy bonus can destabilize continuous policies - Use
update_epochs: 10— more gradient steps per rollout helps with continuous - Longer rollouts (
n_steps: 256+) improve value estimation for dense-reward tasks - Environments should accept
Vec<f32>actions (Pendulum does this natively)
See examples/ppo_pendulum.rs for a complete working example.
Implementation details
- Per-minibatch advantage normalization: Advantages are z-normalized within each minibatch, not globally across the full rollout.
- Clipped value loss:
max(unclipped, clipped)usinga + relu(b - a)to avoidmask_wheregradient issues in Burn’s autodiff. - Clipped surrogate:
min(surr1, surr2)usingb - relu(b - a)for the same reason. - Global gradient clipping: Uses
clip_grad_norm(PyTorch-compatible global norm), not Burn’s built-in per-parameter clipping. - Minibatch shuffling: Fisher-Yates shuffle each epoch.
DQN (Deep Q-Network)
DQN is an off-policy, value-based algorithm. It learns a Q-function that estimates the expected return for each action in a given state, then acts by taking the argmax.
API
QNetworktrait — Your model implementsfn q_values(&self, obs) -> Tensorreturning Q-values for all actions.dqn_update— One gradient step on a minibatch sampled from the replay buffer, using the target network for stable Bellman targets.epsilon_greedy— Action selection with exploration.epsilon_schedule— Linear epsilon annealing.polyak_update— Target network update (hard or soft).
The QNetwork trait
pub trait QNetwork<B: Backend> {
fn q_values(&self, obs: Tensor<B, 2>) -> Tensor<B, 2>;
}
Input shape: [batch, obs_dim]. Output shape: [batch, n_actions].
Example implementation:
use burn::tensor::activation::relu;
#[derive(Module, Debug)]
struct QNet<B: Backend> {
fc1: Linear<B>,
fc2: Linear<B>,
q_head: Linear<B>,
}
impl<B: Backend> QNetwork<B> for QNet<B> {
fn q_values(&self, obs: Tensor<B, 2>) -> Tensor<B, 2> {
let h = relu(self.fc1.forward(obs));
let h = relu(self.fc2.forward(h));
self.q_head.forward(h)
}
}
Training loop
DQN differs from PPO: it uses a single environment, a replay buffer, and epsilon-greedy exploration.
use rl4burn::{dqn_update, epsilon_greedy, epsilon_schedule, DqnConfig, Transition};
use rl4burn::ReplayBuffer;
use rl4burn::polyak_update;
use rl4burn::Env;
let config = DqnConfig::default();
let mut buffer = ReplayBuffer::new(config.buffer_capacity, rand::rngs::SmallRng::seed_from_u64(42));
let mut online: QNet<AB> = QNet::new(&device);
let mut target = online.clone();
let mut optim = AdamConfig::new().init();
let mut obs = env.reset();
for step in 0..50_000 {
// Epsilon-greedy action selection (use non-autodiff model)
let eps = epsilon_schedule(&config, step);
let action = {
let inner = online.valid();
epsilon_greedy::<NdArray, _>(&inner, &obs, 2, eps, &device, &mut rng)
};
// Step environment, store transition
let result = env.step(action);
buffer.extend(std::iter::once(Transition {
obs: obs.clone(),
action: action as i32,
reward: result.reward,
next_obs: result.observation.clone(),
done: result.done(),
}));
obs = if result.done() { env.reset() } else { result.observation };
// Train after warmup
if step >= config.learning_starts && buffer.len() >= config.batch_size {
(online, _) = dqn_update(
online, &target, &mut optim, &mut buffer, &config, &device,
);
// Hard target update every N steps
if step % 250 == 0 {
target = polyak_update(target, &online, 1.0);
}
}
}
Configuration
| Parameter | Default | Description |
|---|---|---|
lr | 1e-4 | Learning rate |
gamma | 0.99 | Discount factor |
buffer_capacity | 10,000 | Replay buffer size |
batch_size | 32 | Minibatch size |
tau | 0.005 | Polyak coefficient (1.0 = hard copy) |
eps_start | 1.0 | Initial exploration rate |
eps_end | 0.05 | Final exploration rate |
eps_decay_steps | 10,000 | Steps to anneal epsilon |
learning_starts | 1,000 | Random steps before training |
Target network
DQN uses a slowly-updated target network for stable Bellman targets. Two strategies:
- Hard updates (
tau = 1.0): Copy all weights every N steps. Simpler, what CleanRL uses. - Soft updates (
tau = 0.005): Polyak average every step. Smoother, what SAC/TD3 use.
The caller is responsible for calling polyak_update — dqn_update only updates the online network.
How dqn_update works
- Sample a minibatch from the replay buffer
- Compute Q(s, a) for taken actions using the online network
- Compute max Q(s’, a’) using the target network (detached from the computation graph by extracting tensor data)
- Bellman target:
y = r + γ * (1 - done) * max_a' Q_target(s', a') - MSE loss:
mean((Q(s, a) - y)²) - Backward + optimizer step
Dual-Clip PPO
An extension of standard PPO used by JueWu and Honor of Kings for distributed training stability.
The problem
In distributed RL, the behavior policy can be several updates behind. When the ratio pi_new/pi_old is very large and the advantage is negative, standard PPO’s objective becomes excessively negative, causing destructive updates.
The fix
Add a floor: when advantage < 0, the objective can’t go below c * advantage (c = 3):
standard_ppo = min(ratio * adv, clip(ratio, 1-ε, 1+ε) * adv)
dual_clip = max(standard_ppo, c * adv) // only when adv < 0
Usage
let config = PpoConfig {
dual_clip_coef: Some(3.0),
..Default::default()
};
That’s it. Set dual_clip_coef: None (the default) for standard PPO.
When to use
Only needed for distributed/asynchronous training where trajectories may be significantly off-policy. For single-machine training, standard PPO is sufficient.
Behavioral Cloning
Train a policy to imitate expert demonstrations via supervised learning. JueWu showed this provides ~64% of final RL performance as initialization.
API
use rl4burn::{bc_loss_discrete, bc_step};
// Single loss computation
let loss = bc_loss_discrete(logits, expert_actions, &device);
// Full training step (forward + backward + optimizer step)
let (model, loss_val) = bc_step(model, &mut optim, obs, expert_actions, lr, &device);
Multi-head actions
For hierarchical action spaces:
use rl4burn::bc_loss_multi_head;
let loss = bc_loss_multi_head(logits, expert_actions, &[11, 30, 8], &device);
// head_sizes: action_type(11), target(30), ability(8)
Tips
- The uniform-policy cross-entropy loss should equal
ln(K)where K is the number of actions. If your initial loss is much higher, something is wrong. - BC is most useful as RL weight initialization, not as a standalone method. BC policies are brittle — they fail on states not in the training data.
Policy Distillation
Train a student network to match a teacher’s behavior. Used in CSPL (Phase 2) to merge multiple specialist teachers into one generalist.
API
use rl4burn::algo::imitation::distillation::{distillation_loss, DistillationConfig};
let config = DistillationConfig {
temperature: 2.0,
soft_weight: 1.0,
hard_weight: 0.0,
t_squared_scaling: true,
};
let loss = distillation_loss(teacher_logits, student_logits, &config);
Temperature
Higher temperature produces softer probability distributions. The student learns more from the relative ordering of actions, not just the best one.
- T=1: standard softmax (peaked)
- T=5: much softer (exposes teacher’s “second choice” preferences)
T-squared scaling
Hinton et al. recommend scaling the soft-target loss by T-squared. Without this, gradients from soft targets are 1/T-squared too small.
Value distillation
use rl4burn::algo::imitation::distillation::value_distillation_loss;
let vloss = value_distillation_loss(teacher_values, student_values);
Loss Functions
Backend-generic loss functions for RL training. All return Tensor<B, 1> with shape [1] (scalar), compatible with .backward().
Value loss (Huber)
pub fn value_loss<B: Backend>(pred: Tensor<B, 1>, target: Tensor<B, 1>) -> Tensor<B, 1>
Smooth L1 (Huber) loss with δ=1.0. Quadratic for small errors, linear for large errors. Prevents outlier targets from dominating the value head update.
Discrete policy loss (REINFORCE)
pub fn policy_loss_discrete<B: Backend>(
logits: Tensor<B, 2>, // [batch, n_actions]
actions: Tensor<B, 2, Int>, // [batch, 1] action indices
mask: Tensor<B, 2>, // [batch, n_actions] valid=1.0
advantage: Tensor<B, 1>, // [batch]
) -> Tensor<B, 1>
Standard REINFORCE: -mean(advantage * log_prob(action)). Supports action masking for environments with invalid actions.
Continuous policy loss
pub fn policy_loss_continuous<B: Backend>(
pred: Tensor<B, 2>, // [batch, action_dim]
target: Tensor<B, 2>, // [batch, action_dim]
advantage: Tensor<B, 1>, // [batch]
) -> Tensor<B, 1>
Advantage-weighted regression for deterministic continuous actions. Only positive advantages contribute gradient (negative advantage + MSE is degenerate).
Note
These loss functions are standalone building blocks. PPO and DQN implement their own loss computation internally (clipped surrogate for PPO, Bellman MSE for DQN). Use these when building custom algorithms.
Multi-Head Value Decomposition
Decompose value estimation into N heads, each tracking a different reward component. Used by JueWu (Honor of Kings) with 5 heads: farming, KDA, damage, pushing, and winning.
API
use rl4burn::{MultiHeadValueConfig, multi_head_gae, multi_head_value_loss};
let config = MultiHeadValueConfig::new(5, 0.99, 0.95)
.with_weights(vec![0.1, 0.2, 0.2, 0.2, 0.3]);
let result = multi_head_gae(
&per_head_rewards, // [5][T]
&per_head_values, // [5][T]
&dones, // [T]
&per_head_last_values, // [5]
&config,
);
// result.combined_advantages: [T] — weighted sum across heads
// result.per_head_returns: [5][T] — targets for each value head
Why decompose?
With a single value function, the agent knows how well it’s doing but not why. Multi-head decomposition provides credit assignment: “I’m farming well but my pushing is weak.”
Each head can have its own discount factor — short-term heads (damage) use lower gamma, long-term heads (winning) use higher gamma.
Per-head value loss
let losses = multi_head_value_loss(&predictions, &targets);
// losses: [5] — MSE per head
let total_loss: f32 = losses.iter().sum();
KL Balancing with Free Bits
DreamerV3’s method for training the RSSM world model without the latent space collapsing.
The problem
The RSSM has an encoder (posterior: what actually happened) and a dynamics predictor (prior: what the model predicts). They’re trained with KL divergence, but:
- If KL goes to zero: the latent space collapses (useless)
- If KL grows unchecked: the world model ignores observations
The solution
Split the KL loss into two terms with different stop-gradients:
| Term | Trains | Stop-gradient on | Weight |
|---|---|---|---|
| Dynamics loss | Prior (predictor) | Posterior | 0.5 |
| Representation loss | Posterior (encoder) | Prior | 0.1 |
Plus free bits: ignore KL below 1 nat (don’t waste capacity eliminating tiny differences).
API
use rl4burn::{kl_balanced_loss, KlBalanceConfig};
let config = KlBalanceConfig::default();
// dyn_weight: 0.5, rep_weight: 0.1, free_bits: 1.0
let loss = kl_balanced_loss(posterior_logits, prior_logits, &config);
For RSSM’s 32x32 grouped categoricals:
use rl4burn::kl_balanced_loss_groups;
// posterior_logits: [batch, 32, 32]
let loss = kl_balanced_loss_groups(posterior_logits, prior_logits, &config);
Standalone KL
use rl4burn::{categorical_kl, categorical_kl_groups};
let kl = categorical_kl(p_logits, q_logits); // [batch]
GAE (Generalized Advantage Estimation)
GAE (Schulman et al., 2015) computes advantages that smoothly interpolate between high-bias/low-variance (TD) and low-bias/high-variance (Monte Carlo) estimates.
API
pub fn gae(
rewards: &[f32],
values: &[f32],
dones: &[bool],
last_value: f32,
gamma: f32,
lambda: f32,
) -> (Vec<f32>, Vec<f32>) // (advantages, returns)
Pure f32 computation — no tensors, no backend dependency.
How it works
For each timestep t, GAE computes:
- TD error:
δ_t = r_t + γ * V(s_{t+1}) * (1 - done_t) - V(s_t) - Advantage:
A_t = Σ_{l=0}^{T-t-1} (γλ)^l * δ_{t+l}
The lambda parameter controls the bias-variance tradeoff:
λ = 0: TD(0) — just the one-step TD error. Low variance, high bias.λ = 1: Monte Carlo — full discounted return minus baseline. Low bias, high variance.λ = 0.95: Standard default. Good tradeoff for most tasks.
Returns are computed as returns = advantages + values.
Done handling
When dones[t] is true, the next state is from a new episode. GAE correctly zeroes out both the bootstrap value and the accumulated advantage at episode boundaries.
Usage
GAE is called internally by ppo_collect. You only need it directly if building a custom algorithm:
use rl4burn::gae;
let rewards = vec![1.0, 1.0, 1.0, 0.0];
let values = vec![5.0, 4.0, 3.0, 2.0];
let dones = vec![false, false, false, true];
let (advantages, returns) = gae::gae(&rewards, &values, &dones, 0.0, 0.99, 0.95);
V-trace
V-trace (Espeholt et al., 2018) is an off-policy correction algorithm used in IMPALA. It computes value targets and policy gradient advantages from trajectories collected by a potentially stale behavior policy.
API
pub fn vtrace_targets(
log_rhos: &[f32], // log importance ratios log(π/μ)
discounts: &[f32], // per-step γ (can vary for terminal steps)
rewards: &[f32],
values: &[f32], // V(s_t) from critic
bootstrap: f32, // V(s_T) for the last state
clip_rho: f32, // importance weight clipping (typically 1.0)
clip_c: f32, // trace accumulation clipping (typically 1.0)
) -> (Vec<f32>, Vec<f32>) // (value_targets, advantages)
Pure f32 computation. Contract annotations enforce preconditions (non-empty inputs, matching lengths, positive clip thresholds).
When to use V-trace
V-trace is for actor-learner architectures (like IMPALA) where the acting policy may be several updates behind the learning policy. For standard on-policy PPO, use GAE instead.
Key parameters
clip_rho(ρ̄): Clips importance weights for value targets. Higher = lower bias but higher variance.clip_c(c̄): Clips importance weights for trace accumulation. Controls how far back off-policy corrections propagate.- Both typically set to 1.0.
UPGO (Self-Imitation Learning)
UPGO (Upgoing Policy Gradient) reinforces only trajectories where the agent performed better than expected. Used by ROA-Star alongside V-trace.
API
use rl4burn::upgo_advantages;
let advantages = upgo_advantages(&rewards, &values, &dones, last_value, gamma);
How it works
At each timestep, UPGO checks if the one-step TD error is positive (did better than the value predicted):
- Positive TD: Propagate the actual return backward (learn from this)
- Negative TD: Truncate to the value estimate (ignore this)
This creates a self-imitation effect: the agent only reinforces actions that led to above-average outcomes.
When to use
UPGO is complementary to V-trace, not a replacement. ROA-Star uses both:
- V-trace for stable off-policy value targets
- UPGO for the policy gradient (only reinforce good trajectories)
Replay Buffer
ReplayBuffer<S, R> stores transitions for off-policy algorithms like DQN. It’s generic over the sample type and a deterministic RNG for reproducible sampling.
API
use rand::SeedableRng;
let mut buffer = ReplayBuffer::new(10_000, rand::rngs::SmallRng::seed_from_u64(42));
buffer.extend(transitions); // add samples
let batch = buffer.sample(64); // random references
let batch = buffer.sample_cloned(64); // random clones (for owned data)
let groups = buffer.group_by(|t| t.episode_id); // group by key
Eviction
When the buffer exceeds capacity, the oldest samples are dropped first (FIFO).
With DQN
use rl4burn::dqn::Transition;
use rl4burn::replay::ReplayBuffer;
let mut buffer = ReplayBuffer::new(10_000, rand::rngs::SmallRng::seed_from_u64(42));
// Store transitions
buffer.extend(std::iter::once(Transition {
obs: obs.clone(),
action: action as i32,
reward: result.reward,
next_obs: result.observation.clone(),
done: result.done(),
}));
// dqn_update samples from the buffer internally
(online, stats) = dqn_update(online, &target, &mut optim, &mut buffer, &config, &device);
Trajectory grouping
The group_by method groups sample indices by an arbitrary key function. Useful for V-trace rescoring where you need to process entire trajectories:
let groups = buffer.group_by(|sample| sample.trajectory_id);
for (traj_id, indices) in &groups {
let trajectory: Vec<_> = indices.iter().map(|&i| &buffer.samples()[i]).collect();
// rescore this trajectory
}
Sequence Replay Buffer
A FIFO buffer that samples contiguous sequences of a fixed length, respecting episode boundaries. Used by DreamerV3 for world model training.
API
use rl4burn::{SequenceReplayBuffer, SequenceStep};
let mut buffer = SequenceReplayBuffer::new(1_000_000, 64);
// capacity: 1M steps, sequence_length: 64
// Add transitions
buffer.push(SequenceStep {
observation: obs.clone(),
action: vec![1.0, 0.0],
reward: 1.0,
done: false,
});
// Sample batch of sequences
let sequences = buffer.sample(16, &mut rng);
// sequences: Vec<Vec<SequenceStep<O>>>, each of length 64
Episode boundaries
Sampled sequences never cross episode boundaries. If a done=true step appears in the buffer, no sequence will start before it and end after it.
FIFO eviction
When the buffer exceeds capacity, the oldest steps are removed first. Episode start indices are automatically adjusted.
Difference from ReplayBuffer
| Feature | ReplayBuffer | SequenceReplayBuffer |
|---|---|---|
| Sample unit | Single step | Contiguous sequence |
| Episode boundaries | Not tracked | Enforced |
| Primary use | DQN, off-policy | DreamerV3 world models |
Percentile Return Normalization
DreamerV3-style advantage normalization using EMA-smoothed percentiles. More robust than per-minibatch normalization for sparse or heterogeneous reward scales.
API
use rl4burn::PercentileNormalizer;
let mut normalizer = PercentileNormalizer::new();
// default: 5th-95th percentile, EMA decay 0.99
// Update with observed returns
normalizer.update(&returns);
// Normalize advantages
let normalized = normalizer.normalize(&advantages);
// Divides by max(1.0, P95 - P5)
// Or combine both steps:
let normalized = normalizer.update_and_normalize(&returns, &advantages);
The max(1, …) floor
The critical detail: when the percentile range is less than 1.0 (sparse rewards, all-zero returns), the scale is clamped to 1.0. Without this, you’d amplify noise.
Customization
let normalizer = PercentileNormalizer::with_percentiles(0.1, 0.9)
.with_decay(0.999);
Intrinsic Rewards
Exploration bonuses based on internal state. Useful when extrinsic rewards are sparse.
API
use rl4burn::collect::intrinsic::{IntrinsicReward, CountBasedReward, combine_rewards};
let mut explorer = CountBasedReward::new(0.1); // discretization resolution
explorer.update(&obs, action, &next_obs);
let bonus = explorer.reward(&obs, action, &next_obs);
// bonus = 1 / sqrt(visit_count)
let combined = combine_rewards(&extrinsic, &intrinsic, 0.01);
// combined[i] = extrinsic[i] + 0.01 * intrinsic[i]
Count-Based Exploration
Reward = 1 / sqrt(N(s)) where N(s) is how many times the agent has visited a discretized version of state s. Novel states get high reward; familiar states get low reward.
Entropy-Reduction Reward
ROA-Star’s scouting reward: max(H_{prev} - H_{current}, 0). Rewards the agent for reducing uncertainty about the opponent’s strategy.
use rl4burn::collect::intrinsic::EntropyReductionReward;
let mut scouting = EntropyReductionReward::new();
let reward = scouting.reward_from_entropy(current_entropy);
The IntrinsicReward trait
Implement for custom exploration strategies:
pub trait IntrinsicReward {
type Observation;
fn reward(&self, obs: &Self::Observation, action: usize, next_obs: &Self::Observation) -> f32;
fn update(&mut self, obs: &Self::Observation, action: usize, next_obs: &Self::Observation);
}
CSPL (Curriculum Self-Play Learning)
JueWu’s 3-phase training pipeline for scaling to many heroes/unit types.
The problem
Training a single policy that handles 40+ heroes in all combinations doesn’t converge (480+ hours without success).
The solution: three phases
| Phase | What | Duration |
|---|---|---|
| 1. Specialists | Train small models on fixed team compositions | ~72h |
| 2. Distillation | Merge all specialists into one big model | ~48h |
| 3. Generalization | Continue RL with random compositions | ~144h |
API
use rl4burn::{CsplPipeline, CsplConfig, CsplPhase};
let mut pipeline = CsplPipeline::new(CsplConfig {
phase1_steps: 100_000,
phase2_steps: 50_000,
phase3_steps: 200_000,
n_specialists: 10,
});
loop {
let phase_changed = pipeline.step();
match pipeline.current_phase() {
CsplPhase::SpecialistTraining => { /* train specialists via self-play */ }
CsplPhase::Distillation => { /* distill into student */ }
CsplPhase::Generalization => { /* continue RL with random compositions */ }
}
if pipeline.is_complete() { break; }
}
Polyak Updates
Polyak (soft) target network updates interpolate between a source model’s weights and a target model’s weights:
target = τ * source + (1 - τ) * target
API
pub fn polyak_update<B: Backend, M: Module<B>>(
target: M,
source: &M,
tau: f32,
) -> M
tau = 1.0: Hard copy (replace target with source entirely).tau = 0.005: Soft update (slowly track source). Standard for SAC/TD3.tau = 0.0: No-op (target unchanged).
Usage
use rl4burn::polyak::polyak_update;
// Hard target update (DQN-style, every N steps)
if step % 250 == 0 {
target = polyak_update(target, &online, 1.0);
}
// Soft target update (SAC/TD3-style, every step)
target = polyak_update(target, &online, 0.005);
How it works
Uses Burn’s ModuleVisitor to collect all parameter tensors from the source model, then ModuleMapper to interpolate each target parameter in place. Works with any Module<B> — nested modules, custom architectures, any number of layers.
Orthogonal Initialization
Orthogonal weight initialization is critical for PPO convergence. It’s the default in CleanRL, Stable Baselines3, and OpenAI baselines.
API
pub fn orthogonal_linear<B: Backend>(
d_in: usize,
d_out: usize,
gain: f32,
device: &B::Device,
rng: &mut impl Rng,
) -> Linear<B>
Creates a Linear layer with orthogonal weights and zero bias. Matches PyTorch’s nn.init.orthogonal_.
Gain values
| Layer | Gain | Why |
|---|---|---|
| Hidden (tanh) | sqrt(2) ≈ 1.414 | Preserves gradient norms through tanh |
| Actor output | 0.01 | Near-uniform initial policy (good exploration) |
| Critic output | 1.0 | Reasonable initial value scale |
Usage
use rl4burn::init::orthogonal_linear;
let sqrt2 = std::f32::consts::SQRT_2;
let actor_fc1 = orthogonal_linear(4, 64, sqrt2, &device, &mut rng);
let actor_out = orthogonal_linear(64, 2, 0.01, &device, &mut rng);
let critic_out = orthogonal_linear(64, 1, 1.0, &device, &mut rng);
Why not use Burn’s built-in initializers?
Two reasons:
- Burn doesn’t have orthogonal initialization. The closest is
XavierUniform, which has similar scale but lacks the orthogonality property. - Burn initializes bias with the same initializer as weights. CleanRL always initializes bias to zero.
orthogonal_linearhandles both correctly.
Implementation
Uses modified Gram-Schmidt orthogonalization on a random Gaussian matrix. Weights are loaded via Param::from_data + load_record to preserve Burn’s autodiff tracking (see Working with Burn’s Autodiff).
Global Gradient Clipping
clip_grad_norm clips gradients by their global L2 norm across all parameters. This matches PyTorch’s torch.nn.utils.clip_grad_norm_.
Why not use Burn’s built-in clipping?
Burn’s GradientClippingConfig::Norm clips each parameter tensor independently. PyTorch clips the global norm across all parameters at once. These produce different behavior:
- Per-parameter (Burn): A large gradient in the critic doesn’t affect clipping of the actor’s gradient.
- Global (PyTorch/rl4burn): The total gradient norm is computed, then all gradients are scaled by the same factor.
For PPO with shared optimizer over actor + critic, global clipping is standard.
API
pub fn clip_grad_norm<B: AutodiffBackend, M: AutodiffModule<B>>(
model: &M,
grads: GradientsParams,
max_norm: f32,
) -> GradientsParams
Call between backward() and optim.step():
let grads = loss.backward();
let mut grads = GradientsParams::from_grads(grads, &model);
grads = clip_grad_norm(&model, grads, 0.5); // max_norm = 0.5
model = optim.step(lr, model, grads);
PPO handles this automatically via PpoConfig::max_grad_norm. Set to 0.0 to disable.
Implementation
Two-pass approach using the inner (non-autodiff) model:
- ModuleVisitor: Extract each gradient from
GradientsParams, compute its L2 norm squared, accumulate the global norm. - Compute
clip_coef = min(1.0, max_norm / (global_norm + 1e-6)). - ModuleMapper: Scale each gradient by
clip_coefand re-register it in a newGradientsParams.
The visitor/mapper operate on B::InnerBackend because Burn stores gradients on the inner backend, not the autodiff wrapper.
Logging
rl4burn provides a lightweight, feature-gated logging system for training metrics. The core Logger trait and built-in loggers ship with zero extra dependencies. TensorBoard and JSON output are opt-in via feature flags.
The Logger trait
All loggers implement Logger:
pub trait Logger {
fn log_scalar(&mut self, key: &str, value: f64, step: u64);
fn log_scalars(&mut self, key: &str, values: &[(&str, f64)], step: u64);
fn log_text(&mut self, key: &str, text: &str, step: u64);
fn log_histogram(&mut self, key: &str, values: &[f32], step: u64);
fn flush(&mut self);
}
Built-in loggers (no feature flags)
PrintLogger — prints scalars to stderr in a formatted line. Accepts a throttle interval so you don’t flood the terminal:
use rl4burn::PrintLogger;
// Print at most every 1000 steps
let mut logger = PrintLogger::new(1000);
logger.log_scalar("train/loss", 0.42, 5000);
// stderr: [step 5000] train/loss: 0.4200
NoopLogger — discards everything. Useful as a default when the caller doesn’t care about logging.
CompositeLogger — fans out to multiple loggers simultaneously:
use rl4burn::{CompositeLogger, PrintLogger};
use rl4burn::TensorBoardLogger; // requires `tensorboard` feature
let mut logger = CompositeLogger::new(vec![
Box::new(PrintLogger::new(0)),
Box::new(TensorBoardLogger::new("runs/ppo_cartpole").unwrap()),
]);
logger.log_scalar("train/loss", 0.5, 100); // goes to both
Logging stats from algorithms
PpoStats and DqnStats implement the Loggable trait, so you can log all their fields in one call:
use rl4burn::Loggable;
let (model, stats) = ppo_update(model, &mut optim, &rollout, &config, lr, &device, &mut rng);
stats.log(&mut logger, step);
// Logs: train/policy_loss, train/value_loss, train/entropy, train/approx_kl
For DQN:
let (online, stats) = dqn_update(online, &target, &mut optim, &mut buffer, &config, &device);
stats.log(&mut logger, step);
// Logs: train/loss, train/mean_q, train/epsilon
TensorBoard (feature-gated)
Enable the tensorboard feature in your Cargo.toml:
[dependencies]
rl4burn = { version = "0.1", features = ["tensorboard"] }
Then create a TensorBoardLogger pointing at a run directory:
use rl4burn::TensorBoardLogger;
let mut logger = TensorBoardLogger::new("runs/experiment_1").unwrap();
logger.log_scalar("train/loss", 0.42, 1000);
logger.log_histogram("weights", &weight_data, 1000);
logger.log_text("info", "training started", 0);
logger.flush();
View results with:
tensorboard --logdir runs/
The logger writes standard TFEvent files (events.out.tfevents.*) with hand-serialized protobufs — no prost or protobuf dependency required. Supports scalars, histograms, and text.
JSON output (feature-gated)
Enable the json-log feature:
[dependencies]
rl4burn = { version = "0.1", features = ["json-log"] }
JsonLogger writes one JSON object per line (JSONL format) to any Write sink:
use rl4burn::JsonLogger;
let mut logger = JsonLogger::from_path("train_log.jsonl").unwrap();
logger.log_scalar("train/loss", 0.42, 1000);
logger.flush();
Each line looks like:
{"type":"scalar","key":"train/loss","value":0.42,"step":1000,"wall_time":1234567890.123}
Bridging to Weights & Biases
A thin Python bridge script is included at scripts/wandb_bridge.py:
cargo run --example ppo_cartpole --features "ndarray,json-log" 2>&1 \
| python scripts/wandb_bridge.py
The same JSONL format can be ingested by neptune, mlflow, comet, or any custom dashboard.
Video recording (feature-gated)
Enable the video feature to record CartPole episodes as GIFs:
[dependencies]
rl4burn = { version = "0.1", features = ["video"] }
Any environment implementing the Renderable trait can produce RGB frames. CartPole and GridWorld both implement it:
use rl4burn::envs::CartPole;
use rl4burn::{write_gif, Env, Renderable};
let mut env = CartPole::new(rng);
env.reset();
let mut frames = vec![env.render()];
loop {
let step = env.step(action);
frames.push(env.render());
if step.done() { break; }
}
write_gif("episode.gif", &frames, 4).unwrap(); // 4 centiseconds per frame
Putting it all together
A typical training script with logging:
use rl4burn::{CompositeLogger, Loggable, Logger, PrintLogger, TensorBoardLogger};
let mut logger = CompositeLogger::new(vec![
Box::new(PrintLogger::new(5000)),
Box::new(TensorBoardLogger::new("runs/ppo").unwrap()),
]);
for iter in 0..n_iterations {
let rollout = ppo_collect::<NdArray, _, _>(&model.valid(), &mut vec_env, &config, &device, &mut rng, &mut current_obs, &mut ep_acc);
let step = (iter + 1) as u64 * steps_per_iter as u64;
if !rollout.episode_returns.is_empty() {
let avg = rollout.episode_returns.iter().sum::<f32>() / rollout.episode_returns.len() as f32;
logger.log_scalar("rollout/avg_return", avg as f64, step);
}
let (new_model, stats) = ppo_update(model, &mut optim, &rollout, &config, lr, &device, &mut rng);
model = new_model;
stats.log(&mut logger, step);
}
logger.flush();
Feature flags summary
| Feature | Dependency | What you get |
|---|---|---|
| (none) | — | Logger trait, PrintLogger, NoopLogger, CompositeLogger, Loggable |
tensorboard | crc32c | TensorBoardLogger (TFEvent files) |
json-log | — | JsonLogger (JSONL output) |
video | gif | write_gif(), CartPole::render() |
Saving & Sharing
After training, you typically want to save the model weights for later inference and share visualizations of agent behavior. This chapter covers both.
Saving model weights
Burn models derive Module, which gives them save_file and load_file for free. No rl4burn-specific API is needed — just use Burn’s recorder system directly.
Save a trained model
use burn::record::{CompactRecorder, Recorder};
// After training completes:
model
.save_file("checkpoints/ppo_cartpole", &CompactRecorder::new())
.expect("failed to save model");
This writes checkpoints/ppo_cartpole.mpk (MessagePack format). The file contains all learnable parameters.
Load a saved model
use burn::record::{CompactRecorder, Recorder};
// Initialize a fresh model, then load weights into it
let model: ActorCritic<AB> = ActorCritic::new(&device)
.load_file("checkpoints/ppo_cartpole", &CompactRecorder::new(), &device)
.expect("failed to load model");
The model architecture must match — load_file loads parameter values, not the structure.
Recorder types
Burn provides several recorders. Choose based on your needs:
| Recorder | Format | Good for |
|---|---|---|
CompactRecorder | MessagePack (.mpk) | Production — small files, fast I/O |
NamedMpkGzFileRecorder | gzipped MessagePack | Sharing — even smaller files |
PrettyJsonFileRecorder | JSON (.json) | Debugging — human-readable weights |
BinFileRecorder | Raw binary (.bin) | Maximum speed, no compression |
CompactRecorder is the default choice for most use cases.
Checkpointing during training
Save periodically so you can resume after interruptions or pick the best checkpoint:
use burn::record::{CompactRecorder, Recorder};
for iter in 0..n_iterations {
// ... collect and update ...
// Save every 50 iterations
if (iter + 1) % 50 == 0 {
let path = format!("checkpoints/ppo_step_{}", (iter + 1) * steps_per_iter);
model
.save_file(&path, &CompactRecorder::new())
.expect("failed to save checkpoint");
}
}
// Always save the final model
model
.save_file("checkpoints/ppo_final", &CompactRecorder::new())
.expect("failed to save final model");
DQN: saving online and target networks
For DQN, save both networks so you can resume training correctly:
online.save_file("checkpoints/dqn_online", &CompactRecorder::new())?;
target.save_file("checkpoints/dqn_target", &CompactRecorder::new())?;
For inference only, you just need the online network.
Sharing visualizations
GIF recordings
With the video feature, record an episode of your trained agent and save it as a GIF:
use rl4burn::envs::CartPole;
use rl4burn::{write_gif, greedy_action, Env, Renderable};
let mut env = CartPole::new(rng);
let mut obs = env.reset();
let mut frames = vec![env.render()];
loop {
// greedy_action runs a forward pass and returns the argmax action
let action = greedy_action(&model, &obs, &device);
let step = env.step(action);
frames.push(env.render());
if step.done() { break; }
obs = step.observation;
}
write_gif("agent_demo.gif", &frames, 4).unwrap();
The resulting GIF can be embedded in READMEs, papers, blog posts, or Slack messages.
TensorBoard
TensorBoard logs are shareable as a directory. Zip the run folder and send it, or use TensorBoard.dev for public sharing:
# View locally
tensorboard --logdir runs/
# Share publicly (requires Google account)
tensorboard dev upload --logdir runs/ppo_cartpole --name "PPO CartPole"
When comparing experiments, use separate run directories:
// Each experiment gets its own subdirectory
let logger = TensorBoardLogger::new(format!("runs/ppo_lr{}", config.lr)).unwrap();
Then tensorboard --logdir runs/ overlays all runs for comparison.
JSONL logs
JSONL files are plain text and easy to share. Post-process them with any tool:
# Quick plot with Python
python -c "
import json, matplotlib.pyplot as plt
data = [json.loads(l) for l in open('train_log.jsonl') if '\"scalar\"' in l]
returns = [(d['step'], d['value']) for d in data if d['key'] == 'rollout/avg_return']
plt.plot(*zip(*returns))
plt.xlabel('Step'); plt.ylabel('Avg Return')
plt.savefig('training_curve.png')
"
# Send to W&B
python scripts/wandb_bridge.py < train_log.jsonl
Putting it all together
A complete training script that checkpoints the model and records a final evaluation GIF:
use burn::record::{CompactRecorder, Recorder};
use rl4burn::{
CompositeLogger, Loggable, Logger, PrintLogger, TensorBoardLogger,
envs::CartPole, write_gif, Env, Renderable, greedy_action,
};
let mut logger = CompositeLogger::new(vec![
Box::new(PrintLogger::new(5000)),
Box::new(TensorBoardLogger::new("runs/ppo").unwrap()),
]);
// Training loop
for iter in 0..n_iterations {
let rollout = ppo_collect::<NdArray, _, _>(
&model.valid(), &mut vec_env, &config, &device, &mut rng, &mut current_obs, &mut ep_acc,
);
let (new_model, stats) = ppo_update(
model, &mut optim, &rollout, &config, lr, &device, &mut rng,
);
model = new_model;
let step = (iter + 1) as u64 * steps_per_iter as u64;
stats.log(&mut logger, step);
// Checkpoint every 100 iterations
if (iter + 1) % 100 == 0 {
model.save_file(
&format!("checkpoints/ppo_{step}"),
&CompactRecorder::new(),
).unwrap();
}
}
logger.flush();
// Save final weights
model.save_file("checkpoints/ppo_final", &CompactRecorder::new()).unwrap();
// Record evaluation episode
let mut env = CartPole::new(rng);
let mut obs = env.reset();
let mut frames = vec![env.render()];
loop {
let action = greedy_action(&model.valid(), &obs, &device);
let step = env.step(action);
frames.push(env.render());
if step.done() { break; }
obs = step.observation;
}
write_gif("evaluation.gif", &frames, 4).unwrap();
LSTM, GRU, and Block GRU
Recurrent cells for temporal reasoning under partial observability. Every game AI paper (AlphaStar, SCC, JueWu) uses LSTM or GRU for sequence processing.
LSTM Cell
use rl4burn::{LstmCell, LstmCellConfig, LstmState};
let cell = LstmCellConfig::new(input_size, hidden_size).init(&device);
let state = LstmState::zeros(batch_size, hidden_size, &device);
// Single step
let new_state = cell.forward(input, &state);
// Full sequence
let (outputs, final_state) = cell.forward_seq(inputs, &state);
// outputs: [batch, seq_len, hidden_size]
GRU Cell
Same API, simpler internals (2 gates vs LSTM’s 3). Uses PyTorch’s convention for reset gate application.
use rl4burn::{GruCell, GruCellConfig};
let cell = GruCellConfig::new(input_size, hidden_size).init(&device);
let h = Tensor::zeros([batch_size, hidden_size], &device);
let new_h = cell.forward(input, h);
Block GRU (DreamerV3)
Block-diagonal GRU reduces recurrent parameters by a factor of n_blocks. The recurrent weight matrix is split into independent blocks, each operating on a partition of the hidden state.
DreamerV3 uses 8 blocks with a 4096-dim hidden state, reducing parameters from 16M to 2M.
use rl4burn::{BlockGruCell, BlockGruCellConfig};
let cell = BlockGruCellConfig::new(input_size, hidden_size)
.with_n_blocks(8)
.init(&device);
When n_blocks = 1, Block GRU is identical to standard GRU.
Transformer Encoder
Reusable multi-head self-attention blocks for entity processing. Used by ROA-Star and SCC to encode sets of game units.
Multi-Head Attention
use rl4burn::{MultiHeadAttention, MultiHeadAttentionConfig};
let attn = MultiHeadAttentionConfig::new(128, 4).init(&device);
// d_model=128, 4 heads (d_k = 32 per head)
let output = attn.forward(query, key, value, None);
// All inputs: [batch, seq_len, 128]
// Optional mask: [batch, seq_len] (true = attend, false = ignore)
Transformer Block
Pre-norm residual block: self-attention + feedforward.
use rl4burn::{TransformerBlock, TransformerBlockConfig};
let block = TransformerBlockConfig::new(128, 4, 512).init(&device);
// d_model=128, 4 heads, d_ff=512
let output = block.forward(input, None); // residual: output ≈ input + attention + ffn
Stacked Encoder
use rl4burn::{TransformerEncoder, TransformerEncoderConfig};
let encoder = TransformerEncoderConfig::new(128, 4, 2, 512).init(&device);
// 2 layers of transformer blocks
let encoded = encoder.forward(entities, None);
Properties
- Permutation equivariant: reordering input tokens reorders output tokens identically (no positional encoding).
- Variable-length: use masking for padded sequences.
- For 30 entities with 128-dim embeddings, a 2-layer encoder runs in microseconds on CPU.
Attention Mechanisms
Three specialized attention modules for game AI architectures.
Target Attention
Scaled dot-product attention for selecting a target entity. The LSTM output serves as query; encoded entities serve as keys. Returns a probability distribution over entities.
use rl4burn::{TargetAttention, TargetAttentionConfig};
let attn = TargetAttentionConfig::new(256, 128).init(&device);
// query_dim=256 (LSTM output), key_dim=128 (entity embedding)
let probs = attn.forward(query, keys, Some(mask));
// query: [batch, 256], keys: [batch, n_entities, 128]
// mask: [batch, n_entities] (true = valid target)
// probs: [batch, n_entities] (sums to 1 over valid targets)
Attention Pooling
Aggregates variable-count entity embeddings into a fixed-size vector using learned query vectors. Superior to mean/max pooling.
use rl4burn::{AttentionPool, AttentionPoolConfig};
let pool = AttentionPoolConfig::new(128, 4, 2).init(&device);
// embed_dim=128, 4 learned queries, 2 attention heads
let pooled = pool.forward(entities, None);
// entities: [batch, n_entities, 128]
// pooled: [batch, 512] (4 queries * 128 dims)
Pointer Network
Additive (Bahdanau) attention for entity selection: score = v^T * tanh(W_q * query + W_k * keys). Used by AlphaStar and SCC for selecting subsets of units.
use rl4burn::{PointerNet, PointerNetConfig};
let ptr = PointerNetConfig::new(256, 128, 64).init(&device);
// query_dim=256, key_dim=128, hidden_dim=64
let probs = ptr.forward(query, keys, Some(mask));
// probs: [batch, n_entities] (selection probabilities)
All three modules support masking for absent/dead entities.
Auto-Regressive Action Distributions
For games where actions decompose into sequential decisions: what action -> which target -> which ability. Every competitive game AI paper uses this pattern.
CompositeDistribution
use rl4burn::{CompositeDistribution, ActionHead};
// 3-head action space: action_type(11) -> target(30) -> ability(8)
let dist = CompositeDistribution::from_heads(
&["action_type", "target", "ability"],
&[11, 30, 8],
);
// Total logits needed from the model: 11 + 30 + 8 = 49
assert_eq!(dist.total_logits(), 49);
Sampling
Given flat logits from the model (all heads concatenated), sample independently per head:
let actions = dist.sample(&logits, mask.as_ref(), &mut rng);
// actions: Vec<Vec<f32>> — [batch][n_heads], integer-valued
For fully auto-regressive sampling (where head 2’s logits depend on head 1’s sample), call the model multiple times and feed actions back.
Log-probabilities
Joint log-prob is the sum of per-head log-probs:
let log_prob = dist.log_prob(logits, &actions, mask.as_ref(), &device);
// log_prob: [batch] — log P(a) = log P(a1) + log P(a2) + log P(a3)
Entropy
Sum of per-head entropies (exact when heads are independent, upper bound otherwise):
let entropy = dist.entropy(logits, mask.as_ref());
// entropy: [batch]
With action masking
Pass a flat mask tensor [batch, total_logits] where 1.0 = valid, 0.0 = invalid. Masked actions are never sampled and get zero probability.
FiLM Conditioning
FiLM (Feature-wise Linear Modulation) applies a context-dependent affine transform to features. Used by SCC to condition spatial action heads on action type.
API
use rl4burn::{Film, FilmConfig};
let film = FilmConfig::new(32, 128).init(&device);
// context_dim=32, feature_dim=128
let output = film.forward(features, context);
// features: [batch, 128], context: [batch, 32]
// output: [batch, 128]
How it works
output = (1 + gamma(context)) * features + beta(context)
The +1 on gamma ensures the layer starts as an identity transform, improving training stability.
Symlog and Twohot Encoding
DreamerV3’s solution for scale-free predictions. Symlog compresses large values; twohot turns regression into classification.
Symlog / Symexp
use rl4burn::{symlog, symexp};
let compressed = symlog(values); // sign(x) * ln(|x| + 1)
let recovered = symexp(compressed); // sign(x) * (exp(|x|) - 1)
// Round-trip: symexp(symlog(x)) ≈ x
Key properties:
symlog(0) = 0symlog(1000) ≈ 6.9(massive compression)symlog(-x) = -symlog(x)(symmetric)- Monotonically increasing
Twohot Encoder
Encodes scalar values as soft distributions over 255 bins in symlog space.
use rl4burn::TwohotEncoder;
let encoder = TwohotEncoder::new(); // 255 bins, [-20, 20]
// Encode: scalar → distribution
let targets = encoder.encode(values, &device); // [batch, 255]
// Decode: distribution → scalar
let values = encoder.decode(probs, &device); // [batch]
// Loss: cross-entropy against twohot targets
let loss = encoder.loss(logits, values, &device); // [1]
Why this matters
Without symlog+twohot, you need to tune learning rates per domain. A reward of 1000 produces 1000x larger gradients than a reward of 1. Symlog compresses this to ~7x. Twohot converts regression to classification, further stabilizing gradients.
DreamerV3 Overview
DreamerV3 learns a model of the world, then trains a policy entirely inside imagined trajectories. It’s architecturally different from the model-free papers (AlphaStar, JueWu) but its sample efficiency could be transformative for fast simulations.
The DreamerV3 training loop
repeat:
1. Collect experience in the real environment
2. Store in sequence replay buffer
3. Sample sequences, train the world model (RSSM)
4. Imagine trajectories from the world model
5. Train actor-critic on imagined data
Steps 4-5 are “free” — no environment interaction needed.
rl4burn modules for DreamerV3
| Component | Module | Page |
|---|---|---|
| World model | Rssm | RSSM |
| Imagination | imagine_rollout | Imagination |
| Value targets | lambda_returns | Imagination |
| Replay | SequenceReplayBuffer | Sequence Replay |
| Transforms | symlog, TwohotEncoder | Symlog |
| KL training | kl_balanced_loss | KL Balance |
| Normalization | PercentileNormalizer | Percentile |
| Block GRU | BlockGruCell | RNN |
RSSM (Recurrent State-Space Model)
The world model at the heart of DreamerV3. Predicts what happens next in a compact latent space.
State representation
The RSSM state has two parts:
- h (deterministic): GRU hidden state, captures long-term memory
- z (stochastic): 32 categorical distributions x 32 classes, captures uncertainty
Together they form a 1024+ dimensional state sufficient to reconstruct observations, predict rewards, and determine if episodes continue.
API
use rl4burn::{Rssm, RssmConfig, RssmState};
let config = RssmConfig::new(obs_dim, action_dim);
let rssm = config.init(&device);
// Initial state (all zeros)
let state = rssm.initial_state(batch_size, &device);
// Training: observe → update
let (new_state, posterior_logits, prior_logits) = rssm.obs_step(&state, action, obs);
// Imagination: predict without observing
let new_state = rssm.imagine_step(&state, action);
// Predictions
let reward_logits = rssm.predict_reward(state.h, state.z); // [batch, 255]
let cont_logits = rssm.predict_continue(state.h, state.z); // [batch, 1]
Training the RSSM
Train with KL-balanced loss between posterior and prior:
use rl4burn::{kl_balanced_loss, KlBalanceConfig, TwohotEncoder};
let kl_loss = kl_balanced_loss(posterior_logits, prior_logits, &KlBalanceConfig::default());
let reward_loss = TwohotEncoder::new().loss(reward_logits, actual_rewards, &device);
let total_loss = kl_loss + reward_loss;
Configuration
let config = RssmConfig {
obs_dim: 64,
action_dim: 11,
deterministic_size: 512, // GRU hidden size
n_categories: 32, // stochastic groups
n_classes: 32, // classes per group
hidden_size: 512, // MLP hidden dim
n_blocks: 8, // block GRU blocks (0 = standard GRU)
unimix: 0.01, // uniform mixture for categoricals
};
Imagination Rollouts
Generate trajectories entirely within the RSSM latent space for actor-critic training.
API
use rl4burn::algo::planning::imagination::{imagine_rollout, lambda_returns};
let trajectory = imagine_rollout(
&rssm,
initial_states,
|h, z| actor_network.forward(h, z), // actor closure
15, // horizon (DreamerV3 default)
);
// trajectory.states: [16] states (initial + 15 imagined)
// trajectory.reward_logits: [15] reward predictions
// trajectory.continue_logits: [15] continue predictions
Lambda-returns
Compute value targets from imagined rewards:
let returns = lambda_returns(
&rewards, // decoded from reward_logits
&values, // critic predictions at each state
&continues, // sigmoid(continue_logits)
0.997, // gamma
0.95, // lambda
);
Stop-gradient rules
During imagination training:
- World model: frozen (no gradients). The actor learns to generate actions that lead to high-value states.
- Value targets: stop-gradiented. The critic trains on fixed targets.
- Rewards: gradients flow through the dynamics model to the actor (the actor is indirectly optimizing for states that the world model predicts will be rewarding).
R2-Dreamer
R2-Dreamer (ICLR 2026) is a computationally efficient world model for RL that achieves strong performance without decoders or augmentation. It replaces the standard reconstruction loss with self-supervised representation objectives.
Key Idea
Standard DreamerV3 trains the encoder via a decoder that reconstructs observations. R2-Dreamer eliminates this bottleneck by using redundancy reduction (Barlow Twins loss) to learn representations directly.
Representation Variants
rl4burn supports all four variants from the paper:
| Variant | Loss | Description |
|---|---|---|
Dreamer | Decoder MSE | Standard DreamerV3 reconstruction baseline |
R2Dreamer | Barlow Twins | Invariance + decorrelation on cross-correlation matrix |
InfoNCE | Contrastive | Positive pair matching with temperature-scaled cosine similarity |
DreamerPro | Prototype | Sinkhorn-Knopp assignment to learned prototypes |
Usage
#![allow(unused)]
fn main() {
use rl4burn::algo::dreamer::{DreamerConfig, dreamer_world_model_loss, dreamer_actor_critic_loss};
use rl4burn::algo::loss::representation::RepresentationVariant;
// Configure with R2-Dreamer (Barlow Twins)
let config = DreamerConfig {
rep_variant: RepresentationVariant::R2Dreamer,
action_dim: 4,
discrete_actions: true,
..DreamerConfig::default()
};
let agent = config.init::<B>(&device);
// Train world model on observed sequences
let (wm_loss, wm_stats) = dreamer_world_model_loss(
&agent, observations, actions, rewards, continues,
);
// Train actor-critic via imagination
let (actor_loss, critic_loss, ac_stats) = dreamer_actor_critic_loss(
&agent, initial_states,
);
}
Architecture
The agent composes existing rl4burn building blocks:
- RSSM (
rl4burn_nn::rssm) — recurrent state-space model with deterministic GRU + stochastic categorical states - Imagination rollouts (
rl4burn_algo::planning::imagination) — generate trajectories in latent space - KL-balanced loss (
rl4burn_algo::loss::kl_balance) — train posterior and prior with free bits - Symlog + Twohot (
rl4burn_nn::symlog) — distributional value prediction - Representation losses (
rl4burn_algo::loss::representation) — Barlow Twins, InfoNCE, DreamerPro, decoder - MLP with RMSNorm (
rl4burn_nn::mlp) — prediction heads and actor/critic networks - CNN encoder/decoder (
rl4burn_nn::conv) — image observation processing
New Modules
| Module | Crate | Description |
|---|---|---|
mlp | rl4burn-nn | Configurable MLP with RMSNorm or LayerNorm |
conv | rl4burn-nn | CNN encoder (images → features) and decoder (features → images) |
multi_encoder | rl4burn-nn | Routes mixed observations (images + vectors) |
representation | rl4burn-algo | Four self-supervised representation losses |
dreamer | rl4burn-algo | DreamerAgent, world model loss, actor-critic loss |
Example
See examples/dreamer/ for a complete training loop on CartPole.
Reference
Nauman & Straffelini, “R2-Dreamer: Redundancy Reduction for Computationally Efficient World Models” (ICLR 2026).
Self-Play
Train agents by playing against past versions of themselves. The core mechanism for competitive game AI.
API
use rl4burn::algo::multi_agent::self_play::{SelfPlayPool, branch_agent};
let mut pool = SelfPlayPool::new();
// Snapshot current model every N steps
pool.add_snapshot(&model, training_step);
// Get a random past opponent
if let Some(opponent) = pool.sample(&mut rng) {
// Run game: model vs opponent
}
// Keep only the 50 most recent
pool.retain_recent(50);
How it works
SelfPlayPool stores cloned copies of the model at different training stages. Opponents are sampled uniformly. For smarter opponent selection, see PFSP Matchmaking.
Important: snapshots are deep copies
When you call add_snapshot, the model is .clone()’d. Mutating the original model afterward does not affect stored snapshots. This is essential — without true deep copies, all “opponents” would have the same weights as the current model.
League Training
AlphaStar-style multi-agent training with role-based specialization. Multiple agents train simultaneously with different objectives.
Agent Roles
| Role | Opponents | Purpose |
|---|---|---|
| Main Agent | 35% self-play + 65% PFSP pool | General strength |
| Main Exploiter | Only the main agent | Find main agent’s weaknesses |
| League Exploiter | PFSP across full pool | Find weaknesses across all strategies |
API
use rl4burn::{League, AgentRole, LeagueAgentConfig};
let mut league = League::new();
league.set_initial_model(initial_model.clone());
// Add agents
league.add_agent(model.clone(), LeagueAgentConfig {
role: AgentRole::MainAgent,
checkpoint_interval: 1000,
reset_threshold: 0,
});
league.add_agent(model.clone(), LeagueAgentConfig {
role: AgentRole::MainExploiter,
checkpoint_interval: 2000,
reset_threshold: 50000,
});
// Training loop
let opponent = league.get_opponent(agent_idx, &mut rng);
// ... play game, update model ...
league.update_agent(agent_idx); // handles checkpointing
Checkpointing
Every checkpoint_interval steps, the agent’s current weights are frozen and added to the opponent pool. All agents can then play against these frozen snapshots.
Exploiter resets
Exploiters that stop improving get reset to the initial model weights:
league.reset_exploiter(exploiter_idx);
PFSP Matchmaking
Prioritized Fictitious Self-Play samples harder opponents more frequently. The opponent you lose to most often is the one you practice against most.
API
use rl4burn::{PfspMatchmaking, PfspConfig};
let mut mm = PfspMatchmaking::new(PfspConfig {
power: 1.0, // higher = more focus on hard opponents
min_prob: 0.01, // every opponent has at least 1% chance
});
mm.add_opponent(0);
mm.add_opponent(1);
mm.add_opponent(2);
// Record results
mm.record_result(0, true, false); // beat opponent 0
mm.record_result(1, false, false); // lost to opponent 1
// Sample: opponent 1 (harder) is sampled more often
let opponent = mm.sample_opponent(&mut rng);
Weighting formula
Selection probability is proportional to (1 - win_rate) ^ power:
- Win rate 90% → weight 0.1
- Win rate 50% → weight 0.5
- Win rate 10% → weight 0.9
Higher power makes the distribution more extreme.
Multi-Agent Shared-Weight Training
Efficiently control multiple units with a single shared policy network. Used by JueWu for 5 heroes and applicable to any game with multiple controlled units.
API
use rl4burn::{batch_multi_agent_obs, unbatch_actions, broadcast_team_reward};
// Batch observations from all agents across all environments
let (obs_tensor, n_envs, n_agents) = batch_multi_agent_obs::<B>(
&per_env_per_agent_obs,
&device,
);
// obs_tensor: [n_envs * n_agents, obs_dim] — one big batch
// Single forward pass for all agents
let output = model.forward(obs_tensor);
// Unbatch actions back to per-env, per-agent
let actions = unbatch_actions(&flat_actions, n_envs, n_agents);
// actions: [n_envs][n_agents]
// Broadcast team reward to all agents
let per_agent_rewards = broadcast_team_reward(&env_rewards, n_agents);
Why shared weights?
With 30 units, running 30 separate forward passes is expensive. With shared weights, batch all 30 observations into one forward pass. The policy generalizes across unit types through the observation encoding.
Privileged Critic
The policy sees only what a player would see. The critic sees everything — including enemy positions behind fog of war.
The Trait
use rl4burn::algo::privileged_critic::{PrivilegedActorCritic, make_critic_input};
impl<B: Backend> PrivilegedActorCritic<B> for MyModel<B> {
fn actor_forward(&self, obs: Tensor<B, 2>) -> Tensor<B, 2> {
self.actor.forward(obs) // partial observation only
}
fn critic_forward(&self, obs: Tensor<B, 2>, privileged: Tensor<B, 2>) -> Tensor<B, 1> {
let input = make_critic_input(obs, privileged);
self.critic.forward(input)
}
}
Why it works
Value estimation under partial observability is noisy — the critic can’t tell if you’re winning or losing without seeing the full game state. Giving the critic privileged information dramatically reduces variance. At deployment time, only the actor is needed.
Goal-Conditioned RL (z-Conditioning)
Condition policies on strategy descriptors to enable rapid specialization. Used by ROA-Star and SCC for exploiter training.
API
use rl4burn::algo::z_conditioning::{ZConditioning, ZConditioningConfig, z_reward};
let z_mod = ZConditioningConfig::new(16, obs_dim).init(&device);
// z_dim=16 (strategy embedding), obs_dim from environment
let conditioned_obs = z_mod.forward(obs, z);
// conditioned_obs: [batch, obs_dim + 64] — ready for policy network
// Pseudo-reward for following target strategy
let reward = z_reward(&observed_stats, &target_z);
// negative L2 distance: closer to target = higher reward
What is z?
A low-dimensional vector describing a play style, computed from human replay statistics. Examples:
- Aggressive: high damage, low farming
- Defensive: low damage, high survival
- Rush: high early-game activity
By conditioning on different z vectors, the same policy can exhibit different strategies.
Agent Branching
Clone an agent’s weights to create a new specialized agent. Used by SCC for initializing exploiters from the current main agent rather than from scratch.
API
use rl4burn::algo::multi_agent::self_play::branch_agent;
let exploiter = branch_agent(&main_agent);
let mut exploiter_optim = AdamConfig::new().init(); // fresh optimizer!
Why branch?
Starting exploiters from the main agent’s current weights (instead of the supervised model) gives them a head start. They already know how to play — they just need to specialize in exploiting weaknesses.
The key: the optimizer state must be reset. branch_agent clones the model; you create a fresh optimizer.
MCTS for Drafting
UCT-based Monte Carlo Tree Search for pre-game decisions like unit composition or hero drafting.
API
use rl4burn::algo::planning::mcts::{MctsTree, MctsConfig};
let mut tree = MctsTree::new(MctsConfig {
n_simulations: 800,
exploration_constant: 1.41,
n_actions: 30, // number of possible picks
});
let visit_counts = tree.search(|action_path| {
// Evaluate this sequence of picks.
// Return estimated win rate (0.0 to 1.0).
evaluate_composition(action_path)
}, &mut rng);
let best_pick = tree.best_action();
let pick_probs = tree.action_probs();
How UCT works
- Select: walk down the tree, choosing children by UCT score =
mean_value + c * sqrt(ln(parent_visits) / visits) - Expand: add a new child for an unexplored action
- Evaluate: call your evaluation function on the action sequence
- Backpropagate: update visit counts and values up to the root
After all simulations, pick the most-visited action (not the highest-value — visit count is more robust).
Beta-VAE Opponent Modeling
ROA-Star’s approach: train a frozen encoder to predict opponent behavior behind fog of war, then use the latent embedding as extra context for all agents.
API
use rl4burn::nn::vae::{BetaVae, BetaVaeConfig};
let vae = BetaVaeConfig::new(obs_dim)
.with_latent_dim(32)
.with_beta(4.0)
.init(&device);
// Training
let output = vae.forward(opponent_features);
let loss = vae.loss(opponent_features, &output);
// Inference: extract strategy embedding
let z = vae.strategy_embedding(opponent_features);
// z: [batch, 32] — feed this as extra context to the policy
Why beta-VAE?
A standard VAE often ignores the latent space (posterior collapse). Higher beta forces the model to use the latent space, producing more disentangled and interpretable strategy embeddings.
Scouting reward
The entropy of the opponent model’s predictions can be used as an intrinsic reward: the agent is rewarded for actions that reduce uncertainty about the opponent.
use rl4burn::collect::intrinsic::EntropyReductionReward;
Distributed Training
Abstractions for multi-GPU/multi-machine gradient synchronization.
The GradientSync Trait
use rl4burn::algo::distributed::{GradientSync, ReduceStrategy};
pub trait GradientSync {
fn all_reduce_f32(&self, values: &[f32], strategy: ReduceStrategy) -> Vec<f32>;
fn rank(&self) -> usize;
fn world_size(&self) -> usize;
fn barrier(&self);
}
Local Development
Use LocalSync for single-machine development. All operations are no-ops.
use rl4burn::algo::distributed::LocalSync;
let sync = LocalSync;
assert_eq!(sync.world_size(), 1);
Custom Implementations
Implement GradientSync for your cluster’s communication library (MPI, NCCL, gRPC, etc.):
struct MpiSync { /* ... */ }
impl GradientSync for MpiSync {
fn all_reduce_f32(&self, values: &[f32], strategy: ReduceStrategy) -> Vec<f32> {
// Call MPI_Allreduce
}
// ...
}
At scale (SCC: ~1000 envs per agent, HoK: 320 GPUs), ring all-reduce is the standard choice for gradient averaging.
Cloud GPU Deployment
rl4burn provides the rl4burn-cloud crate with provider-agnostic abstractions for
launching training jobs on cloud GPU marketplaces. Currently supported:
- Vast.ai — peer-to-peer GPU marketplace with competitive pricing
- RunPod — managed GPU cloud with on-demand and spot pods
Enable the feature in your Cargo.toml:
[dependencies]
rl4burn = { git = "https://github.com/RPP1011/rl4burn", features = ["cloud"] }
The CloudProvider Trait
All providers implement a common interface:
use rl4burn::cloud::{CloudProvider, InstanceRequirements, GpuOffer, Instance};
pub trait CloudProvider {
fn name(&self) -> &'static str;
fn search_offers(&self, reqs: &InstanceRequirements) -> CloudResult<Vec<GpuOffer>>;
fn launch(&self, offer: &GpuOffer) -> CloudResult<Instance>;
fn status(&self, instance_id: &str) -> CloudResult<Instance>;
fn stop(&self, instance_id: &str) -> CloudResult<()>;
}
The trait is HTTP-client agnostic — providers produce structured HttpRequest
values that you execute with your preferred client (reqwest, ureq, curl, etc.).
Specifying Requirements
use rl4burn::cloud::{InstanceRequirements, GpuType};
let reqs = InstanceRequirements {
min_gpu_ram_gib: 24.0,
num_gpus: 1,
gpu_types: vec![GpuType::Rtx4090, GpuType::A100Pcie],
min_ram_gib: 32.0,
min_disk_gib: 100.0,
docker_image: "pytorch/pytorch:2.3.1-cuda12.1-cudnn8-devel".into(),
max_price_per_hour: 1.00,
on_start_cmd: Some("apt-get update && apt-get install -y cargo".into()),
..Default::default()
};
Vast.ai
use rl4burn::cloud::{VastAiProvider, CloudProvider};
let provider = VastAiProvider::new(std::env::var("VASTAI_API_KEY").unwrap());
// Search for offers
let offers = provider.search_offers(&reqs)?;
println!("Found {} offers, cheapest: ${}/hr",
offers.len(), offers[0].price_per_hour);
// Launch the cheapest offer
let instance = provider.launch(&offers[0])?;
println!("Instance {} status: {}", instance.instance_id, instance.status);
// Check status
let inst = provider.status(&instance.instance_id)?;
if let Some(ssh) = &inst.ssh_connection {
println!("Connect: {}", ssh);
}
// Cleanup
provider.stop(&instance.instance_id)?;
RunPod
use rl4burn::cloud::{RunPodProvider, CloudProvider};
let provider = RunPodProvider::new(std::env::var("RUNPOD_API_KEY").unwrap());
let offers = provider.search_offers(&reqs)?;
let instance = provider.launch(&offers[0])?;
println!("Pod {} status: {}", instance.instance_id, instance.status);
provider.stop(&instance.instance_id)?;
Comparing Providers
You can write provider-agnostic code by working with &dyn CloudProvider:
fn cheapest_offer(
providers: &[&dyn CloudProvider],
reqs: &InstanceRequirements,
) -> Option<(GpuOffer, &'static str)> {
let mut best: Option<(GpuOffer, &str)> = None;
for provider in providers {
if let Ok(offers) = provider.search_offers(reqs) {
for offer in offers {
if best.as_ref().map_or(true, |(b, _)| offer.price_per_hour < b.price_per_hour) {
best = Some((offer, provider.name()));
}
}
}
}
best
}
HTTP Client Integration
The providers are designed to be dependency-free. Each method can either:
- Execute requests directly — register an HTTP function with
.with_http(fn). - Return request descriptions — use
.search_request(),.launch_request(), etc. to getHttpRequeststructs you execute yourself.
Example with a custom HTTP function:
fn my_http(req: &rl4burn::cloud::vastai::HttpRequest) -> Result<String, String> {
// Use reqwest, ureq, curl, etc.
todo!()
}
let provider = VastAiProvider::new("key").with_http(my_http);
Supported GPU Types
The GpuType enum covers common training GPUs:
| Variant | GPU |
|---|---|
Rtx3090 | NVIDIA RTX 3090 (24 GB) |
Rtx4090 | NVIDIA RTX 4090 (24 GB) |
RtxA4000 | NVIDIA RTX A4000 (16 GB) |
RtxA5000 | NVIDIA RTX A5000 (24 GB) |
RtxA6000 | NVIDIA RTX A6000 (48 GB) |
A100Pcie | NVIDIA A100 PCIe (40/80 GB) |
A100Sxm | NVIDIA A100 SXM (80 GB) |
H100Pcie | NVIDIA H100 PCIe (80 GB) |
H100Sxm | NVIDIA H100 SXM (80 GB) |
L40 | NVIDIA L40 (48 GB) |
L40s | NVIDIA L40S (48 GB) |
AlphaStar & ROA-Star
The 30-second version
AlphaStar (DeepMind, 2019) was the first AI to beat a top professional StarCraft II player. ROA-Star (Tencent, NeurIPS 2023) achieves the same level with 4x less compute by adding opponent modeling and smarter exploiter training.
Both are massive RL systems, but their core ideas decompose into modular building blocks — most of which are in rl4burn.
What makes StarCraft II hard for RL?
Imagine playing chess, except:
- You can only see part of the board (fog of war)
- Both players move simultaneously
- You control 200 pieces at once
- Each piece has 10+ possible actions
- Games last 20+ minutes (thousands of decisions)
Standard RL algorithms break under this complexity. AlphaStar’s solution: decompose the problem.
Key ideas (and where they are in rl4burn)
1. Auto-regressive action space
Instead of choosing from millions of possible joint actions, AlphaStar samples one decision at a time:
action_type → delay → queue → selected_units → target_unit → target_location
Each head is conditioned on the previous samples. This is exactly what CompositeDistribution provides.
use rl4burn::CompositeDistribution;
let dist = CompositeDistribution::from_heads(
&["action_type", "target", "ability"],
&[11, 30, 8],
);
See Auto-Regressive Action Distributions for details.
2. V-trace for off-policy correction
With thousands of parallel actors, the behavior policy is always slightly stale. V-trace corrects for this. Already in rl4burn as vtrace_targets.
See V-trace.
3. UPGO (self-imitation learning)
Only learn from experiences where you did better than expected. If the return exceeds the value baseline, reinforce it. Otherwise, ignore it.
use rl4burn::upgo_advantages;
let advantages = upgo_advantages(&rewards, &values, &dones, last_value, gamma);
See UPGO.
4. League training with PFSP
Instead of just self-play, AlphaStar trains a league of agents:
- Main agent: plays against everyone
- Main exploiter: specializes in beating the main agent
- League exploiters: find weaknesses across the entire pool
Opponents are sampled using PFSP — harder opponents (lower win rate) get sampled more often.
use rl4burn::{League, AgentRole, LeagueAgentConfig, PfspMatchmaking};
See League Training and PFSP Matchmaking.
5. ROA-Star’s additions
ROA-Star adds two ideas:
- Beta-VAE opponent modeling: A frozen encoder predicts what the opponent is doing behind fog of war. The latent embedding is fed to all agents as extra context. See Beta-VAE Opponent Modeling.
- Goal-conditioned exploiters: Exploiters are conditioned on strategy descriptors z, letting them specialize rapidly. See Goal-Conditioned RL.
Further reading
- AlphaStar paper (Nature, 2019)
- ROA-Star paper (NeurIPS, 2023)
SCC (StarCraft Commander)
The 30-second version
SCC (inspir.ai, ICML 2021) reaches GrandMaster in StarCraft II with 10x less compute than AlphaStar. Its trick: a more efficient architecture (49M vs 139M parameters) and smarter training (agent branching instead of training exploiters from scratch).
Key innovations
Group Transformer
Instead of processing all game units with one big attention layer, SCC groups them:
- Intra-group self-attention: ally units attend to each other, enemy units attend to each other
- Inter-group cross-attention: ally representations attend to enemy representations
This is more efficient for games with natural groupings (teams, unit types).
rl4burn provides the building blocks: TransformerEncoder for self-attention, MultiHeadAttention for cross-attention. See Transformer Encoder and Attention Mechanisms.
Attention-based pooling
Variable numbers of units get aggregated into fixed-size vectors using learned query vectors. Better than mean-pooling because the model learns which units matter most.
use rl4burn::{AttentionPool, AttentionPoolConfig};
let pool = AttentionPoolConfig::new(128, 4, 2).init(&device);
// 128-dim entity embeddings, 4 learned queries, 2 attention heads
// Output: [batch, 4 * 128] = [batch, 512]
See Attention Mechanisms.
FiLM conditioning
The target position head is conditioned on the action type using FiLM: output = gamma(ctx) * input + beta(ctx). This lets the same network produce different spatial distributions depending on whether you’re attacking, moving, or casting.
use rl4burn::{Film, FilmConfig};
let film = FilmConfig::new(action_embed_dim, spatial_feature_dim).init(&device);
See FiLM Conditioning.
Agent branching
When creating a new exploiter, SCC clones the current main agent’s weights instead of starting from the supervised model. The optimizer state is reset. This gives exploiters a head start.
use rl4burn::algo::multi_agent::self_play::branch_agent;
let exploiter = branch_agent(&main_agent);
// Create a fresh optimizer for the exploiter
See Agent Branching.
Pointer networks
For selecting “which of my units should do this?”, SCC uses pointer networks — attention over encoder outputs producing a selection distribution.
use rl4burn::{PointerNet, PointerNetConfig};
The architecture in one sentence
Group Transformer encodes entities → attention pooling aggregates → residual LSTM sequences → FiLM-conditioned heads output → pointer networks select.
Further reading
- SCC paper (ICML, 2021)
JueWu & Honor of Kings
The 30-second version
JueWu (Tencent, NeurIPS 2020) is the first AI to beat top professional players in Honor of Kings, a 5v5 MOBA. The key insight: macro strategy (which lane to go to, when to fight) emerges from micro rewards — you don’t need a separate strategic layer.
Why MOBAs are different from StarCraft
In a MOBA, you control 1 hero (or 5 with shared weights). The action space is simpler but the strategic depth comes from teamwork, timing, and map control. Games are shorter (~15 minutes) but the reward is extremely sparse (win/lose).
Key ideas
Multi-head value decomposition
Instead of one value function, JueWu uses 5:
- Farming (gold/XP)
- KDA (kills/deaths/assists)
- Damage dealt
- Tower pushing
- Win/lose
Each head learns independently with its own discount factor. The combined advantage drives the policy.
use rl4burn::{MultiHeadValueConfig, multi_head_gae};
let config = MultiHeadValueConfig::new(5, 0.99, 0.95)
.with_weights(vec![0.1, 0.2, 0.2, 0.2, 0.3]); // win/lose weighted highest
This helps with credit assignment: the agent knows why it’s doing well, not just that it’s doing well.
See Multi-Head Value Decomposition.
Dual-clip PPO
Standard PPO clips the policy ratio to prevent too-large updates. Dual-clip adds a second constraint: when the advantage is negative, the objective can’t go below c * advantage (c=3). This prevents catastrophic updates in distributed training where trajectories are slightly off-policy.
let config = PpoConfig {
dual_clip_coef: Some(3.0),
..Default::default()
};
See Dual-Clip PPO.
Supervised pre-training matters (a lot)
JueWu-SL (a separate paper) showed that behavioral cloning from top human players provides 64% of the final RL performance. RL then refines and exceeds human play.
use rl4burn::bc_loss_discrete;
See Behavioral Cloning.
Curriculum Self-Play Learning (CSPL)
Training 40+ heroes at once doesn’t converge. CSPL breaks it into 3 phases:
- Specialist training: Train small models on fixed team compositions
- Distillation: Merge all specialists into one big model
- Generalization: Continue RL with random compositions
Without CSPL, training fails after 480+ hours. With it, convergence in ~264 hours.
use rl4burn::{CsplPipeline, CsplConfig, CsplPhase};
See CSPL.
Privileged critic
During training, the value function sees everything — including enemy positions behind fog of war. The policy only sees what the player would see. This dramatically improves value estimation.
use rl4burn::algo::privileged_critic::PrivilegedActorCritic;
See Privileged Critic.
The architecture in one sentence
Shared-weight policy across 5 heroes → LSTM for temporal memory → multi-head value for credit assignment → dual-clip PPO for stability → CSPL for scaling to many heroes.
Further reading
- Honor of Kings paper (NeurIPS, 2020)
- JueWu-SL paper (IEEE TNNLS, 2020)
DreamerV3
The 30-second version
DreamerV3 (Hafner et al., Nature 2025) learns a model of the world and then trains a policy entirely inside imagined trajectories. It works across wildly different domains (Atari, robotic control, Minecraft) with zero hyperparameter tuning.
The secret: symlog transforms that make gradient magnitudes independent of reward scale.
What’s a world model?
Most RL algorithms learn by trial and error in the real environment. World models flip this:
- Play the game a bit, store transitions
- Train a model to predict what happens next (the “world model”)
- Imagine thousands of trajectories inside the model
- Train the policy on imagined data
Step 3 is free — no environment interaction needed. This makes DreamerV3 extremely sample-efficient.
The RSSM (how the world model works)
The RSSM (Recurrent State-Space Model) has 5 networks:
| Component | Input | Output | Purpose |
|---|---|---|---|
| Sequence model (GRU) | h_{t-1}, z_{t-1}, a_{t-1} | h_t | Deterministic memory |
| Encoder (posterior) | h_t, observation | z_t | What actually happened |
| Dynamics (prior) | h_t | z_hat_t | What the model predicts |
| Reward predictor | h_t, z_t | reward | Expected reward |
| Continue predictor | h_t, z_t | continue prob | Is the episode over? |
The state is (h_t, z_t) where h is a deterministic GRU hidden state and z is a stochastic categorical variable (32 groups x 32 classes = 1024 dims).
use rl4burn::{Rssm, RssmConfig};
let rssm = RssmConfig::new(obs_dim, action_dim).init(&device);
let state = rssm.initial_state(batch_size, &device);
// Training: use observations
let (next_state, post_logits, prior_logits) = rssm.obs_step(&state, action, obs);
// Imagination: no observations needed
let next_state = rssm.imagine_step(&state, action);
See RSSM.
Symlog: the key to fixed hyperparameters
The biggest problem with RL across domains is reward scale. Atari rewards are 0-1000. Robotic rewards are -1 to 0. Without normalization, you need different learning rates for each.
DreamerV3 solves this with symlog: symlog(x) = sign(x) * ln(|x| + 1). This compresses large values and keeps small values linear. Combined with twohot encoding (distributional predictions), gradient magnitudes become independent of value scale.
use rl4burn::{symlog, symexp, TwohotEncoder};
let encoder = TwohotEncoder::new(); // 255 bins, [-20, 20] symlog space
let targets = encoder.encode(values, &device); // [batch, 255]
let loss = encoder.loss(logits, values, &device); // cross-entropy
let decoded = encoder.decode(softmax(logits, 1), &device); // back to scalars
See Symlog and Twohot Encoding.
KL balancing: training the world model
The RSSM is trained with two KL losses:
- Dynamics loss: Make the prior match the posterior (train the predictor)
- Representation loss: Make the posterior predictable (don’t be too complex)
Each has a stop-gradient on one side, plus a “free bits” threshold (ignore KL below 1 nat).
use rl4burn::{kl_balanced_loss, KlBalanceConfig};
let config = KlBalanceConfig {
dyn_weight: 0.5,
rep_weight: 0.1,
free_bits: 1.0,
};
let loss = kl_balanced_loss(posterior_logits, prior_logits, &config);
See KL Balancing with Free Bits.
Imagination rollouts
Once the world model is trained, generate trajectories purely in latent space:
use rl4burn::algo::planning::imagination::{imagine_rollout, lambda_returns};
let trajectory = imagine_rollout(&rssm, initial_states, |h, z| actor(h, z), 15);
// trajectory.states: 16 states (initial + 15 steps)
// trajectory.reward_logits: 15 predicted reward distributions
Compute lambda-returns on the imagined rewards, then train actor and critic on these imagined trajectories. The world model parameters are frozen during actor-critic training.
See Imagination Rollouts.
Sequence replay buffer
DreamerV3 samples contiguous sequences (T=64) from a FIFO buffer, never crossing episode boundaries.
use rl4burn::{SequenceReplayBuffer, SequenceStep};
let mut buffer = SequenceReplayBuffer::new(1_000_000, 64);
Percentile return normalization
Instead of per-minibatch normalization, DreamerV3 tracks the 5th-95th percentile range of returns with an EMA and divides by max(1, range). The floor of 1 prevents amplifying noise.
use rl4burn::PercentileNormalizer;
let mut normalizer = PercentileNormalizer::new();
normalizer.update(&returns);
let normalized = normalizer.normalize(&advantages);
See Percentile Return Normalization.
Further reading
- DreamerV3 paper (Nature, 2025)
- DreamerV2 paper (ICLR, 2021)
- Original Dreamer paper (ICLR, 2020)
Working with Burn’s Autodiff
We discovered several Burn 0.20 behaviors that affect RL implementations. This chapter documents them so you don’t hit the same issues.
1. Custom parameter initialization must use from_data + load_record
Problem: Param::initialized(id, tensor) creates parameters that are invisible to Burn’s autodiff. The optimizer will silently produce zero updates — the model trains but weights never change.
Cause: Tensors created via Tensor::from_data(TensorData::new(...), device) are leaf nodes without gradient tracking. Wrapping them in Param::initialized doesn’t register them.
Fix: Use Burn’s record system:
// WRONG: gradients won't flow
let weight = Tensor::from_data(my_data, &device);
let param = Param::initialized(old_param.id.clone(), weight);
// RIGHT: use Param::from_data + load_record
use burn::nn::LinearRecord;
let record = LinearRecord {
weight: Param::from_data(weight_data, &device),
bias: Some(Param::from_data(bias_data, &device)),
};
let linear = LinearConfig::new(d_in, d_out).init(&device).load_record(record);
orthogonal_linear handles this correctly. If you implement custom initialization, follow the same pattern.
Alternative: If you must use Param::initialized, call .set_require_grad(true) on the result. But from_data + load_record is the canonical approach.
2. Gradient clipping is per-parameter, not global
Problem: GradientClippingConfig::Norm(0.5) on the optimizer clips each parameter tensor’s gradient independently. PyTorch’s clip_grad_norm_ clips the global norm across all parameters.
Impact: With per-parameter clipping, the actor’s small gradients are unaffected while the critic’s large gradients are clipped. With global clipping, all gradients are scaled by the same factor, which is the standard behavior for PPO.
Fix: Use clip::clip_grad_norm instead of Burn’s optimizer clipping:
use rl4burn::clip::clip_grad_norm;
// Don't configure clipping on the optimizer
let mut optim = AdamConfig::new().init();
// Clip manually between backward and step
let grads = loss.backward();
let mut grads = GradientsParams::from_grads(grads, &model);
grads = clip_grad_norm(&model, grads, 0.5);
model = optim.step(lr, model, grads);
PPO’s ppo_update does this automatically when max_grad_norm > 0.
3. mask_where may not propagate gradients through the source argument
Problem: tensor.mask_where(mask, source) selects from source where mask is true, otherwise keeps self. Burn’s autodiff may not propagate gradients through the source argument.
Impact: If you use mask_where to compute max(a, b) by selecting the larger value, and the mask happens to select from source for all elements, the gradient can be zero.
Fix: Use arithmetic alternatives:
use burn::tensor::activation::relu;
// Instead of mask_where for max:
// max(a, b) = a + relu(b - a)
let max_val = a.clone() + relu(b - a);
// Instead of mask_where for min:
// min(a, b) = b - relu(b - a)
let min_val = b.clone() - relu(b.clone() - a);
These have correct gradients everywhere (except at the exact boundary where a == b, which has measure zero in practice).
General advice
- Always test gradient flow. After implementing a custom model or loss, verify that
optim.stepactually changes the model’s weights. A simple test:
let before: Vec<f32> = model.weight.val().into_data().to_vec().unwrap();
// ... forward, loss, backward, step ...
let after: Vec<f32> = model.weight.val().into_data().to_vec().unwrap();
assert!(before != after, "weights should change");
-
Use
model.valid()for inference. This strips the autodiff layer, avoiding unnecessary computation graph construction during rollout collection. -
Extract tensor data to break the computation graph. When you need to use a value as a constant (e.g., target Q-values), call
.into_data().to_vec()and create a fresh tensor from the result. This is equivalent to PyTorch’s.detach().