Generalized additive models in brms

Bayesian regression: theory & practice

Author

Michael Franke

This tutorial provides both a conceptual and a practical introduction to fitting generalized additive models (GAMs) in brms. GAMs approximate wiggly curves by “smoothed splines”. The central idea to internalize here is that we can think of smoothed splines as a random effect. This is, indeed, how brms deals with GAMs.

The material here leans heavily on T.J. Mahr’s excellent tutorial.

Preamble

Here is code to load (and if necessary, install) required packages, and to set some global options (for plotting and efficient fitting of Bayesian models).

Toggle code
# install packages from CRAN (unless installed)
pckgs_needed <- c(
  "tidyverse",
  "brms",
  "rstan",
  "rstanarm",
  "remotes",
  "tidybayes",
  "bridgesampling",
  "shinystan",
  "mgcv"
)
pckgs_installed <- installed.packages()[,"Package"]
pckgs_2_install <- pckgs_needed[!(pckgs_needed %in% pckgs_installed)]
if(length(pckgs_2_install)) {
  install.packages(pckgs_2_install)
} 

# install additional packages from GitHub (unless installed)
if (! "aida" %in% pckgs_installed) {
  remotes::install_github("michael-franke/aida-package")
}
if (! "faintr" %in% pckgs_installed) {
  remotes::install_github("michael-franke/faintr")
}
if (! "cspplot" %in% pckgs_installed) {
  remotes::install_github("CogSciPrag/cspplot")
}

# load the required packages
x <- lapply(pckgs_needed, library, character.only = TRUE)
library(aida)
library(faintr)
library(cspplot)

# these options help Stan run faster
options(mc.cores = parallel::detectCores())

# use the CSP-theme for plotting
theme_set(theme_csp())

# global color scheme from CSP
project_colors = cspplot::list_colors() |> pull(hex)
# names(project_colors) <- cspplot::list_colors() |> pull(name)

# setting theme colors globally
scale_colour_discrete <- function(...) {
  scale_colour_manual(..., values = project_colors)
}
scale_fill_discrete <- function(...) {
   scale_fill_manual(..., values = project_colors)
}

Data set and mission statement

Let’s start with a common example for introductions to GAMs, the non-linear relationship between time (after impact) and acceleration (of the head of a motorcycle driver). The data is from the MASS package and it looks like this:

Toggle code
# motorcycle data
data_motor <- MASS::mcycle |> 
  tibble() |> 
  unique()

# plot the data
data_motor |> 
  ggplot(aes(x = times, y = accel)) +
  geom_smooth(color = project_colors[1]) +
  geom_point()

Notice that the plotting function geom_smooth already provides us with a smoothing line to fit the central tendency of this this “wiggly data”. But this one is not very good. We can probably do better.

The goals for the remainder of this tutorial are:

  1. To understand how a wiggly curve can be constructed from a set of “atomic wiggles”.
  2. To see the parallel between (i) fitting weights over “atomic wiggles”, with a penalty against excessive wiggliness, and fitting (ii) group-level effects.
  3. To be able to run a Bayesian regression model with penalized smoothing splines using brms.

Wiggly curves as mixtures over “atomic wiggles”

Let’s play a game! It’s called “make a wiggle”. Here’s how it’s played: I give you a bunch of “atomic wiggles”, you adjust some weights, and out pops … TADA! … a smoothing spline.

Or, for those dull at heart, in technical terms, I give you a set of basis functions. This set is also called basis. Then you build intuitions about how, by adjusting numerical weights for the basis functions, we can flexibly approximate non-linear data.

The following code gives you a function get_basis() which takes as input a sequence of observations \(x\) for which we would like to construct a non-linear line. The output is a set of “atomic wiggles” (see plot below). The function also takes an integer value k as input, which is the number of “atomic wiggles” that you get.

