Learn AI Series):Before we teach an agent to dream, let's settle last episode's three exercises. All of them build on the PPONetwork, RolloutBuffer and PPOAgent classes from episode #109, so I'm assuming those are imported and sitting in scope. As before I'm leaning on gymnasium throughout (pip install gymnasium if you skipped it).
Exercise 1: Get the PPOAgent training on CartPole-v1, plot the per-rollout average reward, then run the clip ablation -- set clip_eps enormous so the clamp never fires -- and compare.
import gymnasium as gym
import numpy as np
import torch
# Assumes PPOAgent and train_ppo from episode #109.
def run_ppo(clip_eps, seed, total_steps=80_000, rollout_len=2048):
env = gym.make("CartPole-v1")
torch.manual_seed(seed)
np.random.seed(seed)
agent = PPOAgent(env.observation_space.shape[0],
env.action_space.n, clip_eps=clip_eps)
history = train_ppo(env, agent, total_steps, rollout_len)
return history
clipped = run_ppo(0.2, seed=0) # the real PPO
unclipped = run_ppo(100.0, seed=0) # clamp never triggers
for name, h in [("clip=0.2", clipped), ("clip=100", unclipped)]:
print(f"{name:>9}: final avg (last 20) = {np.mean(h[-20:]):6.1f} "
f"| peak = {max(h):6.1f}")
Plot the two reward curves side by side and the story tells itself. With clip_eps = 0.2 the curve climbs steadily and stays up at CartPole's ceiling of 500. With clip_eps = 100.0 the clamp in torch.clamp(ratio, 1 - eps, 1 + eps) never actually bites -- surr1 and surr2 become the same number, the min is a no-op, and you've quietly turned PPO back into a multi-epoch REINFORCE-with-baseline. And because we still loop over the same rollout for several epochs, those repeated unconstrained steps march the policy miles away from the data that produced it -- exactly the "no seatbelt" death spiral that opened episode #109. You'll see the unclipped curve spike and then crater, sometimes more than once per run. Same code, one disabled clamp, night-and-day stability. Tells you precisely how much work that one little clamp is doing ;-)
Exercise 2: Add an approximate-KL early stop. After each epoch estimate the mean KL with the cheap mean(old_log_probs - new_log_probs), and if it exceeds a threshold, break out of the epoch loop.
import torch
import torch.nn.functional as F
from torch.distributions import Categorical
# A drop-in replacement for PPOAgent.update() from episode #109.
def update_with_kl_stop(self, last_value, target_kl=0.015):
returns, advantages = self.buffer.compute_gae(
last_value, self.gamma, self.lam)
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
stopped_at = self.epochs # for logging
for epoch in range(self.epochs):
approx_kls = []
for (states, actions, old_log_probs,
b_returns, b_adv) in self.buffer.batches(
returns, advantages, self.batch_size):
logits, values = self.net(states)
dist = Categorical(logits=logits)
new_log_probs = dist.log_prob(actions)
entropy = dist.entropy().mean()
ratio = torch.exp(new_log_probs - old_log_probs)
surr1 = ratio * b_adv
surr2 = torch.clamp(ratio, 1 - self.clip_eps,
1 + self.clip_eps) * b_adv
actor_loss = -torch.min(surr1, surr2).mean()
critic_loss = F.mse_loss(values.squeeze(), b_returns)
loss = (actor_loss + self.value_coef * critic_loss
- self.entropy_coef * entropy)
self.opt.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.net.parameters(), 0.5)
self.opt.step()
approx_kls.append((old_log_probs - new_log_probs).mean().item())
if np.mean(approx_kls) > target_kl: # the early stop
stopped_at = epoch + 1
break
self.buffer.clear()
return stopped_at # log this across training
The estimator mean(old_log_probs - new_log_probs) is a first-order stand-in for the true KL divergence between the old and new policy -- cheap, no extra forward passes, and good enough to act on. Log stopped_at across a whole training run and you'll notice the early stop fires often in the early, fast-moving phase (when each rollout teaches the policy a lot, so it wants to move far) and almost never late in training (when the policy is nearly converged and barely budges). What you've built here is a little hybrid: PPO's implicit clip keeps any single step honest, and this explicit KL check keeps the accumulation of several epochs honest. That explicit KL bound is exactly the leash TRPO used (episode #109) -- we've smuggled a piece of TRPO back in on top of PPO, for about eight lines of code. Having said that, vanilla PPO mostly does fine without it -- this is a belt-and-suspenders touch for the runs that misbehave.
Exercise 3: Adapt PPO to a continuous action space and run it on Pendulum-v1, swapping the Categorical for a Normal.
import torch
import torch.nn as nn
from torch.distributions import Normal
class ContinuousPPONetwork(nn.Module):
"""Actor outputs a Gaussian mean; log_std is a free parameter."""
def __init__(self, state_dim, action_dim, hidden=64):
super().__init__()
self.shared = nn.Sequential(
nn.Linear(state_dim, hidden), nn.Tanh(),
nn.Linear(hidden, hidden), nn.Tanh(),
)
self.mean = nn.Linear(hidden, action_dim) # action mean
self.log_std = nn.Parameter(torch.zeros(action_dim)) # state-independent
self.critic = nn.Linear(hidden, 1)
def forward(self, state):
features = self.shared(state)
return self.mean(features), self.critic(features)
def get_action(self, state):
state_t = torch.FloatTensor(state).unsqueeze(0)
mean, value = self.forward(state_t)
std = self.log_std.exp()
dist = Normal(mean, std)
action = dist.sample()
# sum over action dims: independent Gaussians -> joint log-prob
log_prob = dist.log_prob(action).sum(-1)
return action.squeeze(0).numpy(), log_prob, value.squeeze()
Run this on Pendulum-v1 (a continuous torque-control task) and the same PPOAgent skeleton trains it -- you only have to thread Normal through the update's log-prob calls and .sum(-1) the per-dimension log-probs. Now notice carefully what changed: the distribution (Normal not Categorical), the action shape (a real-valued vector, not an index), and the actor's output layer (a mean head plus a learned log_std). And then notice everything that stayed exactly the same: the clipped surrogate, the GAE computation, the multi-epoch loop, the value loss, the entropy bonus. That invariance is the whole reason PPO travels so well -- discrete or continuous, the optimisation core does not care. It's also the reason I made you build PPO carefully last time: the investment pays off across an entire zoo of problems.
Right -- episode 110, and we are about to cross a genuine line in this RL chapter.
Cast your eye back over everything we've done since episode #102. Q-Learning, SARSA, DQN, REINFORCE, Actor-Critic, PPO -- every single one of them is model-free. The agent reaches into the environment, pulls out rewards, and adjusts a policy or a value function accordingly. It never once asks how the environment works. It doesn't predict what the next state will be. It just reacts to whatever it observes, like a creature with very fast reflexes and absolutely no imagination.
That works, and it works famously well. But it has one nagging, expensive flaw, and today we finally confront it.
Model-free methods are sample-hungry. Embarrassingly so. DQN learning to play a single Atari game can chew through tens of millions of frames -- the equivalent of weeks of nonstop play. PPO solving a robotics task in simulation might need hundreds of millions of timesteps. In a simulator, where you can run a thousand environments in parallel and time costs nothing, that's tolerable. On a real robot, where every interaction takes real seconds and a bad action can snap a real servo, it's a non-starter. You cannot crash a real drone a million times to teach it to fly.
Now think about how you learn instead. You did not need to total a thousand cars to learn to drive. You hold a mental model of how a car behaves -- turn the wheel right, the car goes right; brake hard, you lurch forward -- and you run little simulations in your head before you ever touch the road. "If I pull out now, that lorry arrives about there..." You imagine consequences, and you learn from the imagining.
Model-based RL is the attempt to give an agent that same gift. The plan is simple to state:
Real interactions are precious and slow. Model queries are cheap and fast (a forward pass, or a table lookup). So you spend your scarce real data improving the model, then mint as much synthetic data as you like from it. That is the entire pitch, and when it works, the sample-efficiency gains are enormous -- ten to a hundred times fewer real interactions is not unusual.
Everything hinges on that model. So what is it? Concretely, a function that takes the current state s and an action a and predicts two things:
s' -- where do we end up?r -- what do we get for going there?In a neural network, that's a small regression problem:
import torch
import torch.nn as nn
import numpy as np
class EnvironmentModel(nn.Module):
"""Learned dynamics: predicts next state and reward from (state, action)."""
def __init__(self, state_dim, action_dim, hidden=256):
super().__init__()
self.net = nn.Sequential(
nn.Linear(state_dim + action_dim, hidden), nn.ReLU(),
nn.Linear(hidden, hidden), nn.ReLU(),
)
self.state_head = nn.Linear(hidden, state_dim) # predicts delta-state
self.reward_head = nn.Linear(hidden, 1) # predicts reward
def forward(self, state, action):
# action is one-hot encoded for discrete action spaces
x = torch.cat([state, action], dim=-1)
features = self.net(x)
delta_state = self.state_head(features)
reward = self.reward_head(features)
next_state = state + delta_state # predict the change
return next_state, reward.squeeze(-1)
One detail there is doing quite some heavy lifting: the model predicts the state delta (state + delta_state) rather than the absolute next state. This is standard practice and it matters. Between one frame and the next, most of the state barely moves -- a cart's position nudges, a pole's angle ticks. Asking the network to output the full new state means re-learning all the bits that didn't change; asking it for only the change is a far easier target, with smaller, better-behaved numbers. If that reminds you of the residual connections from episode #46, good -- it's the same trick (learn the residual, not the whole function), wearing different clothes.
The model is trained by plain old supervised learning on real transitions we've collected. No RL magic here at all -- it's the regression we've known since episode #10, just with a state-and-reward target:
class ModelTrainer:
"""Train the environment model from collected (s, a, r, s') experience."""
def __init__(self, model, lr=1e-3):
self.model = model
self.optimizer = torch.optim.Adam(model.parameters(), lr=lr)
self.buffer = [] # (s, a, r, s') tuples
def add_transition(self, state, action, reward, next_state):
self.buffer.append((state, action, reward, next_state))
def train_step(self, batch_size=64):
if len(self.buffer) < batch_size:
return None
idx = np.random.choice(len(self.buffer), batch_size)
batch = [self.buffer[i] for i in idx]
states, actions, rewards, next_states = zip(*batch)
states_t = torch.FloatTensor(np.array(states))
actions_t = torch.FloatTensor(np.array(actions))
rewards_t = torch.FloatTensor(rewards)
next_states_t = torch.FloatTensor(np.array(next_states))
pred_next, pred_reward = self.model(states_t, actions_t)
state_loss = nn.functional.mse_loss(pred_next, next_states_t)
reward_loss = nn.functional.mse_loss(pred_reward, rewards_t)
loss = state_loss + reward_loss # joint objective
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
return loss.item()
Notice the loss is just the sum of two MSE terms -- one for predicting the next state, one for the reward. The agent's whole "understanding" of physics boils down to getting that joint loss low. Bam, jonguh -- that's a world model in fourteen lines.
The cleanest way to see model-based RL in action is Dyna-Q (Sutton, 1991), and it is almost cheekily simple. After every real step, you do your normal Q-Learning update (the model-free part, exactly as in episode #106) -- and then you do a handful of extra Q-Learning updates on transitions conjured up by the model. Real learning and imagined learning, interleaved, sharing the same Q-table:
from collections import defaultdict
import numpy as np
class DynaQ:
"""Dyna-Q: tabular Q-Learning interleaved with model-based planning."""
def __init__(self, n_states, n_actions, alpha=0.1, gamma=0.99,
epsilon=0.1, planning_steps=5):
self.n_actions = n_actions
self.alpha, self.gamma, self.epsilon = alpha, gamma, epsilon
self.planning_steps = planning_steps
self.Q = defaultdict(lambda: np.zeros(n_actions))
self.model = {} # (state, action) -> (reward, next_state, done)
self.visited_sa = [] # which (s, a) pairs we've actually seen
def choose_action(self, state):
if np.random.random() < self.epsilon:
return np.random.randint(self.n_actions)
return int(np.argmax(self.Q[state]))
def update(self, state, action, reward, next_state, done):
# 1. Real Q-Learning update (model-free)
target = reward + self.gamma * np.max(self.Q[next_state]) * (1 - done)
self.Q[state][action] += self.alpha * (target - self.Q[state][action])
# 2. Update the model with what really happened
self.model[(state, action)] = (reward, next_state, done)
if (state, action) not in self.visited_sa:
self.visited_sa.append((state, action))
# 3. Planning: replay imagined transitions from the model
for _ in range(self.planning_steps):
s, a = self.visited_sa[np.random.randint(len(self.visited_sa))]
r, ns, d = self.model[(s, a)]
target = r + self.gamma * np.max(self.Q[ns]) * (1 - d)
self.Q[s][a] += self.alpha * (target - self.Q[s][a])
Look at what that buys you. With planning_steps = 5, every single real interaction triggers six Q-updates -- one from reality plus five replayed from the model. The Q-values propagate backwards through the state space five times faster in wall-clock terms, because the agent isn't sitting idle waiting for the environment to hand it the next experience -- it's mining the experience it already has. On a small gridworld where the tabular model is essentially perfect, Dyna-Q converges dramatically quicker than plain Q-Learning, and the gap widens the more planning steps you allow.
There's a deep idea hiding in that humble loop, and it's worth saying plainly: real data improves the model, the model improves the policy. Two learning processes feeding each other, and only one of them costs you precious real-world interactions.
NB: if Dyna-Q's replay reminds you of DQN's experience replay buffer from episode #107, you've spotted something real -- both reuse past experience instead of throwing it away. The difference is that replay only ever serves up transitions that genuinely happened, whereas a learned model can generate transitions you've never seen at all. That extra reach is the model-based superpower, and (as we'll see) also its Achilles' heel.
A tabular model only works when states are few and discrete. The moment the state is a vector of real numbers -- a robot's joint angles, a game screen's features -- the table explodes. The fix is the obvious one: replace the dictionary with the EnvironmentModel network from earlier, and let it generalise across states it has only sort-of seen:
class NeuralDynaAgent:
"""Dyna with a neural environment model for continuous-state problems."""
def __init__(self, state_dim, n_actions, hidden=128,
planning_steps=10, planning_horizon=5):
self.state_dim, self.n_actions = state_dim, n_actions
self.planning_steps = planning_steps
self.planning_horizon = planning_horizon
self.q_net = nn.Sequential( # model-free component
nn.Linear(state_dim, hidden), nn.ReLU(),
nn.Linear(hidden, hidden), nn.ReLU(),
nn.Linear(hidden, n_actions),
)
self.q_optimizer = torch.optim.Adam(self.q_net.parameters(), lr=1e-3)
self.env_model = EnvironmentModel(state_dim, n_actions, hidden)
self.model_optimizer = torch.optim.Adam(
self.env_model.parameters(), lr=1e-3)
self.replay_buffer = []
def imagine_rollout(self, start_state, horizon):
"""Walk forward through the LEARNED model, no real environment touched."""
state = torch.FloatTensor(start_state).unsqueeze(0)
imagined = []
for _ in range(horizon):
with torch.no_grad():
action_idx = self.q_net(state).argmax(dim=1).item()
action = torch.zeros(1, self.n_actions)
action[0, action_idx] = 1.0
with torch.no_grad():
next_state, reward = self.env_model(state, action)
imagined.append((state.squeeze().numpy(), action_idx,
reward.item(), next_state.squeeze().numpy()))
state = next_state # feed prediction back in
return imagined
def plan(self):
"""Generate imagined experience and train the Q-network on it."""
if len(self.replay_buffer) < 100:
return
for _ in range(self.planning_steps):
start = self.replay_buffer[
np.random.randint(len(self.replay_buffer))][0]
for s, a, r, ns in self.imagine_rollout(start, self.planning_horizon):
self._q_update(s, a, r, ns)
def _q_update(self, state, action, reward, next_state):
state_t = torch.FloatTensor(state).unsqueeze(0)
next_state_t = torch.FloatTensor(next_state).unsqueeze(0)
q_pred = self.q_net(state_t)[0, action]
with torch.no_grad():
q_target = reward + 0.99 * self.q_net(next_state_t).max().item()
loss = (q_pred - q_target) ** 2
self.q_optimizer.zero_grad()
loss.backward()
self.q_optimizer.step()
The crucial line is in imagine_rollout: state = next_state. We feed the model's own prediction back into the model as the next input, and step forward again. The agent is walking through a world entirely of its own making -- a daydream stitched together from a learned dynamics function. As long as the dream stays faithful to reality, training on it is almost free.
That phrase -- "as long as the dream stays faithful" -- is the whole ballgame, and we'll come back to it with a vengeance shortly. First, two landmark systems that pushed the dreaming idea to its limit.
World Models (Ha & Schmidhuber, 2018) is one of those papers that's a genuine pleasure to read, because the central image is so vivid: an agent that learns a compact mental model of its world and then learns to act entirely inside that model, never touching the real environment during policy training. It splits the job into three parts:
z -- the gist of the scene, stripped of pixel clutter;The remarkable result: they trained the controller purely inside the learned model -- inside the "dream" -- and then dropped it into the real environment, where it still performed. The policy never saw a single real observation during its training. It learned to drive a racing car, in effect, by practising in a hallucination of the track. That works only because the model was faithful enough that habits learned in the dream carried over to reality. When you hear people say a policy was "trained in imagination", this is the lineage they're pointing at.
If World Models is the elegant idea, MuZero (Schrittwieser et al., 2020) is the heavyweight result. To appreciate it, line up DeepMind's family tree. AlphaGo was handed the rules of Go. AlphaZero was handed the rules of chess, Go and shogi. MuZero is handed... nothing. It learns to play at superhuman level without knowing the rules of the game at all -- it has to figure out the consequences of its own moves from experience.
It pulls this off with three learned functions:
class MuZeroComponents(nn.Module):
"""Simplified MuZero: representation, dynamics, prediction."""
def __init__(self, obs_dim, hidden_dim, n_actions):
super().__init__()
# Representation: real observation -> hidden state
self.representation = nn.Sequential(
nn.Linear(obs_dim, hidden_dim), nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
)
# Dynamics: hidden state + action -> next hidden state (+ reward)
self.dynamics = nn.Sequential(
nn.Linear(hidden_dim + n_actions, hidden_dim), nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
)
self.reward_pred = nn.Linear(hidden_dim, 1)
# Prediction: hidden state -> policy + value
self.policy_head = nn.Linear(hidden_dim, n_actions)
self.value_head = nn.Linear(hidden_dim, 1)
def initial_inference(self, observation):
"""From a real observation: encode, then read off policy and value."""
hidden = self.representation(observation)
return hidden, self.policy_head(hidden), self.value_head(hidden)
def recurrent_inference(self, hidden_state, action_onehot):
"""From a hidden state + action: imagine the next hidden state."""
x = torch.cat([hidden_state, action_onehot], dim=-1)
next_hidden = self.dynamics(x)
reward = self.reward_pred(next_hidden)
policy = self.policy_head(next_hidden)
value = self.value_head(next_hidden)
return next_hidden, reward, policy, value
The masterstroke is what MuZero refuses to predict. It does not try to reconstruct the next screen of pixels (hard, wasteful, and mostly irrelevant to playing well). Instead it learns an abstract hidden state -- an internal representation tuned for one job only: making the dynamics easy to predict and the planning effective. The model lives in this learned latent space, never in observation space. It doesn't need to imagine what the board looks like, only what matters about it for winning.
And then it plans with Monte Carlo Tree Search (MCTS) -- the same family of tree search we touched on in the bandits and dynamic-programming episodes (#103, #104), now run inside the learned model. From the current hidden state, MuZero simulates many candidate action sequences using its dynamics function, scores each imagined line with the value head, and commits to the move whose imagined future looks best. Search, but over a dream of the game rather than a known rulebook.
The payoff: MuZero matched AlphaZero on Go, chess and shogi -- without the rules -- and simultaneously set a new state of the art on Atari, beating the model-free champions while using far less data. That last bit is the whole thesis of this episode, proven at the highest level: a good learned model buys you sample efficiency that pure reaction simply cannot.
Now the catch, and it is a big one. Learned models are wrong. Not catastrophically, usually -- but wrong in small ways, every step. And in a multi-step rollout those small errors compound.
Do the arithmetic and it's sobering. Say your model is a very respectable 95% accurate on a single step. Chain twenty steps together in your imagination and the accuracy of the final state is roughly 0.95 ** 20, which is about 36%. Twenty steps into the dream and you're more wrong than right. The agent, of course, doesn't know this -- it cheerfully trains on the garbage as if it were gospel.
Worse still is what happens when the model has an exploitable flaw. Suppose there's some weird corner of state space where the model wrongly predicts a huge reward. A model-free agent could never be fooled by this -- it only ever sees real rewards. But a model-based agent will find that phantom jackpot and exploit it relentlessly, optimising hard for a payoff that exists only in its own buggy imagination. It's the RL equivalent of a student who learns to game the practice test instead of the subject.
The field has a toolbox for keeping the dream honest:
The honest summary: model-based RL is a constant negotiation between the sample efficiency you gain and the model error you risk. Push the horizon too far and the second eats the first.
So when do you reach for which? Here's the trade-off laid bare:
| Model-free (DQN, PPO) | Model-based (Dyna, MuZero) | |
|---|---|---|
| Sample efficiency | Low -- millions of interactions | High -- often 10-100x fewer |
| Compute per step | Low | High (model training + planning) |
| Best-case performance | Can reach optimal | Capped by model accuracy |
| Robustness | High (no model to be wrong) | Sensitive to model error |
| Moving parts | Few | Many -- more to tune and debug |
The deciding question is almost always: how expensive is a real interaction? When interactions are cheap -- games, fast simulators where you can spin up a thousand parallel worlds -- model-free wins on sheer simplicity and robustness; who cares about sample efficiency when samples are free? But when interactions are expensive or dangerous -- a physical robot arm, a drone, a treatment policy in healthcare, an industrial controller -- model-based methods stop being a nicety and become the only sane option. You can't crash a real robot ten million times. You can let it crash ten million times in a dream.
That, in one sentence, is why this family of methods exists.
Exercise 1: Implement tabular DynaQ (the class above) on a small gridworld -- FrozenLake-v1 with is_slippery=False from gymnasium does nicely. Train it with planning_steps = 0 (which is just plain Q-Learning), then 5, then 50, all under the same seed, and plot episodes-to-solve for each. Confirm with your own eyes that more planning means faster convergence -- and then explain in a sentence why the benefit eventually plateaus (hint: think about how much genuinely new information one real transition can carry).
Exercise 2: Take the EnvironmentModel and ModelTrainer, collect a few thousand random-policy transitions from CartPole-v1, and train the model on them. Then measure compounding error directly: from a real start state, roll the model forward k steps and compare its predicted state against the true environment's state for k = 1, 5, 10, 20. Plot prediction error against k. You should see roughly the snowball we discussed -- and you'll have measured the 0.95 ** k problem rather than just read about it.
Exercise 3: Build a tiny ensemble -- train five EnvironmentModels on the same data but with different random seeds (so different initialisations and minibatch orders). For a batch of states, have all five predict the next state, and compute the variance of their predictions per state. Now connect the dots: argue how you'd use that variance as an uncertainty signal to decide when to stop trusting an imagined rollout. This is the seed of how serious model-based agents keep their dreams from running off the rails.
That ensemble-disagreement idea is your stepping stone into the messier, more crowded worlds we're heading for next -- environments where the agent is no longer alone, and where "predicting what happens next" means predicting what other learning agents will do. We've now got the model-based engine built and, more importantly, we know exactly where it breaks. That honest understanding of the failure mode is worth more than the algorithm itself ;-)