MCMC programming in R, Python, Java and C

EDIT 16/7/11 – this post has now been largely superceded by a new post!

MCMC

Markov chain Monte Carlo (MCMC) is a powerful simulation technique for exploring high-dimensional probability distributions. It is particularly useful for exploring posterior probability distributions that arise in Bayesian statistics. Although there are some generic tools (such as WinBugs and JAGS) for doing MCMC, for non-standard problems it is often desirable to code up MCMC algorithms from scratch. It is then natural to wonder what programming languages might be good for this purpose. There are hundreds of programming languages one could in principle use for MCMC programming, but it is necessary to use a language with a good scientific library, including good random number generation routines. I have used a large number of programming languages over the years, but these days I mostly program in R, Python, Java or C. I find each of these languages interesting and useful, with different strengths and weaknesses, and I find that no one of these languages dominates any of the others in all situations.

Example

For the purposes of this post we will focus on Gibbs sampling for a simple bivariate distribution defined on the half-plane x>0.

f(x,y) = k x2 exp{-xy2-y2+2y-4x}

The statistical motivation is not important for this post, but this is the kind of distribution which arises naturally in Bayesian inference for the mean and variance of a normal random sample. Gibbs sampling is a simple MCMC scheme which samples in turn from the full-conditional distributions. In this case, the full-conditional for x is Gamma(3,y2+4) and the full-conditional for y is N(1/(x+1),1/(x+1)). The resulting Gibbs sampling algorithm is very simple and works very efficiently, but for illustrative purposes, we will consider generating 20,000 iterations with a “thin” of 500 iterations.

Implementations

R

Although R has many flaws, it is well suited to programming with data, and has a huge array of statistical libraries associated with it. Like many statisticians, I probably use R more than any other language in my day-to-day work. It is the language I always use for the analysis of MCMC output. It is therefore natural to think about using R for coding up MCMC algorithms. Below is a simple function to implement a Gibbs sampler for this problem.

gibbs=function(N,thin)
{
	mat=matrix(0,ncol=3,nrow=N)
	mat[,1]=1:N
	x=0
	y=0
	for (i in 1:N) {
		for (j in 1:thin) {
			x=rgamma(1,3,y*y+4)
			y=rnorm(1,1/(x+1),1/sqrt(x+1))
		}
		mat[i,2:3]=c(x,y)
	}
	mat=data.frame(mat)
	names(mat)=c("Iter","x","y")
	mat
}

This function stores the output in a data frame. It can be called and the data written out to a file as follows:

writegibbs=function(N=20000,thin=500)
{
	mat=gibbs(N,thin)
	write.table(mat,"data.tab",row.names=FALSE)
}
tim=system.time(writegibbs())
print(tim)

This also times how long the function takes to run. We will return to timings later. As already mentioned, I always analyse MCMC output in R. Below is some R code to compare the contours of the true distribution with the contours of the empirical distribution obtained by a 2d kernel smoother.

fun=function(x,y)
{
	x*x*exp(-x*y*y-y*y+2*y-4*x)
}

mat=read.table("data.tab",header=TRUE)
op=par(mfrow=c(2,1))
x=seq(0,4,0.1)
y=seq(-2,4,0.1)
z=outer(x,y,fun)
contour(x,y,z,main="Contours of actual distribution")
require(KernSmooth)
fit=bkde2D(as.matrix(mat[,2:3]),c(0.1,0.1))
contour(fit$x1,fit$x2,fit$fhat,main="Contours of empirical distribution")
par(op)

Python

Python is a great general purpose programming language. In an ideal world, I would probably use Python for all of my programming needs. These days it is also a reasonable platform for scientific computing, due to the scipy project. It is certainly possible to develop MCMC algorithms in Python. A simple Python script for this problem can be written as follows.

import random,math

def gibbs(N=20000,thin=500):
    x=0
    y=0
    print "Iter  x  y"
    for i in range(N):
        for j in range(thin):
            x=random.gammavariate(3,1.0/(y*y+4))
            y=random.gauss(1.0/(x+1),1.0/math.sqrt(x+1))
        print i,x,y
    
gibbs()

This can be run with a command like python gibbs.py > data.tab. Note that python uses a different parametrisation of the gamma distribution to R.

C

