Bayesian regression: theory & practice

02c: Categorical predictors (exercises)

Author

Michael Franke & Timo Roettger

Here is code to load (and if necessary, install) required packages, and to set some global options (for plotting and efficient fitting of Bayesian models).

Code
# install packages from CRAN (unless installed)
pckgs_needed <- c(
  "tidyverse",
  "brms",
  "remotes",
  "tidybayes"
)
pckgs_installed <- installed.packages()[,"Package"]
pckgs_2_install <- pckgs_needed[!(pckgs_needed %in% pckgs_installed)]
if(length(pckgs_2_install)) {
  install.packages(pckgs_2_install)
} 

# install additional packages from GitHub (unless installed)
if (! "aida" %in% pckgs_installed) {
  remotes::install_github("michael-franke/aida-package")
}
if (! "faintr" %in% pckgs_installed) {
  remotes::install_github("michael-franke/faintr")
}
if (! "cspplot" %in% pckgs_installed) {
  remotes::install_github("CogSciPrag/cspplot")
}

# load the required packages
x <- lapply(pckgs_needed, library, character.only = TRUE)
library(aida)
library(faintr)
library(cspplot)

# these options help Stan run faster
options(mc.cores = parallel::detectCores())

# use the CSP-theme for plotting
theme_set(theme_csp())

# global color scheme from CSP
project_colors = cspplot::list_colors()[c(1,3,4,5,2,6:14),"hex", drop = TRUE]

# setting theme colors globally
scale_colour_discrete <- function(...) {
  scale_colour_manual(..., values = project_colors)
}
scale_fill_discrete <- function(...) {
   scale_fill_manual(..., values = project_colors)
}
dolphin <- aida::data_MT

Regression w/ multiple categorical predictors

We want to regress log RT against the full combination of categorical factors group, condition, and prototype_label.

log RT ~ group * condition * prototype_label

The research hypotheses we would like to investigate are:

  1. Typical trials are faster than atypical ones.
  2. CoM trials are slower than the other kinds of trials (straight and curved) together, and respectively.
  3. ‘straight’ trials are faster than ‘curved’ trials.
  4. Click trials are slower than touch trials.

But for this to work (without at least mildly informative priors), we would need to have a sufficient amount of observations in each cell. So, let’s check:

dolphin |>
  mutate(group = as_factor(group),
         condition = as_factor(condition),
         prototype_label = as_factor(prototype_label)) |>
  count(group, condition, prototype_label, .drop = FALSE) |>
  arrange(n)
# A tibble: 20 × 4
   group condition prototype_label     n
   <fct> <fct>     <fct>           <int>
 1 touch Atypical  dCoM2               0
 2 touch Typical   dCoM2               0
 3 touch Atypical  dCoM                9
 4 click Atypical  dCoM2              11
 5 click Typical   dCoM2              11
 6 touch Typical   dCoM               14
 7 touch Atypical  cCoM               21
 8 touch Typical   cCoM               31
 9 click Atypical  cCoM               32
10 click Atypical  curved             36
11 touch Atypical  curved             37
12 click Typical   cCoM               48
13 click Atypical  dCoM               50
14 click Typical   dCoM               52
15 touch Typical   curved             72
16 click Typical   curved             84
17 click Atypical  straight          189
18 touch Atypical  straight          263
19 click Typical   straight          494
20 touch Typical   straight          598

So, there are cells for which we have no observations at all. For simplicity, we therefore just lump all “change of mind”-type trajectories into one category:

dolphin_prepped <-
  dolphin |>
  mutate(
    prototype_label = case_when(
     prototype_label %in% c('curved', 'straight') ~ prototype_label,
     TRUE ~ 'CoM'
    ),
    prototype_label = factor(prototype_label,
                             levels = c('straight', 'curved', 'CoM')))

dolphin_prepped |>
  select(RT, prototype_label)
# A tibble: 2,052 × 2
      RT prototype_label
   <dbl> <fct>          
 1   950 straight       
 2  1251 straight       
 3   930 curved         
 4   690 curved         
 5   951 CoM            
 6  1079 CoM            
 7  1050 CoM            
 8   830 straight       
 9   700 straight       
10   810 straight       
# … with 2,042 more rows

Here is a plot of the data to be analyzed:

dolphin_prepped |>
  ggplot(aes(x = log(RT), fill = condition)) +
  geom_density(alpha = 0.4) +
  facet_grid(group ~ prototype_label)

Exercise 1

Exercise 1a

Use brm() to run a linear regression model for the data set dolphin_prepped and the formula:

log RT ~ group * condition * prototype_label

Set the prior for all population-level slope coefficients to a reasonable, weakly-informative but unbiased prior.

Show solution
fit <- brm(
  formula = log(RT) ~ group * condition * prototype_label,
  prior   = prior(student_t(1, 0, 3), class = "b"),
  data    = dolphin_prepped
  )

Exercise 1b

