Tuning particle MCMC algorithms

Several papers have appeared recently discussing the issue of how to tune the number of particles used in the particle filter within a particle MCMC algorithm such as particle marginal Metropolis Hastings (PMMH). Three such papers are:

I have discussed psuedo marginal MCMC and particle MCMC algorithms in previous posts. It will be useful to refer back to these posts if these topics are unfamiliar. Within particle MCMC algorithms (and psuedo-marginal MCMC algorithms, more generally), an unbiased estimate of marginal likelihood is constructed using a number of particles. The more particles that are used, the better the estimate of marginal likelihood is, and the resulting MCMC algorithm will behave more like a “real” marginal MCMC algorithm. For a small number of particles, the algorithm will still have exactly the correct target, but the noise in the unbiased estimator of marginal likelihood will lead to poor mixing of the MCMC chain. The idea is to use just enough particles to ensure that there isn’t “too much” noise in the unbiased estimator, but not to waste lots of time producing a super-accurate estimate of marginal likelihood if that isn’t necessary to ensure good mixing of the MCMC chain.

The papers above try to give theoretical justifications for certain “rules of thumb” that are commonly used in practice. One widely adopted scheme is to tune the number of particles so that the variance of the log of the estimate of marginal liklihood is around one. The obvious questions are “where?” and “why?”, and these questions turn out to be connected. As we will see, there isn’t really a good answer to the “where?” question, but what people usually do is use a pilot run to get an estimate of the posterior mean, or mode, or MLE, and then pick one and tune the noise variance at that particular parameter value. As to “why?”, well, the papers above make various (slightly different) assumptions, all of which lead to trading off mixing against computation time to obtain an “optimal” number of particles. They don’t all agree that the variance of the noise should be exactly 1, but they all agree to an order of magnitude.

All of the above papers make the assumption that the noise distribution associated with the marginal likelihood estimate is independent of the parameter at which it is being evaluated, which explains why there isn’t a really good answer to the “where?” question – under the assumption it doesn’t matter what parameter value is used for tuning – they are all the same! Easy. Except that’s quite a big assumption, so it would be nice to know that it is reasonable, and unfortunately it isn’t. Let’s look at an example to see what goes wrong.

Example

In Chapter 10 of my book I look in detail at constructing a PMMH algorithm for inferring the parameters of a discretely observed stochastic Lotka-Volterra model. I’ve stepped through the computational details in a previous post which you should refer back to for the necessary background. Following that post, we can construct a particle filter to return an unbiased estimate of marginal likelihood using the following R code (which relies on the smfsb CRAN package):

require(smfsb)
# data
data(LVdata)
data=as.timedData(LVnoise10)
noiseSD=10
# measurement error model
dataLik <- function(x,t,y,log=TRUE,...)
{
    ll=sum(dnorm(y,x,noiseSD,log=TRUE))
    if (log)
        return(ll)
    else
        return(exp(ll))
}
# now define a sampler for the prior on the initial state
simx0 <- function(N,t0,...)
{
    mat=cbind(rpois(N,50),rpois(N,100))
    colnames(mat)=c("x1","x2")
    mat
}
# construct particle filter
mLLik=pfMLLik(150,simx0,0,stepLVc,dataLik,data)

Again, see the relevant previous post for details. So now mLLik() is a function that will return the log of an unbiased estimate of marginal likelihood (based on 150 particles) given a parameter value at which to evaluate.

What we are currently wondering is whether the noise in the estimate is independent of the parameter at which it is evaluated. We can investigate this for this filter easily by looking at how the estimate varies as the first parameter (prey birth rate) varies. The following code computes a log likelihood estimate across a range of values and plots the result.

mLLik1=function(x){mLLik(th=c(th1=x,th2=0.005,th3=0.6))}
x=seq(0.7,1.3,length.out=5001)
y=sapply(x,mLLik1)
plot(x[y>-1e10],y[y>-1e10])

The resulting plot is as follows (click for full size):

Log marginal likelihood

So, looking at the plot, it is very clear that the noise variance certainly isn’t constant as the parameter varies – it varies substantially. Furthermore, the way in which it varies is “dangerous”, in that the noise is smallest in the vicinity of the MLE. So, if a parameter close to the MLE is chosen for tuning the number of particles, this will ensure that the noise is small close to the MLE, but not elsewhere in parameter space. This could have bad consequences for the mixing of the MCMC algorithm as it explores the tails of the posterior distribution.

So with the above in mind, how should one tune the number of particles in a pMCMC algorithm? I can’t give a general answer, but I can explain what I do. We can’t rely on theory, so a pragmatic approach is required. The above rule of thumb usually gives a good starting point for exploration. Then I just directly optimise ESS per CPU second of the pMCMC algorithm from pilot runs for varying numbers of particles (and other tuning parameters in the algorithm). ESS is “expected sample size”, which can be estimated using the effectiveSize() function in the coda CRAN package. Ugly and brutish, but it works…

Advertisements

A functional Gibbs sampler in Scala

