Bayesian inference for a logistic regression model (Part 3)

Part 3: The Metropolis algorithm

Introduction

This is the third part in a series of posts on MCMC-based Bayesian inference for a logistic regression model. If you are new to this series, please go back to Part 1.

In the previous post we derived the log posterior for the model and implemented it in a variety of programming languages and libraries. In this post we will construct a Markov chain having the posterior as its equilibrium.

MCMC

Detailed balance

A homogeneous Markov chain with transition kernel p(\theta_{n+1}|\theta_n) is said to satisfy detailed balance for some target distribution \pi(\theta) if

\displaystyle \pi(\theta)p(\theta'|\theta) = \pi(\theta')p(\theta|\theta'), \quad \forall \theta, \theta'

Integrating both sides wrt \theta gives

\displaystyle \int_\Theta \pi(\theta)p(\theta'|\theta)d\theta = \pi(\theta'),

from which it is clear that \pi(\cdot) is a stationary distribution of the chain (and the chain is reversible). Under fairly mild regularity conditions we expect \pi(\cdot) to be the equilibrium distribution of the chain.

For a given target \pi(\cdot) we would like to find an easy-to-sample-from transition kernel p(\cdot|\cdot) that satisfies detailed balance. This will then give us a way to (asymptotically) generate samples from our target.

In the context of Bayesian inference, the target \pi(\theta) will typically be the posterior distribution, which in the previous post we wrote as \pi(\theta|y). Here we drop the notational dependence on y, since MCMC can be used for any target distribution of interest.

Metropolis-Hastings

Suppose we have a fairly arbitrary easy-to-sample-from transition kernel q(\theta_{n+1}|\theta_n) and a target of interest, \pi(\cdot). Metropolis-Hastings (M-H) is a strategy for using q(\cdot|\cdot) to construct a new transition kernel p(\cdot|\cdot) satisfying detailed balance for \pi(\cdot).

The kernel p(\theta'|\theta) can be described algorithmically as follows:

  1. Call the current state of the chain \theta. Generate a proposal \theta^\star by simulating from q(\theta^\star|\theta).
  2. Compute the acceptance probability

\displaystyle \alpha(\theta^\star|\theta) = \min\left[1,\frac{\pi(\theta^\star)q(\theta|\theta^\star)}{\pi(\theta)q(\theta^\star|\theta)}\right].

  1. With probability \alpha(\theta^\star|\theta) return new state \theta'=\theta^\star, otherwise return \theta'=\theta.

It is clear from the algorithmic description that this kernel will have a point mass at \theta'=\theta, but that for \theta'\not=\theta the transition kernel will be p(\theta'|\theta)=q(\theta'|\theta)\alpha(\theta'|\theta). But then

\displaystyle \pi(\theta)p(\theta'|\theta) = \min[\pi(\theta)q(\theta'|\theta),\pi(\theta')q(\theta|\theta')]

is symmetric in \theta and \theta', and so detailed balance is satisfied. Since detailed balance is trivial at the point mass at \theta=\theta' we are done.

Metropolis algorithm

It is often convenient to generate proposals perturbatively, using a distribution that is symmetric about the current state of the chain. But then q(\theta'|\theta)=q(\theta|\theta'), and so q(\cdot|\cdot) drops out of the acceptance probability. This is the Metropolis algorithm.

Some computational tricks

To generate an event with probability \alpha(\theta^\star|\theta), we can generate a u\sim U(0,1) and accept if u < \alpha(\theta^\star|\theta). This is convenient for several reasons. First, it means that we can ignore the "min", and just accept if

\displaystyle u < \frac{\pi(\theta^\star)q(\theta|\theta^\star)}{\pi(\theta)q(\theta^\star|\theta)}

since u\leq 1 regardless. Better still, we can take logs, and accept if

\displaystyle \log u < \log\pi(\theta^\star) - \log\pi(\theta) + \log q(\theta|\theta^\star) - \log q(\theta^\star|\theta),

so there is no need to evaluate any raw densities. Again, in the case of a symmetric proposal distribution, the q(\cdot|\cdot) terms can be dropped.

Another trick worth noting is that in the case of the simple M-H algorithm described, using a single update for the entire state space (and not multiple component-wise updates, for example), and assuming that the same M-H kernel is used repeatedly to generate successive states of a Markov chain, then the \log \pi(\theta) term (which in the context of Bayesian inference will typically be the log posterior) will have been computed at the previous update (irrespective of whether or not the previous move was accepted). So if we are careful about how we pass on that old value to the next iteration, we can avoid recomputing the log posterior, and our algorithm will only require one log posterior evaluation per iteration rather than two. In functional programming languages it is often convenient to pass around this current log posterior density evaluation explicitly, effectively augmenting the state space of the Markov chain to include the log posterior density.

HoF for a M-H kernel

Since I’m a fan of functional programming, we will adopt a functional style throughout, and start by creating a higher-order function (HoF) that accepts a log-posterior and proposal kernel as input and returns a Metropolis kernel as output.

R

In R we can write a function to create a M-H kernel as follows.

mhKernel = function(logPost, rprop, dprop = function(new, old, ...) { 1 })
    function(x, ll) {
        prop = rprop(x)
        llprop = logPost(prop)
        a = llprop - ll + dprop(x, prop) - dprop(prop, x)
        if (log(runif(1)) < a)
            list(x=prop, ll=llprop)
        else
            list(x=x, ll=ll)
    }

Note that the kernel returned requires as input both a current state x and its associated log-posterior, ll. The new state and log-posterior densities are returned.

We need to use this transition kernel to simulate a Markov chain by successive substitution of newly simulated values back into the kernel. In more sophisticated programming languages we will use streams for this, but in R we can just use a for loop to sample values and write the states into the rows of a matrix.

mcmc = function(init, kernel, iters = 10000, thin = 10, verb = TRUE) {
    p = length(init)
    ll = -Inf
    mat = matrix(0, nrow = iters, ncol = p)
    colnames(mat) = names(init)
    x = init
    if (verb) 
        message(paste(iters, "iterations"))
    for (i in 1:iters) {
        if (verb) 
            message(paste(i, ""), appendLF = FALSE)
        for (j in 1:thin) {
            pair = kernel(x, ll)
            x = pair$x
            ll = pair$ll
            }
        mat[i, ] = x
        }
    if (verb) 
        message("Done.")
    mat
}

Then, in the context of our running logistic regression example, and using the log-posterior from the previous post, we can construct our kernel and run it as follows.

pre = c(10.0,1,1,1,1,1,5,1)
out = mcmc(init, mhKernel(lpost,
          function(x) x + pre*rnorm(p, 0, 0.02)), thin=1000)

Note the use of a symmetric proposal, so the proposal density is not required. Also note the use of a larger proposal variance for the intercept term and the second last covariate. See the full runnable script for further details.

Python

We can do something very similar to R in Python using NumPy. Our HoF for constructing a M-H kernel is

def mhKernel(lpost, rprop, dprop = lambda new, old: 1.):
    def kernel(x, ll):
        prop = rprop(x)
        lp = lpost(prop)
        a = lp - ll + dprop(x, prop) - dprop(prop, x)
        if (np.log(np.random.rand()) < a):
            x = prop
            ll = lp
        return x, ll
    return kernel

Our Markov chain runner function is

def mcmc(init, kernel, thin = 10, iters = 10000, verb = True):
    p = len(init)
    ll = -np.inf
    mat = np.zeros((iters, p))
    x = init
    if (verb):
        print(str(iters) + " iterations")
    for i in range(iters):
        if (verb):
            print(str(i), end=" ", flush=True)
        for j in range(thin):
            x, ll = kernel(x, ll)
        mat[i,:] = x
    if (verb):
        print("\nDone.", flush=True)
    return mat

We can use this code in the context of our logistic regression example as follows.

pre = np.array([10.,1.,1.,1.,1.,1.,5.,1.])

def rprop(beta):
    return beta + 0.02*pre*np.random.randn(p)

out = mcmc(init, mhKernel(lpost, rprop), thin=1000)

See the full runnable script for further details.

JAX

The above R and Python scripts are fine, but both languages are rather slow for this kind of workload. Fortunately it’s rather straightforward to convert the Python code to JAX to obtain quite amazing speed-up. We can write our M-H kernel as

def mhKernel(lpost, rprop, dprop = jit(lambda new, old: 1.)):
    @jit
    def kernel(key, x, ll):
        key0, key1 = jax.random.split(key)
        prop = rprop(key0, x)
        lp = lpost(prop)
        a = lp - ll + dprop(x, prop) - dprop(prop, x)
        accept = (jnp.log(jax.random.uniform(key1)) < a)
        return jnp.where(accept, prop, x), jnp.where(accept, lp, ll)
    return kernel

and our MCMC runner function as

def mcmc(init, kernel, thin = 10, iters = 10000):
    key = jax.random.PRNGKey(42)
    keys = jax.random.split(key, iters)
    @jit
    def step(s, k):
        [x, ll] = s
        x, ll = kernel(k, x, ll)
        s = [x, ll]
        return s, s
    @jit
    def iter(s, k):
        keys = jax.random.split(k, thin)
        _, states = jax.lax.scan(step, s, keys)
        final = [states[0][thin-1], states[1][thin-1]]
        return final, final
    ll = -np.inf
    x = init
    _, states = jax.lax.scan(iter, [x, ll], keys)
    return states[0]

There are really only two slightly tricky things about this code.

The first relates to the way JAX handles pseudo-random numbers. Since JAX is a pure functional eDSL, it can’t be used in conjunction with the typical pseudo-random number generators often used in imperative programming languages which rely on a global mutable state. This can be dealt with reasonably straightforwardly by explicitly passing around the random number state. There is a standard way of doing this that has been common practice in functional programming languages for decades. However, this standard approach is very sequential, and so doesn’t work so well in a parallel context. JAX therefore uses a splittable random number generator, where new states are created by splitting the current state into two (or more). We’ll come back to this when we get to the Haskell examples.

The second thing that might be unfamiliar to imperative programmers is the use of the scan operation (jax.lax.scan) to generate the Markov chain rather than a "for" loop. But scans are standard operations in most functional programming languages.

We can then call this code for our logistic regression example with

pre = jnp.array([10.,1.,1.,1.,1.,1.,5.,1.]).astype(jnp.float32)

@jit
def rprop(key, beta):
    return beta + 0.02*pre*jax.random.normal(key, [p])

out = mcmc(init, mhKernel(lpost, rprop), thin=1000)

See the full runnable script for further details.

Scala

In Scala we can use a similar approach to that already seen for defining a HoF to return a M-H kernel.

def mhKernel[S](
    logPost: S => Double, rprop: S => S,
    dprop: (S, S) => Double = (n: S, o: S) => 1.0
  ): ((S, Double)) => (S, Double) =
    val r = Uniform(0.0,1.0)
    state =>
      val (x0, ll0) = state
      val x = rprop(x0)
      val ll = logPost(x)
      val a = ll - ll0 + dprop(x0, x) - dprop(x, x0)
      if (math.log(r.draw()) < a)
        (x, ll)
      else
        (x0, ll0)

Note that Scala’s static typing does not prevent us from defining a function that is polymorphic in the type of the chain state, which we here call S. Also note that we are adopting a pragmatic approach to random number generation, exploiting the fact that Scala is not a pure functional language, using a mutable generator, and omitting to capture the non-determinism of the rprop function (and the returned kernel) in its type signature. In Scala this is a choice, and we could adopt a purer approach if preferred. We’ll see what such an approach will look like in Haskell, coming up next.

Now that we have the kernel, we don’t need to write an explicit runner function since Scala has good support for streaming data. There are many more-or-less sophisticated ways that we can work with data streams in Scala, and the choice depends partly on how pure one is being about tracking effects (such as non-determinism), but here I’ll just use the simple LazyList from the standard library for unfolding the kernel into an infinite MCMC chain before thinning and truncating appropriately.

  val pre = DenseVector(10.0,1.0,1.0,1.0,1.0,1.0,5.0,1.0)
  def rprop(beta: DVD): DVD = beta + pre *:* (DenseVector(Gaussian(0.0,0.02).sample(p).toArray))
  val kern = mhKernel(lpost, rprop)
  val s = LazyList.iterate((init, -Inf))(kern) map (_._1)
  val out = s.drop(150).thin(1000).take(10000)

See the full runnable script for further details.

Haskell

Since Haskell is a pure functional language, we need to have some convention regarding pseudo-random number generation. Haskell supports several styles. The most commonly adopted approach wraps a mutable generator up in a monad. The typical alternative is to use a pure functional generator and either explicitly thread the state through code or hide this in a monad similar to the standard approach. However, Haskell also supports the use of splittable generators, so we can consider all three approaches for comparative purposes. The approach taken does affect the code and the type signatures, and even the streaming data abstractions most appropriate for chain generation.

Starting with a HoF for producing a Metropolis kernel, an approach using the standard monadic generators could like like

mKernel :: (StatefulGen g m) => (s -> Double) -> (s -> g -> m s) -> 
           g -> (s, Double) -> m (s, Double)
mKernel logPost rprop g (x0, ll0) = do
  x <- rprop x0 g
  let ll = logPost(x)
  let a = ll - ll0
  u <- (genContVar (uniformDistr 0.0 1.0)) g
  let next = if ((log u) < a)
        then (x, ll)
        else (x0, ll0)
  return next

Note how non-determinism is captured in the type signatures by the monad m. The explicit pure approach is to thread the generator through non-deterministic functions.

mKernelP :: (RandomGen g) => (s -> Double) -> (s -> g -> (s, g)) -> 
            g -> (s, Double) -> ((s, Double), g)
mKernelP logPost rprop g (x0, ll0) = let
  (x, g1) = rprop x0 g
  ll = logPost(x)
  a = ll - ll0
  (u, g2) = uniformR (0, 1) g1
  next = if ((log u) < a)
        then (x, ll)
        else (x0, ll0)
  in (next, g2)

Here the updated random number generator state is returned from each non-deterministic function for passing on to subsequent non-deterministic functions. This explicit sequencing of operations makes it possible to wrap the generator state in a state monad giving code very similar to the stateful monadic generator approach, but as already discussed, the sequential nature of this approach makes it unattractive in parallel and concurrent settings.

Fortunately the standard Haskell pure generator is splittable, meaning that we can adopt a splitting approach similar to JAX if we prefer, since this is much more parallel-friendly.

mKernelP :: (RandomGen g) => (s -> Double) -> (s -> g -> s) -> 
            g -> (s, Double) -> (s, Double)
mKernelP logPost rprop g (x0, ll0) = let
  (g1, g2) = split g
  x = rprop x0 g1
  ll = logPost(x)
  a = ll - ll0
  u = unif g2
  next = if ((log u) < a)
        then (x, ll)
        else (x0, ll0)
  in next

Here non-determinism is signalled by passing a generator state (often called a "key" in the context of splittable generators) into a function. Functions receiving a key are responsible for splitting it to ensure that no key is ever used more than once.

Once we have a kernel, we need to unfold our Markov chain. When using the monadic generator approach, it is most natural to unfold using a monadic stream

mcmc :: (StatefulGen g m) =>
  Int -> Int -> s -> (g -> s -> m s) -> g -> MS.Stream m s
mcmc it th x0 kern g = MS.iterateNM it (stepN th (kern g)) x0

stepN :: (Monad m) => Int -> (a -> m a) -> (a -> m a)
stepN n fa = if (n == 1)
  then fa
  else (\x -> (fa x) >>= (stepN (n-1) fa))

whereas for the explicit approaches it is more natural to unfold into a regular infinite data stream. So, for the explicit sequential approach we could use

mcmcP :: (RandomGen g) => s -> (g -> s -> (s, g)) -> g -> DS.Stream s
mcmcP x0 kern g = DS.unfold stepUf (x0, g)
  where
    stepUf xg = let
      (x1, g1) = kern (snd xg) (fst xg)
      in (x1, (x1, g1))

and with the splittable approach we could use

mcmcP :: (RandomGen g) =>
  s -> (g -> s -> s) -> g -> DS.Stream s
mcmcP x0 kern g = DS.unfold stepUf (x0, g)
  where
    stepUf xg = let
      (x1, g1) = xg
      x2 = kern g1 x1
      (g2, _) = split g1
      in (x2, (x2, g2))

Calling these functions for our logistic regression example is similar to what we have seen before, but again there are minor syntactic differences depending on the approach. For further details see the full runnable scripts for the monadic approach, the pure sequential approach, and the splittable approach.

Dex

Dex is a pure functional language and uses a splittable random number generator, so the style we use is similar to JAX (or Haskell using a splittable generator). We can generate a Metropolis kernel with

def mKernel {s} (lpost: s -> Float) (rprop: Key -> s -> s) : 
    Key -> (s & Float) -> (s & Float) =
  def kern (k: Key) (sll: (s & Float)) : (s & Float) =
    (x0, ll0) = sll
    [k1, k2] = split_key k
    x = rprop k1 x0
    ll = lpost x
    a = ll - ll0
    u = rand k2
    select (log u < a) (x, ll) (x0, ll0)
  kern

We can then unfold our Markov chain with

def markov_chain {s} (k: Key) (init: s) (kern: Key -> s -> s) (its: Nat) :
    Fin its => s =
  with_state init \st.
    for i:(Fin its).
      x = kern (ixkey k i) (get st)
      st := x
      x

Here we combine Dex’s state effect with a for loop to unfold the stream. See the full runnable script for further details.

Next steps

As previously discussed, none of these codes are optimised, so care should be taken not to over-interpret running times. However, JAX and Dex are noticeably faster than the alternatives, even running on a single CPU core. Another interesting feature of both JAX and Dex is that they are differentiable. This makes it very easy to develop algorithms using gradient information. In subsequent posts we will think about the gradient of our example log-posterior and how we can use gradient information to develop "better" sampling algorithms.

The complete runnable scripts are all available from this public github repo.

MCMC as a Stream

Introduction

This weekend I’ve been preparing some material for my upcoming Scala for statistical computing short course. As part of the course, I thought it would be useful to walk through how to think about and structure MCMC codes, and in particular, how to think about MCMC algorithms as infinite streams of state. This material is reasonably stand-alone, so it seems suitable for a blog post. Complete runnable code for the examples in this post are available from my blog repo.

A simple MH sampler

For this post I will just consider a trivial toy Metropolis algorithm using a Uniform random walk proposal to target a standard normal distribution. I’ve considered this problem before on my blog, so if you aren’t very familiar with Metropolis-Hastings algorithms, you might want to quickly review my post on Metropolis-Hastings MCMC algorithms in R before continuing. At the end of that post, I gave the following R code for the Metropolis sampler:

metrop3<-function(n=1000,eps=0.5) 
{
        vec=vector("numeric", n)
        x=0
        oldll=dnorm(x,log=TRUE)
        vec[1]=x
        for (i in 2:n) {
                can=x+runif(1,-eps,eps)
                loglik=dnorm(can,log=TRUE)
                loga=loglik-oldll
                if (log(runif(1)) < loga) { 
                        x=can
                        oldll=loglik
                        }
                vec[i]=x
        }
        vec
}

I will begin this post with a fairly direct translation of this algorithm into Scala:

def metrop1(n: Int = 1000, eps: Double = 0.5): DenseVector[Double] = {
    val vec = DenseVector.fill(n)(0.0)
    var x = 0.0
    var oldll = Gaussian(0.0, 1.0).logPdf(x)
    vec(0) = x
    (1 until n).foreach { i =>
      val can = x + Uniform(-eps, eps).draw
      val loglik = Gaussian(0.0, 1.0).logPdf(can)
      val loga = loglik - oldll
      if (math.log(Uniform(0.0, 1.0).draw) < loga) {
        x = can
        oldll = loglik
      }
      vec(i) = x
    }
    vec
}

This code works, and is reasonably fast and efficient, but there are several issues with it from a functional programmers perspective. One issue is that we have committed to storing all MCMC output in RAM in a DenseVector. This probably isn’t an issue here, but for some big problems we might prefer to not store the full set of states, but to just print the states to (say) the console, for possible re-direction to a file. It is easy enough to modify the code to do this:

def metrop2(n: Int = 1000, eps: Double = 0.5): Unit = {
    var x = 0.0
    var oldll = Gaussian(0.0, 1.0).logPdf(x)
    (1 to n).foreach { i =>
      val can = x + Uniform(-eps, eps).draw
      val loglik = Gaussian(0.0, 1.0).logPdf(can)
      val loga = loglik - oldll
      if (math.log(Uniform(0.0, 1.0).draw) < loga) {
        x = can
        oldll = loglik
      }
      println(x)
    }
}

But now we have two version of the algorithm. One for storing results locally, and one for streaming results to the console. This is clearly unsatisfactory, but we shall return to this issue shortly. Another issue that will jump out at functional programmers is the reliance on mutable variables for storing the state and old likelihood. Let’s fix that now by re-writing the algorithm as a tail-recursion.

@tailrec
def metrop3(n: Int = 1000, eps: Double = 0.5, x: Double = 0.0, oldll: Double = Double.MinValue): Unit = {
    if (n > 0) {
      println(x)
      val can = x + Uniform(-eps, eps).draw
      val loglik = Gaussian(0.0, 1.0).logPdf(can)
      val loga = loglik - oldll
      if (math.log(Uniform(0.0, 1.0).draw) < loga)
        metrop3(n - 1, eps, can, loglik)
      else
        metrop3(n - 1, eps, x, oldll)
    }
  }

This has eliminated the vars, and is just as fast and efficient as the previous version of the code. Note that the @tailrec annotation is optional – it just signals to the compiler that we want it to throw an error if for some reason it cannot eliminate the tail call. However, this is for the print-to-console version of the code. What if we actually want to keep the iterations in RAM for subsequent analysis? We can keep the values in an accumulator, as follows.

@tailrec
def metrop4(n: Int = 1000, eps: Double = 0.5, x: Double = 0.0, oldll: Double = Double.MinValue, acc: List[Double] = Nil): DenseVector[Double] = {
    if (n == 0)
      DenseVector(acc.reverse.toArray)
    else {
      val can = x + Uniform(-eps, eps).draw
      val loglik = Gaussian(0.0, 1.0).logPdf(can)
      val loga = loglik - oldll
      if (math.log(Uniform(0.0, 1.0).draw) < loga)
        metrop4(n - 1, eps, can, loglik, can :: acc)
      else
        metrop4(n - 1, eps, x, oldll, x :: acc)
    }
}

Factoring out the updating logic

This is all fine, but we haven’t yet addressed the issue of having different versions of the code depending on what we want to do with the output. The problem is that we have tied up the logic of advancing the Markov chain with what to do with the output. What we need to do is separate out the code for advancing the state. We can do this by defining a new function.

def newState(x: Double, oldll: Double, eps: Double): (Double, Double) = {
    val can = x + Uniform(-eps, eps).draw
    val loglik = Gaussian(0.0, 1.0).logPdf(can)
    val loga = loglik - oldll
    if (math.log(Uniform(0.0, 1.0).draw) < loga) (can, loglik) else (x, oldll)
}

This function takes as input a current state and associated log likelihood and returns a new state and log likelihood following the execution of one step of a MH algorithm. This separates the concern of state updating from the rest of the code. So now if we want to write code that prints the state, we can write it as

  @tailrec
  def metrop5(n: Int = 1000, eps: Double = 0.5, x: Double = 0.0, oldll: Double = Double.MinValue): Unit = {
    if (n > 0) {
      println(x)
      val ns = newState(x, oldll, eps)
      metrop5(n - 1, eps, ns._1, ns._2)
    }
  }

and if we want to accumulate the set of states visited, we can write that as

  @tailrec
  def metrop6(n: Int = 1000, eps: Double = 0.5, x: Double = 0.0, oldll: Double = Double.MinValue, acc: List[Double] = Nil): DenseVector[Double] = {
    if (n == 0) DenseVector(acc.reverse.toArray) else {
      val ns = newState(x, oldll, eps)
      metrop6(n - 1, eps, ns._1, ns._2, ns._1 :: acc)
    }
  }

Both of these functions call newState to do the real work, and concentrate on what to do with the sequence of states. However, both of these functions repeat the logic of how to iterate over the sequence of states.

MCMC as a stream

Ideally we would like to abstract out the details of how to do state iteration from the code as well. Most functional languages have some concept of a Stream, which represents a (potentially infinite) sequence of states. The Stream can embody the logic of how to perform state iteration, allowing us to abstract that away from our code, as well.

To do this, we will restructure our code slightly so that it more clearly maps old state to new state.

def nextState(eps: Double)(state: (Double, Double)): (Double, Double) = {
    val x = state._1
    val oldll = state._2
    val can = x + Uniform(-eps, eps).draw
    val loglik = Gaussian(0.0, 1.0).logPdf(can)
    val loga = loglik - oldll
    if (math.log(Uniform(0.0, 1.0).draw) < loga) (can, loglik) else (x, oldll)
}

The "real" state of the chain is just x, but if we want to avoid recalculation of the old likelihood, then we need to make this part of the chain’s state. We can use this nextState function in order to construct a Stream.

  def metrop7(eps: Double = 0.5, x: Double = 0.0, oldll: Double = Double.MinValue): Stream[Double] =
    Stream.iterate((x, oldll))(nextState(eps)) map (_._1)

The result of calling this is an infinite stream of states. Obviously it isn’t computed – that would require infinite computation, but it captures the logic of iteration and computation in a Stream, that can be thought of as a lazy List. We can get values out by converting the Stream to a regular collection, being careful to truncate the Stream to one of finite length beforehand! eg. metrop7().drop(1000).take(10000).toArray will do a burn-in of 1,000 iterations followed by a main monitoring run of length 10,000, capturing the results in an Array. Note that metrop7().drop(1000).take(10000) is a Stream, and so nothing is actually computed until the toArray is encountered. Conversely, if printing to console is required, just replace the .toArray with .foreach(println).

The above stream-based approach to MCMC iteration is clean and elegant, and deals nicely with issues like burn-in and thinning (which can be handled similarly). This is how I typically write MCMC codes these days. However, functional programming purists would still have issues with this approach, as it isn’t quite pure functional. The problem is that the code isn’t pure – it has a side-effect, which is to mutate the state of the under-pinning pseudo-random number generator. If the code was pure, calling nextState with the same inputs would always give the same result. Clearly this isn’t the case here, as we have specifically designed the function to be stochastic, returning a randomly sampled value from the desired probability distribution. So nextState represents a function for randomly sampling from a conditional probability distribution.

A pure functional approach

Now, ultimately all code has side-effects, or there would be no point in running it! But in functional programming the desire is to make as much of the code as possible pure, and to push side-effects to the very edges of the code. So it’s fine to have side-effects in your main method, but not buried deep in your code. Here the side-effect is at the very heart of the code, which is why it is potentially an issue.

To keep things as simple as possible, at this point we will stop worrying about carrying forward the old likelihood, and hard-code a value of eps. Generalisation is straightforward. We can make our code pure by instead defining a function which represents the conditional probability distribution itself. For this we use a probability monad, which in Breeze is called Rand. We can couple together such functions using monadic binds (flatMap in Scala), expressed most neatly using for-comprehensions. So we can write our transition kernel as

def kernel(x: Double): Rand[Double] = for {
    innov <- Uniform(-0.5, 0.5)
    can = x + innov
    oldll = Gaussian(0.0, 1.0).logPdf(x)
    loglik = Gaussian(0.0, 1.0).logPdf(can)
    loga = loglik - oldll
    u <- Uniform(0.0, 1.0)
} yield if (math.log(u) < loga) can else x

This is now pure – the same input x will always return the same probability distribution – the conditional distribution of the next state given the current state. We can draw random samples from this distribution if we must, but it’s probably better to work as long as possible with pure functions. So next we need to encapsulate the iteration logic. Breeze has a MarkovChain object which can take kernels of this form and return a stochastic Process object representing the iteration logic, as follows.

MarkovChain(0.0)(kernel).
  steps.
  drop(1000).
  take(10000).
  foreach(println)

The steps method contains the logic of how to advance the state of the chain. But again note that no computation actually takes place until the foreach method is encountered – this is when the sampling occurs and the side-effects happen.

Metropolis-Hastings is a common use-case for Markov chains, so Breeze actually has a helper method built-in that will construct a MH sampler directly from an initial state, a proposal kernel, and a (log) target.

MarkovChain.
  metropolisHastings(0.0, (x: Double) =>
  Uniform(x - 0.5, x + 0.5))(x =>
  Gaussian(0.0, 1.0).logPdf(x)).
  steps.
  drop(1000).
  take(10000).
  toArray

Note that if you are using the MH functionality in Breeze, it is important to make sure that you are using version 0.13 (or later), as I fixed a few issues with the MH code shortly prior to the 0.13 release.

Summary

Viewing MCMC algorithms as infinite streams of state is useful for writing elegant, generic, flexible code. Streams occur everywhere in programming, and so there are lots of libraries for working with them. In this post I used the simple Stream from the Scala standard library, but there are much more powerful and flexible stream libraries for Scala, including fs2 and Akka-streams. But whatever libraries you are using, the fundamental concepts are the same. The most straightforward approach to implementation is to define impure stochastic streams to consume. However, a pure functional approach is also possible, and the Breeze library defines some useful functions to facilitate this approach. I’m still a little bit ambivalent about whether the pure approach is worth the additional cognitive overhead, but it’s certainly very interesting and worth playing with and thinking about the pros and cons.

Complete runnable code for the examples in this post are available from my blog repo.

Getting started with parallel MCMC

Introduction to parallel MCMC for Bayesian inference, using C, MPI, the GSL and SPRNG

Introduction

This post is aimed at people who already know how to code up Markov Chain Monte Carlo (MCMC) algorithms in C, but are interested in how to parallelise their code to run on multi-core machines and HPC clusters. I discussed different languages for coding MCMC algorithms in a previous post. The advantage of C is that it is fast, standard and has excellent scientific library support. Ultimately, people pursuing this route will be interested in running their code on large clusters of fast servers, but for the purposes of development and testing, this really isn’t necessary. A single dual-core laptop, or similar, is absolutely fine. I develop and test on a dual-core laptop running Ubuntu linux, so that is what I will assume for the rest of this post.

There are several possible environments for parallel computing, but I will focus on the Message-Passing Interface (MPI). This is a well-established standard for parallel computing, there are many implementations, and it is by far the most commonly used high performance computing (HPC) framework today. Even if you are ultimately interested in writing code for novel architectures such as GPUs, learning the basics of parallel computation using MPI will be time well spent.

MPI

The whole point of MPI is that it is a standard, so code written for one implementation should run fine with any other. There are many implementations. On Linux platforms, the most popular are OpenMPI, LAM, and MPICH. There are various pros and cons associated with each implementation, and if installing on a powerful HPC cluster, serious consideration should be given to which will be the most beneficial. For basic development and testing, however, it really doesn’t matter which is used. I use OpenMPI on my Ubuntu laptop, which can be installed with a simple:

sudo apt-get install openmpi-bin libopenmpi-dev

That’s it! You’re ready to go… You can test your installation with a simple “Hello world” program such as:

#include <stdio.h>
#include <mpi.h>

int main (int argc,char **argv)
{
  int rank, size;
  MPI_Init (&argc, &argv);
  MPI_Comm_rank (MPI_COMM_WORLD, &rank);
  MPI_Comm_size (MPI_COMM_WORLD, &size);	
  printf( "Hello world from process %d of %d\n", rank, size );
  MPI_Finalize();
  return 0;
}

You should be able to compile this with

mpicc -o helloworld helloworld.c

and run (on 2 processors) with

mpirun -np 2 helloworld

GSL

If you are writing non-trivial MCMC codes, you are almost certainly going to need to use a sophisticated math library and associated random number generation (RNG) routines. I typically use the GSL. On Ubuntu, the GSL can be installed with:

sudo apt-get install gsl-bin libgsl0-dev

A simple script to generate some non-uniform random numbers is given below.

#include <gsl/gsl_rng.h>
#include <gsl/gsl_randist.h>

int main(void)
{
  int i; double z; gsl_rng *r;
  r = gsl_rng_alloc(gsl_rng_mt19937);
  gsl_rng_set(r,0);
  for (i=0;i<10;i++) {
    z = gsl_ran_gaussian(r,1.0);
    printf("z(%d) = %f\n",i,z);
  }
  exit(EXIT_SUCCESS);
}

This can be compiled with a command like:

gcc gsl_ran_demo.c -o gsl_ran_demo -lgsl -lgslcblas

and run with

./gsl_ran_demo

SPRNG

When writing parallel Monte Carlo codes, it is important to be able to use independent streams of random numbers on each processor. Although it is tempting to “fudge” things by using a random number generator with a different seed on each processor, this does not guarantee independence of the streams, and an unfortunate choice of seeds could potentially lead to bad behaviour of your algorithm. The solution to this problem is to use a parallel random number generator (PRNG), designed specifically to give independent streams on different processors. Unfortunately the GSL does not have native support for such parallel random number generators, so an external library should be used. SPRNG 2.0 is a popular choice. SPRNG is designed so that it can be used with MPI, but also that it does not have to be. This is an issue, as the standard binary packages distributed with Ubuntu (libsprng2, libsprng2-dev) are compiled without MPI support. If you are going to be using SPRNG with MPI, things are simpler with MPI support, so it makes sense to download sprng2.0b.tar.gz from the SPRNG web site, and build it from source, carefully following the instructions for including MPI support. If you are not familiar with building libraries from source, you may need help from someone who is.

Once you have compiled SPRNG with MPI support, you can test it with the following code:

#include <stdio.h>
#include <stdlib.h>
#include <mpi.h>
#define SIMPLE_SPRNG
#define USE_MPI
#include "sprng.h"

int main(int argc,char *argv[])
{
  double rn; int i,k;
  MPI_Init(&argc,&argv);
  MPI_Comm_rank(MPI_COMM_WORLD,&k);
  init_sprng(DEFAULT_RNG_TYPE,0,SPRNG_DEFAULT);
  for (i=0;i<10;i++)
  {
    rn = sprng();
    printf("Process %d, random number %d: %f\n", k, i+1, rn);
  }
  MPI_Finalize();
  exit(EXIT_SUCCESS);
}

which can be compiled with a command like:

mpicc -I/usr/local/src/sprng2.0/include -L/usr/local/src/sprng2.0/lib -o sprng_demo sprng_demo.c -lsprng -lgmp

You will need to edit paths here to match your installation. If if builds, it can be run on 2 processors with a command like:

mpirun -np 2 sprng_demo

If it doesn’t build, there are many possible reasons. Check the error messages carefully. However, if the compilation fails at the linking stage with obscure messages about not being able to find certain SPRNG MPI functions, one possibility is that the SPRNG library has not been compiled with MPI support.

The problem with SPRNG is that it only provides a uniform random number generator. Of course we would really like to be able to use the SPRNG generator in conjunction with all of the sophisticated GSL routines for non-uniform random number generation. Many years ago I wrote a small piece of code to accomplish this, gsl-sprng.h. Download this and put it in your include path for the following example:

#include <mpi.h>
#include <gsl/gsl_rng.h>
#include "gsl-sprng.h"
#include <gsl/gsl_randist.h>

int main(int argc,char *argv[])
{
  int i,k,po; gsl_rng *r;
  MPI_Init(&argc,&argv);
  MPI_Comm_rank(MPI_COMM_WORLD,&k);
  r=gsl_rng_alloc(gsl_rng_sprng20);
  for (i=0;i<10;i++)
  {
    po = gsl_ran_poisson(r,2.0);
    printf("Process %d, random number %d: %d\n", k, i+1, po);
  }
  MPI_Finalize();
  exit(EXIT_SUCCESS);
}

A new GSL RNG, gsl_rng_sprng20 is created, by including gsl-sprng.h immediately after gsl_rng.h. If you want to set a seed, do so in the usual GSL way, but make sure to set it to be the same on each processor. I have had several emails recently from people who claim that gsl-sprng.h “doesn’t work”. All I can say is that it still works for me! I suspect the problem is that people are attempting to use it with a version of SPRNG without MPI support. That won’t work… Check that the previous SPRNG example works, first.

I can compile and run the above code with

mpicc -I/usr/local/src/sprng2.0/include -L/usr/local/src/sprng2.0/lib -o gsl-sprng_demo gsl-sprng_demo.c -lsprng -lgmp -lgsl -lgslcblas
mpirun -np 2 gsl-sprng_demo

Parallel Monte Carlo

Now we have parallel random number streams, we can think about how to do parallel Monte Carlo simulations. Here is a simple example:

#include <math.h>
#include <mpi.h>
#include <gsl/gsl_rng.h>
#include "gsl-sprng.h"

int main(int argc,char *argv[])
{
  int i,k,N; double u,ksum,Nsum; gsl_rng *r;
  MPI_Init(&argc,&argv);
  MPI_Comm_size(MPI_COMM_WORLD,&N);
  MPI_Comm_rank(MPI_COMM_WORLD,&k);
  r=gsl_rng_alloc(gsl_rng_sprng20);
  for (i=0;i<10000;i++) {
    u = gsl_rng_uniform(r);
    ksum += exp(-u*u);
  }
  MPI_Reduce(&ksum,&Nsum,1,MPI_DOUBLE,MPI_SUM,0,MPI_COMM_WORLD);
  if (k == 0) {
    printf("Monte carlo estimate is %f\n", (Nsum/10000)/N );
  }
  MPI_Finalize();
  exit(EXIT_SUCCESS);
}

which calculates a Monte Carlo estimate of the integral

\displaystyle I=\int_0^1 \exp(-u^2)du

using 10k variates on each available processor. The MPI command MPI_Reduce is used to summarise the values obtained independently in each process. I compile and run with

mpicc -I/usr/local/src/sprng2.0/include -L/usr/local/src/sprng2.0/lib -o monte-carlo monte-carlo.c -lsprng -lgmp -lgsl -lgslcblas
mpirun -np 2 monte-carlo

Parallel chains MCMC

Once parallel Monte Carlo has been mastered, it is time to move on to parallel MCMC. First it makes sense to understand how to run parallel MCMC chains in an MPI environment. I will illustrate this with a simple Metropolis-Hastings algorithm to sample a standard normal using uniform proposals, as discussed in a previous post. Here an independent chain is run on each processor, and the results are written into separate files.

#include <gsl/gsl_rng.h>
#include "gsl-sprng.h"
#include <gsl/gsl_randist.h>
#include <mpi.h>

int main(int argc,char *argv[])
{
  int k,i,iters; double x,can,a,alpha; gsl_rng *r;
  FILE *s; char filename[15];
  MPI_Init(&argc,&argv);
  MPI_Comm_rank(MPI_COMM_WORLD,&k);
  if ((argc != 3)) {
    if (k == 0)
      fprintf(stderr,"Usage: %s <iters> <alpha>\n",argv[0]);
    MPI_Finalize(); return(EXIT_FAILURE);
  }
  iters=atoi(argv[1]); alpha=atof(argv[2]);
  r=gsl_rng_alloc(gsl_rng_sprng20);
  sprintf(filename,"chain-%03d.tab",k);
  s=fopen(filename,"w");
  if (s==NULL) {
    perror("Failed open");
    MPI_Finalize(); return(EXIT_FAILURE);
  }
  x = gsl_ran_flat(r,-20,20);
  fprintf(s,"Iter X\n");
  for (i=0;i<iters;i++) {
    can = x + gsl_ran_flat(r,-alpha,alpha);
    a = gsl_ran_ugaussian_pdf(can) / gsl_ran_ugaussian_pdf(x);
    if (gsl_rng_uniform(r) < a)
      x = can;
    fprintf(s,"%d %f\n",i,x);
  }
  fclose(s);
  MPI_Finalize(); return(EXIT_SUCCESS);
}

I can compile and run this with the following commands

mpicc -I/usr/local/src/sprng2.0/include -L/usr/local/src/sprng2.0/lib -o mcmc mcmc.c -lsprng -lgmp -lgsl -lgslcblas
mpirun -np 2 mcmc 100000 1

Parallelising a single MCMC chain

The parallel chains approach turns out to be surprisingly effective in practice. Obviously the disadvantage of that approach is that “burn in” has to be repeated on every processor, which limits how much efficiency gain can be acheived by running across many processors. Consequently it is often desirable to try and parallelise a single MCMC chain. As MCMC algorithms are inherently sequential, parallelisation is not completely trivial, and most (but not all) approaches to parallelising a single MCMC chain focus on the parallelisation of each iteration. In order for this to be worthwhile, it is necessary that the problem being considered is non-trivial, having a large state space. The strategy is then to divide the state space into “chunks” which can be updated in parallel. I don’t have time to go through a real example in detail in this blog post, but fortunately I wrote a book chapter on this topic almost 10 years ago which is still valid and relevant today. The citation details are:

Wilkinson, D. J. (2005) Parallel Bayesian Computation, Chapter 16 in E. J. Kontoghiorghes (ed.) Handbook of Parallel Computing and Statistics, Marcel Dekker/CRC Press, 481-512.

The book was eventually published in 2005 after a long delay. The publisher which originally commisioned the handbook (Marcel Dekker) was taken over by CRC Press before publication, and the project lay dormant for a couple of years until the new publisher picked it up again and decided to proceed with publication. I have a draft of my original submission in PDF which I recommend reading for further information. The code examples used are also available for download, including several of the examples used in this post, as well as an extended case study on parallelisation of a single chain for Bayesian inference in a stochastic volatility model. Although the chapter is nearly 10 years old, the issues discussed are all still remarkably up-to-date, and the code examples all still work. I think that is a testament to the stability of the technology adopted (C, MPI, GSL). Some of the other handbook chapters have not stood the test of time so well.

For basic information on getting started with MPI and key MPI commands for implementing parallel MCMC algorithms, the above mentioned book chapter is a reasonable place to start. Read it all through carefully, run the examples, and carefully study the code for the parallel stochastic volatility example. Once that is understood, you should find it possible to start writing your own parallel MCMC algorithms. For further information about more sophisticated MPI usage and additional commands, I find the annotated specification: MPI – The complete reference to be as good a source as any.