Bayesian inference for a logistic regression model (Part 6)

Part 6: Hamiltonian Monte Carlo (HMC)

Introduction

This is the sixth part in a series of posts on MCMC-based Bayesian inference for a logistic regression model. If you are new to this series, please go back to Part 1.

In the previous post we saw how to construct an MCMC algorithm utilising gradient information by considering a Langevin equation having our target distribution of interest as its equilibrium. This equation has a physical interpretation in terms of the stochastic dynamics of a particle in a potential equal to minus the log of the target density. It turns out that thinking about the deterministic dynamics of a particle in such a potential can lead to more efficient MCMC algorithms.

Hamiltonian dynamics

Hamiltonian dynamics is often presented as an extension of a fairly general version of Lagrangian dynamics. However, for our purposes a rather simple version is quite sufficient, based on basic concepts from Newtonian dynamics, familiar from school. Inspired by our Langevin example, we will consider the dynamics of a particle in a potential function V(q). We will see later why we want V(q) = -\log \pi(q) for our target of interest, \pi(\cdot). In the context of Hamiltonian (and Lagrangian) dynamics we typically use q as our position variable, rather than x.

The potential function induces a (conservative) force on the particle equal to -\nabla V(q) when the particle is at position q. Then Newton’s second law of motion, "F=ma", takes the form

\displaystyle \nabla V(q) + m \ddot{q} = 0.

In Newtonian mechanics, we often consider the position vector q as 3-dimensional. Here it will be n-dimensional, where n is the number of variables in our target. We can then think of our second law as governing a single n-dimensional particle of mass m, or n one-dimensional particles all of mass m. But in this latter case, there is no need to assume that all particles have the same mass, and we could instead write our law of motion as

\displaystyle \nabla V(q) + M \ddot{q} = 0,

where M is a diagonal matrix. But in fact, since we could change coordinates, there’s no reason to require that M is diagonal. All we need is that M is positive definite, so that we don’t have negative mass in any coordinate direction.

We will take the above equation as the fundamental law governing our dynamical system of interest. The motivation from Newtonian dynamics is interesting, but not required. What is important is that the dynamics of such a system are conservative, in a way that we will shortly make precise.

Our law of motion is a second-order differential equation, since it involves the second derivative of q wrt time. If you’ve ever studied differential equations, you’ll know that there is an easy way to turn a second order equation into a first order equation with twice the dimension by augmenting the system with the velocities. Here, it is more convenient to augment the system with "momentum" variables, p, which we define as p = M\dot{q}. Then we can write our second order system as a pair of first order equations

\displaystyle \dot{q} = M^{-1}p

\displaystyle \dot{p} = -\nabla V(q)

These are, in fact, Hamilton’s equations for this system, though this isn’t how they are typically written.

If we define the kinetic energy as

\displaystyle T(p) = \frac{1}{2}p^\text{T}M^{-1}p,

then the Hamiltonian

\displaystyle H(q,p) = V(q) + T(p),

representing the total energy in the system, is conserved, since

\displaystyle \dot{H} = \nabla V\cdot \dot{q} + \dot{p}^\text{T}M^{-1}p = \nabla V\cdot \dot{q} + \dot{p}^\text{T}\dot{q} = [\nabla V + \dot{p}]\cdot\dot{q} = 0.

So, if we obey our Hamiltonian dynamics, our trajectory in (q,p)-space will follow contours of the Hamiltonian. It’s also clear that the system is time-reversible, so flipping the sign of the momentum p and integrating will exactly reverse the direction in which the contours are traversed. Another quite important property of Hamiltonian dynamics is that they are volume preserving. This can be verified by checking that the divergence of the flow is zero.

\displaystyle \nabla\cdot(\dot{q},\dot{p}) = \nabla_q\cdot\dot{q} + \nabla_p\cdot\dot{p} = 0,

since \dot{q} is a function of p only and \dot{p} is a function of q only.

Hamiltonian Monte Carlo (HMC)

In Hamiltonian Monte Carlo we introduce an augmented target distribution,

\displaystyle \tilde \pi(q,p) \propto \exp[-H(q,p)]

It is clear from this definition that moves leaving the Hamiltonian invariant will also leave the augmented target density unchanged. By following the Hamiltonian dynamics, we will be able to make big (reversible) moves in the space that will be accepted with high probability. Also, our target factorises into two independent components as

\displaystyle \tilde \pi(q,p) \propto \exp[-V(q)]\exp[-T(p)],

and so choosing V(q)=-\log \pi(q) will ensure that the q-marginal is our real target of interest, \pi(\cdot). It’s also clear that our p-marginal is \mathcal N(0,M). This is also the full-conditional for p, so re-sampling p from this distribution and leaving q unchanged is a Gibbs move that will leave the augmented target invariant. Re-sampling p will be necessary to properly explore our augmented target, since this will move us to a different contour of H.

So, an idealised version of HMC would proceed as follows: First, update p by sampling from its known tractable marginal. Second, update p and q jointly by following the Hamiltonian dynamics. If this second move is regarded as a (deterministic) reversible M-H proposal, it will be accepted with probability one since it leaves the augmented target density unchanged. If we could exactly integrate Hamilton’s equations, this would be fine. But in practice, we will need to use some imperfect numerical method for the integration step. But just as for MALA, we can regard the numerical method as a M-H proposal and correct for the fact that it is imperfect, preserving the exact augmented target distribution.

Hamiltonian systems admit nice numerical integration schemes called symplectic integrators. In HMC a simple alternating Euler method is typically used, known as the leap-frog algorithm. The component updates are all shear transformations, and therefore volume preserving, and exact reversibility is ensured by starting and ending with a half-step update of the momentum variables. In principle, to ensure reversibility of the proposal the momentum variables should be sign-flipped (reversed) to finish, but in practice this doesn’t matter since it doesn’t affect the evaluation of the Hamiltonian and it will then get refreshed, anyway.

So, advancing our system by a time step \epsilon can be done with

\displaystyle p(t+\epsilon/2) := p(t) - \frac{\epsilon}{2}\nabla V(q(t))

\displaystyle q(t+\epsilon) := q(t) + \epsilon M^{-1}p(t+\epsilon/2)

\displaystyle p(t+\epsilon) := p(t+\epsilon/2) - \frac{\epsilon}{2}\nabla V(q(t+\epsilon))

It is clear that if many such updates are chained together, adjacent momentum updates can be collapsed together, giving rise to the "leap-frog" nature of the algorithm, and therefore requiring roughly one gradient evaluation per \epsilon update, rather than two. Since this integrator is volume preserving and exactly reversible, for reasonably small \epsilon it follows the Hamiltonian dynamics reasonably well, but not exactly, and so it does not exactly preserve the Hamiltonian. However, it does make a good M-H proposal, and reasonable acceptance probabilities can often be obtained by chaining together l updates to advance the time of the system by T=l\epsilon. The "optimal" value of l and \epsilon will be highly problem dependent, but values of l=20 or l=50 are not unusual. There are various more-or-less standard methods for tuning these, but we will not consider them here.

Note that since our HMC update on the augmented space consists of a Gibbs move and a M-H update, it is important that our M-H kernel does not keep or thread through the old log target density from the previous M-H update, since the Gibbs move will have changed it in the meantime.

Implementations

R

We need a M-H kernel that does not thread through the old log density.

mhKernel = function(logPost, rprop)
    function(x) {
        prop = rprop(x)
        a = logPost(prop) - logPost(x)
        if (log(runif(1)) < a)
            prop
        else
            x
    }

We can then use this to construct a M-H move as part of our HMC update.

hmcKernel = function(lpi, glpi, eps = 1e-4, l=10, dmm = 1) {
    sdmm = sqrt(dmm)
    leapf = function(q, p) {
        p = p + 0.5*eps*glpi(q)
        for (i in 1:l) {
            q = q + eps*p/dmm
            if (i < l)
                p = p + eps*glpi(q)
            else
                p = p + 0.5*eps*glpi(q)
        }
        list(q=q, p=-p)
    }
    alpi = function(x)
        lpi(x$q) - 0.5*sum((x$p^2)/dmm)
    rprop = function(x)
        leapf(x$q, x$p)
    mhk = mhKernel(alpi, rprop)
    function(q) {
        d = length(q)
        x = list(q=q, p=rnorm(d, 0, sdmm))
        mhk(x)$q
    }
}

See the full runnable script for further details.

Python

First a M-H kernel,

def mhKernel(lpost, rprop):
    def kernel(x):
        prop = rprop(x)
        a = lpost(prop) - lpost(x)
        if (np.log(np.random.rand()) < a):
            x = prop
        return x
    return kernel

and then an HMC kernel.

def hmcKernel(lpi, glpi, eps = 1e-4, l=10, dmm = 1):
    sdmm = np.sqrt(dmm)
    def leapf(q, p):    
        p = p + 0.5*eps*glpi(q)
        for i in range(l):
            q = q + eps*p/dmm
            if (i < l-1):
                p = p + eps*glpi(q)
            else:
                p = p + 0.5*eps*glpi(q)
        return (q, -p)
    def alpi(x):
        (q, p) = x
        return lpi(q) - 0.5*np.sum((p**2)/dmm)
    def rprop(x):
        (q, p) = x
        return leapf(q, p)
    mhk = mhKernel(alpi, rprop)
    def kern(q):
        d = len(q)
        p = np.random.randn(d)*sdmm
        return mhk((q, p))[0]
    return kern

See the full runnable script for further details.

JAX

Again, we want an appropriate M-H kernel,

def mhKernel(lpost, rprop, dprop = jit(lambda new, old: 1.)):
    @jit
    def kernel(key, x):
        key0, key1 = jax.random.split(key)
        prop = rprop(key0, x)
        ll = lpost(x)
        lp = lpost(prop)
        a = lp - ll + dprop(x, prop) - dprop(prop, x)
        accept = (jnp.log(jax.random.uniform(key1)) < a)
        return jnp.where(accept, prop, x)
    return kernel

and then an HMC kernel.

def hmcKernel(lpi, glpi, eps = 1e-4, l = 10, dmm = 1):
    sdmm = jnp.sqrt(dmm)
    @jit
    def leapf(q, p):    
        p = p + 0.5*eps*glpi(q)
        for i in range(l):
            q = q + eps*p/dmm
            if (i < l-1):
                p = p + eps*glpi(q)
            else:
                p = p + 0.5*eps*glpi(q)
        return jnp.concatenate((q, -p))
    @jit
    def alpi(x):
        d = len(x) // 2
        return lpi(x[jnp.array(range(d))]) - 0.5*jnp.sum((x[jnp.array(range(d,2*d))]**2)/dmm)
    @jit
    def rprop(k, x):
        d = len(x) // 2
        return leapf(x[jnp.array(range(d))], x[jnp.array(range(d, 2*d))])
    mhk = mhKernel(alpi, rprop)
    @jit
    def kern(k, q):
        key0, key1 = jax.random.split(k)
        d = len(q)
        x = jnp.concatenate((q, jax.random.normal(key0, [d])*sdmm))
        return mhk(key1, x)[jnp.array(range(d))]
    return kern

There is something a little bit strange about this implementation, since the proposal for the M-H move is deterministic, the function rprop just ignores the RNG key that is passed to it. We could tidy this up by making a M-H function especially for deterministic proposals. We won’t pursue this here, but this issue will crop up again later in some of the other functional languages.

See the full runnable script for further details.

Scala

A M-H kernel,

def mhKern[S](
    logPost: S => Double, rprop: S => S,
    dprop: (S, S) => Double = (n: S, o: S) => 1.0
  ): (S) => S =
    val r = Uniform(0.0,1.0)
    x0 =>
      val x = rprop(x0)
      val ll0 = logPost(x0)
      val ll = logPost(x)
      val a = ll - ll0 + dprop(x0, x) - dprop(x, x0)
      if (math.log(r.draw()) < a) x else x0

