Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

rl4burn

Reinforcement learning algorithms for the Burn ML framework.

rl4burn lets you write RL algorithms once with B: AutodiffBackend and run them on any Burn backend — WGPU, CUDA, NdArray, or LibTorch. No per-backend reimplementation.

What’s included

Algorithms

AlgorithmTypeStatus
PPOOn-policy, actor-criticSolves CartPole in <1s
Dual-Clip PPOPPO for distributed trainingJueWu/HoK-style
DQNOff-policy, value-basedSolves CartPole in ~9s
Behavioral CloningSupervised imitationCross-entropy on demonstrations
Policy DistillationTeacher-student transferTemperature-scaled KL

Neural Network Modules

LSTM, GRU, and block-diagonal GRU cells. Transformer encoder blocks. Multi-head attention, target attention, attention pooling, and pointer networks. FiLM conditioning. Auto-regressive action distributions.

World Models (DreamerV3)

RSSM world model with imagination rollouts. Symlog/twohot distributional encoding. KL balancing with free bits. Sequence replay buffer. Percentile return normalization.

Game AI Infrastructure

Self-play with opponent pools. League training with agent roles (AlphaStar-style). PFSP matchmaking. Multi-agent shared-weight training. Privileged critic. Goal-conditioned RL. Agent branching. MCTS for drafting. Beta-VAE opponent modeling. Curriculum self-play learning (CSPL).

Building Blocks

GAE, V-trace, UPGO, replay buffers, multi-head value decomposition, intrinsic rewards, polyak updates, loss functions, orthogonal initialization, global gradient clipping, and logging.

Workspace architecture

rl4burn is organized as a Cargo workspace of five focused crates (rl4burn-core, rl4burn-nn, rl4burn-collect, rl4burn-algo, rl4burn-envs) plus an umbrella rl4burn crate that re-exports the full API. Users depend only on rl4burn — no need to manage individual crate dependencies. See the Architecture chapter for details.

Cookbook

The repository includes 15 runnable examples organized into five tiers:

  1. Fundamentals — quickstart, annotated PPO, config-driven training
  2. Environment Variations — custom environments, continuous actions, multi-discrete actions
  3. Techniques — action masking, reward shaping, LSTM policies
  4. Multi-Agent & Game AI — self-play, multi-agent, curriculum learning
  5. Production — diagnostics, hyperparameter tuning, policy deployment

Run any example with cargo run -p <name> --release. See the Cookbook for the full list and a decision guide for choosing the right algorithm.

Why Burn?

Burn’s Backend trait lets you write generic code:

fn train<B: AutodiffBackend>(model: MyModel<B>, device: &B::Device) {
    // This works on WGPU, CUDA, NdArray, LibTorch...
}

For RL, this means:

  • Train on GPU with Autodiff<Wgpu> or Autodiff<LibTorch>
  • Deploy on edge with NdArray (no GPU needed, no_std capable)
  • Run in the browser with WASM via the WGPU backend

No other Rust RL library achieves this level of backend portability.

Design philosophy

  • You own the training loop. ppo_collect and ppo_update are functions, not a framework. Compose them however you want.
  • Minimal API surface. Each algorithm is ~200 lines. Read the source — it’s meant to be understood.
  • Correctness first. Integration tests train both PPO and DQN on CartPole to convergence. Contract annotations enforce preconditions on core functions.
  • Match reference implementations. PPO defaults match CleanRL’s ppo.py. When Burn’s behavior differs from PyTorch (gradient clipping, parameter initialization), we provide compatible alternatives.

Architecture

rl4burn is organized as a Cargo workspace with five focused crates and one umbrella crate that re-exports everything.

Workspace layout

crates/
  rl4burn-core     — Env trait, spaces, SyncVecEnv, wrappers, Logger
  rl4burn-nn       — Neural network utilities (LSTM, GRU, attention, FiLM, policy traits, init)
  rl4burn-collect  — GAE, V-trace, UPGO, replay buffers, collection patterns
  rl4burn-algo     — PPO, DQN, AC, imitation, multi-agent, planning, losses
  rl4burn-envs     — CartPole, Pendulum, GridWorld
rl4burn/           — Umbrella crate re-exporting everything
examples/          — 15 runnable cookbook examples (see Cookbook)

Dependency DAG

The crates form a clean dependency hierarchy:

rl4burn-core          (no internal deps)
    |
    +--- rl4burn-nn       (depends on core)
    |       |
    |       +--- rl4burn-collect  (depends on core, nn)
    |       |       |
    |       |       +--- rl4burn-algo  (depends on core, nn, collect)
    |       |
    +--- rl4burn-envs     (depends on core)

Each crate has a single responsibility:

  • rl4burn-core defines the foundational abstractions: the Env trait, observation/action spaces, vectorized environments (SyncVecEnv), environment wrappers, and the logging system.
  • rl4burn-nn provides neural network building blocks: RNN cells (LSTM, GRU, block-diagonal GRU), transformer encoders, attention mechanisms, FiLM conditioning, policy traits (DiscreteActorCritic, MaskedActorCritic, QNetwork), orthogonal initialization, gradient clipping, and polyak updates.
  • rl4burn-collect handles data collection: GAE, V-trace, UPGO, replay buffers, sequence replay, intrinsic rewards, percentile normalization, and distributed collection patterns (actor-learner, centralized inference, trajectory queues).
  • rl4burn-algo contains the algorithms: PPO, DQN, actor-critic with V-trace, behavioral cloning, policy distillation, multi-agent infrastructure (self-play, league training, PFSP), planning (MCTS, imagination rollouts), and loss functions.
  • rl4burn-envs provides built-in environments for testing and examples: CartPole, Pendulum, and GridWorld.

The umbrella crate

Most users should depend only on rl4burn in their Cargo.toml. The umbrella crate re-exports the full public API at the top level:

// All of these work — no intermediate module paths needed:
use rl4burn::SyncVecEnv;
use rl4burn::PpoConfig;
use rl4burn::{ppo_collect, ppo_update};
use rl4burn::{DiscreteActorCritic, DiscreteAcOutput};
use rl4burn::{Logger, PrintLogger};
use rl4burn::ReplayBuffer;

If you need access to the sub-crate modules directly, they are also available:

use rl4burn::core;    // rl4burn_core
use rl4burn::nn;      // rl4burn_nn
use rl4burn::collect; // rl4burn_collect
use rl4burn::algo;    // rl4burn_algo
use rl4burn::envs;    // rl4burn_envs

When to depend on individual crates

For most projects, the umbrella rl4burn crate is the right choice. You might depend on individual crates if:

  • You only need environments (rl4burn-core + rl4burn-envs) and want minimal compile times.
  • You are building a custom algorithm and only need the collection primitives (rl4burn-collect).
  • You are writing a library that extends rl4burn and want to minimize your dependency footprint.

Installation

Add rl4burn and Burn to your Cargo.toml:

[dependencies]
rl4burn = { git = "https://github.com/RPP1011/rl4burn" }
burn = { version = "0.20", features = ["std", "ndarray", "autodiff"] }
rand = "0.10"

rl4burn is a workspace of focused crates (rl4burn-core, rl4burn-nn, rl4burn-collect, rl4burn-algo, rl4burn-envs), but the umbrella rl4burn crate re-exports everything so you only need one dependency.

The ndarray feature gives you a CPU backend for development and testing. For GPU training, add wgpu or tch (LibTorch) instead.

Verify the install

Create a src/main.rs:

use rl4burn::envs::CartPole;
use rl4burn::env::Env;
use rand::SeedableRng;

fn main() {
    let mut env = CartPole::new(rand::rngs::SmallRng::seed_from_u64(42));
    let obs = env.reset();
    println!("CartPole observation: {:?}", obs);

    let step = env.step(1); // push right
    println!("Reward: {}, Done: {}", step.reward, step.done());
}
cargo run

You should see a 4-element observation vector and a reward of 1.0.

Your First Agent: PPO on CartPole

This walkthrough trains a PPO agent to balance a pole on a cart. By the end, you’ll understand the three pieces every rl4burn training script needs: a model, environments, and a training loop.

The model

PPO needs an actor-critic model: given an observation, produce action logits and a value estimate. Define a Burn module and implement DiscreteActorCritic:

use burn::module::Module;
use burn::nn::{Linear, LinearConfig};
use burn::prelude::*;
use rl4burn::{DiscreteAcOutput, DiscreteActorCritic};

#[derive(Module, Debug)]
struct ActorCritic<B: Backend> {
    // Separate actor and critic networks (no shared layers)
    actor_fc1: Linear<B>,
    actor_fc2: Linear<B>,
    actor_out: Linear<B>,
    critic_fc1: Linear<B>,
    critic_fc2: Linear<B>,
    critic_out: Linear<B>,
}

impl<B: Backend> ActorCritic<B> {
    fn new(device: &B::Device) -> Self {
        Self {
            actor_fc1: LinearConfig::new(4, 64).init(device),
            actor_fc2: LinearConfig::new(64, 64).init(device),
            actor_out: LinearConfig::new(64, 2).init(device),
            critic_fc1: LinearConfig::new(4, 64).init(device),
            critic_fc2: LinearConfig::new(64, 64).init(device),
            critic_out: LinearConfig::new(64, 1).init(device),
        }
    }
}

impl<B: Backend> DiscreteActorCritic<B> for ActorCritic<B> {
    fn forward(&self, obs: Tensor<B, 2>) -> DiscreteAcOutput<B> {
        let a = self.actor_fc1.forward(obs.clone()).tanh();
        let a = self.actor_fc2.forward(a).tanh();
        let logits = self.actor_out.forward(a);

        let c = self.critic_fc1.forward(obs).tanh();
        let c = self.critic_fc2.forward(c).tanh();
        let values = self.critic_out.forward(c).squeeze_dim::<1>(1);

        DiscreteAcOutput { logits, values }
    }
}

Key points:

  • #[derive(Module)] gives you parameter management, serialization, and device transfer for free.
  • DiscreteAcOutput holds logits: Tensor<B, 2> (shape [batch, n_actions]) and values: Tensor<B, 1> (shape [batch]).
  • The model is generic over B: Backend. The same struct works on any Burn backend.

The environments

CartPole is built in. Wrap it in SyncVecEnv to run multiple copies in parallel:

use rl4burn::envs::CartPole;
use rl4burn::SyncVecEnv;
use rand::SeedableRng;

let n_envs = 4;
let envs: Vec<CartPole<rand::rngs::SmallRng>> = (0..n_envs)
    .map(|i| CartPole::new(rand::rngs::SmallRng::seed_from_u64(i as u64)))
    .collect();
let mut vec_env = SyncVecEnv::new(envs);

SyncVecEnv steps all environments in lockstep and auto-resets when episodes end.

The training loop

PPO training alternates between two phases: collect a rollout of experience, then update the model on that experience.

use burn::backend::{Autodiff, NdArray};
use burn::module::AutodiffModule;
use burn::optim::AdamConfig;
use rl4burn::{ppo_collect, ppo_update, PpoConfig};
use rl4burn::{Loggable, Logger, PrintLogger};

type AB = Autodiff<NdArray>;

let device = burn::backend::ndarray::NdArrayDevice::Cpu;
let mut rng = rand::rngs::SmallRng::seed_from_u64(42);

let mut model: ActorCritic<AB> = ActorCritic::new(&device);
let mut optim = AdamConfig::new().with_epsilon(1e-5).init();
let config = PpoConfig::default();
let mut logger = PrintLogger::new(0);

// Episode return accumulator — persists across rollouts
let mut ep_acc = vec![0.0f32; n_envs];
// Current observations — persists across rollouts
let mut current_obs = vec_env.reset();

for iter in 0..100 {
    // Collect: use the non-autodiff model for inference
    let rollout = ppo_collect::<NdArray, _, _>(
        &model.valid(),
        &mut vec_env,
        &config,
        &device,
        &mut rng,
        &mut current_obs,
        &mut ep_acc,
    );

    // Update: train on the collected data
    let (new_model, stats) = ppo_update(
        model, &mut optim, &rollout, &config,
        config.lr, // or use LR annealing
        &device, &mut rng,
    );
    model = new_model;

    // Log training stats
    let step = (iter + 1) as u64 * (config.n_steps * n_envs) as u64;
    stats.log(&mut logger, step);

    if !rollout.episode_returns.is_empty() {
        let avg = rollout.episode_returns.iter().sum::<f32>()
            / rollout.episode_returns.len() as f32;
        logger.log_scalar("rollout/avg_return", avg as f64, step);
    }
}
logger.flush();

Key points:

  • model.valid() strips the autodiff layer for efficient inference during collection.
  • current_obs holds the latest observations from the environments, persisting across rollout boundaries so the next collection starts from where the last one left off.
  • ep_acc tracks per-env cumulative reward across rollout boundaries. Without this, episodes longer than n_steps would have their returns split.
  • ppo_update returns the updated model (Burn modules are moved through optimizers, not mutated in place).
  • stats.log(...) uses the Loggable trait to log all PPO metrics. See the Logging chapter for details on logger setup.

Run it

cargo run -p quickstart --release

You should see episode returns climb from ~20 (random policy) to 500 (solved) within seconds.

Cookbook

rl4burn ships with 15 runnable examples in the examples/ directory, organized into five tiers of increasing complexity. Each example is a standalone Cargo package that you can run with cargo run -p <name> --release.

Tier 1: Fundamentals

ExampleCommandDescription
quickstartcargo run -p quickstart --releaseMinimal PPO on CartPole — the “hello world” of RL
ppo-annotatedcargo run -p ppo-annotated --releaseSame as quickstart but with detailed comments explaining every line
config-drivencargo run -p config-driven --releaseLoad hyperparameters from a TOML file instead of hardcoding them

