Sheet 6.3: Decoding strategies#
Given a (blackbox) function that gives us a next-word probability, how do we use this to generate naturally sounding text?
This tutorial explores a bunch of options, using the GPT-2 distribution provided by 🤗’s transformer
package.
The tutorial closely follows this blog post: https://huggingface.co/blog/how-to-generate
We will look at the following decoding strategies (in this order):
pure sampling
soft-max sampling
greedy sampling
beam search
top-\(k\) sampling
Preparation (installs, imports, defs)#
!pip install sentencepiece
!pip install datasets
!pip install transformers
# import relevant packages
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
# load the tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
# add the EOS token as PAD token to avoid warnings
model = GPT2LMHeadModel.from_pretrained("gpt2", pad_token_id=tokenizer.eos_token_id)
# convenience function for nicer output
def pretty_print(s):
print("Output:\n" + 100 * '-')
print(tokenizer.decode(s, skip_special_tokens=True))
# encode context the generation is conditioned on
input_ids = tokenizer.encode('I enjoy walking with my cute dog', return_tensors='pt')
Pure sampling approach#
In a pure sampling approach, we just sample each next word with exactly the probability assigned to it by the LM. Notice that this process, therefore, is non-determinisitic. We can force replicable results, though, by setting a seed.
# set a seed for reproducibility (if you want)
# torch.manual_seed(1996)
# use function 'model.generate' from `transformer` package to sample by
# setting `do_sample=True` and knocking out `top_k` sampling (see below)
sample_output = model.generate(
input_ids, # context to continue
do_sample=True, # use sampling (not beam search (see below))
max_length=50, # return maximally 50 words (including the input given)
top_k=0 # just sample one word
)
pretty_print(sample_output[0])
Output:
----------------------------------------------------------------------------------------------------
I enjoy walking with my cute dog Kass; we have a good book out, and a cab so she can hang out as she waits for us: A 404 - Love Noni "Perhaps the best way to describe murder in the arts is to say
Exercise 6.3.1:
How good is this production? Is it grammatical? Locally meaningful, globally meaningful?
[optional] Try sampling 100 single next-words for your initial sentence fragment. (Remember not to set a seed.)
Soft-max sampling#
In soft-max sampling, the probablity of sampling word \(w_i\) is \(P_{\text{sample}}(w_i \mid w_{1:i-1}) \propto \exp (\frac{1}{\tau} P_{\text{M}}(w_i \mid w_{1:i-1}) )\), where \(\tau\) is a temperature parameter.
# same as before but with `temperature`` parameter
SM_sample_output = model.generate(
input_ids, # context to continue
do_sample=True, # use sampling (not beam search (see below))
max_length=50, # return maximally 50 words (including the input given)
top_k=0, # just sample one word
temperature=0.7 # soft-max temperature parameter
)
pretty_print(SM_sample_output[0])
Output:
----------------------------------------------------------------------------------------------------
I enjoy walking with my cute dog, but in the end she is a bit stubborn and stubborn and she is not very happy with me anymore. This is the problem with the dog. I had her being more aggressive and more aggressive with me. I
Exercise 6.3.2:
How good is this production? Is it grammatical? Locally meaningful, globally meaningful?
Predict what will happen if you set \(\tau=5\). Try it out.
Greedy sampling#
In greedy sampling, we don’t actually sample but just take the most likely next-word at every step. Greedy sampling is equivalent to setting \(\tau=0\) for soft-max sampling.
# greedy sampling is the default of the `model.generate` function
greedy_output = model.generate(input_ids, max_length=50)
pretty_print(greedy_output[0])
Output:
----------------------------------------------------------------------------------------------------
I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with my dog. I'm not sure if I'll ever be able to walk with my dog.
I'm not sure if I'll
Exercise 6.3.3:
How good is this production? Is it grammatical? Locally meaningful, globally meaningful?
Is greedy sampling guaranteed to select the most likely sequence? Or can it be led astray?
Beam search#
In simplified terms, beam search is a parallel search procedure that keeps a number \(k\) of path probabilities open at each choice point, dropping the least likely as we go along. (There is actually no unanimity in what exactly beam search means for NLG.)
# option `early_stopping` implies stopping when all beams reach the end-of-sentence token
beam_output = model.generate(
input_ids,
max_length=50,
num_beams=5,
early_stopping=True
)
pretty_print(beam_output[0])
Output:
----------------------------------------------------------------------------------------------------
I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with him again.
I'm not sure if I'll ever be able to walk with him again. I'm not sure if I'll
Exercise 6.3.4:
How good is this production? Is it grammatical? Locally meaningful, globally meaningful?
Try out the option
no_repeat_ngram_size=2
and see if it improves the results. This option supresses generation of \(n\)-grams of the given size. Play around with the number \(n\) supplied.
Top-\(k\) sampling#
This sampling scheme looks at the \(k\) most likely next-words and samples from so that:
top_k_output = model.generate(
input_ids,
do_sample=True,
max_length=50,
top_k=50 # setting `top_k` option triggers top-k sampling
)
pretty_print(top_k_output[0])
Output:
----------------------------------------------------------------------------------------------------
I enjoy walking with my cute dog, but it's not for me to try to please a puppy."
Michele, who lives in Fort Meade, Ga., said she never really wanted to "lose the love" of her
Task:
How good is this production? Is it grammatical? Locally meaningful, globally meaningful?
Exercise 6.3.5:
How good is this production? Is it grammatical? Locally meaningful, globally meaningful?
Top-\(p\) sampling#
Top-\(p\) sampling is similar to top-\(k\) sampling, but restricts sampling not to the top-\(k\) most likely words (so always the same number of words), but the set of most likely words the summed probability of which exceeds threshold \(p\).
top_k_output = model.generate(
input_ids,
do_sample=True,
max_length=50,
top_p=0.9 # set the top-p parameter here
)
pretty_print(top_k_output[0])
Output:
----------------------------------------------------------------------------------------------------
I enjoy walking with my cute dog and he's a really nice person. The second time we got to our house, my mother came over and brought a bucket of water. I'm very happy with it. She was just a little upset that I
Exercise 6.3.6:
How good (grammatical, locally and globally coherent) is this output?
In which cases would the next-word predictions of top-\(k\) and top-\(p\) divergence quite a bit?
Exercise 6.3.7: Comparison of different decoding schemes.
Which of the decoding schemes included in this work sheet is a special case of which other decoding scheme(s)? E.g., X is a special case of Y if the behavior of Y is obtained when we set certain paramters of X to specific values.