and a HMC kernel.

def hmcKernel(lpi: DVD => Double, glpi: DVD => DVD, dmm: DVD,
  eps: Double = 1e-4, l: Int = 10) =
  val sdmm = sqrt(dmm)
  def leapf(q: DVD, p: DVD): (DVD, DVD) = 
    @tailrec def go(q0: DVD, p0: DVD, l: Int): (DVD, DVD) =
      val q = q0 + eps*(p0/:/dmm)
      val p = if (l > 1)
        p0 + eps*glpi(q)
      else
        p0 + 0.5*eps*glpi(q)
      if (l == 1)
        (q, -p)
      else
        go(q, p, l-1)
    go(q, p + 0.5*eps*glpi(q), l)
  def alpi(x: (DVD, DVD)): Double =
    val (q, p) = x
    lpi(q) - 0.5*sum(pow(p,2) /:/ dmm)
  def rprop(x: (DVD, DVD)): (DVD, DVD) =
    val (q, p) = x
    leapf(q, p)
  val mhk = mhKern(alpi, rprop)
  (q: DVD) =>
    val d = q.length
    val p = sdmm map (sd => Gaussian(0,sd).draw())
    mhk((q, p))._1

See the full runnable script for further details.

Haskell

A M-H kernel:

mdKernel :: (StatefulGen g m) => (s -> Double) -> (s -> s) -> g -> s -> m s
mdKernel logPost prop g x0 = do
  let x = prop x0
  let ll0 = logPost x0
  let ll = logPost x
  let a = ll - ll0
  u <- (genContVar (uniformDistr 0.0 1.0)) g
  let next = if ((log u) < a)
        then x
        else x0
  return next

Note that here we are using a M-H kernel specifically for deterministic proposals, since there is no non-determinism signalled in the type signature of prop. We can then use this to construct our HMC kernel.

hmcKernel :: (StatefulGen g m) =>
  (Vector Double -> Double) -> (Vector Double -> Vector Double) -> Vector Double ->
  Double -> Int -> g ->
  Vector Double -> m (Vector Double)
hmcKernel lpi glpi dmm eps l g = let
  sdmm = cmap sqrt dmm
  leapf q p = let
    go q0 p0 l = let
      q = q0 + (scalar eps)*p0/dmm
      p = if (l > 1)
        then p0 + (scalar eps)*(glpi q)
        else p0 + (scalar (eps/2))*(glpi q)
      in if (l == 1)
      then (q, -p)
      else go q p (l - 1)
    in go q (p + (scalar (eps/2))*(glpi q)) l
  alpi x = let
    (q, p) = x
    in (lpi q) - 0.5*(sumElements (p*p/dmm))
  prop x = let
    (q, p) = x
    in leapf q p
  mk = mdKernel alpi prop g
  in (\q0 -> do
         let d = size q0
         zl <- (replicateM d . genContVar (normalDistr 0.0 1.0)) g
         let z = fromList zl
         let p0 = sdmm * z
         (q, p) <- mk (q0, p0)
         return q)

See the full runnable script for further details.

Dex

Again we can use a M-H kernel specific to deterministic proposals.

def mdKernel {s} (lpost: s -> Float) (prop: s -> s)
    (x0: s) (k: Key) : s =
  x = prop x0
  ll0 = lpost x0
  ll = lpost x
  a = ll - ll0
  u = rand k
  select (log u < a) x x0

and use this to construct an HMC kernel.

def hmcKernel {n} (lpi: (Fin n)=>Float -> Float)
    (dmm: (Fin n)=>Float) (eps: Float) (l: Nat)
    (q0: (Fin n)=>Float) (k: Key) : (Fin n)=>Float =
  sdmm = sqrt dmm
  idmm = map (\x. 1.0/x) dmm
  glpi = grad lpi
  def leapf (q0: (Fin n)=>Float) (p0: (Fin n)=>Float) :
      ((Fin n)=>Float & (Fin n)=>Float) =
    p1 = p0 + (eps/2) .* (glpi q0)
    q1 = q0 + eps .* (p1*idmm)
    (q, p) = apply_n l (q1, p1) \(qo, po).
      pn = po + eps .* (glpi qo)
      qn = qo + eps .* (pn*idmm)
      (qn, pn)
    pf = p + (eps/2) .* (glpi q)
    (q, -pf)
  def alpi (qp: ((Fin n)=>Float & (Fin n)=>Float)) : Float =
    (q, p) = qp
    (lpi q) - 0.5*(sum (p*p*idmm))
  def prop (qp: ((Fin n)=>Float & (Fin n)=>Float)) :
      ((Fin n)=>Float & (Fin n)=>Float) =
    (q, p) = qp
    leapf q p
  mk = mdKernel alpi prop
  [k1, k2] = split_key k
  z = randn_vec k1
  p0 = sdmm * z
  (q, p) = mk (q0, p0) k2
  q

Note that the gradient is obtained via automatic differentiation. See the full runnable script for details.

Next steps

This was the main place that I was trying to get to when I started this series of posts. For differentiable log-posteriors (as we have in the case of Bayesian logistic regression), HMC is a pretty good algorithm for reasonably efficient posterior exploration. But there are lots of places we could go from here. We could explore the tuning of MCMC algorithms, or HMC extensions such as NUTS. We could look at MCMC algorithms that are specifically tailored to the logistic regression problem, or we could look at new MCMC algorithms for differentiable targets based on piecewise deterministic Markov processes. Alternatively, we could temporarily abandon MCMC and look at SMC or ABC approaches. Another possibility would be to abandon this multi-language approach and have a bit of a deep dive into Dex, which I think has the potential to be a great programming language for statistical computing. All of these are possibilities for the future, but I’ve a busy few weeks coming up, so the frequency of these posts is likely to substantially decrease.

Remember that all of the code associated with this series of posts is available from this github repo.

Advertisement

Bayesian inference for a logistic regression model (Part 5)

Part 5: the Metropolis-adjusted Langevin algorithm (MALA)

Introduction

This is the fifth part in a series of posts on MCMC-based Bayesian inference for a logistic regression model. If you are new to this series, please go back to Part 1.

In the previous post we saw how to use Langevin dynamics to construct an approximate MCMC scheme using the gradient of the log target distribution. Each step of the algorithm involved simulating from the Euler-Maruyama approximation to the transition kernel of the process, based on some pre-specified step size, \Delta t. We can improve the accuracy of this approximation by making the step size smaller, but this will come at the expense of a more slowly mixing MCMC chain.

Fortunately, there is an easy way to make the algorithm "exact" (in the sense that the equilibrium distribution of the Markov chain will be the exact target distribution), for any finite step size, \Delta t, simply by using the Euler-Maruyama approximation as the proposal distribution in a Metropolis-Hastings algorithm. This is the Metropolis-adjusted Langevin algorithm (MALA). There are various ways this could be coded up, but here, for clarity, a HoF for generating a MALA kernel will be used, and this function will in turn call on a HoF for generating a Metropolis-Hastings kernel.

Implementations

R

First we need a function to generate a M-H kernel.

mhKernel = function(logPost, rprop, dprop = function(new, old, ...) { 1 })
    function(x, ll) {
        prop = rprop(x)
        llprop = logPost(prop)
        a = llprop - ll + dprop(x, prop) - dprop(prop, x)
        if (log(runif(1)) < a)
            list(x=prop, ll=llprop)
        else
            list(x=x, ll=ll)
    }

Then we can easily write a function for returning a MALA kernel that makes use of this M-H function.

malaKernel = function(lpi, glpi, dt = 1e-4, pre = 1) {
    sdt = sqrt(dt)
    spre = sqrt(pre)
    advance = function(x) x + 0.5*pre*glpi(x)*dt
    mhKernel(lpi, function(x) rnorm(p, advance(x), spre*sdt),
             function(new, old) sum(dnorm(new, advance(old), spre*sdt, log=TRUE)))
}

Notice that our MALA function requires as input both the gradient of the log posterior (for the proposal) and the log posterior itself (for the M-H correction). Other details are as we have already seen – see the full runnable script.

Python

Again, we need a M-H kernel

def mhKernel(lpost, rprop, dprop = lambda new, old: 1.):
    def kernel(x, ll):
        prop = rprop(x)
        lp = lpost(prop)
        a = lp - ll + dprop(x, prop) - dprop(prop, x)
        if (np.log(np.random.rand()) < a):
            x = prop
            ll = lp
        return x, ll
    return kernel

and then a MALA kernel

def malaKernel(lpi, glpi, dt = 1e-4, pre = 1):
    p = len(init)
    sdt = np.sqrt(dt)
    spre = np.sqrt(pre)
    advance = lambda x: x + 0.5*pre*glpi(x)*dt
    return mhKernel(lpi, lambda x: advance(x) + np.random.randn(p)*spre*sdt,
            lambda new, old: np.sum(sp.stats.norm.logpdf(new, loc=advance(old), scale=spre*sdt)))

See the full runnable script for further details.

JAX

If we want our algorithm to run fast, and if we want to exploit automatic differentiation to avoid the need to manually compute gradients, then we can easily convert the above code to use JAX.

def mhKernel(lpost, rprop, dprop = jit(lambda new, old: 1.)):
    @jit
    def kernel(key, x, ll):
        key0, key1 = jax.random.split(key)
        prop = rprop(key0, x)
        lp = lpost(prop)
        a = lp - ll + dprop(x, prop) - dprop(prop, x)
        accept = (jnp.log(jax.random.uniform(key1)) < a)
        return jnp.where(accept, prop, x), jnp.where(accept, lp, ll)
    return kernel

def malaKernel(lpi, dt = 1e-4, pre = 1):
    p = len(init)
    glpi = jit(grad(lpost))
    sdt = jnp.sqrt(dt)
    spre = jnp.sqrt(pre)
    advance = jit(lambda x: x + 0.5*pre*glpi(x)*dt)
    return mhKernel(lpi, jit(lambda k, x: advance(x) +
                             jax.random.normal(k, [p])*spre*sdt),
            jit(lambda new, old:
                jnp.sum(jsp.stats.norm.logpdf(new,
                      loc=advance(old), scale=spre*sdt))))

See the full runnable script for further details.

Scala

def mhKernel[S](
    logPost: S => Double, rprop: S => S,
    dprop: (S, S) => Double = (n: S, o: S) => 1.0
  ): ((S, Double)) => (S, Double) =
    val r = Uniform(0.0,1.0)
    state =>
      val (x0, ll0) = state
      val x = rprop(x0)
      val ll = logPost(x)
      val a = ll - ll0 + dprop(x0, x) - dprop(x, x0)
      if (math.log(r.draw()) < a)
        (x, ll)
      else
        (x0, ll0)

def malaKernel(lpi: DVD => Double, glpi: DVD => DVD, pre: DVD, dt: Double = 1e-4) =
  val sdt = math.sqrt(dt)
  val spre = sqrt(pre)
  val p = pre.length
  def advance(beta: DVD): DVD =
    beta + (0.5*dt)*(pre*:*glpi(beta))
  def rprop(beta: DVD): DVD =
    advance(beta) + sdt*spre.map(Gaussian(0,_).sample())
  def dprop(n: DVD, o: DVD): Double = 
    val ao = advance(o)
    (0 until p).map(i => Gaussian(ao(i), spre(i)*sdt).logPdf(n(i))).sum
  mhKernel(lpi, rprop, dprop)

See the full runnable script for further details.

Haskell

