Bayesian inference for a logistic regression model (Part 6)

Part 6: Hamiltonian Monte Carlo (HMC)

Introduction

This is the sixth part in a series of posts on MCMC-based Bayesian inference for a logistic regression model. If you are new to this series, please go back to Part 1.

In the previous post we saw how to construct an MCMC algorithm utilising gradient information by considering a Langevin equation having our target distribution of interest as its equilibrium. This equation has a physical interpretation in terms of the stochastic dynamics of a particle in a potential equal to minus the log of the target density. It turns out that thinking about the deterministic dynamics of a particle in such a potential can lead to more efficient MCMC algorithms.

Hamiltonian dynamics

Hamiltonian dynamics is often presented as an extension of a fairly general version of Lagrangian dynamics. However, for our purposes a rather simple version is quite sufficient, based on basic concepts from Newtonian dynamics, familiar from school. Inspired by our Langevin example, we will consider the dynamics of a particle in a potential function V(q). We will see later why we want V(q) = -\log \pi(q) for our target of interest, \pi(\cdot). In the context of Hamiltonian (and Lagrangian) dynamics we typically use q as our position variable, rather than x.

The potential function induces a (conservative) force on the particle equal to -\nabla V(q) when the particle is at position q. Then Newton’s second law of motion, "F=ma", takes the form

\displaystyle \nabla V(q) + m \ddot{q} = 0.

In Newtonian mechanics, we often consider the position vector q as 3-dimensional. Here it will be n-dimensional, where n is the number of variables in our target. We can then think of our second law as governing a single n-dimensional particle of mass m, or n one-dimensional particles all of mass m. But in this latter case, there is no need to assume that all particles have the same mass, and we could instead write our law of motion as

\displaystyle \nabla V(q) + M \ddot{q} = 0,

where M is a diagonal matrix. But in fact, since we could change coordinates, there’s no reason to require that M is diagonal. All we need is that M is positive definite, so that we don’t have negative mass in any coordinate direction.

We will take the above equation as the fundamental law governing our dynamical system of interest. The motivation from Newtonian dynamics is interesting, but not required. What is important is that the dynamics of such a system are conservative, in a way that we will shortly make precise.

Our law of motion is a second-order differential equation, since it involves the second derivative of q wrt time. If you’ve ever studied differential equations, you’ll know that there is an easy way to turn a second order equation into a first order equation with twice the dimension by augmenting the system with the velocities. Here, it is more convenient to augment the system with "momentum" variables, p, which we define as p = M\dot{q}. Then we can write our second order system as a pair of first order equations

\displaystyle \dot{q} = M^{-1}p

\displaystyle \dot{p} = -\nabla V(q)

These are, in fact, Hamilton’s equations for this system, though this isn’t how they are typically written.

If we define the kinetic energy as

\displaystyle T(p) = \frac{1}{2}p^\text{T}M^{-1}p,

then the Hamiltonian

\displaystyle H(q,p) = V(q) + T(p),

representing the total energy in the system, is conserved, since

\displaystyle \dot{H} = \nabla V\cdot \dot{q} + \dot{p}^\text{T}M^{-1}p = \nabla V\cdot \dot{q} + \dot{p}^\text{T}\dot{q} = [\nabla V + \dot{p}]\cdot\dot{q} = 0.

So, if we obey our Hamiltonian dynamics, our trajectory in (q,p)-space will follow contours of the Hamiltonian. It’s also clear that the system is time-reversible, so flipping the sign of the momentum p and integrating will exactly reverse the direction in which the contours are traversed. Another quite important property of Hamiltonian dynamics is that they are volume preserving. This can be verified by checking that the divergence of the flow is zero.

\displaystyle \nabla\cdot(\dot{q},\dot{p}) = \nabla_q\cdot\dot{q} + \nabla_p\cdot\dot{p} = 0,

since \dot{q} is a function of p only and \dot{p} is a function of q only.

Hamiltonian Monte Carlo (HMC)

In Hamiltonian Monte Carlo we introduce an augmented target distribution,

\displaystyle \tilde \pi(q,p) \propto \exp[-H(q,p)]

It is clear from this definition that moves leaving the Hamiltonian invariant will also leave the augmented target density unchanged. By following the Hamiltonian dynamics, we will be able to make big (reversible) moves in the space that will be accepted with high probability. Also, our target factorises into two independent components as

\displaystyle \tilde \pi(q,p) \propto \exp[-V(q)]\exp[-T(p)],

and so choosing V(q)=-\log \pi(q) will ensure that the q-marginal is our real target of interest, \pi(\cdot). It’s also clear that our p-marginal is \mathcal N(0,M). This is also the full-conditional for p, so re-sampling p from this distribution and leaving q unchanged is a Gibbs move that will leave the augmented target invariant. Re-sampling p will be necessary to properly explore our augmented target, since this will move us to a different contour of H.

So, an idealised version of HMC would proceed as follows: First, update p by sampling from its known tractable marginal. Second, update p and q jointly by following the Hamiltonian dynamics. If this second move is regarded as a (deterministic) reversible M-H proposal, it will be accepted with probability one since it leaves the augmented target density unchanged. If we could exactly integrate Hamilton’s equations, this would be fine. But in practice, we will need to use some imperfect numerical method for the integration step. But just as for MALA, we can regard the numerical method as a M-H proposal and correct for the fact that it is imperfect, preserving the exact augmented target distribution.

Hamiltonian systems admit nice numerical integration schemes called symplectic integrators. In HMC a simple alternating Euler method is typically used, known as the leap-frog algorithm. The component updates are all shear transformations, and therefore volume preserving, and exact reversibility is ensured by starting and ending with a half-step update of the momentum variables. In principle, to ensure reversibility of the proposal the momentum variables should be sign-flipped (reversed) to finish, but in practice this doesn’t matter since it doesn’t affect the evaluation of the Hamiltonian and it will then get refreshed, anyway.

So, advancing our system by a time step \epsilon can be done with

\displaystyle p(t+\epsilon/2) := p(t) - \frac{\epsilon}{2}\nabla V(q(t))

\displaystyle q(t+\epsilon) := q(t) + \epsilon M^{-1}p(t+\epsilon/2)

\displaystyle p(t+\epsilon) := p(t+\epsilon/2) - \frac{\epsilon}{2}\nabla V(q(t+\epsilon))

It is clear that if many such updates are chained together, adjacent momentum updates can be collapsed together, giving rise to the "leap-frog" nature of the algorithm, and therefore requiring roughly one gradient evaluation per \epsilon update, rather than two. Since this integrator is volume preserving and exactly reversible, for reasonably small \epsilon it follows the Hamiltonian dynamics reasonably well, but not exactly, and so it does not exactly preserve the Hamiltonian. However, it does make a good M-H proposal, and reasonable acceptance probabilities can often be obtained by chaining together l updates to advance the time of the system by T=l\epsilon. The "optimal" value of l and \epsilon will be highly problem dependent, but values of l=20 or l=50 are not unusual. There are various more-or-less standard methods for tuning these, but we will not consider them here.

Note that since our HMC update on the augmented space consists of a Gibbs move and a M-H update, it is important that our M-H kernel does not keep or thread through the old log target density from the previous M-H update, since the Gibbs move will have changed it in the meantime.

Implementations

R

We need a M-H kernel that does not thread through the old log density.

mhKernel = function(logPost, rprop)
    function(x) {
        prop = rprop(x)
        a = logPost(prop) - logPost(x)
        if (log(runif(1)) < a)
            prop
        else
            x
    }

We can then use this to construct a M-H move as part of our HMC update.

hmcKernel = function(lpi, glpi, eps = 1e-4, l=10, dmm = 1) {
    sdmm = sqrt(dmm)
    leapf = function(q, p) {
        p = p + 0.5*eps*glpi(q)
        for (i in 1:l) {
            q = q + eps*p/dmm
            if (i < l)
                p = p + eps*glpi(q)
            else
                p = p + 0.5*eps*glpi(q)
        }
        list(q=q, p=-p)
    }
    alpi = function(x)
        lpi(x$q) - 0.5*sum((x$p^2)/dmm)
    rprop = function(x)
        leapf(x$q, x$p)
    mhk = mhKernel(alpi, rprop)
    function(q) {
        d = length(q)
        x = list(q=q, p=rnorm(d, 0, sdmm))
        mhk(x)$q
    }
}

See the full runnable script for further details.

Python

First a M-H kernel,

def mhKernel(lpost, rprop):
    def kernel(x):
        prop = rprop(x)
        a = lpost(prop) - lpost(x)
        if (np.log(np.random.rand()) < a):
            x = prop
        return x
    return kernel

and then an HMC kernel.

def hmcKernel(lpi, glpi, eps = 1e-4, l=10, dmm = 1):
    sdmm = np.sqrt(dmm)
    def leapf(q, p):    
        p = p + 0.5*eps*glpi(q)
        for i in range(l):
            q = q + eps*p/dmm
            if (i < l-1):
                p = p + eps*glpi(q)
            else:
                p = p + 0.5*eps*glpi(q)
        return (q, -p)
    def alpi(x):
        (q, p) = x
        return lpi(q) - 0.5*np.sum((p**2)/dmm)
    def rprop(x):
        (q, p) = x
        return leapf(q, p)
    mhk = mhKernel(alpi, rprop)
    def kern(q):
        d = len(q)
        p = np.random.randn(d)*sdmm
        return mhk((q, p))[0]
    return kern

See the full runnable script for further details.

JAX

Again, we want an appropriate M-H kernel,

def mhKernel(lpost, rprop, dprop = jit(lambda new, old: 1.)):
    @jit
    def kernel(key, x):
        key0, key1 = jax.random.split(key)
        prop = rprop(key0, x)
        ll = lpost(x)
        lp = lpost(prop)
        a = lp - ll + dprop(x, prop) - dprop(prop, x)
        accept = (jnp.log(jax.random.uniform(key1)) < a)
        return jnp.where(accept, prop, x)
    return kernel

and then an HMC kernel.

def hmcKernel(lpi, glpi, eps = 1e-4, l = 10, dmm = 1):
    sdmm = jnp.sqrt(dmm)
    @jit
    def leapf(q, p):    
        p = p + 0.5*eps*glpi(q)
        for i in range(l):
            q = q + eps*p/dmm
            if (i < l-1):
                p = p + eps*glpi(q)
            else:
                p = p + 0.5*eps*glpi(q)
        return jnp.concatenate((q, -p))
    @jit
    def alpi(x):
        d = len(x) // 2
        return lpi(x[jnp.array(range(d))]) - 0.5*jnp.sum((x[jnp.array(range(d,2*d))]**2)/dmm)
    @jit
    def rprop(k, x):
        d = len(x) // 2
        return leapf(x[jnp.array(range(d))], x[jnp.array(range(d, 2*d))])
    mhk = mhKernel(alpi, rprop)
    @jit
    def kern(k, q):
        key0, key1 = jax.random.split(k)
        d = len(q)
        x = jnp.concatenate((q, jax.random.normal(key0, [d])*sdmm))
        return mhk(key1, x)[jnp.array(range(d))]
    return kern

There is something a little bit strange about this implementation, since the proposal for the M-H move is deterministic, the function rprop just ignores the RNG key that is passed to it. We could tidy this up by making a M-H function especially for deterministic proposals. We won’t pursue this here, but this issue will crop up again later in some of the other functional languages.

See the full runnable script for further details.

Scala

A M-H kernel,

def mhKern[S](
    logPost: S => Double, rprop: S => S,
    dprop: (S, S) => Double = (n: S, o: S) => 1.0
  ): (S) => S =
    val r = Uniform(0.0,1.0)
    x0 =>
      val x = rprop(x0)
      val ll0 = logPost(x0)
      val ll = logPost(x)
      val a = ll - ll0 + dprop(x0, x) - dprop(x, x0)
      if (math.log(r.draw()) < a) x else x0

and a HMC kernel.

def hmcKernel(lpi: DVD => Double, glpi: DVD => DVD, dmm: DVD,
  eps: Double = 1e-4, l: Int = 10) =
  val sdmm = sqrt(dmm)
  def leapf(q: DVD, p: DVD): (DVD, DVD) = 
    @tailrec def go(q0: DVD, p0: DVD, l: Int): (DVD, DVD) =
      val q = q0 + eps*(p0/:/dmm)
      val p = if (l > 1)
        p0 + eps*glpi(q)
      else
        p0 + 0.5*eps*glpi(q)
      if (l == 1)
        (q, -p)
      else
        go(q, p, l-1)
    go(q, p + 0.5*eps*glpi(q), l)
  def alpi(x: (DVD, DVD)): Double =
    val (q, p) = x
    lpi(q) - 0.5*sum(pow(p,2) /:/ dmm)
  def rprop(x: (DVD, DVD)): (DVD, DVD) =
    val (q, p) = x
    leapf(q, p)
  val mhk = mhKern(alpi, rprop)
  (q: DVD) =>
    val d = q.length
    val p = sdmm map (sd => Gaussian(0,sd).draw())
    mhk((q, p))._1

See the full runnable script for further details.

Haskell

A M-H kernel:

mdKernel :: (StatefulGen g m) => (s -> Double) -> (s -> s) -> g -> s -> m s
mdKernel logPost prop g x0 = do
  let x = prop x0
  let ll0 = logPost x0
  let ll = logPost x
  let a = ll - ll0
  u <- (genContVar (uniformDistr 0.0 1.0)) g
  let next = if ((log u) < a)
        then x
        else x0
  return next

Note that here we are using a M-H kernel specifically for deterministic proposals, since there is no non-determinism signalled in the type signature of prop. We can then use this to construct our HMC kernel.

hmcKernel :: (StatefulGen g m) =>
  (Vector Double -> Double) -> (Vector Double -> Vector Double) -> Vector Double ->
  Double -> Int -> g ->
  Vector Double -> m (Vector Double)
hmcKernel lpi glpi dmm eps l g = let
  sdmm = cmap sqrt dmm
  leapf q p = let
    go q0 p0 l = let
      q = q0 + (scalar eps)*p0/dmm
      p = if (l > 1)
        then p0 + (scalar eps)*(glpi q)
        else p0 + (scalar (eps/2))*(glpi q)
      in if (l == 1)
      then (q, -p)
      else go q p (l - 1)
    in go q (p + (scalar (eps/2))*(glpi q)) l
  alpi x = let
    (q, p) = x
    in (lpi q) - 0.5*(sumElements (p*p/dmm))
  prop x = let
    (q, p) = x
    in leapf q p
  mk = mdKernel alpi prop g
  in (\q0 -> do
         let d = size q0
         zl <- (replicateM d . genContVar (normalDistr 0.0 1.0)) g
         let z = fromList zl
         let p0 = sdmm * z
         (q, p) <- mk (q0, p0)
         return q)

See the full runnable script for further details.

Dex

Again we can use a M-H kernel specific to deterministic proposals.

def mdKernel {s} (lpost: s -> Float) (prop: s -> s)
    (x0: s) (k: Key) : s =
  x = prop x0
  ll0 = lpost x0
  ll = lpost x
  a = ll - ll0
  u = rand k
  select (log u < a) x x0

and use this to construct an HMC kernel.

def hmcKernel {n} (lpi: (Fin n)=>Float -> Float)
    (dmm: (Fin n)=>Float) (eps: Float) (l: Nat)
    (q0: (Fin n)=>Float) (k: Key) : (Fin n)=>Float =
  sdmm = sqrt dmm
  idmm = map (\x. 1.0/x) dmm
  glpi = grad lpi
  def leapf (q0: (Fin n)=>Float) (p0: (Fin n)=>Float) :
      ((Fin n)=>Float & (Fin n)=>Float) =
    p1 = p0 + (eps/2) .* (glpi q0)
    q1 = q0 + eps .* (p1*idmm)
    (q, p) = apply_n l (q1, p1) \(qo, po).
      pn = po + eps .* (glpi qo)
      qn = qo + eps .* (pn*idmm)
      (qn, pn)
    pf = p + (eps/2) .* (glpi q)
    (q, -pf)
  def alpi (qp: ((Fin n)=>Float & (Fin n)=>Float)) : Float =
    (q, p) = qp
    (lpi q) - 0.5*(sum (p*p*idmm))
  def prop (qp: ((Fin n)=>Float & (Fin n)=>Float)) :
      ((Fin n)=>Float & (Fin n)=>Float) =
    (q, p) = qp
    leapf q p
  mk = mdKernel alpi prop
  [k1, k2] = split_key k
  z = randn_vec k1
  p0 = sdmm * z
  (q, p) = mk (q0, p0) k2
  q

Note that the gradient is obtained via automatic differentiation. See the full runnable script for details.

Next steps

This was the main place that I was trying to get to when I started this series of posts. For differentiable log-posteriors (as we have in the case of Bayesian logistic regression), HMC is a pretty good algorithm for reasonably efficient posterior exploration. But there are lots of places we could go from here. We could explore the tuning of MCMC algorithms, or HMC extensions such as NUTS. We could look at MCMC algorithms that are specifically tailored to the logistic regression problem, or we could look at new MCMC algorithms for differentiable targets based on piecewise deterministic Markov processes. Alternatively, we could temporarily abandon MCMC and look at SMC or ABC approaches. Another possibility would be to abandon this multi-language approach and have a bit of a deep dive into Dex, which I think has the potential to be a great programming language for statistical computing. All of these are possibilities for the future, but I’ve a busy few weeks coming up, so the frequency of these posts is likely to substantially decrease.