The language I have used most for the development of MCMC algorithms is C (usually strict ANSI C). C is fast and efficient, and gives the programmer lots of control over how computations are carried out. The GSL is an excellent scientific library for C, which includes good random number generation facilities. A program for this problem is given below.

#include <stdio.h>
#include <math.h>
#include <stdlib.h>
#include <gsl/gsl_rng.h>
#include <gsl/gsl_randist.h>

void main()
{
  int N=20000;
  int thin=500;
  int i,j;
  gsl_rng *r = gsl_rng_alloc(gsl_rng_mt19937);
  double x=0;
  double y=0;
  printf("Iter x y\n");
  for (i=0;i<N;i++) {
    for (j=0;j<thin;j++) {
      x=gsl_ran_gamma(r,3.0,1.0/(y*y+4));
      y=1.0/(x+1)+gsl_ran_gaussian(r,1.0/sqrt(x+1));
    }
    printf("%d %f %f\n",i,x,y);
  }
}

This can be compiled and run (on Linux) with commands like:

gcc -O2 -lgsl -lgslcblas gibbs.c -o gibbs
time ./gibbs > data.tab 

Java

Java is slightly nicer to code in than C, and has certain advantages over C in web application and cloud computing environments. It also has a few reasonable scientific libraries. I mainly use Parallel COLT, which (despite the name) is a general purpose scientific library, and not just for parallel applications. Using this library, code for this problem is given below.

import java.util.*;

class Gibbs
{

    public static void main(String[] arg)
    {
	int N=20000;
	int thin=500;
	cern.jet.random.tdouble.engine.DoubleRandomEngine rngEngine=
	    new cern.jet.random.tdouble.engine.DoubleMersenneTwister(new Date());
	cern.jet.random.tdouble.Normal rngN=
	    new cern.jet.random.tdouble.Normal(0.0,1.0,rngEngine);
	cern.jet.random.tdouble.Gamma rngG=
	    new cern.jet.random.tdouble.Gamma(1.0,1.0,rngEngine);
	double x=0;
	double y=0;
	System.out.println("Iter x y");
	for (int i=0;i<N;i++) {
	    for (int j=0;j<thin;j++) {
		x=rngG.nextDouble(3.0,y*y+4);
		y=rngN.nextDouble(1.0/(x+1),1.0/Math.sqrt(x+1));
	    }
	    System.out.println(i+" "+x+" "+y);
	}
    }

}

Assuming that Parallel COLT is in your classpath, this can be compiled and run with commands like.

javac Gibbs.java
time java Gibbs > data.tab

Timings

Complex MCMC algorithms can be really slow to run (weeks or months), so execution speed really does matter. So, how do these languages compare for this particular simple problem? R is the slowest, so let’s think about the speed relative to R (on my laptop). The Python code ran 2.4 times faster than R. To me, that is not really worth worrying too much about. Differences in coding style and speed-up tricks can make much bigger differences than this. So I wouldn’t choose between these two on the basis of speed differential alone. Python is likely to be nicer to develop in, but the convenience of doing the computation in R, the standard platform for statistical analysis, could nevertheless tip things in favour of R. C was the fastest, and this is the reason that most of my MCMC code development has been traditionally done in C. It was around 60 times faster than R. This difference is worth worrying about, and can make the difference between an MCMC code being practical to run or not. So how about Java? Many people assume that Java must be much slower than C, as it is byte-code compiled, and runs in a virtual machine. In this case it is around 40 times faster than R. This is within a factor of 2 of C. This is quite typical in my experience, and I have seen codes which run faster in Java. Again, I wouldn’t choose between C and Java purely on the basis of speed, but rather on speed of development and ease of deployment. Java can have advantages over C on both counts, which is why I’ve been experimenting more with Java for MCMC codes recently.

Summary

R and Python are slow, on account of their dynamic type systems, but are quick and easy to develop in. Java and C are much faster. The speed advantages of Java and C can be important in the context of complex MCMC algorithms. One possibility that is often put forward is to prototype in a dynamic language like R or Python and then port (the slow parts) to a statically typed language later. This could be time efficient, as the debugging and re-factoring can take place in the dynamic language where it is easier, then just re-coded fairly directly into the statically typed language once the code is working well. Actually, I haven’t found this to be time efficient for the development of MCMC codes. That could be just me, but it could be a feature of MCMC algorithms in that the dynamic code runs too slowly to be amenable to quick and simple debugging.

Advertisements