Tier 2: Environment Variations

ExampleCommandDescription
custom-envcargo run -p custom-env --releaseImplement the Env trait for your own environment
ppo-continuouscargo run -p ppo-continuous --releasePPO with continuous actions on Pendulum
ppo-multi-discretecargo run -p ppo-multi-discrete --releasePPO with multi-discrete action spaces

Tier 3: Techniques

ExampleCommandDescription
action-maskingcargo run -p action-masking --releaseInvalid action masking with the masked PPO pipeline
reward-shapingcargo run -p reward-shaping --releaseIntrinsic rewards and reward shaping wrappers
lstm-policycargo run -p lstm-policy --releaseRecurrent policy for partially observable environments

Tier 4: Multi-Agent & Game AI

ExampleCommandDescription
self-playcargo run -p self-play --releaseSelf-play training with an opponent pool
multi-agentcargo run -p multi-agent --releaseShared-weight multi-agent training
curriculumcargo run -p curriculum --releaseCurriculum self-play learning (CSPL)

Tier 5: Production

ExampleCommandDescription
diagnosticscargo run -p diagnostics --releaseTensorBoard logging, video recording, and training diagnostics
hyperparameter-tuningcargo run -p hyperparameter-tuning --releaseSystematic hyperparameter sweeps
deploy-policycargo run -p deploy-policy --releaseExport a trained policy for inference on a different backend

Which algorithm should I use?

Use this decision guide to pick the right starting point:

ScenarioRecommended algorithmStart from example
Discrete actions (e.g., CartPole, Atari)PPO or DQNquickstart
Continuous actions (e.g., Pendulum, MuJoCo)PPO with Gaussian policyppo-continuous
Multi-discrete actions (e.g., RTS games)PPO with multi-headppo-multi-discrete
Invalid actions vary per stepMasked PPOaction-masking
Competitive game (1v1 or teams)Self-play PPOself-play
Partial observabilityLSTM policy + PPOlstm-policy
Multiple cooperating agentsShared-weight PPOmulti-agent
Large observation space / model-basedDreamerV3 (future)

When in doubt, start with PPO (quickstart). It is the most versatile algorithm and works well across a wide range of problems. Switch to DQN only if you need off-policy learning or have a small discrete action space where sample efficiency matters.

Environments

The Env trait defines how an RL agent interacts with the world. It follows modern Gymnasium conventions.

The Env trait

pub trait Env {
    type Observation: Clone;
    type Action: Clone;

    fn reset(&mut self) -> Self::Observation;
    fn step(&mut self, action: Self::Action) -> Step<Self::Observation>;
    fn observation_space(&self) -> Space;
    fn action_space(&self) -> Space;
}

step returns a Step struct with separate terminated and truncated flags:

pub struct Step<O> {
    pub observation: O,
    pub reward: f32,
    pub terminated: bool,  // episode ended due to environment dynamics
    pub truncated: bool,   // episode ended due to time limit
}

The done() method returns terminated || truncated.

Implementing a custom environment

use rl4burn::env::{Env, Step};
use rl4burn::space::Space;

struct MyEnv {
    state: f32,
    step_count: usize,
}

impl Env for MyEnv {
    type Observation = Vec<f32>;
    type Action = usize;

    fn reset(&mut self) -> Vec<f32> {
        self.state = 0.0;
        self.step_count = 0;
        vec![self.state]
    }

    fn step(&mut self, action: usize) -> Step<Vec<f32>> {
        self.state += if action == 0 { -0.1 } else { 0.1 };
        self.step_count += 1;
        Step {
            observation: vec![self.state],
            reward: -self.state.abs(), // reward for staying near 0
            terminated: self.state.abs() > 1.0,
            truncated: self.step_count >= 200,
        }
    }

    fn observation_space(&self) -> Space {
        Space::Box { low: vec![-2.0], high: vec![2.0] }
    }

    fn action_space(&self) -> Space {
        Space::Discrete(2)
    }
}

Built-in environments

EnvironmentObs dimActionsMax steps
CartPole42 (left/right)500
use rl4burn::envs::CartPole;
use rand::SeedableRng;

let mut env = CartPole::new(rand::rngs::SmallRng::seed_from_u64(42));

CartPole is generic over R: Rng, so you control the random number generator for reproducibility.

Spaces

Spaces describe the shape and bounds of observations and actions. They’re used for constructing networks (knowing input/output dimensions) and validating data.

The Space enum

pub enum Space {
    Discrete(usize),                     // {0, 1, ..., n-1}
    Box { low: Vec<f32>, high: Vec<f32> }, // continuous, per-dimension bounds
    MultiDiscrete(Vec<usize>),           // multiple independent discrete spaces
}

Methods

  • flat_dim() — total dimension (one-hot width for Discrete, number of dims for Box)
  • shape() — shape as a Vec

Usage

Spaces are returned by Env::observation_space() and Env::action_space(). Use them to size your network layers:

let obs_dim = env.observation_space().flat_dim(); // e.g., 4 for CartPole
let n_actions = match env.action_space() {
    Space::Discrete(n) => n,
    _ => panic!("expected discrete actions"),
};

let fc1 = LinearConfig::new(obs_dim, 64).init(&device);
let out = LinearConfig::new(64, n_actions).init(&device);

Vectorized Environments

SyncVecEnv runs N copies of an environment in lockstep, collecting N transitions per step. This is essential for PPO, which needs batched data from parallel environments.

Usage

use rl4burn::vec_env::SyncVecEnv;
use rl4burn::envs::CartPole;
use rand::SeedableRng;

let envs: Vec<CartPole<_>> = (0..8)
    .map(|i| CartPole::new(rand::rngs::SmallRng::seed_from_u64(i as u64)))
    .collect();
let mut vec_env = SyncVecEnv::new(envs);

// Reset all environments
let observations = vec_env.reset(); // Vec of 8 observations

// Step all environments with one action each
let actions = vec![0, 1, 0, 1, 1, 0, 1, 0];
let steps = vec_env.step(actions); // Vec of 8 Step results

Auto-reset

When an environment reaches a terminal or truncated state, SyncVecEnv automatically resets it. The returned observation is the initial observation of the new episode, not the terminal observation. This matches Gymnasium’s SyncVectorEnv behavior.

The reward and done flags in the returned Step are from the terminal step — only the observation is replaced.

When to use SyncVecEnv

AlgorithmVectorized?Why
PPOYes (required)Needs batched rollouts from parallel envs
DQNNo (typically)Single-env stepping with replay buffer

Environment Wrappers

Wrappers transform an environment’s observations, rewards, or tracking without modifying the environment itself. They implement Env and wrap an inner Env.

EpisodeStats

Tracks cumulative episode reward and length. Updated when episodes complete.

use rl4burn::wrapper::EpisodeStats;

let mut env = EpisodeStats::new(CartPole::new(rng));
env.reset();

loop {
    let step = env.step(action);
    if step.done() {
        println!("Episode return: {}", env.last_episode_reward.unwrap());
        println!("Episode length: {}", env.last_episode_length.unwrap());
        break;
    }
}

RewardClip

Clips rewards to [-limit, limit]. Useful for environments with large or unbounded rewards.

use rl4burn::wrapper::RewardClip;

let env = RewardClip::new(my_env, 1.0); // rewards clipped to [-1, 1]

NormalizeObservation

Normalizes observations to zero mean, unit variance using Welford’s online algorithm. Observations are also clipped to [-clip, clip].

use rl4burn::wrapper::NormalizeObservation;

let env = NormalizeObservation::new(my_env, 10.0).unwrap(); // clip normalized obs to [-10, 10]

Requires the environment to have Observation = Vec<f32> and a Box observation space.

Composing wrappers

Wrappers compose naturally:

let env = EpisodeStats::new(
    RewardClip::new(
        NormalizeObservation::new(my_env, 10.0).unwrap(),
        1.0
    )
);

PPO (Proximal Policy Optimization)

PPO is an on-policy actor-critic algorithm. It collects a batch of experience using the current policy, computes advantages, then performs multiple epochs of minibatch gradient descent with a clipped surrogate objective.

Our implementation matches CleanRL’s ppo.py.

API

PPO is split into two functions:

  • ppo_collect — Run the policy on vectorized environments, collect transitions, compute GAE advantages.
  • ppo_update — Perform clipped PPO gradient steps on the collected data.

You compose them in your own training loop.

The DiscreteActorCritic trait

pub trait DiscreteActorCritic<B: Backend> {
    fn forward(&self, obs: Tensor<B, 2>) -> DiscreteAcOutput<B>;
}

pub struct DiscreteAcOutput<B: Backend> {
    pub logits: Tensor<B, 2>,  // [batch, n_actions]
    pub values: Tensor<B, 1>,  // [batch]
}

Implement this on any #[derive(Module)] struct. The model must output both action logits (for the policy) and value estimates (for the critic) in a single forward pass.

Configuration

PpoConfig defaults match CleanRL:

ParameterDefaultDescription
lr2.5e-4Learning rate
gamma0.99Discount factor
gae_lambda0.95GAE smoothing parameter
clip_eps0.2Surrogate clipping range
vf_coef0.5Value loss coefficient
ent_coef0.01Entropy bonus coefficient
update_epochs4Optimization epochs per rollout
minibatch_size128Minibatch size
n_steps128Rollout length per env
clip_vlosstrueWhether to clip value loss
max_grad_norm0.5Global gradient norm clipping (0 = disabled)

LR annealing

ppo_update takes a current_lr parameter. For linear annealing:

let frac = 1.0 - iter as f64 / n_iterations as f64;
let current_lr = config.lr * frac;

For constant LR, just pass config.lr.

Episode return tracking

ppo_collect accepts an &mut Vec<f32> accumulator for per-env episode returns. This handles episodes that span multiple rollouts correctly. Completed episode returns are in PpoRollout::episode_returns.

let mut current_obs = vec_env.reset(); // create once before the loop
let mut ep_acc = vec![0.0f32; n_envs];

let rollout = ppo_collect(..., &mut current_obs, &mut ep_acc);
for &ret in &rollout.episode_returns {
    println!("completed episode return: {ret}");
}

Multi-discrete actions and action masking

For complex action spaces (multiple discrete dimensions, per-step validity masks), use masked_ppo_collect and masked_ppo_update with an ActionDist:

use rl4burn::{ActionDist, MaskedActorCritic, masked_ppo_collect, masked_ppo_update};

// Action space: [action_type(5), target(10)]
let action_dist = ActionDist::MultiDiscrete(vec![5, 10]);

The MaskedActorCritic trait

pub trait MaskedActorCritic<B: Backend> {
    fn forward(&self, obs: Tensor<B, 2>) -> (Tensor<B, 2>, Tensor<B, 1>);
    fn log_std(&self) -> Option<Tensor<B, 1>> { None } // continuous only
}

If you already have a DiscreteActorCritic model, the delegation is trivial:

impl<B: Backend> MaskedActorCritic<B> for MyModel<B> {
    fn forward(&self, obs: Tensor<B, 2>) -> (Tensor<B, 2>, Tensor<B, 1>) {
        let out = DiscreteActorCritic::forward(self, obs);
        (out.logits, out.values)
    }
}

Action masking

Environments provide per-step masks via Env::action_mask():

fn action_mask(&self) -> Option<Vec<f32>> {
    let mut mask = vec![0.0; 15]; // 5 + 10
    for valid_type in &self.valid_action_types { mask[*valid_type] = 1.0; }
    for valid_target in &self.valid_targets { mask[5 + *valid_target] = 1.0; }
    Some(mask)
}

Masked actions are never sampled and receive zero probability during training.

Env action type

The masked pipeline expects Env<Action = Vec<f32>>. For existing discrete envs (Action = usize), use DiscreteEnvAdapter:

use rl4burn::DiscreteEnvAdapter;

let envs: Vec<DiscreteEnvAdapter<CartPole<_>>> = (0..4)
    .map(|i| DiscreteEnvAdapter(CartPole::new(rng)))
    .collect();

Continuous action spaces

For continuous control (e.g. Pendulum, MuJoCo), use ActionDist::Continuous. The model outputs means (and optionally log standard deviations) for a diagonal Gaussian distribution.

ModelOutput mode

The model outputs [batch, 2 * action_dim] — first half is means, second half is log_stds:

let action_dist = ActionDist::Continuous {
    action_dim: 1,
    log_std_mode: LogStdMode::ModelOutput,
};

// Model outputs [batch, 2]: [mean, log_std]
impl<B: Backend> MaskedActorCritic<B> for MyModel<B> {
    fn forward(&self, obs: Tensor<B, 2>) -> (Tensor<B, 2>, Tensor<B, 1>) {
        let h = self.encoder.forward(obs);
        let logits = self.policy_head.forward(h.clone()); // [batch, 2]
        let values = self.value_head.forward(h).squeeze_dim::<1>(1);
        (logits, values)
    }
}

Separate mode

For state-independent log_std (CleanRL’s default), the model outputs only means and provides log_std via a separate learnable parameter:

let action_dist = ActionDist::Continuous {
    action_dim: 1,
    log_std_mode: LogStdMode::Separate,
};

impl<B: Backend> MaskedActorCritic<B> for MyModel<B> {
    fn forward(&self, obs: Tensor<B, 2>) -> (Tensor<B, 2>, Tensor<B, 1>) {
        // logits = [batch, action_dim] (means only)
        ...
    }
    fn log_std(&self) -> Option<Tensor<B, 1>> {
        Some(self.log_std_param.val())
    }
}

log_std clamping

ActionDist::Continuous automatically clamps log_std to [-5, 2] in all operations (sampling, log-prob, entropy). This prevents numerical instability from excessively large or small standard deviations — a common source of training divergence in continuous RL.

