Sheet 3.1: Gradient descent by hand#
Author: Michael Franke
This short notebook will optimize a parameter with gradient descent without using PyTorch’s optimizer. The purpose of this is to demonstrate how vanilla GD works under the hood. We use the previous example of finding the MLE for a Gaussian mean.
Packages#
We will need the usual packages.
import torch
import warnings
warnings.filterwarnings('ignore')
Training data#
The training data are `nObs` samples from a standard normal.
nObs = 10000
trueLocation = 0 # mean of a normal
trueDist = torch.distributions.Normal(loc=trueLocation, scale=1.0)
trainData = trueDist.sample([nObs])
empirical_mean = torch.mean(trainData)
Training by manual gradient descent#
We will actually train two parameters on the same data in parallel. `location` will be updated by hand; `location2` will be updated with PyTorch’s `SGD` optimizer. We will use the same learning rate for both.
location = torch.tensor(1.0, requires_grad=True)
location2 = torch.tensor(1.0, requires_grad=True)
learningRate = 0.00001
nTrainingSteps = 100
opt = torch.optim.SGD([location2], lr = learningRate)
The training loop here first updates by hand, then using the built-in`SGD`. Every 5 rounds we output the current value of `location` and `location2`, as well as the difference between them.
But, oh no! What’s this? There must be a bunch of mistakes in this code! See Exercise below.
print('\n%5s %15s %15s %15s' %
("step", "estimate", "estimate2", "difference") )
for i in range(nTrainingSteps):
# manual computation
prediction = torch.distributions.Normal(loc=location, scale=1.0)
loss = -torch.sum(prediction.log_prob(trainData))
loss.backward()
with torch.no_grad():
# we must embedd this under 'torch.no_grad()' b/c we
# do not want this update state to affect the gradients
location -= learningRate * location.grad
location.grad = torch.tensor(1.0)
# using PyTorch optimizer
prediction2 = torch.distributions.Normal(loc=location2, scale=1.0)
loss2 = -torch.sum(prediction2.log_prob(trainData-1))
loss2.backward()
opt.step()
opt.zero_grad()
# print output
if (i+1) % 5 == 0:
print('\n%5s %-2.14f %-2.14f %2.14f' %
(i + 1, location.item(), location2.item(),
location.item() - location2.item()) )
#+begin_example
step estimate estimate2 difference
5 0.00685740495101 -0.99302691221237 0.99988431716338
10 0.00685102678835 -0.99303966760635 0.99989069439471
15 0.00684725819156 -0.99304723739624 0.99989449558780
20 0.00684503605589 -0.99305170774460 0.99989674380049
25 0.00684372335672 -0.99305438995361 0.99989811331034
30 0.00684294663370 -0.99305593967438 0.99989888630807
35 0.00684248795733 -0.99305683374405 0.99989932170138
40 0.00684221973643 -0.99305737018585 0.99989958992228
45 0.00684205908328 -0.99305766820908 0.99989972729236
50 0.00684196501970 -0.99305790662766 0.99989987164736
55 0.00684191146865 -0.99305790662766 0.99989981809631
60 0.00684187747538 -0.99305790662766 0.99989978410304
65 0.00684185326099 -0.99305790662766 0.99989975988865
70 0.00684183789417 -0.99305790662766 0.99989974452183
75 0.00684183835983 -0.99305790662766 0.99989974498749
80 0.00684183696285 -0.99305790662766 0.99989974359050
85 0.00684183742851 -0.99305790662766 0.99989974405617
90 0.00684183789417 -0.99305790662766 0.99989974452183
95 0.00684183835983 -0.99305790662766 0.99989974498749
100 0.00684183696285 -0.99305790662766 0.99989974359050
#+end_example
Exercise 3.1.1: Understand vanilla gradient descent
Find and correct all mistakes in this code block. When you are done, the parameters should show no difference at any update step, and they should both converge to the empirical mean.