Plot the posteriors for population-level slope coefficients using the tidybayes package in order to:

  1. determine which combination of factor levels is the default cell
  2. check which coefficients have 95% CIs that do not include zero
  3. try to use this latter information to address any of our research hypotheses (stated above)
Show solution
tidybayes::summarise_draws(fit)
tidybayes::gather_draws(fit, `b_.*`, regex = TRUE) |>
  filter(.variable != "b_Intercept") |>
  ggplot(aes(y = .variable, x = .value)) +
  tidybayes::stat_halfeye() +
  labs(x = "", y = "") +
  geom_segment(x = 0, xend = 0, y = Inf, yend = -Inf,
               lty = "dashed")

# the default cell is for click-atypical-straight

# coeffiencents with 95% CIs that do not include zero are:
#   grouptouch, conditionTypical, prototype_labelCoM

# none of these give us direct information about our research hypotheses

Exercise 1c

Use the faintr package to get information relevant for the current research hypotheses. Interpret each result with respect to what we may conclude from it.

Show solution
# 1. Typical trials are faster than atypical ones.
# -> There is overwhelming evidence that this is true
#    (given the data and the model).

faintr::compare_groups(
  fit,
  lower  = condition == 'Typical',
  higher = condition == 'Atypical'
)

# 2. CoM trials are slower than the other kinds of trials
#    (straight and curved) together, and respectively.
# -> There is overwhelming evidence that this is true
#    (given the data and the model).

faintr::compare_groups(
  fit,
  lower  = prototype_label != 'CoM',
  higher = prototype_label == 'CoM'
)

faintr::compare_groups(
  fit,
  lower  = prototype_label == 'straight',
  higher = prototype_label == 'CoM'
)

# 3. 'straight' trials are faster than 'curved' trials.
# -> There is no evidence for this hypothesis
#    (given the data and the model).

faintr::compare_groups(
  fit,
  lower  = prototype_label == 'straight',
  higher = prototype_label == 'curved'
)

# 4. Click trials are slower than touch trials.
# -> There is overwhelming evidence that this is true
#    (given the data and the model).

faintr::compare_groups(
  fit,
  lower  = group == 'touch',
  higher = group == 'click'
)

Regression w/ metric & categorical predictors

Exercise 2

Exercise 2a

Create a new dataframe that contains only the mean values of the RT, and MAD for each animal (exemplar) and for correct and incorrect responses. Print out the head of the new dataframe.

Show solution
# aggregate
dolphin_agg <- dolphin |> 
  group_by(exemplar, correct) |> 
  dplyr::summarize(MAD = mean(MAD, na.rm = TRUE),
                   RT = mean(RT, na.rm = TRUE))
  
# let's have a look
head(dolphin_agg)

Exercise 2b

Run a linear regression using brms. MAD is the dependent variable (i.e. the measure) and both RT and correct are independent variables (MAD ~ RT + correct). Tip: the coefficients might be really really small, so make sure the output is printed with enough numbers after the comma.

Show solution
# specify the model 
model2 = brm(
  # model formula
  MAD ~ RT + correct, 
  # data
  data = dolphin_agg
  )

print(summary(model2), digits = 5)

Try to understand the coefficient table. There is one coefficient for RT and one coefficient for correct which gives you the change in MAD from incorrect to correct responses.

Exercise 2c

Plot a scatter plot of MAD ~ RT and color code it for correct responses (Tip: Make sure that correct is treated as a factor and not a numeric vector). Draw two predicted lines into the scatterplot. One for correct responses (“lightblue”) and one for incorrect responses (“orange”).

Show solution
dolphin_agg$correct <- as.factor(as.character(dolphin_agg$correct))

# extract model parameters:
model_intercept <- summary(model2)$fixed[1,1]
model_RT <- summary(model2)$fixed[2,1]
model_correct <- summary(model2)$fixed[3,1]

# plot
ggplot(data = dolphin_agg, 
       aes(x = RT, 
           y = MAD,
           color = correct)) + 
  geom_abline(intercept = model_intercept, slope = model_RT, color = "orange", size  = 2) +
  geom_abline(intercept = model_intercept + model_correct , slope = model_RT, color = "lightblue",size  = 2) +
  geom_point(size = 3, alpha = 0.3)

Exercise 2d

Extract the posteriors for the coefficients of both RT and correct from the model output (use the spread_draws() function), calculate their means and a 67% Credible Interval. Print out the head of the aggregated dataframe.

Show solution
# get posteriors for the relevant coefficients
posteriors2 <- model2 |>
  spread_draws(b_RT, b_correct) |>
  select(b_RT, b_correct) |> 
  gather(key = "parameter", value = "posterior")

# aggregate
posteriors2_agg <- posteriors2 |> 
  group_by(parameter) |> 
  summarise(mean_posterior = mean(posterior),
            `67lowerCrI` = HDInterval::hdi(posterior, credMass = 0.67)[1],
            `67higherCrI` = HDInterval::hdi(posterior, credMass = 0.67)[2]
            )