mhKernel :: (StatefulGen g m) => (s -> Double) -> (s -> g -> m s) ->
  (s -> s -> Double) -> g -> (s, Double) -> m (s, Double)
mhKernel logPost rprop dprop g (x0, ll0) = do
  x <- rprop x0 g
  let ll = logPost(x)
  let a = ll - ll0 + (dprop x0 x) - (dprop x x0)
  u <- (genContVar (uniformDistr 0.0 1.0)) g
  let next = if ((log u) < a)
        then (x, ll)
        else (x0, ll0)
  return next

malaKernel :: (StatefulGen g m) =>
  (Vector Double -> Double) -> (Vector Double -> Vector Double) -> 
  Vector Double -> Double -> g ->
  (Vector Double, Double) -> m (Vector Double, Double)
malaKernel lpi glpi pre dt g = let
  sdt = sqrt dt
  spre = cmap sqrt pre
  p = size pre
  advance beta = beta + (scalar (0.5*dt))*pre*(glpi beta)
  rprop beta g = do
    zl <- (replicateM p . genContVar (normalDistr 0.0 1.0)) g
    let z = fromList zl
    return $ advance(beta) + (scalar sdt)*spre*z
  dprop n o = let
    ao = advance o
    in sum $ (\i -> logDensity (normalDistr (ao!i) 
      ((spre!i)*sdt)) (n!i)) <$> [0..(p-1)]
  in mhKernel lpi rprop dprop g

See the full runnable script for further details.

Dex

Recall that Dex is differentiable, so we don’t need to provide gradients.

def mhKernel {s} (lpost: s -> Float) (rprop: s -> Key -> s) (dprop: s -> s -> Float)
    (sll: (s & Float)) (k: Key) : (s & Float) =
  (x0, ll0) = sll
  [k1, k2] = split_key k
  x = rprop x0 k1
  ll = lpost x
  a = ll - ll0 + (dprop x0 x) - (dprop x x0)
  u = rand k2
  select (log u < a) (x, ll) (x0, ll0)

def malaKernel {n} (lpi: (Fin n)=>Float -> Float)
    (pre: (Fin n)=>Float) (dt: Float) :
    ((Fin n)=>Float & Float) -> Key -> ((Fin n)=>Float & Float) =
  sdt = sqrt dt
  spre = sqrt pre
  glp = grad lpi
  v = dt .* pre
  vinv = map (\ x. 1.0/x) v
  def advance (beta: (Fin n)=>Float) : (Fin n)=>Float =
    beta + (0.5*dt) .* (pre*(glp beta))
  def rprop (beta: (Fin n)=>Float) (k: Key) : (Fin n)=>Float =
    (advance beta) + sdt .* (spre*(randn_vec k))
  def dprop (new: (Fin n)=>Float) (old: (Fin n)=>Float) : Float =
    ao = advance old
    diff = new - ao
    -0.5 * sum ((log v) + diff*diff*vinv)
  mhKernel lpi rprop dprop

See the full runnable script for further details.

Next steps

MALA gives us an MCMC algorithm that exploits gradient information to generate "informed" M-H proposals. But it still has a rather "diffusive" character, making it difficult to tune in such a way that large moves are likely to be accepted in challenging high-dimensional situations.

The Langevin dynamics on which MALA is based can be interpreted as the (over-damped) stochastic dynamics of a particle moving in a potential energy field corresponding to minus the log posterior. It turns out that the corresponding deterministic dynamics can be exploited to generate proposals better able to make large moves while still having a high probability of acceptance. This is the idea behind Hamiltonian Monte Carlo (HMC), which we’ll look at next.

Bayesian inference for a logistic regression model (Part 4)

Part 4: Gradients and the Langevin algorithm

Introduction

This is the fourth part in a series of posts on MCMC-based Bayesian inference for a logistic regression model. If you are new to this series, please go back to Part 1.

In the previous post we saw how the Metropolis algorithm could be used to generate a Markov chain targeting our posterior distribution. In high dimensions the diffusive nature of the Metropolis random walk proposal becomes increasingly inefficient. It is therefore natural to try and develop algorithms that use additional information about the target distribution. In the case of a differentiable log posterior target, a natural first step in this direction is to try and make use of gradient information.

Gradient of a logistic regression model

There are various ways to derive the gradient of our logistic regression model, but it might be simplest to start from the first form of the log likelihood that we deduced in Part 2:

\displaystyle l(\beta;y) = y^\textsf{T}X\beta - \mathbf{1}^\textsf{T}\log(\mathbf{1}+\exp[X\beta])

We can write this out in component form as

\displaystyle l(\beta;y) = \sum_j\sum_j y_iX_{ij}\beta_j - \sum_i\log\left(1+\exp\left[\sum_jX_{ij}\beta_j\right]\right).

Differentiating wrt \beta_k gives

\displaystyle \frac{\partial l}{\partial \beta_k} = \sum_i y_iX_{ik} - \sum_i \frac{\exp\left[\sum_j X_{ij}\beta_j\right]X_{ik}}{1+\exp\left[\sum_j X_{ij}\beta_j\right]}.

It’s then reasonably clear that stitching all of the partial derivatives together will give the gradient vector

\displaystyle \nabla l = X^\textsf{T}\left[ y - \frac{\mathbf{1}}{\mathbf{1}+\exp[-X\beta]} \right].

This is the gradient of the log likelihood, but we also need the gradient of the log prior. Since we are assuming independent \beta_i \sim N(0,v_i) priors, it is easy to see that the gradient of the log prior is just -\beta\circ v^{-1}. It is the sum of these two terms that gives the gradient of the log posterior.

R

In R we can implement our gradient function as

glp = function(beta) {
    glpr = -beta/(pscale*pscale)
    gll = as.vector(t(X) %*% (y - 1/(1 + exp(-X %*% beta))))
    glpr + gll
}

Python

In Python we could use

def glp(beta):
    glpr = -beta/(pscale*pscale)
    gll = (X.T).dot(y - 1/(1 + np.exp(-X.dot(beta))))
    return (glpr + gll)

We don’t really need a JAX version, since JAX can auto-diff the log posterior for us.

Scala

  def glp(beta: DVD): DVD =
    val glpr = -beta /:/ pvar
    val gll = (X.t)*(y - ones/:/(ones + exp(-X*beta)))
    glpr + gll

Haskell

Using hmatrix we could use something like

glp :: Matrix Double -> Vector Double -> Vector Double -> Vector Double
glp x y b = let
  glpr = -b / (fromList [100.0, 1, 1, 1, 1, 1, 1, 1])
  gll = (tr x) #> (y - (scalar 1)/((scalar 1) + (cmap exp (-x #> b))))
  in glpr + gll

There’s something interesting to say about Haskell and auto-diff, but getting into this now will be too much of a distraction. I may come back to it in some future post.

Dex

Dex is differentiable, so we don’t need a gradient function – we can just use grad lpost. However, for interest and comparison purposes we could nevertheless implement it directly with something like

prscale = map (\ x. 1.0/x) pscale

def glp (b: (Fin 8)=>Float) : (Fin 8)=>Float =
  glpr = -b*prscale*prscale
  gll = (transpose x) **. (y - (map (\eta. 1.0/(1.0 + eta)) (exp (-x **. b))))
  glpr + gll

Langevin diffusions

Now that we have a way of computing the gradient of the log of our target density we need some MCMC algorithms that can make good use of it. In this post we will look at a simple approximate MCMC algorithm derived from an overdamped Langevin diffusion model. In subsequent posts we’ll look at more sophisticated, exact MCMC algorithms.

The multivariate stochastic differential equation (SDE)

\displaystyle dX_t = \frac{1}{2}\nabla\log\pi(X_t)dt + dW_t

has \pi(\cdot) as its equilibrium distribution. Informally, an SDE of this form is a continuous time process with infinitesimal transition kernel

\displaystyle X_{t+dt}|(X_t=x_t) \sim N\left(x_t+\frac{1}{2}\nabla\log\pi(x_t)dt,\mathbf{I}dt\right).

There are various more-or-less formal ways to see that \pi(\cdot) is stationary. A good way is to check it satisfies the Fokker–Planck equation with zero LHS. A less formal approach would be to see that the infinitesimal transition kernel for the process satisfies detailed balance with \pi(\cdot).

Similar arguments show that for any fixed positive definite matrix A, the SDE

\displaystyle dX_t = \frac{1}{2}A\nabla\log\pi(X_t)dt + A^{1/2}dW_t

also has \pi(\cdot) as a stationary distribution. It is quite common to choose a diagonal matrix A to put the components of X_t on a common scale.

The unadjusted Langevin algorithm

Simulating exact sample paths from SDEs such as the overdamped Langevin diffusion model is typically difficult (though not necessarily impossible), so we instead want something simple and tractable as the basis of our MCMC algorithms. Here we will just simulate from the Euler–Maruyama approximation of the process by choosing a small but finite time step \Delta t and using the transition kernel

\displaystyle X_{t+\Delta t}|(X_t=x_t) \sim N\left(x_t+\frac{1}{2}A\nabla\log\pi(x_t)\Delta t, A\Delta t\right)

as the basis of our MCMC method. For sufficiently small \Delta t this should accurately approximate the Langevin dynamics, leading to an equilibrium distribution very close to \pi(\cdot). That said, we would like to choose \Delta t as large as we can get away with, since that will lead to a more rapidly mixing MCMC chain. Below are some implementations of this kernel for a diagonal pre-conditioning matrix.

Implementation

R

We can create a kernel for the unadjusted Langevin algorithm in R with the following function.

ulKernel = function(glpi, dt = 1e-4, pre = 1) {
    sdt = sqrt(dt)
    spre = sqrt(pre)
    advance = function(x) x + 0.5*pre*glpi(x)*dt
    function(x, ll) rnorm(p, advance(x), spre*sdt)
}

Here, we can pass in pre, which is expected to be a vector representing the diagonal of the pre-conditioning matrix, A. We can then use this kernel to generate an MCMC chain as we have seen previously. See the full runnable script for further details.

Python

def ulKernel(glpi, dt = 1e-4, pre = 1):
    p = len(init)
    sdt = np.sqrt(dt)
    spre = np.sqrt(pre)
    advance = lambda x: x + 0.5*pre*glpi(x)*dt
    def kernel(x):
        return advance(x) + np.random.randn(p)*spre*sdt
    return kernel

See the full runnable script for further details.

JAX

def ulKernel(lpi, dt = 1e-4, pre = 1):
    p = len(init)
    glpi = jit(grad(lpi))
    sdt = jnp.sqrt(dt)
    spre = jnp.sqrt(pre)
    advance = jit(lambda x: x + 0.5*pre*glpi(x)*dt)
    @jit
    def kernel(key, x):
        return advance(x) + jax.random.normal(key, [p])*spre*sdt
    return kernel

Note how for JAX we can just pass in the log posterior, and the gradient function can be obtained by automatic differentiation. See the full runnable script for further details.

Scala

def ulKernel(glp: DVD => DVD, pre: DVD, dt: Double): DVD => DVD =
  val sdt = math.sqrt(dt)
  val spre = sqrt(pre)
  def advance(beta: DVD): DVD =
    beta + (0.5*dt)*(pre*:*glp(beta))
  beta => advance(beta) + sdt*spre.map(Gaussian(0,_).sample())

See the full runnable script for further details.

Haskell

ulKernel :: (StatefulGen g m) =>
  (Vector Double -> Vector Double) -> Vector Double -> Double -> g ->
  Vector Double -> m (Vector Double)