Do not worry about the details of this function, better to directly look at its output. But if you must know: this function uses the gam function from the mgcv package to construct (indirectly, through the model matrix internally built up by mgcv::gam) k cubic regression splines. The brms package uses the function mgcv:s which is called here in the model formula as well (more on this below).

Toggle code
# create a set of /k/ splines for vector /x/
get_basis <- function(x, k = 20){
  # fit a dummy GAM
  fit_gam <- mgcv::gam(
    formula = y ~ s(times, bs = "cr", k = 20),
    data    = tibble(y = 1, times = x))
  # extract model matrix
  model_matrix <- model.matrix(fit_gam)
  # wrangle to long tibble
  some_curves <- model_matrix |> 
    as_tibble() |> 
    mutate(x = x) |> 
    pivot_longer(cols = -x) |> 
    mutate(
      name = str_replace(name, "s\\(times\\).", "curve_"))
  return(some_curves)
}

Let’s construct a basis then and plot it:

Toggle code
n_x = 1000
x = seq(0,60, length.out = n_x)

some_curves <- get_basis(x)

# plot the basis
some_curves |> 
  ggplot(aes(x = x, y = value, color = name, group = name)) +
  geom_line(size=1.5) + 
  scale_colour_manual(values = c(project_colors, project_colors)) +
  theme(legend.position="none") +
  ylab("") + ggtitle("your basis") +
  xlab("")

To play “make a wiggle” you also need a vector of weights, one for each basis function. Here is a function which, given a weight vector, computes the weighted average over all basis functions for a smooth prediction.

Toggle code
make_a_wiggle <- function(some_curves, weights, data = NULL) {
  
  your_wiggly_line <- some_curves |> 
    mutate(
      weight = rep(weights, n_x),
      weighted_value = value * weight
      ) |> 
    group_by(x) |> 
    summarize(wiggly_line = sum(weighted_value)) |> 
    ungroup()
  
  your_wiggle_plot <- your_wiggly_line |> 
    ggplot(aes(x = x, y = wiggly_line)) + 
    geom_line(color = project_colors[1], size=1.5) + 
    ggtitle("your wiggly line") +
    xlab("x") + ylab("")
    
  your_wiggle_plot
}

And here are some examples:

Toggle code
weights <- c(0,2,-1,1,1,2,1,5,5,-7,1,1,1,10,1,1,1,1,1,1)
make_a_wiggle(some_curves, weights)

Toggle code
weights <- c(-1,0.5,-2,-1,0,3,4,3,2,-1,-5,1,1,1,1,2,2,2,2,2)
make_a_wiggle(some_curves, weights)

And a not so wiggly wiggle:

Toggle code
weights <- c(0,1,10,rep(0, times = 17))
make_a_wiggle(some_curves, weights)

Exercise 1: Try a manual fit

Play around with weights and try to find a constellation that approximates the shape of the motorcycle data in data_motor that we plotted above.

Penalizing wiggliness and group-level effects

We saw how a vector of weights allows us to approximate non-linear relationship, once we have a set of elementary basis functions. To fit a non-linear regression model, the data should inform us about which vector of weights to use. So, essentially, we can just use the basis functions as linear predictors. There’s a lot of them, so that’s going to be ugly, but let’s do it:

Toggle code
motor_basis <- get_basis(data_motor$times, nrow(data_motor)) |> 
  mutate (accel = rep(data_motor$accel, each = 20)) |> 
  rename(times = x) |> 
  dplyr::filter(name != "(Intercept)") |> 
  pivot_wider(id_cols = c("times", "accel"), names_from = name, values_from = value, values_fn = mean)

fit_motor_FE <- brms::brm(
  formula = accel ~ times + curve_1 + curve_2 + curve_3 + curve_4 + curve_5 + curve_6 + curve_7 + curve_8 + curve_9 + curve_10 +
    curve_11 + curve_12 + curve_13 + curve_14 + curve_15 + curve_16 + curve_17 + curve_18 + curve_19,
  data    = motor_basis,
  prior   = prior(normal(0,50))
)

