Metropolis Hastings MCMC when the proposal and target have differing support

Introduction

Very often it is desirable to use Metropolis Hastings MCMC for a target distribution which does not have full support (for example, it may correspond to a non-negative random variable), using a proposal distribution which does (for example, a Gaussian random walk proposal). This isn’t a problem at all, but on more than one occasion now I have come across students getting this wrong, so I thought it might be useful to have a brief post on how to do it right, see what people sometimes get wrong, and why, and then think about correcting the wrong method in order to make it right…

A simple example

For this post we will consider a simple Ga(2,1) target distribution, with density

\pi(x) = xe^{-x},\quad x\geq 0.

Of course this is a very simple distribution, and there are many straightforward ways to simulate it directly, but for this post we will use a random walk Metropolis-Hastings (MH) scheme with standard Gaussian innovations. So, if the current state of the chain is x, a proposed new value x^\star will be generated from

f(x^\star|x) = \phi(x^\star-x),

where \phi(\cdot) is the standard normal density. This proposed new value is accepted with probability \min\{1,A\}, where

\displaystyle A = \frac{\pi(x^\star)}{\pi(x)} \frac{f(x|x^\star)}{f(x^\star|x)} = \frac{\pi(x^\star)}{\pi(x)} \frac{\phi(x-x^\star)}{\phi(x^\star-x)} = \frac{\pi(x^\star)}{\pi(x)} ,

since the standard normal density is symmetric.

Correct implementation

We can easily implement this using R as follows:

met1=function(iters)
  {
    xvec=numeric(iters)
    x=1
    for (i in 1:iters) {
      xs=x+rnorm(1)
      A=dgamma(xs,2,1)/dgamma(x,2,1)
      if (runif(1)<A)
        x=xs
      xvec[i]=x
    }
    return(xvec)
  }

We can run it, plot the results and check it against the true target with the following commands.

iters=1000000
out=met1(iters)
hist(out,100,freq=FALSE,main="met1")
curve(dgamma(x,2,1),add=TRUE,col=2,lwd=2)

If you have a slow computer, you may prefer to use iters=100000. The above code uses R’s built-in gamma density. Alternatively, we can hard-code the density as follows.

met2=function(iters)
  {
    xvec=numeric(iters)
    x=1
    for (i in 1:iters) {
      xs=x+rnorm(1)
      A=xs*exp(-xs)/(x*exp(-x))
      if (runif(1)<A)
        x=xs
      xvec[i]=x
    }
    return(xvec)
  }

We can run this code using the following commands, to verify that it does work as expected.

out=met2(iters)
hist(out,100,freq=FALSE,main="met2")
curve(dgamma(x,2,1),add=TRUE,col=2,lwd=2)

However, there is a potential problem with the above code that we have got away with in this instance, which often catches people out. We have hard-coded the density for x>0 without checking the sign of x. Here we get away with it as a negative proposal will lead to a negative acceptance ratio that we will reject straight away. This is not always the case (consider, for example, a Ga(3,1) distribution). So really we should check the sign of x^\star and reject immediately if is not within the support of the target.

Although this problem often catches people out, it tends not to be a big issue in practice, as it typically leads to an obviously incorrect sampler, or a sampler which crashes, and is relatively simple to debug and fix.

An incorrect sampler

The problem I want to focus on here is more subtle, but closely related. It is clear that any x^\star<0 should be rejected. With the above code, such values are indeed rejected, and the sampler advances to the next iteration. However, in more complex samplers, where an update like this might be one tiny part of a massive sampler with a very high-dimensional state space, it seems like a bit of a "waste" of a MH move to just propose a negative value, throw it away, and move on. Evidently, it seems tempting, therefore, to keep on sampling x^\star values until a non-negative value is obtained, and then evaluate the acceptance ratio and decide whether or not to accept. We could code up this sampler as follows.

met3=function(iters)
  {
    xvec=numeric(iters)
    x=1
    for (i in 1:iters) {
      repeat {
        xs=x+rnorm(1)
        if (xs>0)
          break
      }
      A=xs*exp(-xs)/(x*exp(-x))
      if (runif(1)<A)
        x=xs
      xvec[i]=x
    }
    return(xvec)
  }

