Gradient Descent Example
Problem
Explain one iteration of gradient descent for
starting from with learning rate
Step-by-Step Explanation
1. Compute the Gradient
At :
2. Gradient Descent Update Rule
Substitute values:
3. Function Values
📌 The function value decreases, so the step moves us closer to the minimum at .
Final Numerical Result
Interpretation
-
The gradient at points away from the minimum
-
Moving in the negative gradient direction reduces
-
Repeating this process will converge to the global minimum at
Python Code and Visualization
import numpy as np
import matplotlib.pyplot as plt
def f(x):
return x**2
def grad_f(x):
return 2*x
x0 = 4
eta = 0.1
x1 = x0 - eta * grad_f(x0)
x_vals = np.linspace(-5, 5, 400)
y_vals = f(x_vals)
plt.figure()
plt.plot(x_vals, y_vals)
plt.scatter([x0, x1], [f(x0), f(x1)])
plt.plot([x0, x1], [f(x0), f(x1)])
plt.xlabel("x")
plt.ylabel("f(x)")
plt.title("One Iteration of Gradient Descent for f(x) = x^2")
plt.show()
Problem
Minimize
using gradient descent, starting from
Step 1: Mathematical Setup
Gradient
Update Rule
📌 Each iteration multiplies the current value by 0.8, moving it closer to zero.
Step 2: Iterations (First Few)
| Iteration | ||
|---|---|---|
| 0 | 4.0000 | 16.0000 |
| 1 | 3.2000 | 10.2400 |
| 2 | 2.5600 | 6.5536 |
| 3 | 2.0480 | 4.1943 |
| 4 | 1.6384 | 2.6844 |
| ⋮ | ⋮ | ⋮ |
| → | → | → |
| ≈ | 0 | 0 |
The sequence converges to the global minimum
Step 3: Stopping Criterion
We stop when:
import numpy as np
import matplotlib.pyplot as plt
# Define function and gradient
def f(x):
return x**2
def grad_f(x):
return 2*x
# Parameters
eta = 0.1
x0 = 4.0
tolerance = 1e-4
max_iter = 50
# Gradient Descent Iterations
x_vals_iter = [x0]
x = x0
for i in range(max_iter):
grad = grad_f(x)
if abs(grad) < tolerance:
break
x = x - eta * grad
x_vals_iter.append(x)
# Prepare plot
x_plot = np.linspace(-5, 5, 400)
y_plot = f(x_plot)
plt.figure()
plt.plot(x_plot, y_plot)
plt.scatter(x_vals_iter, [f(xi) for xi in x_vals_iter])
plt.plot(x_vals_iter, [f(xi) for xi in x_vals_iter])
plt.xlabel("x")
plt.ylabel("f(x)")
plt.title("Gradient Descent Iterations for f(x) = x^2")
plt.show()
.png)
.png)
Comments
Post a Comment