Continuous PPO tips

  • Set ent_coef: 0.0 — entropy bonus can destabilize continuous policies
  • Use update_epochs: 10 — more gradient steps per rollout helps with continuous
  • Longer rollouts (n_steps: 256+) improve value estimation for dense-reward tasks
  • Environments should accept Vec<f32> actions (Pendulum does this natively)

See examples/ppo_pendulum.rs for a complete working example.

Implementation details

  • Per-minibatch advantage normalization: Advantages are z-normalized within each minibatch, not globally across the full rollout.
  • Clipped value loss: max(unclipped, clipped) using a + relu(b - a) to avoid mask_where gradient issues in Burn’s autodiff.
  • Clipped surrogate: min(surr1, surr2) using b - relu(b - a) for the same reason.
  • Global gradient clipping: Uses clip_grad_norm (PyTorch-compatible global norm), not Burn’s built-in per-parameter clipping.
  • Minibatch shuffling: Fisher-Yates shuffle each epoch.

DQN (Deep Q-Network)

DQN is an off-policy, value-based algorithm. It learns a Q-function that estimates the expected return for each action in a given state, then acts by taking the argmax.

API

  • QNetwork trait — Your model implements fn q_values(&self, obs) -> Tensor returning Q-values for all actions.
  • dqn_update — One gradient step on a minibatch sampled from the replay buffer, using the target network for stable Bellman targets.
  • epsilon_greedy — Action selection with exploration.
  • epsilon_schedule — Linear epsilon annealing.
  • polyak_update — Target network update (hard or soft).

The QNetwork trait

pub trait QNetwork<B: Backend> {
    fn q_values(&self, obs: Tensor<B, 2>) -> Tensor<B, 2>;
}

Input shape: [batch, obs_dim]. Output shape: [batch, n_actions].

Example implementation:

use burn::tensor::activation::relu;

#[derive(Module, Debug)]
struct QNet<B: Backend> {
    fc1: Linear<B>,
    fc2: Linear<B>,
    q_head: Linear<B>,
}

impl<B: Backend> QNetwork<B> for QNet<B> {
    fn q_values(&self, obs: Tensor<B, 2>) -> Tensor<B, 2> {
        let h = relu(self.fc1.forward(obs));
        let h = relu(self.fc2.forward(h));
        self.q_head.forward(h)
    }
}

Training loop

DQN differs from PPO: it uses a single environment, a replay buffer, and epsilon-greedy exploration.

use rl4burn::{dqn_update, epsilon_greedy, epsilon_schedule, DqnConfig, Transition};
use rl4burn::ReplayBuffer;
use rl4burn::polyak_update;
use rl4burn::Env;

let config = DqnConfig::default();
let mut buffer = ReplayBuffer::new(config.buffer_capacity, rand::rngs::SmallRng::seed_from_u64(42));
let mut online: QNet<AB> = QNet::new(&device);
let mut target = online.clone();
let mut optim = AdamConfig::new().init();
let mut obs = env.reset();

for step in 0..50_000 {
    // Epsilon-greedy action selection (use non-autodiff model)
    let eps = epsilon_schedule(&config, step);
    let action = {
        let inner = online.valid();
        epsilon_greedy::<NdArray, _>(&inner, &obs, 2, eps, &device, &mut rng)
    };

    // Step environment, store transition
    let result = env.step(action);
    buffer.extend(std::iter::once(Transition {
        obs: obs.clone(),
        action: action as i32,
        reward: result.reward,
        next_obs: result.observation.clone(),
        done: result.done(),
    }));
    obs = if result.done() { env.reset() } else { result.observation };

    // Train after warmup
    if step >= config.learning_starts && buffer.len() >= config.batch_size {
        (online, _) = dqn_update(
            online, &target, &mut optim, &mut buffer, &config, &device,
        );

        // Hard target update every N steps
        if step % 250 == 0 {
            target = polyak_update(target, &online, 1.0);
        }
    }
}

Configuration

ParameterDefaultDescription
lr1e-4Learning rate
gamma0.99Discount factor
buffer_capacity10,000Replay buffer size
batch_size32Minibatch size
tau0.005Polyak coefficient (1.0 = hard copy)
eps_start1.0Initial exploration rate
eps_end0.05Final exploration rate
eps_decay_steps10,000Steps to anneal epsilon
learning_starts1,000Random steps before training

Target network

DQN uses a slowly-updated target network for stable Bellman targets. Two strategies:

  • Hard updates (tau = 1.0): Copy all weights every N steps. Simpler, what CleanRL uses.
  • Soft updates (tau = 0.005): Polyak average every step. Smoother, what SAC/TD3 use.

The caller is responsible for calling polyak_updatedqn_update only updates the online network.

How dqn_update works

  1. Sample a minibatch from the replay buffer
  2. Compute Q(s, a) for taken actions using the online network
  3. Compute max Q(s’, a’) using the target network (detached from the computation graph by extracting tensor data)
  4. Bellman target: y = r + γ * (1 - done) * max_a' Q_target(s', a')
  5. MSE loss: mean((Q(s, a) - y)²)
  6. Backward + optimizer step

Dual-Clip PPO

An extension of standard PPO used by JueWu and Honor of Kings for distributed training stability.

The problem

In distributed RL, the behavior policy can be several updates behind. When the ratio pi_new/pi_old is very large and the advantage is negative, standard PPO’s objective becomes excessively negative, causing destructive updates.

The fix

Add a floor: when advantage < 0, the objective can’t go below c * advantage (c = 3):

standard_ppo = min(ratio * adv, clip(ratio, 1-ε, 1+ε) * adv)
dual_clip    = max(standard_ppo, c * adv)    // only when adv < 0

Usage

let config = PpoConfig {
    dual_clip_coef: Some(3.0),
    ..Default::default()
};

That’s it. Set dual_clip_coef: None (the default) for standard PPO.

When to use

Only needed for distributed/asynchronous training where trajectories may be significantly off-policy. For single-machine training, standard PPO is sufficient.

Behavioral Cloning

Train a policy to imitate expert demonstrations via supervised learning. JueWu showed this provides ~64% of final RL performance as initialization.

API

use rl4burn::{bc_loss_discrete, bc_step};

// Single loss computation
let loss = bc_loss_discrete(logits, expert_actions, &device);

// Full training step (forward + backward + optimizer step)
let (model, loss_val) = bc_step(model, &mut optim, obs, expert_actions, lr, &device);

Multi-head actions

For hierarchical action spaces:

use rl4burn::bc_loss_multi_head;

let loss = bc_loss_multi_head(logits, expert_actions, &[11, 30, 8], &device);
// head_sizes: action_type(11), target(30), ability(8)

Tips

  • The uniform-policy cross-entropy loss should equal ln(K) where K is the number of actions. If your initial loss is much higher, something is wrong.
  • BC is most useful as RL weight initialization, not as a standalone method. BC policies are brittle — they fail on states not in the training data.

Policy Distillation

Train a student network to match a teacher’s behavior. Used in CSPL (Phase 2) to merge multiple specialist teachers into one generalist.

API

use rl4burn::algo::imitation::distillation::{distillation_loss, DistillationConfig};

let config = DistillationConfig {
    temperature: 2.0,
    soft_weight: 1.0,
    hard_weight: 0.0,
    t_squared_scaling: true,
};

let loss = distillation_loss(teacher_logits, student_logits, &config);

Temperature

Higher temperature produces softer probability distributions. The student learns more from the relative ordering of actions, not just the best one.

  • T=1: standard softmax (peaked)
  • T=5: much softer (exposes teacher’s “second choice” preferences)

T-squared scaling

Hinton et al. recommend scaling the soft-target loss by T-squared. Without this, gradients from soft targets are 1/T-squared too small.

Value distillation

use rl4burn::algo::imitation::distillation::value_distillation_loss;
let vloss = value_distillation_loss(teacher_values, student_values);

Loss Functions

Backend-generic loss functions for RL training. All return Tensor<B, 1> with shape [1] (scalar), compatible with .backward().

Value loss (Huber)

pub fn value_loss<B: Backend>(pred: Tensor<B, 1>, target: Tensor<B, 1>) -> Tensor<B, 1>

Smooth L1 (Huber) loss with δ=1.0. Quadratic for small errors, linear for large errors. Prevents outlier targets from dominating the value head update.

Discrete policy loss (REINFORCE)

pub fn policy_loss_discrete<B: Backend>(
    logits: Tensor<B, 2>,       // [batch, n_actions]
    actions: Tensor<B, 2, Int>, // [batch, 1] action indices
    mask: Tensor<B, 2>,         // [batch, n_actions] valid=1.0
    advantage: Tensor<B, 1>,    // [batch]
) -> Tensor<B, 1>

Standard REINFORCE: -mean(advantage * log_prob(action)). Supports action masking for environments with invalid actions.

Continuous policy loss

pub fn policy_loss_continuous<B: Backend>(
    pred: Tensor<B, 2>,      // [batch, action_dim]
    target: Tensor<B, 2>,    // [batch, action_dim]
    advantage: Tensor<B, 1>, // [batch]
) -> Tensor<B, 1>

Advantage-weighted regression for deterministic continuous actions. Only positive advantages contribute gradient (negative advantage + MSE is degenerate).

Note

These loss functions are standalone building blocks. PPO and DQN implement their own loss computation internally (clipped surrogate for PPO, Bellman MSE for DQN). Use these when building custom algorithms.

Multi-Head Value Decomposition

Decompose value estimation into N heads, each tracking a different reward component. Used by JueWu (Honor of Kings) with 5 heads: farming, KDA, damage, pushing, and winning.

API

use rl4burn::{MultiHeadValueConfig, multi_head_gae, multi_head_value_loss};

let config = MultiHeadValueConfig::new(5, 0.99, 0.95)
    .with_weights(vec![0.1, 0.2, 0.2, 0.2, 0.3]);

let result = multi_head_gae(
    &per_head_rewards,    // [5][T]
    &per_head_values,     // [5][T]
    &dones,               // [T]
    &per_head_last_values, // [5]
    &config,
);

// result.combined_advantages: [T] — weighted sum across heads
// result.per_head_returns: [5][T] — targets for each value head

Why decompose?

With a single value function, the agent knows how well it’s doing but not why. Multi-head decomposition provides credit assignment: “I’m farming well but my pushing is weak.”

Each head can have its own discount factor — short-term heads (damage) use lower gamma, long-term heads (winning) use higher gamma.

Per-head value loss

let losses = multi_head_value_loss(&predictions, &targets);
// losses: [5] — MSE per head
let total_loss: f32 = losses.iter().sum();

KL Balancing with Free Bits

DreamerV3’s method for training the RSSM world model without the latent space collapsing.

The problem

The RSSM has an encoder (posterior: what actually happened) and a dynamics predictor (prior: what the model predicts). They’re trained with KL divergence, but:

  • If KL goes to zero: the latent space collapses (useless)
  • If KL grows unchecked: the world model ignores observations

The solution

Split the KL loss into two terms with different stop-gradients:

TermTrainsStop-gradient onWeight
Dynamics lossPrior (predictor)Posterior0.5
Representation lossPosterior (encoder)Prior0.1

Plus free bits: ignore KL below 1 nat (don’t waste capacity eliminating tiny differences).

API

use rl4burn::{kl_balanced_loss, KlBalanceConfig};

let config = KlBalanceConfig::default();
// dyn_weight: 0.5, rep_weight: 0.1, free_bits: 1.0

let loss = kl_balanced_loss(posterior_logits, prior_logits, &config);

For RSSM’s 32x32 grouped categoricals:

use rl4burn::kl_balanced_loss_groups;

// posterior_logits: [batch, 32, 32]
let loss = kl_balanced_loss_groups(posterior_logits, prior_logits, &config);

Standalone KL

use rl4burn::{categorical_kl, categorical_kl_groups};

let kl = categorical_kl(p_logits, q_logits);  // [batch]

GAE (Generalized Advantage Estimation)

GAE (Schulman et al., 2015) computes advantages that smoothly interpolate between high-bias/low-variance (TD) and low-bias/high-variance (Monte Carlo) estimates.

API

pub fn gae(
    rewards: &[f32],
    values: &[f32],
    dones: &[bool],
    last_value: f32,
    gamma: f32,
    lambda: f32,
) -> (Vec<f32>, Vec<f32>)  // (advantages, returns)

Pure f32 computation — no tensors, no backend dependency.

How it works

For each timestep t, GAE computes:

  • TD error: δ_t = r_t + γ * V(s_{t+1}) * (1 - done_t) - V(s_t)
  • Advantage: A_t = Σ_{l=0}^{T-t-1} (γλ)^l * δ_{t+l}

The lambda parameter controls the bias-variance tradeoff:

  • λ = 0: TD(0) — just the one-step TD error. Low variance, high bias.
  • λ = 1: Monte Carlo — full discounted return minus baseline. Low bias, high variance.
  • λ = 0.95: Standard default. Good tradeoff for most tasks.

Returns are computed as returns = advantages + values.

Done handling

When dones[t] is true, the next state is from a new episode. GAE correctly zeroes out both the bootstrap value and the accumulated advantage at episode boundaries.

Usage

GAE is called internally by ppo_collect. You only need it directly if building a custom algorithm:

use rl4burn::gae;

let rewards = vec![1.0, 1.0, 1.0, 0.0];
let values = vec![5.0, 4.0, 3.0, 2.0];
let dones = vec![false, false, false, true];
let (advantages, returns) = gae::gae(&rewards, &values, &dones, 0.0, 0.99, 0.95);

V-trace

V-trace (Espeholt et al., 2018) is an off-policy correction algorithm used in IMPALA. It computes value targets and policy gradient advantages from trajectories collected by a potentially stale behavior policy.

