Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Replay Buffer

ReplayBuffer<S, R> stores transitions for off-policy algorithms like DQN. It’s generic over the sample type and a deterministic RNG for reproducible sampling.

API

use rand::SeedableRng;

let mut buffer = ReplayBuffer::new(10_000, rand::rngs::SmallRng::seed_from_u64(42));

buffer.extend(transitions);          // add samples
let batch = buffer.sample(64);       // random references
let batch = buffer.sample_cloned(64); // random clones (for owned data)
let groups = buffer.group_by(|t| t.episode_id); // group by key

Eviction

When the buffer exceeds capacity, the oldest samples are dropped first (FIFO).

With DQN

use rl4burn::dqn::Transition;
use rl4burn::replay::ReplayBuffer;

let mut buffer = ReplayBuffer::new(10_000, rand::rngs::SmallRng::seed_from_u64(42));

// Store transitions
buffer.extend(std::iter::once(Transition {
    obs: obs.clone(),
    action: action as i32,
    reward: result.reward,
    next_obs: result.observation.clone(),
    done: result.done(),
}));

// dqn_update samples from the buffer internally
(online, stats) = dqn_update(online, &target, &mut optim, &mut buffer, &config, &device);

Trajectory grouping

The group_by method groups sample indices by an arbitrary key function. Useful for V-trace rescoring where you need to process entire trajectories:

let groups = buffer.group_by(|sample| sample.trajectory_id);
for (traj_id, indices) in &groups {
    let trajectory: Vec<_> = indices.iter().map(|&i| &buffer.samples()[i]).collect();
    // rescore this trajectory
}