Multivariate Gaussian marginal likelihood via THAMES

set.seed(2023)
library(thames)

To use the function thames() we only need two things 1. A T × d matrix of parameters drawn from the posterior distribution. The columns are the parameters (dimension d) and the rows are the T different posterior draws. 2. The vector of unnormalized log posterior values of length T (sum of the log prior and the log likelihood for each drawn parameter).

To illustrate the use of the thames() function we give a toy example on multivariate Gaussian data. For more details on the method, see the paper Metodiev M, Perrot-Dockès M, Ouadah S, Irons N. J., Raftery A. E. (2023), Easily Computed Marginal Likelihoods from Posterior Simulation Using the THAMES Estimator.

Here the data yi, i = 1, …, n are drawn independently from a multivariate normal distribution: along with a prior distribution on the mean vector μ: with s0 > 0. It can be shown that the posterior distribution of the mean vector μ given the data D = {y1, …, yn} is given by: where mn = n/(n + 1/s0), $\bar{y}=(1/n)\sum_{i=1}^n y_i$, and sn = 1/(n + 1/s0).

Toy example: d = 1, n = 20

We fix s0 (the variance of the prior) and μ to 1 (note that the results are similar for some other values).

s0      <- 1
mu_star <- 1
n       <- 20
d       <- 1

We simulate values of Y and calculate the associated sn and mn

library(mvtnorm)

Y =  rmvnorm(n, mu_star, diag(d))
sn = 1/(n+1/s0)
mn = sum(Y)/(n + 1/s0)

To use thames() we need a sample drawn from the posterior of the parameters. In this toy example the only parameter is μ, and we can draw from the posterior exactly. (More generally, MCMC can be used to obtain approximate posterior samples.) Here we take 2000 samples.

mu_sample = rmvnorm(2000, mean=mn,  sn*diag(d))

Now we calculate the unnormalized log posterior for each μ(i).

We first calculate the prior on each sample:

reg_log_prior <- function(param,sig2) {
  d <- length(param)
  p <- dmvnorm(param, rep(0,d), (sig2)*diag(d), log = TRUE)
  return(p)
}

log_prior <- apply(mu_sample, 1, reg_log_prior, s0)

and the likelihood of the data for each sampled parameter:

reg_log_likelihood <- function(param, X) {
  n <- nrow(X)
  d <- length(param)
  sum(dmvnorm(X, param, diag(d),log = TRUE))
}

log_likelihood <- apply(mu_sample, 1, reg_log_likelihood, Y)

and then sum the two to get the log posterior:

log_post <- log_prior + log_likelihood

We can now estimate the marginal likelihood using THAMES:

result <- thames(log_post,mu_sample)

The THAMES estimate of the log marginal likelihood is then

-result$log_zhat_inv
#> [1] -26.36138

The upper and lower bounds of a confidence interval based on asymptotic normality of the estimator (a 95% interval by default) are

-result$log_zhat_inv_L 
#> [1] -26.32752
-result$log_zhat_inv_U
#> [1] -26.39413

If we instead want a 90% confidence interval, we specify a lower quantile of 0.05:

result_90 <- thames(log_post,mu_sample, p = 0.05)
-result_90$log_zhat_inv_L 
#> [1] -26.33305
-result_90$log_zhat_inv_U
#> [1] -26.38894

To check our estimate, we can calculate the Gaussian log marginal likelihood analytically as $$ \ell(y) = -\frac{nd}{2}\log(2\pi)-\frac{d}{2}\log(s_0n+1)-\frac{1}{2}\sum_{i=1}^n\|y_i\|^2+ \frac{n^2}{2(n+1/s_0)}\|\bar{y}\|^2. $$

- n*d*log(2*pi)/2 - d*log(s0*n+1)/2 - sum(Y^2)/2 + sum(colSums(Y)^2)/(2*(n+1/s0))
#> [1] -26.34293

In higher dimensions

We can do exactly the same thing if Yi ∈ ℝd with d > 1

s0      <- 1
n       <- 20
d       <- 2
mu_star <- rep(1,d)
Y =  rmvnorm(n, mu_star, diag(d))
sn = 1/(n+1/s0)
mn = apply(Y,2,sum)/(n + 1/s0)
mu_sample = rmvnorm(2000, mean=mn,  sn*diag(d))

mvg_log_post <- function(param, X, sig2){
  n <- nrow(X)
  d <- length(param)
  l <- sum(dmvnorm(X, param, diag(d),log = TRUE))
  p <- dmvnorm(param, rep(0,d), (sig2)*diag(d), log = TRUE)
  return(p + l)
}
log_post <- apply(mu_sample, 1,mvg_log_post,Y,s0)
result <- thames(log_post,mu_sample)
-result$log_zhat_inv
#> [1] -54.4726

To check our estimate, we again calculate the Gaussian marginal likelihood analytically:

- n*d*log(2*pi)/2 - d*log(s0*n+1)/2 - sum(Y^2)/2 + sum(colSums(Y)^2)/(2*(n+1/s0))
#> [1] -54.46546