API

pub fn vtrace_targets(
    log_rhos: &[f32],     // log importance ratios log(π/μ)
    discounts: &[f32],    // per-step γ (can vary for terminal steps)
    rewards: &[f32],
    values: &[f32],       // V(s_t) from critic
    bootstrap: f32,       // V(s_T) for the last state
    clip_rho: f32,        // importance weight clipping (typically 1.0)
    clip_c: f32,          // trace accumulation clipping (typically 1.0)
) -> (Vec<f32>, Vec<f32>)  // (value_targets, advantages)

Pure f32 computation. Contract annotations enforce preconditions (non-empty inputs, matching lengths, positive clip thresholds).

When to use V-trace

V-trace is for actor-learner architectures (like IMPALA) where the acting policy may be several updates behind the learning policy. For standard on-policy PPO, use GAE instead.

Key parameters

  • clip_rho (ρ̄): Clips importance weights for value targets. Higher = lower bias but higher variance.
  • clip_c (c̄): Clips importance weights for trace accumulation. Controls how far back off-policy corrections propagate.
  • Both typically set to 1.0.

UPGO (Self-Imitation Learning)

UPGO (Upgoing Policy Gradient) reinforces only trajectories where the agent performed better than expected. Used by ROA-Star alongside V-trace.

API

use rl4burn::upgo_advantages;

let advantages = upgo_advantages(&rewards, &values, &dones, last_value, gamma);

How it works

At each timestep, UPGO checks if the one-step TD error is positive (did better than the value predicted):

  • Positive TD: Propagate the actual return backward (learn from this)
  • Negative TD: Truncate to the value estimate (ignore this)

This creates a self-imitation effect: the agent only reinforces actions that led to above-average outcomes.

When to use

UPGO is complementary to V-trace, not a replacement. ROA-Star uses both:

  • V-trace for stable off-policy value targets
  • UPGO for the policy gradient (only reinforce good trajectories)

Replay Buffer

ReplayBuffer<S, R> stores transitions for off-policy algorithms like DQN. It’s generic over the sample type and a deterministic RNG for reproducible sampling.

API

use rand::SeedableRng;

let mut buffer = ReplayBuffer::new(10_000, rand::rngs::SmallRng::seed_from_u64(42));

buffer.extend(transitions);          // add samples
let batch = buffer.sample(64);       // random references
let batch = buffer.sample_cloned(64); // random clones (for owned data)
let groups = buffer.group_by(|t| t.episode_id); // group by key

Eviction

When the buffer exceeds capacity, the oldest samples are dropped first (FIFO).

With DQN

use rl4burn::dqn::Transition;
use rl4burn::replay::ReplayBuffer;

let mut buffer = ReplayBuffer::new(10_000, rand::rngs::SmallRng::seed_from_u64(42));

// Store transitions
buffer.extend(std::iter::once(Transition {
    obs: obs.clone(),
    action: action as i32,
    reward: result.reward,
    next_obs: result.observation.clone(),
    done: result.done(),
}));

// dqn_update samples from the buffer internally
(online, stats) = dqn_update(online, &target, &mut optim, &mut buffer, &config, &device);

Trajectory grouping

The group_by method groups sample indices by an arbitrary key function. Useful for V-trace rescoring where you need to process entire trajectories:

let groups = buffer.group_by(|sample| sample.trajectory_id);
for (traj_id, indices) in &groups {
    let trajectory: Vec<_> = indices.iter().map(|&i| &buffer.samples()[i]).collect();
    // rescore this trajectory
}

Sequence Replay Buffer

A FIFO buffer that samples contiguous sequences of a fixed length, respecting episode boundaries. Used by DreamerV3 for world model training.

API

use rl4burn::{SequenceReplayBuffer, SequenceStep};

let mut buffer = SequenceReplayBuffer::new(1_000_000, 64);
// capacity: 1M steps, sequence_length: 64

// Add transitions
buffer.push(SequenceStep {
    observation: obs.clone(),
    action: vec![1.0, 0.0],
    reward: 1.0,
    done: false,
});

// Sample batch of sequences
let sequences = buffer.sample(16, &mut rng);
// sequences: Vec<Vec<SequenceStep<O>>>, each of length 64

Episode boundaries

Sampled sequences never cross episode boundaries. If a done=true step appears in the buffer, no sequence will start before it and end after it.

FIFO eviction

When the buffer exceeds capacity, the oldest steps are removed first. Episode start indices are automatically adjusted.

Difference from ReplayBuffer

FeatureReplayBufferSequenceReplayBuffer
Sample unitSingle stepContiguous sequence
Episode boundariesNot trackedEnforced
Primary useDQN, off-policyDreamerV3 world models

Percentile Return Normalization

DreamerV3-style advantage normalization using EMA-smoothed percentiles. More robust than per-minibatch normalization for sparse or heterogeneous reward scales.

API

use rl4burn::PercentileNormalizer;

let mut normalizer = PercentileNormalizer::new();
// default: 5th-95th percentile, EMA decay 0.99

// Update with observed returns
normalizer.update(&returns);

// Normalize advantages
let normalized = normalizer.normalize(&advantages);
// Divides by max(1.0, P95 - P5)

// Or combine both steps:
let normalized = normalizer.update_and_normalize(&returns, &advantages);

The max(1, …) floor

The critical detail: when the percentile range is less than 1.0 (sparse rewards, all-zero returns), the scale is clamped to 1.0. Without this, you’d amplify noise.

Customization

let normalizer = PercentileNormalizer::with_percentiles(0.1, 0.9)
    .with_decay(0.999);

Intrinsic Rewards

Exploration bonuses based on internal state. Useful when extrinsic rewards are sparse.

API

use rl4burn::collect::intrinsic::{IntrinsicReward, CountBasedReward, combine_rewards};

let mut explorer = CountBasedReward::new(0.1); // discretization resolution
explorer.update(&obs, action, &next_obs);
let bonus = explorer.reward(&obs, action, &next_obs);
// bonus = 1 / sqrt(visit_count)

let combined = combine_rewards(&extrinsic, &intrinsic, 0.01);
// combined[i] = extrinsic[i] + 0.01 * intrinsic[i]

Count-Based Exploration

Reward = 1 / sqrt(N(s)) where N(s) is how many times the agent has visited a discretized version of state s. Novel states get high reward; familiar states get low reward.

Entropy-Reduction Reward

ROA-Star’s scouting reward: max(H_{prev} - H_{current}, 0). Rewards the agent for reducing uncertainty about the opponent’s strategy.

use rl4burn::collect::intrinsic::EntropyReductionReward;
let mut scouting = EntropyReductionReward::new();
let reward = scouting.reward_from_entropy(current_entropy);

The IntrinsicReward trait

Implement for custom exploration strategies:

pub trait IntrinsicReward {
    type Observation;
    fn reward(&self, obs: &Self::Observation, action: usize, next_obs: &Self::Observation) -> f32;
    fn update(&mut self, obs: &Self::Observation, action: usize, next_obs: &Self::Observation);
}

CSPL (Curriculum Self-Play Learning)

JueWu’s 3-phase training pipeline for scaling to many heroes/unit types.

The problem

Training a single policy that handles 40+ heroes in all combinations doesn’t converge (480+ hours without success).

The solution: three phases

PhaseWhatDuration
1. SpecialistsTrain small models on fixed team compositions~72h
2. DistillationMerge all specialists into one big model~48h
3. GeneralizationContinue RL with random compositions~144h

API

use rl4burn::{CsplPipeline, CsplConfig, CsplPhase};

let mut pipeline = CsplPipeline::new(CsplConfig {
    phase1_steps: 100_000,
    phase2_steps: 50_000,
    phase3_steps: 200_000,
    n_specialists: 10,
});

loop {
    let phase_changed = pipeline.step();

    match pipeline.current_phase() {
        CsplPhase::SpecialistTraining => { /* train specialists via self-play */ }
        CsplPhase::Distillation => { /* distill into student */ }
        CsplPhase::Generalization => { /* continue RL with random compositions */ }
    }

    if pipeline.is_complete() { break; }
}

Polyak Updates

Polyak (soft) target network updates interpolate between a source model’s weights and a target model’s weights:

target = τ * source + (1 - τ) * target

API

pub fn polyak_update<B: Backend, M: Module<B>>(
    target: M,
    source: &M,
    tau: f32,
) -> M
  • tau = 1.0: Hard copy (replace target with source entirely).
  • tau = 0.005: Soft update (slowly track source). Standard for SAC/TD3.
  • tau = 0.0: No-op (target unchanged).

Usage

use rl4burn::polyak::polyak_update;

// Hard target update (DQN-style, every N steps)
if step % 250 == 0 {
    target = polyak_update(target, &online, 1.0);
}

// Soft target update (SAC/TD3-style, every step)
target = polyak_update(target, &online, 0.005);

How it works

Uses Burn’s ModuleVisitor to collect all parameter tensors from the source model, then ModuleMapper to interpolate each target parameter in place. Works with any Module<B> — nested modules, custom architectures, any number of layers.

Orthogonal Initialization

Orthogonal weight initialization is critical for PPO convergence. It’s the default in CleanRL, Stable Baselines3, and OpenAI baselines.

API

pub fn orthogonal_linear<B: Backend>(
    d_in: usize,
    d_out: usize,
    gain: f32,
    device: &B::Device,
    rng: &mut impl Rng,
) -> Linear<B>

Creates a Linear layer with orthogonal weights and zero bias. Matches PyTorch’s nn.init.orthogonal_.

Gain values

LayerGainWhy
Hidden (tanh)sqrt(2) ≈ 1.414Preserves gradient norms through tanh
Actor output0.01Near-uniform initial policy (good exploration)
Critic output1.0Reasonable initial value scale

Usage

use rl4burn::init::orthogonal_linear;
let sqrt2 = std::f32::consts::SQRT_2;

let actor_fc1 = orthogonal_linear(4, 64, sqrt2, &device, &mut rng);
let actor_out = orthogonal_linear(64, 2, 0.01, &device, &mut rng);
let critic_out = orthogonal_linear(64, 1, 1.0, &device, &mut rng);

Why not use Burn’s built-in initializers?

Two reasons:

  1. Burn doesn’t have orthogonal initialization. The closest is XavierUniform, which has similar scale but lacks the orthogonality property.
  2. Burn initializes bias with the same initializer as weights. CleanRL always initializes bias to zero. orthogonal_linear handles both correctly.

Implementation

Uses modified Gram-Schmidt orthogonalization on a random Gaussian matrix. Weights are loaded via Param::from_data + load_record to preserve Burn’s autodiff tracking (see Working with Burn’s Autodiff).

Global Gradient Clipping

clip_grad_norm clips gradients by their global L2 norm across all parameters. This matches PyTorch’s torch.nn.utils.clip_grad_norm_.

Why not use Burn’s built-in clipping?

Burn’s GradientClippingConfig::Norm clips each parameter tensor independently. PyTorch clips the global norm across all parameters at once. These produce different behavior:

  • Per-parameter (Burn): A large gradient in the critic doesn’t affect clipping of the actor’s gradient.
  • Global (PyTorch/rl4burn): The total gradient norm is computed, then all gradients are scaled by the same factor.

For PPO with shared optimizer over actor + critic, global clipping is standard.

API

pub fn clip_grad_norm<B: AutodiffBackend, M: AutodiffModule<B>>(
    model: &M,
    grads: GradientsParams,
    max_norm: f32,
) -> GradientsParams

Call between backward() and optim.step():

let grads = loss.backward();
let mut grads = GradientsParams::from_grads(grads, &model);
grads = clip_grad_norm(&model, grads, 0.5);  // max_norm = 0.5
model = optim.step(lr, model, grads);

PPO handles this automatically via PpoConfig::max_grad_norm. Set to 0.0 to disable.

Implementation

Two-pass approach using the inner (non-autodiff) model:

  1. ModuleVisitor: Extract each gradient from GradientsParams, compute its L2 norm squared, accumulate the global norm.
  2. Compute clip_coef = min(1.0, max_norm / (global_norm + 1e-6)).
  3. ModuleMapper: Scale each gradient by clip_coef and re-register it in a new GradientsParams.

The visitor/mapper operate on B::InnerBackend because Burn stores gradients on the inner backend, not the autodiff wrapper.

Logging

rl4burn provides a lightweight, feature-gated logging system for training metrics. The core Logger trait and built-in loggers ship with zero extra dependencies. TensorBoard and JSON output are opt-in via feature flags.

The Logger trait

All loggers implement Logger:

pub trait Logger {
    fn log_scalar(&mut self, key: &str, value: f64, step: u64);
    fn log_scalars(&mut self, key: &str, values: &[(&str, f64)], step: u64);
    fn log_text(&mut self, key: &str, text: &str, step: u64);
    fn log_histogram(&mut self, key: &str, values: &[f32], step: u64);
    fn flush(&mut self);
}

Built-in loggers (no feature flags)

PrintLogger — prints scalars to stderr in a formatted line. Accepts a throttle interval so you don’t flood the terminal:

use rl4burn::PrintLogger;

// Print at most every 1000 steps
let mut logger = PrintLogger::new(1000);
logger.log_scalar("train/loss", 0.42, 5000);
// stderr: [step     5000] train/loss: 0.4200

NoopLogger — discards everything. Useful as a default when the caller doesn’t care about logging.

CompositeLogger — fans out to multiple loggers simultaneously:

use rl4burn::{CompositeLogger, PrintLogger};
use rl4burn::TensorBoardLogger; // requires `tensorboard` feature

let mut logger = CompositeLogger::new(vec![
    Box::new(PrintLogger::new(0)),
    Box::new(TensorBoardLogger::new("runs/ppo_cartpole").unwrap()),
]);
logger.log_scalar("train/loss", 0.5, 100); // goes to both

