XX: Mixture models

Bayesian regression: theory & practice

Author

Michael Franke

This tutorial discusses a minimal example of a mixture model. After introducing the main idea behind mixture models, a fictitious (minimal) data set is analyzed first with a hand-written Stan program, and then with a mixture regression model using brms.

Preamble

Here is code to load (and if necessary, install) required packages, and to set some global options (for plotting and efficient fitting of Bayesian models).

Toggle code
# install packages from CRAN (unless installed)
pckgs_needed <- c(
  "tidyverse",
  "brms",
  "rstan",
  "rstanarm",
  "remotes",
  "tidybayes",
  "bridgesampling",
  "shinystan",
  "mgcv"
)
pckgs_installed <- installed.packages()[,"Package"]
pckgs_2_install <- pckgs_needed[!(pckgs_needed %in% pckgs_installed)]
if(length(pckgs_2_install)) {
  install.packages(pckgs_2_install)
} 

# install additional packages from GitHub (unless installed)
if (! "aida" %in% pckgs_installed) {
  remotes::install_github("michael-franke/aida-package")
}
if (! "faintr" %in% pckgs_installed) {
  remotes::install_github("michael-franke/faintr")
}
if (! "cspplot" %in% pckgs_installed) {
  remotes::install_github("CogSciPrag/cspplot")
}

# load the required packages
x <- lapply(pckgs_needed, library, character.only = TRUE)
library(aida)
library(faintr)
library(cspplot)

# these options help Stan run faster
options(mc.cores = parallel::detectCores())

# use the CSP-theme for plotting
theme_set(theme_csp())

# global color scheme from CSP
project_colors = cspplot::list_colors() |> pull(hex)
# names(project_colors) <- cspplot::list_colors() |> pull(name)

# setting theme colors globally
scale_colour_discrete <- function(...) {
  scale_colour_manual(..., values = project_colors)
}
scale_fill_discrete <- function(...) {
   scale_fill_manual(..., values = project_colors)
}

Finite mixtures: Motivation & set-up

To motivate mixture modeling, let’s look at a fictitious data set. There are two sets of measurements. For each of two species of plants, we measured the length of 25 exemplars.

Toggle code
heights_A <- c(6.94, 11.77, 8.97, 12.2, 8.48, 
               9.29, 13.03, 13.58, 7.63, 11.47, 
               10.24, 8.99, 8.29, 10.01, 9.47, 
               9.92, 6.83, 11.6, 10.29, 10.7, 
               11, 8.68, 11.71, 10.09, 9.7)

heights_B <- c(11.45, 11.89, 13.35, 11.56, 13.78, 
               12.12, 10.41, 11.99, 12.27, 13.43, 
               10.91, 9.13, 9.25, 9.94, 13.5, 
               11.26, 10.38, 13.78, 9.35, 11.67, 
               11.32, 11.98, 12.92, 12.03, 12.02) + 4

Here’s how this data is distributed:

Toggle code
ffm_data <- tibble(
  A = heights_A,
  B = heights_B
) |> 
  pivot_longer(
    cols      = everything(),
    names_to  = 'species',
    values_to = 'height'
  )

ffm_data |> 
  ggplot(aes(x = height)) +
  geom_density(aes(color = species), size = 2) +
  geom_rug(aes(color = species), size = 1.5) +
  theme(legend.position = 'none')

Now suppose that (for whatever reason) we get the data without information which measure was from which group. If we plot the data without this information, the picture looks like this:

Toggle code
flower_heights <- c(heights_A, heights_B)
tibble(flower_heights) |> 
  ggplot(aes(x = flower_heights)) + 
  geom_rug(size = 2) +
  geom_density(size = 2) +
  xlab("height")

Data may often look like this, showing signs of bi- or multi-modality, i.e., having several “humps” or apparent local areas of higher concentration. If we fit a single Gaussian to this data it might look like this:

Toggle code
# using the descriptive means/SD for a quick "best fit"
mu    <- mean(flower_heights)
sigma <- sd(flower_heights)
tibble(
  source  = c(rep("data", length(flower_heights)), rep("fit", 1000)),
  height  = c(flower_heights, rnorm(1000, mu, sigma))
) |>  
ggplot(aes(x = height, fill=source)) +
  geom_density(size = 2, alpha = 0.3)

If we see a posterior predictive check that looks like this picture above, you know that there is something systematic amiss with your model: you assume a single peak (a multimodal response distribution) but the data shows signs of multi-modality (several “peaks” or centers of high density).

Models that allow multiple “peaks” in the response distribution, are mixture models. When dealing with a Gaussian likelihood function, as in the case at hand, we speak of Gaussian mixture models (GMM). In general, a mixture model incorporates the idea that the data was generated from more than one process (alternative lingo: that observations were samples from different populations).