This really only works if we specify some constraints via the priors for the coefficients! (Try a fit with improper priors!)

Here is a plot of the posterior linear predictor:

Toggle code
plot_post_linPred <- function(fit, data) {
  postPred <- data |> 
    tidybayes::add_epred_draws(fit) |> 
    group_by(times,accel) |> 
    summarize(
      lower  = tidybayes::hdi(.epred)[1],
      mean   = mean(.epred),
      higher = tidybayes::hdi(.epred)[2]
    )
  
  postPred |> 
    ggplot(aes(x = times,  y = accel)) +
    geom_line(aes(y = mean), color = project_colors[1], size = 2) +
    geom_ribbon(aes(ymin = lower, ymax = higher), 
                fill = project_colors[6], alpha = 0.5) +
    geom_point(size = 2)
}

plot_post_linPred(fit_motor_FE, motor_basis)

That looks like it’s on the right track, but it’s not wiggly enough.

Exercise 2: Increasing smoothness of the predictor

What could we do to make the linear predictor more curvy?

The bigger the fixed effect coefficients, the more the basis functions can affect the linear predictor. Currently, there are rather strong priors on the coefficients (A Gaussian with standard deviation 10). If we make these priors wider, we should see a much more curvy fit.

Toggle code
fit_motor_FE_wider <- brms::brm(
  formula = accel ~ times + curve_1 + curve_2 + curve_3 + curve_4 + curve_5 + curve_6 + curve_7 + curve_8 + curve_9 + curve_10 +
    curve_11 + curve_12 + curve_13 + curve_14 + curve_15 + curve_16 + curve_17 + curve_18 + curve_19,
  data    = motor_basis,
  prior   = prior(normal(0,500))
)

plot_post_linPred(fit_motor_FE_wider, motor_basis)

Taking stock so far, all of this shows that:

  • By adjusting weights of basis functions, we can get smoothed curves.
  • These weights can be learned from the data (visual results look good).
  • BUT: the degree of smoothness is determined by the prior on coefficients.

The latter problem is worrisome because we do not want to manually adjust an important parameter like that. Ideally, the smoothness of the curve should be dictated by the data itself. Otherwise, we might overfit or remain too linear, so to speak.

Notice that the parameter to tweak is the standard deviation of the normal distribution of the prior over coefficients. The smaller this is, the more linear a curve we predict. The bigger the more smoothed the fit. So, we want a model that let’s the data decide what the standard deviation is supposed to be for additive offsets to what is otherwise a vanilla regression model. But, hey, that’s exactly what group-level effects do, too!

Concretely, we want a predictor \(\eta\) which consists of the usual linear regression part \(X \beta\) with an additional part \(Z b\), which specifies how much to deviate from the vanilla linear prediction. In the case at hand, the matrix \(Z\) is, essentially, the basis (rows are for each observation of times, columns correspond to the basis functions). We also want the coefficient’s \(b\) to be close to zero but with unknown variance, which is to be inferred from the data. The overal structure of this model is therefore:

\[ \begin{align*} \eta &= X \beta + Z b \\ b & \sim \mathcal{N}(o, \sigma_b) \\ \sigma_b & \sim \dots \text{some prior} \dots \end{align*} \] And that is exactly the structure of a group-level / random-effects model. Random effects are supposed to be small, where what counts as small enough is to be determined by the data. Likewise, weights of basis functions are suppose to be small and, likewise, what counts as small is not to be fixed by hand but governed by the needs of the data. (Here we are again: we should think of group-level effects as putting up an (elastic) fence and letting the sheep graze where they want; a gentle anything-goes within a flexible margin.)

Indeed, brms implements GAMs as, essentially, multi-level models, even though the touch-and-feel maybe different.

Implementing GAMs in brms

To run a basic GAM in brms, just wrap the predictor to be smoothed in s(), which is brms-wrapper around the mgcv function of the same name.