Logging stats from algorithms

PpoStats and DqnStats implement the Loggable trait, so you can log all their fields in one call:

use rl4burn::Loggable;

let (model, stats) = ppo_update(model, &mut optim, &rollout, &config, lr, &device, &mut rng);
stats.log(&mut logger, step);
// Logs: train/policy_loss, train/value_loss, train/entropy, train/approx_kl

For DQN:

let (online, stats) = dqn_update(online, &target, &mut optim, &mut buffer, &config, &device);
stats.log(&mut logger, step);
// Logs: train/loss, train/mean_q, train/epsilon

TensorBoard (feature-gated)

Enable the tensorboard feature in your Cargo.toml:

[dependencies]
rl4burn = { version = "0.1", features = ["tensorboard"] }

Then create a TensorBoardLogger pointing at a run directory:

use rl4burn::TensorBoardLogger;

let mut logger = TensorBoardLogger::new("runs/experiment_1").unwrap();
logger.log_scalar("train/loss", 0.42, 1000);
logger.log_histogram("weights", &weight_data, 1000);
logger.log_text("info", "training started", 0);
logger.flush();

View results with:

tensorboard --logdir runs/

The logger writes standard TFEvent files (events.out.tfevents.*) with hand-serialized protobufs — no prost or protobuf dependency required. Supports scalars, histograms, and text.

JSON output (feature-gated)

Enable the json-log feature:

[dependencies]
rl4burn = { version = "0.1", features = ["json-log"] }

JsonLogger writes one JSON object per line (JSONL format) to any Write sink:

use rl4burn::JsonLogger;

let mut logger = JsonLogger::from_path("train_log.jsonl").unwrap();
logger.log_scalar("train/loss", 0.42, 1000);
logger.flush();

Each line looks like:

{"type":"scalar","key":"train/loss","value":0.42,"step":1000,"wall_time":1234567890.123}

Bridging to Weights & Biases

A thin Python bridge script is included at scripts/wandb_bridge.py:

cargo run --example ppo_cartpole --features "ndarray,json-log" 2>&1 \
  | python scripts/wandb_bridge.py

The same JSONL format can be ingested by neptune, mlflow, comet, or any custom dashboard.

Video recording (feature-gated)

Enable the video feature to record CartPole episodes as GIFs:

[dependencies]
rl4burn = { version = "0.1", features = ["video"] }

Any environment implementing the Renderable trait can produce RGB frames. CartPole and GridWorld both implement it:

use rl4burn::envs::CartPole;
use rl4burn::{write_gif, Env, Renderable};

let mut env = CartPole::new(rng);
env.reset();

let mut frames = vec![env.render()];
loop {
    let step = env.step(action);
    frames.push(env.render());
    if step.done() { break; }
}

write_gif("episode.gif", &frames, 4).unwrap(); // 4 centiseconds per frame

Putting it all together

A typical training script with logging:

use rl4burn::{CompositeLogger, Loggable, Logger, PrintLogger, TensorBoardLogger};

let mut logger = CompositeLogger::new(vec![
    Box::new(PrintLogger::new(5000)),
    Box::new(TensorBoardLogger::new("runs/ppo").unwrap()),
]);

for iter in 0..n_iterations {
    let rollout = ppo_collect::<NdArray, _, _>(&model.valid(), &mut vec_env, &config, &device, &mut rng, &mut current_obs, &mut ep_acc);

    let step = (iter + 1) as u64 * steps_per_iter as u64;
    if !rollout.episode_returns.is_empty() {
        let avg = rollout.episode_returns.iter().sum::<f32>() / rollout.episode_returns.len() as f32;
        logger.log_scalar("rollout/avg_return", avg as f64, step);
    }

    let (new_model, stats) = ppo_update(model, &mut optim, &rollout, &config, lr, &device, &mut rng);
    model = new_model;
    stats.log(&mut logger, step);
}
logger.flush();

Feature flags summary

FeatureDependencyWhat you get
(none)Logger trait, PrintLogger, NoopLogger, CompositeLogger, Loggable
tensorboardcrc32cTensorBoardLogger (TFEvent files)
json-logJsonLogger (JSONL output)
videogifwrite_gif(), CartPole::render()

Saving & Sharing

After training, you typically want to save the model weights for later inference and share visualizations of agent behavior. This chapter covers both.

Saving model weights

Burn models derive Module, which gives them save_file and load_file for free. No rl4burn-specific API is needed — just use Burn’s recorder system directly.

Save a trained model

use burn::record::{CompactRecorder, Recorder};

// After training completes:
model
    .save_file("checkpoints/ppo_cartpole", &CompactRecorder::new())
    .expect("failed to save model");

This writes checkpoints/ppo_cartpole.mpk (MessagePack format). The file contains all learnable parameters.

Load a saved model

use burn::record::{CompactRecorder, Recorder};

// Initialize a fresh model, then load weights into it
let model: ActorCritic<AB> = ActorCritic::new(&device)
    .load_file("checkpoints/ppo_cartpole", &CompactRecorder::new(), &device)
    .expect("failed to load model");

The model architecture must match — load_file loads parameter values, not the structure.

Recorder types

Burn provides several recorders. Choose based on your needs:

RecorderFormatGood for
CompactRecorderMessagePack (.mpk)Production — small files, fast I/O
NamedMpkGzFileRecordergzipped MessagePackSharing — even smaller files
PrettyJsonFileRecorderJSON (.json)Debugging — human-readable weights
BinFileRecorderRaw binary (.bin)Maximum speed, no compression

CompactRecorder is the default choice for most use cases.

Checkpointing during training

Save periodically so you can resume after interruptions or pick the best checkpoint:

use burn::record::{CompactRecorder, Recorder};

for iter in 0..n_iterations {
    // ... collect and update ...

    // Save every 50 iterations
    if (iter + 1) % 50 == 0 {
        let path = format!("checkpoints/ppo_step_{}", (iter + 1) * steps_per_iter);
        model
            .save_file(&path, &CompactRecorder::new())
            .expect("failed to save checkpoint");
    }
}

// Always save the final model
model
    .save_file("checkpoints/ppo_final", &CompactRecorder::new())
    .expect("failed to save final model");

DQN: saving online and target networks

For DQN, save both networks so you can resume training correctly:

online.save_file("checkpoints/dqn_online", &CompactRecorder::new())?;
target.save_file("checkpoints/dqn_target", &CompactRecorder::new())?;

For inference only, you just need the online network.

Sharing visualizations

GIF recordings

With the video feature, record an episode of your trained agent and save it as a GIF:

use rl4burn::envs::CartPole;
use rl4burn::{write_gif, greedy_action, Env, Renderable};

let mut env = CartPole::new(rng);
let mut obs = env.reset();

let mut frames = vec![env.render()];
loop {
    // greedy_action runs a forward pass and returns the argmax action
    let action = greedy_action(&model, &obs, &device);
    let step = env.step(action);
    frames.push(env.render());
    if step.done() { break; }
    obs = step.observation;
}

write_gif("agent_demo.gif", &frames, 4).unwrap();

The resulting GIF can be embedded in READMEs, papers, blog posts, or Slack messages.

TensorBoard

TensorBoard logs are shareable as a directory. Zip the run folder and send it, or use TensorBoard.dev for public sharing:

# View locally
tensorboard --logdir runs/

# Share publicly (requires Google account)
tensorboard dev upload --logdir runs/ppo_cartpole --name "PPO CartPole"

When comparing experiments, use separate run directories:

// Each experiment gets its own subdirectory
let logger = TensorBoardLogger::new(format!("runs/ppo_lr{}", config.lr)).unwrap();

Then tensorboard --logdir runs/ overlays all runs for comparison.

JSONL logs

JSONL files are plain text and easy to share. Post-process them with any tool:

# Quick plot with Python
python -c "
import json, matplotlib.pyplot as plt
data = [json.loads(l) for l in open('train_log.jsonl') if '\"scalar\"' in l]
returns = [(d['step'], d['value']) for d in data if d['key'] == 'rollout/avg_return']
plt.plot(*zip(*returns))
plt.xlabel('Step'); plt.ylabel('Avg Return')
plt.savefig('training_curve.png')
"

# Send to W&B
python scripts/wandb_bridge.py < train_log.jsonl

Putting it all together

A complete training script that checkpoints the model and records a final evaluation GIF:

use burn::record::{CompactRecorder, Recorder};
use rl4burn::{
    CompositeLogger, Loggable, Logger, PrintLogger, TensorBoardLogger,
    envs::CartPole, write_gif, Env, Renderable, greedy_action,
};

let mut logger = CompositeLogger::new(vec![
    Box::new(PrintLogger::new(5000)),
    Box::new(TensorBoardLogger::new("runs/ppo").unwrap()),
]);

// Training loop
for iter in 0..n_iterations {
    let rollout = ppo_collect::<NdArray, _, _>(
        &model.valid(), &mut vec_env, &config, &device, &mut rng, &mut current_obs, &mut ep_acc,
    );

    let (new_model, stats) = ppo_update(
        model, &mut optim, &rollout, &config, lr, &device, &mut rng,
    );
    model = new_model;

    let step = (iter + 1) as u64 * steps_per_iter as u64;
    stats.log(&mut logger, step);

    // Checkpoint every 100 iterations
    if (iter + 1) % 100 == 0 {
        model.save_file(
            &format!("checkpoints/ppo_{step}"),
            &CompactRecorder::new(),
        ).unwrap();
    }
}
logger.flush();

// Save final weights
model.save_file("checkpoints/ppo_final", &CompactRecorder::new()).unwrap();

// Record evaluation episode
let mut env = CartPole::new(rng);
let mut obs = env.reset();
let mut frames = vec![env.render()];
loop {
    let action = greedy_action(&model.valid(), &obs, &device);
    let step = env.step(action);
    frames.push(env.render());
    if step.done() { break; }
    obs = step.observation;
}
write_gif("evaluation.gif", &frames, 4).unwrap();

LSTM, GRU, and Block GRU

Recurrent cells for temporal reasoning under partial observability. Every game AI paper (AlphaStar, SCC, JueWu) uses LSTM or GRU for sequence processing.

LSTM Cell

use rl4burn::{LstmCell, LstmCellConfig, LstmState};

let cell = LstmCellConfig::new(input_size, hidden_size).init(&device);
let state = LstmState::zeros(batch_size, hidden_size, &device);

// Single step
let new_state = cell.forward(input, &state);

// Full sequence
let (outputs, final_state) = cell.forward_seq(inputs, &state);
// outputs: [batch, seq_len, hidden_size]

GRU Cell

Same API, simpler internals (2 gates vs LSTM’s 3). Uses PyTorch’s convention for reset gate application.

use rl4burn::{GruCell, GruCellConfig};

let cell = GruCellConfig::new(input_size, hidden_size).init(&device);
let h = Tensor::zeros([batch_size, hidden_size], &device);
let new_h = cell.forward(input, h);

Block GRU (DreamerV3)

Block-diagonal GRU reduces recurrent parameters by a factor of n_blocks. The recurrent weight matrix is split into independent blocks, each operating on a partition of the hidden state.

DreamerV3 uses 8 blocks with a 4096-dim hidden state, reducing parameters from 16M to 2M.

use rl4burn::{BlockGruCell, BlockGruCellConfig};

let cell = BlockGruCellConfig::new(input_size, hidden_size)
    .with_n_blocks(8)
    .init(&device);

When n_blocks = 1, Block GRU is identical to standard GRU.

Transformer Encoder

Reusable multi-head self-attention blocks for entity processing. Used by ROA-Star and SCC to encode sets of game units.

Multi-Head Attention

use rl4burn::{MultiHeadAttention, MultiHeadAttentionConfig};

let attn = MultiHeadAttentionConfig::new(128, 4).init(&device);
// d_model=128, 4 heads (d_k = 32 per head)

let output = attn.forward(query, key, value, None);
// All inputs: [batch, seq_len, 128]
// Optional mask: [batch, seq_len] (true = attend, false = ignore)

Transformer Block

Pre-norm residual block: self-attention + feedforward.

use rl4burn::{TransformerBlock, TransformerBlockConfig};

let block = TransformerBlockConfig::new(128, 4, 512).init(&device);
// d_model=128, 4 heads, d_ff=512
let output = block.forward(input, None);  // residual: output ≈ input + attention + ffn

Stacked Encoder

use rl4burn::{TransformerEncoder, TransformerEncoderConfig};

let encoder = TransformerEncoderConfig::new(128, 4, 2, 512).init(&device);
// 2 layers of transformer blocks
let encoded = encoder.forward(entities, None);

Properties

  • Permutation equivariant: reordering input tokens reorders output tokens identically (no positional encoding).
  • Variable-length: use masking for padded sequences.
  • For 30 entities with 128-dim embeddings, a 2-layer encoder runs in microseconds on CPU.

Attention Mechanisms

Three specialized attention modules for game AI architectures.

Target Attention

Scaled dot-product attention for selecting a target entity. The LSTM output serves as query; encoded entities serve as keys. Returns a probability distribution over entities.

use rl4burn::{TargetAttention, TargetAttentionConfig};

let attn = TargetAttentionConfig::new(256, 128).init(&device);
// query_dim=256 (LSTM output), key_dim=128 (entity embedding)

let probs = attn.forward(query, keys, Some(mask));
// query: [batch, 256], keys: [batch, n_entities, 128]
// mask: [batch, n_entities] (true = valid target)
// probs: [batch, n_entities] (sums to 1 over valid targets)

Attention Pooling

Aggregates variable-count entity embeddings into a fixed-size vector using learned query vectors. Superior to mean/max pooling.

use rl4burn::{AttentionPool, AttentionPoolConfig};