Remember that all of the code associated with this series of posts is available from this github repo.

Advertisement

Bayesian inference for a logistic regression model (Part 5)

Part 5: the Metropolis-adjusted Langevin algorithm (MALA)

Introduction

This is the fifth part in a series of posts on MCMC-based Bayesian inference for a logistic regression model. If you are new to this series, please go back to Part 1.

In the previous post we saw how to use Langevin dynamics to construct an approximate MCMC scheme using the gradient of the log target distribution. Each step of the algorithm involved simulating from the Euler-Maruyama approximation to the transition kernel of the process, based on some pre-specified step size, \Delta t. We can improve the accuracy of this approximation by making the step size smaller, but this will come at the expense of a more slowly mixing MCMC chain.

Fortunately, there is an easy way to make the algorithm "exact" (in the sense that the equilibrium distribution of the Markov chain will be the exact target distribution), for any finite step size, \Delta t, simply by using the Euler-Maruyama approximation as the proposal distribution in a Metropolis-Hastings algorithm. This is the Metropolis-adjusted Langevin algorithm (MALA). There are various ways this could be coded up, but here, for clarity, a HoF for generating a MALA kernel will be used, and this function will in turn call on a HoF for generating a Metropolis-Hastings kernel.

Implementations

R

First we need a function to generate a M-H kernel.

mhKernel = function(logPost, rprop, dprop = function(new, old, ...) { 1 })
    function(x, ll) {
        prop = rprop(x)
        llprop = logPost(prop)
        a = llprop - ll + dprop(x, prop) - dprop(prop, x)
        if (log(runif(1)) < a)
            list(x=prop, ll=llprop)
        else
            list(x=x, ll=ll)
    }

Then we can easily write a function for returning a MALA kernel that makes use of this M-H function.

malaKernel = function(lpi, glpi, dt = 1e-4, pre = 1) {
    sdt = sqrt(dt)
    spre = sqrt(pre)
    advance = function(x) x + 0.5*pre*glpi(x)*dt
    mhKernel(lpi, function(x) rnorm(p, advance(x), spre*sdt),
             function(new, old) sum(dnorm(new, advance(old), spre*sdt, log=TRUE)))
}

Notice that our MALA function requires as input both the gradient of the log posterior (for the proposal) and the log posterior itself (for the M-H correction). Other details are as we have already seen – see the full runnable script.

Python

Again, we need a M-H kernel

def mhKernel(lpost, rprop, dprop = lambda new, old: 1.):
    def kernel(x, ll):
        prop = rprop(x)
        lp = lpost(prop)
        a = lp - ll + dprop(x, prop) - dprop(prop, x)
        if (np.log(np.random.rand()) < a):
            x = prop
            ll = lp
        return x, ll
    return kernel

and then a MALA kernel

def malaKernel(lpi, glpi, dt = 1e-4, pre = 1):
    p = len(init)
    sdt = np.sqrt(dt)
    spre = np.sqrt(pre)
    advance = lambda x: x + 0.5*pre*glpi(x)*dt
    return mhKernel(lpi, lambda x: advance(x) + np.random.randn(p)*spre*sdt,
            lambda new, old: np.sum(sp.stats.norm.logpdf(new, loc=advance(old), scale=spre*sdt)))

See the full runnable script for further details.

JAX

If we want our algorithm to run fast, and if we want to exploit automatic differentiation to avoid the need to manually compute gradients, then we can easily convert the above code to use JAX.

def mhKernel(lpost, rprop, dprop = jit(lambda new, old: 1.)):
    @jit
    def kernel(key, x, ll):
        key0, key1 = jax.random.split(key)
        prop = rprop(key0, x)
        lp = lpost(prop)
        a = lp - ll + dprop(x, prop) - dprop(prop, x)
        accept = (jnp.log(jax.random.uniform(key1)) < a)
        return jnp.where(accept, prop, x), jnp.where(accept, lp, ll)
    return kernel

def malaKernel(lpi, dt = 1e-4, pre = 1):
    p = len(init)
    glpi = jit(grad(lpost))
    sdt = jnp.sqrt(dt)
    spre = jnp.sqrt(pre)
    advance = jit(lambda x: x + 0.5*pre*glpi(x)*dt)
    return mhKernel(lpi, jit(lambda k, x: advance(x) +
                             jax.random.normal(k, [p])*spre*sdt),
            jit(lambda new, old:
                jnp.sum(jsp.stats.norm.logpdf(new,
                      loc=advance(old), scale=spre*sdt))))

See the full runnable script for further details.

Scala

def mhKernel[S](
    logPost: S => Double, rprop: S => S,
    dprop: (S, S) => Double = (n: S, o: S) => 1.0
  ): ((S, Double)) => (S, Double) =
    val r = Uniform(0.0,1.0)
    state =>
      val (x0, ll0) = state
      val x = rprop(x0)
      val ll = logPost(x)
      val a = ll - ll0 + dprop(x0, x) - dprop(x, x0)
      if (math.log(r.draw()) < a)
        (x, ll)
      else
        (x0, ll0)

def malaKernel(lpi: DVD => Double, glpi: DVD => DVD, pre: DVD, dt: Double = 1e-4) =
  val sdt = math.sqrt(dt)
  val spre = sqrt(pre)
  val p = pre.length
  def advance(beta: DVD): DVD =
    beta + (0.5*dt)*(pre*:*glpi(beta))
  def rprop(beta: DVD): DVD =
    advance(beta) + sdt*spre.map(Gaussian(0,_).sample())
  def dprop(n: DVD, o: DVD): Double = 
    val ao = advance(o)
    (0 until p).map(i => Gaussian(ao(i), spre(i)*sdt).logPdf(n(i))).sum
  mhKernel(lpi, rprop, dprop)

See the full runnable script for further details.

Haskell

mhKernel :: (StatefulGen g m) => (s -> Double) -> (s -> g -> m s) ->
  (s -> s -> Double) -> g -> (s, Double) -> m (s, Double)
mhKernel logPost rprop dprop g (x0, ll0) = do
  x <- rprop x0 g
  let ll = logPost(x)
  let a = ll - ll0 + (dprop x0 x) - (dprop x x0)
  u <- (genContVar (uniformDistr 0.0 1.0)) g
  let next = if ((log u) < a)
        then (x, ll)
        else (x0, ll0)
  return next

malaKernel :: (StatefulGen g m) =>
  (Vector Double -> Double) -> (Vector Double -> Vector Double) -> 
  Vector Double -> Double -> g ->
  (Vector Double, Double) -> m (Vector Double, Double)
malaKernel lpi glpi pre dt g = let
  sdt = sqrt dt
  spre = cmap sqrt pre
  p = size pre
  advance beta = beta + (scalar (0.5*dt))*pre*(glpi beta)
  rprop beta g = do
    zl <- (replicateM p . genContVar (normalDistr 0.0 1.0)) g
    let z = fromList zl
    return $ advance(beta) + (scalar sdt)*spre*z
  dprop n o = let
    ao = advance o
    in sum $ (\i -> logDensity (normalDistr (ao!i) 
      ((spre!i)*sdt)) (n!i)) <$> [0..(p-1)]
  in mhKernel lpi rprop dprop g

See the full runnable script for further details.

Dex

Recall that Dex is differentiable, so we don’t need to provide gradients.

def mhKernel {s} (lpost: s -> Float) (rprop: s -> Key -> s) (dprop: s -> s -> Float)
    (sll: (s & Float)) (k: Key) : (s & Float) =
  (x0, ll0) = sll
  [k1, k2] = split_key k
  x = rprop x0 k1
  ll = lpost x
  a = ll - ll0 + (dprop x0 x) - (dprop x x0)
  u = rand k2
  select (log u < a) (x, ll) (x0, ll0)

def malaKernel {n} (lpi: (Fin n)=>Float -> Float)
    (pre: (Fin n)=>Float) (dt: Float) :
    ((Fin n)=>Float & Float) -> Key -> ((Fin n)=>Float & Float) =
  sdt = sqrt dt
  spre = sqrt pre
  glp = grad lpi
  v = dt .* pre
  vinv = map (\ x. 1.0/x) v
  def advance (beta: (Fin n)=>Float) : (Fin n)=>Float =
    beta + (0.5*dt) .* (pre*(glp beta))
  def rprop (beta: (Fin n)=>Float) (k: Key) : (Fin n)=>Float =
    (advance beta) + sdt .* (spre*(randn_vec k))
  def dprop (new: (Fin n)=>Float) (old: (Fin n)=>Float) : Float =
    ao = advance old
    diff = new - ao
    -0.5 * sum ((log v) + diff*diff*vinv)
  mhKernel lpi rprop dprop

See the full runnable script for further details.

Next steps

MALA gives us an MCMC algorithm that exploits gradient information to generate "informed" M-H proposals. But it still has a rather "diffusive" character, making it difficult to tune in such a way that large moves are likely to be accepted in challenging high-dimensional situations.

The Langevin dynamics on which MALA is based can be interpreted as the (over-damped) stochastic dynamics of a particle moving in a potential energy field corresponding to minus the log posterior. It turns out that the corresponding deterministic dynamics can be exploited to generate proposals better able to make large moves while still having a high probability of acceptance. This is the idea behind Hamiltonian Monte Carlo (HMC), which we’ll look at next.

Bayesian inference for a logistic regression model (Part 4)

Part 4: Gradients and the Langevin algorithm

Introduction

This is the fourth part in a series of posts on MCMC-based Bayesian inference for a logistic regression model. If you are new to this series, please go back to Part 1.

In the previous post we saw how the Metropolis algorithm could be used to generate a Markov chain targeting our posterior distribution. In high dimensions the diffusive nature of the Metropolis random walk proposal becomes increasingly inefficient. It is therefore natural to try and develop algorithms that use additional information about the target distribution. In the case of a differentiable log posterior target, a natural first step in this direction is to try and make use of gradient information.

Gradient of a logistic regression model

There are various ways to derive the gradient of our logistic regression model, but it might be simplest to start from the first form of the log likelihood that we deduced in Part 2:

\displaystyle l(\beta;y) = y^\textsf{T}X\beta - \mathbf{1}^\textsf{T}\log(\mathbf{1}+\exp[X\beta])

We can write this out in component form as

\displaystyle l(\beta;y) = \sum_j\sum_j y_iX_{ij}\beta_j - \sum_i\log\left(1+\exp\left[\sum_jX_{ij}\beta_j\right]\right).

Differentiating wrt \beta_k gives

\displaystyle \frac{\partial l}{\partial \beta_k} = \sum_i y_iX_{ik} - \sum_i \frac{\exp\left[\sum_j X_{ij}\beta_j\right]X_{ik}}{1+\exp\left[\sum_j X_{ij}\beta_j\right]}.

It’s then reasonably clear that stitching all of the partial derivatives together will give the gradient vector

\displaystyle \nabla l = X^\textsf{T}\left[ y - \frac{\mathbf{1}}{\mathbf{1}+\exp[-X\beta]} \right].

This is the gradient of the log likelihood, but we also need the gradient of the log prior. Since we are assuming independent \beta_i \sim N(0,v_i) priors, it is easy to see that the gradient of the log prior is just -\beta\circ v^{-1}. It is the sum of these two terms that gives the gradient of the log posterior.

R

In R we can implement our gradient function as

glp = function(beta) {
    glpr = -beta/(pscale*pscale)
    gll = as.vector(t(X) %*% (y - 1/(1 + exp(-X %*% beta))))
    glpr + gll
}

Python

In Python we could use

def glp(beta):
    glpr = -beta/(pscale*pscale)
    gll = (X.T).dot(y - 1/(1 + np.exp(-X.dot(beta))))
    return (glpr + gll)

We don’t really need a JAX version, since JAX can auto-diff the log posterior for us.

Scala

  def glp(beta: DVD): DVD =
    val glpr = -beta /:/ pvar
    val gll = (X.t)*(y - ones/:/(ones + exp(-X*beta)))
    glpr + gll

Haskell

Using hmatrix we could use something like

glp :: Matrix Double -> Vector Double -> Vector Double -> Vector Double
glp x y b = let
  glpr = -b / (fromList [100.0, 1, 1, 1, 1, 1, 1, 1])
  gll = (tr x) #> (y - (scalar 1)/((scalar 1) + (cmap exp (-x #> b))))
  in glpr + gll

There’s something interesting to say about Haskell and auto-diff, but getting into this now will be too much of a distraction. I may come back to it in some future post.

Dex

Dex is differentiable, so we don’t need a gradient function – we can just use grad lpost. However, for interest and comparison purposes we could nevertheless implement it directly with something like

prscale = map (\ x. 1.0/x) pscale

def glp (b: (Fin 8)=>Float) : (Fin 8)=>Float =
  glpr = -b*prscale*prscale
  gll = (transpose x) **. (y - (map (\eta. 1.0/(1.0 + eta)) (exp (-x **. b))))
  glpr + gll

Langevin diffusions

Now that we have a way of computing the gradient of the log of our target density we need some MCMC algorithms that can make good use of it. In this post we will look at a simple approximate MCMC algorithm derived from an overdamped Langevin diffusion model. In subsequent posts we’ll look at more sophisticated, exact MCMC algorithms.

The multivariate stochastic differential equation (SDE)

\displaystyle dX_t = \frac{1}{2}\nabla\log\pi(X_t)dt + dW_t

has \pi(\cdot) as its equilibrium distribution. Informally, an SDE of this form is a continuous time process with infinitesimal transition kernel

\displaystyle X_{t+dt}|(X_t=x_t) \sim N\left(x_t+\frac{1}{2}\nabla\log\pi(x_t)dt,\mathbf{I}dt\right).

There are various more-or-less formal ways to see that \pi(\cdot) is stationary. A good way is to check it satisfies the Fokker–Planck equation with zero LHS. A less formal approach would be to see that the infinitesimal transition kernel for the process satisfies detailed balance with \pi(\cdot).

Similar arguments show that for any fixed positive definite matrix A, the SDE

\displaystyle dX_t = \frac{1}{2}A\nabla\log\pi(X_t)dt + A^{1/2}dW_t

also has \pi(\cdot) as a stationary distribution. It is quite common to choose a diagonal matrix A to put the components of X_t on a common scale.

The unadjusted Langevin algorithm

Simulating exact sample paths from SDEs such as the overdamped Langevin diffusion model is typically difficult (though not necessarily impossible), so we instead want something simple and tractable as the basis of our MCMC algorithms. Here we will just simulate from the Euler–Maruyama approximation of the process by choosing a small but finite time step \Delta t and using the transition kernel

\displaystyle X_{t+\Delta t}|(X_t=x_t) \sim N\left(x_t+\frac{1}{2}A\nabla\log\pi(x_t)\Delta t, A\Delta t\right)

as the basis of our MCMC method. For sufficiently small \Delta t this should accurately approximate the Langevin dynamics, leading to an equilibrium distribution very close to \pi(\cdot). That said, we would like to choose \Delta t as large as we can get away with, since that will lead to a more rapidly mixing MCMC chain. Below are some implementations of this kernel for a diagonal pre-conditioning matrix.

Implementation

R

We can create a kernel for the unadjusted Langevin algorithm in R with the following function.

ulKernel = function(glpi, dt = 1e-4, pre = 1) {
    sdt = sqrt(dt)
    spre = sqrt(pre)
    advance = function(x) x + 0.5*pre*glpi(x)*dt
    function(x, ll) rnorm(p, advance(x), spre*sdt)
}

Here, we can pass in pre, which is expected to be a vector representing the diagonal of the pre-conditioning matrix, A. We can then use this kernel to generate an MCMC chain as we have seen previously. See the full runnable script for further details.

Python

def ulKernel(glpi, dt = 1e-4, pre = 1):
    p = len(init)
    sdt = np.sqrt(dt)
    spre = np.sqrt(pre)
    advance = lambda x: x + 0.5*pre*glpi(x)*dt
    def kernel(x):
        return advance(x) + np.random.randn(p)*spre*sdt
    return kernel

See the full runnable script for further details.

JAX

def ulKernel(lpi, dt = 1e-4, pre = 1):
    p = len(init)
    glpi = jit(grad(lpi))
    sdt = jnp.sqrt(dt)
    spre = jnp.sqrt(pre)
    advance = jit(lambda x: x + 0.5*pre*glpi(x)*dt)
    @jit
    def kernel(key, x):
        return advance(x) + jax.random.normal(key, [p])*spre*sdt
    return kernel

Note how for JAX we can just pass in the log posterior, and the gradient function can be obtained by automatic differentiation. See the full runnable script for further details.

Scala

def ulKernel(glp: DVD => DVD, pre: DVD, dt: Double): DVD => DVD =
  val sdt = math.sqrt(dt)
  val spre = sqrt(pre)
  def advance(beta: DVD): DVD =
    beta + (0.5*dt)*(pre*:*glp(beta))
  beta => advance(beta) + sdt*spre.map(Gaussian(0,_).sample())

See the full runnable script for further details.

Haskell

ulKernel :: (StatefulGen g m) =>
  (Vector Double -> Vector Double) -> Vector Double -> Double -> g ->
  Vector Double -> m (Vector Double)
