Sheet 3.2: Optimizing an RSA model#

Author: Michael Franke

Here we will explore how to use PyTorch to find optimized values for the parameters of a vanilla RSA model for reference games. This serves several purposes: (i) it provides a chance to exercise with the basics of parameter optimization in PyTorch; and (ii) we learn to think about models as objects that can (and must!) be critically tested with respect to their predictive ability.

To fit a vanilla RSA model, we use data from Qing & Franke (2016). A Bayesian data analysis for this data set and model set up is provided in this chapter of problang.org.

Packages#

We will need to import the `torch` package for the main functionality. In order to have a convenient handle, we load the `torch.nn.functional` package into variable `F`. We use this to refer to the normalization function for tensors: `F.normalize`. We use the `warnings` package to suppress all warning messages in the notebook.

import torch
import torch.nn.functional as F
import warnings
warnings.filterwarnings('ignore')

Context model#

The context model for the reference game is the same as we used before (in Sheet 1.1).

##################################################
## data to fit
##################################################

object_names     = ['blue_circle', 'green_square', 'blue_square']
utterance_names  = ['blue', 'circle', 'green', 'square']
semantic_meaning = torch.tensor(
    # blue circle, green square, blue square
    [[1, 0, 1],  # blue
     [1, 0, 0],  # circle
     [0, 1, 0],  # green
     [0, 1, 1]],  # square,
    dtype= torch.float32
    )

The empirical data#

We use empirical data from Qing & Franke (2016). There were three tasks: (i) speaker production choice, and (ii) listener interpretation choice, and (iii) salience prior elicitation. All three tasks were forced-choice tasks, in which participants had to select a single option from a small list of options.

In the speaker production task, participants were presented with the three referents. They were told which object they should refer to. They selected one option from the list of available utterances.

In the listener interpretation task, participants were presented with the three referents and an utterance. They selected the object that they thought the speaker meant to refer to with that utterance.

In the salience prior elicitation task, participants again saw all three referents. They were told that the speaker wanted to refer to one of these objects with a word in a language they did not know. Again, they were asked to select the object they thought the speaker wanted to refer to. Since this task rids all reasoning about semantic meaning, it is argued to represent a salience baseline of which object is a likely topic of conversation.

We use the data from the salience prior condition to feed into the pragmatic listener model. The data from the speaker production and the listener interpretation tasks is our training data, i.e., what we want to explain.

##################################################
## data to fit
##################################################

salience_prior = F.normalize(torch.tensor([71,139,30],
                                          dtype = torch.float32),
                             p = 1, dim = 0)

# matrix of number of utterance choices for each state
# (rows: objects, columns: utterances)
production_data = torch.tensor([[9, 135, 0, 0],
                                [0, 0, 119, 25],
                                [63, 0, 0, 81]])

# matrix of number of object choices for each ambiguous utterance
# (rows: utterances, columns: objects)
interpretation_data = torch.tensor([[66, 0, 115],   # "blue"
                                    [0, 117, 62]])  # "square"

The RSA model (in PyTorch)#

Here is an implementation of the vanilla RSA model in PyTorch.

##################################################
## RSA model (forward pass)
##################################################

def RSA(alpha, cost_adjectives):
    costs = torch.tensor([1.0, 0, 1.0, 0]) * cost_adjectives
    literal_listener   = F.normalize(semantic_meaning, p = 1, dim = 1)
    pragmatic_speaker  = F.normalize(torch.t(literal_listener)**alpha *
                                     torch.exp(-alpha * costs), p = 1, dim = 1)
    pragmatic_listener = F.normalize(torch.t(pragmatic_speaker) * salience_prior, p = 1, dim = 1)
    return({'speaker': pragmatic_speaker, 'listener': pragmatic_listener})

print("speaker predictions:\n", RSA(1, 1.6)['speaker'])
speaker predictions:
 tensor([[0.0917, 0.9083, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.2876, 0.7124],
        [0.1680, 0.0000, 0.0000, 0.8320]])

Parameters to optimize#

The vanilla RSA model has two free parameters: the optimality parameter \(\alpha\) and the parameter for the cost of utterance, here restricted to a single number for the cost of an adjective (relative to a noun). Since we want to optimize the value of these variables, we require PyTorch to compute gradients.