let pool = AttentionPoolConfig::new(128, 4, 2).init(&device);
// embed_dim=128, 4 learned queries, 2 attention heads

let pooled = pool.forward(entities, None);
// entities: [batch, n_entities, 128]
// pooled: [batch, 512] (4 queries * 128 dims)

Pointer Network

Additive (Bahdanau) attention for entity selection: score = v^T * tanh(W_q * query + W_k * keys). Used by AlphaStar and SCC for selecting subsets of units.

use rl4burn::{PointerNet, PointerNetConfig};

let ptr = PointerNetConfig::new(256, 128, 64).init(&device);
// query_dim=256, key_dim=128, hidden_dim=64

let probs = ptr.forward(query, keys, Some(mask));
// probs: [batch, n_entities] (selection probabilities)

All three modules support masking for absent/dead entities.

Auto-Regressive Action Distributions

For games where actions decompose into sequential decisions: what action -> which target -> which ability. Every competitive game AI paper uses this pattern.

CompositeDistribution

use rl4burn::{CompositeDistribution, ActionHead};

// 3-head action space: action_type(11) -> target(30) -> ability(8)
let dist = CompositeDistribution::from_heads(
    &["action_type", "target", "ability"],
    &[11, 30, 8],
);

// Total logits needed from the model: 11 + 30 + 8 = 49
assert_eq!(dist.total_logits(), 49);

Sampling

Given flat logits from the model (all heads concatenated), sample independently per head:

let actions = dist.sample(&logits, mask.as_ref(), &mut rng);
// actions: Vec<Vec<f32>> — [batch][n_heads], integer-valued

For fully auto-regressive sampling (where head 2’s logits depend on head 1’s sample), call the model multiple times and feed actions back.

Log-probabilities

Joint log-prob is the sum of per-head log-probs:

let log_prob = dist.log_prob(logits, &actions, mask.as_ref(), &device);
// log_prob: [batch] — log P(a) = log P(a1) + log P(a2) + log P(a3)

Entropy

Sum of per-head entropies (exact when heads are independent, upper bound otherwise):

let entropy = dist.entropy(logits, mask.as_ref());
// entropy: [batch]

With action masking

Pass a flat mask tensor [batch, total_logits] where 1.0 = valid, 0.0 = invalid. Masked actions are never sampled and get zero probability.

FiLM Conditioning

FiLM (Feature-wise Linear Modulation) applies a context-dependent affine transform to features. Used by SCC to condition spatial action heads on action type.

API

use rl4burn::{Film, FilmConfig};

let film = FilmConfig::new(32, 128).init(&device);
// context_dim=32, feature_dim=128

let output = film.forward(features, context);
// features: [batch, 128], context: [batch, 32]
// output: [batch, 128]

How it works

output = (1 + gamma(context)) * features + beta(context)

The +1 on gamma ensures the layer starts as an identity transform, improving training stability.

Symlog and Twohot Encoding

DreamerV3’s solution for scale-free predictions. Symlog compresses large values; twohot turns regression into classification.

Symlog / Symexp

use rl4burn::{symlog, symexp};

let compressed = symlog(values);   // sign(x) * ln(|x| + 1)
let recovered = symexp(compressed); // sign(x) * (exp(|x|) - 1)
// Round-trip: symexp(symlog(x)) ≈ x

Key properties:

  • symlog(0) = 0
  • symlog(1000) ≈ 6.9 (massive compression)
  • symlog(-x) = -symlog(x) (symmetric)
  • Monotonically increasing

Twohot Encoder

Encodes scalar values as soft distributions over 255 bins in symlog space.

use rl4burn::TwohotEncoder;

let encoder = TwohotEncoder::new(); // 255 bins, [-20, 20]

// Encode: scalar → distribution
let targets = encoder.encode(values, &device);  // [batch, 255]

// Decode: distribution → scalar
let values = encoder.decode(probs, &device);  // [batch]

// Loss: cross-entropy against twohot targets
let loss = encoder.loss(logits, values, &device);  // [1]

Why this matters

Without symlog+twohot, you need to tune learning rates per domain. A reward of 1000 produces 1000x larger gradients than a reward of 1. Symlog compresses this to ~7x. Twohot converts regression to classification, further stabilizing gradients.

DreamerV3 Overview

DreamerV3 learns a model of the world, then trains a policy entirely inside imagined trajectories. It’s architecturally different from the model-free papers (AlphaStar, JueWu) but its sample efficiency could be transformative for fast simulations.

The DreamerV3 training loop

repeat:
    1. Collect experience in the real environment
    2. Store in sequence replay buffer
    3. Sample sequences, train the world model (RSSM)
    4. Imagine trajectories from the world model
    5. Train actor-critic on imagined data

Steps 4-5 are “free” — no environment interaction needed.

rl4burn modules for DreamerV3

ComponentModulePage
World modelRssmRSSM
Imaginationimagine_rolloutImagination
Value targetslambda_returnsImagination
ReplaySequenceReplayBufferSequence Replay
Transformssymlog, TwohotEncoderSymlog
KL trainingkl_balanced_lossKL Balance
NormalizationPercentileNormalizerPercentile
Block GRUBlockGruCellRNN

RSSM (Recurrent State-Space Model)

The world model at the heart of DreamerV3. Predicts what happens next in a compact latent space.

State representation

The RSSM state has two parts:

  • h (deterministic): GRU hidden state, captures long-term memory
  • z (stochastic): 32 categorical distributions x 32 classes, captures uncertainty

Together they form a 1024+ dimensional state sufficient to reconstruct observations, predict rewards, and determine if episodes continue.

API

use rl4burn::{Rssm, RssmConfig, RssmState};

let config = RssmConfig::new(obs_dim, action_dim);
let rssm = config.init(&device);

// Initial state (all zeros)
let state = rssm.initial_state(batch_size, &device);

// Training: observe → update
let (new_state, posterior_logits, prior_logits) = rssm.obs_step(&state, action, obs);

// Imagination: predict without observing
let new_state = rssm.imagine_step(&state, action);

// Predictions
let reward_logits = rssm.predict_reward(state.h, state.z);   // [batch, 255]
let cont_logits = rssm.predict_continue(state.h, state.z);   // [batch, 1]

Training the RSSM

Train with KL-balanced loss between posterior and prior:

use rl4burn::{kl_balanced_loss, KlBalanceConfig, TwohotEncoder};

let kl_loss = kl_balanced_loss(posterior_logits, prior_logits, &KlBalanceConfig::default());
let reward_loss = TwohotEncoder::new().loss(reward_logits, actual_rewards, &device);
let total_loss = kl_loss + reward_loss;

Configuration

let config = RssmConfig {
    obs_dim: 64,
    action_dim: 11,
    deterministic_size: 512,   // GRU hidden size
    n_categories: 32,          // stochastic groups
    n_classes: 32,             // classes per group
    hidden_size: 512,          // MLP hidden dim
    n_blocks: 8,               // block GRU blocks (0 = standard GRU)
    unimix: 0.01,              // uniform mixture for categoricals
};

Imagination Rollouts

Generate trajectories entirely within the RSSM latent space for actor-critic training.

API

use rl4burn::algo::planning::imagination::{imagine_rollout, lambda_returns};

let trajectory = imagine_rollout(
    &rssm,
    initial_states,
    |h, z| actor_network.forward(h, z),  // actor closure
    15,  // horizon (DreamerV3 default)
);

// trajectory.states: [16] states (initial + 15 imagined)
// trajectory.reward_logits: [15] reward predictions
// trajectory.continue_logits: [15] continue predictions

Lambda-returns

Compute value targets from imagined rewards:

let returns = lambda_returns(
    &rewards,     // decoded from reward_logits
    &values,      // critic predictions at each state
    &continues,   // sigmoid(continue_logits)
    0.997,        // gamma
    0.95,         // lambda
);

Stop-gradient rules

During imagination training:

  1. World model: frozen (no gradients). The actor learns to generate actions that lead to high-value states.
  2. Value targets: stop-gradiented. The critic trains on fixed targets.
  3. Rewards: gradients flow through the dynamics model to the actor (the actor is indirectly optimizing for states that the world model predicts will be rewarding).

R2-Dreamer

R2-Dreamer (ICLR 2026) is a computationally efficient world model for RL that achieves strong performance without decoders or augmentation. It replaces the standard reconstruction loss with self-supervised representation objectives.

Key Idea

Standard DreamerV3 trains the encoder via a decoder that reconstructs observations. R2-Dreamer eliminates this bottleneck by using redundancy reduction (Barlow Twins loss) to learn representations directly.

Representation Variants

rl4burn supports all four variants from the paper:

VariantLossDescription
DreamerDecoder MSEStandard DreamerV3 reconstruction baseline
R2DreamerBarlow TwinsInvariance + decorrelation on cross-correlation matrix
InfoNCEContrastivePositive pair matching with temperature-scaled cosine similarity
DreamerProPrototypeSinkhorn-Knopp assignment to learned prototypes

Usage

#![allow(unused)]
fn main() {
use rl4burn::algo::dreamer::{DreamerConfig, dreamer_world_model_loss, dreamer_actor_critic_loss};
use rl4burn::algo::loss::representation::RepresentationVariant;

// Configure with R2-Dreamer (Barlow Twins)
let config = DreamerConfig {
    rep_variant: RepresentationVariant::R2Dreamer,
    action_dim: 4,
    discrete_actions: true,
    ..DreamerConfig::default()
};
let agent = config.init::<B>(&device);

// Train world model on observed sequences
let (wm_loss, wm_stats) = dreamer_world_model_loss(
    &agent, observations, actions, rewards, continues,
);

// Train actor-critic via imagination
let (actor_loss, critic_loss, ac_stats) = dreamer_actor_critic_loss(
    &agent, initial_states,
);
}

Architecture

The agent composes existing rl4burn building blocks:

  • RSSM (rl4burn_nn::rssm) — recurrent state-space model with deterministic GRU + stochastic categorical states
  • Imagination rollouts (rl4burn_algo::planning::imagination) — generate trajectories in latent space
  • KL-balanced loss (rl4burn_algo::loss::kl_balance) — train posterior and prior with free bits
  • Symlog + Twohot (rl4burn_nn::symlog) — distributional value prediction
  • Representation losses (rl4burn_algo::loss::representation) — Barlow Twins, InfoNCE, DreamerPro, decoder
  • MLP with RMSNorm (rl4burn_nn::mlp) — prediction heads and actor/critic networks
  • CNN encoder/decoder (rl4burn_nn::conv) — image observation processing

New Modules

ModuleCrateDescription
mlprl4burn-nnConfigurable MLP with RMSNorm or LayerNorm
convrl4burn-nnCNN encoder (images → features) and decoder (features → images)
multi_encoderrl4burn-nnRoutes mixed observations (images + vectors)
representationrl4burn-algoFour self-supervised representation losses
dreamerrl4burn-algoDreamerAgent, world model loss, actor-critic loss

Example

See examples/dreamer/ for a complete training loop on CartPole.

Reference

Nauman & Straffelini, “R2-Dreamer: Redundancy Reduction for Computationally Efficient World Models” (ICLR 2026).

Self-Play

Train agents by playing against past versions of themselves. The core mechanism for competitive game AI.

API

use rl4burn::algo::multi_agent::self_play::{SelfPlayPool, branch_agent};

let mut pool = SelfPlayPool::new();

// Snapshot current model every N steps
pool.add_snapshot(&model, training_step);

// Get a random past opponent
if let Some(opponent) = pool.sample(&mut rng) {
    // Run game: model vs opponent
}

// Keep only the 50 most recent
pool.retain_recent(50);

How it works

SelfPlayPool stores cloned copies of the model at different training stages. Opponents are sampled uniformly. For smarter opponent selection, see PFSP Matchmaking.

Important: snapshots are deep copies

When you call add_snapshot, the model is .clone()’d. Mutating the original model afterward does not affect stored snapshots. This is essential — without true deep copies, all “opponents” would have the same weights as the current model.

League Training

AlphaStar-style multi-agent training with role-based specialization. Multiple agents train simultaneously with different objectives.

Agent Roles

RoleOpponentsPurpose
Main Agent35% self-play + 65% PFSP poolGeneral strength
Main ExploiterOnly the main agentFind main agent’s weaknesses
League ExploiterPFSP across full poolFind weaknesses across all strategies

API

use rl4burn::{League, AgentRole, LeagueAgentConfig};

let mut league = League::new();
league.set_initial_model(initial_model.clone());

// Add agents
league.add_agent(model.clone(), LeagueAgentConfig {
    role: AgentRole::MainAgent,
    checkpoint_interval: 1000,
    reset_threshold: 0,
});
league.add_agent(model.clone(), LeagueAgentConfig {
    role: AgentRole::MainExploiter,
    checkpoint_interval: 2000,
    reset_threshold: 50000,
});

// Training loop
let opponent = league.get_opponent(agent_idx, &mut rng);
// ... play game, update model ...
league.update_agent(agent_idx); // handles checkpointing

Checkpointing

Every checkpoint_interval steps, the agent’s current weights are frozen and added to the opponent pool. All agents can then play against these frozen snapshots.

Exploiter resets

Exploiters that stop improving get reset to the initial model weights:

league.reset_exploiter(exploiter_idx);

PFSP Matchmaking

Prioritized Fictitious Self-Play samples harder opponents more frequently. The opponent you lose to most often is the one you practice against most.

API

use rl4burn::{PfspMatchmaking, PfspConfig};

let mut mm = PfspMatchmaking::new(PfspConfig {
    power: 1.0,    // higher = more focus on hard opponents
    min_prob: 0.01, // every opponent has at least 1% chance
});