As reasonable as this idea may at first seem, it does not lead to a sampler having the desired target, as can be verified using the following commands.

out=met3(iters)
hist(out,100,freq=FALSE,main="met3")
curve(dgamma(x,2,1),add=TRUE,col=2,lwd=2)

So, this sampler seems to be sampling something close to the desired target, but not the same. This raises a couple of questions. First and most important, can we fix this sampler so that it does sample the correct target (yes), and second, can we figure out what target density the incorrect sampler is actually sampling (again, yes)? Let’s start with the issue of how to fix the sampler, as this will also help us to understand what the incorrect sampler is doing.

Fixing the truncated sampler

By repeatedly sampling from the proposal until we obtain a non-negative value, we are actually implementing a rejection sampler for sampling from the proposal distribution truncated at zero. This is a perfectly reasonable proposal distribution, so we can use it provided that we use the correct MH acceptance ratio. Now, the truncated density has the same density as the untruncated density, apart from the differing support and a normalising constant. Indeed, this may be why people often assume this method will work, because normalising constants often don’t matter in MH schemes. However, the normalising constant only doesn’t matter if it is independent of the state, and here it is not… Explicitly, we have

f(x^\star|x) \propto \phi(x^\star-x),\quad x^\star>0.

Including the normalising constant we have

\displaystyle f(x^\star|x) = \frac{\phi(x^\star-x)}{\Phi(x)},\quad x^\star>0,

where \Phi(x) is the standard normal CDF. Consequently, the correct acceptance ratio to use with this proposal is

\displaystyle A = \frac{\pi(x^\star)}{\pi(x)} \frac{\phi(x-x^\star)}{\phi(x^\star-x)}\frac{\Phi(x)}{\Phi(x^\star)} =   \frac{\pi(x^\star)}{\pi(x)}\frac{\Phi(x)}{\Phi(x^\star)},

where we see that the normalising constants do not cancel out. We can modify the previous sampler to use the correct acceptance ratio as follows.

met4=function(iters)
  {
    xvec=numeric(iters)
    x=1
    for (i in 1:iters) {
      repeat {
        xs=x+rnorm(1)
        if (xs>0)
          break
      }
      A=xs*exp(-xs)/(x*exp(-x))
      A=A*pnorm(x)/pnorm(xs)
      if (runif(1)<A)
        x=xs
      xvec[i]=x
    }
    return(xvec)
  }

We can verify that this sampler gives leads to the correct target with the following commands.

out=met4(iters)
hist(out,100,freq=FALSE,main="met4")
curve(dgamma(x,2,1),add=TRUE,col=2,lwd=2)

So, truncating the proposal at zero is fine, provided that you modify the acceptance ratio accordingly.

What does the incorrect sampler target?

Now that we understand why the naive truncated sampler was wrong and how to fix it, we can, out of curiosity, wonder what distribution that sampler actually targets. Now we understand what proposal we are actually using, we can re-write the acceptance ratio as

\displaystyle A = \frac{\pi(x^\star)\Phi(x^\star)}{\pi(x)\Phi(x)}\frac{\frac{\phi(x-x^\star)}{\Phi(x^\star)}}{\frac{\phi(x^\star-x)}{\Phi(x)}},

from which it is clear that the actual target of this chain is

\tilde\pi(x) \propto \pi(x)\Phi(x),

or

\tilde\pi(x)\propto xe^{-x}\Phi(x),\quad x\geq 0.

The constant of proportionality is not immediately obvious, but is tractable, and turns out to be a nice undergraduate exercise in integration by parts, leading to

\displaystyle \tilde\pi(x) = \frac{2\sqrt{2\pi}}{2+\sqrt{2\pi}}xe^{-x}\Phi(x),\quad x\geq 0.

We can verify this using the following commands.

out=met3(iters)
hist(out,100,freq=FALSE,main="met3")
curve(dgamma(x,2,1)*pnorm(x)*2*sqrt(2*pi)/(sqrt(2*pi)+2),add=TRUE,col=3,lwd=2)

