## Part 5: the Metropolis-adjusted Langevin algorithm (MALA)

### Introduction

This is the fifth part in a series of posts on MCMC-based Bayesian inference for a logistic regression model. If you are new to this series, please go back to Part 1.

In the previous post we saw how to use Langevin dynamics to construct an approximate MCMC scheme using the gradient of the log target distribution. Each step of the algorithm involved simulating from the Euler-Maruyama approximation to the transition kernel of the process, based on some pre-specified step size, $\Delta t$. We can improve the accuracy of this approximation by making the step size smaller, but this will come at the expense of a more slowly mixing MCMC chain.

Fortunately, there is an easy way to make the algorithm "exact" (in the sense that the equilibrium distribution of the Markov chain will be the exact target distribution), for any finite step size, $\Delta t$, simply by using the Euler-Maruyama approximation as the proposal distribution in a Metropolis-Hastings algorithm. This is the Metropolis-adjusted Langevin algorithm (MALA). There are various ways this could be coded up, but here, for clarity, a HoF for generating a MALA kernel will be used, and this function will in turn call on a HoF for generating a Metropolis-Hastings kernel.

## Implementations

### R

First we need a function to generate a M-H kernel.

``````mhKernel = function(logPost, rprop, dprop = function(new, old, ...) { 1 })
function(x, ll) {
prop = rprop(x)
llprop = logPost(prop)
a = llprop - ll + dprop(x, prop) - dprop(prop, x)
if (log(runif(1)) < a)
list(x=prop, ll=llprop)
else
list(x=x, ll=ll)
}
``````

Then we can easily write a function for returning a MALA kernel that makes use of this M-H function.

``````malaKernel = function(lpi, glpi, dt = 1e-4, pre = 1) {
sdt = sqrt(dt)
spre = sqrt(pre)
advance = function(x) x + 0.5*pre*glpi(x)*dt
function(new, old) sum(dnorm(new, advance(old), spre*sdt, log=TRUE)))
}
``````

Notice that our MALA function requires as input both the gradient of the log posterior (for the proposal) and the log posterior itself (for the M-H correction). Other details are as we have already seen – see the full runnable script.

### Python

Again, we need a M-H kernel

``````def mhKernel(lpost, rprop, dprop = lambda new, old: 1.):
def kernel(x, ll):
prop = rprop(x)
lp = lpost(prop)
a = lp - ll + dprop(x, prop) - dprop(prop, x)
if (np.log(np.random.rand()) < a):
x = prop
ll = lp
return x, ll
return kernel
``````

and then a MALA kernel

``````def malaKernel(lpi, glpi, dt = 1e-4, pre = 1):
p = len(init)
sdt = np.sqrt(dt)
spre = np.sqrt(pre)
advance = lambda x: x + 0.5*pre*glpi(x)*dt
return mhKernel(lpi, lambda x: advance(x) + np.random.randn(p)*spre*sdt,
lambda new, old: np.sum(sp.stats.norm.logpdf(new, loc=advance(old), scale=spre*sdt)))
``````

See the full runnable script for further details.

#### JAX

If we want our algorithm to run fast, and if we want to exploit automatic differentiation to avoid the need to manually compute gradients, then we can easily convert the above code to use JAX.

``````def mhKernel(lpost, rprop, dprop = jit(lambda new, old: 1.)):
@jit
def kernel(key, x, ll):
key0, key1 = jax.random.split(key)
prop = rprop(key0, x)
lp = lpost(prop)
a = lp - ll + dprop(x, prop) - dprop(prop, x)
accept = (jnp.log(jax.random.uniform(key1)) < a)
return jnp.where(accept, prop, x), jnp.where(accept, lp, ll)
return kernel