mm.add_opponent(0);
mm.add_opponent(1);
mm.add_opponent(2);

// Record results
mm.record_result(0, true, false);   // beat opponent 0
mm.record_result(1, false, false);  // lost to opponent 1

// Sample: opponent 1 (harder) is sampled more often
let opponent = mm.sample_opponent(&mut rng);

Weighting formula

Selection probability is proportional to (1 - win_rate) ^ power:

  • Win rate 90% → weight 0.1
  • Win rate 50% → weight 0.5
  • Win rate 10% → weight 0.9

Higher power makes the distribution more extreme.

Multi-Agent Shared-Weight Training

Efficiently control multiple units with a single shared policy network. Used by JueWu for 5 heroes and applicable to any game with multiple controlled units.

API

use rl4burn::{batch_multi_agent_obs, unbatch_actions, broadcast_team_reward};

// Batch observations from all agents across all environments
let (obs_tensor, n_envs, n_agents) = batch_multi_agent_obs::<B>(
    &per_env_per_agent_obs,
    &device,
);
// obs_tensor: [n_envs * n_agents, obs_dim] — one big batch

// Single forward pass for all agents
let output = model.forward(obs_tensor);

// Unbatch actions back to per-env, per-agent
let actions = unbatch_actions(&flat_actions, n_envs, n_agents);
// actions: [n_envs][n_agents]

// Broadcast team reward to all agents
let per_agent_rewards = broadcast_team_reward(&env_rewards, n_agents);

Why shared weights?

With 30 units, running 30 separate forward passes is expensive. With shared weights, batch all 30 observations into one forward pass. The policy generalizes across unit types through the observation encoding.

Privileged Critic

The policy sees only what a player would see. The critic sees everything — including enemy positions behind fog of war.

The Trait

use rl4burn::algo::privileged_critic::{PrivilegedActorCritic, make_critic_input};

impl<B: Backend> PrivilegedActorCritic<B> for MyModel<B> {
    fn actor_forward(&self, obs: Tensor<B, 2>) -> Tensor<B, 2> {
        self.actor.forward(obs)  // partial observation only
    }

    fn critic_forward(&self, obs: Tensor<B, 2>, privileged: Tensor<B, 2>) -> Tensor<B, 1> {
        let input = make_critic_input(obs, privileged);
        self.critic.forward(input)
    }
}

Why it works

Value estimation under partial observability is noisy — the critic can’t tell if you’re winning or losing without seeing the full game state. Giving the critic privileged information dramatically reduces variance. At deployment time, only the actor is needed.

Goal-Conditioned RL (z-Conditioning)

Condition policies on strategy descriptors to enable rapid specialization. Used by ROA-Star and SCC for exploiter training.

API

use rl4burn::algo::z_conditioning::{ZConditioning, ZConditioningConfig, z_reward};

let z_mod = ZConditioningConfig::new(16, obs_dim).init(&device);
// z_dim=16 (strategy embedding), obs_dim from environment

let conditioned_obs = z_mod.forward(obs, z);
// conditioned_obs: [batch, obs_dim + 64] — ready for policy network

// Pseudo-reward for following target strategy
let reward = z_reward(&observed_stats, &target_z);
// negative L2 distance: closer to target = higher reward

What is z?

A low-dimensional vector describing a play style, computed from human replay statistics. Examples:

  • Aggressive: high damage, low farming
  • Defensive: low damage, high survival
  • Rush: high early-game activity

By conditioning on different z vectors, the same policy can exhibit different strategies.

Agent Branching

Clone an agent’s weights to create a new specialized agent. Used by SCC for initializing exploiters from the current main agent rather than from scratch.

API

use rl4burn::algo::multi_agent::self_play::branch_agent;

let exploiter = branch_agent(&main_agent);
let mut exploiter_optim = AdamConfig::new().init();  // fresh optimizer!

Why branch?

Starting exploiters from the main agent’s current weights (instead of the supervised model) gives them a head start. They already know how to play — they just need to specialize in exploiting weaknesses.

The key: the optimizer state must be reset. branch_agent clones the model; you create a fresh optimizer.

MCTS for Drafting

UCT-based Monte Carlo Tree Search for pre-game decisions like unit composition or hero drafting.

API

use rl4burn::algo::planning::mcts::{MctsTree, MctsConfig};

let mut tree = MctsTree::new(MctsConfig {
    n_simulations: 800,
    exploration_constant: 1.41,
    n_actions: 30,  // number of possible picks
});

let visit_counts = tree.search(|action_path| {
    // Evaluate this sequence of picks.
    // Return estimated win rate (0.0 to 1.0).
    evaluate_composition(action_path)
}, &mut rng);

let best_pick = tree.best_action();
let pick_probs = tree.action_probs();

How UCT works

  1. Select: walk down the tree, choosing children by UCT score = mean_value + c * sqrt(ln(parent_visits) / visits)
  2. Expand: add a new child for an unexplored action
  3. Evaluate: call your evaluation function on the action sequence
  4. Backpropagate: update visit counts and values up to the root

After all simulations, pick the most-visited action (not the highest-value — visit count is more robust).

Beta-VAE Opponent Modeling

ROA-Star’s approach: train a frozen encoder to predict opponent behavior behind fog of war, then use the latent embedding as extra context for all agents.

API

use rl4burn::nn::vae::{BetaVae, BetaVaeConfig};

let vae = BetaVaeConfig::new(obs_dim)
    .with_latent_dim(32)
    .with_beta(4.0)
    .init(&device);

// Training
let output = vae.forward(opponent_features);
let loss = vae.loss(opponent_features, &output);

// Inference: extract strategy embedding
let z = vae.strategy_embedding(opponent_features);
// z: [batch, 32] — feed this as extra context to the policy

Why beta-VAE?

A standard VAE often ignores the latent space (posterior collapse). Higher beta forces the model to use the latent space, producing more disentangled and interpretable strategy embeddings.

Scouting reward

The entropy of the opponent model’s predictions can be used as an intrinsic reward: the agent is rewarded for actions that reduce uncertainty about the opponent.

use rl4burn::collect::intrinsic::EntropyReductionReward;

Distributed Training

Abstractions for multi-GPU/multi-machine gradient synchronization.

The GradientSync Trait

use rl4burn::algo::distributed::{GradientSync, ReduceStrategy};

pub trait GradientSync {
    fn all_reduce_f32(&self, values: &[f32], strategy: ReduceStrategy) -> Vec<f32>;
    fn rank(&self) -> usize;
    fn world_size(&self) -> usize;
    fn barrier(&self);
}

Local Development

Use LocalSync for single-machine development. All operations are no-ops.

use rl4burn::algo::distributed::LocalSync;
let sync = LocalSync;
assert_eq!(sync.world_size(), 1);

Custom Implementations

Implement GradientSync for your cluster’s communication library (MPI, NCCL, gRPC, etc.):

struct MpiSync { /* ... */ }

impl GradientSync for MpiSync {
    fn all_reduce_f32(&self, values: &[f32], strategy: ReduceStrategy) -> Vec<f32> {
        // Call MPI_Allreduce
    }
    // ...
}

At scale (SCC: ~1000 envs per agent, HoK: 320 GPUs), ring all-reduce is the standard choice for gradient averaging.

Cloud GPU Deployment

rl4burn provides the rl4burn-cloud crate with provider-agnostic abstractions for launching training jobs on cloud GPU marketplaces. Currently supported:

  • Vast.ai — peer-to-peer GPU marketplace with competitive pricing
  • RunPod — managed GPU cloud with on-demand and spot pods

Enable the feature in your Cargo.toml:

[dependencies]
rl4burn = { git = "https://github.com/RPP1011/rl4burn", features = ["cloud"] }

The CloudProvider Trait

All providers implement a common interface:

use rl4burn::cloud::{CloudProvider, InstanceRequirements, GpuOffer, Instance};

pub trait CloudProvider {
    fn name(&self) -> &'static str;
    fn search_offers(&self, reqs: &InstanceRequirements) -> CloudResult<Vec<GpuOffer>>;
    fn launch(&self, offer: &GpuOffer) -> CloudResult<Instance>;
    fn status(&self, instance_id: &str) -> CloudResult<Instance>;
    fn stop(&self, instance_id: &str) -> CloudResult<()>;
}

The trait is HTTP-client agnostic — providers produce structured HttpRequest values that you execute with your preferred client (reqwest, ureq, curl, etc.).

Specifying Requirements

use rl4burn::cloud::{InstanceRequirements, GpuType};

let reqs = InstanceRequirements {
    min_gpu_ram_gib: 24.0,
    num_gpus: 1,
    gpu_types: vec![GpuType::Rtx4090, GpuType::A100Pcie],
    min_ram_gib: 32.0,
    min_disk_gib: 100.0,
    docker_image: "pytorch/pytorch:2.3.1-cuda12.1-cudnn8-devel".into(),
    max_price_per_hour: 1.00,
    on_start_cmd: Some("apt-get update && apt-get install -y cargo".into()),
    ..Default::default()
};

Vast.ai

use rl4burn::cloud::{VastAiProvider, CloudProvider};

let provider = VastAiProvider::new(std::env::var("VASTAI_API_KEY").unwrap());

// Search for offers
let offers = provider.search_offers(&reqs)?;
println!("Found {} offers, cheapest: ${}/hr",
    offers.len(), offers[0].price_per_hour);

// Launch the cheapest offer
let instance = provider.launch(&offers[0])?;
println!("Instance {} status: {}", instance.instance_id, instance.status);

// Check status
let inst = provider.status(&instance.instance_id)?;
if let Some(ssh) = &inst.ssh_connection {
    println!("Connect: {}", ssh);
}

// Cleanup
provider.stop(&instance.instance_id)?;

RunPod

use rl4burn::cloud::{RunPodProvider, CloudProvider};

let provider = RunPodProvider::new(std::env::var("RUNPOD_API_KEY").unwrap());

let offers = provider.search_offers(&reqs)?;
let instance = provider.launch(&offers[0])?;
println!("Pod {} status: {}", instance.instance_id, instance.status);

provider.stop(&instance.instance_id)?;

Comparing Providers

You can write provider-agnostic code by working with &dyn CloudProvider:

fn cheapest_offer(
    providers: &[&dyn CloudProvider],
    reqs: &InstanceRequirements,
) -> Option<(GpuOffer, &'static str)> {
    let mut best: Option<(GpuOffer, &str)> = None;
    for provider in providers {
        if let Ok(offers) = provider.search_offers(reqs) {
            for offer in offers {
                if best.as_ref().map_or(true, |(b, _)| offer.price_per_hour < b.price_per_hour) {
                    best = Some((offer, provider.name()));
                }
            }
        }
    }
    best
}

HTTP Client Integration

The providers are designed to be dependency-free. Each method can either:

  1. Execute requests directly — register an HTTP function with .with_http(fn).
  2. Return request descriptions — use .search_request(), .launch_request(), etc. to get HttpRequest structs you execute yourself.

Example with a custom HTTP function:

fn my_http(req: &rl4burn::cloud::vastai::HttpRequest) -> Result<String, String> {
    // Use reqwest, ureq, curl, etc.
    todo!()
}

let provider = VastAiProvider::new("key").with_http(my_http);

Supported GPU Types

The GpuType enum covers common training GPUs:

VariantGPU
Rtx3090NVIDIA RTX 3090 (24 GB)
Rtx4090NVIDIA RTX 4090 (24 GB)
RtxA4000NVIDIA RTX A4000 (16 GB)
RtxA5000NVIDIA RTX A5000 (24 GB)
RtxA6000NVIDIA RTX A6000 (48 GB)
A100PcieNVIDIA A100 PCIe (40/80 GB)
A100SxmNVIDIA A100 SXM (80 GB)
H100PcieNVIDIA H100 PCIe (80 GB)
H100SxmNVIDIA H100 SXM (80 GB)
L40NVIDIA L40 (48 GB)
L40sNVIDIA L40S (48 GB)

AlphaStar & ROA-Star

The 30-second version

AlphaStar (DeepMind, 2019) was the first AI to beat a top professional StarCraft II player. ROA-Star (Tencent, NeurIPS 2023) achieves the same level with 4x less compute by adding opponent modeling and smarter exploiter training.

Both are massive RL systems, but their core ideas decompose into modular building blocks — most of which are in rl4burn.

What makes StarCraft II hard for RL?

Imagine playing chess, except:

  • You can only see part of the board (fog of war)
  • Both players move simultaneously
  • You control 200 pieces at once
  • Each piece has 10+ possible actions
  • Games last 20+ minutes (thousands of decisions)

Standard RL algorithms break under this complexity. AlphaStar’s solution: decompose the problem.

Key ideas (and where they are in rl4burn)

1. Auto-regressive action space

Instead of choosing from millions of possible joint actions, AlphaStar samples one decision at a time:

action_type → delay → queue → selected_units → target_unit → target_location

Each head is conditioned on the previous samples. This is exactly what CompositeDistribution provides.

use rl4burn::CompositeDistribution;

let dist = CompositeDistribution::from_heads(
    &["action_type", "target", "ability"],
    &[11, 30, 8],
);

See Auto-Regressive Action Distributions for details.

2. V-trace for off-policy correction

With thousands of parallel actors, the behavior policy is always slightly stale. V-trace corrects for this. Already in rl4burn as vtrace_targets.

See V-trace.

3. UPGO (self-imitation learning)

Only learn from experiences where you did better than expected. If the return exceeds the value baseline, reinforce it. Otherwise, ignore it.

use rl4burn::upgo_advantages;
let advantages = upgo_advantages(&rewards, &values, &dones, last_value, gamma);

See UPGO.

4. League training with PFSP

Instead of just self-play, AlphaStar trains a league of agents:

  • Main agent: plays against everyone
  • Main exploiter: specializes in beating the main agent
  • League exploiters: find weaknesses across the entire pool