ulKernel glpi pre dt g beta = do
  let sdt = sqrt dt
  let spre = cmap sqrt pre
  let p = size pre
  let advance beta = beta + (scalar (0.5*dt))*pre*(glpi beta)
  zl <- (replicateM p . genContVar (normalDistr 0.0 1.0)) g
  let z = fromList zl
  return $latex  advance(beta) + (scalar sdt)*spre*z

See the full runnable script for further details.

Dex

In Dex we can write a function that accepts a gradient function

def ulKernel {n} (glpi: (Fin n)=>Float -> (Fin n)=>Float)
    (pre: (Fin n)=>Float) (dt: Float)
    (b: (Fin n)=>Float) (k: Key) : (Fin n)=>Float =
  sdt = sqrt dt
  spre = sqrt pre
  b + (((0.5)*dt) .* (pre*(glpi b))) +
    (sdt .* (spre*(randn_vec k)))

or we can write a function that accepts a log posterior, and uses auto-diff to construct the gradient

def ulKernel {n} (lpi: (Fin n)=>Float -> Float)
    (pre: (Fin n)=>Float) (dt: Float)
    (b: (Fin n)=>Float) (k: Key) : (Fin n)=>Float =
  glpi = grad lpi
  sdt = sqrt dt
  spre = sqrt pre
  b + ((0.5)*dt) .* (pre*(glpi b)) +
    sdt .* (spre*(randn_vec k))

and since Dex is statically typed, we can’t easily mix these functions up.

See the full runnable scripts, without and with auto-diff.

Next steps

In this post we have seen how to construct an MCMC algorithm that makes use of gradient information. But this algorithm is approximate. In the next post we’ll see how to correct for the approximation by using the Langevin updates as proposals within a Metropolis-Hastings algorithm.

Bayesian inference for a logistic regression model (Part 3)

Part 3: The Metropolis algorithm

Introduction

This is the third part in a series of posts on MCMC-based Bayesian inference for a logistic regression model. If you are new to this series, please go back to Part 1.

In the previous post we derived the log posterior for the model and implemented it in a variety of programming languages and libraries. In this post we will construct a Markov chain having the posterior as its equilibrium.

MCMC

Detailed balance

A homogeneous Markov chain with transition kernel p(\theta_{n+1}|\theta_n) is said to satisfy detailed balance for some target distribution \pi(\theta) if