Now we know the actual target of the incorrect sampler, we can compare it with the correct target as follows.

curve(dgamma(x,2,1),0,10,col=2,lwd=2,main="Densities")
curve(dgamma(x,2,1)*pnorm(x)*2*sqrt(2*pi)/(sqrt(2*pi)+2),add=TRUE,col=3,lwd=2)

So we see that the distributions are different, but not so different that one would immediate suspect an error on the basis of a sample of output. This makes it a difficult bug to track down.

Summary

There is no problem in principle using a proposal with full support for a target with limited support in MH algorithms. However, it is important to check whether a proposed value is within the support of the target and reject the proposed move if it is not. If you are concerned that such a scheme might be inefficient, it is possible to use a truncated proposal provided that you modify the MH acceptance ratio to include the relevant normalisation constants. If you don’t modify the acceptance probability, you will get a sampler which targets the wrong distribution, but it will often be quite similar to the correct target, making it a difficult bug to spot and track down.

Advertisements

Gibbs sampling a Gaussian Markov random field (GMRF) using Java

Introduction

As I’ve explained previously, I’m gradually coming around to the idea of using Java for the development of MCMC codes, and I’m starting to build up a collection of simple examples for getting started. One of the advantages of Java is that it includes a standard cross-platform GUI library. This might not seem like the most important requirement for MCMC, but can actually be very handy in several contexts, particularly for monitoring convergence. One obvious context is that of image analysis, where it can be useful to monitor image reconstructions as the sampler is running. In this post I’ll show three very small simple Java classes which together provide an application for running a Gibbs sampler on a (non-stationary, unconditioned) Gaussian Markov random field.

The model is essentially that the distribution of each pixel is defined intrinsically, dependent only on its four nearest neighbours on a rectangular lattice, and here the distribution will be Gaussian with mean equal to the sample mean of the four neighbouring pixels and a fixed (unit) variance. On its own this isn’t especially useful, but it is a key component of many image analysis applications.

A simple Java implementation

We will start with the class MrfApp containing the main method for the application:

MrfApp.java

import java.io.*;
class MrfApp {
    public static void main(String[] arg)
	throws IOException
    {
	Mrf mrf;
	System.out.println("started program");
	mrf=new Mrf(800,600);
	System.out.println("created mrf object");
	mrf.update(1000);
	System.out.println("done updates");
	mrf.saveImage("mrf.png");
	System.out.println("finished program");
	mrf.frame.dispose();
	System.exit(0);
    }
}

Hopefully this code is largely self-explanatory, but relies on a class called Mrf which contains all of the logic associated with the GMRF.

Mrf.java

import java.io.*;
import java.util.*;
import java.awt.image.*;
import javax.swing.*;
import javax.imageio.ImageIO;


class Mrf 
{
    int n,m;
    double[][] cells;
    Random rng;
    BufferedImage bi;
    WritableRaster wr;
    JFrame frame;
    ImagePanel ip;
    
    Mrf(int n_arg,int m_arg)
    {
	n=n_arg;
	m=m_arg;
	cells=new double[n][m];
	rng=new Random();
	bi=new BufferedImage(n,m,BufferedImage.TYPE_BYTE_GRAY);
	wr=bi.getRaster();
	frame=new JFrame("MRF");
	frame.setSize(n,m);
	frame.add(new ImagePanel(bi));
	frame.setVisible(true);
    }
    
    public void saveImage(String filename)
	throws IOException
    {
	ImageIO.write(bi,"PNG",new File(filename));
    }
    
    public void updateImage()
    {
	double mx=-1e+100;
	double mn=1e+100;
	for (int i=0;i<n;i++) {
	    for (int j=0;j<m;j++) {
		if (cells[i][j]>mx) { mx=cells[i][j]; }
		if (cells[i][j]<mn) { mn=cells[i][j]; }
	    }
	}
	for (int i=0;i<n;i++) {
	    for (int j=0;j<m;j++) {
		int level=(int) (255*(cells[i][j]-mn)/(mx-mn));
		wr.setSample(i,j,0,level);
	    }
	}
	frame.repaint();
    }
    
