Reinforcement Learning
For this project, I implemented Proximal Policy Optimization (PPO) to train agents to play Pong. Unlike standard Supervised Learning where we have a ground truth (e.g., “This image is a cat”), Reinforcement Learning (RL) involves an agent learning through trial and error by interacting with an environment.
The unique challenge here was Self-Play (Multi-Agent). The agent doesn’t play against a pre-programmed definition of a “hard” computer opponent; it plays against a copy of itself. As the agent gets better, its opponent gets better, creating a natural curriculum of increasing difficulty.
Here is the high-level architecture of the training loop:
- Environment:
pong_v3(PettingZoo) - Multi-agent Atari environment. - Algorithm: PPO (Proximal Policy Optimization).
- Model: A 3-Layer Convolutional Neural Network (CNN) processing stacked frames.
The Model Architecture
The agent “sees” the game not as a single static image, but as a stack of 4 grayscale consecutive frames ( pixels). Additionally, because this is a multi-agent environment, 2 extra channels are added to indicate which agent (paddle) the model is currently controlling (Agent Indicator).
The input tensor shape is (Batch, 6, 84, 84) (4 Frame Stack + 2 Agent Indicators).
self.network = nn.Sequential(
layer_init(nn.Conv2d(6, 32, 8, stride=4)), # Layer 1
nn.ReLU(),
layer_init(nn.Conv2d(32, 64, 4, stride=2)), # Layer 2
nn.ReLU(),
layer_init(nn.Conv2d(64, 64, 3, stride=1)), # Layer 3
nn.ReLU(),
nn.Flatten(),
layer_init(nn.Linear(64 * 7 * 7, 512)), # Dense Representation
nn.ReLU(),
)
This backbone feeds into two separate “heads”:
- Actor (Policy): Outputs logits for the 6 possible actions (NOOP, FIRE, UP, DOWN, UP+FIRE, DOWN+FIRE).
- Critic (Value): Estimates the Value of the current state (how likely am I to win from here?).
Proximal Policy Optimization (PPO)
PPO is a policy gradient method that optimizes the agent’s decision-making policy. It iteratively improves the policy by taking small, safe update steps, preventing the unstable learning that can occur if the policy changes too drastically in a single update.
1. Policy Gradient & The “Clip”
The core idea of Policy Gradient is simple:
- If an action led to a good result (positive advantage), increase its probability.
- If it led to a bad result (negative advantage), decrease its probability.
However, standard policy gradient methods can be unstable. If we take too large a step based on a single batch of data, the policy might change drastically and never recover (the “cliff” problem).
PPO solves this with a Clipping Mechanism. It prevents the new policy from deviating too far from the old policy in a single update.
The Objective Function:
Where:
- is the probability ratio.
- is the Advantage (how much better this action was than average).
- is the clipping parameter (usually 0.1 or 0.2).
- denotes the empirical average over a finite batch of samples.
Here is a simple calculation demonstrating how this works. Assume an advantage and .
Case 1: Safe Update The policy changes slightly (Ratio ).
(Since is within the bounds , the clip function returns it unchanged)
Result: The update is accepted as is.
Case 2: Dangerous Update The policy changes drastically (Ratio , doubling the probability).
(Since is greater than the upper bound , the clip function limits it to )
Result: The update is clipped. The model is only rewarded as if it took the maximum safe step (1.2), removing the incentive to make such a large jump.
2. Generalized Advantage Estimation (GAE)
How do we know if an action was “good”? We use the Advantage Function.
Simply using the immediate reward is shortsighted (hitting the ball now might lead to losing later). PPO uses Generalized Advantage Estimation (GAE) to balance variance and bias. It calculates a weighted sum of Bellman errors () to propagate rewards backward in time.
To calculate this in practice, we use two components:
- TD (Temporal Difference) Error (): The difference between “What actually happened + What we expect next” vs “What we expected originally”.
- Advantage (): The recursive sum of these errors (GAE).
Where:
- is the immediate reward.
- is the predicted value of the current state.
- is the predicted value of the next state.
- is the discount factor for future rewards.
- is the discount factor for the advantage.
To see this propagation in action, let’s look at a key moment in the game: The agent hits the ball past the opponent.
- Step 3 (Winning Hit): The ball passes the opponent. Reward is +1.
- Step 2 (Setup): The agent waits in position. Reward is 0.
- Step 1 (Approach): The agent moves towards the ball. Reward is 0.
Here are the raw values the agent observes:
Step | Reward | Value (This represents the Critic's predicted probability of winning)
3 | 1.0 | 0.80
2 | 0.0 | 0.50
1 | 0.0 | 0.60
Now, let’s calculate the Advantage (credit) for each step working backwards, assuming and :
Step 3 (Winning Hit): The agent hits the ball, and it passes the opponent. Reward = +1.
Step 2 (The Setup): The agent stays in position as the ball approaches. Reward = 0.
Step 1 (The Approach): The agent sees the ball coming and starts moving up. Reward = 0.
Even though Step 1 and 2 had 0 immediate reward, they received a massive positive advantage because they led to the point scoring at Step 3. The agent learns that “moving to the right spot” is just as valuable as the final hit.
3. The Total Loss Function
Finally, we combine everything into a single number that we want to minimize. The total loss function consists of three parts:
- Policy Loss (): “Do more of what worked.” We want to maximize this.
- Value Loss (): “Predict the score better.” We minimize the error (MSE) between predictions and actual returns.
- Entropy Bonus (): “Don’t be too sure of yourself.” We want to maximize entropy to encourage exploration.
To optimize all three simultaneously using standard gradient descent (which minimizes loss), we flip the signs of the terms we want to maximize:
Where,
- is the value loss coefficient.
- is the entropy bonus coefficient.
Let’s calculate the loss for Step 2 of the example above.
1. Inputs
- Advantage ():
0.480(This action was good!) - Current Value Prediction ():
0.50 - Target Value (): Since Advantage is the difference between Reality and Prediction (), we can calculate the Target as:
2. Hyperparameters
- (Value Loss Coefficient)
- (Entropy Coefficient)
3. Loss Calculation
-
Policy Term: Since the Advantage is positive (
0.48), the agent increases the action probability. Let’s say the ratio becomes1.2(20% more likely). -
Value Term: The Critic predicted
0.50but should have predicted0.98.
- Entropy Term: Let’s assume the entropy is currently 0.6.
Final Loss:
The loss is negative, which is good. The optimizer will try to make it even more negative by:
- Increasing the probability of this good action (increasing Policy term).
- Making the Critic’s prediction closer to
0.98(decreasing Value term).
Training Loop
1. Experience Collection (Rollout)
First, the agent interacts with the environment for num_steps (typically 128) to collect a batch of data. We disable gradient calculation here (torch.no_grad()) because we are only collecting data, not training yet.
for step in range(0, args.num_steps):
# 1. Forward Pass
with torch.no_grad():
action, logprob, _, value = agent.get_action_and_value(next_obs)
values[step] = value.flatten()
# 2. Storage
actions[step] = action
logprobs[step] = logprob
# 3. Environment Step
next_obs, reward, done, info = envs.step(action.cpu().numpy())
# 4. Update State
rewards[step] = torch.tensor(reward).to(device).view(-1)
next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(done).to(device)
- Forward Pass: We feed the
next_obs(from the previous step) into the agent to getaction,logprob, andvalue. We useno_grad()because we don’t want to backpropagate yet; we are just collecting data. - Storage: We save these into our buffers (
values,actions,logprobs) to use later for calculating advantages and losses. - Environment Step: We execute the
actionin the game environment. Note the.cpu().numpy()conversion because the environment expects standard Python arrays, not GPU tensors. - Update State: We convert the new
next_obsandrewardback to PyTorch Tensors and move them to the GPU (to(device)), ready for the next iteration of the loop.
2. Generalized Advantage Estimation (GAE)
Once we have a full batch of experience, we calculate the advantages backwards from the last step to the first. This is where we apply the GAE formula to balance bias and variance.
with torch.no_grad():
# 1. Bootstrap Value
next_value = agent.get_value(next_obs).reshape(1, -1)
advantages = torch.zeros_like(rewards).to(device)
lastgaelam = 0
# 2. Reverse Loop
for t in reversed(range(args.num_steps)):
if t == args.num_steps - 1:
nextnonterminal = 1.0 - next_done
nextvalues = next_value
else:
nextnonterminal = 1.0 - dones[t + 1]
nextvalues = values[t + 1]
# 3. Delta Calculation (TD Error)
delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t]
# 4. Recursive Advantage
advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam
# 5. Returns
returns = advantages + values
- Bootstrap Value: We need the value of the very last state () to kickstart the backward calculation.
- Reverse Loop: We iterate backwards from the end of the episode to the beginning to allow rewards to “flow” back in time.
- Delta Calculation: We compute the TD error . The
nextnonterminalterm ensures we don’t look past the end of an episode (if the game ended, next value is 0). - Recursive Advantage: We calculate the GAE using
lastgaelamto store the advantage from the previous step (which is since we are going backwards). - Returns: Finally, we compute the target returns which we will use to train the Value network.
3. Optimization (PPO Update)
Finally, we use the collected experience to update the neural network. We loop through the data multiple times (update_epochs) and calculate the Total Loss (Policy + Value - Entropy) to update the weights.
# Optimizing the policy and value network
for epoch in range(args.update_epochs):
np.random.shuffle(b_inds)
for start in range(0, args.batch_size, args.minibatch_size):
end = start + args.minibatch_size
mb_inds = b_inds[start:end]
# 1. Re-Evaluation
_, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_actions.long()[mb_inds])
logratio = newlogprob - b_logprobs[mb_inds]
ratio = logratio.exp()
# 2. Diagnostics (KL & Clipping)
with torch.no_grad():
old_approx_kl = (-logratio).mean()
approx_kl = ((ratio - 1) - logratio).mean()
clipfracs += [((ratio - 1.0).abs() > args.clip_coef).float().mean().item()]
mb_advantages = b_advantages[mb_inds]
if args.norm_adv:
mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8)
# 3. Policy Loss Calculation
pg_loss1 = -mb_advantages * ratio
pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef)
pg_loss = torch.max(pg_loss1, pg_loss2).mean()
# 4. Value Loss Calculation
newvalue = newvalue.view(-1)
if args.clip_vloss:
v_loss_unclipped = (newvalue - b_returns[mb_inds]) ** 2
v_clipped = b_values[mb_inds] + torch.clamp(
newvalue - b_values[mb_inds],
-args.clip_coef,
args.clip_coef,
)
v_loss_clipped = (v_clipped - b_returns[mb_inds]) ** 2
v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
v_loss = 0.5 * v_loss_max.mean()
else:
v_loss = 0.5 * ((newvalue - b_returns[mb_inds]) ** 2).mean()
# 5. Final Optimization Step
entropy_loss = entropy.mean()
loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef
optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm)
optimizer.step()
if args.target_kl is not None:
if approx_kl > args.target_kl:
break
- Re-Evaluation: We pass the batch of observations back through the network to get new probabilities and values. This is crucial because the policy changes slightly after every mini-batch update.
- Diagnostics (KL & Clipping): We calculate the Approximate KL Divergence to measure how much the policy has changed. The code calculates this using the estimator :
If this change is too drastic (larger than
target_kl), we stop the update early to preserve training stability. We also trackclipfracsto see how often the PPO clipping limits are triggering. - Policy Loss: We implement the Clipped Surrogate Objective here.
torch.clamphandles the clipping range , andtorch.max(since we are minimizing negative objective) picks the pessimistic bound. - Value Loss: We calculate the MSE between predicted values and actual returns. Note that we also clip the value updates (
v_clipped) to prevent the critic from changing too drastically in one go. - Final Optimization: We combine the three terms:
pg_loss(Policy),v_loss(Value), andentropy_loss. We then perform standard backpropagation (loss.backward()) and an optimizer step (optimizer.step()).
Backend
The backend is built with FastAPI and serves as the bridge between the browser and the Reinforcement Learning model. Because the agent was trained on raw pixels from the Atari emulator, the backend must run the exact same environment to ensure valid inference.
Tech Stack
| Technology | Role | Description |
|---|---|---|
| FastAPI | API Framework | High-performance Python framework for handling WebSocket connections. |
| PettingZoo | Environment | The Multi-Agent Reinforcement Learning (MARL) environment wrapper for Atari Pong. |
| PyTorch | Inference | Runs the trained PPO agent to predict actions from game states. |
| OpenCV | Rendering | Processes raw pixel frames from the emulator into JPEG images for the frontend. |
| SuperSuit | Preprocessing | Applies the same frame skipping, resizing, and stacking used during training. |
| uv | Package Manager | Fast, secure, and user-friendly package manager for Python. |
| WebSocket | Communication | Enables real-time bidirectional communication between the browser and backend. |
| Docker | Containerization | Ensures consistent environments across development, testing, and production. |
Architecture: Server-Side Rendering
Unlike typical web games where the game logic runs in JavaScript on the client, this project uses Server-Side Rendering (SSR) for the game state.
The RL agent’s “brain” is a Convolutional Neural Network (CNN) trained on specific pixel patterns (84x84 grayscale, stacked 4 frames deep) produced by the ALE (Arcade Learning Environment). Replicating the exact physics and rendering quirks of the Atari 2600 in a browser-based JavaScript emulator is incredibly difficult and prone to “distribution shift”, where slight visual differences confuse the agent.
Instead, this application runs the actual PettingZoo environment on the server and streams the visual output to the client.
- Server: Runs the game loop, queries the AI for actions, renders the frame.
- Network: Sends the frame (Base64 JPEG) via WebSocket.
- Client: Displays the image and sends back user keystrokes.
Real-Time Streaming via WebSockets
To achieve the low-latency required for a playable Pong game, the application relies on WebSockets rather than standard HTTP requests.
A traditional HTTP model (Request Response) would add too much overhead for streaming 15-30 frames per second. WebSockets facilitate Real-Time Distribution by:
- Persistent Connection: The “handshake” happens once. After that, data flows freely without repeating header overhead.
- Full-Duplex Communication: We can push visual updates (Server Client) at the exact same time the user pushes keystrokes (Client Server).
- Event-Driven: The backend triggers a network send immediately after a frame renders, ensuring the client sees the game state as “live” as possible.
Implementation Details
1. Environment Setup
The application uses the exact same SuperSuit wrappers as the training phase to ensure the agent sees what it expects.
def create_env():
"""Create Pong environment - using rgb_array for manual rendering"""
env = pong_v3.parallel_env(render_mode="rgb_array")
# Critical: Must match training preprocessing exactly
env = ss.max_observation_v0(env, 2) # Maximize over 2 frames (flicker fix)
env = ss.frame_skip_v0(env, 4) # Skip 4 frames (standard Atari)
env = ss.clip_reward_v0(env, lower_bound=-1, upper_bound=1)
env = ss.color_reduction_v0(env, mode="B") # B&W
env = ss.resize_v1(env, x_size=84, y_size=84) # Downsample to 84x84
env = ss.frame_stack_v1(env, 4) # Stack 4 frames
return env
2. The Game Loop
The core of the backend is an asyncio loop handling the WebSocket connection. It manages the game state, synchronizes the Human and AI actions, and maintains a stable frame rate.
async def game_loop():
env = create_env()
observations, infos = env.reset()
while state["running"]:
if is_paused:
await asyncio.sleep(0.1)
continue
# 1. AI Inference (Right Paddle)
if "first_0" in env.agents:
obs_tensor = torch.from_numpy(observations["first_0"]).float().unsqueeze(0).to(DEVICE)
with torch.no_grad():
action = agent.get_action(obs_tensor)
actions["first_0"] = action.item()
# 2. Human Input (Left Paddle)
if "second_0" in env.agents:
actions["second_0"] = state["human_action"]
# 3. Environment Step
observations, rewards, terminations, truncations, infos = env.step(actions)
# 4. Render & Send
frame = env.render()
frame_base64 = encode_frame(frame) # Helper to convert numpy -> jpg base64
await websocket.send_json({
"frame": f"data:image/jpeg;base64,{frame_base64}",
"reward": float(rewards.get("second_0", 0)),
"scores": episode_rewards
})
# Cap at ~15 FPS to match standard Atari play speed
await asyncio.sleep(1/15)