\displaystyle \pi(\theta)p(\theta'|\theta) = \pi(\theta')p(\theta|\theta'), \quad \forall \theta, \theta'

Integrating both sides wrt \theta gives

\displaystyle \int_\Theta \pi(\theta)p(\theta'|\theta)d\theta = \pi(\theta'),

from which it is clear that \pi(\cdot) is a stationary distribution of the chain (and the chain is reversible). Under fairly mild regularity conditions we expect \pi(\cdot) to be the equilibrium distribution of the chain.

For a given target \pi(\cdot) we would like to find an easy-to-sample-from transition kernel p(\cdot|\cdot) that satisfies detailed balance. This will then give us a way to (asymptotically) generate samples from our target.

In the context of Bayesian inference, the target \pi(\theta) will typically be the posterior distribution, which in the previous post we wrote as \pi(\theta|y). Here we drop the notational dependence on y, since MCMC can be used for any target distribution of interest.

Metropolis-Hastings

Suppose we have a fairly arbitrary easy-to-sample-from transition kernel q(\theta_{n+1}|\theta_n) and a target of interest, \pi(\cdot). Metropolis-Hastings (M-H) is a strategy for using q(\cdot|\cdot) to construct a new transition kernel p(\cdot|\cdot) satisfying detailed balance for \pi(\cdot).

The kernel p(\theta'|\theta) can be described algorithmically as follows:

  1. Call the current state of the chain \theta. Generate a proposal \theta^\star by simulating from q(\theta^\star|\theta).
  2. Compute the acceptance probability

\displaystyle \alpha(\theta^\star|\theta) = \min\left[1,\frac{\pi(\theta^\star)q(\theta|\theta^\star)}{\pi(\theta)q(\theta^\star|\theta)}\right].

  1. With probability \alpha(\theta^\star|\theta) return new state \theta'=\theta^\star, otherwise return \theta'=\theta.

It is clear from the algorithmic description that this kernel will have a point mass at \theta'=\theta, but that for \theta'\not=\theta the transition kernel will be p(\theta'|\theta)=q(\theta'|\theta)\alpha(\theta'|\theta). But then

\displaystyle \pi(\theta)p(\theta'|\theta) = \min[\pi(\theta)q(\theta'|\theta),\pi(\theta')q(\theta|\theta')]

is symmetric in \theta and \theta', and so detailed balance is satisfied. Since detailed balance is trivial at the point mass at \theta=\theta' we are done.

Metropolis algorithm

It is often convenient to generate proposals perturbatively, using a distribution that is symmetric about the current state of the chain. But then q(\theta'|\theta)=q(\theta|\theta'), and so q(\cdot|\cdot) drops out of the acceptance probability. This is the Metropolis algorithm.

Some computational tricks

To generate an event with probability \alpha(\theta^\star|\theta), we can generate a u\sim U(0,1) and accept if u < \alpha(\theta^\star|\theta). This is convenient for several reasons. First, it means that we can ignore the "min", and just accept if

\displaystyle u < \frac{\pi(\theta^\star)q(\theta|\theta^\star)}{\pi(\theta)q(\theta^\star|\theta)}

since u\leq 1 regardless. Better still, we can take logs, and accept if

\displaystyle \log u < \log\pi(\theta^\star) - \log\pi(\theta) + \log q(\theta|\theta^\star) - \log q(\theta^\star|\theta),

so there is no need to evaluate any raw densities. Again, in the case of a symmetric proposal distribution, the q(\cdot|\cdot) terms can be dropped.

Another trick worth noting is that in the case of the simple M-H algorithm described, using a single update for the entire state space (and not multiple component-wise updates, for example), and assuming that the same M-H kernel is used repeatedly to generate successive states of a Markov chain, then the \log \pi(\theta) term (which in the context of Bayesian inference will typically be the log posterior) will have been computed at the previous update (irrespective of whether or not the previous move was accepted). So if we are careful about how we pass on that old value to the next iteration, we can avoid recomputing the log posterior, and our algorithm will only require one log posterior evaluation per iteration rather than two. In functional programming languages it is often convenient to pass around this current log posterior density evaluation explicitly, effectively augmenting the state space of the Markov chain to include the log posterior density.

HoF for a M-H kernel

Since I’m a fan of functional programming, we will adopt a functional style throughout, and start by creating a higher-order function (HoF) that accepts a log-posterior and proposal kernel as input and returns a Metropolis kernel as output.

R

In R we can write a function to create a M-H kernel as follows.

mhKernel = function(logPost, rprop, dprop = function(new, old, ...) { 1 })
    function(x, ll) {
        prop = rprop(x)
        llprop = logPost(prop)
        a = llprop - ll + dprop(x, prop) - dprop(prop, x)
        if (log(runif(1)) < a)
            list(x=prop, ll=llprop)
        else
            list(x=x, ll=ll)
    }

Note that the kernel returned requires as input both a current state x and its associated log-posterior, ll. The new state and log-posterior densities are returned.

We need to use this transition kernel to simulate a Markov chain by successive substitution of newly simulated values back into the kernel. In more sophisticated programming languages we will use streams for this, but in R we can just use a for loop to sample values and write the states into the rows of a matrix.

mcmc = function(init, kernel, iters = 10000, thin = 10, verb = TRUE) {
    p = length(init)
    ll = -Inf
    mat = matrix(0, nrow = iters, ncol = p)
    colnames(mat) = names(init)
    x = init
    if (verb) 
        message(paste(iters, "iterations"))
    for (i in 1:iters) {
        if (verb) 
            message(paste(i, ""), appendLF = FALSE)
        for (j in 1:thin) {
            pair = kernel(x, ll)
            x = pair$x
            ll = pair$ll
            }
        mat[i, ] = x
        }
    if (verb) 
        message("Done.")
    mat
}

Then, in the context of our running logistic regression example, and using the log-posterior from the previous post, we can construct our kernel and run it as follows.

pre = c(10.0,1,1,1,1,1,5,1)
out = mcmc(init, mhKernel(lpost,
          function(x) x + pre*rnorm(p, 0, 0.02)), thin=1000)

Note the use of a symmetric proposal, so the proposal density is not required. Also note the use of a larger proposal variance for the intercept term and the second last covariate. See the full runnable script for further details.

Python

We can do something very similar to R in Python using NumPy. Our HoF for constructing a M-H kernel is

def mhKernel(lpost, rprop, dprop = lambda new, old: 1.):
    def kernel(x, ll):
        prop = rprop(x)
        lp = lpost(prop)
        a = lp - ll + dprop(x, prop) - dprop(prop, x)
        if (np.log(np.random.rand()) < a):
            x = prop
            ll = lp
        return x, ll
    return kernel

Our Markov chain runner function is

def mcmc(init, kernel, thin = 10, iters = 10000, verb = True):
    p = len(init)
    ll = -np.inf
    mat = np.zeros((iters, p))
    x = init
    if (verb):
        print(str(iters) + " iterations")
    for i in range(iters):
        if (verb):
            print(str(i), end=" ", flush=True)
        for j in range(thin):
            x, ll = kernel(x, ll)
        mat[i,:] = x
    if (verb):
        print("\nDone.", flush=True)
    return mat

We can use this code in the context of our logistic regression example as follows.

pre = np.array([10.,1.,1.,1.,1.,1.,5.,1.])

def rprop(beta):
    return beta + 0.02*pre*np.random.randn(p)

out = mcmc(init, mhKernel(lpost, rprop), thin=1000)

See the full runnable script for further details.

JAX

The above R and Python scripts are fine, but both languages are rather slow for this kind of workload. Fortunately it’s rather straightforward to convert the Python code to JAX to obtain quite amazing speed-up. We can write our M-H kernel as

def mhKernel(lpost, rprop, dprop = jit(lambda new, old: 1.)):
    @jit
    def kernel(key, x, ll):
        key0, key1 = jax.random.split(key)
        prop = rprop(key0, x)
        lp = lpost(prop)
        a = lp - ll + dprop(x, prop) - dprop(prop, x)
        accept = (jnp.log(jax.random.uniform(key1)) < a)
        return jnp.where(accept, prop, x), jnp.where(accept, lp, ll)
    return kernel

and our MCMC runner function as

def mcmc(init, kernel, thin = 10, iters = 10000):
    key = jax.random.PRNGKey(42)
    keys = jax.random.split(key, iters)
    @jit
    def step(s, k):
        [x, ll] = s
        x, ll = kernel(k, x, ll)
        s = [x, ll]
        return s, s
    @jit
    def iter(s, k):
        keys = jax.random.split(k, thin)
        _, states = jax.lax.scan(step, s, keys)
        final = [states[0][thin-1], states[1][thin-1]]
        return final, final
    ll = -np.inf
    x = init
    _, states = jax.lax.scan(iter, [x, ll], keys)
    return states[0]

There are really only two slightly tricky things about this code.

The first relates to the way JAX handles pseudo-random numbers. Since JAX is a pure functional eDSL, it can’t be used in conjunction with the typical pseudo-random number generators often used in imperative programming languages which rely on a global mutable state. This can be dealt with reasonably straightforwardly by explicitly passing around the random number state. There is a standard way of doing this that has been common practice in functional programming languages for decades. However, this standard approach is very sequential, and so doesn’t work so well in a parallel context. JAX therefore uses a splittable random number generator, where new states are created by splitting the current state into two (or more). We’ll come back to this when we get to the Haskell examples.

The second thing that might be unfamiliar to imperative programmers is the use of the scan operation (jax.lax.scan) to generate the Markov chain rather than a "for" loop. But scans are standard operations in most functional programming languages.

We can then call this code for our logistic regression example with

pre = jnp.array([10.,1.,1.,1.,1.,1.,5.,1.]).astype(jnp.float32)

@jit
def rprop(key, beta):
    return beta + 0.02*pre*jax.random.normal(key, [p])

out = mcmc(init, mhKernel(lpost, rprop), thin=1000)

See the full runnable script for further details.

Scala

In Scala we can use a similar approach to that already seen for defining a HoF to return a M-H kernel.

def mhKernel[S](
    logPost: S => Double, rprop: S => S,
    dprop: (S, S) => Double = (n: S, o: S) => 1.0
  ): ((S, Double)) => (S, Double) =
    val r = Uniform(0.0,1.0)
    state =>
      val (x0, ll0) = state
      val x = rprop(x0)
      val ll = logPost(x)
      val a = ll - ll0 + dprop(x0, x) - dprop(x, x0)
      if (math.log(r.draw()) < a)
        (x, ll)
      else
        (x0, ll0)

Note that Scala’s static typing does not prevent us from defining a function that is polymorphic in the type of the chain state, which we here call S. Also note that we are adopting a pragmatic approach to random number generation, exploiting the fact that Scala is not a pure functional language, using a mutable generator, and omitting to capture the non-determinism of the rprop function (and the returned kernel) in its type signature. In Scala this is a choice, and we could adopt a purer approach if preferred. We’ll see what such an approach will look like in Haskell, coming up next.

Now that we have the kernel, we don’t need to write an explicit runner function since Scala has good support for streaming data. There are many more-or-less sophisticated ways that we can work with data streams in Scala, and the choice depends partly on how pure one is being about tracking effects (such as non-determinism), but here I’ll just use the simple LazyList from the standard library for unfolding the kernel into an infinite MCMC chain before thinning and truncating appropriately.

  val pre = DenseVector(10.0,1.0,1.0,1.0,1.0,1.0,5.0,1.0)
  def rprop(beta: DVD): DVD = beta + pre *:* (DenseVector(Gaussian(0.0,0.02).sample(p).toArray))
  val kern = mhKernel(lpost, rprop)
  val s = LazyList.iterate((init, -Inf))(kern) map (_._1)
  val out = s.drop(150).thin(1000).take(10000)

See the full runnable script for further details.

Haskell

Since Haskell is a pure functional language, we need to have some convention regarding pseudo-random number generation. Haskell supports several styles. The most commonly adopted approach wraps a mutable generator up in a monad. The typical alternative is to use a pure functional generator and either explicitly thread the state through code or hide this in a monad similar to the standard approach. However, Haskell also supports the use of splittable generators, so we can consider all three approaches for comparative purposes. The approach taken does affect the code and the type signatures, and even the streaming data abstractions most appropriate for chain generation.

Starting with a HoF for producing a Metropolis kernel, an approach using the standard monadic generators could like like

mKernel :: (StatefulGen g m) => (s -> Double) -> (s -> g -> m s) -> 
           g -> (s, Double) -> m (s, Double)
mKernel logPost rprop g (x0, ll0) = do
  x <- rprop x0 g
  let ll = logPost(x)
  let a = ll - ll0
  u <- (genContVar (uniformDistr 0.0 1.0)) g
  let next = if ((log u) < a)
        then (x, ll)
        else (x0, ll0)
  return next

Note how non-determinism is captured in the type signatures by the monad m. The explicit pure approach is to thread the generator through non-deterministic functions.

mKernelP :: (RandomGen g) => (s -> Double) -> (s -> g -> (s, g)) -> 
            g -> (s, Double) -> ((s, Double), g)
mKernelP logPost rprop g (x0, ll0) = let
  (x, g1) = rprop x0 g
  ll = logPost(x)
  a = ll - ll0
  (u, g2) = uniformR (0, 1) g1
  next = if ((log u) < a)
        then (x, ll)
        else (x0, ll0)
  in (next, g2)

Here the updated random number generator state is returned from each non-deterministic function for passing on to subsequent non-deterministic functions. This explicit sequencing of operations makes it possible to wrap the generator state in a state monad giving code very similar to the stateful monadic generator approach, but as already discussed, the sequential nature of this approach makes it unattractive in parallel and concurrent settings.

Fortunately the standard Haskell pure generator is splittable, meaning that we can adopt a splitting approach similar to JAX if we prefer, since this is much more parallel-friendly.

mKernelP :: (RandomGen g) => (s -> Double) -> (s -> g -> s) -> 
            g -> (s, Double) -> (s, Double)
mKernelP logPost rprop g (x0, ll0) = let
  (g1, g2) = split g
  x = rprop x0 g1
  ll = logPost(x)
  a = ll - ll0
  u = unif g2
  next = if ((log u) < a)
        then (x, ll)
        else (x0, ll0)
  in next

Here non-determinism is signalled by passing a generator state (often called a "key" in the context of splittable generators) into a function. Functions receiving a key are responsible for splitting it to ensure that no key is ever used more than once.

Once we have a kernel, we need to unfold our Markov chain. When using the monadic generator approach, it is most natural to unfold using a monadic stream

mcmc :: (StatefulGen g m) =>
  Int -> Int -> s -> (g -> s -> m s) -> g -> MS.Stream m s
mcmc it th x0 kern g = MS.iterateNM it (stepN th (kern g)) x0

stepN :: (Monad m) => Int -> (a -> m a) -> (a -> m a)
stepN n fa = if (n == 1)
  then fa
  else (\x -> (fa x) >>= (stepN (n-1) fa))

whereas for the explicit approaches it is more natural to unfold into a regular infinite data stream. So, for the explicit sequential approach we could use

mcmcP :: (RandomGen g) => s -> (g -> s -> (s, g)) -> g -> DS.Stream s
mcmcP x0 kern g = DS.unfold stepUf (x0, g)
  where
    stepUf xg = let
      (x1, g1) = kern (snd xg) (fst xg)
      in (x1, (x1, g1))

and with the splittable approach we could use

mcmcP :: (RandomGen g) =>
  s -> (g -> s -> s) -> g -> DS.Stream s
mcmcP x0 kern g = DS.unfold stepUf (x0, g)
  where
    stepUf xg = let
      (x1, g1) = xg
      x2 = kern g1 x1
      (g2, _) = split g1
      in (x2, (x2, g2))

Calling these functions for our logistic regression example is similar to what we have seen before, but again there are minor syntactic differences depending on the approach. For further details see the full runnable scripts for the monadic approach, the pure sequential approach, and the splittable approach.

Dex

Dex is a pure functional language and uses a splittable random number generator, so the style we use is similar to JAX (or Haskell using a splittable generator). We can generate a Metropolis kernel with

def mKernel {s} (lpost: s -> Float) (rprop: Key -> s -> s) : 
    Key -> (s & Float) -> (s & Float) =
  def kern (k: Key) (sll: (s & Float)) : (s & Float) =
    (x0, ll0) = sll
    [k1, k2] = split_key k
    x = rprop k1 x0
    ll = lpost x
    a = ll - ll0
    u = rand k2
    select (log u < a) (x, ll) (x0, ll0)
  kern

We can then unfold our Markov chain with

def markov_chain {s} (k: Key) (init: s) (kern: Key -> s -> s) (its: Nat) :
    Fin its => s =
  with_state init \st.
    for i:(Fin its).
      x = kern (ixkey k i) (get st)
      st := x
      x

Here we combine Dex’s state effect with a for loop to unfold the stream. See the full runnable script for further details.

Next steps

As previously discussed, none of these codes are optimised, so care should be taken not to over-interpret running times. However, JAX and Dex are noticeably faster than the alternatives, even running on a single CPU core. Another interesting feature of both JAX and Dex is that they are differentiable. This makes it very easy to develop algorithms using gradient information. In subsequent posts we will think about the gradient of our example log-posterior and how we can use gradient information to develop "better" sampling algorithms.

The complete runnable scripts are all available from this public github repo.

Bayesian inference for a logistic regression model (Part 2)

Part 2: The log posterior

Introduction

This is the second part in a series of posts on MCMC-based Bayesian inference for a logistic regression model. If you are new to this series, please go back to Part 1.

In the previous post we looked at the basic modelling concepts, and how to fit the model using a variety of PPLs. In this post we will prepare for doing MCMC by considering the problem of computing the unnormalised log posterior for the model. We will then see how this posterior can be implemented in several different languages and libraries.

Derivation

Basic structure

In Bayesian inference the posterior distribution is just the conditional distribution of the model parameters given the data, and therefore proportional to the joint distribution of the model and data. We often write this as

\displaystyle \pi(\theta|y) \propto \pi(\theta,y) = \pi(\theta)\pi(y|\theta).

Taking logs we have

\displaystyle \log \pi(\theta, y) = \log \pi(\theta) + \log \pi(y|\theta).

So (up to an additive constant) the log posterior is just the sum of the log prior and log likelihood. There are many good (numerical) reasons why we try to work exclusively with the log posterior and try to avoid ever evaluating the raw posterior density.

For our example logistic regression model, the parameter vector \theta is just the vector of regression coefficients, \beta. We assumed independent mean zero normal priors for the components of this vector, so the log prior is just the sum of logs of normal densities. Many scientific libraries will have built-in functions for returning the log-pdf of standard distributions, but if an explicit form is required, the log of the density of a N(0,\sigma^2) at x is just

\displaystyle -\log(2\pi)/2 - \log|\sigma| - x^2/(2\sigma^2),

and the initial constant term normalising the density can often be dropped.

Log-likelihood (first attempt)

Information from the data comes into the log posterior via the log-likelihood. The typical way to derive the likelihood for problems of this type is to assume the usual binary encoding of the data (success 1, failure 0). Then, for a Bernoulli observation with probability p_i,\ i=1,\ldots,n, the likelihood associated with observation y_i is

\displaystyle f(y_i|p_i) = \left[ \hphantom{1-}p_i \quad :\ y_i=1 \atop 1-p_i \quad :\ y_i=0 \right. \quad = \quad p_i^{y_i}(1-p_i)^{1-y_i}.

Taking logs and then switching to parameter \eta_i=\text{logit}(p_i) we have

\displaystyle \log f(y_i|\eta_i) = y_i\eta_i - \log(1+e^{\eta_i}),

and summing over n observations gives the log likelihood

\displaystyle \log\pi(y|\eta) \equiv \ell(\eta;y) = y\cdot \eta - \mathbf{1}\cdot\log(\mathbf{1}+\exp{\eta}).

In the context of logistic regression, \eta is the linear predictor, so \eta=X\beta, giving

\displaystyle \ell(\beta;y) = y^\textsf{T}X\beta - \mathbf{1}^\textsf{T}\log(\mathbf{1}+\exp[X\beta]).

This is a perfectly good way of expressing the log-likelihood, and we will come back to it later when we want the gradient of the log-likelihood, but it turns out that there is a similar-but-different way of deriving it that results in an expression that is equivalent but slightly cheaper to evaluate.

Log-likelihood (second attempt)

For our second attempt, we will assume that the data is coded in a different way. Instead of the usual binary encoding, we will assume that the observation \tilde y_i is 1 for success and -1 for failure. This isn’t really a problem, since the two encodings are related by \tilde y_i = 2y_i-1. This new encoding is convenient in the context of a logit parameterisation since then

\displaystyle f(y_i|\eta_i) = \left[ p_i \ :\ \tilde y_i=1\atop 1-p_i\ :\ \tilde y_i=-1 \right. \ = \ \left[ (1+e^{-\eta_i})^{-1} \ :\ \tilde y_i=1\atop (1+e^{\eta_i})^{-1} \ :\ \tilde y_i=-1 \right. \ = \ (1+e^{-\tilde y_i\eta_i})^{-1} ,

and hence

\displaystyle \log f(y_i|\eta_i) = -\log(1+e^{-\tilde y_i\eta_i}).

Summing over observations gives

\displaystyle \ell(\eta;\tilde y) = -\mathbf{1}\cdot \log(\mathbf{1}+\exp[-\tilde y\circ \eta]),

where \circ denotes the Hadamard product. Substituting \eta=X\beta gives the log-likelihood

\displaystyle \ell(\beta;\tilde y) = -\mathbf{1}^\textsf{T} \log(\mathbf{1}+\exp[-\tilde y\circ X\beta]).

This likelihood is a bit cheaper to evaluate that the one previously derived. If we prefer to write in terms of the original data encoding, we can obviously do so as

\displaystyle \ell(\beta; y) = -\mathbf{1}^\textsf{T} \log(\mathbf{1}+\exp[-(2y-\mathbf{1})\circ (X\beta)]),

and in practice, it is this version that is typically used. To be clear, as an algebraic function of \beta and y the two functions are different. But they coincide for binary vectors y which is all that matters.

Implementation

R

In R we can create functions for evaluating the log-likelihood, log-prior and log-posterior as follows (assuming that X and y are in scope).

ll = function(beta)
    sum(-log(1 + exp(-(2*y - 1)*(X %*% beta))))

lprior = function(beta)
    dnorm(beta[1], 0, 10, log=TRUE) + sum(dnorm(beta[-1], 0, 1, log=TRUE))

lpost = function(beta) ll(beta) + lprior(beta)

Python

In Python (with NumPy and SciPy) we can define equivalent functions with

def ll(beta):
    return np.sum(-np.log(1 + np.exp(-(2*y - 1)*(X.dot(beta)))))

def lprior(beta):
    return (sp.stats.norm.logpdf(beta[0], loc=0, scale=10) + 
            np.sum(sp.stats.norm.logpdf(beta[range(1,p)], loc=0, scale=1)))

def lpost(beta):
    return ll(beta) + lprior(beta)

JAX

Python, like R, is a dynamic language, and relatively slow for MCMC algorithms. JAX is a tensor computation framework for Python that embeds a pure functional differentiable array processing language inside Python. JAX can JIT-compile high-performance code for both CPU and GPU, and has good support for parallelism. It is rapidly becoming the preferred way to develop high-performance sampling algorithms within the Python ecosystem. We can encode our log-posterior in JAX as follows.

@jit
def ll(beta):
    return jnp.sum(-jnp.log(1 + jnp.exp(-(2*y - 1)*jnp.dot(X, beta))))

@jit
def lprior(beta):
    return (jsp.stats.norm.logpdf(beta[0], loc=0, scale=10) + 
            jnp.sum(jsp.stats.norm.logpdf(beta[jnp.array(range(1,p))], loc=0, scale=1)))

@jit
def lpost(beta):
    return ll(beta) + lprior(beta)

Scala

JAX is a pure functional programming language embedded in Python. Pure functional programming languages are intrinsically more scalable and compositional than imperative languages such as R and Python, and are much better suited to exploit concurrency and parallelism. I’ve given a bunch of talks about this recently, so if you are interested in this, perhaps start with the materials for my Laplace’s Demon talk. Scala and Haskell are arguably the current best popular general purpose functional programming languages, so it is possibly interesting to consider the use of these languages for the development of scalable statistical inference codes. Since both languages are statically typed compiled functional languages with powerful type systems, they can be highly performant. However, neither is optimised for numerical (tensor) computation, so you should not expect that they will have performance comparable with optimised tensor computation frameworks such as JAX. We can encode our log-posterior in Scala (with Breeze) as follows:

  def ll(beta: DVD): Double =
      sum(-log(ones + exp(-1.0*(2.0*y - ones)*:*(X * beta))))

  def lprior(beta: DVD): Double =
    Gaussian(0,10).logPdf(beta(0)) + 
      sum(beta(1 until p).map(Gaussian(0,1).logPdf(_)))

  def lpost(beta: DVD): Double = ll(beta) + lprior(beta)

Spark

Apache Spark is a Scala library for distributed "big data" processing on clusters of machines. Despite fundamental differences, there is a sense in which Spark for Scala is a bit analogous to JAX for Python: both Spark and JAX are concerned with scalability, but they are targeting rather different aspects of scalability: JAX is concerned with getting regular sized data processing algorithms to run very fast (on GPUs), whereas Spark is concerned with running huge data processing tasks quickly by distributing work over clusters of machines. Despite obvious differences, the fundamental pure functional computational model adopted by both systems is interestingly similar: both systems are based on lazy transformations of immutable data structures using pure functions. This is a fundamental pattern for scalable data processing transcending any particular language, library or framework. We can encode our log posterior in Spark as follows.

    def ll(beta: DVD): Double = 
      df.map{row =>
        val y = row.getAs[Double](0)
        val x = BDV.vertcat(BDV(1.0),toBDV(row.getAs[DenseVector](1)))
        -math.log(1.0 + math.exp(-1.0*(2.0*y-1.0)*(x.dot(beta))))}.reduce(_+_)
    def lprior(beta: DVD): Double =
      Gaussian(0,10).logPdf(beta(0)) +
        sum(beta(1 until p).map(Gaussian(0,1).logPdf(_)))
    def lpost(beta: DVD): Double =
      ll(beta) + lprior(beta)

Haskell

Haskell is an old, lazy pure functional programming language with an advanced type system, and remains the preferred language for the majority of functional programming language researchers. Hmatrix is the standard high performance numerical linear algebra library for Haskell, so we can use it to encode our log-posterior as follows.

ll :: Matrix Double -> Vector Double -> Vector Double -> Double
ll x y b = (negate) (vsum (cmap log (
                              (scalar 1) + (cmap exp (cmap (negate) (
                                                         (((scalar 2) * y) - (scalar 1)) * (x #> b)
                                                         )
                                                     )))))

pscale :: [Double] -- prior standard deviations
pscale = [10.0, 1, 1, 1, 1, 1, 1, 1]
lprior :: Vector Double -> Double
lprior b = sum $  (\x -> logDensity (normalDistr 0.0 (snd x)) (fst x)) <$> (zip (toList b) pscale)
           
lpost :: Matrix Double -> Vector Double -> Vector Double -> Double
lpost x y b = (ll x y b) + (lprior b)

Again, a reminder that, here and elsewhere, there are various optimisations could be done that I’m not bothering with. This is all just proof-of-concept code.

Dex

JAX proves that a pure functional DSL for tensor computation can be extremely powerful and useful. But embedding such a language in a dynamic imperative language like Python has a number of drawbacks. Dex is an experimental statically typed stand-alone DSL for differentiable array and tensor programming that attempts to combine some of the correctness and composability benefits of powerful statically typed functional languages like Scala and Haskell with the performance benefits of tensor computation systems like JAX. It is currently rather early its development, but seems very interesting, and is already quite useable. We can encode our log-posterior in Dex as follows.

def ll (b: (Fin 8)=>Float) : Float =
  neg $  sum (log (map (\ x. (exp x) + 1) ((map (\ yi. 1 - 2*yi) y)*(x **. b))))

pscale = [10.0, 1, 1, 1, 1, 1, 1, 1] -- prior SDs
prscale = map (\ x. 1.0/x) pscale

def lprior (b: (Fin 8)=>Float) : Float =
  bs = b*prscale
  neg $  sum ((log pscale) + (0.5 .* (bs*bs)))

def lpost (b: (Fin 8)=>Float) : Float =
  (ll b) + (lprior b)

Next steps

Now that we have a way of evaluating the log posterior, we can think about constructing Markov chains having the posterior as their equilibrium distribution. In the next post we will look at one of the simplest ways of doing this: the Metropolis algorithm.

Complete runnable scripts are available from this public github repo.

A probability monad for the bootstrap particle filter

Introduction

In the previous post I showed how to write your own general-purpose monadic probabilistic programming language from scratch in 50 lines of (Scala) code. That post is a pre-requisite for this one, so if you haven’t read it, go back and have a quick skim through it before proceeding. In that post I tried to keep everything as simple as possible, but at the expense of both elegance and efficiency. In this post I’ll address one problem with the implementation from that post – the memory (and computational) overhead associated with forming the Cartesian product of particle sets during monadic binding (flatMap). So if particle sets are typically of size N, then the Cartesian product is of size N^2, and multinomial resampling is applied to this set of size N^2 in order to sample back down to a set of size N. But this isn’t actually necessary. We can directly construct a set of size N, certainly saving memory, but also potentially saving computation time if the conditional distribution (on the right of the monadic bind) can be efficiently sampled. If we do this we will have a probability monad encapsulating the logic of a bootstrap particle filter, such as is often used for computing the filtering distribution of a state-space model in time series analysis. This simple change won’t solve the computational issues associated with deep monadic binding, but does solve the memory problem, and can lead to computationally efficient algorithms so long as care is taken in the formulation of probabilistic programs to ensure that deep monadic binding doesn’t occur. We’ll discuss that issue in the context of state-space models later, once we have our new SMC-based probability monad.

Materials for this post can be found in my blog repo, and a draft of this post itself can be found in the form of an executable tut document.

An SMC-based monad

The idea behind the approach to binding used in this monad is to mimic the “predict” step of a bootstrap particle filter. Here, for each particle in the source distribution, exactly one particle is drawn from the required conditional distribution and paired with the source particle, preserving the source particle’s original weight. So, in order to operationalise this, we will need a draw method adding into our probability monad. It will also simplify things to add a flatMap method to our Particle type constructor.

To follow along, you can type sbt console from the min-ppl2 directory of my blog repo, then paste blocks of code one at a time.

  import breeze.stats.{distributions => bdist}
  import breeze.linalg.DenseVector
  import cats._
  import cats.implicits._

  implicit val numParticles = 2000

  case class Particle[T](v: T, lw: Double) { // value and log-weight
    def map[S](f: T => S): Particle[S] = Particle(f(v), lw)
    def flatMap[S](f: T => Particle[S]): Particle[S] = {
      val ps = f(v)
      Particle(ps.v, lw + ps.lw)
    }
  }

I’ve added a dependence on cats here, so that we can use some derived methods, later. To take advantage of this, we must provide evidence that our custom types conform to standard type class interfaces. For example, we can provide evidence that Particle[_] is a monad as follows.

  implicit val particleMonad = new Monad[Particle] {
    def pure[T](t: T): Particle[T] = Particle(t, 0.0)
    def flatMap[T,S](pt: Particle[T])(f: T => Particle[S]): Particle[S] = pt.flatMap(f)
    def tailRecM[T,S](t: T)(f: T => Particle[Either[T,S]]): Particle[S] = ???
  }

The technical details are not important for this post, but we’ll see later what this can give us.

We can now define our Prob[_] monad in the following way.

  trait Prob[T] {
    val particles: Vector[Particle[T]]
    def draw: Particle[T]
    def mapP[S](f: T => Particle[S]): Prob[S] = Empirical(particles map (_ flatMap f))
    def map[S](f: T => S): Prob[S] = mapP(v => Particle(f(v), 0.0))
    def flatMap[S](f: T => Prob[S]): Prob[S] = mapP(f(_).draw)
    def resample(implicit N: Int): Prob[T] = {
      val lw = particles map (_.lw)
      val mx = lw reduce (math.max(_,_))
      val rw = lw map (lwi => math.exp(lwi - mx))
      val law = mx + math.log(rw.sum/(rw.length))
      val ind = bdist.Multinomial(DenseVector(rw.toArray)).sample(N)
      val newParticles = ind map (i => particles(i))
      Empirical(newParticles.toVector map (pi => Particle(pi.v, law)))
    }
    def cond(ll: T => Double): Prob[T] = mapP(v => Particle(v, ll(v)))
    def empirical: Vector[T] = resample.particles.map(_.v)
  }

  case class Empirical[T](particles: Vector[Particle[T]]) extends Prob[T] {
    def draw: Particle[T] = {
      val lw = particles map (_.lw)
      val mx = lw reduce (math.max(_,_))
      val rw = lw map (lwi => math.exp(lwi - mx))
      val law = mx + math.log(rw.sum/(rw.length))
      val idx = bdist.Multinomial(DenseVector(rw.toArray)).draw
      Particle(particles(idx).v, law)
    }
  }

As before, if you are pasting code blocks into the REPL, you will need to use :paste mode to paste these two definitions together.

The essential structure is similar to that from the previous post, but with a few notable differences. Most fundamentally, we now require any concrete implementation to provide a draw method returning a single particle from the distribution. Like before, we are not worrying about purity of functional code here, and using a standard random number generator with a globally mutating state. We can define a mapP method (for “map particle”) using the new flatMap method on Particle, and then use that to define map and flatMap for Prob[_]. Crucially, draw is used to define flatMap without requiring a Cartesian product of distributions to be formed.

We add a draw method to our Empirical[_] implementation. This method is computationally intensive, so it will still be computationally problematic to chain several flatMaps together, but this will no longer be N^2 in memory utilisation. Note that again we carefully set the weight of the drawn particle so that its raw weight is the average of the raw weight of the empirical distribution. This is needed to propagate conditioning information correctly back through flatMaps. There is obviously some code duplication between the draw method on Empirical and the resample method on Prob, but I’m not sure it’s worth factoring out.

It is worth noting that neither flatMap nor cond triggers resampling, so the user of the library is now responsible for resampling when appropriate.

We can provide evidence that Prob[_] forms a monad just like we did Particle[_].

  implicit val probMonad = new Monad[Prob] {
    def pure[T](t: T): Prob[T] = Empirical(Vector(Particle(t, 0.0)))
    def flatMap[T,S](pt: Prob[T])(f: T => Prob[S]): Prob[S] = pt.flatMap(f)
    def tailRecM[T,S](t: T)(f: T => Prob[Either[T,S]]): Prob[S] = ???
  }

Again, we’ll want to be able to create a distribution from an unweighted collection of values.

  def unweighted[T](ts: Vector[T], lw: Double = 0.0): Prob[T] =
    Empirical(ts map (Particle(_, lw)))

We will again define an implementation for distributions with tractable likelihoods, which are therefore easy to condition on. They will typically also be easy to draw from efficiently, and we will use this fact, too.

  trait Dist[T] extends Prob[T] {
    def ll(obs: T): Double
    def ll(obs: Seq[T]): Double = obs map (ll) reduce (_+_)
    def fit(obs: Seq[T]): Prob[T] = mapP(v => Particle(v, ll(obs)))
    def fitQ(obs: Seq[T]): Prob[T] = Empirical(Vector(Particle(obs.head, ll(obs))))
    def fit(obs: T): Prob[T] = fit(List(obs))
    def fitQ(obs: T): Prob[T] = fitQ(List(obs))
  }

We can give implementations of this for a few standard distributions.

  case class Normal(mu: Double, v: Double)(implicit N: Int) extends Dist[Double] {
    lazy val particles = unweighted(bdist.Gaussian(mu, math.sqrt(v)).
      sample(N).toVector).particles
    def draw = Particle(bdist.Gaussian(mu, math.sqrt(v)).draw, 0.0)
    def ll(obs: Double) = bdist.Gaussian(mu, math.sqrt(v)).logPdf(obs)
  }

  case class Gamma(a: Double, b: Double)(implicit N: Int) extends Dist[Double] {
    lazy val particles = unweighted(bdist.Gamma(a, 1.0/b).
      sample(N).toVector).particles
    def draw = Particle(bdist.Gamma(a, 1.0/b).draw, 0.0)
    def ll(obs: Double) = bdist.Gamma(a, 1.0/b).logPdf(obs)
  }

  case class Poisson(mu: Double)(implicit N: Int) extends Dist[Int] {
    lazy val particles = unweighted(bdist.Poisson(mu).
      sample(N).toVector).particles
    def draw = Particle(bdist.Poisson(mu).draw, 0.0)
    def ll(obs: Int) = bdist.Poisson(mu).logProbabilityOf(obs)
  }

Note that we now have to provide an (efficient) draw method for each implementation, returning a single draw from the distribution as a Particle with a (raw) weight of 1 (log weight of 0).

We are done. It’s a few more lines of code than that from the previous post, but this is now much closer to something that could be used in practice to solve actual inference problems using a reasonable number of particles. But to do so we will need to be careful do avoid deep monadic binding. This is easiest to explain with a concrete example.

Using the SMC-based probability monad in practice

Monadic binding and applicative structure

As explained in the previous post, using Scala’s for-expressions for monadic binding gives a natural and elegant PPL for our probability monad “for free”. This is fine, and in general there is no reason why using it should lead to inefficient code. However, for this particular probability monad implementation, it turns out that deep monadic binding comes with a huge performance penalty. For a concrete example, consider the following specification, perhaps of a prior distribution over some independent parameters.

    val prior = for {
      x <- Normal(0,1)
      y <- Gamma(1,1)
      z <- Poisson(10)
    } yield (x,y,z)

Don’t paste that into the REPL – it will take an age to complete!

Again, I must emphasise that there is nothing wrong with this specification, and there is no reason in principle why such a specification can’t be computationally efficient – it’s just a problem for our particular probability monad. We can begin to understand the problem by thinking about how this will be de-sugared by the compiler. Roughly speaking, the above will de-sugar to the following nested flatMaps.

    val prior2 =
      Normal(0,1) flatMap {x =>
        Gamma(1,1) flatMap {y =>
          Poisson(10) map {z =>
            (x,y,z)}}}

Again, beware of pasting this into the REPL.

So, although written from top to bottom, the nesting is such that the flatMaps collapse from the bottom-up. The second flatMap (the first to collapse) isn’t such a problem here, as the Poisson has a O(1) draw method. But the result is an empirical distribution, which has an O(N) draw method. So the first flatMap (the second to collapse) is an O(N^2) operation. By extension, it’s easy to see that the computational cost of nested flatMaps will be exponential in the number of monadic binds. So, looking back at the for expression, the problem is that there are three <-. The last one isn’t a problem since it corresponds to a map, and the second last one isn’t a problem, since the final distribution is tractable with an O(1) draw method. The problem is the first <-, corresponding to a flatMap of one empirical distribution with respect to another. For our particular probability monad, it’s best to avoid these if possible.

The interesting thing to note here is that because the distributions are independent, there is no need for them to be sequenced. They could be defined in any order. In this case it makes sense to use the applicative structure implied by the monad.

Now, since we have told cats that Prob[_] is a monad, it can provide appropriate applicative methods for us automatically. In Cats, every monad is assumed to be also an applicative functor (which is true in Cartesian closed categories, and Cats implicitly assumes that all functors and monads are defined over CCCs). So we can give an alternative specification of the above prior using applicative composition.

 val prior3 = Applicative[Prob].tuple3(Normal(0,1), Gamma(1,1), Poisson(10))
// prior3: Wrapped.Prob[(Double, Double, Int)] = Empirical(Vector(Particle((-0.057088546468105204,0.03027578552505779,9),0.0), Particle((-0.43686658266043743,0.632210127012762,14),0.0), Particle((-0.8805715148936012,3.4799656228544706,4),0.0), Particle((-0.4371726407147289,0.0010707859994652403,12),0.0), Particle((2.0283297088320755,1.040984491158822,10),0.0), Particle((1.2971862986495886,0.189166705596747,14),0.0), Particle((-1.3111333817551083,0.01962422606642761,9),0.0), Particle((1.6573851896142737,2.4021836368401415,9),0.0), Particle((-0.909927220984726,0.019595551644771683,11),0.0), Particle((0.33888133893822464,0.2659823344145805,10),0.0), Particle((-0.3300797295729375,3.2714740256437667,10),0.0), Particle((-1.8520554352884224,0.6175322756460341,10),0.0), Particle((0.541156780497547...

This one is mathematically equivalent, but safe to paste into your REPL, as it does not involve deep monadic binding, and can be used whenever we want to compose together independent components of a probabilistic program. Note that “tupling” is not the only possibility – Cats provides a range of functions for manipulating applicative values.

This is one way to avoid deep monadic binding, but another strategy is to just break up a large for expression into separate smaller for expressions. We can examine this strategy in the context of state-space modelling.

Particle filtering for a non-linear state-space model

We can now re-visit the DGLM example from the previous post. We began by declaring some observations and a prior.

    val data = List(2,1,0,2,3,4,5,4,3,2,1)
// data: List[Int] = List(2, 1, 0, 2, 3, 4, 5, 4, 3, 2, 1)

    val prior = for {
      w <- Gamma(1, 1)
      state0 <- Normal(0.0, 2.0)
    } yield (w, List(state0))
// prior: Wrapped.Prob[(Double, List[Double])] = Empirical(Vector(Particle((4.220683377724395,List(0.37256749723762683)),0.0), Particle((0.4436668049925418,List(-1.0053578391265572)),0.0), Particle((0.9868899648436931,List(-0.6985099310193449)),0.0), Particle((0.13474375773634908,List(0.9099291736792412)),0.0), Particle((1.9654021747685184,List(-0.042127103727998175)),0.0), Particle((0.21761202474220223,List(1.1074616830012525)),0.0), Particle((0.31037163527711015,List(0.9261849914020324)),0.0), Particle((1.672438830781466,List(0.01678529855289384)),0.0), Particle((0.2257151759143097,List(2.5511304854128354)),0.0), Particle((0.3046489890769499,List(3.2918304533361398)),0.0), Particle((1.5115941814057159,List(-1.633612165168878)),0.0), Particle((1.4185906813831506,List(-0.8460922678989864))...

Looking carefully at the for-expression, there are just two <-, and the distribution on the RHS of the flatMap is tractable, so this is just O(N). So far so good.

Next, let’s look at the function to add a time point, which previously looked something like the following.

    def addTimePointSIS(current: Prob[(Double, List[Double])],
      obs: Int): Prob[(Double, List[Double])] = {
      println(s"Conditioning on observation: $obs")
      for {
        tup <- current
        (w, states) = tup
        os = states.head
        ns <- Normal(os, w)
        _ <- Poisson(math.exp(ns)).fitQ(obs)
      } yield (w, ns :: states)
    }
// addTimePointSIS: (current: Wrapped.Prob[(Double, List[Double])], obs: Int)Wrapped.Prob[(Double, List[Double])]

Recall that our new probability monad does not automatically trigger resampling, so applying this function in a fold will lead to a simple sampling importance sampling (SIS) particle filter. Typically, the bootstrap particle filter includes resampling after each time point, giving a special case of a sampling importance resampling (SIR) particle filter, which we could instead write as follows.

    def addTimePointSimple(current: Prob[(Double, List[Double])],
      obs: Int): Prob[(Double, List[Double])] = {
      println(s"Conditioning on observation: $obs")
      val updated = for {
        tup <- current
        (w, states) = tup
        os = states.head
        ns <- Normal(os, w)
        _ <- Poisson(math.exp(ns)).fitQ(obs)
      } yield (w, ns :: states)
      updated.resample
    }
// addTimePointSimple: (current: Wrapped.Prob[(Double, List[Double])], obs: Int)Wrapped.Prob[(Double, List[Double])]

This works fine, but we can see that there are three <- in this for expression. This leads to a flatMap with an empirical distribution on the RHS, and hence is O(N^2). But this is simple enough to fix, by separating the updating process into separate “predict” and “update” steps, which is how people typically formulate particle filters for state-space models, anyway. Here we could write that as

    def addTimePoint(current: Prob[(Double, List[Double])],
      obs: Int): Prob[(Double, List[Double])] = {
      println(s"Conditioning on observation: $obs")
      val predict = for {
        tup <- current
        (w, states) = tup
        os = states.head
        ns <- Normal(os, w)
      }
      yield (w, ns :: states)
      val updated = for {
        tup <- predict
        (w, states) = tup
        st = states.head
        _ <- Poisson(math.exp(st)).fitQ(obs)
      } yield (w, states)
      updated.resample
    }
// addTimePoint: (current: Wrapped.Prob[(Double, List[Double])], obs: Int)Wrapped.Prob[(Double, List[Double])]

By breaking the for expression into two: the first for the “predict” step and the second for the “update” (conditioning on the observation), we get two O(N) operations, which for large N is clearly much faster. We can then run the filter by folding over the observations.

  import breeze.stats.{meanAndVariance => meanVar}
// import breeze.stats.{meanAndVariance=>meanVar}

  val mod = data.foldLeft(prior)(addTimePoint(_,_)).empirical
// Conditioning on observation: 2
// Conditioning on observation: 1
// Conditioning on observation: 0
// Conditioning on observation: 2
// Conditioning on observation: 3
// Conditioning on observation: 4
// Conditioning on observation: 5
// Conditioning on observation: 4
// Conditioning on observation: 3
// Conditioning on observation: 2
// Conditioning on observation: 1
// mod: Vector[(Double, List[Double])] = Vector((0.24822528144246606,List(0.06290285371838457, 0.01633338109272575, 0.8997103339551227, 1.5058726341571411, 1.0579925693609091, 1.1616536515200064, 0.48325623593870665, 0.8457351097543767, -0.1988290999293708, -0.4787511341321954, -0.23212497417019512, -0.15327432440577277)), (1.111430233331792,List(0.6709342824443849, 0.009092797044165657, -0.13203367846117453, 0.4599952735399485, 1.3779288637042504, 0.6176597963402879, 0.6680455419800753, 0.48289163013446945, -0.5994001698510807, 0.4860969602653898, 0.10291798193078927, 1.2878325765987266)), (0.6118925941009055,List(0.6421161986636132, 0.679470360928868, 1.0552459559203342, 1.200835166087372, 1.3690372269589233, 1.8036766847282912, 0.6229883551656629, 0.14872642198313774, -0.122700856878725...

  meanVar(mod map (_._1)) // w
// res0: breeze.stats.meanAndVariance.MeanAndVariance = MeanAndVariance(0.2839184023932576,0.07391602428256917,2000)

  meanVar(mod map (_._2.reverse.head)) // initial state
// res1: breeze.stats.meanAndVariance.MeanAndVariance = MeanAndVariance(0.26057368528422714,0.4802810202354611,2000)

  meanVar(mod map (_._2.head)) // final state
// res2: breeze.stats.meanAndVariance.MeanAndVariance = MeanAndVariance(0.5448036669181697,0.28293080584600894,2000)

Summary and conclusions

Here we have just done some minor tidying up of the rather naive probability monad from the previous post to produce an SMC-based probability monad with improved performance characteristics. Again, we get an embedded probabilistic programming language “for free”. Although the language itself is very flexible, allowing us to construct more-or-less arbitrary probabilistic programs for Bayesian inference problems, we saw that a bug/feature of this particular inference algorithm is that care must be taken to avoid deep monadic binding if reasonable performance is to be obtained. In most cases this is simple to achieve by using applicative composition or by breaking up large for expressions.

There are still many issues and inefficiencies associated with this PPL. In particular, if the main intended application is to state-space models, it would make more sense to tailor the algorithms and implementations to exactly that case. OTOH, if the main concern is a generic PPL, then it would make sense to make the PPL independent of the particular inference algorithm. These are both potential topics for future posts.

Software

  • min-ppl2 – code associated with this blog post
  • Rainier – a more efficient PPL with similar syntax
  • monad-bayes – a Haskell library exploring related ideas

Write your own general-purpose monadic probabilistic programming language from scratch in 50 lines of (Scala) code

Background

In May I attended a great workshop on advances and challenges in machine learning languages at the CMS in Cambridge. There was an a good mix of people from different disciplines, and a bit of a theme around probabilistic programming. The workshop schedule includes links to many of the presentations, and is generally worth browsing. In particular, it includes a link to the slides for my presentation on a compositional approach to scalable Bayesian computation and probabilistic programming. I’ve given a few talks on this kind of thing over the last couple of years, at Newcastle, at the Isaac Newton Institute in Cambridge (twice), and at CIRM in France. But I think I explained things best at this workshop at the CMS, though my impression could partly have been a reflection of the more interested and relevant audience. In the talk I started with a basic explanation of why ideas from category theory and functional programming can help to solve problems in statistical computing in a more composable and scalable way, before moving on to discuss probability monads and their fundamental connection to probabilistic programming. The take home message from the talk is that if you have a generic inference algorithm, expressing the logic in the context of probability monads can give you an embedded probabilistic programming language (PPL) for that inference algorithm essentially “for free”.

So, during my talk I said something a little fool-hardy. I can’t remember my exact words, but while presenting the idea behind an SMC-based probability monad I said something along the lines of “one day I will write a blog post on how to write a probabilistic programming language from scratch in 50 lines of code, and this is how I’ll do it“! Rather predictably (with hindsight), immediately after my talk about half a dozen people all pleaded with me to urgently write the post! I’ve been a little busy since then, but now that things have settled down a little for the summer, I’ve some time to think and code, so here is that post.

Introduction

The idea behind this post is to show that, if you think about the problem in the right way, and use a programming language with syntactic support for monadic composition, then producing a flexible, general, compositional, embedded domain specific language (DSL) for probabilistic programming based on a given generic inference algorithm is no more effort than hard-coding two or three illustrative examples. You would need to code up two or three examples for a paper anyway, but providing a PPL is way more useful. There is also an interesting converse to this, which is that if you can’t easily produce a PPL for your “general” inference algorithm, then perhaps it isn’t quite as “general” as you thought. I’ll try to resist exploring that here…

To illustrate these principles I want to develop a fairly minimal PPL, so that the complexities of the inference algorithm don’t hide the simplicity of the PPL embedding. Importance sampling with resampling is probably the simplest useful generic Bayesian inference algorithm to implement, so that’s what I’ll use. Note that there are many limitations of the approach that I will adopt, which will make it completely unsuitable for “real” problems. In particular, this implementation is: inefficient, in terms of both compute time and memory usage, statistically inefficient for deep nesting and repeated conditioning, due to the particle degeneracy problem, specific to a particular probability monad, strictly evaluated, impure (due to mutation of global random number state), etc. All of these things are easily fixed, but all at the expense of greater abstraction, complexity and lines of code. I’ll probably discuss some of these generalisations and improvements in future posts, but for this post I want to keep everything as short and simple as practical. It’s also worth mentioning that there is nothing particularly original here. Many people have written about monadic embedded PPLs, and several have used an SMC-based monad for illustration. I’ll give some pointers to useful further reading at the end.

The language, in 50 lines of code

Without further ado, let’s just write the PPL. I’m using plain Scala, with just a dependency on the Breeze scientific library, which I’m going to use for simulating random numbers from standard distributions, and evaluation of their log densities. I have a directory of materials associated with this post in a git repo. This post is derived from an executable tut document (so you know it works), which can be found here. If you just want to follow along copying code at the command prompt, just run sbt from an empty or temp directory, and copy the following to spin up a Scala console with the Breeze dependency:

set libraryDependencies += "org.scalanlp" %% "breeze" % "1.0-RC4"
set libraryDependencies += "org.scalanlp" %% "breeze-natives" % "1.0-RC4"
set scalaVersion := "2.13.0"
console

We start with a couple of Breeze imports

import breeze.stats.{distributions => bdist}
import breeze.linalg.DenseVector

which are not strictly necessary, but clean up the subsequent code. We are going to use a set of weighted particles to represent a probability distribution empirically, so we’ll start by defining an appropriate ADT for these:

implicit val numParticles = 300

case class Particle[T](v: T, lw: Double) { // value and log-weight
  def map[S](f: T => S): Particle[S] = Particle(f(v), lw)
}

We also include a map method for pushing a particle through a transformation, and a default number of particles for sampling and resampling. 300 particles are enough for illustrative purposes. Ideally it would be good to increase this for more realistic experiments. We can use this particle type to build our main probability monad as follows.

trait Prob[T] {
  val particles: Vector[Particle[T]]
  def map[S](f: T => S): Prob[S] = Empirical(particles map (_ map f))
  def flatMap[S](f: T => Prob[S]): Prob[S] = {
    Empirical((particles map (p => {
      f(p.v).particles.map(psi => Particle(psi.v, p.lw + psi.lw))
    })).flatten).resample
  }
  def resample(implicit N: Int): Prob[T] = {
    val lw = particles map (_.lw)
    val mx = lw reduce (math.max(_,_))
    val rw = lw map (lwi => math.exp(lwi - mx))
    val law = mx + math.log(rw.sum/(rw.length))
    val ind = bdist.Multinomial(DenseVector(rw.toArray)).sample(N)
    val newParticles = ind map (i => particles(i))
    Empirical(newParticles.toVector map (pi => Particle(pi.v, law)))
  }
  def cond(ll: T => Double): Prob[T] =
    Empirical(particles map (p => Particle(p.v, p.lw + ll(p.v))))
  def empirical: Vector[T] = resample.particles.map(_.v)
}

case class Empirical[T](particles: Vector[Particle[T]]) extends Prob[T]

Note that if you are pasting into the Scala REPL you will need to use :paste mode for this. So Prob[_] is our base probability monad trait, and Empirical[_] is our simplest implementation, which is just a collection of weighted particles. The method flatMap forms the naive product of empirical measures and then resamples in order to stop an explosion in the number of particles. There are two things worth noting about the resample method. The first is that the log-sum-exp trick is being used to avoid overflow and underflow when the log weights are exponentiated. The second is that although the method returns an equally weighted set of particles, the log weights are all set in order that the average raw weight of the output set matches the average raw weight of the input set. This is a little tricky to explain, but it turns out to be necessary in order to correctly propagate conditioning information back through multiple monadic binds (flatMaps). The cond method allows conditioning of a distribution using an arbitrary log-likelihood. It is included for comparison with some other implementations I will refer to later, but we won’t actually be using it, so we could save two lines of code here if necessary. The empirical method just extracts an unweighted set of values from a distribution for subsequent analysis.

It will be handy to have a function to turn a bunch of unweighted particles into a set of particles with equal weights (a sort-of inverse of the empirical method just described), so we can define that as follows.

def unweighted[T](ts: Vector[T], lw: Double = 0.0): Prob[T] =
  Empirical(ts map (Particle(_, lw)))

Probabilistic programming is essentially trivial if we only care about forward sampling. But interesting PPLs allow us to condition on observed values of random variables. In the context of SMC, this is simplest when the distribution being conditioned has a tractable log-likelihood. So we can now define an extension of our probability monad for distributions with a tractable log-likelihood, and define a bunch of convenient conditioning (or “fitting”) methods using it.

trait Dist[T] extends Prob[T] {
  def ll(obs: T): Double
  def ll(obs: Seq[T]): Double = obs map (ll) reduce (_+_)
  def fit(obs: Seq[T]): Prob[T] =
    Empirical(particles map (p => Particle(p.v, p.lw + ll(obs))))
  def fitQ(obs: Seq[T]): Prob[T] = Empirical(Vector(Particle(obs.head, ll(obs))))
  def fit(obs: T): Prob[T] = fit(List(obs))
  def fitQ(obs: T): Prob[T] = fitQ(List(obs))
}

The only unimplemented method is ll(). The fit method re-weights a particle set according to the observed log-likelihood. For convenience, it also returns a particle cloud representing the posterior-predictive distribution of an iid value from the same distribution. This is handy, but comes at the expense of introducing an additional particle cloud. So, if you aren’t interested in the posterior predictive, you can avoid this cost by using the fitQ method (for “fit quick”), which doesn’t return anything useful. We’ll see examples of this in practice, shortly. Note that the fitQ methods aren’t strictly required for our “minimal” PPL, so we can save a couple of lines by omitting them if necessary. Similarly for the variants which allow conditioning on a collection of iid observations from the same distribution.

At this point we are essentially done. But for convenience, we can define a few standard distributions to help get new users of our PPL started. Of course, since the PPL is embedded, it is trivial to add our own additional distributions later.

case class Normal(mu: Double, v: Double)(implicit N: Int) extends Dist[Double] {
  lazy val particles = unweighted(bdist.Gaussian(mu, math.sqrt(v)).sample(N).toVector).particles
  def ll(obs: Double) = bdist.Gaussian(mu, math.sqrt(v)).logPdf(obs) }

case class Gamma(a: Double, b: Double)(implicit N: Int) extends Dist[Double] {
  lazy val particles = unweighted(bdist.Gamma(a, 1.0/b).sample(N).toVector).particles
  def ll(obs: Double) = bdist.Gamma(a, 1.0/b).logPdf(obs) }

case class Poisson(mu: Double)(implicit N: Int) extends Dist[Int] {
  lazy val particles = unweighted(bdist.Poisson(mu).sample(N).toVector).particles
  def ll(obs: Int) = bdist.Poisson(mu).logProbabilityOf(obs) }

Note that I’ve parameterised the Normal and Gamma the way that statisticians usually do, and not the way they are usually parameterised in scientific computing libraries (such as Breeze).

That’s it! This is a complete, general-purpose, composable, monadic PPL, in 50 (actually, 48, and fewer still if you discount trailing braces) lines of code. Let’s now see how it works in practice.

Examples

Normal random sample

We’ll start off with just about the simplest slightly interesting example I can think of: Bayesian inference for the mean and variance of a normal distribution from a random sample.

import breeze.stats.{meanAndVariance => meanVar}
// import breeze.stats.{meanAndVariance=>meanVar}

val mod = for {
  mu <- Normal(0, 100)
  tau <- Gamma(1, 0.1)
  _ <- Normal(mu, 1.0/tau).fitQ(List(8.0,9,7,7,8,10))
} yield (mu,tau)
// mod: Wrapped.Prob[(Double, Double)] = Empirical(Vector(Particle((8.718127116254472,0.93059589932682),-15.21683812389373), Particle((7.977706390420308,1.1575288208065433),-15.21683812389373), Particle((7.977706390420308,1.1744750937611985),-15.21683812389373), Particle((7.328100552769214,1.1181787982959164),-15.21683812389373), Particle((7.977706390420308,0.8283737237370494),-15.21683812389373), Particle((8.592847414557049,2.2934836446009026),-15.21683812389373), Particle((8.718127116254472,1.498741032928539),-15.21683812389373), Particle((8.592847414557049,0.2506065368748732),-15.21683812389373), Particle((8.543283880264225,1.127386759627675),-15.21683812389373), Particle((7.977706390420308,1.3508728798704925),-15.21683812389373), Particle((7.977706390420308,1.1134430556990933),-15.2168...

val modEmp = mod.empirical
// modEmp: Vector[(Double, Double)] = Vector((7.977706390420308,0.8748006833362748), (6.292345096890432,0.20108091703626174), (9.15330820843396,0.7654238730107492), (8.960935105658741,1.027712984079369), (7.455292602273359,0.49495749079351836), (6.911716909394562,0.7739749058662421), (6.911716909394562,0.6353785792877397), (7.977706390420308,1.1744750937611985), (7.977706390420308,1.1134430556990933), (8.718127116254472,1.166399872049532), (8.763777227034538,1.0468304705769353), (8.718127116254472,0.93059589932682), (7.328100552769214,1.6166695922250236), (8.543283880264225,0.4689300351248357), (8.543283880264225,2.0028918490755094), (7.536025958690963,0.6282318170458533), (7.328100552769214,1.6166695922250236), (7.049843463553113,0.20149378088848635), (7.536025958690963,2.3565657669819897...

meanVar(modEmp map (_._1)) // mu
// res0: breeze.stats.meanAndVariance.MeanAndVariance = MeanAndVariance(8.311171010932343,0.4617800639333532,300)

meanVar(modEmp map (_._2)) // tau
// res1: breeze.stats.meanAndVariance.MeanAndVariance = MeanAndVariance(0.940762723934599,0.23641881704888842,300)

Note the use of the empirical method to turn the distribution into an unweighted set of particles for Monte Carlo analysis. Anyway, the main point is that the syntactic sugar for monadic binds (flatMaps) provided by Scala’s for-expressions (similar to do-notation in Haskell) leads to readable code not so different to that in well-known general-purpose PPLs such as BUGS, JAGS, or Stan. There are some important differences, however. In particular, the embedded DSL has probabilistic programs as regular values in the host language. These may be manipulated and composed like other values. This makes this probabilistic programming language more composable than the aforementioned languages, which makes it much simpler to build large, complex probabilistic programs from simpler, well-tested, components, in a scalable way. That is, this PPL we have obtained “for free” is actually in many ways better than most well-known PPLs.

Noisy measurements of a count

Here we’ll look at the problem of inference for a discrete count given some noisy iid continuous measurements of it.

val mod = for {
  count <- Poisson(10)
  tau <- Gamma(1, 0.1)
  _ <- Normal(count, 1.0/tau).fitQ(List(4.2,5.1,4.6,3.3,4.7,5.3))
} yield (count, tau)
// mod: Wrapped.Prob[(Int, Double)] = Empirical(Vector(Particle((5,4.488795220669575),-11.591037521513753), Particle((5,1.7792314573063672),-11.591037521513753), Particle((5,2.5238021156137673),-11.591037521513753), Particle((4,3.280754333896923),-11.591037521513753), Particle((5,2.768438569482849),-11.591037521513753), Particle((4,1.3399975573518912),-11.591037521513753), Particle((5,1.1792835858615431),-11.591037521513753), Particle((5,1.989491156206883),-11.591037521513753), Particle((4,0.7825254987152054),-11.591037521513753), Particle((5,2.7113936834028793),-11.591037521513753), Particle((5,3.7615196800240387),-11.591037521513753), Particle((4,1.6833300961124709),-11.591037521513753), Particle((5,2.749183220798113),-11.591037521513753), Particle((5,2.1074062883430202),-11.591037521513...

val modEmp = mod.empirical
// modEmp: Vector[(Int, Double)] = Vector((4,3.243786594839479), (4,1.5090869158886693), (4,1.280656912383482), (5,2.0616356908358195), (5,3.475433097869503), (5,1.887582611202514), (5,2.8268877720514745), (5,0.9193261688050818), (4,1.7063629502805908), (5,2.116414832864841), (5,3.775508828984636), (5,2.6774941123762814), (5,2.937859946593459), (5,1.2047689975166402), (5,2.5658806161572656), (5,1.925890364268593), (4,1.0194093176888832), (5,1.883288825936725), (5,4.9503779454422965), (5,0.9045613180858916), (4,1.5795027943928661), (5,1.925890364268593), (5,2.198539449287062), (5,1.791363956348445), (5,0.9853760689818026), (4,1.6541388923071607), (5,2.599899960899971), (4,1.8904423810277957), (5,3.8983183765907836), (5,1.9242319515895554), (5,2.8268877720514745), (4,1.772120802027519), (5,2...

meanVar(modEmp map (_._1.toDouble)) // count
// res2: breeze.stats.meanAndVariance.MeanAndVariance = MeanAndVariance(4.670000000000004,0.23521739130434777,300)

meanVar(modEmp map (_._2)) // tau
// res3: breeze.stats.meanAndVariance.MeanAndVariance = MeanAndVariance(1.9678279101913874,0.9603971613375548,300)

I’ve included this mainly as an example of inference for a discrete-valued parameter. There are people out there who will tell you that discrete parameters are bad/evil/impossible. This isn’t true – discrete parameters are cool!

Linear model

Because our PPL is embedded, we can take full advantage of the power of the host programming language to build our models. Let’s explore this in the context of Bayesian estimation of a linear model. We’ll start with some data.

val x = List(1.0,2,3,4,5,6)
// x: List[Double] = List(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)

val y = List(3.0,2,4,5,5,6)
// y: List[Double] = List(3.0, 2.0, 4.0, 5.0, 5.0, 6.0)

val xy = x zip y
// xy: List[(Double, Double)] = List((1.0,3.0), (2.0,2.0), (3.0,4.0), (4.0,5.0), (5.0,5.0), (6.0,6.0))

Now, our (simple) linear regression model will be parameterised by an intercept, alpha, a slope, beta, and a residual variance, v. So, for convenience, let’s define an ADT representing a particular linear model.

case class Param(alpha: Double, beta: Double, v: Double)
// defined class Param

Now we can define a prior distribution over models as follows.

val prior = for {
  alpha <- Normal(0,10)
  beta <- Normal(0,4)
  v <- Gamma(1,0.1)
} yield Param(alpha, beta, v)
// prior: Wrapped.Prob[Param] = Empirical(Vector(Particle(Param(-2.392517550699654,-3.7516090283880095,1.724680963054379),0.0), Particle(Param(7.60982717067903,-1.4318199629361292,2.9436745225038545),0.0), Particle(Param(-1.0281832158124837,-0.2799562317845073,4.05125312048092),0.0), Particle(Param(-1.0509321093485073,-2.4733837587060448,0.5856868459456287),0.0), Particle(Param(7.678898742733517,0.15616204936412104,5.064540017623097),0.0), Particle(Param(-3.392028985658713,-0.694412176170572,7.452625596437611),0.0), Particle(Param(3.0310535934425324,-2.97938526497514,2.138446100857938),0.0), Particle(Param(3.016959696424399,1.3370878561954143,6.18957854813488),0.0), Particle(Param(2.6956505371497066,1.058845844793446,5.257973123790336),0.0), Particle(Param(1.496225540527873,-1.573936445746...

Since our language doesn’t include any direct syntactic support for fitting regression models, we can define our own function for conditioning a distribution over models on a data point, which we can then apply to our prior as a fold over the available data.

def addPoint(current: Prob[Param], obs: (Double, Double)): Prob[Param] = for {
    p <- current
    (x, y) = obs
    _ <- Normal(p.alpha + p.beta * x, p.v).fitQ(y)
  } yield p
// addPoint: (current: Wrapped.Prob[Param], obs: (Double, Double))Wrapped.Prob[Param]

val mod = xy.foldLeft(prior)(addPoint(_,_)).empirical
// mod: Vector[Param] = Vector(Param(1.4386051853067798,0.8900831186754122,4.185564696221981), Param(0.5530582357040271,1.1296886766045509,3.468527573093037), Param(0.6302560079049571,0.9396563485293532,3.7044543917875927), Param(3.68291303096638,0.4781372802435529,5.151665328789926), Param(3.016959696424399,0.4438016738989412,1.9988914122633519), Param(3.016959696424399,0.4438016738989412,1.9988914122633519), Param(0.6302560079049571,0.9396563485293532,3.7044543917875927), Param(0.6302560079049571,0.9396563485293532,3.7044543917875927), Param(3.68291303096638,0.4781372802435529,5.151665328789926), Param(3.016959696424399,0.4438016738989412,1.9988914122633519), Param(0.6302560079049571,0.9396563485293532,3.7044543917875927), Param(0.6302560079049571,0.9396563485293532,3.7044543917875927), ...

meanVar(mod map (_.alpha))
// res4: breeze.stats.meanAndVariance.MeanAndVariance = MeanAndVariance(1.5740812481283812,1.893684802867127,300)

meanVar(mod map (_.beta))
// res5: breeze.stats.meanAndVariance.MeanAndVariance = MeanAndVariance(0.7690238868623273,0.1054479268115053,300)

meanVar(mod map (_.v))
// res6: breeze.stats.meanAndVariance.MeanAndVariance = MeanAndVariance(3.5240853748668695,2.793386340338213,300)

We could easily add syntactic support to our language to enable the fitting of regression-style models, as is done in Rainier, of which more later.

Dynamic generalised linear model

The previous examples have been fairly simple, so let’s finish with something a bit less trivial. Our language is quite flexible enough to allow the analysis of a dynamic generalised linear model (DGLM). Here we’ll fit a Poisson DGLM with a log-link and a simple Brownian state evolution. More complex models are more-or-less similarly straightforward. The model is parameterised by an initial state, state0, and and evolution variance, w.

val data = List(2,1,0,2,3,4,5,4,3,2,1)
// data: List[Int] = List(2, 1, 0, 2, 3, 4, 5, 4, 3, 2, 1)

val prior = for {
  w <- Gamma(1, 1)
  state0 <- Normal(0.0, 2.0)
} yield (w, List(state0))
// prior: Wrapped.Prob[(Double, List[Double])] = Empirical(Vector(Particle((0.12864918092587044,List(-2.862479260552014)),0.0), Particle((1.1706344622093179,List(1.6138397233532091)),0.0), Particle((0.757288087950638,List(-0.3683499919402798)),0.0), Particle((2.755201217523856,List(-0.6527488751780317)),0.0), Particle((0.7535085397802043,List(0.5135562407906502)),0.0), Particle((1.1630726564525629,List(0.9703146201262348)),0.0), Particle((1.0080345715326213,List(-0.375686732266234)),0.0), Particle((4.603723117526974,List(-1.6977366375222938)),0.0), Particle((0.2870669117815037,List(2.2732160435099433)),0.0), Particle((2.454675218313211,List(-0.4148287542786906)),0.0), Particle((0.3612534201761152,List(-1.0099270904161748)),0.0), Particle((0.29578453393473114,List(-2.4938128878051966)),0.0)...

We can define a function to create a new hidden state, prepend it to the list of hidden states, and condition on the observed value at that time point as follows.

def addTimePoint(current: Prob[(Double, List[Double])],
  obs: Int): Prob[(Double, List[Double])] = for {
  tup <- current
  (w, states) = tup
  os = states.head
  ns <- Normal(os, w)
  _ <- Poisson(math.exp(ns)).fitQ(obs)
} yield (w, ns :: states)
// addTimePoint: (current: Wrapped.Prob[(Double, List[Double])], obs: Int)Wrapped.Prob[(Double, List[Double])]

We then run our (augmented state) particle filter as a fold over the time series.

val mod = data.foldLeft(prior)(addTimePoint(_,_)).empirical
// mod: Vector[(Double, List[Double])] = Vector((0.053073252551193446,List(0.8693030057529023, 1.2746526177834938, 1.020307245610461, 1.106341696651584, 1.070777529635013, 0.8749041525303247, 0.9866999164354662, 0.4082577920509255, 0.06903234462140699, -0.018835642776197814, -0.16841912034400547, -0.08919045681401294)), (0.0988871875952762,List(-0.24241948109998607, 0.09321618969352086, 0.9650532206325375, 1.1738734442767293, 1.2272325310228442, 0.9791695328246326, 0.5576319082578128, -0.0054280215024367084, 0.4256621012454391, 0.7486862644576158, 0.8193517409118243, 0.5928750312493785)), (0.16128799384962295,List(-0.30371187329667104, -0.3976854602292066, 0.5869357473774455, 0.9881090696832543, 1.2095181380307558, 0.7211231597865506, 0.8085486452269925, 0.2664373341459165, -0.627344024142...

meanVar(mod map (_._1)) // w
// res7: breeze.stats.meanAndVariance.MeanAndVariance = MeanAndVariance(0.29497487517435844,0.0831412016262515,300)

meanVar(mod map (_._2.reverse.head)) // state0 (initial state)
// res8: breeze.stats.meanAndVariance.MeanAndVariance = MeanAndVariance(0.04617218427664018,0.372844704533101,300)

meanVar(mod map (_._2.head)) // stateN (final state)
// res9: breeze.stats.meanAndVariance.MeanAndVariance = MeanAndVariance(0.4937178761565612,0.2889287607470016,300)

Summary, conclusions, and further reading

So, we’ve seen how we can build a fully functional, general-purpose, compositional, monadic PPL from scratch in 50 lines of code, and we’ve seen how we can use it to solve real, analytically intractable Bayesian inference problems of non-trivial complexity. Of course, there are many limitations to using exactly this PPL implementation in practice. The algorithm becomes intolerably slow for deeply nested models, and uses unreasonably large amounts of RAM for large numbers of particles. It also suffers from a particle degeneracy problem if there are too many conditioning events. But it is important to understand that these are all deficiencies of the naive inference algorithm used, not the PPL itself. The PPL is flexible and compositional and can be used to build models of arbitrary size and complexity – it just needs to be underpinned by a better, more efficient, inference algorithm. Rainier is a Scala library I’ve blogged about previously which uses a very similar PPL to the one described here, but is instead underpinned by a fast, efficient, HMC algorithm. With my student Jonny Law, we have recently arXived a paper on Functional probabilistic programming for scalable Bayesian modelling, discussing some of these issues, and exploring the compositional nature of monadic PPLs (somewhat glossed over in this post).

Since the same PPL can be underpinned by different inference algorithms encapsulated as probability monads, an obvious question is whether it is possible to abstract the PPL away from the inference algorithm implementation. Of course, the answer is “yes”, and this has been explored to great effect in papers such as Practical probabilistic programming with monads and Functional programming for modular Bayesian inference. Note that they use the cond approach to conditioning, which looks a bit unwieldy, but is equivalent to fitting. As well as allowing alternative inference algorithms to be applied to the same probabilistic program, it also enables the composing of inference algorithms – for example, composing a MH algorithm with an SMC algorithm in order to get a PMMH algorithm. The ideas are implemented in an embedded DSL for Haskell, monad-bayes. If you are not used to Haskell, the syntax will probably seem a bit more intimidating than Scala’s, but the semantics are actually quite similar, with the main semantic difference being that Scala is strictly evaluated by default, whereas Haskell is lazily evaluated by default. Both languages support both lazy and strict evaluation – the difference relates simply to default behaviour, but is important nevertheless.

Papers

Software

  • min-ppl – code associated with this blog post
  • Rainier – a more efficient PPL with similar syntax
  • monad-bayes – a Haskell library exploring related ideas

Bayesian hierarchical modelling with Rainier

Introduction

In the previous post I gave a brief introduction to Rainier, a new HMC-based probabilistic programming library/DSL for Scala. In that post I assumed that people were using the latest source version of the library. Since then, version 0.1.1 of the library has been released, so in this post I will demonstrate use of the released version of the software (using the binaries published to Sonatype), and will walk through a slightly more interesting example – a dynamic linear state space model with unknown static parameters. This is similar to, but slightly different from, the DLM example in the Rainier library. So to follow along with this post, all that is required is SBT.

An interactive session

First run SBT from an empty directory, and paste the following at the SBT prompt:

set libraryDependencies  += "com.stripe" %% "rainier-plot" % "0.1.1"
set scalaVersion := "2.12.4"
console

This should give a Scala REPL with appropriate dependencies (rainier-plot has all of the relevant transitive dependencies). We’ll begin with some imports, and then simulating some synthetic data from a dynamic linear state space model with an AR(1) latent state and Gaussian noise on the observations.

import com.stripe.rainier.compute._
import com.stripe.rainier.core._
import com.stripe.rainier.sampler._

implicit val rng = ScalaRNG(1)
val n = 60 // number of observations/time points
val mu = 3.0 // AR(1) mean
val a = 0.95 // auto-regressive parameter
val sig = 0.2 // AR(1) SD
val sigD = 3.0 // observational SD
val state = Stream.
  iterate(0.0)(x => mu + (x - mu) * a + sig * rng.standardNormal).
  take(n).toVector
val obs = state.map(_ + sigD * rng.standardNormal)

Now we have some synthetic data, let’s think about building a probabilistic program for this model. Start with a prior.

case class Static(mu: Real, a: Real, sig: Real, sigD: Real)
val prior = for {
  mu <- Normal(0, 10).param
  a <- Normal(1, 0.1).param
  sig <- Gamma(2,1).param
  sigD <- Gamma(2,2).param
  sp <- Normal(0, 50).param
} yield (Static(mu, a, sig, sigD), List(sp))

Note the use of a case class for wrapping the static parameters. Next, let’s define a function to add a state and associated observation to an existing model.

def addTimePoint(current: RandomVariable[(Static, List[Real])],
                     datum: Double) = for {
  tup <- current
  static = tup._1
  states = tup._2
  os = states.head
  ns <- Normal(((Real.one - static.a) * static.mu) + (static.a * os),
                 static.sig).param
  _ <- Normal(ns, static.sigD).fit(datum)
} yield (static, ns :: states)

Given this, we can generate the probabilistic program for our model as a fold over the data initialised with the prior.

val fullModel = obs.foldLeft(prior)(addTimePoint(_, _))

If we don’t want to keep samples for all of the variables, we can focus on the parameters of interest, wrapping the results in a Map for convenient sampling and plotting.

val model = for {
  tup <- fullModel
  static = tup._1
  states = tup._2
} yield
  Map("mu" -> static.mu,
  "a" -> static.a,
  "sig" -> static.sig,
  "sigD" -> static.sigD,
  "SP" -> states.reverse.head)

We can sample with

val out = model.sample(HMC(3), 100000, 10000 * 500, 500)

(this will take several minutes) and plot some diagnostics with

import com.cibo.evilplot.geometry.Extent
import com.stripe.rainier.plot.EvilTracePlot._

val truth = Map("mu" -> mu, "a" -> a, "sigD" -> sigD,
  "sig" -> sig, "SP" -> state(0))
render(traces(out, truth), "traceplots.png",
  Extent(1200, 1400))
render(pairs(out, truth), "pairs.png")

This generates the following diagnostic plots:

Everything looks good.

Summary

Rainier is a monadic embedded DSL for probabilistic programming in Scala. We can use standard functional combinators and for-expressions for building models to sample, and then run an efficient HMC algorithm on the resulting probability monad in order to obtain samples from the posterior distribution of the model.

See the Rainier repo for further details.

Monadic probabilistic programming in Scala with Rainier

Introduction

Rainier is an interesting new probabilistic programming library for Scala recently open-sourced by Stripe. Probabilistic programming languages provide a computational framework for building and fitting Bayesian models to data. There are many interesting probabilistic programming languages, and there is currently a lot of interesting innovation happening with probabilistic programming languages embedded in strongly typed functional programming languages such as Scala and Haskell. However, most such languages tend to be developed by people lacking expertise in statistics and numerics, leading to elegant, composable languages which work well for toy problems, but don’t scale well to the kinds of practical problems that applied statisticians are interested in. Conversely, there are a few well-known probabilistic programming languages developed by and for statisticians which have efficient inference engines, but are hampered by inflexible, inelegant languages and APIs. Rainier is interesting because it is an attempt to bridge the gap between these two worlds: it has a functional, composable, extensible, monadic API, yet is backed by a very efficient, high-performance scalable inference engine, using HMC and a static compute graph for reverse-mode AD. Clearly there will be some loss of generality associated with choosing an efficient inference algorithm (eg. for HMC, there needs to be a fixed number of parameters and they must all be continuous), but it still covers a large proportion of the class of hierarchical models commonly used in applied statistical modelling.

In this post I’ll give a quick introduction to Rainier using an interactive session requiring only that SBT is installed and the Rainier repo is downloaded or cloned.

Interactive session

To follow along with this post just clone, or download and unpack, the Rainier repo, and run SBT from the top-level Rainier directory and paste commands. First start a Scala REPL.

project rainierPlot
console

Before we start building models, we need some data. For this post we will focus on a simple logistic regression model, and so we will begin by simulating some synthetic data consistent with such a model.

val r = new scala.util.Random(0)
val N = 1000
val beta0 = 0.1
val beta1 = 0.3
val x = (1 to N) map { i =>
  3.0 * r.nextGaussian
}
val theta = x map { xi =>
  beta0 + beta1 * xi
}
def expit(x: Double): Double = 1.0 / (1.0 + math.exp(-x))
val p = theta map expit
val y = p map (pi => (r.nextDouble < pi))

Now we have some synthetic data, we can fit the model and see if we are able to recover the “true” parameters used to generate the synthetic data. In Rainier, we build models by declaring probabilistic programs for the model and the data, and then run an inference engine to generate samples from the posterior distribution.

Start with a bunch of Rainier imports:

import com.stripe.rainier.compute._
import com.stripe.rainier.core._
import com.stripe.rainier.sampler._
import com.stripe.rainier.repl._

Now we want to build a model. We do so by describing the joint distribution of parameters and data. Rainier has a few built-in distributions, and these can be combined using standard functional monadic combinators such as map, zip, flatMap, etc., to create a probabilistic program representing a probability monad for the model. Due to the monadic nature of such probabilistic programs, it is often most natural to declare them using a for-expression.

val model = for {
  beta0 <- Normal(0, 5).param
  beta1 <- Normal(0, 5).param
  _ <- Predictor.from{x: Double =>
      {
        val theta = beta0 + beta1 * x
        val p = Real(1.0) / (Real(1.0) + (Real(0.0) - theta).exp)
        Categorical.boolean(p)
      }
    }.fit(x zip y)
} yield Map("b0"->beta0, "b1"->beta1)

This kind of construction is very natural for anyone familiar with monadic programming in Scala, but will no doubt be a little mysterious otherwise. RandomVariable is the probability monad used for HMC sampling, and these can be constructed from Distributions using .param (for unobserved parameters) and .fit (for variables with associated observations). Predictor is just a convenience for observations corresponding to covariate information. model is therefore a RandomVariable over beta0 and beta1, the two unobserved parameters of interest. Note that I briefly discussed this kind of pure functional approach to describing probabilistic programs (using Rand from Breeze) in my post on MCMC as a stream.

Now we have our probabilistic program, we can sample from it using HMC as follows.

implicit val rng = ScalaRNG(3)
val its = 10000
val thin = 5
val out = model.sample(HMC(5), 10000, its*thin, thin)
println(out.take(10))

The argument to HMC() is the number of leapfrog steps to take per iteration.

Finally, we can use EvilPlot to look at the HMC output and check that we have managed to reasonably recover the true parameters associated with our synthetic data.

import com.cibo.evilplot.geometry.Extent
import com.stripe.rainier.plot.EvilTracePlot._

render(traces(out, truth = Map("b0" -> beta0, "b1" -> beta1)),
  "traceplots.png", Extent(1200, 1000))
render(pairs(out, truth = Map("b0" -> beta0, "b1" -> beta1)), "pairs.png")

Everything looks good, and the sampling is very fast!

Further reading

For further information, see the Rainier repo. In particular, start with the tour of Rainier’s core, which gives a more detailed introduction to how Rainier works than this post. Those interested in how the efficient AD works may want to read about the compute graph, and the implementation notes explain how it all fits together. There is some basic ScalaDoc for the core package, and also some examples (including this one), and there’s a gitter channel for asking questions. This is a very new project, so there are a few minor bugs and wrinkles in the initial release, but development is progressing rapidly, so I fully expect the library to get properly battle-hardened over the next few months.

For those unfamiliar with the monadic approach to probabilistic programming, then Ścibior et al (2015) is probably a good starting point.

Comonads for scientific and statistical computing in Scala

Introduction

In a previous post I’ve given a brief introduction to monads in Scala, aimed at people interested in scientific and statistical computing. Monads are a concept from category theory which turn out to be exceptionally useful for solving many problems in functional programming. But most categorical concepts have a dual, usually prefixed with “co”, so the dual of a monad is the comonad. Comonads turn out to be especially useful for formulating algorithms from scientific and statistical computing in an elegant way. In this post I’ll illustrate their use in signal processing, image processing, numerical integration of PDEs, and Gibbs sampling (of an Ising model). Comonads enable the extension of a local computation to a global computation, and this pattern crops up all over the place in statistical computing.

Monads and comonads

Simplifying massively, from the viewpoint of a Scala programmer, a monad is a mappable (functor) type class augmented with the methods pure and flatMap:

trait Monad[M[_]] extends Functor[M] {
  def pure[T](v: T): M[T]
  def flatMap[T,S](v: M[T])(f: T => M[S]): M[S]
}

In category theory, the dual of a concept is typically obtained by “reversing the arrows”. Here that means reversing the direction of the methods pure and flatMap to get extract and coflatMap, respectively.

trait Comonad[W[_]] extends Functor[W] {
  def extract[T](v: W[T]): T
  def coflatMap[T,S](v: W[T])(f: W[T] => S): W[S]
}

So, while pure allows you to wrap plain values in a monad, extract allows you to get a value out of a comonad. So you can always get a value out of a comonad (unlike a monad). Similarly, while flatMap allows you to transform a monad using a function returning a monad, coflatMap allows you to transform a comonad using a function which collapses a comonad to a single value. It is coflatMap (sometimes called extend) which can extend a local computation (producing a single value) to the entire comonad. We’ll look at how that works in the context of some familiar examples.

Applying a linear filter to a data stream

One of the simplest examples of a comonad is an infinite stream of data. I’ve discussed streams in a previous post. By focusing on infinite streams we know the stream will never be empty, so there will always be a value that we can extract. Which value does extract give? For a Stream encoded as some kind of lazy list, the only value we actually know is the value at the head of the stream, with subsequent values to be lazily computed as required. So the head of the list is the only reasonable value for extract to return.

Understanding coflatMap is a bit more tricky, but it is coflatMap that provides us with the power to apply a non-trivial statistical computation to the stream. The input is a function which transforms a stream into a value. In our example, that will be a function which computes a weighted average of the first few values and returns that weighted average as the result. But the return type of coflatMap must be a stream of such computations. Following the types, a few minutes thought reveals that the only reasonable thing to do is to return the stream formed by applying the weighted average function to all sub-streams, recursively. So, for a Stream s (of type Stream[T]) and an input function f: W[T] => S, we form a stream whose head is f(s) and whose tail is coflatMap(f) applied to s.tail. Again, since we are working with an infinite stream, we don’t have to worry about whether or not the tail is empty. This gives us our comonadic Stream, and it is exactly what we need for applying a linear filter to the data stream.

In Scala, Cats is a library providing type classes from Category theory, and instances of those type classes for parametrised types in the standard library. In particular, it provides us with comonadic functionality for the standard Scala Stream. Let’s start by defining a stream corresponding to the logistic map.

import cats._
import cats.implicits._

val lam = 3.7
def s = Stream.iterate(0.5)(x => lam*x*(1-x))
s.take(10).toList
// res0: List[Double] = List(0.5, 0.925, 0.25668749999999985,
//  0.7059564011718747, 0.7680532550204203, 0.6591455741499428, ...

Let us now suppose that we want to apply a linear filter to this stream, in order to smooth the values. The idea behind using comonads is that you figure out how to generate one desired value, and let coflatMap take care of applying the same logic to the rest of the structure. So here, we need a function to generate the first filtered value (since extract is focused on the head of the stream). A simple first attempt a function to do this might look like the following.

  def linearFilterS(weights: Stream[Double])(s: Stream[Double]): Double =
    (weights, s).parMapN(_*_).sum

This aligns each weight in parallel with a corresponding value from the stream, and combines them using multiplication. The resulting (hopefully finite length) stream is then summed (with addition). We can test this with

linearFilterS(Stream(0.25,0.5,0.25))(s)
// res1: Double = 0.651671875

and let coflatMap extend this computation to the rest of the stream with something like:

s.coflatMap(linearFilterS(Stream(0.25,0.5,0.25))).take(5).toList
// res2: List[Double] = List(0.651671875, 0.5360828502929686, ...

This is all completely fine, but our linearFilterS function is specific to the Stream comonad, despite the fact that all we’ve used about it in the function is that it is a parallelly composable and foldable. We can make this much more generic as follows:

  def linearFilter[F[_]: Foldable, G[_]](
    weights: F[Double], s: F[Double]
  )(implicit ev: NonEmptyParallel[F, G]): Double =
    (weights, s).parMapN(_*_).fold

This uses some fairly advanced Scala concepts which I don’t want to get into right now (I should also acknowledge that I had trouble getting the syntax right for this, and got help from Fabio Labella (@SystemFw) on the Cats gitter channel). But this version is more generic, and can be used to linearly filter other data structures than Stream. We can use this for regular Streams as follows:

s.coflatMap(s => linearFilter(Stream(0.25,0.5,0.25),s))
// res3: scala.collection.immutable.Stream[Double] = Stream(0.651671875, ?)

But we can apply this new filter to other collections. This could be other, more sophisticated, streams such as provided by FS2, Monix or Akka streams. But it could also be a non-stream collection, such as List:

val sl = s.take(10).toList
sl.coflatMap(sl => linearFilter(List(0.25,0.5,0.25),sl))
// res4: List[Double] = List(0.651671875, 0.5360828502929686, ...

Assuming that we have the Breeze scientific library available, we can plot the raw and smoothed trajectories.

def myFilter(s: Stream[Double]): Double =
  linearFilter(Stream(0.25, 0.5, 0.25),s)
val n = 500
import breeze.plot._
import breeze.linalg._
val fig = Figure(s"The (smoothed) logistic map (lambda=$lam)")
val p0 = fig.subplot(3,1,0)
p0 += plot(linspace(1,n,n),s.take(n))
p0.ylim = (0.0,1.0)
p0.title = s"The logistic map (lambda=$lam)"
val p1 = fig.subplot(3,1,1)
p1 += plot(linspace(1,n,n),s.coflatMap(myFilter).take(n))
p1.ylim = (0.0,1.0)
p1.title = "Smoothed by a simple linear filter"
val p2 = fig.subplot(3,1,2)
p2 += plot(linspace(1,n,n),s.coflatMap(myFilter).coflatMap(myFilter).coflatMap(myFilter).coflatMap(myFilter).coflatMap(myFilter).take(n))
p2.ylim = (0.0,1.0)
p2.title = "Smoothed with 5 applications of the linear filter"
fig.refresh

Image processing and the heat equation

Streaming data is in no way the only context in which a comonadic approach facilitates an elegant approach to scientific and statistical computing. Comonads crop up anywhere where we want to extend a computation that is local to a small part of a data structure to the full data structure. Another commonly cited area of application of comonadic approaches is image processing (I should acknowledge that this section of the post is very much influenced by a blog post on comonadic image processing in Haskell). However, the kinds of operations used in image processing are in many cases very similar to the operations used in finite difference approaches to numerical integration of partial differential equations (PDEs) such as the heat equation, so in this section I will blur (sic) the distinction between the two, and numerically integrate the 2D heat equation in order to Gaussian blur a noisy image.

First we need a simple image type which can have pixels of arbitrary type T (this is very important – all functors must be fully type polymorphic).

  import scala.collection.parallel.immutable.ParVector
  case class Image[T](w: Int, h: Int, data: ParVector[T]) {
    def apply(x: Int, y: Int): T = data(x*h+y)
    def map[S](f: T => S): Image[S] = Image(w, h, data map f)
    def updated(x: Int, y: Int, value: T): Image[T] =
      Image(w,h,data.updated(x*h+y,value))
  }

Here I’ve chosen to back the image with a parallel immutable vector. This wasn’t necessary, but since this type has a map operation which automatically parallelises over multiple cores, any map operations applied to the image will be automatically parallelised. This will ultimately lead to all of our statistical computations being automatically parallelised without us having to think about it.

As it stands, this image isn’t comonadic, since it doesn’t implement extract or coflatMap. Unlike the case of Stream, there isn’t really a uniquely privileged pixel, so it’s not clear what extract should return. For many data structures of this type, we make them comonadic by adding a “cursor” pointing to a “current” element of interest, and use this as the focus for computations applied with coflatMap. This is simplest to explain by example. We can define our “pointed” image type as follows:

  case class PImage[T](x: Int, y: Int, image: Image[T]) {
    def extract: T = image(x, y)
    def map[S](f: T => S): PImage[S] = PImage(x, y, image map f)
    def coflatMap[S](f: PImage[T] => S): PImage[S] = PImage(
      x, y, Image(image.w, image.h,
      (0 until (image.w * image.h)).toVector.par.map(i => {
        val xx = i / image.h
        val yy = i % image.h
        f(PImage(xx, yy, image))
      })))

There is missing a closing brace, as I’m not quite finished. Here x and y represent the location of our cursor, so extract returns the value of the pixel indexed by our cursor. Similarly, coflatMap forms an image where the value of the image at each location is the result of applying the function f to the image which had the cursor set to that location. Clearly f should use the cursor in some way, otherwise the image will have the same value at every pixel location. Note that map and coflatMap operations will be automatically parallelised. The intuitive idea behind coflatMap is that it extends local computations. For the stream example, the local computation was a linear combination of nearby values. Similarly, in image analysis problems, we often want to apply a linear filter to nearby pixels. We can get at the pixel at the cursor location using extract, but we probably also want to be able to move the cursor around to nearby locations. We can do that by adding some appropriate methods to complete the class definition.

    def up: PImage[T] = {
      val py = y-1
      val ny = if (py >= 0) py else (py + image.h)
      PImage(x,ny,image)
    }
    def down: PImage[T] = {
      val py = y+1
      val ny = if (py < image.h) py else (py - image.h)
      PImage(x,ny,image)
    }
    def left: PImage[T] = {
      val px = x-1
      val nx = if (px >= 0) px else (px + image.w)
      PImage(nx,y,image)
    }
    def right: PImage[T] = {
      val px = x+1
      val nx = if (px < image.w) px else (px - image.w)
      PImage(nx,y,image)
    }
  }

Here each method returns a new pointed image with the cursor shifted by one pixel in the appropriate direction. Note that I’ve used periodic boundary conditions here, which often makes sense for numerical integration of PDEs, but makes less sense for real image analysis problems. Note that we have embedded all “indexing” issues inside the definition of our classes. Now that we have it, none of the statistical algorithms that we develop will involve any explicit indexing. This makes it much less likely to develop algorithms containing bugs corresponding to “off-by-one” or flipped axis errors.

This class is now fine for our requirements. But if we wanted Cats to understand that this structure is really a comonad (perhaps because we wanted to use derived methods, such as coflatten), we would need to provide evidence for this. The details aren’t especially important for this post, but we can do it simply as follows:

  implicit val pimageComonad = new Comonad[PImage] {
    def extract[A](wa: PImage[A]) = wa.extract
    def coflatMap[A,B](wa: PImage[A])(f: PImage[A] => B): PImage[B] =
      wa.coflatMap(f)
    def map[A,B](wa: PImage[A])(f: A => B): PImage[B] = wa.map(f)
  }

It’s handy to have some functions for converting Breeze dense matrices back and forth with our image class.

  import breeze.linalg.{Vector => BVec, _}
  def BDM2I[T](m: DenseMatrix[T]): Image[T] =
    Image(m.cols, m.rows, m.data.toVector.par)
  def I2BDM(im: Image[Double]): DenseMatrix[Double] =
    new DenseMatrix(im.h,im.w,im.data.toArray)

Now we are ready to see how to use this in practice. Let’s start by defining a very simple linear filter.

def fil(pi: PImage[Double]): Double = (2*pi.extract+
  pi.up.extract+pi.down.extract+pi.left.extract+pi.right.extract)/6.0

This simple filter can be used to “smooth” or “blur” an image. However, from a more sophisticated viewpoint, exactly this type of filter can be used to represent one time step of a numerical method for time integration of the 2D heat equation. Now we can simulate a noisy image and apply our filter to it using coflatMap:

import breeze.stats.distributions.Gaussian
val bdm = DenseMatrix.tabulate(200,250){case (i,j) => math.cos(
  0.1*math.sqrt((i*i+j*j))) + Gaussian(0.0,2.0).draw}
val pim0 = PImage(0,0,BDM2I(bdm))
def pims = Stream.iterate(pim0)(_.coflatMap(fil))

Note that here, rather than just applying the filter once, I’ve generated an infinite stream of pointed images, each one representing an additional application of the linear filter. Thus the sequence represents the time solution of the heat equation with initial condition corresponding to our simulated noisy image.

We can render the first few frames to check that it seems to be working.

import breeze.plot._
val fig = Figure("Diffusing a noisy image")
pims.take(25).zipWithIndex.foreach{case (pim,i) => {
  val p = fig.subplot(5,5,i)
  p += image(I2BDM(pim.image))
}}

Note that the numerical integration is carried out in parallel on all available cores automatically. Other image filters can be applied, and other (parabolic) PDEs can be numerically integrated in an essentially similar way.

Gibbs sampling the Ising model

Another place where the concept of extending a local computation to a global computation crops up is in the context of Gibbs sampling a high-dimensional probability distribution by cycling through the sampling of each variable in turn from its full-conditional distribution. I’ll illustrate this here using the Ising model, so that I can reuse the pointed image class from above, but the principles apply to any Gibbs sampling problem. In particular, the Ising model that we consider has a conditional independence structure corresponding to a graph of a square lattice. As above, we will use the comonadic structure of the square lattice to construct a Gibbs sampler. However, we can construct a Gibbs sampler for arbitrary graphical models in an essentially identical way by using a graph comonad.

Let’s begin by simulating a random image containing +/-1s:

import breeze.stats.distributions.{Binomial,Bernoulli}
val beta = 0.4
val bdm = DenseMatrix.tabulate(500,600){
  case (i,j) => (new Binomial(1,0.2)).draw
}.map(_*2 - 1) // random matrix of +/-1s
val pim0 = PImage(0,0,BDM2I(bdm))

We can use this to initialise our Gibbs sampler. We now need a Gibbs kernel representing the update of each pixel.

def gibbsKernel(pi: PImage[Int]): Int = {
   val sum = pi.up.extract+pi.down.extract+pi.left.extract+pi.right.extract
   val p1 = math.exp(beta*sum)
   val p2 = math.exp(-beta*sum)
   val probplus = p1/(p1+p2)
   if (new Bernoulli(probplus).draw) 1 else -1
}

So far so good, but there a couple of issues that we need to consider before we plough ahead and start coflatMapping. The first is that pure functional programmers will object to the fact that this function is not pure. It is a stochastic function which has the side-effect of mutating the random number state. I’m just going to duck that issue here, as I’ve previously discussed how to fix it using probability monads, and I don’t want it to distract us here.

However, there is a more fundamental problem here relating to parallel versus sequential application of Gibbs kernels. coflatMap is conceptually parallel (irrespective of how it is implemented) in that all computations used to build the new comonad are based solely on the information available in the starting comonad. OTOH, detailed balance of the Markov chain will only be preserved if the kernels for each pixel are applied sequentially. So if we coflatMap this kernel over the image we will break detailed balance. I should emphasise that this has nothing to do with the fact that I’ve implemented the pointed image using a parallel vector. Exactly the same issue would arise if we switched to backing the image with a regular (sequential) immutable Vector.

The trick here is to recognise that if we coloured alternate pixels black and white using a chequerboard pattern, then all of the black pixels are conditionally independent given the white pixels and vice-versa. Conditionally independent pixels can be updated by parallel application of a Gibbs kernel. So we just need separate kernels for updating odd and even pixels.

def oddKernel(pi: PImage[Int]): Int =
  if ((pi.x+pi.y) % 2 != 0) pi.extract else gibbsKernel(pi)
def evenKernel(pi: PImage[Int]): Int =
  if ((pi.x+pi.y) % 2 == 0) pi.extract else gibbsKernel(pi)

Each of these kernels can be coflatMapped over the image preserving detailed balance of the chain. So we can now construct an infinite stream of MCMC iterations as follows.

def pims = Stream.iterate(pim0)(_.coflatMap(oddKernel).
  coflatMap(evenKernel))

We can animate the first few iterations with:

import breeze.plot._
val fig = Figure("Ising model Gibbs sampler")
fig.width = 1000
fig.height = 800
pims.take(50).zipWithIndex.foreach{case (pim,i) => {
  print(s"$i ")
  fig.clear
  val p = fig.subplot(1,1,0)
  p.title = s"Ising model: frame $i"
  p += image(I2BDM(pim.image.map{_.toDouble}))
  fig.refresh
}}
println

Here I have a movie showing the first 1000 iterations. Note that youtube seems to have over-compressed it, but you should get the basic idea.

Again, note that this MCMC sampler runs in parallel on all available cores, automatically. This issue of odd/even pixel updating emphasises another issue that crops up a lot in functional programming: very often, thinking about how to express an algorithm functionally leads to an algorithm which parallelises naturally. For general graphs, figuring out which groups of nodes can be updated in parallel is essentially the graph colouring problem. I’ve discussed this previously in relation to parallel MCMC in:

Wilkinson, D. J. (2005) Parallel Bayesian Computation, Chapter 16 in E. J. Kontoghiorghes (ed.) Handbook of Parallel Computing and Statistics, Marcel Dekker/CRC Press, 481-512.

Further reading

There are quite a few blog posts discussing comonads in the context of Haskell. In particular, the post on comonads for image analysis I mentioned previously, and this one on cellular automata. Bartosz’s post on comonads gives some connection back to the mathematical origins. Runar’s Scala comonad tutorial is the best source I know for comonads in Scala.

Full runnable code corresponding to this blog post is available from my blog repo.