    public void update(int num)
    {
	for (int i=0;i<num;i++) {
	    updateOnce();
	}
    }
    
    private void updateOnce()
    {
	double mean;
	for (int i=0;i<n;i++) {
	    for (int j=0;j<m;j++) {
		if (i==0) {
		    if (j==0) {
			mean=0.5*(cells[0][1]+cells[1][0]);
		    } 
		    else if (j==m-1) {
			mean=0.5*(cells[0][j-1]+cells[1][j]);
		    } 
		    else {
			mean=(cells[0][j-1]+cells[0][j+1]+cells[1][j])/3.0;
		    }
		}
		else if (i==n-1) {
		    if (j==0) {
			mean=0.5*(cells[i][1]+cells[i-1][0]);
		    }
		    else if (j==m-1) {
			mean=0.5*(cells[i][j-1]+cells[i-1][j]);
		    }
		    else {
			mean=(cells[i][j-1]+cells[i][j+1]+cells[i-1][j])/3.0;
		    }
		}
		else if (j==0) {
		    mean=(cells[i-1][0]+cells[i+1][0]+cells[i][1])/3.0;
		}
		else if (j==m-1) {
		    mean=(cells[i-1][j]+cells[i+1][j]+cells[i][j-1])/3.0;
		}
		else {
		    mean=0.25*(cells[i][j-1]+cells[i][j+1]+cells[i+1][j]
			       +cells[i-1][j]);
		}
		cells[i][j]=mean+rng.nextGaussian();
	    }
	}
	updateImage();
    }
    
}

This class contains a few simple methods for creating and updating the GMRF, and also for maintaining and updating a graphical view of the GMRF as the sampler is running. The Gibbs sampler update itself is encoded in the final method, updateOnce, and most of the code is to deal with edge and corner cases (in the literal rather than metaphorical sense!). This is called repeatedly by the method update for the required number of iterations. At the end of each iteration, the method updateOnce triggers updateImage which updates the image associated GMRF. The GMRF itself is stored in a 2-dimensional array of doubles, but an image pixel typically consists of a grayscale value represented by an unsigned byte – that is, an integer from 0 to 255. So updateImage scans through the GMRF to find the maximum and minimum values and then maps the GMRF values onto the 0 to 255 scale. The image itself is set up by the constructor method, Mrf. This class relies on an additional class called ImagePanel, which is a simple GUI panel for displaying images:

ImagePanel.java

import java.awt.*;
import java.awt.image.*;
import javax.swing.*;

class ImagePanel extends JPanel {

	protected BufferedImage image;

	public ImagePanel(BufferedImage image) {
		this.image=image;
		Dimension dim=new Dimension(image.getWidth(),image.getHeight());
		setPreferredSize(dim);
		setMinimumSize(dim);
		revalidate();
		repaint();
	}

	public void paintComponent(Graphics g) {
		g.drawImage(image,0,0,this);
	}

}

This completes the application, which can be compiled and run from the command line with

javac *.java
java MrfApp

This should compile the code and run the application, which will show a GMRF updating for 1000 iterations. When the 1000 iterations are complete, the application writes the final image to a file and then quits.

Using Parallel COLT

The above classes are very convenient, as they should work with any standard Java installation. However, in more complex scenarios, it is likely that a math library such as Parallel COLT will be required. In this case it will make sense to make use of features in the COLT library, such as random number generators and 2d matrix objects. We can adapt the above application by replacing the MrfApp and Mrf classes with the following versions (the ImagePanel class remains unchanged):

MrfApp.java

import java.io.*;
import cern.jet.random.tdouble.engine.*;

class MrfApp {

    public static void main(String[] arg)
	throws IOException
    {
	Mrf mrf;
	int seed=1234;
	System.out.println("started program");
        DoubleRandomEngine rngEngine=new DoubleMersenneTwister(seed);
	mrf=new Mrf(800,600,rngEngine);
	System.out.println("created mrf object");
	mrf.update(1000);
	System.out.println("done updates");
	mrf.saveImage("mrf.png");
	System.out.println("finished program");
	mrf.frame.dispose();
	System.exit(0);
    }

}