Toggle code
fit_motor <- 
  brms::brm(
    formula = accel ~ s(times, bs = "cr", k = 20),
    data    = data_motor
  )

Here is the posterior linear predictor for this model, which looks fine:

Toggle code
plot_post_linPred(fit_motor, data_motor)

It is interesting to interpret the summary:

Toggle code
# interpret this
summary(fit_motor)
 Family: gaussian 
  Links: mu = identity; sigma = identity 
Formula: accel ~ s(times, bs = "cr", k = 20) 
   Data: data_motor (Number of observations: 132) 
  Draws: 4 chains, each with iter = 2000; warmup = 1000; thin = 1;
         total post-warmup draws = 4000

Smooth Terms: 
              Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sds(stimes_1)     4.95      1.25     3.12     7.93 1.01      810     1187

Population-Level Effects: 
          Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
Intercept   -25.67      2.00   -29.73   -21.73 1.00     4500     2689
stimes_1      1.66      0.33     1.02     2.31 1.00     4324     2713

Family Specific Parameters: 
      Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sigma    22.84      1.48    20.13    25.89 1.00     4728     2960

Draws were sampled using sampling(NUTS). For each parameter, Bulk_ESS
and Tail_ESS are effective sample size measures, and Rhat is the potential
scale reduction factor on split chains (at convergence, Rhat = 1).

The “Smooth Term” sds(stimes_1) is essentially the standard deviation \(\sigma\) in the equations above. (The “sigma” is the usual standard deviation of the Gaussian likelihood function.) The intercept is the usual intercept, and the term “stimes_1” is the normal linear regressions slope. But these are so distorted by the penalized smooths in which they are wrapped that is is not prudent to draw strong conclusions from them (in my understanding; though I am not certain for it).

Exercises

Exercise 3: Posterior predictives

The previous plot showed ribbons for the 95% HDI for the linear predictor. Make a similar plot for the posterior predictive distribution. Before you code, think about how you expect the plot to differ from the previous one (different curve, more / less uncertainty …).

We expect the central tendency to be similar (albeit a bit more noisy, given sampling inaccuracy), but most of all the ribbons will be broader, given that we quantify uncertaint about where the data points would fall, not about where the central tendency is.

Toggle code
postPred <- data_motor |> 
  tidybayes::add_predicted_draws(fit_motor) |> 
  group_by(times,accel) |> 
  summarize(
    lower  = tidybayes::hdi(.prediction)[1],
    mean   = mean(.prediction),
    higher = tidybayes::hdi(.prediction)[2]
  )

postPred |> 
  ggplot(aes(x = times,  y = accel)) +
  geom_line(aes(y = mean), color = project_colors[1], size = 2) +
  geom_ribbon(aes(ymin = lower, ymax = higher), 
              fill = project_colors[7], alpha = 0.3) +
  geom_point(size = 2)

Exercise 4: Posterior predictive check

Consult visual posterior predictive checks for the GAM. Does it look alright? Anything systematicaly amiss?

A simple PPC suggests that the overall distribution of the data is not matched. This could be due to the large uncertainty for early time points, which is not borne out by the data. This explains why the PPCs are flatter, less pronounced.

Toggle code
pp_check(fit_motor, ndraws = 40)

Exercise 5: A distributional GAM

Given the problem above, it makes sense to try a distributional model, regressing the “sigma” parameter of the likelihood function on times. Decide whether it’s better to plot the posterior for the linear predictor or the posterior predictives. Interpret what you see.

Here is the model fit:

Toggle code
fit_motor_distributional <- 
  brms::brm(
    formula = brms::bf(accel ~ s(times, bs = "cr", k = 20),
                       sigma ~ times),
    data    = data_motor
  )

We should consult the posterior predictive because that is where the difference in spread around the linear predictor will show, not in the estimates of the linear predictor.

Toggle code
postPred <- data_motor |> 
  tidybayes::add_predicted_draws(fit_motor_distributional) |> 
  group_by(times,accel) |> 
  summarize(
    lower  = tidybayes::hdi(.prediction)[1],
    mean   = mean(.prediction),
    higher = tidybayes::hdi(.prediction)[2]
  )