ulKernel glpi pre dt g beta = do
  let sdt = sqrt dt
  let spre = cmap sqrt pre
  let p = size pre
  let advance beta = beta + (scalar (0.5*dt))*pre*(glpi beta)
  zl <- (replicateM p . genContVar (normalDistr 0.0 1.0)) g
  let z = fromList zl
  return $latex  advance(beta) + (scalar sdt)*spre*z

See the full runnable script for further details.

Dex

In Dex we can write a function that accepts a gradient function

def ulKernel {n} (glpi: (Fin n)=>Float -> (Fin n)=>Float)
    (pre: (Fin n)=>Float) (dt: Float)
    (b: (Fin n)=>Float) (k: Key) : (Fin n)=>Float =
  sdt = sqrt dt
  spre = sqrt pre
  b + (((0.5)*dt) .* (pre*(glpi b))) +
    (sdt .* (spre*(randn_vec k)))

or we can write a function that accepts a log posterior, and uses auto-diff to construct the gradient

def ulKernel {n} (lpi: (Fin n)=>Float -> Float)
    (pre: (Fin n)=>Float) (dt: Float)
    (b: (Fin n)=>Float) (k: Key) : (Fin n)=>Float =
  glpi = grad lpi
  sdt = sqrt dt
  spre = sqrt pre
  b + ((0.5)*dt) .* (pre*(glpi b)) +
    sdt .* (spre*(randn_vec k))

and since Dex is statically typed, we can’t easily mix these functions up.

See the full runnable scripts, without and with auto-diff.

Next steps

In this post we have seen how to construct an MCMC algorithm that makes use of gradient information. But this algorithm is approximate. In the next post we’ll see how to correct for the approximation by using the Langevin updates as proposals within a Metropolis-Hastings algorithm.

Bayesian inference for a logistic regression model (Part 3)

Part 3: The Metropolis algorithm

Introduction

This is the third part in a series of posts on MCMC-based Bayesian inference for a logistic regression model. If you are new to this series, please go back to Part 1.

In the previous post we derived the log posterior for the model and implemented it in a variety of programming languages and libraries. In this post we will construct a Markov chain having the posterior as its equilibrium.

MCMC

Detailed balance

A homogeneous Markov chain with transition kernel p(\theta_{n+1}|\theta_n) is said to satisfy detailed balance for some target distribution \pi(\theta) if