Mrf.java

import java.io.*;
import java.util.*;
import java.awt.image.*;
import javax.swing.*;
import javax.imageio.ImageIO;
import cern.jet.random.tdouble.*;
import cern.jet.random.tdouble.engine.*;
import cern.colt.matrix.tdouble.impl.*;

class Mrf 
{
    int n,m;
    DenseDoubleMatrix2D cells;
    DoubleRandomEngine rng;
    Normal rngN;
    BufferedImage bi;
    WritableRaster wr;
    JFrame frame;
    ImagePanel ip;
    
    Mrf(int n_arg,int m_arg,DoubleRandomEngine rng)
    {
	n=n_arg;
	m=m_arg;
	cells=new DenseDoubleMatrix2D(n,m);
	this.rng=rng;
	rngN=new Normal(0.0,1.0,rng);
	bi=new BufferedImage(n,m,BufferedImage.TYPE_BYTE_GRAY);
	wr=bi.getRaster();
	frame=new JFrame("MRF");
	frame.setSize(n,m);
	frame.add(new ImagePanel(bi));
	frame.setVisible(true);
    }
    
    public void saveImage(String filename)
	throws IOException
    {
	ImageIO.write(bi,"PNG",new File(filename));
    }
    
    public void updateImage()
    {
	double mx=-1e+100;
	double mn=1e+100;
	for (int i=0;i<n;i++) {
	    for (int j=0;j<m;j++) {
		if (cells.getQuick(i,j)>mx) { mx=cells.getQuick(i,j); }
		if (cells.getQuick(i,j)<mn) { mn=cells.getQuick(i,j); }
	    }
	}
	for (int i=0;i<n;i++) {
	    for (int j=0;j<m;j++) {
		int level=(int) (255*(cells.getQuick(i,j)-mn)/(mx-mn));
		wr.setSample(i,j,0,level);
	    }
	}
	frame.repaint();
    }
    
    public void update(int num)
    {
	for (int i=0;i<num;i++) {
	    updateOnce();
	}
    }
    
    private void updateOnce()
    {
	double mean;
	for (int i=0;i<n;i++) {
	    for (int j=0;j<m;j++) {
		if (i==0) {
		    if (j==0) {
			mean=0.5*(cells.getQuick(0,1)+cells.getQuick(1,0));
		    } 
		    else if (j==m-1) {
			mean=0.5*(cells.getQuick(0,j-1)+cells.getQuick(1,j));
		    } 
		    else {
			mean=(cells.getQuick(0,j-1)+cells.getQuick(0,j+1)+cells.getQuick(1,j))/3.0;
		    }
		}
		else if (i==n-1) {
		    if (j==0) {
			mean=0.5*(cells.getQuick(i,1)+cells.getQuick(i-1,0));
		    }
		    else if (j==m-1) {
			mean=0.5*(cells.getQuick(i,j-1)+cells.getQuick(i-1,j));
		    }
		    else {
			mean=(cells.getQuick(i,j-1)+cells.getQuick(i,j+1)+cells.getQuick(i-1,j))/3.0;
		    }
		}
		else if (j==0) {
		    mean=(cells.getQuick(i-1,0)+cells.getQuick(i+1,0)+cells.getQuick(i,1))/3.0;
		}
		else if (j==m-1) {
		    mean=(cells.getQuick(i-1,j)+cells.getQuick(i+1,j)+cells.getQuick(i,j-1))/3.0;
		}
		else {
		    mean=0.25*(cells.getQuick(i,j-1)+cells.getQuick(i,j+1)+cells.getQuick(i+1,j)
			       +cells.getQuick(i-1,j));
		}
		cells.setQuick(i,j,mean+rngN.nextDouble());
	    }
	}
	updateImage();
    }
    
}

Again, the code should be reasonably self explanatory, and will compile and run in the same way provided that Parallel COLT is installed and in your classpath. This version runs approximately twice as fast as the previous version on all of the machines I’ve tried it on.

Reference

I have found the following book very useful for understanding how to work with images in Java:

Hunt, K.A. (2010) The Art of Image Processing with Java, A K Peters/CRC Press.