Backpropagation: The Heart of Neural Network Training
•2 min read
Backpropagation: The Heart of Neural Network Training
Backpropagation is the algorithm that makes deep learning possible. Let's understand how it works mathematically and implement it from scratch.
The Forward Pass
Consider a simple 3-layer neural network. The forward pass can be described as:
Where is the activation function (e.g., sigmoid, ReLU).
The Loss Function
For binary classification, we use the cross-entropy loss:
Backpropagation Equations
The beauty of backpropagation lies in the chain rule. For the output layer:
For the hidden layer:
Where denotes element-wise multiplication.
Weight and Bias Updates
The gradients for weights and biases are:
Implementation
import numpy as np def sigmoid(z): return 1 / (1 + np.exp(-np.clip(z, -500, 500))) def sigmoid_derivative(z): s = sigmoid(z) return s * (1 - s) class NeuralNetwork: def __init__(self, input_size, hidden_size, output_size): # Initialize weights randomly self.W1 = np.random.randn(hidden_size, input_size) * 0.01 self.b1 = np.zeros((hidden_size, 1)) self.W2 = np.random.randn(output_size, hidden_size) * 0.01 self.b2 = np.zeros((output_size, 1)) def forward(self, X): self.z1 = self.W1 @ X + self.b1 self.a1 = sigmoid(self.z1) self.z2 = self.W2 @ self.a1 + self.b2 self.a2 = sigmoid(self.z2) return self.a2 def backward(self, X, y, learning_rate=0.01): m = X.shape[1] # Backward propagation dz2 = self.a2 - y dW2 = (1/m) * dz2 @ self.a1.T db2 = (1/m) * np.sum(dz2, axis=1, keepdims=True) dz1 = (self.W2.T @ dz2) * sigmoid_derivative(self.z1) dW1 = (1/m) * dz1 @ X.T db1 = (1/m) * np.sum(dz1, axis=1, keepdims=True) # Update parameters self.W1 -= learning_rate * dW1 self.b1 -= learning_rate * db1 self.W2 -= learning_rate * dW2 self.b2 -= learning_rate * db2
Key Insights
- Chain Rule: Backpropagation is just the chain rule applied systematically
- Efficiency: We compute gradients in reverse order, reusing computations
- Scalability: The algorithm scales to networks of any depth
Understanding these mathematical foundations is crucial for debugging neural networks and developing intuition about their behavior.