Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Multi-Agent Shared-Weight Training

Efficiently control multiple units with a single shared policy network. Used by JueWu for 5 heroes and applicable to any game with multiple controlled units.

API

use rl4burn::{batch_multi_agent_obs, unbatch_actions, broadcast_team_reward};

// Batch observations from all agents across all environments
let (obs_tensor, n_envs, n_agents) = batch_multi_agent_obs::<B>(
    &per_env_per_agent_obs,
    &device,
);
// obs_tensor: [n_envs * n_agents, obs_dim] — one big batch

// Single forward pass for all agents
let output = model.forward(obs_tensor);

// Unbatch actions back to per-env, per-agent
let actions = unbatch_actions(&flat_actions, n_envs, n_agents);
// actions: [n_envs][n_agents]

// Broadcast team reward to all agents
let per_agent_rewards = broadcast_team_reward(&env_rewards, n_agents);

Why shared weights?

With 30 units, running 30 separate forward passes is expensive. With shared weights, batch all 30 observations into one forward pass. The policy generalizes across unit types through the observation encoding.