\displaystyle \pi(\theta)p(\theta'|\theta) = \pi(\theta')p(\theta|\theta'), \quad \forall \theta, \theta'

Integrating both sides wrt \theta gives

\displaystyle \int_\Theta \pi(\theta)p(\theta'|\theta)d\theta = \pi(\theta'),

from which it is clear that \pi(\cdot) is a stationary distribution of the chain (and the chain is reversible). Under fairly mild regularity conditions we expect \pi(\cdot) to be the equilibrium distribution of the chain.

For a given target \pi(\cdot) we would like to find an easy-to-sample-from transition kernel p(\cdot|\cdot) that satisfies detailed balance. This will then give us a way to (asymptotically) generate samples from our target.

In the context of Bayesian inference, the target \pi(\theta) will typically be the posterior distribution, which in the previous post we wrote as \pi(\theta|y). Here we drop the notational dependence on y, since MCMC can be used for any target distribution of interest.

Metropolis-Hastings

Suppose we have a fairly arbitrary easy-to-sample-from transition kernel q(\theta_{n+1}|\theta_n) and a target of interest, \pi(\cdot). Metropolis-Hastings (M-H) is a strategy for using q(\cdot|\cdot) to construct a new transition kernel p(\cdot|\cdot) satisfying detailed balance for \pi(\cdot).

The kernel p(\theta'|\theta) can be described algorithmically as follows:

  1. Call the current state of the chain \theta. Generate a proposal \theta^\star by simulating from q(\theta^\star|\theta).
  2. Compute the acceptance probability

\displaystyle \alpha(\theta^\star|\theta) = \min\left[1,\frac{\pi(\theta^\star)q(\theta|\theta^\star)}{\pi(\theta)q(\theta^\star|\theta)}\right].

  1. With probability \alpha(\theta^\star|\theta) return new state \theta'=\theta^\star, otherwise return \theta'=\theta.

It is clear from the algorithmic description that this kernel will have a point mass at \theta'=\theta, but that for \theta'\not=\theta the transition kernel will be p(\theta'|\theta)=q(\theta'|\theta)\alpha(\theta'|\theta). But then

\displaystyle \pi(\theta)p(\theta'|\theta) = \min[\pi(\theta)q(\theta'|\theta),\pi(\theta')q(\theta|\theta')]

is symmetric in \theta and \theta', and so detailed balance is satisfied. Since detailed balance is trivial at the point mass at \theta=\theta' we are done.

Metropolis algorithm

It is often convenient to generate proposals perturbatively, using a distribution that is symmetric about the current state of the chain. But then q(\theta'|\theta)=q(\theta|\theta'), and so q(\cdot|\cdot) drops out of the acceptance probability. This is the Metropolis algorithm.

Some computational tricks

To generate an event with probability \alpha(\theta^\star|\theta), we can generate a u\sim U(0,1) and accept if u < \alpha(\theta^\star|\theta). This is convenient for several reasons. First, it means that we can ignore the "min", and just accept if

\displaystyle u < \frac{\pi(\theta^\star)q(\theta|\theta^\star)}{\pi(\theta)q(\theta^\star|\theta)}

since u\leq 1 regardless. Better still, we can take logs, and accept if

\displaystyle \log u < \log\pi(\theta^\star) - \log\pi(\theta) + \log q(\theta|\theta^\star) - \log q(\theta^\star|\theta),

so there is no need to evaluate any raw densities. Again, in the case of a symmetric proposal distribution, the q(\cdot|\cdot) terms can be dropped.

Another trick worth noting is that in the case of the simple M-H algorithm described, using a single update for the entire state space (and not multiple component-wise updates, for example), and assuming that the same M-H kernel is used repeatedly to generate successive states of a Markov chain, then the \log \pi(\theta) term (which in the context of Bayesian inference will typically be the log posterior) will have been computed at the previous update (irrespective of whether or not the previous move was accepted). So if we are careful about how we pass on that old value to the next iteration, we can avoid recomputing the log posterior, and our algorithm will only require one log posterior evaluation per iteration rather than two. In functional programming languages it is often convenient to pass around this current log posterior density evaluation explicitly, effectively augmenting the state space of the Markov chain to include the log posterior density.

HoF for a M-H kernel

Since I’m a fan of functional programming, we will adopt a functional style throughout, and start by creating a higher-order function (HoF) that accepts a log-posterior and proposal kernel as input and returns a Metropolis kernel as output.

R

In R we can write a function to create a M-H kernel as follows.

mhKernel = function(logPost, rprop, dprop = function(new, old, ...) { 1 })
    function(x, ll) {
        prop = rprop(x)
        llprop = logPost(prop)
        a = llprop - ll + dprop(x, prop) - dprop(prop, x)
        if (log(runif(1)) < a)
            list(x=prop, ll=llprop)
        else
            list(x=x, ll=ll)
    }

Note that the kernel returned requires as input both a current state x and its associated log-posterior, ll. The new state and log-posterior densities are returned.

We need to use this transition kernel to simulate a Markov chain by successive substitution of newly simulated values back into the kernel. In more sophisticated programming languages we will use streams for this, but in R we can just use a for loop to sample values and write the states into the rows of a matrix.

mcmc = function(init, kernel, iters = 10000, thin = 10, verb = TRUE) {
    p = length(init)
    ll = -Inf
    mat = matrix(0, nrow = iters, ncol = p)
    colnames(mat) = names(init)
    x = init
    if (verb) 
        message(paste(iters, "iterations"))
    for (i in 1:iters) {
        if (verb) 
            message(paste(i, ""), appendLF = FALSE)
        for (j in 1:thin) {
            pair = kernel(x, ll)
            x = pair$x
            ll = pair$ll
            }
        mat[i, ] = x
        }
    if (verb) 
        message("Done.")
    mat
}

Then, in the context of our running logistic regression example, and using the log-posterior from the previous post, we can construct our kernel and run it as follows.

pre = c(10.0,1,1,1,1,1,5,1)
out = mcmc(init, mhKernel(lpost,
          function(x) x + pre*rnorm(p, 0, 0.02)), thin=1000)

Note the use of a symmetric proposal, so the proposal density is not required. Also note the use of a larger proposal variance for the intercept term and the second last covariate. See the full runnable script for further details.

Python

We can do something very similar to R in Python using NumPy. Our HoF for constructing a M-H kernel is

def mhKernel(lpost, rprop, dprop = lambda new, old: 1.):
    def kernel(x, ll):
        prop = rprop(x)
        lp = lpost(prop)
        a = lp - ll + dprop(x, prop) - dprop(prop, x)
        if (np.log(np.random.rand()) < a):
            x = prop
            ll = lp
        return x, ll
    return kernel

Our Markov chain runner function is

def mcmc(init, kernel, thin = 10, iters = 10000, verb = True):
    p = len(init)
    ll = -np.inf
    mat = np.zeros((iters, p))
    x = init
    if (verb):
        print(str(iters) + " iterations")
    for i in range(iters):
        if (verb):
            print(str(i), end=" ", flush=True)
        for j in range(thin):
            x, ll = kernel(x, ll)
        mat[i,:] = x
    if (verb):
        print("\nDone.", flush=True)
    return mat

We can use this code in the context of our logistic regression example as follows.

pre = np.array([10.,1.,1.,1.,1.,1.,5.,1.])

def rprop(beta):
    return beta + 0.02*pre*np.random.randn(p)

out = mcmc(init, mhKernel(lpost, rprop), thin=1000)

See the full runnable script for further details.

JAX

The above R and Python scripts are fine, but both languages are rather slow for this kind of workload. Fortunately it’s rather straightforward to convert the Python code to JAX to obtain quite amazing speed-up. We can write our M-H kernel as

def mhKernel(lpost, rprop, dprop = jit(lambda new, old: 1.)):
    @jit
    def kernel(key, x, ll):
        key0, key1 = jax.random.split(key)
        prop = rprop(key0, x)
        lp = lpost(prop)
        a = lp - ll + dprop(x, prop) - dprop(prop, x)
        accept = (jnp.log(jax.random.uniform(key1)) < a)
        return jnp.where(accept, prop, x), jnp.where(accept, lp, ll)
    return kernel

and our MCMC runner function as

def mcmc(init, kernel, thin = 10, iters = 10000):
    key = jax.random.PRNGKey(42)
    keys = jax.random.split(key, iters)
    @jit
    def step(s, k):
        [x, ll] = s
        x, ll = kernel(k, x, ll)
        s = [x, ll]
        return s, s
    @jit
    def iter(s, k):
        keys = jax.random.split(k, thin)
        _, states = jax.lax.scan(step, s, keys)
        final = [states[0][thin-1], states[1][thin-1]]
        return final, final
    ll = -np.inf
    x = init
    _, states = jax.lax.scan(iter, [x, ll], keys)
    return states[0]

There are really only two slightly tricky things about this code.

The first relates to the way JAX handles pseudo-random numbers. Since JAX is a pure functional eDSL, it can’t be used in conjunction with the typical pseudo-random number generators often used in imperative programming languages which rely on a global mutable state. This can be dealt with reasonably straightforwardly by explicitly passing around the random number state. There is a standard way of doing this that has been common practice in functional programming languages for decades. However, this standard approach is very sequential, and so doesn’t work so well in a parallel context. JAX therefore uses a splittable random number generator, where new states are created by splitting the current state into two (or more). We’ll come back to this when we get to the Haskell examples.

The second thing that might be unfamiliar to imperative programmers is the use of the scan operation (jax.lax.scan) to generate the Markov chain rather than a "for" loop. But scans are standard operations in most functional programming languages.

We can then call this code for our logistic regression example with

pre = jnp.array([10.,1.,1.,1.,1.,1.,5.,1.]).astype(jnp.float32)

@jit
def rprop(key, beta):
    return beta + 0.02*pre*jax.random.normal(key, [p])

out = mcmc(init, mhKernel(lpost, rprop), thin=1000)

See the full runnable script for further details.

Scala

In Scala we can use a similar approach to that already seen for defining a HoF to return a M-H kernel.

def mhKernel[S](
    logPost: S => Double, rprop: S => S,
    dprop: (S, S) => Double = (n: S, o: S) => 1.0
  ): ((S, Double)) => (S, Double) =
    val r = Uniform(0.0,1.0)
    state =>
      val (x0, ll0) = state
      val x = rprop(x0)
      val ll = logPost(x)
      val a = ll - ll0 + dprop(x0, x) - dprop(x, x0)
      if (math.log(r.draw()) < a)
        (x, ll)
      else
        (x0, ll0)

Note that Scala’s static typing does not prevent us from defining a function that is polymorphic in the type of the chain state, which we here call S. Also note that we are adopting a pragmatic approach to random number generation, exploiting the fact that Scala is not a pure functional language, using a mutable generator, and omitting to capture the non-determinism of the rprop function (and the returned kernel) in its type signature. In Scala this is a choice, and we could adopt a purer approach if preferred. We’ll see what such an approach will look like in Haskell, coming up next.

Now that we have the kernel, we don’t need to write an explicit runner function since Scala has good support for streaming data. There are many more-or-less sophisticated ways that we can work with data streams in Scala, and the choice depends partly on how pure one is being about tracking effects (such as non-determinism), but here I’ll just use the simple LazyList from the standard library for unfolding the kernel into an infinite MCMC chain before thinning and truncating appropriately.

  val pre = DenseVector(10.0,1.0,1.0,1.0,1.0,1.0,5.0,1.0)
  def rprop(beta: DVD): DVD = beta + pre *:* (DenseVector(Gaussian(0.0,0.02).sample(p).toArray))
  val kern = mhKernel(lpost, rprop)
  val s = LazyList.iterate((init, -Inf))(kern) map (_._1)
  val out = s.drop(150).thin(1000).take(10000)

See the full runnable script for further details.

Haskell

Since Haskell is a pure functional language, we need to have some convention regarding pseudo-random number generation. Haskell supports several styles. The most commonly adopted approach wraps a mutable generator up in a monad. The typical alternative is to use a pure functional generator and either explicitly thread the state through code or hide this in a monad similar to the standard approach. However, Haskell also supports the use of splittable generators, so we can consider all three approaches for comparative purposes. The approach taken does affect the code and the type signatures, and even the streaming data abstractions most appropriate for chain generation.

Starting with a HoF for producing a Metropolis kernel, an approach using the standard monadic generators could like like

mKernel :: (StatefulGen g m) => (s -> Double) -> (s -> g -> m s) -> 
           g -> (s, Double) -> m (s, Double)
mKernel logPost rprop g (x0, ll0) = do
  x <- rprop x0 g
  let ll = logPost(x)
  let a = ll - ll0
  u <- (genContVar (uniformDistr 0.0 1.0)) g
  let next = if ((log u) < a)
        then (x, ll)
        else (x0, ll0)
  return next

Note how non-determinism is captured in the type signatures by the monad m. The explicit pure approach is to thread the generator through non-deterministic functions.

mKernelP :: (RandomGen g) => (s -> Double) -> (s -> g -> (s, g)) -> 
            g -> (s, Double) -> ((s, Double), g)
mKernelP logPost rprop g (x0, ll0) = let
  (x, g1) = rprop x0 g
  ll = logPost(x)
  a = ll - ll0
  (u, g2) = uniformR (0, 1) g1
  next = if ((log u) < a)
        then (x, ll)
        else (x0, ll0)
  in (next, g2)

Here the updated random number generator state is returned from each non-deterministic function for passing on to subsequent non-deterministic functions. This explicit sequencing of operations makes it possible to wrap the generator state in a state monad giving code very similar to the stateful monadic generator approach, but as already discussed, the sequential nature of this approach makes it unattractive in parallel and concurrent settings.

Fortunately the standard Haskell pure generator is splittable, meaning that we can adopt a splitting approach similar to JAX if we prefer, since this is much more parallel-friendly.

mKernelP :: (RandomGen g) => (s -> Double) -> (s -> g -> s) -> 
            g -> (s, Double) -> (s, Double)
mKernelP logPost rprop g (x0, ll0) = let
  (g1, g2) = split g
  x = rprop x0 g1
  ll = logPost(x)
  a = ll - ll0
  u = unif g2
  next = if ((log u) < a)
        then (x, ll)
        else (x0, ll0)
  in next

Here non-determinism is signalled by passing a generator state (often called a "key" in the context of splittable generators) into a function. Functions receiving a key are responsible for splitting it to ensure that no key is ever used more than once.

Once we have a kernel, we need to unfold our Markov chain. When using the monadic generator approach, it is most natural to unfold using a monadic stream

mcmc :: (StatefulGen g m) =>
  Int -> Int -> s -> (g -> s -> m s) -> g -> MS.Stream m s
mcmc it th x0 kern g = MS.iterateNM it (stepN th (kern g)) x0

stepN :: (Monad m) => Int -> (a -> m a) -> (a -> m a)
stepN n fa = if (n == 1)
  then fa
  else (\x -> (fa x) >>= (stepN (n-1) fa))

whereas for the explicit approaches it is more natural to unfold into a regular infinite data stream. So, for the explicit sequential approach we could use

mcmcP :: (RandomGen g) => s -> (g -> s -> (s, g)) -> g -> DS.Stream s
mcmcP x0 kern g = DS.unfold stepUf (x0, g)
  where
    stepUf xg = let
      (x1, g1) = kern (snd xg) (fst xg)
      in (x1, (x1, g1))

and with the splittable approach we could use

mcmcP :: (RandomGen g) =>
  s -> (g -> s -> s) -> g -> DS.Stream s
mcmcP x0 kern g = DS.unfold stepUf (x0, g)
  where
    stepUf xg = let
      (x1, g1) = xg
      x2 = kern g1 x1
      (g2, _) = split g1
      in (x2, (x2, g2))

Calling these functions for our logistic regression example is similar to what we have seen before, but again there are minor syntactic differences depending on the approach. For further details see the full runnable scripts for the monadic approach, the pure sequential approach, and the splittable approach.

Dex

Dex is a pure functional language and uses a splittable random number generator, so the style we use is similar to JAX (or Haskell using a splittable generator). We can generate a Metropolis kernel with

def mKernel {s} (lpost: s -> Float) (rprop: Key -> s -> s) : 
    Key -> (s & Float) -> (s & Float) =
  def kern (k: Key) (sll: (s & Float)) : (s & Float) =
    (x0, ll0) = sll
    [k1, k2] = split_key k
    x = rprop k1 x0
    ll = lpost x
    a = ll - ll0
    u = rand k2
    select (log u < a) (x, ll) (x0, ll0)
  kern

We can then unfold our Markov chain with

def markov_chain {s} (k: Key) (init: s) (kern: Key -> s -> s) (its: Nat) :
    Fin its => s =
  with_state init \st.
    for i:(Fin its).
      x = kern (ixkey k i) (get st)
      st := x
      x

Here we combine Dex’s state effect with a for loop to unfold the stream. See the full runnable script for further details.

Next steps

As previously discussed, none of these codes are optimised, so care should be taken not to over-interpret running times. However, JAX and Dex are noticeably faster than the alternatives, even running on a single CPU core. Another interesting feature of both JAX and Dex is that they are differentiable. This makes it very easy to develop algorithms using gradient information. In subsequent posts we will think about the gradient of our example log-posterior and how we can use gradient information to develop "better" sampling algorithms.

The complete runnable scripts are all available from this public github repo.

The smfsb R package

Introduction

In the previous post I gave a brief introduction to the third edition of my textbook, Stochastic modelling for systems biology. The algorithms described in the book are illustrated by implementations in R. These implementations are collected together in an R package on CRAN called smfsb. This post will provide a brief introduction to the package and its capabilities.

Installation

The package is on CRAN – see the CRAN package page for details. So the simplest way to install it is to enter

install.packages("smfsb")

at the R command prompt. This will install the latest version that is on CRAN. Once installed, the package can be loaded with

library(smfsb)

The package is well-documented, so further information can be obtained with the usual R mechanisms, such as

vignette(package="smfsb")
vignette("smfsb")
help(package="smfsb")
?StepGillespie
example(StepCLE1D)

The version of the package on CRAN is almost certainly what you want. However, the package is developed on R-Forge – see the R-Forge project page for details. So the very latest version of the package can always be installed with

install.packages("smfsb", repos="http://R-Forge.R-project.org")

if you have a reason for wanting it.

A brief tutorial

The vignette gives a quick introduction the the library, which I don’t need to repeat verbatim here. If you are new to the package, I recommend working through that before continuing. Here I’ll concentrate on some of the new features associated with the third edition.

Simulating stochastic kinetic models

Much of the book is concerned with the simulation of stochastic kinetic models using exact and approximate algorithms. Although the primary focus of the text is the application to modelling of intra-cellular processes, the methods are also appropriate for population modelling of ecological and epidemic processes. For example, we can start by simulating a simple susceptible-infectious-recovered (SIR) disease epidemic model.

set.seed(2)
data(spnModels)

stepSIR = StepGillespie(SIR)
plot(simTs(SIR$M, 0, 8, 0.05, stepSIR),
  main="Exact simulation of the SIR model")

Exact simulation of the SIR epidemic model
The focus of the text is stochastic simulation of discrete models, so that is the obvious place to start. But there is also support for continuous deterministic simulation.

plot(simTs(SIR$M, 0, 8, 0.05, StepEulerSPN(SIR)),
  main="Euler simulation of the SIR model")

Euler simulation of the SIR model
My favourite toy population dynamics model is the Lotka-Volterra (LV) model, so I tend to use this frequently as a running example throughout the book. We can simulate this (exactly) as follows.

stepLV = StepGillespie(LV)
plot(simTs(LV$M, 0, 30, 0.2, stepLV),
  main="Exact simulation of the LV model")

Exact simulation of the Lotka-Volterra model

Stochastic reaction-diffusion modelling

The first two editions of the book were almost exclusively concerned with well-mixed systems, where spatial effects are ignorable. One of the main new features of the third edition is the inclusion of a new chapter on spatially extended systems. The focus is on models related to the reaction diffusion master equation (RDME) formulation, rather than individual particle-based simulations. For these models, space is typically divided into a regular grid of voxels, with reactions taking place as normal within each voxel, and additional reaction events included, corresponding to the diffusion of particles to adjacent voxels. So to specify such models, we just need an initial condition, a reaction model, and diffusion coefficients (one for each reacting species). So, we can carry out exact simulation of an RDME model for a 1D spatial domain as follows.

N=20; T=30
x0=matrix(0, nrow=2, ncol=N)
rownames(x0) = c("x1", "x2")
x0[,round(N/2)] = LV$M
stepLV1D = StepGillespie1D(LV, c(0.6, 0.6))
xx = simTs1D(x0, 0, T, 0.2, stepLV1D, verb=TRUE)
image(xx[1,,], main="Prey", xlab="Space", ylab="Time")

Discrete 1D simulation of the LV model

image(xx[2,,], main="Predator", xlab="Space", ylab="Time")

Discrete 1D simulation of the LV model
Exact simulation of discrete stochastic reaction diffusion systems is very expensive (and the reference implementation provided in the package is very inefficient), so we will often use diffusion approximations based on the CLE.

stepLV1DC = StepCLE1D(LV, c(0.6, 0.6))
xx = simTs1D(x0, 0, T, 0.2, stepLV1D)
image(xx[1,,], main="Prey", xlab="Space", ylab="Time")

Spatial CLE simulation of the 1D LV model

image(xx[2,,], main="Predator", xlab="Space", ylab="Time")

Spatial CLE simulation of the 1D LV model
We can think of this algorithm as an explicit numerical integration of the obvious SPDE approximation to the exact model.

The package also includes support for simulation of 2D systems. Again, we can use the Spatial CLE to speed things up.

m=70; n=50; T=10
data(spnModels)
x0=array(0, c(2,m,n))
dimnames(x0)[[1]]=c("x1", "x2")
x0[,round(m/2),round(n/2)] = LV$M
stepLV2D = StepCLE2D(LV, c(0.6,0.6), dt=0.05)
xx = simTs2D(x0, 0, T, 0.5, stepLV2D)
N = dim(xx)[4]
image(xx[1,,,N],main="Prey",xlab="x",ylab="y")

Spatial CLE simulation of the 2D LV model

image(xx[2,,,N],main="Predator",xlab="x",ylab="y")

Spatial CLE simulation of the 2D LV model

Bayesian parameter inference

Although much of the book is concerned with the problem of forward simulation, the final chapters are concerned with the inverse problem of estimating model parameters, such as reaction rate constants, from data. A computational Bayesian approach is adopted, with the main emphasis being placed on “likelihood free” methods, which rely on forward simulation to avoid explicit computation of sample path likelihoods. The second edition included some rudimentary code for a likelihood free particle marginal Metropolis-Hastings (PMMH) particle Markov chain Monte Carlo (pMCMC) algorithm. The third edition includes a more complete and improved implementation, in addition to approximate inference algorithms based on approximate Bayesian computation (ABC).

The key function underpinning the PMMH approach is pfMLLik, which computes an estimate of marginal model log-likelihood using a (bootstrap) particle filter. There is a new implementation of this function with the third edition. There is also a generic implementation of the Metropolis-Hastings algorithm, metropolisHastings, which can be combined with pfMLLik to create a PMMH algorithm. PMMH algorithms are very slow, but a full demo of how to use these functions for parameter inference is included in the package and can be run with

demo(PMCMC)

Simple rejection-based ABC methods are facilitated by the (very simple) function abcRun, which just samples from a prior and then carries out independent simulations in parallel before computing summary statistics. A simple illustration of the use of the function is given below.

data(LVdata)
rprior <- function() { exp(c(runif(1, -3, 3),runif(1,-8,-2),runif(1,-4,2))) }
rmodel <- function(th) { simTs(c(50,100), 0, 30, 2, stepLVc, th) }
sumStats <- identity
ssd = sumStats(LVperfect)
distance <- function(s) {
    diff = s - ssd
    sqrt(sum(diff*diff))
}
rdist <- function(th) { distance(sumStats(rmodel(th))) }
out = abcRun(10000, rprior, rdist)
q=quantile(out$dist, c(0.01, 0.05, 0.1))
print(q)
##       1%       5%      10% 
## 772.5546 845.8879 881.0573
accepted = out$param[out$dist < q[1],]
print(summary(accepted))
##        V1                V2                  V3         
##  Min.   :0.06498   Min.   :0.0004467   Min.   :0.01887  
##  1st Qu.:0.16159   1st Qu.:0.0012598   1st Qu.:0.04122  
##  Median :0.35750   Median :0.0023488   Median :0.14664  
##  Mean   :0.68565   Mean   :0.0046887   Mean   :0.36726  
##  3rd Qu.:0.86708   3rd Qu.:0.0057264   3rd Qu.:0.36870  
##  Max.   :4.76773   Max.   :0.0309364   Max.   :3.79220
print(summary(log(accepted)))
##        V1                V2               V3         
##  Min.   :-2.7337   Min.   :-7.714   Min.   :-3.9702  
##  1st Qu.:-1.8228   1st Qu.:-6.677   1st Qu.:-3.1888  
##  Median :-1.0286   Median :-6.054   Median :-1.9198  
##  Mean   :-0.8906   Mean   :-5.877   Mean   :-1.9649  
##  3rd Qu.:-0.1430   3rd Qu.:-5.163   3rd Qu.:-0.9978  
##  Max.   : 1.5619   Max.   :-3.476   Max.   : 1.3329

Naive rejection-based ABC algorithms are notoriously inefficient, so the library also includes an implementation of a more efficient, sequential version of ABC, often known as ABC-SMC, in the function abcSmc. This function requires specification of a perturbation kernel to “noise up” the particles at each algorithm sweep. Again, the implementation is parallel, using the parallel package to run the required simulations in parallel on multiple cores. A simple illustration of use is given below.

rprior <- function() { c(runif(1, -3, 3), runif(1, -8, -2), runif(1, -4, 2)) }
dprior <- function(x, ...) { dunif(x[1], -3, 3, ...) + 
    dunif(x[2], -8, -2, ...) + dunif(x[3], -4, 2, ...) }
rmodel <- function(th) { simTs(c(50,100), 0, 30, 2, stepLVc, exp(th)) }
rperturb <- function(th){th + rnorm(3, 0, 0.5)}
dperturb <- function(thNew, thOld, ...){sum(dnorm(thNew, thOld, 0.5, ...))}
sumStats <- identity
ssd = sumStats(LVperfect)
distance <- function(s) {
    diff = s - ssd
    sqrt(sum(diff*diff))
}
rdist <- function(th) { distance(sumStats(rmodel(th))) }
out = abcSmc(5000, rprior, dprior, rdist, rperturb,
    dperturb, verb=TRUE, steps=6, factor=5)
## 6 5 4 3 2 1
print(summary(out))
##        V1                V2               V3        
##  Min.   :-2.9961   Min.   :-7.988   Min.   :-3.999  
##  1st Qu.:-1.9001   1st Qu.:-6.786   1st Qu.:-3.428  
##  Median :-1.2571   Median :-6.167   Median :-2.433  
##  Mean   :-1.0789   Mean   :-6.014   Mean   :-2.196  
##  3rd Qu.:-0.2682   3rd Qu.:-5.261   3rd Qu.:-1.161  
##  Max.   : 2.1128   Max.   :-2.925   Max.   : 1.706

We can then plot some results with

hist(out[,1],30,main="log(c1)")

ABC-SMC posterior for the LV model

hist(out[,2],30,main="log(c2)")

ABC-SMC posterior for the LV model

hist(out[,3],30,main="log(c3)")

ABC-SMC posterior for the LV model

Although the inference methods are illustrated in the book in the context of parameter inference for stochastic kinetic models, their implementation is generic, and can be used with any appropriate parameter inference problem.

The smfsbSBML package

smfsbSBML is another R package associated with the third edition of the book. This package is not on CRAN due to its dependency on a package not on CRAN, and hence is slightly less straightforward to install. Follow the available installation instructions to install the package. Once installed, you should be able to load the package with

library(smfsbSBML)

This package provides a function for reading in SBML files and parsing them into the simulatable stochastic Petri net (SPN) objects used by the main smfsb R package. Examples of suitable SBML models are included in the main smfsb GitHub repo. An appropriate SBML model can be read and parsed with a command like:

model = sbml2spn("mySbmlModel.xml")

The resulting value, model is an SPN object which can be passed in to simulation functions such as StepGillespie for constructing stochastic simulation algorithms.

Other software

In addition to the above R packages, I also have some Python scripts for converting between SBML and the SBML-shorthand notation I use in the book. See the SBML-shorthand page for further details.

Although R is a convenient language for teaching and learning about stochastic simulation, it isn’t ideal for serious research-level scientific computing or computational statistics. So for the third edition of the book I have also developed scala-smfsb, a library written in the Scala programming language, which re-implements all of the models and algorithms from the third edition of the book in Scala, a fast, efficient, strongly-typed, compiled, functional programming language. I’ll give an introduction to this library in a subsequent post, but in the meantime, it is already well documented, so see the scala-smfsb repo for further details, including information on installation, getting started, a tutorial, examples, API docs, etc.

Source

This blog post started out as an RMarkdown document, the source of which can be found here.

Data frames and tables in Scala

Introduction

To statisticians and data scientists used to working in R, the concept of a data frame is one of the most natural and basic starting points for statistical computing and data analysis. It always surprises me that data frames aren’t a core concept in most programming languages’ standard libraries, since they are essentially a representation of a relational database table, and relational databases are pretty ubiquitous in data processing and related computing. For statistical modelling and data science, having functions designed for data frames is much more elegant than using functions designed to work directly on vectors and matrices, for example. Trivial things like being able to refer to columns by a readable name rather than a numeric index makes a huge difference, before we even get into issues like columns of heterogeneous types, coherent handling of missing data, etc. This is why modelling in R is typically nicer than in certain other languages I could mention, where libraries for scientific and numerical computing existed for a long time before libraries for data frames were added to the language ecosystem.

To build good libraries for statistical computing in Scala, it will be helpful to build those libraries using a good data frame implementation. With that in mind I’ve started to look for existing Scala data frame libraries and to compare them.

A simple data manipulation task

For this post I’m going to consider a very simple data manipulation task: first reading in a CSV file from disk into a data frame object, then filtering out some rows, then adding a derived column, then finally writing the data frame back to disk as a CSV file. We will start by looking at how this would be done in R. First we need an example CSV file. Since many R packages contain example datasets, we will use one of those. We will export Cars93 from the MASS package:

library(MASS)
write.csv(Cars93,"cars93.csv",row.names=FALSE)

If MASS isn’t installed, it can be installed with a simple install.packages("MASS"). The above code snippet generates a CSV file to be used for the example. Typing ?Cars93 will give some information about the dataset, including the original source.

Our analysis task is going to be to load the file from disk, filter out cars with EngineSize larger than 4 (litres), add a new column to the data frame, WeightKG, containing the weight of the car in KG, derived from the column Weight (in pounds), and then write back to disk in CSV format. This is the kind of thing that R excels at (pun intended):

df=read.csv("cars93.csv")
print(dim(df))
df = df[df$EngineSize<=4.0,]
print(dim(df))
df$WeightKG = df$Weight*0.453592
print(dim(df))
write.csv(df,"cars93m.csv",row.names=FALSE)

Now let’s see how a similar task could be accomplished using Scala data frames.

Data frames and tables in Scala

Saddle

Saddle is probably the best known data frame library for Scala. It is strongly influenced by the pandas library for Python. A simple Saddle session for accomplishing this task might proceed as follows:

val file = CsvFile("cars93.csv")
val df = CsvParser.parse(file).withColIndex(0)
println(df)
val df2 = df.rfilter(_("EngineSize").
             mapValues(CsvParser.parseDouble).at(0)<=4.0)
println(df2)
val wkg=df2.col("Weight").mapValues(CsvParser.parseDouble).
             mapValues(_*0.453592).setColIndex(Index("WeightKG"))
val df3=df2.joinPreserveColIx(wkg.mapValues(_.toString))
println(df3)
df3.writeCsvFile("saddle-out.csv")

Although this looks OK, it’s not completely satisfactory, as the data frame is actually representing a matrix of Strings. Although you can have a data frame containing columns of any type, since Saddle data frames are backed by a matrix object (with type corresponding to the common super-type), the handling of columns of heterogeneous types always seems rather cumbersome. I suspect that it is this clumsy handling of heterogeneously typed columns that has motivated the development of alternative data frame libraries for Scala.

Scala-datatable

Scala-datatable is a lightweight minimal immutable data table for Scala, with good support for columns of differing types. However, it is currently really very minimal, and doesn’t have CSV import or export, for example. That said, there are several CSV libraries for Scala, so it’s quite easy to write functions to import from CSV into a datatable and write CSV back out from one. I’ve a couple of example functions, readCsv() and writeCsv() in the full code examples associated with this post. Now since datatable supports heterogeneous column types and I don’t want to write a type guesser, my readCsv() function expects information regarding the column types. This could be relaxed with a bit of effort. An example session follows:

    val colTypes=Map("DriveTrain" -> StringCol, 
                     "Min.Price" -> Double, 
                     "Cylinders" -> Int, 
                     "Horsepower" -> Int, 
                     "Length" -> Int, 
                     "Make" -> StringCol, 
                     "Passengers" -> Int, 
                     "Width" -> Int, 
                     "Fuel.tank.capacity" -> Double, 
                     "Origin" -> StringCol, 
                     "Wheelbase" -> Int, 
                     "Price" -> Double, 
                     "Luggage.room" -> Double, 
                     "Weight" -> Int, 
                     "Model" -> StringCol, 
                     "Max.Price" -> Double, 
                     "Manufacturer" -> StringCol, 
                     "EngineSize" -> Double, 
                     "AirBags" -> StringCol, 
                     "Man.trans.avail" -> StringCol, 
                     "Rear.seat.room" -> Double, 
                     "RPM" -> Int, 
                     "Turn.circle" -> Double, 
                     "MPG.highway" -> Int, 
                     "MPG.city" -> Int, 
                     "Rev.per.mile" -> Int, 
                     "Type" -> StringCol)
    val df=readCsv("Cars93",new FileReader("cars93.csv"),colTypes)
    println(df.length,df.columns.length)
    val df2=df.filter(row=>row.as[Double]("EngineSize")<=4.0).toDataTable
    println(df2.length,df2.columns.length)

    val oldCol=df2.columns("Weight").as[Int]
    val newCol=new DataColumn[Double]("WeightKG",oldCol.data.map{_.toDouble*0.453592})
    val df3=df2.columns.add(newCol).get
    println(df3.length,df3.columns.length)

    writeCsv(df3,new File("out.csv"))

Apart from the declaration of column types, the code is actually a little bit cleaner than the corresponding Saddle code, and the column types are all properly preserved and appropriately handled. However, a significant limitation of this data frame is that it doesn’t seem to have special handling of missing values, requiring some kind of manually coded “special value” approach from users of this data frame. This is likely to limit the appeal of this library for general statistical and data science applications.

Framian

Framian is a full-featured data frame library for Scala, open-sourced by Pellucid analytics. It is strongly influenced by R data frame libraries, and aims to provide most of the features that R users would expect. It has good support for clean handling of heterogeneously typed columns (using shapeless), handles missing data, and includes good CSV import:

val df=Csv.parseFile(new File("cars93.csv")).labeled.toFrame
println(""+df.rows+" "+df.cols)
val df2=df.filter(Cols("EngineSize").as[Double])( _ <= 4.0 )
println(""+df2.rows+" "+df2.cols)
val df3=df2.map(Cols("Weight").as[Int],"WeightKG")(r=>r.toDouble*0.453592)
println(""+df3.rows+" "+df3.cols)
println(df3.colIndex)
val csv = Csv.fromFrame(new CsvFormat(",", header = true))(df3)
new PrintWriter("out.csv") { write(csv.toString); close }

This is arguably the cleanest solution so far. Unfortunately the output isn’t quite right(!), as there currently seems to be a bug in Csv.fromFrame which causes the ordering of columns to get out of sync with the ordering of the column headers. Presumably this bug will soon be fixed, and if not it is easy to write a CSV writer for these frames, as I did above for scala-datatable.

Spark DataFrames

The three data frames considered so far are all standard single-machine, non-distributed, in-memory objects. The Scala data frame implementation currently subject to the most social media buzz is a different beast entirely. A DataFrame object has recently been added to Apache Spark. I’ve already discussed the problems of first developing a data analysis library without data frames and then attempting to bolt a data frame object on top post-hoc. Spark has repeated this mistake, but it’s still much better to have a data frame in Spark than not. Spark is a Scala framework for the distributed processing and analysis of huge datasets on a cluster. I will discuss it further in future posts. If you have a legitimate need for this kind of set-up, then Spark is a pretty impressive piece of technology (though note that there are competitors, such as flink). However, for datasets that can be analysed on a single machine, then Spark seems like a rather slow and clunky sledgehammer to crack a nut. So, for datasets in the terabyte range and above, Spark DataFrames are great, but for datasets smaller than a few gigs, it’s probably not the best solution. With those caveats in mind, here’s how to solve our problem using Spark DataFrames (and the spark-csv library) in the Spark Shell:

val df = sqlContext.read.format("com.databricks.spark.csv").
                         option("header", "true").
                         option("inferSchema","true").
                         load("cars93.csv")
val df2=df.filter("EngineSize <= 4.0")
val col=df2.col("Weight")*0.453592
val df3=df2.withColumn("WeightKG",col)
df3.write.format("com.databricks.spark.csv").
                         option("header","true").
                         save("out-csv")

Summary

If you really need a distributed data frame library, then you will probably want to use Spark. However, for the vast majority of statistical modelling and data science tasks, Spark is likely to be unnecessarily complex and heavyweight. The other three libraries considered all have pros and cons. They are all largely one-person hobby projects, quite immature, and not currently under very active development. Saddle is fine for when you just want to add column headings to a matrix. Scala-datatable is lightweight and immutable, if you don’t care about missing values. On balance, I think Framian is probably the most full-featured “batteries included” R-like data frame, and so is likely to be most attractive to statisticians and data scientists. However, it’s pretty immature, and the dependence on shapeless may be of concern to those who prefer libraries to be lean and devoid of sorcery!

I’d be really interested to know of other people’s experiences of these libraries, so please do comment if you have any views, and especially if you have opinions on the relative merits of the different libraries.

The full source code for all of these examples, including sbt build files, can be found in a new github repo I’ve created for the code examples associated with this blog.

Calling R from Scala sbt projects using rscala

Overview

In the previous post I showed how the rscala package (which has replaced the jvmr package) can be used to call Scala code from within R. In this post I will show how to call R from Scala code. I have previously described how to do this using jvmr. This post is really just an update to show how things work with rscala.

Since I’m focusing here on Scala sbt projects, I’m assuming that sbt is installed, in addition to rscala (described in the previous post). The only “trick” required for calling back to R from Scala is telling sbt where the rscala jar file is located. You can find the location from the R console as illustrated by the following session:

> library(rscala)
> rscala::rscalaJar("2.11")
[1] "/home/ndjw1/R/x86_64-pc-linux-gnu-library/3.2/rscala/java/rscala_2.11-1.0.6.jar"

This location (which will obviously be different for you) can then be added in to your sbt classpath by adding the following line to your build.sbt file:

unmanagedJars in Compile += file("/home/ndjw1/R/x86_64-pc-linux-gnu-library/3.2/rscala/java/rscala_2.11-1.0.6.jar")

Once this is done, calling out to R from your Scala sbt project can be carried out as described in the rscala documentation. For completeness, a working example is given below.

Example

In this example I will use Scala to simulate some data consistent with a Poisson regression model, and then push the data to R to fit it using the R function glm(), and then pull back the fitted regression coefficients into Scala. This is obviously a very artificial example, but the point is to show how it is possible to call back to R for some statistical procedure that may be “missing” from Scala.

The dependencies for this project are described in the file build.sbt

name := "rscala test"

version := "0.1"

scalacOptions ++= Seq("-unchecked", "-deprecation", "-feature")

libraryDependencies  ++= Seq(
            "org.scalanlp" %% "breeze" % "0.10",
            "org.scalanlp" %% "breeze-natives" % "0.10"
)

resolvers ++= Seq(
            "Sonatype Snapshots" at "https://oss.sonatype.org/content/repositories/snapshots/",
            "Sonatype Releases" at "https://oss.sonatype.org/content/repositories/releases/"
)

unmanagedJars in Compile += file("/home/ndjw1/R/x86_64-pc-linux-gnu-library/3.2/rscala/java/rscala_2.11-1.0.6.jar")

scalaVersion := "2.11.6"

The complete Scala program is contained in the file PoisReg.scala

import org.ddahl.rscala.callback._
import breeze.stats.distributions._
import breeze.linalg._

object ScalaToRTest {

  def main(args: Array[String]) = {

    // first simulate some data consistent with a Poisson regression model
    val x = Uniform(50,60).sample(1000)
    val eta = x map { xi => (xi * 0.1) - 3 }
    val mu = eta map { math.exp(_) }
    val y = mu map { Poisson(_).draw }
    
    // call to R to fit the Poission regression model
    val R = RClient() // initialise an R interpreter
    R.x=x.toArray // send x to R
    R.y=y.toArray // send y to R
    R.eval("mod <- glm(y~x,family=poisson())") // fit the model in R
    // pull the fitted coefficents back into scala
    val beta = DenseVector[Double](R.evalD1("mod$coefficients"))

    // print the fitted coefficents
    println(beta)

  }

}

If these two files are put in an empty directory, the code can be compiled and run by typing sbt run from the command prompt in the relevant directory. The commented code should be self-explanatory, but see the rscala documentation for further details. In particular, the rscala scaladoc is useful.

Calling Scala code from R using rscala

Introduction

In a previous post I looked at how to call Scala code from R using a CRAN package called jvmr. This package now seems to have been replaced by a new package called rscala. Like the old package, it requires a pre-existing Java installation. Unlike the old package, however, it no longer depends on rJava, which may simplify some installations. The rscala package is well documented, with a reference manual and a draft paper. In this post I will concentrate on the issue of calling sbt-based projects with dependencies on external libraries (such as breeze).

On a system with Java installed, it should be possible to install the rscala package with a simple

install.packages("rscala")

from the R command prompt. Calling

library(rscala)

will check that it has worked. The package will do a sensible search for a Scala installation and use it if it can find one. If it can’t find one (or can only find an installation older than 2.10.x), it will fail. In this case you can download and install a Scala installation specifically for rscala using the command

rscala::scalaInstall()

This option is likely to be attractive to sbt (or IDE) users who don’t like to rely on a system-wide scala installation.

A Gibbs sampler in Scala using Breeze

For illustration I’m going to use a Scala implementation of a Gibbs sampler. The Scala code, gibbs.scala is given below:

package gibbs

object Gibbs {

    import scala.annotation.tailrec
    import scala.math.sqrt
    import breeze.stats.distributions.{Gamma,Gaussian}

    case class State(x: Double, y: Double) {
      override def toString: String = x.toString + " , " + y + "\n"
    }

    def nextIter(s: State): State = {
      val newX = Gamma(3.0, 1.0/((s.y)*(s.y)+4.0)).draw
      State(newX, Gaussian(1.0/(newX+1), 1.0/sqrt(2*newX+2)).draw)
    }

    @tailrec def nextThinnedIter(s: State,left: Int): State =
      if (left==0) s else nextThinnedIter(nextIter(s),left-1)

    def genIters(s: State, stop: Int, thin: Int): List[State] = {
      @tailrec def go(s: State, left: Int, acc: List[State]): List[State] =
        if (left>0)
          go(nextThinnedIter(s,thin), left-1, s::acc)
          else acc
      go(s,stop,Nil).reverse
    }

    def main(args: Array[String]) = {
      if (args.length != 3) {
        println("Usage: sbt \"run <outFile> <iters> <thin>\"")
        sys.exit(1)
      } else {
        val outF=args(0)
        val iters=args(1).toInt
        val thin=args(2).toInt
        val out = genIters(State(0.0,0.0),iters,thin)
        val s = new java.io.FileWriter(outF)
        s.write("x , y\n")
        out map { it => s.write(it.toString) }
        s.close
      }
    }

}

This code requires Scala and the Breeze scientific library in order to build. We can specify this in a sbt build file, which should be called build.sbt and placed in the same directory as the Scala code.

name := "gibbs"

version := "0.1"

scalacOptions ++= Seq("-unchecked", "-deprecation", "-feature")

libraryDependencies  ++= Seq(
            "org.scalanlp" %% "breeze" % "0.10",
            "org.scalanlp" %% "breeze-natives" % "0.10"
)

resolvers ++= Seq(
            "Sonatype Snapshots" at "https://oss.sonatype.org/content/repositories/snapshots/",
            "Sonatype Releases" at "https://oss.sonatype.org/content/repositories/releases/"
)

scalaVersion := "2.11.6"

Now, from a system command prompt in the directory where the files are situated, it should be possible to download all dependencies and compile and run the code with a simple

sbt "run output.csv 50000 1000"

sbt magically manages all of the dependencies for us so that we don’t have to worry about them. However, for calling from R, it may be desirable to run the code without running sbt. There are several ways to achieve this, but the simplest is to build an “assembly jar” or “fat jar”, which is a Java byte-code file containing all code and libraries required in order to run the code on any system with a Java installation.

To build an assembly jar first create a subdirectory called project (the name matters), an in it place two files. The first should be called assembly.sbt, and should contain the line

addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.13.0")

Since the version of the assembly tool can depend on the version of sbt, it is also best to fix the version of sbt being used by creating another file in the project directory called build.properties, which should contain the line

sbt.version=0.13.7

Now return to the parent directory and run

sbt assembly

If this works, it should create a fat jar target/scala-2.11/gibbs-assembly-0.1.jar. You can check it works by running

java -jar target/scala-2.11/gibbs-assembly-0.1.jar output.csv 10000 10

Assuming that it does, you are now ready to try running the code from within R.

Calling via R system calls

Since this code takes a relatively long time to run, calling it from R via simple system calls isn’t a particularly terrible idea. For example, we can do this from the R command prompt with the following commands

system("java -jar target/scala-2.11/gibbs-assembly-0.1.jar output.csv 50000 1000")
out=read.csv("output.csv")
library(smfsb)
mcmcSummary(out,rows=2)

This works fine, but is a bit clunky. Tighter integration between R and Scala would be useful, which is where rscala comes in.

Calling assembly Scala projects via rscala

rscala provides a very simple way to embed a Scala interpreter within an R session, to be able to execute Scala expressions from R and to have the results returned back to the R session for further processing. The main issue with using this in practice is managing dependencies on external libraries and setting the Scala classpath correctly. By using an assembly jar we can bypass most of these issues, and it becomes trivial to call our Scala code direct from the R interpreter, as the following code illustrates.

library(rscala)
sc=scalaInterpreter("target/scala-2.11/gibbs-assembly-0.1.jar")
sc%~%'import gibbs.Gibbs._'
out=sc%~%'genIters(State(0.0,0.0),50000,1000).toArray.map{s=>Array(s.x,s.y)}'
library(smfsb)
mcmcSummary(out,rows=2)

Here we call the getIters function directly, rather than via the main method. This function returns an immutable List of States. Since R doesn’t understand this, we map it to an Array of Arrays, which R then unpacks into an R matrix for us to store in the matrix out.

Summary

The CRAN package rscala makes it very easy to embed a Scala interpreter within an R session. However, for most non-trivial statistical computing problems, the Scala code will have dependence on external scientific libraries such as Breeze. The standard way to easily manage external dependencies in the Scala ecosystem is sbt. Given an sbt-based Scala project, it is easy to generate an assembly jar in order to initialise the rscala Scala interpreter with the classpath needed to call arbitrary Scala functions. This provides very convenient inter-operability between R and Scala for many statistical computing applications.

Calling R from Scala sbt projects

[Update: The jvmr package has been replaced by the rscala package. There is a new version of this post which replaces this one.]

Overview

In previous posts I’ve shown how the jvmr CRAN R package can be used to call Scala sbt projects from R and inline Scala Breeze code in R. In this post I will show how to call to R from a Scala sbt project. This requires that R and the jvmr CRAN R package are installed on your system, as described in the previous posts. Since I’m focusing here on Scala sbt projects, I’m also assuming that sbt is installed.

The only “trick” required for calling back to R from Scala is telling sbt where the jvmr jar file is located. You can find the location from the R console as illustrated by the following session:

&gt; library(jvmr)
&gt; .jvmr.jar
[1] "/home/ndjw1/R/x86_64-pc-linux-gnu-library/3.1/jvmr/java/jvmr_2.11-2.11.2.1.jar"

This location (which will obviously be different for you) can then be added in to your sbt classpath by adding the following line to your build.sbt file:

unmanagedJars in Compile += file("/home/ndjw1/R/x86_64-pc-linux-gnu-library/3.1/jvmr/java/jvmr_2.11-2.11.2.1.jar")

Once this is done, calling out to R from your Scala sbt project can be carried out as described in the jvmr documentation. For completeness, a working example is given below.

Example

In this example I will use Scala to simulate some data consistent with a Poisson regression model, and then push the data to R to fit it using the R function glm(), and then pull back the fitted regression coefficients into Scala. This is obviously a very artificial example, but the point is to show how it is possible to call back to R for some statistical procedure that may be “missing” from Scala.

The dependencies for this project are described in the file build.sbt

name := "jvmr test"

version := "0.1"

scalacOptions ++= Seq("-unchecked", "-deprecation", "-feature")

libraryDependencies  ++= Seq(
            "org.scalanlp" %% "breeze" % "0.10",
            "org.scalanlp" %% "breeze-natives" % "0.10"
)

resolvers ++= Seq(
            "Sonatype Snapshots" at "https://oss.sonatype.org/content/repositories/snapshots/",
            "Sonatype Releases" at "https://oss.sonatype.org/content/repositories/releases/"
)

unmanagedJars in Compile += file("/home/ndjw1/R/x86_64-pc-linux-gnu-library/3.1/jvmr/java/jvmr_2.11-2.11.2.1.jar")

scalaVersion := "2.11.2"

The complete Scala program is contained in the file PoisReg.scala

import org.ddahl.jvmr.RInScala
import breeze.stats.distributions._
import breeze.linalg._

object ScalaToRTest {

  def main(args: Array[String]) = {

    // first simulate some data consistent with a Poisson regression model
    val x = Uniform(50,60).sample(1000)
    val eta = x map { xi =&gt; (xi * 0.1) - 3 }
    val mu = eta map { math.exp(_) }
    val y = mu map { Poisson(_).draw }
    
    // call to R to fit the Poission regression model
    val R = RInScala() // initialise an R interpreter
    R.x=x.toArray // send x to R
    R.y=y.toArray // send y to R
    R.eval("mod &lt;- glm(y~x,family=poisson())") // fit the model in R
    // pull the fitted coefficents back into scala
    val beta = DenseVector[Double](R.toVector[Double]("mod$coefficients"))

    // print the fitted coefficents
    println(beta)

  }

}

If these two files are put in an empty directory, the code can be compiled and run by typing sbt run from the command prompt in the relevant directory. The commented code should be self-explanatory, but see the jvmr documentation for further details.

Inlining Scala Breeze code in R using jvmr and sbt

[Update: The CRAN package “jvmr” has been replaced by a new package “rscala”. Rather than completely re-write this post, I’ve just created a github gist containing a new function, breezeInterpreter(), which works similarly to the function breezeInit() in this post. Usage information is given at the top of the gist.]

Introduction

In the previous post I showed how to call Scala code from R using sbt and jvmr. The approach described in that post is the one I would recommend for any non-trivial piece of Scala code – mixing up code from different languages in the same source code file is not a good strategy in general. That said, for very small snippets of code, it can sometimes be convenient to inline Scala code directly into an R source code file. The canonical example of this is a computationally intensive algorithm being prototyped in R which has a slow inner loop. If the inner loop can be recoded in a few lines of Scala, it would be nice to just inline this directly into the R code without having to create a separate Scala project. The CRAN package jvmr provides a very simple and straightforward way to do this. However, as discussed in the last post, most Scala code for statistical computing (even short and simple code) is likely to rely on Breeze for special functions, probability distributions, non-uniform random number generation, numerical linear algebra, etc. In this post we will see how to use sbt in order to make sure that the Breeze library and all of its dependencies are downloaded and cached, and to provide a correct classpath with which to initialise a jvmr scalaInterpreter session.

Setting up

Configuring your system to be able to inline Scala Breeze code is very easy. You just need to install Java, R and sbt. Then install the CRAN R package jvmr. At this point you have everything you need except for the R function breezeInit, given at the end of this post. I’ve deferred the function to the end of the post as it is a bit ugly, and the details of it are not important. All it does is get sbt to ensure that Breeze is correctly downloaded and cached and then starts a scalaInterpreter with Breeze on the classpath. With this function available, we can use it within an R session as the following R session illustrates:

&gt; b=breezeInit()
&gt; b['import breeze.stats.distributions._']
NULL
&gt; b['Poisson(10).sample(20).toArray']
 [1] 13 14 13 10  7  6 15 14  5 10 14 11 15  8 11 12  6  7  5  7
&gt; summary(b['Gamma(3,2).sample(10000).toArray'])
   Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
 0.2124  3.4630  5.3310  5.9910  7.8390 28.5200 
&gt; 

So we see that Scala Breeze code can be inlined directly into an R session, and if we are careful about return types, have the results of Scala expressions automatically unpack back into convenient R data structures.

Summary

In this post I have shown how easy it is to inline Scala Breeze code into R using sbt in conjunction with the CRAN package jvmr. This has many potential applications, with the most obvious being the desire to recode slow inner loops from R to Scala. This should give performance quite comparable with alternatives such as Rcpp, with the advantage being that you get to write beautiful, elegant, functional Scala code instead of horrible, ugly, imperative C++ code! 😉

The breezeInit function

The actual breezeInit() function is given below. It is a little ugly, but very simple. It is obviously easy to customise for different libraries and library versions as required. All of the hard work is done by sbt which must be installed and on the default system path in order for this function to work.

breezeInit&lt;-function()
{
  library(jvmr)
  sbtStr="name := \"tmp\"

version := \"0.1\"

libraryDependencies  ++= Seq(
            \"org.scalanlp\" %% \"breeze\" % \"0.10\",
            \"org.scalanlp\" %% \"breeze-natives\" % \"0.10\"
)

resolvers ++= Seq(
            \"Sonatype Snapshots\" at \"https://oss.sonatype.org/content/repositories/snapshots/\",
            \"Sonatype Releases\" at \"https://oss.sonatype.org/content/repositories/releases/\"
)

scalaVersion := \"2.11.2\"

lazy val printClasspath = taskKey[Unit](\"Dump classpath\")

printClasspath := {
  (fullClasspath in Runtime value) foreach {
    e =&gt; print(e.data+\"!\")
  }
}
"
  tmps=file(file.path(tempdir(),"build.sbt"),"w")
  cat(sbtStr,file=tmps)
  close(tmps)
  owd=getwd()
  setwd(tempdir())
  cpstr=system2("sbt","printClasspath",stdout=TRUE)
  cpst=cpstr[length(cpstr)]
  cpsp=strsplit(cpst,"!")[[1]]
  cp=cpsp[2:(length(cpsp)-1)]
  si=scalaInterpreter(cp,use.jvmr.class.path=FALSE)
  setwd(owd)
  si
}