postPred |> 
  ggplot(aes(x = times,  y = accel)) +
  geom_line(aes(y = mean), color = project_colors[1], size = 2) +
  geom_ribbon(aes(ymin = lower, ymax = higher), 
              fill = project_colors[7], alpha = 0.4) +
  geom_point(size = 2)

We see a tighter credible interval for early time points, but the change to the non-distributional model is not very large.

Exercise 6: A GAM for World Temperature

Run a Bayesian GAM for the World Temperature data (aida::data_WorldTemp). Plot the expected posterior predictor of central tendency together with the data. Can you use information from the posterior over model parameters to address the question of whether there is a trend towards higher measurements as time progresses? Compare the conclusions you draw from this linear model to those you would draw from a naive linear model.

First run the models.

Toggle code
fit_temp_splines <- 
  brms::brm(
    formula = avg_temp ~ s(year, bs = "cr", k=20),
    data    = aida::data_WorldTemp
  )

fit_temp_vanilla <- 
  brms::brm(
    formula = avg_temp ~ year,
    data    = aida::data_WorldTemp
  )

Here is the posterior predicted central tendency for the GAM (w/ credible intervals):

Toggle code
postPred <- aida::data_WorldTemp |> 
  tidybayes::add_predicted_draws(fit_temp_splines) |> 
  group_by(avg_temp, year) |> 
  summarize(
    lower  = tidybayes::hdi(.prediction)[1],
    mean   = mean(.prediction),
    higher = tidybayes::hdi(.prediction)[2]
  )

postPred |> 
  ggplot(aes(x = year,  y = avg_temp)) +
  geom_line(aes(y = mean), color = project_colors[1], size = 2) +
  geom_ribbon(aes(ymin = lower, ymax = higher), 
              fill = project_colors[6], alpha = 0.4) +
  geom_point(size = 2)

We can compare the summaries of both models:

Toggle code
tidybayes::summarise_draws(fit_temp_splines)[1:4,]
# A tibble: 4 × 10
  variable     mean median      sd     mad     q5    q95  rhat ess_bulk ess_tail
  <chr>       <num>  <num>   <num>   <num>  <num>  <num> <num>    <num>    <num>
1 b_Interce… 8.31   8.31   0.0200  0.0195  8.28   8.35    1.00    5089.    3026.
2 bs_syear_1 0.122  0.122  0.00499 0.00500 0.114  0.130   1.00    4876.    3011.
3 sds_syear… 0.0571 0.0542 0.0192  0.0177  0.0318 0.0930  1.00     665.    1211.
4 sigma      0.326  0.325  0.0144  0.0145  0.303  0.350   1.00    4001.    3232.
Toggle code
tidybayes::summarise_draws(fit_temp_vanilla)[1:4,]
# A tibble: 4 × 10
  variable        mean   median      sd     mad       q5      q95  rhat ess_bulk
  <chr>          <num>    <num>   <num>   <num>    <num>    <num> <num>    <num>
1 b_Intercept -3.52    -3.52    6.23e-1 6.25e-1 -4.53    -2.48     1.00    4763.
2 b_year       0.00628  0.00628 3.31e-4 3.33e-4  0.00573  0.00681  1.00    4762.
3 sigma        0.406    0.405   1.79e-2 1.80e-2  0.378    0.437    1.00    1346.
4 lprior      -3.16    -3.16    1.55e-3 1.55e-3 -3.16    -3.16     1.00    1349.
# ℹ 1 more variable: ess_tail <num>

The population-level slope parameter represents the unsmoothed linear predictor line for the splines model, but it is difficult to interpret (because we have to strip off the smoothing splines, so to speak). There seems to be an indication of a positive linear effect underneath the smooth, but I, personally, would not base firm conclusions on this.

replace talk of “linear predictor” with central tendency