Let \(\langle f_1, \dots, f_k \rangle\) be \(k\) likelihood functions for data \(Y\). The \(k\)-mixture model for \(Y\) explains the data as a weighted combination, with mixture weights \(\alpha\) (a probability vector). Procedurally, think of inferring for each data point \(y_{i}\) which a mixture component \(k(i)\) most likelihood belongs to (i.e., by which of \(k\) different processes it may have been generated; or, from which of \(k\) different populations it was sampled). The probability that any \(y_{i}\) is in class \(j\) is given by \(\alpha_{i}\), so that \(\alpha\) represents the overall probabilities of different components (sub-populations). This results in a mixture likelihood function which can be written like so:

\[f^{\mathrm{MM}}(y_i) = \alpha_{k(i)} f_{k(i)}\]

For the current case, we assume that there are just two components (because we see two “humps”, or have a priori information that there are exactly two groups. Concretely, for each data point \(y_i\), \(i \in \{1, \dots, N\}\), we are going to estimate how likely data point \(i\) may have been a sample from normal distribution “Number 0”, with \(\mu_0\) and \(\sigma_0\), or from normal distribution “Number 1”, with \(\mu_1\) and \(\sigma_1\). Naturally, all \(\mu_{0,1}\) and \(\sigma_{0,1}\) are estimated from the data, as are the group-indicator variables \(z_i\). There is also a global parameter \(p\) which indicates how likely any data point is to come from one of the two distributions (you’ll think about which one below!). Here’s the full model we will work with (modulo an additional ordering constraint, as discussed below):

\[ \begin{align*} p & \sim \text{Beta}(1,1) \\ z_i & \sim \text{Bernoulli}(p) \\ \mu_{0,1} & \sim \mathcal{N}(12, 10) \\ \sigma_{0,1} & \sim \text{log-normal}(0, 2) \\ y_i & \sim \mathcal{N}(\mu_{z_i}, \sigma_{z_i}) \end{align*} \]

Exercise 1a: Draw the model

Draw a graphical representation of this mixture model.

FILL ME

A Gaussian mixture model in Stan

The Stan model

We are going to pack the data together for fitting the Stan model:

Toggle code
data_GMM <- list(
  y = flower_heights,
  N = length(flower_heights)
)

Below is the Stan code for this model. It is also given in file Gaussian-mixture-01-basic.stan. A few comments on this code:

  1. There is no occurrence of variable \(z_i\), as this is marginalized out. We do this by incrementing the log-score manually, using target += log_sum_exp(alpha).
  2. We declare vector mu to be of a particular type which we have not seen before. We want the vector to be ordered. We will come back to this later. Don’t worry about it now.
data {
  int<lower=1> N; 
  real y[N];      
}
parameters {
  real<lower=0,upper=1> p;         
  ordered[2] mu;             
  vector<lower=0>[2] sigma; 
}
model {
  p ~ beta(1,1);
  mu ~ normal(12, 10);
  sigma ~ lognormal(0, 1);
  for (i in 1:N) {
    vector[2] alpha;
    alpha[1] = log(p)   + normal_lpdf(y[i] | mu[1], sigma[1]);
    alpha[2] = log(1-p) + normal_lpdf(y[i] | mu[2], sigma[2]);
    target += log_sum_exp(alpha);
  }
}
Toggle code
stan_fit_2b_GMM <- stan(
  file = 'stan-files/Gaussian-mixture-01-basic.stan',
  data = data_GMM
)
Toggle code
stan_fit_2b_GMM
Inference for Stan model: anon_model.
4 chains, each with iter=2000; warmup=1000; thin=1; 
post-warmup draws per chain=1000, total post-warmup draws=4000.

            mean se_mean   sd    2.5%     25%     50%     75%   97.5% n_eff
p           0.54    0.01 0.14    0.25    0.46    0.54    0.63    0.82   692
mu[1]      10.42    0.03 0.85    8.96    9.84   10.34   10.92   12.38   712
mu[2]      15.70    0.02 0.63   14.05   15.39   15.79   16.12   16.64   812
sigma[1]    2.08    0.02 0.58    1.12    1.67    2.01    2.44    3.35   753
sigma[2]    1.41    0.02 0.50    0.68    1.08    1.34    1.65    2.77   727
lp__     -127.32    0.07 2.02 -132.54 -128.37 -126.86 -125.81 -124.77   739
         Rhat
p           1
mu[1]       1
mu[2]       1
sigma[1]    1
sigma[2]    1
lp__        1

Samples were drawn using NUTS(diag_e) at Sat Feb 10 10:26:55 2024.
For each parameter, n_eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor on split chains (at 
convergence, Rhat=1).
Exercise 1b: Interpret this outcome

Interpret these results! Focus on parameters \(p\), \(\mu_1\) and \(\mu_2\). What does \(p\) capture in this implementation? Do the (mean) estimated values make sense?

Yes, they do make sense. \(p\) is the prevalence of data from the group with the higher mean, which is group \(B\) in our case. The model infers that there are roughly equally many data points from each group, which is indeed the case. The model also recovers the descriptive means of each group!

Toggle code
ffm_data |> 
  group_by(species) |> 
  summarise(
    mean     = mean(height),
    std_dev  = sd(height)
  )
# A tibble: 2 × 3
  species  mean std_dev
  <chr>   <dbl>   <dbl>
1 A        10.0    1.76
2 B        15.7    1.38

An unidentifiable model

Let’s run the model in file Gaussian-mixture-02-unindentifiable.stan, which is exactly the same as before but with vector mu being an unordered vector of reals.

Toggle code
stan_fit_2c_GMM <- stan(
  file = 'stan-files/Gaussian-mixture-02-unindentifiable.stan',
  data = data_GMM,
  # set a seed for reproducible results
  seed = 1734
)

Here’s a summary of the outcome:

Toggle code
stan_fit_2c_GMM
Inference for Stan model: anon_model.
4 chains, each with iter=2000; warmup=1000; thin=1; 
post-warmup draws per chain=1000, total post-warmup draws=4000.

            mean se_mean   sd    2.5%     25%     50%     75%   97.5% n_eff
p           0.50    0.02 0.14    0.21    0.40    0.50    0.59    0.77    81
mu[1]      12.85    1.80 2.77    9.13   10.20   11.71   15.72   16.50     2
mu[2]      13.26    1.75 2.77    9.19   10.44   14.76   15.82   16.50     3
sigma[1]    1.75    0.21 0.61    0.80    1.30    1.66    2.12    3.11     8
sigma[2]    1.74    0.21 0.64    0.80    1.27    1.64    2.10    3.18     9
lp__     -128.86    0.05 1.86 -133.37 -129.79 -128.46 -127.51 -126.51  1241
         Rhat
p        1.04
mu[1]    2.72
mu[2]    2.41
sigma[1] 1.18
sigma[2] 1.16
lp__     1.00

Samples were drawn using NUTS(diag_e) at Sat Feb 10 10:27:17 2024.
For each parameter, n_eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor on split chains (at 
convergence, Rhat=1).
Exercise 1c: Interpret model output

What is remarkable here? Explain what happened. Explain in what sense this model is “unidentifiable”.

Hint: Explore the parameters with high \(\hat{R}\) values. When a model fit seems problematic, a nice tool to explore what might be amiss is the package shinystan. You could do this:

Toggle code
shinystan::launch_shinystan(stan_fit_2c_GMM)

Then head over to the tab “Explore” and have a look at some of the parameters.

The \(\hat{R}\) values of the mean parameters are substantially above 1, suggesting that the model did not converge. But if we look at trace plots for these parameters, for example, we see that \(\mu_1\) has “locked into” group A for some chains, and into group B for some other chains.

The model is therefore unidentifiable in the sense that, without requiring that mu is ordered, \(\mu_1\) could be for group A or group B, and which one it will take on depends on random initialization. Requiring that mu be ordered, breaks this symmetry.

Posterior predictive check

We can extend the (identifiable) model from above to also output samples from the posterior predictive distribution. This is given in file Gaussian-mixture-03-withPostPred.stan. Let’s run this model, collect the posterior predictive samples in a variable called yrep and draw a density plot.

Toggle code
stan_fit_2d_GMM <- stan(
  file = 'stan-files/Gaussian-mixture-03-withPostPred.stan',
  data = data_GMM,
  # only return the posterior predictive samples
  pars = c('yrep')
)
Toggle code
tibble(
  source  = c(rep("data", length(flower_heights)), rep("PostPred", length(extract(stan_fit_2d_GMM)$yrep))),
  height = c(flower_heights, extract(stan_fit_2d_GMM)$yrep)
) |>  
  ggplot(aes(x = height, fill=source, color = source)) +
  geom_density(size = 2, alpha = 0.3)

Exercise 1d: Scrutinize posterior predictive check

Does this look like a distribution that could have generated the data?

The distribution looks plausible enough. The visual fit in these density plots is not perfect also because we use quite a different number of samples to estimate the density.

A Gaussian mixture model in brms

We can also run this finite mixture model in brms. Fitting the parameters of a single Gaussian is like fitting an intercept-only simple linear regression model. We can add finite mixtures to brms via the family parameter and the function brms::mixture(). Here, we define a finite mixture of Gaussians, of course, but more flexibility is possible.

Toggle code
brms_fit_2e_GMM <- brm(
  # intercept only model
  formula = y ~ 1, 
  data = data_GMM, 
  # declare that the likelihood should be a mixture
  family = mixture(gaussian, gaussian),
  # use weakly informative priors on mu  
  prior = c(
    prior(normal(12, 10), Intercept, dpar = mu1),
    prior(normal(12, 10), Intercept, dpar = mu2)
  )
) 

Let’s look at the model fit:

Toggle code
brms_fit_2e_GMM
 Family: mixture(gaussian, gaussian) 
  Links: mu1 = identity; sigma1 = identity; mu2 = identity; sigma2 = identity; theta1 = identity; theta2 = identity 
Formula: y ~ 1 
   Data: data_GMM (Number of observations: 50) 
  Draws: 4 chains, each with iter = 2000; warmup = 1000; thin = 1;
         total post-warmup draws = 4000

Population-Level Effects: 
              Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
mu1_Intercept    10.48      0.97     8.86    12.49 1.01      711      618
mu2_Intercept    15.63      0.76    13.48    16.67 1.00     1106      574

Family Specific Parameters: 
       Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sigma1     2.22      0.65     1.20     3.60 1.00      847     1299
sigma2     1.56      0.71     0.76     3.19 1.00      853      579
theta1     0.54      0.15     0.21     0.85 1.00      859      698
theta2     0.46      0.15     0.15     0.79 1.00      859      698

Draws were sampled using sampling(NUTS). For each parameter, Bulk_ESS
and Tail_ESS are effective sample size measures, and Rhat is the potential
scale reduction factor on split chains (at convergence, Rhat = 1).

Let’s also look at the Stan code that brms produced in the background for this model in order to find out how this model is related to that of Ex 2.b:

Toggle code
brms_fit_2e_GMM$model
// generated with brms 2.20.1
functions {
}
data {
  int<lower=1> N;  // total number of observations
  vector[N] Y;  // response variable
  vector[2] con_theta;  // prior concentration
  int prior_only;  // should the likelihood be ignored?
}
transformed data {
}
parameters {
  real<lower=0> sigma1;  // dispersion parameter
  real<lower=0> sigma2;  // dispersion parameter
  simplex[2] theta;  // mixing proportions
  ordered[2] ordered_Intercept;  // to identify mixtures
}
transformed parameters {
  // identify mixtures via ordering of the intercepts
  real Intercept_mu1 = ordered_Intercept[1];
  // identify mixtures via ordering of the intercepts
  real Intercept_mu2 = ordered_Intercept[2];
  // mixing proportions
  real<lower=0,upper=1> theta1;
  real<lower=0,upper=1> theta2;
  real lprior = 0;  // prior contributions to the log posterior
  theta1 = theta[1];
  theta2 = theta[2];
  lprior += normal_lpdf(Intercept_mu1 | 12, 10);
  lprior += student_t_lpdf(sigma1 | 3, 0, 4.2)
    - 1 * student_t_lccdf(0 | 3, 0, 4.2);
  lprior += normal_lpdf(Intercept_mu2 | 12, 10);
  lprior += student_t_lpdf(sigma2 | 3, 0, 4.2)
    - 1 * student_t_lccdf(0 | 3, 0, 4.2);
  lprior += dirichlet_lpdf(theta | con_theta);
}
model {
  // likelihood including constants
  if (!prior_only) {
    // initialize linear predictor term
    vector[N] mu1 = rep_vector(0.0, N);
    // initialize linear predictor term
    vector[N] mu2 = rep_vector(0.0, N);
    mu1 += Intercept_mu1;
    mu2 += Intercept_mu2;
    // likelihood of the mixture model
    for (n in 1:N) {
      real ps[2];
      ps[1] = log(theta1) + normal_lpdf(Y[n] | mu1[n], sigma1);
      ps[2] = log(theta2) + normal_lpdf(Y[n] | mu2[n], sigma2);
      target += log_sum_exp(ps);
    }
  }
  // priors including constants
  target += lprior;
}
generated quantities {
  // actual population-level intercept
  real b_mu1_Intercept = Intercept_mu1;
  // actual population-level intercept
  real b_mu2_Intercept = Intercept_mu2;
}

Now, your job. Look at the two previous outputs and answer the following questions:

Exercise 2a

Is the brms-model the exact same as the model in the previous section (model coded directly in Stan)?

No, the priors on \(\mu\) and \(\sigma\) are different.

Exercise 2b

What is the equivalent of the variable alpha from the model of the previous section in this new brms-generated code?

That’s ps

Exercise 2c

What is the equivalent of the variable p from the model of Ex 2.b in this new brms-generated code?

That’s theta[1]

Exercise 2d

Is the brms code generating posterior predictive samples?

No! This is not strictly necessary. These can be generated also later for a fitted object.

Exercise 2e

What is the prior probability in the brms-generated model of any given data point \(y_i\) to be from the first or second mixture component? Can you even tell from the code?

We cannot say. It’s in the variable con_theta but that is supplied from the outside. We can only guess. (A good guess would be: yes, it’s also unbiased 50/50.)