def malaKernel(lpi, dt = 1e-4, pre = 1):
p = len(init)
sdt = jnp.sqrt(dt)
spre = jnp.sqrt(pre)
advance = jit(lambda x: x + 0.5*pre*glpi(x)*dt)
return mhKernel(lpi, jit(lambda k, x: advance(x) +
jax.random.normal(k, [p])*spre*sdt),
jit(lambda new, old:
jnp.sum(jsp.stats.norm.logpdf(new,
``````

See the full runnable script for further details.

### Scala

``````def mhKernel[S](
logPost: S => Double, rprop: S => S,
dprop: (S, S) => Double = (n: S, o: S) => 1.0
): ((S, Double)) => (S, Double) =
val r = Uniform(0.0,1.0)
state =>
val (x0, ll0) = state
val x = rprop(x0)
val ll = logPost(x)
val a = ll - ll0 + dprop(x0, x) - dprop(x, x0)
if (math.log(r.draw()) < a)
(x, ll)
else
(x0, ll0)

def malaKernel(lpi: DVD => Double, glpi: DVD => DVD, pre: DVD, dt: Double = 1e-4) =
val sdt = math.sqrt(dt)
val spre = sqrt(pre)
val p = pre.length
beta + (0.5*dt)*(pre*:*glpi(beta))
def rprop(beta: DVD): DVD =
def dprop(n: DVD, o: DVD): Double =
(0 until p).map(i => Gaussian(ao(i), spre(i)*sdt).logPdf(n(i))).sum
mhKernel(lpi, rprop, dprop)
``````

See the full runnable script for further details.

``````mhKernel :: (StatefulGen g m) => (s -> Double) -> (s -> g -> m s) ->
(s -> s -> Double) -> g -> (s, Double) -> m (s, Double)
mhKernel logPost rprop dprop g (x0, ll0) = do
x <- rprop x0 g
let ll = logPost(x)
let a = ll - ll0 + (dprop x0 x) - (dprop x x0)
u <- (genContVar (uniformDistr 0.0 1.0)) g
let next = if ((log u) < a)
then (x, ll)
else (x0, ll0)
return next

malaKernel :: (StatefulGen g m) =>
(Vector Double -> Double) -> (Vector Double -> Vector Double) ->
Vector Double -> Double -> g ->
(Vector Double, Double) -> m (Vector Double, Double)
malaKernel lpi glpi pre dt g = let
sdt = sqrt dt
spre = cmap sqrt pre
p = size pre
advance beta = beta + (scalar (0.5*dt))*pre*(glpi beta)
rprop beta g = do
zl <- (replicateM p . genContVar (normalDistr 0.0 1.0)) g
let z = fromList zl
return \$ advance(beta) + (scalar sdt)*spre*z
dprop n o = let
in sum \$ (\i -> logDensity (normalDistr (ao!i)
((spre!i)*sdt)) (n!i)) <\$> [0..(p-1)]
in mhKernel lpi rprop dprop g
``````

See the full runnable script for further details.

### Dex

Recall that Dex is differentiable, so we don’t need to provide gradients.

``````def mhKernel {s} (lpost: s -> Float) (rprop: s -> Key -> s) (dprop: s -> s -> Float)
(sll: (s & Float)) (k: Key) : (s & Float) =
(x0, ll0) = sll
[k1, k2] = split_key k
x = rprop x0 k1
ll = lpost x
a = ll - ll0 + (dprop x0 x) - (dprop x x0)
u = rand k2
select (log u < a) (x, ll) (x0, ll0)

def malaKernel {n} (lpi: (Fin n)=>Float -> Float)
(pre: (Fin n)=>Float) (dt: Float) :
((Fin n)=>Float & Float) -> Key -> ((Fin n)=>Float & Float) =
sdt = sqrt dt
spre = sqrt pre
v = dt .* pre
vinv = map (\ x. 1.0/x) v
def advance (beta: (Fin n)=>Float) : (Fin n)=>Float =
beta + (0.5*dt) .* (pre*(glp beta))
def rprop (beta: (Fin n)=>Float) (k: Key) : (Fin n)=>Float =
(advance beta) + sdt .* (spre*(randn_vec k))
def dprop (new: (Fin n)=>Float) (old: (Fin n)=>Float) : Float =
diff = new - ao
-0.5 * sum ((log v) + diff*diff*vinv)
mhKernel lpi rprop dprop
``````

See the full runnable script for further details.

## Next steps

MALA gives us an MCMC algorithm that exploits gradient information to generate "informed" M-H proposals. But it still has a rather "diffusive" character, making it difficult to tune in such a way that large moves are likely to be accepted in challenging high-dimensional situations.

The Langevin dynamics on which MALA is based can be interpreted as the (over-damped) stochastic dynamics of a particle moving in a potential energy field corresponding to minus the log posterior. It turns out that the corresponding deterministic dynamics can be exploited to generate proposals better able to make large moves while still having a high probability of acceptance. This is the idea behind Hamiltonian Monte Carlo (HMC), which we’ll look at next.