##################################################
## model parameters to fit
##################################################

alpha           = torch.tensor(1.0, requires_grad=True) # soft-max parameter
cost_adjectives = torch.tensor(0.0, requires_grad=True) # differential cost of 'adjectives'

Optimization#

To optimize the model parameters with stochastic gradient descent, we first instantiate an optimizer object, which we tell about the parameter to optimize. The we iterate the training cycle, each time calling the RSA model (feed-forward pass) with the current parameter values, and then computing the (negative) log-likelihood of the data.

##################################################
## optimization
##################################################

opt = torch.optim.SGD([alpha, cost_adjectives], lr = 0.0001)

# output header
print('\n%5s %24s %15s %15s' %
      ("step", "loss", "alpha", "cost") )

for i in range(4000):

    RSA_prediction      = RSA(alpha, cost_adjectives)
    speaker_pred        = RSA_prediction['speaker']
    Multinomial_speaker = torch.distributions.multinomial.Multinomial(144, probs = speaker_pred)
    logProbs_speaker    = Multinomial_speaker.log_prob(production_data)

    listener_pred          = RSA_prediction['listener']
    Multinomial_listener_0 = torch.distributions.multinomial.Multinomial(181,probs = listener_pred[0,])
    logProbs_listener_0    = Multinomial_listener_0.log_prob(interpretation_data[0,])
    Multinomial_listener_1 = torch.distributions.multinomial.Multinomial(179,probs = listener_pred[3,])
    logProbs_listener_1    = Multinomial_listener_1.log_prob(interpretation_data[1,])

    loss = -torch.sum(logProbs_speaker) - logProbs_listener_0 - logProbs_listener_1

    loss.backward()

    if (i+1) % 250 == 0:
        print('%5d %24.5f %15.5f %15.5f' %
              (i + 1, loss.item(), alpha.item(),
               cost_adjectives.item()) )

    opt.step()
    opt.zero_grad()
#+begin_example

 step                     loss           alpha            cost
  250                 21.74205         2.12154         0.17193
  500                 16.10578         2.47786         0.15869
  750                 15.55774         2.58906         0.15650
 1000                 15.50400         2.62389         0.15597
 1250                 15.49873         2.63481         0.15582
 1500                 15.49818         2.63825         0.15577
 1750                 15.49814         2.63933         0.15576
 2000                 15.49815         2.63966         0.15575
 2250                 15.49813         2.63977         0.15575
 2500                 15.49815         2.63979         0.15575
 2750                 15.49815         2.63979         0.15575
 3000                 15.49815         2.63979         0.15575
 3250                 15.49815         2.63979         0.15575
 3500                 15.49815         2.63979         0.15575
 3750                 15.49815         2.63979         0.15575
 4000                 15.49815         2.63979         0.15575
#+end_example

Exercise 2.3.1: Comparing model variants

  1. We have so far implemented the literal listener as \(P_{lit}(s \mid u) \propto L_{ij}\). But some RSA models also include the salience prior, which we have so far only used in the pragmatic listener part into the literal listener model. Under this alternative construction the literal listener would be defined as \(P_{lit}(s \mid u) \propto P_{sal}(s) \ L_{ij}\). Change the `RSA` function to implement this alternative definition. (Hint: you only need to add this string somewhere in the code: `* salienceprior`.) Run the model otherwise as is. Inspect the output of the optimization loop. Use this information to draw conclusions about which of the two model variants is a better predictor of the data.

  2. Go back to the original model. We now want to address whether we actually need the cost parameter. Run the original model (w/ a literal listener w/o salience prior information), but optimize only the \(\alpha\) parameter. The cost parameter should be initialized to 0 and stay this way. Fit the model and use the output information to draw conclusions about which model is better: with or without a flexible cost parameter.

References#

Qing, C., & Franke, M. (2015). Variations on a Bayesian theme: Comparing Bayesian models of referential reasoning. In H. Zeevat, & H. Schmitz (Eds.), Bayesian Natural Language Semantics and Pragmatics (pp. 201–220). Berlin: Springer.