Opponents are sampled using PFSP — harder opponents (lower win rate) get sampled more often.

use rl4burn::{League, AgentRole, LeagueAgentConfig, PfspMatchmaking};

See League Training and PFSP Matchmaking.

5. ROA-Star’s additions

ROA-Star adds two ideas:

  • Beta-VAE opponent modeling: A frozen encoder predicts what the opponent is doing behind fog of war. The latent embedding is fed to all agents as extra context. See Beta-VAE Opponent Modeling.
  • Goal-conditioned exploiters: Exploiters are conditioned on strategy descriptors z, letting them specialize rapidly. See Goal-Conditioned RL.

Further reading

SCC (StarCraft Commander)

The 30-second version

SCC (inspir.ai, ICML 2021) reaches GrandMaster in StarCraft II with 10x less compute than AlphaStar. Its trick: a more efficient architecture (49M vs 139M parameters) and smarter training (agent branching instead of training exploiters from scratch).

Key innovations

Group Transformer

Instead of processing all game units with one big attention layer, SCC groups them:

  • Intra-group self-attention: ally units attend to each other, enemy units attend to each other
  • Inter-group cross-attention: ally representations attend to enemy representations

This is more efficient for games with natural groupings (teams, unit types).

rl4burn provides the building blocks: TransformerEncoder for self-attention, MultiHeadAttention for cross-attention. See Transformer Encoder and Attention Mechanisms.

Attention-based pooling

Variable numbers of units get aggregated into fixed-size vectors using learned query vectors. Better than mean-pooling because the model learns which units matter most.

use rl4burn::{AttentionPool, AttentionPoolConfig};

let pool = AttentionPoolConfig::new(128, 4, 2).init(&device);
// 128-dim entity embeddings, 4 learned queries, 2 attention heads
// Output: [batch, 4 * 128] = [batch, 512]

See Attention Mechanisms.

FiLM conditioning

The target position head is conditioned on the action type using FiLM: output = gamma(ctx) * input + beta(ctx). This lets the same network produce different spatial distributions depending on whether you’re attacking, moving, or casting.

use rl4burn::{Film, FilmConfig};
let film = FilmConfig::new(action_embed_dim, spatial_feature_dim).init(&device);

See FiLM Conditioning.

Agent branching

When creating a new exploiter, SCC clones the current main agent’s weights instead of starting from the supervised model. The optimizer state is reset. This gives exploiters a head start.

use rl4burn::algo::multi_agent::self_play::branch_agent;
let exploiter = branch_agent(&main_agent);
// Create a fresh optimizer for the exploiter

See Agent Branching.

Pointer networks

For selecting “which of my units should do this?”, SCC uses pointer networks — attention over encoder outputs producing a selection distribution.

use rl4burn::{PointerNet, PointerNetConfig};

The architecture in one sentence

Group Transformer encodes entities → attention pooling aggregates → residual LSTM sequences → FiLM-conditioned heads output → pointer networks select.

Further reading

JueWu & Honor of Kings

The 30-second version

JueWu (Tencent, NeurIPS 2020) is the first AI to beat top professional players in Honor of Kings, a 5v5 MOBA. The key insight: macro strategy (which lane to go to, when to fight) emerges from micro rewards — you don’t need a separate strategic layer.

Why MOBAs are different from StarCraft

In a MOBA, you control 1 hero (or 5 with shared weights). The action space is simpler but the strategic depth comes from teamwork, timing, and map control. Games are shorter (~15 minutes) but the reward is extremely sparse (win/lose).

Key ideas

Multi-head value decomposition

Instead of one value function, JueWu uses 5:

  • Farming (gold/XP)
  • KDA (kills/deaths/assists)
  • Damage dealt
  • Tower pushing
  • Win/lose

Each head learns independently with its own discount factor. The combined advantage drives the policy.

use rl4burn::{MultiHeadValueConfig, multi_head_gae};

let config = MultiHeadValueConfig::new(5, 0.99, 0.95)
    .with_weights(vec![0.1, 0.2, 0.2, 0.2, 0.3]);  // win/lose weighted highest

This helps with credit assignment: the agent knows why it’s doing well, not just that it’s doing well.

See Multi-Head Value Decomposition.

Dual-clip PPO

Standard PPO clips the policy ratio to prevent too-large updates. Dual-clip adds a second constraint: when the advantage is negative, the objective can’t go below c * advantage (c=3). This prevents catastrophic updates in distributed training where trajectories are slightly off-policy.

let config = PpoConfig {
    dual_clip_coef: Some(3.0),
    ..Default::default()
};

See Dual-Clip PPO.

Supervised pre-training matters (a lot)

JueWu-SL (a separate paper) showed that behavioral cloning from top human players provides 64% of the final RL performance. RL then refines and exceeds human play.

use rl4burn::bc_loss_discrete;

See Behavioral Cloning.

Curriculum Self-Play Learning (CSPL)

Training 40+ heroes at once doesn’t converge. CSPL breaks it into 3 phases:

  1. Specialist training: Train small models on fixed team compositions
  2. Distillation: Merge all specialists into one big model
  3. Generalization: Continue RL with random compositions

Without CSPL, training fails after 480+ hours. With it, convergence in ~264 hours.

use rl4burn::{CsplPipeline, CsplConfig, CsplPhase};

See CSPL.

Privileged critic

During training, the value function sees everything — including enemy positions behind fog of war. The policy only sees what the player would see. This dramatically improves value estimation.

use rl4burn::algo::privileged_critic::PrivilegedActorCritic;

See Privileged Critic.

The architecture in one sentence

Shared-weight policy across 5 heroes → LSTM for temporal memory → multi-head value for credit assignment → dual-clip PPO for stability → CSPL for scaling to many heroes.

Further reading

DreamerV3

The 30-second version

DreamerV3 (Hafner et al., Nature 2025) learns a model of the world and then trains a policy entirely inside imagined trajectories. It works across wildly different domains (Atari, robotic control, Minecraft) with zero hyperparameter tuning.

The secret: symlog transforms that make gradient magnitudes independent of reward scale.

What’s a world model?

Most RL algorithms learn by trial and error in the real environment. World models flip this:

  1. Play the game a bit, store transitions
  2. Train a model to predict what happens next (the “world model”)
  3. Imagine thousands of trajectories inside the model
  4. Train the policy on imagined data

Step 3 is free — no environment interaction needed. This makes DreamerV3 extremely sample-efficient.

The RSSM (how the world model works)

The RSSM (Recurrent State-Space Model) has 5 networks:

ComponentInputOutputPurpose
Sequence model (GRU)h_{t-1}, z_{t-1}, a_{t-1}h_tDeterministic memory
Encoder (posterior)h_t, observationz_tWhat actually happened
Dynamics (prior)h_tz_hat_tWhat the model predicts
Reward predictorh_t, z_trewardExpected reward
Continue predictorh_t, z_tcontinue probIs the episode over?

The state is (h_t, z_t) where h is a deterministic GRU hidden state and z is a stochastic categorical variable (32 groups x 32 classes = 1024 dims).

use rl4burn::{Rssm, RssmConfig};

let rssm = RssmConfig::new(obs_dim, action_dim).init(&device);
let state = rssm.initial_state(batch_size, &device);

// Training: use observations
let (next_state, post_logits, prior_logits) = rssm.obs_step(&state, action, obs);

// Imagination: no observations needed
let next_state = rssm.imagine_step(&state, action);

See RSSM.

Symlog: the key to fixed hyperparameters

The biggest problem with RL across domains is reward scale. Atari rewards are 0-1000. Robotic rewards are -1 to 0. Without normalization, you need different learning rates for each.

DreamerV3 solves this with symlog: symlog(x) = sign(x) * ln(|x| + 1). This compresses large values and keeps small values linear. Combined with twohot encoding (distributional predictions), gradient magnitudes become independent of value scale.

use rl4burn::{symlog, symexp, TwohotEncoder};

let encoder = TwohotEncoder::new();  // 255 bins, [-20, 20] symlog space
let targets = encoder.encode(values, &device);   // [batch, 255]
let loss = encoder.loss(logits, values, &device); // cross-entropy
let decoded = encoder.decode(softmax(logits, 1), &device);  // back to scalars

See Symlog and Twohot Encoding.

KL balancing: training the world model

The RSSM is trained with two KL losses:

  • Dynamics loss: Make the prior match the posterior (train the predictor)
  • Representation loss: Make the posterior predictable (don’t be too complex)

Each has a stop-gradient on one side, plus a “free bits” threshold (ignore KL below 1 nat).

use rl4burn::{kl_balanced_loss, KlBalanceConfig};

let config = KlBalanceConfig {
    dyn_weight: 0.5,
    rep_weight: 0.1,
    free_bits: 1.0,
};
let loss = kl_balanced_loss(posterior_logits, prior_logits, &config);

See KL Balancing with Free Bits.

Imagination rollouts

Once the world model is trained, generate trajectories purely in latent space:

use rl4burn::algo::planning::imagination::{imagine_rollout, lambda_returns};

let trajectory = imagine_rollout(&rssm, initial_states, |h, z| actor(h, z), 15);
// trajectory.states: 16 states (initial + 15 steps)
// trajectory.reward_logits: 15 predicted reward distributions

Compute lambda-returns on the imagined rewards, then train actor and critic on these imagined trajectories. The world model parameters are frozen during actor-critic training.

See Imagination Rollouts.

Sequence replay buffer

DreamerV3 samples contiguous sequences (T=64) from a FIFO buffer, never crossing episode boundaries.

use rl4burn::{SequenceReplayBuffer, SequenceStep};
let mut buffer = SequenceReplayBuffer::new(1_000_000, 64);

See Sequence Replay Buffer.

Percentile return normalization

Instead of per-minibatch normalization, DreamerV3 tracks the 5th-95th percentile range of returns with an EMA and divides by max(1, range). The floor of 1 prevents amplifying noise.

use rl4burn::PercentileNormalizer;
let mut normalizer = PercentileNormalizer::new();
normalizer.update(&returns);
let normalized = normalizer.normalize(&advantages);

See Percentile Return Normalization.

Further reading

Working with Burn’s Autodiff

We discovered several Burn 0.20 behaviors that affect RL implementations. This chapter documents them so you don’t hit the same issues.

1. Custom parameter initialization must use from_data + load_record

Problem: Param::initialized(id, tensor) creates parameters that are invisible to Burn’s autodiff. The optimizer will silently produce zero updates — the model trains but weights never change.

Cause: Tensors created via Tensor::from_data(TensorData::new(...), device) are leaf nodes without gradient tracking. Wrapping them in Param::initialized doesn’t register them.

Fix: Use Burn’s record system:

// WRONG: gradients won't flow
let weight = Tensor::from_data(my_data, &device);
let param = Param::initialized(old_param.id.clone(), weight);

// RIGHT: use Param::from_data + load_record
use burn::nn::LinearRecord;

let record = LinearRecord {
    weight: Param::from_data(weight_data, &device),
    bias: Some(Param::from_data(bias_data, &device)),
};
let linear = LinearConfig::new(d_in, d_out).init(&device).load_record(record);

orthogonal_linear handles this correctly. If you implement custom initialization, follow the same pattern.

Alternative: If you must use Param::initialized, call .set_require_grad(true) on the result. But from_data + load_record is the canonical approach.

2. Gradient clipping is per-parameter, not global

Problem: GradientClippingConfig::Norm(0.5) on the optimizer clips each parameter tensor’s gradient independently. PyTorch’s clip_grad_norm_ clips the global norm across all parameters.

Impact: With per-parameter clipping, the actor’s small gradients are unaffected while the critic’s large gradients are clipped. With global clipping, all gradients are scaled by the same factor, which is the standard behavior for PPO.

Fix: Use clip::clip_grad_norm instead of Burn’s optimizer clipping:

use rl4burn::clip::clip_grad_norm;

// Don't configure clipping on the optimizer
let mut optim = AdamConfig::new().init();

// Clip manually between backward and step
let grads = loss.backward();
let mut grads = GradientsParams::from_grads(grads, &model);
grads = clip_grad_norm(&model, grads, 0.5);
model = optim.step(lr, model, grads);

PPO’s ppo_update does this automatically when max_grad_norm > 0.

3. mask_where may not propagate gradients through the source argument

Problem: tensor.mask_where(mask, source) selects from source where mask is true, otherwise keeps self. Burn’s autodiff may not propagate gradients through the source argument.

Impact: If you use mask_where to compute max(a, b) by selecting the larger value, and the mask happens to select from source for all elements, the gradient can be zero.

Fix: Use arithmetic alternatives:

use burn::tensor::activation::relu;

// Instead of mask_where for max:
// max(a, b) = a + relu(b - a)
let max_val = a.clone() + relu(b - a);

// Instead of mask_where for min:
// min(a, b) = b - relu(b - a)
let min_val = b.clone() - relu(b.clone() - a);

These have correct gradients everywhere (except at the exact boundary where a == b, which has measure zero in practice).

General advice

  • Always test gradient flow. After implementing a custom model or loss, verify that optim.step actually changes the model’s weights. A simple test:
let before: Vec<f32> = model.weight.val().into_data().to_vec().unwrap();
// ... forward, loss, backward, step ...
let after: Vec<f32> = model.weight.val().into_data().to_vec().unwrap();
assert!(before != after, "weights should change");
  • Use model.valid() for inference. This strips the autodiff layer, avoiding unnecessary computation graph construction during rollout collection.

  • Extract tensor data to break the computation graph. When you need to use a value as a constant (e.g., target Q-values), call .into_data().to_vec() and create a fresh tensor from the result. This is equivalent to PyTorch’s .detach().