# print out
posteriors2_agg

Exercise 2e

Plot the scatterplot from 2c and plot 50 sample tuples for the regression lines for correct and incorrect responses.

Show solution
# sample 50 random numbers from 4000 samples
random_50 <- sample(1:4000, 50, replace = FALSE)
  
# wrangle data frame
posteriors3 <- model2 |>
  spread_draws(b_Intercept, b_RT, b_correct) |>
  select(b_Intercept, b_RT, b_correct) |> 
  # filter by the row numbers in random_50
  slice(random_50)
  
# plot
ggplot(data = dolphin_agg, 
       aes(x = RT, 
           y = MAD, 
           color = correct)) + 
  geom_abline(data = posteriors3,
              aes(intercept = b_Intercept, slope = b_RT), 
              color = "orange", size  = 0.1) +
  geom_abline(data = posteriors3,
              aes(intercept = b_Intercept + b_correct, slope = b_RT), 
              color = "lightblue", size  = 0.1) +
  geom_point(size = 3, alpha = 0.3)

Exercise 2f

Given our model and our data, calculate the evidential ratio of correct responses exhibiting larger MADs than incorrect responses.

Show solution
hypothesis(model2, 'correct > 0')

Exercise 3

Here is an aggregated data set dolphin_agg for you.

# aggregate
dolphin_agg <- dolphin %>% 
  group_by(group, exemplar) %>% 
  dplyr::summarize(MAD = median(MAD, na.rm = TRUE),
                   RT = median(RT, na.rm = TRUE)) %>% 
  mutate(log_RT = log(RT))

Exercise 3a

Standardize (“z-transform”) log_RT such that the mean is at zero and 1 unit corresponds to the standard deviation. Name it log_RT_s.

Show solution
dolphin_agg$log_RT_s <- scale(dolphin_agg$log_RT, scale = TRUE)

Exercise 3b

Run a linear model with brms that predicts MAD based on log_RT_s, group, and their two-way interaction.

Show solution
model1 = brm(
  MAD ~ log_RT_s * group, 
  data = dolphin_agg
  )

Exercise 3c

Plot MAD (y) against log_RT_s (x) in a scatter plot and color-code for group. Plot the regression lines for the click and the touch group into the plot and don’t forget to take possible interactions into account.

Show solution
# extract posterior means for model coefficients
Intercept = summary(model1)$fixed[1,1]
log_RT = summary(model1)$fixed[2,1]
group = summary(model1)$fixed[3,1]
interaction = summary(model1)$fixed[4,1]

# plot
ggplot(data = dolphin_agg, 
       aes(x = log_RT_s, 
           y = MAD, 
           color = group)) + 
  geom_point(size = 3, alpha = 0.3) +
  geom_vline(xintercept = 0, lty = "dashed") +
  geom_abline(intercept = Intercept, slope = log_RT, 
              color = project_colors[1], size = 2) +
  geom_abline(intercept = Intercept + group, slope = log_RT + interaction, 
              color = project_colors[2], size = 2) 

Exercise 3d

Specify very skeptic priors for all three coefficients. Use a normal distribution with mean = 0, and sd = 10. Rerun the model with those priors.

Show solution
# specify priors
priors_model2 <- c(
  set_prior("normal(0,10)", class = "b", coef = "log_RT_s"),
  set_prior("normal(0,10)", class = "b", coef = "grouptouch"),
  set_prior("normal(0,10)", class = "b", coef = "log_RT_s:grouptouch")
)

# model
model2 = brm(
  MAD ~ log_RT_s * group, 
  data = dolphin_agg,
  prior = priors_model2
  )

Exercise 3e

Compare the model output of model1 to model2. What are the differences and what are the reasons for these differences?

Show solution
# We can compare the model predictions by looking at the coefficients / plotting them:
summary(model1)
summary(model2)

# extract posterior means for model coefficients
Intercept = summary(model2)$fixed[1,1]
log_RT = summary(model2)$fixed[2,1]
group = summary(model2)$fixed[3,1]
interaction = summary(model2)$fixed[4,1]

# plot
ggplot(data = dolphin_agg, 
       aes(x = log_RT_s, 
           y = MAD, 
           color = group)) + 
  geom_point(size = 3, alpha = 0.3) +
  geom_vline(xintercept = 0, lty = "dashed") +
  geom_abline(intercept = Intercept, slope = log_RT, 
              color = project_colors[1], size = 2) +
  geom_abline(intercept = Intercept + group, slope = log_RT + interaction, 
              color = project_colors[1], size = 2) 

# ANSWER: the magnitude of the coefficients is much smaller in model2, with the interaction term being close to zero. As a result, the lines in the plot are closer together and run in parallel. The reason for this change lies in the priors. We defined the priors of model2 rather narrowly, down weighing  data points larger or smaller than zero. This is a case of the prior dominating the posterior.