For many years I’ve had a passing interest in functional programming and languages which support functional programming approaches. I’m also quite interested in MOOCs and their future role in higher education. So I recently signed up for my first on-line course, Functional Programming Principles in Scala, via Coursera. I’m around half way through the course at the time of writing, and I’m enjoying it very much. I knew that I didn’t know much about Scala before starting the course, but during the course I’ve also learned that I didn’t know as much about functional programming as I thought I did, either! 😉 The course itself is very interesting, the assignments are well designed and appropriately challenging, and the web infrastructure to support the course is working well. I suspect I’ll try other on-line courses in the future.

Functional programming emphasises immutability, and discourages imperative programming approaches that use variables that can be modified during run-time. There are many advantages to immutability, especially in the context of parallel and concurrent programming, which is becoming increasingly important as multi-core systems become the norm. I’ve always found functional programming to be intellectually appealing, but have often worried about the practicalities of using functional programming in the context of scientific computing where many algorithms are iterative in nature, and are typically encoded using imperative approaches. The Scala programming language is appealing to me as it supports both imperative and functional styles of programming, as well as object oriented approaches. However, as a result of taking this course I am now determined to pursue functional approaches further, and get more of a feel for how practical they are for scientific computing applications.

For my first experiment, I’m going back to my post describing a Gibbs sampler in various languages. See that post for further details of the algorithm. In that post I did have an example implementation in Scala, which looked like this:

object GibbsSc {
 
    import cern.jet.random.tdouble.engine.DoubleMersenneTwister
    import cern.jet.random.tdouble.Normal
    import cern.jet.random.tdouble.Gamma
    import Math.sqrt
    import java.util.Date
 
    def main(args: Array[String]) {
        val N=50000
        val thin=1000
        val rngEngine=new DoubleMersenneTwister(new Date)
        val rngN=new Normal(0.0,1.0,rngEngine)
        val rngG=new Gamma(1.0,1.0,rngEngine)
        var x=0.0
        var y=0.0
        println("Iter x y")
        for (i <- 0 until N) {
            for (j <- 0 until thin) {
                x=rngG.nextDouble(3.0,y*y+4)
                y=rngN.nextDouble(1.0/(x+1),1.0/sqrt(2*x+2))
            }
            println(i+" "+x+" "+y)
        }
    }
 
}

At the time I wrote that post I knew even less about Scala than I do now, so I created the code by starting from the Java version and removing all of the annoying clutter! 😉 Clearly this code has an imperative style, utilising variables (declared with var) x and y having mutable state that is updated by a nested for loop. This algorithm is typical of the kind I use every day, so if I can’t re-write this in a more functional style, removing all mutable variables from my code, then I’m not going to get very far with functional programming!

In fact it is very easy to re-write this in a more functional style without utilising mutable variables. One possible approach is presented below.

object FunGibbs {
 
    import cern.jet.random.tdouble.engine.DoubleMersenneTwister
    import cern.jet.random.tdouble.Normal
    import cern.jet.random.tdouble.Gamma
    import java.util.Date
    import scala.math.sqrt

    val rngEngine=new DoubleMersenneTwister(new Date)
    val rngN=new Normal(0.0,1.0,rngEngine)
    val rngG=new Gamma(1.0,1.0,rngEngine)

    class State(val x: Double,val y: Double)

    def nextIter(s: State): State = {
         val newX=rngG.nextDouble(3.0,(s.y)*(s.y)+4.0)
         new State(newX, 
              rngN.nextDouble(1.0/(newX+1),1.0/sqrt(2*newX+2)))
    }

    def nextThinnedIter(s: State,left: Int): State = {
       if (left==0) s 
       else nextThinnedIter(nextIter(s),left-1)
    }

    def genIters(s: State,current: Int,stop: Int,thin: Int): State = {
         if (!(current>stop)) {
             println(current+" "+s.x+" "+s.y)
             genIters(nextThinnedIter(s,thin),current+1,stop,thin)
         }
         else s
    }

    def main(args: Array[String]) {
        println("Iter x y")
        genIters(new State(0.0,0.0),1,50000,1000)
     }

}

Although it is a few lines longer, it is a fairly clean implementation, and doesn’t look like a hack. Like many functional programs, this one makes extensive use of recursion. This is one of the things that has always concerned me about functional programming – many scientific computing applications involve lots of iteration, and that can potentially translate into very deep recursion. The above program has an apparent recursion depth of 50 million! However, it runs fine without crashing despite the fact that most programming languages will crash out with a stack overflow with recursion depths of more than a couple of thousand. So why doesn’t this crash? It runs fine because the recursion I used is a special form of recursion known as a tail call. Most functional (and some imperative) programming languages automatically perform tail call elimination which essentially turns the tail call into an iteration which runs very fast without creating new stack frames. In fact, this functional version of the code runs at roughly the same speed as the iterative version I presented first (perhaps just a few percent slower – I haven’t done careful timings), and runs well within a factor of 2 of imperative C code. So actually this seems perfectly practical so far, and I’m looking forward to experimenting more with functional programming approaches to statistical computation over the coming months…